From 36691d91385daf9d0e996cf466d37dbc2e356f55 Mon Sep 17 00:00:00 2001 From: lrdossan Date: Wed, 4 Sep 2024 16:22:28 +0200 Subject: [PATCH] polished cern caimira frontned model related methods --- .../api/controller/virus_report_controller.py | 3 +- .../calculator/report/virus_report_data.py | 43 +++++------- .../cern_caimira/apps/calculator/__init__.py | 3 +- .../apps/calculator/report/virus_report.py | 69 ++++++++++++++----- cern_caimira/tests/test_report_generator.py | 10 ++- 5 files changed, 72 insertions(+), 56 deletions(-) diff --git a/caimira/src/caimira/api/controller/virus_report_controller.py b/caimira/src/caimira/api/controller/virus_report_controller.py index 41bf4d72..9f070757 100644 --- a/caimira/src/caimira/api/controller/virus_report_controller.py +++ b/caimira/src/caimira/api/controller/virus_report_controller.py @@ -18,10 +18,9 @@ def generate_model(form_obj, data_registry): return form_obj.build_model(sample_size=sample_size) -def generate_report_results(form_obj, model): +def generate_report_results(form_obj): return rg.calculate_report_data( form=form_obj, - model=model, executor_factory=functools.partial( concurrent.futures.ThreadPoolExecutor, None, # TODO define report_parallelism ), diff --git a/caimira/src/caimira/calculator/report/virus_report_data.py b/caimira/src/caimira/calculator/report/virus_report_data.py index 8b6e3eca..bf4f5733 100644 --- a/caimira/src/caimira/calculator/report/virus_report_data.py +++ b/caimira/src/caimira/calculator/report/virus_report_data.py @@ -5,8 +5,6 @@ import typing import numpy as np import matplotlib.pyplot as plt -import urllib -import zlib from caimira.calculator.models import models, dataclass_utils, profiler, monte_carlo as mc from caimira.calculator.models.enums import ViralLoads @@ -123,7 +121,9 @@ def _calculate_co2_concentration(CO2_model, time, fn_name=None): @profiler.profile -def calculate_report_data(form: VirusFormData, model: models.ExposureModel, executor_factory: typing.Callable[[], concurrent.futures.Executor]) -> typing.Dict[str, typing.Any]: +def calculate_report_data(form: VirusFormData, executor_factory: typing.Callable[[], concurrent.futures.Executor]) -> typing.Dict[str, typing.Any]: + model: models.ExposureModel = form.build_model() + times = interesting_times(model) short_range_intervals = [interaction.presence.boundaries()[0] for interaction in model.short_range] @@ -191,7 +191,7 @@ def calculate_report_data(form: VirusFormData, model: models.ExposureModel, exec uncertainties_plot(prob, conditional_probability_data))) return { - "model_repr": repr(model), + "model": model, "times": list(times), "exposed_presence_intervals": exposed_presence_intervals, "short_range_intervals": short_range_intervals, @@ -330,26 +330,7 @@ def img2base64(img_data) -> str: return f'data:image/png;base64,{pic_hash}' -def generate_permalink(base_url, get_root_url, get_root_calculator_url, form: VirusFormData): - form_dict = VirusFormData.to_dict(form, strip_defaults=True) - - # Generate the calculator URL arguments that would be needed to re-create this - # form. - args = urllib.parse.urlencode(form_dict) - - # Then zlib compress + base64 encode the string. To be inverted by the - # /_c/ endpoint. - compressed_args = base64.b64encode(zlib.compress(args.encode())).decode() - qr_url = f"{base_url}{get_root_url()}/_c/{compressed_args}" - url = f"{base_url}{get_root_calculator_url()}?{args}" - - return { - 'link': url, - 'shortened': qr_url, - } - - -def manufacture_viral_load_scenarios_percentiles(model: mc.ExposureModel) -> typing.Dict[str, mc.ExposureModel]: +def calculate_vl_scenarios_percentiles(model: mc.ExposureModel) -> typing.Dict[str, mc.ExposureModel]: viral_load = model.concentration_model.infected.virus.viral_load_in_sputum scenarios = {} for percentil in (0.01, 0.05, 0.25, 0.5, 0.75, 0.95, 0.99): @@ -359,7 +340,9 @@ def manufacture_viral_load_scenarios_percentiles(model: mc.ExposureModel) -> typ ) scenarios[str(vl)] = np.mean( specific_vl_scenario.infection_probability()) - return scenarios + return { + 'alternative_viral_load': scenarios, + } def manufacture_alternative_scenarios(form: VirusFormData) -> typing.Dict[str, mc.ExposureModel]: @@ -451,7 +434,6 @@ def comparison_report( form: VirusFormData, report_data: typing.Dict[str, typing.Any], scenarios: typing.Dict[str, mc.ExposureModel], - sample_times: typing.List[float], executor_factory: typing.Callable[[], concurrent.futures.Executor], ): if (form.short_range_option == "short_range_no"): @@ -474,7 +456,7 @@ def comparison_report( results = executor.map( scenario_statistics, scenarios.values(), - [sample_times] * len(scenarios), + [report_data['times']] * len(scenarios), [compute_prob_exposure] * len(scenarios), timeout=60, ) @@ -485,3 +467,10 @@ def comparison_report( return { 'stats': statistics, } + + +def alternative_scenarios_data(form: VirusFormData, report_data: typing.Dict[str, typing.Any], executor_factory: typing.Callable[[], concurrent.futures.Executor]) -> typing.Dict[str, typing.Any]: + alternative_scenarios: typing.Dict[str, typing.Any] = manufacture_alternative_scenarios(form=form) + return { + 'alternative_scenarios': comparison_report(form=form, report_data=report_data, scenarios=alternative_scenarios, executor_factory=executor_factory) + } diff --git a/cern_caimira/src/cern_caimira/apps/calculator/__init__.py b/cern_caimira/src/cern_caimira/apps/calculator/__init__.py index 167bacd8..6eba145e 100644 --- a/cern_caimira/src/cern_caimira/apps/calculator/__init__.py +++ b/cern_caimira/src/cern_caimira/apps/calculator/__init__.py @@ -246,8 +246,7 @@ async def post(self) -> None: max_workers=self.settings['handler_worker_pool_size'], timeout=300, ) - model = virus_report_controller.generate_model(form, data_registry) - report_data_task = executor.submit(calculate_report_data, form, model, + report_data_task = executor.submit(calculate_report_data, form, executor_factory=functools.partial( concurrent.futures.ThreadPoolExecutor, self.settings['report_generation_parallelism'], diff --git a/cern_caimira/src/cern_caimira/apps/calculator/report/virus_report.py b/cern_caimira/src/cern_caimira/apps/calculator/report/virus_report.py index 7b292ef9..e6ed1200 100644 --- a/cern_caimira/src/cern_caimira/apps/calculator/report/virus_report.py +++ b/cern_caimira/src/cern_caimira/apps/calculator/report/virus_report.py @@ -5,13 +5,16 @@ import json import typing import jinja2 +import urllib +import zlib +import base64 import numpy as np from .. import markdown_tools from caimira.calculator.models import models from caimira.calculator.validators.virus.virus_validator import VirusFormData -from caimira.calculator.report.virus_report_data import calculate_report_data, interesting_times, manufacture_alternative_scenarios, manufacture_viral_load_scenarios_percentiles, comparison_report, generate_permalink +from caimira.calculator.report.virus_report_data import alternative_scenarios_data, calculate_report_data, calculate_vl_scenarios_percentiles def minutes_to_time(minutes: int) -> str: @@ -62,6 +65,25 @@ def non_zero_percentage(percentage: int) -> str: return "{:0.1f}%".format(percentage) +def generate_permalink(base_url, get_root_url, get_root_calculator_url, form: VirusFormData): + form_dict = VirusFormData.to_dict(form, strip_defaults=True) + + # Generate the calculator URL arguments that would be needed to re-create this + # form. + args = urllib.parse.urlencode(form_dict) + + # Then zlib compress + base64 encode the string. To be inverted by the + # /_c/ endpoint. + compressed_args = base64.b64encode(zlib.compress(args.encode())).decode() + qr_url = f"{base_url}{get_root_url()}/_c/{compressed_args}" + url = f"{base_url}{get_root_calculator_url()}?{args}" + + return { + 'link': url, + 'shortened': qr_url, + } + + @dataclasses.dataclass class VirusReportGenerator: jinja_loader: jinja2.BaseLoader @@ -74,44 +96,53 @@ def build_report( form: VirusFormData, executor_factory: typing.Callable[[], concurrent.futures.Executor], ) -> str: - model = form.build_model() context = self.prepare_context( - base_url, model, form, executor_factory=executor_factory) + base_url, form, executor_factory=executor_factory) return self.render(context) def prepare_context( self, base_url: str, - model: models.ExposureModel, form: VirusFormData, executor_factory: typing.Callable[[], concurrent.futures.Executor], ) -> dict: now = datetime.utcnow().astimezone() time = now.strftime("%Y-%m-%d %H:%M:%S UTC") - data_registry_version = f"v{model.data_registry.version}" if model.data_registry.version else None context = { - 'model': model, 'form': form, 'creation_date': time, - 'data_registry_version': data_registry_version, } - scenario_sample_times = interesting_times(model) - report_data = calculate_report_data( - form, model, executor_factory=executor_factory) + # Main report data + report_data = calculate_report_data(form, executor_factory) context.update(report_data) - alternative_scenarios = manufacture_alternative_scenarios(form) - context['alternative_viral_load'] = manufacture_viral_load_scenarios_percentiles( - model) if form.conditional_probability_viral_loads else None - context['alternative_scenarios'] = comparison_report( - form, report_data, alternative_scenarios, scenario_sample_times, executor_factory=executor_factory, - ) - context['permalink'] = generate_permalink( + # Model and Data Registry + model: models.ExposureModel = report_data['model'] + data_registry_version: typing.Optional[str] = f"v{model.data_registry.version}" if model.data_registry.version else None + + # Alternative scenarios data + alternative_scenarios: typing.Dict[str,typing.Any] = alternative_scenarios_data(form, report_data, executor_factory) + context.update(alternative_scenarios) + + # Alternative viral load data + if form.conditional_probability_viral_loads: + alternative_viral_load: typing.Dict[str,typing.Any] = calculate_vl_scenarios_percentiles(model) + context.update(alternative_viral_load) + + # Permalink + permalink: typing.Dict[str, str] = generate_permalink( base_url, self.get_root_url, self.get_root_calculator_url, form) - context['get_url'] = self.get_root_url - context['get_calculator_url'] = self.get_root_calculator_url + + # URLs (root, calculator and permalink) + context.update({ + 'model_repr': repr(model), + 'data_registry_version': data_registry_version, + 'permalink': permalink, + 'get_url': self.get_root_url, + 'get_calculator_url': self.get_root_calculator_url, + }) return context diff --git a/cern_caimira/tests/test_report_generator.py b/cern_caimira/tests/test_report_generator.py index 372cad77..6e257f61 100644 --- a/cern_caimira/tests/test_report_generator.py +++ b/cern_caimira/tests/test_report_generator.py @@ -104,23 +104,21 @@ def test_interesting_times_w_temp(exposure_model_w_outside_temp_changes): np.testing.assert_allclose(result, expected) -def test_expected_new_cases(baseline_form_with_sr: VirusFormData): - model = baseline_form_with_sr.build_model() - +def test_expected_new_cases(baseline_form_with_sr: VirusFormData): executor_factory = partial( concurrent.futures.ThreadPoolExecutor, 1, ) # Short- and Long-range contributions - report_data = rep_gen.calculate_report_data(baseline_form_with_sr, model, executor_factory) + report_data = rep_gen.calculate_report_data(baseline_form_with_sr, executor_factory) sr_lr_expected_new_cases = report_data['expected_new_cases'] sr_lr_prob_inf = report_data['prob_inf']/100 # Long-range contributions alone - scenario_sample_times = rep_gen.interesting_times(model) + scenario_sample_times = report_data['times'] alternative_scenarios = rep_gen.manufacture_alternative_scenarios(baseline_form_with_sr) alternative_statistics = rep_gen.comparison_report( - baseline_form_with_sr, report_data, alternative_scenarios, scenario_sample_times, executor_factory=executor_factory, + baseline_form_with_sr, report_data, alternative_scenarios, executor_factory=executor_factory, ) lr_expected_new_cases = alternative_statistics['stats']['Base scenario without short-range interactions']['expected_new_cases']