Skip to content

Commit

Permalink
create dataloaders for the other 2 tasks and called evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
sago2693 committed Aug 21, 2023
1 parent 8e6bf92 commit 35636f9
Showing 1 changed file with 38 additions and 2 deletions.
40 changes: 38 additions & 2 deletions multitask_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from datasets import SentenceClassificationDataset, SentencePairDataset, \
load_multitask_data, load_multitask_test_data

from evaluation import model_eval_sst, test_model_multitask
from evaluation import model_eval_sst, test_model_multitask, model_eval_multitask


TQDM_DISABLE=True
Expand Down Expand Up @@ -82,6 +82,7 @@ def predict_paraphrase(self,
Note that your output should be unnormalized (a logit); it will be passed to the sigmoid function
during evaluation, and handled as a logit by the appropriate loss function.
'''

### TODO
raise NotImplementedError

Expand Down Expand Up @@ -121,14 +122,33 @@ def train_multitask(args):
# Create the data and its corresponding datasets and dataloader
sst_train_data, num_labels,para_train_data, sts_train_data = load_multitask_data(args.sst_train,args.para_train,args.sts_train, split ='train')
sst_dev_data, num_labels,para_dev_data, sts_dev_data = load_multitask_data(args.sst_dev,args.para_dev,args.sts_dev, split ='train')


#Sentiment analysis
sst_train_data = SentenceClassificationDataset(sst_train_data, args)
sst_dev_data = SentenceClassificationDataset(sst_dev_data, args)

sst_train_dataloader = DataLoader(sst_train_data, shuffle=True, batch_size=args.batch_size,
collate_fn=sst_train_data.collate_fn)
sst_dev_dataloader = DataLoader(sst_dev_data, shuffle=False, batch_size=args.batch_size,
collate_fn=sst_dev_data.collate_fn)

#Paraphrasing
paraphrase_train_data = SentencePairDataset(para_train_data, args, isRegression =False)
paraphrase_dev_data = SentencePairDataset(para_dev_data, args, isRegression =False)

paraphrase_train_dataloader = DataLoader(paraphrase_train_data, shuffle=True, batch_size=args.batch_size,
collate_fn=paraphrase_train_data.collate_fn)
paraphrase_dev_dataloader = DataLoader(paraphrase_dev_data, shuffle=True, batch_size=args.batch_size,
collate_fn=paraphrase_dev_data.collate_fn)

#sts
sts_train_data = SentencePairDataset(sts_train_data, args, isRegression =True)
sts_dev_data = SentencePairDataset(sst_dev_data, args, isRegression =True)

sts_train_dataloader = DataLoader(sst_train_data, shuffle=True, batch_size=256,
collate_fn=sst_train_data.collate_fn)
sts_dev_dataloader = DataLoader(sst_dev_data, shuffle=True, batch_size=256,
collate_fn=sst_dev_data.collate_fn)

# Init model
config = {'hidden_dropout_prob': args.hidden_dropout_prob,
Expand Down Expand Up @@ -175,6 +195,22 @@ def train_multitask(args):
train_acc, train_f1, *_ = model_eval_sst(sst_train_dataloader, model, device)
dev_acc, dev_f1, *_ = model_eval_sst(sst_dev_dataloader, model, device)


## Adding multitask evaluation
#train
(train_paraphrase_accuracy, train_para_y_pred, train_para_sent_ids,
train_sentiment_accuracy,train_sst_y_pred, train_sst_sent_ids,
train_sts_corr, train_sts_y_pred, train_sts_sent_ids) = model_eval_multitask(sst_train_dataloader,
paraphrase_train_dataloader,sts_train_dataloader,model, model.device )

#dev
(dev_paraphrase_accuracy, dev_para_y_pred, dev_para_sent_ids,
dev_sentiment_accuracy,dev_sst_y_pred, dev_sst_sent_ids,
dev_sts_corr, dev_sts_y_pred, dev__sent_ids) = model_eval_multitask(sst_dev_dataloader,
paraphrase_dev_dataloader,sts_dev_dataloader,model, model.device )

#We have to weight or average the three sores to save the best model.
# In the diven code only sst is used
if dev_acc > best_dev_acc:
best_dev_acc = dev_acc
save_model(model, optimizer, args, config, args.filepath)
Expand Down

0 comments on commit 35636f9

Please sign in to comment.