位置: IT常识 - 正文

Huggingface之transformers零基础使用指南

编辑:rootadmin
前几篇博文中介绍了Transformer,由于其优越的性能表现,在工业界使用的越来越广泛,同时,配合迁移学习理论,越来越多的Transformer预训练模型和源码库逐渐开源,Huggingface就是其中做的最为出色的一家机构。Huggingface是一家在NLP社区做出杰出贡献的纽约创业公司,其所... ...

推荐整理分享Huggingface之transformers零基础使用指南,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

前几篇博文中介绍了Transformer,由于其优越的性能表现,在工业界使用的越来越广泛,同时,配合迁移学习理论,越来越多的Transformer预训练模型和源码库逐渐开源,Huggingface就是其中做的最为出色的一家机构。Huggingface是一家在NLP社区做出杰出贡献的纽约创业公司,其所提供的大量预训练模型和代码等资源被广泛的应用于学术研究当中。Huggingface所开源的Transformers提供了数以千计针对于各种任务的预训练模型模型,开发者可以根据自身的需要,选择模型进行训练或微调,也可阅读api文档和源码, 快速开发新模型。

本篇博文,我们对Huggingface所开源的Transformers进行介绍。在此之前,请通过下行命令安装transformers库:

pip install transformers1 从AutoClass说起¶

transformers库中提供了上百个算法模型的实现,有BERT模型对应的BertModel类,有BART对应的BartModel类……,每当我们使用对应的预训练模型时,都必须先找到对应类名,然后进行实例化,麻烦吗?非常麻烦!

所以,transformers库中提供统一的入口,也就是我们这里说到的“AutoClass”系列的高级对象,通过在调用“AutoClass”的from_pretrained()方法时指定预训练模型的名称或预训练模型所在目录,即可快速、便捷得完成预训练模型创建。有了“AutoClass”,只需要知道预训练模型的名称,或者将预训练模型下载好,程序将根据预训练模型配置文件中model_type或者预训练模型名称、路径进行模式匹配,自动决定实例化哪一个模型类,不再需要再到该模型在transfors库中对应的类名。“AutoClass”所有类都不能够通过init()方法进行实例化,只能通过from_pretrained()方法实例化指定的类。

如下所示,我们到Huggingface官网下载好一个中文BERT预训练模型,模型所有文件存放在当前目录下的“model/bert-base-chinese”路径下。创建预训练模型时,我们将这一路径传递到from_pretrained()方法,即可完成模型创建,创建好的模型为BertModel类的实例。

In[1]:from transformers import AutoModelIn[4]:model = AutoModel.from_pretrained("./models/bert-base-chinese")print(type(model))Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).<class 'transformers.models.bert.modeling_bert.BertModel'>

可以看到,。这一过程中,之所以会有提示信息显示,是因为有些权重参数并未使用上,这是正常的。

之所以说“AutoClass”是一个系列,是因为“AutoClass”不仅包括便捷创建预训练模型的类对象AutoModel,还包括预训练模型对应的分词器类对象AutoTokenizer,预训练模型配置管理类AutoTokenizer,以及其他各种特色功能的类对象,这里不一一列举,看一参考Huggingface官方文档对“AutoClass”的说明。

2 词的向量表示——AutoTokenizer¶

几乎所有的自然语言处理任务,都是从分词和词的向量表示开始的,Transformer算法模型也不例外,所以,在Huggingface的transformers库中提供了高级API对象——AutoTokenizer,用以加载预训练的分词器实现这一过程。

AutoTokenizer是Huggingface提供的“AutoClass”系列的高级对象,可以便捷的调用tokenizers库(Huggingface提供的专门用于分词等操作的代码库)实现加载预训练的分词器。

通过在AutoTokenizer中定义的from_pretrained方法指定需要加载的分词器名称,即可从网络上自动加载分词器,并实例化tokenizers库中分词器。tokenizers中定义的分词器对象提供非常丰富的功能,例如定义词库、加载词库、截断、填充、指定特殊标记等。

`这里需要注意,大多数情况下,我们都是同时使用预定义的分词器和预训练模型,或者说是配套使用的,例如,我们使用的预训练模型是“bert-base-chinese”,那么,加载分词器是,也必须使用“bert-base-chinese”对应的词库,否则,使用预训练模型就效果将大大降低。`

In[5]:from transformers import AutoTokenizerIn[7]:tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")In[8]:sentence = "床前明月光"tokenizer(sentence)Out[8]:{'input_ids': [101, 2414, 1184, 3209, 3299, 1045, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}指定目录加载分词器

当然,有时候因为网络原因,也可以先手动从Huggingface官网下载模型,然后在from_pretrained方法中指定本地目录方式进行加载。

In[9]:tokenizer = AutoTokenizer.from_pretrained("./models/bert-base-chinese")In[10]:sentence = "床前明月光"tokenizer(sentence)Out[10]:{'input_ids': [101, 2414, 1184, 3209, 3299, 1045, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}同时转化多个句子In[11]:sentence = ["床前明月光", "床前明月光,疑是地上霜。"]tokenizer(sentence)Out[11]:{'input_ids': [[101, 2414, 1184, 3209, 3299, 1045, 102], [101, 2414, 1184, 3209, 3299, 1045, 8024, 4542, 3221, 1765, 677, 7458, 511, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}其他参数功能

tokenizer内部还提供其他丰富的参数用于实现多种多样功能:

In[12]:tokenizer( ["床前明月光", "床前明月光,疑是地上霜。"], padding=True, # 长度不足max_length时是否进行填充 truncation=True, # 长度超过max_length时是否进行截断 max_length=10, return_tensors="pt", # 指定返回数据类型,pt:pytorch的张量,tf:TensorFlow的张量)Out[12]:{'input_ids': tensor([[ 101, 2414, 1184, 3209, 3299, 1045, 102, 0, 0, 0], [ 101, 2414, 1184, 3209, 3299, 1045, 8024, 4542, 3221, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}3 配置——AutoConfig¶

每一类的算法模型的框架结构都是不一样的,所以超参数配置也不一样,如果每次加载预训练模型,都要用户手动去找到对应的配置项、配置类,就是非不便捷了,所以,在“AutoClass”中也提供有专门的配置管理入口——AutoConfig。

Huggingface之transformers零基础使用指南

一般来说,就算同一个算法的预训练模型,也可能有不同的网络结构,所以,我们下载的预训练模型本身就提供有一个配置文件,例如在Huggingface官网下载的预训练模型,提供有一个config.json文件,AutoConfig将从里面加载当前预训练模型的特定配置项信息进行覆盖。

以BERT模型为例,我们下来看看默认的配置项:

In[13]:from transformers import BertConfigconfig = BertConfig()In[14]:configOut[14]:BertConfig { "attention_probs_dropout_prob": 0.1, "classifier_dropout": null, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 0, "position_embedding_type": "absolute", "transformers_version": "4.24.0", "type_vocab_size": 2, "use_cache": true, "vocab_size": 30522}In[15]:from transformers import AutoConfigconfig = AutoConfig.from_pretrained("./models/bert-base-chinese")In[16]:configOut[16]:BertConfig { "_name_or_path": "./models/bert-base-chinese", "attention_probs_dropout_prob": 0.1, "classifier_dropout": null, "directionality": "bidi", "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "lstm_dropout_prob": 0.5, "lstm_embedding_size": 768, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 0, "pooler_fc_size": 768, "pooler_num_attention_heads": 12, "pooler_num_fc_layers": 3, "pooler_size_per_head": 128, "pooler_type": "first_token_transform", "position_embedding_type": "absolute", "transformers_version": "4.24.0", "type_vocab_size": 2, "use_cache": true, "vocab_size": 21128}

可以看到,从预训练模型加载出来的配置项与之前的默认配置项略有不同。而且,这个配置实例就是BertConfig类的实例,如下所示:

In[17]:type(config)Out[17]:transformers.models.bert.configuration_bert.BertConfig

通过config实例,我们可以对配置项进行修改,例如,上述配置中,编码器结构为12层编码器层,我们将其修改为5层,如下所示,经过修改后,最终创建的模型编码器只包含5层结构,也只有前5层会加载预训练结构,其他权重将会被舍弃。

In[18]:config.num_hidden_layers=5print(config)BertConfig { "_name_or_path": "./models/bert-base-chinese", "attention_probs_dropout_prob": 0.1, "classifier_dropout": null, "directionality": "bidi", "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "lstm_dropout_prob": 0.5, "lstm_embedding_size": 768, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 5, "pad_token_id": 0, "pooler_fc_size": 768, "pooler_num_attention_heads": 12, "pooler_num_fc_layers": 3, "pooler_size_per_head": 128, "pooler_type": "first_token_transform", "position_embedding_type": "absolute", "transformers_version": "4.24.0", "type_vocab_size": 2, "use_cache": true, "vocab_size": 21128}

修改之后的参数,如果后续需要再次使用,可以保存到本地,传入保存路径,将在指定目录保存为config.json文件:

In[75]:config.save_pretrained("./models/bert-base-chinese")4 创建预训练模型——AutoModel¶

Huggingface官方提供了很多的预训练模型,可以在Huggingface官网很容易找到。通过AutoModel类,创建预训练模型最简单的方法就是直接传入预训练模型名称或者本地路径,因为国内网络环境原因,建议先去将预训练模型下载到本地,通过指定目录的方式进行加载:

In[19]:from transformers import AutoModelmodel = AutoModel.from_pretrained("./models/bert-base-chinese")Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

通过这种方法,模型将直接加载预训练模型config.json的配置项。也可以在加载模型时,指定配置类实例,这样就可以实现对预训练模型的自定义,例如,传入我们上一小节中修改后的config实例:

In[20]:model = AutoModel.from_pretrained("./models/bert-base-chinese", config=config)Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.weight', 'bert.encoder.layer.8.attention.self.value.bias', 'bert.encoder.layer.8.attention.output.dense.bias', 'bert.encoder.layer.10.attention.self.query.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'bert.encoder.layer.7.attention.output.dense.bias', 'bert.encoder.layer.9.attention.self.value.bias', 'bert.encoder.layer.7.attention.self.query.weight', 'bert.encoder.layer.10.output.LayerNorm.weight', 'bert.encoder.layer.9.attention.output.LayerNorm.bias', 'bert.encoder.layer.10.attention.output.LayerNorm.weight', 'bert.encoder.layer.9.attention.self.key.weight', 'cls.predictions.transform.dense.weight', 'bert.encoder.layer.5.attention.self.query.weight', 'bert.encoder.layer.11.output.LayerNorm.weight', 'bert.encoder.layer.6.attention.self.query.weight', 'bert.encoder.layer.10.output.dense.weight', 'bert.encoder.layer.6.attention.output.dense.bias', 'bert.encoder.layer.5.attention.self.key.weight', 'bert.encoder.layer.7.attention.self.value.bias', 'bert.encoder.layer.5.intermediate.dense.weight', 'bert.encoder.layer.9.intermediate.dense.weight', 'bert.encoder.layer.5.attention.self.query.bias', 'bert.encoder.layer.7.attention.self.key.weight', 'bert.encoder.layer.11.output.dense.weight', 'bert.encoder.layer.8.attention.self.key.weight', 'bert.encoder.layer.10.output.dense.bias', 'bert.encoder.layer.10.attention.output.dense.weight', 'bert.encoder.layer.11.intermediate.dense.bias', 'bert.encoder.layer.9.output.LayerNorm.weight', 'bert.encoder.layer.9.output.LayerNorm.bias', 'bert.encoder.layer.6.attention.self.value.weight', 'bert.encoder.layer.10.attention.output.dense.bias', 'bert.encoder.layer.11.attention.output.dense.bias', 'bert.encoder.layer.10.intermediate.dense.weight', 'bert.encoder.layer.10.attention.self.value.weight', 'bert.encoder.layer.6.attention.self.value.bias', 'bert.encoder.layer.6.attention.self.query.bias', 'bert.encoder.layer.11.intermediate.dense.weight', 'bert.encoder.layer.5.attention.output.LayerNorm.bias', 'bert.encoder.layer.8.output.dense.weight', 'bert.encoder.layer.11.attention.self.query.weight', 'bert.encoder.layer.7.intermediate.dense.bias', 'bert.encoder.layer.9.output.dense.bias', 'bert.encoder.layer.11.attention.self.value.bias', 'bert.encoder.layer.5.attention.self.value.bias', 'bert.encoder.layer.11.attention.output.LayerNorm.weight', 'bert.encoder.layer.10.attention.self.key.bias', 'bert.encoder.layer.9.attention.output.LayerNorm.weight', 'bert.encoder.layer.6.output.dense.weight', 'bert.encoder.layer.6.output.LayerNorm.bias', 'bert.encoder.layer.7.attention.self.key.bias', 'bert.encoder.layer.11.output.LayerNorm.bias', 'bert.encoder.layer.8.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.value.weight', 'bert.encoder.layer.8.attention.output.dense.weight', 'bert.encoder.layer.9.attention.output.dense.weight', 'cls.predictions.bias', 'bert.encoder.layer.9.output.dense.weight', 'bert.encoder.layer.8.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.self.key.bias', 'bert.encoder.layer.6.attention.self.key.bias', 'bert.encoder.layer.9.attention.self.query.bias', 'bert.encoder.layer.6.intermediate.dense.bias', 'bert.encoder.layer.6.attention.self.key.weight', 'bert.encoder.layer.8.attention.self.key.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.output.dense.bias', 'bert.encoder.layer.6.output.dense.bias', 'bert.encoder.layer.11.attention.self.key.weight', 'bert.encoder.layer.7.output.dense.weight', 'bert.encoder.layer.8.intermediate.dense.bias', 'bert.encoder.layer.5.attention.self.value.weight', 'bert.encoder.layer.7.output.LayerNorm.weight', 'bert.encoder.layer.5.output.dense.weight', 'bert.encoder.layer.11.output.dense.bias', 'bert.encoder.layer.8.output.dense.bias', 'bert.encoder.layer.10.attention.self.query.weight', 'bert.encoder.layer.9.intermediate.dense.bias', 'bert.encoder.layer.9.attention.self.value.weight', 'bert.encoder.layer.11.attention.output.LayerNorm.bias', 'bert.encoder.layer.7.intermediate.dense.weight', 'bert.encoder.layer.7.attention.self.query.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.weight', 'bert.encoder.layer.5.output.LayerNorm.weight', 'cls.predictions.decoder.weight', 'bert.encoder.layer.9.attention.self.query.weight', 'bert.encoder.layer.6.intermediate.dense.weight', 'bert.encoder.layer.10.output.LayerNorm.bias', 'bert.encoder.layer.11.attention.output.dense.weight', 'bert.encoder.layer.10.intermediate.dense.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.key.bias', 'bert.encoder.layer.8.intermediate.dense.weight', 'bert.encoder.layer.5.output.dense.bias', 'bert.encoder.layer.5.attention.output.dense.bias', 'bert.encoder.layer.8.attention.output.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'bert.encoder.layer.8.attention.self.value.weight', 'cls.predictions.transform.LayerNorm.bias', 'bert.encoder.layer.8.attention.self.query.bias', 'bert.encoder.layer.5.attention.output.dense.weight', 'bert.encoder.layer.7.output.dense.bias', 'cls.seq_relationship.bias', 'bert.encoder.layer.6.attention.output.dense.weight', 'bert.encoder.layer.11.attention.self.query.bias', 'bert.encoder.layer.7.attention.self.value.weight', 'bert.encoder.layer.8.attention.output.LayerNorm.weight', 'bert.encoder.layer.6.output.LayerNorm.weight', 'bert.encoder.layer.5.attention.self.key.bias', 'bert.encoder.layer.10.attention.self.value.bias', 'bert.encoder.layer.5.attention.output.LayerNorm.weight', 'bert.encoder.layer.7.output.LayerNorm.bias', 'bert.encoder.layer.5.intermediate.dense.bias', 'bert.encoder.layer.7.attention.output.dense.weight', 'bert.encoder.layer.10.attention.output.LayerNorm.bias', 'bert.encoder.layer.8.attention.self.query.weight', 'bert.encoder.layer.10.attention.self.key.weight', 'bert.encoder.layer.5.output.LayerNorm.bias']- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

因为在上一章节,我们将编码器结构层数改为5,所以,这里提示很多权重参数并未使用。

同时,我们也可以通过在from_pretrained()方法中直接传参的方式,传入配置项,例如,我们将编码器层数改为3层。注意,这种方式在指定了config参数时不在生效。

In[21]:model = AutoModel.from_pretrained("./models/bert-base-chinese", num_hidden_layers=3)Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['bert.encoder.layer.4.attention.self.value.bias', 'cls.seq_relationship.weight', 'bert.encoder.layer.8.attention.self.value.bias', 'bert.encoder.layer.8.attention.output.dense.bias', 'bert.encoder.layer.10.attention.self.query.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'bert.encoder.layer.7.attention.output.dense.bias', 'bert.encoder.layer.9.attention.self.value.bias', 'bert.encoder.layer.7.attention.self.query.weight', 'bert.encoder.layer.10.output.LayerNorm.weight', 'bert.encoder.layer.9.attention.output.LayerNorm.bias', 'bert.encoder.layer.4.intermediate.dense.weight', 'bert.encoder.layer.4.attention.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.self.key.weight', 'cls.predictions.transform.dense.weight', 'bert.encoder.layer.10.attention.output.LayerNorm.weight', 'bert.encoder.layer.4.attention.self.key.weight', 'bert.encoder.layer.3.intermediate.dense.bias', 'bert.encoder.layer.5.attention.self.query.weight', 'bert.encoder.layer.11.output.LayerNorm.weight', 'bert.encoder.layer.6.attention.self.query.weight', 'bert.encoder.layer.10.output.dense.weight', 'bert.encoder.layer.6.attention.output.dense.bias', 'bert.encoder.layer.5.attention.self.key.weight', 'bert.encoder.layer.7.attention.self.value.bias', 'bert.encoder.layer.5.intermediate.dense.weight', 'bert.encoder.layer.9.intermediate.dense.weight', 'bert.encoder.layer.5.attention.self.query.bias', 'bert.encoder.layer.7.attention.self.key.weight', 'bert.encoder.layer.4.output.dense.weight', 'bert.encoder.layer.8.attention.self.key.weight', 'bert.encoder.layer.11.output.dense.weight', 'bert.encoder.layer.10.output.dense.bias', 'bert.encoder.layer.10.attention.output.dense.weight', 'bert.encoder.layer.11.intermediate.dense.bias', 'bert.encoder.layer.3.attention.self.key.weight', 'bert.encoder.layer.9.output.LayerNorm.weight', 'bert.encoder.layer.9.output.LayerNorm.bias', 'bert.encoder.layer.3.attention.self.value.bias', 'bert.encoder.layer.6.attention.self.value.weight', 'bert.encoder.layer.10.attention.output.dense.bias', 'bert.encoder.layer.11.attention.output.dense.bias', 'bert.encoder.layer.10.intermediate.dense.weight', 'bert.encoder.layer.10.attention.self.value.weight', 'bert.encoder.layer.6.attention.self.value.bias', 'bert.encoder.layer.6.attention.self.query.bias', 'bert.encoder.layer.4.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.output.dense.bias', 'bert.encoder.layer.11.intermediate.dense.weight', 'bert.encoder.layer.5.attention.output.LayerNorm.bias', 'bert.encoder.layer.4.attention.self.query.weight', 'bert.encoder.layer.8.output.dense.weight', 'bert.encoder.layer.11.attention.self.query.weight', 'bert.encoder.layer.3.intermediate.dense.weight', 'bert.encoder.layer.4.attention.output.dense.bias', 'bert.encoder.layer.7.intermediate.dense.bias', 'bert.encoder.layer.9.output.dense.bias', 'bert.encoder.layer.11.attention.self.value.bias', 'bert.encoder.layer.5.attention.self.value.bias', 'bert.encoder.layer.11.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.attention.output.dense.weight', 'bert.encoder.layer.10.attention.self.key.bias', 'bert.encoder.layer.9.attention.output.LayerNorm.weight', 'bert.encoder.layer.4.output.LayerNorm.bias', 'bert.encoder.layer.6.output.dense.weight', 'bert.encoder.layer.6.output.LayerNorm.bias', 'bert.encoder.layer.7.attention.self.key.bias', 'bert.encoder.layer.11.output.LayerNorm.bias', 'bert.encoder.layer.8.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.value.weight', 'bert.encoder.layer.8.attention.output.dense.weight', 'bert.encoder.layer.9.attention.output.dense.weight', 'cls.predictions.bias', 'bert.encoder.layer.9.output.dense.weight', 'bert.encoder.layer.8.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.self.key.bias', 'bert.encoder.layer.6.attention.self.key.bias', 'bert.encoder.layer.3.output.dense.weight', 'bert.encoder.layer.9.attention.self.query.bias', 'bert.encoder.layer.6.intermediate.dense.bias', 'bert.encoder.layer.6.attention.self.key.weight', 'bert.encoder.layer.8.attention.self.key.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.bias', 'bert.encoder.layer.3.output.LayerNorm.bias', 'bert.encoder.layer.3.attention.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.output.dense.bias', 'bert.encoder.layer.6.output.dense.bias', 'bert.encoder.layer.11.attention.self.key.weight', 'bert.encoder.layer.7.output.dense.weight', 'bert.encoder.layer.8.intermediate.dense.bias', 'bert.encoder.layer.5.attention.self.value.weight', 'bert.encoder.layer.7.output.LayerNorm.weight', 'bert.encoder.layer.5.output.dense.weight', 'bert.encoder.layer.3.attention.self.query.weight', 'bert.encoder.layer.11.output.dense.bias', 'bert.encoder.layer.4.output.LayerNorm.weight', 'bert.encoder.layer.3.attention.self.key.bias', 'bert.encoder.layer.8.output.dense.bias', 'bert.encoder.layer.9.intermediate.dense.bias', 'bert.encoder.layer.10.attention.self.query.weight', 'bert.encoder.layer.9.attention.self.value.weight', 'bert.encoder.layer.11.attention.output.LayerNorm.bias', 'bert.encoder.layer.4.attention.self.key.bias', 'bert.encoder.layer.7.intermediate.dense.weight', 'bert.encoder.layer.7.attention.self.query.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.weight', 'bert.encoder.layer.5.output.LayerNorm.weight', 'cls.predictions.decoder.weight', 'bert.encoder.layer.9.attention.self.query.weight', 'bert.encoder.layer.6.intermediate.dense.weight', 'bert.encoder.layer.10.output.LayerNorm.bias', 'bert.encoder.layer.11.attention.output.dense.weight', 'bert.encoder.layer.10.intermediate.dense.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.key.bias', 'bert.encoder.layer.8.intermediate.dense.weight', 'bert.encoder.layer.5.output.dense.bias', 'bert.encoder.layer.5.attention.output.dense.bias', 'bert.encoder.layer.8.attention.output.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'bert.encoder.layer.8.attention.self.value.weight', 'cls.predictions.transform.LayerNorm.bias', 'bert.encoder.layer.3.attention.output.LayerNorm.weight', 'bert.encoder.layer.8.attention.self.query.bias', 'bert.encoder.layer.5.attention.output.dense.weight', 'bert.encoder.layer.3.attention.output.dense.bias', 'bert.encoder.layer.7.output.dense.bias', 'bert.encoder.layer.4.attention.output.dense.weight', 'bert.encoder.layer.6.attention.output.dense.weight', 'bert.encoder.layer.11.attention.self.query.bias', 'cls.seq_relationship.bias', 'bert.encoder.layer.4.intermediate.dense.bias', 'bert.encoder.layer.7.attention.self.value.weight', 'bert.encoder.layer.8.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.output.LayerNorm.weight', 'bert.encoder.layer.4.attention.self.query.bias', 'bert.encoder.layer.6.output.LayerNorm.weight', 'bert.encoder.layer.5.attention.self.key.bias', 'bert.encoder.layer.10.attention.self.value.bias', 'bert.encoder.layer.4.output.dense.bias', 'bert.encoder.layer.5.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.attention.self.value.weight', 'bert.encoder.layer.7.output.LayerNorm.bias', 'bert.encoder.layer.5.intermediate.dense.bias', 'bert.encoder.layer.7.attention.output.dense.weight', 'bert.encoder.layer.10.attention.output.LayerNorm.bias', 'bert.encoder.layer.8.attention.self.query.weight', 'bert.encoder.layer.3.attention.self.query.bias', 'bert.encoder.layer.10.attention.self.key.weight', 'bert.encoder.layer.5.output.LayerNorm.bias', 'bert.encoder.layer.4.attention.self.value.weight']- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

我们尝试一下将tokenizer编码后的张量在model中进行前向传播:

In[22]:tens = model(**tokenizer("床前明月光,疑是地上霜。", return_tensors="pt"))In[23]:tens.last_hidden_state.shapeOut[23]:torch.Size([1, 14, 768])

当模型修改或者重新训练后,可以通过model.save_pretrained()方法再次保存,保存后,在指定目录中将生成两个文件:配置文件(config.json),权重文件(pytorch_model.bin)。

In[82]:model.save_pretrained("./new_model/bert-base-chinese")5 使用现成的任务模型¶

在transformers库中,Huggingface还提供有许多完整网络模型用于各式各样的AI任务,例如图像分类、文本分类、语音分类、翻译、问答等,这类API大多以AutoModelFor*开头,我们打印输出看看:

In[31]:import transformersIn[33]:for api in dir(transformers): if api.startswith('AutoModelFor'): print(api)AutoModelForAudioClassificationAutoModelForAudioFrameClassificationAutoModelForAudioXVectorAutoModelForCTCAutoModelForCausalLMAutoModelForDepthEstimationAutoModelForDocumentQuestionAnsweringAutoModelForImageClassificationAutoModelForImageSegmentationAutoModelForInstanceSegmentationAutoModelForMaskedImageModelingAutoModelForMaskedLMAutoModelForMultipleChoiceAutoModelForNextSentencePredictionAutoModelForObjectDetectionAutoModelForPreTrainingAutoModelForQuestionAnsweringAutoModelForSemanticSegmentationAutoModelForSeq2SeqLMAutoModelForSequenceClassificationAutoModelForSpeechSeq2SeqAutoModelForTableQuestionAnsweringAutoModelForTokenClassificationAutoModelForVideoClassificationAutoModelForVision2SeqAutoModelForVisualQuestionAnsweringAutoModelForZeroShotObjectDetection

以其中的AutoModelForSequenceClassification为例,介绍怎么使用:

In[34]:from transformers import AutoModelForSequenceClassificationIn[35]:model = AutoModelForSequenceClassification.from_pretrained("./models/bert-base-chinese", num_labels=2)Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./models/bert-base-chinese and are newly initialized: ['classifier.weight', 'classifier.bias']You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

num_labels意思是,我们需要进行的任务最终标签有两类,即这是一个二分类模型。我们查看一下模型结构:

In[36]:modelOut[36]:BertForSequenceClassification( (bert): BertModel( (embeddings): BertEmbeddings( (word_embeddings): Embedding(21128, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (token_type_embeddings): Embedding(2, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): BertEncoder( (layer): ModuleList( (0): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (1): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (2): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (3): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (4): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (5): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (6): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (7): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (8): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (9): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (10): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (11): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) (pooler): BertPooler( (dense): Linear(in_features=768, out_features=768, bias=True) (activation): Tanh() ) ) (dropout): Dropout(p=0.1, inplace=False) (classifier): Linear(in_features=768, out_features=2, bias=True))In[37]:tokenizer("床前明月光,疑是地上霜。", return_tensors="pt")Out[37]:{'input_ids': tensor([[ 101, 2414, 1184, 3209, 3299, 1045, 8024, 4542, 3221, 1765, 677, 7458, 511, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}In[38]:tens = model(**tokenizer("床前明月光,疑是地上霜。", return_tensors="pt"))In[39]:tensOut[39]:SequenceClassifierOutput(loss=None, logits=tensor([[-0.4371, -0.1223]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

最后输出了两个值,分贝对应于两个类别。

6 自定义模型¶In[1]:from chb import *import pandas as pdfrom tqdm import tqdmfrom collections import defaultdictimport torchfrom torch import nnfrom torch.utils.data.dataloader import DataLoaderfrom torch.utils.data.dataset import Datasetfrom torch.optim import AdamWfrom sklearn.model_selection import train_test_splitfrom transformers import AutoConfig,AutoModel,AutoTokenizer,get_linear_schedule_with_warmup,loggingimport warningswarnings.filterwarnings('ignore')In[2]:RANDOM_SEED = 1000MAX_LEN = 50BATCH_SIZE = 646.1 自定义数据集¶In[3]:x_lst = []y_lst = []with open('./data/中文文本-新闻分类数据集','r') as f: # 获得训练数据的总行数 for _ in tqdm(f,desc='load dataset'): try: line = f.readline().replace('\u3000\u3000', '').replace('\n', '') x, y = line.split('\t') if y == 'label': continue x_lst.append(x) y_lst.append(y) except: passload dataset: 5902it [00:00, 55378.83it/s]In[4]:len(x_lst), len(y_lst)Out[4]:(5900, 5900)In[5]:x_lst[0], y_lst[0]Out[5]:('昌平京基鹭府10月29日推别墅1200万套起享97折新浪房产讯(编辑郭彪)京基鹭府(论坛相册户型样板间点评地图搜索)售楼处位于昌平区京承高速北七家出口向西南公里路南。项目预计10月29日开盘,总价1200万元/套起,2012年年底入住。待售户型为联排户型面积为410-522平方米,独栋户型面积为938平方米,双拼户型面积为522平方米。京基鹭府项目位于昌平定泗路与东北路交界处。项目周边配套齐全,幼儿园:伊顿双语幼儿园、温莎双语幼儿园;中学:北师大亚太实验学校、潞河中学(北京市重点);大学:王府语言学校、北京邮电大学、现代音乐学院;医院:王府中西医结合医院(三级甲等)、潞河医院、解放军263医院、安贞医院昌平分院;购物:龙德广场、中联万家商厦、世纪华联超市、瑰宝购物中心、家乐福超市;酒店:拉斐特城堡、鲍鱼岛;休闲娱乐设施:九华山庄、温都温泉度假村、小汤山疗养院、龙脉温泉度假村、小汤山文化广场、皇港高尔夫、高地高尔夫、北鸿高尔夫球场;银行:工商银行、建设银行、中国银行、北京农村商业银行;邮局:中国邮政储蓄;其它:北七家建材城、百安居建材超市、北七家镇武装部、北京宏翔鸿企业孵化基地等,享受便捷生活。京基鹭府坐守定泗路,立汤路交汇处。连接京昌、八达岭、机场高速,南至5环,北上6环,紧邻立汤路,一路向南,直抵鸟巢、水立方、长安街,距北京唯一不堵车的京承高速出口仅1公里,节约出行时间成本,形成了三横、三纵的立体式交通网络项目周边多为别墅项目,人口密度低,交通出行舒适度高。>>报名参加“乐动银十”10月22日大型抄底看房团以上信息仅供参考,最终以开发商公布为准。订阅会员置业刊我们将直接把最新的热盘动向发送到您的邮箱更多热盘推荐:新锐白领淘低价1-2居网罗2万内轨道精装房手握20万咋买板楼2居网罗城南沿轨优质盘关注娃娃教育网罗学区房不足百万元住上通透2居', '房产')In[6]:label2id = dict()id2label = dict()for i, label in enumerate(set(y_lst)): label2id[label] = i id2label[i] = labelIn[7]:tokenizer = AutoTokenizer.from_pretrained("./models/bert-base-chinese")

先把所有的文本都转化为编码,而不是在后续数据集中转化,这样可以避免在后续训练过程中,每一个epoch都要进行转化,提升效率:

In[8]:token_lens = []for txt in tqdm(x_lst): tokens = tokenizer.encode(txt, max_length=512) token_lens.append(len(tokens)) 0%| | 0/5900 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.100%|██████████| 5900/5900 [00:07<00:00, 739.64it/s]In[9]:class NewsDataset(Dataset): def __init__(self,x_lst,y_lst,tokenizer,max_len): self.x_lst=x_lst self.y_lst=y_lst self.tokenizer=tokenizer self.max_len=max_len def __len__(self): return len(self.x_lst) def __getitem__(self,index): """ index 为数据索引,迭代取第index条数据 """ text=str(self.x_lst[index]) label=label2id[self.y_lst[index]] encoding=self.tokenizer.encode_plus( text, add_special_tokens=True, max_length=self.max_len, return_token_type_ids=True, pad_to_max_length=True, return_attention_mask=True, return_tensors='pt', ) return { 'texts':text, 'input_ids':encoding['input_ids'].flatten(), 'attention_mask':encoding['attention_mask'].flatten(), 'labels':torch.tensor(label,dtype=torch.long) }In[10]:x_train, x_val, y_train, y_val = train_test_split(x_lst, y_lst, test_size=0.15, random_state=RANDOM_SEED) # 划分训练集 测试集In[11]:# datasettrain_dataset = NewsDataset(x_train, y_train, tokenizer, MAX_LEN)val_dataset = NewsDataset(x_val, y_val, tokenizer, MAX_LEN)In[12]:# dataloadertrain_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)6.2 自定义网络¶

这里我们使用BERT预训练模型,同时接Dropout层和一层线形层,构成自定义网络:

In[13]:class CustomBERTModel(nn.Module): def __init__(self, n_classes): super(CustomBERTModel, self).__init__() self.bert = AutoModel.from_pretrained("./models/bert-base-chinese") self.drop = nn.Dropout(p=0.3) self.out = nn.Linear(self.bert.config.hidden_size, n_classes) def forward(self, input_ids, attention_mask): _, pooled_output = self.bert( input_ids=input_ids, attention_mask=attention_mask, return_dict = False ) output = self.drop(pooled_output) # dropout return self.out(output)In[14]:device = set_device(cuda_index=1)2022-12-20 16:12:39 set_device line 11 out: cuda:1In[15]:n_classes = len(label2id)model = CustomBERTModel(n_classes)model = model.to(device)Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

自定义数据集:

6.3 训练¶In[16]:def train_epoch(model, data_loader,loss_fn,optimizer,device,scheduler, n_examples): model.train() losses = [] correct_predictions = 0 for i, d in bar(data_loader): input_ids = d["input_ids"].to(device) attention_mask = d["attention_mask"].to(device) targets = d["labels"].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask ) _, preds = torch.max(outputs, dim=1) loss = loss_fn(outputs, targets) correct_predictions += torch.sum(preds == targets) losses.append(loss.item()) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() scheduler.step() optimizer.zero_grad() return correct_predictions.double() / n_examples, np.mean(losses)In[17]:def eval_model(model, data_loader, loss_fn, device, n_examples): model.eval() # 验证预测模式 losses = [] correct_predictions = 0 with torch.no_grad(): for d in data_loader: input_ids = d["input_ids"].to(device) attention_mask = d["attention_mask"].to(device) targets = d["labels"].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask ) _, preds = torch.max(outputs, dim=1) loss = loss_fn(outputs, targets) correct_predictions += torch.sum(preds == targets) losses.append(loss.item()) return correct_predictions.double() / n_examples, np.mean(losses)In[18]:EPOCHS = 5 # 训练轮数optimizer = AdamW(model.parameters(), lr=2e-5)total_steps = len(train_dataloader) * EPOCHSscheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, num_training_steps=total_steps)loss_fn = nn.CrossEntropyLoss().to(device)In[19]:best_accuracy = 0is_best = Falset = Tableprint(['epoch', 'train_accuracy', 'train_loss', 'test_accuracy', 'test_loss', 'is_best'])t.print_header()for epoch in range(EPOCHS): train_acc, train_loss = train_epoch(model,train_dataloader,loss_fn,optimizer,device,scheduler,len(x_train)) val_acc, val_loss = eval_model( model, val_dataloader, loss_fn, device, len(x_val) ) if val_acc > best_accuracy: is_best = True torch.save(model.state_dict(), './models/news_classification/best_model_state.bin') best_accuracy = val_acc else: is_best = False t.print_row(epoch, f"{train_acc:.4f}", f"{train_loss:.4f}", f"{val_acc:.4f}", f"{val_loss:.4f}", is_best)+======+===========+====================+================+===================+===============+=============+| | epoch | train_accuracy | train_loss | test_accuracy | test_loss | is_best |+======+===========+====================+================+===================+===============+=============+| 1 | 0 | 0.6080 | 1.4608 | 0.8893 | 0.5278 | True |+------+-----------+--------------------+----------------+-------------------+---------------+-------------+| 2 | 1 | 0.9196 | 0.3766 | 0.9096 | 0.3583 | True |+------+-----------+--------------------+----------------+-------------------+---------------+-------------+| 3 | 2 | 0.9589 | 0.2015 | 0.9153 | 0.3413 | True |+------+-----------+--------------------+----------------+-------------------+---------------+-------------+| 4 | 3 | 0.9765 | 0.1272 | 0.9153 | 0.3286 | False |+------+-----------+--------------------+----------------+-------------------+---------------+-------------+| 5 | 4 | 0.9836 | 0.0919 | 0.9220 | 0.3239 | True |+------+-----------+--------------------+----------------+-------------------+---------------+-------------+

使用BERT预训练模型+自定义网络,模型初始时就具有了较高的准确率。

本文链接地址:https://www.jiuchutong.com/zhishi/304715.html 转载请保留说明!

上一篇:phpcms v9怎么安装(phpcms v9官网)

下一篇:301重定向到https 并且不带www跳转到带www

  • 同一商品税收分类编码不一样
  • 印花税办理流程
  • 工会经费的应税项是什么
  • 航天金税财务软件使用说明
  • 清算资产处置
  • 小规模纳税人收到专票后如何处理
  • 包工包料怎么开税票合适
  • 装载机折旧年限是几年
  • 最近的火车票代售点
  • 一张发票开不足一台设备的金额怎么办
  • 关联方交易容易产生哪些弊端
  • 中小企业的管理者角色和技能有什么要求
  • 非上市公司自然股权转让
  • 车险代缴费
  • 收取外部客户电费如何结转成本
  • 商场销售化妆品应当缴纳增值税和消费税对吗
  • 研发样机是什么
  • 销售额没有达到要求企业采取措施
  • 装修改造空调尾板多少钱
  • 定期存款利息收入现金流
  • 餐饮业收入的会计分录及摘要
  • 收到对方公司的货款怎么记账
  • 账面价值,账面净值,账面余额
  • 主营业务收入平均增长率计算例题
  • 有关预提费用如何冲销
  • 空头支票怎么办
  • 投标保证金利息怎么做账
  • mac auto tune
  • 内置管理员无法打开此应用
  • 公积金提取条件和标准
  • 预收外汇如何结汇
  • 民办非企业单位是什么企业类型
  • u盘格式化技巧
  • 补缴企业所得税和滞纳金如何入账
  • Docker部署nginx
  • php floor
  • 外汇申报是什么意思
  • 发票章需要注销吗
  • web后端开发框架有哪些
  • php内涵
  • php获取随机数
  • 嵌入式开关安装效果图
  • 功能强大的词语
  • 税号一般多少位数字
  • 企业注销前的账务处理
  • 企业销售旧车如何开票
  • element级联选择器动态获取数据
  • 出口货物赠品如何申报
  • 公司向股东借的钱怎么还
  • 租赁合同的印花税怎么交
  • mysql误删数据
  • 企业不需要交残保金吗
  • 房产税应纳税额计算例题
  • 合并设立是什么意思
  • 邮寄快递费用计算
  • 委托加工的相关法律规定
  • 来料加工账务处理流程
  • 职工宿舍怎么入账
  • 投入的资金如何做账
  • 担保公司预计负债
  • 软件测试费用明细
  • 挂靠企业电费如何处理?
  • 企业购买加油卡出售怎么做账
  • 建账的注意事项
  • 如何设置自动删除安装包
  • macos使用方法
  • centos永久修改主机名
  • windows更新后一直在欢迎界面
  • javascript判断语句
  • css写文字
  • 模拟新浪微博用户注册程序设计
  • python从入门到精通第三版pdf下载
  • jquery怎么打开
  • jquery的用法
  • dom,ran
  • android studio 安装好后怎么在桌面找到
  • python如何编程
  • 如何知道公司所有账户
  • 学费报销找学校哪个部门
  • 深圳税务忘记密码
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设