Skip to content

Commit

Permalink
service data is not a form parameter but passed as an argument to the…
Browse files Browse the repository at this point in the history
… model constructor
  • Loading branch information
lrdossan committed Jul 28, 2023
1 parent 91cba7c commit 41205d4
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 39 deletions.
5 changes: 2 additions & 3 deletions caimira/apps/calculator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,8 @@ async def post(self) -> None:
self.send_error(500, reason=error_message)
try:
# As an example
fetched_service_data = {'data': {'foo': 'bar', 'viral_load': {'num': 40, 'scale_factor': 7.01, 'shape_factor': 3.47, 'start': 0.01, 'stop': 0.99}}, 'version': 1}
requested_model_config['fetched_service_data'] = json.dumps(fetched_service_data['data'])
form = model_generator.FormData.from_dict(requested_model_config)
# fetched_service_data = {'evaporation_factor': 0.9, 'foo': 'bar', 'viral_load': {'num': 40, 'scale_factor': 7.01, 'shape_factor': 3.47, 'start': 0.01, 'stop': 0.99}}
form = model_generator.FormData.from_dict(requested_model_config, fetched_service_data)

except Exception as err:
if self.settings.get("debug", False):
Expand Down
2 changes: 1 addition & 1 deletion caimira/apps/calculator/data_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ async def fetch(self):
),
raise_error=True)

return json.loads(response.body)
return json.loads(response.body)['data']

1 change: 0 additions & 1 deletion caimira/apps/calculator/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
'sensor_in_use': '',
'short_range_option': 'short_range_no',
'short_range_interactions': '[]',
'fetched_service_data': '{}'
}

# ------------------ Activities ----------------------
Expand Down
37 changes: 22 additions & 15 deletions caimira/apps/calculator/model_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,19 @@

@dataclasses.dataclass
class ServiceData:
evaporation_factor: float
virus_distributions: dict


_evaporation_factor: float
_virus_distributions: dict

def get_evaporation_factor(self) -> float:
return self._evaporation_factor

def get_virus_distribution(self, virus_type: str) -> float:
return self._virus_distributions[virus_type]

def get_virus_distributions(self) -> dict:
return self._virus_distributions


@dataclasses.dataclass
class FormData:
activity_type: str
Expand Down Expand Up @@ -98,13 +107,12 @@ class FormData:
sensor_in_use: str
short_range_option: str
short_range_interactions: list
fetched_service_data: dict

_SERVICE_DATA: ServiceData
_DEFAULTS: typing.ClassVar[typing.Dict[str, typing.Any]] = DEFAULTS
_SERVICE_DATA: ServiceData = DataGenerator().generate_data_from_parameters()

@classmethod
def from_dict(cls, form_data: typing.Dict) -> "FormData":
def from_dict(cls, form_data: typing.Dict, service_data: typing.Optional[dict] = None) -> "FormData":
# Take a copy of the form data so that we can mutate it.
form_data = form_data.copy()
form_data.pop('_xsrf', None)
Expand All @@ -127,10 +135,10 @@ def from_dict(cls, form_data: typing.Dict) -> "FormData":
if key not in cls._DEFAULTS:
raise ValueError(f'Invalid argument "{html.escape(key)}" given')

# Populate Service Data with data that comes from the form data
cls._SERVICE_DATA = DataGenerator(form_data['fetched_service_data']).generate_data_from_parameters()
instance = cls(**form_data)
# Populate Service Data with data that comes from the database
_sd = DataGenerator(service_data).generate_data_from_parameters()

instance = cls(**form_data, _SERVICE_DATA=_sd)
instance.validate()
return instance

Expand All @@ -142,7 +150,7 @@ def to_dict(cls, form: "FormData", strip_defaults: bool = False) -> dict:
}

for attr, value in form_dict.items():
if attr in _CAST_RULES_NATIVE_TO_FORM_ARG and attr != '_SERVICE_DATA':
if attr in _CAST_RULES_NATIVE_TO_FORM_ARG:
form_dict[attr] = _CAST_RULES_NATIVE_TO_FORM_ARG[attr](value)

if strip_defaults:
Expand Down Expand Up @@ -356,7 +364,7 @@ def build_mc_model(self) -> mc.ExposureModel:
room=room,
ventilation=self.ventilation(),
infected=infected_population,
evaporation_factor=self._SERVICE_DATA.evaporation_factor,
evaporation_factor=self._SERVICE_DATA.get_evaporation_factor(),
),
short_range = tuple(short_range),
exposed=self.exposed_population(),
Expand Down Expand Up @@ -515,7 +523,7 @@ def generate_precise_activity_expiration(self) -> typing.Tuple[typing.Any, ...]:

def infected_population(self) -> mc.InfectedPopulation:
# Initializes the virus
virus = self._SERVICE_DATA.virus_distributions[self.virus_type]
virus = self._SERVICE_DATA.get_virus_distribution(self.virus_type)

activity_index = ACTIVITY_TYPES.index(self.activity_type)
activity_defn = ACTIVITIES[activity_index]['activity']
Expand Down Expand Up @@ -869,7 +877,6 @@ def baseline_raw_form_data() -> typing.Dict[str, typing.Union[str, float]]:
'window_opening_regime': 'windows_open_permanently',
'short_range_option': 'short_range_no',
'short_range_interactions': '[]',
'fetched_service_data': '{}',
}


Expand Down
13 changes: 7 additions & 6 deletions caimira/monte_carlo/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,39 +206,40 @@ def generate_viral_load_distribution(self):

def generate_data_from_parameters(self):
# From https://doi.org/10.1101/2021.10.14.21264988 and refererences therein
_viral_load = self.generate_viral_load_distribution()
virus_distributions = {
'SARS_CoV_2': mc.SARSCoV2(
viral_load_in_sputum=self.generate_viral_load_distribution(),
viral_load_in_sputum=_viral_load,
infectious_dose=infectious_dose_distribution,
viable_to_RNA_ratio=viable_to_RNA_ratio_distribution,
transmissibility_factor=1.,
),
'SARS_CoV_2_ALPHA': mc.SARSCoV2(
viral_load_in_sputum=self.generate_viral_load_distribution(),
viral_load_in_sputum=_viral_load,
infectious_dose=infectious_dose_distribution,
viable_to_RNA_ratio=viable_to_RNA_ratio_distribution,
transmissibility_factor=0.78,
),
'SARS_CoV_2_BETA': mc.SARSCoV2(
viral_load_in_sputum=self.generate_viral_load_distribution(),
viral_load_in_sputum=_viral_load,
infectious_dose=infectious_dose_distribution,
viable_to_RNA_ratio=viable_to_RNA_ratio_distribution,
transmissibility_factor=0.8,
),
'SARS_CoV_2_GAMMA': mc.SARSCoV2(
viral_load_in_sputum=self.generate_viral_load_distribution(),
viral_load_in_sputum=_viral_load,
infectious_dose=infectious_dose_distribution,
viable_to_RNA_ratio=viable_to_RNA_ratio_distribution,
transmissibility_factor=0.72,
),
'SARS_CoV_2_DELTA': mc.SARSCoV2(
viral_load_in_sputum=self.generate_viral_load_distribution(),
viral_load_in_sputum=_viral_load,
infectious_dose=infectious_dose_distribution,
viable_to_RNA_ratio=viable_to_RNA_ratio_distribution,
transmissibility_factor=0.51,
),
'SARS_CoV_2_OMICRON': mc.SARSCoV2(
viral_load_in_sputum=self.generate_viral_load_distribution(),
viral_load_in_sputum=_viral_load,
infectious_dose=infectious_dose_distribution,
viable_to_RNA_ratio=viable_to_RNA_ratio_distribution,
transmissibility_factor=0.2,
Expand Down
2 changes: 1 addition & 1 deletion caimira/tests/apps/calculator/test_model_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from caimira.apps.calculator.model_generator import _hours2timestring
from caimira.apps.calculator.model_generator import minutes_since_midnight
from caimira import models
from caimira.monte_carlo.data import expiration_distributions, DataGenerator
from caimira.monte_carlo.data import expiration_distributions


def test_model_from_dict(baseline_form_data):
Expand Down
11 changes: 4 additions & 7 deletions caimira/tests/test_full_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@
from caimira import models,data
from caimira.utils import method_cache
from caimira.models import _VectorisedFloat,Interval,SpecificInterval
from caimira.monte_carlo.sampleable import LogNormal
from caimira.monte_carlo.data import (expiration_distributions,
expiration_BLO_factors,short_range_expiration_distributions,
short_range_distances,activity_distributions, DataGenerator)
short_range_distances,activity_distributions)
from caimira.monte_carlo.data import DataGenerator

SAMPLE_SIZE = 1_000_000
TOLERANCE = 0.04

sqrt2pi = np.sqrt(2.*np.pi)
sqrt2 = np.sqrt(2.)
ln2 = np.log(2)

_sd = DataGenerator().generate_data_from_parameters()
virus_distributions: dict = _sd.get_virus_distributions()

@dataclass(frozen=True)
class SimpleConcentrationModel:
Expand Down Expand Up @@ -484,7 +485,6 @@ def c_model() -> mc.ConcentrationModel:

@pytest.fixture
def c_model_distr() -> mc.ConcentrationModel:
virus_distributions = DataGenerator().generate_data_from_parameters().virus_distributions
return mc.ConcentrationModel(
room=models.Room(volume=50, humidity=0.3),
ventilation=models.AirChange(active=models.PeriodicInterval(
Expand Down Expand Up @@ -614,7 +614,6 @@ def expo_sr_model_distr(c_model_distr) -> mc.ExposureModel:

@pytest.fixture
def simple_expo_sr_model_distr() -> SimpleExposureModel:
virus_distributions = DataGenerator().generate_data_from_parameters().virus_distributions
return SimpleExposureModel(
infected_presence = presence,
viral_load = virus_distributions['SARS_CoV_2_DELTA'
Expand Down Expand Up @@ -717,7 +716,6 @@ def test_longrange_exposure(c_model):
"time", [11., 12.5, 17.]
)
def test_longrange_concentration_with_distributions(c_model_distr,time):
virus_distributions = DataGenerator().generate_data_from_parameters().virus_distributions
simple_expo_model = SimpleConcentrationModel(
infected_presence = presence,
viral_load = virus_distributions['SARS_CoV_2_DELTA'
Expand All @@ -735,7 +733,6 @@ def test_longrange_concentration_with_distributions(c_model_distr,time):


def test_longrange_exposure_with_distributions(c_model_distr):
virus_distributions = DataGenerator().generate_data_from_parameters().virus_distributions
simple_expo_model = SimpleExposureModel(
infected_presence = presence,
viral_load = virus_distributions['SARS_CoV_2_DELTA'
Expand Down
7 changes: 4 additions & 3 deletions caimira/tests/test_monte_carlo_full_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import caimira.monte_carlo as mc
from caimira import models,data
from caimira.monte_carlo.data import activity_distributions, expiration_distributions, infectious_dose_distribution, viable_to_RNA_ratio_distribution, DataGenerator
from caimira.monte_carlo.data import activity_distributions, expiration_distributions, infectious_dose_distribution, viable_to_RNA_ratio_distribution
from caimira.monte_carlo.data import DataGenerator
from caimira.apps.calculator.model_generator import build_expiration

SAMPLE_SIZE = 500_000
Expand Down Expand Up @@ -33,8 +34,8 @@
for month, temperatures in toronto_hourly_temperatures_celsius_per_hour.items()
}

virus_distributions = DataGenerator().generate_data_from_parameters().virus_distributions

_sd = DataGenerator().generate_data_from_parameters()
virus_distributions = _sd.get_virus_distributions()

# References values for infection_probability and expected new cases
# in the following tests, were obtained from the feature/mc branch
Expand Down
6 changes: 4 additions & 2 deletions caimira/tests/test_predefined_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import numpy.testing as npt
import pytest

from caimira.monte_carlo.data import activity_distributions, DataGenerator
from caimira.monte_carlo.data import activity_distributions
from caimira.monte_carlo.data import DataGenerator


# Mean & std deviations from https://doi.org/10.1101/2021.10.14.21264988 (Table 3)
Expand Down Expand Up @@ -39,6 +40,7 @@ def test_activity_distributions(distribution, mean, std):
]
)
def test_viral_load_logdistribution(distribution, mean, std):
virus = DataGenerator().generate_data_from_parameters().virus_distributions[distribution].build_model(size=1000000)
_sd = DataGenerator().generate_data_from_parameters()
virus = _sd.get_virus_distribution(distribution).build_model(size=1000000)
npt.assert_allclose(np.log10(virus.viral_load_in_sputum).mean(), mean, atol=0.01)
npt.assert_allclose(np.log10(virus.viral_load_in_sputum).std(), std, atol=0.01)

0 comments on commit 41205d4

Please sign in to comment.