本教程详细介绍如何使用 Spring Boot 3 整合 Spring AI 实现一个具有记忆功能的 AI 助手。该实现使用 Redis 作为存储介质,支持用户级别的会话隔离和 30 天的对话历史持久化。
确保 JAVA_HOME 和 MAVEN_HOME 已正确配置。
使用 Spring Initializr 创建项目:
在 pom.xml 文件中添加以下依赖:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
<dependencies> <!-- Spring Boot 核心依赖 --> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency> <!-- Spring Data Redis --> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-data-redis</artifactId> </dependency> <!-- MyBatis Plus --> <dependency> <groupId>com.baomidou</groupId> <artifactId>mybatis-plus-boot-starter</artifactId> <version>3.5.5</version> </dependency> <!-- MySQL 驱动 --> <dependency> <groupId>com.mysql</groupId> <artifactId>mysql-connector-j</artifactId> <scope>runtime</scope> </dependency> <!-- Spring AI --> <dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-openai</artifactId> <version>1.0.0</version> </dependency> <!-- Sa-Token 认证 --> <dependency> <groupId>cn.dev33</groupId> <artifactId>sa-token-spring-boot-starter</artifactId> <version>1.38.1</version> </dependency> <!-- Jackson 序列化 --> <dependency> <groupId>com.fasterxml.jackson.core</groupId> <artifactId>jackson-databind</artifactId> </dependency> <!-- Lombok --> <dependency> <groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId> <optional>true</optional> </dependency> <!-- Spring Boot 测试 --> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-test</artifactId> <scope>test</scope> </dependency> </dependencies> |
创建 src/main/resources/application.yml 文件,配置应用信息:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
server: port: 9527 servlet: context-path: /api spring: application: name: smart-pic-community-backend # Redis 配置 data: redis: database: 2 host: localhost port: 6379 timeout: 5000 # 数据库配置 datasource: driver-class-name: com.mysql.cj.jdbc.Driver url: jdbc:mysql://localhost:3306/smart_pic_community username: root password: your_password # Spring AI 配置 ai: openai: base-url: https://api.deepseek.com/ # 使用 DeepSeek API api-key: your_api_key chat: options: model: deepseek-chat |
创建 Redis 配置类,确保正确序列化对象:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
package com.spc.smartpiccommunitybackend.config; import com.fasterxml.jackson.annotation.JsonAutoDetect; import com.fasterxml.jackson.annotation.PropertyAccessor; import com.fasterxml.jackson.databind.ObjectMapper; import lombok.extern.slf4j.Slf4j; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.data.redis.connection.RedisConnectionFactory; import org.springframework.data.redis.core.RedisTemplate; import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer; import org.springframework.data.redis.serializer.StringRedisSerializer; @Configuration @Slf4j public class RedisConfiguration { @Bean public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) { log.info("开始创建redis模板对象..."); RedisTemplate<String, Object> redisTemplate = new RedisTemplate<>(); // 设置连接工厂 redisTemplate.setConnectionFactory(redisConnectionFactory); // 使用 StringRedisSerializer 来序列化和反序列化 redis 的 key StringRedisSerializer stringRedisSerializer = new StringRedisSerializer(); // key 采用 String 的序列化方式 redisTemplate.setKeySerializer(stringRedisSerializer); redisTemplate.setHashKeySerializer(stringRedisSerializer); // 使用 Jackson2JsonRedisSerializer 来序列化和反序列化 redis 的 value Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class); ObjectMapper objectMapper = new ObjectMapper(); objectMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY); objectMapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL); jackson2JsonRedisSerializer.setObjectMapper(objectMapper); // value 采用 JSON 的序列化方式 redisTemplate.setValueSerializer(jackson2JsonRedisSerializer); redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer); redisTemplate.afterPropertiesSet(); return redisTemplate; } } |
注意事项:
创建消息视图对象,用于前端展示:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
package com.spc.smartpiccommunitybackend.model.vo.ai;
import lombok.Data; import lombok.NoArgsConstructor; import org.springframework.ai.chat.messages.Message;
@NoArgsConstructor @Data public class MessageVO { private String role; private String content;
public MessageVO(Message message) { this.role = switch (message.getMessageType()) { case USER -> "user"; case ASSISTANT -> "assistant"; case SYSTEM -> "system"; default -> ""; }; this.content = message.getText(); } } |
功能说明:
创建可序列化的消息对象,用于 Redis 存储:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
package com.spc.smartpiccommunitybackend.model.entity.ai;
import lombok.Data; import lombok.NoArgsConstructor; import java.io.Serializable;
@Data @NoArgsConstructor public class SerializableMessage implements Serializable { private static final long serialVersionUID = 1L;
private String role; private String content; private String messageType; private Long timestamp;
public SerializableMessage(String role, String content, String messageType) { this.role = role; this.content = content; this.messageType = messageType; this.timestamp = System.currentTimeMillis(); }
public SerializableMessage(String role, String content) { this(role, content, "user"); } } |
功能说明:
创建聊天历史仓库接口:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
package com.spc.smartpiccommunitybackend.repository;
import java.util.List; import java.util.Map;
public interface ChatHistoryRepository { /** * 保存会话记录 */ void save(String type, String chatId, Long userId);
/** * 获取用户的会话ID列表 */ List<String> getChatIds(Long userId, String type);
/** * 保存聊天消息 */ void saveMessage(String chatId, String message, String sender);
/** * 获取聊天消息历史 */ List<String> getMessages(String chatId);
/** * 删除会话 */ void deleteChat(Long userId, String type, String chatId);
/** * 获取会话信息 */ Map<Object, Object> getSessionInfo(String chatId); } |
实现基于 Redis 的聊天历史仓库:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
package com.spc.smartpiccommunitybackend.repository;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.core.JsonProcessingException; import lombok.RequiredArgsConstructor; import org.springframework.data.redis.core.RedisTemplate; import org.springframework.stereotype.Component;
import java.util.*; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors;
@Component @RequiredArgsConstructor public class RedisChatHistoryRepository implements ChatHistoryRepository {
private final RedisTemplate<String, Object> redisTemplate;
// Redis key前缀 private static final String CHAT_HISTORY_PREFIX = "chat:history:"; private static final String CHAT_SESSION_PREFIX = "chat:session:"; private static final String CHAT_MESSAGES_PREFIX = "chat:messages:";
/** * 保存会话记录 */ @Override public void save(String type, String chatId, Long userId) { // 保存会话信息 String sessionKey = CHAT_SESSION_PREFIX + chatId; Map<String, Object> sessionInfo = new HashMap<>(); sessionInfo.put("userId", String.valueOf(userId)); sessionInfo.put("type", type); sessionInfo.put("createTime", System.currentTimeMillis()); sessionInfo.put("lastUpdateTime", System.currentTimeMillis()); redisTemplate.opsForHash().putAll(sessionKey, sessionInfo); // 设置过期时间为30天 redisTemplate.expire(sessionKey, 30, TimeUnit.DAYS);
// 将chatId添加到用户的聊天历史列表中 String historyKey = CHAT_HISTORY_PREFIX + userId + ":" + type; redisTemplate.opsForSet().add(historyKey, chatId); // 设置过期时间为30天 redisTemplate.expire(historyKey, 30, TimeUnit.DAYS); }
/** * 获取用户的会话ID列表 */ @Override public List<String> getChatIds(Long userId, String type) { String historyKey = CHAT_HISTORY_PREFIX + userId + ":" + type; Set<Object> chatIds = redisTemplate.opsForSet().members(historyKey); if (chatIds == null || chatIds.isEmpty()) { return Collections.emptyList(); } return chatIds.stream() .map(Object::toString) .collect(Collectors.toList()); }
/** * 保存聊天消息 */ @Override public void saveMessage(String chatId, String message, String sender) { String messagesKey = CHAT_MESSAGES_PREFIX + chatId; // 创建消息对象 Map<String, Object> messageInfo = new HashMap<>(); messageInfo.put("content", message); messageInfo.put("sender", sender); messageInfo.put("timestamp", System.currentTimeMillis()); // 使用JSON格式保存消息 ObjectMapper objectMapper = new ObjectMapper(); try { String jsonMessage = objectMapper.writeValueAsString(messageInfo); redisTemplate.opsForList().rightPush(messagesKey, jsonMessage); } catch (JsonProcessingException e) { e.printStackTrace(); // 如果JSON序列化失败,使用原始消息 redisTemplate.opsForList().rightPush(messagesKey, message); } // 设置过期时间为30天 redisTemplate.expire(messagesKey, 30, TimeUnit.DAYS);
// 更新会话的最后更新时间 String sessionKey = CHAT_SESSION_PREFIX + chatId; redisTemplate.opsForHash().put(sessionKey, "lastUpdateTime", System.currentTimeMillis()); // 确保会话信息也有过期时间 redisTemplate.expire(sessionKey, 30, TimeUnit.DAYS); }
/** * 获取聊天消息历史 */ @Override public List<String> getMessages(String chatId) { String messagesKey = CHAT_MESSAGES_PREFIX + chatId; List<Object> messages = redisTemplate.opsForList().range(messagesKey, 0, -1); if (messages == null || messages.isEmpty()) { return Collections.emptyList(); } return messages.stream() .map(Object::toString) .collect(Collectors.toList()); }
/** * 删除会话 */ @Override public void deleteChat(Long userId, String type, String chatId) { // 从用户的聊天历史列表中删除 String historyKey = CHAT_HISTORY_PREFIX + userId + ":" + type; redisTemplate.opsForSet().remove(historyKey, chatId);
// 删除会话信息 String sessionKey = CHAT_SESSION_PREFIX + chatId; redisTemplate.delete(sessionKey);
// 删除聊天消息 String messagesKey = CHAT_MESSAGES_PREFIX + chatId; redisTemplate.delete(messagesKey); }
/** * 获取会话信息 */ @Override public Map<Object, Object> getSessionInfo(String chatId) { String sessionKey = CHAT_SESSION_PREFIX + chatId; return redisTemplate.opsForHash().entries(sessionKey); } } |
注意事项:
实现基于 Redis 的聊天记忆,支持 Spring AI 的 ChatMemory 接口:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
package com.spc.smartpiccommunitybackend.config;
import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.data.redis.core.RedisTemplate; import org.springframework.stereotype.Component;
import java.util.ArrayList; import java.util.List; import java.util.concurrent.TimeUnit;
@Component public class RedisChatMemory implements ChatMemory {
private final RedisTemplate<String, Object> redisTemplate; private static final String MEMORY_KEY_PREFIX = "chat:memory:"; private static final long EXPIRATION_DAYS = 30;
public RedisChatMemory(RedisTemplate<String, Object> redisTemplate) { this.redisTemplate = redisTemplate; }
@Override public void add(String key, List<Message> messages) { // 为特定会话添加多条消息 for (Message message : messages) { addMessage(key, message); } }
@Override public List<Message> get(String key, int maxCount) { // 实现 get 方法,根据 key 获取消息 List<Message> messages = getMessages(key); // 如果指定了最大数量,返回不超过该数量的消息 if (maxCount > 0 && messages.size() > maxCount) { return messages.subList(messages.size() - maxCount, messages.size()); } return messages; }
@Override public void clear() { // 清理所有会话记忆 // 注意:这个操作会删除所有聊天记忆,谨慎使用 }
/** * 为特定会话添加消息 */ public void addMessage(String chatId, Message message) { String key = MEMORY_KEY_PREFIX + chatId; redisTemplate.opsForList().rightPush(key, message); redisTemplate.expire(key, EXPIRATION_DAYS, TimeUnit.DAYS); }
/** * 获取特定会话的消息 */ public List<Message> getMessages(String chatId) { String key = MEMORY_KEY_PREFIX + chatId; List<Object> objects = redisTemplate.opsForList().range(key, 0, -1); List<Message> messages = new ArrayList<>(); if (objects != null) { for (Object obj : objects) { if (obj instanceof Message) { messages.add((Message) obj); } } } return messages; }
/** * 清理特定会话的记忆 */ public void clear(String chatId) { String key = MEMORY_KEY_PREFIX + chatId; redisTemplate.delete(key); } } |
注意事项:
配置 Spring AI 的 ChatClient:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
package com.spc.smartpiccommunitybackend.config;
import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration;
@Configuration public class CommonConfiguration {
@Bean public ChatClient chatClient(OpenAiChatModel openAiChatModel, ChatMemory chatMemory) { return ChatClient.builder(openAiChatModel) .defaultAdvisors( new SimpleLoggerAdvisor(), new MessageChatMemoryAdvisor(chatMemory) ) .build(); } } |
注意事项:
实现聊天控制器,处理 AI 对话请求:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
package com.spc.smartpiccommunitybackend.controller;
import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.spc.smartpiccommunitybackend.repository.RedisChatHistoryRepository; import com.spc.smartpiccommunitybackend.service.UserService; import com.spc.smartpiccommunitybackend.utils.ErrorCode; import com.spc.smartpiccommunitybackend.utils.ThrowUtils; import com.spc.smartpiccommunitybackend.pojo.User; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.beans.factory.annotation.Resource; import org.springframework.web.bind.annotation.*; import reactor.core.publisher.Flux;
import javax.servlet.http.HttpServletRequest; import java.io.IOException; import java.util.ArrayList; import java.util.List;
@RestController @RequestMapping("/ai") public class ChatController {
private final ChatClient chatClient; private final RedisChatHistoryRepository chatHistoryRepository; @Resource private UserService userService;
public ChatController(ChatClient chatClient, RedisChatHistoryRepository chatHistoryRepository) { this.chatClient = chatClient; this.chatHistoryRepository = chatHistoryRepository; }
@RequestMapping(value = "/chat", produces = "text/html;charset=UTF-8") public Flux<String> chat(@RequestParam(defaultValue = "讲个笑话") String prompt, String chatId, HttpServletRequest request) { User loginUser = userService.getLoginUser(request); // 校验登录用户是否为空 ThrowUtils.throwIf(loginUser == null, ErrorCode.NOT_LOGIN_ERROR); Long userId = loginUser.getId();
// 保存会话信息 chatHistoryRepository.save("chat", chatId, userId);
// 获取历史对话消息作为上下文 List<Message> messages = new ArrayList<>();
// 添加系统消息 SystemMessage systemMessage = new SystemMessage( "你是一个智能图片社区的AI助手,名为虹小智。请用友好、专业的语气回答用户问题," + "提供关于图片社区的相关信息和帮助。" ); messages.add(systemMessage);
// 获取并解析历史消息 List<String> historyMessages = chatHistoryRepository.getMessages(chatId); ObjectMapper objectMapper = new ObjectMapper();
for (String messageStr : historyMessages) { try { JsonNode node = objectMapper.readTree(messageStr); String sender = node.get("sender").asText(); String content = node.get("content").asText();
if ("user".equals(sender)) { messages.add(new UserMessage(content)); } else if ("ai".equals(sender)) { messages.add(new AssistantMessage(content)); } } catch (IOException e) { e.printStackTrace(); } }
// 添加用户当前消息 messages.add(new UserMessage(prompt));
// 保存用户消息到历史记录 chatHistoryRepository.saveMessage(chatId, prompt, "user");
// 调用AI模型获取响应 return chatClient.stream(messages) .doOnNext(response -> { // 保存AI响应到历史记录 chatHistoryRepository.saveMessage(chatId, response, "ai"); }); } } |
注意事项:
实现聊天历史控制器,处理聊天历史的获取和删除:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
package com.spc.smartpiccommunitybackend.controller;
import com.spc.smartpiccommunitybackend.repository.RedisChatHistoryRepository; import com.spc.smartpiccommunitybackend.service.UserService; import com.spc.smartpiccommunitybackend.utils.ErrorCode; import com.spc.smartpiccommunitybackend.utils.ThrowUtils; import com.spc.smartpiccommunitybackend.pojo.User; import org.springframework.web.bind.annotation.*;
import javax.servlet.http.HttpServletRequest; import java.util.List;
@RestController @RequestMapping("/ai/history") public class ChatHistoryController {
private final RedisChatHistoryRepository chatHistoryRepository; private final UserService userService;
public ChatHistoryController(RedisChatHistoryRepository chatHistoryRepository, UserService userService) { this.chatHistoryRepository = chatHistoryRepository; this.userService = userService; }
/** * 获取用户的聊天历史ID列表 */ @GetMapping("/{type}") public List<String> getChatHistory(@PathVariable String type, HttpServletRequest request) { User loginUser = userService.getLoginUser(request); ThrowUtils.throwIf(loginUser == null, ErrorCode.NOT_LOGIN_ERROR); Long userId = loginUser.getId();
return chatHistoryRepository.getChatIds(userId, type); }
/** * 删除指定聊天历史 */ @DeleteMapping("/{type}/{chatId}") public boolean deleteChatHistory(@PathVariable String type, @PathVariable String chatId, HttpServletRequest request) { User loginUser = userService.getLoginUser(request); ThrowUtils.throwIf(loginUser == null, ErrorCode.NOT_LOGIN_ERROR); Long userId = loginUser.getId();
chatHistoryRepository.deleteChat(userId, type, chatId); return true; } } |
注意事项: