/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.searchpipelines.questionanswering.generative.llm;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.searchpipelines.questionanswering.generative.client.MachineLearningInternalClient;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionInput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm;
import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil;

public class DefaultLlmImpl
implements Llm {
    @Generated
    private static final Logger log = LogManager.getLogger(DefaultLlmImpl.class);
    private static final String CONNECTOR_INPUT_PARAMETER_MODEL = "model";
    private static final String CONNECTOR_INPUT_PARAMETER_MESSAGES = "messages";
    private static final String CONNECTOR_OUTPUT_CHOICES = "choices";
    private static final String CONNECTOR_OUTPUT_MESSAGE = "message";
    private static final String CONNECTOR_OUTPUT_MESSAGE_ROLE = "role";
    private static final String CONNECTOR_OUTPUT_MESSAGE_CONTENT = "content";
    private static final String CONNECTOR_OUTPUT_ERROR = "error";
    private final String openSearchModelId;
    private MachineLearningInternalClient mlClient;

    public DefaultLlmImpl(String openSearchModelId, Client client) {
        Preconditions.checkNotNull((Object)openSearchModelId);
        this.openSearchModelId = openSearchModelId;
        this.mlClient = new MachineLearningInternalClient(client);
    }

    @VisibleForTesting
    protected void setMlClient(MachineLearningInternalClient mlClient) {
        this.mlClient = mlClient;
    }

    @Override
    public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionInput) {
        RemoteInferenceInputDataSet dataset = RemoteInferenceInputDataSet.builder().parameters(this.getInputParameters(chatCompletionInput)).build();
        MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset((MLInputDataset)dataset).build();
        ActionFuture<MLOutput> future = this.mlClient.predict(this.openSearchModelId, mlInput);
        ModelTensorOutput modelOutput = (ModelTensorOutput)future.actionGet((long)(chatCompletionInput.getTimeoutInSeconds() * 1000));
        Map dataAsMap = ((ModelTensor)((ModelTensors)modelOutput.getMlModelOutputs().get(0)).getMlModelTensors().get(0)).getDataAsMap();
        log.info("dataAsMap: {}", (Object)dataAsMap.toString());
        return this.buildChatCompletionOutput(chatCompletionInput.getModelProvider(), dataAsMap);
    }

    protected Map<String, String> getInputParameters(ChatCompletionInput chatCompletionInput) {
        HashMap<String, String> inputParameters = new HashMap<String, String>();
        if (chatCompletionInput.getModelProvider() == Llm.ModelProvider.OPENAI) {
            inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel());
            String messages = PromptUtil.getChatCompletionPrompt(chatCompletionInput.getSystemPrompt(), chatCompletionInput.getUserInstructions(), chatCompletionInput.getQuestion(), chatCompletionInput.getChatHistory(), chatCompletionInput.getContexts());
            inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
            log.info("Messages to LLM: {}", (Object)messages);
        } else if (chatCompletionInput.getModelProvider() == Llm.ModelProvider.BEDROCK) {
            inputParameters.put("inputs", PromptUtil.buildSingleStringPrompt(chatCompletionInput.getSystemPrompt(), chatCompletionInput.getUserInstructions(), chatCompletionInput.getQuestion(), chatCompletionInput.getChatHistory(), chatCompletionInput.getContexts()));
        } else {
            throw new IllegalArgumentException("Unknown/unsupported model provider: " + chatCompletionInput.getModelProvider());
        }
        log.info("LLM input parameters: {}", (Object)((Object)inputParameters).toString());
        return inputParameters;
    }

    protected ChatCompletionOutput buildChatCompletionOutput(Llm.ModelProvider provider, Map<String, ?> dataAsMap) {
        List<Object> answers = null;
        List<String> errors = null;
        if (provider == Llm.ModelProvider.OPENAI) {
            List choices = (List)dataAsMap.get(CONNECTOR_OUTPUT_CHOICES);
            if (choices == null) {
                Map error = (Map)dataAsMap.get(CONNECTOR_OUTPUT_ERROR);
                errors = List.of((String)error.get(CONNECTOR_OUTPUT_MESSAGE));
            } else {
                Map firstChoiceMap = (Map)choices.get(0);
                log.info("Choices: {}", (Object)firstChoiceMap.toString());
                Map message = (Map)firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE);
                log.info("role: {}, content: {}", message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE), message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
                answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
            }
        } else if (provider == Llm.ModelProvider.BEDROCK) {
            String response = (String)dataAsMap.get("completion");
            if (response != null) {
                answers = List.of(response);
            } else {
                Map error = (Map)dataAsMap.get(CONNECTOR_OUTPUT_ERROR);
                errors = error != null ? List.of((String)error.get(CONNECTOR_OUTPUT_MESSAGE)) : List.of("Unknown error or response.");
            }
        } else {
            throw new IllegalArgumentException("Unknown/unsupported model provider: " + provider);
        }
        return new ChatCompletionOutput(answers, errors);
    }
}

