Skip to content

Commit

Permalink
handle overlapping points during removal (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrhooray authored Nov 30, 2024
1 parent 3164525 commit 965a9b1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/kdtree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,15 @@ impl<A: Float + Zero + One, T: std::cmp::PartialEq, U: AsRef<[A]> + std::cmp::Pa
let mut removed = 0;
self.check_point(point.as_ref())?;
if let (Some(mut points), Some(mut bucket)) = (self.points.take(), self.bucket.take()) {
while let Some(p_index) = points.iter().position(|x| x == point) {
if &bucket[p_index] == data {
points.remove(p_index);
bucket.remove(p_index);
removed += 1;
self.size -= 1;
}
while let Some(p_index) = points
.iter()
.zip(bucket.iter())
.position(|(p, d)| p == point && d == data)
{
points.remove(p_index);
bucket.remove(p_index);
removed += 1;
self.size -= 1;
}
self.points = Some(points);
self.bucket = Some(bucket);
Expand Down
18 changes: 18 additions & 0 deletions tests/kdtree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,21 @@ fn handles_remove_no_match() {
vec![(16.0, &4), (36.0, &3)]
);
}

#[test]
fn handles_remove_overlapping_points() {
let a = ([0f64, 0f64], 0);
let b = ([0f64, 0f64], 1);
let mut kdtree = KdTree::new(2);

kdtree.add(a.0, a.1).unwrap();
kdtree.add(b.0, b.1).unwrap();

let num_removed = kdtree.remove(&[0f64, 0f64], &1).unwrap();
assert_eq!(kdtree.size(), 1);
assert_eq!(num_removed, 1);
assert_eq!(
kdtree.nearest(&[0f64, 0f64], 1, &squared_euclidean).unwrap(),
vec![(0.0, &0)]
);
}

0 comments on commit 965a9b1

Please sign in to comment.