通过PyTorch来创建一个文本分类的Bert模型

开发 后端
在本文中,介绍了一种称为BERT(带转换器Transformers的双向编码Encoder 器表示)的语言模型,该模型在问答、自然语言推理、分类和通用语言理解评估或 (GLUE)等任务中取得了最先进的性能.

[[420285]]

2018 年,谷歌发表了一篇题为《Pre-training of deep bidirectional Transformers for Language Understanding》的论文。

在本文中,介绍了一种称为BERT(带转换器Transformers的双向编码Encoder 器表示)的语言模型,该模型在问答、自然语言推理、分类和通用语言理解评估或 (GLUE)等任务中取得了最先进的性能.

BERT全称为Bidirectional Encoder Representation from Transformers[1],是一种用于语言表征的预训练模型。

它基于谷歌2017年发布的Transformer架构,通常的Transformer使用一组编码器和解码器网络,而BERT只需要一个额外的输出层,对预训练进行fine-tune,就可以满足各种任务,根本没有必要针对特定任务对模型进行修改。

BERT将多个Transformer编码器堆叠在一起。Transformer基于著名的多头注意力(Multi-head Attention)模块,该模块在视觉和语言任务方面都取得了巨大成功。

在本文中,我们将使用 PyTorch来创建一个文本分类的Bert模型。

笔者介今天绍一个python库 --- simpletransformers,可以很好的解决高级预训练语言模型使用困难的问题。

simpletransformers使得高级预训练模型(BERT、RoBERTa、XLNet、XLM、DistilBERT、ALBERT、CamemBERT、XLM-RoBERTa、FlauBERT)的训练、评估和预测变得简单,每条只需3行即可初始化模型。

数据集来源:https://www.kaggle.com/jrobischon/wikipedia-movie-plots

该数据集包含对来自世界各地的 34,886 部电影的描述。列描述如下:

  • 发行年份:电影发行的年份
  • 标题:电影标题
  • 起源:电影的起源(即美国、宝莱坞、泰米尔等)
  • 剧情:主要演员
  • 类型:电影类型
  • 维基页面- 从中抓取情节描述的维基百科页面的 URL
  • 情节:电影情节的长篇描述
  1. import numpy as np 
  2. import pandas as pd 
  3. import os, json, gc, re, random 
  4. from tqdm.notebook import tqdm 
  5. import torch, transformers, tokenizers 
  6. movies_df = pd.read_csv("wiki_movie_plots_deduped.csv"
  7. from sklearn.preprocessing import LabelEncoder 
  8.  
  9. movies_df = movies_df[(movies_df["Origin/Ethnicity"]=="American") | (movies_df["Origin/Ethnicity"]=="British")] 
  10. movies_df = movies_df[["Plot""Genre"]] 
  11. drop_indices = movies_df[movies_df["Genre"] == "unknown" ].index 
  12. movies_df.drop(drop_indices, inplace=True
  13.  
  14. # Combine genres: 1) "sci-fi" with "science fiction" &  2) "romantic comedy" with "romance" 
  15. movies_df["Genre"].replace({"sci-fi""science fiction""romantic comedy""romance"}, inplace=True
  16.  
  17. # 根据频率选择电影类型 
  18. shortlisted_genres = movies_df["Genre"].value_counts().reset_index(name="count").query("count > 200")["index"].tolist() 
  19. movies_df = movies_df[movies_df["Genre"].isin(shortlisted_genres)].reset_index(drop=True
  20.  
  21. # Shuffle  
  22. movies_df = movies_df.sample(frac=1).reset_index(drop=True
  23.  
  24. #从不同类型中抽取大致相同数量的电影情节样本(以减少阶级不平衡问题) 
  25. movies_df = movies_df.groupby("Genre").head(400).reset_index(drop=True
  26. label_encoder = LabelEncoder() 
  27. movies_df["genre_encoded"] = label_encoder.fit_transform(movies_df["Genre"].tolist()) 
  28. movies_df = movies_df[["Plot""Genre""genre_encoded"]] 
  29. movies_df 

使用 torch 加载 BERT 模型,最简单的方法是使用 Simple Transformers 库,以便只需 3 行代码即可初始化、在给定数据集上训练和在给定数据集上评估 Transformer 模型。

  1. from simpletransformers.classification import ClassificationModel 
  2.  
  3. # 模型参数 
  4. model_args = { 
  5.     "reprocess_input_data"True
  6.     "overwrite_output_dir"True
  7.     "save_model_every_epoch"False
  8.     "save_eval_checkpoints"False
  9.     "max_seq_length": 512, 
  10.     "train_batch_size": 16, 
  11.     "num_train_epochs": 4, 
  12.  
  13. Create a ClassificationModel 
  14. model = ClassificationModel('bert''bert-base-cased', num_labels=len(shortlisted_genres), args=model_args) 

训练模型

  1. train_df, eval_df = train_test_split(movies_df, test_size=0.2, stratify=movies_df["Genre"], random_state=42) 
  2.  
  3. # Train the model 
  4. model.train_model(train_df[["Plot""genre_encoded"]]) 
  5.  
  6. # Evaluate the model 
  7. result, model_outputs, wrong_predictions = model.eval_model(eval_df[["Plot""genre_encoded"]]) 
  8. print(result) 
  9.  
  10. {'mcc': 0.5299659404649717, 'eval_loss': 1.4970421879083518} 
  11. CPU times: user 19min 1s, sys: 4.95 s, total: 19min 6s 
  12. Wall time: 20min 14s 

关于simpletransformers的官方文档:https://simpletransformers.ai/docs

Github链接:https://github.com/ThilinaRajapakse/simpletransformers

 

责任编辑:姜华 来源: Python之王
相关推荐

2020-09-25 09:58:37

谷歌Android开发者

2022-10-09 08:00:00

机器学习文本分类算法

2021-03-06 07:00:00

awk文本分析工具Linux

2018-07-04 15:17:07

CNNNLP模型

2020-06-04 12:55:44

PyTorch分类器神经网络

2020-09-22 15:17:59

谷歌Android技术

2017-08-04 14:23:04

机器学习神经网络TensorFlow

2018-12-17 09:10:52

机器学习TensorFlow容器

2020-03-23 08:00:00

开源数据集文本分类

2017-06-20 11:00:13

大数据自然语言文本分类器

2010-09-25 15:46:58

帐户管理旧账户

2023-12-31 16:35:31

Pytorch函数深度学习

2023-02-27 09:31:00

streamlitst.sidebar菜单

2022-09-28 15:34:06

机器学习语音识别Pytorch

2017-08-25 14:23:44

TensorFlow神经网络文本分类

2020-11-30 09:30:00

数据模型架构

2023-11-28 09:00:00

机器学习少样本学习SetFit

2020-03-13 15:33:54

Google 开源技术

2023-12-18 08:00:42

JavaScrip框架Lit
点赞
收藏

51CTO技术栈公众号