import matplotlib.pyplot as plt
import numpy as np

"""
Sample code for K Nearest Neighbor algorithm. The sample considers three-dimansional points,
however, it can work with more dimensions by modifying the data array (10 by 3 matrix) to be
any number of rows (instead of 10) and any number of columns (features of a point).

KNN method gets the training array of points, uses the first point as the one to test and gets 
an arbitrary k=3. All may be changed to suit other cases. Try for example k=4

The end result are the k closest points to the test point
"""



def euclidean_distance(p1, p2):
    """ Calculates the distance between two points"""
    d = 0.0
    for i in range(len(p1)):
        a = float(p1[i])
        b = float(p2[i])
        d += np.power((a - b), 2)
    d = np.sqrt(d)
    return d

def KNN(train, test, K):
    """

    :param train: The list of trainig points
    :param test:  The point we are checking
    :param K: The number of closest neighbors we wanr
    :return: The k closest points

    We loop over the training points, and for each, we check the distance
    to the test point. The result is the distance list. We sort it from
    the lowest to highest distance. We return the top k points as a list.

    """
    distances = []
    NN = []
    for p in train:
        dist = euclidean_distance(test, p)
        distances.append((p, dist))
    distances.sort(key=lambda dist: dist[1])
    distances = distances[1:]  # remove the test point itself
    #print(distances)
    for i in range(K):
        NN.append(distances[i][0])
    return NN

def main():
    data = np.array([[6.0, 7.0, 1.0],
                     [2.0, 3.0, 2.0],
                     [3.0, 7.0, 3.0],
                     [4.0, 4.0, 4.0],
                     [5.0, 8.0, 5.0],
                     [6.0, 5.0, 6.0],
                     [7.0, 9.0, 7.0],
                     [8.0, 5.0, 8.0],
                     [8.0, 2.0, 9.0],
                     [10.0, 2.0, 10.0]])
    neighbors = KNN(data, data[0], 3)
    for neighbor in neighbors:
        print(neighbor)

if __name__ == '__main__':
    main()