Skip to content

Instantly share code, notes, and snippets.

@beginor
Created July 28, 2014 23:45
Show Gist options
  • Select an option

  • Save beginor/32dce6904a556474e7ad to your computer and use it in GitHub Desktop.

Select an option

Save beginor/32dce6904a556474e7ad to your computer and use it in GitHub Desktop.

Revisions

  1. @agjacome agjacome created this gist Mar 26, 2013.
    256 changes: 256 additions & 0 deletions KdTree.java
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,256 @@
    /****************************************************************************
    * Author: Alberto Gutiérrez Jácome <[email protected]>
    * Date: 16/09/2012
    *
    * Compilation: javac KdTree.java
    * Execution: not applicable
    * Dependencies: Point2D.java RectHV.java StdDraw.java Queue.java
    *
    * Description: A mutable data type that uses a 2d-tree to represent a set of
    * points in the unit square. A 2d-tree is a generalization of a BST to
    * two-dimensional keys. The idea is to build a BST with points in the nodes,
    * using the x and y coordinates of the points as keys in strictly alternating
    * sequence. The prime advantage of a 2d-tree over a BST is that it supports
    * efficient implementation of range search and nearest neighbor search. Each
    * node corresponds to an axis-aligned rectangle in the unit square, which
    * encloses all of the points in its subtree. The root corresponds to the unit
    * square; the left and right children of the root corresponds to the two
    * rectangles split by the x-coordinate of the point at the root; and so
    * forth.
    *
    ***************************************************************************/

    public class KdTree
    {

    // helper data type representing a node of a kd-tree
    private static class KdNode
    {
    private KdNode left;
    private KdNode right;
    private final boolean vertical;
    private final double x;
    private final double y;

    public KdNode(final double x, final double y, final KdNode l,
    final KdNode r, final boolean v)
    {
    this.x = x;
    this.y = y;
    left = l;
    right = r;
    vertical = v;
    }
    }

    private static final RectHV CONTAINER = new RectHV(0, 0, 1, 1);
    private KdNode root;
    private int size;

    // construct an empty tree of points
    public KdTree()
    {
    size = 0;
    root = null;
    }

    // does the tree contain the point p?
    public boolean contains(final Point2D p)
    {
    return contains(root, p.x(), p.y());
    }

    // helper: does the subtree rooted at node contain (x, y)?
    private boolean contains(KdNode node, double x, double y)
    {
    if (node == null) return false;
    if (node.x == x && node.y == y) return true;

    if (node.vertical && x < node.x || !node.vertical && y < node.y)
    return contains(node.left, x, y);
    else
    return contains(node.right, x, y);
    }

    // draw all of the points to standard draw
    public void draw()
    {
    StdDraw.setScale(0, 1);

    StdDraw.setPenColor(StdDraw.BLACK);
    StdDraw.setPenRadius();
    CONTAINER.draw();

    draw(root, CONTAINER);
    }

    // helper: draw node point and its division line (given by rect)
    private void draw(final KdNode node, final RectHV rect)
    {
    if (node == null) return;

    // draw the point
    StdDraw.setPenColor(StdDraw.BLACK);
    StdDraw.setPenRadius(0.01);
    new Point2D(node.x, node.y).draw();

    // get the min and max points of division line
    Point2D min, max;
    if (node.vertical) {
    StdDraw.setPenColor(StdDraw.RED);
    min = new Point2D(node.x, rect.ymin());
    max = new Point2D(node.x, rect.ymax());
    } else {
    StdDraw.setPenColor(StdDraw.BLUE);
    min = new Point2D(rect.xmin(), node.y);
    max = new Point2D(rect.xmax(), node.y);
    }

    // draw that division line
    StdDraw.setPenRadius();
    min.drawTo(max);

    // recursively draw children
    draw(node.left, leftRect(rect, node));
    draw(node.right, rightRect(rect, node));
    }

    // helper: add point p to subtree rooted at node
    private KdNode insert(final KdNode node, final Point2D p,
    final boolean vertical)
    {
    // if new node, create it
    if (node == null) {
    size++;
    return new KdNode(p.x(), p.y(), null, null, vertical);
    }

    // if already in, return it
    if (node.x == p.x() && node.y == p.y()) return node;

    // else, insert it where corresponds (left - right recursive call)
    if (node.vertical && p.x() < node.x || !node.vertical && p.y() < node.y)
    node.left = insert(node.left, p, !node.vertical);
    else
    node.right = insert(node.right, p, !node.vertical);

    return node;
    }

    // add the point p to the tree (if it is not already in the tree)
    public void insert(final Point2D p)
    {
    root = insert(root, p, true);
    }

    // is the tree empty?
    public boolean isEmpty()
    {
    return size == 0;
    }

    // helper: get the left rectangle of node inside parent's rect
    private RectHV leftRect(final RectHV rect, final KdNode node)
    {
    if (node.vertical)
    return new RectHV(rect.xmin(), rect.ymin(), node.x, rect.ymax());
    else
    return new RectHV(rect.xmin(), rect.ymin(), rect.xmax(), node.y);
    }

    // helper: nearest neighbor of (x,y) in subtree rooted at node
    private Point2D nearest(final KdNode node, final RectHV rect,
    final double x, final double y, final Point2D candidate)
    {
    if (node == null) return candidate;

    double dqn = 0.0;
    double drq = 0.0;
    RectHV left = null;
    RectHV rigt = null;
    final Point2D query = new Point2D(x, y);
    Point2D nearest = candidate;

    if (nearest != null) {
    dqn = query.distanceSquaredTo(nearest);
    drq = rect.distanceSquaredTo(query);
    }

    if (nearest == null || dqn > drq) {
    final Point2D point = new Point2D(node.x, node.y);
    if (nearest == null || dqn > query.distanceSquaredTo(point))
    nearest = point;

    if (node.vertical) {
    left = new RectHV(rect.xmin(), rect.ymin(), node.x, rect.ymax());
    rigt = new RectHV(node.x, rect.ymin(), rect.xmax(), rect.ymax());

    if (x < node.x) {
    nearest = nearest(node.left, left, x, y, nearest);
    nearest = nearest(node.right, rigt, x, y, nearest);
    } else {
    nearest = nearest(node.right, rigt, x, y, nearest);
    nearest = nearest(node.left, left, x, y, nearest);
    }
    } else {
    left = new RectHV(rect.xmin(), rect.ymin(), rect.xmax(), node.y);
    rigt = new RectHV(rect.xmin(), node.y, rect.xmax(), rect.ymax());

    if (y < node.y) {
    nearest = nearest(node.left, left, x, y, nearest);
    nearest = nearest(node.right, rigt, x, y, nearest);
    } else {
    nearest = nearest(node.right, rigt, x, y, nearest);
    nearest = nearest(node.left, left, x, y, nearest);
    }
    }
    }

    return nearest;
    }

    // a nearest neighbor in the set to p; null if set is empty
    public Point2D nearest(final Point2D p)
    {
    return nearest(root, CONTAINER, p.x(), p.y(), null);
    }

    // helper: points in subtree rooted at node inside rect
    private void range(final KdNode node, final RectHV nrect,
    final RectHV rect, final Queue<Point2D> queue)
    {
    if (node == null) return;

    if (rect.intersects(nrect)) {
    final Point2D p = new Point2D(node.x, node.y);
    if (rect.contains(p)) queue.enqueue(p);
    range(node.left, leftRect(nrect, node), rect, queue);
    range(node.right, rightRect(nrect, node), rect, queue);
    }
    }

    // all points in the set that are inside the rectangle
    public Iterable<Point2D> range(final RectHV rect)
    {
    final Queue<Point2D> queue = new Queue<Point2D>();
    range(root, CONTAINER, rect, queue);

    return queue;
    }

    // helper: get the right rectangle of node inside parent's rect
    private RectHV rightRect(final RectHV rect, final KdNode node)
    {
    if (node.vertical)
    return new RectHV(node.x, rect.ymin(), rect.xmax(), rect.ymax());
    else
    return new RectHV(rect.xmin(), node.y, rect.xmax(), rect.ymax());
    }

    // number of points in the tree
    public int size()
    {
    return size;
    }

    }