From b77c4a2d4503bae5e8568efb403bee9fad30cdd2 Mon Sep 17 00:00:00 2001 From: xushiwei Date: Tue, 6 Feb 2024 03:24:48 +0800 Subject: [PATCH 1/9] ydb.Class: query - todo --- go.mod | 2 +- go.sum | 4 +- ydb/class.go | 120 +++++++++++++++++++++++++++++++++++++++++++++------ 3 files changed, 111 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index 9f46690..0e57796 100644 --- a/go.mod +++ b/go.mod @@ -5,5 +5,5 @@ go 1.18 require ( github.com/golang-jwt/jwt/v5 v5.2.0 github.com/goplus/gop v1.2.0 - github.com/qiniu/x v1.13.3 + github.com/qiniu/x v1.13.4-0.20240205192036-55db357e2bdf ) diff --git a/go.sum b/go.sum index 3c24008..abb4fc8 100644 --- a/go.sum +++ b/go.sum @@ -2,5 +2,5 @@ github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1 github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/goplus/gop v1.2.0 h1:1EOUKhr4OitJs0BtVBVVbejuUkLbXMiFparS1VW7Fhg= github.com/goplus/gop v1.2.0/go.mod h1:F4xOWRMTPCKzNBaF1gZC/JsEKtYW1+Ldwp7tyFoaOWo= -github.com/qiniu/x v1.13.3 h1:NER9aJnVzjH0XapzIWrWNAn2SPwck0xGMyIIlfCMm84= -github.com/qiniu/x v1.13.3/go.mod h1:INZ2TSWSJVWO/RuELQROERcslBwVgFG7MkTfEdaQz9E= +github.com/qiniu/x v1.13.4-0.20240205192036-55db357e2bdf h1:FCmAN2cuVQWqWd5Mkko0W6Tuc5q2DT+tuz1Ln1zvd7M= +github.com/qiniu/x v1.13.4-0.20240205192036-55db357e2bdf/go.mod h1:INZ2TSWSJVWO/RuELQROERcslBwVgFG7MkTfEdaQz9E= diff --git a/ydb/class.go b/ydb/class.go index 1f2f16e..825c1b6 100644 --- a/ydb/class.go +++ b/ydb/class.go @@ -23,8 +23,10 @@ import ( "log" "reflect" "strings" + "unicode/utf8" "github.com/goplus/gop/ast" + "github.com/qiniu/x/ctype" ) var ( @@ -77,8 +79,8 @@ func (p *Class) Use(table string, src ...ast.Node) { // Ret checks a query or call result. // // For checking query result: -// - ret , &, , &, ... -// - ret , &, , &, ... +// - ret , &, , &, ... +// - ret , &, , &, ... // - ret & // - ret & // @@ -104,9 +106,9 @@ func (p *Class) Ret__1(args ...any) { // - insert // - insert func (p *Class) Insert__0(src ast.Node, args ...any) { - /* if p.tbl == "" { - TODO: - } */ + if p.tbl == "" { + log.Panicln("please call `use ` to specified current table") + } nArg := len(args) if nArg == 1 { p.insertStruc(args[0]) @@ -307,8 +309,8 @@ type query struct { } // For checking query result: -// - ret , &, , &, ... -// - ret , &, , &, ... +// - ret , &, , &, ... +// - ret , &, , &, ... // - ret & // - ret & func (p *Class) queryRet(args ...any) { @@ -329,22 +331,116 @@ func (p *Class) queryRetPtr(arg any) { } // For checking query result: -// - ret , &, , &, ... -// - ret , &, , &, ... +// - ret , &, , &, ... +// - ret , &, , &, ... func (p *Class) queryRetKvPair(kvPair ...any) { nPair := len(kvPair) if nPair < 2 || nPair&1 != 0 { - log.Panicln("usage: ret , &, , &, ...") + log.Panicln("usage: ret , &, , &, ...") } + + q := p.query + tbl := p.exprTblname(q.cond) + n := nPair >> 1 - names := make([]string, n) + exprs := make([]string, n) rets := make([]any, n) for i := 0; i < nPair; i += 2 { - names[i>>1] = kvPair[i].(string) + expr := kvPair[i].(string) + if etbl := p.exprTblname(expr); etbl != tbl { + log.Panicf( + "query currently doesn't support multiple tables: `query` use `%s` but `ret` use `%s`\n", + tbl, etbl, + ) + } + exprs[i>>1] = expr rets[i>>1] = kvPair[i+1] } } +func (p *Class) exprTblname(cond string) string { + tbls := exprTblnames(cond) + tbl := "" + switch len(tbls) { + case 0: + case 1: + tbl = tbls[0] + default: + log.Panicln("query currently doesn't support multiple tables") + } + if tbl == "" { + tbl = p.tbl + } + return tbl +} + +func exprTblnames(expr string) (tbls []string) { + for expr != "" { + pos := ctype.ScanCSymbol(expr) + if pos != 0 { + name := "" + if pos > 0 { + switch expr[pos] { + case '.': + name = expr[:pos] + expr = ctype.SkipCSymbol(expr[pos+1:]) + case '(': // function call, eg. SUM(...) + expr = expr[pos+1:] + continue + default: + expr = expr[pos:] + } + } else { + expr = "" + } + switch name { + case "AND", "OR": + default: + tbls = addTblname(tbls, name) + } + continue + } + pos = ctype.ScanTypeEx(ctype.FLOAT_FIRST_CHAT, ctype.CSYMBOL_NEXT_CHAR, expr) + if pos == 0 { + c, size := utf8.DecodeRuneInString(expr) + switch c { + case '\'': + expr = skipStringConst(expr[1:], '\'') + default: + expr = expr[size:] + } + } else if pos < 0 { + break + } else { + expr = expr[pos:] + } + } + return +} + +func skipStringConst(next string, quot rune) string { + skip := false + for i, c := range next { + if skip { + skip = false + } else if c == '\\' { + skip = true + } else if c == quot { + return next[i+1:] + } + } + return "" +} + +func addTblname(tbls []string, tbl string) []string { + for _, v := range tbls { + if v == tbl { + return tbls + } + } + return append(tbls, tbl) +} + // Query creates a new query. // - query , , , ... func (p *Class) Query__0(src ast.Node, cond string, args ...any) { From dd9408d2336c0a2880b87b90c89cd61474a0c8b0 Mon Sep 17 00:00:00 2001 From: xushiwei Date: Tue, 6 Feb 2024 17:20:28 +0800 Subject: [PATCH 2/9] queryRetKvPair, sqlQuery, sqlRetRow --- ydb/class.go | 101 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 98 insertions(+), 3 deletions(-) diff --git a/ydb/class.go b/ydb/class.go index 825c1b6..446f59f 100644 --- a/ydb/class.go +++ b/ydb/class.go @@ -22,6 +22,7 @@ import ( "errors" "log" "reflect" + "strconv" "strings" "unicode/utf8" @@ -30,6 +31,7 @@ import ( ) var ( + ErrNoRows = sql.ErrNoRows ErrDuplicated = errors.New("duplicated") ) @@ -172,8 +174,9 @@ func (p *Class) insertStrucRow(vArg reflect.Value) { } const ( - valFlagNormal = 1 - valFlagSlice = 2 + valFlagNormal = 1 + valFlagSlice = 2 + valFlagInvalid = valFlagNormal | valFlagSlice ) // Insert inserts a new row. @@ -330,6 +333,77 @@ func (p *Class) queryRet(args ...any) { func (p *Class) queryRetPtr(arg any) { } +func isSlice(v any) bool { + return reflect.ValueOf(v).Kind() == reflect.Slice +} + +func retKind(ret any) int { + v := reflect.ValueOf(ret) + if v.Kind() != reflect.Pointer { + log.Panicln("usage: ret , &, , &, ...") + } + if v.Elem().Kind() == reflect.Slice { + return valFlagSlice + } + return valFlagNormal +} + +func sqlRetRow(rows *sql.Rows, rets []any) { + if !rows.Next() { + err := rows.Err() + if err == nil { + err = ErrNoRows + } + log.Panicln("ret:", err) + } + err := rows.Scan(rets...) + if err != nil { + log.Panicln("ret:", err) + } +} + +func sqlRetRows(rows *sql.Rows, rets []any) { +} + +// sqlQuery NOTE: +// - one of args maybe is a slice +func sqlQuery(db *sql.DB, ctx context.Context, query string, args, rets []any, retSlice bool) { + iArgSlice := -1 + for i, arg := range args { + if isSlice(arg) { + if iArgSlice >= 0 { + log.Panicf( + "query: multiple arguments (%dth, %dth) are slices (only one can be)\n", + iArgSlice+1, i+1, + ) + } + iArgSlice = i + } + } + if iArgSlice >= 0 { + if !retSlice { + log.Panicln("one of `query` arguments is a slice, but `ret` arguments are not") + } + sqlMultiQuery(db, ctx, query, iArgSlice, args, rets) + return + } + + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + log.Panicln("query:", err) + } + defer rows.Close() + + if retSlice { + sqlRetRows(rows, rets) + return + } + sqlRetRow(rows, rets) +} + +func sqlMultiQuery(db *sql.DB, ctx context.Context, query string, iArgSlice int, args, rets []any) { +} + // For checking query result: // - ret , &, , &, ... // - ret , &, , &, ... @@ -345,6 +419,7 @@ func (p *Class) queryRetKvPair(kvPair ...any) { n := nPair >> 1 exprs := make([]string, n) rets := make([]any, n) + kind := 0 for i := 0; i < nPair; i += 2 { expr := kvPair[i].(string) if etbl := p.exprTblname(expr); etbl != tbl { @@ -353,9 +428,29 @@ func (p *Class) queryRetKvPair(kvPair ...any) { tbl, etbl, ) } + ret := kvPair[i+1] + kind |= retKind(ret) exprs[i>>1] = expr - rets[i>>1] = kvPair[i+1] + rets[i>>1] = ret + } + if kind == valFlagInvalid { + log.Panicln(`all ret arguments should be address of slices or address of normal variable: + ret , &, , &, ... + ret , &, , &, ...`) + } + + query := make([]byte, 0, 128) + query = append(query, "SELECT "...) + query = append(query, strings.Join(exprs, ",")...) + query = append(query, " FROM "...) + query = append(query, tbl...) + query = append(query, " WHERE "...) + query = append(query, q.cond...) + if q.limit > 0 { + query = append(query, " LIMIT "...) + query = append(query, strconv.Itoa(q.limit)...) } + sqlQuery(p.db, context.TODO(), string(query), q.args, rets, kind == valFlagSlice) } func (p *Class) exprTblname(cond string) string { From adf79884dcc2abb9bbf0569ca1ab57e12d067bfb Mon Sep 17 00:00:00 2001 From: xushiwei Date: Tue, 6 Feb 2024 17:50:30 +0800 Subject: [PATCH 3/9] sqlRetRows --- ydb/class.go | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/ydb/class.go b/ydb/class.go index 446f59f..1ea4f69 100644 --- a/ydb/class.go +++ b/ydb/class.go @@ -362,7 +362,27 @@ func sqlRetRow(rows *sql.Rows, rets []any) { } } -func sqlRetRows(rows *sql.Rows, rets []any) { +func sqlRetRows(rows *sql.Rows, vRets []reflect.Value, oneRet []any, needInit bool) { + for rows.Next() { + if needInit { + for _, ret := range oneRet { + reflect.ValueOf(ret).Elem().SetZero() + } + } else { + needInit = true + } + err := rows.Scan(oneRet...) + if err != nil { + log.Panicln("ret:", err) + } + for i, vRet := range vRets { + v := reflect.ValueOf(oneRet[i]) + vRet.Set(reflect.Append(vRet, v.Elem())) + } + } + if err := rows.Err(); err != nil { + log.Panicln("ret:", err) + } } // sqlQuery NOTE: @@ -395,12 +415,27 @@ func sqlQuery(db *sql.DB, ctx context.Context, query string, args, rets []any, r defer rows.Close() if retSlice { - sqlRetRows(rows, rets) + vRets, oneRet := makeSliceRets(rets) + sqlRetRows(rows, vRets, oneRet, false) return } sqlRetRow(rows, rets) } +func makeSliceRets(rets []any) (vRets []reflect.Value, oneRet []any) { + vRets = make([]reflect.Value, len(rets)) + oneRet = make([]any, len(rets)) + for i, ret := range rets { + slice := reflect.ValueOf(ret).Elem() + slice.SetZero() + vRets[i] = slice + + elem := slice.Type().Elem() + oneRet[i] = reflect.New(elem).Interface() + } + return +} + func sqlMultiQuery(db *sql.DB, ctx context.Context, query string, iArgSlice int, args, rets []any) { } From 790b47932afd60fb8794e6ff82134481ba509d75 Mon Sep 17 00:00:00 2001 From: xushiwei Date: Tue, 6 Feb 2024 18:33:44 +0800 Subject: [PATCH 4/9] onErr: error processing of a sql execution --- ydb/class.go | 108 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 73 insertions(+), 35 deletions(-) diff --git a/ydb/class.go b/ydb/class.go index 1ea4f69..723e5bd 100644 --- a/ydb/class.go +++ b/ydb/class.go @@ -50,7 +50,8 @@ type Class struct { api *api result []reflect.Value // result of an api call - ret func(args ...any) + ret func(args ...any) + onErr func(err error) } func newClass(name string, sql *Sql) *Class { @@ -78,6 +79,18 @@ func (p *Class) Use(table string, src ...ast.Node) { p.tobj = tblobj } +// OnErr sets error processing of a sql execution. +func (p *Class) OnErr(onErr func(error), src ...ast.Node) { + p.onErr = onErr +} + +func (p *Class) handleErr(prompt string, err error) { + if p.onErr == nil { + log.Panicln(prompt, err) + } + p.onErr(err) +} + // Ret checks a query or call result. // // For checking query result: @@ -107,38 +120,37 @@ func (p *Class) Ret__1(args ...any) { // - insert , , , , ... // - insert // - insert -func (p *Class) Insert__0(src ast.Node, args ...any) { +func (p *Class) Insert__0(src ast.Node, args ...any) (sql.Result, error) { if p.tbl == "" { log.Panicln("please call `use ` to specified current table") } nArg := len(args) if nArg == 1 { - p.insertStruc(args[0]) - } else { - p.insertKvPair(args...) + return p.insertStruc(args[0]) } + return p.insertKvPair(args...) } // Insert inserts a new row. // - insert // - insert -func (p *Class) insertStruc(arg any) { +func (p *Class) insertStruc(arg any) (sql.Result, error) { vArg := reflect.ValueOf(arg) switch vArg.Kind() { case reflect.Slice: - p.insertStrucRows(vArg) + return p.insertStrucRows(vArg) case reflect.Pointer: vArg = vArg.Elem() fallthrough default: - p.insertStrucRow(vArg) + return p.insertStrucRow(vArg) } } -func (p *Class) insertStrucRows(vSlice reflect.Value) { +func (p *Class) insertStrucRows(vSlice reflect.Value) (sql.Result, error) { rows := vSlice.Len() if rows == 0 { - return + return nil, nil } hasPtr := false elem := vSlice.Type().Elem() @@ -160,17 +172,17 @@ func (p *Class) insertStrucRows(vSlice reflect.Value) { } vals = getVals(vals, vElem, cols) } - p.insertRowsVals(names, vals, rows) + return p.insertRowsVals(names, vals, rows) } -func (p *Class) insertStrucRow(vArg reflect.Value) { +func (p *Class) insertStrucRow(vArg reflect.Value) (sql.Result, error) { if vArg.Kind() != reflect.Struct { log.Panicln("usage: insert ") } n := vArg.NumField() names, cols := getCols(make([]string, 0, n), make([]field, 0, n), n, vArg.Type(), 0) vals := getVals(make([]any, 0, len(cols)), vArg, cols) - p.insertRow(names, vals) + return p.insertRow(names, vals) } const ( @@ -182,7 +194,7 @@ const ( // Insert inserts a new row. // - insert , , , , ... // - insert , , , , ... -func (p *Class) insertKvPair(kvPair ...any) { +func (p *Class) insertKvPair(kvPair ...any) (sql.Result, error) { nPair := len(kvPair) if nPair < 2 || nPair&1 != 0 { log.Panicln("usage: insert , , , , ...") @@ -212,16 +224,17 @@ func (p *Class) insertKvPair(kvPair ...any) { } switch kind { case valFlagNormal: - p.insertRow(names, vals) + return p.insertRow(names, vals) case valFlagSlice: - p.insertSliceRows(names, vals, rows) + return p.insertSliceRows(names, vals, rows) default: log.Panicln("can't insert mix slice and normal value") } + return nil, nil } // NOTE: len(args) == len(names) -func (p *Class) insertSliceRows(names []string, args []any, rows int) { +func (p *Class) insertSliceRows(names []string, args []any, rows int) (sql.Result, error) { vals := make([]any, 0, len(names)*rows) for i := 0; i < rows; i++ { for _, arg := range args { @@ -229,33 +242,34 @@ func (p *Class) insertSliceRows(names []string, args []any, rows int) { vals = append(vals, v.Index(i).Interface()) } } - p.insertRowsVals(names, vals, rows) + return p.insertRowsVals(names, vals, rows) } // NOTE: len(vals) == len(names) * rows -func (p *Class) insertRowsVals(names []string, vals []any, rows int) { +func (p *Class) insertRowsVals(names []string, vals []any, rows int) (sql.Result, error) { n := len(names) query := insertQuery(p.tbl, names) query = append(query, valParams(n, rows)...) result, err := p.db.ExecContext(context.TODO(), string(query), vals...) - insertRet(result, err) + return p.insertRet(result, err) } -func (p *Class) insertRow(names []string, vals []any) { +func (p *Class) insertRow(names []string, vals []any) (sql.Result, error) { if len(names) == 0 { log.Panicln("insert: nothing to insert") } query := insertQuery(p.tbl, names) query = append(query, valParam(len(vals))...) result, err := p.db.ExecContext(context.TODO(), string(query), vals...) - insertRet(result, err) + return p.insertRet(result, err) } -func insertRet(result sql.Result, err error) { +func (p *Class) insertRet(result sql.Result, err error) (sql.Result, error) { if err != nil { - log.Panicln("insert:", err) + p.handleErr("insert:", err) } + return result, err } func insertQuery(tbl string, names []string) []byte { @@ -282,24 +296,24 @@ func valParam(n int) string { } // Insert inserts a new row. -func (p *Class) Insert__1(kvPair ...any) { - p.Insert__0(nil, kvPair...) +func (p *Class) Insert__1(kvPair ...any) (sql.Result, error) { + return p.Insert__0(nil, kvPair...) } // Count returns rows of a query result. -func (p *Class) Count__0(src ast.Node, cond string, args ...any) (n int) { +func (p *Class) Count__0(src ast.Node, cond string, args ...any) (n int, err error) { if p.tbl == "" { log.Panicln("please call `use ` to specified a table name") } row := p.db.QueryRowContext(context.TODO(), "SELECT COUNT(*) FROM "+p.tbl+" WHERE "+cond, args...) - if err := row.Scan(&n); err != nil { - log.Panicln("count:", err) + if err = row.Scan(&n); err != nil { + p.handleErr("count:", err) } return } // Count returns rows of a query result. -func (p *Class) Count__1(cond string, args ...any) (n int) { +func (p *Class) Count__1(cond string, args ...any) (n int, err error) { return p.Count__0(nil, cond, args...) } @@ -427,7 +441,6 @@ func makeSliceRets(rets []any) (vRets []reflect.Value, oneRet []any) { oneRet = make([]any, len(rets)) for i, ret := range rets { slice := reflect.ValueOf(ret).Elem() - slice.SetZero() vRets[i] = slice elem := slice.Type().Elem() @@ -436,7 +449,28 @@ func makeSliceRets(rets []any) (vRets []reflect.Value, oneRet []any) { return } +func sqlQueryOne(db *sql.DB, ctx context.Context, query string, args, oneRet []any, vRets []reflect.Value) { + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + log.Panicln("query:", err) + } + defer rows.Close() + + sqlRetRows(rows, vRets, oneRet, true) +} + func sqlMultiQuery(db *sql.DB, ctx context.Context, query string, iArgSlice int, args, rets []any) { + argSlice := args[iArgSlice] + defer func() { + args[iArgSlice] = argSlice + }() + vRets, oneRet := makeSliceRets(rets) + vArgSlice := reflect.ValueOf(argSlice) + for i, n := 0, vArgSlice.Len(); i < n; i++ { + arg := vArgSlice.Index(i).Interface() + args[iArgSlice] = arg + sqlQueryOne(db, ctx, query, args, oneRet, vRets) + } } // For checking query result: @@ -594,16 +628,20 @@ func (p *Class) Limit__0(n int, src ...ast.Node) { } // Limit checks if query result rows is < n or not. -func (p *Class) Limit__1(src ast.Node, n int, cond string, args ...any) { - ret := p.Count__0(src, cond, args...) +func (p *Class) Limit__1(src ast.Node, n int, cond string, args ...any) error { + ret, err := p.Count__0(src, cond, args...) + if err != nil { + return err + } if ret >= n { log.Panicf("limit %s: got %d, expected <%d\n", cond, ret, n) } + return nil } // Limit checks if query result rows is < n or not. -func (p *Class) Limit__2(n int, cond string, args ...any) { - p.Limit__1(nil, n, cond, args...) +func (p *Class) Limit__2(n int, cond string, args ...any) error { + return p.Limit__1(nil, n, cond, args...) } // ----------------------------------------------------------------------------- From 0da0d95ddf548b2d6bc3b4bf5127511a9bb541da Mon Sep 17 00:00:00 2001 From: xushiwei Date: Tue, 6 Feb 2024 19:32:05 +0800 Subject: [PATCH 5/9] onErr: error processing of a sql execution --- ydb/class.go | 63 +++++++++++++++++++++++++++++++--------------------- 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/ydb/class.go b/ydb/class.go index 723e5bd..31ca35c 100644 --- a/ydb/class.go +++ b/ydb/class.go @@ -33,6 +33,7 @@ import ( var ( ErrNoRows = sql.ErrNoRows ErrDuplicated = errors.New("duplicated") + ErrOutOfLimit = errors.New("out of limit") ) // ----------------------------------------------------------------------------- @@ -362,21 +363,23 @@ func retKind(ret any) int { return valFlagNormal } -func sqlRetRow(rows *sql.Rows, rets []any) { +func (p *Class) sqlRetRow(rows *sql.Rows, rets []any) error { if !rows.Next() { err := rows.Err() if err == nil { err = ErrNoRows } - log.Panicln("ret:", err) + p.handleErr("ret:", err) + return err } err := rows.Scan(rets...) if err != nil { - log.Panicln("ret:", err) + p.handleErr("ret:", err) } + return err } -func sqlRetRows(rows *sql.Rows, vRets []reflect.Value, oneRet []any, needInit bool) { +func (p *Class) sqlRetRows(rows *sql.Rows, vRets []reflect.Value, oneRet []any, needInit bool) error { for rows.Next() { if needInit { for _, ret := range oneRet { @@ -387,21 +390,24 @@ func sqlRetRows(rows *sql.Rows, vRets []reflect.Value, oneRet []any, needInit bo } err := rows.Scan(oneRet...) if err != nil { - log.Panicln("ret:", err) + p.handleErr("ret:", err) + return err } for i, vRet := range vRets { v := reflect.ValueOf(oneRet[i]) vRet.Set(reflect.Append(vRet, v.Elem())) } } - if err := rows.Err(); err != nil { - log.Panicln("ret:", err) + err := rows.Err() + if err != nil { + p.handleErr("ret:", err) } + return err } // sqlQuery NOTE: // - one of args maybe is a slice -func sqlQuery(db *sql.DB, ctx context.Context, query string, args, rets []any, retSlice bool) { +func (p *Class) sqlQuery(ctx context.Context, query string, args, rets []any, retSlice bool) error { iArgSlice := -1 for i, arg := range args { if isSlice(arg) { @@ -418,22 +424,21 @@ func sqlQuery(db *sql.DB, ctx context.Context, query string, args, rets []any, r if !retSlice { log.Panicln("one of `query` arguments is a slice, but `ret` arguments are not") } - sqlMultiQuery(db, ctx, query, iArgSlice, args, rets) - return + return p.sqlMultiQuery(ctx, query, iArgSlice, args, rets) } - rows, err := db.QueryContext(ctx, query, args...) + rows, err := p.db.QueryContext(ctx, query, args...) if err != nil { - log.Panicln("query:", err) + p.handleErr("query:", err) + return err } defer rows.Close() if retSlice { vRets, oneRet := makeSliceRets(rets) - sqlRetRows(rows, vRets, oneRet, false) - return + return p.sqlRetRows(rows, vRets, oneRet, false) } - sqlRetRow(rows, rets) + return p.sqlRetRow(rows, rets) } func makeSliceRets(rets []any) (vRets []reflect.Value, oneRet []any) { @@ -449,17 +454,18 @@ func makeSliceRets(rets []any) (vRets []reflect.Value, oneRet []any) { return } -func sqlQueryOne(db *sql.DB, ctx context.Context, query string, args, oneRet []any, vRets []reflect.Value) { - rows, err := db.QueryContext(ctx, query, args...) +func (p *Class) sqlQueryOne(ctx context.Context, query string, args, oneRet []any, vRets []reflect.Value) error { + rows, err := p.db.QueryContext(ctx, query, args...) if err != nil { - log.Panicln("query:", err) + p.handleErr("query:", err) + return err } defer rows.Close() - sqlRetRows(rows, vRets, oneRet, true) + return p.sqlRetRows(rows, vRets, oneRet, true) } -func sqlMultiQuery(db *sql.DB, ctx context.Context, query string, iArgSlice int, args, rets []any) { +func (p *Class) sqlMultiQuery(ctx context.Context, query string, iArgSlice int, args, rets []any) error { argSlice := args[iArgSlice] defer func() { args[iArgSlice] = argSlice @@ -469,14 +475,17 @@ func sqlMultiQuery(db *sql.DB, ctx context.Context, query string, iArgSlice int, for i, n := 0, vArgSlice.Len(); i < n; i++ { arg := vArgSlice.Index(i).Interface() args[iArgSlice] = arg - sqlQueryOne(db, ctx, query, args, oneRet, vRets) + if err := p.sqlQueryOne(ctx, query, args, oneRet, vRets); err != nil { + return err + } } + return nil } // For checking query result: // - ret , &, , &, ... // - ret , &, , &, ... -func (p *Class) queryRetKvPair(kvPair ...any) { +func (p *Class) queryRetKvPair(kvPair ...any) error { nPair := len(kvPair) if nPair < 2 || nPair&1 != 0 { log.Panicln("usage: ret , &, , &, ...") @@ -519,7 +528,7 @@ func (p *Class) queryRetKvPair(kvPair ...any) { query = append(query, " LIMIT "...) query = append(query, strconv.Itoa(q.limit)...) } - sqlQuery(p.db, context.TODO(), string(query), q.args, rets, kind == valFlagSlice) + return p.sqlQuery(context.TODO(), string(query), q.args, rets, kind == valFlagSlice) } func (p *Class) exprTblname(cond string) string { @@ -634,9 +643,13 @@ func (p *Class) Limit__1(src ast.Node, n int, cond string, args ...any) error { return err } if ret >= n { - log.Panicf("limit %s: got %d, expected <%d\n", cond, ret, n) + if p.onErr == nil { + log.Panicf("limit %s: got %d, expected <%d\n", cond, ret, n) + } + err = ErrOutOfLimit + p.onErr(err) } - return nil + return err } // Limit checks if query result rows is < n or not. From 36f603821f0511599285cc0f4df4deb281648d3d Mon Sep 17 00:00:00 2001 From: xushiwei Date: Wed, 7 Feb 2024 00:21:18 +0800 Subject: [PATCH 6/9] ydb.Class: queryRetPtr, queryStrucRow, queryStrucRows --- gop.mod | 2 +- ydb/class.go | 200 ++++++++++++++++++++++++++++++++++++++++----------- ydb/table.go | 8 ++- 3 files changed, 165 insertions(+), 45 deletions(-) diff --git a/gop.mod b/gop.mod index 424f077..1472e73 100644 --- a/gop.mod +++ b/gop.mod @@ -1,4 +1,4 @@ -gop 1.1 +gop 1.2 project _yap.gox App github.com/goplus/yap diff --git a/ydb/class.go b/ydb/class.go index 31ca35c..a0f82e4 100644 --- a/ydb/class.go +++ b/ydb/class.go @@ -51,7 +51,7 @@ type Class struct { api *api result []reflect.Value // result of an api call - ret func(args ...any) + ret func(args ...any) error onErr func(err error) } @@ -171,7 +171,7 @@ func (p *Class) insertStrucRows(vSlice reflect.Value) (sql.Result, error) { if hasPtr { vElem = vElem.Elem() } - vals = getVals(vals, vElem, cols) + vals = getVals(vals, vElem, cols, true) } return p.insertRowsVals(names, vals, rows) } @@ -182,7 +182,7 @@ func (p *Class) insertStrucRow(vArg reflect.Value) (sql.Result, error) { } n := vArg.NumField() names, cols := getCols(make([]string, 0, n), make([]field, 0, n), n, vArg.Type(), 0) - vals := getVals(make([]any, 0, len(cols)), vArg, cols) + vals := getVals(make([]any, 0, len(cols)), vArg, cols, true) return p.insertRow(names, vals) } @@ -326,32 +326,171 @@ type query struct { limit int // 0 means no limit } +func (q *query) makeSelectExpr(tbl string, exprs []string) string { + query := make([]byte, 0, 128) + query = append(query, "SELECT "...) + query = append(query, strings.Join(exprs, ",")...) + query = append(query, " FROM "...) + query = append(query, tbl...) + query = append(query, " WHERE "...) + query = append(query, q.cond...) + if q.limit > 0 { + query = append(query, " LIMIT "...) + query = append(query, strconv.Itoa(q.limit)...) + } + return string(query) +} + // For checking query result: // - ret , &, , &, ... // - ret , &, , &, ... // - ret & // - ret & -func (p *Class) queryRet(args ...any) { +func (p *Class) queryRet(args ...any) (err error) { nArg := len(args) if nArg == 1 { - p.queryRetPtr(args[0]) + err = p.queryRetPtr(args[0]) } else { - p.queryRetKvPair(args...) + err = p.queryRetKvPair(args...) } p.query = nil p.ret = nil + return } // For checking query result: // - ret & -// - ret & -func (p *Class) queryRetPtr(arg any) { +// - ret & +func (p *Class) queryRetPtr(ret any) error { + vRet := reflect.ValueOf(ret) + if vRet.Kind() != reflect.Pointer { + log.Panicln("usage: ret &") + } + + switch vRet = vRet.Elem(); vRet.Kind() { + case reflect.Slice: + return p.queryStrucRows(vRet) + default: + return p.queryStrucRow(vRet) + } +} + +// For checking query result: +// - ret & +func (p *Class) queryStrucRow(vRet reflect.Value) error { + if vRet.Kind() != reflect.Struct { + log.Panicln("usage: ret &") + } + + n := vRet.NumField() + names, cols := getCols(make([]string, 0, n), make([]field, 0, n), n, vRet.Type(), 0) + rets := getVals(make([]any, 0, len(cols)), vRet, cols, false) + + q := p.query + query := q.makeSelectExpr(p.tbl, names) + return p.sqlQueryVals(context.TODO(), query, q.args, rets) +} + +func (p *Class) queryStrucOne( + ctx context.Context, query string, args []any, + vSlice reflect.Value, elem dbType, cols []field, hasPtr bool) error { + vRet := reflect.New(elem).Elem() + rets := getVals(make([]any, 0, len(cols)), vRet, cols, false) + err := p.sqlQueryVals(ctx, query, args, rets) + if err != nil { + return err + } + if hasPtr { + vRet = vRet.Addr() + } + vSlice.Set(reflect.Append(vSlice, vRet)) + return nil +} + +func (p *Class) queryStrucMulti( + ctx context.Context, query string, args []any, iArgSlice int, + vSlice reflect.Value, elem dbType, cols []field, hasPtr bool) error { + argSlice := args[iArgSlice] + defer func() { + args[iArgSlice] = argSlice + }() + vArgSlice := reflect.ValueOf(argSlice) + for i, n := 0, vArgSlice.Len(); i < n; i++ { + arg := vArgSlice.Index(i).Interface() + args[iArgSlice] = arg + if err := p.queryStrucOne(ctx, query, args, vSlice, elem, cols, hasPtr); err != nil { + return err + } + } + return nil +} + +// For checking query result: +// - ret & +func (p *Class) queryStrucRows(vSlice reflect.Value) error { + hasPtr := false + elem := vSlice.Type().Elem() + kind := elem.Kind() + if kind == reflect.Pointer { + elem, hasPtr = elem.Elem(), true + kind = elem.Kind() + } + if kind != reflect.Struct { + log.Panicln("usage: ret &") + } + + n := elem.NumField() + names, cols := getCols(make([]string, 0, n), make([]field, 0, n), n, elem, 0) + + q := p.query + query := q.makeSelectExpr(p.tbl, names) + + args := q.args + iArgSlice := checkArgSlice(args) + if iArgSlice >= 0 { + return p.queryStrucMulti(context.TODO(), query, args, iArgSlice, vSlice, elem, cols, hasPtr) + } + return p.queryStrucOne(context.TODO(), query, args, vSlice, elem, cols, hasPtr) +} + +// sqlQueryVals NOTE: +// - one of args maybe is a slice +func (p *Class) sqlQueryVals(ctx context.Context, query string, args, rets []any) error { + iArgSlice := checkArgSlice(args) + if iArgSlice >= 0 { + log.Panicln("one of `query` arguments is a slice, but `ret` arguments are not") + } + + rows, err := p.db.QueryContext(ctx, query, args...) + if err != nil { + p.handleErr("query:", err) + return err + } + defer rows.Close() + + return p.sqlRetRow(rows, rets) } func isSlice(v any) bool { return reflect.ValueOf(v).Kind() == reflect.Slice } +func checkArgSlice(args []any) int { + iArgSlice := -1 + for i, arg := range args { + if isSlice(arg) { + if iArgSlice >= 0 { + log.Panicf( + "query: multiple arguments (%dth, %dth) are slices (only one can be)\n", + iArgSlice+1, i+1, + ) + } + iArgSlice = i + } + } + return iArgSlice +} + func retKind(ret any) int { v := reflect.ValueOf(ret) if v.Kind() != reflect.Pointer { @@ -405,25 +544,11 @@ func (p *Class) sqlRetRows(rows *sql.Rows, vRets []reflect.Value, oneRet []any, return err } -// sqlQuery NOTE: +// sqlQueryRows NOTE: // - one of args maybe is a slice -func (p *Class) sqlQuery(ctx context.Context, query string, args, rets []any, retSlice bool) error { - iArgSlice := -1 - for i, arg := range args { - if isSlice(arg) { - if iArgSlice >= 0 { - log.Panicf( - "query: multiple arguments (%dth, %dth) are slices (only one can be)\n", - iArgSlice+1, i+1, - ) - } - iArgSlice = i - } - } +func (p *Class) sqlQueryRows(ctx context.Context, query string, args, rets []any) error { + iArgSlice := checkArgSlice(args) if iArgSlice >= 0 { - if !retSlice { - log.Panicln("one of `query` arguments is a slice, but `ret` arguments are not") - } return p.sqlMultiQuery(ctx, query, iArgSlice, args, rets) } @@ -434,11 +559,8 @@ func (p *Class) sqlQuery(ctx context.Context, query string, args, rets []any, re } defer rows.Close() - if retSlice { - vRets, oneRet := makeSliceRets(rets) - return p.sqlRetRows(rows, vRets, oneRet, false) - } - return p.sqlRetRow(rows, rets) + vRets, oneRet := makeSliceRets(rets) + return p.sqlRetRows(rows, vRets, oneRet, false) } func makeSliceRets(rets []any) (vRets []reflect.Value, oneRet []any) { @@ -517,18 +639,11 @@ func (p *Class) queryRetKvPair(kvPair ...any) error { ret , &, , &, ...`) } - query := make([]byte, 0, 128) - query = append(query, "SELECT "...) - query = append(query, strings.Join(exprs, ",")...) - query = append(query, " FROM "...) - query = append(query, tbl...) - query = append(query, " WHERE "...) - query = append(query, q.cond...) - if q.limit > 0 { - query = append(query, " LIMIT "...) - query = append(query, strconv.Itoa(q.limit)...) + query := q.makeSelectExpr(tbl, exprs) + if kind == valFlagNormal { + return p.sqlQueryVals(context.TODO(), query, q.args, rets) } - return p.sqlQuery(context.TODO(), string(query), q.args, rets, kind == valFlagSlice) + return p.sqlQueryRows(context.TODO(), query, q.args, rets) } func (p *Class) exprTblname(cond string) string { @@ -689,8 +804,9 @@ func (p *Class) Call__1(args ...any) { p.Call__0(nil, args...) } -func (p *Class) callRet(args ...any) { +func (p *Class) callRet(args ...any) error { p.ret = nil + return nil } // ----------------------------------------------------------------------------- diff --git a/ydb/table.go b/ydb/table.go index 57179eb..66472e2 100644 --- a/ydb/table.go +++ b/ydb/table.go @@ -56,10 +56,14 @@ func newTable(name, ver string, schema dbType) *Table { return p } -func getVals(vals []any, v reflect.Value, cols []field) []any { +func getVals(vals []any, v reflect.Value, cols []field, elem bool) []any { this := uintptr(v.Addr().UnsafePointer()) for _, col := range cols { - val := reflect.NewAt(col.typ, unsafe.Pointer(this+col.offset)).Interface() + v := reflect.NewAt(col.typ, unsafe.Pointer(this+col.offset)) + if elem { + v = v.Elem() + } + val := v.Interface() vals = append(vals, val) } return vals From 449c1becf69630b6102451210064d90c88a877c1 Mon Sep 17 00:00:00 2001 From: xushiwei Date: Wed, 7 Feb 2024 01:39:42 +0800 Subject: [PATCH 7/9] ydb.Class: callRet --- test/logt/logt.go | 171 ++++++++++++++++++++++++++++++++++++++++++++++ ydb/class.go | 25 ++++++- 2 files changed, 195 insertions(+), 1 deletion(-) create mode 100644 test/logt/logt.go diff --git a/test/logt/logt.go b/test/logt/logt.go new file mode 100644 index 0000000..51355fb --- /dev/null +++ b/test/logt/logt.go @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2024 The GoPlus Authors (goplus.org). All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logt + +import ( + "log" + "runtime" + "time" +) + +type T struct { + name string + fail bool + skipped bool +} + +func New() *T { + return &T{} +} + +func (p *T) Name() string { + return p.name +} + +// Fail marks the function as having failed but continues execution. +func (p *T) Fail() { + p.fail = true +} + +// Failed reports whether the function has failed. +func (p *T) Failed() bool { + return p.fail +} + +// FailNow marks the function as having failed and stops its execution +// by calling runtime.Goexit (which then runs all deferred calls in the +// current goroutine). +// Execution will continue at the next test or benchmark. +// FailNow must be called from the goroutine running the +// test or benchmark function, not from other goroutines +// created during the test. Calling FailNow does not stop +// those other goroutines. +func (p *T) FailNow() { + p.fail = true + runtime.Goexit() +} + +// Log formats its arguments using default formatting, analogous to Println, +// and records the text in the error log. For tests, the text will be printed only if +// the test fails or the -test.v flag is set. For benchmarks, the text is always +// printed to avoid having performance depend on the value of the -test.v flag. +func (p *T) Log(args ...any) { + log.Println(args...) +} + +// Logf formats its arguments according to the format, analogous to Printf, and +// records the text in the error log. A final newline is added if not provided. For +// tests, the text will be printed only if the test fails or the -test.v flag is +// set. For benchmarks, the text is always printed to avoid having performance +// depend on the value of the -test.v flag. +func (p *T) Logf(format string, args ...any) { + log.Printf(format, args...) +} + +// Errorln is equivalent to Log followed by Fail. +func (p *T) Errorln(args ...any) { + log.Println(args...) + p.Fail() +} + +// Errorf is equivalent to Logf followed by Fail. +func (p *T) Errorf(format string, args ...any) { + log.Printf(format, args...) + p.Fail() +} + +// Fatal is equivalent to Log followed by FailNow. +func (p *T) Fatal(args ...any) { + log.Println(args...) + p.FailNow() +} + +// Fatalf is equivalent to Logf followed by FailNow. +func (p *T) Fatalf(format string, args ...any) { + log.Printf(format, args...) + p.FailNow() +} + +// Skip is equivalent to Log followed by SkipNow. +func (p *T) Skip(args ...any) { + log.Println(args...) + p.SkipNow() +} + +// Skipf is equivalent to Logf followed by SkipNow. +func (p *T) Skipf(format string, args ...any) { + log.Printf(format, args...) + p.SkipNow() +} + +// SkipNow marks the test as having been skipped and stops its execution +// by calling runtime.Goexit. +// If a test fails (see Error, Errorf, Fail) and is then skipped, +// it is still considered to have failed. +// Execution will continue at the next test or benchmark. See also FailNow. +// SkipNow must be called from the goroutine running the test, not from +// other goroutines created during the test. Calling SkipNow does not stop +// those other goroutines. +func (p *T) SkipNow() { + p.skipped = true + runtime.Goexit() +} + +// Skipped reports whether the test was skipped. +func (p *T) Skipped() bool { + return p.skipped +} + +// Helper marks the calling function as a test helper function. +// When printing file and line information, that function will be skipped. +// Helper may be called simultaneously from multiple goroutines. +func (p *T) Helper() { +} + +// Cleanup registers a function to be called when the test (or subtest) and all its +// subtests complete. Cleanup functions will be called in last added, +// first called order. +func (p *T) Cleanup(f func()) { + // TODO: +} + +// TempDir returns a temporary directory for the test to use. +// The directory is automatically removed by Cleanup when the test and +// all its subtests complete. +// Each subsequent call to t.TempDir returns a unique directory; +// if the directory creation fails, TempDir terminates the test by calling Fatal. +func (p *T) TempDir() string { + panic("todo") +} + +// Run runs f as a subtest of t called name. +// +// Run may be called simultaneously from multiple goroutines, but all such calls +// must return before the outer test function for t returns. +func (p *T) Run(name string, f func()) bool { + p.name = name + f() + return true +} + +// Deadline reports the time at which the test binary will have +// exceeded the timeout specified by the -timeout flag. +// +// The ok result is false if the -timeout flag indicates “no timeout” (0). +func (p *T) Deadline() (deadline time.Time, ok bool) { + panic("todo") +} diff --git a/ydb/class.go b/ydb/class.go index a0f82e4..b168a29 100644 --- a/ydb/class.go +++ b/ydb/class.go @@ -27,6 +27,8 @@ import ( "unicode/utf8" "github.com/goplus/gop/ast" + "github.com/goplus/yap/test" + "github.com/goplus/yap/test/logt" "github.com/qiniu/x/ctype" ) @@ -53,6 +55,8 @@ type Class struct { ret func(args ...any) error onErr func(err error) + + test.CaseT } func newClass(name string, sql *Sql) *Class { @@ -67,6 +71,13 @@ func newClass(name string, sql *Sql) *Class { } } +func (p *Class) t() test.CaseT { + if p.CaseT == nil { + p.CaseT = logt.New() + } + return p.CaseT +} + func (p *Class) gen(ctx context.Context) { } @@ -795,7 +806,7 @@ func (p *Class) Call__0(src ast.Node, args ...any) { for i, arg := range args { vArgs[i] = reflect.ValueOf(arg) } - p.result = reflect.ValueOf(p.api).Call(vArgs) + p.result = reflect.ValueOf(p.api.spec).Call(vArgs) p.ret = p.callRet } @@ -805,6 +816,18 @@ func (p *Class) Call__1(args ...any) { } func (p *Class) callRet(args ...any) error { + t := p.t() + result := p.result + if len(result) != len(args) { + t.Fatalf( + "call ret: unmatched result parameters count - got %d, expected %d\n", + len(args), len(result), + ) + } + for i, arg := range args { + ret := result[i].Interface() + test.Gopt_Case_Match__4(t, arg, ret) + } p.ret = nil return nil } From 7ca1e307a609805c34f8b98ddf4255240563426b Mon Sep 17 00:00:00 2001 From: xushiwei Date: Wed, 7 Feb 2024 01:44:16 +0800 Subject: [PATCH 8/9] at least go1.20: because ydb uses reflect.Value.SetZero --- .github/workflows/go.yml | 2 +- go.mod | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index adbf693..731f307 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -14,7 +14,7 @@ jobs: build: strategy: matrix: - go-version: [1.18.x, 1.21.x] + go-version: [1.20.x, 1.21.x] os: [ubuntu-latest, windows-latest,macos-11] runs-on: ${{ matrix.os }} steps: diff --git a/go.mod b/go.mod index 0e57796..b9b7226 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/goplus/yap -go 1.18 +go 1.20 require ( github.com/golang-jwt/jwt/v5 v5.2.0 From 3fe0b0f7e0209ef0a27f462607376ab63b2ff303 Mon Sep 17 00:00:00 2001 From: xushiwei Date: Wed, 7 Feb 2024 02:01:42 +0800 Subject: [PATCH 9/9] reflectutil.SetZero --- .github/workflows/go.yml | 2 +- go.mod | 2 +- noredirect/noredirect.go | 16 ++++++++++++++++ reflectutil/reflect_go120.go | 30 ++++++++++++++++++++++++++++++ reflectutil/reflect_mock.go | 30 ++++++++++++++++++++++++++++++ reflectutil/reflect_test.go | 31 +++++++++++++++++++++++++++++++ ydb/class.go | 3 ++- 7 files changed, 111 insertions(+), 3 deletions(-) create mode 100644 reflectutil/reflect_go120.go create mode 100644 reflectutil/reflect_mock.go create mode 100644 reflectutil/reflect_test.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 731f307..adbf693 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -14,7 +14,7 @@ jobs: build: strategy: matrix: - go-version: [1.20.x, 1.21.x] + go-version: [1.18.x, 1.21.x] os: [ubuntu-latest, windows-latest,macos-11] runs-on: ${{ matrix.os }} steps: diff --git a/go.mod b/go.mod index b9b7226..0e57796 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/goplus/yap -go 1.20 +go 1.18 require ( github.com/golang-jwt/jwt/v5 v5.2.0 diff --git a/noredirect/noredirect.go b/noredirect/noredirect.go index 0a250a5..1565376 100644 --- a/noredirect/noredirect.go +++ b/noredirect/noredirect.go @@ -1,3 +1,19 @@ +/* + * Copyright (c) 2023 The GoPlus Authors (goplus.org). All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package noredirect import ( diff --git a/reflectutil/reflect_go120.go b/reflectutil/reflect_go120.go new file mode 100644 index 0000000..7c61baa --- /dev/null +++ b/reflectutil/reflect_go120.go @@ -0,0 +1,30 @@ +//go:build go1.20 +// +build go1.20 + +/* + * Copyright (c) 2024 The GoPlus Authors (goplus.org). All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package reflectutil + +import ( + "reflect" +) + +// SetZero sets v to be the zero value of v's type. +// It panics if CanSet returns false. +func SetZero(v reflect.Value) { + v.SetZero() +} diff --git a/reflectutil/reflect_mock.go b/reflectutil/reflect_mock.go new file mode 100644 index 0000000..9f35f96 --- /dev/null +++ b/reflectutil/reflect_mock.go @@ -0,0 +1,30 @@ +//go:build !go1.20 +// +build !go1.20 + +/* + * Copyright (c) 2024 The GoPlus Authors (goplus.org). All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package reflectutil + +import ( + "reflect" +) + +// SetZero sets v to be the zero value of v's type. +// It panics if CanSet returns false. +func SetZero(v reflect.Value) { + v.Set(reflect.Zero(v.Type())) +} diff --git a/reflectutil/reflect_test.go b/reflectutil/reflect_test.go new file mode 100644 index 0000000..b56073e --- /dev/null +++ b/reflectutil/reflect_test.go @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2024 The GoPlus Authors (goplus.org). All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package reflectutil + +import ( + "reflect" + "testing" +) + +func TestSetZero(t *testing.T) { + a := 2 + v := reflect.ValueOf(&a).Elem() + SetZero(v) + if !v.IsZero() { + t.Fatal("SetZero:", v) + } +} diff --git a/ydb/class.go b/ydb/class.go index b168a29..91d9107 100644 --- a/ydb/class.go +++ b/ydb/class.go @@ -27,6 +27,7 @@ import ( "unicode/utf8" "github.com/goplus/gop/ast" + "github.com/goplus/yap/reflectutil" "github.com/goplus/yap/test" "github.com/goplus/yap/test/logt" "github.com/qiniu/x/ctype" @@ -533,7 +534,7 @@ func (p *Class) sqlRetRows(rows *sql.Rows, vRets []reflect.Value, oneRet []any, for rows.Next() { if needInit { for _, ret := range oneRet { - reflect.ValueOf(ret).Elem().SetZero() + reflectutil.SetZero(reflect.ValueOf(ret).Elem()) } } else { needInit = true