关注公众号,回复“工具调用”获取代码。
背景
博主之前研究的是ChatGLM3模型,该模型提供了Openai方式调用工具的代码。但后续转到Qwen1.5模型后,好像不可以直接用Openai接口的方式调用工具了。
然后,博主研究了Qwen-Agent框架,实现了自定义工具:
https://blog.csdn.net/weixin_44455388/article/details/136524354?spm=1001.2014.3001.5501
研究了LangChain框架,实现了自定义工具:
https://blog.csdn.net/weixin_44455388/article/details/136536875?spm=1001.2014.3001.5501
虽然,使用以上框架实现了自定义工具,但是调用工具时,均需要依赖于python环境和以上框架,觉得还是有一定的限制。再加上,博主之前的基于大模型的所有功能(本地知识库、Text2SQL等)均是使用Java调用OpenAI接口实现,没有使用类似langChain这样的python框架。作为倔强的Java程序员,还是想用Java去实现自定义工具。
思路
首先需要封装OpenAIChat对象,对象应包括以下变量:
/**
* 使用的模型名称.
*/
@Builder.Default
private String model = "gpt-3.5-turbo";
/**
* 模型的API地址.
*/
@Builder.Default
private String endpointUrl = "http://127.0.0.1:8000/";
/**
* 模型的最大token数.
*/
@Builder.Default
private int maxToken = 20000;
/**
* 模型的temperature.
*/
@Builder.Default
private float temperature = 0.9f;
/**
* 模型的TopP
*/
@Builder.Default
private float topP = 0.78f;
/**
* 是否使用历史记录.
*/
private boolean withHistory;
/**
* 历史记录
*/
@Builder.Default
private List<List<?>> history = new ArrayList<>();
OpenAIChat对象中构建工具调用流式问答的方法实现:
/**
* 工具流式问答
* @param sessionId 会话ID
* @param prompt 提示词
* @param baseTools 工具
* @return
*/
@Override
public Flux<String> streamChatWithTools(String sessionId, String prompt, List<BaseTool> baseTools) {
String class2Json = buildClass2Json(new BaseTool());
String finalPrompt = String.format("你是一个AI助手,我会给你一个工具对象集合,工具对象包括name(工具名)、description(工具描述)、params(工具参数)。" +
"你可以结合工具对象,从用户的问句中提取到关键词,确定要实现用户的任务应该喧杂哪个工具对象。" +
"用户的任务是:%S。" +
"工具对象集合是:%s。" +
"您的响应结果必须为JSON格式,并且不要返回任何不必要的解释,只提供遵循此格式的符合RFC8259的JSON响应。以下是输出必须遵守的JSON Schema实例:```%s```",
prompt, JSON.toJSONString(baseTools), class2Json);
String funcParams = chat(finalPrompt);
funcParams = JSON.parseObject(funcParams, OpenAIChatResponse.class).getChoices().get(0).getMessage().getContent();
funcParams = funcParams.substring(funcParams.indexOf("{"), funcParams.lastIndexOf("}")+1);
String toolResult = LoadFunctions.load(JSON.parseObject(funcParams, BaseTool.class));
LOG.info("工具调用结果为:"+toolResult);
String promptFormat = String.format("基于以下内容:{%s}。请回答:{%s}", toolResult, prompt);
Flux<String> stringFlux = streamChat(sessionId, promptFormat);
return stringFlux;
}
创建一个自定义工具类,所有的工具都在这里实现,我这里创建了三个工具(查询天气、返回一个UUID、查询手机号的归属地):
public class FunctionCaller {
/**
* 查询天气API
*
* @param params
* @return
*/
public String getWeather(Map<String,Object> params) {
String cityName = params.get("cityName").toString();
if (cityName == null || cityName.isEmpty()) {
throw new IllegalArgumentException("City name must not be null or empty");
}
OkHttpClient client = new OkHttpClient.Builder()
.connectTimeout(60, TimeUnit.SECONDS)
.writeTimeout(60, TimeUnit.SECONDS)
.readTimeout(60, TimeUnit.SECONDS)
.build();
try {
Map<String, String> headers = new HashMap<>(16);
headers.put("Content-Type", "application/json");
Request.Builder builder = new Request.Builder()
.url("https://wttr.in/" + cityName + "?format=j1");
builder.headers(Headers.of(headers));
builder.method("GET", null);
Request request = builder.build();
Response response = client.newCall(request).execute();
if (response.isSuccessful()) {
ResponseBody responseBody = response.body();
JSONObject jsonObject = JSONObject.parseObject(responseBody.string());
JSONObject weatherData = new JSONObject();
// Extract the desired weather data from the JSON response
JSONArray currentCondition = jsonObject.getJSONArray("current_condition");
JSONObject firstEntry = currentCondition.getJSONObject(0);
weatherData.put("temp_C", firstEntry.getString("temp_C"));
weatherData.put("FeelsLikeC", firstEntry.getString("FeelsLikeC"));
weatherData.put("humidity", firstEntry.getString("humidity"));
weatherData.put("weatherDesc", firstEntry.getString("weatherDesc"));
weatherData.put("observation_time", firstEntry.getString("observation_time"));
weatherData.put("cityName", params.get("cityName").toString());
return weatherData.toString();
} else {
throw new HttpRuntimeException("Failed.接口访问失败");
}
} catch (IOException e) {
e.printStackTrace();
return "Error encountered while fetching weather data!";
}
}
/**
* 随机获取一个uuid
* @return
*/
public String getUuid(Map<String,Object> params){
return UUID.randomUUID().toString();
}
/**
* 查询手机号的归属地
* @param params
* @return
*/
public String getMobileAddress(Map<String,Object> params) {
String mobileNum = params.get("mobileNum").toString();
try {
OkHttpClient client = new OkHttpClient.Builder()
.connectTimeout(60, TimeUnit.SECONDS)
.writeTimeout(60, TimeUnit.SECONDS)
.readTimeout(60, TimeUnit.SECONDS)
.build();
MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded");
RequestBody body = RequestBody.create(mediaType, "mobile="+mobileNum);
Request request = new Request.Builder()
.url("https://eolink.o.apispace.com/teladress/teladress")
.method("POST",body)
.addHeader("X-APISpace-Token","v1a524e7ctm4h87ilxxxxxxxxxxxxx")
.addHeader("Content-Type","")
.build();
Response response = client.newCall(request).execute();
if (response.isSuccessful()) {
ResponseBody responseBody = response.body();
return responseBody.string();
} else {
throw new HttpRuntimeException("Failed.接口访问失败");
}
} catch (Exception e) {
e.printStackTrace();
return "Error fetching mobile address: " + e.getMessage();
}
}
}
然后,采用反射的方式调用这些工具:
/**
* 通过反射机制调用函数
* @param methodName
* @param jsonNode
* @return
*/
public static Object reflect(String methodName, Map<String,String> jsonNode){
FunctionCaller functionCaller = new FunctionCaller();
Method method = ReflectUtil.getMethod(FunctionCaller.class, methodName, new Class[]{Map.class});
try {
Object invoke = method.invoke(functionCaller, jsonNode);
return invoke;
} catch (IllegalAccessException e) {
LOG.error("FunctionReflect reflect occur illegal access exception,method name = {},jsonNode = {}",methodName,jsonNode,e);
throw new RuntimeException(e);
} catch (InvocationTargetException e) {
LOG.error("FunctionReflect reflect occur invocation target exception,method name = {},jsonNode = {}",methodName,jsonNode,e);
throw new RuntimeException(e);
}
}
调用的时候,将所有的工具集合作为参数,传入OpenAIChat的streamChatWithTools方法:
public static void test2(){
BaseTool baseTool = new BaseTool();
baseTool.setName("getWeather");
Map<String,String> map = new HashMap<>(16);
map.put("cityName","城市");
baseTool.setParams(map);
baseTool.setDescription("查询天气工具");
BaseTool baseTool1 = new BaseTool();
baseTool1.setName("getUuid");
baseTool1.setDescription("获取UUID");
baseTool1.setParams(null);
BaseTool baseTool2 = new BaseTool();
baseTool2.setName("getMobileAddress");
baseTool2.setDescription("查询手机号归属地");
Map<String,String> map2 = new HashMap<>(16);
map2.put("mobileNum","手机号");
baseTool2.setParams(map2);
List<BaseTool> baseTools = Arrays.asList(baseTool,baseTool1,baseTool2);
OpenAIChat openAIChat = OpenAIChat.builder()
.endpointUrl("http://127.0.0.1:11434/v1/chat/completions")
.model("qwen:7b-chat-v1.5-q5_K_M")
.build().init();
Flux<String> stringFlux = openAIChat.streamChatWithTools("112233","查询北京的天气", baseTools);
stringFlux.subscribe();
//System.out.println(s);
}
然后就可以实现工具调用了: