package AceJet;

/**
 *   determines EDT class of words based on their context
 *   (not successful --- not used in final 03 system)
 */
 
import java.util.*;
import java.io.*;

import Jet.*;
import Jet.Lisp.FeatureSet;
import Jet.Pat.Pat;
import Jet.Tipster.*;
import Jet.Refres.Resolve;
 
import org.w3c.dom.*;
import org.xml.sax.*;
import javax.xml.parsers.*;
 
 class WordSense {
 	
 	static ExternalDocument doc;
	static String currentDoc;
	static final String ACEdir =
	    "C:/Documents and Settings/Ralph Grishman/My Documents/ACE/";
	static final String fileList = 
	    ACEdir + "training all.txt";
	static final String wordSenseFile = ACEdir + "word sense.txt";
 	
 	static Vector tokens;
 	static HashMap tokenStartMap;
 	static DocumentBuilder builder;
	static PrintStream writer;
	static HashMap wordContextMap = new HashMap();
	static int trainingMentions = 0, correct = 0, incorrect = 0;
 	
 	public static void main (String[] args) 
	    throws IOException, ParserConfigurationException, SAXException {
		// initialize Jet
		System.out.println("Starting ACE Word Sense Training ...");
		JetTest.initializeFromConfig("word sense.properties");
		// new Jet.Console();
		Pat.trace = false;
		// initialize APF reader
		DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
		factory.setValidating(false);
		builder = factory.newDocumentBuilder();
		
		writer = new PrintStream (new FileOutputStream (wordSenseFile));
		// open list of files
		BufferedReader reader = new BufferedReader (new FileReader(fileList));
		int docCount = 0;
		while ((currentDoc = reader.readLine()) != null) { 
			// if (true) continue;
			// process file 'currentDoc'
			docCount++;
			System.out.println ("\nProcessing document " + docCount + ": " + currentDoc);
			// read document
			String textFile = ACEdir + currentDoc + ".sgm";
			doc = new ExternalDocument("sgml", textFile);
			doc.setAllTags(true);
			doc.open();
			// process document
			Control.processDocument (doc, null, false, docCount);
			// collect all tokens
			collectTokens();
			// read key file with mention information
			// (populate mentionSet and mentionStartMap)
			// Ace.countXmlChars(doc);
			String apfFileName = ACEdir + currentDoc + ".sgm.tmx.rdc.xml";
			AceDocument aceDoc = new AceDocument(textFile, apfFileName);
			LearnRelations.findEntityMentions (aceDoc);
			// process possible
			processMentions(doc);
		}
		writeWordContexts(writer);
		writer.close();
		System.out.println (trainingMentions + " training mentions");
		System.out.println (correct + " correct predictions, " + incorrect + " incorrect");
	}
	
	static void collectTokens () {
		tokens = doc.annotationsOfType ("token");
		tokenStartMap = new HashMap();
		for (int i=0; i<tokens.size(); i++) {
			Annotation token = (Annotation) tokens.get(i);
			int start = token.start();
			tokenStartMap.put(new Integer(start), new Integer(i));
		}
	}
	/*
	static final int contextWindow = 1;
	
	static ArrayList context (Annotation w) {
		ArrayList context = new ArrayList();
		int start = w.start();
		int end = w.end();
		int len = tokens.size();
		Integer startTokenI = (Integer) tokenStartMap.get(new Integer(start));
		Integer nextTokenI = (Integer) tokenStartMap.get(new Integer(end));
		if (startTokenI == null || nextTokenI == null) 
			return context;
		int startToken = startTokenI.intValue();
		int nextToken = nextTokenI.intValue();
		for (int i=Math.max(0,startToken-contextWindow); i<startToken; i++)
			context.add(doc.text((Annotation) tokens.get(i)).trim());
		for (int i=nextToken; i<Math.min(nextToken+contextWindow, len); i++)
			context.add(doc.text((Annotation) tokens.get(i)).trim());
		return context;
	}
	*/
	static ArrayList context (Annotation w) {
		ArrayList context = new ArrayList();
		FeatureSet pa = (FeatureSet) w.get("pa");
		boolean uncountable = pa != null &&
				pa.get("det") == null &&
			    // pa.get("number") == "singular" &&
			    w.get("poss") == null;
		context.add(uncountable?"UNCOUNTABLE":"COUNTABLE");
		return context;
	}
		
	static void processMentions (ExternalDocument doc) {
		// gather all mentions in document using Resolve.gatherMentions
		Vector mentions = Resolve.gatherMentions(doc, new Span(0, doc.length()));
		// for each mention
		for (int imention=0; imention<mentions.size(); imention++) {
			//    look up in APF.mentionSet
			Annotation mention = (Annotation) mentions.get(imention);
			Annotation head = Resolve.getHeadC(mention);
			String cat = (String) head.get("cat");
			if (cat.equals("pro") || cat.equals("det"))
				continue;
			String headString = Resolve.normalizeName(doc.text(head));
			Mention apfMention = 
				(Mention) LearnRelations.mentionStartMap.get(new Integer(head.start()));
			//    if in, classify, else "other"
			String EDTtype = "OTHER";
			if (apfMention != null)
				EDTtype = apfMention.type;
			// determine word context
			ArrayList context = context(mention);
			boolean training = trainingMentions < 45000; // true;
			if (training) {
				trainingMentions++;
				// retrieve / create WordType record
				WordSenseData wt = (WordSenseData) wordContextMap.get(headString);
				if (wt == null) {
					wt = new WordSenseData(headString);
					wordContextMap.put(headString, wt);
				}
				wt.incrementTypeCount(EDTtype);
				
				for (int i=0; i<context.size(); i++)
					wt.incrementContextCount (EDTtype, (String) context.get(i));
			} else {
				WordSenseData wt = (WordSenseData) wordContextMap.get(headString);
				if (wt == null) continue;
				int itype = wt.decodeType(context);
				String prediction = WordSenseData.EDT_TYPES[itype];
				if (prediction.equals(EDTtype)) {
					correct++;
				} else {
					incorrect++;
					System.out.print   ("Word: " + headString);
					System.out.println (" predict " + prediction + ", should be " + EDTtype);
				}
			}
		}			
	}
	
	static void writeWordContexts (PrintStream writer) {
		Iterator it = wordContextMap.values().iterator();
		while (it.hasNext()) {
			((WordSenseData) it.next()).write(writer);
		}
		writer.close();
	}
			
}

/**
 *  information about a word:  how frequently is appears as each EDT type, and
 *  what contexts are associated with each word context.
 */
 
class WordSenseData {

	static final String[] EDT_TYPES 
		= {"OTHER", "PERSON", "ORGANIZATION", "GPE", "LOCATION", "FACILITY"};
	static final int numTypes = 6;
	
	String word;
	String uniqueType;  // if word is always of one type
	int count;
	int[] typeCount;
	/* map from word to frequency with which it appears in context with type i */
	HashMap[] typeContext;
	int ccount = 0, ucount = 0;  // <<<
	
	WordSenseData (String word) {
		this.word = word;
		typeCount = new int[6];
		typeContext = new HashMap[6];
	}
	
	void write (PrintStream writer) {
		writer.println("WORD " + word);
		writer.println("COUNT " + count);
		checkForUniqueType();
		if (uniqueType != null) {
			writer.println("TYPE " + uniqueType);
		} else {
			for (int i=0; i<numTypes; i++) {
				writer.println ("COUNT " + EDT_TYPES[i] + " " + typeCount[i]);
				if (typeContext[i] != null) {
					Iterator it = typeContext[i].keySet().iterator();
					while (it.hasNext()) {
						String contextWord = (String) it.next();
						int count = ((Integer) typeContext[i].get(contextWord)).intValue();
						writer.println("CONTEXT " + EDT_TYPES[i] + " " + contextWord +
						               " " + count);
					}
				}
			}
		}
		writer.println("END");
	}
	
	void incrementTypeCount (String type) {
		count++;
		int i = typeIndex(type);
		typeCount[i]++;
	}
	
	void incrementContextCount (String type, String contextWord) {
		int i = typeIndex(type);
		if (typeContext[i] == null)
			typeContext[i] = new HashMap();
		Integer countI = (Integer) typeContext[i].get(contextWord);
		int count = (countI == null) ? 0 : countI.intValue();
		typeContext[i].put(contextWord, new Integer(count+1));
		if (contextWord == "COUNTABLE") ccount++; else ucount++; //<<<<
	}
	
	private static int typeIndex (String type) {
		for (int i=0; i<numTypes; i++) {
			if (type.equals(EDT_TYPES[i])) return i;
		}
		System.out.println ("typeIndex:  unexpected type " + type);
		return 0;
	}
	
	private void checkForUniqueType () {
		String type = null;
		for (int i=0; i<numTypes; i++) {
			if (typeCount[i] > 0) {
				if (type == null) {
					type = EDT_TYPES[i];
				} else {
					return;  // no unique type
				}
			}
		}
		uniqueType = type;
		return;
	}
	
	int decodeType (ArrayList context) {
		String c = (String) context.get(0);
		int k = (c=="COUNTABLE")?ccount:ucount;
		double bestProb = -1.;
		int bestType = -1;
		for (int itype=0; itype<numTypes; itype++) {
			double prob = (double) typeCount[itype] / count;
			
			for (int i=0; i<context.size(); i++) {
				String contextWord = (String) context.get(i);
				int contextCount = 0;
				if (typeContext[itype] != null) {
					Integer contextCountI = (Integer) typeContext[itype].get(contextWord);
					contextCount = (contextCountI == null) ? 0 : contextCountI.intValue();
				}
				if (k>5) prob = (double) contextCount / count; // <<<<<<
				// prob *= (double) (contextCount+1) / (typeCount[itype] + 1);
			}
			/* */
			if (prob > bestProb) {
				bestProb = prob;
				bestType = itype;
			}
		}
		return bestType;
	}
				
}