Skip to content
Permalink
main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
from torch import optim
from transformers import BertModel
import torch
from transformers import Adafactor, AdamW, get_linear_schedule_with_warmup
import os
from models.pretrain_model import PretrainModel
from utils.data_utils import E2EABSA_dataset, Tokenizer
import torch.nn as nn
from torch.utils.data import DataLoader
from datetime import datetime
from utils.metrics import F1, FocalLoss, compute_kl_loss, ContrastiveLoss
from utils.result_helper import init_logger
from config import config
import time
from allennlp.modules.elmo import Elmo
from torch.nn.utils import clip_grad_norm_
from tensorboardX import SummaryWriter
logger = init_logger(logging_folder=config.working_path.joinpath('checkout'),
logging_file=config.working_path.joinpath("checkout\\training_log.txt"))
class Trainer(object):
def __init__(self, model, tokenizer, args):
self.model = model
self.tokenizer = tokenizer
self.args = args
if args.load_model:
self.model.load_state_dict(torch.load(args.state_dict_path))
self.model = self.model.to(args.device)
self.model_name = f"{args.model_name}-{args.downstream}"
self.max_val_step = 0
self.max_val_acc = 0
self.train_metric = 0
self.train_loss = 0
self.step = 0
self.time = 0
self.min_metrics = 0.50 # min F1 metrics to save model
if args.loss == "CE":
self.weight = [0.1, 0.8, 1., 1., 1.2, 1.2, 1.2, 1., 1., 1.]
# self.weight = [0.07, 1.0, 1.0, 1.0, 1.0, 1.0, 1.1, 1.0, 1.0, 1.0]
criterion_weight = torch.tensor(self.weight).to(self.args.device)
self.criterion = nn.CrossEntropyLoss(ignore_index=self.tokenizer.target_pad_token_id,
weight=criterion_weight)
elif args.loss == "focal":
self.weight = [args.alpha for _ in range(self.args.num_classes)]
self.criterion = FocalLoss(class_num=args.num_classes,
alpha=self.weight,
gamma=args.gamma,
ignore_index=self.tokenizer.target_pad_token_id,
device=args.device)
else:
assert f"loss function {args.loss} only implement 'CE' , 'focal"
self.optimizer = self.args.optimizer(self.model.parameters(),
**self.args.optimizer_kwargs)
self.scheduler = get_linear_schedule_with_warmup(self.optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=args.max_steps)
if args.metrics == "f1": # future work, change metrics
self.metrics = F1(args.num_classes, downstream=args.downstream)
else:
assert "--metrics only implement f1"
self.dev_dataloader = DataLoader(E2EABSA_dataset(file_path=self.args.file_path['dev'],
tokenizer=self.tokenizer),
batch_size=self.args.batch_size,
shuffle=self.args.shuffle,
drop_last=True)
self.train_dataloader = DataLoader(E2EABSA_dataset(file_path=self.args.file_path['train'],
tokenizer=self.tokenizer),
batch_size=self.args.batch_size,
shuffle=self.args.shuffle,
drop_last=True)
if self.args.contrastive:
self.contrastive_loss = ContrastiveLoss(temp=args.temp)
def _gen_inputs(self, data):
inputs = data["text_ids"].to(self.args.device)
target = data["pred_ids"].to(self.args.device)
attention_mask = data["att_mask"].to(self.args.device)
return inputs, target, attention_mask
def _train_epoch(self, epoch):
self.model.train()
TP, FP, FN = 0, 0, 0
for data in self.train_dataloader:
self.optimizer.zero_grad()
inputs, target, attention_mask = self._gen_inputs(data)
if self.model_name.endswith('crf'):
loss, logits = self.model(inputs, attention_mask=attention_mask, labels=target)
loss = loss / self.args.batch_size
output = self.model.crf.viterbi_tags(logits=logits, mask=attention_mask)
output = [x + [self.tokenizer.target_pad_token_id] * (self.args.max_seq_len - len(x))
for x in output]
output = torch.tensor(output, dtype=torch.long, device=self.args.device)
else:
if self.args.augument:
addition_loss = 0
output,q = self.model(inputs, attention_mask=attention_mask)
output2,p = self.model(inputs, attention_mask=attention_mask)
if self.args.rdrop:
addition_loss = self.args.rdrop_alpha * compute_kl_loss(q,p,pad_mask=attention_mask) + addition_loss
if self.args.contrastive:
addition_loss = self.args.contrastive_alpha * self.contrastive_loss(p,q,attention_mask) + addition_loss
loss = self.criterion(output.view(-1, self.args.num_classes), target.view(-1))
loss2 = self.criterion(output2.view(-1, self.args.num_classes), target.view(-1))
loss = (loss2 + loss)/2 + addition_loss
else:
output = self.model(inputs, attention_mask=attention_mask)
loss = self.criterion(output.view(-1, self.args.num_classes), target.view(-1))
loss.backward()
dTP, dFP, dFN = self.metrics(output, target, attention_mask)
TP += dTP
FP += dFP
FN += dFN
if self.args.clip_large_grad:
clip_grad_norm_(self.model.parameters(),
max_norm=self.args.max_grad_norm,
norm_type=2.0)
self.optimizer.step()
self.scheduler.step()
self.train_loss += loss
self.step += 1
if self.step % self.args.step == 0:
self.train_metric = self.metrics.get_f1(TP, FP, FN)
self._checkout(epoch)
self.time = time.time()
self.train_loss, self.train_metric = 0, 0
TP, FP, FN = 0, 0, 0
self.model.train()
def _dev_epoch(self):
self.model.eval()
dev_losses = 0
TP, FP, FN = 0, 0, 0
count = 0
with torch.no_grad():
for data in self.dev_dataloader:
count += 1
inputs, target, attention_mask = self._gen_inputs(data)
if self.model_name.endswith('crf'):
loss, logits = self.model(inputs, attention_mask=attention_mask, labels=target)
loss = loss / self.args.batch_size
output = self.model.crf.viterbi_tags(logits=logits, mask=attention_mask)
# padding outputs
output = [x + [self.tokenizer.target_pad_token_id] * (self.args.max_seq_len - len(x))
for x in output]
output = torch.tensor(output, dtype=torch.long, device=self.args.device)
else:
output = self.model(inputs, attention_mask=attention_mask)
if self.args.augument:
output = output[0]
loss = self.criterion(output.view(-1, self.args.num_classes), target.view(-1))
dTP, dFP, dFN = self.metrics(output, target, attention_mask)
TP += dTP
FP += dFP
FN += dFN
dev_losses += loss
return dev_losses / count, self.metrics.get_f1(TP, FP, FN, verbose=self.args.verbose)
def run(self):
self.update_loss = 0
times = datetime.now()
self.writer = SummaryWriter(r'runs\{self.args.mode}_{self.args.seed}_{self.model_name}_{times}')
if not os.path.exists(r'checkout\state_dict'):
os.mkdir(r'checkout\state_dict')
logger.info(f"\n>>>>>>>>>>>>>>>>>>>>>{datetime.now()}>>>>>>>>>>>>>>>>>>>>>>>>")
for arg in vars(self.args):
logger.info(f'>>> {arg}: {getattr(self.args, arg)}')
logger.info(f">>> class weight(or alpha) {self.weight}")
self.time = time.time()
for epoch in range(self.args.epochs + 1):
self._train_epoch(epoch)
if self.step > self.args.max_steps:
break
self.writer.close()
def _checkout(self, epoch):
train_loss = self.train_loss / self.args.step
dev_loss, dev_metrics = self._dev_epoch()
logger.info(f"> Epoch: {epoch} Step: {self.step}, "
f"train loss: {train_loss:.4f} "
f"{self.metrics.name}: {self.train_metric * 100:.2f}% "
f"dev loss: {dev_loss:.4f} "
f"{self.metrics.name}: {dev_metrics * 100:.2f}% "
f"{(time.time() - self.time) / 60:.2f} min")
self.writer.add_scalar(tag="Train_loss",scalar_value=train_loss,global_step=self.step)
self.writer.add_scalar(tag="Train_F1", scalar_value=self.train_metric * 100,global_step=self.step)
self.writer.add_scalar("Dev_loss",dev_loss,self.step)
self.writer.add_scalar("Dev_F1",dev_metrics * 100,self.step)
if dev_metrics > self.max_val_acc:
self.max_val_acc = dev_metrics
if dev_metrics > self.min_metrics:
path = f'checkout/state_dict/{self.model_name}_' \
f'{self.args.mode}_seed{self.args.seed}.pth'
torch.save(self.model.state_dict(), path)
print(f'>> saved: {path}')
def main(args):
optimizers = {
'adadelta': torch.optim.Adadelta, # default lr=1.0
'adagrad': torch.optim.Adagrad, # default lr=0.01
'adam': torch.optim.Adam, # default lr=0.001
'adamax': torch.optim.Adamax, # default lr=0.002
'asgd': torch.optim.ASGD, # default lr=0.01
'rmsprop': torch.optim.RMSprop, # default lr=0.01
'sgd': torch.optim.SGD,
'adamw': AdamW,
'Adafactor': Adafactor
}
default_optim_kwargs = {'lr': args.lr, 'weight_decay': args.weight_decay}
optimizers_kwargs = {
'adadelta': {},
'adagrad': {},
'adam': {"betas": (args.adam_beta1, args.adam_beta2),
"eps": args.adam_epsilon,
"amsgrad": args.adam_amsgrad},
'adamax': {},
'asgd': {},
'rmsprop': {},
'sgd': {},
'adamw': {"betas": (args.adam_beta1, args.adam_beta2),
"eps": args.adam_epsilon},
'Adafactor': {"scale_parameter": False, "relative_step": False}
}
assert args.optimizer in list(optimizers.keys()), \
f"Optimizer only support {list(optimizers.keys())}"
if args.augument:
assert args.model_name == "bert" and not args.downstream.endswith('crf'), "augument not support crf and elmo"
args.optimizer_kwargs = optimizers_kwargs[args.optimizer]
args.optimizer_kwargs.update(default_optim_kwargs)
args.optimizer = optimizers[args.optimizer]
args.device = torch.device('cuda:0')
model = None
tokenizer = Tokenizer(args=args)
if args.model_name.startswith("bert"):
print(f"> Loading bert model {args.pretrained_bert_name}")
bert = BertModel.from_pretrained(args.pretrained_bert_name)
model = PretrainModel(pretrain_model=bert, args=args)
elif args.model_name.startswith("elmo"):
print(f"> Loading elmo model")
elmo = Elmo(options_file=config.options_file,
weight_file=config.weight_file,
num_output_representations=1,
dropout=0,
requires_grad=args.finetune_elmo)
model = PretrainModel(pretrain_model=elmo, args=args)
elif args.model_name.startswith("glove"):
print(f"> Loading glove model")
model = PretrainModel(pretrain_model='glove', args=args)
else:
assert f"model {args.model_name} not implement "
trainer = Trainer(model=model, tokenizer=tokenizer, args=args)
trainer.run()
if __name__ == "__main__":
main(config.args)