/*
 * Decompiled with CFR 0.152.
 */
package edu.umass.cs.mallet.base.types;

import edu.umass.cs.mallet.base.types.Alphabet;
import edu.umass.cs.mallet.base.types.FeatureSequence;
import edu.umass.cs.mallet.base.types.FeatureVector;
import edu.umass.cs.mallet.base.types.Multinomial;
import edu.umass.cs.mallet.base.types.TokenSequence;
import edu.umass.cs.mallet.base.util.Random;
import java.util.ArrayList;

public class Dirichlet {
    double[] alphas;
    Alphabet dict;

    public Dirichlet(double[] alphas, Alphabet dict) {
        if (dict != null && alphas.length != dict.size()) {
            throw new IllegalArgumentException("alphas and dict sizes do not match.");
        }
        this.alphas = alphas;
        this.dict = dict;
        if (dict != null) {
            dict.stopGrowth();
        }
    }

    public Dirichlet(double[] alphas) {
        this.alphas = alphas;
        this.dict = null;
    }

    public Dirichlet(Alphabet dict) {
        this(dict, 1.0);
    }

    public Dirichlet(Alphabet dict, double alpha) {
        this(dict.size(), alpha);
        this.dict = dict;
        dict.stopGrowth();
    }

    public Dirichlet(int size) {
        this(size, 1.0);
    }

    public Dirichlet(int size, double alpha) {
        this.alphas = new double[size];
        int i = 0;
        while (i < size) {
            this.alphas[i] = alpha;
            ++i;
        }
    }

    public Alphabet getAlphabet() {
        return this.dict;
    }

    public int size() {
        return this.alphas.length;
    }

    public double alpha(int featureIndex) {
        return this.alphas[featureIndex];
    }

    public void print() {
        if (this.alphas != null) {
            throw new IllegalStateException("foo");
        }
        System.out.println("Dirichlet:");
        int j = 0;
        while (j < this.alphas.length) {
            System.out.println(this.dict != null ? this.dict.lookupObject(j).toString() : String.valueOf(j) + "=" + this.alphas[j]);
            ++j;
        }
    }

    protected double[] randomRawMultinomial(Random r) {
        double sum = 0.0;
        double[] pr = new double[this.alphas.length];
        int i = 0;
        while (i < this.alphas.length) {
            pr[i] = r.nextGamma(this.alphas[i]);
            sum += pr[i];
            ++i;
        }
        i = 0;
        while (i < this.alphas.length) {
            int n = i++;
            pr[n] = pr[n] / sum;
        }
        return pr;
    }

    public Multinomial randomMultinomial(Random r) {
        return new Multinomial(this.randomRawMultinomial(r), this.dict, this.alphas.length, false, false);
    }

    public Dirichlet randomDirichlet(Random r, double averageAlpha) {
        double[] pr = this.randomRawMultinomial(r);
        double alphaSum = (double)pr.length * averageAlpha;
        int i = 0;
        while (i < pr.length) {
            int n = i++;
            pr[n] = pr[n] * alphaSum;
        }
        return new Dirichlet(pr, this.dict);
    }

    public FeatureSequence randomFeatureSequence(Random r, int length) {
        Multinomial m = this.randomMultinomial(r);
        return m.randomFeatureSequence(r, length);
    }

    public FeatureVector randomFeatureVector(Random r, int size) {
        return new FeatureVector(this.randomFeatureSequence(r, size));
    }

    public TokenSequence randomTokenSequence(Random r, int length) {
        FeatureSequence fs = this.randomFeatureSequence(r, length);
        TokenSequence ts = new TokenSequence(length);
        int i = 0;
        while (i < length) {
            ts.add(fs.getObjectAtPosition(i));
            ++i;
        }
        return ts;
    }

    public double[] randomVector(Random r) {
        return this.randomRawMultinomial(r);
    }

    public static abstract class Estimator {
        ArrayList multinomials;

        public Estimator() {
            this.multinomials = new ArrayList();
        }

        public Estimator(ArrayList multinomials) {
            this.multinomials = multinomials;
            int i = 1;
            while (i < multinomials.size()) {
                if (((Multinomial)multinomials.get(i - 1)).size() != ((Multinomial)multinomials.get(i)).size() || ((Multinomial)multinomials.get(i - 1)).getAlphabet() != ((Multinomial)multinomials.get(i)).getAlphabet()) {
                    throw new IllegalArgumentException("All multinomials must have same size and Alphabet.");
                }
                ++i;
            }
        }

        public void addMultinomial(Multinomial m) {
            this.multinomials.add(m);
        }

        public abstract Dirichlet estimate();
    }

    public static class MethodOfMomentsEstimator
    extends Estimator {
        public Dirichlet estimate() {
            Dirichlet d = new Dirichlet(((Multinomial)this.multinomials.get(0)).size());
            int i = 1;
            while (i < this.multinomials.size()) {
                ((Multinomial)this.multinomials.get(i)).addProbabilitiesTo(d.alphas);
                ++i;
            }
            double alphaSum = 0.0;
            int i2 = 0;
            while (i2 < d.alphas.length) {
                alphaSum += d.alphas[i2];
                ++i2;
            }
            i2 = 0;
            while (i2 < d.alphas.length) {
                int n = i2++;
                d.alphas[n] = d.alphas[n] / alphaSum;
            }
            throw new UnsupportedOperationException("Not yet implemented.");
        }
    }
}

