From 286e96f1627111033c8b64a2425facac5a7de2d0 Mon Sep 17 00:00:00 2001 From: Mitchell Adair Date: Mon, 16 Dec 2024 18:38:49 -0500 Subject: [PATCH] implement vector type, start tests --- drizzle-kit/src/introspect-singlestore.ts | 11 +++ .../src/serializer/singlestoreSerializer.ts | 2 +- drizzle-kit/tests/push/singlestore.test.ts | 8 ++ .../src/singlestore-core/columns/all.ts | 2 + .../src/singlestore-core/columns/index.ts | 1 + .../src/singlestore-core/columns/vector.ts | 80 +++++++++++++++++++ .../src/singlestore-core/expressions.ts | 9 +++ drizzle-orm/type-tests/singlestore/tables.ts | 5 ++ .../tests/singlestore/singlestore-common.ts | 43 ++++++++++ 9 files changed, 160 insertions(+), 1 deletion(-) create mode 100644 drizzle-orm/src/singlestore-core/columns/vector.ts diff --git a/drizzle-kit/src/introspect-singlestore.ts b/drizzle-kit/src/introspect-singlestore.ts index 09c2feec0..8f9c98acd 100644 --- a/drizzle-kit/src/introspect-singlestore.ts +++ b/drizzle-kit/src/introspect-singlestore.ts @@ -49,6 +49,7 @@ const singlestoreImportsList = new Set([ 'tinyint', 'varbinary', 'varchar', + 'vector', 'year', 'enum', ]); @@ -789,6 +790,16 @@ const column = ( return out; } + if (lowered.startsWith('vector')) { + const [dimensions, elementType] = lowered.substring('vector'.length + 1, lowered.length - 1).split(','); + let out = `${casing(name)}: vector(${ + dbColumnName({ name, casing: rawCasing, withMode: true }) + }{ dimensions: ${dimensions}${elementType ? `, elementType: ${elementType}` : ''} })`; + + out += defaultValue ? `.default(${mapColumnDefault(defaultValue, isExpression)})` : ''; + return out; + } + console.log('uknown', type); return `// Warning: Can't parse ${type} from database\n\t// ${type}Type: ${type}("${name}")`; }; diff --git a/drizzle-kit/src/serializer/singlestoreSerializer.ts b/drizzle-kit/src/serializer/singlestoreSerializer.ts index e8c89f1d1..e65f53d25 100644 --- a/drizzle-kit/src/serializer/singlestoreSerializer.ts +++ b/drizzle-kit/src/serializer/singlestoreSerializer.ts @@ -130,7 +130,7 @@ export const generateSingleStoreSnapshot = ( if (typeof column.default === 'string') { columnToSet.default = `'${column.default}'`; } else { - if (sqlTypeLowered === 'json') { + if (sqlTypeLowered === 'json' || Array.isArray(column.default)) { columnToSet.default = `'${JSON.stringify(column.default)}'`; } else if (column.default instanceof Date) { if (sqlTypeLowered === 'date') { diff --git a/drizzle-kit/tests/push/singlestore.test.ts b/drizzle-kit/tests/push/singlestore.test.ts index 82c72063c..dea28759c 100644 --- a/drizzle-kit/tests/push/singlestore.test.ts +++ b/drizzle-kit/tests/push/singlestore.test.ts @@ -23,6 +23,7 @@ import { tinyint, varbinary, varchar, + vector, year, } from 'drizzle-orm/singlestore-core'; import getPort from 'get-port'; @@ -249,6 +250,13 @@ const singlestoreSuite: DialectSuite = { columnNotNull: binary('column_not_null', { length: 1 }).notNull(), columnDefault: binary('column_default', { length: 12 }), }), + + allVectors: singlestoreTable('all_vectors', { + vectorSimple: vector('vector_simple', { dimensions: 1 }), + vectorElementType: vector('vector_element_type', { dimensions: 1, elementType: 'I8' }), + vectorNotNull: vector('vector_not_null', { dimensions: 1 }).notNull(), + vectorDefault: vector('vector_default', { dimensions: 1 }).default([1]), + }), }; const { statements } = await diffTestSchemasPushSingleStore( diff --git a/drizzle-orm/src/singlestore-core/columns/all.ts b/drizzle-orm/src/singlestore-core/columns/all.ts index 66d289e3f..dc8ae407a 100644 --- a/drizzle-orm/src/singlestore-core/columns/all.ts +++ b/drizzle-orm/src/singlestore-core/columns/all.ts @@ -21,6 +21,7 @@ import { timestamp } from './timestamp.ts'; import { tinyint } from './tinyint.ts'; import { varbinary } from './varbinary.ts'; import { varchar } from './varchar.ts'; +import { vector } from './vector.ts'; import { year } from './year.ts'; export function getSingleStoreColumnBuilders() { @@ -48,6 +49,7 @@ export function getSingleStoreColumnBuilders() { tinyint, varbinary, varchar, + vector, year, }; } diff --git a/drizzle-orm/src/singlestore-core/columns/index.ts b/drizzle-orm/src/singlestore-core/columns/index.ts index b51f0fac4..ec17fa21a 100644 --- a/drizzle-orm/src/singlestore-core/columns/index.ts +++ b/drizzle-orm/src/singlestore-core/columns/index.ts @@ -22,4 +22,5 @@ export * from './timestamp.ts'; export * from './tinyint.ts'; export * from './varbinary.ts'; export * from './varchar.ts'; +export * from './vector.ts'; export * from './year.ts'; diff --git a/drizzle-orm/src/singlestore-core/columns/vector.ts b/drizzle-orm/src/singlestore-core/columns/vector.ts new file mode 100644 index 000000000..dee6ba9e9 --- /dev/null +++ b/drizzle-orm/src/singlestore-core/columns/vector.ts @@ -0,0 +1,80 @@ +import type { ColumnBaseConfig } from '~/column'; +import type { ColumnBuilderBaseConfig, ColumnBuilderRuntimeConfig, MakeColumnConfig } from '~/column-builder'; +import { entityKind } from '~/entity.ts'; +import type { AnySingleStoreTable } from '~/singlestore-core/table.ts'; +import { getColumnNameAndConfig } from '~/utils.ts'; +import { SingleStoreColumn, SingleStoreColumnBuilder } from './common.ts'; + +export type SingleStoreVectorBuilderInitial = SingleStoreVectorBuilder<{ + name: TName; + dataType: 'array'; + columnType: 'SingleStoreVector'; + data: Array; + driverParam: Array; + enumValues: undefined; + generated: undefined; +}>; + +export class SingleStoreVectorBuilder> + extends SingleStoreColumnBuilder +{ + static override readonly [entityKind]: string = 'SingleStoreVectorBuilder'; + + constructor(name: T['name'], config: SingleStoreVectorConfig) { + super(name, 'array', 'SingleStoreVector'); + this.config.dimensions = config.dimensions; + this.config.elementType = config.elementType; + } + + /** @internal */ + override build( + table: AnySingleStoreTable<{ name: TTableName }>, + ): SingleStoreVector> { + return new SingleStoreVector(table, this.config as ColumnBuilderRuntimeConfig); + } +} + +export class SingleStoreVector> extends SingleStoreColumn { + static override readonly [entityKind]: string = 'SingleStoreVector'; + + readonly dimensions: number; + readonly elementType: ElementType | undefined; + + constructor(table: AnySingleStoreTable<{ name: T['tableName'] }>, config: SingleStoreVectorBuilder['config']) { + super(table, config); + this.dimensions = config.dimensions; + this.elementType = config.elementType; + } + + getSQLType(): string { + const et = this.elementType === undefined ? '' : `, ${this.elementType}`; + return `vector(${this.dimensions}${et})`; + } + + override mapToDriverValue(value: Array) { + return JSON.stringify(value); + } + + override mapFromDriverValue(value: string): Array { + return JSON.parse(value); + } +} + +type ElementType = 'I8' | 'I16' | 'I32' | 'I64' | 'F32' | 'F64'; + +export interface SingleStoreVectorConfig { + dimensions: number; + elementType?: ElementType; +} + +export function vector( + config: SingleStoreVectorConfig, +): SingleStoreVectorBuilderInitial<''>; +export function vector( + name: TName, + config: SingleStoreVectorConfig, +): SingleStoreVectorBuilderInitial; +export function vector(a: string | SingleStoreVectorConfig, b?: SingleStoreVectorConfig) { + const { name, config } = getColumnNameAndConfig(a, b); + return new SingleStoreVectorBuilder(name, config); +} diff --git a/drizzle-orm/src/singlestore-core/expressions.ts b/drizzle-orm/src/singlestore-core/expressions.ts index 6d4284d18..397e87392 100644 --- a/drizzle-orm/src/singlestore-core/expressions.ts +++ b/drizzle-orm/src/singlestore-core/expressions.ts @@ -23,3 +23,12 @@ export function substring( chunks.push(sql`)`); return sql.join(chunks); } + +// Vectors +export function dotProduct(column: SingleStoreColumn | SQL.Aliased, value: Array) { + return sql`${column} <*> ${JSON.stringify(value)}`; +} + +export function euclideanDistance(column: SingleStoreColumn | SQL.Aliased, value: Array) { + return sql`${column} <-> ${JSON.stringify(value)}`; +} diff --git a/drizzle-orm/type-tests/singlestore/tables.ts b/drizzle-orm/type-tests/singlestore/tables.ts index 73d9c6993..fb02eb774 100644 --- a/drizzle-orm/type-tests/singlestore/tables.ts +++ b/drizzle-orm/type-tests/singlestore/tables.ts @@ -34,6 +34,7 @@ import { uniqueIndex, varbinary, varchar, + vector, year, } from '~/singlestore-core/index.ts'; import { singlestoreSchema } from '~/singlestore-core/schema.ts'; @@ -917,6 +918,8 @@ Expect< varchar: varchar('varchar', { length: 1 }), varchar2: varchar('varchar2', { length: 1, enum: ['a', 'b', 'c'] }), varchardef: varchar('varchardef', { length: 1 }).default(''), + vector: vector('vector', { dimensions: 1 }), + vector2: vector('vector2', { dimensions: 1, elementType: 'I8' }), year: year('year'), yeardef: year('yeardef').default(0), }); @@ -1015,6 +1018,8 @@ Expect< varchar: varchar({ length: 1 }), varchar2: varchar({ length: 1, enum: ['a', 'b', 'c'] }), varchardef: varchar({ length: 1 }).default(''), + vector: vector({ dimensions: 1 }), + vector2: vector({ dimensions: 1, elementType: 'I8' }), year: year(), yeardef: year().default(0), }); diff --git a/integration-tests/tests/singlestore/singlestore-common.ts b/integration-tests/tests/singlestore/singlestore-common.ts index fe7c2afb4..5c5d357bf 100644 --- a/integration-tests/tests/singlestore/singlestore-common.ts +++ b/integration-tests/tests/singlestore/singlestore-common.ts @@ -58,8 +58,10 @@ import { uniqueIndex, uniqueKeyName, varchar, + vector, year, } from 'drizzle-orm/singlestore-core'; +import { euclideanDistance, dotProduct } from 'drizzle-orm/singlestore-core/expressions'; import { migrate } from 'drizzle-orm/singlestore/migrator'; import getPort from 'get-port'; import { v4 as uuid } from 'uuid'; @@ -156,6 +158,12 @@ const aggregateTable = singlestoreTable('aggregate_table', { nullOnly: int('null_only'), }); +const vectorSearchTable = singlestoreTable('vector_search', { + id: serial('id').notNull(), + text: text('text').notNull(), + embedding: vector('embedding', { dimensions: 10 }), +}); + // To test another schema and multischema const mySchema = singlestoreSchema(`mySchema`); @@ -366,6 +374,23 @@ export function tests(driver?: string) { ]); } + async function setupVectorSearchTest(db: TestSingleStoreDB) { + await db.execute(sql`drop table if exists \`vector_search\``); + await db.execute( + sql` + create table \`vector_search\` ( + \`id\` integer primary key auto_increment not null, + \`text\` text not null, + \`embedding\` vector(10) not null + ) + `, + ) + await db.insert(vectorSearchTable).values([ + { id: 1, text: "I like dogs", embedding: [0.6119,0.1395,0.2921,0.3664,0.4561,0.7852,0.1997,0.5142,0.5924,0.0465] }, + { id: 2, text: "I like cats", embedding: [0.6075,0.1705,0.0651,0.9489,0.9656,0.8084,0.3046,0.0977,0.6842,0.4402] } + ]) + } + test('table config: unsigned ints', async () => { const unsignedInts = singlestoreTable('cities1', { bigint: bigint('bigint', { mode: 'number', unsigned: true }), @@ -2907,6 +2932,24 @@ export function tests(driver?: string) { expect(result2[0]?.value).toBe(null); }); + test('simple vector search', async (ctx) => { + const { db } = ctx.singlestore; + const table = vectorSearchTable; + const embedding = [0.42,0.93,0.88,0.57,0.32,0.64,0.76,0.52,0.19,0.81]; // ChatGPT's 10 dimension embedding for "dogs are cool" + await setupVectorSearchTest(db); + + const withRankEuclidean = db.select({ id: table.id, text: table.text, rank: sql`row_number() over (order by ${euclideanDistance(table.embedding, embedding)})`.as('rank') }).from(table).as('with_rank') + const withRankDotProduct = db.select({ id: table.id, text: table.text, rank: sql`row_number() over (order by ${dotProduct(table.embedding, embedding)})`.as('rank') }).from(table).as('with_rank') + const result1 = await db.select({ id: withRankEuclidean.id, text: withRankEuclidean.text }).from(withRankEuclidean).where(eq(withRankEuclidean.rank, 1)); + const result2 = await db.select({ id: withRankDotProduct.id, text: withRankDotProduct.text }).from(withRankDotProduct).where(eq(withRankDotProduct.rank, 1)); + + expect(result1.length).toEqual(1) + expect(result1[0]).toEqual({ id: 1, text: "I like dogs" }); + + expect(result2.length).toEqual(1) + expect(result2[0]).toEqual({ id: 1, text: "I like dogs" }); + }) + test('test $onUpdateFn and $onUpdate works as $default', async (ctx) => { const { db } = ctx.singlestore;