package com.github.llmjava.cohere4j;

import com.github.llmjava.cohere4j.callback.AsyncCallback;
import com.github.llmjava.cohere4j.callback.StreamingCallback;
import com.github.llmjava.cohere4j.exception.CohereException;
import com.github.llmjava.cohere4j.request.ClassifyRequest;
import com.github.llmjava.cohere4j.request.DetectLanguageRequest;
import com.github.llmjava.cohere4j.request.DetokenizeRequest;
import com.github.llmjava.cohere4j.request.EmbedRequest;
import com.github.llmjava.cohere4j.request.GenerateRequest;
import com.github.llmjava.cohere4j.request.RerankRequest;
import com.github.llmjava.cohere4j.request.SummarizeRequest;
import com.github.llmjava.cohere4j.request.TokenizeRequest;
import com.github.llmjava.cohere4j.response.ClassifyResponse;
import com.github.llmjava.cohere4j.response.DetectLanguageResponse;
import com.github.llmjava.cohere4j.response.DetokenizeResponse;
import com.github.llmjava.cohere4j.response.EmbedResponse;
import com.github.llmjava.cohere4j.response.GenerateResponse;
import com.github.llmjava.cohere4j.response.RerankResponse;
import com.github.llmjava.cohere4j.response.SummarizeResponse;
import com.github.llmjava.cohere4j.response.TokenizeResponse;
import com.github.llmjava.cohere4j.response.streaming.ResponseConverter;
import com.github.llmjava.cohere4j.response.streaming.StreamGenerateResponse;
import com.google.gson.Gson;
import java.io.IOException;
import retrofit2.Call;
import retrofit2.Callback;
import retrofit2.Response;

/* loaded from: input_file:com/github/llmjava/cohere4j/CohereClient.class */
public class CohereClient {
    private final CohereApi api;
    private final CohereConfig config;
    private final Gson gson;

    /* loaded from: input_file:com/github/llmjava/cohere4j/CohereClient$Builder.class */
    public static class Builder {
        private CohereApi api;
        private CohereConfig config;
        private Gson gson;

        public Builder withConfig(CohereConfig cohereConfig) {
            this.config = cohereConfig;
            CohereApiFactory cohereApiFactory = new CohereApiFactory();
            this.api = cohereApiFactory.createGson().createHttpClient(cohereConfig).build();
            this.gson = cohereApiFactory.gson;
            return this;
        }

        public CohereClient build() {
            return new CohereClient(this);
        }
    }

    CohereClient(Builder builder) {
        this.api = builder.api;
        this.config = builder.config;
        this.gson = builder.gson;
    }

    public GenerateResponse generate(GenerateRequest generateRequest) {
        return (GenerateResponse) execute(this.api.generate(generateRequest));
    }

    public void generateAsync(GenerateRequest generateRequest, AsyncCallback<GenerateResponse> asyncCallback) {
        execute(this.api.generate(generateRequest), asyncCallback);
    }

    public void generateStream(GenerateRequest generateRequest, final StreamingCallback<StreamGenerateResponse> streamingCallback) {
        if (!generateRequest.isStreaming().booleanValue()) {
            throw new IllegalArgumentException("Expected a streaming request");
        }
        final ResponseConverter responseConverter = new ResponseConverter(this.gson);
        this.api.generateStream(generateRequest).enqueue(new Callback<String>() { // from class: com.github.llmjava.cohere4j.CohereClient.1
            public void onResponse(Call<String> call, Response<String> response) {
                if (!response.isSuccessful()) {
                    streamingCallback.onFailure(CohereException.fromResponse(response));
                    return;
                }
                for (StreamGenerateResponse streamGenerateResponse : responseConverter.toStreamingGenerationResponse((String) response.body())) {
                    if (streamGenerateResponse.isFinished().booleanValue()) {
                        streamingCallback.onComplete(streamGenerateResponse);
                    } else {
                        streamingCallback.onPart(streamGenerateResponse);
                    }
                }
            }

            public void onFailure(Call<String> call, Throwable th) {
                streamingCallback.onFailure(th);
            }
        });
    }

    public EmbedResponse embed(EmbedRequest embedRequest) {
        return (EmbedResponse) execute(this.api.embed(embedRequest));
    }

    public void embedAsync(EmbedRequest embedRequest, AsyncCallback<EmbedResponse> asyncCallback) {
        execute(this.api.embed(embedRequest), asyncCallback);
    }

    public ClassifyResponse classify(ClassifyRequest classifyRequest) {
        return (ClassifyResponse) execute(this.api.classify(classifyRequest));
    }

    public void classifyAsync(ClassifyRequest classifyRequest, AsyncCallback<ClassifyResponse> asyncCallback) {
        execute(this.api.classify(classifyRequest), asyncCallback);
    }

    public TokenizeResponse tokenize(TokenizeRequest tokenizeRequest) {
        return (TokenizeResponse) execute(this.api.tokenize(tokenizeRequest));
    }

    public void tokenizeAsync(TokenizeRequest tokenizeRequest, AsyncCallback<TokenizeResponse> asyncCallback) {
        execute(this.api.tokenize(tokenizeRequest), asyncCallback);
    }

    public DetokenizeResponse detokenize(DetokenizeRequest detokenizeRequest) {
        return (DetokenizeResponse) execute(this.api.detokenize(detokenizeRequest));
    }

    public void detokenizeAsync(DetokenizeRequest detokenizeRequest, AsyncCallback<DetokenizeResponse> asyncCallback) {
        execute(this.api.detokenize(detokenizeRequest), asyncCallback);
    }

    public DetectLanguageResponse detectLanguage(DetectLanguageRequest detectLanguageRequest) {
        return (DetectLanguageResponse) execute(this.api.detectLanguage(detectLanguageRequest));
    }

    public void detectLanguageAsync(DetectLanguageRequest detectLanguageRequest, AsyncCallback<DetectLanguageResponse> asyncCallback) {
        execute(this.api.detectLanguage(detectLanguageRequest), asyncCallback);
    }

    public SummarizeResponse summarize(SummarizeRequest summarizeRequest) {
        return (SummarizeResponse) execute(this.api.summarize(summarizeRequest));
    }

    public void summarizeAsync(SummarizeRequest summarizeRequest, AsyncCallback<SummarizeResponse> asyncCallback) {
        execute(this.api.summarize(summarizeRequest), asyncCallback);
    }

    public RerankResponse rerank(RerankRequest rerankRequest) {
        return (RerankResponse) execute(this.api.rerank(rerankRequest));
    }

    public void rerankAsync(RerankRequest rerankRequest, AsyncCallback<RerankResponse> asyncCallback) {
        execute(this.api.rerank(rerankRequest), asyncCallback);
    }

    private <T> T execute(Call<T> call) {
        try {
            Response execute = call.execute();
            if (execute.isSuccessful()) {
                return (T) execute.body();
            }
            throw CohereException.fromResponse(execute);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private <T> void execute(Call<T> call, final AsyncCallback<T> asyncCallback) {
        call.enqueue(new Callback<T>() { // from class: com.github.llmjava.cohere4j.CohereClient.2
            /* JADX WARN: Multi-variable type inference failed */
            public void onResponse(Call<T> call2, Response<T> response) {
                if (response.isSuccessful()) {
                    asyncCallback.onSuccess(response.body());
                } else {
                    asyncCallback.onFailure(CohereException.fromResponse(response));
                }
            }

            public void onFailure(Call<T> call2, Throwable th) {
                asyncCallback.onFailure(th);
            }
        });
    }
}
