Added requests interceptors support. IP addresses can now be banned.

This commit is contained in:
Daniele Maglie 2024-01-19 15:25:53 +01:00
parent f3ed518b65
commit 183fa3e27a
14 changed files with 196 additions and 208 deletions

View file

@ -34,4 +34,4 @@ sast:
bench:
@echo "Running benchmark"
ab -v 2 -n 20 -c 10 -k -T application/json -p tests/bench_params.txt 127.0.0.1:58008/api/v1/authentication/login
#ab -v 2 -n 20 -c 10 -k 127.0.0.1:58008/api/v1/ping
#ab -v 2 -n 200 -c 10 -k 127.0.0.1:58008/api/v1/ping

View file

@ -11,7 +11,13 @@ class Application(tornado.web.Application):
Application initializer.
"""
def __init__(self, debug: bool, handlers: list, version: str, prefix: str, cookie_secret: str, cache_path: str):
domain = 'localhost'
allowed_methods = 'POST, GET'
max_age = 3600
def __init__(self, debug: bool, domain: str, handlers: list, version: str, prefix: str, cookie_secret: str, cache_path: str):
"""
This is a list of handlers that directly extend the RequestHandler class, instead of relying on a BaseHandler class.
@ -23,6 +29,9 @@ class Application(tornado.web.Application):
:return: a list of routes and their handlers.
"""
if domain:
self.domain = domain
if prefix:
version = f'{prefix}/{version}'

View file

@ -7,13 +7,17 @@ from v1.routes import APIv1
from application import Application
api_config = Config('api').get('general')
application_config = Config('application').get('general')
api_config = Config('application').get('api')
if __name__ == "__main__":
app = Application(
# wether to run the application using debug mode
debug = api_config['debug'],
# useful for CORS settings
domain = application_config['domain'],
# load current routes
handlers = APIv1.routes,

18
interceptors/logger.py Normal file
View file

@ -0,0 +1,18 @@
from piracyshield_component.log.logger import Logger
class LoggerInterceptor:
"""
Logs requests
"""
async def execute(self, r):
r.application.logger.debug(f'> GET `{r.request.uri}` from `{r.request.remote_ip}`')
class LoggerInterceptorException(Exception):
"""
We need this as a placeholder.
"""
pass

View file

@ -0,0 +1,34 @@
from piracyshield_service.security.blacklist.exists_by_ip_address import SecurityBlacklistExistsByIPAddressService
from ioutils.errors import ErrorCode, ErrorMessage
class BlacklistInterceptor:
"""
Applies bans.
"""
security_blacklist_exists_by_ip_address_service = None
async def execute(self, r):
await self._prepare_modules()
if self.security_blacklist_exists_by_ip_address_service.execute(
ip_address = r.request.remote_ip
) == True:
raise BlacklistInterceptorException()
async def _prepare_modules(self):
self.security_blacklist_exists_by_ip_address_service = SecurityBlacklistExistsByIPAddressService()
class BlacklistInterceptorException(Exception):
"""
The IP address is temporary banned.
"""
status_code = 403
error_code = ErrorCode.IP_ADDRESS_BLACKLISTED
error_message = ErrorMessage.IP_ADDRESS_BLACKLISTED

View file

@ -0,0 +1,68 @@
from tornado.locks import Lock
from collections import deque
import time
from ioutils.errors import ErrorCode, ErrorMessage
class RateLimitInterceptor:
"""
Blocks flooding.
"""
# override the default methods
SUPPORTED_METHODS = ("GET", "POST")
# max requests allowed in a second
MAX_REQUESTS_PER_SECOND = 100
# requests container
REQUESTS = {}
# needs a mutex to manage the async threads writes
mutex = Lock()
async def execute(self, _):
# get the current timestamp in seconds
timestamp = int(time.time())
async with self.mutex:
if timestamp not in self.REQUESTS:
self.REQUESTS[timestamp] = deque(maxlen = self.MAX_REQUESTS_PER_SECOND)
# check if we exceeded the rate limit
if len(self.REQUESTS[timestamp]) >= self.MAX_REQUESTS_PER_SECOND:
raise RateLimitInterceptorException()
# add the current request
self.REQUESTS[timestamp].append(time.time())
# trigger the old timestamps cleaning
await self._clean_timestamps(timestamp)
async def _clean_timestamps(self, current_timestamp):
"""
Cleanup the old timestamps that are no longer relevant.
NOTE: this is quite inefficient and slow for Python, but that's all we can do to manage this in the application.
:param current_timestamp: current timestamp.
"""
async with self.mutex:
# keep only the keys for the current and previous second
for timestamp in list(self.REQUESTS.keys()):
if timestamp < current_timestamp - 1:
del self.REQUESTS[timestamp]
class RateLimitInterceptorException(Exception):
"""
Too many requests.
"""
status_code = 429
error_code = ErrorCode.TOO_MANY_REQUESTS
error_message = ErrorMessage.TOO_MANY_REQUESTS

View file

@ -31,9 +31,9 @@ class ErrorCode:
MISSING_FILE = '1012'
NO_DATA_AVAILABLE = '1013'
CHANGE_PASSWORD = '1013'
CHANGE_PASSWORD = '1014'
IP_ADDRESS_BLACKLISTED = '1014'
class ErrorMessage:
@ -75,8 +75,8 @@ class ErrorMessage:
MISSING_FILE = 'Missing required file.'
NO_DATA_AVAILABLE = 'No data available for this request.'
# account settings
CHANGE_PASSWORD = 'A password change has been activated for your account. You must first authenticate via web app and follow the instructions.'
IP_ADDRESS_BLACKLISTED = 'Your IP address is temporary blacklisted.'

View file

@ -1,143 +0,0 @@
from piracyshield_component.security.filter import Filter
import json
class JSONParametersHandler:
"""
Pre-handle for the JSON request.
"""
request_body = {}
default_sanitization_rules = {
'string': [
'strip'
]
}
def __init__(self, request_body):
self.request_body = request_body
def process_request(self, required_fields: list = None, optional_fields: list = None, sanitization_rules: list = None) -> dict:
"""
Validates and sanitizes incoming JSON request data.
# TODO: need to determine a sanitization template for the rules as this is still a generic approach.
:param required_fields: list of required fields.
:param optional_fields: list of non-mandatory fields.
:param sanitization_rules: list of required sanitizations.
"""
# try to load the JSON data, this will raise an exception if the content is not valid
try:
self.request_body = json.loads(self.request_body)
except Exception:
raise JSONParametersNonValidException()
if required_fields:
self._validate_input(required_fields, optional_fields)
if sanitization_rules:
self.request_body = self._sanitize_input(self.request_body, sanitization_rules)
else:
self.request_body = self._sanitize_input(self.request_body, self.default_sanitization_rules)
return self.request_body
def _validate_input(self, required_fields: list, optional_fields: list) -> None | Exception:
"""
Validates that the incoming JSON request contains all required fields.
:param required_fields: list of required fields.
:param optional_fields: list of non-mandatory fields.
"""
if not optional_fields:
# if there's no optional field we want the exact number of parameters
if len(self.request_body) > len(required_fields):
raise JSONParametersTooManyException()
missing_fields = []
for field in required_fields:
if field not in self.request_body:
missing_fields.append(field)
if missing_fields:
raise JSONParametersMissingException()
# TODO: should we report the missing fields back to the user?
return None
def _sanitize_input(self, data: any, sanitization_rules: list) -> any:
"""
Cleans the input data.
:param data: any parameter in the request.
:param sanitization_rules: list of cleaning rules.
:return: the cleaned data.
"""
if isinstance(data, dict):
sanitized_data = {}
for key, value in data.items():
if value:
sanitized_value = self._sanitize_input(value, sanitization_rules)
if sanitized_value:
sanitized_data[key] = sanitized_value
else:
sanitized_data[key] = value
return sanitized_data
elif isinstance(data, list):
sanitized_data = []
for item in data:
sanitized_value = self._sanitize_input(item, sanitization_rules)
if item:
sanitized_data.append(sanitized_value)
return sanitized_data
elif isinstance(data, str):
if 'string' in sanitization_rules.keys():
if 'strip' in sanitization_rules['string']:
return Filter.strip(data)
else:
return data
class JSONParametersNonValidException(Exception):
"""
Not JSON data.
"""
pass
class JSONParametersMissingException(Exception):
"""
The parameters we're looking for are completely missing.
"""
pass
class JSONParametersTooManyException(Exception):
"""
More parameters than expected is an error.
"""
pass

View file

@ -27,10 +27,10 @@ class ProtectedHandler(BaseHandler):
account_data = {}
def prepare(self):
async def prepare(self):
self.authentication_verify_access_token_service = AuthenticationVerifyAccessTokenService()
super().prepare()
await super().prepare()
def initialize_account(self):
"""

View file

@ -1,58 +1,49 @@
from piracyshield_component.log.logger import Logger
from piracyshield_component.exception import ApplicationException
from .response import ResponseHandler
import tornado.web
import time
from collections import deque
from tornado.locks import Lock
from interceptors.logger import LoggerInterceptor
from interceptors.security.blacklist import BlacklistInterceptor
from interceptors.security.rate_limit import RateLimitInterceptor
from ioutils.errors import ErrorCode, ErrorMessage
class RequestHandler(ResponseHandler):
"""
Requests gateway.
The requestor gets blocked here before proceding any further.
Interceptors are implemented to allow certain features to take part of the action.
Current actived interceptors:
- security: takes care of the IP addresses exceeding limits (ex. applies the ban rules when triggered).
- rate limit: the IP will get a 429 when trying to perform too many requests on a given timeframe.
"""
# override the default methods
SUPPORTED_METHODS = ("GET", "POST")
interceptors = [
LoggerInterceptor,
BlacklistInterceptor,
RateLimitInterceptor
]
# max requests allowed in a second
MAX_REQUESTS_PER_SECOND = 100
async def prepare(self):
await self.run_interceptors()
# requests container
REQUESTS = {}
async def run_interceptors(self):
for interceptor in self.interceptors:
try:
await interceptor().execute(self)
def prepare(self) -> None:
"""
Handles the request general procedures.
This method implements a very simple request limit check.
"""
self.application.logger.debug(f'> GET `{self.request.uri}` from `{self.request.remote_ip}`')
# get the current timestamp in seconds
timestamp = int(time.time())
# TODO: this should be better handled and also provide a way to temporary ban each IP when flooding.
# check if the number of requests for this second has exceeded the limit
if timestamp in self.REQUESTS:
if self.REQUESTS[timestamp] >= self.MAX_REQUESTS_PER_SECOND:
self.error(status_code = 429, error_code = ErrorCode.TOO_MANY_REQUESTS, message = ErrorMessage.TOO_MANY_REQUESTS)
except Exception as InterceptorException:
self.error(
status_code = InterceptorException.status_code,
error_code = InterceptorException.error_code,
message = InterceptorException.error_message
)
return
# increment the number of requests for this second
self.REQUESTS[timestamp] = self.REQUESTS.get(timestamp, 0) + 1
# decrement the number of requests after one second
tornado.ioloop.IOLoop.current().call_later(1.0, self._decrement_requests_count, timestamp)
def _decrement_requests_count(self, timestamp):
"""
Decrement the requests count per timestamp.
:param timestamp: current timestamp.
"""
if timestamp in self.REQUESTS:
self.REQUESTS[timestamp] -= 1

View file

@ -1,14 +1,17 @@
import tornado.web
from tornado.web import RequestHandler
from piracyshield_component.config import Config
from .errors import ErrorCode, ErrorMessage
import datetime
import json
class ResponseHandler(tornado.web.RequestHandler):
class ResponseHandler(RequestHandler):
"""
Response handler.
Prepares the response parameters for the next layer (Request).
"""
def set_default_headers(self) -> None:
@ -16,11 +19,12 @@ class ResponseHandler(tornado.web.RequestHandler):
Sets the default headers.
"""
self.set_header('Access-Control-Allow-Origin', '*')
# allow CORS only from current domain
self.set_header('Access-Control-Allow-Origin', self.application.domain)
self.set_header('Access-Control-Allow-Methods', 'POST, GET')
self.set_header('Access-Control-Allow-Methods', self.application.allowed_methods)
self.set_header('Access-Control-Max-Age', 3600)
self.set_header('Access-Control-Max-Age', self.application.max_age)
# response for this API is always a JSON
self.set_header('Content-Type', 'application/json')
@ -44,8 +48,8 @@ class ResponseHandler(tornado.web.RequestHandler):
value = value,
expires_days = 1,
httponly = True,
samesite = "Strict"
# secure = True
samesite = "Strict",
secure = True
)
def get_refresh_cookie(self) -> any:
@ -106,8 +110,9 @@ class ResponseHandler(tornado.web.RequestHandler):
'status': 'success'
}
# this will return empty data as well
response['data'] = data
# ensures these data are appended in any case, even if empty
if isinstance(data, str) or isinstance(data, list) or isinstance(data, dict):
response['data'] = data
# generic purpose informations that we want to communicate
if note:

View file

@ -1,2 +1,2 @@
tornado
tornado ~= 6.3.3
pytest

View file

@ -40,7 +40,8 @@ class AuthenticationLoginHandler(BaseHandler):
None,
self.process,
self.request_data.get('email'),
self.request_data.get('password')
self.request_data.get('password'),
self.request.remote_ip
)
self.success(data = {
@ -51,13 +52,14 @@ class AuthenticationLoginHandler(BaseHandler):
except ApplicationException as e:
self.error(status_code = 400, error_code = e.code, message = e.message)
def process(self, email: str, password: str) -> tuple:
def process(self, email: str, password: str, ip_address: str) -> tuple:
authentication_authenticate_service = AuthenticationAuthenticateService()
# try to authenticate
payload = authentication_authenticate_service.execute(
email = email,
password = password
password = password,
ip_address = ip_address
)
authentication_generate_access_token_service = AuthenticationGenerateAccessTokenService()

View file

@ -36,7 +36,7 @@ class AuthenticationRefreshHandler(BaseHandler):
return
if not self.request_data.get('refresh_token'):
return self.error(status_code = 403, error_code = ErrorCode.MISSING_REFRESH_TOKEN, message = ErrorMessage.MISSING_REFRESH_TOKEN)
return self.error(status_code = 401, error_code = ErrorCode.MISSING_REFRESH_TOKEN, message = ErrorMessage.MISSING_REFRESH_TOKEN)
refresh_token = self.request_data.get('refresh_token')