位置: IT常识 - 正文

HuggingFace简明教程

编辑:rootadmin
HuggingFace简明教程

推荐整理分享HuggingFace简明教程,希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:,内容如对您有帮助,希望把文章链接给更多的朋友!

视频链接:HuggingFace简明教程,BERT中文模型实战示例.NLP预训练模型,Transformers类库,datasets类库快速入门._哔哩哔哩_bilibili

1.huggingface简介与安装

什么是huggingface?huggingface是一个开源社区,它提供了先进的NLP模型,数据集,以及其他便利的工具。

数据集:Hugging Face – The AI community building the future. 

这些数据集可以根据任务、语言等来分类

模型:Models - Hugging Face 

官方文档: Hugging Face - Documentation 

 主要的模型:

        自回归:GPT2、Transformer-XL、XLNet

        自编码:BERT、ALBERT、RoBERTa、ELECTRA

        Seq2Seq:BART、Pegasus、T5

安装环境:

        前置环境:python、pytorch安装

        安装transformers、datasets包:

#安装transformers#pip安装pip install transformers#conda安装conda install -c huggingface transformers#安装datasets#pip安装pip install datasets#conda安装conda install -c huggingface -c conda-forge datasets

推荐使用pip进行安装

2.使用字典和分词工具加载tokenizer,准备语料

        在加载tokenizer的时候要传一个name,这个name与模型的name相一致,所以一个模型对应一个tokenizer

from transformers import BertTokenizer#加载预训练字典和分词方法tokenizer = BertTokenizer.from_pretrained( pretrained_model_name_or_path='bert-base-chinese', cache_dir=None, force_download=False,)sents = [ '选择珠江花园的原因就是方便。', '笔记本的键盘确实爽。', '房间太小。其他的都一般。', '今天才知道这书还有第6卷,真有点郁闷.', '机器背面似乎被撕了张什么标签,残胶还在。',]tokenizer, sents简单的编码

一次编码两个句子,text_pair是可以不传的,如果不传的话就是一次编码一个句子

#编码两个句子out = tokenizer.encode( text=sents[0], text_pair=sents[1], #当句子长度大于max_length时,截断 truncation=True, #一律补pad到max_length长度 padding='max_length', add_special_tokens=True, max_length=30, return_tensors=None,# 默认返回list)print(out)tokenizer.decode(out)增强的编码函数#增强的编码函数out = tokenizer.encode_plus( text=sents[0], text_pair=sents[1], #当句子长度大于max_length时,截断 truncation=True, #一律补零到max_length长度 padding='max_length', max_length=30, add_special_tokens=True, #可取值tf,pt,np,默认为返回list return_tensors=None, #返回token_type_ids return_token_type_ids=True, #返回attention_mask return_attention_mask=True, #返回special_tokens_mask 特殊符号标识 return_special_tokens_mask=True, #返回offset_mapping 标识每个词的起止位置,这个参数只能BertTokenizerFast使用 #return_offsets_mapping=True, #返回length 标识长度 return_length=True,)

增强编码的结果:

#input_ids 就是编码后的词#token_type_ids 第一个句子和特殊符号的位置是0,第二个句子的位置是1#special_tokens_mask 特殊符号的位置是1,其他位置是0#attention_mask pad的位置是0,其他位置是1#length 返回句子长度for k, v in out.items(): print(k, ':', v)tokenizer.decode(out['input_ids'])

 

 批量编码句子

上述方式是一次编码一个或者一对句子,但是实际操作中需要批量编码句子。这里编码的是一个一个的句子,而不是一对一对的句子

#批量编码句子out = tokenizer.batch_encode_plus( batch_text_or_text_pairs=[sents[0], sents[1]], add_special_tokens=True, #当句子长度大于max_length时,截断 truncation=True, #一律补零到max_length长度 padding='max_length', max_length=15, #可取值tf,pt,np,默认为返回list return_tensors=None, #返回token_type_ids return_token_type_ids=True, #返回attention_mask return_attention_mask=True, #返回special_tokens_mask 特殊符号标识 return_special_tokens_mask=True, #返回offset_mapping 标识每个词的起止位置,这个参数只能BertTokenizerFast使用 #return_offsets_mapping=True, #返回length 标识长度 return_length=True,)

批量编码的结果:

#input_ids 就是编码后的词#token_type_ids 第一个句子和特殊符号的位置是0,第二个句子的位置是1#special_tokens_mask 特殊符号的位置是1,其他位置是0#attention_mask pad的位置是0,其他位置是1#length 返回句子长度for k, v in out.items(): print(k, ':', v)tokenizer.decode(out['input_ids'][0]), tokenizer.decode(out['input_ids'][1])HuggingFace简明教程

批量成对编码

传入的list中是一个一个的tuple,tuple中是一对句子

#批量编码成对的句子out = tokenizer.batch_encode_plus( batch_text_or_text_pairs=[(sents[0], sents[1]), (sents[2], sents[3])], add_special_tokens=True, #当句子长度大于max_length时,截断 truncation=True, #一律补零到max_length长度 padding='max_length', max_length=30, #可取值tf,pt,np,默认为返回list return_tensors=None, #返回token_type_ids return_token_type_ids=True, #返回attention_mask return_attention_mask=True, #返回special_tokens_mask 特殊符号标识 return_special_tokens_mask=True, #返回offset_mapping 标识每个词的起止位置,这个参数只能BertTokenizerFast使用 #return_offsets_mapping=True, #返回length 标识长度 return_length=True,)

批量成对编码结果:

#input_ids 就是编码后的词#token_type_ids 第一个句子和特殊符号的位置是0,第二个句子的位置是1#special_tokens_mask 特殊符号的位置是1,其他位置是0#attention_mask pad的位置是0,其他位置是1#length 返回句子长度for k, v in out.items(): print(k, ':', v)tokenizer.decode(out['input_ids'][0]) 字典操作

操作tokenizer中的字典,当前的字典以一个字为一个词

#获取字典zidian = tokenizer.get_vocab()type(zidian), len(zidian), '月光' in zidian,

 

#添加新词tokenizer.add_tokens(new_tokens=['月光', '希望'])#添加新符号tokenizer.add_special_tokens({'eos_token': '[EOS]'})zidian = tokenizer.get_vocab()type(zidian), len(zidian), zidian['月光'], zidian['[EOS]']

 编码新词:

#编码新添加的词out = tokenizer.encode( text='月光的新希望[EOS]', text_pair=None, #当句子长度大于max_length时,截断 truncation=True, #一律补pad到max_length长度 padding='max_length', add_special_tokens=True, max_length=8, return_tensors=None,)print(out)tokenizer.decode(out)3.数据集操作加载数据集

以情感分类数据集为例

from datasets import load_dataset#加载数据dataset = load_dataset(path='seamew/ChnSentiCorp')dataset

#查看一个数据dataset[0]  排序和打乱#sort#未排序的label是乱序的print(dataset['label'][:10])#排序之后label有序了sorted_dataset = dataset.sort('label')print(sorted_dataset['label'][:10])print(sorted_dataset['label'][-10:])

#shuffle#打乱顺序shuffled_dataset = sorted_dataset.shuffle(seed=42)shuffled_dataset['label'][:10]

选择和过滤#selectdataset.select([0, 10, 20, 30, 40, 50])

 

#filterdef f(data): return data['text'].startswith('选择')start_with_ar = dataset.filter(f)len(start_with_ar), start_with_ar['text']

切分和分桶#train_test_split, 切分训练集和测试集dataset.train_test_split(test_size=0.1)

#shard#把数据切分到4个桶中,均匀分配dataset.shard(num_shards=4, index=0)

列操作和类型转换#rename_columndataset.rename_column('text', 'textA')

#remove_columnsdataset.remove_columns(['text'])

#set_formatdataset.set_format(type='torch', columns=['label'])dataset[0]

 map函数

对数据集中的每一条数据都做函数f操作

#mapdef f(data): data['text'] = 'My sentence: ' + data['text'] return datadatatset_map = dataset.map(f)datatset_map['text'][:5]

保存和加载#保存数据集到磁盘dataset.save_to_disk(dataset_dict_path='./data/ChnSentiCorp')#从磁盘加载数据from datasets import load_from_diskdataset = load_from_disk('./data/ChnSentiCorp')

导出和保存为其他格式

#导出为csv格式dataset = load_dataset(path='seamew/ChnSentiCorp', split='train')dataset.to_csv(path_or_buf='./data/ChnSentiCorp.csv')#加载csv格式数据csv_dataset = load_dataset(path='csv', data_files='./data/ChnSentiCorp.csv', split='train')#导出为json格式dataset = load_dataset(path='seamew/ChnSentiCorp', split='train')dataset.to_json(path_or_buf='./data/ChnSentiCorp.json')#加载json格式数据json_dataset = load_dataset(path='json', data_files='./data/ChnSentiCorp.json', split='train')4.使用评价函数查看可用的评价指标from datasets import list_metrics#列出评价指标metrics_list = list_metrics()len(metrics_list), metrics_list 查看该指标的说明文档

可以按照评价指标的说明文档中的示例代码来使用该指标

from datasets import load_metric#加载一个评价指标metric = load_metric('glue', 'mrpc')print(metric.inputs_description)计算一个评价指标#计算一个评价指标predictions = [0, 1, 0]references = [0, 1, 1]final_score = metric.compute(predictions=predictions, references=references)final_score

5.使用pipline函数

pipeline提供了一些不需要训练就可以执行一些nlp任务的模型,实用价值不高

情感分类from transformers import pipeline#文本分类classifier = pipeline("sentiment-analysis")result = classifier("I hate you")[0]print(result)result = classifier("I love you")[0]print(result) 阅读理解from transformers import pipeline#阅读理解question_answerer = pipeline("question-answering")context = r"""Extractive Question Answering is the task of extracting an answer from a text given a question. An example of a question answering dataset is the SQuAD dataset, which is entirely based on that task. If you would like to fine-tune a model on a SQuAD task, you may leverage the examples/pytorch/question-answering/run_squad.py script."""result = question_answerer(question="What is extractive question answering?", context=context)print(result)result = question_answerer( question="What is a good example of a question answering dataset?", context=context)print(result)完形填空from transformers import pipeline#完形填空unmasker = pipeline("fill-mask")from pprint import pprintsentence = 'HuggingFace is creating a <mask> that the community uses to solve NLP tasks.'unmasker(sentence)文本生成from transformers import pipeline#文本生成text_generator = pipeline("text-generation")text_generator("As far as I am concerned, I will", max_length=50, do_sample=False)命名实体识别from transformers import pipeline#命名实体识别ner_pipe = pipeline("ner")sequence = """Hugging Face Inc. is a company based in New York City. Its headquarters are in DUMBO,therefore very close to the Manhattan Bridge which is visible from the window."""for entity in ner_pipe(sequence): print(entity)文本摘要from transformers import pipeline#文本总结summarizer = pipeline("summarization")ARTICLE = """ New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York.A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband.Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other.In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage.Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the2010 marriage license application, according to court documents.Prosecutors said the marriages were part of an immigration scam.On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further.After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said DetectiveAnnette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002.All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say.Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages.Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted.The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'sInvestigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali.Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force.If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18."""summarizer(ARTICLE, max_length=130, min_length=30, do_sample=False)翻译from transformers import pipeline#翻译translator = pipeline("translation_en_to_de")sentence = "Hugging Face is a technology company based in New York and Paris"translator(sentence, max_length=40)trainer API加载分词工具from transformers import AutoTokenizer#加载分词工具tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')定义数据集from datasets import load_datasetfrom datasets import load_from_disk#加载数据集#从网络加载#datasets = load_dataset(path='glue', name='sst2')#从本地磁盘加载数据datasets = load_from_disk('./data/glue_sst2')#分词def f(data): return tokenizer( data['sentence'], padding='max_length', truncation=True, max_length=30, )datasets = datasets.map(f, batched=True, batch_size=1000, num_proc=4)#取数据子集,否则数据太多跑不动dataset_train = datasets['train'].shuffle().select(range(1000))dataset_test = datasets['validation'].shuffle().select(range(200))del datasetsdataset_train加载模型from transformers import AutoModelForSequenceClassification#加载模型model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased', num_labels=2)print(sum([i.nelement() for i in model.parameters()]) / 10000) # 查看模型参数数量定义评价函数import numpy as npfrom datasets import load_metricfrom transformers.trainer_utils import EvalPrediction#加载评价函数metric = load_metric('accuracy')#定义评价函数def compute_metrics(eval_pred): logits, labels = eval_pred logits = logits.argmax(axis=1) return metric.compute(predictions=logits, references=labels)#模拟测试输出eval_pred = EvalPrediction( predictions=np.array([[0, 1], [2, 3], [4, 5], [6, 7]]), label_ids=np.array([1, 1, 1, 1]),)compute_metrics(eval_pred)定义训练器并测试from transformers import TrainingArguments, Trainer#初始化训练参数args = TrainingArguments(output_dir='./output_dir', evaluation_strategy='epoch')args.num_train_epochs = 1args.learning_rate = 1e-4args.weight_decay = 1e-2args.per_device_eval_batch_size = 32args.per_device_train_batch_size = 16#初始化训练器trainer = Trainer( model=model, args=args, train_dataset=dataset_train, eval_dataset=dataset_test, compute_metrics=compute_metrics,)#评价模型trainer.evaluate()

模型未训练前的准确率是0.49

#训练trainer.train()

 

 训练一个epoch之后的准确率为0.8

保存模型参数#保存模型trainer.save_model(output_dir='./output_dir')使用保存的模型参数定义测试数据集import torchdef collate_fn(data): label = [i['label'] for i in data] input_ids = [i['input_ids'] for i in data] token_type_ids = [i['token_type_ids'] for i in data] attention_mask = [i['attention_mask'] for i in data] label = torch.LongTensor(label) input_ids = torch.LongTensor(input_ids) token_type_ids = torch.LongTensor(token_type_ids) attention_mask = torch.LongTensor(attention_mask) return label, input_ids, token_type_ids, attention_mask#数据加载器loader_test = torch.utils.data.DataLoader(dataset=dataset_test, batch_size=4, collate_fn=collate_fn, shuffle=True, drop_last=True)for i, (label, input_ids, token_type_ids, attention_mask) in enumerate(loader_test): breaklabel, input_ids, token_type_ids, attention_mask测试import torch#测试def test(): #加载参数 model.load_state_dict(torch.load('./output_dir/pytorch_model.bin')) model.eval() #运算 out = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) #[4, 2] -> [4] out = out['logits'].argmax(dim=1) correct = (out == label).sum().item() return correct / len(label)test()
本文链接地址:https://www.jiuchutong.com/zhishi/285062.html 转载请保留说明!

上一篇:大城历史公园中的帕喜善佩寺,泰国大城府 (© travelstock44/Alamy)(大城遗址公园)

下一篇:SU-03T语音模块的使用(小智语音控制LED灯)(语音模块作用)

  • 拍照水印怎么设置时间地点(拍照水印怎么设置经纬度)

    拍照水印怎么设置时间地点(拍照水印怎么设置经纬度)

  • 抖音只能实名认证一个账号吗(抖音只能实名认证可以取消吗)

    抖音只能实名认证一个账号吗(抖音只能实名认证可以取消吗)

  • 在powerpoint的页面设置中,能够设置(powerpoint页面设置在哪个选项卡)

    在powerpoint的页面设置中,能够设置(powerpoint页面设置在哪个选项卡)

  • 华为手机提取文字怎么弄(华为手机提取文字的功能在哪里)

    华为手机提取文字怎么弄(华为手机提取文字的功能在哪里)

  • 华为mate30无线快充怎么使用(华为mate30epro无线)

    华为mate30无线快充怎么使用(华为mate30epro无线)

  • 果壳网为什么不能访问了(果壳网为什么叫果壳)

    果壳网为什么不能访问了(果壳网为什么叫果壳)

  • 音频信号的频率范围(音频信号的频率如何计算)

    音频信号的频率范围(音频信号的频率如何计算)

  • 面容id显示稍后设置(面容id不可以用 稍后尝试)

    面容id显示稍后设置(面容id不可以用 稍后尝试)

  • 支付宝怎么看自动续费的东西(支付宝怎么看自动续费项目)

    支付宝怎么看自动续费的东西(支付宝怎么看自动续费项目)

  • 怎么鉴别airpods2真假(怎么鉴别airpods)

    怎么鉴别airpods2真假(怎么鉴别airpods)

  • 快手小黄车收费吗(快手小黄车收费标准)

    快手小黄车收费吗(快手小黄车收费标准)

  • 华为mate20怎么解锁(华为mate20怎么解除关联华为帐号)

    华为mate20怎么解锁(华为mate20怎么解除关联华为帐号)

  • 苹果8有红外线遥控功能吗(苹果手机自带红外线吗)

    苹果8有红外线遥控功能吗(苹果手机自带红外线吗)

  • 苹果手机怎么弄天气预报(苹果手机怎么弄门禁卡开门)

    苹果手机怎么弄天气预报(苹果手机怎么弄门禁卡开门)

  • 微信可以看访客记录吗(微信可以看访客吗?)

    微信可以看访客记录吗(微信可以看访客吗?)

  • 抖音昵称一天能改几次(抖音昵称一天能改多少次)

    抖音昵称一天能改几次(抖音昵称一天能改多少次)

  • oppogps在哪里(oppoa11gps在哪里)

    oppogps在哪里(oppoa11gps在哪里)

  • 表格递减是升序还是降序(excel中递减是升序还是降序)

    表格递减是升序还是降序(excel中递减是升序还是降序)

  • m923q是什么型号(m928q是什么型号)

    m923q是什么型号(m928q是什么型号)

  • 快手直播公屏怎么@别人(快手直播公屏怎么变小)

    快手直播公屏怎么@别人(快手直播公屏怎么变小)

  • vivox27怎么开微信美颜(vivox27怎么开微信视频美颜)

    vivox27怎么开微信美颜(vivox27怎么开微信视频美颜)

  • 手机用户体验怎么关闭(手机用户体验怎么删除)

    手机用户体验怎么关闭(手机用户体验怎么删除)

  • word截图快捷键ctrl加什么(word截图快捷键ctrl加什么保存)

    word截图快捷键ctrl加什么(word截图快捷键ctrl加什么保存)

  • 华为手机怎么截取长图(华为手机怎么截图长屏幕截图)

    华为手机怎么截取长图(华为手机怎么截图长屏幕截图)

  • 安装win7系统前在BIOS中设置硬盘模式的方法(安装win7前需要手动格式化c盘吗)

    安装win7系统前在BIOS中设置硬盘模式的方法(安装win7前需要手动格式化c盘吗)

  • 30岁了,说几句大实话(30多岁应该怎么说)

    30岁了,说几句大实话(30多岁应该怎么说)

  • 递延所得税资产会计处理全过程
  • 如何确定关联方及关联关系
  • 出口退税免税政策
  • 计提附加税会计凭证怎么做会计分录
  • 兼职人员工资需要申报个税吗
  • 个人所得税修改赡养信息
  • 每个月结转损益都有什么科目
  • 如何判断会计政策变动
  • 合同负债里面含增值税吗
  • 研发费用资本化支出
  • 3项经费计提比例2015
  • 财产租赁个人所得税
  • 发放股票股利的账务处理
  • 偿还不起债务大约能判多少年
  • 公司进行债务重构的原因可能包括
  • 公司临时聘用人员法规
  • 纳税调整后所得怎么算
  • 汽油费能计入办公费吗
  • 已抵扣发票红冲后发票还给对方公司
  • 增值税普通发票有什么用
  • 长期股权投资的投资收益怎么算
  • 培训费和差旅费可以一起报嘛
  • 成品油生产企业身份归类管理办法
  • 离职有补偿金的可不可以领取失业金
  • 员工两处取得工资收入
  • 出售单位车辆如何处理业务?
  • win11进入黑屏
  • 解放双手神器说说
  • 应交税费和应交税金的区别属于什么科目
  • 出口退税率和进项一致吗
  • 公司购买的财产保险服务可以抵扣进项吗
  • win8快捷键大全
  • 电脑进不去系统一直在启动界面
  • PHP:session_is_registered()的用法_Session函数
  • 公司logo设计费入什么科目
  • 海关专用缴款书认证的步骤和说明
  • win10系统的安装
  • vue清空form数据再重新赋值
  • 增值税进项和销项怎么抵扣
  • 社保台账显示未托收
  • 金税盘怎么设置字体
  • php前台模板
  • 应收账款和应付账款属于什么科目
  • 计提加计抵减额在财务报表里哪里体现
  • 所得税申报表收入包含营业外收入吗
  • 增值税专用发票和普通发票的区别
  • 材料成本差异的借方表示什么
  • 商誉需要交税吗
  • mysql配置怎么调出来
  • 公司收取保证金合法吗
  • 计提个税和缴纳个税金额不符的原因
  • 建信融通e信通怎样转让
  • 银行承兑汇票和银行汇票的区别
  • 经营租赁方式租入再转租的建筑物
  • 残保金缴纳额计算公式
  • 公司应付款是什么意思
  • 科目余额表如何看
  • 递延所得税会计处理全过程
  • 招标付款条件及比例
  • 酒店财务帐务处理方案
  • mysql存储过程的语句块以什么开始以什么结束
  • MySQL 5.7.14 net start mysql 服务无法启动-“NET HELPMSG 3534” 的奇怪问题
  • 微软9月30日将发行股票
  • 扫清落叶堆怎么扫
  • windows下键盘不能用
  • 电脑开机绿
  • 旅游软件页面
  • jqueryform表单提交
  • shell echo-e
  • linux中mysql备份shell脚本代码
  • python爬虫详解
  • bootstrap要学多久
  • jquery将文本框设置为只读
  • 浙江省网上税务局申报
  • 重庆市国家税务局电子税务局官网
  • 稽查局是税务局的派出机构还是内设机构
  • 入职培训结束寄语
  • 税务稽查团队
  • 税务打虚打骗
  • 新郑市税务局
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设