Skip to content

Commit

Permalink
Merge pull request #59 from xushiwei/db
Browse files Browse the repository at this point in the history
ydb.Class: insert structVal/structSlice
  • Loading branch information
xushiwei authored Feb 5, 2024
2 parents 2581e0d + d94e7e1 commit c0f9cec
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 19 deletions.
91 changes: 77 additions & 14 deletions ydb/class.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ var (
type Class struct {
name string
tbl string
tobj *Table
sql *Sql
db *sql.DB
apis map[string]*api
Expand Down Expand Up @@ -65,10 +66,12 @@ func (p *Class) gen(ctx context.Context) {

// Use sets the default table used in following sql operations.
func (p *Class) Use(table string, src ...ast.Node) {
if _, ok := p.sql.tables[table]; !ok {
tblobj, ok := p.sql.tables[table]
if !ok {
log.Panicln("table not found:", table)
}
p.tbl = table
p.tobj = tblobj
}

// Ret checks a query or call result.
Expand Down Expand Up @@ -99,7 +102,7 @@ func (p *Class) Ret__1(args ...any) {
// - insert <colName1>, <val1>, <colName2>, <val2>, ...
// - insert <colName1>, <valSlice1>, <colName2>, <valSlice2>, ...
// - insert <structValOrPtr>
// - insert <structSlice>
// - insert <structOrPtrSlice>
func (p *Class) Insert__0(src ast.Node, args ...any) {
/* if p.tbl == "" {
TODO:
Expand All @@ -114,8 +117,56 @@ func (p *Class) Insert__0(src ast.Node, args ...any) {

// Insert inserts a new row.
// - insert <structValOrPtr>
// - insert <structSlice>
// - insert <structOrPtrSlice>
func (p *Class) insertStruc(arg any) {
vArg := reflect.ValueOf(arg)
switch vArg.Kind() {
case reflect.Slice:
p.insertStrucRows(vArg)
case reflect.Pointer:
vArg = vArg.Elem()
fallthrough
default:
p.insertStrucRow(vArg)
}
}

func (p *Class) insertStrucRows(vSlice reflect.Value) {
rows := vSlice.Len()
if rows == 0 {
return
}
hasPtr := false
elem := vSlice.Type().Elem()
kind := elem.Kind()
if kind == reflect.Pointer {
elem, hasPtr = elem.Elem(), true
kind = elem.Kind()
}
if kind != reflect.Struct {
log.Panicln("usage: insert <structOrPtrSlice>")
}
n := elem.NumField()
names, cols := getCols(make([]string, 0, n), make([]field, 0, n), n, elem, 0)
vals := make([]any, 0, len(names)*rows)
for row := 0; row < rows; row++ {
vElem := vSlice.Index(row)
if hasPtr {
vElem = vElem.Elem()
}
vals = getVals(vals, vElem, cols)
}
p.insertRowsVals(names, vals, rows)
}

func (p *Class) insertStrucRow(vArg reflect.Value) {
if vArg.Kind() != reflect.Struct {
log.Panicln("usage: insert <structValOrPtr>")
}
n := vArg.NumField()
names, cols := getCols(make([]string, 0, n), make([]field, 0, n), n, vArg.Type(), 0)
vals := getVals(make([]any, 0, len(cols)), vArg, cols)
p.insertRow(names, vals)
}

const (
Expand Down Expand Up @@ -158,33 +209,38 @@ func (p *Class) insertKvPair(kvPair ...any) {
case valFlagNormal:
p.insertRow(names, vals)
case valFlagSlice:
p.insertRows(names, vals, rows)
p.insertSliceRows(names, vals, rows)
default:
log.Panicln("can't insert mix slice and normal value")
}
}

func (p *Class) insertRows(names []string, args []any, rows int) {
n := len(args)
valparam := valParam(n)
valparams := strings.Repeat(valparam+",", rows)
valparams = valparams[:len(valparams)-1]

query := insertQuery(p.tbl, names)
query = append(query, valparams...)

vals := make([]any, 0, n*rows)
// NOTE: len(args) == len(names)
func (p *Class) insertSliceRows(names []string, args []any, rows int) {
vals := make([]any, 0, len(names)*rows)
for i := 0; i < rows; i++ {
for _, arg := range args {
v := arg.(reflect.Value)
vals = append(vals, v.Index(i).Interface())
}
}
p.insertRowsVals(names, vals, rows)
}

// NOTE: len(vals) == len(names) * rows
func (p *Class) insertRowsVals(names []string, vals []any, rows int) {
n := len(names)
query := insertQuery(p.tbl, names)
query = append(query, valParams(n, rows)...)

result, err := p.db.ExecContext(context.TODO(), string(query), vals...)
insertRet(result, err)
}

func (p *Class) insertRow(names []string, vals []any) {
if len(names) == 0 {
log.Panicln("insert: nothing to insert")
}
query := insertQuery(p.tbl, names)
query = append(query, valParam(len(vals))...)
result, err := p.db.ExecContext(context.TODO(), string(query), vals...)
Expand All @@ -207,6 +263,13 @@ func insertQuery(tbl string, names []string) []byte {
return query
}

func valParams(n, rows int) string {
valparam := valParam(n)
valparams := strings.Repeat(valparam+",", rows)
valparams = valparams[:len(valparams)-1]
return valparams
}

func valParam(n int) string {
valparam := strings.Repeat("?,", n)
valparam = "(" + valparam[:len(valparam)-1] + ")"
Expand Down
47 changes: 42 additions & 5 deletions ydb/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"reflect"
"strings"
"time"
"unsafe"
)

type dbType = reflect.Type
Expand All @@ -45,27 +46,63 @@ type column struct {
type field struct {
typ dbType // field type
offset uintptr // offset within struct, in bytes
index []int // index sequence for Type.FieldByIndex
}

func newTable(name, ver string, schema dbType) *Table {
n := schema.NumField()
cols := make([]*column, 0, n)
p := &Table{name: name, ver: ver, schema: schema, cols: cols}
p.defineCols(n, schema)
p.defineCols(n, schema, 0)
return p
}

func (p *Table) defineCols(n int, t dbType) {
func getVals(vals []any, v reflect.Value, cols []field) []any {
this := uintptr(v.Addr().UnsafePointer())
for _, col := range cols {
val := reflect.NewAt(col.typ, unsafe.Pointer(this+col.offset)).Interface()
vals = append(vals, val)
}
return vals
}

func getCols(names []string, cols []field, n int, t dbType, base uintptr) ([]string, []field) {
for i := 0; i < n; i++ {
fld := t.Field(i)
if fld.Anonymous {
fldType := fld.Type
names, cols = getCols(names, cols, fldType.NumField(), fldType, base+fld.Offset)
continue
}
if fld.IsExported() {
name := ""
if tag := string(fld.Tag); tag != "" {
if c := tag[0]; c >= 'a' && c <= 'z' { // suppose a column name is lower case
if pos := strings.IndexByte(tag, ' '); pos > 0 {
tag = tag[:pos]
}
name = tag
}
}
if name == "" {
name = dbName(fld.Name)
}
names = append(names, name)
cols = append(cols, field{fld.Type, base + fld.Offset})
}
}
return names, cols
}

func (p *Table) defineCols(n int, t dbType, base uintptr) {
for i := 0; i < n; i++ {
fld := t.Field(i)
if fld.Anonymous {
fldType := fld.Type
p.defineCols(fldType.NumField(), fldType)
p.defineCols(fldType.NumField(), fldType, base+fld.Offset)
continue
}
if fld.IsExported() {
col := &column{fld: field{fld.Type, fld.Offset, fld.Index}}
col := &column{fld: field{fld.Type, base + fld.Offset}}
if tag := string(fld.Tag); tag != "" {
if parts := strings.Fields(tag); len(parts) > 0 {
if c := parts[0][0]; c >= 'a' && c <= 'z' { // suppose a column name is lower case
Expand Down

0 comments on commit c0f9cec

Please sign in to comment.