/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.genai.dbs.providers;

import com.fasterxml.jackson.core.JsonProcessingException;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.UnaryOperator;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.neo4j.genai.dbs.CollectionNotFoundException;
import org.neo4j.genai.dbs.RowMappingConfig;
import org.neo4j.genai.dbs.VectorDatabaseProvider;
import org.neo4j.genai.dbs.VectorDatabaseRequest;
import org.neo4j.genai.dbs.VectorDatabases;
import org.neo4j.genai.util.GenAIProcedureException;
import org.neo4j.genai.util.HttpService;
import org.neo4j.genai.util.JsonUtils;

public class ChromaDb
implements VectorDatabaseProvider {
    private static final UnaryOperator<String> CREATE_BASE_URI = arg_0 -> ChromaDb.lambda$static$0("%s/api/v1/collections", arg_0);
    private static final BinaryOperator<String> CREATE_COLLECTION_BASE_URI = (host, collection) -> (String)CREATE_BASE_URI.apply((String)host) + "/" + collection;
    private static final BiFunction<String, String, String> CREATE_GET_POINTS_BASE_URI = CREATE_COLLECTION_BASE_URI.andThen(v -> v + "/get");
    private static final BiFunction<String, String, String> CREATE_UPSERT_POINTS_BASE_URI = CREATE_COLLECTION_BASE_URI.andThen(v -> v + "/upsert");
    private static final BiFunction<String, String, String> CREATE_QUERY_BASE_URI = CREATE_COLLECTION_BASE_URI.andThen(v -> v + "/query");
    private static final String IDS_KEY = "ids";

    @Override
    public <T> VectorDatabaseRequest<T> createRequestFor(VectorDatabaseProvider.Command command, String host, String collection, Map<String, Object> configuration, Map<String, Object> additionalArguments) {
        RowMappingConfig rowMappingConfig = (RowMappingConfig)additionalArguments.get("rowMappingConfig");
        VectorDatabases.ProcedureArguments procedureArguments = (VectorDatabases.ProcedureArguments)additionalArguments.get("procedureArguments");
        Function<HttpRequest.Builder, HttpRequest.Builder> commonRequestBuilder = httpRequestBuilder -> ChromaDb.addHttpVersion(this.addAuthorizationHeader(configuration, (HttpRequest.Builder)httpRequestBuilder));
        if (command == VectorDatabaseProvider.Command.GET_COLLECTION_METADATA) {
            return ChromaDb.createGetCollectionMetadataRequest(host, collection, commonRequestBuilder);
        }
        if (command == VectorDatabaseProvider.Command.GET) {
            return ChromaDb.createGetRequest(host, collection, additionalArguments, procedureArguments, rowMappingConfig, commonRequestBuilder);
        }
        if (command == VectorDatabaseProvider.Command.QUERY) {
            return ChromaDb.createQueryRequest(host, collection, additionalArguments, procedureArguments, rowMappingConfig, commonRequestBuilder);
        }
        if (command == VectorDatabaseProvider.Command.UPSERT) {
            return ChromaDb.createUpsertRequest(host, collection, configuration, additionalArguments, commonRequestBuilder);
        }
        if (command == VectorDatabaseProvider.Command.DELETE_COLLECTION) {
            return ChromaDb.createDeleteCollectionRequest(host, collection, commonRequestBuilder);
        }
        if (command == VectorDatabaseProvider.Command.CREATE_COLLECTION) {
            return ChromaDb.createCreateCollectionRequest(host, collection, additionalArguments, commonRequestBuilder);
        }
        if (command == VectorDatabaseProvider.Command.DELETE) {
            return ChromaDb.createDeleteRequest(host, collection, configuration, additionalArguments, commonRequestBuilder);
        }
        throw new UnsupportedOperationException();
    }

    private static <T> VectorDatabaseRequest<T> createDeleteRequest(String host, String collection, Map<String, Object> configuration, Map<String, Object> additionalArguments, Function<HttpRequest.Builder, HttpRequest.Builder> commonRequestBuilder) {
        URI target = URI.create((String)CREATE_COLLECTION_BASE_URI.apply(host, collection) + "/delete");
        Function<HttpRequest.Builder, HttpRequest> requestCustomizer = httpRequestBuilder -> {
            try {
                String body = JsonUtils.getObjectMapper().writeValueAsString(Map.of(IDS_KEY, additionalArguments.get(IDS_KEY)));
                return httpRequestBuilder.POST(HttpRequest.BodyPublishers.ofString(body)).build();
            }
            catch (JsonProcessingException e) {
                throw new RuntimeException(e);
            }
        };
        return new VectorDatabaseRequest<Object>(target, commonRequestBuilder.andThen(requestCustomizer), in -> VectorDatabases.StatusDTO.ok(null));
    }

    private static <T> VectorDatabaseRequest<T> createCreateCollectionRequest(String host, String collection, Map<String, Object> additionalArguments, Function<HttpRequest.Builder, HttpRequest.Builder> commonRequestBuilder) {
        URI target = URI.create((String)CREATE_BASE_URI.apply(host));
        Function<HttpRequest.Builder, HttpRequest> requestCustomizer = httpRequestBuilder -> {
            try {
                String body = JsonUtils.getObjectMapper().writeValueAsString(Map.of("name", collection, "metadata", Map.of("size", additionalArguments.get("size"), "hnsw:space", additionalArguments.get("similarity").toString().toLowerCase())));
                return httpRequestBuilder.POST(HttpRequest.BodyPublishers.ofString(body)).build();
            }
            catch (JsonProcessingException e) {
                throw new RuntimeException(e);
            }
        };
        return new VectorDatabaseRequest<Object>(target, commonRequestBuilder.andThen(requestCustomizer), in -> VectorDatabases.StatusDTO.ok(null));
    }

    private static <T> VectorDatabaseRequest<T> createUpsertRequest(String host, String collection, Map<String, Object> configuration, Map<String, Object> additionalArguments, Function<HttpRequest.Builder, HttpRequest.Builder> commonRequestBuilder) {
        URI target = URI.create(CREATE_UPSERT_POINTS_BASE_URI.apply(host, collection));
        Function<HttpRequest.Builder, HttpRequest> requestCustomizer = httpRequestBuilder -> {
            try {
                ArrayList<List<Double>> embeddings = new ArrayList<List<Double>>();
                ArrayList<Map<String, Object>> metadatas = new ArrayList<Map<String, Object>>();
                ArrayList<Object> ids = new ArrayList<Object>();
                List vectors = (List)additionalArguments.get("vectors");
                for (Map vector : vectors) {
                    embeddings.add((List)vector.get("vector"));
                    ids.add(vector.get("id"));
                    metadatas.add((Map)vector.get("metadata"));
                }
                String body = JsonUtils.getObjectMapper().writeValueAsString((Object)new UpsertPayload(embeddings, metadatas, ids));
                return httpRequestBuilder.POST(HttpRequest.BodyPublishers.ofString(body)).build();
            }
            catch (JsonProcessingException e) {
                throw new RuntimeException(e);
            }
        };
        return new VectorDatabaseRequest<Object>(target, commonRequestBuilder.andThen(requestCustomizer), in -> VectorDatabases.StatusDTO.ok(null));
    }

    @Override
    public BiFunction<Integer, String, Optional<GenAIProcedureException>> getProviderSpecificStatusHandler(String collection) {
        return (statusCode, message) -> {
            boolean collectionNotFoundMessageExists = message.contains("Collection " + collection + " does not exist.");
            if (statusCode == 400 && collectionNotFoundMessageExists) {
                return Optional.of(new CollectionNotFoundException(collection));
            }
            if (statusCode == 500 && collectionNotFoundMessageExists) {
                return Optional.of(new GenAIProcedureException("API request forbidden (HTTP response code: 403); check your credentials.", 500));
            }
            return Optional.empty();
        };
    }

    private static <T> VectorDatabaseRequest<T> createQueryRequest(String host, String collection, Map<String, Object> additionalArguments, VectorDatabases.ProcedureArguments procedureArguments, RowMappingConfig rowMappingConfig, Function<HttpRequest.Builder, HttpRequest.Builder> commonRequestBuilder) {
        URI target = URI.create(CREATE_QUERY_BASE_URI.apply(host, collection));
        Function<HttpRequest.Builder, HttpRequest> requestCustomizer = httpRequestBuilder -> {
            try {
                HashMap<String, List<String>> requestParameters = new HashMap<String, List<String>>(Map.of("query_embeddings", List.of(additionalArguments.get("vector")), "n_results", additionalArguments.get("limit"), "include", List.of("metadatas", "embeddings", "distances")));
                requestParameters.put("where", (List<String>)((Object)((Optional)additionalArguments.get("filter")).orElseGet(Map::of)));
                String body = JsonUtils.getObjectMapper().writeValueAsString(requestParameters);
                return httpRequestBuilder.POST(HttpRequest.BodyPublishers.ofString(body)).build();
            }
            catch (JsonProcessingException e) {
                throw new RuntimeException(e);
            }
        };
        Function<InputStream, Object> responseTransformer = inputStream -> {
            try {
                Map map = (Map)JsonUtils.getObjectMapper().readValue(inputStream, JsonUtils.TYPE_REF_MAP_STRING_OBJECT);
                List ids = (List)((List)map.get(IDS_KEY)).get(0);
                List embeddings = (List)((List)map.get("embeddings")).get(0);
                List metadatas = (List)((List)map.get("metadatas")).get(0);
                List distances = (List)((List)map.get("distances")).get(0);
                return IntStream.rangeClosed(0, ids.size() - 1).mapToObj(i -> {
                    HashMap row = new HashMap();
                    row.put(rowMappingConfig.idKey(), ids.get(i));
                    row.put(rowMappingConfig.metadataKey(), metadatas.get(i));
                    row.put(rowMappingConfig.scoreKey(), distances.get(i));
                    if (procedureArguments.allResults() && embeddings != null && embeddings.get(i) != null) {
                        row.put(rowMappingConfig.vectorKey(), embeddings.get(i));
                    }
                    return row;
                });
            }
            catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        };
        return new VectorDatabaseRequest<Object>(target, commonRequestBuilder.andThen(requestCustomizer), responseTransformer);
    }

    private static <T> VectorDatabaseRequest<T> createGetRequest(String host, String collection, Map<String, Object> additionalArguments, VectorDatabases.ProcedureArguments procedureArguments, RowMappingConfig rowMappingConfig, Function<HttpRequest.Builder, HttpRequest.Builder> commonRequestBuilder) {
        URI target = URI.create(CREATE_GET_POINTS_BASE_URI.apply(host, collection));
        Function<HttpRequest.Builder, HttpRequest> requestCustomizer = httpRequestBuilder -> {
            try {
                String body = JsonUtils.getObjectMapper().writeValueAsString(Map.of(IDS_KEY, additionalArguments.get(IDS_KEY), "include", List.of("metadatas", "embeddings")));
                return httpRequestBuilder.POST(HttpRequest.BodyPublishers.ofString(body)).build();
            }
            catch (JsonProcessingException e) {
                throw new RuntimeException(e);
            }
        };
        Result result = new Result(target, commonRequestBuilder.andThen(requestCustomizer));
        Function<InputStream, Stream> resultTransformer = inputStream -> {
            try {
                Map map = (Map)JsonUtils.getObjectMapper().readValue(inputStream, JsonUtils.TYPE_REF_MAP_STRING_OBJECT);
                List ids = (List)map.get(IDS_KEY);
                List embeddings = (List)map.get("embeddings");
                List metadatas = (List)map.get("metadatas");
                return IntStream.rangeClosed(0, ids.size() - 1).mapToObj(i -> {
                    HashMap row = new HashMap();
                    row.put(rowMappingConfig.idKey(), ids.get(i));
                    row.put(rowMappingConfig.metadataKey(), metadatas.get(i));
                    if (procedureArguments.allResults() && embeddings != null && embeddings.get(i) != null) {
                        row.put(rowMappingConfig.vectorKey(), embeddings.get(i));
                    }
                    return row;
                });
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        };
        return new VectorDatabaseRequest<Stream>(result.target(), result.requestCustomizer(), resultTransformer);
    }

    private static <T> VectorDatabaseRequest<T> createGetCollectionMetadataRequest(String host, String collection, Function<HttpRequest.Builder, HttpRequest.Builder> commonRequestBuilder) {
        URI target = URI.create((String)CREATE_COLLECTION_BASE_URI.apply(host, collection));
        Function<InputStream, VectorDatabases.InfoDTO> responseTransformer = HttpService.DEFAULT_RESPONSE_TO_MAP_TRANSFORMER.andThen(VectorDatabases.InfoDTO::of);
        return new VectorDatabaseRequest<VectorDatabases.InfoDTO>(target, commonRequestBuilder.andThen(HttpRequest.Builder::build), responseTransformer);
    }

    private static <T> VectorDatabaseRequest<T> createDeleteCollectionRequest(String host, String collection, Function<HttpRequest.Builder, HttpRequest.Builder> commonRequestBuilder) {
        URI target = URI.create((String)CREATE_COLLECTION_BASE_URI.apply(host, collection));
        Function<HttpRequest.Builder, HttpRequest> requestCustomizer = httpRequestBuilder -> httpRequestBuilder.DELETE().build();
        Function<InputStream, VectorDatabases.StatusDTO> responseTransformer = HttpService.DEFAULT_RESPONSE_TO_MAP_TRANSFORMER.andThen(something -> new VectorDatabases.StatusDTO("ok", Map.of()));
        return new VectorDatabaseRequest<VectorDatabases.StatusDTO>(target, commonRequestBuilder.andThen(requestCustomizer), responseTransformer);
    }

    private static HttpRequest.Builder addHttpVersion(HttpRequest.Builder httpRequestBuilder) {
        return httpRequestBuilder.version(HttpClient.Version.HTTP_1_1);
    }

    private static /* synthetic */ String lambda$static$0(String rec$, Object xva$0) {
        return "%s/api/v1/collections".formatted(xva$0);
    }

    private record Result(URI target, Function<HttpRequest.Builder, HttpRequest> requestCustomizer) {
    }

    private record UpsertPayload(List<List<Double>> embeddings, List<Map<String, Object>> metadatas, List<Object> ids) {
    }
}

