定义PyTorch模型时候遇到问题
1 2 3 4 5 6
| TypeError: new() received an invalid combination of arguments - got (float, int, int, int), but expected one of: * (torch.device device) * (torch.Storage storage) * (Tensor other) * (tuple of ints size, torch.device device) * (object data, torch.device device)
|
出错代码为
1
| self.conv1 = nn.Conv2d(input_channels, output_channels/4, 1, 1, bias = False)
|
由于环境是python3,整除会出现float型,造成参数类型错误,这里很有可能就是output_channels/4
出了问题。
在python3中,1024/4=256.0
1 2 3 4 5
| Python 3.7.4 (default, Aug 13 2019, 20:35:49) [GCC 7.3.0] :: Anaconda, Inc. on linux Type "help", "copyright", "credits" or "license" for more information. >>> 1024/4 256.0
|
而在python2中,1024/4=256
1 2 3 4 5
| Python 2.7.18rc1 (default, Apr 7 2020, 12:05:55) [GCC 9.3.0] on linux2 Type "help", "copyright", "credits" or "license" for more information. >>> 1024/4 256
|
将/
改为//
,表示强制转换为int型
,问题解决。
1
| self.conv1 = nn.Conv2d(input_channels, output_channels//4, 1, 1, bias = False)
|