前言
我在前一段时间突发奇想,就使用java来调用chatgpt的接口,然后写了一个简单小程序,也上了热榜第一,,事实上,这个程序毛病挺多的,最不能让人接受的一点就是返回速度非常缓慢(即使使用非常好的外网服务器)。
现在,我改进了一下程序,使用异步请求的方式,基本可以实现秒回复。并且还基于websocket编写了一个微信小程序来进行交互,可以直接使用微信小程序来进行体验。
现在我将所有代码都上传了github(链接在文章结尾),大家可以clone下来,部署到服务器上,真正实现自己的聊天机器人!!!
效果展示
部分截图如下
原理说明
在 我说明了java调用chatgpt的基本原理,这里的代码就是对这个代码的改进,使用异步请求的方式来进行。
注意看官方文档,我们在请求时可以提供一个参数stream,然后就可以实现按照流的形式进行返回,这种方式基本可以做到没有延迟就给出答案。
由于这次改进的思路主要就是将请求改为了异步,其他的基本一样,所以就不做解释,直接给出代码了,代码上面都有注释
/**
* 这个方法用于测试的,可以在控制台打印输出结果
*
* @param chatgptrequestparameter 请求的参数
* @param question 问题
*/
public void printanswer(chatrequestparameter chatgptrequestparameter, string question) {
asyncclient.start();
// 创建一个post请求
asyncrequestbuilder asyncrequest = asyncrequestbuilder.post(url);
// 设置请求参数
chatgptrequestparameter.addmessages(new chatmessage("user", question));
// 请求的参数转换为字符串
string valueasstring = null;
try {
valueasstring = objectmapper.writevalueasstring(chatgptrequestparameter);
} catch (jsonprocessingexception e) {
e.printstacktrace();
}
// 设置编码和请求参数
contenttype contenttype = contenttype.create("text/plain", charset);
asyncrequest.setentity(valueasstring, contenttype);
asyncrequest.setcharset(charset);
// 设置请求头
asyncrequest.setheader(httpheaders.content_type, "application/json");
// 设置登录凭证
asyncrequest.setheader(httpheaders.authorization, "bearer " + apikey);
// 下面就是生产者消费者模型
countdownlatch latch = new countdownlatch(1);
// 用于记录返回的答案
stringbuilder sb = new stringbuilder();
// 消费者
abstractcharresponseconsumer<httpresponse> consumer = new abstractcharresponseconsumer<httpresponse>() {
httpresponse response;
@override
protected void start(httpresponse response, contenttype contenttype) throws httpexception, ioexception {
setcharset(charset);
this.response = response;
}
@override
protected int capacityincrement() {
return integer.max_value;
}
@override
protected void data(charbuffer src, boolean endofstream) throws ioexception {
// 收到一个请求就进行处理
string ss = src.tostring();
// 通过data:进行分割,如果不进行此步,可能返回的答案会少一些内容
for (string s : ss.split("data:")) {
// 去除掉data:
if (s.startswith("data:")) {
s = s.substring(5);
}
// 返回的数据可能是(done)
if (s.length() > 8) {
// 转换为对象
chatresponseparameter responseparameter = objectmapper.readvalue(s, chatresponseparameter.class);
// 处理结果
for (choice choice : responseparameter.getchoices()) {
string content = choice.getdelta().getcontent();
if (content != null && !"".equals(content)) {
// 保存结果
sb.append(content);
// 将结果使用websocket传送过去
system.out.print(content);
}
}
}
}
}
@override
protected httpresponse buildresult() throws ioexception {
return response;
}
@override
public void releaseresources() {
}
};
// 执行请求
asyncclient.execute(asyncrequest.build(), consumer, new futurecallback<httpresponse>() {
@override
public void completed(httpresponse response) {
latch.countdown();
chatgptrequestparameter.addmessages(new chatmessage("assistant", sb.tostring()));
system.out.println("回答结束!!!");
}
@override
public void failed(exception ex) {
latch.countdown();
system.out.println("failed");
ex.printstacktrace();
}
@override
public void cancelled() {
latch.countdown();
system.out.println("cancelled");
}
});
try {
latch.await();
} catch (interruptedexception e) {
e.printstacktrace();
}
}
服务器端代码说明
我使用java搭建了一个简单的服务器端程序,提供最基础的用户登录校验功能,以及提供了websocket通信。
用户校验的代码
package com.ttpfx.controller;
import com.ttpfx.entity.user;
import com.ttpfx.service.userservice;
import com.ttpfx.utils.r;
import org.springframework.web.bind.annotation.requestmapping;
import org.springframework.web.bind.annotation.restcontroller;
import javax.annotation.resource;
import java.util.objects;
import java.util.concurrent.concurrenthashmap;
/**
* @author ttpfx
* @date 2023/3/29
*/
@restcontroller
@requestmapping("/user")
public class usercontroller {
@resource
private userservice userservice;
public static concurrenthashmap<string, user> loginuser = new concurrenthashmap<>();
public static concurrenthashmap<string, long> loginuserkey = new concurrenthashmap<>();
@requestmapping("/login")
public r login(string username, string password) {
if (username == null) return r.fail("必须填写用户名");
user user = userservice.querybyname(username);
if (user == null) return r.fail("用户名不存在");
string targetpassword = user.getpassword();
if (targetpassword == null) return r.fail("用户密码异常");
if (!targetpassword.equals(password)) return r.fail("密码错误");
loginuser.put(username, user);
loginuserkey.put(username, system.currenttimemillis());
return r.ok(string.valueof(loginuserkey.get(username)));
}
@requestmapping("/logout")
public r logout(string username) {
loginuser.remove(username);
loginuserkey.remove(username);
return r.ok();
}
@requestmapping("/checkuserkey")
public r checkuserkey(string username, long key){
if (username==null || key == null)return r.fail("用户校验异常");
if (!objects.equals(loginuserkey.get(username), key)){
return r.fail("用户在其他地方登录!!!");
}
return r.ok();
}
@requestmapping("/loginuser")
public r loginuser(){
return r.ok("success",loginuser.keyset());
}
}
基于websocket通信的代码
package com.ttpfx.server;
import com.fasterxml.jackson.databind.objectmapper;
import com.ttpfx.entity.userlog;
import com.ttpfx.model.chatmodel;
import com.ttpfx.service.userlogservice;
import com.ttpfx.service.userservice;
import com.ttpfx.vo.chat.chatrequestparameter;
import org.springframework.stereotype.component;
import javax.annotation.resource;
import javax.websocket.*;
import javax.websocket.server.pathparam;
import javax.websocket.server.serverendpoint;
import java.io.ioexception;
import java.time.localdatetime;
import java.util.concurrent.concurrenthashmap;
/**
* @author ttpfx
* @date 2023/3/28
*/
@component
@serverendpoint("/chatwebsocket/{username}")
public class chatwebsocketserver {
/**
* 静态变量,用来记录当前在线连接数。应该把它设计成线程安全的。
*/
private static int onlinecount = 0;
/**
* concurrent包的线程安全map,用来存放每个客户端对应的mywebsocket对象。
*/
private static concurrenthashmap<string, chatwebsocketserver> chatwebsocketmap = new concurrenthashmap<>();
/**
* 与某个客户端的连接会话,需要通过它来给客户端发送数据
*/
private session session;
/**
* 接收的username
*/
private string username = "";
private userlog userlog;
private static userservice userservice;
private static userlogservice userlogservice;
@resource
public void setuserservice(userservice userservice) {
chatwebsocketserver.userservice = userservice;
}
@resource
public void setuserlogservice(userlogservice userlogservice) {
chatwebsocketserver.userlogservice = userlogservice;
}
private objectmapper objectmapper = new objectmapper();
private static chatmodel chatmodel;
@resource
public void setchatmodel(chatmodel chatmodel) {
chatwebsocketserver.chatmodel = chatmodel;
}
chatrequestparameter chatrequestparameter = new chatrequestparameter();
/**
* 建立连接
* @param session 会话
* @param username 连接用户名称
*/
@onopen
public void onopen(session session, @pathparam("username") string username) {
this.session = session;
this.username = username;
this.userlog = new userlog();
// 这里的用户id不可能为null,出现null,那么就是非法请求
try {
this.userlog.setuserid(userservice.querybyname(username).getid());
} catch (exception e) {
e.printstacktrace();
try {
session.close();
} catch (ioexception ex) {
ex.printstacktrace();
}
}
this.userlog.setusername(username);
chatwebsocketmap.put(username, this);
onlinecount++;
system.out.println(username + "--open");
}
@onclose
public void onclose() {
chatwebsocketmap.remove(username);
system.out.println(username + "--close");
}
@onmessage
public void onmessage(string message, session session) {
system.out.println(username + "--" + message);
// 记录日志
this.userlog.setdatetime(localdatetime.now());
this.userlog.setprelogid(this.userlog.getlogid() == null ? -1 : this.userlog.getlogid());
this.userlog.setlogid(null);
this.userlog.setquestion(message);
long start = system.currenttimemillis();
// 这里就会返回结果
string answer = chatmodel.getanswer(session, chatrequestparameter, message);
long end = system.currenttimemillis();
this.userlog.setconsumetime(end - start);
this.userlog.setanswer(answer);
userlogservice.save(userlog);
}
@onerror
public void onerror(session session, throwable error) {
error.printstacktrace();
}
public void sendmessage(string message) throws ioexception {
this.session.getbasicremote().sendtext(message);
}
public static void sendinfo(string message, string touserid) throws ioexception {
chatwebsocketmap.get(touserid).sendmessage(message);
}
}
微信小程序代码说明
我写了一个简单微信小程序来和后端进行通信,界面如下
大家只需要下载源代码,然将程序中的ip改为自己服务器的ip即可
代码链接
github的地址为 https://github.com/c-ttpfx/chatgpt-java-wx
可以直接使用 git clone https://github.com/c-ttpfx/chatgpt-java-wx.git 下载代码到本地
我在github里面说明了安装使用的基本步骤,大家按照步骤使用即可
总结
上面聊天小程序就是我花2天写出来的,可能会有一些bug,我自己测试的时候倒是没有怎么遇到bug,聊天和登录功能都能正常使用。
对于微信小程序,由于我不是专业搞前端的,就只东拼西凑实现了最基本的功能(登录、聊天),大家可以自己写一个,反正后端接口都提供好了嘛,也不是很难,不想写也可以将就使用我的。
最后,也是最重要的,大家帮我的代码star一下!!! 感谢大家了(≥▽≤)/(≥▽≤)/
更新日志
2023/5/13 14:42更新
对代码进行了重构,最新的代码已经支持代理,通过在application.yaml里面进行简单配置即可使用
gpt:
proxy:
host: 127.0.0.1
port: 7890
发表评论