diff --git a/Tokenizer_C#/TokenizerLib/ITokenizer.cs b/Tokenizer_C#/TokenizerLib/ITokenizer.cs index 7674074..08ff714 100644 --- a/Tokenizer_C#/TokenizerLib/ITokenizer.cs +++ b/Tokenizer_C#/TokenizerLib/ITokenizer.cs @@ -22,6 +22,22 @@ public interface ITokenizer /// public (List TokenIds, string Text) EncodeTrimPrefix(string text, IReadOnlyCollection allowedSpecial, int maxTokenCount); + /// + /// Encode a string with or without special tokens set through constructor. + /// + public List Encode(string text, bool applySpecialTokens = true); + + /// + /// Encode a piece of text limited by max token count through trimming suffix, with or without special tokens set through constructor. + /// + public (List TokenIds, string Text) EncodeTrimSuffix(string text, int maxTokenCount, bool applySpecialTokens = true); + + + /// + /// Encode a piece of text limited by max token count through trimming prefix, with or without special tokens set through constructor. + /// + public (List TokenIds, string Text) EncodeTrimPrefix(string text, int maxTokenCount, bool applySpecialTokens = true); + /// /// Decode an array of integer token ids diff --git a/Tokenizer_C#/TokenizerLib/TikTokenizer.cs b/Tokenizer_C#/TokenizerLib/TikTokenizer.cs index f78ac9f..c0bba10 100644 --- a/Tokenizer_C#/TokenizerLib/TikTokenizer.cs +++ b/Tokenizer_C#/TokenizerLib/TikTokenizer.cs @@ -21,6 +21,7 @@ public class TikTokenizer : ITokenizer { private IReadOnlyDictionary SpecialTokensEncoder = null!; + private IReadOnlyCollection SpecialTokens = null!; private Regex Regex = null!; private IReadOnlyDictionary Encoder = null!; private IReadOnlyDictionary Decoder = null!; @@ -76,6 +77,7 @@ private void Init(IReadOnlyDictionary encoder, IReadOnlyDictionary< Regex = new Regex(pattern, RegexOptions.Compiled); SpecialTokensRegex = new Regex(string.Join("|", specialTokensEncoder.Keys.Select(s => Regex.Escape(s))), RegexOptions.Compiled); SpecialTokensEncoder = specialTokensEncoder; + SpecialTokens = specialTokensEncoder.Keys.ToList(); Decoder = Encoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key); @@ -136,13 +138,7 @@ private Dictionary LoadTikTokenBpe(Stream tikTokenBpeFileStream) return bpeDict; } - /// - /// Encode a string with a set of allowed special tokens that are not broken apart. - /// - /// String to be encoded - /// A set of special tokens could appear in the text - /// List of token ids - public List Encode(string text, IReadOnlyCollection allowedSpecial) + private List EncodeInternal(string text, IReadOnlyCollection allowedSpecial) { var tokenIds = new List(); int start = 0; @@ -173,6 +169,43 @@ public List Encode(string text, IReadOnlyCollection allowedSpecial) return tokenIds; } + /// + /// Encode a string with a set of allowed special tokens that are not broken apart. + /// + /// String to be encoded + /// A set of special tokens could appear in the text + /// List of token ids + public List Encode(string text, IReadOnlyCollection allowedSpecial) + { + if (allowedSpecial is null || allowedSpecial.Count == 0) + { + return Encode(text, false); + } + return EncodeInternal(text, allowedSpecial); + } + + /// + /// Encode a string with or without special tokens set through constructor. + /// + /// String to be encoded + /// Whether to apply special token processing + /// + public List Encode(string text, bool applySpecialTokens = true) + { + + if (applySpecialTokens && SpecialTokens.Count > 0) + { + return EncodeInternal(text, SpecialTokens); + } + + var tokenIds = new List(); + int start = 0; + Encode(text, tokenIds, start, text.Length); + + return tokenIds; + + } + /// /// Encode a special token matched through regex. /// @@ -194,7 +227,7 @@ private int EncodeSpecialToken(List tokenIds, Match nextSpecial) /// Start search index in the string /// The regex match of a special token /// The index of the special token matched or the end of the text - private void FindNextSpecialToken(string text, IReadOnlyCollection allowedSpecial, int start, out Match nextSpecial, out int end) + private void FindNextSpecialToken(string text, IReadOnlyCollection? allowedSpecial, int start, out Match nextSpecial, out int end) { int startFind = start; while (true) @@ -308,14 +341,7 @@ private void Encode(string text, List tokenIds, int start, int end) return (tokenCount, encodeLength); } - /// - /// Encode a piece of text limited by max token count through trimming suffix - /// - /// Text to be encoded - /// A set of special tokens could appear in the text - /// The max token count - /// (List TokenIds, string Text) - Token ids and text after suffix truncation based on max token count - public (List TokenIds, string Text) EncodeTrimSuffix(string text, IReadOnlyCollection allowedSpecial, int maxTokenCount) + private (List TokenIds, string Text) EncodeTrimSuffixInternal(string text, IReadOnlyCollection allowedSpecial, int maxTokenCount) { var tokenIds = new List(); @@ -367,21 +393,58 @@ private void Encode(string text, List tokenIds, int start, int end) } /// - /// Encode a piece of text limited by max token count through trimming prefix + /// Encode a piece of text limited by max token count through trimming suffix /// /// Text to be encoded /// A set of special tokens could appear in the text /// The max token count - /// (List TokenIds, string Text) - Token ids and text after prefix truncation based on max token count - public (List TokenIds, string Text) EncodeTrimPrefix(string text, IReadOnlyCollection allowedSpecial, int maxTokenCount) + /// (List TokenIds, string Text) - Token ids and text after suffix truncation based on max token count + public (List TokenIds, string Text) EncodeTrimSuffix(string text, IReadOnlyCollection allowedSpecial, int maxTokenCount) + { + if (allowedSpecial is null || allowedSpecial.Count == 0) + { + return EncodeTrimSuffix(text, maxTokenCount, false); + } + + return EncodeTrimSuffixInternal(text, allowedSpecial, maxTokenCount); + + } + + /// + /// Encode a piece of text limited by max token count through trimming suffix, with or without special tokens set through constructor. + /// + /// String to be encoded + /// The max token count + /// Whether to apply special token processing + /// + public (List TokenIds, string Text) EncodeTrimSuffix(string text, int maxTokenCount, bool applySpecialTokens = true) + { + if (applySpecialTokens && SpecialTokens.Count > 0) + { + return EncodeTrimSuffixInternal(text, SpecialTokens, maxTokenCount); + } + + var tokenIds = new List(); + int start = 0; + int tokenCount = 0; + var encodeLength = 0; + (_, encodeLength) = EncodeTrimSuffix(text, tokenIds, start, text.Length, maxTokenCount, tokenCount, encodeLength); + var encodedText = encodeLength == text.Length ? text : text[..encodeLength]; + + return (tokenIds, encodedText); + } + + private (List TokenIds, string Text) EncodeTrimPrefixInternal(string text, IReadOnlyCollection allowedSpecial, int maxTokenCount) { var tokenIds = new List(); int start = 0; int tokenCount = 0; var encodeLength = 0; - var tokenCountMap = new SortedDictionary(); - tokenCountMap.Add(tokenCount, encodeLength); + var tokenCountMap = new SortedDictionary + { + { tokenCount, encodeLength } + }; while (true) { Match nextSpecial; @@ -390,39 +453,7 @@ private void Encode(string text, List tokenIds, int start, int end) if (end > start) { - foreach (Match match in Regex.Matches(text[start..end])) - { - var piece = match.Value; - - if (this.Cache.Lookup(match.Value, out int[] tokens)) - { - tokenCount += tokens.Length; - encodeLength += piece.Length; - tokenIds.AddRange(tokens); - tokenCountMap[tokenCount] = encodeLength; - } - else - { - var bytes = Encoding.UTF8.GetBytes(piece); - if (Encoder.TryGetValue(bytes, out int token)) - { - tokenCount++; - encodeLength += piece.Length; - tokenIds.Add(token); - tokenCountMap[tokenCount] = encodeLength; - - } - else - { - var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder); - this.Cache.Add(piece, encodedTokens.ToArray()); - tokenCount += encodedTokens.Count; - encodeLength += piece.Length; - tokenIds.AddRange(encodedTokens); - tokenCountMap[tokenCount] = encodeLength; - } - } - } + Encode(text, tokenIds, start, ref tokenCount, ref encodeLength, tokenCountMap, end); } if (nextSpecial.Success) @@ -442,6 +473,11 @@ private void Encode(string text, List tokenIds, int start, int end) } } + return TrimPrefix(text, maxTokenCount, tokenIds, tokenCount, tokenCountMap); + } + + private static (List TokenIds, string Text) TrimPrefix(string text, int maxTokenCount, List tokenIds, int tokenCount, SortedDictionary tokenCountMap) + { if (tokenCount <= maxTokenCount) { return (tokenIds, text); @@ -463,6 +499,85 @@ private void Encode(string text, List tokenIds, int start, int end) return (tokenIds.Skip(actualPrefixTokenCount).ToList(), text[actualPrefixStrLength..]); } + private void Encode(string text, List tokenIds, int start, ref int tokenCount, ref int encodeLength, SortedDictionary tokenCountMap, int end) + { + foreach (Match match in Regex.Matches(text[start..end])) + { + var piece = match.Value; + + if (this.Cache.Lookup(match.Value, out int[] tokens)) + { + tokenCount += tokens.Length; + encodeLength += piece.Length; + tokenIds.AddRange(tokens); + tokenCountMap[tokenCount] = encodeLength; + } + else + { + var bytes = Encoding.UTF8.GetBytes(piece); + if (Encoder.TryGetValue(bytes, out int token)) + { + tokenCount++; + encodeLength += piece.Length; + tokenIds.Add(token); + tokenCountMap[tokenCount] = encodeLength; + + } + else + { + var encodedTokens = BytePairEncoder.BytePairEncode(bytes, Encoder); + this.Cache.Add(piece, encodedTokens.ToArray()); + tokenCount += encodedTokens.Count; + encodeLength += piece.Length; + tokenIds.AddRange(encodedTokens); + tokenCountMap[tokenCount] = encodeLength; + } + } + } + } + + /// + /// Encode a piece of text limited by max token count through trimming prefix + /// + /// Text to be encoded + /// A set of special tokens could appear in the text + /// The max token count + /// (List TokenIds, string Text) - Token ids and text after prefix truncation based on max token count + public (List TokenIds, string Text) EncodeTrimPrefix(string text, IReadOnlyCollection allowedSpecial, int maxTokenCount) + { + if (allowedSpecial is null || allowedSpecial.Count == 0) + { + return EncodeTrimPrefix(text, maxTokenCount, false); + } + return EncodeTrimPrefixInternal(text, allowedSpecial, maxTokenCount); + } + + /// + /// Encode a piece of text limited by max token count through trimming prefix, with or without special tokens set through constructor. + /// + /// Text to be encoded + /// The max token count + /// Whether to apply special token processing + /// + public (List TokenIds, string Text) EncodeTrimPrefix(string text, int maxTokenCount, bool applySpecialTokens = true) + { + if (applySpecialTokens && SpecialTokens.Count > 0) + { + return EncodeTrimPrefixInternal(text, SpecialTokens, maxTokenCount); + } + var tokenIds = new List(); + + int start = 0; + int tokenCount = 0; + var encodeLength = 0; + var tokenCountMap = new SortedDictionary + { + { tokenCount, encodeLength } + }; + Encode(text, tokenIds, start, ref tokenCount, ref encodeLength, tokenCountMap, text.Length); + return TrimPrefix(text, maxTokenCount, tokenIds, tokenCount, tokenCountMap); + } + /// /// Decode an array of integer token ids /// diff --git a/Tokenizer_C#/TokenizerLib/TokenizerLib.csproj b/Tokenizer_C#/TokenizerLib/TokenizerLib.csproj index 1ef64d3..efbf501 100644 --- a/Tokenizer_C#/TokenizerLib/TokenizerLib.csproj +++ b/Tokenizer_C#/TokenizerLib/TokenizerLib.csproj @@ -8,7 +8,7 @@ Tokenizer Tokenizer for OpenAI large language models. 8.0 - 1.3.2 + 1.3.3 $(AssemblyVersion) $(AssemblyVersion) Microsoft diff --git a/Tokenizer_C#/TokenizerTest/TikTokenizerUnitTest.cs b/Tokenizer_C#/TokenizerTest/TikTokenizerUnitTest.cs index 8196b07..43b7d2e 100644 --- a/Tokenizer_C#/TokenizerTest/TikTokenizerUnitTest.cs +++ b/Tokenizer_C#/TokenizerTest/TikTokenizerUnitTest.cs @@ -53,7 +53,7 @@ public void TestEncode0() public void TestEncode1() { var text = "<|im_start|>Hello World<|im_end|>"; - var encoded = Tokenizer.Encode(text, new HashSet(SpecialTokens.Keys)); + var encoded = Tokenizer.Encode(text); Assert.AreEqual(4, encoded.Count); Assert.AreEqual(100264, encoded[0]); Assert.AreEqual(9906, encoded[1]); @@ -70,6 +70,9 @@ public void TestEncode2() var encoded = Tokenizer.Encode(text, new HashSet(SpecialTokens.Keys)); Assert.AreEqual(5584, encoded.Count); + encoded = Tokenizer.Encode(text, false); + Assert.AreEqual(5584, encoded.Count); + string json = File.ReadAllText("./testData/tokens.json"); var expected = JsonConvert.DeserializeObject(json); @@ -131,6 +134,14 @@ public void TestEncodeTrimSuffix() Assert.AreEqual(4, encoded.TokenIds.Count); Assert.AreEqual(text, encoded.Text); + encoded = Tokenizer.EncodeTrimSuffix(text, 4, false); + Assert.AreEqual(4, encoded.TokenIds.Count); + Assert.AreEqual("<|im_start", encoded.Text); + + encoded = Tokenizer.EncodeTrimSuffix(text, 4); + Assert.AreEqual(4, encoded.TokenIds.Count); + Assert.AreEqual(text, encoded.Text); + encoded = Tokenizer.EncodeTrimSuffix(text, new HashSet(SpecialTokens.Keys), 5); Assert.AreEqual(4, encoded.TokenIds.Count); Assert.AreEqual(text, encoded.Text); @@ -173,6 +184,14 @@ public void TestEncodeTrimPrefix() Assert.AreEqual(4, encoded.TokenIds.Count); Assert.AreEqual(text, encoded.Text); + encoded = Tokenizer.EncodeTrimPrefix(text, 4, false); + Assert.AreEqual(4, encoded.TokenIds.Count); + Assert.AreEqual("im_end|>", encoded.Text); + + encoded = Tokenizer.EncodeTrimPrefix(text, 4); + Assert.AreEqual(4, encoded.TokenIds.Count); + Assert.AreEqual(text, encoded.Text); + encoded = Tokenizer.EncodeTrimPrefix(text, new HashSet(SpecialTokens.Keys), 5); Assert.AreEqual(4, encoded.TokenIds.Count); Assert.AreEqual(text, encoded.Text);