当前位置: 代码网 > it编程>编程语言>Java > 使用SpringBoot3整合Spring AI实现具有记忆功能的AI助手

使用SpringBoot3整合Spring AI实现具有记忆功能的AI助手

2026年04月12日 Java 我要评论
1. 项目概述本教程详细介绍如何使用 spring boot 3 整合 spring ai 实现一个具有记忆功能的 ai 助手。该实现使用 redis 作为存储介质,支持用户级别的会话隔离和 30 天

1. 项目概述

本教程详细介绍如何使用 spring boot 3 整合 spring ai 实现一个具有记忆功能的 ai 助手。该实现使用 redis 作为存储介质,支持用户级别的会话隔离和 30 天的对话历史持久化。

技术栈

  • spring boot 3.3.0
  • java 17
  • spring ai
  • redis 6.0+
  • mybatis plus
  • mysql 8.0
  • sa-token(用户认证)

2. 环境准备

2.1 安装必要软件

2.2 配置环境变量

确保 java_homemaven_home 已正确配置。

3. 项目初始化

3.1 创建 spring boot 项目

使用 spring initializr 创建项目:

  • 访问 spring initializr
  • 选择 spring boot 3.3.0
  • 选择 java 17
  • 添加依赖:spring web, spring data redis, mybatis plus, mysql driver, spring boot devtools

3.2 配置 maven 依赖

pom.xml 文件中添加以下依赖:

<dependencies>
    <!-- spring boot 核心依赖 -->
    <dependency>
        <groupid>org.springframework.boot</groupid>
        <artifactid>spring-boot-starter-web</artifactid>
    </dependency>
    <!-- spring data redis -->
    <dependency>
        <groupid>org.springframework.boot</groupid>
        <artifactid>spring-boot-starter-data-redis</artifactid>
    </dependency>
    <!-- mybatis plus -->
    <dependency>
        <groupid>com.baomidou</groupid>
        <artifactid>mybatis-plus-boot-starter</artifactid>
        <version>3.5.5</version>
    </dependency>
    <!-- mysql 驱动 -->
    <dependency>
        <groupid>com.mysql</groupid>
        <artifactid>mysql-connector-j</artifactid>
        <scope>runtime</scope>
    </dependency>
    <!-- spring ai -->
    <dependency>
        <groupid>org.springframework.ai</groupid>
        <artifactid>spring-ai-openai</artifactid>
        <version>1.0.0</version>
    </dependency>
    <!-- sa-token 认证 -->
    <dependency>
        <groupid>cn.dev33</groupid>
        <artifactid>sa-token-spring-boot-starter</artifactid>
        <version>1.38.1</version>
    </dependency>
    <!-- jackson 序列化 -->
    <dependency>
        <groupid>com.fasterxml.jackson.core</groupid>
        <artifactid>jackson-databind</artifactid>
    </dependency>
    <!-- lombok -->
    <dependency>
        <groupid>org.projectlombok</groupid>
        <artifactid>lombok</artifactid>
        <optional>true</optional>
    </dependency>
    <!-- spring boot 测试 -->
    <dependency>
        <groupid>org.springframework.boot</groupid>
        <artifactid>spring-boot-starter-test</artifactid>
        <scope>test</scope>
    </dependency>
</dependencies>

4. 核心配置

4.1 应用配置文件 (application.yml)

创建 src/main/resources/application.yml 文件,配置应用信息:

server:
  port: 9527
  servlet:
    context-path: /api
spring:
  application:
    name: smart-pic-community-backend
  # redis 配置
  data:
    redis:
      database: 2
      host: localhost
      port: 6379
      timeout: 5000
  # 数据库配置
  datasource:
    driver-class-name: com.mysql.cj.jdbc.driver
    url: jdbc:mysql://localhost:3306/smart_pic_community
    username: root
    password: your_password
  # spring ai 配置
  ai:
    openai:
      base-url: https://api.deepseek.com/  # 使用 deepseek api
      api-key: your_api_key
      chat:
        options:
          model: deepseek-chat

4.2 redis 配置 (redisconfiguration.java)

创建 redis 配置类,确保正确序列化对象:

package com.spc.smartpiccommunitybackend.config;
import com.fasterxml.jackson.annotation.jsonautodetect;
import com.fasterxml.jackson.annotation.propertyaccessor;
import com.fasterxml.jackson.databind.objectmapper;
import lombok.extern.slf4j.slf4j;
import org.springframework.context.annotation.bean;
import org.springframework.context.annotation.configuration;
import org.springframework.data.redis.connection.redisconnectionfactory;
import org.springframework.data.redis.core.redistemplate;
import org.springframework.data.redis.serializer.jackson2jsonredisserializer;
import org.springframework.data.redis.serializer.stringredisserializer;
@configuration
@slf4j
public class redisconfiguration {
    @bean
    public redistemplate<string, object> redistemplate(redisconnectionfactory redisconnectionfactory) {
        log.info("开始创建redis模板对象...");
        redistemplate<string, object> redistemplate = new redistemplate<>();
        // 设置连接工厂
        redistemplate.setconnectionfactory(redisconnectionfactory);
        // 使用 stringredisserializer 来序列化和反序列化 redis 的 key
        stringredisserializer stringredisserializer = new stringredisserializer();
        // key 采用 string 的序列化方式
        redistemplate.setkeyserializer(stringredisserializer);
        redistemplate.sethashkeyserializer(stringredisserializer);
        // 使用 jackson2jsonredisserializer 来序列化和反序列化 redis 的 value
        jackson2jsonredisserializer<object> jackson2jsonredisserializer = new jackson2jsonredisserializer<>(object.class);
        objectmapper objectmapper = new objectmapper();
        objectmapper.setvisibility(propertyaccessor.all, jsonautodetect.visibility.any);
        objectmapper.enabledefaulttyping(objectmapper.defaulttyping.non_final);
        jackson2jsonredisserializer.setobjectmapper(objectmapper);
        // value 采用 json 的序列化方式
        redistemplate.setvalueserializer(jackson2jsonredisserializer);
        redistemplate.sethashvalueserializer(jackson2jsonredisserializer);
        redistemplate.afterpropertiesset();
        return redistemplate;
    }
}

注意事项

  • 使用 jackson2jsonredisserializer 而不是 stringredisserializer 可以避免 classcastexception
  • 启用默认类型可以确保反序列化时能正确识别对象类型

5. 核心功能实现

5.1 创建模型类

5.1.1 创建 messagevo.java

创建消息视图对象,用于前端展示:

package com.spc.smartpiccommunitybackend.model.vo.ai;

import lombok.data;
import lombok.noargsconstructor;
import org.springframework.ai.chat.messages.message;

@noargsconstructor
@data
public class messagevo {
    private string role;
    private string content;

    public messagevo(message message) {
        this.role = switch (message.getmessagetype()) {
            case user -> "user";
            case assistant -> "assistant";
            case system -> "system";
            default -> "";
        };
        this.content = message.gettext();
    }
}

功能说明

  • 将 spring ai 的 message 对象转换为前端可识别的格式
  • 根据消息类型设置不同的角色
  • 提取消息内容用于展示

5.1.2 创建 serializablemessage.java

创建可序列化的消息对象,用于 redis 存储:

package com.spc.smartpiccommunitybackend.model.entity.ai;

import lombok.data;
import lombok.noargsconstructor;
import java.io.serializable;

@data
@noargsconstructor
public class serializablemessage implements serializable {
    private static final long serialversionuid = 1l;

    private string role;
    private string content;
    private string messagetype;
    private long timestamp;

    public serializablemessage(string role, string content, string messagetype) {
        this.role = role;
        this.content = content;
        this.messagetype = messagetype;
        this.timestamp = system.currenttimemillis();
    }

    public serializablemessage(string role, string content) {
        this(role, content, "user");
    }
}

功能说明

  • 实现 serializable 接口,支持 redis 序列化
  • 包含角色、内容、消息类型和时间戳字段
  • 提供多个构造方法,方便使用

5.2 定义仓库接口 (chathistoryrepository.java)

创建聊天历史仓库接口:

package com.spc.smartpiccommunitybackend.repository;

import java.util.list;
import java.util.map;

public interface chathistoryrepository {
    /**
     * 保存会话记录
     */
    void save(string type, string chatid, long userid);
    
    /**
     * 获取用户的会话id列表
     */
    list<string> getchatids(long userid, string type);
    
    /**
     * 保存聊天消息
     */
    void savemessage(string chatid, string message, string sender);
    
    /**
     * 获取聊天消息历史
     */
    list<string> getmessages(string chatid);
    
    /**
     * 删除会话
     */
    void deletechat(long userid, string type, string chatid);
    
    /**
     * 获取会话信息
     */
    map<object, object> getsessioninfo(string chatid);
}

5.3 实现 redis 聊天历史仓库 (redischathistoryrepository.java)

实现基于 redis 的聊天历史仓库:

package com.spc.smartpiccommunitybackend.repository;

import com.fasterxml.jackson.databind.objectmapper;
import com.fasterxml.jackson.core.jsonprocessingexception;
import lombok.requiredargsconstructor;
import org.springframework.data.redis.core.redistemplate;
import org.springframework.stereotype.component;

import java.util.*;
import java.util.concurrent.timeunit;
import java.util.stream.collectors;

@component
@requiredargsconstructor
public class redischathistoryrepository implements chathistoryrepository {

    private final redistemplate<string, object> redistemplate;

    // redis key前缀
    private static final string chat_history_prefix = "chat:history:";
    private static final string chat_session_prefix = "chat:session:";
    private static final string chat_messages_prefix = "chat:messages:";

    /**
     * 保存会话记录
     */
    @override
    public void save(string type, string chatid, long userid) {
        // 保存会话信息
        string sessionkey = chat_session_prefix + chatid;
        map<string, object> sessioninfo = new hashmap<>();
        sessioninfo.put("userid", string.valueof(userid));
        sessioninfo.put("type", type);
        sessioninfo.put("createtime", system.currenttimemillis());
        sessioninfo.put("lastupdatetime", system.currenttimemillis());
        redistemplate.opsforhash().putall(sessionkey, sessioninfo);
        // 设置过期时间为30天
        redistemplate.expire(sessionkey, 30, timeunit.days);

        // 将chatid添加到用户的聊天历史列表中
        string historykey = chat_history_prefix + userid + ":" + type;
        redistemplate.opsforset().add(historykey, chatid);
        // 设置过期时间为30天
        redistemplate.expire(historykey, 30, timeunit.days);
    }

    /**
     * 获取用户的会话id列表
     */
    @override
    public list<string> getchatids(long userid, string type) {
        string historykey = chat_history_prefix + userid + ":" + type;
        set<object> chatids = redistemplate.opsforset().members(historykey);
        if (chatids == null || chatids.isempty()) {
            return collections.emptylist();
        }
        return chatids.stream()
                .map(object::tostring)
                .collect(collectors.tolist());
    }

    /**
     * 保存聊天消息
     */
    @override
    public void savemessage(string chatid, string message, string sender) {
        string messageskey = chat_messages_prefix + chatid;
        // 创建消息对象
        map<string, object> messageinfo = new hashmap<>();
        messageinfo.put("content", message);
        messageinfo.put("sender", sender);
        messageinfo.put("timestamp", system.currenttimemillis());
        // 使用json格式保存消息
        objectmapper objectmapper = new objectmapper();
        try {
            string jsonmessage = objectmapper.writevalueasstring(messageinfo);
            redistemplate.opsforlist().rightpush(messageskey, jsonmessage);
        } catch (jsonprocessingexception e) {
            e.printstacktrace();
            // 如果json序列化失败,使用原始消息
            redistemplate.opsforlist().rightpush(messageskey, message);
        }
        // 设置过期时间为30天
        redistemplate.expire(messageskey, 30, timeunit.days);

        // 更新会话的最后更新时间
        string sessionkey = chat_session_prefix + chatid;
        redistemplate.opsforhash().put(sessionkey, "lastupdatetime", system.currenttimemillis());
        // 确保会话信息也有过期时间
        redistemplate.expire(sessionkey, 30, timeunit.days);
    }

    /**
     * 获取聊天消息历史
     */
    @override
    public list<string> getmessages(string chatid) {
        string messageskey = chat_messages_prefix + chatid;
        list<object> messages = redistemplate.opsforlist().range(messageskey, 0, -1);
        if (messages == null || messages.isempty()) {
            return collections.emptylist();
        }
        return messages.stream()
                .map(object::tostring)
                .collect(collectors.tolist());
    }

    /**
     * 删除会话
     */
    @override
    public void deletechat(long userid, string type, string chatid) {
        // 从用户的聊天历史列表中删除
        string historykey = chat_history_prefix + userid + ":" + type;
        redistemplate.opsforset().remove(historykey, chatid);

        // 删除会话信息
        string sessionkey = chat_session_prefix + chatid;
        redistemplate.delete(sessionkey);

        // 删除聊天消息
        string messageskey = chat_messages_prefix + chatid;
        redistemplate.delete(messageskey);
    }

    /**
     * 获取会话信息
     */
    @override
    public map<object, object> getsessioninfo(string chatid) {
        string sessionkey = chat_session_prefix + chatid;
        return redistemplate.opsforhash().entries(sessionkey);
    }
}

注意事项

  • 使用不同的 redis key 前缀区分不同类型的数据
  • 为所有 redis 键设置过期时间,避免内存泄漏
  • 处理 json 序列化失败的情况,提高系统健壮性

5.4 实现聊天记忆 (redischatmemory.java)

实现基于 redis 的聊天记忆,支持 spring ai 的 chatmemory 接口:

package com.spc.smartpiccommunitybackend.config;

import org.springframework.ai.chat.messages.message;
import org.springframework.ai.chat.memory.chatmemory;
import org.springframework.data.redis.core.redistemplate;
import org.springframework.stereotype.component;

import java.util.arraylist;
import java.util.list;
import java.util.concurrent.timeunit;

@component
public class redischatmemory implements chatmemory {

    private final redistemplate<string, object> redistemplate;
    private static final string memory_key_prefix = "chat:memory:";
    private static final long expiration_days = 30;

    public redischatmemory(redistemplate<string, object> redistemplate) {
        this.redistemplate = redistemplate;
    }

    @override
    public void add(string key, list<message> messages) {
        // 为特定会话添加多条消息
        for (message message : messages) {
            addmessage(key, message);
        }
    }

    @override
    public list<message> get(string key, int maxcount) {
        // 实现 get 方法,根据 key 获取消息
        list<message> messages = getmessages(key);
        // 如果指定了最大数量,返回不超过该数量的消息
        if (maxcount > 0 && messages.size() > maxcount) {
            return messages.sublist(messages.size() - maxcount, messages.size());
        }
        return messages;
    }

    @override
    public void clear() {
        // 清理所有会话记忆
        // 注意:这个操作会删除所有聊天记忆,谨慎使用
    }

    /**
     * 为特定会话添加消息
     */
    public void addmessage(string chatid, message message) {
        string key = memory_key_prefix + chatid;
        redistemplate.opsforlist().rightpush(key, message);
        redistemplate.expire(key, expiration_days, timeunit.days);
    }

    /**
     * 获取特定会话的消息
     */
    public list<message> getmessages(string chatid) {
        string key = memory_key_prefix + chatid;
        list<object> objects = redistemplate.opsforlist().range(key, 0, -1);
        list<message> messages = new arraylist<>();
        if (objects != null) {
            for (object obj : objects) {
                if (obj instanceof message) {
                    messages.add((message) obj);
                }
            }
        }
        return messages;
    }

    /**
     * 清理特定会话的记忆
     */
    public void clear(string chatid) {
        string key = memory_key_prefix + chatid;
        redistemplate.delete(key);
    }
}

注意事项

  • 实现 chatmemory 接口以支持 spring ai 的消息记忆功能
  • 为每条消息设置过期时间,确保内存使用合理
  • 提供批量添加和获取消息的方法,提高性能

5.5 配置 chatclient (commonconfiguration.java)

配置 spring ai 的 chatclient

package com.spc.smartpiccommunitybackend.config;

import org.springframework.ai.chat.client.chatclient;
import org.springframework.ai.chat.client.advisor.messagechatmemoryadvisor;
import org.springframework.ai.chat.client.advisor.simpleloggeradvisor;
import org.springframework.ai.chat.memory.chatmemory;
import org.springframework.ai.openai.openaichatmodel;
import org.springframework.context.annotation.bean;
import org.springframework.context.annotation.configuration;

@configuration
public class commonconfiguration {

    @bean
    public chatclient chatclient(openaichatmodel openaichatmodel, chatmemory chatmemory) {
        return chatclient.builder(openaichatmodel)
                .defaultadvisors(
                        new simpleloggeradvisor(),
                        new messagechatmemoryadvisor(chatmemory)
                )
                .build();
    }
}

注意事项

  • 使用 messagechatmemoryadvisor 来启用聊天记忆功能
  • 使用 simpleloggeradvisor 来记录聊天交互,方便调试

5.6 实现聊天控制器 (chatcontroller.java)

实现聊天控制器,处理 ai 对话请求:

package com.spc.smartpiccommunitybackend.controller;

import com.fasterxml.jackson.databind.jsonnode;
import com.fasterxml.jackson.databind.objectmapper;
import com.spc.smartpiccommunitybackend.repository.redischathistoryrepository;
import com.spc.smartpiccommunitybackend.service.userservice;
import com.spc.smartpiccommunitybackend.utils.errorcode;
import com.spc.smartpiccommunitybackend.utils.throwutils;
import com.spc.smartpiccommunitybackend.pojo.user;
import org.springframework.ai.chat.client.chatclient;
import org.springframework.ai.chat.messages.message;
import org.springframework.ai.chat.messages.systemmessage;
import org.springframework.ai.chat.messages.usermessage;
import org.springframework.ai.chat.messages.assistantmessage;
import org.springframework.beans.factory.annotation.resource;
import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.flux;

import javax.servlet.http.httpservletrequest;
import java.io.ioexception;
import java.util.arraylist;
import java.util.list;

@restcontroller
@requestmapping("/ai")
public class chatcontroller {

    private final chatclient chatclient;
    private final redischathistoryrepository chathistoryrepository;
    @resource
    private userservice userservice;

    public chatcontroller(chatclient chatclient, redischathistoryrepository chathistoryrepository) {
        this.chatclient = chatclient;
        this.chathistoryrepository = chathistoryrepository;
    }

    @requestmapping(value = "/chat", produces = "text/html;charset=utf-8")
    public flux<string> chat(@requestparam(defaultvalue = "讲个笑话") string prompt,
                             string chatid,
                             httpservletrequest request) {
        user loginuser = userservice.getloginuser(request);
        // 校验登录用户是否为空
        throwutils.throwif(loginuser == null, errorcode.not_login_error);
        long userid = loginuser.getid();
        
        // 保存会话信息
        chathistoryrepository.save("chat", chatid, userid);

        // 获取历史对话消息作为上下文
        list<message> messages = new arraylist<>();
        
        // 添加系统消息
        systemmessage systemmessage = new systemmessage(
            "你是一个智能图片社区的ai助手,名为虹小智。请用友好、专业的语气回答用户问题," +
            "提供关于图片社区的相关信息和帮助。"
        );
        messages.add(systemmessage);
        
        // 获取并解析历史消息
        list<string> historymessages = chathistoryrepository.getmessages(chatid);
        objectmapper objectmapper = new objectmapper();
        
        for (string messagestr : historymessages) {
            try {
                jsonnode node = objectmapper.readtree(messagestr);
                string sender = node.get("sender").astext();
                string content = node.get("content").astext();
                
                if ("user".equals(sender)) {
                    messages.add(new usermessage(content));
                } else if ("ai".equals(sender)) {
                    messages.add(new assistantmessage(content));
                }
            } catch (ioexception e) {
                e.printstacktrace();
            }
        }
        
        // 添加用户当前消息
        messages.add(new usermessage(prompt));
        
        // 保存用户消息到历史记录
        chathistoryrepository.savemessage(chatid, prompt, "user");
        
        // 调用ai模型获取响应
        return chatclient.stream(messages)
                .doonnext(response -> {
                    // 保存ai响应到历史记录
                    chathistoryrepository.savemessage(chatid, response, "ai");
                });
    }
}

注意事项

  • 验证用户登录状态,确保会话隔离
  • 保存用户消息和 ai 响应到历史记录
  • 使用 flux 实现流式响应,提高用户体验
  • 处理历史消息解析异常,提高系统健壮性

5.7 实现聊天历史控制器 (chathistorycontroller.java)

实现聊天历史控制器,处理聊天历史的获取和删除:

package com.spc.smartpiccommunitybackend.controller;

import com.spc.smartpiccommunitybackend.repository.redischathistoryrepository;
import com.spc.smartpiccommunitybackend.service.userservice;
import com.spc.smartpiccommunitybackend.utils.errorcode;
import com.spc.smartpiccommunitybackend.utils.throwutils;
import com.spc.smartpiccommunitybackend.pojo.user;
import org.springframework.web.bind.annotation.*;

import javax.servlet.http.httpservletrequest;
import java.util.list;

@restcontroller
@requestmapping("/ai/history")
public class chathistorycontroller {

    private final redischathistoryrepository chathistoryrepository;
    private final userservice userservice;

    public chathistorycontroller(redischathistoryrepository chathistoryrepository, userservice userservice) {
        this.chathistoryrepository = chathistoryrepository;
        this.userservice = userservice;
    }

    /**
     * 获取用户的聊天历史id列表
     */
    @getmapping("/{type}")
    public list<string> getchathistory(@pathvariable string type, httpservletrequest request) {
        user loginuser = userservice.getloginuser(request);
        throwutils.throwif(loginuser == null, errorcode.not_login_error);
        long userid = loginuser.getid();
        
        return chathistoryrepository.getchatids(userid, type);
    }

    /**
     * 删除指定聊天历史
     */
    @deletemapping("/{type}/{chatid}")
    public boolean deletechathistory(@pathvariable string type, 
                                    @pathvariable string chatid, 
                                    httpservletrequest request) {
        user loginuser = userservice.getloginuser(request);
        throwutils.throwif(loginuser == null, errorcode.not_login_error);
        long userid = loginuser.getid();
        
        chathistoryrepository.deletechat(userid, type, chatid);
        return true;
    }
}

注意事项

  • 验证用户登录状态,确保只能操作自己的聊天历史
  • 提供获取和删除聊天历史的接口,方便前端管理

6. 测试和验证

6.1 启动服务

  1. 确保 redis 和 mysql 服务已启动
  2. 运行 spring boot 应用
  3. 访问 http://localhost:9527/api/ai/chat?prompt=你好&chatid=test123 测试 ai 响应

6.2 测试聊天记忆功能

  1. 发送第一条消息:http://localhost:9527/api/ai/chat?prompt=你好,我叫张三&chatid=test123
  2. 发送第二条消息:http://localhost:9527/api/ai/chat?prompt=你知道我叫什么名字吗?&chatid=test123
  3. 验证 ai 能否正确回答你的名字

6.3 测试用户隔离功能

  1. 使用不同用户登录
  2. 验证不同用户的聊天历史是否相互隔离

6.4 测试历史记录持久化

  1. 发送多条消息
  2. 重启服务
  3. 验证聊天历史是否仍然存在

6.5 测试历史记录管理

  1. 获取用户的聊天历史列表:get http://localhost:9527/api/ai/history/chat
  2. 删除指定聊天历史:delete http://localhost:9527/api/ai/history/chat/test123
  3. 验证聊天历史是否已删除

以上就是使用springboot3整合spring ai实现具有记忆功能的ai助手的详细内容,更多关于springboot3 spring ai实现ai助手的资料请关注代码网其它相关文章!

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2026  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com