/*
 * Decompiled with CFR 0.152.
 */
package tratz.parse.ml;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import tratz.parse.ml.AbstractParseModel;
import tratz.parse.ml.FinalizedParseModel;
import tratz.types.ChecksumMap;
import tratz.types.FloatArrayList;
import tratz.types.IntArrayList;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class TrainablePerceptron
extends AbstractParseModel {
    public static final long serialVersionUID = 1L;
    private int count;
    private List<Entry> mEntries = new ArrayList<Entry>();
    private ChecksumMap<String> mFeatToInd;

    public TrainablePerceptron() {
    }

    public TrainablePerceptron(List<String> actions) {
        this.mActions = new ArrayList<String>(actions);
        this.mActionToIndex = new HashMap();
        this.mFeatToInd = new ChecksumMap();
        this.mEntries.add(new Entry());
        this.mFeatToInd.put("blahblah**blah", 0);
    }

    @Override
    public int getIndex(String feat, boolean add) {
        int index = this.mFeatToInd.get(feat);
        if (index != Integer.MIN_VALUE) {
            return index;
        }
        if (add) {
            index = this.mFeatToInd.size();
            this.mFeatToInd.put(feat, index);
        }
        return index;
    }

    public ChecksumMap<String> getFeatToInd() {
        return this.mFeatToInd;
    }

    @Override
    public void incrementCount() {
        ++this.count;
    }

    @Override
    public void updateFeature(int actionIndex, int feat, double change) {
        int numEntries = this.mEntries.size();
        if (feat >= numEntries) {
            for (int j = 0; j < feat - numEntries + 1; ++j) {
                this.mEntries.add(new Entry());
                if (numEntries % 100000 != 0) continue;
                System.err.println("Entries: " + numEntries);
            }
        }
        Entry entry = this.mEntries.get(feat);
        if (entry.w == null) {
            entry.w = new FloatArrayList(1);
            entry.w2 = new FloatArrayList(1);
            entry.c = new IntArrayList(1);
            entry.w2.add(0.0f);
            entry.w.add((float)change);
            entry.c.add(this.count);
            entry.classOne = actionIndex;
        } else {
            int numWeights = entry.w.size();
            if (numWeights == 1) {
                if (entry.classOne == actionIndex) {
                    entry.w2.set(0, entry.w2.get(0) + entry.w.get(0) * (float)(this.count - entry.c.get(0)));
                    entry.w.set(0, entry.w.get(0) + (float)change);
                    entry.c.set(0, this.count);
                } else {
                    int oldClassOne = entry.classOne;
                    int c = entry.c.get(0);
                    double w = entry.w.get(0);
                    double w2 = entry.w2.get(0);
                    int maxIndex = Math.max(actionIndex, entry.classOne);
                    entry.classOne = -1;
                    if (numWeights <= maxIndex) {
                        int numToAdd = maxIndex - numWeights + 1;
                        for (int j = 0; j < numToAdd; ++j) {
                            entry.w2.add(0.0f);
                            entry.w.add(0.0f);
                            entry.c.add(0);
                        }
                    }
                    entry.w2.set(0, 0.0f);
                    entry.w.set(0, 0.0f);
                    entry.c.set(0, 0);
                    entry.w2.set(oldClassOne, (float)w2);
                    entry.w.set(oldClassOne, (float)w);
                    entry.c.set(oldClassOne, c);
                    double oldW = entry.w.get(actionIndex);
                    entry.w2.set(actionIndex, (float)((double)entry.w2.get(actionIndex) + oldW * (double)(this.count - entry.c.get(actionIndex))));
                    entry.w.set(actionIndex, (float)(oldW + change));
                    entry.c.set(actionIndex, this.count);
                }
            } else {
                if (numWeights <= actionIndex) {
                    int numToAdd = actionIndex - numWeights + 1;
                    for (int j = 0; j < numToAdd; ++j) {
                        entry.w2.add(0.0f);
                        entry.w.add(0.0f);
                        entry.c.add(0);
                    }
                }
                double oldW = entry.w.get(actionIndex);
                entry.w2.set(actionIndex, (float)((double)entry.w2.get(actionIndex) + oldW * (double)(this.count - entry.c.get(actionIndex))));
                entry.w.set(actionIndex, (float)(oldW + change));
                entry.c.set(actionIndex, this.count);
            }
        }
    }

    @Override
    public void update(String action, IntArrayList feats, double change) {
        int actionIndex = this.getActionIndex(action, true);
        int numFeats = feats.size();
        int numEntries = this.mEntries.size();
        for (int i = 0; i < numFeats; ++i) {
            int feat = feats.get(i);
            if (feat >= numEntries) {
                for (int j = 0; j < feat - numEntries + 1; ++j) {
                    this.mEntries.add(new Entry());
                }
                if (this.mEntries.size() % 100000 == 0) {
                    System.err.println("Entries: " + this.mEntries.size());
                }
            }
            Entry entry = this.mEntries.get(feat);
            if (entry.w == null) {
                entry.w = new FloatArrayList(1);
                entry.w2 = new FloatArrayList(1);
                entry.c = new IntArrayList(1);
                entry.w2.add(0.0f);
                entry.w.add((float)change);
                entry.c.add(this.count);
                entry.classOne = actionIndex;
                continue;
            }
            int numWeights = entry.w.size();
            if (numWeights == 1) {
                if (entry.classOne == actionIndex) {
                    entry.w2.set(0, entry.w2.get(0) + entry.w.get(0) * (float)(this.count - entry.c.get(0)));
                    entry.w.set(0, entry.w.get(0) + (float)change);
                    entry.c.set(0, this.count);
                    continue;
                }
                int oldClassOne = entry.classOne;
                int c = entry.c.get(0);
                double w = entry.w.get(0);
                double w2 = entry.w2.get(0);
                int maxIndex = Math.max(actionIndex, entry.classOne);
                entry.classOne = -1;
                if (numWeights <= maxIndex) {
                    int numToAdd = maxIndex - numWeights + 1;
                    for (int j = 0; j < numToAdd; ++j) {
                        entry.w2.add(0.0f);
                        entry.w.add(0.0f);
                        entry.c.add(0);
                    }
                }
                entry.w2.set(0, 0.0f);
                entry.w.set(0, 0.0f);
                entry.c.set(0, 0);
                entry.w2.set(oldClassOne, (float)w2);
                entry.w.set(oldClassOne, (float)w);
                entry.c.set(oldClassOne, c);
                double oldW = entry.w.get(actionIndex);
                entry.w2.set(actionIndex, (float)((double)entry.w2.get(actionIndex) + oldW * (double)(this.count - entry.c.get(actionIndex))));
                entry.w.set(actionIndex, (float)(oldW + change));
                entry.c.set(actionIndex, this.count);
                continue;
            }
            if (numWeights <= actionIndex) {
                int numToAdd = actionIndex - numWeights + 1;
                for (int j = 0; j < numToAdd; ++j) {
                    entry.w2.add(0.0f);
                    entry.w.add(0.0f);
                    entry.c.add(0);
                }
            }
            double oldW = entry.w.get(actionIndex);
            entry.w2.set(actionIndex, (float)((double)entry.w2.get(actionIndex) + oldW * (double)(this.count - entry.c.get(actionIndex))));
            entry.w.set(actionIndex, (float)(oldW + change));
            entry.c.set(actionIndex, this.count);
        }
    }

    @Override
    public final void scoreIntermediate(List<String> actions, IntArrayList feats, int[] indices, double[] scores) {
        int i;
        int numActions = actions.size();
        int numFeats = feats.size();
        int numEntries = this.mEntries.size();
        for (i = 0; i < numActions; ++i) {
            Integer val = (Integer)this.mActionToIndex.get(actions.get(i));
            indices[i] = val == null ? -1 : val;
            scores[i] = 0.0;
        }
        block1: for (i = 0; i < numFeats; ++i) {
            int a;
            int entrySize;
            int feat = feats.get(i);
            if (feat >= numEntries) continue;
            Entry entry = this.mEntries.get(feat);
            int n = entrySize = entry.c == null ? 0 : entry.c.size();
            if (entrySize > 1) {
                for (a = 0; a < numActions; ++a) {
                    int actionIndex = indices[a];
                    if (actionIndex == -1 || actionIndex >= entrySize) continue;
                    int n2 = a;
                    scores[n2] = scores[n2] + (double)entry.w.get(actionIndex);
                }
                continue;
            }
            if (entrySize != 1) continue;
            for (a = 0; a < numActions; ++a) {
                if (indices[a] != entry.classOne) continue;
                int n3 = a;
                scores[n3] = scores[n3] + (double)entry.w.get(0);
                continue block1;
            }
        }
    }

    @Override
    public final void score(List<String> actions, IntArrayList feats, int[] indices, double[] scores) {
        int i;
        int numActions = actions.size();
        int numFeats = feats.size();
        int numEntries = this.mEntries.size();
        for (i = 0; i < numActions; ++i) {
            Integer val = (Integer)this.mActionToIndex.get(actions.get(i));
            indices[i] = val == null ? -1 : val;
            scores[i] = 0.0;
        }
        block1: for (i = 0; i < numFeats; ++i) {
            int a;
            int entrySize;
            int feat = feats.get(i);
            if (feat >= numEntries) continue;
            Entry entry = this.mEntries.get(feat);
            int n = entrySize = entry.c == null ? 0 : entry.c.size();
            if (entrySize > 1) {
                for (a = 0; a < numActions; ++a) {
                    int actionIndex = indices[a];
                    if (actionIndex == -1 || actionIndex >= entrySize) continue;
                    double w = entry.w.get(actionIndex);
                    double w2 = entry.w2.get(actionIndex);
                    int c = entry.c.get(actionIndex);
                    int n2 = a;
                    scores[n2] = scores[n2] + (w2 + w * (double)(this.count - c));
                }
                continue;
            }
            if (entrySize != 1) continue;
            for (a = 0; a < numActions; ++a) {
                if (indices[a] != entry.classOne) continue;
                double w = entry.w.get(0);
                double w2 = entry.w2.get(0);
                int c = entry.c.get(0);
                int n3 = a;
                scores[n3] = scores[n3] + (w2 + w * (double)(this.count - c));
                continue block1;
            }
        }
    }

    public FinalizedParseModel createFinal(double amountToKeep) {
        return new FinalizedParseModel(this.mActions, this.mActionToIndex, this.mPosPosActs, this.mFeatToInd, this.count, (ArrayList)this.mEntries, amountToKeep);
    }

    public static class Entry
    implements Serializable {
        public static final long serialVersionUID = 1L;
        public int classOne;
        public FloatArrayList w = null;
        public FloatArrayList w2 = null;
        public IntArrayList c = null;
    }
}

