/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.genai.vector;

import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalDouble;
import java.util.OptionalLong;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.lang3.mutable.MutableInt;
import org.eclipse.collections.api.RichIterable;
import org.eclipse.collections.api.factory.Lists;
import org.eclipse.collections.api.factory.Maps;
import org.eclipse.collections.api.factory.primitive.IntLists;
import org.eclipse.collections.api.list.ImmutableList;
import org.eclipse.collections.api.list.MutableList;
import org.eclipse.collections.api.list.primitive.MutableIntList;
import org.eclipse.collections.api.map.MutableMap;
import org.eclipse.collections.impl.collection.mutable.CollectionAdapter;
import org.neo4j.annotations.service.Service;
import org.neo4j.genai.util.HttpService;
import org.neo4j.genai.util.Monitors;
import org.neo4j.genai.util.Parameters;
import org.neo4j.genai.vector.VectorEncodingCallCountersMonitor;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;
import org.neo4j.procedure.Sensitive;
import org.neo4j.procedure.UserFunction;
import org.neo4j.service.Services;
import org.neo4j.util.Preconditions;
import org.neo4j.values.AnyValue;
import org.neo4j.values.storable.Value;
import org.neo4j.values.storable.Values;
import org.neo4j.values.virtual.MapValue;

public class VectorEncoding {
    public static final String VERSION = "1.2.0";
    private static final ImmutableList<Provider> PROVIDERS = Lists.immutable.withAllSorted(Comparator.comparing(Provider::name, String.CASE_INSENSITIVE_ORDER), (RichIterable)CollectionAdapter.adapt((Collection)Services.loadAll(Provider.class)));
    @Context
    public GraphDatabaseService graphDatabaseService;
    private static final HttpService httpService = new HttpService();

    @Procedure(name="genai.vector.listEncodingProviders")
    @Description(value="Lists the available vector embedding providers.")
    public Stream<ProviderRow> listEncodingProviders() {
        return PROVIDERS.stream().map(ProviderRow::from);
    }

    @UserFunction(name="genai.vector.encode")
    @Description(value="Encode a given resource as a vector using the named provider.")
    public Value encode(@Name(value="resource", description="The object to transform into an embedding.") String resource, @Name(value="provider", description="The identifier of the provider: (\"VertexAI\", \"OpenAI\", \"AzureOpenAI\", \"Bedrock\").") String providerName, @Sensitive @Name(value="configuration", defaultValue="{}", description="VertexAI: {token :: STRING, projectId :: STRING, model :: STRING, region :: STRING, taskType :: STRING, title :: STRING }\n\nOpenAI: {token :: STRING, model :: STRING, dimensions :: INTEGER}\n\nAzureOpenAI: {token :: STRING, resource :: STRING, deployment :: STRING, dimensions :: INTEGER}\n\nAmazonBedrock: {accessKeyId :: STRING, secretAccessKey :: STRING, model :: STRING, region :: STRING}") AnyValue configuration) {
        Objects.requireNonNull(providerName, "'provider' must not be null");
        MapValue configurationMap = VectorEncoding.requireNonNullMap(configuration);
        Provider<?> provider = VectorEncoding.getProvider(providerName);
        this.getMonitor().encodeFunctionCalled(provider.name());
        if (resource == null) {
            return Values.NO_VALUE;
        }
        return Values.floatArray((float[])provider.configure(configurationMap).encode(httpService, resource));
    }

    @Procedure(name="genai.vector.encodeBatch")
    @Description(value="Encode a given batch of resources as vectors using the named provider.\nFor each element in the given resource LIST this returns:\n    * the corresponding 'index' within that LIST,\n    * the original 'resource' element itself,\n    * and the encoded 'vector'.\n")
    public Stream<InternalBatchRow> encode(@Name(value="resources", description="The object to transform into an embedding.") List<String> resources, @Name(value="provider", description="The GenAI provider to use.") String providerName, @Sensitive @Name(value="configuration", defaultValue="{}", description="The provider specific settings.") AnyValue configuration) {
        Objects.requireNonNull(resources, "'resources' must not be null");
        Preconditions.checkArgument((!resources.isEmpty() ? 1 : 0) != 0, (String)"'resources' must not be empty");
        Objects.requireNonNull(providerName, "'provider' must not be null");
        MapValue configurationMap = VectorEncoding.requireNonNullMap(configuration);
        Provider<?> provider = VectorEncoding.getProvider(providerName);
        this.getMonitor().encodeBatchProcedureCalled(provider.name());
        MutableIntList removedIndexes = IntLists.mutable.empty();
        MutableList cleanedResources = Lists.mutable.withInitialCapacity(resources.size());
        ListIterator<String> it = resources.listIterator();
        while (it.hasNext()) {
            int index = it.nextIndex();
            String resource = it.next();
            if (resource == null) {
                removedIndexes.add(index);
                continue;
            }
            cleanedResources.add((Object)resource);
        }
        return provider.configure(configurationMap).encode(httpService, (List<String>)cleanedResources, removedIndexes.toArray()).map(InternalBatchRow::new);
    }

    private static MapValue requireNonNullMap(AnyValue configuration) {
        if (configuration == Values.NO_VALUE) {
            throw new IllegalArgumentException("'configuration' must not be null");
        }
        if (!(configuration instanceof MapValue)) {
            throw new IllegalArgumentException("'configuration' must be a map");
        }
        MapValue map = (MapValue)configuration;
        return map;
    }

    private VectorEncodingCallCountersMonitor getMonitor() {
        return Monitors.getMonitor(this.graphDatabaseService, VectorEncodingCallCountersMonitor.class);
    }

    static Provider<?> getProvider(String name) {
        for (Provider provider : PROVIDERS) {
            if (String.CASE_INSENSITIVE_ORDER.compare(provider.name(), name) != 0) continue;
            return provider;
        }
        throw new RuntimeException("Vector encoding provider not supported: %s".formatted(name));
    }

    @Service
    public static interface Provider<PARAMETERS> {
        public Class<PARAMETERS> parameterDeclarations();

        public String name();

        default public Encoder configure(MapValue configuration) {
            return this.configure(Parameters.parse(this.parameterDeclarations(), configuration));
        }

        public Encoder configure(PARAMETERS var1);

        public static interface Encoder {
            public float[] encode(HttpService var1, String var2);

            default public Stream<BatchRow> encode(HttpService httpService, List<String> resources, int[] nullIndexes) {
                MutableInt offset = new MutableInt();
                return IntStream.range(0, resources.size() + nullIndexes.length).mapToObj(index -> {
                    if (Arrays.binarySearch(nullIndexes, index) >= 0) {
                        offset.increment();
                        return new BatchRow(index, null, null);
                    }
                    int offsetIndex = index - offset.intValue();
                    String resource = (String)resources.get(offsetIndex);
                    return new BatchRow(index, resource, this.encode(httpService, resource));
                });
            }
        }
    }

    public record InternalBatchRow(@Description(value="The index of the corresponding element in the input list.") long index, @Description(value="The name of the input resource.") String resource, @Description(value="The generated vector embedding for the resource.") Value vector) {
        InternalBatchRow(BatchRow row) {
            this(row.index(), row.resource(), Values.of((Object)row.vector));
        }
    }

    public record BatchRow(long index, String resource, float[] vector) {
    }

    public record ProviderRow(@Description(value="The name of the GenAI provider.") String name, @Description(value="The signature of the required config map.") String requiredConfigType, @Description(value="The signature of the optional config map.") String optionalConfigType, @Description(value="The default values for the GenAI provider.") Map<String, Object> defaultConfig) {
        public static ProviderRow from(Provider<?> provider) {
            List<Parameters.Parameter> parameters = Parameters.getParameters(provider.parameterDeclarations());
            return new ProviderRow(provider.name(), ProviderRow.requiredConfigType(parameters), ProviderRow.optionalConfigType(parameters), ProviderRow.defaultConfig(parameters));
        }

        private static String requiredConfigType(List<Parameters.Parameter> parameters) {
            return ProviderRow.cypherMapType(parameters.stream().filter(Parameters.Parameter::isRequired));
        }

        private static String optionalConfigType(List<Parameters.Parameter> parameters) {
            return ProviderRow.cypherMapType(parameters.stream().filter(Parameters.Parameter::isOptional));
        }

        private static String cypherMapType(Stream<Parameters.Parameter> parameters) {
            return parameters.map(p -> "%s :: %s".formatted(p.name(), p.type().cypherName())).collect(Collectors.joining(", ", "{ ", " }"));
        }

        private static Map<String, Object> defaultConfig(List<Parameters.Parameter> parameters) {
            MutableMap defaults = Maps.mutable.empty();
            for (Parameters.Parameter parameter : parameters) {
                Object defaultValue = parameter.defaultValue();
                if (defaultValue == null) continue;
                if (defaultValue instanceof Optional) {
                    Optional optionalDefaultValue = (Optional)defaultValue;
                    optionalDefaultValue.ifPresent(o -> defaults.put((Object)parameter.name(), o));
                    continue;
                }
                if (defaultValue instanceof OptionalLong) {
                    OptionalLong optionalDefaultValue = (OptionalLong)defaultValue;
                    optionalDefaultValue.ifPresent(o -> defaults.put((Object)parameter.name(), (Object)o));
                    continue;
                }
                if (defaultValue instanceof OptionalDouble) {
                    OptionalDouble optionalDefaultValue = (OptionalDouble)defaultValue;
                    optionalDefaultValue.ifPresent(o -> defaults.put((Object)parameter.name(), (Object)o));
                    continue;
                }
                defaults.put((Object)parameter.name(), defaultValue);
            }
            return defaults.asUnmodifiable();
        }
    }
}

