0.前置准备
1.安装ollama
然后安装如下三个模型
c:\users\gt-jyw-3>ollama list name id size modified qwen3-embedding:0.6b ac6da0dfba84 639 mb 8 hours ago qwen3.5:2b 324d162be6ca 2.7 gb 25 hours ago deepseek-r1:1.5b e0979632db5a 1.1 gb 25 hours ag
第一个是检索用的,后面两个对话用,根据自己需要安装别的也行
对话需要带有(tool),检索需要带有(embedding)
2.下载qdrant
qdrant是一个本地轻型的向量数据库
下载地址: https://github.com/qdrant/qdrant/releases
windows直接双击启动就行
3.环境准备
jdk17,然后下面就是完整的代码,直接复制粘贴就可以用
对话直接访问 ip:port/chat?msg=你的问题。 流式返回就是chats?msg=
再加一个&sessionid可以基于redis进行会话存储(使用的是db:6)
文件检索这个我写死了个目录,可以自行修改和扩展在knowledgebaseconfig.java>defaultpath
然后访问ip:port/file/scan 先进行索引创建 然后再/file/chat?msg=你的问题即可
代码99%都是ai写的,本人亲测是可以读取文件夹里面ppt,md,docx内容的
只是模型太笨,可以换个更大的模型,这里用的几个都是很小的来测试的
1.整体结构

2.pom文件
<?xml version="1.0" encoding="utf-8"?>
<project xmlns="http://maven.apache.org/pom/4.0.0" xmlns:xsi="http://www.w3.org/2001/xmlschema-instance"
xsi:schemalocation="http://maven.apache.org/pom/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelversion>4.0.0</modelversion>
<parent>
<groupid>org.springframework.boot</groupid>
<artifactid>spring-boot-starter-parent</artifactid>
<version>3.3.0</version>
<relativepath/>
</parent>
<groupid>com.cxl</groupid>
<artifactid>springai-demo</artifactid>
<version>0.0.1-snapshot</version>
<name>springai-demo</name>
<description>springai-demo</description>
<url/>
<licenses>
<license/>
</licenses>
<developers>
<developer/>
</developers>
<scm>
<connection/>
<developerconnection/>
<tag/>
<url/>
</scm>
<properties>
<java.version>17</java.version>
<spring-ai.version>1.0.0-m6</spring-ai.version>
</properties>
<repositories>
<repository>
<id>spring-milestones</id>
<name>spring milestones</name>
<url>https://repo.spring.io/milestone</url>
<snapshots>
<enabled>false</enabled>
</snapshots>
</repository>
</repositories>
<dependencymanagement>
<dependencies>
<dependency>
<groupid>org.springframework.ai</groupid>
<artifactid>spring-ai-bom</artifactid>
<version>${spring-ai.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencymanagement>
<dependencies>
<dependency>
<groupid>org.springframework.ai</groupid>
<artifactid>spring-ai-qdrant-store-spring-boot-starter</artifactid>
<version>${spring-ai.version}</version> <!-- 替换为实际版本 -->
</dependency>
<!-- spring boot web -->
<dependency>
<groupid>org.springframework.boot</groupid>
<artifactid>spring-boot-starter-web</artifactid>
</dependency>
<!-- spring ai core -->
<dependency>
<groupid>org.springframework.ai</groupid>
<artifactid>spring-ai-core</artifactid>
</dependency>
<!-- spring ai ollama starter -->
<dependency>
<groupid>org.springframework.ai</groupid>
<artifactid>spring-ai-ollama-spring-boot-starter</artifactid>
</dependency>
<!-- spring boot devtools -->
<dependency>
<groupid>org.springframework.boot</groupid>
<artifactid>spring-boot-devtools</artifactid>
<scope>runtime</scope>
<optional>true</optional>
</dependency>
<!-- lombok -->
<dependency>
<groupid>org.projectlombok</groupid>
<artifactid>lombok</artifactid>
<optional>true</optional>
</dependency>
<!-- apache poi for pptx extraction -->
<dependency>
<groupid>org.apache.poi</groupid>
<artifactid>poi-ooxml</artifactid>
<version>5.2.5</version>
</dependency>
<!-- spring data redis -->
<dependency>
<groupid>org.springframework.boot</groupid>
<artifactid>spring-boot-starter-data-redis</artifactid>
</dependency>
<!-- spring aop -->
<dependency>
<groupid>org.springframework.boot</groupid>
<artifactid>spring-boot-starter-aop</artifactid>
</dependency>
<!-- test -->
<dependency>
<groupid>org.springframework.boot</groupid>
<artifactid>spring-boot-starter-test</artifactid>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupid>org.apache.maven.plugins</groupid>
<artifactid>maven-compiler-plugin</artifactid>
<configuration>
<annotationprocessorpaths>
<path>
<groupid>org.springframework.boot</groupid>
<artifactid>spring-boot-configuration-processor</artifactid>
</path>
<path>
<groupid>org.projectlombok</groupid>
<artifactid>lombok</artifactid>
</path>
</annotationprocessorpaths>
</configuration>
</plugin>
<plugin>
<groupid>org.springframework.boot</groupid>
<artifactid>spring-boot-maven-plugin</artifactid>
<configuration>
<excludes>
<exclude>
<groupid>org.projectlombok</groupid>
<artifactid>lombok</artifactid>
</exclude>
</excludes>
</configuration>
</plugin>
</plugins>
</build>
</project>3.application.yaml文件
# ===================================================================
# spring ai demo 配置文件
# ===================================================================
# 服务器配置
server:
port: 12000 # 应用端口
# spring 配置
spring:
application:
name: springai-demo # 应用名称
# redis 配置 - 用于存储对话历史
data:
redis:
host: localhost # redis 服务器地址
port: 6379 # redis 端口
database: 6 # redis 数据库编号 (0-15)
# ai 大模型配置
ai:
# ollama 本地大模型服务配置
ollama:
base-url: http://localhost:11434 # ollama api 地址
embedding:
model: qwen3-embedding:0.6b # 嵌入模型 (用于 rag 向量化)
enabled: true
# 嵌入向量配置
embedding:
options:
model: qwen3-embedding:0.6b # 嵌入模型名称
# 向量数据库配置 (qdrant)
vectorstore:
qdrant:
host: localhost # qdrant 服务地址
port: 6334 # qdrant grpc 端口 (rest api 端口为 6333)
collection-name: my_docs # 向量集合名称
# ===================================================================
# 知识库配置 (knowledge base)
# 用途: filesearchservice 用于 rag 知识检索
# ===================================================================
knowledge-base:
default-path: "c:\\users\\gt-jyw-3\\documents\\doct" # 知识库文件默认目录
collection-name: my_docs # qdrant 向量集合名称
host: localhost # qdrant 服务地址
port: 6334 # qdrant 端口
top-k: 5 # 知识检索返回的相关文档数量
# ===================================================================
# 对话历史配置 (chat history)
# 用途: chathistoryservice + chathistoryaspect 用于存储会话上下文
# ===================================================================
chat-history:
enabled: true # 是否启用对话历史功能
key-prefix: "chat:history:" # redis key 前缀
max-size: 20 # 每个会话最大保存的消息数量
expire-days: 7 # 历史记录过期天数
# ===================================================================
# system 预制词配置 (ai 角色设定)
# 用途: chatservice 调用不同模型时使用的系统提示词
# ===================================================================
system:
settings:
# 默认提示词 (当模型没有特定提示词时使用)
default-prompt: "你是一个有帮助的ai助手,请用简洁专业的语言回答用户问题。"
# 针对不同模型的特定提示词
prompts:
# qwen 模型提示词
qwen3.5:2b: "你是qwen3.5模型,一个由阿里云开发的大语言模型。你擅长中文理解和生成,请用专业、准确的语言回答用户问题。"
# deepseek 模型提示词
deepseek-r1:1.5b: "你是deepseek-r1模型,一个专注于推理和思考的ai助手。请在回答问题时展示清晰的思考过程,提供深入的分析。"
# ===================================================================
# 应用全局配置
# ===================================================================
app:
context-aware-enabled: true # 是否启用上下文感知功能
default-top-k: 3 # 默认知识检索数量
session-timeout-minutes: 60 # 会话超时时间(分钟)4.启用对话记忆注解
package com.cxl.annotation;
import java.lang.annotation.*;
@target(elementtype.method)
@retention(retentionpolicy.runtime)
@documented
public @interface enablechathistory {
boolean enabled() default true;
}5.对话记忆注解切面
package com.cxl.aspect;
import com.cxl.annotation.enablechathistory;
import com.cxl.service.chathistoryservice;
import lombok.requiredargsconstructor;
import lombok.extern.slf4j.slf4j;
import org.aspectj.lang.proceedingjoinpoint;
import org.aspectj.lang.annotation.around;
import org.aspectj.lang.annotation.aspect;
import org.aspectj.lang.reflect.methodsignature;
import org.springframework.stereotype.component;
import reactor.core.publisher.flux;
import java.lang.reflect.method;
@slf4j
@aspect
@component
@requiredargsconstructor
public class chathistoryaspect {
private final chathistoryservice chathistoryservice;
@around("@annotation(com.cxl.annotation.enablechathistory)")
public object around(proceedingjoinpoint joinpoint) throws throwable {
methodsignature signature = (methodsignature) joinpoint.getsignature();
method method = signature.getmethod();
enablechathistory annotation = method.getannotation(enablechathistory.class);
if (!annotation.enabled()) {
return joinpoint.proceed();
}
object[] args = joinpoint.getargs();
string[] paramnames = signature.getparameternames();
string sessionid = null;
string usermessage = null;
for (int i = 0; i < args.length; i++) {
string paramname = paramnames[i];
if ("sessionid".equals(paramname) && args[i] != null) {
sessionid = (string) args[i];
}
if ("msg".equals(paramname) && args[i] != null) {
usermessage = (string) args[i];
}
}
if (sessionid == null || sessionid.isempty()) {
return joinpoint.proceed();
}
final string finalsessionid = sessionid;
final string finalusermessage = usermessage;
log.info("========== chathistoryaspect ==========");
log.info("sessionid: {}", finalsessionid);
log.info("usermessage: {}", finalusermessage);
object result = joinpoint.proceed();
if (result instanceof flux) {
flux<string> fluxresult = (flux<string>) result;
return fluxresult
.collectlist()
.doonnext(chunks -> {
string fullresponse = string.join("", chunks);
chathistoryservice.addmessage(finalsessionid, "user", finalusermessage);
chathistoryservice.addmessage(finalsessionid, "assistant", fullresponse);
log.info("saved stream chat history to redis");
})
.flatmapmany(chunks -> flux.fromiterable(chunks));
} else if (result instanceof string) {
chathistoryservice.addmessage(finalsessionid, "user", finalusermessage);
chathistoryservice.addmessage(finalsessionid, "assistant", (string) result);
log.info("saved chat history to redis");
}
return result;
}
}6.配置类
1.aiconfig
package com.cxl.config;
import io.micrometer.observation.observationregistry;
import org.springframework.ai.chat.model.chatmodel;
import org.springframework.ai.embedding.embeddingmodel;
import org.springframework.ai.ollama.ollamachatmodel;
import org.springframework.ai.ollama.ollamaembeddingmodel;
import org.springframework.ai.ollama.api.ollamaapi;
import org.springframework.ai.ollama.api.ollamaoptions;
import org.springframework.ai.ollama.management.modelmanagementoptions;
import org.springframework.context.annotation.bean;
import org.springframework.context.annotation.configuration;
import org.springframework.context.annotation.primary;
import java.util.list;
@configuration
public class aiconfig {
@bean
public ollamaapi ollamaapi() {
return new ollamaapi("http://localhost:11434");
}
@bean
public observationregistry observationregistry() {
return observationregistry.noop;
}
@bean
@primary
public chatmodel qwenchatmodel(ollamaapi ollamaapi, observationregistry observationregistry) {
return new ollamachatmodel(
ollamaapi,
ollamaoptions.builder().model("qwen3.5:2b").build(),
null,
list.of(),
observationregistry,
modelmanagementoptions.builder().build()
);
}
@bean("deepseekchatmodel")
public chatmodel deepseekchatmodel(ollamaapi ollamaapi, observationregistry observationregistry) {
return new ollamachatmodel(
ollamaapi,
ollamaoptions.builder().model("deepseek-r1:1.5b").build(),
null,
list.of(),
observationregistry,
modelmanagementoptions.builder().build()
);
}
@bean
public embeddingmodel embeddingmodel(ollamaapi ollamaapi, observationregistry observationregistry) {
return new ollamaembeddingmodel(
ollamaapi,
ollamaoptions.builder().model("qwen3-embedding:0.6b").build(),
observationregistry,
modelmanagementoptions.builder().build()
);
}
}2.appconfig
package com.cxl.config;
import lombok.data;
import org.springframework.boot.context.properties.configurationproperties;
import org.springframework.stereotype.component;
/**
* 全局配置类
* 管理应用的全局开关和配置
*/
@data
@component
@configurationproperties(prefix = "app")
public class appconfig {
/**
* 上下文感知全局开关
* true: 启用上下文感知
* false: 禁用上下文感知
*/
private boolean contextawareenabled = true;
/**
* 文件检索默认返回数量
*/
private int defaulttopk = 3;
/**
* 会话超时时间(分钟)
*/
private int sessiontimeoutminutes = 60;
/**
* 文件扫描的最大文件大小(字节)
*/
private long maxfilesize = 10 * 1024 * 1024; // 10mb
/**
* 支持的文件类型
*/
private string[] supportedfiletypes = {
"txt", "md", "json", "xml", "html", "css", "js", "java",
"pdf", "pptx", "ppt",
"jpg", "jpeg", "png", "gif", "webp",
"mp4", "avi", "mov", "wmv"
};
}3.对话客户端配置
package com.cxl.config;
import org.springframework.ai.chat.client.chatclient;
import org.springframework.ai.chat.model.chatmodel;
import org.springframework.beans.factory.annotation.qualifier;
import org.springframework.context.annotation.bean;
import org.springframework.context.annotation.configuration;
import org.springframework.context.annotation.primary;
@configuration
public class chatclientconfig {
@bean
@primary
public chatclient qwenchatclient(chatmodel chatmodel) {
return chatclient.builder(chatmodel).build();
}
@bean("deepseekchatclient")
public chatclient deepseekchatclient(@qualifier("deepseekchatmodel") chatmodel chatmodel) {
return chatclient.builder(chatmodel).build();
}
}4.知识库配置(向量数据库)
package com.cxl.config;
import org.springframework.boot.context.properties.configurationproperties;
import org.springframework.context.annotation.configuration;
@configuration
@configurationproperties(prefix = "knowledge-base")
public class knowledgebaseconfig {
private string defaultpath = "c:\\users\\gt-jyw-3\\documents\\doct";
private string collectionname = "my_docs";
private string host = "localhost";
private int port = 6334;
private int topk = 5;
public string getdefaultpath() {
return defaultpath;
}
public void setdefaultpath(string defaultpath) {
this.defaultpath = defaultpath;
}
public string getcollectionname() {
return collectionname;
}
public void setcollectionname(string collectionname) {
this.collectionname = collectionname;
}
public string gethost() {
return host;
}
public void sethost(string host) {
this.host = host;
}
public int getport() {
return port;
}
public void setport(int port) {
this.port = port;
}
public int gettopk() {
return topk;
}
public void settopk(int topk) {
this.topk = topk;
}
}5.系统配置
package com.cxl.config;
import lombok.data;
import org.springframework.boot.context.properties.configurationproperties;
import org.springframework.stereotype.component;
import java.util.hashmap;
import java.util.map;
/**
* system预制词配置类
* 用于管理不同模型的system预制词
*/
@data
@component
@configurationproperties(prefix = "system.settings")
public class systemsettings {
/**
* 不同模型的system预制词配置
* key: 模型名称
* value: system预制词内容
*/
private map<string, string> prompts = new hashmap<>();
/**
* 默认system预制词
*/
private string defaultprompt = "你是一个有帮助的ai助手。";
/**
* 根据模型名称获取对应的system预制词
* @param modelname 模型名称
* @return system预制词
*/
public string getsystemprompt(string modelname) {
return prompts.getordefault(modelname, defaultprompt);
}
}7.接口
1.基本对话&流式返回
package com.cxl.controller;
import com.cxl.annotation.enablechathistory;
import com.cxl.service.chatservice;
import lombok.requiredargsconstructor;
import org.springframework.http.mediatype;
import org.springframework.web.bind.annotation.getmapping;
import org.springframework.web.bind.annotation.requestmapping;
import org.springframework.web.bind.annotation.requestparam;
import org.springframework.web.bind.annotation.restcontroller;
import reactor.core.publisher.flux;
@restcontroller
@requestmapping
@requiredargsconstructor
public class chatcontroller {
private final chatservice chatservice;
@enablechathistory
@getmapping("/chat")
public string chat(
@requestparam string msg,
@requestparam(required = false) string sessionid,
@requestparam(defaultvalue = "qwen") string model) {
return chatservice.chat(msg, model, sessionid);
}
@enablechathistory
@getmapping(value = "/chats", produces = mediatype.text_event_stream_value)
public flux<string> chatstream(
@requestparam string msg,
@requestparam(required = false) string sessionid,
@requestparam(defaultvalue = "qwen") string model) {
return chatservice.chatstreamraw(msg, model, sessionid);
}
}2.基于文件目录对话
package com.cxl.controller;
import com.cxl.annotation.enablechathistory;
import com.cxl.service.chathistoryservice;
import com.cxl.service.chatservice;
import com.cxl.service.filesearchservice;
import com.cxl.service.filesearchservice.documentinfo;
import com.cxl.service.sessionmanagerservice;
import lombok.requiredargsconstructor;
import org.springframework.http.mediatype;
import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.flux;
import java.util.hashmap;
import java.util.list;
import java.util.map;
@restcontroller
@requestmapping("/file")
@requiredargsconstructor
public class filesearchcontroller {
private final filesearchservice filesearchservice;
private final chatservice chatservice;
private final sessionmanagerservice sessionmanagerservice;
private final chathistoryservice chathistoryservice;
@getmapping("/scan")
public map<string, object> scandirectory(
@requestparam(required = false) string directory,
@requestparam(required = false) string sessionid) {
if (sessionid != null) {
sessionmanagerservice.sessioninfo session = sessionmanagerservice.getorcreatesession(sessionid);
session.updatelastactivity();
}
int filecount = filesearchservice.scandirectory(directory);
map<string, object> result = new hashmap<>();
result.put("status", "success");
result.put("message", "文件夹扫描完成");
result.put("filecount", filecount);
return result;
}
@enablechathistory
@getmapping("/chat")
public string filechat(
@requestparam string msg,
@requestparam string sessionid,
@requestparam(defaultvalue = "qwen") string model,
@requestparam(required = false) boolean contextaware,
@requestparam(defaultvalue = "3") int topk,
@requestparam(required = false) string filetype) {
sessionmanagerservice.sessioninfo session = sessionmanagerservice.getorcreatesession(sessionid);
session.updatelastactivity();
boolean usecontextaware = session.iscontextaware();
if (contextaware != null) {
usecontextaware = contextaware;
session.setcontextaware(usecontextaware);
}
map<string, object> filters = new hashmap<>();
if (filetype != null && !filetype.isempty()) {
filters.put("filetype", filetype);
}
list<documentinfo> documents = filesearchservice.search(msg, topk, filters);
stringbuilder context = new stringbuilder();
if (usecontextaware && !documents.isempty()) {
context.append("根据以下文件内容回答问题:\n\n");
for (documentinfo doc : documents) {
context.append("文件: ").append(doc.filename)
.append("\n路径: ").append(doc.path)
.append("\n内容: ").append(doc.content.substring(0, math.min(500, doc.content.length())))
.append("...\n\n");
}
}
string prompt = context.tostring() + "用户问题: " + msg;
return chatservice.chat(prompt, model, sessionid);
}
@enablechathistory
@getmapping(value = "/chats", produces = mediatype.text_event_stream_value)
public flux<string> filechatstream(
@requestparam string msg,
@requestparam string sessionid,
@requestparam(defaultvalue = "qwen") string model,
@requestparam(required = false) boolean contextaware,
@requestparam(defaultvalue = "3") int topk,
@requestparam(required = false) string filetype) {
sessionmanagerservice.sessioninfo session = sessionmanagerservice.getorcreatesession(sessionid);
session.updatelastactivity();
boolean usecontextaware = session.iscontextaware();
if (contextaware != null) {
usecontextaware = contextaware;
session.setcontextaware(usecontextaware);
}
map<string, object> filters = new hashmap<>();
if (filetype != null && !filetype.isempty()) {
filters.put("filetype", filetype);
}
list<documentinfo> documents = filesearchservice.search(msg, topk, filters);
stringbuilder context = new stringbuilder();
if (usecontextaware && !documents.isempty()) {
context.append("根据以下文件内容回答问题:\n\n");
for (documentinfo doc : documents) {
context.append("文件: ").append(doc.filename)
.append("\n路径: ").append(doc.path)
.append("\n内容: ").append(doc.content.substring(0, math.min(500, doc.content.length())))
.append("...\n\n");
}
}
string prompt = context.tostring() + "用户问题: " + msg;
return chatservice.chatstreamraw(prompt, model, sessionid);
}
@getmapping("/clear")
public map<string, object> clearall(@requestparam(required = false) string sessionid) {
filesearchservice.clearall();
if (sessionid != null) {
chathistoryservice.clearhistory(sessionid);
}
map<string, object> result = new hashmap<>();
result.put("status", "success");
result.put("message", sessionid != null
? "已清除向量数据库和会话历史记录"
: "已清除向量数据库中的所有文档");
return result;
}
@getmapping("/context")
public map<string, object> setcontextaware(
@requestparam string sessionid,
@requestparam boolean contextaware) {
sessionmanagerservice.sessioninfo session = sessionmanagerservice.getorcreatesession(sessionid);
session.setcontextaware(contextaware);
map<string, object> result = new hashmap<>();
result.put("status", "success");
result.put("contextaware", contextaware);
result.put("sessionid", sessionid);
return result;
}
}8.service
1.chathistoryservice
package com.cxl.service;
import lombok.extern.slf4j.slf4j;
import org.springframework.boot.context.properties.configurationproperties;
import org.springframework.data.redis.core.stringredistemplate;
import org.springframework.stereotype.service;
import java.time.localdatetime;
import java.time.format.datetimeformatter;
import java.util.arraylist;
import java.util.list;
import java.util.concurrent.timeunit;
@slf4j
@service
@configurationproperties(prefix = "chat-history")
public class chathistoryservice {
private string keyprefix = "chat:history:";
private int maxsize = 20;
private int expiredays = 7;
private boolean enabled = true;
private final stringredistemplate redistemplate;
public chathistoryservice(stringredistemplate redistemplate) {
this.redistemplate = redistemplate;
}
public void addmessage(string sessionid, string role, string content) {
if (!enabled) {
log.debug("chat history is disabled, skipping save");
return;
}
if (sessionid == null || sessionid.isempty()) {
log.warn("sessionid is null or empty, skipping save");
return;
}
string key = keyprefix + sessionid;
string message = string.format("%s|%s|%s",
localdatetime.now().format(datetimeformatter.iso_local_date_time),
role,
content
);
try {
redistemplate.opsforlist().rightpush(key, message);
redistemplate.opsforlist().trim(key, -maxsize, -1);
redistemplate.expire(key, expiredays, timeunit.days);
log.info("saved message to redis - sessionid: {}, role: {}, key: {}", sessionid, role, key);
} catch (exception e) {
log.error("failed to save message to redis: {}", e.getmessage(), e);
}
}
public list<chatmessage> gethistory(string sessionid) {
if (!enabled) {
log.debug("chat history is disabled, returning empty history");
return new arraylist<>();
}
if (sessionid == null || sessionid.isempty()) {
log.warn("sessionid is null or empty, returning empty history");
return new arraylist<>();
}
string key = keyprefix + sessionid;
log.info("getting history from redis - sessionid: {}, key: {}", sessionid, key);
try {
list<string> rawmessages = redistemplate.opsforlist().range(key, 0, -1);
if (rawmessages == null || rawmessages.isempty()) {
log.info("no history found in redis for sessionid: {}", sessionid);
return new arraylist<>();
}
log.info("found {} messages in redis for sessionid: {}", rawmessages.size(), sessionid);
list<chatmessage> messages = new arraylist<>();
for (string raw : rawmessages) {
string[] parts = raw.split("\\|", 3);
if (parts.length >= 3) {
chatmessage msg = new chatmessage();
msg.settimestamp(parts[0]);
msg.setrole(parts[1]);
msg.setcontent(parts[2]);
messages.add(msg);
}
}
return messages;
} catch (exception e) {
log.error("failed to get history from redis: {}", e.getmessage(), e);
return new arraylist<>();
}
}
public string getlastusermessage(string sessionid) {
list<chatmessage> history = gethistory(sessionid);
for (int i = history.size() - 1; i >= 0; i--) {
if ("user".equals(history.get(i).getrole())) {
return history.get(i).getcontent();
}
}
return null;
}
public void clearhistory(string sessionid) {
string key = keyprefix + sessionid;
redistemplate.delete(key);
}
public string getkeyprefix() { return keyprefix; }
public void setkeyprefix(string keyprefix) { this.keyprefix = keyprefix; }
public int getmaxsize() { return maxsize; }
public void setmaxsize(int maxsize) { this.maxsize = maxsize; }
public int getexpiredays() { return expiredays; }
public void setexpiredays(int expiredays) { this.expiredays = expiredays; }
public boolean isenabled() { return enabled; }
public void setenabled(boolean enabled) { this.enabled = enabled; }
public static class chatmessage {
private string timestamp;
private string role;
private string content;
public string gettimestamp() { return timestamp; }
public void settimestamp(string timestamp) { this.timestamp = timestamp; }
public string getrole() { return role; }
public void setrole(string role) { this.role = role; }
public string getcontent() { return content; }
public void setcontent(string content) { this.content = content; }
}
}2.chatservice
package com.cxl.service;
import com.cxl.config.systemsettings;
import lombok.extern.slf4j.slf4j;
import org.springframework.ai.chat.client.chatclient;
import org.springframework.ai.chat.messages.systemmessage;
import org.springframework.ai.chat.messages.usermessage;
import org.springframework.beans.factory.annotation.qualifier;
import org.springframework.stereotype.service;
import reactor.core.publisher.flux;
import java.time.localdatetime;
import java.time.format.datetimeformatter;
import java.util.list;
import java.util.stream.collectors;
@slf4j
@service
public class chatservice {
private final chatclient qwenchatclient;
private final chatclient deepseekchatclient;
private final systemsettings systemsettings;
private final chathistoryservice chathistoryservice;
private final datetimeformatter formatter = datetimeformatter.ofpattern("yyyy-mm-dd hh:mm:ss");
public chatservice(
chatclient qwenchatclient,
@qualifier("deepseekchatclient") chatclient deepseekchatclient,
systemsettings systemsettings,
chathistoryservice chathistoryservice) {
this.qwenchatclient = qwenchatclient;
this.deepseekchatclient = deepseekchatclient;
this.systemsettings = systemsettings;
this.chathistoryservice = chathistoryservice;
}
public string chat(string msg, string model) {
return chatwithhistory(msg, model, null);
}
public string chat(string msg, string model, string sessionid) {
return chatwithhistory(msg, model, sessionid);
}
private string chatwithhistory(string msg, string model, string sessionid) {
string starttime = localdatetime.now().format(formatter);
log.info("========== 对话开始 ==========");
log.info("开始时间: {}", starttime);
log.info("使用模型: {}", model);
log.info("用户消息: {}", msg);
if (sessionid != null) {
log.info("会话id: {}", sessionid);
}
chatclient chatclient = getchatclient(model);
string modelname = getmodelname(model);
string systemprompt = systemsettings.getsystemprompt(modelname);
log.info("system预制词: {}", systemprompt);
long startmillis = system.currenttimemillis();
chatclient.chatclientrequestspec promptspec = chatclient.prompt()
.system(systemprompt);
if (sessionid != null) {
list<chathistoryservice.chatmessage> history = chathistoryservice.gethistory(sessionid);
if (!history.isempty()) {
stringbuilder historycontext = new stringbuilder();
historycontext.append("\n以下是之前的对话历史:\n");
for (chathistoryservice.chatmessage chatmsg : history) {
historycontext.append(chatmsg.getrole()).append(": ").append(chatmsg.getcontent()).append("\n");
}
promptspec = promptspec.system(systemprompt + historycontext.tostring());
}
}
string content = promptspec.user(msg).call().content();
long endmillis = system.currenttimemillis();
log.info("响应耗时: {} ms", endmillis - startmillis);
log.info("ai回复: {}", content);
log.info("========== 对话结束 ==========\n");
return content;
}
public flux<string> chatstreamraw(string msg, string model) {
return chatstreamwithhistory(msg, model, null);
}
public flux<string> chatstreamraw(string msg, string model, string sessionid) {
return chatstreamwithhistory(msg, model, sessionid);
}
private flux<string> chatstreamwithhistory(string msg, string model, string sessionid) {
string starttime = localdatetime.now().format(formatter);
log.info("========== 流式对话开始 ==========");
log.info("开始时间: {}", starttime);
log.info("使用模型: {}", model);
log.info("用户消息: {}", msg);
if (sessionid != null) {
log.info("会话id: {}", sessionid);
}
chatclient chatclient = getchatclient(model);
string modelname = getmodelname(model);
string systemprompt = systemsettings.getsystemprompt(modelname);
log.info("system预制词: {}", systemprompt);
long startmillis = system.currenttimemillis();
chatclient.chatclientrequestspec promptspec = chatclient.prompt()
.system(systemprompt);
if (sessionid != null) {
list<chathistoryservice.chatmessage> history = chathistoryservice.gethistory(sessionid);
if (!history.isempty()) {
stringbuilder historycontext = new stringbuilder();
historycontext.append("\n以下是之前的对话历史:\n");
for (chathistoryservice.chatmessage chatmsg : history) {
historycontext.append(chatmsg.getrole()).append(": ").append(chatmsg.getcontent()).append("\n");
}
promptspec = promptspec.system(systemprompt + historycontext.tostring());
}
}
return promptspec.user(msg)
.stream()
.content()
.dooncomplete(() -> {
long endmillis = system.currenttimemillis();
log.info("响应耗时: {} ms", endmillis - startmillis);
log.info("========== 流式对话结束 ==========\n");
})
.doonerror(error -> {
log.error("流式对话发生错误: {}", error.getmessage(), error);
});
}
private chatclient getchatclient(string model) {
if ("deepseek".equalsignorecase(model)) {
return deepseekchatclient;
}
return qwenchatclient;
}
private string getmodelname(string model) {
if ("deepseek".equalsignorecase(model)) {
return "deepseek-r1:1.5b";
}
return "qwen3.5:2b";
}
}3.filesearchservice
package com.cxl.service;
import com.cxl.config.knowledgebaseconfig;
import lombok.extern.slf4j.slf4j;
import org.apache.poi.xslf.usermodel.xmlslideshow;
import org.apache.poi.xslf.usermodel.xslfslide;
import org.springframework.ai.document.document;
import org.springframework.ai.vectorstore.searchrequest;
import org.springframework.ai.vectorstore.vectorstore;
import org.springframework.ai.vectorstore.filter.filter.expression;
import org.springframework.ai.vectorstore.filter.filterexpressionbuilder;
import org.springframework.stereotype.service;
import java.io.file;
import java.io.fileinputstream;
import java.io.ioexception;
import java.nio.file.files;
import java.nio.file.path;
import java.nio.file.paths;
import java.util.*;
import java.util.stream.collectors;
@slf4j
@service
public class filesearchservice {
private final vectorstore vectorstore;
private final knowledgebaseconfig knowledgebaseconfig;
public filesearchservice(vectorstore vectorstore, knowledgebaseconfig knowledgebaseconfig) {
this.vectorstore = vectorstore;
this.knowledgebaseconfig = knowledgebaseconfig;
}
public int scandirectory(string directorypath) {
string actualpath = (directorypath == null || directorypath.isempty())
? knowledgebaseconfig.getdefaultpath()
: directorypath;
log.info("开始扫描文件夹: {}", actualpath);
list<document> documents = new arraylist<>();
file directory = new file(actualpath);
if (!directory.exists() || !directory.isdirectory()) {
log.error("文件夹不存在或不是目录: {}", actualpath);
return 0;
}
try {
scanfile(directory, documents);
if (!documents.isempty()) {
vectorstore.add(documents);
log.info("成功扫描 {} 个文件并添加到向量数据库", documents.size());
}
} catch (exception e) {
log.error("扫描文件夹失败: {}", e.getmessage(), e);
return 0;
}
return documents.size();
}
private void scanfile(file file, list<document> documents) throws ioexception {
if (file.isdirectory()) {
file[] files = file.listfiles();
if (files != null) {
for (file f : files) {
scanfile(f, documents);
}
}
} else {
string content = extractcontent(file);
if (content != null && !content.trim().isempty()) {
map<string, object> metadata = new hashmap<>();
metadata.put("filename", file.getname());
metadata.put("path", file.getabsolutepath());
metadata.put("filetype", getfiletype(file.getname()));
metadata.put("extension", getfileextension(file.getname()));
metadata.put("size", (int) file.length());
document doc = document.builder()
.id(uuid.randomuuid().tostring())
.text(content)
.metadata(metadata)
.build();
documents.add(doc);
}
}
}
private string extractcontent(file file) {
try {
string extension = getfileextension(file.getname()).tolowercase();
path path = paths.get(file.getabsolutepath());
if (extension.equals("txt") || extension.equals("md") || extension.equals("json") ||
extension.equals("xml") || extension.equals("html") || extension.equals("css") ||
extension.equals("js") || extension.equals("java") || extension.equals("py") ||
extension.equals("go") || extension.equals("rs") || extension.equals("c") ||
extension.equals("cpp") || extension.equals("h") || extension.equals("sql")) {
return files.readstring(path);
} else if (extension.equals("pptx") || extension.equals("ppt")) {
return extractpptcontent(file);
} else {
return string.format("文件名: %s\n路径: %s\n文件大小: %d bytes\n文件类型: %s",
file.getname(), file.getabsolutepath(), file.length(), getfiletype(file.getname()));
}
} catch (exception e) {
log.warn("提取文件内容失败: {} - {}", file.getabsolutepath(), e.getmessage());
return null;
}
}
private string extractpptcontent(file file) {
string extension = getfileextension(file.getname()).tolowercase();
stringbuilder content = new stringbuilder();
content.append("文件名: ").append(file.getname()).append("\n");
content.append("路径: ").append(file.getabsolutepath()).append("\n\n");
if (!extension.equals("pptx")) {
return string.format("文件名: %s\n路径: %s\n文件大小: %d bytes\n文件类型: ppt(仅支持.pptx格式)",
file.getname(), file.getabsolutepath(), file.length());
}
try (fileinputstream fis = new fileinputstream(file);
xmlslideshow ppt = new xmlslideshow(fis)) {
list<xslfslide> slides = ppt.getslides();
for (int i = 0; i < slides.size(); i++) {
xslfslide slide = slides.get(i);
content.append("=== 第 ").append(i + 1).append(" 页 ===\n");
for (object shape : slide.getshapes()) {
if (shape instanceof org.apache.poi.xslf.usermodel.xslftextshape) {
org.apache.poi.xslf.usermodel.xslftextshape textshape =
(org.apache.poi.xslf.usermodel.xslftextshape) shape;
string text = textshape.gettext();
if (text != null && !text.trim().isempty()) {
content.append(text).append("\n");
}
}
}
content.append("\n");
}
} catch (exception e) {
log.warn("提取ppt内容失败: {} - {}", file.getabsolutepath(), e.getmessage());
return string.format("文件名: %s\n路径: %s\n文件大小: %d bytes\n文件类型: ppt",
file.getname(), file.getabsolutepath(), file.length());
}
return content.tostring();
}
private string getfiletype(string filename) {
string extension = getfileextension(filename).tolowercase();
switch (extension) {
case "txt": case "md": case "json": case "xml": case "html":
case "css": case "js": case "java": case "py": case "go": case "rs":
case "c": case "cpp": case "h": case "sql":
return "text";
case "pdf": case "doc": case "docx": return "document";
case "pptx": case "ppt": return "presentation";
case "jpg": case "jpeg": case "png": case "gif": case "webp": return "image";
case "mp4": case "avi": case "mov": case "wmv": return "video";
default: return "other";
}
}
private string getfileextension(string filename) {
int lastdotindex = filename.lastindexof('.');
return lastdotindex > 0 ? filename.substring(lastdotindex + 1) : "";
}
public list<documentinfo> search(string query, int topk, map<string, object> filters) {
log.info("开始检索: {}", query);
try {
int actualtopk = (topk > 0) ? topk : knowledgebaseconfig.gettopk();
searchrequest.builder builder = searchrequest.builder()
.query(query)
.topk(actualtopk);
if (filters != null && filters.containskey("filetype")) {
string filtertype = (string) filters.get("filetype");
expression filter = new filterexpressionbuilder()
.eq("filetype", filtertype)
.build();
builder.filterexpression(filter);
}
list<org.springframework.ai.document.document> results = vectorstore.similaritysearch(builder.build());
return results.stream()
.map(this::todocumentinfo)
.collect(collectors.tolist());
} catch (exception e) {
log.error("检索失败: {}", e.getmessage(), e);
return new arraylist<>();
}
}
private documentinfo todocumentinfo(org.springframework.ai.document.document doc) {
documentinfo info = new documentinfo();
info.id = doc.getid();
info.content = doc.gettext();
info.filename = (string) doc.getmetadata().getordefault("filename", "");
info.path = (string) doc.getmetadata().getordefault("path", "");
info.filetype = (string) doc.getmetadata().getordefault("filetype", "");
info.extension = (string) doc.getmetadata().getordefault("extension", "");
object sizeobj = doc.getmetadata().get("size");
info.size = (sizeobj != null) ? ((number) sizeobj).longvalue() : 0l;
return info;
}
public void clearall() {
try {
searchrequest searchrequest = searchrequest.builder()
.query("*")
.topk(1000)
.build();
list<org.springframework.ai.document.document> alldocs = vectorstore.similaritysearch(searchrequest);
list<string> ids = alldocs.stream()
.map(org.springframework.ai.document.document::getid)
.collect(collectors.tolist());
if (!ids.isempty()) {
vectorstore.delete(ids);
}
log.info("已清除向量数据库中的所有文档,共 {} 条", ids.size());
} catch (exception e) {
log.error("清除向量数据库失败: {}", e.getmessage(), e);
}
}
public static class documentinfo {
public string id;
public string filename;
public string path;
public string content;
public string filetype;
public string extension;
public long size;
}
}4.sessionmanagerservice
package com.cxl.service;
import lombok.data;
import org.springframework.stereotype.service;
import java.util.map;
import java.util.concurrent.concurrenthashmap;
import java.util.concurrent.executors;
import java.util.concurrent.scheduledexecutorservice;
import java.util.concurrent.timeunit;
@service
public class sessionmanagerservice {
private final map<string, sessioninfo> sessions = new concurrenthashmap<>();
private final scheduledexecutorservice executorservice = executors.newscheduledthreadpool(1);
private int sessiontimeoutminutes = 60;
public sessionmanagerservice() {
executorservice.scheduleatfixedrate(this::cleanupexpiredsessions,
sessiontimeoutminutes,
sessiontimeoutminutes,
timeunit.minutes);
}
public sessioninfo getorcreatesession(string sessionid) {
return sessions.computeifabsent(sessionid, id -> {
sessioninfo session = new sessioninfo(id);
return session;
});
}
public void removesession(string sessionid) {
sessions.remove(sessionid);
}
private void cleanupexpiredsessions() {
long now = system.currenttimemillis();
long timeoutms = sessiontimeoutminutes * 60 * 1000;
sessions.entryset().removeif(entry -> {
sessioninfo session = entry.getvalue();
boolean expired = now - session.getlastactivitytime() > timeoutms;
return expired;
});
}
@data
public static class sessioninfo {
private final string sessionid;
private long lastactivitytime;
private boolean contextaware = true;
public sessioninfo(string sessionid) {
this.sessionid = sessionid;
this.lastactivitytime = system.currenttimemillis();
}
public void updatelastactivity() {
this.lastactivitytime = system.currenttimemillis();
}
}
}9.向量数据库初始化
package com.cxl;
import java.net.uri;
import java.net.http.httpclient;
import java.net.http.httprequest;
import java.net.http.httpresponse;
public class qdrantcollectioncreator {
private static final string collection_name = "my_docs";
private static final int vector_dimension = 1024;
private static final string host = "localhost";
private static final int port = 6333;
public static void main(string[] args) {
system.out.println("=== qdrant collection creator ===");
httpclient client = httpclient.newhttpclient();
system.out.println("\n1. checking qdrant service availability...");
string healthurl = "http://" + host + ":" + port + "/";
try {
httprequest healthrequest = httprequest.newbuilder()
.uri(uri.create(healthurl))
.get()
.build();
httpresponse<string> healthresponse = client.send(healthrequest, httpresponse.bodyhandlers.ofstring());
system.out.println("qdrant service status: " + healthresponse.statuscode());
system.out.println("response: " + healthresponse.body());
} catch (exception e) {
system.err.println("error: cannot connect to qdrant at http://" + host + ":" + port);
return;
}
system.out.println("\n2. listing available collections...");
string listurl = "http://" + host + ":" + port + "/collections";
try {
httprequest listrequest = httprequest.newbuilder()
.uri(uri.create(listurl))
.get()
.build();
httpresponse<string> listresponse = client.send(listrequest, httpresponse.bodyhandlers.ofstring());
system.out.println("status: " + listresponse.statuscode());
system.out.println("response: " + listresponse.body());
} catch (exception e) {
system.out.println("error: " + e.getmessage());
}
system.out.println("\n3. checking if collection exists...");
string checkurl = "http://" + host + ":" + port + "/collections/" + collection_name;
try {
httprequest checkrequest = httprequest.newbuilder()
.uri(uri.create(checkurl))
.get()
.build();
httpresponse<string> checkresponse = client.send(checkrequest, httpresponse.bodyhandlers.ofstring());
system.out.println("status: " + checkresponse.statuscode());
if (checkresponse.statuscode() == 200) {
system.out.println("collection '" + collection_name + "' already exists!");
return;
}
} catch (exception e) {
system.out.println("collection does not exist. will create it...");
}
system.out.println("\n4. creating collection with put...");
string createurl = "http://" + host + ":" + port + "/collections/" + collection_name;
string requestbody = string.format("""
{
"vectors": {
"size": %d,
"distance": "cosine"
}
}
""", vector_dimension);
try {
httprequest createrequest = httprequest.newbuilder()
.uri(uri.create(createurl))
.header("content-type", "application/json")
.put(httprequest.bodypublishers.ofstring(requestbody))
.build();
httpresponse<string> createresponse = client.send(createrequest, httpresponse.bodyhandlers.ofstring());
system.out.println("status: " + createresponse.statuscode());
system.out.println("response: " + createresponse.body());
if (createresponse.statuscode() == 200 || createresponse.statuscode() == 201) {
system.out.println("\n✓ collection '" + collection_name + "' created successfully!");
return;
}
} catch (exception e) {
system.err.println("error with put: " + e.getmessage());
}
system.out.println("\n5. trying post to /collections endpoint...");
string posturl = "http://" + host + ":" + port + "/collections";
try {
httprequest postrequest = httprequest.newbuilder()
.uri(uri.create(posturl))
.header("content-type", "application/json")
.post(httprequest.bodypublishers.ofstring(requestbody))
.build();
httpresponse<string> postresponse = client.send(postrequest, httpresponse.bodyhandlers.ofstring());
system.out.println("status: " + postresponse.statuscode());
system.out.println("response: " + postresponse.body());
if (postresponse.statuscode() == 200 || postresponse.statuscode() == 201) {
system.out.println("\n✓ collection '" + collection_name + "' created successfully!");
} else {
system.err.println("\n✗ failed to create collection");
}
} catch (exception e) {
system.err.println("error with post: " + e.getmessage());
}
}
}10.启动类
package com.cxl;
import com.cxl.config.appconfig;
import com.cxl.config.systemsettings;
import org.springframework.boot.springapplication;
import org.springframework.boot.autoconfigure.springbootapplication;
import org.springframework.boot.context.properties.enableconfigurationproperties;
@springbootapplication
@enableconfigurationproperties({systemsettings.class, appconfig.class})
public class springaidemoapplication {
public static void main(string[] args) {
springapplication.run(springaidemoapplication.class, args);
}
}到此这篇关于基于springai+qdrant+ollama本地模型和向量数据库开发问答和rag检索(完整代码)的文章就介绍到这了,更多相关springai 向量数据库开发问答和rag检索内容请搜索代码网以前的文章或继续浏览下面的相关文章希望大家以后多多支持代码网!
发表评论