Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a few state-related cc ops #2354

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions include/cudaq/Optimizer/Dialect/CC/CCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,68 @@ def cc_AddressOfOp : CCOp<"address_of", [Pure,
}];
}

def cc_CreateStateOp : CCOp<"create_state", [Pure] > {
let summary = "Create state from data";
let description = [{
This operation takes a pointer to state data and creates a quantum state.
The operation can be optimized away in DeleteStates pass, or replaced
by an intrinsic runtime call on simulators.

```mlir
%0 = cc.create_state %data: !cc.ptr<!cc.state>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example only has 1 argument, but below it lists 2 arguments.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot to update, thanks!

```
}];

let arguments = (ins
AnyPointerType:$data,
AnySignlessInteger:$length
);
let results = (outs AnyPointerType:$result);
let assemblyFormat = [{
$data `,` $length `:` functional-type(operands, results) attr-dict
}];
}

def cc_GetNumberOfQubitsOp : CCOp<"get_number_of_qubits", [Pure] > {
let summary = "Get number of qubits from a quantum state";
let description = [{
This operation takes a state pointer argument and returns a number of
qubits in the state. The operation can be optimized away in some passes
line ReplaceStateByKernel or DeleteStates, or replaced by an intrinsic
runtime call on simulators.

```mlir
%0 = cc.get_number_of_qubits %state : i64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type on this example doesn't correspond to the assembly below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot to update, thanks!

```
}];

let arguments = (ins cc_PointerType:$state);
let results = (outs AnySignlessInteger:$result);
let assemblyFormat = [{
$state `:` functional-type(operands, results) attr-dict
}];
}

def cc_GetStateOp : CCOp<"get_state", [Pure] > {
let summary = "Get state from kernel with the provided name.";
let description = [{
This operation is created by argument synthesis of state pointer arguments
for quantum devices. It takes a kernel name as ASCIIZ string literal value
and returns the kernel's quantum state. The operation is replaced by a call
to the kernel with the provided name in ReplaceStateByKernel pass.

```mlir
%0 = cc.get_state "callee" : !cc.ptr<!cc.state>
```
}];

let arguments = (ins StrAttr:$calleeName);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably take a !cc.ptr<?> argument and make use of the cc.literal_string operation. That would be more flexible in the long run.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, will do

let results = (outs cc_PointerType:$result);
let assemblyFormat = [{
$calleeName `:` qualified(type(results)) attr-dict
}];
}

def cc_GlobalOp : CCOp<"global", [IsolatedFromAbove, Symbol]> {
let summary = "Create a global constant or variable";
let description = [{
Expand Down
5 changes: 2 additions & 3 deletions include/cudaq/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -752,9 +752,8 @@ def DeleteStates : Pass<"delete-states", "mlir::ModuleOp"> {
func.func @foo() attributes {"cudaq-entrypoint", "cudaq-kernel", no_this} {
%c8_i64 = arith.constant 8 : i64
%0 = cc.address_of @foo.rodata_synth_0 : !cc.ptr<!cc.array<complex<f32> x 8>>
%3 = cc.cast %0 : (!cc.ptr<!cc.array<complex<f32> x 8>>) -> !cc.ptr<i8>
%4 = call @__nvqpp_cudaq_state_createFromData_fp32(%3, %c8_i64) : (!cc.ptr<i8>, i64) -> !cc.ptr<!cc.state>
%5 = call @__nvqpp_cudaq_state_numberOfQubits(%4) : (!cc.ptr<!cc.state>) -> i64
%4 = cc.create_state %3, %c8_i64 : (!cc.ptr<!cc.array<complex<f32> x 8>>, i64) -> !cc.ptr<!cc.state>
%5 = cc.get_number_of_qubits %4 : (!cc.ptr<!cc.state>) -> i64
%6 = quake.alloca !quake.veq<?>[%5 : i64]
%7 = quake.init_state %6, %4 : (!quake.veq<?>, !cc.ptr<!cc.state>) -> !quake.veq<?>

Expand Down
13 changes: 3 additions & 10 deletions lib/Frontend/nvqpp/ConvertExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2694,19 +2694,12 @@ bool QuakeBridgeVisitor::VisitCXXConstructExpr(clang::CXXConstructExpr *x) {
initials = load.getPtrvalue();
}
if (isStateType(initials.getType())) {
IRBuilder irBuilder(builder.getContext());
auto mod =
builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
auto result =
irBuilder.loadIntrinsic(mod, getNumQubitsFromCudaqState);
assert(succeeded(result) && "loading intrinsic should never fail");
Value state = initials;
auto i64Ty = builder.getI64Type();
auto numQubits = builder.create<func::CallOp>(
loc, i64Ty, getNumQubitsFromCudaqState, ValueRange{state});
auto numQubits =
builder.create<cudaq::cc::GetNumberOfQubitsOp>(loc, i64Ty, state);
auto veqTy = quake::VeqType::getUnsized(ctx);
Value alloc = builder.create<quake::AllocaOp>(loc, veqTy,
numQubits.getResult(0));
Value alloc = builder.create<quake::AllocaOp>(loc, veqTy, numQubits);
return pushValue(builder.create<quake::InitializeStateOp>(
loc, veqTy, alloc, state));
}
Expand Down
68 changes: 67 additions & 1 deletion lib/Optimizer/CodeGen/QuakeToCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

#include "QuakeToCodegen.h"
#include "CodeGenOps.h"
#include "cudaq/Optimizer/Builder/Intrinsics.h"
#include "cudaq/Optimizer/CodeGen/Passes.h"
#include "cudaq/Optimizer/CodeGen/QIRFunctionNames.h"
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
Expand Down Expand Up @@ -62,10 +65,73 @@ class ExpandComplexCast : public OpRewritePattern<cudaq::cc::CastOp> {
return success();
}
};

class CreateStateOpPattern : public OpRewritePattern<cudaq::cc::CreateStateOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(cudaq::cc::CreateStateOp createStateOp,
PatternRewriter &rewriter) const override {
auto module = createStateOp->getParentOfType<ModuleOp>();
auto loc = createStateOp.getLoc();
auto ctx = createStateOp.getContext();
auto buffer = createStateOp.getOperand(0);
auto size = createStateOp.getOperand(1);

auto bufferTy = buffer.getType();
auto ptrTy = cast<cudaq::cc::PointerType>(bufferTy);
auto arrTy = cast<cudaq::cc::ArrayType>(ptrTy.getElementType());
auto eleTy = arrTy.getElementType();
auto is64Bit = isa<Float64Type>(eleTy);

if (auto cTy = dyn_cast<ComplexType>(eleTy))
is64Bit = isa<Float64Type>(cTy.getElementType());

auto createStateFunc = is64Bit ? cudaq::createCudaqStateFromDataFP64
: cudaq::createCudaqStateFromDataFP32;
cudaq::IRBuilder irBuilder(ctx);
auto result = irBuilder.loadIntrinsic(module, createStateFunc);
assert(succeeded(result) && "loading intrinsic should never fail");

auto stateTy = cudaq::cc::StateType::get(ctx);
auto statePtrTy = cudaq::cc::PointerType::get(stateTy);
auto i8PtrTy = cudaq::cc::PointerType::get(rewriter.getI8Type());
auto cast = rewriter.create<cudaq::cc::CastOp>(loc, i8PtrTy, buffer);

rewriter.replaceOpWithNewOp<func::CallOp>(
createStateOp, statePtrTy, createStateFunc, ValueRange{cast, size});
return success();
}
};

class GetNumberOfQubitsOpPattern
: public OpRewritePattern<cudaq::cc::GetNumberOfQubitsOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(cudaq::cc::GetNumberOfQubitsOp getNumQubitsOp,
PatternRewriter &rewriter) const override {
auto module = getNumQubitsOp->getParentOfType<ModuleOp>();
auto ctx = getNumQubitsOp.getContext();
auto state = getNumQubitsOp.getOperand();

cudaq::IRBuilder irBuilder(ctx);
auto result =
irBuilder.loadIntrinsic(module, cudaq::getNumQubitsFromCudaqState);
assert(succeeded(result) && "loading intrinsic should never fail");

rewriter.replaceOpWithNewOp<func::CallOp>(
getNumQubitsOp, rewriter.getI64Type(),
cudaq::getNumQubitsFromCudaqState, state);
return success();
}
};

} // namespace

void cudaq::codegen::populateQuakeToCodegenPatterns(
mlir::RewritePatternSet &patterns) {
auto *ctx = patterns.getContext();
patterns.insert<CodeGenRAIIPattern, ExpandComplexCast>(ctx);
patterns.insert<CodeGenRAIIPattern, ExpandComplexCast, CreateStateOpPattern,
GetNumberOfQubitsOpPattern>(ctx);
}
112 changes: 43 additions & 69 deletions lib/Optimizer/Transforms/DeleteStates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,104 +29,79 @@ namespace cudaq::opt {
using namespace mlir;

namespace {

static bool isCall(Operation *callOp, std::vector<const char *> &&names) {
if (callOp) {
if (auto createStateCall = dyn_cast<func::CallOp>(callOp)) {
if (auto calleeAttr = createStateCall.getCalleeAttr()) {
auto funcName = calleeAttr.getValue().str();
if (std::find(names.begin(), names.end(), funcName) != names.end())
return true;
}
}
}
return false;
}

static bool isCreateStateCall(Operation *callOp) {
return isCall(callOp, {cudaq::createCudaqStateFromDataFP64,
cudaq::createCudaqStateFromDataFP32});
}

static bool isNumberOfQubitsCall(Operation *callOp) {
return isCall(callOp, {cudaq::getNumQubitsFromCudaqState});
}

/// For a call to `__nvqpp_cudaq_state_createFromData_fpXX`, get the number of
/// qubits allocated.
static std::size_t getStateSize(Operation *callOp) {
if (isCreateStateCall(callOp)) {
if (auto createStateCall = dyn_cast<func::CallOp>(callOp)) {
auto sizeOperand = createStateCall.getOperand(1);
auto defOp = sizeOperand.getDefiningOp();
while (defOp && !dyn_cast<arith::ConstantIntOp>(defOp))
defOp = defOp->getOperand(0).getDefiningOp();
if (auto constOp = dyn_cast<arith::ConstantIntOp>(defOp))
return constOp.getValue().cast<IntegerAttr>().getInt();
}
/// For a `cc:CreateStateOp`, get the number of qubits allocated.
static std::size_t getStateSize(Operation *op) {
if (auto createStateOp = dyn_cast<cudaq::cc::CreateStateOp>(op)) {
auto sizeOperand = createStateOp.getOperand(1);
auto defOp = sizeOperand.getDefiningOp();
while (defOp && !dyn_cast<arith::ConstantIntOp>(defOp))
defOp = defOp->getOperand(0).getDefiningOp();
if (auto constOp = dyn_cast<arith::ConstantIntOp>(defOp))
return constOp.getValue().cast<IntegerAttr>().getInt();
}
callOp->emitError("Cannot compute number of qubits");
op->emitError("Cannot compute number of qubits from createStateOp");
return 0;
}

// clang-format off
/// Remove `__nvqpp_cudaq_state_numberOfQubits` calls.
/// Replace `cc.get_number_of_qubits` by a constant.
/// ```
/// %1 = arith.constant 8 : i64
/// %2 = call @__nvqpp_cudaq_state_createFromData_fp32(%0, %1) : (!cc.ptr<i8>, i64) -> !cc.ptr<!cc.state>
/// %3 = call @__nvqpp_cudaq_state_numberOfQubits(%2) : (!cc.ptr<!cc.state>) -> i64
/// %c8_i64 = arith.constant 8 : i64
/// %2 = cc.create_state %3, %c8_i64 : (!cc.ptr<i8>, i64) -> !cc.ptr<!cc.state>
/// %3 = cc.get_number_of_qubits %2 : i64
/// ...
/// ───────────────────────────────────────────
/// %1 = arith.constant 8 : i64
/// %2 = call @__nvqpp_cudaq_state_createFromData_fp32(%0, %1) : (!cc.ptr<i8>, i64) -> !cc.ptr<!cc.state>
/// %5 = arith.constant 3 : i64
/// %c8_i64 = arith.constant 8 : i64
/// %2 = cc.create_state %3, %c8_i64 : (!cc.ptr<i8>, i64) -> !cc.ptr<!cc.state>
/// %3 = arith.constant 3 : i64
/// ```
// clang-format on
class NumberOfQubitsPattern : public OpRewritePattern<func::CallOp> {
class NumberOfQubitsPattern
: public OpRewritePattern<cudaq::cc::GetNumberOfQubitsOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(func::CallOp callOp,
LogicalResult matchAndRewrite(cudaq::cc::GetNumberOfQubitsOp op,
PatternRewriter &rewriter) const override {
if (isNumberOfQubitsCall(callOp)) {
auto createStateOp = callOp.getOperand(0).getDefiningOp();
if (isCreateStateCall(createStateOp)) {
auto size = getStateSize(createStateOp);
rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(
callOp, std::countr_zero(size), rewriter.getI64Type());
return success();
}
auto stateOp = op.getOperand();
if (auto createStateOp =
stateOp.getDefiningOp<cudaq::cc::CreateStateOp>()) {
auto size = getStateSize(createStateOp);
rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(
op, std::countr_zero(size), rewriter.getI64Type());
return success();
}
return failure();
}
};

// clang-format off
/// Replace calls to `__nvqpp_cudaq_state_numberOfQubits` by a constant.
/// Remove `cc.create_state` instructions and pass their data directly to
/// the `quake.state_init` instruction instead.
/// ```
/// %2 = cc.cast %1 : (!cc.ptr<!cc.array<complex<f32> x 8>>) -> !cc.ptr<i8>
/// %3 = call @__nvqpp_cudaq_state_createFromData_fp32(%2, %c8_i64) : (!cc.ptr<i8>, i64) -> !cc.ptr<!cc.state>
/// %3 = cc.create_state %3, %c8_i64 : (!cc.ptr<i8>, i64) -> !cc.ptr<!cc.state>
/// %4 = quake.alloca !quake.veq<?>[%0 : i64]
/// %5 = quake.init_state %4, %3 : (!quake.veq<?>, !cc.ptr<!cc.state>) -> !quake.veq<?>
/// ───────────────────────────────────────────
/// ...
/// %3 = call @__nvqpp_cudaq_state_createFromData_fp32(%2, %c8_i64) : (!cc.ptr<i8>, i64) -> !cc.ptr<!cc.state>
/// %4 = quake.alloca !quake.veq<?>[%0 : i64]
/// %5 = quake.init_state %4, %1 : (!quake.veq<?>, !cc.ptr<!cc.array<complex<f32> x 8>>) -> !quake.veq<?>
/// ```
// clang-format on

class StateToDataPattern : public OpRewritePattern<quake::InitializeStateOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(quake::InitializeStateOp initState,
PatternRewriter &rewriter) const override {
auto stateOp = initState.getOperand(1).getDefiningOp();
auto state = initState.getOperand(1);
auto targets = initState.getTargets();

if (isCreateStateCall(stateOp)) {
auto dataOp = stateOp->getOperand(0);
if (auto cast = dyn_cast<cudaq::cc::CastOp>(dataOp.getDefiningOp()))
if (auto createStateOp = state.getDefiningOp<cudaq::cc::CreateStateOp>()) {
auto dataOp = createStateOp->getOperand(0);
if (auto cast = dataOp.getDefiningOp<cudaq::cc::CastOp>())
dataOp = cast.getOperand();
rewriter.replaceOpWithNewOp<quake::InitializeStateOp>(
initState, targets.getType(), targets, dataOp);
Expand Down Expand Up @@ -163,10 +138,8 @@ class DeleteStatesPass
llvm::SmallVector<Operation *> usedStates;

func.walk([&](Operation *op) {
if (isCreateStateCall(op)) {
if (op->getUses().empty())
op->erase();
else
if (isa<cudaq::cc::CreateStateOp>(op)) {
if (!op->getUses().empty())
usedStates.push_back(op);
}
});
Expand All @@ -178,15 +151,16 @@ class DeleteStatesPass
func.walk([&](Operation *op) {
if (isa<func::ReturnOp>(op)) {
auto loc = op->getLoc();
auto deleteState = cudaq::deleteCudaqState;
auto result = irBuilder.loadIntrinsic(module, deleteState);
auto result =
irBuilder.loadIntrinsic(module, cudaq::deleteCudaqState);
assert(succeeded(result) && "loading intrinsic should never fail");

builder.setInsertionPoint(op);
for (auto createStateOp : usedStates) {
auto results = cast<func::CallOp>(createStateOp).getResults();
builder.create<func::CallOp>(loc, std::nullopt, deleteState,
results);
auto result = cast<cudaq::cc::CreateStateOp>(createStateOp);
builder.create<func::CallOp>(loc, std::nullopt,
cudaq::deleteCudaqState,
mlir::ValueRange{result});
}
}
});
Expand Down
9 changes: 5 additions & 4 deletions python/cudaq/kernel/ast_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2246,11 +2246,9 @@ def bodyBuilder(iterVal):
# handle `cudaq.qvector(state)`
statePtr = self.ifNotPointerThenStore(valueOrPtr)

symName = '__nvqpp_cudaq_state_numberOfQubits'
load_intrinsic(self.module, symName)
i64Ty = self.getIntegerType()
numQubits = func.CallOp([i64Ty], symName,
[statePtr]).result
numQubits = cc.GetNumberOfQubitsOp(i64Ty,
statePtr).result

veqTy = quake.VeqType.get(self.ctx)
qubits = quake.AllocaOp(veqTy, size=numQubits).result
Expand Down Expand Up @@ -3831,6 +3829,9 @@ def visit_Name(self, node):
if cc.StdvecType.isinstance(eleTy):
self.pushValue(value)
return
if cc.StateType.isinstance(eleTy):
self.pushValue(value)
return
loaded = cc.LoadOp(value).result
self.pushValue(loaded)
elif cc.CallableType.isinstance(
Expand Down
Loading
Loading