-
Notifications
You must be signed in to change notification settings - Fork 0
/
2024-03-05-json-prussin-rosario-human-viruses.py
250 lines (203 loc) · 7.78 KB
/
2024-03-05-json-prussin-rosario-human-viruses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
#!/usr/bin/env python3
import gzip
import json
import os
import subprocess
import typing
from pathlib import Path
import matplotlib.pyplot as plt # type: ignore
import matplotlib.ticker as ticker # type: ignore
import numpy as np
import pandas as pd
import seaborn as sns # type: ignore
from matplotlib.gridspec import GridSpec # type: ignore
from collections import defaultdict
from scipy import stats
dashboard = os.path.expanduser("~/code/mgs-pipeline/dashboard/")
with open(os.path.join(dashboard, "human_virus_sample_counts.json")) as inf:
human_virus_sample_counts = json.load(inf)
with open(os.path.join(dashboard, "metadata_samples.json")) as inf:
metadata_samples = json.load(inf)
with open(os.path.join(dashboard, "metadata_bioprojects.json")) as inf:
metadata_bioprojects = json.load(inf)
with open(os.path.join(dashboard, "metadata_papers.json")) as inf:
metadata_papers = json.load(inf)
with open(os.path.join(dashboard, "taxonomic_names.json")) as inf:
taxonomic_names = json.load(inf)
studies = list(metadata_papers.keys())
def load_taxonomic_data() -> dict[int, tuple[str, int]]:
parents = {}
with open(os.path.join(dashboard, "nodes.dmp")) as inf:
for line in inf:
child_taxid, parent_taxid, child_rank, *_ = line.replace(
"\t|\n", ""
).split("\t|\t")
parent_taxid = int(parent_taxid)
child_taxid = int(child_taxid)
child_rank = child_rank.strip()
parents[child_taxid] = (child_rank, parent_taxid)
return parents
def get_family(taxid: int, parents: dict[int, tuple[str, int]]) -> int:
original_taxid = taxid
try:
current_rank, parent_taxid = parents[original_taxid]
except KeyError:
print(f"Taxid {original_taxid} not found in parents")
return None
max_tries = 15
tries=0
while current_rank != "family":
if tries > max_tries:
print(f"Reached max tries for taxid {original_taxid}")
return None
current_taxid = parent_taxid
try:
current_rank, parent_taxid = parents[current_taxid]
except KeyError:
print(f"Taxid {current_taxid} not found in parents")
tries += 1
else:
family_taxid = current_taxid
return family_taxid
def test_get_family():
parents = load_taxonomic_data()
assert get_family(11676, parents) == 11632
assert get_family(694009, parents) == 11118
print("get_family tests passed!")
def get_taxid_name(
target_taxid: int, taxonomic_names: dict[str, list[str]]
) -> str:
tax_name = taxonomic_names[f"{target_taxid}"][0]
return tax_name
def assemble_plotting_dfs() -> tuple[pd.DataFrame, pd.DataFrame]:
parents = load_taxonomic_data() #DEBUG
sample_ras = defaultdict(list)
bar_plot_data = []
for study in studies:
if study not in ["Prussin 2019", "Rosario 2018"]:
continue
for bioproject in metadata_papers[study]["projects"]:
samples = metadata_bioprojects[bioproject]
for sample in samples:
na_type = metadata_samples[sample]["na_type"]
if study == "Prussin 2019":
sample_type = metadata_samples[sample]["sample_type"]
season = metadata_samples[sample]["season"]
sampling_range = metadata_samples[sample]["sampling_range"]
if sample_type != "hvac_filter":
continue
if "Control" in sampling_range or "Unexposed" in sampling_range:
print(f"Excluding {sample} due to it being a control")
continue
if season not in [
"Winter",
"Spring",
"Summer",
"Fall",
"Closed",
]:
continue
human_virus_counts = {}
human_virus_reads = 0
for taxid in human_virus_sample_counts.keys():
n_assignments = human_virus_sample_counts[taxid].get(sample, 0)
human_virus_counts[taxid] = n_assignments
human_virus_reads += n_assignments
bar_plot_data.append(
{
"study": study,
"sample": sample,
"na_type": na_type,
"hv_reads": human_virus_reads,
**human_virus_counts,
}
)
df = pd.DataFrame(bar_plot_data)
species_taxids = df.columns[4:]
parents = load_taxonomic_data()
species_to_family = {
taxid: get_family(int(taxid), parents) for taxid in species_taxids
}
df.rename(columns=species_to_family, inplace=True)
df = df.groupby(df.columns, axis=1).sum() # summing family counts
df = df.melt(
id_vars=["study", "sample", "na_type", "hv_reads"],
var_name="taxid",
value_name="reads",
)
df = df.groupby(["study", "na_type", "taxid"]).reads.sum().reset_index()
df = df[df.reads != 0]
df["relative_abundance"] = df.groupby(["study", "na_type"])[
"reads"
].transform(lambda x: x / x.sum())
N_TOP_TAXA = 9
top_taxa = (
df.groupby("taxid").relative_abundance.sum().nlargest(N_TOP_TAXA).index
)
top_taxa_rows = df[df.taxid.isin(top_taxa)]
top_taxa_rows["hv_family"] = top_taxa_rows["taxid"].apply(
lambda x: get_taxid_name(x, taxonomic_names)
)
minor_taxa = df[~df.taxid.isin(top_taxa)]["taxid"].unique()
minor_taxa_rows = (
df[df.taxid.isin(minor_taxa)]
.groupby(["study", "na_type"])
.agg(
{
"relative_abundance": "sum",
}
)
).reset_index()
minor_taxa_rows["hv_family"] = "minor_taxa"
df = pd.concat([top_taxa_rows, minor_taxa_rows])
return df
def barplot(df):
ten_color_palette = [
"#8dd3c7",
"#f1c232",
"#bebada",
"#fb8072",
"#80b1d3",
"#fdb462",
"#b3de69",
"#fccde5",
"#bc80bd",
"#d9d9d9",
]
reads_per_study_and_na = df.groupby(["study", "na_type"]).reads.sum().reset_index()
reads_per_study_and_na = list(reads_per_study_and_na.itertuples(index=False, name=None))
df_pivot = df.pivot_table(index=['study', 'na_type'],
columns='hv_family',
values='relative_abundance')
# drop brackets from bar labels
fig, ax = plt.subplots(figsize=(10, 5))
df_pivot.plot(kind='barh', stacked=True, color=ten_color_palette, ax=ax)
ax.invert_yaxis()
ax.set_xlabel("Relative abundance among human-infecting virus families")
ax.tick_params(left=False)
ax.set_ylabel("")
ax.tick_params(left=False, labelright=True, labelleft=False)
ax.set_xlim(right=1, left=0)
ax.legend(
loc=(0.035, -0.72),
ncol=4,
fontsize=11.1,
frameon=False,
)
y_positions = [0.87, 0.63, 0.37, 0.13] # FIXME: hardcoded
for ypos, (study, na_type, reads) in zip(y_positions, reads_per_study_and_na):
ax.text(-0.02, ypos, f"{int(reads)} reads", transform=ax.transAxes, fontsize=11.1, ha="right", va="center")
sns.despine(top=True, right=True, left=True, bottom=False)
plt.tight_layout()
plt.savefig("barplot_json.png", bbox_inches="tight", dpi=300)
return ax
def save_plot(fig, figdir: Path, name: str) -> None:
for ext in ["pdf", "png"]:
fig.savefig(figdir / f"{name}.{ext}", bbox_inches="tight", dpi=900)
def start():
df = assemble_plotting_dfs()
# run test
test_get_family()
barplot(df)
if __name__ == "__main__":
start()