From 91aef1103913fee1f95d2a6913dad2855e6549ff Mon Sep 17 00:00:00 2001 From: Marko Mikulicic Date: Tue, 23 Apr 2024 11:27:04 +0200 Subject: [PATCH 1/2] Reformat --- src/spm_decode_main.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spm_decode_main.cc b/src/spm_decode_main.cc index 824f2bd6..f666533b 100644 --- a/src/spm_decode_main.cc +++ b/src/spm_decode_main.cc @@ -16,13 +16,13 @@ #include #include +#include "absl/flags/flag.h" +#include "absl/strings/str_split.h" #include "common.h" #include "filesystem.h" #include "init.h" #include "sentencepiece.pb.h" #include "sentencepiece_processor.h" -#include "absl/flags/flag.h" -#include "absl/strings/str_split.h" #include "util.h" ABSL_FLAG(std::string, model, "", "model file name"); From 2ce6ea8187b8e438039afdda17668b3a6799ea44 Mon Sep 17 00:00:00 2001 From: Marko Mikulicic Date: Tue, 23 Apr 2024 11:53:53 +0200 Subject: [PATCH 2/2] Implement poolside input_format for spm_decode --- src/spm_decode_main.cc | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/src/spm_decode_main.cc b/src/spm_decode_main.cc index f666533b..76514495 100644 --- a/src/spm_decode_main.cc +++ b/src/spm_decode_main.cc @@ -28,11 +28,22 @@ ABSL_FLAG(std::string, model, "", "model file name"); ABSL_FLAG(std::string, input, "", "input filename"); ABSL_FLAG(std::string, output, "", "output filename"); -ABSL_FLAG(std::string, input_format, "piece", "choose from piece or id"); +ABSL_FLAG(std::string, input_format, "piece", + "choose from piece, id, poolside or poolside_no_toc"); ABSL_FLAG(std::string, output_format, "string", "choose from string or proto"); ABSL_FLAG(std::string, extra_options, "", "':' separated encoder extra options, e.g., \"reverse:bos:eos\""); +std::vector read_uint32(absl::string_view binary_data) { + std::vector result; + for (size_t i = 0; i < binary_data.size(); i += sizeof(uint32_t)) { + uint32_t value = + *reinterpret_cast(binary_data.data() + i); + result.push_back(value); + } + return result; +} + int main(int argc, char *argv[]) { sentencepiece::ScopedResourceDestructor cleaner; sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true); @@ -101,6 +112,31 @@ int main(int argc, char *argv[]) { LOG(FATAL) << "Unknown output format: " << absl::GetFlag(FLAGS_output_format); } + } else if (absl::GetFlag(FLAGS_input_format) == "poolside" || + absl::GetFlag(FLAGS_input_format) == "poolside_no_toc") { + for (const auto &filename : rest_args) { + auto input = sentencepiece::filesystem::NewReadableFile(filename, false); + CHECK_OK(input->status()); + input->ReadAll(&line); + { + auto tokens = read_uint32(line); + if (absl::GetFlag(FLAGS_input_format) == "poolside") { + // the last 8 bytes of the file contain encode the number of documents + // encoded as a little endian uint64_t which means the last 2 uint32_t + // of the tokens array contain the length; + uint64_t num_docs = (uint64_t)tokens.back() << 32; + tokens.pop_back(); + num_docs |= tokens.back(); + // the trailier of the file contains the length of each document + // encoded as an uint32_t + tokens.resize(tokens.size() - num_docs); + } + + CHECK_OK(sp.Decode(tokens, &detok)); + output->Write(detok); + } + } + return 0; } else { LOG(FATAL) << "Unknown input format: " << absl::GetFlag(FLAGS_input_format); }