/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.common.agent;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.text.StringEscapeUtils;
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLAgentType;
import org.opensearch.ml.common.connector.AwsConnector;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.ConnectorClientConfig;
import org.opensearch.ml.common.input.execute.agent.ContentBlock;
import org.opensearch.ml.common.input.execute.agent.ContentType;
import org.opensearch.ml.common.input.execute.agent.DocumentContent;
import org.opensearch.ml.common.input.execute.agent.ImageContent;
import org.opensearch.ml.common.input.execute.agent.Message;
import org.opensearch.ml.common.input.execute.agent.SourceType;
import org.opensearch.ml.common.input.execute.agent.ToolCall;
import org.opensearch.ml.common.input.execute.agent.VideoContent;
import org.opensearch.ml.common.model.ModelProvider;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;

public class BedrockConverseModelProvider
extends ModelProvider {
    private static final String DEFAULT_REGION = "us-east-1";
    private static final String REQUEST_BODY_TEMPLATE = "{\"system\": [{\"text\": \"${parameters.system_prompt}\"}], \"messages\": [${parameters._chat_history:-}${parameters.body}${parameters._interactions:-}]${parameters.tool_configs:-} }";
    private static final String TEXT_INPUT_BODY_TEMPLATE = "{\"role\":\"user\",\"content\":[{\"text\":\"${parameters.user_text}\"}]}";
    private static final String CONTENT_BLOCKS_BODY_TEMPLATE = "{\"role\":\"user\",\"content\":[${parameters.content_array}]}";
    private static final String TEXT_CONTENT_TEMPLATE = "{\"text\":\"${parameters.content_text}\"}";
    private static final String IMAGE_CONTENT_TEMPLATE = "{\"image\":{\"format\":\"${parameters.image_format}\",\"source\":{\"${parameters.image_source_type}\":\"${parameters.image_data}\"}}}";
    private static final String DOCUMENT_CONTENT_TEMPLATE = "{\"document\":{\"format\":\"${parameters.doc_format}\",\"name\":\"${parameters.doc_name}\",\"source\":{\"${parameters.doc_source_type}\":\"${parameters.doc_data}\"}}}";
    private static final String VIDEO_CONTENT_TEMPLATE = "{\"video\":{\"format\":\"${parameters.video_format}\",\"source\":{\"${parameters.video_source_type}\":\"${parameters.video_data}\"}}}";
    private static final String MESSAGE_TEMPLATE = "{\"role\":\"${parameters.msg_role}\",\"content\":[${parameters.msg_content_array}]}";

    @Override
    public Connector createConnector(String modelId, Map<String, String> credential, Map<String, String> modelParameters) {
        HashMap<String, String> parameters = new HashMap<String, String>();
        parameters.put("region", DEFAULT_REGION);
        parameters.put("service_name", "bedrock");
        parameters.put("model", modelId);
        if (modelParameters != null) {
            parameters.putAll(modelParameters);
        }
        HashMap<String, String> headers = new HashMap<String, String>();
        headers.put("content-type", "application/json");
        ConnectorAction predictAction = ConnectorAction.builder().actionType(ConnectorAction.ActionType.PREDICT).method("POST").url("https://bedrock-runtime.${parameters.region}.amazonaws.com/model/${parameters.model}/converse").headers(headers).requestBody(REQUEST_BODY_TEMPLATE).build();
        ConnectorClientConfig connectorClientConfig = new ConnectorClientConfig();
        connectorClientConfig.setMaxRetryTimes(3);
        return AwsConnector.awsConnectorBuilder().name("Auto-generated Bedrock Converse connector for Agent").description("Auto-generated connector for Bedrock Converse API").version("1").protocol("aws_sigv4").parameters(parameters).credential(credential != null ? credential : new HashMap()).actions(List.of(predictAction)).connectorClientConfig(connectorClientConfig).build();
    }

    @Override
    public MLRegisterModelInput createModelInput(String modelName, Connector connector, Map<String, String> modelParameters) {
        return MLRegisterModelInput.builder().functionName(FunctionName.REMOTE).modelName("Auto-generated model for " + modelName).description("Auto-generated model for agent").connector(connector).build();
    }

    @Override
    public String getLLMInterface() {
        return "bedrock/converse/claude";
    }

    @Override
    public Map<String, String> mapTextInput(String text, MLAgentType type) {
        HashMap<String, String> parameters = new HashMap<String, String>();
        HashMap<String, String> templateParams = new HashMap<String, String>();
        if (type == MLAgentType.PLAN_EXECUTE_AND_REFLECT) {
            templateParams.put("user_text", "${parameters.prompt}");
        } else {
            templateParams.put("user_text", StringEscapeUtils.escapeJson((String)text));
        }
        StringSubstitutor substitutor = new StringSubstitutor(templateParams, "${parameters.", "}");
        String body = substitutor.replace(TEXT_INPUT_BODY_TEMPLATE);
        parameters.put("body", body);
        return parameters;
    }

    @Override
    public Map<String, String> mapContentBlocks(List<ContentBlock> contentBlocks, MLAgentType type) {
        HashMap<String, String> parameters = new HashMap<String, String>();
        String contentArray = this.buildContentArrayFromBlocks(contentBlocks, type);
        HashMap<String, String> templateParams = new HashMap<String, String>();
        templateParams.put("content_array", contentArray);
        StringSubstitutor substitutor = new StringSubstitutor(templateParams, "${parameters.", "}");
        String body = substitutor.replace(CONTENT_BLOCKS_BODY_TEMPLATE);
        parameters.put("body", body);
        return parameters;
    }

    @Override
    public Map<String, String> mapMessages(List<Message> messages, MLAgentType type) {
        HashMap<String, String> parameters = new HashMap<String, String>();
        String messagesString = this.buildMessagesArray(messages, type);
        parameters.put("body", messagesString);
        parameters.put("no_escape_params", "_chat_history,_tools,_interactions,tool_configs,body");
        return parameters;
    }

    private String buildContentArrayFromBlocks(List<ContentBlock> blocks, MLAgentType type) {
        if (blocks == null || blocks.isEmpty()) {
            return "";
        }
        StringBuilder contentArray = new StringBuilder();
        boolean first = true;
        for (ContentBlock block : blocks) {
            if (!first) {
                contentArray.append(",");
            }
            first = false;
            switch (block.getType()) {
                case TEXT: {
                    HashMap<String, String> textParams = new HashMap<String, String>();
                    if (type == MLAgentType.PLAN_EXECUTE_AND_REFLECT) {
                        textParams.put("content_text", "${parameters.prompt}");
                    } else {
                        textParams.put("content_text", StringEscapeUtils.escapeJson((String)block.getText()));
                    }
                    StringSubstitutor textSubstitutor = new StringSubstitutor(textParams, "${parameters.", "}");
                    contentArray.append(textSubstitutor.replace(TEXT_CONTENT_TEMPLATE));
                    break;
                }
                case IMAGE: {
                    ImageContent image = block.getImage();
                    HashMap<String, String> imageParams = new HashMap<String, String>();
                    imageParams.put("image_format", image.getFormat());
                    imageParams.put("image_data", StringEscapeUtils.escapeJson((String)image.getData()));
                    String imageSourceType = this.mapSourceTypeToBedrock(image.getType(), image.getData());
                    imageParams.put("image_source_type", imageSourceType);
                    StringSubstitutor imageSubstitutor = new StringSubstitutor(imageParams, "${parameters.", "}");
                    contentArray.append(imageSubstitutor.replace(IMAGE_CONTENT_TEMPLATE));
                    break;
                }
                case DOCUMENT: {
                    DocumentContent document = block.getDocument();
                    HashMap<String, String> docParams = new HashMap<String, String>();
                    docParams.put("doc_format", document.getFormat());
                    docParams.put("doc_name", "document");
                    docParams.put("doc_data", StringEscapeUtils.escapeJson((String)document.getData()));
                    String docSourceType = this.mapSourceTypeToBedrock(document.getType(), document.getData());
                    docParams.put("doc_source_type", docSourceType);
                    StringSubstitutor docSubstitutor = new StringSubstitutor(docParams, "${parameters.", "}");
                    contentArray.append(docSubstitutor.replace(DOCUMENT_CONTENT_TEMPLATE));
                    break;
                }
                case VIDEO: {
                    VideoContent video = block.getVideo();
                    HashMap<String, String> videoParams = new HashMap<String, String>();
                    videoParams.put("video_format", video.getFormat());
                    videoParams.put("video_data", StringEscapeUtils.escapeJson((String)video.getData()));
                    String videoSourceType = this.mapSourceTypeToBedrock(video.getType(), video.getData());
                    videoParams.put("video_source_type", videoSourceType);
                    StringSubstitutor videoSubstitutor = new StringSubstitutor(videoParams, "${parameters.", "}");
                    contentArray.append(videoSubstitutor.replace(VIDEO_CONTENT_TEMPLATE));
                    break;
                }
            }
        }
        return contentArray.toString();
    }

    private String buildMessagesArray(List<Message> messages, MLAgentType type) {
        if (messages == null || messages.isEmpty()) {
            return "";
        }
        StringBuilder messagesArray = new StringBuilder();
        boolean first = true;
        for (int i = 0; i < messages.size(); ++i) {
            Message message = messages.get(i);
            StringBuilder contentArray = new StringBuilder();
            if ("tool".equalsIgnoreCase(message.getRole())) {
                int j;
                for (j = i; j < messages.size() && "tool".equalsIgnoreCase(messages.get(j).getRole()); ++j) {
                    if (contentArray.length() > 0) {
                        contentArray.append(",");
                    }
                    contentArray.append(this.buildToolResultBlock(messages.get(j)));
                }
                i = j - 1;
            } else if ("assistant".equalsIgnoreCase(message.getRole()) && message.getToolCalls() != null && !message.getToolCalls().isEmpty()) {
                if (message.getContent() != null && !message.getContent().isEmpty()) {
                    contentArray.append(this.buildContentArrayFromBlocks(message.getContent(), type));
                }
                for (ToolCall toolCall : message.getToolCalls()) {
                    if (contentArray.length() > 0) {
                        contentArray.append(",");
                    }
                    contentArray.append(this.buildToolUseBlock(toolCall));
                }
            } else if (message.getContent() != null && !message.getContent().isEmpty()) {
                contentArray.append(this.buildContentArrayFromBlocks(message.getContent(), type));
            }
            if (!first) {
                messagesArray.append(",");
            }
            first = false;
            String role = "tool".equalsIgnoreCase(message.getRole()) ? "user" : message.getRole();
            HashMap<String, String> msgParams = new HashMap<String, String>();
            msgParams.put("msg_role", role);
            msgParams.put("msg_content_array", contentArray.toString());
            StringSubstitutor msgSubstitutor = new StringSubstitutor(msgParams, "${parameters.", "}");
            messagesArray.append(msgSubstitutor.replace(MESSAGE_TEMPLATE));
        }
        return messagesArray.toString();
    }

    private String buildToolUseBlock(ToolCall toolCall) {
        HashMap<String, String> params = new HashMap<String, String>();
        params.put("tool_use_id", toolCall.getId());
        params.put("tool_name", toolCall.getFunction().getName());
        String arguments = toolCall.getFunction().getArguments();
        if (arguments == null || arguments.trim().isEmpty()) {
            arguments = "{}";
        }
        params.put("tool_input", arguments);
        String template = "{\"toolUse\":{\"toolUseId\":\"${parameters.tool_use_id}\",\"name\":\"${parameters.tool_name}\",\"input\":${parameters.tool_input}}}";
        StringSubstitutor substitutor = new StringSubstitutor(params, "${parameters.", "}");
        return substitutor.replace(template);
    }

    private String buildToolResultBlock(Message message) {
        HashMap<String, String> params = new HashMap<String, String>();
        params.put("tool_call_id", message.getToolCallId());
        String contentText = "";
        if (message.getContent() != null && !message.getContent().isEmpty()) {
            for (ContentBlock block : message.getContent()) {
                if (block.getType() != ContentType.TEXT) continue;
                contentText = StringEscapeUtils.escapeJson((String)block.getText());
                break;
            }
        }
        params.put("content_text", contentText);
        String template = "{\"toolResult\":{\"toolUseId\":\"${parameters.tool_call_id}\",\"content\":[{\"text\":\"${parameters.content_text}\"}]}}";
        StringSubstitutor substitutor = new StringSubstitutor(params, "${parameters.", "}");
        return substitutor.replace(template);
    }

    private String mapSourceTypeToBedrock(SourceType sourceType, String dataUrl) {
        if (sourceType == null) {
            String supportedTypes = Stream.of(SourceType.values()).map(Enum::name).collect(Collectors.joining(", "));
            throw new IllegalArgumentException("Source type is required. Supported types: " + supportedTypes);
        }
        return switch (sourceType) {
            case SourceType.BASE64 -> "bytes";
            case SourceType.URL -> {
                if (dataUrl == null || !dataUrl.startsWith("s3://")) {
                    throw new IllegalArgumentException("URL-based content must use S3 URIs (s3://...). Other URL schemes are not supported by Bedrock Converse API");
                }
                yield "s3Location";
            }
            default -> {
                String supportedTypes = Stream.of(SourceType.values()).map(Enum::name).collect(Collectors.joining(", "));
                throw new IllegalArgumentException("Unsupported source type. Supported types: " + supportedTypes);
            }
        };
    }
}

