/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.search.collector;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import lombok.Generated;
import lombok.NonNull;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.FieldComparator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.FieldValueHitQueue;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.LeafFieldComparator;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.grouping.CollapseTopFieldDocs;
import org.apache.lucene.search.grouping.GroupSelector;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.PriorityQueue;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.neuralsearch.query.HybridSubQueryScorer;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.collector.CollapseDocSourceGroupSelector;
import org.opensearch.neuralsearch.search.collector.HybridLeafCollector;
import org.opensearch.neuralsearch.search.collector.HybridSearchCollector;
import org.opensearch.neuralsearch.search.lucene.MultiLeafFieldComparator;

public class HybridCollapsingTopDocsCollector<T>
implements HybridSearchCollector,
Collector {
    @Generated
    private static final Logger log = LogManager.getLogger(HybridCollapsingTopDocsCollector.class);
    protected final String collapseField;
    private int totalHitCount;
    private float maxScore = 0.0f;
    private Sort sort;
    private final GroupSelector<T> groupSelector;
    private final int[] reversed;
    private final boolean needsScores = true;
    private int docBase;
    private final Map<T, FieldValueHitQueue<FieldValueHitQueue.Entry>[]> groupQueueMap;
    private Map<T, int[]> collectedHitsPerSubQueryMap;
    private Map<T, FieldValueHitQueue.Entry[]> fieldValueLeafTrackersMap;
    private Map<T, LeafFieldComparator[]> comparatorsMap;
    private Map<T, FieldComparator<?>> firstComparatorMap;
    private Map<T, Integer> reverseMulMap;
    private Map<T, boolean[]> queueFullMap;
    private final int numHits;
    private final int docsPerGroupPerSubQuery;
    TotalHits.Relation totalHitsRelation = TotalHits.Relation.EQUAL_TO;
    private HitsThresholdChecker hitsThresholdChecker;

    HybridCollapsingTopDocsCollector(GroupSelector<T> groupSelector, String collapseField, @NonNull Sort groupSort, int topNGroups, HitsThresholdChecker hitsThresholdChecker, int docsPerGroupPerSubQuery) {
        Objects.requireNonNull(groupSort, "groupSort is marked non-null but is null");
        this.groupSelector = groupSelector;
        this.collapseField = collapseField;
        this.sort = groupSort;
        SortField[] sortFields = groupSort.getSort();
        this.reversed = new int[sortFields.length];
        for (int i = 0; i < sortFields.length; ++i) {
            SortField sortField = sortFields[i];
            this.reversed[i] = sortField.getReverse() ? -1 : 1;
        }
        this.groupQueueMap = new HashMap<T, FieldValueHitQueue<FieldValueHitQueue.Entry>[]>();
        this.collectedHitsPerSubQueryMap = new HashMap<T, int[]>();
        this.fieldValueLeafTrackersMap = new HashMap<T, FieldValueHitQueue.Entry[]>();
        this.comparatorsMap = new HashMap<T, LeafFieldComparator[]>();
        this.firstComparatorMap = new HashMap();
        this.reverseMulMap = new HashMap<T, Integer>();
        this.queueFullMap = new HashMap<T, boolean[]>();
        this.numHits = topNGroups;
        this.hitsThresholdChecker = hitsThresholdChecker;
        this.docsPerGroupPerSubQuery = docsPerGroupPerSubQuery > 0 ? docsPerGroupPerSubQuery : topNGroups;
    }

    public static HybridCollapsingTopDocsCollector<?> createKeyword(String collapseField, MappedFieldType fieldType, Sort sort, int topNGroups, HitsThresholdChecker hitsThresholdChecker, int docsPerGroupPerSubQuery) {
        return new HybridCollapsingTopDocsCollector<BytesRef>(new CollapseDocSourceGroupSelector.Keyword(fieldType), collapseField, sort, topNGroups, hitsThresholdChecker, docsPerGroupPerSubQuery);
    }

    public static HybridCollapsingTopDocsCollector<?> createNumeric(String collapseField, MappedFieldType fieldType, Sort sort, int topNGroups, HitsThresholdChecker hitsThresholdChecker, int docsPerGroupPerSubQuery) {
        return new HybridCollapsingTopDocsCollector<Long>(new CollapseDocSourceGroupSelector.Numeric(fieldType), collapseField, sort, topNGroups, hitsThresholdChecker, docsPerGroupPerSubQuery);
    }

    public List<CollapseTopFieldDocs> topDocs() throws IOException {
        ArrayList<CollapseTopFieldDocs> topDocsList = new ArrayList<CollapseTopFieldDocs>();
        if (this.collectedHitsPerSubQueryMap.isEmpty()) {
            return topDocsList;
        }
        int numSubQueries = this.collectedHitsPerSubQueryMap.values().iterator().next().length;
        for (int subQueryNumber = 0; subQueryNumber < numSubQueries; ++subQueryNumber) {
            GroupPriorityQueue topGroupsQueue = new GroupPriorityQueue(this, this.numHits);
            int totalHitsForSubQuery = 0;
            for (int[] nArray : this.collectedHitsPerSubQueryMap.values()) {
                totalHitsForSubQuery += nArray[subQueryNumber];
            }
            for (Map.Entry entry : this.groupQueueMap.entrySet()) {
                Object groupValue = entry.getKey();
                FieldValueHitQueue queue = ((FieldValueHitQueue[])entry.getValue())[subQueryNumber];
                if (queue.size() <= 0) continue;
                topGroupsQueue.insertWithOverflow(new GroupEntry(groupValue, (FieldValueHitQueue<FieldValueHitQueue.Entry>)queue));
            }
            ArrayList<FieldDoc> fieldDocs = new ArrayList<FieldDoc>();
            ArrayList arrayList = new ArrayList();
            GroupEntry[] topGroups = new GroupEntry[topGroupsQueue.size()];
            for (int j = topGroupsQueue.size() - 1; j >= 0; --j) {
                topGroups[j] = (GroupEntry)topGroupsQueue.pop();
            }
            for (GroupEntry groupEntry : topGroups) {
                Object groupValue = groupEntry.groupValue;
                FieldValueHitQueue<FieldValueHitQueue.Entry> priorityQueue = groupEntry.queue;
                int n = priorityQueue.getComparators().length;
                FieldValueHitQueue.Entry[] entries = new FieldValueHitQueue.Entry[priorityQueue.size()];
                for (int i = priorityQueue.size() - 1; i >= 0; --i) {
                    entries[i] = (FieldValueHitQueue.Entry)priorityQueue.pop();
                }
                for (FieldValueHitQueue.Entry queueEntry : entries) {
                    Object[] fields = new Object[n];
                    for (int k = 0; k < n; ++k) {
                        fields[k] = priorityQueue.getComparators()[k].value(queueEntry.slot);
                    }
                    fieldDocs.add(new FieldDoc(queueEntry.doc, queueEntry.score, fields));
                    arrayList.add(groupValue instanceof BytesRef ? BytesRef.deepCopyOf((BytesRef)((BytesRef)groupValue)) : groupValue);
                }
            }
            topDocsList.add(new CollapseTopFieldDocs(this.collapseField, new TotalHits((long)totalHitsForSubQuery, this.totalHitsRelation), (ScoreDoc[])fieldDocs.toArray(new FieldDoc[0]), this.sort.getSort(), arrayList.toArray(new Object[0])));
        }
        return topDocsList;
    }

    @Override
    public int getTotalHits() {
        return this.totalHitCount;
    }

    @Override
    public float getMaxScore() {
        return this.maxScore;
    }

    public ScoreMode scoreMode() {
        return ScoreMode.COMPLETE;
    }

    public LeafCollector getLeafCollector(final LeafReaderContext context) throws IOException {
        this.docBase = context.docBase;
        this.groupSelector.setNextReader(context);
        return new HybridLeafCollector(){
            Map<T, Boolean> initializeLeafComparatorsPerSegmentOnceMap = new HashMap();

            public void collect(int doc) throws IOException {
                HybridSubQueryScorer compoundQueryScorer = this.getCompoundQueryScorer();
                if (Objects.isNull((Object)compoundQueryScorer)) {
                    return;
                }
                HybridCollapsingTopDocsCollector.this.groupSelector.advanceTo(doc);
                Object groupValue = HybridCollapsingTopDocsCollector.this.groupSelector.currentValue();
                assert (groupValue != null);
                float[] subScoresByQuery = compoundQueryScorer.getSubQueryScores();
                this.initializeQueueIfNeeded(groupValue, subScoresByQuery.length);
                this.initializeLeafComparatorsIfNeeded(groupValue, subScoresByQuery.length, compoundQueryScorer);
                this.updateHitCount();
                for (int subQueryNumber = 0; subQueryNumber < subScoresByQuery.length; ++subQueryNumber) {
                    float score = subScoresByQuery[subQueryNumber];
                    int[] collectedHitsForCurrentSubQuery = HybridCollapsingTopDocsCollector.this.collectedHitsPerSubQueryMap.get(groupValue);
                    int slot = collectedHitsForCurrentSubQuery[subQueryNumber];
                    int n = subQueryNumber;
                    collectedHitsForCurrentSubQuery[n] = collectedHitsForCurrentSubQuery[n] + 1;
                    HybridCollapsingTopDocsCollector.this.collectedHitsPerSubQueryMap.put(groupValue, collectedHitsForCurrentSubQuery);
                    if (this.isQueueFull(groupValue, subQueryNumber)) {
                        this.updateExistingEntry(groupValue, subQueryNumber, doc);
                        continue;
                    }
                    this.addNewEntry(groupValue, subQueryNumber, doc, score, slot);
                }
            }

            private void initializeQueueIfNeeded(T groupValue, int subQueryCount) throws IOException {
                if (HybridCollapsingTopDocsCollector.this.groupQueueMap.get(groupValue) == null) {
                    this.initializeQueue(subQueryCount);
                }
            }

            private void initializeLeafComparatorsIfNeeded(T groupValue, int numSubQueries, HybridSubQueryScorer compoundQueryScorer) throws IOException {
                if (this.initializeLeafComparatorsPerSegmentOnceMap.get(groupValue) == null) {
                    this.initializeLeafComparators(groupValue, numSubQueries, compoundQueryScorer);
                }
            }

            private void initializeLeafComparators(T groupValue, int numSubQueries, HybridSubQueryScorer compoundQueryScorer) throws IOException {
                LeafFieldComparator[] comparators = HybridCollapsingTopDocsCollector.this.comparatorsMap.get(groupValue);
                FieldValueHitQueue<FieldValueHitQueue.Entry>[] compoundScores = HybridCollapsingTopDocsCollector.this.groupQueueMap.get(groupValue);
                for (int subQueryNumber = 0; subQueryNumber < numSubQueries; ++subQueryNumber) {
                    LeafFieldComparator[] leafFieldComparators = compoundScores[subQueryNumber].getComparators(context);
                    int[] reverseMuls = compoundScores[subQueryNumber].getReverseMul();
                    if (leafFieldComparators.length == 1) {
                        HybridCollapsingTopDocsCollector.this.reverseMulMap.put(HybridCollapsingTopDocsCollector.this.groupSelector.copyValue(), reverseMuls[0]);
                        comparators[subQueryNumber] = leafFieldComparators[0];
                    } else {
                        HybridCollapsingTopDocsCollector.this.reverseMulMap.put(HybridCollapsingTopDocsCollector.this.groupSelector.copyValue(), 1);
                        comparators[subQueryNumber] = new MultiLeafFieldComparator(leafFieldComparators, reverseMuls);
                    }
                    comparators[subQueryNumber].setScorer((Scorable)compoundQueryScorer);
                }
                HybridCollapsingTopDocsCollector.this.comparatorsMap.put(HybridCollapsingTopDocsCollector.this.groupSelector.copyValue(), comparators);
                this.initializeLeafComparatorsPerSegmentOnceMap.put(HybridCollapsingTopDocsCollector.this.groupSelector.copyValue(), false);
            }

            private void updateHitCount() throws CollectionTerminatedException {
                ++HybridCollapsingTopDocsCollector.this.totalHitCount;
                HybridCollapsingTopDocsCollector.this.hitsThresholdChecker.incrementHitCount();
                if (HybridCollapsingTopDocsCollector.this.hitsThresholdChecker.isThresholdReached()) {
                    HybridCollapsingTopDocsCollector.this.setTotalHitsRelation(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
                    log.info("Terminating collection early as specified hits threshold is reached");
                    throw new CollectionTerminatedException();
                }
            }

            private boolean isQueueFull(T groupValue, int index) {
                boolean[] queueFullArray = HybridCollapsingTopDocsCollector.this.queueFullMap.get(groupValue);
                return queueFullArray[index];
            }

            private void updateExistingEntry(T groupValue, int index, int doc) throws IOException {
                LeafFieldComparator[] comparators = HybridCollapsingTopDocsCollector.this.comparatorsMap.get(groupValue);
                if (HybridCollapsingTopDocsCollector.this.reverseMulMap.get(groupValue) * comparators[index].compareBottom(doc) > 0) {
                    FieldValueHitQueue.Entry[] fieldValueLeafTrackers = HybridCollapsingTopDocsCollector.this.fieldValueLeafTrackersMap.get(groupValue);
                    FieldValueHitQueue<FieldValueHitQueue.Entry>[] compoundScores = HybridCollapsingTopDocsCollector.this.groupQueueMap.get(groupValue);
                    comparators[index].copy(fieldValueLeafTrackers[index].slot, doc);
                    fieldValueLeafTrackers[index].doc = HybridCollapsingTopDocsCollector.this.docBase + doc;
                    fieldValueLeafTrackers[index] = (FieldValueHitQueue.Entry)compoundScores[index].updateTop();
                    comparators[index].setBottom(fieldValueLeafTrackers[index].slot);
                    this.updateMaps(comparators, fieldValueLeafTrackers, compoundScores);
                }
            }

            private void addNewEntry(T groupValue, int subQueryNumber, int doc, float score, int slot) throws IOException {
                FieldValueHitQueue<FieldValueHitQueue.Entry>[] compoundScores = HybridCollapsingTopDocsCollector.this.groupQueueMap.get(groupValue);
                HybridCollapsingTopDocsCollector.this.maxScore = Math.max(score, HybridCollapsingTopDocsCollector.this.maxScore);
                FieldValueHitQueue.Entry[] fieldValueLeafTrackers = HybridCollapsingTopDocsCollector.this.fieldValueLeafTrackersMap.get(groupValue);
                LeafFieldComparator[] comparators = HybridCollapsingTopDocsCollector.this.comparatorsMap.get(groupValue);
                comparators[subQueryNumber].copy(slot, doc);
                FieldValueHitQueue.Entry bottomEntry = new FieldValueHitQueue.Entry(slot, HybridCollapsingTopDocsCollector.this.docBase + doc);
                bottomEntry.score = score;
                fieldValueLeafTrackers[subQueryNumber] = (FieldValueHitQueue.Entry)compoundScores[subQueryNumber].add((Object)bottomEntry);
                this.updateMaps(comparators, fieldValueLeafTrackers, compoundScores);
                if (slot == HybridCollapsingTopDocsCollector.this.numHits - 1) {
                    boolean[] queueFullArray = HybridCollapsingTopDocsCollector.this.queueFullMap.get(groupValue);
                    queueFullArray[subQueryNumber] = true;
                    HybridCollapsingTopDocsCollector.this.queueFullMap.put(HybridCollapsingTopDocsCollector.this.groupSelector.copyValue(), queueFullArray);
                }
            }

            private void updateMaps(LeafFieldComparator[] comparators, FieldValueHitQueue.Entry[] fieldValueLeafTrackers, FieldValueHitQueue<FieldValueHitQueue.Entry>[] compoundScores) throws IOException {
                HybridCollapsingTopDocsCollector.this.comparatorsMap.put(HybridCollapsingTopDocsCollector.this.groupSelector.copyValue(), comparators);
                HybridCollapsingTopDocsCollector.this.fieldValueLeafTrackersMap.put(HybridCollapsingTopDocsCollector.this.groupSelector.copyValue(), fieldValueLeafTrackers);
                HybridCollapsingTopDocsCollector.this.groupQueueMap.put(HybridCollapsingTopDocsCollector.this.groupSelector.copyValue(), compoundScores);
            }

            private void initializeQueue(int numSubQueries) throws IOException {
                FieldValueHitQueue[] compoundScores = new FieldValueHitQueue[numSubQueries];
                for (int i = 0; i < numSubQueries; ++i) {
                    compoundScores[i] = FieldValueHitQueue.create((SortField[])HybridCollapsingTopDocsCollector.this.sort.getSort(), (int)HybridCollapsingTopDocsCollector.this.docsPerGroupPerSubQuery);
                    HybridCollapsingTopDocsCollector.this.firstComparatorMap.put(HybridCollapsingTopDocsCollector.this.groupSelector.copyValue(), compoundScores[i].getComparators()[0]);
                }
                HybridCollapsingTopDocsCollector.this.groupQueueMap.put(HybridCollapsingTopDocsCollector.this.groupSelector.copyValue(), compoundScores);
                HybridCollapsingTopDocsCollector.this.collectedHitsPerSubQueryMap.put(HybridCollapsingTopDocsCollector.this.groupSelector.copyValue(), new int[numSubQueries]);
                HybridCollapsingTopDocsCollector.this.comparatorsMap.put(HybridCollapsingTopDocsCollector.this.groupSelector.copyValue(), new LeafFieldComparator[numSubQueries]);
                HybridCollapsingTopDocsCollector.this.fieldValueLeafTrackersMap.put(HybridCollapsingTopDocsCollector.this.groupSelector.copyValue(), new FieldValueHitQueue.Entry[numSubQueries]);
                HybridCollapsingTopDocsCollector.this.queueFullMap.put(HybridCollapsingTopDocsCollector.this.groupSelector.copyValue(), new boolean[numSubQueries]);
            }
        };
    }

    @Generated
    public void setTotalHitsRelation(TotalHits.Relation totalHitsRelation) {
        this.totalHitsRelation = totalHitsRelation;
    }

    private class GroupPriorityQueue<T>
    extends PriorityQueue<GroupEntry<T>> {
        GroupPriorityQueue(HybridCollapsingTopDocsCollector hybridCollapsingTopDocsCollector, int maxSize) {
            super(maxSize);
        }

        protected boolean lessThan(GroupEntry<T> a, GroupEntry<T> b) {
            FieldValueHitQueue<FieldValueHitQueue.Entry> queueA = a.queue;
            FieldValueHitQueue<FieldValueHitQueue.Entry> queueB = b.queue;
            if (queueA.size() == 0 && queueB.size() == 0) {
                return false;
            }
            if (queueA.size() == 0) {
                return true;
            }
            if (queueB.size() == 0) {
                return false;
            }
            FieldValueHitQueue.Entry entryA = (FieldValueHitQueue.Entry)queueA.top();
            FieldValueHitQueue.Entry entryB = (FieldValueHitQueue.Entry)queueB.top();
            FieldComparator[] comparators = queueA.getComparators();
            int[] reverseMul = queueA.getReverseMul();
            for (int i = 0; i < comparators.length; ++i) {
                Object valueB;
                FieldComparator comparator = comparators[i];
                Object valueA = comparator.value(entryA.slot);
                int comparison = comparator.compareValues(valueA, valueB = comparator.value(entryB.slot));
                if (comparison == 0) continue;
                return reverseMul[i] * comparison < 0;
            }
            return entryB.score < entryA.score;
        }
    }

    private static class GroupEntry<T> {
        final T groupValue;
        final FieldValueHitQueue<FieldValueHitQueue.Entry> queue;

        GroupEntry(T groupValue, FieldValueHitQueue<FieldValueHitQueue.Entry> queue) {
            this.groupValue = groupValue;
            this.queue = queue;
        }
    }
}

