diff --git a/Makefile b/Makefile index 4088abc..7805b15 100644 --- a/Makefile +++ b/Makefile @@ -34,4 +34,4 @@ sast: bench: @echo "Running benchmark" ab -v 2 -n 20 -c 10 -k -T application/json -p tests/bench_params.txt 127.0.0.1:58008/api/v1/authentication/login - #ab -v 2 -n 20 -c 10 -k 127.0.0.1:58008/api/v1/ping + #ab -v 2 -n 200 -c 10 -k 127.0.0.1:58008/api/v1/ping diff --git a/application.py b/application.py index bce47bf..13c19c2 100644 --- a/application.py +++ b/application.py @@ -11,7 +11,13 @@ class Application(tornado.web.Application): Application initializer. """ - def __init__(self, debug: bool, handlers: list, version: str, prefix: str, cookie_secret: str, cache_path: str): + domain = 'localhost' + + allowed_methods = 'POST, GET' + + max_age = 3600 + + def __init__(self, debug: bool, domain: str, handlers: list, version: str, prefix: str, cookie_secret: str, cache_path: str): """ This is a list of handlers that directly extend the RequestHandler class, instead of relying on a BaseHandler class. @@ -23,6 +29,9 @@ class Application(tornado.web.Application): :return: a list of routes and their handlers. """ + if domain: + self.domain = domain + if prefix: version = f'{prefix}/{version}' diff --git a/boot.py b/boot.py index 6a54f6f..0058ae1 100644 --- a/boot.py +++ b/boot.py @@ -7,13 +7,17 @@ from v1.routes import APIv1 from application import Application -api_config = Config('api').get('general') +application_config = Config('application').get('general') +api_config = Config('application').get('api') if __name__ == "__main__": app = Application( # wether to run the application using debug mode debug = api_config['debug'], + # useful for CORS settings + domain = application_config['domain'], + # load current routes handlers = APIv1.routes, diff --git a/interceptors/logger.py b/interceptors/logger.py new file mode 100644 index 0000000..12d83dd --- /dev/null +++ b/interceptors/logger.py @@ -0,0 +1,18 @@ +from piracyshield_component.log.logger import Logger + +class LoggerInterceptor: + + """ + Logs requests + """ + + async def execute(self, r): + r.application.logger.debug(f'> GET `{r.request.uri}` from `{r.request.remote_ip}`') + +class LoggerInterceptorException(Exception): + + """ + We need this as a placeholder. + """ + + pass diff --git a/interceptors/security/blacklist.py b/interceptors/security/blacklist.py new file mode 100644 index 0000000..dbef568 --- /dev/null +++ b/interceptors/security/blacklist.py @@ -0,0 +1,34 @@ +from piracyshield_service.security.blacklist.exists_by_ip_address import SecurityBlacklistExistsByIPAddressService + +from ioutils.errors import ErrorCode, ErrorMessage + +class BlacklistInterceptor: + + """ + Applies bans. + """ + + security_blacklist_exists_by_ip_address_service = None + + async def execute(self, r): + await self._prepare_modules() + + if self.security_blacklist_exists_by_ip_address_service.execute( + ip_address = r.request.remote_ip + ) == True: + raise BlacklistInterceptorException() + + async def _prepare_modules(self): + self.security_blacklist_exists_by_ip_address_service = SecurityBlacklistExistsByIPAddressService() + +class BlacklistInterceptorException(Exception): + + """ + The IP address is temporary banned. + """ + + status_code = 403 + + error_code = ErrorCode.IP_ADDRESS_BLACKLISTED + + error_message = ErrorMessage.IP_ADDRESS_BLACKLISTED diff --git a/interceptors/security/rate_limit.py b/interceptors/security/rate_limit.py new file mode 100644 index 0000000..af9fa85 --- /dev/null +++ b/interceptors/security/rate_limit.py @@ -0,0 +1,68 @@ +from tornado.locks import Lock +from collections import deque +import time + +from ioutils.errors import ErrorCode, ErrorMessage + +class RateLimitInterceptor: + + """ + Blocks flooding. + """ + + # override the default methods + SUPPORTED_METHODS = ("GET", "POST") + + # max requests allowed in a second + MAX_REQUESTS_PER_SECOND = 100 + + # requests container + REQUESTS = {} + + # needs a mutex to manage the async threads writes + mutex = Lock() + + async def execute(self, _): + # get the current timestamp in seconds + timestamp = int(time.time()) + + async with self.mutex: + if timestamp not in self.REQUESTS: + self.REQUESTS[timestamp] = deque(maxlen = self.MAX_REQUESTS_PER_SECOND) + + # check if we exceeded the rate limit + if len(self.REQUESTS[timestamp]) >= self.MAX_REQUESTS_PER_SECOND: + raise RateLimitInterceptorException() + + # add the current request + self.REQUESTS[timestamp].append(time.time()) + + # trigger the old timestamps cleaning + await self._clean_timestamps(timestamp) + + async def _clean_timestamps(self, current_timestamp): + """ + Cleanup the old timestamps that are no longer relevant. + + NOTE: this is quite inefficient and slow for Python, but that's all we can do to manage this in the application. + + :param current_timestamp: current timestamp. + """ + + async with self.mutex: + # keep only the keys for the current and previous second + for timestamp in list(self.REQUESTS.keys()): + if timestamp < current_timestamp - 1: + del self.REQUESTS[timestamp] + +class RateLimitInterceptorException(Exception): + + """ + Too many requests. + """ + + status_code = 429 + + error_code = ErrorCode.TOO_MANY_REQUESTS + + error_message = ErrorMessage.TOO_MANY_REQUESTS diff --git a/ioutils/errors.py b/ioutils/errors.py index e08eb64..1ddadcc 100644 --- a/ioutils/errors.py +++ b/ioutils/errors.py @@ -31,9 +31,9 @@ class ErrorCode: MISSING_FILE = '1012' - NO_DATA_AVAILABLE = '1013' + CHANGE_PASSWORD = '1013' - CHANGE_PASSWORD = '1014' + IP_ADDRESS_BLACKLISTED = '1014' class ErrorMessage: @@ -75,8 +75,8 @@ class ErrorMessage: MISSING_FILE = 'Missing required file.' - NO_DATA_AVAILABLE = 'No data available for this request.' - # account settings CHANGE_PASSWORD = 'A password change has been activated for your account. You must first authenticate via web app and follow the instructions.' + + IP_ADDRESS_BLACKLISTED = 'Your IP address is temporary blacklisted.' diff --git a/ioutils/parameters.py.save b/ioutils/parameters.py.save deleted file mode 100644 index b462907..0000000 --- a/ioutils/parameters.py.save +++ /dev/null @@ -1,143 +0,0 @@ -from piracyshield_component.security.filter import Filter - -import json - -class JSONParametersHandler: - - """ - Pre-handle for the JSON request. - """ - - request_body = {} - - default_sanitization_rules = { - 'string': [ - 'strip' - ] - } - - def __init__(self, request_body): - self.request_body = request_body - - def process_request(self, required_fields: list = None, optional_fields: list = None, sanitization_rules: list = None) -> dict: - """ - Validates and sanitizes incoming JSON request data. - - # TODO: need to determine a sanitization template for the rules as this is still a generic approach. - - :param required_fields: list of required fields. - :param optional_fields: list of non-mandatory fields. - :param sanitization_rules: list of required sanitizations. - """ - - # try to load the JSON data, this will raise an exception if the content is not valid - try: - self.request_body = json.loads(self.request_body) - - except Exception: - raise JSONParametersNonValidException() - - if required_fields: - self._validate_input(required_fields, optional_fields) - - if sanitization_rules: - self.request_body = self._sanitize_input(self.request_body, sanitization_rules) - - else: - self.request_body = self._sanitize_input(self.request_body, self.default_sanitization_rules) - - return self.request_body - - def _validate_input(self, required_fields: list, optional_fields: list) -> None | Exception: - """ - Validates that the incoming JSON request contains all required fields. - - :param required_fields: list of required fields. - :param optional_fields: list of non-mandatory fields. - """ - - if not optional_fields: - # if there's no optional field we want the exact number of parameters - if len(self.request_body) > len(required_fields): - raise JSONParametersTooManyException() - - missing_fields = [] - - for field in required_fields: - if field not in self.request_body: - missing_fields.append(field) - - if missing_fields: - raise JSONParametersMissingException() - - # TODO: should we report the missing fields back to the user? - - return None - - def _sanitize_input(self, data: any, sanitization_rules: list) -> any: - """ - Cleans the input data. - - :param data: any parameter in the request. - :param sanitization_rules: list of cleaning rules. - :return: the cleaned data. - """ - - if isinstance(data, dict): - sanitized_data = {} - - for key, value in data.items(): - if value: - sanitized_value = self._sanitize_input(value, sanitization_rules) - - if sanitized_value: - sanitized_data[key] = sanitized_value - - - else: - sanitized_data[key] = value - - return sanitized_data - - elif isinstance(data, list): - sanitized_data = [] - - for item in data: - sanitized_value = self._sanitize_input(item, sanitization_rules) - - if item: - sanitized_data.append(sanitized_value) - - return sanitized_data - - elif isinstance(data, str): - if 'string' in sanitization_rules.keys(): - if 'strip' in sanitization_rules['string']: - return Filter.strip(data) - - else: - return data - -class JSONParametersNonValidException(Exception): - - """ - Not JSON data. - """ - - pass - -class JSONParametersMissingException(Exception): - - """ - The parameters we're looking for are completely missing. - """ - - pass - -class JSONParametersTooManyException(Exception): - - """ - More parameters than expected is an error. - """ - - pass diff --git a/ioutils/protected.py b/ioutils/protected.py index 991b0f6..4464b03 100644 --- a/ioutils/protected.py +++ b/ioutils/protected.py @@ -27,10 +27,10 @@ class ProtectedHandler(BaseHandler): account_data = {} - def prepare(self): + async def prepare(self): self.authentication_verify_access_token_service = AuthenticationVerifyAccessTokenService() - super().prepare() + await super().prepare() def initialize_account(self): """ diff --git a/ioutils/request.py b/ioutils/request.py index b6da808..1fa79b0 100644 --- a/ioutils/request.py +++ b/ioutils/request.py @@ -1,58 +1,49 @@ from piracyshield_component.log.logger import Logger from piracyshield_component.exception import ApplicationException - from .response import ResponseHandler -import tornado.web import time +from collections import deque +from tornado.locks import Lock + +from interceptors.logger import LoggerInterceptor +from interceptors.security.blacklist import BlacklistInterceptor +from interceptors.security.rate_limit import RateLimitInterceptor + +from ioutils.errors import ErrorCode, ErrorMessage class RequestHandler(ResponseHandler): """ Requests gateway. + + The requestor gets blocked here before proceding any further. + Interceptors are implemented to allow certain features to take part of the action. + + Current actived interceptors: + - security: takes care of the IP addresses exceeding limits (ex. applies the ban rules when triggered). + - rate limit: the IP will get a 429 when trying to perform too many requests on a given timeframe. """ - # override the default methods - SUPPORTED_METHODS = ("GET", "POST") + interceptors = [ + LoggerInterceptor, + BlacklistInterceptor, + RateLimitInterceptor + ] - # max requests allowed in a second - MAX_REQUESTS_PER_SECOND = 100 + async def prepare(self): + await self.run_interceptors() - # requests container - REQUESTS = {} + async def run_interceptors(self): + for interceptor in self.interceptors: + try: + await interceptor().execute(self) - def prepare(self) -> None: - """ - Handles the request general procedures. - This method implements a very simple request limit check. - """ - - self.application.logger.debug(f'> GET `{self.request.uri}` from `{self.request.remote_ip}`') - - # get the current timestamp in seconds - timestamp = int(time.time()) - - # TODO: this should be better handled and also provide a way to temporary ban each IP when flooding. - - # check if the number of requests for this second has exceeded the limit - if timestamp in self.REQUESTS: - if self.REQUESTS[timestamp] >= self.MAX_REQUESTS_PER_SECOND: - self.error(status_code = 429, error_code = ErrorCode.TOO_MANY_REQUESTS, message = ErrorMessage.TOO_MANY_REQUESTS) + except Exception as InterceptorException: + self.error( + status_code = InterceptorException.status_code, + error_code = InterceptorException.error_code, + message = InterceptorException.error_message + ) return - - # increment the number of requests for this second - self.REQUESTS[timestamp] = self.REQUESTS.get(timestamp, 0) + 1 - - # decrement the number of requests after one second - tornado.ioloop.IOLoop.current().call_later(1.0, self._decrement_requests_count, timestamp) - - def _decrement_requests_count(self, timestamp): - """ - Decrement the requests count per timestamp. - - :param timestamp: current timestamp. - """ - - if timestamp in self.REQUESTS: - self.REQUESTS[timestamp] -= 1 diff --git a/ioutils/response.py b/ioutils/response.py index 5d0d720..54784a2 100644 --- a/ioutils/response.py +++ b/ioutils/response.py @@ -1,14 +1,17 @@ -import tornado.web +from tornado.web import RequestHandler + +from piracyshield_component.config import Config from .errors import ErrorCode, ErrorMessage -import datetime import json -class ResponseHandler(tornado.web.RequestHandler): +class ResponseHandler(RequestHandler): """ Response handler. + + Prepares the response parameters for the next layer (Request). """ def set_default_headers(self) -> None: @@ -16,11 +19,12 @@ class ResponseHandler(tornado.web.RequestHandler): Sets the default headers. """ - self.set_header('Access-Control-Allow-Origin', '*') + # allow CORS only from current domain + self.set_header('Access-Control-Allow-Origin', self.application.domain) - self.set_header('Access-Control-Allow-Methods', 'POST, GET') + self.set_header('Access-Control-Allow-Methods', self.application.allowed_methods) - self.set_header('Access-Control-Max-Age', 3600) + self.set_header('Access-Control-Max-Age', self.application.max_age) # response for this API is always a JSON self.set_header('Content-Type', 'application/json') @@ -44,8 +48,8 @@ class ResponseHandler(tornado.web.RequestHandler): value = value, expires_days = 1, httponly = True, - samesite = "Strict" - # secure = True + samesite = "Strict", + secure = True ) def get_refresh_cookie(self) -> any: @@ -106,8 +110,9 @@ class ResponseHandler(tornado.web.RequestHandler): 'status': 'success' } - # this will return empty data as well - response['data'] = data + # ensures these data are appended in any case, even if empty + if isinstance(data, str) or isinstance(data, list) or isinstance(data, dict): + response['data'] = data # generic purpose informations that we want to communicate if note: diff --git a/requirements.txt b/requirements.txt index 75b3ae4..6866d09 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -tornado +tornado ~= 6.3.3 pytest diff --git a/v1/handlers/authentication/login.py b/v1/handlers/authentication/login.py index d4cbbf5..4709ca3 100644 --- a/v1/handlers/authentication/login.py +++ b/v1/handlers/authentication/login.py @@ -40,7 +40,8 @@ class AuthenticationLoginHandler(BaseHandler): None, self.process, self.request_data.get('email'), - self.request_data.get('password') + self.request_data.get('password'), + self.request.remote_ip ) self.success(data = { @@ -51,13 +52,14 @@ class AuthenticationLoginHandler(BaseHandler): except ApplicationException as e: self.error(status_code = 400, error_code = e.code, message = e.message) - def process(self, email: str, password: str) -> tuple: + def process(self, email: str, password: str, ip_address: str) -> tuple: authentication_authenticate_service = AuthenticationAuthenticateService() # try to authenticate payload = authentication_authenticate_service.execute( email = email, - password = password + password = password, + ip_address = ip_address ) authentication_generate_access_token_service = AuthenticationGenerateAccessTokenService() diff --git a/v1/handlers/authentication/refresh.py b/v1/handlers/authentication/refresh.py index 9faf860..1942d59 100644 --- a/v1/handlers/authentication/refresh.py +++ b/v1/handlers/authentication/refresh.py @@ -36,7 +36,7 @@ class AuthenticationRefreshHandler(BaseHandler): return if not self.request_data.get('refresh_token'): - return self.error(status_code = 403, error_code = ErrorCode.MISSING_REFRESH_TOKEN, message = ErrorMessage.MISSING_REFRESH_TOKEN) + return self.error(status_code = 401, error_code = ErrorCode.MISSING_REFRESH_TOKEN, message = ErrorMessage.MISSING_REFRESH_TOKEN) refresh_token = self.request_data.get('refresh_token')