加入收藏 | 设为首页 | 会员中心 | 我要投稿 应用网_丽江站长网 (http://www.0888zz.com/)- 科技、建站、数据工具、云上网络、机器学习!
当前位置: 首页 > 运营中心 > 建站资源 > 经验 > 正文

PyTorch最佳实践,怎样才能写出一手风格优美的代码

发布时间:2019-05-07 10:55:08 所属栏目:经验 来源:机器之心编译
导读:副标题#e# 虽然这是一个非官方的 PyTorch 指南,但本文总结了一年多使用 PyTorch 框架的经验,尤其是用它开发深度学习相关工作的最优解决方案。请注意,我们分享的经验大多是从研究和实践角度出发的。 这是一个开发的项目,欢迎其它读者改进该文档: https:

2. PyTorch 环境下的简单残差网络

  1. class ResnetBlock(nn.Module): 
  2.     def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 
  3.         super(ResnetBlock, self).__init__() 
  4.         selfself.conv_block = self.build_conv_block(...) 
  5.  
  6.     def build_conv_block(self, ...): 
  7.         conv_block = [] 
  8.  
  9.         conv_block += [nn.Conv2d(...), 
  10.                        norm_layer(...), 
  11.                        nn.ReLU()] 
  12.         if use_dropout: 
  13.             conv_block += [nn.Dropout(...)] 
  14.  
  15.         conv_block += [nn.Conv2d(...), 
  16.                        norm_layer(...)] 
  17.  
  18.         return nn.Sequential(*conv_block) 
  19.  
  20.     def forward(self, x): 
  21.         out = x + self.conv_block(x) 
  22.         return ou 

在这里,ResNet 模块的跳跃连接直接在前向传导过程中实现了,PyTorch 允许在前向传导过程中进行动态操作。

3. PyTorch 环境下的带多个输出的网络

对于有多个输出的网络(例如使用一个预训练好的 VGG 网络构建感知损失),我们使用以下模式:

  1. class Vgg19(torch.nn.Module): 
  2.   def __init__(self, requires_grad=False): 
  3.     super(Vgg19, self).__init__() 
  4.     vgg_pretrained_features = models.vgg19(pretrained=True).features 
  5.     self.slice1 = torch.nn.Sequential() 
  6.     self.slice2 = torch.nn.Sequential() 
  7.     self.slice3 = torch.nn.Sequential() 
  8.  
  9.     for x in range(7): 
  10.         self.slice1.add_module(str(x), vgg_pretrained_features[x]) 
  11.     for x in range(7, 21): 
  12.         self.slice2.add_module(str(x), vgg_pretrained_features[x]) 
  13.     for x in range(21, 30): 
  14.         self.slice3.add_module(str(x), vgg_pretrained_features[x]) 
  15.     if not requires_grad: 
  16.         for param in self.parameters(): 
  17.             param.requires_grad = False 
  18.  
  19.   def forward(self, x): 
  20.     h_relu1 = self.slice1(x) 
  21.     h_relu2 = self.slice2(h_relu1)         
  22.     h_relu3 = self.slice3(h_relu2)         
  23.     out = [h_relu1, h_relu2, h_relu3] 
  24.     return out 

(编辑:应用网_丽江站长网)

【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!

热点阅读