mirror of
https://github.com/fuckpiracyshield/api.git
synced 2024-12-21 10:20:49 +01:00
Added requests interceptors support. IP addresses can now be banned.
This commit is contained in:
parent
f3ed518b65
commit
183fa3e27a
14 changed files with 196 additions and 208 deletions
2
Makefile
2
Makefile
|
@ -34,4 +34,4 @@ sast:
|
|||
bench:
|
||||
@echo "Running benchmark"
|
||||
ab -v 2 -n 20 -c 10 -k -T application/json -p tests/bench_params.txt 127.0.0.1:58008/api/v1/authentication/login
|
||||
#ab -v 2 -n 20 -c 10 -k 127.0.0.1:58008/api/v1/ping
|
||||
#ab -v 2 -n 200 -c 10 -k 127.0.0.1:58008/api/v1/ping
|
||||
|
|
|
@ -11,7 +11,13 @@ class Application(tornado.web.Application):
|
|||
Application initializer.
|
||||
"""
|
||||
|
||||
def __init__(self, debug: bool, handlers: list, version: str, prefix: str, cookie_secret: str, cache_path: str):
|
||||
domain = 'localhost'
|
||||
|
||||
allowed_methods = 'POST, GET'
|
||||
|
||||
max_age = 3600
|
||||
|
||||
def __init__(self, debug: bool, domain: str, handlers: list, version: str, prefix: str, cookie_secret: str, cache_path: str):
|
||||
"""
|
||||
This is a list of handlers that directly extend the RequestHandler class, instead of relying on a BaseHandler class.
|
||||
|
||||
|
@ -23,6 +29,9 @@ class Application(tornado.web.Application):
|
|||
:return: a list of routes and their handlers.
|
||||
"""
|
||||
|
||||
if domain:
|
||||
self.domain = domain
|
||||
|
||||
if prefix:
|
||||
version = f'{prefix}/{version}'
|
||||
|
||||
|
|
6
boot.py
6
boot.py
|
@ -7,13 +7,17 @@ from v1.routes import APIv1
|
|||
|
||||
from application import Application
|
||||
|
||||
api_config = Config('api').get('general')
|
||||
application_config = Config('application').get('general')
|
||||
api_config = Config('application').get('api')
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = Application(
|
||||
# wether to run the application using debug mode
|
||||
debug = api_config['debug'],
|
||||
|
||||
# useful for CORS settings
|
||||
domain = application_config['domain'],
|
||||
|
||||
# load current routes
|
||||
handlers = APIv1.routes,
|
||||
|
||||
|
|
18
interceptors/logger.py
Normal file
18
interceptors/logger.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
from piracyshield_component.log.logger import Logger
|
||||
|
||||
class LoggerInterceptor:
|
||||
|
||||
"""
|
||||
Logs requests
|
||||
"""
|
||||
|
||||
async def execute(self, r):
|
||||
r.application.logger.debug(f'> GET `{r.request.uri}` from `{r.request.remote_ip}`')
|
||||
|
||||
class LoggerInterceptorException(Exception):
|
||||
|
||||
"""
|
||||
We need this as a placeholder.
|
||||
"""
|
||||
|
||||
pass
|
34
interceptors/security/blacklist.py
Normal file
34
interceptors/security/blacklist.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
from piracyshield_service.security.blacklist.exists_by_ip_address import SecurityBlacklistExistsByIPAddressService
|
||||
|
||||
from ioutils.errors import ErrorCode, ErrorMessage
|
||||
|
||||
class BlacklistInterceptor:
|
||||
|
||||
"""
|
||||
Applies bans.
|
||||
"""
|
||||
|
||||
security_blacklist_exists_by_ip_address_service = None
|
||||
|
||||
async def execute(self, r):
|
||||
await self._prepare_modules()
|
||||
|
||||
if self.security_blacklist_exists_by_ip_address_service.execute(
|
||||
ip_address = r.request.remote_ip
|
||||
) == True:
|
||||
raise BlacklistInterceptorException()
|
||||
|
||||
async def _prepare_modules(self):
|
||||
self.security_blacklist_exists_by_ip_address_service = SecurityBlacklistExistsByIPAddressService()
|
||||
|
||||
class BlacklistInterceptorException(Exception):
|
||||
|
||||
"""
|
||||
The IP address is temporary banned.
|
||||
"""
|
||||
|
||||
status_code = 403
|
||||
|
||||
error_code = ErrorCode.IP_ADDRESS_BLACKLISTED
|
||||
|
||||
error_message = ErrorMessage.IP_ADDRESS_BLACKLISTED
|
68
interceptors/security/rate_limit.py
Normal file
68
interceptors/security/rate_limit.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
from tornado.locks import Lock
|
||||
from collections import deque
|
||||
import time
|
||||
|
||||
from ioutils.errors import ErrorCode, ErrorMessage
|
||||
|
||||
class RateLimitInterceptor:
|
||||
|
||||
"""
|
||||
Blocks flooding.
|
||||
"""
|
||||
|
||||
# override the default methods
|
||||
SUPPORTED_METHODS = ("GET", "POST")
|
||||
|
||||
# max requests allowed in a second
|
||||
MAX_REQUESTS_PER_SECOND = 100
|
||||
|
||||
# requests container
|
||||
REQUESTS = {}
|
||||
|
||||
# needs a mutex to manage the async threads writes
|
||||
mutex = Lock()
|
||||
|
||||
async def execute(self, _):
|
||||
# get the current timestamp in seconds
|
||||
timestamp = int(time.time())
|
||||
|
||||
async with self.mutex:
|
||||
if timestamp not in self.REQUESTS:
|
||||
self.REQUESTS[timestamp] = deque(maxlen = self.MAX_REQUESTS_PER_SECOND)
|
||||
|
||||
# check if we exceeded the rate limit
|
||||
if len(self.REQUESTS[timestamp]) >= self.MAX_REQUESTS_PER_SECOND:
|
||||
raise RateLimitInterceptorException()
|
||||
|
||||
# add the current request
|
||||
self.REQUESTS[timestamp].append(time.time())
|
||||
|
||||
# trigger the old timestamps cleaning
|
||||
await self._clean_timestamps(timestamp)
|
||||
|
||||
async def _clean_timestamps(self, current_timestamp):
|
||||
"""
|
||||
Cleanup the old timestamps that are no longer relevant.
|
||||
|
||||
NOTE: this is quite inefficient and slow for Python, but that's all we can do to manage this in the application.
|
||||
|
||||
:param current_timestamp: current timestamp.
|
||||
"""
|
||||
|
||||
async with self.mutex:
|
||||
# keep only the keys for the current and previous second
|
||||
for timestamp in list(self.REQUESTS.keys()):
|
||||
if timestamp < current_timestamp - 1:
|
||||
del self.REQUESTS[timestamp]
|
||||
|
||||
class RateLimitInterceptorException(Exception):
|
||||
|
||||
"""
|
||||
Too many requests.
|
||||
"""
|
||||
|
||||
status_code = 429
|
||||
|
||||
error_code = ErrorCode.TOO_MANY_REQUESTS
|
||||
|
||||
error_message = ErrorMessage.TOO_MANY_REQUESTS
|
|
@ -31,9 +31,9 @@ class ErrorCode:
|
|||
|
||||
MISSING_FILE = '1012'
|
||||
|
||||
NO_DATA_AVAILABLE = '1013'
|
||||
CHANGE_PASSWORD = '1013'
|
||||
|
||||
CHANGE_PASSWORD = '1014'
|
||||
IP_ADDRESS_BLACKLISTED = '1014'
|
||||
|
||||
class ErrorMessage:
|
||||
|
||||
|
@ -75,8 +75,8 @@ class ErrorMessage:
|
|||
|
||||
MISSING_FILE = 'Missing required file.'
|
||||
|
||||
NO_DATA_AVAILABLE = 'No data available for this request.'
|
||||
|
||||
# account settings
|
||||
|
||||
CHANGE_PASSWORD = 'A password change has been activated for your account. You must first authenticate via web app and follow the instructions.'
|
||||
|
||||
IP_ADDRESS_BLACKLISTED = 'Your IP address is temporary blacklisted.'
|
||||
|
|
|
@ -1,143 +0,0 @@
|
|||
from piracyshield_component.security.filter import Filter
|
||||
|
||||
import json
|
||||
|
||||
class JSONParametersHandler:
|
||||
|
||||
"""
|
||||
Pre-handle for the JSON request.
|
||||
"""
|
||||
|
||||
request_body = {}
|
||||
|
||||
default_sanitization_rules = {
|
||||
'string': [
|
||||
'strip'
|
||||
]
|
||||
}
|
||||
|
||||
def __init__(self, request_body):
|
||||
self.request_body = request_body
|
||||
|
||||
def process_request(self, required_fields: list = None, optional_fields: list = None, sanitization_rules: list = None) -> dict:
|
||||
"""
|
||||
Validates and sanitizes incoming JSON request data.
|
||||
|
||||
# TODO: need to determine a sanitization template for the rules as this is still a generic approach.
|
||||
|
||||
:param required_fields: list of required fields.
|
||||
:param optional_fields: list of non-mandatory fields.
|
||||
:param sanitization_rules: list of required sanitizations.
|
||||
"""
|
||||
|
||||
# try to load the JSON data, this will raise an exception if the content is not valid
|
||||
try:
|
||||
self.request_body = json.loads(self.request_body)
|
||||
|
||||
except Exception:
|
||||
raise JSONParametersNonValidException()
|
||||
|
||||
if required_fields:
|
||||
self._validate_input(required_fields, optional_fields)
|
||||
|
||||
if sanitization_rules:
|
||||
self.request_body = self._sanitize_input(self.request_body, sanitization_rules)
|
||||
|
||||
else:
|
||||
self.request_body = self._sanitize_input(self.request_body, self.default_sanitization_rules)
|
||||
|
||||
return self.request_body
|
||||
|
||||
def _validate_input(self, required_fields: list, optional_fields: list) -> None | Exception:
|
||||
"""
|
||||
Validates that the incoming JSON request contains all required fields.
|
||||
|
||||
:param required_fields: list of required fields.
|
||||
:param optional_fields: list of non-mandatory fields.
|
||||
"""
|
||||
|
||||
if not optional_fields:
|
||||
# if there's no optional field we want the exact number of parameters
|
||||
if len(self.request_body) > len(required_fields):
|
||||
raise JSONParametersTooManyException()
|
||||
|
||||
missing_fields = []
|
||||
|
||||
for field in required_fields:
|
||||
if field not in self.request_body:
|
||||
missing_fields.append(field)
|
||||
|
||||
if missing_fields:
|
||||
raise JSONParametersMissingException()
|
||||
|
||||
# TODO: should we report the missing fields back to the user?
|
||||
|
||||
return None
|
||||
|
||||
def _sanitize_input(self, data: any, sanitization_rules: list) -> any:
|
||||
"""
|
||||
Cleans the input data.
|
||||
|
||||
:param data: any parameter in the request.
|
||||
:param sanitization_rules: list of cleaning rules.
|
||||
:return: the cleaned data.
|
||||
"""
|
||||
|
||||
if isinstance(data, dict):
|
||||
sanitized_data = {}
|
||||
|
||||
for key, value in data.items():
|
||||
if value:
|
||||
sanitized_value = self._sanitize_input(value, sanitization_rules)
|
||||
|
||||
if sanitized_value:
|
||||
sanitized_data[key] = sanitized_value
|
||||
|
||||
|
||||
else:
|
||||
sanitized_data[key] = value
|
||||
|
||||
return sanitized_data
|
||||
|
||||
elif isinstance(data, list):
|
||||
sanitized_data = []
|
||||
|
||||
for item in data:
|
||||
sanitized_value = self._sanitize_input(item, sanitization_rules)
|
||||
|
||||
if item:
|
||||
sanitized_data.append(sanitized_value)
|
||||
|
||||
return sanitized_data
|
||||
|
||||
elif isinstance(data, str):
|
||||
if 'string' in sanitization_rules.keys():
|
||||
if 'strip' in sanitization_rules['string']:
|
||||
return Filter.strip(data)
|
||||
|
||||
else:
|
||||
return data
|
||||
|
||||
class JSONParametersNonValidException(Exception):
|
||||
|
||||
"""
|
||||
Not JSON data.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
class JSONParametersMissingException(Exception):
|
||||
|
||||
"""
|
||||
The parameters we're looking for are completely missing.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
class JSONParametersTooManyException(Exception):
|
||||
|
||||
"""
|
||||
More parameters than expected is an error.
|
||||
"""
|
||||
|
||||
pass
|
|
@ -27,10 +27,10 @@ class ProtectedHandler(BaseHandler):
|
|||
|
||||
account_data = {}
|
||||
|
||||
def prepare(self):
|
||||
async def prepare(self):
|
||||
self.authentication_verify_access_token_service = AuthenticationVerifyAccessTokenService()
|
||||
|
||||
super().prepare()
|
||||
await super().prepare()
|
||||
|
||||
def initialize_account(self):
|
||||
"""
|
||||
|
|
|
@ -1,58 +1,49 @@
|
|||
from piracyshield_component.log.logger import Logger
|
||||
from piracyshield_component.exception import ApplicationException
|
||||
|
||||
from .response import ResponseHandler
|
||||
|
||||
import tornado.web
|
||||
import time
|
||||
from collections import deque
|
||||
from tornado.locks import Lock
|
||||
|
||||
from interceptors.logger import LoggerInterceptor
|
||||
from interceptors.security.blacklist import BlacklistInterceptor
|
||||
from interceptors.security.rate_limit import RateLimitInterceptor
|
||||
|
||||
from ioutils.errors import ErrorCode, ErrorMessage
|
||||
|
||||
class RequestHandler(ResponseHandler):
|
||||
|
||||
"""
|
||||
Requests gateway.
|
||||
|
||||
The requestor gets blocked here before proceding any further.
|
||||
Interceptors are implemented to allow certain features to take part of the action.
|
||||
|
||||
Current actived interceptors:
|
||||
- security: takes care of the IP addresses exceeding limits (ex. applies the ban rules when triggered).
|
||||
- rate limit: the IP will get a 429 when trying to perform too many requests on a given timeframe.
|
||||
"""
|
||||
|
||||
# override the default methods
|
||||
SUPPORTED_METHODS = ("GET", "POST")
|
||||
interceptors = [
|
||||
LoggerInterceptor,
|
||||
BlacklistInterceptor,
|
||||
RateLimitInterceptor
|
||||
]
|
||||
|
||||
# max requests allowed in a second
|
||||
MAX_REQUESTS_PER_SECOND = 100
|
||||
async def prepare(self):
|
||||
await self.run_interceptors()
|
||||
|
||||
# requests container
|
||||
REQUESTS = {}
|
||||
async def run_interceptors(self):
|
||||
for interceptor in self.interceptors:
|
||||
try:
|
||||
await interceptor().execute(self)
|
||||
|
||||
def prepare(self) -> None:
|
||||
"""
|
||||
Handles the request general procedures.
|
||||
This method implements a very simple request limit check.
|
||||
"""
|
||||
|
||||
self.application.logger.debug(f'> GET `{self.request.uri}` from `{self.request.remote_ip}`')
|
||||
|
||||
# get the current timestamp in seconds
|
||||
timestamp = int(time.time())
|
||||
|
||||
# TODO: this should be better handled and also provide a way to temporary ban each IP when flooding.
|
||||
|
||||
# check if the number of requests for this second has exceeded the limit
|
||||
if timestamp in self.REQUESTS:
|
||||
if self.REQUESTS[timestamp] >= self.MAX_REQUESTS_PER_SECOND:
|
||||
self.error(status_code = 429, error_code = ErrorCode.TOO_MANY_REQUESTS, message = ErrorMessage.TOO_MANY_REQUESTS)
|
||||
except Exception as InterceptorException:
|
||||
self.error(
|
||||
status_code = InterceptorException.status_code,
|
||||
error_code = InterceptorException.error_code,
|
||||
message = InterceptorException.error_message
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
# increment the number of requests for this second
|
||||
self.REQUESTS[timestamp] = self.REQUESTS.get(timestamp, 0) + 1
|
||||
|
||||
# decrement the number of requests after one second
|
||||
tornado.ioloop.IOLoop.current().call_later(1.0, self._decrement_requests_count, timestamp)
|
||||
|
||||
def _decrement_requests_count(self, timestamp):
|
||||
"""
|
||||
Decrement the requests count per timestamp.
|
||||
|
||||
:param timestamp: current timestamp.
|
||||
"""
|
||||
|
||||
if timestamp in self.REQUESTS:
|
||||
self.REQUESTS[timestamp] -= 1
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
import tornado.web
|
||||
from tornado.web import RequestHandler
|
||||
|
||||
from piracyshield_component.config import Config
|
||||
|
||||
from .errors import ErrorCode, ErrorMessage
|
||||
|
||||
import datetime
|
||||
import json
|
||||
|
||||
class ResponseHandler(tornado.web.RequestHandler):
|
||||
class ResponseHandler(RequestHandler):
|
||||
|
||||
"""
|
||||
Response handler.
|
||||
|
||||
Prepares the response parameters for the next layer (Request).
|
||||
"""
|
||||
|
||||
def set_default_headers(self) -> None:
|
||||
|
@ -16,11 +19,12 @@ class ResponseHandler(tornado.web.RequestHandler):
|
|||
Sets the default headers.
|
||||
"""
|
||||
|
||||
self.set_header('Access-Control-Allow-Origin', '*')
|
||||
# allow CORS only from current domain
|
||||
self.set_header('Access-Control-Allow-Origin', self.application.domain)
|
||||
|
||||
self.set_header('Access-Control-Allow-Methods', 'POST, GET')
|
||||
self.set_header('Access-Control-Allow-Methods', self.application.allowed_methods)
|
||||
|
||||
self.set_header('Access-Control-Max-Age', 3600)
|
||||
self.set_header('Access-Control-Max-Age', self.application.max_age)
|
||||
|
||||
# response for this API is always a JSON
|
||||
self.set_header('Content-Type', 'application/json')
|
||||
|
@ -44,8 +48,8 @@ class ResponseHandler(tornado.web.RequestHandler):
|
|||
value = value,
|
||||
expires_days = 1,
|
||||
httponly = True,
|
||||
samesite = "Strict"
|
||||
# secure = True
|
||||
samesite = "Strict",
|
||||
secure = True
|
||||
)
|
||||
|
||||
def get_refresh_cookie(self) -> any:
|
||||
|
@ -106,8 +110,9 @@ class ResponseHandler(tornado.web.RequestHandler):
|
|||
'status': 'success'
|
||||
}
|
||||
|
||||
# this will return empty data as well
|
||||
response['data'] = data
|
||||
# ensures these data are appended in any case, even if empty
|
||||
if isinstance(data, str) or isinstance(data, list) or isinstance(data, dict):
|
||||
response['data'] = data
|
||||
|
||||
# generic purpose informations that we want to communicate
|
||||
if note:
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
tornado
|
||||
tornado ~= 6.3.3
|
||||
pytest
|
||||
|
|
|
@ -40,7 +40,8 @@ class AuthenticationLoginHandler(BaseHandler):
|
|||
None,
|
||||
self.process,
|
||||
self.request_data.get('email'),
|
||||
self.request_data.get('password')
|
||||
self.request_data.get('password'),
|
||||
self.request.remote_ip
|
||||
)
|
||||
|
||||
self.success(data = {
|
||||
|
@ -51,13 +52,14 @@ class AuthenticationLoginHandler(BaseHandler):
|
|||
except ApplicationException as e:
|
||||
self.error(status_code = 400, error_code = e.code, message = e.message)
|
||||
|
||||
def process(self, email: str, password: str) -> tuple:
|
||||
def process(self, email: str, password: str, ip_address: str) -> tuple:
|
||||
authentication_authenticate_service = AuthenticationAuthenticateService()
|
||||
|
||||
# try to authenticate
|
||||
payload = authentication_authenticate_service.execute(
|
||||
email = email,
|
||||
password = password
|
||||
password = password,
|
||||
ip_address = ip_address
|
||||
)
|
||||
|
||||
authentication_generate_access_token_service = AuthenticationGenerateAccessTokenService()
|
||||
|
|
|
@ -36,7 +36,7 @@ class AuthenticationRefreshHandler(BaseHandler):
|
|||
return
|
||||
|
||||
if not self.request_data.get('refresh_token'):
|
||||
return self.error(status_code = 403, error_code = ErrorCode.MISSING_REFRESH_TOKEN, message = ErrorMessage.MISSING_REFRESH_TOKEN)
|
||||
return self.error(status_code = 401, error_code = ErrorCode.MISSING_REFRESH_TOKEN, message = ErrorMessage.MISSING_REFRESH_TOKEN)
|
||||
|
||||
refresh_token = self.request_data.get('refresh_token')
|
||||
|
||||
|
|
Loading…
Reference in a new issue