diff --git a/go.mod b/go.mod index 3b07877..231db6d 100644 --- a/go.mod +++ b/go.mod @@ -1,15 +1,27 @@ module github.com/houqp/sqlvet -go 1.13 +go 1.19 require ( github.com/houqp/gtest v1.0.0 - github.com/logrusorgru/aurora v0.0.0-20191116043053-66b7ad493a23 - github.com/pelletier/go-toml v1.6.0 - github.com/pganalyze/pg_query_go v1.0.3 + github.com/logrusorgru/aurora v2.0.3+incompatible + github.com/pelletier/go-toml v1.9.5 + github.com/pganalyze/pg_query_go/v2 v2.2.0 github.com/sirupsen/logrus v1.9.0 - github.com/spf13/cobra v0.0.5 + github.com/spf13/cobra v1.6.1 + github.com/stretchr/testify v1.8.1 + golang.org/x/tools v0.3.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fatih/structtag v1.2.0 // indirect + github.com/golang/protobuf v1.4.2 // indirect + github.com/inconshreveable/mousetrap v1.0.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/stretchr/testify v1.7.0 - golang.org/x/tools v0.0.0-20191108193012-7d206e10da11 + golang.org/x/mod v0.7.0 // indirect + golang.org/x/sys v0.2.0 // indirect + google.golang.org/protobuf v1.23.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index ec97fba..045b5a3 100644 --- a/go.sum +++ b/go.sum @@ -1,81 +1,88 @@ -github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= -github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= -github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= -github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= -github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= -github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.1 h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k= +github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/houqp/gtest v1.0.0 h1:BEA3kuePC0Q9bDoksoZXCnyYIZGUbMXurjwyv/ql0cA= github.com/houqp/gtest v1.0.0/go.mod h1:oxg4BHzN6nRAQZWTc5qO90uK9voXKmb5kg4/XE6lhKw= -github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= -github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/inconshreveable/mousetrap v1.0.1 h1:U3uMjPSQEBMNp1lFxmllqCPM6P5u/Xq7Pgzkat/bFNc= +github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/logrusorgru/aurora v0.0.0-20191116043053-66b7ad493a23 h1:Wp7NjqGKGN9te9N/rvXYRhlVcrulGdxnz8zadXWs7fc= -github.com/logrusorgru/aurora v0.0.0-20191116043053-66b7ad493a23/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4= -github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= -github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= -github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= -github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= -github.com/pelletier/go-toml v1.6.0 h1:aetoXYr0Tv7xRU/V4B4IZJ2QcbtMUFoNb3ORp7TzIK4= -github.com/pelletier/go-toml v1.6.0/go.mod h1:5N711Q9dKgbdkxHL+MEfF31hpT7l0S0s/t2kKREewys= -github.com/pganalyze/pg_query_go v1.0.3 h1:cur7WhCeA63mUD3Y/hZCl4QbU8NudQr1tIZV/ctsXCQ= -github.com/pganalyze/pg_query_go v1.0.3/go.mod h1:tR53lU3ddnExxb0XeLyYuQIK3dkR03FjQ9sj8AV/up8= +github.com/logrusorgru/aurora v2.0.3+incompatible h1:tOpm7WcpBTn4fjmVfgpQq0EfczGlG91VSDkswnjF5A8= +github.com/logrusorgru/aurora v2.0.3+incompatible/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4= +github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8= +github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= +github.com/pganalyze/pg_query_go/v2 v2.2.0 h1:OW+reH+ZY7jdEuPyuLGlf1m7dLbE+fDudKXhLs0Ttpk= +github.com/pganalyze/pg_query_go/v2 v2.2.0/go.mod h1:XAxmVqz1tEGqizcQ3YSdN90vCOHBWjJi8URL1er5+cA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= -github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= -github.com/spf13/cobra v0.0.5 h1:f0B+LkLX6DtmRH1isoNA9VTtNUK9K8xYd28JNNfOv/s= -github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= -github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= -github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/spf13/cobra v1.6.1 h1:o94oiPyS4KD1mPy2fmcYYHHfCxLqYjJOhGsCHFZtEzA= +github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUqzrY= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= -github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= go.uber.org/goleak v0.10.1-0.20191111212139-7380c5a9fa84 h1:DSZ6nQuvDK2fSSOX15dEhAYgXJAfaFwhpaxEAnGtAwU= go.uber.org/goleak v0.10.1-0.20191111212139-7380c5a9fa84/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= -golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.7.0 h1:LapD9S96VoQRhi/GrNTqeBJFrUjs5UHCAtTlgwA5oZA= +golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20191108193012-7d206e10da11 h1:Yq9t9jnGoR+dBuitxdo9l6Q7xh/zOyNnYUtDKaQ3x0E= golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.3.0 h1:SrNbZl6ECOS1qFzgTdQfWXZM9XBkiA6tkFrH9YSTPHM= +golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.4 h1:/eiJrUcujPVeJ3xlSWaiNi3uSVmDGBK1pDHUHAnao1I= -gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/parseutil/sqlx_test.go b/pkg/parseutil/sqlx_test.go new file mode 100644 index 0000000..f8e7027 --- /dev/null +++ b/pkg/parseutil/sqlx_test.go @@ -0,0 +1,96 @@ +package parseutil + +import "testing" + +func TestCompileQuery(t *testing.T) { + table := []struct { + Q, R, D, T, N string + V []string + }{ + // basic test for named parameters, invalid char ',' terminating + { + Q: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, + R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`, + D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`, + T: `INSERT INTO foo (a,b,c,d) VALUES (@p1, @p2, @p3, @p4)`, + N: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`, + V: []string{"name", "age", "first", "last"}, + }, + // This query tests a named parameter ending the string as well as numbers + { + Q: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`, + R: `SELECT * FROM a WHERE first_name=? AND last_name=?`, + D: `SELECT * FROM a WHERE first_name=$1 AND last_name=$2`, + T: `SELECT * FROM a WHERE first_name=@p1 AND last_name=@p2`, + N: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`, + V: []string{"name1", "name2"}, + }, + { + Q: `SELECT "::foo" FROM a WHERE first_name=:name1 AND last_name=:name2`, + R: `SELECT ":foo" FROM a WHERE first_name=? AND last_name=?`, + D: `SELECT ":foo" FROM a WHERE first_name=$1 AND last_name=$2`, + T: `SELECT ":foo" FROM a WHERE first_name=@p1 AND last_name=@p2`, + N: `SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`, + V: []string{"name1", "name2"}, + }, + { + Q: `SELECT 'a::b::c' || first_name, '::::ABC::_::' FROM person WHERE first_name=:first_name AND last_name=:last_name`, + R: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=? AND last_name=?`, + D: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`, + T: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=@p1 AND last_name=@p2`, + N: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`, + V: []string{"first_name", "last_name"}, + }, + { + Q: `SELECT @name := "name", :age, :first, :last`, + R: `SELECT @name := "name", ?, ?, ?`, + D: `SELECT @name := "name", $1, $2, $3`, + N: `SELECT @name := "name", :age, :first, :last`, + T: `SELECT @name := "name", @p1, @p2, @p3`, + V: []string{"age", "first", "last"}, + }, + /* This unicode awareness test sadly fails, because of our byte-wise worldview. + * We could certainly iterate by Rune instead, though it's a great deal slower, + * it's probably the RightWay(tm) + { + Q: `INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`, + R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`, + D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`, + N: []string{"name", "age", "first", "last"}, + }, + */ + } + + for _, test := range table { + qr, names, err := CompileNamedQuery([]byte(test.Q), QUESTION) + if err != nil { + t.Error(err) + } + if qr != test.R { + t.Errorf("expected %s, got %s", test.R, qr) + } + if len(names) != len(test.V) { + t.Errorf("expected %#v, got %#v", test.V, names) + } else { + for i, name := range names { + if name != test.V[i] { + t.Errorf("expected %dth name to be %s, got %s", i+1, test.V[i], name) + } + } + } + qd, _, _ := CompileNamedQuery([]byte(test.Q), DOLLAR) + if qd != test.D { + t.Errorf("\nexpected: `%s`\ngot: `%s`", test.D, qd) + } + + qt, _, _ := CompileNamedQuery([]byte(test.Q), AT) + if qt != test.T { + t.Errorf("\nexpected: `%s`\ngot: `%s`", test.T, qt) + } + + qq, _, _ := CompileNamedQuery([]byte(test.Q), NAMED) + if qq != test.N { + t.Errorf("\nexpected: `%s`\ngot: `%s`\n(len: %d vs %d)", test.N, qq, len(test.N), len(qq)) + } + } +} diff --git a/pkg/schema/postgres.go b/pkg/schema/postgres.go index 0078e6d..5e4d50b 100644 --- a/pkg/schema/postgres.go +++ b/pkg/schema/postgres.go @@ -1,68 +1,52 @@ package schema import ( - "io/ioutil" + "os" "strings" - pg_query "github.com/pganalyze/pg_query_go" - nodes "github.com/pganalyze/pg_query_go/nodes" + pg_query "github.com/pganalyze/pg_query_go/v2" ) -// func debugNode(n nodes.Node) { -// b, e := n.MarshalJSON() -// if e != nil { -// fmt.Println("Node decode error:", e) -// } else { -// fmt.Println(string(b)) -// } -// } - func (s *Db) LoadPostgres(schemaPath string) error { - schemaBytes, err := ioutil.ReadFile(schemaPath) + schemaBytes, err := os.ReadFile(schemaPath) if err != nil { return err } - // func() { - // tree, _ := pg_query.ParseToJSON(string(schemaBytes)) - // fmt.Println("????????????", tree) - // }() - tree, err := pg_query.Parse(string(schemaBytes)) if err != nil { return err } - for _, stmt := range tree.Statements { - raw, ok := stmt.(nodes.RawStmt) - if !ok { + for _, stmt := range tree.Stmts { + if stmt.Stmt == nil { continue } - switch stmt := raw.Stmt.(type) { - case nodes.CreateStmt: - tableName := *stmt.Relation.Relname + switch { + case stmt.Stmt.GetCreateStmt() != nil: + tableName := stmt.Stmt.GetCreateStmt().Relation.Relname table := Table{ Name: tableName, Columns: map[string]Column{}, } - for _, colElem := range stmt.TableElts.Items { - colDef, ok := colElem.(nodes.ColumnDef) - if !ok { + for _, colElem := range stmt.Stmt.GetCreateStmt().TableElts { + if colElem.GetColumnDef() == nil { continue } + colDef := colElem.GetColumnDef() typeParts := []string{} - for _, typNode := range colDef.TypeName.Names.Items { - tStr, ok := typNode.(nodes.String) - if !ok { + for _, typNode := range colDef.TypeName.Names { + if typNode.GetString_() == nil { continue } + tStr := typNode.GetString_() typeParts = append(typeParts, tStr.Str) } - colName := *colDef.Colname + colName := colDef.Colname table.Columns[colName] = Column{ Name: colName, Type: strings.Join(typeParts, "."), diff --git a/pkg/vet/vet.go b/pkg/vet/vet.go index f2664cc..714e63f 100644 --- a/pkg/vet/vet.go +++ b/pkg/vet/vet.go @@ -6,8 +6,7 @@ import ( "fmt" "reflect" - pg_query "github.com/pganalyze/pg_query_go" - nodes "github.com/pganalyze/pg_query_go/nodes" + pg_query "github.com/pganalyze/pg_query_go/v2" "github.com/houqp/sqlvet/pkg/schema" ) @@ -24,11 +23,11 @@ type TableUsed struct { type ColumnUsed struct { Column string Table string - Location int + Location int32 } type QueryParam struct { - Number int + Number int32 // TODO: also store related column type info for analysis } @@ -69,39 +68,44 @@ func DebugQuery(q string) { var pretty bytes.Buffer json.Indent(&pretty, []byte(b), "\t", " ") fmt.Println("query: " + q) - fmt.Println("parsed query: " + string(pretty.Bytes())) + fmt.Println("parsed query: " + pretty.String()) } -func rangeVarToTableUsed(r nodes.RangeVar) TableUsed { +func rangeVarToTableUsed(r *pg_query.RangeVar) TableUsed { t := TableUsed{ - Name: *r.Relname, + Name: r.Relname, } if r.Alias != nil { - t.Alias = *r.Alias.Aliasname + t.Alias = r.Alias.Aliasname } return t } +/* ========================================================================================================= + this is where the rewrite to nil checks rather than type refleciton starts + ========================================================================================================= */ + // return nil if no specific column is being referenced -func columnRefToColumnUsed(colRef nodes.ColumnRef) *ColumnUsed { +func columnRefToColumnUsed(colRef *pg_query.ColumnRef) *ColumnUsed { cu := ColumnUsed{ Location: colRef.Location, } - var colField nodes.Node - if len(colRef.Fields.Items) > 1 { + var colField *pg_query.Node + if len(colRef.Fields) > 1 { // in the form of SELECT table.column FROM table - cu.Table = colRef.Fields.Items[0].(nodes.String).Str - colField = colRef.Fields.Items[1] + cu.Table = colRef.Fields[0].GetString_().Str + // fmt.Printf("table: %s\n", cu.Table) + colField = colRef.Fields[1] } else { // in the form of SELECT column FROM table - colField = colRef.Fields.Items[0] + colField = colRef.Fields[0] } - switch refField := colField.(type) { - case nodes.String: - cu.Column = refField.Str - case nodes.A_Star: + switch { + case colField.GetString_() != nil: + cu.Column = colField.GetString_().GetStr() + case colField.GetAStar() != nil: // SELECT * return nil default: @@ -112,58 +116,56 @@ func columnRefToColumnUsed(colRef nodes.ColumnRef) *ColumnUsed { return &cu } -func getUsedTablesFromJoinArg(arg nodes.Node) []TableUsed { - switch n := arg.(type) { - case nodes.RangeVar: - return []TableUsed{rangeVarToTableUsed(n)} - case nodes.JoinExpr: +func getUsedTablesFromJoinArg(arg *pg_query.Node) []TableUsed { + switch { + case arg.GetRangeVar() != nil: + return []TableUsed{rangeVarToTableUsed(arg.GetRangeVar())} + case arg.GetJoinExpr() != nil: return append( - getUsedTablesFromJoinArg(n.Larg), - getUsedTablesFromJoinArg(n.Rarg)...) + getUsedTablesFromJoinArg(arg.GetJoinExpr().GetLarg()), + getUsedTablesFromJoinArg(arg.GetJoinExpr().GetRarg())...) default: return []TableUsed{} } } // extract used tables from FROM clause and JOIN clauses -func getUsedTablesFromSelectStmt(fromClauseList nodes.List) []TableUsed { +func getUsedTablesFromSelectStmt(fromClauseList []*pg_query.Node) []TableUsed { usedTables := []TableUsed{} - if len(fromClauseList.Items) <= 0 { + if len(fromClauseList) <= 0 { // skip because no table is referenced in the query return usedTables } - for _, fromItem := range fromClauseList.Items { - switch fromExpr := fromItem.(type) { - case nodes.RangeVar: + for _, fromItem := range fromClauseList { + switch { + case fromItem.GetRangeVar() != nil: // SELECT without JOIN - usedTables = append(usedTables, rangeVarToTableUsed(fromExpr)) - case nodes.JoinExpr: + usedTables = append(usedTables, rangeVarToTableUsed(fromItem.GetRangeVar())) + case fromItem.GetJoinExpr() != nil: // SELECT with one or more JOINs - usedTables = append(usedTables, getUsedTablesFromJoinArg(fromExpr.Larg)...) - usedTables = append(usedTables, getUsedTablesFromJoinArg(fromExpr.Rarg)...) + usedTables = append(usedTables, getUsedTablesFromJoinArg(fromItem.GetJoinExpr().GetLarg())...) + usedTables = append(usedTables, getUsedTablesFromJoinArg(fromItem.GetJoinExpr().GetRarg())...) } } return usedTables } -func getUsedColumnsFromJoinQuals(quals nodes.Node) []ColumnUsed { +func getUsedColumnsFromJoinQuals(quals *pg_query.Node) []ColumnUsed { usedCols := []ColumnUsed{} - switch joinCond := quals.(type) { - case nodes.A_Expr: - lcolRef, ok := joinCond.Lexpr.(nodes.ColumnRef) - if ok { - cu := columnRefToColumnUsed(lcolRef) + if quals.GetAExpr() != nil { + joinCond := quals.GetAExpr() + if lColRef := joinCond.GetLexpr().GetColumnRef(); lColRef != nil { + cu := columnRefToColumnUsed(lColRef) if cu != nil { usedCols = append(usedCols, *cu) } } - rcolRef, ok := joinCond.Rexpr.(nodes.ColumnRef) - if ok { - cu := columnRefToColumnUsed(rcolRef) + if rColRef := joinCond.GetRexpr().GetColumnRef(); rColRef != nil { + cu := columnRefToColumnUsed(rColRef) if cu != nil { usedCols = append(usedCols, *cu) } @@ -173,55 +175,60 @@ func getUsedColumnsFromJoinQuals(quals nodes.Node) []ColumnUsed { return usedCols } -func getUsedColumnsFromJoinExpr(expr nodes.JoinExpr) []ColumnUsed { +// todo this rewrite seems especially dubious +func getUsedColumnsFromJoinExpr(expr *pg_query.Node) []ColumnUsed { usedCols := []ColumnUsed{} - - if larg, ok := expr.Larg.(nodes.JoinExpr); ok { + if expr.GetJoinExpr() == nil { + return usedCols + } + joinExpr := expr.GetJoinExpr() + if larg := joinExpr.Larg; larg != nil { usedCols = append(usedCols, getUsedColumnsFromJoinExpr(larg)...) } - if rarg, ok := expr.Rarg.(nodes.JoinExpr); ok { + if rarg := joinExpr.Rarg; rarg != nil { usedCols = append(usedCols, getUsedColumnsFromJoinExpr(rarg)...) } - usedCols = append(usedCols, getUsedColumnsFromJoinQuals(expr.Quals)...) + usedCols = append(usedCols, getUsedColumnsFromJoinQuals(joinExpr.Quals)...) return usedCols } -func getUsedColumnsFromJoinClauses(fromClauseList nodes.List) []ColumnUsed { +func getUsedColumnsFromJoinClauses(fromClauseList []*pg_query.Node) []ColumnUsed { usedCols := []ColumnUsed{} - if len(fromClauseList.Items) <= 0 { + if len(fromClauseList) <= 0 { // skip because no table is referenced in the query, which means there // is no Join clause return usedCols } - for _, fromItem := range fromClauseList.Items { - switch fromExpr := fromItem.(type) { - case nodes.RangeVar: + for _, fromItem := range fromClauseList { + switch { + case fromItem.GetRangeVar() != nil: // SELECT without JOIN continue - case nodes.JoinExpr: + case fromItem.GetJoinExpr() != nil: // SELECT with one or more JOINs - usedCols = append(usedCols, getUsedColumnsFromJoinExpr(fromExpr)...) + usedCols = append(usedCols, getUsedColumnsFromJoinExpr(fromItem)...) } } return usedCols } -func getUsedColumnsFromReturningList(returningList nodes.List) []ColumnUsed { +func getUsedColumnsFromReturningList(returningList []*pg_query.Node) []ColumnUsed { usedCols := []ColumnUsed{} - for _, node := range returningList.Items { - target, ok := node.(nodes.ResTarget) - if !ok { + for _, node := range returningList { + target := node.GetResTarget() + if target == nil { continue } - switch targetVal := target.Val.(type) { - case nodes.ColumnRef: - cu := columnRefToColumnUsed(targetVal) + switch { + // case pg_query.ColumnRef: + case target.Val.GetColumnRef() != nil: + cu := columnRefToColumnUsed(target.Val.GetColumnRef()) if cu == nil { continue } @@ -302,111 +309,126 @@ func validateTableColumns(ctx VetContext, tables []TableUsed, cols []ColumnUsed) return nil } -func validateInsertValues(ctx VetContext, cols []ColumnUsed, vals []nodes.Node) error { +func validateInsertValues(ctx VetContext, cols []ColumnUsed, vals []*pg_query.Node) error { colCnt := len(cols) // val could be nodes.ParamRef valCnt := len(vals) if colCnt != valCnt { - return fmt.Errorf("Column count %d doesn't match value count %d.", colCnt, valCnt) + return fmt.Errorf("column count %d doesn't match value count %d", colCnt, valCnt) } return nil } -func parseWindowDef(ctx VetContext, winDef *nodes.WindowDef, parseRe *ParseResult) error { - if len(winDef.PartitionClause.Items) > 0 { - if err := parseExpression(ctx, winDef.PartitionClause, parseRe); err != nil { +func parseWindowDef(ctx VetContext, winDef *pg_query.WindowDef, parseRe *ParseResult) error { + if len(winDef.PartitionClause) > 0 { + // TODO should this be [0] + if err := parseExpression(ctx, winDef.GetPartitionClause()[0], parseRe); err != nil { return err } } - if len(winDef.OrderClause.Items) > 0 { - if err := parseExpression(ctx, winDef.OrderClause, parseRe); err != nil { + if len(winDef.OrderClause) > 0 { + // TODO should this be [0] + if err := parseExpression(ctx, winDef.OrderClause[0], parseRe); err != nil { return err } } return nil } -func parseExpression(ctx VetContext, clause nodes.Node, parseRe *ParseResult) error { - switch expr := clause.(type) { - case nodes.A_Expr: - if expr.Lexpr != nil { - err := parseExpression(ctx, expr.Lexpr, parseRe) +// recursive function to parse expressions including nested expressions +func parseExpression(ctx VetContext, clause *pg_query.Node, parseRe *ParseResult) error { + switch { + case clause.GetAExpr() != nil: + if clause.GetAExpr().GetLexpr() != nil { + err := parseExpression(ctx, clause.GetAExpr().GetLexpr(), parseRe) if err != nil { return err } } - if expr.Rexpr != nil { - err := parseExpression(ctx, expr.Rexpr, parseRe) + if clause.GetAExpr().GetRexpr() != nil { + err := parseExpression(ctx, clause.GetAExpr().GetRexpr(), parseRe) if err != nil { return err } } - case nodes.BoolExpr: - return parseExpression(ctx, expr.Args, parseRe) - case nodes.NullTest: - return parseExpression(ctx, expr.Arg, parseRe) - case nodes.ColumnRef: - cu := columnRefToColumnUsed(expr) + case clause.GetBoolExpr() != nil: + // TODO should this be args or args[0]? + return parseExpression(ctx, clause.GetBoolExpr().Args[0], parseRe) + case clause.GetNullTest() != nil: + nullTest := clause.GetNullTest() + if arg := nullTest.GetArg(); arg != nil { + return parseExpression(ctx, arg, parseRe) + } + case clause.GetColumnRef() != nil: + // cu := columnRefToColumnUsed(expr) + cu := columnRefToColumnUsed(clause.GetColumnRef()) if cu == nil { return nil } parseRe.Columns = append(parseRe.Columns, *cu) - case nodes.ParamRef: + case clause.GetParamRef() != nil: // WHERE id=$1 - AddQueryParam(&parseRe.Params, QueryParam{Number: expr.Number}) - case nodes.A_Const: + AddQueryParam(&parseRe.Params, QueryParam{Number: clause.GetParamRef().GetNumber()}) + case clause.GetAConst() != nil: // WHERE 1 - case nodes.FuncCall: + case clause.GetFuncCall() != nil: // WHERE date=NOW() // WHERE MAX(id) > 1 - if err := parseExpression(ctx, expr.Args, parseRe); err != nil { - return err + // TODO should this be args or args[0]? + funcCall := clause.GetFuncCall() + if len(funcCall.Args) > 0 { + if err := parseExpression(ctx, funcCall.Args[0], parseRe); err != nil { + return err + } } // SELECT ROW_NUMBER() OVER (PARTITION BY id) - if expr.Over != nil { - err := parseExpression(ctx, expr.Over, parseRe) - if err != nil { - return err + if clause.GetFuncCall().GetOver() != nil { + // TODO dubious rewrite and should this be [0] + over := clause.GetFuncCall().GetOver() + if len(over.PartitionClause) > 0 { + err := parseExpression(ctx, over.PartitionClause[0], parseRe) + if err != nil { + return err + } } } - case nodes.TypeCast: + case clause.GetTypeCast() != nil: // WHERE foo=True - return parseExpression(ctx, expr.Arg, parseRe) - case nodes.List: + return parseExpression(ctx, clause.GetTypeCast().Arg, parseRe) + case clause.GetList() != nil: // WHERE id IN (1, 2, 3) - for _, item := range expr.Items { + for _, item := range clause.GetList().Items { err := parseExpression(ctx, item, parseRe) if err != nil { return err } } - case nodes.SubLink: + case clause.GetSubLink() != nil: // WHERE id IN (SELECT id FROM foo) - selectStmt, ok := expr.Subselect.(nodes.SelectStmt) - if !ok { + subselect := clause.GetSubLink().GetSubselect() + if subselect.GetSelectStmt() == nil { return fmt.Errorf( - "Unsupported subquery type: %s", reflect.TypeOf(expr.Subselect)) + "unsupported subquery type: %v", subselect) } - queryParams, err := validateSelectStmt(ctx, selectStmt) + queryParams, err := validateSelectStmt(ctx, subselect.GetSelectStmt()) if err != nil { return err } if len(queryParams) > 0 { AddQueryParams(&parseRe.Params, queryParams) } - case nodes.CoalesceExpr: - return parseExpression(ctx, expr.Args, parseRe) - case *nodes.WindowDef: - return parseWindowDef(ctx, expr, parseRe) - case nodes.WindowDef: - return parseWindowDef(ctx, &expr, parseRe) - case nodes.SortBy: - return parseExpression(ctx, expr.Node, parseRe) + case clause.GetCoalesceExpr() != nil: + // TODO should this be [0]? + return parseExpression(ctx, clause.GetCoalesceExpr().GetArgs()[0], parseRe) + case clause.GetWindowDef() != nil: + return parseWindowDef(ctx, clause.GetWindowDef(), parseRe) + case clause.GetSortBy() != nil: + return parseExpression(ctx, clause.GetSortBy().Node, parseRe) default: return fmt.Errorf( - "Unsupported expression, found node of type: %v", + "unsupported expression, found node of type: %v", reflect.TypeOf(clause), ) } @@ -415,20 +437,19 @@ func parseExpression(ctx VetContext, clause nodes.Node, parseRe *ParseResult) er } // find used column names from where clause -func parseWhereClause(ctx VetContext, clause nodes.Node, parseRe *ParseResult) error { +func parseWhereClause(ctx VetContext, clause *pg_query.Node, parseRe *ParseResult) error { err := parseExpression(ctx, clause, parseRe) if err != nil { - err = fmt.Errorf("Invalid WHERE clause: %w", err) + err = fmt.Errorf("invalid WHERE clause: %w", err) } return err } -func getUsedColumnsFromNodeList(nodelist nodes.List) []ColumnUsed { +func getUsedColumnsFromNodeList(nodelist []*pg_query.Node) []ColumnUsed { usedCols := []ColumnUsed{} - for _, item := range nodelist.Items { - switch clause := item.(type) { - case nodes.ColumnRef: - cu := columnRefToColumnUsed(clause) + for _, item := range nodelist { + if item.GetColumnRef() != nil { + cu := columnRefToColumnUsed(item.GetColumnRef()) if cu != nil { usedCols = append(usedCols, *cu) } @@ -437,12 +458,11 @@ func getUsedColumnsFromNodeList(nodelist nodes.List) []ColumnUsed { return usedCols } -func getUsedColumnsFromSortClause(sortList nodes.List) []ColumnUsed { +func getUsedColumnsFromSortClause(sortList []*pg_query.Node) []ColumnUsed { usedCols := []ColumnUsed{} - for _, item := range sortList.Items { - switch sortClause := item.(type) { - case nodes.SortBy: - if colRef, ok := sortClause.Node.(nodes.ColumnRef); ok { + for _, item := range sortList { + if item.GetSortBy() != nil { + if colRef := item.GetSortBy().GetNode().GetColumnRef(); colRef != nil { cu := columnRefToColumnUsed(colRef) if cu != nil { usedCols = append(usedCols, *cu) @@ -453,20 +473,19 @@ func getUsedColumnsFromSortClause(sortList nodes.List) []ColumnUsed { return usedCols } -func validateSelectStmt(ctx VetContext, stmt nodes.SelectStmt) ([]QueryParam, error) { +func validateSelectStmt(ctx VetContext, stmt *pg_query.SelectStmt) ([]QueryParam, error) { usedTables := getUsedTablesFromSelectStmt(stmt.FromClause) usedCols := []ColumnUsed{} queryParams := []QueryParam{} - for _, item := range stmt.TargetList.Items { - target, ok := item.(nodes.ResTarget) - if !ok { + for _, item := range stmt.TargetList { + if item.GetResTarget() == nil { continue } re := &ParseResult{} - err := parseExpression(ctx, target.Val, re) + err := parseExpression(ctx, item.GetResTarget().Val, re) if err != nil { return nil, err } @@ -494,7 +513,7 @@ func validateSelectStmt(ctx VetContext, stmt nodes.SelectStmt) ([]QueryParam, er } } - if len(stmt.GroupClause.Items) > 0 { + if len(stmt.GroupClause) > 0 { usedCols = append(usedCols, getUsedColumnsFromNodeList(stmt.GroupClause)...) } @@ -512,9 +531,10 @@ func validateSelectStmt(ctx VetContext, stmt nodes.SelectStmt) ([]QueryParam, er } } - if len(stmt.WindowClause.Items) > 0 { + if len(stmt.WindowClause) > 0 { re := &ParseResult{} - err := parseExpression(ctx, stmt.WindowClause, re) + // TODO: should this be [0]? + err := parseExpression(ctx, stmt.WindowClause[0], re) if err != nil { return nil, err } @@ -522,39 +542,37 @@ func validateSelectStmt(ctx VetContext, stmt nodes.SelectStmt) ([]QueryParam, er AddQueryParams(&queryParams, re.Params) } - if len(stmt.SortClause.Items) > 0 { + if len(stmt.SortClause) > 0 { usedCols = append(usedCols, getUsedColumnsFromSortClause(stmt.SortClause)...) } return queryParams, validateTableColumns(ctx, usedTables, usedCols) } -func validateUpdateStmt(ctx VetContext, stmt nodes.UpdateStmt) ([]QueryParam, error) { - tableName := *stmt.Relation.Relname +func validateUpdateStmt(ctx VetContext, stmt *pg_query.UpdateStmt) ([]QueryParam, error) { + tableName := stmt.Relation.Relname usedTables := []TableUsed{{Name: tableName}} usedTables = append(usedTables, getUsedTablesFromSelectStmt(stmt.FromClause)...) usedCols := []ColumnUsed{} queryParams := []QueryParam{} - for _, item := range stmt.TargetList.Items { - target := item.(nodes.ResTarget) + for _, item := range stmt.TargetList { + target := item.GetResTarget() usedCols = append(usedCols, ColumnUsed{ Table: tableName, - Column: *target.Name, + Column: target.Name, Location: target.Location, }) - // 'val' is the expression to assign. - switch expr := target.Val.(type) { - case nodes.ColumnRef: - // UPDATE table1 SET table1.foo=table2.bar FROM table2 - cu := columnRefToColumnUsed(expr) + switch { + case target.Val != nil && target.Val.GetColumnRef() != nil: + cu := columnRefToColumnUsed(target.Val.GetColumnRef()) if cu != nil { usedCols = append(usedCols, *cu) } - case nodes.ParamRef: - AddQueryParam(&queryParams, QueryParam{Number: expr.Number}) + case target.Val != nil && target.Val.GetParamRef() != nil: + AddQueryParam(&queryParams, QueryParam{Number: target.Val.GetParamRef().Number}) } } @@ -568,34 +586,34 @@ func validateUpdateStmt(ctx VetContext, stmt nodes.UpdateStmt) ([]QueryParam, er AddQueryParams(&queryParams, re.Params) } - if len(stmt.ReturningList.Items) > 0 { + if len(stmt.ReturningList) > 0 { usedCols = append(usedCols, getUsedColumnsFromReturningList(stmt.ReturningList)...) } return queryParams, validateTableColumns(ctx, usedTables, usedCols) } -func validateInsertStmt(ctx VetContext, stmt nodes.InsertStmt) ([]QueryParam, error) { - tableName := *stmt.Relation.Relname +func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam, error) { + tableName := stmt.Relation.Relname usedTables := []TableUsed{{Name: tableName}} targetCols := []ColumnUsed{} - for _, item := range stmt.Cols.Items { - target := item.(nodes.ResTarget) + for _, item := range stmt.Cols { + target := item.GetResTarget() targetCols = append(targetCols, ColumnUsed{ Table: tableName, - Column: *target.Name, + Column: target.Name, Location: target.Location, }) } - values := []nodes.Node{} + values := []*pg_query.Node{} // make a copy of targetCols because we need it to do value count // validation separately usedCols := append([]ColumnUsed{}, targetCols...) queryParams := []QueryParam{} - selectStmt := stmt.SelectStmt.(nodes.SelectStmt) + selectStmt := stmt.GetSelectStmt().GetSelectStmt() if selectStmt.ValuesLists != nil { /* * In the form of: @@ -608,19 +626,21 @@ func validateInsertStmt(ctx VetContext, stmt nodes.InsertStmt) ([]QueryParam, er * node), regardless of the context of the VALUES list. It's up to parse * analysis to reject that where not valid. */ - for _, node := range selectStmt.ValuesLists[0] { - re := &ParseResult{} - err := parseExpression(ctx, node, re) - if err != nil { - return nil, fmt.Errorf("Invalid value list: %w", err) - } - if len(re.Columns) > 0 { - usedCols = append(usedCols, re.Columns...) - } - if len(re.Params) > 0 { - AddQueryParams(&queryParams, re.Params) + for _, list := range selectStmt.GetValuesLists() { + for _, node := range list.GetList().Items { + re := &ParseResult{} + err := parseExpression(ctx, node, re) + if err != nil { + return nil, fmt.Errorf("invalid value list: %w", err) + } + if len(re.Columns) > 0 { + usedCols = append(usedCols, re.Columns...) + } + if len(re.Params) > 0 { + AddQueryParams(&queryParams, re.Params) + } + values = append(values, node) } - values = append(values, node) } } else { /* @@ -646,26 +666,26 @@ func validateInsertStmt(ctx VetContext, stmt nodes.InsertStmt) ([]QueryParam, er } } - for _, item := range selectStmt.TargetList.Items { - target := item.(nodes.ResTarget) + for _, item := range selectStmt.TargetList { + target := item.GetResTarget().Val values = append(values, target) - switch targetVal := target.Val.(type) { - case nodes.ColumnRef: - cu := columnRefToColumnUsed(targetVal) + switch { + case target.GetColumnRef() != nil: + cu := columnRefToColumnUsed(target.GetColumnRef()) if cu == nil { continue } usedCols = append(usedCols, *cu) - case nodes.SubLink: - subquery, ok := targetVal.Subselect.(nodes.SelectStmt) - if !ok { + case target.GetSubLink() != nil: + tv := target.GetSubLink().Subselect + if tv.GetSelectStmt() == nil { return nil, fmt.Errorf( - "Unsupported subquery type in value list: %s", reflect.TypeOf(targetVal.Subselect)) + "unsupported subquery type in value list: %s", reflect.TypeOf(tv)) } - qparams, err := validateSelectStmt(ctx, subquery) + qparams, err := validateSelectStmt(ctx, tv.GetSelectStmt()) if err != nil { - return nil, fmt.Errorf("Invalid SELECT query in value list: %w", err) + return nil, fmt.Errorf("invalid SELECT query in value list: %w", err) } if len(qparams) > 0 { AddQueryParams(&queryParams, qparams) @@ -674,7 +694,7 @@ func validateInsertStmt(ctx VetContext, stmt nodes.InsertStmt) ([]QueryParam, er } } - if len(stmt.ReturningList.Items) > 0 { + if len(stmt.ReturningList) > 0 { usedCols = append(usedCols, getUsedColumnsFromReturningList(stmt.ReturningList)...) } @@ -689,8 +709,8 @@ func validateInsertStmt(ctx VetContext, stmt nodes.InsertStmt) ([]QueryParam, er return queryParams, nil } -func validateDeleteStmt(ctx VetContext, stmt nodes.DeleteStmt) ([]QueryParam, error) { - tableName := *stmt.Relation.Relname +func validateDeleteStmt(ctx VetContext, stmt *pg_query.DeleteStmt) ([]QueryParam, error) { + tableName := stmt.Relation.Relname if err := validateTable(ctx, tableName); err != nil { return nil, err } @@ -712,7 +732,7 @@ func validateDeleteStmt(ctx VetContext, stmt nodes.DeleteStmt) ([]QueryParam, er } } - if len(stmt.ReturningList.Items) > 0 { + if len(stmt.ReturningList) > 0 { usedCols = append( usedCols, getUsedColumnsFromReturningList(stmt.ReturningList)...) } @@ -733,32 +753,28 @@ func ValidateSqlQuery(ctx VetContext, queryStr string) ([]QueryParam, error) { return nil, err } - if len(tree.Statements) == 0 || len(tree.Statements) > 1 { - return nil, fmt.Errorf("query contained more than one statement.") - } - - raw, ok := tree.Statements[0].(nodes.RawStmt) - if !ok { - return nil, fmt.Errorf("query contained invalid statement.") - } - - switch stmt := raw.Stmt.(type) { - case nodes.SelectStmt: - return validateSelectStmt(ctx, stmt) - case nodes.UpdateStmt: - return validateUpdateStmt(ctx, stmt) - case nodes.InsertStmt: - return validateInsertStmt(ctx, stmt) - case nodes.DeleteStmt: - return validateDeleteStmt(ctx, stmt) - case nodes.DropStmt: - case nodes.TruncateStmt: - case nodes.AlterTableStmt: - case nodes.CreateSchemaStmt: - case nodes.VariableSetStmt: - // TODO: check for invalid pg variables + if len(tree.Stmts) == 0 || len(tree.Stmts) > 1 { + return nil, fmt.Errorf("query contained more than one statement") + } + + var raw *pg_query.RawStmt = tree.Stmts[0] + switch { + case raw.Stmt.GetSelectStmt() != nil: + return validateSelectStmt(ctx, raw.Stmt.GetSelectStmt()) + case raw.Stmt.GetUpdateStmt() != nil: + return validateUpdateStmt(ctx, raw.Stmt.GetUpdateStmt()) + case raw.Stmt.GetInsertStmt() != nil: + return validateInsertStmt(ctx, raw.Stmt.GetInsertStmt()) + case raw.Stmt.GetDeleteStmt() != nil: + return validateDeleteStmt(ctx, raw.Stmt.GetDeleteStmt()) + case raw.Stmt.GetDropStmt() != nil: + case raw.Stmt.GetTruncateStmt() != nil: + case raw.Stmt.GetAlterTableStmt() != nil: + case raw.Stmt.GetCreateSchemaStmt() != nil: + case raw.Stmt.GetVariableSetStmt() != nil: + // TODO: check for invalid pg variables default: - return nil, fmt.Errorf("unsupported statement: %v.", reflect.TypeOf(raw.Stmt)) + return nil, fmt.Errorf("unsupported statement: %v", reflect.TypeOf(raw.Stmt)) } return nil, nil diff --git a/pkg/vet/vet_test.go b/pkg/vet/vet_test.go index 9a6b05e..89b487b 100644 --- a/pkg/vet/vet_test.go +++ b/pkg/vet/vet_test.go @@ -129,12 +129,12 @@ func TestInvalidInsert(t *testing.T) { { "not enough values", `INSERT INTO foo (id, value) VALUES ($1)`, - errors.New("Column count 2 doesn't match value count 1."), + errors.New("column count 2 doesn't match value count 1"), }, { "too many values", `INSERT INTO foo (id, value) VALUES ($1, $2, $3)`, - errors.New("Column count 2 doesn't match value count 3."), + errors.New("column count 2 doesn't match value count 3"), }, { "invalid column in value list", @@ -177,7 +177,7 @@ func TestInvalidInsert(t *testing.T) { FROM bar WHERE bar.id=2`, fmt.Errorf( - "Invalid SELECT query in value list: %w", + "invalid SELECT query in value list: %w", errors.New("column `ida` is not defined in table `bar`")), }, { @@ -210,7 +210,7 @@ func TestInvalidInsert(t *testing.T) { 'test' )`, fmt.Errorf( - "Invalid value list: %w", + "invalid value list: %w", errors.New("column `ida` is not defined in table `bar`")), }, { @@ -578,14 +578,14 @@ func TestInvalidDelete(t *testing.T) { "invalid column in where subquery", `DELETE FROM foo WHERE id = (SELECT id FROM foo WHERE date=NOW())`, fmt.Errorf( - "Invalid WHERE clause: %w", + "invalid WHERE clause: %w", errors.New("column `date` is not defined in table `foo`")), }, { "invalid table in where subquery", `DELETE FROM foo WHERE id = (SELECT id FROM foononexist WHERE id=1)`, fmt.Errorf( - "Invalid WHERE clause: %w", + "invalid WHERE clause: %w", errors.New("invalid table name: foononexist")), }, {