在python代码中,输出张量还需要经过一定的处理,如argmax
、维度处理等操作,但在pytorch mobile中并没有相关的API,使用java进行实现将会非常麻烦。如在java中实现以下方法
1 2 3
| pred = torch.argmax(outputs[0], 1) pred = pred.cpu().data.numpy() predict = pred.squeeze(0)
|
java中的outputTensor
1 2 3
| final IValue[] outputTuple = module.forward(IValue.from(inputTensor)).toTuple(); final Tensor outputTensor = outputTuple[0].toTensor(); Log.d(TAG, "onCreate: " + outputTensor.getDataAsFloatArray());
|
其中java不支持numpy库,方法都得自己重写,且getDataAsFloatArray
获取的是一维的序列,处理也特别麻烦。
解决
在对模型进行trace的时候,修改model中的forward函数,加入以上对输出tensor的操作,即可将方法集成在模型中,直接输出处理后的结果了。
1 2 3 4 5 6 7 8
| def forward(self, x): ... pred = torch.argmax(outputs[0], 1) pred = pred.cpu().data predict = pred.squeeze(0) predict = predict.type(torch.FloatTensor)
return predict
|