Java调用Pytorch实现以图搜图
设计技术栈:
1、ElasticSearch环境;
2、Python运行环境(如果事先没有pytorch模型时,可以用python脚本创建模型);
1、运行效果
2、创建模型(有则可以跳过)
vi script.py
import torch
import torch.nn as nn
import torchvision.models as models
class ImageFeatureExtractor(nn.Module):
def __init__(self):
super(ImageFeatureExtractor, self).__init__()
self.resnet = models.resnet50(pretrained=True)
#最终输出维度1024的向量,下文elastic search要设置dims为1024
self.resnet.fc = nn.Linear(2048, 1024)
def forward(self, x):
x = self.resnet(x)
return x
if __name__ == '__main__':
model = ImageFeatureExtractor()
model.eval()
#根据模型随便创建一个输入
input = torch.rand([1, 3, 224, 224])
output = model(input)
#以这种方式保存
script = torch.jit.trace(model, input)
script.save("model.pt")
2、java项目pom.xml
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.19.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-native-cpu</artifactId>
<version>1.10.0</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-jni</artifactId>
<version>1.10.0-0.19.0</version>
</dependency>
<dependency>
<groupId>org.elasticsearch.client</groupId>
<artifactId>elasticsearch-rest-high-level-client</artifactId>
</dependency>
</dependencies>
3、ES创建文档
PUT /isi
{
"mappings": {
"properties": {
"vector": {
"type": "dense_vector",
"dims": 1024
},
"url" : {
"type" : "keyword"
},
"user_id": {
"type": "keyword"
}
}
}
}
4、编写java代码调用模型
ORCUtil.java
package com.topprismcloud.rtm;
import ai.djl.Device;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Transform;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.ScriptQueryBuilder;
import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xcontent.XContentType;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URL;
import java.nio.file.Paths;
import java.util.*;
public class ORCUtil {
private static final String INDEX = "isi";
private static final int IMAGE_SIZE = 224;
private static Model model; // 模型
private static Predictor<Image, float[]> predictor; // predictor.predict(input)相当于python中model(input)
static {
try {
model = Model.newInstance("model");
// 这里的model.pt是上面代码展示的那种方式保存的
model.load(ORCUtil.class.getClassLoader().getResourceAsStream("model.pt"));
Transform resize = new Resize(IMAGE_SIZE);
Transform toTensor = new ToTensor();
Transform normalize = new Normalize(new float[] { 0.485f, 0.456f, 0.406f },
new float[] { 0.229f, 0.224f, 0.225f });
// Translator处理输入Image转为tensor、输出转为float[]
Translator<Image, float[]> translator = new Translator<Image, float[]>() {
@Override
public NDList processInput(TranslatorContext ctx, Image input) throws Exception {
NDManager ndManager = ctx.getNDManager();
System.out.println("input: " + input.getWidth() + ", " + input.getHeight());
NDArray transform = normalize
.transform(toTensor.transform(resize.transform(input.toNDArray(ndManager))));
System.out.println(transform.getShape());
NDList list = new NDList();
list.add(transform);
return list;
}
@Override
public float[] processOutput(TranslatorContext ctx, NDList ndList) throws Exception {
return ndList.get(0).toFloatArray();
}
};
predictor = new Predictor<>(model, translator, Device.cpu(), true);
} catch (Exception e) {
e.printStackTrace();
}
}
public static void upload() throws Exception {
HttpHost host=new HttpHost("14.20.30.16", 9200, HttpHost.DEFAULT_SCHEME_NAME);
RestClientBuilder builder=RestClient.builder(host);
CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials("elastic", "123456"));
builder.setHttpClientConfigCallback(f -> f.setDefaultCredentialsProvider(credentialsProvider));
RestHighLevelClient client = new RestHighLevelClient( builder);
// 批量上传请求
BulkRequest bulkRequest = new BulkRequest(INDEX);
File file = new File("D:\\001ENV\\nginx-1.24.0\\html\\resource\\new");
for (File listFile : file.listFiles()) {
// float[] vector = predictor.predict(ImageFactory.getInstance()
// .fromInputStream(Test.class.getClassLoader().getResourceAsStream("new/" + listFile.getName())));
float[] vector = predictor.predict(ImageFactory.getInstance()
.fromInputStream(new FileInputStream(listFile)));
// 构建文档
Map<String, Object> jsonMap = new HashMap<>();
jsonMap.put("url", "/resource/"+listFile.getName());
jsonMap.put("vector", vector);
jsonMap.put("user_id", "user123");
IndexRequest request = new IndexRequest(INDEX).source(jsonMap, XContentType.JSON);
bulkRequest.add(request);
}
client.bulk(bulkRequest, RequestOptions.DEFAULT);
client.close();
}
// 接收待搜索图片的inputstream,搜索与其相似的图片
public static List<SearchResult> search(InputStream input) throws Throwable {
float[] vector = predictor.predict(ImageFactory.getInstance().fromInputStream(input));
System.out.println(Arrays.toString(vector));
// 展示k个结果
int k = 100;
// 连接Elasticsearch服务器
RestHighLevelClient client = new RestHighLevelClient(
RestClient.builder(new HttpHost("14.20.30.16", 9200, "http")));
SearchRequest searchRequest = new SearchRequest(INDEX);
Script script = new Script(ScriptType.INLINE, "painless", "cosineSimilarity(params.queryVector, doc['vector'])",
Collections.singletonMap("queryVector", vector));
FunctionScoreQueryBuilder functionScoreQueryBuilder = QueryBuilders
.functionScoreQuery(QueryBuilders.matchAllQuery(), ScoreFunctionBuilders.scriptFunction(script));
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.query(functionScoreQueryBuilder).fetchSource(null, "vector") // 不返回vector字段,太多了没用还耗时
.size(k);
searchRequest.source(searchSourceBuilder);
SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);
SearchHits hits = searchResponse.getHits();
List<SearchResult> list = new ArrayList<>();
for (SearchHit hit : hits) {
// 处理搜索结果
System.out.println(hit.toString());
SearchResult result = new SearchResult((String) hit.getSourceAsMap().get("url"), hit.getScore());
list.add(result);
}
client.close();
return list;
}
public static void main(String[] args) throws Throwable {
ORCUtil.upload();
System.out.println("hao");
}
}
SearchController.java
package com.topprismcloud.rtm;
import java.util.List;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;
@RestController
@CrossOrigin
public class SearchController {
@PostMapping("search")
public ResponseEntity search(MultipartFile file) {
try {
List<SearchResult> list = ORCUtil.search(file.getInputStream());
return ResponseEntity.ok(list);
} catch (Throwable e) {
return ResponseEntity.status(400).body(null);
}
}
}
SearchResult.java
package com.topprismcloud.rtm;
import lombok.AllArgsConstructor;
import lombok.Data;
@Data
@AllArgsConstructor
public class SearchResult {
private String url;
private Float score;
}
5、前端
index.html
<!DOCTYPE html>
<html lang="zh">
<head>
<meta charset="UTF-8">
<title>以图搜图</title>
<style>
body {
background: url("/img/bg.jpg");
background-attachment: fixed;
background-size: 100% 100%;
}
body>div {
width: 1000px;
margin: 50px auto;
padding: 10px 20px;
border: 1px solid lightgray;
border-radius: 20px;
box-sizing: border-box;
background: rgba(255, 255, 255, 0.7);
}
.upload {
display: inline-block;
width: 300px;
height: 280px;
border: 1px dashed lightcoral;
vertical-align: top;
}
.upload .cover {
width: 200px;
height: 200px;
margin: 10px 50px;
border: 1px solid black;
box-sizing: border-box;
text-align: center;
line-height: 200px;
position: relative;
}
.upload img {
width: 198px;
height: 198px;
position: absolute;
left: 0;
top: 0;
}
.upload input {
margin-left: 50px;
}
.upload button {
width: 80px;
height: 30px;
margin-left: 110px;
}
.result-block {
display: inline-block;
margin-left: 40px;
border: 1px solid lightgray;
border-radius: 10px;
min-height: 500px;
width: 600px;
}
.result-block h1 {
text-align: center;
margin-top: 100px;
}
.result {
padding: 10px;
cursor: pointer;
display: inline-block;
}
.result:hover {
background: rgb(240, 240, 240);
}
.result p {
width: 110px;
overflow: hidden;
white-space: nowrap;
text-overflow: ellipsis;
}
.result img {
width: 160px;
height: 160px;
}
.result .prob {
color: rgb(37, 147, 60)
}
</style>
<script src="js/jquery-3.6.0.js"></script>
</head>
<body>
<div>
<div class="upload">
<div class="cover">
请选择图片
<img id="image" src="" />
</div>
<input id="file" type="file">
</div>
<div class="result-block">
<h1>请选择图片</h1>
</div>
</div>
<ul id="box">
</ul>
<script>
var file = $('#file')
file.change(function () {
let f = this.files[0]
let index = f.name.lastIndexOf('.')
let fileText = f.name.substring(index, f.name.length)
let ext = fileText.toLowerCase() //文件类型
console.log(ext)
if (ext != '.png' && ext != '.jpg' && ext != '.jpeg') {
alert('系统仅支持 JPG、PNG、JPEG 格式的图片,请您调整格式后重新上传')
return
}
$('.result-block').empty().append($('<h1>正在识别中...</h1>'))
$("#image").attr("src", getObjectURL(f));
let formData = new FormData()
formData.append('file', f)
$.ajax({
url: 'http://10.1.2.240:8081/search',
method: 'post',
data: formData,
processData: false,
contentType: false,
success: res => {
console.log('shibie', res)
$('.result-block').empty()
for (let item of res) {
console.log(item)
let html = `<div class="result">
<img src="${item.url}"/>
<div style="display: inline-block;vertical-align: top">
<p class="prob">得分:${item.score.toFixed(4)}</p>
</div>
</div>`
$('.result-block').append($(html))
}
}
})
});
$('#button').click(function (e) {
var file = $('#file')[0].files[0] //单个
console.log(file)
})
function getObjectURL(file) {
var url = null;
if (window.createObjcectURL != undefined) {
url = window.createOjcectURL(file);
} else if (window.URL != undefined) {
url = window.URL.createObjectURL(file);
} else if (window.webkitURL != undefined) {
url = window.webkitURL.createObjectURL(file);
}
return url;
}
function detect() {
}
</script>
</body>
</html>
6、打包后的源代码
以图搜图Java+html源代码
相关参考文章:Java调用Pytorch模型进行图像识别