package AceJet;

/**
 *  Takes a set of pattern files,
 *    builds a statistical model from the files
 *    evaluates the effectiveness of the model at predicting types
 */

import java.util.*;
import java.io.*;

class BuildRelationModel {
	
	static RelationPattern relPat;
	static StringBuffer features = new StringBuffer();
	
	static final String ACEdir =
	    "C:/Documents and Settings/Ralph Grishman/My Documents/ACE/";
	static final String rootDir = ACEdir + "relations/";
	static final String testPatternFile = rootDir + "patterns.log";
	static final String patternFile = rootDir + "patterns.log";
	static final String handPatternFile = ACEdir + "lisp/" + "patterns.log";
	static final String generalPatternFile = ACEdir + "lisp/" + "generalPatterns.log";
	
	static RelationPatternSet adam, eve, general;
	
	static final int testCorpusSize = 2000;  // 400;
	
	public static void main (String[] args) throws IOException {

		String prep = RelationPattern.prepositionOfLink("s-verb-in");
		System.out.println (prep);
		boolean prepMatch = RelationPattern.matchingRelations("s-verb-in", "s-werb-in");
		System.out.println (prepMatch);
		/*		
		adam = new RelationPatternSet();
		adam.load(handPatternFile, 0 );
		eve = new RelationPatternSet();
		eve.load(patternFile, testCorpusSize);
		general = new RelationPatternSet();
		general.load(generalPatternFile, 0);
		buildProbModel (eve);
		predict();
		*/
	}
	
	static final String[] typeSubtype =
		{"total",
		 "no relation",
		 "PHYS      Located",
		 "PHYS      Near",
		 "PHYS      Part-Whole",
		 "PER-SOC   Business",
		 "PER-SOC   Family",
		 "PER-SOC   Other",
		 "EMP-ORG   Employ-Executive",
		 "EMP-ORG   Employ-Staff",
		 "EMP-ORG   Employ-Undetermined",
		 "EMP-ORG   Member-of-Group",
		 "EMP-ORG   Subsidiary",
		 "EMP-ORG   Partner",
		 "EMP-ORG   Other",
		 "ART       User-or-Owner",
		 "ART       Inventor-Manufacturer",
		 "ART       Other",
		 "OTHER-AFF Ethnic",
		 "OTHER-AFF Ideology",
		 "OTHER-AFF Other",
		 "GPE-AFF   Citizen-or-Resident",
		 "GPE-AFF   Based-In",
		 "GPE-AFF   Other",
		 "DISC      "};
	
	private static final int TYPELENGTH = 9;
	private static final int TOTAL = 0;
	static final int NO_RELATION = 1;
	private static final int N_SUBTYPES = typeSubtype.length;
	
	private static int[] subtypeCount = new int[N_SUBTYPES];
	private static int relationCount = 0;
	private static int MAX_SIZE = 25;
	private static int[] relationLengthCount = new int[MAX_SIZE];
	private static int[] lengthCount = new int[MAX_SIZE];
	private static HashMap[] nonFinalWordCount = new HashMap[N_SUBTYPES];
	private static HashMap[] finalWordCount = new HashMap[N_SUBTYPES];
	
	static void buildProbModel (RelationPatternSet rps) {
		for (int i=0; i<N_SUBTYPES; i++) {
			nonFinalWordCount[i] = new HashMap();
			finalWordCount[i] = new HashMap();
		}
		Iterator it = rps.iterator();
		while (it.hasNext()) {
			RelationPattern pattern = (RelationPattern) it.next();
			String type = pattern.relationType;
			String subType = pattern.relationSubtype;
			int itype;
			if (type.equals("0")) {
				itype = NO_RELATION;
			} else {
				itype = typeSubtypeToIndex(type, subType);
				if (itype < 0) {
					System.out.println ("*** unknown type/subtype" + type + ":" + subType);
					continue;
				}
			}
			ArrayList lilink = pattern.linearLink;
			int size = lilink.size();
			// if (size >= MAX_SIZE)
			//  	System.out.println ("*** Huge pattern " + lilink);
			if (size > 0 && size < MAX_SIZE && !lilink.get(0).equals("0")) {
				if (!type.equals("0")) {
					relationCount++;
					relationLengthCount[size]++;
				}
				lengthCount[size]++;
				subtypeCount[TOTAL]++;
				subtypeCount[itype]++;				
				for (int i=0; i<size-1; i++) {
					incrementHashMap (nonFinalWordCount[TOTAL], (String) lilink.get(i), 1);
					incrementHashMap (nonFinalWordCount[itype], (String) lilink.get(i), 1);
				}
				incrementHashMap (finalWordCount[TOTAL], (String) lilink.get(size-1), 1);
				incrementHashMap (finalWordCount[itype], (String) lilink.get(size-1), 1);
			}
		}
	}				
	
	static int mostLikelySubtype (RelationInstance relpat) {
		String argType1 = relpat.getType1();
		String argType2 = relpat.getType2();
		int bestType = -1;
		double bestProb = -1;
		for (int iType=1; iType<N_SUBTYPES; iType++) {
			double prob = subtypeProb (iType, relpat.linearLink);
			if (prob > bestProb) {
				bestType = iType;
				bestProb = prob;
			}
		}
		if (bestType > 0)
			System.out.println (">>> Best stat. type = " + typeSubtype[bestType]);
		return bestType;
	}
	
	static final double NO_RELATION_BIAS = 0.2;
	static final double VOCAB_SIZE = 4000.;
	static final boolean trace = false;
	private static double BETA = 0.1;
		
	private static double subtypeProb (int iType, ArrayList linearLink) {
		int size = linearLink.size();
		if (size == 0 || size >= MAX_SIZE || linearLink.get(0).equals("0"))
			return -1;	
		double prob = 1.0;
		String word;
		for (int i=0; i<size-1; i++) {
			prob *= wordSubtypeProb (iType, (String) linearLink.get(i), false);
		}
		prob *= wordSubtypeProb (iType, (String) linearLink.get(size-1), true);
		if (iType == NO_RELATION) {
			// prob(no relation | length)
			double f1 = 1. - (double) relationLengthCount[size] / lengthCount[size];
			prob *= f1 * NO_RELATION_BIAS;
			if (trace) System.out.println ("No relation:  f1 = " + f1 + ", prob = " + prob);
		} else {
			// prob(relation itype | some relation)
			double f1 = (double) subtypeCount[iType] / relationCount;
			// prob(some relation | length)
			double f2 = (double) relationLengthCount[size] / lengthCount[size];
			prob *= f1 * f2;
			if (trace) System.out.println ("Relation " + typeSubtype[iType] + ":  f1 = " + f1 + 
					", f2 = " + f2 + ", prob = " + prob);
		}
		return prob;
	}
		
	private static double wordSubtypeProb (int iType, String word, boolean last) {
		Integer count;
		int ct;
		double prob;
		if (last)
			count = (Integer) finalWordCount[iType].get(word);
		else
			count = (Integer) nonFinalWordCount[iType].get(word);
		ct = count==null ? 0 : count.intValue();
		if (ct > 0)
			prob = (double) ct / (double) subtypeCount[iType];
		else {
			if (last)
				count = (Integer) finalWordCount[TOTAL].get(word);
			else
				count = (Integer) nonFinalWordCount[TOTAL].get(word);
			ct = count==null ? 0 : count.intValue();
			if (ct > 0)
				prob = BETA * (double) ct / (double) subtypeCount[TOTAL];
			else
				prob = 1. / VOCAB_SIZE;
		}
		if (trace)
			System.out.println ("P(" + word + "|" + typeSubtype[iType] + ")=" + prob);
		return prob;
	}
	
	/**
	 *  converts a type : subType to an integer.
	 */
	 
	private static int typeSubtypeToIndex (String type, String subType) {
		if (type.equals("0"))
			return NO_RELATION;
		for (int i=2; i<N_SUBTYPES; i++) {
			String tSt = typeSubtype[i];
			String t = tSt.substring(0,9).trim();
			String sT = tSt.substring(10);
			if (type.equals(t) && subType.equals(sT))
				return i;
		}
		return -1;
	}
		
	static void incrementHashMap (HashMap map, String key, int n) {
		int count;
		Integer countI = (Integer) map.get(key);
		if (countI == null)
			count = 0;
		else
			count = countI.intValue();
		map.put(key, new Integer(count+n));
	}
		
	private static void predict () throws IOException {
		String line;
		int count = 0;
		int correct = 0;
		int spurious = 0;
		int missing = 0;
		int incorrect = 0;
		BufferedReader reader = new BufferedReader(new FileReader(testPatternFile));
		while((line = reader.readLine()) != null) {
			count++;
			if (count > testCorpusSize) break;
			relPat = new RelationPattern (line);
			// look first for match in corpus file
			RelationPattern match1 = adam.findMatch(relPat, 5);
			RelationPattern match2 = eve.findMatch(relPat, 21);
			RelationPattern match3 = general.findMatch(relPat, 5);			
			String predictedType;
			if (match1 != null)
				predictedType = match1.relationType;
			else if (match2 != null)
				predictedType = match2.relationType;
			else if (match3 != null)
				predictedType = match3.relationType;
			else {
				// predictedType  = m.getBestOutcome(m.eval(buildPredictFeatures()));
				int i = mostLikelySubtype(relPat);
				if (i < 0 || i == NO_RELATION)
					predictedType = "0";
				else
					predictedType = typeSubtype[i].substring(0,4).trim();
			}
			// if(!(relPat.relationType.equals("0") && predictedType.equals("0"))) {
			if (!relPat.relationType.equals(predictedType)) {
			// if (true) {
				System.out.println (line);
				System.out.println 
				    ("Correct type: " + relPat.relationType + " Predicted type:  " + predictedType);
				if (match1 != null)
				    System.out.println ("Best Adam pattern = " + match1.string);
				else if (match2 != null)
				    System.out.println ("Best corpus pattern = " + match2.string);
				else if (match3 != null)
					System.out.println ("Best gen'l pattern = " + match3.string);
				else
					System.out.println ("No pattern matched.");
			}
			if (relPat.relationType.equals("0")) {
				if (!predictedType.equals("0"))
					spurious++;
			} else {
				if (relPat.relationType.equals(predictedType))
					correct++;
				else if (predictedType.equals("0"))
					missing++;
				else
					incorrect++;
			}
		}
		System.out.println (correct + " correct predictions");
		System.out.println (spurious + " spurious");
		System.out.println (missing + " missing");
		System.out.println (incorrect + " incorrect");
		System.out.println ("Recall = " + ((float) correct) / (correct + incorrect + missing));
		System.out.println ("Precision = " + ((float) correct) / (correct + incorrect + spurious));
		System.out.println ("Value = " + (correct - spurious - missing - incorrect));
	}	
}

abstract class RelationInstance {
	String relationType = "";
	String relationSubtype = "";
	String syntacticLink = "";           // syntactic connective (for candidate relation)
	ArrayList linearLink;           // series of linear connectives (for candidate relation)
	
	abstract String getType1();
	abstract String getType2();
}

/**
 *  a pattern and the associated relation type and subtype, as learned from
 *  the training corpus.
 */
 
class RelationPattern extends RelationInstance implements Comparable  {
	
	String string;
	String mentionType1 = "   ", mentionSubtype1 = "", mentionHead1 = "";
	String mentionType2 = "   ", mentionSubtype2 = "", mentionHead2 = "";

  /*
   *  creates a RelationPattern from the representation in 'line'.
   *  Line has the form
   *  type1 head1 [syntactic link : linear link] type2 head2 --> relationType relationSubtype
   */
   	
	RelationPattern (String line) {
		// System.out.println (line);
		string = line;
		linearLink = new ArrayList();
		StringTokenizer st = new StringTokenizer(line);
		String reverseFlag = st.nextToken();
		boolean reversed = false;
		if (reverseFlag.equals("arg1-arg2"))
			reversed = false;
		else if (reverseFlag.equals("arg2-arg1"))
			reversed = true;
		else
			System.out.println ("Unexpected value of reverseFlag: " + reverseFlag);
		mentionType1 = st.nextToken();
		mentionSubtype1 = st.nextToken();
		if (mentionSubtype1.equals("*")) mentionSubtype1 = "";
		mentionHead1 = st.nextToken();
		if (!st.nextToken().equals("[")) {
			System.out.println ("Cannot find [ in line: " + line);
			return;
		}
		syntacticLink = st.nextToken();
		if (!st.nextToken().equals(":")) {
			System.out.println ("Cannot find : in line: " + line);
			return;
		}
		String constit;
		while (!(constit = st.nextToken()).equals("]")) {
			if (!noiseToken(constit))  //<<<<< added Sep. 17
			linearLink.add(constit);
		}
		mentionType2 = st.nextToken();
		mentionSubtype2 = st.nextToken();
		if (mentionSubtype2.equals("*")) mentionSubtype2 = "";
		mentionHead2 = st.nextToken();
		if (!st.nextToken().equals("-->")) {
			System.out.println ("Cannot find --> in line: " + line);
			return;
		}
		relationType = st.nextToken();
		if (st.hasMoreTokens())
			relationSubtype = st.nextToken();
		if (reversed)
			relationType += "-1";
		return;
	}
	
	static boolean noiseToken (String token) {  // <<< added Sep. 17
		return // token.equals(",") ||
		       token.startsWith("adv(") ||
		       token.startsWith("timex(") ||
		       token.startsWith("q(") ||
		       token.equals("'") ||
		       token.equals("''") ||
		       token.equals("\"");
	}
	
	String getType1 () {
		return mentionType1;
	}
	
	String getType2 () {
		return mentionType2;
	}
	
	// two patterns match if the arguments match and either the syntactic link
	// or the linear link matches
	
	int distance (RelationInstance ri) {
		String type1, type2, subtype1, subtype2, head1, head2;
		if (ri instanceof RelationPattern) {
			RelationPattern rp = (RelationPattern) ri;
			type1 = rp.mentionType1;
			type2 = rp.mentionType2;
			subtype1 = rp.mentionSubtype1;
			subtype2 = rp.mentionSubtype2;
			head1 = rp.mentionHead1;
			head2 = rp.mentionHead2;
		} else {  // ri instance of RelationMention
			RelationMention rm = (RelationMention) ri;
			type1 = rm.mention1.type;
			type2 = rm.mention2.type;
			subtype1 = rm.mention1.subtype;
			subtype2 = rm.mention2.subtype;
			head1 = LearnRelations.getHead(rm.mention1);
			head2 = LearnRelations.getHead(rm.mention2);
		}			
		if (mentionType1.length() < 3 || mentionType2.length() < 3 ||
		    type1.length() < 3 || type2.length() < 3) {
			// System.out.println ("Error in mention type length.");
			return 3;
		}
		// either types must match or heads must match
		boolean wildCard1 = mentionHead1.equals("0") || head1.equals("0");
		boolean wildCard2 = mentionHead2.equals("0") || head2.equals("0");
		boolean exactHeadMatch1 = mentionHead1.equals(head1) && !wildCard1;
		boolean exactHeadMatch2 = mentionHead2.equals(head2) && !wildCard2;
		boolean typeMatch1 = 
		       (mentionType1.substring(0,3).equals(type1.substring(0,3)) ||
		        exactHeadMatch1);
		boolean subtypeMatch1 =
		       (mentionSubtype1.equals(subtype1)) || exactHeadMatch1;
		boolean typeMatch2 =
		       (mentionType2.substring(0,3).equals(type2.substring(0,3)) ||
		        exactHeadMatch2);
		boolean subtypeMatch2 =
		       (mentionSubtype2.equals(subtype2)) || exactHeadMatch2;
		// if (!(typeMatch1 && typeMatch2))
		// 	return 100;
		boolean arg1Match =
		       mentionHead1.equals(head1) || wildCard1;
		boolean arg2Match =
		       mentionHead2.equals(head2) || wildCard2;
		if (syntacticLink.equals("of") && !arg1Match)
			return 100;
		//if ((syntacticLink.equals("poss-1") || syntacticLink.equals("nameMod-1"))
		//    && !arg2Match)
		//    return 100;
		boolean syntaxMatch =
		       syntacticLink.equals(ri.syntacticLink) && !syntacticLink.equals("0");
		boolean prepMatch =
					 matchingRelations(syntacticLink, ri.syntacticLink);
		boolean linearMatch =
		       (linearLink.size() == ri.linearLink.size()) &&
		       (linearLink.size() == 0 || !linearLink.get(0).equals("0"));
		if (linearMatch) {
			for (int i=0; i<linearLink.size(); i++) {
				if (!linearLink.get(i).equals(ri.linearLink.get(i)))
					linearMatch = false;
			}
		}
		// minimal conditions:  types, one argument, and one link match
		//if (!(arg1Match | arg2Match))
		// 	return 100;
		// /*
		if (!(syntaxMatch | prepMatch | linearMatch))
			return 100;
		int dist = 0;
		if (!typeMatch1) dist += 20;
		// if (typeMatch1 & !subtypeMatch1) dist += 1;
		if (!typeMatch2) dist += 20;
		// if (typeMatch2 & !subtypeMatch2) dist += 1;
		if (!syntaxMatch) dist++;
		if (!prepMatch) dist++;
		if (!linearMatch) dist+=2;
		// if (wildCard1) dist+=2;
		// if (wildCard2) dist+=2;
		if (!arg1Match) dist+=8;
		if (!arg2Match) dist+=8;
		return dist;
		// */
		/*
		// new version Sep. 19
		int dist = 0;
		int linearDist;
		if ((linearLink.size() > 0 && linearLink.get(0).equals("0")) ||
		    (ri.linearLink.size() > 0 && ri.linearLink.get(0).equals("0")))
		    linearDist = 999;
		else
			linearDist = minEditDistance(linearLink, ri.linearLink);
		if (syntaxMatch)
			dist = Math.min(1, linearDist);
		else
			dist = 1 + linearDist;
		if (!arg1Match) dist+=8;
		if (!arg2Match) dist+=8;
		return dist;
		*/
	}
	
	public int compareTo (Object x) {
		RelationPattern rp = (RelationPattern) x;
		return string.compareTo(rp.string);
	}
	
	static String prepositionOfLink (String syntacticLink) {
		if (syntacticLink.startsWith("s-") || syntacticLink.startsWith("o-")) {
			int dash = syntacticLink.lastIndexOf('-');
			if (dash > 0)
				return syntacticLink.substring(dash+1);
		}
		return null;
	}
	
	static boolean matchingRelations (String syntacticLink1, String syntacticLink2) {
		String prep1 = prepositionOfLink(syntacticLink1);
		String prep2 = prepositionOfLink(syntacticLink2);
		return prep1 != null && prep1.equals(prep2);
	}
			
	/*
	public static int minEditDistance (ArrayList source, ArrayList target) {
		int n = target.size();
		int m = source.size();
		int[][] distance = new int[n+1][m+1];
		distance[0][0] = 0;
		for (int i=0; i<=n; i++) {
			for (int j=0; j<=m; j++) {
				if (i == 0 && j > 0)
					distance[i][j] = distance[i][j-1] + deleteCost((String) source.get(j-1));
				else if (j == 0 && i > 0)
					distance[i][j] = distance[i-1][j] + insertCost((String) target.get(i-1));
				else if (i > 0 && j > 0)
					distance[i][j] = 
						Math.min(distance[i-1][j] + insertCost((String) target.get(i-1)),
						Math.min(distance[i-1][j-1] + substCost((String) source.get(j-1),(String)  target.get(i-1)),
						         distance[i][j-1] + deleteCost((String) source.get(j-1))));
			}
		}
		return distance[n][m];
	}
	
	private static int insertCost (String t) {
		return 10;
	}
	
	private static int substCost (String s, String t) {
		if (s.equals(t))
			return 0;
		else
			return 10;
	}
	
	private static int deleteCost (String s) {
		return 10;
	}
	*/
}

class RelationPatternSet {
	
	TreeMap patternSet, patternIndex;
	
	RelationPatternSet () {
		patternSet = new TreeMap();
		patternIndex = new TreeMap();
	}
	
	Iterator iterator () {
		return patternSet.keySet().iterator();
	}
		
	void load (String patternFile, int skipCount) throws IOException {
		BufferedReader reader = new BufferedReader(new FileReader(patternFile));
		int count = 0;
		String line;
		while((line = reader.readLine()) != null) {
			count++;
			if (count < skipCount) continue;
			RelationPattern pattern = new RelationPattern(line);
			if (pattern.relationType.equals("")) continue;
			Integer freqI = (Integer) patternSet.get(pattern);
			int freq = (freqI==null) ? 0 : freqI.intValue();
			patternSet.put(pattern, new Integer(freq+1));
			if (!pattern.syntacticLink.equals("0")) {
				String prep = RelationPattern.prepositionOfLink (pattern.syntacticLink);
				if (prep != null) {
					indexPattern(pattern, prep);
				} else {
					indexPattern(pattern, pattern.syntacticLink);
				}
			}
			if (pattern.linearLink.size() > 0) {
				String last = (String) pattern.linearLink.get(pattern.linearLink.size()-1);
				if (!last.equals("0"))
					indexPattern(pattern, last);
			}	else {
				indexPattern(pattern, "**");
			}
		}
		System.out.println ((count - skipCount) + " patterns loaded.");
		reader.close();
	}
	
	private void indexPattern (RelationPattern pattern, String key) {
		HashSet set = (HashSet) patternIndex.get(key);
		if (set == null)
			set = new HashSet();
		set.add(pattern);
		patternIndex.put(key, set);
	}
		
	RelationPattern findMatch (RelationInstance rp, int maxDistance) {
		int bestCount = -1;
		int bestDistance = 100;
		RelationPattern bestPattern = null;
		// Iterator it = patternSet.keySet().iterator();
		HashSet candidates = null;
		if (!rp.syntacticLink.equals("0")) {
			String prep = RelationPattern.prepositionOfLink (rp.syntacticLink);
			if (prep != null) {
				candidates = (HashSet) patternIndex.get(prep);
			} else {
				candidates = (HashSet) patternIndex.get(rp.syntacticLink);
			}
		}
		if (candidates == null)
			candidates = new HashSet();
		if (rp.linearLink.size() > 0) {
			String last = (String) rp.linearLink.get(rp.linearLink.size()-1);
			if (!last.equals("0")) {
				HashSet more = (HashSet)patternIndex.get(last);
				if (more != null)
					candidates.addAll(more);
			} 
		} else {
			HashSet more = (HashSet)patternIndex.get("**");
			if (more != null)
				candidates.addAll(more);
		}
		Iterator it = candidates.iterator();
		HashMap relationTypeCount = new HashMap();  // <<
		HashMap typeRelationMap = new HashMap(); // <<
		while (it.hasNext()) {
			RelationPattern pattern = (RelationPattern) it.next();
			int count = ((Integer) patternSet.get(pattern)).intValue();
			int dis = pattern.distance(rp);
			/*
			if (dis < bestDistance || (dis == bestDistance && count > bestCount)) {
				bestDistance = dis;
				bestCount = count;
				bestPattern = pattern;
			// */
			// /*
			if (dis < bestDistance) {
				relationTypeCount.clear();
				typeRelationMap.clear();
				bestDistance = dis;
			}
			if (dis == bestDistance) {
				// System.out.println ("Relation type: " + pattern.relationType);
				String typeSubType = pattern.relationType + ":" + pattern.relationSubtype;
				BuildRelationModel.incrementHashMap (relationTypeCount, typeSubType, count);
				typeRelationMap.put(typeSubType, pattern);
			// */
			}
		}
		if (bestDistance <= maxDistance) {
			/*
			return bestPattern;
			// */
			// /*
			Iterator countIt = relationTypeCount.keySet().iterator();
			String bestTypeSubType = "";
			while (countIt.hasNext()) {
				String typeSubType = (String) countIt.next();
				Integer count = (Integer) relationTypeCount.get(typeSubType);
				int ct = count.intValue();
				if (ct > bestCount) {
					bestCount = ct;
					bestTypeSubType = typeSubType;
				}
			}
			return (RelationPattern) typeRelationMap.get(bestTypeSubType);
			// */			
		} else {
			return null;
		}
	}
}