位置: IT常识 - 正文
推荐整理分享【记录】torch.nn.CrossEntropyLoss报错及解决(torch.nn.function),希望有所帮助,仅作参考,欢迎阅读内容。
文章相关热门搜索词:torch.nn.functional.linear,torch.nn.utils.clip_grad_norm,torch.nn.functional.linear,torch.nn.embedding,torch.nn.utils.clip_grad_norm,torch.nn.utils.clip_grad_norm,torch.nn.lstm,torch.nn.functional.grid_sample,内容如对您有帮助,希望把文章链接给更多的朋友!
在多分类语义分割问题中使用torch.nn.CrossEntropyLoss的时候,遇到的报错有:
1. Assertion `t >= 0 && t < n_classes` failed.2. RuntimeError: Expected floating point type for target with class probabilities, got Long通过官方文档了解到,torch.nn.CrossEntropyLoss分为两种情况:
直接使用class进行分类,此时的label为0,1,2…的整数。对于这类情况,torch.nn.CrossEntropyLoss中添加了LogSoftmax以及 NLLLoss,因此不用在网络的最后添加 softmax和argmax 将输出结果转换为整型。使用每一类的概率。这种标签通常情况下效果比直接使用class进行分类要好一些,但在少样本 && 在每一类上使用标签过于严格 的时候,才推荐使用概率作为标签。解决假设传入torch.nn.CrossEntropyLoss的参数为torch.nn.CrossEntropyLoss(pred, label),其中pred为模型预测的输出,label为标签。 这两个报错都是因为pred输入的维度错误导致的 根据官网文档,如果直接使用class进行分类,pred的维度应该是[batchsize, class, dim 1, dim 2, ... dim K],label的维度应该是[batchsize, dim 1, dim 2, ... dim K]。注意在网络输出的channel中加入class number的维度。不然softmax无法计算,及model的output channel = class number。 另,如果想直接使用class进行分类,需要讲label的type转换成long格式:labels = labels.to(device, dtype=torch.long)
上一篇:【YOLOv5】LabVIEW+OpenVINO让你的YOLOv5在CPU上飞起来(labview oop)
下一篇:React--》超详细教程——React脚手架的搭建与使用(reactz)
友情链接: 武汉网站建设