当前位置:首页 > 技术分析 > 正文内容

聊聊langchain4j的AiServices

ruisui881个月前 (03-19)技术分析13

本文主要研究一下langchain4j的AiServices

示例

原生版本

public interface Assistant {
    String chat(String userMessage);
}

构建

Assistant assistant = AiServices.create(Assistant.class, chatLanguageModel);
String resp = assistant.chat(userMessage);

spring-boot版本

@AiService
public interface AssistantV2 {

    @SystemMessage("You are a polite assistant")
    String chat(String userMessage);
}

之后直接像使用托管的bean一样注入就可以使用

    @Autowired
    AssistantV2 assistantV2;

    @GetMapping("/ai-service")
    public String aiService(@RequestParam("prompt") String prompt) {
        return assistantV2.chat(prompt);
    }    

源码

AiServices

dev/langchain4j/service/AiServices.java

public abstract class AiServices {

    protected static final String DEFAULT = "default";

    protected final AiServiceContext context;

    private boolean retrieverSet = false;
    private boolean contentRetrieverSet = false;
    private boolean retrievalAugmentorSet = false;

    protected AiServices(AiServiceContext context) {
        this.context = context;
    }

    /**
     * Creates an AI Service (an implementation of the provided interface), that is backed by the provided chat model.
     * This convenience method can be used to create simple AI Services.
     * For more complex cases, please use {@link #builder}.
     *
     * @param aiService         The class of the interface to be implemented.
     * @param chatLanguageModel The chat model to be used under the hood.
     * @return An instance of the provided interface, implementing all its defined methods.
     */
    public static  T create(Class aiService, ChatLanguageModel chatLanguageModel) {
        return builder(aiService).chatLanguageModel(chatLanguageModel).build();
    }

    /**
     * Creates an AI Service (an implementation of the provided interface), that is backed by the provided streaming chat model.
     * This convenience method can be used to create simple AI Services.
     * For more complex cases, please use {@link #builder}.
     *
     * @param aiService                  The class of the interface to be implemented.
     * @param streamingChatLanguageModel The streaming chat model to be used under the hood.
     *                                   The return type of all methods should be {@link TokenStream}.
     * @return An instance of the provided interface, implementing all its defined methods.
     */
    public static  T create(Class aiService, StreamingChatLanguageModel streamingChatLanguageModel) {
        return builder(aiService)
                .streamingChatLanguageModel(streamingChatLanguageModel)
                .build();
    }

    /**
     * Begins the construction of an AI Service.
     *
     * @param aiService The class of the interface to be implemented.
     * @return builder
     */
    public static  AiServices builder(Class aiService) {
        AiServiceContext context = new AiServiceContext(aiService);
        for (AiServicesFactory factory : loadFactories(AiServicesFactory.class)) {
            return factory.create(context);
        }
        return new DefaultAiServices<>(context);
    }

    /**
     * Configures chat model that will be used under the hood of the AI Service.
     * 

* Either {@link ChatLanguageModel} or {@link StreamingChatLanguageModel} should be configured, * but not both at the same time. * * @param chatLanguageModel Chat model that will be used under the hood of the AI Service. * @return builder */ public AiServices chatLanguageModel(ChatLanguageModel chatLanguageModel) { context.chatModel = chatLanguageModel; return this; } /** * Configures streaming chat model that will be used under the hood of the AI Service. * The methods of the AI Service must return a {@link TokenStream} type. *

* Either {@link ChatLanguageModel} or {@link StreamingChatLanguageModel} should be configured, * but not both at the same time. * * @param streamingChatLanguageModel Streaming chat model that will be used under the hood of the AI Service. * @return builder */ public AiServices streamingChatLanguageModel(StreamingChatLanguageModel streamingChatLanguageModel) { context.streamingChatModel = streamingChatLanguageModel; return this; } /** * Configures the system message provider, which provides a system message to be used each time an AI service is invoked. *
* When both {@code @SystemMessage} and the system message provider are configured, * {@code @SystemMessage} takes precedence. * * @param systemMessageProvider A {@link Function} that accepts a chat memory ID * (a value of a method parameter annotated with @{@link MemoryId}) * and returns a system message to be used. * If there is no parameter annotated with {@code @MemoryId}, * the value of memory ID is "default". * The returned {@link String} can be either a complete system message * or a system message template containing unresolved template variables (e.g. "{{name}}"), * which will be resolved using the values of method parameters annotated with @{@link V}. * @return builder */ public AiServices systemMessageProvider(Function systemMessageProvider) { context.systemMessageProvider = systemMessageProvider.andThen(Optional::ofNullable); return this; } /** * Configures the chat memory that will be used to preserve conversation history between method calls. *

* Unless a {@link ChatMemory} or {@link ChatMemoryProvider} is configured, all method calls will be independent of each other. * In other words, the LLM will not remember the conversation from the previous method calls. *

* The same {@link ChatMemory} instance will be used for every method call. *

* If you want to have a separate {@link ChatMemory} for each user/conversation, configure {@link #chatMemoryProvider} instead. *

* Either a {@link ChatMemory} or a {@link ChatMemoryProvider} can be configured, but not both simultaneously. * * @param chatMemory An instance of chat memory to be used by the AI Service. * @return builder */ public AiServices chatMemory(ChatMemory chatMemory) { context.chatMemories = new ConcurrentHashMap<>(); context.chatMemories.put(DEFAULT, chatMemory); return this; } /** * Configures the chat memory provider, which provides a dedicated instance of {@link ChatMemory} for each user/conversation. * To distinguish between users/conversations, one of the method's arguments should be a memory ID (of any data type) * annotated with {@link MemoryId}. * For each new (previously unseen) memoryId, an instance of {@link ChatMemory} will be automatically obtained * by invoking {@link ChatMemoryProvider#get(Object id)}. * Example: *

     * interface Assistant {
     *
     *     String chat(@MemoryId int memoryId, @UserMessage String message);
     * }
     * 
* If you prefer to use the same (shared) {@link ChatMemory} for all users/conversations, configure a {@link #chatMemory} instead. *

* Either a {@link ChatMemory} or a {@link ChatMemoryProvider} can be configured, but not both simultaneously. * * @param chatMemoryProvider The provider of a {@link ChatMemory} for each new user/conversation. * @return builder */ public AiServices chatMemoryProvider(ChatMemoryProvider chatMemoryProvider) { context.chatMemories = new ConcurrentHashMap<>(); context.chatMemoryProvider = chatMemoryProvider; return this; } /** * Configures a moderation model to be used for automatic content moderation. * If a method in the AI Service is annotated with {@link Moderate}, the moderation model will be invoked * to check the user content for any inappropriate or harmful material. * * @param moderationModel The moderation model to be used for content moderation. * @return builder * @see Moderate */ public AiServices moderationModel(ModerationModel moderationModel) { context.moderationModel = moderationModel; return this; } /** * Configures the tools that the LLM can use. * * @param objectsWithTools One or more objects whose methods are annotated with {@link Tool}. * All these tools (methods annotated with {@link Tool}) will be accessible to the LLM. * Note that inherited methods are ignored. * @return builder * @see Tool */ public AiServices tools(Object... objectsWithTools) { return tools(asList(objectsWithTools)); } /** * Configures the tools that the LLM can use. * * @param objectsWithTools A list of objects whose methods are annotated with {@link Tool}. * All these tools (methods annotated with {@link Tool}) are accessible to the LLM. * Note that inherited methods are ignored. * @return builder * @see Tool */ public AiServices tools(Collection<Object> objectsWithTools) { context.toolService.tools(objectsWithTools); return this; } /** * Configures the tool provider that the LLM can use * * @param toolProvider Decides which tools the LLM could use to handle the request * @return builder */ public AiServices toolProvider(ToolProvider toolProvider) { context.toolService.toolProvider(toolProvider); return this; } /** * Configures the tools that the LLM can use. * * @param tools A map of {@link ToolSpecification} to {@link ToolExecutor} entries. * This method of configuring tools is useful when tools must be configured programmatically. * Otherwise, it is recommended to use the {@link Tool}-annotated java methods * and configure tools with the {@link #tools(Object...)} and {@link #tools(Collection)} methods. * @return builder */ public AiServices tools(Map tools) { context.toolService.tools(tools); return this; } /** * Configures the strategy to be used when the LLM hallucinates a tool name (i.e., attempts to call a nonexistent tool). * * @param hallucinatedToolNameStrategy A Function from {@link ToolExecutionRequest} to {@link ToolExecutionResultMessage} defining * the response provided to the LLM when it hallucinates a tool name. * @return builder */ public AiServices hallucinatedToolNameStrategy( Function hallucinatedToolNameStrategy) { context.toolService.hallucinatedToolNameStrategy(hallucinatedToolNameStrategy); return this; } /** * @param retriever The retriever to be used by the AI Service. * @return builder * @deprecated Use {@link #contentRetriever(ContentRetriever)} * (e.g. {@link EmbeddingStoreContentRetriever}) instead. *
* Configures a retriever that will be invoked on every method call to fetch relevant information * related to the current user message from an underlying source (e.g., embedding store). * This relevant information is automatically injected into the message sent to the LLM. */ @Deprecated(forRemoval = true) public AiServices retriever(Retriever retriever) { if (contentRetrieverSet || retrievalAugmentorSet) { throw illegalConfiguration("Only one out of [retriever, contentRetriever, retrievalAugmentor] can be set"); } if (retriever != null) { AiServices withContentRetriever = contentRetriever(retriever.toContentRetriever()); retrieverSet = true; return withContentRetriever; } return this; } /** * Configures a content retriever to be invoked on every method call for retrieving relevant content * related to the user's message from an underlying data source * (e.g., an embedding store in the case of an {@link EmbeddingStoreContentRetriever}). * The retrieved relevant content is then automatically incorporated into the message sent to the LLM. *
* This method provides a straightforward approach for those who do not require * a customized {@link RetrievalAugmentor}. * It configures a {@link DefaultRetrievalAugmentor} with the provided {@link ContentRetriever}. * * @param contentRetriever The content retriever to be used by the AI Service. * @return builder */ public AiServices contentRetriever(ContentRetriever contentRetriever) { if (retrieverSet || retrievalAugmentorSet) { throw illegalConfiguration("Only one out of [retriever, contentRetriever, retrievalAugmentor] can be set"); } contentRetrieverSet = true; context.retrievalAugmentor = DefaultRetrievalAugmentor.builder() .contentRetriever(ensureNotNull(contentRetriever, "contentRetriever")) .build(); return this; } /** * Configures a retrieval augmentor to be invoked on every method call. * * @param retrievalAugmentor The retrieval augmentor to be used by the AI Service. * @return builder */ public AiServices retrievalAugmentor(RetrievalAugmentor retrievalAugmentor) { if (retrieverSet || contentRetrieverSet) { throw illegalConfiguration("Only one out of [retriever, contentRetriever, retrievalAugmentor] can be set"); } retrievalAugmentorSet = true; context.retrievalAugmentor = ensureNotNull(retrievalAugmentor, "retrievalAugmentor"); return this; } /** * Constructs and returns the AI Service. * * @return An instance of the AI Service implementing the specified interface. */ public abstract T build(); //...... }

AiServices是个抽象类,它提供了AiServices的builder方法,默认创建DefaultAiServices,它提供了设置chatLanguageModel、
streamingChatLanguageModel、systemMessageProvider、chatMemory、chatMemoryProvider、moderationModel、tools、toolProvider、contentRetriever、retrievalAugmentor方法。它定义了build抽象方法供子类去实现。

DefaultAiServices

dev/langchain4j/service/DefaultAiServices.java

class DefaultAiServices extends AiServices {

    private final ServiceOutputParser serviceOutputParser = new ServiceOutputParser();
    private final Collection tokenStreamAdapters = loadFactories(TokenStreamAdapter.class);

    DefaultAiServices(AiServiceContext context) {
        super(context);
    }

    //......

    public T build() {

        performBasicValidation();

        for (Method method : context.aiServiceClass.getMethods()) {
            if (method.isAnnotationPresent(Moderate.class) && context.moderationModel == null) {
                throw illegalConfiguration(
                        "The @Moderate annotation is present, but the moderationModel is not set up. "
                                + "Please ensure a valid moderationModel is configured before using the @Moderate annotation.");
            }
            if (method.getReturnType() == Result.class
                    || method.getReturnType() == List.class
                    || method.getReturnType() == Set.class) {
                TypeUtils.validateReturnTypesAreProperlyParametrized(method.getName(), method.getGenericReturnType());
            }

            if (context.chatMemoryProvider == null) {
                for (Parameter parameter : method.getParameters()) {
                    if (parameter.isAnnotationPresent(MemoryId.class)) {
                        throw illegalConfiguration(
                                "In order to use @MemoryId, please configure the ChatMemoryProvider on the '%s'.",
                                context.aiServiceClass.getName());
                    }
                }
            }
        }

        Object proxyInstance = Proxy.newProxyInstance(
                context.aiServiceClass.getClassLoader(),
                new Class[] {context.aiServiceClass},
                new InvocationHandler() {

                    private final ExecutorService executor = Executors.newCachedThreadPool();

                    @Override
                    public Object invoke(Object proxy, Method method, Object[] args) throws Exception {

                        if (method.getDeclaringClass() == Object.class) {
                            // methods like equals(), hashCode() and toString() should not be handled by this proxy
                            return method.invoke(this, args);
                        }

                        validateParameters(method);

                        Object memoryId = findMemoryId(method, args).orElse(DEFAULT);

                        Optional systemMessage = prepareSystemMessage(memoryId, method, args);
                        UserMessage userMessage = prepareUserMessage(method, args);
                        AugmentationResult augmentationResult = null;
                        if (context.retrievalAugmentor != null) {
                            List chatMemory = context.hasChatMemory()
                                    ? context.chatMemory(memoryId).messages()
                                    : null;
                            Metadata metadata = Metadata.from(userMessage, memoryId, chatMemory);
                            AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
                            augmentationResult = context.retrievalAugmentor.augment(augmentationRequest);
                            userMessage = (UserMessage) augmentationResult.chatMessage();
                        }

                        // TODO give user ability to provide custom OutputParser
                        Type returnType = method.getGenericReturnType();

                        boolean streaming = returnType == TokenStream.class || canAdaptTokenStreamTo(returnType);

                        boolean supportsJsonSchema =
                                supportsJsonSchema(); // TODO should it be called for returnType==String?
                        Optional jsonSchema = Optional.empty();
                        if (supportsJsonSchema && !streaming) {
                            jsonSchema = jsonSchemaFrom(returnType);
                        }

                        if ((!supportsJsonSchema || jsonSchema.isEmpty()) && !streaming) {
                            // TODO append after storing in the memory?
                            userMessage = appendOutputFormatInstructions(returnType, userMessage);
                        }

                        if (context.hasChatMemory()) {
                            ChatMemory chatMemory = context.chatMemory(memoryId);
                            systemMessage.ifPresent(chatMemory::add);
                            chatMemory.add(userMessage);
                        }

                        List messages;
                        if (context.hasChatMemory()) {
                            messages = context.chatMemory(memoryId).messages();
                        } else {
                            messages = new ArrayList<>();
                            systemMessage.ifPresent(messages::add);
                            messages.add(userMessage);
                        }

                        Future moderationFuture = triggerModerationIfNeeded(method, messages);

                        ToolExecutionContext toolExecutionContext =
                                context.toolService.executionContext(memoryId, userMessage);

                        if (streaming) {
                            TokenStream tokenStream = new AiServiceTokenStream(
                                    messages,
                                    toolExecutionContext.toolSpecifications(),
                                    toolExecutionContext.toolExecutors(),
                                    augmentationResult != null ? augmentationResult.contents() : null,
                                    context,
                                    memoryId);
                            // TODO moderation
                            if (returnType == TokenStream.class) {
                                return tokenStream;
                            } else {
                                return adapt(tokenStream, returnType);
                            }
                        }

                        ResponseFormat responseFormat = null;
                        if (supportsJsonSchema && jsonSchema.isPresent()) {
                            responseFormat = ResponseFormat.builder()
                                    .type(JSON)
                                    .jsonSchema(jsonSchema.get())
                                    .build();
                        }

                        ChatRequestParameters parameters = ChatRequestParameters.builder()
                                .toolSpecifications(toolExecutionContext.toolSpecifications())
                                .responseFormat(responseFormat)
                                .build();

                        ChatRequest chatRequest = ChatRequest.builder()
                                .messages(messages)
                                .parameters(parameters)
                                .build();

                        ChatResponse chatResponse = context.chatModel.chat(chatRequest);

                        verifyModerationIfNeeded(moderationFuture);

                        ToolExecutionResult toolExecutionResult = context.toolService.executeInferenceAndToolsLoop(
                                chatResponse,
                                parameters,
                                messages,
                                context.chatModel,
                                context.hasChatMemory() ? context.chatMemory(memoryId) : null,
                                memoryId,
                                toolExecutionContext.toolExecutors());

                        chatResponse = toolExecutionResult.chatResponse();
                        FinishReason finishReason = chatResponse.metadata().finishReason();
                        Response response = Response.from(
                                chatResponse.aiMessage(), toolExecutionResult.tokenUsageAccumulator(), finishReason);

                        Object parsedResponse = serviceOutputParser.parse(response, returnType);
                        if (typeHasRawClass(returnType, Result.class)) {
                            return Result.builder()
                                    .content(parsedResponse)
                                    .tokenUsage(toolExecutionResult.tokenUsageAccumulator())
                                    .sources(augmentationResult == null ? null : augmentationResult.contents())
                                    .finishReason(finishReason)
                                    .toolExecutions(toolExecutionResult.toolExecutions())
                                    .build();
                        } else {
                            return parsedResponse;
                        }
                    }

                    private boolean canAdaptTokenStreamTo(Type returnType) {
                        for (TokenStreamAdapter tokenStreamAdapter : tokenStreamAdapters) {
                            if (tokenStreamAdapter.canAdaptTokenStreamTo(returnType)) {
                                return true;
                            }
                        }
                        return false;
                    }

                    private Object adapt(TokenStream tokenStream, Type returnType) {
                        for (TokenStreamAdapter tokenStreamAdapter : tokenStreamAdapters) {
                            if (tokenStreamAdapter.canAdaptTokenStreamTo(returnType)) {
                                return tokenStreamAdapter.adapt(tokenStream);
                            }
                        }
                        throw new IllegalStateException("Can't find suitable TokenStreamAdapter");
                    }

                    private boolean supportsJsonSchema() {
                        return context.chatModel != null
                                && context.chatModel.supportedCapabilities().contains(RESPONSE_FORMAT_JSON_SCHEMA);
                    }

                    private UserMessage appendOutputFormatInstructions(Type returnType, UserMessage userMessage) {
                        String outputFormatInstructions = serviceOutputParser.outputFormatInstructions(returnType);
                        String text = userMessage.singleText() + outputFormatInstructions;
                        if (isNotNullOrBlank(userMessage.name())) {
                            userMessage = UserMessage.from(userMessage.name(), text);
                        } else {
                            userMessage = UserMessage.from(text);
                        }
                        return userMessage;
                    }

                    private Future triggerModerationIfNeeded(Method method, List messages) {
                        if (method.isAnnotationPresent(Moderate.class)) {
                            return executor.submit(() -> {
                                List messagesToModerate = removeToolMessages(messages);
                                return context.moderationModel
                                        .moderate(messagesToModerate)
                                        .content();
                            });
                        }
                        return null;
                    }
                });

        return (T) proxyInstance;
    }

    //......
}

DefaultAiServices集成了AiServices,它的build方法主要通过Proxy.newProxyInstance来创建实现类,InvocationHandler的实现主要是处理systemMessage、userMessage、构建chatMemory、toolExecutionContext,最后构建ChatRequest,通过context.chatModel.chat(chatRequest)执行请求,然后解析和适配输出。

小结

langchain4j提供了诸如ChatLanguageModel, ChatMessage, ChatMemory的low level的组件,也提供了诸如Chains和AI Services这样的high level的组件,用于协同多个组件(提示模版、ChatMemory、LLM、输出解析、RAG组件:嵌入模型和评分)一起。其中Chains是从Python的LangChain移植过来的,不过不方便自定义,于是后续不再继续添加新增功能了。langchain4j提供了AI Services来取代Chains,它有点类似于JPA或者Retrofit,通过简单声明接口就可以自动生成代理实现类,它可以处理LLM输入的格式化,LLM输出的解析,ChatMemory、Tools、RAG。

langchain4j提供了AiServices来创建DefaultAiServices,它默认是通过JDK的Proxy.newProxyInstance创建了实现类。

doc

  • ai-services

扫描二维码推送至手机访问。

版权声明:本文由ruisui88发布,如需转载请注明出处。

本文链接:http://www.ruisui88.com/post/2891.html

分享给朋友:

“聊聊langchain4j的AiServices” 的相关文章

微软的Linux发行版终于加入了对XFS根文件系统的支持

当许多Linux发行版在评估新的根文件系统选项或甚至像OpenZFS这样的特性,微软内部Linux发行版到本月才开始支持XFS作为根文件系统选项。随着这个月对微软内部Linux发行版CBL-Mariner的更新,他们现在支持XFS作为根文件系统。到目前为止,这个用于微软内部各种目的的Linux发行版...

Vue组件通信之props深入详解!

props 是 Vue 组件中一个很重要的概念。它是用来从父组件向子组件传递数据的。为什么需要props?这是因为在Vue中,组件是相互隔离的。每个组件都有自己的作用域,子组件无法直接访问父组件的状态或值。通过props,父组件可以将数据传递给子组件。使用props的步骤:1. 在子组件中定义pro...

「2022」打算跳槽涨薪,必问面试题及答案——VUE篇

1、为什么选择VUE,解决了什么问题?vue.js 正如官网所说的,是一套构建用户界面的渐进式框架。与其它重量级框架不同的是,vue 被设计为可以自底向上逐层应用。vue 的核心库只关注视图层,不仅易于上手,还便于与第三方库或既有项目整合。另外一方面,当与现代化工具链以及各种支持类库结合使用时,vu...

gitlab简单搭建与应用

一、gitlab1、简介GitLab是利用Ruby on Rails一个开源的版本管理系统,实现一个自托管的Git项目仓库,可通过Web界面进行访问公开的或者私人项目。与Github类似,GitLab能够浏览源代码,管理缺陷和注释。可以管理团队对仓库的访问,它非常易于浏览提交过的版本并提供一个文件历...

迁移GIT仓库并带有历史提交记录

迁移git仓库开发在很多时候,会遇到一个问题。GIT仓库的管理,特别是仓库的迁移。我需要保留已有的历史记录,而不是重新开发,重头再来。我们可以这样做:使用--mirror模式会把本地的分支都克隆。// 先用--bare克隆裸仓库 git clone git@gitee.com:xxx/testApp...

深入理解vue-router原理

说到vue-router就表明他只适合于vue和vue是强绑定的关系;不适合其他框架;现在我们模仿实现一个VueRouter;1.要使页面刷新;借助vue本身的响应式原理;import Home from "./views/Home"; import About from "...