diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml new file mode 100644 index 00000000..038b5b83 --- /dev/null +++ b/.github/workflows/pypi-release.yml @@ -0,0 +1,61 @@ +name: PyPI release +on: + workflow_dispatch: + +jobs: + testing: + uses: ./.github/workflows/testing.yml + release: + needs: testing + runs-on: ubuntu-latest + env: + TWINE_USERNAME: __token__ + + steps: + - name: Checkout Repo + uses: actions/checkout@v3 + + - name: Setup Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + + - name: Install build dependencies + run: | + python -m pip install --upgrade pip + pip install -U twine build + + - name: Build the dist files + run: python -m build . + + - name: Publish to the test PyPI + env: + TWINE_PASSWORD: ${{ secrets.TEST_PYPI_TOKEN }} + run: twine upload dist/* --repository=testpypi + + - name: Test installing from test PyPI and running tests + run: | + pip install -i https://testpypi.python.org/pypi --extra-index-url https://pypi.org/simple datatrove[testing] + python -m nltk.downloader punkt + make test + + - name: Get tag name + id: get_tag_name + run: | + echo TAG_NAME=$(grep '^version' pyproject.toml | head -1 | cut -d '"' -f 2) >> $GITHUB_OUTPUT + + - name: Tag the release + uses: actions/github-script@v7 + with: + script: | + github.rest.git.createRef({ + owner: context.repo.owner, + repo: context.repo.repo, + ref: 'refs/tags/v${{ steps.get_tag_name.outputs.TAG_NAME }}', + sha: context.sha + }) + + - name: Publish to PyPI + env: + TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} + run: twine upload dist/* --repository=pypi diff --git a/.github/workflows/ci.yml b/.github/workflows/testing.yml similarity index 73% rename from .github/workflows/ci.yml rename to .github/workflows/testing.yml index 675e7897..cf52a7b4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/testing.yml @@ -1,4 +1,4 @@ -name: CI +name: Test & Check Code Quality on: pull_request: @@ -7,6 +7,7 @@ on: push: branches: - main + workflow_call: jobs: check_code_quality: @@ -19,12 +20,12 @@ jobs: python-version: "3.10" - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install .[quality] + python -m pip install uv + uv pip install --system .[quality] - name: Check quality run: | - ruff check tests src # linter - ruff format --check tests src # formatter + ruff check tests src examples # linter + ruff format --check tests src examples # formatter test: runs-on: ubuntu-latest @@ -40,8 +41,8 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install .[testing] + python -m pip install uv + uv pip install --system .[testing] python -m nltk.downloader punkt - name: Test with pytest run: | diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml new file mode 100644 index 00000000..9cbbf680 --- /dev/null +++ b/.github/workflows/trufflehog.yml @@ -0,0 +1,15 @@ +on: + push: + +name: Secret Leaks + +jobs: + trufflehog: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Secret Scanning + uses: trufflesecurity/trufflehog@main diff --git a/CITATION.cff b/CITATION.cff index c7c510d3..ae268901 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -6,6 +6,8 @@ type: software authors: - given-names: Guilherme family-names: Penedo + - given-names: Hynek + family-names: Kydlíček - given-names: Alessandro family-names: Cappelli - given-names: Thomas diff --git a/README.md b/README.md index e46f521d..6ee00eac 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ Local, remote and other file systems are supported through [fsspec](https://file * [Filtering data](#filtering-data) * [Saving data](#saving-data) * [Deduplicating data](#deduplicating-data) + * [Summary Statistics](#summary-statistics) * [Custom blocks](#custom-blocks) + [Simple data](#simple-data) + [Custom function](#custom-function) @@ -96,6 +97,7 @@ Some options common to all executors: - `pipeline` a list consisting of the pipeline steps that should be run - `logging_dir` a datafolder where log files, statistics and more should be saved. Do not reuse folders for different pipelines/jobs as this will overwrite your stats, logs and completions. - `skip_completed` (_bool_, `True` by default) datatrove keeps track of completed tasks so that when you relaunch a job they can be skipped. Set this to `False` to disable this behaviour +- `randomize_start_duration` (_int_, `0` by default) the maximum number of seconds to delay the start of each task to prevent all tasks from starting simultaneously and potentially overloading the system. Call an executor's `run` method to execute its pipeline. @@ -223,6 +225,12 @@ For a pipeline with `logging_dir` **mylogspath/exp1**, the following folder stru ``` +### Colorization +Log messages support colorization. By default, colorization will be auto detected for console messages and disabled for log files (logs/task_XXXXX.log). +To explicitly enable or disable colorization, you may set the following environment variables: +- `DATATROVE_COLORIZE_LOGS` "1" to add ANSI colors to console log messages and "0" to disable colorization. +- `DATATROVE_COLORIZE_LOG_FILES` set to "1" to add ANSI colors to log messages saved to logs/task_XXXXX.log. + ## DataFolder / paths Datatrove supports a wide variety of input/output sources through [fsspec](https://filesystem-spec.readthedocs.io/en/latest/). @@ -279,6 +287,45 @@ JsonlWriter( ### Deduplicating data For deduplication check the examples [minhash_deduplication.py](examples/minhash_deduplication.py), [sentence_deduplication.py](examples/sentence_deduplication.py) and [exact_substrings.py](examples/exact_substrings.py). +### Summary Statistics +For summary statistics on your data you can use the [Stats](src/datatrove/pipeline/stats/summary_stats/) blocks. These blocks provide an easy way to collect data-profiles on your dataset in a distributed manner. It's a two step process in which you first: +1) For each shard iterate over documents and collect stats into of the following groupings `summary` (all docs counted to "summary" key), `fqdn` (fully qualified domain name grouping), `suffix` (the last part of the url path grouping) or `histogram` (value based grouping). +2) Merge the stats from different shards into a single file. +See the [summary_stats.py](examples/summarty_stats.py) for more details. + +Each resulting stat is saved in a separate file with following structure: `output_folder/{fqdn,suffix,summary,histogram}/{stat_name}/metric.json` + +Each such file is a `MetricStatsDict` object, which you can easily load using: +```python +from datatrove.pipeline.stats.summary_stats import MetricStatsDict +import json +stats = MetricStatsDict.from_dict(json.load(open("fqdn/length/metric.json"))) + +# E.g for total length of nytimes.com docs +stats["nytimes.com"].total + +# Or for mean of cnn.com docs +stats["cnn.com"].mean +``` + +Following stats are available: +- `contamination_stats.py`: `word_contamination_{words[0]}: Frequency of words contamination in the document. +- `doc_stats.py`: `length`: Length of the document, `white_space_ratio`: Ratio of whitespace characters, `non_alpha_digit_ratio`: Ratio of non-alphabetic and non-digit characters, `digit_ratio`: Ratio of digits, `uppercase_ratio`: Ratio of uppercase letters, `elipsis_ratio`: Ratio of elipsis characters, `punctuation_ratio`: Punctuation ratio +- `lang_stats.py`: `fasttext_{language}`: Language of the document using fastText +- `line_stats.py`: `n_lines`: Number of lines per doc, `avg_line_length`: Average length of line per doc, `long_line_ratio_words`: Ratio of lines with more than k chars, `short_line_ratio_chars`: Ratio of lines with more than k chars, `bullet_point_lines_ratio`: Ratio of bullet points, `line_duplicates`: Ratio of lines that are duplicates, `line_char_duplicates`: Ratio of chars in duplicated lines +- `paragraph_stats.py`: `n_paragraphs`: Number of paragraphs, `avg_paragraph_length`: Average paragraph length, `short_paragraph_ratio_{chars}`: Ratio of short paragraphs (<{chars} chars), `long_paragraph_ratio_{chars}`: Ratio of long paragraphs (>{chars} chars) +- `perplexity_stats.py`: `ccnet_perplexity_{model_dataset}_{language}`: Perplexity of the document using the CCNet model for {model} on {dataset} in {language} +- `sentence_stats.py`: `n_sentences`: Number of sentences, `avg_sentence_length`: Average sentence length, `short_sentence_ratio_{chars}`: Ratio of short sentences (<{chars} chars), `long_sentence_ratio_{chars}`: Ratio of long sentences (>{chars} chars) +- `token_stats.py`:`token_count`: Number of tokens in the document +- `word_stats.py`: `n_words`: Number of words in the document, `avg_word_length`: Average length of words in the document, `avg_words_per_line`: Average number of words per line in the document, `short_word_ratio_{chars}`: Ratio of words shorter than {chars} characters, `stop_word_ratio`: Ratio of stop words, `long_word_ratio_{chars}`: Ratio of words longer than {chars} characters, `type_token_ratio`: Number of unique words / Number of tokens, `capitalized_word_ratio`: Ratio of capitalized words, `uppercase_word_ratio`: Ratio of uppercase words + + + + + + + + ### Custom blocks #### Simple data @@ -405,11 +452,11 @@ pytest -sv ./tests/ ```bibtex @misc{penedo2024datatrove, - author = {Penedo, Guilherme and Cappelli, Alessandro and Wolf, Thomas and Sasko, Mario}, + author = {Penedo, Guilherme and Kydlíček, Hynek and Cappelli, Alessandro and Sasko, Mario and Wolf, Thomas}, title = {DataTrove: large scale data processing}, year = {2024}, publisher = {GitHub}, journal = {GitHub repository}, url = {https://github.com/huggingface/datatrove} } -``` \ No newline at end of file +``` diff --git a/examples/fineweb.py b/examples/fineweb.py new file mode 100644 index 00000000..d5b93188 --- /dev/null +++ b/examples/fineweb.py @@ -0,0 +1,176 @@ +""" +This file contains the code used to process and create the +FineWeb dataset (https://huggingface.co/datasets/HuggingFaceFW/fineweb) +""" + +from datatrove.executor.slurm import SlurmPipelineExecutor +from datatrove.pipeline.dedup import MinhashDedupCluster, MinhashDedupFilter, MinhashDedupSignature +from datatrove.pipeline.dedup.minhash import MinhashConfig, MinhashDedupBuckets +from datatrove.pipeline.extractors import Trafilatura +from datatrove.pipeline.filters import ( + C4QualityFilter, + FineWebQualityFilter, + GopherQualityFilter, + GopherRepetitionFilter, + LanguageFilter, + URLFilter, +) +from datatrove.pipeline.formatters import PIIFormatter +from datatrove.pipeline.readers import JsonlReader, WarcReader +from datatrove.pipeline.tokens import TokensCounter +from datatrove.pipeline.writers.jsonl import JsonlWriter + + +""" + we first ran the following pipeline for each dump +""" +DUMP_TO_PROCESS = "CC-MAIN-2023-50" # example + +MAIN_OUTPUT_PATH = "s3://some_s3_bucket" +FILTERING_OUTPUT_PATH = f"{MAIN_OUTPUT_PATH}/base_processing" + +main_processing_executor = SlurmPipelineExecutor( + job_name=f"cc_{DUMP_TO_PROCESS}", + pipeline=[ + WarcReader( + f"s3://commoncrawl/crawl-data/{DUMP_TO_PROCESS}/segments/", + glob_pattern="*/warc/*", # we want the warc files + default_metadata={"dump": DUMP_TO_PROCESS}, + ), + URLFilter(exclusion_writer=JsonlWriter(f"{FILTERING_OUTPUT_PATH}/removed/1_url/{DUMP_TO_PROCESS}")), + Trafilatura(favour_precision=True), + LanguageFilter( + exclusion_writer=JsonlWriter( + f"{FILTERING_OUTPUT_PATH}/2_non_english/", + output_filename="${language}/" + DUMP_TO_PROCESS + "/${rank}.jsonl.gz", + # folder structure: language/dump/file + ) + ), + GopherRepetitionFilter( + exclusion_writer=JsonlWriter(f"{FILTERING_OUTPUT_PATH}/removed/3_gopher_rep/{DUMP_TO_PROCESS}") + ), + GopherQualityFilter( + exclusion_writer=JsonlWriter(f"{FILTERING_OUTPUT_PATH}/removed/4_gopher_qual/{DUMP_TO_PROCESS}") + ), + C4QualityFilter( + filter_no_terminal_punct=False, + exclusion_writer=JsonlWriter(f"{FILTERING_OUTPUT_PATH}/removed/5_c4/{DUMP_TO_PROCESS}"), + ), + FineWebQualityFilter( + exclusion_writer=JsonlWriter(f"{FILTERING_OUTPUT_PATH}/removed/6_fineweb_qual/{DUMP_TO_PROCESS}") + ), + JsonlWriter(f"{FILTERING_OUTPUT_PATH}/output/{DUMP_TO_PROCESS}"), + ], + tasks=8000, + time="10:00:00", + logging_dir=f"{MAIN_OUTPUT_PATH}/logs/base_processing/{DUMP_TO_PROCESS}", + slurm_logs_folder=f"logs/base_processing/{DUMP_TO_PROCESS}/slurm_logs", # must be local + randomize_start_duration=180, # don't hit the bucket all at once with the list requests + mem_per_cpu_gb=2, + partition="hopper-cpu", +) +main_processing_executor.run() + +""" + we then applied minhash deduplication to each individual dump, +""" + +# you can also change ngrams or the number of buckets and their size here +minhash_config = MinhashConfig( + use_64bit_hashes=True, # better precision -> fewer false positives (collisions) + num_buckets=14, + hashes_per_bucket=8, + n_grams=5, +) + +S3_MINHASH_BASE_PATH = f"{MAIN_OUTPUT_PATH}/minhash" + +S3_LOGS_FOLDER = f"{MAIN_OUTPUT_PATH}/logs/minhash" +LOCAL_LOGS_FOLDER = "logs/minhash" + +TOTAL_TASKS = 1000 + +# this is the original data that we want to deduplicate +INPUT_READER = JsonlReader( + f"{FILTERING_OUTPUT_PATH}/output/{DUMP_TO_PROCESS}" +) # this is the output from the first part + +# stage 1 computes minhash signatures for each task (each task gets a set of files) +stage1 = SlurmPipelineExecutor( + job_name=f"mh1_{DUMP_TO_PROCESS}", + pipeline=[ + INPUT_READER, + MinhashDedupSignature( + output_folder=f"{S3_MINHASH_BASE_PATH}/{DUMP_TO_PROCESS}/signatures", config=minhash_config + ), + ], + tasks=TOTAL_TASKS, + time="5:00:00", + partition="hopper-cpu", + logging_dir=f"{S3_LOGS_FOLDER}/signatures", + slurm_logs_folder=f"{LOCAL_LOGS_FOLDER}/signatures/slurm_logs", + randomize_start_duration=180, + depends=main_processing_executor, # only start after the first one completes +) + +stage2 = SlurmPipelineExecutor( + job_name=f"mh2_{DUMP_TO_PROCESS}", + pipeline=[ + MinhashDedupBuckets( + input_folder=f"{S3_MINHASH_BASE_PATH}/{DUMP_TO_PROCESS}/signatures", + output_folder=f"{S3_MINHASH_BASE_PATH}/{DUMP_TO_PROCESS}/buckets", + config=MinhashConfig(use_64bit_hashes=True), + ), + ], + tasks=minhash_config.num_buckets * 50, # the code supports parallelizing each bucket. here we run 50 + # workers per bucket + randomize_start_duration=180, + logging_dir=f"{S3_LOGS_FOLDER}/buckets", + partition="hopper-cpu", + time="02:00:00", + mem_per_cpu_gb=4, + cpus_per_task=3, # you can add run more (smaller) tasks if you do not have a lot of memory + depends=stage1, +) + + +stage3 = SlurmPipelineExecutor( + job_name=f"mh3_{DUMP_TO_PROCESS}", + pipeline=[ + MinhashDedupCluster( + input_folder=f"{S3_MINHASH_BASE_PATH}/{DUMP_TO_PROCESS}/buckets", + output_folder=f"{S3_MINHASH_BASE_PATH}/{DUMP_TO_PROCESS}/remove_ids", + config=minhash_config, + ), + ], + tasks=1, # this step runs on a single task + logging_dir=f"{S3_LOGS_FOLDER}/clustering", + partition="hopper-cpu", + time="30:00:00", # and can also be quite slow. Usually not this slow though + mem_per_cpu_gb=25, + cpus_per_task=8, # if you dedup a full dump, you do need a lot of memory for this one + depends=stage2, +) + + +stage4 = SlurmPipelineExecutor( + job_name=f"mh4_{DUMP_TO_PROCESS}", + pipeline=[ + INPUT_READER, + TokensCounter(), # you can remove this one, it's just a nice way to know how many tokens we have + # before and after dedup + MinhashDedupFilter(input_folder=f"{S3_MINHASH_BASE_PATH}/{DUMP_TO_PROCESS}/remove_ids"), + # run the PII removal + PIIFormatter(), + JsonlWriter(f"{S3_MINHASH_BASE_PATH}/{DUMP_TO_PROCESS}/deduped_output"), + ], + tasks=TOTAL_TASKS, + logging_dir=f"{S3_LOGS_FOLDER}/filtering", + partition="hopper-cpu", + time="5:00:00", + mem_per_cpu_gb=4, + depends=stage3, +) + +# launch dedup pipelines +stage4.run() diff --git a/examples/process_common_crawl_dump.py b/examples/process_common_crawl_dump.py index 68c5d18d..8cee638b 100644 --- a/examples/process_common_crawl_dump.py +++ b/examples/process_common_crawl_dump.py @@ -6,7 +6,6 @@ GopherQualityFilter, GopherRepetitionFilter, LanguageFilter, - ListFilter, URLFilter, ) from datatrove.pipeline.readers import WarcReader @@ -39,14 +38,13 @@ ), GopherRepetitionFilter(exclusion_writer=JsonlWriter(f"{MAIN_OUTPUT_PATH}/removed/repetitive/{DUMP}")), GopherQualityFilter(exclusion_writer=JsonlWriter(f"{MAIN_OUTPUT_PATH}/removed/quality/{DUMP}")), - ListFilter(exclusion_writer=JsonlWriter(f"{MAIN_OUTPUT_PATH}/removed/list/{DUMP}")), JsonlWriter(f"{MAIN_OUTPUT_PATH}/output/{DUMP}"), ], tasks=8000, time="10:00:00", logging_dir=f"{MAIN_OUTPUT_PATH}/logs/base_processing/{DUMP}", slurm_logs_folder=f"logs/process_dump/processing/base_processing/{DUMP}/slurm_logs", - randomize_start=True, + randomize_start_duration=180, mem_per_cpu_gb=2, partition="hopper-cpu", ) diff --git a/examples/summary_stats.py b/examples/summary_stats.py new file mode 100644 index 00000000..4cff7daa --- /dev/null +++ b/examples/summary_stats.py @@ -0,0 +1,80 @@ +import argparse +import dataclasses + +from datatrove.executor.slurm import SlurmPipelineExecutor +from datatrove.pipeline.filters.sampler_filter import SamplerFilter +from datatrove.pipeline.readers.jsonl import JsonlReader +from datatrove.pipeline.stats import DocStats, LineStats, StatsMerger, TopKConfig, WordStats + + +TOTAL_TASKS = 500 + +parser = argparse.ArgumentParser(description="Summary Stats") +parser.add_argument("dump_path", help="Dump name sampler") +parser.add_argument("sample_rate", type=float, help="Sample rate") +parser.add_argument("--prefix", default="", help="Prefix") +parser.add_argument("--glob", help="Glob pattern") +parser.add_argument("--text_key", default="text", help="Text key") +parser.add_argument("--reader", default="jsonl", help="Reader type") + +if __name__ == "__main__": + args = parser.parse_args() + experiment_name = args.dump_path.replace("/", "_") + LOCAL_LOGS_FOLDER = f"/logs/{experiment_name}" + DATA_FOLDER = f"s3://data/{experiment_name}" + SOURCE = f"{args.prefix}/{args.dump_path}" + print(SOURCE) + + top_k_config = TopKConfig(top_k_groups=["fqdn", "suffix"], top_k=10_000) + + compute = SlurmPipelineExecutor( + pipeline=[ + JsonlReader(SOURCE, doc_progress=True, limit=-1, glob_pattern=args.glob, text_key=args.text_key), + # Sampling is fine for summary stats + SamplerFilter( + rate=args.sample_rate, + ), + WordStats( + output_folder=DATA_FOLDER, + top_k_config=top_k_config, + ), + LineStats( + output_folder=DATA_FOLDER, + top_k_config=top_k_config, + ), + DocStats( + output_folder=DATA_FOLDER, + top_k_config=top_k_config, + ), + ], + tasks=TOTAL_TASKS, + job_name=f"summary-stats-{experiment_name}", + time="24:00:00", + partition="hopper-cpu", + logging_dir=f"{LOCAL_LOGS_FOLDER}-compute", + qos="normal", + mem_per_cpu_gb=2, + cpus_per_task=1, + ) + + merger = SlurmPipelineExecutor( + pipeline=[ + StatsMerger( + input_folder=DATA_FOLDER, + output_folder=f"{DATA_FOLDER}", + remove_input=False, + top_k_config=dataclasses.replace(top_k_config, top_k=8_000), + ), + ], + tasks=TOTAL_TASKS, + job_name=f"merging-stats-{experiment_name}", + time="24:00:00", + partition="hopper-cpu", + logging_dir=f"{LOCAL_LOGS_FOLDER}-merge", + qos="normal", + mem_per_cpu_gb=2, + cpus_per_task=1, + depends=compute, + ) + + merger.run() diff --git a/examples/url_deduplication.py b/examples/url_deduplication.py new file mode 100644 index 00000000..eaf99d00 --- /dev/null +++ b/examples/url_deduplication.py @@ -0,0 +1,80 @@ +import argparse + +import numpy as np + +from datatrove.executor.base import PipelineExecutor +from datatrove.executor.local import LocalPipelineExecutor +from datatrove.pipeline.dedup.url_dedup import ( + UrlDedupConfig, + UrlDedupFilter, + UrlDedupSignature, + UrlFindDedups, +) +from datatrove.pipeline.readers import JsonlReader +from datatrove.pipeline.writers.jsonl import JsonlWriter + + +""" +Example on how to use url-deduplication. +To run url deduplication we need to run three different pipelines (same as sentence dedup) +""" + + +# modify url dedup hyper params here +url_dedup_config = UrlDedupConfig( + # this will keep the longest document for each url + document_priority=lambda doc: min(np.iinfo(np.uint16).max, len(doc.text) // 4), + url_normalizer=lambda url: url.lower(), +) + +FINDER_WORKERS = 4 # this will speed up/parallelize step 2 + +LIMIT = -1 # for testing + + +def run_example(args): + pipeline_1 = [ + JsonlReader(args.input_folder, limit=LIMIT, progress=True), + UrlDedupSignature( + output_folder=f"{args.sigs_dup_folder}/sigs", + config=url_dedup_config, + finder_workers=FINDER_WORKERS, + ), + ] + + pipeline_2 = [ + UrlFindDedups( + data_folder=f"{args.sigs_dup_folder}/sigs", + output_folder=f"{args.sigs_dup_folder}/dups", + config=url_dedup_config, + ) + ] + + pipeline_3 = [ + JsonlReader(data_folder=args.input_folder, limit=LIMIT, progress=True), + UrlDedupFilter( + data_folder=f"{args.sigs_dup_folder}/dups", + config=url_dedup_config, + exclusion_writer=JsonlWriter(output_folder=f"{args.base_output_folder}/removed"), + ), + JsonlWriter(output_folder=f"{args.base_output_folder}/output"), + ] + + executor_1: PipelineExecutor = LocalPipelineExecutor(pipeline=pipeline_1, tasks=4) + + executor_2: PipelineExecutor = LocalPipelineExecutor(pipeline=pipeline_2, tasks=FINDER_WORKERS) + + executor_3: PipelineExecutor = LocalPipelineExecutor(pipeline=pipeline_3, tasks=4) + + print(executor_1.run()) + print(executor_2.run()) + print(executor_3.run()) + + +parser = argparse.ArgumentParser(description="URL Deduplication") +parser.add_argument("input_folder", help="Input folder path") +parser.add_argument("base_output_folder", help="Base output folder path") +parser.add_argument("sigs_dup_folder", help="sigs-dup folder path") +if __name__ == "__main__": + args = parser.parse_args() + run_example(args) diff --git a/pyproject.toml b/pyproject.toml index 6dbfb148..c1bb7d7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "datatrove" -version = "0.0.1" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) +version = "0.2.0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) description = "HuggingFace library to process and filter large amounts of webdata" readme = "README.md" authors = [ @@ -28,7 +28,7 @@ dependencies = [ "humanize", "loguru>=0.7.0", "multiprocess", - "numpy>=1.25.0", + "numpy>=1.25.0,<2.0.0", "tqdm", ] @@ -41,7 +41,8 @@ io = [ "pyarrow", "python-magic", "warcio", - "datasets>=2.18.0" + "datasets>=2.18.0", + "orjson" ] s3 = [ "s3fs>=2023.12.2", @@ -50,11 +51,27 @@ processing = [ "fasttext-wheel", "nltk", "inscriptis", -# "readability-lxml @ git+https://github.com/huggingface/python-readability.git@speedup", +# "readability-lxml @ git+https://github.com/huggingface/python-readability.git@speedup", "tldextract", "trafilatura>=1.8.0", "tokenizers", - "ftfy" + "ftfy", + "fasteners", + "xxhash", + "kenlm", + "pyahocorasick" +] +decont = [ + "lighteval>=0.3.0" +] +multilingual = [ + "spacy", + "stanza", + "pyvi", + "pythainlp", + "jieba", + "indic-nlp-library", + "kiwipiepy", ] quality = [ "ruff>=0.1.5" @@ -63,7 +80,9 @@ testing = [ "datatrove[cli]", "datatrove[io]", "datatrove[processing]", + "datatrove[multilingual]", "datatrove[s3]", + "datatrove[decont]", "pytest", "pytest-timeout", "pytest-xdist", diff --git a/src/datatrove/executor/base.py b/src/datatrove/executor/base.py index a4c36356..2ebdb647 100644 --- a/src/datatrove/executor/base.py +++ b/src/datatrove/executor/base.py @@ -1,15 +1,22 @@ import dataclasses import json +import random +import time from abc import ABC, abstractmethod from collections import deque from collections.abc import Sequence from typing import Callable -from loguru import logger - from datatrove.io import DataFolderLike, get_datafolder from datatrove.pipeline.base import PipelineStep -from datatrove.utils.logging import add_task_logger, close_task_logger, get_random_str, get_timestamp, log_pipeline +from datatrove.utils.logging import ( + add_task_logger, + close_task_logger, + get_random_str, + get_timestamp, + log_pipeline, + logger, +) from datatrove.utils.stats import PipelineStats @@ -22,6 +29,7 @@ class PipelineExecutor(ABC): logging_dir: where to save logs, stats, etc. Should be parsable into a datatrove.io.DataFolder skip_completed: whether to skip tasks that were completed in previous runs. default: True + randomize_start_duration: the maximum number of seconds to delay the start of each task. """ @abstractmethod @@ -30,10 +38,12 @@ def __init__( pipeline: list[PipelineStep | Callable], logging_dir: DataFolderLike = None, skip_completed: bool = True, + randomize_start_duration: int = 0, ): self.pipeline: list[PipelineStep | Callable] = pipeline self.logging_dir = get_datafolder(logging_dir if logging_dir else f"logs/{get_timestamp()}_{get_random_str()}") self.skip_completed = skip_completed + self.randomize_start_duration = randomize_start_duration @abstractmethod def run(self): @@ -69,6 +79,9 @@ def _run_for_rank(self, rank: int, local_rank: int = 0) -> PipelineStats: return PipelineStats() logfile = add_task_logger(self.logging_dir, rank, local_rank) log_pipeline(self.pipeline) + + if self.randomize_start_duration > 0: + time.sleep(random.randint(0, self.randomize_start_duration)) try: # pipe data from one step to the next pipelined_data = None diff --git a/src/datatrove/executor/local.py b/src/datatrove/executor/local.py index ef6df4ca..0d16db74 100644 --- a/src/datatrove/executor/local.py +++ b/src/datatrove/executor/local.py @@ -4,11 +4,11 @@ from typing import Callable import multiprocess -from loguru import logger from datatrove.executor.base import PipelineExecutor from datatrove.io import DataFolderLike from datatrove.pipeline.base import PipelineStep +from datatrove.utils.logging import logger from datatrove.utils.stats import PipelineStats @@ -30,6 +30,7 @@ class LocalPipelineExecutor(PipelineExecutor): Tasks [local_rank_offset, local_rank_offset + local_tasks] will be run. depends: another LocalPipelineExecutor that should run before this one + randomize_start_duration: the maximum number of seconds to delay the start of each task. """ def __init__( @@ -43,8 +44,9 @@ def __init__( start_method: str = "forkserver", local_tasks: int = -1, local_rank_offset: int = 0, + randomize_start_duration: int = 0, ): - super().__init__(pipeline, logging_dir, skip_completed) + super().__init__(pipeline, logging_dir, skip_completed, randomize_start_duration) self.tasks = tasks self.workers = workers if workers != -1 else tasks self.start_method = start_method diff --git a/src/datatrove/executor/slurm.py b/src/datatrove/executor/slurm.py index 17495721..b0cc9077 100644 --- a/src/datatrove/executor/slurm.py +++ b/src/datatrove/executor/slurm.py @@ -1,8 +1,8 @@ from __future__ import annotations import json +import math import os -import random import signal import subprocess import sys @@ -14,18 +14,17 @@ import dill from dill import CONTENTS_FMODE -from loguru import logger from datatrove.executor.base import PipelineExecutor from datatrove.io import DataFolderLike from datatrove.pipeline.base import PipelineStep -from datatrove.utils.logging import get_random_str, get_timestamp +from datatrove.utils.logging import get_random_str, get_timestamp, logger def requeue_handler(signum, _frame): signame = signal.Signals(signum).name logger.warning(f"Received signal {signum} ({signame}). Requeueing and exiting...") - subprocess.run(["scontrol", "requeue", "${SLURM_JOB_ID}"]) + subprocess.run(["scontrol", "requeue", os.environ.get("SLURM_JOB_ID")]) sys.exit(15) @@ -73,13 +72,13 @@ class SlurmPipelineExecutor(PipelineExecutor): stagger_max_array_jobs: when max_array_launch_parallel is True, this determines how many seconds to wait between launching each of the parallel jobs run_on_dependency_fail: start executing when a job we depend on finishes even if it has failed - randomize_start: randomize the start of each task in a job in a ~3 min window + randomize_start_duration: the maximum number of seconds to delay the start of each task. requeue_signals: requeue the job and exit when one of these signals is received. Useful for when an instance is being reclaimed and jobs must be stopped for example. Set to None to disable mail_type: see https://slurm.schedmd.com/sbatch.html. Common values are (NONE, BEGIN, END, FAIL, REQUEUE, ALL) mail_user: email address to send notifications to requeue: requeue the job if it fails - + tasks_per_job: each slurm job in the job array will run these many datatrove tasks. This reduces the total nb of slurm jobs launched. """ def __init__( @@ -106,18 +105,21 @@ def __init__( max_array_launch_parallel: bool = False, stagger_max_array_jobs: int = 0, run_on_dependency_fail: bool = False, - randomize_start: bool = False, + randomize_start_duration: int = 0, requeue_signals: tuple[str] | None = ("SIGUSR1",), mail_type: str = "ALL", mail_user: str = None, requeue: bool = True, + srun_args: dict = None, + tasks_per_job: int = 1, ): - super().__init__(pipeline, logging_dir, skip_completed) + super().__init__(pipeline, logging_dir, skip_completed, randomize_start_duration) self.tasks = tasks self.workers = workers self.partition = partition self.cpus_per_task = cpus_per_task self.mem_per_cpu_gb = mem_per_cpu_gb + self.tasks_per_job = tasks_per_job self.time = time self.job_name = job_name self.qos = qos @@ -131,11 +133,12 @@ def __init__( self.max_array_launch_parallel = max_array_launch_parallel self.stagger_max_array_jobs = stagger_max_array_jobs self.run_on_dependency_fail = run_on_dependency_fail - self.randomize_start = randomize_start + self.randomize_start_duration = randomize_start_duration self.job_id = None self.requeue_signals = requeue_signals self.mail_type = mail_type self.mail_user = mail_user + self.srun_args = srun_args self.slurm_logs_folder = ( slurm_logs_folder if slurm_logs_folder @@ -160,18 +163,21 @@ def run(self): slurm_rank = int(os.environ["SLURM_ARRAY_TASK_ID"]) + self.max_array_size * int( os.environ.get("RUN_OFFSET", 0) ) + ranks_to_run_range = (slurm_rank * self.tasks_per_job, (slurm_rank + 1) * self.tasks_per_job) with self.logging_dir.open("ranks_to_run.json", "r") as ranks_to_run_file: all_ranks = json.load(ranks_to_run_file) - if slurm_rank >= len(all_ranks): + if ranks_to_run_range[0] >= len(all_ranks): return - rank = all_ranks[slurm_rank] for ss in self.requeue_signals or []: signal.signal(signal.Signals[ss], requeue_handler) - if self.randomize_start: - time.sleep(random.randint(0, 60 * 3)) - self._run_for_rank(rank) + for rank_to_run in range(*ranks_to_run_range): + if rank_to_run >= len(all_ranks): + break + rank = all_ranks[rank_to_run] + + self._run_for_rank(rank) else: # we still have to launch the job self.launch_job() @@ -244,12 +250,14 @@ def launch_job(self): # we actually save this (only once) to avoid race conditions json.dump(ranks_to_run, ranks_to_run_file) - max_array = min(len(ranks_to_run), self.max_array_size) if self.max_array_size != -1 else len(ranks_to_run) + nb_jobs_to_launch = math.ceil(len(ranks_to_run) / self.tasks_per_job) + max_array = min(nb_jobs_to_launch, self.max_array_size) if self.max_array_size != -1 else nb_jobs_to_launch # create the actual sbatch script + srun_args_str = " ".join([f"--{k}={v}" for k, v in self.srun_args.items()]) if self.srun_args else "" launch_file_contents = self.get_launch_file_contents( self.get_sbatch_args(max_array), - f"srun -l launch_pickled_pipeline {self.logging_dir.resolve_paths('executor.pik')}", + f"srun {srun_args_str} -l launch_pickled_pipeline {self.logging_dir.resolve_paths('executor.pik')}", ) # save it with self.logging_dir.open("launch_script.slurm", "w") as launchscript_f: @@ -261,7 +269,7 @@ def launch_job(self): # launch (possibly multiple) jobs launched_jobs = 0 - while launched_jobs * max_array < len(ranks_to_run): + while launched_jobs * max_array < nb_jobs_to_launch: if launched_jobs and self.max_array_launch_parallel and self.stagger_max_array_jobs > 0: time.sleep(self.stagger_max_array_jobs) args = [f"--export=ALL,RUN_OFFSET={launched_jobs}"] diff --git a/src/datatrove/io.py b/src/datatrove/io.py index 0622b54e..62489c46 100644 --- a/src/datatrove/io.py +++ b/src/datatrove/io.py @@ -1,14 +1,16 @@ import os.path from glob import has_magic -from typing import IO, TypeAlias +from typing import IO, Callable, TypeAlias from fsspec import AbstractFileSystem from fsspec import open as fsspec_open from fsspec.callbacks import NoOpCallback, TqdmCallback -from fsspec.core import get_fs_token_paths, url_to_fs +from fsspec.core import get_fs_token_paths, strip_protocol, url_to_fs from fsspec.implementations.dirfs import DirFileSystem from fsspec.implementations.local import LocalFileSystem -from huggingface_hub import HfFileSystem +from huggingface_hub import HfFileSystem, cached_assets_path + +from datatrove.utils.logging import logger class OutputFileManager: @@ -283,6 +285,11 @@ def open_file(file: IO | str, mode="rt", **kwargs): return file +def file_exists(path: str): + fs, a, fpath = get_fs_token_paths(path) + return fs.exists(fpath[0]) + + def download_file(remote_path: str, local_path: str, progress: bool = True): fs, _, paths = get_fs_token_paths(remote_path) fs.get_file( @@ -302,4 +309,55 @@ def download_file(remote_path: str, local_path: str, progress: bool = True): ) +def safely_create_file(file_to_lock: str, do_processing: Callable): + """ + Gets a lock to download/process and create some file(s). When processing is done a ".completed" file is created. + If this file already exists, we skip the processing. Otherwise, we try to acquire a lock and when we get it if the + completed file has not been created yet, we run the processing. + + Args: + file_to_lock: str: lock will be "lock_path.lock" and completed file "lock_path.completed" + do_processing: callback with the code to run to process/create the files + """ + from fasteners import InterProcessLock + + completed_file = f"{file_to_lock}.completed" + + # if the completed file exists, we exit straight away + if os.path.exists(completed_file): + return + + # file is either being downloaded or needs to be downloaded + with InterProcessLock(f"{file_to_lock}.lock"): + if not os.path.exists(completed_file): + do_processing() + open(completed_file, "a").close() + + +def cached_asset_path_or_download( + remote_path: str, progress: bool = True, namespace: str = "default", subfolder: str = "default", desc: str = "file" +): + """ + Download a file from a remote path to a local path. + This function is process-safe and will only download the file if it hasn't been downloaded already. + Args: + namespace: will group diff blocks. example: "filters" + subfolder: relative to the specific block calling this function. Example: "language_filter" + remote_path: str: The remote path to the file to download + progress: bool: Whether to show a progress bar (Default value = True) + desc: description of the file being downloaded + """ + + download_dir = cached_assets_path(library_name="datatrove", namespace=namespace, subfolder=subfolder) + local_path = os.path.join(download_dir, strip_protocol(remote_path).replace("/", "_")) + + def do_download_file(): + logger.info(f'⬇️ Downloading {desc} from "{remote_path}"...') + download_file(remote_path, local_path, progress) + logger.info(f'⬇️ Downloaded {desc} to "{local_path}".') + + safely_create_file(local_path, do_download_file) + return local_path + + DataFolderLike: TypeAlias = str | tuple[str, dict] | DataFolder diff --git a/src/datatrove/pipeline/base.py b/src/datatrove/pipeline/base.py index 5ea855b1..a013766b 100644 --- a/src/datatrove/pipeline/base.py +++ b/src/datatrove/pipeline/base.py @@ -1,9 +1,8 @@ from abc import ABC, abstractmethod from itertools import chain -from typing import NoReturn from datatrove.data import Document, DocumentsPipeline -from datatrove.utils._import_utils import _is_package_available +from datatrove.utils._import_utils import check_required_dependencies from datatrove.utils.stats import Stats @@ -29,14 +28,7 @@ def __new__(cls, *args, **kwargs): """ required_dependencies = chain.from_iterable(getattr(t, "_requires_dependencies", []) for t in cls.mro()) if required_dependencies: - missing_dependencies: dict[str, str] = {} - for dependency in required_dependencies: - dependency = dependency if isinstance(dependency, tuple) else (dependency, dependency) - package_name, pip_name = dependency - if not _is_package_available(package_name): - missing_dependencies[package_name] = pip_name - if missing_dependencies: - _raise_error_for_missing_dependencies(cls.__name__, missing_dependencies) + check_required_dependencies(cls.__name__, required_dependencies) return super().__new__(cls) def __init__(self): @@ -125,26 +117,3 @@ def __call__(self, data: DocumentsPipeline = None, rank: int = 0, world_size: in """ return self.run(data, rank, world_size) - - -def _raise_error_for_missing_dependencies(step_name: str, dependencies: dict[str, str]) -> NoReturn: - """Helper to raise an ImportError for missing dependencies and prompt the user to install said dependencies - - Args: - step_name: str - The name of the step - dependencies: dict[str, str] - The missing dependencies - - """ - dependencies = dict(sorted(dependencies.items())) - package_names = list(dependencies) - if len(dependencies) > 1: - package_names = ( - f"{','.join('`' + package_name + '`' for package_name in package_names[:-1])} and `{package_names[-1]}`" - ) - else: - package_names = f"`{package_names[0]}`" - raise ImportError( - f"Please install {package_names} to use {step_name} (`pip install {' '.join(list(dependencies.values()))}`)." - ) diff --git a/src/datatrove/pipeline/decont/__init__.py b/src/datatrove/pipeline/decont/__init__.py new file mode 100644 index 00000000..efea634c --- /dev/null +++ b/src/datatrove/pipeline/decont/__init__.py @@ -0,0 +1 @@ +from .n_grams import NGramsDecontConfig, NGramsDecontFilter, NGramsDecontIndexer diff --git a/src/datatrove/pipeline/decont/n_grams.py b/src/datatrove/pipeline/decont/n_grams.py new file mode 100644 index 00000000..9918e8c2 --- /dev/null +++ b/src/datatrove/pipeline/decont/n_grams.py @@ -0,0 +1,227 @@ +""" +Used for n-gram decontamination. +First build an index using the tasks we want to use to decontaminate our training dataset. +Then read your training data and apply the filter with the index loaded. +""" + +import os +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import Tuple + +import numpy as np + +from datatrove.data import Document, DocumentsPipeline +from datatrove.io import DataFolderLike, file_exists, get_datafolder, open_file +from datatrove.pipeline.base import PipelineStep +from datatrove.pipeline.filters.base_filter import BaseFilter +from datatrove.pipeline.writers.disk_base import DiskWriter +from datatrove.utils.binaryio import read_np_from_file +from datatrove.utils.hashing import HashConfig, create_hash_func +from datatrove.utils.logging import logger +from datatrove.utils.text import TextNormConfig, ngrams, simplify_text +from datatrove.utils.typeshelper import Languages +from datatrove.utils.word_tokenizers import load_word_tokenizer + + +@dataclass +class NGramsDecontConfig: + """ + Example for n_grams=4 + query = ['A', 'B', 'C', 'D', 'E'] (the prompt/instruction) + label = ['F', 'G', 'H', 'I', 'J'] (the answer/gold) + Will find the following N-GRAMS in the training data: + 'F G H I' + 'G H I J' + + IF find_query_ngrams: + 'A B C D' + 'B C D E' + + IF find_overlap_ngrams: + 'C D E F' + 'D E F G' + 'E F G H' + """ + + n_grams: int = 12 + find_query_ngrams: bool = False # enable to also check for matches in n-grams containing only the input/prompt + find_overlap_ngrams: bool = True # will also find matches for n-grams containing BOTH input and query + norm_config: TextNormConfig = field(default_factory=TextNormConfig) + hash_config: HashConfig = field(default_factory=HashConfig) + + +class NGramsDecontIndexer(PipelineStep): + """ + Creates a decontamination index (basically a list of uint64 hashes from ngrams) for each reference task. + Ways to provide task data: + - as input documents from the previous pipeline step with "text=label/correct answer" + and metadata={"query": query/prompt/input, "task": task name} + - as a list of strings in the format "suite|task" from the lighteval metadata table: + https://github.com/huggingface/lighteval/blob/main/src/lighteval/tasks/tasks_table.jsonl as `lighteval_tasks` + - a path to a text file containing one such list, with one "suite|task" per line as `lighteval_tasks` + you can also define your custom tasks with `custom_lighteval_tasks`. See explanation for `custom_tasks` here: + https://github.com/huggingface/lighteval/tree/main?tab=readme-ov-file#evaluate-a-model-on-extended-community-or-custom-tasks + + """ + + type = "🦠 - DECONT" + name = "💥 N-grams build index" + _requires_dependencies = ["lighteval"] + + def __init__( + self, + output_folder: DataFolderLike, + lighteval_tasks: str | list[str] | None = None, # list in the format suite|task or path to one such list + custom_lighteval_tasks: str | None = None, + config: NGramsDecontConfig = None, + language: str = Languages.english, + ): + super().__init__() + self.output_folder = get_datafolder(output_folder) + # parse list of tasks + if isinstance(lighteval_tasks, str): + if file_exists(lighteval_tasks): + with open_file(lighteval_tasks, "rt") as f: + self.lighteval_tasks = f.read().strip().splitlines() + else: + self.lighteval_tasks = [lighteval_tasks] + else: + self.lighteval_tasks = lighteval_tasks + self.custom_lighteval_tasks = custom_lighteval_tasks + self.config = config or NGramsDecontConfig() + self.tokenizer = load_word_tokenizer(language) + self.hash_func = create_hash_func(self.config.hash_config) + + def compute_hashes(self, label: str, query: str | None = None) -> list[int]: + label_tokens = self.tokenizer.word_tokenize(simplify_text(label, self.config.norm_config)) + ngrams_to_compute = list(ngrams(label_tokens, self.config.n_grams)) + if query is not None: + query_tokens = self.tokenizer.word_tokenize(simplify_text(query, self.config.norm_config)) + if self.config.find_query_ngrams: + ngrams_to_compute.extend(ngrams(query_tokens, self.config.n_grams)) + if self.config.find_overlap_ngrams: + # add tokens overlapping query and label + """ + A, B, C, D, E | F, G, H, I, J + 5 grams + B, C, D, E, F (-N + 1 + i:) + (:i + 1) + ... + E, F, G, H, I + """ + ngrams_to_compute.extend( + [ + query_tokens[-self.config.n_grams + 1 + i :] + label_tokens[: i + 1] + for i in range(self.config.n_grams - 1) + # make sure we actually get a list of size N + if len(query_tokens) >= self.config.n_grams - 1 - i and len(label_tokens) >= i + 1 + ] + ) + return list(map(self.hash_func, map(" ".join, ngrams_to_compute))) + + def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1): + if world_size != 1: + raise ValueError("Decontamination index building requires a single worker.") + hashes = defaultdict(set) + # use whatever date is parsed in with the following format: + # doc.text -> label + # doc.metadata["input"] -> input + if data: + for doc in data: + if not self.config.find_query_ngrams and "query" not in doc.metadata: + raise ValueError( + "only_label_ngrams is False but could not find 'query' field in documents metadata" + ) + hashes[doc.metadata.get("task", "input")].update( + self.compute_hashes(doc.text, doc.metadata.get("query", None)) + ) + + # parse data from lighteval defined tasks + from lighteval.tasks.lighteval_task import LightevalTask + from lighteval.tasks.registry import Registry + + task_dict = Registry(cache_dir=os.getenv("HF_HOME")).get_task_dict( + self.lighteval_tasks, custom_tasks=self.custom_lighteval_tasks + ) + LightevalTask.load_datasets(task_dict.values()) + + for task_name, task in task_dict.items(): + for eval_doc in task.eval_docs(): + try: + golds = eval_doc.get_golds() + query = eval_doc.query + except Exception as e: + logger.warning(f"Error while fetching doc data: {e}") + continue + for gold in golds: + hashes[task_name].update(self.compute_hashes(gold, query)) + + for task_name, task_hashes in hashes.items(): + hashes_array = np.array(list(task_hashes), dtype=self.config.hash_config.np_descr) + logger.info(f"Saving {len(task_hashes)} hashes for {task_name}") + with self.output_folder.open(f"{task_name.replace(' ', '_')}.index.hashes", mode="wb") as f: + if self.output_folder.is_local(): + hashes_array.tofile(f) + else: + f.write(hashes_array.tobytes()) + + +class NGramsDecontFilter(BaseFilter): + """ + Loads list of hashes created by the Indexer step. + For each document in the block's input, we will check if any of its ngrams are part of the reference eval tasks. + If so, they will be removed. The contaminated ngram and task where it was found will be saved in the removed + document's metadata. + """ + + type = "🦠 - DECONT" + name = "💥 N-grams decontaminate" + + def __init__( + self, + index_folder: DataFolderLike, + config: NGramsDecontConfig = None, + exclusion_writer: DiskWriter = None, + language: str = Languages.english, + ): + super().__init__() + self.index_folder = get_datafolder(index_folder) + self.config = config or NGramsDecontConfig() + self.exclusion_writer = exclusion_writer + self.language = language + self._index_hashes = None + self.hash_func = create_hash_func(self.config.hash_config) + self.tokenizer = load_word_tokenizer(language) + + def load_index_hashes(self): + def load_index_from_file(file): + with self.index_folder.open(file, mode="rb") as f: + return file, read_np_from_file( + f, np.dtype(self.config.hash_config.np_descr), self.index_folder.is_local() + ).tolist() + + with ThreadPoolExecutor() as pool: + hashes = pool.map(load_index_from_file, self.index_folder.list_files()) + + self._index_hashes = {} + for filename, hashlist in hashes: + taskname = filename.removesuffix(".index.hashes") + logger.info(f"Loading {len(hashlist)} hashes for {taskname}") + for hash in hashlist: + self._index_hashes[hash] = taskname + + def filter(self, doc: Document) -> bool | Tuple[bool, str]: + if self._index_hashes is None: + self.load_index_hashes() + + text_tokens = self.tokenizer.word_tokenize(simplify_text(doc.text, self.config.norm_config)) + ngrams_to_compute = list(ngrams(text_tokens, self.config.n_grams)) + for n_gram in map(" ".join, ngrams_to_compute): + task = self._index_hashes.get(self.hash_func(n_gram), None) + if task is not None: + doc.metadata["contaminated_ngram"] = n_gram + doc.metadata["contaminated_task"] = task + self.stat_update(f"contaminated_{task}") + if ":" in task: + self.stat_update(f"contaminated_tg_{task[:task.index(':')]}") + return False, "contaminated" + return True diff --git a/src/datatrove/pipeline/dedup/__init__.py b/src/datatrove/pipeline/dedup/__init__.py index de55ae40..1b997b87 100644 --- a/src/datatrove/pipeline/dedup/__init__.py +++ b/src/datatrove/pipeline/dedup/__init__.py @@ -2,9 +2,11 @@ from .exact_substrings import ESDatasetToSequence, ESMergeSequences, ESRangeRemover from .minhash import ( MinhashBuildIndex, + MinhashConfig, MinhashDedupBuckets, MinhashDedupCluster, MinhashDedupFilter, MinhashDedupSignature, ) -from .sentence_dedup import SentenceDedupFilter, SentenceDedupSignature, SentenceFindDedups +from .sentence_dedup import SentDedupConfig, SentenceDedupFilter, SentenceDedupSignature, SentenceFindDedups +from .url_dedup import UrlDedupConfig, UrlDedupFilter, UrlDedupSignature, UrlFindDedups diff --git a/src/datatrove/pipeline/dedup/bloom_filter.py b/src/datatrove/pipeline/dedup/bloom_filter.py index 086d82c4..5a1fbb87 100644 --- a/src/datatrove/pipeline/dedup/bloom_filter.py +++ b/src/datatrove/pipeline/dedup/bloom_filter.py @@ -1,15 +1,18 @@ import contextlib import math +from dataclasses import dataclass, field import numpy as np -from loguru import logger from datatrove.data import Document, DocumentsPipeline from datatrove.io import DataFolderLike, get_datafolder from datatrove.pipeline.base import PipelineStep from datatrove.pipeline.writers.disk_base import DiskWriter -from datatrove.utils.text import DEF_TEXT_NORM_CONFIG, TextNormConfig, sha1_hash32, simplify_text -from datatrove.utils.typeshelper import StatHints +from datatrove.utils.hashing import HashConfig, create_hash_func +from datatrove.utils.logging import logger +from datatrove.utils.text import TextNormConfig, ngrams, simplify_text +from datatrove.utils.typeshelper import Languages, StatHints +from datatrove.utils.word_tokenizers import load_word_tokenizer # http://en.wikipedia.org/wiki/Mersenne_prime @@ -17,6 +20,37 @@ MAX_HASH = 1 << 32 - 1 +@dataclass +class BloomFilterConfig: + """ + m_bytes: bloom filter size in bytes (actual size x8 bigger) + k: number of hashes + expected_elements: expected number of elements, aka + shingles. + duplicate_threshold: above which documents are considered as + duplicated + n_grams: n_grams to use + seed: seed + """ + + m_bytes: int + k: int = None + expected_elements: int = None + duplicate_threshold: float = 0.8 + n_grams: int = 13 + seed: int = 0 + norm_config: TextNormConfig = field(default_factory=TextNormConfig) + hash_config: HashConfig = field(default_factory=lambda: HashConfig(precision=32)) + + @property + def m(self): # (self.m + 7) // 8 # size in bytes + return self.m_bytes * 8 + + def __post_init__(self): + if self.k is None: + self.k = get_optimal_k(self.m, expected_elements=self.expected_elements) + + def get_optimal_k(size_in_bytes: int, expected_elements: int) -> int: assert expected_elements, f"if {expected_elements=} then k must be given" m = size_in_bytes * 8 @@ -34,56 +68,39 @@ class SingleBloomFilter(PipelineStep): Args: output_folder: output folder: local or on S3 - m_bytes: bloom filter size in bytes (actual size x8 bigger) - k: number of hashes - expected_elements: expected number of elements, aka - shingles. - duplicate_threshold: above which documents are considered as - duplicated - n_grams: n_grams to use - seed: seed save_bloom_filter: if true saves bloom filter for later use exclusion_writer: saves duplicated data """ type = "🫂 - DEDUPS" name = "🪷 Bloom-filter" - _requires_dependencies = ["nltk"] def __init__( self, output_folder: DataFolderLike, - m_bytes: int, - k: int = None, - expected_elements: int = None, - duplicate_threshold: float = 0.8, - n_grams: int = 13, - seed: int = 0, - norm_config: TextNormConfig = DEF_TEXT_NORM_CONFIG, + config: BloomFilterConfig, save_bloom_filter: bool = False, exclusion_writer: DiskWriter = None, - language: str = "english", + language: str = Languages.english, ): super().__init__() self.output_folder = get_datafolder(output_folder) - self.m_bytes = m_bytes # size in bits - self.m = m_bytes * 8 # (self.m + 7) // 8 # size in bytes - self.k = k if k else get_optimal_k(self.m, expected_elements=expected_elements) - self.duplicate_threshold = duplicate_threshold - self.n_grams = n_grams - self.bit_vector = bytearray(([0] * self.m_bytes)) + self.tokenizer = load_word_tokenizer(language) + self.config = config + self.bit_vector = bytearray(([0] * self.config.m_bytes)) self.save_bloom_filter = save_bloom_filter self.exclusion_writer = exclusion_writer - self.norm_config = norm_config - assert self.m < MAX_HASH + # TODO: Add support for 64-bit + assert self.config.hash_config.precision == 32, "Bloom filter only supports 32-bit hashes" + self.hash_fc = create_hash_func(self.config.hash_config) + assert self.config.m < MAX_HASH - self.seed = seed self.total_shingles = 0 self._parameters = None - assert self.m_bytes < MAX_HASH, f"{MAX_HASH=} is smaller than {self.m_bytes=}" - if expected_elements: - fp = get_false_positive_prob(self.m_bytes, n=expected_elements, k=self.k) + assert self.config.m_bytes < MAX_HASH, f"{MAX_HASH=} is smaller than {self.config.m_bytes=}" + if self.config.expected_elements: + fp = get_false_positive_prob(self.config.m_bytes, n=self.config.expected_elements, k=self.config.k) if fp > 0.05: logger.warning(f"False probability = {fp:.3}") else: @@ -103,10 +120,10 @@ def parameters(self): random parameters for the hash functions. """ if not self._parameters: - gen = np.random.RandomState(self.seed) + gen = np.random.RandomState(self.config.seed) self._parameters = ( - gen.randint(1, _mersenne_prime, dtype=np.uint64, size=(1, self.k)), - gen.randint(0, _mersenne_prime, dtype=np.uint64, size=(1, self.k)), + gen.randint(1, _mersenne_prime, dtype=np.uint64, size=(1, self.config.k)), + gen.randint(0, _mersenne_prime, dtype=np.uint64, size=(1, self.config.k)), ) return self._parameters @@ -114,12 +131,12 @@ def get_shingles(self, text: str) -> np.ndarray: """Get shingles from a string of text Shingles are created by hashing n-grams of simplified text (lower cases, whitespace normalized, no punctuation, etc). """ - from nltk import ngrams, word_tokenize - return np.fromiter( [ - sha1_hash32(" ".join(x).encode("utf-8")) - for x in ngrams(word_tokenize(simplify_text(text, self.norm_config)), self.n_grams) + self.hash_fc(" ".join(x)) + for x in ngrams( + self.tokenizer.word_tokenize(simplify_text(text, self.config.norm_config)), self.config.n_grams + ) ], dtype=np.uint64, ).reshape((-1, 1)) @@ -127,7 +144,7 @@ def get_shingles(self, text: str) -> np.ndarray: def get_indexes(self, shingles: np.ndarray) -> list[list[int]]: """Get indexes for the shingles with the k hashing functions""" a, b = self.parameters - phv = np.bitwise_and((shingles * a + b) % _mersenne_prime, self.m_bytes) + phv = np.bitwise_and((shingles * a + b) % _mersenne_prime, self.config.m_bytes) return phv.tolist() def update_bf(self, indexes: list[int]): @@ -144,7 +161,6 @@ def query(self, indexes: list[int]) -> bool: mask = 1 << bit_index if (self.bit_vector[byte_index] & mask) == 0: return False - return True def step(self, doc: Document) -> bool: @@ -166,7 +182,7 @@ def step(self, doc: Document) -> bool: indexes_to_update.extend(indexes) self.update_bf(indexes_to_update) - if duplicate_shingles / len(shingles) > self.duplicate_threshold: + if duplicate_shingles / len(shingles) > self.config.duplicate_threshold: self.stat_update(StatHints.dropped) return False return True @@ -188,5 +204,7 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1): f.write(self.bit_vector) logger.info(f"{self.total_shingles=}") - logger.info(f"False probability = {get_false_positive_prob(self.m_bytes, n=self.total_shingles, k=self.k):.3}") - logger.info(f"Optimal K given total shingles = {get_optimal_k(self.m_bytes, self.total_shingles)}") + logger.info( + f"False probability = {get_false_positive_prob(self.config.m_bytes, n=self.total_shingles, k=self.config.k):.3}" + ) + logger.info(f"Optimal K given total shingles = {get_optimal_k(self.config.m_bytes, self.total_shingles)}") diff --git a/src/datatrove/pipeline/dedup/exact_substrings.py b/src/datatrove/pipeline/dedup/exact_substrings.py index e4016b80..6a0d2135 100644 --- a/src/datatrove/pipeline/dedup/exact_substrings.py +++ b/src/datatrove/pipeline/dedup/exact_substrings.py @@ -18,13 +18,15 @@ from typing import BinaryIO, Generator import numpy as np -from loguru import logger from datatrove.io import DataFolderLike, get_datafolder from datatrove.pipeline.base import DocumentsPipeline, PipelineStep +from datatrove.utils.logging import logger from ...utils.tokenization import PipelineStepWithTokenizer from ...utils.typeshelper import ExtensionHelperES as EH +from ...utils.typeshelper import Languages +from ...utils.word_tokenizers import load_word_tokenizer SEPARATOR_BYTES = 12 @@ -148,14 +150,13 @@ def sequence_reader(file: BinaryIO, size_file: BinaryIO) -> Generator[list, None class ESRangeRemover(PipelineStepWithTokenizer): type = "🫂 - DEDUP" name = "🪞 - exact-substrings stage 3" - _requires_dependencies = ["nltk"] def __init__( self, sequence_folder: DataFolderLike, tokenizer_name_or_path: str = "gpt2", min_doc_words: int = 50, - language: str = "english", + language: str = Languages.english, ): super().__init__() self.sequence_folder = get_datafolder(sequence_folder) @@ -168,6 +169,7 @@ def __init__( self.bytes_counter = 0 self.range_idx = 0 self.language = language + self.word_tokenizer = load_word_tokenizer(language) def reset(self): self.bytes_counter = 0 @@ -290,8 +292,6 @@ def get_duplicate_range(self, bytes_len: int): return ranges def remove_duplicate(self, doc, bytes_content): - from nltk import word_tokenize - n_bytes = len(bytes_content) duplicates_ranges = self.get_duplicate_range(n_bytes) duplicates = [] @@ -308,7 +308,7 @@ def remove_duplicate(self, doc, bytes_content): self.bytes_counter += len(bytes_content) - if len(word_tokenize(doc.text, self.language)) < self.min_doc_words: + if len(self.word_tokenizer.word_tokenize(doc.text)) < self.min_doc_words: return False return True diff --git a/src/datatrove/pipeline/dedup/minhash.py b/src/datatrove/pipeline/dedup/minhash.py index 33a298b3..66449d27 100644 --- a/src/datatrove/pipeline/dedup/minhash.py +++ b/src/datatrove/pipeline/dedup/minhash.py @@ -4,24 +4,26 @@ import re import struct from dataclasses import dataclass, field +from pathlib import Path from typing import Generator import numpy as np from fsspec.spec import AbstractBufferedFile -from loguru import logger from datatrove.data import DocumentsPipeline from datatrove.io import DataFolderLike, get_datafolder from datatrove.pipeline.base import PipelineStep from datatrove.pipeline.writers.disk_base import DiskWriter from datatrove.utils.binaryio import read_tuples_from_file, seek_to_start -from datatrove.utils.text import TextNormConfig, sha1_hash32, sha1_hash64, simplify_text -from datatrove.utils.typeshelper import StatHints +from datatrove.utils.hashing import HashConfig, create_hash_func +from datatrove.utils.logging import logger +from datatrove.utils.text import TextNormConfig, ngrams, simplify_text +from datatrove.utils.typeshelper import Languages, StatHints +from datatrove.utils.word_tokenizers import load_word_tokenizer # http://en.wikipedia.org/wiki/Mersenne_prime _mersenne_prime = np.uint64((1 << 61) - 1) -_max_hash_32b = np.uint64((1 << 32) - 1) """ n_grams -> roughly nr of words (this should be small enough to catch fuzzy matches but big enough to not have each shingle be too common) @@ -41,36 +43,19 @@ class MinhashConfig: n_grams: n-grams size to use num_buckets: number of buckets to use hashes_per_bucket: number of hashes per bucket - use_64bit_hashes: use 64bit hashes. Uses 32bit hashes if `False` seed: random seed used to generate the hash function parameters. Should be the same on all workers to ensure they all have the same parameters """ n_grams: int = 5 - num_buckets: int = 14 hashes_per_bucket: int = 8 - - use_64bit_hashes: bool = False seed: int = 1 norm_config: TextNormConfig = field(default_factory=TextNormConfig) - - @property - def hash_dtype(self): - return np.uint64 if self.use_64bit_hashes else np.uint32 - - @property - def hash_format(self): - return "Q" if self.use_64bit_hashes else "I" + hash_config: HashConfig = field(default_factory=HashConfig) def __str__(self): - return ( - f"{self.n_grams}ng_{self.num_buckets}bs_{self.hashes_per_bucket}hs_" - f"{'64' if self.use_64bit_hashes else '32'}b" - ) - - -DEFAULT_MINHASH_CONFIG = MinhashConfig() + return f"{self.n_grams}ng_{self.num_buckets}bs_{self.hashes_per_bucket}hs_{self.hash_config}" @dataclass(order=True) @@ -86,6 +71,7 @@ class HashSig: sig: tuple[int] file_id: int + file_stem: str doc_id: int reader_id: int @@ -111,12 +97,13 @@ def read_sigs( config: minhash configuration (a MinhashConfig object) index_file: is index file """ - line_format = f"{config.hashes_per_bucket}{config.hash_format}{'I' if not index_file else ''}" + line_format = f"{config.hashes_per_bucket}{config.hash_config.struct_format}{'I' if not index_file else ''}" with file as f: if f.size == 0: return - seek_to_start(f, min_hash, line_format, config.hash_format) + seek_to_start(f, min_hash, line_format, config.hash_config.struct_format) last = None + file_stem = Path(file.path).name.removesuffix(".minhash.sig") for data in read_tuples_from_file(f, line_format, lines_to_buffer=lines_to_buffer): sigdata = data if index_file else data[:-1] assert sigdata[0] >= min_hash and ( @@ -126,9 +113,9 @@ def read_sigs( break last = sigdata yield ( - HashSig(sig=sigdata, doc_id=-1, file_id=-1, reader_id=reader_id) + HashSig(sig=sigdata, doc_id=-1, file_id=-1, reader_id=reader_id, file_stem=file_stem) if index_file - else HashSig(sig=sigdata, doc_id=data[-1], file_id=reader_id, reader_id=reader_id) + else HashSig(sig=sigdata, doc_id=data[-1], file_id=reader_id, reader_id=reader_id, file_stem=file_stem) ) @@ -144,26 +131,26 @@ class MinhashDedupSignature(PipelineStep): type = "🫂 - DEDUP" name = "🎯 MinHash stage 1" - _requires_dependencies = ["nltk"] - def __init__( - self, output_folder: DataFolderLike, config: MinhashConfig = DEFAULT_MINHASH_CONFIG, language: str = "english" - ): + def __init__(self, output_folder: DataFolderLike, config: MinhashConfig = None, language: str = Languages.english): super().__init__() self.output_folder = get_datafolder(output_folder) - self.config = config + self.config = config or MinhashConfig() self.num_hashes = self.config.num_buckets * self.config.hashes_per_bucket self._parameters = None - self._hash_func = sha1_hash32 if not self.config.use_64bit_hashes else sha1_hash64 + self._hash_func = create_hash_func(self.config.hash_config) self.language = language + self.word_tokenizer = load_word_tokenizer(language) @property def parameters(self): """Minhash parameters Create parameters for a random bijective permutation function - that maps a 32-bit hash value to another 32-bit hash value. + that maps a 32/64-bit hash value to another 32/64-bit hash value. http://en.wikipedia.org/wiki/Universal_hashing + + Note: For 64-bit hashes the upper-bound for codomain is not [0,2**64) but [0,2**61 - 1) """ if not self._parameters: gen = np.random.RandomState(self.config.seed) @@ -184,10 +171,11 @@ def get_signature(self, shingles: np.ndarray) -> list[list[int]]: """ a, b = self.parameters phv = (shingles * a + b) % _mersenne_prime - if not self.config.use_64bit_hashes: - phv = np.bitwise_and(phv, _max_hash_32b) + if self.config.hash_config.precision == 32: + phv = np.bitwise_and(phv, self.config.hash_config.max) return [ - x.tolist() for x in np.split(np.min(phv, axis=0).astype(self.config.hash_dtype), self.config.num_buckets) + x.tolist() + for x in np.split(np.min(phv, axis=0).astype(self.config.hash_config.np_dtype), self.config.num_buckets) ] def get_shingles(self, text: str) -> np.ndarray: @@ -201,12 +189,13 @@ def get_shingles(self, text: str) -> np.ndarray: Returns: numpy array of shingles: dtype = uint64, shape = (number of n_grams in string, 1) """ - from nltk import ngrams, word_tokenize - return np.fromiter( [ - self._hash_func(" ".join(x).encode("utf-8")) - for x in ngrams(word_tokenize(simplify_text(text, self.config.norm_config)), self.config.n_grams) + self._hash_func(" ".join(x)) + for x in ngrams( + self.word_tokenizer.word_tokenize(simplify_text(text, self.config.norm_config)), + self.config.n_grams, + ) ], dtype=np.uint64, ).reshape((-1, 1)) @@ -226,7 +215,9 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1): # print(f"{self.hashes_per_bucket=} {bucket_sig=}") bucket.write( struct.pack( - f"<{self.config.hashes_per_bucket}{self.config.hash_format}I", *bucket_sig, doc_idx + f"<{self.config.hashes_per_bucket}{self.config.hash_config.struct_format}I", + *bucket_sig, + doc_idx, ) ) # TODO: prevent these files from being uploaded/redownloaded in the first place @@ -249,7 +240,9 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1): for sig in sigs: fo.write( struct.pack( - f"<{self.config.hashes_per_bucket}{self.config.hash_format}I", *sig.sig, sig.doc_id + f"<{self.config.hashes_per_bucket}{self.config.hash_config.struct_format}I", + *sig.sig, + sig.doc_id, ) ) @@ -276,7 +269,7 @@ def __init__( input_folder: DataFolderLike, output_folder: DataFolderLike, index_folder: DataFolderLike = None, - config: MinhashConfig = DEFAULT_MINHASH_CONFIG, + config: MinhashConfig = None, only_dedup_in_index: bool = True, create_index_name: str = None, lines_to_buffer: int = 5, @@ -285,7 +278,7 @@ def __init__( self.input_folder = get_datafolder(input_folder) self.output_folder = get_datafolder(output_folder) self.index_folder = get_datafolder(index_folder) if index_folder else None - self.config = config + self.config = config or MinhashConfig() self.only_dedup_in_index = only_dedup_in_index self.create_index_name = create_index_name self.lines_to_buffer = lines_to_buffer @@ -293,12 +286,15 @@ def __init__( def get_worker_hash_range(self, sig_files, rank, world_size): workers_per_bucket = world_size // self.config.num_buckets bucket, bucket_worker = divmod(rank, workers_per_bucket) - hash_min, hash_max = 0, _mersenne_prime if self.config.use_64bit_hashes else _max_hash_32b + hash_min, hash_max = ( + 0, + _mersenne_prime if self.config.hash_config.precision == 64 else self.config.hash_config.max, + ) if workers_per_bucket > 1 and len(sig_files): # take the first file and find bucket_worker boundaries. all workers in a bucket process the same set of # files, so this should be consistent across workers (and span the entire range of hashes) with self.input_folder.open(sig_files[0], mode="rb") as f: - line_size = struct.calcsize(f"{self.config.hashes_per_bucket}{self.config.hash_format}I") + line_size = struct.calcsize(f"{self.config.hashes_per_bucket}{self.config.hash_config.struct_format}I") L, rem = divmod(f.size, line_size) assert rem == 0, "file size not divisible by line size" assert L >= workers_per_bucket, f"tried to use {workers_per_bucket=} but there are only {L} lines" @@ -306,13 +302,15 @@ def get_worker_hash_range(self, sig_files, rank, world_size): # not first f.seek(line_size * (L // workers_per_bucket) * bucket_worker, os.SEEK_SET) hash_min = struct.unpack( - self.config.hash_format, f.read(struct.calcsize(self.config.hash_format)) + self.config.hash_config.struct_format, + f.read(struct.calcsize(self.config.hash_config.struct_format)), )[0] if bucket_worker + 1 < workers_per_bucket: # not last f.seek(line_size * (L // workers_per_bucket) * (bucket_worker + 1), os.SEEK_SET) hash_max = struct.unpack( - self.config.hash_format, f.read(struct.calcsize(self.config.hash_format)) + self.config.hash_config.struct_format, + f.read(struct.calcsize(self.config.hash_config.struct_format)), )[0] return hash_min, hash_max @@ -392,16 +390,21 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1 # write (file_id1, doc_id1, file_id2, doc_id2) if last.is_from_index(): # we can't actually write -1, so we use SENTINEL instead - out_f.write(struct.pack("<4I", SENTINEL, SENTINEL, v.file_id, v.doc_id)) + out_f.write(struct.pack("<4I", SENTINEL, SENTINEL, int(v.file_stem), v.doc_id)) self.stat_update("index_match", "total_matches") # if there isn't an index, or we are not only deduping in relation to the index elif not index_files or not self.only_dedup_in_index: - out_f.write(struct.pack("<4I", last.file_id, last.doc_id, v.file_id, v.doc_id)) + out_f.write( + struct.pack("<4I", int(last.file_stem), last.doc_id, int(v.file_stem), v.doc_id) + ) self.stat_update("total_matches") elif out_index: # new sig that isn't part of any index, save to our new index out_index.write( - struct.pack(f"<%d{self.config.hash_format}" % self.config.hashes_per_bucket, *v.sig) + struct.pack( + f"<%d{self.config.hash_config.struct_format}" % self.config.hashes_per_bucket, + *v.sig, + ) ) last = v next_sig = next(sig_readers[v.reader_id], None) @@ -425,7 +428,7 @@ def __init__( self, input_folder: DataFolderLike, output_folder: DataFolderLike, - config: MinhashConfig = DEFAULT_MINHASH_CONFIG, + config: MinhashConfig = None, save_cluster_id: bool = False, ignore_index_matches: bool = False, lines_to_buffer: int = 5, @@ -433,7 +436,7 @@ def __init__( super().__init__() self.input_folder = get_datafolder(input_folder) self.output_folder = get_datafolder(output_folder) - self.config = config + self.config = config or MinhashConfig() self.save_cluster_id = save_cluster_id self.ignore_index_matches = ignore_index_matches self.lines_to_buffer = lines_to_buffer @@ -449,6 +452,7 @@ def run(self, data: DocumentsPipeline = None, _: int = 0, world_size: int = 1): def parent(x): if x not in union_set or union_set[x] == x: return x + # Path Compression union_set[x] = parent(union_set[x]) return union_set[x] @@ -566,13 +570,13 @@ def __init__( input_folder: DataFolderLike, output_folder: DataFolderLike, index_name: str, - config: MinhashConfig = DEFAULT_MINHASH_CONFIG, + config: MinhashConfig = None, lines_to_buffer: int = 5, ): super().__init__() self.input_folder = input_folder self.output_folder = output_folder - self.config = config + self.config = config or MinhashConfig() self.index_name = index_name self.lines_to_buffer = lines_to_buffer @@ -596,7 +600,11 @@ def run(self, data: DocumentsPipeline = None, bucket: int = 0, world_size: int = while pq: v: HashSig = heapq.heappop(pq) if not last or last.sig != v.sig: - out_f.write(struct.pack(f"<%d{self.config.hash_format}" % self.config.hashes_per_bucket, *v.sig)) + out_f.write( + struct.pack( + f"<%d{self.config.hash_config.struct_format}" % self.config.hashes_per_bucket, *v.sig + ) + ) last = v next_sig = next(sig_readers[v.file_id], None) if next_sig: diff --git a/src/datatrove/pipeline/dedup/sentence_dedup.py b/src/datatrove/pipeline/dedup/sentence_dedup.py index f863d144..2e096469 100644 --- a/src/datatrove/pipeline/dedup/sentence_dedup.py +++ b/src/datatrove/pipeline/dedup/sentence_dedup.py @@ -13,20 +13,23 @@ import struct from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field +from pathlib import Path from typing import BinaryIO, Generator import numpy as np from fsspec.spec import AbstractBufferedFile -from loguru import logger from tqdm import tqdm from datatrove.data import Document, DocumentsPipeline from datatrove.io import DataFolderLike, get_datafolder from datatrove.pipeline.base import PipelineStep -from datatrove.utils.binaryio import read_tuples_from_file -from datatrove.utils.text import TextNormConfig, sha1_hash64, simplify_text -from datatrove.utils.typeshelper import ExtensionHelperSD, StatHints +from datatrove.utils.binaryio import read_np_from_file, read_tuples_from_file +from datatrove.utils.hashing import HashConfig, create_hash_func +from datatrove.utils.logging import logger +from datatrove.utils.text import SPLIT_TEXT_SENTENCES, TextNormConfig, ngrams, simplify_text, split_into_parts +from datatrove.utils.typeshelper import ExtensionHelperSD, Languages, StatHints +from ...utils.word_tokenizers import load_word_tokenizer from ..writers.disk_base import DiskWriter @@ -36,11 +39,10 @@ class SentDedupConfig: split_sentences: bool = True # set to False to split on \n instead only_dedup_in_index: bool = True min_doc_words: int = 50 + min_num_sentences: int = 3 # remove docs that end up with fewer than 3 sentences min_words_to_remove_span: int = 0 norm_config: TextNormConfig = field(default_factory=TextNormConfig) - - -DEFAULT_SENT_DEDUP_CONFIG = SentDedupConfig() + hash_config: HashConfig = field(default_factory=HashConfig) @dataclass(order=True) @@ -52,6 +54,7 @@ class HashSig: doc_id: int file_id: int = None sent_id: int = None + file_stem: str = None def is_from_index(self): return self.doc_id == self.sent_id == -1 @@ -69,14 +72,13 @@ class SentenceDedupSignature(PipelineStep): type = "🫂 - DEDUPS" name = "💥 sentence-deduplication stage 1" - _requires_dependencies = ["nltk"] def __init__( self, output_folder: DataFolderLike, finder_workers: int = 1, - config: SentDedupConfig = DEFAULT_SENT_DEDUP_CONFIG, - language: str = "english", + config: SentDedupConfig = None, + language: str = Languages.english, ): super().__init__() self.output_folder = get_datafolder(output_folder) @@ -85,15 +87,19 @@ def __init__( elif finder_workers > 1: logger.warning(f"Remember to also set the name of tasks of the finder block to {finder_workers=}!") self.finder_workers = finder_workers - self.config = config + self.config = config or SentDedupConfig() + self.hash_fc = create_hash_func(config.hash_config) self.language = language + self.tokenizer = load_word_tokenizer(language) def save_hashes(self, rank: int, signatures): - # explicitly define little endiannes - signatures = np.array(signatures, dtype=[("hash", " list[None] | list[tuple[int, int, int]]: - from nltk import ngrams - from nltk.tokenize import sent_tokenize - - sentences = sent_tokenize(doc.text, self.language) if self.config.split_sentences else doc.text.splitlines() + sentences = self.tokenizer.sent_tokenize(doc.text) if self.config.split_sentences else doc.text.splitlines() if len(sentences) < self.config.n_sentences: return [] sentences_tokens = [simplify_text(sent, self.config.norm_config) for sent in sentences] n_sent_grams: list = [" ".join(x) for x in ngrams(sentences_tokens, self.config.n_sentences)] hashes = [ - (sha1_hash64(n_sent_gram.encode("utf-8")), doc_idx, sentence_idx) + (self.hash_fc(n_sent_gram), doc_idx, sentence_idx) for sentence_idx, n_sent_gram in enumerate(n_sent_grams) if n_sent_gram.strip() != "" # we actually do not want to remove all the \n everywhere ] @@ -156,18 +161,23 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1): def read_sigs( - file: AbstractBufferedFile, file_id: int, index_file: bool = False, lines_to_buffer: int = 5 + file: AbstractBufferedFile, + file_id: int, + config: SentDedupConfig, + index_file: bool = False, + lines_to_buffer: int = 5, ) -> Generator[HashSig, None, None]: - line_format = "QIH" if not index_file else "Q" + line_format = f"{config.hash_config.struct_format}IH" if not index_file else config.hash_config.struct_format + file_stem = Path(file.path).name.removesuffix(ExtensionHelperSD.stage_1_signature) last = None with file as f: for data in read_tuples_from_file(f, line_format, lines_to_buffer=lines_to_buffer): assert last is None or data[0] >= last, f"Hash order error. {f.tell()=}, {data[0]=}, {last=}" last = data[0] yield ( - HashSig(hash_value=data[0], doc_id=-1, file_id=file_id, sent_id=-1) + HashSig(hash_value=data[0], doc_id=-1, file_id=file_id, sent_id=-1, file_stem=file_stem) if index_file - else HashSig(file_id=file_id, hash_value=data[0], doc_id=data[1], sent_id=data[2]) + else HashSig(file_id=file_id, hash_value=data[0], doc_id=data[1], sent_id=data[2], file_stem=file_stem) ) @@ -192,14 +202,14 @@ def __init__( data_folder: DataFolderLike, output_folder: DataFolderLike, index_folder: DataFolderLike = None, - config: SentDedupConfig = DEFAULT_SENT_DEDUP_CONFIG, + config: SentDedupConfig = None, lines_to_buffer: int = 5, ): super().__init__() self.data_folder = get_datafolder(data_folder) self.output_folder = get_datafolder(output_folder) self.index_folder = get_datafolder(index_folder) if index_folder else None - self.config = config + self.config = config or SentDedupConfig() self.lines_to_buffer = lines_to_buffer def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1): @@ -216,7 +226,7 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1 subdirectory=f"{rank:04d}", glob_pattern=ExtensionHelperSD.stage_1_signature ) sig_readers = [ - read_sigs(file, file_i, lines_to_buffer=self.lines_to_buffer) + read_sigs(file, file_i, config=self.config, lines_to_buffer=self.lines_to_buffer) for file_i, file in enumerate(self.data_folder.open_files(sig_files)) ] index_files = self.index_folder.list_files() if self.index_folder else None @@ -225,7 +235,11 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1 sig_readers.extend( [ read_sigs( - file, len(sig_readers) + file_i, index_file=True, lines_to_buffer=self.lines_to_buffer + file, + len(sig_readers) + file_i, + config=self.config, + index_file=True, + lines_to_buffer=self.lines_to_buffer, ) for file_i, file in enumerate(self.data_folder.open_files(index_files)) ] @@ -255,7 +269,7 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1 if ( last and last.hash_value == v.hash_value and not v.is_from_index() ): # we never want to match samples from the index itself - out_filename = f"{rank:04d}/{v.file_id:05d}{ExtensionHelperSD.stage_2_duplicates}" + out_filename = f"{rank:04d}/{v.file_stem}{ExtensionHelperSD.stage_2_duplicates}" # the previous one we are matching against is part of the index # OR there are no index files # OR we are also matching within the main dataset @@ -288,34 +302,26 @@ class SentenceDedupFilter(PipelineStep): def __init__( self, data_folder: DataFolderLike, - config: SentDedupConfig = DEFAULT_SENT_DEDUP_CONFIG, + config: SentDedupConfig = None, exclusion_writer: DiskWriter = None, - language: str = "english", + language: str = Languages.english, ): - from nltk import load - super().__init__() self.data_folder = get_datafolder(data_folder) - self.config = config - self._tokenizer = load(f"tokenizers/punkt/{language}.pickle") + self.config = config or SentDedupConfig() + self.tokenizer = load_word_tokenizer(language) self.exclusion_writer = exclusion_writer self.language = language def read_duplicates(self, file: BinaryIO) -> np.ndarray: """Helper function to read duplicates from a binary file storing (doc_id, sent_id) pairs as created by the second stage.""" - with file as f: - if self.data_folder.is_local(): - return np.fromfile( - f, dtype=[("doc", " tuple[str, str]: - from nltk.tokenize import word_tokenize - sentence_spans = ( - list(self._tokenizer.span_tokenize(doc.text)) if self.config.split_sentences else doc.text.splitlines() + list(self.tokenizer.span_tokenize(doc.text)) if self.config.split_sentences else doc.text.splitlines() ) kept_sentences = [] original_formatted = [] @@ -336,10 +342,10 @@ def remove_dup_sentences(self, doc: Document, du_lines: np.ndarray) -> tuple[str # if outside the range, we keep this line/sent if idx >= drop_until: if removed_span: - original_formatted.append("<<<\u001b[0m") + original_formatted.append("<<<") if ( self.config.min_words_to_remove_span > 0 - and len(word_tokenize("\n".join(removed_span), self.language)) + and len(self.tokenizer.word_tokenize("\n".join(removed_span))) < self.config.min_words_to_remove_span ): kept_sentences.extend(removed_span) @@ -347,15 +353,15 @@ def remove_dup_sentences(self, doc: Document, du_lines: np.ndarray) -> tuple[str kept_sentences.append(line_text) elif not removed_span: removed_span.append(line_text) - original_formatted.append("\033[91m>>>") + original_formatted.append(">>>") original_formatted.append(line_text) if self.config.split_sentences: last_s = s[1] # use this to include whitespace that is not included in the sentence spans if removed_span: - original_formatted.append("<<<\u001b[0m") + original_formatted.append("<<<") if ( self.config.min_words_to_remove_span > 0 - and len(word_tokenize("\n".join(removed_span), self.language)) < self.config.min_words_to_remove_span + and len(self.tokenizer.word_tokenize("\n".join(removed_span))) < self.config.min_words_to_remove_span ): kept_sentences.extend(removed_span) if len(kept_sentences) < len(sentence_spans): @@ -370,8 +376,6 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> Do SentenceDedupFilter reads a DocumentPipeline and removes duplicated sentences found at stage 2 """ - from nltk.tokenize import word_tokenize - folders = self.data_folder.list_files(include_directories=True, recursive=False) # for performance reasons when having for instance 12k*10k files files = [ @@ -412,8 +416,23 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> Do dups_doc_i += 1 if ( - filtered_text == doc.text - or len(word_tokenize(filtered_text, self.language)) > self.config.min_doc_words + ( + filtered_text == doc.text # no change + or ( + ( + # min doc words + self.config.min_doc_words <= 0 + or len(self.tokenizer.word_tokenize(filtered_text)) >= self.config.min_doc_words + ) + and ( + # min num sentences + self.config.min_num_sentences <= 0 + or len(split_into_parts(filtered_text, SPLIT_TEXT_SENTENCES, self.language)) + >= self.config.min_num_sentences + ) + ) + ) + and filtered_text # can not be completely empty ): # document is kept self.update_doc_stats(doc) if not filtered_text == doc.text and writer: @@ -438,20 +457,26 @@ class SentenceDedupBuildIndex(PipelineStep): name = "💥 sentence-deduplication build index" def __init__( - self, data_folder: DataFolderLike, output_folder: DataFolderLike, index_name: str, lines_to_buffer: int = 5 + self, + data_folder: DataFolderLike, + output_folder: DataFolderLike, + index_name: str, + config: SentDedupConfig = None, + lines_to_buffer: int = 5, ): super().__init__() self.data_folder = get_datafolder(data_folder) self.output_folder = get_datafolder(output_folder) self.index_name = index_name self.lines_to_buffer = lines_to_buffer + self.config = config or SentDedupConfig() def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1): assert world_size == 1, "SentenceDedupBuildIndex can only run on a single worker." with self.stats.time_stats: sig_files = self.data_folder.list_files(glob_pattern=ExtensionHelperSD.stage_1_signature) sig_readers = [ - read_sigs(file, file_i, lines_to_buffer=self.lines_to_buffer) + read_sigs(file, file_i, self.config, lines_to_buffer=self.lines_to_buffer) for file_i, file in enumerate(self.data_folder.open_files(sig_files)) ] @@ -463,7 +488,7 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1 while pq: v: HashSig = heapq.heappop(pq) if last != v.hash_value: - out_f.write(struct.pack(" bool: + # Ensure that highest priority is always first of the hashes + return (self.hash_value, -self.priority, self.doc_id) < ( + other.hash_value, + -other.priority, + other.doc_id, + ) + + +def get_sig_dtype(config: HashConfig) -> np.dtype: + return np.dtype([("hash", config.np_dtype), ("priority", "= 1") + elif finder_workers > 1: + logger.warning(f"Remember to also set the number of tasks of the finder block to {finder_workers=}!") + self.finder_workers = finder_workers + self.config = config or UrlDedupConfig() + self.hash_fc = create_hash_func(self.config.hash_config) + + def save_hashes(self, rank: int, signatures): + sig_dtype = get_sig_dtype(self.config.hash_config) + priority_max = np.iinfo(sig_dtype["priority"]).max + + # 0 will stay as is, so we can't use 0 as a priority + assert all( + sig[1] >= 1 and sig[1] <= priority_max for sig in signatures + ), f"priority must be between 1 and {priority_max}" + signatures = np.array(signatures, dtype=sig_dtype) + + # Ensure that the highest priority is always first + signatures["priority"] = -signatures["priority"] + signatures.sort(axis=0) + signatures["priority"] = -signatures["priority"] + + # Same code as in sentence_dedup + hashes_per_worker = self.config.hash_config.max // self.finder_workers + left_idx = 0 + for hash_i in range(self.finder_workers): + with self.output_folder.open( + f"{hash_i:04d}/{rank:05d}{ExtensionHelperSD.stage_1_signature}", + mode="wb", + ) as f: + # last bucket needs to have everything + right_hash = ( + (hash_i + 1) * hashes_per_worker if hash_i != self.finder_workers - 1 else np.iinfo(np.uint64).max + ) + # find last hash that goes in this bucket. This obeys the following rule: + # signatures['hash'][right_idx - 1] <= right_hash <= signatures['hash'][right_idx] + right_idx = left_idx + signatures["hash"][left_idx:].searchsorted(right_hash, side="right") + # save to file + if right_idx > left_idx: + bts = signatures[left_idx:right_idx].tobytes() + f.write(bts) + left_idx = right_idx + # we've reached the end of our data + if right_idx >= len(signatures): + break + + def get_hashes(self, doc: Document, doc_idx: int) -> list[None] | list[tuple[int, int, int]]: + normalized_url: str = ( + self.config.url_normalizer(doc.metadata["url"]) if self.config.url_normalizer else doc.metadata["url"] + ) + priority = self.config.document_priority(doc) if self.config.document_priority else 1 + hashes = [(self.hash_fc(normalized_url), priority, doc_idx)] + + return hashes + + def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1): + signatures = [] + for doc_idx, doc in enumerate(data): + with self.stats.time_stats: + self.stat_update(StatHints.total) + signatures.extend(self.get_hashes(doc, doc_idx)) + self.save_hashes(rank, signatures) + + +def read_sigs( + file: AbstractBufferedFile, + file_id: int, + hash_config: HashConfig, + index_file: bool = False, + lines_to_buffer: int = 5, +) -> Generator[HashSig, None, None]: + last = None + line_format = f"{hash_config.struct_format}HI" if not index_file else hash_config.struct_format + with file as f: + file_stem = Path(f.path).name.removesuffix(ExtensionHelperSD.stage_1_signature) + for data in read_tuples_from_file(f, line_format, lines_to_buffer=lines_to_buffer): + assert last is None or data[0] >= last, f"Hash order error. {f.tell()=}, {data[0]=}, {last=}" + last = data[0] + yield ( + HashSig(hash_value=data[0], doc_id=-1, file_id=file_id, priority=-1, file_stem=file_stem) + if index_file + else HashSig( + file_id=file_id, + file_stem=file_stem, + hash_value=data[0], + priority=data[1], + doc_id=data[2], + ) + ) + + +class UrlFindDedups(PipelineStep): + """UrlDedup: Second pipeline step + UrlFindDedups reads all the signatures from the previous step and loads them + in a priority queue to check for duplicates. If a duplicate is found its document id is saved. + The document with the highest priority is the one that will be saved out of the duplicates . + + Args: + data_folder: data folder where signatures are saved + output_folder: folder where duplicates are saved + index_folder: folder where index files are saved + config: configuration for the dedup + lines_to_buffer: number of lines to buffer (speed up reading) + """ + + type = "🫂 - DEDUPS" + name = "💥 url-deduplication stage 2" + + def __init__( + self, + data_folder: DataFolderLike, + output_folder: DataFolderLike, + index_folder: DataFolderLike | None = None, + config: UrlDedupConfig | None = None, + lines_to_buffer: int = 5, + ): + super().__init__() + self.data_folder = get_datafolder(data_folder) + self.output_folder = get_datafolder(output_folder) + self.index_folder = get_datafolder(index_folder) if index_folder else None + + self.config = config or UrlDedupConfig() + self.lines_to_buffer = lines_to_buffer + + def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1): + with self.stats.time_stats: + if world_size == 1: + # check that there was not a mistake in setting this values + sig_files = self.data_folder.list_files(glob_pattern="*/*" + ExtensionHelperSD.stage_1_signature) + if any(not sig_file.startswith("0000/") for sig_file in sig_files): + raise ValueError( + f"{world_size=} but found sig files for different hash buckets. Set tasks=finder_workers" + ) + else: + sig_files = self.data_folder.list_files( + subdirectory=f"{rank:04d}", + glob_pattern=ExtensionHelperSD.stage_1_signature, + ) + sig_readers = [ + read_sigs( + file, + file_i, + self.config.hash_config, + lines_to_buffer=self.lines_to_buffer, + ) + for file_i, file in enumerate(self.data_folder.open_files(sig_files)) + ] + index_files = self.index_folder.list_files() if self.index_folder else None + if index_files: + logger.info(f"Found index file(s): {', '.join(index_files)}") + sig_readers.extend( + [ + read_sigs( + file, + len(sig_readers) + file_i, + self.config.hash_config, + index_file=True, + lines_to_buffer=self.lines_to_buffer, + ) + for file_i, file in enumerate(self.data_folder.open_files(index_files)) + ] + ) + + logger.info(f"Initializing pq with {len(sig_readers)} files.") + with ThreadPoolExecutor() as executor: + pq = [ + x + for x in tqdm( + executor.map(lambda x: next(x, None), sig_readers), + total=len(sig_readers), + desc="Initializing pq...", + ) + if x + ] + heapq.heapify(pq) + logger.info("PQ initialized.") + + output_mg = self.output_folder.get_output_file_manager(mode="wb") + last: HashSig | None = None + packer = struct.Struct(" np.ndarray: + """Helper function to read duplicates from a binary file storing (doc_id) as created by the second stage.""" + with file as f: + return read_np_from_file(f, dtype=dup_dtype, is_local_file=self.data_folder.is_local()) + + def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1): + folders = self.data_folder.list_files(include_directories=True, recursive=False) + # for performance reasons when having for instance 12k*10k files + files = [ + f + for f in [f"{folder}/{rank:05d}{ExtensionHelperSD.stage_2_duplicates}" for folder in folders] + if self.data_folder.exists(f) + ] + + logger.info(f"Loading duplicate indexes from {len(files)} results files.") + + dup_dtype = get_sig_dtype(self.config.hash_config)[2] + all_dups = np.array([], dtype=dup_dtype) + if files: + with ThreadPoolExecutor() as pool: + read_partial = partial(self.read_duplicates, dup_dtype=dup_dtype) + all_dups = np.concatenate( + list( + tqdm( + pool.map(read_partial, self.data_folder.open_files(files)), + total=len(files), + ) + ), + axis=0, + ) + all_dups.sort() + + logger.info("Loaded duplicate indexes.") + dups_doc_i = 0 + with self.exclusion_writer if self.exclusion_writer else contextlib.nullcontext() as writer: + with self.stats.time_stats: + for doc_idx, doc in enumerate(data): + self.stat_update(StatHints.total) + with self.stats.time_stats: + if dups_doc_i < all_dups.shape[0] and all_dups[dups_doc_i] == doc_idx: + if writer: + writer.write(doc, rank=rank) + self.stat_update(StatHints.dropped) + dups_doc_i += 1 + else: + self.stat_update(StatHints.forwarded) + self.update_doc_stats(doc) + yield doc + + +class UrlDedupBuildIndex(PipelineStep): + """UrlDedup: Only build an index + Works exactly the same as SentenceDedupBuildIndex + + Args: + data_folder: data folder to get signature files. + output_folder: folder where index is saved + index_name: name of the index + """ + + type = "🫂 - DEDUP" + name = "💥 url-deduplication build index" + + def __init__( + self, + data_folder: DataFolderLike, + output_folder: DataFolderLike, + index_name: str, + config: UrlDedupConfig | None = None, + lines_to_buffer: int = 5, + ): + super().__init__() + self.data_folder = get_datafolder(data_folder) + self.output_folder = get_datafolder(output_folder) + self.index_name = index_name + self.lines_to_buffer = lines_to_buffer + self.config = config or UrlDedupConfig() + + def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1): + assert world_size == 1, "UrlDedupBuildIndex can only run on a single worker." + with self.stats.time_stats: + sig_files = self.data_folder.list_files(glob_pattern=ExtensionHelperSD.stage_1_signature) + sig_readers = [ + read_sigs(file, file_i, self.config.hash_config, lines_to_buffer=self.lines_to_buffer) + for file_i, file in enumerate(self.data_folder.open_files(sig_files)) + ] + + pq = [next(sig_reader) for sig_reader in sig_readers] + heapq.heapify(pq) + + with self.output_folder.open(f"{self.index_name}.{ExtensionHelperSD.index}", mode="wb") as out_f: + last = None + while pq: + v: HashSig = heapq.heappop(pq) + if last != v.hash_value: + out_f.write(struct.pack(f"<{self.config.hash_config.struct_format}", v.hash_value)) + last = v.hash_value + new_v = next(sig_readers[v.file_id], None) + + if new_v: + heapq.heappush(pq, new_v) diff --git a/src/datatrove/pipeline/extractors/base.py b/src/datatrove/pipeline/extractors/base.py index 62b12d5b..d3622d79 100644 --- a/src/datatrove/pipeline/extractors/base.py +++ b/src/datatrove/pipeline/extractors/base.py @@ -1,10 +1,9 @@ -import signal from abc import abstractmethod +from concurrent.futures import ThreadPoolExecutor -from loguru import logger - -from datatrove.data import Document, DocumentsPipeline +from datatrove.data import DocumentsPipeline from datatrove.pipeline.base import PipelineStep +from datatrove.utils.logging import logger from datatrove.utils.typeshelper import StatHints @@ -35,34 +34,6 @@ def extract(self, text: str) -> str: """ pass - def timeout_extract(self, doc: Document): - """Stops the extraction if it takes longer than timeout. - This is the main entrypoint for this class. - - Args: - doc: Document: - - Returns: - - """ - - def signal_handler(_signum, _frame): - raise TimeoutError - - signal.signal(signal.SIGALRM, signal_handler) - signal.setitimer(signal.ITIMER_REAL, self.timeout) - try: - return self.extract(doc.text) - - except TimeoutError: - logger.warning("⏰ Timeout while cleaning record text. Skipping record.") - - except Exception as e: - logger.warning(f'❌ Error "{e}" while cleaning record text. Skipping record.') - - finally: - signal.setitimer(signal.ITIMER_REAL, 0) - def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> DocumentsPipeline: """Iterates through each document in data and calls `timeout_extract` on it. @@ -74,13 +45,22 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> Do Returns: """ - for doc in data: - self.stat_update(StatHints.total) - with self.track_time(): - doc.text = self.timeout_extract(doc) - if doc.text: - self.stat_update(StatHints.forwarded) - self.update_doc_stats(doc) - yield doc - else: - self.stat_update(StatHints.dropped) + with ThreadPoolExecutor() as executor: # more reliable than using signal for timeouts + for doc in data: + self.stat_update(StatHints.total) + with self.track_time(): + future = executor.submit(self.extract, doc.text) + try: + doc.text = future.result(timeout=self.timeout) + except TimeoutError: + logger.warning("⏰ Timeout while cleaning record text. Skipping record.") + continue + except Exception as e: + logger.warning(f'❌ Error "{e}" while cleaning record text. Skipping record.') + continue + if doc.text: + self.stat_update(StatHints.forwarded) + self.update_doc_stats(doc) + yield doc + else: + self.stat_update(StatHints.dropped) diff --git a/src/datatrove/pipeline/filters/__init__.py b/src/datatrove/pipeline/filters/__init__.py index 25a0423f..065496a2 100644 --- a/src/datatrove/pipeline/filters/__init__.py +++ b/src/datatrove/pipeline/filters/__init__.py @@ -1,10 +1,10 @@ -from .c4_quality_filter import C4ParagraphFilter, C4QualityFilter +from .c4_filters import C4BadWordsFilter, C4ParagraphFilter, C4QualityFilter from .fasttext_filter import FastTextClassifierFilter +from .fineweb_quality_filter import FineWebQualityFilter from .gopher_quality_filter import GopherQualityFilter from .gopher_repetition_filter import GopherRepetitionFilter from .lambda_filter import LambdaFilter from .language_filter import LanguageFilter -from .list_filter import ListFilter from .regex_filter import RegexFilter from .sampler_filter import SamplerFilter from .unigram_log_probs import UnigramLogProbFilter diff --git a/src/datatrove/pipeline/filters/c4_quality_filter.py b/src/datatrove/pipeline/filters/c4_filters.py similarity index 58% rename from src/datatrove/pipeline/filters/c4_quality_filter.py rename to src/datatrove/pipeline/filters/c4_filters.py index b2cfd13d..b9c425f1 100644 --- a/src/datatrove/pipeline/filters/c4_quality_filter.py +++ b/src/datatrove/pipeline/filters/c4_filters.py @@ -1,9 +1,14 @@ import heapq import re +from numpy.random import default_rng + from datatrove.data import Document +from datatrove.io import cached_asset_path_or_download from datatrove.pipeline.filters.base_filter import BaseFilter from datatrove.pipeline.writers.disk_base import DiskWriter +from datatrove.utils.typeshelper import Languages +from datatrove.utils.word_tokenizers import load_word_tokenizer CITATION_REGEX = re.compile(r"\[\d*]|\[edit]|\[citation needed]") @@ -51,12 +56,10 @@ class C4QualityFilter(BaseFilter): """ name = "⛰ C4 Quality" - _requires_dependencies = ["nltk"] def __init__( self, exclusion_writer: DiskWriter = None, - tokenizer_language: str = "english", split_paragraph: bool = True, # default as used on c4. Set to "False" to split with sent_tokenize remove_citations: bool = True, filter_no_terminal_punct: bool = True, @@ -67,9 +70,9 @@ def __init__( filter_javascript: bool = True, filter_curly_bracket: bool = True, filter_policy: bool = True, + language: str = Languages.english, ): super().__init__(exclusion_writer) - self.tokenizer_language = tokenizer_language self.split_paragraph = split_paragraph self.remove_citations = remove_citations self.filter_no_terminal_punct = filter_no_terminal_punct @@ -80,15 +83,10 @@ def __init__( self.filter_javascript = filter_javascript self.filter_curly_bracket = filter_curly_bracket self.filter_policy = filter_policy + self.tokenizer = load_word_tokenizer(language) def filter(self, doc: Document) -> bool | tuple[bool, str]: - from nltk.tokenize import sent_tokenize - - lines = ( - doc.text.splitlines() - if self.split_paragraph - else sent_tokenize(doc.text, language=self.tokenizer_language) - ) + lines = doc.text.splitlines() if self.split_paragraph else self.tokenizer.sent_tokenize(doc.text) num_sentences = 0 kept_lines = [] @@ -127,7 +125,7 @@ def filter(self, doc: Document) -> bool | tuple[bool, str]: if self.filter_policy and any(p in line_l for p in POLICY_SUBSTRINGS): self.stat_update("line-filter-policy") continue - num_sentences += len(sent_tokenize(line, language=self.tokenizer_language)) if self.split_paragraph else 1 + num_sentences += len(self.tokenizer.sent_tokenize(line)) if self.split_paragraph else 1 kept_lines.append(line) self.stat_update("line-kept") if num_sentences < self.min_num_sentences: @@ -168,3 +166,116 @@ def filter(self, doc: Document) -> bool | tuple[bool, str]: if not self.paragraph_filter(doc.text): return False, f"< {self.min_paragraphs} paragraphs" return True + + +_EN_BADWORDS_URL = "https://raw.githubusercontent.com/LDNOOBW/List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words/25e679f03d96baa721cde20db9944649e8d0a844/en" +_BADWORDS_URL = "https://raw.githubusercontent.com/LDNOOBW/List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words/5faf2ba42d7b1c0977169ec3611df25a3c08eb13/" +_BADWORDS_LANGS = [ + "ar", + "cs", + "da", + "de", + "en", + "eo", + "es", + "fa", + "fi", + "fil", + "fr", + "fr-CA-u-sd-caqc", + "hi", + "hu", + "it", + "ja", + "kab", + "ko", + "nl", + "no", + "pl", + "pt", + "ru", + "sv", + "th", + "tlh", + "tr", + "zh", +] +# Words that are allowed since they are common subwords in languages without +# spaces. These each filter >10% of documents of their language when disallowed. +_BADWORDS_ALLOWLIST = {"ja": {"sm", "グロ", "女の子"}, "zh": {"性"}} + + +class C4BadWordsFilter(BaseFilter): + """ + Badwords filter from C4. + Args: + keep_fraction (float): what percentage of pages containing bad words should be kept + fail_on_missing_language (bool) whether to fail when a document has an unknown language + seed (int): used for the uniform distribution generator for use with keep_fraction + default_language (str): what language for samples without language in their metadata + """ + + name = "⛰ C4 Badwords" + + def __init__( + self, + keep_fraction: float = 0.0, + fail_on_missing_language: bool = True, + seed: int = None, + default_language: str = "en", + exclusion_writer: DiskWriter = None, + ): + super().__init__(exclusion_writer) + self.keep_fraction = keep_fraction + self.fail_on_missing_language = fail_on_missing_language + self._badwords_regex: dict[str, re.Pattern] = {} + self.uniform = default_rng(seed).uniform + self.default_language = default_language + + def _get_badwords(self, lang: str): + if lang not in self._badwords_regex: + if lang not in _BADWORDS_LANGS: + if self.fail_on_missing_language: + raise ValueError( + f'There is not badwords list available for "{lang}". ' + f"Set fail_on_missing_language=False to continue anyway." + ) + else: + return None + local_path = cached_asset_path_or_download( + _BADWORDS_URL + lang if lang != "en" else _EN_BADWORDS_URL, + namespace="filters", + subfolder="c4_badwords", + ) + badwords: set[str] = set() + # load from file + with open(local_path, "rt") as f: + badwords.update(line.strip() for line in f) + for lang, allowlist in _BADWORDS_ALLOWLIST.items(): + badwords -= allowlist + + words = [re.escape(w) for w in badwords] + self._badwords_regex[lang] = ( + # For Japanese, Thai, and Chinese, do not require word separations. + re.compile("|".join(words)) + if lang in ("ja", "th", "zh") + # For other languages, match only when flanked by non-word chars. + else re.compile(r"(?:\W|^)({})(?:\W|$)".format("|".join(words))) + ) + return self._badwords_regex[lang] + + def filter(self, doc: Document) -> bool | tuple[bool, str]: + lang: str = doc.metadata.get("language", self.default_language) + badwords_regex = self._get_badwords(lang) + if badwords_regex is None: + self.stat_update("missing_badwords_lang", f"missing_badwords_lang_{lang}") + return True + badwords_found = badwords_regex.search(doc.text.lower()) + if badwords_found is not None: + self.stat_update("documents_with_badwords", f"documents_with_badwords_{lang}") + if self.keep_fraction > 0.0 and self.uniform() < self.keep_fraction: + self.stat_update("document_kept_with_badwords", f"document_kept_with_badwords_{lang}") + return True + self.stat_update(f"document_removed_with_badwords_{lang}") + return False, "document_removed_with_badwords" + return True diff --git a/src/datatrove/pipeline/filters/fasttext_filter.py b/src/datatrove/pipeline/filters/fasttext_filter.py index 431f352c..cdeb5ad0 100644 --- a/src/datatrove/pipeline/filters/fasttext_filter.py +++ b/src/datatrove/pipeline/filters/fasttext_filter.py @@ -1,14 +1,13 @@ -import os +from collections import defaultdict from typing import Tuple -from fsspec.core import strip_protocol -from huggingface_hub import cached_assets_path -from loguru import logger +import numpy as np from datatrove.data import Document -from datatrove.io import download_file +from datatrove.io import cached_asset_path_or_download from datatrove.pipeline.filters.base_filter import BaseFilter from datatrove.pipeline.writers.disk_base import DiskWriter +from datatrove.utils.text import SPLIT_TEXT_DOCUMENTS, split_into_parts class FastTextClassifierFilter(BaseFilter): @@ -31,25 +30,29 @@ class FastTextClassifierFilter(BaseFilter): keep_labels: tuple of (label name without "__label__", min score) (or list of such tuples) remove_labels: tuple of (label name without "__label__", min score) (or list of such tuples) save_labels_in_metadata: whether to save all the label scores in the document metadata + newline_replacement: str to replace \n with before predicting scores + filter_mode: predict and filter on DOCUMENT, PARAGRAPH or SENTENCE level exclusion_writer: """ name = "🤖 fastText" - _requires_dependencies = [("fasttext", "fasttext-wheel")] + _requires_dependencies = [("fasttext", "fasttext-wheel"), "fasteners"] def __init__( self, model_url: str, - keep_labels: Tuple[str, float] | list[Tuple[str, float]] = None, - remove_labels: Tuple[str, float] | list[Tuple[str, float]] = None, + keep_labels: Tuple[str, float] | list[Tuple[str, float]] | None = None, + remove_labels: Tuple[str, float] | list[Tuple[str, float]] | None = None, save_labels_in_metadata: bool = True, - exclusion_writer: DiskWriter = None, + exclusion_writer: DiskWriter | None = None, newline_replacement="", + filter_mode: str = SPLIT_TEXT_DOCUMENTS, ): super().__init__(exclusion_writer) self.model_url = model_url self.keep_labels = keep_labels self.remove_labels = remove_labels + self.filter_mode = filter_mode if keep_labels and remove_labels: raise ValueError("You can only supply one of `keep_labels` or `remove_labels`.") self.newline_replacement = newline_replacement @@ -65,26 +68,45 @@ def model(self): if not self._model: from fasttext.FastText import _FastText - download_dir = cached_assets_path(library_name="datatrove", namespace="filters", subfolder="fasttext") - - model_file = os.path.join(download_dir, strip_protocol(self.model_url).replace("/", "_")) - if not os.path.isfile(model_file): - logger.info(f'⬇️ Downloading fast-text model from "{self.model_url}"...') - download_file(self.model_url, model_file) - logger.info(f'⬇️ Downloaded fast-text model to "{model_file}".') + model_file = cached_asset_path_or_download( + self.model_url, namespace="filters", subfolder="fasttext", desc="fast-text model" + ) self._model = _FastText(model_file) + # check label values + available_labels = [x.removeprefix("__label__") for x in self._model.labels] + for label, _ in self.keep_labels or [] + self.remove_labels or []: + if label not in available_labels: + raise ValueError( + f"Label '{label}' passed as keep_labels or remove_labels is not available in this " + f"FastText model. Available labels: {available_labels}" + ) return self._model def filter(self, doc: Document) -> bool: - labels, scores = self.model.predict(doc.text.replace("\n", self.newline_replacement)) - label_scores = dict(zip(labels, scores)) + def check_label_scores(unit_scores): + if self.keep_labels: + return any( + unit_scores.get(f"__label__{label}", -9e9) >= min_score for label, min_score in self.keep_labels + ) + else: + return not self.remove_labels or not any( + unit_scores.get(f"__label__{label}", -9e9) >= min_score for label, min_score in self.remove_labels + ) + + units = split_into_parts(doc.text, mode=self.filter_mode) + kept_spans = [] + label_scores = defaultdict(list) + for unit in units: + labels, scores = self.model.predict(unit.strip().replace("\n", self.newline_replacement), k=-1) + if self.save_labels_in_metadata: + for label, score in zip(labels, scores): + label_scores[label].append(score) + if check_label_scores(dict(zip(labels, scores))): + kept_spans.append(unit) + self.stat_update("kept_span") + else: + self.stat_update("removed_span") + doc.text = "".join(kept_spans) if self.save_labels_in_metadata: - doc.metadata.update(label_scores) - if self.keep_labels: - return any( - label_scores.get(f"__label__{label}", -9e9) >= min_score for label, min_score in self.keep_labels - ) - else: - return not self.remove_labels or not any( - label_scores.get(f"__label__{label}", -9e9) >= min_score for label, min_score in self.remove_labels - ) + doc.metadata.update({label: np.mean(scores).item() for label, scores in label_scores.items()}) + return not not doc.text.strip() diff --git a/src/datatrove/pipeline/filters/fineweb_quality_filter.py b/src/datatrove/pipeline/filters/fineweb_quality_filter.py new file mode 100644 index 00000000..0d40f785 --- /dev/null +++ b/src/datatrove/pipeline/filters/fineweb_quality_filter.py @@ -0,0 +1,54 @@ +from datatrove.pipeline.filters.base_filter import BaseFilter +from datatrove.pipeline.filters.gopher_repetition_filter import find_duplicates +from datatrove.pipeline.writers.disk_base import DiskWriter +from datatrove.utils.typeshelper import Languages +from datatrove.utils.word_tokenizers import load_word_tokenizer + + +class FineWebQualityFilter(BaseFilter): + name = "🍷 FineWeb Quality" + + def __init__( + self, + exclusion_writer: DiskWriter = None, + line_punct_thr: float = 0.12, + line_punct_exclude_zero: bool = False, + short_line_thr: float = 0.67, + short_line_length: int = 30, + char_duplicates_ratio: float = 0.01, + new_line_ratio: float = 0.3, + language: str = Languages.english, + ): + super().__init__(exclusion_writer) + self.line_punct_thr = line_punct_thr + self.line_punct_exclude_zero = line_punct_exclude_zero + self.short_line_threshold = short_line_thr + self.short_line_length = short_line_length + self.char_duplicates_ratio = char_duplicates_ratio + self.new_line_ratio = new_line_ratio + self.tokenizer = load_word_tokenizer(language) + + def filter(self, doc) -> bool | tuple[bool, str]: + stop_chars = (".", "'", '"', "!", "?") + + lines = doc.text.split("\n") + ratio = sum(1 for line in lines if line.endswith(stop_chars)) / len(lines) + if ratio <= self.line_punct_thr and not (ratio == 0 and self.line_punct_exclude_zero): + return False, "line_punct_ratio" + + ratio = sum(1 for line in lines if len(line) <= self.short_line_length) / len(lines) + if ratio >= self.short_line_threshold: + return False, "short_line_ratio" + + non_empty_lines = [line for line in lines if line.strip() != ""] + ratio = find_duplicates(non_empty_lines)[1] / len(doc.text.replace("\n", "")) + + if ratio >= self.char_duplicates_ratio: + return False, "char_dup_ratio" + + words = self.tokenizer.word_tokenize(doc.text) + new_line = doc.text.count("\n") + if new_line / len(words) > self.new_line_ratio: + return False, "list_ratio" + + return True diff --git a/src/datatrove/pipeline/filters/gopher_quality_filter.py b/src/datatrove/pipeline/filters/gopher_quality_filter.py index 8362bcf9..aaa530d3 100644 --- a/src/datatrove/pipeline/filters/gopher_quality_filter.py +++ b/src/datatrove/pipeline/filters/gopher_quality_filter.py @@ -4,6 +4,8 @@ from datatrove.pipeline.filters.base_filter import BaseFilter from datatrove.pipeline.writers.disk_base import DiskWriter from datatrove.utils.text import PUNCTUATION_SET +from datatrove.utils.typeshelper import Languages +from datatrove.utils.word_tokenizers import load_word_tokenizer STOP_WORDS = ["the", "be", "to", "of", "and", "that", "have", "with"] @@ -11,7 +13,6 @@ class GopherQualityFilter(BaseFilter): name = "🥇 Gopher Quality" - _requires_dependencies = ["nltk"] def __init__( self, @@ -26,6 +27,7 @@ def __init__( min_stop_words: int | None = 2, stop_words: list[str] | None = None, exclusion_writer: DiskWriter = None, + language: str = Languages.english, ): """ Filter to apply Gopher's quality heuristic rules. @@ -55,6 +57,7 @@ def __init__( self.max_non_alpha_words_ratio = max_non_alpha_words_ratio self.min_stop_words = min_stop_words self.stop_words = set(STOP_WORDS if stop_words is None else stop_words) + self.tokenizer = load_word_tokenizer(language) def filter(self, doc: Document) -> bool | tuple[bool, str]: """ @@ -66,10 +69,8 @@ def filter(self, doc: Document) -> bool | tuple[bool, str]: Returns: False if sample.text does not pass any of the the heuristic tests """ - from nltk.tokenize import word_tokenize - text = doc.text - words = word_tokenize(text) # TODO we should use language id filter + words = self.tokenizer.word_tokenize(text) n_words = len(words) non_symbol_words = [w for w in words if any(ch not in PUNCTUATION_SET for ch in w)] diff --git a/src/datatrove/pipeline/filters/gopher_repetition_filter.py b/src/datatrove/pipeline/filters/gopher_repetition_filter.py index 718c4d6d..318c33da 100644 --- a/src/datatrove/pipeline/filters/gopher_repetition_filter.py +++ b/src/datatrove/pipeline/filters/gopher_repetition_filter.py @@ -4,6 +4,8 @@ from datatrove.data import Document from datatrove.pipeline.filters.base_filter import BaseFilter from datatrove.pipeline.writers.disk_base import DiskWriter +from datatrove.utils.typeshelper import Languages +from datatrove.utils.word_tokenizers import load_word_tokenizer """ @@ -70,7 +72,6 @@ def find_all_duplicate(words: list[str], n: int) -> int: class GopherRepetitionFilter(BaseFilter): name = "👯 Gopher Repetition" - _requires_dependencies = ["nltk"] def __init__( self, @@ -81,6 +82,7 @@ def __init__( top_n_grams: tuple[tuple[int, float]] = ((2, 0.2), (3, 0.18), (4, 0.16)), dup_n_grams: tuple[tuple[int, float]] = ((5, 0.15), (6, 0.14), (7, 0.13), (8, 0.12), (9, 0.11), (10, 0.10)), exclusion_writer: DiskWriter = None, + language: str = Languages.english, ): """ @@ -102,10 +104,10 @@ def __init__( self.top_n_grams = top_n_grams self.dup_n_grams = dup_n_grams self.paragraph_exp = re.compile(r"\n{2,}") + self._line_splitter = re.compile("\n+") + self.tokenizer = load_word_tokenizer(language) def filter(self, doc: Document) -> bool | tuple[bool, str]: - from nltk.tokenize import word_tokenize - text = doc.text paragraphs = self.paragraph_exp.split(text.strip()) @@ -115,14 +117,14 @@ def filter(self, doc: Document) -> bool | tuple[bool, str]: if self.dup_para_char_frac and char_duplicates / len(text) > self.dup_para_char_frac: return False, "dup_para_char_frac" - lines = text.splitlines() + lines = self._line_splitter.split(text) line_duplicates, char_duplicates = find_duplicates(lines) if self.dup_line_frac and line_duplicates / len(lines) > self.dup_line_frac: return False, "dup_line_frac" if self.dup_line_char_frac and char_duplicates / len(text) > self.dup_line_char_frac: return False, "dup_line_char_frac" - words = word_tokenize(text, language="english") # TODO we should use language id filter + words = self.tokenizer.word_tokenize(text) for n, n_frac in self.top_n_grams: n_grams = get_n_grams(words, n) diff --git a/src/datatrove/pipeline/filters/language_filter.py b/src/datatrove/pipeline/filters/language_filter.py index 70f6bd7c..362d6825 100644 --- a/src/datatrove/pipeline/filters/language_filter.py +++ b/src/datatrove/pipeline/filters/language_filter.py @@ -1,12 +1,7 @@ -import os - -from huggingface_hub import cached_assets_path -from loguru import logger - from datatrove.data import Document -from datatrove.io import download_file from datatrove.pipeline.filters.base_filter import BaseFilter from datatrove.pipeline.writers.disk_base import DiskWriter +from datatrove.utils.lid import FastTextModel from datatrove.utils.typeshelper import Languages @@ -15,13 +10,14 @@ class LanguageFilter(BaseFilter): name = "🌍 Language ID" - _requires_dependencies = [("fasttext", "fasttext-wheel")] + _requires_dependencies = [("fasttext", "fasttext-wheel"), "fasteners"] def __init__( self, languages: tuple = (Languages.english,), language_threshold: float = 0.65, exclusion_writer: DiskWriter = None, + label_only: bool = False, ): """ filters if the predicted language is not among given language or if the language score is below language @@ -31,27 +27,13 @@ def __init__( languages: list of languages to keep language_threshold: language_threshold minimum score to accept a document exclusion_writer: + label_only: if True, only the language label is added to the metadata and no documents are removed """ super().__init__(exclusion_writer) self.language_threshold = language_threshold self.languages = languages - self._model = None - - @property - def model(self): - if not self._model: - from fasttext.FastText import _FastText - - download_dir = cached_assets_path( - library_name="datatrove", namespace="filters", subfolder="language_filter" - ) - model_file = os.path.join(download_dir, "lid.176.bin") - if not os.path.isfile(model_file): - logger.info("⬇️ Downloading fast-text language identifier model...") - download_file(LANGUAGE_ID_MODEL_URL, model_file) - logger.info("⬇️ Downloaded fast-text language identifier model.") - self._model = _FastText(model_file) - return self._model + self.model = FastTextModel(list(languages)) + self.label_only = label_only def filter(self, doc: Document) -> bool: """Args: @@ -60,10 +42,7 @@ def filter(self, doc: Document) -> bool: Returns: is_filter """ - - language, score = self.model.predict(doc.text.replace("\n", "")) - # language label is given in the form __label__ - language = language[0].split("__")[2] - doc.metadata["language"] = language - doc.metadata["language_score"] = score[0] - return score > self.language_threshold and language in self.languages + best_lang_pair, lang_pairs = self.model.predict(doc) + doc.metadata["language"] = best_lang_pair[0] + doc.metadata["language_score"] = best_lang_pair[1] + return self.label_only or any(score > self.language_threshold for score in lang_pairs.values()) diff --git a/src/datatrove/pipeline/filters/list_filter.py b/src/datatrove/pipeline/filters/list_filter.py deleted file mode 100644 index 881ba906..00000000 --- a/src/datatrove/pipeline/filters/list_filter.py +++ /dev/null @@ -1,37 +0,0 @@ -from datatrove.data import Document -from datatrove.pipeline.filters.base_filter import BaseFilter -from datatrove.pipeline.writers.disk_base import DiskWriter - - -class ListFilter(BaseFilter): - """ - Checks the ratio of number of lines to number of words. - Equivalent to around a min of 3.333 words per line - - """ - - name = "🎅 List" - _requires_dependencies = ["nltk"] - - def __init__(self, new_line_ratio: float | None = 0.3, exclusion_writer: DiskWriter = None): # TODO better tune - """ """ - super().__init__(exclusion_writer) - self.new_line_ratio = new_line_ratio - - def filter(self, doc: Document) -> bool | tuple[bool, str]: - """Applies heuristic rules to decide if a document should be REMOVED - Args: - doc - - Returns: - False if sample.text is a list - """ - from nltk.tokenize import word_tokenize - - text = doc.text - words = word_tokenize(text) # TODO we should use language id filter - new_line = text.count("\n") - if new_line / len(words) > self.new_line_ratio: - return False, "Suspected list" - - return True diff --git a/src/datatrove/pipeline/filters/unigram_log_probs.py b/src/datatrove/pipeline/filters/unigram_log_probs.py index 481255c1..af42e096 100644 --- a/src/datatrove/pipeline/filters/unigram_log_probs.py +++ b/src/datatrove/pipeline/filters/unigram_log_probs.py @@ -4,11 +4,13 @@ import numpy as np from huggingface_hub import cached_assets_path -from loguru import logger from datatrove.data import Document from datatrove.pipeline.filters.base_filter import BaseFilter from datatrove.pipeline.writers.disk_base import DiskWriter +from datatrove.utils.logging import logger +from datatrove.utils.typeshelper import Languages +from datatrove.utils.word_tokenizers import load_word_tokenizer UNIGRAM_DOWNLOAD = "https://ai2-s2-research-public.s3-us-west-2.amazonaws.com/lucas/google-1T-unigram/unigram_freq.csv" @@ -23,12 +25,9 @@ class UnigramLogProbFilter(BaseFilter): """ name = "🧑‍🍳 Unigram log-prob filter" - _requires_dependencies = ["nltk"] def __init__( - self, - logprobs_threshold: float = -10, - exclusion_writer: DiskWriter = None, + self, logprobs_threshold: float = -10, exclusion_writer: DiskWriter = None, language: str = Languages.english ): """ @@ -39,6 +38,7 @@ def __init__( super().__init__(exclusion_writer) self.logprobs_threshold = logprobs_threshold self.unigram_frequencies = self.get_frequencies() + self.tokenizer = load_word_tokenizer(language) def get_frequencies(self): download_dir = cached_assets_path( @@ -60,9 +60,7 @@ def get_frequencies(self): return {word: count / total_count for word, count in zip(words, counts)} def get_logprob(self, doc): - from nltk.tokenize import word_tokenize - - words = word_tokenize(doc.text) + words = self.tokenizer.word_tokenize(doc.text) freqs = [self.unigram_frequencies.get(word.lower(), 1e-9) for word in words] if len(freqs) == 0: diff --git a/src/datatrove/pipeline/filters/url_filter.py b/src/datatrove/pipeline/filters/url_filter.py index c707b75f..0675a010 100644 --- a/src/datatrove/pipeline/filters/url_filter.py +++ b/src/datatrove/pipeline/filters/url_filter.py @@ -6,7 +6,9 @@ from huggingface_hub import cached_assets_path from datatrove.data import Document +from datatrove.io import safely_create_file from datatrove.utils._import_utils import ASSETS_PATH +from datatrove.utils.logging import logger from ..writers.disk_base import DiskWriter from .base_filter import BaseFilter @@ -23,9 +25,9 @@ def parse_list(line, do_normalize=True): return {normalize(x) if do_normalize else x.strip() for x in line if x[0] != "#"} -def get_list(abs_path: str, file_name: str, extra: set = None, do_normalize: bool = True): +def get_list(abs_path: str, file_name: str, extra: set, do_normalize: bool = True): with open(os.path.join(abs_path, file_name)) as f: - return parse_list(f, do_normalize).union(set(parse_list(extra, do_normalize)) if extra else set()) + return parse_list(f, do_normalize).union(extra) class URLFilter(BaseFilter): @@ -41,7 +43,7 @@ class URLFilter(BaseFilter): """ name = "😈 Url-filter" - _requires_dependencies = ["tldextract"] + _requires_dependencies = ["tldextract", "fasteners", ("ahocorasick", "pyahocorasick")] def __init__( self, @@ -51,29 +53,44 @@ def __init__( banned_words: Iterable = None, banned_subwords: Iterable = None, soft_banned_words: Iterable = None, + use_integrated_lists: bool = True, exclusion_writer: DiskWriter = None, ): + import ahocorasick from tldextract import TLDExtract super().__init__(exclusion_writer) self.soft_word_threshold = soft_word_threshold - self.block_listed_domains = extra_domains - self.block_listed_url = extra_urls - self.banned_words = banned_words - self.banned_subwords = banned_subwords - self.soft_banned_words = soft_banned_words + self.block_listed_domains = parse_list(extra_domains, do_normalize=False) if extra_domains else set() + self.block_listed_url = parse_list(extra_urls, do_normalize=False) if extra_urls else set() + self.banned_words = parse_list(banned_words) if banned_words else set() + self.banned_subwords = parse_list(banned_subwords) if banned_subwords else set() + self.soft_banned_words = parse_list(soft_banned_words) if soft_banned_words else set() + self.use_integrated_lists = use_integrated_lists self._downloaded = False self.tldextractor = TLDExtract() + self.banned_subwords_automaton = ahocorasick.Automaton(ahocorasick.STORE_INTS) + for word in self.banned_subwords: + self.banned_subwords_automaton.add_word(word, len(self.banned_subwords_automaton)) + + if not self.use_integrated_lists: + self.banned_subwords_automaton.make_automaton() + def download_data(self): - if self._downloaded: + if self._downloaded or not self.use_integrated_lists: return download_dir = cached_assets_path(library_name="datatrove", namespace="filters", subfolder="url_filter") - if not os.path.isfile(os.path.join(download_dir, "adult", "domains")) or not os.path.isfile( - os.path.join(download_dir, "adult", "urls") - ): + file_to_lock = os.path.join(download_dir, "url_filterblacklists.tar.gz") + + def do_extract(): + logger.info("💥 Extracting url filter blacklists...") with tarfile.open(os.path.join(ASSETS_PATH, "url_filterblacklists.tar.gz"), "r:gz") as tar: tar.extractall(download_dir) + logger.info("💥 Extracted url filter blacklists.") + + safely_create_file(file_to_lock, do_extract) + self.block_listed_domains = get_list( download_dir, "adult/domains", self.block_listed_domains, do_normalize=False ) @@ -81,6 +98,9 @@ def download_data(self): self.banned_words = get_list(ASSETS_PATH, "banned_words.txt", self.banned_words) self.banned_subwords = get_list(ASSETS_PATH, "banned_subwords.txt", self.banned_subwords) self.soft_banned_words = get_list(ASSETS_PATH, "soft_banned_words.txt", self.soft_banned_words) + for word in self.banned_subwords: + self.banned_subwords_automaton.add_word(word, len(self.banned_subwords_automaton)) + self.banned_subwords_automaton.make_automaton() self._downloaded = True def filter(self, document: Document) -> bool | tuple[bool, str]: @@ -108,7 +128,7 @@ def filter(self, document: Document) -> bool | tuple[bool, str]: return False, "soft_blacklisted" normalized_space = normalize(url) - if any(word in normalized_space for word in self.banned_subwords): + if self.banned_subwords and next(self.banned_subwords_automaton.iter(normalized_space), False): return False, "blacklisted_subword" return True diff --git a/src/datatrove/pipeline/formatters/__init__.py b/src/datatrove/pipeline/formatters/__init__.py index 207c085a..68e94c76 100644 --- a/src/datatrove/pipeline/formatters/__init__.py +++ b/src/datatrove/pipeline/formatters/__init__.py @@ -1,2 +1,3 @@ from .ftfy import FTFYFormatter +from .pii import PIIFormatter from .symbol_lines_remover import SymbolLinesFormatter diff --git a/src/datatrove/pipeline/formatters/pii.py b/src/datatrove/pipeline/formatters/pii.py new file mode 100644 index 00000000..c0582459 --- /dev/null +++ b/src/datatrove/pipeline/formatters/pii.py @@ -0,0 +1,94 @@ +import ipaddress +import re +from functools import partial +from typing import Callable + +from datatrove.pipeline.formatters.base import BaseFormatter + + +class PIIReplacer: + def __init__( + self, regex: str, replacements: tuple[str, ...] | str, validator: Callable[[str], bool] | None = None + ): + self.regex: re.Pattern = re.compile(regex) + self.replacements = ( + replacements + if type(replacements) is tuple + else (tuple(replacements) if not isinstance(replacements, str) else (replacements,)) + ) + self.validator = validator # extra validation for a match + self._replace_i = 0 + + def replace(self, text: str): + def get_replacement(matchobj): + if self.validator and not self.validator(matchobj.group(0)): + # not a valid match. replace with itself + return matchobj.group(0) + replacement = self.replacements[self._replace_i] + self._replace_i = (self._replace_i + 1) % len(self.replacements) + return replacement + + return self.regex.sub(get_replacement, text) + + +def public_ip_validator(ip, public_only: bool = True) -> bool: + try: + ip = ipaddress.ip_address(ip) + return not public_only or ip.is_global + except ValueError: + return False + + +class PIIFormatter(BaseFormatter): + """ + Replaces email addresses and ip addresses in the document text. + Args: + remove_emails: Replace email addresses + remove_ips: Replace IP addresses + only_remove_public_ips: by default we only replace public (and thus PII) IPs + email_replacement: tuple of strings to use as replacement. They will be used in a circular way + ip_replacement same as email_replacement but for IP addresses + """ + + name = "📞 PII" + + def __init__( + self, + remove_emails: bool = True, + remove_ips: bool = True, + only_remove_public_ips: bool = True, + # example.com/org are actually maintained as an example + email_replacement: tuple[str, ...] | str = ("email@example.com", "firstname.lastname@example.org"), + # randomly generated list of ips. they did not respond to ping requests at the time the list was created + ip_replacement: tuple[str, ...] | str = ( + "22.214.171.124", + "126.96.36.199", + "188.8.131.52", + "184.108.40.206", + "220.127.116.11", + "18.104.22.168", + ), + ): + super().__init__() + self.remove_emails = remove_emails + self.remove_ips = remove_ips + + self.emails_replacer = PIIReplacer( + r"\b[A-Za-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[A-Za-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:(?:[A-Za-z0-9](?:[" + r"A-Za-z0-9-]*[A-Za-z0-9])?\.)+[A-Za-z0-9](?:[A-Za-z0-9-]*[A-Za-z0-9])?|\[(?:(?:25[0-5]|2[0-4][0-9]|[" + r"01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?|[A-Za-z0-9-]*[A-Za-z0-9]:)])", + email_replacement, + ) + + self.ip_replacer = PIIReplacer( + r"(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)", + validator=partial(public_ip_validator, public_only=only_remove_public_ips), + replacements=ip_replacement, + ) + + def format(self, text: str) -> str: + if self.remove_emails: + text = self.emails_replacer.replace(text) + if self.remove_ips: + text = self.ip_replacer.replace(text) + return text diff --git a/src/datatrove/pipeline/readers/base.py b/src/datatrove/pipeline/readers/base.py index bb1b023b..2cab14b5 100644 --- a/src/datatrove/pipeline/readers/base.py +++ b/src/datatrove/pipeline/readers/base.py @@ -1,14 +1,14 @@ import random from abc import abstractmethod -from contextlib import nullcontext +from types import MethodType from typing import Callable -from loguru import logger from tqdm import tqdm from datatrove.data import Document, DocumentsPipeline from datatrove.io import DataFolderLike, get_datafolder from datatrove.pipeline.base import PipelineStep +from datatrove.utils.logging import logger class BaseReader(PipelineStep): @@ -17,12 +17,12 @@ class BaseReader(PipelineStep): Args: limit: limit the number of documents to read. Useful for debugging - progress: show tqdm progress bar. Might be spammy in some environments adapter: function to adapt the data dict from the source to a Document. - Take as input: data: dict, path: str, id_in_file: int | str - Return: a dict with at least a "text" key - text_key: key to use for the text in the default adapter (default: "text"). Ignored if you provide your own `adapter` - id_key: key to use for the id in the default adapter (default: "id"). Ignored if you provide your own `adapter` + Takes as input: (self, data: dict, path: str, id_in_file: int | str) + self allows access to self.text_key and self.id_key + Returns: a dict with at least a "text" key + text_key: key to use for the text in the default adapter (default: "text"). + id_key: key to use for the id in the default adapter (default: "id"). default_metadata: default metadata to add to all documents """ @@ -31,28 +31,18 @@ class BaseReader(PipelineStep): def __init__( self, limit: int = -1, - progress: bool = False, + skip: int = 0, adapter: Callable = None, text_key: str = "text", id_key: str = "id", default_metadata: dict = None, ): - """ - - Args: - limit: read at most this number of documents - progress: show a tqdm progress bar - adapter: custom function that should return a dictionary with the datatrove Document format (see _default_adapter) - text_key: the key containing the text data. `text` by default - id_key: the key containing the id for each sample. `id` by default - default_metadata: a dictionary with any data that should be added to all sample's metadata - """ super().__init__() self.limit = limit - self.progress = progress + self.skip = skip self.text_key = text_key self.id_key = id_key - self.adapter = adapter if adapter else self._default_adapter + self.adapter = MethodType(adapter, self) if adapter else self._default_adapter self._empty_warning = False self.default_metadata = default_metadata @@ -91,7 +81,8 @@ def get_document_from_dict(self, data: dict, source_file: str, id_in_file: int | if not self._empty_warning: self._empty_warning = True logger.warning( - f"Found document without text, skipping. " f'Is your `text_key` ("{self.text_key}") correct?' + f"Found document without text, skipping. " + f'Is your `text_key` ("{self.text_key}") correct? Available keys: {list(data.keys())}' ) return None document = Document(**parsed_data) @@ -113,7 +104,8 @@ class BaseDiskReader(BaseReader): Args: data_folder: the data folder to read from limit: limit the number of documents to read. Useful for debugging - progress: show progress bar + file_progress: show progress bar for files + doc_progress: show progress bar for documents adapter: function to adapt the data from the source to a Document text_key: key to use for the text in the default adapter (default: "text"). Ignored if you provide your own `adapter` id_key: key to use for the id in the default adapter (default: "id"). Ignored if you provide your own `adapter` @@ -128,7 +120,9 @@ def __init__( self, data_folder: DataFolderLike, limit: int = -1, - progress: bool = False, + skip: int = 0, + file_progress: bool = False, + doc_progress: bool = False, adapter: Callable = None, text_key: str = "text", id_key: str = "id", @@ -142,7 +136,9 @@ def __init__( Args: data_folder: a str, tuple or DataFolder object representing a path/filesystem limit: read at most this number of documents - progress: show a tqdm progress bar + skip: skip the first n rows + file_progress: show a tqdm progress bar for files + doc_progress: show a tqdm progress bar for documents adapter: custom function that should return a dictionary with the datatrove Document format (see _default_adapter) text_key: the key containing the text data. `text` by default id_key: the key containing the id for each sample. `id` by default @@ -152,11 +148,13 @@ def __init__( shuffle_files: shuffle the files within the returned shard. Mostly used for data viz. purposes, do not use with dedup blocks """ - super().__init__(limit, progress, adapter, text_key, id_key, default_metadata) + super().__init__(limit, skip, adapter, text_key, id_key, default_metadata) self.data_folder = get_datafolder(data_folder) self.recursive = recursive self.glob_pattern = glob_pattern self.shuffle_files = shuffle_files + self.file_progress = file_progress + self.doc_progress = doc_progress def get_document_from_dict(self, data: dict, source_file: str, id_in_file: int): document = super().get_document_from_dict(data, source_file, id_in_file) @@ -187,18 +185,30 @@ def read_files_shard(self, shard: list[str]) -> DocumentsPipeline: """ li = 0 - with tqdm(total=self.limit if self.limit != -1 else None) if self.progress else nullcontext() as pbar: - for filepath in shard: + skipped = 0 + with ( + tqdm( + total=self.limit if self.limit != -1 else None, + desc="Document progress", + unit="doc", + disable=not self.doc_progress, + ) as doc_pbar, + tqdm(total=len(shard), desc="File progress", unit="file", disable=not self.file_progress) as file_pbar, + ): + for i, filepath in enumerate(shard): self.stat_update("input_files") - logger.info(f"Reading input file {filepath}") + logger.info(f"Reading input file {filepath}, {i+1}/{len(shard)}") di = 0 for di, document in enumerate(self.read_file(filepath)): + if skipped < self.skip: + skipped += 1 + continue if self.limit != -1 and li >= self.limit: break yield document - if self.progress: - pbar.update() + doc_pbar.update() li += 1 + file_pbar.update() self.stat_update("documents", value=di, unit="input_file") if self.limit != -1 and li >= self.limit: break diff --git a/src/datatrove/pipeline/readers/csv.py b/src/datatrove/pipeline/readers/csv.py index 446315d7..0c45d219 100644 --- a/src/datatrove/pipeline/readers/csv.py +++ b/src/datatrove/pipeline/readers/csv.py @@ -13,7 +13,9 @@ class CsvReader(BaseDiskReader): data_folder: the data folder to read from compression: the compression to use (default: "infer") limit: limit the number of CSV lines to read in each rank. Useful for debugging - progress: show progress bar + skip: skip the first n rows + file_progress: show progress bar for files + doc_progress: show progress bar for documents adapter: function to adapt the data dict from the source to a Document. Take as input: data: dict, path: str, id_in_file: int | str Return: a dict with at least a "text" key @@ -33,7 +35,9 @@ def __init__( data_folder: DataFolderLike, compression: Literal["infer", "gzip", "zstd"] | None = "infer", limit: int = -1, - progress: bool = False, + skip: int = 0, + file_progress: bool = False, + doc_progress: bool = False, adapter: Callable = None, text_key: str = "text", id_key: str = "id", @@ -45,7 +49,9 @@ def __init__( super().__init__( data_folder, limit, - progress, + skip, + file_progress, + doc_progress, adapter, text_key, id_key, diff --git a/src/datatrove/pipeline/readers/huggingface.py b/src/datatrove/pipeline/readers/huggingface.py index 3921459b..3b45cf6e 100644 --- a/src/datatrove/pipeline/readers/huggingface.py +++ b/src/datatrove/pipeline/readers/huggingface.py @@ -1,4 +1,4 @@ -from contextlib import nullcontext +import copy from typing import Callable from tqdm import tqdm @@ -14,9 +14,11 @@ class HuggingFaceDatasetReader(BaseReader): Args: dataset: the name of the dataset to load with datasets.load_dataset dataset_options: options to pass to the load_dataset function + streaming: whether to stream the dataset limit: limit the number of rows to read + skip: skip the first n rows batch_size: the batch size to use - progress: show progress bar + doc_progress: show progress bar for documents adapter: function to adapt the data dict from the source to a Document. Take as input: data: dict, path: str, id_in_file: int | str Return: a dict with at least a "text" key @@ -32,18 +34,22 @@ def __init__( self, dataset: str, dataset_options: dict | None = None, + streaming: bool = False, limit: int = -1, + skip: int = 0, batch_size: int = 1000, - progress: bool = False, + doc_progress: bool = False, adapter: Callable = None, text_key: str = "text", id_key: str = "id", default_metadata: dict = None, ): - super().__init__(limit, progress, adapter, text_key, id_key, default_metadata) + super().__init__(limit, skip, adapter, text_key, id_key, default_metadata) self.dataset = dataset - self.dataset_options = dataset_options + self.dataset_options = dataset_options or {} self.batch_size = batch_size + self.doc_progress = doc_progress + self.streaming = streaming def get_document_from_dict(self, data: dict, source: str, id_in_file: int | str): document = super().get_document_from_dict(data, source, id_in_file) @@ -51,16 +57,45 @@ def get_document_from_dict(self, data: dict, source: str, id_in_file: int | str) document.metadata.setdefault("dataset", source) return document + def _get_dataset_shard(self, dst, rank: int, world_size: int): + from datasets import Dataset, IterableDataset + from datasets.distributed import split_dataset_by_node + + if isinstance(dst, Dataset): + return dst.shard(world_size, rank, contiguous=True) + elif isinstance(dst, IterableDataset) and dst.n_shards > 1: + # In case we have more than 1 shard (file), we shard + # on shards/file level. + ex_iterable = dst._ex_iterable.shard_data_sources(rank, world_size) + return IterableDataset( + ex_iterable=ex_iterable, + info=dst._info.copy(), + split=dst._split, + formatting=dst._formatting, + shuffling=copy.deepcopy(dst._shuffling), + distributed=copy.deepcopy(dst._distributed), + token_per_repo_id=dst._token_per_repo_id, + ) + else: + # If we have just a single shard/file, we shard inter-file + return split_dataset_by_node(dst, rank, world_size) + def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1) -> DocumentsPipeline: from datasets import load_dataset # type: ignore if data: yield from data - # sadly sharding in this way with streaming is not supported by HF datasets yet, so no streaming - ds = load_dataset(self.dataset, **self.dataset_options) - shard = ds.shard(world_size, rank, contiguous=True) - with tqdm(total=self.limit if self.limit != -1 else None) if self.progress else nullcontext() as pbar: + ds = load_dataset(self.dataset, **self.dataset_options, streaming=self.streaming) + + # In case the dataset is (Iterable)?DatasetDict, raise informative error + if isinstance(ds, dict): + raise ValueError( + f"You forgot to specify the split of the dataset. Update your dataset_options to include 'split'. Available splits: {list(ds.keys())}" + ) + + shard = self._get_dataset_shard(ds, rank, world_size) + with tqdm(total=self.limit if self.limit != -1 else None, disable=not self.doc_progress) as pbar: li = 0 for batch in shard.iter(self.batch_size): if self.limit != -1 and li >= self.limit: @@ -77,6 +112,5 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1 self.update_doc_stats(document) self.stat_update("documents") li += 1 - if self.progress: - pbar.update() + pbar.update() yield from documents diff --git a/src/datatrove/pipeline/readers/ipc.py b/src/datatrove/pipeline/readers/ipc.py index 0752f48a..e2333e42 100644 --- a/src/datatrove/pipeline/readers/ipc.py +++ b/src/datatrove/pipeline/readers/ipc.py @@ -10,8 +10,10 @@ class IpcReader(BaseDiskReader): Args: data_folder: the data folder to read from limit: limit the number of IPC documents to read + skip: skip the first n rows stream: if True, will read the file as a stream (default: False) - progress: show progress bar + file_progress: show progress bar for files + doc_progress: show progress bar for documents adapter: function to adapt the data dict from the source to a Document. Take as input: data: dict, path: str, id_in_file: int | str Return: a dict with at least a "text" key @@ -31,8 +33,10 @@ def __init__( self, data_folder: DataFolderLike, limit: int = -1, + skip: int = 0, stream: bool = False, - progress: bool = False, + file_progress: bool = False, + doc_progress: bool = False, adapter: Callable = None, text_key: str = "text", id_key: str = "id", @@ -44,7 +48,9 @@ def __init__( super().__init__( data_folder, limit, - progress, + skip, + file_progress, + doc_progress, adapter, text_key, id_key, diff --git a/src/datatrove/pipeline/readers/jsonl.py b/src/datatrove/pipeline/readers/jsonl.py index 789be2ec..1e5aaad3 100644 --- a/src/datatrove/pipeline/readers/jsonl.py +++ b/src/datatrove/pipeline/readers/jsonl.py @@ -1,11 +1,8 @@ -import json -from json import JSONDecodeError from typing import Callable, Literal -from loguru import logger - from datatrove.io import DataFolderLike from datatrove.pipeline.readers.base import BaseDiskReader +from datatrove.utils.logging import logger class JsonlReader(BaseDiskReader): @@ -16,7 +13,9 @@ class JsonlReader(BaseDiskReader): data_folder: the data folder to read from compression: the compression to use (default: "infer") limit: limit the number of JSON lines to read - progress: show progress bar + skip: skip the first n rows + file_progress: show progress bar for files + doc_progress: show progress bar for documents adapter: function to adapt the data dict from the source to a Document. Take as input: data: dict, path: str, id_in_file: int | str Return: a dict with at least a "text" key @@ -30,13 +29,16 @@ class JsonlReader(BaseDiskReader): """ name = "🐿 Jsonl" + _requires_dependencies = ["orjson"] def __init__( self, data_folder: DataFolderLike, compression: Literal["infer", "gzip", "zstd"] | None = "infer", limit: int = -1, - progress: bool = False, + skip: int = 0, + file_progress: bool = False, + doc_progress: bool = False, adapter: Callable = None, text_key: str = "text", id_key: str = "id", @@ -48,7 +50,9 @@ def __init__( super().__init__( data_folder, limit, - progress, + skip, + file_progress, + doc_progress, adapter, text_key, id_key, @@ -60,12 +64,15 @@ def __init__( self.compression = compression def read_file(self, filepath: str): + import orjson + from orjson import JSONDecodeError + with self.data_folder.open(filepath, "rt", encoding='utf-8', compression=self.compression) as f: try: for li, line in enumerate(f): with self.track_time(): try: - document = self.get_document_from_dict(json.loads(line), filepath, li) + document = self.get_document_from_dict(orjson.loads(line), filepath, li) if not document: continue except (EOFError, JSONDecodeError) as e: diff --git a/src/datatrove/pipeline/readers/parquet.py b/src/datatrove/pipeline/readers/parquet.py index 7be213da..57bbabea 100644 --- a/src/datatrove/pipeline/readers/parquet.py +++ b/src/datatrove/pipeline/readers/parquet.py @@ -11,9 +11,11 @@ class ParquetReader(BaseDiskReader): Args: data_folder: the data folder to read from limit: limit the number of Parquet rows to read + skip: skip the first n rows batch_size: the batch size to use (default: 1000) read_metadata: if True, will read the metadata (default: True) - progress: show progress bar + file_progress: show progress bar for files + doc_progress: show progress bar for documents adapter: function to adapt the data dict from the source to a Document. Take as input: data: dict, path: str, id_in_file: int | str Return: a dict with at least a "text" key @@ -33,9 +35,11 @@ def __init__( self, data_folder: DataFolderLike, limit: int = -1, + skip: int = 0, batch_size: int = 1000, read_metadata: bool = True, - progress: bool = False, + file_progress: bool = False, + doc_progress: bool = False, adapter: Callable = None, text_key: str = "text", id_key: str = "id", @@ -47,7 +51,9 @@ def __init__( super().__init__( data_folder, limit, - progress, + skip, + file_progress, + doc_progress, adapter, text_key, id_key, diff --git a/src/datatrove/pipeline/readers/warc.py b/src/datatrove/pipeline/readers/warc.py index 578baabc..665b9661 100644 --- a/src/datatrove/pipeline/readers/warc.py +++ b/src/datatrove/pipeline/readers/warc.py @@ -16,7 +16,9 @@ class WarcReader(BaseDiskReader): data_folder: the data folder to read from compression: the compression to use (default: "infer") limit: limit the number of WARC documents to read - progress: show progress bar + skip: skip the first n rows + file_progress: show progress bar for files + doc_progress: show progress bar for documents adapter: function to adapt the data dict from the source to a Document. Take as input: data: dict, path: str, id_in_file: int | str Return: a dict with at least a "text" key @@ -37,7 +39,9 @@ def __init__( data_folder: DataFolderLike, compression: Literal["infer", "gzip", "zstd"] | None = "infer", limit: int = -1, - progress: bool = False, + skip: int = 0, + file_progress: bool = False, + doc_progress: bool = False, adapter: Callable = None, text_key: str = "text", id_key: str = "id", @@ -50,7 +54,9 @@ def __init__( super().__init__( data_folder, limit, - progress, + skip, + file_progress, + doc_progress, adapter, text_key, id_key, diff --git a/src/datatrove/pipeline/stats/__init__.py b/src/datatrove/pipeline/stats/__init__.py index 9a04428c..ea420b4e 100644 --- a/src/datatrove/pipeline/stats/__init__.py +++ b/src/datatrove/pipeline/stats/__init__.py @@ -1,2 +1,11 @@ -from .doc_len import DocLenStats -from .urls import URLStats +from datatrove.pipeline.stats.config import DEFAULT_TOP_K_CONFIG, GROUP, STAT_TYPE, TopKConfig +from datatrove.pipeline.stats.contamination_stats import WordsContaminationStats +from datatrove.pipeline.stats.doc_stats import DocStats +from datatrove.pipeline.stats.lang_stats import LangStats +from datatrove.pipeline.stats.line_stats import LineStats +from datatrove.pipeline.stats.merger import STATS_MERGED_NAME, StatsMerger +from datatrove.pipeline.stats.paragraph_stats import ParagraphStats +from datatrove.pipeline.stats.perplexity_stats import CCNetPerplexityStats +from datatrove.pipeline.stats.sentence_stats import SentenceStats +from datatrove.pipeline.stats.token_stats import TokenStats +from datatrove.pipeline.stats.word_stats import WordStats diff --git a/src/datatrove/pipeline/stats/base.py b/src/datatrove/pipeline/stats/base.py new file mode 100644 index 00000000..36d64601 --- /dev/null +++ b/src/datatrove/pipeline/stats/base.py @@ -0,0 +1,121 @@ +import heapq +import json +from abc import abstractmethod +from collections import defaultdict +from typing import get_args + +from loguru import logger + +from datatrove.data import Document, DocumentsPipeline +from datatrove.io import DataFolderLike, get_datafolder +from datatrove.pipeline.base import PipelineStep +from datatrove.pipeline.stats.config import DEFAULT_TOP_K_CONFIG, GROUP, STAT_TYPE, TopKConfig +from datatrove.utils.stats import MetricStatsDict + + +class BaseStats(PipelineStep): + """ + Datatrove block for computing statistics of dataset. + Each stat is of type MetricStatsDict saved in output_folder/{group}/{stat_name}/{rank:05d}.json + Args: + output_folder: The folder where the statistics will be saved. + groups_to_compute: The groups of statistics to compute. + histogram_round_digits: The number of digits to round the histogram values to. + This ensures reasonable number of bins. + top_k_config: The configuration for compressing the statistics. + Each group in top_k_groups will truncate the statistics to the top k keys. + This lowers memory usage and speeds up the merging in second-stage. + """ + + type = "📊 - STATS" + name = "👑 Summary stats" + _requires_dependencies = ["tldextract"] + + def __init__( + self, + output_folder: DataFolderLike, + groups_to_compute: list[GROUP] | None = None, + histogram_round_digits: int = 3, + top_k_config: TopKConfig = DEFAULT_TOP_K_CONFIG, + ) -> None: + from tldextract import TLDExtract + + super().__init__() + self.output_folder = get_datafolder(output_folder) + self.groups = groups_to_compute or list(get_args(GROUP)) + self.histogram_round_digits = histogram_round_digits + self.top_k_cfg = top_k_config + self.tld_extractor = TLDExtract() + + @abstractmethod + def extract_stats(self, doc: Document) -> dict[str, int | float]: + """ + Abstract method for extracting stats from a document. + Args: + doc: The document to extract stats from. + + Returns: + A dictionary of statistics, where the key is the stat name and the value is the stat value. + """ + raise NotImplementedError() + + def get_kv(self, doc: Document, value: STAT_TYPE, group_name: GROUP) -> tuple[str, STAT_TYPE]: + if group_name == "histogram": + # Use rounding to reduce then number of values for histogram + return str(round(value, self.histogram_round_digits)), 1 + elif group_name == "summary": + return "summary", value + elif group_name == "fqdn": + fqdn = doc.metadata.get("fqdn") + if fqdn is None: + fqdn = self.tld_extractor.extract_str(doc.metadata["url"]).fqdn + doc.metadata["fqdn"] = fqdn + return fqdn, value + elif group_name == "suffix": + suffix = doc.metadata.get("suffix") + if suffix is None: + suffix = self.tld_extractor.extract_str(doc.metadata["url"]).suffix + doc.metadata["suffix"] = suffix + return suffix, value + else: + raise ValueError(f"Unknown group name: {group_name}") + + def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> DocumentsPipeline: + groups_dicts: dict[GROUP, dict[str, MetricStatsDict]] = { + group: defaultdict(MetricStatsDict) for group in self.groups + } + + for doc in data: + with self.track_time(): + try: + doc_stats = self.extract_stats(doc) + except Exception as e: + logger.error(f"Error while extracting stats from document {doc.id}", exc_info=e) + raise e + + for group, counters in groups_dicts.items(): + for stat, value in doc_stats.items(): + key, value = self.get_kv(doc, value, group) + counters[stat][key] += value + + doc.metadata.update(doc_stats) + yield doc + + # save to disk + for group, stats_dict in groups_dicts.items(): + group_top_k_keys = None + + for stat_name, stat_values in stats_dict.items(): + if group in self.top_k_cfg.top_k_groups: + # We don't have to compute this for every stat in group, as stat.n will be constant + if group_top_k_keys is None: + group_top_k_keys = heapq.nlargest( + self.top_k_cfg.top_k, stat_values, key=lambda x: stat_values[x].n + ) + + stat_values = MetricStatsDict(init={s: stat_values[s] for s in group_top_k_keys}) + + with self.output_folder.open(f"{group}/{stat_name}/{rank:05d}.json", "wt") as f: + json.dump(stat_values.to_dict(), f) + # delete the group_dicts to save mem + del groups_dicts diff --git a/src/datatrove/pipeline/stats/config.py b/src/datatrove/pipeline/stats/config.py new file mode 100644 index 00000000..487ebba1 --- /dev/null +++ b/src/datatrove/pipeline/stats/config.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass +from typing import Literal + + +GROUP = Literal["summary", "histogram", "fqdn", "suffix"] + + +@dataclass(frozen=True) +class TopKConfig: + """ + Configuration for compressing the statistics. + Each group in top_k_groups will be truncated to the top k keys. + This lowers memory usage and speeds up the merging in second-stage. + + If run in distributed mode, each node will create its own top_k_keys, which + leads to inconsistent top_k_keys between nodes. To account for this, set around + 0.8*top_k as the number of top_k_keys for merging step. + """ + + top_k_groups: list[Literal["fqdn", "suffix"]] + top_k: int + + +DEFAULT_TOP_K_CONFIG = TopKConfig(top_k_groups=["fqdn", "suffix"], top_k=100_000) + +STAT_TYPE = int | float diff --git a/src/datatrove/pipeline/stats/contamination_stats.py b/src/datatrove/pipeline/stats/contamination_stats.py new file mode 100644 index 00000000..f9344267 --- /dev/null +++ b/src/datatrove/pipeline/stats/contamination_stats.py @@ -0,0 +1,50 @@ +from typing import get_args + +from datatrove.data import Document +from datatrove.io import DataFolderLike +from datatrove.pipeline.stats.base import BaseStats +from datatrove.pipeline.stats.config import DEFAULT_TOP_K_CONFIG, GROUP, TopKConfig +from datatrove.utils.text import TextNormConfig, simplify_text +from datatrove.utils.typeshelper import Languages +from datatrove.utils.word_tokenizers import load_word_tokenizer + + +class WordsContaminationStats(BaseStats): + """ + Words contamination stats of a document. + + Available stats: + word_contamination_{words[0]}: Frequency of words contamination in the document. + + Args: + words: The words to check for contamination. + """ + + name = "😷 Words contamination" + + def __init__( + self, + output_folder: DataFolderLike, + words: list[str], + norm_config: TextNormConfig = TextNormConfig(), + language: str = Languages.english, + groups_to_compute: list[GROUP] = list(get_args(GROUP)), + histogram_round_digits: int = 3, + top_k_config: TopKConfig = DEFAULT_TOP_K_CONFIG, + ) -> None: + super().__init__(output_folder, groups_to_compute, histogram_round_digits, top_k_config=top_k_config) + if len(words) == 0: + raise ValueError("At least one word must be provided") + + self.norm_config = norm_config + self.language = language + self.words = words + + def extract_stats(self, doc: Document) -> dict[str, int | float]: + word_tokenizer = load_word_tokenizer(self.language) + + doc_words = word_tokenizer.word_tokenize(simplify_text(doc.text, self.norm_config)) + return { + f"words_contamination_{self.words[0]}": sum([1 for word in doc_words if word in self.words]) + / len(doc_words) + } diff --git a/src/datatrove/pipeline/stats/doc_len.py b/src/datatrove/pipeline/stats/doc_len.py deleted file mode 100644 index 28c5f859..00000000 --- a/src/datatrove/pipeline/stats/doc_len.py +++ /dev/null @@ -1,15 +0,0 @@ -from datatrove.pipeline.base import DocumentsPipeline, PipelineStep - - -class DocLenStats(PipelineStep): - """Pipeline step to compute the length of each document in a pipeline. - Will add a "length" metadata to each document with the length of the text in characters and (if available) tokens. - """ - - type = "📊 - STATS" - name = "🤓 document length" - - def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> DocumentsPipeline: - for doc in data: - self.update_doc_stats(doc) - yield doc diff --git a/src/datatrove/pipeline/stats/doc_stats.py b/src/datatrove/pipeline/stats/doc_stats.py new file mode 100644 index 00000000..ffcdffb0 --- /dev/null +++ b/src/datatrove/pipeline/stats/doc_stats.py @@ -0,0 +1,50 @@ +import re +from typing import get_args + +from datatrove.data import Document +from datatrove.io import DataFolderLike +from datatrove.pipeline.stats.base import BaseStats +from datatrove.pipeline.stats.config import DEFAULT_TOP_K_CONFIG, GROUP, TopKConfig +from datatrove.utils.text import PUNCTUATION + + +ELIPSIS = ["...", "…"] + + +class DocStats(BaseStats): + """ + Summary stats of document level metrics: + + Available stats: + length: Length of the document + white_space_ratio: Ratio of whitespace characters + non_alpha_digit_ratio: Ratio of non-alphabetic and non-digit characters + digit_ratio: Ratio of digits + uppercase_ratio: Ratio of uppercase letters + elipsis_ratio: Ratio of elipsis characters + punctuation_ratio: Punctuation ratio + """ + + name = "📜 Doc stats" + + def __init__( + self, + output_folder: DataFolderLike, + groups_to_compute: list[GROUP] = list(get_args(GROUP)), + histogram_round_digits: int = 3, + top_k_config: TopKConfig = DEFAULT_TOP_K_CONFIG, + ) -> None: + super().__init__(output_folder, groups_to_compute, histogram_round_digits, top_k_config) + self.elipsis_regex = re.compile("|".join([f"(?:{re.escape(elipsis)})" for elipsis in ELIPSIS])) + self.punc_regex = re.compile("|".join([f"(?:{re.escape(punc)})" for punc in PUNCTUATION])) + + def extract_stats(self, doc: Document) -> dict[str, int | float]: + return { + "length": len(doc.text), + "white_space_ratio": sum([1 for c in doc.text if c.isspace()]) / len(doc.text), + "non_alpha_digit_ratio": sum([1 for c in doc.text if not c.isalpha() and not c.isdigit()]) / len(doc.text), + "digit_ratio": sum([1 for c in doc.text if c.isdigit()]) / len(doc.text), + "uppercase_ratio": sum([1 for c in doc.text if c.isupper()]) / len(doc.text), + "elipsis_ratio": sum(len(elipsis) for elipsis in self.elipsis_regex.findall(doc.text)) / len(doc.text), + "punctuation_ratio": sum(len(punc) for punc in self.punc_regex.findall(doc.text)) / len(doc.text), + } diff --git a/src/datatrove/pipeline/stats/lang_stats.py b/src/datatrove/pipeline/stats/lang_stats.py new file mode 100644 index 00000000..c108efcf --- /dev/null +++ b/src/datatrove/pipeline/stats/lang_stats.py @@ -0,0 +1,38 @@ +from typing import get_args + +from datatrove.data import Document +from datatrove.io import DataFolderLike +from datatrove.pipeline.stats.base import BaseStats +from datatrove.pipeline.stats.config import DEFAULT_TOP_K_CONFIG, GROUP, TopKConfig +from datatrove.utils.lid import FastTextModel + + +class LangStats(BaseStats): + """ + Summary stats of language metrics: + + Available stats: + fasttext_{language} + """ + + name = "🎤 Language stats" + + def __init__( + self, + output_folder: DataFolderLike, + language: str, + groups_to_compute: list[GROUP] = list(get_args(GROUP)), + histogram_round_digits: int = 3, + top_k_config: TopKConfig = DEFAULT_TOP_K_CONFIG, + ) -> None: + super().__init__(output_folder, groups_to_compute, histogram_round_digits, top_k_config) + self.fasttext = FastTextModel([language]) + self.language = language + + def extract_stats(self, doc: Document) -> dict[str, int | float]: + language_score = 0 + if doc.metadata.get("language") == self.language and "language_score" in doc.metadata: + language_score = doc.metadata["language_score"] + else: + language_score = self.fasttext.predict(doc)[1][self.language] + return {f"fasttext_{self.language}": language_score} diff --git a/src/datatrove/pipeline/stats/line_stats.py b/src/datatrove/pipeline/stats/line_stats.py new file mode 100644 index 00000000..0b84b619 --- /dev/null +++ b/src/datatrove/pipeline/stats/line_stats.py @@ -0,0 +1,87 @@ +from typing import get_args + +from datatrove.data import Document +from datatrove.io import DataFolderLike +from datatrove.pipeline.filters.c4_filters import END_PUNCTUATION +from datatrove.pipeline.filters.gopher_repetition_filter import find_duplicates +from datatrove.pipeline.stats.base import BaseStats +from datatrove.pipeline.stats.config import DEFAULT_TOP_K_CONFIG, GROUP, TopKConfig + + +def get_max_chars_per_line_ratio(lines, chars: int) -> float: + return sum([1 for line in lines if len(line) <= chars]) / len(lines) + + +def get_min_chars_per_line_ratio(lines, chars: int) -> float: + return sum([1 for line in lines if len(line) >= chars]) / len(lines) + + +def is_bullet_line(line: str): + if len(line.strip()) == 0: + return False + return line.strip()[0] in "-*•" + + +class LineStats(BaseStats): + """ + Summary stats of line level metrics. + + Available stats: + n_lines: Number of lines per doc + avg_line_length: Average length of line per doc + long_line_ratio_words: Ratio of lines with more than k chars + short_line_ratio_chars: Ratio of lines with more than k chars + bullet_point_lines_ratio: Ratio of bullet points + line_duplicates: Ratio of lines that are duplicates + line_char_duplicates: Ratio of chars in duplicated lines + + Args: + max_k_chars_per_line_tresholds: List of max chars per line to compute stats for. If None, default to [10, 30] + min_k_chars_per_line_thresholds: List of min chars per line to compute stats for. If None, default to [2000, 10000] + """ + + name = "🎼 Line stats" + + def __init__( + self, + output_folder: DataFolderLike, + max_k_chars_per_line_tresholds: list[int] | None = None, + min_k_chars_per_line_thresholds: list[int] | None = None, + groups_to_compute: list[GROUP] = list(get_args(GROUP)), + ignore_empty_lines: bool = False, + histogram_round_digits: int = 3, + top_k_config: TopKConfig = DEFAULT_TOP_K_CONFIG, + ) -> None: + super().__init__(output_folder, groups_to_compute, histogram_round_digits, top_k_config) + self.short_max_chars = ( + max_k_chars_per_line_tresholds if max_k_chars_per_line_tresholds is not None else [10, 30] + ) + self.long_max_chars = ( + min_k_chars_per_line_thresholds if min_k_chars_per_line_thresholds is not None else [2000, 10000] + ) + self.ignore_empty_lines = ignore_empty_lines + + def extract_stats(self, doc: Document): + lines: list[str] = doc.metadata.get("lines") or doc.text.split("\n") + # Don't ignore empty lines for count + n_lines = len(lines) + + lines = [line for line in lines if len(line.strip()) > 0] if self.ignore_empty_lines else lines + line_dups, char_dups = find_duplicates(lines) + return { + "n_lines": n_lines, + "avg_line_length": (sum([len(line) for line in lines]) / len(lines)), + **{ + f"short_line_ratio_chars_{chars}": get_max_chars_per_line_ratio(lines, chars) + for chars in self.short_max_chars + }, + **{ + f"long_line_ratio_chars_{chars}": get_min_chars_per_line_ratio(lines, chars) + for chars in self.long_max_chars + }, + "lines_ending_with_terminal_mark_ratio": sum(1 for line in lines if line.endswith(END_PUNCTUATION)) + / len(lines), + "bullet_point_lines_ratio": sum(1 for line in lines if is_bullet_line(line)) / len(lines), + "line_duplicates": line_dups / len(lines), + "line_char_duplicates": char_dups / sum(len(line) for line in lines), + } diff --git a/src/datatrove/pipeline/stats/merger.py b/src/datatrove/pipeline/stats/merger.py new file mode 100644 index 00000000..c27ac4fb --- /dev/null +++ b/src/datatrove/pipeline/stats/merger.py @@ -0,0 +1,84 @@ +import heapq +import json +from pathlib import Path + +from loguru import logger +from tqdm import tqdm + +from datatrove.data import DocumentsPipeline +from datatrove.io import DataFolderLike, get_datafolder +from datatrove.pipeline.base import PipelineStep +from datatrove.pipeline.stats.config import DEFAULT_TOP_K_CONFIG, TopKConfig +from datatrove.utils.stats import MetricStats, MetricStatsDict + + +STATS_MERGED_NAME = "metric.json" + + +class StatsMerger(PipelineStep): + """ + Datatrove block for merging partial stats files into a single file. + Each stat is of type MetricStatsDict saved in output_folder/{group}/{stat_name}/metric.json + Args: + input_folder: The folder used for saving stats files of SummaryStats block. + output_folder: The folder where the merged stats will be saved. + remove_input: Whether to remove the input files after merging. + top_k: The configuration for compressing the statistics. + Each group in top_k_groups will truncate the statistics to the top k keys. + """ + + type = "📊 - STATS" + name = "🔗 Merging stats" + + def __init__( + self, + input_folder: DataFolderLike, + output_folder: DataFolderLike, + remove_input: bool = False, + top_k_config: TopKConfig = DEFAULT_TOP_K_CONFIG, + ) -> None: + super().__init__() + self.input_folder = get_datafolder(input_folder) + self.output_folder = get_datafolder(output_folder) + self.remove_input = remove_input + self.top_k_config = top_k_config + + def get_leaf_non_empty_folders(self): + return sorted([path for path, folders, files in self.input_folder.walk("") if not folders and files]) + + def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> DocumentsPipeline: + """ + Args: + data: DocumentsPipeline: (Default value = None) + rank: int: (Default value = 0) + world_size: int: (Default value = 1) + + Each node will read a folder with stats files and merge them into a single file + """ + folders_shard = self.get_leaf_non_empty_folders()[rank::world_size] + logger.info(f"Merging {len(folders_shard)} stat folders") + with self.track_time(): + for folder in tqdm(folders_shard): + input_files = self.input_folder.glob(f"{folder}/[0-9][0-9][0-9][0-9][0-9].json") + logger.info(f"Processing folder {folder} with {len(input_files)} files") + + stat = MetricStatsDict() + for file in tqdm(input_files): + # Use inplace add to avoid creating a new dict + with self.input_folder.open(file, "rt") as f: + for key, item in json.load(f).items(): + stat[key] += MetricStats.from_dict(item) + + with self.output_folder.open(f"{folder}/{STATS_MERGED_NAME}", "wt") as f: + group_name = Path(folder).parent.name + if group_name in self.top_k_config.top_k_groups: + top_k_keys = heapq.nlargest(self.top_k_config.top_k, stat, key=lambda x: stat.get(x).n) + stat = MetricStatsDict(init={s: stat.get(s) for s in top_k_keys}) + json.dump(stat.to_dict(), f) + + if self.remove_input: + for file in input_files: + self.input_folder.rm(file) + + if data: + yield from data diff --git a/src/datatrove/pipeline/stats/paragraph_stats.py b/src/datatrove/pipeline/stats/paragraph_stats.py new file mode 100644 index 00000000..36347e30 --- /dev/null +++ b/src/datatrove/pipeline/stats/paragraph_stats.py @@ -0,0 +1,74 @@ +from typing import get_args + +from datatrove.data import Document +from datatrove.io import DataFolderLike +from datatrove.pipeline.filters.gopher_repetition_filter import find_duplicates +from datatrove.pipeline.stats.base import BaseStats +from datatrove.pipeline.stats.config import DEFAULT_TOP_K_CONFIG, GROUP, TopKConfig + + +def get_short_paragraph_ratio(paragraphs: list[str], threshold: int) -> float: + return sum([1 for paragraph in paragraphs if len(paragraph) <= threshold]) / len(paragraphs) + + +def get_long_paragraph_ratio(paragraphs: list[str], threshold: int) -> float: + return sum([1 for paragraph in paragraphs if len(paragraph) >= threshold]) / len(paragraphs) + + +class ParagraphStats(BaseStats): + """ + Summary stats of paragraphs in a document. + + Available stats: + n_paragraphs + avg_paragraph_length + short_paragraph_ratio_{chars} + long_paragraph_ratio_{chars} + """ + + type = "📊 - STATS" + name = "📄 Paragraph stats" + + def __init__( + self, + output_folder: DataFolderLike, + short_paragraph_max_chars_threshold: list[int] | None = None, + long_paragraph_max_chars_threshold: list[int] | None = None, + ignore_empty_paragraphs: bool = False, + histogram_round_digits: int = 3, + groups_to_compute: list[GROUP] = list(get_args(GROUP)), + top_k_config: TopKConfig = DEFAULT_TOP_K_CONFIG, + ) -> None: + super().__init__( + output_folder, + groups_to_compute, + histogram_round_digits, + top_k_config, + ) + + self.ignore_empty_paragraphs = ignore_empty_paragraphs + self.short_paragraph_max_chars_threshold = short_paragraph_max_chars_threshold or [100] + self.long_paragraph_max_chars_threshold = long_paragraph_max_chars_threshold or [1000] + + def extract_stats(self, doc: Document) -> dict[str, int | float]: + paragraphs = [p for p in doc.text.split("\n\n") if p.strip()] + # Don't ignore empty paragraphs for count + n_paragraphs = len(paragraphs) + + paragraphs = [p for p in paragraphs if p.strip()] if self.ignore_empty_paragraphs else paragraphs + paragraph_dups, paragraph_char_dups = find_duplicates(paragraphs) + + return { + "n_paragraphs": n_paragraphs, + "avg_paragraph_length": sum([len(p) for p in paragraphs]) / n_paragraphs, + **{ + f"short_paragraph_ratio_{chars}": get_short_paragraph_ratio(paragraphs, chars) + for chars in self.short_paragraph_max_chars_threshold + }, + **{ + f"long_paragraph_ratio_{chars}": get_long_paragraph_ratio(paragraphs, chars) + for chars in self.long_paragraph_max_chars_threshold + }, + "paragraph_duplicates": paragraph_dups / n_paragraphs, + "paragraph_char_duplicates": paragraph_char_dups / sum(len(p) for p in paragraphs), + } diff --git a/src/datatrove/pipeline/stats/perplexity_stats.py b/src/datatrove/pipeline/stats/perplexity_stats.py new file mode 100644 index 00000000..d777b485 --- /dev/null +++ b/src/datatrove/pipeline/stats/perplexity_stats.py @@ -0,0 +1,37 @@ +from typing import get_args + +from datatrove.data import Document +from datatrove.io import DataFolderLike +from datatrove.pipeline.stats.base import BaseStats +from datatrove.pipeline.stats.config import DEFAULT_TOP_K_CONFIG, GROUP, TopKConfig +from datatrove.utils.perplexity import KenlmModel +from datatrove.utils.typeshelper import Languages + + +class CCNetPerplexityStats(BaseStats): + """ + Summary stats of perplexity metrics: + + Available stats: + ccnet_perplexity_{model_dataset}_{language} + """ + + name = "🤯 CCNet perplexity stats" + _requires_dependencies = BaseStats._requires_dependencies + ["kenlm"] + + def __init__( + self, + output_folder: DataFolderLike, + model_dataset: str, + language: str = Languages.english, + histogram_round_digits: int = 3, + groups_to_compute: list[GROUP] = list(get_args(GROUP)), + top_k_config: TopKConfig = DEFAULT_TOP_K_CONFIG, + ) -> None: + super().__init__(output_folder, groups_to_compute, histogram_round_digits, top_k_config) + self.model = KenlmModel(model_dataset=model_dataset, language=language) + + def extract_stats(self, doc: Document) -> dict[str, int | float]: + return { + f"ccnet_perplexity_{self.model.model_dataset}_{self.model.language}": self.model.get_perplexity(doc.text) + } diff --git a/src/datatrove/pipeline/stats/sentence_stats.py b/src/datatrove/pipeline/stats/sentence_stats.py new file mode 100644 index 00000000..eacac713 --- /dev/null +++ b/src/datatrove/pipeline/stats/sentence_stats.py @@ -0,0 +1,69 @@ +from typing import get_args + +from datatrove.data import Document +from datatrove.io import DataFolderLike +from datatrove.pipeline.stats.base import BaseStats +from datatrove.pipeline.stats.config import DEFAULT_TOP_K_CONFIG, GROUP, TopKConfig +from datatrove.utils.typeshelper import Languages +from datatrove.utils.word_tokenizers import load_word_tokenizer + + +def get_short_sentence_ratio(sentences: list[str], threshold: int) -> float: + return sum([1 for sentence in sentences if len(sentence) <= threshold]) / len(sentences) + + +def get_long_sentence_ratio(sentences: list[str], threshold: int) -> float: + return sum([1 for sentence in sentences if len(sentence) >= threshold]) / len(sentences) + + +class SentenceStats(BaseStats): + """ + Sentence level stats of a document. + + Available stats: + * n_sentences + * avg_sentence_length: + * short_sentence_ratio_{chars}: + * long_sentence_ratio_{chars}: + """ + + name = "🈂️ Sentence stats" + + def __init__( + self, + output_folder: DataFolderLike, + short_sentence_max_chars_threshold: list[int] | None = None, + long_sentence_max_chars_threshold: list[int] | None = None, + language: str = Languages.english, + histogram_round_digits: int = 3, + groups_to_compute: list[GROUP] = list(get_args(GROUP)), + top_k_config: TopKConfig = DEFAULT_TOP_K_CONFIG, + ) -> None: + super().__init__( + output_folder, + groups_to_compute, + histogram_round_digits, + top_k_config, + ) + + self.short_sentence_max_chars_threshold = short_sentence_max_chars_threshold or [20] + self.long_sentence_max_chars_threshold = long_sentence_max_chars_threshold or [75] + self.language = language + + def extract_stats(self, doc: Document) -> dict[str, int | float]: + word_tokenizer = load_word_tokenizer(self.language) + + sentences = [s for s in word_tokenizer.sent_tokenize(doc.text) if s.strip()] + + return { + "n_sentences": len(sentences), + "avg_sentence_length": sum([len(s) for s in sentences]) / len(sentences), + **{ + f"short_sentence_ratio_{chars}": get_short_sentence_ratio(sentences, chars) + for chars in self.short_sentence_max_chars_threshold + }, + **{ + f"long_sentence_ratio_{chars}": get_long_sentence_ratio(sentences, chars) + for chars in self.long_sentence_max_chars_threshold + }, + } diff --git a/src/datatrove/pipeline/stats/token_stats.py b/src/datatrove/pipeline/stats/token_stats.py new file mode 100644 index 00000000..3c0595a2 --- /dev/null +++ b/src/datatrove/pipeline/stats/token_stats.py @@ -0,0 +1,39 @@ +from datatrove.data import Document +from datatrove.io import DataFolderLike +from datatrove.pipeline.stats.base import BaseStats +from datatrove.pipeline.stats.config import DEFAULT_TOP_K_CONFIG, GROUP, TopKConfig +from datatrove.utils.tokenization import PipelineStepWithTokenizer + + +class TokenStats(BaseStats, PipelineStepWithTokenizer): + """ + Token stats of a document. + + Available metrics: + token_count: Number of tokens in the document + """ + + name = "🔗 Token counter" + + _requires_dependencies = ["tokenizers"] + BaseStats._requires_dependencies + + def __init__( + self, + output_folder: DataFolderLike, + tokenizer_name_or_path: str = "gpt2", + groups_to_compute: list[GROUP] = ["fqdn", "suffix", "summary", "histogram"], + histogram_rounding: int = 3, + top_k_config: TopKConfig = DEFAULT_TOP_K_CONFIG, + ) -> None: + BaseStats.__init__(self, output_folder, groups_to_compute, histogram_rounding, top_k_config) + PipelineStepWithTokenizer.__init__(self) + self.tokenizer_name_or_path = tokenizer_name_or_path + + def extract_stats(self, doc: Document) -> dict[str, int | float]: + tokens_count = doc.metadata.get("token_count", None) + if tokens_count is None: + tokens_count = len(self.tokenizer.encode(doc.text).tokens) + + return { + "token_count": tokens_count, + } diff --git a/src/datatrove/pipeline/stats/urls.py b/src/datatrove/pipeline/stats/urls.py deleted file mode 100644 index 5ee6ae64..00000000 --- a/src/datatrove/pipeline/stats/urls.py +++ /dev/null @@ -1,87 +0,0 @@ -import json - -from datatrove.io import DataFolderLike, get_datafolder -from datatrove.pipeline.base import DocumentsPipeline, PipelineStep -from datatrove.utils.stats import MetricStatsDict - - -class URLStats(PipelineStep): - """Pipeline step to compute the statistics of URLs in a pipeline. - Will add a "url_stats.json" file in the output folder with the statistics. - - Args: - output_folder: the output folder to save the statistics - url_field: the field to use as URL in the Document metadata (default: "url") - input_folder: the input folder to read the statistics from (default: None). Used to merge statistics - topk: the number of top URLs to keep (default: None - keep all) - min_doc_count_to_save: the minimum number of documents per URL to save the URL (default: 1) - """ - - type = "📊 - STATS" - name = "🌐 URLs" - _requires_dependencies = ["tldextract"] - - def __init__( - self, - output_folder: DataFolderLike, - url_field: str = "url", - input_folder: DataFolderLike = None, - topk: int = None, - min_doc_count_to_save: int = 1, - ): - super().__init__() - self.url_field = url_field - self.output_folder = get_datafolder(output_folder) - self.input_folder = get_datafolder(input_folder) if input_folder else None - self.topk = topk - self.min_doc_count_to_save = min_doc_count_to_save - - def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> DocumentsPipeline: - doc_counter = MetricStatsDict() - tokens_counter = MetricStatsDict() - total_docs = 0 - total_tokens = 0 - if self.input_folder: - # reduce the map results - assert world_size == 1, "world_size must be 1 when getting the input from an input_folder" - for file in self.input_folder.list_files(glob_pattern="json"): - with self.input_folder.open(file, "rt") as f: - file_data = json.load(f) - doc_counter += MetricStatsDict(init=file_data["doc_counter"]) - tokens_counter += MetricStatsDict(init=file_data["tokens_counter"]) - total_docs += file_data["total_docs"] - total_tokens += file_data["total_tokens"] - if self.topk: - doc_counter = doc_counter.topk(self.topk) - tokens_counter = tokens_counter.topk(self.topk) - else: - from tldextract import tldextract - - # map and produce one output file per rank - for doc in data: - url = tldextract.extract(doc.metadata.get(self.url_field)).fqdn - doc_counter[url] += 1 - total_docs += 1 - if token_count := doc.metadata.get("token_count", None): - tokens_counter[url] += token_count - total_tokens += token_count - yield doc - # save to disk - if self.min_doc_count_to_save > 0: - for url in list(doc_counter.keys()): - if doc_counter[url].total < self.min_doc_count_to_save: - del doc_counter[url] - if url in tokens_counter: - del tokens_counter[url] - with self.output_folder.open( - f"{rank:05d}_url_stats.json" if not self.input_folder else "url_stats.json", "wt" - ) as f: - json.dump( - { - "total_docs": total_docs, - "total_tokens": total_tokens, - "doc_counter": doc_counter.to_dict(), - "tokens_counter": tokens_counter.to_dict(), - }, - f, - ) diff --git a/src/datatrove/pipeline/stats/word_stats.py b/src/datatrove/pipeline/stats/word_stats.py new file mode 100644 index 00000000..8fe16d52 --- /dev/null +++ b/src/datatrove/pipeline/stats/word_stats.py @@ -0,0 +1,83 @@ +from typing import get_args + +from datatrove.data import Document +from datatrove.io import DataFolderLike +from datatrove.pipeline.filters.gopher_quality_filter import STOP_WORDS +from datatrove.pipeline.stats.base import BaseStats +from datatrove.pipeline.stats.config import DEFAULT_TOP_K_CONFIG, GROUP, TopKConfig +from datatrove.utils.typeshelper import Languages +from datatrove.utils.word_tokenizers import load_word_tokenizer + + +def get_short_word_ratio(words: list[str], threshold: int) -> float: + return sum([1 for word in words if len(word) <= threshold]) / len(words) + + +def get_long_word_ratio(words: list[str], threshold: int) -> float: + return sum([1 for word in words if len(word) >= threshold]) / len(words) + + +class WordStats(BaseStats): + """ + Word level stats of a document. + + Available stats: + n_words: Number of words in the document + avg_word_length: Average length of words in the document + avg_words_per_line: Average number of words per line in the document + short_word_ratio_{chars}: Ratio of words shorter than {chars} characters + stop_word_ratio: Ratio of stop words + long_word_ratio_{chars}: Ratio of words longer than {chars} characters + type_token_ratio: Type-Token Ratio (TTR) + capitalized_word_ratio: Ratio of capitalized words + uppercase_word_ratio: Ratio of uppercase words + """ + + name = "🈂️ Word stats" + + def __init__( + self, + output_folder: DataFolderLike, + stop_words: list[str] = STOP_WORDS, + short_word_max_chars_threshold: list[int] | None = None, + long_word_max_chars_threshold: list[int] | None = None, + language: str = Languages.english, + groups_to_compute: list[GROUP] = list(get_args(GROUP)), + histogram_round_digits: int = 3, + top_k_config: TopKConfig = DEFAULT_TOP_K_CONFIG, + ) -> None: + super().__init__( + output_folder, + groups_to_compute, + histogram_round_digits, + top_k_config, + ) + + self.short_word_max_chars_threshold = short_word_max_chars_threshold or [3] + self.long_word_max_chars_threshold = long_word_max_chars_threshold or [7] + self.language = language + self.stop_words = stop_words + + def extract_stats(self, doc: Document) -> dict[str, int | float]: + word_tokenizer = load_word_tokenizer(self.language) + + words = word_tokenizer.word_tokenize(doc.text) + lines = doc.text.splitlines() + + return { + "n_words": len(words), + "avg_word_length": sum([len(word) for word in words]) / len(words), + "avg_words_per_line": len(words) / len(lines), + **{ + f"short_word_ratio_{chars}": get_short_word_ratio(words, chars) + for chars in self.short_word_max_chars_threshold + }, + **{ + f"long_word_ratio_{chars}": get_long_word_ratio(words, chars) + for chars in self.long_word_max_chars_threshold + }, + "type_token_ratio": len(set(words)) / len(words), + "uppercase_word_ratio": sum([1 for word in words if word.isupper()]) / len(words), + "capitalized_word_ratio": sum([1 for word in words if word.istitle()]) / len(words), + "stop_word_ratio": sum([1 for word in words if word in self.stop_words]) / len(words), + } diff --git a/src/datatrove/pipeline/tokens/context_shuffler.py b/src/datatrove/pipeline/tokens/context_shuffler.py index 8e2cc1bc..b9d9e9fe 100644 --- a/src/datatrove/pipeline/tokens/context_shuffler.py +++ b/src/datatrove/pipeline/tokens/context_shuffler.py @@ -1,13 +1,13 @@ import mmap import numpy as np -from loguru import logger from numpy.random import default_rng from datatrove.data import DocumentsPipeline from datatrove.io import DataFolderLike, get_datafolder from datatrove.pipeline.base import PipelineStep from datatrove.pipeline.tokens.merger import load_doc_ends +from datatrove.utils.logging import logger class DocumentTokenizerContextShuffler(PipelineStep): @@ -18,6 +18,7 @@ class DocumentTokenizerContextShuffler(PipelineStep): output_folder: the output folder to write the shuffled documents to window_size: the size of the window to shuffle (default: 2048 + 1) seed: the seed for the random number generator (default: None) + token_size (int): size of each token, in bytes """ name = "🗃 Context Shuffler" @@ -29,11 +30,13 @@ def __init__( output_folder: DataFolderLike, window_size: int = 2048 + 1, seed: int = None, + token_size: int = 2, ): super().__init__() self.input_folder = get_datafolder(input_folder) self.output_folder = get_datafolder(output_folder) self.window_size = window_size + self.token_size = token_size self.rand = default_rng(seed) def get_ordering(self, all_doc_ends): @@ -73,5 +76,8 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1 with mmap.mmap(f.fileno(), 0, prot=mmap.PROT_READ) as unshuf: with self.track_time(): for windowi in ordering: - start, end = windowi * self.window_size * 2, (windowi + 1) * self.window_size * 2 + start, end = ( + windowi * self.window_size * self.token_size, + (windowi + 1) * self.window_size * self.token_size, + ) fout.write(unshuf[start:end]) diff --git a/src/datatrove/pipeline/tokens/counter.py b/src/datatrove/pipeline/tokens/counter.py index a6f70507..40ba4734 100644 --- a/src/datatrove/pipeline/tokens/counter.py +++ b/src/datatrove/pipeline/tokens/counter.py @@ -1,6 +1,6 @@ from datatrove.data import DocumentsPipeline from datatrove.pipeline.base import PipelineStep -from datatrove.utils.tokenization import PipelineStepWithTokenizer +from datatrove.utils.tokenization import PipelineStepWithTokenizer, batched class TokensCounter(PipelineStepWithTokenizer): @@ -10,7 +10,8 @@ class TokensCounter(PipelineStepWithTokenizer): Args: tokenizer_name_or_path (str): the name or path of the tokenizer to use, from the HuggingFace tokenizers library or a local file. - count_eos_token (bool): whether to count the EOS token on each document. + count_eos_token (bool): whether to count the EOS token on each document. (basically +1 per document) + batch_size: batch size for tokenization """ name = "📊 Counter" @@ -20,20 +21,12 @@ def __init__( self, tokenizer_name_or_path: str = "gpt2", # tokenizer to use, from HF or a local file path count_eos_token: bool = False, # whether to count the EOS token on each document - overwrite: bool = True, # re-tokenize and recompute nb of tokens even if they are already in metadata["tokens_count"] + batch_size: int = 10000, # batch size for tokenization ): - """ - Initializes the token counting pipeline step. - - Args: - tokenizer_name_or_path: Name or path of tokenizer to use (from HF or local). - count_eos_token: Whether to include the EOS token in the token count per document. (basically +1 per document) - overwrite: Whether to re-tokenize and recompute the number of tokens even if they are already stored in metadata["tokens_count"] - """ super().__init__() self.tokenizer_name_or_path = tokenizer_name_or_path self.count_eos_token = count_eos_token - self.overwrite = overwrite + self.batch_size = batch_size def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> DocumentsPipeline: """ @@ -47,17 +40,19 @@ def run(self, data: DocumentsPipeline, rank: int = 0, world_size: int = 1) -> Do DocumentsPipeline: The pipeline with updated documents, each having a new or updated `token_count` in its metadata. """ - for document in data: - if "token_count" in document.metadata and not self.overwrite: - count = document.metadata["token_count"] - else: - with self.track_time(): - count = len(self.tokenizer.encode(document.text).ids) - if self.count_eos_token: - count += 1 + from tokenizers import Encoding + + # tokenize document's text in batches to go faster + for batch in batched(data, self.batch_size): + with self.track_time(unit="batch"): + encoded_batch: list[Encoding] = self.tokenizer.encode_batch([document.text for document in batch]) + for document, encoded in zip(batch, encoded_batch): + count = len(encoded.ids) + if self.count_eos_token: + count += 1 document.metadata["token_count"] = count - self.stat_update("tokens", value=count) - yield document + self.stat_update("tokens", value=count) + yield document class LengthCounter(PipelineStep): diff --git a/src/datatrove/pipeline/tokens/merger.py b/src/datatrove/pipeline/tokens/merger.py index 0fbaadf9..9781baa4 100644 --- a/src/datatrove/pipeline/tokens/merger.py +++ b/src/datatrove/pipeline/tokens/merger.py @@ -105,9 +105,18 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1 f"({len(datafiles)} vs {len(datafiles_index)} vs {len(datafiles_loss)})" ) + tokenizer_name_or_path, token_size = None, 2 + if self.save_final_metadata: + if self.input_folder.isfile(f"{datafiles[0]}.metadata"): + with self.input_folder.open(f"{datafiles[0]}.metadata", "rt") as f: + tokenizer_name_or_path = f.read().splitlines()[0] + if "|" in tokenizer_name_or_path: + tokenizer_name_or_path, token_size = tokenizer_name_or_path.split("|") + token_size = int(token_size) + doc_ends = [load_doc_ends(self.input_folder.open(file, "rb")) for file in datafiles_index] token_inputs = list( - map(partial(get_data_reader, nb_bytes=2), self.input_folder.open_files(datafiles), doc_ends) + map(partial(get_data_reader, nb_bytes=token_size), self.input_folder.open_files(datafiles), doc_ends) ) loss_inputs = ( list(map(partial(get_data_reader, nb_bytes=1), self.input_folder.open_files(datafiles_loss), doc_ends)) @@ -115,12 +124,6 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1 else None ) - tokenizer_name_or_path = None - if self.save_final_metadata: - if self.input_folder.isfile(f"{datafiles[0]}.metadata"): - with self.input_folder.open(f"{datafiles[0]}.metadata", "rt") as f: - tokenizer_name_or_path = f.read().splitlines()[0] - ordering = self.get_ordering(doc_ends) file_ct = 0 @@ -131,6 +134,7 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1 upload_block_size=self.upload_block_size, tokenizer_name_or_path=tokenizer_name_or_path, save_final_metadata=self.save_final_metadata, + token_size=token_size, ) for input_file_id in tqdm( ordering, desc="Merging documents", unit="documents", total=len(ordering), disable=not self.progress @@ -147,13 +151,14 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1 upload_block_size=self.upload_block_size, tokenizer_name_or_path=tokenizer_name_or_path, save_final_metadata=self.save_final_metadata, + token_size=token_size, ) # copy tokens and loss tokens = next(token_inputs[input_file_id]) output_file.write_bytes(tokens) if loss_inputs: output_file.write_loss_bytes(next(loss_inputs[input_file_id])) - self.stat_update("tokens", value=len(tokens) // 2) + self.stat_update("tokens", value=len(tokens) // token_size) # cleanup output_file.close() if self.save_final_metadata: diff --git a/src/datatrove/pipeline/tokens/tokenizer.py b/src/datatrove/pipeline/tokens/tokenizer.py index ca2c2b17..558e8ad3 100644 --- a/src/datatrove/pipeline/tokens/tokenizer.py +++ b/src/datatrove/pipeline/tokens/tokenizer.py @@ -1,15 +1,14 @@ -import itertools import struct from typing import TYPE_CHECKING import humanize import numpy as np -from loguru import logger from numpy.random import default_rng from datatrove.data import Document, DocumentsPipeline from datatrove.io import DataFolder, DataFolderLike, get_datafolder -from datatrove.utils.tokenization import PipelineStepWithTokenizer +from datatrove.utils.logging import logger +from datatrove.utils.tokenization import PipelineStepWithTokenizer, batched SHUFFLING_READ_BLOCK_SIZE = 50000 # read 50kb at a time only (~mean + 2sigmas for final filtered common crawl docs) @@ -20,26 +19,6 @@ from tokenizers import Encoding -def batched(iterable, n): - """In python 3.12+ we could use itertools.batched instead - - One difference with itertools.batched: we return a list instead of a tuple - - Args: - iterable: - n: - - Returns: - - """ - # batched('ABCDEFG', 3) --> ABC DEF G - if n < 1: - raise ValueError("n must be at least one") - it = iter(iterable) - while batch := list(itertools.islice(it, n)): - yield batch - - class TokenizedFile: """Class to write tokenized documents to local/remote folders. Handles writing the tokenized document, an index file with the document ends (in tokens), and optionally a loss file with loss masks. @@ -51,6 +30,8 @@ class TokenizedFile: save_index (bool): whether to save the index file (document boundaries) save_loss_metadata (bool): whether to save the loss metadata (to mask some tokens during training) upload_block_size (int): the fsspec size of the upload block for remote filesystems (S3) + token_size (int): size of each token, in bytes + """ def __init__( @@ -62,6 +43,7 @@ def __init__( upload_block_size: int | None = None, tokenizer_name_or_path: str | None = None, save_final_metadata: bool = False, + token_size: int = 2, ): self.output_folder = get_datafolder(output_folder) self.filename = filename @@ -69,6 +51,8 @@ def __init__( self.save_loss_metadata = save_loss_metadata self.upload_block_size = upload_block_size self.write_idx = 0 + self.token_size = token_size + self.token_format = "I" if self.token_size == 4 else "H" self.doc_ends = [] self.tokenizer_name_or_path = tokenizer_name_or_path self.save_final_metadata = save_final_metadata @@ -121,12 +105,10 @@ def write_bytes(self, tk_bytes: bytes, doc_ends: list[int] = None): if doc_ends is not None: # We've written several documents at once self.doc_ends.extend([d + self.write_idx for d in doc_ends]) - # 1 token = 2 bytes (uint16) - self.write_idx += len(tk_bytes) // 2 + self.write_idx += len(tk_bytes) // self.token_size else: # We've written a single document - # 1 token = 2 bytes (uint16) - self.write_idx += len(tk_bytes) // 2 + self.write_idx += len(tk_bytes) // self.token_size # save each document's boundary self.doc_ends.append(self.write_idx) @@ -149,8 +131,8 @@ def write(self, tokens: list[int], loss_values: np.ndarray | None): tokens (list[int]): the tokens to write loss_values (np.ndarray | None): optional loss values to write """ - # get the bytes for uint16 (H) - self.write_bytes(struct.pack("<%sH" % len(tokens), *tokens)) + # get the bytes + self.write_bytes(struct.pack(f"<%s{self.token_format}" % len(tokens), *tokens)) if loss_values is not None: self.write_loss_bytes(struct.pack("<%s?" % len(loss_values), *loss_values)) @@ -197,6 +179,7 @@ def copy( upload_block_size=self.upload_block_size, tokenizer_name_or_path=self.tokenizer_name_or_path, save_final_metadata=self.save_final_metadata, + token_size=self.token_size, ) logger.info(f"Shuffling in {destination}...") # shuffle doc_id @@ -204,9 +187,9 @@ def copy( for doc_id in ordering: # get start and end from the boundaries start, end = self.doc_ends[doc_id - 1] if doc_id > 0 else 0, self.doc_ends[doc_id] - # copy the bytes. each token is 2 bytes - tokens_file.seek(start * 2) - new_file.write_bytes(tokens_file.read((end - start) * 2)) + # copy the bytes. each token is token_size bytes + tokens_file.seek(start * self.token_size) + new_file.write_bytes(tokens_file.read((end - start) * self.token_size)) # copy loss values (1 byte per token) if loss_file: loss_file.seek(start) @@ -223,6 +206,7 @@ def copy( upload_block_size=self.upload_block_size, tokenizer_name_or_path=self.tokenizer_name_or_path, save_final_metadata=self.save_final_metadata, + token_size=self.token_size, ) logger.info(f"Shuffling in {destination}...") total_tokens_written = 0 @@ -241,13 +225,21 @@ def write_final_metadata(self, token_count: int = -1, filename: str = None): """ tokenizer_name = self.tokenizer_name_or_path if not tokenizer_name: - tokenizer_name = "Unknown Tokenizer" + tokenizer_name = "Unknown Tokenizer" + "|" + str(self.token_size) if filename is None: filename = self.filename with self.output_folder.open(f"{filename}.metadata", "wt") as f: if token_count == -1: token_count = self.write_idx - f.write("\n".join([tokenizer_name, str(token_count), humanize.metric(token_count, unit="T")])) + f.write( + "\n".join( + [ + tokenizer_name + "|" + str(self.token_size), + str(token_count), + humanize.metric(token_count, unit="T"), + ] + ) + ) def get_output_filename(save_filename, rank: int, name: str, sub_rank: int = None): @@ -360,6 +352,7 @@ def write_unshuffled(self, data: DocumentsPipeline, filename: str): upload_block_size=self.upload_block_size, tokenizer_name_or_path=self.tokenizer_name_or_path, save_final_metadata=self.save_final_metadata, + token_size=self.token_size, ) # tokenize document's text in batches to go faster – we compute loss values independently if needed for batch in batched(data, self.batch_size): diff --git a/src/datatrove/pipeline/writers/disk_base.py b/src/datatrove/pipeline/writers/disk_base.py index 1dce26d2..168af37c 100644 --- a/src/datatrove/pipeline/writers/disk_base.py +++ b/src/datatrove/pipeline/writers/disk_base.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from collections import Counter from string import Template +from types import MethodType from typing import IO, Callable from datatrove.data import Document, DocumentsPipeline @@ -47,7 +48,7 @@ def __init__( raise ValueError("Can only specify `max_file_size` when writing in binary mode!") self.output_filename = Template(output_filename) self.output_mg = self.output_folder.get_output_file_manager(mode=mode, compression=compression) - self.adapter = adapter if adapter else self._default_adapter + self.adapter = MethodType(adapter, self) if adapter else self._default_adapter self.expand_metadata = expand_metadata def _default_adapter(self, document: Document) -> dict: @@ -128,7 +129,9 @@ def _get_filename_with_file_id(self, filename): Returns: formatted filename """ - return f"{os.path.dirname(filename)}/{self.file_id_counter[filename]:03d}_{os.path.basename(filename)}" + if os.path.dirname(filename): + return f"{os.path.dirname(filename)}/{self.file_id_counter[filename]:03d}_{os.path.basename(filename)}" + return f"{self.file_id_counter[filename]:03d}_{os.path.basename(filename)}" def write(self, document: Document, rank: int = 0, **kwargs): """ diff --git a/src/datatrove/pipeline/writers/huggingface.py b/src/datatrove/pipeline/writers/huggingface.py index 7529bfc7..32bebe05 100644 --- a/src/datatrove/pipeline/writers/huggingface.py +++ b/src/datatrove/pipeline/writers/huggingface.py @@ -11,10 +11,10 @@ preupload_lfs_files, ) from huggingface_hub.utils import HfHubHTTPError -from loguru import logger from datatrove.io import DataFolderLike, get_datafolder from datatrove.pipeline.writers import ParquetWriter +from datatrove.utils.logging import logger MAX_RETRIES = 12 diff --git a/src/datatrove/pipeline/writers/jsonl.py b/src/datatrove/pipeline/writers/jsonl.py index 7b4bce10..555f33a6 100644 --- a/src/datatrove/pipeline/writers/jsonl.py +++ b/src/datatrove/pipeline/writers/jsonl.py @@ -1,4 +1,3 @@ -import json from typing import IO, Callable from datatrove.io import DataFolderLike @@ -17,6 +16,7 @@ class JsonlWriter(DiskWriter): default_output_filename: str = "${rank}.jsonl" name = "🐿 Jsonl" + _requires_dependencies = ["orjson"] def __init__( self, @@ -24,8 +24,18 @@ def __init__( output_filename: str = None, compression: str | None = "gzip", adapter: Callable = None, + max_file_size: int = -1, # in bytes. -1 for unlimited ): - super().__init__(output_folder, output_filename=output_filename, compression=compression, adapter=adapter) + super().__init__( + output_folder, + output_filename=output_filename, + compression=compression, + adapter=adapter, + mode="wb", + max_file_size=max_file_size, + ) def _write(self, document: dict, file_handler: IO, _filename: str): - file_handler.write(json.dumps(document, ensure_ascii=False) + "\n") + import orjson + + file_handler.write(orjson.dumps(document, option=orjson.OPT_APPEND_NEWLINE)) diff --git a/src/datatrove/tools/failed_logs.py b/src/datatrove/tools/failed_logs.py index a7596d4e..5a4a4276 100644 --- a/src/datatrove/tools/failed_logs.py +++ b/src/datatrove/tools/failed_logs.py @@ -3,12 +3,12 @@ import os.path import re -from loguru import logger from rich.console import Console from rich.prompt import Confirm from datatrove.io import get_datafolder from datatrove.utils._import_utils import is_rich_available +from datatrove.utils.logging import logger if not is_rich_available(): diff --git a/src/datatrove/tools/jobs_status.py b/src/datatrove/tools/jobs_status.py index dc730d62..56eb69d6 100644 --- a/src/datatrove/tools/jobs_status.py +++ b/src/datatrove/tools/jobs_status.py @@ -2,11 +2,11 @@ import json import os.path -from loguru import logger from rich.console import Console from datatrove.io import get_datafolder from datatrove.utils._import_utils import is_rich_available +from datatrove.utils.logging import logger if not is_rich_available(): diff --git a/src/datatrove/tools/merge_stats.py b/src/datatrove/tools/merge_stats.py index 21d81d5f..e61a0b64 100644 --- a/src/datatrove/tools/merge_stats.py +++ b/src/datatrove/tools/merge_stats.py @@ -2,10 +2,10 @@ import json import os.path -from loguru import logger from tqdm import tqdm from datatrove.io import get_datafolder, open_file +from datatrove.utils.logging import logger from datatrove.utils.stats import PipelineStats diff --git a/src/datatrove/utils/_import_utils.py b/src/datatrove/utils/_import_utils.py index 26a2edf1..bdb875e1 100644 --- a/src/datatrove/utils/_import_utils.py +++ b/src/datatrove/utils/_import_utils.py @@ -1,11 +1,46 @@ import importlib.resources import os from functools import lru_cache +from typing import NoReturn ASSETS_PATH = os.path.join(importlib.resources.files(__package__.split(".")[0]), "assets") +def check_required_dependencies(step_name: str, required_dependencies: list[str] | list[tuple[str, str]]): + missing_dependencies: dict[str, str] = {} + for dependency in required_dependencies: + dependency = dependency if isinstance(dependency, tuple) else (dependency, dependency) + package_name, pip_name = dependency + if not _is_package_available(package_name): + missing_dependencies[package_name] = pip_name + if missing_dependencies: + _raise_error_for_missing_dependencies(step_name, missing_dependencies) + + +def _raise_error_for_missing_dependencies(step_name: str, dependencies: dict[str, str]) -> NoReturn: + """Helper to raise an ImportError for missing dependencies and prompt the user to install said dependencies + + Args: + step_name: str + The name of the step + dependencies: dict[str, str] + The missing dependencies + + """ + dependencies = dict(sorted(dependencies.items())) + package_names = list(dependencies) + if len(dependencies) > 1: + package_names = ( + f"{','.join('`' + package_name + '`' for package_name in package_names[:-1])} and `{package_names[-1]}`" + ) + else: + package_names = f"`{package_names[0]}`" + raise ImportError( + f"Please install {package_names} to use {step_name} (`pip install {' '.join(list(dependencies.values()))}`)." + ) + + @lru_cache def _is_package_available(package_name): """ @@ -44,3 +79,7 @@ def is_s3fs_available(): def is_moto_available(): return _is_package_available("moto") + + +def is_torch_available(): + return _is_package_available("torch") diff --git a/src/datatrove/utils/binaryio.py b/src/datatrove/utils/binaryio.py index 37ed1ccf..58b2b4da 100644 --- a/src/datatrove/utils/binaryio.py +++ b/src/datatrove/utils/binaryio.py @@ -3,6 +3,7 @@ from functools import cache from typing import BinaryIO +import numpy as np from fsspec.spec import AbstractBufferedFile @@ -29,6 +30,27 @@ def read_tuples_from_file(file: BinaryIO, *formats, lines_to_buffer: int = 5): yield from reader.iter_unpack(chunk) +def read_np_from_file( + file: BinaryIO, + dtype: np.dtype, + is_local_file: bool = False, +) -> np.ndarray: + """ + Utility which reads data from a file and returns a numpy array. + Args: + file: the file to read from + dtype: expected dtype of data + is_local_file: whether the file is a local file (enables optimizations) + Returns: + numpy array of data from the file + """ + with file: + if is_local_file: + return np.fromfile(file, dtype=dtype) + else: + return np.frombuffer(file.read(), dtype=dtype) + + def seek_to_start(f: AbstractBufferedFile, start_hash: int, line_format: str, hash_format: str): if start_hash == 0: return diff --git a/src/datatrove/utils/dataset.py b/src/datatrove/utils/dataset.py new file mode 100644 index 00000000..2e4abad7 --- /dev/null +++ b/src/datatrove/utils/dataset.py @@ -0,0 +1,139 @@ +from bisect import bisect + +import numpy as np +import torch +from fsspec import AbstractFileSystem +from fsspec.core import url_to_fs + +from datatrove.utils._import_utils import is_torch_available + + +if is_torch_available(): + from torch.utils.data import Dataset + + class DatatroveFileDataset(Dataset): + """Dataset for a single .ds file created by datatrove + We loop on the dataset if asking for an index larger than the dataset size + + Args: + file_path (str): path to file on s3, locally, or some other fsspec supported path + seq_len (int): sequence length + token_size (int): size of a single token, in bytes. Usually 2 for vocab sizes < 65k and 4 for larger + max_tokens (int): only read at most this number of tokens + """ + + def __init__( + self, + file_path: str, + seq_len: int, + token_size: int = 2, + max_tokens: int | None = None, + ): + self.file_path: str = file_path + self.seq_len = seq_len + self.token_size = token_size + + self.fs: AbstractFileSystem + self.fs, self.file_path = url_to_fs(file_path) + fsize = self.fs.size(self.file_path) + # total number of full contexts in this file + num_tokens = fsize // self.token_size + self._len = (min(max_tokens, num_tokens) if max_tokens else num_tokens) // (seq_len + 1) + self._f = None + + def __getitem__(self, item): + if not self._f: + self._f = self.fs.open(self.file_path, "rb") + chunk_size = self.token_size * (self.seq_len + 1) + self._f.seek(item * chunk_size) + return { + "input_ids": torch.as_tensor( + np.frombuffer(self._f.read(chunk_size), np.uint16 if self.token_size == 2 else np.uint32).astype( + np.int64 + ), + dtype=torch.long, + ) + } + + def __len__(self): + return self._len + + def __del__(self): + if self._f: + self._f.close() + + class DatatroveFolderDataset(Dataset): + """ + Dataset for a folder of .ds files + We loop on the dataset if asking for an index larger than the dataset size + + Args: + folder_path (str): path to folder on S3, locally, or some other fsspec supported path + seq_len (int): sequence length + filename_pattern (Union[Pattern, str], optional): filename pattern. Defaults to None. + recursive (bool, optional): search recursively. Defaults to True. + token_size (int): size of a single token, in bytes. Usually 2 for vocab sizes < 65k and 4 for larger + max_tokens (int): only read at most this number of tokens + shuffle (bool, optional): shuffle the files in the folder. Defaults to False. + seed (int, optional): seed for shuffling. Defaults to 42. + """ + + def __init__( + self, + folder_path: str, + seq_len: int, + filename_pattern: str = None, + recursive: bool = True, + token_size: int = 2, + max_tokens: int | None = None, + shuffle: bool = False, + seed: int = 42, + ): + self.folder_path = folder_path + self.filename_pattern = filename_pattern + fs, folder_path = url_to_fs(folder_path) + matched_files = ( + fs.find(folder_path, detail=False, maxdepth=1 if not recursive else None) + if not filename_pattern + else fs.glob(filename_pattern, maxdepth=1 if not recursive else None) + ) + if not matched_files: + raise FileNotFoundError(f'No files matching "{filename_pattern}" found in {folder_path}') + + self.files = [] + remaining_tokens = max_tokens + for path in matched_files: + file_data = DatatroveFileDataset( + fs.unstrip_protocol(path), + seq_len, + token_size=token_size, + max_tokens=remaining_tokens, + ) + self.files.append(file_data) + if remaining_tokens is not None: + remaining_tokens -= len(file_data) * (seq_len + 1) + if remaining_tokens <= 0: + break + + if shuffle: + rand = np.random.default_rng(seed) + ordering = rand.permutation(range(len(self.files))) + self.files = [self.files[i] for i in ordering] + + self.lens = np.cumsum([0] + [len(f) for f in self.files]).tolist() + + self.current_file = 0 + + def __getitem__(self, item): + # check if we are in the same file as before + if not (self.lens[self.current_file] <= item < self.lens[self.current_file + 1]): + # figure out current file + self.current_file = bisect(self.lens, item) - 1 + # subtract file starting offset + return self.files[self.current_file][item - self.lens[self.current_file]] + + def __len__(self): + return self.lens[-1] if self.lens else 0 +else: + DatatroveFileDataset = NotImplemented + DatatroveFolderDataset = NotImplemented diff --git a/src/datatrove/utils/hashes/sha1.py b/src/datatrove/utils/hashes/sha1.py new file mode 100644 index 00000000..52323c54 --- /dev/null +++ b/src/datatrove/utils/hashes/sha1.py @@ -0,0 +1,26 @@ +import hashlib +import struct + + +def sha1_hash32(data: str): + """A 32-bit hash function based on SHA1. + + Args: + data (bytes): the data to generate 32-bit integer hash from. + + Returns: + int: an integer hash value that can be encoded using 32 bits. + """ + return struct.unpack(" Callable[[str], int]: + if config.hash_fc == "sha1": + return sha1_hash32 if config.precision == 32 else sha1_hash64 + elif config.hash_fc == "xxhash": + from datatrove.utils.hashes.xxhash import xxhash32, xxhash64 + + return xxhash32 if config.precision == 32 else xxhash64 + else: + raise ValueError(f"Unknown {config.hash_fc=}") diff --git a/src/datatrove/utils/lid.py b/src/datatrove/utils/lid.py new file mode 100644 index 00000000..7f166154 --- /dev/null +++ b/src/datatrove/utils/lid.py @@ -0,0 +1,57 @@ +from abc import abstractmethod + +from datatrove.data import Document +from datatrove.io import cached_asset_path_or_download + + +class LID: + def __init__(self, languages: list[str]) -> None: + self.languages = languages + + @abstractmethod + def predict(self, doc: Document) -> tuple[tuple[str, int], dict[str, float]]: + """ + Predicts the likelihood of the document being written in given languages, alongside with the most likely language + Args: + doc (Document): Document to predict languages for + Returns: + dict[str, float]: Languages and score + """ + raise NotImplementedError + + +class FastTextModel(LID): + LANGUAGE_ID_MODEL_URL = "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin" + + def __init__(self, languages: list[str], k: int = 1) -> None: + """ + Args: + languages (list[str]): Languages to predict + k (int, optional): Number of top-k languages to consider, all languages outside of k will be considered as being predicted with 0.0 + """ + super().__init__(languages) + self._model = None + self.k = k + + @property + def model(self): + if not self._model: + from fasttext.FastText import _FastText + + model_file = cached_asset_path_or_download( + self.LANGUAGE_ID_MODEL_URL, + namespace="filters", + subfolder="language_filter", + desc="fast-text language identifier model", + ) + self._model = _FastText(model_file) + return self._model + + def predict(self, doc: Document) -> tuple[tuple[str, int], dict[str, float]]: + langs, scores = self.model.predict(doc.text.replace("\n", " "), k=self.k) + lang_pairs = {lang.split("__")[2]: score for lang, score in zip(langs, scores)} + best_lang_pair = max(lang_pairs.items(), key=lambda x: x[1]) + return best_lang_pair, {lang: lang_pairs.get(lang, 0.0) for lang in self.languages} + + +# We don't support CLD3, not only it's worse than fasttext, but installation is really problematic, because of old version of protobuffers diff --git a/src/datatrove/utils/logging.py b/src/datatrove/utils/logging.py index 9a2b7855..f301b425 100644 --- a/src/datatrove/utils/logging.py +++ b/src/datatrove/utils/logging.py @@ -1,3 +1,4 @@ +import os import random import string import sys @@ -5,7 +6,14 @@ from loguru import logger -from datatrove.io import DataFolder + +def get_env_bool(name, default=None): + env_var = os.environ.get(name, None) + return default if env_var is None else (env_var.lower().strip() in ("yes", "true", "t", "1")) + + +DATATROVE_COLORIZE_LOGS = get_env_bool("DATATROVE_COLORIZE_LOGS") +DATATROVE_COLORIZE_LOG_FILES = get_env_bool("DATATROVE_COLORIZE_LOG_FILES", False) def get_timestamp() -> str: @@ -29,21 +37,25 @@ def get_random_str(length=5): return "".join(random.choice(string.ascii_lowercase) for _ in range(length)) -def add_task_logger(logging_dir: DataFolder, rank: int, local_rank: int = 0): +def add_task_logger( + logging_dir, + rank: int, + local_rank: int = 0, +): """ Sets up logging for a given task Args: - logging_dir: DataFolder: + logging_dir: DataFolder rank: int: local_rank: int: (Default value = 0) - Returns: """ logger.remove() logfile = logging_dir.open(f"logs/task_{rank:05d}.log", "wt", encoding="utf-8") - logger.add(sys.stderr, level="INFO" if local_rank == 0 else "ERROR") - logger.add(logfile, colorize=True, level="DEBUG") + logger.add(sys.stderr, colorize=DATATROVE_COLORIZE_LOGS, level="INFO" if local_rank == 0 else "ERROR") + logger.add(logfile, colorize=DATATROVE_COLORIZE_LOG_FILES, level="DEBUG") + logger.info(f"Launching pipeline for {rank=}") return logfile @@ -53,14 +65,17 @@ def close_task_logger(logfile): Close logfile and reset logging setup Args: logfile: - Returns: """ logger.complete() - logger.remove() + setup_default_logger() # re-add default logger logfile.close() - logger.add(sys.stderr) # re-add default logger + + +def setup_default_logger(): + logger.remove() + logger.add(sys.stderr, colorize=DATATROVE_COLORIZE_LOGS) def log_pipeline(pipeline): @@ -74,3 +89,7 @@ def log_pipeline(pipeline): """ steps = "\n".join([pipe.__repr__() if callable(pipe) else "Iterable" for pipe in pipeline]) logger.info(f"\n--- 🛠️ PIPELINE 🛠\n{steps}") + + +# set colorization based on env vars +setup_default_logger() diff --git a/src/datatrove/utils/perplexity.py b/src/datatrove/utils/perplexity.py new file mode 100644 index 00000000..8fc84dc0 --- /dev/null +++ b/src/datatrove/utils/perplexity.py @@ -0,0 +1,164 @@ +# This file includes code from edugp/kenlm by Eduardo Gonzalez Ponferrada, +# licensed under the MIT License. The original code can be found at https://huggingface.co/edugp/kenlm. + +import re +from pathlib import Path +from typing import Dict + +from huggingface_hub import hf_hub_url + +from datatrove.io import cached_asset_path_or_download +from datatrove.utils.text import TextNormConfig, simplify_text + + +MODEL_REPO = "edugp/kenlm" + + +class SentencePiece: + def __init__( + self, + model_dataset: str, + model_name: str, + ): + super().__init__() + self.model_name = model_name + self.model_dataset = model_dataset + self._model = None + + @property + def model(self): + import sentencepiece + + if self._model is None: + path = cached_asset_path_or_download( + hf_hub_url(MODEL_REPO, str(Path(self.model_dataset, f"{self.model_name}.sp.model"))) + ) + self._model = sentencepiece.SentencePieceProcessor() + self._model.load(path) + return self._model + + def tokenize(self, text: dict) -> dict: + tokenized = self.model.encode_as_pieces(text) + return " ".join(tokenized) + + +class KenlmModel: + digit_re: re.Pattern = re.compile(r"\d") + unicode_punct: Dict[str, str] = { + ",": ",", + "。": ".", + "、": ",", + "„": '"', + "”": '"', + "“": '"', + "«": '"', + "»": '"', + "1": '"', + "」": '"', + "「": '"', + "《": '"', + "》": '"', + "´": "'", + "∶": ":", + ":": ":", + "?": "?", + "!": "!", + "(": "(", + ")": ")", + ";": ";", + "–": "-", + "—": " - ", + ".": ". ", + "~": "~", + "’": "'", + "…": "...", + "━": "-", + "〈": "<", + "〉": ">", + "【": "[", + "】": "]", + "%": "%", + "►": "-", + } + unicode_punct_re = re.compile(f"[{''.join(unicode_punct.keys())}]") + non_printing_chars_re = re.compile(f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]") + + def __init__( + self, + model_dataset: str, + language: str, + ): + self.model_dataset = model_dataset + self.language = language + self._tokenizer = None + self._model = None + + @property + def model(self): + import kenlm + + if self._model is None: + model_path = Path(self.model_dataset, f"{self.language}.arpa.bin") + path = cached_asset_path_or_download(hf_hub_url(MODEL_REPO, str(model_path))) + self._model = kenlm.Model(path) + return self._model + + @property + def tokenizer(self): + if self._tokenizer is None: + self._tokenizer = SentencePiece(self.model_dataset, self.language) + return self._tokenizer + + @classmethod + def from_pretrained( + cls, + model_dataset: str, + language: str, + ): + return cls( + model_dataset, + language, + ) + + def pp(self, log_score, length): + return 10.0 ** (-log_score / length) + + def get_perplexity(self, doc: str, normalize_cc_net: bool = True): + if normalize_cc_net: + doc = self.normalize( + doc, + ) + # Tokenize (after normalizing): See https://github.com/facebookresearch/cc_net/blob/bda555bd1cf1ee2e0b925363e62a61cd46c8b60d/cc_net/mine.py#L352 for full pipeline + doc = self.tokenizer.tokenize(doc) + doc_log_score, doc_length = 0, 0 + for line in doc.split("\n"): + log_score = self.model.score(line) + length = len(line.split()) + 1 + doc_log_score += log_score + doc_length += length + return round(self.pp(doc_log_score, doc_length), 1) + + def normalize( + self, + text: str, + ) -> str: + text = simplify_text( + text, + config=TextNormConfig( + lowercase=True, + norm_numbers=True, + norm_whitespace=False, + remove_punctuation=False, + norm_unicode_diacritics=True, + ), + ) + # TODO: integrate these options to simplify_text + text = self.replace_unicode_punct(text) + text = self.remove_non_printing_char(text) + return text + + def replace_unicode_punct(self, text: str) -> str: + return "".join(self.unicode_punct.get(c, c) for c in text) + + def remove_non_printing_char(self, text: str) -> str: + return self.non_printing_chars_re.sub("", text) diff --git a/src/datatrove/utils/stats.py b/src/datatrove/utils/stats.py index d222e6ed..6a4b96ce 100644 --- a/src/datatrove/utils/stats.py +++ b/src/datatrove/utils/stats.py @@ -45,7 +45,11 @@ def __repr__(self): return ", ".join(f"{key}: {stats}" for key, stats in self.items()) def to_dict(self): - return {a: b.to_dict() for a, b in self.items()} + return {a: (b.to_dict() if hasattr(b, "to_dict") else b) for a, b in self.items()} + + @classmethod + def from_dict(cls, data): + return MetricStatsDict(init={a: MetricStats.from_dict(b) for a, b in data.items()}) class Stats: @@ -304,7 +308,9 @@ def from_dict(cls, data): if isinstance(data, dict): total = data.get("total") mean = data.get("mean", 1) - n = data.get("n", total if mean != 1 else 1) + # We save n if we it has been added 1+ times and we didn't add just 1 -> mean == 1 + # This means that if mean == 1 and we don't have n, the n must be total, otherwise 1 + n = data.get("n", total if mean == 1 else 1) return cls( total=total, n=n, diff --git a/src/datatrove/utils/text.py b/src/datatrove/utils/text.py index a3759712..a2d713bc 100644 --- a/src/datatrove/utils/text.py +++ b/src/datatrove/utils/text.py @@ -1,14 +1,21 @@ -import hashlib import re -import struct import unicodedata from dataclasses import dataclass +from itertools import tee +from typing import Iterable + +from datatrove.utils.typeshelper import Languages +from datatrove.utils.word_tokenizers import load_word_tokenizer PUNCTUATION = "!/—”:%1〈&(、━\\【#%「」,】;+^]~“《„';’{|∶´[=-`*.(–?!:$~«〉,><》)?)。…@_.\"}►»" + "".join( - map(chr, (x for a, b in ((0, 9), (11, 13), (13, 32), (127, 160)) for x in range(a, b))) + map( + chr, + (x for a, b in ((0, 9), (11, 13), (13, 32), (127, 160)) for x in range(a, b)), + ) ) PUNCTUATION_SET = set(PUNCTUATION) +PUNCTUATION_TRANS = str.maketrans(PUNCTUATION, " " * len(PUNCTUATION)) @dataclass @@ -23,7 +30,7 @@ class TextNormConfig: DEF_TEXT_NORM_CONFIG = TextNormConfig() -NUMBERS_PATTERN = re.compile(r"\d+") +NUMBERS_PATTERN = re.compile(r"\d+(\.\d+)?") WHITESPACE_PATTERN = re.compile(r"\s+") # WARNING: english specific WEEKDAYS_PATTERN = re.compile(r"monday|tuesday|wednesday|thursday|friday|saturday|sunday") @@ -32,6 +39,9 @@ class TextNormConfig: def simplify_text(text: str, config=DEF_TEXT_NORM_CONFIG) -> str: """Performs the following operations to increase recall when looking for matches between documents: + - number normalization + - weekday normalization + - month normalization - lowercase text - replace all whitespace with a single " " - remove all punctuation @@ -44,47 +54,73 @@ def simplify_text(text: str, config=DEF_TEXT_NORM_CONFIG) -> str: Returns: modified text """ + # We should apply the transformation in such order so that, we do same transformations + # incrementaly as we would do if we applied each from scratch. + # Eg. + # 1|2|3 -> 000 + # vs + # 1|2|3 -> 0 + # lower case if config.lowercase: text = text.lower() - # remove consecutive spaces, newlines, tabs in the middle and in the beginning / end - if config.norm_whitespace: - text = WHITESPACE_PATTERN.sub(" ", text.strip()) - # remove punctuation - if config.remove_punctuation: - text = text.translate(str.maketrans("", "", PUNCTUATION)) - # diacritics/unicode normalization - if config.norm_unicode_diacritics: - text = "".join(c for c in unicodedata.normalize("NFD", text) if unicodedata.category(c) != "Mn") if config.norm_numbers: text = NUMBERS_PATTERN.sub("0", text) if config.norm_weekdays: text = WEEKDAYS_PATTERN.sub("WEEKDAY", text) if config.norm_monthnames: text = MONTHS_PATTERN.sub("MONTH", text) - return text.strip() - - -# https://github.com/ekzhu/datasketch/blob/master/datasketch/hashfunc.py -def sha1_hash32(data): - """A 32-bit hash function based on SHA1. - - Args: - data (bytes): the data to generate 32-bit integer hash from. - Returns: - int: an integer hash value that can be encoded using 32 bits. - """ - return struct.unpack(" int: + if not self._token_size: + self._token_size = 4 if self.tokenizer.get_vocab_size() > np.iinfo(np.uint16).max + 1 else 2 + return self._token_size + + @property + def token_format(self) -> str: + return "I" if self.token_size == 4 else "H" @property def tokenizer(self) -> "Tokenizer": if not self._tokenizer: if not self.tokenizer_name_or_path: raise ValueError("self.tokenizer_name_or_path needs to be set!") - self._tokenizer: "Tokenizer" = load_tokenizer(self.tokenizer_name_or_path) + self._tokenizer = load_tokenizer(self.tokenizer_name_or_path) if self._post_processor: self._tokenizer.post_processor = self._post_processor elif self.eos_token: @@ -44,3 +58,23 @@ def tokenizer(self) -> "Tokenizer": pair=None, ) return self._tokenizer + + +def batched(iterable, n): + """In python 3.12+ we could use itertools.batched instead + + One difference with itertools.batched: we return a list instead of a tuple + + Args: + iterable: + n: + + Returns: + + """ + # batched('ABCDEFG', 3) --> ABC DEF G + if n < 1: + raise ValueError("n must be at least one") + it = iter(iterable) + while batch := list(itertools.islice(it, n)): + yield batch diff --git a/src/datatrove/utils/typeshelper.py b/src/datatrove/utils/typeshelper.py index 52cf012c..5ea348b8 100644 --- a/src/datatrove/utils/typeshelper.py +++ b/src/datatrove/utils/typeshelper.py @@ -19,10 +19,103 @@ class Languages: portuguese = "pt" italian = "it" french = "fr" - swedish = "sv" romanian = "ro" german = "de" latin = "la" + czech = "cs" + danish = "da" + finnish = "fi" + greek = "el" + norwegian = "no" + polish = "pl" + russian = "ru" + slovenian = "sl" + swedish = "sv" + turkish = "tr" + dutch = "nl" + chinese = "zh" + japanese = "ja" + vietnamese = "vi" + indonesian = "id" + persian = "fa" + korean = "ko" + arabic = "ar" + thai = "th" + hindi = "hi" + bengali = "bn" + tamil = "ta" + hungarian = "hu" + ukrainian = "uk" + slovak = "sk" + bulgarian = "bg" + catalan = "ca" + croatian = "hr" + serbian = "sr" + lithuanian = "lt" + estonian = "et" + hebrew = "he" + latvian = "lv" + serbocroatian = "sh" # Deprecated + albanian = "sq" + azerbaijani = "az" + icelandic = "is" + macedonian = "mk" + georgian = "ka" + galician = "gl" + armenian = "hy" + basque = "eu" + swahili = "sw" + malay = "ms" + tagalog = "tl" + javanese = "jv" + punjabi = "pa" + bihari = "bh" # Deprecated + gujarati = "gu" + yoruba = "yo" + marathi = "mr" + urdu = "ur" + amharic = "am" + telugu = "te" + malayalam = "ml" + kannada = "kn" + nepali = "ne" + kazakh = "kk" + belarusian = "be" + burmese = "my" + esperanto = "eo" + uzbek = "uz" + khmer = "km" + tajik = "tg" + welsh = "cy" + norwegian_nynorsk = "nn" + bosnian = "bs" + sinhala = "si" + tatar = "tt" + afrikaans = "af" + oriya = "or" + kirghiz = "ky" + irish = "ga" + occitan = "oc" + kurdish = "ku" + lao = "lo" + luxembourgish = "lb" + bashkir = "ba" + western_frisian = "fy" + pashto = "ps" + maltese = "mt" + breton = "bt" + assamese = "as" + malagasy = "mg" + divehi = "dv" + yiddish = "yi" + somali = "so" + sanskrit = "sa" + sindhi = "sd" + turkmen = "tk" + south_azerbaijani = "azb" + sorani = "ckb" + cebuano = "ceb" + war = "war" class StatHints: diff --git a/src/datatrove/utils/word_tokenizers.py b/src/datatrove/utils/word_tokenizers.py new file mode 100644 index 00000000..309803f9 --- /dev/null +++ b/src/datatrove/utils/word_tokenizers.py @@ -0,0 +1,329 @@ +from abc import ABC, abstractmethod +from typing import Callable, Iterator + +from datatrove.utils._import_utils import check_required_dependencies +from datatrove.utils.typeshelper import Languages + + +def strip_strings(els: list[str]) -> list[str]: + return [el.strip() for el in els if len(el.strip()) > 0] + + +def simple_span_tokenize(text: str, sents: list[str]) -> Iterator[tuple[int, int]]: + start_index = 0 + for sent in sents: + start_char = text.index(sent, start_index) + end_char = start_char + len(sent) + start_index = end_char + yield start_char, end_char + + +class WordTokenizer(ABC): + @abstractmethod + def word_tokenize(self, text: str) -> list[str]: + pass + + @abstractmethod + def sent_tokenize(self, text: str) -> list[str]: + pass + + @abstractmethod + def span_tokenize(self, text: str) -> list[tuple[int, int]]: + pass + + +class NLTKTokenizer(WordTokenizer): + def __init__(self, punkt_language: str): + super().__init__() + check_required_dependencies(f"{punkt_language} word tokenizer", ["nltk"]) + self.punkt_language = punkt_language + self._tokenizer = None + + @property + def tokenizer(self): + if not self._tokenizer: + from nltk import load + + self._tokenizer = load(f"tokenizers/punkt/{self.punkt_language}.pickle") + return self._tokenizer + + def word_tokenize(self, text) -> list[str]: + from nltk.tokenize import word_tokenize + + tokens = word_tokenize(text, language=self.punkt_language) + return strip_strings(tokens) + + def sent_tokenize(self, text: str) -> list[str]: + from nltk.tokenize import sent_tokenize + + sents = sent_tokenize(text, language=self.punkt_language) + return strip_strings(sents) + + def span_tokenize(self, text: str) -> list[tuple[int, int]]: + return list(self.tokenizer.span_tokenize(text)) + + +class SpaCyTokenizer(WordTokenizer): + def __init__(self, spacy_language: str, config=None): + super().__init__() + check_required_dependencies(f"{spacy_language} word tokenizer", ["spacy"]) + if spacy_language == "vi": + check_required_dependencies(f"{spacy_language} word tokenizer", ["pyvi"]) + elif spacy_language == "zh": + check_required_dependencies(f"{spacy_language} word tokenizer", ["jieba"]) + self.spacy_language = spacy_language + self.config = config + self._tokenizer = None + + @property + def tokenizer(self): + if not self._tokenizer: + import spacy + + if self.config is None: + self._tokenizer = spacy.blank(self.spacy_language) + else: + self._tokenizer = spacy.blank(self.spacy_language, config=self.config) + self._tokenizer.add_pipe("sentencizer") + return self._tokenizer + + def word_tokenize(self, text: str) -> list[str]: + self.tokenizer.max_length = len(text) + 10 + tokens = [token.text for token in self.tokenizer(text, disable=["parser", "tagger", "ner"])] + return strip_strings(tokens) + + def sent_tokenize(self, text: str) -> list[str]: + self.tokenizer.max_length = len(text) + 10 + sents = [sent.text for sent in self.tokenizer(text, disable=["parser", "tagger", "ner"]).sents] + return strip_strings(sents) + + def span_tokenize(self, text: str) -> list[tuple[int, int]]: + return [ + (sent.start_char, sent.end_char) + for sent in self.tokenizer(text, disable=["parser", "tagger", "ner"]).sents + ] + + +class StanzaTokenizer(WordTokenizer): + def __init__(self, stanza_language: str, **stanza_kwargs): + super().__init__() + check_required_dependencies(f"{stanza_language} word tokenizer", ["stanza"]) + self.stanza_language = stanza_language + self.stanza_kwargs = stanza_kwargs + self._tokenizer = None + + @property + def tokenizer(self): + if not self._tokenizer: + import stanza + from stanza.pipeline.core import DownloadMethod + + self._tokenizer = stanza.Pipeline( + self.stanza_language, + processors="tokenize", + download_method=DownloadMethod.REUSE_RESOURCES, + **self.stanza_kwargs, + ) + + return self._tokenizer + + def word_tokenize(self, text: str) -> list[str]: + doc = self.tokenizer(text) + tokens = [token.text for sentence in doc.sentences for token in sentence.tokens] + return strip_strings(tokens) + + def sent_tokenize(self, text: str) -> list[str]: + doc = self.tokenizer(text) + sents = [sentence.text for sentence in doc.sentences] + return strip_strings(sents) + + def span_tokenize(self, text: str) -> list[tuple[int, int]]: + doc = self.tokenizer(text) + return [(sent.tokens[0].start_char, sent.tokens[-1].end_char) for sent in doc.sentences] + + +class ThaiTokenizer(WordTokenizer): + def __init__(self): + super().__init__() + check_required_dependencies("th word tokenizer", ["pythainlp"]) + + def word_tokenize(self, text: str) -> list[str]: + from pythainlp.tokenize import word_tokenize as th_word_tokenize + + tokens = th_word_tokenize(text, keep_whitespace=False, engine="newmm-safe") + return strip_strings(tokens) + + def sent_tokenize(self, text: str) -> list[str]: + from pythainlp.tokenize import sent_tokenize as th_sent_tokenize + + sents = th_sent_tokenize(text) + return strip_strings(sents) + + def span_tokenize(self, text: str) -> list[tuple[int, int]]: + sents = self.sent_tokenize(text) + return list(simple_span_tokenize(text, sents)) + + +class IndicNLPTokenizer(WordTokenizer): + def __init__(self, language: str): + super().__init__() + self.language = language + check_required_dependencies(f"{language} word tokenizer", [("indicnlp", "indic-nlp-library")]) + + def word_tokenize(self, text) -> list[str]: + from indicnlp.tokenize.indic_tokenize import trivial_tokenize as indicnlp_trivial_tokenize + + tokens = indicnlp_trivial_tokenize(text, self.language) + return strip_strings(tokens) + + def sent_tokenize(self, text: str) -> list[str]: + from indicnlp.tokenize.sentence_tokenize import sentence_split + + sents = sentence_split(text, lang=self.language) + return strip_strings(sents) + + def span_tokenize(self, text: str) -> list[tuple[int, int]]: + sents = self.sent_tokenize(text) + return list(simple_span_tokenize(text, sents)) + + +class KiwiTokenizer(WordTokenizer): + def __init__(self, model_type="sbg"): + super().__init__() + check_required_dependencies("ko word tokenizer", ["kiwipiepy"]) + self.model_type = model_type + self._tokenizer = None + + @property + def tokenizer(self): + if not self._tokenizer: + from kiwipiepy import Kiwi + + self._tokenizer = Kiwi(model_type=self.model_type) + return self._tokenizer + + def word_tokenize(self, text: str) -> list[str]: + tokens = [token.form for token in self.tokenizer.tokenize(text)] + return strip_strings(tokens) + + def sent_tokenize(self, text: str) -> list[str]: + sents = [sent.text for sent in self.tokenizer.split_into_sents(text)] + return strip_strings(sents) + + def span_tokenize(self, text: str) -> list[tuple[int, int]]: + return [(sent.start, sent.end) for sent in self.tokenizer.split_into_sents(text)] + + +# If you know a better tokenizer or better proxy language, please submit a PR +WORD_TOKENIZER_FACTORY: dict[str, Callable[[], WordTokenizer]] = { + Languages.english: lambda: NLTKTokenizer("english"), + Languages.korean: lambda: KiwiTokenizer(), + Languages.german: lambda: NLTKTokenizer("german"), + Languages.french: lambda: NLTKTokenizer("french"), + Languages.czech: lambda: NLTKTokenizer("czech"), + Languages.danish: lambda: NLTKTokenizer("danish"), + Languages.dutch: lambda: NLTKTokenizer("dutch"), + Languages.estonian: lambda: NLTKTokenizer("estonian"), + Languages.finnish: lambda: NLTKTokenizer("finnish"), + Languages.greek: lambda: NLTKTokenizer("greek"), + Languages.italian: lambda: NLTKTokenizer("italian"), + Languages.malayalam: lambda: NLTKTokenizer("malayalam"), + Languages.norwegian: lambda: NLTKTokenizer("norwegian"), + Languages.polish: lambda: NLTKTokenizer("polish"), + Languages.portuguese: lambda: NLTKTokenizer("portuguese"), + Languages.russian: lambda: NLTKTokenizer("russian"), + Languages.slovenian: lambda: NLTKTokenizer("slovene"), + Languages.spanish: lambda: NLTKTokenizer("spanish"), + Languages.swedish: lambda: NLTKTokenizer("swedish"), + Languages.turkish: lambda: NLTKTokenizer("turkish"), + Languages.chinese: lambda: SpaCyTokenizer("zh", {"nlp": {"tokenizer": {"segmenter": "jieba"}}}), + Languages.japanese: lambda: StanzaTokenizer("ja"), + Languages.vietnamese: lambda: SpaCyTokenizer("vi"), + Languages.indonesian: lambda: SpaCyTokenizer("id"), + Languages.persian: lambda: SpaCyTokenizer("fa"), + Languages.arabic: lambda: SpaCyTokenizer("ar"), + Languages.hindi: lambda: SpaCyTokenizer("hi"), + Languages.tamil: lambda: SpaCyTokenizer("ta"), + Languages.urdu: lambda: SpaCyTokenizer("ur"), + Languages.marathi: lambda: SpaCyTokenizer("mr"), + Languages.telugu: lambda: SpaCyTokenizer("te"), + Languages.hungarian: lambda: SpaCyTokenizer("hu"), + Languages.romanian: lambda: SpaCyTokenizer("ro"), + Languages.ukrainian: lambda: SpaCyTokenizer("uk"), + Languages.slovak: lambda: SpaCyTokenizer("sk"), + Languages.bulgarian: lambda: SpaCyTokenizer("bg"), + Languages.catalan: lambda: SpaCyTokenizer("ca"), + Languages.croatian: lambda: SpaCyTokenizer("hr"), + Languages.latin: lambda: SpaCyTokenizer("la"), + Languages.serbian: lambda: SpaCyTokenizer("sr"), + Languages.lithuanian: lambda: SpaCyTokenizer("lt"), + Languages.hebrew: lambda: SpaCyTokenizer("he"), + Languages.latvian: lambda: SpaCyTokenizer("lv"), + Languages.icelandic: lambda: SpaCyTokenizer("is"), + Languages.armenian: lambda: SpaCyTokenizer("hy"), + Languages.basque: lambda: SpaCyTokenizer("eu"), + Languages.thai: lambda: ThaiTokenizer(), + Languages.tagalog: lambda: SpaCyTokenizer("tl"), + Languages.albanian: lambda: SpaCyTokenizer("sq"), + Languages.macedonian: lambda: SpaCyTokenizer("mk"), + Languages.azerbaijani: lambda: SpaCyTokenizer("az"), + Languages.amharic: lambda: SpaCyTokenizer("am"), + Languages.bengali: lambda: SpaCyTokenizer("bn"), + Languages.malay: lambda: SpaCyTokenizer("ms"), + Languages.urdu: lambda: SpaCyTokenizer("ur"), + Languages.nepali: lambda: SpaCyTokenizer("ne"), + Languages.kazakh: lambda: StanzaTokenizer("kk"), + Languages.gujarati: lambda: SpaCyTokenizer("gu"), + Languages.kannada: lambda: SpaCyTokenizer("kn"), + Languages.welsh: lambda: StanzaTokenizer("cy"), + Languages.norwegian_nynorsk: lambda: NLTKTokenizer( + "norwegian" + ), # TODO: change to SpaCyTokenizer("nn") when spacy version>=3.7.4 + Languages.sinhala: lambda: SpaCyTokenizer("si"), + Languages.tatar: lambda: SpaCyTokenizer("tt"), + Languages.afrikaans: lambda: SpaCyTokenizer("af"), + Languages.kirghiz: lambda: SpaCyTokenizer("ky"), + Languages.irish: lambda: SpaCyTokenizer("ga"), + Languages.luxembourgish: lambda: SpaCyTokenizer("lb"), + Languages.maltese: lambda: StanzaTokenizer("mt"), + Languages.sanskrit: lambda: SpaCyTokenizer("sa"), + Languages.yoruba: lambda: SpaCyTokenizer("yo"), + Languages.serbocroatian: lambda: SpaCyTokenizer("sr"), + Languages.oriya: lambda: IndicNLPTokenizer("or"), + Languages.punjabi: lambda: IndicNLPTokenizer("sa"), + Languages.assamese: lambda: IndicNLPTokenizer("as"), + Languages.war: lambda: IndicNLPTokenizer("war"), + Languages.sindhi: lambda: IndicNLPTokenizer("sd"), + Languages.bosnian: lambda: SpaCyTokenizer("hr"), # Proxy + Languages.belarusian: lambda: SpaCyTokenizer("uk"), # Proxy + Languages.galician: lambda: NLTKTokenizer("portuguese"), # Proxy + Languages.esperanto: lambda: NLTKTokenizer("english"), # Proxy + Languages.occitan: lambda: SpaCyTokenizer("ca"), # Proxy + Languages.cebuano: lambda: NLTKTokenizer("english"), # Proxy + Languages.swahili: lambda: NLTKTokenizer("english"), # Proxy + Languages.javanese: lambda: NLTKTokenizer("english"), # Proxy + Languages.uzbek: lambda: NLTKTokenizer("turkish"), # Proxy, alternative ru + Languages.tajik: lambda: SpaCyTokenizer("ru"), # Proxy + Languages.kurdish: lambda: NLTKTokenizer("english"), # Proxy, multiple scripts! + Languages.sorani: lambda: SpaCyTokenizer("fa"), # Proxy + Languages.south_azerbaijani: lambda: SpaCyTokenizer("fa"), # Proxy + Languages.bashkir: lambda: SpaCyTokenizer("tt"), # Proxy + Languages.western_frisian: lambda: NLTKTokenizer("dutch"), # Proxy + Languages.breton: lambda: StanzaTokenizer("cy"), # Proxy + Languages.malagasy: lambda: NLTKTokenizer("english"), # Proxy + Languages.yiddish: lambda: SpaCyTokenizer("he"), # Proxy + Languages.somali: lambda: NLTKTokenizer("english"), # Proxy + Languages.turkmen: lambda: NLTKTokenizer("turkish"), # Proxy + Languages.pashto: lambda: SpaCyTokenizer("xx"), # Proxy +} + +WORD_TOKENIZER_CACHE: dict[str, WordTokenizer] = {} + + +def load_word_tokenizer(language: str) -> WordTokenizer: + if language not in WORD_TOKENIZER_CACHE: + if language not in WORD_TOKENIZER_FACTORY: + raise ValueError(f"Language '{language}' doesn't have a tokenizer.") + tokenizer = WORD_TOKENIZER_FACTORY[language]() + WORD_TOKENIZER_CACHE[language] = tokenizer + return WORD_TOKENIZER_CACHE[language] diff --git a/tests/pipeline/test_adapter_reader.py b/tests/pipeline/test_adapter_reader.py new file mode 100644 index 00000000..409e97ea --- /dev/null +++ b/tests/pipeline/test_adapter_reader.py @@ -0,0 +1,25 @@ +import unittest + +from datatrove.pipeline.readers import HuggingFaceDatasetReader + +from ..utils import require_datasets + + +@require_datasets +class TestAdapterReader(unittest.TestCase): + def test_adapter_reader(self): + def custom_adapter(self, data, path, id_in_file): + return { + "text": data[self.text_key] + "\n" + data["best_answer"], # Example usage of self to access text_key + "id": data.pop(self.id_key, f"{path}/{id_in_file}"), + } + + reader = HuggingFaceDatasetReader( + "truthful_qa", + dataset_options={"name": "generation", "split": "validation"}, + text_key="question", + id_key="", + adapter=custom_adapter, + ) + data = list(reader()) + assert len(data[0].text) == 104 diff --git a/tests/pipeline/test_bloom_filter.py b/tests/pipeline/test_bloom_filter.py index 4b603a61..aca0b24c 100644 --- a/tests/pipeline/test_bloom_filter.py +++ b/tests/pipeline/test_bloom_filter.py @@ -3,7 +3,8 @@ import unittest from datatrove.data import Document -from datatrove.pipeline.dedup.bloom_filter import SingleBloomFilter +from datatrove.pipeline.dedup.bloom_filter import BloomFilterConfig, SingleBloomFilter +from tests.utils import use_hash_configs TEXT_0 = ( @@ -84,14 +85,18 @@ TARGETS = [True] * 8 + [False] * 3 -class SentenceDedup(unittest.TestCase): +class BloomFilter(unittest.TestCase): def setUp(self): # Create a temporary directory self.tmp_dir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, self.tmp_dir) - def test_sd(self): - bloom_filter = SingleBloomFilter(output_folder=self.tmp_dir, m_bytes=2**10 - 1, k=7, expected_elements=866) + @use_hash_configs(precision=[32]) + def test_sd(self, hash_config): + bloom_filter = SingleBloomFilter( + output_folder=self.tmp_dir, + config=BloomFilterConfig(m_bytes=2**10 - 1, k=7, expected_elements=866, hash_config=hash_config), + ) for doc_idx, doc in enumerate(DOCS): is_unique = bloom_filter.step(doc) diff --git a/tests/pipeline/test_filters.py b/tests/pipeline/test_filters.py index ebbfae27..a5bd5174 100644 --- a/tests/pipeline/test_filters.py +++ b/tests/pipeline/test_filters.py @@ -84,10 +84,21 @@ def test_lambda(self): def test_language(self): language_filter = LanguageFilter(languages=("en", "it")) - self.assertTrue(language_filter.filter(Document(text=TEXT_LF_1, id="0"))) - self.assertFalse(language_filter.filter(Document(text=TEXT_LF_2, id="0"))) - self.assertFalse(language_filter.filter(Document(text=TEXT_LF_3, id="0"))) - self.assertTrue(language_filter.filter(Document(text=TEXT_LF_4, id="0"))) + doc1 = Document(text=TEXT_LF_1, id="0") + self.assertTrue(language_filter.filter(doc1)) + self.assertEqual(doc1.metadata["language"], "en") + + doc2 = Document(text=TEXT_LF_2, id="0") + self.assertFalse(language_filter.filter(doc2)) + self.assertEqual(doc2.metadata["language"], "fr") + + doc3 = Document(text=TEXT_LF_3, id="0") + self.assertFalse(language_filter.filter(doc3)) + self.assertEqual(doc3.metadata["language"], "pt") + + doc4 = Document(text=TEXT_LF_4, id="0") + self.assertTrue(language_filter.filter(doc4)) + self.assertEqual(doc4.metadata["language"], "it") def test_regex(self): regex_filter = RegexFilter(regex_exp=r"(?i)copyright") diff --git a/tests/pipeline/test_hf_reader.py b/tests/pipeline/test_hf_reader.py index 2c4267ba..5a02160c 100644 --- a/tests/pipeline/test_hf_reader.py +++ b/tests/pipeline/test_hf_reader.py @@ -12,4 +12,29 @@ def test_read_dataset(self): "truthful_qa", dataset_options={"name": "generation", "split": "validation"}, text_key="question" ) data = list(reader()) - assert len(data) == 817 + self.assertEqual(len(data), 817) + + def test_read_streaming_dataset(self): + reader = HuggingFaceDatasetReader( + "truthful_qa", + dataset_options={"name": "generation", "split": "validation"}, + text_key="question", + streaming=True, + ) + data = list(reader()) + self.assertEqual(len(data), 817) + + def test_sharding(self): + for shards in [1, 3]: + for streaming in [True, False]: + reader = HuggingFaceDatasetReader( + "huggingface/datatrove-tests", + dataset_options={"name": f"sharding-{shards}", "split": "train"}, + text_key="text", + streaming=streaming, + ) + data0 = list(reader(rank=0, world_size=2)) + data1 = list(reader(rank=1, world_size=2)) + + self.assertEqual(len(data0), 3) + self.assertEqual(len(data1), 2) diff --git a/tests/pipeline/test_minhash.py b/tests/pipeline/test_minhash.py index 3c67ebed..28bf7865 100644 --- a/tests/pipeline/test_minhash.py +++ b/tests/pipeline/test_minhash.py @@ -20,7 +20,7 @@ read_sigs, ) -from ..utils import require_nltk +from ..utils import require_nltk, require_xxhash, use_hash_configs lorem_ipsum = """Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aliquam euismod vel ante vitae rhoncus. Curabitur eu lectus et magna maximus facilisis eu non magna. Maecenas sed velit vitae est ornare placerat. Vestibulum quis consectetur nunc, a feugiat lorem. Cras in ipsum fringilla, vestibulum urna sit amet, viverra tortor. Orci varius natoque penatibus et magnis dis parturient montes, nascetur ridiculus mus. Morbi euismod vestibulum elit id placerat. Fusce malesuada ultricies condimentum. Cras tincidunt eget lorem nec hendrerit. Aenean mattis arcu dolor, id semper velit ullamcorper malesuada. Aliquam non ipsum et eros venenatis aliquet. Proin eleifend interdum scelerisque. Interdum et malesuada fames ac ante ipsum primis in faucibus. Mauris nunc sapien, molestie eget convallis at, maximus nec ipsum. Morbi quam diam, blandit ut mollis at, varius eu tellus. Maecenas sem justo, porttitor at odio nec, interdum posuere ex. @@ -35,220 +35,215 @@ @require_nltk +@require_xxhash class TestMinhash(unittest.TestCase): def setUp(self): # Create a temporary directory self.tmp_dir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, self.tmp_dir) - def test_signatures(self): - for use_64bit_hashes in (True, False): - config = MinhashConfig(use_64bit_hashes=use_64bit_hashes) - minhash = MinhashDedupSignature(output_folder=os.path.join(self.tmp_dir, "signatures1"), config=config) - shingles = minhash.get_shingles(lorem_ipsum) - sig = minhash.get_signature(shingles) - minhash2 = MinhashDedupSignature(output_folder=os.path.join(self.tmp_dir, "signatures2"), config=config) - # check consistency - assert sig == minhash2.get_signature(shingles) - - # check correct number of outputs - assert len(sig) == minhash.config.num_buckets - assert all((len(x) == minhash.config.hashes_per_bucket for x in sig)) - - # check similarity approximation - for pctd in range(0, 100, 5): - dec = pctd / 100 - endp = floor(len(lorem_ipsum) * dec) - textd = lorem_ipsum[:endp] + lorem_ipsum[len(lorem_ipsum) - 1 : endp : -1] - sigd = minhash.get_signature(minhash.get_shingles(textd)) - simil = ( - sum([1 if a == b else 0 for ba, bb in zip(sig, sigd) for a, b in zip(ba, bb)]) / minhash.num_hashes - ) - assert dec - 0.21 < simil < dec + 0.21 - - # check output file format and order - samples = [Document(f"sample {i}, {lorem_ipsum[i:: 10]}", id="test") for i in range(100)] - minhash(samples) - for bi in range(config.num_buckets): - with minhash.output_folder.open(f"bucket_{bi:03d}/00000.minhash.sig", "rb") as f: - prev = None - doc_ids = set() - S = np.dtype(config.hash_dtype).itemsize - for di in range(100): - data = struct.unpack( - f"<%s{config.hash_format}" % config.hashes_per_bucket, f.read(config.hashes_per_bucket * S) - ) - doc_id = struct.unpack("= prev - prev = data - assert 0 <= doc_id < 100 - doc_ids.add(doc_id) - assert len(doc_ids) == 100 - - def test_buckets_and_cluster(self): - for use_64bit_hashes in (True, False): - sigs_folder = os.path.join(self.tmp_dir, "b_signatures") - buckets_folder = os.path.join(self.tmp_dir, "b_buckets") - clusters_folder = os.path.join(self.tmp_dir, "b_clusters") - config = MinhashConfig(use_64bit_hashes=use_64bit_hashes) - - signatures_block = MinhashDedupSignature(output_folder=sigs_folder, config=config) - buckets_block = MinhashDedupBuckets( - input_folder=sigs_folder, - output_folder=buckets_folder, - config=config, - ) - - clusters = [[0, 20, 50], [400, 420], [800, 810, 820, 840, 860], [1205, 1215, 1225, 1245], [1600], [2000]] - - cluster_samples = [ - Document(text=lorem_ipsum[x : x + 300], id=f"{ci}_{xi}", metadata={"ci": ci, "xi": xi}) - for ci, cluster in enumerate(clusters) - for xi, x in enumerate(cluster) - ] - - signatures_block(cluster_samples) - # test file read - for fi, file in enumerate(buckets_block.input_folder.list_files()): - last = None - for sig in read_sigs(buckets_block.input_folder.open(file, "rb"), fi, config): - assert 0 <= sig.doc_id < 100 - assert last is None or sig.sig >= last - assert len(sig.sig) == config.hashes_per_bucket - last = sig.sig - - # test duplicate pairs - for b in range(config.num_buckets * 10): - buckets_block(None, rank=b, world_size=config.num_buckets * 10) - bucket_results_folder = get_datafolder(buckets_folder) - dup_files = bucket_results_folder.list_files(glob_pattern="*.dups") - pairs = defaultdict(set) - for dup_file in dup_files: - with bucket_results_folder.open(dup_file, "rb") as df: - while data := df.read(4 * struct.calcsize("I")): - f1, d1, f2, d2 = struct.unpack("<4I", data) - assert f1 == f2 == 0 - assert cluster_samples[d1].metadata["ci"] == cluster_samples[d2].metadata["ci"] - pairs[d1].add(d2) - pairs[d2].add(d1) - doc_id = 0 - for cluster in clusters: - print(cluster) - print(pairs) - for a in range(doc_id, doc_id + len(cluster)): - assert len(cluster) < 2 or any( - a in pairs[b] for b in range(doc_id, doc_id + len(cluster)) if a != b + @use_hash_configs() + def test_signatures(self, hash_config): + config = MinhashConfig(hash_config=hash_config) + minhash = MinhashDedupSignature(output_folder=os.path.join(self.tmp_dir, "signatures1"), config=config) + shingles = minhash.get_shingles(lorem_ipsum) + sig = minhash.get_signature(shingles) + minhash2 = MinhashDedupSignature(output_folder=os.path.join(self.tmp_dir, "signatures2"), config=config) + # check consistency + assert sig == minhash2.get_signature(shingles) + + # check correct number of outputs + assert len(sig) == minhash.config.num_buckets + assert all((len(x) == minhash.config.hashes_per_bucket for x in sig)) + + # check similarity approximation + for pctd in range(0, 100, 5): + dec = pctd / 100 + endp = floor(len(lorem_ipsum) * dec) + textd = lorem_ipsum[:endp] + lorem_ipsum[len(lorem_ipsum) - 1 : endp : -1] + sigd = minhash.get_signature(minhash.get_shingles(textd)) + simil = sum([1 if a == b else 0 for ba, bb in zip(sig, sigd) for a, b in zip(ba, bb)]) / minhash.num_hashes + assert dec - 0.21 < simil < dec + 0.21 + + # check output file format and order + samples = [Document(f"sample {i}, {lorem_ipsum[i:: 10]}", id="test") for i in range(100)] + minhash(samples) + for bi in range(config.num_buckets): + with minhash.output_folder.open(f"bucket_{bi:03d}/00000.minhash.sig", "rb") as f: + prev = None + doc_ids = set() + S = np.dtype(config.hash_config.np_dtype).itemsize + for di in range(100): + data = struct.unpack( + f"<%s{config.hash_config.struct_format}" % config.hashes_per_bucket, + f.read(config.hashes_per_bucket * S), + ) + doc_id = struct.unpack("= prev + prev = data + assert 0 <= doc_id < 100 + doc_ids.add(doc_id) + assert len(doc_ids) == 100 + + @use_hash_configs() + def test_buckets_and_cluster(self, hash_config): + sigs_folder = os.path.join(self.tmp_dir, "b_signatures") + buckets_folder = os.path.join(self.tmp_dir, "b_buckets") + clusters_folder = os.path.join(self.tmp_dir, "b_clusters") + config = MinhashConfig(hash_config=hash_config) + + signatures_block = MinhashDedupSignature(output_folder=sigs_folder, config=config) + buckets_block = MinhashDedupBuckets( + input_folder=sigs_folder, + output_folder=buckets_folder, + config=config, + ) + + clusters = [[0, 20, 50], [400, 420], [800, 810, 820, 840, 860], [1205, 1215, 1225, 1245], [1600], [2000]] + + cluster_samples = [ + Document(text=lorem_ipsum[x : x + 400], id=f"{ci}_{xi}", metadata={"ci": ci, "xi": xi}) + for ci, cluster in enumerate(clusters) + for xi, x in enumerate(cluster) + ] + + signatures_block(cluster_samples) + # test file read + for fi, file in enumerate(buckets_block.input_folder.list_files()): + last = None + for sig in read_sigs(buckets_block.input_folder.open(file, "rb"), fi, config): + assert 0 <= sig.doc_id < 100 + assert last is None or sig.sig >= last + assert len(sig.sig) == config.hashes_per_bucket + last = sig.sig + + # test duplicate pairs + for b in range(config.num_buckets * 10): + buckets_block(None, rank=b, world_size=config.num_buckets * 10) + bucket_results_folder = get_datafolder(buckets_folder) + dup_files = bucket_results_folder.list_files(glob_pattern="*.dups") + pairs = defaultdict(set) + for dup_file in dup_files: + with bucket_results_folder.open(dup_file, "rb") as df: + while data := df.read(4 * struct.calcsize("I")): + f1, d1, f2, d2 = struct.unpack("<4I", data) + assert f1 == f2 == 0 + assert cluster_samples[d1].metadata["ci"] == cluster_samples[d2].metadata["ci"] + pairs[d1].add(d2) + pairs[d2].add(d1) + doc_id = 0 + for cluster in clusters: + print(cluster) + print(pairs) + for a in range(doc_id, doc_id + len(cluster)): + assert len(cluster) < 2 or any(a in pairs[b] for b in range(doc_id, doc_id + len(cluster)) if a != b) + doc_id += len(cluster) + + # clustering + cluster_block = MinhashDedupCluster(bucket_results_folder, clusters_folder, config=config) + cluster_block(None) + + cluster_results_folder = get_datafolder(clusters_folder) + remove_ids = set() + with cluster_results_folder.open(cluster_results_folder.list_files()[0], "rb") as df: + while data := df.read(struct.calcsize("I")): + remove_ids.add(struct.unpack("[]”.VERY.”very@\\ "very”.unusual@strange.example.com + + + +List of Invalid Email Addresses + +plainaddress +#@%^%#$@#$@#.com +@example.com +Joe Smith +email.example.com +email@example@example.com +.email@example.com +email.@example.com +email..email@example.com +あいうえお@example.com +email@example.com (Joe Smith) +email@example +email@-example.com +email@example.web +email@111.222.333.44444 +email@example..com +Abc..123@example.com + + + +List of Strange Invalid Email Addresses + +”(),:;<>[\\]@example.com +just”not”right@example.com +this\\ is"really"not\\allowed@example.com""" + +EMAIL_TEST_OUTPUT = r"""Use: for testing against email regex +ref: http://codefool.tumblr.com/post/15288874550/list-of-valid-and-invalid-email-addresses + + +List of Valid Email Addresses + +EMAIL +EMAIL +EMAIL +EMAIL +EMAIL +EMAIL +"email"@example.com +EMAIL +EMAIL +EMAIL +EMAIL +EMAIL +EMAIL +EMAIL +EMAIL + + + +List of Strange Valid Email Addresses + +much.”more\ unusual”@example.com +very.unusual.”@”.EMAIL +very.”(),:;<>[]”.VERY.”very@\ "very”.EMAIL + + + +List of Invalid Email Addresses + +plainaddress +#@%^%#$@#$@#.com +@example.com +Joe Smith +email.example.com +email@EMAIL +.EMAIL +email.@example.com +email..EMAIL +あいうえお@example.com +EMAIL (Joe Smith) +email@example +email@-example.com +EMAIL +EMAIL +email@example..com +Abc..EMAIL + + + +List of Strange Invalid Email Addresses + +”(),:;<>[\]@example.com +just”not”EMAIL +this\ is"really"not\EMAIL""" + + +class TestPIIRemoval(unittest.TestCase): + def test_pii_removal(self): + remover = PIIFormatter( + email_replacement="EMAIL", + ip_replacement="IP", + ) + self.assertEqual(remover.format(IP_TEST_INPUT), IP_TEST_OUTPUT) + self.assertEqual(remover.format(EMAIL_TEST_INPUT), EMAIL_TEST_OUTPUT) diff --git a/tests/pipeline/test_sentence_deduplication.py b/tests/pipeline/test_sentence_deduplication.py index d988c87c..8fd96acc 100644 --- a/tests/pipeline/test_sentence_deduplication.py +++ b/tests/pipeline/test_sentence_deduplication.py @@ -13,7 +13,7 @@ SentenceFindDedups, ) -from ..utils import require_nltk +from ..utils import require_nltk, require_xxhash, use_hash_configs def get_random_string(n: int = 20): @@ -136,16 +136,20 @@ def get_random_string(n: int = 20): @require_nltk +@require_xxhash class SentenceDedup(unittest.TestCase): def setUp(self): # Create a temporary directory self.tmp_dir = tempfile.mkdtemp() - self.addCleanup(shutil.rmtree, self.tmp_dir) + self.addCleanup(shutil.rmtree, self.tmp_dir, ignore_errors=True) def test_sd(self): - signature_creation = SentenceDedupSignature(output_folder=self.tmp_dir + "/sigs") - find_duplicates = SentenceFindDedups(data_folder=self.tmp_dir + "/sigs", output_folder=self.tmp_dir + "/dups") - dedup_filter = SentenceDedupFilter(data_folder=self.tmp_dir + "/dups", config=SentDedupConfig(min_doc_words=0)) + config = SentDedupConfig(min_doc_words=0, min_num_sentences=0) + signature_creation = SentenceDedupSignature(output_folder=self.tmp_dir + "/sigs", config=config) + find_duplicates = SentenceFindDedups( + data_folder=self.tmp_dir + "/sigs", output_folder=self.tmp_dir + "/dups", config=config + ) + dedup_filter = SentenceDedupFilter(data_folder=self.tmp_dir + "/dups", config=config) signature_creation(data=DOCS) find_duplicates() @@ -153,10 +157,13 @@ def test_sd(self): self.assertEqual(doc.text, TARGETS[i]) def test_sd_worker(self): - signature_creation = SentenceDedupSignature(output_folder=self.tmp_dir + "/sigs") + config = SentDedupConfig(min_doc_words=0, min_num_sentences=0) + signature_creation = SentenceDedupSignature(output_folder=self.tmp_dir + "/sigs", config=config) - find_duplicates = SentenceFindDedups(data_folder=self.tmp_dir + "/sigs", output_folder=self.tmp_dir + "/dups") - dedup_filter = SentenceDedupFilter(data_folder=self.tmp_dir + "/dups", config=SentDedupConfig(min_doc_words=0)) + find_duplicates = SentenceFindDedups( + data_folder=self.tmp_dir + "/sigs", output_folder=self.tmp_dir + "/dups", config=config + ) + dedup_filter = SentenceDedupFilter(data_folder=self.tmp_dir + "/dups", config=config) signature_creation(data=DOCS, rank=0, world_size=2) signature_creation(data=DOCS_2, rank=1, world_size=2) @@ -168,11 +175,17 @@ def test_sd_worker(self): for i, doc in enumerate(dedup_filter(data=copy.deepcopy(DOCS_2), rank=1, world_size=2)): self.assertEqual(doc.text, TARGETS_WS2_1[i]) - def test_distributed_find_dups(self): - signature_creation = SentenceDedupSignature(output_folder=self.tmp_dir + "/sigs", finder_workers=50) - - find_duplicates = SentenceFindDedups(data_folder=self.tmp_dir + "/sigs", output_folder=self.tmp_dir + "/dups") - dedup_filter = SentenceDedupFilter(data_folder=self.tmp_dir + "/dups", config=SentDedupConfig(min_doc_words=0)) + @use_hash_configs() + def test_distributed_find_dups(self, hash_config): + config = SentDedupConfig(hash_config=hash_config, min_doc_words=0, min_num_sentences=0) + signature_creation = SentenceDedupSignature( + output_folder=self.tmp_dir + "/sigs", finder_workers=50, config=config + ) + + find_duplicates = SentenceFindDedups( + data_folder=self.tmp_dir + "/sigs", output_folder=self.tmp_dir + "/dups", config=config + ) + dedup_filter = SentenceDedupFilter(data_folder=self.tmp_dir + "/dups", config=config) signature_creation(data=DOCS, rank=0, world_size=2) signature_creation(data=DOCS_2, rank=1, world_size=2) diff --git a/tests/pipeline/test_stats.py b/tests/pipeline/test_stats.py new file mode 100644 index 00000000..adbbc357 --- /dev/null +++ b/tests/pipeline/test_stats.py @@ -0,0 +1,285 @@ +import json +import shutil +import tempfile +import unittest +from typing import get_args + +from datatrove.data import Document +from datatrove.io import get_datafolder +from datatrove.pipeline.stats import ( + DEFAULT_TOP_K_CONFIG, + GROUP, + STATS_MERGED_NAME, + DocStats, + LangStats, + LineStats, + ParagraphStats, + StatsMerger, + TokenStats, + TopKConfig, + WordsContaminationStats, + WordStats, +) +from datatrove.pipeline.stats.base import BaseStats +from datatrove.utils.stats import MetricStatsDict +from tests.utils import require_nltk, require_tldextract, require_tokenizers + + +class DummyStats(BaseStats): + def __init__( + self, output_folder, groups=get_args(GROUP), histogram_round_digits=2, top_k_config=DEFAULT_TOP_K_CONFIG + ): + super().__init__( + output_folder, + groups_to_compute=groups, + histogram_round_digits=histogram_round_digits, + top_k_config=top_k_config, + ) + + def extract_stats(self, doc: Document): + return {"stat": float(doc.text)} + + +DOCS = [ + Document("1.5", "1", metadata={"url": "test1.co.uk"}), + Document("2", "1", metadata={"url": "test1.co.uk"}), + Document("1", "2", metadata={"url": "test2.cz"}), + Document("1", "3", metadata={"url": "test3.cz"}), +] + + +@require_tldextract +class TestSummaryStats(unittest.TestCase): + def setUp(self): + # Create a temporary directory + self.tmp_dir = get_datafolder(tempfile.mkdtemp()) + self.addCleanup(shutil.rmtree, self.tmp_dir.path) + + def test_grouping(self): + summary_stats = DummyStats(output_folder=self.tmp_dir) + list(summary_stats.run(DOCS, 0, 1)) + + with self.tmp_dir.open("summary/stat/00000.json") as f: + stats = MetricStatsDict.from_dict(json.load(f)) + self.assertEqual(stats["summary"].total, 5.5) + + with self.tmp_dir.open("fqdn/stat/00000.json") as f: + stats = MetricStatsDict.from_dict(json.load(f)) + self.assertEqual(stats["test1.co.uk"].total, 3.5) + self.assertEqual(stats["test2.cz"].total, 1) + self.assertEqual(stats["test3.cz"].total, 1) + + with self.tmp_dir.open("suffix/stat/00000.json") as f: + stats = MetricStatsDict.from_dict(json.load(f)) + self.assertEqual(stats["co.uk"].total, 3.5) + self.assertEqual(stats["cz"].total, 2) + + with self.tmp_dir.open("histogram/stat/00000.json") as f: + stats = MetricStatsDict.from_dict(json.load(f)) + self.assertEqual(stats["1.0"].total, 2) + self.assertEqual(stats["1.5"].total, 1) + self.assertEqual(stats["2.0"].total, 1) + + def test_histogram_rounding(self): + summary_stats = DummyStats(output_folder=self.tmp_dir, histogram_round_digits=0) + list(summary_stats.run(DOCS, 0, 1)) + + with self.tmp_dir.open("histogram/stat/00000.json") as f: + stats = MetricStatsDict.from_dict(json.load(f)) + self.assertEqual(stats["1.0"].total, 2) + self.assertEqual(stats["2.0"].total, 2) + + def test_compute_top_k(self): + top_k_config = TopKConfig(top_k=1, top_k_groups=["fqdn"]) + summary_stats = DummyStats(output_folder=self.tmp_dir, top_k_config=top_k_config) + list(summary_stats.run(DOCS, 0, 1)) + + with self.tmp_dir.open("fqdn/stat/00000.json") as f: + stats = MetricStatsDict.from_dict(json.load(f)) + self.assertEqual(stats["test1.co.uk"].total, 3.5) + self.assertEqual(stats["test2.cz"].total, 0) + + def test_merging_stats(self): + summary_stats = DummyStats(output_folder=self.tmp_dir) + merge_stats = StatsMerger(self.tmp_dir, self.tmp_dir) + + list(summary_stats.run(DOCS[0:2], 0, 2)) + list(summary_stats.run(DOCS[2:4], 1, 2)) + list(merge_stats.run(DOCS, 0, 1)) + with self.tmp_dir.open(f"summary/stat/{STATS_MERGED_NAME}") as f: + stats = MetricStatsDict.from_dict(json.load(f)) + self.assertEqual(stats["summary"].total, 5.5) + + def test_merging_top_k(self): + top_k_config = TopKConfig(top_k=1, top_k_groups=["fqdn"]) + summary_stats = DummyStats(output_folder=self.tmp_dir) + merge_stats = StatsMerger(self.tmp_dir, self.tmp_dir, top_k_config=top_k_config) + + list(summary_stats.run(DOCS[0:2], 0, 2)) + list(summary_stats.run(DOCS[2:4], 1, 2)) + list(merge_stats.run(DOCS, 0, 1)) + with self.tmp_dir.open(f"fqdn/stat/{STATS_MERGED_NAME}") as f: + stats = MetricStatsDict.from_dict(json.load(f)) + self.assertEqual(stats["test1.co.uk"].total, 3.5) + self.assertEqual(stats["test2.cz"].total, 0) + + +@require_tldextract +@require_tokenizers +@require_nltk +class TestStatsModules(unittest.TestCase): + def setUp(self): + # Create a temporary directory + self.tmp_dir = get_datafolder(tempfile.mkdtemp()) + self.addCleanup(shutil.rmtree, self.tmp_dir.path) + + def load_computed_means(self, stat_names: list[str]) -> dict: + def load_stat_total(f) -> dict: + stat = MetricStatsDict.from_dict(json.load(f)) + return {k: v.total for k, v in stat.items()} + + computed_stats = {} + for stat in stat_names: + with self.tmp_dir.open(f"histogram/{stat}/00000.json") as f: + computed_stats[stat] = load_stat_total(f) + return computed_stats + + def test_line_stats(self): + docs = [ + Document("hello\nhow\nhow\nzyou?", "1", metadata={"url": "test.cz"}), + Document("test test a", "2", metadata={"url": "test.cz"}), + Document("* Hello", "2", metadata={"url": "test.cz"}), + ] + + expected_line_stats = { + "n_lines": {"4": 1, "1": 2}, + "avg_line_length": {"4.0": 1, "11.0": 1, "7.0": 1}, + "short_line_ratio_chars_3": {"0.5": 1, "0.0": 2}, + "long_line_ratio_chars_5": {"0.5": 1, "1.0": 2}, + "bullet_point_lines_ratio": {"0.0": 2, "1.0": 1}, + "line_duplicates": {"0.25": 1, "0.0": 2}, + "line_char_duplicates": {"0.188": 1, "0.0": 2}, + } + + line_stats = LineStats(self.tmp_dir, max_k_chars_per_line_tresholds=[3], min_k_chars_per_line_thresholds=[5]) + list(line_stats.run(docs)) + + computed_stats = self.load_computed_means(list(expected_line_stats.keys())) + self.assertEqual(computed_stats, expected_line_stats) + + def test_doc_stats(self): + docs = [ + Document("1~", "1", metadata={"url": "test.cz"}), + Document("Test ...", "2", metadata={"url": "test.cz"}), + ] + expected_doc_stats = { + "length": {"2": 1, "8": 1}, + "white_space_ratio": {"0.125": 1, "0.0": 1}, + "non_alpha_digit_ratio": {"0.5": 2}, + "digit_ratio": {"0.5": 1, "0.0": 1}, + "uppercase_ratio": {"0.125": 1, "0.0": 1}, + "elipsis_ratio": {"0.375": 1, "0.0": 1}, + "punctuation_ratio": {"0.5": 1, "0.375": 1}, + } + doc_stats = DocStats(self.tmp_dir) + + list(doc_stats.run(docs)) + computed_stats = self.load_computed_means(list(expected_doc_stats.keys())) + self.assertEqual(computed_stats, expected_doc_stats) + + def test_word_stats(self): + docs = [ + Document("okay\nokay", "1", metadata={"url": "test.cz"}), + Document("test test of", "2", metadata={"url": "test.cz"}), + Document("Looooooooong", "3", metadata={"url": "test.cz"}), + ] + + expected_word_stats = { + "n_words": {"2": 1, "3": 1, "1": 1}, + "avg_word_length": {"3.333": 1, "4.0": 1, "12.0": 1}, + "avg_words_per_line": {"1.0": 2, "3.0": 1}, + "short_word_ratio_3": {"0.333": 1, "0.0": 2}, + "long_word_ratio_7": {"0.0": 2, "1.0": 1}, + "type_token_ratio": {"0.5": 1, "0.667": 1, "1.0": 1}, + "uppercase_word_ratio": {"0.0": 3}, + "capitalized_word_ratio": {"0.0": 2, "1.0": 1}, + "stop_word_ratio": {"0.0": 2, "0.333": 1}, + } + word_stats = WordStats( + self.tmp_dir, + short_word_max_chars_threshold=[3], + long_word_max_chars_threshold=[7], + groups_to_compute=["histogram"], + ) + list(word_stats.run(docs)) + computed_stats = self.load_computed_means(list(expected_word_stats.keys())) + self.assertEqual(computed_stats, expected_word_stats) + + def test_words_contamination(self): + docs = [ + Document("chat gpt loves the word delve and delve is word", "1", metadata={"url": "test.cz"}), + Document("chat gpt doesn't prefer any words", "2", metadata={"url": "test.cz"}), + ] + + expected_words_contamination = { + "words_contamination_delve": {"0.2": 1, "0.0": 1}, + } + + contamination_stats = WordsContaminationStats(self.tmp_dir, ["delve"]) + list(contamination_stats.run(docs)) + + computed_stats = self.load_computed_means(list(expected_words_contamination.keys())) + self.assertEqual(computed_stats, expected_words_contamination) + + def test_token_counter(self): + docs = [ + Document("hi how are you ?", "1", metadata={"url": "test.cz"}), + Document(" hi hi", "2", metadata={"url": "test.cz"}), + ] + + expected_token_counter = { + "token_count": {"5": 1, "2": 1}, + } + + token_counter = TokenStats(self.tmp_dir) + list(token_counter.run(docs)) + + computed_stats = self.load_computed_means(list(expected_token_counter.keys())) + self.assertEqual(computed_stats, expected_token_counter) + + def test_lang_stats(self): + docs = [ + Document("This is pure english text", "1", metadata={"url": "test.cz"}), + Document("Toto je český text", "2", metadata={"url": "test.cz"}), + ] + + expected_lang_stats = { + "fasttext_en": {"0.887": 1, "0.0": 1}, + } + + lang_stats = LangStats(self.tmp_dir, language="en") + list(lang_stats.run(docs)) + + computed_stats = self.load_computed_means(list(expected_lang_stats.keys())) + self.assertEqual(computed_stats, expected_lang_stats) + + def test_paragraph_stats(self): + docs = [ + Document( + "paragraph one\n\nparagraph two\n\nshort\n\nvery very long one", "1", metadata={"url": "test.cz"} + ), + ] + + expected_paragraph_stats = { + "n_paragraphs": {"4": 1}, + "avg_paragraph_length": {"12.25": 1}, + "short_paragraph_ratio_5": {"0.25": 1}, + "long_paragraph_ratio_15": {"0.25": 1}, + } + paragraph_stats = ParagraphStats( + self.tmp_dir, short_paragraph_max_chars_threshold=[5], long_paragraph_max_chars_threshold=[15] + ) + list(paragraph_stats.run(docs)) + + computed_stats = self.load_computed_means(list(expected_paragraph_stats.keys())) + self.assertEqual(computed_stats, expected_paragraph_stats) diff --git a/tests/pipeline/test_text.py b/tests/pipeline/test_text.py new file mode 100644 index 00000000..21337b9b --- /dev/null +++ b/tests/pipeline/test_text.py @@ -0,0 +1,20 @@ +import unittest + +from src.datatrove.utils.text import PUNCTUATION, TextNormConfig, simplify_text + + +class TestTextTransformation(unittest.TestCase): + def test_text_table_norm(self): + text = "|$17.56||1|\n|$15.37||2599|" + config = TextNormConfig(norm_numbers=True, remove_punctuation=True, norm_whitespace=True) + transformed_text = simplify_text(text, config) + expected_text = "0 0 0 0" + self.assertEqual(transformed_text, expected_text) + + def test_punc_normalization(self): + text = PUNCTUATION + config = TextNormConfig(remove_punctuation=True) + transformed_text = simplify_text(text, config) + # Should be just 0, because there is a strange 1 in special symbols + expected_text = "0" + self.assertEqual(transformed_text, expected_text) diff --git a/tests/pipeline/test_url_deduplication.py b/tests/pipeline/test_url_deduplication.py new file mode 100644 index 00000000..c063f1ef --- /dev/null +++ b/tests/pipeline/test_url_deduplication.py @@ -0,0 +1,163 @@ +import copy +import shutil +import tempfile +import unittest + +from datatrove.data import Document +from datatrove.pipeline.dedup.url_dedup import ( + UrlDedupConfig, + UrlDedupFilter, + UrlDedupSignature, + UrlFindDedups, +) +from tests.utils import require_xxhash, use_hash_configs + + +DOCS = [ + Document(text="", metadata={"url": "https://example.com"}, id="1"), + Document(text="", metadata={"url": "https://example.com"}, id="2"), + Document(text="", metadata={"url": "https://new-site.com"}, id="3"), + Document(text="", metadata={"url": "https://example.com"}, id="4"), + Document(text="", metadata={"url": "https://example2.com"}, id="5"), +] + +DOCS_1 = DOCS[:2] +DOCS_2 = DOCS[2:] + + +@require_xxhash +class UrlDedup(unittest.TestCase): + def setUp(self): + # Create a temporary directory + self.tmp_dir = tempfile.mkdtemp() + self.addCleanup(shutil.rmtree, self.tmp_dir) + + def test_url_deduplication(self): + signature_creation = UrlDedupSignature(output_folder=self.tmp_dir + "/sigs") + find_duplicates = UrlFindDedups( + data_folder=self.tmp_dir + "/sigs", + output_folder=self.tmp_dir + "/dups", + lines_to_buffer=1000, + ) + dedup_filter = UrlDedupFilter(data_folder=self.tmp_dir + "/dups") + + signature_creation(data=DOCS) + find_duplicates() + docs = list(dedup_filter(data=copy.deepcopy(DOCS))) + self.assertEqual(len(docs), 3) + self.assertEqual( + {doc.metadata["url"] for doc in docs}, + {doc.metadata["url"] for doc in DOCS}, + ) + + def test_url_deduplication_with_priority_highest_id(self): + config = UrlDedupConfig(document_priority=lambda x: int(x.id)) + + signature_creation = UrlDedupSignature(output_folder=self.tmp_dir + "/sigs", config=config) + find_duplicates = UrlFindDedups( + data_folder=self.tmp_dir + "/sigs", + output_folder=self.tmp_dir + "/dups", + config=config, + ) + dedup_filter = UrlDedupFilter(data_folder=self.tmp_dir + "/dups", config=config) + + signature_creation(data=DOCS) + find_duplicates() + docs = list(dedup_filter(data=copy.deepcopy(DOCS))) + + expected_ids = [3, 4, 5] + self.assertEqual(len(docs), 3) + self.assertEqual({int(doc.id) for doc in docs}, set(expected_ids)) + + def test_url_deduplication_with_priority_lowest_id(self): + config = UrlDedupConfig(document_priority=lambda x: 5 - int(x.id) + 1) + + signature_creation = UrlDedupSignature(output_folder=self.tmp_dir + "/sigs", config=config) + find_duplicates = UrlFindDedups( + data_folder=self.tmp_dir + "/sigs", + output_folder=self.tmp_dir + "/dups", + config=config, + ) + dedup_filter = UrlDedupFilter(data_folder=self.tmp_dir + "/dups", config=config) + + signature_creation(data=DOCS) + find_duplicates() + docs = list(dedup_filter(data=copy.deepcopy(DOCS))) + + expected_ids = [1, 3, 5] + self.assertEqual(len(docs), 3) + self.assertEqual({int(doc.id) for doc in docs}, set(expected_ids)) + + def test_url_deduplication_with_normalization(self): + config = UrlDedupConfig(url_normalizer=lambda x: x.replace("2", "")) + + signature_creation = UrlDedupSignature(output_folder=self.tmp_dir + "/sigs", config=config) + find_duplicates = UrlFindDedups( + data_folder=self.tmp_dir + "/sigs", + output_folder=self.tmp_dir + "/dups", + config=config, + ) + dedup_filter = UrlDedupFilter(data_folder=self.tmp_dir + "/dups", config=config) + + signature_creation(data=DOCS) + find_duplicates() + docs = list(dedup_filter(data=copy.deepcopy(DOCS))) + + self.assertEqual(len(docs), 2) + self.assertEqual( + {doc.metadata["url"] for doc in docs}, + {"https://example.com", "https://new-site.com"}, + ) + + def test_sd_worker(self): + config = UrlDedupConfig(document_priority=lambda x: int(x.id)) + signature_creation = UrlDedupSignature(output_folder=self.tmp_dir + "/sigs", config=config) + + find_duplicates = UrlFindDedups( + data_folder=self.tmp_dir + "/sigs", + output_folder=self.tmp_dir + "/dups", + config=config, + ) + dedup_filter = UrlDedupFilter(data_folder=self.tmp_dir + "/dups", config=config) + + signature_creation(data=DOCS_1, rank=0, world_size=2) + signature_creation(data=DOCS_2, rank=1, world_size=2) + find_duplicates() + + dedup_1 = list(dedup_filter(data=copy.deepcopy(DOCS_1), rank=0, world_size=2)) + dedup_2 = list(dedup_filter(data=copy.deepcopy(DOCS_2), rank=1, world_size=2)) + + self.assertEqual(len(dedup_1), 0) + self.assertEqual(len(dedup_2), 3) + self.assertEqual( + {doc.metadata["url"] for doc in dedup_2}, + {doc.metadata["url"] for doc in DOCS}, + ) + + @use_hash_configs() + def test_distributed_find_dups(self, hash_config): + config = UrlDedupConfig(document_priority=lambda x: int(x.id), hash_config=hash_config) + + signature_creation = UrlDedupSignature(output_folder=self.tmp_dir + "/sigs", finder_workers=50, config=config) + + find_duplicates = UrlFindDedups( + data_folder=self.tmp_dir + "/sigs", + output_folder=self.tmp_dir + "/dups", + config=config, + ) + dedup_filter = UrlDedupFilter(data_folder=self.tmp_dir + "/dups", config=config) + + signature_creation(data=DOCS_1, rank=0, world_size=2) + signature_creation(data=DOCS_2, rank=1, world_size=2) + for rank in range(50): + find_duplicates(rank=rank, world_size=50) + + dedup_docs = list(dedup_filter(data=copy.deepcopy(DOCS_1), rank=0, world_size=2)) + + dedup_docs_2 = list(dedup_filter(data=copy.deepcopy(DOCS_2), rank=1, world_size=2)) + self.assertEqual(len(dedup_docs), 0) + self.assertEqual(len(dedup_docs_2), 3) + self.assertEqual( + {doc.metadata["url"] for doc in dedup_docs_2}, + {doc.metadata["url"] for doc in DOCS}, + ) diff --git a/tests/pipeline/test_word_tokenizers.py b/tests/pipeline/test_word_tokenizers.py new file mode 100644 index 00000000..9f67a44d --- /dev/null +++ b/tests/pipeline/test_word_tokenizers.py @@ -0,0 +1,47 @@ +import unittest + +from nltk.tokenize import word_tokenize + +from datatrove.utils.word_tokenizers import WORD_TOKENIZER_FACTORY, load_word_tokenizer + + +SAMPLE_TEXT = ( + "I wish it need not have happened in my time,' said Frodo. 'So do I,' said Gandalf, 'and so do all who live to " + "see such times. But that is not for them to decide. All we have to decide is what to do with the time that is " + "given us.' Hello world! \n\n ქართული \n\t Hello\nworld! " +) + + +class TestWordTokenizers(unittest.TestCase): + def test_word_tokenizers(self): + for language in WORD_TOKENIZER_FACTORY.keys(): + tokenizer = load_word_tokenizer(language) + tokens = tokenizer.word_tokenize(SAMPLE_TEXT) + assert len(tokens) >= 1, f"'{language}' tokenizer doesn't output tokens" + is_stripped = [token == token.strip() for token in tokens] + assert all(is_stripped), f"'{language}' tokenizer tokens contain whitespaces" + + def test_sent_tokenizers(self): + for language in WORD_TOKENIZER_FACTORY.keys(): + tokenizer = load_word_tokenizer(language) + sents = tokenizer.sent_tokenize(SAMPLE_TEXT) + assert len(sents) >= 1, f"'{language}' tokenizer doesn't output sentences" + is_stripped = [sent == sent.strip() for sent in sents] + assert all(is_stripped), f"'{language}' tokenizer sentences contain whitespaces" + + def test_span_tokenizers(self): + for language in WORD_TOKENIZER_FACTORY.keys(): + tokenizer = load_word_tokenizer(language) + sents = tokenizer.sent_tokenize(SAMPLE_TEXT) + spans = tokenizer.span_tokenize(SAMPLE_TEXT) + assert len(spans) >= 1, f"'{language}' tokenizer doesn't output spans" + spans_match_sents = [sent in SAMPLE_TEXT[span[0] : span[1]] for sent, span in zip(sents, spans)] + assert all(spans_match_sents), f"'{language}' tokenizer spans don't match with sentences" + + def test_english_tokenizer(self): + nltk_words = word_tokenize(SAMPLE_TEXT, language="english") + + en_tokenizer = load_word_tokenizer("en") + tokenizer_words = en_tokenizer.word_tokenize(SAMPLE_TEXT) + + self.assertEqual(nltk_words, tokenizer_words, "NLTK tokenizer and multilingual tokenizer differ") diff --git a/tests/test_io.py b/tests/test_io.py index 27d60c15..4e5aee23 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,11 +1,15 @@ +import multiprocessing +import os import shutil import tempfile +import time import unittest +from functools import partial import boto3 import moto -from datatrove.io import get_datafolder +from datatrove.io import get_datafolder, safely_create_file EXAMPLE_DIRS = ("/home/testuser/somedir", "file:///home/testuser2/somedir", "s3://test-bucket/somedir") @@ -16,6 +20,12 @@ ) +def fake_do_download(cc, ll): + time.sleep(0.5) + with ll: + cc.value += 1 + + @moto.mock_aws class TestIO(unittest.TestCase): def setUp(self): @@ -34,3 +44,31 @@ def test_make_dirs(self): with df.open("subdir1/subdir2/some_path.txt", "wt") as f: f.write("hello") assert df.isdir("subdir1/subdir2") + + def test_safely_create_file_locking(self): + for runi, (completed_exists, lock_exists, expec_calls) in enumerate( + ( + (True, True, 0), + (False, True, 1), + (False, False, 1), + ) + ): + manager = multiprocessing.Manager() + counter = manager.Value("i", 0) + lock = manager.Lock() + + file_path = os.path.join(self.tmp_dir, str(runi), "myfile") + os.makedirs(os.path.join(self.tmp_dir, str(runi))) + + with manager.Pool(2) as pool: + if completed_exists: + open(file_path + ".completed", "a").close() + if lock_exists: + open(file_path + ".lock", "a").close() + + pool.starmap( + partial(safely_create_file, do_processing=partial(fake_do_download, cc=counter, ll=lock)), + [(file_path,) for _ in range(2)], + ) + + self.assertEqual(counter.value, expec_calls) diff --git a/tests/utils.py b/tests/utils.py index 18e3b1e1..3d076308 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,34 @@ +import itertools import unittest +from functools import wraps +from typing import get_args, get_type_hints + +from datatrove.utils.hashing import HashConfig + + +def use_hash_configs( + precision: list[int] = list(get_args(get_type_hints(HashConfig)["precision"])), hash_fc: list[str] = ["xxhash"] +): + """ + Decorator which runs the wrapped test function, with hash config of all combinations of given precision and hash_fc + Args: + precision (list[int]): List of precision values to use. Defaults to all possible values. + hash_fc (list[str]): List of hash functions to use. Defaults to ["xxhash"]. + """ + + def wrapper(f): + @wraps(f) + def inner_wraper(self: unittest.TestCase, *args, **kwargs): + for p, h in itertools.product(precision, hash_fc): + config = HashConfig(precision=p, hash_fc=h) + self.setUp() + f(self, *args, config, **kwargs) + self.tearDown() + self.doCleanups() + + return inner_wraper + + return wrapper def require_nltk(test_case): @@ -95,3 +125,19 @@ def require_datasets(test_case): except ImportError: test_case = unittest.skip("test requires datasets")(test_case) return test_case + + +def require_xxhash(test_case): + try: + import xxhash # noqa: F401 + except ImportError: + test_case = unittest.skip("test requires xxhash")(test_case) + return test_case + + +def require_lighteval(test_case): + try: + import lighteval # noqa: F401 + except ImportError: + test_case = unittest.skip("test requires lighteval")(test_case) + return test_case