package marytts.machinelearning;

import marytts.signalproc.analysis.distance.DistanceComputer;
import marytts.util.math.MathUtils;

/* loaded from: input_file:lib/marytts-signalproc-5.1-SNAPSHOT.jar:marytts/machinelearning/KMeansClusteringTrainer.class */
public class KMeansClusteringTrainer {
    public Cluster[] clusters;
    public int[] totalObservationsInClusters;
    public int[] clusterIndices;
    public double[][] covMatrixGlobal;
    public double[][] invCovMatrixGlobal;

    /* JADX WARN: Multi-variable type inference failed */
    public void train(double[][] dArr, KMeansClusteringTrainerParams kMeansClusteringTrainerParams) {
        if (kMeansClusteringTrainerParams.globalVariances == null) {
            kMeansClusteringTrainerParams.globalVariances = MathUtils.variance(dArr, MathUtils.mean(dArr, true), true);
        }
        int length = dArr.length;
        int length2 = dArr[0].length;
        int i = -1;
        double d = Double.MIN_VALUE;
        double[] dArr2 = new double[kMeansClusteringTrainerParams.numClusters];
        for (int i2 = 0; i2 < kMeansClusteringTrainerParams.numClusters; i2++) {
            dArr2[i2] = new double[length2];
        }
        int[] iArr = new int[length];
        for (int i3 = 0; i3 < length; i3++) {
            iArr[i3] = new int[kMeansClusteringTrainerParams.numClusters];
        }
        int[] iArr2 = new int[length];
        for (int i4 = 0; i4 < length; i4++) {
            iArr2[i4] = new int[kMeansClusteringTrainerParams.numClusters];
        }
        int[] iArr3 = new int[kMeansClusteringTrainerParams.numClusters];
        double[] dArr3 = new double[length2];
        this.clusters = new Cluster[kMeansClusteringTrainerParams.numClusters];
        for (int i5 = 0; i5 < kMeansClusteringTrainerParams.numClusters; i5++) {
            this.clusters[i5] = new Cluster(length2, kMeansClusteringTrainerParams.isDiagonalOutputCovariance);
        }
        for (int i6 = 1; i6 <= kMeansClusteringTrainerParams.numClusters; i6++) {
            for (int i7 = 1; i7 <= length2; i7++) {
                this.clusters[i6 - 1].meanVector[i7 - 1] = 0.0d;
            }
            for (int i8 = 1; i8 <= length; i8++) {
                iArr[i8 - 1][i6 - 1] = 0;
            }
        }
        double[] mean = MathUtils.mean(dArr, true);
        double[] dArr4 = new double[length];
        double[] dArr5 = new double[kMeansClusteringTrainerParams.numClusters + 1];
        double d2 = Double.MAX_VALUE;
        int i9 = -1;
        for (int i10 = 1; i10 <= kMeansClusteringTrainerParams.numClusters; i10++) {
            for (int i11 = 1; i11 <= length; i11++) {
                if (i10 > 1) {
                    for (int i12 = 1; i12 <= i10 - 1; i12++) {
                        dArr5[i12 - 1] = DistanceComputer.getNormalizedEuclideanDistance(this.clusters[i12 - 1].meanVector, dArr[i11 - 1], kMeansClusteringTrainerParams.globalVariances);
                    }
                    dArr5[i10 - 1] = DistanceComputer.getNormalizedEuclideanDistance(mean, dArr[i11 - 1], kMeansClusteringTrainerParams.globalVariances);
                    dArr4[i11 - 1] = MathUtils.mean(dArr5, 0, i10 - 1);
                } else {
                    dArr4[i11 - 1] = DistanceComputer.getNormalizedEuclideanDistance(mean, dArr[i11 - 1], kMeansClusteringTrainerParams.globalVariances);
                }
            }
            for (int i13 = 1; i13 <= length; i13++) {
                if (i13 == 1 || dArr4[i13 - 1] > d2) {
                    d2 = dArr4[i13 - 1];
                    i9 = i13;
                }
            }
            for (int i14 = 0; i14 < length2; i14++) {
                this.clusters[i10 - 1].meanVector[i14] = dArr[i9 - 1][i14];
            }
        }
        int[] iArr4 = new int[kMeansClusteringTrainerParams.numClusters];
        int i15 = 0;
        double[] dArr6 = new double[kMeansClusteringTrainerParams.numClusters];
        this.totalObservationsInClusters = new int[kMeansClusteringTrainerParams.numClusters];
        this.clusterIndices = new int[length];
        int i16 = 0;
        boolean z = true;
        while (z) {
            for (int i17 = 1; i17 <= length; i17++) {
                for (int i18 = 1; i18 <= kMeansClusteringTrainerParams.numClusters; i18++) {
                    double normalizedEuclideanDistance = DistanceComputer.getNormalizedEuclideanDistance(this.clusters[i18 - 1].meanVector, dArr[i17 - 1], kMeansClusteringTrainerParams.globalVariances);
                    iArr[i17 - 1][i18 - 1] = 0;
                    if (i18 == 1 || normalizedEuclideanDistance < d) {
                        d = normalizedEuclideanDistance;
                        i = i18;
                    }
                }
                for (int i19 = 1; i19 <= kMeansClusteringTrainerParams.numClusters; i19++) {
                    if (i19 == i) {
                        iArr[i17 - 1][i19 - 1] = 1;
                    }
                }
            }
            for (int i20 = 1; i20 <= kMeansClusteringTrainerParams.numClusters; i20++) {
                this.totalObservationsInClusters[i20 - 1] = 0;
                iArr4[i20 - 1] = 0;
            }
            int i21 = 1;
            for (int i22 = 1; i22 <= kMeansClusteringTrainerParams.numClusters; i22++) {
                for (int i23 = 1; i23 <= length2; i23++) {
                    dArr2[i22 - 1][i23 - 1] = 0;
                }
                for (int i24 = 1; i24 <= length; i24++) {
                    if (iArr[i24 - 1][i22 - 1] == 1) {
                        for (int i25 = 1; i25 <= length2; i25++) {
                            dArr2[i22 - 1][i25 - 1] = dArr2[i22 - 1][i25 - 1] + dArr[i24 - 1][i25 - 1];
                        }
                        this.clusterIndices[i24 - 1] = i22 - 1;
                        int[] iArr5 = this.totalObservationsInClusters;
                        int i26 = i22 - 1;
                        iArr5[i26] = iArr5[i26] + 1;
                    }
                }
                if (this.totalObservationsInClusters[i22 - 1] < kMeansClusteringTrainerParams.minSamplesInOneCluster) {
                    iArr4[i21 - 1] = i22;
                    i15++;
                    i21++;
                }
            }
            int i27 = 0;
            for (int i28 = 0; i28 < this.totalObservationsInClusters.length; i28++) {
                dArr6[i28] = this.totalObservationsInClusters[i28];
            }
            int[] quickSort = MathUtils.quickSort(dArr6, 0, kMeansClusteringTrainerParams.numClusters - 1);
            for (int i29 = 1; i29 <= kMeansClusteringTrainerParams.numClusters; i29++) {
                if (this.totalObservationsInClusters[i29 - 1] >= kMeansClusteringTrainerParams.minSamplesInOneCluster) {
                    for (int i30 = 1; i30 <= length2; i30++) {
                        this.clusters[i29 - 1].meanVector[i30 - 1] = dArr2[i29 - 1][i30 - 1] / this.totalObservationsInClusters[i29 - 1];
                    }
                } else {
                    for (int i31 = 1; i31 <= length2; i31++) {
                        this.clusters[i29 - 1].meanVector[i31 - 1] = this.clusters[quickSort[(kMeansClusteringTrainerParams.numClusters - i27) - 1]].meanVector[i31 - 1] + (Math.random() * Math.abs(this.clusters[quickSort[(kMeansClusteringTrainerParams.numClusters - i27) - 1]].meanVector[i31 - 1]) * 0.01d);
                    }
                    i27++;
                }
            }
            for (int i32 = 1; i32 <= kMeansClusteringTrainerParams.numClusters; i32++) {
                iArr3[i32 - 1] = this.totalObservationsInClusters[i32 - 1];
            }
            i16++;
            int i33 = 0;
            if (i16 > 1) {
                if (i16 >= kMeansClusteringTrainerParams.maxIterations) {
                    z = false;
                }
                for (int i34 = 1; i34 <= length; i34++) {
                    int i35 = 1;
                    while (true) {
                        if (i35 <= kMeansClusteringTrainerParams.numClusters) {
                            if (iArr2[i34 - 1][i35 - 1] != iArr[i34 - 1][i35 - 1]) {
                                i33++;
                                break;
                            }
                            i35++;
                        }
                    }
                }
                if ((i33 / length) * 100.0d < kMeansClusteringTrainerParams.minClusterChangePercent) {
                    z = false;
                }
            }
            for (int i36 = 1; i36 <= length; i36++) {
                for (int i37 = 1; i37 <= kMeansClusteringTrainerParams.numClusters; i37++) {
                    iArr2[i36 - 1][i37 - 1] = iArr[i36 - 1][i37 - 1];
                }
            }
        }
        for (int i38 = 0; i38 < kMeansClusteringTrainerParams.numClusters; i38++) {
            if (this.totalObservationsInClusters[i38] > 0) {
                int[] iArr6 = new int[this.totalObservationsInClusters[i38]];
                int i39 = 0;
                for (int i40 = 0; i40 < length; i40++) {
                    if (this.clusterIndices[i40] == i38) {
                        int i41 = i39;
                        i39++;
                        iArr6[i41] = i40;
                    }
                }
                if (kMeansClusteringTrainerParams.isDiagonalOutputCovariance) {
                    double[] diagonal = MathUtils.diagonal(MathUtils.covariance(dArr, this.clusters[i38].meanVector, true, iArr6));
                    for (int i42 = 0; i42 < diagonal.length; i42++) {
                        diagonal[i42] = Math.max(diagonal[i42], kMeansClusteringTrainerParams.minCovarianceAllowed);
                    }
                    System.arraycopy(diagonal, 0, this.clusters[i38].covMatrix[0], 0, diagonal.length);
                    this.clusters[i38].invCovMatrix[0] = MathUtils.inverse(this.clusters[i38].covMatrix[0]);
                } else {
                    this.clusters[i38].covMatrix = MathUtils.covariance(dArr, this.clusters[i38].meanVector, true, iArr6);
                    for (int i43 = 0; i43 < this.clusters[i38].covMatrix.length; i43++) {
                        for (int i44 = 0; i44 < this.clusters[i38].covMatrix[i43].length; i44++) {
                            this.clusters[i38].covMatrix[i43][i44] = Math.max(this.clusters[i38].covMatrix[i43][i44], kMeansClusteringTrainerParams.minCovarianceAllowed);
                        }
                    }
                    this.clusters[i38].invCovMatrix = MathUtils.inverse(this.clusters[i38].covMatrix);
                }
            }
        }
        for (int i45 = 0; i45 < kMeansClusteringTrainerParams.numClusters; i45++) {
            dArr6[i45] = this.totalObservationsInClusters[i45];
        }
        int i46 = MathUtils.quickSort(dArr6, 0, kMeansClusteringTrainerParams.numClusters - 1)[kMeansClusteringTrainerParams.numClusters - 1];
        for (int i47 = 0; i47 < kMeansClusteringTrainerParams.numClusters; i47++) {
            if (this.totalObservationsInClusters[i47] < kMeansClusteringTrainerParams.minSamplesInOneCluster) {
                System.arraycopy(this.clusters[i46].meanVector, 0, this.clusters[i47].meanVector, 0, length2);
                if (kMeansClusteringTrainerParams.isDiagonalOutputCovariance) {
                    System.arraycopy(this.clusters[i46].covMatrix[0], 0, this.clusters[i47].covMatrix[0], 0, length2);
                    System.arraycopy(this.clusters[i46].invCovMatrix[0], 0, this.clusters[i47].invCovMatrix[0], 0, length2);
                } else {
                    for (int i48 = 0; i48 < length2; i48++) {
                        System.arraycopy(this.clusters[i46].covMatrix[i48], 0, this.clusters[i47].covMatrix[i48], 0, length2);
                        System.arraycopy(this.clusters[i46].invCovMatrix[i48], 0, this.clusters[i47].invCovMatrix[i48], 0, length2);
                    }
                }
            }
        }
        if (kMeansClusteringTrainerParams.isDiagonalOutputCovariance) {
            double[][] covariance = MathUtils.covariance(dArr, true);
            this.covMatrixGlobal = new double[1][covariance.length];
            this.covMatrixGlobal[0] = MathUtils.diagonal(covariance);
            for (int i49 = 0; i49 < this.covMatrixGlobal[0].length; i49++) {
                this.covMatrixGlobal[0][i49] = Math.max(this.covMatrixGlobal[0][i49], kMeansClusteringTrainerParams.minCovarianceAllowed);
            }
            this.invCovMatrixGlobal = new double[1][covariance.length];
            this.invCovMatrixGlobal[0] = MathUtils.inverse(this.covMatrixGlobal[0]);
            return;
        }
        this.covMatrixGlobal = MathUtils.covariance(dArr);
        for (int i50 = 0; i50 < this.covMatrixGlobal[0].length; i50++) {
            for (int i51 = 0; i51 < this.covMatrixGlobal[i50].length; i51++) {
                this.covMatrixGlobal[i50][i51] = Math.max(this.covMatrixGlobal[i50][i51], kMeansClusteringTrainerParams.minCovarianceAllowed);
            }
        }
        this.invCovMatrixGlobal = MathUtils.inverse(this.covMatrixGlobal);
    }

    public int getFeatureDimension() {
        if (this.clusters == null || this.clusters[0].meanVector == null) {
            return 0;
        }
        return this.clusters[0].meanVector.length;
    }

    public int getTotalClusters() {
        if (this.clusters != null) {
            return this.clusters.length;
        }
        return 0;
    }

    public boolean isDiagonalCovariance() {
        if (this.clusters != null) {
            return this.clusters[0].isDiagonalCovariance;
        }
        return false;
    }
}
