import math, numpy as np
import numpy as np
import sys
from scriptexit import ScriptExit
import sklearn
import copy
import datetime
from dateutil.relativedelta import relativedelta
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import AdaBoostRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsRegressor
# from sklearn.datasets import load_iris
from sklearn.cross_validation import train_test_split
from sklearn import svm
import itertools

# only if all one level lists
def flatten(mylist):
  return list(itertools.chain.from_iterable(mylist))

# Anil's calibration code
def conj_calibrate(C,Xcal, ycal, eps):
    # COMPUTE THE ACCEPTANCE THRESHOLD FOR CLASSIFIER C, CALIBRATION DATA (Xcal,ycal) AND TARGET RATE eps
    # NOTE: WORKS ONLY FOR CLASSIFIERS THAT SUPPORT predict_proba()
    A = C.predict_proba(Xcal)
    pred_cal = C.predict(Xcal)
    n_cal, n_features = Xcal.shape
    # COMPUTE THE POTENTIAL THRESHOLD POINTS
    pt = np.empty(n_cal)
    for i in range(n_cal):
        ai = np.flipud(np.argsort(A[i,:]))
        pt[i] = A[i,ai[1]] # Looking for the second highest probability
    # FIND THE SMALLEST THRESHOLD TO SATISFY THE CONDITION
    a = np.flipud(np.argsort(pt))
    errors = sum(pred_cal != ycal) + 1
    print "errors ", errors
    predictions = n_cal + 1
    print "(errors+0.0)/predictions ", (errors+0.0)/predictions 
    print " eps ", eps
    threshold = 1
    if (errors+0.0)/predictions > eps:
        print 'entering threshold reduction loop'
        for i in a: # As we change the threshold, we refuse one more data point each time.
            if pred_cal[i] != ycal[i]:
                errors = errors - 1 # We got rid of one error
            predictions = predictions - 1
    	    print 'within loop: threshold is: ', pt[i]
    	    print 'within loop: errors: ', errors
    	    print 'within loop: predictions: ', predictions
    	    print "within loop: (errors+0.0)/predictions ", (errors+0.0)/predictions 
            if (errors+0.0)/predictions <= eps:
                threshold = pt[i]
    		print 'final threshold is: ', threshold
		return threshold
    return 0.0

# Anil's test code given threshold
def conj_test(C,Xtest, threshold):
    # PREDICT LABELS OF THE POINTS IN Xtest AND CHOOSE THE POINTS TO REFUSE RELATIVE TO THE ACCEPTANCE THRESHOLD threshold
    test_A = C.predict_proba(Xtest)
    pred_test = C.predict(Xtest)
    refused = np.sum(test_A>threshold, axis=1)>=2
    return pred_test, refused

# Anil's calibrate and then test
def conj_calibrate_test(C,Xcal, ycal, Xtest, eps):
    # CALIBRATE CLASSIFIER C ON (Xcal, ycal) WITH TARGET RATE eps AND PREDICT THE CORRESPONDING LABELS FOR Xtest
    th = conj_calibrate(C,Xcal, ycal, eps)
    print "threshold for calibration is: ", th
    if th == 0.0:
	return "Refuse everything"
    return conj_test(C,Xtest, th)




def filein(name):
	file = open(name,"r")
	text = file.readlines()
	file.close()
	i = 0
	fields=[]
	headers=[]
	mydict={}
	scoregame={}
	winlost = {}
	history={} #key will be team name and for every game, we're going to have date, 
	# opponent team, a bunch of statistics. So this will be a dictionary from team name
	# to a sequence of game statistics. history['knicks'] = [[10/1/2014, heat, ...], 
	# [10/2/2014, cavaliers, ....] ...]
	for x in text:
		if (not "#" == x[0]) and (len(x) > 5):
			if "date" == x[0:4]:
				headers = x.split("	")
				i = 0
				for h in headers:
					mydict[h] = i
					i+= 1
			else:
				player = x.split("	")
				if ((player[mydict['ou margin']]=='-') or (player[mydict['rest']]=='-') or (player[mydict['attendance']]=='-') or (player[mydict['ats margin']]=='-') or (player[mydict['line']]=='-') or (player[mydict['fouls']]=='-')) and(player[mydict['full name']] == 'Arkansas Little Rock Trojans') and False:
					print "Arkansas Little Rock Trojans rejected: ", player
					print "ou margin: ", player[mydict['ou margin']]
					print "rest: ", player[mydict['rest']]
					print "attendance: ", player[mydict['attendance']]
					print "ats margin: ", player[mydict['ats margin']]
					print "fouls: ", player[mydict['fouls']]
				if (not player[mydict['ou margin']]=='-') and (not player[mydict['rest']]=='-') and (not player[mydict['attendance']]=='-') and (not player[mydict['ats margin']]=='-') and (not player[mydict['line']]=='-') and (not player[mydict['fouls']]=='-'):
				#print "player: ", player
					fields.append(player)
					predicted_ou_line = int(player[mydict['points']])+ int(player[mydict['o:points']]) - float(player[mydict['ou margin']])
					#print "player:", player
					twopointattempts = int(player[mydict['field goals attempted']]) - int(player[mydict['three pointers attempted']])
					twopointmade = int(player[mydict['field goals made']]) - int(player[mydict['three pointers made']])
					if (betflag == 1): 
					    overflag = (float(player[mydict['ou margin']])) > 0
					    underflag = (float(player[mydict['ou margin']])) < 0
					    if overflag:
						# Tyler can modify zz except last field
						zz = [int(player[mydict['date']]), (player[mydict['full name']]), player[mydict['o:full name']], float(player[mydict['ou margin']]),   int(player[mydict['points']]), int(player[mydict['o:points']]), int(player[mydict['rest']]), float(player[mydict['line']]), float(player[mydict['ats margin']]), twopointattempts, twopointmade, int(player[mydict['fouls']]), int(player[mydict['attendance']]), int(player[mydict['three pointers attempted']]), int(player[mydict['three pointers made']]), int(player[mydict['turnovers\n']][:-1]), predicted_ou_line, 1]
					    elif underflag: 
						# Tyler can modify zz except last field:
						zz = [int(player[mydict['date']]), (player[mydict['full name']]), player[mydict['o:full name']], float(player[mydict['ou margin']]),   int(player[mydict['points']]), int(player[mydict['o:points']]), int(player[mydict['rest']]), float(player[mydict['line']]), float(player[mydict['ats margin']]), twopointattempts, twopointmade,  int(player[mydict['fouls']]), int(player[mydict['attendance']]), int(player[mydict['three pointers attempted']]), int(player[mydict['three pointers made']]), int(player[mydict['turnovers\n']][:-1]), predicted_ou_line, -1]
					    else: # equality 
						# Tyler can modify zz except last field:
						zz = [int(player[mydict['date']]), (player[mydict['full name']]), player[mydict['o:full name']], float(player[mydict['ou margin']]),   int(player[mydict['points']]), int(player[mydict['o:points']]), int(player[mydict['rest']]), float(player[mydict['line']]), float(player[mydict['ats margin']]), twopointattempts, twopointmade, int(player[mydict['fouls']]), int(player[mydict['attendance']]), int(player[mydict['three pointers attempted']]), int(player[mydict['three pointers made']]), int(player[mydict['turnovers\n']][:-1]), predicted_ou_line, 0]
					if (betflag == 2):  # predict sum of points
						zz = [int(player[mydict['date']]), (player[mydict['full name']]), player[mydict['o:full name']], float(player[mydict['ou margin']]),   int(player[mydict['points']]), int(player[mydict['o:points']]), int(player[mydict['rest']]), float(player[mydict['line']]), float(player[mydict['ats margin']]), twopointattempts, twopointmade, int(player[mydict['fouls']]), int(player[mydict['attendance']]), int(player[mydict['three pointers attempted']]), int(player[mydict['three pointers made']]), int(player[mydict['turnovers\n']][:-1]), predicted_ou_line, int(player[mydict['points']])+ int(player[mydict['o:points']])]
					if player[mydict['full name']] in history:
					  history[player[mydict['full name']]].append(zz)
					else:
					  history[player[mydict['full name']]]= [zz]
	return [headers, fields, mydict, history] # only history is used

# Order all games by time into totallist
def ordergames(homehist, awayhist):
	totallist = []
	for h in homehist:
	   # print "h is: ", h
	   if h in awayhist:
		homelist = homehist[h]
		awaylist = awayhist[h]
		# print "team, len(homelist):", h, len(homelist)
		# print "team, len(awaylist):", h, len(awaylist)
		ihome = 0
		iaway = 0
		while (ihome < len(homelist)) and (iaway < len(awaylist)):
			if (homelist[ihome][0] <= awaylist[iaway][0]):
				totallist.append(['home', homelist[ihome]])
				ihome+=1
			else:
				totallist.append(['away', awaylist[iaway]])
				iaway+=1
		#these almost never happen if they happen it is because the schedules last couple of games were away
		if (ihome == len(homelist)) and (iaway < len(awaylist)):
			while(iaway < len(awaylist)):
				totallist.append(['away', awaylist[iaway]])
				iaway+=1
		#these almost never happen if they happen it is because the schedules last couple of games were home
		if (ihome < len(homelist)) and (iaway == len(awaylist)):
			while(ihome < len(homelist)):
				totallist.append(['home', homelist[ihome]])
				ihome+=1
		count = 1 # Need to start with 1 to avoid divide by 0
	return totallist


# totallist is all games sorted by time
# For each team and date find all summary statistics 
# up to this game for that team.
def getprev(totallist):
  outdict = {}
  for pair in totallist:
    homeaway = pair[0]
    t = pair[1]
    teamname = t[1]
    if teamname not in outdict:
      outdict[teamname] = {}
    if homeaway not in outdict[teamname]:
      outdict[teamname][homeaway] = []
    if teamname == 'Arkansas Little Rock Trojans':
      print "found: ", teamname
      print outdict[teamname]
    x = runningstats(t, outdict[teamname][homeaway])
    # print "result of runningstats:: ", x
    outdict[teamname][homeaway].append(x)
    # print "result of outdict append:: ", outdict[teamname][homeaway]
  return outdict

# Tyler should modify this if there are more stats to be collected.
# Get the running stats for this game
# 0: the number of games including this one
# 1:  the date
# 2:  the other team
# 3: the line
# 4: the overunder line
# 5: days of rest
# these are beyond what is known when the game starts
# 6: points for this game
# 7: two pointer attempts
# 8: two pointer made
# 9: three pointer attempts
# 10: three pointer made
# 11: target field (whether boolean or number)
# previousstats is an array having to do with a team and can be home or away
# newgame will be added.
# Tyler note: if you change data fields collected, you will affect this.
# totallist is all games sorted by time
# For each team and date find all summary statistics 
# up to this game for that team.
def runningstats(newgame, previousstats):
  global num_beforegame
  statlen = 16 # this is the number of stats we keep and may change
  # we keep the stats in summation form and then divide by the length
  # of the list to get averages
  # Five of them are dummies.
  if 0 < len(previousstats): 
    current = previousstats[-1]
  else:
    current = [0]*statlen
  # add to history
  # Before game starts, we know the line, 
  # the overunderline (sum of points - ou margin) the teams, 
  # the rest, plus history divided by home and away: 
  # averages of points per game, average three throws
  # average turnovers, average ats margin, average two pointers,
  # create the new current
  # print 'previousstats is: ', previousstats
  # for p in previousstats:
    # print p
  # print 'current[0] is: ', current[0]
  newcurrent = []
  newcurrent.append(current[0] + 1) # the number of games including this one
  newcurrent.append(newgame[0]) # the date
  newcurrent.append(newgame[2]) # the other team
  newcurrent.append(newgame[7]) # the line
  # newcurrent.append(newgame[16]) # the overunder line
  y = (newgame[16]) # line + ats margin
  if (y > 150):
    newcurrent.append(3.0)  
  elif (y >= 144) and (y <= 150):
    newcurrent.append(2.0)  
  elif (y >= 124) and (y <= 144):
    newcurrent.append(1.0)  
  else:
    newcurrent.append(0.0)  
  newcurrent.append(newgame[6]) # the rest
  # following is known after the game
  num_beforegame = len(newcurrent) 
	# how many stats are known before game starts
  newcurrent.append(current[num_beforegame] + newgame[4]) 
  #newcurrent.append(0.0)
	# points for this game
  #newcurrent.append(current[num_beforegame+1] + newgame[9]) 
  newcurrent.append(0.0)
     # two pointer attempts
  #newcurrent.append(current[num_beforegame+2] + newgame[10]) # two pointer made
  newcurrent.append(0.0)
  newcurrent.append(current[num_beforegame+3] + newgame[13]) 
  #newcurrent.append(0.0)
     # three pointer attempts
  #newcurrent.append(current[num_beforegame+4] + newgame[14]) 
  newcurrent.append(0.0)
    # three pointer made
  # newcurrent.append(newgame[8])  # ats margin Tyler can change to include
  newcurrent.append(0.0)  # dummy

  #y = (newgame[7]+newgame[8]) # line + ats margin
  y = (newgame[16]) # line + ats margin
  if (y > 150):
    newcurrent.append(3.0)  
  elif (y >= 144) and (y <= 150):
    newcurrent.append(2.0)  
  elif (y >= 124) and (y <= 144):
    newcurrent.append(1.0)  
  else:
    newcurrent.append(0.0)  
  newcurrent.append(0.0)  # dummy
  newcurrent.append(0.0)  # dummy
  newcurrent.append(newgame[17]) # target field (whether boolean or number)
  # print 'newcurrent is: ', newcurrent
  return newcurrent
	
# Take this teamhistory and for each game get the stats of both sides and then
# the outcome as the last column.
# The strategy is the following: 
# for each game (we identify a game by date and opposing
# team), take all the previous stats for each team both for home and away
# and the stats for the home game team and away game team and the estimate.
# Each row of team history contains information about that team
# as both a home team and an away team.
# Tyler note: if you change data collected, you will affect this.
def assemblearray(teamhistory, datecutoffstart, datecutoffend):
  outarr = []
  for k in teamhistory:
    t = teamhistory[k]
    # print "t: ", t
    # print "t['home']: ", t['home']
    for h in t['home']: 
      if (h[1] >= datecutoffstart) and (h[1] <= datecutoffend):
        newel = []
        homehomestats = assemblestatsupto('homehome', h[1], t['home'])
        homeawaystats = assemblestatsupto('homeaway', h[1], t['away'])
        if h[2] in teamhistory:
         t2 = teamhistory[h[2]]
         # print "t2 is: ", t2
         awayhomestats = assemblestatsupto('awayhome', h[1], t2['home'])
         awayawaystats = assemblestatsupto('awayaway', h[1], t2['away'])
         # print "home team is: ", k , ' away team is: ', h[2]
         # print "length of homehomestats is: ", len(homehomestats)
         # print "length of awayhomestats is: ", len(awayhomestats)
         # print "length of homeawaystats is: ", len(homeawaystats)
         # print "length of awayawaystats is: ", len(awayawaystats)
         if (0 < len(awayhomestats)) and (0 < len(homehomestats)) and (0 < len(homeawaystats)) and (0 < len(awayawaystats)):
          targ = homehomestats[-1] # last field
          newel = []
          newel.append(h[1])
          newel.append(homehomestats[0])  # line
          newel.append(homehomestats[1]) # overunderline
          newel.append(homehomestats[2]) # home days of rest
          newel.append(awayawaystats[2]) # away team days of rest
          # home team stats
          newel.append(homehomestats[3]) # average of points per game 
          newel.append(homehomestats[4]) # average of two point attemps per game
          newel.append(homehomestats[5]) # average of two point made  per game
          newel.append(homehomestats[6]) # average of 3  point attemps per game
          newel.append(homehomestats[7]) # average of 3  point made  per game
          newel.append(homehomestats[8]) # ats margin ???
          newel.append(homehomestats[9]) # dummy
          newel.append(homehomestats[10]) # dummy
          newel.append(homehomestats[11]) # dummy
          newel.append(homeawaystats[3]) # average of points per game
          newel.append(homeawaystats[4]) # average of two point attemps per game
          newel.append(homeawaystats[5]) # average of two point made  per game
          newel.append(homeawaystats[6]) # average of 3  point attemps per game
          newel.append(homeawaystats[7]) # average of 3  point made  per game
          newel.append(homeawaystats[8]) # dummy
          newel.append(homeawaystats[9]) # dummy
          newel.append(homeawaystats[10]) # dummy
          newel.append(homeawaystats[11]) # dummy
          # away team stats
          newel.append(awayhomestats[3]) # average of points per game 
          newel.append(awayhomestats[4]) # average of two point attemps per game
          newel.append(awayhomestats[5]) # average of two point made  per game
          newel.append(awayhomestats[6]) # average of 3  point attemps per game
          newel.append(awayhomestats[7]) # average of 3  point made  per game
          newel.append(awayhomestats[8]) # dummy
          newel.append(awayhomestats[9]) # dummy
          newel.append(awayhomestats[10]) # dummy
          newel.append(awayhomestats[11]) # dummy
          newel.append(awayawaystats[3]) # average of points per game
          newel.append(awayawaystats[4]) # average of two point attemps per game
          newel.append(awayawaystats[5]) # average of two point made  per game
          newel.append(awayawaystats[6]) # average of 3  point attemps per game
          newel.append(awayawaystats[7]) # average of 3  point made  per game
          newel.append(awayawaystats[8]) # dummy
          newel.append(awayawaystats[9]) # dummy
          newel.append(awayawaystats[10]) # dummy
          newel.append(awayawaystats[11]) # dummy
          newel.append(targ) # target
          # print "home team is: ", k , ' away team is: ', h[2]
          if (k in forbiddenteams) or (h[2] in forbiddenteams):
            newel.append(True) # True means forbidden
          else:
            newel.append(False)
          outarr.append(newel)
  return outarr

# we are taking either home or away games for some team and capturing
# the averages up to previous games etc.
# Tyler note: if you change data collected, you will affect this.
def assemblestatsupto(status, uptodate, myhist):
  older = []
  older2 = []
  notyetfound = True
  current = []
  for h in myhist:
    if (h[1] == uptodate) and notyetfound:
      notyetfound = False
      current = copy.deepcopy(h)
    elif (h[1] > uptodate) and notyetfound:
      notyetfound = False
      current = copy.deepcopy(older)
      older = copy.deepcopy(older2)
    elif notyetfound:
      older2 = copy.deepcopy(older)
      older = copy.deepcopy(h)
  # print "status  is: ", status
  # print "date  is: ", uptodate
  # print "current is: ", current
  # print "older is: ", older
  # 0: the number of games including this one
  # 1:  the date
  # 2:  the other team
  # 3: the line
  # 4: the overunder line
  # 5: days of rest
  # these are beyond what is known when the game starts
  # 6: points for this game
  # 7: two pointer attempts
  # 8: two pointer made
  # 9: three pointer attempts
  # 10: three pointer made
  # several dummy fields all 0s starting at num_beforegame+5
  # 10: target field (whether boolean or number)
  if (0 < len(current)) and (0 < len(older)):
    num = older[0] + 0.0
    # print 'num_beforegame is: ', num_beforegame
    out = [current[3], current[4], current[5], older[num_beforegame]/num, older[num_beforegame+1]/num, older[num_beforegame+2]/num, older[num_beforegame+3]/num, older[num_beforegame+4]/num, older[num_beforegame+5]/num, older[num_beforegame+6], older[num_beforegame+7], older[num_beforegame+8], current[num_beforegame+9]]
    print 'current is: ', current
    print 'older is: ', older
    print 'out is: ', out
    return out
  return []

  
				

# DATA

# do not predict on these
# Tyler changes
forbiddenteams = ['Southern California Trojans', 'Texas Southern Tigers']

train_on_forbidden = True # if true then train on forbidden teams, though not for testing. If False, then don't collect for any purpose.

betflag = 1  # If 1, we're trying to figure out over or under
   # If 2, we're trying to predict the total score.

# EXECUTION

# Tyler main note TYLER READ THIS: 
# if you wish to change 
# the data collected, then you can change all the zzs in filein
# Then you change runningstats
# Then assemblestatsupto and then assemblearray
# For example to add a new stat in zz and have that propagate through,
# you'd have to add a new field say at the end of every zz.
# then you'd reference that as newcurrent.append in runningstats.
# then you'd have to change  the vector out in assemblestatsupto
# then the newel.append in assemblearray four times (for homehome, homeaway
# awayhome, awayaway).
# Alternatively, assuming zz has what you need but you think you might
# want to combine some values (e.g. get the total number of two pointers
# to the nearest 10; get the ratio between threes and twos),
# you can start with runningstats and then work your way through
#  assemblestatsupto and then assemblearray.
datahome = filein(targetfile)
# homeheaders = datahome[0]
# homefields = datahome[1]
# homedict = datahome[2]
homehistory = datahome[3]
dataaway = filein(targetfile2)
# awayheaders = dataaway[0]
# awayfields = dataaway[1]
# awaydict = dataaway[2]
awayhistory = dataaway[3]
# print "dictionary for home history:", 
# for h in homehistory:
  # print h, homehistory[h]
# print "dictionary for away history:"
# for h in h, awayhistory:
  # print awayhistory[h]


# create ordered games for all teams
# print "length homehistory: ", len(homehistory)
# print "length awayhistory: ", len(awayhistory)
orderedgames = ordergames(homehistory, awayhistory)
print "length ordered games: ", len(orderedgames)
# print "ordered games: " 
# for g in orderedgames:
  # print "g is:  ", g

# For each team and date find all summary statistics 
# from previous games for that team.
teamhistory = getprev(orderedgames)
# print 'teamhistory: '
# for t in teamhistory:
  # print t, teamhistory

		
# Before game starts, we know the line, 
# the overunderline (sum of points - ou margin) the teams, 
# the rest, plus history divided by home and away: 
# averages of points per game, average three throws
# average turnovers, average ats margin, average two pointers,
# average two pointers made, average fouls, pure three pointers, pure three pointers made (two pointers = field goals - three pointers)
startdate = teststartdate
enddate = 20160404
print "length of teamhistory: ", len(teamhistory)
alldata = assemblearray(teamhistory, startdate, enddate)
print "length of alldata: ", len(alldata)
s = str(startdate)
mystartdate = datetime.date(int(s[:4]), int(s[4:6]), int(s[6:]))
# Tyler can adjust training period
mytestdate = mystartdate + relativedelta(days=60)
y = str(mytestdate.year)
m = str(mytestdate.month)
if (2 > len(m)):
  m = '0'+m
d = str(mytestdate.day)
testdate = int(y+m+d)


print "date, line, overunderline, home team rest, away team rest, homehome average points, homehome average two point attempts, homehome average two point made, homehome average three point attempts, homehome average 3 points made, atsmargin or dummy, ... three dummy fields, homeaway average points, homeaway average two point attempts, homeaway average two points made, homeaway average three point attempts, homeaway, average three points made, homeaway atsmargin or dummy, ... three dummy fields, ... same fields as homeaway for awayhome ... , same fields as homeaway for awayaway ..., target from homehome" 
for x in alldata:
  print x
if train_on_forbidden: 
  X_train = np.array([np.array([float(value) for value in line[1:-2]]) for line in alldata if line[0] < testdate])
  y_train =  np.array([int(line[-2]) for line in alldata if line[0] < testdate])
else:
  X_train = np.array([np.array([float(value) for value in line[1:-2]]) for line in alldata if (line[0] < testdate) and (line[-1] == False)])
  y_train =  np.array([int(line[-2]) for line in alldata if (line[0] < testdate) and (line[-1] == False)])

X_test = np.array([np.array([float(value) for value in line[1:-2]]) for line in alldata if (line[0] >= testdate) and (line[-1] == False)])
y_test =  np.array([int(line[-2]) for line in alldata if (line[0] >= testdate) and (line[-1] == False)])

print "len(y_test) is: ", len(y_test)

# print "first twenty rows of teamsingametest is: ", 
# print teamsingametest
# print "first twenty rows of data is: ", X_testall[:5]
# print "first twenty rows of target is: ", y_test[:5]

if min(len(X_test), len(y_test), len(X_train), len(y_train)) == 0:
	raise ScriptExit("some file is of zero length")

# Machine Learning Starts Here




print "Decision tree:"

clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)

print "decision tree classifier ", clf


   

# print "\n\nIt ranks the following feature importances:"




# *SAMPLE RANDOM FOREST*




# clf = LogisticRegression()
# clf = RandomForestClassifier()

# SAMPLE SVN

# clf = svm.SVC(gamma=0.001, C=100.)
clf = KNeighborsRegressor()

clf.fit(X_train, y_train)
# clf.predict(X_test)
mypred = clf.predict(X_test)

# print 'svn mypred', mypred

clf = RandomForestClassifier(n_estimators=100, min_samples_leaf=1, oob_score=False, n_jobs=-1)
clf.fit(X_train, y_train)
importances = clf.feature_importances_
print "feature importances are: ", importances

mypred2 = clf.predict(X_test)
# print 'random forest predictions', mypred2
# print 'random forest truth', y_test
correct = 0.0
for i in range(len(y_test)):
  if mypred2[i] == y_test[i]:
    correct+= 1
correctnessratio = correct/len(y_test)
print 'random forest correctness rate (no refusals)', correctnessratio

'''
# comment this out
# Now use Anil's method

# SPLIT THE TRAINING DATA AS THE CORE TRAINING AND CALIBRATION SETS (2:1)
core_X, cal_X, core_y, cal_y = train_test_split(X_train, y_train, train_size=0.66)
n_core, n_features = core_X.shape
# TRAIN YOUR CLASSIFIER ON THE CORE SET
classifier = RandomForestClassifier(n_estimators=100, min_samples_leaf=1, oob_score=False, n_jobs=-1)
classifier.fit(core_X, core_y)
epsilon = 0.42
mypred3, ref = conj_calibrate_test(classifier, cal_X,cal_y,X_test,epsilon)
# Want to evaluate these predicitons just on the non-refused data
# err  = pred != test_y
print 'len(mypred3) = ', len(mypred3)
print 'len(ref) = ', len(ref)

# print 'random forest truth', y_test
correct = 0.0
num = 0
for i in range(len(y_test)):
 if ref[i] == False:
  num+= 1 # others don't count
  print 'mypred3[i] is: ', mypred3[i], ' and y_test[i] is: ', y_test[i]
  if mypred3[i] == y_test[i]:
    correct+= 1
correctnessratio = correct/num
print 'random forest with refusal correctness rate', correctnessratio

'''
