/*
 * Decompiled with CFR 0.152.
 */
package edu.umass.cs.mallet.base.fst;

import edu.umass.cs.mallet.base.fst.Transducer;
import edu.umass.cs.mallet.base.fst.TransducerEvaluator;
import edu.umass.cs.mallet.base.types.Instance;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.types.MatrixOps;
import edu.umass.cs.mallet.base.types.Sequence;
import edu.umass.cs.mallet.base.util.MalletLogger;
import java.io.PrintStream;
import java.util.logging.Logger;

public class TokenAccuracyEvaluator
extends TransducerEvaluator {
    private static Logger logger = MalletLogger.getLogger(TokenAccuracyEvaluator.class.getName());
    private double lastAccuracy;

    public TokenAccuracyEvaluator(boolean printViterbiPath) {
        this.viterbiOutput = printViterbiPath;
    }

    public TokenAccuracyEvaluator() {
        this(false);
    }

    public boolean evaluate(Transducer crf, boolean finishedTraining, int iteration, boolean converged, double cost, InstanceList training, InstanceList validation, InstanceList testing) {
        logger.info("Iteration=" + iteration + " Cost=" + cost);
        if (this.shouldDoEvaluate(iteration, finishedTraining)) {
            InstanceList[] lists = new InstanceList[]{training, validation, testing};
            String[] listnames = new String[]{"Training", "Validation", "Testing"};
            int k = 0;
            while (k < lists.length) {
                if (lists[k] != null) {
                    this.test(crf, lists[k], listnames[k], null);
                }
                ++k;
            }
        }
        return true;
    }

    public void test(Transducer model, InstanceList data, String description, PrintStream viterbiOutputStream) {
        double[] meanStatesExpl = new double[data.size()];
        int numCorrectTokens = 0;
        int totalTokens = 0;
        logger.info("Results for " + description);
        int i = 0;
        while (i < data.size()) {
            Instance instance = data.getInstance(i);
            Sequence input = (Sequence)instance.getData();
            Sequence trueOutput = (Sequence)instance.getTarget();
            assert (input.size() == trueOutput.size());
            Sequence predOutput = model.transduce(input);
            assert (predOutput.size() == trueOutput.size());
            meanStatesExpl[i] = MatrixOps.mean(model.getNstatesExpl());
            int j = 0;
            while (j < trueOutput.size()) {
                ++totalTokens;
                if (trueOutput.get(j).equals(predOutput.get(j))) {
                    ++numCorrectTokens;
                }
                if (viterbiOutputStream != null) {
                    Object f = input.get(j);
                    viterbiOutputStream.println(String.valueOf(trueOutput.get(j).toString()) + '/' + predOutput.get(j).toString() + "  " + f.toString());
                }
                ++j;
            }
            ++i;
        }
        double cMean = MatrixOps.mean(meanStatesExpl);
        logger.info("Mean states explored=" + cMean);
        this.lastAccuracy = (double)numCorrectTokens / (double)totalTokens;
        logger.info(String.valueOf(description) + " accuracy=" + this.lastAccuracy);
    }

    public double getLastAccuracy() {
        return this.lastAccuracy;
    }
}

