Skip to content

Commit

Permalink
discojs*: rename .unbatch() to .flatten()
Browse files Browse the repository at this point in the history
  • Loading branch information
JulienVig committed Dec 3, 2024
1 parent 7361a95 commit c60eab6
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 24 deletions.
2 changes: 1 addition & 1 deletion cli/src/benchmark_gpt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async function main(args: Required<CLIArguments>): Promise<void> {
task.trainingInformation.maxSequenceLength = contextLength
const dataset = loadText('../datasets/wikitext/wiki.train.tokens')
.map(text => processing.tokenize(tokenizer, text))
.unbatch()
.flatten()
.batch(config.blockSize + 1, 1)

const preprocessedDataset = dataset
Expand Down
2 changes: 1 addition & 1 deletion cli/src/train_gpt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async function main(): Promise<void> {

const tokenDataset = new Dataset([data])
.map((text: string) => processing.tokenize(tokenizer, text))
.unbatch()
.flatten()
.batch(config.blockSize + 1, 1)
.map((tokens) => [tokens.pop(), tokens.last()] as [List<number>, number])
.repeat()
Expand Down
4 changes: 2 additions & 2 deletions discojs/src/dataset/dataset.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ describe("dataset", () => {
expect(await arrayFromAsync(right)).to.have.length(1);
});

it("batches in samed sized chunks", async () => {
it("batches in same sized chunks", async () => {
const dataset = new Dataset([1, 2, 3, 4]);

const batched = dataset.batch(2);
Expand Down Expand Up @@ -155,7 +155,7 @@ describe("dataset", () => {
const blockSize = 4

const parsed = new Dataset([expectedTokens])
.unbatch()
.flatten()
.batch(blockSize + 1, 1)

// -1 because the last sequence is dropped as there is no next token label
Expand Down
40 changes: 24 additions & 16 deletions discojs/src/dataset/dataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,22 @@ type DatasetLike<T> =
| (() => AsyncIterator<T, void>)
| (() => Iterator<T, void>);

/** Convert a DatasetLike object to an async generator */
async function* datasetLikeToGenerator<U>(content: DatasetLike<U>):
AsyncGenerator<U, void, undefined> {
let iter: AsyncIterator<U, void> | Iterator<U, void>;
if (typeof content === "function") iter = content();
else if (Symbol.asyncIterator in content)
iter = content[Symbol.asyncIterator]();
else iter = content[Symbol.iterator]();

while (true) {
const result = await iter.next();
if (result.done === true) break;
yield result.value;
}
}

/** Immutable series of data */
export class Dataset<T> implements AsyncIterable<T> {
readonly #content: () => AsyncIterator<T, void, undefined>;
Expand All @@ -23,18 +39,8 @@ export class Dataset<T> implements AsyncIterable<T> {
*/
constructor(content: DatasetLike<T>) {
this.#content = async function* () {
let iter: AsyncIterator<T, void> | Iterator<T, void>;
if (typeof content === "function") iter = content();
else if (Symbol.asyncIterator in content)
iter = content[Symbol.asyncIterator]();
else iter = content[Symbol.iterator]();

while (true) {
const result = await iter.next();
if (result.done === true) break;
yield result.value;
}
};
yield* datasetLikeToGenerator(content);
}
}

[Symbol.asyncIterator](): AsyncIterator<T> {
Expand Down Expand Up @@ -160,11 +166,13 @@ export class Dataset<T> implements AsyncIterable<T> {
);
}

/** Flatten chunks */
unbatch<U>(this: Dataset<Batched<U>>): Dataset<U> {
/** Flatten batches/arrays of elements */
flatten<U>(this: Dataset<DatasetLike<U>>): Dataset<U> {
return new Dataset(
async function* (this: Dataset<Batched<U>>) {
for await (const batch of this) yield* batch;
async function* (this: Dataset<DatasetLike<U>>) {
for await (const batch of this) {
yield* datasetLikeToGenerator(batch);
}
}.bind(this),
);
}
Expand Down
4 changes: 2 additions & 2 deletions discojs/src/processing/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ export async function preprocess<D extends DataType>(

const tokenizer = await models.getTaskTokenizer(t);
return d.map(text => processing.tokenize(tokenizer, text))
.unbatch()
.flatten()
.batch(blockSize + 1, 1)
.map((tokens) => [tokens.pop(), tokens.last()]) as
Dataset<DataFormat.ModelEncoded[D]>;
Expand Down Expand Up @@ -101,7 +101,7 @@ export async function preprocessWithoutLabel<D extends DataType>(
const tokenizer = await models.getTaskTokenizer(t);

return d.map(text => processing.tokenize(tokenizer, text))
.unbatch()
.flatten()
.batch(blockSize)
}
}
Expand Down
4 changes: 2 additions & 2 deletions discojs/src/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export class Validator<D extends DataType> {
.zip(batch.map(([_, outputs]) => outputs))
.map(([inferred, truth]) => inferred === truth),
)
.unbatch();
.flatten();

for await (const e of results) yield e;
}
Expand All @@ -36,7 +36,7 @@ export class Validator<D extends DataType> {
)
.batch(this.task.trainingInformation.batchSize)
.map((batch) => this.#model.predict(batch))
.unbatch();
.flatten();

const predictions = await processing.postprocess(
this.task,
Expand Down

0 comments on commit c60eab6

Please sign in to comment.