1. 项目概述
本教程详细介绍如何使用 spring boot 3 整合 spring ai 实现一个具有记忆功能的 ai 助手。该实现使用 redis 作为存储介质,支持用户级别的会话隔离和 30 天的对话历史持久化。
技术栈
- spring boot 3.3.0
- java 17
- spring ai
- redis 6.0+
- mybatis plus
- mysql 8.0
- sa-token(用户认证)
2. 环境准备
2.1 安装必要软件
- jdk 17+:oracle jdk 或 openjdk
- maven 3.9+:maven 官网
- redis 6.0+:redis 官网 或使用 docker
- mysql 8.0+:mysql 官网 或使用 docker
- ide:intellij idea 或 eclipse
2.2 配置环境变量
确保 java_home 和 maven_home 已正确配置。
3. 项目初始化
3.1 创建 spring boot 项目
使用 spring initializr 创建项目:
- 访问 spring initializr
- 选择 spring boot 3.3.0
- 选择 java 17
- 添加依赖:spring web, spring data redis, mybatis plus, mysql driver, spring boot devtools
3.2 配置 maven 依赖
在 pom.xml 文件中添加以下依赖:
<dependencies>
<!-- spring boot 核心依赖 -->
<dependency>
<groupid>org.springframework.boot</groupid>
<artifactid>spring-boot-starter-web</artifactid>
</dependency>
<!-- spring data redis -->
<dependency>
<groupid>org.springframework.boot</groupid>
<artifactid>spring-boot-starter-data-redis</artifactid>
</dependency>
<!-- mybatis plus -->
<dependency>
<groupid>com.baomidou</groupid>
<artifactid>mybatis-plus-boot-starter</artifactid>
<version>3.5.5</version>
</dependency>
<!-- mysql 驱动 -->
<dependency>
<groupid>com.mysql</groupid>
<artifactid>mysql-connector-j</artifactid>
<scope>runtime</scope>
</dependency>
<!-- spring ai -->
<dependency>
<groupid>org.springframework.ai</groupid>
<artifactid>spring-ai-openai</artifactid>
<version>1.0.0</version>
</dependency>
<!-- sa-token 认证 -->
<dependency>
<groupid>cn.dev33</groupid>
<artifactid>sa-token-spring-boot-starter</artifactid>
<version>1.38.1</version>
</dependency>
<!-- jackson 序列化 -->
<dependency>
<groupid>com.fasterxml.jackson.core</groupid>
<artifactid>jackson-databind</artifactid>
</dependency>
<!-- lombok -->
<dependency>
<groupid>org.projectlombok</groupid>
<artifactid>lombok</artifactid>
<optional>true</optional>
</dependency>
<!-- spring boot 测试 -->
<dependency>
<groupid>org.springframework.boot</groupid>
<artifactid>spring-boot-starter-test</artifactid>
<scope>test</scope>
</dependency>
</dependencies>4. 核心配置
4.1 应用配置文件 (application.yml)
创建 src/main/resources/application.yml 文件,配置应用信息:
server:
port: 9527
servlet:
context-path: /api
spring:
application:
name: smart-pic-community-backend
# redis 配置
data:
redis:
database: 2
host: localhost
port: 6379
timeout: 5000
# 数据库配置
datasource:
driver-class-name: com.mysql.cj.jdbc.driver
url: jdbc:mysql://localhost:3306/smart_pic_community
username: root
password: your_password
# spring ai 配置
ai:
openai:
base-url: https://api.deepseek.com/ # 使用 deepseek api
api-key: your_api_key
chat:
options:
model: deepseek-chat4.2 redis 配置 (redisconfiguration.java)
创建 redis 配置类,确保正确序列化对象:
package com.spc.smartpiccommunitybackend.config;
import com.fasterxml.jackson.annotation.jsonautodetect;
import com.fasterxml.jackson.annotation.propertyaccessor;
import com.fasterxml.jackson.databind.objectmapper;
import lombok.extern.slf4j.slf4j;
import org.springframework.context.annotation.bean;
import org.springframework.context.annotation.configuration;
import org.springframework.data.redis.connection.redisconnectionfactory;
import org.springframework.data.redis.core.redistemplate;
import org.springframework.data.redis.serializer.jackson2jsonredisserializer;
import org.springframework.data.redis.serializer.stringredisserializer;
@configuration
@slf4j
public class redisconfiguration {
@bean
public redistemplate<string, object> redistemplate(redisconnectionfactory redisconnectionfactory) {
log.info("开始创建redis模板对象...");
redistemplate<string, object> redistemplate = new redistemplate<>();
// 设置连接工厂
redistemplate.setconnectionfactory(redisconnectionfactory);
// 使用 stringredisserializer 来序列化和反序列化 redis 的 key
stringredisserializer stringredisserializer = new stringredisserializer();
// key 采用 string 的序列化方式
redistemplate.setkeyserializer(stringredisserializer);
redistemplate.sethashkeyserializer(stringredisserializer);
// 使用 jackson2jsonredisserializer 来序列化和反序列化 redis 的 value
jackson2jsonredisserializer<object> jackson2jsonredisserializer = new jackson2jsonredisserializer<>(object.class);
objectmapper objectmapper = new objectmapper();
objectmapper.setvisibility(propertyaccessor.all, jsonautodetect.visibility.any);
objectmapper.enabledefaulttyping(objectmapper.defaulttyping.non_final);
jackson2jsonredisserializer.setobjectmapper(objectmapper);
// value 采用 json 的序列化方式
redistemplate.setvalueserializer(jackson2jsonredisserializer);
redistemplate.sethashvalueserializer(jackson2jsonredisserializer);
redistemplate.afterpropertiesset();
return redistemplate;
}
}注意事项:
- 使用
jackson2jsonredisserializer而不是stringredisserializer可以避免classcastexception - 启用默认类型可以确保反序列化时能正确识别对象类型
5. 核心功能实现
5.1 创建模型类
5.1.1 创建 messagevo.java
创建消息视图对象,用于前端展示:
package com.spc.smartpiccommunitybackend.model.vo.ai;
import lombok.data;
import lombok.noargsconstructor;
import org.springframework.ai.chat.messages.message;
@noargsconstructor
@data
public class messagevo {
private string role;
private string content;
public messagevo(message message) {
this.role = switch (message.getmessagetype()) {
case user -> "user";
case assistant -> "assistant";
case system -> "system";
default -> "";
};
this.content = message.gettext();
}
}
功能说明:
- 将 spring ai 的
message对象转换为前端可识别的格式 - 根据消息类型设置不同的角色
- 提取消息内容用于展示
5.1.2 创建 serializablemessage.java
创建可序列化的消息对象,用于 redis 存储:
package com.spc.smartpiccommunitybackend.model.entity.ai;
import lombok.data;
import lombok.noargsconstructor;
import java.io.serializable;
@data
@noargsconstructor
public class serializablemessage implements serializable {
private static final long serialversionuid = 1l;
private string role;
private string content;
private string messagetype;
private long timestamp;
public serializablemessage(string role, string content, string messagetype) {
this.role = role;
this.content = content;
this.messagetype = messagetype;
this.timestamp = system.currenttimemillis();
}
public serializablemessage(string role, string content) {
this(role, content, "user");
}
}
功能说明:
- 实现
serializable接口,支持 redis 序列化 - 包含角色、内容、消息类型和时间戳字段
- 提供多个构造方法,方便使用
5.2 定义仓库接口 (chathistoryrepository.java)
创建聊天历史仓库接口:
package com.spc.smartpiccommunitybackend.repository;
import java.util.list;
import java.util.map;
public interface chathistoryrepository {
/**
* 保存会话记录
*/
void save(string type, string chatid, long userid);
/**
* 获取用户的会话id列表
*/
list<string> getchatids(long userid, string type);
/**
* 保存聊天消息
*/
void savemessage(string chatid, string message, string sender);
/**
* 获取聊天消息历史
*/
list<string> getmessages(string chatid);
/**
* 删除会话
*/
void deletechat(long userid, string type, string chatid);
/**
* 获取会话信息
*/
map<object, object> getsessioninfo(string chatid);
}
5.3 实现 redis 聊天历史仓库 (redischathistoryrepository.java)
实现基于 redis 的聊天历史仓库:
package com.spc.smartpiccommunitybackend.repository;
import com.fasterxml.jackson.databind.objectmapper;
import com.fasterxml.jackson.core.jsonprocessingexception;
import lombok.requiredargsconstructor;
import org.springframework.data.redis.core.redistemplate;
import org.springframework.stereotype.component;
import java.util.*;
import java.util.concurrent.timeunit;
import java.util.stream.collectors;
@component
@requiredargsconstructor
public class redischathistoryrepository implements chathistoryrepository {
private final redistemplate<string, object> redistemplate;
// redis key前缀
private static final string chat_history_prefix = "chat:history:";
private static final string chat_session_prefix = "chat:session:";
private static final string chat_messages_prefix = "chat:messages:";
/**
* 保存会话记录
*/
@override
public void save(string type, string chatid, long userid) {
// 保存会话信息
string sessionkey = chat_session_prefix + chatid;
map<string, object> sessioninfo = new hashmap<>();
sessioninfo.put("userid", string.valueof(userid));
sessioninfo.put("type", type);
sessioninfo.put("createtime", system.currenttimemillis());
sessioninfo.put("lastupdatetime", system.currenttimemillis());
redistemplate.opsforhash().putall(sessionkey, sessioninfo);
// 设置过期时间为30天
redistemplate.expire(sessionkey, 30, timeunit.days);
// 将chatid添加到用户的聊天历史列表中
string historykey = chat_history_prefix + userid + ":" + type;
redistemplate.opsforset().add(historykey, chatid);
// 设置过期时间为30天
redistemplate.expire(historykey, 30, timeunit.days);
}
/**
* 获取用户的会话id列表
*/
@override
public list<string> getchatids(long userid, string type) {
string historykey = chat_history_prefix + userid + ":" + type;
set<object> chatids = redistemplate.opsforset().members(historykey);
if (chatids == null || chatids.isempty()) {
return collections.emptylist();
}
return chatids.stream()
.map(object::tostring)
.collect(collectors.tolist());
}
/**
* 保存聊天消息
*/
@override
public void savemessage(string chatid, string message, string sender) {
string messageskey = chat_messages_prefix + chatid;
// 创建消息对象
map<string, object> messageinfo = new hashmap<>();
messageinfo.put("content", message);
messageinfo.put("sender", sender);
messageinfo.put("timestamp", system.currenttimemillis());
// 使用json格式保存消息
objectmapper objectmapper = new objectmapper();
try {
string jsonmessage = objectmapper.writevalueasstring(messageinfo);
redistemplate.opsforlist().rightpush(messageskey, jsonmessage);
} catch (jsonprocessingexception e) {
e.printstacktrace();
// 如果json序列化失败,使用原始消息
redistemplate.opsforlist().rightpush(messageskey, message);
}
// 设置过期时间为30天
redistemplate.expire(messageskey, 30, timeunit.days);
// 更新会话的最后更新时间
string sessionkey = chat_session_prefix + chatid;
redistemplate.opsforhash().put(sessionkey, "lastupdatetime", system.currenttimemillis());
// 确保会话信息也有过期时间
redistemplate.expire(sessionkey, 30, timeunit.days);
}
/**
* 获取聊天消息历史
*/
@override
public list<string> getmessages(string chatid) {
string messageskey = chat_messages_prefix + chatid;
list<object> messages = redistemplate.opsforlist().range(messageskey, 0, -1);
if (messages == null || messages.isempty()) {
return collections.emptylist();
}
return messages.stream()
.map(object::tostring)
.collect(collectors.tolist());
}
/**
* 删除会话
*/
@override
public void deletechat(long userid, string type, string chatid) {
// 从用户的聊天历史列表中删除
string historykey = chat_history_prefix + userid + ":" + type;
redistemplate.opsforset().remove(historykey, chatid);
// 删除会话信息
string sessionkey = chat_session_prefix + chatid;
redistemplate.delete(sessionkey);
// 删除聊天消息
string messageskey = chat_messages_prefix + chatid;
redistemplate.delete(messageskey);
}
/**
* 获取会话信息
*/
@override
public map<object, object> getsessioninfo(string chatid) {
string sessionkey = chat_session_prefix + chatid;
return redistemplate.opsforhash().entries(sessionkey);
}
}
注意事项:
- 使用不同的 redis key 前缀区分不同类型的数据
- 为所有 redis 键设置过期时间,避免内存泄漏
- 处理 json 序列化失败的情况,提高系统健壮性
5.4 实现聊天记忆 (redischatmemory.java)
实现基于 redis 的聊天记忆,支持 spring ai 的 chatmemory 接口:
package com.spc.smartpiccommunitybackend.config;
import org.springframework.ai.chat.messages.message;
import org.springframework.ai.chat.memory.chatmemory;
import org.springframework.data.redis.core.redistemplate;
import org.springframework.stereotype.component;
import java.util.arraylist;
import java.util.list;
import java.util.concurrent.timeunit;
@component
public class redischatmemory implements chatmemory {
private final redistemplate<string, object> redistemplate;
private static final string memory_key_prefix = "chat:memory:";
private static final long expiration_days = 30;
public redischatmemory(redistemplate<string, object> redistemplate) {
this.redistemplate = redistemplate;
}
@override
public void add(string key, list<message> messages) {
// 为特定会话添加多条消息
for (message message : messages) {
addmessage(key, message);
}
}
@override
public list<message> get(string key, int maxcount) {
// 实现 get 方法,根据 key 获取消息
list<message> messages = getmessages(key);
// 如果指定了最大数量,返回不超过该数量的消息
if (maxcount > 0 && messages.size() > maxcount) {
return messages.sublist(messages.size() - maxcount, messages.size());
}
return messages;
}
@override
public void clear() {
// 清理所有会话记忆
// 注意:这个操作会删除所有聊天记忆,谨慎使用
}
/**
* 为特定会话添加消息
*/
public void addmessage(string chatid, message message) {
string key = memory_key_prefix + chatid;
redistemplate.opsforlist().rightpush(key, message);
redistemplate.expire(key, expiration_days, timeunit.days);
}
/**
* 获取特定会话的消息
*/
public list<message> getmessages(string chatid) {
string key = memory_key_prefix + chatid;
list<object> objects = redistemplate.opsforlist().range(key, 0, -1);
list<message> messages = new arraylist<>();
if (objects != null) {
for (object obj : objects) {
if (obj instanceof message) {
messages.add((message) obj);
}
}
}
return messages;
}
/**
* 清理特定会话的记忆
*/
public void clear(string chatid) {
string key = memory_key_prefix + chatid;
redistemplate.delete(key);
}
}
注意事项:
- 实现
chatmemory接口以支持 spring ai 的消息记忆功能 - 为每条消息设置过期时间,确保内存使用合理
- 提供批量添加和获取消息的方法,提高性能
5.5 配置 chatclient (commonconfiguration.java)
配置 spring ai 的 chatclient:
package com.spc.smartpiccommunitybackend.config;
import org.springframework.ai.chat.client.chatclient;
import org.springframework.ai.chat.client.advisor.messagechatmemoryadvisor;
import org.springframework.ai.chat.client.advisor.simpleloggeradvisor;
import org.springframework.ai.chat.memory.chatmemory;
import org.springframework.ai.openai.openaichatmodel;
import org.springframework.context.annotation.bean;
import org.springframework.context.annotation.configuration;
@configuration
public class commonconfiguration {
@bean
public chatclient chatclient(openaichatmodel openaichatmodel, chatmemory chatmemory) {
return chatclient.builder(openaichatmodel)
.defaultadvisors(
new simpleloggeradvisor(),
new messagechatmemoryadvisor(chatmemory)
)
.build();
}
}
注意事项:
- 使用
messagechatmemoryadvisor来启用聊天记忆功能 - 使用
simpleloggeradvisor来记录聊天交互,方便调试
5.6 实现聊天控制器 (chatcontroller.java)
实现聊天控制器,处理 ai 对话请求:
package com.spc.smartpiccommunitybackend.controller;
import com.fasterxml.jackson.databind.jsonnode;
import com.fasterxml.jackson.databind.objectmapper;
import com.spc.smartpiccommunitybackend.repository.redischathistoryrepository;
import com.spc.smartpiccommunitybackend.service.userservice;
import com.spc.smartpiccommunitybackend.utils.errorcode;
import com.spc.smartpiccommunitybackend.utils.throwutils;
import com.spc.smartpiccommunitybackend.pojo.user;
import org.springframework.ai.chat.client.chatclient;
import org.springframework.ai.chat.messages.message;
import org.springframework.ai.chat.messages.systemmessage;
import org.springframework.ai.chat.messages.usermessage;
import org.springframework.ai.chat.messages.assistantmessage;
import org.springframework.beans.factory.annotation.resource;
import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.flux;
import javax.servlet.http.httpservletrequest;
import java.io.ioexception;
import java.util.arraylist;
import java.util.list;
@restcontroller
@requestmapping("/ai")
public class chatcontroller {
private final chatclient chatclient;
private final redischathistoryrepository chathistoryrepository;
@resource
private userservice userservice;
public chatcontroller(chatclient chatclient, redischathistoryrepository chathistoryrepository) {
this.chatclient = chatclient;
this.chathistoryrepository = chathistoryrepository;
}
@requestmapping(value = "/chat", produces = "text/html;charset=utf-8")
public flux<string> chat(@requestparam(defaultvalue = "讲个笑话") string prompt,
string chatid,
httpservletrequest request) {
user loginuser = userservice.getloginuser(request);
// 校验登录用户是否为空
throwutils.throwif(loginuser == null, errorcode.not_login_error);
long userid = loginuser.getid();
// 保存会话信息
chathistoryrepository.save("chat", chatid, userid);
// 获取历史对话消息作为上下文
list<message> messages = new arraylist<>();
// 添加系统消息
systemmessage systemmessage = new systemmessage(
"你是一个智能图片社区的ai助手,名为虹小智。请用友好、专业的语气回答用户问题," +
"提供关于图片社区的相关信息和帮助。"
);
messages.add(systemmessage);
// 获取并解析历史消息
list<string> historymessages = chathistoryrepository.getmessages(chatid);
objectmapper objectmapper = new objectmapper();
for (string messagestr : historymessages) {
try {
jsonnode node = objectmapper.readtree(messagestr);
string sender = node.get("sender").astext();
string content = node.get("content").astext();
if ("user".equals(sender)) {
messages.add(new usermessage(content));
} else if ("ai".equals(sender)) {
messages.add(new assistantmessage(content));
}
} catch (ioexception e) {
e.printstacktrace();
}
}
// 添加用户当前消息
messages.add(new usermessage(prompt));
// 保存用户消息到历史记录
chathistoryrepository.savemessage(chatid, prompt, "user");
// 调用ai模型获取响应
return chatclient.stream(messages)
.doonnext(response -> {
// 保存ai响应到历史记录
chathistoryrepository.savemessage(chatid, response, "ai");
});
}
}
注意事项:
- 验证用户登录状态,确保会话隔离
- 保存用户消息和 ai 响应到历史记录
- 使用
flux实现流式响应,提高用户体验 - 处理历史消息解析异常,提高系统健壮性
5.7 实现聊天历史控制器 (chathistorycontroller.java)
实现聊天历史控制器,处理聊天历史的获取和删除:
package com.spc.smartpiccommunitybackend.controller;
import com.spc.smartpiccommunitybackend.repository.redischathistoryrepository;
import com.spc.smartpiccommunitybackend.service.userservice;
import com.spc.smartpiccommunitybackend.utils.errorcode;
import com.spc.smartpiccommunitybackend.utils.throwutils;
import com.spc.smartpiccommunitybackend.pojo.user;
import org.springframework.web.bind.annotation.*;
import javax.servlet.http.httpservletrequest;
import java.util.list;
@restcontroller
@requestmapping("/ai/history")
public class chathistorycontroller {
private final redischathistoryrepository chathistoryrepository;
private final userservice userservice;
public chathistorycontroller(redischathistoryrepository chathistoryrepository, userservice userservice) {
this.chathistoryrepository = chathistoryrepository;
this.userservice = userservice;
}
/**
* 获取用户的聊天历史id列表
*/
@getmapping("/{type}")
public list<string> getchathistory(@pathvariable string type, httpservletrequest request) {
user loginuser = userservice.getloginuser(request);
throwutils.throwif(loginuser == null, errorcode.not_login_error);
long userid = loginuser.getid();
return chathistoryrepository.getchatids(userid, type);
}
/**
* 删除指定聊天历史
*/
@deletemapping("/{type}/{chatid}")
public boolean deletechathistory(@pathvariable string type,
@pathvariable string chatid,
httpservletrequest request) {
user loginuser = userservice.getloginuser(request);
throwutils.throwif(loginuser == null, errorcode.not_login_error);
long userid = loginuser.getid();
chathistoryrepository.deletechat(userid, type, chatid);
return true;
}
}
注意事项:
- 验证用户登录状态,确保只能操作自己的聊天历史
- 提供获取和删除聊天历史的接口,方便前端管理
6. 测试和验证
6.1 启动服务
- 确保 redis 和 mysql 服务已启动
- 运行 spring boot 应用
- 访问
http://localhost:9527/api/ai/chat?prompt=你好&chatid=test123测试 ai 响应
6.2 测试聊天记忆功能
- 发送第一条消息:
http://localhost:9527/api/ai/chat?prompt=你好,我叫张三&chatid=test123 - 发送第二条消息:
http://localhost:9527/api/ai/chat?prompt=你知道我叫什么名字吗?&chatid=test123 - 验证 ai 能否正确回答你的名字
6.3 测试用户隔离功能
- 使用不同用户登录
- 验证不同用户的聊天历史是否相互隔离
6.4 测试历史记录持久化
- 发送多条消息
- 重启服务
- 验证聊天历史是否仍然存在
6.5 测试历史记录管理
- 获取用户的聊天历史列表:
get http://localhost:9527/api/ai/history/chat - 删除指定聊天历史:
delete http://localhost:9527/api/ai/history/chat/test123 - 验证聊天历史是否已删除
以上就是使用springboot3整合spring ai实现具有记忆功能的ai助手的详细内容,更多关于springboot3 spring ai实现ai助手的资料请关注代码网其它相关文章!
发表评论