Skip to content

Commit

Permalink
Merge pull request #3169 from AnnoyingTechnology/241203-bedrock-reran…
Browse files Browse the repository at this point in the history
…king

Brings support for Bedrock's reranking models, Fixes #3152
  • Loading branch information
sestinj authored Dec 6, 2024
2 parents 7d30c7a + 5e46f83 commit f3d0399
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 37 deletions.
99 changes: 99 additions & 0 deletions core/context/rerankers/bedrock.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import {
BedrockRuntimeClient,
InvokeModelCommand,
} from "@aws-sdk/client-bedrock-runtime";
import { fromIni } from "@aws-sdk/credential-providers";
import { Chunk, Reranker, RerankerName } from "../../index.js";

export class BedrockReranker implements Reranker {
name: RerankerName = "bedrock";

static defaultOptions = {
region: "us-east-1",
model: "amazon.rerank-v1:0",
profile: "bedrock",
};

private supportedModels = ["amazon.rerank-v1:0", "cohere.rerank-v3-5:0"];

constructor(
private readonly params: {
region?: string;
model?: string;
profile?: string;
} = {},
) {
if (params.model && !this.supportedModels.includes(params.model)) {
throw new Error(
`Unsupported model: ${params.model}. Supported models are: ${this.supportedModels.join(", ")}`,
);
}
}

async rerank(query: string, chunks: Chunk[]): Promise<number[]> {
if (!query || !chunks.length) {
throw new Error("Query and chunks must not be empty");
}

try {
const credentials = await this._getCredentials();
const client = new BedrockRuntimeClient({
region: this.params.region ?? BedrockReranker.defaultOptions.region,
credentials,
});

const model = this.params.model ?? BedrockReranker.defaultOptions.model;

// Base payload for both models
const payload: any = {
query: query,
documents: chunks.map((chunk) => chunk.content),
top_n: chunks.length,
};

// Add api_version for Cohere model
if (model.startsWith("cohere.rerank")) {
payload.api_version = 2;
}

const input = {
body: JSON.stringify(payload),
modelId: model,
accept: "*/*",
contentType: "application/json",
};

const command = new InvokeModelCommand(input);
const response = await client.send(command);

if (!response.body) {
throw new Error("Empty response received from Bedrock");
}

const responseBody = JSON.parse(new TextDecoder().decode(response.body));

// Sort results by index to maintain original order
return responseBody.results
.sort((a: any, b: any) => a.index - b.index)
.map((result: any) => result.relevance_score);
} catch (error) {
console.error("Error in BedrockReranker.rerank:", error);
throw error;
}
}

private async _getCredentials() {
try {
const credentials = await fromIni({
profile: this.params.profile ?? BedrockReranker.defaultOptions.profile,
ignoreCache: true,
})();
return credentials;
} catch (e) {
console.warn(
`AWS profile with name ${this.params.profile ?? BedrockReranker.defaultOptions.profile} not found in ~/.aws/credentials, using default profile`,
);
return await fromIni()();
}
}
}
4 changes: 3 additions & 1 deletion core/context/rerankers/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { RerankerName } from "../../index.js";

import { BedrockReranker } from "./bedrock.js";
import { CohereReranker } from "./cohere.js";
import { ContinueProxyReranker } from "./ContinueProxyReranker.js";
import { FreeTrialReranker } from "./freeTrial.js";
Expand All @@ -9,9 +10,10 @@ import { VoyageReranker } from "./voyage.js";

export const AllRerankers: { [key in RerankerName]: any } = {
cohere: CohereReranker,
bedrock: BedrockReranker,
llm: LLMReranker,
voyage: VoyageReranker,
"free-trial": FreeTrialReranker,
"huggingface-tei": HuggingFaceTEIReranker,
"continue-proxy": ContinueProxyReranker,
};
};
1 change: 1 addition & 0 deletions core/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,7 @@ export interface EmbeddingsProvider {

export type RerankerName =
| "cohere"
| "bedrock"
| "voyage"
| "llm"
| "free-trial"
Expand Down
82 changes: 49 additions & 33 deletions core/indexing/docs/article.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,59 +25,75 @@ function breakdownArticleComponent(
max_chunk_size: number,
): Chunk[] {
const chunks: Chunk[] = [];

const lines = article.body.split("\n");
let startLine = 0;
let endLine = 0;
let content = "";
let index = 0;

const createChunk = (
chunkContent: string,
chunkStartLine: number,
chunkEndLine: number,
) => {
chunks.push({
content: chunkContent.trim(),
startLine: chunkStartLine,
endLine: chunkEndLine,
otherMetadata: {
title: cleanHeader(article.title),
},
index: index++,
filepath: new URL(
`${subpath}#${cleanFragment(article.title)}`,
url,
).toString(),
digest: subpath,
});
};

for (let i = 0; i < lines.length; i++) {
const line = lines[i];
if (content.length + line.length <= max_chunk_size) {

// Handle oversized lines by splitting them
if (line.length > max_chunk_size) {
// First push any accumulated content
if (content.trim().length > 0) {
createChunk(content, startLine, endLine);
content = "";
}

// Split the long line into chunks
let remainingLine = line;
let subLineStart = i;
while (remainingLine.length > 0) {
const chunk = remainingLine.slice(0, max_chunk_size);
createChunk(chunk, subLineStart, i);
remainingLine = remainingLine.slice(max_chunk_size);
}
startLine = i + 1;
continue;
}

// Normal line handling
if (content.length + line.length + 1 <= max_chunk_size) {
content += `${line}\n`;
endLine = i;
} else {
chunks.push({
content: content.trim(),
startLine: startLine,
endLine: endLine,
otherMetadata: {
title: cleanHeader(article.title),
},
index: index,
filepath: new URL(
`${subpath}#${cleanFragment(article.title)}`,
url,
).toString(),
digest: subpath,
});
if (content.trim().length > 0) {
createChunk(content, startLine, endLine);
}
content = `${line}\n`;
startLine = i;
endLine = i;
index += 1;
}
}

// Push the last chunk
if (content) {
chunks.push({
content: content.trim(),
startLine: startLine,
endLine: endLine,
otherMetadata: {
title: cleanHeader(article.title),
},
index: index,
filepath: new URL(
`${subpath}#${cleanFragment(article.title)}`,
url,
).toString(),
digest: subpath,
});
if (content.trim().length > 0) {
createChunk(content, startLine, endLine);
}

// Don't use small chunks. Probably they're a mistake. Definitely they'll confuse the embeddings model.
return chunks.filter((c) => c.content.trim().length > 20);
}

Expand Down
14 changes: 12 additions & 2 deletions docs/docs/customize/model-providers/top-level/bedrock.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,19 @@ We recommend configuring [`amazon.titan-embed-text-v2:0`](https://docs.aws.amazo

## Reranking model

Bedrock currently does not offer any reranking models.
We recommend configuring `cohere.rerank-v3-5:0` as your reranking model, you may also use `amazon.rerank-v1:0`.

[Click here](../../model-types/reranking.md) to see a list of reranking model providers.
```json title="~/.continue/config.json"
{
"reranker": {
"name": "bedrock",
"params": {
"model": "cohere.rerank-v3-5:0",
"region": "us-west-2"
}
}
}
```

## Authentication

Expand Down
3 changes: 2 additions & 1 deletion docs/docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,11 @@ Configuration for the reranker model used in response ranking.

**Properties:**

- `name` (**required**): Reranker name, e.g., `cohere`, `voyage`, `llm`, `free-trial`, `huggingface-tei`
- `name` (**required**): Reranker name, e.g., `cohere`, `voyage`, `llm`, `free-trial`, `huggingface-tei`, `bedrock`
- `params`:
- `model`: Model name
- `apiKey`: Api key
- `region`: Region (for Bedrock only)

Example

Expand Down
27 changes: 27 additions & 0 deletions extensions/vscode/config_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2642,6 +2642,7 @@
"properties": {
"name": {
"enum": [
"bedrock",
"cohere",
"voyage",
"llm",
Expand Down Expand Up @@ -2743,6 +2744,32 @@
}
}
},
{
"if": {
"properties": {
"name": {
"enum": ["bedrock"]
}
},
"required": ["name"]
},
"then": {
"properties": {
"params": {
"type": "object",
"properties": {
"model": {
"enum": [
"cohere.rerank-v3-5:0",
"amazon.rerank-v1:0"
]
}
},
"required": ["model"]
}
}
}
},
{
"if": {
"properties": {
Expand Down

0 comments on commit f3d0399

Please sign in to comment.