Skip to content

Commit

Permalink
polished cern caimira frontned model related methods
Browse files Browse the repository at this point in the history
  • Loading branch information
lrdossan committed Sep 4, 2024
1 parent ebde0e1 commit 36691d9
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@ def generate_model(form_obj, data_registry):
return form_obj.build_model(sample_size=sample_size)


def generate_report_results(form_obj, model):
def generate_report_results(form_obj):
return rg.calculate_report_data(
form=form_obj,
model=model,
executor_factory=functools.partial(
concurrent.futures.ThreadPoolExecutor, None, # TODO define report_parallelism
),
Expand Down
43 changes: 16 additions & 27 deletions caimira/src/caimira/calculator/report/virus_report_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import typing
import numpy as np
import matplotlib.pyplot as plt
import urllib
import zlib

from caimira.calculator.models import models, dataclass_utils, profiler, monte_carlo as mc
from caimira.calculator.models.enums import ViralLoads
Expand Down Expand Up @@ -123,7 +121,9 @@ def _calculate_co2_concentration(CO2_model, time, fn_name=None):


@profiler.profile
def calculate_report_data(form: VirusFormData, model: models.ExposureModel, executor_factory: typing.Callable[[], concurrent.futures.Executor]) -> typing.Dict[str, typing.Any]:
def calculate_report_data(form: VirusFormData, executor_factory: typing.Callable[[], concurrent.futures.Executor]) -> typing.Dict[str, typing.Any]:
model: models.ExposureModel = form.build_model()

times = interesting_times(model)
short_range_intervals = [interaction.presence.boundaries()[0]
for interaction in model.short_range]
Expand Down Expand Up @@ -191,7 +191,7 @@ def calculate_report_data(form: VirusFormData, model: models.ExposureModel, exec
uncertainties_plot(prob, conditional_probability_data)))

return {
"model_repr": repr(model),
"model": model,
"times": list(times),
"exposed_presence_intervals": exposed_presence_intervals,
"short_range_intervals": short_range_intervals,
Expand Down Expand Up @@ -330,26 +330,7 @@ def img2base64(img_data) -> str:
return f'data:image/png;base64,{pic_hash}'


def generate_permalink(base_url, get_root_url, get_root_calculator_url, form: VirusFormData):
form_dict = VirusFormData.to_dict(form, strip_defaults=True)

# Generate the calculator URL arguments that would be needed to re-create this
# form.
args = urllib.parse.urlencode(form_dict)

# Then zlib compress + base64 encode the string. To be inverted by the
# /_c/ endpoint.
compressed_args = base64.b64encode(zlib.compress(args.encode())).decode()
qr_url = f"{base_url}{get_root_url()}/_c/{compressed_args}"
url = f"{base_url}{get_root_calculator_url()}?{args}"

return {
'link': url,
'shortened': qr_url,
}


def manufacture_viral_load_scenarios_percentiles(model: mc.ExposureModel) -> typing.Dict[str, mc.ExposureModel]:
def calculate_vl_scenarios_percentiles(model: mc.ExposureModel) -> typing.Dict[str, mc.ExposureModel]:
viral_load = model.concentration_model.infected.virus.viral_load_in_sputum
scenarios = {}
for percentil in (0.01, 0.05, 0.25, 0.5, 0.75, 0.95, 0.99):
Expand All @@ -359,7 +340,9 @@ def manufacture_viral_load_scenarios_percentiles(model: mc.ExposureModel) -> typ
)
scenarios[str(vl)] = np.mean(
specific_vl_scenario.infection_probability())
return scenarios
return {
'alternative_viral_load': scenarios,
}


def manufacture_alternative_scenarios(form: VirusFormData) -> typing.Dict[str, mc.ExposureModel]:
Expand Down Expand Up @@ -451,7 +434,6 @@ def comparison_report(
form: VirusFormData,
report_data: typing.Dict[str, typing.Any],
scenarios: typing.Dict[str, mc.ExposureModel],
sample_times: typing.List[float],
executor_factory: typing.Callable[[], concurrent.futures.Executor],
):
if (form.short_range_option == "short_range_no"):
Expand All @@ -474,7 +456,7 @@ def comparison_report(
results = executor.map(
scenario_statistics,
scenarios.values(),
[sample_times] * len(scenarios),
[report_data['times']] * len(scenarios),
[compute_prob_exposure] * len(scenarios),
timeout=60,
)
Expand All @@ -485,3 +467,10 @@ def comparison_report(
return {
'stats': statistics,
}


def alternative_scenarios_data(form: VirusFormData, report_data: typing.Dict[str, typing.Any], executor_factory: typing.Callable[[], concurrent.futures.Executor]) -> typing.Dict[str, typing.Any]:
alternative_scenarios: typing.Dict[str, typing.Any] = manufacture_alternative_scenarios(form=form)
return {
'alternative_scenarios': comparison_report(form=form, report_data=report_data, scenarios=alternative_scenarios, executor_factory=executor_factory)
}
3 changes: 1 addition & 2 deletions cern_caimira/src/cern_caimira/apps/calculator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,7 @@ async def post(self) -> None:
max_workers=self.settings['handler_worker_pool_size'],
timeout=300,
)
model = virus_report_controller.generate_model(form, data_registry)
report_data_task = executor.submit(calculate_report_data, form, model,
report_data_task = executor.submit(calculate_report_data, form,
executor_factory=functools.partial(
concurrent.futures.ThreadPoolExecutor,
self.settings['report_generation_parallelism'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import json
import typing
import jinja2
import urllib
import zlib
import base64
import numpy as np

from .. import markdown_tools

from caimira.calculator.models import models
from caimira.calculator.validators.virus.virus_validator import VirusFormData
from caimira.calculator.report.virus_report_data import calculate_report_data, interesting_times, manufacture_alternative_scenarios, manufacture_viral_load_scenarios_percentiles, comparison_report, generate_permalink
from caimira.calculator.report.virus_report_data import alternative_scenarios_data, calculate_report_data, calculate_vl_scenarios_percentiles


def minutes_to_time(minutes: int) -> str:
Expand Down Expand Up @@ -62,6 +65,25 @@ def non_zero_percentage(percentage: int) -> str:
return "{:0.1f}%".format(percentage)


def generate_permalink(base_url, get_root_url, get_root_calculator_url, form: VirusFormData):
form_dict = VirusFormData.to_dict(form, strip_defaults=True)

# Generate the calculator URL arguments that would be needed to re-create this
# form.
args = urllib.parse.urlencode(form_dict)

# Then zlib compress + base64 encode the string. To be inverted by the
# /_c/ endpoint.
compressed_args = base64.b64encode(zlib.compress(args.encode())).decode()
qr_url = f"{base_url}{get_root_url()}/_c/{compressed_args}"
url = f"{base_url}{get_root_calculator_url()}?{args}"

return {
'link': url,
'shortened': qr_url,
}


@dataclasses.dataclass
class VirusReportGenerator:
jinja_loader: jinja2.BaseLoader
Expand All @@ -74,44 +96,53 @@ def build_report(
form: VirusFormData,
executor_factory: typing.Callable[[], concurrent.futures.Executor],
) -> str:
model = form.build_model()
context = self.prepare_context(
base_url, model, form, executor_factory=executor_factory)
base_url, form, executor_factory=executor_factory)
return self.render(context)

def prepare_context(
self,
base_url: str,
model: models.ExposureModel,
form: VirusFormData,
executor_factory: typing.Callable[[], concurrent.futures.Executor],
) -> dict:
now = datetime.utcnow().astimezone()
time = now.strftime("%Y-%m-%d %H:%M:%S UTC")

data_registry_version = f"v{model.data_registry.version}" if model.data_registry.version else None
context = {
'model': model,
'form': form,
'creation_date': time,
'data_registry_version': data_registry_version,
}

scenario_sample_times = interesting_times(model)
report_data = calculate_report_data(
form, model, executor_factory=executor_factory)
# Main report data
report_data = calculate_report_data(form, executor_factory)
context.update(report_data)

alternative_scenarios = manufacture_alternative_scenarios(form)
context['alternative_viral_load'] = manufacture_viral_load_scenarios_percentiles(
model) if form.conditional_probability_viral_loads else None
context['alternative_scenarios'] = comparison_report(
form, report_data, alternative_scenarios, scenario_sample_times, executor_factory=executor_factory,
)
context['permalink'] = generate_permalink(
# Model and Data Registry
model: models.ExposureModel = report_data['model']
data_registry_version: typing.Optional[str] = f"v{model.data_registry.version}" if model.data_registry.version else None

# Alternative scenarios data
alternative_scenarios: typing.Dict[str,typing.Any] = alternative_scenarios_data(form, report_data, executor_factory)
context.update(alternative_scenarios)

# Alternative viral load data
if form.conditional_probability_viral_loads:
alternative_viral_load: typing.Dict[str,typing.Any] = calculate_vl_scenarios_percentiles(model)
context.update(alternative_viral_load)

# Permalink
permalink: typing.Dict[str, str] = generate_permalink(
base_url, self.get_root_url, self.get_root_calculator_url, form)
context['get_url'] = self.get_root_url
context['get_calculator_url'] = self.get_root_calculator_url

# URLs (root, calculator and permalink)
context.update({
'model_repr': repr(model),
'data_registry_version': data_registry_version,
'permalink': permalink,
'get_url': self.get_root_url,
'get_calculator_url': self.get_root_calculator_url,
})

return context

Expand Down
10 changes: 4 additions & 6 deletions cern_caimira/tests/test_report_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,21 @@ def test_interesting_times_w_temp(exposure_model_w_outside_temp_changes):
np.testing.assert_allclose(result, expected)


def test_expected_new_cases(baseline_form_with_sr: VirusFormData):
model = baseline_form_with_sr.build_model()

def test_expected_new_cases(baseline_form_with_sr: VirusFormData):
executor_factory = partial(
concurrent.futures.ThreadPoolExecutor, 1,
)

# Short- and Long-range contributions
report_data = rep_gen.calculate_report_data(baseline_form_with_sr, model, executor_factory)
report_data = rep_gen.calculate_report_data(baseline_form_with_sr, executor_factory)
sr_lr_expected_new_cases = report_data['expected_new_cases']
sr_lr_prob_inf = report_data['prob_inf']/100

# Long-range contributions alone
scenario_sample_times = rep_gen.interesting_times(model)
scenario_sample_times = report_data['times']
alternative_scenarios = rep_gen.manufacture_alternative_scenarios(baseline_form_with_sr)
alternative_statistics = rep_gen.comparison_report(
baseline_form_with_sr, report_data, alternative_scenarios, scenario_sample_times, executor_factory=executor_factory,
baseline_form_with_sr, report_data, alternative_scenarios, executor_factory=executor_factory,
)

lr_expected_new_cases = alternative_statistics['stats']['Base scenario without short-range interactions']['expected_new_cases']
Expand Down

0 comments on commit 36691d9

Please sign in to comment.