package AceJet;

/**
 *   determines EDT class and genericity of words from training corpus.
 *   Main program writes EDT type dictionary and generic dictionary.
 *   <p>
 *   (July 04) New version supports training from both data with
 *   subtypes (2004 data) and data without subtypes (earlier data).
 *   Counts without subtypes are represented by a subtype "*".
 *   "total" subtype is total of all subtypes but does NOT include
 *   subtype "*".
 */
 
import java.util.*;
import java.io.*;

import Jet.*;
import Jet.Lisp.FeatureSet;
import Jet.Pat.Pat;
import Jet.Tipster.*;
import Jet.Parser.SynFun;
import Jet.Refres.Resolve;
import Jet.Lex.EnglishLex;
import Jet.Chunk.Chunker;
 
public class EDTtype {
 	
 	static ExternalDocument doc;
	static final String ACEdir =
	    "C:/Documents and Settings/Ralph Grishman/My Documents/ACE/";
	static final String typeDictFile = ACEdir + "new EDT type dict.txt";
	static final String genericFile = ACEdir + "new generic dict.txt";
	static final boolean monocase = true;
 	
 	static Vector tokens;
 	static HashMap tokenStartMap;
	static PrintStream writer, gwriter;
	static TreeMap typeDataMap = new TreeMap();
	static TreeSet genericHeads;
	static int trainingMentions = 0, correct = 0, incorrect = 0, unknown = 0;
 	
 	/*
 	 *  main uses the files on file list (the text files and corresponding
 	 *  APF files) to gather statistics on EDT type as a function of NP
 	 *  head, as well as statistics on generics.
 	 */
 	 
 	public static void main (String[] args) throws IOException {
		// initialize Jet
		System.out.println("Starting ACE EDT Type / Generic Training ...");
		JetTest.initializeFromConfig("props/train EDT.properties");
		// new Jet.Console();
		Chunker.loadModel();
		Pat.trace = false;
		
		writer = new PrintStream (new FileOutputStream (typeDictFile));
		gwriter = new PrintStream (new FileOutputStream (genericFile));
		
		// train on old data
		AceDocument.ace2004 = false;
		trainFromFileList (ACEdir + "training all.txt");
		trainFromFileList (ACEdir + "feb02 all.txt");
		trainFromFileList (ACEdir + "sep02 all.txt");
		trainFromFileList (ACEdir + "aug03 all.txt");
		// train on new (2004) data
		AceDocument.ace2004 = true;
		trainFromFileList (ACEdir + "training04 nwire 21andup.txt");
		trainFromFileList (ACEdir + "training04 bnews 21andup.txt");
		trainFromFileList (ACEdir + "training04 chinese.txt");
		
		writeTypeDict(writer);
		writeGenericDict(gwriter);
		EDTtypeData.reportSubtypeTotals ();
		System.out.println (trainingMentions + " training mentions");
		System.out.println (correct + " correct predictions, " + incorrect + " incorrect");
		System.out.println (unknown + " unknown");
	}
	
	static void trainFromFileList (String fileList) throws IOException {
		String currentDoc;
		// open list of files
		BufferedReader reader = new BufferedReader (new FileReader(fileList));
		int docCount = 0;
		while ((currentDoc = reader.readLine()) != null) { 
			// if (true) continue;
			// process file 'currentDoc'
			docCount++;
			System.out.println ("\nProcessing document " + docCount + ": " + currentDoc);
			// read document
			String textFileName = ACEdir + currentDoc + ".sgm";
			doc = new ExternalDocument("sgml", textFileName);
			doc.setAllTags(true);
			doc.open();
			// check document case
			Ace.monocase = Ace.allLowerCase(doc);
			// process document
			Control.processDocument (doc, null, false, docCount);
			// collect all tokens
			collectTokens();
			// read key file with mention information
			// (populate mentionSet and mentionStartMap)
			String suffix = (currentDoc.startsWith("aug03") || currentDoc.startsWith("training04")) ?
			                ".apf.xml" : ".sgm.tmx.rdc.xml";
			String apfFileName = ACEdir + currentDoc + suffix;
			AceDocument aceDoc = new AceDocument(textFileName, apfFileName);
			LearnRelations.findEntityMentions (aceDoc);
			// process possible
			processMentions(doc);
		}
	}
	
	static void collectTokens () {
		tokens = doc.annotationsOfType ("token");
		tokenStartMap = new HashMap();
		for (int i=0; i<tokens.size(); i++) {
			Annotation token = (Annotation) tokens.get(i);
			int start = token.start();
			tokenStartMap.put(new Integer(start), new Integer(i));
		}
	}

	static void processMentions (ExternalDocument doc) {
		// gather all mentions in document using Resolve.gatherMentions
		Vector mentions = Resolve.gatherMentions(doc, new Span(0, doc.length()));
		// for each mention
		for (int imention=0; imention<mentions.size(); imention++) {
			//    look up in APF.mentionSet
			Annotation mention = (Annotation) mentions.get(imention);
			Annotation head = Resolve.getHeadC(mention);
			String cat = (String) head.get("cat");
			if (cat.equals("pro") || cat.equals("det") || cat.equals("name"))
				continue;
			String headString = Resolve.normalizeName(doc.text(head));
			if (monocase)
				headString = headString.toLowerCase();
			Mention apfMention = 
				(Mention) LearnRelations.mentionStartMap.get(new Integer(head.start()));
			//    if in, classify, else "other"
			String EDTtype = "OTHER";
			String EDTsubtype = "";
			if (apfMention != null) {
				EDTtype = apfMention.type;
				EDTsubtype = apfMention.subtype;
			}
			if (!AceDocument.ace2004)
				EDTsubtype = "*";
			boolean training = trainingMentions < 200000; // true;
			if (training) {
				trainingMentions++;
				// retrieve / create WordType record
				EDTtypeData wt = (EDTtypeData) typeDataMap.get(headString);
				if (wt == null) {
					wt = new EDTtypeData(headString);
					typeDataMap.put(headString, wt);
				}
				wt.incrementTypeCount(EDTtype, EDTsubtype, 1);
				if (apfMention != null)
					wt.incrementGenericCount(apfMention.generic);
			} else {
				String prediction = bareType(getTypeSubtype(doc, null, mention));
				if (prediction.equals(EDTtype)) {
					correct++;
				} else {
					incorrect++;
					System.out.print   ("Word: " + headString);
					System.out.println (" predict " + prediction + ", should be " + EDTtype);
				}
			}
		}			
	}
	
	static void writeTypeDict (PrintStream writer) {
		Iterator it = typeDataMap.values().iterator();
		while (it.hasNext()) {
			((EDTtypeData) it.next()).write(writer);
		}
		writer.close();
	}
	
	public static void readTypeDict () {
		String fileName = JetTest.getConfigFile("Ace.EDTtype.fileName");
		if (fileName != null) {
			readTypeDict(fileName);
		} else {
			System.out.println ("EDTtype.readTypeDict:  no file name specified in config file");
		}
	}
	
	public static void readTypeDict (String dictFile) {
		System.out.println ("Loading type dictionary " + dictFile);
		typeDataMap = new TreeMap();
		try {
			BufferedReader reader = new BufferedReader(new FileReader(dictFile));
			String line;
			while ((line = reader.readLine()) != null) {
				EDTtypeData data = EDTtypeData.readLine(line);
				if (data != null)
					typeDataMap.put(data.word, data);
			}
			System.out.println ("Type dictionary loaded.");
		} catch (IOException e) {
			System.out.print("Unable to load dictionary due to exception: ");
			System.out.println(e);
		}
	}
	
	static void writeGenericDict (PrintStream writer) {
		Iterator it = typeDataMap.values().iterator();
		while (it.hasNext()) {
			EDTtypeData td = (EDTtypeData) it.next();
			if (td.genericCount > 0 || td.nonGenericCount > 0)
				writer.println (td.word + " | " + td.genericCount + 
				                " " + td.nonGenericCount);
		}
		writer.close();
	}
	
	public static void readGenericDict (String dictFile) {
		System.out.println ("Loading generic dictionary.");
		genericHeads = new TreeSet();
		try {
			BufferedReader reader = new BufferedReader(new FileReader(dictFile));
			String line;
			while ((line = reader.readLine()) != null) {
				int split = line.indexOf('|');
				if (split <= 1) {
					System.out.println ("** error in generic dict: " + line);
					return;
				}
				String term = line.substring(0,split-1);
				String typeStatistics = line.substring(split+2);
				StringTokenizer st = new StringTokenizer(typeStatistics);
       			String genericCountString = st.nextToken();
       			String nonGenericCountString = st.nextToken();
       			int genericCount = Integer.valueOf(genericCountString).intValue();
       			int nonGenericCount = Integer.valueOf(nonGenericCountString).intValue();
       			if (genericCount > nonGenericCount && (genericCount + nonGenericCount) > 2)
       				genericHeads.add(term);	
			}
			System.out.println ("Generic dictionary loaded.");
		} catch (IOException e) {
			System.out.print("Unable to load dictionary due to exception: ");
			System.out.println(e);
		}
	}
	
	public static boolean hasGenericHead (ExternalDocument doc, Annotation mention) {
		Annotation headC = Resolve.getHeadC (mention);
		String headWord = Resolve.normalizeName(doc.text(headC).trim());
		return genericHeads.contains(headWord);
	}
	
	
	/**
	 *  returns the EDT type of a mention:  PERSON, GPE, ORGANIZATION,
	 *  LOCATION, FACILITY, or OTHER (where OTHER indicates that it is not
	 *  and EDT mention).
	 */
	 
	private static final String[] partitives = {"group", "part", "member", "portion",
	    "center", "bunch", "couple", "remainder", "rest", "lot", "percent", "%",
	    "dozen", "hundred", "thousand", /* also, 'a number of' but not 'the number of' */
	    "some", "either", "neither", "any", "each", "all", "both", "none",
	    "most", "many", "afew", "one", "q"};
	    
	private static final String[] governmentTitles = {"Vice-President",
	    "Vice-Premier", "Prime-Minister", "Foreign-Minister", 
	    "Foreign-Secretary", "Secretary-of-State", "Attorney-General",
	    "Justice-Minister", "Secretary-General"};
	 
	public static String getTypeSubtype (ExternalDocument doc, Annotation entity,
								  Annotation mention) {
		String paHead = SynFun.getHead(doc, mention).toLowerCase();
		String det = SynFun.getDet(mention);
		boolean isHumanMention = SynFun.getHuman(mention);
		Annotation headC = Resolve.getHeadC (mention);
		// for perfect mentions
		if (Ace.perfectMentions) {
			String tsubt = PerfectAce.getTypeSubtype(headC);
			if (tsubt != null && !tsubt.equals(""))
				return tsubt;
			else
				System.out.println ("*** no type info for " + doc.text(headC));
		}
		// look up in gazetteer
		String gazetteerType = getGazetteerTypeSubtype (doc, mention);
		if (gazetteerType != null)
			return gazetteerType;
		String headWord = Resolve.normalizeName(doc.text(headC).trim());
		String name = SynFun.getName(doc, mention);
		String cat = (String) headC.get("cat");
		// for named mentions, use type assigned by name tagger
		if (name != null) {
			if (paHead != null && !paHead.equalsIgnoreCase("otherName")) {
				String type = paHead.toUpperCase();
				// return typeAndSubtype (type, EDTtypeData.bestSubtype(type));
				String subtype = NameSubtyper.classify(name, type);
				return typeAndSubtype (type, subtype);
			} else {
				return "OTHER";
			}
		}
		// for phrases such as "group of X", "part of X", "all of X", 
		// use the type of phrase X.
		if (in(paHead, partitives) || headC.get("cat") == "q") {
			// 'of' complement may be at lower level of tree, so go down tree
			// until we find such a complement
			Annotation x = mention;
			while (x != null && x.get("of") == null)
				x = (Annotation) x.get("headC");
			if (x != null) {
				Annotation of = (Annotation) x.get("of");
				System.out.println ("Using computed type for " + paHead);
				String type = getTypeSubtype(doc, null, of);
				// special case:  parts of a GPE are a LOCATION
				if (bareType(type).equals("GPE") &&
				    (paHead.equals("part") || paHead.equals("portion")))
				   type = "LOCATION:Region-Subnational";
				// Ace.partitiveMap.put(new Integer(mention.span().start()), new Integer(of.span().start()));
				return type;
			}
		}
		// for pronouns, not in partitives, return "OTHER"
		// (this suppresses entities whose first mention is a pronoun)
		if (cat.equals("pro") || cat.equals("det") || cat.equals("q"))
			return "OTHER";
		// for some nouns, EDTtype depends on whether they appear with a
		// determiner;  handle these separately
		String type = handCodedEDTtype (det, headWord);
		if (type != null)
			return type;
		// for all other nouns, look head up in EDT type dictionary
		//    first use actual (inflected) head
		type = lookUpEDTtype(headWord.toLowerCase());
		if (type != null)
			return type.intern();
		//    then try with regularized head from PA structure			
		type = lookUpEDTtype(paHead);
		if (type != null)
			return type.intern();
		// if there is no entry for singular form, check if plural form has entry
		String[] singular = new String[1];
		singular[0] = paHead;
		String[] plural = EnglishLex.nounPlural(singular);
		type = lookUpEDTtype(plural[0]);
		if (type != null)
			return type.intern();
		// if no entries at all, and entity has feature 'human' from Comlex,
		// treat as a person
		if (Ace.preferRelations) {
			if (isHumanMention || (entity != null && entity.get("human") == "t"))
			return "PERSON";
		}
		unknown++;
		return "OTHER";
	}
	
	private static String getGazetteerTypeSubtype (Document doc, Annotation mention) {
		String[] headTokens = Resolve.getHeadTokens(doc, mention);
		if (mention.get("cat") == "np") {
			// if (Resolve.isGenericNationals(doc, mention))
			// 	return "GPE:Nation";
			// else 
			if (Ace.gazetteer.isNational(headTokens) || Ace.gazetteer.isNationals(headTokens))
				return "PERSON";
		} else {
			if (Ace.gazetteer.isNationality(headTokens))
				return "GPE:Nation";
		}
		return null;
	}
	
	static HashMap specifiedEDTtype = new HashMap();
	static { // the following nouns are markable if they appear with
			 // a determiner
			 // (cf. "by force" and "by the American force")
			 specifiedEDTtype.put("force", "ORGANIZATION:Other");
			 // ("on board" vs. "on the board")
			 specifiedEDTtype.put("board", "ORGANIZATION:Commercial");
			 // ("in prison" vs. "in the prison")
	         specifiedEDTtype.put("prison", "FACILITY:Building");
	         specifiedEDTtype.put("room", "FACILITY:Subarea-Building");
	         specifiedEDTtype.put("home", "FACILITY:Building");
	         specifiedEDTtype.put("state", "GPE:State-or-Province");
	         specifiedEDTtype.put("land", "LOCATION:Region-National");
	         // 'minister' as a noun is markable
	         // 'Minister' as part of a title is not
	         specifiedEDTtype.put("minister", "PERSON");
	       }
	
	static String handCodedEDTtype (String determiner, String head) {
		String type = (String) specifiedEDTtype.get(head);
		if (type == null) return null;
		if (determiner == null) return "OTHER";
		return type;
	}
	
	static String lookUpEDTtype (String word) {
		if (word == null) return null;
		EDTtypeData data = (EDTtypeData) typeDataMap.get(word.toLowerCase());
		if (data == null) return null;
		return data.getBestTypeSubtype();
	}
	
	private static boolean in (Object o, Object[] array) {
		for (int i=0; i<array.length; i++)
			if (array[i] != null && array[i].equals(o)) return true;
		return false;
	}
	
	static String bareType (String typeSubtype) {
		int p = typeSubtype.indexOf(':');
		if (p > 0)
			return typeSubtype.substring(0,p);
		else
			return typeSubtype;
	}
	
	static String subtype (String typeSubtype) {
		int p = typeSubtype.indexOf(':');
		if (p > 0)
			return typeSubtype.substring(p+1);
		else
			return "";
	}
	
	static String typeAndSubtype (String type, String subtype) {
		return type.toUpperCase() + ":" + subtype;
	}
					
}

/**
 *  information about a word:  how frequently is appears as each EDT type, and
 *  how frequently it appears as generic or non-generic.
 */
 
class EDTtypeData {

	static final String[] EDT_TYPES 
		= {"OTHER", "PERSON", "ORGANIZATION", "GPE", "LOCATION", "FACILITY"};
	static final int numTypes = 6;
	static HashMap subtypeTotals = new HashMap();
	
	String word;
	String type = null;     // most frequent type
	String subtype = null;  // most frequent subtype
	int count = 0;
	HashMap typeCount;   // for 2004 data
	int genericCount = 0, nonGenericCount = 0;
	
	EDTtypeData (String word) {
		this.word = word;
		typeCount = new HashMap();
	}
	
	static EDTtypeData readLine(String line) {
		int split = line.indexOf('|');
		if (split <= 1) {
			System.out.println ("** error in ace dict: " + line);
			return null;
		}
		String term = line.substring(0,split-1);
		EDTtypeData data = new EDTtypeData(term);
		String typeStatistics = line.substring(split+2);
		StringTokenizer st = new StringTokenizer(typeStatistics);
		while (st.hasMoreTokens()) {
	  	String typeSubtype = st.nextToken();
	  	String countString = st.nextToken();
	  	int count = Integer.valueOf(countString).intValue();
	  	int p = typeSubtype.indexOf(':');
	  	String type, subtype;
	  	if (p > 0) {
				type = typeSubtype.substring(0,p);
				subtype = typeSubtype.substring(p+1);
			} else {
				type = typeSubtype;
				subtype = "*";
			}
			data.incrementTypeCount (type, subtype, count);
	  }
	  return data;
	}	
	
	String getBestType () {
		if (type == null) determineBestType();
		return type;
	}
	
	String getBestSubtype () {
		if (type == null) determineBestType();
		return subtype;
	}
	
	String getBestTypeSubtype () {
		return EDTtype.typeAndSubtype(getBestType(), getBestSubtype());
	}
	
	void write (PrintStream writer) {
		writer.print(word + " |");
		Iterator it = typeCount.keySet().iterator();
		while (it.hasNext()) {
			String type = (String) it.next();
			HashMap subMap = (HashMap) typeCount.get(type);
			Iterator it2 = subMap.keySet().iterator();
			while (it2.hasNext()) {
				String subtype = (String) it2.next();
				if (subtype == "total") continue;
				Integer cc = (Integer) subMap.get(subtype);
				if (subtype.equals("*")) {
					writer.print(" " + type + " " + cc.intValue());
				} else {
					writer.print(" " + type + ":" + subtype + " " + cc.intValue());
				}
			}
		}
		writer.println();
	}
	
	void incrementTypeCount (String type, String subtype, int incr) {
		incrementCount (typeCount, type, subtype, incr);
		if (!subtype.equals("*")) {
			incrementCount (typeCount, type, "total", incr);
			incrementCount (subtypeTotals, type, subtype, incr);
		}
	}
	
	/**
	 *  increments a counter in a two-level hash map
	 */
	 
	static void incrementCount (HashMap map, String a, String b, int incr) {
		HashMap aMap = (HashMap) map.get(a);
		if (aMap == null) {
			aMap = new HashMap();
			map.put(a, aMap);
		}
		Integer cc = (Integer) aMap.get(b);
		int c = (cc == null) ? 0 : cc.intValue();
		c += incr;
		aMap.put(b, new Integer(c));
	}
	
	void incrementGenericCount (boolean generic) {
		if (generic) {
			genericCount++;
		} else {
			nonGenericCount++;
		}
	}
	
	/**
	 *  determines the best type and subtype for a word.
	 */
	 
	void determineBestType () {
		if (type != null) return;
		int bestCount = 0;
		Iterator it = typeCount.keySet().iterator();
		while (it.hasNext()) {
			String tp = (String) it.next();
			Integer cc = (Integer) ((HashMap) typeCount.get(tp)).get("total");
			int c = (cc == null) ? 0 : cc.intValue();
			if (c > bestCount) {
				type = tp;
				bestCount = c;
			}
		}
		if (type == null) {
			// if there is no new (2004) data, use counts from old data
			it = typeCount.keySet().iterator();
			while (it.hasNext()) {
				String tp = (String) it.next();
				Integer cc = (Integer) ((HashMap) typeCount.get(tp)).get("*");
				int c = (cc == null) ? 0 : cc.intValue();
				if (c > bestCount) {
					type = tp;
					bestCount = c;
				}
			}
			if (type == null) {
				System.out.println ("EDTtypeData.determineBestType failed for " + word);
			} else {
				subtype = bestSubtype(type);
			}
			return;
		}
		bestCount = -1;
		HashMap subtypeCount = (HashMap) typeCount.get(type);
		it = subtypeCount.keySet().iterator();
		while (it.hasNext()) {
			String subtp = (String) it.next();
			if (subtp == "total") continue;
			if (subtp.equals("*")) continue;
			Integer cc = (Integer) subtypeCount.get(subtp);
			int c = (cc == null) ? 0 : cc.intValue();
			if (c > bestCount) {
				subtype = subtp;
				bestCount = c;
			}
		}
	}
	
	static void reportSubtypeTotals () {
		Iterator it = subtypeTotals.keySet().iterator();
		while (it.hasNext()) {
			String type = (String) it.next();
			System.out.println ("For type: " + type);
			HashMap subMap = (HashMap) subtypeTotals.get(type);
			Iterator it2 = subMap.keySet().iterator();
			while (it2.hasNext()) {
				String subtype = (String) it2.next();
				Integer cc = (Integer) subMap.get(subtype);
				System.out.print(subtype + " " + cc.intValue() + " ");
			}
			System.out.println();
		}
	}		

  static String bestSubtype (String type) {
  	HashMap subMap = (HashMap) subtypeTotals.get(type);
  	String best = null;
  	int bestCount = 0;
		Iterator it2 = subMap.keySet().iterator();
		while (it2.hasNext()) {
			String subtype = (String) it2.next();
			Integer cc = (Integer) subMap.get(subtype);
			int c = cc.intValue();
			if (c > bestCount) {
				best = subtype;
				bestCount = c;
			}
		}
		if (best != null) {
			return best;
		} else if (type.equals("PERSON") || type.equals("OTHER")) {
			return "";
		} else {
			return "Other";
		}
	}
	
}