詹学伟
詹学伟
Published on 2025-11-23 / 8 Visits
0
0

saa-会话持久化

一、说明

本章节主要介绍saa的会话持久化功能,同时使用的统一父工程(见saa系列第一篇)。

其实这个和之前langchain4j中持久化一样,都是将用户的会话信息存储到指定的位置,记得之前使用langchain4j的时候是存储到了mongodb,本章节使用的redis

二、代码

结构:

POM文件:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
	xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
	<modelVersion>4.0.0</modelVersion>
	<parent>
		<groupId>com.saa</groupId>
		<artifactId>saa-parent</artifactId>
		<version>0.0.1-SNAPSHOT</version>
	</parent>
	<groupId>com.saa</groupId>
	<artifactId>saa-persistent</artifactId>
	<version>0.0.1-SNAPSHOT</version>
	<name>saa-persistent</name>
	<description>saa-persistent</description>
	<properties>
		<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
		<maven.compiler.source>17</maven.compiler.source>
		<maven.compiler.target>17</maven.compiler.target>
	</properties>
	<dependencies>
		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-web</artifactId>
		</dependency>
		<dependency>
			<groupId>com.alibaba.cloud.ai</groupId>
			<artifactId>spring-ai-alibaba-starter-dashscope</artifactId>
		</dependency>

		<dependency>
			<groupId>com.alibaba.cloud.ai</groupId>
			<artifactId>spring-ai-alibaba-starter-memory-redis</artifactId>
		</dependency>

		<dependency>
			<groupId>redis.clients</groupId>
			<artifactId>jedis</artifactId>
		</dependency>

		<dependency>
			<groupId>org.projectlombok</groupId>
			<artifactId>lombok</artifactId>
			<version>1.18.34</version>
			<optional>true</optional>
		</dependency>

		<dependency>
			<groupId>cn.hutool</groupId>
			<artifactId>hutool-all</artifactId>
			<version>5.8.16</version>
		</dependency>
		<dependency>
			<groupId>junit</groupId>
			<artifactId>junit</artifactId>
			<scope>test</scope>
		</dependency>

		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-test</artifactId>
		</dependency>
	</dependencies>


	<build>
		<plugins>
			<plugin>
				<groupId>org.apache.maven.plugins</groupId>
				<artifactId>maven-compiler-plugin</artifactId>
				<version>3.8.1</version>
				<configuration>
					<source>17</source>
					<target>17</target>
					<encoding>UTF-8</encoding>
				</configuration>
			</plugin>
			<plugin>
				<groupId>org.springframework.boot</groupId>
				<artifactId>spring-boot-maven-plugin</artifactId>
				<version>${spring-boot.version}</version>
			</plugin>
		</plugins>
	</build>

</project>

YML配置文件:

spring:
  application:
    name: saa-persistent
  data:
    redis:
      host: 127.0.0.1
      port: 6379
      password: 123456
      timeout: 5000
      connectionTimeout: 5000
      database: 0

  ai:
    dashscope:
      api-key: sk-09c7b571687b46d5a2e25a03fbddxxxx
      base-url: https://dashscope.aliyuncs.com/compatible-mode/v1
      chat:
        options:
          model: qwen3-max

server:
  port: 8086
  servlet:
    encoding:
      enabled: true
      force: true
      charset: UTF-8

配置类:

SsaLLMConfig

package com.saa.persistent.config;

import com.alibaba.cloud.ai.dashscope.api.DashScopeApi;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatModel;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

/**
 * @author zhanxuewei
 */
@Configuration
public class SsaLLMConfig {


    @Value("${spring.ai.dashscope.api-key}")
    private String apiKey;

    private static final String MODEL_DEEPSEEK = "deepseek-v3";
    private static final String MODEL_QWEN = "qwen-max";

    @Bean(name = "deepseek")
    public ChatModel deepseek() {
        return DashScopeChatModel.builder()
                .dashScopeApi(DashScopeApi.builder().apiKey(apiKey).build())
                .defaultOptions(DashScopeChatOptions.builder().withModel(MODEL_DEEPSEEK).build())
                .build();
    }

    @Bean(name = "qwen")
    public ChatModel qwen() {
        return DashScopeChatModel.builder()
                .dashScopeApi(DashScopeApi.builder().apiKey(apiKey).build())
                .defaultOptions(DashScopeChatOptions.builder().withModel(MODEL_QWEN).build())
                .build();
    }

    /**
     * 添加了持久化,并自定义持久化RedisChatMemoryRepository
     *
     * @param deepseek
     * @param redisChatMemoryRepository
     * @return
     */
    @Bean(name = "deepseekChatClient")
    public ChatClient deepeekChatClient(@Qualifier("deepseek") ChatModel deepseek,
                                        CustomerRedisChatMemoryRepository redisChatMemoryRepository) {

        MessageWindowChatMemory windowChatMemory = MessageWindowChatMemory.builder()
                .chatMemoryRepository(redisChatMemoryRepository)
                .maxMessages(10)
                .build();
        ChatClient chatClient = ChatClient.builder(deepseek)
                .defaultOptions(ChatOptions.builder().model(MODEL_DEEPSEEK).build())
                .defaultAdvisors(MessageChatMemoryAdvisor.builder(windowChatMemory)
                        .build())
                .build();

        return chatClient;
    }

    /*@Bean(name = "deepseekChatClient")
    public ChatClient deepeekChatClient(@Qualifier("deepseek") ChatModel deepseek,
                                        RedisChatMemoryRepository redisChatMemoryRepository) {

        MessageWindowChatMemory windowChatMemory = MessageWindowChatMemory.builder()
                .chatMemoryRepository(redisChatMemoryRepository)
                .maxMessages(10)
                .build();
        ChatClient chatClient = ChatClient.builder(deepseek)
                .defaultOptions(ChatOptions.builder().model(MODEL_DEEPSEEK).build())
                .defaultAdvisors(MessageChatMemoryAdvisor.builder(windowChatMemory)
                        .build())
                .build();

        return chatClient;
    }*/

    @Bean(name = "qwenChatClient")
    public ChatClient qwenChatClient(@Qualifier("qwen") ChatModel qwen) {
        return ChatClient.builder(qwen)
                .defaultOptions(ChatOptions.builder().model(MODEL_QWEN).build())
                .build();
    }

}

RedisMemoryConfig

package com.saa.persistent.config;

import com.alibaba.cloud.ai.memory.redis.RedisChatMemoryRepository;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
public class RedisMemoryConfig {

    @Value("${spring.data.redis.host}")
    private String host;

    @Value("${spring.data.redis.port}")
    private int port;

    @Value("${spring.data.redis.password}")
    private String password;

    @Bean
    public RedisChatMemoryRepository redisChatMemoryRepository() {
        return RedisChatMemoryRepository.builder()
                .host(host)
                .port(port)
                .password(password)
                .build();
    }

    @Bean
    public CustomerRedisChatMemoryRepository myRedisChatMemoryRepository() {
        return CustomerRedisChatMemoryRepository.builder()
                .host(host)
                .port(port)
                .password(password)
                .build();
    }

}

CustomerRedisChatMemoryRepository

/*
 * Copyright 2024-2025 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.saa.persistent.config;

import com.alibaba.cloud.ai.memory.redis.serializer.MessageDeserializer;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPoolConfig;

import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.Message;
import org.springframework.util.Assert;

import java.util.ArrayList;
import java.util.List;

/**
 * 自定义RedisChatMemoryRepository,这里仅修改了DEFAULT_KEY_PREFIX
 */
public class CustomerRedisChatMemoryRepository implements ChatMemoryRepository, AutoCloseable {

	private static final Logger logger = LoggerFactory.getLogger(CustomerRedisChatMemoryRepository.class);

	private static final String DEFAULT_KEY_PREFIX = "customer_chat_memory:";

	private final JedisPool jedisPool;

	private final ObjectMapper objectMapper;

	private CustomerRedisChatMemoryRepository(JedisPool jedisPool) {
		Assert.notNull(jedisPool, "jedisPool cannot be null");
		this.jedisPool = jedisPool;
		this.objectMapper = new ObjectMapper();
		SimpleModule module = new SimpleModule();
		module.addDeserializer(Message.class, new MessageDeserializer());
		this.objectMapper.registerModule(module);
	}

	public static RedisBuilder builder() {
		return new RedisBuilder();
	}

	public static class RedisBuilder {

		private String host = "127.0.0.1";

		private int port = 6379;

		private String password;

		private int timeout = 2000;

		private JedisPoolConfig poolConfig;

		public RedisBuilder host(String host) {
			this.host = host;
			return this;
		}

		public RedisBuilder port(int port) {
			this.port = port;
			return this;
		}

		public RedisBuilder password(String password) {
			this.password = password;
			return this;
		}

		public RedisBuilder timeout(int timeout) {
			this.timeout = timeout;
			return this;
		}

		public RedisBuilder poolConfig(JedisPoolConfig poolConfig) {
			this.poolConfig = poolConfig;
			return this;
		}

		public CustomerRedisChatMemoryRepository build() {
			if (poolConfig == null) {
				poolConfig = new JedisPoolConfig();
			}
			JedisPool jedisPool = new JedisPool(poolConfig, host, port, timeout, password);
			return new CustomerRedisChatMemoryRepository(jedisPool);
		}

	}

	@Override
	public List<String> findConversationIds() {
		try (Jedis jedis = jedisPool.getResource()) {
			List<String> keys = new ArrayList<>(jedis.keys(DEFAULT_KEY_PREFIX + "*"));
			return keys.stream().map(key -> key.substring(DEFAULT_KEY_PREFIX.length())).toList();
		}
	}

	@Override
	public List<Message> findByConversationId(String conversationId) {
		Assert.hasText(conversationId, "conversationId cannot be null or empty");
		try (Jedis jedis = jedisPool.getResource()) {
			String key = DEFAULT_KEY_PREFIX + conversationId;
			List<String> messageStrings = jedis.lrange(key, 0, -1);
			List<Message> messages = new ArrayList<>();

			for (String messageString : messageStrings) {
				try {
					Message message = objectMapper.readValue(messageString, Message.class);
					messages.add(message);
				}
				catch (JsonProcessingException e) {
					throw new RuntimeException("Error deserializing message", e);
				}
			}
			return messages;
		}
	}

	@Override
	public void saveAll(String conversationId, List<Message> messages) {
		Assert.hasText(conversationId, "conversationId cannot be null or empty");
		Assert.notNull(messages, "messages cannot be null");
		Assert.noNullElements(messages, "messages cannot contain null elements");

		try (Jedis jedis = jedisPool.getResource()) {
			String key = DEFAULT_KEY_PREFIX + conversationId;
			// Clear existing messages first
			deleteByConversationId(conversationId);

			// Add all messages in order
			for (Message message : messages) {
				try {
					String messageJson = objectMapper.writeValueAsString(message);
					jedis.rpush(key, messageJson);
				}
				catch (JsonProcessingException e) {
					throw new RuntimeException("Error serializing message", e);
				}
			}
		}
	}

	@Override
	public void deleteByConversationId(String conversationId) {
		Assert.hasText(conversationId, "conversationId cannot be null or empty");
		try (Jedis jedis = jedisPool.getResource()) {
			String key = DEFAULT_KEY_PREFIX + conversationId;
			jedis.del(key);
		}
	}

	/**
	 * Clear messages over the limit for a conversation
	 * @param conversationId the conversation ID
	 * @param maxLimit maximum number of messages to keep
	 * @param deleteSize number of messages to delete when over limit
	 */
	public void clearOverLimit(String conversationId, int maxLimit, int deleteSize) {
		Assert.hasText(conversationId, "conversationId cannot be null or empty");
		try (Jedis jedis = jedisPool.getResource()) {
			String key = DEFAULT_KEY_PREFIX + conversationId;
			List<String> all = jedis.lrange(key, 0, -1);

			if (all.size() >= maxLimit) {
				all = all.stream().skip(Math.max(0, deleteSize)).toList();
				deleteByConversationId(conversationId);
				for (String message : all) {
					jedis.rpush(key, message);
				}
			}
		}
	}

	@Override
	public void close() {
		if (jedisPool != null) {
			jedisPool.close();
			logger.info("Redis connection pool closed");
		}
	}

}

控制器PersistentController

package com.saa.persistent.controller;

import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import reactor.core.publisher.Flux;

import java.util.function.Consumer;

import static org.springframework.ai.chat.memory.ChatMemory.CONVERSATION_ID;

@Slf4j
@RequestMapping("persistent")
@RestController
public class PersistentController {

    @Resource(name = "deepseek")
    private ChatModel deepseekChatModel;

    @Resource(name = "qwen")
    private ChatModel qwenChatModel;

    @Resource(name = "deepseekChatClient")
    private ChatClient deepseekChatClient;

    @Resource(name = "qwenChatClient")
    private ChatClient qwenChatClient;

    @GetMapping("/chat")
    public String chat(@RequestParam(name = "question") String question,
                       @RequestParam(name = "userId") String userId) {
        String content = deepseekChatClient.prompt(question)
                .advisors(new Consumer<ChatClient.AdvisorSpec>() {
                    @Override
                    public void accept(ChatClient.AdvisorSpec advisorSpec) {
                        advisorSpec.param(CONVERSATION_ID, userId);
                    }
                }).call().content();
        return content;
    }

    @GetMapping("/chat2")
    public Flux<String> streamChat(@RequestParam(name = "question", defaultValue = "你是谁?") String question) {
        Flux<String> content = deepseekChatClient.prompt()
                .system("你是一个法律问题专家,只回答法律相关的任何问题,与法律不相关的问题一律回答:我只能回答法律相关的问题")
                .user(question)
                .stream()
                .content();
        log.info("result: {}", content);
        return content;
    }
}

三、测试

测试前清空redis:

发起请求:

redis新增的记录:


Comment