/*
 * Decompiled with CFR 0.152.
 */
package tratz.ml;

import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.Vector;
import java.util.zip.GZIPInputStream;
import tratz.cmdline.CommandLineOptions;
import tratz.cmdline.CommandLineOptionsParser;
import tratz.cmdline.ParsedCommandLine;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class FeatureSelection {
    public static final String OPT_INPUT_DIR = "instances";
    public static final String OPT_OUTPUT_DIR = "output";
    public static final String OPT_MIN_FEATURE_FREQUENCY = "minfreq";
    public static final String OPT_NUMTHREADS = "numthreads";
    public static final String OPT_SELECTIONMETRIC = "metric";
    public static final int DEFAULT_NUMTHREADS = 1;
    public static final int DEFAULT_MINCOUNT = 0;
    public static final String DEFAULT_SELECTION_METRIC_CLASS = ChiSquared.class.getName();

    private static CommandLineOptions createOptions() {
        CommandLineOptions cmdOptions = new CommandLineOptions();
        cmdOptions.addOption(OPT_INPUT_DIR, "file", "input file/directory containing instances file");
        cmdOptions.addOption(OPT_OUTPUT_DIR, "file", "output file/directory for the feature ranking files");
        cmdOptions.addOption(OPT_MIN_FEATURE_FREQUENCY, "integer", "minimum feature frequency");
        cmdOptions.addOption(OPT_SELECTIONMETRIC, "classname", "name of the feature selection class");
        cmdOptions.addOption(OPT_NUMTHREADS, OPT_NUMTHREADS, "number of threads to use (only applicable for directories)");
        return cmdOptions;
    }

    public static void main(String[] args) throws Exception {
        ParsedCommandLine cmdLine = new CommandLineOptionsParser().parseOptions(FeatureSelection.createOptions(), args);
        File instances = new File(cmdLine.getStringValue(OPT_INPUT_DIR));
        final File output = new File(cmdLine.getStringValue(OPT_OUTPUT_DIR));
        final int minFeatFrequency = cmdLine.getIntegerValue(OPT_MIN_FEATURE_FREQUENCY, 0);
        String selectionMetricClass = cmdLine.getStringValue(OPT_SELECTIONMETRIC, DEFAULT_SELECTION_METRIC_CLASS);
        final SelectionMetric metric = (SelectionMetric)Class.forName(selectionMetricClass).newInstance();
        System.err.println("Instances: " + instances.getAbsolutePath());
        if (instances.isFile()) {
            if (output.exists() && !output.isFile()) {
                System.err.println("Error: output must specify a directory if instances specifies a directory");
                System.exit(-1);
            }
            FeatureSelection.processFile(instances, output, minFeatFrequency, (SelectionMetric)ChiSquared.class.newInstance());
        } else {
            int numThreads = cmdLine.getIntegerValue(OPT_NUMTHREADS, 1);
            final Vector<File> inputFilesVector = new Vector<File>(Arrays.asList(instances.listFiles()));
            for (int i = 0; i < numThreads; ++i) {
                new Thread(){

                    public void run() {
                        File infile = null;
                        while ((infile = FeatureSelection.getNext(inputFilesVector)) != null) {
                            File outfile = new File(output, infile.getName());
                            try {
                                FeatureSelection.processFile(infile, outfile, minFeatFrequency, metric);
                            }
                            catch (IOException ioe) {
                                System.err.println("Error with file: " + infile.getName());
                                ioe.printStackTrace();
                            }
                            catch (Exception e) {
                                System.err.println("Error experienced during: " + infile.getName());
                                e.printStackTrace();
                            }
                        }
                    }
                }.start();
            }
        }
    }

    public static File getNext(Vector<File> files) {
        File result = null;
        try {
            result = files.remove(0);
        }
        catch (ArrayIndexOutOfBoundsException arrayIndexOutOfBoundsException) {
            // empty catch block
        }
        return result;
    }

    public static void processFile(File infile, File outfile, int minCount, SelectionMetric metric) throws IOException {
        System.err.println("Processing: " + infile.getName());
        long startTime = System.currentTimeMillis();
        InputStream inStream = new FileInputStream(infile);
        if (infile.getName().endsWith(".gz")) {
            inStream = new GZIPInputStream(inStream);
        }
        BufferedReader reader = new BufferedReader(new InputStreamReader(new BufferedInputStream(inStream)));
        String line = null;
        HashMap<Integer, IntHolder> hashToCount = new HashMap<Integer, IntHolder>();
        HashMap<String, IntHolder> classToCount = new HashMap<String, IntHolder>();
        System.err.println("Reading and counting...");
        int l = 0;
        while ((line = reader.readLine()) != null) {
            String[] split = line.split("\u0018");
            if (split.length == 1) continue;
            IntHolder valHolder = (IntHolder)classToCount.get(split[1]);
            if (valHolder == null) {
                valHolder = new IntHolder(0);
            }
            ++valHolder.val;
            classToCount.put(split[1], valHolder);
            for (int i = 2; i < split.length; ++i) {
                int hash = split[i].hashCode();
                IntHolder holder = (IntHolder)hashToCount.get(hash);
                if (holder == null) {
                    holder = new IntHolder(0);
                    hashToCount.put(hash, holder);
                }
                ++holder.val;
            }
            if (++l % 100000 != 0) continue;
            System.err.println(l + " " + hashToCount.size());
        }
        reader.close();
        System.err.println("Number of examples: " + l);
        System.err.println("Keeping only hashes of features meeting the frequency threshold...");
        Set hashes = hashToCount.keySet();
        int numHashes = hashes.size();
        HashSet<Integer> keepHashes = new HashSet<Integer>();
        for (Integer hash : hashes) {
            IntHolder count = (IntHolder)hashToCount.get(hash);
            if (count.val < minCount) continue;
            keepHashes.add(hash);
        }
        hashToCount.clear();
        hashToCount = null;
        outfile.getAbsoluteFile().getParentFile().mkdirs();
        PrintWriter writer = new PrintWriter(outfile);
        HashMap<String, Integer> classToIndex = new HashMap<String, Integer>();
        HashMap<Integer, String> indexToClass = new HashMap<Integer, String>();
        int newIndex = 0;
        for (String s : classToCount.keySet()) {
            classToIndex.put(s, newIndex);
            indexToClass.put(newIndex, s);
            ++newIndex;
            writer.print(s);
            writer.print('\t');
        }
        writer.println();
        int numClasses = classToIndex.size();
        int[] classCounts = new int[numClasses];
        for (int i = 0; i < numClasses; ++i) {
            String clazz = (String)indexToClass.get(i);
            IntHolder count = (IntHolder)classToCount.get(clazz);
            classCounts[i] = count == null ? 0 : count.val;
            writer.print(count.val);
            writer.print('\t');
        }
        writer.println();
        HashMap<String, Object> featToCounts = new HashMap<String, Object>();
        final HashMap<String, Double> featToScore = new HashMap<String, Double>();
        metric.calculateCountsAndScores(infile, classToIndex, classCounts, keepHashes, featToCounts, featToScore);
        ArrayList featureList = new ArrayList(featToScore.keySet());
        Collections.sort(featureList, new Comparator<String>(){

            @Override
            public int compare(String s1, String s2) {
                double score2;
                double score1 = (Double)featToScore.get(s1);
                if (score1 > (score2 = ((Double)featToScore.get(s2)).doubleValue())) {
                    return -1;
                }
                if (score1 < score2) {
                    return 1;
                }
                return s1.compareTo(s2);
            }
        });
        System.err.println("Writing counts");
        int z = 0;
        for (String s : featureList) {
            writer.print(s);
            writer.print('\t');
            writer.print(featToScore.get(s));
            Object counts = featToCounts.get(s);
            for (int x = 0; x < numClasses; ++x) {
                writer.print('\t');
                writer.print(Array.getInt(counts, x));
            }
            writer.println();
            if (++z % 10000 != 0) continue;
            writer.flush();
        }
        writer.close();
        System.err.println((System.currentTimeMillis() - startTime) / 1000L + " seconds");
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static class DummyMetric
    implements SelectionMetric {
        @Override
        public void calculateCountsAndScores(File infile, Map<String, Integer> classToIndex, int[] classCounts, Set<Integer> keepHashes, Map<String, Object> featToCounts, Map<String, Double> featToScore) throws IOException {
            FilterInputStream iStream = new BufferedInputStream(new FileInputStream(infile));
            if (infile.getName().endsWith(".gz")) {
                iStream = new GZIPInputStream(iStream);
            }
            String line = null;
            BufferedReader reader = new BufferedReader(new InputStreamReader(iStream));
            int l = 0;
            int numClasses = classToIndex.size();
            while ((line = reader.readLine()) != null) {
                String[] split = line.split("\u0018");
                if (split.length == 1) continue;
                for (int i = 2; i < split.length; ++i) {
                    int hash = split[i].hashCode();
                    if (!keepHashes.contains(hash) || featToScore.containsKey(split[i])) continue;
                    featToScore.put(split[i], 1.0);
                    featToCounts.put(split[i], new int[numClasses]);
                }
                if (++l % 100000 != 0) continue;
                System.err.println(l + "\t" + featToCounts.size());
            }
            reader.close();
            for (String feat : featToCounts.keySet()) {
                featToScore.put(feat, 1.0);
            }
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static class ChiSquared
    implements SelectionMetric {
        @Override
        public void calculateCountsAndScores(File infile, Map<String, Integer> classToIndex, int[] classCounts, Set<Integer> keepHashes, Map<String, Object> featToCounts, Map<String, Double> featToScore) throws IOException {
            FilterInputStream iStream = new BufferedInputStream(new FileInputStream(infile));
            if (infile.getName().endsWith(".gz")) {
                iStream = new GZIPInputStream(iStream);
            }
            String line = null;
            BufferedReader reader = new BufferedReader(new InputStreamReader(iStream));
            int l = 0;
            int numClasses = classToIndex.size();
            while ((line = reader.readLine()) != null) {
                String[] split = line.split("\u0018");
                if (split.length == 1) continue;
                int classIndex = classToIndex.get(split[1]);
                for (int i = 2; i < split.length; ++i) {
                    int x;
                    int hash = split[i].hashCode();
                    if (!keepHashes.contains(hash)) continue;
                    Object counts = featToCounts.get(split[i]);
                    if (counts == null) {
                        byte[] byArray = new byte[numClasses];
                        counts = byArray;
                        featToCounts.put(split[i], byArray);
                    }
                    if (counts instanceof byte[]) {
                        byte[] countsAsBytes = (byte[])counts;
                        int n = classIndex;
                        countsAsBytes[n] = (byte)(countsAsBytes[n] + 1);
                        if (countsAsBytes[n] != 127) continue;
                        short[] shortCounts = new short[numClasses];
                        featToCounts.put(split[i], shortCounts);
                        for (x = 0; x < numClasses; ++x) {
                            shortCounts[x] = countsAsBytes[x];
                        }
                        continue;
                    }
                    if (counts instanceof short[]) {
                        short[] countsAsShorts = (short[])counts;
                        int n = classIndex;
                        countsAsShorts[n] = (short)(countsAsShorts[n] + 1);
                        if (countsAsShorts[n] != Short.MAX_VALUE) continue;
                        int[] intCounts = new int[numClasses];
                        featToCounts.put(split[i], intCounts);
                        for (x = 0; x < numClasses; ++x) {
                            intCounts[x] = countsAsShorts[x];
                        }
                        continue;
                    }
                    int[] nArray = (int[])counts;
                    int n = classIndex;
                    nArray[n] = nArray[n] + 1;
                }
                if (++l % 100000 != 0) continue;
                System.err.println(l + "\t" + featToCounts.size());
            }
            reader.close();
            double[] classPercentages = new double[numClasses];
            double total = 0.0;
            for (int cCount : classCounts) {
                total += (double)cCount;
            }
            for (int i = 0; i < numClasses; ++i) {
                classPercentages[i] = (double)classCounts[i] / total;
            }
            for (String feat : featToCounts.keySet()) {
                int i;
                double posTotal = 0.0;
                double chiSquared = 0.0;
                Object counts = featToCounts.get(feat);
                for (i = 0; i < numClasses; ++i) {
                    posTotal += (double)Array.getInt(counts, i);
                }
                for (i = 0; i < numClasses; ++i) {
                    double observed = Array.getInt(counts, i);
                    double expected = classPercentages[i] * posTotal;
                    double diff = observed - expected;
                    chiSquared += diff * diff / expected;
                }
                featToScore.put(feat, chiSquared);
            }
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static interface SelectionMetric {
        public void calculateCountsAndScores(File var1, Map<String, Integer> var2, int[] var3, Set<Integer> var4, Map<String, Object> var5, Map<String, Double> var6) throws IOException;
    }

    private static class IntHolder {
        int val;

        public IntHolder(int val) {
            this.val = val;
        }
    }
}

