diff --git a/src/piracyshield_data_storage/account/session/__init__.py b/src/piracyshield_data_storage/account/session/__init__.py new file mode 100644 index 0000000..8d1c8b6 --- /dev/null +++ b/src/piracyshield_data_storage/account/session/__init__.py @@ -0,0 +1 @@ + diff --git a/src/piracyshield_data_storage/account/session/memory.py b/src/piracyshield_data_storage/account/session/memory.py new file mode 100644 index 0000000..fcdf947 --- /dev/null +++ b/src/piracyshield_data_storage/account/session/memory.py @@ -0,0 +1,169 @@ +from piracyshield_data_storage.database.redis.document import DatabaseRedisDocument, DatabaseRedisSetException, DatabaseRedisGetException + +class AccountSessionMemory(DatabaseRedisDocument): + + session_prefix = 'session' + + def __init__(self, database: int): + super().__init__() + + self.establish(database) + + def add_long_session(self, account_id: str, refresh_token: str, data: dict, duration: int) -> bool | Exception: + """ + Store an access token generated from a refresh token. + + :param refresh_token: a valid refresh token. + :param access_token: a valid access token. + :param duration: refresh token expire time. + :return: true if the item has been stored. + """ + + try: + return self.hset_with_expiry( + key = f'{self.session_prefix}:{account_id}:long:{refresh_token}', + mapping = data, + expiry = duration + ) + + except DatabaseRedisSetException: + raise AccountSessionMemorySetException() + + def add_short_session(self, account_id: str, refresh_token: str, access_token: str, data: dict, duration: int) -> bool | Exception: + """ + Store an access token generated from a refresh token. + + :param refresh_token: a valid refresh token. + :param access_token: a valid access token. + :param duration: refresh token expire time. + :return: true if the item has been stored. + """ + + try: + return self.hset_with_expiry( + key = f'{self.session_prefix}:{account_id}:short:{access_token}', + mapping = data, + expiry = duration + ) + + except DatabaseRedisSetException: + raise AccountSessionMemorySetException() + + def get_all_by_account(self, account_id: str) -> list | Exception: + """ + Retrieves all the active long and short sessions. + + :param account_id: a valid account identifier. + :return: the requested data. + """ + + try: + return self.keys( + key = f'{self.session_prefix}:{account_id}:*:*' + ) + + except DatabaseRedisGetException: + raise AccountSessionMemoryGetException() + + def get_all_short_by_account(self, account_id: str) -> list | Exception: + """ + Retrieves all the active short sessions. + + :param account_id: a valid account identifier. + :return: the requested data. + """ + + try: + return self.keys( + key = f'{self.session_prefix}:{account_id}:short:*' + ) + + except DatabaseRedisGetException: + raise AccountSessionMemoryGetException() + + def get_session(self, session: str) -> list | Exception: + """ + Retrieves a single session. + + :param token: a valid refresh or access identifier. + :return: the requested data. + """ + + try: + return self.hgetall( + key = session + ) + + except DatabaseRedisGetException: + raise AccountSessionMemoryGetException() + + def find_long_session(self, refresh_token: str) -> list | Exception: + """ + Retrieves a single session. + + :param token: a valid refresh or access identifier. + :return: the requested data. + """ + + try: + response = self.keys( + key = f'{self.session_prefix}:*:long:{refresh_token}' + ) + + if isinstance(response, list) and len(response): + return response.__getitem__(0) + + return response + + except DatabaseRedisGetException: + raise AccountSessionMemoryGetException() + + def remove_long_session(self, account_id: str, refresh_token: str) -> bool | Exception: + """ + Removes a long session. + + :param account_id: a valid account identifier. + :param refresh_token: a valid active refresh token. + :return: true if the item has been removed. + """ + + try: + return self.delete( + key = f'{self.session_prefix}:{account_id}:long:{refresh_token}' + ) + + except DatabaseRedisSetException: + raise AccountSessionMemorySetException() + + def remove_short_session(self, account_id: str, access_token: str) -> bool | Exception: + """ + Removes a short session. + + :param account_id: a valid account identifier. + :param access_token: a valid active access token. + :return: true if the item has been removed. + """ + + try: + return self.delete( + key = f'{self.session_prefix}:{account_id}:short:{access_token}' + ) + + except DatabaseRedisSetException: + raise AccountSessionMemorySetException() + +class AccountSessionMemorySetException(Exception): + + """ + Cannot set the value. + """ + + pass + +class AccountSessionMemoryGetException(Exception): + + """ + Cannot get the value. + """ + + pass diff --git a/src/piracyshield_data_storage/database/redis/document.py b/src/piracyshield_data_storage/database/redis/document.py index 8f4e8e9..abcfee2 100644 --- a/src/piracyshield_data_storage/database/redis/document.py +++ b/src/piracyshield_data_storage/database/redis/document.py @@ -2,12 +2,43 @@ from piracyshield_data_storage.database.redis.connection import DatabaseRedisCon class DatabaseRedisDocument(DatabaseRedisConnection): + def keys(self, key: str) -> any: + try: + return self.instance.keys( + name = key + ) + + except: + raise DatabaseRedisGetException() + + # string + def set_with_expiry(self, key: str, value: any, expiry: int) -> bool | Exception: if self.instance.set(key, value, ex = expiry) == True: return True raise DatabaseRedisSetException() + def setnx_with_expiry(self, key: str, value: any, expiry: int) -> bool | Exception: + pipeline = self.instance.pipeline() + + pipeline.setnx( + name = key, + value = value + ) + + pipeline.expire( + name = key, + time = expiry + ) + + result = pipeline.execute() + + if result: + return True + + raise DatabaseRedisSetException() + def incr(self, key: str, amount: int = 1) -> bool | Exception: return self.instance.incr(name = key, amount = amount) @@ -17,6 +48,59 @@ class DatabaseRedisDocument(DatabaseRedisConnection): def delete(self, key: str) -> any: return self.instance.delete(key) + # hash + + def hset_with_expiry(self, key: str, mapping: list, expiry: int) -> bool | Exception: + pipeline = self.instance.pipeline() + + pipeline.hset( + name = key, + mapping = mapping + ) + + pipeline.expire( + name = key, + time = expiry + ) + + result = pipeline.execute() + + if result: + return True + + raise DatabaseRedisSetException() + + def hgetall(self, key: str) -> any: + try: + return self.instance.hgetall( + name = key + ) + + except: + raise DatabaseRedisGetException() + + # list + + def lpush_with_expiry(self, key: str, value: str, expiry: int) -> bool | Exception: + pipeline = self.instance.pipeline() + + pipeline.lpush( + key, + value + ) + + pipeline.expire( + name = key, + time = expiry + ) + + result = pipeline.execute() + + if result: + return True + + raise DatabaseRedisSetException() + class DatabaseRedisSetException(Exception): """ diff --git a/src/piracyshield_data_storage/security/blacklist/memory.py b/src/piracyshield_data_storage/security/blacklist/memory.py index dd9b499..199bc15 100644 --- a/src/piracyshield_data_storage/security/blacklist/memory.py +++ b/src/piracyshield_data_storage/security/blacklist/memory.py @@ -34,7 +34,7 @@ class SecurityBlacklistMemory(DatabaseRedisDocument): """ Verifies if an IP address is in the blacklist. - :param item: a valid IP address. + :param ip_address: a valid IP address. :return: returns the TTL of the item. """ @@ -53,7 +53,7 @@ class SecurityBlacklistMemory(DatabaseRedisDocument): """ Removes an IP address from the blacklist. - :param item: a valid IP address. + :param ip_address: a valid IP address. :return: true if the item has been removed. """ @@ -65,6 +65,82 @@ class SecurityBlacklistMemory(DatabaseRedisDocument): except DatabaseRedisSetException: raise SecurityBlacklistMemorySetException() + def add_refresh_token(self, refresh_token: str, duration: int) -> bool | Exception: + """ + Blacklists a refresh token. + + :param refresh_token: an active refresh token. + :param duration: duration of the blacklist in seconds. + :return: true if the item has been stored. + """ + + try: + return self.set_with_expiry( + key = f'{self.token_prefix}:{refresh_token}', + value = '1', + expiry = duration + ) + + except DatabaseRedisSetException: + raise SecurityBlacklistMemorySetException() + + def exists_by_refresh_token(self, refresh_token: str) -> bool | Exception: + """ + Verifies if a refresh token is in the blacklist. + + :param refresh_token: a valid refresh token. + :return: returns the TTL of the item. + """ + + try: + response = self.get(key = f'{self.token_prefix}:{refresh_token}') + + if response: + return True + + return False + + except DatabaseRedisGetException: + raise SecurityBlacklistMemoryGetException() + + def add_access_token(self, access_token: str, duration: int) -> bool | Exception: + """ + Blacklists an access token. + + :param access_token: an active access token. + :param duration: duration of the blacklist in seconds. + :return: true if the item has been stored. + """ + + try: + return self.set_with_expiry( + key = f'{self.token_prefix}:{access_token}', + value = '1', + expiry = duration + ) + + except DatabaseRedisSetException: + raise SecurityBlacklistMemorySetException() + + def exists_by_access_token(self, access_token: str) -> bool | Exception: + """ + Verifies if an access token is in the blacklist. + + :param access_token: a valid access token. + :return: returns the TTL of the item. + """ + + try: + response = self.get(key = f'{self.token_prefix}:{access_token}') + + if response: + return True + + return False + + except DatabaseRedisGetException: + raise SecurityBlacklistMemoryGetException() + class SecurityBlacklistMemorySetException(Exception): """