/*
 * Decompiled with CFR 0.152.
 */
package tratz.ml;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import tratz.ml.ClassDictionary;
import tratz.ml.ClassScoreTuple;
import tratz.ml.FeatureDictionary;
import tratz.types.ChecksumMap;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class LinearClassificationModel
implements Serializable {
    public static final long serialVersionUID = 1L;
    private List<float[]> mModel;
    private int[] mModelLabelOrder;
    private ClassDictionary mLabelAlphabet;
    private FeatureDictionary mAlphabet;

    public LinearClassificationModel(List<float[]> model, int[] modelLabelOrder, ClassDictionary labelAlphabet, FeatureDictionary alphabet) {
        this.mModel = model;
        this.mModelLabelOrder = modelLabelOrder;
        this.mLabelAlphabet = labelAlphabet;
        this.mAlphabet = alphabet;
    }

    public LinearClassificationModel createTrimmedModel(double amountToRemove) {
        assert (amountToRemove < 1.0 && amountToRemove > 0.0);
        ClassDictionary labelAlphabet = new ClassDictionary(this.mLabelAlphabet);
        Map<ChecksumMap.TwoPartKey, Integer> featset = this.mAlphabet.getKeySet();
        final HashMap<ChecksumMap.TwoPartKey, Float> featToAbsoluteValue = new HashMap<ChecksumMap.TwoPartKey, Float>();
        int numClasses = this.mModel.size();
        int maxVectorWidth = 0;
        for (float[] vector : this.mModel) {
            if (vector.length <= maxVectorWidth) continue;
            maxVectorWidth = vector.length;
        }
        ArrayList<ChecksumMap.TwoPartKey> featList = new ArrayList<ChecksumMap.TwoPartKey>();
        for (ChecksumMap.TwoPartKey feat : featset.keySet()) {
            float abs = 0.0f;
            int featIndex = featset.get(feat);
            for (int i = 0; i < numClasses; ++i) {
                abs += Math.abs(this.mModel.get(i)[featIndex]);
            }
            if (!(abs > 0.0f)) continue;
            featToAbsoluteValue.put(feat, Float.valueOf(abs));
            featList.add(feat);
        }
        Collections.sort(featList, new Comparator<ChecksumMap.TwoPartKey>(){

            @Override
            public int compare(ChecksumMap.TwoPartKey s1, ChecksumMap.TwoPartKey s2) {
                float value1 = ((Float)featToAbsoluteValue.get(s1)).floatValue();
                float value2 = ((Float)featToAbsoluteValue.get(s2)).floatValue();
                if (value2 > value1) {
                    return 1;
                }
                if (value2 < value1) {
                    return -1;
                }
                return 0;
            }
        });
        FeatureDictionary alphabet = new FeatureDictionary();
        int featsToKeep = (int)((1.0 - amountToRemove) * (double)featList.size());
        ArrayList<float[]> model = new ArrayList<float[]>(numClasses);
        for (int i = 0; i < numClasses; ++i) {
            float[] oldvector = this.mModel.get(i);
            float[] vector = new float[featsToKeep + 1];
            model.add(vector);
            for (int j = 0; j < featsToKeep; ++j) {
                ChecksumMap.TwoPartKey feat = (ChecksumMap.TwoPartKey)featList.get(j);
                vector[alphabet.lookupIndex((ChecksumMap.TwoPartKey)feat, (boolean)true)] = oldvector[this.mAlphabet.lookupIndex(feat, false)];
            }
        }
        return new LinearClassificationModel(model, this.mModelLabelOrder, labelAlphabet, alphabet);
    }

    public ClassScoreTuple[] getDecision(Set<String> features) {
        int i;
        int[] featList = new int[features.size()];
        int x = 0;
        for (String feat : features) {
            int index = this.mAlphabet.lookupIndex(feat, false);
            if (index < 0) continue;
            featList[x] = index;
            ++x;
        }
        int numClasses = this.mModel.size();
        float[] decVals = new float[numClasses];
        for (int i2 = 0; i2 < numClasses; ++i2) {
            decVals[i2] = this.calculateScore(featList, this.mModel.get(i2), x);
        }
        Object[] ranks = new ClassScoreTuple[numClasses];
        for (i = 0; i < numClasses; ++i) {
            ranks[i] = new ClassScoreTuple(this.mLabelAlphabet.lookupLabel(this.mModelLabelOrder[i]), 0.0);
        }
        for (i = 0; i < numClasses; ++i) {
            ((ClassScoreTuple)ranks[i]).score = decVals[i];
        }
        Arrays.sort(ranks);
        return ranks;
    }

    private final float calculateScore(int[] features, float[] vector, int numFeats) {
        float score = 0.0f;
        for (int i = 0; i < numFeats; ++i) {
            score += vector[features[i]];
        }
        return score;
    }
}

