Skip to content

Commit

Permalink
Addition binop support
Browse files Browse the repository at this point in the history
  • Loading branch information
averyanalex committed Dec 18, 2023
1 parent 6fcdaff commit 6903aa0
Show file tree
Hide file tree
Showing 9 changed files with 274 additions and 168 deletions.
22 changes: 11 additions & 11 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions-rust-lang/setup-rust-toolchain@v1
with:
toolchain: nightly
components: rustfmt
- name: Run tests
run: cargo test --release --verbose
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
- uses: actions/checkout@v4
- uses: actions-rust-lang/setup-rust-toolchain@v1
- uses: Swatinem/rust-cache@v2
- name: Install native libraries
run: sudo apt-get install coinor-cbc coinor-libcbc-dev
- name: Run tests
run: cargo test --release
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
4 changes: 2 additions & 2 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
system: let
overlays = [(import rust-overlay)];
pkgs = import nixpkgs {inherit system overlays;};
rustNightly = pkgs.rust-bin.nightly.latest.default;
rust = pkgs.rust-bin.fromRustupToolchainFile ./rust-toolchain.toml;
in {
devShells = {
default = pkgs.mkShell {
buildInputs = [rustNightly pkgs.cbc pkgs.cmake];
buildInputs = [rust pkgs.cbc pkgs.cmake];
LIBCLANG_PATH = "${pkgs.llvmPackages.libclang.lib}/lib";
shellHook = with pkgs; ''
export BINDGEN_EXTRA_CLANG_ARGS="$(< ${stdenv.cc}/nix-support/libc-crt1-cflags) \
Expand Down
2 changes: 2 additions & 0 deletions rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[toolchain]
channel = "nightly"
2 changes: 0 additions & 2 deletions src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ define_language! {
"/" = Div([Id; 2]),
"==" = Eq([Id; 2]),
"?" = Ternary([Id; 3]),
"i" = Index([Id; 2]),
"ir" = IndexRange([Id; 3]),
Constant(BigUint),
Argument(ArgumentInfo),
}
Expand Down
117 changes: 0 additions & 117 deletions src/circuit.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::fmt::Display;

use rustc_hash::{FxHashMap, FxHashSet};

#[derive(Debug)]
Expand All @@ -8,34 +6,6 @@ pub struct GateX {
target: Qubit,
}

impl GateX {
pub fn format_qasm(
&self,
_f: &mut std::fmt::Formatter<'_>,
_map: &FxHashMap<u32, QubitDesc>,
) -> std::fmt::Result {
Ok(())
// match self {
// QGate::X(arg) => f.write_fmt(format_args!("x {}", map[arg])),
// QGate::CX { control, target } => {
// f.write_fmt(format_args!("cx {}, {}", map[control], map[target]))
// }
// QGate::Toffoli {
// control_0,
// control_1,
// target,
// } => f.write_fmt(format_args!(
// "ccx {}, {}, {}",
// map[control_0], map[control_1], map[target]
// )),
// QGate::MultiCX {
// controls: _,
// target: _,
// } => f.write_fmt(format_args!("mcx CRINGE FIXME")),
// }
}
}

#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub struct Qubit(pub u32);

Expand All @@ -52,32 +22,15 @@ pub enum QubitRegister {
Argument(String),
}

impl Display for QubitRegister {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
QubitRegister::Ancillary => f.write_str("ancilla"),
QubitRegister::Result => f.write_str("result"),
QubitRegister::Argument(arg) => write!(f, "{arg}"),
}
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct QubitDesc {
pub reg: QubitRegister,
pub index: u32,
}

impl Display for QubitDesc {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{}[{}]", self.reg, self.index))
}
}

#[derive(Debug, Default)]
pub struct Circuit {
pub qubits_count: u32,
// free_ancillas: Vec<Qubit>,
pub gates: Vec<GateX>,
qubits_map: FxHashMap<Qubit, FxHashSet<QubitDesc>>,
}
Expand All @@ -87,11 +40,6 @@ impl Circuit {
self.qubits_map
.entry(qubit)
.and_modify(|set| {
// if description.reg == QubitRegister::Result {
// assert!(set
// .iter()
// .all(|desc| matches!(desc.reg, QubitRegister::Argument(_))));
// }
set.insert(description.clone());
})
.or_insert_with(|| {
Expand Down Expand Up @@ -128,9 +76,6 @@ impl Circuit {
pub fn execute(&self, args: &FxHashMap<String, Vec<bool>>) -> Vec<bool> {
let mut qubits = FxHashMap::default();

// dbg!(&self.qubits_map);

// let qubit_map = self.fill_qubit_map();
for (qubit, values) in &self.qubits_map {
for value in values {
if let QubitRegister::Argument(arg) = value.reg.clone() {
Expand All @@ -139,15 +84,11 @@ impl Circuit {
}
}

// dbg!(&qubits);
// dbg!(qubits.len());

for gate in &self.gates {
qubits.insert(
gate.target,
qubits.get(&gate.target).unwrap_or(&false)
^ gate.controls.iter().fold(true, |acc, (qubit, inverted)| {
// dbg!(&qubit);
acc & (qubits[qubit] ^ inverted)
}),
);
Expand Down Expand Up @@ -180,62 +121,4 @@ impl Circuit {

result
}

// fn fill_qubit_map(&self) -> FxHashMap<Qubit, QubitDesc> {
// let mut map = self.qubits_map.clone();
// let mut anc_index = 0;
// for id in 0..self.qubits_count {
// map.entry(Qubit(id)).or_insert_with(|| {
// let desc = QubitDesc {
// reg: QubitRegister::Ancillary,
// index: anc_index,
// };
// anc_index += 1;
// desc
// });
// }
// map
// }
}

// impl Display for Circuit {
// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// let mut map = self.qubits_map.clone();
// let mut anc_index = 0;
// for id in 0..self.next_id {
// map.entry(id).or_insert(QubitDesc {
// reg: QubitRegister::Ancillary,
// index: anc_index,
// });
// anc_index += 1;
// }

// f.write_str("OPENQASM 3.0;\n")?;
// f.write_str("include \"stdgates.inc\";\n\n")?;

// f.write_str("def app")?;

// f.write_fmt(format_args!(" qubit[{anc_index}] ancilla"))?;
// f.write_fmt(format_args!(
// ", qubit[{}] result",
// map.iter()
// .filter(|(_k, v)| v.reg == QubitRegister::Result)
// .count()
// ))?;
// for (name, len) in &self.arguments {
// f.write_fmt(format_args!(", qubit[{}] {}", len, name))?;
// }
// f.write_str(" {\n")?;

// for gate in &self.gates {
// f.write_str(" ")?;
// gate.format_qasm(f, &map)?;
// f.write_str(";\n")?;
// }

// // f.write_str(" reset ancillas;\n")?;

// f.write_str("}\n")?;
// Ok(())
// }
// }
25 changes: 18 additions & 7 deletions src/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use egg::{Id, Language, RecExpr};
use itertools::Itertools;
use petgraph::prelude::*;
use petgraph::{dot, prelude::*};
use rustc_hash::{FxHashMap, FxHashSet};

use crate::{
logic::Logic,
circuit::{Circuit, Qubit, QubitDesc, QubitRegister},
logic::Logic,
};

pub struct Compiler {
Expand Down Expand Up @@ -79,6 +79,8 @@ impl Compiler {
}
}

dbg!(dot::Dot::with_config(&graph, &[dot::Config::EdgeNoLabel]));

Self {
circuit,
graph,
Expand Down Expand Up @@ -160,7 +162,7 @@ impl Compiler {
mcx_sources.insert((source_qubit, false));
} else {
match graph[source].kind {
LogicNodeKind::And => collect_sources_of_and(and, graph, mcx_sources),
LogicNodeKind::And => collect_sources_of_and(source, graph, mcx_sources),
LogicNodeKind::Not => {
let arg_of_not = graph
.neighbors_directed(source, Direction::Incoming)
Expand Down Expand Up @@ -236,10 +238,19 @@ impl Compiler {
LogicNodeKind::And => {
self.construct_mcx(target, target_qubit);
}
LogicNodeKind::Not => {
self.circuit
.cx(self.graph[source].qubit.unwrap(), true, target_qubit);
}
LogicNodeKind::Not => match self.graph[source].qubit {
Some(qubit) => {
self.circuit.cx(qubit, true, target_qubit);
}
None => {
if matches!(self.graph[source].kind, LogicNodeKind::And) {
self.construct_mcx(target, target_qubit);
self.circuit.x(target_qubit);
} else {
todo!();
}
}
},
LogicNodeKind::Arg => todo!(),
LogicNodeKind::Register => todo!(),
LogicNodeKind::Constant(_) => todo!(),
Expand Down
57 changes: 48 additions & 9 deletions src/executor.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
// use std::{ops::Sub, rc::Rc};

use egg::{Id, Language, RecExpr};
// use num::{One, Zero};
use itertools::Itertools;
use num::{BigUint, FromPrimitive, One, Zero};
use rustc_hash::FxHashMap;

use crate::logic::Logic;
use crate::{builder::Op, logic::Logic};

pub fn execute_gates(logic: &RecExpr<Logic>, args: &FxHashMap<String, Vec<bool>>) -> Vec<bool>
// where
// T: Zero + One + Sub<Output = T>,
// T: Clone,
{
pub fn execute_logic(logic: &RecExpr<Logic>, args: &FxHashMap<String, Vec<bool>>) -> Vec<bool> {
let mut done: FxHashMap<Id, bool> = FxHashMap::default();

for (idx, op) in logic.as_ref().iter().enumerate() {
Expand All @@ -37,3 +32,47 @@ pub fn execute_gates(logic: &RecExpr<Logic>, args: &FxHashMap<String, Vec<bool>>
.map(|c| done[c])
.collect()
}

pub fn execute_op(op: &RecExpr<Op>, args: &FxHashMap<String, BigUint>) -> BigUint {
let mut done: FxHashMap<Id, BigUint> = FxHashMap::default();

for (idx, op) in op.as_ref().iter().enumerate() {
let result = match op {
Op::Not(arg) => {
let digits = done[arg].iter_u64_digits().map(|d| !d).collect_vec();
assert_eq!(digits.len(), 1);
BigUint::from_u64(digits[0]).unwrap()
}
Op::Xor([a, b]) => done[a].clone() ^ done[b].clone(),
Op::Or([a, b]) => done[a].clone() | done[b].clone(),
Op::And([a, b]) => done[a].clone() & done[b].clone(),
Op::Shr([a, b]) => done[a].clone() >> u128::try_from(done[b].clone()).unwrap(),
Op::Shl([a, b]) => done[a].clone() << u128::try_from(done[b].clone()).unwrap(),
Op::Add([a, b]) => done[a].clone() + done[b].clone(),
Op::Sub([a, b]) => done[a].clone() - done[b].clone(),
Op::Mul([a, b]) => done[a].clone() * done[b].clone(),
Op::Div([a, b]) => done[a].clone() / done[b].clone(),
Op::Eq([a, b]) => {
if done[a] == done[b] {
BigUint::one()
} else {
BigUint::zero()
}
}
Op::Ternary([cond, then, or]) => {
if done[cond].is_one() {
done[then].clone()
} else if done[cond].is_zero() {
done[or].clone()
} else {
panic!("expected condition to be 1 or 0, got {}", done[cond])
}
}
Op::Constant(c) => c.clone(),
Op::Argument(arg) => args[&arg.name].clone(),
};
done.insert(Id::from(idx), result);
}

done[&Id::from(op.as_ref().len() - 1)].clone()
}
Loading

0 comments on commit 6903aa0

Please sign in to comment.