/*
 * Decompiled with CFR 0.152.
 */
package tratz.parse.train;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import tratz.parse.NLParser;
import tratz.parse.ParseAction;
import tratz.parse.featgen.ParseFeatureGenerator;
import tratz.parse.ml.ParseModel;
import tratz.parse.train.DefaultPenaltyFunction;
import tratz.parse.train.PenaltyFunction;
import tratz.parse.train.PerSentenceTrainer;
import tratz.parse.types.Arc;
import tratz.parse.types.Token;
import tratz.parse.types.TokenPointer;
import tratz.types.IntArrayList;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class StandardPerSentenceTrainer
implements PerSentenceTrainer {
    public static final double DEFAULT_MAX_UPDATE = 0.1;
    public static final int DEFAULT_MAX_ITERATIONS = 10;
    private double mMaxUpdate = 0.1;
    private int mMaxIterations = 10;
    private PenaltyFunction mPenaltyFunction = new DefaultPenaltyFunction();

    public IntArrayList getValues(ParseModel model, Set<String> fts, boolean addFeats) {
        IntArrayList values = new IntArrayList(fts.size());
        for (String f : fts) {
            int index = model.getIndex(f, addFeats);
            if (index == Integer.MIN_VALUE) continue;
            values.add(index);
        }
        return values;
    }

    public static boolean hasAllItsDependents(Token topOfStack, List<Arc> arcListFull, List<Arc> arcListWorking) {
        int numFull = arcListFull == null ? 0 : arcListFull.size();
        int numWorking = arcListWorking == null ? 0 : arcListWorking.size();
        return numFull == numWorking;
    }

    @Override
    public PerSentenceTrainer.TrainingResult train(List<Token> sentence, List[] goldArcs, Arc[] goldTokenToHead, ParseModel w, ParseFeatureGenerator featGen, Token[] tokenToSubcomponentHead, int[] projectiveIndices) throws Exception {
        boolean maxIterationsReached = false;
        boolean fatalError = false;
        int[] indicesHolder = new int[w.getActions().size()];
        double[] scores = new double[w.getActions().size()];
        w.incrementCount();
        int numTokens = sentence.size();
        int numInvalids = 0;
        List[] currentArcs = new List[numTokens + 1];
        TokenPointer first = null;
        TokenPointer[] tokenToPtr = new TokenPointer[numTokens + 1];
        boolean[] actionListStale = new boolean[numTokens + 1];
        TokenPointer prev = null;
        for (int i = 0; i < numTokens; ++i) {
            Token t = sentence.get(i);
            TokenPointer ptr = new TokenPointer(t, null, prev);
            if (first == null) {
                first = ptr;
            }
            tokenToPtr[t.getIndex()] = ptr;
            if (prev != null) {
                prev.next = ptr;
            }
            prev = ptr;
            actionListStale[i] = true;
        }
        actionListStale[numTokens] = true;
        IntArrayList[] featureCache = new IntArrayList[numTokens + 1];
        HashMap actionCache = new HashMap();
        HashSet<String> feats = new HashSet<String>();
        int numIterations = 0;
        while (first.next != null) {
            ++numIterations;
            ParseAction highestScoredValidAction = null;
            ParseAction highestScoredInvalidAction = null;
            ParseAction lowestScoredValidAction = null;
            double lowestScoredValidActionScore = Double.POSITIVE_INFINITY;
            double highestScoredInvalidActionScore = Double.NEGATIVE_INFINITY;
            double highestScoredValidActionScore = Double.NEGATIVE_INFINITY;
            ArrayList<ParseAction> invalidActions = new ArrayList<ParseAction>();
            TokenPointer ptr = first;
            while (ptr != null) {
                Token token = ptr.tok;
                ArrayList<ParseAction> actions = (ArrayList<ParseAction>)actionCache.get(token);
                if (actionListStale[token.getIndex()]) {
                    int i;
                    int numActions;
                    actions = null;
                    featGen.genFeats(feats, w, sentence, ptr, currentArcs);
                    IntArrayList values = this.getValues(w, feats, true);
                    feats.clear();
                    IntArrayList tokenFeatures = featureCache[token.getIndex()];
                    featureCache[token.getIndex()] = tokenFeatures = values;
                    List<String> actionNames = w.getActions(token, ptr.next == null ? null : ptr.next.tok, goldTokenToHead);
                    w.scoreIntermediate(actionNames, tokenFeatures, indicesHolder, scores);
                    if (actions == null) {
                        actions = new ArrayList<ParseAction>();
                        actionCache.put(token, actions);
                        numActions = actionNames.size();
                        for (i = 0; i < numActions; ++i) {
                            actions.add(new ParseAction(token, ptr, actionNames.get(i), scores[i]));
                        }
                    } else {
                        numActions = actionNames.size();
                        for (i = 0; i < numActions; ++i) {
                            ParseAction action = (ParseAction)actions.get(i);
                            action.score = scores[i];
                        }
                    }
                    actionListStale[token.getIndex()] = false;
                }
                for (ParseAction action : actions) {
                    TokenPointer tmpPtr = null;
                    int tokenIndex = action.token.getIndex();
                    double penalty = this.mPenaltyFunction.calculatePenalty(tokenToPtr[tokenIndex], action, goldTokenToHead, goldArcs, currentArcs, tokenToSubcomponentHead, projectiveIndices);
                    if (penalty > 0.0 || lowestScoredValidAction != null && (action.actionName.equals("SWAPRIGHT") && !lowestScoredValidAction.actionName.equals("SWAPRIGHT") && ((tmpPtr = tokenToPtr[tokenIndex].next) == null || (tmpPtr = tmpPtr.next) == null || !this.hasAllItsDependentsAndIsAMatch(action.token, tmpPtr.tok, goldTokenToHead[tokenIndex], goldTokenToHead, goldArcs[tokenIndex], currentArcs[tokenIndex])) || action.actionName.equals("SWAPLEFT") && !lowestScoredValidAction.actionName.equals("SWAPLEFT") && ((tmpPtr = tokenToPtr[tokenIndex].prev) == null || (tmpPtr = tmpPtr.prev) == null || !this.hasAllItsDependentsAndIsAMatch(action.token, tmpPtr.tok, goldTokenToHead[tokenIndex], goldTokenToHead, goldArcs[tokenIndex], currentArcs[tokenIndex])))) {
                        invalidActions.add(action);
                        if (!(action.score > highestScoredInvalidActionScore)) continue;
                        highestScoredInvalidAction = action;
                        highestScoredInvalidActionScore = action.score;
                        continue;
                    }
                    if (lowestScoredValidAction != null) {
                        tokenIndex = lowestScoredValidAction.token.getIndex();
                        if (!(action.actionName.startsWith("SWAP") || (!lowestScoredValidAction.actionName.equals("SWAPRIGHT") || (tmpPtr = tokenToPtr[tokenIndex].next) != null && (tmpPtr = tmpPtr.next) != null && this.hasAllItsDependentsAndIsAMatch(lowestScoredValidAction.token, tmpPtr.tok, goldTokenToHead[tokenIndex], goldTokenToHead, goldArcs[tokenIndex], currentArcs[tokenIndex])) && (!lowestScoredValidAction.actionName.equals("SWAPLEFT") || (tmpPtr = tokenToPtr[tokenIndex].prev) != null && (tmpPtr = tmpPtr.prev) != null && this.hasAllItsDependentsAndIsAMatch(lowestScoredValidAction.token, tmpPtr.tok, goldTokenToHead[tokenIndex], goldTokenToHead, goldArcs[tokenIndex], currentArcs[tokenIndex])))) {
                            invalidActions.add(lowestScoredValidAction);
                            lowestScoredValidActionScore = Double.POSITIVE_INFINITY;
                            highestScoredValidActionScore = Double.NEGATIVE_INFINITY;
                            if (lowestScoredValidAction.score > highestScoredInvalidActionScore) {
                                highestScoredInvalidAction = lowestScoredValidAction;
                                highestScoredInvalidActionScore = lowestScoredValidAction.score;
                                lowestScoredValidAction = null;
                                highestScoredValidAction = null;
                            }
                        }
                    }
                    if (action.score < lowestScoredValidActionScore || lowestScoredValidActionScore == Double.MAX_VALUE) {
                        lowestScoredValidAction = action;
                        lowestScoredValidActionScore = action.score;
                    }
                    if (!(action.score > highestScoredValidActionScore) && !Double.isInfinite(highestScoredValidActionScore)) continue;
                    highestScoredValidAction = action;
                    highestScoredValidActionScore = action.score;
                }
                ptr = ptr.next;
            }
            if (lowestScoredValidAction == null) {
                System.err.println("ERROR: No Valid Action Found! Do cycles or multi-headed tokens exist? Moving on to next sentence...");
                TokenPointer ptrZ = first;
                while (ptrZ != null && ptrZ.tok != null) {
                    IntArrayList tokenFeatures = featureCache[ptrZ.tok.getIndex()];
                    if (tokenFeatures == null) {
                        // empty if block
                    }
                    List<String> actionNames2 = w.getActions(ptrZ.tok, ptrZ.next == null ? null : ptrZ.next.tok, goldTokenToHead);
                    int numActions = actionNames2.size();
                    System.err.print(ptrZ.tok.getText() + " " + ptrZ.tok.getPos() + " ");
                    w.scoreIntermediate(actionNames2, tokenFeatures, indicesHolder, scores);
                    for (int i = 0; i < numActions; ++i) {
                        String actionName = actionNames2.get(i);
                        System.err.print(actionName + ":" + scores[i] + ", ");
                    }
                    System.err.println();
                    ptrZ = ptrZ.next;
                }
                fatalError = true;
                break;
            }
            if (highestScoredInvalidActionScore < lowestScoredValidAction.score || numIterations > this.mMaxIterations) {
                if (numIterations > this.mMaxIterations) {
                    maxIterationsReached = true;
                }
                numIterations = 0;
                first = NLParser.performAction(sentence, first, tokenToPtr, highestScoredValidAction, actionListStale, featureCache, currentArcs, -1, featGen.getContextWidth());
                continue;
            }
            ++numInvalids;
            this.performUpdate(lowestScoredValidAction, highestScoredInvalidAction, first, actionListStale, featureCache, w);
        }
        return new PerSentenceTrainer.TrainingResult(maxIterationsReached, fatalError, numInvalids);
    }

    private void performUpdate(ParseAction lowestScoredValidAction, ParseAction maxInvalidAction, TokenPointer first, boolean[] actionListStale, IntArrayList[] featureCache, ParseModel w) {
        ParseAction goodAction = lowestScoredValidAction;
        ParseAction badAction = maxInvalidAction;
        TokenPointer tptr = first;
        while (tptr != null) {
            actionListStale[tptr.tok.getIndex()] = true;
            tptr = tptr.next;
        }
        double denominator = featureCache[badAction.token.getIndex()].size() + featureCache[goodAction.token.getIndex()].size();
        double change = 1.0;
        double update = Math.min(this.mMaxUpdate, (badAction.score - goodAction.score + change) / denominator);
        w.update(badAction.actionName, featureCache[badAction.token.getIndex()], -update);
        w.update(goodAction.actionName, featureCache[goodAction.token.getIndex()], update);
    }

    private boolean hasAllItsDependentsAndIsAMatch(Token tokenToMove, Token newNeighbor, Arc goldHeadArc, Arc[] goldTokenToHead, List<Arc> goldArcs, List<Arc> currentArcs) {
        return goldHeadArc != null && goldHeadArc.getHead() == newNeighbor && StandardPerSentenceTrainer.hasAllItsDependents(tokenToMove, goldArcs, currentArcs);
    }
}

