import numpy as np
import math
from scipy import ndimage,spatial
import matplotlib.pyplot as plt
from skimage import feature
from skimage.feature import corner_harris, corner_subpix, corner_peaks
from sklearn.cluster import KMeans
from skimage.filters import threshold_otsu
from skimage.draw import line_aa
import math
import sys

def grayscale_luminosity(img):
    return 0.21*img[:,:,0] + 0.72 * img[:,:,1] + 0.07 * img[:,:,2]

def kmeans_cluster(img_gray):
    img_gray_reshaped = img_gray.reshape((-1,1))
    print("About to start k-means clustering")
    k = KMeans(n_clusters = 5)
    print("Starting k-means clustring")
    k.fit(img_gray_reshaped)
    print("Finished k-means clustering")
    values = k.cluster_centers_.squeeze()
    labels = k.labels_
    return (values, labels)

def corners_seg(img_labels2):
    coords = corner_peaks(corner_harris(img_labels2), min_distance=5)
    coords_subpix = corner_subpix(img_labels2, coords, window_size=13)
    return (coords, coords_subpix)

def reassign_labels(labels1, values1, labels2, values2):
    # TODO not efficient computionally
    if values1.length != values2.length:
        return None
    
    n = values1.length
    c_min = 1e8 # TODO max value of float 
    reassignment = np.array((n, 2))
    for i in range(0, n):
        e1 = values1[i]
        for j in range(0, n):
            e2 = values2[j]
            d = np.sum(np.abs(e2-e1))
            if d < c_min:
                c_min = d
                reassignment[i,0] = labels1[i]
                reassignment[i,1] = labels2[j]
                
    return reassignment

def compare_by_poi(img1, img2, xwindow = 5, ywindow = 5):
    # both images are matrix of 0s or 1s
    k = np.ones((xwindow,ywindow))
    middlex = (xwindow-1) / 2
    middley = (ywindow-1) / 2
    k[middlex,middley] = 0
    img2_neighbour_count_int = ndimage.convolve(img2)
    img2_neighbour_count_bool = img2_neighbour_count_int > 0
    img2_neighbour_count_int2 = np.array(img2_neighbour_count_bool, dtype=np.uint32)
    img_comparision = img2_neighbour_count_int2 * img1
    n1 = np.sum(img1)
    n2 = np.sum(img_comparision)
    return float(n2)/float(n1)

def find_first_marked_point(crossing_mask):
    (n,m) = crossing_mask.shape
    for i in range(0,n):
        for j in range(0,m):
            if crossing_mask[i,j]:
                return (j,i)
            
    return None

def MSE(a1, a2):
    c = 1
    for n in a1.shape:
        c *= n
        
    d = a1 - a2
    return np.sum(d*d) / n

nsampling = 100

def calculate_shape_vector(
        path,
        nsampling,
        median_filter_size = 10,
        sweep_line_length = 1024.0):

    img = ndimage.imread(path)

    print("Image read")
    img_gray_not_filtered = grayscale_luminosity(img)
    img_gray = ndimage.median_filter(img_gray_not_filtered,
            size=median_filter_size)
    # fig = plt.figure(figsize=(15,15))
    otsu_lvl = threshold_otsu(img_gray)
    print(otsu_lvl)
    mask = img_gray <= otsu_lvl

    # plt.imshow(mask, cmap="Greys_r")
    # plt.show()

    otsu_sub_mask_lvl = threshold_otsu(img_gray[mask])
    mask2 = img_gray <= otsu_sub_mask_lvl

    sub_mask = mask * mask2
    sub_mask2 = mask * (True ^ mask2)

    # plt.imshow(sub_mask, cmap="Greys_r")
    # plt.show()

    # plt.imshow(sub_mask2, cmap="Greys_r")
    # plt.show()

    label_im, nb_labels = ndimage.label(sub_mask2)
    print("Nr of regions: {}".format(nb_labels))

    sizes = ndimage.sum(sub_mask2, label_im, range(nb_labels + 1))

    mask_size = sizes < 500
    print("Liczba usunietych: {}".format(np.sum(mask_size)))
    remove_pixel = mask_size[label_im]
    label_im[remove_pixel] = 0
    # plt.imshow(label_im)
    # plt.show()

    centers = np.zeros((nb_labels+1-np.sum(mask_size),2))
    idx = 0
    for i in range(0, nb_labels+1):
        if mask_size[i]:
            continue

        # processing that means calculating segment parameters

        mask_segment = label_im == i
        centers[idx,:] = ndimage.center_of_mass(mask_segment)
        idx += 1

    print("Centers: {}".format(centers))
    centers2 = np.copy(centers)
    (n,m) = img_gray.shape
    centers2[:,0] = centers[:,0] / float(n)
    centers2[:,1] = centers[:,0] / float(m)

    print("Centers 2: {}".format(centers2))

    mask_int = np.array(mask, dtype=np.uint32)
    label_im, nb_labels = ndimage.label(mask_int)
    sizes = ndimage.sum(mask, label_im, range(nb_labels + 1))
    mask_size = sizes < 0
    remove_pixel = mask_size[label_im]
    label_im[remove_pixel] = 0

    edges2 = feature.canny(label_im > 0)

    # plt.imshow(edges2, cmap="Greys_r")

    sx = 0.0
    sy = 0.0
    c = 0

    (n,m) = edges2.shape
    print("Width: {}, height: {}".format(n, m))
    for i in range(0, n):
        for j in range(0, m):
            if edges2[i,j]:
                c += 1
                sx += j
                sy += i

    sx /= c
    sy /= c

    print("Middle ({}x{})".format(sx,sy))
    # plt.plot([sx], [sy], "r+")

    line_length = sweep_line_length
    angle_step = 2.0 * np.pi / nsampling
    dist_vector = np.zeros((nsampling), dtype=np.float32)
    dist_vector[:] = 1e8 # TODO add here infinitium
    for i in range(0,nsampling):
        line_arr = np.zeros((n,m), dtype=np.bool)
        rr,cc,val = line_aa(int(math.floor(sy)),
                            int(math.floor(sx)),
                            int(math.floor(sy+line_length*math.sin(angle_step*i))),
                            int(math.floor(sx+line_length*math.cos(angle_step*i))))
        m1 = rr >= 0
        m2 = rr < n
        m3 = cc >= 0
        m4 = cc < m
        mall = m1 * m2 * m3 * m4
        cc2 = cc[mall]
        rr2 = rr[mall]
        val2 = val[mall]

        line_arr[rr2,cc2] = val2 > 0
        crossing_mask = (edges2 * line_arr)
        # TODO find one (first - best) crossing point and mark him
        p = find_first_marked_point(crossing_mask)
        if p != None:
            plt.plot([p[0]], [p[1]], "r+")
            dist_vector[i] = math.sqrt((p[0] - sx)*(p[0] - sx) +  (p[1] - sy)*(p[1] - sy))

    # plt.show()  
    return (mask, dist_vector / np.min(dist_vector), centers2)

if len(sys.argv) < 2:
    print("Too small arguments")
    sys.exit(1)

(img_gray1, v1, centers1) = calculate_shape_vector(sys.argv[1],nsampling)
if len(sys.argv) < 3:
    print("To small arguments to compare")
    sys.exit(0)

(img_gray2, v2, centers2) = calculate_shape_vector(sys.argv[2],nsampling)
# (img_gray2, v2) = calculate_shape_vector("italian-sole-new-800_1_2_1.jpg",nsampling)

def compare_sets(c1, c2):
    k = spatial.KDTree(c1)
    s = 0.0
    (n,m) = c2.shape
    for i in range(0, n):
        (d,idx) = k.query(c2[i,:])
        s += d*d

    s /= float(n)
    return math.sqrt(s)

print("v1")
print(v1)
print("v2")
print(v2)
diff = v1-v2
print("Diff: {}".format(diff))
mse = np.sum(diff*diff)/nsampling
print("MSE: {}".format(MSE(v1,v2)))

print("Diff between sets of points: {}"
        .format(compare_sets(centers1, centers2)))