Skip to content

Commit

Permalink
add mulit-law-pile pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
jderiu committed Apr 25, 2024
1 parent 80bd7b7 commit 1d8dd8a
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 6 deletions.
77 changes: 77 additions & 0 deletions pipelines/hugginface_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
"""
import os, re
from datatrove.pipeline.readers.huggingface import HuggingFaceDatasetReader
from datatrove.pipeline.tokens import TokensCounter, LengthCounter
from swiss_ai.writers.jsonl import SwissAIJsonlWriter
from datatrove.executor.local import LocalPipelineExecutor

os.environ["HF_BASE"] = "/work_space_data/hf_cache"


def find_years(text):
# Regex pattern to match four-digit numbers that are likely to be years
# This pattern matches any number from 1900 to 2099
pattern = r'\b(19[0-9]{2}|20[0-9]{2})\b'

# Find all matches in the text
years = re.findall(pattern, text)

return years


def _multilegal_adapter(data: dict, path: str, id_in_file: int | str):
years = find_years(data['text'])
if len(years) > 0:
#very crude estimation of the year..
year = max(int(year) for year in years if int(year) <= 2024)
else:
year = 2024
metadata = {
'language': data['language'],
'year': year,
'optional': {
'type': data['type'],
'jurisdiction': data['jurisdiction']
}
}

return {
"text": data.pop('text', ""),
"id": f"{path}/{id_in_file}",
"media": data.pop("media", []),
"metadata": metadata
}


if __name__ == '__main__':
pipeline = [
HuggingFaceDatasetReader(
dataset='joelniklaus/Multi_Legal_Pile',
dataset_options={
'split': 'train',
'name': 'da_caselaw',
'cache_dir': '/work_space_data/hf_cache',
'trust_remote_code': True
},
progress=True,
adapter=_multilegal_adapter,
limit=1000
),
TokensCounter(tokenizer_name_or_path='t5-small'),
LengthCounter(),
SwissAIJsonlWriter(
output_folder="/work_space_data/multilegal_pile/jsonl"
)
]

exec = LocalPipelineExecutor(
pipeline=pipeline,
tasks=16,
workers=1,
start_method="spawn",
logging_dir="/work_space_data/multilegal_pile/logging"
)

exec.run()
2 changes: 2 additions & 0 deletions pipelines/swissdox_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from datatrove.pipeline.writers import JsonlWriter
from datatrove.executor.local import LocalPipelineExecutor

os.environ["HF_BASE"] = "/work_space_data/hf_cache/"

if __name__ == '__main__':
pipeline = [
RawSwissDoxReader(
Expand Down
1 change: 1 addition & 0 deletions src/datatrove/pipeline/readers/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1
yield from data
# sadly sharding in this way with streaming is not supported by HF datasets yet, so no streaming
ds = load_dataset(self.dataset, **self.dataset_options)

shard = ds.shard(world_size, rank, contiguous=True)
with tqdm(total=self.limit if self.limit != -1 else None) if self.progress else nullcontext() as pbar:
li = 0
Expand Down
7 changes: 1 addition & 6 deletions src/swiss_ai/writers/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,7 @@ def check_document(document: dict):
if metadata is None or type(metadata) is not dict:
return False

required_metadata = metadata.get('required', None)

if required_metadata is None or type(required_metadata) is not dict:
return False

required_check = SwissAIJsonlWriter._check_required_metadata(required_metadata)
required_check = SwissAIJsonlWriter._check_required_metadata(metadata)
if not required_check:
return False

Expand Down

0 comments on commit 1d8dd8a

Please sign in to comment.