package trace import ( "context" "fmt" "runtime" "strings" pg_query "github.com/pganalyze/pg_query_go/v2" "github.com/samber/lo" "go.opencensus.io/trace" "gorm.io/gorm" ) // RegisterCallbacks registers the necessary callbacks in Gorm's hook system for instrumentation. func RegisterGormCallbacks(db *gorm.DB) error { if err := db.Callback().Create().Before("gorm:create").Register("instrumentation:before_create", beforeCreate); err != nil { return err } if err := db.Callback().Create().After("gorm:create").Register("instrumentation:after_create", afterCreate); err != nil { return err } if err := db.Callback().Query().Before("gorm:query").Register("instrumentation:before_query", beforeQuery); err != nil { return err } if err := db.Callback().Query().After("gorm:query").Register("instrumentation:after_query", afterQuery); err != nil { return err } if err := db.Callback().Update().Before("gorm:update").Register("instrumentation:before_update", beforeUpdate); err != nil { return err } if err := db.Callback().Update().After("gorm:update").Register("instrumentation:after_update", afterUpdate); err != nil { return err } if err := db.Callback().Delete().Before("gorm:delete").Register("instrumentation:before_delete", beforeDelete); err != nil { return err } if err := db.Callback().Delete().After("gorm:delete").Register("instrumentation:after_delete", afterDelete); err != nil { return err } return nil } func before(db *gorm.DB, operation string) { db.Statement.Context = startTrace(db.Statement.Context, db, operation) } func startTrace(ctx context.Context, db *gorm.DB, operation string) context.Context { // Don't trace queries if they don't have a parent span. if span := trace.FromContext(ctx); span == nil { return ctx } var ( file string line int ) // walk up the call stack looking for the line of code that called us. but // give up if it's more than 20 steps, and skip the first 5 as they're all // gorm anyway for n := 5; n < 20; n++ { _, file, line, _ = runtime.Caller(n) if strings.Contains(file, "/gorm.io/") { // skip any helper code and go further up the call stack continue } break } ctx, span := trace.StartSpan(ctx, fmt.Sprintf("gorm.%s.%s", operation, db.Statement.Table)) span.AddAttributes(trace.StringAttribute("gorm.table", db.Statement.Table)) span.AddAttributes(trace.StringAttribute("caller", fmt.Sprintf("%s:%v", file, line))) return ctx } func after(scope *gorm.DB, operation string) { endTrace(scope, operation) } func endTrace(db *gorm.DB, operation string) { span := trace.FromContext(db.Statement.Context) if span == nil || !span.IsRecordingEvents() { return } var status trace.Status if db.Error != nil { err := db.Error if err == gorm.ErrRecordNotFound { status.Code = trace.StatusCodeNotFound } else { status.Code = trace.StatusCodeUnknown } status.Message = err.Error() } span.AddAttributes( trace.Int64Attribute("gorm.rows_affected", db.Statement.RowsAffected), trace.StringAttribute("gorm.query", db.Statement.SQL.String()), ) fingerprint, err := pg_query.Fingerprint(db.Statement.SQL.String()) if err != nil { fingerprint = "unknown" } span.SetName(fmt.Sprintf("gorm.%s.%s.%s", operation, db.Statement.Table, fingerprint)) span.SetStatus(status) span.End() } func beforeCreate(scope *gorm.DB) { before(scope, "create") } func afterCreate(scope *gorm.DB) { after(scope, "create") } func beforeQuery(scope *gorm.DB) { before(scope, "query") } func afterQuery(scope *gorm.DB) { fieldStrings := []string{} if scope.Statement != nil { fieldStrings = lo.Map(scope.Statement.Vars, func(v interface{}, i int) string { return fmt.Sprintf("($%v = %v)", i+1, v) }) } span := trace.FromContext(scope.Statement.Context) if span != nil && span.IsRecordingEvents() { span.AddAttributes( trace.StringAttribute("gorm.query.vars", strings.Join(fieldStrings, ", ")), ) } after(scope, "query") } func beforeUpdate(scope *gorm.DB) { before(scope, "update") } func afterUpdate(scope *gorm.DB) { after(scope, "update") } func beforeDelete(scope *gorm.DB) { before(scope, "delete") } func afterDelete(scope *gorm.DB) { after(scope, "delete") }