-
Notifications
You must be signed in to change notification settings - Fork 0
/
fit_panel.py
executable file
·103 lines (90 loc) · 3.03 KB
/
fit_panel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!/usr/bin/env python3
from pathlib import Path
import pandas as pd
import os
import stats
from mgs import Enrichment, MGSData, target_bioprojects
from pathogens import predictors_by_taxid
MODEL_OUTPUT_DIR = "model_output"
def summarize_output(coeffs: pd.DataFrame) -> pd.DataFrame:
return coeffs.groupby(
[
"pathogen",
"tidy_name",
"taxids",
"predictor_type",
"study",
"location",
]
).ra_at_1in100.describe(percentiles=[0.05, 0.25, 0.5, 0.75, 0.95])
def start(num_samples: int, plot: bool) -> None:
figdir = os.path.join(MODEL_OUTPUT_DIR, "model_panel_fig")
if plot:
os.makedirs(figdir, exist_ok=True)
mgs_data = MGSData.from_repo()
input_data = []
output_data = []
study_pathogen_rhats = {}
for (
pathogen_name,
tidy_name,
predictor_type,
taxids,
predictors,
) in predictors_by_taxid():
taxids_str = "_".join(str(t) for t in taxids)
for study, bioprojects in target_bioprojects.items():
if study in ["brinch", "spurbeck"]:
print(f"Skipping {study} for {pathogen_name}")
continue
enrichment = Enrichment.PANEL
model = stats.build_model(
mgs_data,
bioprojects,
predictors,
taxids,
random_seed=sum(taxids),
enrichment=enrichment,
)
if model is None:
continue
model.fit_model(num_samples=num_samples)
rhat = model.get_rhat()
study_pathogen_rhats[f"{study}, {tidy_name}"] = rhat
if plot:
taxid_str = "-".join(str(tid) for tid in taxids)
model.plot_figures(
path=figdir,
prefix=f"{pathogen_name}-{taxid_str}-{predictor_type}-{study}",
)
metadata = dict(
pathogen=pathogen_name,
tidy_name=tidy_name,
taxids=taxids_str,
predictor_type=predictor_type,
study=study,
)
input_data.append(model.input_df.assign(**metadata))
output_data.append(model.get_coefficients().assign(**metadata))
input = pd.concat(input_data)
input.to_csv(
os.path.join(MODEL_OUTPUT_DIR, "panel_input.tsv"),
sep="\t",
index=False,
)
coeffs = pd.concat(output_data)
coeffs.to_csv(
os.path.join(MODEL_OUTPUT_DIR, "panel_fits.tsv"), sep="\t", index=False
)
summary = summarize_output(coeffs)
summary.to_csv(
os.path.join(MODEL_OUTPUT_DIR, "panel_fits_summary.tsv"), sep="\t"
)
print(
"Model fitting of panel-amplified samples complete\nR-hat statistics:"
)
for pathogen_and_study, rhat in study_pathogen_rhats.items():
print(f"{pathogen_and_study}: rhat={rhat}")
if __name__ == "__main__":
# TODO: Command line arguments
start(num_samples=8000, plot=True)