TensorFlow在Android实现物体识别实验设计

Dec 12 2017

目录:

一. Tensorflow的安装及识别模型的训练
二. 识别api在Flask上的搭建
三. Android端的拍照识别

Tensorflow的安装及识别模型的训练

——基于谷歌Inception-v3模型

项目目录结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
-training
-training_data //存放训练数据集文件夹
-10000 //label为10000的图片文件夹
-10003
-10004
……
-model // 存放inception-v3模型
-bottleneck // 存放模型瓶颈层的特征结果
-10000 //label为10000的图片文件夹
-10003
-10004
……
-retrain.py //训练程序
-collect_training_data.py //收集样本数据程序
-dictionay.txt //翻译字典文件
-output_graph.pb //生成的模型文件
-output_label.txt // 生成的label文件

主要步骤

并生成dictionary.txt,作为识别结果的翻译文件

收集训练数据

主要注意:

collect_training_data.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# coding=utf-8
# 收集植物图片
from bs4 import BeautifulSoup
import requests
import os
import random

sess = requests.session()
refer = 'http://www.plantphoto.cn/sp/'
get_pic = 'http://www.plantphoto.cn/ashx/getphotopage.ashx'
data_path = 'training_data/'


def save_imgs(url, path):
header = {
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8",
"Connection": "keep-alive",
"Host": "img.plantphoto.cn",
"Referer": refer + path,
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.81 "
"Safari/537.36",
}
local_filename = url.split('/')[-1]
print("Download Image File=", local_filename)
try:
r = requests.get(url, headers=header, stream=True) # here we need to set stream = True parameter
with open(data_path + path + '/' + local_filename, 'wb') as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
f.flush()
f.close()
except:
print(u'保存图片失败')

if __name__ == '__main__':
file = open('dictionary.txt', 'w')
for time in range(200): # 循环收集200次
i = str(random.randint(10000, 99999)) # 生成随机植物cid
a = range(1, 3) # 收集两页数据
name = ''
for page in reversed(a):
new_url = get_pic + '?page=' + str(page) + '&n=2&group=sp&cid=' + i
header = {
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8',
'Accept-Encoding': 'gzip, deflate, sdch',
'Accept-Language': 'zh-CN,zh;q=0.8,zh-TW;q=0.6',
'Connection': 'keep-alive',
'Host': 'www.plantphoto.cn',
'Cookie': 'AJSTAT_ok_pages=1; AJSTAT_ok_times=4; '
'__tins__2318434=%7B%22sid%22%3A%201512994067786%2C%20%22vd%22%3A%201%2C%20%22expires%22%3A'
'%201512995867786%7D; __51cke__=; __51laig__=4',
'Cache-Control': 'max-age=0',
'Upgrade-Insecure-Requests': '1',
'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) '
'Chrome/58.0.3029.81 Safari/537.36 '
}
response = requests.get(new_url, headers=header).content
soup = BeautifulSoup(response, 'lxml')
imgs = soup.find_all('img')
if imgs.__len__() > 10: # 确保样本数量不小于20
path = data_path + i
name = imgs[0]['alt'].encode('utf-8')
if not os.path.exists(path):
os.mkdir(path)
for j in imgs:
save_imgs(j['src'], i)
else:
# 第二页图片数量不足10张
break
file.write(i + ',' + name + '\n')
file.close()

Tensorflow在Ubuntu16.04上的安装

使用pip安装cpu版本tensorflow

1
sudo pip install --upgrade https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl

tensorflow最新版本地址可在此查看
tensorflow

测试tensorflow是否安装成功:

1
2
3
4
5
6
7
$ python

>>> import tensorflow as tf
>>> hello = tf.constant('Hello, TensorFlow!')
>>> sess = tf.Session()
>>> print sess.run(hello)
Hello, TensorFlow!

安装pillow(图像处理库,用于对识别图像的旋转等操作)

sudo pip install pillow

安装numpy(一种开源的数值计算扩展,这里用于计算图像旋转数据)

sudo pip install numpy

安装opencv(著名图像处理视觉库,这里用于对识别图像的模糊和旋转处理)

sudo pip install opencv

开始训练

下载tensorflow retrain
demo中的retrain.py到项目根目录中,下载地址

打开终端,输入以下命令即进入训练过程

1
$ python ./retrain.py --bottleneck_dir bottleneck --how_many_training_steps 4000 --model_dir model --output_graph output_graph.pb --output_labels output_labels.txt --image_dir training_data/

可选参数:
使用时在参数名称前加”- -“,如”- -output_graph”

参数名 值类型 默认值 说明
image_dir str 训练数据集所在的目录
output_graph str /tmp/output_graph.pb 训练生成的模型保存位置
output_labels str /tmp/output_labels.txt 训练生成的标签保存位置,本例为output_labels.txt
summaries_dir str /tmp/retrain_logs TensorBoard的日志摘要的保存位置
how_many_training_steps int 4000 训练步数
learning_rate float 0.01 学习率
testing_percentage int 10 测试集的百分比
validation_percentage int 10 验证集的百分比
eval_step_interval int 10 训练结果评估的时间间隔
train_batch_size int 100 一次训练的图像的数量
validation_batch_size int 100 一次验证图像数量
print_misclassified_test_images bool False 打印输出所有错误分类的测试图像列表
model_dir str /tmp/imagenet 存放inception-v3模型
bottleneck_dir str /tmp/bottleneck 缓存的瓶颈层值的文件路径
final_tensor_name str final_result 重新训练的图像中输出的分类层名字,将会在get_tensor_by_name时使用
flip_left_right bool False 是否随机水平翻转训练图像的一半
random_crop int 0 随机裁剪训练图像的百分比
random_scale int 0 随机缩放训练图像的百分比
random_brightness int 0 随机调整亮度
architecture str inception_v3 图像分类模型,默认为inception-v3,inception-v3是使用最多的图像模型,但也是最慢的

开始训练后,会自动下载一个inception-2015-12-05.tgz的文件,约88MB

本例训练样本为植物,共有12种,共计472个图像数据,训练过程所用时间约为20分钟

最终训练结果: Final test accuracy = 72.7% (N=55)

训练完成后,项目根目录会生成output_labels.txt与ouput_graph.pb

测试:

test.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#coding=utf-8
import cv2
import math
import numpy as np
from PIL import Image
from PIL.ExifTags import TAGS
import tensorflow as tf

# 加载dictionary.txt
dictionary = {}
ftotal = open('./dictionary.txt', 'r')
line = ftotal.readline()
while line:
totaltree = line.replace('\r\n', '').replace('\n', '').split(',')
dictionary[totaltree[0]] = totaltree[1].decode("utf-8")
line = ftotal.readline()
ftotal.close()

# 加载识别标签
labels = []
for label in tf.gfile.GFile("output_labels.txt"):
labels.append(label.rstrip())

# 加载Graph
with tf.gfile.FastGFile("output_graph.pb", 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')

sess = tf.Session()
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')

# 植物识别,返回识别结果字符串
def plantRecognition(plantFile):

# 打开植物图片进行识别
image = tf.gfile.FastGFile(plantFile, 'rb').read()
predict = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image})

# 根据分类概率进行排序
top = predict[0].argsort()[-len(predict[0]):][::-1]

result = ''
count = 8
for index in top:
human_string = dictionary[labels[index]]
score = predict[0][index]

# 识别率低于0.01即跳出
if score < 0.0099:
break

# 拼写识别结果
temp = human_string + ':' + '%(p).2f'%{'p':score * 100} + '%'
result += temp + '\n'

count -= 1
if count == 0:
break

print(result)
return result

# 图片按中心点旋转
def rotate_about_center(src, angle, scale=1.):
if angle == 0:
return src

w = src.shape[1]
h = src.shape[0]
rangle = np.deg2rad(angle) # angle in radians
# now calculate new image width and height
nw = (abs(np.sin(rangle)*h) + abs(np.cos(rangle)*w))*scale
nh = (abs(np.cos(rangle)*h) + abs(np.sin(rangle)*w))*scale
# ask OpenCV for the rotation matrix
rot_mat = cv2.getRotationMatrix2D((nw*0.5, nh*0.5), angle, scale)
# calculate the move from the old center to the new center combined
# with the rotation
rot_move = np.dot(rot_mat, np.array([(nw-w)*0.5, (nh-h)*0.5,0]))
# the move only affects the translation, so update the translation
# part of the transform
rot_mat[0,2] += rot_move[0]
rot_mat[1,2] += rot_move[1]
return cv2.warpAffine(src, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)

# 对图像进行模糊处理,提高识别率
def GaussianBlur(imagefile):

rotateangle = get_Rotate_Angle(imagefile)
img = cv2.imread(imagefile)
kernel_size = (5, 5)
sigma = 3
newimg = cv2.GaussianBlur(img, kernel_size, sigma)
res = rotate_about_center(newimg, rotateangle)
cv2.imwrite(imagefile, res);

# 获取图片旋转角度
def get_Rotate_Angle(fname):
try:
img = Image.open(fname)
if hasattr(img, '_getexif'):
exifinfo = img._getexif()
if exifinfo != None:
for tag, value in exifinfo.items():
decoded = TAGS.get(tag, tag)
if decoded == 'Orientation':
if value == 6:
return 270
else:
return 0
except IOError:
print 'IOError'
return 0

if __name__ == '__main__':
plantRecognition('测试图片路径')

测试图片:

测试结果:

识别api在Flask上的搭建

服务器: Ubuntu 16.04 64位

准备工作

首先安装项目所需的组件:pip,dev,nginx。python2下安装

1
2
3
sudo apt-get update
sudo apt-get install python-pip python-dev nginx

python3

1
2
sudo apt-get update
sudo apt-get install python3-pip python3-dev nginx

安装python虚拟环境

virtualenv 是一个创建隔绝的Python环境的工具。virtualenv创建一个包含所有必要的可执行文件的文件夹,用来使用Python工程所需的包。

同样,python2下是

1
sudo pip install virtualenv

python3

1
sudo pip3 install virtualenv

在/home目录下创建项目

1
2
3
cd /home
mkdir tensorflow
cd tensorflow

在tensorflow下创建虚拟环境tensorflowenv

1
virtualenv tensorflowenv

进入虚拟环境

1
source tensorflowenv/bin/activate

此时出现 (tensorflowenv) root@:/home/tensorflow#表明已进入虚拟环境

配置Flask应用

安装 Flask and uWSGI

1
pip install uwsgi flask

创建一个Hello world

1
nano ~/tensorflow/tensorflow.py

tensorflow.py

1
2
3
4
5
6
7
8
9
from flask import Flask
app = Flask(__name__)

@app.route("/")
def hello():
return "<h1 style='color:blue'>Hello World!</h1>"

if __name__ == "__main__":
app.run(host='0.0.0.0')

配置WSGI

创建一个wsgi.py

1
nano ~/tensorflow/wsgi.py
1
2
3
4
from tensorflow import app

if __name__ == "__main__":
app.run()

退出虚拟环境

1
deactivate

创建uWSGI配置文件tensorflow.ini

1
nano ~/tensorflow/tensorflow.ini
1
2
3
4
5
6
7
8
9
10
11
12
[uwsgi]
module = wsgi:app
daemonize = /var/log/uwsgi.log

master = true
processes = 5

socket = tensorflow.sock
chmod-socket = 666
vacuum = true

die-on-term = true

配置systemd

创建tensorflow.service

1
sudo nano /etc/systemd/system/tensorflow.service
1
2
3
4
5
6
7
8
9
10
11
12
13
[Unit]
Description=uWSGI instance to serve tensorflow
After=network.target

[Service]
User=root
Group=www-data
WorkingDirectory=/home/tensorflow
Environment="PATH=/home/tensorflow/tensorflowenv/bin"
ExecStart=/home/tensorflow/tensorflowenv/bin/uwsgi --ini tensorflow.ini

[Install]
WantedBy=multi-user.target

现在可以启动uWSGI服务

1
2
sudo systemctl start tensorflow
sudo systemctl enable tensorflow

配置Nginx代理

1
sudo nano /etc/nginx/sites-available/tensorflow
1
2
3
4
5
6
7
8
9
server {
listen 80;
server_name 服务器IP;

location / {
include uwsgi_params;
uwsgi_pass unix:/home/tensorflow/tensorflow.sock;
}
}

开启创建的Nginx服务

1
sudo ln -s /etc/nginx/sites-available/tensorflow /etc/nginx/sites-enabled

检查Nginx配置是否有错误

1
sudo nginx -t

重启Nginx

1
sudo systemctl restart nginx
1
sudo ufw allow 'Nginx Full'

测试Flask是否部署成功

访问http://server_domain_or_IP 查看是否部署成功

安装tensorflow,numpy,pillow,opencv

安装步骤详情见1.1.2.2
Tensorflow在Ubuntu16.04上的安装

注意,由于flask是在安装在虚拟环境下,因此这里的TensorFlow也必须安装在虚拟环境下

进入虚拟环境:

1
source tensorflowenv/bin/activate

tensorflow在虚拟环境下的安装需要加–ignore-installed参数,否则在识别时可能会提示no module named TensorFlow错误
示例:

1
sudo pip install --ignore-installed --upgrade https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl

上传模型

上传训练生成的dictionary.txt 、 output_labels.txt 、 output_graph.pb
三个文件到~/tensorflow/中

编写flask主程序tensorflow.py

tensorflow.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#coding=utf-8

import sys
import os
default_encoding = 'utf-8'
if sys.getdefaultencoding() != default_encoding:
reload(sys)
sys.setdefaultencoding(default_encoding)

import cv2
import math
import numpy as np
from PIL import Image
from PIL.ExifTags import TAGS
import tensorflow as tf
from flask import Flask, request, jsonify
app = Flask(__name__)

#加载识别标签翻译字典
dictionary = {}
ftotal = open('./dictionarys.txt', 'r')
line = ftotal.readline()
while line:
totaltree = line.replace('\r\n', '').replace('\n', '').split(',')
dictionary[totaltree[0]] = totaltree[1].decode("utf-8")
line = ftotal.readline()
ftotal.close()

#加载识别标签
labels = []
for label in tf.gfile.GFile("output_labels.txt"):
labels.append(label.rstrip())

# 加载Graph
with tf.gfile.FastGFile("output_graph.pb", 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')

@app.route('/upload', methods=['GET', 'POST'])
def upload():
if request.method == 'POST':
file = request.files['file']

#构建图片文件存储路径
photofile = '/home/plant/uploadPhoto/' + file.filename

#保存图片文件
file.save(photofile)

GaussianBlur(photofile)

return plantRecognition(photofile)
#返回植物图片识别结果


#植物识别,返回识别结果字符串
def plantRecognition(plantFile):

sess = tf.Session()
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
#打开植物图片进行识别
image = tf.gfile.FastGFile(plantFile, 'rb').read()
predict = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image})

#根据分类概率进行排序
top = predict[0].argsort()[-len(predict[0]):][::-1]

if os.path.exists(plantFile):
os.remove(plantFile)
# 识别后移除文件,节省服务器存储空间
result = []
item = {}
count = 8
for index in top:
human_string = dictionary[labels[index]].decode("utf-8")
score = predict[0][index]

#识别率低于0.01即跳出
if score < 0.0099:
break

item['name'] = human_string.decode("utf-8")
item['possibility'] = '%(p).2f'%{'p':score * 100} + '%'
result.append(item)
item = {}
#只返回识别前8个结果
count -= 1
if count == 0:
break
# 返回json化的识别数据
return jsonify({'data':result})

#图片按中心点旋转
def rotate_about_center(src, angle, scale=1.):
if angle == 0:
return src

w = src.shape[1]
h = src.shape[0]
rangle = np.deg2rad(angle) # angle in radians
# now calculate new image width and height
nw = (abs(np.sin(rangle)*h) + abs(np.cos(rangle)*w))*scale
nh = (abs(np.cos(rangle)*h) + abs(np.sin(rangle)*w))*scale
# ask OpenCV for the rotation matrix
rot_mat = cv2.getRotationMatrix2D((nw*0.5, nh*0.5), angle, scale)
# calculate the move from the old center to the new center combined
# with the rotation
rot_move = np.dot(rot_mat, np.array([(nw-w)*0.5, (nh-h)*0.5,0]))
# the move only affects the translation, so update the translation
# part of the transform
rot_mat[0,2] += rot_move[0]
rot_mat[1,2] += rot_move[1]
return cv2.warpAffine(src, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)

#图片高斯模糊处理
def GaussianBlur(imagefile):

rotateangle = get_Rotate_Angle(imagefile)

img = cv2.imread(imagefile)
kernel_size = (5, 5)
sigma = 3
newimg = cv2.GaussianBlur(img, kernel_size, sigma)
res = rotate_about_center(newimg, rotateangle)

cv2.imwrite(imagefile, res);

#获取图片旋转角度
def get_Rotate_Angle(fname):
try:
img = Image.open(fname)
if hasattr(img, '_getexif'):
exifinfo = img._getexif()
if exifinfo != None:
for tag, value in exifinfo.items():
decoded = TAGS.get(tag, tag)

if decoded == 'Orientation':

if value == 6:
return 270
else:
return 0
except IOError:
print 'IOERROR ' + fname
return 0

if __name__ == '__main__':
app.run(host='0.0.0.0', port=80)

api测试结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
{
"data": [
{
"name": "\u82ab\u837d",
"possibility": "52.84%"
},
{
"name": "\u88c2\u53f6\u94c1\u7ebf\u83b2",
"possibility": "26.60%"
},
{
"name": "\u65e5\u672c\u84dd\u76c6\u82b1",
"possibility": "11.06%"
},
{
"name": "\u83f1\u53f6\u5d16\u722c\u85e4",
"possibility": "4.37%"
},
{
"name": "\u9e92\u9e9f\u638c",
"possibility": "2.93%"
},
{
"name": "\u4e5d\u5473\u4e00\u679d\u84bf",
"possibility": "1.34%"
}
]
}

Android端的拍照识别

拍照

android拍照的第三方库十分多,好用的也不少,这里用的是natario1的CameraView CameraView地址

这里使用了retrofit+rxJava作网络请求

CameraView提供了一个cameraListener监听拍照点击行为

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

cameraView.addCameraListener(new CameraListener() {
@Override
public void onPictureTaken(byte[] picture) {
CameraUtils.decodeBitmap(picture, new CameraUtils.BitmapCallback() {
@Override
public void onBitmapReady(Bitmap bitmap) {
File file = new File(saveBitmapToLocalDir(bitmap)); //将bitmap保存到本地路径
RequestBody requestBody = RequestBody.create(MediaType.parse("multipart/form-data"),file);
MultipartBody.Part body = MultipartBody.Part.createFormData("file",file.getName(),requestBody);
PlantApplication plantApplication = PlantApplication.create(context);
PlantService plantService = plantApplication.getPlantService();
Disposable disposable = plantService.recognition(PlantFactory.UPLOAD_PLANT_PIC,body)
.observeOn(AndroidSchedulers.mainThread())
.subscribeOn(plantApplication.subscribeScheduler())
.subscribe(new Consumer<RecognitionResponse>() {
@Override
public void accept(RecognitionResponse recognitionResponse) throws Exception {
// 获取数据成功
List<RecognitionResult> resultList = recognitionResponse.getList();
......
}
}, new Consumer<Throwable>() {
@Override
public void accept(Throwable throwable) throws Exception {
Toast.makeText(context,throwable.toString(),Toast.LENGTH_LONG).show();
}
});
compositeDisposable.add(disposable);
}
});
}
});

PlantApplication.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import android.app.Application;
import android.content.Context;

import com.mikepenz.iconics.Iconics;
import com.treecute.plant.data.GoodsFactory;
import com.treecute.plant.data.GoodsService;
import com.treecute.plant.data.PlantFactory;
import com.treecute.plant.data.PlantService;
import com.treecute.plant.data.UserFactory;
import com.treecute.plant.data.UserService;

import io.reactivex.Scheduler;
import io.reactivex.schedulers.Schedulers;

/**
* Created by mkind on 2017/11/21 0021.
*/

public class PlantApplication extends Application {
private PlantService plantService;
private Scheduler scheduler;


@Override
public void onCreate() {
super.onCreate();
}

private static PlantApplication get(Context context){
return (PlantApplication) context.getApplicationContext();
}

public static PlantApplication create(Context context){
return PlantApplication.get(context);
}

public PlantService getPlantService(){
if (plantService==null){
plantService = PlantFactory.create();
}
return plantService;
}

public Scheduler subscribeScheduler(){
if (scheduler==null){
scheduler = Schedulers.io();
}
return scheduler;
}

public void setPlantService(PlantService plantService) {
this.plantService = plantService;
}

public void setScheduler(Scheduler scheduler) {
this.scheduler = scheduler;
}
}

PlantService

1
2
3
4
5
6
7
8

public interface PlantService {

// 创建post图片接口
@Multipart @POST Observable<RecognitionResponse> recognition(@Url String url, @Part MultipartBody.Part file);

}

PlantFactory

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18


public class PlantFactory {
public final static String BASE_URL = "http://120.25.1.26:97/";

public final static String UPLOAD_PLANT_PIC = "http://120.25.1.26/upload";

public static PlantService create(){
Retrofit retrofit = new Retrofit.Builder()
.baseUrl(BASE_URL)
.addConverterFactory(GsonConverterFactory.create())
.addCallAdapterFactory(RxJava2CallAdapterFactory.create())
.build();
return retrofit.create(PlantService.class);
}

}

最终效果: