import numpy as np
from numba import njit, prange

# Numba support for object pointers is currently (Q4 2019) wonky,
# which is why plain arrays with indices are used instead.

@njit("i8(i8[:], i8[:], i8[:], i8[:], i8[:], f4[:, :, :], f4[:], f4[:, :], i8[:], i8)", cache=True, nogil=True)
def _make_tree(
    i0_inds,
    i1_inds,
    less_inds,
    more_inds,
    split_dims,
    bounds,
    split_values,
    points,
    indices,
    min_leaf_size,
):
    dimension = points.shape[1]
    # Expect log2(len(points) / min_leaf_size) depth, 1000 should be plenty
    stack = np.empty(1000, np.int64)
    stack_size = 0
    n_nodes = 0
    # min_leaf_size <= leaf_node_size < max_leaf_size
    max_leaf_size = 2 * min_leaf_size

    # Push i0, i1, i_node
    stack[stack_size] = 0
    stack_size += 1
    stack[stack_size] = points.shape[0]
    stack_size += 1
    stack[stack_size] = n_nodes
    n_nodes += 1
    stack_size += 1

    # While there are more tree nodes to process recursively
    while stack_size > 0:
        stack_size -= 1
        i_node = stack[stack_size]
        stack_size -= 1
        i1 = stack[stack_size]
        stack_size -= 1
        i0 = stack[stack_size]

        lo = bounds[i_node, 0]
        hi = bounds[i_node, 1]

        for d in range(dimension):
            lo[d] = points[i0, d]
            hi[d] = points[i0, d]

        # Find lower and upper bounds of points for each dimension
        for i in range(i0 + 1, i1):
            for d in range(dimension):
                lo[d] = min(lo[d], points[i, d])
                hi[d] = max(hi[d], points[i, d])

        # Done if node is small
        if i1 - i0 <= max_leaf_size:
            i0_inds[i_node] = i0
            i1_inds[i_node] = i1
            less_inds[i_node] = -1
            more_inds[i_node] = -1
            split_dims[i_node] = -1
            split_values[i_node] = 0.0
        else:
            # Split on largest dimension
            lengths = hi - lo
            split_dim = np.argmax(lengths)
            split_value = lo[split_dim] + 0.5 * lengths[split_dim]

            # Partition i0:i1 range into points where points[i, split_dim] < split_value
            i = i0
            j = i1 - 1

            while i < j:
                while i < j and points[i, split_dim] < split_value:
                    i += 1
                while i < j and points[j, split_dim] >= split_value:
                    j -= 1

                # Swap points
                if i < j:
                    for d in range(dimension):
                        temp = points[i, d]
                        points[i, d] = points[j, d]
                        points[j, d] = temp

                    temp_i_node = indices[i]
                    indices[i] = indices[j]
                    indices[j] = temp_i_node

            if points[i, split_dim] < split_value:
                i += 1

            i_split = i

            # Now it holds that:
            # for i in range(i0, i_split): assert(points[i, split_dim] < split_value)
            # for i in range(i_split, i1): assert(points[i, split_dim] >= split_value)

            # Ensure that each node has at least min_leaf_size children
            i_split = max(i0 + min_leaf_size, min(i1 - min_leaf_size, i_split))

            less = n_nodes
            n_nodes += 1
            more = n_nodes
            n_nodes += 1

            # push i0, i_split, less
            stack[stack_size] = i0
            stack_size += 1
            stack[stack_size] = i_split
            stack_size += 1
            stack[stack_size] = less
            stack_size += 1

            # push i_split, i1, more
            stack[stack_size] = i_split
            stack_size += 1
            stack[stack_size] = i1
            stack_size += 1
            stack[stack_size] = more
            stack_size += 1

            i0_inds[i_node] = i0
            i1_inds[i_node] = i1
            less_inds[i_node] = less
            more_inds[i_node] = more
            split_dims[i_node] = split_dim
            split_values[i_node] = split_value

    return n_nodes


@njit("void(i8[:], i8[:], i8[:], i8[:], i8[:], f4[:, :, :], f4[:], f4[:, :], f4[:, :], i8[:, :], f4[:, :], i8)", cache=True, nogil=True, parallel=True)
def _find_knn(
    i0_inds,
    i1_inds,
    less_inds,
    more_inds,
    split_dims,
    bounds,
    split_values,
    points,
    query_points,
    out_indices,
    out_distances,
    k,
):
    dimension = points.shape[1]

    # For each query point
    for i_query in prange(query_points.shape[0]):
        query_point = query_points[i_query]

        distances = out_distances[i_query]
        indices = out_indices[i_query]

        # Expect log2(len(points) / min_leaf_size) depth, 1000 should be plenty
        stack = np.empty(1000, np.int64)

        n_neighbors = 0
        stack[0] = 0
        stack_size = 1

        # While there are nodes to visit
        while stack_size > 0:
            stack_size -= 1
            i_node = stack[stack_size]

            # If we found more neighbors than we need
            if n_neighbors >= k:
                # Calculate distance to bounding box of node
                dist = 0.0
                for d in range(dimension):
                    p = query_point[d]
                    dp = p - max(bounds[i_node, 0, d], min(bounds[i_node, 1, d], p))
                    dist += dp * dp

                # Do nothing with this node if all points we have found so far
                # are closer than the bounding box of the node.
                if dist > distances[n_neighbors - 1]:
                    continue

            # If leaf node
            if split_dims[i_node] == -1:
                # For each point in leaf node
                for i in range(i0_inds[i_node], i1_inds[i_node]):
                    # Calculate distance between query point and point in node
                    distance = 0.0
                    for d in range(dimension):
                        dd = query_point[d] - points[i, d]
                        distance += dd * dd

                    # Find insert position
                    insert_pos = n_neighbors
                    for j in range(n_neighbors - 1, -1, -1):
                        if distances[j] > distance:
                            insert_pos = j

                    # Insert found point in a sorted order
                    if insert_pos < k:
                        # Move [insert_pos:k-1] one to the right to make space
                        for j in range(min(n_neighbors, k - 1), insert_pos, -1):
                            indices[j] = indices[j - 1]
                            distances[j] = distances[j - 1]

                        # Insert new neighbor
                        indices[insert_pos] = i
                        distances[insert_pos] = distance
                        n_neighbors = min(n_neighbors + 1, k)
            else:
                # Descent to child nodes
                less = less_inds[i_node]
                more = more_inds[i_node]
                split_dim = split_dims[i_node]

                # First, visit child in same bounding box as query point
                if query_point[split_dim] < split_values[i_node]:
                    stack[stack_size] = more
                    stack_size += 1
                    stack[stack_size] = less
                    stack_size += 1
                else:
                    # Next, visit other child
                    stack[stack_size] = less
                    stack_size += 1
                    stack[stack_size] = more
                    stack_size += 1

        # Workaround for https://github.com/numba/numba/issues/5156
        stack_size += 0

class KDTree(object):
    """KDTree implementation"""

    def __init__(self, data_points, min_leaf_size=8):
        """Constructs a KDTree for given data points. The implementation currently only supports data type `np.float32`.

        Parameters
        ----------
        data_points: numpy.ndarray (of type `np.float32`)
            Dataset with shape :math:`n \\times d`, where :math:`n` is the number of data points in the data set and :math:`d` is the dimension of each data point
        min_leaf_size: int
            Minimum number of nodes in a leaf, defaults to 8

        Example
        -------
        >>> from pymatting import *
        >>> import numpy as np
        >>> data_set = np.random.randn(100, 2)
        >>> tree = KDTree(data_set.astype(np.float32))
        """
        assert data_points.dtype == np.float32

        n_data, dimension = data_points.shape

        max_nodes = 2 * ((n_data + min_leaf_size - 1) // min_leaf_size)

        self.i0_inds = np.empty(max_nodes, np.int64)
        self.i1_inds = np.empty(max_nodes, np.int64)
        self.less_inds = np.empty(max_nodes, np.int64)
        self.more_inds = np.empty(max_nodes, np.int64)
        self.split_dims = np.empty(max_nodes, np.int64)
        self.bounds = np.empty((max_nodes, 2, dimension), np.float32)
        self.split_values = np.empty(max_nodes, np.float32)
        self.shuffled_data_points = data_points.copy()
        self.shuffled_indices = np.arange(n_data).astype(np.int64)

        self.n_nodes = _make_tree(
            self.i0_inds,
            self.i1_inds,
            self.less_inds,
            self.more_inds,
            self.split_dims,
            self.bounds,
            self.split_values,
            self.shuffled_data_points,
            self.shuffled_indices,
            min_leaf_size,
        )

    def query(self, query_points, k):
        """Query the tree

        Parameters
        ----------
        query_points: numpy.ndarray (of type `np.float32`)
            Data points for which the next neighbours should be calculated
        k: int
            Number of neighbors to find

        Returns
        -------
        distances: numpy.ndarray
            Distances to the neighbors
        indices: numpy.ndarray
            Indices of the k nearest neighbors in original data array

        Example
        -------
        >>> from pymatting import *
        >>> import numpy as np
        >>> data_set = np.random.randn(100, 2)
        >>> tree = KDTree(data_set.astype(np.float32))
        >>> tree.query(np.array([[0.5,0.5]], dtype=np.float32), k=3)
        (array([[0.14234178, 0.15879704, 0.26760164]], dtype=float32), array([[29, 21, 20]]))
        """
        assert query_points.dtype == np.float32

        if self.shuffled_data_points.shape[0] < k:
            raise ValueError(
                f"Number of data points ({self.shuffled_data_points.shape[0]}) is less than the number of neighbors requested (k={k})."
                " Please provide a larger dataset or reduce the value of k."
            )

        n_query = query_points.shape[0]

        squared_distances = np.empty((n_query, k), np.float32)
        indices = np.empty((n_query, k), np.int64)

        _find_knn(
            self.i0_inds,
            self.i1_inds,
            self.less_inds,
            self.more_inds,
            self.split_dims,
            self.bounds,
            self.split_values,
            self.shuffled_data_points,
            query_points,
            indices,
            squared_distances,
            k,
        )

        indices = self.shuffled_indices[indices]
        distances = np.sqrt(squared_distances)

        return distances, indices


def knn(data_points, query_points, k):
    """Find k nearest neighbors in a data set. The implementation currently only supports data type `np.float32`.

    Parameters
    ----------
    data_points: numpy.ndarray (of type `np.float32`)
        Dataset with shape :math:`n \\times d`, where :math:`n` is the number of data points in the data set and :math:`d` is the dimension of each data point
    query_points: numpy.ndarray (of type `np.float32`)
        Data points for which the next neighbours should be calculated
    k: int
        Number of neighbors to find

    Returns
    -------
    distances: numpy.ndarray
        Distances to the neighbors
    indices: numpy.ndarray
        Indices of the k nearest neighbors in original data array

    Example
    -------
    >>> from pymatting import *
    >>> import numpy as np
    >>> data_set = np.random.randn(100, 2)
    >>> knn(data_set.astype(np.float32), np.array([[0.5,0.5]], dtype=np.float32), k=2)
    (array([[0.16233477, 0.25393516]], dtype=float32), array([[25, 17]]))
    """
    tree = KDTree(data_points)
    return tree.query(query_points, k)
