diff --git a/results.go b/results.go index a232cd5c..597def8c 100644 --- a/results.go +++ b/results.go @@ -38,16 +38,24 @@ func NewResults(targetNames []string, testModes []string) *Results { } } -// SingleResult represents the verification result from a single target, with the schema: -// SingleResult[schema][table][mode] = test output. -type SingleResult map[string]map[string]map[string]string +// DatabaseResult represents the verification result from a single target database: +// DatabaseResult[schema][table][mode] = test output. +type DatabaseResult map[string]SchemaResult + +// SchemaResult represents the verification result from a single schema: +// SchemaResult[table][mode] = test output. +type SchemaResult map[string]TableResult + +// TableResult represents the verification result from a single table: +// TableResult[mode] = test output. +type TableResult map[string]string // AddResult adds a SingleResult from a test on a specific target to the Results object. -func (r *Results) AddResult(targetName string, schemaTableHashes SingleResult) { +func (r *Results) AddResult(targetName string, databaseHashes DatabaseResult) { r.mutex.Lock() defer r.mutex.Unlock() - for schema, tables := range schemaTableHashes { + for schema, tables := range databaseHashes { if _, ok := r.content[schema]; !ok { r.content[schema] = make(map[string]map[string]map[string][]string) } diff --git a/verify.go b/verify.go index 84d32eb1..22e43813 100644 --- a/verify.go +++ b/verify.go @@ -2,6 +2,7 @@ package pgverify import ( "context" + "sync" "github.com/jackc/pgx/pgtype" "github.com/jackc/pgx/v4" @@ -30,7 +31,7 @@ func (c Config) Verify(ctx context.Context, targets []*pgx.ConnConfig) (*Results // First check that we can connect to every specified target database. targetNames := make([]string, len(targets)) - conns := make(map[int]*pgx.Conn) + connConfs := make(map[int]*pgx.ConnConfig) for i, target := range targets { pgxLoggerFields := logrus.Fields{ @@ -52,29 +53,23 @@ func (c Config) Verify(ctx context.Context, targets []*pgx.ConnConfig) (*Results target.LogLevel = pgx.LogLevelError - conn, err := pgx.ConnectConfig(ctx, target) - if err != nil { - return finalResults, err - } - defer conn.Close(ctx) - conns[i] = conn + connConfs[i] = target } finalResults = NewResults(targetNames, c.TestModes) // Then query each target database in parallel to generate table hashes. - var doneChannels []chan struct{} + wg := &sync.WaitGroup{} - for i, conn := range conns { - done := make(chan struct{}) - go c.runTestsOnTarget(ctx, targetNames[i], conn, finalResults, done) - doneChannels = append(doneChannels, done) - } + for i, connConf := range connConfs { + wg.Add(1) - for _, done := range doneChannels { - <-done + go c.runTestsOnTarget(ctx, targetNames[i], connConf, finalResults, wg) } + // Wait for queries to complete + wg.Wait() + // Compare final results reportErrors := finalResults.CheckForErrors() if len(reportErrors) > 0 { @@ -86,26 +81,40 @@ func (c Config) Verify(ctx context.Context, targets []*pgx.ConnConfig) (*Results return finalResults, nil } -func (c Config) runTestsOnTarget(ctx context.Context, targetName string, conn *pgx.Conn, finalResults *Results, done chan struct{}) { +func (c Config) runTestsOnTarget(ctx context.Context, targetName string, connConf *pgx.ConnConfig, finalResults *Results, wg *sync.WaitGroup) { + defer wg.Done() + logger := c.Logger.WithField("target", targetName) + conn, err := pgx.ConnectConfig(ctx, connConf) + if err != nil { + logger.WithError(err).Error("failed to connect to target") + + return + } + + defer conn.Close(ctx) + schemaTableHashes, err := c.fetchTargetTableNames(ctx, conn) if err != nil { logger.WithError(err).Error("failed to fetch target tables") - close(done) return } - schemaTableHashes = c.runTestQueriesOnTarget(ctx, logger, conn, schemaTableHashes) + for schemaName, schemaHashes := range schemaTableHashes { + for tableName := range schemaHashes { + wg.Add(1) + + go c.runTestQueriesOnTable(ctx, logger, connConf, targetName, schemaName, tableName, finalResults, wg) + } + } - finalResults.AddResult(targetName, schemaTableHashes) logger.Info("Table hashes computed") - close(done) } -func (c Config) fetchTargetTableNames(ctx context.Context, conn *pgx.Conn) (SingleResult, error) { - schemaTableHashes := make(SingleResult) +func (c Config) fetchTargetTableNames(ctx context.Context, conn *pgx.Conn) (DatabaseResult, error) { + schemaTableHashes := make(DatabaseResult) rows, err := conn.Query(ctx, buildGetTablesQuery(c.IncludeSchemas, c.ExcludeSchemas, c.IncludeTables, c.ExcludeTables)) if err != nil { @@ -119,10 +128,10 @@ func (c Config) fetchTargetTableNames(ctx context.Context, conn *pgx.Conn) (Sing } if _, ok := schemaTableHashes[schema.String]; !ok { - schemaTableHashes[schema.String] = make(map[string]map[string]string) + schemaTableHashes[schema.String] = make(SchemaResult) } - schemaTableHashes[schema.String][table.String] = make(map[string]string) + schemaTableHashes[schema.String][table.String] = make(TableResult) for _, testMode := range c.TestModes { schemaTableHashes[schema.String][table.String][testMode] = defaultErrorOutput @@ -152,111 +161,133 @@ func (c Config) validColumnTarget(columnName string) bool { return false } -func (c Config) runTestQueriesOnTarget(ctx context.Context, logger *logrus.Entry, conn *pgx.Conn, schemaTableHashes SingleResult) SingleResult { - for schemaName, tables := range schemaTableHashes { - for tableName := range tables { - tableLogger := logger.WithField("table", tableName).WithField("schema", schemaName) - tableLogger.Info("Computing hash") +func (c Config) runTestQueriesOnTable(ctx context.Context, logger *logrus.Entry, connConf *pgx.ConnConfig, targetName, schemaName, tableName string, finalResults *Results, wg *sync.WaitGroup) { + defer wg.Done() - rows, err := conn.Query(ctx, buildGetColumsQuery(schemaName, tableName)) - if err != nil { - tableLogger.WithError(err).Error("Failed to query column names, data types") + tableLogger := logger.WithField("table", tableName).WithField("schema", schemaName) + tableLogger.Info("Computing hash") - continue - } + conn, err := pgx.ConnectConfig(ctx, connConf) + if err != nil { + logger.WithError(err).Error("failed to connect to target") - allTableColumns := make(map[string]column) + return + } - for rows.Next() { - var columnName, dataType, constraintName, constraintType pgtype.Text + defer conn.Close(ctx) - err := rows.Scan(&columnName, &dataType, &constraintName, &constraintType) - if err != nil { - tableLogger.WithError(err).Error("Failed to parse column names, data types from query response") + rows, err := conn.Query(ctx, buildGetColumsQuery(schemaName, tableName)) + if err != nil { + tableLogger.WithError(err).Error("Failed to query column names, data types") - continue - } + return + } - existing, ok := allTableColumns[columnName.String] - if ok { - existing.constraints = append(existing.constraints, constraintType.String) - allTableColumns[columnName.String] = existing - } else { - allTableColumns[columnName.String] = column{columnName.String, dataType.String, []string{constraintType.String}} - } - } + allTableColumns := make(map[string]column) - var tableColumns []column + for rows.Next() { + var columnName, dataType, constraintName, constraintType pgtype.Text - var primaryKeyColumnNames []string + err := rows.Scan(&columnName, &dataType, &constraintName, &constraintType) + if err != nil { + tableLogger.WithError(err).Error("Failed to parse column names, data types from query response") - for _, col := range allTableColumns { - if col.IsPrimaryKey() { - primaryKeyColumnNames = append(primaryKeyColumnNames, col.name) - } + continue + } - if c.validColumnTarget(col.name) { - tableColumns = append(tableColumns, col) - } - } + existing, ok := allTableColumns[columnName.String] + if ok { + existing.constraints = append(existing.constraints, constraintType.String) + allTableColumns[columnName.String] = existing + } else { + allTableColumns[columnName.String] = column{columnName.String, dataType.String, []string{constraintType.String}} + } + } - if len(primaryKeyColumnNames) == 0 { - tableLogger.Error("No primary keys found") + var tableColumns []column - continue - } + var primaryKeyColumnNames []string - tableLogger.WithFields(logrus.Fields{ - "primary_keys": primaryKeyColumnNames, - "columns": tableColumns, - }).Info("Determined columns to hash") + for _, col := range allTableColumns { + if col.IsPrimaryKey() { + primaryKeyColumnNames = append(primaryKeyColumnNames, col.name) + } - for _, testMode := range c.TestModes { - testLogger := tableLogger.WithField("test", testMode) + if c.validColumnTarget(col.name) { + tableColumns = append(tableColumns, col) + } + } - var query string + if len(primaryKeyColumnNames) == 0 { + tableLogger.Error("No primary keys found") - switch testMode { - case TestModeFull: - query = buildFullHashQuery(c, schemaName, tableName, tableColumns) - case TestModeBookend: - query = buildBookendHashQuery(c, schemaName, tableName, tableColumns, c.BookendLimit) - case TestModeSparse: - query = buildSparseHashQuery(c, schemaName, tableName, tableColumns, c.SparseMod) - case TestModeRowCount: - query = buildRowCountQuery(schemaName, tableName) - } + return + } - testLogger.Debugf("Generated query: %s", query) + tableLogger.WithFields(logrus.Fields{ + "primary_keys": primaryKeyColumnNames, + "columns": tableColumns, + }).Info("Determined columns to hash") + + for _, testMode := range c.TestModes { + testLogger := tableLogger.WithField("test", testMode) + + var query string + + switch testMode { + case TestModeFull: + query = buildFullHashQuery(c, schemaName, tableName, tableColumns) + case TestModeBookend: + query = buildBookendHashQuery(c, schemaName, tableName, tableColumns, c.BookendLimit) + case TestModeSparse: + query = buildSparseHashQuery(c, schemaName, tableName, tableColumns, c.SparseMod) + case TestModeRowCount: + query = buildRowCountQuery(schemaName, tableName) + } - testOutput, err := runTestOnTable(ctx, conn, query) - if err != nil { - testLogger.WithError(err).Error("Failed to compute hash") + testLogger.Debugf("Generated query: %s", query) - continue - } + wg.Add(1) - schemaTableHashes[schemaName][tableName][testMode] = testOutput - testLogger.Infof("Hash computed: %s", testOutput) - } - } + go runTestOnTable(ctx, testLogger, connConf, targetName, schemaName, tableName, testMode, query, finalResults, wg) } - - return schemaTableHashes } -func runTestOnTable(ctx context.Context, conn *pgx.Conn, query string) (string, error) { +func runTestOnTable(ctx context.Context, logger *logrus.Entry, connConf *pgx.ConnConfig, targetName, schemaName, tableName, testMode, query string, finalResults *Results, wg *sync.WaitGroup) { + defer wg.Done() + + conn, err := pgx.ConnectConfig(ctx, connConf) + if err != nil { + logger.WithError(err).Error("failed to connect to target") + + return + } + + defer conn.Close(ctx) + row := conn.QueryRow(ctx, query) + var testOutputString string + var testOutput pgtype.Text if err := row.Scan(&testOutput); err != nil { switch err { case pgx.ErrNoRows: - return "no rows", nil + testOutputString = "no rows" default: - return "", errors.Wrap(err, "failed to scan test output") + logger.WithError(err).Error("failed to scan test output") + + return } + } else { + testOutputString = testOutput.String } - return testOutput.String, nil + logger.Infof("Hash computed: %s", testOutputString) + + databaseResults := make(DatabaseResult) + databaseResults[schemaName] = make(SchemaResult) + databaseResults[schemaName][tableName] = make(TableResult) + databaseResults[schemaName][tableName][testMode] = testOutputString + finalResults.AddResult(targetName, databaseResults) }