Skip to content

Commit

Permalink
configurable concurrency (#132)
Browse files Browse the repository at this point in the history
configurable upload concurrency, disabled by default

add retry mechanism for http errors on binary upload
  • Loading branch information
itamarga authored Jan 23, 2024
1 parent 60b36b0 commit 08e0eb4
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 30 deletions.
4 changes: 4 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
1.19.13
_______
- Configurable concurrency for scan uploads

1.19.12
_______
- Raise AnalysisSkippedByRuleError when analysis is skipped by rule on server
Expand Down
2 changes: 1 addition & 1 deletion intezer_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.19.12'
__version__ = '1.19.13'
41 changes: 31 additions & 10 deletions intezer_sdk/_endpoint_analysis_api.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import gzip
import logging
from typing import List

import requests

from intezer_sdk.api import IntezerApiClient
from intezer_sdk.api import raise_for_status
from intezer_sdk.consts import SCAN_MAX_UPLOAD_RETRIES


class EndpointScanApi:
def __init__(self, scan_id: str, api: IntezerApiClient):
def __init__(self, scan_id: str, api: IntezerApiClient, max_upload_retries: int = SCAN_MAX_UPLOAD_RETRIES):
self.api = api
if not scan_id:
raise ValueError('scan_id must be provided')
self.scan_id = scan_id
self.base_url = f"{api.base_url.replace('/api/','')}/scans/scans/{scan_id}"
self.max_upload_retries = max_upload_retries

def request_with_refresh_expired_access_token(self, *args, **kwargs):
return self.api.request_with_refresh_expired_access_token(base_url=self.base_url, *args, **kwargs)
Expand All @@ -28,6 +33,12 @@ def send_processes_info(self, processes_info: dict):
method='POST')
raise_for_status(response)

def send_all_loaded_modules_info(self, all_loaded_modules_info: dict):
response = self.request_with_refresh_expired_access_token(path=f'/processes/loaded-modules-info',
data=all_loaded_modules_info,
method='POST')
raise_for_status(response)

def send_loaded_modules_info(self, pid, loaded_modules_info: dict):
response = self.request_with_refresh_expired_access_token(path=f'/processes/{pid}/loaded-modules-info',
data=loaded_modules_info,
Expand Down Expand Up @@ -75,16 +86,26 @@ def send_memory_module_dump_info(self, memory_modules_info: dict) -> List[str]:
return response.json()['result']

def upload_collected_binary(self, file_path: str, collected_from: str):
with open(file_path, 'rb') as file_to_upload:
file_data = file_to_upload.read()
compressed_data = gzip.compress(file_data, compresslevel=9)
response = self.request_with_refresh_expired_access_token(
path=f'/{collected_from}/collected-binaries',
data=compressed_data,
headers={'Content-Type': 'application/octet-stream', 'Content-Encoding': 'gzip'},
method='POST')
file_data = open(file_path, 'rb').read()
compressed_data = gzip.compress(file_data, compresslevel=9)
logger = logging.getLogger(__name__)
# we have builtin retry for connection errors, but we want to retry on 500 errors as well
for retry_count in range(self.max_upload_retries):
try:
response = self.request_with_refresh_expired_access_token(
path=f'/{collected_from}/collected-binaries',
data=compressed_data,
headers={'Content-Type': 'application/octet-stream', 'Content-Encoding': 'gzip'},
method='POST')
raise_for_status(response)
return
except requests.HTTPError:
if self.max_upload_retries - retry_count <= 1:
raise
logger.warning(f'Failed to upload {file_path}, retrying')
except Exception:
raise

raise_for_status(response)

def end_scan(self, scan_summary: dict):
response = self.request_with_refresh_expired_access_token(path='/end',
Expand Down
2 changes: 2 additions & 0 deletions intezer_sdk/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,5 @@ class OnPremiseVersion(enum.IntEnum):
USER_AGENT = f'intezer-python-sdk-{__version__}'
CHECK_STATUS_INTERVAL = 1
SCAN_TYPE_OFFLINE_ENDPOINT_SCAN = 'offline_endpoint_scan'
SCAN_DEFAULT_MAX_WORKERS = 1
SCAN_MAX_UPLOAD_RETRIES = 3
49 changes: 30 additions & 19 deletions intezer_sdk/endpoint_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from intezer_sdk.api import get_global_api
from intezer_sdk.base_analysis import Analysis
from intezer_sdk.consts import EndpointAnalysisEndReason
from intezer_sdk.consts import SCAN_DEFAULT_MAX_WORKERS
from intezer_sdk.sub_analysis import SubAnalysis

logger = logging.getLogger(__name__)
Expand All @@ -35,7 +36,8 @@ class EndpointAnalysis(Analysis):
def __init__(self,
api: IntezerApiClient = None,
scan_api: EndpointScanApi = None,
offline_scan_directory: str = None):
offline_scan_directory: str = None,
max_concurrent_uploads: int = None):
"""
Initializes an EndpointAnalysis object.
Supports offline scan mode, run Scanner.exe with the '-o' flag to generate the offline scan directory.
Expand All @@ -45,6 +47,7 @@ def __init__(self,
:param offline_scan_directory: The directory of the offline scan. (example: C:\scans\scan_%computername%_%time%)
"""
super().__init__(api)
self.max_workers = max_concurrent_uploads or SCAN_DEFAULT_MAX_WORKERS
self._scan_api = scan_api
if offline_scan_directory:
files_dir = os.path.join(offline_scan_directory, '..', 'files')
Expand Down Expand Up @@ -124,18 +127,14 @@ def _send_analyze_to_api(self, **additional_parameters) -> str:
raise ValueError('Scan directory is not set')
if not os.path.isdir(self._offline_scan_directory):
raise ValueError('Scan directory does not exist')
if not os.path.isdir(self._files_dir):
raise ValueError('Files directory does not exist')
if not os.path.isdir(self._fileless_dir):
raise ValueError('Fileless directory does not exist')
if not os.path.isdir(self._memory_modules_dir):
raise ValueError('Memory modules directory does not exist')

self._scan_id, self.analysis_id = self._create_scan()

self.status = consts.AnalysisStatusCode.IN_PROGRESS
self._initialize_endpoint_api()

logger.info(f'Uploading {os.path.basename(os.path.abspath(self._offline_scan_directory))}')

self._send_host_info()
self._send_scheduled_tasks_info()
self._send_processes_info()
Expand Down Expand Up @@ -201,6 +200,13 @@ def _send_scheduled_tasks_info(self):

def _send_loaded_modules_info(self):
logger.info(f'Endpoint analysis: {self.analysis_id}, uploading loaded modules info')
unified_modules_file_path = os.path.join(self._offline_scan_directory, 'all_loaded_modules_info.json')
if os.path.isfile(unified_modules_file_path):
with open(unified_modules_file_path, encoding='utf-8') as f:
loaded_modules_info = json.load(f)
self._scan_api.send_all_loaded_modules_info(loaded_modules_info)
return

for loaded_module_info_file in glob.glob(os.path.join(self._offline_scan_directory,
'*_loaded_modules_info.json')):
with open(loaded_module_info_file, encoding='utf-8') as f:
Expand All @@ -211,15 +217,14 @@ def _send_loaded_modules_info(self):

def _send_files_info_and_upload_required(self):
logger.info(f'Endpoint analysis: {self.analysis_id}, uploading files info and uploading required files')
with concurrent.futures.ThreadPoolExecutor() as executor:
for files_info_file in glob.glob(os.path.join(self._offline_scan_directory, 'files_info_*.json')):

logger.debug(f'Endpoint analysis: {self.analysis_id}, uploading {files_info_file}')
with open(files_info_file, encoding='utf-8') as f:
files_info = json.load(f)
files_to_upload = self._scan_api.send_files_info(files_info)

futures = []
for files_info_file in glob.glob(os.path.join(self._offline_scan_directory, 'files_info_*.json')):
logger.debug(f'Endpoint analysis: {self.analysis_id}, uploading {files_info_file}')
with open(files_info_file, encoding='utf-8') as f:
files_info = json.load(f)
files_to_upload = self._scan_api.send_files_info(files_info)

futures = []
with concurrent.futures.ThreadPoolExecutor(self.max_workers) as executor:
for file_to_upload in files_to_upload:
file_path = os.path.join(self._files_dir, f'{file_to_upload}.sample')
if os.path.isfile(file_path):
Expand All @@ -232,20 +237,26 @@ def _send_files_info_and_upload_required(self):
future.result()

def _send_module_differences(self):
file_module_differences_file_path = os.path.join(self._offline_scan_directory, 'file_module_differences.json')
if not os.path.isfile(file_module_differences_file_path):
return
logger.info(f'Endpoint analysis: {self.analysis_id}, uploading file module differences info')
with open(os.path.join(self._offline_scan_directory, 'file_module_differences.json'), encoding='utf-8') as f:
with open(file_module_differences_file_path, encoding='utf-8') as f:
file_module_differences = json.load(f)
self._scan_api.send_file_module_differences(file_module_differences)

def _send_injected_modules_info(self):
injected_modules_info_path = os.path.join(self._offline_scan_directory, 'injected_modules_info.json')
if not os.path.isfile(injected_modules_info_path):
return
logger.info(f'Endpoint analysis: {self.analysis_id}, uploading injected modules info')
with open(os.path.join(self._offline_scan_directory, 'injected_modules_info.json'), encoding='utf-8') as f:
with open(injected_modules_info_path, encoding='utf-8') as f:
injected_modules_info = json.load(f)
self._scan_api.send_injected_modules_info(injected_modules_info)

def _send_memory_module_dump_info_and_upload_required(self):
logger.info(f'Endpoint analysis: {self.analysis_id}, uploading memory module dump info')
with concurrent.futures.ThreadPoolExecutor() as executor:
with concurrent.futures.ThreadPoolExecutor(self.max_workers) as executor:
for memory_module_dump_info_file in glob.glob(os.path.join(self._offline_scan_directory,
'memory_module_dump_info_*.json')):

Expand Down

0 comments on commit 08e0eb4

Please sign in to comment.