From 4920172e66f176aec8eb59938dd381b1fe54a6a1 Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 26 Apr 2026 20:51:58 +0000 Subject: [PATCH] Remove testify dependency in favor of stdlib testing Replace github.com/stretchr/testify/require calls with the equivalent stdlib testing patterns (if-cond t.Fatal/t.Fatalf). Use reflect.DeepEqual for value comparisons that aren't directly comparable with ==. https://claude.ai/code/session_01SzZiuen17CDNi8REpAGw1k --- ast/base_test.go | 55 +- ast/ddl_test.go | 11 +- ast/dml_test.go | 51 +- ast/expressions_test.go | 11 +- ast/flag_test.go | 18 +- ast/format_test.go | 13 +- ast/functions_test.go | 37 +- ast/misc_test.go | 46 +- ast/model_test.go | 50 +- ast/procedure_test.go | 31 +- ast/sem_test.go | 14 +- ast/stats_test.go | 77 +- ast/util_test.go | 117 +- auth/caching_sha2_test.go | 59 +- auth/mysql_native_password_test.go | 30 +- auth/tidb_sm3_test.go | 63 +- charset/charset_test.go | 66 +- charset/encoding_test.go | 123 +- duration/duration_test.go | 10 +- format/format_test.go | 50 +- generate_keyword/genkeyword_test.go | 10 +- go.mod | 5 - go.sum | 11 +- mysql/const_test.go | 55 +- mysql/error_test.go | 18 +- mysql/privs_test.go | 98 +- mysql/type_test.go | 50 +- parser/consistent_test.go | 34 +- parser/digester_test.go | 75 +- parser/hintparser_test.go | 23 +- parser/keywords_test.go | 23 +- parser/lateral_test.go | 61 +- parser/lexer_test.go | 211 +- parser/parser_test.go | 3290 ++++++++++++++++++++------- parser/reserved_words_test.go | 52 +- terror/terror_test.go | 175 +- types/etc_test.go | 15 +- types/field_type_test.go | 286 ++- util/escape_test.go | 7 +- 39 files changed, 4038 insertions(+), 1393 deletions(-) diff --git a/ast/base_test.go b/ast/base_test.go index e76ab71..c480610 100644 --- a/ast/base_test.go +++ b/ast/base_test.go @@ -20,7 +20,8 @@ import ( "testing" "github.com/sqlc-dev/marino/charset" - "github.com/stretchr/testify/require" + + "reflect" ) func TestNodeSetText(t *testing.T) { @@ -37,8 +38,12 @@ func TestNodeSetText(t *testing.T) { } for _, tt := range tests { n.SetText(tt.enc, tt.text) - require.Equal(t, tt.expectUTF8Text, n.Text()) - require.Equal(t, tt.expectText, n.OriginalText()) + if !reflect.DeepEqual(tt.expectUTF8Text, n.Text()) { + t.Fatalf("got %v, want %v", n.Text(), tt.expectUTF8Text) + } + if !reflect.DeepEqual(tt.expectText, n.OriginalText()) { + t.Fatalf("got %v, want %v", n.OriginalText(), tt.expectText) + } } } @@ -66,7 +71,9 @@ func TestBinaryStringLiteralConversion(t *testing.T) { } for _, tt := range printableTests { n.SetText(charset.EncodingUTF8Impl, tt.text) - require.Equal(t, tt.want, n.Text(), tt.name) + if !reflect.DeepEqual(tt.want, n.Text()) { + t.Fatalf("%v: got %v, want %v", tt.name, n.Text(), tt.want) + } } // Binary (non-printable) strings — should convert to 0x hex literals @@ -98,7 +105,9 @@ func TestBinaryStringLiteralConversion(t *testing.T) { } for _, tt := range binaryTests { n.SetText(charset.EncodingUTF8Impl, tt.text) - require.Equal(t, tt.want, n.Text(), tt.name) + if !reflect.DeepEqual(tt.want, n.Text()) { + t.Fatalf("%v: got %v, want %v", tt.name, n.Text(), tt.want) + } } } @@ -206,7 +215,9 @@ func TestBinaryStringLiteralSkipsComments(t *testing.T) { } for _, tt := range tests { n.SetText(charset.EncodingUTF8Impl, tt.text) - require.Equal(t, tt.want, n.Text(), tt.name) + if !reflect.DeepEqual(tt.want, n.Text()) { + t.Fatalf("%v: got %v, want %v", tt.name, n.Text(), tt.want) + } } } @@ -215,15 +226,21 @@ func TestBinaryStringLiteralNoBackslashEscapes(t *testing.T) { n.SetText(charset.EncodingUTF8Impl, "SELECT '\\n'") n.SetNoBackslashEscapes(true) - require.Equal(t, "SELECT '\\n'", n.Text(), "NO_BACKSLASH_ESCAPES literal \\n") + if !reflect.DeepEqual("SELECT '\\n'", n.Text()) { + t.Fatalf("%v: got %v, want %v", "NO_BACKSLASH_ESCAPES literal \\n", n.Text(), "SELECT '\\n'") + } n.SetText(charset.EncodingUTF8Impl, "SELECT '\\' , 'after'") n.SetNoBackslashEscapes(true) - require.Equal(t, "SELECT '\\' , 'after'", n.Text(), "NO_BACKSLASH_ESCAPES quote boundary") + if !reflect.DeepEqual("SELECT '\\' , 'after'", n.Text()) { + t.Fatalf("%v: got %v, want %v", "NO_BACKSLASH_ESCAPES quote boundary", n.Text(), "SELECT '\\' , 'after'") + } n.SetText(charset.EncodingUTF8Impl, "SELECT '\xd2\xe4'") n.SetNoBackslashEscapes(true) - require.Equal(t, "SELECT 0xd2e4", n.Text(), "NO_BACKSLASH_ESCAPES binary") + if !reflect.DeepEqual("SELECT 0xd2e4", n.Text()) { + t.Fatalf("%v: got %v, want %v", "NO_BACKSLASH_ESCAPES binary", n.Text(), "SELECT 0xd2e4") + } } func TestBinaryStringLiteralGBK(t *testing.T) { @@ -233,23 +250,33 @@ func TestBinaryStringLiteralGBK(t *testing.T) { // This should be decoded as valid GBK and left as a printable string, // not converted to a hex literal. n.SetText(charset.EncodingGBKImpl, "select '\xb1\xed\x31'") - require.Equal(t, "select '表1'", n.Text(), "GBK printable") + if !reflect.DeepEqual("select '表1'", n.Text()) { + t.Fatalf("%v: got %v, want %v", "GBK printable", n.Text(), "select '表1'") + } // GBK with actual invalid bytes should still convert to hex n.SetText(charset.EncodingGBKImpl, "select '\x80\xff'") - require.Equal(t, "select 0x80ff", n.Text(), "GBK binary") + if !reflect.DeepEqual("select 0x80ff", n.Text()) { + t.Fatalf("%v: got %v, want %v", "GBK binary", n.Text(), "select 0x80ff") + } // 筡 = \xb9\x5c in GBK; trail byte 0x5c must not be mistaken for backslash n.SetText(charset.EncodingGBKImpl, "select '\xb9\x5c'") - require.Equal(t, "select '筡'", n.Text(), "GBK 0x5c trail byte") + if !reflect.DeepEqual("select '筡'", n.Text()) { + t.Fatalf("%v: got %v, want %v", "GBK 0x5c trail byte", n.Text(), "select '筡'") + } // Multiple GBK chars with 0x5c trail bytes: 筡 = \xb9\x5c, 臷 = \xc5\x5c n.SetText(charset.EncodingGBKImpl, "select '\xb9\x5c\xc5\x5c'") - require.Equal(t, "select '筡臷'", n.Text(), "GBK multiple 0x5c trail bytes") + if !reflect.DeepEqual("select '筡臷'", n.Text()) { + t.Fatalf("%v: got %v, want %v", "GBK multiple 0x5c trail bytes", n.Text(), "select '筡臷'") + } // 0x5c trail byte right before closing quote must not escape the quote n.SetText(charset.EncodingGBKImpl, "select '\xb9\x5c', 'after'") - require.Equal(t, "select '筡', 'after'", n.Text(), "GBK 0x5c before quote") + if !reflect.DeepEqual("select '筡', 'after'", n.Text()) { + t.Fatalf("%v: got %v, want %v", "GBK 0x5c before quote", n.Text(), "select '筡', 'after'") + } } func buildBinaryClause() string { diff --git a/ast/ddl_test.go b/ast/ddl_test.go index 45cc6da..6caf3c5 100644 --- a/ast/ddl_test.go +++ b/ast/ddl_test.go @@ -18,7 +18,8 @@ import ( . "github.com/sqlc-dev/marino/ast" "github.com/sqlc-dev/marino/format" - "github.com/stretchr/testify/require" + + "reflect" ) func TestDDLVisitorCover(t *testing.T) { @@ -59,8 +60,12 @@ func TestDDLVisitorCover(t *testing.T) { for _, v := range stmts { ce.reset() v.node.Accept(checkVisitor{}) - require.Equal(t, v.expectedEnterCnt, ce.enterCnt) - require.Equal(t, v.expectedLeaveCnt, ce.leaveCnt) + if !reflect.DeepEqual(v.expectedEnterCnt, ce.enterCnt) { + t.Fatalf("got %v, want %v", ce.enterCnt, v.expectedEnterCnt) + } + if !reflect.DeepEqual(v.expectedLeaveCnt, ce.leaveCnt) { + t.Fatalf("got %v, want %v", ce.leaveCnt, v.expectedLeaveCnt) + } v.node.Accept(visitor1{}) } } diff --git a/ast/dml_test.go b/ast/dml_test.go index 66db8cd..10da64c 100644 --- a/ast/dml_test.go +++ b/ast/dml_test.go @@ -17,10 +17,13 @@ import ( "fmt" "testing" - "github.com/sqlc-dev/marino/parser" . "github.com/sqlc-dev/marino/ast" "github.com/sqlc-dev/marino/format" - "github.com/stretchr/testify/require" + "github.com/sqlc-dev/marino/parser" + + "reflect" + "regexp" + "strings" ) func TestDMLVisitorCover(t *testing.T) { @@ -68,8 +71,12 @@ func TestDMLVisitorCover(t *testing.T) { for _, v := range stmts { ce.reset() v.node.Accept(checkVisitor{}) - require.Equal(t, v.expectedEnterCnt, ce.enterCnt) - require.Equal(t, v.expectedLeaveCnt, ce.leaveCnt) + if !reflect.DeepEqual(v.expectedEnterCnt, ce.enterCnt) { + t.Fatalf("got %v, want %v", ce.enterCnt, v.expectedEnterCnt) + } + if !reflect.DeepEqual(v.expectedLeaveCnt, ce.leaveCnt) { + t.Fatalf("got %v, want %v", ce.leaveCnt, v.expectedLeaveCnt) + } v.node.Accept(visitor1{}) } } @@ -630,9 +637,15 @@ func TestImportIntoRestore(t *testing.T) { } func TestFulltextSearchModifier(t *testing.T) { - require.False(t, FulltextSearchModifier(FulltextSearchModifierNaturalLanguageMode).IsBooleanMode()) - require.True(t, FulltextSearchModifier(FulltextSearchModifierNaturalLanguageMode).IsNaturalLanguageMode()) - require.False(t, FulltextSearchModifier(FulltextSearchModifierNaturalLanguageMode).WithQueryExpansion()) + if FulltextSearchModifier(FulltextSearchModifierNaturalLanguageMode).IsBooleanMode() { + t.Fatal("expected false") + } + if !(FulltextSearchModifier(FulltextSearchModifierNaturalLanguageMode).IsNaturalLanguageMode()) { + t.Fatal("expected true") + } + if FulltextSearchModifier(FulltextSearchModifierNaturalLanguageMode).WithQueryExpansion() { + t.Fatal("expected false") + } } func TestImportIntoSecureText(t *testing.T) { @@ -658,19 +671,31 @@ func TestImportIntoSecureText(t *testing.T) { for _, tc := range testCases { comment := fmt.Sprintf("input = %s", tc.input) node, err := p.ParseOneStmt(tc.input, "", "") - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } n, ok := node.(SensitiveStmtNode) - require.True(t, ok, comment) - require.Regexp(t, tc.secured, n.SecureText(), comment) + if !(ok) { + t.Fatal(comment) + } + if !regexp.MustCompile(tc.secured).MatchString(n.SecureText()) { + t.Fatalf("%v: expected %q to match %q", comment, n.SecureText(), tc.secured) + } } } func TestImportIntoFromSelectInvalidStmt(t *testing.T) { p := parser.New() _, err := p.ParseOneStmt("IMPORT INTO t1(a, @1) FROM select * from t2;", "", "") - require.ErrorContains(t, err, "Cannot use user variable(1) in IMPORT INTO FROM SELECT statement") + if err == nil || !strings.Contains(err.Error(), "Cannot use user variable(1) in IMPORT INTO FROM SELECT statement") { + t.Fatalf("expected error containing %q, got %v", "Cannot use user variable(1) in IMPORT INTO FROM SELECT statement", err) + } _, err = p.ParseOneStmt("IMPORT INTO t1(a, @b) FROM select * from t2;", "", "") - require.ErrorContains(t, err, "Cannot use user variable(b) in IMPORT INTO FROM SELECT statement") + if err == nil || !strings.Contains(err.Error(), "Cannot use user variable(b) in IMPORT INTO FROM SELECT statement") { + t.Fatalf("expected error containing %q, got %v", "Cannot use user variable(b) in IMPORT INTO FROM SELECT statement", err) + } _, err = p.ParseOneStmt("IMPORT INTO t1(a) set a=1 FROM select a from t2;", "", "") - require.ErrorContains(t, err, "Cannot use SET clause in IMPORT INTO FROM SELECT statement.") + if err == nil || !strings.Contains(err.Error(), "Cannot use SET clause in IMPORT INTO FROM SELECT statement.") { + t.Fatalf("expected error containing %q, got %v", "Cannot use SET clause in IMPORT INTO FROM SELECT statement.", err) + } } diff --git a/ast/expressions_test.go b/ast/expressions_test.go index 31a3571..b153220 100644 --- a/ast/expressions_test.go +++ b/ast/expressions_test.go @@ -19,7 +19,8 @@ import ( . "github.com/sqlc-dev/marino/ast" "github.com/sqlc-dev/marino/format" "github.com/sqlc-dev/marino/mysql" - "github.com/stretchr/testify/require" + + "reflect" ) type checkVisitor struct{} @@ -94,8 +95,12 @@ func TestExpresionsVisitorCover(t *testing.T) { for _, v := range stmts { ce.reset() v.node.Accept(checkVisitor{}) - require.Equal(t, v.expectedEnterCnt, ce.enterCnt) - require.Equal(t, v.expectedLeaveCnt, ce.leaveCnt) + if !reflect.DeepEqual(v.expectedEnterCnt, ce.enterCnt) { + t.Fatalf("got %v, want %v", ce.enterCnt, v.expectedEnterCnt) + } + if !reflect.DeepEqual(v.expectedLeaveCnt, ce.leaveCnt) { + t.Fatalf("got %v, want %v", ce.leaveCnt, v.expectedLeaveCnt) + } v.node.Accept(visitor1{}) } } diff --git a/ast/flag_test.go b/ast/flag_test.go index f17cee2..748de82 100644 --- a/ast/flag_test.go +++ b/ast/flag_test.go @@ -14,11 +14,13 @@ package ast_test import ( + "fmt" "testing" - "github.com/sqlc-dev/marino/parser" "github.com/sqlc-dev/marino/ast" - "github.com/stretchr/testify/require" + "github.com/sqlc-dev/marino/parser" + + "reflect" ) func TestHasAggFlag(t *testing.T) { @@ -33,7 +35,9 @@ func TestHasAggFlag(t *testing.T) { } for _, tt := range flagTests { expr.SetFlag(tt.flag) - require.Equal(t, tt.hasAgg, ast.HasAggFlag(expr)) + if !reflect.DeepEqual(tt.hasAgg, ast.HasAggFlag(expr)) { + t.Fatalf("got %v, want %v", ast.HasAggFlag(expr), tt.hasAgg) + } } } @@ -130,10 +134,14 @@ func TestFlag(t *testing.T) { p := parser.New() for _, tt := range flagTests { stmt, err := p.ParseOneStmt("select "+tt.expr, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt := stmt.(*ast.SelectStmt) ast.SetFlag(selectStmt) expr := selectStmt.Fields.Fields[0].Expr - require.Equalf(t, tt.flag, expr.GetFlag(), "For %s", tt.expr) + if !reflect.DeepEqual(tt.flag, expr.GetFlag()) { + t.Fatalf("%s: got %v, want %v", fmt.Sprintf("For %s", tt.expr), expr.GetFlag(), tt.flag) + } } } diff --git a/ast/format_test.go b/ast/format_test.go index 2177a57..acae234 100644 --- a/ast/format_test.go +++ b/ast/format_test.go @@ -5,9 +5,10 @@ import ( "fmt" "testing" - "github.com/sqlc-dev/marino/parser" "github.com/sqlc-dev/marino/ast" - "github.com/stretchr/testify/require" + "github.com/sqlc-dev/marino/parser" + + "reflect" ) func getDefaultCharsetAndCollate() (string, string) { @@ -89,10 +90,14 @@ func TestAstFormat(t *testing.T) { charset, collation := getDefaultCharsetAndCollate() stmts, _, err := parser.New().Parse(expr, charset, collation) node := stmts[0].(*ast.SelectStmt).Fields.Fields[0].Expr - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } writer := bytes.NewBufferString("") node.Format(writer) - require.Equal(t, tt.output, writer.String()) + if !reflect.DeepEqual(tt.output, writer.String()) { + t.Fatalf("got %v, want %v", writer.String(), tt.output) + } } } diff --git a/ast/functions_test.go b/ast/functions_test.go index 291e79e..19737cb 100644 --- a/ast/functions_test.go +++ b/ast/functions_test.go @@ -17,12 +17,13 @@ import ( "strings" "testing" - "github.com/sqlc-dev/marino/parser" . "github.com/sqlc-dev/marino/ast" "github.com/sqlc-dev/marino/format" "github.com/sqlc-dev/marino/mysql" + "github.com/sqlc-dev/marino/parser" "github.com/sqlc-dev/marino/test_driver" - "github.com/stretchr/testify/require" + + "reflect" ) func TestFunctionsVisitorCover(t *testing.T) { @@ -170,15 +171,21 @@ func TestConvert(t *testing.T) { for _, testCase := range cases { stmt, err := parser.New().ParseOneStmt(testCase.SQL, "", "") if testCase.ErrorMessage != "" { - require.EqualError(t, err, testCase.ErrorMessage) + if err == nil || err.Error() != testCase.ErrorMessage { + t.Fatalf("expected error %q, got %v", testCase.ErrorMessage, err) + } continue } - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } st := stmt.(*SelectStmt) expr := st.Fields.Fields[0].Expr.(*FuncCallExpr) charsetArg := expr.Args[1].(*test_driver.ValueExpr) - require.Equal(t, testCase.CharsetName, charsetArg.GetString()) + if !reflect.DeepEqual(testCase.CharsetName, charsetArg.GetString()) { + t.Fatalf("got %v, want %v", charsetArg.GetString(), testCase.CharsetName) + } } } @@ -199,15 +206,21 @@ func TestChar(t *testing.T) { for _, testCase := range cases { stmt, err := parser.New().ParseOneStmt(testCase.SQL, "", "") if testCase.ErrorMessage != "" { - require.EqualError(t, err, testCase.ErrorMessage) + if err == nil || err.Error() != testCase.ErrorMessage { + t.Fatalf("expected error %q, got %v", testCase.ErrorMessage, err) + } continue } - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } st := stmt.(*SelectStmt) expr := st.Fields.Fields[0].Expr.(*FuncCallExpr) charsetArg := expr.Args[1].(*test_driver.ValueExpr) - require.Equal(t, testCase.CharsetName, charsetArg.GetString()) + if !reflect.DeepEqual(testCase.CharsetName, charsetArg.GetString()) { + t.Fatalf("got %v, want %v", charsetArg.GetString(), testCase.CharsetName) + } } } @@ -259,8 +272,12 @@ func TestRestoreWithError(t *testing.T) { sql := "select " + c p := parser.New() stmt, err := p.ParseOneStmt(sql, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } var sb strings.Builder - require.Error(t, extractNodeFunc(stmt).Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb))) + if extractNodeFunc(stmt).Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)) == nil { + t.Fatal("expected error") + } } } diff --git a/ast/misc_test.go b/ast/misc_test.go index 2e9174f..d54c14f 100644 --- a/ast/misc_test.go +++ b/ast/misc_test.go @@ -17,10 +17,12 @@ import ( "fmt" "testing" - "github.com/sqlc-dev/marino/parser" "github.com/sqlc-dev/marino/ast" "github.com/sqlc-dev/marino/mysql" - "github.com/stretchr/testify/require" + "github.com/sqlc-dev/marino/parser" + + "reflect" + "regexp" ) type visitor struct{} @@ -106,7 +108,9 @@ constraint foreign key (jobabbr) references ffxi_jobtype (jobabbr) on delete cas ` parse := parser.New() stmts, _, err := parse.Parse(sql, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } for _, stmt := range stmts { stmt.Accept(visitor{}) stmt.Accept(visitor1{}) @@ -126,7 +130,9 @@ import into t from '/file.csv'` p := parser.New() stmts, _, err := p.Parse(sql, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } for _, stmt := range stmts { stmt.Accept(visitor{}) stmt.Accept(visitor1{}) @@ -142,7 +148,9 @@ func TestSensitiveStatement(t *testing.T) { } for i, stmt := range positive { _, ok := stmt.(ast.SensitiveStmtNode) - require.Truef(t, ok, "%d, %#v fail", i, stmt) + if !(ok) { + t.Fatalf("%d, %#v fail", i, stmt) + } } negative := []ast.StmtNode{ @@ -160,7 +168,9 @@ func TestSensitiveStatement(t *testing.T) { } for _, stmt := range negative { _, ok := stmt.(ast.SensitiveStmtNode) - require.False(t, ok) + if ok { + t.Fatal("expected false") + } } } @@ -312,10 +322,16 @@ func TestBRIESecureText(t *testing.T) { for _, tc := range testCases { comment := fmt.Sprintf("input = %s", tc.input) node, err := p.ParseOneStmt(tc.input, "", "") - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } n, ok := node.(ast.SensitiveStmtNode) - require.True(t, ok, comment) - require.Regexp(t, tc.secured, n.SecureText(), comment) + if !(ok) { + t.Fatal(comment) + } + if !regexp.MustCompile(tc.secured).MatchString(n.SecureText()) { + t.Fatalf("%v: expected %q to match %q", comment, n.SecureText(), tc.secured) + } } } @@ -440,9 +456,15 @@ func TestRedactTrafficStmt(t *testing.T) { p := parser.New() for _, tc := range testCases { node, err := p.ParseOneStmt(tc.input, "", "") - require.NoError(t, err, tc.input) + if err != nil { + t.Fatalf("%v: %v", tc.input, err) + } n, ok := node.(ast.SensitiveStmtNode) - require.True(t, ok, tc.input) - require.Equal(t, tc.secured, n.SecureText(), tc.input) + if !(ok) { + t.Fatal(tc.input) + } + if !reflect.DeepEqual(tc.secured, n.SecureText()) { + t.Fatalf("%v: got %v, want %v", tc.input, n.SecureText(), tc.secured) + } } } diff --git a/ast/model_test.go b/ast/model_test.go index 67aa931..caa5e71 100644 --- a/ast/model_test.go +++ b/ast/model_test.go @@ -17,14 +17,20 @@ import ( "encoding/json" "testing" - "github.com/stretchr/testify/require" + "reflect" ) func TestT(t *testing.T) { abc := NewCIStr("aBC") - require.Equal(t, "aBC", abc.O) - require.Equal(t, "abc", abc.L) - require.Equal(t, "aBC", abc.String()) + if !reflect.DeepEqual("aBC", abc.O) { + t.Fatalf("got %v, want %v", abc.O, "aBC") + } + if !reflect.DeepEqual("abc", abc.L) { + t.Fatalf("got %v, want %v", abc.L, "abc") + } + if !reflect.DeepEqual("aBC", abc.String()) { + t.Fatalf("got %v, want %v", abc.String(), "aBC") + } } func TestUnmarshalCIStr(t *testing.T) { @@ -33,15 +39,33 @@ func TestUnmarshalCIStr(t *testing.T) { // Test unmarshal CIStr from a single string. str := "aaBB" buf, err := json.Marshal(str) - require.NoError(t, err) - require.NoError(t, ci.UnmarshalJSON(buf)) - require.Equal(t, str, ci.O) - require.Equal(t, "aabb", ci.L) + if err != nil { + t.Fatal(err) + } + if ci.UnmarshalJSON(buf) != nil { + t.Fatal(ci.UnmarshalJSON(buf)) + } + if !reflect.DeepEqual(str, ci.O) { + t.Fatalf("got %v, want %v", ci.O, str) + } + if !reflect.DeepEqual("aabb", ci.L) { + t.Fatalf("got %v, want %v", ci.L, "aabb") + } buf, err = json.Marshal(ci) - require.NoError(t, err) - require.Equal(t, `{"O":"aaBB","L":"aabb"}`, string(buf)) - require.NoError(t, ci.UnmarshalJSON(buf)) - require.Equal(t, str, ci.O) - require.Equal(t, "aabb", ci.L) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(`{"O":"aaBB","L":"aabb"}`, string(buf)) { + t.Fatalf("got %v, want %v", string(buf), `{"O":"aaBB","L":"aabb"}`) + } + if ci.UnmarshalJSON(buf) != nil { + t.Fatal(ci.UnmarshalJSON(buf)) + } + if !reflect.DeepEqual(str, ci.O) { + t.Fatalf("got %v, want %v", ci.O, str) + } + if !reflect.DeepEqual("aabb", ci.L) { + t.Fatalf("got %v, want %v", ci.L, "aabb") + } } diff --git a/ast/procedure_test.go b/ast/procedure_test.go index 1d0e27f..b056fdd 100644 --- a/ast/procedure_test.go +++ b/ast/procedure_test.go @@ -17,9 +17,8 @@ import ( "fmt" "testing" - "github.com/sqlc-dev/marino/parser" "github.com/sqlc-dev/marino/ast" - "github.com/stretchr/testify/require" + "github.com/sqlc-dev/marino/parser" ) func TestProcedureVisitorCover(t *testing.T) { @@ -105,22 +104,34 @@ func TestProcedure(t *testing.T) { if err != nil { fmt.Println(testcase) } - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } _, ok := stmt[0].(*ast.ProcedureInfo) - require.True(t, ok, testcase) + if !(ok) { + t.Fatal(testcase) + } } } func TestShowCreateProcedure(t *testing.T) { p := parser.New() stmt, _, err := p.Parse("show create procedure proc_2", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } _, ok := stmt[0].(*ast.ShowStmt) - require.True(t, ok) + if !(ok) { + t.Fatal("expected true") + } stmt, _, err = p.Parse("drop procedure proc_2", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } _, ok = stmt[0].(*ast.DropProcedureStmt) - require.True(t, ok) + if !(ok) { + t.Fatal("expected true") + } } func TestProcedureVisitor(t *testing.T) { @@ -132,7 +143,9 @@ func TestProcedureVisitor(t *testing.T) { parse := parser.New() for _, sql := range sqls { stmts, _, err := parse.Parse(sql, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } for _, stmt := range stmts { stmt.Accept(visitor{}) stmt.Accept(visitor1{}) diff --git a/ast/sem_test.go b/ast/sem_test.go index 6c93c08..5953cc5 100644 --- a/ast/sem_test.go +++ b/ast/sem_test.go @@ -16,7 +16,7 @@ package ast import ( "testing" - "github.com/stretchr/testify/require" + "reflect" ) func TestShowCommand(t *testing.T) { @@ -25,7 +25,9 @@ func TestShowCommand(t *testing.T) { Tp: ShowStmtType(i), } - require.NotEqual(t, stmt.SEMCommand(), UnknownCommand, "SEMCommand should not be UnknownCommand for ShowStmtType %d", i) + if reflect.DeepEqual(stmt.SEMCommand(), UnknownCommand) { + t.Fatalf("expected values to differ, both are %v", UnknownCommand) + } } } @@ -35,7 +37,9 @@ func TestAdminCommand(t *testing.T) { Tp: AdminStmtType(i), } - require.NotEqual(t, stmt.SEMCommand(), UnknownCommand, "SEMCommand should not be UnknownCommand for AdminStmtType %d", i) + if reflect.DeepEqual(stmt.SEMCommand(), UnknownCommand) { + t.Fatalf("expected values to differ, both are %v", UnknownCommand) + } } } @@ -45,6 +49,8 @@ func TestBRIECommand(t *testing.T) { Kind: i, } - require.NotEqual(t, stmt.SEMCommand(), UnknownCommand, "SEMCommand should not be UnknownCommand for BRIEKind %s", i) + if reflect.DeepEqual(stmt.SEMCommand(), UnknownCommand) { + t.Fatalf("expected values to differ, both are %v", UnknownCommand) + } } } diff --git a/ast/stats_test.go b/ast/stats_test.go index b31145b..b3dab49 100644 --- a/ast/stats_test.go +++ b/ast/stats_test.go @@ -17,10 +17,11 @@ import ( "strings" "testing" - "github.com/sqlc-dev/marino/parser" "github.com/sqlc-dev/marino/ast" "github.com/sqlc-dev/marino/format" - "github.com/stretchr/testify/require" + "github.com/sqlc-dev/marino/parser" + + "reflect" ) func TestRefreshStatsStmt(t *testing.T) { @@ -79,18 +80,30 @@ func TestRefreshStatsStmt(t *testing.T) { p := parser.New() for _, test := range tests { stmt, err := p.ParseOneStmt(test.sql, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } rs := stmt.(*ast.RefreshStatsStmt) if test.modeSet { - require.NotNil(t, rs.RefreshMode) - require.Equal(t, test.mode, *rs.RefreshMode) + if rs.RefreshMode == nil { + t.Fatal("expected non-nil") + } + if !reflect.DeepEqual(test.mode, *rs.RefreshMode) { + t.Fatalf("got %v, want %v", *rs.RefreshMode, test.mode) + } } else { - require.Nil(t, rs.RefreshMode) + if rs.RefreshMode != nil { + t.Fatalf("expected nil, got %v", rs.RefreshMode) + } } var sb strings.Builder err = stmt.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)) - require.NoError(t, err) - require.Equal(t, test.want, sb.String()) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(test.want, sb.String()) { + t.Fatalf("got %v, want %v", sb.String(), test.want) + } } } @@ -150,15 +163,27 @@ func TestFlushStatsDeltaScoped(t *testing.T) { for _, test := range tests { t.Run(test.sql, func(t *testing.T) { stmt, err := p.ParseOneStmt(test.sql, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } fs := stmt.(*ast.FlushStmt) - require.Equal(t, ast.FlushStatsDelta, fs.Tp) - require.Len(t, fs.FlushObjects, test.objects) - require.Equal(t, test.cluster, fs.IsCluster) + if !reflect.DeepEqual(ast.FlushStatsDelta, fs.Tp) { + t.Fatalf("got %v, want %v", fs.Tp, ast.FlushStatsDelta) + } + if got := len(fs.FlushObjects); got != test.objects { + t.Fatalf("expected length %d, got %d", test.objects, got) + } + if !reflect.DeepEqual(test.cluster, fs.IsCluster) { + t.Fatalf("got %v, want %v", fs.IsCluster, test.cluster) + } var sb strings.Builder err = stmt.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)) - require.NoError(t, err) - require.Equal(t, test.want, sb.String()) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(test.want, sb.String()) { + t.Fatalf("got %v, want %v", sb.String(), test.want) + } }) } @@ -186,13 +211,19 @@ func TestFlushStatsDeltaScoped(t *testing.T) { for _, test := range dedupTests { t.Run(test.name, func(t *testing.T) { stmt, err := p.ParseOneStmt(test.sql, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } fs := stmt.(*ast.FlushStmt) fs.DedupFlushObjects() var sb strings.Builder err = fs.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)) - require.NoError(t, err) - require.Equal(t, test.want, sb.String()) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(test.want, sb.String()) { + t.Fatalf("got %v, want %v", sb.String(), test.want) + } }) } } @@ -234,13 +265,19 @@ func TestRefreshStatsStmtDedup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { stmt, err := p.ParseOneStmt(test.sql, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } rs := stmt.(*ast.RefreshStatsStmt) rs.Dedup() var sb strings.Builder err = rs.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)) - require.NoError(t, err) - require.Equal(t, test.want, sb.String()) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(test.want, sb.String()) { + t.Fatalf("got %v, want %v", sb.String(), test.want) + } }) } } diff --git a/ast/util_test.go b/ast/util_test.go index 342db78..7b902e3 100644 --- a/ast/util_test.go +++ b/ast/util_test.go @@ -18,71 +18,100 @@ import ( "strings" "testing" - "github.com/sqlc-dev/marino/parser" . "github.com/sqlc-dev/marino/ast" . "github.com/sqlc-dev/marino/format" "github.com/sqlc-dev/marino/mysql" + "github.com/sqlc-dev/marino/parser" "github.com/sqlc-dev/marino/test_driver" - "github.com/stretchr/testify/require" + + "reflect" ) func TestCacheable(t *testing.T) { // test non-SelectStmt var stmt Node = &DeleteStmt{} - require.False(t, IsReadOnly(stmt, true)) + if IsReadOnly(stmt, true) { + t.Fatal("expected false") + } stmt = &InsertStmt{} - require.False(t, IsReadOnly(stmt, true)) + if IsReadOnly(stmt, true) { + t.Fatal("expected false") + } stmt = &UpdateStmt{} - require.False(t, IsReadOnly(stmt, true)) + if IsReadOnly(stmt, true) { + t.Fatal("expected false") + } stmt = &ExplainStmt{} - require.True(t, IsReadOnly(stmt, true)) + if !(IsReadOnly(stmt, true)) { + t.Fatal("expected true") + } stmt = &ExplainStmt{} - require.True(t, IsReadOnly(stmt, true)) + if !(IsReadOnly(stmt, true)) { + t.Fatal("expected true") + } stmt = &DoStmt{} - require.True(t, IsReadOnly(stmt, true)) + if !(IsReadOnly(stmt, true)) { + t.Fatal("expected true") + } stmt = &ExplainStmt{ Stmt: &InsertStmt{}, } - require.True(t, IsReadOnly(stmt, true)) + if !(IsReadOnly(stmt, true)) { + t.Fatal("expected true") + } stmt = &ExplainStmt{ Analyze: true, Stmt: &InsertStmt{}, } - require.False(t, IsReadOnly(stmt, true)) + if IsReadOnly(stmt, true) { + t.Fatal("expected false") + } stmt = &ExplainStmt{ Stmt: &SelectStmt{}, } - require.True(t, IsReadOnly(stmt, true)) + if !(IsReadOnly(stmt, true)) { + t.Fatal("expected true") + } stmt = &ExplainStmt{ Analyze: true, Stmt: &SelectStmt{}, } - require.True(t, IsReadOnly(stmt, true)) + if !(IsReadOnly(stmt, true)) { + t.Fatal("expected true") + } stmt = &ShowStmt{} - require.True(t, IsReadOnly(stmt, true)) + if !(IsReadOnly(stmt, true)) { + t.Fatal("expected true") + } stmt = &ShowStmt{} - require.True(t, IsReadOnly(stmt, true)) + if !(IsReadOnly(stmt, true)) { + t.Fatal("expected true") + } stmt = &TraceStmt{ Stmt: &SelectStmt{}, } - require.True(t, IsReadOnly(stmt, true)) + if !(IsReadOnly(stmt, true)) { + t.Fatal("expected true") + } stmt = &TraceStmt{ Stmt: &DeleteStmt{}, } - require.False(t, IsReadOnly(stmt, true)) + if IsReadOnly(stmt, true) { + t.Fatal("expected false") + } } func TestUnionReadOnly(t *testing.T) { @@ -99,22 +128,34 @@ func TestUnionReadOnly(t *testing.T) { Selects: []Node{selectReadOnly, selectReadOnly}, }, } - require.True(t, IsReadOnly(setOprStmt, true)) + if !(IsReadOnly(setOprStmt, true)) { + t.Fatal("expected true") + } setOprStmt.SelectList.Selects = []Node{selectReadOnly, selectReadOnly, selectReadOnly} - require.True(t, IsReadOnly(setOprStmt, true)) + if !(IsReadOnly(setOprStmt, true)) { + t.Fatal("expected true") + } setOprStmt.SelectList.Selects = []Node{selectReadOnly, selectForUpdate} - require.False(t, IsReadOnly(setOprStmt, true)) + if IsReadOnly(setOprStmt, true) { + t.Fatal("expected false") + } setOprStmt.SelectList.Selects = []Node{selectReadOnly, selectForUpdateNoWait} - require.False(t, IsReadOnly(setOprStmt, true)) + if IsReadOnly(setOprStmt, true) { + t.Fatal("expected false") + } setOprStmt.SelectList.Selects = []Node{selectForUpdate, selectForUpdateNoWait} - require.False(t, IsReadOnly(setOprStmt, true)) + if IsReadOnly(setOprStmt, true) { + t.Fatal("expected false") + } setOprStmt.SelectList.Selects = []Node{selectReadOnly, selectForUpdate, selectForUpdateNoWait} - require.False(t, IsReadOnly(setOprStmt, true)) + if IsReadOnly(setOprStmt, true) { + t.Fatal("expected false") + } } // CleanNodeText set the text of node and all child node empty. @@ -195,18 +236,28 @@ func runNodeRestoreTestWithFlags(t *testing.T, nodeTestCases []NodeRestoreTestCa expectSQL := fmt.Sprintf(template, testCase.expectSQL) stmt, err := p.ParseOneStmt(sourceSQL, "", "") comment := fmt.Sprintf("source %#v", testCase) - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } var sb strings.Builder err = extractNodeFunc(stmt).Restore(NewRestoreCtx(flags, &sb)) - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } restoreSql := fmt.Sprintf(template, sb.String()) comment = fmt.Sprintf("source %#v; restore %v", testCase, restoreSql) - require.Equal(t, expectSQL, restoreSql, comment) + if !reflect.DeepEqual(expectSQL, restoreSql) { + t.Fatalf("%v: got %v, want %v", comment, restoreSql, expectSQL) + } stmt2, err := p.ParseOneStmt(restoreSql, "", "") - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } CleanNodeText(stmt) CleanNodeText(stmt2) - require.Equal(t, stmt, stmt2, comment) + if !reflect.DeepEqual(stmt, stmt2) { + t.Fatalf("%v: got %v, want %v", comment, stmt2, stmt) + } } } @@ -220,12 +271,18 @@ func runNodeRestoreTestWithFlagsStmtChange(t *testing.T, nodeTestCases []NodeRes expectSQL := fmt.Sprintf(template, testCase.expectSQL) stmt, err := p.ParseOneStmt(sourceSQL, "", "") comment := fmt.Sprintf("source %#v", testCase) - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } var sb strings.Builder err = extractNodeFunc(stmt).Restore(NewRestoreCtx(flags, &sb)) - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } restoreSql := fmt.Sprintf(template, sb.String()) comment = fmt.Sprintf("source %#v; restore %v", testCase, restoreSql) - require.Equal(t, expectSQL, restoreSql, comment) + if !reflect.DeepEqual(expectSQL, restoreSql) { + t.Fatalf("%v: got %v, want %v", comment, restoreSql, expectSQL) + } } } diff --git a/auth/caching_sha2_test.go b/auth/caching_sha2_test.go index 2a3551f..8e98f1e 100644 --- a/auth/caching_sha2_test.go +++ b/auth/caching_sha2_test.go @@ -18,7 +18,8 @@ import ( "testing" "github.com/sqlc-dev/marino/mysql" - "github.com/stretchr/testify/require" + + "reflect" ) var foobarPwdSHA2Hash, _ = hex.DecodeString("24412430303524031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") @@ -26,37 +27,51 @@ var foobarPwdSHA2Hash, _ = hex.DecodeString("24412430303524031A69251C34295C4B351 func TestCheckShaPasswordGood(t *testing.T) { pwd := "foobar" r, err := CheckHashingPassword(foobarPwdSHA2Hash, pwd, mysql.AuthCachingSha2Password) - require.NoError(t, err) - require.True(t, r) + if err != nil { + t.Fatal(err) + } + if !(r) { + t.Fatal("expected true") + } } func TestCheckShaPasswordBad(t *testing.T) { pwd := "not_foobar" pwhash, _ := hex.DecodeString("24412430303524031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") r, err := CheckHashingPassword(pwhash, pwd, mysql.AuthCachingSha2Password) - require.NoError(t, err) - require.False(t, r) + if err != nil { + t.Fatal(err) + } + if r { + t.Fatal("expected false") + } } func TestCheckShaPasswordShort(t *testing.T) { pwd := "not_foobar" pwhash, _ := hex.DecodeString("aaaaaaaa") _, err := CheckHashingPassword(pwhash, pwd, mysql.AuthCachingSha2Password) - require.Error(t, err) + if err == nil { + t.Fatal("expected error") + } } func TestCheckShaPasswordDigestTypeIncompatible(t *testing.T) { pwd := "not_foobar" pwhash, _ := hex.DecodeString("24422430303524031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") _, err := CheckHashingPassword(pwhash, pwd, mysql.AuthCachingSha2Password) - require.Error(t, err) + if err == nil { + t.Fatal("expected error") + } } func TestCheckShaPasswordIterationsInvalid(t *testing.T) { pwd := "not_foobar" pwhash, _ := hex.DecodeString("24412430304724031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") _, err := CheckHashingPassword(pwhash, pwd, mysql.AuthCachingSha2Password) - require.Error(t, err) + if err == nil { + t.Fatal("expected error") + } } // The output from NewHashPassword is not stable as the hash is based on the generated salt. @@ -65,20 +80,34 @@ func TestNewSha2Password(t *testing.T) { pwd := "testpwd" pwhash := NewHashPassword(pwd, mysql.AuthCachingSha2Password) r, err := CheckHashingPassword([]byte(pwhash), pwd, mysql.AuthCachingSha2Password) - require.NoError(t, err) - require.True(t, r) + if err != nil { + t.Fatal(err) + } + if !(r) { + t.Fatal("expected true") + } for r := range pwhash { - require.Less(t, pwhash[r], uint8(128)) - require.NotEqual(t, pwhash[r], 0) // NUL - require.NotEqual(t, pwhash[r], 36) // '$' + if !(pwhash[r] < uint8(128)) { + t.Fatalf("expected %v < %v", pwhash[r], uint8(128)) + } + if reflect.DeepEqual(pwhash[r], 0) { + t.Fatalf("expected values to differ, both are %v", 0) + } // NUL + if reflect.DeepEqual(pwhash[r], 36) { + t.Fatalf("expected values to differ, both are %v", 36) + } // '$' } } func BenchmarkShaPassword(b *testing.B) { for i := 0; i < b.N; i++ { m, err := CheckHashingPassword(foobarPwdSHA2Hash, "foobar", mysql.AuthCachingSha2Password) - require.Nil(b, err) - require.True(b, m) + if err != nil { + b.Fatalf("expected nil, got %v", err) + } + if !(m) { + b.Fatal("expected true") + } } } diff --git a/auth/mysql_native_password_test.go b/auth/mysql_native_password_test.go index d5ca759..67b7e69 100644 --- a/auth/mysql_native_password_test.go +++ b/auth/mysql_native_password_test.go @@ -16,19 +16,27 @@ package auth import ( "testing" - "github.com/stretchr/testify/require" + "reflect" ) func TestEncodePassword(t *testing.T) { pwd := "123" - require.Equal(t, "*23AE809DDACAF96AF0FD78ED04B6A265E05AA257", EncodePassword(pwd)) - require.Equal(t, EncodePasswordBytes([]byte(pwd)), EncodePassword(pwd)) + if !reflect.DeepEqual("*23AE809DDACAF96AF0FD78ED04B6A265E05AA257", EncodePassword(pwd)) { + t.Fatalf("got %v, want %v", EncodePassword(pwd), "*23AE809DDACAF96AF0FD78ED04B6A265E05AA257") + } + if !reflect.DeepEqual(EncodePasswordBytes([]byte(pwd)), EncodePassword(pwd)) { + t.Fatalf("got %v, want %v", EncodePassword(pwd), EncodePasswordBytes([]byte(pwd))) + } } func TestDecodePassword(t *testing.T) { x, err := DecodePassword(EncodePassword("123")) - require.NoError(t, err) - require.Equal(t, Sha1Hash(Sha1Hash([]byte("123"))), x) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(Sha1Hash(Sha1Hash([]byte("123"))), x) { + t.Fatalf("got %v, want %v", x, Sha1Hash(Sha1Hash([]byte("123")))) + } } func TestCheckScramble(t *testing.T) { @@ -37,12 +45,18 @@ func TestCheckScramble(t *testing.T) { auth := []byte{24, 180, 183, 225, 166, 6, 81, 102, 70, 248, 199, 143, 91, 204, 169, 9, 161, 171, 203, 33} encodepwd := EncodePassword(pwd) hpwd, err := DecodePassword(encodepwd) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } res := CheckScrambledPassword(salt, hpwd, auth) - require.True(t, res) + if !(res) { + t.Fatal("expected true") + } // Do not panic for invalid input. res = CheckScrambledPassword(salt, hpwd, []byte("xxyyzz")) - require.False(t, res) + if res { + t.Fatal("expected false") + } } diff --git a/auth/tidb_sm3_test.go b/auth/tidb_sm3_test.go index 6ecfc7d..c1f417c 100644 --- a/auth/tidb_sm3_test.go +++ b/auth/tidb_sm3_test.go @@ -18,7 +18,8 @@ import ( "testing" "github.com/sqlc-dev/marino/mysql" - "github.com/stretchr/testify/require" + + "reflect" ) var foobarPwdSM3Hash, _ = hex.DecodeString("24412430303524031a69251c34295c4b35167c7f1e5a7b63091349536c72627066426a635061762e556e6c63533159414d7762317261324a5a3047756b4244664177434e3043") @@ -34,64 +35,94 @@ func TestSM3(t *testing.T) { text := testCase[0] expect, _ = hex.DecodeString(testCase[1]) result := Sm3Hash([]byte(text)) - require.Equal(t, expect, result) + if !reflect.DeepEqual(expect, result) { + t.Fatalf("got %v, want %v", result, expect) + } } } func TestCheckSM3PasswordGood(t *testing.T) { pwd := "foobar" r, err := CheckHashingPassword(foobarPwdSM3Hash, pwd, mysql.AuthTiDBSM3Password) - require.NoError(t, err) - require.True(t, r) + if err != nil { + t.Fatal(err) + } + if !(r) { + t.Fatal("expected true") + } } func TestCheckSM3PasswordBad(t *testing.T) { pwd := "not_foobar" pwhash, _ := hex.DecodeString("24412430303524031a69251c34295c4b35167c7f1e5a7b6309134956387565426743446d3643446176712f6c4b63323667346e48624872776f39512e4342416a693656676f2f") r, err := CheckHashingPassword(pwhash, pwd, mysql.AuthTiDBSM3Password) - require.NoError(t, err) - require.False(t, r) + if err != nil { + t.Fatal(err) + } + if r { + t.Fatal("expected false") + } } func TestCheckSM3PasswordShort(t *testing.T) { pwd := "not_foobar" pwhash, _ := hex.DecodeString("aaaaaaaa") _, err := CheckHashingPassword(pwhash, pwd, mysql.AuthTiDBSM3Password) - require.Error(t, err) + if err == nil { + t.Fatal("expected error") + } } func TestCheckSM3PasswordDigestTypeIncompatible(t *testing.T) { pwd := "not_foobar" pwhash, _ := hex.DecodeString("24432430303524031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") _, err := CheckHashingPassword(pwhash, pwd, mysql.AuthTiDBSM3Password) - require.Error(t, err) + if err == nil { + t.Fatal("expected error") + } } func TestCheckSM3PasswordIterationsInvalid(t *testing.T) { pwd := "not_foobar" pwhash, _ := hex.DecodeString("24412430304724031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537") _, err := CheckHashingPassword(pwhash, pwd, mysql.AuthTiDBSM3Password) - require.Error(t, err) + if err == nil { + t.Fatal("expected error") + } } func TestNewSM3Password(t *testing.T) { pwd := "testpwd" pwhash := NewHashPassword(pwd, mysql.AuthTiDBSM3Password) r, err := CheckHashingPassword([]byte(pwhash), pwd, mysql.AuthTiDBSM3Password) - require.NoError(t, err) - require.True(t, r) + if err != nil { + t.Fatal(err) + } + if !(r) { + t.Fatal("expected true") + } for r := range pwhash { - require.Less(t, pwhash[r], uint8(128)) - require.NotEqual(t, pwhash[r], 0) // NUL - require.NotEqual(t, pwhash[r], 36) // '$' + if !(pwhash[r] < uint8(128)) { + t.Fatalf("expected %v < %v", pwhash[r], uint8(128)) + } + if reflect.DeepEqual(pwhash[r], 0) { + t.Fatalf("expected values to differ, both are %v", 0) + } // NUL + if reflect.DeepEqual(pwhash[r], 36) { + t.Fatalf("expected values to differ, both are %v", 36) + } // '$' } } func BenchmarkSM3Password(b *testing.B) { for i := 0; i < b.N; i++ { m, err := CheckHashingPassword(foobarPwdSM3Hash, "foobar", mysql.AuthTiDBSM3Password) - require.Nil(b, err) - require.True(b, m) + if err != nil { + b.Fatalf("expected nil, got %v", err) + } + if !(m) { + b.Fatal("expected true") + } } } diff --git a/charset/charset_test.go b/charset/charset_test.go index 3675e09..71b9289 100644 --- a/charset/charset_test.go +++ b/charset/charset_test.go @@ -17,12 +17,14 @@ import ( "math/rand" "testing" - "github.com/stretchr/testify/require" + "reflect" ) func testValidCharset(t *testing.T, charset string, collation string, expect bool) { b := ValidCharsetAndCollation(charset, collation) - require.Equal(t, expect, b) + if !reflect.DeepEqual(expect, b) { + t.Fatalf("got %v, want %v", b, expect) + } } func TestValidCharset(t *testing.T) { @@ -57,10 +59,14 @@ func TestValidCharset(t *testing.T) { func testGetDefaultCollation(t *testing.T, charset string, expectCollation string, succ bool) { b, err := GetDefaultCollation(charset) if !succ { - require.Error(t, err) + if err == nil { + t.Fatal("expected error") + } return } - require.Equal(t, expectCollation, b) + if !reflect.DeepEqual(expectCollation, b) { + t.Fatalf("got %v, want %v", b, expectCollation) + } } func TestGetDefaultCollation(t *testing.T) { @@ -87,12 +93,16 @@ func TestGetDefaultCollation(t *testing.T) { for _, collate := range collations { if collate.IsDefault { if desc, ok := CharacterSetInfos[collate.CharsetName]; ok { - require.Equal(t, desc.DefaultCollation, collate.Name) + if !reflect.DeepEqual(desc.DefaultCollation, collate.Name) { + t.Fatalf("got %v, want %v", collate.Name, desc.DefaultCollation) + } charsetNum++ } } } - require.Equal(t, len(CharacterSetInfos), charsetNum) + if !reflect.DeepEqual(len(CharacterSetInfos), charsetNum) { + t.Fatalf("got %v, want %v", charsetNum, len(CharacterSetInfos)) + } } func TestGetCharsetDesc(t *testing.T) { @@ -113,9 +123,13 @@ func TestGetCharsetDesc(t *testing.T) { for _, tt := range tests { desc, err := GetCharsetInfo(tt.cs) if !tt.succ { - require.Error(t, err) + if err == nil { + t.Fatal("expected error") + } } else { - require.Equal(t, tt.result, desc.Name) + if !reflect.DeepEqual(tt.result, desc.Name) { + t.Fatalf("got %v, want %v", desc.Name, tt.result) + } } } } @@ -123,12 +137,18 @@ func TestGetCharsetDesc(t *testing.T) { func TestGetCollationByName(t *testing.T) { for _, collation := range collations { coll, err := GetCollationByName(collation.Name) - require.NoError(t, err) - require.Equal(t, collation, coll) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(collation, coll) { + t.Fatalf("got %v, want %v", coll, collation) + } } _, err := GetCollationByName("non_exist") - require.EqualError(t, err, "[ddl:1273]Unknown collation: 'non_exist'") + if err == nil || err.Error() != "[ddl:1273]Unknown collation: 'non_exist'" { + t.Fatalf("expected error %q, got %v", "[ddl:1273]Unknown collation: 'non_exist'", err) + } } func TestValidCustomCharset(t *testing.T) { @@ -153,12 +173,20 @@ func TestValidCustomCharset(t *testing.T) { func TestUTF8MB3(t *testing.T) { colname, err := GetDefaultCollationLegacy("utf8mb3") - require.NoError(t, err) - require.Equal(t, colname, "utf8_bin") + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(colname, "utf8_bin") { + t.Fatalf("got %v, want %v", "utf8_bin", colname) + } csinfo, err := GetCharsetInfo("utf8mb3") - require.NoError(t, err) - require.Equal(t, csinfo.Name, "utf8") + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(csinfo.Name, "utf8") { + t.Fatalf("got %v, want %v", "utf8", csinfo.Name) + } tests := []struct { cs string @@ -170,8 +198,12 @@ func TestUTF8MB3(t *testing.T) { } for _, tt := range tests { col, err := GetCollationByName(tt.cs) - require.NoError(t, err) - require.Equal(t, col.Name, tt.alias) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(col.Name, tt.alias) { + t.Fatalf("got %v, want %v", tt.alias, col.Name) + } } } diff --git a/charset/encoding_test.go b/charset/encoding_test.go index a57504b..0b93f71 100644 --- a/charset/encoding_test.go +++ b/charset/encoding_test.go @@ -19,28 +19,45 @@ import ( "unicode/utf8" "github.com/sqlc-dev/marino/charset" - "github.com/stretchr/testify/require" "golang.org/x/text/transform" + + "reflect" ) func TestEncoding(t *testing.T) { enc := charset.FindEncoding(charset.CharsetGBK) - require.Equal(t, charset.CharsetGBK, enc.Name()) + if !reflect.DeepEqual(charset.CharsetGBK, enc.Name()) { + t.Fatalf("got %v, want %v", enc.Name(), charset.CharsetGBK) + } txt := []byte("一二三四") e, _ := charset.Lookup("gbk") gbkEncodedTxt, _, err := transform.Bytes(e.NewEncoder(), txt) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } result, err := enc.Transform(nil, gbkEncodedTxt, charset.OpDecode) - require.NoError(t, err) - require.Equal(t, txt, result) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(txt, result) { + t.Fatalf("got %v, want %v", result, txt) + } gbkEncodedTxt2, err := enc.Transform(nil, txt, charset.OpEncode) - require.NoError(t, err) - require.Equal(t, gbkEncodedTxt2, gbkEncodedTxt) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(gbkEncodedTxt2, gbkEncodedTxt) { + t.Fatalf("got %v, want %v", gbkEncodedTxt, gbkEncodedTxt2) + } result, err = enc.Transform(nil, gbkEncodedTxt2, charset.OpDecode) - require.NoError(t, err) - require.Equal(t, txt, result) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(txt, result) { + t.Fatalf("got %v, want %v", result, txt) + } GBKCases := []struct { utf8Str string @@ -67,11 +84,17 @@ func TestEncoding(t *testing.T) { cmt := fmt.Sprintf("%v", tc) result, err := enc.Transform(nil, []byte(tc.utf8Str), charset.OpDecodeReplace) if tc.isValid { - require.NoError(t, err, cmt) + if err != nil { + t.Fatalf("%v: %v", cmt, err) + } } else { - require.Error(t, err, cmt) + if err == nil { + t.Fatal(cmt) + } + } + if !reflect.DeepEqual(tc.result, string(result)) { + t.Fatalf("%v: got %v, want %v", cmt, string(result), tc.result) } - require.Equal(t, tc.result, string(result), cmt) } utf8Cases := []struct { @@ -91,11 +114,17 @@ func TestEncoding(t *testing.T) { cmt := fmt.Sprintf("%v", tc) result, err := enc.Transform(nil, []byte(tc.utf8Str), charset.OpEncodeReplace) if tc.isValid { - require.NoError(t, err, cmt) + if err != nil { + t.Fatalf("%v: %v", cmt, err) + } } else { - require.Error(t, err, cmt) + if err == nil { + t.Fatal(cmt) + } + } + if !reflect.DeepEqual(tc.result, string(result)) { + t.Fatalf("%v: got %v, want %v", cmt, string(result), tc.result) } - require.Equal(t, tc.result, string(result), cmt) } } @@ -151,30 +180,50 @@ func TestEncodingValidate(t *testing.T) { enc = charset.EncodingUTF8MB3StrictImpl } strBytes := []byte(tc.str) - require.Equal(t, tc.ok, enc.IsValid(strBytes), msg) + if !reflect.DeepEqual(tc.ok, enc.IsValid(strBytes)) { + t.Fatalf("%v: got %v, want %v", msg, enc.IsValid(strBytes), tc.ok) + } replace, _ := enc.Transform(nil, strBytes, charset.OpReplaceNoErr) - require.Equal(t, tc.expected, string(replace), msg) + if !reflect.DeepEqual(tc.expected, string(replace)) { + t.Fatalf("%v: got %v, want %v", msg, string(replace), tc.expected) + } } } func TestEncodingGB18030(t *testing.T) { enc := charset.FindEncoding(charset.CharsetGB18030) - require.Equal(t, charset.CharsetGB18030, enc.Name()) + if !reflect.DeepEqual(charset.CharsetGB18030, enc.Name()) { + t.Fatalf("got %v, want %v", enc.Name(), charset.CharsetGB18030) + } txt := []byte("一二三四") e, _ := charset.Lookup("gb18030") gb18030EncodedTxt, _, err := transform.Bytes(e.NewEncoder(), txt) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } result, err := enc.Transform(nil, gb18030EncodedTxt, charset.OpDecode) - require.NoError(t, err) - require.Equal(t, txt, result) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(txt, result) { + t.Fatalf("got %v, want %v", result, txt) + } gb18030EncodedTxt2, err := enc.Transform(nil, txt, charset.OpEncode) - require.NoError(t, err) - require.Equal(t, gb18030EncodedTxt2, gb18030EncodedTxt) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(gb18030EncodedTxt2, gb18030EncodedTxt) { + t.Fatalf("got %v, want %v", gb18030EncodedTxt, gb18030EncodedTxt2) + } result, err = enc.Transform(nil, gb18030EncodedTxt2, charset.OpDecode) - require.NoError(t, err) - require.Equal(t, txt, result) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(txt, result) { + t.Fatalf("got %v, want %v", result, txt) + } GB18030Cases := []struct { utf8Str string @@ -211,11 +260,17 @@ func TestEncodingGB18030(t *testing.T) { cmt := fmt.Sprintf("utf8Str: %s, result: %s, isValid: %t", tc.utf8Str, tc.result, tc.isValid) result, err := enc.Transform(nil, []byte(tc.utf8Str), charset.OpDecodeReplace) if tc.isValid { - require.NoError(t, err, cmt) + if err != nil { + t.Fatalf("%v: %v", cmt, err) + } } else { - require.Error(t, err, cmt) + if err == nil { + t.Fatal(cmt) + } + } + if !reflect.DeepEqual(tc.result, string(result)) { + t.Fatalf("%v: got %v, want %v", cmt, string(result), tc.result) } - require.Equal(t, tc.result, string(result), cmt) } utf8Cases := []struct { @@ -235,10 +290,16 @@ func TestEncodingGB18030(t *testing.T) { cmt := fmt.Sprintf("%v", tc) result, err := enc.Transform(nil, []byte(tc.utf8Str), charset.OpEncodeReplace) if tc.isValid { - require.NoError(t, err, cmt) + if err != nil { + t.Fatalf("%v: %v", cmt, err) + } } else { - require.Error(t, err, cmt) + if err == nil { + t.Fatal(cmt) + } + } + if !reflect.DeepEqual(tc.result, string(result)) { + t.Fatalf("%v: got %v, want %v", cmt, string(result), tc.result) } - require.Equal(t, tc.result, string(result), cmt) } } diff --git a/duration/duration_test.go b/duration/duration_test.go index 4a1555d..2f8f507 100644 --- a/duration/duration_test.go +++ b/duration/duration_test.go @@ -17,7 +17,7 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" + "reflect" ) func TestParseDuration(t *testing.T) { @@ -58,8 +58,12 @@ func TestParseDuration(t *testing.T) { for _, c := range cases { t.Run(c.str, func(t *testing.T) { d, err := ParseDuration(c.str) - require.NoError(t, err) - require.Equal(t, c.duration, d) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(c.duration, d) { + t.Fatalf("got %v, want %v", d, c.duration) + } }) } } diff --git a/format/format_test.go b/format/format_test.go index bc04033..67cd5f7 100644 --- a/format/format_test.go +++ b/format/format_test.go @@ -15,20 +15,28 @@ package format import ( "bytes" + "fmt" "io" "strings" "testing" "github.com/pingcap/errors" - "github.com/stretchr/testify/require" + + "reflect" ) func checkFormat(t *testing.T, f Formatter, buf *bytes.Buffer, str, expect string) { _, err := f.Format(str, 3) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } b, err := io.ReadAll(buf) - require.NoError(t, err) - require.Equal(t, expect, string(b)) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(expect, string(b)) { + t.Fatalf("got %v, want %v", string(b), expect) + } } func TestFormat(t *testing.T) { @@ -79,7 +87,9 @@ func TestRestoreCtx(t *testing.T) { ctx.WriteString("str`.'\"ing\\") ctx.WritePlain(" ") ctx.WriteName("na`.'\"Me\\") - require.Equalf(t, testCase.expect, sb.String(), "case: %#v", testCase) + if !reflect.DeepEqual(testCase.expect, sb.String()) { + t.Fatalf("%s: got %v, want %v", fmt.Sprintf("case: %#v", testCase), sb.String(), testCase.expect) + } } } @@ -87,23 +97,39 @@ func TestRestoreSpecialComment(t *testing.T) { var sb strings.Builder sb.Reset() ctx := NewRestoreCtx(RestoreTiDBSpecialComment, &sb) - require.NoError(t, ctx.WriteWithSpecialComments("fea_id", func() error { + if ctx.WriteWithSpecialComments("fea_id", func() error { ctx.WritePlain("content") return nil - })) - require.Equal(t, "/*T![fea_id] content */", sb.String()) + }) != nil { + t.Fatal(ctx.WriteWithSpecialComments("fea_id", func() error { + ctx.WritePlain("content") + return nil + })) + } + if !reflect.DeepEqual("/*T![fea_id] content */", sb.String()) { + t.Fatalf("got %v, want %v", sb.String(), "/*T![fea_id] content */") + } sb.Reset() - require.NoError(t, ctx.WriteWithSpecialComments("", func() error { + if ctx.WriteWithSpecialComments("", func() error { ctx.WritePlain("shard_row_id_bits") return nil - })) - require.Equal(t, "/*T! shard_row_id_bits */", sb.String()) + }) != nil { + t.Fatal(ctx.WriteWithSpecialComments("", func() error { + ctx.WritePlain("shard_row_id_bits") + return nil + })) + } + if !reflect.DeepEqual("/*T! shard_row_id_bits */", sb.String()) { + t.Fatalf("got %v, want %v", sb.String(), "/*T! shard_row_id_bits */") + } sb.Reset() err := errors.New("xxxx") got := ctx.WriteWithSpecialComments("", func() error { return err }) - require.Same(t, err, got) + if err != got { + t.Fatalf("expected pointer equality, got %p vs %p", err, got) + } } diff --git a/generate_keyword/genkeyword_test.go b/generate_keyword/genkeyword_test.go index 8e3f49f..698e142 100644 --- a/generate_keyword/genkeyword_test.go +++ b/generate_keyword/genkeyword_test.go @@ -3,13 +3,17 @@ package main import ( "testing" - "github.com/stretchr/testify/require" + "reflect" ) func TestParseLine(t *testing.T) { add := parseLine(" add \"ADD\"") - require.Equal(t, add, "ADD") + if !reflect.DeepEqual(add, "ADD") { + t.Fatalf("got %v, want %v", "ADD", add) + } tso := parseLine(" tidbCurrentTSO \"TIDB_CURRENT_TSO\"") - require.Equal(t, tso, "TIDB_CURRENT_TSO") + if !reflect.DeepEqual(tso, "TIDB_CURRENT_TSO") { + t.Fatalf("got %v, want %v", "TIDB_CURRENT_TSO", tso) + } } diff --git a/go.mod b/go.mod index 4d16663..c032eb1 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ require ( github.com/go-sql-driver/mysql v1.7.1 github.com/pingcap/errors v0.11.5-0.20250523034308-74f78ae071ee github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 - github.com/stretchr/testify v1.8.4 go.uber.org/goleak v1.3.0 golang.org/x/text v0.19.0 modernc.org/mathutil v1.6.0 @@ -18,11 +17,7 @@ require ( ) require ( - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/kr/text v0.2.0 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect go.uber.org/atomic v1.11.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect modernc.org/golex v1.1.0 // indirect ) diff --git a/go.sum b/go.sum index 48cfd95..1a930f1 100644 --- a/go.sum +++ b/go.sum @@ -1,15 +1,10 @@ github.com/coreos/go-semver v0.3.1 h1:yi21YpKnrx1gt5R+la8n5WgS0kCrsPp33dmEyHReZr4= github.com/coreos/go-semver v0.3.1/go.mod h1:irMmmIw/7yzSRPWryHsK7EYSg09caPQL03VsM8rvUec= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pingcap/errors v0.11.5-0.20250523034308-74f78ae071ee h1:/IDPbpzkzA97t1/Z1+C3KlxbevjMeaI6BQYxvivu4u8= github.com/pingcap/errors v0.11.5-0.20250523034308-74f78ae071ee/go.mod h1:X2r9ueLEUZgtx2cIogM0v4Zj5uvvzhuuiu7Pn8HzMPg= github.com/pingcap/failpoint v0.0.0-20240528011301-b51a646c7c86 h1:tdMsjOqUR7YXHoBitzdebTvOjs/swniBTOLy5XiMtuE= @@ -22,8 +17,8 @@ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qq github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= @@ -43,8 +38,6 @@ golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3 golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/mysql/const_test.go b/mysql/const_test.go index 5df5ac8..e9fa27d 100644 --- a/mysql/const_test.go +++ b/mysql/const_test.go @@ -16,7 +16,8 @@ package mysql import ( "testing" - "github.com/stretchr/testify/require" + "reflect" + "strings" ) func TestSQLMode(t *testing.T) { @@ -91,39 +92,65 @@ func TestSQLMode(t *testing.T) { }} for _, ca := range hardCode { - require.Equal(t, ca.value, int(ca.code)) + if !reflect.DeepEqual(ca.value, int(ca.code)) { + t.Fatalf("got %v, want %v", int(ca.code), ca.value) + } } } func TestVersionSeparator(t *testing.T) { // DO NOT change the value of VersionSeparator. - require.Equal(t, "-TiDB-", VersionSeparator) + if !reflect.DeepEqual("-TiDB-", VersionSeparator) { + t.Fatalf("got %v, want %v", VersionSeparator, "-TiDB-") + } } func TestBuildTiDBXReleaseVersion(t *testing.T) { tidbXVersion, err := BuildTiDBXReleaseVersion("v26.3.0") - require.NoError(t, err) - require.Equal(t, "CLOUD.202603.0", tidbXVersion) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual("CLOUD.202603.0", tidbXVersion) { + t.Fatalf("got %v, want %v", tidbXVersion, "CLOUD.202603.0") + } tidbXVersion, err = BuildTiDBXReleaseVersion("v26.3.0-xxx") - require.NoError(t, err) - require.Equal(t, "CLOUD.202603.0-xxx", tidbXVersion) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual("CLOUD.202603.0-xxx", tidbXVersion) { + t.Fatalf("got %v, want %v", tidbXVersion, "CLOUD.202603.0-xxx") + } serverVersion, err := BuildTiDBXServerVersion("v26.3.0") - require.NoError(t, err) - require.Equal(t, "8.0.11-TiDB-CLOUD.202603.0", serverVersion) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual("8.0.11-TiDB-CLOUD.202603.0", serverVersion) { + t.Fatalf("got %v, want %v", serverVersion, "8.0.11-TiDB-CLOUD.202603.0") + } serverVersion, err = BuildTiDBXServerVersion("v26.3.0-xxx") - require.NoError(t, err) - require.Equal(t, "8.0.11-TiDB-CLOUD.202603.0-xxx", serverVersion) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual("8.0.11-TiDB-CLOUD.202603.0-xxx", serverVersion) { + t.Fatalf("got %v, want %v", serverVersion, "8.0.11-TiDB-CLOUD.202603.0-xxx") + } for _, ver := range []string{"26.1.1", "v26xxxx", "v24.1.1", "v26.0.1", "v26.13.1"} { _, err = BuildTiDBXReleaseVersion(ver) - require.ErrorContains(t, err, "invalid TiDB release version") + if err == nil || !strings.Contains(err.Error(), "invalid TiDB release version") { + t.Fatalf("expected error containing %q, got %v", "invalid TiDB release version", err) + } } } func TestNormalizeTiDBReleaseVersionForNextGen(t *testing.T) { - require.Equal(t, tidbXPlaceholderReleaseVersion, NormalizeTiDBReleaseVersionForNextGen(legacyTiDBReleaseVersionPlaceholder)) - require.Equal(t, "v26.3.0", NormalizeTiDBReleaseVersionForNextGen("v26.3.0")) + if !reflect.DeepEqual(tidbXPlaceholderReleaseVersion, NormalizeTiDBReleaseVersionForNextGen(legacyTiDBReleaseVersionPlaceholder)) { + t.Fatalf("got %v, want %v", NormalizeTiDBReleaseVersionForNextGen(legacyTiDBReleaseVersionPlaceholder), tidbXPlaceholderReleaseVersion) + } + if !reflect.DeepEqual("v26.3.0", NormalizeTiDBReleaseVersionForNextGen("v26.3.0")) { + t.Fatalf("got %v, want %v", NormalizeTiDBReleaseVersionForNextGen("v26.3.0"), "v26.3.0") + } } diff --git a/mysql/error_test.go b/mysql/error_test.go index 8468390..0e9fc2d 100644 --- a/mysql/error_test.go +++ b/mysql/error_test.go @@ -15,20 +15,26 @@ package mysql import ( "testing" - - "github.com/stretchr/testify/require" ) func TestSQLError(t *testing.T) { e := NewErrf(ErrNoDB, "no db error", nil) - require.Greater(t, len(e.Error()), 0) + if !(len(e.Error()) > 0) { + t.Fatalf("expected %v > %v", len(e.Error()), 0) + } e = NewErrf(0, "customized error", nil) - require.Greater(t, len(e.Error()), 0) + if !(len(e.Error()) > 0) { + t.Fatalf("expected %v > %v", len(e.Error()), 0) + } e = NewErr(ErrNoDB) - require.Greater(t, len(e.Error()), 0) + if !(len(e.Error()) > 0) { + t.Fatalf("expected %v > %v", len(e.Error()), 0) + } e = NewErr(0, "customized error", nil) - require.Greater(t, len(e.Error()), 0) + if !(len(e.Error()) > 0) { + t.Fatalf("expected %v > %v", len(e.Error()), 0) + } } diff --git a/mysql/privs_test.go b/mysql/privs_test.go index 03d4352..4770152 100644 --- a/mysql/privs_test.go +++ b/mysql/privs_test.go @@ -16,7 +16,7 @@ package mysql import ( "testing" - "github.com/stretchr/testify/require" + "reflect" ) func TestPrivString(t *testing.T) { @@ -25,70 +25,118 @@ func TestPrivString(t *testing.T) { if p > AllPriv { break } - require.NotEqualf(t, "", p.String(), "%d-th", i) + if reflect.DeepEqual("", p.String()) { + t.Fatalf("expected values to differ, both are %v", p.String()) + } } } func TestPrivColumn(t *testing.T) { for _, p := range AllGlobalPrivs { - require.NotEmptyf(t, p.ColumnString(), "%s", p) + if len(p.ColumnString()) == 0 { + t.Fatalf("%s", p) + } np, ok := NewPrivFromColumn(p.ColumnString()) - require.Truef(t, ok, "%s", p) - require.Equal(t, p, np) + if !(ok) { + t.Fatalf("%s", p) + } + if !reflect.DeepEqual(p, np) { + t.Fatalf("got %v, want %v", np, p) + } } for _, p := range StaticGlobalOnlyPrivs { - require.NotEmptyf(t, p.ColumnString(), "%s", p) + if len(p.ColumnString()) == 0 { + t.Fatalf("%s", p) + } np, ok := NewPrivFromColumn(p.ColumnString()) - require.Truef(t, ok, "%s", p) - require.Equal(t, p, np) + if !(ok) { + t.Fatalf("%s", p) + } + if !reflect.DeepEqual(p, np) { + t.Fatalf("got %v, want %v", np, p) + } } for _, p := range AllDBPrivs { - require.NotEmptyf(t, p.ColumnString(), "%s", p) + if len(p.ColumnString()) == 0 { + t.Fatalf("%s", p) + } np, ok := NewPrivFromColumn(p.ColumnString()) - require.Truef(t, ok, "%s", p) - require.Equal(t, p, np) + if !(ok) { + t.Fatalf("%s", p) + } + if !reflect.DeepEqual(p, np) { + t.Fatalf("got %v, want %v", np, p) + } } } func TestPrivSetString(t *testing.T) { for _, p := range AllTablePrivs { - require.NotEmptyf(t, p.SetString(), "%s", p) + if len(p.SetString()) == 0 { + t.Fatalf("%s", p) + } np, ok := NewPrivFromSetEnum(p.SetString()) - require.Truef(t, ok, "%s", p) - require.Equal(t, p, np) + if !(ok) { + t.Fatalf("%s", p) + } + if !reflect.DeepEqual(p, np) { + t.Fatalf("got %v, want %v", np, p) + } } for _, p := range AllColumnPrivs { - require.NotEmptyf(t, p.SetString(), "%s", p) + if len(p.SetString()) == 0 { + t.Fatalf("%s", p) + } np, ok := NewPrivFromSetEnum(p.SetString()) - require.Truef(t, ok, "%s", p) - require.Equal(t, p, np) + if !(ok) { + t.Fatalf("%s", p) + } + if !reflect.DeepEqual(p, np) { + t.Fatalf("got %v, want %v", np, p) + } } } func TestPrivsHas(t *testing.T) { // it is a simple helper, does not handle all&dynamic privs privs := Privileges{AllPriv} - require.True(t, privs.Has(AllPriv)) - require.False(t, privs.Has(InsertPriv)) + if !(privs.Has(AllPriv)) { + t.Fatal("expected true") + } + if privs.Has(InsertPriv) { + t.Fatal("expected false") + } // multiple privs privs = Privileges{InsertPriv, SelectPriv} - require.True(t, privs.Has(SelectPriv)) - require.True(t, privs.Has(InsertPriv)) - require.False(t, privs.Has(DropPriv)) + if !(privs.Has(SelectPriv)) { + t.Fatal("expected true") + } + if !(privs.Has(InsertPriv)) { + t.Fatal("expected true") + } + if privs.Has(DropPriv) { + t.Fatal("expected false") + } } func TestPrivAllConsistency(t *testing.T) { // AllPriv in mysql.user columns. for priv := CreatePriv; priv != AllPriv; priv = priv << 1 { _, ok := Priv2UserCol[priv] - require.Truef(t, ok, "priv fail %d", priv) + if !(ok) { + t.Fatalf("priv fail %d", priv) + } } - require.Equal(t, len(AllGlobalPrivs)+1, len(Priv2UserCol)) + if !reflect.DeepEqual(len(AllGlobalPrivs)+1, len(Priv2UserCol)) { + t.Fatalf("got %v, want %v", len(Priv2UserCol), len(AllGlobalPrivs)+1) + } // USAGE privilege doesn't have a column in Priv2UserCol // ALL privilege doesn't have a column in Priv2UserCol // so it's +2 - require.Equal(t, len(Priv2UserCol)+2, len(Priv2Str)) + if !reflect.DeepEqual(len(Priv2UserCol)+2, len(Priv2Str)) { + t.Fatalf("got %v, want %v", len(Priv2Str), len(Priv2UserCol)+2) + } } diff --git a/mysql/type_test.go b/mysql/type_test.go index 644f96b..f353ce7 100644 --- a/mysql/type_test.go +++ b/mysql/type_test.go @@ -15,21 +15,43 @@ package mysql import ( "testing" - - "github.com/stretchr/testify/require" ) func TestFlags(t *testing.T) { - require.True(t, HasNotNullFlag(NotNullFlag)) - require.True(t, HasUniKeyFlag(UniqueKeyFlag)) - require.True(t, HasNotNullFlag(NotNullFlag)) - require.True(t, HasNoDefaultValueFlag(NoDefaultValueFlag)) - require.True(t, HasAutoIncrementFlag(AutoIncrementFlag)) - require.True(t, HasUnsignedFlag(UnsignedFlag)) - require.True(t, HasZerofillFlag(ZerofillFlag)) - require.True(t, HasBinaryFlag(BinaryFlag)) - require.True(t, HasPriKeyFlag(PriKeyFlag)) - require.True(t, HasMultipleKeyFlag(MultipleKeyFlag)) - require.True(t, HasTimestampFlag(TimestampFlag)) - require.True(t, HasOnUpdateNowFlag(OnUpdateNowFlag)) + if !(HasNotNullFlag(NotNullFlag)) { + t.Fatal("expected true") + } + if !(HasUniKeyFlag(UniqueKeyFlag)) { + t.Fatal("expected true") + } + if !(HasNotNullFlag(NotNullFlag)) { + t.Fatal("expected true") + } + if !(HasNoDefaultValueFlag(NoDefaultValueFlag)) { + t.Fatal("expected true") + } + if !(HasAutoIncrementFlag(AutoIncrementFlag)) { + t.Fatal("expected true") + } + if !(HasUnsignedFlag(UnsignedFlag)) { + t.Fatal("expected true") + } + if !(HasZerofillFlag(ZerofillFlag)) { + t.Fatal("expected true") + } + if !(HasBinaryFlag(BinaryFlag)) { + t.Fatal("expected true") + } + if !(HasPriKeyFlag(PriKeyFlag)) { + t.Fatal("expected true") + } + if !(HasMultipleKeyFlag(MultipleKeyFlag)) { + t.Fatal("expected true") + } + if !(HasTimestampFlag(TimestampFlag)) { + t.Fatal("expected true") + } + if !(HasOnUpdateNowFlag(OnUpdateNowFlag)) { + t.Fatal("expected true") + } } diff --git a/parser/consistent_test.go b/parser/consistent_test.go index c8f00f0..121b755 100644 --- a/parser/consistent_test.go +++ b/parser/consistent_test.go @@ -20,15 +20,19 @@ import ( "strings" "testing" - requires "github.com/stretchr/testify/require" + "reflect" ) func TestKeywordConsistent(t *testing.T) { parserFilename := "parser.y" parserFile, err := os.Open(parserFilename) - requires.NoError(t, err) + if err != nil { + t.Fatal(err) + } data, err := gio.ReadAll(parserFile) - requires.NoError(t, err) + if err != nil { + t.Fatal(err) + } content := string(data) reservedKeywordStartMarker := "\t/* The following tokens belong to ReservedKeyword. Notice: make sure these tokens are contained in ReservedKeyword. */" @@ -43,20 +47,32 @@ func TestKeywordConsistent(t *testing.T) { tidbKeywords := extractKeywords(content, tidbKeywordStartMarker, identTokenEndMarker) for k, v := range aliases { - requires.NotEqual(t, k, v) - requires.Equal(t, tokenMap[v], tokenMap[k]) + if reflect.DeepEqual(k, v) { + t.Fatalf("expected values to differ, both are %v", v) + } + if !reflect.DeepEqual(tokenMap[v], tokenMap[k]) { + t.Fatalf("got %v, want %v", tokenMap[k], tokenMap[v]) + } } keywordCount := len(reservedKeywords) + len(unreservedKeywords) + len(notKeywordTokens) + len(tidbKeywords) - requires.Equal(t, keywordCount-len(windowFuncTokenMap), len(tokenMap)-len(aliases)) + if !reflect.DeepEqual(keywordCount-len(windowFuncTokenMap), len(tokenMap)-len(aliases)) { + t.Fatalf("got %v, want %v", len(tokenMap)-len(aliases), keywordCount-len(windowFuncTokenMap)) + } unreservedCollectionDef := extractKeywordsFromCollectionDef(content, "\nUnReservedKeyword:") - requires.Equal(t, unreservedCollectionDef, unreservedKeywords, "UnReservedKeyword") + if !reflect.DeepEqual(unreservedCollectionDef, unreservedKeywords) { + t.Fatalf("%v: got %v, want %v", "UnReservedKeyword", unreservedKeywords, unreservedCollectionDef) + } notKeywordTokensCollectionDef := extractKeywordsFromCollectionDef(content, "\nNotKeywordToken:") - requires.Equal(t, notKeywordTokensCollectionDef, notKeywordTokens, "NotKeywordToken") + if !reflect.DeepEqual(notKeywordTokensCollectionDef, notKeywordTokens) { + t.Fatalf("%v: got %v, want %v", "NotKeywordToken", notKeywordTokens, notKeywordTokensCollectionDef) + } tidbKeywordsCollectionDef := extractKeywordsFromCollectionDef(content, "\nTiDBKeyword:") - requires.Equal(t, tidbKeywordsCollectionDef, tidbKeywords, "TiDBKeyword") + if !reflect.DeepEqual(tidbKeywordsCollectionDef, tidbKeywords) { + t.Fatalf("%v: got %v, want %v", "TiDBKeyword", tidbKeywords, tidbKeywordsCollectionDef) + } } func extractMiddle(str, startMarker, endMarker string) string { diff --git a/parser/digester_test.go b/parser/digester_test.go index 10b3992..219caa9 100644 --- a/parser/digester_test.go +++ b/parser/digester_test.go @@ -20,7 +20,8 @@ import ( "testing" "github.com/sqlc-dev/marino/parser" - "github.com/stretchr/testify/require" + + "reflect" ) func TestNormalize(t *testing.T) { @@ -75,11 +76,17 @@ func TestNormalize(t *testing.T) { for _, test := range tests_for_generic_normalization_rules { normalized := parser.Normalize(test.input, "ON") digest := parser.DigestNormalized(normalized) - require.Equal(t, test.expect, normalized) + if !reflect.DeepEqual(test.expect, normalized) { + t.Fatalf("got %v, want %v", normalized, test.expect) + } normalized2, digest2 := parser.NormalizeDigest(test.input) - require.Equal(t, normalized, normalized2) - require.Equalf(t, digest.String(), digest2.String(), "%+v", test) + if !reflect.DeepEqual(normalized, normalized2) { + t.Fatalf("got %v, want %v", normalized2, normalized) + } + if !reflect.DeepEqual(digest.String(), digest2.String()) { + t.Fatalf("%s: got %v, want %v", fmt.Sprintf("%+v", test), digest2.String(), digest.String()) + } } tests_for_binding_specific_rules := []struct { @@ -97,11 +104,17 @@ func TestNormalize(t *testing.T) { for _, test := range tests_for_binding_specific_rules { normalized := parser.NormalizeForBinding(test.input, false) digest := parser.DigestNormalized(normalized) - require.Equal(t, test.expect, normalized) + if !reflect.DeepEqual(test.expect, normalized) { + t.Fatalf("got %v, want %v", normalized, test.expect) + } normalized2, digest2 := parser.NormalizeDigestForBinding(test.input) - require.Equal(t, normalized, normalized2) - require.Equalf(t, digest.String(), digest2.String(), "%+v", test) + if !reflect.DeepEqual(normalized, normalized2) { + t.Fatalf("got %v, want %v", normalized2, normalized) + } + if !reflect.DeepEqual(digest.String(), digest2.String()) { + t.Fatalf("%s: got %v, want %v", fmt.Sprintf("%+v", test), digest2.String(), digest.String()) + } } } @@ -122,7 +135,9 @@ func TestNormalizeRedact(t *testing.T) { for _, c := range cases { normalized := parser.Normalize(c.input, "MARKER") - require.Equal(t, c.expect, normalized) + if !reflect.DeepEqual(c.expect, normalized) { + t.Fatalf("got %v, want %v", normalized, c.expect) + } } } @@ -169,7 +184,9 @@ func TestNormalizeKeepHint(t *testing.T) { } for _, test := range tests { normalized := parser.NormalizeKeepHint(test.input) - require.Equal(t, test.expect, normalized) + if !reflect.DeepEqual(test.expect, normalized) { + t.Fatalf("got %v, want %v", normalized, test.expect) + } } } @@ -183,13 +200,21 @@ func TestNormalizeDigest(t *testing.T) { } for _, test := range tests { normalized, digest := parser.NormalizeDigest(test.sql) - require.Equal(t, test.normalized, normalized) - require.Equal(t, test.digest, digest.String()) + if !reflect.DeepEqual(test.normalized, normalized) { + t.Fatalf("got %v, want %v", normalized, test.normalized) + } + if !reflect.DeepEqual(test.digest, digest.String()) { + t.Fatalf("got %v, want %v", digest.String(), test.digest) + } normalized = parser.Normalize(test.sql, "ON") digest = parser.DigestNormalized(normalized) - require.Equal(t, test.normalized, normalized) - require.Equal(t, test.digest, digest.String()) + if !reflect.DeepEqual(test.normalized, normalized) { + t.Fatalf("got %v, want %v", normalized, test.normalized) + } + if !reflect.DeepEqual(test.digest, digest.String()) { + t.Fatalf("got %v, want %v", digest.String(), test.digest) + } } } @@ -209,7 +234,9 @@ func TestDigestHashEqForSimpleSQL(t *testing.T) { d = dig.String() continue } - require.Equal(t, dig.String(), d) + if !reflect.DeepEqual(dig.String(), d) { + t.Fatalf("got %v, want %v", d, dig.String()) + } } } } @@ -226,7 +253,9 @@ func TestDigestHashNotEqForSimpleSQL(t *testing.T) { d = dig.String() continue } - require.NotEqual(t, dig.String(), d) + if reflect.DeepEqual(dig.String(), d) { + t.Fatalf("expected values to differ, both are %v", d) + } } } } @@ -234,11 +263,19 @@ func TestDigestHashNotEqForSimpleSQL(t *testing.T) { func TestGenDigest(t *testing.T) { hash := genRandDigest("abc") digest := parser.NewDigest(hash) - require.Equal(t, fmt.Sprintf("%x", hash), digest.String()) - require.Equal(t, hash, digest.Bytes()) + if !reflect.DeepEqual(fmt.Sprintf("%x", hash), digest.String()) { + t.Fatalf("got %v, want %v", digest.String(), fmt.Sprintf("%x", hash)) + } + if !reflect.DeepEqual(hash, digest.Bytes()) { + t.Fatalf("got %v, want %v", digest.Bytes(), hash) + } digest = parser.NewDigest(nil) - require.Equal(t, "", digest.String()) - require.Nil(t, digest.Bytes()) + if !reflect.DeepEqual("", digest.String()) { + t.Fatalf("got %v, want %v", digest.String(), "") + } + if digest.Bytes() != nil { + t.Fatalf("expected nil, got %v", digest.Bytes()) + } } func genRandDigest(str string) []byte { diff --git a/parser/hintparser_test.go b/parser/hintparser_test.go index 0312a0f..0d93317 100644 --- a/parser/hintparser_test.go +++ b/parser/hintparser_test.go @@ -16,10 +16,13 @@ package parser_test import ( "testing" - "github.com/sqlc-dev/marino/parser" "github.com/sqlc-dev/marino/ast" "github.com/sqlc-dev/marino/mysql" - "github.com/stretchr/testify/require" + "github.com/sqlc-dev/marino/parser" + + "fmt" + "reflect" + "strings" ) func TestParseHint(t *testing.T) { @@ -475,11 +478,19 @@ func TestParseHint(t *testing.T) { for _, tc := range testCases { output, errs := parser.ParseHint("/*+"+tc.input+"*/", tc.mode, parser.Pos{Line: 1}) - require.Lenf(t, errs, len(tc.errs), "input = %s,\n... errs = %q", tc.input, errs) + if got := len(errs); got != len(tc.errs) { + t.Fatalf("%s: expected length %d, got %d", fmt.Sprintf("input = %s,\n... errs = %q", tc.input, errs), len(tc.errs), got) + } for i, err := range errs { - require.Errorf(t, err, "input = %s, i = %d", tc.input, i) - require.Containsf(t, err.Error(), tc.errs[i], "input = %s, i = %d", tc.input, i) + if err == nil { + t.Fatalf("input = %s, i = %d", tc.input, i) + } + if !strings.Contains(err.Error(), tc.errs[i]) { + t.Fatalf("%s: expected %q to contain %q", fmt.Sprintf("input = %s, i = %d", tc.input, i), err.Error(), tc.errs[i]) + } + } + if !reflect.DeepEqual(tc.output, output) { + t.Fatalf("input = %s,\n... got %v, want %v", tc.input, output, tc.output) } - require.Equalf(t, tc.output, output, "input = %s,\n... output = %q", tc.input, output) } } diff --git a/parser/keywords_test.go b/parser/keywords_test.go index baee8ea..5660835 100644 --- a/parser/keywords_test.go +++ b/parser/keywords_test.go @@ -17,13 +17,18 @@ import ( "testing" "github.com/sqlc-dev/marino/parser" - "github.com/stretchr/testify/require" + + "reflect" ) func TestKeywords(t *testing.T) { // Test for the first keyword - require.Equal(t, "ADD", parser.Keywords[0].Word) - require.Equal(t, true, parser.Keywords[0].Reserved) + if !reflect.DeepEqual("ADD", parser.Keywords[0].Word) { + t.Fatalf("got %v, want %v", parser.Keywords[0].Word, "ADD") + } + if !reflect.DeepEqual(true, parser.Keywords[0].Reserved) { + t.Fatalf("got %v, want %v", parser.Keywords[0].Reserved, true) + } // Make sure TiDBKeywords are included. found := false @@ -32,11 +37,15 @@ func TestKeywords(t *testing.T) { found = true } } - require.Equal(t, found, true, "TiDBKeyword ADMIN is part of the list") + if !reflect.DeepEqual(found, true) { + t.Fatalf("%v: got %v, want %v", "TiDBKeyword ADMIN is part of the list", true, found) + } } func TestKeywordsLength(t *testing.T) { - require.Equal(t, 679, len(parser.Keywords)) + if !reflect.DeepEqual(679, len(parser.Keywords)) { + t.Fatalf("got %v, want %v", len(parser.Keywords), 679) + } reservedNr := 0 for _, kw := range parser.Keywords { @@ -44,7 +53,9 @@ func TestKeywordsLength(t *testing.T) { reservedNr += 1 } } - require.Equal(t, 233, reservedNr) + if !reflect.DeepEqual(233, reservedNr) { + t.Fatalf("got %v, want %v", reservedNr, 233) + } } func TestKeywordsSorting(t *testing.T) { diff --git a/parser/lateral_test.go b/parser/lateral_test.go index 2203980..9e32691 100644 --- a/parser/lateral_test.go +++ b/parser/lateral_test.go @@ -18,10 +18,12 @@ import ( "strings" "testing" - "github.com/sqlc-dev/marino/parser" "github.com/sqlc-dev/marino/ast" "github.com/sqlc-dev/marino/format" - "github.com/stretchr/testify/require" + "github.com/sqlc-dev/marino/parser" + + "fmt" + "reflect" ) func TestLateralParsing(t *testing.T) { @@ -127,29 +129,43 @@ func TestLateralParsing(t *testing.T) { stmt, err := p.ParseOneStmt(tc.sql, "", "") if tc.expectError { - require.Error(t, err, "Expected parsing to fail for: %s", tc.sql) + if err == nil { + t.Fatalf("Expected parsing to fail for: %s", tc.sql) + } return } - require.NoError(t, err, "Failed to parse: %s", tc.sql) - require.NotNil(t, stmt) + if err != nil { + t.Fatalf("%s: %v", fmt.Sprintf("Failed to parse: %s", tc.sql), err) + } + if stmt == nil { + t.Fatal("expected non-nil") + } // Test round-trip: parse -> restore -> parse again var sb strings.Builder restoreCtx := format.NewRestoreCtx(format.RestoreStringSingleQuotes, &sb) err = stmt.Restore(restoreCtx) - require.NoError(t, err, "Failed to restore statement") + if err != nil { + t.Fatalf("%v: %v", "Failed to restore statement", err) + } restored := sb.String() if tc.checkLateral { // Verify LATERAL keyword is preserved in restoration - require.Contains(t, restored, "LATERAL", "LATERAL keyword missing in restored SQL: %s", restored) + if !strings.Contains(restored, "LATERAL") { + t.Fatalf("%s: expected %q to contain %q", fmt.Sprintf("LATERAL keyword missing in restored SQL: %s", restored), restored, "LATERAL") + } } // Parse the restored SQL to ensure it's valid (round-trip test) stmt2, err := p.ParseOneStmt(restored, "", "") - require.NoError(t, err, "Failed to parse restored SQL: %s", restored) - require.NotNil(t, stmt2) + if err != nil { + t.Fatalf("%s: %v", fmt.Sprintf("Failed to parse restored SQL: %s", restored), err) + } + if stmt2 == nil { + t.Fatal("expected non-nil") + } // Verify AST flags on both original and round-tripped statements. for _, stmtToCheck := range []struct { @@ -160,25 +176,34 @@ func TestLateralParsing(t *testing.T) { {"round-trip", stmt2}, } { selectStmt, ok := stmtToCheck.node.(*ast.SelectStmt) - require.True(t, ok, "[%s] Statement should be SelectStmt", stmtToCheck.label) - require.NotNil(t, selectStmt.From, "[%s] FROM clause should not be nil", stmtToCheck.label) + if !(ok) { + t.Fatalf("[%s] Statement should be SelectStmt", stmtToCheck.label) + } + if selectStmt.From == nil { + t.Fatalf("[%s] FROM clause should not be nil", stmtToCheck.label) + } if tc.checkLateral { lateralTS := findLateralTableSource(selectStmt.From.TableRefs) - require.NotNil(t, lateralTS, "[%s] LATERAL TableSource not found for: %s", stmtToCheck.label, tc.sql) + if lateralTS == nil { + t.Fatalf("[%s] LATERAL TableSource not found for: %s", stmtToCheck.label, tc.sql) + } if len(tc.columnNames) > 0 { - require.Len(t, lateralTS.ColumnNames, len(tc.columnNames), - "[%s] column name count mismatch", stmtToCheck.label) + if got := len(lateralTS.ColumnNames); got != len(tc.columnNames) { + t.Fatalf("%s: expected length %d, got %d", fmt.Sprintf("[%s] column name count mismatch", stmtToCheck.label), len(tc.columnNames), got) + } for i, expected := range tc.columnNames { - require.Equal(t, expected, lateralTS.ColumnNames[i].L, - "[%s] column name mismatch at index %d", stmtToCheck.label, i) + if !reflect.DeepEqual(expected, lateralTS.ColumnNames[i].L) { + t.Fatalf("%s: got %v, want %v", fmt.Sprintf("[%s] column name mismatch at index %d", stmtToCheck.label, i), lateralTS.ColumnNames[i].L, expected) + } } } } else { lateralTS := findLateralTableSource(selectStmt.From.TableRefs) - require.Nil(t, lateralTS, "[%s] Lateral should be false for non-LATERAL query: %s", - stmtToCheck.label, tc.sql) + if lateralTS != nil { + t.Fatalf("[%s] Lateral should be false for non-LATERAL query: %s", stmtToCheck.label, tc.sql) + } } } }) diff --git a/parser/lexer_test.go b/parser/lexer_test.go index a14fa76..7086b6f 100644 --- a/parser/lexer_test.go +++ b/parser/lexer_test.go @@ -19,7 +19,8 @@ import ( "unicode" "github.com/sqlc-dev/marino/mysql" - requires "github.com/stretchr/testify/require" + + "reflect" ) func TestTokenID(t *testing.T) { @@ -27,7 +28,9 @@ func TestTokenID(t *testing.T) { l := NewScanner(str) var v yySymType tok1 := l.Lex(&v) - requires.Equal(t, tok1, tok) + if !reflect.DeepEqual(tok1, tok) { + t.Fatalf("got %v, want %v", tok, tok1) + } } } @@ -37,7 +40,9 @@ func TestSingleChar(t *testing.T) { l := NewScanner(string(tok)) var v yySymType tok1 := l.Lex(&v) - requires.Equal(t, tok1, int(tok)) + if !reflect.DeepEqual(tok1, int(tok)) { + t.Fatalf("got %v, want %v", int(tok), tok1) + } } } @@ -91,15 +96,23 @@ func TestUnderscoreCS(t *testing.T) { var v yySymType scanner := NewScanner(`_utf8"string"`) tok := scanner.Lex(&v) - requires.Equal(t, underscoreCS, tok) + if !reflect.DeepEqual(underscoreCS, tok) { + t.Fatalf("got %v, want %v", tok, underscoreCS) + } tok = scanner.Lex(&v) - requires.Equal(t, stringLit, tok) + if !reflect.DeepEqual(stringLit, tok) { + t.Fatalf("got %v, want %v", tok, stringLit) + } scanner.reset("N'string'") tok = scanner.Lex(&v) - requires.Equal(t, underscoreCS, tok) + if !reflect.DeepEqual(underscoreCS, tok) { + t.Fatalf("got %v, want %v", tok, underscoreCS) + } tok = scanner.Lex(&v) - requires.Equal(t, stringLit, tok) + if !reflect.DeepEqual(stringLit, tok) { + t.Fatalf("got %v, want %v", tok, stringLit) + } } func TestLiteral(t *testing.T) { @@ -203,7 +216,9 @@ func runTest(t *testing.T, table []testCaseItem) { for _, v := range table { l := NewScanner(v.str) tok := l.Lex(&val) - requires.Equal(t, v.tok, tok, v.str) + if !reflect.DeepEqual(v.tok, tok) { + t.Fatalf("%v: got %v, want %v", v.str, tok, v.tok) + } } } @@ -213,13 +228,21 @@ func runLiteralTest(t *testing.T, table []testLiteralValue) { val := l.LexLiteral() switch val.(type) { case int64: - requires.Equal(t, v.val, val, v.str) + if !reflect.DeepEqual(v.val, val) { + t.Fatalf("%v: got %v, want %v", v.str, val, v.val) + } case float64: - requires.Equal(t, v.val, val, v.str) + if !reflect.DeepEqual(v.val, val) { + t.Fatalf("%v: got %v, want %v", v.str, val, v.val) + } case string: - requires.Equal(t, v.val, val, v.str) + if !reflect.DeepEqual(v.val, val) { + t.Fatalf("%v: got %v, want %v", v.str, val, v.val) + } default: - requires.Equal(t, v.val, fmt.Sprint(val), v.str) + if !reflect.DeepEqual(v.val, fmt.Sprint(val)) { + t.Fatalf("%v: got %v, want %v", v.str, fmt.Sprint(val), v.val) + } } } } @@ -253,9 +276,15 @@ func TestScanQuotedIdent(t *testing.T) { l := NewScanner("`fk`") l.r.peek() tok, pos, lit := scanQuotedIdent(l) - requires.Zero(t, pos.Offset) - requires.Equal(t, quotedIdentifier, tok) - requires.Equal(t, "fk", lit) + if pos.Offset != 0 { + t.Fatalf("expected zero, got %v", pos.Offset) + } + if !reflect.DeepEqual(quotedIdentifier, tok) { + t.Fatalf("got %v, want %v", tok, quotedIdentifier) + } + if !reflect.DeepEqual("fk", lit) { + t.Fatalf("got %v, want %v", lit, "fk") + } } func TestScanString(t *testing.T) { @@ -287,9 +316,15 @@ func TestScanString(t *testing.T) { for _, v := range table { l := NewScanner(v.raw) tok, pos, lit := l.scan() - requires.Zero(t, pos.Offset) - requires.Equal(t, stringLit, tok) - requires.Equal(t, v.expect, lit) + if pos.Offset != 0 { + t.Fatalf("expected zero, got %v", pos.Offset) + } + if !reflect.DeepEqual(stringLit, tok) { + t.Fatalf("got %v, want %v", tok, stringLit) + } + if !reflect.DeepEqual(v.expect, lit) { + t.Fatalf("got %v, want %v", lit, v.expect) + } } } @@ -320,9 +355,15 @@ func TestScanStringWithNoBackslashEscapesMode(t *testing.T) { for _, v := range table { l.reset(v.raw) tok, pos, lit := l.scan() - requires.Zero(t, pos.Offset) - requires.Equal(t, stringLit, tok) - requires.Equal(t, v.expect, lit) + if pos.Offset != 0 { + t.Fatalf("expected zero, got %v", pos.Offset) + } + if !reflect.DeepEqual(stringLit, tok) { + t.Fatalf("got %v, want %v", tok, stringLit) + } + if !reflect.DeepEqual(v.expect, lit) { + t.Fatalf("got %v, want %v", lit, v.expect) + } } } @@ -351,41 +392,73 @@ func TestIdentifier(t *testing.T) { l.reset(item[0]) var v yySymType tok := l.Lex(&v) - requires.Equal(t, identifier, tok, item) - requires.Equal(t, item[1], v.ident, item) + if !reflect.DeepEqual(identifier, tok) { + t.Fatalf("%v: got %v, want %v", item, tok, identifier) + } + if !reflect.DeepEqual(item[1], v.ident) { + t.Fatalf("%v: got %v, want %v", item, v.ident, item[1]) + } } } func TestSpecialComment(t *testing.T) { l := NewScanner("/*!40101 select\n5*/") tok, pos, lit := l.scan() - requires.Equal(t, identifier, tok) - requires.Equal(t, "select", lit) - requires.Equal(t, Pos{1, 9, 9}, pos) + if !reflect.DeepEqual(identifier, tok) { + t.Fatalf("got %v, want %v", tok, identifier) + } + if !reflect.DeepEqual("select", lit) { + t.Fatalf("got %v, want %v", lit, "select") + } + if !reflect.DeepEqual(Pos{1, 9, 9}, pos) { + t.Fatalf("got %v, want %v", pos, Pos{1, 9, 9}) + } tok, pos, lit = l.scan() - requires.Equal(t, intLit, tok) - requires.Equal(t, "5", lit) - requires.Equal(t, Pos{2, 1, 16}, pos) + if !reflect.DeepEqual(intLit, tok) { + t.Fatalf("got %v, want %v", tok, intLit) + } + if !reflect.DeepEqual("5", lit) { + t.Fatalf("got %v, want %v", lit, "5") + } + if !reflect.DeepEqual(Pos{2, 1, 16}, pos) { + t.Fatalf("got %v, want %v", pos, Pos{2, 1, 16}) + } } func TestFeatureIDsComment(t *testing.T) { l := NewScanner("/*T![auto_rand] auto_random(5) */") tok, pos, lit := l.scan() - requires.Equal(t, identifier, tok) - requires.Equal(t, "auto_random", lit) - requires.Equal(t, Pos{1, 16, 16}, pos) + if !reflect.DeepEqual(identifier, tok) { + t.Fatalf("got %v, want %v", tok, identifier) + } + if !reflect.DeepEqual("auto_random", lit) { + t.Fatalf("got %v, want %v", lit, "auto_random") + } + if !reflect.DeepEqual(Pos{1, 16, 16}, pos) { + t.Fatalf("got %v, want %v", pos, Pos{1, 16, 16}) + } tok, _, _ = l.scan() - requires.Equal(t, int('('), tok) + if !reflect.DeepEqual(int('('), tok) { + t.Fatalf("got %v, want %v", tok, int('(')) + } _, pos, lit = l.scan() - requires.Equal(t, "5", lit) - requires.Equal(t, Pos{1, 28, 28}, pos) + if !reflect.DeepEqual("5", lit) { + t.Fatalf("got %v, want %v", lit, "5") + } + if !reflect.DeepEqual(Pos{1, 28, 28}, pos) { + t.Fatalf("got %v, want %v", pos, Pos{1, 28, 28}) + } tok, _, _ = l.scan() - requires.Equal(t, int(')'), tok) + if !reflect.DeepEqual(int(')'), tok) { + t.Fatalf("got %v, want %v", tok, int(')')) + } l = NewScanner("/*T![unsupported_feature] unsupported(123) */") tok, _, _ = l.scan() - requires.Equal(t, 0, tok) + if !reflect.DeepEqual(0, tok) { + t.Fatalf("got %v, want %v", tok, 0) + } } func TestOptimizerHint(t *testing.T) { @@ -406,9 +479,15 @@ func TestOptimizerHint(t *testing.T) { if tok == 0 { return } - requires.Equal(t, tokens[i].tok, tok, i) - requires.Equal(t, tokens[i].ident, sym.ident, i) - requires.Equal(t, tokens[i].pos, sym.offset, i) + if !reflect.DeepEqual(tokens[i].tok, tok) { + t.Fatalf("%v: got %v, want %v", i, tok, tokens[i].tok) + } + if !reflect.DeepEqual(tokens[i].ident, sym.ident) { + t.Fatalf("%v: got %v, want %v", i, sym.ident, tokens[i].ident) + } + if !reflect.DeepEqual(tokens[i].pos, sym.offset) { + t.Fatalf("%v: got %v, want %v", i, sym.offset, tokens[i].pos) + } } } @@ -484,7 +563,9 @@ func TestOptimizerHintAfterCertainKeywordOnly(t *testing.T) { var sym yySymType for i := 0; ; i++ { tok := scanner.Lex(&sym) - requires.Equalf(t, tc.tokens[i], tok, "input = [%s], i = %d", tc.input, i) + if !reflect.DeepEqual(tc.tokens[i], tok) { + t.Fatalf("%s: got %v, want %v", fmt.Sprintf("input = [%s], i = %d", tc.input, i), tok, tc.tokens[i]) + } if tok == 0 { break } @@ -509,12 +590,18 @@ func TestInt(t *testing.T) { var v yySymType scanner.reset(test.input) tok := scanner.Lex(&v) - requires.Equal(t, intLit, tok) + if !reflect.DeepEqual(intLit, tok) { + t.Fatalf("got %v, want %v", tok, intLit) + } switch i := v.item.(type) { case int64: - requires.Equal(t, test.expect, uint64(i)) + if !reflect.DeepEqual(test.expect, uint64(i)) { + t.Fatalf("got %v, want %v", uint64(i), test.expect) + } case uint64: - requires.Equal(t, test.expect, i) + if !reflect.DeepEqual(test.expect, i) { + t.Fatalf("got %v, want %v", i, test.expect) + } default: t.Fail() } @@ -540,17 +627,29 @@ func TestSQLModeANSIQuotes(t *testing.T) { var v yySymType scanner.reset(test.input) tok := scanner.Lex(&v) - requires.Equal(t, test.tok, tok) - requires.Equal(t, test.ident, v.ident) + if !reflect.DeepEqual(test.tok, tok) { + t.Fatalf("got %v, want %v", tok, test.tok) + } + if !reflect.DeepEqual(test.ident, v.ident) { + t.Fatalf("got %v, want %v", v.ident, test.ident) + } } scanner.reset(`'string' 'string'`) var v yySymType tok := scanner.Lex(&v) - requires.Equal(t, stringLit, tok) - requires.Equal(t, "string", v.ident) + if !reflect.DeepEqual(stringLit, tok) { + t.Fatalf("got %v, want %v", tok, stringLit) + } + if !reflect.DeepEqual("string", v.ident) { + t.Fatalf("got %v, want %v", v.ident, "string") + } tok = scanner.Lex(&v) - requires.Equal(t, stringLit, tok) - requires.Equal(t, "string", v.ident) + if !reflect.DeepEqual(stringLit, tok) { + t.Fatalf("got %v, want %v", tok, stringLit) + } + if !reflect.DeepEqual("string", v.ident) { + t.Fatalf("got %v, want %v", v.ident, "string") + } } func TestIllegal(t *testing.T) { @@ -645,7 +744,9 @@ func TestVersionDigits(t *testing.T) { scanner.reset(test.input) scanner.scanVersionDigits(test.min, test.max) nextChar := scanner.r.readByte() - requires.Equalf(t, test.nextChar, nextChar, "input = %s", test.input) + if !reflect.DeepEqual(test.nextChar, nextChar) { + t.Fatalf("%s: got %v, want %v", fmt.Sprintf("input = %s", test.input), nextChar, test.nextChar) + } } } @@ -715,8 +816,12 @@ func TestFeatureIDs(t *testing.T) { for _, test := range tests { scanner.reset(test.input) featureIDs := scanner.scanFeatureIDs() - requires.Equalf(t, test.featureIDs, featureIDs, "input = %s", test.input) + if !reflect.DeepEqual(test.featureIDs, featureIDs) { + t.Fatalf("%s: got %v, want %v", fmt.Sprintf("input = %s", test.input), featureIDs, test.featureIDs) + } nextChar := scanner.r.readByte() - requires.Equalf(t, test.nextChar, nextChar, "input = %s", test.input) + if !reflect.DeepEqual(test.nextChar, nextChar) { + t.Fatalf("%s: got %v, want %v", fmt.Sprintf("input = %s", test.input), nextChar, test.nextChar) + } } } diff --git a/parser/parser_test.go b/parser/parser_test.go index 6e7d10f..b94d1d4 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -16,21 +16,23 @@ package parser_test import ( "bytes" "fmt" + "regexp" "runtime" "slices" "strings" "testing" "github.com/pingcap/errors" - "github.com/sqlc-dev/marino/parser" "github.com/sqlc-dev/marino/ast" "github.com/sqlc-dev/marino/charset" . "github.com/sqlc-dev/marino/format" "github.com/sqlc-dev/marino/mysql" "github.com/sqlc-dev/marino/opcode" + "github.com/sqlc-dev/marino/parser" "github.com/sqlc-dev/marino/terror" "github.com/sqlc-dev/marino/test_driver" - "github.com/stretchr/testify/require" + + "reflect" ) func TestSimple(t *testing.T) { @@ -68,15 +70,21 @@ func TestSimple(t *testing.T) { src := fmt.Sprintf("SELECT * FROM db.%s;", kw) _, err := p.ParseOneStmt(src, "", "") - require.NoErrorf(t, err, "source %s", src) + if err != nil { + t.Fatalf("%s: %v", fmt.Sprintf("source %s", src), err) + } src = fmt.Sprintf("SELECT * FROM %s.desc", kw) _, err = p.ParseOneStmt(src, "", "") - require.NoErrorf(t, err, "source %s", src) + if err != nil { + t.Fatalf("%s: %v", fmt.Sprintf("source %s", src), err) + } src = fmt.Sprintf("SELECT t.%s FROM t", kw) _, err = p.ParseOneStmt(src, "", "") - require.NoErrorf(t, err, "source %s", src) + if err != nil { + t.Fatalf("%s: %v", fmt.Sprintf("source %s", src), err) + } } // Testcase for unreserved keywords @@ -104,51 +112,85 @@ func TestSimple(t *testing.T) { for _, kw := range unreservedKws { src := fmt.Sprintf("SELECT %s FROM tbl;", kw) _, err := p.ParseOneStmt(src, "", "") - require.NoErrorf(t, err, "source %s", src) + if err != nil { + t.Fatalf("%s: %v", fmt.Sprintf("source %s", src), err) + } } // Testcase for prepared statement src := "SELECT id+?, id+? from t;" _, err := p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } // Testcase for -- Comment and unary -- operator src = "CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED); -- foo\nSelect --1 from foo;" stmts, _, err := p.Parse(src, "", "") - require.NoError(t, err) - require.Len(t, stmts, 2) + if err != nil { + t.Fatal(err) + } + if got := len(stmts); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } // Testcase for /*! xx */ // See http://dev.mysql.com/doc/refman/5.7/en/comments.html // Fix: https://github.com/pingcap/tidb/issues/971 src = "/*!40101 SET character_set_client = utf8 */;" stmts, _, err = p.Parse(src, "", "") - require.NoError(t, err) - require.Len(t, stmts, 1) + if err != nil { + t.Fatal(err) + } + if got := len(stmts); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } stmt := stmts[0] _, ok := stmt.(*ast.SetStmt) - require.True(t, ok) + if !(ok) { + t.Fatal("expected true") + } // for issue #2017 src = "insert into blobtable (a) values ('/*! truncated */');" stmt, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } is, ok := stmt.(*ast.InsertStmt) - require.True(t, ok) - require.Len(t, is.Lists, 1) - require.Len(t, is.Lists[0], 1) - require.Equal(t, "/*! truncated */", is.Lists[0][0].(ast.ValueExpr).GetDatumString()) + if !(ok) { + t.Fatal("expected true") + } + if got := len(is.Lists); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if got := len(is.Lists[0]); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("/*! truncated */", is.Lists[0][0].(ast.ValueExpr).GetDatumString()) { + t.Fatalf("got %v, want %v", is.Lists[0][0].(ast.ValueExpr).GetDatumString(), "/*! truncated */") + } // Testcase for CONVERT(expr,type) src = "SELECT CONVERT('111', SIGNED);" st, err := p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } ss, ok := st.(*ast.SelectStmt) - require.True(t, ok) - require.Len(t, ss.Fields.Fields, 1) + if !(ok) { + t.Fatal("expected true") + } + if got := len(ss.Fields.Fields); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } cv, ok := ss.Fields.Fields[0].Expr.(*ast.FuncCastExpr) - require.True(t, ok) - require.Equal(t, ast.CastConvertFunction, cv.FunctionType) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual(ast.CastConvertFunction, cv.FunctionType) { + t.Fatalf("got %v, want %v", cv.FunctionType, ast.CastConvertFunction) + } // for query start with comment srcs := []string{ @@ -160,54 +202,80 @@ func TestSimple(t *testing.T) { } for _, src := range srcs { st, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } _, ok = st.(*ast.SelectStmt) - require.True(t, ok) + if !(ok) { + t.Fatal("expected true") + } } // for issue #961 src = "create table t (c int key);" st, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } cs, ok := st.(*ast.CreateTableStmt) - require.True(t, ok) - require.Len(t, cs.Cols, 1) - require.Len(t, cs.Cols[0].Options, 1) - require.Equal(t, ast.ColumnOptionPrimaryKey, cs.Cols[0].Options[0].Tp) + if !(ok) { + t.Fatal("expected true") + } + if got := len(cs.Cols); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if got := len(cs.Cols[0].Options); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual(ast.ColumnOptionPrimaryKey, cs.Cols[0].Options[0].Tp) { + t.Fatalf("got %v, want %v", cs.Cols[0].Options[0].Tp, ast.ColumnOptionPrimaryKey) + } // for issue #4497 src = "create table t1(a NVARCHAR(100));" _, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } // for issue 2803 src = "use quote;" _, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } // issue #4354 src = "select b'';" _, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } src = "select B'';" _, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } // src = "select 0b'';" // _, err = p.ParseOneStmt(src, "", "") - // require.Error(t, err) + // if err == nil { t.Fatal("expected error") } // for #4909, support numericType `signed` filedOpt. src = "CREATE TABLE t(_sms smallint signed, _smu smallint unsigned);" _, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } // for #7371, support NATIONAL CHARACTER // reference link: https://dev.mysql.com/doc/refman/5.7/en/charset-national.html src = "CREATE TABLE t(c1 NATIONAL CHARACTER(10));" _, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } src = `CREATE TABLE t(a tinyint signed, b smallint signed, @@ -225,95 +293,177 @@ func TestSimple(t *testing.T) { );` st, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } ct, ok := st.(*ast.CreateTableStmt) - require.True(t, ok) + if !(ok) { + t.Fatal("expected true") + } for _, col := range ct.Cols { - require.Equal(t, uint(0), col.Tp.GetFlag()&mysql.UnsignedFlag) + if !reflect.DeepEqual(uint(0), col.Tp.GetFlag()&mysql.UnsignedFlag) { + t.Fatalf("got %v, want %v", col.Tp.GetFlag()&mysql.UnsignedFlag, uint(0)) + } } // for issue #4006 src = `insert into tb(v) (select v from tb);` _, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } // for issue #34642 src = `SELECT a as c having c = a;` _, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } // for issue #9823 src = "SELECT 9223372036854775807;" st, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } sel, ok := st.(*ast.SelectStmt) - require.True(t, ok) + if !(ok) { + t.Fatal("expected true") + } expr := sel.Fields.Fields[0] vExpr := expr.Expr.(*test_driver.ValueExpr) - require.Equal(t, test_driver.KindInt64, vExpr.Kind()) + if !reflect.DeepEqual(test_driver.KindInt64, vExpr.Kind()) { + t.Fatalf("got %v, want %v", vExpr.Kind(), test_driver.KindInt64) + } src = "SELECT 9223372036854775808;" st, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } sel, ok = st.(*ast.SelectStmt) - require.True(t, ok) + if !(ok) { + t.Fatal("expected true") + } expr = sel.Fields.Fields[0] vExpr = expr.Expr.(*test_driver.ValueExpr) - require.Equal(t, test_driver.KindUint64, vExpr.Kind()) + if !reflect.DeepEqual(test_driver.KindUint64, vExpr.Kind()) { + t.Fatalf("got %v, want %v", vExpr.Kind(), test_driver.KindUint64) + } src = `select 99e+r10 from t1;` st, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } sel, ok = st.(*ast.SelectStmt) - require.True(t, ok) + if !(ok) { + t.Fatal("expected true") + } bExpr, ok := sel.Fields.Fields[0].Expr.(*ast.BinaryOperationExpr) - require.True(t, ok) - require.Equal(t, opcode.Plus, bExpr.Op) - require.Equal(t, "99e", bExpr.L.(*ast.ColumnNameExpr).Name.Name.O) - require.Equal(t, "r10", bExpr.R.(*ast.ColumnNameExpr).Name.Name.O) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual(opcode.Plus, bExpr.Op) { + t.Fatalf("got %v, want %v", bExpr.Op, opcode.Plus) + } + if !reflect.DeepEqual("99e", bExpr.L.(*ast.ColumnNameExpr).Name.Name.O) { + t.Fatalf("got %v, want %v", bExpr.L.(*ast.ColumnNameExpr).Name.Name.O, "99e") + } + if !reflect.DeepEqual("r10", bExpr.R.(*ast.ColumnNameExpr).Name.Name.O) { + t.Fatalf("got %v, want %v", bExpr.R.(*ast.ColumnNameExpr).Name.Name.O, "r10") + } src = `select t./*123*/*,@c3:=0 from t order by t.c1;` st, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } sel, ok = st.(*ast.SelectStmt) - require.True(t, ok) - require.Equal(t, "t", sel.Fields.Fields[0].WildCard.Table.O) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual("t", sel.Fields.Fields[0].WildCard.Table.O) { + t.Fatalf("got %v, want %v", sel.Fields.Fields[0].WildCard.Table.O, "t") + } varExpr, ok := sel.Fields.Fields[1].Expr.(*ast.VariableExpr) - require.True(t, ok) - require.Equal(t, "c3", varExpr.Name) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual("c3", varExpr.Name) { + t.Fatalf("got %v, want %v", varExpr.Name, "c3") + } src = `select t.1e from test.t;` st, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } sel, ok = st.(*ast.SelectStmt) - require.True(t, ok) + if !(ok) { + t.Fatal("expected true") + } colExpr, ok := sel.Fields.Fields[0].Expr.(*ast.ColumnNameExpr) - require.True(t, ok) - require.Equal(t, "t", colExpr.Name.Table.O) - require.Equal(t, "1e", colExpr.Name.Name.O) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual("t", colExpr.Name.Table.O) { + t.Fatalf("got %v, want %v", colExpr.Name.Table.O, "t") + } + if !reflect.DeepEqual("1e", colExpr.Name.Name.O) { + t.Fatalf("got %v, want %v", colExpr.Name.Name.O, "1e") + } tName := sel.From.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName) - require.Equal(t, "test", tName.Schema.O) - require.Equal(t, "t", tName.Name.O) + if !reflect.DeepEqual("test", tName.Schema.O) { + t.Fatalf("got %v, want %v", tName.Schema.O, "test") + } + if !reflect.DeepEqual("t", tName.Name.O) { + t.Fatalf("got %v, want %v", tName.Name.O, "t") + } src = "select t. `a` > 10 from t;" st, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } bExpr, ok = st.(*ast.SelectStmt).Fields.Fields[0].Expr.(*ast.BinaryOperationExpr) - require.True(t, ok) - require.Equal(t, opcode.GT, bExpr.Op) - require.Equal(t, "a", bExpr.L.(*ast.ColumnNameExpr).Name.Name.O) - require.Equal(t, "t", bExpr.L.(*ast.ColumnNameExpr).Name.Table.O) - require.Equal(t, int64(10), bExpr.R.(ast.ValueExpr).GetValue().(int64)) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual(opcode.GT, bExpr.Op) { + t.Fatalf("got %v, want %v", bExpr.Op, opcode.GT) + } + if !reflect.DeepEqual("a", bExpr.L.(*ast.ColumnNameExpr).Name.Name.O) { + t.Fatalf("got %v, want %v", bExpr.L.(*ast.ColumnNameExpr).Name.Name.O, "a") + } + if !reflect.DeepEqual("t", bExpr.L.(*ast.ColumnNameExpr).Name.Table.O) { + t.Fatalf("got %v, want %v", bExpr.L.(*ast.ColumnNameExpr).Name.Table.O, "t") + } + if !reflect.DeepEqual(int64(10), bExpr.R.(ast.ValueExpr).GetValue().(int64)) { + t.Fatalf("got %v, want %v", bExpr.R.(ast.ValueExpr).GetValue().(int64), int64(10)) + } p.SetSQLMode(mysql.ModeANSIQuotes) src = `select t."dot"=10 from t;` st, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } bExpr, ok = st.(*ast.SelectStmt).Fields.Fields[0].Expr.(*ast.BinaryOperationExpr) - require.True(t, ok) - require.Equal(t, opcode.EQ, bExpr.Op) - require.Equal(t, "dot", bExpr.L.(*ast.ColumnNameExpr).Name.Name.O) - require.Equal(t, "t", bExpr.L.(*ast.ColumnNameExpr).Name.Table.O) - require.Equal(t, int64(10), bExpr.R.(ast.ValueExpr).GetValue().(int64)) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual(opcode.EQ, bExpr.Op) { + t.Fatalf("got %v, want %v", bExpr.Op, opcode.EQ) + } + if !reflect.DeepEqual("dot", bExpr.L.(*ast.ColumnNameExpr).Name.Name.O) { + t.Fatalf("got %v, want %v", bExpr.L.(*ast.ColumnNameExpr).Name.Name.O, "dot") + } + if !reflect.DeepEqual("t", bExpr.L.(*ast.ColumnNameExpr).Name.Table.O) { + t.Fatalf("got %v, want %v", bExpr.L.(*ast.ColumnNameExpr).Name.Table.O, "t") + } + if !reflect.DeepEqual(int64(10), bExpr.R.(ast.ValueExpr).GetValue().(int64)) { + t.Fatalf("got %v, want %v", bExpr.R.(ast.ValueExpr).GetValue().(int64), int64(10)) + } } func TestSpecialComments(t *testing.T) { @@ -321,31 +471,55 @@ func TestSpecialComments(t *testing.T) { // 1. Make sure /*! ... */ respects the same SQL mode. _, err := p.ParseOneStmt(`SELECT /*! '\' */;`, "", "") - require.Error(t, err) + if err == nil { + t.Fatal("expected error") + } p.SetSQLMode(mysql.ModeNoBackslashEscapes) st, err := p.ParseOneStmt(`SELECT /*! '\' */;`, "", "") - require.NoError(t, err) - require.IsType(t, &ast.SelectStmt{}, st) + if err != nil { + t.Fatal(err) + } + if _, ok := st.(*ast.SelectStmt); !ok { + t.Fatalf("expected type %T, got %T", &ast.SelectStmt{}, st) + } // 2. Make sure multiple statements inside /*! ... */ will not crash // (this is issue #330) stmts, _, err := p.Parse("/*! SET x = 1; SELECT 2 */", "", "") - require.NoError(t, err) - require.Len(t, stmts, 2) - require.IsType(t, &ast.SetStmt{}, stmts[0]) - require.Equal(t, "/*! SET x = 1;", stmts[0].Text()) - require.IsType(t, &ast.SelectStmt{}, stmts[1]) - require.Equal(t, " SELECT 2 */", stmts[1].Text()) + if err != nil { + t.Fatal(err) + } + if got := len(stmts); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if _, ok := stmts[0].(*ast.SetStmt); !ok { + t.Fatalf("expected type %T, got %T", &ast.SetStmt{}, stmts[0]) + } + if !reflect.DeepEqual("/*! SET x = 1;", stmts[0].Text()) { + t.Fatalf("got %v, want %v", stmts[0].Text(), "/*! SET x = 1;") + } + if _, ok := stmts[1].(*ast.SelectStmt); !ok { + t.Fatalf("expected type %T, got %T", &ast.SelectStmt{}, stmts[1]) + } + if !reflect.DeepEqual(" SELECT 2 */", stmts[1].Text()) { + t.Fatalf("got %v, want %v", stmts[1].Text(), " SELECT 2 */") + } // ^ not sure if correct approach; having multiple statements in MySQL is a syntax error. // 3. Make sure invalid text won't cause infinite loop // (this is issue #336) st, err = p.ParseOneStmt("SELECT /*+ 😅 */ SLEEP(1);", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } sel, ok := st.(*ast.SelectStmt) - require.True(t, ok) - require.Len(t, sel.TableHints, 0) + if !(ok) { + t.Fatal("expected true") + } + if got := len(sel.TableHints); got != 0 { + t.Fatalf("expected length %d, got %d", 0, got) + } } type testCase struct { @@ -366,10 +540,14 @@ func RunTest(t *testing.T, table []testCase, enableWindowFunc bool, MariaDB bool for _, tbl := range table { _, _, err := p.Parse(tbl.src, "", "") if !tbl.ok { - require.Errorf(t, err, "source %v, error %v", tbl.src, errors.Trace(err)) + if err == nil { + t.Fatalf("source %v, error %v", tbl.src, errors.Trace(err)) + } continue } - require.NoErrorf(t, err, "source:\n%v\nerror:\n%v", tbl.src, errors.Trace(err)) + if err != nil { + t.Fatalf("%s: %v", fmt.Sprintf("source:\n%v\nerror:\n%v", tbl.src, errors.Trace(err)), err) + } // restore correctness test if tbl.ok { RunRestoreTest(t, tbl.src, tbl.restore, enableWindowFunc, MariaDB) @@ -384,25 +562,35 @@ func RunRestoreTest(t *testing.T, sourceSQLs, expectSQLs string, enableWindowFun p.SetMariaDB(MariaDB) comment := fmt.Sprintf("source %v", sourceSQLs) stmts, _, err := p.Parse(sourceSQLs, "", "") - require.NoErrorf(t, err, "source %v", sourceSQLs) + if err != nil { + t.Fatalf("%s: %v", fmt.Sprintf("source %v", sourceSQLs), err) + } restoreSQLs := "" for _, stmt := range stmts { sb.Reset() err = stmt.Restore(NewRestoreCtx(DefaultRestoreFlags, &sb)) - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } restoreSQL := sb.String() comment = fmt.Sprintf("source %v; restore %v", sourceSQLs, restoreSQL) restoreStmt, err := p.ParseOneStmt(restoreSQL, "", "") - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } CleanNodeText(stmt) CleanNodeText(restoreStmt) - require.Equal(t, stmt, restoreStmt, comment) + if !reflect.DeepEqual(stmt, restoreStmt) { + t.Fatalf("%v: got %v, want %v", comment, restoreStmt, stmt) + } if restoreSQLs != "" { restoreSQLs += "; " } restoreSQLs += restoreSQL } - require.Equalf(t, expectSQLs, restoreSQLs, "restore %v; expect %v", restoreSQLs, expectSQLs) + if !reflect.DeepEqual(expectSQLs, restoreSQLs) { + t.Fatalf("%s: got %v, want %v", fmt.Sprintf("restore %v; expect %v", restoreSQLs, expectSQLs), restoreSQLs, expectSQLs) + } } func RunTestInRealAsFloatMode(t *testing.T, table []testCase, enableWindowFunc bool) { @@ -413,10 +601,14 @@ func RunTestInRealAsFloatMode(t *testing.T, table []testCase, enableWindowFunc b _, _, err := p.Parse(tbl.src, "", "") comment := fmt.Sprintf("source %v", tbl.src) if !tbl.ok { - require.Error(t, err, comment) + if err == nil { + t.Fatal(comment) + } continue } - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } // restore correctness test if tbl.ok { RunRestoreTestInRealAsFloatMode(t, tbl.src, tbl.restore, enableWindowFunc) @@ -431,25 +623,35 @@ func RunRestoreTestInRealAsFloatMode(t *testing.T, sourceSQLs, expectSQLs string p.SetSQLMode(mysql.ModeRealAsFloat) comment := fmt.Sprintf("source %v", sourceSQLs) stmts, _, err := p.Parse(sourceSQLs, "", "") - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } restoreSQLs := "" for _, stmt := range stmts { sb.Reset() err = stmt.Restore(NewRestoreCtx(DefaultRestoreFlags, &sb)) - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } restoreSQL := sb.String() comment = fmt.Sprintf("source %v; restore %v", sourceSQLs, restoreSQL) restoreStmt, err := p.ParseOneStmt(restoreSQL, "", "") - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } CleanNodeText(stmt) CleanNodeText(restoreStmt) - require.Equal(t, stmt, restoreStmt, comment) + if !reflect.DeepEqual(stmt, restoreStmt) { + t.Fatalf("%v: got %v, want %v", comment, restoreStmt, stmt) + } if restoreSQLs != "" { restoreSQLs += "; " } restoreSQLs += restoreSQL } - require.Equal(t, expectSQLs, restoreSQLs, "restore %v; expect %v", restoreSQLs, expectSQLs) + if !reflect.DeepEqual(expectSQLs, restoreSQLs) { + t.Fatalf("%s: got %v, want %v", fmt.Sprintf("restore %v; expect %v", restoreSQLs, expectSQLs), restoreSQLs, expectSQLs) + } } func RunErrMsgTest(t *testing.T, table []testErrMsgCase) { @@ -458,9 +660,13 @@ func RunErrMsgTest(t *testing.T, table []testErrMsgCase) { _, _, err := p.Parse(tbl.src, "", "") comment := fmt.Sprintf("source %v", tbl.src) if tbl.err != nil { - require.True(t, terror.ErrorEqual(err, tbl.err), comment) + if !(terror.ErrorEqual(err, tbl.err)) { + t.Fatal(comment) + } } else { - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } } } } @@ -1523,41 +1729,73 @@ func TestSetVariable(t *testing.T) { p := parser.New() for _, tbl := range table { stmt, err := p.ParseOneStmt(tbl.Input, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } setStmt, ok := stmt.(*ast.SetStmt) - require.True(t, ok) - require.Len(t, setStmt.Variables, 1) + if !(ok) { + t.Fatal("expected true") + } + if got := len(setStmt.Variables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } v := setStmt.Variables[0] - require.Equal(t, tbl.Name, v.Name) - require.Equal(t, tbl.IsGlobal, v.IsGlobal) - require.Equal(t, tbl.IsInstance, v.IsInstance) - require.Equal(t, tbl.IsSystem, v.IsSystem) + if !reflect.DeepEqual(tbl.Name, v.Name) { + t.Fatalf("got %v, want %v", v.Name, tbl.Name) + } + if !reflect.DeepEqual(tbl.IsGlobal, v.IsGlobal) { + t.Fatalf("got %v, want %v", v.IsGlobal, tbl.IsGlobal) + } + if !reflect.DeepEqual(tbl.IsInstance, v.IsInstance) { + t.Fatalf("got %v, want %v", v.IsInstance, tbl.IsInstance) + } + if !reflect.DeepEqual(tbl.IsSystem, v.IsSystem) { + t.Fatalf("got %v, want %v", v.IsSystem, tbl.IsSystem) + } } _, err := p.ParseOneStmt("set xx.xx.xx = 666", "", "") - require.Error(t, err) + if err == nil { + t.Fatal("expected error") + } } func TestFlushTable(t *testing.T) { p := parser.New() stmt, _, err := p.Parse("flush local tables tbl1,tbl2 with read lock", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } flushTable := stmt[0].(*ast.FlushStmt) - require.Equal(t, ast.FlushTables, flushTable.Tp) - require.Equal(t, "tbl1", flushTable.Tables[0].Name.L) - require.Equal(t, "tbl2", flushTable.Tables[1].Name.L) - require.True(t, flushTable.NoWriteToBinLog) - require.True(t, flushTable.ReadLock) + if !reflect.DeepEqual(ast.FlushTables, flushTable.Tp) { + t.Fatalf("got %v, want %v", flushTable.Tp, ast.FlushTables) + } + if !reflect.DeepEqual("tbl1", flushTable.Tables[0].Name.L) { + t.Fatalf("got %v, want %v", flushTable.Tables[0].Name.L, "tbl1") + } + if !reflect.DeepEqual("tbl2", flushTable.Tables[1].Name.L) { + t.Fatalf("got %v, want %v", flushTable.Tables[1].Name.L, "tbl2") + } + if !(flushTable.NoWriteToBinLog) { + t.Fatal("expected true") + } + if !(flushTable.ReadLock) { + t.Fatal("expected true") + } } func TestFlushPrivileges(t *testing.T) { p := parser.New() stmt, _, err := p.Parse("flush privileges", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } flushPrivilege := stmt[0].(*ast.FlushStmt) - require.Equal(t, ast.FlushPrivileges, flushPrivilege.Tp) + if !reflect.DeepEqual(ast.FlushPrivileges, flushPrivilege.Tp) { + t.Fatalf("got %v, want %v", flushPrivilege.Tp, ast.FlushPrivileges) + } } func TestExpression(t *testing.T) { @@ -2474,10 +2712,14 @@ func TestBuiltinFuncAsIdentifier(t *testing.T) { for _, c := range testcases { _, _, err := p.Parse(c.src, "", "") if !c.ok { - require.Errorf(t, err, "source %v", c.src) + if err == nil { + t.Fatalf("source %v", c.src) + } continue } - require.NoErrorf(t, err, "source %v", c.src) + if err != nil { + t.Fatalf("%s: %v", fmt.Sprintf("source %v", c.src), err) + } if c.ok && !ignoreSpace { RunRestoreTest(t, c.src, c.restore, false, false) } @@ -4235,457 +4477,947 @@ func TestDDL(t *testing.T) { func TestHintError(t *testing.T) { p := parser.New() stmt, warns, err := p.Parse("select /*+ tidb_unknown(T1,t2) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) - require.Len(t, warns, 1) - require.Equal(t, `[parser:8061]Optimizer hint tidb_unknown is not supported by TiDB and is ignored`, warns[0].Error()) - require.Len(t, stmt[0].(*ast.SelectStmt).TableHints, 0) + if err != nil { + t.Fatal(err) + } + if got := len(warns); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual(`[parser:8061]Optimizer hint tidb_unknown is not supported by TiDB and is ignored`, warns[0].Error()) { + t.Fatalf("got %v, want %v", warns[0].Error(), `[parser:8061]Optimizer hint tidb_unknown is not supported by TiDB and is ignored`) + } + if got := len(stmt[0].(*ast.SelectStmt).TableHints); got != 0 { + t.Fatalf("expected length %d, got %d", 0, got) + } stmt, warns, err = p.Parse("select /*+ TIDB_INLJ(t1, T2) tidb_unknown(T1,t2, 1) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.Len(t, stmt[0].(*ast.SelectStmt).TableHints, 1) - require.NoError(t, err) - require.Len(t, warns, 1) - require.Equal(t, `[parser:8061]Optimizer hint tidb_unknown is not supported by TiDB and is ignored`, warns[0].Error()) + if got := len(stmt[0].(*ast.SelectStmt).TableHints); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if err != nil { + t.Fatal(err) + } + if got := len(warns); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual(`[parser:8061]Optimizer hint tidb_unknown is not supported by TiDB and is ignored`, warns[0].Error()) { + t.Fatalf("got %v, want %v", warns[0].Error(), `[parser:8061]Optimizer hint tidb_unknown is not supported by TiDB and is ignored`) + } _, _, err = p.Parse("select c1, c2 from /*+ tidb_unknow(T1,t2) */ t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) // Hints are ignored after the "FROM" keyword! + if err != nil { + t.Fatal(err) + } // Hints are ignored after the "FROM" keyword! _, _, err = p.Parse("select1 /*+ TIDB_INLJ(t1, T2) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.EqualError(t, err, "line 1 column 7 near \"select1 /*+ TIDB_INLJ(t1, T2) */ c1, c2 from t1, t2 where t1.c1 = t2.c1\" ") + if err == nil || err.Error() != "line 1 column 7 near \"select1 /*+ TIDB_INLJ(t1, T2) */ c1, c2 from t1, t2 where t1.c1 = t2.c1\" " { + t.Fatalf("expected error %q, got %v", "line 1 column 7 near \"select1 /*+ TIDB_INLJ(t1, T2) */ c1, c2 from t1, t2 where t1.c1 = t2.c1\" ", err) + } _, _, err = p.Parse("select /*+ TIDB_INLJ(t1, T2) */ c1, c2 fromt t1, t2 where t1.c1 = t2.c1", "", "") - require.EqualError(t, err, "line 1 column 47 near \"t1, t2 where t1.c1 = t2.c1\" ") + if err == nil || err.Error() != "line 1 column 47 near \"t1, t2 where t1.c1 = t2.c1\" " { + t.Fatalf("expected error %q, got %v", "line 1 column 47 near \"t1, t2 where t1.c1 = t2.c1\" ", err) + } _, _, err = p.Parse("SELECT 1 FROM DUAL WHERE 1 IN (SELECT /*+ DEBUG_HINT3 */ 1)", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } stmt, _, err = p.Parse("insert into t select /*+ memory_quota(1 MB) */ * from t;", "", "") - require.NoError(t, err) - require.Len(t, stmt[0].(*ast.InsertStmt).TableHints, 0) - require.Len(t, stmt[0].(*ast.InsertStmt).Select.(*ast.SelectStmt).TableHints, 1) + if err != nil { + t.Fatal(err) + } + if got := len(stmt[0].(*ast.InsertStmt).TableHints); got != 0 { + t.Fatalf("expected length %d, got %d", 0, got) + } + if got := len(stmt[0].(*ast.InsertStmt).Select.(*ast.SelectStmt).TableHints); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } stmt, _, err = p.Parse("insert /*+ memory_quota(1 MB) */ into t select * from t;", "", "") - require.NoError(t, err) - require.Len(t, stmt[0].(*ast.InsertStmt).TableHints, 1) + if err != nil { + t.Fatal(err) + } + if got := len(stmt[0].(*ast.InsertStmt).TableHints); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } _, warns, err = p.Parse("SELECT id FROM tbl WHERE id = 0 FOR UPDATE /*+ xyz */", "", "") - require.NoError(t, err) - require.Len(t, warns, 1) - require.Regexp(t, `near '/\*\+' at line 1$`, warns[0].Error()) + if err != nil { + t.Fatal(err) + } + if got := len(warns); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !regexp.MustCompile(`near '/\*\+' at line 1$`).MatchString(warns[0].Error()) { + t.Fatalf("expected %q to match %q", warns[0].Error(), `near '/\*\+' at line 1$`) + } _, warns, err = p.Parse("create global binding for select /*+ max_execution_time(1) */ 1 using select /*+ max_execution_time(1) */ 1;\n", "", "") - require.NoError(t, err) - require.Len(t, warns, 0) + if err != nil { + t.Fatal(err) + } + if got := len(warns); got != 0 { + t.Fatalf("expected length %d, got %d", 0, got) + } } func TestErrorMsg(t *testing.T) { p := parser.New() _, _, err := p.Parse("select1 1", "", "") - require.EqualError(t, err, "line 1 column 7 near \"select1 1\" ") + if err == nil || err.Error() != "line 1 column 7 near \"select1 1\" " { + t.Fatalf("expected error %q, got %v", "line 1 column 7 near \"select1 1\" ", err) + } _, _, err = p.Parse("select 1 from1 dual", "", "") - require.EqualError(t, err, "line 1 column 19 near \"dual\" ") + if err == nil || err.Error() != "line 1 column 19 near \"dual\" " { + t.Fatalf("expected error %q, got %v", "line 1 column 19 near \"dual\" ", err) + } _, _, err = p.Parse("select * from t1 join t2 from t1.a = t2.a;", "", "") - require.EqualError(t, err, "line 1 column 29 near \"from t1.a = t2.a;\" ") + if err == nil || err.Error() != "line 1 column 29 near \"from t1.a = t2.a;\" " { + t.Fatalf("expected error %q, got %v", "line 1 column 29 near \"from t1.a = t2.a;\" ", err) + } _, _, err = p.Parse("select * from t1 join t2 one t1.a = t2.a;", "", "") - require.EqualError(t, err, "line 1 column 31 near \"t1.a = t2.a;\" ") + if err == nil || err.Error() != "line 1 column 31 near \"t1.a = t2.a;\" " { + t.Fatalf("expected error %q, got %v", "line 1 column 31 near \"t1.a = t2.a;\" ", err) + } _, _, err = p.Parse("select * from t1 join t2 on t1.a >>> t2.a;", "", "") - require.EqualError(t, err, "line 1 column 36 near \"> t2.a;\" ") + if err == nil || err.Error() != "line 1 column 36 near \"> t2.a;\" " { + t.Fatalf("expected error %q, got %v", "line 1 column 36 near \"> t2.a;\" ", err) + } _, _, err = p.Parse("create table t(f_year year(5))ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin;", "", "") - require.EqualError(t, err, "[parser:1818]Supports only YEAR or YEAR(4) column") + if err == nil || err.Error() != "[parser:1818]Supports only YEAR or YEAR(4) column" { + t.Fatalf("expected error %q, got %v", "[parser:1818]Supports only YEAR or YEAR(4) column", err) + } _, _, err = p.Parse("create table ``.t (id int);", "", "") - require.EqualError(t, err, "[parser:1102]Incorrect database name ''") + if err == nil || err.Error() != "[parser:1102]Incorrect database name ''" { + t.Fatalf("expected error %q, got %v", "[parser:1102]Incorrect database name ''", err) + } _, _, err = p.Parse("create table ` `.t (id int);", "", "") - require.EqualError(t, err, "[parser:1102]Incorrect database name ' '") + if err == nil || err.Error() != "[parser:1102]Incorrect database name ' '" { + t.Fatalf("expected error %q, got %v", "[parser:1102]Incorrect database name ' '", err) + } _, _, err = p.Parse("select ifnull(a,0) & ifnull(a,0) like '55' ESCAPE '\\\\a' from t;", "", "") - require.EqualError(t, err, "[parser:1210]Incorrect arguments to ESCAPE") + if err == nil || err.Error() != "[parser:1210]Incorrect arguments to ESCAPE" { + t.Fatalf("expected error %q, got %v", "[parser:1210]Incorrect arguments to ESCAPE", err) + } _, _, err = p.Parse("load data infile 'aaa' into table aaa FIELDS Enclosed by '\\\\b';", "", "") - require.EqualError(t, err, "[parser:1083]Field separator argument is not what is expected; check the manual") + if err == nil || err.Error() != "[parser:1083]Field separator argument is not what is expected; check the manual" { + t.Fatalf("expected error %q, got %v", "[parser:1083]Field separator argument is not what is expected; check the manual", err) + } _, _, err = p.Parse("load data infile 'aaa' into table aaa FIELDS Escaped by '\\\\b';", "", "") - require.EqualError(t, err, "[parser:1083]Field separator argument is not what is expected; check the manual") + if err == nil || err.Error() != "[parser:1083]Field separator argument is not what is expected; check the manual" { + t.Fatalf("expected error %q, got %v", "[parser:1083]Field separator argument is not what is expected; check the manual", err) + } _, _, err = p.Parse("load data infile 'aaa' into table aaa FIELDS Enclosed by '\\\\b' Escaped by '\\\\b' ;", "", "") - require.EqualError(t, err, "[parser:1083]Field separator argument is not what is expected; check the manual") + if err == nil || err.Error() != "[parser:1083]Field separator argument is not what is expected; check the manual" { + t.Fatalf("expected error %q, got %v", "[parser:1083]Field separator argument is not what is expected; check the manual", err) + } _, _, err = p.Parse("ALTER DATABASE `` CHARACTER SET = ''", "", "") - require.EqualError(t, err, "[parser:1115]Unknown character set: ''") + if err == nil || err.Error() != "[parser:1115]Unknown character set: ''" { + t.Fatalf("expected error %q, got %v", "[parser:1115]Unknown character set: ''", err) + } _, _, err = p.Parse("ALTER DATABASE t CHARACTER SET = ''", "", "") - require.EqualError(t, err, "[parser:1115]Unknown character set: ''") + if err == nil || err.Error() != "[parser:1115]Unknown character set: ''" { + t.Fatalf("expected error %q, got %v", "[parser:1115]Unknown character set: ''", err) + } _, _, err = p.Parse("ALTER SCHEMA t CHARACTER SET = 'SOME_INVALID_CHARSET'", "", "") - require.EqualError(t, err, "[parser:1115]Unknown character set: 'SOME_INVALID_CHARSET'") + if err == nil || err.Error() != "[parser:1115]Unknown character set: 'SOME_INVALID_CHARSET'" { + t.Fatalf("expected error %q, got %v", "[parser:1115]Unknown character set: 'SOME_INVALID_CHARSET'", err) + } _, _, err = p.Parse("ALTER DATABASE t COLLATE = ''", "", "") - require.EqualError(t, err, "[ddl:1273]Unknown collation: ''") + if err == nil || err.Error() != "[ddl:1273]Unknown collation: ''" { + t.Fatalf("expected error %q, got %v", "[ddl:1273]Unknown collation: ''", err) + } _, _, err = p.Parse("ALTER SCHEMA t COLLATE = 'SOME_INVALID_COLLATION'", "", "") - require.EqualError(t, err, "[ddl:1273]Unknown collation: 'SOME_INVALID_COLLATION'") + if err == nil || err.Error() != "[ddl:1273]Unknown collation: 'SOME_INVALID_COLLATION'" { + t.Fatalf("expected error %q, got %v", "[ddl:1273]Unknown collation: 'SOME_INVALID_COLLATION'", err) + } _, _, err = p.Parse("ALTER DATABASE CHARSET = 'utf8mb4' COLLATE = 'utf8_bin'", "", "") - require.EqualError(t, err, "line 1 column 24 near \"= 'utf8mb4' COLLATE = 'utf8_bin'\" ") + if err == nil || err.Error() != "line 1 column 24 near \"= 'utf8mb4' COLLATE = 'utf8_bin'\" " { + t.Fatalf("expected error %q, got %v", "line 1 column 24 near \"= 'utf8mb4' COLLATE = 'utf8_bin'\" ", err) + } _, _, err = p.Parse("ALTER DATABASE t ENCRYPTION = ''", "", "") - require.EqualError(t, err, "[parser:1525]Incorrect argument (should be Y or N) value: ''") + if err == nil || err.Error() != "[parser:1525]Incorrect argument (should be Y or N) value: ''" { + t.Fatalf("expected error %q, got %v", "[parser:1525]Incorrect argument (should be Y or N) value: ''", err) + } _, _, err = p.Parse("ALTER DATABASE", "", "") - require.EqualError(t, err, "line 1 column 14 near \"\" ") + if err == nil || err.Error() != "line 1 column 14 near \"\" " { + t.Fatalf("expected error %q, got %v", "line 1 column 14 near \"\" ", err) + } _, _, err = p.Parse("ALTER SCHEMA `ANY_DB_NAME`", "", "") - require.EqualError(t, err, "line 1 column 26 near \"\" ") + if err == nil || err.Error() != "line 1 column 26 near \"\" " { + t.Fatalf("expected error %q, got %v", "line 1 column 26 near \"\" ", err) + } _, _, err = p.Parse("alter table t partition by range FIELDS(a)", "", "") - require.EqualError(t, err, "[ddl:1492]For RANGE partitions each partition must be defined") + if err == nil || err.Error() != "[ddl:1492]For RANGE partitions each partition must be defined" { + t.Fatalf("expected error %q, got %v", "[ddl:1492]For RANGE partitions each partition must be defined", err) + } _, _, err = p.Parse("alter table t partition by list FIELDS(a)", "", "") - require.EqualError(t, err, "[ddl:1492]For LIST partitions each partition must be defined") + if err == nil || err.Error() != "[ddl:1492]For LIST partitions each partition must be defined" { + t.Fatalf("expected error %q, got %v", "[ddl:1492]For LIST partitions each partition must be defined", err) + } _, _, err = p.Parse("alter table t partition by list FIELDS(a)", "", "") - require.EqualError(t, err, "[ddl:1492]For LIST partitions each partition must be defined") + if err == nil || err.Error() != "[ddl:1492]For LIST partitions each partition must be defined" { + t.Fatalf("expected error %q, got %v", "[ddl:1492]For LIST partitions each partition must be defined", err) + } _, _, err = p.Parse("alter table t partition by list FIELDS(a,b,c)", "", "") - require.EqualError(t, err, "[ddl:1492]For LIST partitions each partition must be defined") + if err == nil || err.Error() != "[ddl:1492]For LIST partitions each partition must be defined" { + t.Fatalf("expected error %q, got %v", "[ddl:1492]For LIST partitions each partition must be defined", err) + } _, _, err = p.Parse("alter table t lock = first", "", "") - require.EqualError(t, err, "[parser:1801]Unknown LOCK type 'first'") + if err == nil || err.Error() != "[parser:1801]Unknown LOCK type 'first'" { + t.Fatalf("expected error %q, got %v", "[parser:1801]Unknown LOCK type 'first'", err) + } _, _, err = p.Parse("alter table t lock = start", "", "") - require.EqualError(t, err, "[parser:1801]Unknown LOCK type 'start'") + if err == nil || err.Error() != "[parser:1801]Unknown LOCK type 'start'" { + t.Fatalf("expected error %q, got %v", "[parser:1801]Unknown LOCK type 'start'", err) + } _, _, err = p.Parse("alter table t lock = commit", "", "") - require.EqualError(t, err, "[parser:1801]Unknown LOCK type 'commit'") + if err == nil || err.Error() != "[parser:1801]Unknown LOCK type 'commit'" { + t.Fatalf("expected error %q, got %v", "[parser:1801]Unknown LOCK type 'commit'", err) + } _, _, err = p.Parse("alter table t lock = binlog", "", "") - require.EqualError(t, err, "[parser:1801]Unknown LOCK type 'binlog'") + if err == nil || err.Error() != "[parser:1801]Unknown LOCK type 'binlog'" { + t.Fatalf("expected error %q, got %v", "[parser:1801]Unknown LOCK type 'binlog'", err) + } _, _, err = p.Parse("alter table t lock = randomStr123", "", "") - require.EqualError(t, err, "[parser:1801]Unknown LOCK type 'randomStr123'") + if err == nil || err.Error() != "[parser:1801]Unknown LOCK type 'randomStr123'" { + t.Fatalf("expected error %q, got %v", "[parser:1801]Unknown LOCK type 'randomStr123'", err) + } _, _, err = p.Parse("create table t (a longtext unicode)", "", "") - require.EqualError(t, err, "[parser:1115]Unknown character set: 'ucs2'") + if err == nil || err.Error() != "[parser:1115]Unknown character set: 'ucs2'" { + t.Fatalf("expected error %q, got %v", "[parser:1115]Unknown character set: 'ucs2'", err) + } _, _, err = p.Parse("create table t (a long byte, b text unicode)", "", "") - require.EqualError(t, err, "[parser:1115]Unknown character set: 'ucs2'") + if err == nil || err.Error() != "[parser:1115]Unknown character set: 'ucs2'" { + t.Fatalf("expected error %q, got %v", "[parser:1115]Unknown character set: 'ucs2'", err) + } _, _, err = p.Parse("create table t (a long ascii, b long unicode)", "", "") - require.EqualError(t, err, "[parser:1115]Unknown character set: 'ucs2'") + if err == nil || err.Error() != "[parser:1115]Unknown character set: 'ucs2'" { + t.Fatalf("expected error %q, got %v", "[parser:1115]Unknown character set: 'ucs2'", err) + } _, _, err = p.Parse("create table t (a text unicode, b mediumtext ascii, c int)", "", "") - require.EqualError(t, err, "[parser:1115]Unknown character set: 'ucs2'") + if err == nil || err.Error() != "[parser:1115]Unknown character set: 'ucs2'" { + t.Fatalf("expected error %q, got %v", "[parser:1115]Unknown character set: 'ucs2'", err) + } _, _, err = p.Parse("select 1 collate some_unknown_collation", "", "") - require.EqualError(t, err, "[ddl:1273]Unknown collation: 'some_unknown_collation'") + if err == nil || err.Error() != "[ddl:1273]Unknown collation: 'some_unknown_collation'" { + t.Fatalf("expected error %q, got %v", "[ddl:1273]Unknown collation: 'some_unknown_collation'", err) + } } func TestOptimizerHints(t *testing.T) { p := parser.New() // Test USE_INDEX stmt, _, err := p.Parse("select /*+ USE_INDEX(T1,T2), use_index(t3,t4) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt := stmt[0].(*ast.SelectStmt) hints := selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "use_index", hints[0].HintName.L) - require.Len(t, hints[0].Tables, 1) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) - require.Len(t, hints[0].Indexes, 1) - require.Equal(t, "t2", hints[0].Indexes[0].L) - - require.Equal(t, "use_index", hints[1].HintName.L) - require.Len(t, hints[1].Tables, 1) - require.Equal(t, "t3", hints[1].Tables[0].TableName.L) - require.Len(t, hints[1].Indexes, 1) - require.Equal(t, "t4", hints[1].Indexes[0].L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("use_index", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "use_index") + } + if got := len(hints[0].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } + if got := len(hints[0].Indexes); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t2", hints[0].Indexes[0].L) { + t.Fatalf("got %v, want %v", hints[0].Indexes[0].L, "t2") + } + + if !reflect.DeepEqual("use_index", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "use_index") + } + if got := len(hints[1].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t3", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t3") + } + if got := len(hints[1].Indexes); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t4", hints[1].Indexes[0].L) { + t.Fatalf("got %v, want %v", hints[1].Indexes[0].L, "t4") + } // Test FORCE_INDEX stmt, _, err = p.Parse("select /*+ FORCE_INDEX(T1,T2), force_index(t3,t4) RESOURCE_GROUP(rg1)*/ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 3) - require.Equal(t, "force_index", hints[0].HintName.L) - require.Len(t, hints[0].Tables, 1) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) - require.Len(t, hints[0].Indexes, 1) - require.Equal(t, "t2", hints[0].Indexes[0].L) - - require.Equal(t, "force_index", hints[1].HintName.L) - require.Len(t, hints[1].Tables, 1) - require.Equal(t, "t3", hints[1].Tables[0].TableName.L) - require.Len(t, hints[1].Indexes, 1) - require.Equal(t, "t4", hints[1].Indexes[0].L) - - require.Equal(t, "resource_group", hints[2].HintName.L) - require.Equal(t, hints[2].HintData, "rg1") + if got := len(hints); got != 3 { + t.Fatalf("expected length %d, got %d", 3, got) + } + if !reflect.DeepEqual("force_index", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "force_index") + } + if got := len(hints[0].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } + if got := len(hints[0].Indexes); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t2", hints[0].Indexes[0].L) { + t.Fatalf("got %v, want %v", hints[0].Indexes[0].L, "t2") + } + + if !reflect.DeepEqual("force_index", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "force_index") + } + if got := len(hints[1].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t3", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t3") + } + if got := len(hints[1].Indexes); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t4", hints[1].Indexes[0].L) { + t.Fatalf("got %v, want %v", hints[1].Indexes[0].L, "t4") + } + + if !reflect.DeepEqual("resource_group", hints[2].HintName.L) { + t.Fatalf("got %v, want %v", hints[2].HintName.L, "resource_group") + } + if !reflect.DeepEqual(hints[2].HintData, "rg1") { + t.Fatalf("got %v, want %v", "rg1", hints[2].HintData) + } // Test IGNORE_INDEX stmt, _, err = p.Parse("select /*+ IGNORE_INDEX(T1,T2), ignore_index(t3,t4) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "ignore_index", hints[0].HintName.L) - require.Len(t, hints[0].Tables, 1) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) - require.Len(t, hints[0].Indexes, 1) - require.Equal(t, "t2", hints[0].Indexes[0].L) - - require.Equal(t, "ignore_index", hints[1].HintName.L) - require.Len(t, hints[1].Tables, 1) - require.Equal(t, "t3", hints[1].Tables[0].TableName.L) - require.Len(t, hints[1].Indexes, 1) - require.Equal(t, "t4", hints[1].Indexes[0].L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("ignore_index", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "ignore_index") + } + if got := len(hints[0].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } + if got := len(hints[0].Indexes); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t2", hints[0].Indexes[0].L) { + t.Fatalf("got %v, want %v", hints[0].Indexes[0].L, "t2") + } + + if !reflect.DeepEqual("ignore_index", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "ignore_index") + } + if got := len(hints[1].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t3", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t3") + } + if got := len(hints[1].Indexes); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t4", hints[1].Indexes[0].L) { + t.Fatalf("got %v, want %v", hints[1].Indexes[0].L, "t4") + } // Test ORDER_INDEX stmt, _, err = p.Parse("select /*+ ORDER_INDEX(T1,T2), order_index(t3,t4) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "order_index", hints[0].HintName.L) - require.Len(t, hints[0].Tables, 1) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) - require.Len(t, hints[0].Indexes, 1) - require.Equal(t, "t2", hints[0].Indexes[0].L) - - require.Equal(t, "order_index", hints[1].HintName.L) - require.Len(t, hints[1].Tables, 1) - require.Equal(t, "t3", hints[1].Tables[0].TableName.L) - require.Len(t, hints[1].Indexes, 1) - require.Equal(t, "t4", hints[1].Indexes[0].L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("order_index", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "order_index") + } + if got := len(hints[0].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } + if got := len(hints[0].Indexes); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t2", hints[0].Indexes[0].L) { + t.Fatalf("got %v, want %v", hints[0].Indexes[0].L, "t2") + } + + if !reflect.DeepEqual("order_index", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "order_index") + } + if got := len(hints[1].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t3", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t3") + } + if got := len(hints[1].Indexes); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t4", hints[1].Indexes[0].L) { + t.Fatalf("got %v, want %v", hints[1].Indexes[0].L, "t4") + } // Test NO_ORDER_INDEX stmt, _, err = p.Parse("select /*+ NO_ORDER_INDEX(T1,T2), no_order_index(t3,t4) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "no_order_index", hints[0].HintName.L) - require.Len(t, hints[0].Tables, 1) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) - require.Len(t, hints[0].Indexes, 1) - require.Equal(t, "t2", hints[0].Indexes[0].L) - - require.Equal(t, "no_order_index", hints[1].HintName.L) - require.Len(t, hints[1].Tables, 1) - require.Equal(t, "t3", hints[1].Tables[0].TableName.L) - require.Len(t, hints[1].Indexes, 1) - require.Equal(t, "t4", hints[1].Indexes[0].L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("no_order_index", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "no_order_index") + } + if got := len(hints[0].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } + if got := len(hints[0].Indexes); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t2", hints[0].Indexes[0].L) { + t.Fatalf("got %v, want %v", hints[0].Indexes[0].L, "t2") + } + + if !reflect.DeepEqual("no_order_index", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "no_order_index") + } + if got := len(hints[1].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t3", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t3") + } + if got := len(hints[1].Indexes); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t4", hints[1].Indexes[0].L) { + t.Fatalf("got %v, want %v", hints[1].Indexes[0].L, "t4") + } // Test INDEX_LOOKUP_PUSHDOWN stmt, _, err = p.Parse("select /*+ INDEX_LOOKUP_PUSHDOWN(T1,T2), index_lookup_pushdown(t3,t4) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "index_lookup_pushdown", hints[0].HintName.L) - require.Len(t, hints[0].Tables, 1) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) - require.Len(t, hints[0].Indexes, 1) - require.Equal(t, "t2", hints[0].Indexes[0].L) - - require.Equal(t, "index_lookup_pushdown", hints[1].HintName.L) - require.Len(t, hints[1].Tables, 1) - require.Equal(t, "t3", hints[1].Tables[0].TableName.L) - require.Len(t, hints[1].Indexes, 1) - require.Equal(t, "t4", hints[1].Indexes[0].L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("index_lookup_pushdown", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "index_lookup_pushdown") + } + if got := len(hints[0].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } + if got := len(hints[0].Indexes); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t2", hints[0].Indexes[0].L) { + t.Fatalf("got %v, want %v", hints[0].Indexes[0].L, "t2") + } - // Test TIDB_SMJ - stmt, _, err = p.Parse("select /*+ TIDB_SMJ(T1,t2), tidb_smj(T3,t4) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if !reflect.DeepEqual("index_lookup_pushdown", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "index_lookup_pushdown") + } + if got := len(hints[1].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t3", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t3") + } + if got := len(hints[1].Indexes); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t4", hints[1].Indexes[0].L) { + t.Fatalf("got %v, want %v", hints[1].Indexes[0].L, "t4") + } + + // Test TIDB_SMJ + stmt, _, err = p.Parse("select /*+ TIDB_SMJ(T1,t2), tidb_smj(T3,t4) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "tidb_smj", hints[0].HintName.L) - require.Len(t, hints[0].Tables, 2) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) - require.Equal(t, "t2", hints[0].Tables[1].TableName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("tidb_smj", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "tidb_smj") + } + if got := len(hints[0].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } + if !reflect.DeepEqual("t2", hints[0].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[1].TableName.L, "t2") + } - require.Equal(t, "tidb_smj", hints[1].HintName.L) - require.Len(t, hints[1].Tables, 2) - require.Equal(t, "t3", hints[1].Tables[0].TableName.L) - require.Equal(t, "t4", hints[1].Tables[1].TableName.L) + if !reflect.DeepEqual("tidb_smj", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "tidb_smj") + } + if got := len(hints[1].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t3", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t3") + } + if !reflect.DeepEqual("t4", hints[1].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[1].TableName.L, "t4") + } // Test MERGE_JOIN stmt, _, err = p.Parse("select /*+ MERGE_JOIN(t1, T2), merge_join(t3, t4) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "merge_join", hints[0].HintName.L) - require.Len(t, hints[0].Tables, 2) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) - require.Equal(t, "t2", hints[0].Tables[1].TableName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("merge_join", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "merge_join") + } + if got := len(hints[0].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } + if !reflect.DeepEqual("t2", hints[0].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[1].TableName.L, "t2") + } - require.Equal(t, "merge_join", hints[1].HintName.L) - require.Len(t, hints[1].Tables, 2) - require.Equal(t, "t3", hints[1].Tables[0].TableName.L) - require.Equal(t, "t4", hints[1].Tables[1].TableName.L) + if !reflect.DeepEqual("merge_join", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "merge_join") + } + if got := len(hints[1].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t3", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t3") + } + if !reflect.DeepEqual("t4", hints[1].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[1].TableName.L, "t4") + } // TEST BROADCAST_JOIN stmt, _, err = p.Parse("select /*+ BROADCAST_JOIN(t1, T2), broadcast_join(t3, t4) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "broadcast_join", hints[0].HintName.L) - require.Len(t, hints[0].Tables, 2) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) - require.Equal(t, "t2", hints[0].Tables[1].TableName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("broadcast_join", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "broadcast_join") + } + if got := len(hints[0].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } + if !reflect.DeepEqual("t2", hints[0].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[1].TableName.L, "t2") + } - require.Equal(t, "broadcast_join", hints[1].HintName.L) - require.Len(t, hints[1].Tables, 2) - require.Equal(t, "t3", hints[1].Tables[0].TableName.L) - require.Equal(t, "t4", hints[1].Tables[1].TableName.L) + if !reflect.DeepEqual("broadcast_join", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "broadcast_join") + } + if got := len(hints[1].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t3", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t3") + } + if !reflect.DeepEqual("t4", hints[1].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[1].TableName.L, "t4") + } // Test TIDB_INLJ stmt, _, err = p.Parse("select /*+ TIDB_INLJ(t1, T2), tidb_inlj(t3, t4) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "tidb_inlj", hints[0].HintName.L) - require.Len(t, hints[0].Tables, 2) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) - require.Equal(t, "t2", hints[0].Tables[1].TableName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("tidb_inlj", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "tidb_inlj") + } + if got := len(hints[0].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } + if !reflect.DeepEqual("t2", hints[0].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[1].TableName.L, "t2") + } - require.Equal(t, "tidb_inlj", hints[1].HintName.L) - require.Len(t, hints[1].Tables, 2) - require.Equal(t, "t3", hints[1].Tables[0].TableName.L) - require.Equal(t, "t4", hints[1].Tables[1].TableName.L) + if !reflect.DeepEqual("tidb_inlj", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "tidb_inlj") + } + if got := len(hints[1].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t3", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t3") + } + if !reflect.DeepEqual("t4", hints[1].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[1].TableName.L, "t4") + } // Test INL_JOIN stmt, _, err = p.Parse("select /*+ INL_JOIN(t1, T2), inl_join(t3, t4) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "inl_join", hints[0].HintName.L) - require.Len(t, hints[0].Tables, 2) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) - require.Equal(t, "t2", hints[0].Tables[1].TableName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("inl_join", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "inl_join") + } + if got := len(hints[0].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } + if !reflect.DeepEqual("t2", hints[0].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[1].TableName.L, "t2") + } - require.Equal(t, "inl_join", hints[1].HintName.L) - require.Len(t, hints[1].Tables, 2) - require.Equal(t, "t3", hints[1].Tables[0].TableName.L) - require.Equal(t, "t4", hints[1].Tables[1].TableName.L) + if !reflect.DeepEqual("inl_join", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "inl_join") + } + if got := len(hints[1].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t3", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t3") + } + if !reflect.DeepEqual("t4", hints[1].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[1].TableName.L, "t4") + } // Test INL_HASH_JOIN stmt, _, err = p.Parse("select /*+ INL_HASH_JOIN(t1, T2), inl_hash_join(t3, t4) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "inl_hash_join", hints[0].HintName.L) - require.Len(t, hints[0].Tables, 2) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) - require.Equal(t, "t2", hints[0].Tables[1].TableName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("inl_hash_join", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "inl_hash_join") + } + if got := len(hints[0].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } + if !reflect.DeepEqual("t2", hints[0].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[1].TableName.L, "t2") + } - require.Equal(t, "inl_hash_join", hints[1].HintName.L) - require.Len(t, hints[1].Tables, 2) - require.Equal(t, "t3", hints[1].Tables[0].TableName.L) - require.Equal(t, "t4", hints[1].Tables[1].TableName.L) + if !reflect.DeepEqual("inl_hash_join", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "inl_hash_join") + } + if got := len(hints[1].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t3", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t3") + } + if !reflect.DeepEqual("t4", hints[1].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[1].TableName.L, "t4") + } // Test INL_MERGE_JOIN stmt, _, err = p.Parse("select /*+ INL_MERGE_JOIN(t1, T2), inl_merge_join(t3, t4) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "inl_merge_join", hints[0].HintName.L) - require.Len(t, hints[0].Tables, 2) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) - require.Equal(t, "t2", hints[0].Tables[1].TableName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("inl_merge_join", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "inl_merge_join") + } + if got := len(hints[0].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } + if !reflect.DeepEqual("t2", hints[0].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[1].TableName.L, "t2") + } - require.Equal(t, "inl_merge_join", hints[1].HintName.L) - require.Len(t, hints[1].Tables, 2) - require.Equal(t, "t3", hints[1].Tables[0].TableName.L) - require.Equal(t, "t4", hints[1].Tables[1].TableName.L) + if !reflect.DeepEqual("inl_merge_join", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "inl_merge_join") + } + if got := len(hints[1].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t3", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t3") + } + if !reflect.DeepEqual("t4", hints[1].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[1].TableName.L, "t4") + } // Test TIDB_HJ stmt, _, err = p.Parse("select /*+ TIDB_HJ(t1, T2), tidb_hj(t3, t4) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "tidb_hj", hints[0].HintName.L) - require.Len(t, hints[0].Tables, 2) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) - require.Equal(t, "t2", hints[0].Tables[1].TableName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("tidb_hj", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "tidb_hj") + } + if got := len(hints[0].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } + if !reflect.DeepEqual("t2", hints[0].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[1].TableName.L, "t2") + } - require.Equal(t, "tidb_hj", hints[1].HintName.L) - require.Len(t, hints[1].Tables, 2) - require.Equal(t, "t3", hints[1].Tables[0].TableName.L) - require.Equal(t, "t4", hints[1].Tables[1].TableName.L) + if !reflect.DeepEqual("tidb_hj", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "tidb_hj") + } + if got := len(hints[1].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t3", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t3") + } + if !reflect.DeepEqual("t4", hints[1].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[1].TableName.L, "t4") + } // Test HASH_JOIN stmt, _, err = p.Parse("select /*+ HASH_JOIN(t1, T2), hash_join(t3, t4) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "hash_join", hints[0].HintName.L) - require.Len(t, hints[0].Tables, 2) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) - require.Equal(t, "t2", hints[0].Tables[1].TableName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("hash_join", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "hash_join") + } + if got := len(hints[0].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } + if !reflect.DeepEqual("t2", hints[0].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[1].TableName.L, "t2") + } - require.Equal(t, "hash_join", hints[1].HintName.L) - require.Len(t, hints[1].Tables, 2) - require.Equal(t, "t3", hints[1].Tables[0].TableName.L) - require.Equal(t, "t4", hints[1].Tables[1].TableName.L) + if !reflect.DeepEqual("hash_join", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "hash_join") + } + if got := len(hints[1].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t3", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t3") + } + if !reflect.DeepEqual("t4", hints[1].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[1].TableName.L, "t4") + } // Test HASH_JOIN_BUILD and HASH_JOIN_PROBE stmt, _, err = p.Parse("select /*+ hash_join_build(t1), hash_join_probe(t4) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "hash_join_build", hints[0].HintName.L) - require.Len(t, hints[0].Tables, 1) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("hash_join_build", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "hash_join_build") + } + if got := len(hints[0].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } - require.Equal(t, "hash_join_probe", hints[1].HintName.L) - require.Len(t, hints[1].Tables, 1) - require.Equal(t, "t4", hints[1].Tables[0].TableName.L) + if !reflect.DeepEqual("hash_join_probe", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "hash_join_probe") + } + if got := len(hints[1].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t4", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t4") + } // Test HASH_JOIN with SWAP_JOIN_INPUTS/NO_SWAP_JOIN_INPUTS // t1 for build, t4 for probe stmt, _, err = p.Parse("select /*+ HASH_JOIN(t1, T2), hash_join(t3, t4), SWAP_JOIN_INPUTS(t1), NO_SWAP_JOIN_INPUTS(t4) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 4) - require.Equal(t, "hash_join", hints[0].HintName.L) - require.Len(t, hints[0].Tables, 2) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) - require.Equal(t, "t2", hints[0].Tables[1].TableName.L) + if got := len(hints); got != 4 { + t.Fatalf("expected length %d, got %d", 4, got) + } + if !reflect.DeepEqual("hash_join", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "hash_join") + } + if got := len(hints[0].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } + if !reflect.DeepEqual("t2", hints[0].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[1].TableName.L, "t2") + } - require.Equal(t, "hash_join", hints[1].HintName.L) - require.Len(t, hints[1].Tables, 2) - require.Equal(t, "t3", hints[1].Tables[0].TableName.L) - require.Equal(t, "t4", hints[1].Tables[1].TableName.L) + if !reflect.DeepEqual("hash_join", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "hash_join") + } + if got := len(hints[1].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t3", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t3") + } + if !reflect.DeepEqual("t4", hints[1].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[1].TableName.L, "t4") + } - require.Equal(t, "swap_join_inputs", hints[2].HintName.L) - require.Len(t, hints[2].Tables, 1) - require.Equal(t, "t1", hints[2].Tables[0].TableName.L) + if !reflect.DeepEqual("swap_join_inputs", hints[2].HintName.L) { + t.Fatalf("got %v, want %v", hints[2].HintName.L, "swap_join_inputs") + } + if got := len(hints[2].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t1", hints[2].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[2].Tables[0].TableName.L, "t1") + } - require.Equal(t, "no_swap_join_inputs", hints[3].HintName.L) - require.Len(t, hints[3].Tables, 1) - require.Equal(t, "t4", hints[3].Tables[0].TableName.L) + if !reflect.DeepEqual("no_swap_join_inputs", hints[3].HintName.L) { + t.Fatalf("got %v, want %v", hints[3].HintName.L, "no_swap_join_inputs") + } + if got := len(hints[3].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t4", hints[3].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[3].Tables[0].TableName.L, "t4") + } // Test MAX_EXECUTION_TIME queries := []string{ @@ -4696,12 +5428,20 @@ func TestOptimizerHints(t *testing.T) { } for i, query := range queries { stmt, _, err = p.Parse(query, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 1) - require.Equal(t, "max_execution_time", hints[0].HintName.L, "case", i) - require.Equal(t, uint64(1000), hints[0].HintData.(uint64)) + if got := len(hints); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("max_execution_time", hints[0].HintName.L) { + t.Fatalf("case %d: got %v, want %v", i, hints[0].HintName.L, "max_execution_time") + } + if !reflect.DeepEqual(uint64(1000), hints[0].HintData.(uint64)) { + t.Fatalf("got %v, want %v", hints[0].HintData.(uint64), uint64(1000)) + } } // Test NTH_PLAN @@ -4713,427 +5453,833 @@ func TestOptimizerHints(t *testing.T) { } for i, query := range queries { stmt, _, err = p.Parse(query, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 1) - require.Equal(t, "nth_plan", hints[0].HintName.L, "case", i) - require.Equal(t, int64(10), hints[0].HintData.(int64)) + if got := len(hints); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("nth_plan", hints[0].HintName.L) { + t.Fatalf("case %d: got %v, want %v", i, hints[0].HintName.L, "nth_plan") + } + if !reflect.DeepEqual(int64(10), hints[0].HintData.(int64)) { + t.Fatalf("got %v, want %v", hints[0].HintData.(int64), int64(10)) + } } // Test USE_INDEX_MERGE stmt, _, err = p.Parse("select /*+ USE_INDEX_MERGE(t1, c1), use_index_merge(t2, c1), use_index_merge(t3, c1, primary, c2) */ c1, c2 from t1, t2, t3 where t1.c1 = t2.c1 and t3.c2 = t1.c2", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 3) - require.Equal(t, "use_index_merge", hints[0].HintName.L) - require.Len(t, hints[0].Tables, 1) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) - require.Len(t, hints[0].Indexes, 1) - require.Equal(t, "c1", hints[0].Indexes[0].L) - - require.Equal(t, "use_index_merge", hints[1].HintName.L) - require.Len(t, hints[1].Tables, 1) - require.Equal(t, "t2", hints[1].Tables[0].TableName.L) - require.Len(t, hints[1].Indexes, 1) - require.Equal(t, "c1", hints[1].Indexes[0].L) - - require.Equal(t, "use_index_merge", hints[2].HintName.L) - require.Len(t, hints[2].Tables, 1) - require.Equal(t, "t3", hints[2].Tables[0].TableName.L) - require.Len(t, hints[2].Indexes, 3) - require.Equal(t, "c1", hints[2].Indexes[0].L) - require.Equal(t, "primary", hints[2].Indexes[1].L) - require.Equal(t, "c2", hints[2].Indexes[2].L) + if got := len(hints); got != 3 { + t.Fatalf("expected length %d, got %d", 3, got) + } + if !reflect.DeepEqual("use_index_merge", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "use_index_merge") + } + if got := len(hints[0].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } + if got := len(hints[0].Indexes); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("c1", hints[0].Indexes[0].L) { + t.Fatalf("got %v, want %v", hints[0].Indexes[0].L, "c1") + } + + if !reflect.DeepEqual("use_index_merge", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "use_index_merge") + } + if got := len(hints[1].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t2", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t2") + } + if got := len(hints[1].Indexes); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("c1", hints[1].Indexes[0].L) { + t.Fatalf("got %v, want %v", hints[1].Indexes[0].L, "c1") + } + + if !reflect.DeepEqual("use_index_merge", hints[2].HintName.L) { + t.Fatalf("got %v, want %v", hints[2].HintName.L, "use_index_merge") + } + if got := len(hints[2].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t3", hints[2].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[2].Tables[0].TableName.L, "t3") + } + if got := len(hints[2].Indexes); got != 3 { + t.Fatalf("expected length %d, got %d", 3, got) + } + if !reflect.DeepEqual("c1", hints[2].Indexes[0].L) { + t.Fatalf("got %v, want %v", hints[2].Indexes[0].L, "c1") + } + if !reflect.DeepEqual("primary", hints[2].Indexes[1].L) { + t.Fatalf("got %v, want %v", hints[2].Indexes[1].L, "primary") + } + if !reflect.DeepEqual("c2", hints[2].Indexes[2].L) { + t.Fatalf("got %v, want %v", hints[2].Indexes[2].L, "c2") + } // Test READ_FROM_STORAGE stmt, _, err = p.Parse("select /*+ READ_FROM_STORAGE(tiflash[t1, t2], tikv[t3]) */ c1, c2 from t1, t2, t1 t3 where t1.c1 = t2.c1 and t2.c1 = t3.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "read_from_storage", hints[0].HintName.L) - require.Equal(t, "tiflash", hints[0].HintData.(ast.CIStr).L) - require.Len(t, hints[0].Tables, 2) - require.Equal(t, "t1", hints[0].Tables[0].TableName.L) - require.Equal(t, "t2", hints[0].Tables[1].TableName.L) - require.Equal(t, "read_from_storage", hints[1].HintName.L) - require.Equal(t, "tikv", hints[1].HintData.(ast.CIStr).L) - require.Len(t, hints[1].Tables, 1) - require.Equal(t, "t3", hints[1].Tables[0].TableName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("read_from_storage", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "read_from_storage") + } + if !reflect.DeepEqual("tiflash", hints[0].HintData.(ast.CIStr).L) { + t.Fatalf("got %v, want %v", hints[0].HintData.(ast.CIStr).L, "tiflash") + } + if got := len(hints[0].Tables); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("t1", hints[0].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[0].TableName.L, "t1") + } + if !reflect.DeepEqual("t2", hints[0].Tables[1].TableName.L) { + t.Fatalf("got %v, want %v", hints[0].Tables[1].TableName.L, "t2") + } + if !reflect.DeepEqual("read_from_storage", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "read_from_storage") + } + if !reflect.DeepEqual("tikv", hints[1].HintData.(ast.CIStr).L) { + t.Fatalf("got %v, want %v", hints[1].HintData.(ast.CIStr).L, "tikv") + } + if got := len(hints[1].Tables); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if !reflect.DeepEqual("t3", hints[1].Tables[0].TableName.L) { + t.Fatalf("got %v, want %v", hints[1].Tables[0].TableName.L, "t3") + } // Test USE_TOJA stmt, _, err = p.Parse("select /*+ USE_TOJA(true), use_toja(false) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "use_toja", hints[0].HintName.L) - require.True(t, hints[0].HintData.(bool)) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("use_toja", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "use_toja") + } + if !(hints[0].HintData.(bool)) { + t.Fatal("expected true") + } - require.Equal(t, "use_toja", hints[1].HintName.L) - require.False(t, hints[1].HintData.(bool)) + if !reflect.DeepEqual("use_toja", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "use_toja") + } + if hints[1].HintData.(bool) { + t.Fatal("expected false") + } // Test IGNORE_PLAN_CACHE stmt, _, err = p.Parse("select /*+ IGNORE_PLAN_CACHE(), ignore_plan_cache() */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "ignore_plan_cache", hints[0].HintName.L) - require.Equal(t, "ignore_plan_cache", hints[1].HintName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("ignore_plan_cache", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "ignore_plan_cache") + } + if !reflect.DeepEqual("ignore_plan_cache", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "ignore_plan_cache") + } stmt, _, err = p.Parse("delete /*+ IGNORE_PLAN_CACHE(), ignore_plan_cache() */ from t where a = 1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } deleteStmt := stmt[0].(*ast.DeleteStmt) hints = deleteStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "ignore_plan_cache", hints[0].HintName.L) - require.Equal(t, "ignore_plan_cache", hints[1].HintName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("ignore_plan_cache", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "ignore_plan_cache") + } + if !reflect.DeepEqual("ignore_plan_cache", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "ignore_plan_cache") + } stmt, _, err = p.Parse("update /*+ IGNORE_PLAN_CACHE(), ignore_plan_cache() */ t set a = 1 where a = 10", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } updateStmt := stmt[0].(*ast.UpdateStmt) hints = updateStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "ignore_plan_cache", hints[0].HintName.L) - require.Equal(t, "ignore_plan_cache", hints[1].HintName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("ignore_plan_cache", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "ignore_plan_cache") + } + if !reflect.DeepEqual("ignore_plan_cache", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "ignore_plan_cache") + } // Test WRITE_SLOW_LOG stmt, _, err = p.Parse("select /*+ WRITE_SLOW_LOG(), write_slow_log() */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 0) + if got := len(hints); got != 0 { + t.Fatalf("expected length %d, got %d", 0, got) + } stmt, _, err = p.Parse("select /*+ WRITE_SLOW_LOG, write_slow_log*/ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "write_slow_log", hints[0].HintName.L) - require.Equal(t, "write_slow_log", hints[1].HintName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("write_slow_log", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "write_slow_log") + } + if !reflect.DeepEqual("write_slow_log", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "write_slow_log") + } // Test USE_CASCADES stmt, _, err = p.Parse("select /*+ USE_CASCADES(true), use_cascades(false) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "use_cascades", hints[0].HintName.L) - require.True(t, hints[0].HintData.(bool)) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("use_cascades", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "use_cascades") + } + if !(hints[0].HintData.(bool)) { + t.Fatal("expected true") + } - require.Equal(t, "use_cascades", hints[1].HintName.L) - require.False(t, hints[1].HintData.(bool)) + if !reflect.DeepEqual("use_cascades", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "use_cascades") + } + if hints[1].HintData.(bool) { + t.Fatal("expected false") + } // Test USE_PLAN_CACHE stmt, _, err = p.Parse("select /*+ USE_PLAN_CACHE(), use_plan_cache() */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "use_plan_cache", hints[0].HintName.L) - require.Equal(t, "use_plan_cache", hints[1].HintName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("use_plan_cache", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "use_plan_cache") + } + if !reflect.DeepEqual("use_plan_cache", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "use_plan_cache") + } // Test QUERY_TYPE stmt, _, err = p.Parse("select /*+ QUERY_TYPE(OLAP), query_type(OLTP) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "query_type", hints[0].HintName.L) - require.Equal(t, "olap", hints[0].HintData.(ast.CIStr).L) - require.Equal(t, "query_type", hints[1].HintName.L) - require.Equal(t, "oltp", hints[1].HintData.(ast.CIStr).L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("query_type", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "query_type") + } + if !reflect.DeepEqual("olap", hints[0].HintData.(ast.CIStr).L) { + t.Fatalf("got %v, want %v", hints[0].HintData.(ast.CIStr).L, "olap") + } + if !reflect.DeepEqual("query_type", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "query_type") + } + if !reflect.DeepEqual("oltp", hints[1].HintData.(ast.CIStr).L) { + t.Fatalf("got %v, want %v", hints[1].HintData.(ast.CIStr).L, "oltp") + } // Test MEMORY_QUOTA stmt, _, err = p.Parse("select /*+ MEMORY_QUOTA(1 MB), memory_quota(1 GB) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "memory_quota", hints[0].HintName.L) - require.Equal(t, int64(1024*1024), hints[0].HintData.(int64)) - require.Equal(t, "memory_quota", hints[1].HintName.L) - require.Equal(t, int64(1024*1024*1024), hints[1].HintData.(int64)) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("memory_quota", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "memory_quota") + } + if !reflect.DeepEqual(int64(1024*1024), hints[0].HintData.(int64)) { + t.Fatalf("got %v, want %v", hints[0].HintData.(int64), int64(1024*1024)) + } + if !reflect.DeepEqual("memory_quota", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "memory_quota") + } + if !reflect.DeepEqual(int64(1024*1024*1024), hints[1].HintData.(int64)) { + t.Fatalf("got %v, want %v", hints[1].HintData.(int64), int64(1024*1024*1024)) + } _, _, err = p.Parse("select /*+ MEMORY_QUOTA(18446744073709551612 MB), memory_quota(8689934592 GB) */ 1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } // Test HASH_AGG stmt, _, err = p.Parse("select /*+ HASH_AGG(), hash_agg() */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "hash_agg", hints[0].HintName.L) - require.Equal(t, "hash_agg", hints[1].HintName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("hash_agg", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "hash_agg") + } + if !reflect.DeepEqual("hash_agg", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "hash_agg") + } // Test MPPAgg stmt, _, err = p.Parse("select /*+ MPP_1PHASE_AGG(), mpp_1phase_agg() */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "mpp_1phase_agg", hints[0].HintName.L) - require.Equal(t, "mpp_1phase_agg", hints[1].HintName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("mpp_1phase_agg", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "mpp_1phase_agg") + } + if !reflect.DeepEqual("mpp_1phase_agg", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "mpp_1phase_agg") + } stmt, _, err = p.Parse("select /*+ MPP_2PHASE_AGG(), mpp_2phase_agg() */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "mpp_2phase_agg", hints[0].HintName.L) - require.Equal(t, "mpp_2phase_agg", hints[1].HintName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("mpp_2phase_agg", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "mpp_2phase_agg") + } + if !reflect.DeepEqual("mpp_2phase_agg", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "mpp_2phase_agg") + } // Test ShuffleJoin stmt, _, err = p.Parse("select /*+ SHUFFLE_JOIN(t1, t2), shuffle_join(t1, t2) */ * from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "shuffle_join", hints[0].HintName.L) - require.Equal(t, "shuffle_join", hints[1].HintName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("shuffle_join", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "shuffle_join") + } + if !reflect.DeepEqual("shuffle_join", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "shuffle_join") + } // Test STREAM_AGG stmt, _, err = p.Parse("select /*+ STREAM_AGG(), stream_agg() */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "stream_agg", hints[0].HintName.L) - require.Equal(t, "stream_agg", hints[1].HintName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("stream_agg", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "stream_agg") + } + if !reflect.DeepEqual("stream_agg", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "stream_agg") + } // Test AGG_TO_COP stmt, _, err = p.Parse("select /*+ AGG_TO_COP(), agg_to_cop() */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "agg_to_cop", hints[0].HintName.L) - require.Equal(t, "agg_to_cop", hints[1].HintName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("agg_to_cop", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "agg_to_cop") + } + if !reflect.DeepEqual("agg_to_cop", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "agg_to_cop") + } // Test NO_INDEX_MERGE stmt, _, err = p.Parse("select /*+ NO_INDEX_MERGE(), no_index_merge() */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "no_index_merge", hints[0].HintName.L) - require.Equal(t, "no_index_merge", hints[1].HintName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("no_index_merge", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "no_index_merge") + } + if !reflect.DeepEqual("no_index_merge", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "no_index_merge") + } // Test READ_CONSISTENT_REPLICA stmt, _, err = p.Parse("select /*+ READ_CONSISTENT_REPLICA(), read_consistent_replica() */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "read_consistent_replica", hints[0].HintName.L) - require.Equal(t, "read_consistent_replica", hints[1].HintName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("read_consistent_replica", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "read_consistent_replica") + } + if !reflect.DeepEqual("read_consistent_replica", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "read_consistent_replica") + } // Test LIMIT_TO_COP stmt, _, err = p.Parse("select /*+ LIMIT_TO_COP(), limit_to_cop() */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "limit_to_cop", hints[0].HintName.L) - require.Equal(t, "limit_to_cop", hints[1].HintName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("limit_to_cop", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "limit_to_cop") + } + if !reflect.DeepEqual("limit_to_cop", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "limit_to_cop") + } // Test CTE MERGE stmt, _, err = p.Parse("with cte(x) as (select * from t1) select /*+ MERGE(), merge() */ * from cte;", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "merge", hints[0].HintName.L) - require.Equal(t, "merge", hints[1].HintName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("merge", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "merge") + } + if !reflect.DeepEqual("merge", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "merge") + } // Test STRAIGHT_JOIN stmt, _, err = p.Parse("select /*+ STRAIGHT_JOIN(), straight_join() */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "straight_join", hints[0].HintName.L) - require.Equal(t, "straight_join", hints[1].HintName.L) + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("straight_join", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "straight_join") + } + if !reflect.DeepEqual("straight_join", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "straight_join") + } // Test LEADING stmt, _, err = p.Parse("select /*+ LEADING(T1), LEADING(t2, t3), LEADING(T4, t5, t6) */ c1, c2 from t1, t2 where t1.c1 = t2.c1", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 3) - require.Equal(t, "leading", hints[0].HintName.L) + if got := len(hints); got != 3 { + t.Fatalf("expected length %d, got %d", 3, got) + } + if !reflect.DeepEqual("leading", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "leading") + } leadingList1, ok := hints[0].HintData.(*ast.LeadingList) - require.True(t, ok) - require.Len(t, leadingList1.Items, 1) + if !(ok) { + t.Fatal("expected true") + } + if got := len(leadingList1.Items); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } hintTable1, ok := leadingList1.Items[0].(*ast.HintTable) - require.True(t, ok) - require.Equal(t, "t1", hintTable1.TableName.L) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual("t1", hintTable1.TableName.L) { + t.Fatalf("got %v, want %v", hintTable1.TableName.L, "t1") + } - require.Equal(t, "leading", hints[1].HintName.L) + if !reflect.DeepEqual("leading", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "leading") + } leadingList2, ok := hints[1].HintData.(*ast.LeadingList) - require.True(t, ok) - require.Len(t, leadingList2.Items, 2) + if !(ok) { + t.Fatal("expected true") + } + if got := len(leadingList2.Items); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } hintTable2, ok := leadingList2.Items[0].(*ast.HintTable) - require.True(t, ok) - require.Equal(t, "t2", hintTable2.TableName.L) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual("t2", hintTable2.TableName.L) { + t.Fatalf("got %v, want %v", hintTable2.TableName.L, "t2") + } hintTable3, ok := leadingList2.Items[1].(*ast.HintTable) - require.True(t, ok) - require.Equal(t, "t3", hintTable3.TableName.L) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual("t3", hintTable3.TableName.L) { + t.Fatalf("got %v, want %v", hintTable3.TableName.L, "t3") + } - require.Equal(t, "leading", hints[2].HintName.L) + if !reflect.DeepEqual("leading", hints[2].HintName.L) { + t.Fatalf("got %v, want %v", hints[2].HintName.L, "leading") + } leadingList3, ok := hints[2].HintData.(*ast.LeadingList) - require.True(t, ok) - require.Len(t, leadingList3.Items, 3) + if !(ok) { + t.Fatal("expected true") + } + if got := len(leadingList3.Items); got != 3 { + t.Fatalf("expected length %d, got %d", 3, got) + } hintTable4, ok := leadingList3.Items[0].(*ast.HintTable) - require.True(t, ok) - require.Equal(t, "t4", hintTable4.TableName.L) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual("t4", hintTable4.TableName.L) { + t.Fatalf("got %v, want %v", hintTable4.TableName.L, "t4") + } hintTable5, ok := leadingList3.Items[1].(*ast.HintTable) - require.True(t, ok) - require.Equal(t, "t5", hintTable5.TableName.L) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual("t5", hintTable5.TableName.L) { + t.Fatalf("got %v, want %v", hintTable5.TableName.L, "t5") + } hintTable6, ok := leadingList3.Items[2].(*ast.HintTable) - require.True(t, ok) - require.Equal(t, "t6", hintTable6.TableName.L) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual("t6", hintTable6.TableName.L) { + t.Fatalf("got %v, want %v", hintTable6.TableName.L, "t6") + } // Test NO_HASH_JOIN stmt, _, err = p.Parse("select /*+ NO_HASH_JOIN(t1, t2), NO_HASH_JOIN(t3) */ * from t1, t2, t3", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "no_hash_join", hints[0].HintName.L) - require.Equal(t, hints[0].Tables[0].TableName.L, "t1") - require.Equal(t, hints[0].Tables[1].TableName.L, "t2") + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("no_hash_join", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "no_hash_join") + } + if !reflect.DeepEqual(hints[0].Tables[0].TableName.L, "t1") { + t.Fatalf("got %v, want %v", "t1", hints[0].Tables[0].TableName.L) + } + if !reflect.DeepEqual(hints[0].Tables[1].TableName.L, "t2") { + t.Fatalf("got %v, want %v", "t2", hints[0].Tables[1].TableName.L) + } - require.Equal(t, "no_hash_join", hints[1].HintName.L) - require.Equal(t, hints[1].Tables[0].TableName.L, "t3") + if !reflect.DeepEqual("no_hash_join", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "no_hash_join") + } + if !reflect.DeepEqual(hints[1].Tables[0].TableName.L, "t3") { + t.Fatalf("got %v, want %v", "t3", hints[1].Tables[0].TableName.L) + } // Test NO_MERGE_JOIN stmt, _, err = p.Parse("select /*+ NO_MERGE_JOIN(t1), NO_MERGE_JOIN(t3) */ * from t1, t2, t3", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "no_merge_join", hints[0].HintName.L) - require.Equal(t, hints[0].Tables[0].TableName.L, "t1") + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("no_merge_join", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "no_merge_join") + } + if !reflect.DeepEqual(hints[0].Tables[0].TableName.L, "t1") { + t.Fatalf("got %v, want %v", "t1", hints[0].Tables[0].TableName.L) + } - require.Equal(t, "no_merge_join", hints[1].HintName.L) - require.Equal(t, hints[1].Tables[0].TableName.L, "t3") + if !reflect.DeepEqual("no_merge_join", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "no_merge_join") + } + if !reflect.DeepEqual(hints[1].Tables[0].TableName.L, "t3") { + t.Fatalf("got %v, want %v", "t3", hints[1].Tables[0].TableName.L) + } // Test INDEX_JOIN stmt, _, err = p.Parse("select /*+ INDEX_JOIN(t1), INDEX_JOIN(t3) */ * from t1, t2, t3", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "index_join", hints[0].HintName.L) - require.Equal(t, hints[0].Tables[0].TableName.L, "t1") + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("index_join", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "index_join") + } + if !reflect.DeepEqual(hints[0].Tables[0].TableName.L, "t1") { + t.Fatalf("got %v, want %v", "t1", hints[0].Tables[0].TableName.L) + } - require.Equal(t, "index_join", hints[1].HintName.L) - require.Equal(t, hints[1].Tables[0].TableName.L, "t3") + if !reflect.DeepEqual("index_join", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "index_join") + } + if !reflect.DeepEqual(hints[1].Tables[0].TableName.L, "t3") { + t.Fatalf("got %v, want %v", "t3", hints[1].Tables[0].TableName.L) + } // Test NO_INDEX_JOIN stmt, _, err = p.Parse("select /*+ NO_INDEX_JOIN(t1), NO_INDEX_JOIN(t3) */ * from t1, t2, t3", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "no_index_join", hints[0].HintName.L) - require.Equal(t, hints[0].Tables[0].TableName.L, "t1") + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("no_index_join", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "no_index_join") + } + if !reflect.DeepEqual(hints[0].Tables[0].TableName.L, "t1") { + t.Fatalf("got %v, want %v", "t1", hints[0].Tables[0].TableName.L) + } - require.Equal(t, "no_index_join", hints[1].HintName.L) - require.Equal(t, hints[1].Tables[0].TableName.L, "t3") + if !reflect.DeepEqual("no_index_join", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "no_index_join") + } + if !reflect.DeepEqual(hints[1].Tables[0].TableName.L, "t3") { + t.Fatalf("got %v, want %v", "t3", hints[1].Tables[0].TableName.L) + } // Test INDEX_HASH_JOIN stmt, _, err = p.Parse("select /*+ INDEX_HASH_JOIN(t1), INDEX_HASH_JOIN(t3) */ * from t1, t2, t3", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "index_hash_join", hints[0].HintName.L) - require.Equal(t, hints[0].Tables[0].TableName.L, "t1") + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("index_hash_join", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "index_hash_join") + } + if !reflect.DeepEqual(hints[0].Tables[0].TableName.L, "t1") { + t.Fatalf("got %v, want %v", "t1", hints[0].Tables[0].TableName.L) + } - require.Equal(t, "index_hash_join", hints[1].HintName.L) - require.Equal(t, hints[1].Tables[0].TableName.L, "t3") + if !reflect.DeepEqual("index_hash_join", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "index_hash_join") + } + if !reflect.DeepEqual(hints[1].Tables[0].TableName.L, "t3") { + t.Fatalf("got %v, want %v", "t3", hints[1].Tables[0].TableName.L) + } // Test NO_INDEX_HASH_JOIN stmt, _, err = p.Parse("select /*+ NO_INDEX_HASH_JOIN(t1), NO_INDEX_HASH_JOIN(t3) */ * from t1, t2, t3", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "no_index_hash_join", hints[0].HintName.L) - require.Equal(t, hints[0].Tables[0].TableName.L, "t1") + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("no_index_hash_join", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "no_index_hash_join") + } + if !reflect.DeepEqual(hints[0].Tables[0].TableName.L, "t1") { + t.Fatalf("got %v, want %v", "t1", hints[0].Tables[0].TableName.L) + } - require.Equal(t, "no_index_hash_join", hints[1].HintName.L) - require.Equal(t, hints[1].Tables[0].TableName.L, "t3") + if !reflect.DeepEqual("no_index_hash_join", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "no_index_hash_join") + } + if !reflect.DeepEqual(hints[1].Tables[0].TableName.L, "t3") { + t.Fatalf("got %v, want %v", "t3", hints[1].Tables[0].TableName.L) + } // Test INDEX_MERGE_JOIN stmt, _, err = p.Parse("select /*+ INDEX_MERGE_JOIN(t1), INDEX_MERGE_JOIN(t3) */ * from t1, t2, t3", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "index_merge_join", hints[0].HintName.L) - require.Equal(t, hints[0].Tables[0].TableName.L, "t1") + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("index_merge_join", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "index_merge_join") + } + if !reflect.DeepEqual(hints[0].Tables[0].TableName.L, "t1") { + t.Fatalf("got %v, want %v", "t1", hints[0].Tables[0].TableName.L) + } - require.Equal(t, "index_merge_join", hints[1].HintName.L) - require.Equal(t, hints[1].Tables[0].TableName.L, "t3") + if !reflect.DeepEqual("index_merge_join", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "index_merge_join") + } + if !reflect.DeepEqual(hints[1].Tables[0].TableName.L, "t3") { + t.Fatalf("got %v, want %v", "t3", hints[1].Tables[0].TableName.L) + } // Test NO_INDEX_MERGE_JOIN stmt, _, err = p.Parse("select /*+ NO_INDEX_MERGE_JOIN(t1), NO_INDEX_MERGE_JOIN(t3) */ * from t1, t2, t3", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "no_index_merge_join", hints[0].HintName.L) - require.Equal(t, hints[0].Tables[0].TableName.L, "t1") + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("no_index_merge_join", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "no_index_merge_join") + } + if !reflect.DeepEqual(hints[0].Tables[0].TableName.L, "t1") { + t.Fatalf("got %v, want %v", "t1", hints[0].Tables[0].TableName.L) + } - require.Equal(t, "no_index_merge_join", hints[1].HintName.L) - require.Equal(t, hints[1].Tables[0].TableName.L, "t3") + if !reflect.DeepEqual("no_index_merge_join", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "no_index_merge_join") + } + if !reflect.DeepEqual(hints[1].Tables[0].TableName.L, "t3") { + t.Fatalf("got %v, want %v", "t3", hints[1].Tables[0].TableName.L) + } // Test HYPO_INDEX stmt, _, err = p.Parse("select /*+ HYPO_INDEX(t1, a), HYPO_INDEX(t3, a, b, c) */ * from t1, t2, t3", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } selectStmt = stmt[0].(*ast.SelectStmt) hints = selectStmt.TableHints - require.Len(t, hints, 2) - require.Equal(t, "hypo_index", hints[0].HintName.L) - require.Equal(t, hints[0].Tables[0].TableName.L, "t1") + if got := len(hints); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual("hypo_index", hints[0].HintName.L) { + t.Fatalf("got %v, want %v", hints[0].HintName.L, "hypo_index") + } + if !reflect.DeepEqual(hints[0].Tables[0].TableName.L, "t1") { + t.Fatalf("got %v, want %v", "t1", hints[0].Tables[0].TableName.L) + } - require.Equal(t, "hypo_index", hints[1].HintName.L) - require.Equal(t, hints[1].Tables[0].TableName.L, "t3") + if !reflect.DeepEqual("hypo_index", hints[1].HintName.L) { + t.Fatalf("got %v, want %v", hints[1].HintName.L, "hypo_index") + } + if !reflect.DeepEqual(hints[1].Tables[0].TableName.L, "t3") { + t.Fatalf("got %v, want %v", "t3", hints[1].Tables[0].TableName.L) + } } func TestType(t *testing.T) { @@ -5400,7 +6546,9 @@ type subqueryChecker struct { // Enter implements ast.Visitor interface. func (sc *subqueryChecker) Enter(inNode ast.Node) (outNode ast.Node, skipChildren bool) { if expr, ok := inNode.(*ast.SubqueryExpr); ok { - require.Equal(sc.t, sc.text, expr.Query.Text()) + if !reflect.DeepEqual(sc.text, expr.Query.Text()) { + sc.t.Fatalf("got %v, want %v", expr.Query.Text(), sc.text) + } return inNode, true } return inNode, false @@ -5455,7 +6603,9 @@ func TestSubquery(t *testing.T) { p := parser.New() for _, tbl := range tests { stmt, err := p.ParseOneStmt(tbl.input, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } stmt.Accept(&subqueryChecker{ text: tbl.text, t: t, @@ -5587,7 +6737,9 @@ func TestSetOperator(t *testing.T) { func checkOrderBy(t *testing.T, s ast.Node, hasOrderBy []bool, i int) int { switch x := s.(type) { case *ast.SelectStmt: - require.Equal(t, hasOrderBy[i], x.OrderBy != nil) + if !reflect.DeepEqual(hasOrderBy[i], x.OrderBy != nil) { + t.Fatalf("got %v, want %v", x.OrderBy != nil, hasOrderBy[i]) + } return i + 1 case *ast.SetOprSelectList: for _, sel := range x.Selects { @@ -5615,18 +6767,24 @@ func TestUnionOrderBy(t *testing.T) { for _, tbl := range tests { stmt, _, err := p.Parse(tbl.src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } us, ok := stmt[0].(*ast.SetOprStmt) if ok { var i int for _, s := range us.SelectList.Selects { i = checkOrderBy(t, s, tbl.hasOrderBy, i) } - require.Equal(t, tbl.hasOrderBy[i], us.OrderBy != nil) + if !reflect.DeepEqual(tbl.hasOrderBy[i], us.OrderBy != nil) { + t.Fatalf("got %v, want %v", us.OrderBy != nil, tbl.hasOrderBy[i]) + } } ss, ok := stmt[0].(*ast.SelectStmt) if ok { - require.Equal(t, tbl.hasOrderBy[0], ss.OrderBy != nil) + if !reflect.DeepEqual(tbl.hasOrderBy[0], ss.OrderBy != nil) { + t.Fatalf("got %v, want %v", ss.OrderBy != nil, tbl.hasOrderBy[0]) + } } } } @@ -5727,9 +6885,13 @@ func TestPriority(t *testing.T) { p := parser.New() stmt, _, err := p.Parse("select HIGH_PRIORITY * from t", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } sel := stmt[0].(*ast.SelectStmt) - require.Equal(t, mysql.HighPriority, sel.SelectStmtOpts.Priority) + if !reflect.DeepEqual(mysql.HighPriority, sel.SelectStmtOpts.Priority) { + t.Fatalf("got %v, want %v", sel.SelectStmtOpts.Priority, mysql.HighPriority) + } } func TestSQLResult(t *testing.T) { @@ -5755,10 +6917,14 @@ func TestSQLNoCache(t *testing.T) { p := parser.New() for _, tbl := range table { stmt, _, err := p.Parse(tbl.src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } sel := stmt[0].(*ast.SelectStmt) - require.Equal(t, tbl.ok, sel.SelectStmtOpts.SQLCache) + if !reflect.DeepEqual(tbl.ok, sel.SelectStmtOpts.SQLCache) { + t.Fatalf("got %v, want %v", sel.SelectStmtOpts.SQLCache, tbl.ok) + } } } @@ -5968,12 +7134,22 @@ func TestBinding(t *testing.T) { p := parser.New() sms, _, err := p.Parse("create global binding for select * from t using select * from t use index(a)", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } v, ok := sms[0].(*ast.CreateBindingStmt) - require.True(t, ok) - require.Equal(t, "select * from t", v.OriginNode.Text()) - require.Equal(t, "select * from t use index(a)", v.HintedNode.Text()) - require.True(t, v.GlobalScope) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual("select * from t", v.OriginNode.Text()) { + t.Fatalf("got %v, want %v", v.OriginNode.Text(), "select * from t") + } + if !reflect.DeepEqual("select * from t use index(a)", v.HintedNode.Text()) { + t.Fatalf("got %v, want %v", v.HintedNode.Text(), "select * from t use index(a)") + } + if !(v.GlobalScope) { + t.Fatal("expected true") + } } func TestView(t *testing.T) { @@ -6052,13 +7228,25 @@ func TestView(t *testing.T) { // Test case for the text of the select statement in create view statement. p := parser.New() sms, _, err := p.Parse("create view v as select * from t", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } v, ok := sms[0].(*ast.CreateViewStmt) - require.True(t, ok) - require.Equal(t, ast.AlgorithmUndefined, v.Algorithm) - require.Equal(t, "select * from t", v.Select.Text()) - require.Equal(t, ast.SecurityDefiner, v.Security) - require.Equal(t, ast.CheckOptionCascaded, v.CheckOption) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual(ast.AlgorithmUndefined, v.Algorithm) { + t.Fatalf("got %v, want %v", v.Algorithm, ast.AlgorithmUndefined) + } + if !reflect.DeepEqual("select * from t", v.Select.Text()) { + t.Fatalf("got %v, want %v", v.Select.Text(), "select * from t") + } + if !reflect.DeepEqual(ast.SecurityDefiner, v.Security) { + t.Fatalf("got %v, want %v", v.Security, ast.SecurityDefiner) + } + if !reflect.DeepEqual(ast.CheckOptionCascaded, v.CheckOption) { + t.Fatalf("got %v, want %v", v.CheckOption, ast.CheckOptionCascaded) + } src := `CREATE OR REPLACE ALGORITHM = UNDEFINED DEFINER = root@localhost SQL SECURITY DEFINER @@ -6067,29 +7255,61 @@ func TestView(t *testing.T) { var st ast.StmtNode st, err = p.ParseOneStmt(src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } v, ok = st.(*ast.CreateViewStmt) - require.True(t, ok) - require.True(t, v.OrReplace) - require.Equal(t, ast.AlgorithmUndefined, v.Algorithm) - require.Equal(t, "root", v.Definer.Username) - require.Equal(t, "localhost", v.Definer.Hostname) - require.Equal(t, ast.NewCIStr("a"), v.Cols[0]) - require.Equal(t, ast.NewCIStr("b"), v.Cols[1]) - require.Equal(t, ast.NewCIStr("c"), v.Cols[2]) - require.Equal(t, "select c,d,e from t", v.Select.Text()) - require.Equal(t, ast.SecurityDefiner, v.Security) - require.Equal(t, ast.CheckOptionCascaded, v.CheckOption) + if !(ok) { + t.Fatal("expected true") + } + if !(v.OrReplace) { + t.Fatal("expected true") + } + if !reflect.DeepEqual(ast.AlgorithmUndefined, v.Algorithm) { + t.Fatalf("got %v, want %v", v.Algorithm, ast.AlgorithmUndefined) + } + if !reflect.DeepEqual("root", v.Definer.Username) { + t.Fatalf("got %v, want %v", v.Definer.Username, "root") + } + if !reflect.DeepEqual("localhost", v.Definer.Hostname) { + t.Fatalf("got %v, want %v", v.Definer.Hostname, "localhost") + } + if !reflect.DeepEqual(ast.NewCIStr("a"), v.Cols[0]) { + t.Fatalf("got %v, want %v", v.Cols[0], ast.NewCIStr("a")) + } + if !reflect.DeepEqual(ast.NewCIStr("b"), v.Cols[1]) { + t.Fatalf("got %v, want %v", v.Cols[1], ast.NewCIStr("b")) + } + if !reflect.DeepEqual(ast.NewCIStr("c"), v.Cols[2]) { + t.Fatalf("got %v, want %v", v.Cols[2], ast.NewCIStr("c")) + } + if !reflect.DeepEqual("select c,d,e from t", v.Select.Text()) { + t.Fatalf("got %v, want %v", v.Select.Text(), "select c,d,e from t") + } + if !reflect.DeepEqual(ast.SecurityDefiner, v.Security) { + t.Fatalf("got %v, want %v", v.Security, ast.SecurityDefiner) + } + if !reflect.DeepEqual(ast.CheckOptionCascaded, v.CheckOption) { + t.Fatalf("got %v, want %v", v.CheckOption, ast.CheckOptionCascaded) + } src = ` CREATE VIEW v1 AS SELECT * FROM t; CREATE VIEW v2 AS SELECT 123123123123123; ` nodes, _, err := p.Parse(src, "", "") - require.NoError(t, err) - require.Len(t, nodes, 2) - require.Equal(t, nodes[0].(*ast.CreateViewStmt).Select.Text(), "SELECT * FROM t") - require.Equal(t, nodes[1].(*ast.CreateViewStmt).Select.Text(), "SELECT 123123123123123") + if err != nil { + t.Fatal(err) + } + if got := len(nodes); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual(nodes[0].(*ast.CreateViewStmt).Select.Text(), "SELECT * FROM t") { + t.Fatalf("got %v, want %v", "SELECT * FROM t", nodes[0].(*ast.CreateViewStmt).Select.Text()) + } + if !reflect.DeepEqual(nodes[1].(*ast.CreateViewStmt).Select.Text(), "SELECT 123123123123123") { + t.Fatalf("got %v, want %v", "SELECT 123123123123123", nodes[1].(*ast.CreateViewStmt).Select.Text()) + } } func TestTimestampDiffUnit(t *testing.T) { @@ -6097,19 +7317,31 @@ func TestTimestampDiffUnit(t *testing.T) { // TimeUnit should be unified to upper case. p := parser.New() stmt, _, err := p.Parse("SELECT TIMESTAMPDIFF(MONTH,'2003-02-01','2003-05-01'), TIMESTAMPDIFF(month,'2003-02-01','2003-05-01');", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } ss := stmt[0].(*ast.SelectStmt) fields := ss.Fields.Fields - require.Len(t, fields, 2) + if got := len(fields); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } expr := fields[0].Expr f, ok := expr.(*ast.FuncCallExpr) - require.True(t, ok) - require.Equal(t, ast.TimeUnitMonth, f.Args[0].(*ast.TimeUnitExpr).Unit) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual(ast.TimeUnitMonth, f.Args[0].(*ast.TimeUnitExpr).Unit) { + t.Fatalf("got %v, want %v", f.Args[0].(*ast.TimeUnitExpr).Unit, ast.TimeUnitMonth) + } expr = fields[1].Expr f, ok = expr.(*ast.FuncCallExpr) - require.True(t, ok) - require.Equal(t, ast.TimeUnitMonth, f.Args[0].(*ast.TimeUnitExpr).Unit) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual(ast.TimeUnitMonth, f.Args[0].(*ast.TimeUnitExpr).Unit) { + t.Fatalf("got %v, want %v", f.Args[0].(*ast.TimeUnitExpr).Unit, ast.TimeUnitMonth) + } // Test Illegal TimeUnit for TimestampDiff table := []testCase{ @@ -6132,25 +7364,37 @@ func TestFuncCallExprOffset(t *testing.T) { // Test case for offset field on func call expr. p := parser.New() stmt, _, err := p.Parse("SELECT s.a(), b();", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } ss := stmt[0].(*ast.SelectStmt) fields := ss.Fields.Fields - require.Len(t, fields, 2) + if got := len(fields); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } { // s.a() expr := fields[0].Expr f, ok := expr.(*ast.FuncCallExpr) - require.True(t, ok) - require.Equal(t, 7, f.OriginTextPosition()) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual(7, f.OriginTextPosition()) { + t.Fatalf("got %v, want %v", f.OriginTextPosition(), 7) + } } { // b() expr := fields[1].Expr f, ok := expr.(*ast.FuncCallExpr) - require.True(t, ok) - require.Equal(t, 14, f.OriginTextPosition()) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual(14, f.OriginTextPosition()) { + t.Fatalf("got %v, want %v", f.OriginTextPosition(), 14) + } } } @@ -6191,7 +7435,9 @@ func TestSQLModeANSIQuotes(t *testing.T) { } for _, test := range tests { _, _, err := p.Parse(test, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } } } @@ -6203,18 +7449,28 @@ func TestDDLStatements(t *testing.T) { b char(10) charset utf8 collate utf8_general_ci, c text charset latin1) ENGINE=innoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin` stmts, _, err := p.Parse(createTableStr, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } stmt := stmts[0].(*ast.CreateTableStmt) - require.True(t, mysql.HasBinaryFlag(stmt.Cols[0].Tp.GetFlag())) + if !(mysql.HasBinaryFlag(stmt.Cols[0].Tp.GetFlag())) { + t.Fatal("expected true") + } for _, colDef := range stmt.Cols[1:] { - require.False(t, mysql.HasBinaryFlag(colDef.Tp.GetFlag())) + if mysql.HasBinaryFlag(colDef.Tp.GetFlag()) { + t.Fatal("expected false") + } } for _, tblOpt := range stmt.Options { switch tblOpt.Tp { case ast.TableOptionCharset: - require.Equal(t, "utf8", tblOpt.StrValue) + if !reflect.DeepEqual("utf8", tblOpt.StrValue) { + t.Fatalf("got %v, want %v", tblOpt.StrValue, "utf8") + } case ast.TableOptionCollate: - require.Equal(t, "utf8_bin", tblOpt.StrValue) + if !reflect.DeepEqual("utf8_bin", tblOpt.StrValue) { + t.Fatalf("got %v, want %v", tblOpt.StrValue, "utf8_bin") + } } } createTableStr = `CREATE TABLE t ( @@ -6222,12 +7478,20 @@ func TestDDLStatements(t *testing.T) { b binary(10), c blob)` stmts, _, err = p.Parse(createTableStr, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } stmt = stmts[0].(*ast.CreateTableStmt) for _, colDef := range stmt.Cols { - require.Equal(t, charset.CharsetBin, colDef.Tp.GetCharset()) - require.Equal(t, charset.CollationBin, colDef.Tp.GetCollate()) - require.True(t, mysql.HasBinaryFlag(colDef.Tp.GetFlag())) + if !reflect.DeepEqual(charset.CharsetBin, colDef.Tp.GetCharset()) { + t.Fatalf("got %v, want %v", colDef.Tp.GetCharset(), charset.CharsetBin) + } + if !reflect.DeepEqual(charset.CollationBin, colDef.Tp.GetCollate()) { + t.Fatalf("got %v, want %v", colDef.Tp.GetCollate(), charset.CollationBin) + } + if !(mysql.HasBinaryFlag(colDef.Tp.GetFlag())) { + t.Fatal("expected true") + } } // Test set collate for all column types createTableStr = `CREATE TABLE t ( @@ -6259,27 +7523,39 @@ func TestDDLStatements(t *testing.T) { c_set set('1') collate utf8_bin, c_json json collate utf8_bin)` _, _, err = p.Parse(createTableStr, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } createTableStr = `CREATE TABLE t (c_double double(10))` _, _, err = p.Parse(createTableStr, "", "") - require.EqualError(t, err, "[parser:1149]You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use") + if err == nil || err.Error() != "[parser:1149]You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use" { + t.Fatalf("expected error %q, got %v", "[parser:1149]You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use", err) + } p.SetStrictDoubleTypeCheck(false) _, _, err = p.Parse(createTableStr, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } p.SetStrictDoubleTypeCheck(true) createTableStr = `CREATE TABLE t (c_double double(10, 2))` _, _, err = p.Parse(createTableStr, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } createTableStr = `create global temporary table t010(local_01 int, local_03 varchar(20))` _, _, err = p.Parse(createTableStr, "", "") - require.EqualError(t, err, "line 1 column 70 near \"\"GLOBAL TEMPORARY and ON COMMIT DELETE ROWS must appear together ") + if err == nil || err.Error() != "line 1 column 70 near \"\"GLOBAL TEMPORARY and ON COMMIT DELETE ROWS must appear together " { + t.Fatalf("expected error %q, got %v", "line 1 column 70 near \"\"GLOBAL TEMPORARY and ON COMMIT DELETE ROWS must appear together ", err) + } createTableStr = `create global temporary table t010(local_01 int, local_03 varchar(20)) on commit preserve rows` _, _, err = p.Parse(createTableStr, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } } func TestAnalyze(t *testing.T) { @@ -6389,7 +7665,9 @@ func TestTableSample(t *testing.T) { } for _, sql := range cases { _, err := p.ParseOneStmt(sql, "", "") - require.NoErrorf(t, err, "source %v", sql) + if err != nil { + t.Fatalf("%s: %v", fmt.Sprintf("source %v", sql), err) + } } } @@ -6407,26 +7685,38 @@ func TestGeneratedColumn(t *testing.T) { for _, tbl := range tests { stmtNodes, _, err := p.Parse(tbl.input, "", "") if tbl.ok { - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } stmtNode := stmtNodes[0] for _, col := range stmtNode.(*ast.CreateTableStmt).Cols { for _, opt := range col.Options { if opt.Tp == ast.ColumnOptionGenerated { - require.Equal(t, tbl.expr, opt.Expr.Text()) + if !reflect.DeepEqual(tbl.expr, opt.Expr.Text()) { + t.Fatalf("got %v, want %v", opt.Expr.Text(), tbl.expr) + } } } } } else { - require.Error(t, err) + if err == nil { + t.Fatal("expected error") + } } } _, _, err := p.Parse("create table t1 (a int, b int as (a + 1) default 10);", "", "") - require.Equal(t, err.Error(), "[ddl:1221]Incorrect usage of DEFAULT and generated column") + if !reflect.DeepEqual(err.Error(), "[ddl:1221]Incorrect usage of DEFAULT and generated column") { + t.Fatalf("got %v, want %v", "[ddl:1221]Incorrect usage of DEFAULT and generated column", err.Error()) + } _, _, err = p.Parse("create table t1 (a int, b int as (a + 1) on update now());", "", "") - require.Equal(t, err.Error(), "[ddl:1221]Incorrect usage of ON UPDATE and generated column") + if !reflect.DeepEqual(err.Error(), "[ddl:1221]Incorrect usage of ON UPDATE and generated column") { + t.Fatalf("got %v, want %v", "[ddl:1221]Incorrect usage of ON UPDATE and generated column", err.Error()) + } _, _, err = p.Parse("create table t1 (a int, b int as (a + 1) auto_increment);", "", "") - require.Equal(t, err.Error(), "[ddl:1221]Incorrect usage of AUTO_INCREMENT and generated column") + if !reflect.DeepEqual(err.Error(), "[ddl:1221]Incorrect usage of AUTO_INCREMENT and generated column") { + t.Fatalf("got %v, want %v", "[ddl:1221]Incorrect usage of AUTO_INCREMENT and generated column", err.Error()) + } } func TestSetTransaction(t *testing.T) { @@ -6451,13 +7741,23 @@ func TestSetTransaction(t *testing.T) { p := parser.New() for _, tbl := range tests { stmt1, err := p.ParseOneStmt(tbl.input, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } setStmt := stmt1.(*ast.SetStmt) vars := setStmt.Variables[0] - require.Equal(t, "tx_isolation", vars.Name) - require.Equal(t, tbl.isGlobal, vars.IsGlobal) - require.Equal(t, true, vars.IsSystem) - require.Equal(t, tbl.value, vars.Value.(ast.ValueExpr).GetValue()) + if !reflect.DeepEqual("tx_isolation", vars.Name) { + t.Fatalf("got %v, want %v", vars.Name, "tx_isolation") + } + if !reflect.DeepEqual(tbl.isGlobal, vars.IsGlobal) { + t.Fatalf("got %v, want %v", vars.IsGlobal, tbl.isGlobal) + } + if !reflect.DeepEqual(true, vars.IsSystem) { + t.Fatalf("got %v, want %v", vars.IsSystem, true) + } + if !reflect.DeepEqual(tbl.value, vars.Value.(ast.ValueExpr).GetValue()) { + t.Fatalf("got %v, want %v", vars.Value.(ast.ValueExpr).GetValue(), tbl.value) + } } } @@ -6466,10 +7766,14 @@ func TestSideEffect(t *testing.T) { // clean state, cause the following SQL parse fail. p := parser.New() _, err := p.ParseOneStmt("create table t /*!50100 'abc', 'abc' */;", "", "") - require.Error(t, err) + if err == nil { + t.Fatal("expected error") + } _, err = p.ParseOneStmt("show tables;", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } } func TestTablePartition(t *testing.T) { @@ -6671,11 +7975,17 @@ ENGINE=INNODB PARTITION BY LINEAR HASH (a) PARTITIONS 1;`, true, "CREATE TABLE ` // Check comment content. p := parser.New() stmt, err := p.ParseOneStmt("create table t (id int) partition by range (id) (partition p0 values less than (10) comment 'check')", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } createTable := stmt.(*ast.CreateTableStmt) comment, ok := createTable.Partition.Definitions[0].Comment() - require.True(t, ok) - require.Equal(t, "check", comment) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual("check", comment) { + t.Fatalf("got %v, want %v", comment, "check") + } } func TestTablePartitionNameList(t *testing.T) { @@ -6686,16 +7996,28 @@ func TestTablePartitionNameList(t *testing.T) { p := parser.New() for _, tbl := range table { stmt, _, err := p.Parse(tbl.src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } sel := stmt[0].(*ast.SelectStmt) source, ok := sel.From.TableRefs.Left.(*ast.TableSource) - require.True(t, ok) + if !(ok) { + t.Fatal("expected true") + } tableName, ok := source.Source.(*ast.TableName) - require.True(t, ok) - require.Len(t, tableName.PartitionNames, 2) - require.Equal(t, ast.CIStr{O: "p0", L: "p0"}, tableName.PartitionNames[0]) - require.Equal(t, ast.CIStr{O: "p1", L: "p1"}, tableName.PartitionNames[1]) + if !(ok) { + t.Fatal("expected true") + } + if got := len(tableName.PartitionNames); got != 2 { + t.Fatalf("expected length %d, got %d", 2, got) + } + if !reflect.DeepEqual(ast.CIStr{O: "p0", L: "p0"}, tableName.PartitionNames[0]) { + t.Fatalf("got %v, want %v", tableName.PartitionNames[0], ast.CIStr{O: "p0", L: "p0"}) + } + if !reflect.DeepEqual(ast.CIStr{O: "p1", L: "p1"}, tableName.PartitionNames[1]) { + t.Fatalf("got %v, want %v", tableName.PartitionNames[1], ast.CIStr{O: "p1", L: "p1"}) + } } } @@ -6707,12 +8029,18 @@ func TestNotExistsSubquery(t *testing.T) { p := parser.New() for _, tbl := range table { stmt, _, err := p.Parse(tbl.src, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } sel := stmt[0].(*ast.SelectStmt) exists, ok := sel.Where.(*ast.ExistsSubqueryExpr) - require.True(t, ok) - require.Equal(t, tbl.ok, exists.Not) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual(tbl.ok, exists.Not) { + t.Fatalf("got %v, want %v", exists.Not, tbl.ok) + } } } @@ -6816,7 +8144,9 @@ func (wfc *windowFrameBoundChecker) Enter(inNode ast.Node) (outNode ast.Node, sk wfc.fb = inNode.(*ast.FrameBound) if wfc.fb.Unit != ast.TimeUnitInvalid { _, ok := wfc.fb.Expr.(ast.ValueExpr) - require.False(wfc.t, ok) + if ok { + wfc.t.Fatal("expected false") + } } } return inNode, false @@ -6852,20 +8182,30 @@ func TestVisitFrameBound(t *testing.T) { } for _, tbl := range table { stmt, err := p.ParseOneStmt(tbl.s, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } checker := windowFrameBoundChecker{t: t} stmt.Accept(&checker) - require.Equal(t, tbl.exprRc, checker.exprRc) - require.Equal(t, tbl.unit, checker.unit) + if !reflect.DeepEqual(tbl.exprRc, checker.exprRc) { + t.Fatalf("got %v, want %v", checker.exprRc, tbl.exprRc) + } + if !reflect.DeepEqual(tbl.unit, checker.unit) { + t.Fatalf("got %v, want %v", checker.unit, tbl.unit) + } } } func TestFieldText(t *testing.T) { p := parser.New() stmts, _, err := p.Parse("select a from t", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } tmp := stmts[0].(*ast.SelectStmt) - require.Equal(t, "a", tmp.Fields.Fields[0].Text()) + if !reflect.DeepEqual("a", tmp.Fields.Fields[0].Text()) { + t.Fatalf("got %v, want %v", tmp.Fields.Fields[0].Text(), "a") + } sqls := []string{ "trace select a from t", @@ -6874,10 +8214,16 @@ func TestFieldText(t *testing.T) { } for _, sql := range sqls { stmts, _, err = p.Parse(sql, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } traceStmt := stmts[0].(*ast.TraceStmt) - require.Equal(t, sql, traceStmt.Text()) - require.Equal(t, "select a from t", traceStmt.Stmt.Text()) + if !reflect.DeepEqual(sql, traceStmt.Text()) { + t.Fatalf("got %v, want %v", traceStmt.Text(), sql) + } + if !reflect.DeepEqual("select a from t", traceStmt.Stmt.Text()) { + t.Fatalf("got %v, want %v", traceStmt.Stmt.Text(), "select a from t") + } } } @@ -6890,7 +8236,9 @@ func TestQuotedSystemVariables(t *testing.T) { "", "", ) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } ss := st.(*ast.SelectStmt) expected := []*ast.VariableExpr{ { @@ -6938,15 +8286,27 @@ func TestQuotedSystemVariables(t *testing.T) { }, } - require.Len(t, ss.Fields.Fields, len(expected)) + if got := len(ss.Fields.Fields); got != len(expected) { + t.Fatalf("expected length %d, got %d", len(expected), got) + } for i, field := range ss.Fields.Fields { ve := field.Expr.(*ast.VariableExpr) comment := fmt.Sprintf("field %d, ve = %v", i, ve) - require.Equal(t, expected[i].Name, ve.Name, comment) - require.Equal(t, expected[i].IsGlobal, ve.IsGlobal, comment) - require.Equal(t, expected[i].IsInstance, ve.IsInstance, comment) - require.Equal(t, expected[i].IsSystem, ve.IsSystem, comment) - require.Equal(t, expected[i].ExplicitScope, ve.ExplicitScope, comment) + if !reflect.DeepEqual(expected[i].Name, ve.Name) { + t.Fatalf("%v: got %v, want %v", comment, ve.Name, expected[i].Name) + } + if !reflect.DeepEqual(expected[i].IsGlobal, ve.IsGlobal) { + t.Fatalf("%v: got %v, want %v", comment, ve.IsGlobal, expected[i].IsGlobal) + } + if !reflect.DeepEqual(expected[i].IsInstance, ve.IsInstance) { + t.Fatalf("%v: got %v, want %v", comment, ve.IsInstance, expected[i].IsInstance) + } + if !reflect.DeepEqual(expected[i].IsSystem, ve.IsSystem) { + t.Fatalf("%v: got %v, want %v", comment, ve.IsSystem, expected[i].IsSystem) + } + if !reflect.DeepEqual(expected[i].ExplicitScope, ve.ExplicitScope) { + t.Fatalf("%v: got %v, want %v", comment, ve.ExplicitScope, expected[i].ExplicitScope) + } } } @@ -6959,7 +8319,9 @@ func TestQuotedVariableColumnName(t *testing.T) { "", "", ) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } ss := st.(*ast.SelectStmt) expected := []string{ "@abc", @@ -6975,9 +8337,13 @@ func TestQuotedVariableColumnName(t *testing.T) { "@", } - require.Len(t, ss.Fields.Fields, len(expected)) + if got := len(ss.Fields.Fields); got != len(expected) { + t.Fatalf("expected length %d, got %d", len(expected), got) + } for i, field := range ss.Fields.Fields { - require.Equal(t, expected[i], field.Text()) + if !reflect.DeepEqual(expected[i], field.Text()) { + t.Fatalf("got %v, want %v", field.Text(), expected[i]) + } } } @@ -6985,14 +8351,26 @@ func TestCharset(t *testing.T) { p := parser.New() st, err := p.ParseOneStmt("ALTER SCHEMA GLOBAL DEFAULT CHAR SET utf8mb4", "", "") - require.NoError(t, err) - require.NotNil(t, st.(*ast.AlterDatabaseStmt)) + if err != nil { + t.Fatal(err) + } + if st.(*ast.AlterDatabaseStmt) == nil { + t.Fatal("expected non-nil") + } st, err = p.ParseOneStmt("ALTER DATABASE CHAR SET = utf8mb4", "", "") - require.NoError(t, err) - require.NotNil(t, st.(*ast.AlterDatabaseStmt)) + if err != nil { + t.Fatal(err) + } + if st.(*ast.AlterDatabaseStmt) == nil { + t.Fatal("expected non-nil") + } st, err = p.ParseOneStmt("ALTER DATABASE DEFAULT CHAR SET = utf8mb4", "", "") - require.NoError(t, err) - require.NotNil(t, st.(*ast.AlterDatabaseStmt)) + if err != nil { + t.Fatal(err) + } + if st.(*ast.AlterDatabaseStmt) == nil { + t.Fatal("expected non-nil") + } } func TestUnderscoreCharset(t *testing.T) { @@ -7012,11 +8390,17 @@ func TestUnderscoreCharset(t *testing.T) { sql := fmt.Sprintf("select hex(_%s '3F')", tt.cs) _, err := p.ParseOneStmt(sql, "", "") if tt.parseFail { - require.EqualError(t, err, fmt.Sprintf("line 1 column %d near \"'3F')\" ", len(tt.cs)+17)) + if err == nil || err.Error() != fmt.Sprintf("line 1 column %d near \"'3F')\" ", len(tt.cs)+17) { + t.Fatalf("expected error %q, got %v", fmt.Sprintf("line 1 column %d near \"'3F')\" ", len(tt.cs)+17), err) + } } else if tt.unSupport { - require.EqualError(t, err, ast.ErrUnknownCharacterSet.GenWithStack("Unsupported character introducer: '%-.64s'", tt.cs).Error()) + if err == nil || err.Error() != ast.ErrUnknownCharacterSet.GenWithStack("Unsupported character introducer: '%-.64s'", tt.cs).Error() { + t.Fatalf("expected error %q, got %v", ast.ErrUnknownCharacterSet.GenWithStack("Unsupported character introducer: '%-.64s'", tt.cs).Error(), err) + } } else { - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } } } } @@ -7025,45 +8409,83 @@ func TestFulltextSearch(t *testing.T) { p := parser.New() st, err := p.ParseOneStmt("SELECT * FROM fulltext_test WHERE MATCH(content) AGAINST('search')", "", "") - require.NoError(t, err) - require.NotNil(t, st.(*ast.SelectStmt)) + if err != nil { + t.Fatal(err) + } + if st.(*ast.SelectStmt) == nil { + t.Fatal("expected non-nil") + } st, err = p.ParseOneStmt("SELECT * FROM fulltext_test WHERE MATCH() AGAINST('search')", "", "") - require.Error(t, err) - require.Nil(t, st) + if err == nil { + t.Fatal("expected error") + } + if st != nil { + t.Fatalf("expected nil, got %v", st) + } st, err = p.ParseOneStmt("SELECT * FROM fulltext_test WHERE MATCH(content) AGAINST()", "", "") - require.Error(t, err) - require.Nil(t, st) + if err == nil { + t.Fatal("expected error") + } + if st != nil { + t.Fatalf("expected nil, got %v", st) + } st, err = p.ParseOneStmt("SELECT * FROM fulltext_test WHERE MATCH(content) AGAINST('search' IN)", "", "") - require.Error(t, err) - require.Nil(t, st) + if err == nil { + t.Fatal("expected error") + } + if st != nil { + t.Fatalf("expected nil, got %v", st) + } st, err = p.ParseOneStmt("SELECT * FROM fulltext_test WHERE MATCH(content) AGAINST('search' IN BOOLEAN MODE WITH QUERY EXPANSION)", "", "") - require.Error(t, err) - require.Nil(t, st) + if err == nil { + t.Fatal("expected error") + } + if st != nil { + t.Fatalf("expected nil, got %v", st) + } st, err = p.ParseOneStmt("SELECT * FROM fulltext_test WHERE MATCH(title,content) AGAINST('search' IN NATURAL LANGUAGE MODE)", "", "") - require.NoError(t, err) - require.NotNil(t, st.(*ast.SelectStmt)) + if err != nil { + t.Fatal(err) + } + if st.(*ast.SelectStmt) == nil { + t.Fatal("expected non-nil") + } writer := bytes.NewBufferString("") st.(*ast.SelectStmt).Where.Format(writer) - require.Equal(t, "MATCH(title,content) AGAINST(\"search\")", writer.String()) + if !reflect.DeepEqual("MATCH(title,content) AGAINST(\"search\")", writer.String()) { + t.Fatalf("got %v, want %v", writer.String(), "MATCH(title,content) AGAINST(\"search\")") + } st, err = p.ParseOneStmt("SELECT * FROM fulltext_test WHERE MATCH(title,content) AGAINST('search' IN BOOLEAN MODE)", "", "") - require.NoError(t, err) - require.NotNil(t, st.(*ast.SelectStmt)) + if err != nil { + t.Fatal(err) + } + if st.(*ast.SelectStmt) == nil { + t.Fatal("expected non-nil") + } writer.Reset() st.(*ast.SelectStmt).Where.Format(writer) - require.Equal(t, "MATCH(title,content) AGAINST(\"search\" IN BOOLEAN MODE)", writer.String()) + if !reflect.DeepEqual("MATCH(title,content) AGAINST(\"search\" IN BOOLEAN MODE)", writer.String()) { + t.Fatalf("got %v, want %v", writer.String(), "MATCH(title,content) AGAINST(\"search\" IN BOOLEAN MODE)") + } st, err = p.ParseOneStmt("SELECT * FROM fulltext_test WHERE MATCH(title,content) AGAINST('search' WITH QUERY EXPANSION)", "", "") - require.NoError(t, err) - require.NotNil(t, st.(*ast.SelectStmt)) + if err != nil { + t.Fatal(err) + } + if st.(*ast.SelectStmt) == nil { + t.Fatal("expected non-nil") + } writer.Reset() st.(*ast.SelectStmt).Where.Format(writer) - require.Equal(t, "MATCH(title,content) AGAINST(\"search\" WITH QUERY EXPANSION)", writer.String()) + if !reflect.DeepEqual("MATCH(title,content) AGAINST(\"search\" WITH QUERY EXPANSION)", writer.String()) { + t.Fatalf("got %v, want %v", writer.String(), "MATCH(title,content) AGAINST(\"search\" WITH QUERY EXPANSION)") + } } func TestStartTransaction(t *testing.T) { @@ -7094,8 +8516,12 @@ func TestSignedInt64OutOfRange(t *testing.T) { for _, s := range cases { _, err := p.ParseOneStmt(s, "", "") - require.Error(t, err) - require.Contains(t, err.Error(), "out of range") + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "out of range") { + t.Fatalf("expected %q to contain %q", err.Error(), "out of range") + } } } @@ -7303,17 +8729,37 @@ func TestStatisticsOps(t *testing.T) { p := parser.New() sms, _, err := p.Parse("create statistics if not exists stats1 (cardinality) on t(a,b,c)", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } v, ok := sms[0].(*ast.CreateStatisticsStmt) - require.True(t, ok) - require.True(t, v.IfNotExists) - require.Equal(t, "stats1", v.StatsName) - require.Equal(t, ast.StatsTypeCardinality, v.StatsType) - require.Equal(t, ast.CIStr{O: "t", L: "t"}, v.Table.Name) - require.Len(t, v.Columns, 3) - require.Equal(t, ast.CIStr{O: "a", L: "a"}, v.Columns[0].Name) - require.Equal(t, ast.CIStr{O: "b", L: "b"}, v.Columns[1].Name) - require.Equal(t, ast.CIStr{O: "c", L: "c"}, v.Columns[2].Name) + if !(ok) { + t.Fatal("expected true") + } + if !(v.IfNotExists) { + t.Fatal("expected true") + } + if !reflect.DeepEqual("stats1", v.StatsName) { + t.Fatalf("got %v, want %v", v.StatsName, "stats1") + } + if !reflect.DeepEqual(ast.StatsTypeCardinality, v.StatsType) { + t.Fatalf("got %v, want %v", v.StatsType, ast.StatsTypeCardinality) + } + if !reflect.DeepEqual(ast.CIStr{O: "t", L: "t"}, v.Table.Name) { + t.Fatalf("got %v, want %v", v.Table.Name, ast.CIStr{O: "t", L: "t"}) + } + if got := len(v.Columns); got != 3 { + t.Fatalf("expected length %d, got %d", 3, got) + } + if !reflect.DeepEqual(ast.CIStr{O: "a", L: "a"}, v.Columns[0].Name) { + t.Fatalf("got %v, want %v", v.Columns[0].Name, ast.CIStr{O: "a", L: "a"}) + } + if !reflect.DeepEqual(ast.CIStr{O: "b", L: "b"}, v.Columns[1].Name) { + t.Fatalf("got %v, want %v", v.Columns[1].Name, ast.CIStr{O: "b", L: "b"}) + } + if !reflect.DeepEqual(ast.CIStr{O: "c", L: "c"}, v.Columns[2].Name) { + t.Fatalf("got %v, want %v", v.Columns[2].Name, ast.CIStr{O: "c", L: "c"}) + } } func TestHighNotPrecedenceMode(t *testing.T) { @@ -7321,42 +8767,74 @@ func TestHighNotPrecedenceMode(t *testing.T) { var sb strings.Builder sms, _, err := p.Parse("SELECT NOT 1 BETWEEN -5 AND 5", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } v, ok := sms[0].(*ast.SelectStmt) - require.True(t, ok) + if !(ok) { + t.Fatal("expected true") + } v1, ok := v.Fields.Fields[0].Expr.(*ast.UnaryOperationExpr) - require.True(t, ok) - require.Equal(t, opcode.Not, v1.Op) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual(opcode.Not, v1.Op) { + t.Fatalf("got %v, want %v", v1.Op, opcode.Not) + } err = sms[0].Restore(NewRestoreCtx(DefaultRestoreFlags, &sb)) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } restoreSQL := sb.String() - require.Equal(t, "SELECT NOT 1 BETWEEN -5 AND 5", restoreSQL) + if !reflect.DeepEqual("SELECT NOT 1 BETWEEN -5 AND 5", restoreSQL) { + t.Fatalf("got %v, want %v", restoreSQL, "SELECT NOT 1 BETWEEN -5 AND 5") + } sb.Reset() sms, _, err = p.Parse("SELECT !1 BETWEEN -5 AND 5", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } v, ok = sms[0].(*ast.SelectStmt) - require.True(t, ok) + if !(ok) { + t.Fatal("expected true") + } _, ok = v.Fields.Fields[0].Expr.(*ast.BetweenExpr) - require.True(t, ok) + if !(ok) { + t.Fatal("expected true") + } err = sms[0].Restore(NewRestoreCtx(DefaultRestoreFlags, &sb)) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } restoreSQL = sb.String() - require.Equal(t, "SELECT !1 BETWEEN -5 AND 5", restoreSQL) + if !reflect.DeepEqual("SELECT !1 BETWEEN -5 AND 5", restoreSQL) { + t.Fatalf("got %v, want %v", restoreSQL, "SELECT !1 BETWEEN -5 AND 5") + } sb.Reset() p = parser.New() p.SetSQLMode(mysql.ModeHighNotPrecedence) sms, _, err = p.Parse("SELECT NOT 1 BETWEEN -5 AND 5", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } v, ok = sms[0].(*ast.SelectStmt) - require.True(t, ok) + if !(ok) { + t.Fatal("expected true") + } _, ok = v.Fields.Fields[0].Expr.(*ast.BetweenExpr) - require.True(t, ok) + if !(ok) { + t.Fatal("expected true") + } err = sms[0].Restore(NewRestoreCtx(DefaultRestoreFlags, &sb)) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } restoreSQL = sb.String() - require.Equal(t, "SELECT !1 BETWEEN -5 AND 5", restoreSQL) + if !reflect.DeepEqual("SELECT !1 BETWEEN -5 AND 5", restoreSQL) { + t.Fatalf("got %v, want %v", restoreSQL, "SELECT !1 BETWEEN -5 AND 5") + } } // For CTE @@ -7471,10 +8949,14 @@ func TestWithoutCharsetFlags(t *testing.T) { for _, tbl := range cases { stmts, _, err := p.Parse(tbl.src, "", "") if !tbl.ok { - require.Error(t, err) + if err == nil { + t.Fatal("expected error") + } continue } - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } // restore correctness test var sb strings.Builder restoreSQLs := "" @@ -7483,14 +8965,18 @@ func TestWithoutCharsetFlags(t *testing.T) { ctx := NewRestoreCtx(tbl.flag, &sb) ctx.DefaultDB = "test" err = stmt.Restore(ctx) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } restoreSQL := sb.String() if restoreSQLs != "" { restoreSQLs += "; " } restoreSQLs += restoreSQL } - require.Equal(t, tbl.restore, restoreSQLs) + if !reflect.DeepEqual(tbl.restore, restoreSQLs) { + t.Fatalf("got %v, want %v", restoreSQLs, tbl.restore) + } } } @@ -7507,23 +8993,31 @@ func TestRestoreBinOpWithBrackets(t *testing.T) { _, _, err := p.Parse(tbl.src, "", "") comment := fmt.Sprintf("source %v", tbl.src) if !tbl.ok { - require.Error(t, err, comment) + if err == nil { + t.Fatal(comment) + } continue } - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } // restore correctness test if tbl.ok { var sb strings.Builder comment := fmt.Sprintf("source %v", tbl.src) stmts, _, err := p.Parse(tbl.src, "", "") - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } restoreSQLs := "" for _, stmt := range stmts { sb.Reset() ctx := NewRestoreCtx(RestoreStringSingleQuotes|RestoreSpacesAroundBinaryOperation|RestoreBracketAroundBinaryOperation|RestoreStringWithoutCharset|RestoreNameBackQuotes, &sb) ctx.DefaultDB = "test" err = stmt.Restore(ctx) - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } restoreSQL := sb.String() comment = fmt.Sprintf("source %v; restore %v", tbl.src, restoreSQL) if restoreSQLs != "" { @@ -7532,7 +9026,9 @@ func TestRestoreBinOpWithBrackets(t *testing.T) { restoreSQLs += restoreSQL } comment = fmt.Sprintf("restore %v; expect %v", restoreSQLs, tbl.restore) - require.Equal(t, tbl.restore, restoreSQLs, comment) + if !reflect.DeepEqual(tbl.restore, restoreSQLs) { + t.Fatalf("%v: got %v, want %v", comment, restoreSQLs, tbl.restore) + } } } } @@ -7565,23 +9061,31 @@ func TestCTEBindings(t *testing.T) { _, _, err := p.Parse(tbl.src, "", "") comment := fmt.Sprintf("source %v", tbl.src) if !tbl.ok { - require.Error(t, err, comment) + if err == nil { + t.Fatal(comment) + } continue } - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } // restore correctness test if tbl.ok { var sb strings.Builder comment := fmt.Sprintf("source %v", tbl.src) stmts, _, err := p.Parse(tbl.src, "", "") - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } restoreSQLs := "" for _, stmt := range stmts { sb.Reset() ctx := NewRestoreCtx(RestoreStringSingleQuotes|RestoreSpacesAroundBinaryOperation|RestoreStringWithoutCharset|RestoreNameBackQuotes, &sb) ctx.DefaultDB = "test" err = stmt.Restore(ctx) - require.NoError(t, err, comment) + if err != nil { + t.Fatalf("%v: %v", comment, err) + } restoreSQL := sb.String() comment = fmt.Sprintf("source %v; restore %v", tbl.src, restoreSQL) if restoreSQLs != "" { @@ -7590,7 +9094,9 @@ func TestCTEBindings(t *testing.T) { restoreSQLs += restoreSQL } comment = fmt.Sprintf("restore %v; expect %v", restoreSQLs, tbl.restore) - require.Equal(t, tbl.restore, restoreSQLs, comment) + if !reflect.DeepEqual(tbl.restore, restoreSQLs) { + t.Fatalf("%v: got %v, want %v", comment, restoreSQLs, tbl.restore) + } } } } @@ -7621,35 +9127,71 @@ func TestPlanReplayer(t *testing.T) { p := parser.New() sms, _, err := p.Parse("PLAN REPLAYER DUMP EXPLAIN SELECT a FROM t", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } v, ok := sms[0].(*ast.PlanReplayerStmt) - require.True(t, ok) - require.Equal(t, "SELECT a FROM t", v.Stmt.Text()) - require.False(t, v.Analyze) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual("SELECT a FROM t", v.Stmt.Text()) { + t.Fatalf("got %v, want %v", v.Stmt.Text(), "SELECT a FROM t") + } + if v.Analyze { + t.Fatal("expected false") + } sms, _, err = p.Parse("PLAN REPLAYER DUMP EXPLAIN ANALYZE SELECT a FROM t", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } v, ok = sms[0].(*ast.PlanReplayerStmt) - require.True(t, ok) - require.Equal(t, "SELECT a FROM t", v.Stmt.Text()) - require.True(t, v.Analyze) + if !(ok) { + t.Fatal("expected true") + } + if !reflect.DeepEqual("SELECT a FROM t", v.Stmt.Text()) { + t.Fatalf("got %v, want %v", v.Stmt.Text(), "SELECT a FROM t") + } + if !(v.Analyze) { + t.Fatal("expected true") + } // Multiple SQL records: EXPLAIN ( "sql1", "sql2", ... ) sms, _, err = p.Parse("PLAN REPLAYER DUMP EXPLAIN ('SELECT * FROM t1', 'SELECT * FROM t2')", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } v, ok = sms[0].(*ast.PlanReplayerStmt) - require.True(t, ok) - require.Nil(t, v.Stmt) - require.False(t, v.Analyze) - require.Equal(t, []string{"SELECT * FROM t1", "SELECT * FROM t2"}, v.StmtList) + if !(ok) { + t.Fatal("expected true") + } + if v.Stmt != nil { + t.Fatalf("expected nil, got %v", v.Stmt) + } + if v.Analyze { + t.Fatal("expected false") + } + if !reflect.DeepEqual([]string{"SELECT * FROM t1", "SELECT * FROM t2"}, v.StmtList) { + t.Fatalf("got %v, want %v", v.StmtList, []string{"SELECT * FROM t1", "SELECT * FROM t2"}) + } sms, _, err = p.Parse("PLAN REPLAYER DUMP EXPLAIN ANALYZE ('SELECT * FROM t1')", "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } v, ok = sms[0].(*ast.PlanReplayerStmt) - require.True(t, ok) - require.Nil(t, v.Stmt) - require.True(t, v.Analyze) - require.Equal(t, []string{"SELECT * FROM t1"}, v.StmtList) + if !(ok) { + t.Fatal("expected true") + } + if v.Stmt != nil { + t.Fatalf("expected nil, got %v", v.Stmt) + } + if !(v.Analyze) { + t.Fatal("expected true") + } + if !reflect.DeepEqual([]string{"SELECT * FROM t1"}, v.StmtList) { + t.Fatalf("got %v, want %v", v.StmtList, []string{"SELECT * FROM t1"}) + } } func TestTrafficStmt(t *testing.T) { @@ -7686,22 +9228,36 @@ func TestTrafficStmt(t *testing.T) { for _, tbl := range table { stmts, _, err := p.Parse(tbl.src, "", "") if !tbl.ok { - require.Error(t, err, tbl.src) + if err == nil { + t.Fatal(tbl.src) + } continue } - require.NoError(t, err, tbl.src) - require.Len(t, stmts, 1) + if err != nil { + t.Fatalf("%v: %v", tbl.src, err) + } + if got := len(stmts); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } v, ok := stmts[0].(*ast.TrafficStmt) - require.True(t, ok) + if !(ok) { + t.Fatal("expected true") + } switch v.OpType { case ast.TrafficOpCapture, ast.TrafficOpReplay: - require.Equal(t, "/tmp", v.Dir) + if !reflect.DeepEqual("/tmp", v.Dir) { + t.Fatalf("got %v, want %v", v.Dir, "/tmp") + } } sb.Reset() ctx := NewRestoreCtx(RestoreStringSingleQuotes|RestoreSpacesAroundBinaryOperation|RestoreStringWithoutCharset|RestoreNameBackQuotes, &sb) err = v.Restore(ctx) - require.NoError(t, err) - require.Equal(t, tbl.restore, sb.String()) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(tbl.restore, sb.String()) { + t.Fatalf("got %v, want %v", sb.String(), tbl.restore) + } } } @@ -7710,25 +9266,43 @@ func TestGBKEncoding(t *testing.T) { gbkEncoding, _ := charset.Lookup("gbk") encoder := gbkEncoding.NewEncoder() sql, err := encoder.String("create table 测试表 (测试列 varchar(255) default 'GBK测试用例');") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } stmt, _, err := p.ParseSQL(sql) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } checker := &gbkEncodingChecker{} _, _ = stmt[0].Accept(checker) - require.NotEqual(t, "测试表", checker.tblName) - require.NotEqual(t, "测试列", checker.colName) + if reflect.DeepEqual("测试表", checker.tblName) { + t.Fatalf("expected values to differ, both are %v", checker.tblName) + } + if reflect.DeepEqual("测试列", checker.colName) { + t.Fatalf("expected values to differ, both are %v", checker.colName) + } gbkOpt := parser.CharsetClient("gbk") stmt, _, err = p.ParseSQL(sql, gbkOpt) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } _, _ = stmt[0].Accept(checker) - require.Equal(t, "测试表", checker.tblName) - require.Equal(t, "测试列", checker.colName) - require.Equal(t, "GBK测试用例", checker.expr) + if !reflect.DeepEqual("测试表", checker.tblName) { + t.Fatalf("got %v, want %v", checker.tblName, "测试表") + } + if !reflect.DeepEqual("测试列", checker.colName) { + t.Fatalf("got %v, want %v", checker.colName, "测试列") + } + if !reflect.DeepEqual("GBK测试用例", checker.expr) { + t.Fatalf("got %v, want %v", checker.expr, "GBK测试用例") + } _, _, err = p.ParseSQL("select _gbk '\xc6\x5c' from dual;") - require.Error(t, err) + if err == nil { + t.Fatal("expected error") + } for _, test := range []struct { sql string @@ -7746,9 +9320,13 @@ func TestGBKEncoding(t *testing.T) { } { _, _, err = p.ParseSQL(test.sql, gbkOpt) if test.err { - require.Error(t, err, test.sql) + if err == nil { + t.Fatal(test.sql) + } } else { - require.NoError(t, err, test.sql) + if err != nil { + t.Fatalf("%v: %v", test.sql, err) + } } } } @@ -7758,25 +9336,43 @@ func TestGB18030Encoding(t *testing.T) { gb18030Encoding, _ := charset.Lookup("gb18030") encoder := gb18030Encoding.NewEncoder() sql, err := encoder.String("create table 测试表 (测试列 varchar(255) default 'GB18030测试用例');") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } stmt, _, err := p.ParseSQL(sql) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } checker := &gbkEncodingChecker{} _, _ = stmt[0].Accept(checker) - require.NotEqual(t, "测试表", checker.tblName) - require.NotEqual(t, "测试列", checker.colName) + if reflect.DeepEqual("测试表", checker.tblName) { + t.Fatalf("expected values to differ, both are %v", checker.tblName) + } + if reflect.DeepEqual("测试列", checker.colName) { + t.Fatalf("expected values to differ, both are %v", checker.colName) + } gb18030Opt := parser.CharsetClient("gb18030") stmt, _, err = p.ParseSQL(sql, gb18030Opt) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } _, _ = stmt[0].Accept(checker) - require.Equal(t, "测试表", checker.tblName) - require.Equal(t, "测试列", checker.colName) - require.Equal(t, "GB18030测试用例", checker.expr) + if !reflect.DeepEqual("测试表", checker.tblName) { + t.Fatalf("got %v, want %v", checker.tblName, "测试表") + } + if !reflect.DeepEqual("测试列", checker.colName) { + t.Fatalf("got %v, want %v", checker.colName, "测试列") + } + if !reflect.DeepEqual("GB18030测试用例", checker.expr) { + t.Fatalf("got %v, want %v", checker.expr, "GB18030测试用例") + } _, _, err = p.ParseSQL("select _gbk '\xc6\x5c' from dual;") - require.Error(t, err) + if err == nil { + t.Fatal("expected error") + } for _, test := range []struct { sql string @@ -7794,9 +9390,13 @@ func TestGB18030Encoding(t *testing.T) { } { _, _, err = p.ParseSQL(test.sql, gb18030Opt) if test.err { - require.Error(t, err, test.sql) + if err == nil { + t.Fatal(test.sql) + } } else { - require.NoError(t, err, test.sql) + if err != nil { + t.Fatalf("%v: %v", test.sql, err) + } } } } @@ -7834,9 +9434,13 @@ func TestInsertStatementMemoryAllocation(t *testing.T) { var oldStats, newStats runtime.MemStats runtime.ReadMemStats(&oldStats) _, err := parser.New().ParseOneStmt(sql, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } runtime.ReadMemStats(&newStats) - require.Less(t, int(newStats.TotalAlloc-oldStats.TotalAlloc), 1024*500) + if !(int(newStats.TotalAlloc-oldStats.TotalAlloc) < 1024*500) { + t.Fatalf("expected %v < %v", int(newStats.TotalAlloc-oldStats.TotalAlloc), 1024*500) + } } func TestCharsetIntroducer(t *testing.T) { @@ -7844,11 +9448,17 @@ func TestCharsetIntroducer(t *testing.T) { defer charset.RemoveCharset("gbk") // `_gbk` is treated as a character set. _, _, err := p.Parse("select _gbk 'a';", "", "") - require.EqualError(t, err, "[ddl:1115]Unsupported character introducer: 'gbk'") + if err == nil || err.Error() != "[ddl:1115]Unsupported character introducer: 'gbk'" { + t.Fatalf("expected error %q, got %v", "[ddl:1115]Unsupported character introducer: 'gbk'", err) + } _, _, err = p.Parse("select _gbk 0x1234;", "", "") - require.EqualError(t, err, "[ddl:1115]Unsupported character introducer: 'gbk'") + if err == nil || err.Error() != "[ddl:1115]Unsupported character introducer: 'gbk'" { + t.Fatalf("expected error %q, got %v", "[ddl:1115]Unsupported character introducer: 'gbk'", err) + } _, _, err = p.Parse("select _gbk 0b101001;", "", "") - require.EqualError(t, err, "[ddl:1115]Unsupported character introducer: 'gbk'") + if err == nil || err.Error() != "[ddl:1115]Unsupported character introducer: 'gbk'" { + t.Fatalf("expected error %q, got %v", "[ddl:1115]Unsupported character introducer: 'gbk'", err) + } } func TestNonTransactionalDML(t *testing.T) { @@ -7972,30 +9582,52 @@ func TestIssue45898(t *testing.T) { p := parser.New() p.ParseSQL("a.") stmts, _, err := p.ParseSQL("select count(1) from t") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } var sb strings.Builder restoreCtx := NewRestoreCtx(DefaultRestoreFlags, &sb) sb.Reset() stmts[0].Restore(restoreCtx) - require.Equal(t, "SELECT COUNT(1) FROM `t`", sb.String()) + if !reflect.DeepEqual("SELECT COUNT(1) FROM `t`", sb.String()) { + t.Fatalf("got %v, want %v", sb.String(), "SELECT COUNT(1) FROM `t`") + } } func TestMultiStmt(t *testing.T) { p := parser.New() stmts, _, err := p.Parse("SELECT 'foo'; SELECT 'foo;bar','baz'; select 'foo' , 'bar' , 'baz' ;select 1", "", "") - require.NoError(t, err) - require.Equal(t, len(stmts), 4) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(len(stmts), 4) { + t.Fatalf("got %v, want %v", 4, len(stmts)) + } stmt1 := stmts[0].(*ast.SelectStmt) stmt2 := stmts[1].(*ast.SelectStmt) stmt3 := stmts[2].(*ast.SelectStmt) stmt4 := stmts[3].(*ast.SelectStmt) - require.Equal(t, "'foo'", stmt1.Fields.Fields[0].Text()) - require.Equal(t, "'foo;bar'", stmt2.Fields.Fields[0].Text()) - require.Equal(t, "'baz'", stmt2.Fields.Fields[1].Text()) - require.Equal(t, "'foo'", stmt3.Fields.Fields[0].Text()) - require.Equal(t, "'bar'", stmt3.Fields.Fields[1].Text()) - require.Equal(t, "'baz'", stmt3.Fields.Fields[2].Text()) - require.Equal(t, "1", stmt4.Fields.Fields[0].Text()) + if !reflect.DeepEqual("'foo'", stmt1.Fields.Fields[0].Text()) { + t.Fatalf("got %v, want %v", stmt1.Fields.Fields[0].Text(), "'foo'") + } + if !reflect.DeepEqual("'foo;bar'", stmt2.Fields.Fields[0].Text()) { + t.Fatalf("got %v, want %v", stmt2.Fields.Fields[0].Text(), "'foo;bar'") + } + if !reflect.DeepEqual("'baz'", stmt2.Fields.Fields[1].Text()) { + t.Fatalf("got %v, want %v", stmt2.Fields.Fields[1].Text(), "'baz'") + } + if !reflect.DeepEqual("'foo'", stmt3.Fields.Fields[0].Text()) { + t.Fatalf("got %v, want %v", stmt3.Fields.Fields[0].Text(), "'foo'") + } + if !reflect.DeepEqual("'bar'", stmt3.Fields.Fields[1].Text()) { + t.Fatalf("got %v, want %v", stmt3.Fields.Fields[1].Text(), "'bar'") + } + if !reflect.DeepEqual("'baz'", stmt3.Fields.Fields[2].Text()) { + t.Fatalf("got %v, want %v", stmt3.Fields.Fields[2].Text(), "'baz'") + } + if !reflect.DeepEqual("1", stmt4.Fields.Fields[0].Text()) { + t.Fatalf("got %v, want %v", stmt4.Fields.Fields[0].Text(), "1") + } } // https://dev.mysql.com/doc/refman/8.1/en/other-vendor-data-types.html diff --git a/parser/reserved_words_test.go b/parser/reserved_words_test.go index 91091be..3377df1 100644 --- a/parser/reserved_words_test.go +++ b/parser/reserved_words_test.go @@ -33,15 +33,21 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/sqlc-dev/marino/ast" - requires "github.com/stretchr/testify/require" + + "fmt" + "regexp" ) func TestCompareReservedWordsWithMySQL(t *testing.T) { parserFilename := "parser.y" parserFile, err := os.Open(parserFilename) - requires.NoError(t, err) + if err != nil { + t.Fatal(err) + } data, err := gio.ReadAll(parserFile) - requires.NoError(t, err) + if err != nil { + t.Fatal(err) + } content := string(data) reservedKeywordStartMarker := "\t/* The following tokens belong to ReservedKeyword. Notice: make sure these tokens are contained in ReservedKeyword. */" @@ -57,9 +63,13 @@ func TestCompareReservedWordsWithMySQL(t *testing.T) { p := New() db, err := dbsql.Open("mysql", "root@tcp(127.0.0.1:3306)/") - requires.NoError(t, err) + if err != nil { + t.Fatal(err) + } defer func() { - requires.NoError(t, db.Close()) + if db.Close() != nil { + t.Fatal(db.Close()) + } }() for _, kw := range reservedKeywords { @@ -84,12 +94,20 @@ func TestCompareReservedWordsWithMySQL(t *testing.T) { if _, ok := windowFuncTokenMap[kw]; !ok { // for some reason the query does parse even then the keyword is reserved in TiDB. _, _, err = p.Parse(query, "", "") - requires.Error(t, err) - requires.Regexp(t, errRegexp, err.Error()) + if err == nil { + t.Fatal("expected error") + } + if !regexp.MustCompile(errRegexp).MatchString(err.Error()) { + t.Fatalf("expected %q to match %q", err.Error(), errRegexp) + } } _, err = db.Exec(query) - requires.Error(t, err, query) - requires.Regexp(t, errRegexp, err.Error(), "MySQL suggests that '%s' should *not* be reserved!", kw) + if err == nil { + t.Fatal(query) + } + if !regexp.MustCompile(errRegexp).MatchString(err.Error()) { + t.Fatalf("%s: expected %q to match %q", fmt.Sprintf("MySQL suggests that '%s' should *not* be reserved!", kw), err.Error(), errRegexp) + } } for _, kws := range [][]string{unreservedKeywords, notKeywordTokens, tidbKeywords} { @@ -106,12 +124,20 @@ func TestCompareReservedWordsWithMySQL(t *testing.T) { query := "do (select 1 as " + kw + ")" stmts, _, err := p.Parse(query, "", "") - requires.NoError(t, err) - requires.Len(t, stmts, 1) - requires.IsType(t, &ast.DoStmt{}, stmts[0]) + if err != nil { + t.Fatal(err) + } + if got := len(stmts); got != 1 { + t.Fatalf("expected length %d, got %d", 1, got) + } + if _, ok := stmts[0].(*ast.DoStmt); !ok { + t.Fatalf("expected type %T, got %T", &ast.DoStmt{}, stmts[0]) + } _, err = db.Exec(query) - requires.NoErrorf(t, err, "MySQL suggests that '%s' should be reserved!", kw) + if err != nil { + t.Fatalf("%s: %v", fmt.Sprintf("MySQL suggests that '%s' should be reserved!", kw), err) + } } } } diff --git a/terror/terror_test.go b/terror/terror_test.go index 29910c3..f7faea4 100644 --- a/terror/terror_test.go +++ b/terror/terror_test.go @@ -22,59 +22,110 @@ import ( "testing" "github.com/pingcap/errors" - "github.com/stretchr/testify/require" + + "reflect" ) func TestErrCode(t *testing.T) { - require.Equal(t, ErrCode(1), CodeMissConnectionID) - require.Equal(t, ErrCode(2), CodeResultUndetermined) + if !reflect.DeepEqual(ErrCode(1), CodeMissConnectionID) { + t.Fatalf("got %v, want %v", CodeMissConnectionID, ErrCode(1)) + } + if !reflect.DeepEqual(ErrCode(2), CodeResultUndetermined) { + t.Fatalf("got %v, want %v", CodeResultUndetermined, ErrCode(2)) + } } func TestTError(t *testing.T) { - require.NotEmpty(t, ClassParser.String()) - require.NotEmpty(t, ClassOptimizer.String()) - require.NotEmpty(t, ClassKV.String()) - require.NotEmpty(t, ClassServer.String()) + if len(ClassParser.String()) == 0 { + t.Fatal("expected non-empty") + } + if len(ClassOptimizer.String()) == 0 { + t.Fatal("expected non-empty") + } + if len(ClassKV.String()) == 0 { + t.Fatal("expected non-empty") + } + if len(ClassServer.String()) == 0 { + t.Fatal("expected non-empty") + } parserErr := ClassParser.New(ErrCode(100), "error 100") - require.NotEmpty(t, parserErr.Error()) - require.True(t, ClassParser.EqualClass(parserErr)) - require.False(t, ClassParser.NotEqualClass(parserErr)) + if len(parserErr.Error()) == 0 { + t.Fatal("expected non-empty") + } + if !(ClassParser.EqualClass(parserErr)) { + t.Fatal("expected true") + } + if ClassParser.NotEqualClass(parserErr) { + t.Fatal("expected false") + } - require.False(t, ClassOptimizer.EqualClass(parserErr)) + if ClassOptimizer.EqualClass(parserErr) { + t.Fatal("expected false") + } optimizerErr := ClassOptimizer.New(ErrCode(2), "abc") - require.False(t, ClassOptimizer.EqualClass(errors.New("abc"))) - require.False(t, ClassOptimizer.EqualClass(nil)) - require.True(t, optimizerErr.Equal(optimizerErr.GenWithStack("def"))) - require.False(t, optimizerErr.Equal(nil)) - require.False(t, optimizerErr.Equal(errors.New("abc"))) + if ClassOptimizer.EqualClass(errors.New("abc")) { + t.Fatal("expected false") + } + if ClassOptimizer.EqualClass(nil) { + t.Fatal("expected false") + } + if !(optimizerErr.Equal(optimizerErr.GenWithStack("def"))) { + t.Fatal("expected true") + } + if optimizerErr.Equal(nil) { + t.Fatal("expected false") + } + if optimizerErr.Equal(errors.New("abc")) { + t.Fatal("expected false") + } // Test case for FastGen. - require.True(t, optimizerErr.Equal(optimizerErr.FastGen("def"))) - require.True(t, optimizerErr.Equal(optimizerErr.FastGen("def: %s", "def"))) + if !(optimizerErr.Equal(optimizerErr.FastGen("def"))) { + t.Fatal("expected true") + } + if !(optimizerErr.Equal(optimizerErr.FastGen("def: %s", "def"))) { + t.Fatal("expected true") + } kvErr := ClassKV.New(1062, "key already exist") e := kvErr.FastGen("Duplicate entry '%d' for key 'PRIMARY'", 1) - require.Equal(t, "[kv:1062]Duplicate entry '1' for key 'PRIMARY'", e.Error()) + if !reflect.DeepEqual("[kv:1062]Duplicate entry '1' for key 'PRIMARY'", e.Error()) { + t.Fatalf("got %v, want %v", e.Error(), "[kv:1062]Duplicate entry '1' for key 'PRIMARY'") + } sqlErr := ToSQLError(errors.Cause(e).(*Error)) - require.Equal(t, "Duplicate entry '1' for key 'PRIMARY'", sqlErr.Message) - require.Equal(t, uint16(1062), sqlErr.Code) + if !reflect.DeepEqual("Duplicate entry '1' for key 'PRIMARY'", sqlErr.Message) { + t.Fatalf("got %v, want %v", sqlErr.Message, "Duplicate entry '1' for key 'PRIMARY'") + } + if !reflect.DeepEqual(uint16(1062), sqlErr.Code) { + t.Fatalf("got %v, want %v", sqlErr.Code, uint16(1062)) + } err := errors.Trace(ErrCritical.GenWithStackByArgs("test")) - require.True(t, ErrCritical.Equal(err)) + if !(ErrCritical.Equal(err)) { + t.Fatal("expected true") + } err = errors.Trace(ErrCritical) - require.True(t, ErrCritical.Equal(err)) + if !(ErrCritical.Equal(err)) { + t.Fatal("expected true") + } } func TestJson(t *testing.T) { prevTErr := errors.Normalize("json test", errors.MySQLErrorCode(int(CodeExecResultIsEmpty))) buf, err := json.Marshal(prevTErr) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } var curTErr errors.Error err = json.Unmarshal(buf, &curTErr) - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } isEqual := prevTErr.Equal(&curTErr) - require.True(t, isEqual) + if !(isEqual) { + t.Fatal("expected true") + } } var predefinedErr = ClassExecutor.New(ErrCode(123), "predefiend error") @@ -90,40 +141,72 @@ func call() error { func TestErrorEqual(t *testing.T) { e1 := errors.New("test error") - require.NotNil(t, e1) + if e1 == nil { + t.Fatal("expected non-nil") + } e2 := errors.Trace(e1) - require.NotNil(t, e2) + if e2 == nil { + t.Fatal("expected non-nil") + } e3 := errors.Trace(e2) - require.NotNil(t, e3) + if e3 == nil { + t.Fatal("expected non-nil") + } - require.Equal(t, e1, errors.Cause(e2)) - require.Equal(t, e1, errors.Cause(e3)) - require.Equal(t, errors.Cause(e3), errors.Cause(e2)) + if !reflect.DeepEqual(e1, errors.Cause(e2)) { + t.Fatalf("got %v, want %v", errors.Cause(e2), e1) + } + if !reflect.DeepEqual(e1, errors.Cause(e3)) { + t.Fatalf("got %v, want %v", errors.Cause(e3), e1) + } + if !reflect.DeepEqual(errors.Cause(e3), errors.Cause(e2)) { + t.Fatalf("got %v, want %v", errors.Cause(e2), errors.Cause(e3)) + } e4 := errors.New("test error") - require.NotEqual(t, e1, errors.Cause(e4)) + if reflect.DeepEqual(e1, errors.Cause(e4)) { + t.Fatalf("expected values to differ, both are %v", errors.Cause(e4)) + } e5 := errors.Errorf("test error") - require.NotEqual(t, e1, errors.Cause(e5)) + if reflect.DeepEqual(e1, errors.Cause(e5)) { + t.Fatalf("expected values to differ, both are %v", errors.Cause(e5)) + } - require.True(t, ErrorEqual(e1, e2)) - require.True(t, ErrorEqual(e1, e3)) - require.True(t, ErrorEqual(e1, e4)) - require.True(t, ErrorEqual(e1, e5)) + if !(ErrorEqual(e1, e2)) { + t.Fatal("expected true") + } + if !(ErrorEqual(e1, e3)) { + t.Fatal("expected true") + } + if !(ErrorEqual(e1, e4)) { + t.Fatal("expected true") + } + if !(ErrorEqual(e1, e5)) { + t.Fatal("expected true") + } var e6 error - require.True(t, ErrorEqual(nil, nil)) - require.True(t, ErrorNotEqual(e1, e6)) + if !(ErrorEqual(nil, nil)) { + t.Fatal("expected true") + } + if !(ErrorNotEqual(e1, e6)) { + t.Fatal("expected true") + } code1 := ErrCode(9001) code2 := ErrCode(9002) te1 := ClassParser.Synthesize(code1, "abc") te3 := ClassKV.New(code1, "abc") te4 := ClassKV.New(code2, "abc") - require.False(t, ErrorEqual(te1, te3)) - require.False(t, ErrorEqual(te3, te4)) + if ErrorEqual(te1, te3) { + t.Fatal("expected false") + } + if ErrorEqual(te3, te4) { + t.Fatal("expected false") + } } func TestLog(t *testing.T) { @@ -161,7 +244,9 @@ func TestTraceAndLocation(t *testing.T) { sysStack++ } } - require.Equalf(t, 9, len(lines)-(2*sysStack), "stack =\n%s", stack) + if !reflect.DeepEqual(9, len(lines)-(2*sysStack)) { + t.Fatalf("%s: got %v, want %v", fmt.Sprintf("stack =\n%s", stack), len(lines)-(2*sysStack), 9) + } var containTerr bool for _, v := range lines { if strings.Contains(v, "terror_test.go") { @@ -169,5 +254,7 @@ func TestTraceAndLocation(t *testing.T) { break } } - require.True(t, containTerr) + if !(containTerr) { + t.Fatal("expected true") + } } diff --git a/types/etc_test.go b/types/etc_test.go index c6e9bc7..d1234e3 100644 --- a/types/etc_test.go +++ b/types/etc_test.go @@ -17,18 +17,25 @@ import ( "testing" "github.com/sqlc-dev/marino/mysql" - "github.com/stretchr/testify/require" + + "reflect" ) func TestStrToType(t *testing.T) { for tp, str := range type2Str { a := StrToType(str) - require.Equal(t, tp, a) + if !reflect.DeepEqual(tp, a) { + t.Fatalf("got %v, want %v", a, tp) + } } tp := StrToType("blob") - require.Equal(t, tp, mysql.TypeBlob) + if !reflect.DeepEqual(tp, mysql.TypeBlob) { + t.Fatalf("got %v, want %v", mysql.TypeBlob, tp) + } tp = StrToType("binary") - require.Equal(t, tp, mysql.TypeString) + if !reflect.DeepEqual(tp, mysql.TypeString) { + t.Fatalf("got %v, want %v", mysql.TypeString, tp) + } } diff --git a/types/field_type_test.go b/types/field_type_test.go index 62225c9..f5dd819 100644 --- a/types/field_type_test.go +++ b/types/field_type_test.go @@ -17,181 +17,301 @@ import ( "fmt" "testing" - "github.com/sqlc-dev/marino/parser" "github.com/sqlc-dev/marino/ast" "github.com/sqlc-dev/marino/charset" "github.com/sqlc-dev/marino/mysql" + "github.com/sqlc-dev/marino/parser" + // import parser_driver _ "github.com/sqlc-dev/marino/test_driver" . "github.com/sqlc-dev/marino/types" - "github.com/stretchr/testify/require" + + "reflect" ) func TestFieldType(t *testing.T) { ft := NewFieldType(mysql.TypeDuration) - require.Equal(t, UnspecifiedLength, ft.GetFlen()) - require.Equal(t, UnspecifiedLength, ft.GetDecimal()) + if !reflect.DeepEqual(UnspecifiedLength, ft.GetFlen()) { + t.Fatalf("got %v, want %v", ft.GetFlen(), UnspecifiedLength) + } + if !reflect.DeepEqual(UnspecifiedLength, ft.GetDecimal()) { + t.Fatalf("got %v, want %v", ft.GetDecimal(), UnspecifiedLength) + } ft.SetDecimal(5) - require.Equal(t, "time(5)", ft.String()) - require.False(t, HasCharset(ft)) + if !reflect.DeepEqual("time(5)", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "time(5)") + } + if HasCharset(ft) { + t.Fatal("expected false") + } ft = NewFieldType(mysql.TypeLong) ft.SetFlen(5) ft.SetFlag(mysql.UnsignedFlag | mysql.ZerofillFlag) - require.Equal(t, "int(5) UNSIGNED ZEROFILL", ft.String()) - require.Equal(t, "int(5) unsigned", ft.InfoSchemaStr()) - require.False(t, HasCharset(ft)) + if !reflect.DeepEqual("int(5) UNSIGNED ZEROFILL", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "int(5) UNSIGNED ZEROFILL") + } + if !reflect.DeepEqual("int(5) unsigned", ft.InfoSchemaStr()) { + t.Fatalf("got %v, want %v", ft.InfoSchemaStr(), "int(5) unsigned") + } + if HasCharset(ft) { + t.Fatal("expected false") + } ft = NewFieldType(mysql.TypeFloat) ft.SetFlen(12) // Default ft.SetDecimal(3) // Not Default - require.Equal(t, "float(12,3)", ft.String()) + if !reflect.DeepEqual("float(12,3)", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "float(12,3)") + } ft = NewFieldType(mysql.TypeFloat) ft.SetFlen(12) // Default ft.SetDecimal(-1) // Default - require.Equal(t, "float", ft.String()) + if !reflect.DeepEqual("float", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "float") + } ft = NewFieldType(mysql.TypeFloat) ft.SetFlen(5) // Not Default ft.SetDecimal(-1) // Default - require.Equal(t, "float", ft.String()) + if !reflect.DeepEqual("float", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "float") + } ft = NewFieldType(mysql.TypeFloat) ft.SetFlen(7) // Not Default ft.SetDecimal(3) // Not Default - require.Equal(t, "float(7,3)", ft.String()) - require.False(t, HasCharset(ft)) + if !reflect.DeepEqual("float(7,3)", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "float(7,3)") + } + if HasCharset(ft) { + t.Fatal("expected false") + } ft = NewFieldType(mysql.TypeDouble) ft.SetFlen(22) // Default ft.SetDecimal(3) // Not Default - require.Equal(t, "double(22,3)", ft.String()) + if !reflect.DeepEqual("double(22,3)", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "double(22,3)") + } ft = NewFieldType(mysql.TypeDouble) ft.SetFlen(22) // Default ft.SetDecimal(-1) // Default - require.Equal(t, "double", ft.String()) + if !reflect.DeepEqual("double", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "double") + } ft = NewFieldType(mysql.TypeDouble) ft.SetFlen(5) // Not Default ft.SetDecimal(-1) // Default - require.Equal(t, "double", ft.String()) + if !reflect.DeepEqual("double", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "double") + } ft = NewFieldType(mysql.TypeDouble) ft.SetFlen(7) // Not Default ft.SetDecimal(3) // Not Default - require.Equal(t, "double(7,3)", ft.String()) - require.False(t, HasCharset(ft)) + if !reflect.DeepEqual("double(7,3)", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "double(7,3)") + } + if HasCharset(ft) { + t.Fatal("expected false") + } ft = NewFieldType(mysql.TypeBlob) ft.SetFlen(10) ft.SetCharset("UTF8") ft.SetCollate("UTF8_UNICODE_GI") - require.Equal(t, "text CHARACTER SET UTF8 COLLATE UTF8_UNICODE_GI", ft.String()) - require.True(t, HasCharset(ft)) + if !reflect.DeepEqual("text CHARACTER SET UTF8 COLLATE UTF8_UNICODE_GI", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "text CHARACTER SET UTF8 COLLATE UTF8_UNICODE_GI") + } + if !(HasCharset(ft)) { + t.Fatal("expected true") + } ft = NewFieldType(mysql.TypeVarchar) ft.SetFlen(10) ft.AddFlag(mysql.BinaryFlag) - require.Equal(t, "varchar(10) BINARY", ft.String()) - require.False(t, HasCharset(ft)) + if !reflect.DeepEqual("varchar(10) BINARY", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "varchar(10) BINARY") + } + if HasCharset(ft) { + t.Fatal("expected false") + } ft = NewFieldType(mysql.TypeString) ft.SetCharset(charset.CharsetBin) ft.AddFlag(mysql.BinaryFlag) - require.Equal(t, "binary(1)", ft.String()) - require.False(t, HasCharset(ft)) + if !reflect.DeepEqual("binary(1)", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "binary(1)") + } + if HasCharset(ft) { + t.Fatal("expected false") + } ft = NewFieldType(mysql.TypeEnum) ft.SetElems([]string{"a", "b"}) - require.Equal(t, "enum('a','b')", ft.String()) - require.True(t, HasCharset(ft)) + if !reflect.DeepEqual("enum('a','b')", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "enum('a','b')") + } + if !(HasCharset(ft)) { + t.Fatal("expected true") + } ft = NewFieldType(mysql.TypeEnum) ft.SetElems([]string{"'a'", "'b'"}) - require.Equal(t, "enum('''a''','''b''')", ft.String()) - require.True(t, HasCharset(ft)) + if !reflect.DeepEqual("enum('''a''','''b''')", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "enum('''a''','''b''')") + } + if !(HasCharset(ft)) { + t.Fatal("expected true") + } ft = NewFieldType(mysql.TypeEnum) ft.SetElems([]string{"a\nb", "a\tb", "a\rb"}) - require.Equal(t, "enum('a\\nb','a\tb','a\\rb')", ft.String()) - require.True(t, HasCharset(ft)) + if !reflect.DeepEqual("enum('a\\nb','a\tb','a\\rb')", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "enum('a\\nb','a\tb','a\\rb')") + } + if !(HasCharset(ft)) { + t.Fatal("expected true") + } ft = NewFieldType(mysql.TypeEnum) ft.SetElems([]string{"a\nb", "a'\t\r\nb", "a\rb"}) - require.Equal(t, "enum('a\\nb','a'' \\r\\nb','a\\rb')", ft.String()) - require.True(t, HasCharset(ft)) + if !reflect.DeepEqual("enum('a\\nb','a'' \\r\\nb','a\\rb')", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "enum('a\\nb','a'' \\r\\nb','a\\rb')") + } + if !(HasCharset(ft)) { + t.Fatal("expected true") + } ft = NewFieldType(mysql.TypeSet) ft.SetElems([]string{"a", "b"}) - require.Equal(t, "set('a','b')", ft.String()) - require.True(t, HasCharset(ft)) + if !reflect.DeepEqual("set('a','b')", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "set('a','b')") + } + if !(HasCharset(ft)) { + t.Fatal("expected true") + } ft = NewFieldType(mysql.TypeSet) ft.SetElems([]string{"'a'", "'b'"}) - require.Equal(t, "set('''a''','''b''')", ft.String()) - require.True(t, HasCharset(ft)) + if !reflect.DeepEqual("set('''a''','''b''')", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "set('''a''','''b''')") + } + if !(HasCharset(ft)) { + t.Fatal("expected true") + } ft = NewFieldType(mysql.TypeSet) ft.SetElems([]string{"a\nb", "a'\t\r\nb", "a\rb"}) - require.Equal(t, "set('a\\nb','a'' \\r\\nb','a\\rb')", ft.String()) - require.True(t, HasCharset(ft)) + if !reflect.DeepEqual("set('a\\nb','a'' \\r\\nb','a\\rb')", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "set('a\\nb','a'' \\r\\nb','a\\rb')") + } + if !(HasCharset(ft)) { + t.Fatal("expected true") + } ft = NewFieldType(mysql.TypeSet) ft.SetElems([]string{"a'\nb", "a'b\tc"}) - require.Equal(t, "set('a''\\nb','a''b c')", ft.String()) - require.True(t, HasCharset(ft)) + if !reflect.DeepEqual("set('a''\\nb','a''b c')", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "set('a''\\nb','a''b c')") + } + if !(HasCharset(ft)) { + t.Fatal("expected true") + } ft = NewFieldType(mysql.TypeTimestamp) ft.SetFlen(8) ft.SetDecimal(2) - require.Equal(t, "timestamp(2)", ft.String()) - require.False(t, HasCharset(ft)) + if !reflect.DeepEqual("timestamp(2)", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "timestamp(2)") + } + if HasCharset(ft) { + t.Fatal("expected false") + } ft = NewFieldType(mysql.TypeTimestamp) ft.SetFlen(8) ft.SetDecimal(0) - require.Equal(t, "timestamp", ft.String()) - require.False(t, HasCharset(ft)) + if !reflect.DeepEqual("timestamp", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "timestamp") + } + if HasCharset(ft) { + t.Fatal("expected false") + } ft = NewFieldType(mysql.TypeDatetime) ft.SetFlen(8) ft.SetDecimal(2) - require.Equal(t, "datetime(2)", ft.String()) - require.False(t, HasCharset(ft)) + if !reflect.DeepEqual("datetime(2)", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "datetime(2)") + } + if HasCharset(ft) { + t.Fatal("expected false") + } ft = NewFieldType(mysql.TypeDatetime) ft.SetFlen(8) ft.SetDecimal(0) - require.Equal(t, "datetime", ft.String()) - require.False(t, HasCharset(ft)) + if !reflect.DeepEqual("datetime", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "datetime") + } + if HasCharset(ft) { + t.Fatal("expected false") + } ft = NewFieldType(mysql.TypeDate) ft.SetFlen(8) ft.SetDecimal(2) - require.Equal(t, "date", ft.String()) - require.False(t, HasCharset(ft)) + if !reflect.DeepEqual("date", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "date") + } + if HasCharset(ft) { + t.Fatal("expected false") + } ft = NewFieldType(mysql.TypeDate) ft.SetFlen(8) ft.SetDecimal(0) - require.Equal(t, "date", ft.String()) - require.False(t, HasCharset(ft)) + if !reflect.DeepEqual("date", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "date") + } + if HasCharset(ft) { + t.Fatal("expected false") + } ft = NewFieldType(mysql.TypeYear) ft.SetFlen(4) ft.SetDecimal(0) - require.Equal(t, "year(4)", ft.String()) - require.False(t, HasCharset(ft)) + if !reflect.DeepEqual("year(4)", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "year(4)") + } + if HasCharset(ft) { + t.Fatal("expected false") + } ft = NewFieldType(mysql.TypeYear) ft.SetFlen(2) ft.SetDecimal(2) - require.Equal(t, "year(2)", ft.String()) - require.False(t, HasCharset(ft)) + if !reflect.DeepEqual("year(2)", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "year(2)") + } + if HasCharset(ft) { + t.Fatal("expected false") + } ft = NewFieldType(mysql.TypeVarchar) ft.SetFlen(0) ft.SetDecimal(0) - require.Equal(t, "varchar(0)", ft.String()) - require.True(t, HasCharset(ft)) + if !reflect.DeepEqual("varchar(0)", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "varchar(0)") + } + if !(HasCharset(ft)) { + t.Fatal("expected true") + } ft = NewFieldType(mysql.TypeString) ft.SetFlen(0) ft.SetDecimal(0) - require.Equal(t, "char(0)", ft.String()) - require.True(t, HasCharset(ft)) + if !reflect.DeepEqual("char(0)", ft.String()) { + t.Fatalf("got %v, want %v", ft.String(), "char(0)") + } + if !(HasCharset(ft)) { + t.Fatal("expected true") + } } func TestHasCharsetFromStmt(t *testing.T) { @@ -235,10 +355,14 @@ func TestHasCharsetFromStmt(t *testing.T) { for _, typ := range types { sql := fmt.Sprintf(template, typ.strType) stmt, err := p.ParseOneStmt(sql, "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } col := stmt.(*ast.CreateTableStmt).Cols[0] - require.Equal(t, typ.hasCharset, HasCharset(col.Tp)) + if !reflect.DeepEqual(typ.hasCharset, HasCharset(col.Tp)) { + t.Fatalf("got %v, want %v", HasCharset(col.Tp), typ.hasCharset) + } } } @@ -267,9 +391,13 @@ func TestEnumSetFlen(t *testing.T) { for _, ca := range cases { stmt, err := p.ParseOneStmt(fmt.Sprintf("create table t (e %v)", ca.sql), "", "") - require.NoError(t, err) + if err != nil { + t.Fatal(err) + } col := stmt.(*ast.CreateTableStmt).Cols[0] - require.Equal(t, ca.ex, col.Tp.GetFlen()) + if !reflect.DeepEqual(ca.ex, col.Tp.GetFlen()) { + t.Fatalf("got %v, want %v", col.Tp.GetFlen(), ca.ex) + } } } @@ -277,27 +405,37 @@ func TestFieldTypeEqual(t *testing.T) { // tp not equal ft1 := NewFieldType(mysql.TypeDouble) ft2 := NewFieldType(mysql.TypeFloat) - require.Equal(t, false, ft1.Equal(ft2)) + if !reflect.DeepEqual(false, ft1.Equal(ft2)) { + t.Fatalf("got %v, want %v", ft1.Equal(ft2), false) + } // decimal not equal ft2 = NewFieldType(mysql.TypeDouble) ft2.SetDecimal(5) - require.Equal(t, false, ft1.Equal(ft2)) + if !reflect.DeepEqual(false, ft1.Equal(ft2)) { + t.Fatalf("got %v, want %v", ft1.Equal(ft2), false) + } // flen not equal and decimal not -1 ft1.SetDecimal(5) ft1.SetFlen(22) - require.Equal(t, false, ft1.Equal(ft2)) + if !reflect.DeepEqual(false, ft1.Equal(ft2)) { + t.Fatalf("got %v, want %v", ft1.Equal(ft2), false) + } // flen equal ft2.SetFlen(22) - require.Equal(t, true, ft1.Equal(ft2)) + if !reflect.DeepEqual(true, ft1.Equal(ft2)) { + t.Fatalf("got %v, want %v", ft1.Equal(ft2), true) + } // decimal is -1 ft1.SetDecimal(-1) ft2.SetDecimal(-1) ft1.SetFlen(23) - require.Equal(t, true, ft1.Equal(ft2)) + if !reflect.DeepEqual(true, ft1.Equal(ft2)) { + t.Fatalf("got %v, want %v", ft1.Equal(ft2), true) + } } func TestCompactStr(t *testing.T) { @@ -323,9 +461,13 @@ func TestCompactStr(t *testing.T) { ft.SetFlag(cc.flags) TiDBStrictIntegerDisplayWidth = false - require.Equal(t, cc.e1, ft.CompactStr()) + if !reflect.DeepEqual(cc.e1, ft.CompactStr()) { + t.Fatalf("got %v, want %v", ft.CompactStr(), cc.e1) + } TiDBStrictIntegerDisplayWidth = true - require.Equal(t, cc.e2, ft.CompactStr()) + if !reflect.DeepEqual(cc.e2, ft.CompactStr()) { + t.Fatalf("got %v, want %v", ft.CompactStr(), cc.e2) + } } } diff --git a/util/escape_test.go b/util/escape_test.go index 599c4d0..be33942 100644 --- a/util/escape_test.go +++ b/util/escape_test.go @@ -14,9 +14,10 @@ package util import ( + "fmt" "testing" - "github.com/stretchr/testify/require" + "reflect" ) func TestUnescapeChar(t *testing.T) { @@ -49,6 +50,8 @@ func TestUnescapeChar(t *testing.T) { } for _, tt := range tests { got := UnescapeChar(tt.input) - require.Equal(t, tt.want, got, "UnescapeChar(%q)", tt.input) + if !reflect.DeepEqual(tt.want, got) { + t.Fatalf("%s: got %v, want %v", fmt.Sprintf("UnescapeChar(%q)", tt.input), got, tt.want) + } } }