Various fixes.

This commit is contained in:
Daniele Maglie 2024-02-07 14:54:36 +01:00
parent 0fb6d08a7f
commit e32bc9d4b7
19 changed files with 452 additions and 59 deletions

View file

@ -14,9 +14,11 @@ class PingHandler(RequestHandler):
Handle simple pings to check the API availability.
"""
# let's be fun
responses = [
'Pong!',
'Do APIs dream of electric requests?'
'Do APIs dream of electric requests?',
'So long and thanks for all the requests'
]
def get(self):

View file

@ -7,7 +7,7 @@ class LoggerInterceptor:
"""
async def execute(self, r):
r.application.logger.debug(f'> GET `{r.request.uri}` from `{r.request.remote_ip}`')
r.application.logger.debug(f'> GET `{r.request.uri}` from `{r.ip_address}`')
class LoggerInterceptorException(Exception):

View file

@ -13,8 +13,10 @@ class BlacklistInterceptor:
async def execute(self, r):
await self._prepare_modules()
# if we have a valid IP address, let's check if it has been blacklisted
if r.ip_address:
if self.security_blacklist_exists_by_ip_address_service.execute(
ip_address = r.request.remote_ip
ip_address = r.ip_address
) == True:
raise BlacklistInterceptorException()

View file

@ -14,7 +14,7 @@ class RateLimitInterceptor:
SUPPORTED_METHODS = ("GET", "POST")
# max requests allowed in a second
MAX_REQUESTS_PER_SECOND = 100
MAX_REQUESTS_PER_SECOND = 1000
# requests container
REQUESTS = {}

View file

@ -31,7 +31,24 @@ class RequestHandler(ResponseHandler):
RateLimitInterceptor
]
ip_address = None
async def prepare(self):
self.ip_address = self.request.remote_ip
# honor the IP passed via reverse proxy
if 'X-Forwarded-For' in self.request.headers:
# TODO: we need validation as this could be easily spoofed.
# we could have multiple IPs here
forwarded_ip_addresses = str(self.request.headers.get("X-Forwarded-For")).split(',')
# get only the forwarded client IP
self.ip_address = forwarded_ip_addresses[0]
# if any port is specified just ignore it and keep only the IP
self.ip_address.split(':')[0]
await self.run_interceptors()
async def run_interceptors(self):

View file

@ -0,0 +1,47 @@
import pytest
from piracyshield_component.validation.validator import Validator
from piracyshield_component.validation.rules.ipv6 import IPv6
from piracyshield_component.validation.validator import ValidatorRuleNonValidException
class TestGeneral:
valid_ipv6_list = [
"2001:0db8:85a3:0000:0000:8a2e:0370:7334",
"fe80:0000:0000:0000:0204:61ff:fe9d:f156",
"2001:0db8:0000:0000:0000:0000:0000:0001",
"fe80:0000:0000:0000:0204:61ff:fe9d:f157",
"2001:0db8:1234:5678:90ab:cdef:0000:0000",
"2606:2800:220:1:248:1893:25c8:1946",
"2001:4860:4860:0:0:0:0:6464",
"2001:4860:4860:0:0:0:0:8844",
"2001:4860:4860:0:0:0:0:8888",
"2001:4860:4860:0:0:0:0:64",
"2606:4700:4700:0:0:0:0:64",
"2606:4700:4700:0:0:0:0:1001",
"2606:4700:4700:0:0:0:0:1111",
"2606:4700:4700:0:0:0:0:6400",
"2a01:4f8:10a:1::4",
"::1",
"::"
]
def test_valid_ipv6(self):
"""
Check if the IPv6 list is valid.
"""
for ipv6 in self.valid_ipv6_list:
rules = [
IPv6()
]
v = Validator(ipv6, rules)
v.validate()
if len(v.errors) != 0:
print(ipv6, v.errors)
assert len(v.errors) == 0

View file

@ -0,0 +1,78 @@
import pytest
from piracyshield_component.validation.validator import Validator
from piracyshield_component.validation.rules.ipv6 import IPv6
from piracyshield_component.validation.validator import ValidatorRuleNonValidException
class TestGeneral:
valid_ipv6_list = [
# wrong length
"2001:0db8:85a3:0000:0000:8a2e:0370:7334:1234",
"FE80:0000:0000:0000:0202:B3FF:FE1E:8329:ABCD:EF12",
"1234:5678:9ABC:DEF0:1234:5678:9ABC:DEF0:1234",
# non valid characters
"2001:0db8:85a3:0000:0000:8a2e:0370:73G4",
"FE80:0000:0000:0000:0202:B3FF:FE1E:832Z",
"2001:0db8::85a3::7334",
# too many segments
"FE80:0000:0000:0000:0202:B3FF:FE1E:8329:1234",
"2001:0db8:85a3:0000:0000:8a2e:0370:7334:5678",
# too few segments
"2001:0db8:85a3:0000:8a2e:0370",
"FE80:0000:B3FF:FE1E:8329",
# consecutive double colons
"2001:0db8::85a3::7334",
"::1234::5678::9ABC",
# leading zeros in a quad
"2001:0db8:85a3:00001:0000:8a2e:0370:7334",
"FE80:00000:0000:0000:0202:B3FF:FE1E:8329",
# non valid dot in notation
"2001:0db8:85a3:0000:0000:8a2e:0370.7334",
"FE80:0000:0000:0000:0202:B3FF:FE1E:8329.1234",
# mixed ipv6/ipv4
"2001:0db8:85a3:0000:0000:8a2e:192.168.1.1.1",
"::ffff:192.168.0.256",
# segment with more than 4 hex characters
"2001:0db8:85a3:00000:0000:8a2e:0370:7334",
"FE80:0000:0000:10000:0202:B3FF:FE1E:8329",
# segment with non-hex characters
"2001:0db8:85g3:0000:0000:8a2e:0370:7334",
"FE80:0000:0000:0000:0202:B3ZZ:FE1E:8329",
# incorrect ipv4 mapping
"::ffff:192.168.0.999",
"::ffff:299.255.255.255"
]
def test_valid_ipv6(self):
"""
Check if the IPv6 list is valid.
"""
for ipv6 in self.valid_ipv6_list:
rules = [
IPv6()
]
v = Validator(ipv6, rules)
v.validate()
if len(v.errors) == 0:
print(ipv6, v.errors)
assert len(v.errors) != 0

View file

@ -0,0 +1,46 @@
import pytest
from piracyshield_component.validation.validator import Validator
from piracyshield_component.validation.rules.cidr_syntax_ipv4 import CIDRSyntaxIPv4
from piracyshield_component.validation.validator import ValidatorRuleNonValidException
from random import randint
class TestGeneral:
max_cidr_classes = 10000
valid_cidr_ipv4_list = []
def setup_method(self):
self.valid_cidr_ipv4_list = self.__generate_random_list(self.max_cidr_classes)
def test_valid_cidr_ipv4(self):
"""
Check if the CIDR IPv4 syntax is valid.
"""
for cidr_ipv4 in self.valid_cidr_ipv4_list:
rules = [
CIDRSyntaxIPv4()
]
v = Validator(cidr_ipv4, rules)
v.validate()
if len(v.errors) != 0:
print(cidr_ipv4, v.errors)
assert len(v.errors) == 0
def __generate_random_list(self, size: int):
cidr_list = []
subnet_variety = [8, 16, 24] # Different subnet masks for variety
for i in range(0, size):
for subnet in subnet_variety:
cidr_list.append(f"{randint(1, 254)}.0.0.0/{subnet}")
return cidr_list

View file

@ -0,0 +1,54 @@
import pytest
from piracyshield_component.validation.validator import Validator
from piracyshield_component.validation.rules.cidr_syntax_ipv6 import CIDRSyntaxIPv6
from piracyshield_component.validation.validator import ValidatorRuleNonValidException
from random import randint, getrandbits
from ipaddress import IPv6Address
class TestGeneral:
max_cidr_classes = 10000
valid_cidr_ipv6_list = []
def setup_method(self):
self.valid_cidr_ipv6_list = self.__generate_random_list(self.max_cidr_classes)
def test_valid_cidr_ipv6(self):
"""
Check if the CIDR IPv6 syntax is valid.
"""
for cidr_ipv6 in self.valid_cidr_ipv6_list:
rules = [
CIDRSyntaxIPv6()
]
v = Validator(cidr_ipv6, rules)
v.validate()
if len(v.errors) != 0:
print(cidr_ipv6, v.errors)
assert len(v.errors) == 0
def __generate_random_list(self, size: int):
cidr_list = []
for _ in range(size):
# Generating a random IPv6 address
random_ip = IPv6Address(getrandbits(128))
# Choosing a random subnet mask
subnet_mask = randint(1, 128)
# Combining to form CIDR notation
cidr = f"{random_ip}/{subnet_mask}"
cidr_list.append(cidr)
return cidr_list

View file

@ -15,14 +15,14 @@ class TestReporterCreateTicket:
ticket_wait_time = 76
ticket_parameters = {
'dda_id': '002ad48ea02a43db9003b4f15f1da9b3',
'dda_id': '7b3d774097ca477687f29ad0968833ac',
'description': '__MOCK_TICKET__',
'forensic_evidence': {
'hash': {}
},
'fqdn': [
'mock-website.com',
'google.com'
'mock-website-two.com'
],
'ipv4': [
'9.8.7.6',
@ -44,8 +44,6 @@ class TestReporterCreateTicket:
create_response = authenticated_post_request('/api/v1/ticket/create', self.access_token, self.ticket_parameters)
print(" RES -> ", create_response.json())
assert create_response.status_code == 200
assert create_response.json()['status'] == 'success'

View file

@ -30,6 +30,15 @@ class TestProviderSetTicketItems:
assert response.status_code == 200
assert response.json()['status'] == 'success'
def test_set_unprocessed_fqdn(self):
response = authenticated_post_request('/api/v1/ticket/item/set/unprocessed', self.access_token, {
'value': 'mock-website-two.com',
'reason': 'ALREADY_BLOCKED'
})
assert response.status_code == 200
assert response.json()['status'] == 'success'
def test_set_processed_ipv4(self):
response = authenticated_post_request('/api/v1/ticket/item/set/processed', self.access_token, {
'value': '9.8.7.6'

View file

@ -39,8 +39,6 @@ class CreateProviderAccountHandler(ProtectedHandler):
if self.handle_post(self.required_fields) == False:
return
print(self.request_data)
try:
# ensure that we have the proper permissions for this operation
self.permission_service.can_create_account()

View file

@ -41,7 +41,7 @@ class AuthenticationLoginHandler(BaseHandler):
authentication_authenticate_service.execute,
self.request_data.get('email'),
self.request_data.get('password'),
self.request.remote_ip
self.ip_address
)
# store the refresh token in a http-only secure cookie

View file

@ -57,7 +57,7 @@ class AuthenticationRefreshHandler(BaseHandler):
None,
authentication_refresh_access_token_service.execute,
refresh_token,
self.request.remote_ip
self.ip_address
)
# return the access_token

View file

@ -0,0 +1,67 @@
import sys
import os
# I hate python imports
current = os.path.dirname(os.path.realpath(__file__))
parent = os.path.dirname(current)
sys.path.append(parent)
import tornado
from ioutils.protected import ProtectedHandler
from ioutils.errors import ErrorCode, ErrorMessage
from piracyshield_service.dda.get_by_identifier import DDAGetByIdentifierService
from piracyshield_service.dda.get_by_identifier_for_reporter import DDAGetByIdentifierForReporterService
from piracyshield_data_model.account.role.model import AccountRoleModel
from piracyshield_component.exception import ApplicationException
class GetByIdentifierDDAHandler(ProtectedHandler):
"""
Handles getting a single DDA by its identifier.
"""
required_fields = [
'dda_id'
]
async def post(self):
if self.initialize_account() == False:
return
if self.handle_post(self.required_fields) == False:
return
try:
# check what level of view we have
if self.account_data.get('role') == AccountRoleModel.INTERNAL.value:
dda_get_by_identifier_service = DDAGetByIdentifierService()
response = await tornado.ioloop.IOLoop.current().run_in_executor(
None,
dda_get_by_identifier_service.execute,
self.request_data.get('dda_id')
)
self.success(data = response)
elif self.account_data.get('role') == AccountRoleModel.REPORTER.value:
dda_get_by_identifier_for_reporter_service = DDAGetByIdentifierForReporterService()
response = await tornado.ioloop.IOLoop.current().run_in_executor(
None,
dda_get_by_identifier_for_reporter_service.execute,
self.request_data.get('dda_id'),
self.account_data.get('account_id')
)
self.success(data = response)
else:
self.error(status_code = 403, error_code = ErrorCode.PERMISSION_DENIED, message = ErrorMessage.PERMISSION_DENIED)
except ApplicationException as e:
self.error(status_code = 400, error_code = e.code, message = e.message)

View file

@ -1,53 +1,102 @@
import sys
import os
# I hate python imports
current = os.path.dirname(os.path.realpath(__file__))
parent = os.path.dirname(current)
sys.path.append(parent)
import json
import tornado
from ioutils.protected import ProtectedHandler
from ioutils.errors import ErrorCode, ErrorMessage
from piracyshield_service.forensic.create_archive import ForensicCreateArchiveService
from piracyshield_component.environment import Environment
from piracyshield_component.exception import ApplicationException
class UploadForensicHandler(ProtectedHandler):
"""
Appends the forensic evidence.
"""
def get_received_chunks_info_path(self, filename):
return f"{Environment.CACHE_PATH}/{filename}_info.json"
def update_received_chunks_info(self, filename, chunk_index):
info_path = self.get_received_chunks_info_path(filename)
if os.path.exists(info_path):
with open(info_path, 'r') as file:
info = json.load(file)
received_chunks = set(info['received_chunks'])
else:
received_chunks = set()
received_chunks.add(chunk_index)
with open(info_path, 'w') as file:
json.dump({"received_chunks": list(received_chunks)}, file)
def check_all_chunks_received(self, filename, total_chunks):
info_path = self.get_received_chunks_info_path(filename)
if os.path.exists(info_path):
with open(info_path, 'r') as file:
info = json.load(file)
return len(info['received_chunks']) == total_chunks
return False
async def post(self, ticket_id):
if self.initialize_account() == False:
if not self.initialize_account():
return
# verify permissions
self.permission_service.can_upload_ticket()
# TODO: we could add some mime type filters but we might need to interrogate the ticket model to have a dynamic implementation.
# no archive passed?
if 'archive' not in self.request.files:
if 'archive' not in self.request.files or 'chunkIndex' not in self.request.arguments or 'totalChunks' not in self.request.arguments:
return self.error(status_code = 400, error_code = ErrorCode.MISSING_FILE, message = ErrorMessage.MISSING_FILE)
# get the file data
archive_handler = self.request.files['archive'][0]
try:
archive_handler = self.request.files['archive'][0]
chunk_index = int(self.request.arguments['chunkIndex'][0])
total_chunks = int(self.request.arguments['totalChunks'][0])
original_filename = self.get_argument('originalFileName')
filename = f"{ticket_id}_{original_filename}"
temp_storage_path = f"{Environment.CACHE_PATH}/{filename}"
chunk_temp_path = temp_storage_path + f"_part{chunk_index}"
with open(chunk_temp_path, "wb") as temp_file:
temp_file.write(archive_handler['body'])
# update received chunks
self.update_received_chunks_info(filename, chunk_index)
# did we get every chunk?
if self.check_all_chunks_received(filename, total_chunks):
# proceed to assemble all the chunks into the a single file
with open(temp_storage_path, "wb") as final_file:
for i in range(total_chunks):
part_file_path = temp_storage_path + f"_part{i}"
if not os.path.exists(part_file_path):
raise ApplicationException("Missing file chunk: " + str(i))
with open(part_file_path, "rb") as part_file:
final_file.write(part_file.read())
os.remove(part_file_path)
# process the final file
forensic_create_archive_service = ForensicCreateArchiveService()
await tornado.ioloop.IOLoop.current().run_in_executor(
None,
forensic_create_archive_service.execute,
ticket_id,
archive_handler['filename'],
archive_handler['body']
filename
)
# clean stored info
os.remove(self.get_received_chunks_info_path(filename))
# always answer or the frontend won't send anything more
self.success()
except ApplicationException as e:

View file

@ -52,6 +52,7 @@ class CreateTicketHandler(ProtectedHandler):
ticket_create_service = TicketCreateService()
if self.account_data.get('role') == AccountRoleModel.INTERNAL.value:
ticket_id, revoke_time = await tornado.ioloop.IOLoop.current().run_in_executor(
None,
ticket_create_service.execute,
@ -70,5 +71,28 @@ class CreateTicketHandler(ProtectedHandler):
note = f'Ticket created. If this is a mistake, you have {revoke_time} seconds to remove it before it gets visible to the providers.'
)
elif self.account_data.get('role') == AccountRoleModel.REPORTER.value:
ticket_id, revoke_time = await tornado.ioloop.IOLoop.current().run_in_executor(
None,
ticket_create_service.execute,
self.request_data.get('dda_id'),
self.request_data.get('forensic_evidence'),
self.request_data.get('fqdn') or [],
self.request_data.get('ipv4') or [],
self.request_data.get('ipv6') or [],
self.request_data.get('assigned_to') or [], # TEMPORARY
#[], # currently, we do not allow any choice by the reporter to which provider will receive the ticket
self.account_data.get('account_id'),
self.request_data.get('description') or None
)
self.success(
data = { 'ticket_id': ticket_id },
note = f'Ticket created. If this is a mistake, you have {revoke_time} seconds to remove it before it gets visible to the providers.'
)
else:
self.error(status_code = 403, error_code = ErrorCode.PERMISSION_DENIED, message = ErrorMessage.PERMISSION_DENIED)
except ApplicationException as e:
self.error(status_code = 400, error_code = e.code, message = e.message)

View file

@ -87,6 +87,7 @@ from .handlers.whitelist.set_status import SetStatusActiveWhitelistItemHandler,
from .handlers.whitelist.remove import RemoveWhitelistItemHandler
from .handlers.dda.create import CreateDDAHandler
from .handlers.dda.get_by_identifier import GetByIdentifierDDAHandler
from .handlers.dda.get_all import GetAllDDAHandler
from .handlers.dda.get_all_by_account import GetAllByAccountDDAHandler
from .handlers.dda.get_global import GetGlobalDDAHandler
@ -212,6 +213,7 @@ class APIv1:
# DDA management
(r"/dda/create", CreateDDAHandler),
(r"/dda/get", GetByIdentifierDDAHandler),
(r"/dda/get/all", GetAllDDAHandler),
(r"/dda/get/all/by_account", GetAllByAccountDDAHandler),
(r"/dda/get/global", GetGlobalDDAHandler),