package org.springframework.ai.anthropic;

import com.fasterxml.jackson.core.type.TypeReference;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.MessageAggregator;
import org.springframework.ai.chat.observation.ChatModelObservationContext;
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.content.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

/* loaded from: input_file:org/springframework/ai/anthropic/AnthropicChatModel.class */
public class AnthropicChatModel implements ChatModel {
    public static final String DEFAULT_MODEL_NAME = AnthropicApi.ChatModel.CLAUDE_3_7_SONNET.getValue();
    public static final Integer DEFAULT_MAX_TOKENS = 500;
    public static final Double DEFAULT_TEMPERATURE = Double.valueOf(0.8d);
    private static final Logger logger = LoggerFactory.getLogger(AnthropicChatModel.class);
    private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
    private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build();
    public final RetryTemplate retryTemplate;
    private final AnthropicApi anthropicApi;
    private final AnthropicChatOptions defaultOptions;
    private final ObservationRegistry observationRegistry;
    private final ToolCallingManager toolCallingManager;
    private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate;
    private ChatModelObservationConvention observationConvention;

    /* loaded from: input_file:org/springframework/ai/anthropic/AnthropicChatModel$Builder.class */
    public static final class Builder {
        private AnthropicApi anthropicApi;
        private ToolCallingManager toolCallingManager;
        private AnthropicChatOptions defaultOptions = AnthropicChatOptions.builder().model(AnthropicChatModel.DEFAULT_MODEL_NAME).maxTokens(AnthropicChatModel.DEFAULT_MAX_TOKENS).temperature(AnthropicChatModel.DEFAULT_TEMPERATURE).build();
        private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;
        private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
        private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate();

        private Builder() {
        }

        public Builder anthropicApi(AnthropicApi anthropicApi) {
            this.anthropicApi = anthropicApi;
            return this;
        }

        public Builder defaultOptions(AnthropicChatOptions anthropicChatOptions) {
            this.defaultOptions = anthropicChatOptions;
            return this;
        }

        public Builder retryTemplate(RetryTemplate retryTemplate) {
            this.retryTemplate = retryTemplate;
            return this;
        }

        public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
            this.toolCallingManager = toolCallingManager;
            return this;
        }

        public Builder toolExecutionEligibilityPredicate(ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
            this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
            return this;
        }

        public Builder observationRegistry(ObservationRegistry observationRegistry) {
            this.observationRegistry = observationRegistry;
            return this;
        }

        public AnthropicChatModel build() {
            return this.toolCallingManager != null ? new AnthropicChatModel(this.anthropicApi, this.defaultOptions, this.toolCallingManager, this.retryTemplate, this.observationRegistry) : new AnthropicChatModel(this.anthropicApi, this.defaultOptions, AnthropicChatModel.DEFAULT_TOOL_CALLING_MANAGER, this.retryTemplate, this.observationRegistry, this.toolExecutionEligibilityPredicate);
        }
    }

    public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions anthropicChatOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry) {
        this(anthropicApi, anthropicChatOptions, toolCallingManager, retryTemplate, observationRegistry, new DefaultToolExecutionEligibilityPredicate());
    }

    public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions anthropicChatOptions, ToolCallingManager toolCallingManager, RetryTemplate retryTemplate, ObservationRegistry observationRegistry, ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
        this.observationConvention = DEFAULT_OBSERVATION_CONVENTION;
        Assert.notNull(anthropicApi, "anthropicApi cannot be null");
        Assert.notNull(anthropicChatOptions, "defaultOptions cannot be null");
        Assert.notNull(toolCallingManager, "toolCallingManager cannot be null");
        Assert.notNull(retryTemplate, "retryTemplate cannot be null");
        Assert.notNull(observationRegistry, "observationRegistry cannot be null");
        Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate cannot be null");
        this.anthropicApi = anthropicApi;
        this.defaultOptions = anthropicChatOptions;
        this.toolCallingManager = toolCallingManager;
        this.retryTemplate = retryTemplate;
        this.observationRegistry = observationRegistry;
        this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
    }

    public ChatResponse call(Prompt prompt) {
        return internalCall(buildRequestPrompt(prompt), null);
    }

    public ChatResponse internalCall(Prompt prompt, ChatResponse chatResponse) {
        AnthropicApi.ChatCompletionRequest createRequest = createRequest(prompt, false);
        ChatModelObservationContext build = ChatModelObservationContext.builder().prompt(prompt).provider(AnthropicApi.PROVIDER_NAME).requestOptions(prompt.getOptions()).build();
        ChatResponse chatResponse2 = (ChatResponse) ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> {
            return build;
        }, this.observationRegistry).observe(() -> {
            ResponseEntity responseEntity = (ResponseEntity) this.retryTemplate.execute(retryContext -> {
                return this.anthropicApi.chatCompletionEntity(createRequest, getAdditionalHttpHeaders(prompt));
            });
            AnthropicApi.ChatCompletionResponse chatCompletionResponse = (AnthropicApi.ChatCompletionResponse) responseEntity.getBody();
            ChatResponse chatResponse3 = toChatResponse((AnthropicApi.ChatCompletionResponse) responseEntity.getBody(), UsageUtils.getCumulativeUsage(chatCompletionResponse.usage() != null ? getDefaultUsage(chatCompletionResponse.usage()) : new EmptyUsage(), chatResponse));
            build.setResponse(chatResponse3);
            return chatResponse3;
        });
        if (!this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse2)) {
            return chatResponse2;
        }
        ToolExecutionResult executeToolCalls = this.toolCallingManager.executeToolCalls(prompt, chatResponse2);
        return executeToolCalls.returnDirect() ? ChatResponse.builder().from(chatResponse2).generations(ToolExecutionResult.buildGenerations(executeToolCalls)).build() : internalCall(new Prompt(executeToolCalls.conversationHistory(), prompt.getOptions()), chatResponse2);
    }

    private DefaultUsage getDefaultUsage(AnthropicApi.Usage usage) {
        return new DefaultUsage(usage.inputTokens(), usage.outputTokens(), Integer.valueOf(usage.inputTokens().intValue() + usage.outputTokens().intValue()), usage);
    }

    public Flux<ChatResponse> stream(Prompt prompt) {
        return internalStream(buildRequestPrompt(prompt), null);
    }

    public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse chatResponse) {
        return Flux.deferContextual(contextView -> {
            AnthropicApi.ChatCompletionRequest createRequest = createRequest(prompt, true);
            ChatModelObservationContext build = ChatModelObservationContext.builder().prompt(prompt).provider(AnthropicApi.PROVIDER_NAME).requestOptions(prompt.getOptions()).build();
            Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> {
                return build;
            }, this.observationRegistry);
            observation.parentObservation((Observation) contextView.getOrDefault("micrometer.observation", (Object) null)).start();
            Flux switchMap = this.anthropicApi.chatCompletionStream(createRequest, getAdditionalHttpHeaders(prompt)).switchMap(chatCompletionResponse -> {
                ChatResponse chatResponse2 = toChatResponse(chatCompletionResponse, UsageUtils.getCumulativeUsage(chatCompletionResponse.usage() != null ? getDefaultUsage(chatCompletionResponse.usage()) : new EmptyUsage(), chatResponse));
                return (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), chatResponse2) && chatResponse2.hasFinishReasons(Set.of("tool_use"))) ? Flux.defer(() -> {
                    ToolExecutionResult executeToolCalls = this.toolCallingManager.executeToolCalls(prompt, chatResponse2);
                    return executeToolCalls.returnDirect() ? Flux.just(ChatResponse.builder().from(chatResponse2).generations(ToolExecutionResult.buildGenerations(executeToolCalls)).build()) : internalStream(new Prompt(executeToolCalls.conversationHistory(), prompt.getOptions()), chatResponse2);
                }).subscribeOn(Schedulers.boundedElastic()) : Mono.just(chatResponse2);
            });
            Objects.requireNonNull(observation);
            Flux contextWrite = switchMap.doOnError(observation::error).doFinally(signalType -> {
                observation.stop();
            }).contextWrite(context -> {
                return context.put("micrometer.observation", observation);
            });
            MessageAggregator messageAggregator = new MessageAggregator();
            Objects.requireNonNull(build);
            return messageAggregator.aggregate(contextWrite, (v1) -> {
                r2.setResponse(v1);
            });
        });
    }

    private ChatResponse toChatResponse(AnthropicApi.ChatCompletionResponse chatCompletionResponse, Usage usage) {
        if (chatCompletionResponse == null) {
            logger.warn("Null chat completion returned");
            return new ChatResponse(List.of());
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (AnthropicApi.ContentBlock contentBlock : chatCompletionResponse.content()) {
            switch (contentBlock.type()) {
                case TEXT:
                case TEXT_DELTA:
                    arrayList.add(new Generation(new AssistantMessage(contentBlock.text(), Map.of()), ChatGenerationMetadata.builder().finishReason(chatCompletionResponse.stopReason()).build()));
                    break;
                case THINKING:
                case THINKING_DELTA:
                    HashMap hashMap = new HashMap();
                    hashMap.put("signature", contentBlock.signature());
                    arrayList.add(new Generation(new AssistantMessage(contentBlock.thinking(), hashMap), ChatGenerationMetadata.builder().finishReason(chatCompletionResponse.stopReason()).build()));
                    break;
                case REDACTED_THINKING:
                    HashMap hashMap2 = new HashMap();
                    hashMap2.put("data", contentBlock.data());
                    arrayList.add(new Generation(new AssistantMessage((String) null, hashMap2), ChatGenerationMetadata.builder().finishReason(chatCompletionResponse.stopReason()).build()));
                    break;
                case TOOL_USE:
                    arrayList2.add(new AssistantMessage.ToolCall(contentBlock.id(), "function", contentBlock.name(), JsonParser.toJson(contentBlock.input())));
                    break;
            }
        }
        if (chatCompletionResponse.stopReason() != null && arrayList.isEmpty()) {
            arrayList.add(new Generation(new AssistantMessage((String) null, Map.of()), ChatGenerationMetadata.builder().finishReason(chatCompletionResponse.stopReason()).build()));
        }
        if (!CollectionUtils.isEmpty(arrayList2)) {
            arrayList.add(new Generation(new AssistantMessage("", Map.of(), arrayList2), ChatGenerationMetadata.builder().finishReason(chatCompletionResponse.stopReason()).build()));
        }
        return new ChatResponse(arrayList, from(chatCompletionResponse, usage));
    }

    private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse chatCompletionResponse) {
        return from(chatCompletionResponse, getDefaultUsage(chatCompletionResponse.usage()));
    }

    private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse chatCompletionResponse, Usage usage) {
        Assert.notNull(chatCompletionResponse, "Anthropic ChatCompletionResult must not be null");
        return ChatResponseMetadata.builder().id(chatCompletionResponse.id()).model(chatCompletionResponse.model()).usage(usage).keyValue("stop-reason", chatCompletionResponse.stopReason()).keyValue("stop-sequence", chatCompletionResponse.stopSequence()).keyValue("type", chatCompletionResponse.type()).build();
    }

    private String fromMediaData(Object obj) {
        if (obj instanceof byte[]) {
            return Base64.getEncoder().encodeToString((byte[]) obj);
        }
        if (obj instanceof String) {
            return (String) obj;
        }
        throw new IllegalArgumentException("Unsupported media data type: " + obj.getClass().getSimpleName());
    }

    private AnthropicApi.ContentBlock.Type getContentBlockTypeByMedia(Media media) {
        String mimeType = media.getMimeType().toString();
        if (mimeType.startsWith("image")) {
            return AnthropicApi.ContentBlock.Type.IMAGE;
        }
        if (mimeType.contains("pdf")) {
            return AnthropicApi.ContentBlock.Type.DOCUMENT;
        }
        throw new IllegalArgumentException("Unsupported media type: " + mimeType + ". Supported types are: images (image/*) and PDF documents (application/pdf)");
    }

    private MultiValueMap<String, String> getAdditionalHttpHeaders(Prompt prompt) {
        HashMap hashMap = new HashMap(this.defaultOptions.getHttpHeaders());
        if (prompt.getOptions() != null) {
            AnthropicChatOptions options = prompt.getOptions();
            if (options instanceof AnthropicChatOptions) {
                hashMap.putAll(options.getHttpHeaders());
            }
        }
        return CollectionUtils.toMultiValueMap((Map) hashMap.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            return List.of((String) entry.getValue());
        })));
    }

    Prompt buildRequestPrompt(Prompt prompt) {
        AnthropicChatOptions anthropicChatOptions = null;
        if (prompt.getOptions() != null) {
            ToolCallingChatOptions options = prompt.getOptions();
            anthropicChatOptions = options instanceof ToolCallingChatOptions ? (AnthropicChatOptions) ModelOptionsUtils.copyToTarget(options, ToolCallingChatOptions.class, AnthropicChatOptions.class) : (AnthropicChatOptions) ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, AnthropicChatOptions.class);
        }
        AnthropicChatOptions anthropicChatOptions2 = (AnthropicChatOptions) ModelOptionsUtils.merge(anthropicChatOptions, this.defaultOptions, AnthropicChatOptions.class);
        if (anthropicChatOptions != null) {
            anthropicChatOptions2.setHttpHeaders(mergeHttpHeaders(anthropicChatOptions.getHttpHeaders(), this.defaultOptions.getHttpHeaders()));
            anthropicChatOptions2.setInternalToolExecutionEnabled((Boolean) ModelOptionsUtils.mergeOption(anthropicChatOptions.isInternalToolExecutionEnabled(), this.defaultOptions.isInternalToolExecutionEnabled()));
            anthropicChatOptions2.setToolNames(ToolCallingChatOptions.mergeToolNames(anthropicChatOptions.getToolNames(), this.defaultOptions.getToolNames()));
            anthropicChatOptions2.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(anthropicChatOptions.getToolCallbacks(), this.defaultOptions.getToolCallbacks()));
            anthropicChatOptions2.setToolContext(ToolCallingChatOptions.mergeToolContext(anthropicChatOptions.getToolContext(), this.defaultOptions.getToolContext()));
        } else {
            anthropicChatOptions2.setHttpHeaders(this.defaultOptions.getHttpHeaders());
            anthropicChatOptions2.setInternalToolExecutionEnabled(this.defaultOptions.isInternalToolExecutionEnabled());
            anthropicChatOptions2.setToolNames(this.defaultOptions.getToolNames());
            anthropicChatOptions2.setToolCallbacks(this.defaultOptions.getToolCallbacks());
            anthropicChatOptions2.setToolContext(this.defaultOptions.getToolContext());
        }
        ToolCallingChatOptions.validateToolCallbacks(anthropicChatOptions2.getToolCallbacks());
        return new Prompt(prompt.getInstructions(), anthropicChatOptions2);
    }

    private Map<String, String> mergeHttpHeaders(Map<String, String> map, Map<String, String> map2) {
        HashMap hashMap = new HashMap(map2);
        hashMap.putAll(map);
        return hashMap;
    }

    AnthropicApi.ChatCompletionRequest createRequest(Prompt prompt, boolean z) {
        AnthropicApi.ChatCompletionRequest chatCompletionRequest = new AnthropicApi.ChatCompletionRequest(this.defaultOptions.getModel(), prompt.getInstructions().stream().filter(message -> {
            return message.getMessageType() != MessageType.SYSTEM;
        }).map(message2 -> {
            if (message2.getMessageType() == MessageType.USER) {
                ArrayList arrayList = new ArrayList(List.of(new AnthropicApi.ContentBlock(message2.getText())));
                if (message2 instanceof UserMessage) {
                    UserMessage userMessage = (UserMessage) message2;
                    if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
                        arrayList.addAll(userMessage.getMedia().stream().map(media -> {
                            return new AnthropicApi.ContentBlock(getContentBlockTypeByMedia(media), new AnthropicApi.ContentBlock.Source(media.getMimeType().toString(), fromMediaData(media.getData())));
                        }).toList());
                    }
                }
                return new AnthropicApi.AnthropicMessage(arrayList, AnthropicApi.Role.valueOf(message2.getMessageType().name()));
            }
            if (message2.getMessageType() != MessageType.ASSISTANT) {
                if (message2.getMessageType() == MessageType.TOOL) {
                    return new AnthropicApi.AnthropicMessage(((ToolResponseMessage) message2).getResponses().stream().map(toolResponse -> {
                        return new AnthropicApi.ContentBlock(AnthropicApi.ContentBlock.Type.TOOL_RESULT, toolResponse.id(), toolResponse.responseData());
                    }).toList(), AnthropicApi.Role.USER);
                }
                throw new IllegalArgumentException("Unsupported message type: " + String.valueOf(message2.getMessageType()));
            }
            AssistantMessage assistantMessage = (AssistantMessage) message2;
            ArrayList arrayList2 = new ArrayList();
            if (StringUtils.hasText(message2.getText())) {
                arrayList2.add(new AnthropicApi.ContentBlock(message2.getText()));
            }
            if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
                for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
                    arrayList2.add(new AnthropicApi.ContentBlock(AnthropicApi.ContentBlock.Type.TOOL_USE, toolCall.id(), toolCall.name(), (Map<String, Object>) ModelOptionsUtils.jsonToMap(toolCall.arguments())));
                }
            }
            return new AnthropicApi.AnthropicMessage(arrayList2, AnthropicApi.Role.ASSISTANT);
        }).toList(), (String) prompt.getInstructions().stream().filter(message3 -> {
            return message3.getMessageType() == MessageType.SYSTEM;
        }).map(message4 -> {
            return message4.getText();
        }).collect(Collectors.joining(System.lineSeparator())), this.defaultOptions.getMaxTokens(), this.defaultOptions.getTemperature(), Boolean.valueOf(z));
        AnthropicChatOptions options = prompt.getOptions();
        AnthropicApi.ChatCompletionRequest chatCompletionRequest2 = (AnthropicApi.ChatCompletionRequest) ModelOptionsUtils.merge(options, chatCompletionRequest, AnthropicApi.ChatCompletionRequest.class);
        List<ToolDefinition> resolveToolDefinitions = this.toolCallingManager.resolveToolDefinitions(options);
        if (!CollectionUtils.isEmpty(resolveToolDefinitions)) {
            chatCompletionRequest2 = AnthropicApi.ChatCompletionRequest.from((AnthropicApi.ChatCompletionRequest) ModelOptionsUtils.merge(chatCompletionRequest2, this.defaultOptions, AnthropicApi.ChatCompletionRequest.class)).tools(getFunctionTools(resolveToolDefinitions)).build();
        }
        return chatCompletionRequest2;
    }

    private List<AnthropicApi.Tool> getFunctionTools(List<ToolDefinition> list) {
        return list.stream().map(toolDefinition -> {
            return new AnthropicApi.Tool(toolDefinition.name(), toolDefinition.description(), (Map) JsonParser.fromJson(toolDefinition.inputSchema(), new TypeReference<Map<String, Object>>() { // from class: org.springframework.ai.anthropic.AnthropicChatModel.1
            }));
        }).toList();
    }

    public ChatOptions getDefaultOptions() {
        return AnthropicChatOptions.fromOptions(this.defaultOptions);
    }

    public void setObservationConvention(ChatModelObservationConvention chatModelObservationConvention) {
        Assert.notNull(chatModelObservationConvention, "observationConvention cannot be null");
        this.observationConvention = chatModelObservationConvention;
    }

    public static Builder builder() {
        return new Builder();
    }
}
