Skip to content

Commit

Permalink
Merge pull request #269 from adwhit/alex/memory-leak
Browse files Browse the repository at this point in the history
Fix memory leak
  • Loading branch information
charypar authored Oct 8, 2024
2 parents 18c1e36 + dab170a commit 9e249e5
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 49 deletions.
295 changes: 249 additions & 46 deletions crux_core/src/capability/executor.rs
Original file line number Diff line number Diff line change
@@ -1,96 +1,299 @@
use std::{
sync::{Arc, Mutex},
task::Context,
task::{Context, Wake},
};

use crossbeam_channel::{Receiver, Sender};
use futures::{
future,
task::{waker_ref, ArcWake},
Future, FutureExt,
};
use futures::{future, Future, FutureExt};
use slab::Slab;

type BoxFuture = future::BoxFuture<'static, ()>;

// used in docs/internals/runtime.md
// ANCHOR: executor
pub(crate) struct QueuingExecutor {
ready_queue: Receiver<Arc<Task>>,
spawn_queue: Receiver<BoxFuture>,
ready_queue: Receiver<TaskId>,
ready_sender: Sender<TaskId>,
tasks: Mutex<Slab<Option<BoxFuture>>>,
}
// ANCHOR_END: executor

// used in docs/internals/runtime.md
// ANCHOR: spawner
#[derive(Clone)]
pub struct Spawner {
task_sender: Sender<Arc<Task>>,
future_sender: Sender<BoxFuture>,
}
// ANCHOR_END: spawner

// used in docs/internals/runtime.md
// ANCHOR: task
struct Task {
future: Mutex<Option<future::BoxFuture<'static, ()>>>,
#[derive(Clone, Copy, Debug)]
struct TaskId(u32);

impl std::ops::Deref for TaskId {
type Target = u32;

task_sender: Sender<Arc<Task>>,
fn deref(&self) -> &Self::Target {
&self.0
}
}
// ANCHOR_END: task

pub(crate) fn executor_and_spawner() -> (QueuingExecutor, Spawner) {
let (task_sender, ready_queue) = crossbeam_channel::unbounded();
let (future_sender, spawn_queue) = crossbeam_channel::unbounded();
let (ready_sender, ready_queue) = crossbeam_channel::unbounded();

(QueuingExecutor { ready_queue }, Spawner { task_sender })
(
QueuingExecutor {
ready_queue,
spawn_queue,
ready_sender,
tasks: Mutex::new(Slab::new()),
},
Spawner { future_sender },
)
}

// used in docs/internals/runtime.md
// ANCHOR: spawning
impl Spawner {
pub fn spawn(&self, future: impl Future<Output = ()> + 'static + Send) {
let future = future.boxed();
let task = Arc::new(Task {
future: Mutex::new(Some(future)),
task_sender: self.task_sender.clone(),
});

self.task_sender
.send(task)
self.future_sender
.send(future)
.expect("unable to spawn an async task, task sender channel is disconnected.")
}
}
// ANCHOR_END: spawning

#[derive(Clone)]
struct TaskWaker {
task_id: TaskId,
sender: Sender<TaskId>,
}

// used in docs/internals/runtime.md
// ANCHOR: arc_wake
impl ArcWake for Task {
fn wake_by_ref(arc_self: &Arc<Self>) {
let cloned = arc_self.clone();
arc_self
.task_sender
.send(cloned)
.expect("unable to wake an async task, task sender channel is disconnected.")
// ANCHOR: wake
impl Wake for TaskWaker {
fn wake(self: Arc<Self>) {
self.wake_by_ref();
}

fn wake_by_ref(self: &Arc<Self>) {
// This send can fail if the executor has been dropped.
// In which case, nothing to do
let _ = self.sender.send(self.task_id);
}
}
// ANCHOR_END: arc_wake
// ANCHOR_END: wake

// used in docs/internals/runtime.md
// ANCHOR: run_all
impl QueuingExecutor {
pub fn run_all(&self) {
// While there are tasks to be processed
while let Ok(task) = self.ready_queue.try_recv() {
// Unlock the future in the Task
let mut future_slot = task.future.lock().unwrap();

// Take it, replace with None, ...
if let Some(mut future) = future_slot.take() {
let waker = waker_ref(&task);
let context = &mut Context::from_waker(&waker);

// ...and poll it
if future.as_mut().poll(context).is_pending() {
// If it's still pending, put it back
*future_slot = Some(future)
// we read off both queues and execute the tasks we receive.
// Since either queue can generate work for the other queue,
// we read from them in a loop until we are sure both queues
// are exhausted
let mut did_some_work = true;

while did_some_work {
did_some_work = false;
while let Ok(task) = self.spawn_queue.try_recv() {
let task_id = self
.tasks
.lock()
.expect("Task slab poisoned")
.insert(Some(task));
self.run_task(TaskId(task_id.try_into().expect("TaskId overflow")));
did_some_work = true;
}
while let Ok(task_id) = self.ready_queue.try_recv() {
match self.run_task(task_id) {
RunTask::Unavailable => {
// We were unable to run the task as it is (presumably) being run on
// another thread. We re-queue the task for 'later' and do NOT set
// `did_some_work = true`. That way we will keep looping and doing work
// until all remaining work is 'unavailable', at which point we will bail
// out of the loop, leaving the queued work to be finished by another thread.
// This strategy should avoid dropping work or busy-looping
self.ready_sender.send(task_id).expect("could not requeue");
}
RunTask::Missing => {
// This is possible if a naughty future sends a wake notification while
// still running, then runs to completion and is evicted from the slab.
// Nothing to be done.
}
RunTask::Suspended | RunTask::Completed => did_some_work = true,
}
}
}
}

fn run_task(&self, task_id: TaskId) -> RunTask {
let mut lock = self.tasks.lock().expect("Task slab poisoned");
let Some(task) = lock.get_mut(*task_id as usize) else {
return RunTask::Missing;
};
let Some(mut task) = task.take() else {
// the slot exists but the task is missing - presumably it
// is being executed on another thread
return RunTask::Unavailable;
};

// free the mutex so other threads can make progress
drop(lock);

let waker = Arc::new(TaskWaker {
task_id,
sender: self.ready_sender.clone(),
})
.into();
let context = &mut Context::from_waker(&waker);

// poll the task
if task.as_mut().poll(context).is_pending() {
// If it's still pending, put the future back in the slot
self.tasks
.lock()
.expect("Task slab poisoned")
.get_mut(*task_id as usize)
.expect("Task slot is missing")
.replace(task);
RunTask::Suspended
} else {
// otherwise the future is completed and we can free the slot
self.tasks.lock().unwrap().remove(*task_id as usize);
RunTask::Completed
}
}
}

enum RunTask {
Missing,
Unavailable,
Suspended,
Completed,
}

// ANCHOR_END: run_all

#[cfg(test)]
mod tests {

use rand::Rng;
use std::{
sync::atomic::{AtomicI32, Ordering},
task::Poll,
};

use super::*;
use crate::capability::shell_request::ShellRequest;

#[test]
fn test_task_does_not_leak() {
// Arc is a convenient RAII counter
let counter = Arc::new(());
assert_eq!(Arc::strong_count(&counter), 1);

let (executor, spawner) = executor_and_spawner();

let future = {
let counter = counter.clone();
async move {
assert_eq!(Arc::strong_count(&counter), 2);
ShellRequest::<()>::new().await;
}
};

spawner.spawn(future);
executor.run_all();
drop(executor);
drop(spawner);
assert_eq!(Arc::strong_count(&counter), 1);
}

#[test]
fn test_multithreaded_executor() {
// We define a future which chaotically sends notifications to wake up the task
// The future has a random chance to suspend or to defer to its children which
// may also suspend. However it will ultimately resolve to `Ready` and once it
// has done so will stay finished
struct Chaotic {
ready_once: bool,
children: Vec<Chaotic>,
}

static CHAOS_COUNT: AtomicI32 = AtomicI32::new(0);

impl Chaotic {
fn new_with_children(num_children: usize) -> Self {
CHAOS_COUNT.fetch_add(1, Ordering::SeqCst);
Self {
ready_once: false,
children: (0..num_children)
.map(|_| Chaotic::new_with_children(num_children - 1))
.collect(),
}
}
}

impl Future for Chaotic {
type Output = ();

fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// once we're done, we're done
if self.ready_once {
return Poll::Ready(());
}
if rand::thread_rng().gen_bool(0.1) {
cx.waker().wake_by_ref();
return Poll::Pending;
} else {
let mut ready = true;
let this = self.get_mut();
for child in &mut this.children {
if let Poll::Pending = child.poll_unpin(cx) {
ready = false;
}
}
if ready {
this.ready_once = true;
// throw a wake in for extra chaos
cx.waker().wake_by_ref();
CHAOS_COUNT.fetch_sub(1, Ordering::SeqCst);
Poll::Ready(())
} else {
Poll::Pending
}
}
}
}

let (executor, spawner) = executor_and_spawner();
// 100 futures with many (1957) children each equals lots of chaos
for _ in 0..100 {
let future = Chaotic::new_with_children(6);
spawner.spawn(future);
}
assert_eq!(CHAOS_COUNT.load(Ordering::SeqCst), 195700);
let executor = Arc::new(executor);
assert_eq!(executor.spawn_queue.len(), 100);

// Spawn 10 threads and run all
let handles = (0..10)
.map(|_| {
let executor = executor.clone();
std::thread::spawn(move || {
executor.run_all();
})
})
.collect::<Vec<_>>();
for handle in handles {
handle.join().unwrap();
}
// nothing left in queue, all futures resolved
assert_eq!(executor.spawn_queue.len(), 0);
assert_eq!(executor.ready_queue.len(), 0);
assert_eq!(CHAOS_COUNT.load(Ordering::SeqCst), 0);
}
}
16 changes: 15 additions & 1 deletion crux_core/src/capability/shell_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@ pub struct ShellRequest<T> {
shared_state: Arc<Mutex<SharedState<T>>>,
}

#[cfg(test)]
impl ShellRequest<()> {
pub(crate) fn new() -> Self {
Self {
shared_state: Arc::new(Mutex::new(SharedState {
result: None,
waker: None,
send_request: None,
})),
}
}
}

struct SharedState<T> {
result: Option<T>,
waker: Option<Waker>,
Expand All @@ -38,7 +51,8 @@ impl<T> Future for ShellRequest<T> {
match shared_state.result.take() {
Some(result) => Poll::Ready(result),
None => {
shared_state.waker = Some(cx.waker().clone());
let cloned_waker = cx.waker().clone();
shared_state.waker = Some(cloned_waker);
Poll::Pending
}
}
Expand Down
4 changes: 2 additions & 2 deletions docs/src/internals/runtime.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ The final piece of the puzzle is the executor itself:

This is the receiving end of the channel from the spawner.

The executor has a single method, `run_all`:
The executor has a single public method, `run_all`:

```rust,no_run,noplayground
{{#include ../../../crux_core/src/capability/executor.rs:run_all}}
Expand All @@ -180,7 +180,7 @@ call. The `waker_ref` creates a waker which, when asked to wake up, will call
this method on the task:

```rust,no_run,noplayground
{{#include ../../../crux_core/src/capability/executor.rs:arc_wake}}
{{#include ../../../crux_core/src/capability/executor.rs:wake}}
```

this is where the task resubmits itself for processing on the next run.
Expand Down

0 comments on commit 9e249e5

Please sign in to comment.