#!/usr/bin/python
##################################################
#Copyright (c) 2011, David F. Fouhey
#See License.txt
##################################################
#Main Code
##################################################

from SIFT import *
from JLinkage import *
from GlobalMerging import *
from SpatialAnalysis import *

OUTPUT_IMAGES = True
try:
    import cv
except:
    print "No OpenCV Python bindings; can't output images"
    OUTPUT_IMAGES = False

import pickle, sys

def intP(p):
    """Turn a tuple of floats into an tuple of ints"""
    return tuple(map(int, list(p)))

def getClusterImage(imageSource, clusters, usePoints2=False):
    """Get an image drawing the given clusters on the image"""
    image = copyImage(imageSource) 
    for clusterI, cluster in enumerate(clusters):
        #get the cluster color
        rgbColor = cv.RGB(*int2Color(clusterI))
        clusterPoints = [p[:2] for p in cluster]
        if usePoints2:
            clusterPoints = [p[2:] for p in cluster]
        #draw the point
        for p in clusterPoints:
            cv.Circle(image, intP(p), 5, rgbColor, thickness=-1)
        #draw the convex hull 
        CH = getConvexHull(clusterPoints)
        for i in range(-1,len(CH)-1):
            p,q = CH[i], CH[i+1]
            cv.Line(image, intP(p), intP(q), rgbColor, thickness=3)
    return image

def copyImage(image):
    """Copy an opencv image"""
    imageCopy = cv.CreateImage(cv.GetSize(image), image.depth, image.nChannels)
    cv.Copy(image, imageCopy)
    return imageCopy

def getConvexHull(points):
    """Get the convex hull of some points"""
    return [p for p in cv.ConvexHull2(points, cv.CreateMemStorage(), return_points=1)]

def outputPlanes(outFile, clusters):
    """Output the planes represented by clusters in a text-file"""
    for clusterI, cluster in enumerate(clusters):
        H = solvePerspective(cluster)
        H = H[0]+H[1]+H[2]
        outFile.write("%d:H:%f %f %f %f %f %f %f %f %f\n" % tuple([clusterI+1]+H))
        for p in cluster:
            outFile.write("%d:p:%f %f %f %f\n" % (tuple([clusterI+1])+p))

def drawMatches(image1, image2, matches):
    """Draw a match image"""
    w1, h1 = cv.GetSize(image1); w2, h2 = cv.GetSize(image2) 
    renderImage = cv.CreateImage((max(w1, w2), h1+h2), image1.depth, image1.nChannels)
    cv.Zero(renderImage)
    cv.SetImageROI(renderImage, (0, 0, w1, h1)); cv.Copy(image1, renderImage)
    cv.SetImageROI(renderImage, (0, h1, w2, h2)); cv.Copy(image2, renderImage)
    cv.ResetImageROI(renderImage)
    for p in matches:
        x, y, xp, yp = map(int,p)
        cv.Line(renderImage, (x,y), (xp,yp+h1), cv.RGB(255,0,0))
    return renderImage

if __name__ == "__main__":
    if len(sys.argv) != 3:
        print sys.argv[0],"file1 file2"
        sys.exit(1)
    if not os.path.exists("output"):
        os.mkdir("output")
    inputImage1, inputImage2 = sys.argv[1:]
    if not (os.path.exists(inputImage1) and os.path.exists(inputImage2)):
        print "No such images!"
        sys.exit(1)
    matches = getImageMatches(inputImage1, inputImage2, 0.75)
    clustersJL = JLinkage(matches, 200, 1.5, 25.0, 7)
    print "JLinkage Cluster Count:", len(clustersJL)
    clusters = globalMerging(clustersJL, 1.5)
    filteredClusters = fragmentClusters(clusters, 7)
    print "Filtered Cluster Count:", len(filteredClusters)
    #output the plane info
    outputPlanes(file("output/out.txt","w"), filteredClusters)
    
    #if we have opencv, output the images
    if OUTPUT_IMAGES:
        image1 = cv.LoadImage(inputImage1+".pgm")
        image2 = cv.LoadImage(inputImage2+".pgm")
        cv.SaveImage("output/matches.png", drawMatches(image1, image2, matches))
        cv.SaveImage("output/clustersJL.png", getClusterImage(image1, clustersJL))
        cv.SaveImage("output/clustersMerged.png", getClusterImage(image1, clusters))
        cv.SaveImage("output/final1.png", getClusterImage(image1, filteredClusters))
        cv.SaveImage("output/final2.png", getClusterImage(image2, filteredClusters, usePoints2=True))

    #clean up after ourselves
    for imageName in [inputImage1, inputImage2]:
        os.remove("%s.pgm" % (imageName))
        os.remove("%s.keys" % (imageName))
     

