位置: IT常识 - 正文

Huggingface之transformers零基础使用指南

编辑:rootadmin
前几篇博文中介绍了Transformer,由于其优越的性能表现,在工业界使用的越来越广泛,同时,配合迁移学习理论,越来越多的Transformer预训练模型和源码库逐渐开源,Huggingface就是其中做的最为出色的一家机构。Huggingface是一家在NLP社区做出杰出贡献的纽约创业公司,其所... ...

推荐整理分享Huggingface之transformers零基础使用指南,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

前几篇博文中介绍了Transformer,由于其优越的性能表现,在工业界使用的越来越广泛,同时,配合迁移学习理论,越来越多的Transformer预训练模型和源码库逐渐开源,Huggingface就是其中做的最为出色的一家机构。Huggingface是一家在NLP社区做出杰出贡献的纽约创业公司,其所提供的大量预训练模型和代码等资源被广泛的应用于学术研究当中。Huggingface所开源的Transformers提供了数以千计针对于各种任务的预训练模型模型,开发者可以根据自身的需要,选择模型进行训练或微调,也可阅读api文档和源码, 快速开发新模型。

本篇博文,我们对Huggingface所开源的Transformers进行介绍。在此之前,请通过下行命令安装transformers库:

pip install transformers1 从AutoClass说起¶

transformers库中提供了上百个算法模型的实现,有BERT模型对应的BertModel类,有BART对应的BartModel类……,每当我们使用对应的预训练模型时,都必须先找到对应类名,然后进行实例化,麻烦吗?非常麻烦!

所以,transformers库中提供统一的入口,也就是我们这里说到的“AutoClass”系列的高级对象,通过在调用“AutoClass”的from_pretrained()方法时指定预训练模型的名称或预训练模型所在目录,即可快速、便捷得完成预训练模型创建。有了“AutoClass”,只需要知道预训练模型的名称,或者将预训练模型下载好,程序将根据预训练模型配置文件中model_type或者预训练模型名称、路径进行模式匹配,自动决定实例化哪一个模型类,不再需要再到该模型在transfors库中对应的类名。“AutoClass”所有类都不能够通过init()方法进行实例化,只能通过from_pretrained()方法实例化指定的类。

如下所示,我们到Huggingface官网下载好一个中文BERT预训练模型,模型所有文件存放在当前目录下的“model/bert-base-chinese”路径下。创建预训练模型时,我们将这一路径传递到from_pretrained()方法,即可完成模型创建,创建好的模型为BertModel类的实例。

In[1]:from transformers import AutoModelIn[4]:model = AutoModel.from_pretrained("./models/bert-base-chinese")print(type(model))Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).<class 'transformers.models.bert.modeling_bert.BertModel'>

可以看到,。这一过程中,之所以会有提示信息显示,是因为有些权重参数并未使用上,这是正常的。

之所以说“AutoClass”是一个系列,是因为“AutoClass”不仅包括便捷创建预训练模型的类对象AutoModel,还包括预训练模型对应的分词器类对象AutoTokenizer,预训练模型配置管理类AutoTokenizer,以及其他各种特色功能的类对象,这里不一一列举,看一参考Huggingface官方文档对“AutoClass”的说明。

2 词的向量表示——AutoTokenizer¶

几乎所有的自然语言处理任务,都是从分词和词的向量表示开始的,Transformer算法模型也不例外,所以,在Huggingface的transformers库中提供了高级API对象——AutoTokenizer,用以加载预训练的分词器实现这一过程。

AutoTokenizer是Huggingface提供的“AutoClass”系列的高级对象,可以便捷的调用tokenizers库(Huggingface提供的专门用于分词等操作的代码库)实现加载预训练的分词器。

通过在AutoTokenizer中定义的from_pretrained方法指定需要加载的分词器名称,即可从网络上自动加载分词器,并实例化tokenizers库中分词器。tokenizers中定义的分词器对象提供非常丰富的功能,例如定义词库、加载词库、截断、填充、指定特殊标记等。

`这里需要注意,大多数情况下,我们都是同时使用预定义的分词器和预训练模型,或者说是配套使用的,例如,我们使用的预训练模型是“bert-base-chinese”,那么,加载分词器是,也必须使用“bert-base-chinese”对应的词库,否则,使用预训练模型就效果将大大降低。`

In[5]:from transformers import AutoTokenizerIn[7]:tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")In[8]:sentence = "床前明月光"tokenizer(sentence)Out[8]:{'input_ids': [101, 2414, 1184, 3209, 3299, 1045, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}指定目录加载分词器

当然,有时候因为网络原因,也可以先手动从Huggingface官网下载模型,然后在from_pretrained方法中指定本地目录方式进行加载。

In[9]:tokenizer = AutoTokenizer.from_pretrained("./models/bert-base-chinese")In[10]:sentence = "床前明月光"tokenizer(sentence)Out[10]:{'input_ids': [101, 2414, 1184, 3209, 3299, 1045, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}同时转化多个句子In[11]:sentence = ["床前明月光", "床前明月光,疑是地上霜。"]tokenizer(sentence)Out[11]:{'input_ids': [[101, 2414, 1184, 3209, 3299, 1045, 102], [101, 2414, 1184, 3209, 3299, 1045, 8024, 4542, 3221, 1765, 677, 7458, 511, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}其他参数功能

tokenizer内部还提供其他丰富的参数用于实现多种多样功能:

In[12]:tokenizer( ["床前明月光", "床前明月光,疑是地上霜。"], padding=True, # 长度不足max_length时是否进行填充 truncation=True, # 长度超过max_length时是否进行截断 max_length=10, return_tensors="pt", # 指定返回数据类型,pt:pytorch的张量,tf:TensorFlow的张量)Out[12]:{'input_ids': tensor([[ 101, 2414, 1184, 3209, 3299, 1045, 102, 0, 0, 0], [ 101, 2414, 1184, 3209, 3299, 1045, 8024, 4542, 3221, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}3 配置——AutoConfig¶

每一类的算法模型的框架结构都是不一样的,所以超参数配置也不一样,如果每次加载预训练模型,都要用户手动去找到对应的配置项、配置类,就是非不便捷了,所以,在“AutoClass”中也提供有专门的配置管理入口——AutoConfig。

Huggingface之transformers零基础使用指南

一般来说,就算同一个算法的预训练模型,也可能有不同的网络结构,所以,我们下载的预训练模型本身就提供有一个配置文件,例如在Huggingface官网下载的预训练模型,提供有一个config.json文件,AutoConfig将从里面加载当前预训练模型的特定配置项信息进行覆盖。

以BERT模型为例,我们下来看看默认的配置项:

In[13]:from transformers import BertConfigconfig = BertConfig()In[14]:configOut[14]:BertConfig { "attention_probs_dropout_prob": 0.1, "classifier_dropout": null, "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 0, "position_embedding_type": "absolute", "transformers_version": "4.24.0", "type_vocab_size": 2, "use_cache": true, "vocab_size": 30522}In[15]:from transformers import AutoConfigconfig = AutoConfig.from_pretrained("./models/bert-base-chinese")In[16]:configOut[16]:BertConfig { "_name_or_path": "./models/bert-base-chinese", "attention_probs_dropout_prob": 0.1, "classifier_dropout": null, "directionality": "bidi", "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "lstm_dropout_prob": 0.5, "lstm_embedding_size": 768, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 12, "pad_token_id": 0, "pooler_fc_size": 768, "pooler_num_attention_heads": 12, "pooler_num_fc_layers": 3, "pooler_size_per_head": 128, "pooler_type": "first_token_transform", "position_embedding_type": "absolute", "transformers_version": "4.24.0", "type_vocab_size": 2, "use_cache": true, "vocab_size": 21128}

可以看到,从预训练模型加载出来的配置项与之前的默认配置项略有不同。而且,这个配置实例就是BertConfig类的实例,如下所示:

In[17]:type(config)Out[17]:transformers.models.bert.configuration_bert.BertConfig

通过config实例,我们可以对配置项进行修改,例如,上述配置中,编码器结构为12层编码器层,我们将其修改为5层,如下所示,经过修改后,最终创建的模型编码器只包含5层结构,也只有前5层会加载预训练结构,其他权重将会被舍弃。

In[18]:config.num_hidden_layers=5print(config)BertConfig { "_name_or_path": "./models/bert-base-chinese", "attention_probs_dropout_prob": 0.1, "classifier_dropout": null, "directionality": "bidi", "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "lstm_dropout_prob": 0.5, "lstm_embedding_size": 768, "max_position_embeddings": 512, "model_type": "bert", "num_attention_heads": 12, "num_hidden_layers": 5, "pad_token_id": 0, "pooler_fc_size": 768, "pooler_num_attention_heads": 12, "pooler_num_fc_layers": 3, "pooler_size_per_head": 128, "pooler_type": "first_token_transform", "position_embedding_type": "absolute", "transformers_version": "4.24.0", "type_vocab_size": 2, "use_cache": true, "vocab_size": 21128}

修改之后的参数,如果后续需要再次使用,可以保存到本地,传入保存路径,将在指定目录保存为config.json文件:

In[75]:config.save_pretrained("./models/bert-base-chinese")4 创建预训练模型——AutoModel¶

Huggingface官方提供了很多的预训练模型,可以在Huggingface官网很容易找到。通过AutoModel类,创建预训练模型最简单的方法就是直接传入预训练模型名称或者本地路径,因为国内网络环境原因,建议先去将预训练模型下载到本地,通过指定目录的方式进行加载:

In[19]:from transformers import AutoModelmodel = AutoModel.from_pretrained("./models/bert-base-chinese")Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

通过这种方法,模型将直接加载预训练模型config.json的配置项。也可以在加载模型时,指定配置类实例,这样就可以实现对预训练模型的自定义,例如,传入我们上一小节中修改后的config实例:

In[20]:model = AutoModel.from_pretrained("./models/bert-base-chinese", config=config)Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.weight', 'bert.encoder.layer.8.attention.self.value.bias', 'bert.encoder.layer.8.attention.output.dense.bias', 'bert.encoder.layer.10.attention.self.query.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'bert.encoder.layer.7.attention.output.dense.bias', 'bert.encoder.layer.9.attention.self.value.bias', 'bert.encoder.layer.7.attention.self.query.weight', 'bert.encoder.layer.10.output.LayerNorm.weight', 'bert.encoder.layer.9.attention.output.LayerNorm.bias', 'bert.encoder.layer.10.attention.output.LayerNorm.weight', 'bert.encoder.layer.9.attention.self.key.weight', 'cls.predictions.transform.dense.weight', 'bert.encoder.layer.5.attention.self.query.weight', 'bert.encoder.layer.11.output.LayerNorm.weight', 'bert.encoder.layer.6.attention.self.query.weight', 'bert.encoder.layer.10.output.dense.weight', 'bert.encoder.layer.6.attention.output.dense.bias', 'bert.encoder.layer.5.attention.self.key.weight', 'bert.encoder.layer.7.attention.self.value.bias', 'bert.encoder.layer.5.intermediate.dense.weight', 'bert.encoder.layer.9.intermediate.dense.weight', 'bert.encoder.layer.5.attention.self.query.bias', 'bert.encoder.layer.7.attention.self.key.weight', 'bert.encoder.layer.11.output.dense.weight', 'bert.encoder.layer.8.attention.self.key.weight', 'bert.encoder.layer.10.output.dense.bias', 'bert.encoder.layer.10.attention.output.dense.weight', 'bert.encoder.layer.11.intermediate.dense.bias', 'bert.encoder.layer.9.output.LayerNorm.weight', 'bert.encoder.layer.9.output.LayerNorm.bias', 'bert.encoder.layer.6.attention.self.value.weight', 'bert.encoder.layer.10.attention.output.dense.bias', 'bert.encoder.layer.11.attention.output.dense.bias', 'bert.encoder.layer.10.intermediate.dense.weight', 'bert.encoder.layer.10.attention.self.value.weight', 'bert.encoder.layer.6.attention.self.value.bias', 'bert.encoder.layer.6.attention.self.query.bias', 'bert.encoder.layer.11.intermediate.dense.weight', 'bert.encoder.layer.5.attention.output.LayerNorm.bias', 'bert.encoder.layer.8.output.dense.weight', 'bert.encoder.layer.11.attention.self.query.weight', 'bert.encoder.layer.7.intermediate.dense.bias', 'bert.encoder.layer.9.output.dense.bias', 'bert.encoder.layer.11.attention.self.value.bias', 'bert.encoder.layer.5.attention.self.value.bias', 'bert.encoder.layer.11.attention.output.LayerNorm.weight', 'bert.encoder.layer.10.attention.self.key.bias', 'bert.encoder.layer.9.attention.output.LayerNorm.weight', 'bert.encoder.layer.6.output.dense.weight', 'bert.encoder.layer.6.output.LayerNorm.bias', 'bert.encoder.layer.7.attention.self.key.bias', 'bert.encoder.layer.11.output.LayerNorm.bias', 'bert.encoder.layer.8.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.value.weight', 'bert.encoder.layer.8.attention.output.dense.weight', 'bert.encoder.layer.9.attention.output.dense.weight', 'cls.predictions.bias', 'bert.encoder.layer.9.output.dense.weight', 'bert.encoder.layer.8.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.self.key.bias', 'bert.encoder.layer.6.attention.self.key.bias', 'bert.encoder.layer.9.attention.self.query.bias', 'bert.encoder.layer.6.intermediate.dense.bias', 'bert.encoder.layer.6.attention.self.key.weight', 'bert.encoder.layer.8.attention.self.key.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.output.dense.bias', 'bert.encoder.layer.6.output.dense.bias', 'bert.encoder.layer.11.attention.self.key.weight', 'bert.encoder.layer.7.output.dense.weight', 'bert.encoder.layer.8.intermediate.dense.bias', 'bert.encoder.layer.5.attention.self.value.weight', 'bert.encoder.layer.7.output.LayerNorm.weight', 'bert.encoder.layer.5.output.dense.weight', 'bert.encoder.layer.11.output.dense.bias', 'bert.encoder.layer.8.output.dense.bias', 'bert.encoder.layer.10.attention.self.query.weight', 'bert.encoder.layer.9.intermediate.dense.bias', 'bert.encoder.layer.9.attention.self.value.weight', 'bert.encoder.layer.11.attention.output.LayerNorm.bias', 'bert.encoder.layer.7.intermediate.dense.weight', 'bert.encoder.layer.7.attention.self.query.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.weight', 'bert.encoder.layer.5.output.LayerNorm.weight', 'cls.predictions.decoder.weight', 'bert.encoder.layer.9.attention.self.query.weight', 'bert.encoder.layer.6.intermediate.dense.weight', 'bert.encoder.layer.10.output.LayerNorm.bias', 'bert.encoder.layer.11.attention.output.dense.weight', 'bert.encoder.layer.10.intermediate.dense.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.key.bias', 'bert.encoder.layer.8.intermediate.dense.weight', 'bert.encoder.layer.5.output.dense.bias', 'bert.encoder.layer.5.attention.output.dense.bias', 'bert.encoder.layer.8.attention.output.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'bert.encoder.layer.8.attention.self.value.weight', 'cls.predictions.transform.LayerNorm.bias', 'bert.encoder.layer.8.attention.self.query.bias', 'bert.encoder.layer.5.attention.output.dense.weight', 'bert.encoder.layer.7.output.dense.bias', 'cls.seq_relationship.bias', 'bert.encoder.layer.6.attention.output.dense.weight', 'bert.encoder.layer.11.attention.self.query.bias', 'bert.encoder.layer.7.attention.self.value.weight', 'bert.encoder.layer.8.attention.output.LayerNorm.weight', 'bert.encoder.layer.6.output.LayerNorm.weight', 'bert.encoder.layer.5.attention.self.key.bias', 'bert.encoder.layer.10.attention.self.value.bias', 'bert.encoder.layer.5.attention.output.LayerNorm.weight', 'bert.encoder.layer.7.output.LayerNorm.bias', 'bert.encoder.layer.5.intermediate.dense.bias', 'bert.encoder.layer.7.attention.output.dense.weight', 'bert.encoder.layer.10.attention.output.LayerNorm.bias', 'bert.encoder.layer.8.attention.self.query.weight', 'bert.encoder.layer.10.attention.self.key.weight', 'bert.encoder.layer.5.output.LayerNorm.bias']- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

因为在上一章节,我们将编码器结构层数改为5,所以,这里提示很多权重参数并未使用。

同时,我们也可以通过在from_pretrained()方法中直接传参的方式,传入配置项,例如,我们将编码器层数改为3层。注意,这种方式在指定了config参数时不在生效。

In[21]:model = AutoModel.from_pretrained("./models/bert-base-chinese", num_hidden_layers=3)Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['bert.encoder.layer.4.attention.self.value.bias', 'cls.seq_relationship.weight', 'bert.encoder.layer.8.attention.self.value.bias', 'bert.encoder.layer.8.attention.output.dense.bias', 'bert.encoder.layer.10.attention.self.query.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'bert.encoder.layer.7.attention.output.dense.bias', 'bert.encoder.layer.9.attention.self.value.bias', 'bert.encoder.layer.7.attention.self.query.weight', 'bert.encoder.layer.10.output.LayerNorm.weight', 'bert.encoder.layer.9.attention.output.LayerNorm.bias', 'bert.encoder.layer.4.intermediate.dense.weight', 'bert.encoder.layer.4.attention.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.self.key.weight', 'cls.predictions.transform.dense.weight', 'bert.encoder.layer.10.attention.output.LayerNorm.weight', 'bert.encoder.layer.4.attention.self.key.weight', 'bert.encoder.layer.3.intermediate.dense.bias', 'bert.encoder.layer.5.attention.self.query.weight', 'bert.encoder.layer.11.output.LayerNorm.weight', 'bert.encoder.layer.6.attention.self.query.weight', 'bert.encoder.layer.10.output.dense.weight', 'bert.encoder.layer.6.attention.output.dense.bias', 'bert.encoder.layer.5.attention.self.key.weight', 'bert.encoder.layer.7.attention.self.value.bias', 'bert.encoder.layer.5.intermediate.dense.weight', 'bert.encoder.layer.9.intermediate.dense.weight', 'bert.encoder.layer.5.attention.self.query.bias', 'bert.encoder.layer.7.attention.self.key.weight', 'bert.encoder.layer.4.output.dense.weight', 'bert.encoder.layer.8.attention.self.key.weight', 'bert.encoder.layer.11.output.dense.weight', 'bert.encoder.layer.10.output.dense.bias', 'bert.encoder.layer.10.attention.output.dense.weight', 'bert.encoder.layer.11.intermediate.dense.bias', 'bert.encoder.layer.3.attention.self.key.weight', 'bert.encoder.layer.9.output.LayerNorm.weight', 'bert.encoder.layer.9.output.LayerNorm.bias', 'bert.encoder.layer.3.attention.self.value.bias', 'bert.encoder.layer.6.attention.self.value.weight', 'bert.encoder.layer.10.attention.output.dense.bias', 'bert.encoder.layer.11.attention.output.dense.bias', 'bert.encoder.layer.10.intermediate.dense.weight', 'bert.encoder.layer.10.attention.self.value.weight', 'bert.encoder.layer.6.attention.self.value.bias', 'bert.encoder.layer.6.attention.self.query.bias', 'bert.encoder.layer.4.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.output.dense.bias', 'bert.encoder.layer.11.intermediate.dense.weight', 'bert.encoder.layer.5.attention.output.LayerNorm.bias', 'bert.encoder.layer.4.attention.self.query.weight', 'bert.encoder.layer.8.output.dense.weight', 'bert.encoder.layer.11.attention.self.query.weight', 'bert.encoder.layer.3.intermediate.dense.weight', 'bert.encoder.layer.4.attention.output.dense.bias', 'bert.encoder.layer.7.intermediate.dense.bias', 'bert.encoder.layer.9.output.dense.bias', 'bert.encoder.layer.11.attention.self.value.bias', 'bert.encoder.layer.5.attention.self.value.bias', 'bert.encoder.layer.11.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.attention.output.dense.weight', 'bert.encoder.layer.10.attention.self.key.bias', 'bert.encoder.layer.9.attention.output.LayerNorm.weight', 'bert.encoder.layer.4.output.LayerNorm.bias', 'bert.encoder.layer.6.output.dense.weight', 'bert.encoder.layer.6.output.LayerNorm.bias', 'bert.encoder.layer.7.attention.self.key.bias', 'bert.encoder.layer.11.output.LayerNorm.bias', 'bert.encoder.layer.8.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.value.weight', 'bert.encoder.layer.8.attention.output.dense.weight', 'bert.encoder.layer.9.attention.output.dense.weight', 'cls.predictions.bias', 'bert.encoder.layer.9.output.dense.weight', 'bert.encoder.layer.8.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.self.key.bias', 'bert.encoder.layer.6.attention.self.key.bias', 'bert.encoder.layer.3.output.dense.weight', 'bert.encoder.layer.9.attention.self.query.bias', 'bert.encoder.layer.6.intermediate.dense.bias', 'bert.encoder.layer.6.attention.self.key.weight', 'bert.encoder.layer.8.attention.self.key.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.bias', 'bert.encoder.layer.3.output.LayerNorm.bias', 'bert.encoder.layer.3.attention.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.output.dense.bias', 'bert.encoder.layer.6.output.dense.bias', 'bert.encoder.layer.11.attention.self.key.weight', 'bert.encoder.layer.7.output.dense.weight', 'bert.encoder.layer.8.intermediate.dense.bias', 'bert.encoder.layer.5.attention.self.value.weight', 'bert.encoder.layer.7.output.LayerNorm.weight', 'bert.encoder.layer.5.output.dense.weight', 'bert.encoder.layer.3.attention.self.query.weight', 'bert.encoder.layer.11.output.dense.bias', 'bert.encoder.layer.4.output.LayerNorm.weight', 'bert.encoder.layer.3.attention.self.key.bias', 'bert.encoder.layer.8.output.dense.bias', 'bert.encoder.layer.9.intermediate.dense.bias', 'bert.encoder.layer.10.attention.self.query.weight', 'bert.encoder.layer.9.attention.self.value.weight', 'bert.encoder.layer.11.attention.output.LayerNorm.bias', 'bert.encoder.layer.4.attention.self.key.bias', 'bert.encoder.layer.7.intermediate.dense.weight', 'bert.encoder.layer.7.attention.self.query.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.weight', 'bert.encoder.layer.5.output.LayerNorm.weight', 'cls.predictions.decoder.weight', 'bert.encoder.layer.9.attention.self.query.weight', 'bert.encoder.layer.6.intermediate.dense.weight', 'bert.encoder.layer.10.output.LayerNorm.bias', 'bert.encoder.layer.11.attention.output.dense.weight', 'bert.encoder.layer.10.intermediate.dense.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.key.bias', 'bert.encoder.layer.8.intermediate.dense.weight', 'bert.encoder.layer.5.output.dense.bias', 'bert.encoder.layer.5.attention.output.dense.bias', 'bert.encoder.layer.8.attention.output.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'bert.encoder.layer.8.attention.self.value.weight', 'cls.predictions.transform.LayerNorm.bias', 'bert.encoder.layer.3.attention.output.LayerNorm.weight', 'bert.encoder.layer.8.attention.self.query.bias', 'bert.encoder.layer.5.attention.output.dense.weight', 'bert.encoder.layer.3.attention.output.dense.bias', 'bert.encoder.layer.7.output.dense.bias', 'bert.encoder.layer.4.attention.output.dense.weight', 'bert.encoder.layer.6.attention.output.dense.weight', 'bert.encoder.layer.11.attention.self.query.bias', 'cls.seq_relationship.bias', 'bert.encoder.layer.4.intermediate.dense.bias', 'bert.encoder.layer.7.attention.self.value.weight', 'bert.encoder.layer.8.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.output.LayerNorm.weight', 'bert.encoder.layer.4.attention.self.query.bias', 'bert.encoder.layer.6.output.LayerNorm.weight', 'bert.encoder.layer.5.attention.self.key.bias', 'bert.encoder.layer.10.attention.self.value.bias', 'bert.encoder.layer.4.output.dense.bias', 'bert.encoder.layer.5.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.attention.self.value.weight', 'bert.encoder.layer.7.output.LayerNorm.bias', 'bert.encoder.layer.5.intermediate.dense.bias', 'bert.encoder.layer.7.attention.output.dense.weight', 'bert.encoder.layer.10.attention.output.LayerNorm.bias', 'bert.encoder.layer.8.attention.self.query.weight', 'bert.encoder.layer.3.attention.self.query.bias', 'bert.encoder.layer.10.attention.self.key.weight', 'bert.encoder.layer.5.output.LayerNorm.bias', 'bert.encoder.layer.4.attention.self.value.weight']- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

我们尝试一下将tokenizer编码后的张量在model中进行前向传播:

In[22]:tens = model(**tokenizer("床前明月光,疑是地上霜。", return_tensors="pt"))In[23]:tens.last_hidden_state.shapeOut[23]:torch.Size([1, 14, 768])

当模型修改或者重新训练后,可以通过model.save_pretrained()方法再次保存,保存后,在指定目录中将生成两个文件:配置文件(config.json),权重文件(pytorch_model.bin)。

In[82]:model.save_pretrained("./new_model/bert-base-chinese")5 使用现成的任务模型¶

在transformers库中,Huggingface还提供有许多完整网络模型用于各式各样的AI任务,例如图像分类、文本分类、语音分类、翻译、问答等,这类API大多以AutoModelFor*开头,我们打印输出看看:

In[31]:import transformersIn[33]:for api in dir(transformers): if api.startswith('AutoModelFor'): print(api)AutoModelForAudioClassificationAutoModelForAudioFrameClassificationAutoModelForAudioXVectorAutoModelForCTCAutoModelForCausalLMAutoModelForDepthEstimationAutoModelForDocumentQuestionAnsweringAutoModelForImageClassificationAutoModelForImageSegmentationAutoModelForInstanceSegmentationAutoModelForMaskedImageModelingAutoModelForMaskedLMAutoModelForMultipleChoiceAutoModelForNextSentencePredictionAutoModelForObjectDetectionAutoModelForPreTrainingAutoModelForQuestionAnsweringAutoModelForSemanticSegmentationAutoModelForSeq2SeqLMAutoModelForSequenceClassificationAutoModelForSpeechSeq2SeqAutoModelForTableQuestionAnsweringAutoModelForTokenClassificationAutoModelForVideoClassificationAutoModelForVision2SeqAutoModelForVisualQuestionAnsweringAutoModelForZeroShotObjectDetection

以其中的AutoModelForSequenceClassification为例,介绍怎么使用:

In[34]:from transformers import AutoModelForSequenceClassificationIn[35]:model = AutoModelForSequenceClassification.from_pretrained("./models/bert-base-chinese", num_labels=2)Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ./models/bert-base-chinese and are newly initialized: ['classifier.weight', 'classifier.bias']You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

num_labels意思是,我们需要进行的任务最终标签有两类,即这是一个二分类模型。我们查看一下模型结构:

In[36]:modelOut[36]:BertForSequenceClassification( (bert): BertModel( (embeddings): BertEmbeddings( (word_embeddings): Embedding(21128, 768, padding_idx=0) (position_embeddings): Embedding(512, 768) (token_type_embeddings): Embedding(2, 768) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): BertEncoder( (layer): ModuleList( (0): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (1): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (2): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (3): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (4): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (5): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (6): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (7): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (8): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (9): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (10): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (11): BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=768, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=768, out_features=3072, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=3072, out_features=768, bias=True) (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) (pooler): BertPooler( (dense): Linear(in_features=768, out_features=768, bias=True) (activation): Tanh() ) ) (dropout): Dropout(p=0.1, inplace=False) (classifier): Linear(in_features=768, out_features=2, bias=True))In[37]:tokenizer("床前明月光,疑是地上霜。", return_tensors="pt")Out[37]:{'input_ids': tensor([[ 101, 2414, 1184, 3209, 3299, 1045, 8024, 4542, 3221, 1765, 677, 7458, 511, 102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}In[38]:tens = model(**tokenizer("床前明月光,疑是地上霜。", return_tensors="pt"))In[39]:tensOut[39]:SequenceClassifierOutput(loss=None, logits=tensor([[-0.4371, -0.1223]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

最后输出了两个值,分贝对应于两个类别。

6 自定义模型¶In[1]:from chb import *import pandas as pdfrom tqdm import tqdmfrom collections import defaultdictimport torchfrom torch import nnfrom torch.utils.data.dataloader import DataLoaderfrom torch.utils.data.dataset import Datasetfrom torch.optim import AdamWfrom sklearn.model_selection import train_test_splitfrom transformers import AutoConfig,AutoModel,AutoTokenizer,get_linear_schedule_with_warmup,loggingimport warningswarnings.filterwarnings('ignore')In[2]:RANDOM_SEED = 1000MAX_LEN = 50BATCH_SIZE = 646.1 自定义数据集¶In[3]:x_lst = []y_lst = []with open('./data/中文文本-新闻分类数据集','r') as f: # 获得训练数据的总行数 for _ in tqdm(f,desc='load dataset'): try: line = f.readline().replace('\u3000\u3000', '').replace('\n', '') x, y = line.split('\t') if y == 'label': continue x_lst.append(x) y_lst.append(y) except: passload dataset: 5902it [00:00, 55378.83it/s]In[4]:len(x_lst), len(y_lst)Out[4]:(5900, 5900)In[5]:x_lst[0], y_lst[0]Out[5]:('昌平京基鹭府10月29日推别墅1200万套起享97折新浪房产讯(编辑郭彪)京基鹭府(论坛相册户型样板间点评地图搜索)售楼处位于昌平区京承高速北七家出口向西南公里路南。项目预计10月29日开盘,总价1200万元/套起,2012年年底入住。待售户型为联排户型面积为410-522平方米,独栋户型面积为938平方米,双拼户型面积为522平方米。京基鹭府项目位于昌平定泗路与东北路交界处。项目周边配套齐全,幼儿园:伊顿双语幼儿园、温莎双语幼儿园;中学:北师大亚太实验学校、潞河中学(北京市重点);大学:王府语言学校、北京邮电大学、现代音乐学院;医院:王府中西医结合医院(三级甲等)、潞河医院、解放军263医院、安贞医院昌平分院;购物:龙德广场、中联万家商厦、世纪华联超市、瑰宝购物中心、家乐福超市;酒店:拉斐特城堡、鲍鱼岛;休闲娱乐设施:九华山庄、温都温泉度假村、小汤山疗养院、龙脉温泉度假村、小汤山文化广场、皇港高尔夫、高地高尔夫、北鸿高尔夫球场;银行:工商银行、建设银行、中国银行、北京农村商业银行;邮局:中国邮政储蓄;其它:北七家建材城、百安居建材超市、北七家镇武装部、北京宏翔鸿企业孵化基地等,享受便捷生活。京基鹭府坐守定泗路,立汤路交汇处。连接京昌、八达岭、机场高速,南至5环,北上6环,紧邻立汤路,一路向南,直抵鸟巢、水立方、长安街,距北京唯一不堵车的京承高速出口仅1公里,节约出行时间成本,形成了三横、三纵的立体式交通网络项目周边多为别墅项目,人口密度低,交通出行舒适度高。>>报名参加“乐动银十”10月22日大型抄底看房团以上信息仅供参考,最终以开发商公布为准。订阅会员置业刊我们将直接把最新的热盘动向发送到您的邮箱更多热盘推荐:新锐白领淘低价1-2居网罗2万内轨道精装房手握20万咋买板楼2居网罗城南沿轨优质盘关注娃娃教育网罗学区房不足百万元住上通透2居', '房产')In[6]:label2id = dict()id2label = dict()for i, label in enumerate(set(y_lst)): label2id[label] = i id2label[i] = labelIn[7]:tokenizer = AutoTokenizer.from_pretrained("./models/bert-base-chinese")

先把所有的文本都转化为编码,而不是在后续数据集中转化,这样可以避免在后续训练过程中,每一个epoch都要进行转化,提升效率:

In[8]:token_lens = []for txt in tqdm(x_lst): tokens = tokenizer.encode(txt, max_length=512) token_lens.append(len(tokens)) 0%| | 0/5900 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.100%|██████████| 5900/5900 [00:07<00:00, 739.64it/s]In[9]:class NewsDataset(Dataset): def __init__(self,x_lst,y_lst,tokenizer,max_len): self.x_lst=x_lst self.y_lst=y_lst self.tokenizer=tokenizer self.max_len=max_len def __len__(self): return len(self.x_lst) def __getitem__(self,index): """ index 为数据索引,迭代取第index条数据 """ text=str(self.x_lst[index]) label=label2id[self.y_lst[index]] encoding=self.tokenizer.encode_plus( text, add_special_tokens=True, max_length=self.max_len, return_token_type_ids=True, pad_to_max_length=True, return_attention_mask=True, return_tensors='pt', ) return { 'texts':text, 'input_ids':encoding['input_ids'].flatten(), 'attention_mask':encoding['attention_mask'].flatten(), 'labels':torch.tensor(label,dtype=torch.long) }In[10]:x_train, x_val, y_train, y_val = train_test_split(x_lst, y_lst, test_size=0.15, random_state=RANDOM_SEED) # 划分训练集 测试集In[11]:# datasettrain_dataset = NewsDataset(x_train, y_train, tokenizer, MAX_LEN)val_dataset = NewsDataset(x_val, y_val, tokenizer, MAX_LEN)In[12]:# dataloadertrain_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)6.2 自定义网络¶

这里我们使用BERT预训练模型,同时接Dropout层和一层线形层,构成自定义网络:

In[13]:class CustomBERTModel(nn.Module): def __init__(self, n_classes): super(CustomBERTModel, self).__init__() self.bert = AutoModel.from_pretrained("./models/bert-base-chinese") self.drop = nn.Dropout(p=0.3) self.out = nn.Linear(self.bert.config.hidden_size, n_classes) def forward(self, input_ids, attention_mask): _, pooled_output = self.bert( input_ids=input_ids, attention_mask=attention_mask, return_dict = False ) output = self.drop(pooled_output) # dropout return self.out(output)In[14]:device = set_device(cuda_index=1)2022-12-20 16:12:39 set_device line 11 out: cuda:1In[15]:n_classes = len(label2id)model = CustomBERTModel(n_classes)model = model.to(device)Some weights of the model checkpoint at ./models/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

自定义数据集:

6.3 训练¶In[16]:def train_epoch(model, data_loader,loss_fn,optimizer,device,scheduler, n_examples): model.train() losses = [] correct_predictions = 0 for i, d in bar(data_loader): input_ids = d["input_ids"].to(device) attention_mask = d["attention_mask"].to(device) targets = d["labels"].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask ) _, preds = torch.max(outputs, dim=1) loss = loss_fn(outputs, targets) correct_predictions += torch.sum(preds == targets) losses.append(loss.item()) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() scheduler.step() optimizer.zero_grad() return correct_predictions.double() / n_examples, np.mean(losses)In[17]:def eval_model(model, data_loader, loss_fn, device, n_examples): model.eval() # 验证预测模式 losses = [] correct_predictions = 0 with torch.no_grad(): for d in data_loader: input_ids = d["input_ids"].to(device) attention_mask = d["attention_mask"].to(device) targets = d["labels"].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask ) _, preds = torch.max(outputs, dim=1) loss = loss_fn(outputs, targets) correct_predictions += torch.sum(preds == targets) losses.append(loss.item()) return correct_predictions.double() / n_examples, np.mean(losses)In[18]:EPOCHS = 5 # 训练轮数optimizer = AdamW(model.parameters(), lr=2e-5)total_steps = len(train_dataloader) * EPOCHSscheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, num_training_steps=total_steps)loss_fn = nn.CrossEntropyLoss().to(device)In[19]:best_accuracy = 0is_best = Falset = Tableprint(['epoch', 'train_accuracy', 'train_loss', 'test_accuracy', 'test_loss', 'is_best'])t.print_header()for epoch in range(EPOCHS): train_acc, train_loss = train_epoch(model,train_dataloader,loss_fn,optimizer,device,scheduler,len(x_train)) val_acc, val_loss = eval_model( model, val_dataloader, loss_fn, device, len(x_val) ) if val_acc > best_accuracy: is_best = True torch.save(model.state_dict(), './models/news_classification/best_model_state.bin') best_accuracy = val_acc else: is_best = False t.print_row(epoch, f"{train_acc:.4f}", f"{train_loss:.4f}", f"{val_acc:.4f}", f"{val_loss:.4f}", is_best)+======+===========+====================+================+===================+===============+=============+| | epoch | train_accuracy | train_loss | test_accuracy | test_loss | is_best |+======+===========+====================+================+===================+===============+=============+| 1 | 0 | 0.6080 | 1.4608 | 0.8893 | 0.5278 | True |+------+-----------+--------------------+----------------+-------------------+---------------+-------------+| 2 | 1 | 0.9196 | 0.3766 | 0.9096 | 0.3583 | True |+------+-----------+--------------------+----------------+-------------------+---------------+-------------+| 3 | 2 | 0.9589 | 0.2015 | 0.9153 | 0.3413 | True |+------+-----------+--------------------+----------------+-------------------+---------------+-------------+| 4 | 3 | 0.9765 | 0.1272 | 0.9153 | 0.3286 | False |+------+-----------+--------------------+----------------+-------------------+---------------+-------------+| 5 | 4 | 0.9836 | 0.0919 | 0.9220 | 0.3239 | True |+------+-----------+--------------------+----------------+-------------------+---------------+-------------+

使用BERT预训练模型+自定义网络,模型初始时就具有了较高的准确率。

本文链接地址:https://www.jiuchutong.com/zhishi/304715.html 转载请保留说明!

上一篇:phpcms v9怎么安装(phpcms v9官网)

下一篇:301重定向到https 并且不带www跳转到带www

  • mvc框架网页设计如果做一个展示品牌形象的网站,需要在设计上下功夫

    mvc框架网页设计如果做一个展示品牌形象的网站,需要在设计上下功夫

  • 促销是什么 线下零售企业怎样向电商学促销(促销是指什么)

    促销是什么 线下零售企业怎样向电商学促销(促销是指什么)

  • 微信扫码付款能查到付款人信息吗(微信扫码付款能查到对方微信号吗)

    微信扫码付款能查到付款人信息吗(微信扫码付款能查到对方微信号吗)

  • 信号旁边有个hd是什么意思(信号旁边有个hd怎么设置)

    信号旁边有个hd是什么意思(信号旁边有个hd怎么设置)

  • 苹果8可以改成双卡双待吗(苹果8可以改成双卡吗)

    苹果8可以改成双卡双待吗(苹果8可以改成双卡吗)

  • i7 3770相当于现在什么水平(i73770相当于现在的i3几代)

    i7 3770相当于现在什么水平(i73770相当于现在的i3几代)

  • 移动卡可以网上注销吗(移动卡可以网上实名认证吗)

    移动卡可以网上注销吗(移动卡可以网上实名认证吗)

  • 笔记本充电95充不满(笔记本充电一直是95正常吗)

    笔记本充电95充不满(笔记本充电一直是95正常吗)

  • 苹果掉水里怎么办(苹果掉水里怎么找)

    苹果掉水里怎么办(苹果掉水里怎么找)

  • 固态硬盘无法识别(固态硬盘无法识别如何提取数据)

    固态硬盘无法识别(固态硬盘无法识别如何提取数据)

  • 信号格上面有个e什么意思(信号格上面有个黄点)

    信号格上面有个e什么意思(信号格上面有个黄点)

  • iPhone11锁屏刹那出现红色(苹果11锁屏状态下点一下就会亮怎么解决)

    iPhone11锁屏刹那出现红色(苹果11锁屏状态下点一下就会亮怎么解决)

  • ppt中怎么设置内容一个一个出现(ppt怎么设置内容文本框边距)

    ppt中怎么设置内容一个一个出现(ppt怎么设置内容文本框边距)

  • oppo手机可以测体温吗(OPPO手机可以测距吗)

    oppo手机可以测体温吗(OPPO手机可以测距吗)

  • 云打印怎么用(云打印怎么用微信打印)

    云打印怎么用(云打印怎么用微信打印)

  • 荣耀20青春版采用什么指纹解锁方案(荣耀20青春版20i)

    荣耀20青春版采用什么指纹解锁方案(荣耀20青春版20i)

  • 淘气值多少才能助力(淘气值多少才能助力喵树)

    淘气值多少才能助力(淘气值多少才能助力喵树)

  • 公众号看到访客记录吗(公众号 访客记录)

    公众号看到访客记录吗(公众号 访客记录)

  • 全民k歌怎么屏蔽别人(全民k歌怎么屏蔽陌生人访问)

    全民k歌怎么屏蔽别人(全民k歌怎么屏蔽陌生人访问)

  • 无法正常关闭你的电脑(无法正常关闭你的电脑怎样恢复)

    无法正常关闭你的电脑(无法正常关闭你的电脑怎样恢复)

  • 饿了么怎么改用户名(饿了么如何更改)

    饿了么怎么改用户名(饿了么如何更改)

  • 苹果app是用什么开发的(苹果app是用什么代码写的)

    苹果app是用什么开发的(苹果app是用什么代码写的)

  • switch sd卡插哪里(switch sd卡插口)

    switch sd卡插哪里(switch sd卡插口)

  • win10最新2021激活秘钥分享 电脑系统激活序列号推荐 附激活工具(windows10最新激活密钥2020)

    win10最新2021激活秘钥分享 电脑系统激活序列号推荐 附激活工具(windows10最新激活密钥2020)

  • 注册表编辑器误删了恢复教程(注册表编辑器误删怎么办)

    注册表编辑器误删了恢复教程(注册表编辑器误删怎么办)

  • java8 (jdk 1.8) 新特性——Lambda

    java8 (jdk 1.8) 新特性——Lambda

  • 已认证进项税额转出会计分录账务处理
  • 应纳税所得额为负数是什么意思
  • 关联方的借款用什么科目核算
  • 增值税专用发票抵扣最新规定
  • 一般纳税人改成小规模纳税人
  • 生产车间的办公用品费计入
  • 印花税技术服务合同按什么交
  • 设备使用率怎么计算公式
  • 销项负数发票用勾选吗
  • 当月发出的货还有运费吗
  • 工伤保险可以税后支付吗
  • 公司股票卖出会计处理
  • 个人所得税退税截止日期2023
  • 购入的汽车怎么入账
  • 未分配利润期末余额怎么算出来
  • 付工程款现金怎么做凭证?
  • 1697509029
  • 期权的行权收益
  • 采购开票税率
  • 房地产企业收到房款账务处理
  • 剑灵玩一会就崩溃
  • Win7系统Syswow64文件夹是什么及能否删除的相关内
  • 印花税可以不计提嘛
  • 如何看懂财务报表的书
  • 会计主管人员应该具备的基本条件
  • cpu资源占用高怎么办
  • python中函数的可变参数有哪几种?各有什么特点?
  • php红包源码
  • 采用销售百分比法预测对外
  • 福利企业即征即退优惠政策
  • PHP:pg_get_result()的用法_PostgreSQL函数
  • php文件上传用什么请求方法
  • token登录器
  • vue的路由实现
  • 国内外人工智能发展现状,及最新的人工智能应用?
  • php二维数组的遍历
  • checksum 命令
  • yii框架教程
  • 出口退税方法相关文献
  • vuex 3
  • 增值税专用发票几个点
  • 水电费专用发票模板
  • percona-toolkit之pt-kill 杀掉mysql查询或连接的方法
  • 差旅费不抵扣
  • 印花税怎么用
  • 房产税的纳税义务人是征税范围内房屋产权所有人
  • 扇贝的储存方式
  • 建筑安装主要做什么
  • 返利开红字发票怎么做账
  • 小规模开专票的税点是多少
  • 企业发票冲红的风险
  • 净资产收益率计算公式
  • mysql详细介绍
  • win8关闭触屏
  • linux获取进程启动时间
  • mac所有窗口最小化
  • 在系统注册表中注册
  • xp系统无法预览图片
  • win7电脑桌面图标不见了右键也没用
  • unix2dos linux实现
  • Win10预览版拆弹
  • linuxmain函数
  • jQuery Real Person验证码插件防止表单自动提交
  • python中的字典与列表属于什么类型
  • nodejs 读取文件
  • 仿淘宝源码
  • jquery插件使用教程
  • 大杀器歼35震撼首飞!中国空军正式踏入世界第一梯队
  • python开发的程序
  • js设计模型
  • 如何获取android实体类保存的数据
  • Python性能优化指南
  • python cookie session
  • cfca证书下载流程
  • 税务局开蔬菜普票需要几个点
  • 重庆车牌号申请
  • 加油河南app怎么注销
  • 安徽省地方税务局刘利庆
  • 民办非企业暂行
  • 从国外买东西回来卖
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设