Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Models trained on the Meta Learning interface do not support test functions other than accuracy #1379

Open
castelojb opened this issue Jul 4, 2022 · 3 comments
Assignees
Labels
bug / fix Something isn't working help wanted Extra attention is needed

Comments

@castelojb
Copy link

castelojb commented Jul 4, 2022

🐛 Bug

First of all, congratulations for working at a high level with the interface using learn2learn. The bug is that when a model is trained using the meta learning method and then submitted to trainer.test, it does not use other test functions present in test_metrics. Also, it doesn't give me the results of a prediction in the usual way, model(x), it returns None

To Reproduce

seed_everything(42)

datamodule = ImageClassificationData.from_data_frame(
     "path",
     "class",
     train_data_frame=train,
     val_data_frame = validate,
     test_data_frame = test,
      transform_kwargs=dict(image_size=(128, 128)),
     batch_size=2
     )

model = ImageClassifier(
    backbone="resnet18",
    training_strategy="maml",
    pretrained=False,
    training_strategy_kwargs={
        "epoch_length": 50,
        "meta_batch_size": 2,
        "num_tasks": 50,
        "test_num_tasks": 50,
        "ways": datamodule.num_classes,
        "shots": 2,
        "test_ways": 2,
        "test_shots": 1,
        # "test_queries": 15,
    },
    optimizer=torch.optim.Adam,
    learning_rate=0.001,
)

trainer = flash.Trainer(
    max_epochs=50,
    precision=16,
    accelerator="ddp_shared",
    gpus=int(torch.cuda.is_available()),
)

trainer.fit(model, datamodule=datamodule )

# 5. Save the model!
trainer.save_checkpoint(path)

# read from the reading from the model saved above
model_trn = ImageClassifier.load_from_checkpoint(path)

from torchmetrics import F1Score, Precision, Recall, Accuracy, 

model_trn .test_metrics['F1-Score'] = F1Score(6, average='macro')

model_trn .test_metrics['Precision'] = Precision(6, average='macro')

model_trn .test_metrics['Recall'] = Recall(6, average='macro')

trainer.test(model_trn, dataloaders=datamodule .test_dataloader())

>>> [{'test_accuracy': 0.8799999952316284, 'test_loss': 0.280563086271286}]

test_loader = datamodule.test_dataloader()

data_iter = iter(test_loader)
sample_ = next(data_iter)
input = sample_['input']

model_trn(input)
>>> None

Expected behavior

I performed the same testing process on a trained resnet18 model without meta learning and it came out as expected

model_trn = ImageClassifier.load_from_checkpoint(path)

from torchmetrics import F1Score, Precision, Recall, 

model_trn.test_metrics['F1-Score'] = F1Score(6, average='macro')

model_trn.test_metrics['Precision'] = Precision(6, average='macro')

model_trn.test_metrics['Recall'] = Recall(6, average='macro')

trainer.test(model_trn, dataloaders=datamodule.test_dataloader())

>>> [{'test_F1-Score': 0.5142857432365417,
  'test_Precision': 0.4933333694934845,
  'test_Recall': 0.5800000429153442,
  'test_accuracy': 0.5882353186607361,
  'test_cross_entropy': 1.0256577730178833}]

test_loader = datamodule.test_dataloader()

data_iter = iter(test_loader)
sample_ = next(data_iter)
input = sample_['input']

model_trn(input).shape
>>> (2, 6)

Environment

Everything has been tested on google colab with GPU. I used the following commands to install the libs

!pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git'
!pip install 'git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[image]'

I don't know why, but the import and the use of the keyword "maml" only work if you perform these installations previously

!pip install learn2learn
!pip install Pillow==7.1.2
@castelojb castelojb added bug / fix Something isn't working help wanted Extra attention is needed labels Jul 4, 2022
@stale stale bot added the won't fix This will not be worked on label Sep 22, 2022
@ethanwharris ethanwharris removed the won't fix This will not be worked on label Oct 1, 2022
@Lightning-Universe Lightning-Universe deleted a comment from stale bot Dec 23, 2022
@Borda
Copy link
Member

Borda commented Dec 23, 2022

@castelojb would you be interested in helping us to debug/extend this case? 🦦

@castelojb
Copy link
Author

Sounds good to me! But how can I help with this?

@Borda
Copy link
Member

Borda commented Jan 5, 2023

Sounds good to me! But how can I help with this?

lets ping/pair with @krshrimali or @ethanwharris to give you inside... :)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
bug / fix Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants