diff --git a/bin/sozo/src/commands/execute.rs b/bin/sozo/src/commands/execute.rs index c8510d1d16..0305fa2186 100644 --- a/bin/sozo/src/commands/execute.rs +++ b/bin/sozo/src/commands/execute.rs @@ -1,3 +1,8 @@ +use super::options::account::AccountOptions; +use super::options::starknet::StarknetOptions; +use super::options::transaction::TransactionOptions; +use super::options::world::WorldOptions; +use crate::utils; use anyhow::{anyhow, Result}; use clap::Args; use dojo_utils::{Invoker, TxnConfig}; @@ -8,35 +13,19 @@ use sozo_scarbext::WorkspaceExt; use sozo_walnut::WalnutDebugger; use starknet::core::types::Call; use starknet::core::utils as snutils; -use tracing::trace; -use super::options::account::AccountOptions; -use super::options::starknet::StarknetOptions; -use super::options::transaction::TransactionOptions; -use super::options::world::WorldOptions; -use crate::utils; +use dojo_world::diff::WorldDiff; +use scarb::core::Workspace; +use starknet::core::types::Address; +use tracing::trace; #[derive(Debug, Args)] #[command(about = "Execute a system with the given calldata.")] pub struct ExecuteArgs { #[arg( - help = "The address or the tag (ex: dojo_examples:actions) of the contract to be executed." + help = "List of calls to execute. Each call should be in format: ,,,,... (ex: dojo_examples:actions,execute,1,2)" )] - pub tag_or_address: ResourceDescriptor, - - #[arg(help = "The name of the entrypoint to be executed.")] - pub entrypoint: String, - - #[arg(short, long)] - #[arg(help = "The calldata to be passed to the system. Comma separated values e.g., \ - 0x12345,128,u256:9999999999. Sozo supports some prefixes that you can use to \ - automatically parse some types. The supported prefixes are: - - u256: A 256-bit unsigned integer. - - sstr: A cairo short string. - - str: A cairo string (ByteArray). - - int: A signed integer. - - no prefix: A cairo felt or any type that fit into one felt.")] - pub calldata: Option, + pub calls: Vec, #[arg(long)] #[arg(help = "If true, sozo will compute the diff of the world from the chain to translate \ @@ -56,6 +45,69 @@ pub struct ExecuteArgs { pub transaction: TransactionOptions, } +#[derive(Debug)] +pub struct CallArgs { + pub tag_or_address: ResourceDescriptor, // Contract address or tag + pub entrypoint: String, // Entrypoint to call + pub calldata: Option, // Calldata to pass to the entrypoint +} + +impl std::str::FromStr for CallArgs { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + let s = s.trim(); + + let parts: Vec<&str> = s.split(',').collect(); + if parts.len() < 2 { + return Err(anyhow!("Invalid call format. Expected format: ,,,,...")); + } + + let entrypoint = parts[1].trim(); + if entrypoint.is_empty() { + return Err(anyhow!("Empty entrypoint")); + } + if !entrypoint.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') { + return Err(anyhow!("Invalid entrypoint format. Must contain only alphanumeric characters and underscores")); + } + + Ok(CallArgs { + tag_or_address: parts[0].parse()?, + entrypoint: entrypoint.to_string(), + calldata: if parts.len() > 2 { Some(parts[2..].join(",")) } else { None }, + }) + } +} + +async fn resolve_contract_address( + descriptor: &ResourceDescriptor, + world_diff: &WorldDiff, + options: &ExecuteArgs, + ws: &Workspace, +) -> Result
{ + match descriptor { + ResourceDescriptor::Address(address) => Ok(*address), + ResourceDescriptor::Tag(tag) => { + let contracts = utils::contracts_from_manifest_or_diff( + options.account.clone(), + options.starknet.clone(), + options.world, + &ws, + options.diff, + ) + .await?; + + contracts + .get(tag) + .map(|c| c.address) + .ok_or_else(|| anyhow!("Contract {descriptor} not found in the world diff.")) + } + ResourceDescriptor::Name(_) => { + unimplemented!("Expected to be a resolved tag with default namespace.") + } + } +} + impl ExecuteArgs { pub fn run(self, config: &Config) -> Result<()> { trace!(args = ?self); @@ -64,8 +116,6 @@ impl ExecuteArgs { let profile_config = ws.load_profile_config()?; - let descriptor = self.tag_or_address.ensure_namespace(&profile_config.namespace.default); - #[cfg(feature = "walnut")] let _walnut_debugger = WalnutDebugger::new_from_flag( self.transaction.walnut, @@ -75,66 +125,57 @@ impl ExecuteArgs { let txn_config: TxnConfig = self.transaction.try_into()?; config.tokio_handle().block_on(async { - let (contract_address, contracts) = match &descriptor { - ResourceDescriptor::Address(address) => (Some(*address), Default::default()), - ResourceDescriptor::Tag(tag) => { - let contracts = utils::contracts_from_manifest_or_diff( - self.account.clone(), - self.starknet.clone(), - self.world, - &ws, - self.diff, - ) - .await?; - - (contracts.get(tag).map(|c| c.address), contracts) - } - ResourceDescriptor::Name(_) => { - unimplemented!("Expected to be a resolved tag with default namespace.") - } - }; - - let contract_address = contract_address.ok_or_else(|| { - let mut message = format!("Contract {descriptor} not found in the manifest."); - if self.diff { - message.push_str( - " Run the command again with `--diff` to force the fetch of data from the \ - chain.", - ); - } - anyhow!(message) - })?; - - trace!( - contract=?descriptor, - entrypoint=self.entrypoint, - calldata=?self.calldata, - "Executing Execute command." - ); - - let calldata = if let Some(cd) = self.calldata { - calldata_decoder::decode_calldata(&cd)? - } else { - vec![] - }; - - let call = Call { - calldata, - to: contract_address, - selector: snutils::get_selector_from_name(&self.entrypoint)?, - }; - - let (provider, _) = self.starknet.provider(profile_config.env.as_ref())?; - - let account = self - .account - .account(provider, profile_config.env.as_ref(), &self.starknet, &contracts) - .await?; - - let invoker = Invoker::new(&account, txn_config); - // TODO: add walnut back, perhaps at the invoker level. - let tx_result = invoker.invoke(call).await?; - + // We could save the world diff computation extracting the account directly from the options. + let (world_diff, account, _) = utils::get_world_diff_and_account( + self.account, + self.starknet.clone(), + self.world, + &ws, + &mut None, + ) + .await?; + + let mut invoker = Invoker::new(&account, txn_config); + + // Parse the Vec into Vec using FromStr + let call_args_list: Vec = + self.calls.iter().map(|s| s.parse()).collect::>>()?; + + for call_args in call_args_list { + let descriptor = + call_args.tag_or_address.ensure_namespace(&profile_config.namespace.default); + + // Checking the contract tag in local manifest + let contract_address = + if let Some(local_address) = ws.get_contract_address(&descriptor) { + local_address + } else { + resolve_contract_address(&descriptor, &world_diff, &ws).await?; + }; + + trace!( + contract=?descriptor, + entrypoint=call_args.entrypoint, + calldata=?call_args.calldata, + "Executing Execute command." + ); + + let calldata = if let Some(cd) = call_args.calldata { + calldata_decoder::decode_calldata(&cd)? + } else { + vec![] + }; + + let call = Call { + calldata, + to: contract_address, + selector: snutils::get_selector_from_name(&call_args.entrypoint)?, + }; + + invoker.add_call(call); // Adding each call to the Invoker + } + + let tx_result = invoker.invoke(call).await?; // Invoking the multi-call println!("{}", tx_result); Ok(()) })