Skip to content

Commit

Permalink
Fix encodeTrim* on special strings with repeat tokens (#26)
Browse files Browse the repository at this point in the history
* Fix tests

* Fix unused variables
  • Loading branch information
lramos15 authored Dec 20, 2023
1 parent 3c6fcb9 commit 512d432
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 14 deletions.
10 changes: 5 additions & 5 deletions tokenizer_ts/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tokenizer_ts/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"name": "@microsoft/tiktokenizer",
"displayName": "tiktokenizer",
"description": "Tokenizer for OpenAI large language models.",
"version": "1.0.3",
"version": "1.0.4",
"author": {
"name": "Microsoft Corporation"
},
Expand Down
30 changes: 24 additions & 6 deletions tokenizer_ts/src/tikTokenizer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,15 @@ export class TikTokenizer {
const piece = match[0];
if (this.cache.has(piece)) {
let tokens = this.cache.get(piece);
tokenCount += tokens!.length;
if (tokenCount <= maxTokenCount) {
if (tokenCount + tokens!.length <= maxTokenCount) {
tokenCount += tokens!.length;
encodeLength += piece.length;
tokenIds.push(...tokens!);
} else {
let remainingTokens = maxTokenCount - tokenCount;
tokenCount += remainingTokens;
encodeLength += piece.length;
tokenIds.push(...tokens!.slice(0, remainingTokens));
break;
}
} else {
Expand All @@ -254,8 +258,8 @@ export class TikTokenizer {
const token = this.encoder!.get(uint8ArrayToString(bytes));
if (token !== undefined) {
this.cache.set(piece, [token]);
tokenCount++;
if (tokenCount <= maxTokenCount) {
if (tokenCount + 1 <= maxTokenCount) {
tokenCount++;
encodeLength += piece.length;
tokenIds.push(token);
} else {
Expand All @@ -264,11 +268,15 @@ export class TikTokenizer {
} else {
const encodedTokens = bytePairEncode(bytes, this.encoder!);
this.cache.set(piece, encodedTokens);
tokenCount += encodedTokens.length;
if (tokenCount <= maxTokenCount) {
if (tokenCount + encodedTokens.length <= maxTokenCount) {
tokenCount += encodedTokens.length;
encodeLength += piece.length;
tokenIds.push(...encodedTokens);
} else {
let remainingTokens = maxTokenCount - tokenCount;
tokenCount += remainingTokens;
encodeLength += piece.length;
tokenIds.push(...encodedTokens.slice(0, remainingTokens));
break;
}
}
Expand Down Expand Up @@ -443,6 +451,16 @@ export class TikTokenizer {
}
}

// Naive approach if chunks are incorrect
if (actualPrefixTokenCount > maxTokenCount) {
const encodedTokens = this.encode(text, allowedSpecial);
const slicedTokens = encodedTokens.slice(encodedTokens.length - maxTokenCount);
return {
tokenIds: slicedTokens,
text: this.decode(slicedTokens)
};
}

return {
tokenIds: tokenIds.slice(actualPrefixTokenCount),
text: text.slice(actualPrefixStrLength)
Expand Down
20 changes: 18 additions & 2 deletions tokenizer_ts/test/tikTokenizer.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ suite("TikTokenizer Test Suite", function() {

test("encode trim suffix - 2", () => {
const str = "<|im_start|>Hello TempWorld<|im_end|>";
const encodedStr = "<|im_start|>Hello";
const encodedStr = "<|im_start|>Hello TempWorld";
let encoded = tokenizer.encodeTrimSuffix(
str,
5,
Expand Down Expand Up @@ -125,10 +125,18 @@ suite("TikTokenizer Test Suite", function() {
3,
Array.from(specialTokens.keys())
);
assert.deepStrictEqual(encoded.tokenIds, [100264, 9906]);
assert.deepStrictEqual(encoded.tokenIds, [100264, 9906, 20539]);
assert.deepStrictEqual(encoded.text, encodedStr);
});

test("encode trim suffix - 3", () => {
const str = "t".repeat(4000);
const encodedStr = tokenizer.encode(str);
let encodedTrimSuffix = tokenizer.encodeTrimSuffix(str, 5, []);
assert.deepStrictEqual(encodedTrimSuffix.tokenIds.length, 5);
assert.deepStrictEqual(encodedTrimSuffix.tokenIds, encodedStr.slice(0, 5));
});

test("encode trim prefix", () => {
const str = "<|im_start|>Hello World<|im_end|>";
const encodedStr = "Hello World<|im_end|>";
Expand Down Expand Up @@ -197,6 +205,14 @@ suite("TikTokenizer Test Suite", function() {
assert.deepStrictEqual(encoded.text, encodedStr);
});

test("encode trim prefix - 3", () => {
const str = "t".repeat(4000);
const encodedStr = tokenizer.encode(str);
let encodedTrimSuffix = tokenizer.encodeTrimPrefix(str, 5, []);
assert.deepStrictEqual(encodedTrimSuffix.tokenIds.length, 5);
assert.deepStrictEqual(encodedTrimSuffix.tokenIds, encodedStr.slice(encodedStr.length - 5));
});

test("tokenize source code - gpt-3.5", done => {
const source = fs.readFileSync("test/testdata/lib.rs.txt", "utf8");
const filePath = "test/testdata/tokens_gpt_3.5_turbo.json";
Expand Down

0 comments on commit 512d432

Please sign in to comment.