位置: IT常识 - 正文

【pytorch】Vision Transformer实现图像分类+可视化+训练数据保存(pytorch x.view)

编辑:rootadmin
【pytorch】Vision Transformer实现图像分类+可视化+训练数据保存 一、Vision Transformer介绍

推荐整理分享【pytorch】Vision Transformer实现图像分类+可视化+训练数据保存(pytorch x.view),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch x.view,pytorch的view,pytorchview,pytorch vision transformer,pytorch tensor view,pytorchview,pytorch tensor view,pytorch x.view,内容如对您有帮助,希望把文章链接给更多的朋友!

Transformer的核心是 “自注意力” 机制。

论文地址:https://arxiv.org/pdf/2010.11929.pdf

自注意力(self-attention)相比 卷积神经网络 和 循环神经网络 同时具有并行计算和最短的最大路径⻓度这两个优势。因此,使用自注意力来设计深度架构是很有吸引力的。对比之前仍然依赖循环神经网络实现输入表示的自注意力模型 [Cheng et al., 2016,Lin et al., 2017b, Paulus et al., 2017],transformer模型完全基于注意力机制,没有任何卷积层或循环神经网络层 [Vaswani et al., 2017]。尽管transformer最初是应用于在文本数据上的序列到序列学习,但现在已经推广到各种现代的深度学习中,例如语言、视觉、语音和强化学习领域。

17年发布时主要应用于不同语言之间翻译功能的实现。而在后来,有关研究发现Transformer应用于计算机视觉CV方面有着不输于卷积神经网络的强劲性能,一定程度上甚至比卷积神经网络更强。于是,初代Vision Transformer诞生了, 简称Vit。

Vision Transformer和Transformer区别是什么?用最最最简单的理解方式来看,Transformer的工作就是把一句话从一种语言翻译成另一种语言。主要是通过是将待翻译的一句话拆分为 多个单词 或者 多个模块,进行编码和解码训练,再评估那个单词对应的意思得分高就是相应的翻译结果。

而Vision Transformer则是将一个图片抽象地看做翻译中一个句子,通过图像分割将其拆分为多个模块,再进行编码和解码训练,评估中得分高的选项便是预测的结果。(纯属个人理解,如有错误,欢迎批评指正)

二、数据集

我的数据集为植物叶片病害的无标注数据集,共有三种类型。

{ "0": "Huanglong_disease", "1": "Magnesium_deficiency", "2": "Normal"}

其中train : val : test  =  8 : 1 : 1,种类都是三种,只是数量不一样。

train├── Huanglong_disease│ ├── 000000.jpg│ ├── 000001.jpg│ ├── 000002.jpg│ ├── .............│ ├── 000607.jpg├── Magnesium_deficiency└── Normal

大概长这样:

三、实战代码1.vit_model.py"""original code from rwightman:https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py"""from functools import partialfrom collections import OrderedDictimport torchimport torch.nn as nndef drop_path(x, drop_prob: float = 0., training: bool = False): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return outputclass DropPath(nn.Module): """ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training)class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None): super().__init__() img_size = (img_size, img_size) patch_size = (patch_size, patch_size) self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." # flatten: [B, C, H, W] -> [B, C, HW] # transpose: [B, C, HW] -> [B, HW, C] x = self.proj(x).flatten(2).transpose(1, 2) x = self.norm(x) return xclass Attention(nn.Module): def __init__(self, dim, # 输入token的dim num_heads=8, qkv_bias=False, qk_scale=None, attn_drop_ratio=0., proj_drop_ratio=0.): super(Attention, self).__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop_ratio) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop_ratio) def forward(self, x): # [batch_size, num_patches + 1, total_embed_dim] B, N, C = x.shape # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim] # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head] # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head] qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # [batch_size, num_heads, num_patches + 1, embed_dim_per_head] q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1] # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head] # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head] # reshape: -> [batch_size, num_patches + 1, total_embed_dim] x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return xclass Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return xclass Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_ratio=0., attn_drop_ratio=0., drop_path_ratio=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super(Block, self).__init__() self.norm1 = norm_layer(dim) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio) def forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return xclass VisionTransformer(nn.Module): def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, representation_size=None, distilled=False, drop_ratio=0., attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None, act_layer=None): """ Args: img_size (int, tuple): input image size patch_size (int, tuple): patch size in_c (int): number of input channels num_classes (int): number of classes for classification head embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True qk_scale (float): override default qk scale of head_dim ** -0.5 if set representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set distilled (bool): model includes a distillation token and head as in DeiT models drop_ratio (float): dropout rate attn_drop_ratio (float): attention dropout rate drop_path_ratio (float): stochastic depth rate embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer """ super(VisionTransformer, self).__init__() self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_tokens = 2 if distilled else 1 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) self.pos_drop = nn.Dropout(p=drop_ratio) dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule self.blocks = nn.Sequential(*[ Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i], norm_layer=norm_layer, act_layer=act_layer) for i in range(depth) ]) self.norm = norm_layer(embed_dim) # Representation layer if representation_size and not distilled: self.has_logits = True self.num_features = representation_size self.pre_logits = nn.Sequential(OrderedDict([ ("fc", nn.Linear(embed_dim, representation_size)), ("act", nn.Tanh()) ])) else: self.has_logits = False self.pre_logits = nn.Identity() # Classifier head(s) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.head_dist = None if distilled: self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() # Weight init nn.init.trunc_normal_(self.pos_embed, std=0.02) if self.dist_token is not None: nn.init.trunc_normal_(self.dist_token, std=0.02) nn.init.trunc_normal_(self.cls_token, std=0.02) self.apply(_init_vit_weights) def forward_features(self, x): # [B, C, H, W] -> [B, num_patches, embed_dim] x = self.patch_embed(x) # [B, 196, 768] # [1, 1, 768] -> [B, 1, 768] cls_token = self.cls_token.expand(x.shape[0], -1, -1) if self.dist_token is None: x = torch.cat((cls_token, x), dim=1) # [B, 197, 768] else: x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) x = self.pos_drop(x + self.pos_embed) x = self.blocks(x) x = self.norm(x) if self.dist_token is None: return self.pre_logits(x[:, 0]) else: return x[:, 0], x[:, 1] def forward(self, x): x = self.forward_features(x) if self.head_dist is not None: x, x_dist = self.head(x[0]), self.head_dist(x[1]) if self.training and not torch.jit.is_scripting(): # during inference, return the average of both classifier predictions return x, x_dist else: return (x + x_dist) / 2 else: x = self.head(x) return xdef _init_vit_weights(m): """ ViT weight initialization :param m: module """ if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=.01) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.LayerNorm): nn.init.zeros_(m.bias) nn.init.ones_(m.weight)def vit_base_patch16_224(num_classes: int = 1000): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer. weights ported from official Google JAX impl: 链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f """ model = VisionTransformer(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=None, num_classes=num_classes) return modeldef vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. weights ported from official Google JAX impl: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth """ model = VisionTransformer(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768 if has_logits else None, num_classes=num_classes) return modeldef vit_base_patch32_224(num_classes: int = 1000): """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer. weights ported from official Google JAX impl: 链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl """ model = VisionTransformer(img_size=224, patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=None, num_classes=num_classes) return modeldef vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True): """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. weights ported from official Google JAX impl: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth """ model = VisionTransformer(img_size=224, patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768 if has_logits else None, num_classes=num_classes) return modeldef vit_large_patch16_224(num_classes: int = 1000): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer. weights ported from official Google JAX impl: 链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8 """ model = VisionTransformer(img_size=224, patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=None, num_classes=num_classes) return modeldef vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. weights ported from official Google JAX impl: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth """ model = VisionTransformer(img_size=224, patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024 if has_logits else None, num_classes=num_classes) return modeldef vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. weights ported from official Google JAX impl: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth """ model = VisionTransformer(img_size=224, patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024 if has_logits else None, num_classes=num_classes) return modeldef vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True): """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: converted weights not currently available, too large for github release hosting. """ model = VisionTransformer(img_size=224, patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280 if has_logits else None, num_classes=num_classes) return model2.utils.pyimport osimport sysimport jsonimport pickleimport randomimport torchfrom tqdm import tqdmimport matplotlib.pyplot as pltdef read_split_data(root: str, val_rate: float = 0.2): random.seed(0) # 保证随机结果可复现 assert os.path.exists(root), "dataset root: {} does not exist.".format(root) # 遍历文件夹,一个文件夹对应一个类别 flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] # 排序,保证顺序一致 flower_class.sort() # 生成类别名称以及对应的数字索引 class_indices = dict((k, v) for v, k in enumerate(flower_class)) json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4) with open('class_indices.json', 'w') as json_file: json_file.write(json_str) train_images_path = [] # 存储训练集的所有图片路径 train_images_label = [] # 存储训练集图片对应索引信息 val_images_path = [] # 存储验证集的所有图片路径 val_images_label = [] # 存储验证集图片对应索引信息 every_class_num = [] # 存储每个类别的样本总数 supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型 # 遍历每个文件夹下的文件 for cla in flower_class: cla_path = os.path.join(root, cla) # 遍历获取supported支持的所有文件路径 images = [os.path.join(root, cla, i) for i in os.listdir(cla_path) if os.path.splitext(i)[-1] in supported] # 获取该类别对应的索引 image_class = class_indices[cla] # 记录该类别的样本数量 every_class_num.append(len(images)) # 按比例随机采样验证样本 val_path = random.sample(images, k=int(len(images) * val_rate)) for img_path in images: if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集 val_images_path.append(img_path) val_images_label.append(image_class) else: # 否则存入训练集 train_images_path.append(img_path) train_images_label.append(image_class) print("{} images were found in the dataset.".format(sum(every_class_num))) print("{} images for training.".format(len(train_images_path))) print("{} images for validation.".format(len(val_images_path))) plot_image = False if plot_image: # 绘制每种类别个数柱状图 plt.bar(range(len(flower_class)), every_class_num, align='center') # 将横坐标0,1,2,3,4替换为相应的类别名称 plt.xticks(range(len(flower_class)), flower_class) # 在柱状图上添加数值标签 for i, v in enumerate(every_class_num): plt.text(x=i, y=v + 5, s=str(v), ha='center') # 设置x坐标 plt.xlabel('image class') # 设置y坐标 plt.ylabel('number of images') # 设置柱状图的标题 plt.title('flower class distribution') plt.show() return train_images_path, train_images_label, val_images_path, val_images_labeldef plot_data_loader_image(data_loader): batch_size = data_loader.batch_size plot_num = min(batch_size, 4) json_path = './class_indices.json' assert os.path.exists(json_path), json_path + " does not exist." json_file = open(json_path, 'r') class_indices = json.load(json_file) for data in data_loader: images, labels = data for i in range(plot_num): # [C, H, W] -> [H, W, C] img = images[i].numpy().transpose(1, 2, 0) # 反Normalize操作 img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255 label = labels[i].item() plt.subplot(1, plot_num, i+1) plt.xlabel(class_indices[str(label)]) plt.xticks([]) # 去掉x轴的刻度 plt.yticks([]) # 去掉y轴的刻度 plt.imshow(img.astype('uint8')) plt.show()def write_pickle(list_info: list, file_name: str): with open(file_name, 'wb') as f: pickle.dump(list_info, f)def read_pickle(file_name: str) -> list: with open(file_name, 'rb') as f: info_list = pickle.load(f) return info_listdef train_one_epoch(model, optimizer, data_loader, device, epoch): model.train() loss_function = torch.nn.CrossEntropyLoss() accu_loss = torch.zeros(1).to(device) # 累计损失 accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数 optimizer.zero_grad() sample_num = 0 data_loader = tqdm(data_loader, file=sys.stdout) for step, data in enumerate(data_loader): images, labels = data sample_num += images.shape[0] pred = model(images.to(device)) pred_classes = torch.max(pred, dim=1)[1] accu_num += torch.eq(pred_classes, labels.to(device)).sum() loss = loss_function(pred, labels.to(device)) loss.backward() accu_loss += loss.detach() data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch, accu_loss.item() / (step + 1), accu_num.item() / sample_num) if not torch.isfinite(loss): print('WARNING: non-finite loss, ending training ', loss) sys.exit(1) optimizer.step() optimizer.zero_grad() return accu_loss.item() / (step + 1), accu_num.item() / sample_num@torch.no_grad()def evaluate(model, data_loader, device, epoch): loss_function = torch.nn.CrossEntropyLoss() model.eval() accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数 accu_loss = torch.zeros(1).to(device) # 累计损失 sample_num = 0 data_loader = tqdm(data_loader, file=sys.stdout) for step, data in enumerate(data_loader): images, labels = data sample_num += images.shape[0] pred = model(images.to(device)) pred_classes = torch.max(pred, dim=1)[1] accu_num += torch.eq(pred_classes, labels.to(device)).sum() loss = loss_function(pred, labels.to(device)) accu_loss += loss data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch, accu_loss.item() / (step + 1), accu_num.item() / sample_num) return accu_loss.item() / (step + 1), accu_num.item() / sample_num3.my_dataset.pyfrom PIL import Imageimport torchfrom torch.utils.data import Datasetclass MyDataSet(Dataset): """自定义数据集""" def __init__(self, images_path: list, images_class: list, transform=None): self.images_path = images_path self.images_class = images_class self.transform = transform def __len__(self): return len(self.images_path) def __getitem__(self, item): img = Image.open(self.images_path[item]) # RGB为彩色图片,L为灰度图片 if img.mode != 'RGB': raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item])) label = self.images_class[item] if self.transform is not None: img = self.transform(img) return img, label @staticmethod def collate_fn(batch): # 官方实现的default_collate可以参考 # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py images, labels = tuple(zip(*batch)) images = torch.stack(images, dim=0) labels = torch.as_tensor(labels) return images, labels4.train.py

其中若使用预训练模型需要提前下载,下载地址在 utils.py 处有标明,代码默认是使用预训练模型的。下载后,预训练模型放入项目的根目录即可。我训练的数据集种类有三种,于是我将网络的全连接层的输出改成了 3 ,各位需要依据自己数据集不同来进行调整。

【pytorch】Vision Transformer实现图像分类+可视化+训练数据保存(pytorch x.view)

若下载不方便,也可以下载我上传的资源:

vit_base_patch16_224_in21k.zip-深度学习文档类资源-CSDN下载

import osimport mathimport argparseimport torchimport torch.optim as optimimport torch.optim.lr_scheduler as lr_schedulerfrom torch.utils.tensorboard import SummaryWriterfrom torchvision import transformsfrom my_dataset import MyDataSetfrom vit_model import vit_base_patch16_224_in21k as create_modelfrom utils import read_split_data, train_one_epoch, evaluateimport xlwtbook = xlwt.Workbook(encoding='utf-8') #创建Workbook,相当于创建Excel# 创建sheet,Sheet1为表的名字,cell_overwrite_ok为是否覆盖单元格sheet1 = book.add_sheet(u'Train_data', cell_overwrite_ok=True)# 向表中添加数据sheet1.write(0, 0, 'epoch')sheet1.write(0, 1, 'Train_Loss')sheet1.write(0, 2, 'Train_Acc')sheet1.write(0, 3, 'Val_Loss')sheet1.write(0, 4, 'Val_Acc')sheet1.write(0, 5, 'lr')sheet1.write(0, 6, 'Best val Acc')def main(args): best_acc = 0 device = torch.device(args.device if torch.cuda.is_available() else "cpu") if os.path.exists("./weights") is False: os.makedirs("./weights") tb_writer = SummaryWriter() train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path) data_transform = { "train": transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]), "val": transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])} # 实例化训练数据集 train_dataset = MyDataSet(images_path=train_images_path, images_class=train_images_label, transform=data_transform["train"]) # 实例化验证数据集 val_dataset = MyDataSet(images_path=val_images_path, images_class=val_images_label, transform=data_transform["val"]) batch_size = args.batch_size nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers print('Using {} dataloader workers every process'.format(nw)) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=nw, collate_fn=train_dataset.collate_fn) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=nw, collate_fn=val_dataset.collate_fn) model = create_model(num_classes=3, has_logits=False).to(device) images = torch.zeros(1, 3, 224, 224).to(device)#要求大小与输入图片的大小一致 tb_writer.add_graph(model, images, verbose=False) if args.weights != "": assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights) weights_dict = torch.load(args.weights, map_location=device) # 删除不需要的权重 del_keys = ['head.weight', 'head.bias'] if model.has_logits \ else ['pre_logits.fc.weight', 'pre_logits.fc.bias', 'head.weight', 'head.bias'] for k in del_keys: del weights_dict[k] print(model.load_state_dict(weights_dict, strict=False)) if args.freeze_layers: for name, para in model.named_parameters(): # 除head, pre_logits外,其他权重全部冻结 if "head" not in name and "pre_logits" not in name: para.requires_grad_(False) else: print("training {}".format(name)) pg = [p for p in model.parameters() if p.requires_grad] optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5) # Scheduler https://arxiv.org/pdf/1812.01187.pdf lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) for epoch in range(args.epochs): sheet1.write(epoch+1, 0, epoch+1) sheet1.write(epoch + 1, 5, str(optimizer.state_dict()['param_groups'][0]['lr'])) # train train_loss, train_acc = train_one_epoch(model=model, optimizer=optimizer, data_loader=train_loader, device=device, epoch=epoch) scheduler.step() sheet1.write(epoch + 1, 1, str(train_loss)) sheet1.write(epoch + 1, 2, str(train_acc)) # validate val_loss, val_acc = evaluate(model=model, data_loader=val_loader, device=device, epoch=epoch) sheet1.write(epoch + 1, 3, str(val_loss)) sheet1.write(epoch + 1, 4, str(val_acc)) tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"] tb_writer.add_scalar(tags[0], train_loss, epoch) tb_writer.add_scalar(tags[1], train_acc, epoch) tb_writer.add_scalar(tags[2], val_loss, epoch) tb_writer.add_scalar(tags[3], val_acc, epoch) tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch) if val_acc > best_acc: best_acc = val_acc torch.save(model.state_dict(), "./weights/best_model.pth") #torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch)) sheet1.write(1, 6, str(best_acc)) book.save('.\Train_data.xlsx') print("The Best Acc = : {:.4f}".format(best_acc))if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--num_classes', type=int, default=3) parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--batch-size', type=int, default=8) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--lrf', type=float, default=0.01) # 数据集所在根目录 parser.add_argument('--data-path', type=str, default=r"D:\pyCharmdata\resnet50_plant_3\datasets\train") parser.add_argument('--model-name', default='', help='create model name') # 预训练权重路径,如果不想载入就设置为空字符 parser.add_argument('--weights', type=str, default='./vit_base_patch16_224_in21k.pth', help='initial weights path') # 是否冻结权重 parser.add_argument('--freeze-layers', type=bool, default=False) parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)') opt = parser.parse_args() main(opt)5.predict.py

可以实现单张图片的种类预测,得分最高的便是模型预测种类。

import osimport jsonimport torchfrom PIL import Imagefrom torchvision import transformsimport matplotlib.pyplot as pltfrom vit_model import vit_base_patch16_224_in21k as create_modeldef main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_transform = transforms.Compose( [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) # load image img_path = r"D:\pyCharmdata\resnet50_plant_3\datasets\test\Huanglong_disease\000000.jpg" assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) img = Image.open(img_path) plt.imshow(img) # [N, C, H, W] img = data_transform(img) # expand batch dimension img = torch.unsqueeze(img, dim=0) # read class_indict json_path = './class_indices.json' assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) with open(json_path, "r") as f: class_indict = json.load(f) # create model model = create_model(num_classes=3, has_logits=False).to(device) # load model weights model_weight_path = "./weights/best_model.pth" model.load_state_dict(torch.load(model_weight_path, map_location=device)) model.eval() with torch.no_grad(): # predict class output = torch.squeeze(model(img.to(device))).cpu() predict = torch.softmax(output, dim=0) predict_cla = torch.argmax(predict).numpy() print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy()) plt.title(print_res) for i in range(len(predict)): print("class: {:10} prob: {:.3}".format(class_indict[str(i)], predict[i].numpy())) plt.show()if __name__ == '__main__': main()

预测结果展示:

四、训练数据

在配置好环境和数据集、预训练模型的路径后,即可运行 train.py 开始训练,默认是训练100轮。

训练使用的是SGDM优化器,初始学习率为0.001,使用LambdaLR自定义学习率调整策略,导入预训练模型但不冻结网络层和参数。

 训练过程中可以在项目路径下的终端 输入:

tensorboard --logdir=runs/

进行实时监控训练进程,也可以查看 Vision Transformer 的网络可视化结构。

Vision Transformer 的网络可视化 :

我简单训练了100轮后,最高 val_acc 准确率为 0.9976。

 训练结束后,会在项目根目录生成一个Excel文件,里面记载了训练全过程的数据,你也可以在通过 Matlab 来获得高度自定义化的可视化对比图片,堪称 论文人 的福音。

我这里只展示前10轮的训练数据。

我的完整项目框架,有需要的自取:

Vit_myself.zip-深度学习文档类资源-CSDN下载

 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~如果本文对你有帮助,欢迎一键三连!!!
本文链接地址:https://www.jiuchutong.com/zhishi/299104.html 转载请保留说明!

上一篇:vue中的map()快速使用(vue-mapvgl)

下一篇:pytorch 多GPU并行训练代码讲解(pytorch多块gpu)

  • 加油站的成品油是石油公司配送吗
  • 医院这么开发票
  • 一个季度30万是不含税吗
  • 获取发票信息异常
  • 公司现金支票取钱需要带什么资料
  • 直接支付和授权支付方式的区别与联系
  • 收到转账支票怎么去银行处理
  • 汽车销售公司购进车辆怎么做账
  • 外币报表折算差额在会计报表中应作为
  • 农产品收购发票图片
  • 设计服务发票怎么入账
  • 增值税发票地址开错了有影响吗
  • 可以通过哪些渠道获得就业信息
  • 投资收益科目应用
  • 供应商给客户员工回扣有罪吗
  • 抵扣联复印件可以做账吗
  • 增值税小规模纳税人减免增值税政策
  • 房地产行业的增值税是多少
  • 公司账户里的钱有利息吗
  • 如何升级mac系统到10.12
  • 负债率是什么指标
  • 个人所得税计提和发放分录
  • 期间费用计入什么科目
  • 一般纳税人差额征税申报表怎么填
  • wordpress进行商城开发
  • 分配股利会稀释股权吗
  • 营改增后土地增值税如何计算
  • java webflux
  • php解析xml文件
  • 管理费用税金怎么算
  • 新建利润表
  • phpcms二次开发教程
  • php安装不了
  • 帝国cms首页调用其他网站数据
  • 残保金都要申报吗
  • 生产车间发放工资
  • 合同补充协议印花税怎么交
  • 应交税费需要结转到本年利润吗
  • 小规模加工企业加工费会计分录
  • ms sql 2012
  • sql语句批量添加数据
  • 上级拨付的债券怎么做账
  • 月未转出未交增值税
  • 汇算清缴如何调报表
  • 资本公积含义
  • 长期股权投资在现金流量表哪里体现
  • 房地产公司车位出租会计分录
  • 固定资产开普票还是专票
  • 水电费收据可以入公司帐吗
  • 企业搬迁安置费一般怎么赔
  • 实收资本可以大过注册资本吗
  • 确认递延所得税资产账务处理
  • 不是企业职工能否挂靠企业交社保
  • 计提税金及附加的金额如何算
  • 公司代个人缴纳社保
  • mysql参数表
  • innodb.trx
  • win7一直弹广告怎么办
  • xp系统如何打印文件
  • linux怎么用命令
  • freebsd操作命令
  • win10的ghost
  • windowxp音频驱动
  • win7虚拟xp系统怎么安装
  • ubuntu21.04中文
  • win8系统文件
  • linux装完显卡驱动分辨率低
  • 铁嘴王指什么动物
  • 简述android多线程编程的实现方式
  • css display none之后怎么显示回来
  • css样式表可以兼容所有浏览器吗
  • python3 创建字典
  • 网络游戏数据包
  • easyui控件
  • JavaScript中length属性的使用方法
  • 置顶怎么折叠起来
  • JavaScript设置字体颜色
  • 关联企业之间借款的税收处理
  • 核定征收一般纳税人
  • 会计专业有必要读博士吗
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设