#!/usr/bin/python
##################################################
#Copyright (c) 2011, David F. Fouhey
#See License.txt
##################################################
#JLinkage
##################################################


from homography import *
from math import *
from random import *

def chooseWithWeight(weights):
    """Given a list of weights, choose an index with the given weight"""
    weightSum = sum(weights)
    #sample a value within the sum of the weights 
    selection = random()*weightSum
    #this could be improved to log n...
    i, soFar = 0, 0.0
    while i < len(weights):
        if selection < soFar+weights[i]: 
            break 
        soFar += weights[i]
        i += 1
    return i

def gaussPDF(x, mu, sigmaSq):
    """Compute the Gaussian PDF P(x | \mu, \sigma)"""
    multiplier = 1.0 / sqrt(2*pi*sigmaSq)
    toExp = -0.5*(x-mu)*(x-mu)/sigmaSq
    return multiplier*exp(toExp)

def bivariateGauss(p, q, sigma):
    """Compute the bivariate Gaussian PDF centered at p for q"""
    return gaussPDF(p[0], q[0], sigma**2)*gaussPDF(p[1], q[1], sigma**2)

def sampleWithGaussian(matches, sigma):
    """Sample 4 matches from the matches with Kanazawa sampling"""
    firstMatch = randint(0,len(matches)-1)
    location = matches[firstMatch][:2]
    #get weights for all the points
    weights = [bivariateGauss(location, matches[i][:2], sigma) for i in range(len(matches))]
    sampledMatches = [firstMatch]
    while len(sampledMatches) != 4:
        candidateMatches = [i for i in range(len(matches)) if i not in sampledMatches]
        candidateWeights = [weights[i] for i in range(len(matches)) if i not in sampledMatches]
        sampleIndexIndex = chooseWithWeight(candidateWeights)
        sampleIndex = candidateMatches[sampleIndexIndex]
        sampledMatches.append(sampleIndex)
    assert len(sampledMatches) == len(set(sampledMatches))
    return [matches[i] for i in sampledMatches] 

def JaccardDistance(A, B):
    """Compute the Jaccard distance between sets A and B"""
    union = float(len(A.union(B)))
    isect = float(len(A.intersection(B)))
    if union == 0.0:
        return 1.0
    return 1.0 - (isect / union)

def unpackClusters(clusters, matches):
    """Given a list of lists of indices into a matches list, return the matches"""
    matchClusters = []
    for cluster in clusters:
        matchClusters.append([matches[i] for i in cluster])
    return matchClusters 

def JLinkage(matches, numModels, epsilon, sigma, minClusterSize):
    """Perform J-Linkage"""
    #sample the models
    models = []
    for i in range(numModels):
        models.append(solvePerspective(sampleWithGaussian(matches, sigma)))

    #singleton clusters
    clusters = [[i] for i in range(len(matches))]

    #compute the preference sets
    preferenceSets = {}
    for pi, p in enumerate(matches):
        preferenceSets[pi] = set([i for i in range(numModels) if getError(models[i], p) < epsilon])

    #This is not the naive implementation - we don't update distances at every 
    #iteration, but only as necessary.  
    validIndices = range(len(matches))
    #distances = positive if valid, else negative
    distances = {}
    for ii,i in enumerate(validIndices):
        for j in validIndices[ii+1:]:
            distances[(i,j)] = -1

    while True:
        minDistance = 2.0
        minPair = (-1, -1)
        for ii, i in enumerate(validIndices):
            for j in validIndices[ii+1:]:
                #if not valid, update
                if distances[(i,j)] < 0:
                    distances[(i,j)] = JaccardDistance(preferenceSets[i], preferenceSets[j])

                #update the minimum distance
                if distances[(i,j)] < minDistance:
                    minPair = i,j
                    minDistance = distances[(i,j)]

        if minDistance >= 1.0:
            break

        #update the distances
        clusters[minPair[0]] += clusters[minPair[1]]
        preferenceSets[minPair[0]].intersection_update(preferenceSets[minPair[1]])
        validIndices.remove(minPair[1])

        #mark the distances as invalid
        for i in validIndices:
            distances[(minPair[0],i)] = distances[(i,minPair[0])] = -1
            
    
    clusters = [clusters[i] for i in validIndices if len(clusters[i]) > minClusterSize]
    return unpackClusters(clusters, matches)


