/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.genai.vector.providers;

import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.core.ObjectCodec;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.net.URI;
import java.net.http.HttpRequest;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.text.StringSubstitutor;
import org.eclipse.collections.api.IntIterable;
import org.eclipse.collections.api.factory.Maps;
import org.eclipse.collections.api.map.MapIterable;
import org.eclipse.collections.api.map.MutableMap;
import org.eclipse.collections.api.multimap.Multimap;
import org.eclipse.collections.api.multimap.list.MutableListMultimap;
import org.eclipse.collections.impl.factory.Multimaps;
import org.eclipse.collections.impl.factory.primitive.IntSets;
import org.neo4j.genai.util.HttpService;
import org.neo4j.genai.util.JsonUtils;
import org.neo4j.genai.util.MalformedGenAIResponseException;
import org.neo4j.genai.util.aws.AwsSignatureV4HeaderGenerator;
import org.neo4j.genai.vector.VectorEncoding;

public final class Bedrock
implements VectorEncoding.Provider<Parameters> {
    public static final String NAME = "Bedrock";
    private static final String ENDPOINT_TEMPLATE = "https://bedrock-runtime.${region}.amazonaws.com/model/${model}/invoke";
    static final String DEFAULT_REGION = "us-east-1";
    static final Set<String> SUPPORTED_REGIONS = Set.of("ap-northeast-1", "ap-northeast-2", "ap-northeast-3", "ap-south-1", "ap-south-2", "ap-southeast-1", "ap-southeast-2", "ca-central-1", "eu-central-1", "eu-central-2", "eu-north-1", "eu-south-1", "eu-south-2", "eu-west-1", "eu-west-2", "eu-west-3", "sa-east-1", "us-east-1", "us-east-2", "us-gov-east-1", "us-gov-west-1", "us-west-2");
    private static final String STRINGIFIED_SUPPORTED_REGIONS = SUPPORTED_REGIONS.stream().map(s -> "'" + s + "'").collect(Collectors.joining(", ", "[", "]"));
    static final MapIterable<String, Model> KNOWN_MODELS = Maps.immutable.of((Object)"amazon.titan-embed-text-v1", (Object)TitanEmbedTextG1Model.INSTANCE, (Object)"amazon.titan-embed-image-v1", (Object)TitanEmbedImageG1Model.INSTANCE, (Object)"amazon.titan-embed-text-v2:0", (Object)TitanEmbedTextV2Model.INSTANCE);
    static final String DEFAULT_MODEL = "amazon.titan-embed-text-v1";

    @Override
    public Class<Parameters> parameterDeclarations() {
        return Parameters.class;
    }

    @Override
    public String name() {
        return NAME;
    }

    @Override
    public VectorEncoding.Provider.Encoder configure(Parameters configuration) {
        if (!SUPPORTED_REGIONS.contains(configuration.region)) {
            throw new IllegalArgumentException("Provided region '%s' is not supported. Supported regions: %s".formatted(configuration.region, STRINGIFIED_SUPPORTED_REGIONS));
        }
        URI endpoint = URI.create(StringSubstitutor.replace((Object)ENDPOINT_TEMPLATE, Map.of("region", configuration.region, "model", configuration.model)));
        Model model = (Model)KNOWN_MODELS.getOrDefault((Object)configuration.model, (Object)FallbackModel.INSTANCE);
        model.validateConfiguration(configuration);
        return model.encoder(endpoint, configuration);
    }

    private static class TitanEmbedTextG1Model
    implements Model {
        private static final TitanEmbedTextG1Model INSTANCE = new TitanEmbedTextG1Model();
        private static final String NAME = "amazon.titan-embed-text-v1";

        private TitanEmbedTextG1Model() {
        }

        @Override
        public Encoder encoder(URI endpoint, Parameters configuration) {
            return new TitanEmbedTextG1Encoder(endpoint, configuration);
        }
    }

    public static class Parameters {
        public String accessKeyId;
        public String secretAccessKey;
        public String model = "amazon.titan-embed-text-v1";
        public String region = "us-east-1";
        public OptionalLong dimensions;
        public Optional<Boolean> normalize;
    }

    private static class FallbackModel
    implements Model {
        private static final FallbackModel INSTANCE = new FallbackModel();

        private FallbackModel() {
        }

        @Override
        public Encoder encoder(URI endpoint, Parameters configuration) {
            return new FallbackEncoder(endpoint, configuration);
        }
    }

    private static interface Model {
        default public void validateConfiguration(Parameters configuration) {
        }

        public Encoder encoder(URI var1, Parameters var2);
    }

    static abstract class Encoder
    implements VectorEncoding.Provider.Encoder {
        private final URI endpoint;
        protected final Parameters configuration;

        protected Encoder(URI endpoint, Parameters configuration) {
            this.endpoint = endpoint;
            this.configuration = configuration;
        }

        protected abstract Object buildPayload(String var1);

        @Override
        public float[] encode(HttpService httpService, String resource) {
            try {
                String body = this.createRequestBody(resource);
                return httpService.request(this.endpoint, builder -> {
                    HttpRequest intermediate = builder.build();
                    MutableListMultimap requestProperties = Multimaps.mutable.list.with((Object)"Host", (Object)this.endpoint.getHost());
                    intermediate.headers().map().forEach((arg_0, arg_1) -> ((MutableListMultimap)requestProperties).putAll(arg_0, arg_1));
                    Multimap<String, String> finalHeaders = new AwsSignatureV4HeaderGenerator(this.configuration.region, this.endpoint, body, (Multimap<String, String>)requestProperties).generate(this.configuration.accessKeyId, this.configuration.secretAccessKey);
                    HttpRequest.Builder newBuilder = HttpRequest.newBuilder(intermediate, (k, v) -> !finalHeaders.containsKey(k));
                    finalHeaders.forEachKeyValue(newBuilder::header);
                    return newBuilder.POST(HttpRequest.BodyPublishers.ofString(body)).build();
                }, Encoder::parseResponse);
            }
            catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        }

        static float[] parseResponse(InputStream inputStream) throws MalformedGenAIResponseException {
            float[] fArray;
            block11: {
                JsonNode tree;
                ObjectMapper objectMapper = JsonUtils.getObjectMapper();
                try {
                    tree = objectMapper.readTree(inputStream);
                }
                catch (IOException e) {
                    throw new MalformedGenAIResponseException("Unexpected error occurred while parsing the API response", e);
                }
                JsonNode embedding = Encoder.getExpectedFrom(tree, "embedding");
                if (!embedding.isArray()) {
                    throw new MalformedGenAIResponseException("Expected embedding to be an array");
                }
                JsonParser parser = embedding.traverse((ObjectCodec)objectMapper);
                try {
                    fArray = (float[])parser.readValueAs(JsonUtils.TYPE_REF_FLOAT_VECTOR);
                    if (parser == null) break block11;
                }
                catch (Throwable throwable) {
                    try {
                        if (parser != null) {
                            try {
                                parser.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    catch (IOException e) {
                        throw new MalformedGenAIResponseException("Unexpected error occurred while parsing the embedding", e);
                    }
                }
                parser.close();
            }
            return fArray;
        }

        private static JsonNode getExpectedFrom(JsonNode json, String property) throws MalformedGenAIResponseException {
            return JsonUtils.getExpectedFrom(Bedrock.NAME, json, property);
        }

        private String createRequestBody(String resource) throws IOException {
            return JsonUtils.getObjectMapper().writeValueAsString(this.buildPayload(resource));
        }
    }

    private static class TitanEmbedImageG1Model
    implements Model {
        private static final TitanEmbedImageG1Model INSTANCE = new TitanEmbedImageG1Model();
        private static final String NAME = "amazon.titan-embed-image-v1";
        private static final IntIterable VALID_DIMENSIONS = IntSets.immutable.of(new int[]{256, 384, 1024});
        private static final String STRINGIFIED_VALID_DIMENSIONS = VALID_DIMENSIONS.makeString("[", ", ", "]");

        private TitanEmbedImageG1Model() {
        }

        @Override
        public void validateConfiguration(Parameters configuration) {
            long dimensions;
            if (configuration.dimensions.isPresent() && !VALID_DIMENSIONS.contains((int)(dimensions = configuration.dimensions.getAsLong()))) {
                throw new IllegalArgumentException("Provided dimensions '%d' is not supported. Supported dimensions: %s".formatted(dimensions, STRINGIFIED_VALID_DIMENSIONS));
            }
        }

        @Override
        public Encoder encoder(URI endpoint, Parameters configuration) {
            return new TitanEmbedImageG1Encoder(endpoint, configuration);
        }
    }

    private static class TitanEmbedTextV2Model
    implements Model {
        private static final TitanEmbedTextV2Model INSTANCE = new TitanEmbedTextV2Model();
        private static final String NAME = "amazon.titan-embed-text-v2:0";
        private static final IntIterable VALID_DIMENSIONS = IntSets.immutable.of(new int[]{256, 512, 1024});
        private static final String STRINGIFIED_VALID_DIMENSIONS = VALID_DIMENSIONS.makeString("[", ", ", "]");

        private TitanEmbedTextV2Model() {
        }

        @Override
        public void validateConfiguration(Parameters configuration) {
            long dimensions;
            if (configuration.dimensions.isPresent() && !VALID_DIMENSIONS.contains((int)(dimensions = configuration.dimensions.getAsLong()))) {
                throw new IllegalArgumentException("Provided dimensions '%d' is not supported. Supported dimensions: %s".formatted(dimensions, STRINGIFIED_VALID_DIMENSIONS));
            }
        }

        @Override
        public Encoder encoder(URI endpoint, Parameters configuration) {
            return new TitanEmbedTextV2Encoder(endpoint, configuration);
        }
    }

    private static class TitanEmbedTextV2Encoder
    extends Encoder {
        private TitanEmbedTextV2Encoder(URI endpoint, Parameters configuration) {
            super(endpoint, configuration);
        }

        @Override
        protected Object buildPayload(String resource) {
            MutableMap payload = Maps.mutable.of((Object)"inputText", (Object)resource);
            this.configuration.dimensions.ifPresent(arg_0 -> TitanEmbedTextV2Encoder.lambda$buildPayload$0((Map)payload, arg_0));
            this.configuration.normalize.ifPresent(arg_0 -> TitanEmbedTextV2Encoder.lambda$buildPayload$1((Map)payload, arg_0));
            return payload;
        }

        private static /* synthetic */ void lambda$buildPayload$1(Map payload, Boolean normalize) {
            payload.put("normalize", normalize);
        }

        private static /* synthetic */ void lambda$buildPayload$0(Map payload, long dimensions) {
            payload.put("dimensions", dimensions);
        }
    }

    private static class TitanEmbedImageG1Encoder
    extends Encoder {
        private TitanEmbedImageG1Encoder(URI endpoint, Parameters configuration) {
            super(endpoint, configuration);
        }

        @Override
        protected Object buildPayload(String resource) {
            MutableMap payload = Maps.mutable.of((Object)"inputText", (Object)resource);
            this.configuration.dimensions.ifPresent(arg_0 -> TitanEmbedImageG1Encoder.lambda$buildPayload$0((Map)payload, arg_0));
            return payload;
        }

        private static /* synthetic */ void lambda$buildPayload$0(Map payload, long dimensions) {
            payload.put("embeddingConfig", Maps.mutable.of((Object)"outputEmbeddingLength", (Object)dimensions));
        }
    }

    private static class TitanEmbedTextG1Encoder
    extends Encoder {
        private TitanEmbedTextG1Encoder(URI endpoint, Parameters configuration) {
            super(endpoint, configuration);
        }

        @Override
        protected Object buildPayload(String resource) {
            return Maps.mutable.of((Object)"inputText", (Object)resource);
        }
    }

    private static class FallbackEncoder
    extends Encoder {
        private FallbackEncoder(URI endpoint, Parameters configuration) {
            super(endpoint, configuration);
        }

        @Override
        protected Object buildPayload(String resource) {
            return Maps.mutable.of((Object)"inputText", (Object)resource);
        }
    }
}

