mirror of
https://github.com/fuckpiracyshield/data-storage.git
synced 2024-11-22 13:19:46 +01:00
Added support to sessions. More changes to the blacklist methods.
This commit is contained in:
parent
4acfb7e91c
commit
5397063bbd
4 changed files with 332 additions and 2 deletions
|
@ -0,0 +1 @@
|
|||
|
169
src/piracyshield_data_storage/account/session/memory.py
Normal file
169
src/piracyshield_data_storage/account/session/memory.py
Normal file
|
@ -0,0 +1,169 @@
|
|||
from piracyshield_data_storage.database.redis.document import DatabaseRedisDocument, DatabaseRedisSetException, DatabaseRedisGetException
|
||||
|
||||
class AccountSessionMemory(DatabaseRedisDocument):
|
||||
|
||||
session_prefix = 'session'
|
||||
|
||||
def __init__(self, database: int):
|
||||
super().__init__()
|
||||
|
||||
self.establish(database)
|
||||
|
||||
def add_long_session(self, account_id: str, refresh_token: str, data: dict, duration: int) -> bool | Exception:
|
||||
"""
|
||||
Store an access token generated from a refresh token.
|
||||
|
||||
:param refresh_token: a valid refresh token.
|
||||
:param access_token: a valid access token.
|
||||
:param duration: refresh token expire time.
|
||||
:return: true if the item has been stored.
|
||||
"""
|
||||
|
||||
try:
|
||||
return self.hset_with_expiry(
|
||||
key = f'{self.session_prefix}:{account_id}:long:{refresh_token}',
|
||||
mapping = data,
|
||||
expiry = duration
|
||||
)
|
||||
|
||||
except DatabaseRedisSetException:
|
||||
raise AccountSessionMemorySetException()
|
||||
|
||||
def add_short_session(self, account_id: str, refresh_token: str, access_token: str, data: dict, duration: int) -> bool | Exception:
|
||||
"""
|
||||
Store an access token generated from a refresh token.
|
||||
|
||||
:param refresh_token: a valid refresh token.
|
||||
:param access_token: a valid access token.
|
||||
:param duration: refresh token expire time.
|
||||
:return: true if the item has been stored.
|
||||
"""
|
||||
|
||||
try:
|
||||
return self.hset_with_expiry(
|
||||
key = f'{self.session_prefix}:{account_id}:short:{access_token}',
|
||||
mapping = data,
|
||||
expiry = duration
|
||||
)
|
||||
|
||||
except DatabaseRedisSetException:
|
||||
raise AccountSessionMemorySetException()
|
||||
|
||||
def get_all_by_account(self, account_id: str) -> list | Exception:
|
||||
"""
|
||||
Retrieves all the active long and short sessions.
|
||||
|
||||
:param account_id: a valid account identifier.
|
||||
:return: the requested data.
|
||||
"""
|
||||
|
||||
try:
|
||||
return self.keys(
|
||||
key = f'{self.session_prefix}:{account_id}:*:*'
|
||||
)
|
||||
|
||||
except DatabaseRedisGetException:
|
||||
raise AccountSessionMemoryGetException()
|
||||
|
||||
def get_all_short_by_account(self, account_id: str) -> list | Exception:
|
||||
"""
|
||||
Retrieves all the active short sessions.
|
||||
|
||||
:param account_id: a valid account identifier.
|
||||
:return: the requested data.
|
||||
"""
|
||||
|
||||
try:
|
||||
return self.keys(
|
||||
key = f'{self.session_prefix}:{account_id}:short:*'
|
||||
)
|
||||
|
||||
except DatabaseRedisGetException:
|
||||
raise AccountSessionMemoryGetException()
|
||||
|
||||
def get_session(self, session: str) -> list | Exception:
|
||||
"""
|
||||
Retrieves a single session.
|
||||
|
||||
:param token: a valid refresh or access identifier.
|
||||
:return: the requested data.
|
||||
"""
|
||||
|
||||
try:
|
||||
return self.hgetall(
|
||||
key = session
|
||||
)
|
||||
|
||||
except DatabaseRedisGetException:
|
||||
raise AccountSessionMemoryGetException()
|
||||
|
||||
def find_long_session(self, refresh_token: str) -> list | Exception:
|
||||
"""
|
||||
Retrieves a single session.
|
||||
|
||||
:param token: a valid refresh or access identifier.
|
||||
:return: the requested data.
|
||||
"""
|
||||
|
||||
try:
|
||||
response = self.keys(
|
||||
key = f'{self.session_prefix}:*:long:{refresh_token}'
|
||||
)
|
||||
|
||||
if isinstance(response, list) and len(response):
|
||||
return response.__getitem__(0)
|
||||
|
||||
return response
|
||||
|
||||
except DatabaseRedisGetException:
|
||||
raise AccountSessionMemoryGetException()
|
||||
|
||||
def remove_long_session(self, account_id: str, refresh_token: str) -> bool | Exception:
|
||||
"""
|
||||
Removes a long session.
|
||||
|
||||
:param account_id: a valid account identifier.
|
||||
:param refresh_token: a valid active refresh token.
|
||||
:return: true if the item has been removed.
|
||||
"""
|
||||
|
||||
try:
|
||||
return self.delete(
|
||||
key = f'{self.session_prefix}:{account_id}:long:{refresh_token}'
|
||||
)
|
||||
|
||||
except DatabaseRedisSetException:
|
||||
raise AccountSessionMemorySetException()
|
||||
|
||||
def remove_short_session(self, account_id: str, access_token: str) -> bool | Exception:
|
||||
"""
|
||||
Removes a short session.
|
||||
|
||||
:param account_id: a valid account identifier.
|
||||
:param access_token: a valid active access token.
|
||||
:return: true if the item has been removed.
|
||||
"""
|
||||
|
||||
try:
|
||||
return self.delete(
|
||||
key = f'{self.session_prefix}:{account_id}:short:{access_token}'
|
||||
)
|
||||
|
||||
except DatabaseRedisSetException:
|
||||
raise AccountSessionMemorySetException()
|
||||
|
||||
class AccountSessionMemorySetException(Exception):
|
||||
|
||||
"""
|
||||
Cannot set the value.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
class AccountSessionMemoryGetException(Exception):
|
||||
|
||||
"""
|
||||
Cannot get the value.
|
||||
"""
|
||||
|
||||
pass
|
|
@ -2,12 +2,43 @@ from piracyshield_data_storage.database.redis.connection import DatabaseRedisCon
|
|||
|
||||
class DatabaseRedisDocument(DatabaseRedisConnection):
|
||||
|
||||
def keys(self, key: str) -> any:
|
||||
try:
|
||||
return self.instance.keys(
|
||||
name = key
|
||||
)
|
||||
|
||||
except:
|
||||
raise DatabaseRedisGetException()
|
||||
|
||||
# string
|
||||
|
||||
def set_with_expiry(self, key: str, value: any, expiry: int) -> bool | Exception:
|
||||
if self.instance.set(key, value, ex = expiry) == True:
|
||||
return True
|
||||
|
||||
raise DatabaseRedisSetException()
|
||||
|
||||
def setnx_with_expiry(self, key: str, value: any, expiry: int) -> bool | Exception:
|
||||
pipeline = self.instance.pipeline()
|
||||
|
||||
pipeline.setnx(
|
||||
name = key,
|
||||
value = value
|
||||
)
|
||||
|
||||
pipeline.expire(
|
||||
name = key,
|
||||
time = expiry
|
||||
)
|
||||
|
||||
result = pipeline.execute()
|
||||
|
||||
if result:
|
||||
return True
|
||||
|
||||
raise DatabaseRedisSetException()
|
||||
|
||||
def incr(self, key: str, amount: int = 1) -> bool | Exception:
|
||||
return self.instance.incr(name = key, amount = amount)
|
||||
|
||||
|
@ -17,6 +48,59 @@ class DatabaseRedisDocument(DatabaseRedisConnection):
|
|||
def delete(self, key: str) -> any:
|
||||
return self.instance.delete(key)
|
||||
|
||||
# hash
|
||||
|
||||
def hset_with_expiry(self, key: str, mapping: list, expiry: int) -> bool | Exception:
|
||||
pipeline = self.instance.pipeline()
|
||||
|
||||
pipeline.hset(
|
||||
name = key,
|
||||
mapping = mapping
|
||||
)
|
||||
|
||||
pipeline.expire(
|
||||
name = key,
|
||||
time = expiry
|
||||
)
|
||||
|
||||
result = pipeline.execute()
|
||||
|
||||
if result:
|
||||
return True
|
||||
|
||||
raise DatabaseRedisSetException()
|
||||
|
||||
def hgetall(self, key: str) -> any:
|
||||
try:
|
||||
return self.instance.hgetall(
|
||||
name = key
|
||||
)
|
||||
|
||||
except:
|
||||
raise DatabaseRedisGetException()
|
||||
|
||||
# list
|
||||
|
||||
def lpush_with_expiry(self, key: str, value: str, expiry: int) -> bool | Exception:
|
||||
pipeline = self.instance.pipeline()
|
||||
|
||||
pipeline.lpush(
|
||||
key,
|
||||
value
|
||||
)
|
||||
|
||||
pipeline.expire(
|
||||
name = key,
|
||||
time = expiry
|
||||
)
|
||||
|
||||
result = pipeline.execute()
|
||||
|
||||
if result:
|
||||
return True
|
||||
|
||||
raise DatabaseRedisSetException()
|
||||
|
||||
class DatabaseRedisSetException(Exception):
|
||||
|
||||
"""
|
||||
|
|
|
@ -34,7 +34,7 @@ class SecurityBlacklistMemory(DatabaseRedisDocument):
|
|||
"""
|
||||
Verifies if an IP address is in the blacklist.
|
||||
|
||||
:param item: a valid IP address.
|
||||
:param ip_address: a valid IP address.
|
||||
:return: returns the TTL of the item.
|
||||
"""
|
||||
|
||||
|
@ -53,7 +53,7 @@ class SecurityBlacklistMemory(DatabaseRedisDocument):
|
|||
"""
|
||||
Removes an IP address from the blacklist.
|
||||
|
||||
:param item: a valid IP address.
|
||||
:param ip_address: a valid IP address.
|
||||
:return: true if the item has been removed.
|
||||
"""
|
||||
|
||||
|
@ -65,6 +65,82 @@ class SecurityBlacklistMemory(DatabaseRedisDocument):
|
|||
except DatabaseRedisSetException:
|
||||
raise SecurityBlacklistMemorySetException()
|
||||
|
||||
def add_refresh_token(self, refresh_token: str, duration: int) -> bool | Exception:
|
||||
"""
|
||||
Blacklists a refresh token.
|
||||
|
||||
:param refresh_token: an active refresh token.
|
||||
:param duration: duration of the blacklist in seconds.
|
||||
:return: true if the item has been stored.
|
||||
"""
|
||||
|
||||
try:
|
||||
return self.set_with_expiry(
|
||||
key = f'{self.token_prefix}:{refresh_token}',
|
||||
value = '1',
|
||||
expiry = duration
|
||||
)
|
||||
|
||||
except DatabaseRedisSetException:
|
||||
raise SecurityBlacklistMemorySetException()
|
||||
|
||||
def exists_by_refresh_token(self, refresh_token: str) -> bool | Exception:
|
||||
"""
|
||||
Verifies if a refresh token is in the blacklist.
|
||||
|
||||
:param refresh_token: a valid refresh token.
|
||||
:return: returns the TTL of the item.
|
||||
"""
|
||||
|
||||
try:
|
||||
response = self.get(key = f'{self.token_prefix}:{refresh_token}')
|
||||
|
||||
if response:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except DatabaseRedisGetException:
|
||||
raise SecurityBlacklistMemoryGetException()
|
||||
|
||||
def add_access_token(self, access_token: str, duration: int) -> bool | Exception:
|
||||
"""
|
||||
Blacklists an access token.
|
||||
|
||||
:param access_token: an active access token.
|
||||
:param duration: duration of the blacklist in seconds.
|
||||
:return: true if the item has been stored.
|
||||
"""
|
||||
|
||||
try:
|
||||
return self.set_with_expiry(
|
||||
key = f'{self.token_prefix}:{access_token}',
|
||||
value = '1',
|
||||
expiry = duration
|
||||
)
|
||||
|
||||
except DatabaseRedisSetException:
|
||||
raise SecurityBlacklistMemorySetException()
|
||||
|
||||
def exists_by_access_token(self, access_token: str) -> bool | Exception:
|
||||
"""
|
||||
Verifies if an access token is in the blacklist.
|
||||
|
||||
:param access_token: a valid access token.
|
||||
:return: returns the TTL of the item.
|
||||
"""
|
||||
|
||||
try:
|
||||
response = self.get(key = f'{self.token_prefix}:{access_token}')
|
||||
|
||||
if response:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except DatabaseRedisGetException:
|
||||
raise SecurityBlacklistMemoryGetException()
|
||||
|
||||
class SecurityBlacklistMemorySetException(Exception):
|
||||
|
||||
"""
|
||||
|
|
Loading…
Reference in a new issue