mirror of
https://github.com/fuckpiracyshield/api.git
synced 2024-12-22 02:40:48 +01:00
Various fixes.
This commit is contained in:
parent
0fb6d08a7f
commit
e32bc9d4b7
19 changed files with 452 additions and 59 deletions
|
@ -14,9 +14,11 @@ class PingHandler(RequestHandler):
|
||||||
Handle simple pings to check the API availability.
|
Handle simple pings to check the API availability.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# let's be fun
|
||||||
responses = [
|
responses = [
|
||||||
'Pong!',
|
'Pong!',
|
||||||
'Do APIs dream of electric requests?'
|
'Do APIs dream of electric requests?',
|
||||||
|
'So long and thanks for all the requests'
|
||||||
]
|
]
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
|
|
|
@ -7,7 +7,7 @@ class LoggerInterceptor:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def execute(self, r):
|
async def execute(self, r):
|
||||||
r.application.logger.debug(f'> GET `{r.request.uri}` from `{r.request.remote_ip}`')
|
r.application.logger.debug(f'> GET `{r.request.uri}` from `{r.ip_address}`')
|
||||||
|
|
||||||
class LoggerInterceptorException(Exception):
|
class LoggerInterceptorException(Exception):
|
||||||
|
|
||||||
|
|
|
@ -13,8 +13,10 @@ class BlacklistInterceptor:
|
||||||
async def execute(self, r):
|
async def execute(self, r):
|
||||||
await self._prepare_modules()
|
await self._prepare_modules()
|
||||||
|
|
||||||
|
# if we have a valid IP address, let's check if it has been blacklisted
|
||||||
|
if r.ip_address:
|
||||||
if self.security_blacklist_exists_by_ip_address_service.execute(
|
if self.security_blacklist_exists_by_ip_address_service.execute(
|
||||||
ip_address = r.request.remote_ip
|
ip_address = r.ip_address
|
||||||
) == True:
|
) == True:
|
||||||
raise BlacklistInterceptorException()
|
raise BlacklistInterceptorException()
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ class RateLimitInterceptor:
|
||||||
SUPPORTED_METHODS = ("GET", "POST")
|
SUPPORTED_METHODS = ("GET", "POST")
|
||||||
|
|
||||||
# max requests allowed in a second
|
# max requests allowed in a second
|
||||||
MAX_REQUESTS_PER_SECOND = 100
|
MAX_REQUESTS_PER_SECOND = 1000
|
||||||
|
|
||||||
# requests container
|
# requests container
|
||||||
REQUESTS = {}
|
REQUESTS = {}
|
||||||
|
|
|
@ -31,7 +31,24 @@ class RequestHandler(ResponseHandler):
|
||||||
RateLimitInterceptor
|
RateLimitInterceptor
|
||||||
]
|
]
|
||||||
|
|
||||||
|
ip_address = None
|
||||||
|
|
||||||
async def prepare(self):
|
async def prepare(self):
|
||||||
|
self.ip_address = self.request.remote_ip
|
||||||
|
|
||||||
|
# honor the IP passed via reverse proxy
|
||||||
|
if 'X-Forwarded-For' in self.request.headers:
|
||||||
|
# TODO: we need validation as this could be easily spoofed.
|
||||||
|
|
||||||
|
# we could have multiple IPs here
|
||||||
|
forwarded_ip_addresses = str(self.request.headers.get("X-Forwarded-For")).split(',')
|
||||||
|
|
||||||
|
# get only the forwarded client IP
|
||||||
|
self.ip_address = forwarded_ip_addresses[0]
|
||||||
|
|
||||||
|
# if any port is specified just ignore it and keep only the IP
|
||||||
|
self.ip_address.split(':')[0]
|
||||||
|
|
||||||
await self.run_interceptors()
|
await self.run_interceptors()
|
||||||
|
|
||||||
async def run_interceptors(self):
|
async def run_interceptors(self):
|
||||||
|
|
47
tests/01_general/test_0101_valid_ipv6.py
Normal file
47
tests/01_general/test_0101_valid_ipv6.py
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from piracyshield_component.validation.validator import Validator
|
||||||
|
from piracyshield_component.validation.rules.ipv6 import IPv6
|
||||||
|
|
||||||
|
from piracyshield_component.validation.validator import ValidatorRuleNonValidException
|
||||||
|
|
||||||
|
class TestGeneral:
|
||||||
|
|
||||||
|
valid_ipv6_list = [
|
||||||
|
"2001:0db8:85a3:0000:0000:8a2e:0370:7334",
|
||||||
|
"fe80:0000:0000:0000:0204:61ff:fe9d:f156",
|
||||||
|
"2001:0db8:0000:0000:0000:0000:0000:0001",
|
||||||
|
"fe80:0000:0000:0000:0204:61ff:fe9d:f157",
|
||||||
|
"2001:0db8:1234:5678:90ab:cdef:0000:0000",
|
||||||
|
"2606:2800:220:1:248:1893:25c8:1946",
|
||||||
|
"2001:4860:4860:0:0:0:0:6464",
|
||||||
|
"2001:4860:4860:0:0:0:0:8844",
|
||||||
|
"2001:4860:4860:0:0:0:0:8888",
|
||||||
|
"2001:4860:4860:0:0:0:0:64",
|
||||||
|
"2606:4700:4700:0:0:0:0:64",
|
||||||
|
"2606:4700:4700:0:0:0:0:1001",
|
||||||
|
"2606:4700:4700:0:0:0:0:1111",
|
||||||
|
"2606:4700:4700:0:0:0:0:6400",
|
||||||
|
"2a01:4f8:10a:1::4",
|
||||||
|
"::1",
|
||||||
|
"::"
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_valid_ipv6(self):
|
||||||
|
"""
|
||||||
|
Check if the IPv6 list is valid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for ipv6 in self.valid_ipv6_list:
|
||||||
|
rules = [
|
||||||
|
IPv6()
|
||||||
|
]
|
||||||
|
|
||||||
|
v = Validator(ipv6, rules)
|
||||||
|
|
||||||
|
v.validate()
|
||||||
|
|
||||||
|
if len(v.errors) != 0:
|
||||||
|
print(ipv6, v.errors)
|
||||||
|
|
||||||
|
assert len(v.errors) == 0
|
78
tests/01_general/test_0102_non_valid_ipv6.py
Normal file
78
tests/01_general/test_0102_non_valid_ipv6.py
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from piracyshield_component.validation.validator import Validator
|
||||||
|
from piracyshield_component.validation.rules.ipv6 import IPv6
|
||||||
|
|
||||||
|
from piracyshield_component.validation.validator import ValidatorRuleNonValidException
|
||||||
|
|
||||||
|
class TestGeneral:
|
||||||
|
|
||||||
|
valid_ipv6_list = [
|
||||||
|
# wrong length
|
||||||
|
"2001:0db8:85a3:0000:0000:8a2e:0370:7334:1234",
|
||||||
|
"FE80:0000:0000:0000:0202:B3FF:FE1E:8329:ABCD:EF12",
|
||||||
|
"1234:5678:9ABC:DEF0:1234:5678:9ABC:DEF0:1234",
|
||||||
|
|
||||||
|
# non valid characters
|
||||||
|
"2001:0db8:85a3:0000:0000:8a2e:0370:73G4",
|
||||||
|
"FE80:0000:0000:0000:0202:B3FF:FE1E:832Z",
|
||||||
|
"2001:0db8::85a3::7334",
|
||||||
|
|
||||||
|
# too many segments
|
||||||
|
"FE80:0000:0000:0000:0202:B3FF:FE1E:8329:1234",
|
||||||
|
"2001:0db8:85a3:0000:0000:8a2e:0370:7334:5678",
|
||||||
|
|
||||||
|
# too few segments
|
||||||
|
"2001:0db8:85a3:0000:8a2e:0370",
|
||||||
|
"FE80:0000:B3FF:FE1E:8329",
|
||||||
|
|
||||||
|
# consecutive double colons
|
||||||
|
"2001:0db8::85a3::7334",
|
||||||
|
"::1234::5678::9ABC",
|
||||||
|
|
||||||
|
# leading zeros in a quad
|
||||||
|
|
||||||
|
"2001:0db8:85a3:00001:0000:8a2e:0370:7334",
|
||||||
|
"FE80:00000:0000:0000:0202:B3FF:FE1E:8329",
|
||||||
|
|
||||||
|
# non valid dot in notation
|
||||||
|
"2001:0db8:85a3:0000:0000:8a2e:0370.7334",
|
||||||
|
"FE80:0000:0000:0000:0202:B3FF:FE1E:8329.1234",
|
||||||
|
|
||||||
|
# mixed ipv6/ipv4
|
||||||
|
|
||||||
|
"2001:0db8:85a3:0000:0000:8a2e:192.168.1.1.1",
|
||||||
|
"::ffff:192.168.0.256",
|
||||||
|
|
||||||
|
# segment with more than 4 hex characters
|
||||||
|
|
||||||
|
"2001:0db8:85a3:00000:0000:8a2e:0370:7334",
|
||||||
|
"FE80:0000:0000:10000:0202:B3FF:FE1E:8329",
|
||||||
|
|
||||||
|
# segment with non-hex characters
|
||||||
|
"2001:0db8:85g3:0000:0000:8a2e:0370:7334",
|
||||||
|
"FE80:0000:0000:0000:0202:B3ZZ:FE1E:8329",
|
||||||
|
|
||||||
|
# incorrect ipv4 mapping
|
||||||
|
"::ffff:192.168.0.999",
|
||||||
|
"::ffff:299.255.255.255"
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_valid_ipv6(self):
|
||||||
|
"""
|
||||||
|
Check if the IPv6 list is valid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for ipv6 in self.valid_ipv6_list:
|
||||||
|
rules = [
|
||||||
|
IPv6()
|
||||||
|
]
|
||||||
|
|
||||||
|
v = Validator(ipv6, rules)
|
||||||
|
|
||||||
|
v.validate()
|
||||||
|
|
||||||
|
if len(v.errors) == 0:
|
||||||
|
print(ipv6, v.errors)
|
||||||
|
|
||||||
|
assert len(v.errors) != 0
|
46
tests/01_general/test_0103_cidr_ipv4.py
Normal file
46
tests/01_general/test_0103_cidr_ipv4.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from piracyshield_component.validation.validator import Validator
|
||||||
|
from piracyshield_component.validation.rules.cidr_syntax_ipv4 import CIDRSyntaxIPv4
|
||||||
|
|
||||||
|
from piracyshield_component.validation.validator import ValidatorRuleNonValidException
|
||||||
|
|
||||||
|
from random import randint
|
||||||
|
|
||||||
|
class TestGeneral:
|
||||||
|
|
||||||
|
max_cidr_classes = 10000
|
||||||
|
|
||||||
|
valid_cidr_ipv4_list = []
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
self.valid_cidr_ipv4_list = self.__generate_random_list(self.max_cidr_classes)
|
||||||
|
|
||||||
|
def test_valid_cidr_ipv4(self):
|
||||||
|
"""
|
||||||
|
Check if the CIDR IPv4 syntax is valid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for cidr_ipv4 in self.valid_cidr_ipv4_list:
|
||||||
|
rules = [
|
||||||
|
CIDRSyntaxIPv4()
|
||||||
|
]
|
||||||
|
|
||||||
|
v = Validator(cidr_ipv4, rules)
|
||||||
|
|
||||||
|
v.validate()
|
||||||
|
|
||||||
|
if len(v.errors) != 0:
|
||||||
|
print(cidr_ipv4, v.errors)
|
||||||
|
|
||||||
|
assert len(v.errors) == 0
|
||||||
|
|
||||||
|
def __generate_random_list(self, size: int):
|
||||||
|
cidr_list = []
|
||||||
|
subnet_variety = [8, 16, 24] # Different subnet masks for variety
|
||||||
|
|
||||||
|
for i in range(0, size):
|
||||||
|
for subnet in subnet_variety:
|
||||||
|
cidr_list.append(f"{randint(1, 254)}.0.0.0/{subnet}")
|
||||||
|
|
||||||
|
return cidr_list
|
54
tests/01_general/test_0103_cidr_ipv6.py
Normal file
54
tests/01_general/test_0103_cidr_ipv6.py
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from piracyshield_component.validation.validator import Validator
|
||||||
|
from piracyshield_component.validation.rules.cidr_syntax_ipv6 import CIDRSyntaxIPv6
|
||||||
|
|
||||||
|
from piracyshield_component.validation.validator import ValidatorRuleNonValidException
|
||||||
|
|
||||||
|
from random import randint, getrandbits
|
||||||
|
|
||||||
|
from ipaddress import IPv6Address
|
||||||
|
|
||||||
|
class TestGeneral:
|
||||||
|
|
||||||
|
max_cidr_classes = 10000
|
||||||
|
|
||||||
|
valid_cidr_ipv6_list = []
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
self.valid_cidr_ipv6_list = self.__generate_random_list(self.max_cidr_classes)
|
||||||
|
|
||||||
|
def test_valid_cidr_ipv6(self):
|
||||||
|
"""
|
||||||
|
Check if the CIDR IPv6 syntax is valid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
for cidr_ipv6 in self.valid_cidr_ipv6_list:
|
||||||
|
rules = [
|
||||||
|
CIDRSyntaxIPv6()
|
||||||
|
]
|
||||||
|
|
||||||
|
v = Validator(cidr_ipv6, rules)
|
||||||
|
|
||||||
|
v.validate()
|
||||||
|
|
||||||
|
if len(v.errors) != 0:
|
||||||
|
print(cidr_ipv6, v.errors)
|
||||||
|
|
||||||
|
assert len(v.errors) == 0
|
||||||
|
|
||||||
|
def __generate_random_list(self, size: int):
|
||||||
|
cidr_list = []
|
||||||
|
|
||||||
|
for _ in range(size):
|
||||||
|
# Generating a random IPv6 address
|
||||||
|
random_ip = IPv6Address(getrandbits(128))
|
||||||
|
|
||||||
|
# Choosing a random subnet mask
|
||||||
|
subnet_mask = randint(1, 128)
|
||||||
|
|
||||||
|
# Combining to form CIDR notation
|
||||||
|
cidr = f"{random_ip}/{subnet_mask}"
|
||||||
|
cidr_list.append(cidr)
|
||||||
|
|
||||||
|
return cidr_list
|
|
@ -15,14 +15,14 @@ class TestReporterCreateTicket:
|
||||||
ticket_wait_time = 76
|
ticket_wait_time = 76
|
||||||
|
|
||||||
ticket_parameters = {
|
ticket_parameters = {
|
||||||
'dda_id': '002ad48ea02a43db9003b4f15f1da9b3',
|
'dda_id': '7b3d774097ca477687f29ad0968833ac',
|
||||||
'description': '__MOCK_TICKET__',
|
'description': '__MOCK_TICKET__',
|
||||||
'forensic_evidence': {
|
'forensic_evidence': {
|
||||||
'hash': {}
|
'hash': {}
|
||||||
},
|
},
|
||||||
'fqdn': [
|
'fqdn': [
|
||||||
'mock-website.com',
|
'mock-website.com',
|
||||||
'google.com'
|
'mock-website-two.com'
|
||||||
],
|
],
|
||||||
'ipv4': [
|
'ipv4': [
|
||||||
'9.8.7.6',
|
'9.8.7.6',
|
||||||
|
@ -44,8 +44,6 @@ class TestReporterCreateTicket:
|
||||||
|
|
||||||
create_response = authenticated_post_request('/api/v1/ticket/create', self.access_token, self.ticket_parameters)
|
create_response = authenticated_post_request('/api/v1/ticket/create', self.access_token, self.ticket_parameters)
|
||||||
|
|
||||||
print(" RES -> ", create_response.json())
|
|
||||||
|
|
||||||
assert create_response.status_code == 200
|
assert create_response.status_code == 200
|
||||||
assert create_response.json()['status'] == 'success'
|
assert create_response.json()['status'] == 'success'
|
||||||
|
|
|
@ -30,6 +30,15 @@ class TestProviderSetTicketItems:
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()['status'] == 'success'
|
assert response.json()['status'] == 'success'
|
||||||
|
|
||||||
|
def test_set_unprocessed_fqdn(self):
|
||||||
|
response = authenticated_post_request('/api/v1/ticket/item/set/unprocessed', self.access_token, {
|
||||||
|
'value': 'mock-website-two.com',
|
||||||
|
'reason': 'ALREADY_BLOCKED'
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()['status'] == 'success'
|
||||||
|
|
||||||
def test_set_processed_ipv4(self):
|
def test_set_processed_ipv4(self):
|
||||||
response = authenticated_post_request('/api/v1/ticket/item/set/processed', self.access_token, {
|
response = authenticated_post_request('/api/v1/ticket/item/set/processed', self.access_token, {
|
||||||
'value': '9.8.7.6'
|
'value': '9.8.7.6'
|
|
@ -39,8 +39,6 @@ class CreateProviderAccountHandler(ProtectedHandler):
|
||||||
if self.handle_post(self.required_fields) == False:
|
if self.handle_post(self.required_fields) == False:
|
||||||
return
|
return
|
||||||
|
|
||||||
print(self.request_data)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# ensure that we have the proper permissions for this operation
|
# ensure that we have the proper permissions for this operation
|
||||||
self.permission_service.can_create_account()
|
self.permission_service.can_create_account()
|
||||||
|
|
|
@ -41,7 +41,7 @@ class AuthenticationLoginHandler(BaseHandler):
|
||||||
authentication_authenticate_service.execute,
|
authentication_authenticate_service.execute,
|
||||||
self.request_data.get('email'),
|
self.request_data.get('email'),
|
||||||
self.request_data.get('password'),
|
self.request_data.get('password'),
|
||||||
self.request.remote_ip
|
self.ip_address
|
||||||
)
|
)
|
||||||
|
|
||||||
# store the refresh token in a http-only secure cookie
|
# store the refresh token in a http-only secure cookie
|
||||||
|
|
|
@ -57,7 +57,7 @@ class AuthenticationRefreshHandler(BaseHandler):
|
||||||
None,
|
None,
|
||||||
authentication_refresh_access_token_service.execute,
|
authentication_refresh_access_token_service.execute,
|
||||||
refresh_token,
|
refresh_token,
|
||||||
self.request.remote_ip
|
self.ip_address
|
||||||
)
|
)
|
||||||
|
|
||||||
# return the access_token
|
# return the access_token
|
||||||
|
|
67
v1/handlers/dda/get_by_identifier.py
Normal file
67
v1/handlers/dda/get_by_identifier.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# I hate python imports
|
||||||
|
current = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
parent = os.path.dirname(current)
|
||||||
|
sys.path.append(parent)
|
||||||
|
|
||||||
|
import tornado
|
||||||
|
|
||||||
|
from ioutils.protected import ProtectedHandler
|
||||||
|
from ioutils.errors import ErrorCode, ErrorMessage
|
||||||
|
|
||||||
|
from piracyshield_service.dda.get_by_identifier import DDAGetByIdentifierService
|
||||||
|
from piracyshield_service.dda.get_by_identifier_for_reporter import DDAGetByIdentifierForReporterService
|
||||||
|
|
||||||
|
from piracyshield_data_model.account.role.model import AccountRoleModel
|
||||||
|
|
||||||
|
from piracyshield_component.exception import ApplicationException
|
||||||
|
|
||||||
|
class GetByIdentifierDDAHandler(ProtectedHandler):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Handles getting a single DDA by its identifier.
|
||||||
|
"""
|
||||||
|
|
||||||
|
required_fields = [
|
||||||
|
'dda_id'
|
||||||
|
]
|
||||||
|
|
||||||
|
async def post(self):
|
||||||
|
if self.initialize_account() == False:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.handle_post(self.required_fields) == False:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# check what level of view we have
|
||||||
|
if self.account_data.get('role') == AccountRoleModel.INTERNAL.value:
|
||||||
|
dda_get_by_identifier_service = DDAGetByIdentifierService()
|
||||||
|
|
||||||
|
response = await tornado.ioloop.IOLoop.current().run_in_executor(
|
||||||
|
None,
|
||||||
|
dda_get_by_identifier_service.execute,
|
||||||
|
self.request_data.get('dda_id')
|
||||||
|
)
|
||||||
|
|
||||||
|
self.success(data = response)
|
||||||
|
|
||||||
|
elif self.account_data.get('role') == AccountRoleModel.REPORTER.value:
|
||||||
|
dda_get_by_identifier_for_reporter_service = DDAGetByIdentifierForReporterService()
|
||||||
|
|
||||||
|
response = await tornado.ioloop.IOLoop.current().run_in_executor(
|
||||||
|
None,
|
||||||
|
dda_get_by_identifier_for_reporter_service.execute,
|
||||||
|
self.request_data.get('dda_id'),
|
||||||
|
self.account_data.get('account_id')
|
||||||
|
)
|
||||||
|
|
||||||
|
self.success(data = response)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.error(status_code = 403, error_code = ErrorCode.PERMISSION_DENIED, message = ErrorMessage.PERMISSION_DENIED)
|
||||||
|
|
||||||
|
except ApplicationException as e:
|
||||||
|
self.error(status_code = 400, error_code = e.code, message = e.message)
|
|
@ -1,53 +1,102 @@
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
# I hate python imports
|
|
||||||
current = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
parent = os.path.dirname(current)
|
|
||||||
sys.path.append(parent)
|
|
||||||
|
|
||||||
import tornado
|
import tornado
|
||||||
|
|
||||||
from ioutils.protected import ProtectedHandler
|
from ioutils.protected import ProtectedHandler
|
||||||
from ioutils.errors import ErrorCode, ErrorMessage
|
from ioutils.errors import ErrorCode, ErrorMessage
|
||||||
|
|
||||||
from piracyshield_service.forensic.create_archive import ForensicCreateArchiveService
|
from piracyshield_service.forensic.create_archive import ForensicCreateArchiveService
|
||||||
|
from piracyshield_component.environment import Environment
|
||||||
from piracyshield_component.exception import ApplicationException
|
from piracyshield_component.exception import ApplicationException
|
||||||
|
|
||||||
class UploadForensicHandler(ProtectedHandler):
|
class UploadForensicHandler(ProtectedHandler):
|
||||||
|
|
||||||
"""
|
def get_received_chunks_info_path(self, filename):
|
||||||
Appends the forensic evidence.
|
return f"{Environment.CACHE_PATH}/{filename}_info.json"
|
||||||
"""
|
|
||||||
|
def update_received_chunks_info(self, filename, chunk_index):
|
||||||
|
info_path = self.get_received_chunks_info_path(filename)
|
||||||
|
|
||||||
|
if os.path.exists(info_path):
|
||||||
|
with open(info_path, 'r') as file:
|
||||||
|
info = json.load(file)
|
||||||
|
|
||||||
|
received_chunks = set(info['received_chunks'])
|
||||||
|
else:
|
||||||
|
received_chunks = set()
|
||||||
|
|
||||||
|
received_chunks.add(chunk_index)
|
||||||
|
|
||||||
|
with open(info_path, 'w') as file:
|
||||||
|
json.dump({"received_chunks": list(received_chunks)}, file)
|
||||||
|
|
||||||
|
def check_all_chunks_received(self, filename, total_chunks):
|
||||||
|
info_path = self.get_received_chunks_info_path(filename)
|
||||||
|
|
||||||
|
if os.path.exists(info_path):
|
||||||
|
with open(info_path, 'r') as file:
|
||||||
|
info = json.load(file)
|
||||||
|
|
||||||
|
return len(info['received_chunks']) == total_chunks
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
async def post(self, ticket_id):
|
async def post(self, ticket_id):
|
||||||
if self.initialize_account() == False:
|
if not self.initialize_account():
|
||||||
return
|
return
|
||||||
|
|
||||||
# verify permissions
|
# verify permissions
|
||||||
self.permission_service.can_upload_ticket()
|
self.permission_service.can_upload_ticket()
|
||||||
|
|
||||||
# TODO: we could add some mime type filters but we might need to interrogate the ticket model to have a dynamic implementation.
|
if 'archive' not in self.request.files or 'chunkIndex' not in self.request.arguments or 'totalChunks' not in self.request.arguments:
|
||||||
|
|
||||||
# no archive passed?
|
|
||||||
if 'archive' not in self.request.files:
|
|
||||||
return self.error(status_code = 400, error_code = ErrorCode.MISSING_FILE, message = ErrorMessage.MISSING_FILE)
|
return self.error(status_code = 400, error_code = ErrorCode.MISSING_FILE, message = ErrorMessage.MISSING_FILE)
|
||||||
|
|
||||||
# get the file data
|
|
||||||
archive_handler = self.request.files['archive'][0]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
archive_handler = self.request.files['archive'][0]
|
||||||
|
chunk_index = int(self.request.arguments['chunkIndex'][0])
|
||||||
|
total_chunks = int(self.request.arguments['totalChunks'][0])
|
||||||
|
original_filename = self.get_argument('originalFileName')
|
||||||
|
|
||||||
|
filename = f"{ticket_id}_{original_filename}"
|
||||||
|
|
||||||
|
temp_storage_path = f"{Environment.CACHE_PATH}/{filename}"
|
||||||
|
|
||||||
|
chunk_temp_path = temp_storage_path + f"_part{chunk_index}"
|
||||||
|
|
||||||
|
with open(chunk_temp_path, "wb") as temp_file:
|
||||||
|
temp_file.write(archive_handler['body'])
|
||||||
|
|
||||||
|
# update received chunks
|
||||||
|
self.update_received_chunks_info(filename, chunk_index)
|
||||||
|
|
||||||
|
# did we get every chunk?
|
||||||
|
if self.check_all_chunks_received(filename, total_chunks):
|
||||||
|
# proceed to assemble all the chunks into the a single file
|
||||||
|
with open(temp_storage_path, "wb") as final_file:
|
||||||
|
for i in range(total_chunks):
|
||||||
|
part_file_path = temp_storage_path + f"_part{i}"
|
||||||
|
|
||||||
|
if not os.path.exists(part_file_path):
|
||||||
|
raise ApplicationException("Missing file chunk: " + str(i))
|
||||||
|
|
||||||
|
with open(part_file_path, "rb") as part_file:
|
||||||
|
final_file.write(part_file.read())
|
||||||
|
|
||||||
|
os.remove(part_file_path)
|
||||||
|
|
||||||
|
# process the final file
|
||||||
forensic_create_archive_service = ForensicCreateArchiveService()
|
forensic_create_archive_service = ForensicCreateArchiveService()
|
||||||
|
|
||||||
await tornado.ioloop.IOLoop.current().run_in_executor(
|
await tornado.ioloop.IOLoop.current().run_in_executor(
|
||||||
None,
|
None,
|
||||||
forensic_create_archive_service.execute,
|
forensic_create_archive_service.execute,
|
||||||
ticket_id,
|
ticket_id,
|
||||||
archive_handler['filename'],
|
filename
|
||||||
archive_handler['body']
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# clean stored info
|
||||||
|
os.remove(self.get_received_chunks_info_path(filename))
|
||||||
|
|
||||||
|
# always answer or the frontend won't send anything more
|
||||||
self.success()
|
self.success()
|
||||||
|
|
||||||
except ApplicationException as e:
|
except ApplicationException as e:
|
||||||
|
|
|
@ -52,6 +52,7 @@ class CreateTicketHandler(ProtectedHandler):
|
||||||
|
|
||||||
ticket_create_service = TicketCreateService()
|
ticket_create_service = TicketCreateService()
|
||||||
|
|
||||||
|
if self.account_data.get('role') == AccountRoleModel.INTERNAL.value:
|
||||||
ticket_id, revoke_time = await tornado.ioloop.IOLoop.current().run_in_executor(
|
ticket_id, revoke_time = await tornado.ioloop.IOLoop.current().run_in_executor(
|
||||||
None,
|
None,
|
||||||
ticket_create_service.execute,
|
ticket_create_service.execute,
|
||||||
|
@ -70,5 +71,28 @@ class CreateTicketHandler(ProtectedHandler):
|
||||||
note = f'Ticket created. If this is a mistake, you have {revoke_time} seconds to remove it before it gets visible to the providers.'
|
note = f'Ticket created. If this is a mistake, you have {revoke_time} seconds to remove it before it gets visible to the providers.'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif self.account_data.get('role') == AccountRoleModel.REPORTER.value:
|
||||||
|
ticket_id, revoke_time = await tornado.ioloop.IOLoop.current().run_in_executor(
|
||||||
|
None,
|
||||||
|
ticket_create_service.execute,
|
||||||
|
self.request_data.get('dda_id'),
|
||||||
|
self.request_data.get('forensic_evidence'),
|
||||||
|
self.request_data.get('fqdn') or [],
|
||||||
|
self.request_data.get('ipv4') or [],
|
||||||
|
self.request_data.get('ipv6') or [],
|
||||||
|
self.request_data.get('assigned_to') or [], # TEMPORARY
|
||||||
|
#[], # currently, we do not allow any choice by the reporter to which provider will receive the ticket
|
||||||
|
self.account_data.get('account_id'),
|
||||||
|
self.request_data.get('description') or None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.success(
|
||||||
|
data = { 'ticket_id': ticket_id },
|
||||||
|
note = f'Ticket created. If this is a mistake, you have {revoke_time} seconds to remove it before it gets visible to the providers.'
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.error(status_code = 403, error_code = ErrorCode.PERMISSION_DENIED, message = ErrorMessage.PERMISSION_DENIED)
|
||||||
|
|
||||||
except ApplicationException as e:
|
except ApplicationException as e:
|
||||||
self.error(status_code = 400, error_code = e.code, message = e.message)
|
self.error(status_code = 400, error_code = e.code, message = e.message)
|
||||||
|
|
|
@ -87,6 +87,7 @@ from .handlers.whitelist.set_status import SetStatusActiveWhitelistItemHandler,
|
||||||
from .handlers.whitelist.remove import RemoveWhitelistItemHandler
|
from .handlers.whitelist.remove import RemoveWhitelistItemHandler
|
||||||
|
|
||||||
from .handlers.dda.create import CreateDDAHandler
|
from .handlers.dda.create import CreateDDAHandler
|
||||||
|
from .handlers.dda.get_by_identifier import GetByIdentifierDDAHandler
|
||||||
from .handlers.dda.get_all import GetAllDDAHandler
|
from .handlers.dda.get_all import GetAllDDAHandler
|
||||||
from .handlers.dda.get_all_by_account import GetAllByAccountDDAHandler
|
from .handlers.dda.get_all_by_account import GetAllByAccountDDAHandler
|
||||||
from .handlers.dda.get_global import GetGlobalDDAHandler
|
from .handlers.dda.get_global import GetGlobalDDAHandler
|
||||||
|
@ -212,6 +213,7 @@ class APIv1:
|
||||||
|
|
||||||
# DDA management
|
# DDA management
|
||||||
(r"/dda/create", CreateDDAHandler),
|
(r"/dda/create", CreateDDAHandler),
|
||||||
|
(r"/dda/get", GetByIdentifierDDAHandler),
|
||||||
(r"/dda/get/all", GetAllDDAHandler),
|
(r"/dda/get/all", GetAllDDAHandler),
|
||||||
(r"/dda/get/all/by_account", GetAllByAccountDDAHandler),
|
(r"/dda/get/all/by_account", GetAllByAccountDDAHandler),
|
||||||
(r"/dda/get/global", GetGlobalDDAHandler),
|
(r"/dda/get/global", GetGlobalDDAHandler),
|
||||||
|
|
Loading…
Reference in a new issue