mirror of
https://github.com/fuckpiracyshield/data-storage.git
synced 2024-11-22 13:19:46 +01:00
Added support to sessions. More changes to the blacklist methods.
This commit is contained in:
parent
4acfb7e91c
commit
5397063bbd
4 changed files with 332 additions and 2 deletions
|
@ -0,0 +1 @@
|
||||||
|
|
169
src/piracyshield_data_storage/account/session/memory.py
Normal file
169
src/piracyshield_data_storage/account/session/memory.py
Normal file
|
@ -0,0 +1,169 @@
|
||||||
|
from piracyshield_data_storage.database.redis.document import DatabaseRedisDocument, DatabaseRedisSetException, DatabaseRedisGetException
|
||||||
|
|
||||||
|
class AccountSessionMemory(DatabaseRedisDocument):
|
||||||
|
|
||||||
|
session_prefix = 'session'
|
||||||
|
|
||||||
|
def __init__(self, database: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.establish(database)
|
||||||
|
|
||||||
|
def add_long_session(self, account_id: str, refresh_token: str, data: dict, duration: int) -> bool | Exception:
|
||||||
|
"""
|
||||||
|
Store an access token generated from a refresh token.
|
||||||
|
|
||||||
|
:param refresh_token: a valid refresh token.
|
||||||
|
:param access_token: a valid access token.
|
||||||
|
:param duration: refresh token expire time.
|
||||||
|
:return: true if the item has been stored.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.hset_with_expiry(
|
||||||
|
key = f'{self.session_prefix}:{account_id}:long:{refresh_token}',
|
||||||
|
mapping = data,
|
||||||
|
expiry = duration
|
||||||
|
)
|
||||||
|
|
||||||
|
except DatabaseRedisSetException:
|
||||||
|
raise AccountSessionMemorySetException()
|
||||||
|
|
||||||
|
def add_short_session(self, account_id: str, refresh_token: str, access_token: str, data: dict, duration: int) -> bool | Exception:
|
||||||
|
"""
|
||||||
|
Store an access token generated from a refresh token.
|
||||||
|
|
||||||
|
:param refresh_token: a valid refresh token.
|
||||||
|
:param access_token: a valid access token.
|
||||||
|
:param duration: refresh token expire time.
|
||||||
|
:return: true if the item has been stored.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.hset_with_expiry(
|
||||||
|
key = f'{self.session_prefix}:{account_id}:short:{access_token}',
|
||||||
|
mapping = data,
|
||||||
|
expiry = duration
|
||||||
|
)
|
||||||
|
|
||||||
|
except DatabaseRedisSetException:
|
||||||
|
raise AccountSessionMemorySetException()
|
||||||
|
|
||||||
|
def get_all_by_account(self, account_id: str) -> list | Exception:
|
||||||
|
"""
|
||||||
|
Retrieves all the active long and short sessions.
|
||||||
|
|
||||||
|
:param account_id: a valid account identifier.
|
||||||
|
:return: the requested data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.keys(
|
||||||
|
key = f'{self.session_prefix}:{account_id}:*:*'
|
||||||
|
)
|
||||||
|
|
||||||
|
except DatabaseRedisGetException:
|
||||||
|
raise AccountSessionMemoryGetException()
|
||||||
|
|
||||||
|
def get_all_short_by_account(self, account_id: str) -> list | Exception:
|
||||||
|
"""
|
||||||
|
Retrieves all the active short sessions.
|
||||||
|
|
||||||
|
:param account_id: a valid account identifier.
|
||||||
|
:return: the requested data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.keys(
|
||||||
|
key = f'{self.session_prefix}:{account_id}:short:*'
|
||||||
|
)
|
||||||
|
|
||||||
|
except DatabaseRedisGetException:
|
||||||
|
raise AccountSessionMemoryGetException()
|
||||||
|
|
||||||
|
def get_session(self, session: str) -> list | Exception:
|
||||||
|
"""
|
||||||
|
Retrieves a single session.
|
||||||
|
|
||||||
|
:param token: a valid refresh or access identifier.
|
||||||
|
:return: the requested data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.hgetall(
|
||||||
|
key = session
|
||||||
|
)
|
||||||
|
|
||||||
|
except DatabaseRedisGetException:
|
||||||
|
raise AccountSessionMemoryGetException()
|
||||||
|
|
||||||
|
def find_long_session(self, refresh_token: str) -> list | Exception:
|
||||||
|
"""
|
||||||
|
Retrieves a single session.
|
||||||
|
|
||||||
|
:param token: a valid refresh or access identifier.
|
||||||
|
:return: the requested data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.keys(
|
||||||
|
key = f'{self.session_prefix}:*:long:{refresh_token}'
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(response, list) and len(response):
|
||||||
|
return response.__getitem__(0)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except DatabaseRedisGetException:
|
||||||
|
raise AccountSessionMemoryGetException()
|
||||||
|
|
||||||
|
def remove_long_session(self, account_id: str, refresh_token: str) -> bool | Exception:
|
||||||
|
"""
|
||||||
|
Removes a long session.
|
||||||
|
|
||||||
|
:param account_id: a valid account identifier.
|
||||||
|
:param refresh_token: a valid active refresh token.
|
||||||
|
:return: true if the item has been removed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.delete(
|
||||||
|
key = f'{self.session_prefix}:{account_id}:long:{refresh_token}'
|
||||||
|
)
|
||||||
|
|
||||||
|
except DatabaseRedisSetException:
|
||||||
|
raise AccountSessionMemorySetException()
|
||||||
|
|
||||||
|
def remove_short_session(self, account_id: str, access_token: str) -> bool | Exception:
|
||||||
|
"""
|
||||||
|
Removes a short session.
|
||||||
|
|
||||||
|
:param account_id: a valid account identifier.
|
||||||
|
:param access_token: a valid active access token.
|
||||||
|
:return: true if the item has been removed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.delete(
|
||||||
|
key = f'{self.session_prefix}:{account_id}:short:{access_token}'
|
||||||
|
)
|
||||||
|
|
||||||
|
except DatabaseRedisSetException:
|
||||||
|
raise AccountSessionMemorySetException()
|
||||||
|
|
||||||
|
class AccountSessionMemorySetException(Exception):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Cannot set the value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
class AccountSessionMemoryGetException(Exception):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Cannot get the value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
|
@ -2,12 +2,43 @@ from piracyshield_data_storage.database.redis.connection import DatabaseRedisCon
|
||||||
|
|
||||||
class DatabaseRedisDocument(DatabaseRedisConnection):
|
class DatabaseRedisDocument(DatabaseRedisConnection):
|
||||||
|
|
||||||
|
def keys(self, key: str) -> any:
|
||||||
|
try:
|
||||||
|
return self.instance.keys(
|
||||||
|
name = key
|
||||||
|
)
|
||||||
|
|
||||||
|
except:
|
||||||
|
raise DatabaseRedisGetException()
|
||||||
|
|
||||||
|
# string
|
||||||
|
|
||||||
def set_with_expiry(self, key: str, value: any, expiry: int) -> bool | Exception:
|
def set_with_expiry(self, key: str, value: any, expiry: int) -> bool | Exception:
|
||||||
if self.instance.set(key, value, ex = expiry) == True:
|
if self.instance.set(key, value, ex = expiry) == True:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
raise DatabaseRedisSetException()
|
raise DatabaseRedisSetException()
|
||||||
|
|
||||||
|
def setnx_with_expiry(self, key: str, value: any, expiry: int) -> bool | Exception:
|
||||||
|
pipeline = self.instance.pipeline()
|
||||||
|
|
||||||
|
pipeline.setnx(
|
||||||
|
name = key,
|
||||||
|
value = value
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline.expire(
|
||||||
|
name = key,
|
||||||
|
time = expiry
|
||||||
|
)
|
||||||
|
|
||||||
|
result = pipeline.execute()
|
||||||
|
|
||||||
|
if result:
|
||||||
|
return True
|
||||||
|
|
||||||
|
raise DatabaseRedisSetException()
|
||||||
|
|
||||||
def incr(self, key: str, amount: int = 1) -> bool | Exception:
|
def incr(self, key: str, amount: int = 1) -> bool | Exception:
|
||||||
return self.instance.incr(name = key, amount = amount)
|
return self.instance.incr(name = key, amount = amount)
|
||||||
|
|
||||||
|
@ -17,6 +48,59 @@ class DatabaseRedisDocument(DatabaseRedisConnection):
|
||||||
def delete(self, key: str) -> any:
|
def delete(self, key: str) -> any:
|
||||||
return self.instance.delete(key)
|
return self.instance.delete(key)
|
||||||
|
|
||||||
|
# hash
|
||||||
|
|
||||||
|
def hset_with_expiry(self, key: str, mapping: list, expiry: int) -> bool | Exception:
|
||||||
|
pipeline = self.instance.pipeline()
|
||||||
|
|
||||||
|
pipeline.hset(
|
||||||
|
name = key,
|
||||||
|
mapping = mapping
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline.expire(
|
||||||
|
name = key,
|
||||||
|
time = expiry
|
||||||
|
)
|
||||||
|
|
||||||
|
result = pipeline.execute()
|
||||||
|
|
||||||
|
if result:
|
||||||
|
return True
|
||||||
|
|
||||||
|
raise DatabaseRedisSetException()
|
||||||
|
|
||||||
|
def hgetall(self, key: str) -> any:
|
||||||
|
try:
|
||||||
|
return self.instance.hgetall(
|
||||||
|
name = key
|
||||||
|
)
|
||||||
|
|
||||||
|
except:
|
||||||
|
raise DatabaseRedisGetException()
|
||||||
|
|
||||||
|
# list
|
||||||
|
|
||||||
|
def lpush_with_expiry(self, key: str, value: str, expiry: int) -> bool | Exception:
|
||||||
|
pipeline = self.instance.pipeline()
|
||||||
|
|
||||||
|
pipeline.lpush(
|
||||||
|
key,
|
||||||
|
value
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline.expire(
|
||||||
|
name = key,
|
||||||
|
time = expiry
|
||||||
|
)
|
||||||
|
|
||||||
|
result = pipeline.execute()
|
||||||
|
|
||||||
|
if result:
|
||||||
|
return True
|
||||||
|
|
||||||
|
raise DatabaseRedisSetException()
|
||||||
|
|
||||||
class DatabaseRedisSetException(Exception):
|
class DatabaseRedisSetException(Exception):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -34,7 +34,7 @@ class SecurityBlacklistMemory(DatabaseRedisDocument):
|
||||||
"""
|
"""
|
||||||
Verifies if an IP address is in the blacklist.
|
Verifies if an IP address is in the blacklist.
|
||||||
|
|
||||||
:param item: a valid IP address.
|
:param ip_address: a valid IP address.
|
||||||
:return: returns the TTL of the item.
|
:return: returns the TTL of the item.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ class SecurityBlacklistMemory(DatabaseRedisDocument):
|
||||||
"""
|
"""
|
||||||
Removes an IP address from the blacklist.
|
Removes an IP address from the blacklist.
|
||||||
|
|
||||||
:param item: a valid IP address.
|
:param ip_address: a valid IP address.
|
||||||
:return: true if the item has been removed.
|
:return: true if the item has been removed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -65,6 +65,82 @@ class SecurityBlacklistMemory(DatabaseRedisDocument):
|
||||||
except DatabaseRedisSetException:
|
except DatabaseRedisSetException:
|
||||||
raise SecurityBlacklistMemorySetException()
|
raise SecurityBlacklistMemorySetException()
|
||||||
|
|
||||||
|
def add_refresh_token(self, refresh_token: str, duration: int) -> bool | Exception:
|
||||||
|
"""
|
||||||
|
Blacklists a refresh token.
|
||||||
|
|
||||||
|
:param refresh_token: an active refresh token.
|
||||||
|
:param duration: duration of the blacklist in seconds.
|
||||||
|
:return: true if the item has been stored.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.set_with_expiry(
|
||||||
|
key = f'{self.token_prefix}:{refresh_token}',
|
||||||
|
value = '1',
|
||||||
|
expiry = duration
|
||||||
|
)
|
||||||
|
|
||||||
|
except DatabaseRedisSetException:
|
||||||
|
raise SecurityBlacklistMemorySetException()
|
||||||
|
|
||||||
|
def exists_by_refresh_token(self, refresh_token: str) -> bool | Exception:
|
||||||
|
"""
|
||||||
|
Verifies if a refresh token is in the blacklist.
|
||||||
|
|
||||||
|
:param refresh_token: a valid refresh token.
|
||||||
|
:return: returns the TTL of the item.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.get(key = f'{self.token_prefix}:{refresh_token}')
|
||||||
|
|
||||||
|
if response:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
except DatabaseRedisGetException:
|
||||||
|
raise SecurityBlacklistMemoryGetException()
|
||||||
|
|
||||||
|
def add_access_token(self, access_token: str, duration: int) -> bool | Exception:
|
||||||
|
"""
|
||||||
|
Blacklists an access token.
|
||||||
|
|
||||||
|
:param access_token: an active access token.
|
||||||
|
:param duration: duration of the blacklist in seconds.
|
||||||
|
:return: true if the item has been stored.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.set_with_expiry(
|
||||||
|
key = f'{self.token_prefix}:{access_token}',
|
||||||
|
value = '1',
|
||||||
|
expiry = duration
|
||||||
|
)
|
||||||
|
|
||||||
|
except DatabaseRedisSetException:
|
||||||
|
raise SecurityBlacklistMemorySetException()
|
||||||
|
|
||||||
|
def exists_by_access_token(self, access_token: str) -> bool | Exception:
|
||||||
|
"""
|
||||||
|
Verifies if an access token is in the blacklist.
|
||||||
|
|
||||||
|
:param access_token: a valid access token.
|
||||||
|
:return: returns the TTL of the item.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.get(key = f'{self.token_prefix}:{access_token}')
|
||||||
|
|
||||||
|
if response:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
except DatabaseRedisGetException:
|
||||||
|
raise SecurityBlacklistMemoryGetException()
|
||||||
|
|
||||||
class SecurityBlacklistMemorySetException(Exception):
|
class SecurityBlacklistMemorySetException(Exception):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in a new issue