Skip to content

Commit

Permalink
Adding new APIs to avoid passing in allowed special tokens (#27)
Browse files Browse the repository at this point in the history
* Adding new APIs to avoid passing in allowed special tokens

* Update version to 1.3.3
  • Loading branch information
shengyfu authored Jan 11, 2024
1 parent 512d432 commit 2c9ba5d
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 56 deletions.
16 changes: 16 additions & 0 deletions Tokenizer_C#/TokenizerLib/ITokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@ public interface ITokenizer
/// </summary>
public (List<int> TokenIds, string Text) EncodeTrimPrefix(string text, IReadOnlyCollection<string> allowedSpecial, int maxTokenCount);

/// <summary>
/// Encode a string with or without special tokens set through constructor.
/// </summary>
public List<int> Encode(string text, bool applySpecialTokens = true);

/// <summary>
/// Encode a piece of text limited by max token count through trimming suffix, with or without special tokens set through constructor.
/// </summary>
public (List<int> TokenIds, string Text) EncodeTrimSuffix(string text, int maxTokenCount, bool applySpecialTokens = true);


/// <summary>
/// Encode a piece of text limited by max token count through trimming prefix, with or without special tokens set through constructor.
/// </summary>
public (List<int> TokenIds, string Text) EncodeTrimPrefix(string text, int maxTokenCount, bool applySpecialTokens = true);


/// <summary>
/// Decode an array of integer token ids
Expand Down
223 changes: 169 additions & 54 deletions Tokenizer_C#/TokenizerLib/TikTokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class TikTokenizer : ITokenizer
{

private IReadOnlyDictionary<string, int> SpecialTokensEncoder = null!;
private IReadOnlyCollection<string> SpecialTokens = null!;
private Regex Regex = null!;
private IReadOnlyDictionary<byte[], int> Encoder = null!;
private IReadOnlyDictionary<int, byte[]> Decoder = null!;
Expand Down Expand Up @@ -76,6 +77,7 @@ private void Init(IReadOnlyDictionary<byte[], int> encoder, IReadOnlyDictionary<
Regex = new Regex(pattern, RegexOptions.Compiled);
SpecialTokensRegex = new Regex(string.Join("|", specialTokensEncoder.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled);
SpecialTokensEncoder = specialTokensEncoder;
SpecialTokens = specialTokensEncoder.Keys.ToList();

Decoder = Encoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);

Expand Down Expand Up @@ -136,13 +138,7 @@ private Dictionary<byte[], int> LoadTikTokenBpe(Stream tikTokenBpeFileStream)
return bpeDict;
}

/// <summary>
/// Encode a string with a set of allowed special tokens that are not broken apart.
/// </summary>
/// <param name="text">String to be encoded</param>
/// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
/// <returns>List of token ids</returns>
public List<int> Encode(string text, IReadOnlyCollection<string> allowedSpecial)
private List<int> EncodeInternal(string text, IReadOnlyCollection<string> allowedSpecial)
{
var tokenIds = new List<int>();
int start = 0;
Expand Down Expand Up @@ -173,6 +169,43 @@ public List<int> Encode(string text, IReadOnlyCollection<string> allowedSpecial)
return tokenIds;
}

/// <summary>
/// Encode a string with a set of allowed special tokens that are not broken apart.
/// </summary>
/// <param name="text">String to be encoded</param>
/// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
/// <returns>List of token ids</returns>
public List<int> Encode(string text, IReadOnlyCollection<string> allowedSpecial)
{
if (allowedSpecial is null || allowedSpecial.Count == 0)
{
return Encode(text, false);
}
return EncodeInternal(text, allowedSpecial);
}

/// <summary>
/// Encode a string with or without special tokens set through constructor.
/// </summary>
/// <param name="text">String to be encoded</param>
/// <param name="applySpecialTokens">Whether to apply special token processing</param>
/// <returns></returns>
public List<int> Encode(string text, bool applySpecialTokens = true)
{

if (applySpecialTokens && SpecialTokens.Count > 0)
{
return EncodeInternal(text, SpecialTokens);
}

var tokenIds = new List<int>();
int start = 0;
Encode(text, tokenIds, start, text.Length);

return tokenIds;

}

/// <summary>
/// Encode a special token matched through regex.
/// </summary>
Expand All @@ -194,7 +227,7 @@ private int EncodeSpecialToken(List<int> tokenIds, Match nextSpecial)
/// <param name="start">Start search index in the string</param>
/// <param name="nextSpecial">The regex match of a special token</param>
/// <param name="end">The index of the special token matched or the end of the text</param>
private void FindNextSpecialToken(string text, IReadOnlyCollection<string> allowedSpecial, int start, out Match nextSpecial, out int end)
private void FindNextSpecialToken(string text, IReadOnlyCollection<string>? allowedSpecial, int start, out Match nextSpecial, out int end)
{
int startFind = start;
while (true)
Expand Down Expand Up @@ -308,14 +341,7 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
return (tokenCount, encodeLength);
}

/// <summary>
/// Encode a piece of text limited by max token count through trimming suffix
/// </summary>
/// <param name="text">Text to be encoded</param>
/// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
/// <param name="maxTokenCount">The max token count</param>
/// <returns>(List<int> TokenIds, string Text) - Token ids and text after suffix truncation based on max token count</returns>
public (List<int> TokenIds, string Text) EncodeTrimSuffix(string text, IReadOnlyCollection<string> allowedSpecial, int maxTokenCount)
private (List<int> TokenIds, string Text) EncodeTrimSuffixInternal(string text, IReadOnlyCollection<string> allowedSpecial, int maxTokenCount)
{
var tokenIds = new List<int>();

Expand Down Expand Up @@ -367,21 +393,58 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
}

/// <summary>
/// Encode a piece of text limited by max token count through trimming prefix
/// Encode a piece of text limited by max token count through trimming suffix
/// </summary>
/// <param name="text">Text to be encoded</param>
/// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
/// <param name="maxTokenCount">The max token count</param>
/// <returns>(List<int> TokenIds, string Text) - Token ids and text after prefix truncation based on max token count</returns>
public (List<int> TokenIds, string Text) EncodeTrimPrefix(string text, IReadOnlyCollection<string> allowedSpecial, int maxTokenCount)
/// <returns>(List<int> TokenIds, string Text) - Token ids and text after suffix truncation based on max token count</returns>
public (List<int> TokenIds, string Text) EncodeTrimSuffix(string text, IReadOnlyCollection<string> allowedSpecial, int maxTokenCount)
{
if (allowedSpecial is null || allowedSpecial.Count == 0)
{
return EncodeTrimSuffix(text, maxTokenCount, false);
}

return EncodeTrimSuffixInternal(text, allowedSpecial, maxTokenCount);

}

/// <summary>
/// Encode a piece of text limited by max token count through trimming suffix, with or without special tokens set through constructor.
/// </summary>
/// <param name="text">String to be encoded</param>
/// <param name="maxTokenCount">The max token count</param>
/// <param name="applySpecialTokens">Whether to apply special token processing</param>
/// <returns></returns>
public (List<int> TokenIds, string Text) EncodeTrimSuffix(string text, int maxTokenCount, bool applySpecialTokens = true)
{
if (applySpecialTokens && SpecialTokens.Count > 0)
{
return EncodeTrimSuffixInternal(text, SpecialTokens, maxTokenCount);
}

var tokenIds = new List<int>();
int start = 0;
int tokenCount = 0;
var encodeLength = 0;
(_, encodeLength) = EncodeTrimSuffix(text, tokenIds, start, text.Length, maxTokenCount, tokenCount, encodeLength);
var encodedText = encodeLength == text.Length ? text : text[..encodeLength];

return (tokenIds, encodedText);
}

private (List<int> TokenIds, string Text) EncodeTrimPrefixInternal(string text, IReadOnlyCollection<string> allowedSpecial, int maxTokenCount)
{
var tokenIds = new List<int>();

int start = 0;
int tokenCount = 0;
var encodeLength = 0;
var tokenCountMap = new SortedDictionary<int, int>();
tokenCountMap.Add(tokenCount, encodeLength);
var tokenCountMap = new SortedDictionary<int, int>
{
{ tokenCount, encodeLength }
};
while (true)
{
Match nextSpecial;
Expand All @@ -390,39 +453,7 @@ private void Encode(string text, List<int> tokenIds, int start, int end)

if (end > start)
{
foreach (Match match in Regex.Matches(text[start..end]))
{
var piece = match.Value;

if (this.Cache.Lookup(match.Value, out int[] tokens))
{
tokenCount += tokens.Length;
encodeLength += piece.Length;
tokenIds.AddRange(tokens);
tokenCountMap[tokenCount] = encodeLength;
}
else
{
var bytes = Encoding.UTF8.GetBytes(piece);
if (Encoder.TryGetValue(bytes, out int token))
{
tokenCount++;
encodeLength += piece.Length;
tokenIds.Add(token);
tokenCountMap[tokenCount] = encodeLength;

}
else
{
var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder);
this.Cache.Add(piece, encodedTokens.ToArray());
tokenCount += encodedTokens.Count;
encodeLength += piece.Length;
tokenIds.AddRange(encodedTokens);
tokenCountMap[tokenCount] = encodeLength;
}
}
}
Encode(text, tokenIds, start, ref tokenCount, ref encodeLength, tokenCountMap, end);
}

if (nextSpecial.Success)
Expand All @@ -442,6 +473,11 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
}
}

return TrimPrefix(text, maxTokenCount, tokenIds, tokenCount, tokenCountMap);
}

private static (List<int> TokenIds, string Text) TrimPrefix(string text, int maxTokenCount, List<int> tokenIds, int tokenCount, SortedDictionary<int, int> tokenCountMap)
{
if (tokenCount <= maxTokenCount)
{
return (tokenIds, text);
Expand All @@ -463,6 +499,85 @@ private void Encode(string text, List<int> tokenIds, int start, int end)
return (tokenIds.Skip(actualPrefixTokenCount).ToList(), text[actualPrefixStrLength..]);
}

private void Encode(string text, List<int> tokenIds, int start, ref int tokenCount, ref int encodeLength, SortedDictionary<int, int> tokenCountMap, int end)
{
foreach (Match match in Regex.Matches(text[start..end]))
{
var piece = match.Value;

if (this.Cache.Lookup(match.Value, out int[] tokens))
{
tokenCount += tokens.Length;
encodeLength += piece.Length;
tokenIds.AddRange(tokens);
tokenCountMap[tokenCount] = encodeLength;
}
else
{
var bytes = Encoding.UTF8.GetBytes(piece);
if (Encoder.TryGetValue(bytes, out int token))
{
tokenCount++;
encodeLength += piece.Length;
tokenIds.Add(token);
tokenCountMap[tokenCount] = encodeLength;

}
else
{
var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder);
this.Cache.Add(piece, encodedTokens.ToArray());
tokenCount += encodedTokens.Count;
encodeLength += piece.Length;
tokenIds.AddRange(encodedTokens);
tokenCountMap[tokenCount] = encodeLength;
}
}
}
}

/// <summary>
/// Encode a piece of text limited by max token count through trimming prefix
/// </summary>
/// <param name="text">Text to be encoded</param>
/// <param name="allowedSpecial">A set of special tokens could appear in the text</param>
/// <param name="maxTokenCount">The max token count</param>
/// <returns>(List<int> TokenIds, string Text) - Token ids and text after prefix truncation based on max token count</returns>
public (List<int> TokenIds, string Text) EncodeTrimPrefix(string text, IReadOnlyCollection<string> allowedSpecial, int maxTokenCount)
{
if (allowedSpecial is null || allowedSpecial.Count == 0)
{
return EncodeTrimPrefix(text, maxTokenCount, false);
}
return EncodeTrimPrefixInternal(text, allowedSpecial, maxTokenCount);
}

/// <summary>
/// Encode a piece of text limited by max token count through trimming prefix, with or without special tokens set through constructor.
/// </summary>
/// <param name="text">Text to be encoded</param>
/// <param name="maxTokenCount">The max token count</param>
/// <param name="applySpecialTokens">Whether to apply special token processing</param>
/// <returns></returns>
public (List<int> TokenIds, string Text) EncodeTrimPrefix(string text, int maxTokenCount, bool applySpecialTokens = true)
{
if (applySpecialTokens && SpecialTokens.Count > 0)
{
return EncodeTrimPrefixInternal(text, SpecialTokens, maxTokenCount);
}
var tokenIds = new List<int>();

int start = 0;
int tokenCount = 0;
var encodeLength = 0;
var tokenCountMap = new SortedDictionary<int, int>
{
{ tokenCount, encodeLength }
};
Encode(text, tokenIds, start, ref tokenCount, ref encodeLength, tokenCountMap, text.Length);
return TrimPrefix(text, maxTokenCount, tokenIds, tokenCount, tokenCountMap);
}

/// <summary>
/// Decode an array of integer token ids
/// </summary>
Expand Down
2 changes: 1 addition & 1 deletion Tokenizer_C#/TokenizerLib/TokenizerLib.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
<Title>Tokenizer</Title>
<Description>Tokenizer for OpenAI large language models.</Description>
<LangVersion>8.0</LangVersion>
<AssemblyVersion>1.3.2</AssemblyVersion>
<AssemblyVersion>1.3.3</AssemblyVersion>
<FileVersion>$(AssemblyVersion)</FileVersion>
<Version>$(AssemblyVersion)</Version>
<Authors>Microsoft</Authors>
Expand Down
21 changes: 20 additions & 1 deletion Tokenizer_C#/TokenizerTest/TikTokenizerUnitTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public void TestEncode0()
public void TestEncode1()
{
var text = "<|im_start|>Hello World<|im_end|>";
var encoded = Tokenizer.Encode(text, new HashSet<string>(SpecialTokens.Keys));
var encoded = Tokenizer.Encode(text);
Assert.AreEqual(4, encoded.Count);
Assert.AreEqual(100264, encoded[0]);
Assert.AreEqual(9906, encoded[1]);
Expand All @@ -70,6 +70,9 @@ public void TestEncode2()
var encoded = Tokenizer.Encode(text, new HashSet<string>(SpecialTokens.Keys));
Assert.AreEqual(5584, encoded.Count);

encoded = Tokenizer.Encode(text, false);
Assert.AreEqual(5584, encoded.Count);

string json = File.ReadAllText("./testData/tokens.json");
var expected = JsonConvert.DeserializeObject<int[]>(json);

Expand Down Expand Up @@ -131,6 +134,14 @@ public void TestEncodeTrimSuffix()
Assert.AreEqual(4, encoded.TokenIds.Count);
Assert.AreEqual(text, encoded.Text);

encoded = Tokenizer.EncodeTrimSuffix(text, 4, false);
Assert.AreEqual(4, encoded.TokenIds.Count);
Assert.AreEqual("<|im_start", encoded.Text);

encoded = Tokenizer.EncodeTrimSuffix(text, 4);
Assert.AreEqual(4, encoded.TokenIds.Count);
Assert.AreEqual(text, encoded.Text);

encoded = Tokenizer.EncodeTrimSuffix(text, new HashSet<string>(SpecialTokens.Keys), 5);
Assert.AreEqual(4, encoded.TokenIds.Count);
Assert.AreEqual(text, encoded.Text);
Expand Down Expand Up @@ -173,6 +184,14 @@ public void TestEncodeTrimPrefix()
Assert.AreEqual(4, encoded.TokenIds.Count);
Assert.AreEqual(text, encoded.Text);

encoded = Tokenizer.EncodeTrimPrefix(text, 4, false);
Assert.AreEqual(4, encoded.TokenIds.Count);
Assert.AreEqual("im_end|>", encoded.Text);

encoded = Tokenizer.EncodeTrimPrefix(text, 4);
Assert.AreEqual(4, encoded.TokenIds.Count);
Assert.AreEqual(text, encoded.Text);

encoded = Tokenizer.EncodeTrimPrefix(text, new HashSet<string>(SpecialTokens.Keys), 5);
Assert.AreEqual(4, encoded.TokenIds.Count);
Assert.AreEqual(text, encoded.Text);
Expand Down

0 comments on commit 2c9ba5d

Please sign in to comment.