forked from CogStack/MedCATtrainer
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #200 from CogStack/meta-cat-model-preds
CU-2e77aae MetaCAT model Predictions
- Loading branch information
Showing
18 changed files
with
607 additions
and
186 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
98 changes: 98 additions & 0 deletions
98
webapp/api/api/migrations/0082_remove_metacatmodel_meta_task_and_more.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Generated by Django 5.0.6 on 2024-08-28 10:56 | ||
|
||
import django.db.models.deletion | ||
from django.db import migrations, models | ||
|
||
|
||
class Migration(migrations.Migration): | ||
|
||
dependencies = [ | ||
('api', '0081_alter_metatask_name'), | ||
] | ||
|
||
operations = [ | ||
migrations.RemoveField( | ||
model_name='metacatmodel', | ||
name='meta_task', | ||
), | ||
migrations.AddField( | ||
model_name='metaannotation', | ||
name='predicted_meta_task_value', | ||
field=models.ForeignKey(blank=True, help_text='meta annotation predicted by a MetaAnnotationModel', null=True, on_delete=django.db.models.deletion.CASCADE, related_name='predicted_value', to='api.metataskvalue'), | ||
), | ||
migrations.AddField( | ||
model_name='metatask', | ||
name='prediction_model', | ||
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='api.metacatmodel'), | ||
), | ||
migrations.AddField( | ||
model_name='project', | ||
name='meta_cat_predictions', | ||
field=models.BooleanField(default=False, help_text='If MetaTasks are setup on the project and there are associated MetaCATModel instances, display these predictions in the interface to be validated / corrected'), | ||
), | ||
migrations.AddField( | ||
model_name='projectannotateentities', | ||
name='model_pack', | ||
field=models.ForeignKey(blank=True, default=None, help_text='A MedCAT model pack. This will raise an exception if both the CDB and Vocab and ModelPack fields are set', null=True, on_delete=django.db.models.deletion.SET_NULL, to='api.modelpack'), | ||
), | ||
migrations.AddField( | ||
model_name='projectgroup', | ||
name='meta_cat_predictions', | ||
field=models.BooleanField(default=False, help_text='If MetaTasks are setup on the project and there are associated MetaCATModel instances, display these predictions in the interface to be validated / corrected'), | ||
), | ||
migrations.AddField( | ||
model_name='projectgroup', | ||
name='model_pack', | ||
field=models.ForeignKey(blank=True, default=None, help_text='A MedCAT model pack. This will raise an exception if both the CDB and Vocab and ModelPack fields are set', null=True, on_delete=django.db.models.deletion.SET_NULL, to='api.modelpack'), | ||
), | ||
migrations.AlterField( | ||
model_name='metaannotation', | ||
name='validated', | ||
field=models.BooleanField(default=False, help_text='If an annotation is not '), | ||
), | ||
migrations.AlterField( | ||
model_name='metacatmodel', | ||
name='meta_cat_dir', | ||
field=models.FilePathField(allow_folders=True, editable=False, help_text='The zip or dir for a MetaCAT model, not editable, is set via a model pack .zip upload'), | ||
), | ||
migrations.AlterField( | ||
model_name='metacatmodel', | ||
name='name', | ||
field=models.CharField(help_text='The task name followed by the underlying model impl', max_length=100), | ||
), | ||
migrations.AlterField( | ||
model_name='projectannotateentities', | ||
name='concept_db', | ||
field=models.ForeignKey(blank=True, help_text='The MedCAT CDB used to annotate / validate', null=True, on_delete=django.db.models.deletion.SET_NULL, to='api.conceptdb'), | ||
), | ||
migrations.AlterField( | ||
model_name='projectannotateentities', | ||
name='tasks', | ||
field=models.ManyToManyField(blank=True, default=None, help_text='The set of MetaAnnotation tasks configured for this project, this will default to the set of Tasks configured in a ModelPack if a model pack is used for the project', to='api.metatask'), | ||
), | ||
migrations.AlterField( | ||
model_name='projectannotateentities', | ||
name='vocab', | ||
field=models.ForeignKey(blank=True, help_text='The MedCAT Vocab used to annotate / validate', null=True, on_delete=django.db.models.deletion.SET_NULL, to='api.vocabulary'), | ||
), | ||
migrations.AlterField( | ||
model_name='projectgroup', | ||
name='cdb_search_filter', | ||
field=models.ManyToManyField(blank=True, help_text='The CDB that will be used for concept lookup. This specific CDB should have been "imported" via the CDB admin screen', related_name='project_group_concept_source', to='api.conceptdb'), | ||
), | ||
migrations.AlterField( | ||
model_name='projectgroup', | ||
name='concept_db', | ||
field=models.ForeignKey(blank=True, help_text='The MedCAT CDB used to annotate / validate', null=True, on_delete=django.db.models.deletion.SET_NULL, to='api.conceptdb'), | ||
), | ||
migrations.AlterField( | ||
model_name='projectgroup', | ||
name='tasks', | ||
field=models.ManyToManyField(blank=True, default=None, help_text='The set of MetaAnnotation tasks configured for this project, this will default to the set of Tasks configured in a ModelPack if a model pack is used for the project', to='api.metatask'), | ||
), | ||
migrations.AlterField( | ||
model_name='projectgroup', | ||
name='vocab', | ||
field=models.ForeignKey(blank=True, help_text='The MedCAT Vocab used to annotate / validate', null=True, on_delete=django.db.models.deletion.SET_NULL, to='api.vocabulary'), | ||
), | ||
] |
23 changes: 23 additions & 0 deletions
23
webapp/api/api/migrations/0083_project_prepared_documents_and_more.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Generated by Django 5.0.6 on 2024-08-29 11:25 | ||
|
||
from django.db import migrations, models | ||
|
||
|
||
class Migration(migrations.Migration): | ||
|
||
dependencies = [ | ||
('api', '0082_remove_metacatmodel_meta_task_and_more'), | ||
] | ||
|
||
operations = [ | ||
migrations.AddField( | ||
model_name='project', | ||
name='prepared_documents', | ||
field=models.ManyToManyField(blank=True, default=None, help_text='Set automatically on each prep of a document', related_name='prepared_documents', to='api.document'), | ||
), | ||
migrations.AlterField( | ||
model_name='project', | ||
name='validated_documents', | ||
field=models.ManyToManyField(blank=True, default=None, help_text='Set automatically on each doc submission', to='api.document'), | ||
), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import logging | ||
import os | ||
from typing import Dict | ||
|
||
import pkg_resources | ||
from medcat.cat import CAT | ||
from medcat.cdb import CDB | ||
from medcat.vocab import Vocab | ||
|
||
from api.models import ConceptDB | ||
|
||
""" | ||
Module level caches for CDBs, Vocabs and CAT instances. | ||
""" | ||
# Maps between IDs and objects | ||
CDB_MAP = {} | ||
VOCAB_MAP = {} | ||
CAT_MAP = {} | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def get_medcat_from_cdb_vocab(project, | ||
cdb_map: Dict[str, CDB]=CDB_MAP, | ||
vocab_map: Dict[str, Vocab]=VOCAB_MAP, | ||
cat_map: Dict[str, CAT]=CAT_MAP) -> CAT: | ||
cdb_id = project.concept_db.id | ||
vocab_id = project.vocab.id | ||
cat_id = str(cdb_id) + "-" + str(vocab_id) | ||
if cat_id in cat_map: | ||
cat = cat_map[cat_id] | ||
else: | ||
if cdb_id in cdb_map: | ||
cdb = cdb_map[cdb_id] | ||
else: | ||
cdb_path = project.concept_db.cdb_file.path | ||
try: | ||
cdb = CDB.load(cdb_path) | ||
except KeyError as ke: | ||
mc_v = pkg_resources.get_distribution('medcat').version | ||
if int(mc_v.split('.')[0]) > 0: | ||
logger.error('Attempted to load MedCAT v0.x model with MCTrainer v1.x') | ||
raise Exception('Attempted to load MedCAT v0.x model with MCTrainer v1.x', | ||
'Please re-configure this project to use a MedCAT v1.x CDB or consult the ' | ||
'MedCATTrainer Dev team if you believe this should work') from ke | ||
raise | ||
|
||
custom_config = os.getenv("MEDCAT_CONFIG_FILE") | ||
if custom_config is not None and os.path.exists(custom_config): | ||
cdb.config.parse_config_file(path=custom_config) | ||
else: | ||
logger.info("No MEDCAT_CONFIG_FILE env var set to valid path, using default config available on CDB") | ||
cdb_map[cdb_id] = cdb | ||
|
||
if vocab_id in vocab_map: | ||
vocab = vocab_map[vocab_id] | ||
else: | ||
vocab_path = project.vocab.vocab_file.path | ||
vocab = Vocab.load(vocab_path) | ||
vocab_map[vocab_id] = vocab | ||
cat = CAT(cdb=cdb, config=cdb.config, vocab=vocab) | ||
cat_map[cat_id] = cat | ||
return cat | ||
|
||
|
||
def get_medcat_from_model_pack(project, cat_map: Dict[str, CAT]=CAT_MAP) -> CAT: | ||
model_pack_obj = project.model_pack | ||
cat_id = 'mp' + str(model_pack_obj.id) | ||
logger.info('Loading model pack from:%s', model_pack_obj.model_pack.path) | ||
cat = CAT.load_model_pack(model_pack_obj.model_pack.path) | ||
cat_map[cat_id] = cat | ||
return cat | ||
|
||
|
||
def get_medcat(project, | ||
cdb_map: Dict[str, CDB]=CDB_MAP, | ||
vocab_map: Dict[str, Vocab]=VOCAB_MAP, | ||
cat_map: Dict[str, CAT]=CAT_MAP): | ||
try: | ||
if project.model_pack is None: | ||
cat = get_medcat_from_cdb_vocab(project, cdb_map, vocab_map, cat_map) | ||
else: | ||
cat = get_medcat_from_model_pack(project, cat_map) | ||
return cat | ||
except AttributeError: | ||
raise Exception('Failure loading Project ConceptDB, Vocab or Model Pack. Are these set correctly?') | ||
|
||
|
||
def get_cached_medcat(project, cat_map: Dict[str, CAT]=CAT_MAP): | ||
if project.concept_db is None or project.vocab is None: | ||
return None | ||
cdb_id = project.concept_db.id | ||
vocab_id = project.vocab.id | ||
cat_id = str(cdb_id) + "-" + str(vocab_id) | ||
return cat_map.get(cat_id) | ||
|
||
|
||
def clear_cached_medcat(project, cat_map: Dict[str, CAT]=CAT_MAP): | ||
cdb_id = project.concept_db.id | ||
vocab_id = project.vocab.id | ||
cat_id = str(cdb_id) + "-" + str(vocab_id) | ||
if cat_id in cat_map: | ||
del cat_map[cat_id] | ||
|
||
|
||
def get_cached_cdb(cdb_id: str, cdb_map: Dict[str, CDB]=CDB_MAP) -> CDB: | ||
if cdb_id not in cdb_map: | ||
cdb_obj = ConceptDB.objects.get(id=cdb_id) | ||
cdb = CDB.load(cdb_obj.cdb_file.path) | ||
cdb_map[cdb_id] = cdb | ||
return cdb_map[cdb_id] | ||
|
||
|
||
def clear_cached_cdb(cdb_id, cdb_map: Dict[str, CDB]=CDB_MAP): | ||
if cdb_id in cdb_map: | ||
del cdb_map[cdb_id] | ||
|
||
|
||
def is_model_loaded(project, | ||
cdb_map: Dict[str, CDB]=CDB_MAP, | ||
cat_map: Dict[str, CAT]=CAT_MAP): | ||
if project.concept_db is None: | ||
# model pack is used. | ||
return False if not project.model_pack else f'mp{project.model_pack.id}' in cat_map | ||
else: | ||
return False if not project.concept_db else project.concept_db.id in cdb_map |
Oops, something went wrong.