From 408b5dac00dff78192ed239a77d83be4b710809f Mon Sep 17 00:00:00 2001 From: Adam Page Date: Fri, 11 Aug 2023 11:11:18 -0500 Subject: [PATCH] feat: latest version. --- configs/arrhythmia-demo.json | 2 +- configs/evaluate-arrhythmia-model.json | 2 +- configs/export-arrhythmia-model.json | 2 +- configs/heartkit-demo.json | 6 ++--- evb/src/constants.h | 4 ++-- evb/src/main.cc | 4 ++-- heartkit/datasets/qtdb.py | 27 ++++++++++++++++++++++ heartkit/demo/defines.py | 4 +++- heartkit/demo/evb.py | 31 +++++++++----------------- heartkit/demo/pc.py | 18 +++++++++------ pyproject.toml | 2 +- 11 files changed, 63 insertions(+), 39 deletions(-) diff --git a/configs/arrhythmia-demo.json b/configs/arrhythmia-demo.json index dbf2b78f..c93316bd 100644 --- a/configs/arrhythmia-demo.json +++ b/configs/arrhythmia-demo.json @@ -10,5 +10,5 @@ "vid_pid": "51966:16385", "baudrate": "115200", "data_parallelism": 1, - "datasets": ["icentia11k"] + "dataset": "icentia11k" } diff --git a/configs/evaluate-arrhythmia-model.json b/configs/evaluate-arrhythmia-model.json index 968e7bbf..83a8050b 100644 --- a/configs/evaluate-arrhythmia-model.json +++ b/configs/evaluate-arrhythmia-model.json @@ -7,6 +7,6 @@ "test_patients": 1000, "test_size": 100000, "model_file": "./results/arrhythmia/model.tf", - "model_file_rmt": "wandb://ambiq/model-registry/heartkit-arrhythmia:v1", + "model_file_rmt": "wandb://ambiq/model-registry/heartkit-arrhythmia:v0", "threshold": 0.75 } diff --git a/configs/export-arrhythmia-model.json b/configs/export-arrhythmia-model.json index 2da9e4a9..47d5b73d 100644 --- a/configs/export-arrhythmia-model.json +++ b/configs/export-arrhythmia-model.json @@ -7,7 +7,7 @@ "test_patients": 1000, "test_size": 10000, "model_file": "./results/arrhythmia/model.tf", - "model_file_rmt": "wandb://ambiq/model-registry/heartkit-arrhythmia:v1", + "model_file_rmt": "wandb://ambiq/model-registry/heartkit-arrhythmia:v0", "quantization": true, "use_logits": false, "threshold": 0.80, diff --git a/configs/heartkit-demo.json b/configs/heartkit-demo.json index 47f502b3..9548c57e 100644 --- a/configs/heartkit-demo.json +++ b/configs/heartkit-demo.json @@ -2,8 +2,8 @@ "job_dir": "./results/heartkit-demo", "ds_path": "./datasets", "rest_address": "http://0.0.0.0:8000/api/v1", - "backend": "pc", - "frontend": "web", + "backend": "evb", + "frontend": "console", "arrhythmia_model": "./results/arrhythmia/model.tf", "segmentation_model": "./results/segmentation/model.tf", "beat_model": "./results/beat/model.tf", @@ -13,5 +13,5 @@ "vid_pid": "51966:16385", "baudrate": "115200", "data_parallelism": 1, - "datasets": ["icentia11k"] + "dataset": "ludb" } diff --git a/evb/src/constants.h b/evb/src/constants.h index 837a21cf..a98bd286 100644 --- a/evb/src/constants.h +++ b/evb/src/constants.h @@ -37,7 +37,7 @@ #define SEGMENTATION_QUANTIZE (1) #define SEG_MODEL_SIZE_KB (85) #define SEG_FRAME_LEN (512) -#define SEG_OVERLAP_LEN (20) +#define SEG_OVERLAP_LEN (56) #define SEG_STEP_SIZE (SEG_FRAME_LEN - 2 * SEG_OVERLAP_LEN) #define SEG_THRESHOLD (0.50) @@ -46,7 +46,7 @@ #define BEAT_QUANTIZE (1) #define BEAT_MODEL_SIZE_KB (60) #define BEAT_FRAME_LEN (160) -#define BEAT_THRESHOLD (0.75) +#define BEAT_THRESHOLD (0.50) // App block #define DISPLAY_LEN_USEC (2000000) diff --git a/evb/src/main.cc b/evb/src/main.cc index d69f7e28..b6c7693f 100644 --- a/evb/src/main.cc +++ b/evb/src/main.cc @@ -69,7 +69,7 @@ apply_arrhythmia_model() { uint32_t apply_segmentation_model() { uint32_t err = 0; - for (size_t i = 0; i < HK_DATA_LEN - SEG_FRAME_LEN + 1; i += SEG_FRAME_LEN) { + for (size_t i = 0; i < HK_DATA_LEN - SEG_FRAME_LEN + 1; i += SEG_STEP_SIZE) { err = segmentation_inference(&hkStore.ecgData[i], &hkStore.segMask[i], SEG_OVERLAP_LEN, SEG_THRESHOLD); } err |= segmentation_inference(&hkStore.ecgData[HK_DATA_LEN - SEG_FRAME_LEN], &hkStore.segMask[HK_DATA_LEN - SEG_FRAME_LEN], @@ -462,7 +462,6 @@ loop() { case START_COLLECT_STATE: print_to_pc("COLLECT_STATE\n"); start_collecting(); - clear_collect_mode(); hkStore.state = COLLECT_STATE; break; @@ -477,6 +476,7 @@ loop() { case STOP_COLLECT_STATE: stop_collecting(); + clear_collect_mode(); hkStore.state = hkStore.errorCode != 0 ? FAIL_STATE : PREPROCESS_STATE; break; diff --git a/heartkit/datasets/qtdb.py b/heartkit/datasets/qtdb.py index 41d7119e..9a6b7c0e 100644 --- a/heartkit/datasets/qtdb.py +++ b/heartkit/datasets/qtdb.py @@ -290,7 +290,34 @@ def segmentation_generator( # if num_samples > samples_per_patient: # break # # END FOR + # END FOR + def signal_generator(self, patient_generator: PatientGenerator, samples_per_patient: int = 1) -> SampleGenerator: + """ + Generate frames using patient generator. + from the segments in patient data by placing a frame in a random location within one of the segments. + Args: + patient_generator (PatientGenerator): Generator that yields a tuple of patient id and patient data. + Patient data may contain only signals, since labels are not used. + samples_per_patient (int): Samples per patient. + Returns: + SampleGenerator: Generator of input data of shape (frame_size, 1) + """ + for _, pt in patient_generator: + data = pt["data"][:] + if self.sampling_rate != self.target_rate: + data = resample_signal(data, self.sampling_rate, self.target_rate) + # END IF + for _ in range(samples_per_patient): + lead_idx = np.random.randint(data.shape[1]) + if data.shape[0] > self.frame_size: + frame_start = np.random.randint(data.shape[0] - self.frame_size) + else: + frame_start = 0 + frame_end = frame_start + self.frame_size + x = data[frame_start:frame_end, lead_idx].astype(np.float32).reshape((self.frame_size,)) + yield x + # END FOR # END FOR def get_patient_data_segments(self, patient: int) -> tuple[npt.NDArray, npt.NDArray]: diff --git a/heartkit/demo/defines.py b/heartkit/demo/defines.py index 133ba7d3..0825021f 100644 --- a/heartkit/demo/defines.py +++ b/heartkit/demo/defines.py @@ -6,6 +6,8 @@ from pydantic import BaseModel, Extra, Field +DatasetTypes = Literal["icentia11k", "ludb", "qtdb", "synthetic"] + class AppState(StrEnum): """HeartKit backend app state""" @@ -59,7 +61,7 @@ class HeartDemoParams(BaseModel, extra=Extra.allow): beat_model: str | None = Field(default=None, description="Beat TF model path") # Dataset arguments - datasets: list[str] = Field(default_factory=list, description="Dataset names") + dataset: DatasetTypes = Field(default="ludb", description="Dataset name") ds_path: Path = Field(default_factory=Path, description="Dataset directory") sampling_rate: int = Field(200, description="Target sampling rate (Hz)") frame_size: int = Field(2000, description="Frame size") diff --git a/heartkit/demo/evb.py b/heartkit/demo/evb.py index 1d4c904c..76177482 100644 --- a/heartkit/demo/evb.py +++ b/heartkit/demo/evb.py @@ -15,7 +15,7 @@ from neuralspot.rpc import GenericDataOperations_PcToEvb as gen_pc2evb from neuralspot.rpc.utils import get_serial_transport -from ..datasets import SyntheticDataset +from ..datasets import IcentiaDataset, LudbDataset, SyntheticDataset from .client import HKRestClient from .defines import AppState, HeartDemoParams, HeartKitState, HKResult from .utils import setup_logger @@ -98,25 +98,16 @@ def create_data_generator(self) -> Generator[npt.NDArray[np.float32], None, None Returns: Generator[npt.NDArray[np.float32], None, None]: Data generator """ - - def default_gen(): - while True: - yield np.random.rand(self.params.frame_size).astype(np.float32) - - datasets: list[str] = self.params.datasets - if len(datasets) == 0: - data_gen = default_gen() - elif "icentia11k" in datasets: - ds = SyntheticDataset( - ds_path=str(self.params.ds_path), - frame_size=self.params.frame_size, - target_rate=self.params.sampling_rate, - ) - pt_gen = ds.uniform_patient_generator(ds.get_test_patient_ids()) - data_gen = ds.signal_generator(pt_gen, samples_per_patient=self.params.samples_per_patient) - else: - raise ValueError(f"Unsupported dataset: {datasets}") - + data_handlers = dict(icentia11k=IcentiaDataset, synthetic=SyntheticDataset, ludb=LudbDataset) + logger.info(f"Loading dataset {self.params.dataset}") + DataHandler = data_handlers.get(self.params.dataset, LudbDataset) + ds = DataHandler( + ds_path=str(self.params.ds_path), + frame_size=self.params.frame_size, + target_rate=self.params.sampling_rate, + ) + pt_gen = ds.uniform_patient_generator(ds.get_test_patient_ids()) + data_gen = ds.signal_generator(pt_gen, samples_per_patient=self.params.samples_per_patient) return data_gen def ns_rpc_data_sendBlockToPC(self, block: gen_pc2evb.common.dataBlock): diff --git a/heartkit/demo/pc.py b/heartkit/demo/pc.py index 6fd937f5..f2298d71 100644 --- a/heartkit/demo/pc.py +++ b/heartkit/demo/pc.py @@ -9,7 +9,7 @@ from neuralspot.tflite.model import load_model -from ..datasets import LudbDataset +from ..datasets import IcentiaDataset, LudbDataset, QtdbDataset, SyntheticDataset from ..defines import HeartBeat, HeartRate, HeartSegment from ..signal import ( compute_rr_intervals, @@ -54,7 +54,10 @@ def create_data_generator(self) -> Generator[npt.NDArray[np.float32], None, None Returns: Generator[npt.NDArray[np.float32], None, None]: Data generator """ - ds = LudbDataset( + data_handlers = dict(icentia11k=IcentiaDataset, synthetic=SyntheticDataset, ludb=LudbDataset, qtdb=QtdbDataset) + logger.info(f"Loading dataset {self.params.dataset}") + DataHandler = data_handlers.get(self.params.dataset, LudbDataset) + ds = DataHandler( ds_path=str(self.params.ds_path), frame_size=self.params.frame_size, target_rate=self.params.sampling_rate, @@ -72,11 +75,12 @@ def load_models(self): if self.params.beat_model: self.beat_model = load_model(self.params.beat_model) - def preprocess(self, data: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: + def preprocess(self, data: npt.NDArray[np.float32]) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]: """Perform pre-processing to data""" - data = filter_signal(data, lowcut=0.5, highcut=30, order=3, sample_rate=self.params.sampling_rate, axis=0) data = normalize_signal(data, eps=0.1, axis=None) - return data + ecg_data = filter_signal(data, lowcut=0.5, highcut=30, order=3, sample_rate=self.params.sampling_rate, axis=0) + qrs_data = filter_signal(data, lowcut=10, highcut=30, order=3, sample_rate=self.params.sampling_rate, axis=0) + return ecg_data, qrs_data def arrhythmia_inference(self, data: npt.NDArray[np.float32], threshold: float = 0.75) -> npt.NDArray[np.uint8]: """Apply arrhythmia model to data. @@ -204,7 +208,7 @@ def run(self): # Pre-process self.update_app_state(AppState.PREPROCESS_STATE) - data = self.preprocess(data=data) + data, qrs_data = self.preprocess(data=data) # Inference self.update_app_state(AppState.INFERENCE_STATE) @@ -221,7 +225,7 @@ def run(self): seg_mask, _ = self.segmentation_inference(data, threshold=0.7) # Apply HRV model and extract R peaks - rpeaks, rr_ints, _ = self.hrv_inference(data, seg_mask) + rpeaks, rr_ints, _ = self.hrv_inference(qrs_data, seg_mask) bpm = 60 / (np.mean(rr_ints) / self.params.sampling_rate) # bpm, _, rpeaks = compute_hrv(data, qrs_mask, self.params.sampling_rate) avg_rr = max(0, int(self.params.sampling_rate / (bpm / 60))) diff --git a/pyproject.toml b/pyproject.toml index 5ae666a5..a7cf0ed9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "heartkit" -version = "1.0.0" +version = "1.1.0" description = "AI driven heart monitoring kit for ultra low-power wearables." license = "BSD-3-Clause" authors = ["Adam Page "]