package edu.berkeley.compbio.ml.cluster.bayesian;

import com.davidsoergel.conja.Function;
import com.davidsoergel.conja.Parallel;
import com.davidsoergel.dsutils.collections.ConcurrentHashWeightedSet;
import com.davidsoergel.dsutils.collections.MutableWeightedSet;
import com.davidsoergel.dsutils.collections.WeightedSet;
import com.davidsoergel.stats.DissimilarityMeasure;
import com.davidsoergel.stats.ProbabilisticDissimilarityMeasure;
import com.google.common.collect.TreeMultimap;
import edu.berkeley.compbio.ml.cluster.AbstractSupervisedOnlineClusteringMethod;
import edu.berkeley.compbio.ml.cluster.AdditiveClusterable;
import edu.berkeley.compbio.ml.cluster.BasicCentroidCluster;
import edu.berkeley.compbio.ml.cluster.CentroidCluster;
import edu.berkeley.compbio.ml.cluster.ClusterMove;
import edu.berkeley.compbio.ml.cluster.ClusterRuntimeException;
import edu.berkeley.compbio.ml.cluster.Clusterable;
import edu.berkeley.compbio.ml.cluster.ClusterableIterator;
import edu.berkeley.compbio.ml.cluster.NoGoodClusterException;
import edu.berkeley.compbio.ml.cluster.PointClusterFilter;
import edu.berkeley.compbio.ml.cluster.ProhibitionModel;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.log4j.Logger;
import org.jetbrains.annotations.Nullable;

/* loaded from: input_file:lib/ml-0.9.jar:edu/berkeley/compbio/ml/cluster/bayesian/MultiNeighborClustering.class */
public abstract class MultiNeighborClustering<T extends AdditiveClusterable<T>> extends AbstractSupervisedOnlineClusteringMethod<T, CentroidCluster<T>> {
    private static final Logger logger;
    protected final int maxNeighbors;
    protected final double unknownDistanceThreshold;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:lib/ml-0.9.jar:edu/berkeley/compbio/ml/cluster/bayesian/MultiNeighborClustering$BestLabelPair.class */
    public static class BestLabelPair {
        final String bestLabel;
        final String secondBestLabel;

        private BestLabelPair(String str, String str2) {
            this.bestLabel = str;
            this.secondBestLabel = str2;
        }

        public String getBestLabel() {
            return this.bestLabel;
        }

        public String getSecondBestLabel() {
            return this.secondBestLabel;
        }

        public boolean hasSecondBestLabel() {
            return this.secondBestLabel != null;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:lib/ml-0.9.jar:edu/berkeley/compbio/ml/cluster/bayesian/MultiNeighborClustering$VotingResults.class */
    public class VotingResults {
        private final Map<String, MutableWeightedSet<ClusterMove<T, CentroidCluster<T>>>> labelContributions = new HashMap();
        private final MutableWeightedSet<String> labelVotes = new ConcurrentHashWeightedSet();

        protected VotingResults() {
        }

        public void addContribution(ClusterMove<T, CentroidCluster<T>> clusterMove, String str, Double d) {
            MutableWeightedSet<ClusterMove<T, CentroidCluster<T>>> mutableWeightedSet = this.labelContributions.get(str);
            if (mutableWeightedSet == null) {
                mutableWeightedSet = new ConcurrentHashWeightedSet();
                this.labelContributions.put(str, mutableWeightedSet);
            }
            mutableWeightedSet.add(clusterMove, d.doubleValue(), 1);
        }

        public void addVotes(WeightedSet<String> weightedSet) {
            this.labelVotes.addAll(weightedSet);
        }

        public void addVotes(WeightedSet<String> weightedSet, double d) {
            this.labelVotes.addAll(weightedSet, d);
        }

        public double computeWeightedDistance(String str) {
            return computeWeightedDistance(this.labelContributions.get(str));
        }

        public BestLabelPair getSubResults(Set<String> set) throws NoGoodClusterException {
            Iterator<String> it = this.labelVotes.extractWithKeys(set).keysInDecreasingWeightOrder(new Comparator() { // from class: edu.berkeley.compbio.ml.cluster.bayesian.MultiNeighborClustering.VotingResults.1
                final Map<String, Double> cache = new HashMap();

                private Double getWeightedDistance(String str) {
                    Double d = this.cache.get(str);
                    if (d == null) {
                        d = Double.valueOf(VotingResults.this.computeWeightedDistance((WeightedSet) VotingResults.this.labelContributions.get(str)));
                        this.cache.put(str, d);
                    }
                    return d;
                }

                @Override // java.util.Comparator
                public int compare(Object obj, Object obj2) {
                    return Double.compare(getWeightedDistance((String) obj).doubleValue(), getWeightedDistance((String) obj2).doubleValue());
                }
            }).iterator();
            try {
                String next = it.next();
                String str = null;
                try {
                    str = it.next();
                } catch (NoSuchElementException e) {
                }
                return new BestLabelPair(next, str);
            } catch (NoSuchElementException e2) {
                throw new NoGoodClusterException();
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public double computeWeightedDistance(WeightedSet<ClusterMove<T, CentroidCluster<T>>> weightedSet) {
            double d = 0.0d;
            for (Map.Entry<ClusterMove<T, CentroidCluster<T>>, Double> entry : weightedSet.getItemNormalizedMap().entrySet()) {
                d += entry.getValue().doubleValue() * entry.getKey().bestDistance;
            }
            return d;
        }

        public WeightedSet<String> getLabelVotes() {
            return this.labelVotes;
        }
    }

    public MultiNeighborClustering(DissimilarityMeasure<T> dissimilarityMeasure, double d, Set<String> set, Map<String, Set<String>> map, ProhibitionModel<T> prohibitionModel, Set<String> set2, int i) {
        super(dissimilarityMeasure, set, map, prohibitionModel, set2);
        this.maxNeighbors = i;
        this.unknownDistanceThreshold = d;
    }

    public String bestLabel(T t, Set<String> set) throws NoGoodClusterException {
        return addUpNeighborVotes(scoredClusterMoves(t)).getSubResults(set).getBestLabel();
    }

    @Override // edu.berkeley.compbio.ml.cluster.AbstractSupervisedOnlineClusteringMethod
    public void trainWithKnownTrainingLabels(ClusterableIterator<T> clusterableIterator) {
        final AtomicInteger atomicInteger = new AtomicInteger(0);
        Parallel.forEach(clusterableIterator, new Function<T, Void>() { // from class: edu.berkeley.compbio.ml.cluster.bayesian.MultiNeighborClustering.1
            @Override // com.davidsoergel.conja.Function
            public Void apply(@Nullable T t) {
                int incrementAndGet = atomicInteger.incrementAndGet();
                MultiNeighborClustering.this.addCluster(new BasicCentroidCluster(incrementAndGet, t));
                if (incrementAndGet % 1000 != 0) {
                    return null;
                }
                MultiNeighborClustering.logger.info("Trained " + incrementAndGet + " samples");
                return null;
            }
        });
        logger.info("Done training " + getNumClusters() + " samples");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public MultiNeighborClustering<T>.VotingResults addUpNeighborVotes(TreeMultimap<Double, ClusterMove<T, CentroidCluster<T>>> treeMultimap) {
        MultiNeighborClustering<T>.VotingResults votingResults = new VotingResults();
        int i = 0;
        double d = 0.0d;
        for (ClusterMove<T, CentroidCluster<T>> clusterMove : treeMultimap.values()) {
            if (i >= this.maxNeighbors) {
                break;
            }
            if (!$assertionsDisabled && clusterMove.bestDistance < d) {
                throw new AssertionError();
            }
            d = clusterMove.bestDistance;
            WeightedSet<String> derivedLabelProbabilities = clusterMove.bestCluster.getDerivedLabelProbabilities();
            votingResults.addVotes(derivedLabelProbabilities, clusterMove.voteWeight);
            for (Map.Entry<String, Double> entry : derivedLabelProbabilities.getItemNormalizedMap().entrySet()) {
                votingResults.addContribution(clusterMove, entry.getKey(), entry.getValue());
            }
            i++;
        }
        return votingResults;
    }

    @Override // edu.berkeley.compbio.ml.cluster.AbstractClusteringMethod
    public ClusterMove<T, CentroidCluster<T>> bestClusterMove(T t) throws NoGoodClusterException {
        throw new ClusterRuntimeException("It doesn't make sense to get the best clustermove with a multi-neighbor clustering; look for the best label instead using scoredClusterMoves");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public TreeMultimap<Double, ClusterMove<T, CentroidCluster<T>>> scoredClusterMoves(T t) throws NoGoodClusterException {
        TreeMultimap<Double, ClusterMove<T, CentroidCluster<T>>> create = TreeMultimap.create();
        PointClusterFilter filter = this.prohibitionModel == null ? null : this.prohibitionModel.getFilter(t);
        for (C c : getClusters()) {
            if (filter == null || !filter.isProhibited(c)) {
                ClusterMove<T, CentroidCluster<T>> makeClusterMove = makeClusterMove(c, this.measure instanceof ProbabilisticDissimilarityMeasure ? ((ProbabilisticDissimilarityMeasure) this.measure).distanceFromTo(t, c.getCentroid(), this.clusterPriors.get(c).doubleValue()) : this.measure.distanceFromTo(t, c.getCentroid()));
                if (makeClusterMove.bestDistance < this.unknownDistanceThreshold) {
                    create.put(Double.valueOf(makeClusterMove.bestDistance), makeClusterMove);
                }
            }
        }
        if (create.isEmpty()) {
            throw new NoGoodClusterException("No clusters passed the unknown threshold");
        }
        return create;
    }

    protected ClusterMove<T, CentroidCluster<T>> makeClusterMove(CentroidCluster<T> centroidCluster, double d) {
        ClusterMove<T, CentroidCluster<T>> clusterMove = new ClusterMove<>();
        clusterMove.bestCluster = centroidCluster;
        clusterMove.bestDistance = d;
        return clusterMove;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // edu.berkeley.compbio.ml.cluster.AbstractClusteringMethod, edu.berkeley.compbio.ml.cluster.ClusteringMethod
    public /* bridge */ /* synthetic */ String bestLabel(Clusterable clusterable, Set set) throws NoGoodClusterException {
        return bestLabel((MultiNeighborClustering<T>) clusterable, (Set<String>) set);
    }

    static {
        $assertionsDisabled = !MultiNeighborClustering.class.desiredAssertionStatus();
        logger = Logger.getLogger(MultiNeighborClustering.class);
    }
}
