diff --git a/api/api.go b/api/api.go index b059350..45a6a02 100644 --- a/api/api.go +++ b/api/api.go @@ -11,11 +11,12 @@ import ( // A StateResponse returns information about the current state of the walletd // daemon. type StateResponse struct { - Version string `json:"version"` - Commit string `json:"commit"` - OS string `json:"os"` - BuildTime time.Time `json:"buildTime"` - StartTime time.Time `json:"startTime"` + Version string `json:"version"` + Commit string `json:"commit"` + OS string `json:"os"` + BuildTime time.Time `json:"buildTime"` + StartTime time.Time `json:"startTime"` + IndexMode wallet.IndexMode `json:"indexMode"` } // A GatewayPeer is a currently-connected peer. diff --git a/api/api_test.go b/api/api_test.go index d6e1de0..a7222b8 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -88,7 +88,10 @@ func TestWalletAdd(t *testing.T) { } defer ws.Close() - wm := wallet.NewManager(cm, ws, log.Named("wallet")) + wm, err := wallet.NewManager(cm, ws, wallet.WithLogger(log.Named("wallet"))) + if err != nil { + t.Fatal(err) + } defer wm.Close() c, shutdown := runServer(cm, nil, wm) @@ -273,7 +276,10 @@ func TestWallet(t *testing.T) { }) // create the wallet manager - wm := wallet.NewManager(cm, ws, log.Named("wallet")) + wm, err := wallet.NewManager(cm, ws, wallet.WithLogger(log.Named("wallet"))) + if err != nil { + t.Fatal(err) + } defer wm.Close() // create seed address vault @@ -492,7 +498,10 @@ func TestAddresses(t *testing.T) { } defer ws.Close() - wm := wallet.NewManager(cm, ws, log.Named("wallet")) + wm, err := wallet.NewManager(cm, ws, wallet.WithLogger(log.Named("wallet"))) + if err != nil { + t.Fatal(err) + } defer wm.Close() sav := wallet.NewSeedAddressVault(wallet.NewSeed(), 0, 20) @@ -686,7 +695,10 @@ func TestV2(t *testing.T) { t.Fatal(err) } defer ws.Close() - wm := wallet.NewManager(cm, ws, log.Named("wallet")) + wm, err := wallet.NewManager(cm, ws, wallet.WithLogger(log.Named("wallet"))) + if err != nil { + t.Fatal(err) + } defer wm.Close() c, shutdown := runServer(cm, nil, wm) @@ -909,7 +921,10 @@ func TestP2P(t *testing.T) { t.Fatal(err) } - wm1 := wallet.NewManager(cm1, store1, log1.Named("wallet")) + wm1, err := wallet.NewManager(cm1, store1, wallet.WithLogger(log1.Named("wallet"))) + if err != nil { + t.Fatal(err) + } defer wm1.Close() l1, err := net.Listen("tcp", ":0") @@ -949,7 +964,10 @@ func TestP2P(t *testing.T) { t.Fatal(err) } defer store2.Close() - wm2 := wallet.NewManager(cm2, store2, log2.Named("wallet")) + wm2, err := wallet.NewManager(cm2, store2, wallet.WithLogger(log2.Named("wallet"))) + if err != nil { + t.Fatal(err) + } defer wm2.Close() l2, err := net.Listen("tcp", ":0") diff --git a/api/server.go b/api/server.go index 3778993..23b5603 100644 --- a/api/server.go +++ b/api/server.go @@ -48,6 +48,7 @@ type ( // A WalletManager manages wallets, keyed by name. WalletManager interface { + IndexMode() wallet.IndexMode Tip() (types.ChainIndex, error) Scan(_ context.Context, index types.ChainIndex) error @@ -97,6 +98,7 @@ func (s *server) stateHandler(jc jape.Context) { OS: runtime.GOOS, BuildTime: build.Time(), StartTime: s.startTime, + IndexMode: s.wm.IndexMode(), }) } diff --git a/cmd/walletd/main.go b/cmd/walletd/main.go index cab41ed..ac75bb4 100644 --- a/cmd/walletd/main.go +++ b/cmd/walletd/main.go @@ -11,6 +11,7 @@ import ( cwallet "go.sia.tech/coreutils/wallet" "go.sia.tech/walletd/api" "go.sia.tech/walletd/build" + "go.sia.tech/walletd/wallet" "go.uber.org/zap" "go.uber.org/zap/zapcore" "golang.org/x/term" @@ -68,7 +69,7 @@ Runs a CPU miner. Not intended for production use. func main() { log.SetFlags(0) - var gatewayAddr, apiAddr, dir, network, seed string + var gatewayAddr, apiAddr, dir, network, seed, indexModeStr string var upnp, bootstrap bool var minerAddrStr string @@ -83,6 +84,7 @@ func main() { rootCmd.BoolVar(&upnp, "upnp", true, "attempt to forward ports and discover IP with UPnP") rootCmd.BoolVar(&bootstrap, "bootstrap", true, "attempt to bootstrap the network") rootCmd.StringVar(&seed, "seed", "", "testnet seed") + rootCmd.StringVar(&indexModeStr, "index", "full", "address index mode (full, partial, off)") versionCmd := flagg.New("version", versionUsage) seedCmd := flagg.New("seed", seedUsage) mineCmd := flagg.New("mine", mineUsage) @@ -135,7 +137,17 @@ func main() { // redirect stdlib log to zap zap.RedirectStdLog(logger.Named("stdlib")) - n, err := newNode(gatewayAddr, dir, network, upnp, bootstrap, logger) + var indexMode wallet.IndexMode + switch indexModeStr { + case "full": + indexMode = wallet.IndexModeFull + case "partial": + indexMode = wallet.IndexModePartial + case "off": + indexMode = wallet.IndexModeNone + } + + n, err := newNode(gatewayAddr, dir, network, upnp, bootstrap, indexMode, logger) if err != nil { log.Fatal(err) } diff --git a/cmd/walletd/node.go b/cmd/walletd/node.go index 9eb9c7f..cc55770 100644 --- a/cmd/walletd/node.go +++ b/cmd/walletd/node.go @@ -100,7 +100,7 @@ func (n *node) Close() error { return n.store.Close() } -func newNode(addr, dir string, chainNetwork string, useUPNP, useBootstrap bool, log *zap.Logger) (*node, error) { +func newNode(addr, dir string, chainNetwork string, useUPNP, useBootstrap bool, indexMode wallet.IndexMode, log *zap.Logger) (*node, error) { var network *consensus.Network var genesisBlock types.Block var bootstrapPeers []string @@ -187,8 +187,10 @@ func newNode(addr, dir string, chainNetwork string, useUPNP, useBootstrap bool, } s := syncer.New(l, cm, ps, header, syncer.WithLogger(log.Named("syncer"))) - wm := wallet.NewManager(cm, store, log.Named("wallet")) - + wm, err := wallet.NewManager(cm, store, wallet.WithLogger(log.Named("wallet")), wallet.WithIndexMode(indexMode)) + if err != nil { + return nil, fmt.Errorf("failed to create wallet manager: %w", err) + } return &node{ chainStore: bdb, cm: cm, diff --git a/persist/sqlite/addresses.go b/persist/sqlite/addresses.go index b3fadd3..e2da36b 100644 --- a/persist/sqlite/addresses.go +++ b/persist/sqlite/addresses.go @@ -1,6 +1,8 @@ package sqlite import ( + "database/sql" + "errors" "fmt" "go.sia.tech/core/types" @@ -11,7 +13,12 @@ import ( func (s *Store) AddressBalance(address types.Address) (balance wallet.Balance, err error) { err = s.transaction(func(tx *txn) error { const query = `SELECT siacoin_balance, immature_siacoin_balance, siafund_balance FROM sia_addresses WHERE sia_address=$1` - return tx.QueryRow(query, encode(address)).Scan(decode(&balance.Siacoins), decode(&balance.ImmatureSiacoins), &balance.Siafunds) + err := tx.QueryRow(query, encode(address)).Scan(decode(&balance.Siacoins), decode(&balance.ImmatureSiacoins), &balance.Siafunds) + if errors.Is(err, sql.ErrNoRows) { + balance = wallet.Balance{} + return nil + } + return err }) return } @@ -70,7 +77,25 @@ func (s *Store) AddressSiacoinOutputs(address types.Address, offset, limit int) siacoins = append(siacoins, siacoin) } - return rows.Err() + if err := rows.Err(); err != nil { + return err + } + + // retrieve the merkle proofs for the siacoin elements + if s.indexMode == wallet.IndexModeFull { + indices := make([]uint64, len(siacoins)) + for i, se := range siacoins { + indices[i] = se.LeafIndex + } + proofs, err := fillElementProofs(tx, indices) + if err != nil { + return fmt.Errorf("failed to fill element proofs: %w", err) + } + for i, proof := range proofs { + siacoins[i].MerkleProof = proof + } + } + return nil }) return } @@ -97,7 +122,25 @@ func (s *Store) AddressSiafundOutputs(address types.Address, offset, limit int) } siafunds = append(siafunds, siafund) } - return rows.Err() + if err := rows.Err(); err != nil { + return err + } + + // retrieve the merkle proofs for the siafund elements + if s.indexMode == wallet.IndexModeFull { + indices := make([]uint64, len(siafunds)) + for i, se := range siafunds { + indices[i] = se.LeafIndex + } + proofs, err := fillElementProofs(tx, indices) + if err != nil { + return fmt.Errorf("failed to fill element proofs: %w", err) + } + for i, proof := range proofs { + siafunds[i].MerkleProof = proof + } + } + return nil }) return } diff --git a/persist/sqlite/consensus.go b/persist/sqlite/consensus.go index c529dd7..92761a6 100644 --- a/persist/sqlite/consensus.go +++ b/persist/sqlite/consensus.go @@ -14,6 +14,8 @@ import ( ) type updateTx struct { + indexMode wallet.IndexMode + tx *txn relevantAddresses map[types.Address]bool } @@ -24,6 +26,10 @@ type addressRef struct { } func (ut *updateTx) SiacoinStateElements() ([]types.StateElement, error) { + if ut.indexMode == wallet.IndexModeFull { + panic("SiacoinStateElements called in full index mode") + } + const query = `SELECT id, leaf_index, merkle_proof FROM siacoin_elements` rows, err := ut.tx.Query(query) if err != nil { @@ -43,6 +49,10 @@ func (ut *updateTx) SiacoinStateElements() ([]types.StateElement, error) { } func (ut *updateTx) UpdateSiacoinStateElements(elements []types.StateElement) error { + if ut.indexMode == wallet.IndexModeFull { + panic("UpdateSiacoinStateElements called in full index mode") + } + log := ut.tx.log.Named("UpdateSiacoinStateElements") log.Debug("updating siacoin state elements", zap.Int("count", len(elements))) @@ -65,6 +75,10 @@ func (ut *updateTx) UpdateSiacoinStateElements(elements []types.StateElement) er } func (ut *updateTx) SiafundStateElements() ([]types.StateElement, error) { + if ut.indexMode == wallet.IndexModeFull { + panic("SiafundStateElements called in full index mode") + } + const query = `SELECT id, leaf_index, merkle_proof FROM siafund_elements` rows, err := ut.tx.Query(query) if err != nil { @@ -84,6 +98,10 @@ func (ut *updateTx) SiafundStateElements() ([]types.StateElement, error) { } func (ut *updateTx) UpdateSiafundStateElements(elements []types.StateElement) error { + if ut.indexMode == wallet.IndexModeFull { + panic("UpdateSiafundStateElements called in full index mode") + } + const query = `UPDATE siafund_elements SET merkle_proof=$1, leaf_index=$2 WHERE id=$3 RETURNING id` stmt, err := ut.tx.Prepare(query) if err != nil { @@ -101,7 +119,31 @@ func (ut *updateTx) UpdateSiafundStateElements(elements []types.StateElement) er return nil } +func (ut *updateTx) UpdateStateTree(changes []wallet.TreeNodeUpdate) error { + if ut.indexMode != wallet.IndexModeFull { + panic("UpdateStateTree called in partial index mode") + } + + stmt, err := ut.tx.Prepare(`INSERT INTO state_tree (row, column, value) VALUES ($1, $2, $3) ON CONFLICT (row, column) DO UPDATE SET value=EXCLUDED.value`) + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer stmt.Close() + + for _, change := range changes { + _, err := stmt.Exec(change.Row, change.Column, encode(change.Hash)) + if err != nil { + return fmt.Errorf("failed to execute statement: %w", err) + } + } + return nil +} + func (ut *updateTx) AddressRelevant(addr types.Address) (bool, error) { + if ut.indexMode == wallet.IndexModeFull { + return true, nil + } + if relevant, ok := ut.relevantAddresses[addr]; ok { return relevant, nil } @@ -142,13 +184,13 @@ func (ut *updateTx) ApplyIndex(index types.ChainIndex, state wallet.AppliedState if err := spendSiacoinElements(tx, state.SpentSiacoinElements, indexID); err != nil { return fmt.Errorf("failed to spend siacoin elements: %w", err) - } else if err := addSiacoinElements(tx, state.CreatedSiacoinElements, indexID, log.Named("addSiacoinElements")); err != nil { + } else if err := addSiacoinElements(tx, state.CreatedSiacoinElements, indexID, ut.indexMode, log.Named("addSiacoinElements")); err != nil { return fmt.Errorf("failed to add siacoin elements: %w", err) } if err := spendSiafundElements(tx, state.SpentSiafundElements, indexID); err != nil { return fmt.Errorf("failed to spend siafund elements: %w", err) - } else if err := addSiafundElements(tx, state.CreatedSiafundElements, indexID, log.Named("addSiafundElements")); err != nil { + } else if err := addSiafundElements(tx, state.CreatedSiafundElements, indexID, ut.indexMode, log.Named("addSiafundElements")); err != nil { return fmt.Errorf("failed to add siafund elements: %w", err) } @@ -186,20 +228,22 @@ func (s *Store) UpdateChainState(reverted []chain.RevertUpdate, applied []chain. log := s.log.Named("UpdateChainState").With(zap.Int("reverted", len(reverted)), zap.Int("applied", len(applied))) return s.transaction(func(tx *txn) error { utx := &updateTx{ + indexMode: s.indexMode, + tx: tx, relevantAddresses: make(map[types.Address]bool), } - if err := wallet.UpdateChainState(utx, reverted, applied, log); err != nil { - return fmt.Errorf("failed to update chain state: %w", err) - } else if err := setLastCommittedIndex(tx, applied[len(applied)-1].State.Index); err != nil { + state := applied[len(applied)-1].State + + if err := wallet.UpdateChainState(utx, reverted, applied, s.indexMode, log); err != nil { + return err + } else if err := setGlobalState(tx, state.Index, state.Elements.NumLeaves); err != nil { return fmt.Errorf("failed to set last committed index: %w", err) } - height := applied[len(applied)-1].State.Index.Height - - if height > spentElementRetentionBlocks { - pruneHeight := height - spentElementRetentionBlocks + if state.Index.Height > spentElementRetentionBlocks { + pruneHeight := state.Index.Height - spentElementRetentionBlocks siacoins, err := pruneSpentSiacoinElements(tx, pruneHeight) if err != nil { @@ -210,10 +254,7 @@ func (s *Store) UpdateChainState(reverted []chain.RevertUpdate, applied []chain. if err != nil { return fmt.Errorf("failed to cleanup siafund elements: %w", err) } - - if len(siacoins) > 0 || len(siafunds) > 0 { - log.Debug("pruned elements", zap.Stringers("siacoins", siacoins), zap.Stringers("siafunds", siafunds), zap.Uint64("pruneHeight", pruneHeight)) - } + log.Debug("pruned elements", zap.Int64("siacoins", siacoins), zap.Int64("siafunds", siafunds), zap.Uint64("pruneHeight", pruneHeight)) } return nil }) @@ -231,6 +272,35 @@ func (s *Store) ResetLastIndex() error { return err } +// IndexMode returns the current index mode. +func (s *Store) IndexMode() (wallet.IndexMode, error) { + var mode wallet.IndexMode + err := s.db.QueryRow(`SELECT index_mode FROM global_settings`).Scan(&mode) + return mode, err +} + +// SetIndexMode sets the index mode. If the index mode is already set, this +// function will return an error. +func (s *Store) SetIndexMode(mode wallet.IndexMode) error { + return s.transaction(func(tx *txn) error { + _, err := tx.Exec(`UPDATE global_settings SET index_mode=$1 WHERE index_mode IS NULL`, mode) + if err != nil { + return fmt.Errorf("failed to set index mode: %w", err) + } + + // check that the index mode was set + var existingMode wallet.IndexMode + err = tx.QueryRow(`SELECT index_mode FROM global_settings`).Scan(&existingMode) + if err != nil { + return fmt.Errorf("failed to query index mode: %w", err) + } else if existingMode != mode { + return fmt.Errorf("cannot change index mode from %v to %v", existingMode, mode) + } + s.indexMode = mode // this is a bit annoying + return nil + }) +} + func scanStateElement(s scanner) (se types.StateElement, err error) { err = s.Scan(decode(&se.ID), &se.LeafIndex, decodeSlice(&se.MerkleProof)) return @@ -401,16 +471,16 @@ func revertMatureSiacoinBalance(tx *txn, index types.ChainIndex) error { return nil } -func addSiacoinElements(tx *txn, elements []types.SiacoinElement, indexID int64, log *zap.Logger) error { +func addSiacoinElements(tx *txn, elements []types.SiacoinElement, indexID int64, indexMode wallet.IndexMode, log *zap.Logger) error { if len(elements) == 0 { return nil } - addrStmt, err := insertAddressStatement(tx) + addressRefStmt, done, err := addressRefStmt(tx) if err != nil { return fmt.Errorf("failed to prepare address statement: %w", err) } - defer addrStmt.Close() + defer done() existsStmt, err := tx.Prepare(`SELECT EXISTS(SELECT 1 FROM siacoin_elements WHERE id=$1)`) if err != nil { @@ -427,7 +497,7 @@ func addSiacoinElements(tx *txn, elements []types.SiacoinElement, indexID int64, balanceChanges := make(map[int64]wallet.Balance) for _, se := range elements { - addrRef, err := scanAddress(addrStmt.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency), 0)) + addrRef, err := addressRefStmt(se.SiacoinOutput.Address) if err != nil { return fmt.Errorf("failed to query address: %w", err) } else if _, ok := balanceChanges[addrRef.ID]; !ok { @@ -440,6 +510,12 @@ func addSiacoinElements(tx *txn, elements []types.SiacoinElement, indexID int64, return fmt.Errorf("failed to check if siacoin element exists: %w", err) } + // in full index mode, Merkle proofs are stored in the state tree table + // rather than per element. + if indexMode == wallet.IndexModeFull { + se.MerkleProof = nil + } + _, err = insertStmt.Exec(encode(se.ID), encode(se.SiacoinOutput.Value), encodeSlice(se.MerkleProof), se.LeafIndex, se.MaturityHeight, addrRef.ID, se.MaturityHeight == 0, indexID) if err != nil { return fmt.Errorf("failed to execute statement: %w", err) @@ -489,11 +565,11 @@ func removeSiacoinElements(tx *txn, elements []types.SiacoinElement) error { return nil } - addrStmt, err := insertAddressStatement(tx) + addressRefStmt, done, err := addressRefStmt(tx) if err != nil { return fmt.Errorf("failed to prepare address statement: %w", err) } - defer addrStmt.Close() + defer done() stmt, err := tx.Prepare(`DELETE FROM siacoin_elements WHERE id=$1 RETURNING id, matured`) if err != nil { @@ -503,7 +579,7 @@ func removeSiacoinElements(tx *txn, elements []types.SiacoinElement) error { balanceChanges := make(map[int64]wallet.Balance) for _, se := range elements { - addrRef, err := scanAddress(addrStmt.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency), 0)) + addrRef, err := addressRefStmt(se.SiacoinOutput.Address) if err != nil { return fmt.Errorf("failed to query address: %w", err) } else if _, ok := balanceChanges[addrRef.ID]; !ok { @@ -554,11 +630,11 @@ func revertSpentSiacoinElements(tx *txn, elements []types.SiacoinElement) error return nil } - addrStmt, err := insertAddressStatement(tx) + addressRefStmt, done, err := addressRefStmt(tx) if err != nil { return fmt.Errorf("failed to prepare address statement: %w", err) } - defer addrStmt.Close() + defer done() stmt, err := tx.Prepare(`UPDATE siacoin_elements SET spent_index_id=NULL WHERE id=$1 AND spent_index_id IS NOT NULL RETURNING id`) if err != nil { @@ -568,7 +644,7 @@ func revertSpentSiacoinElements(tx *txn, elements []types.SiacoinElement) error balanceChanges := make(map[int64]wallet.Balance) for _, se := range elements { - addrRef, err := scanAddress(addrStmt.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency), 0)) + addrRef, err := addressRefStmt(se.SiacoinOutput.Address) if err != nil { return fmt.Errorf("failed to query address: %w", err) } else if _, ok := balanceChanges[addrRef.ID]; !ok { @@ -615,11 +691,11 @@ func spendSiacoinElements(tx *txn, elements []types.SiacoinElement, indexID int6 return nil } - addrStmt, err := insertAddressStatement(tx) + addressRefStmt, done, err := addressRefStmt(tx) if err != nil { return fmt.Errorf("failed to prepare address statement: %w", err) } - defer addrStmt.Close() + defer done() stmt, err := tx.Prepare(`UPDATE siacoin_elements SET spent_index_id=$1 WHERE id=$2 AND spent_index_id IS NULL RETURNING id`) if err != nil { @@ -629,7 +705,7 @@ func spendSiacoinElements(tx *txn, elements []types.SiacoinElement, indexID int6 balanceChanges := make(map[int64]wallet.Balance) for _, se := range elements { - addrRef, err := scanAddress(addrStmt.QueryRow(encode(se.SiacoinOutput.Address), encode(types.ZeroCurrency), 0)) + addrRef, err := addressRefStmt(se.SiacoinOutput.Address) if err != nil { return fmt.Errorf("failed to query address: %w", err) } else if _, ok := balanceChanges[addrRef.ID]; !ok { @@ -671,16 +747,16 @@ func spendSiacoinElements(tx *txn, elements []types.SiacoinElement, indexID int6 return nil } -func addSiafundElements(tx *txn, elements []types.SiafundElement, indexID int64, log *zap.Logger) error { +func addSiafundElements(tx *txn, elements []types.SiafundElement, indexID int64, indexMode wallet.IndexMode, log *zap.Logger) error { if len(elements) == 0 { return nil } - addrStmt, err := insertAddressStatement(tx) + addressRefStmt, done, err := addressRefStmt(tx) if err != nil { return fmt.Errorf("failed to prepare address statement: %w", err) } - defer addrStmt.Close() + defer done() existsStmt, err := tx.Prepare(`SELECT EXISTS(SELECT 1 FROM siafund_elements WHERE id=$1)`) if err != nil { @@ -696,7 +772,7 @@ func addSiafundElements(tx *txn, elements []types.SiafundElement, indexID int64, balanceChanges := make(map[int64]uint64) for _, se := range elements { - addrRef, err := scanAddress(addrStmt.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency), 0)) + addrRef, err := addressRefStmt(se.SiafundOutput.Address) if err != nil { return fmt.Errorf("failed to query address: %w", err) } else if _, ok := balanceChanges[addrRef.ID]; !ok { @@ -708,6 +784,12 @@ func addSiafundElements(tx *txn, elements []types.SiafundElement, indexID int64, return fmt.Errorf("failed to check if siafund element exists: %w", err) } + // in full index mode, Merkle proofs are stored in the state tree table + // rather than per element. + if indexMode == wallet.IndexModeFull { + se.MerkleProof = nil + } + _, err = insertStmt.Exec(encode(se.ID), se.SiafundOutput.Value, encodeSlice(se.MerkleProof), se.LeafIndex, encode(se.ClaimStart), addrRef.ID, indexID) if err != nil { return fmt.Errorf("failed to execute statement: %w", err) @@ -747,11 +829,11 @@ func removeSiafundElements(tx *txn, elements []types.SiafundElement) error { return nil } - addrStmt, err := insertAddressStatement(tx) + addressRefStmt, done, err := addressRefStmt(tx) if err != nil { return fmt.Errorf("failed to prepare address statement: %w", err) } - defer addrStmt.Close() + defer done() stmt, err := tx.Prepare(`DELETE FROM siafund_elements WHERE id=$1 RETURNING id`) if err != nil { @@ -761,7 +843,7 @@ func removeSiafundElements(tx *txn, elements []types.SiafundElement) error { balanceChanges := make(map[int64]uint64) for _, se := range elements { - addrRef, err := scanAddress(addrStmt.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency), 0)) + addrRef, err := addressRefStmt(se.SiafundOutput.Address) if err != nil { return fmt.Errorf("failed to query address: %w", err) } else if _, ok := balanceChanges[addrRef.ID]; !ok { @@ -808,11 +890,11 @@ func spendSiafundElements(tx *txn, elements []types.SiafundElement, indexID int6 return nil } - addrStmt, err := insertAddressStatement(tx) + addressRefStmt, done, err := addressRefStmt(tx) if err != nil { return fmt.Errorf("failed to prepare address statement: %w", err) } - defer addrStmt.Close() + defer done() stmt, err := tx.Prepare(`UPDATE siafund_elements SET spent_index_id=$1 WHERE id=$2 AND spent_index_id IS NULL RETURNING id`) if err != nil { @@ -822,7 +904,7 @@ func spendSiafundElements(tx *txn, elements []types.SiafundElement, indexID int6 balanceChanges := make(map[int64]wallet.Balance) for _, se := range elements { - addrRef, err := scanAddress(addrStmt.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency), 0)) + addrRef, err := addressRefStmt(se.SiafundOutput.Address) if err != nil { return fmt.Errorf("failed to query address: %w", err) } else if _, ok := balanceChanges[addrRef.ID]; !ok { @@ -873,11 +955,11 @@ func revertSpentSiafundElements(tx *txn, elements []types.SiafundElement) error return nil } - addrStmt, err := insertAddressStatement(tx) + addressRefStmt, done, err := addressRefStmt(tx) if err != nil { return fmt.Errorf("failed to prepare address statement: %w", err) } - defer addrStmt.Close() + defer done() stmt, err := tx.Prepare(`UPDATE siafund_elements SET spent_index_id=NULL WHERE id=$1 AND spent_index_id IS NOT NULL RETURNING id`) if err != nil { @@ -887,7 +969,7 @@ func revertSpentSiafundElements(tx *txn, elements []types.SiafundElement) error balanceChanges := make(map[int64]wallet.Balance) for _, se := range elements { - addrRef, err := scanAddress(addrStmt.QueryRow(encode(se.SiafundOutput.Address), encode(types.ZeroCurrency), 0)) + addrRef, err := addressRefStmt(se.SiafundOutput.Address) if err != nil { return fmt.Errorf("failed to query address: %w", err) } else if _, ok := balanceChanges[addrRef.ID]; !ok { @@ -940,7 +1022,7 @@ func addEvents(tx *txn, events []wallet.Event, indexID int64) error { } defer insertEventStmt.Close() - addrStmt, err := tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) VALUES ($1, $2, $3, 0) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address RETURNING id`) + addrStmt, err := tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) VALUES ($1, $2, $2, 0) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address RETURNING id`) if err != nil { return fmt.Errorf("failed to prepare address statement: %w", err) } @@ -975,7 +1057,7 @@ func addEvents(tx *txn, events []wallet.Event, indexID int64) error { } var addressID int64 - err = addrStmt.QueryRow(encode(addr), encode(types.ZeroCurrency), 0).Scan(&addressID) + err = addrStmt.QueryRow(encode(addr), encode(types.ZeroCurrency)).Scan(&addressID) if err != nil { return fmt.Errorf("failed to get address: %w", err) } @@ -1206,48 +1288,40 @@ func revertOrphans(tx *txn, index types.ChainIndex, log *zap.Logger) error { return err } -func pruneSpentSiacoinElements(tx *txn, height uint64) (removed []types.SiacoinOutputID, err error) { - const query = `DELETE FROM siacoin_elements WHERE spent_index_id IN (SELECT id FROM chain_indices WHERE height <= $1) RETURNING id` - rows, err := tx.Query(query, height) +func pruneSpentSiacoinElements(tx *txn, height uint64) (removed int64, err error) { + const query = `DELETE FROM siacoin_elements WHERE spent_index_id IN (SELECT id FROM chain_indices WHERE height <= $1)` + res, err := tx.Exec(query, height) if err != nil { - return nil, fmt.Errorf("failed to query siacoin elements: %w", err) + return 0, fmt.Errorf("failed to query siacoin elements: %w", err) } - defer rows.Close() - - for rows.Next() { - var id types.SiacoinOutputID - if err := rows.Scan(decode(&id)); err != nil { - return nil, fmt.Errorf("failed to scan siacoin element: %w", err) - } - removed = append(removed, id) - } - return removed, rows.Err() + return res.RowsAffected() } -func pruneSpentSiafundElements(tx *txn, height uint64) (removed []types.SiafundOutputID, err error) { - const query = `DELETE FROM siafund_elements WHERE spent_index_id IN (SELECT id FROM chain_indices WHERE height <= $1) RETURNING id` - rows, err := tx.Query(query, height) +func pruneSpentSiafundElements(tx *txn, height uint64) (removed int64, err error) { + const query = `DELETE FROM siafund_elements WHERE spent_index_id IN (SELECT id FROM chain_indices WHERE height <= $1)` + res, err := tx.Exec(query, height) if err != nil { - return nil, fmt.Errorf("failed to query siafund elements: %w", err) - } - defer rows.Close() - - for rows.Next() { - var id types.SiafundOutputID - if err := rows.Scan(decode(&id)); err != nil { - return nil, fmt.Errorf("failed to scan siafund element: %w", err) - } - removed = append(removed, id) + return 0, fmt.Errorf("failed to query siacoin elements: %w", err) } - return removed, rows.Err() + return res.RowsAffected() } -func setLastCommittedIndex(tx *txn, index types.ChainIndex) error { - _, err := tx.Exec(`UPDATE global_settings SET last_indexed_tip=$1`, encode(index)) +func setGlobalState(tx *txn, index types.ChainIndex, numLeaves uint64) error { + _, err := tx.Exec(`UPDATE global_settings SET last_indexed_tip=$1, element_num_leaves=$2`, encode(index), numLeaves) return err } -func insertAddressStatement(tx *txn) (*stmt, error) { +func addressRefStmt(tx *txn) (func(types.Address) (addressRef, error), func() error, error) { + stmt, err := tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) VALUES ($1, $2, $3, $4) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address RETURNING id, siacoin_balance, immature_siacoin_balance, siafund_balance`) + if err != nil { + return nil, nil, fmt.Errorf("failed to prepare address statement: %w", err) + } // the on conflict is effectively a no-op, but enables us to return the id of the existing address - return tx.Prepare(`INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) VALUES ($1, $2, $2, $3) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address RETURNING id, siacoin_balance, immature_siacoin_balance, siafund_balance`) + return func(addr types.Address) (addressRef, error) { + ref, err := scanAddress(stmt.QueryRow(encode(addr), encode(types.ZeroCurrency), encode(types.ZeroCurrency), 0)) + if err != nil { + return addressRef{}, fmt.Errorf("failed to get address %q: %w", addr, err) + } + return ref, nil + }, stmt.Close, nil } diff --git a/persist/sqlite/init.go b/persist/sqlite/init.go index a29a00c..329e9a8 100644 --- a/persist/sqlite/init.go +++ b/persist/sqlite/init.go @@ -18,7 +18,7 @@ import ( var initDatabase string func initializeSettings(tx *txn, target int64) error { - _, err := tx.Exec(`INSERT INTO global_settings (id, db_version, last_indexed_tip) VALUES (0, ?, ?)`, target, encode(types.ChainIndex{})) + _, err := tx.Exec(`INSERT INTO global_settings (id, db_version, last_indexed_tip, element_num_leaves) VALUES (0, ?, ?, ?)`, target, encode(types.ChainIndex{}), 0) return err } diff --git a/persist/sqlite/init.sql b/persist/sqlite/init.sql index c4c3537..133cad1 100644 --- a/persist/sqlite/init.sql +++ b/persist/sqlite/init.sql @@ -43,6 +43,13 @@ CREATE INDEX siafund_elements_address_id ON siafund_elements (address_id); CREATE INDEX siafund_elements_chain_index_id ON siafund_elements (chain_index_id); CREATE INDEX siafund_elements_spent_index_id ON siafund_elements (spent_index_id); +CREATE TABLE state_tree ( + row INTEGER, + column INTEGER, + value BLOB NOT NULL, + PRIMARY KEY (row, column) +); + CREATE TABLE events ( id INTEGER PRIMARY KEY, chain_index_id INTEGER NOT NULL REFERENCES chain_indices (id), @@ -97,5 +104,7 @@ CREATE INDEX syncer_bans_expiration_index ON syncer_bans (expiration); CREATE TABLE global_settings ( id INTEGER PRIMARY KEY NOT NULL DEFAULT 0 CHECK (id = 0), -- enforce a single row db_version INTEGER NOT NULL, -- used for migrations - last_indexed_tip BLOB NOT NULL -- the last chain index that was processed + index_mode INTEGER, -- the mode of the data store + last_indexed_tip BLOB NOT NULL, -- the last chain index that was processed + element_num_leaves INTEGER NOT NULL -- the number of leaves in the state tree ); diff --git a/persist/sqlite/peers_test.go b/persist/sqlite/peers_test.go index 135ad26..eca3a77 100644 --- a/persist/sqlite/peers_test.go +++ b/persist/sqlite/peers_test.go @@ -12,7 +12,7 @@ import ( func TestAddPeer(t *testing.T) { log := zaptest.NewLogger(t) - db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log) + db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log.Named("sqlite3")) if err != nil { t.Fatal(err) } @@ -77,7 +77,7 @@ func TestAddPeer(t *testing.T) { func TestBanPeer(t *testing.T) { log := zaptest.NewLogger(t) - db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log) + db, err := OpenDatabase(filepath.Join(t.TempDir(), "test.db"), log.Named("sqlite3")) if err != nil { t.Fatal(err) } diff --git a/persist/sqlite/store.go b/persist/sqlite/store.go index 417ee19..1237413 100644 --- a/persist/sqlite/store.go +++ b/persist/sqlite/store.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "go.sia.tech/walletd/wallet" "go.uber.org/zap" "lukechampine.com/frand" ) @@ -16,6 +17,8 @@ import ( type ( // A Store is a persistent store that uses a SQL database as its backend. Store struct { + indexMode wallet.IndexMode + db *sql.DB log *zap.Logger } @@ -75,11 +78,11 @@ func sqliteFilepath(fp string) string { // an error, the transaction is rolled back. Otherwise, the transaction is // committed. func doTransaction(db *sql.DB, log *zap.Logger, fn func(tx *txn) error) error { - start := time.Now() dbtx, err := db.Begin() if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) } + start := time.Now() defer func() { if err := dbtx.Rollback(); err != nil && !errors.Is(err, sql.ErrTxDone) { log.Error("failed to rollback transaction", zap.Error(err)) diff --git a/persist/sqlite/wallet.go b/persist/sqlite/wallet.go index b812e7a..8e01e99 100644 --- a/persist/sqlite/wallet.go +++ b/persist/sqlite/wallet.go @@ -5,100 +5,13 @@ import ( "encoding/json" "errors" "fmt" + "math/bits" "time" "go.sia.tech/core/types" "go.sia.tech/walletd/wallet" ) -func scanSiacoinElement(s scanner) (se types.SiacoinElement, err error) { - err = s.Scan(decode(&se.ID), decode(&se.SiacoinOutput.Value), decodeSlice(&se.MerkleProof), &se.LeafIndex, &se.MaturityHeight, decode(&se.SiacoinOutput.Address)) - return -} - -func scanSiafundElement(s scanner) (se types.SiafundElement, err error) { - err = s.Scan(decode(&se.ID), &se.LeafIndex, decodeSlice(&se.MerkleProof), &se.SiafundOutput.Value, decode(&se.ClaimStart), decode(&se.SiafundOutput.Address)) - return -} - -func insertAddress(tx *txn, addr types.Address) (id int64, err error) { - const query = `INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) -VALUES ($1, $2, $2, 0) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address -RETURNING id` - - err = tx.QueryRow(query, encode(addr), encode(types.ZeroCurrency)).Scan(&id) - return -} - -func scanEvent(s scanner) (ev wallet.Event, eventID int64, err error) { - var eventType string - var eventBuf []byte - - err = s.Scan(&eventID, decode(&ev.ID), &ev.MaturityHeight, decode(&ev.Timestamp), &ev.Index.Height, decode(&ev.Index.ID), &eventType, &eventBuf) - if err != nil { - return - } - - switch eventType { - case wallet.EventTypeTransaction: - var tx wallet.EventTransaction - if err = json.Unmarshal(eventBuf, &tx); err != nil { - return wallet.Event{}, 0, fmt.Errorf("failed to unmarshal transaction event: %w", err) - } - ev.Data = &tx - case wallet.EventTypeContractPayout: - var m wallet.EventContractPayout - if err = json.Unmarshal(eventBuf, &m); err != nil { - return wallet.Event{}, 0, fmt.Errorf("failed to unmarshal missed file contract event: %w", err) - } - ev.Data = &m - case wallet.EventTypeMinerPayout: - var m wallet.EventMinerPayout - if err = json.Unmarshal(eventBuf, &m); err != nil { - return wallet.Event{}, 0, fmt.Errorf("failed to unmarshal payout event: %w", err) - } - ev.Data = &m - case wallet.EventTypeFoundationSubsidy: - var m wallet.EventFoundationSubsidy - if err = json.Unmarshal(eventBuf, &m); err != nil { - return wallet.Event{}, 0, fmt.Errorf("failed to unmarshal foundation subsidy event: %w", err) - } - ev.Data = &m - default: - return wallet.Event{}, 0, fmt.Errorf("unknown event type: %s", eventType) - } - return -} - -func getWalletEvents(tx *txn, id wallet.ID, offset, limit int) (events []wallet.Event, eventIDs []int64, err error) { - const query = `SELECT ev.id, ev.event_id, ev.maturity_height, ev.date_created, ci.height, ci.block_id, ev.event_type, ev.event_data - FROM events ev - INNER JOIN chain_indices ci ON (ev.chain_index_id = ci.id) - WHERE ev.id IN (SELECT event_id FROM event_addresses WHERE address_id IN (SELECT address_id FROM wallet_addresses WHERE wallet_id=$1)) - ORDER BY ev.maturity_height DESC, ev.id DESC - LIMIT $2 OFFSET $3` - - rows, err := tx.Query(query, id, limit, offset) - if err != nil { - return nil, nil, err - } - defer rows.Close() - - for rows.Next() { - event, eventID, err := scanEvent(rows) - if err != nil { - return nil, nil, fmt.Errorf("failed to scan event: %w", err) - } - - events = append(events, event) - eventIDs = append(eventIDs, eventID) - } - if err := rows.Err(); err != nil { - return nil, nil, err - } - return -} - func (s *Store) getWalletEventRelevantAddresses(tx *txn, id wallet.ID, eventIDs []int64) (map[int64][]types.Address, error) { query := `SELECT ea.event_id, sa.sia_address FROM event_addresses ea @@ -322,7 +235,26 @@ func (s *Store) WalletSiacoinOutputs(id wallet.ID, offset, limit int) (siacoins siacoins = append(siacoins, siacoin) } - return rows.Err() + + if err := rows.Err(); err != nil { + return err + } + + // retrieve the merkle proofs for the siacoin elements + if s.indexMode == wallet.IndexModeFull { + indices := make([]uint64, len(siacoins)) + for i, se := range siacoins { + indices[i] = se.LeafIndex + } + proofs, err := fillElementProofs(tx, indices) + if err != nil { + return fmt.Errorf("failed to fill element proofs: %w", err) + } + for i, proof := range proofs { + siacoins[i].MerkleProof = proof + } + } + return nil }) return } @@ -353,7 +285,25 @@ func (s *Store) WalletSiafundOutputs(id wallet.ID, offset, limit int) (siafunds } siafunds = append(siafunds, siafund) } - return rows.Err() + if err := rows.Err(); err != nil { + return err + } + + // retrieve the merkle proofs for the siacoin elements + if s.indexMode == wallet.IndexModeFull { + indices := make([]uint64, len(siafunds)) + for i, se := range siafunds { + indices[i] = se.LeafIndex + } + proofs, err := fillElementProofs(tx, indices) + if err != nil { + return fmt.Errorf("failed to fill element proofs: %w", err) + } + for i, proof := range proofs { + siafunds[i].MerkleProof = proof + } + } + return nil }) return } @@ -434,10 +384,142 @@ WHERE wa.wallet_id=$1 AND sa.sia_address=$2 LIMIT 1` return } +func scanSiacoinElement(s scanner) (se types.SiacoinElement, err error) { + err = s.Scan(decode(&se.ID), decode(&se.SiacoinOutput.Value), decodeSlice(&se.MerkleProof), &se.LeafIndex, &se.MaturityHeight, decode(&se.SiacoinOutput.Address)) + return +} + +func scanSiafundElement(s scanner) (se types.SiafundElement, err error) { + err = s.Scan(decode(&se.ID), &se.LeafIndex, decodeSlice(&se.MerkleProof), &se.SiafundOutput.Value, decode(&se.ClaimStart), decode(&se.SiafundOutput.Address)) + return +} + +func insertAddress(tx *txn, addr types.Address) (id int64, err error) { + const query = `INSERT INTO sia_addresses (sia_address, siacoin_balance, immature_siacoin_balance, siafund_balance) +VALUES ($1, $2, $3, 0) ON CONFLICT (sia_address) DO UPDATE SET sia_address=EXCLUDED.sia_address +RETURNING id` + + err = tx.QueryRow(query, encode(addr), encode(types.ZeroCurrency), encode(types.ZeroCurrency)).Scan(&id) + return +} + +func fillElementProofs(tx *txn, indices []uint64) (proofs [][]types.Hash256, _ error) { + if len(indices) == 0 { + return nil, nil + } + + var numLeaves uint64 + if err := tx.QueryRow(`SELECT element_num_leaves FROM global_settings LIMIT 1`).Scan(&numLeaves); err != nil { + return nil, fmt.Errorf("failed to query state tree leaves: %w", err) + } + + stmt, err := tx.Prepare(`SELECT value FROM state_tree WHERE row=? AND column=?`) + if err != nil { + return nil, fmt.Errorf("failed to prepare statement: %w", err) + } + defer stmt.Close() + + data := make(map[uint64]map[uint64]types.Hash256) + for _, leafIndex := range indices { + proof := make([]types.Hash256, bits.Len64(leafIndex^numLeaves)-1) + for j := range proof { + row, col := uint64(j), (leafIndex>>j)^1 + + // check if the hash is already in the cache + if h, ok := data[row][col]; ok { + proof[j] = h + continue + } + + // query the hash from the database + if err := stmt.QueryRow(row, col).Scan(decode(&proof[j])); err != nil { + return nil, fmt.Errorf("failed to query state element (%d,%d): %w", row, col, err) + } + + // cache the hash + if _, ok := data[row]; !ok { + data[row] = make(map[uint64]types.Hash256) + } + data[row][col] = proof[j] + } + proofs = append(proofs, proof) + } + return +} + +func scanEvent(s scanner) (ev wallet.Event, eventID int64, err error) { + var eventType string + var eventBuf []byte + + err = s.Scan(&eventID, decode(&ev.ID), &ev.MaturityHeight, decode(&ev.Timestamp), &ev.Index.Height, decode(&ev.Index.ID), &eventType, &eventBuf) + if err != nil { + return + } + + switch eventType { + case wallet.EventTypeTransaction: + var tx wallet.EventTransaction + if err = json.Unmarshal(eventBuf, &tx); err != nil { + return wallet.Event{}, 0, fmt.Errorf("failed to unmarshal transaction event: %w", err) + } + ev.Data = &tx + case wallet.EventTypeContractPayout: + var m wallet.EventContractPayout + if err = json.Unmarshal(eventBuf, &m); err != nil { + return wallet.Event{}, 0, fmt.Errorf("failed to unmarshal missed file contract event: %w", err) + } + ev.Data = &m + case wallet.EventTypeMinerPayout: + var m wallet.EventMinerPayout + if err = json.Unmarshal(eventBuf, &m); err != nil { + return wallet.Event{}, 0, fmt.Errorf("failed to unmarshal payout event: %w", err) + } + ev.Data = &m + case wallet.EventTypeFoundationSubsidy: + var m wallet.EventFoundationSubsidy + if err = json.Unmarshal(eventBuf, &m); err != nil { + return wallet.Event{}, 0, fmt.Errorf("failed to unmarshal foundation subsidy event: %w", err) + } + ev.Data = &m + default: + return wallet.Event{}, 0, fmt.Errorf("unknown event type: %s", eventType) + } + return +} + +func getWalletEvents(tx *txn, id wallet.ID, offset, limit int) (events []wallet.Event, eventIDs []int64, err error) { + const query = `SELECT ev.id, ev.event_id, ev.maturity_height, ev.date_created, ci.height, ci.block_id, ev.event_type, ev.event_data + FROM events ev + INNER JOIN chain_indices ci ON (ev.chain_index_id = ci.id) + WHERE ev.id IN (SELECT event_id FROM event_addresses WHERE address_id IN (SELECT address_id FROM wallet_addresses WHERE wallet_id=$1)) + ORDER BY ev.maturity_height DESC, ev.id DESC + LIMIT $2 OFFSET $3` + + rows, err := tx.Query(query, id, limit, offset) + if err != nil { + return nil, nil, err + } + defer rows.Close() + + for rows.Next() { + event, eventID, err := scanEvent(rows) + if err != nil { + return nil, nil, fmt.Errorf("failed to scan event: %w", err) + } + + events = append(events, event) + eventIDs = append(eventIDs, eventID) + } + if err := rows.Err(); err != nil { + return nil, nil, err + } + return +} + func walletExists(tx *txn, id wallet.ID) error { - const query = `SELECT id FROM wallets WHERE id=$1` - var dummyID int64 - err := tx.QueryRow(query, id).Scan(&dummyID) + const query = `SELECT 1 FROM wallets WHERE id=$1` + var dummy int + err := tx.QueryRow(query, id).Scan(&dummy) if errors.Is(err, sql.ErrNoRows) { return wallet.ErrNotFound } diff --git a/wallet/manager.go b/wallet/manager.go index c75aefd..94f342d 100644 --- a/wallet/manager.go +++ b/wallet/manager.go @@ -2,6 +2,7 @@ package wallet import ( "context" + "errors" "fmt" "sync" "time" @@ -12,7 +13,32 @@ import ( "go.uber.org/zap" ) +// IndexMode represents the index mode of the wallet manager. The index mode +// determines how the wallet manager stores the consensus state. +// +// IndexModePartial - The wallet manager scans the blockchain starting at +// genesis. Only state from addresses that are registered with a +// wallet will be stored. If an address is added to a wallet after the +// scan completes, the manager will need to rescan. +// +// IndexModeFull - The wallet manager scans the blockchain starting at genesis +// and stores the state of all addresses. +// +// IndexModeNone - The wallet manager does not scan the blockchain. This is +// useful for multiple nodes sharing the same database. None should only be used +// when connecting to a database that is in "Full" mode. +const ( + IndexModePartial IndexMode = iota + IndexModeFull + IndexModeNone +) + +const syncBatchSize = 250 + type ( + // An IndexMode determines the chain state that the wallet manager stores. + IndexMode uint8 + // A ChainManager manages the consensus state ChainManager interface { Tip() types.ChainIndex @@ -46,11 +72,14 @@ type ( AddressSiacoinOutputs(address types.Address, offset, limit int) (siacoins []types.SiacoinElement, err error) AddressSiafundOutputs(address types.Address, offset, limit int) (siafunds []types.SiafundElement, err error) + SetIndexMode(IndexMode) error LastCommittedIndex() (types.ChainIndex, error) } // A Manager manages wallets. Manager struct { + indexMode IndexMode + chain ChainManager store Store log *zap.Logger @@ -61,6 +90,25 @@ type ( } ) +// String returns the string representation of the index mode. +func (i IndexMode) String() string { + switch i { + case IndexModePartial: + return "partial" + case IndexModeFull: + return "full" + case IndexModeNone: + return "none" + default: + return "unknown" + } +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (i IndexMode) MarshalText() ([]byte, error) { + return []byte(i.String()), nil +} + // Tip returns the last scanned chain index of the manager. func (m *Manager) Tip() (types.ChainIndex, error) { return m.store.LastCommittedIndex() @@ -159,6 +207,10 @@ func (m *Manager) Reserve(ids []types.Hash256, duration time.Duration) error { // Scan rescans the chain starting from the given index. The scan will complete // when the chain manager reaches the current tip or the context is canceled. func (m *Manager) Scan(ctx context.Context, index types.ChainIndex) error { + if m.indexMode != IndexModePartial { + return fmt.Errorf("scans are disabled in index mode %s", m.indexMode) + } + ctx, cancel, err := m.tg.AddWithContext(ctx) if err != nil { return err @@ -170,6 +222,11 @@ func (m *Manager) Scan(ctx context.Context, index types.ChainIndex) error { return syncStore(ctx, m.store, m.chain, index) } +// IndexMode returns the index mode of the wallet manager. +func (m *Manager) IndexMode() IndexMode { + return m.indexMode +} + // Close closes the wallet manager. func (m *Manager) Close() error { m.tg.Stop() @@ -183,7 +240,7 @@ func syncStore(ctx context.Context, store Store, cm ChainManager, index types.Ch return ctx.Err() default: } - crus, caus, err := cm.UpdatesSince(index, 1000) + crus, caus, err := cm.UpdatesSince(index, syncBatchSize) if err != nil { return fmt.Errorf("failed to subscribe to chain manager: %w", err) } else if err := store.UpdateChainState(crus, caus); err != nil { @@ -195,14 +252,29 @@ func syncStore(ctx context.Context, store Store, cm ChainManager, index types.Ch } // NewManager creates a new wallet manager. -func NewManager(cm ChainManager, store Store, log *zap.Logger) *Manager { +func NewManager(cm ChainManager, store Store, opts ...Option) (*Manager, error) { m := &Manager{ + indexMode: IndexModePartial, + chain: cm, store: store, - log: log, + log: zap.NewNop(), tg: threadgroup.New(), } + for _, opt := range opts { + opt(m) + } + + // if the index mode is none, skip setting the index mode in the store + // and return the manager + if m.indexMode == IndexModeNone { + return m, nil + } else if err := store.SetIndexMode(m.indexMode); err != nil { + return nil, err + } + + // start a goroutine to sync the store with the chain manager reorgChan := make(chan struct{}, 1) reorgChan <- struct{}{} unsubscribe := cm.OnReorg(func(index types.ChainIndex) { @@ -215,6 +287,7 @@ func NewManager(cm ChainManager, store Store, log *zap.Logger) *Manager { go func() { defer unsubscribe() + log := m.log.Named("sync") ctx, cancel, err := m.tg.AddWithContext(context.Background()) if err != nil { log.Panic("failed to add to threadgroup", zap.Error(err)) @@ -232,12 +305,12 @@ func NewManager(cm ChainManager, store Store, log *zap.Logger) *Manager { // update the store lastTip, err := store.LastCommittedIndex() if err != nil { - log.Error("failed to get last committed index", zap.Error(err)) - } else if err := syncStore(ctx, store, cm, lastTip); err != nil { - log.Error("failed to sync store", zap.Error(err)) + log.Panic("failed to get last committed index", zap.Error(err)) + } else if err := syncStore(ctx, store, cm, lastTip); err != nil && !errors.Is(err, context.Canceled) { + log.Panic("failed to sync store", zap.Error(err)) } m.mu.Unlock() } }() - return m + return m, nil } diff --git a/wallet/options.go b/wallet/options.go new file mode 100644 index 0000000..7033321 --- /dev/null +++ b/wallet/options.go @@ -0,0 +1,20 @@ +package wallet + +import "go.uber.org/zap" + +// An Option configures a wallet Manager. +type Option func(*Manager) + +// WithLogger sets the logger used by the manager. +func WithLogger(log *zap.Logger) Option { + return func(m *Manager) { + m.log = log + } +} + +// WithIndexMode sets the index mode used by the manager. +func WithIndexMode(mode IndexMode) Option { + return func(m *Manager) { + m.indexMode = mode + } +} diff --git a/wallet/update.go b/wallet/update.go index 8a42c6a..83852b0 100644 --- a/wallet/update.go +++ b/wallet/update.go @@ -9,6 +9,13 @@ import ( ) type ( + // A stateTreeUpdater is an interface for applying and reverting + // Merkle tree updates. + stateTreeUpdater interface { + UpdateElementProof(e *types.StateElement) + ForEachTreeNode(fn func(row uint64, col uint64, h types.Hash256)) + } + // AddressBalance pairs an address with its balance. AddressBalance struct { Address types.Address `json:"address"` @@ -18,6 +25,7 @@ type ( // AppliedState contains all state changes made to a store after applying a chain // update. AppliedState struct { + NumLeaves uint64 Events []Event CreatedSiacoinElements []types.SiacoinElement SpentSiacoinElements []types.SiacoinElement @@ -28,12 +36,21 @@ type ( // RevertedState contains all state changes made to a store after reverting // a chain update. RevertedState struct { + NumLeaves uint64 UnspentSiacoinElements []types.SiacoinElement DeletedSiacoinElements []types.SiacoinElement UnspentSiafundElements []types.SiafundElement DeletedSiafundElements []types.SiafundElement } + // A TreeNodeUpdate contains the hash of a Merkle tree node and its row and + // column indices. + TreeNodeUpdate struct { + Hash types.Hash256 + Row int + Column int + } + // An UpdateTx atomically updates the state of a store. UpdateTx interface { SiacoinStateElements() ([]types.StateElement, error) @@ -42,6 +59,8 @@ type ( SiafundStateElements() ([]types.StateElement, error) UpdateSiafundStateElements([]types.StateElement) error + UpdateStateTree([]TreeNodeUpdate) error + AddressRelevant(types.Address) (bool, error) ApplyIndex(types.ChainIndex, AppliedState) error @@ -49,9 +68,53 @@ type ( } ) +// updateStateElements updates the state elements in a store according to the +// changes made by a chain update. +func updateStateElements(tx UpdateTx, update stateTreeUpdater, indexMode IndexMode) error { + if indexMode == IndexModeNone { + panic("updateStateElements called with IndexModeNone") // developer error + } + + if indexMode == IndexModeFull { + var updates []TreeNodeUpdate + update.ForEachTreeNode(func(row, col uint64, h types.Hash256) { + updates = append(updates, TreeNodeUpdate{h, int(row), int(col)}) + }) + return tx.UpdateStateTree(updates) + } else { + // fetch all siacoin and siafund state elements + siacoinStateElements, err := tx.SiacoinStateElements() + if err != nil { + return fmt.Errorf("failed to get siacoin state elements: %w", err) + } + + // update siacoin element proofs + for i := range siacoinStateElements { + update.UpdateElementProof(&siacoinStateElements[i]) + } + + if err := tx.UpdateSiacoinStateElements(siacoinStateElements); err != nil { + return fmt.Errorf("failed to update siacoin state elements: %w", err) + } + + siafundStateElements, err := tx.SiafundStateElements() + if err != nil { + return fmt.Errorf("failed to get siafund state elements: %w", err) + } + + // update siafund element proofs + for i := range siafundStateElements { + update.UpdateElementProof(&siafundStateElements[i]) + } + return tx.UpdateSiafundStateElements(siafundStateElements) + } +} + // applyChainUpdate atomically applies a chain update to a store -func applyChainUpdate(tx UpdateTx, cau chain.ApplyUpdate) error { - var applied AppliedState +func applyChainUpdate(tx UpdateTx, cau chain.ApplyUpdate, indexMode IndexMode) error { + applied := AppliedState{ + NumLeaves: cau.State.Elements.NumLeaves, + } // determine which siacoin and siafund elements are ephemeral // @@ -123,44 +186,19 @@ func applyChainUpdate(tx UpdateTx, cau chain.ApplyUpdate) error { } applied.Events = AppliedEvents(cau.State, cau.Block, cau, relevant) - // fetch all siacoin and siafund state elements - siacoinStateElements, err := tx.SiacoinStateElements() - if err != nil { - return fmt.Errorf("failed to get siacoin state elements: %w", err) - } - - // update siacoin element proofs - for i := range siacoinStateElements { - cau.UpdateElementProof(&siacoinStateElements[i]) - } - - if err := tx.UpdateSiacoinStateElements(siacoinStateElements); err != nil { - return fmt.Errorf("failed to update siacoin state elements: %w", err) - } - - siafundStateElements, err := tx.SiafundStateElements() - if err != nil { - return fmt.Errorf("failed to get siafund state elements: %w", err) - } - - // update siafund element proofs - for i := range siafundStateElements { - cau.UpdateElementProof(&siafundStateElements[i]) - } - - if err := tx.UpdateSiafundStateElements(siafundStateElements); err != nil { - return fmt.Errorf("failed to update siacoin state elements: %w", err) - } - - if err := tx.ApplyIndex(cau.State.Index, applied); err != nil { - return fmt.Errorf("failed to apply chain update %q: %w", cau.State.Index, err) + if err := updateStateElements(tx, cau, indexMode); err != nil { + return fmt.Errorf("failed to update state elements: %w", err) + } else if err := tx.ApplyIndex(cau.State.Index, applied); err != nil { + return fmt.Errorf("failed to apply index: %w", err) } return nil } // revertChainUpdate atomically reverts a chain update from a store -func revertChainUpdate(tx UpdateTx, cru chain.RevertUpdate, revertedIndex types.ChainIndex) error { - var reverted RevertedState +func revertChainUpdate(tx UpdateTx, cru chain.RevertUpdate, revertedIndex types.ChainIndex, indexMode IndexMode) error { + reverted := RevertedState{ + NumLeaves: cru.State.Elements.NumLeaves, + } // determine which siacoin and siafund elements are ephemeral // @@ -226,43 +264,20 @@ func revertChainUpdate(tx UpdateTx, cru chain.RevertUpdate, revertedIndex types. }) if err := tx.RevertIndex(revertedIndex, reverted); err != nil { - return fmt.Errorf("failed to revert index %q: %w", revertedIndex, err) - } - - siacoinElements, err := tx.SiacoinStateElements() - if err != nil { - return fmt.Errorf("failed to get siacoin state elements: %w", err) - } - for i := range siacoinElements { - cru.UpdateElementProof(&siacoinElements[i]) - } - if err := tx.UpdateSiacoinStateElements(siacoinElements); err != nil { - return fmt.Errorf("failed to update siacoin state elements: %w", err) + return fmt.Errorf("failed to revert index: %w", err) } - - // update siafund element proofs - siafundElements, err := tx.SiafundStateElements() - if err != nil { - return fmt.Errorf("failed to get siafund state elements: %w", err) - } - for i := range siafundElements { - cru.UpdateElementProof(&siafundElements[i]) - } - if err := tx.UpdateSiafundStateElements(siafundElements); err != nil { - return fmt.Errorf("failed to update siafund state elements: %w", err) - } - return nil + return updateStateElements(tx, cru, indexMode) } // UpdateChainState atomically updates the state of a store with a set of // updates from the chain manager. -func UpdateChainState(tx UpdateTx, reverted []chain.RevertUpdate, applied []chain.ApplyUpdate, log *zap.Logger) error { +func UpdateChainState(tx UpdateTx, reverted []chain.RevertUpdate, applied []chain.ApplyUpdate, indexMode IndexMode, log *zap.Logger) error { for _, cru := range reverted { revertedIndex := types.ChainIndex{ ID: cru.Block.ID(), Height: cru.State.Index.Height + 1, } - if err := revertChainUpdate(tx, cru, revertedIndex); err != nil { + if err := revertChainUpdate(tx, cru, revertedIndex, indexMode); err != nil { return fmt.Errorf("failed to revert chain update %q: %w", revertedIndex, err) } log.Debug("reverted chain update", zap.Stringer("blockID", revertedIndex.ID), zap.Uint64("height", revertedIndex.Height)) @@ -270,7 +285,7 @@ func UpdateChainState(tx UpdateTx, reverted []chain.RevertUpdate, applied []chai for _, cau := range applied { // apply the chain update - if err := applyChainUpdate(tx, cau); err != nil { + if err := applyChainUpdate(tx, cau, indexMode); err != nil { return fmt.Errorf("failed to apply chain update %q: %w", cau.State.Index, err) } log.Debug("applied chain update", zap.Stringer("blockID", cau.State.Index.ID), zap.Uint64("height", cau.State.Index.Height)) diff --git a/wallet/wallet_test.go b/wallet/wallet_test.go index 3fbe6f4..76fbfb0 100644 --- a/wallet/wallet_test.go +++ b/wallet/wallet_test.go @@ -98,200 +98,220 @@ func TestReorg(t *testing.T) { pk := types.GeneratePrivateKey() addr := types.StandardUnlockHash(pk.PublicKey()) - log := zaptest.NewLogger(t) - dir := t.TempDir() - db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) - if err != nil { - t.Fatal(err) - } - defer db.Close() - - bdb, err := coreutils.OpenBoltChainDB(filepath.Join(dir, "consensus.db")) - if err != nil { - t.Fatal(err) - } - defer bdb.Close() + setupNode := func(t *testing.T, mode wallet.IndexMode) (consensus.State, *sqlite.Store, *chain.Manager, *wallet.Manager) { + t.Helper() - network, genesisBlock := testV1Network(types.VoidAddress) // don't care about siafunds + log := zaptest.NewLogger(t) + dir := t.TempDir() + db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { db.Close() }) - store, genesisState, err := chain.NewDBStore(bdb, network, genesisBlock) - if err != nil { - t.Fatal(err) - } - cm := chain.NewManager(store, genesisState) + bdb, err := coreutils.OpenBoltChainDB(filepath.Join(dir, "consensus.db")) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { bdb.Close() }) - wm := wallet.NewManager(cm, db, log.Named("wallet")) - defer wm.Close() + network, genesisBlock := testV1Network(types.VoidAddress) // don't care about siafunds - w, err := wm.AddWallet(wallet.Wallet{Name: "test"}) - if err != nil { - t.Fatal(err) - } else if err := wm.AddAddress(w.ID, wallet.Address{Address: addr}); err != nil { - t.Fatal(err) - } + store, genesisState, err := chain.NewDBStore(bdb, network, genesisBlock) + if err != nil { + t.Fatal(err) + } + cm := chain.NewManager(store, genesisState) - expectedPayout := cm.TipState().BlockReward() - maturityHeight := cm.TipState().MaturityHeight() - // mine a block sending the payout to the wallet - if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, addr)}); err != nil { - t.Fatal(err) + wm, err := wallet.NewManager(cm, db, wallet.WithLogger(log.Named("wallet")), wallet.WithIndexMode(mode)) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { wm.Close() }) + return genesisState, db, cm, wm } - waitForBlock(t, cm, db) - assertBalance := func(siacoin, immature types.Currency) error { - b, err := wm.WalletBalance(w.ID) + testReorg := func(t *testing.T, genesisState consensus.State, db *sqlite.Store, cm *chain.Manager, wm *wallet.Manager) { + w, err := wm.AddWallet(wallet.Wallet{Name: "test"}) if err != nil { - return fmt.Errorf("failed to check balance: %w", err) - } else if !b.Siacoins.Equals(siacoin) { - return fmt.Errorf("expected siacoin balance %v, got %v", siacoin, b.Siacoins) - } else if !b.ImmatureSiacoins.Equals(immature) { - return fmt.Errorf("expected immature siacoin balance %v, got %v", immature, b.ImmatureSiacoins) + t.Fatal(err) + } else if err := wm.AddAddress(w.ID, wallet.Address{Address: addr}); err != nil { + t.Fatal(err) } - return nil - } - if err := assertBalance(types.ZeroCurrency, expectedPayout); err != nil { - t.Fatal(err) - } + expectedPayout := cm.TipState().BlockReward() + maturityHeight := cm.TipState().MaturityHeight() + // mine a block sending the payout to the wallet + if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, addr)}); err != nil { + t.Fatal(err) + } + waitForBlock(t, cm, db) - // check that a payout event was recorded - events, err := wm.Events(w.ID, 0, 100) - if err != nil { - t.Fatal(err) - } else if len(events) != 1 { - t.Fatalf("expected 1 event, got %v", len(events)) - } else if events[0].Data.EventType() != wallet.EventTypeMinerPayout { - t.Fatalf("expected payout event, got %v", events[0].Data.EventType()) - } + assertBalance := func(siacoin, immature types.Currency) error { + b, err := wm.WalletBalance(w.ID) + if err != nil { + return fmt.Errorf("failed to check balance: %w", err) + } else if !b.Siacoins.Equals(siacoin) { + return fmt.Errorf("expected siacoin balance %v, got %v", siacoin, b.Siacoins) + } else if !b.ImmatureSiacoins.Equals(immature) { + return fmt.Errorf("expected immature siacoin balance %v, got %v", immature, b.ImmatureSiacoins) + } + return nil + } - // check that the utxo was created - utxos, err := wm.UnspentSiacoinOutputs(w.ID, 0, 100) - if err != nil { - t.Fatal(err) - } else if len(utxos) != 1 { - t.Fatalf("expected 1 output, got %v", len(utxos)) - } else if utxos[0].SiacoinOutput.Value.Cmp(expectedPayout) != 0 { - t.Fatalf("expected %v, got %v", expectedPayout, utxos[0].SiacoinOutput.Value) - } else if utxos[0].MaturityHeight != maturityHeight { - t.Fatalf("expected %v, got %v", maturityHeight, utxos[0].MaturityHeight) - } + if err := assertBalance(types.ZeroCurrency, expectedPayout); err != nil { + t.Fatal(err) + } - // mine to trigger a reorg - var blocks []types.Block - state := genesisState - for i := 0; i < 10; i++ { - block := mineBlock(state, nil, types.VoidAddress) - blocks = append(blocks, block) - state.Index.ID = block.ID() - state.Index.Height++ - } - if err := cm.AddBlocks(blocks); err != nil { - t.Fatal(err) - } - waitForBlock(t, cm, db) + // check that a payout event was recorded + events, err := wm.Events(w.ID, 0, 100) + if err != nil { + t.Fatal(err) + } else if len(events) != 1 { + t.Fatalf("expected 1 event, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeMinerPayout { + t.Fatalf("expected payout event, got %v", events[0].Data.EventType()) + } - // check that the balance was reverted - if err := assertBalance(types.ZeroCurrency, types.ZeroCurrency); err != nil { - t.Fatal(err) - } + // check that the utxo was created + utxos, err := wm.UnspentSiacoinOutputs(w.ID, 0, 100) + if err != nil { + t.Fatal(err) + } else if len(utxos) != 1 { + t.Fatalf("expected 1 output, got %v", len(utxos)) + } else if utxos[0].SiacoinOutput.Value.Cmp(expectedPayout) != 0 { + t.Fatalf("expected %v, got %v", expectedPayout, utxos[0].SiacoinOutput.Value) + } else if utxos[0].MaturityHeight != maturityHeight { + t.Fatalf("expected %v, got %v", maturityHeight, utxos[0].MaturityHeight) + } - // check that the payout event was reverted - events, err = wm.Events(w.ID, 0, 100) - if err != nil { - t.Fatal(err) - } else if len(events) != 0 { - t.Fatalf("expected 0 events, got %v", len(events)) - } + // mine to trigger a reorg + var blocks []types.Block + state := genesisState + for i := 0; i < 10; i++ { + block := mineBlock(state, nil, types.VoidAddress) + blocks = append(blocks, block) + state.Index.ID = block.ID() + state.Index.Height++ + } + if err := cm.AddBlocks(blocks); err != nil { + t.Fatal(err) + } + waitForBlock(t, cm, db) - // check that the utxo was removed - utxos, err = wm.UnspentSiacoinOutputs(w.ID, 0, 100) - if err != nil { - t.Fatal(err) - } else if len(utxos) != 0 { - t.Fatalf("expected 0 outputs, got %v", len(utxos)) - } + // check that the balance was reverted + if err := assertBalance(types.ZeroCurrency, types.ZeroCurrency); err != nil { + t.Fatal(err) + } - // mine a new payout - expectedPayout = cm.TipState().BlockReward() - maturityHeight = cm.TipState().MaturityHeight() - if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, addr)}); err != nil { - t.Fatal(err) - } - waitForBlock(t, cm, db) + // check that the payout event was reverted + events, err = wm.Events(w.ID, 0, 100) + if err != nil { + t.Fatal(err) + } else if len(events) != 0 { + t.Fatalf("expected 0 events, got %v", len(events)) + } - // check that the payout was received - if err := assertBalance(types.ZeroCurrency, expectedPayout); err != nil { - t.Fatal(err) - } + // check that the utxo was removed + utxos, err = wm.UnspentSiacoinOutputs(w.ID, 0, 100) + if err != nil { + t.Fatal(err) + } else if len(utxos) != 0 { + t.Fatalf("expected 0 outputs, got %v", len(utxos)) + } - // check that a payout event was recorded - events, err = wm.Events(w.ID, 0, 100) - if err != nil { - t.Fatal(err) - } else if len(events) != 1 { - t.Fatalf("expected 1 event, got %v", len(events)) - } else if events[0].Data.EventType() != wallet.EventTypeMinerPayout { - t.Fatalf("expected payout event, got %v", events[0].Data.EventType()) - } + // mine a new payout + expectedPayout = cm.TipState().BlockReward() + maturityHeight = cm.TipState().MaturityHeight() + if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, addr)}); err != nil { + t.Fatal(err) + } + waitForBlock(t, cm, db) - // check that the utxo was created - utxos, err = wm.UnspentSiacoinOutputs(w.ID, 0, 100) - if err != nil { - t.Fatal(err) - } else if len(utxos) != 1 { - t.Fatalf("expected 1 output, got %v", len(utxos)) - } else if utxos[0].SiacoinOutput.Value.Cmp(expectedPayout) != 0 { - t.Fatalf("expected %v, got %v", expectedPayout, utxos[0].SiacoinOutput.Value) - } else if utxos[0].MaturityHeight != maturityHeight { - t.Fatalf("expected %v, got %v", maturityHeight, utxos[0].MaturityHeight) - } + // check that the payout was received + if err := assertBalance(types.ZeroCurrency, expectedPayout); err != nil { + t.Fatal(err) + } - // mine until the payout matures - var prevState consensus.State - for i := cm.TipState().Index.Height; i < maturityHeight+1; i++ { - if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, types.VoidAddress)}); err != nil { + // check that a payout event was recorded + events, err = wm.Events(w.ID, 0, 100) + if err != nil { t.Fatal(err) + } else if len(events) != 1 { + t.Fatalf("expected 1 event, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeMinerPayout { + t.Fatalf("expected payout event, got %v", events[0].Data.EventType()) } - if i == maturityHeight-5 { - prevState = cm.TipState() + + // check that the utxo was created + utxos, err = wm.UnspentSiacoinOutputs(w.ID, 0, 100) + if err != nil { + t.Fatal(err) + } else if len(utxos) != 1 { + t.Fatalf("expected 1 output, got %v", len(utxos)) + } else if utxos[0].SiacoinOutput.Value.Cmp(expectedPayout) != 0 { + t.Fatalf("expected %v, got %v", expectedPayout, utxos[0].SiacoinOutput.Value) + } else if utxos[0].MaturityHeight != maturityHeight { + t.Fatalf("expected %v, got %v", maturityHeight, utxos[0].MaturityHeight) } - } - waitForBlock(t, cm, db) - // check that the balance was updated - if err := assertBalance(expectedPayout, types.ZeroCurrency); err != nil { - t.Fatal(err) - } + // mine until the payout matures + var prevState consensus.State + for i := cm.TipState().Index.Height; i < maturityHeight+1; i++ { + if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, types.VoidAddress)}); err != nil { + t.Fatal(err) + } + if i == maturityHeight-5 { + prevState = cm.TipState() + } + } + waitForBlock(t, cm, db) - // reorg the last few blocks to re-mature the payout - blocks = nil - state = prevState - for i := 0; i < 10; i++ { - blocks = append(blocks, mineBlock(state, nil, types.VoidAddress)) - state.Index.ID = blocks[len(blocks)-1].ID() - state.Index.Height++ - } - if err := cm.AddBlocks(blocks); err != nil { - t.Fatal(err) - } - waitForBlock(t, cm, db) + // check that the balance was updated + if err := assertBalance(expectedPayout, types.ZeroCurrency); err != nil { + t.Fatal(err) + } - // check that the balance is correct - if err := assertBalance(expectedPayout, types.ZeroCurrency); err != nil { - t.Fatal(err) - } + // reorg the last few blocks to re-mature the payout + blocks = nil + state = prevState + for i := 0; i < 10; i++ { + blocks = append(blocks, mineBlock(state, nil, types.VoidAddress)) + state.Index.ID = blocks[len(blocks)-1].ID() + state.Index.Height++ + } + if err := cm.AddBlocks(blocks); err != nil { + t.Fatal(err) + } + waitForBlock(t, cm, db) - // check that only the single utxo still exists - utxos, err = wm.UnspentSiacoinOutputs(w.ID, 0, 100) - if err != nil { - t.Fatal(err) - } else if len(utxos) != 1 { - t.Fatalf("expected 1 output, got %v", len(utxos)) - } else if utxos[0].SiacoinOutput.Value.Cmp(expectedPayout) != 0 { - t.Fatalf("expected %v, got %v", expectedPayout, utxos[0].SiacoinOutput.Value) - } else if utxos[0].MaturityHeight != maturityHeight { - t.Fatalf("expected %v, got %v", maturityHeight, utxos[0].MaturityHeight) + // check that the balance is correct + if err := assertBalance(expectedPayout, types.ZeroCurrency); err != nil { + t.Fatal(err) + } + + // check that only the single utxo still exists + utxos, err = wm.UnspentSiacoinOutputs(w.ID, 0, 100) + if err != nil { + t.Fatal(err) + } else if len(utxos) != 1 { + t.Fatalf("expected 1 output, got %v", len(utxos)) + } else if utxos[0].SiacoinOutput.Value.Cmp(expectedPayout) != 0 { + t.Fatalf("expected %v, got %v", expectedPayout, utxos[0].SiacoinOutput.Value) + } else if utxos[0].MaturityHeight != maturityHeight { + t.Fatalf("expected %v, got %v", maturityHeight, utxos[0].MaturityHeight) + } } + + t.Run("IndexModePartial", func(t *testing.T) { + state, db, cm, w := setupNode(t, wallet.IndexModePartial) + testReorg(t, state, db, cm, w) + }) + + t.Run("IndexModeFull", func(t *testing.T) { + state, db, cm, w := setupNode(t, wallet.IndexModeFull) + testReorg(t, state, db, cm, w) + }) } func TestEphemeralBalance(t *testing.T) { @@ -320,7 +340,10 @@ func TestEphemeralBalance(t *testing.T) { } cm := chain.NewManager(store, genesisState) - wm := wallet.NewManager(cm, db, log.Named("wallet")) + wm, err := wallet.NewManager(cm, db, wallet.WithLogger(log.Named("wallet"))) + if err != nil { + t.Fatal(err) + } defer wm.Close() w, err := wm.AddWallet(wallet.Wallet{Name: "test"}) @@ -513,7 +536,10 @@ func TestWalletAddresses(t *testing.T) { } cm := chain.NewManager(store, genesisState) - wm := wallet.NewManager(cm, db, log.Named("wallet")) + wm, err := wallet.NewManager(cm, db, wallet.WithLogger(log.Named("wallet"))) + if err != nil { + t.Fatal(err) + } defer wm.Close() // Add a wallet @@ -644,7 +670,10 @@ func TestScan(t *testing.T) { cm := chain.NewManager(store, genesisState) - wm := wallet.NewManager(cm, db, log.Named("wallet")) + wm, err := wallet.NewManager(cm, db, wallet.WithLogger(log.Named("wallet"))) + if err != nil { + t.Fatal(err) + } defer wm.Close() pk2 := types.GeneratePrivateKey() @@ -795,7 +824,10 @@ func TestSiafunds(t *testing.T) { cm := chain.NewManager(store, genesisState) - wm := wallet.NewManager(cm, db, log.Named("wallet")) + wm, err := wallet.NewManager(cm, db, wallet.WithLogger(log.Named("wallet"))) + if err != nil { + t.Fatal(err) + } defer wm.Close() pk2 := types.GeneratePrivateKey() @@ -931,7 +963,10 @@ func TestOrphans(t *testing.T) { } cm := chain.NewManager(store, genesisState) - wm := wallet.NewManager(cm, db, log.Named("wallet")) + wm, err := wallet.NewManager(cm, db, wallet.WithLogger(log.Named("wallet"))) + if err != nil { + t.Fatal(err) + } defer wm.Close() w, err := wm.AddWallet(wallet.Wallet{Name: "test"}) @@ -1062,7 +1097,10 @@ func TestOrphans(t *testing.T) { t.Fatal(err) } - wm = wallet.NewManager(cm, db, log.Named("wallet")) + wm, err = wallet.NewManager(cm, db, wallet.WithLogger(log.Named("wallet"))) + if err != nil { + t.Fatal(err) + } defer wm.Close() waitForBlock(t, cm, db) @@ -1091,6 +1129,217 @@ func TestOrphans(t *testing.T) { } } +func TestFullIndex(t *testing.T) { + pk := types.GeneratePrivateKey() + addr := types.StandardUnlockHash(pk.PublicKey()) + + pk2 := types.GeneratePrivateKey() + addr2 := types.StandardUnlockHash(pk2.PublicKey()) + + log := zaptest.NewLogger(t) + dir := t.TempDir() + db, err := sqlite.OpenDatabase(filepath.Join(dir, "walletd.sqlite3"), log.Named("sqlite3")) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + bdb, err := coreutils.OpenBoltChainDB(filepath.Join(dir, "consensus.db")) + if err != nil { + t.Fatal(err) + } + defer bdb.Close() + + network, genesisBlock := testV2Network(addr2) + store, genesisState, err := chain.NewDBStore(bdb, network, genesisBlock) + if err != nil { + t.Fatal(err) + } + cm := chain.NewManager(store, genesisState) + + wm, err := wallet.NewManager(cm, db, wallet.WithLogger(log.Named("wallet")), wallet.WithIndexMode(wallet.IndexModeFull)) + if err != nil { + t.Fatal(err) + } + defer wm.Close() + + waitForBlock(t, cm, db) + + assertBalance := func(t *testing.T, address types.Address, siacoin, immature types.Currency, siafund uint64) { + t.Helper() + + b, err := wm.AddressBalance(address) + if err != nil { + t.Fatal(err) + } else if !b.ImmatureSiacoins.Equals(immature) { + t.Fatalf("expected immature siacoin balance %v, got %v", immature, b.ImmatureSiacoins) + } else if !b.Siacoins.Equals(siacoin) { + t.Fatalf("expected siacoin balance %v, got %v", siacoin, b.Siacoins) + } else if b.Siafunds != siafund { + t.Fatalf("expected siafund balance %v, got %v", siafund, b.Siafunds) + } + } + + // check the events are empty for the first address + if events, err := wm.AddressEvents(addr, 0, 100); err != nil { + t.Fatal(err) + } else if len(events) != 0 { + t.Fatalf("expected 0 events, got %v", len(events)) + } + + // assert that the airdropped siafunds are on the second address + assertBalance(t, addr2, types.ZeroCurrency, types.ZeroCurrency, cm.TipState().SiafundCount()) + // check the events for the air dropped siafunds + if events, err := wm.AddressEvents(addr2, 0, 100); err != nil { + t.Fatal(err) + } else if len(events) != 1 { + t.Fatalf("expected 1 event, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeTransaction { + t.Fatalf("expected transaction event, got %v", events[0].Data.EventType()) + } + + // mine a block and send the payout to the first address + expectedBalance1 := cm.TipState().BlockReward() + maturityHeight := cm.TipState().MaturityHeight() + if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, addr)}); err != nil { + t.Fatal(err) + } + waitForBlock(t, cm, db) + + // check the payout was received + if events, err := wm.AddressEvents(addr, 0, 100); err != nil { + t.Fatal(err) + } else if len(events) != 1 { + t.Fatalf("expected 1 events, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeMinerPayout { + t.Fatalf("expected miner payout event, got %v", events[0].Data.EventType()) + } + + assertBalance(t, addr, types.ZeroCurrency, expectedBalance1, 0) + + // mine until the payout matures + for i := cm.TipState().Index.Height; i < maturityHeight; i++ { + if err := cm.AddBlocks([]types.Block{mineBlock(cm.TipState(), nil, types.VoidAddress)}); err != nil { + t.Fatal(err) + } + } + waitForBlock(t, cm, db) + + // check that the events did not change + if events, err := wm.AddressEvents(addr, 0, 100); err != nil { + t.Fatal(err) + } else if len(events) != 1 { + t.Fatalf("expected 1 events, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeMinerPayout { + t.Fatalf("expected miner payout event, got %v", events[0].Data.EventType()) + } + + assertBalance(t, addr, expectedBalance1, types.ZeroCurrency, 0) + assertBalance(t, addr2, types.ZeroCurrency, types.ZeroCurrency, cm.TipState().SiafundCount()) + + // send half siacoins to the second address + utxos, err := wm.AddressSiacoinOutputs(addr, 0, 100) + if err != nil { + t.Fatal(err) + } + + policy := types.PolicyTypeUnlockConditions(types.StandardUnlockConditions(pk.PublicKey())) + txn := types.V2Transaction{ + SiacoinInputs: []types.V2SiacoinInput{ + { + Parent: utxos[0], + SatisfiedPolicy: types.SatisfiedPolicy{ + Policy: types.SpendPolicy{ + Type: policy, + }, + }, + }, + }, + SiacoinOutputs: []types.SiacoinOutput{ + {Address: addr2, Value: utxos[0].SiacoinOutput.Value.Div64(2)}, + {Address: addr, Value: utxos[0].SiacoinOutput.Value.Div64(2)}, + }, + } + txn.SiacoinInputs[0].SatisfiedPolicy.Signatures = []types.Signature{pk.SignHash(cm.TipState().InputSigHash(txn))} + + if err := cm.AddBlocks([]types.Block{mineV2Block(cm.TipState(), []types.V2Transaction{txn}, types.VoidAddress)}); err != nil { + t.Fatal(err) + } + waitForBlock(t, cm, db) + + assertBalance(t, addr, expectedBalance1.Div64(2), types.ZeroCurrency, 0) + assertBalance(t, addr2, expectedBalance1.Div64(2), types.ZeroCurrency, cm.TipState().SiafundCount()) + + // check the events for the transaction + if events, err := wm.AddressEvents(addr, 0, 100); err != nil { + t.Fatal(err) + } else if len(events) != 2 { + t.Fatalf("expected 2 events, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeTransaction { + t.Fatalf("expected transaction event, got %v", events[0].Data.EventType()) + } + + // check the events for the second address + if events, err := wm.AddressEvents(addr2, 0, 100); err != nil { + t.Fatal(err) + } else if len(events) != 2 { + t.Fatalf("expected 2 event, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeTransaction { + t.Fatalf("expected transaction event, got %v", events[0].Data.EventType()) + } + + sf, err := wm.AddressSiafundOutputs(addr2, 0, 100) + if err != nil { + t.Fatal(err) + } + + // send the siafunds to the first address + policy = types.PolicyTypeUnlockConditions(types.StandardUnlockConditions(pk2.PublicKey())) + txn = types.V2Transaction{ + SiafundInputs: []types.V2SiafundInput{ + { + Parent: sf[0], + SatisfiedPolicy: types.SatisfiedPolicy{ + Policy: types.SpendPolicy{ + Type: policy, + }, + }, + ClaimAddress: addr2, // claim address shouldn't create an event since the value is 0 + }, + }, + SiafundOutputs: []types.SiafundOutput{ + {Address: addr, Value: sf[0].SiafundOutput.Value}, + }, + } + txn.SiafundInputs[0].SatisfiedPolicy.Signatures = []types.Signature{pk2.SignHash(cm.TipState().InputSigHash(txn))} + + if err := cm.AddBlocks([]types.Block{mineV2Block(cm.TipState(), []types.V2Transaction{txn}, types.VoidAddress)}); err != nil { + t.Fatal(err) + } + waitForBlock(t, cm, db) + + assertBalance(t, addr, expectedBalance1.Div64(2), types.ZeroCurrency, cm.TipState().SiafundCount()) + assertBalance(t, addr2, expectedBalance1.Div64(2), types.ZeroCurrency, 0) + + // check the events for the transaction + if events, err := wm.AddressEvents(addr2, 0, 100); err != nil { + t.Fatal(err) + } else if len(events) != 3 { + t.Fatalf("expected 3 events, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeTransaction { + t.Fatalf("expected transaction event, got %v", events[0].Data.EventType()) + } + + // check the events for the first address + if events, err := wm.AddressEvents(addr, 0, 100); err != nil { + t.Fatal(err) + } else if len(events) != 3 { + t.Fatalf("expected 3 events, got %v", len(events)) + } else if events[0].Data.EventType() != wallet.EventTypeTransaction { + t.Fatalf("expected transaction event, got %v", events[0].Data.EventType()) + } +} + func TestV2(t *testing.T) { pk := types.GeneratePrivateKey() addr := types.StandardUnlockHash(pk.PublicKey()) @@ -1117,7 +1366,10 @@ func TestV2(t *testing.T) { } cm := chain.NewManager(store, genesisState) - wm := wallet.NewManager(cm, db, log.Named("wallet")) + wm, err := wallet.NewManager(cm, db, wallet.WithLogger(log.Named("wallet"))) + if err != nil { + t.Fatal(err) + } defer wm.Close() w, err := wm.AddWallet(wallet.Wallet{Name: "test"}) @@ -1236,7 +1488,10 @@ func TestScanV2(t *testing.T) { cm := chain.NewManager(store, genesisState) - wm := wallet.NewManager(cm, db, log.Named("wallet")) + wm, err := wallet.NewManager(cm, db, wallet.WithLogger(log.Named("wallet"))) + if err != nil { + t.Fatal(err) + } defer wm.Close() pk2 := types.GeneratePrivateKey() @@ -1414,7 +1669,10 @@ func TestReorgV2(t *testing.T) { } cm := chain.NewManager(store, genesisState) - wm := wallet.NewManager(cm, db, log.Named("wallet")) + wm, err := wallet.NewManager(cm, db, wallet.WithLogger(log.Named("wallet"))) + if err != nil { + t.Fatal(err) + } defer wm.Close() w, err := wm.AddWallet(wallet.Wallet{Name: "test"}) @@ -1647,7 +1905,10 @@ func TestOrphansV2(t *testing.T) { } cm := chain.NewManager(store, genesisState) - wm := wallet.NewManager(cm, db, log.Named("wallet")) + wm, err := wallet.NewManager(cm, db, wallet.WithLogger(log.Named("wallet"))) + if err != nil { + t.Fatal(err) + } defer wm.Close() w, err := wm.AddWallet(wallet.Wallet{Name: "test"}) @@ -1769,7 +2030,10 @@ func TestOrphansV2(t *testing.T) { t.Fatal(err) } - wm = wallet.NewManager(cm, db, log.Named("wallet")) + wm, err = wallet.NewManager(cm, db, wallet.WithLogger(log.Named("wallet"))) + if err != nil { + t.Fatal(err) + } defer wm.Close() waitForBlock(t, cm, db) @@ -1856,7 +2120,10 @@ func TestDeleteWallet(t *testing.T) { } cm := chain.NewManager(store, genesisState) - wm := wallet.NewManager(cm, db, log.Named("wallet")) + wm, err := wallet.NewManager(cm, db, wallet.WithLogger(log.Named("wallet"))) + if err != nil { + t.Fatal(err) + } defer wm.Close() w, err := wm.AddWallet(wallet.Wallet{Name: "test"})