Skip to content

Commit

Permalink
implement vector type, start tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mitchwadair committed Dec 16, 2024
1 parent fe986e6 commit 286e96f
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 1 deletion.
11 changes: 11 additions & 0 deletions drizzle-kit/src/introspect-singlestore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ const singlestoreImportsList = new Set([
'tinyint',
'varbinary',
'varchar',
'vector',
'year',
'enum',
]);
Expand Down Expand Up @@ -789,6 +790,16 @@ const column = (
return out;
}

if (lowered.startsWith('vector')) {
const [dimensions, elementType] = lowered.substring('vector'.length + 1, lowered.length - 1).split(',');
let out = `${casing(name)}: vector(${
dbColumnName({ name, casing: rawCasing, withMode: true })
}{ dimensions: ${dimensions}${elementType ? `, elementType: ${elementType}` : ''} })`;

out += defaultValue ? `.default(${mapColumnDefault(defaultValue, isExpression)})` : '';
return out;
}

console.log('uknown', type);
return `// Warning: Can't parse ${type} from database\n\t// ${type}Type: ${type}("${name}")`;
};
Expand Down
2 changes: 1 addition & 1 deletion drizzle-kit/src/serializer/singlestoreSerializer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ export const generateSingleStoreSnapshot = (
if (typeof column.default === 'string') {
columnToSet.default = `'${column.default}'`;
} else {
if (sqlTypeLowered === 'json') {
if (sqlTypeLowered === 'json' || Array.isArray(column.default)) {
columnToSet.default = `'${JSON.stringify(column.default)}'`;
} else if (column.default instanceof Date) {
if (sqlTypeLowered === 'date') {
Expand Down
8 changes: 8 additions & 0 deletions drizzle-kit/tests/push/singlestore.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
tinyint,
varbinary,
varchar,
vector,
year,
} from 'drizzle-orm/singlestore-core';
import getPort from 'get-port';
Expand Down Expand Up @@ -249,6 +250,13 @@ const singlestoreSuite: DialectSuite = {
columnNotNull: binary('column_not_null', { length: 1 }).notNull(),
columnDefault: binary('column_default', { length: 12 }),
}),

allVectors: singlestoreTable('all_vectors', {
vectorSimple: vector('vector_simple', { dimensions: 1 }),
vectorElementType: vector('vector_element_type', { dimensions: 1, elementType: 'I8' }),
vectorNotNull: vector('vector_not_null', { dimensions: 1 }).notNull(),
vectorDefault: vector('vector_default', { dimensions: 1 }).default([1]),
}),
};

const { statements } = await diffTestSchemasPushSingleStore(
Expand Down
2 changes: 2 additions & 0 deletions drizzle-orm/src/singlestore-core/columns/all.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { timestamp } from './timestamp.ts';
import { tinyint } from './tinyint.ts';
import { varbinary } from './varbinary.ts';
import { varchar } from './varchar.ts';
import { vector } from './vector.ts';
import { year } from './year.ts';

export function getSingleStoreColumnBuilders() {
Expand Down Expand Up @@ -48,6 +49,7 @@ export function getSingleStoreColumnBuilders() {
tinyint,
varbinary,
varchar,
vector,
year,
};
}
Expand Down
1 change: 1 addition & 0 deletions drizzle-orm/src/singlestore-core/columns/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ export * from './timestamp.ts';
export * from './tinyint.ts';
export * from './varbinary.ts';
export * from './varchar.ts';
export * from './vector.ts';
export * from './year.ts';
80 changes: 80 additions & 0 deletions drizzle-orm/src/singlestore-core/columns/vector.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import type { ColumnBaseConfig } from '~/column';
import type { ColumnBuilderBaseConfig, ColumnBuilderRuntimeConfig, MakeColumnConfig } from '~/column-builder';
import { entityKind } from '~/entity.ts';
import type { AnySingleStoreTable } from '~/singlestore-core/table.ts';
import { getColumnNameAndConfig } from '~/utils.ts';
import { SingleStoreColumn, SingleStoreColumnBuilder } from './common.ts';

export type SingleStoreVectorBuilderInitial<TName extends string> = SingleStoreVectorBuilder<{
name: TName;
dataType: 'array';
columnType: 'SingleStoreVector';
data: Array<number>;
driverParam: Array<number>;
enumValues: undefined;
generated: undefined;
}>;

export class SingleStoreVectorBuilder<T extends ColumnBuilderBaseConfig<'array', 'SingleStoreVector'>>
extends SingleStoreColumnBuilder<T, SingleStoreVectorConfig>
{
static override readonly [entityKind]: string = 'SingleStoreVectorBuilder';

constructor(name: T['name'], config: SingleStoreVectorConfig) {
super(name, 'array', 'SingleStoreVector');
this.config.dimensions = config.dimensions;
this.config.elementType = config.elementType;
}

/** @internal */
override build<TTableName extends string>(
table: AnySingleStoreTable<{ name: TTableName }>,
): SingleStoreVector<MakeColumnConfig<T, TTableName>> {
return new SingleStoreVector(table, this.config as ColumnBuilderRuntimeConfig<any, any>);
}
}

export class SingleStoreVector<T extends ColumnBaseConfig<'array', 'SingleStoreVector'>> extends SingleStoreColumn<T> {
static override readonly [entityKind]: string = 'SingleStoreVector';

readonly dimensions: number;
readonly elementType: ElementType | undefined;

constructor(table: AnySingleStoreTable<{ name: T['tableName'] }>, config: SingleStoreVectorBuilder<T>['config']) {
super(table, config);
this.dimensions = config.dimensions;
this.elementType = config.elementType;
}

getSQLType(): string {
const et = this.elementType === undefined ? '' : `, ${this.elementType}`;
return `vector(${this.dimensions}${et})`;
}

override mapToDriverValue(value: Array<number>) {
return JSON.stringify(value);
}

override mapFromDriverValue(value: string): Array<number> {
return JSON.parse(value);
}
}

type ElementType = 'I8' | 'I16' | 'I32' | 'I64' | 'F32' | 'F64';

export interface SingleStoreVectorConfig {
dimensions: number;
elementType?: ElementType;
}

export function vector<U extends string>(
config: SingleStoreVectorConfig,
): SingleStoreVectorBuilderInitial<''>;
export function vector<TName extends string>(
name: TName,
config: SingleStoreVectorConfig,
): SingleStoreVectorBuilderInitial<TName>;
export function vector(a: string | SingleStoreVectorConfig, b?: SingleStoreVectorConfig) {
const { name, config } = getColumnNameAndConfig<SingleStoreVectorConfig>(a, b);
return new SingleStoreVectorBuilder(name, config);
}
9 changes: 9 additions & 0 deletions drizzle-orm/src/singlestore-core/expressions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,12 @@ export function substring(
chunks.push(sql`)`);
return sql.join(chunks);
}

// Vectors
export function dotProduct(column: SingleStoreColumn | SQL.Aliased, value: Array<number>) {
return sql`${column} <*> ${JSON.stringify(value)}`;
}

export function euclideanDistance(column: SingleStoreColumn | SQL.Aliased, value: Array<number>) {
return sql`${column} <-> ${JSON.stringify(value)}`;
}
5 changes: 5 additions & 0 deletions drizzle-orm/type-tests/singlestore/tables.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import {
uniqueIndex,
varbinary,
varchar,
vector,
year,
} from '~/singlestore-core/index.ts';
import { singlestoreSchema } from '~/singlestore-core/schema.ts';
Expand Down Expand Up @@ -917,6 +918,8 @@ Expect<
varchar: varchar('varchar', { length: 1 }),
varchar2: varchar('varchar2', { length: 1, enum: ['a', 'b', 'c'] }),
varchardef: varchar('varchardef', { length: 1 }).default(''),
vector: vector('vector', { dimensions: 1 }),
vector2: vector('vector2', { dimensions: 1, elementType: 'I8' }),
year: year('year'),
yeardef: year('yeardef').default(0),
});
Expand Down Expand Up @@ -1015,6 +1018,8 @@ Expect<
varchar: varchar({ length: 1 }),
varchar2: varchar({ length: 1, enum: ['a', 'b', 'c'] }),
varchardef: varchar({ length: 1 }).default(''),
vector: vector({ dimensions: 1 }),
vector2: vector({ dimensions: 1, elementType: 'I8' }),
year: year(),
yeardef: year().default(0),
});
Expand Down
43 changes: 43 additions & 0 deletions integration-tests/tests/singlestore/singlestore-common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ import {
uniqueIndex,
uniqueKeyName,
varchar,
vector,
year,
} from 'drizzle-orm/singlestore-core';
import { euclideanDistance, dotProduct } from 'drizzle-orm/singlestore-core/expressions';
import { migrate } from 'drizzle-orm/singlestore/migrator';
import getPort from 'get-port';
import { v4 as uuid } from 'uuid';
Expand Down Expand Up @@ -156,6 +158,12 @@ const aggregateTable = singlestoreTable('aggregate_table', {
nullOnly: int('null_only'),
});

const vectorSearchTable = singlestoreTable('vector_search', {
id: serial('id').notNull(),
text: text('text').notNull(),
embedding: vector('embedding', { dimensions: 10 }),
});

// To test another schema and multischema
const mySchema = singlestoreSchema(`mySchema`);

Expand Down Expand Up @@ -366,6 +374,23 @@ export function tests(driver?: string) {
]);
}

async function setupVectorSearchTest(db: TestSingleStoreDB) {
await db.execute(sql`drop table if exists \`vector_search\``);
await db.execute(
sql`
create table \`vector_search\` (
\`id\` integer primary key auto_increment not null,
\`text\` text not null,
\`embedding\` vector(10) not null
)
`,
)
await db.insert(vectorSearchTable).values([
{ id: 1, text: "I like dogs", embedding: [0.6119,0.1395,0.2921,0.3664,0.4561,0.7852,0.1997,0.5142,0.5924,0.0465] },
{ id: 2, text: "I like cats", embedding: [0.6075,0.1705,0.0651,0.9489,0.9656,0.8084,0.3046,0.0977,0.6842,0.4402] }
])
}

test('table config: unsigned ints', async () => {
const unsignedInts = singlestoreTable('cities1', {
bigint: bigint('bigint', { mode: 'number', unsigned: true }),
Expand Down Expand Up @@ -2907,6 +2932,24 @@ export function tests(driver?: string) {
expect(result2[0]?.value).toBe(null);
});

test('simple vector search', async (ctx) => {
const { db } = ctx.singlestore;
const table = vectorSearchTable;
const embedding = [0.42,0.93,0.88,0.57,0.32,0.64,0.76,0.52,0.19,0.81]; // ChatGPT's 10 dimension embedding for "dogs are cool"
await setupVectorSearchTest(db);

const withRankEuclidean = db.select({ id: table.id, text: table.text, rank: sql`row_number() over (order by ${euclideanDistance(table.embedding, embedding)})`.as('rank') }).from(table).as('with_rank')
const withRankDotProduct = db.select({ id: table.id, text: table.text, rank: sql`row_number() over (order by ${dotProduct(table.embedding, embedding)})`.as('rank') }).from(table).as('with_rank')
const result1 = await db.select({ id: withRankEuclidean.id, text: withRankEuclidean.text }).from(withRankEuclidean).where(eq(withRankEuclidean.rank, 1));
const result2 = await db.select({ id: withRankDotProduct.id, text: withRankDotProduct.text }).from(withRankDotProduct).where(eq(withRankDotProduct.rank, 1));

expect(result1.length).toEqual(1)
expect(result1[0]).toEqual({ id: 1, text: "I like dogs" });

expect(result2.length).toEqual(1)
expect(result2[0]).toEqual({ id: 1, text: "I like dogs" });
})

test('test $onUpdateFn and $onUpdate works as $default', async (ctx) => {
const { db } = ctx.singlestore;

Expand Down

0 comments on commit 286e96f

Please sign in to comment.