From f25115c8da0abe35a7d0b7f054f7f7544bb61111 Mon Sep 17 00:00:00 2001 From: Don Browne Date: Thu, 11 Jul 2024 13:04:45 +0100 Subject: [PATCH] Fix issues with mapping of `TIMESTAMPZ[]` postgres type Prior to this PR, sqlc was unable to map `TIMESTAMPZ[]` to `[]time.Time` due a limitation in the `pq` driver. This PR leverages a suggested fix from a Github issue in the `pq` repo. --- internal/db/eval_history.sql.go | 6 +-- internal/db/models.go | 2 +- internal/db/types.go | 76 ++++++++++++++++++++++++++++++++ internal/history/service.go | 4 +- internal/history/service_test.go | 6 +-- sqlc.yaml | 4 ++ 6 files changed, 89 insertions(+), 9 deletions(-) create mode 100644 internal/db/types.go diff --git a/internal/db/eval_history.sql.go b/internal/db/eval_history.sql.go index 431fd59d5b..fedddf14ef 100644 --- a/internal/db/eval_history.sql.go +++ b/internal/db/eval_history.sql.go @@ -62,7 +62,7 @@ func (q *Queries) GetLatestEvalStateForRuleEntity(ctx context.Context, arg GetLa &i.RuleEntityID, &i.Status, &i.Details, - pq.Array(&i.EvaluationTimes), + &i.EvaluationTimes, &i.MostRecentEvaluation, ) return i, err @@ -375,12 +375,12 @@ WHERE id = $2 ` type UpdateEvaluationTimesParams struct { - EvaluationTimes []time.Time `json:"evaluation_times"` + EvaluationTimes PgTimeArray `json:"evaluation_times"` ID uuid.UUID `json:"id"` } func (q *Queries) UpdateEvaluationTimes(ctx context.Context, arg UpdateEvaluationTimesParams) error { - _, err := q.db.ExecContext(ctx, updateEvaluationTimes, pq.Array(arg.EvaluationTimes), arg.ID) + _, err := q.db.ExecContext(ctx, updateEvaluationTimes, arg.EvaluationTimes, arg.ID) return err } diff --git a/internal/db/models.go b/internal/db/models.go index d916984c7f..f0b99ed79e 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -494,7 +494,7 @@ type EvaluationStatus struct { RuleEntityID uuid.UUID `json:"rule_entity_id"` Status EvalStatusTypes `json:"status"` Details string `json:"details"` - EvaluationTimes []time.Time `json:"evaluation_times"` + EvaluationTimes PgTimeArray `json:"evaluation_times"` MostRecentEvaluation time.Time `json:"most_recent_evaluation"` } diff --git a/internal/db/types.go b/internal/db/types.go new file mode 100644 index 0000000000..4012e4114f --- /dev/null +++ b/internal/db/types.go @@ -0,0 +1,76 @@ +// Copyright 2024 Stacklok, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// PgTime/PgTimeArray is used to work around the pq driver's inability to map +// TIMESTAMPZ[] to []time.Time without some hand holding. +// Code taken mostly as-is from: https://github.com/lib/pq/issues/536#issuecomment-397849980 + +package db + +import ( + "database/sql/driver" + "errors" + "time" + + "github.com/lib/pq" +) + +// ErrParseData signifies that an error occurred while parsing SQL data +var ErrParseData = errors.New("unable to parse SQL data") + +// PgTime wraps a time.Time +type PgTime struct{ time.Time } + +// Scan implements the sql.Scanner interface +func (t *PgTime) Scan(val interface{}) error { + switch v := val.(type) { + case time.Time: + t.Time = v + return nil + case []uint8: // byte is the same as uint8: https://golang.org/pkg/builtin/#byte + _t, err := pq.ParseTimestamp(nil, string(v)) + if err != nil { + return ErrParseData + } + t.Time = _t + return nil + case string: + _t, err := pq.ParseTimestamp(nil, v) + if err != nil { + return ErrParseData + } + t.Time = _t + return nil + } + return ErrParseData +} + +// Value implements the driver.Valuer interface +func (t *PgTime) Value() (driver.Value, error) { return pq.FormatTimestamp(t.Time), nil } + +// PgTimeArray wraps a time.Time slice to be used as a Postgres array +// type PgTimeArray []time.Time +type PgTimeArray []PgTime + +//type PgTimeArray []pq.NullTime + +// Scan implements the sql.Scanner interface +func (a *PgTimeArray) Scan(src interface{}) error { + return pq.GenericArray{A: a}.Scan(src) +} + +// Value implements the driver.Valuer interface +func (a *PgTimeArray) Value() (driver.Value, error) { + return pq.GenericArray{A: a}.Value() +} diff --git a/internal/history/service.go b/internal/history/service.go index 8c3982bcc3..5a48e86a1f 100644 --- a/internal/history/service.go +++ b/internal/history/service.go @@ -162,10 +162,10 @@ func (_ *evaluationHistoryService) updateExistingStatus( ctx context.Context, qtx db.Querier, evaluationID uuid.UUID, - times []time.Time, + times db.PgTimeArray, ) error { // if the status is repeated, then just append the current timestamp to it - times = append(times, time.Now()) + times = append(times, db.PgTime{Time: time.Now()}) return qtx.UpdateEvaluationTimes(ctx, db.UpdateEvaluationTimesParams{ EvaluationTimes: times, ID: evaluationID, diff --git a/internal/history/service_test.go b/internal/history/service_test.go index bd575159e2..60fc9909bb 100644 --- a/internal/history/service_test.go +++ b/internal/history/service_test.go @@ -747,21 +747,21 @@ var ( RuleEntityID: ruleEntityID, Status: db.EvalStatusTypesError, Details: errTest.Error(), - EvaluationTimes: []time.Time{time.Now()}, + EvaluationTimes: db.PgTimeArray{db.PgTime{Time: time.Now()}}, } differentDetails = db.EvaluationStatus{ ID: evaluationID, RuleEntityID: ruleEntityID, Status: db.EvalStatusTypesError, Details: "something went wrong", - EvaluationTimes: []time.Time{time.Now()}, + EvaluationTimes: db.PgTimeArray{db.PgTime{Time: time.Now()}}, } differentState = db.EvaluationStatus{ ID: evaluationID, RuleEntityID: ruleEntityID, Status: db.EvalStatusTypesSkipped, Details: engerr.ErrEvaluationSkipped.Error(), - EvaluationTimes: []time.Time{time.Now()}, + EvaluationTimes: db.PgTimeArray{db.PgTime{Time: time.Now()}}, } errTest = errors.New("oh no") ) diff --git a/sqlc.yaml b/sqlc.yaml index b06c009266..bbf74cde10 100644 --- a/sqlc.yaml +++ b/sqlc.yaml @@ -31,3 +31,7 @@ sql: - db_type: profile_selector go_type: type: "ProfileSelector" + - column: "evaluation_statuses.evaluation_times" + go_type: + type: "PgTimeArray" +