package main import ( "fmt" "log" "strings" "github.com/pingcap/parser" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/format" _ "github.com/pingcap/tidb/types/parser_driver" ) // Rewrite sql Rewrite type Rewrite struct { SQL string NewSQL string Stmt ast.StmtNode } // NewRewrite Func func NewRewrite(sql, charset, collation string) *Rewrite { p := parser.New() stmtNode, err := p.ParseOneStmt(sql, charset, collation) if err != nil { log.Fatal("error...", err) } return &Rewrite{ SQL: sql, Stmt: stmtNode, } } func newLimit(val int) *ast.Limit { limit := ast.Limit{ Count: ast.NewValueExpr(val), } return &limit } type checkLimitVisitor struct{} func (clv *checkLimitVisitor) Enter(in ast.Node) (out ast.Node, skipChildren bool) { switch node := in.(type) { case *ast.Limit: count := ast.NewValueExpr(1) node.Count = count return node, false case *ast.SelectStmt: node.Limit = newLimit(1) } return in, true } func (clv *checkLimitVisitor) Leave(in ast.Node) (out ast.Node, ok bool) { return in, true } func (rw *Rewrite) forceSelectLimit1() *Rewrite { if rw.Stmt == nil { return rw } foundSelect := false switch stmt := rw.Stmt.(type) { case *ast.SelectStmt: v := checkLimitVisitor{} stmt.Accept(&v) foundSelect = true } if foundSelect { var sb strings.Builder ctx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb) rw.Stmt.Restore(ctx) rw.NewSQL = sb.String() } return rw } func main() { sql1 := "SELECT t1.a, t2.b FROM t1 JOIN t2 ON t1.id = t2.fid WHERE t1.c>100 limit 100;" rw := NewRewrite(sql1, "", "") rw.forceSelectLimit1() fmt.Println(rw.NewSQL) // OUT: SELECT `t1`.`a`,`t2`.`b` FROM `t1` JOIN `t2` ON `t1`.`id`=`t2`.`fid` WHERE `t1`.`c`>100 LIMIT 1 sql2 := "SELECT t1.a, t2.b FROM t1 JOIN t2 ON t1.id = t2.fid WHERE t1.c>100;" rw = NewRewrite(sql2, "", "") rw.forceSelectLimit1() fmt.Println(rw.NewSQL) // OUT: SELECT `t1`.`a`,`t2`.`b` FROM `t1` JOIN `t2` ON `t1`.`id`=`t2`.`fid` WHERE `t1`.`c`>100 LIMIT 1 sql3 := "DELETE FROM Customers WHERE CustomerName='Alfreds Futterkiste';" rw = NewRewrite(sql3, "", "") rw.forceSelectLimit1() fmt.Println(rw.NewSQL) // OUT: "" }