From f99ae043fd389c1c208d487bbe27c1040ad9cb39 Mon Sep 17 00:00:00 2001 From: Nia Catlin <5470374+ncatlin@users.noreply.github.com> Date: Thu, 12 Feb 2026 13:50:39 +0000 Subject: [PATCH 1/2] New feature: Cape dynamic test framework (#2904) * initial ui and DB implementation commit * database.py refactored into multiple mixins * backend implementation of objectives, ui alignment * implement test updating/removal during reload * hide inactive tests from tasking * implement audit session deletion * tests are now fully evaluated and reported * session auto-refresh, task storage deletion * queue clearing, config and timing display improvements * implement audit session list paging * remove tasks from analysis search. implement tags_tasks_not_like filter in list tasks. * code tidying and moving files to final structure * audit framework: implement conf enable toggle. add auth decorator to all paths * fix some audit bugs * clear old audit py module structure, fix queue bugs * add config edit functionality * pluralize table row counts * prepare audit_packages dir * prepare audit_packages dir * resolve lingering issues from fork * these modules largely use the tasking functions, so label as tasking mixin to improve intellisense * fix infuriating visual studio django auto formatting * more minor fixes - exception handling, imports * final minor changes * linting, improve config update workflow, test cascade * more linting * cleared ruff issues * fix pytest issues * fix pytest issues * resolve odd fstring complaint from tests * typo * Use test status constants; fix audit task handling Standardize status checks by replacing hard-coded status strings with TEST_* constants across audit code and web views. Improve TestLoader payload handling to treat an extracted directory with multiple items as the payload rather than raising an error. Fix a typo in a docstring. Optimize update_audit_tasks_status to only update/evaluate runs when the status actually changes and commit the DB session once if any changes occurred. Also prevent re-queuing tests that aren't in the unqueued state and use constants when unqueuing/queuing runs. * Update imports: remove unused statuses, add TASK_RUNNING Remove unused test status constants (TEST_FAILED, TEST_UNQUEUED) from lib/cuckoo/core/data/audits.py imports, add TASK_RUNNING import to modules/machinery/az.py so the Azure machinery can reference the running state, and clean up a stray trailing whitespace in lib/cuckoo/common/audit_utils.py. These changes tidy up imports and prevent missing-constant usage in the AZ module. * rework audit db usage, improve config edits, improve add test UX * Update audits.py * Revert "Update audits.py" This reverts commit f142b507c3cdf120ab00132421ddffb99afa1d2b. * improve db and error handling of test reloading * ruff blank lines * docs --------- Co-authored-by: doomedraven --- .gitignore | 3 +- conf/default/web.conf.default | 3 + docs/book/src/usage/audit.rst | 158 ++ docs/book/src/usage/index.rst | 1 + lib/cuckoo/common/abstracts.py | 4 +- lib/cuckoo/common/audit_utils.py | 193 ++ lib/cuckoo/common/cleaners_utils.py | 9 +- lib/cuckoo/common/web_utils.py | 8 +- lib/cuckoo/core/analysis_manager.py | 5 +- lib/cuckoo/core/data/__init__.py | 0 lib/cuckoo/core/data/audit_data.py | 157 ++ lib/cuckoo/core/data/audits.py | 496 ++++ lib/cuckoo/core/data/db_common.py | 115 + lib/cuckoo/core/data/guests.py | 91 + lib/cuckoo/core/data/machines.py | 399 +++ lib/cuckoo/core/data/samples.py | 407 +++ lib/cuckoo/core/data/task.py | 178 ++ lib/cuckoo/core/data/tasking.py | 1367 +++++++++ lib/cuckoo/core/database.py | 2432 +---------------- lib/cuckoo/core/machinery_manager.py | 4 +- lib/cuckoo/core/scheduler.py | 6 +- lib/cuckoo/core/startup.py | 3 +- modules/machinery/aws.py | 2 +- modules/machinery/az.py | 3 +- modules/reporting/callback.py | 3 +- tests/audit_packages/readme.md | 3 + tests/test_analysis_manager.py | 5 +- tests/test_database.py | 22 +- tests/web/test_apiv2.py | 2 +- utils/process.py | 7 +- web/analysis/views.py | 7 +- web/apiv2/views.py | 5 +- web/audit/__init__.py | 3 + web/audit/urls.py | 22 + web/audit/views.py | 495 ++++ web/dashboard/views.py | 10 +- web/templates/audit/index.html | 446 +++ .../audit/partials/objective_item.html | 56 + .../audit/partials/session_status_header.html | 199 ++ .../audit/partials/session_test_run.html | 115 + web/templates/audit/partials/task_config.html | 145 + web/templates/audit/partials/timing_item.html | 42 + web/templates/audit/session.html | 371 +++ web/templates/header.html | 3 + web/web/settings.py | 3 + web/web/urls.py | 2 + 46 files changed, 5552 insertions(+), 2458 deletions(-) create mode 100644 docs/book/src/usage/audit.rst create mode 100644 lib/cuckoo/common/audit_utils.py create mode 100644 lib/cuckoo/core/data/__init__.py create mode 100644 lib/cuckoo/core/data/audit_data.py create mode 100644 lib/cuckoo/core/data/audits.py create mode 100644 lib/cuckoo/core/data/db_common.py create mode 100644 lib/cuckoo/core/data/guests.py create mode 100644 lib/cuckoo/core/data/machines.py create mode 100644 lib/cuckoo/core/data/samples.py create mode 100644 lib/cuckoo/core/data/task.py create mode 100644 lib/cuckoo/core/data/tasking.py create mode 100644 tests/audit_packages/readme.md create mode 100644 web/audit/__init__.py create mode 100644 web/audit/urls.py create mode 100644 web/audit/views.py create mode 100644 web/templates/audit/index.html create mode 100644 web/templates/audit/partials/objective_item.html create mode 100644 web/templates/audit/partials/session_status_header.html create mode 100644 web/templates/audit/partials/session_test_run.html create mode 100644 web/templates/audit/partials/task_config.html create mode 100644 web/templates/audit/partials/timing_item.html create mode 100644 web/templates/audit/session.html diff --git a/.gitignore b/.gitignore index 6c7756c82d9..ff14b51a369 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,5 @@ tests/test_bson.bson.compressed installer/cape-config.sh installer/kvm-config.sh -docs/book/src/_build \ No newline at end of file +docs/book/src/_build +/.vs diff --git a/conf/default/web.conf.default b/conf/default/web.conf.default index 665d1032bf6..001b8f55222 100644 --- a/conf/default/web.conf.default +++ b/conf/default/web.conf.default @@ -246,3 +246,6 @@ enabled = no [pcap_ng] enabled = no + +[audit_framework] +enabled = no \ No newline at end of file diff --git a/docs/book/src/usage/audit.rst b/docs/book/src/usage/audit.rst new file mode 100644 index 00000000000..7de09f6b82b --- /dev/null +++ b/docs/book/src/usage/audit.rst @@ -0,0 +1,158 @@ +.. _audit_framework: + +=============== +Audit Framework +=============== + +The Audit Framework is a specialized subsystem within CAPE designed to verify the correctness and reliability of the sandbox's analysis capabilities. It allows operators to define specific test cases ("Audit Packages") that run known samples with expected behavioral outcomes ("Objectives"). This is particularly useful for validating that CAPE is correctly capturing specific behaviors (e.g., shellcode injection, network beacons) after updates or configuration changes. + +Concepts +======== + +* **Available Test**: A test case definition stored on the disk. It consists of a payload (e.g., a malware sample) and a Python script defining the success criteria. +* **Test Session**: A collection of test runs. You can group multiple tests into a session to validate a specific aspect of the system (e.g., "Weekly Regression Test"). +* **Test Run**: A single execution of an *Available Test* within a *Test Session*. It links to a standard CAPE Task ID. +* **Objective**: A specific criterion that must be met for a test to pass (e.g., "DNS request to evil.com observed", "File dropped in AppData"). + +Configuration +============= + +To enable the Audit Framework, ensure the feature is enabled in your web configuration. + +Edit ``conf/web.conf``: + +.. code-block:: ini + + [audit_framework] + enabled = yes + +The framework looks for test packages in ``tests/audit_packages/`` by default. + +Creating Audit Packages +======================= + +Audit packages are directory-based. Each package must be a subdirectory inside ``tests/audit_packages/`` (or the configured path) containing at least two files: + +1. ``payload.zip``: A zip file containing the sample to be analyzed. + * *Note*: If the zip contains a single file, that file is treated as the payload. If it contains multiple files, the extracted directory is treated as the payload (useful for packages requiring dependencies). +2. ``test.py``: A Python script defining the test metadata, objectives, and evaluation logic. + +Directory Structure Example +--------------------------- + +.. code-block:: text + + tests/audit_packages/ + ├── Emotet_Network_Beacon/ + │ ├── payload.zip + │ └── test.py + └── AsyncRAT_Config_Extract/ + ├── payload.zip + └── test.py + +The ``test.py`` Structure +------------------------- + +The Python script must define a class named ``CapeDynamicTest`` that implements the following methods: + +* ``get_metadata()``: Returns a dictionary of test settings. +* ``get_objectives()``: Returns a list of objective objects. +* ``evaluate_results(task_dir)``: Analyzes the analysis results. +* ``get_results()``: Returns the final status of objectives. + +**Example `test.py`:** + +.. code-block:: python + + import os + import json + + class TestObjective: + def __init__(self, name, requirement, children=None): + self.name = name + self.requirement = requirement + self.children = children or [] + + class CapeDynamicTest: + def __init__(self): + self._results = {} + + def get_metadata(self): + """ + Define high-level test information. + """ + return { + "Name": "Emotet Beacon Test", + "Description": "Verifies that CAPE detects the C2 network connection.", + "Package": "exe", # CAPE analysis package to use + "Timeout": 200, # Analysis timeout in seconds + "Zip Password": "infected" # Password for payload.zip (optional) + } + + def get_objectives(self): + """ + Define the criteria for success. + """ + return [ + TestObjective("network_c2", "Must connect to C2 server 1.2.3.4"), + TestObjective("dropped_payload", "Must drop the second stage loader") + ] + + def evaluate_results(self, task_dir): + """ + Parse the CAPE report to verify objectives. + task_dir: Path to the storage directory for this task (contains report.json, etc.) + """ + report_path = os.path.join(task_dir, "reports", "report.json") + + # Default state + self._results = { + "network_c2": {"state": "failure", "state_reason": "IP not found"}, + "dropped_payload": {"state": "failure", "state_reason": "File not found"} + } + + if not os.path.exists(report_path): + return + + with open(report_path, "r") as f: + report = json.load(f) + + # Check Network + for host in report.get("network", {}).get("hosts", []): + if host == "1.2.3.4": + self._results["network_c2"] = {"state": "success", "state_reason": "Connection found"} + + # Check Dropped Files + if "dropped" in report: + self._results["dropped_payload"] = {"state": "success", "state_reason": "Dropped files present"} + + def get_results(self): + """ + Return the dictionary of results calculated in evaluate_results. + Keys must match the Objective names. + """ + return self._results + +Web Interface Usage +=================== + +Access the Audit interface via the sidebar menu or at ``/audit/``. + +1. **Manage Tests**: + The main dashboard lists all available tests. + * If you have added new tests to the disk, click **Reload Tests** to update the database. + +2. **Create Session**: + * Select the checkboxes next to the tests you wish to run. + * Click **Create Session**. + * You will be redirected to the Session view. + +3. **Run Audit**: + * In the Session view, you can see the status of each test (Unqueued, Queued, Running, Complete). + * Click **Queue All** to submit all unqueued tests to CAPE. + * The status will update automatically as CAPE processes the tasks. + +4. **View Results**: + * Once a test is ``Complete``, the framework automatically runs the ``evaluate_results`` logic from your `test.py`. + * The UI will display a **Pass** (Green) or **Fail** (Red) badge for each objective. + * You can expand a test row to see detailed reasons for failure or success. diff --git a/docs/book/src/usage/index.rst b/docs/book/src/usage/index.rst index 75cd714c64a..cb3a537a5c7 100644 --- a/docs/book/src/usage/index.rst +++ b/docs/book/src/usage/index.rst @@ -12,6 +12,7 @@ This chapter explains how to use CAPE. submit web api + audit dist cluster_administration packages diff --git a/lib/cuckoo/common/abstracts.py b/lib/cuckoo/common/abstracts.py index 631f48dbfe0..5cdb9d43646 100644 --- a/lib/cuckoo/common/abstracts.py +++ b/lib/cuckoo/common/abstracts.py @@ -40,7 +40,9 @@ from lib.cuckoo.common.path_utils import path_exists, path_mkdir from lib.cuckoo.common.url_validate import url as url_validator from lib.cuckoo.common.utils import create_folder, get_memdump_path, load_categories -from lib.cuckoo.core.database import Database, Machine, _Database, Task +from lib.cuckoo.core.database import Database, _Database +from lib.cuckoo.core.data.task import Task +from lib.cuckoo.core.data.machines import Machine try: import re2 as re diff --git a/lib/cuckoo/common/audit_utils.py b/lib/cuckoo/common/audit_utils.py new file mode 100644 index 00000000000..e9eb83cb02f --- /dev/null +++ b/lib/cuckoo/common/audit_utils.py @@ -0,0 +1,193 @@ +import os +import logging +import zipfile +import shutil +from pathlib import Path +from typing import Any, List, Dict +import importlib.util +from lib.cuckoo.core.data import task as db_task +from lib.cuckoo.core.data.audit_data import TEST_RUNNING, TEST_COMPLETE, TEST_FAILED, TEST_QUEUED + +log = logging.getLogger(__name__) + +def load_module(module_path): + module_name = "test_py_module" + spec = importlib.util.spec_from_file_location(module_name, str(module_path)) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + if not hasattr(module, 'CapeDynamicTest'): + log.warning(str(dir(module))) + raise ValueError("Module has no CapeDynamicTest class") + tester = module.CapeDynamicTest() + + if not hasattr(tester, 'get_metadata'): + raise ValueError(f"CapeDynamicTest from {module_path} lacks get_metadata() function") + return tester + + +class TestLoader(): + def __init__(self, tests_directory): + if not os.path.exists(tests_directory): + raise ValueError(f"Tests directory '{tests_directory}' does not exist.") + self.tests_root = tests_directory + + def _extract_payload(self, payload_archive, payload_output_dir, zip_password=None): + + # Verify payload ZIP integrity + try: + with zipfile.ZipFile(payload_archive, 'r') as z: + # If a password is provided in JSON, verify we can access the list + if zip_password: + z.setpassword(zip_password.encode()) + # Test if the zip is actually readable/not corrupt + z.testzip() + except zipfile.BadZipFile: + if zip_password: + raise zipfile.BadZipFile(f"{payload_archive} is not usable with the given password") + else: + raise zipfile.BadZipFile(f"{payload_archive} is corrupt") + + # delete the unwrapped payload in case a new zip has been uploaded + if os.path.exists(payload_output_dir): + shutil.rmtree(payload_output_dir) + + with zipfile.ZipFile(payload_archive, 'r') as zip_ref: + if zip_password: + zip_ref.extractall(payload_output_dir, pwd=zip_password) + else: + zip_ref.extractall(payload_output_dir) + + payload_path = None + if not os.path.isdir(payload_output_dir): + raise NotADirectoryError("Bad payload directory extracted") + + dir_path = Path(payload_output_dir) + dir_contents = list(dir_path.iterdir()) + if not dir_contents: + raise FileNotFoundError("Nothing in extracted payload directory") + + if len(dir_contents) == 1: + payload_path = str(dir_contents[0]) + else: + # If multiple items, treat the directory itself as the payload + payload_path = payload_output_dir + + if not os.path.exists(payload_path): + raise FileNotFoundError("Nothing extracted from payload archive or it could not be written to disk") + + return payload_path + + def validate_test_directory(self, test_path: str) -> Dict[str, Any]: + """ + Validates a single test directory and returns the metadata from the test module. + Raises ValueError if the anything is invalid. + """ + payload_archive = os.path.join(test_path, "payload.zip") + module_path = os.path.join(test_path, "test.py") + + # Check for required files + if not os.path.exists(payload_archive): + raise ValueError(f"Missing payload.zip in {payload_archive}") + if not os.path.exists(module_path): + raise ValueError(f"Missing test.py in {module_path}") + + test_metadata = {} + test_metadata['module_path'] = module_path + + # Load and instantiate the python test module and fetch metadata + try: + tester = load_module(module_path) + test_metadata['info'] = tester.get_metadata() + + test_metadata['objectives'] = [] + + def load_objective(objective): + objdict = {'name': objective.name, + 'requirement': objective.requirement, + 'children': [load_objective(child) for child in objective.children] + } + return objdict + for objective in tester.get_objectives(): + test_metadata['objectives'].append(load_objective(objective)) + + except Exception as e: + raise ValueError(f"Failed to load test module or fetch metadata from {module_path}: {e}") + + conf = test_metadata['info'].get("Task Config", None) + if conf: + if conf.get("Request Options",None) is None: + test_metadata['info']["Request Options"] = "" + + if 'Name' not in test_metadata['info']: + raise ValueError(f"Metadata in {module_path} missing 'Name' field") + if 'Package' not in test_metadata['info']: + raise ValueError(f"Metadata in {module_path} missing 'Package' field") + + zip_password = test_metadata['info'].get("Zip Password", None) + payload_output_dir = os.path.join(test_path, "payload") + test_metadata['payload_path'] = self._extract_payload(payload_archive, payload_output_dir, zip_password) + + # Return prepared metadata for DB ingest + return test_metadata + + def load_tests(self) -> List[Dict[str, Any]]: + """ + Walks the root directory and yields validated test configurations. + """ + available_tests = [] + unavailable_tests = [] + + if not os.path.exists(self.tests_root): + log.error("Tests root %s does not exist.", self.tests_root) + return {"error": f"Tests root {self.tests_root} does not exist."} + + for entry in os.scandir(self.tests_root): + if entry.is_dir(): + test_config = None + try: + test_config = self.validate_test_directory(entry.path) + available_tests.append(test_config) + log.info("Loaded test: %s",test_config['info']['Name']) + except Exception as e: + log.exception("Skipping directory %s due to exception",entry.path) + unavailable_tests.append({"module_path":entry.path, "error":str(e)}) + + return {'available':available_tests, 'unavailable': unavailable_tests} + + +class TestResultValidator(): + def __init__(self, test_module_path:str, task_storage_directory: str): + if os.path.isdir(task_storage_directory): + self.task_directory = task_storage_directory + else: + raise NotADirectoryError(f"Invalid task directory: {task_storage_directory}") + + try: + self.test_module = load_module(test_module_path) + except Exception as e: + raise ValueError(f"Failed to load test evaluation module {test_module_path}: {e}") + + def evaluate(self): + self.test_module.evaluate_results(self.task_directory) + return self.test_module.get_results() + +def task_status_to_run_status(cape_task_status): + if cape_task_status == db_task.TASK_REPORTED: + return TEST_COMPLETE + if cape_task_status == db_task.TASK_PENDING: + return TEST_QUEUED + if cape_task_status in [db_task.TASK_RUNNING, + db_task.TASK_DISTRIBUTED, + db_task.TASK_RECOVERED, + db_task.TASK_COMPLETED, + db_task.TASK_DISTRIBUTED_COMPLETED]: + return TEST_RUNNING + if cape_task_status in [db_task.TASK_BANNED, + db_task.TASK_FAILED_ANALYSIS, + db_task.TASK_FAILED_PROCESSING, + db_task.TASK_FAILED_REPORTING + ]: + return TEST_FAILED + + raise Exception(f"Unknown cape task status: {cape_task_status}") diff --git a/lib/cuckoo/common/cleaners_utils.py b/lib/cuckoo/common/cleaners_utils.py index d264e73f6e9..ee36657f3f3 100644 --- a/lib/cuckoo/common/cleaners_utils.py +++ b/lib/cuckoo/common/cleaners_utils.py @@ -14,17 +14,16 @@ from lib.cuckoo.common.dist_db import create_session from lib.cuckoo.common.exceptions import CuckooOperationalError from lib.cuckoo.common.path_utils import path_delete, path_exists, path_get_date, path_is_dir, path_mkdir -from lib.cuckoo.core.database import ( +from lib.cuckoo.core.database import Database, _Database +from lib.cuckoo.core.data.samples import Sample +from lib.cuckoo.core.data.task import ( TASK_FAILED_ANALYSIS, TASK_FAILED_PROCESSING, TASK_FAILED_REPORTING, TASK_PENDING, TASK_RECOVERED, TASK_REPORTED, - Database, - Sample, - Task, - _Database, + Task ) from lib.cuckoo.core.startup import create_structure, init_console_logging diff --git a/lib/cuckoo/common/web_utils.py b/lib/cuckoo/common/web_utils.py index a8028a32eb4..5755283b553 100644 --- a/lib/cuckoo/common/web_utils.py +++ b/lib/cuckoo/common/web_utils.py @@ -33,17 +33,17 @@ validate_referrer, validate_ttp, ) -from lib.cuckoo.core.database import ( +from lib.cuckoo.core.data.task import ( ALL_DB_STATUSES, TASK_FAILED_ANALYSIS, TASK_FAILED_PROCESSING, TASK_FAILED_REPORTING, TASK_RECOVERED, TASK_REPORTED, - Database, - Sample, Task, -) + ) +from lib.cuckoo.core.data.samples import Sample +from lib.cuckoo.core.database import Database from lib.cuckoo.core.rooter import _load_socks5_operational, vpns from lib.downloaders import Downloaders diff --git a/lib/cuckoo/core/analysis_manager.py b/lib/cuckoo/core/analysis_manager.py index 808db42b640..366a8f7d569 100644 --- a/lib/cuckoo/core/analysis_manager.py +++ b/lib/cuckoo/core/analysis_manager.py @@ -21,7 +21,10 @@ from lib.cuckoo.common.objects import File from lib.cuckoo.common.path_utils import path_delete, path_exists, path_mkdir from lib.cuckoo.common.utils import convert_to_printable, create_folder, get_memdump_path -from lib.cuckoo.core.database import TASK_COMPLETED, TASK_PENDING, TASK_RUNNING, Database, Guest, Machine, Task, _Database +from lib.cuckoo.core.database import Database, _Database +from lib.cuckoo.core.data.task import TASK_COMPLETED, TASK_PENDING, TASK_RUNNING, Task +from lib.cuckoo.core.data.machines import Machine +from lib.cuckoo.core.data.guests import Guest from lib.cuckoo.core.guest import GuestManager from lib.cuckoo.core.machinery_manager import MachineryManager from lib.cuckoo.core.plugins import RunAuxiliary diff --git a/lib/cuckoo/core/data/__init__.py b/lib/cuckoo/core/data/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/lib/cuckoo/core/data/audit_data.py b/lib/cuckoo/core/data/audit_data.py new file mode 100644 index 00000000000..87a65389434 --- /dev/null +++ b/lib/cuckoo/core/data/audit_data.py @@ -0,0 +1,157 @@ +from datetime import datetime +from typing import List, Optional +from sqlalchemy import (Column, DateTime, ForeignKey, Integer, String, Table, Text, JSON, Boolean) +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from .db_common import _utcnow_naive, Base + +TEST_COMPLETE = "complete" +TEST_UNQUEUED = "unqueued" +TEST_QUEUED = "queued" +TEST_RUNNING = "running" +TEST_FAILED = "failed" + +class TestSession(Base): + """Test session table for tracking test runs.""" + + __tablename__ = "test_sessions" + id: Mapped[int] = mapped_column(primary_key=True) + added_on: Mapped[datetime] = mapped_column(DateTime(timezone=False), default=_utcnow_naive, nullable=False) + runs: Mapped[List["TestRun"]] = relationship(back_populates="session", cascade="all, delete-orphan") + + def __repr__(self): + return f"" + + @property + def unqueued_run_count(self): + return sum(1 for run in self.runs if run.status == TEST_UNQUEUED) + + @property + def queued_run_count(self): + return sum(1 for run in self.runs if run.status == TEST_QUEUED) + + +class AvailableTest(Base): + """A test case available for running against a CAPE sandbox + installation with the test harness""" + + __tablename__ = "available_tests" + + # db ID for the test + id: Mapped[int] = mapped_column(primary_key=True) + + # unique human readable name for the test + name: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) + + # description of test concept, objectives + description: Mapped[str | None] = mapped_column(Text, nullable=True) + + # describe the payload (format, arch, malice) + payload_notes: Mapped[str | None] = mapped_column(Text, nullable=True) + + # give useful info on what to expect from or how to interpret results + result_notes: Mapped[str | None] = mapped_column(Text, nullable=True) + + # password to unwrap the zip, if it is encrypted + zip_password: Mapped[str | None] = mapped_column(Text, nullable=True) + + # CAPE analysis package to use: exe, archive, doc2016, etc + package: Mapped[str] = mapped_column(String(64), nullable=False) + + # CAPE timeout parameter + timeout: Mapped[int | None] = mapped_column(Integer, nullable=True) + + # list of operating systems this test is expected to work on + targets: Mapped[List[str] | None] = mapped_column(Text, nullable=True) + + # Store absolute paths for execution and verification + payload_path: Mapped[str] = mapped_column(Text, nullable=False) + module_path: Mapped[str] = mapped_column(Text, nullable=False) + + # Store 'Task Config' and other metadata as a JSON blob + # we shouldn't need the details in the web view, just parse it + # in the test tasking logic + task_config: Mapped[dict] = mapped_column(JSON, nullable=False) + + objective_templates: Mapped[List["TestObjectiveTemplate"]] = relationship(secondary="test_template_association") + runs: Mapped[List["TestRun"]] = relationship(back_populates="test_definition") + + is_active: Mapped[str | None] = mapped_column(Boolean, default=True, nullable=False) + + +test_template_association = Table( + "test_template_association", + Base.metadata, + Column("test_id", ForeignKey("available_tests.id", ondelete="CASCADE"), primary_key=True), + Column("template_id", ForeignKey("test_objectives_templates.id", ondelete="CASCADE"), primary_key=True), +) + + +class TestObjectiveTemplate(Base): + """A measure of success of a single objective of a dynamic analysis + test run. eg: a certain flag was found in a dropped file.""" + + __tablename__ = "test_objectives_templates" + + # metadata true for all instances of this objective over all tests + id: Mapped[int] = mapped_column(primary_key=True) + full_name: Mapped[str] = mapped_column(String(512), unique=True, nullable=False) + name: Mapped[str] = mapped_column(String(255), nullable=False) + requirement: Mapped[str | None] = mapped_column(Text, nullable=True) + parent_id: Mapped[Optional[int]] = mapped_column(ForeignKey("test_objectives_templates.id")) + children: Mapped[List["TestObjectiveTemplate"]] = relationship(back_populates="parent", cascade="all, delete-orphan") + parent: Mapped[Optional["TestObjectiveTemplate"]] = relationship(back_populates="children", remote_side=[id]) + + +class TestObjectiveInstance(Base): + """A measure of success of a single objective of a dynamic analysis + test run. eg: a certain flag was found in a dropped file.""" + + __tablename__ = "test_objective_instances" + id: Mapped[int] = mapped_column(primary_key=True) + + # The Link to the objective template + template_id: Mapped[int] = mapped_column(ForeignKey("test_objectives_templates.id")) + template: Mapped["TestObjectiveTemplate"] = relationship() + + # link back to the test run + run_id: Mapped[int] = mapped_column(ForeignKey("test_runs.id"), nullable=False) + run: Mapped["TestRun"] = relationship(back_populates="objectives") + parent_id: Mapped[Optional[int]] = mapped_column(ForeignKey("test_objective_instances.id")) + children: Mapped[List["TestObjectiveInstance"]] = relationship( + back_populates="parent", + cascade="all, delete-orphan", + lazy="selectin" + ) + + parent: Mapped[Optional["TestObjectiveInstance"]] = relationship(back_populates="children", remote_side=[id]) + + # per-run state of this objective + state: Mapped[str | None] = mapped_column(Text, nullable=True) + state_reason: Mapped[str | None] = mapped_column(Text, nullable=True) + + +class TestRun(Base): + """Details of a single run of an AvailableTest within a TestSession.""" + + __tablename__ = "test_runs" + + id: Mapped[int] = mapped_column(primary_key=True) + session_id: Mapped[int] = mapped_column(ForeignKey("test_sessions.id")) + test_id: Mapped[int] = mapped_column(ForeignKey("available_tests.id")) + + # CAPE Specifics + cape_task_id: Mapped[Optional[int]] = mapped_column(nullable=True) # ID returned by CAPE API + status: Mapped[str] = mapped_column(String(50), default="unqueued") # pending, running, completed, failed + + # Results + started_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + completed_at: Mapped[Optional[datetime]] = mapped_column(DateTime) + logs: Mapped[Optional[str]] = mapped_column(Text()) + raw_results: Mapped[Optional[dict]] = mapped_column(JSON) # Store summary JSON from CAPE + + session: Mapped["TestSession"] = relationship(back_populates="runs") + test_definition: Mapped["AvailableTest"] = relationship(back_populates="runs") + objectives: Mapped[List["TestObjectiveInstance"]] = relationship( + back_populates="run", cascade="all, delete-orphan", lazy="joined" # Performance boost: loads objectives with the run + ) diff --git a/lib/cuckoo/core/data/audits.py b/lib/cuckoo/core/data/audits.py new file mode 100644 index 00000000000..0919815704a --- /dev/null +++ b/lib/cuckoo/core/data/audits.py @@ -0,0 +1,496 @@ +import os +import logging +import shutil +from typing import List, Optional, Tuple + +from sqlalchemy import select, delete +from lib.cuckoo.common.exceptions import CuckooDependencyError +from lib.cuckoo.common.objects import File +from lib.cuckoo.common.constants import CUCKOO_ROOT +from lib.cuckoo.common.audit_utils import task_status_to_run_status, TestResultValidator + +from .audit_data import (TestSession, AvailableTest, TestRun, + TestObjectiveTemplate, TestObjectiveInstance, + test_template_association, + TEST_COMPLETE, TEST_RUNNING, TEST_QUEUED) + +log = logging.getLogger(__name__) + +try: + from sqlalchemy import (func, select, exists, delete, update, String) + from sqlalchemy.orm import joinedload, selectinload, Session + +except ImportError: # pragma: no cover + raise CuckooDependencyError("Unable to import sqlalchemy (install with `poetry install`)") + + +class AuditsMixIn: + def list_test_sessions( + self, + db_session: Session, + limit:int = None, + offset:int = None, + ) -> List[TestSession]: + """Retrieve list of test harness test sessions. + @param limit: specify a limit of entries. + @param offset: list offset. + @return: list of test sessions. + """ + # Create a select statement ordered by newest first + stmt = (select(TestSession) + .order_by(TestSession.added_on.desc()) + .options(selectinload(TestSession.runs))) + + # Apply pagination if provided + if limit is not None: + stmt = stmt.limit(limit) + if offset is not None: + stmt = stmt.offset(offset) + + # Execute and return scalars (the actual objects) + result = db_session.scalars(stmt).unique() + return list(result.all()) + + def get_session_id_range(self) -> Tuple[Optional[int], Optional[int]]: + """ + Get the (oldest, newest) from the TestSession table for pagination + Returns (None, None) if the table is empty. + """ + stmt = select( + func.min(TestSession.id), + func.max(TestSession.id) + ) + result = self.session.execute(stmt).tuples().first() + return result if result else (None, None) + + def get_test(self, availabletest_id: int, db_session: Session) -> Optional[AvailableTest]: + stmt = select(AvailableTest).where(AvailableTest.id == availabletest_id) + return db_session.execute(stmt).unique().scalar_one_or_none() + + def list_available_tests( + self, + db_session, + limit=None, + offset=None, + active_only=True + ) -> List[AvailableTest]: + """Retrieve list of loaded and correctly parsed testcases. + @param limit: specify a limit of entries. + @param offset: list offset. + @return: list of tests. + """ + # Create a select statement ordered by newest first + stmt = select(AvailableTest).order_by(AvailableTest.name.desc()) + if active_only: + stmt = stmt.where(AvailableTest.is_active) + + # Apply pagination if provided + if limit is not None: + stmt = stmt.limit(limit) + if offset is not None: + stmt = stmt.offset(offset) + + # Execute and return scalars (the actual objects) + result = db_session.scalars(stmt) + testslist = list(result.all()) + return testslist + + def count_test_sessions(self) -> int: + """Count number of test sessions created + @return: number of test sessions. + """ + stmt = select(func.count(TestSession.id)) + return self.session.scalar(stmt) + + def count_available_tests(self, active_only=True) -> int: + """Count number of loaded and correctly parsed test cases + @return: number of available tests. + """ + stmt = select(func.count(AvailableTest.id)) + if active_only: + stmt = stmt.where(AvailableTest.is_active) + return self.session.scalar(stmt) + + def _load_test(self, test: AvailableTest, db_session: Session): + ''' + Upsert loaded test data into the database + ''' + result = {'module_path':test["module_path"]} + try: + info = test["info"] + test_name = info.get("Name") + + stmt = select(AvailableTest).where(AvailableTest.name == test_name) + test_template = db_session.execute(stmt).scalar_one_or_none() + if not test_template: + test_template = AvailableTest(name=test_name) + db_session.add(test_template) + result['added'] = True + else: + result['updated'] = True + + test_template.description=info.get("Description", None) + test_template.payload_notes=info.get("Payload Notes", None) + test_template.result_notes=info.get("Result Notes", None) + test_template.zip_password=info.get("Zip Password", None) + test_template.timeout=info.get("Timeout", None) + test_template.package=info.get("Package") + test_template.payload_path=test["payload_path"] + test_template.module_path=test["module_path"] + test_template.targets=info.get("Targets", None) + test_template.task_config=info.get("Task Config", {}) + test_template.is_active = True + + # Recursive upsert for objectives + def sync_objective(test_name, obj_data, parent_obj=None): + full_name = f"{test_name}::{obj_data.get('name')}" + + # Check if this template already exists + stmt = select(TestObjectiveTemplate).where( + TestObjectiveTemplate.full_name == full_name + ) + objective_template = db_session.execute(stmt).scalar_one_or_none() + + if not objective_template: + objective_template = TestObjectiveTemplate(full_name=full_name) + db_session.add(objective_template) + + # Update attributes + objective_template.name = obj_data.get("name") + objective_template.requirement = obj_data.get("requirement") + objective_template.parent = parent_obj + + # Handle children recursively + for child_data in obj_data.get("children", []): + sync_objective(test_name, child_data, parent_obj=objective_template) + return objective_template + + current_test_templates = [] + for obj_data in test["objectives"]: + tpl = sync_objective(test_name, obj_data) + current_test_templates.append(tpl) + + test_template.objective_templates = current_test_templates + + except Exception as ex: + result['errormsg'] = f"Error preparing test entry for {test['info'].get('Name','unknown')}: {ex}" + log.exception(result['errormsg']) + + return result + + def reload_tests(self, available_tests, unavailable_tests): + """Load parsed test info into the database + @param: available_tests: dictionaries of successfully parsed test metadata + @param: unavailable_tests: dictionaries of paths and errors for failed test loads + """ + log.info("Reloading available tests into database, currently there are %d",self.count_available_tests()) + + current_test_names = [] + stats = {'added':0, 'updated':0, 'error':0} + with self.session.session_factory() as db_session, db_session.begin(): + for test in available_tests: + load_result = self._load_test(test, db_session) + if 'error' in load_result: + unavailable_tests.append(load_result) + stats['error'] += 1 + else: + current_test_names.append(test["info"].get("Name")) + if load_result.get('added', False): + stats['added'] += 1 + if load_result.get('updated', False): + stats['updated'] += 1 + if stats['added'] > 0: + db_session.commit() + + test_count_after_add = self.count_available_tests() + self.purge_unreferenced_tests(current_test_names) + test_count_after_clean = self.count_available_tests() + removed = test_count_after_clean - test_count_after_add + msg = f"Reloaded tests, there are now {test_count_after_clean} available \ + ({stats['added']} added, {stats['updated']} updated, \ + {removed} removed, {stats['error']} errored)" + log.info(msg) + return test_count_after_clean + + def purge_unreferenced_tests(self, loaded_test_names): + """ + Cleanup function to remove tests and test objectives which were not + loaded by the previous reload, and are not referenced by any stored test sessions + @param: loaded_test_names: names of all tests that were recently loaded + """ + # delete tests not in the current loaded set and + # not referenced in a previous test session + with self.session.session_factory() as db_session, db_session.begin(): + retired_tests_stmt = delete(AvailableTest).where( + AvailableTest.name.notin_(loaded_test_names), + ~exists().where(TestRun.test_id == AvailableTest.id) + ) + db_session.execute(retired_tests_stmt) + + # mark deleted tests referenced by past sessions as inactive so + db_session.execute( + update(AvailableTest) + .where(AvailableTest.name.notin_(loaded_test_names)) + .values(is_active=False) + ) + + # Only delete if they are NOT used by ANY test AND NOT used by ANY results + orphaned_tpl_stmt = delete(TestObjectiveTemplate).where( + # Not linked to any AvailableTest (active or inactive) + ~exists().where(test_template_association.c.template_id == TestObjectiveTemplate.id), + + # AND not linked to any historical test results + ~exists().where(TestObjectiveInstance.template_id == TestObjectiveTemplate.id) + ) + + # Pass 1: Delete Leaf nodes that meet the criteria + db_session.execute(orphaned_tpl_stmt.where(TestObjectiveTemplate.parent_id.is_not(None))) + + # Pass 2: Delete Root nodes that meet the criteria + db_session.execute(orphaned_tpl_stmt.where(TestObjectiveTemplate.parent_id.is_(None))) + db_session.commit() + + def store_objective_results(self, run_id: int, results: dict): + with self.session.session_factory() as db_sess, db_sess.begin(): + # Fetch the run with its objectives and templates pre-loaded + stmt = ( + select(TestRun) + .options( + joinedload(TestRun.objectives) + .joinedload(TestObjectiveInstance.template) + ) + .where(TestRun.id == run_id) + ) + run = db_sess.execute(stmt).unique().scalar_one_or_none() + + if not run: + log.error("Run %d not found for result assignment",run_id) + return + + # Helper to traverse the results dict and update instances + def apply_results(instances, current_results_level): + for inst in instances: + name = inst.template.name + if name in current_results_level: + data = current_results_level[name] + + # Update the instance state + inst.state = data.get("state") + inst.state_reason = data.get("state_reason") + + # Recurse into children if they exist in both DB and Dict + if inst.children and "children" in data: + apply_results(inst.children, data["children"]) + + # We only pass top-level objectives (those without a parent_id) + top_level_instances = [obj for obj in run.objectives if obj.parent_id is None] + apply_results(top_level_instances, results) + + # Store the whole thing for posterity + run.raw_results = results + log.info("Updated objective states for Run %d",run_id) + db_sess.commit() + + def evaluate_objective_results(self, test_run: TestRun): + log.info("Starting evaluation of test run #%d",test_run.id) + task_storage = os.path.join(CUCKOO_ROOT, "storage", "analyses", str(test_run.cape_task_id)) + validator = TestResultValidator( test_run.test_definition.module_path, task_storage) + results_dict = validator.evaluate() + self.store_objective_results(test_run.id, results_dict) + + def update_audit_tasks_status(self, db_session: Session, audit_session: TestSession): + changed = False + for run in audit_session.runs: + if run.cape_task_id: + cape_task = self.view_task(run.cape_task_id) + if cape_task: + new_status = task_status_to_run_status(cape_task.status) + if run.status != new_status: + if run.status != TEST_COMPLETE and new_status == TEST_COMPLETE: + self.evaluate_objective_results(run) + run.status = new_status + changed = True + if changed: + db_session.commit() + + def get_test_session(self, session_id: int) -> Optional[TestSession]: + with self.session.session_factory() as db_session, db_session.begin(): + stmt = ( + select(TestSession) + .options( + # Branch A: Load the test definitions for the runs + selectinload(TestSession.runs) + .joinedload(TestRun.test_definition), + + # Branch B: Load the objectives, their templates, AND their children + selectinload(TestSession.runs) + .selectinload(TestRun.objectives) + .options( + joinedload(TestObjectiveInstance.template), + selectinload(TestObjectiveInstance.children) + .joinedload(TestObjectiveInstance.template) + ) + ) + .where(TestSession.id == session_id) + ) + + test_session = db_session.execute(stmt).unique().scalar_one_or_none() + # do just-in-time refresh of test run statuses + if test_session: + self.update_audit_tasks_status(db_session, test_session) + db_session.expunge_all() + return test_session + return None + + def delete_test_session(self, session_id: int, purge_storage: bool = True) -> bool: + """ + Deletes a specific TestSession and all associated objective instances. + @param: session_id: audit session to delete + @param: purge_storage: if true, also delete the task storage directories of all the test runs + """ + session_id = int(session_id) + with self.session.session_factory() as db_session, db_session.begin(): + stmt = select(TestSession).where(TestSession.id == session_id) + session_obj = db_session.execute(stmt).unique().scalar_one_or_none() + + if not session_obj: + log.warning("Attempted to delete non-existent TestSession ID: %d",session_id) + return False + + # Safety check: Don't delete active runs unless forced + stmt = ( + select(func.count(TestRun.id)) + .where( + TestRun.session_id == session_id, + TestRun.status == TEST_RUNNING + ) + ) + active_runs = db_session.execute(stmt).scalar() + + if active_runs > 0: + log.warning("Cannot delete Session %d: one or more runs are still 'running'",session_id) + return False + + if purge_storage: + for run in session_obj.runs: + cape_task_id = run.cape_task_id + if isinstance(cape_task_id, int): + task_storage_dir = os.path.join(CUCKOO_ROOT, "storage", "analyses", str(cape_task_id)) + if os.path.isdir(task_storage_dir): + shutil.rmtree(task_storage_dir) + + db_session.delete(session_obj) + db_session.commit() + log.info("Deleted TestSession %d and all its objective results.",session_id) + return True + + def create_session_from_tests(self, test_ids: list) -> int: + with self.session.session_factory() as db_session, db_session.begin(): + try: + # 1. Initialize the new Session + new_test_session = TestSession() + db_session.add(new_test_session) + + # Flush so the DB generates an ID for new_session + # without committing the whole transaction yet + db_session.flush() + + # 2. Create a Run entry for every test ID provided + for t_id in test_ids: + test_def = db_session.get(AvailableTest, int(t_id)) + run = TestRun(session_id=new_test_session.id, test_id=test_def.id) + + db_session.add(run) + for template in test_def.objective_templates: + def init_objective(obj_template): + children = [init_objective(obj_child) for obj_child in obj_template.children] + result = TestObjectiveInstance(run_id=run.id, template_id=obj_template.id,children=children, state="untested") + return result + + run.objectives.append(init_objective(template)) + + db_session.commit() + # The session ID to return for the redirect + test_session_id = new_test_session.id + + return test_session_id + except Exception as e: + db_session.rollback() + log.error("Failed to create test session: %s",str(e)) + raise + finally: + db_session.close() + + def get_audit_session_test(self, session_id: int, testrun_id: int) -> Optional[TestRun]: + stmt = ( + select(TestRun) + .options( + joinedload(TestRun.test_definition), + selectinload(TestRun.objectives).joinedload(TestObjectiveInstance.template), + selectinload(TestRun.objectives).selectinload(TestObjectiveInstance.children) + ) + .where(TestRun.id == testrun_id) + .where(TestRun.session_id == session_id) + ) + + with self.session.session_factory() as db_sess, db_sess.begin(): + result = db_sess.execute(stmt).unique().scalar_one_or_none() + if result: + db_sess.expunge_all() + return result + + def set_audit_run_status(self, session_id: int, testrun_id: int, new_status: String) -> None: + with self.session.session_factory() as db_sess, db_sess.begin(): + run = self.get_audit_session_test(session_id, testrun_id) + if run: + run.status = new_status + db_sess.commit() + + def assign_cape_task_to_testrun(self, run_id: int, cape_task_id: int) -> bool: + """ + Updates a TestRun with the ID returned from the CAPE sandbox. + """ + with self.session.session_factory() as db_sess, db_sess.begin(): + stmt = select(TestRun).where(TestRun.id == run_id) + run = db_sess.execute(stmt).unique().scalar_one_or_none() + + if run: + run.cape_task_id = cape_task_id + run.status = TEST_QUEUED + db_sess.commit() + log.info("TestRun %d successfully linked to CAPE Task %d",run_id,cape_task_id) + return True + else: + log.error("Failed to link task and task ID: TestRun %d not found.",run_id) + return False + + def queue_audit_test(self, session_id, testrun_id, user_id=0): + test_instance = self.get_audit_session_test(session_id, testrun_id) + test_definition = test_instance.test_definition + + conf = test_definition.task_config + task_options = conf.get("Request Options","") + if task_options is None: # if None -> pending forever + task_options = "" + + new_task_id = self.add( + File(test_definition.payload_path), + timeout=test_definition.timeout, + package=test_definition.package, + options=task_options, + priority=1, + custom=conf.get("Custom Request Params",""), + #machine=machine, + #platform=platform, + tags=conf.get("Tags",None), + #memory=memory, + enforce_timeout=conf.get("Enforce Timeout",None), + #clock=clock, + route=test_definition.task_config.get("Route",None), + #cape=cape, + tags_tasks=["audit"], + user_id=user_id, + #parent_sample=parent_sample, + source_url=False + ) + return new_task_id diff --git a/lib/cuckoo/core/data/db_common.py b/lib/cuckoo/core/data/db_common.py new file mode 100644 index 00000000000..101b3fcf74c --- /dev/null +++ b/lib/cuckoo/core/data/db_common.py @@ -0,0 +1,115 @@ +from __future__ import annotations +import json +from typing import TYPE_CHECKING, List +from lib.cuckoo.common.config import Config +from lib.cuckoo.common.exceptions import CuckooDependencyError +if TYPE_CHECKING: + from .machines import Machine + from .task import Task + +from datetime import datetime, timezone +import pytz +try: + from sqlalchemy import ( + Column, + ForeignKey, + Integer, + String, + Table, + ) + from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + mapped_column, + relationship, + ) + +except ImportError: # pragma: no cover + raise CuckooDependencyError("Unable to import sqlalchemy (install with `poetry install`)") + +# ToDo verify variable declaration in Mapped +class Base(DeclarativeBase): + pass + +cfg = Config("cuckoo") +tz_name = cfg.cuckoo.get("timezone", "utc") + +def _utcnow_naive(): + """Returns the current time in the configured timezone as a naive datetime object.""" + try: + tz = pytz.timezone(tz_name) + except pytz.UnknownTimeZoneError: + tz = timezone.utc + return datetime.now(tz).replace(tzinfo=None) + +# Secondary table used in association Machine - Tag. +machines_tags = Table( + "machines_tags", + Base.metadata, + Column("machine_id", Integer, ForeignKey("machines.id")), + Column("tag_id", Integer, ForeignKey("tags.id")), +) + +# Secondary table used in association Task - Tag. +tasks_tags = Table( + "tasks_tags", + Base.metadata, + Column("task_id", Integer, ForeignKey("tasks.id", ondelete="cascade")), + Column("tag_id", Integer, ForeignKey("tags.id", ondelete="cascade")), +) + +class Tag(Base): + """Tag describing anything you want.""" + + __tablename__ = "tags" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(nullable=False, unique=True) + machines: Mapped[List["Machine"]] = relationship(secondary=machines_tags, back_populates="tags") + tasks: Mapped[List["Task"]] = relationship(secondary=tasks_tags, back_populates="tags") + + def __repr__(self): + return f"" + + def __init__(self, name): + self.name = name + + +class Error(Base): + """Analysis errors.""" + + __tablename__ = "errors" + MAX_LENGTH = 1024 + + id: Mapped[int] = mapped_column(primary_key=True) + message: Mapped[str] = mapped_column(String(MAX_LENGTH), nullable=False) + task_id: Mapped[int] = mapped_column(ForeignKey("tasks.id"), nullable=False) + task: Mapped["Task"] = relationship(back_populates="errors") + + def to_dict(self): + """Converts object to dict. + @return: dict + """ + d = {} + for column in self.__table__.columns: + d[column.name] = getattr(self, column.name) + return d + + def to_json(self): + """Converts object to JSON. + @return: JSON data + """ + return json.dumps(self.to_dict()) + + def __init__(self, message, task_id): + if len(message) > self.MAX_LENGTH: + # Make sure that we don't try to insert an error message longer than what's allowed + # in the database. Provide the beginning and the end of the error. + left_of_ellipses = self.MAX_LENGTH // 2 - 2 + right_of_ellipses = self.MAX_LENGTH - left_of_ellipses - 3 + message = "...".join((message[:left_of_ellipses], message[-right_of_ellipses:])) + self.message = message + self.task_id = task_id + + def __repr__(self): + return f"" diff --git a/lib/cuckoo/core/data/guests.py b/lib/cuckoo/core/data/guests.py new file mode 100644 index 00000000000..a7227d01967 --- /dev/null +++ b/lib/cuckoo/core/data/guests.py @@ -0,0 +1,91 @@ +from typing import Optional, TYPE_CHECKING +from datetime import datetime +import json +from lib.cuckoo.common.exceptions import CuckooDependencyError +if TYPE_CHECKING: + from lib.cuckoo.core.data.task import Task +from .db_common import Base, _utcnow_naive + +try: + from sqlalchemy import DateTime, ForeignKey, select + from sqlalchemy.orm import Mapped, mapped_column, relationship +except ImportError: # pragma: no cover + raise CuckooDependencyError("Unable to import sqlalchemy (install with `poetry install`)") + + +class Guest(Base): + """Tracks guest run.""" + + __tablename__ = "guests" + + id: Mapped[int] = mapped_column(primary_key=True) + status: Mapped[str] = mapped_column(nullable=False) + name: Mapped[str] = mapped_column(nullable=False) + label: Mapped[str] = mapped_column(nullable=False) + platform: Mapped[str] = mapped_column(nullable=False) + manager: Mapped[str] = mapped_column(nullable=False) + + started_on: Mapped[datetime] = mapped_column( + DateTime(timezone=False), default=_utcnow_naive, nullable=False + ) + shutdown_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) + task_id: Mapped[int] = mapped_column(ForeignKey("tasks.id", ondelete="cascade"), nullable=False, unique=True) + task: Mapped["Task"] = relationship(back_populates="guest") + + def __repr__(self): + return f"" + + def to_dict(self): + """Converts object to dict. + @return: dict + """ + d = {} + for column in self.__table__.columns: + value = getattr(self, column.name) + if isinstance(value, datetime): + d[column.name] = value.strftime("%Y-%m-%d %H:%M:%S") + else: + d[column.name] = value + return d + + def to_json(self): + """Converts object to JSON. + @return: JSON data + """ + return json.dumps(self.to_dict()) + + def __init__(self, name, label, platform, manager, task_id): + self.name = name + self.label = label + self.platform = platform + self.manager = manager + self.task_id = task_id + + +class GuestsMixIn: + def guest_get_status(self, task_id: int): + """Gets the status for a given guest.""" + stmt = select(Guest).where(Guest.task_id == task_id) + guest = self.session.scalar(stmt) + return guest.status if guest else None + + def guest_set_status(self, task_id: int, status: str): + """Sets the status for a given guest.""" + stmt = select(Guest).where(Guest.task_id == task_id) + guest = self.session.scalar(stmt) + if guest is not None: + guest.status = status + + def guest_remove(self, guest_id): + """Removes a guest start entry.""" + guest = self.session.get(Guest, guest_id) + if guest: + self.session.delete(guest) + + def guest_stop(self, guest_id): + """Logs guest stop. + @param guest_id: guest log entry id + """ + guest = self.session.get(Guest, guest_id) + if guest: + guest.shutdown_on = _utcnow_naive() diff --git a/lib/cuckoo/core/data/machines.py b/lib/cuckoo/core/data/machines.py new file mode 100644 index 00000000000..7ecefa87c80 --- /dev/null +++ b/lib/cuckoo/core/data/machines.py @@ -0,0 +1,399 @@ +from __future__ import annotations +import json +import logging +from typing import List, Optional, Union +from datetime import datetime +from lib.cuckoo.common.config import Config +from lib.cuckoo.common.exceptions import CuckooDependencyError, CuckooUnserviceableTaskError +from .db_common import Base, machines_tags +from .db_common import Tag +from .db_common import _utcnow_naive +from .task import Task +from .guests import Guest + + +try: + from sqlalchemy import ( + Boolean, + DateTime, + delete, + func, + Integer, + select, + Select, + String, + ) + from sqlalchemy.orm import ( + Mapped, + mapped_column, + relationship, + subqueryload, + ) + +except ImportError: # pragma: no cover + raise CuckooDependencyError("Unable to import sqlalchemy (install with `poetry install`)") + +MACHINE_RUNNING = "running" + +log = logging.getLogger(__name__) +web_conf = Config("web") + + + +class Machine(Base): + """Configured virtual machines to be used as guests.""" + + __tablename__ = "machines" + + id: Mapped[int] = mapped_column(Integer(), primary_key=True) + name: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) + label: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) + arch: Mapped[str] = mapped_column(String(255), nullable=False) + ip: Mapped[str] = mapped_column(String(255), nullable=False) + platform: Mapped[str] = mapped_column(String(255), nullable=False) + tags: Mapped[List["Tag"]] = relationship(secondary=machines_tags, back_populates="machines") + interface: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + snapshot: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + locked: Mapped[bool] = mapped_column(Boolean(), nullable=False, default=False) + locked_changed_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) + status: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + status_changed_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) + resultserver_ip: Mapped[str] = mapped_column(String(255), nullable=False) + resultserver_port: Mapped[str] = mapped_column(String(255), nullable=False) + reserved: Mapped[bool] = mapped_column(Boolean(), nullable=False, default=False) + + def __repr__(self): + return f"" + + def to_dict(self): + """Converts object to dict. + @return: dict + """ + d = {} + for column in self.__table__.columns: + value = getattr(self, column.name) + if isinstance(value, datetime): + d[column.name] = value.strftime("%Y-%m-%d %H:%M:%S") + else: + d[column.name] = value + + # Tags are a relation so no column to iterate. + d["tags"] = [tag.name for tag in self.tags] + return d + + def to_json(self): + """Converts object to JSON. + @return: JSON data + """ + return json.dumps(self.to_dict()) + + def __init__(self, name, label, arch, ip, platform, interface, snapshot, resultserver_ip, resultserver_port, reserved): + self.name = name + self.label = label + self.arch = arch + self.ip = ip + self.platform = platform + self.interface = interface + self.snapshot = snapshot + self.resultserver_ip = resultserver_ip + self.resultserver_port = resultserver_port + self.reserved = reserved + +class MachinesMixIn: + def clean_machines(self): + """Clean old stored machines and related tables.""" + # Secondary table. + # TODO: this is better done via cascade delete. + # self.engine.execute(machines_tags.delete()) + # ToDo : If your ForeignKey has "ON DELETE CASCADE", deleting a Machine + # would automatically delete its entries in machines_tags. + # If not, deleting them manually first is correct. + self.session.execute(delete(machines_tags)) + self.session.execute(delete(Machine)) + + def delete_machine(self, name) -> bool: + """Delete a single machine entry from DB.""" + + stmt = select(Machine).where(Machine.name == name) + machine = self.session.scalar(stmt) + + if machine: + # Deleting a specific ORM instance remains the same + self.session.delete(machine) + return True + else: + log.warning("%s does not exist in the database.", name) + return False + + def add_machine( + self, name, label, arch, ip, platform, tags, interface, snapshot, resultserver_ip, resultserver_port, reserved, locked=False + ) -> Machine: + """Add a guest machine. + @param name: machine id + @param label: machine label + @param arch: machine arch + @param ip: machine IP address + @param platform: machine supported platform + @param tags: list of comma separated tags + @param interface: sniffing interface for this machine + @param snapshot: snapshot name to use instead of the current one, if configured + @param resultserver_ip: IP address of the Result Server + @param resultserver_port: port of the Result Server + @param reserved: True if the machine can only be used when specifically requested + """ + + machine = Machine( + name=name, + label=label, + arch=arch, + ip=ip, + platform=platform, + interface=interface, + snapshot=snapshot, + resultserver_ip=resultserver_ip, + resultserver_port=resultserver_port, + reserved=reserved, + ) + + if tags: + with self.session.no_autoflush: + for tag in tags.replace(" ", "").split(","): + machine.tags.append(self._get_or_create(Tag, name=tag)) + if locked: + machine.locked = True + + self.session.add(machine) + return machine + + def set_machine_interface(self, label, interface): + stmt = select(Machine).filter_by(label=label) + machine = self.session.scalar(stmt) + + if machine is None: + log.debug("Database error setting interface: %s not found", label) + return + + # This part remains the same + machine.interface = interface + return machine + + def create_guest(self, machine: Machine, manager: str, task: Task) -> Guest: + guest = Guest(machine.name, machine.label, machine.platform, manager, task.id) + guest.status = "init" + self.session.add(guest) + return guest + + def _package_vm_requires_check(self, package: str) -> list: + """ + We allow to users use their custom tags to tag properly any VM that can run this package + """ + return [vm_tag.strip() for vm_tag in web_conf.packages.get(package).split(",")] if web_conf.packages.get(package) else [] + + def find_machine_to_service_task(self, task: Task) -> Optional[Machine]: + """Find a machine that is able to service the given task. + Returns: The Machine if an available machine was found; None if there is at least 1 machine + that *could* service it, but they are all currently in use. + Raises: CuckooUnserviceableTaskError if there are no machines in the pool that would be able + to service it. + """ + task_archs, task_tags = self._task_arch_tags_helper(task) + os_version = self._package_vm_requires_check(task.package) + + base_stmt = select(Machine).options(subqueryload(Machine.tags)) + + # This helper now encapsulates the final ordering, locking, and execution. + # It takes a Select statement as input. + def get_locked_machine(stmt: Select) -> Optional[Machine]: + final_stmt = stmt.order_by(Machine.locked, Machine.locked_changed_on).with_for_update(of=Machine) + return self.session.scalars(final_stmt).first() + + filter_kwargs = { + "statement": base_stmt, + "label": task.machine, + "platform": task.platform, + "tags": task_tags, + "archs": task_archs, + "os_version": os_version, + } + + filtered_stmt = self.filter_machines_to_task(include_reserved=False, **filter_kwargs) + machine = get_locked_machine(filtered_stmt) + + if machine is None and not task.machine and task_tags: + # The task was given at least 1 tag, but there are no non-reserved machines + # that could satisfy the request. So let's check "reserved" machines. + filtered_stmt = self.filter_machines_to_task(include_reserved=True, **filter_kwargs) + machine = get_locked_machine(filtered_stmt) + + if machine is None: + raise CuckooUnserviceableTaskError + if machine.locked: + # There aren't any machines that can service the task NOW, but there is at + # least one in the pool that could service it once it's available. + return None + return machine + + @staticmethod + def filter_machines_by_arch(statement: Select, arch: list) -> Select: + """Adds a filter to the given select statement for the machine architecture. + Allows x64 machines to be returned when requesting x86. + """ + if arch: + if "x86" in arch: + # Prefer x86 machines over x64 if x86 is what was requested. + statement = statement.where(Machine.arch.in_(("x64", "x86"))).order_by(Machine.arch.desc()) + else: + statement = statement.where(Machine.arch.in_(arch)) + return statement + + def filter_machines_to_task( + self, statement: Select, label=None, platform=None, tags=None, archs=None, os_version=None, include_reserved=False + ) -> Select: + """Adds filters to the given select statement based on the task. + + @param statement: A `select()` statement to add filters to. + """ + if label: + statement = statement.where(Machine.label == label) + elif not include_reserved: + # Use .is_(False) for boolean checks + statement = statement.where(Machine.reserved.is_(False)) + + if platform: + statement = statement.where(Machine.platform == platform) + + statement = self.filter_machines_by_arch(statement, archs) + + if tags: + for tag in tags: + statement = statement.where(Machine.tags.any(name=tag)) + + if os_version: + statement = statement.where(Machine.tags.any(Tag.name.in_(os_version))) + + return statement + + def list_machines( + self, + locked=None, + label=None, + platform=None, + tags=None, + arch=None, + include_reserved=False, + os_version=None, + ) -> List[Machine]: + """Lists virtual machines. + @return: list of virtual machines + """ + """ + id | name | label | arch | + ----+-------+-------+------+ + 77 | cape1 | win7 | x86 | + 78 | cape2 | win10 | x64 | + """ + # ToDo do we really need it + with self.session.begin_nested(): + # with self.session.no_autoflush: + stmt = select(Machine).options(subqueryload(Machine.tags)) + + if locked is not None: + stmt = stmt.where(Machine.locked.is_(locked)) + + stmt = self.filter_machines_to_task( + statement=stmt, + label=label, + platform=platform, + tags=tags, + archs=arch, + os_version=os_version, + include_reserved=include_reserved, + ) + return self.session.execute(stmt).unique().scalars().all() + + def assign_machine_to_task(self, task: Task, machine: Optional[Machine]) -> Task: + if machine: + task.machine = machine.label + task.machine_id = machine.id + else: + task.machine = None + task.machine_id = None + self.session.add(task) + return task + + def lock_machine(self, machine: Machine) -> Machine: + """Places a lock on a free virtual machine. + @param machine: the Machine to lock + @return: locked machine + """ + machine.locked = True + machine.locked_changed_on = _utcnow_naive() + self.set_machine_status(machine, MACHINE_RUNNING) + self.session.add(machine) + + return machine + + def unlock_machine(self, machine: Machine) -> Machine: + """Remove lock from a virtual machine. + @param machine: The Machine to unlock. + @return: unlocked machine + """ + machine.locked = False + machine.locked_changed_on = _utcnow_naive() + self.session.merge(machine) + return machine + + def count_machines_available(self, label=None, platform=None, tags=None, arch=None, include_reserved=False, os_version=None): + """How many (relevant) virtual machines are ready for analysis. + @param label: machine ID. + @param platform: machine platform. + @param tags: machine tags + @param arch: machine arch + @param include_reserved: include 'reserved' machines in the result, regardless of whether or not a 'label' was provided. + @return: free virtual machines count + """ + stmt = select(func.count(Machine.id)).where(Machine.locked.is_(False)) + stmt = self.filter_machines_to_task( + statement=stmt, + label=label, + platform=platform, + tags=tags, + archs=arch, + os_version=os_version, + include_reserved=include_reserved, + ) + + return self.session.scalar(stmt) + + def get_available_machines(self) -> List[Machine]: + """Which machines are available""" + stmt = select(Machine).options(subqueryload(Machine.tags)).where(Machine.locked.is_(False)) + return self.session.scalars(stmt).all() + + def count_machines_running(self) -> int: + """Counts how many machines are currently locked (running).""" + stmt = select(func.count(Machine.id)).where(Machine.locked.is_(True)) + return self.session.scalar(stmt) + + def set_machine_status(self, machine_or_label: Union[str, Machine], status): + """Set status for a virtual machine.""" + if isinstance(machine_or_label, str): + stmt = select(Machine).where(Machine.label == machine_or_label) + machine = self.session.scalar(stmt) + else: + machine = machine_or_label + + if machine: + machine.status = status + machine.status_changed_on = _utcnow_naive() + # No need for session.add() here; the ORM tracks changes to loaded objects. + + def view_machine(self, name: str) -> Optional[Machine]: + """Shows virtual machine details by name.""" + stmt = select(Machine).options(subqueryload(Machine.tags)).where(Machine.name == name) + return self.session.scalar(stmt) + + def view_machine_by_label(self, label: str) -> Optional[Machine]: + """Shows virtual machine details by label.""" + stmt = select(Machine).options(subqueryload(Machine.tags)).where(Machine.label == label) + return self.session.scalar(stmt) diff --git a/lib/cuckoo/core/data/samples.py b/lib/cuckoo/core/data/samples.py new file mode 100644 index 00000000000..d444761789f --- /dev/null +++ b/lib/cuckoo/core/data/samples.py @@ -0,0 +1,407 @@ +from typing import List, Optional, Union +from datetime import timedelta +import hashlib +import os +import json +import logging +from .db_common import (Base, _utcnow_naive) +from .task import (Task, TASK_PENDING, TASK_RUNNING, TASK_DISTRIBUTED) +from lib.cuckoo.common.config import Config +from lib.cuckoo.common.objects import PCAP, File, Static +from lib.cuckoo.common.exceptions import ( + CuckooDependencyError +) +from lib.cuckoo.common.constants import CUCKOO_ROOT +from lib.cuckoo.common.path_utils import path_exists + +repconf = Config("reporting") +web_conf = Config("web") + +if repconf.mongodb.enabled: + from dev_utils.mongodb import mongo_find +from sqlalchemy.exc import IntegrityError +try: + from sqlalchemy import ( + BigInteger, + func, + ForeignKey, + Index, + select, + String, + Text, + ) + from sqlalchemy.orm import ( + aliased, + Mapped, + joinedload, + mapped_column, + relationship, + ) +except ImportError: # pragma: no cover + raise CuckooDependencyError("Unable to import sqlalchemy (install with `poetry install`)") + +log = logging.getLogger(__name__) + +class SampleAssociation(Base): + __tablename__ = "sample_associations" + + # Each column is part of a composite primary key + parent_id: Mapped[int] = mapped_column(ForeignKey("samples.id"), primary_key=True) + child_id: Mapped[int] = mapped_column(ForeignKey("samples.id"), primary_key=True) + + # This is the crucial column that links to the specific child's task + task_id: Mapped[int] = mapped_column(ForeignKey("tasks.id", ondelete="CASCADE"), primary_key=True) + + # Relationships from the association object itself + parent: Mapped["Sample"] = relationship(foreign_keys=[parent_id], back_populates="child_links") + child: Mapped["Sample"] = relationship(foreign_keys=[child_id], back_populates="parent_links") + task: Mapped["Task"] = relationship(back_populates="association") + + +class Sample(Base): + """Submitted files details.""" + + __tablename__ = "samples" + + id: Mapped[int] = mapped_column(primary_key=True) + file_size: Mapped[int] = mapped_column(BigInteger, nullable=False) + file_type: Mapped[str] = mapped_column(Text(), nullable=False) + md5: Mapped[str] = mapped_column(String(32), nullable=False) + crc32: Mapped[str] = mapped_column(String(8), nullable=False) + sha1: Mapped[str] = mapped_column(String(40), nullable=False) + sha256: Mapped[str] = mapped_column(String(64), nullable=False) + sha512: Mapped[str] = mapped_column(String(128), nullable=False) + ssdeep: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + source_url: Mapped[Optional[str]] = mapped_column(String(2000), nullable=True) + tasks: Mapped[List["Task"]] = relationship(back_populates="sample", cascade="all, delete-orphan") + + child_links: Mapped[List["SampleAssociation"]] = relationship( + foreign_keys=[SampleAssociation.parent_id], back_populates="parent" + ) + # When this Sample is a child, this gives you its association links + parent_links: Mapped[List["SampleAssociation"]] = relationship( + foreign_keys=[SampleAssociation.child_id], back_populates="child" + ) + + # ToDo replace with index=True + __table_args__ = ( + Index("md5_index", "md5"), + Index("sha1_index", "sha1"), + Index("sha256_index", "sha256", unique=True), + ) + + def __repr__(self): + return f"" + + def to_dict(self): + """Converts object to dict. + @return: dict + """ + d = {} + for column in self.__table__.columns: + d[column.name] = getattr(self, column.name) + return d + + def to_json(self): + """Converts object to JSON. + @return: JSON data + """ + return json.dumps(self.to_dict()) + + def __init__(self, md5, crc32, sha1, sha256, sha512, file_size, file_type=None, ssdeep=None, parent_sample=None, source_url=None): + self.md5 = md5 + self.sha1 = sha1 + self.crc32 = crc32 + self.sha256 = sha256 + self.sha512 = sha512 + self.file_size = file_size + if file_type: + self.file_type = file_type + if ssdeep: + self.ssdeep = ssdeep + # if parent_sample: + # self.parent_sample = parent_sample + if source_url: + self.source_url = source_url + +class SamplesMixIn: + def register_sample(self, obj, source_url=False): + if isinstance(obj, (File, PCAP, Static)): + fileobj = File(obj.file_path) + file_type = fileobj.get_type() + file_md5 = fileobj.get_md5() + sample = None + # check if hash is known already + try: + # get or create + sample = self.session.scalar(select(Sample).where(Sample.md5 == file_md5)) + if sample is None: + with self.session.begin_nested(): + sample = Sample( + md5=file_md5, + crc32=fileobj.get_crc32(), + sha1=fileobj.get_sha1(), + sha256=fileobj.get_sha256(), + sha512=fileobj.get_sha512(), + file_size=fileobj.get_size(), + file_type=file_type, + ssdeep=fileobj.get_ssdeep(), + source_url=source_url, + ) + self.session.add(sample) + except IntegrityError as e: + log.exception(e) + return sample + + + def check_file_uniq(self, sha256: str, hours: int = 0): + # TODO This function is poorly named. It returns True if a sample with the given + # sha256 already exists in the database, rather than returning True if the given + # sha256 is unique. + uniq = False + if hours and sha256: + date_since = _utcnow_naive() - timedelta(hours=hours) + + stmt = ( + select(Task) + .join(Sample, Task.sample_id == Sample.id) + .where(Sample.sha256 == sha256) + .where(Task.added_on >= date_since) + ) + return self.session.scalar(select(stmt.exists())) + else: + if not self.find_sample(sha256=sha256): + uniq = False + else: + uniq = True + return uniq + + def get_file_types(self) -> List[str]: + """Gets a sorted list of unique sample file types.""" + # .distinct() is cleaner than group_by() for a single column. + stmt = select(Sample.file_type).distinct().order_by(Sample.file_type) + return self.session.scalars(stmt).all() + + def view_sample(self, sample_id): + """Retrieve information on a sample given a sample id. + @param sample_id: ID of the sample to query. + @return: details on the sample used in sample: sample_id. + """ + return self.session.get(Sample, sample_id) + + def get_children_by_parent_id(self, parent_id: int) -> List[Sample]: + """ + Finds all child Samples using an explicit join. + """ + # Create an alias to represent the Child Sample in the query + ChildSample = aliased(Sample, name="child") + + # This query selects child samples by joining through the association table + stmt = ( + select(ChildSample) + .join(SampleAssociation, ChildSample.id == SampleAssociation.child_id) + .where(SampleAssociation.parent_id == parent_id) + ) + + return self.session.scalars(stmt).all() + + def find_sample( + self, md5: str = None, sha1: str = None, sha256: str = None, parent: int = None, task_id: int = None, sample_id: int = None + ) -> Union[Optional[Sample], List[Sample], List[Task]]: + """Searches for samples or tasks based on different criteria.""" + + if md5: + return self.session.scalar(select(Sample).where(Sample.md5 == md5)) + + if sha1: + return self.session.scalar(select(Sample).where(Sample.sha1 == sha1)) + + if sha256: + return self.session.scalar(select(Sample).where(Sample.sha256 == sha256)) + + if parent is not None: + return self.get_children_by_parent_id(parent) + + if sample_id is not None: + # Using session.get() is much more efficient than a select query. + # We wrap the result in a list to match the original function's behavior. + sample = self.session.get(Sample, sample_id) + return [sample] if sample else [] + + if task_id is not None: + # Note: This branch returns a list of Task objects. + stmt = select(Task).join(Sample, Task.sample_id == Sample.id).options(joinedload(Task.sample)).where(Task.id == task_id) + return self.session.scalars(stmt).all() + + return None + + def sample_still_used(self, sample_hash: str, task_id: int): + """Retrieve information if sample is used by another task(s). + @param sample_hash: sha256. + @param task_id: task_id + @return: bool + """ + stmt = ( + select(Task) + .join(Sample, Task.sample_id == Sample.id) + .where(Sample.sha256 == sample_hash) + .where(Task.id != task_id) + .where(Task.status.in_((TASK_PENDING, TASK_RUNNING, TASK_DISTRIBUTED))) + ) + + # select(stmt.exists()) creates a `SELECT EXISTS(...)` query. + # session.scalar() executes it and returns True or False directly. + return self.session.scalar(select(stmt.exists())) + + def _hash_file_in_chunks(self, path: str, hash_algo) -> str: + """Helper function to hash a file efficiently in chunks.""" + hasher = hash_algo() + buffer_size = 65536 # 64kb + with open(path, "rb") as f: + while chunk := f.read(buffer_size): + hasher.update(chunk) + return hasher.hexdigest() + + def sample_path_by_hash(self, sample_hash: str = False, task_id: int = False): + """Retrieve information on a sample location by given hash. + @param hash: md5/sha1/sha256/sha256. + @param task_id: task_id + @return: samples path(s) as list. + """ + sizes = { + 32: Sample.md5, + 40: Sample.sha1, + 64: Sample.sha256, + 128: Sample.sha512, + } + + hashlib_sizes = { + 32: hashlib.md5, + 40: hashlib.sha1, + 64: hashlib.sha256, + 128: hashlib.sha512, + } + + sizes_mongo = { + 32: "md5", + 40: "sha1", + 64: "sha256", + 128: "sha512", + } + + if task_id: + file_path = os.path.join(CUCKOO_ROOT, "storage", "analyses", str(task_id), "binary") + if path_exists(file_path): + return [file_path] + + # binary also not stored in binaries, perform hash lookup + stmt = select(Sample).join(Task, Sample.id == Task.sample_id).where(Task.id == task_id) + db_sample = self.session.scalar(stmt) + if db_sample: + path = os.path.join(CUCKOO_ROOT, "storage", "binaries", db_sample.sha256) + if path_exists(path): + return [path] + + sample_hash = db_sample.sha256 + + if not sample_hash: + return [] + + query_filter = sizes.get(len(sample_hash), "") + sample = [] + # check storage/binaries + if query_filter: + stmt = select(Sample).where(query_filter == sample_hash) + db_sample = self.session.scalar(stmt) + if db_sample is not None: + path = os.path.join(CUCKOO_ROOT, "storage", "binaries", db_sample.sha256) + if path_exists(path): + sample = [path] + + if not sample: + tasks = [] + if repconf.mongodb.enabled and web_conf.general.check_sample_in_mongodb: + tasks = mongo_find( + "files", + {sizes_mongo.get(len(sample_hash), ""): sample_hash}, + {"_info_ids": 1, "sha256": 1}, + ) + """ deprecated code + elif repconf.elasticsearchdb.enabled: + tasks = [ + d["_source"] + for d in es.search( + index=get_analysis_index(), + body={"query": {"match": {f"CAPE.payloads.{sizes_mongo.get(len(sample_hash), '')}": sample_hash}}}, + _source=["CAPE.payloads", "info.id"], + )["hits"]["hits"] + ] + """ + if tasks: + for task in tasks: + for id in task.get("_task_ids", []): + # ToDo suricata path - "suricata.files.file_info.path + for category in ("files", "procdump", "CAPE"): + file_path = os.path.join(CUCKOO_ROOT, "storage", "analyses", str(id), category, task["sha256"]) + if path_exists(file_path): + sample = [file_path] + break + if sample: + break + + if not sample: + # search in temp folder if not found in binaries + stmt = select(Task).join(Sample, Task.sample_id == Sample.id).where(query_filter == sample_hash) + db_sample = self.session.scalars(stmt).all() + + if db_sample is not None: + """ + samples = [_f for _f in [tmp_sample.to_dict().get("target", "") for tmp_sample in db_sample] if _f] + # hash validation and if exist + samples = [file_path for file_path in samples if path_exists(file_path)] + for path in samples: + with open(path, "rb") as f: + if sample_hash == hashlib_sizes[len(sample_hash)](f.read()).hexdigest(): + sample = [path] + break + """ + # Use a generator expression for memory efficiency + target_paths = (tmp_sample.to_dict().get("target", "") for tmp_sample in db_sample) + + # Filter for paths that exist + existing_paths = (p for p in target_paths if p and path_exists(p)) + # ToDo review if we really want/need this + for path in existing_paths: + if sample_hash == self._hash_file_in_chunks(path, hashlib_sizes[len(sample_hash)]): + sample = [path] + break + return sample + + def count_samples(self) -> int: + """Counts the amount of samples in the database.""" + stmt = select(func.count(Sample.id)) + return self.session.scalar(stmt) + + def get_source_url(self, sample_id: int = None) -> Optional[str]: + """Retrieves the source URL for a given sample ID.""" + if not sample_id: + return None + + try: + stmt = select(Sample.source_url).where(Sample.id == int(sample_id)) + return self.session.scalar(stmt) + except (TypeError, ValueError): + # Handle cases where sample_id is not a valid integer. + return None + + def get_parent_sample_from_task(self, task_id: int) -> Optional[Sample]: + """Finds the Parent Sample using the ID of the child's Task.""" + + # This query joins the Sample table (as the parent) to the + # association object and filters by the task_id. + stmt = ( + select(Sample) + .join(SampleAssociation, Sample.id == SampleAssociation.parent_id) + .where(SampleAssociation.task_id == task_id) + ) + return self.session.scalar(stmt) diff --git a/lib/cuckoo/core/data/task.py b/lib/cuckoo/core/data/task.py new file mode 100644 index 00000000000..d04d2572ca5 --- /dev/null +++ b/lib/cuckoo/core/data/task.py @@ -0,0 +1,178 @@ +import json +from datetime import datetime +from typing import List, Optional, TYPE_CHECKING +from lib.cuckoo.common.exceptions import CuckooDependencyError +if TYPE_CHECKING: + from lib.cuckoo.core.data.samples import Sample, SampleAssociation + from lib.cuckoo.core.data.guests import Guest + from lib.cuckoo.core.data.db_common import Tag, Error + +from .db_common import Base, _utcnow_naive, tasks_tags +try: + from sqlalchemy.orm import Mapped, mapped_column, relationship + from sqlalchemy import ( + Boolean, + DateTime, + Enum, + ForeignKey, + Index, + Integer, + String, + Text, + ) +except ImportError: # pragma: no cover + raise CuckooDependencyError("Unable to import sqlalchemy (install with `poetry install`)") + +TASK_BANNED = "banned" +TASK_PENDING = "pending" +TASK_RUNNING = "running" +TASK_DISTRIBUTED = "distributed" +TASK_COMPLETED = "completed" +TASK_RECOVERED = "recovered" +TASK_REPORTED = "reported" +TASK_FAILED_ANALYSIS = "failed_analysis" +TASK_FAILED_PROCESSING = "failed_processing" +TASK_FAILED_REPORTING = "failed_reporting" +TASK_DISTRIBUTED_COMPLETED = "distributed_completed" + + +ALL_DB_STATUSES = ( + TASK_BANNED, + TASK_PENDING, + TASK_RUNNING, + TASK_DISTRIBUTED, + TASK_COMPLETED, + TASK_RECOVERED, + TASK_REPORTED, + TASK_FAILED_ANALYSIS, + TASK_FAILED_PROCESSING, + TASK_FAILED_REPORTING, + TASK_DISTRIBUTED_COMPLETED, +) + + +class Task(Base): + """Analysis task queue.""" + + __tablename__ = "tasks" + + id: Mapped[int] = mapped_column(Integer(), primary_key=True) + target: Mapped[str] = mapped_column(Text(), nullable=False) + category: Mapped[str] = mapped_column(String(255), nullable=False) + cape: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True) + timeout: Mapped[int] = mapped_column(Integer(), server_default="0", nullable=False) + priority: Mapped[int] = mapped_column(Integer(), server_default="1", nullable=False) + custom: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + machine: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + package: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + route: Mapped[Optional[str]] = mapped_column(String(128), nullable=True, default=False) + # Task tags + tags_tasks: Mapped[Optional[str]] = mapped_column(String(256), nullable=True) + # Virtual machine tags + tags: Mapped[List["Tag"]] = relationship(secondary=tasks_tags, back_populates="tasks", passive_deletes=True) + options: Mapped[Optional[str]] = mapped_column(Text(), nullable=True) + platform: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + memory: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + enforce_timeout: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + clock: Mapped[datetime] = mapped_column( + DateTime(timezone=False), + default=_utcnow_naive, + nullable=False, + ) + added_on: Mapped[datetime] = mapped_column( + DateTime(timezone=False), + default=_utcnow_naive, + nullable=False, + ) + started_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) + completed_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) + status: Mapped[str] = mapped_column( + Enum( + TASK_BANNED, + TASK_PENDING, + TASK_RUNNING, + TASK_COMPLETED, + TASK_DISTRIBUTED, + TASK_REPORTED, + TASK_RECOVERED, + TASK_FAILED_ANALYSIS, + TASK_FAILED_PROCESSING, + TASK_FAILED_REPORTING, + name="status_type", + ), + server_default=TASK_PENDING, + nullable=False, + ) + + # Statistics data to identify broken Cuckoos servers or VMs + # Also for doing profiling to improve speed + dropped_files: Mapped[Optional[int]] = mapped_column(nullable=True) + running_processes: Mapped[Optional[int]] = mapped_column(nullable=True) + api_calls: Mapped[Optional[int]] = mapped_column(nullable=True) + domains: Mapped[Optional[int]] = mapped_column(nullable=True) + signatures_total: Mapped[Optional[int]] = mapped_column(nullable=True) + signatures_alert: Mapped[Optional[int]] = mapped_column(nullable=True) + files_written: Mapped[Optional[int]] = mapped_column(nullable=True) + registry_keys_modified: Mapped[Optional[int]] = mapped_column(nullable=True) + crash_issues: Mapped[Optional[int]] = mapped_column(nullable=True) + anti_issues: Mapped[Optional[int]] = mapped_column(nullable=True) + analysis_started_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) + analysis_finished_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) + processing_started_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) + processing_finished_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) + signatures_started_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) + signatures_finished_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) + reporting_started_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) + reporting_finished_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) + timedout: Mapped[bool] = mapped_column(nullable=False, default=False) + + sample_id: Mapped[Optional[int]] = mapped_column(ForeignKey("samples.id"), nullable=True) + sample: Mapped["Sample"] = relationship(back_populates="tasks") # , lazy="subquery" + machine_id: Mapped[Optional[int]] = mapped_column(nullable=True) + guest: Mapped["Guest"] = relationship( + back_populates="task", uselist=False, cascade="all, delete-orphan" # This is crucial for a one-to-one relationship + ) + errors: Mapped[List["Error"]] = relationship( + back_populates="task", cascade="all, delete-orphan" # This MUST match the attribute name on the Error model + ) + + tlp: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) + user_id: Mapped[Optional[int]] = mapped_column(nullable=True) + + # The Task is linked to one specific parent/child association event + association: Mapped[Optional["SampleAssociation"]] = relationship(back_populates="task", cascade="all, delete-orphan") + + __table_args__ = ( + Index("category_index", "category"), + Index("status_index", "status"), + Index("added_on_index", "added_on"), + Index("completed_on_index", "completed_on"), + ) + + def to_dict(self): + """Converts object to dict. + @return: dict + """ + d = {} + for column in self.__table__.columns: + value = getattr(self, column.name) + if isinstance(value, datetime): + d[column.name] = value.strftime("%Y-%m-%d %H:%M:%S") + else: + d[column.name] = value + + # Tags are a relation so no column to iterate. + d["tags"] = [tag.name for tag in self.tags] + return d + + def to_json(self): + """Converts object to JSON. + @return: JSON data + """ + return json.dumps(self.to_dict()) + + def __init__(self, target=None): + self.target = target + + def __repr__(self): + return f"" diff --git a/lib/cuckoo/core/data/tasking.py b/lib/cuckoo/core/data/tasking.py new file mode 100644 index 00000000000..e3716ba4dd6 --- /dev/null +++ b/lib/cuckoo/core/data/tasking.py @@ -0,0 +1,1367 @@ +from .db_common import _utcnow_naive +import logging +from typing import List, Optional, Tuple, Dict +from datetime import datetime, timedelta, timezone + +from lib.cuckoo.common.config import Config +from lib.cuckoo.common.integrations.parse_pe import PortableExecutable +from lib.cuckoo.common.objects import PCAP, URL, File, Static +from lib.cuckoo.common.exceptions import CuckooDependencyError +from lib.cuckoo.common.utils import bytes2str, get_options +from lib.cuckoo.common.demux import demux_sample +from lib.cuckoo.common.cape_utils import static_config_lookup, static_extraction +from lib.cuckoo.common.path_utils import path_delete, path_exists +from .samples import Sample, SampleAssociation +from .db_common import Tag, Error +from .task import (Task, TASK_PENDING, TASK_RUNNING, TASK_DISTRIBUTED, + TASK_COMPLETED, TASK_RECOVERED, TASK_REPORTED, + TASK_FAILED_PROCESSING, TASK_DISTRIBUTED_COMPLETED, + TASK_FAILED_REPORTING, TASK_BANNED + ) + +# Sflock does a good filetype recon +from sflock.abstracts import File as SflockFile +from sflock.ident import identify as sflock_identify + +try: + from sqlalchemy.exc import SQLAlchemyError + from sqlalchemy import ( + delete, + func, + not_, + select, + update, + ) + from sqlalchemy.orm import joinedload, subqueryload +except ImportError: # pragma: no cover + raise CuckooDependencyError("Unable to import sqlalchemy (install with `poetry install`)") + + +log = logging.getLogger(__name__) +conf = Config("cuckoo") +distconf = Config("distributed") +web_conf = Config("web") + +LINUX_STATIC = web_conf.linux.static_only +DYNAMIC_ARCH_DETERMINATION = web_conf.general.dynamic_arch_determination + +sandbox_packages = ( + "access", + "archive", + "nsis", + "cpl", + "reg", + "regsvr", + "dll", + "exe", + "pdf", + "pub", + "doc", + "xls", + "ppt", + "jar", + "zip", + "rar", + "swf", + "python", + "msi", + "msix", + "ps1", + "msg", + "nodejs", + "eml", + "js", + "html", + "hta", + "xps", + "wsf", + "mht", + "doc", + "vbs", + "lnk", + "chm", + "hwp", + "inp", + "vbs", + "js", + "vbejse", + "msbuild", + "sct", + "xslt", + "shellcode", + "shellcode_x64", + "generic", + "iso", + "vhd", + "udf", + "one", + "inf", +) + +class TasksMixIn: + def add( + self, + obj, + *, + timeout=0, + package="", + options="", + priority=1, + custom="", + machine="", + platform="", + tags=None, + memory=False, + enforce_timeout=False, + clock=None, + parent_sample=None, + tlp=None, + static=False, + source_url=False, + route=None, + cape=False, + tags_tasks=False, + user_id=0, + ): + """Add a task to database. + @param obj: object to add (File or URL). + @param timeout: selected timeout. + @param options: analysis options. + @param priority: analysis priority. + @param custom: custom options. + @param machine: selected machine. + @param platform: platform. + @param tags: optional tags that must be set for machine selection + @param memory: toggle full memory dump. + @param enforce_timeout: toggle full timeout execution. + @param clock: virtual machine clock time + @param parent_id: parent task id + @param parent_sample: original sample in case of archive + @param static: try static extraction first + @param tlp: TLP sharing designation + @param source_url: url from where it was downloaded + @param route: Routing route + @param cape: CAPE options + @param tags_tasks: Task tags so users can tag their jobs + @param user_id: Link task to user if auth enabled + @return: cursor or None. + """ + # Convert empty strings and None values to a valid int + + if isinstance(obj, (File, PCAP, Static)): + fileobj = File(obj.file_path) + file_type = fileobj.get_type() + file_md5 = fileobj.get_md5() + # check if hash is known already + # ToDo consider migrate to _get_or_create? + sample = self.session.scalar(select(Sample).where(Sample.md5 == file_md5)) + if not sample: + try: + with self.session.begin_nested(): + sample = Sample( + md5=file_md5, + crc32=fileobj.get_crc32(), + sha1=fileobj.get_sha1(), + sha256=fileobj.get_sha256(), + sha512=fileobj.get_sha512(), + file_size=fileobj.get_size(), + file_type=file_type, + ssdeep=fileobj.get_ssdeep(), + source_url=source_url, + ) + self.session.add(sample) + except Exception as e: + log.exception(e) + + if DYNAMIC_ARCH_DETERMINATION: + # Assign architecture to task to fetch correct VM type + + # This isn't 100% fool proof + _tags = tags.split(",") if isinstance(tags, str) else [] + arch_tag = fileobj.predict_arch() + if package.endswith("_x64"): + _tags.append("x64") + elif arch_tag: + _tags.append(arch_tag) + tags = ",".join(set(_tags)) + task = Task(obj.file_path) + task.sample_id = sample.id + + if isinstance(obj, (PCAP, Static)): + # since no VM will operate on this PCAP + task.started_on = _utcnow_naive() + + elif isinstance(obj, URL): + task = Task(obj.url) + _tags = tags.split(",") if isinstance(tags, str) else [] + _tags.append("x64") + _tags.append("x86") + tags = ",".join(set(_tags)) + + else: + return None + + task.category = obj.__class__.__name__.lower() + task.timeout = timeout + task.package = package + task.options = options + task.priority = priority + task.custom = custom + task.machine = machine + task.platform = platform + task.memory = bool(memory) + task.enforce_timeout = enforce_timeout + task.tlp = tlp + task.route = route + task.cape = cape + task.tags_tasks = tags_tasks + # Deal with tags format (i.e., foo,bar,baz) + if tags: + for tag in tags.split(","): + tag_name = tag.strip() + if tag_name and tag_name not in [tag.name for tag in task.tags]: + # "Task" object is being merged into a Session along the backref cascade path for relationship "Tag.tasks"; in SQLAlchemy 2.0, this reverse cascade will not take place. + # Set cascade_backrefs to False in either the relationship() or backref() function for the 2.0 behavior; or to set globally for the whole Session, set the future=True flag + # (Background on this error at: https://sqlalche.me/e/14/s9r1) (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9) + task.tags.append(self._get_or_create(Tag, name=tag_name)) + + if clock: + if isinstance(clock, str): + try: + task.clock = datetime.strptime(clock, "%m-%d-%Y %H:%M:%S") + except ValueError: + log.warning("The date you specified has an invalid format, using current timestamp") + task.clock = datetime.fromtimestamp(0, timezone.utc).replace(tzinfo=None) + + else: + task.clock = clock + else: + task.clock = datetime.fromtimestamp(0, timezone.utc).replace(tzinfo=None) + + task.user_id = user_id + + if parent_sample: + association = SampleAssociation( + parent=parent_sample, + child=sample, + task=task, + ) + self.session.add(association) + + # Use a nested transaction so that we can return an ID. + with self.session.begin_nested(): + self.session.add(task) + + return task.id + + def add_path( + self, + file_path, + timeout=0, + package="", + options="", + priority=1, + custom="", + machine="", + platform="", + tags=None, + memory=False, + enforce_timeout=False, + clock=None, + tlp=None, + static=False, + source_url=False, + route=None, + cape=False, + tags_tasks=False, + user_id=0, + parent_sample = None, + ): + """Add a task to database from file path. + @param file_path: sample path. + @param timeout: selected timeout. + @param options: analysis options. + @param priority: analysis priority. + @param custom: custom options. + @param machine: selected machine. + @param platform: platform. + @param tags: Tags required in machine selection + @param memory: toggle full memory dump. + @param enforce_timeout: toggle full timeout execution. + @param clock: virtual machine clock time + @param parent_id: parent analysis id + @param parent_sample: sample object if archive + @param static: try static extraction first + @param tlp: TLP sharing designation + @param route: Routing route + @param cape: CAPE options + @param tags_tasks: Task tags so users can tag their jobs + @user_id: Allow link task to user if auth enabled + @parent_sample: Sample object, if archive + @return: cursor or None. + """ + if not file_path or not path_exists(file_path): + log.warning("File does not exist: %s", file_path) + return None + + # Convert empty strings and None values to a valid int + if not timeout: + timeout = 0 + if not priority: + priority = 1 + if file_path.endswith((".htm", ".html")) and not package: + package = web_conf.url_analysis.package + + return self.add( + File(file_path), + timeout=timeout, + package=package, + options=options, + priority=priority, + custom=custom, + machine=machine, + platform=platform, + tags=tags, + memory=memory, + enforce_timeout=enforce_timeout, + clock=clock, + tlp=tlp, + source_url=source_url, + route=route, + cape=cape, + tags_tasks=tags_tasks, + user_id=user_id, + parent_sample=parent_sample, + ) + + def _identify_aux_func(self, file: bytes, package: str, check_shellcode: bool = True) -> tuple: + # before demux we need to check as msix has zip mime and we don't want it to be extracted: + tmp_package = False + if not package: + f = SflockFile.from_path(file) + try: + tmp_package = sflock_identify(f, check_shellcode=check_shellcode) + except Exception as e: + log.error("Failed to sflock_ident due to %s", str(e)) + tmp_package = "generic" + + if tmp_package and tmp_package in sandbox_packages: + # This probably should be way much bigger list of formats + if tmp_package in ("iso", "udf", "vhd"): + package = "archive" + elif tmp_package in ("zip", "rar"): + package = "" + elif tmp_package in ("html",): + package = web_conf.url_analysis.package + else: + package = tmp_package + + return package, tmp_package + + def demux_sample_and_add_to_db( + self, + file_path, + timeout=0, + package="", + options="", + priority=1, + custom="", + machine="", + platform="", + tags=None, + memory=False, + enforce_timeout=False, + clock=None, + tlp=None, + static=False, + source_url=False, + only_extraction=False, + tags_tasks=False, + route=None, + cape=False, + user_id=0, + category=None, + ): + """ + Handles ZIP file submissions, submitting each extracted file to the database + Returns a list of added task IDs + """ + task_id = False + task_ids = [] + config = {} + details = {} + + if not isinstance(file_path, bytes): + file_path = file_path.encode() + + ( + static, + priority, + machine, + platform, + custom, + memory, + clock, + unique, + referrer, + tlp, + tags_tasks, + route, + cape, + options, + timeout, + enforce_timeout, + package, + tags, + category, + ) = self.recon( + file_path, + options, + timeout=timeout, + enforce_timeout=enforce_timeout, + package=package, + tags=tags, + static=static, + priority=priority, + machine=machine, + platform=platform, + custom=custom, + memory=memory, + clock=clock, + tlp=tlp, + tags_tasks=tags_tasks, + route=route, + cape=cape, + category=category, + ) + + if category == "static": + # force change of category + task_ids += self.add_static( + file_path=file_path, + priority=priority, + tlp=tlp, + user_id=user_id, + options=options, + package=package, + ) + return task_ids, details + + check_shellcode = True + if options and "check_shellcode=0" in options: + check_shellcode = False + + if not package: + if "file=" in options: + # set zip as package when specifying file= in options + package = "zip" + else: + # Checking original file as some filetypes doesn't require demux + package, _ = self._identify_aux_func(file_path, package, check_shellcode=check_shellcode) + + parent_sample = None + # extract files from the (potential) archive + extracted_files, demux_error_msgs = demux_sample(file_path, package, options, platform=platform) + # check if len is 1 and the same file, if diff register file, and set parent + if extracted_files and not any(file_path == path for path, _ in extracted_files): + parent_sample = self.register_sample(File(file_path), source_url=source_url) + if conf.cuckoo.delete_archive: + path_delete(file_path.decode()) + + # create tasks for each file in the archive + for file, platform in extracted_files: + if not path_exists(file): + log.error("Extracted file doesn't exist: %s", file) + continue + # ToDo we lose package here and send APKs to windows + if platform in ("linux", "darwin") and LINUX_STATIC: + task_ids += self.add_static( + file_path=file_path, + priority=priority, + tlp=tlp, + user_id=user_id, + options=options, + package=package, + parent_sample=parent_sample, + ) + continue + if static: + # On huge loads this just become a bottleneck + config = False + if web_conf.general.check_config_exists: + config = static_config_lookup(file) + if config: + task_ids.append(config["id"]) + else: + config = static_extraction(file) + if config or only_extraction: + task_ids += self.add_static( + file_path=file, priority=priority, tlp=tlp, user_id=user_id, options=options, parent_sample=parent_sample, + ) + + if not config and not only_extraction: + if not package: + package, tmp_package = self._identify_aux_func(file, "", check_shellcode=check_shellcode) + + if not tmp_package: + log.info("Do sandbox packages need an update? Sflock identifies as: %s - %s", tmp_package, file) + + if package == "dll" and "function" not in options: + with PortableExecutable(file.decode()) as pe: + dll_export = pe.choose_dll_export() + if dll_export == "DllRegisterServer": + package = "regsvr" + elif dll_export == "xlAutoOpen": + package = "xls" + elif dll_export: + if options: + options += f",function={dll_export}" + else: + options = f"function={dll_export}" + + # ToDo better solution? - Distributed mode here: + # Main node is storage so try to extract before submit to vm isn't propagated to workers + if static and not config and distconf.distributed.enabled: + if options: + options += ",dist_extract=1" + else: + options = "dist_extract=1" + + task_id = self.add_path( + file_path=file.decode(), + timeout=timeout, + priority=priority, + options=options, + package=package, + machine=machine, + platform=platform, + memory=memory, + custom=custom, + enforce_timeout=enforce_timeout, + tags=tags, + clock=clock, + tlp=tlp, + source_url=source_url, + route=route, + tags_tasks=tags_tasks, + cape=cape, + user_id=user_id, + parent_sample=parent_sample, + ) + package = None + if task_id: + task_ids.append(task_id) + + if config and isinstance(config, dict): + details = {"config": config.get("cape_config", {})} + if demux_error_msgs: + details["errors"] = demux_error_msgs + # this is aim to return custom data, think of this as kwargs + return task_ids, details + + def add_pcap( + self, + file_path, + timeout=0, + package="", + options="", + priority=1, + custom="", + machine="", + platform="", + tags=None, + memory=False, + enforce_timeout=False, + clock=None, + tlp=None, + user_id=0, + ): + return self.add( + PCAP(file_path.decode()), + timeout=timeout, + package=package, + options=options, + priority=priority, + custom=custom, + machine=machine, + platform=platform, + tags=tags, + memory=memory, + enforce_timeout=enforce_timeout, + clock=clock, + tlp=tlp, + user_id=user_id, + ) + + def add_static( + self, + file_path, + timeout=0, + package="", + options="", + priority=1, + custom="", + machine="", + platform="", + tags=None, + memory=False, + enforce_timeout=False, + clock=None, + tlp=None, + static=True, + user_id=0, + parent_sample=None, + ): + extracted_files, demux_error_msgs = demux_sample(file_path, package, options) + + # check if len is 1 and the same file, if diff register file, and set parent + if not isinstance(file_path, bytes): + file_path = file_path.encode() + + # ToDo callback maybe or inside of the self.add + if extracted_files and ((file_path, platform) not in extracted_files and (file_path, "") not in extracted_files): + if not parent_sample: + parent_sample = self.register_sample(File(file_path)) + if conf.cuckoo.delete_archive: + # ToDo keep as info for now + log.info("Deleting archive: %s. conf.cuckoo.delete_archive is enabled. %s", file_path, str(extracted_files)) + path_delete(file_path) + + task_ids = [] + # create tasks for each file in the archive + for file, platform in extracted_files: + task_id = self.add( + Static(file.decode()), + timeout=timeout, + package=package, + options=options, + priority=priority, + custom=custom, + machine=machine, + platform=platform, + tags=tags, + memory=memory, + enforce_timeout=enforce_timeout, + clock=clock, + tlp=tlp, + static=static, + parent_sample=parent_sample, + user_id=user_id, + ) + if task_id: + task_ids.append(task_id) + + return task_ids + + def add_url( + self, + url, + timeout=0, + package="", + options="", + priority=1, + custom="", + machine="", + platform="", + tags=None, + memory=False, + enforce_timeout=False, + clock=None, + tlp=None, + route=None, + cape=False, + tags_tasks=False, + user_id=0, + ): + """Add a task to database from url. + @param url: url. + @param timeout: selected timeout. + @param options: analysis options. + @param priority: analysis priority. + @param custom: custom options. + @param machine: selected machine. + @param platform: platform. + @param tags: tags for machine selection + @param memory: toggle full memory dump. + @param enforce_timeout: toggle full timeout execution. + @param clock: virtual machine clock time + @param tlp: TLP sharing designation + @param route: Routing route + @param cape: CAPE options + @param tags_tasks: Task tags so users can tag their jobs + @param user_id: Link task to user + @return: cursor or None. + """ + + # Convert empty strings and None values to a valid int + if not timeout: + timeout = 0 + if not priority: + priority = 1 + if not package: + package = web_conf.url_analysis.package + + return self.add( + URL(url), + timeout=timeout, + package=package, + options=options, + priority=priority, + custom=custom, + machine=machine, + platform=platform, + tags=tags, + memory=memory, + enforce_timeout=enforce_timeout, + clock=clock, + tlp=tlp, + route=route, + cape=cape, + tags_tasks=tags_tasks, + user_id=user_id, + ) + + def set_vnc_port(self, task_id: int, port: int): + stmt = select(Task).where(Task.id == task_id) + task = self.session.scalar(stmt) + + if task is None: + log.debug("Database error setting VPN port: For task %s", task_id) + return + + # This logic remains the same + if task.options: + task.options += f",vnc_port={port}" + else: + task.options = f"vnc_port={port}" + + def _task_arch_tags_helper(self, task: Task): + # Are there available machines that match up with a task? + task_archs = [tag.name for tag in task.tags if tag.name in ("x86", "x64")] + task_tags = [tag.name for tag in task.tags if tag.name not in task_archs] + + return task_archs, task_tags + + def update_clock(self, task_id): + row = self.session.get(Task, task_id) + + if not row: + return + # datetime.fromtimestamp(0, tz=timezone.utc) + if row.clock == datetime.fromtimestamp(0, timezone.utc).replace(tzinfo=None): + if row.category == "file": + # datetime.now(timezone.utc) + row.clock = _utcnow_naive() + timedelta(days=self.cfg.cuckoo.daydelta) + else: + # datetime.now(timezone.utc) + row.clock = _utcnow_naive() + return row.clock + + def set_task_status(self, task: Task, status) -> Task: + if status != TASK_DISTRIBUTED_COMPLETED: + task.status = status + + if status in (TASK_RUNNING, TASK_DISTRIBUTED): + task.started_on = _utcnow_naive() + elif status in (TASK_COMPLETED, TASK_DISTRIBUTED_COMPLETED): + task.completed_on = _utcnow_naive() + elif status == TASK_REPORTED: + task.reporting_finished_on = _utcnow_naive() + + self.session.add(task) + return task + + def set_status(self, task_id: int, status) -> Optional[Task]: + """Set task status. + @param task_id: task identifier + @param status: status string + @return: operation status + """ + log.info("setstat task %d status %s",task_id,status) + task = self.session.get(Task, task_id) + + if not task: + return None + + return self.set_task_status(task, status) + + def fetch_task(self, categories: list = None): + """Fetches a task waiting to be processed and locks it for running. + @return: None or task + """ + stmt = ( + select(Task) + .where(Task.status == TASK_PENDING) + .where(not_(Task.options.contains("node="))) + .order_by(Task.priority.desc(), Task.added_on) + ) + + if categories: + stmt = stmt.where(Task.category.in_(categories)) + + # 2. Execute the statement and get the first result object + row = self.session.scalars(stmt).first() + + if not row: + return None + + # This business logic remains the same + self.set_status(task_id=row.id, status=TASK_RUNNING) + + return row + + def add_error(self, message, task_id): + """Add an error related to a task.""" + # This function already uses modern, correct SQLAlchemy 2.0 patterns. + # No changes are needed. + error = Error(message=message, task_id=task_id) + # Use a separate session so that, regardless of the state of a transaction going on + # outside of this function, the error will always be committed to the database. + with self.session.session_factory() as sess, sess.begin(): + sess.add(error) + + def reschedule(self, task_id): + """Reschedule a task. + @param task_id: ID of the task to reschedule. + @return: ID of the newly created task. + """ + task = self.view_task(task_id) + + if not task: + return None + + if task.category == "file": + add = self.add_path + elif task.category == "url": + add = self.add_url + elif task.category == "pcap": + add = self.add_pcap + elif task.category == "static": + add = self.add_static + + # Change status to recovered. + self.session.get(Task, task_id).status = TASK_RECOVERED + + # Normalize tags. + if task.tags: + tags = ",".join(tag.name for tag in task.tags) + else: + tags = task.tags + + def _ensure_valid_target(task): + if task.category == "url": + # URL tasks always have valid targets, return it as-is. + return task.target + + # All other task types have a "target" pointing to a temp location, + # so get a stable path "target" based on the sample hash. + paths = self.sample_path_by_hash(task.sample.sha256, task_id) + paths = [file_path for file_path in paths if path_exists(file_path)] + if not paths: + return None + + if task.category == "pcap": + # PCAP task paths are represented as bytes + return paths[0].encode() + return paths[0] + + task_target = _ensure_valid_target(task) + if not task_target: + log.warning("Unable to find valid target for task: %s", task_id) + return + + new_task_id = None + if task.category in ("file", "url"): + new_task_id = add( + task_target, + task.timeout, + task.package, + task.options, + task.priority, + task.custom, + task.machine, + task.platform, + tags, + task.memory, + task.enforce_timeout, + task.clock, + tlp=task.tlp, + route=task.route, + ) + elif task.category in ("pcap", "static"): + new_task_id = add( + task_target, + task.timeout, + task.package, + task.options, + task.priority, + task.custom, + task.machine, + task.platform, + tags, + task.memory, + task.enforce_timeout, + task.clock, + tlp=task.tlp, + ) + + self.session.get(Task, task_id).custom = f"Recovery_{new_task_id}" + + return new_task_id + + def count_matching_tasks(self, category=None, status=None, not_status=None): + """Retrieve list of task. + @param category: filter by category + @param status: filter by task status + @param not_status: exclude this task status from filter + @return: number of tasks. + """ + stmt = select(func.count(Task.id)) + + if status: + stmt = stmt.where(Task.status == status) + if not_status: + stmt = stmt.where(Task.status != not_status) + if category: + stmt = stmt.where(Task.category == category) + + # 2. Execute the statement and return the single integer result. + return self.session.scalar(stmt) + + def list_tasks( + self, + limit=None, + details=False, + category=None, + offset=None, + status=None, + sample_id=None, + not_status=None, + completed_after=None, + order_by=None, + added_before=None, + id_before=None, + id_after=None, + options_like=False, + options_not_like=False, + tags_tasks_like=False, + tags_tasks_not_like=False, + task_ids=False, + include_hashes=False, + user_id=None, + for_update=False, + ) -> List[Task]: + """Retrieve list of task. + @param limit: specify a limit of entries. + @param details: if details about must be included + @param category: filter by category + @param offset: list offset + @param status: filter by task status + @param sample_id: filter tasks for a sample + @param not_status: exclude this task status from filter + @param completed_after: only list tasks completed after this timestamp + @param order_by: definition which field to sort by + @param added_before: tasks added before a specific timestamp + @param id_before: filter by tasks which is less than this value + @param id_after filter by tasks which is greater than this value + @param options_like: filter tasks by specific option inside of the options + @param options_not_like: filter tasks by specific option not inside of the options + @param tags_tasks_like: filter tasks by specific tag + @param tags_tasks_not_like: filter tasks by specific tag not inside of task tags + @param task_ids: list of task_id + @param include_hashes: return task+samples details + @param user_id: list of tasks submitted by user X + @param for_update: If True, use "SELECT FOR UPDATE" in order to create a row-level lock on the selected tasks. + @return: list of tasks. + """ + tasks: List[Task] = [] + stmt = select(Task).options(joinedload(Task.guest), subqueryload(Task.errors), subqueryload(Task.tags)) + if include_hashes: + stmt = stmt.options(joinedload(Task.sample)) + if status: + if "|" in status: + stmt = stmt.where(Task.status.in_(status.split("|"))) + else: + stmt = stmt.where(Task.status == status) + if not_status: + stmt = stmt.where(Task.status != not_status) + if category: + stmt = stmt.where(Task.category.in_([category] if isinstance(category, str) else category)) + if sample_id is not None: + stmt = stmt.where(Task.sample_id == sample_id) + if id_before is not None: + stmt = stmt.where(Task.id < id_before) + if id_after is not None: + stmt = stmt.where(Task.id > id_after) + if completed_after: + stmt = stmt.where(Task.completed_on > completed_after) + if added_before: + stmt = stmt.where(Task.added_on < added_before) + if options_like: + stmt = stmt.where(Task.options.like(f"%{options_like.replace('*', '%')}%")) + if options_not_like: + stmt = stmt.where(Task.options.notlike(f"%{options_not_like.replace('*', '%')}%")) + if tags_tasks_like: + stmt = stmt.where(Task.tags_tasks.like(f"%{tags_tasks_like}%")) + if tags_tasks_not_like: + stmt = stmt.where(Task.tags_tasks.notlike(f"%{tags_tasks_not_like}%")) + if task_ids: + stmt = stmt.where(Task.id.in_(task_ids)) + if user_id is not None: + stmt = stmt.where(Task.user_id == user_id) + + # 3. Chaining for ordering, pagination, and locking remains the same + if order_by is not None and isinstance(order_by, tuple): + stmt = stmt.order_by(*order_by) + elif order_by is not None: + stmt = stmt.order_by(order_by) + else: + stmt = stmt.order_by(Task.added_on.desc()) + + stmt = stmt.limit(limit).offset(offset) + if for_update: + stmt = stmt.with_for_update(of=Task) + + tasks = self.session.scalars(stmt).all() + return tasks + + def delete_task(self, task_id): + """Delete information on a task. + @param task_id: ID of the task to query. + @return: operation status. + """ + task = self.session.get(Task, task_id) + if task is None: + return False + self.session.delete(task) + # ToDo missed commits everywhere, check if autocommit is possible + return True + + def delete_tasks( + self, + category=None, + status=None, + sample_id=None, + not_status=None, + completed_after=None, + added_before=None, + id_before=None, + id_after=None, + options_like=False, + options_not_like=False, + tags_tasks_like=False, + task_ids=False, + user_id=None, + ): + """Delete tasks based on parameters. If no filters are provided, no tasks will be deleted. + + Args: + category: filter by category + status: filter by task status + sample_id: filter tasks for a sample + not_status: exclude this task status from filter + completed_after: only list tasks completed after this timestamp + added_before: tasks added before a specific timestamp + id_before: filter by tasks which is less than this value + id_after: filter by tasks which is greater than this value + options_like: filter tasks by specific option inside of the options + options_not_like: filter tasks by specific option not inside of the options + tags_tasks_like: filter tasks by specific tag + task_ids: list of task_id + user_id: list of tasks submitted by user X + + Returns: + bool: True if the operation was successful (including no tasks to delete), False otherwise. + """ + delete_stmt = delete(Task) + filters_applied = False + + # 2. Chain .where() clauses for all filters + if status: + if "|" in status: + delete_stmt = delete_stmt.where(Task.status.in_(status.split("|"))) + else: + delete_stmt = delete_stmt.where(Task.status == status) + filters_applied = True + if not_status: + delete_stmt = delete_stmt.where(Task.status != not_status) + filters_applied = True + if category: + delete_stmt = delete_stmt.where(Task.category.in_([category] if isinstance(category, str) else category)) + filters_applied = True + if sample_id is not None: + delete_stmt = delete_stmt.where(Task.sample_id == sample_id) + filters_applied = True + if id_before is not None: + delete_stmt = delete_stmt.where(Task.id < id_before) + filters_applied = True + if id_after is not None: + delete_stmt = delete_stmt.where(Task.id > id_after) + filters_applied = True + if completed_after: + delete_stmt = delete_stmt.where(Task.completed_on > completed_after) + filters_applied = True + if added_before: + delete_stmt = delete_stmt.where(Task.added_on < added_before) + filters_applied = True + if options_like: + delete_stmt = delete_stmt.where(Task.options.like(f"%{options_like.replace('*', '%')}%")) + filters_applied = True + if options_not_like: + delete_stmt = delete_stmt.where(Task.options.notlike(f"%{options_not_like.replace('*', '%')}%")) + filters_applied = True + if tags_tasks_like: + delete_stmt = delete_stmt.where(Task.tags_tasks.like(f"%{tags_tasks_like}%")) + filters_applied = True + if task_ids: + delete_stmt = delete_stmt.where(Task.id.in_(task_ids)) + filters_applied = True + if user_id is not None: + delete_stmt = delete_stmt.where(Task.user_id == user_id) + filters_applied = True + + if not filters_applied: + log.warning("No filters provided for delete_tasks. No tasks will be deleted.") + return True + + # ToDo Transaction Handling + # The transaction logic (commit/rollback) is kept the same for a direct port, + # but the more idiomatic SQLAlchemy 2.0 approach would be to wrap the execution + # in a with self.session.begin(): block, which handles transactions automatically. + try: + result = self.session.execute(delete_stmt) + log.info("Deleted %d tasks matching the criteria.", result.rowcount) + self.session.commit() + return True + except SQLAlchemyError as e: + log.error("Error deleting tasks: %s", str(e)) + self.session.rollback() + return False + + # ToDo replace with delete_tasks + def clean_timed_out_tasks(self, timeout: int): + """Deletes PENDING tasks that were added more than `timeout` seconds ago.""" + if timeout <= 0: + return + + # Calculate the cutoff time before which tasks are considered timed out. + timeout_threshold = _utcnow_naive() - timedelta(seconds=timeout) + + # Build a single, efficient DELETE statement that filters in the database. + delete_stmt = delete(Task).where(Task.status == TASK_PENDING).where(Task.added_on < timeout_threshold) + + # Execute the bulk delete statement. + # The transaction should be handled by the calling code, + # typically with a `with session.begin():` block. + result = self.session.execute(delete_stmt) + + if result.rowcount > 0: + log.info("Deleted %d timed-out PENDING tasks.", result.rowcount) + + def minmax_tasks(self) -> Tuple[int, int]: + """Finds the minimum start time and maximum completion time for all tasks.""" + # A single query is more efficient than two separate ones. + stmt = select(func.min(Task.started_on), func.max(Task.completed_on)) + min_val, max_val = self.session.execute(stmt).one() + + if min_val and max_val: + # .timestamp() is the modern way to get a unix timestamp. + return int(min_val.replace(tzinfo=timezone.utc).timestamp()), int(max_val.replace(tzinfo=timezone.utc).timestamp()) + + return 0, 0 + + def get_tlp_tasks(self) -> List[int]: + """Retrieves a list of task IDs that have TLP enabled.""" + # Selecting just the ID is more efficient than fetching full objects. + stmt = select(Task.id).where(Task.tlp == "true") + # .scalars() directly yields the values from the single selected column. + return self.session.scalars(stmt).all() + + + + def get_tasks_status_count(self) -> Dict[str, int]: + """Counts tasks, grouped by status.""" + stmt = select(Task.status, func.count(Task.status)).group_by(Task.status) + # .execute() returns rows, which can be directly converted to a dict. + return dict(self.session.execute(stmt).all()) + + def count_tasks(self, status: str = None, mid: int = None) -> int: + """Counts tasks in the database, with optional filters.""" + # Build a `SELECT COUNT(...)` query from the start for efficiency. + stmt = select(func.count(Task.id)) + if mid: + stmt = stmt.where(Task.machine_id == mid) + if status: + stmt = stmt.where(Task.status == status) + + # .scalar() executes the query and returns the single integer result. + return self.session.scalar(stmt) + + def view_task(self, task_id, details=False) -> Optional[Task]: + """Retrieve information on a task. + @param task_id: ID of the task to query. + @return: details on the task. + """ + query = select(Task).where(Task.id == task_id) + if details: + query = query.options( + joinedload(Task.guest), subqueryload(Task.errors), subqueryload(Task.tags), joinedload(Task.sample) + ) + else: + query = query.options(subqueryload(Task.tags), joinedload(Task.sample)) + return self.session.scalar(query) + + # This function is used by the runstatistics community module. + def add_statistics_to_task(self, task_id, details): # pragma: no cover + """add statistic to task + @param task_id: ID of the task to query. + @param: details statistic. + @return true of false. + """ + # ToDo do we really need this? does it need commit? + task = self.session.get(Task, task_id) + if task: + task.dropped_files = details["dropped_files"] + task.running_processes = details["running_processes"] + task.api_calls = details["api_calls"] + task.domains = details["domains"] + task.signatures_total = details["signatures_total"] + task.signatures_alert = details["signatures_alert"] + task.files_written = details["files_written"] + task.registry_keys_modified = details["registry_keys_modified"] + task.crash_issues = details["crash_issues"] + task.anti_issues = details["anti_issues"] + return True + + + def ban_user_tasks(self, user_id: int): + """ + Bans all PENDING tasks submitted by a given user. + @param user_id: user id + """ + + update_stmt = update(Task).where(Task.user_id == user_id, Task.status == TASK_PENDING).values(status=TASK_BANNED) + + # 2. Execute the statement. + # The transaction should be handled by the calling code, + # ToDo e.g., with a `with session.begin():` block. + self.session.execute(update_stmt) + + def tasks_reprocess(self, task_id: int): + """common func for api and views""" + task = self.view_task(task_id) + if not task: + return True, "Task ID does not exist in the database", "" + + if task.status not in { + # task status suitable for reprocessing + # allow reprocessing of tasks already processed (maybe detections changed) + TASK_REPORTED, + # allow reprocessing of tasks that were rescheduled + TASK_RECOVERED, + # allow reprocessing of tasks that previously failed the processing stage + TASK_FAILED_PROCESSING, + # allow reprocessing of tasks that previously failed the reporting stage + TASK_FAILED_REPORTING, + # TASK_COMPLETED, + }: + return True, f"Task ID {task_id} cannot be reprocessed in status {task.status}", task.status + + # Save the old_status, because otherwise, in the call to set_status(), + # sqlalchemy will use the cached Task object that `task` is already a reference + # to and update that in place. That would result in `task.status` in this + # function being set to TASK_COMPLETED and we don't want to return that. + old_status = task.status + self.set_status(task_id, TASK_COMPLETED) + return False, "", old_status + + def view_errors(self, task_id: int) -> List[Error]: + """Gets all errors related to a task.""" + stmt = select(Error).where(Error.task_id == task_id) + return self.session.scalars(stmt).all() + + # Submission hooks to manipulate arguments of tasks execution + def recon( + self, + filename, + orig_options, + timeout=0, + enforce_timeout=False, + package="", + tags=None, + static=False, + priority=1, + machine="", + platform="", + custom="", + memory=False, + clock=None, + unique=False, + referrer=None, + tlp=None, + tags_tasks=False, + route=None, + cape=False, + category=None, + ): + # Get file filetype to ensure self extracting archives run longer + if not isinstance(filename, str): + filename = bytes2str(filename) + + lowered_filename = filename.lower() + + # sfx = File(filename).is_sfx() + + if "malware_name" in lowered_filename: + orig_options += "" + # if sfx: + # orig_options += ",timeout=500,enforce_timeout=1,procmemdump=1,procdump=1" + # timeout = 500 + # enforce_timeout = True + + if web_conf.general.yara_recon: + hits = File(filename).get_yara("binaries") + for hit in hits: + cape_name = hit["meta"].get("cape_type", "") + if not cape_name.endswith(("Crypter", "Packer", "Obfuscator", "Loader", "Payload")): + continue + + orig_options_parsed = get_options(orig_options) + parsed_options = get_options(hit["meta"].get("cape_options", "")) + if "tags" in parsed_options: + tags = "," + parsed_options["tags"] if tags else parsed_options["tags"] + del parsed_options["tags"] + # custom packages should be added to lib/cuckoo/core/database.py -> sandbox_packages list + # Do not overwrite user provided package + if not package and "package" in parsed_options: + package = parsed_options["package"] + del parsed_options["package"] + + if "category" in parsed_options: + category = parsed_options["category"] + del parsed_options["category"] + + orig_options_parsed.update(parsed_options) + orig_options = ",".join([f"{k}={v}" for k, v in orig_options_parsed.items()]) + + return ( + static, + priority, + machine, + platform, + custom, + memory, + clock, + unique, + referrer, + tlp, + tags_tasks, + route, + cape, + orig_options, + timeout, + enforce_timeout, + package, + tags, + category, + ) diff --git a/lib/cuckoo/core/database.py b/lib/cuckoo/core/database.py index 22d7cb582bc..c88096ab4c3 100644 --- a/lib/cuckoo/core/database.py +++ b/lib/cuckoo/core/database.py @@ -5,72 +5,40 @@ # https://blog.miguelgrinberg.com/post/what-s-new-in-sqlalchemy-2-0 # https://docs.sqlalchemy.org/en/20/changelog/migration_20.html# -import hashlib -import json import logging import os import sys from contextlib import suppress -from datetime import datetime, timedelta, timezone -from typing import Any, List, Optional, Union, Tuple, Dict -import pytz +from typing import Any, Optional - -# Sflock does a good filetype recon -from sflock.abstracts import File as SflockFile -from sflock.ident import identify as sflock_identify - -from lib.cuckoo.common.cape_utils import static_config_lookup, static_extraction from lib.cuckoo.common.colors import red from lib.cuckoo.common.constants import CUCKOO_ROOT -from lib.cuckoo.common.demux import demux_sample from lib.cuckoo.common.exceptions import ( CuckooDatabaseError, CuckooDatabaseInitializationError, CuckooDependencyError, - CuckooOperationalError, - CuckooUnserviceableTaskError, + CuckooOperationalError ) from lib.cuckoo.common.config import Config -from lib.cuckoo.common.integrations.parse_pe import PortableExecutable -from lib.cuckoo.common.objects import PCAP, URL, File, Static -from lib.cuckoo.common.path_utils import path_delete, path_exists -from lib.cuckoo.common.utils import bytes2str, create_folder, get_options +from lib.cuckoo.common.path_utils import path_exists +from lib.cuckoo.common.utils import create_folder + +from .data.db_common import Base +from .data.tasking import TasksMixIn +from .data.machines import MachinesMixIn +from .data.samples import SamplesMixIn +from .data.guests import GuestsMixIn +from .data.audits import AuditsMixIn + # ToDo postgresql+psycopg2 in connection try: from sqlalchemy.engine import make_url - from sqlalchemy import ( - Boolean, - BigInteger, - Column, - DateTime, - Enum, - ForeignKey, - Index, - Integer, - String, - Table, - Text, - create_engine, - # event, - func, - not_, - or_, - select, - Select, - delete, - update, - ) - from sqlalchemy.exc import IntegrityError, SQLAlchemyError + from sqlalchemy import String, create_engine, func, select + from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import ( - aliased, - joinedload, - subqueryload, - relationship, scoped_session, sessionmaker, - DeclarativeBase, Mapped, mapped_column, ) @@ -78,70 +46,9 @@ except ImportError: # pragma: no cover raise CuckooDependencyError("Unable to import sqlalchemy (install with `poetry install`)") -cfg = Config("cuckoo") -tz_name = cfg.cuckoo.get("timezone", "utc") - -def _utcnow_naive(): - """Returns the current time in the configured timezone as a naive datetime object.""" - try: - tz = pytz.timezone(tz_name) - except pytz.UnknownTimeZoneError: - tz = timezone.utc - return datetime.now(tz).replace(tzinfo=None) -sandbox_packages = ( - "access", - "archive", - "nsis", - "cpl", - "reg", - "regsvr", - "dll", - "exe", - "pdf", - "pub", - "doc", - "xls", - "ppt", - "jar", - "zip", - "rar", - "swf", - "python", - "msi", - "msix", - "ps1", - "msg", - "nodejs", - "eml", - "js", - "html", - "hta", - "xps", - "wsf", - "mht", - "doc", - "vbs", - "lnk", - "chm", - "hwp", - "inp", - "vbs", - "js", - "vbejse", - "msbuild", - "sct", - "xslt", - "shellcode", - "shellcode_x64", - "generic", - "iso", - "vhd", - "udf", - "one", - "inf", -) +SCHEMA_VERSION = "2b3c4d5e6f7g" log = logging.getLogger(__name__) conf = Config("cuckoo") @@ -152,445 +59,16 @@ def _utcnow_naive(): LINUX_STATIC = web_conf.linux.static_only DYNAMIC_ARCH_DETERMINATION = web_conf.general.dynamic_arch_determination -if repconf.mongodb.enabled: - from dev_utils.mongodb import mongo_find if repconf.elasticsearchdb.enabled: from dev_utils.elasticsearchdb import elastic_handler # , get_analysis_index - es = elastic_handler -SCHEMA_VERSION = "2b3c4d5e6f7g" -TASK_BANNED = "banned" -TASK_PENDING = "pending" -TASK_RUNNING = "running" -TASK_DISTRIBUTED = "distributed" -TASK_COMPLETED = "completed" -TASK_RECOVERED = "recovered" -TASK_REPORTED = "reported" -TASK_FAILED_ANALYSIS = "failed_analysis" -TASK_FAILED_PROCESSING = "failed_processing" -TASK_FAILED_REPORTING = "failed_reporting" -TASK_DISTRIBUTED_COMPLETED = "distributed_completed" - -ALL_DB_STATUSES = ( - TASK_BANNED, - TASK_PENDING, - TASK_RUNNING, - TASK_DISTRIBUTED, - TASK_COMPLETED, - TASK_RECOVERED, - TASK_REPORTED, - TASK_FAILED_ANALYSIS, - TASK_FAILED_PROCESSING, - TASK_FAILED_REPORTING, - TASK_DISTRIBUTED_COMPLETED, -) - -MACHINE_RUNNING = "running" - -# ToDo verify variable declaration in Mapped - - -class Base(DeclarativeBase): - pass - - -# Secondary table used in association Machine - Tag. -machines_tags = Table( - "machines_tags", - Base.metadata, - Column("machine_id", Integer, ForeignKey("machines.id")), - Column("tag_id", Integer, ForeignKey("tags.id")), -) - -# Secondary table used in association Task - Tag. -tasks_tags = Table( - "tasks_tags", - Base.metadata, - Column("task_id", Integer, ForeignKey("tasks.id", ondelete="cascade")), - Column("tag_id", Integer, ForeignKey("tags.id", ondelete="cascade")), -) - def get_count(q, property): count_q = q.statement.with_only_columns(func.count(property)).order_by(None) count = q.session.execute(count_q).scalar() return count -class SampleAssociation(Base): - __tablename__ = "sample_associations" - - # Each column is part of a composite primary key - parent_id: Mapped[int] = mapped_column(ForeignKey("samples.id"), primary_key=True) - child_id: Mapped[int] = mapped_column(ForeignKey("samples.id"), primary_key=True) - - # This is the crucial column that links to the specific child's task - task_id: Mapped[int] = mapped_column(ForeignKey("tasks.id", ondelete="CASCADE"), primary_key=True) - - # Relationships from the association object itself - parent: Mapped["Sample"] = relationship(foreign_keys=[parent_id], back_populates="child_links") - child: Mapped["Sample"] = relationship(foreign_keys=[child_id], back_populates="parent_links") - task: Mapped["Task"] = relationship(back_populates="association") - -class Machine(Base): - """Configured virtual machines to be used as guests.""" - - __tablename__ = "machines" - - id: Mapped[int] = mapped_column(Integer(), primary_key=True) - name: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) - label: Mapped[str] = mapped_column(String(255), nullable=False, unique=True) - arch: Mapped[str] = mapped_column(String(255), nullable=False) - ip: Mapped[str] = mapped_column(String(255), nullable=False) - platform: Mapped[str] = mapped_column(String(255), nullable=False) - tags: Mapped[List["Tag"]] = relationship(secondary=machines_tags, back_populates="machines") - interface: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - snapshot: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - locked: Mapped[bool] = mapped_column(Boolean(), nullable=False, default=False) - locked_changed_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) - status: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - status_changed_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) - resultserver_ip: Mapped[str] = mapped_column(String(255), nullable=False) - resultserver_port: Mapped[str] = mapped_column(String(255), nullable=False) - reserved: Mapped[bool] = mapped_column(Boolean(), nullable=False, default=False) - - def __repr__(self): - return f"" - - def to_dict(self): - """Converts object to dict. - @return: dict - """ - d = {} - for column in self.__table__.columns: - value = getattr(self, column.name) - if isinstance(value, datetime): - d[column.name] = value.strftime("%Y-%m-%d %H:%M:%S") - else: - d[column.name] = value - - # Tags are a relation so no column to iterate. - d["tags"] = [tag.name for tag in self.tags] - return d - - def to_json(self): - """Converts object to JSON. - @return: JSON data - """ - return json.dumps(self.to_dict()) - - def __init__(self, name, label, arch, ip, platform, interface, snapshot, resultserver_ip, resultserver_port, reserved): - self.name = name - self.label = label - self.arch = arch - self.ip = ip - self.platform = platform - self.interface = interface - self.snapshot = snapshot - self.resultserver_ip = resultserver_ip - self.resultserver_port = resultserver_port - self.reserved = reserved - - -class Tag(Base): - """Tag describing anything you want.""" - - __tablename__ = "tags" - - id: Mapped[int] = mapped_column(primary_key=True) - name: Mapped[str] = mapped_column(nullable=False, unique=True) - machines: Mapped[List["Machine"]] = relationship(secondary=machines_tags, back_populates="tags") - tasks: Mapped[List["Task"]] = relationship(secondary=tasks_tags, back_populates="tags") - - def __repr__(self): - return f"" - - def __init__(self, name): - self.name = name - - -class Guest(Base): - """Tracks guest run.""" - - __tablename__ = "guests" - - id: Mapped[int] = mapped_column(primary_key=True) - status: Mapped[str] = mapped_column(nullable=False) - name: Mapped[str] = mapped_column(nullable=False) - label: Mapped[str] = mapped_column(nullable=False) - platform: Mapped[str] = mapped_column(nullable=False) - manager: Mapped[str] = mapped_column(nullable=False) - - started_on: Mapped[datetime] = mapped_column( - DateTime(timezone=False), default=_utcnow_naive, nullable=False - ) - shutdown_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) - task_id: Mapped[int] = mapped_column(ForeignKey("tasks.id", ondelete="cascade"), nullable=False, unique=True) - task: Mapped["Task"] = relationship(back_populates="guest") - - def __repr__(self): - return f"" - - def to_dict(self): - """Converts object to dict. - @return: dict - """ - d = {} - for column in self.__table__.columns: - value = getattr(self, column.name) - if isinstance(value, datetime): - d[column.name] = value.strftime("%Y-%m-%d %H:%M:%S") - else: - d[column.name] = value - return d - - def to_json(self): - """Converts object to JSON. - @return: JSON data - """ - return json.dumps(self.to_dict()) - - def __init__(self, name, label, platform, manager, task_id): - self.name = name - self.label = label - self.platform = platform - self.manager = manager - self.task_id = task_id - - -class Sample(Base): - """Submitted files details.""" - - __tablename__ = "samples" - - id: Mapped[int] = mapped_column(primary_key=True) - file_size: Mapped[int] = mapped_column(BigInteger, nullable=False) - file_type: Mapped[str] = mapped_column(Text(), nullable=False) - md5: Mapped[str] = mapped_column(String(32), nullable=False) - crc32: Mapped[str] = mapped_column(String(8), nullable=False) - sha1: Mapped[str] = mapped_column(String(40), nullable=False) - sha256: Mapped[str] = mapped_column(String(64), nullable=False) - sha512: Mapped[str] = mapped_column(String(128), nullable=False) - ssdeep: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - source_url: Mapped[Optional[str]] = mapped_column(String(2000), nullable=True) - tasks: Mapped[List["Task"]] = relationship(back_populates="sample", cascade="all, delete-orphan") - - child_links: Mapped[List["SampleAssociation"]] = relationship( - foreign_keys=[SampleAssociation.parent_id], back_populates="parent" - ) - # When this Sample is a child, this gives you its association links - parent_links: Mapped[List["SampleAssociation"]] = relationship( - foreign_keys=[SampleAssociation.child_id], back_populates="child" - ) - - # ToDo replace with index=True - __table_args__ = ( - Index("md5_index", "md5"), - Index("sha1_index", "sha1"), - Index("sha256_index", "sha256", unique=True), - ) - - def __repr__(self): - return f"" - - def to_dict(self): - """Converts object to dict. - @return: dict - """ - d = {} - for column in self.__table__.columns: - d[column.name] = getattr(self, column.name) - return d - - def to_json(self): - """Converts object to JSON. - @return: JSON data - """ - return json.dumps(self.to_dict()) - - def __init__(self, md5, crc32, sha1, sha256, sha512, file_size, file_type=None, ssdeep=None, parent_sample=None, source_url=None): - self.md5 = md5 - self.sha1 = sha1 - self.crc32 = crc32 - self.sha256 = sha256 - self.sha512 = sha512 - self.file_size = file_size - if file_type: - self.file_type = file_type - if ssdeep: - self.ssdeep = ssdeep - # if parent_sample: - # self.parent_sample = parent_sample - if source_url: - self.source_url = source_url - - -class Error(Base): - """Analysis errors.""" - - __tablename__ = "errors" - MAX_LENGTH = 1024 - - id: Mapped[int] = mapped_column(primary_key=True) - message: Mapped[str] = mapped_column(String(MAX_LENGTH), nullable=False) - task_id: Mapped[int] = mapped_column(ForeignKey("tasks.id"), nullable=False) - task: Mapped["Task"] = relationship(back_populates="errors") - - def to_dict(self): - """Converts object to dict. - @return: dict - """ - d = {} - for column in self.__table__.columns: - d[column.name] = getattr(self, column.name) - return d - - def to_json(self): - """Converts object to JSON. - @return: JSON data - """ - return json.dumps(self.to_dict()) - - def __init__(self, message, task_id): - if len(message) > self.MAX_LENGTH: - # Make sure that we don't try to insert an error message longer than what's allowed - # in the database. Provide the beginning and the end of the error. - left_of_ellipses = self.MAX_LENGTH // 2 - 2 - right_of_ellipses = self.MAX_LENGTH - left_of_ellipses - 3 - message = "...".join((message[:left_of_ellipses], message[-right_of_ellipses:])) - self.message = message - self.task_id = task_id - - def __repr__(self): - return f"" - - -class Task(Base): - """Analysis task queue.""" - - __tablename__ = "tasks" - - id: Mapped[int] = mapped_column(Integer(), primary_key=True) - target: Mapped[str] = mapped_column(Text(), nullable=False) - category: Mapped[str] = mapped_column(String(255), nullable=False) - cape: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True) - timeout: Mapped[int] = mapped_column(Integer(), server_default="0", nullable=False) - priority: Mapped[int] = mapped_column(Integer(), server_default="1", nullable=False) - custom: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - machine: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - package: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - route: Mapped[Optional[str]] = mapped_column(String(128), nullable=True, default=False) - # Task tags - tags_tasks: Mapped[Optional[str]] = mapped_column(String(256), nullable=True) - # Virtual machine tags - tags: Mapped[List["Tag"]] = relationship(secondary=tasks_tags, back_populates="tasks", passive_deletes=True) - options: Mapped[Optional[str]] = mapped_column(Text(), nullable=True) - platform: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - memory: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - enforce_timeout: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - clock: Mapped[datetime] = mapped_column( - DateTime(timezone=False), - default=_utcnow_naive, - nullable=False, - ) - added_on: Mapped[datetime] = mapped_column( - DateTime(timezone=False), - default=_utcnow_naive, - nullable=False, - ) - started_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) - completed_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) - status: Mapped[str] = mapped_column( - Enum( - TASK_BANNED, - TASK_PENDING, - TASK_RUNNING, - TASK_COMPLETED, - TASK_DISTRIBUTED, - TASK_REPORTED, - TASK_RECOVERED, - TASK_FAILED_ANALYSIS, - TASK_FAILED_PROCESSING, - TASK_FAILED_REPORTING, - name="status_type", - ), - server_default=TASK_PENDING, - nullable=False, - ) - - # Statistics data to identify broken Cuckoos servers or VMs - # Also for doing profiling to improve speed - dropped_files: Mapped[Optional[int]] = mapped_column(nullable=True) - running_processes: Mapped[Optional[int]] = mapped_column(nullable=True) - api_calls: Mapped[Optional[int]] = mapped_column(nullable=True) - domains: Mapped[Optional[int]] = mapped_column(nullable=True) - signatures_total: Mapped[Optional[int]] = mapped_column(nullable=True) - signatures_alert: Mapped[Optional[int]] = mapped_column(nullable=True) - files_written: Mapped[Optional[int]] = mapped_column(nullable=True) - registry_keys_modified: Mapped[Optional[int]] = mapped_column(nullable=True) - crash_issues: Mapped[Optional[int]] = mapped_column(nullable=True) - anti_issues: Mapped[Optional[int]] = mapped_column(nullable=True) - analysis_started_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) - analysis_finished_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) - processing_started_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) - processing_finished_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) - signatures_started_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) - signatures_finished_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) - reporting_started_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) - reporting_finished_on: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=False), nullable=True) - timedout: Mapped[bool] = mapped_column(nullable=False, default=False) - - sample_id: Mapped[Optional[int]] = mapped_column(ForeignKey("samples.id"), nullable=True) - sample: Mapped["Sample"] = relationship(back_populates="tasks") # , lazy="subquery" - machine_id: Mapped[Optional[int]] = mapped_column(nullable=True) - guest: Mapped["Guest"] = relationship( - back_populates="task", uselist=False, cascade="all, delete-orphan" # This is crucial for a one-to-one relationship - ) - errors: Mapped[List["Error"]] = relationship( - back_populates="task", cascade="all, delete-orphan" # This MUST match the attribute name on the Error model - ) - - tlp: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) - user_id: Mapped[Optional[int]] = mapped_column(nullable=True) - - # The Task is linked to one specific parent/child association event - association: Mapped[Optional["SampleAssociation"]] = relationship(back_populates="task", cascade="all, delete-orphan") - - __table_args__ = ( - Index("category_index", "category"), - Index("status_index", "status"), - Index("added_on_index", "added_on"), - Index("completed_on_index", "completed_on"), - ) - - def to_dict(self): - """Converts object to dict. - @return: dict - """ - d = {} - for column in self.__table__.columns: - value = getattr(self, column.name) - if isinstance(value, datetime): - d[column.name] = value.strftime("%Y-%m-%d %H:%M:%S") - else: - d[column.name] = value - - # Tags are a relation so no column to iterate. - d["tags"] = [tag.name for tag in self.tags] - return d - - def to_json(self): - """Converts object to JSON. - @return: JSON data - """ - return json.dumps(self.to_dict()) - - def __init__(self, target=None): - self.target = target - - def __repr__(self): - return f"" - class AlembicVersion(Base): """Table used to pinpoint actual database schema release.""" @@ -600,7 +78,12 @@ class AlembicVersion(Base): version_num: Mapped[str] = mapped_column(String(32), nullable=False, primary_key=True) -class _Database: + +class _Database(TasksMixIn, + GuestsMixIn, + MachinesMixIn, + SamplesMixIn, + AuditsMixIn): """Analysis queue database. This class handles the creation of the database user for internal queue @@ -737,1879 +220,6 @@ def drop(self): except SQLAlchemyError as e: raise CuckooDatabaseError(f"Unable to create or connect to database: {e}") - def clean_machines(self): - """Clean old stored machines and related tables.""" - # Secondary table. - # TODO: this is better done via cascade delete. - # self.engine.execute(machines_tags.delete()) - # ToDo : If your ForeignKey has "ON DELETE CASCADE", deleting a Machine - # would automatically delete its entries in machines_tags. - # If not, deleting them manually first is correct. - self.session.execute(delete(machines_tags)) - self.session.execute(delete(Machine)) - - def delete_machine(self, name) -> bool: - """Delete a single machine entry from DB.""" - - stmt = select(Machine).where(Machine.name == name) - machine = self.session.scalar(stmt) - - if machine: - # Deleting a specific ORM instance remains the same - self.session.delete(machine) - return True - else: - log.warning("%s does not exist in the database.", name) - return False - - def add_machine( - self, name, label, arch, ip, platform, tags, interface, snapshot, resultserver_ip, resultserver_port, reserved, locked=False - ) -> Machine: - """Add a guest machine. - @param name: machine id - @param label: machine label - @param arch: machine arch - @param ip: machine IP address - @param platform: machine supported platform - @param tags: list of comma separated tags - @param interface: sniffing interface for this machine - @param snapshot: snapshot name to use instead of the current one, if configured - @param resultserver_ip: IP address of the Result Server - @param resultserver_port: port of the Result Server - @param reserved: True if the machine can only be used when specifically requested - """ - - machine = Machine( - name=name, - label=label, - arch=arch, - ip=ip, - platform=platform, - interface=interface, - snapshot=snapshot, - resultserver_ip=resultserver_ip, - resultserver_port=resultserver_port, - reserved=reserved, - ) - - if tags: - with self.session.no_autoflush: - for tag in tags.replace(" ", "").split(","): - machine.tags.append(self._get_or_create(Tag, name=tag)) - if locked: - machine.locked = True - - self.session.add(machine) - return machine - - def set_machine_interface(self, label, interface): - stmt = select(Machine).filter_by(label=label) - machine = self.session.scalar(stmt) - - if machine is None: - log.debug("Database error setting interface: %s not found", label) - return - - # This part remains the same - machine.interface = interface - return machine - - def set_vnc_port(self, task_id: int, port: int): - stmt = select(Task).where(Task.id == task_id) - task = self.session.scalar(stmt) - - if task is None: - log.debug("Database error setting VPN port: For task %s", task_id) - return - - # This logic remains the same - if task.options: - task.options += f",vnc_port={port}" - else: - task.options = f"vnc_port={port}" - - def update_clock(self, task_id): - row = self.session.get(Task, task_id) - - if not row: - return - # datetime.fromtimestamp(0, tz=timezone.utc) - if row.clock == datetime.fromtimestamp(0, timezone.utc).replace(tzinfo=None): - if row.category == "file": - # datetime.now(timezone.utc) - row.clock = _utcnow_naive() + timedelta(days=self.cfg.cuckoo.daydelta) - else: - # datetime.now(timezone.utc) - row.clock = _utcnow_naive() - return row.clock - - def set_task_status(self, task: Task, status) -> Task: - if status != TASK_DISTRIBUTED_COMPLETED: - task.status = status - - if status in (TASK_RUNNING, TASK_DISTRIBUTED): - task.started_on = _utcnow_naive() - elif status in (TASK_COMPLETED, TASK_DISTRIBUTED_COMPLETED): - task.completed_on = _utcnow_naive() - - self.session.add(task) - return task - - def set_status(self, task_id: int, status) -> Optional[Task]: - """Set task status. - @param task_id: task identifier - @param status: status string - @return: operation status - """ - task = self.session.get(Task, task_id) - - if not task: - return None - - return self.set_task_status(task, status) - - def create_guest(self, machine: Machine, manager: str, task: Task) -> Guest: - guest = Guest(machine.name, machine.label, machine.platform, manager, task.id) - guest.status = "init" - self.session.add(guest) - return guest - - def _package_vm_requires_check(self, package: str) -> list: - """ - We allow to users use their custom tags to tag properly any VM that can run this package - """ - return [vm_tag.strip() for vm_tag in web_conf.packages.get(package).split(",")] if web_conf.packages.get(package) else [] - - def _task_arch_tags_helper(self, task: Task): - # Are there available machines that match up with a task? - task_archs = [tag.name for tag in task.tags if tag.name in ("x86", "x64")] - task_tags = [tag.name for tag in task.tags if tag.name not in task_archs] - - return task_archs, task_tags - - def find_machine_to_service_task(self, task: Task) -> Optional[Machine]: - """Find a machine that is able to service the given task. - Returns: The Machine if an available machine was found; None if there is at least 1 machine - that *could* service it, but they are all currently in use. - Raises: CuckooUnserviceableTaskError if there are no machines in the pool that would be able - to service it. - """ - task_archs, task_tags = self._task_arch_tags_helper(task) - os_version = self._package_vm_requires_check(task.package) - - base_stmt = select(Machine).options(subqueryload(Machine.tags)) - - # This helper now encapsulates the final ordering, locking, and execution. - # It takes a Select statement as input. - def get_locked_machine(stmt: Select) -> Optional[Machine]: - final_stmt = stmt.order_by(Machine.locked, Machine.locked_changed_on).with_for_update(of=Machine) - return self.session.scalars(final_stmt).first() - - filter_kwargs = { - "statement": base_stmt, - "label": task.machine, - "platform": task.platform, - "tags": task_tags, - "archs": task_archs, - "os_version": os_version, - } - - filtered_stmt = self.filter_machines_to_task(include_reserved=False, **filter_kwargs) - machine = get_locked_machine(filtered_stmt) - - if machine is None and not task.machine and task_tags: - # The task was given at least 1 tag, but there are no non-reserved machines - # that could satisfy the request. So let's check "reserved" machines. - filtered_stmt = self.filter_machines_to_task(include_reserved=True, **filter_kwargs) - machine = get_locked_machine(filtered_stmt) - - if machine is None: - raise CuckooUnserviceableTaskError - if machine.locked: - # There aren't any machines that can service the task NOW, but there is at - # least one in the pool that could service it once it's available. - return None - return machine - - def fetch_task(self, categories: list = None): - """Fetches a task waiting to be processed and locks it for running. - @return: None or task - """ - stmt = ( - select(Task) - .where(Task.status == TASK_PENDING) - .where(not_(Task.options.contains("node="))) - .order_by(Task.priority.desc(), Task.added_on) - ) - - if categories: - stmt = stmt.where(Task.category.in_(categories)) - - # 2. Execute the statement and get the first result object - row = self.session.scalars(stmt).first() - - if not row: - return None - - # This business logic remains the same - self.set_status(task_id=row.id, status=TASK_RUNNING) - - return row - - def guest_get_status(self, task_id: int): - """Gets the status for a given guest.""" - stmt = select(Guest).where(Guest.task_id == task_id) - guest = self.session.scalar(stmt) - return guest.status if guest else None - - def guest_set_status(self, task_id: int, status: str): - """Sets the status for a given guest.""" - stmt = select(Guest).where(Guest.task_id == task_id) - guest = self.session.scalar(stmt) - if guest is not None: - guest.status = status - - def guest_remove(self, guest_id): - """Removes a guest start entry.""" - guest = self.session.get(Guest, guest_id) - if guest: - self.session.delete(guest) - - def guest_stop(self, guest_id): - """Logs guest stop. - @param guest_id: guest log entry id - """ - guest = self.session.get(Guest, guest_id) - if guest: - guest.shutdown_on = _utcnow_naive() - - @staticmethod - def filter_machines_by_arch(statement: Select, arch: list) -> Select: - """Adds a filter to the given select statement for the machine architecture. - Allows x64 machines to be returned when requesting x86. - """ - if arch: - if "x86" in arch: - # Prefer x86 machines over x64 if x86 is what was requested. - statement = statement.where(Machine.arch.in_(("x64", "x86"))).order_by(Machine.arch.desc()) - else: - statement = statement.where(Machine.arch.in_(arch)) - return statement - - def filter_machines_to_task( - self, statement: Select, label=None, platform=None, tags=None, archs=None, os_version=None, include_reserved=False - ) -> Select: - """Adds filters to the given select statement based on the task. - - @param statement: A `select()` statement to add filters to. - """ - if label: - statement = statement.where(Machine.label == label) - elif not include_reserved: - # Use .is_(False) for boolean checks - statement = statement.where(Machine.reserved.is_(False)) - - if platform: - statement = statement.where(Machine.platform == platform) - - statement = self.filter_machines_by_arch(statement, archs) - - if tags: - for tag in tags: - statement = statement.where(Machine.tags.any(name=tag)) - - if os_version: - statement = statement.where(Machine.tags.any(Tag.name.in_(os_version))) - - return statement - - def list_machines( - self, - locked=None, - label=None, - platform=None, - tags=None, - arch=None, - include_reserved=False, - os_version=None, - ) -> List[Machine]: - """Lists virtual machines. - @return: list of virtual machines - """ - """ - id | name | label | arch | - ----+-------+-------+------+ - 77 | cape1 | win7 | x86 | - 78 | cape2 | win10 | x64 | - """ - # ToDo do we really need it - with self.session.begin_nested(): - # with self.session.no_autoflush: - stmt = select(Machine).options(subqueryload(Machine.tags)) - - if locked is not None: - stmt = stmt.where(Machine.locked.is_(locked)) - - stmt = self.filter_machines_to_task( - statement=stmt, - label=label, - platform=platform, - tags=tags, - archs=arch, - os_version=os_version, - include_reserved=include_reserved, - ) - return self.session.execute(stmt).unique().scalars().all() - - def assign_machine_to_task(self, task: Task, machine: Optional[Machine]) -> Task: - if machine: - task.machine = machine.label - task.machine_id = machine.id - else: - task.machine = None - task.machine_id = None - self.session.add(task) - return task - - def lock_machine(self, machine: Machine) -> Machine: - """Places a lock on a free virtual machine. - @param machine: the Machine to lock - @return: locked machine - """ - machine.locked = True - machine.locked_changed_on = _utcnow_naive() - self.set_machine_status(machine, MACHINE_RUNNING) - self.session.add(machine) - - return machine - - def unlock_machine(self, machine: Machine) -> Machine: - """Remove lock from a virtual machine. - @param machine: The Machine to unlock. - @return: unlocked machine - """ - machine.locked = False - machine.locked_changed_on = _utcnow_naive() - self.session.merge(machine) - return machine - - def count_machines_available(self, label=None, platform=None, tags=None, arch=None, include_reserved=False, os_version=None): - """How many (relevant) virtual machines are ready for analysis. - @param label: machine ID. - @param platform: machine platform. - @param tags: machine tags - @param arch: machine arch - @param include_reserved: include 'reserved' machines in the result, regardless of whether or not a 'label' was provided. - @return: free virtual machines count - """ - stmt = select(func.count(Machine.id)).where(Machine.locked.is_(False)) - stmt = self.filter_machines_to_task( - statement=stmt, - label=label, - platform=platform, - tags=tags, - archs=arch, - os_version=os_version, - include_reserved=include_reserved, - ) - - return self.session.scalar(stmt) - - def get_available_machines(self) -> List[Machine]: - """Which machines are available""" - stmt = select(Machine).options(subqueryload(Machine.tags)).where(Machine.locked.is_(False)) - return self.session.scalars(stmt).all() - - def count_machines_running(self) -> int: - """Counts how many machines are currently locked (running).""" - stmt = select(func.count(Machine.id)).where(Machine.locked.is_(True)) - return self.session.scalar(stmt) - - def set_machine_status(self, machine_or_label: Union[str, Machine], status): - """Set status for a virtual machine.""" - if isinstance(machine_or_label, str): - stmt = select(Machine).where(Machine.label == machine_or_label) - machine = self.session.scalar(stmt) - else: - machine = machine_or_label - - if machine: - machine.status = status - machine.status_changed_on = _utcnow_naive() - # No need for session.add() here; the ORM tracks changes to loaded objects. - - def add_error(self, message, task_id): - """Add an error related to a task.""" - # This function already uses modern, correct SQLAlchemy 2.0 patterns. - # No changes are needed. - error = Error(message=message, task_id=task_id) - # Use a separate session so that, regardless of the state of a transaction going on - # outside of this function, the error will always be committed to the database. - with self.session.session_factory() as sess, sess.begin(): - sess.add(error) - - def register_sample(self, obj, source_url=False): - if isinstance(obj, (File, PCAP, Static)): - fileobj = File(obj.file_path) - file_type = fileobj.get_type() - file_md5 = fileobj.get_md5() - sample = None - # check if hash is known already - try: - # get or create - sample = self.session.scalar(select(Sample).where(Sample.md5 == file_md5)) - if sample is None: - with self.session.begin_nested(): - sample = Sample( - md5=file_md5, - crc32=fileobj.get_crc32(), - sha1=fileobj.get_sha1(), - sha256=fileobj.get_sha256(), - sha512=fileobj.get_sha512(), - file_size=fileobj.get_size(), - file_type=file_type, - ssdeep=fileobj.get_ssdeep(), - source_url=source_url, - ) - self.session.add(sample) - except IntegrityError as e: - log.exception(e) - return sample - - def add( - self, - obj, - *, - timeout=0, - package="", - options="", - priority=1, - custom="", - machine="", - platform="", - tags=None, - memory=False, - enforce_timeout=False, - clock=None, - parent_sample=None, - tlp=None, - static=False, - source_url=False, - route=None, - cape=False, - tags_tasks=False, - user_id=0, - ): - """Add a task to database. - @param obj: object to add (File or URL). - @param timeout: selected timeout. - @param options: analysis options. - @param priority: analysis priority. - @param custom: custom options. - @param machine: selected machine. - @param platform: platform. - @param tags: optional tags that must be set for machine selection - @param memory: toggle full memory dump. - @param enforce_timeout: toggle full timeout execution. - @param clock: virtual machine clock time - @param parent_id: parent task id - @param parent_sample: original sample in case of archive - @param static: try static extraction first - @param tlp: TLP sharing designation - @param source_url: url from where it was downloaded - @param route: Routing route - @param cape: CAPE options - @param tags_tasks: Task tags so users can tag their jobs - @param user_id: Link task to user if auth enabled - @return: cursor or None. - """ - # Convert empty strings and None values to a valid int - - if isinstance(obj, (File, PCAP, Static)): - fileobj = File(obj.file_path) - file_type = fileobj.get_type() - file_md5 = fileobj.get_md5() - # check if hash is known already - # ToDo consider migrate to _get_or_create? - sample = self.session.scalar(select(Sample).where(Sample.md5 == file_md5)) - if not sample: - try: - with self.session.begin_nested(): - sample = Sample( - md5=file_md5, - crc32=fileobj.get_crc32(), - sha1=fileobj.get_sha1(), - sha256=fileobj.get_sha256(), - sha512=fileobj.get_sha512(), - file_size=fileobj.get_size(), - file_type=file_type, - ssdeep=fileobj.get_ssdeep(), - source_url=source_url, - ) - self.session.add(sample) - except Exception as e: - log.exception(e) - - if DYNAMIC_ARCH_DETERMINATION: - # Assign architecture to task to fetch correct VM type - - # This isn't 100% fool proof - _tags = tags.split(",") if isinstance(tags, str) else [] - arch_tag = fileobj.predict_arch() - if package.endswith("_x64"): - _tags.append("x64") - elif arch_tag: - _tags.append(arch_tag) - tags = ",".join(set(_tags)) - task = Task(obj.file_path) - task.sample_id = sample.id - - if isinstance(obj, (PCAP, Static)): - # since no VM will operate on this PCAP - task.started_on = _utcnow_naive() - - elif isinstance(obj, URL): - task = Task(obj.url) - _tags = tags.split(",") if isinstance(tags, str) else [] - _tags.append("x64") - _tags.append("x86") - tags = ",".join(set(_tags)) - - else: - return None - - task.category = obj.__class__.__name__.lower() - task.timeout = timeout - task.package = package - task.options = options - task.priority = priority - task.custom = custom - task.machine = machine - task.platform = platform - task.memory = bool(memory) - task.enforce_timeout = enforce_timeout - task.tlp = tlp - task.route = route - task.cape = cape - task.tags_tasks = tags_tasks - # Deal with tags format (i.e., foo,bar,baz) - if tags: - for tag in tags.split(","): - tag_name = tag.strip() - if tag_name and tag_name not in [tag.name for tag in task.tags]: - # "Task" object is being merged into a Session along the backref cascade path for relationship "Tag.tasks"; in SQLAlchemy 2.0, this reverse cascade will not take place. - # Set cascade_backrefs to False in either the relationship() or backref() function for the 2.0 behavior; or to set globally for the whole Session, set the future=True flag - # (Background on this error at: https://sqlalche.me/e/14/s9r1) (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9) - task.tags.append(self._get_or_create(Tag, name=tag_name)) - - if clock: - if isinstance(clock, str): - try: - task.clock = datetime.strptime(clock, "%m-%d-%Y %H:%M:%S") - except ValueError: - log.warning("The date you specified has an invalid format, using current timestamp") - task.clock = datetime.fromtimestamp(0, timezone.utc).replace(tzinfo=None) - - else: - task.clock = clock - else: - task.clock = datetime.fromtimestamp(0, timezone.utc).replace(tzinfo=None) - - task.user_id = user_id - - if parent_sample: - association = SampleAssociation( - parent=parent_sample, - child=sample, - task=task, - ) - self.session.add(association) - - # Use a nested transaction so that we can return an ID. - with self.session.begin_nested(): - self.session.add(task) - - return task.id - - def add_path( - self, - file_path, - timeout=0, - package="", - options="", - priority=1, - custom="", - machine="", - platform="", - tags=None, - memory=False, - enforce_timeout=False, - clock=None, - tlp=None, - static=False, - source_url=False, - route=None, - cape=False, - tags_tasks=False, - user_id=0, - parent_sample = None, - ): - """Add a task to database from file path. - @param file_path: sample path. - @param timeout: selected timeout. - @param options: analysis options. - @param priority: analysis priority. - @param custom: custom options. - @param machine: selected machine. - @param platform: platform. - @param tags: Tags required in machine selection - @param memory: toggle full memory dump. - @param enforce_timeout: toggle full timeout execution. - @param clock: virtual machine clock time - @param parent_id: parent analysis id - @param parent_sample: sample object if archive - @param static: try static extraction first - @param tlp: TLP sharing designation - @param route: Routing route - @param cape: CAPE options - @param tags_tasks: Task tags so users can tag their jobs - @user_id: Allow link task to user if auth enabled - @parent_sample: Sample object, if archive - @return: cursor or None. - """ - if not file_path or not path_exists(file_path): - log.warning("File does not exist: %s", file_path) - return None - - # Convert empty strings and None values to a valid int - if not timeout: - timeout = 0 - if not priority: - priority = 1 - if file_path.endswith((".htm", ".html")) and not package: - package = web_conf.url_analysis.package - - return self.add( - File(file_path), - timeout=timeout, - package=package, - options=options, - priority=priority, - custom=custom, - machine=machine, - platform=platform, - tags=tags, - memory=memory, - enforce_timeout=enforce_timeout, - clock=clock, - tlp=tlp, - source_url=source_url, - route=route, - cape=cape, - tags_tasks=tags_tasks, - user_id=user_id, - parent_sample=parent_sample, - ) - - def _identify_aux_func(self, file: bytes, package: str, check_shellcode: bool = True) -> tuple: - # before demux we need to check as msix has zip mime and we don't want it to be extracted: - tmp_package = False - if not package: - f = SflockFile.from_path(file) - try: - tmp_package = sflock_identify(f, check_shellcode=check_shellcode) - except Exception as e: - log.error("Failed to sflock_ident due to %s", str(e)) - tmp_package = "generic" - - if tmp_package and tmp_package in sandbox_packages: - # This probably should be way much bigger list of formats - if tmp_package in ("iso", "udf", "vhd"): - package = "archive" - elif tmp_package in ("zip", "rar"): - package = "" - elif tmp_package in ("html",): - package = web_conf.url_analysis.package - else: - package = tmp_package - - return package, tmp_package - - # Submission hooks to manipulate arguments of tasks execution - def recon( - self, - filename, - orig_options, - timeout=0, - enforce_timeout=False, - package="", - tags=None, - static=False, - priority=1, - machine="", - platform="", - custom="", - memory=False, - clock=None, - unique=False, - referrer=None, - tlp=None, - tags_tasks=False, - route=None, - cape=False, - category=None, - ): - # Get file filetype to ensure self extracting archives run longer - if not isinstance(filename, str): - filename = bytes2str(filename) - - lowered_filename = filename.lower() - - # sfx = File(filename).is_sfx() - - if "malware_name" in lowered_filename: - orig_options += "" - # if sfx: - # orig_options += ",timeout=500,enforce_timeout=1,procmemdump=1,procdump=1" - # timeout = 500 - # enforce_timeout = True - - if web_conf.general.yara_recon: - hits = File(filename).get_yara("binaries") - for hit in hits: - cape_name = hit["meta"].get("cape_type", "") - if not cape_name.endswith(("Crypter", "Packer", "Obfuscator", "Loader", "Payload")): - continue - - orig_options_parsed = get_options(orig_options) - parsed_options = get_options(hit["meta"].get("cape_options", "")) - if "tags" in parsed_options: - tags = "," + parsed_options["tags"] if tags else parsed_options["tags"] - del parsed_options["tags"] - # custom packages should be added to lib/cuckoo/core/database.py -> sandbox_packages list - # Do not overwrite user provided package - if not package and "package" in parsed_options: - package = parsed_options["package"] - del parsed_options["package"] - - if "category" in parsed_options: - category = parsed_options["category"] - del parsed_options["category"] - - orig_options_parsed.update(parsed_options) - orig_options = ",".join([f"{k}={v}" for k, v in orig_options_parsed.items()]) - - return ( - static, - priority, - machine, - platform, - custom, - memory, - clock, - unique, - referrer, - tlp, - tags_tasks, - route, - cape, - orig_options, - timeout, - enforce_timeout, - package, - tags, - category, - ) - - def demux_sample_and_add_to_db( - self, - file_path, - timeout=0, - package="", - options="", - priority=1, - custom="", - machine="", - platform="", - tags=None, - memory=False, - enforce_timeout=False, - clock=None, - tlp=None, - static=False, - source_url=False, - only_extraction=False, - tags_tasks=False, - route=None, - cape=False, - user_id=0, - category=None, - ): - """ - Handles ZIP file submissions, submitting each extracted file to the database - Returns a list of added task IDs - """ - task_id = False - task_ids = [] - config = {} - details = {} - - if not isinstance(file_path, bytes): - file_path = file_path.encode() - - ( - static, - priority, - machine, - platform, - custom, - memory, - clock, - unique, - referrer, - tlp, - tags_tasks, - route, - cape, - options, - timeout, - enforce_timeout, - package, - tags, - category, - ) = self.recon( - file_path, - options, - timeout=timeout, - enforce_timeout=enforce_timeout, - package=package, - tags=tags, - static=static, - priority=priority, - machine=machine, - platform=platform, - custom=custom, - memory=memory, - clock=clock, - tlp=tlp, - tags_tasks=tags_tasks, - route=route, - cape=cape, - category=category, - ) - - if category == "static": - # force change of category - task_ids += self.add_static( - file_path=file_path, - priority=priority, - tlp=tlp, - user_id=user_id, - options=options, - package=package, - ) - return task_ids, details - - check_shellcode = True - if options and "check_shellcode=0" in options: - check_shellcode = False - - if not package: - if "file=" in options: - # set zip as package when specifying file= in options - package = "zip" - else: - # Checking original file as some filetypes doesn't require demux - package, _ = self._identify_aux_func(file_path, package, check_shellcode=check_shellcode) - - parent_sample = None - # extract files from the (potential) archive - extracted_files, demux_error_msgs = demux_sample(file_path, package, options, platform=platform) - # check if len is 1 and the same file, if diff register file, and set parent - if extracted_files and not any(file_path == path for path, _ in extracted_files): - parent_sample = self.register_sample(File(file_path), source_url=source_url) - if conf.cuckoo.delete_archive: - path_delete(file_path.decode()) - - # create tasks for each file in the archive - for file, platform in extracted_files: - if not path_exists(file): - log.error("Extracted file doesn't exist: %s", file) - continue - # ToDo we lose package here and send APKs to windows - if platform in ("linux", "darwin") and LINUX_STATIC: - task_ids += self.add_static( - file_path=file_path, - priority=priority, - tlp=tlp, - user_id=user_id, - options=options, - package=package, - parent_sample=parent_sample, - ) - continue - if static: - # On huge loads this just become a bottleneck - config = False - if web_conf.general.check_config_exists: - config = static_config_lookup(file) - if config: - task_ids.append(config["id"]) - else: - config = static_extraction(file) - if config or only_extraction: - task_ids += self.add_static( - file_path=file, priority=priority, tlp=tlp, user_id=user_id, options=options, parent_sample=parent_sample, - ) - - if not config and not only_extraction: - if not package: - package, tmp_package = self._identify_aux_func(file, "", check_shellcode=check_shellcode) - - if not tmp_package: - log.info("Do sandbox packages need an update? Sflock identifies as: %s - %s", tmp_package, file) - - if package == "dll" and "function" not in options: - with PortableExecutable(file.decode()) as pe: - dll_export = pe.choose_dll_export() - if dll_export == "DllRegisterServer": - package = "regsvr" - elif dll_export == "xlAutoOpen": - package = "xls" - elif dll_export: - if options: - options += f",function={dll_export}" - else: - options = f"function={dll_export}" - - # ToDo better solution? - Distributed mode here: - # Main node is storage so try to extract before submit to vm isn't propagated to workers - if static and not config and distconf.distributed.enabled: - if options: - options += ",dist_extract=1" - else: - options = "dist_extract=1" - - task_id = self.add_path( - file_path=file.decode(), - timeout=timeout, - priority=priority, - options=options, - package=package, - machine=machine, - platform=platform, - memory=memory, - custom=custom, - enforce_timeout=enforce_timeout, - tags=tags, - clock=clock, - tlp=tlp, - source_url=source_url, - route=route, - tags_tasks=tags_tasks, - cape=cape, - user_id=user_id, - parent_sample=parent_sample, - ) - package = None - if task_id: - task_ids.append(task_id) - - if config and isinstance(config, dict): - details = {"config": config.get("cape_config", {})} - if demux_error_msgs: - details["errors"] = demux_error_msgs - # this is aim to return custom data, think of this as kwargs - return task_ids, details - - def add_pcap( - self, - file_path, - timeout=0, - package="", - options="", - priority=1, - custom="", - machine="", - platform="", - tags=None, - memory=False, - enforce_timeout=False, - clock=None, - tlp=None, - user_id=0, - ): - return self.add( - PCAP(file_path.decode()), - timeout=timeout, - package=package, - options=options, - priority=priority, - custom=custom, - machine=machine, - platform=platform, - tags=tags, - memory=memory, - enforce_timeout=enforce_timeout, - clock=clock, - tlp=tlp, - user_id=user_id, - ) - - def add_static( - self, - file_path, - timeout=0, - package="", - options="", - priority=1, - custom="", - machine="", - platform="", - tags=None, - memory=False, - enforce_timeout=False, - clock=None, - tlp=None, - static=True, - user_id=0, - parent_sample=None, - ): - extracted_files, demux_error_msgs = demux_sample(file_path, package, options) - - # check if len is 1 and the same file, if diff register file, and set parent - if not isinstance(file_path, bytes): - file_path = file_path.encode() - - # ToDo callback maybe or inside of the self.add - if extracted_files and ((file_path, platform) not in extracted_files and (file_path, "") not in extracted_files): - if not parent_sample: - parent_sample = self.register_sample(File(file_path)) - if conf.cuckoo.delete_archive: - # ToDo keep as info for now - log.info("Deleting archive: %s. conf.cuckoo.delete_archive is enabled. %s", file_path, str(extracted_files)) - path_delete(file_path) - - task_ids = [] - # create tasks for each file in the archive - for file, platform in extracted_files: - task_id = self.add( - Static(file.decode()), - timeout=timeout, - package=package, - options=options, - priority=priority, - custom=custom, - machine=machine, - platform=platform, - tags=tags, - memory=memory, - enforce_timeout=enforce_timeout, - clock=clock, - tlp=tlp, - static=static, - parent_sample=parent_sample, - user_id=user_id, - ) - if task_id: - task_ids.append(task_id) - - return task_ids - - def add_url( - self, - url, - timeout=0, - package="", - options="", - priority=1, - custom="", - machine="", - platform="", - tags=None, - memory=False, - enforce_timeout=False, - clock=None, - tlp=None, - route=None, - cape=False, - tags_tasks=False, - user_id=0, - ): - """Add a task to database from url. - @param url: url. - @param timeout: selected timeout. - @param options: analysis options. - @param priority: analysis priority. - @param custom: custom options. - @param machine: selected machine. - @param platform: platform. - @param tags: tags for machine selection - @param memory: toggle full memory dump. - @param enforce_timeout: toggle full timeout execution. - @param clock: virtual machine clock time - @param tlp: TLP sharing designation - @param route: Routing route - @param cape: CAPE options - @param tags_tasks: Task tags so users can tag their jobs - @param user_id: Link task to user - @return: cursor or None. - """ - - # Convert empty strings and None values to a valid int - if not timeout: - timeout = 0 - if not priority: - priority = 1 - if not package: - package = web_conf.url_analysis.package - - return self.add( - URL(url), - timeout=timeout, - package=package, - options=options, - priority=priority, - custom=custom, - machine=machine, - platform=platform, - tags=tags, - memory=memory, - enforce_timeout=enforce_timeout, - clock=clock, - tlp=tlp, - route=route, - cape=cape, - tags_tasks=tags_tasks, - user_id=user_id, - ) - - def reschedule(self, task_id): - """Reschedule a task. - @param task_id: ID of the task to reschedule. - @return: ID of the newly created task. - """ - task = self.view_task(task_id) - - if not task: - return None - - if task.category == "file": - add = self.add_path - elif task.category == "url": - add = self.add_url - elif task.category == "pcap": - add = self.add_pcap - elif task.category == "static": - add = self.add_static - - # Change status to recovered. - self.session.get(Task, task_id).status = TASK_RECOVERED - - # Normalize tags. - if task.tags: - tags = ",".join(tag.name for tag in task.tags) - else: - tags = task.tags - - def _ensure_valid_target(task): - if task.category == "url": - # URL tasks always have valid targets, return it as-is. - return task.target - - # All other task types have a "target" pointing to a temp location, - # so get a stable path "target" based on the sample hash. - paths = self.sample_path_by_hash(task.sample.sha256, task_id) - paths = [file_path for file_path in paths if path_exists(file_path)] - if not paths: - return None - - if task.category == "pcap": - # PCAP task paths are represented as bytes - return paths[0].encode() - return paths[0] - - task_target = _ensure_valid_target(task) - if not task_target: - log.warning("Unable to find valid target for task: %s", task_id) - return - - new_task_id = None - if task.category in ("file", "url"): - new_task_id = add( - task_target, - task.timeout, - task.package, - task.options, - task.priority, - task.custom, - task.machine, - task.platform, - tags, - task.memory, - task.enforce_timeout, - task.clock, - tlp=task.tlp, - route=task.route, - ) - elif task.category in ("pcap", "static"): - new_task_id = add( - task_target, - task.timeout, - task.package, - task.options, - task.priority, - task.custom, - task.machine, - task.platform, - tags, - task.memory, - task.enforce_timeout, - task.clock, - tlp=task.tlp, - ) - - self.session.get(Task, task_id).custom = f"Recovery_{new_task_id}" - - return new_task_id - - def count_matching_tasks(self, category=None, status=None, not_status=None): - """Retrieve list of task. - @param category: filter by category - @param status: filter by task status - @param not_status: exclude this task status from filter - @return: number of tasks. - """ - stmt = select(func.count(Task.id)) - - if status: - stmt = stmt.where(Task.status == status) - if not_status: - stmt = stmt.where(Task.status != not_status) - if category: - stmt = stmt.where(Task.category == category) - - # 2. Execute the statement and return the single integer result. - return self.session.scalar(stmt) - - def check_file_uniq(self, sha256: str, hours: int = 0): - # TODO This function is poorly named. It returns True if a sample with the given - # sha256 already exists in the database, rather than returning True if the given - # sha256 is unique. - uniq = False - if hours and sha256: - date_since = _utcnow_naive() - timedelta(hours=hours) - - stmt = ( - select(Task) - .join(Sample, Task.sample_id == Sample.id) - .where(Sample.sha256 == sha256) - .where(Task.added_on >= date_since) - ) - return self.session.scalar(select(stmt.exists())) - else: - if not self.find_sample(sha256=sha256): - uniq = False - else: - uniq = True - - return uniq - - - def get_parent_sample_from_task(self, task_id: int) -> Optional[Sample]: - """Finds the Parent Sample using the ID of the child's Task.""" - - # This query joins the Sample table (as the parent) to the - # association object and filters by the task_id. - stmt = ( - select(Sample) - .join(SampleAssociation, Sample.id == SampleAssociation.parent_id) - .where(SampleAssociation.task_id == task_id) - ) - return self.session.scalar(stmt) - - def list_tasks( - self, - limit=None, - details=False, - category=None, - offset=None, - status=None, - sample_id=None, - not_status=None, - completed_after=None, - order_by=None, - added_before=None, - id_before=None, - id_after=None, - options_like=False, - options_not_like=False, - tags_tasks_like=False, - task_ids=False, - include_hashes=False, - user_id=None, - for_update=False, - ) -> List[Task]: - """Retrieve list of task. - @param limit: specify a limit of entries. - @param details: if details about must be included - @param category: filter by category - @param offset: list offset - @param status: filter by task status - @param sample_id: filter tasks for a sample - @param not_status: exclude this task status from filter - @param completed_after: only list tasks completed after this timestamp - @param order_by: definition which field to sort by - @param added_before: tasks added before a specific timestamp - @param id_before: filter by tasks which is less than this value - @param id_after filter by tasks which is greater than this value - @param options_like: filter tasks by specific option inside of the options - @param options_not_like: filter tasks by specific option not inside of the options - @param tags_tasks_like: filter tasks by specific tag - @param task_ids: list of task_id - @param include_hashes: return task+samples details - @param user_id: list of tasks submitted by user X - @param for_update: If True, use "SELECT FOR UPDATE" in order to create a row-level lock on the selected tasks. - @return: list of tasks. - """ - tasks: List[Task] = [] - stmt = select(Task).options(joinedload(Task.guest), subqueryload(Task.errors), subqueryload(Task.tags)) - if include_hashes: - stmt = stmt.options(joinedload(Task.sample)) - if status: - if "|" in status: - stmt = stmt.where(Task.status.in_(status.split("|"))) - else: - stmt = stmt.where(Task.status == status) - if not_status: - stmt = stmt.where(Task.status != not_status) - if category: - stmt = stmt.where(Task.category.in_([category] if isinstance(category, str) else category)) - if sample_id is not None: - stmt = stmt.where(Task.sample_id == sample_id) - if id_before is not None: - stmt = stmt.where(Task.id < id_before) - if id_after is not None: - stmt = stmt.where(Task.id > id_after) - if completed_after: - stmt = stmt.where(Task.completed_on > completed_after) - if added_before: - stmt = stmt.where(Task.added_on < added_before) - if options_like: - stmt = stmt.where(Task.options.like(f"%{options_like.replace('*', '%')}%")) - if options_not_like: - # Fix: SQL NULL NOT LIKE returns NULL, not TRUE - # Must explicitly check for NULL and empty string - stmt = stmt.where( - or_( - Task.options is None, - Task.options == "", - not_(Task.options.like(f"%{options_not_like.replace('*', '%')}%")) - ) - ) - if tags_tasks_like: - stmt = stmt.where(Task.tags_tasks.like(f"%{tags_tasks_like}%")) - if task_ids: - stmt = stmt.where(Task.id.in_(task_ids)) - if user_id is not None: - stmt = stmt.where(Task.user_id == user_id) - - # 3. Chaining for ordering, pagination, and locking remains the same - if order_by is not None and isinstance(order_by, tuple): - stmt = stmt.order_by(*order_by) - elif order_by is not None: - stmt = stmt.order_by(order_by) - else: - stmt = stmt.order_by(Task.added_on.desc()) - - stmt = stmt.limit(limit).offset(offset) - if for_update: - stmt = stmt.with_for_update(of=Task) - - tasks = self.session.scalars(stmt).all() - return tasks - - def delete_task(self, task_id): - """Delete information on a task. - @param task_id: ID of the task to query. - @return: operation status. - """ - task = self.session.get(Task, task_id) - if task is None: - return False - self.session.delete(task) - # ToDo missed commits everywhere, check if autocommit is possible - return True - - def delete_tasks( - self, - category=None, - status=None, - sample_id=None, - not_status=None, - completed_after=None, - added_before=None, - id_before=None, - id_after=None, - options_like=False, - options_not_like=False, - tags_tasks_like=False, - task_ids=False, - user_id=None, - ): - """Delete tasks based on parameters. If no filters are provided, no tasks will be deleted. - - Args: - category: filter by category - status: filter by task status - sample_id: filter tasks for a sample - not_status: exclude this task status from filter - completed_after: only list tasks completed after this timestamp - added_before: tasks added before a specific timestamp - id_before: filter by tasks which is less than this value - id_after: filter by tasks which is greater than this value - options_like: filter tasks by specific option inside of the options - options_not_like: filter tasks by specific option not inside of the options - tags_tasks_like: filter tasks by specific tag - task_ids: list of task_id - user_id: list of tasks submitted by user X - - Returns: - bool: True if the operation was successful (including no tasks to delete), False otherwise. - """ - delete_stmt = delete(Task) - filters_applied = False - - # 2. Chain .where() clauses for all filters - if status: - if "|" in status: - delete_stmt = delete_stmt.where(Task.status.in_(status.split("|"))) - else: - delete_stmt = delete_stmt.where(Task.status == status) - filters_applied = True - if not_status: - delete_stmt = delete_stmt.where(Task.status != not_status) - filters_applied = True - if category: - delete_stmt = delete_stmt.where(Task.category.in_([category] if isinstance(category, str) else category)) - filters_applied = True - if sample_id is not None: - delete_stmt = delete_stmt.where(Task.sample_id == sample_id) - filters_applied = True - if id_before is not None: - delete_stmt = delete_stmt.where(Task.id < id_before) - filters_applied = True - if id_after is not None: - delete_stmt = delete_stmt.where(Task.id > id_after) - filters_applied = True - if completed_after: - delete_stmt = delete_stmt.where(Task.completed_on > completed_after) - filters_applied = True - if added_before: - delete_stmt = delete_stmt.where(Task.added_on < added_before) - filters_applied = True - if options_like: - delete_stmt = delete_stmt.where(Task.options.like(f"%{options_like.replace('*', '%')}%")) - filters_applied = True - if options_not_like: - delete_stmt = delete_stmt.where(Task.options.notlike(f"%{options_not_like.replace('*', '%')}%")) - filters_applied = True - if tags_tasks_like: - delete_stmt = delete_stmt.where(Task.tags_tasks.like(f"%{tags_tasks_like}%")) - filters_applied = True - if task_ids: - delete_stmt = delete_stmt.where(Task.id.in_(task_ids)) - filters_applied = True - if user_id is not None: - delete_stmt = delete_stmt.where(Task.user_id == user_id) - filters_applied = True - - if not filters_applied: - log.warning("No filters provided for delete_tasks. No tasks will be deleted.") - return True - - # ToDo Transaction Handling - # The transaction logic (commit/rollback) is kept the same for a direct port, - # but the more idiomatic SQLAlchemy 2.0 approach would be to wrap the execution - # in a with self.session.begin(): block, which handles transactions automatically. - try: - result = self.session.execute(delete_stmt) - log.info("Deleted %d tasks matching the criteria.", result.rowcount) - self.session.commit() - return True - except SQLAlchemyError as e: - log.error("Error deleting tasks: %s", str(e)) - self.session.rollback() - return False - - # ToDo replace with delete_tasks - def clean_timed_out_tasks(self, timeout: int): - """Deletes PENDING tasks that were added more than `timeout` seconds ago.""" - if timeout <= 0: - return - - # Calculate the cutoff time before which tasks are considered timed out. - timeout_threshold = _utcnow_naive() - timedelta(seconds=timeout) - - # Build a single, efficient DELETE statement that filters in the database. - delete_stmt = delete(Task).where(Task.status == TASK_PENDING).where(Task.added_on < timeout_threshold) - - # Execute the bulk delete statement. - # The transaction should be handled by the calling code, - # typically with a `with session.begin():` block. - result = self.session.execute(delete_stmt) - - if result.rowcount > 0: - log.info("Deleted %d timed-out PENDING tasks.", result.rowcount) - - def minmax_tasks(self) -> Tuple[int, int]: - """Finds the minimum start time and maximum completion time for all tasks.""" - # A single query is more efficient than two separate ones. - stmt = select(func.min(Task.started_on), func.max(Task.completed_on)) - min_val, max_val = self.session.execute(stmt).one() - - if min_val and max_val: - # .timestamp() is the modern way to get a unix timestamp. - return int(min_val.replace(tzinfo=timezone.utc).timestamp()), int(max_val.replace(tzinfo=timezone.utc).timestamp()) - - return 0, 0 - - def get_tlp_tasks(self) -> List[int]: - """Retrieves a list of task IDs that have TLP enabled.""" - # Selecting just the ID is more efficient than fetching full objects. - stmt = select(Task.id).where(Task.tlp == "true") - # .scalars() directly yields the values from the single selected column. - return self.session.scalars(stmt).all() - - def get_file_types(self) -> List[str]: - """Gets a sorted list of unique sample file types.""" - # .distinct() is cleaner than group_by() for a single column. - stmt = select(Sample.file_type).distinct().order_by(Sample.file_type) - return self.session.scalars(stmt).all() - - def get_tasks_status_count(self) -> Dict[str, int]: - """Counts tasks, grouped by status.""" - stmt = select(Task.status, func.count(Task.status)).group_by(Task.status) - # .execute() returns rows, which can be directly converted to a dict. - return dict(self.session.execute(stmt).all()) - - def count_tasks(self, status: str = None, mid: int = None) -> int: - """Counts tasks in the database, with optional filters.""" - # Build a `SELECT COUNT(...)` query from the start for efficiency. - stmt = select(func.count(Task.id)) - if mid: - stmt = stmt.where(Task.machine_id == mid) - if status: - stmt = stmt.where(Task.status == status) - - # .scalar() executes the query and returns the single integer result. - return self.session.scalar(stmt) - - def view_task(self, task_id, details=False) -> Optional[Task]: - """Retrieve information on a task. - @param task_id: ID of the task to query. - @return: details on the task. - """ - query = select(Task).where(Task.id == task_id) - if details: - query = query.options( - joinedload(Task.guest), subqueryload(Task.errors), subqueryload(Task.tags), joinedload(Task.sample) - ) - else: - query = query.options(subqueryload(Task.tags), joinedload(Task.sample)) - return self.session.scalar(query) - - # This function is used by the runstatistics community module. - def add_statistics_to_task(self, task_id, details): # pragma: no cover - """add statistic to task - @param task_id: ID of the task to query. - @param: details statistic. - @return true of false. - """ - # ToDo do we really need this? does it need commit? - task = self.session.get(Task, task_id) - if task: - task.dropped_files = details["dropped_files"] - task.running_processes = details["running_processes"] - task.api_calls = details["api_calls"] - task.domains = details["domains"] - task.signatures_total = details["signatures_total"] - task.signatures_alert = details["signatures_alert"] - task.files_written = details["files_written"] - task.registry_keys_modified = details["registry_keys_modified"] - task.crash_issues = details["crash_issues"] - task.anti_issues = details["anti_issues"] - return True - - def view_sample(self, sample_id): - """Retrieve information on a sample given a sample id. - @param sample_id: ID of the sample to query. - @return: details on the sample used in sample: sample_id. - """ - return self.session.get(Sample, sample_id) - - def get_children_by_parent_id(self, parent_id: int) -> List[Sample]: - """ - Finds all child Samples using an explicit join. - """ - # Create an alias to represent the Child Sample in the query - ChildSample = aliased(Sample, name="child") - - # This query selects child samples by joining through the association table - stmt = ( - select(ChildSample) - .join(SampleAssociation, ChildSample.id == SampleAssociation.child_id) - .where(SampleAssociation.parent_id == parent_id) - ) - - return self.session.scalars(stmt).all() - - def find_sample( - self, md5: str = None, sha1: str = None, sha256: str = None, parent: int = None, task_id: int = None, sample_id: int = None - ) -> Union[Optional[Sample], List[Sample], List[Task]]: - """Searches for samples or tasks based on different criteria.""" - - if md5: - return self.session.scalar(select(Sample).where(Sample.md5 == md5)) - - if sha1: - return self.session.scalar(select(Sample).where(Sample.sha1 == sha1)) - - if sha256: - return self.session.scalar(select(Sample).where(Sample.sha256 == sha256)) - - if parent is not None: - return self.get_children_by_parent_id(parent) - - if sample_id is not None: - # Using session.get() is much more efficient than a select query. - # We wrap the result in a list to match the original function's behavior. - sample = self.session.get(Sample, sample_id) - return [sample] if sample else [] - - if task_id is not None: - # Note: This branch returns a list of Task objects. - stmt = select(Task).join(Sample, Task.sample_id == Sample.id).options(joinedload(Task.sample)).where(Task.id == task_id) - return self.session.scalars(stmt).all() - - return None - - def sample_still_used(self, sample_hash: str, task_id: int): - """Retrieve information if sample is used by another task(s). - @param sample_hash: sha256. - @param task_id: task_id - @return: bool - """ - stmt = ( - select(Task) - .join(Sample, Task.sample_id == Sample.id) - .where(Sample.sha256 == sample_hash) - .where(Task.id != task_id) - .where(Task.status.in_((TASK_PENDING, TASK_RUNNING, TASK_DISTRIBUTED))) - ) - - # select(stmt.exists()) creates a `SELECT EXISTS(...)` query. - # session.scalar() executes it and returns True or False directly. - return self.session.scalar(select(stmt.exists())) - - def _hash_file_in_chunks(self, path: str, hash_algo) -> str: - """Helper function to hash a file efficiently in chunks.""" - hasher = hash_algo() - buffer_size = 65536 # 64kb - with open(path, "rb") as f: - while chunk := f.read(buffer_size): - hasher.update(chunk) - return hasher.hexdigest() - - def sample_path_by_hash(self, sample_hash: str = False, task_id: int = False): - """Retrieve information on a sample location by given hash. - @param hash: md5/sha1/sha256/sha256. - @param task_id: task_id - @return: samples path(s) as list. - """ - sizes = { - 32: Sample.md5, - 40: Sample.sha1, - 64: Sample.sha256, - 128: Sample.sha512, - } - - hashlib_sizes = { - 32: hashlib.md5, - 40: hashlib.sha1, - 64: hashlib.sha256, - 128: hashlib.sha512, - } - - sizes_mongo = { - 32: "md5", - 40: "sha1", - 64: "sha256", - 128: "sha512", - } - - if task_id: - file_path = os.path.join(CUCKOO_ROOT, "storage", "analyses", str(task_id), "binary") - if path_exists(file_path): - return [file_path] - - # binary also not stored in binaries, perform hash lookup - stmt = select(Sample).join(Task, Sample.id == Task.sample_id).where(Task.id == task_id) - db_sample = self.session.scalar(stmt) - if db_sample: - path = os.path.join(CUCKOO_ROOT, "storage", "binaries", db_sample.sha256) - if path_exists(path): - return [path] - - sample_hash = db_sample.sha256 - - if not sample_hash: - return [] - - query_filter = sizes.get(len(sample_hash), "") - sample = [] - # check storage/binaries - if query_filter: - stmt = select(Sample).where(query_filter == sample_hash) - db_sample = self.session.scalar(stmt) - if db_sample is not None: - path = os.path.join(CUCKOO_ROOT, "storage", "binaries", db_sample.sha256) - if path_exists(path): - sample = [path] - - if not sample: - tasks = [] - if repconf.mongodb.enabled and web_conf.general.check_sample_in_mongodb: - tasks = mongo_find( - "files", - {sizes_mongo.get(len(sample_hash), ""): sample_hash}, - {"_info_ids": 1, "sha256": 1}, - ) - """ deprecated code - elif repconf.elasticsearchdb.enabled: - tasks = [ - d["_source"] - for d in es.search( - index=get_analysis_index(), - body={"query": {"match": {f"CAPE.payloads.{sizes_mongo.get(len(sample_hash), '')}": sample_hash}}}, - _source=["CAPE.payloads", "info.id"], - )["hits"]["hits"] - ] - """ - if tasks: - for task in tasks: - for id in task.get("_task_ids", []): - # ToDo suricata path - "suricata.files.file_info.path - for category in ("files", "procdump", "CAPE"): - file_path = os.path.join(CUCKOO_ROOT, "storage", "analyses", str(id), category, task["sha256"]) - if path_exists(file_path): - sample = [file_path] - break - if sample: - break - - if not sample: - # search in temp folder if not found in binaries - stmt = select(Task).join(Sample, Task.sample_id == Sample.id).where(query_filter == sample_hash) - db_sample = self.session.scalars(stmt).all() - - if db_sample is not None: - """ - samples = [_f for _f in [tmp_sample.to_dict().get("target", "") for tmp_sample in db_sample] if _f] - # hash validation and if exist - samples = [file_path for file_path in samples if path_exists(file_path)] - for path in samples: - with open(path, "rb") as f: - if sample_hash == hashlib_sizes[len(sample_hash)](f.read()).hexdigest(): - sample = [path] - break - """ - # Use a generator expression for memory efficiency - target_paths = (tmp_sample.to_dict().get("target", "") for tmp_sample in db_sample) - - # Filter for paths that exist - existing_paths = (p for p in target_paths if p and path_exists(p)) - # ToDo review if we really want/need this - for path in existing_paths: - if sample_hash == self._hash_file_in_chunks(path, hashlib_sizes[len(sample_hash)]): - sample = [path] - break - return sample - - def count_samples(self) -> int: - """Counts the amount of samples in the database.""" - stmt = select(func.count(Sample.id)) - return self.session.scalar(stmt) - - def view_machine(self, name: str) -> Optional[Machine]: - """Shows virtual machine details by name.""" - stmt = select(Machine).options(subqueryload(Machine.tags)).where(Machine.name == name) - return self.session.scalar(stmt) - - def view_machine_by_label(self, label: str) -> Optional[Machine]: - """Shows virtual machine details by label.""" - stmt = select(Machine).options(subqueryload(Machine.tags)).where(Machine.label == label) - return self.session.scalar(stmt) - - def view_errors(self, task_id: int) -> List[Error]: - """Gets all errors related to a task.""" - stmt = select(Error).where(Error.task_id == task_id) - return self.session.scalars(stmt).all() - - def get_source_url(self, sample_id: int = None) -> Optional[str]: - """Retrieves the source URL for a given sample ID.""" - if not sample_id: - return None - - try: - stmt = select(Sample.source_url).where(Sample.id == int(sample_id)) - return self.session.scalar(stmt) - except (TypeError, ValueError): - # Handle cases where sample_id is not a valid integer. - return None - - def ban_user_tasks(self, user_id: int): - """ - Bans all PENDING tasks submitted by a given user. - @param user_id: user id - """ - - update_stmt = update(Task).where(Task.user_id == user_id, Task.status == TASK_PENDING).values(status=TASK_BANNED) - - # 2. Execute the statement. - # The transaction should be handled by the calling code, - # ToDo e.g., with a `with session.begin():` block. - self.session.execute(update_stmt) - - def tasks_reprocess(self, task_id: int): - """common func for api and views""" - task = self.view_task(task_id) - if not task: - return True, "Task ID does not exist in the database", "" - - if task.status not in { - # task status suitable for reprocessing - # allow reprocessing of tasks already processed (maybe detections changed) - TASK_REPORTED, - # allow reprocessing of tasks that were rescheduled - TASK_RECOVERED, - # allow reprocessing of tasks that previously failed the processing stage - TASK_FAILED_PROCESSING, - # allow reprocessing of tasks that previously failed the reporting stage - TASK_FAILED_REPORTING, - # TASK_COMPLETED, - }: - return True, f"Task ID {task_id} cannot be reprocessed in status {task.status}", task.status - - # Save the old_status, because otherwise, in the call to set_status(), - # sqlalchemy will use the cached Task object that `task` is already a reference - # to and update that in place. That would result in `task.status` in this - # function being set to TASK_COMPLETED and we don't want to return that. - old_status = task.status - self.set_status(task_id, TASK_COMPLETED) - return False, "", old_status - - _DATABASE: Optional[_Database] = None diff --git a/lib/cuckoo/core/machinery_manager.py b/lib/cuckoo/core/machinery_manager.py index 02d7efaedbc..7437ff39417 100644 --- a/lib/cuckoo/core/machinery_manager.py +++ b/lib/cuckoo/core/machinery_manager.py @@ -7,7 +7,9 @@ from lib.cuckoo.common.abstracts import Machinery from lib.cuckoo.common.config import Config from lib.cuckoo.common.exceptions import CuckooCriticalError, CuckooMachineError -from lib.cuckoo.core.database import Database, Machine, Task, _Database +from lib.cuckoo.core.database import Database, _Database +from lib.cuckoo.core.data.machines import Machine +from lib.cuckoo.core.data.task import Task from lib.cuckoo.core.plugins import list_plugins from lib.cuckoo.core.rooter import rooter, vpns diff --git a/lib/cuckoo/core/scheduler.py b/lib/cuckoo/core/scheduler.py index 14241866a32..5188ec53678 100644 --- a/lib/cuckoo/core/scheduler.py +++ b/lib/cuckoo/core/scheduler.py @@ -21,7 +21,10 @@ from lib.cuckoo.common.exceptions import CuckooUnserviceableTaskError from lib.cuckoo.common.utils import CATEGORIES_NEEDING_VM, load_categories from lib.cuckoo.core.analysis_manager import AnalysisManager -from lib.cuckoo.core.database import TASK_FAILED_ANALYSIS, TASK_PENDING, Database, Machine, Task, _Database, _utcnow_naive +from lib.cuckoo.core.data.db_common import _utcnow_naive +from lib.cuckoo.core.database import Database, _Database +from lib.cuckoo.core.data.machines import Machine +from lib.cuckoo.core.data.task import Task, TASK_FAILED_ANALYSIS, TASK_PENDING from lib.cuckoo.core.machinery_manager import MachineryManager log = logging.getLogger(__name__) @@ -238,6 +241,7 @@ def find_pending_task_to_service(self) -> Tuple[Optional[Task], Optional[Machine "Please check your machinery configuration." ) + if self.cfg.cuckoo.fail_unserviceable: log.info( log_message.format( diff --git a/lib/cuckoo/core/startup.py b/lib/cuckoo/core/startup.py index 602eaeb2fa5..8a7a7b26fbd 100644 --- a/lib/cuckoo/core/startup.py +++ b/lib/cuckoo/core/startup.py @@ -52,7 +52,8 @@ from lib.cuckoo.common.exceptions import CuckooOperationalError, CuckooStartupError from lib.cuckoo.common.path_utils import path_exists from lib.cuckoo.common.utils import create_folders -from lib.cuckoo.core.database import TASK_FAILED_ANALYSIS, TASK_RUNNING, Database +from lib.cuckoo.core.database import Database +from lib.cuckoo.core.data.task import TASK_FAILED_ANALYSIS, TASK_RUNNING from lib.cuckoo.core.log import init_logger from lib.cuckoo.core.plugins import import_package, import_plugin, list_plugins from lib.cuckoo.core.rooter import rooter, socks5s, vpns diff --git a/modules/machinery/aws.py b/modules/machinery/aws.py index c247a473bfc..5bb06a65668 100644 --- a/modules/machinery/aws.py +++ b/modules/machinery/aws.py @@ -3,7 +3,7 @@ import time from lib.cuckoo.common.config import Config -from lib.cuckoo.core.database import Machine +from lib.cuckoo.core.data.machines import Machine cfg = Config() HAVE_BOTO3 = False diff --git a/modules/machinery/az.py b/modules/machinery/az.py index b116213d7e5..38f4e829ef0 100644 --- a/modules/machinery/az.py +++ b/modules/machinery/az.py @@ -24,7 +24,8 @@ CuckooMachineError, CuckooUnserviceableTaskError, ) -from lib.cuckoo.core.database import TASK_PENDING, TASK_RUNNING, Machine, Task +from lib.cuckoo.core.data.task import TASK_PENDING, TASK_RUNNING, Task +from lib.cuckoo.core.data.machines import Machine HAVE_AZURE = False cfg = Config() diff --git a/modules/reporting/callback.py b/modules/reporting/callback.py index 448af1c95b5..029030f2ff6 100644 --- a/modules/reporting/callback.py +++ b/modules/reporting/callback.py @@ -4,7 +4,8 @@ import requests from lib.cuckoo.common.abstracts import Report -from lib.cuckoo.core.database import TASK_REPORTED, Database +from lib.cuckoo.core.database import Database +from lib.cuckoo.core.data.task import TASK_REPORTED log = logging.getLogger(__name__) main_db = Database() diff --git a/tests/audit_packages/readme.md b/tests/audit_packages/readme.md new file mode 100644 index 00000000000..ce55118458d --- /dev/null +++ b/tests/audit_packages/readme.md @@ -0,0 +1,3 @@ +Audit packages go here +Each test has a directory containing 'payload.zip' and a 'test.py' +All directories should be owned or at least writeable by the cape user \ No newline at end of file diff --git a/tests/test_analysis_manager.py b/tests/test_analysis_manager.py index 76ead8f4a2d..e283558d2ed 100644 --- a/tests/test_analysis_manager.py +++ b/tests/test_analysis_manager.py @@ -12,7 +12,10 @@ from lib.cuckoo.common.abstracts import Machinery from lib.cuckoo.common.config import Config, ConfigMeta from lib.cuckoo.core.analysis_manager import AnalysisManager -from lib.cuckoo.core.database import TASK_RUNNING, Guest, Machine, Task, _Database +from lib.cuckoo.core.data.task import TASK_RUNNING, Task +from lib.cuckoo.core.data.guests import Guest +from lib.cuckoo.core.data.machines import Machine +from lib.cuckoo.core.database import _Database from lib.cuckoo.core.machinery_manager import MachineryManager from lib.cuckoo.core.scheduler import Scheduler diff --git a/tests/test_database.py b/tests/test_database.py index 706d6322154..71a2fc0c4be 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -18,23 +18,19 @@ from lib.cuckoo.common.exceptions import CuckooUnserviceableTaskError from lib.cuckoo.common.path_utils import path_mkdir from lib.cuckoo.common.utils import store_temp_file -from lib.cuckoo.core import database -from lib.cuckoo.core.database import ( +from lib.cuckoo.core.data import tasking +from lib.cuckoo.core.data.task import ( TASK_BANNED, TASK_COMPLETED, TASK_PENDING, TASK_REPORTED, TASK_RUNNING, - Error, - Guest, - Machine, - Sample, - Tag, - Task, - _Database, - _utcnow_naive, - machines_tags, -) + Task) +from lib.cuckoo.core.data.guests import Guest +from lib.cuckoo.core.data.samples import Sample +from lib.cuckoo.core.data.machines import Machine, machines_tags +from lib.cuckoo.core.data.db_common import _utcnow_naive, Tag, Error +from lib.cuckoo.core.database import _Database @dataclasses.dataclass @@ -447,7 +443,7 @@ def test_update_clock_url(self, db: _Database, monkeypatch, freezer): with db.session.begin(): task_id = db.add_url("https://www.google.com") now = _utcnow_naive() - monkeypatch.setattr(database.datetime, "utcnow", lambda: now) + monkeypatch.setattr(tasking.datetime, "utcnow", lambda: now) # URL's are unaffected by the daydelta setting. monkeypatch.setattr(db.cfg.cuckoo, "daydelta", 1) assert db.update_clock(task_id) == now diff --git a/tests/web/test_apiv2.py b/tests/web/test_apiv2.py index 796eefa6949..204eea858b5 100644 --- a/tests/web/test_apiv2.py +++ b/tests/web/test_apiv2.py @@ -5,7 +5,7 @@ from django.test import SimpleTestCase from lib.cuckoo.common.config import ConfigMeta -from lib.cuckoo.core.database import ( +from lib.cuckoo.core.data.task import ( TASK_BANNED, TASK_COMPLETED, TASK_DISTRIBUTED, diff --git a/utils/process.py b/utils/process.py index 6cc440219ac..e41a82d5bd3 100644 --- a/utils/process.py +++ b/utils/process.py @@ -40,14 +40,13 @@ from lib.cuckoo.common.constants import CUCKOO_ROOT from lib.cuckoo.common.path_utils import path_delete, path_exists, path_mkdir from lib.cuckoo.common.utils import get_options -from lib.cuckoo.core.database import ( +from lib.cuckoo.core.database import Database, init_database +from lib.cuckoo.core.data.task import ( TASK_COMPLETED, TASK_FAILED_PROCESSING, TASK_FAILED_REPORTING, TASK_REPORTED, - Database, - Task, - init_database, + Task ) from lib.cuckoo.core.plugins import RunProcessing, RunReporting, RunSignatures from lib.cuckoo.core.startup import ConsoleHandler, check_linux_dist, init_modules diff --git a/web/analysis/views.py b/web/analysis/views.py index c493340aba2..65494c02332 100644 --- a/web/analysis/views.py +++ b/web/analysis/views.py @@ -36,7 +36,8 @@ from lib.cuckoo.common.path_utils import path_exists, path_get_size, path_mkdir, path_read_file, path_safe from lib.cuckoo.common.utils import delete_folder, yara_detected from lib.cuckoo.common.web_utils import category_all_files, my_rate_minutes, my_rate_seconds, perform_search, rateblock, statistics -from lib.cuckoo.core.database import TASK_PENDING, Database, Task +from lib.cuckoo.core.database import Database, TasksMixIn +from lib.cuckoo.core.data.task import TASK_PENDING, Task from modules.reporting.report_doc import CHUNK_CALL_SIZE try: @@ -161,7 +162,7 @@ if enabledconf["mongodb"] or enabledconf["elasticsearchdb"]: DISABLED_WEB = False -db = Database() +db: TasksMixIn = Database() anon_not_viewable_func_list = ( "file", @@ -339,7 +340,7 @@ def index(request, page=1): analyses_pcaps = [] analyses_static = [] - tasks_files = db.list_tasks(limit=TASK_LIMIT, offset=off, category="file", not_status=TASK_PENDING) + tasks_files = db.list_tasks(limit=TASK_LIMIT, offset=off, category="file", not_status=TASK_PENDING, tags_tasks_not_like="audit") tasks_static = db.list_tasks(limit=TASK_LIMIT, offset=off, category="static", not_status=TASK_PENDING) tasks_urls = db.list_tasks(limit=TASK_LIMIT, offset=off, category="url", not_status=TASK_PENDING) tasks_pcaps = db.list_tasks(limit=TASK_LIMIT, offset=off, category="pcap", not_status=TASK_PENDING) diff --git a/web/apiv2/views.py b/web/apiv2/views.py index 909cc12eac2..0657abd08a8 100644 --- a/web/apiv2/views.py +++ b/web/apiv2/views.py @@ -53,12 +53,11 @@ statistics, validate_task, ) -from lib.cuckoo.core.database import ( +from lib.cuckoo.core.database import Database, _Database +from lib.cuckoo.core.data.task import ( TASK_RECOVERED, TASK_RUNNING, - Database, Task, - _Database, ) from lib.cuckoo.core.rooter import _load_socks5_operational, vpns diff --git a/web/audit/__init__.py b/web/audit/__init__.py new file mode 100644 index 00000000000..48d6094651a --- /dev/null +++ b/web/audit/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2010-2015 Cuckoo Foundation. +# This file is part of Cuckoo Sandbox - http://www.cuckoosandbox.org +# See the file "docs/LICENSE" for copying permission. diff --git a/web/audit/urls.py b/web/audit/urls.py new file mode 100644 index 00000000000..13d423430c4 --- /dev/null +++ b/web/audit/urls.py @@ -0,0 +1,22 @@ +# Copyright (C) 2010-2015 Cuckoo Foundation. +# This file is part of Cuckoo Sandbox - http://www.cuckoosandbox.org +# See the file "docs/LICENSE" for copying permission. + +from django.urls import re_path, path +from audit import views + +urlpatterns = [ + re_path(r"^$", views.audit_index, name="audit_index"), + re_path(r"^page/(?P\d+)/$", views.audit_index, name="audit_index"), + re_path(r"^session/(?P\d+)/$", views.session_index, name="test_session"), + re_path(r"^session/(?P\d+)/status$", views.session_status, name="session_status"), + re_path(r"^session/(?P\d+)/run_update//", views.get_run_update, name="get_run_update"), + re_path(r"^reload_available_tests/", views.reload_available_tests, name="reload_available_tests"), + re_path(r"^create_test_session/$", views.create_test_session, name="create_test_session"), + re_path(r"^delete_test_session/(?P\d+)/$", views.delete_test_session, name="delete_test_session"), + path(r"session//queue_tests/", views.queue_all_tests, name="queue_all_tests"), + path(r"session//unqueue_tests/", views.unqueue_all_tests, name="unqueue_all_tests"), + path(r"session//queue_tests//", views.queue_test, name="queue_test"), + path(r"session//unqueue_tests//", views.unqueue_test, name="unqueue_test"), + re_path(r"^update_task_config/(?P\d+)/$", views.update_task_config, name="update_task_config") +] diff --git a/web/audit/views.py b/web/audit/views.py new file mode 100644 index 00000000000..1f828bbf8e4 --- /dev/null +++ b/web/audit/views.py @@ -0,0 +1,495 @@ +import json +import os +import sys +import logging +from typing import Optional, Dict + +from django.conf import settings +from django.contrib import messages +from django.contrib.auth.decorators import login_required +from django.http import JsonResponse, HttpResponseNotFound, HttpResponseForbidden +from django.shortcuts import redirect, render +from django import template +from django.template.loader import render_to_string +from django.views.decorators.http import require_POST +from django.urls import reverse + +register = template.Library() + +sys.path.append(settings.CUCKOO_PATH) + +logger = logging.getLogger(__name__) + +from lib.cuckoo.common.config import Config +from lib.cuckoo.common.audit_utils import TestLoader +from lib.cuckoo.core.database import Database +from lib.cuckoo.core.data.audits import AuditsMixIn, TestSession +from lib.cuckoo.core.data.task import TASK_PENDING, Task +from lib.cuckoo.core.data.db_common import _utcnow_naive +from lib.cuckoo.core.data.audit_data import (TestRun, TEST_QUEUED, TEST_COMPLETE, TEST_FAILED, TEST_RUNNING, TEST_UNQUEUED) + +''' +try: + from django_ratelimit.decorators import ratelimit +except ImportError: + try: + from ratelimit.decorators import ratelimit + except ImportError: + print("missed dependency: poetry install") +''' + +SESSIONS_PER_PAGE = 10 +AUDIT_PACKAGES_ROOT = os.path.join(settings.CUCKOO_PATH, "tests", "audit_packages") +processing_cfg = Config("processing") +reporting_cfg = Config("reporting") +integrations_cfg = Config("integrations") +web_cfg = Config("web") +db: AuditsMixIn = Database() + +anon_not_viewable_func_list = ( +) + +# Conditional decorator for web authentication +class conditional_login_required: + def __init__(self, dec, condition): + self.decorator = dec + self.condition = condition + + def __call__(self, func): + if not hasattr(web_cfg, 'audit_framework') or \ + not hasattr(web_cfg.audit_framework, 'enabled') or \ + not web_cfg.audit_framework.enabled: + def fail(*args, **kwargs): + return HttpResponseForbidden("Audit Framework is not set to enabled in web config.") + return fail + + if settings.ANON_VIEW and func.__name__ not in anon_not_viewable_func_list: + return func + if not self.condition: + return func + return self.decorator(func) + +@conditional_login_required(login_required, settings.WEB_AUTHENTICATION) +def audit_index(request, page:int = 1): + """ + The main index function for the /audit page with lists of sessions and tests + Currently only handles paging for sessions as tests are probably better + viewed as a single page while being picked through for a new session + """ + with db.session.session_factory() as db_session, db_session.begin(): + available_tests = db.list_available_tests(db_session=db_session) + for test in available_tests: + # We create a new "virtual" attribute on the object + if test.task_config: + test.task_config_pretty = json.dumps(test.task_config, indent=2) + else: + test.task_config_pretty = "{}" + + paging = {} + first_last_session = db.get_session_id_range() + page = int(page) + if page == 0: + page = 1 + offset = (page - 1) * SESSIONS_PER_PAGE + + test_sessions = db.list_test_sessions(db_session=db_session, offset=offset, limit=SESSIONS_PER_PAGE) + for audit_session in test_sessions: + for run in audit_session.runs: + if run.status not in [TEST_COMPLETE, TEST_FAILED]: + db.update_audit_tasks_status(db_session=db_session, audit_session=audit_session) + break + audit_session.stats = get_session_stats(audit_session) + + paging["show_session_prev"] = "hide" + paging["show_session_next"] = "hide" + + if test_sessions: + if test_sessions[0].id != first_last_session[1]: + paging["show_session_prev"] = "show" + if test_sessions[-1].id != first_last_session[0]: + paging["show_session_next"] = "show" + + sessions_count = db.count_test_sessions() + pages_sessions_num = int(sessions_count / SESSIONS_PER_PAGE + 1) + + sessions_pages = [] + if pages_sessions_num < 11 or page < 6: + sessions_pages = list(range(1, min(10, pages_sessions_num) + 1)) + elif page > 5: + sessions_pages = list(range(min(page - 5, pages_sessions_num - 10) + 1, min(page + 5, pages_sessions_num) + 1)) + + paging["sessions_page_range"] = sessions_pages + paging["next_page"] = str(page + 1) + paging["prev_page"] = str(page - 1) + paging["current_page"] = page + + return render( + request, + "audit/index.html", + { + "available_tests": available_tests, + "total_sessions": sessions_count, + "sessions": test_sessions, + "paging": paging, + }, + ) + + +@require_POST +@conditional_login_required(login_required, settings.WEB_AUTHENTICATION) +def create_test_session(request): + """ + Takes a list of test IDs in the test_ids POST value and generates + a new audit session from them. + Redirects to the new session on success. + Redirects back to the session list on failure. + """ + if request.method != "POST": + return redirect("audit_index") + + test_ids = request.POST.getlist("test_ids") + if not test_ids: + messages.warning(request, "No tests were selected.") + return redirect("audit_index") + + try: + # This calls the SQLAlchemy logic we discussed to create the TestSession + TestRuns + session_id = db.create_session_from_tests(test_ids) + + # Success! Now we go to the "Mission Control" page + return redirect("test_session", session_id=session_id) + + except Exception as e: + messages.error(request, "Error creating session: %s",str(e)) + return redirect("audit_index") + + +@require_POST +@conditional_login_required(login_required, settings.WEB_AUTHENTICATION) +def reload_available_tests(request): + """ + Triggers the TestLoader to refresh the AvailableTests table. + By default reads from tests/audit_packages + """ + # Path where your test subdirectories live + if not os.path.isdir(AUDIT_PACKAGES_ROOT): + errmsg = "reload_available_tests::Audit packages root is not an existing directory: " + AUDIT_PACKAGES_ROOT + messages.error(request, errmsg) + return redirect(reverse("audit_index") + "#available-tests") + + loader = TestLoader(AUDIT_PACKAGES_ROOT) + result = loader.load_tests() + logger.info("Test load results: %s", json.dumps(result)) + + try: + # This calls the method you added to your Mixin + count = db.reload_tests(result["available"], result["unavailable"]) + if result["unavailable"]: + if not result["available"]: + errmg = "Failed to load %d tests from %s [%s]."%\ + (len(result["unavailable"]), + AUDIT_PACKAGES_ROOT, + str(result['unavailable'])) + messages.error(request,errmg) + else: + messages.warning( + request, + f"Partial failure to reload tests from {AUDIT_PACKAGES_ROOT} [Failed: {result['unavailable']}].", + ) + else: + messages.success(request, f"Successfully reloaded all {count} tests from {AUDIT_PACKAGES_ROOT}") + + except Exception as e: + messages.error(request, f"reload_available_tests:: Error reloading tests: {str(e)}") + logger.exception("reload_available_tests::exception") + + return redirect(reverse("audit_index") + "#available-tests") + + +@require_POST +@conditional_login_required(login_required, settings.WEB_AUTHENTICATION) +def delete_test_session(request, session_id: int): + """ + Purges a tests session and its task storage directory + Fails if any of the tests are active + """ + try: + session = db.get_test_session(session_id) + if session: + if not db.delete_test_session(session_id): + messages.warning(request, f"Could not delete active session #{session_id}.") + except Exception as e: + messages.error(request, f"Error deleting session: {str(e)}") + logger.error("Error deleting session: %s",str(e)) + return redirect("audit_index") + + +@conditional_login_required(login_required, settings.WEB_AUTHENTICATION) +def session_index(request, session_id: int): + """ + The index function for an invididual audit session + """ + session_data = db.get_test_session(session_id) + stats = get_session_stats(session_data) + + if not session_data: + messages.warning(request, "Session not found.") + return redirect("audit_index") + + run_html = {} + for run in session_data.runs: + run_status = _render_run_update(request, session_id, run.id) + run_html[run.id] = run_status["html"] + + return render( + request, + "audit/session.html", + {"session": session_data, "runs": session_data.runs, "run_html": run_html, "stats": stats}, + ) + + +def generate_task_diagnostics(task: Task, test_run: TestRun): + """ + Gathers CAPE task timestamps + """ + diagnostics = {} + timenow = _utcnow_naive() + if task.added_on: + if task.started_on: + diagnostics["start_wait"] = task.started_on - task.added_on + else: + diagnostics["start_wait"] = timenow - task.added_on + + if task.started_on: + if task.completed_on: + diagnostics["run_time"] = task.completed_on - task.started_on + else: + diagnostics["run_time"] = timenow - task.started_on + + if task.completed_on: + if task.reporting_finished_on: + diagnostics["report_wait"] = task.reporting_finished_on - task.completed_on + else: + # implementing reporting_finished_on was a recent change, it may not be there + if test_run.status == TEST_RUNNING: + diagnostics["report_wait"] = timenow - task.completed_on + + return diagnostics + +def _format_json_config(config_raw): + return json.dumps(config_raw, indent=2) + +def _render_run_update(request, session_id: int, testrun_id: int): + """ + The mechanics of rendering sessions is here + This framework is currently lazy loaded, so test sessions will be updated + and objectives evaluated when this is called, making it potentially slow on + some occasions. + """ + db_test_session = db.get_test_session(session_id) + test_run = next((r for r in db_test_session.runs if r.id == testrun_id), None) + cape_task_info = None + diagnostics = None + if test_run.cape_task_id is not None: + cape_task_info = db.view_task(test_run.cape_task_id) + if cape_task_info: + diagnostics = generate_task_diagnostics(cape_task_info, test_run) + + if test_run.test_definition.task_config: + test_run.task_config_pretty = _format_json_config(test_run.test_definition.task_config) + + # Render just the partial file with the updated 'run' object + html = render_to_string( + "audit/partials/session_test_run.html", + {"run": test_run, "cape_task": cape_task_info, "diagnostics": diagnostics}, + request=request, + ) + return {"html": html, "status": test_run.status, "id": test_run.id} + + +@conditional_login_required(login_required, settings.WEB_AUTHENTICATION) +def get_run_update(request, session_id: int, testrun_id: int): + """ + Get an update for a test run of a session without having to reload the whole page + """ + return JsonResponse(_render_run_update(request, session_id, testrun_id)) + +def get_session_stats(db_test_session: TestSession) -> Optional[Dict]: + """ + Fetch test and objective statistics + """ + if db_test_session is None: + return None + runs = db_test_session.runs + stats = { + "tests": { TEST_QUEUED: 0, TEST_UNQUEUED: 0, TEST_COMPLETE: 0, TEST_RUNNING: 0, TEST_FAILED: 0}, + "objectives": {"untested": 0, "skipped": 0, "success": 0, "failure": 0, "info": 0, "error": 0}, + "complete_but_unevaluated": 0, + } + for run in runs: + stats["tests"][run.status] += 1 + for objective in run.objectives: + stats["objectives"][objective.state] += 1 + if run.status == TEST_COMPLETE and objective.state == "untested": + stats["complete_but_unevaluated"] += 1 + return stats + + +@register.filter +def get_item(dictionary, key): + return dictionary.get(key) + +@conditional_login_required(login_required, settings.WEB_AUTHENTICATION) +def session_status(request, session_id: int): + """ + The call used when a session page is being refreshed + """ + db_test_session = db.get_test_session(session_id) + if db_test_session is None: + logger.warning("Tried to view session_status with invalid session %s", str(session_status)) + return HttpResponseNotFound + + runs = db_test_session.runs + stats = get_session_stats(db_test_session) + + results = [] + for run in runs: + results.append(_render_run_update(request, session_id, run.id)) + + status_box = render_to_string( + "audit/partials/session_status_header.html", {"stats": stats, "session": db_test_session}, request=request + ) + + return JsonResponse( + { + "test_cards": results, + "status_box_card": status_box, + "stats": stats, + "count_unqueued": db_test_session.queued_run_count, + "count_queued": db_test_session.unqueued_run_count, + } + ) + + +def inner_queue_test(request, session_id: int, testrun_id: int) -> Optional[int]: + """ + Queue a test from an audit session as a CAPE task + Returns the cape task id if success, or None if failure + """ + user_id = request.user.id or 0 + cape_task_id = None + try: + cape_task_id = db.queue_audit_test(session_id, testrun_id, user_id) + db.assign_cape_task_to_testrun(testrun_id, cape_task_id) + logger.info("CAPE queued task %d to service audit [session:%d test:%d user:%d]", + cape_task_id, session_id, testrun_id, user_id) + return cape_task_id + except Exception as ex: + messages.error(request, f"Task Exception: {ex}") + return None + + +@require_POST +@conditional_login_required(login_required, settings.WEB_AUTHENTICATION) +def queue_test(request, session_id: int, testrun_id: int): + """ + Path to queue a single test of a session + """ + cape_task_id = inner_queue_test(request, session_id, testrun_id) + if cape_task_id: + return JsonResponse({"status": "success", "message": "Test queued successfully", "task_id": cape_task_id}) + else: + return JsonResponse({"status": "failure", "message": "Could not queue test - see messages.", "task_id": None}) + +@require_POST +@conditional_login_required(login_required, settings.WEB_AUTHENTICATION) +def queue_all_tests(request, session_id: int): + """ + Path to queue all unqueued tests of a session + Returns the test_run_id -> task_id mapping + """ + task_ids = [] + db_test_session = db.get_test_session(session_id) + for run in db_test_session.runs: + if run.status != TEST_UNQUEUED: + continue + task_id = inner_queue_test(request, session_id, run.id) + if task_id is not None: + task_ids.append({"run": run.id, "task": task_id}) + return JsonResponse({"task_ids": task_ids}) + +def inner_unqueue_test(testrun: TestRun) -> Optional[int]: + """ + Try to delete a CAPE task of a session which has been queued (but not started yet) + @parameter testrun: The TestRun db object of the run to clear + Returns the deleted CAPE task id, or None if failed + """ + # note: I tried to use db.delete_tasks(), with the task_id's & TASK_PENDING + # filter but couldn't get round commit/transaction errors + # there is a chance of some race conditions here + if testrun.cape_task_id is not None and testrun.status == TEST_QUEUED: + task_id = testrun.cape_task_id + cape_task = db.view_task(task_id) + if cape_task.status == TASK_PENDING: + db.delete_task(task_id) + testrun.status = TEST_UNQUEUED + testrun.cape_task_id = None + return task_id + return None + +@require_POST +@conditional_login_required(login_required, settings.WEB_AUTHENTICATION) +def unqueue_test(request, session_id, testrun_id): + """ + Path to unqueue a single test of a session + """ + run = db.get_audit_session_test(session_id, testrun_id) + if not run: + return JsonResponse({"status": "failure", "message": "Could not retrieve test task", "task_id": None}) + + cape_task_id = inner_unqueue_test(run) + if cape_task_id: + return JsonResponse({"status": "success", "message": "Test unqueued successfully", "task_id": cape_task_id}) + else: + return JsonResponse({"status": "failure", "message": "Could not unqueue test", "task_id": None}) + +@require_POST +@conditional_login_required(login_required, settings.WEB_AUTHENTICATION) +def unqueue_all_tests(request, session_id: int): + deleted_task_ids = [] + db_test_session = db.get_test_session(session_id) + if not db_test_session: + logger.warning("Request to unqueue all tests of invalid session %d",session_id) + else: + for run in db_test_session.runs: + task_id = inner_unqueue_test(run) + if task_id: + deleted_task_ids.append(task_id) + return JsonResponse({"deleted_tasks": deleted_task_ids}) + + +@require_POST +@conditional_login_required(login_required, settings.WEB_AUTHENTICATION) +def update_task_config(request, availabletest_id): + if request.method == "POST": + with db.session.session_factory() as db_session, db_session.begin(): + test = db.get_test(availabletest_id=availabletest_id, db_session=db_session) + raw_json = request.POST.get("task_config", "").strip() + + try: + # 1. Validate JSON syntax + parsed_data = json.loads(raw_json) + + # 2. Save the minified version to the DB (or keep pretty if preferred) + test.task_config = parsed_data + db_session.commit() + messages.success(request, f"Configuration for Test {test.name} (#{test.id}) updated successfully.") + return JsonResponse({"success": True, "config_pretty": _format_json_config(parsed_data)}) + + except json.JSONDecodeError as e: + messages.error(request, f"Failed to save: Invalid JSON format. Error: {str(e)}") + return JsonResponse({"success": False, "error": "bad json: "+str(e)}) + + except Exception as e: + messages.error(request, f"An unexpected error occurred: {str(e)}") + return JsonResponse({"success": False, "error": str(e)}) diff --git a/web/dashboard/views.py b/web/dashboard/views.py index 5a7a09787e1..b23c634658b 100644 --- a/web/dashboard/views.py +++ b/web/dashboard/views.py @@ -9,10 +9,13 @@ from django.shortcuts import render from django.views.decorators.http import require_safe +from lib.cuckoo.core.data.tasking import TasksMixIn + sys.path.append(settings.CUCKOO_PATH) from lib.cuckoo.common.web_utils import top_detections -from lib.cuckoo.core.database import ( +from lib.cuckoo.core.database import Database +from lib.cuckoo.core.data.task import ( TASK_COMPLETED, TASK_DISTRIBUTED, TASK_FAILED_ANALYSIS, @@ -21,8 +24,7 @@ TASK_PENDING, TASK_RECOVERED, TASK_REPORTED, - TASK_RUNNING, - Database, + TASK_RUNNING ) @@ -47,7 +49,7 @@ def format_number_with_space(number): @require_safe @conditional_login_required(login_required, settings.WEB_AUTHENTICATION) def index(request): - db = Database() + db: TasksMixIn = Database() report = dict( total_samples=format_number_with_space(db.count_samples()), diff --git a/web/templates/audit/index.html b/web/templates/audit/index.html new file mode 100644 index 00000000000..26850f60635 --- /dev/null +++ b/web/templates/audit/index.html @@ -0,0 +1,446 @@ +{% extends "base.html" %} + + + +{% block content %} + +{% if messages %} +
+
    + {% for message in messages %} +
  • + {{ message }} +
  • + {% endfor %} +
+
+{% endif %} + + + + + +
+
+ +
+ {% if paging.show_session_next == "show" or paging.show_session_prev == "show" %} + + {% endif %} + + +
+
+
Recent Audit Sessions
+ {{total_sessions}} session{{total_sessions|pluralize}} +
+ +
+ + + + + + + + + + + + {% for session in sessions %} + + + + + + + + + + + + {% empty %} + + + + {% endfor %} + +
Session IDCreatedTestsObjectivesCleanup
+ #{{ session.id }} + + + {{ session.added_on|date:"Y-m-d" }} + {{ session.added_on|date:"H:i" }} + + +
+ +
+
Unqueued
+
+ {{ session.stats.tests.unqueued }} +
+
+ +
+
Queued
+
+ {{ session.stats.tests.queued }} +
+
+ +
+
Running
+
+ {{ session.stats.tests.running }} +
+
+ +
+
Done
+
+ {{ session.stats.tests.complete }} +
+
+ +
+
Error
+
+ {{ session.stats.tests.failed }} +
+
+
+
+
+ +
+
Untested
+
+ {{ session.stats.objectives.untested }} +
+
+ +
+
Skipped
+
+ {{ session.stats.objectives.skipped }} +
+
+ +
+
Error
+
+ {{ session.stats.objectives.error }} +
+
+ +
+
Info
+
+ {{ session.stats.objectives.info }} +
+
+ +
+
Success
+
+ {{ session.stats.objectives.success }} +
+
+ +
+
Failure
+
+ {{ session.stats.objectives.failure }} +
+
+
+
+
+ {% csrf_token %} + +
+
+ No sessions found. +
+
+
+
+
+
+ {% csrf_token %} +
+ +
+
+
+
+
New Audit Session: Available Tests
+ {{available_tests|length}} test{{ available_tests|pluralize }} +
+ +
+ +
+ {% if available_tests %} + + + + {% endif %} +
+ +
+ {% csrf_token %} + +
+
+ +
+ + + + + + + + + + + + + {% for test in available_tests %} + + + + + + + + + + + + {% empty %} + + + + {% endfor %} + +
NamePackageTargetsDescriptionConfig
+ + {{ test.name }}{{ test.package }} + {% if "windows" in test.targets|lower %} + + {% endif %} + + {% if "linux" in test.targets|lower %} + + {% endif %} + + +
{{ test.description }}
+ {% if test.payload_notes %} +
+ + {{ test.payload_notes }} +
+ {% endif %} +
+ +
+
+
+ {% include "audit/partials/task_config.html" with test=test config_pretty=test.task_config_pretty colspan=5 %} +
+
+
+ + No tests available. Place tests in the audits location and reload. +
+
+
+
+ +
+ + {% if paging.show_test_next == "show" or paging.show_test_prev == "show" %} + {% endif %} +
+ + + +{% endblock %} diff --git a/web/templates/audit/partials/objective_item.html b/web/templates/audit/partials/objective_item.html new file mode 100644 index 00000000000..f04067ceaca --- /dev/null +++ b/web/templates/audit/partials/objective_item.html @@ -0,0 +1,56 @@ +
+
+ +
+ {{ obj.template.name }} +
{{ obj.template.requirement }}
+
+ +
+ +
+ {% if obj.state_reason %} + + {{ obj.state_reason }} + + {% endif %} +
+ +
+ {% if obj.state == "success" %} + + REQUIREMENT MET + + {% elif obj.state == "failure" %} + + REQUIREMENT UNMET + + {% elif obj.state == "untested" %} + + REQUIREMENT UNTESTED + + {% elif obj.state == "skipped" %} + + EVALUATION SKIPPED + + {% elif obj.state == "error" %} + + EVALUATION ERROR + + {% elif obj.state == "info" %} + + INFORMATIONAL RESULT + + {% endif %} +
+
+
+ + {% if obj.children %} +
+ {% for child in obj.children %} + {% include "audit/partials/objective_item.html" with obj=child %} + {% endfor %} +
+ {% endif %} +
\ No newline at end of file diff --git a/web/templates/audit/partials/session_status_header.html b/web/templates/audit/partials/session_status_header.html new file mode 100644 index 00000000000..adc6871238f --- /dev/null +++ b/web/templates/audit/partials/session_status_header.html @@ -0,0 +1,199 @@ +
+
+
+ +
+

Audit Session #{{ session.id }}

+ + {{ session.added_on|date:"Y-m-d H:i" }} + +
+ +
+
Session Tests
+ +
+
+
Unqueued
+
+ {{ stats.tests.unqueued }} +
+
+ +
+
Queued
+
+ {{ stats.tests.queued }} +
+
+ +
+
Running
+
+ {{ stats.tests.running }} +
+
+ +
+
Done
+
+ {{ stats.tests.complete }} +
+
+ +
+
Error
+
+ {{ stats.tests.failed }} +
+
+
+
+ +
+
Objective Outcomes
+
+ +
+
Untested
+
+ {{ stats.objectives.untested }} +
+
+ +
+
Skipped
+
+ {{ stats.objectives.skipped }} +
+
+ +
+
Error
+
+ {{ stats.objectives.error }} +
+
+ +
+
Info
+
+ {{ stats.objectives.info }} +
+
+ +
+
Success
+
+ {{ stats.objectives.success }} +
+
+ +
+
Failure
+
+ {{ stats.objectives.failure }} +
+
+
+
+ + +
+ +
+ + + + + +
+
+
+
+
+ + \ No newline at end of file diff --git a/web/templates/audit/partials/session_test_run.html b/web/templates/audit/partials/session_test_run.html new file mode 100644 index 00000000000..4b5cce43d2a --- /dev/null +++ b/web/templates/audit/partials/session_test_run.html @@ -0,0 +1,115 @@ +
+
+
+
+ {{ run.test_definition.name }} +
+ {{ run.test_definition.description }} | Package: {{ run.test_definition.package }} +
+ +
+ +
+
Task
+
+ {% if cape_task %} + + #{{ cape_task.id }}{{ cape_task.status|upper }} + + {% else %} + NONE + {% endif %} +
+
+ +
+
Options
+
+ +
+
+ +
+
Timing
+
+ +
+
+ +
+
Test #{{ run.id }}
+
+ {% if run.status == "unqueued" %} + Not Queued + {% elif run.status == "queued" %} + Queued + {% elif run.status == "running" %} + {% if cape_task.status == "completed" %}Awaiting Report{% else %}Running{% endif %} + {% elif run.status == "complete" %} + Complete + {% elif run.status == "failed" %} + Failed + {% else %} + {{run.status}} + {% endif %} +
+
+ + {% if run.status == "unqueued" or run.status == "queued" %} +
+ +
+ {% if run.status == "unqueued" %} + + {% else %} + + {% endif %} +
+
+ {% endif %} + +
+ + + +
+ +
+
+ {% include "audit/partials/task_config.html" with test=run.test_definition config_pretty=run.task_config_pretty colspan=5 %} +
+
+ +
+ {% include "audit/partials/timing_item.html" with cape_task=cape_task diagnostics=diagnostics colspan=5 %} +
+ +
+
+ Requirement +
+
+ Evaluation +
+
+ +
+
+ {% for obj in run.objectives %} + {% if not obj.parent_id %} +
+ {% include "audit/partials/objective_item.html" with obj=obj %} +
+ {% endif %} + {% endfor %} +
+
+
\ No newline at end of file diff --git a/web/templates/audit/partials/task_config.html b/web/templates/audit/partials/task_config.html new file mode 100644 index 00000000000..da87c8b7e2e --- /dev/null +++ b/web/templates/audit/partials/task_config.html @@ -0,0 +1,145 @@ +
+
+
+
+
+ + task_config.json + +
+ + Cannot edit while auto-refreshing + + + +
+
+ +
+
{{ config_pretty }}
+
+ +
+
+ {% csrf_token %} + +
+ + +
+
+
+
+

+ + Note: This is the currently stored configuration. + Edits will not be reflected in previous tasks. +

+
+
+
+
+
+ \ No newline at end of file diff --git a/web/templates/audit/partials/timing_item.html b/web/templates/audit/partials/timing_item.html new file mode 100644 index 00000000000..4b28ca667e9 --- /dev/null +++ b/web/templates/audit/partials/timing_item.html @@ -0,0 +1,42 @@ +
+
+
+
+
+ +
+ + Task Timeline + +
+ +
+
+
+
Added
+
{{ cape_task.added_on|default:"--" }}
+
+
+
Start Delay
+
{{ diagnostics.start_wait|default:"--" }}
+
Started
+
{{ cape_task.started_on|default:"--" }}
+
+
Run Time
+
{{ diagnostics.run_time|default:"--" }}
+
Completed
+
{{ cape_task.completed_on|default:"--" }}
+
+
+
Processing Wait
+
{{ diagnostics.report_wait|default:"--" }}
+
Reported
+
{{ cape_task.reporting_finished_on|default:"--" }}
+
+
+
+
+
+
+
+
\ No newline at end of file diff --git a/web/templates/audit/session.html b/web/templates/audit/session.html new file mode 100644 index 00000000000..f456761b87d --- /dev/null +++ b/web/templates/audit/session.html @@ -0,0 +1,371 @@ +{% extends "base.html" %} +{% block content %} + + + +{% if messages|length > 0 %} +
+
Messages
+
    + {% for message in messages %} + {% if message.tags != "success" %} +
  • + {{ message }} +
  • + {% endif %} + {% endfor %} +
+
+{% endif %} + + +
+ +
+ + +
+ +
+
+
+
+
+ + +
+ +
+
+ + +
+ {% include "audit/partials/session_status_header.html" with stats=stats session=session csrf_token=csrf_token %} +
+ +
+ {% for run in session.runs %} + {% with pagehtml=run_html|get_item:run.id %} + {{ pagehtml|safe }} + {% endwith %} + {% endfor %} +
+ + + +{% endblock %} diff --git a/web/templates/header.html b/web/templates/header.html index bf797ca81f4..5c823d17db4 100644 --- a/web/templates/header.html +++ b/web/templates/header.html @@ -17,6 +17,9 @@ + {% if settings.AUDIT_FRAMEWORK %} + + {% endif %} {% if settings.WEB_AUTHENTICATION %}