/*
 * Decompiled with CFR 0.152.
 */
package edu.umass.cs.mallet.base.types;

import edu.umass.cs.mallet.base.classify.Classification;
import edu.umass.cs.mallet.base.types.FeatureVector;
import edu.umass.cs.mallet.base.types.Instance;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.types.LabelVector;
import edu.umass.cs.mallet.base.types.Labeling;
import edu.umass.cs.mallet.base.types.RankedFeatureVector;
import edu.umass.cs.mallet.base.util.MalletLogger;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.logging.Logger;

public class ExpGain
extends RankedFeatureVector {
    private static Logger logger = MalletLogger.getLogger(ExpGain.class.getName());
    boolean usingHyperbolicPrior = false;
    double hyperbolicSlope = 0.2;
    double hyperbolicSharpness = 10.0;

    private static double[] calcExpGains(InstanceList ilist, LabelVector[] classifications, double gaussianPriorVariance) {
        int li;
        FeatureVector fv;
        Labeling labeling;
        int fli;
        int numInstances = ilist.size();
        int numClasses = ilist.getTargetAlphabet().size();
        int numFeatures = ilist.getDataAlphabet().size();
        assert (ilist.size() > 0);
        double[][] p = new double[numClasses][numFeatures];
        double[][] q = new double[numClasses][numFeatures];
        double[][] alphas = new double[numClasses][numFeatures];
        logger.info("Starting klgains, #instances=" + numInstances);
        double trueLabelWeightSum = 0.0;
        double modelLabelWeightSum = 0.0;
        int i = 0;
        while (i < numInstances) {
            assert (classifications[i].getLabelAlphabet() == ilist.getTargetAlphabet());
            Instance inst = ilist.getInstance(i);
            Labeling labeling2 = inst.getLabeling();
            FeatureVector fv2 = (FeatureVector)inst.getData();
            double perInstanceModelLabelWeight = 0.0;
            int li2 = 0;
            while (li2 < numClasses) {
                double trueLabelWeight = labeling2.value(li2);
                double modelLabelWeight = classifications[i].value(li2);
                trueLabelWeightSum += trueLabelWeight;
                modelLabelWeightSum += modelLabelWeight;
                perInstanceModelLabelWeight += modelLabelWeight;
                if (trueLabelWeight != 0.0 || modelLabelWeight != 0.0) {
                    int fl = 0;
                    while (fl < fv2.numLocations()) {
                        fli = fv2.indexAtLocation(fl);
                        assert (fv2.valueAtLocation(fl) == 1.0);
                        double[] dArray = p[li2];
                        int n = fli;
                        dArray[n] = dArray[n] + trueLabelWeight;
                        double[] dArray2 = q[li2];
                        int n2 = fli;
                        dArray2[n2] = dArray2[n2] + modelLabelWeight;
                        ++fl;
                    }
                }
                ++li2;
            }
            assert (Math.abs(perInstanceModelLabelWeight - 1.0) < 0.001);
            ++i;
        }
        assert (Math.abs(trueLabelWeightSum / (double)numInstances - 1.0) < 0.001) : "trueLabelWeightSum should be 1.0, it was " + trueLabelWeightSum;
        assert (Math.abs(modelLabelWeightSum / (double)numInstances - 1.0) < 0.001) : "modelLabelWeightSum should be 1.0, it was " + modelLabelWeightSum;
        double[][] dalphas = new double[numClasses][numFeatures];
        double[][] alphaChangeOld = new double[numClasses][numFeatures];
        double[][] alphaMax = new double[numClasses][numFeatures];
        double[][] alphaMin = new double[numClasses][numFeatures];
        double[][] ddalphas = new double[numClasses][numFeatures];
        int i2 = 0;
        while (i2 < numClasses) {
            int j = 0;
            while (j < numFeatures) {
                alphaMax[i2][j] = Double.POSITIVE_INFINITY;
                alphaMin[i2][j] = Double.NEGATIVE_INFINITY;
                ++j;
            }
            ++i2;
        }
        double maxAlphachange = 0.0;
        double maxDalpha = 99.0;
        int maxNewtonSteps = 50;
        int newton = 0;
        while (maxDalpha > 1.0E-8 && newton < maxNewtonSteps) {
            int i3 = 0;
            while (i3 < numClasses) {
                int j = 0;
                while (j < numFeatures) {
                    dalphas[i3][j] = p[i3][j] - alphas[i3][j] / gaussianPriorVariance;
                    ddalphas[i3][j] = -1.0 / gaussianPriorVariance;
                    ++j;
                }
                ++i3;
            }
            i3 = 0;
            while (i3 < ilist.size()) {
                assert (classifications[i3].getLabelAlphabet() == ilist.getTargetAlphabet());
                Instance inst = ilist.getInstance(i3);
                labeling = inst.getLabeling();
                fv = (FeatureVector)inst.getData();
                int fl = 0;
                while (fl < fv.numLocations()) {
                    fli = fv.indexAtLocation(fl);
                    li = 0;
                    while (li < numClasses) {
                        double modelLabelWeight = classifications[i3].value(li);
                        double expalpha = Math.exp(alphas[li][fli]);
                        double numerator = modelLabelWeight * expalpha;
                        double denominator = numerator + (1.0 - modelLabelWeight);
                        double[] dArray = dalphas[li];
                        int n = fli;
                        dArray[n] = dArray[n] - numerator / denominator;
                        double[] dArray3 = ddalphas[li];
                        int n3 = fli;
                        dArray3[n3] = dArray3[n3] + (numerator * numerator / (denominator * denominator) - numerator / denominator);
                        ++li;
                    }
                    ++fl;
                }
                ++i3;
            }
            maxDalpha = 0.0;
            maxAlphachange = 0.0;
            int i4 = 0;
            while (i4 < numClasses) {
                int j = 0;
                while (j < numFeatures) {
                    double alphachange = -(dalphas[i4][j] / ddalphas[i4][j]);
                    if (p[i4][j] != 0.0 || q[i4][j] != 0.0) {
                        if (Double.isNaN(alphas[i4][j]) || Double.isNaN(alphachange)) {
                            logger.info("alpha[" + i4 + "][" + j + "]=" + alphas[i4][j] + " p=" + p[i4][j] + " q=" + q[i4][j] + " dalpha=" + dalphas[i4][j] + " ddalpha=" + ddalphas[i4][j] + " alphachange=" + alphachange + " min=" + alphaMin[i4][j] + " max=" + alphaMax[i4][j]);
                        }
                        if (Double.isNaN(alphas[i4][j]) || Double.isNaN(dalphas[i4][j]) || Double.isNaN(ddalphas[i4][j]) || Double.isInfinite(alphas[i4][j]) || Double.isInfinite(dalphas[i4][j]) || Double.isInfinite(ddalphas[i4][j])) {
                            alphachange = 0.0;
                        }
                        double oldalpha = alphas[i4][j];
                        double newalpha = Math.abs(alphachange + alphaChangeOld[i4][j]) / Math.abs(alphachange) < 0.01 ? alphas[i4][j] + alphachange / 2.0 : alphas[i4][j] + alphachange;
                        if (alphachange < 0.0 && alphaMax[i4][j] > alphas[i4][j]) {
                            alphaMax[i4][j] = alphas[i4][j];
                        }
                        if (alphachange > 0.0 && alphaMin[i4][j] < alphas[i4][j]) {
                            alphaMin[i4][j] = alphas[i4][j];
                        }
                        if (newalpha <= alphaMax[i4][j] && newalpha >= alphaMin[i4][j]) {
                            alphas[i4][j] = newalpha;
                        } else {
                            assert (alphaMax[i4][j] != Double.POSITIVE_INFINITY);
                            assert (alphaMin[i4][j] != Double.NEGATIVE_INFINITY);
                            alphas[i4][j] = alphaMin[i4][j] + (alphaMax[i4][j] - alphaMin[i4][j]) / 2.0;
                        }
                        alphachange = alphas[i4][j] - oldalpha;
                        if (Math.abs(alphachange) > maxAlphachange) {
                            maxAlphachange = Math.abs(alphachange);
                        }
                        if (Math.abs(dalphas[i4][j]) > maxDalpha) {
                            maxDalpha = Math.abs(dalphas[i4][j]);
                        }
                        alphaChangeOld[i4][j] = alphachange;
                    }
                    ++j;
                }
                ++i4;
            }
            logger.info("After " + newton + " Newton iterations, maximum alphachange=" + maxAlphachange + " dalpha=" + maxDalpha);
            ++newton;
        }
        alphaMin = alphaMax = (double[][])null;
        alphaChangeOld = alphaMax;
        dalphas = alphaMax;
        ddalphas = alphaMax;
        double[][] qeag = new double[numClasses][numFeatures];
        int i5 = 0;
        while (i5 < ilist.size()) {
            assert (classifications[i5].getLabelAlphabet() == ilist.getTargetAlphabet());
            Instance inst = ilist.getInstance(i5);
            labeling = inst.getLabeling();
            fv = (FeatureVector)inst.getData();
            int fvMaxLocation = fv.numLocations() - 1;
            li = 0;
            while (li < numClasses) {
                double modelLabelWeight = classifications[i5].value(li);
                int fl = 0;
                while (fl < fv.numLocations()) {
                    fli = fv.indexAtLocation(fl);
                    double[] dArray = qeag[li];
                    int n = fli;
                    dArray[n] = dArray[n] + Math.log(modelLabelWeight * Math.exp(alphas[li][fli]) + (1.0 - modelLabelWeight));
                    ++fl;
                }
                ++li;
            }
            ++i5;
        }
        double[] klgains = new double[numFeatures];
        int i6 = 0;
        while (i6 < numClasses) {
            int j = 0;
            while (j < numFeatures) {
                double klgainIncr;
                assert (!Double.isInfinite(alphas[i6][j]));
                double alpha = alphas[i6][j];
                if (alpha != 0.0 && !((klgainIncr = alpha * p[i6][j] - qeag[i6][j] - alpha * alpha / (2.0 * gaussianPriorVariance)) < 0.0)) {
                    int n = j;
                    klgains[n] = klgains[n] + klgainIncr;
                }
                ++j;
            }
            ++i6;
        }
        return klgains;
    }

    public ExpGain(InstanceList ilist, LabelVector[] classifications, double gaussianPriorVariance) {
        super(ilist.getDataAlphabet(), ExpGain.calcExpGains(ilist, classifications, gaussianPriorVariance));
    }

    private static LabelVector[] getLabelVectorsFromClassifications(Classification[] c) {
        LabelVector[] ret = new LabelVector[c.length];
        int i = 0;
        while (i < c.length) {
            ret[i] = c[i].getLabelVector();
            ++i;
        }
        return ret;
    }

    public ExpGain(InstanceList ilist, Classification[] classifications, double gaussianPriorVariance) {
        super(ilist.getDataAlphabet(), ExpGain.calcExpGains(ilist, ExpGain.getLabelVectorsFromClassifications(classifications), gaussianPriorVariance));
    }

    public static class Factory
    implements RankedFeatureVector.Factory {
        LabelVector[] classifications;
        double gaussianPriorVariance = 10.0;
        private static final long serialVersionUID = 1L;
        private static final int CURRENT_SERIAL_VERSION = 0;
        static final /* synthetic */ boolean $assertionsDisabled;

        static {
            $assertionsDisabled = !Factory.class.desiredAssertionStatus();
        }

        public Factory(LabelVector[] classifications) {
            this.classifications = classifications;
        }

        public Factory(LabelVector[] classifications, double gaussianPriorVariance) {
            this.classifications = classifications;
            this.gaussianPriorVariance = gaussianPriorVariance;
        }

        public RankedFeatureVector newRankedFeatureVector(InstanceList ilist) {
            if (!$assertionsDisabled && ilist.getTargetAlphabet() != this.classifications[0].getAlphabet()) {
                throw new AssertionError();
            }
            return new ExpGain(ilist, this.classifications, this.gaussianPriorVariance);
        }

        private void writeObject(ObjectOutputStream out) throws IOException {
            out.writeInt(0);
            out.writeInt(this.classifications.length);
            int i = 0;
            while (i < this.classifications.length) {
                out.writeObject(this.classifications[i]);
                ++i;
            }
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            int version = in.readInt();
            int n = in.readInt();
            this.classifications = new LabelVector[n];
            int i = 0;
            while (i < n) {
                this.classifications[i] = (LabelVector)in.readObject();
                ++i;
            }
        }
    }
}

