位置: IT常识 - 正文

Huggingface之transformers零基础使用指南

发布时间:2024-01-23
前几篇博文中介绍了Transformer,由于其优越的性能表现,在工业界使用的越来越广泛,同时,配合迁移学习理论,越来越多的Transformer预训练模型和源码库逐渐开源,Huggingface就是其中做的最为出色的一家机构。Huggingface是一家在NLP社区做出杰出贡献的纽约创业公司,其所... ...

推荐整理分享Huggingface之transformers零基础使用指南,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

前几篇博文中介绍了Transformer,由于其优越的性能表现,在工业界使用的越来越广泛,同时,配合迁移学习理论,越来越多的Transformer预训练模型和源码库逐渐开源,Huggingface就是其中做的最为出色的一家机构。Huggingface是一家在NLP社区做出杰出贡献的纽约创业公司,其所提供的大量预训练模型和代码等资源被广泛的应用于学术研究当中。Huggingface所开源的Transformers提供了数以千计针对于各种任务的预训练模型模型,开发者可以根据自身的需要,选择模型进行训练或微调,也可阅读api文档和源码, 快速开发新模型。

本篇博文,我们对Huggingface所开源的Transformers进行介绍。在此之前,请通过下行命令安装transformers库:

pip install transformers1 从AutoClass说起¶

transformers库中提供了上百个算法模型的实现,有BERT模型对应的BertModel类,有BART对应的BartModel类……,每当我们使用对应的预训练模型时,都必须先找到对应类名,然后进行实例化,麻烦吗?非常麻烦!

所以,transformers库中提供统一的入口,也就是我们这里说到的“AutoClass”系列的高级对象,通过在调用“AutoClass”的from_pretrained()方法时指定预训练模型的名称或预训练模型所在目录,即可快速、便捷得完成预训练模型创建。有了“AutoClass”,只需要知道预训练模型的名称,或者将预训练模型下载好,程序将根据预训练模型配置文件中model_type或者预训练模型名称、路径进行模式匹配,自动决定实例化哪一个模型类,不再需要再到该模型在transfors库中对应的类名。“AutoClass”所有类都不能够通过init()方法进行实例化,只能通过from_pretrained()方法实例化指定的类。

如下所示,我们到Huggingface官网下载好一个中文BERT预训练模型,模型所有文件存放在当前目录下的“model/bert-base-chinese”路径下。创建预训练模型时,我们将这一路径传递到from_pretrained()方法,即可完成模型创建,创建好的模型为BertModel类的实例。

In[1]:from transformers import AutoModelIn[4]:model = AutoModel.from_pretrained("./models/bert-base-chinese")print(type(model))Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).<class 'transformers.models.bert.modeling_bert.BertModel'>

可以看到,。这一过程中,之所以会有提示信息显示,是因为有些权重参数并未使用上,这是正常的。

之所以说“AutoClass”是一个系列,是因为“AutoClass”不仅包括便捷创建预训练模型的类对象AutoModel,还包括预训练模型对应的分词器类对象AutoTokenizer,预训练模型配置管理类AutoTokenizer,以及其他各种特色功能的类对象,这里不一一列举,看一参考Huggingface官方文档对“AutoClass”的说明。

2 词的向量表示——AutoTokenizer¶

几乎所有的自然语言处理任务,都是从分词和词的向量表示开始的,Transformer算法模型也不例外,所以,在Huggingface的transformers库中提供了高级API对象——AutoTokenizer,用以加载预训练的分词器实现这一过程。

AutoTokenizer是Huggingface提供的“AutoClass”系列的高级对象,可以便捷的调用tokenizers库(Huggingface提供的专门用于分词等操作的代码库)实现加载预训练的分词器。

通过在AutoTokenizer中定义的from_pretrained方法指定需要加载的分词器名称,即可从网络上自动加载分词器,并实例化tokenizers库中分词器。tokenizers中定义的分词器对象提供非常丰富的功能,例如定义词库、加载词库、截断、填充、指定特殊标记等。

`这里需要注意,大多数情况下,我们都是同时使用预定义的分词器和预训练模型,或者说是配套使用的,例如,我们使用的预训练模型是“bert-base-chinese”,那么,加载分词器是,也必须使用“bert-base-chinese”对应的词库,否则,使用预训练模型就效果将大大降低。`

In[5]:from transformers import AutoTokenizerIn[7]:tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")In[8]:sentence = "床前明月光"tokenizer(sentence)Out[8]:{'input_ids': [101, 2414, 1184, 3209, 3299, 1045, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}指定目录加载分词器

当然,有时候因为网络原因,也可以先手动从Huggingface官网下载模型,然后在from_pretrained方法中指定本地目录方式进行加载。

In[9]:tokenizer = AutoTokenizer.from_pretrained("./models/bert-base-chinese")In[10]:sentence = "床前明月光"tokenizer(sentence)Out[10]:{'input_ids': [101, 2414, 1184, 3209, 3299, 1045, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}同时转化多个句子In[11]:sentence = ["床前明月光", "床前明月光,疑是地上霜。"]tokenizer(sentence)Out[11]:{'input_ids': [[101, 2414, 1184, 3209, 3299, 1045, 102], [101, 2414, 1184, 3209, 3299, 1045, 8024, 4542, 3221, 1765, 677, 7458, 511, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}其他参数功能

tokenizer内部还提供其他丰富的参数用于实现多种多样功能:

In[12]:tokenizer( ["床前明月光", "床前明月光,疑是地上霜。"], padding=True, # 长度不足max_length时是否进行填充 truncation=True, # 长度超过max_length时是否进行截断 max_length=10, return_tensors="pt", # 指定返回数据类型,pt:pytorch的张量,tf:TensorFlow的张量)Out[12]:{'input_ids': tensor([[ 101, 2414, 1184, 3209, 3299, 1045, 102, 0, 0, 0], [ 101, 2414, 1184, 3209, 3299, 1045, 8024, 4542, 3221, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}3 配置——AutoConfig¶

每一类的算法模型的框架结构都是不一样的,所以超参数配置也不一样,如果每次加载预训练模型,都要用户手动去找到对应的配置项、配置类,就是非不便捷了,所以,在“AutoClass”中也提供有专门的配置管理入口——AutoConfig。

Huggingface之transformers零基础使用指南

一般来说,就算同一个算法的预训练模型,也可能有不同的网络结构,所以,我们下载的预训练模型本身就提供有一个配置文件,例如在Huggingface官网下载的预训练模型,提供有一个config.json文件,AutoConfig将从里面加载当前预训练模型的特定配置项信息进行覆盖。

以BERT模型为例,我们下来看看默认的配置项:

In[13]:from transformers import BertConfigconfig = BertConfig()In[14]:configOut[14]:BertConfig { "attention_probs_dropout_prob": 0.1, "classifier_dropout": null, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 0, "position_embedding_type": "absolute", "transformers_version": "4.24.0", "type_vocab_size": 2, "use_cache": true, "vocab_size": 30522}In[15]:from transformers import AutoConfigconfig = AutoConfig.from_pretrained("./models/bert-base-chinese")In[16]:configOut[16]:BertConfig { "_name_or_path": "./models/bert-base-chinese", "attention_probs_dropout_prob": 0.1, "classifier_dropout": null, "directionality": "bidi", "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "lstm_dropout_prob": 0.5, "lstm_embedding_size": 768, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 0, "pooler_fc_size": 768, "pooler_num_attention_heads": 12, "pooler_num_fc_layers": 3, "pooler_size_per_head": 128, "pooler_type": "first_token_transform", "position_embedding_type": "absolute", "transformers_version": "4.24.0", "type_vocab_size": 2, "use_cache": true, "vocab_size": 21128}

可以看到,从预训练模型加载出来的配置项与之前的默认配置项略有不同。而且,这个配置实例就是BertConfig类的实例,如下所示:

In[17]:type(config)Out[17]:transformers.models.bert.configuration_bert.BertConfig

通过config实例,我们可以对配置项进行修改,例如,上述配置中,编码器结构为12层编码器层,我们将其修改为5层,如下所示,经过修改后,最终创建的模型编码器只包含5层结构,也只有前5层会加载预训练结构,其他权重将会被舍弃。

In[18]:config.num_hidden_layers=5print(config)BertConfig { "_name_or_path": "./models/bert-base-chinese", "attention_probs_dropout_prob": 0.1, "classifier_dropout": null, "directionality": "bidi", "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "lstm_dropout_prob": 0.5, "lstm_embedding_size": 768, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 5, "pad_token_id": 0, "pooler_fc_size": 768, "pooler_num_attention_heads": 12, "pooler_num_fc_layers": 3, "pooler_size_per_head": 128, "pooler_type": "first_token_transform", "position_embedding_type": "absolute", "transformers_version": "4.24.0", "type_vocab_size": 2, "use_cache": true, "vocab_size": 21128}

修改之后的参数,如果后续需要再次使用,可以保存到本地,传入保存路径,将在指定目录保存为config.json文件:

In[75]:config.save_pretrained("./models/bert-base-chinese")4 创建预训练模型——AutoModel¶

Huggingface官方提供了很多的预训练模型,可以在Huggingface官网很容易找到。通过AutoModel类,创建预训练模型最简单的方法就是直接传入预训练模型名称或者本地路径,因为国内网络环境原因,建议先去将预训练模型下载到本地,通过指定目录的方式进行加载:

In[19]:from transformers import AutoModelmodel = AutoModel.from_pretrained("./models/bert-base-chinese")Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

通过这种方法,模型将直接加载预训练模型config.json的配置项。也可以在加载模型时,指定配置类实例,这样就可以实现对预训练模型的自定义,例如,传入我们上一小节中修改后的config实例:

In[20]:model = AutoModel.from_pretrained("./models/bert-base-chinese", config=config)Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.weight', 'bert.encoder.layer.8.attention.self.value.bias', 'bert.encoder.layer.8.attention.output.dense.bias', 'bert.encoder.layer.10.attention.self.query.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'bert.encoder.layer.7.attention.output.dense.bias', 'bert.encoder.layer.9.attention.self.value.bias', 'bert.encoder.layer.7.attention.self.query.weight', 'bert.encoder.layer.10.output.LayerNorm.weight', 'bert.encoder.layer.9.attention.output.LayerNorm.bias', 'bert.encoder.layer.10.attention.output.LayerNorm.weight', 'bert.encoder.layer.9.attention.self.key.weight', 'cls.predictions.transform.dense.weight', 'bert.encoder.layer.5.attention.self.query.weight', 'bert.encoder.layer.11.output.LayerNorm.weight', 'bert.encoder.layer.6.attention.self.query.weight', 'bert.encoder.layer.10.output.dense.weight', 'bert.encoder.layer.6.attention.output.dense.bias', 'bert.encoder.layer.5.attention.self.key.weight', 'bert.encoder.layer.7.attention.self.value.bias', 'bert.encoder.layer.5.intermediate.dense.weight', 'bert.encoder.layer.9.intermediate.dense.weight', 'bert.encoder.layer.5.attention.self.query.bias', 'bert.encoder.layer.7.attention.self.key.weight', 'bert.encoder.layer.11.output.dense.weight', 'bert.encoder.layer.8.attention.self.key.weight', 'bert.encoder.layer.10.output.dense.bias', 'bert.encoder.layer.10.attention.output.dense.weight', 'bert.encoder.layer.11.intermediate.dense.bias', 'bert.encoder.layer.9.output.LayerNorm.weight', 'bert.encoder.layer.9.output.LayerNorm.bias', 'bert.encoder.layer.6.attention.self.value.weight', 'bert.encoder.layer.10.attention.output.dense.bias', 'bert.encoder.layer.11.attention.output.dense.bias', 'bert.encoder.layer.10.intermediate.dense.weight', 'bert.encoder.layer.10.attention.self.value.weight', 'bert.encoder.layer.6.attention.self.value.bias', 'bert.encoder.layer.6.attention.self.query.bias', 'bert.encoder.layer.11.intermediate.dense.weight', 'bert.encoder.layer.5.attention.output.LayerNorm.bias', 'bert.encoder.layer.8.output.dense.weight', 'bert.encoder.layer.11.attention.self.query.weight', 'bert.encoder.layer.7.intermediate.dense.bias', 'bert.encoder.layer.9.output.dense.bias', 'bert.encoder.layer.11.attention.self.value.bias', 'bert.encoder.layer.5.attention.self.value.bias', 'bert.encoder.layer.11.attention.output.LayerNorm.weight', 'bert.encoder.layer.10.attention.self.key.bias', 'bert.encoder.layer.9.attention.output.LayerNorm.weight', 'bert.encoder.layer.6.output.dense.weight', 'bert.encoder.layer.6.output.LayerNorm.bias', 'bert.encoder.layer.7.attention.self.key.bias', 'bert.encoder.layer.11.output.LayerNorm.bias', 'bert.encoder.layer.8.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.value.weight', 'bert.encoder.layer.8.attention.output.dense.weight', 'bert.encoder.layer.9.attention.output.dense.weight', 'cls.predictions.bias', 'bert.encoder.layer.9.output.dense.weight', 'bert.encoder.layer.8.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.self.key.bias', 'bert.encoder.layer.6.attention.self.key.bias', 'bert.encoder.layer.9.attention.self.query.bias', 'bert.encoder.layer.6.intermediate.dense.bias', 'bert.encoder.layer.6.attention.self.key.weight', 'bert.encoder.layer.8.attention.self.key.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.output.dense.bias', 'bert.encoder.layer.6.output.dense.bias', 'bert.encoder.layer.11.attention.self.key.weight', 'bert.encoder.layer.7.output.dense.weight', 'bert.encoder.layer.8.intermediate.dense.bias', 'bert.encoder.layer.5.attention.self.value.weight', 'bert.encoder.layer.7.output.LayerNorm.weight', 'bert.encoder.layer.5.output.dense.weight', 'bert.encoder.layer.11.output.dense.bias', 'bert.encoder.layer.8.output.dense.bias', 'bert.encoder.layer.10.attention.self.query.weight', 'bert.encoder.layer.9.intermediate.dense.bias', 'bert.encoder.layer.9.attention.self.value.weight', 'bert.encoder.layer.11.attention.output.LayerNorm.bias', 'bert.encoder.layer.7.intermediate.dense.weight', 'bert.encoder.layer.7.attention.self.query.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.weight', 'bert.encoder.layer.5.output.LayerNorm.weight', 'cls.predictions.decoder.weight', 'bert.encoder.layer.9.attention.self.query.weight', 'bert.encoder.layer.6.intermediate.dense.weight', 'bert.encoder.layer.10.output.LayerNorm.bias', 'bert.encoder.layer.11.attention.output.dense.weight', 'bert.encoder.layer.10.intermediate.dense.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.key.bias', 'bert.encoder.layer.8.intermediate.dense.weight', 'bert.encoder.layer.5.output.dense.bias', 'bert.encoder.layer.5.attention.output.dense.bias', 'bert.encoder.layer.8.attention.output.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'bert.encoder.layer.8.attention.self.value.weight', 'cls.predictions.transform.LayerNorm.bias', 'bert.encoder.layer.8.attention.self.query.bias', 'bert.encoder.layer.5.attention.output.dense.weight', 'bert.encoder.layer.7.output.dense.bias', 'cls.seq_relationship.bias', 'bert.encoder.layer.6.attention.output.dense.weight', 'bert.encoder.layer.11.attention.self.query.bias', 'bert.encoder.layer.7.attention.self.value.weight', 'bert.encoder.layer.8.attention.output.LayerNorm.weight', 'bert.encoder.layer.6.output.LayerNorm.weight', 'bert.encoder.layer.5.attention.self.key.bias', 'bert.encoder.layer.10.attention.self.value.bias', 'bert.encoder.layer.5.attention.output.LayerNorm.weight', 'bert.encoder.layer.7.output.LayerNorm.bias', 'bert.encoder.layer.5.intermediate.dense.bias', 'bert.encoder.layer.7.attention.output.dense.weight', 'bert.encoder.layer.10.attention.output.LayerNorm.bias', 'bert.encoder.layer.8.attention.self.query.weight', 'bert.encoder.layer.10.attention.self.key.weight', 'bert.encoder.layer.5.output.LayerNorm.bias']- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

因为在上一章节,我们将编码器结构层数改为5,所以,这里提示很多权重参数并未使用。

同时,我们也可以通过在from_pretrained()方法中直接传参的方式,传入配置项,例如,我们将编码器层数改为3层。注意,这种方式在指定了config参数时不在生效。

In[21]:model = AutoModel.from_pretrained("./models/bert-base-chinese", num_hidden_layers=3)Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['bert.encoder.layer.4.attention.self.value.bias', 'cls.seq_relationship.weight', 'bert.encoder.layer.8.attention.self.value.bias', 'bert.encoder.layer.8.attention.output.dense.bias', 'bert.encoder.layer.10.attention.self.query.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'bert.encoder.layer.7.attention.output.dense.bias', 'bert.encoder.layer.9.attention.self.value.bias', 'bert.encoder.layer.7.attention.self.query.weight', 'bert.encoder.layer.10.output.LayerNorm.weight', 'bert.encoder.layer.9.attention.output.LayerNorm.bias', 'bert.encoder.layer.4.intermediate.dense.weight', 'bert.encoder.layer.4.attention.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.self.key.weight', 'cls.predictions.transform.dense.weight', 'bert.encoder.layer.10.attention.output.LayerNorm.weight', 'bert.encoder.layer.4.attention.self.key.weight', 'bert.encoder.layer.3.intermediate.dense.bias', 'bert.encoder.layer.5.attention.self.query.weight', 'bert.encoder.layer.11.output.LayerNorm.weight', 'bert.encoder.layer.6.attention.self.query.weight', 'bert.encoder.layer.10.output.dense.weight', 'bert.encoder.layer.6.attention.output.dense.bias', 'bert.encoder.layer.5.attention.self.key.weight', 'bert.encoder.layer.7.attention.self.value.bias', 'bert.encoder.layer.5.intermediate.dense.weight', 'bert.encoder.layer.9.intermediate.dense.weight', 'bert.encoder.layer.5.attention.self.query.bias', 'bert.encoder.layer.7.attention.self.key.weight', 'bert.encoder.layer.4.output.dense.weight', 'bert.encoder.layer.8.attention.self.key.weight', 'bert.encoder.layer.11.output.dense.weight', 'bert.encoder.layer.10.output.dense.bias', 'bert.encoder.layer.10.attention.output.dense.weight', 'bert.encoder.layer.11.intermediate.dense.bias', 'bert.encoder.layer.3.attention.self.key.weight', 'bert.encoder.layer.9.output.LayerNorm.weight', 'bert.encoder.layer.9.output.LayerNorm.bias', 'bert.encoder.layer.3.attention.self.value.bias', 'bert.encoder.layer.6.attention.self.value.weight', 'bert.encoder.layer.10.attention.output.dense.bias', 'bert.encoder.layer.11.attention.output.dense.bias', 'bert.encoder.layer.10.intermediate.dense.weight', 'bert.encoder.layer.10.attention.self.value.weight', 'bert.encoder.layer.6.attention.self.value.bias', 'bert.encoder.layer.6.attention.self.query.bias', 'bert.encoder.layer.4.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.output.dense.bias', 'bert.encoder.layer.11.intermediate.dense.weight', 'bert.encoder.layer.5.attention.output.LayerNorm.bias', 'bert.encoder.layer.4.attention.self.query.weight', 'bert.encoder.layer.8.output.dense.weight', 'bert.encoder.layer.11.attention.self.query.weight', 'bert.encoder.layer.3.intermediate.dense.weight', 'bert.encoder.layer.4.attention.output.dense.bias', 'bert.encoder.layer.7.intermediate.dense.bias', 'bert.encoder.layer.9.output.dense.bias', 'bert.encoder.layer.11.attention.self.value.bias', 'bert.encoder.layer.5.attention.self.value.bias', 'bert.encoder.layer.11.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.attention.output.dense.weight', 'bert.encoder.layer.10.attention.self.key.bias', 'bert.encoder.layer.9.attention.output.LayerNorm.weight', 'bert.encoder.layer.4.output.LayerNorm.bias', 'bert.encoder.layer.6.output.dense.weight', 'bert.encoder.layer.6.output.LayerNorm.bias', 'bert.encoder.layer.7.attention.self.key.bias', 'bert.encoder.layer.11.output.LayerNorm.bias', 'bert.encoder.layer.8.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.value.weight', 'bert.encoder.layer.8.attention.output.dense.weight', 'bert.encoder.layer.9.attention.output.dense.weight', 'cls.predictions.bias', 'bert.encoder.layer.9.output.dense.weight', 'bert.encoder.layer.8.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.self.key.bias', 'bert.encoder.layer.6.attention.self.key.bias', 'bert.encoder.layer.3.output.dense.weight', 'bert.encoder.layer.9.attention.self.query.bias', 'bert.encoder.layer.6.intermediate.dense.bias', 'bert.encoder.layer.6.attention.self.key.weight', 'bert.encoder.layer.8.attention.self.key.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.bias', 'bert.encoder.layer.3.output.LayerNorm.bias', 'bert.encoder.layer.3.attention.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.output.dense.bias', 'bert.encoder.layer.6.output.dense.bias', 'bert.encoder.layer.11.attention.self.key.weight', 'bert.encoder.layer.7.output.dense.weight', 'bert.encoder.layer.8.intermediate.dense.bias', 'bert.encoder.layer.5.attention.self.value.weight', 'bert.encoder.layer.7.output.LayerNorm.weight', 'bert.encoder.layer.5.output.dense.weight', 'bert.encoder.layer.3.attention.self.query.weight', 'bert.encoder.layer.11.output.dense.bias', 'bert.encoder.layer.4.output.LayerNorm.weight', 'bert.encoder.layer.3.attention.self.key.bias', 'bert.encoder.layer.8.output.dense.bias', 'bert.encoder.layer.9.intermediate.dense.bias', 'bert.encoder.layer.10.attention.self.query.weight', 'bert.encoder.layer.9.attention.self.value.weight', 'bert.encoder.layer.11.attention.output.LayerNorm.bias', 'bert.encoder.layer.4.attention.self.key.bias', 'bert.encoder.layer.7.intermediate.dense.weight', 'bert.encoder.layer.7.attention.self.query.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.weight', 'bert.encoder.layer.5.output.LayerNorm.weight', 'cls.predictions.decoder.weight', 'bert.encoder.layer.9.attention.self.query.weight', 'bert.encoder.layer.6.intermediate.dense.weight', 'bert.encoder.layer.10.output.LayerNorm.bias', 'bert.encoder.layer.11.attention.output.dense.weight', 'bert.encoder.layer.10.intermediate.dense.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.key.bias', 'bert.encoder.layer.8.intermediate.dense.weight', 'bert.encoder.layer.5.output.dense.bias', 'bert.encoder.layer.5.attention.output.dense.bias', 'bert.encoder.layer.8.attention.output.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'bert.encoder.layer.8.attention.self.value.weight', 'cls.predictions.transform.LayerNorm.bias', 'bert.encoder.layer.3.attention.output.LayerNorm.weight', 'bert.encoder.layer.8.attention.self.query.bias', 'bert.encoder.layer.5.attention.output.dense.weight', 'bert.encoder.layer.3.attention.output.dense.bias', 'bert.encoder.layer.7.output.dense.bias', 'bert.encoder.layer.4.attention.output.dense.weight', 'bert.encoder.layer.6.attention.output.dense.weight', 'bert.encoder.layer.11.attention.self.query.bias', 'cls.seq_relationship.bias', 'bert.encoder.layer.4.intermediate.dense.bias', 'bert.encoder.layer.7.attention.self.value.weight', 'bert.encoder.layer.8.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.output.LayerNorm.weight', 'bert.encoder.layer.4.attention.self.query.bias', 'bert.encoder.layer.6.output.LayerNorm.weight', 'bert.encoder.layer.5.attention.self.key.bias', 'bert.encoder.layer.10.attention.self.value.bias', 'bert.encoder.layer.4.output.dense.bias', 'bert.encoder.layer.5.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.attention.self.value.weight', 'bert.encoder.layer.7.output.LayerNorm.bias', 'bert.encoder.layer.5.intermediate.dense.bias', 'bert.encoder.layer.7.attention.output.dense.weight', 'bert.encoder.layer.10.attention.output.LayerNorm.bias', 'bert.encoder.layer.8.attention.self.query.weight', 'bert.encoder.layer.3.attention.self.query.bias', 'bert.encoder.layer.10.attention.self.key.weight', 'bert.encoder.layer.5.output.LayerNorm.bias', 'bert.encoder.layer.4.attention.self.value.weight']- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

我们尝试一下将tokenizer编码后的张量在model中进行前向传播:

In[22]:tens = model(**tokenizer("床前明月光,疑是地上霜。", return_tensors="pt"))In[23]:tens.last_hidden_state.shapeOut[23]:torch.Size([1, 14, 768])

当模型修改或者重新训练后,可以通过model.save_pretrained()方法再次保存,保存后,在指定目录中将生成两个文件:配置文件(config.json),权重文件(pytorch_model.bin)。

In[82]:model.save_pretrained("./new_model/bert-base-chinese")5 使用现成的任务模型¶

在transformers库中,Huggingface还提供有许多完整网络模型用于各式各样的AI任务,例如图像分类、文本分类、语音分类、翻译、问答等,这类API大多以AutoModelFor*开头,我们打印输出看看:

In[31]:import transformersIn[33]:for api in dir(transformers): if api.startswith('AutoModelFor'): print(api)AutoModelForAudioClassificationAutoModelForAudioFrameClassificationAutoModelForAudioXVectorAutoModelForCTCAutoModelForCausalLMAutoModelForDepthEstimationAutoModelForDocumentQuestionAnsweringAutoModelForImageClassificationAutoModelForImageSegmentationAutoModelForInstanceSegmentationAutoModelForMaskedImageModelingAutoModelForMaskedLMAutoModelForMultipleChoiceAutoModelForNextSentencePredictionAutoModelForObjectDetectionAutoModelForPreTrainingAutoModelForQuestionAnsweringAutoModelForSemanticSegmentationAutoModelForSeq2SeqLMAutoModelForSequenceClassificationAutoModelForSpeechSeq2SeqAutoModelForTableQuestionAnsweringAutoModelForTokenClassificationAutoModelForVideoClassificationAutoModelForVision2SeqAutoModelForVisualQuestionAnsweringAutoModelForZeroShotObjectDetection

以其中的AutoModelForSequenceClassification为例,介绍怎么使用:

In[34]:from transformers import AutoModelForSequenceClassificationIn[35]:model = AutoModelForSequenceClassification.from_pretrained("./models/bert-base-chinese", num_labels=2)Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./models/bert-base-chinese and are newly initialized: ['classifier.weight', 'classifier.bias']You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

num_labels意思是,我们需要进行的任务最终标签有两类,即这是一个二分类模型。我们查看一下模型结构:

In[36]:modelOut[36]:BertForSequenceClassification( (bert): BertModel( (embeddings): BertEmbeddings( (word_embeddings): Embedding(21128, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (token_type_embeddings): Embedding(2, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): BertEncoder( (layer): ModuleList( (0): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (1): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (2): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (3): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (4): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (5): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (6): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (7): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (8): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (9): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (10): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (11): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) (pooler): BertPooler( (dense): Linear(in_features=768, out_features=768, bias=True) (activation): Tanh() ) ) (dropout): Dropout(p=0.1, inplace=False) (classifier): Linear(in_features=768, out_features=2, bias=True))In[37]:tokenizer("床前明月光,疑是地上霜。", return_tensors="pt")Out[37]:{'input_ids': tensor([[ 101, 2414, 1184, 3209, 3299, 1045, 8024, 4542, 3221, 1765, 677, 7458, 511, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}In[38]:tens = model(**tokenizer("床前明月光,疑是地上霜。", return_tensors="pt"))In[39]:tensOut[39]:SequenceClassifierOutput(loss=None, logits=tensor([[-0.4371, -0.1223]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

最后输出了两个值,分贝对应于两个类别。

6 自定义模型¶In[1]:from chb import *import pandas as pdfrom tqdm import tqdmfrom collections import defaultdictimport torchfrom torch import nnfrom torch.utils.data.dataloader import DataLoaderfrom torch.utils.data.dataset import Datasetfrom torch.optim import AdamWfrom sklearn.model_selection import train_test_splitfrom transformers import AutoConfig,AutoModel,AutoTokenizer,get_linear_schedule_with_warmup,loggingimport warningswarnings.filterwarnings('ignore')In[2]:RANDOM_SEED = 1000MAX_LEN = 50BATCH_SIZE = 646.1 自定义数据集¶In[3]:x_lst = []y_lst = []with open('./data/中文文本-新闻分类数据集','r') as f: # 获得训练数据的总行数 for _ in tqdm(f,desc='load dataset'): try: line = f.readline().replace('\u3000\u3000', '').replace('\n', '') x, y = line.split('\t') if y == 'label': continue x_lst.append(x) y_lst.append(y) except: passload dataset: 5902it [00:00, 55378.83it/s]In[4]:len(x_lst), len(y_lst)Out[4]:(5900, 5900)In[5]:x_lst[0], y_lst[0]Out[5]:('昌平京基鹭府10月29日推别墅1200万套起享97折新浪房产讯(编辑郭彪)京基鹭府(论坛相册户型样板间点评地图搜索)售楼处位于昌平区京承高速北七家出口向西南公里路南。项目预计10月29日开盘,总价1200万元/套起,2012年年底入住。待售户型为联排户型面积为410-522平方米,独栋户型面积为938平方米,双拼户型面积为522平方米。京基鹭府项目位于昌平定泗路与东北路交界处。项目周边配套齐全,幼儿园:伊顿双语幼儿园、温莎双语幼儿园;中学:北师大亚太实验学校、潞河中学(北京市重点);大学:王府语言学校、北京邮电大学、现代音乐学院;医院:王府中西医结合医院(三级甲等)、潞河医院、解放军263医院、安贞医院昌平分院;购物:龙德广场、中联万家商厦、世纪华联超市、瑰宝购物中心、家乐福超市;酒店:拉斐特城堡、鲍鱼岛;休闲娱乐设施:九华山庄、温都温泉度假村、小汤山疗养院、龙脉温泉度假村、小汤山文化广场、皇港高尔夫、高地高尔夫、北鸿高尔夫球场;银行:工商银行、建设银行、中国银行、北京农村商业银行;邮局:中国邮政储蓄;其它:北七家建材城、百安居建材超市、北七家镇武装部、北京宏翔鸿企业孵化基地等,享受便捷生活。京基鹭府坐守定泗路,立汤路交汇处。连接京昌、八达岭、机场高速,南至5环,北上6环,紧邻立汤路,一路向南,直抵鸟巢、水立方、长安街,距北京唯一不堵车的京承高速出口仅1公里,节约出行时间成本,形成了三横、三纵的立体式交通网络项目周边多为别墅项目,人口密度低,交通出行舒适度高。>>报名参加“乐动银十”10月22日大型抄底看房团以上信息仅供参考,最终以开发商公布为准。订阅会员置业刊我们将直接把最新的热盘动向发送到您的邮箱更多热盘推荐:新锐白领淘低价1-2居网罗2万内轨道精装房手握20万咋买板楼2居网罗城南沿轨优质盘关注娃娃教育网罗学区房不足百万元住上通透2居', '房产')In[6]:label2id = dict()id2label = dict()for i, label in enumerate(set(y_lst)): label2id[label] = i id2label[i] = labelIn[7]:tokenizer = AutoTokenizer.from_pretrained("./models/bert-base-chinese")

先把所有的文本都转化为编码,而不是在后续数据集中转化,这样可以避免在后续训练过程中,每一个epoch都要进行转化,提升效率:

In[8]:token_lens = []for txt in tqdm(x_lst): tokens = tokenizer.encode(txt, max_length=512) token_lens.append(len(tokens)) 0%| | 0/5900 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.100%|██████████| 5900/5900 [00:07<00:00, 739.64it/s]In[9]:class NewsDataset(Dataset): def __init__(self,x_lst,y_lst,tokenizer,max_len): self.x_lst=x_lst self.y_lst=y_lst self.tokenizer=tokenizer self.max_len=max_len def __len__(self): return len(self.x_lst) def __getitem__(self,index): """ index 为数据索引,迭代取第index条数据 """ text=str(self.x_lst[index]) label=label2id[self.y_lst[index]] encoding=self.tokenizer.encode_plus( text, add_special_tokens=True, max_length=self.max_len, return_token_type_ids=True, pad_to_max_length=True, return_attention_mask=True, return_tensors='pt', ) return { 'texts':text, 'input_ids':encoding['input_ids'].flatten(), 'attention_mask':encoding['attention_mask'].flatten(), 'labels':torch.tensor(label,dtype=torch.long) }In[10]:x_train, x_val, y_train, y_val = train_test_split(x_lst, y_lst, test_size=0.15, random_state=RANDOM_SEED) # 划分训练集 测试集In[11]:# datasettrain_dataset = NewsDataset(x_train, y_train, tokenizer, MAX_LEN)val_dataset = NewsDataset(x_val, y_val, tokenizer, MAX_LEN)In[12]:# dataloadertrain_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)6.2 自定义网络¶

这里我们使用BERT预训练模型,同时接Dropout层和一层线形层,构成自定义网络:

In[13]:class CustomBERTModel(nn.Module): def __init__(self, n_classes): super(CustomBERTModel, self).__init__() self.bert = AutoModel.from_pretrained("./models/bert-base-chinese") self.drop = nn.Dropout(p=0.3) self.out = nn.Linear(self.bert.config.hidden_size, n_classes) def forward(self, input_ids, attention_mask): _, pooled_output = self.bert( input_ids=input_ids, attention_mask=attention_mask, return_dict = False ) output = self.drop(pooled_output) # dropout return self.out(output)In[14]:device = set_device(cuda_index=1)2022-12-20 16:12:39 set_device line 11 out: cuda:1In[15]:n_classes = len(label2id)model = CustomBERTModel(n_classes)model = model.to(device)Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

自定义数据集:

6.3 训练¶In[16]:def train_epoch(model, data_loader,loss_fn,optimizer,device,scheduler, n_examples): model.train() losses = [] correct_predictions = 0 for i, d in bar(data_loader): input_ids = d["input_ids"].to(device) attention_mask = d["attention_mask"].to(device) targets = d["labels"].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask ) _, preds = torch.max(outputs, dim=1) loss = loss_fn(outputs, targets) correct_predictions += torch.sum(preds == targets) losses.append(loss.item()) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() scheduler.step() optimizer.zero_grad() return correct_predictions.double() / n_examples, np.mean(losses)In[17]:def eval_model(model, data_loader, loss_fn, device, n_examples): model.eval() # 验证预测模式 losses = [] correct_predictions = 0 with torch.no_grad(): for d in data_loader: input_ids = d["input_ids"].to(device) attention_mask = d["attention_mask"].to(device) targets = d["labels"].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask ) _, preds = torch.max(outputs, dim=1) loss = loss_fn(outputs, targets) correct_predictions += torch.sum(preds == targets) losses.append(loss.item()) return correct_predictions.double() / n_examples, np.mean(losses)In[18]:EPOCHS = 5 # 训练轮数optimizer = AdamW(model.parameters(), lr=2e-5)total_steps = len(train_dataloader) * EPOCHSscheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, num_training_steps=total_steps)loss_fn = nn.CrossEntropyLoss().to(device)In[19]:best_accuracy = 0is_best = Falset = Tableprint(['epoch', 'train_accuracy', 'train_loss', 'test_accuracy', 'test_loss', 'is_best'])t.print_header()for epoch in range(EPOCHS): train_acc, train_loss = train_epoch(model,train_dataloader,loss_fn,optimizer,device,scheduler,len(x_train)) val_acc, val_loss = eval_model( model, val_dataloader, loss_fn, device, len(x_val) ) if val_acc > best_accuracy: is_best = True torch.save(model.state_dict(), './models/news_classification/best_model_state.bin') best_accuracy = val_acc else: is_best = False t.print_row(epoch, f"{train_acc:.4f}", f"{train_loss:.4f}", f"{val_acc:.4f}", f"{val_loss:.4f}", is_best)+======+===========+====================+================+===================+===============+=============+| | epoch | train_accuracy | train_loss | test_accuracy | test_loss | is_best |+======+===========+====================+================+===================+===============+=============+| 1 | 0 | 0.6080 | 1.4608 | 0.8893 | 0.5278 | True |+------+-----------+--------------------+----------------+-------------------+---------------+-------------+| 2 | 1 | 0.9196 | 0.3766 | 0.9096 | 0.3583 | True |+------+-----------+--------------------+----------------+-------------------+---------------+-------------+| 3 | 2 | 0.9589 | 0.2015 | 0.9153 | 0.3413 | True |+------+-----------+--------------------+----------------+-------------------+---------------+-------------+| 4 | 3 | 0.9765 | 0.1272 | 0.9153 | 0.3286 | False |+------+-----------+--------------------+----------------+-------------------+---------------+-------------+| 5 | 4 | 0.9836 | 0.0919 | 0.9220 | 0.3239 | True |+------+-----------+--------------------+----------------+-------------------+---------------+-------------+

使用BERT预训练模型+自定义网络,模型初始时就具有了较高的准确率。

本文链接地址:https://www.jiuchutong.com/zhishi/304715.html 转载请保留说明!

上一篇:phpcms v9怎么安装(phpcms v9官网)

下一篇:301重定向到https 并且不带www跳转到带www

  • 支付增值税会计科目
  • 经营所得汇算在哪里查
  • 税务问答网站
  • 小规模纳税人认定的最新标准2022
  • 翻唱歌曲发行时怎么把原唱名字改了
  • 防伪开票系统技术维护费怎么做分录
  • 租金营改增
  • 备用金可以不要发票吗
  • 网上申购到发行多长时间
  • 报税利润表的第二季度本期金额本年累计金额是什么
  • 企业的哪些活动对企业有长期影响呢
  • 超范围经营如何举报
  • 银行漏记账会造成什么后果
  • 由收款人签发,经付款人承兑
  • 税前扣除的职工教育经费
  • 发票开出来对方不走账会怎么样?
  • 出口退税进项发票
  • 商品进销差价是资产类账户,其抵减的账户是
  • 去年税收滞纳金是个人缴的,用做账嘛
  • 生产企业出口退税流程怎么操作
  • 销售额增加10%什么概念
  • 旅行社差额征税增值税申报表怎么填
  • 银行转账结算凭证按照填制手续和内容
  • 出纳微信收款之后怎么做
  • 出售固定资产属于收入
  • 因腐败因素形成的损失企业所得税如何处理?
  • 应收账款未计提坏账,但是确实收不回来
  • 付出去的款项退回怎么做账
  • win11怎么下载手机应用
  • 上月有留底税额,这个月有销销没有进项怎么做账
  • 职工教育经费怎么做凭证
  • 废品损失科目的期末余额在借方表示什么?
  • 深入php:面向对象、模式与实践
  • dtft与dfs
  • 失控增值税专用发票
  • vue setstate
  • tabstat命令怎么用
  • 小规模纳税人免增值税的账务处理
  • 帝国cms图片显示不了
  • 员工迟到扣款怎么处理
  • 销售旧货和销售使用过的固定资产
  • mongodb分区分片
  • 通行费电子发票的发票代码为多少位
  • 转账支票背书盖章图位置
  • 申报缴纳上月税费的会计分录
  • db2自增函数
  • 如何安装sql server2022
  • MySQL提示The InnoDB feature is disabled需要开启InnoDB的解决方法
  • 什么是指企业的市场营销活动发生影响的各种因素的总和
  • 资产负债表调整事项
  • 个人所得税的计算标准
  • 冲减应收账款该怎么处理
  • 负数发票开错了是否可以作废?
  • 装修费用摊销的会计分录怎么写
  • 企业注销在建工程怎么处理
  • 餐饮开票税率
  • 购买电脑增值税税率是多少
  • 发票抬头注意事项
  • 有存货的公司
  • 浅谈基于comsol的锂离子电池仿真
  • sqlcipher c#
  • mysql闪退怎么回事
  • mac识别文字软件
  • nwtray.exe - nwtray是什么进程 作用是什么
  • linux yum安装软件命令
  • linux安全工具
  • sendmail邮件服务器的配置
  • centos6.9
  • javascript基础编程
  • 问题少年特训学校
  • nginx日志切割原理
  • node用mongodb还是mysql好
  • npm安装nodemodules
  • javascript高级编程
  • android通信机制
  • python flask
  • 上海电子发票试点
  • 辽宁税务遴选
  • 2021年京东养鸡如何合作
  • 个人股权转让是否增值了怎么判断
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号