1.什么是tensorflow?
tensorflow名字的由来就是张量(tensor)在计算图(computational graph)里的流动(flow),如图。它的基础就是前面介绍的基于计算图的自动微分,除了自动帮你求梯度之外,它也提供了各种常见的操作(op,也就是计算图的节点),常见的损失函数,优化算法。

tensorflow 是一个开放源代码软件库,用于进行高性能数值计算。借助其灵活的架构,用户可以轻松地将计算工作部署到多种平台(cpu、gpu、tpu)和设备(桌面设备、服务器集群、移动设备、边缘设备等)。
tensorflow 是一个用于研究和生产的开放源代码机器学习库。tensorflow 提供了各种 api,可供初学者和专家在桌面、移动、网络和云端环境下进行开发。
tensorflow是采用数据流图(data flow graphs)来计算,所以首先我们得创建一个数据流流图,然后再将我们的数据(数据以张量(tensor)的形式存在)放在数据流图中计算. 节点(nodes)在图中表示数学操作,图中的边(edges)则表示在节点间相互联系的多维数据数组, 即张量(tensor)。训练模型时tensor会不断的从数据流图中的一个节点flow到另一节点, 这就是tensorflow名字的由来。 张量(tensor):张量有多种. 零阶张量为 纯量或标量 (scalar) 也就是一个数值. 比如 [1],一阶张量为 向量 (vector), 比如 一维的 [1, 2, 3],二阶张量为 矩阵 (matrix), 比如 二维的 [[1, 2, 3],[4, 5, 6],[7, 8, 9]],以此类推, 还有 三阶 三维的 … 张量从流图的一端流动到另一端的计算过程。它生动形象地描述了复杂数据结构在人工神经网中的流动、传输、分析和处理模式。
在机器学习中,数值通常由4种类型构成: (1)标量(scalar):即一个数值,它是计算的最小单元,如“1”或“3.2”等。 (2)向量(vector):由一些标量构成的一维数组,如[1, 3.2, 4.6]等。 (3)矩阵(matrix):是由标量构成的二维数组。 (4)张量(tensor):由多维(通常)数组构成的数据集合,可理解为高维矩阵。
tensorflow的基本概念
- 图:描述了计算过程,tensorflow用图来表示计算过程
- 张量:tensorflow 使用tensor表示数据,每一个tensor是一个多维化的数组
- 操作:图中的节点为op,一个op获得/输入0个或者多个tensor,执行并计算,产生0个或多个tensor
- 会话:session tensorflow的运行需要再绘话里面运行
tensorflow写代码流程
- 定义变量占位符
- 根据数学原理写方程
- 定义损失函数cost
- 定义优化梯度下降 gradientdescentoptimizer
- session 进行训练,for循环
- 保存saver
2.环境准备
整合步骤
- 模型构建:首先,我们需要在tensorflow中定义并训练深度学习模型。这可能涉及选择合适的网络结构、优化器和损失函数等。
- 训练数据准备:接下来,我们需要准备用于训练和验证模型的数据。这可能包括数据清洗、标注和预处理等步骤。
- rest api设计:为了与tensorflow模型进行交互,我们需要在springboot中创建一个rest api。这可以使用springboot的内置功能来实现,例如使用spring mvc或spring webflux。
- 模型部署:在模型训练完成后,我们需要将其部署到springboot应用中。为此,我们可以使用tensorflow的java api将模型导出为onnx或savedmodel格式,然后在springboot应用中加载并使用。
在整合过程中,有几个关键点需要注意。首先,防火墙设置可能会影响tensorflow训练过程中的网络通信。确保你的防火墙允许tensorflow访问其所需的网络资源,以免出现训练中断或模型性能下降的问题。其次,要关注版本兼容性。springboot和tensorflow都有各自的版本更新周期,确保在整合时使用兼容的版本可以避免很多不必要的麻烦。
3.代码工程
实验目的
实现图片检测
pom.xml
<?xml version="1.0" encoding="utf-8"?>
<project xmlns="http://maven.apache.org/pom/4.0.0"
xmlns:xsi="http://www.w3.org/2001/xmlschema-instance"
xsi:schemalocation="http://maven.apache.org/pom/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactid>springboot-demo</artifactid>
<groupid>com.et</groupid>
<version>1.0-snapshot</version>
</parent>
<modelversion>4.0.0</modelversion>
<artifactid>tensorflow</artifactid>
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupid>org.springframework.boot</groupid>
<artifactid>spring-boot-starter-web</artifactid>
</dependency>
<dependency>
<groupid>org.springframework.boot</groupid>
<artifactid>spring-boot-autoconfigure</artifactid>
</dependency>
<dependency>
<groupid>org.springframework.boot</groupid>
<artifactid>spring-boot-starter-test</artifactid>
<scope>test</scope>
</dependency>
<dependency>
<groupid>org.tensorflow</groupid>
<artifactid>tensorflow-core-platform</artifactid>
<version>0.5.0</version>
</dependency>
<dependency>
<groupid>org.projectlombok</groupid>
<artifactid>lombok</artifactid>
</dependency>
<dependency>
<groupid>jmimemagic</groupid>
<artifactid>jmimemagic</artifactid>
<version>0.1.2</version>
</dependency>
<dependency>
<groupid>jakarta.platform</groupid>
<artifactid>jakarta.jakartaee-api</artifactid>
<version>9.0.0</version>
</dependency>
<dependency>
<groupid>commons-io</groupid>
<artifactid>commons-io</artifactid>
<version>2.16.1</version>
</dependency>
<dependency>
<groupid>org.springframework.restdocs</groupid>
<artifactid>spring-restdocs-mockmvc</artifactid>
<scope>test</scope>
</dependency>
</dependencies>
</project>
controller
package com.et.tf.api;
import java.io.ioexception;
import com.et.tf.service.classifyimageservice;
import net.sf.jmimemagic.magic;
import net.sf.jmimemagic.magicmatch;
import org.springframework.beans.factory.annotation.autowired;
import org.springframework.web.bind.annotation.crossorigin;
import org.springframework.web.bind.annotation.postmapping;
import org.springframework.web.bind.annotation.requestmapping;
import org.springframework.web.bind.annotation.requestparam;
import org.springframework.web.bind.annotation.restcontroller;
import org.springframework.web.multipart.multipartfile;
@restcontroller
@requestmapping("/api")
public class appcontroller {
@autowired
classifyimageservice classifyimageservice;
@postmapping(value = "/classify")
@crossorigin(origins = "*")
public classifyimageservice.labelwithprobability classifyimage(@requestparam multipartfile file) throws ioexception {
checkimagecontents(file);
return classifyimageservice.classifyimage(file.getbytes());
}
@requestmapping(value = "/")
public string index() {
return "index";
}
private void checkimagecontents(multipartfile file) {
magicmatch match;
try {
match = magic.getmagicmatch(file.getbytes());
} catch (exception e) {
throw new runtimeexception(e);
}
string mimetype = match.getmimetype();
if (!mimetype.startswith("image")) {
throw new illegalargumentexception("not an image type: " + mimetype);
}
}
}
service
package com.et.tf.service;
import jakarta.annotation.predestroy;
import java.util.arrays;
import java.util.list;
import lombok.allargsconstructor;
import lombok.data;
import lombok.noargsconstructor;
import lombok.extern.slf4j.slf4j;
import org.springframework.beans.factory.annotation.value;
import org.springframework.stereotype.service;
import org.tensorflow.graph;
import org.tensorflow.output;
import org.tensorflow.session;
import org.tensorflow.tensor;
import org.tensorflow.ndarray.ndarrays;
import org.tensorflow.ndarray.shape;
import org.tensorflow.ndarray.buffer.floatdatabuffer;
import org.tensorflow.op.opscope;
import org.tensorflow.op.scope;
import org.tensorflow.proto.framework.datatype;
import org.tensorflow.types.tfloat32;
import org.tensorflow.types.tint32;
import org.tensorflow.types.tstring;
import org.tensorflow.types.family.ttype;
//inspired from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/labelimage.java
@service
@slf4j
public class classifyimageservice {
private final session session;
private final list<string> labels;
private final string outputlayer;
private final int w;
private final int h;
private final float mean;
private final float scale;
public classifyimageservice(
graph inceptiongraph, list<string> labels, @value("${tf.outputlayer}") string outputlayer,
@value("${tf.image.width}") int imagew, @value("${tf.image.height}") int imageh,
@value("${tf.image.mean}") float mean, @value("${tf.image.scale}") float scale
) {
this.labels = labels;
this.outputlayer = outputlayer;
this.h = imageh;
this.w = imagew;
this.mean = mean;
this.scale = scale;
this.session = new session(inceptiongraph);
}
public labelwithprobability classifyimage(byte[] imagebytes) {
long start = system.currenttimemillis();
try (tensor image = normalizedimagetotensor(imagebytes)) {
float[] labelprobabilities = classifyimageprobabilities(image);
int bestlabelidx = maxindex(labelprobabilities);
labelwithprobability labelwithprobability =
new labelwithprobability(labels.get(bestlabelidx), labelprobabilities[bestlabelidx] * 100f, system.currenttimemillis() - start);
log.debug(string.format(
"image classification [%s %.2f%%] took %d ms",
labelwithprobability.getlabel(),
labelwithprobability.getprobability(),
labelwithprobability.getelapsed()
)
);
return labelwithprobability;
}
}
private float[] classifyimageprobabilities(tensor image) {
try (tensor result = session.runner().feed("input", image).fetch(outputlayer).run().get(0)) {
final shape resultshape = result.shape();
final long[] rshape = resultshape.asarray();
if (resultshape.numdimensions() != 2 || rshape[0] != 1) {
throw new runtimeexception(
string.format(
"expected model to produce a [1 n] shaped tensor where n is the number of labels, instead it produced one with shape %s",
arrays.tostring(rshape)
));
}
int nlabels = (int) rshape[1];
floatdatabuffer resultfloatbuffer = result.asrawtensor().data().asfloats();
float[] dst = new float[nlabels];
resultfloatbuffer.read(dst);
return dst;
}
}
private int maxindex(float[] probabilities) {
int best = 0;
for (int i = 1; i < probabilities.length; ++i) {
if (probabilities[i] > probabilities[best]) {
best = i;
}
}
return best;
}
private tensor normalizedimagetotensor(byte[] imagebytes) {
try (graph g = new graph();
tint32 batchtensor = tint32.scalarof(0);
tint32 sizetensor = tint32.vectorof(h, w);
tfloat32 meantensor = tfloat32.scalarof(mean);
tfloat32 scaletensor = tfloat32.scalarof(scale);
) {
graphbuilder b = new graphbuilder(g);
//tutorial python here: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/label_image
// some constants specific to the pre-trained model at:
// https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz
//
// - the model was trained with images scaled to 299x299 pixels.
// - the colors, represented as r, g, b in 1-byte each were converted to
// float using (value - mean)/scale.
// since the graph is being constructed once per execution here, we can use a constant for the
// input image. if the graph were to be re-used for multiple input images, a placeholder would
// have been more appropriate.
final output input = b.constant("input", tstring.tensorofbytes(ndarrays.scalarofobject(imagebytes)));
final output output =
b.div(
b.sub(
b.resizebilinear(
b.expanddims(
b.cast(b.decodejpeg(input, 3), datatype.dt_float),
b.constant("make_batch", batchtensor)
),
b.constant("size", sizetensor)
),
b.constant("mean", meantensor)
),
b.constant("scale", scaletensor)
);
try (session s = new session(g)) {
return s.runner().fetch(output.op().name()).run().get(0);
}
}
}
static class graphbuilder {
final scope scope;
graphbuilder(graph g) {
this.g = g;
this.scope = new opscope(g);
}
output div(output x, output y) {
return binaryop("div", x, y);
}
output sub(output x, output y) {
return binaryop("sub", x, y);
}
output resizebilinear(output images, output size) {
return binaryop("resizebilinear", images, size);
}
output expanddims(output input, output dim) {
return binaryop("expanddims", input, dim);
}
output cast(output value, datatype dtype) {
return g.opbuilder("cast", "cast", scope).addinput(value).setattr("dstt", dtype).build().output(0);
}
output decodejpeg(output contents, long channels) {
return g.opbuilder("decodejpeg", "decodejpeg", scope)
.addinput(contents)
.setattr("channels", channels)
.build()
.output(0);
}
output<? extends ttype> constant(string name, tensor t) {
return g.opbuilder("const", name, scope)
.setattr("dtype", t.datatype())
.setattr("value", t)
.build()
.output(0);
}
private output binaryop(string type, output in1, output in2) {
return g.opbuilder(type, type, scope).addinput(in1).addinput(in2).build().output(0);
}
private final graph g;
}
@predestroy
public void close() {
session.close();
}
@data
@noargsconstructor
@allargsconstructor
public static class labelwithprobability {
private string label;
private float probability;
private long elapsed;
}
}
application.yaml
tf:
frozenmodelpath: inception-v3/inception_v3_2016_08_28_frozen.pb
labelspath: inception-v3/imagenet_slim_labels.txt
outputlayer: inceptionv3/predictions/reshape_1
image:
width: 299
height: 299
mean: 0
scale: 255
logging.level.net.sf.jmimemagic: warn
spring:
servlet:
multipart:
max-file-size: 5mb
application.java
package com.et.tf;
import java.io.ioexception;
import java.nio.charset.standardcharsets;
import java.util.list;
import java.util.stream.collectors;
import lombok.extern.slf4j.slf4j;
import org.apache.commons.io.ioutils;
import org.springframework.beans.factory.annotation.value;
import org.springframework.boot.springapplication;
import org.springframework.boot.autoconfigure.springbootapplication;
import org.springframework.context.annotation.bean;
import org.springframework.core.io.classpathresource;
import org.springframework.core.io.filesystemresource;
import org.springframework.core.io.resource;
import org.tensorflow.graph;
import org.tensorflow.proto.framework.graphdef;
@springbootapplication
@slf4j
public class application {
public static void main(string[] args) {
springapplication.run(application.class, args);
}
@bean
public graph tfmodelgraph(@value("${tf.frozenmodelpath}") string tffrozenmodelpath) throws ioexception {
resource graphresource = getresource(tffrozenmodelpath);
graph graph = new graph();
graph.importgraphdef(graphdef.parsefrom(graphresource.getinputstream()));
log.info("loaded tensorflow model");
return graph;
}
private resource getresource(@value("${tf.frozenmodelpath}") string tffrozenmodelpath) {
resource graphresource = new filesystemresource(tffrozenmodelpath);
if (!graphresource.exists()) {
graphresource = new classpathresource(tffrozenmodelpath);
}
if (!graphresource.exists()) {
throw new illegalargumentexception(string.format("file %s does not exist", tffrozenmodelpath));
}
return graphresource;
}
@bean
public list<string> tfmodellabels(@value("${tf.labelspath}") string labelspath) throws ioexception {
resource labelsres = getresource(labelspath);
log.info("loaded model labels");
return ioutils.readlines(labelsres.getinputstream(), standardcharsets.utf_8).stream()
.map(label -> label.substring(label.contains(":") ? label.indexof(":") + 1 : 0)).collect(collectors.tolist());
}
}
以上只是一些关键代码,所有代码请参见下面代码仓库
代码仓库
https://github.com/harries/springboot-demo
4.测试
启动 spring boot应用程序
测试图片分类
访问http://127.0.0.1:8080/,上传一张图片,点击分类

5.总结
以上就是springboot集成tensorflow实现图片检测功能的详细内容,更多关于springboot tensorflow图片检测的资料请关注代码网其它相关文章!
发表评论