Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
jderiu committed Jun 19, 2024
2 parents ba1c6f5 + 9fbab55 commit a550e1d
Show file tree
Hide file tree
Showing 10 changed files with 172 additions and 52 deletions.
2 changes: 1 addition & 1 deletion examples/local_minhash_deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# you can also change ngrams or the number of buckets and their size here
minhash_config = MinhashConfig(use_64bit_hashes=True) # better precision -> fewer false positives (collisions)

corpus = 'swissdox'
corpus = 'curiavista'

S3_MINHASH_BASE_PATH = f"/work_space_data/{corpus}/minhash/"

Expand Down
18 changes: 8 additions & 10 deletions pipelines/curia_vista.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from swiss_ai.readers.curia_vista import RawCuriaVistaReader
from datatrove.pipeline.tokens import TokensCounter, LengthCounter
from datatrove.pipeline.writers import JsonlWriter
from swiss_ai.writers.jsonl import SwissAIJsonlWriter
from datatrove.pipeline.readers import JsonlReader
from datatrove.executor.local import LocalPipelineExecutor
from datetime import datetime

now = datetime.now()

if __name__ == '__main__':
table = 'Business'
table = 'Transcript'
trascr_cols = [
'Text'
]

now = datetime.now()
batch = now.strftime("%Y_%m_%d_%H_%M_%S")
Expand All @@ -20,17 +23,12 @@
RawCuriaVistaReader(
table=table,
progress=True,
columns=[
'SubmittedText',
'FederalCouncilResponseText',
'InitialSituation',
'Proceedings'
],
limit=100
columns=trascr_cols,
limit=1500000
),
TokensCounter(tokenizer_name_or_path='t5-small'),
LengthCounter(),
JsonlWriter(
SwissAIJsonlWriter(
output_folder=f"/work_space_data/curiavista/{table}/jsonl_{batch}"
)
]
Expand Down
76 changes: 76 additions & 0 deletions pipelines/hugginface_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
"""
import os, re
from datatrove.pipeline.readers.huggingface import HuggingFaceDatasetReader
from datatrove.pipeline.tokens import TokensCounter, LengthCounter
from swiss_ai.writers.jsonl import SwissAIJsonlWriter
from datatrove.executor.local import LocalPipelineExecutor

os.environ["HF_BASE"] = "/work_space_data/hf_cache"


def find_years(text):
# Regex pattern to match four-digit numbers that are likely to be years
# This pattern matches any number from 1900 to 2099
pattern = r'\b(19[0-9]{2}|20[0-9]{2})\b'

# Find all matches in the text
years = re.findall(pattern, text)

return years


def _multilegal_adapter(data: dict, path: str, id_in_file: int | str):
years = find_years(data['text'])
if len(years) > 0:
#very crude estimation of the year..
year = max(int(year) for year in years if int(year) <= 2024)
else:
year = 2024
metadata = {
'language': data['language'],
'year': year,
'optional': {
'type': data['type'],
'jurisdiction': data['jurisdiction']
}
}

return {
"text": data.pop('text', ""),
"id": f"{path}/{id_in_file}",
"media": data.pop("media", []),
"metadata": metadata
}


if __name__ == '__main__':
pipeline = [
HuggingFaceDatasetReader(
dataset='joelniklaus/Multi_Legal_Pile',
dataset_options={
'split': 'train',
'name': 'da_caselaw',
'cache_dir': '/work_space_data/hf_cache',
'trust_remote_code': True
},
progress=True,
adapter=_multilegal_adapter,
limit=1000
),
TokensCounter(tokenizer_name_or_path='t5-small'),
SwissAIJsonlWriter(
output_folder="/work_space_data/multilegal_pile/jsonl"
)
]

exec = LocalPipelineExecutor(
pipeline=pipeline,
tasks=16,
workers=16,
start_method="spawn",
logging_dir="/work_space_data/multilegal_pile/logging"
)

exec.run()
16 changes: 10 additions & 6 deletions pipelines/swissdox_raw.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""
"""

from swiss_ai.readers.swissdox import RawSwissDoxReader
from datatrove.pipeline.tokens import TokensCounter, LengthCounter
from datatrove.pipeline.writers import JsonlWriter
from swiss_ai.writers.jsonl import SwissAIJsonlWriter
from datatrove.executor.local import LocalPipelineExecutor
import os

Expand All @@ -14,15 +13,20 @@

if __name__ == "__main__":
pipeline = [
RawSwissDoxReader(data_folder="/work_space_data/swissdox", limit=-1),
TokensCounter(tokenizer_name_or_path="t5-small"),
RawSwissDoxReader(
data_folder="/work_space_data/swissdox",
limit=-1
),
LengthCounter(),
JsonlWriter(output_folder="/work_space_data/swissdox/jsonl"),
TokensCounter(tokenizer_name_or_path='t5-small'),
SwissAIJsonlWriter(
output_folder="/work_space_data/swissdox/jsonl"
)
]

exec = LocalPipelineExecutor(
pipeline=pipeline,
tasks=16,
tasks=64,
workers=16,
start_method="spawn",
logging_dir="/work_space_data/swissdox/logging",
Expand Down
2 changes: 1 addition & 1 deletion src/datatrove/executor/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def run(self):
)
# merged stats
stats = sum(stats, start=PipelineStats())
with self.logging_dir.open("stats.json", "wt") as statsfile:
with self.logging_dir.open("stats.json", "wt", encoding='utf-8') as statsfile:
stats.save_to_disk(statsfile)
logger.success(stats.get_repr(f"All {self.local_tasks} tasks"))
return stats
Expand Down
Empty file.
3 changes: 3 additions & 0 deletions src/swiss_ai/pipeline/pii_removal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from presidio_analyzer import AnalyzerEngine
from presidio_anonymizer import AnonymizerEngine

87 changes: 60 additions & 27 deletions src/swiss_ai/readers/curia_vista.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import xml.etree.ElementTree as ET
from tqdm import tqdm

from langdetect import detect
from datatrove.io import DataFolderLike, get_datafolder
from datatrove.pipeline.readers.base import BaseReader, DocumentsPipeline

Expand Down Expand Up @@ -68,7 +68,7 @@ def parse_ids(self, id_url):
if not child.tag.endswith('entry'):
continue
idx = child[-1][0][-1].text
indices.add(idx)
indices.add(int(idx))
return indices

def retrieve_single_record_for_id(self, in_id):
Expand All @@ -91,13 +91,18 @@ def retrieve_single_record_for_id(self, in_id):
return all_data

def _curia_vista_adapter(self, data: dict, path: str, id_in_file: int | str):

text = ''.join([f'<h2>{col}<h2>{data[col]}' for col in self.columns if data[col] is not None])
meta_data = {k: v for k, v in data.items() if k not in self.columns}
if not text:
text = 'DUMMY_TEXT'
meta_data['delete'] = True
opt_meta_data = {k: v for k, v in data.items() if k not in self.columns}

lang = opt_meta_data['LanguageOfText'].lower() if opt_meta_data['LanguageOfText'] is not None else None
if lang is None:
lang = detect(text)

meta_data = {
'optional': opt_meta_data,
'language': lang,
'year': int(opt_meta_data['MeetingDate'][:4])
}

return {
"text": text,
Expand All @@ -107,25 +112,53 @@ def _curia_vista_adapter(self, data: dict, path: str, id_in_file: int | str):
}

def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1) -> DocumentsPipeline:
processed_ids = set([doc.id for doc in data])
processed_ids = set()
processed_dp = set()
try:
for document in data:
processed_ids.add(document.id)
dp = f'{document.id}_{document.metadata["language"]}'
if dp in processed_dp:
continue
processed_dp.add(dp)
if document.metadata["language"] is None:
document.metadata["language"] = detect(document.text)

yield document
except:
print('Noooo')

print(f'Already processed {len(processed_ids)} Documents')
if len(processed_ids) > 0:
last_id = max(processed_ids)
else:
last_id = 0

ids = ['dummy']
limit = self.limit
if not limit == -1:
id_url = f"{self.base_url}?$top={limit}&$filter=Language eq 'DE'&$select=ID"
else:
id_url = f"{self.base_url}?$filter=Language eq 'DE'&$select=ID"
ids = self.parse_ids(id_url)
ids = ids.difference(processed_ids)

for nr, entry_id in tqdm(enumerate(ids, start=1)):
with self.track_time():
entries = self.retrieve_single_record_for_id(entry_id)

if nr % 10 == 0:
time.sleep(self.timeout)

for data_dict in entries:
document = self.get_document_from_dict(data_dict, self.table, entry_id)
if not document:
continue
yield document

global_count = 0
while len(ids) > 0 and global_count < limit:
id_url = f"{self.base_url}?$filter=Language eq 'DE' and ID gt {last_id} &$orderby=ID&$select=ID&$top=100"
ids = self.parse_ids(id_url)
ids = ids.difference(processed_ids)
ids = sorted(ids)

for nr, entry_id in tqdm(enumerate(ids, start=1)):
with self.track_time():
entries = self.retrieve_single_record_for_id(entry_id)
global_count += len(entries)

last_id = entry_id
if nr % 10 == 0:
time.sleep(self.timeout)

for data_dict in entries:
document = self.get_document_from_dict(data_dict, self.table, entry_id)
if not document:
continue
yield document
if global_count >= limit:
break
processed_ids.add(entry_id)
time.sleep(60)
16 changes: 11 additions & 5 deletions src/swiss_ai/readers/swissdox.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def iterate_entries(self, f, meta_data: dict):
if ignroe_article:
continue
tmp_text = f"{tmp_text}\n{content}"

def load_meta_data(self, filepath):
meta_data_full = {}
with self.meta_data_folder.open(filepath, "r", encoding='utf-8', compression=self.compression) as mf:
Expand All @@ -108,11 +109,16 @@ def load_meta_data(self, filepath):
news_paper_short = sline[-3]
date = sline[-4]

meta_dict = json.loads(sdict)
meta_dict['news_paper_short'] = news_paper_short.strip()
meta_dict['news_paper'] = news_paper.strip()
meta_dict['pub_date'] = date.strip()
meta_dict['lang'] = lang
opt_meta_dict = json.loads(sdict)
opt_meta_dict['news_paper_short'] = news_paper_short.strip()
opt_meta_dict['news_paper'] = news_paper.strip()
opt_meta_dict['pub_date'] = date.strip()
meta_dict = {
'language': lang,
'year': int(date.strip().split('-')[0]),
'optional': opt_meta_dict
}

meta_data_full[lid] = meta_dict
return meta_data_full

Expand Down
4 changes: 2 additions & 2 deletions src/swiss_ai/writers/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import IO, Callable

from datatrove.io import DataFolderLike
from loguru import logger
from datatrove.pipeline.writers.disk_base import DiskWriter
from swiss_ai.utils.language_list import LANGUAGE_CODES
from datetime import datetime
Expand Down Expand Up @@ -85,7 +86,6 @@ def _check_required_metadata(required_metadata: dict):
def _write(self, document: dict, file_handler: IO, _filename: str):
passed_check = SwissAIJsonlWriter.check_document(document)
if not passed_check:
#TODO handle this better and give more descriptive feedback
raise ValueError('Document is not valid')
logger.warning(f'Document not valid: {str(document)}')

file_handler.write(json.dumps(document, ensure_ascii=False) + "\n")

0 comments on commit a550e1d

Please sign in to comment.