利用forward实现对输出张量的处理

Apr 25 2020

在python代码中,输出张量还需要经过一定的处理,如argmax、维度处理等操作,但在pytorch mobile中并没有相关的API,使用java进行实现将会非常麻烦。如在java中实现以下方法

1
2
3
pred = torch.argmax(outputs[0], 1)
pred = pred.cpu().data.numpy()
predict = pred.squeeze(0)

java中的outputTensor

1
2
3
final IValue[] outputTuple = module.forward(IValue.from(inputTensor)).toTuple();
final Tensor outputTensor = outputTuple[0].toTensor();
Log.d(TAG, "onCreate: " + outputTensor.getDataAsFloatArray());

其中java不支持numpy库,方法都得自己重写,且getDataAsFloatArray获取的是一维的序列,处理也特别麻烦。

解决

在对模型进行trace的时候,修改model中的forward函数,加入以上对输出tensor的操作,即可将方法集成在模型中,直接输出处理后的结果了。

1
2
3
4
5
6
7
8
def forward(self, x):
...
pred = torch.argmax(outputs[0], 1)
pred = pred.cpu().data
predict = pred.squeeze(0)
predict = predict.type(torch.FloatTensor)

return predict