/*
 * Decompiled with CFR 0.152.
 */
package net.mehvahdjukaar.moonlight.api.util.math.kmeans;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.Random;
import net.mehvahdjukaar.moonlight.api.resources.textures.Palette;
import net.mehvahdjukaar.moonlight.api.resources.textures.PaletteColor;
import net.mehvahdjukaar.moonlight.api.util.math.colors.BaseColor;
import net.mehvahdjukaar.moonlight.api.util.math.colors.LABColor;
import net.mehvahdjukaar.moonlight.api.util.math.colors.RGBColor;
import net.mehvahdjukaar.moonlight.api.util.math.kmeans.IDataEntry;

public class DataSet<A> {
    private final List<IDataEntry<A>> colorPoints = new LinkedList<IDataEntry<A>>();
    private final List<IDataEntry<A>> lastCentroids = new LinkedList<IDataEntry<A>>();
    private final List<Integer> indicesOfCentroids = new LinkedList<Integer>();
    private final Random random;

    public <T extends IDataEntry<A>> DataSet(List<T> colors) {
        this.colorPoints.addAll(colors);
        this.random = new Random(Objects.hash(Float.valueOf(this.colorPoints.get(0).distTo(this.colorPoints.get(this.colorPoints.size() - 1)))));
    }

    public static DataSet<ColorPoint> fromPalette(Palette palette) {
        return new DataSet<ColorPoint>(palette.getValues().stream().map(ColorPoint::new).toList());
    }

    public IDataEntry<A> calculateCentroid(int clusterNo) {
        LinkedList colorsInCluster = new LinkedList();
        for (IDataEntry<A> colorPoint : this.colorPoints) {
            if (colorPoint.getClusterNo() != clusterNo) continue;
            colorsInCluster.add(colorPoint);
        }
        if (colorsInCluster.size() == 0) {
            return new ColorPoint(new PaletteColor(new RGBColor(0)));
        }
        return ((IDataEntry)colorsInCluster.get(0)).average(colorsInCluster);
    }

    public List<IDataEntry<A>> recomputeCentroids(int clusterSize) {
        this.lastCentroids.clear();
        for (int i = 0; i < clusterSize; ++i) {
            this.lastCentroids.add(this.calculateCentroid(i));
        }
        return this.lastCentroids;
    }

    public IDataEntry<A> randomFromDataSet() {
        int index = this.random.nextInt(this.colorPoints.size());
        return this.colorPoints.get(index);
    }

    public Double calculateClusterSSE(IDataEntry<A> centroid, int clusterNo) {
        double SSE = 0.0;
        for (IDataEntry<A> colorPoint : this.colorPoints) {
            if (colorPoint.getClusterNo() != clusterNo) continue;
            float dist = centroid.distTo(colorPoint);
            SSE += (double)(dist * dist);
        }
        return SSE;
    }

    public Double calculateTotalSSE(List<IDataEntry<A>> centroids) {
        Double SSE = 0.0;
        for (int i = 0; i < centroids.size(); ++i) {
            SSE = SSE + this.calculateClusterSSE(centroids.get(i), i);
        }
        return SSE;
    }

    public IDataEntry<A> calculateWeighedCentroid() {
        double sum = 0.0;
        for (int i = 0; i < this.colorPoints.size(); ++i) {
            if (this.indicesOfCentroids.contains(i)) continue;
            double minDist = Double.MAX_VALUE;
            for (int ind : this.indicesOfCentroids) {
                double dist = this.colorPoints.get(i).distTo(this.colorPoints.get(ind));
                if (!(dist < minDist)) continue;
                minDist = dist;
            }
            if (this.indicesOfCentroids.isEmpty()) {
                sum = 0.0;
            }
            sum += minDist;
        }
        double threshold = sum * this.random.nextDouble();
        for (int i = 0; i < this.colorPoints.size(); ++i) {
            if (this.indicesOfCentroids.contains(i)) continue;
            double minDist = Double.MAX_VALUE;
            for (int ind : this.indicesOfCentroids) {
                double dist = this.colorPoints.get(i).distTo(this.colorPoints.get(ind));
                if (!(dist < minDist)) continue;
                minDist = dist;
            }
            if (!((sum += minDist) > threshold)) continue;
            this.indicesOfCentroids.add(i);
            return this.colorPoints.get(i);
        }
        throw new UnsupportedOperationException("Something bad happened");
    }

    public List<IDataEntry<A>> getColorPoints() {
        return this.colorPoints;
    }

    public List<IDataEntry<A>> getLastCentroids() {
        return this.lastCentroids;
    }

    public static class ColorPoint
    implements IDataEntry<ColorPoint> {
        private final int weight;
        private final PaletteColor color;
        private int clusterNo;

        public ColorPoint(PaletteColor color) {
            this.color = color;
            this.weight = color.getOccurrence();
        }

        @Override
        public IDataEntry<ColorPoint> average(List<IDataEntry<ColorPoint>> others) {
            ArrayList<LABColor> pixels = new ArrayList<LABColor>();
            for (int i = 0; i < this.weight; ++i) {
                pixels.add(this.color.lab());
            }
            for (IDataEntry<ColorPoint> c : others) {
                if (c == this) continue;
                for (int i = 0; i < c.cast().weight; ++i) {
                    pixels.add(c.cast().color.lab());
                }
            }
            return new ColorPoint(new PaletteColor((BaseColor<?>)BaseColor.mixColors(pixels)));
        }

        @Override
        public void setClusterNo(int clusterNo) {
            this.clusterNo = clusterNo;
        }

        @Override
        public int getClusterNo() {
            return this.clusterNo;
        }

        @Override
        public float distTo(IDataEntry<ColorPoint> a) {
            return this.color.lab().distTo(a.cast().color.lab());
        }

        @Override
        public ColorPoint cast() {
            return this;
        }

        public PaletteColor getColor() {
            return this.color;
        }
    }
}

