Introdução
Baseado no post anterior abaixo, neste post irei mostrar como implementar uma aplicação console que utiliza o framework langchain4j para a construir aplicações de chat com uso de LLMs (Large Language Models) comerciais em conjunto com a técnica RAG (Retrieval Augmented Generation).
A aplicação de exemplo está disponível no seguinte repositório do github https://github.com/prbpedro/java-langchain4j-googlevertexai-rag.
Tecnologias utilizadas
Serão utilizadas as seguintes tecnologias:
Java 21
Gradle 8.8
langchain4j 0.33.0
Postgres 16.3
Qdrant Vector Database
Google Vertex AI - Modelos text-multilingual-embedding-002 e gemini-1.5-flash-001
Arquitetura básica e funcionamento da aplicação
Os recursos necessários para a aplicação, postgres (base de dados relacional para o armazenamento dos chats) e Qdrant (Base de dados vetorial para armazenamento e busca dos vetores da técnica RAG), podem ser criados com o uso da ferramenta docker compose.
O seguinte arquivo docker-compose.yaml está configurado para inicializar a base de dados postgres com a estrutura de tabela necessária para armazenamento dos chats através do script sql init.sql e criar a base de dados vetorial Qdrant.
services:
qdrant:
image: mirror.gcr.io/qdrant/qdrant
container_name: qdrant
ports:
- 6333:6333
- 6334:6334
postgres:
image: mirror.gcr.io/postgres:16.3-bullseye
container_name: postgres
restart: always
shm_size: 128mb
environment:
POSTGRES_PASSWORD: postgres
POSTGRES_USER: postgres
POSTGRES_DB: postgres
ports:
- 5432:5432
volumes:
- ./postgres:/docker-entrypoint-initdb.d
healthcheck:
test: [ "CMD-SHELL", "pg_isready -d postgres -U postgres" ]
interval: 15s
timeout: 5s
retries: 5
A estrutura da tabela para armazenamento dos chats no postgres pode ser representada da seguinte forma:
CREATE TABLE IF NOT EXISTS public.chats
(
id integer NOT NULL,
chat_messages jsonb,
CONSTRAINT chats_pkey PRIMARY KEY (id)
);
Na inicialização da aplicação a mesma deverá?
Criar uma instância que implemente a interface ChatMemoryStore do framework langchain4j que irá orquestrar as chamadas aos métodos implementados pela classe para gerenciar a persistência dos chats.
A classe PersistentChatMemoryStore é um exemplo que pode ser utilizado:
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
import org.postgresql.util.PGobject;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageDeserializer;
import dev.langchain4j.data.message.ChatMessageSerializer;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
public class PersistentChatMemoryStore implements ChatMemoryStore, AutoCloseable {
private static final String SELECT = "SELECT chat_messages FROM chats WHERE id = ?";
private static final String SELECT_IDS = "SELECT id FROM chats";
private static final String DELETE = "delete from chats where id = ?";
private static final String UPSERT = """
INSERT INTO chats (id, chat_messages)
VALUES (?, ?)
ON CONFLICT(id)
DO UPDATE SET
chat_messages = EXCLUDED.chat_messages;
""";
private static final String DB_URL = "jdbc:postgresql://localhost/postgres?user=postgres&password=postgres&ssl=false";
private final Connection conn;
public PersistentChatMemoryStore() throws SQLException {
conn = DriverManager.getConnection(DB_URL);
}
public List<Integer> getChatIds() {
try (Statement statement = conn.createStatement()) {
try (ResultSet rs = statement.executeQuery(SELECT_IDS)) {
List<Integer> ids = new ArrayList<>();
while (rs.next()) {
ids.add(rs.getInt(1));
}
return ids;
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public List<ChatMessage> getMessages(Object memoryId) {
try (PreparedStatement preparedStatement = conn.prepareStatement(SELECT)) {
preparedStatement.setInt(1, (int) memoryId);
try (ResultSet rs = preparedStatement.executeQuery()) {
while (rs.next()) {
final String chatMessagesString = rs.getString(1);
return ChatMessageDeserializer.messagesFromJson(chatMessagesString);
}
return List.of();
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public void updateMessages(Object memoryId, List<ChatMessage> messages) {
try (PreparedStatement preparedStatement = conn.prepareStatement(UPSERT)) {
final String chatMessages = ChatMessageSerializer.messagesToJson(messages);
preparedStatement.setInt(1, (int) memoryId);
final PGobject jsonObject = new PGobject();
jsonObject.setType("jsonb");
jsonObject.setValue(chatMessages);
preparedStatement.setObject(2, jsonObject);
preparedStatement.execute();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public void deleteMessages(Object memoryId) {
try (PreparedStatement preparedStatement = conn.prepareStatement(DELETE)) {
preparedStatement.setInt(1, (int) memoryId);
preparedStatement.execute();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public void close() throws SQLException {
conn.close();
}
}
Instanciar o modelo de embedding, responsável por transformar as partes dos documentos em vetores, referenciando o modelo text-multilingual-embedding-002, disposto pelo Google Cloud Platform,através da interface EmbeddingModel e classe VertexAiEmbeddingModel do framework langchain4j.
final String googleVertexAiEndpoint = System.getenv(GOOGLE_VERTEX_AI_LOCATION) + "-aiplatform.googleapis.com:443";
EmbeddingModel embeddingModel = VertexAiEmbeddingModel
.builder()
.endpoint(googleVertexAiEndpoint)
.project(System.getenv(GOOGLE_VERTEX_AI_PROJECT_ID))
.location(System.getenv(GOOGLE_VERTEX_AI_LOCATION))
.publisher("google")
.modelName("text-multilingual-embedding-002")
.maxRetries(3)
.build();
Instanciar o modelo de chat, responsável por interagir com a LLM comercial, referenciando o modelo gemini-1.5-flash-001, disposto pelo Google Cloud Platform,através da interface ChatLanguageModel e classe VertexAiGeminiChatModel do framework langchain4j.
ChatLanguageModel chatLanguageModel = VertexAiGeminiChatModel
.builder()
.project(System.getenv(GOOGLE_VERTEX_AI_PROJECT_ID))
.location(System.getenv(GOOGLE_VERTEX_AI_LOCATION))
.modelName("gemini-1.5-flash-001")
.maxOutputTokens(1000)
.temperature(0f)
.build();
Instanciar a classe EmbeddingStore do framework langchain4j, inicializando a coleção de vetores na base de dados vetorial Qdrant, obtendo os documentos dispostos na pasta documents, vetorizando as partes dos documentos e inserindo os mesmos na base de dados vetorial. A classe DocumentIngestor utiliza classes do framework langchain4j e classes dispostas pelo cliente Java do Qdrant para ingerir os documentos. Na criação da coleção do Qdrant são utilizadas as configurações tamanho do vetor igual a 768 (conforme modelo text-multilingual-embedding-002) e métrica de cálculo de similaridade de vetores igual a Cosseno por ser a indicada para textos.
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.stream.Stream;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.parser.TextDocumentParser;
import dev.langchain4j.data.document.parser.apache.pdfbox.ApachePdfBoxDocumentParser;
import dev.langchain4j.data.document.splitter.DocumentSplitters;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import io.qdrant.client.ConditionFactory;
import io.qdrant.client.QdrantClient;
import io.qdrant.client.QdrantGrpcClient;
import io.qdrant.client.grpc.Collections.Distance;
import io.qdrant.client.grpc.Collections.VectorParams;
import io.qdrant.client.grpc.Points.Filter;
import io.qdrant.client.grpc.Points.ScrollPoints;
import io.qdrant.client.grpc.Points.ScrollResponse;
import io.qdrant.client.grpc.Points.WithPayloadSelector;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
@Slf4j
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public class DocumentIngestor {
private static final String DOCUMENTS_FOLDER = "documents";
private static final String PDF_EXTENSION = ".pdf";
private static final String FILE_NAME = "file_name";
public static void ingestDocuments(
final EmbeddingModel embeddingModel,
final EmbeddingStore<TextSegment> embeddingStore)
throws URISyntaxException, IOException, InterruptedException, ExecutionException {
final List<Document> documents = getDocuments();
final EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder()
.documentSplitter(DocumentSplitters.recursive(26, 0))
.embeddingModel(embeddingModel)
.embeddingStore(embeddingStore)
.build();
try (QdrantClient qdrantClient = new QdrantClient(
QdrantGrpcClient.newBuilder(
QDrantConstants.HOST,
QDrantConstants.PORT,
QDrantConstants.USE_TLS)
.build())) {
createCollectionIfNotExists(qdrantClient);
ingestDocuments(documents, ingestor, qdrantClient);
}
}
private static List<Document> getDocuments()
throws URISyntaxException, FileNotFoundException, IOException {
final List<Document> documents = new ArrayList<>();
final ApachePdfBoxDocumentParser pdfParser = new ApachePdfBoxDocumentParser();
final TextDocumentParser textParser = new TextDocumentParser();
final Path path = Paths.get(
App.class
.getClassLoader()
.getResource(DOCUMENTS_FOLDER)
.toURI());
try (Stream<Path> walk = Files.walk(path)) {
final Iterator<Path> walkIterator = walk.iterator();
while (walkIterator.hasNext()) {
final Path documentPath = walkIterator.next();
final File file = documentPath.toFile();
if (file.isFile()) {
final FileInputStream inputStream = new FileInputStream(file);
final Document document = documentPath.endsWith(PDF_EXTENSION)
? pdfParser.parse(inputStream)
: textParser.parse(inputStream);
document.metadata().put(FILE_NAME, file.getName());
documents.add(document);
}
}
}
return documents;
}
private static void createCollectionIfNotExists(final QdrantClient qdrantClient)
throws InterruptedException, ExecutionException {
final Optional<String> collection = qdrantClient
.listCollectionsAsync()
.get()
.stream()
.filter(coll -> coll.equals(QDrantConstants.COLLECTION_NAME))
.findFirst();
if (collection.isPresent()) {
log.info("Coleção já existe: {}", collection.get());
return;
}
qdrantClient.createCollectionAsync(
QDrantConstants.COLLECTION_NAME,
VectorParams
.newBuilder()
.setDistance(Distance.Cosine)
.setSize(768)
.build())
.get();
}
private static void ingestDocuments(
final List<Document> documents,
final EmbeddingStoreIngestor ingestor,
final QdrantClient qdrantClient)
throws InterruptedException, ExecutionException {
for (Document document : documents) {
ScrollResponse scrollResponse = qdrantClient.scrollAsync(
ScrollPoints
.newBuilder()
.setCollectionName(QDrantConstants.COLLECTION_NAME)
.setFilter(
Filter
.newBuilder()
.addMust(ConditionFactory
.matchKeyword(
FILE_NAME,
document.metadata().getString(FILE_NAME)))
.build())
.setLimit(1)
.setWithPayload(WithPayloadSelector.newBuilder().setEnable(true).build())
.build())
.get();
if (scrollResponse.getResultCount() > 0) {
log.info("Documento já inserido: {}", document.metadata().getString(FILE_NAME));
} else {
ingestor.ingest(document);
}
}
}
}
Instanciar a classe CompressingQueryTransformer do framework langchain4j referenciando o modelo de chat já instanciado. Esta classe será responsável por condensar uma determinada consulta junto com uma memória de chat em uma consulta concisa.
QueryTransformer queryTransformer = new CompressingQueryTransformer(chatLanguageModel);
Instanciar a classe EmbeddingStoreContentRetriever do framework langchain4j referenciando a instância da classe EmbeddingStore e do modelo de embedding já citados. Esta classe será responsável por orquestrar as chamadas ao EmbeddingStore para obtenção de similaridades de vetores. Foram customizadas as configurações de resultados máximos igual a 2 e score mínimo igual a 0.82.
RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
.queryTransformer(queryTransformer)
.contentRetriever(contentRetriever)
.build();
Instanciar a classe DefaultRetrievalAugmentor do framework langchain4j referenciando a instância da classe EmbeddingStore e do modelo de embedding já citados. Esta classe será responsável por orquestrar orquestrar o fluxo entre os componentes básicos QueryTransformer, QueryRouter, ContentRetriever, ContentAggregator e ContentInjector. O fluxo pode ser descrito da seguinte forma:
Uma consulta é transformada usando um QueryTransformer em uma ou várias consultas.
Cada consulta é roteada para o ContentRetriever apropriado usando um QueryRouter. Cada ContentRetriever recupera um ou vários conteúdos usando uma consulta.
Todos os conteúdos recuperados por todos os ContentRetrievers usando todas as consultas são agregados (fundidos/reclassificados/filtrados/etc.) em uma lista final de conteúdos usando um ContentAggregator.
Por último, uma lista final de conteúdos é injetada no prompt usando um ContentInjector
RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
.queryTransformer(queryTransformer)
.contentRetriever(contentRetriever)
.build();
Instanciar a classe MessageWindowChatMemory do framework langchain4j referenciando a instância da classe PersistentChatMemoryStore. Esta classe será responsável por orquestrar as chamadas à PersistentChatMemoryStore para obter as mensagens históricas dos chats. A configuração do número máximo de mensagens históricas carregadas foi customizada para 5.
ChatMemoryProvider chatMemoryProvider = memoryId -> MessageWindowChatMemory.builder()
.id(memoryId)
.maxMessages(15)
.chatMemoryStore(store)
.build();
Instanciar a interface Assistant, através da classe AiServices do framework langchain4j, referenciando as instâncias das classes DefaultRetrievalAugmentor, do modelo de chat e MessageWindowChatMemory.
import dev.langchain4j.service.MemoryId;
import dev.langchain4j.service.Result;
import dev.langchain4j.service.UserMessage;
public interface Assistant {
Result<String> chat(@MemoryId int memoryId, @UserMessage String userMessage);
}
AiServices.builder(Assistant.class)
.chatLanguageModel(chatLanguageModel)
.retrievalAugmentor(retrievalAugmentor)
.chatMemoryProvider(chatMemoryProvider)
.build();
Após a inicialização da aplicação deverá ser exibida uma mensagem com todos os identificadores de chats armazenados na base de dados Postgres.
A aplicação então solicitará que o usuário informe o id do chat para continuar um chat existente ou a entrada de um novo id para iniciar um novo chat. O id deve ser um número inteiro.
Caso seja optado por continuar um chat existente o framework langchain4j irá carregar as últimas 15 mensagens do chat e enviar os mesmos para o LLM, caso contrário irá persistir um novo chat. Após isto o framework irá solicitar a entrada de uma nova mensagem (prompt) do usuário para ser enviado ao LLM.
Após informar a mensagem o framework langchain4j irá converter a mesma em um vetor através do modelo text-multilingual-embedding-002 e buscar na base de dados de vetores Qdrant os vetores similares.
Caso sejam encontradas similaridades o framework langchain4j irá alterar o prompt feito para incluir o texto original referente aos vetores similares, caso contrário o mesmo não irá alterar o prompt, e enviar o prompt para o LLM gemini-1.5-flash-001
Com a resposta do LLM a aplicação exibirá então a mensagem enviada e a resposta da mesma.
Após isto é solicitada novamente a entrada de uma nova mensagem (prompt) do usuário para ser enviado ao LLM até que o usuário insira a mensagem sair.
import java.util.List;
import java.util.Scanner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import dev.langchain4j.service.Result;
public class Utils {
public static void startConversationWith(PersistentChatMemoryStore store, Assistant assistant) {
Logger log = LoggerFactory.getLogger(Assistant.class);
try (Scanner scanner = new Scanner(System.in)) {
List<Integer> chatIds = store.getChatIds();
log.info("==================================================");
log.info("Identificadores de chats: ");
chatIds.forEach(id -> log.info(id.toString()));
log.info(
"Digite um identificador de chat para continuar um chat existente ou um identificador não existente para iniciar um novo chat: ");
int chatId = Integer.parseInt(scanner.nextLine());
while (true) {
log.info("==================================================");
log.info("User: ");
String userQuery = scanner.nextLine();
log.info("==================================================");
if ("exit".equalsIgnoreCase(userQuery)) {
break;
}
Result<String> result = assistant.chat(chatId, userQuery);
log.info("==================================================");
log.info("Assistant: " + result.content());
log.info("Sources: ");
result.sources().forEach(content -> log.info(content.toString()));
}
}
}
}
Configurações necessárias para a execução da aplicação
Para a execução da aplicação é necessário que um projeto do Google Cloud Platform esteja preparado para receber chamadas para a API do Vertex AI.
Para ativar as APIs siga os seguintes passos dispostos no link https://cloud.google.com/vertex-ai/docs/start/cloud-environment?hl=pt-br#enable_vertexai_apis .
Após a ativação crie uma conta de serviço seguindo os passos dispostos no link https://cloud.google.com/iam/docs/service-accounts-create?hl=pt-br#iam-service-accounts-create-console .
Após isto faça download do arquivo de credenciais seguindo os passos dispostos no link https://developers.google.com/workspace/guides/create-credentials?hl=pt-br#create_credentials_for_a_service_account .
Agora será necessário confugirar as seguintes variáveis de ambiente:
GOOGLE_VERTEX_AI_PROJECT_ID: Id do projeto do Google Cloud Platform
GOOGLE_VERTEX_AI_LOCATION: Região do Vertex AI a ser utilizada (normalmente us-central1)
GOOGLE_APPLICATION_CREDENTIALS: Caminho para o arquivo de credenciais baixado