Added requests interceptors support. IP addresses can now be banned.

This commit is contained in:
Daniele Maglie 2024-01-19 15:25:53 +01:00
parent f3ed518b65
commit 183fa3e27a
14 changed files with 196 additions and 208 deletions

View file

@ -34,4 +34,4 @@ sast:
bench: bench:
@echo "Running benchmark" @echo "Running benchmark"
ab -v 2 -n 20 -c 10 -k -T application/json -p tests/bench_params.txt 127.0.0.1:58008/api/v1/authentication/login ab -v 2 -n 20 -c 10 -k -T application/json -p tests/bench_params.txt 127.0.0.1:58008/api/v1/authentication/login
#ab -v 2 -n 20 -c 10 -k 127.0.0.1:58008/api/v1/ping #ab -v 2 -n 200 -c 10 -k 127.0.0.1:58008/api/v1/ping

View file

@ -11,7 +11,13 @@ class Application(tornado.web.Application):
Application initializer. Application initializer.
""" """
def __init__(self, debug: bool, handlers: list, version: str, prefix: str, cookie_secret: str, cache_path: str): domain = 'localhost'
allowed_methods = 'POST, GET'
max_age = 3600
def __init__(self, debug: bool, domain: str, handlers: list, version: str, prefix: str, cookie_secret: str, cache_path: str):
""" """
This is a list of handlers that directly extend the RequestHandler class, instead of relying on a BaseHandler class. This is a list of handlers that directly extend the RequestHandler class, instead of relying on a BaseHandler class.
@ -23,6 +29,9 @@ class Application(tornado.web.Application):
:return: a list of routes and their handlers. :return: a list of routes and their handlers.
""" """
if domain:
self.domain = domain
if prefix: if prefix:
version = f'{prefix}/{version}' version = f'{prefix}/{version}'

View file

@ -7,13 +7,17 @@ from v1.routes import APIv1
from application import Application from application import Application
api_config = Config('api').get('general') application_config = Config('application').get('general')
api_config = Config('application').get('api')
if __name__ == "__main__": if __name__ == "__main__":
app = Application( app = Application(
# wether to run the application using debug mode # wether to run the application using debug mode
debug = api_config['debug'], debug = api_config['debug'],
# useful for CORS settings
domain = application_config['domain'],
# load current routes # load current routes
handlers = APIv1.routes, handlers = APIv1.routes,

18
interceptors/logger.py Normal file
View file

@ -0,0 +1,18 @@
from piracyshield_component.log.logger import Logger
class LoggerInterceptor:
"""
Logs requests
"""
async def execute(self, r):
r.application.logger.debug(f'> GET `{r.request.uri}` from `{r.request.remote_ip}`')
class LoggerInterceptorException(Exception):
"""
We need this as a placeholder.
"""
pass

View file

@ -0,0 +1,34 @@
from piracyshield_service.security.blacklist.exists_by_ip_address import SecurityBlacklistExistsByIPAddressService
from ioutils.errors import ErrorCode, ErrorMessage
class BlacklistInterceptor:
"""
Applies bans.
"""
security_blacklist_exists_by_ip_address_service = None
async def execute(self, r):
await self._prepare_modules()
if self.security_blacklist_exists_by_ip_address_service.execute(
ip_address = r.request.remote_ip
) == True:
raise BlacklistInterceptorException()
async def _prepare_modules(self):
self.security_blacklist_exists_by_ip_address_service = SecurityBlacklistExistsByIPAddressService()
class BlacklistInterceptorException(Exception):
"""
The IP address is temporary banned.
"""
status_code = 403
error_code = ErrorCode.IP_ADDRESS_BLACKLISTED
error_message = ErrorMessage.IP_ADDRESS_BLACKLISTED

View file

@ -0,0 +1,68 @@
from tornado.locks import Lock
from collections import deque
import time
from ioutils.errors import ErrorCode, ErrorMessage
class RateLimitInterceptor:
"""
Blocks flooding.
"""
# override the default methods
SUPPORTED_METHODS = ("GET", "POST")
# max requests allowed in a second
MAX_REQUESTS_PER_SECOND = 100
# requests container
REQUESTS = {}
# needs a mutex to manage the async threads writes
mutex = Lock()
async def execute(self, _):
# get the current timestamp in seconds
timestamp = int(time.time())
async with self.mutex:
if timestamp not in self.REQUESTS:
self.REQUESTS[timestamp] = deque(maxlen = self.MAX_REQUESTS_PER_SECOND)
# check if we exceeded the rate limit
if len(self.REQUESTS[timestamp]) >= self.MAX_REQUESTS_PER_SECOND:
raise RateLimitInterceptorException()
# add the current request
self.REQUESTS[timestamp].append(time.time())
# trigger the old timestamps cleaning
await self._clean_timestamps(timestamp)
async def _clean_timestamps(self, current_timestamp):
"""
Cleanup the old timestamps that are no longer relevant.
NOTE: this is quite inefficient and slow for Python, but that's all we can do to manage this in the application.
:param current_timestamp: current timestamp.
"""
async with self.mutex:
# keep only the keys for the current and previous second
for timestamp in list(self.REQUESTS.keys()):
if timestamp < current_timestamp - 1:
del self.REQUESTS[timestamp]
class RateLimitInterceptorException(Exception):
"""
Too many requests.
"""
status_code = 429
error_code = ErrorCode.TOO_MANY_REQUESTS
error_message = ErrorMessage.TOO_MANY_REQUESTS

View file

@ -31,9 +31,9 @@ class ErrorCode:
MISSING_FILE = '1012' MISSING_FILE = '1012'
NO_DATA_AVAILABLE = '1013' CHANGE_PASSWORD = '1013'
CHANGE_PASSWORD = '1014' IP_ADDRESS_BLACKLISTED = '1014'
class ErrorMessage: class ErrorMessage:
@ -75,8 +75,8 @@ class ErrorMessage:
MISSING_FILE = 'Missing required file.' MISSING_FILE = 'Missing required file.'
NO_DATA_AVAILABLE = 'No data available for this request.'
# account settings # account settings
CHANGE_PASSWORD = 'A password change has been activated for your account. You must first authenticate via web app and follow the instructions.' CHANGE_PASSWORD = 'A password change has been activated for your account. You must first authenticate via web app and follow the instructions.'
IP_ADDRESS_BLACKLISTED = 'Your IP address is temporary blacklisted.'

View file

@ -1,143 +0,0 @@
from piracyshield_component.security.filter import Filter
import json
class JSONParametersHandler:
"""
Pre-handle for the JSON request.
"""
request_body = {}
default_sanitization_rules = {
'string': [
'strip'
]
}
def __init__(self, request_body):
self.request_body = request_body
def process_request(self, required_fields: list = None, optional_fields: list = None, sanitization_rules: list = None) -> dict:
"""
Validates and sanitizes incoming JSON request data.
# TODO: need to determine a sanitization template for the rules as this is still a generic approach.
:param required_fields: list of required fields.
:param optional_fields: list of non-mandatory fields.
:param sanitization_rules: list of required sanitizations.
"""
# try to load the JSON data, this will raise an exception if the content is not valid
try:
self.request_body = json.loads(self.request_body)
except Exception:
raise JSONParametersNonValidException()
if required_fields:
self._validate_input(required_fields, optional_fields)
if sanitization_rules:
self.request_body = self._sanitize_input(self.request_body, sanitization_rules)
else:
self.request_body = self._sanitize_input(self.request_body, self.default_sanitization_rules)
return self.request_body
def _validate_input(self, required_fields: list, optional_fields: list) -> None | Exception:
"""
Validates that the incoming JSON request contains all required fields.
:param required_fields: list of required fields.
:param optional_fields: list of non-mandatory fields.
"""
if not optional_fields:
# if there's no optional field we want the exact number of parameters
if len(self.request_body) > len(required_fields):
raise JSONParametersTooManyException()
missing_fields = []
for field in required_fields:
if field not in self.request_body:
missing_fields.append(field)
if missing_fields:
raise JSONParametersMissingException()
# TODO: should we report the missing fields back to the user?
return None
def _sanitize_input(self, data: any, sanitization_rules: list) -> any:
"""
Cleans the input data.
:param data: any parameter in the request.
:param sanitization_rules: list of cleaning rules.
:return: the cleaned data.
"""
if isinstance(data, dict):
sanitized_data = {}
for key, value in data.items():
if value:
sanitized_value = self._sanitize_input(value, sanitization_rules)
if sanitized_value:
sanitized_data[key] = sanitized_value
else:
sanitized_data[key] = value
return sanitized_data
elif isinstance(data, list):
sanitized_data = []
for item in data:
sanitized_value = self._sanitize_input(item, sanitization_rules)
if item:
sanitized_data.append(sanitized_value)
return sanitized_data
elif isinstance(data, str):
if 'string' in sanitization_rules.keys():
if 'strip' in sanitization_rules['string']:
return Filter.strip(data)
else:
return data
class JSONParametersNonValidException(Exception):
"""
Not JSON data.
"""
pass
class JSONParametersMissingException(Exception):
"""
The parameters we're looking for are completely missing.
"""
pass
class JSONParametersTooManyException(Exception):
"""
More parameters than expected is an error.
"""
pass

View file

@ -27,10 +27,10 @@ class ProtectedHandler(BaseHandler):
account_data = {} account_data = {}
def prepare(self): async def prepare(self):
self.authentication_verify_access_token_service = AuthenticationVerifyAccessTokenService() self.authentication_verify_access_token_service = AuthenticationVerifyAccessTokenService()
super().prepare() await super().prepare()
def initialize_account(self): def initialize_account(self):
""" """

View file

@ -1,58 +1,49 @@
from piracyshield_component.log.logger import Logger from piracyshield_component.log.logger import Logger
from piracyshield_component.exception import ApplicationException from piracyshield_component.exception import ApplicationException
from .response import ResponseHandler from .response import ResponseHandler
import tornado.web
import time import time
from collections import deque
from tornado.locks import Lock
from interceptors.logger import LoggerInterceptor
from interceptors.security.blacklist import BlacklistInterceptor
from interceptors.security.rate_limit import RateLimitInterceptor
from ioutils.errors import ErrorCode, ErrorMessage
class RequestHandler(ResponseHandler): class RequestHandler(ResponseHandler):
""" """
Requests gateway. Requests gateway.
The requestor gets blocked here before proceding any further.
Interceptors are implemented to allow certain features to take part of the action.
Current actived interceptors:
- security: takes care of the IP addresses exceeding limits (ex. applies the ban rules when triggered).
- rate limit: the IP will get a 429 when trying to perform too many requests on a given timeframe.
""" """
# override the default methods interceptors = [
SUPPORTED_METHODS = ("GET", "POST") LoggerInterceptor,
BlacklistInterceptor,
RateLimitInterceptor
]
# max requests allowed in a second async def prepare(self):
MAX_REQUESTS_PER_SECOND = 100 await self.run_interceptors()
# requests container async def run_interceptors(self):
REQUESTS = {} for interceptor in self.interceptors:
try:
await interceptor().execute(self)
def prepare(self) -> None: except Exception as InterceptorException:
""" self.error(
Handles the request general procedures. status_code = InterceptorException.status_code,
This method implements a very simple request limit check. error_code = InterceptorException.error_code,
""" message = InterceptorException.error_message
)
self.application.logger.debug(f'> GET `{self.request.uri}` from `{self.request.remote_ip}`')
# get the current timestamp in seconds
timestamp = int(time.time())
# TODO: this should be better handled and also provide a way to temporary ban each IP when flooding.
# check if the number of requests for this second has exceeded the limit
if timestamp in self.REQUESTS:
if self.REQUESTS[timestamp] >= self.MAX_REQUESTS_PER_SECOND:
self.error(status_code = 429, error_code = ErrorCode.TOO_MANY_REQUESTS, message = ErrorMessage.TOO_MANY_REQUESTS)
return return
# increment the number of requests for this second
self.REQUESTS[timestamp] = self.REQUESTS.get(timestamp, 0) + 1
# decrement the number of requests after one second
tornado.ioloop.IOLoop.current().call_later(1.0, self._decrement_requests_count, timestamp)
def _decrement_requests_count(self, timestamp):
"""
Decrement the requests count per timestamp.
:param timestamp: current timestamp.
"""
if timestamp in self.REQUESTS:
self.REQUESTS[timestamp] -= 1

View file

@ -1,14 +1,17 @@
import tornado.web from tornado.web import RequestHandler
from piracyshield_component.config import Config
from .errors import ErrorCode, ErrorMessage from .errors import ErrorCode, ErrorMessage
import datetime
import json import json
class ResponseHandler(tornado.web.RequestHandler): class ResponseHandler(RequestHandler):
""" """
Response handler. Response handler.
Prepares the response parameters for the next layer (Request).
""" """
def set_default_headers(self) -> None: def set_default_headers(self) -> None:
@ -16,11 +19,12 @@ class ResponseHandler(tornado.web.RequestHandler):
Sets the default headers. Sets the default headers.
""" """
self.set_header('Access-Control-Allow-Origin', '*') # allow CORS only from current domain
self.set_header('Access-Control-Allow-Origin', self.application.domain)
self.set_header('Access-Control-Allow-Methods', 'POST, GET') self.set_header('Access-Control-Allow-Methods', self.application.allowed_methods)
self.set_header('Access-Control-Max-Age', 3600) self.set_header('Access-Control-Max-Age', self.application.max_age)
# response for this API is always a JSON # response for this API is always a JSON
self.set_header('Content-Type', 'application/json') self.set_header('Content-Type', 'application/json')
@ -44,8 +48,8 @@ class ResponseHandler(tornado.web.RequestHandler):
value = value, value = value,
expires_days = 1, expires_days = 1,
httponly = True, httponly = True,
samesite = "Strict" samesite = "Strict",
# secure = True secure = True
) )
def get_refresh_cookie(self) -> any: def get_refresh_cookie(self) -> any:
@ -106,7 +110,8 @@ class ResponseHandler(tornado.web.RequestHandler):
'status': 'success' 'status': 'success'
} }
# this will return empty data as well # ensures these data are appended in any case, even if empty
if isinstance(data, str) or isinstance(data, list) or isinstance(data, dict):
response['data'] = data response['data'] = data
# generic purpose informations that we want to communicate # generic purpose informations that we want to communicate

View file

@ -1,2 +1,2 @@
tornado tornado ~= 6.3.3
pytest pytest

View file

@ -40,7 +40,8 @@ class AuthenticationLoginHandler(BaseHandler):
None, None,
self.process, self.process,
self.request_data.get('email'), self.request_data.get('email'),
self.request_data.get('password') self.request_data.get('password'),
self.request.remote_ip
) )
self.success(data = { self.success(data = {
@ -51,13 +52,14 @@ class AuthenticationLoginHandler(BaseHandler):
except ApplicationException as e: except ApplicationException as e:
self.error(status_code = 400, error_code = e.code, message = e.message) self.error(status_code = 400, error_code = e.code, message = e.message)
def process(self, email: str, password: str) -> tuple: def process(self, email: str, password: str, ip_address: str) -> tuple:
authentication_authenticate_service = AuthenticationAuthenticateService() authentication_authenticate_service = AuthenticationAuthenticateService()
# try to authenticate # try to authenticate
payload = authentication_authenticate_service.execute( payload = authentication_authenticate_service.execute(
email = email, email = email,
password = password password = password,
ip_address = ip_address
) )
authentication_generate_access_token_service = AuthenticationGenerateAccessTokenService() authentication_generate_access_token_service = AuthenticationGenerateAccessTokenService()

View file

@ -36,7 +36,7 @@ class AuthenticationRefreshHandler(BaseHandler):
return return
if not self.request_data.get('refresh_token'): if not self.request_data.get('refresh_token'):
return self.error(status_code = 403, error_code = ErrorCode.MISSING_REFRESH_TOKEN, message = ErrorMessage.MISSING_REFRESH_TOKEN) return self.error(status_code = 401, error_code = ErrorCode.MISSING_REFRESH_TOKEN, message = ErrorMessage.MISSING_REFRESH_TOKEN)
refresh_token = self.request_data.get('refresh_token') refresh_token = self.request_data.get('refresh_token')