From d94e7e1c98196061e76b7cd9f1e9cc3ebe94098a Mon Sep 17 00:00:00 2001 From: xushiwei Date: Mon, 5 Feb 2024 23:52:13 +0800 Subject: [PATCH] ydb.Class: insert structVal/structSlice --- ydb/class.go | 91 ++++++++++++++++++++++++++++++++++++++++++++-------- ydb/table.go | 47 ++++++++++++++++++++++++--- 2 files changed, 119 insertions(+), 19 deletions(-) diff --git a/ydb/class.go b/ydb/class.go index 360a85c..1f2f16e 100644 --- a/ydb/class.go +++ b/ydb/class.go @@ -36,6 +36,7 @@ var ( type Class struct { name string tbl string + tobj *Table sql *Sql db *sql.DB apis map[string]*api @@ -65,10 +66,12 @@ func (p *Class) gen(ctx context.Context) { // Use sets the default table used in following sql operations. func (p *Class) Use(table string, src ...ast.Node) { - if _, ok := p.sql.tables[table]; !ok { + tblobj, ok := p.sql.tables[table] + if !ok { log.Panicln("table not found:", table) } p.tbl = table + p.tobj = tblobj } // Ret checks a query or call result. @@ -99,7 +102,7 @@ func (p *Class) Ret__1(args ...any) { // - insert , , , , ... // - insert , , , , ... // - insert -// - insert +// - insert func (p *Class) Insert__0(src ast.Node, args ...any) { /* if p.tbl == "" { TODO: @@ -114,8 +117,56 @@ func (p *Class) Insert__0(src ast.Node, args ...any) { // Insert inserts a new row. // - insert -// - insert +// - insert func (p *Class) insertStruc(arg any) { + vArg := reflect.ValueOf(arg) + switch vArg.Kind() { + case reflect.Slice: + p.insertStrucRows(vArg) + case reflect.Pointer: + vArg = vArg.Elem() + fallthrough + default: + p.insertStrucRow(vArg) + } +} + +func (p *Class) insertStrucRows(vSlice reflect.Value) { + rows := vSlice.Len() + if rows == 0 { + return + } + hasPtr := false + elem := vSlice.Type().Elem() + kind := elem.Kind() + if kind == reflect.Pointer { + elem, hasPtr = elem.Elem(), true + kind = elem.Kind() + } + if kind != reflect.Struct { + log.Panicln("usage: insert ") + } + n := elem.NumField() + names, cols := getCols(make([]string, 0, n), make([]field, 0, n), n, elem, 0) + vals := make([]any, 0, len(names)*rows) + for row := 0; row < rows; row++ { + vElem := vSlice.Index(row) + if hasPtr { + vElem = vElem.Elem() + } + vals = getVals(vals, vElem, cols) + } + p.insertRowsVals(names, vals, rows) +} + +func (p *Class) insertStrucRow(vArg reflect.Value) { + if vArg.Kind() != reflect.Struct { + log.Panicln("usage: insert ") + } + n := vArg.NumField() + names, cols := getCols(make([]string, 0, n), make([]field, 0, n), n, vArg.Type(), 0) + vals := getVals(make([]any, 0, len(cols)), vArg, cols) + p.insertRow(names, vals) } const ( @@ -158,33 +209,38 @@ func (p *Class) insertKvPair(kvPair ...any) { case valFlagNormal: p.insertRow(names, vals) case valFlagSlice: - p.insertRows(names, vals, rows) + p.insertSliceRows(names, vals, rows) default: log.Panicln("can't insert mix slice and normal value") } } -func (p *Class) insertRows(names []string, args []any, rows int) { - n := len(args) - valparam := valParam(n) - valparams := strings.Repeat(valparam+",", rows) - valparams = valparams[:len(valparams)-1] - - query := insertQuery(p.tbl, names) - query = append(query, valparams...) - - vals := make([]any, 0, n*rows) +// NOTE: len(args) == len(names) +func (p *Class) insertSliceRows(names []string, args []any, rows int) { + vals := make([]any, 0, len(names)*rows) for i := 0; i < rows; i++ { for _, arg := range args { v := arg.(reflect.Value) vals = append(vals, v.Index(i).Interface()) } } + p.insertRowsVals(names, vals, rows) +} + +// NOTE: len(vals) == len(names) * rows +func (p *Class) insertRowsVals(names []string, vals []any, rows int) { + n := len(names) + query := insertQuery(p.tbl, names) + query = append(query, valParams(n, rows)...) + result, err := p.db.ExecContext(context.TODO(), string(query), vals...) insertRet(result, err) } func (p *Class) insertRow(names []string, vals []any) { + if len(names) == 0 { + log.Panicln("insert: nothing to insert") + } query := insertQuery(p.tbl, names) query = append(query, valParam(len(vals))...) result, err := p.db.ExecContext(context.TODO(), string(query), vals...) @@ -207,6 +263,13 @@ func insertQuery(tbl string, names []string) []byte { return query } +func valParams(n, rows int) string { + valparam := valParam(n) + valparams := strings.Repeat(valparam+",", rows) + valparams = valparams[:len(valparams)-1] + return valparams +} + func valParam(n int) string { valparam := strings.Repeat("?,", n) valparam = "(" + valparam[:len(valparam)-1] + ")" diff --git a/ydb/table.go b/ydb/table.go index 6f8494e..57179eb 100644 --- a/ydb/table.go +++ b/ydb/table.go @@ -23,6 +23,7 @@ import ( "reflect" "strings" "time" + "unsafe" ) type dbType = reflect.Type @@ -45,27 +46,63 @@ type column struct { type field struct { typ dbType // field type offset uintptr // offset within struct, in bytes - index []int // index sequence for Type.FieldByIndex } func newTable(name, ver string, schema dbType) *Table { n := schema.NumField() cols := make([]*column, 0, n) p := &Table{name: name, ver: ver, schema: schema, cols: cols} - p.defineCols(n, schema) + p.defineCols(n, schema, 0) return p } -func (p *Table) defineCols(n int, t dbType) { +func getVals(vals []any, v reflect.Value, cols []field) []any { + this := uintptr(v.Addr().UnsafePointer()) + for _, col := range cols { + val := reflect.NewAt(col.typ, unsafe.Pointer(this+col.offset)).Interface() + vals = append(vals, val) + } + return vals +} + +func getCols(names []string, cols []field, n int, t dbType, base uintptr) ([]string, []field) { + for i := 0; i < n; i++ { + fld := t.Field(i) + if fld.Anonymous { + fldType := fld.Type + names, cols = getCols(names, cols, fldType.NumField(), fldType, base+fld.Offset) + continue + } + if fld.IsExported() { + name := "" + if tag := string(fld.Tag); tag != "" { + if c := tag[0]; c >= 'a' && c <= 'z' { // suppose a column name is lower case + if pos := strings.IndexByte(tag, ' '); pos > 0 { + tag = tag[:pos] + } + name = tag + } + } + if name == "" { + name = dbName(fld.Name) + } + names = append(names, name) + cols = append(cols, field{fld.Type, base + fld.Offset}) + } + } + return names, cols +} + +func (p *Table) defineCols(n int, t dbType, base uintptr) { for i := 0; i < n; i++ { fld := t.Field(i) if fld.Anonymous { fldType := fld.Type - p.defineCols(fldType.NumField(), fldType) + p.defineCols(fldType.NumField(), fldType, base+fld.Offset) continue } if fld.IsExported() { - col := &column{fld: field{fld.Type, fld.Offset, fld.Index}} + col := &column{fld: field{fld.Type, base + fld.Offset}} if tag := string(fld.Tag); tag != "" { if parts := strings.Fields(tag); len(parts) > 0 { if c := parts[0][0]; c >= 'a' && c <= 'z' { // suppose a column name is lower case