mirror of
https://github.com/fuckpiracyshield/component.git
synced 2024-12-22 02:20:50 +01:00
Initial commit.
This commit is contained in:
commit
ca324d66c7
37 changed files with 1510 additions and 0 deletions
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
|
@ -0,0 +1,6 @@
|
|||
__pycache__/
|
||||
build/
|
||||
eggs/
|
||||
.eggs/
|
||||
*.egg
|
||||
*.egg-info/
|
3
README.md
Normal file
3
README.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
### Component
|
||||
|
||||
Necessary components needed by every other package.
|
5
pyproject.toml
Normal file
5
pyproject.toml
Normal file
|
@ -0,0 +1,5 @@
|
|||
[build-system]
|
||||
requires = [
|
||||
"setuptools>=54",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
18
setup.cfg
Normal file
18
setup.cfg
Normal file
|
@ -0,0 +1,18 @@
|
|||
[metadata]
|
||||
name = piracyshield_component
|
||||
version = 1.0.0
|
||||
description = Base Components
|
||||
|
||||
[options]
|
||||
package_dir=
|
||||
=src
|
||||
packages = find:
|
||||
python_requires = >= 3.10
|
||||
install_requires =
|
||||
toml
|
||||
pyjwt
|
||||
argon2-cffi
|
||||
pytz
|
||||
|
||||
[options.packages.find]
|
||||
where = src
|
1
src/piracyshield_component/__init__.py
Normal file
1
src/piracyshield_component/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
|
80
src/piracyshield_component/config.py
Normal file
80
src/piracyshield_component/config.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
from piracyshield_component.environment import Environment
|
||||
|
||||
import os
|
||||
import toml
|
||||
|
||||
class Config:
|
||||
|
||||
"""
|
||||
Configuration utility that supports TOML config creation.
|
||||
"""
|
||||
|
||||
config_path = None
|
||||
|
||||
config_content = None
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
"""
|
||||
Handles the setting of the path of a config file.
|
||||
This is representable by: Environment.CONFIG_PATH/<config_path>/my_config.toml
|
||||
|
||||
:param config_path: a valid path present in CONFIG_PATH directory.
|
||||
"""
|
||||
|
||||
self.config_path = config_path
|
||||
|
||||
self.config_content = self.load()
|
||||
|
||||
def load(self) -> dict | Exception:
|
||||
"""
|
||||
Loads the whole file.
|
||||
|
||||
:return: returns the content of the file.
|
||||
"""
|
||||
|
||||
file_path = f'{Environment.CONFIG_PATH}/{self.config_path}.toml'
|
||||
|
||||
try:
|
||||
return toml.load(file_path)
|
||||
|
||||
except FileNotFoundError:
|
||||
raise ConfigNotFound(f'Impossibile trovare file {file_path}')
|
||||
|
||||
def get(self, key: str) -> str | Exception:
|
||||
"""
|
||||
Gets a single key from the loaded content.
|
||||
|
||||
:param key: a valid key.
|
||||
:return: value of the dictionary key.
|
||||
"""
|
||||
|
||||
try:
|
||||
return self.config_content[key]
|
||||
|
||||
except KeyError:
|
||||
raise ConfigKeyNotFound(f'Impossibile trovare chiave {key}')
|
||||
|
||||
def get_all(self, key: str = None) -> any:
|
||||
"""
|
||||
Returns the whole content of a configuration or its path.
|
||||
|
||||
:return: different types of data.
|
||||
"""
|
||||
|
||||
return self.config_content[key] if key else self.config_path
|
||||
|
||||
class ConfigNotFound(Exception):
|
||||
|
||||
"""
|
||||
No config found.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
class ConfigKeyNotFound(Exception):
|
||||
|
||||
"""
|
||||
Key passed not found.
|
||||
"""
|
||||
|
||||
pass
|
23
src/piracyshield_component/environment.py
Normal file
23
src/piracyshield_component/environment.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
import os
|
||||
|
||||
# TODO: provide development and production environment types
|
||||
# to have customization of the code in the other packages.
|
||||
|
||||
class Environment:
|
||||
|
||||
"""
|
||||
Environment management.
|
||||
|
||||
DATA_PATH consists in the root of the next folders.
|
||||
The structure should be:
|
||||
|
||||
data/
|
||||
config/
|
||||
cache/
|
||||
"""
|
||||
|
||||
DATA_PATH = os.environ['PIRACYSHIELD_DATA_PATH']
|
||||
|
||||
CONFIG_PATH = os.environ['PIRACYSHIELD_CONFIG_PATH']
|
||||
|
||||
CACHE_PATH = os.environ['PIRACYSHIELD_CACHE_PATH']
|
53
src/piracyshield_component/exception.py
Normal file
53
src/piracyshield_component/exception.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
import traceback
|
||||
|
||||
from piracyshield_component.log.logger import Logger
|
||||
|
||||
class ApplicationException(Exception):
|
||||
|
||||
"""
|
||||
Global application exception.
|
||||
This class is invoked as a last gateway from the error.
|
||||
"""
|
||||
|
||||
def __init__(self, code: str, message: str, unrecovered_exception = None):
|
||||
"""
|
||||
Provides context for the exception.
|
||||
|
||||
:param code: predefined code that identifies the error.
|
||||
:param message: short description of the issue.
|
||||
"""
|
||||
|
||||
self._code = code
|
||||
|
||||
self._message = message
|
||||
|
||||
self._unrecovered_exception = unrecovered_exception
|
||||
|
||||
logger = Logger('application')
|
||||
|
||||
logger.debug(f'{code}: {message}')
|
||||
|
||||
if unrecovered_exception:
|
||||
self._traceback = traceback.format_exc()
|
||||
|
||||
logger.error(f'Unrecovered exception: {unrecovered_exception} {self._traceback}')
|
||||
|
||||
@property
|
||||
def code(self) -> str:
|
||||
"""
|
||||
Sets the options.
|
||||
|
||||
:return: string containing the error code.
|
||||
"""
|
||||
|
||||
return self._code
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
"""
|
||||
Sets the options.
|
||||
|
||||
:return: string of the error message.
|
||||
"""
|
||||
|
||||
return self._message
|
1
src/piracyshield_component/io/__init__.py
Normal file
1
src/piracyshield_component/io/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
|
28
src/piracyshield_component/io/filesystem.py
Normal file
28
src/piracyshield_component/io/filesystem.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
import os
|
||||
|
||||
class Filesystem:
|
||||
|
||||
"""
|
||||
Class to manage elements on the filesystem.
|
||||
"""
|
||||
|
||||
def get_size(self, absolute_file_path) -> int | Exception:
|
||||
"""
|
||||
Returns the size in Bytes format.
|
||||
|
||||
:param file_path: the absolute path of the file.
|
||||
"""
|
||||
|
||||
if os.path.exists(absolute_file_path):
|
||||
return os.path.getsize(absolute_file_path)
|
||||
|
||||
else:
|
||||
raise FilesystemNotFoundException()
|
||||
|
||||
class FilesystemNotFoundException(Exception):
|
||||
|
||||
"""
|
||||
Element not found in the filesystem.
|
||||
"""
|
||||
|
||||
pass
|
1
src/piracyshield_component/log/__init__.py
Normal file
1
src/piracyshield_component/log/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
|
68
src/piracyshield_component/log/formatter.py
Normal file
68
src/piracyshield_component/log/formatter.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
import logging
|
||||
|
||||
class ColorFormatter(logging.Formatter):
|
||||
|
||||
"""
|
||||
Manages color output for the console.
|
||||
"""
|
||||
|
||||
def __init__(self, fmt = None, datefmt = None, style = '%'):
|
||||
"""
|
||||
Defines the final format of the message.
|
||||
|
||||
# TODO: this should be better explained.
|
||||
|
||||
Further info can be found on the official documentation:
|
||||
- https://docs.python.org/3/library/logging.html#formatter-objects
|
||||
|
||||
:param fmt: optional format of the string.
|
||||
:param datefmt: optional date format of the string.
|
||||
:param style: formatting string style.
|
||||
"""
|
||||
|
||||
super().__init__(fmt = fmt, datefmt = datefmt, style = style)
|
||||
|
||||
self.color_map = {
|
||||
logging.DEBUG: Color.DEBUG,
|
||||
logging.INFO: Color.INFO,
|
||||
logging.WARNING: Color.WARNING,
|
||||
logging.ERROR: Color.ERROR,
|
||||
logging.CRITICAL: Color.CRITICAL,
|
||||
}
|
||||
|
||||
self.reset_color = '\033[0m'
|
||||
|
||||
def format(self, record) -> str:
|
||||
"""
|
||||
Applies formatting to the passed string.
|
||||
|
||||
:param record: string to be formatted.
|
||||
:return: the formatted string.
|
||||
"""
|
||||
|
||||
levelname = record.levelname
|
||||
|
||||
message = super().format(record)
|
||||
|
||||
color = self.color_map.get(record.levelno)
|
||||
|
||||
if color:
|
||||
message = color + message + self.reset_color
|
||||
|
||||
return message
|
||||
|
||||
class Color:
|
||||
|
||||
"""
|
||||
Default formatting colors for each logging level.
|
||||
"""
|
||||
|
||||
DEBUG = '\033[1;37m'
|
||||
|
||||
INFO = '\033[1;32m'
|
||||
|
||||
WARNING = '\033[1;33m'
|
||||
|
||||
ERROR = '\033[1;31m'
|
||||
|
||||
CRITICAL = '\033[1;31m'
|
172
src/piracyshield_component/log/logger.py
Normal file
172
src/piracyshield_component/log/logger.py
Normal file
|
@ -0,0 +1,172 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from piracyshield_component.environment import Environment
|
||||
from piracyshield_component.config import Config
|
||||
|
||||
from piracyshield_component.log.output.console import ConsoleOutput
|
||||
from piracyshield_component.log.output.filesystem import FilesystemOutput
|
||||
|
||||
import logging
|
||||
|
||||
class Logger:
|
||||
|
||||
"""
|
||||
Application logging management class.
|
||||
"""
|
||||
|
||||
name = None
|
||||
|
||||
logger = None
|
||||
|
||||
general_config = None
|
||||
|
||||
console_config = None
|
||||
|
||||
filesystem_config = None
|
||||
|
||||
def __init__(self, name: str):
|
||||
"""
|
||||
Initializes a new logging instance using a nem as identifier.
|
||||
|
||||
:param name: a valid string representing the type of service that wants to log the operations.
|
||||
"""
|
||||
|
||||
self._prepare_configs()
|
||||
|
||||
self.name = name
|
||||
|
||||
self.logger = logging.getLogger(name)
|
||||
|
||||
self.logger.setLevel(self._get_level(self.general_config['level']))
|
||||
|
||||
self._register_handlers()
|
||||
|
||||
def _register_handlers(self) -> None:
|
||||
"""
|
||||
Registers different kind of handlers.
|
||||
|
||||
Supporterd outputs:
|
||||
- console
|
||||
- filesystem
|
||||
"""
|
||||
|
||||
if self.console_config['enabled'] == True and not self._has_handler(ConsoleOutput):
|
||||
self.logger.addHandler(
|
||||
ConsoleOutput(
|
||||
format_syntax = self.general_config['format'],
|
||||
colorize = self.console_config['colorize']
|
||||
)
|
||||
)
|
||||
|
||||
if self.filesystem_config['enabled'] == True and not self._has_handler(FilesystemOutput):
|
||||
self.logger.addHandler(
|
||||
FilesystemOutput(
|
||||
name = self.name,
|
||||
path = self.filesystem_config['path'],
|
||||
format_syntax = self.general_config['format']
|
||||
)
|
||||
)
|
||||
|
||||
def _has_handler(self, instance) -> bool:
|
||||
"""
|
||||
Avoids duplications of the already added handlers.
|
||||
|
||||
:return: true or false if the handler is already present or not.
|
||||
"""
|
||||
|
||||
return any(isinstance(h, instance) for h in self.logger.handlers)
|
||||
|
||||
def _get_level(self, level: str) -> int | Exception:
|
||||
"""
|
||||
Handles the level based on a fixed list of options.
|
||||
|
||||
:return: the real integer value of the logging levels.
|
||||
"""
|
||||
|
||||
match level:
|
||||
case 'debug':
|
||||
return logging.DEBUG
|
||||
|
||||
case 'info':
|
||||
return logging.INFO
|
||||
|
||||
case 'warning':
|
||||
return logging.WARNING
|
||||
|
||||
case 'error':
|
||||
return logging.ERROR
|
||||
|
||||
case 'critical':
|
||||
return logging.CRITICAL
|
||||
|
||||
case _:
|
||||
raise LoggerLevelNotFound()
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
"""
|
||||
Handles debugging messages.
|
||||
Logs minor informations useful for debugging.
|
||||
|
||||
:param message: the string we want to log.
|
||||
"""
|
||||
|
||||
self.logger.debug(message)
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""
|
||||
Handles info messages.
|
||||
Logs ordnary operations.
|
||||
|
||||
:param message: the string we want to log.
|
||||
"""
|
||||
|
||||
self.logger.info(message)
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""
|
||||
Handles warning messages.
|
||||
Logs issues of any kind.
|
||||
|
||||
:param message: the string we want to log.
|
||||
"""
|
||||
|
||||
self.logger.warning(message)
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""
|
||||
Handles error messages.
|
||||
Logs recoverable errors.
|
||||
|
||||
:param message: the string we want to log.
|
||||
"""
|
||||
|
||||
self.logger.error(message)
|
||||
|
||||
def critical(self, message: str) -> None:
|
||||
"""
|
||||
Handles critical messages.
|
||||
This should be considered a step before quitting the application.
|
||||
|
||||
:param message: the string we want to log.
|
||||
"""
|
||||
|
||||
self.logger.critical(message)
|
||||
|
||||
def _prepare_configs(self):
|
||||
"""
|
||||
Register configurations.
|
||||
"""
|
||||
|
||||
self.general_config = Config('logger').get('general')
|
||||
|
||||
self.console_config = Config('logger').get('console')
|
||||
|
||||
self.filesystem_config = Config('logger').get('filesystem')
|
||||
|
||||
class LoggerLevelNotFound(Exception):
|
||||
|
||||
"""
|
||||
Log level specified is non valid.
|
||||
"""
|
||||
|
||||
pass
|
1
src/piracyshield_component/log/output/__init__.py
Normal file
1
src/piracyshield_component/log/output/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
|
30
src/piracyshield_component/log/output/console.py
Normal file
30
src/piracyshield_component/log/output/console.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
import logging
|
||||
import sys
|
||||
|
||||
from piracyshield_component.log.formatter import ColorFormatter
|
||||
|
||||
class ConsoleOutput(logging.StreamHandler):
|
||||
|
||||
"""
|
||||
Manages the console output.
|
||||
"""
|
||||
|
||||
def __init__(self, format_syntax: str, colorize: bool = True):
|
||||
"""
|
||||
Sets the output options.
|
||||
|
||||
:param format_syntax: the string format to apply.
|
||||
:param colorize: whether to apply or not colorization of the output.
|
||||
"""
|
||||
|
||||
super().__init__(stream = sys.stdout)
|
||||
|
||||
formatter = None
|
||||
|
||||
if colorize == True:
|
||||
formatter = ColorFormatter(format_syntax)
|
||||
|
||||
else:
|
||||
formatter = logging.Formatter(format_syntax)
|
||||
|
||||
self.setFormatter(formatter)
|
37
src/piracyshield_component/log/output/filesystem.py
Normal file
37
src/piracyshield_component/log/output/filesystem.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
from piracyshield_component.utils.time import Time
|
||||
|
||||
import logging
|
||||
|
||||
class FilesystemOutput(logging.FileHandler):
|
||||
|
||||
"""
|
||||
Manages filesystem output.
|
||||
"""
|
||||
|
||||
FORMAT = "%d-%m-%Y"
|
||||
|
||||
def __init__(self, name: str, path: str, format_syntax: str):
|
||||
"""
|
||||
Sets the output options.
|
||||
|
||||
:param name: the filename of the log.
|
||||
:param path: absolute path of the logging directory.
|
||||
:param format_syntax: formatting string style.
|
||||
"""
|
||||
|
||||
super().__init__(self._get_filename(name, path))
|
||||
|
||||
self.setFormatter(logging.Formatter(format_syntax))
|
||||
|
||||
def _get_filename(self, name: str, path: str) -> str:
|
||||
"""
|
||||
Resolves the file position.
|
||||
|
||||
:param name: the filename of the log.
|
||||
:param path: absolute path of the log.
|
||||
:return: absolute filename path.
|
||||
"""
|
||||
|
||||
now = Time.now_format(self.FORMAT)
|
||||
|
||||
return f'/{path}/{now}-{name}.log'
|
1
src/piracyshield_component/security/__init__.py
Normal file
1
src/piracyshield_component/security/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
|
90
src/piracyshield_component/security/checksum.py
Normal file
90
src/piracyshield_component/security/checksum.py
Normal file
|
@ -0,0 +1,90 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
|
||||
class Checksum:
|
||||
|
||||
"""
|
||||
Helper class for checksum calculations.
|
||||
"""
|
||||
|
||||
def from_string(self, algorithm: str, string: str) -> str:
|
||||
"""
|
||||
Calculates checksum of a given string.
|
||||
|
||||
:param str: the string to process.
|
||||
:return: checksum string.
|
||||
"""
|
||||
|
||||
try:
|
||||
instance = hashlib.new(algorithm)
|
||||
|
||||
# convert the string into bytes before handling
|
||||
instance.update(bytes(string, 'utf-8'))
|
||||
|
||||
except ValueError:
|
||||
raise ChecksumParameterException()
|
||||
|
||||
except UnicodeEncodeError:
|
||||
raise ChecksumUnicodeException()
|
||||
|
||||
except AttributeError:
|
||||
raise ChecksumParameterException()
|
||||
|
||||
return instance.hexdigest()
|
||||
|
||||
def from_file(self, algorithm: str, file_path: str) -> str:
|
||||
"""
|
||||
Calculates checksum of a given file.
|
||||
|
||||
:param str: the absolute file path.
|
||||
:return: checksum string.
|
||||
"""
|
||||
|
||||
try:
|
||||
instance = hashlib.new(algorithm)
|
||||
|
||||
with open(file_path, 'rb') as handle:
|
||||
while True:
|
||||
data = handle.read(65536)
|
||||
|
||||
# EOF
|
||||
if not data:
|
||||
break
|
||||
|
||||
instance.update(data)
|
||||
|
||||
return instance.hexdigest()
|
||||
|
||||
except ValueError:
|
||||
raise ChecksumParameterException()
|
||||
|
||||
except AttributeError:
|
||||
raise ChecksumParameterException()
|
||||
|
||||
except:
|
||||
raise ChecksumCalculationException()
|
||||
|
||||
class ChecksumCalculationException(Exception):
|
||||
|
||||
"""
|
||||
Cannot process the file.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
class ChecksumUnicodeException(Exception):
|
||||
|
||||
"""
|
||||
Cannot encode the data to UTF-8.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
class ChecksumParameterException(Exception):
|
||||
|
||||
"""
|
||||
Most likely not a string.
|
||||
"""
|
||||
|
||||
pass
|
23
src/piracyshield_component/security/filter.py
Normal file
23
src/piracyshield_component/security/filter.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
import re
|
||||
|
||||
class Filter:
|
||||
|
||||
"""
|
||||
Generic input filter utility.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def strip(value: str, character: str = ' '):
|
||||
"""
|
||||
Strips the character from the start and the end of a string.
|
||||
"""
|
||||
|
||||
return value.strip(character)
|
||||
|
||||
@staticmethod
|
||||
def remove_whitespace(value: str):
|
||||
"""
|
||||
Removes whitespaces from a string.
|
||||
"""
|
||||
|
||||
return re.sub(r'\s+', '', value)
|
78
src/piracyshield_component/security/hasher.py
Normal file
78
src/piracyshield_component/security/hasher.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import argon2
|
||||
|
||||
class Hasher:
|
||||
|
||||
"""
|
||||
Helper class to encode and verify strings. Used as password encoder.
|
||||
"""
|
||||
|
||||
hasher_instance = None
|
||||
|
||||
def __init__(self, time_cost: int, memory_cost: int, parallelism: int, hash_length: int, salt_length: int):
|
||||
"""
|
||||
Initialize the instance and sets the options.
|
||||
|
||||
!! THESE PARAMETERS SHOULD NOT BE CHANGED AFTER THE FIRST USAGE !!
|
||||
|
||||
:param time_cost: execution time cost of the hashing operation.
|
||||
:param memory_cost: memory usage of the hasing operation.
|
||||
:param parallelism: quantity of parallel threads used.
|
||||
:param hash_length: length of the hash.
|
||||
:param salt_length: length of the generated salt.
|
||||
"""
|
||||
|
||||
self.hasher_instance = argon2.PasswordHasher(
|
||||
time_cost = time_cost,
|
||||
memory_cost = memory_cost,
|
||||
parallelism = parallelism,
|
||||
hash_len = hash_length,
|
||||
salt_len = salt_length,
|
||||
type = argon2.low_level.Type.ID
|
||||
)
|
||||
|
||||
def encode_string(self, string: str) -> str:
|
||||
"""
|
||||
Encode the string as an argon2 hash.
|
||||
|
||||
:param str: the string to hash.
|
||||
:return: the encoded argon2 hash.
|
||||
"""
|
||||
|
||||
try:
|
||||
# generate the hash
|
||||
hashed_string = self.hasher_instance.hash(string)
|
||||
|
||||
except argon2.exceptions.HashingError:
|
||||
raise HasherGenericException()
|
||||
|
||||
return hashed_string
|
||||
|
||||
def verify_hash(self, string: str, hashed_string: str) -> bool | Exception:
|
||||
"""
|
||||
Verify the string against the hash.
|
||||
|
||||
:param string: the string to verify.
|
||||
:param hashed_string: the argon2 hash to verify against.
|
||||
:return: True if the plaintext password matches the hash, otherwise False.
|
||||
"""
|
||||
|
||||
try:
|
||||
return self.hasher_instance.verify(hashed_string, string)
|
||||
|
||||
except argon2.exceptions.VerificationError:
|
||||
raise HasherNonValidException()
|
||||
|
||||
|
||||
class HasherGenericException(Exception):
|
||||
|
||||
"""
|
||||
Raised during the encoding of the plain text string.
|
||||
"""
|
||||
|
||||
class HasherNonValidException(Exception):
|
||||
|
||||
"""
|
||||
Raised during the verification procedure, if the string is not matching the hash.
|
||||
"""
|
33
src/piracyshield_component/security/identifier.py
Normal file
33
src/piracyshield_component/security/identifier.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import secrets
|
||||
|
||||
class Identifier:
|
||||
|
||||
"""
|
||||
Class for identifiers generation.
|
||||
"""
|
||||
|
||||
def generate(self) -> str:
|
||||
"""
|
||||
Generates RFC 4122 compliant ids.
|
||||
|
||||
:return: a string based on the UUIDv4, minus the dashes.
|
||||
"""
|
||||
|
||||
string = uuid.uuid4()
|
||||
|
||||
# converts to string and removes dashes producing a 32 characters value
|
||||
return string.hex
|
||||
|
||||
def generate_short_unsafe(self, length: int = 8) -> str:
|
||||
"""
|
||||
Generates a short handy string to be used in situations where we need to add a prefix or something like that.
|
||||
Does not guarantees the uniqueness, but that's not an issue for this kind of identifier usage.
|
||||
|
||||
:param length: custom length of the string.
|
||||
:return: an alphanumeric string.
|
||||
"""
|
||||
|
||||
return secrets.token_hex(length)
|
167
src/piracyshield_component/security/token.py
Normal file
167
src/piracyshield_component/security/token.py
Normal file
|
@ -0,0 +1,167 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from piracyshield_component.utils.time import Time, TimeValueException
|
||||
|
||||
import jwt
|
||||
import datetime
|
||||
|
||||
class JWTToken:
|
||||
|
||||
"""
|
||||
Class for the management of the JWT token.
|
||||
"""
|
||||
|
||||
secret_key = None
|
||||
|
||||
algorithm = None
|
||||
|
||||
expiration_time = None
|
||||
|
||||
def __init__(self, access_secret_key: str, refresh_secret_key: str, access_expiration_time: int, refresh_expiration_time: int, algorithm: str):
|
||||
"""
|
||||
Sets the options.
|
||||
|
||||
!! THESE PARAMETERS SHOULD NOT BE CHANGED AFTER THE FIRST USAGE !!
|
||||
|
||||
:param access_secret_key: the secret key for the access token generation.
|
||||
:param access_secret_key: the secret key for the refresh token generation.
|
||||
:param access_expiration_time: time duration for the access token, set in seconds.
|
||||
:param access_expiration_time: time duration for the refresh token, set in seconds.
|
||||
:param algorithm: algorithm for encoding and decoding (default: HS256).
|
||||
"""
|
||||
|
||||
self.access_secret_key = access_secret_key
|
||||
|
||||
self.refresh_secret_key = refresh_secret_key
|
||||
|
||||
self.access_expiration_time = access_expiration_time
|
||||
|
||||
self.refresh_expiration_time = refresh_expiration_time
|
||||
|
||||
self.algorithm = algorithm
|
||||
|
||||
def generate_access_token(self, payload: dict) -> str | Exception:
|
||||
"""
|
||||
Handles the generation of an access token using the configuration data.
|
||||
It's a short lived token that needs to be refreshed periodically.
|
||||
|
||||
:param payload: dictionary containing the data to encode.
|
||||
:return: a JWT token signed with the secret key and encoded in base64.
|
||||
"""
|
||||
|
||||
return self.generate_token(payload, self.access_secret_key, self.access_expiration_time)
|
||||
|
||||
def generate_refresh_token(self, payload: dict) -> str | Exception:
|
||||
"""
|
||||
Handles the generation of a refresh token using the configuration data.
|
||||
This is intended as a long lived token, used to periodically refresh the access token.
|
||||
|
||||
:param payload: dictionary containing the data to encode.
|
||||
:return: a JWT token signed with the secret key and encoded in base64.
|
||||
"""
|
||||
|
||||
return self.generate_token(payload, self.refresh_secret_key, self.refresh_expiration_time)
|
||||
|
||||
def generate_token(self, payload: dict, secret_key: str, period: int) -> str | Exception:
|
||||
"""
|
||||
Generate a JWT token with the given payload.
|
||||
|
||||
:param payload: a dictionary containing the data to be encoded in the JWT token.
|
||||
:return: a JWT token signed with the secret key and encoded in base64.
|
||||
"""
|
||||
|
||||
# calculate the expiration time
|
||||
now = Time.now()
|
||||
|
||||
# add the issued time
|
||||
payload['iat'] = now
|
||||
|
||||
# add the expiration time
|
||||
payload['exp'] = now + datetime.timedelta(seconds = period)
|
||||
|
||||
try:
|
||||
# generate the final token
|
||||
token = jwt.encode(payload, secret_key, algorithm = self.algorithm)
|
||||
|
||||
return token
|
||||
|
||||
# TODO: get a more granular error reporting here
|
||||
except Exception as e:
|
||||
raise JWTTokenGenericException()
|
||||
|
||||
def verify_access_token(self, token: any) -> dict | Exception:
|
||||
"""
|
||||
Verifies the access token with the configured access secret key.
|
||||
|
||||
:param token: a valid JWT access token.
|
||||
:return: the original payload.
|
||||
"""
|
||||
|
||||
return self.verify_token(token, self.access_secret_key)
|
||||
|
||||
def verify_refresh_token(self, token: any) -> dict | Exception:
|
||||
"""
|
||||
Verifies the refresh token with the configured refresh secret key.
|
||||
|
||||
:param token: a valid JWT access token.
|
||||
:return: the original payload.
|
||||
"""
|
||||
|
||||
return self.verify_token(token, self.refresh_secret_key)
|
||||
|
||||
def verify_token(self, token: any, secret_key: str) -> dict | Exception:
|
||||
"""
|
||||
Verify the given JWT token and return the decoded payload if the token is valid.
|
||||
|
||||
:param token: a JWT token encoded in base64.
|
||||
:return: the decoded payload if the token is valid, otherwise an exception will be raised.
|
||||
"""
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, secret_key, algorithms = [self.algorithm])
|
||||
|
||||
# check if it's expired
|
||||
try:
|
||||
if Time.timestamp_to_datetime(payload['exp']) < Time.now():
|
||||
raise JWTTokenExpiredException()
|
||||
|
||||
except TimeValueException:
|
||||
raise JWTTokenNonValidException()
|
||||
|
||||
return payload
|
||||
|
||||
except jwt.exceptions.ExpiredSignatureError:
|
||||
raise JWTTokenExpiredException()
|
||||
|
||||
# TODO: should deal with more explicit exceptions on our side
|
||||
except (
|
||||
jwt.exceptions.InvalidTokenError,
|
||||
jwt.exceptions.InvalidSignatureError,
|
||||
jwt.exceptions.DecodeError,
|
||||
jwt.exceptions.InvalidAlgorithmError
|
||||
):
|
||||
raise JWTTokenNonValidException()
|
||||
|
||||
class JWTTokenGenericException(Exception):
|
||||
|
||||
"""
|
||||
Generic exception as a last option.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
class JWTTokenExpiredException(Exception):
|
||||
|
||||
"""
|
||||
Exception raised on token expired time.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
class JWTTokenNonValidException(Exception):
|
||||
|
||||
"""
|
||||
Raised during the token verification, if there's no valid match.
|
||||
"""
|
||||
|
||||
pass
|
1
src/piracyshield_component/utils/__init__.py
Normal file
1
src/piracyshield_component/utils/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
|
122
src/piracyshield_component/utils/time.py
Normal file
122
src/piracyshield_component/utils/time.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
import datetime
|
||||
import time
|
||||
|
||||
import pytz
|
||||
|
||||
class Time:
|
||||
|
||||
"""
|
||||
Handy utility for date and/or time generation.
|
||||
"""
|
||||
|
||||
TIMEZONE = 'Europe/Rome'
|
||||
|
||||
@staticmethod
|
||||
def now() -> str:
|
||||
"""
|
||||
Returns the current date and time.
|
||||
|
||||
:return: the current date and time as a string.
|
||||
"""
|
||||
|
||||
timezone = pytz.timezone(Time.TIMEZONE)
|
||||
|
||||
utc = datetime.datetime.now()
|
||||
|
||||
return utc.astimezone(timezone)
|
||||
|
||||
def now_format(datetime_format: str) -> str:
|
||||
"""
|
||||
Returns the current date and time using a custom format.
|
||||
|
||||
:return: the current date and time as a string.
|
||||
"""
|
||||
|
||||
now = Time.now()
|
||||
|
||||
return now.strftime(datetime_format)
|
||||
|
||||
@staticmethod
|
||||
def now_iso8601() -> str:
|
||||
"""
|
||||
Returns the current date and time in ISO 8601 format.
|
||||
|
||||
:return: the current date and time as a string.
|
||||
"""
|
||||
|
||||
now = Time.now()
|
||||
|
||||
return now.isoformat()
|
||||
|
||||
@staticmethod
|
||||
def timestamp():
|
||||
"""
|
||||
Returns the current Unix timestamp as an integer.
|
||||
|
||||
:return: the current Unix timestamp.
|
||||
"""
|
||||
|
||||
now = Time.now()
|
||||
|
||||
return int(now.timestamp())
|
||||
|
||||
@staticmethod
|
||||
def timestamp_to_datetime(timestamp: int) -> bool | Exception:
|
||||
try:
|
||||
# convert the timestamp to a datetime with timezone information
|
||||
return datetime.datetime.fromtimestamp(timestamp, tz = pytz.timezone(Time.TIMEZONE))
|
||||
|
||||
except ValueError:
|
||||
raise TimeValueException()
|
||||
|
||||
@staticmethod
|
||||
def is_expired(date: int, expiration_time: int) -> bool:
|
||||
"""
|
||||
Checks if the provided date is expired.
|
||||
|
||||
:param date: a date in ISO8601 format.
|
||||
:param expiration_time: the distance in seconds for the date to expiry.
|
||||
:return: true if expired.
|
||||
"""
|
||||
|
||||
# convert the date string to a datetime object
|
||||
datetime_object = datetime.datetime.fromisoformat(date)
|
||||
|
||||
# sum the expired time to the date
|
||||
converted_date = datetime_object + datetime.timedelta(seconds = expiration_time)
|
||||
|
||||
current_date = Time.now()
|
||||
|
||||
return current_date > converted_date
|
||||
|
||||
def is_valid_iso8601(date: int) -> bool | Exception:
|
||||
"""
|
||||
Validates the date against the ISO8601 format.
|
||||
|
||||
:param date: a date in ISO8601 format.
|
||||
:return: true if the date format is correct.
|
||||
"""
|
||||
|
||||
try:
|
||||
datetime.datetime.fromisoformat(date)
|
||||
|
||||
return True
|
||||
|
||||
except:
|
||||
raise TimeFormatException()
|
||||
|
||||
class TimeValueException(Exception):
|
||||
|
||||
"""
|
||||
Wrong input.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
class TimeFormatException(Exception):
|
||||
|
||||
"""
|
||||
Wrong date format.
|
||||
"""
|
||||
|
||||
pass
|
1
src/piracyshield_component/validation/__init__.py
Normal file
1
src/piracyshield_component/validation/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
|
33
src/piracyshield_component/validation/rule.py
Normal file
33
src/piracyshield_component/validation/rule.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
class Rule(ABC):
|
||||
|
||||
"""
|
||||
Basic rule class.
|
||||
"""
|
||||
|
||||
errors = []
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize errors list.
|
||||
"""
|
||||
|
||||
self.errors = []
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self):
|
||||
"""
|
||||
Method invoked for the rule processing.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def register_error(self, message) -> None:
|
||||
"""
|
||||
Registers all of the errors occurred during the validation.
|
||||
|
||||
:param message: error message.
|
||||
"""
|
||||
|
||||
self.errors.append(message)
|
1
src/piracyshield_component/validation/rules/__init__.py
Normal file
1
src/piracyshield_component/validation/rules/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
|
31
src/piracyshield_component/validation/rules/as_code.py
Normal file
31
src/piracyshield_component/validation/rules/as_code.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
from piracyshield_component.validation.rule import Rule
|
||||
|
||||
import re
|
||||
|
||||
class ASCode(Rule):
|
||||
|
||||
"""
|
||||
Rule that checks for a valid AS code.
|
||||
"""
|
||||
|
||||
message = 'AS code not valid'
|
||||
|
||||
expression = r'^(AS)?[0-9]{1,10}$'
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize parent __init__.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, value: str) -> None:
|
||||
"""
|
||||
Checks the validity of the AS code.
|
||||
It's flexible enough to allow also non AS/A prefix strings.
|
||||
|
||||
:param value: a valid string.
|
||||
"""
|
||||
|
||||
if not re.search(self.expression, value):
|
||||
self.register_error(self.message)
|
30
src/piracyshield_component/validation/rules/dda.py
Normal file
30
src/piracyshield_component/validation/rules/dda.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
from piracyshield_component.validation.rule import Rule
|
||||
|
||||
import re
|
||||
|
||||
class DDA(Rule):
|
||||
|
||||
"""
|
||||
Rule that checks for a valid DDA identifier.
|
||||
"""
|
||||
|
||||
message = 'DDA identifier not valid'
|
||||
|
||||
expression = r'^[0-9]{3}\/[0-9]{2}\/DDA$'
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize parent __init__.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, value: str) -> None:
|
||||
"""
|
||||
Checks the validity of the DDA identifier.
|
||||
|
||||
:param value: a valid string.
|
||||
"""
|
||||
|
||||
if not re.search(self.expression, value):
|
||||
self.register_error(self.message)
|
31
src/piracyshield_component/validation/rules/email.py
Normal file
31
src/piracyshield_component/validation/rules/email.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
from piracyshield_component.validation.rule import Rule
|
||||
|
||||
import re
|
||||
|
||||
class Email(Rule):
|
||||
|
||||
"""
|
||||
Rule that checks for a valid e-mail address.
|
||||
"""
|
||||
|
||||
message = 'The value must be a valid e-mail address'
|
||||
|
||||
# support also second level domains
|
||||
expression = r'^[a-z0-9]+[\._]?[a-z0-9]+[@]([a-z\-]+.)?\w+[.]\w{2,18}$'
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize parent __init__.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, value: any) -> None:
|
||||
"""
|
||||
Matches our regular expression against the given value.
|
||||
|
||||
param: value: a valid string.
|
||||
"""
|
||||
|
||||
if not re.search(self.expression, value):
|
||||
self.register_error(self.message)
|
30
src/piracyshield_component/validation/rules/fqdn.py
Normal file
30
src/piracyshield_component/validation/rules/fqdn.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
from piracyshield_component.validation.rule import Rule
|
||||
|
||||
import re
|
||||
|
||||
class FQDN(Rule):
|
||||
|
||||
"""
|
||||
Rule that checks for a valid fully qualified domain name.
|
||||
"""
|
||||
|
||||
message = 'FQDN not valid'
|
||||
|
||||
expression = r'^([a-zA-Z0-9]([a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$'
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize parent __init__.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, value: any) -> None:
|
||||
"""
|
||||
Checks for a valid FQDN.
|
||||
|
||||
:param value: a valid string.
|
||||
"""
|
||||
|
||||
if not re.search(self.expression, value):
|
||||
self.register_error(self.message)
|
54
src/piracyshield_component/validation/rules/ipv4.py
Normal file
54
src/piracyshield_component/validation/rules/ipv4.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
from piracyshield_component.validation.rule import Rule
|
||||
|
||||
class IPv4(Rule):
|
||||
|
||||
"""
|
||||
Rule that checks for a valid IPv4.
|
||||
"""
|
||||
|
||||
octets_message = 'IPv4 not valid, expecting four octects, got {}'
|
||||
|
||||
octets_digits_message = 'IPv4 not valid, expecting four octets of digits'
|
||||
|
||||
octets_length_message = 'IPv4 not valid, one or more octet(s) too long'
|
||||
|
||||
octets_digits_size_message = 'IPv4 not valid, expecting digits from 0 to 255'
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize parent __init__.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, value: str) -> None:
|
||||
"""
|
||||
Checks the validity of the passed string.
|
||||
Instead of relying on regular expression, is it better to check the single octects.
|
||||
This also allow for better error reporting.
|
||||
|
||||
:param value: a valid string.
|
||||
"""
|
||||
|
||||
octets = value.split('.')
|
||||
|
||||
octets_size = len(octets)
|
||||
|
||||
# we're expecting 4 octects
|
||||
if octets_size != 4:
|
||||
self.register_error(self.octets_message.format(octets_size))
|
||||
|
||||
for octet in octets:
|
||||
single_octet_size = len(octet)
|
||||
|
||||
# each octect must be an integer
|
||||
if not octet.isdigit():
|
||||
self.register_error(self.octets_digits_message)
|
||||
|
||||
# with a maximum length of 3
|
||||
if single_octet_size > 3:
|
||||
self.register_error(self.octets_length_message)
|
||||
|
||||
# between 0 and 255
|
||||
if single_octet_size < 0 or single_octet_size > 255:
|
||||
self.register_error(self.octets_digits_size_message)
|
60
src/piracyshield_component/validation/rules/ipv6.py
Normal file
60
src/piracyshield_component/validation/rules/ipv6.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
from piracyshield_component.validation.rule import Rule
|
||||
|
||||
class IPv6(Rule):
|
||||
|
||||
"""
|
||||
Rule that checks for a valid IPv6.
|
||||
"""
|
||||
|
||||
hextets_message = 'IPv6 not valid, expecting eight hextets, got {}'
|
||||
|
||||
hextets_digits_message = 'IPv6 not valid, expecting eight hextets of hexadecimal digits'
|
||||
|
||||
hextets_length_message = 'IPv6 not valid, one or more hextet(s) too long'
|
||||
|
||||
hextets_digits_size_message = 'IPv6 not valid, expecting hexadecimal digits from 0 to FFFF'
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize parent __init__.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, value: str) -> None:
|
||||
"""
|
||||
Checks the validity of the passed string.
|
||||
Instead of relying on regular expression, it's better to check the individual hextets.
|
||||
This also allows for better error reporting.
|
||||
|
||||
:param value: a valid string.
|
||||
"""
|
||||
|
||||
hextets = value.split(':')
|
||||
|
||||
hextets_size = len(hextets)
|
||||
|
||||
# we're expecting 8 hextets
|
||||
if hextets_size != 8:
|
||||
self.register_error(self.hextets_message.format(hextets_size))
|
||||
|
||||
for hextet in hextets:
|
||||
single_hextet_size = len(hextet)
|
||||
|
||||
# each hextet must be composed of hexadecimal digits
|
||||
if not all(c in '0123456789ABCDEFabcdef' for c in hextet):
|
||||
self.register_error(self.hextets_digits_message)
|
||||
|
||||
# with a maximum length of 4
|
||||
if single_hextet_size > 4:
|
||||
self.register_error(self.hextets_length_message)
|
||||
|
||||
try:
|
||||
# convert the hextet to an integer in base 16
|
||||
int_value = int(hextet, 16)
|
||||
|
||||
# check if the integer value is within the valid range (0~0xFFFF)
|
||||
if not (0 <= int_value <= 0xFFFF):
|
||||
self.register_error(self.hextets_digits_size_message)
|
||||
|
||||
except ValueError:
|
||||
self.register_error(self.hextets_digits_size_message)
|
37
src/piracyshield_component/validation/rules/length.py
Normal file
37
src/piracyshield_component/validation/rules/length.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
from piracyshield_component.validation.rule import Rule
|
||||
|
||||
class Length(Rule):
|
||||
|
||||
"""
|
||||
Rule that checks the length of a value.
|
||||
"""
|
||||
|
||||
message = 'The value must be comprised between {} and {} characters'
|
||||
|
||||
minimum = None
|
||||
|
||||
maximum = None
|
||||
|
||||
def __init__(self, minimum: int, maximum: int):
|
||||
"""
|
||||
Initialize parent __init__ and set the minimum and maximum allowed length.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.minimum = minimum
|
||||
|
||||
self.maximum = maximum
|
||||
|
||||
def __call__(self, value: any) -> None:
|
||||
"""
|
||||
Check if our value has the correct specified length.
|
||||
|
||||
param: value: a valid string.
|
||||
"""
|
||||
|
||||
# TODO: might want to handle this in depth.
|
||||
length = len(value) if isinstance(value, str) else len(str(value))
|
||||
|
||||
if length < self.minimum or length > self.maximum:
|
||||
self.register_error(self.message.format(self.minimum, self.maximum))
|
36
src/piracyshield_component/validation/rules/required.py
Normal file
36
src/piracyshield_component/validation/rules/required.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
from piracyshield_component.validation.rule import Rule
|
||||
|
||||
class Required(Rule):
|
||||
|
||||
"""
|
||||
Rule that checks if a string is empty.
|
||||
"""
|
||||
|
||||
message = 'A valid string is required'
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize parent __init__.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, value: any) -> None:
|
||||
"""
|
||||
Stores and executes the code.
|
||||
|
||||
param: value: any value.
|
||||
"""
|
||||
|
||||
self.value = value
|
||||
|
||||
self.is_empty()
|
||||
|
||||
def is_empty(self) -> None:
|
||||
"""
|
||||
Check if the string is empty.
|
||||
No filters are applied in this phase as we want to avoid input and output misalignements.
|
||||
"""
|
||||
|
||||
if not self.value:
|
||||
self.register_error(self.message)
|
55
src/piracyshield_component/validation/rules/string.py
Normal file
55
src/piracyshield_component/validation/rules/string.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
from piracyshield_component.validation.rule import Rule
|
||||
|
||||
class String(Rule):
|
||||
|
||||
"""
|
||||
Rule that checks if a string has valid characters.
|
||||
"""
|
||||
|
||||
message = 'The value must be a string containing letters, numbers'
|
||||
|
||||
message_string = 'The value must be a valid string'
|
||||
|
||||
characters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
|
||||
|
||||
def __init__(self, allowed: str = ''):
|
||||
"""
|
||||
Initialize the parent __init__.
|
||||
|
||||
TODO: we might want to extend this to allow a full customization of the characters list
|
||||
(ie. when we don't really want all the characters spectrum)
|
||||
|
||||
:param allowed: a string containing allowed characters other than the default ones.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
||||
# join the allowed characters to our main list and convert them to a unique values dictionary
|
||||
self.characters = set(''.join([allowed, self.characters]))
|
||||
|
||||
def __call__(self, value: any) -> None:
|
||||
"""
|
||||
Stores and executes the code.
|
||||
|
||||
param: value: a valid string.
|
||||
"""
|
||||
|
||||
self.value = value
|
||||
|
||||
self.has_characters()
|
||||
|
||||
def has_characters(self) -> None:
|
||||
"""
|
||||
Checks if the value is a valid string instance and if the characters are allowed.
|
||||
"""
|
||||
|
||||
if not isinstance(self.value, str):
|
||||
self.register_error(self.message_string)
|
||||
|
||||
else:
|
||||
# converts the value into a dictionary with no duplicates
|
||||
exploded_value = set(self.value)
|
||||
|
||||
# check if the characters appear in our list
|
||||
if exploded_value.issubset(self.characters) == False:
|
||||
self.register_error(self.message)
|
69
src/piracyshield_component/validation/validator.py
Normal file
69
src/piracyshield_component/validation/validator.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from piracyshield_component.validation.rule import Rule
|
||||
|
||||
class Validator:
|
||||
|
||||
"""
|
||||
Validation utility class.
|
||||
"""
|
||||
|
||||
value = None
|
||||
|
||||
rules = None
|
||||
|
||||
errors = []
|
||||
|
||||
def __init__(self, value, rules: list):
|
||||
"""
|
||||
Sets the options.
|
||||
|
||||
:param value: the value we want to analyze.
|
||||
:param rules: a list of validation rule classes.
|
||||
"""
|
||||
|
||||
self.value = value
|
||||
|
||||
self.rules = rules
|
||||
|
||||
self.errors = []
|
||||
|
||||
self._check_rules()
|
||||
|
||||
def validate(self) -> None:
|
||||
"""
|
||||
Cycles through the set rules collecting errors, if any.
|
||||
"""
|
||||
|
||||
for rule in self.rules:
|
||||
rule(self.value)
|
||||
|
||||
# merge the errors lists
|
||||
self.errors = self.errors + rule.errors
|
||||
|
||||
def is_valid(self) -> list | bool:
|
||||
"""
|
||||
Returns true when the errors array is filled.
|
||||
"""
|
||||
|
||||
return not self.errors
|
||||
|
||||
def _check_rules(self) -> None | Exception:
|
||||
"""
|
||||
Ensures the validity of the passed rules.
|
||||
"""
|
||||
|
||||
for rule in self.rules:
|
||||
if not isinstance(rule, Rule):
|
||||
raise ValidatorRuleNonValidException()
|
||||
|
||||
# reset errors on each validation
|
||||
rule.errors = []
|
||||
|
||||
class ValidatorRuleNonValidException(Exception):
|
||||
|
||||
"""
|
||||
Raised if the class doesn't inherit the Rule class.
|
||||
"""
|
||||
|
||||
pass
|
Loading…
Reference in a new issue