Added support to sessions. More changes to the blacklist methods.

This commit is contained in:
Daniele Maglie 2024-01-21 14:21:05 +01:00
parent 4acfb7e91c
commit 5397063bbd
4 changed files with 332 additions and 2 deletions

View file

@ -0,0 +1 @@

View file

@ -0,0 +1,169 @@
from piracyshield_data_storage.database.redis.document import DatabaseRedisDocument, DatabaseRedisSetException, DatabaseRedisGetException
class AccountSessionMemory(DatabaseRedisDocument):
session_prefix = 'session'
def __init__(self, database: int):
super().__init__()
self.establish(database)
def add_long_session(self, account_id: str, refresh_token: str, data: dict, duration: int) -> bool | Exception:
"""
Store an access token generated from a refresh token.
:param refresh_token: a valid refresh token.
:param access_token: a valid access token.
:param duration: refresh token expire time.
:return: true if the item has been stored.
"""
try:
return self.hset_with_expiry(
key = f'{self.session_prefix}:{account_id}:long:{refresh_token}',
mapping = data,
expiry = duration
)
except DatabaseRedisSetException:
raise AccountSessionMemorySetException()
def add_short_session(self, account_id: str, refresh_token: str, access_token: str, data: dict, duration: int) -> bool | Exception:
"""
Store an access token generated from a refresh token.
:param refresh_token: a valid refresh token.
:param access_token: a valid access token.
:param duration: refresh token expire time.
:return: true if the item has been stored.
"""
try:
return self.hset_with_expiry(
key = f'{self.session_prefix}:{account_id}:short:{access_token}',
mapping = data,
expiry = duration
)
except DatabaseRedisSetException:
raise AccountSessionMemorySetException()
def get_all_by_account(self, account_id: str) -> list | Exception:
"""
Retrieves all the active long and short sessions.
:param account_id: a valid account identifier.
:return: the requested data.
"""
try:
return self.keys(
key = f'{self.session_prefix}:{account_id}:*:*'
)
except DatabaseRedisGetException:
raise AccountSessionMemoryGetException()
def get_all_short_by_account(self, account_id: str) -> list | Exception:
"""
Retrieves all the active short sessions.
:param account_id: a valid account identifier.
:return: the requested data.
"""
try:
return self.keys(
key = f'{self.session_prefix}:{account_id}:short:*'
)
except DatabaseRedisGetException:
raise AccountSessionMemoryGetException()
def get_session(self, session: str) -> list | Exception:
"""
Retrieves a single session.
:param token: a valid refresh or access identifier.
:return: the requested data.
"""
try:
return self.hgetall(
key = session
)
except DatabaseRedisGetException:
raise AccountSessionMemoryGetException()
def find_long_session(self, refresh_token: str) -> list | Exception:
"""
Retrieves a single session.
:param token: a valid refresh or access identifier.
:return: the requested data.
"""
try:
response = self.keys(
key = f'{self.session_prefix}:*:long:{refresh_token}'
)
if isinstance(response, list) and len(response):
return response.__getitem__(0)
return response
except DatabaseRedisGetException:
raise AccountSessionMemoryGetException()
def remove_long_session(self, account_id: str, refresh_token: str) -> bool | Exception:
"""
Removes a long session.
:param account_id: a valid account identifier.
:param refresh_token: a valid active refresh token.
:return: true if the item has been removed.
"""
try:
return self.delete(
key = f'{self.session_prefix}:{account_id}:long:{refresh_token}'
)
except DatabaseRedisSetException:
raise AccountSessionMemorySetException()
def remove_short_session(self, account_id: str, access_token: str) -> bool | Exception:
"""
Removes a short session.
:param account_id: a valid account identifier.
:param access_token: a valid active access token.
:return: true if the item has been removed.
"""
try:
return self.delete(
key = f'{self.session_prefix}:{account_id}:short:{access_token}'
)
except DatabaseRedisSetException:
raise AccountSessionMemorySetException()
class AccountSessionMemorySetException(Exception):
"""
Cannot set the value.
"""
pass
class AccountSessionMemoryGetException(Exception):
"""
Cannot get the value.
"""
pass

View file

@ -2,12 +2,43 @@ from piracyshield_data_storage.database.redis.connection import DatabaseRedisCon
class DatabaseRedisDocument(DatabaseRedisConnection): class DatabaseRedisDocument(DatabaseRedisConnection):
def keys(self, key: str) -> any:
try:
return self.instance.keys(
name = key
)
except:
raise DatabaseRedisGetException()
# string
def set_with_expiry(self, key: str, value: any, expiry: int) -> bool | Exception: def set_with_expiry(self, key: str, value: any, expiry: int) -> bool | Exception:
if self.instance.set(key, value, ex = expiry) == True: if self.instance.set(key, value, ex = expiry) == True:
return True return True
raise DatabaseRedisSetException() raise DatabaseRedisSetException()
def setnx_with_expiry(self, key: str, value: any, expiry: int) -> bool | Exception:
pipeline = self.instance.pipeline()
pipeline.setnx(
name = key,
value = value
)
pipeline.expire(
name = key,
time = expiry
)
result = pipeline.execute()
if result:
return True
raise DatabaseRedisSetException()
def incr(self, key: str, amount: int = 1) -> bool | Exception: def incr(self, key: str, amount: int = 1) -> bool | Exception:
return self.instance.incr(name = key, amount = amount) return self.instance.incr(name = key, amount = amount)
@ -17,6 +48,59 @@ class DatabaseRedisDocument(DatabaseRedisConnection):
def delete(self, key: str) -> any: def delete(self, key: str) -> any:
return self.instance.delete(key) return self.instance.delete(key)
# hash
def hset_with_expiry(self, key: str, mapping: list, expiry: int) -> bool | Exception:
pipeline = self.instance.pipeline()
pipeline.hset(
name = key,
mapping = mapping
)
pipeline.expire(
name = key,
time = expiry
)
result = pipeline.execute()
if result:
return True
raise DatabaseRedisSetException()
def hgetall(self, key: str) -> any:
try:
return self.instance.hgetall(
name = key
)
except:
raise DatabaseRedisGetException()
# list
def lpush_with_expiry(self, key: str, value: str, expiry: int) -> bool | Exception:
pipeline = self.instance.pipeline()
pipeline.lpush(
key,
value
)
pipeline.expire(
name = key,
time = expiry
)
result = pipeline.execute()
if result:
return True
raise DatabaseRedisSetException()
class DatabaseRedisSetException(Exception): class DatabaseRedisSetException(Exception):
""" """

View file

@ -34,7 +34,7 @@ class SecurityBlacklistMemory(DatabaseRedisDocument):
""" """
Verifies if an IP address is in the blacklist. Verifies if an IP address is in the blacklist.
:param item: a valid IP address. :param ip_address: a valid IP address.
:return: returns the TTL of the item. :return: returns the TTL of the item.
""" """
@ -53,7 +53,7 @@ class SecurityBlacklistMemory(DatabaseRedisDocument):
""" """
Removes an IP address from the blacklist. Removes an IP address from the blacklist.
:param item: a valid IP address. :param ip_address: a valid IP address.
:return: true if the item has been removed. :return: true if the item has been removed.
""" """
@ -65,6 +65,82 @@ class SecurityBlacklistMemory(DatabaseRedisDocument):
except DatabaseRedisSetException: except DatabaseRedisSetException:
raise SecurityBlacklistMemorySetException() raise SecurityBlacklistMemorySetException()
def add_refresh_token(self, refresh_token: str, duration: int) -> bool | Exception:
"""
Blacklists a refresh token.
:param refresh_token: an active refresh token.
:param duration: duration of the blacklist in seconds.
:return: true if the item has been stored.
"""
try:
return self.set_with_expiry(
key = f'{self.token_prefix}:{refresh_token}',
value = '1',
expiry = duration
)
except DatabaseRedisSetException:
raise SecurityBlacklistMemorySetException()
def exists_by_refresh_token(self, refresh_token: str) -> bool | Exception:
"""
Verifies if a refresh token is in the blacklist.
:param refresh_token: a valid refresh token.
:return: returns the TTL of the item.
"""
try:
response = self.get(key = f'{self.token_prefix}:{refresh_token}')
if response:
return True
return False
except DatabaseRedisGetException:
raise SecurityBlacklistMemoryGetException()
def add_access_token(self, access_token: str, duration: int) -> bool | Exception:
"""
Blacklists an access token.
:param access_token: an active access token.
:param duration: duration of the blacklist in seconds.
:return: true if the item has been stored.
"""
try:
return self.set_with_expiry(
key = f'{self.token_prefix}:{access_token}',
value = '1',
expiry = duration
)
except DatabaseRedisSetException:
raise SecurityBlacklistMemorySetException()
def exists_by_access_token(self, access_token: str) -> bool | Exception:
"""
Verifies if an access token is in the blacklist.
:param access_token: a valid access token.
:return: returns the TTL of the item.
"""
try:
response = self.get(key = f'{self.token_prefix}:{access_token}')
if response:
return True
return False
except DatabaseRedisGetException:
raise SecurityBlacklistMemoryGetException()
class SecurityBlacklistMemorySetException(Exception): class SecurityBlacklistMemorySetException(Exception):
""" """