携手创造,共同成长!这是我参与「日新方案 8 月更文应战」的第30天,点击查看活动详情

说明

本系列博客将记载自己学习的课程:NLP实战高手课,链接为:time.geekbang.org/course/intr…本篇为27-28节的课程笔记,首要介绍Pytorch中运用torchtext进行文本分类的示例,本篇博客将介绍怎么在上一篇博客记载的模型界说后进行练习和评价,最终导师给出了一些改善的测验方向。

评价函数的树立

IMDB数据集是一个典型的2分类数据集。为此,咱们运用准确率作为评价目标,该函数的界说如下:

def binary_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """
    #round predictions to the closest integer
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float() #convert into float for division 
    acc = correct.sum() / len(correct)
    return acc

模型练习

接下来,咱们介绍怎么界说模型的练习,其首要包含以下几个模块:

  • 设置模型为train模式
  • 从Dataloader中逐一batch的加载数据
  • 清零优化器的梯度
  • 进行模型的forward操作,得到输出预测
  • 核算输出预测与真实值之间的loss
  • loss进行反向传达将梯度进行回传
  • 优化器进行优化

按照上面的结构,咱们将代码书写如下:

def train(model, iterator, optimizer, criterion):
    epoch_loss = 0
    epoch_acc = 0
    model.train()
    for batch in iterator:        
        optimizer.zero_grad()        
        text, text_lengths = batch.text        
        predictions = model(text, text_lengths).squeeze(1)        
        loss = criterion(predictions, batch.label)        
        acc = binary_accuracy(predictions, batch.label)        
        loss.backward()        
        optimizer.step()        
        epoch_loss += loss.item()
        epoch_acc += acc.item()     
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

在这里,咱们还核算了练习的loss和练习集上的预测准确率作为参阅。

模型的验证

模型的验证模块首要衡量一个练习后的模型能够在验证集上的体现怎么。其收拾代码结构与练习相仿,但必须要注意的是,进行模型验证前一定要把模型设置为验证模式,此刻,模型在核算时的梯度将不会保存。

其对应的代码如下:

def evaluate(model, iterator, criterion):
    epoch_loss = 0
    epoch_acc = 0
    model.eval()
    with torch.no_grad():    
        for batch in iterator:
            text, text_lengths = batch.text            
            predictions = model(text, text_lengths).squeeze(1)      
            loss = criterion(predictions, batch.label)            
            acc = binary_accuracy(predictions, batch.label)
            epoch_loss += loss.item()
            epoch_acc += acc.item()
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

练习与评价

完结好以上的界说后,咱们终于能够开始练习和评价了,咱们设置练习的epoch为5,在每个epoch结束时对模型进行评价。

N_EPOCHS = 5
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
    start_time = time.time()
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'model.pt')
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

打印输出如下:

Epoch: 01 | Epoch Time: 1m 50s
	Train Loss: 0.558 | Train Acc: 70.57%
	 Val. Loss: 0.444 |  Val. Acc: 79.54%
Epoch: 02 | Epoch Time: 1m 50s
	Train Loss: 0.393 | Train Acc: 82.70%
	 Val. Loss: 0.383 |  Val. Acc: 83.21%
Epoch: 03 | Epoch Time: 1m 50s
	Train Loss: 0.287 | Train Acc: 88.10%
	 Val. Loss: 0.300 |  Val. Acc: 88.08%
Epoch: 04 | Epoch Time: 1m 50s
	Train Loss: 0.161 | Train Acc: 94.26%
	 Val. Loss: 0.314 |  Val. Acc: 87.84%
Epoch: 05 | Epoch Time: 1m 50s
	Train Loss: 0.122 | Train Acc: 95.53%
	 Val. Loss: 0.367 |  Val. Acc: 87.17%

能够看到,随着练习的进行,模型在练习集上的loss安稳下降,准确性也在逐步提高;而在验证集上准确性提高到一定数值后就不在提升了,乃至有所下降,这可能是模型过拟合导致。

得到验证集上的最好体现的模型后,咱们在测验集上进行一步测验:

model.load_state_dict(torch.load('model.pt'))
test_loss, test_acc = evaluate(model, test_iterator, criterion)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

输出如下:

Test Loss: 0.321 | Test Acc: 87.01%

能够看到,咱们的模型在测验集上仍旧体现良好,达到了87%的准确性。

改善与建议

以上的代码尽管能够取得较好的结果,但因为该问题比较简单,所以还有许多地方能够改善,首要的改善点如下:

  • 超参数设置:如学习率的巨细、变化方法、需不需要warmup;batch_size的巨细;
  • 初始化:Embedding、weight的初始化方法不同将会带来不同的体现;
  • 数据集清洗:咱们能够去除掉某些乱码或许HTML标签等;
  • 分词优化:能够测验其他tokenizer改善分词效果。
  • ……

总结

本篇和上一篇博客为我们介绍了一个完整的NLP中的文本分类任务所涉及的各个模块,因为篇幅所限,未能展现全部代码,有兴趣的读者能够参阅gitee.com/geektime-ge… 。