Java调用Pytorch实现以图搜图(附源码)

Java调用Pytorch实现以图搜图

设计技术栈:
1、ElasticSearch环境;
2、Python运行环境(如果事先没有pytorch模型时,可以用python脚本创建模型);

1、运行效果

在这里插入图片描述

2、创建模型(有则可以跳过)

vi script.py

import torch
import torch.nn as nn
import torchvision.models as models
 
class ImageFeatureExtractor(nn.Module):
    def __init__(self):
        super(ImageFeatureExtractor, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        #最终输出维度1024的向量,下文elastic search要设置dims为1024
        self.resnet.fc = nn.Linear(2048, 1024)
 
    def forward(self, x):
        x = self.resnet(x)
        return x
 
if __name__ == '__main__':
    model = ImageFeatureExtractor()
    model.eval()
    #根据模型随便创建一个输入
    input = torch.rand([1, 3, 224, 224])
    output = model(input)
    #以这种方式保存
    script = torch.jit.trace(model, input)
    script.save("model.pt")

2、java项目pom.xml

<dependencies>
		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-web</artifactId>
		</dependency>
		<dependency>
			<groupId>org.projectlombok</groupId>
			<artifactId>lombok</artifactId>
			<scope>provided</scope>
		</dependency>
		<dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>0.19.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-cpu</artifactId>
            <version>1.10.0</version>
            <scope>runtime</scope>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-jni</artifactId>
            <version>1.10.0-0.19.0</version>
        </dependency>
        <dependency>
            <groupId>org.elasticsearch.client</groupId>
            <artifactId>elasticsearch-rest-high-level-client</artifactId>
        </dependency>
	</dependencies>

3、ES创建文档

PUT /isi
{
  "mappings": {
    "properties": {
      "vector": {
        "type": "dense_vector",
        "dims": 1024
      },
      "url" : {
        "type" : "keyword"
      },
      "user_id": {
          "type": "keyword"
      }
    }
  }
}

4、编写java代码调用模型

ORCUtil.java

package com.topprismcloud.rtm;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Transform;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.ScriptQueryBuilder;
import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xcontent.XContentType;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URL;
import java.nio.file.Paths;
import java.util.*;

public class ORCUtil {

	private static final String INDEX = "isi";
	private static final int IMAGE_SIZE = 224;
	private static Model model; // 模型
	private static Predictor<Image, float[]> predictor; // predictor.predict(input)相当于python中model(input)
	static {
		try {
			model = Model.newInstance("model");
			// 这里的model.pt是上面代码展示的那种方式保存的
			model.load(ORCUtil.class.getClassLoader().getResourceAsStream("model.pt"));
			Transform resize = new Resize(IMAGE_SIZE);
			Transform toTensor = new ToTensor();
			Transform normalize = new Normalize(new float[] { 0.485f, 0.456f, 0.406f },
					new float[] { 0.229f, 0.224f, 0.225f });
			// Translator处理输入Image转为tensor、输出转为float[]
			Translator<Image, float[]> translator = new Translator<Image, float[]>() {
				@Override
				public NDList processInput(TranslatorContext ctx, Image input) throws Exception {
					NDManager ndManager = ctx.getNDManager();
					System.out.println("input: " + input.getWidth() + ", " + input.getHeight());
					NDArray transform = normalize
							.transform(toTensor.transform(resize.transform(input.toNDArray(ndManager))));
					System.out.println(transform.getShape());
					NDList list = new NDList();
					list.add(transform);
					return list;
				}

				@Override
				public float[] processOutput(TranslatorContext ctx, NDList ndList) throws Exception {
					return ndList.get(0).toFloatArray();
				}
			};
			predictor = new Predictor<>(model, translator, Device.cpu(), true);
		} catch (Exception e) {
			e.printStackTrace();
		}
	}

	public static void upload() throws Exception {
		HttpHost host=new HttpHost("14.20.30.16", 9200, HttpHost.DEFAULT_SCHEME_NAME);
		RestClientBuilder builder=RestClient.builder(host);
		CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
		credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials("elastic", "123456"));
		builder.setHttpClientConfigCallback(f -> f.setDefaultCredentialsProvider(credentialsProvider));
		RestHighLevelClient client = new RestHighLevelClient( builder);
		// 批量上传请求
		BulkRequest bulkRequest = new BulkRequest(INDEX);
		File file = new File("D:\\001ENV\\nginx-1.24.0\\html\\resource\\new");
		for (File listFile : file.listFiles()) {
//			float[] vector = predictor.predict(ImageFactory.getInstance()
//					.fromInputStream(Test.class.getClassLoader().getResourceAsStream("new/" + listFile.getName())));
			
			float[] vector = predictor.predict(ImageFactory.getInstance()
					.fromInputStream(new FileInputStream(listFile)));
			// 构建文档
			Map<String, Object> jsonMap = new HashMap<>();
			jsonMap.put("url", "/resource/"+listFile.getName());
			jsonMap.put("vector", vector);
			jsonMap.put("user_id", "user123");
			IndexRequest request = new IndexRequest(INDEX).source(jsonMap, XContentType.JSON);
			bulkRequest.add(request);
		}
		client.bulk(bulkRequest, RequestOptions.DEFAULT);
		client.close();
	}

	// 接收待搜索图片的inputstream,搜索与其相似的图片
	public static List<SearchResult> search(InputStream input) throws Throwable {
		float[] vector = predictor.predict(ImageFactory.getInstance().fromInputStream(input));
		System.out.println(Arrays.toString(vector));

		// 展示k个结果
		int k = 100;
		// 连接Elasticsearch服务器
		RestHighLevelClient client = new RestHighLevelClient(
				RestClient.builder(new HttpHost("14.20.30.16", 9200, "http")));

		SearchRequest searchRequest = new SearchRequest(INDEX);
		Script script = new Script(ScriptType.INLINE, "painless", "cosineSimilarity(params.queryVector, doc['vector'])",
				Collections.singletonMap("queryVector", vector));

		FunctionScoreQueryBuilder functionScoreQueryBuilder = QueryBuilders
				.functionScoreQuery(QueryBuilders.matchAllQuery(), ScoreFunctionBuilders.scriptFunction(script));

		SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
		searchSourceBuilder.query(functionScoreQueryBuilder).fetchSource(null, "vector") // 不返回vector字段,太多了没用还耗时
				.size(k);

		searchRequest.source(searchSourceBuilder);

		SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);

		SearchHits hits = searchResponse.getHits();

		List<SearchResult> list = new ArrayList<>();
		for (SearchHit hit : hits) {
			// 处理搜索结果
			System.out.println(hit.toString());
			SearchResult result = new SearchResult((String) hit.getSourceAsMap().get("url"), hit.getScore());
			list.add(result);
		}

		client.close();
		return list;
	}

	public static void main(String[] args) throws Throwable {
		ORCUtil.upload();
		System.out.println("hao");
	}
}

SearchController.java

package com.topprismcloud.rtm;

import java.util.List;

import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;

@RestController
@CrossOrigin
public class SearchController {

	@PostMapping("search")
	public ResponseEntity search(MultipartFile file) {
		try {
			List<SearchResult> list = ORCUtil.search(file.getInputStream());
			return ResponseEntity.ok(list);
		} catch (Throwable e) {
			return ResponseEntity.status(400).body(null);
		}
	}
}

SearchResult.java

package com.topprismcloud.rtm;

import lombok.AllArgsConstructor;
import lombok.Data;

@Data
@AllArgsConstructor
public class SearchResult {
    private String url;
    private Float score;
}

5、前端

index.html

<!DOCTYPE html>
<html lang="zh">

<head>
    <meta charset="UTF-8">
    <title>以图搜图</title>
    <style>
        body {
            background: url("/img/bg.jpg");
            background-attachment: fixed;
            background-size: 100% 100%;
        }

        body>div {
            width: 1000px;
            margin: 50px auto;
            padding: 10px 20px;
            border: 1px solid lightgray;
            border-radius: 20px;
            box-sizing: border-box;
            background: rgba(255, 255, 255, 0.7);
        }

        .upload {
            display: inline-block;
            width: 300px;
            height: 280px;
            border: 1px dashed lightcoral;
            vertical-align: top;
        }

        .upload .cover {
            width: 200px;
            height: 200px;
            margin: 10px 50px;
            border: 1px solid black;
            box-sizing: border-box;
            text-align: center;
            line-height: 200px;
            position: relative;
        }

        .upload img {
            width: 198px;
            height: 198px;
            position: absolute;
            left: 0;
            top: 0;
        }

        .upload input {
            margin-left: 50px;
        }

        .upload button {
            width: 80px;
            height: 30px;
            margin-left: 110px;
        }

        .result-block {
            display: inline-block;
            margin-left: 40px;
            border: 1px solid lightgray;
            border-radius: 10px;
            min-height: 500px;
            width: 600px;
        }

        .result-block h1 {
            text-align: center;
            margin-top: 100px;
        }

        .result {
            padding: 10px;
            cursor: pointer;
            display: inline-block;
        }

        .result:hover {
            background: rgb(240, 240, 240);
        }

        .result p {
            width: 110px;
            overflow: hidden;
            white-space: nowrap;
            text-overflow: ellipsis;
        }

        .result img {
            width: 160px;
            height: 160px;
        }

        .result .prob {
            color: rgb(37, 147, 60)
        }
    </style>
    <script src="js/jquery-3.6.0.js"></script>
</head>

<body>
    <div>
        <div class="upload">
            <div class="cover">
                请选择图片
                <img id="image" src="" />
            </div>
            <input id="file" type="file">
        </div>
        <div class="result-block">
            <h1>请选择图片</h1>
        </div>
    </div>
    <ul id="box">

    </ul>
    <script>
        var file = $('#file')
        file.change(function () {
            let f = this.files[0]
            let index = f.name.lastIndexOf('.')
            let fileText = f.name.substring(index, f.name.length)
            let ext = fileText.toLowerCase() //文件类型
            console.log(ext)
            if (ext != '.png' && ext != '.jpg' && ext != '.jpeg') {
                alert('系统仅支持 JPG、PNG、JPEG 格式的图片,请您调整格式后重新上传')
                return
            }
            $('.result-block').empty().append($('<h1>正在识别中...</h1>'))
            $("#image").attr("src", getObjectURL(f));
            let formData = new FormData()
            formData.append('file', f)
            $.ajax({
                url: 'http://10.1.2.240:8081/search',
                method: 'post',
                data: formData,
                processData: false,
                contentType: false,
                success: res => {
                    console.log('shibie', res)
                    $('.result-block').empty()
                    for (let item of res) {
                        console.log(item)
                        let html = `<div class="result">
                                    <img src="${item.url}"/>
                                    <div style="display: inline-block;vertical-align: top">
                                        <p class="prob">得分:${item.score.toFixed(4)}</p>
                                    </div>
                                </div>`
                        $('.result-block').append($(html))
                    }

                }
            })
        });
        $('#button').click(function (e) {
            var file = $('#file')[0].files[0] //单个
            console.log(file)
        })
        function getObjectURL(file) {
            var url = null;
            if (window.createObjcectURL != undefined) {
                url = window.createOjcectURL(file);
            } else if (window.URL != undefined) {
                url = window.URL.createObjectURL(file);
            } else if (window.webkitURL != undefined) {
                url = window.webkitURL.createObjectURL(file);
            }
            return url;
        }
        function detect() {

        }
    </script>
</body>

</html>

6、打包后的源代码

以图搜图Java+html源代码

相关参考文章:Java调用Pytorch模型进行图像识别

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/27414.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

AI实战营:目标检测与MMDetection

目录 目标检测的基本范式 什么是目标检测 目标检测 vs 图像分类 目标检测 in 人脸识别 目标检测 in 智慧城市 ​编辑​编辑 目标检测 in 自动驾驶 目标检测 in 下游视觉任务 目标检测技术的演进 基础知识 框、边界框&#xff08;Bounding Box&#xff09; 交并比…

计算机网络填空题

我会写下自己的答案和理解 希望自己可用在学习中体会到快乐&#xff0c;而不是麻木。 1. 网络协议三要素中语义是指 需要发出何种控制信息&#xff0c;完成何种动作以及做出何种响应 1.在计算机网络中要做到有条不紊的交换数据&#xff0c;就必须遵守一些事…

组件更新的底层逻辑

第一种更新&#xff1a;组件更新的逻辑&#xff0c;当修改了相关状态&#xff0c;组件会更新 1.触发shouldComponentUpdate 周期函数:是否允许更新 shouldComponentUpdate(nextProps, nextState) { // nextState: 存储要修改的最新状态 // this. state:存储的还是修改前的状态…

闲置APP小程序开发 你不喜欢的可能正是别人需要的

生活中我们常常会产生各种闲置物品&#xff0c;尤其是对于有宝宝的家庭来说&#xff0c;孩子小的时候可能会添置各种玩具、婴儿车或者是别的用品&#xff0c;随着孩子渐渐长大&#xff0c;这些东西都用不上了&#xff0c;但是扔了又觉得很可惜&#xff0c;留着又占地方&#xf…

RocketMQ5.x版本延迟消息被重放问题调查

一、问题 由于目标计划是将集群从4.9.x逐步升级至5.x&#xff0c;故目前先对一些不重要的集群进行升级测试。 但是在4.x的broker陆续升级至5.x的过程中&#xff0c;发现了延迟消息被重放的问题。 具体如下: 在升级时刷新后台监控&#xff0c;发现竟然有写入量&#xff1a; 即…

基于组件化开发思想的微信小程序开发框架

跨端框架的出现为小程序应用的开发带来了巨大的便利性和灵活性。它们提供了统一的开发方式、代码复用的能力&#xff0c;并且与小程序容器技术紧密结合&#xff0c;实现了一次编码、多端运行的目标。开发者可以根据项目需求和团队技术栈选择合适的跨端框架&#xff0c;从而在不…

11 GMM——高斯混合模型

文章目录 11 GMM——高斯混合模型11.1 模型介绍11.2 通过MLE估计参数11.3 EM求解 11 GMM——高斯混合模型 11.1 模型介绍 从几何角度来说&#xff1a; 高斯混合模型表示&#xff1a;加权平均——由多个高斯分布混合叠加而成&#xff0c;如图 公式可以表达为&#xff1a; p…

基于XC7Z100的PCIe采集卡(GMSL FMC采集卡)

GMSL 图像采集卡 特性 ● PCIe Gen2.0 X8 总线&#xff1b; ● 支持V4L2调用&#xff1b; ● 1路CAN接口&#xff1b; ● 6路/12路 GMSL1/2摄像头输入&#xff0c;最高可达8MP&#xff1b; ● 2路可定义相机同步触发输入/输出&#xff1b; 优势 ● 采用PCIe主卡与FMC子…

安卓大作业 书籍列表APP

系列文章 安卓大作业 书籍列表APP 文章目录 系列文章1&#xff0e;背景2&#xff0e;功能3. 源代码获取 1&#xff0e;背景 我做的项目是一个可以查看到书籍列表以及详情效果的内容&#xff0c;主要使用到的技术有Intent数据传递以及数据库存储的应用&#xff0c;其次使用的组…

Qt线程的几种使用方法

目录 引言使用方法重写QThread::run()moveToThreadQRunnable使用QtConcurrent使用 完整代码 引言 多线程不应该是一个复杂而令人生畏的东西&#xff0c;它应该只是程序员的一个工具&#xff0c;不应该是调用者过多记忆相关概念&#xff0c;而应该是被调用方应该尽可能的简化调…

Linux教程——常见Linux发行版本有哪些?

新手往往会被 Linux 众多的发行版本搞得一头雾水&#xff0c;我们首先来解释一下这个问题。 从技术上来说&#xff0c;李纳斯•托瓦兹开发的 Linux 只是一个内核。内核指的是一个提供设备驱动、文件系统、进程管理、网络通信等功能的系统软件&#xff0c;内核并不是一套完整的…

网络安全从业人员2023年后真的会被AI取代吗?

随着ChatGPT的火爆&#xff0c;很多人开始担心网络安全从业人员会被AI取代。如果说网络安全挖洞的话&#xff0c;AI可能真的能取代。但是网络安全不仅仅只是挖洞&#xff0c;所以AI只是能缓解网络安全人员不足的情况&#xff0c;但是是不会取代人类的作用的。 就拿最近很火的C…

【线性代数】

求解线性方程组 右乘向量/矩阵 把左边的矩阵拆成一个个列向量&#xff0c;右边的向量表示对左边列向量组的线性组合。 [ c o l 1 c o l 2 c o l 3 ] [ 3 4 5 ] [ 3 c o l 1 4 c o l 2 5 c o l 3 ] \left[\begin{array}{c} col_{1} & col_{2} & col_{3} \end{array}\…

WPS表格处理

wps表格中公式出来的内容如何转为纯文本 选中公式算出的结果区域&#xff0c;复制&#xff0c;在原区域上右键&#xff0c;选择性粘贴为数值&#xff0c;就转成文本了&#xff0c;当然公式也就消除了。 wps表格如何设置整列公式&#xff1f; 1、先来看看下面这个例子需做出商…

Git、Github、Gitee的区别

⭐作者主页&#xff1a;逐梦苍穹 ⭐所属专栏&#xff1a;Git 目录 1、Git2、Gitee3、GitHub 什么是版本管理&#xff1f;   版本管理是管理各个不同的版本&#xff0c;出了问题可以及时回滚。 1、Git Git是一个分布式版本控制系统&#xff0c;用于跟踪和管理代码的变化。它是…

【Ubuntu系统内核更新与卸载】

【Ubuntu系统内核更新与卸载】 1. 前言2. 内核安装2.1 系统更新2.2 官网下载 3. 内核卸载3.1 需求分析3.2 卸载方法 1. 前言 我们在搭建环境时常常遇到内核版本不匹配的问题&#xff0c;需要我们安装新的内核版本&#xff1b;有时又会遇到在安装软件时报错boot空间已满无法安装…

2021年国赛高教杯数学建模B题乙醇偶合制备C4烯烃解题全过程文档及程序

2021年国赛高教杯数学建模 B题 乙醇偶合制备C4烯烃 原题再现 C4 烯烃广泛应用于化工产品及医药的生产&#xff0c;乙醇是生产制备 C4 烯烃的原料。在制备过程中&#xff0c;催化剂组合&#xff08;即&#xff1a;Co 负载量、Co/SiO2 和 HAP 装料比、乙醇浓度的组合&#xff0…

(六)CSharp-CSharp图解教程版-委托

一、委托概述 1、什么是委托 委托和类一样&#xff0c;是一种用户定义类型&#xff08;即是一种类&#xff0c;所以也是一个引用类型&#xff09;。在它们组成的结构方面区别是&#xff0c;类表示的是数据和方法的集合&#xff0c;而委托则持有一个或多个方法。 可以把 deleg…

HNU-操作系统OS-作业1(4-9章)

这份文件是OS_homework_1 by计科2102 wolf 202108010XXX 文档设置了目录,可以通过目录快速跳转至答案部分。 第四章 4.1用以下标志运行程序:./process-run.py -l 5:100,5:100。CPU 利用率(CPU 使用时间的百分比)应该是多少?为什么你知道这一点?利用 -c 标记查看你…

[230604] 听力TPO66汇总·上篇| C1 L1 C2|10:20~12:00

目录​​​​​​​ Science Fiction And Sci-fi-C1 错题分析 C1-3 细节双选题 C1 精听练习 做题笔记 Financial Advice-C2 全对 C2 精听练习 Sleep-L1 错题分析 L1-4 细节题 L1-5 细节双选题 L1 精听练习 做题笔记 词汇&#xff1a;http://t.csdn.cn/Zhuws 两篇对…