Skip to content

Commit

Permalink
feat: latest version.
Browse files Browse the repository at this point in the history
  • Loading branch information
apage224 committed Aug 11, 2023
1 parent 8e00452 commit 408b5da
Show file tree
Hide file tree
Showing 11 changed files with 63 additions and 39 deletions.
2 changes: 1 addition & 1 deletion configs/arrhythmia-demo.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
"vid_pid": "51966:16385",
"baudrate": "115200",
"data_parallelism": 1,
"datasets": ["icentia11k"]
"dataset": "icentia11k"
}
2 changes: 1 addition & 1 deletion configs/evaluate-arrhythmia-model.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
"test_patients": 1000,
"test_size": 100000,
"model_file": "./results/arrhythmia/model.tf",
"model_file_rmt": "wandb://ambiq/model-registry/heartkit-arrhythmia:v1",
"model_file_rmt": "wandb://ambiq/model-registry/heartkit-arrhythmia:v0",
"threshold": 0.75
}
2 changes: 1 addition & 1 deletion configs/export-arrhythmia-model.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"test_patients": 1000,
"test_size": 10000,
"model_file": "./results/arrhythmia/model.tf",
"model_file_rmt": "wandb://ambiq/model-registry/heartkit-arrhythmia:v1",
"model_file_rmt": "wandb://ambiq/model-registry/heartkit-arrhythmia:v0",
"quantization": true,
"use_logits": false,
"threshold": 0.80,
Expand Down
6 changes: 3 additions & 3 deletions configs/heartkit-demo.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
"job_dir": "./results/heartkit-demo",
"ds_path": "./datasets",
"rest_address": "http://0.0.0.0:8000/api/v1",
"backend": "pc",
"frontend": "web",
"backend": "evb",
"frontend": "console",
"arrhythmia_model": "./results/arrhythmia/model.tf",
"segmentation_model": "./results/segmentation/model.tf",
"beat_model": "./results/beat/model.tf",
Expand All @@ -13,5 +13,5 @@
"vid_pid": "51966:16385",
"baudrate": "115200",
"data_parallelism": 1,
"datasets": ["icentia11k"]
"dataset": "ludb"
}
4 changes: 2 additions & 2 deletions evb/src/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#define SEGMENTATION_QUANTIZE (1)
#define SEG_MODEL_SIZE_KB (85)
#define SEG_FRAME_LEN (512)
#define SEG_OVERLAP_LEN (20)
#define SEG_OVERLAP_LEN (56)
#define SEG_STEP_SIZE (SEG_FRAME_LEN - 2 * SEG_OVERLAP_LEN)
#define SEG_THRESHOLD (0.50)

Expand All @@ -46,7 +46,7 @@
#define BEAT_QUANTIZE (1)
#define BEAT_MODEL_SIZE_KB (60)
#define BEAT_FRAME_LEN (160)
#define BEAT_THRESHOLD (0.75)
#define BEAT_THRESHOLD (0.50)

// App block
#define DISPLAY_LEN_USEC (2000000)
Expand Down
4 changes: 2 additions & 2 deletions evb/src/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ apply_arrhythmia_model() {
uint32_t
apply_segmentation_model() {
uint32_t err = 0;
for (size_t i = 0; i < HK_DATA_LEN - SEG_FRAME_LEN + 1; i += SEG_FRAME_LEN) {
for (size_t i = 0; i < HK_DATA_LEN - SEG_FRAME_LEN + 1; i += SEG_STEP_SIZE) {
err = segmentation_inference(&hkStore.ecgData[i], &hkStore.segMask[i], SEG_OVERLAP_LEN, SEG_THRESHOLD);
}
err |= segmentation_inference(&hkStore.ecgData[HK_DATA_LEN - SEG_FRAME_LEN], &hkStore.segMask[HK_DATA_LEN - SEG_FRAME_LEN],
Expand Down Expand Up @@ -462,7 +462,6 @@ loop() {
case START_COLLECT_STATE:
print_to_pc("COLLECT_STATE\n");
start_collecting();
clear_collect_mode();
hkStore.state = COLLECT_STATE;
break;

Expand All @@ -477,6 +476,7 @@ loop() {

case STOP_COLLECT_STATE:
stop_collecting();
clear_collect_mode();
hkStore.state = hkStore.errorCode != 0 ? FAIL_STATE : PREPROCESS_STATE;
break;

Expand Down
27 changes: 27 additions & 0 deletions heartkit/datasets/qtdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,34 @@ def segmentation_generator(
# if num_samples > samples_per_patient:
# break
# # END FOR
# END FOR

def signal_generator(self, patient_generator: PatientGenerator, samples_per_patient: int = 1) -> SampleGenerator:
"""
Generate frames using patient generator.
from the segments in patient data by placing a frame in a random location within one of the segments.
Args:
patient_generator (PatientGenerator): Generator that yields a tuple of patient id and patient data.
Patient data may contain only signals, since labels are not used.
samples_per_patient (int): Samples per patient.
Returns:
SampleGenerator: Generator of input data of shape (frame_size, 1)
"""
for _, pt in patient_generator:
data = pt["data"][:]
if self.sampling_rate != self.target_rate:
data = resample_signal(data, self.sampling_rate, self.target_rate)
# END IF
for _ in range(samples_per_patient):
lead_idx = np.random.randint(data.shape[1])
if data.shape[0] > self.frame_size:
frame_start = np.random.randint(data.shape[0] - self.frame_size)
else:
frame_start = 0
frame_end = frame_start + self.frame_size
x = data[frame_start:frame_end, lead_idx].astype(np.float32).reshape((self.frame_size,))
yield x
# END FOR
# END FOR

def get_patient_data_segments(self, patient: int) -> tuple[npt.NDArray, npt.NDArray]:
Expand Down
4 changes: 3 additions & 1 deletion heartkit/demo/defines.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from pydantic import BaseModel, Extra, Field

DatasetTypes = Literal["icentia11k", "ludb", "qtdb", "synthetic"]


class AppState(StrEnum):
"""HeartKit backend app state"""
Expand Down Expand Up @@ -59,7 +61,7 @@ class HeartDemoParams(BaseModel, extra=Extra.allow):
beat_model: str | None = Field(default=None, description="Beat TF model path")

# Dataset arguments
datasets: list[str] = Field(default_factory=list, description="Dataset names")
dataset: DatasetTypes = Field(default="ludb", description="Dataset name")
ds_path: Path = Field(default_factory=Path, description="Dataset directory")
sampling_rate: int = Field(200, description="Target sampling rate (Hz)")
frame_size: int = Field(2000, description="Frame size")
Expand Down
31 changes: 11 additions & 20 deletions heartkit/demo/evb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from neuralspot.rpc import GenericDataOperations_PcToEvb as gen_pc2evb
from neuralspot.rpc.utils import get_serial_transport

from ..datasets import SyntheticDataset
from ..datasets import IcentiaDataset, LudbDataset, SyntheticDataset
from .client import HKRestClient
from .defines import AppState, HeartDemoParams, HeartKitState, HKResult
from .utils import setup_logger
Expand Down Expand Up @@ -98,25 +98,16 @@ def create_data_generator(self) -> Generator[npt.NDArray[np.float32], None, None
Returns:
Generator[npt.NDArray[np.float32], None, None]: Data generator
"""

def default_gen():
while True:
yield np.random.rand(self.params.frame_size).astype(np.float32)

datasets: list[str] = self.params.datasets
if len(datasets) == 0:
data_gen = default_gen()
elif "icentia11k" in datasets:
ds = SyntheticDataset(
ds_path=str(self.params.ds_path),
frame_size=self.params.frame_size,
target_rate=self.params.sampling_rate,
)
pt_gen = ds.uniform_patient_generator(ds.get_test_patient_ids())
data_gen = ds.signal_generator(pt_gen, samples_per_patient=self.params.samples_per_patient)
else:
raise ValueError(f"Unsupported dataset: {datasets}")

data_handlers = dict(icentia11k=IcentiaDataset, synthetic=SyntheticDataset, ludb=LudbDataset)
logger.info(f"Loading dataset {self.params.dataset}")
DataHandler = data_handlers.get(self.params.dataset, LudbDataset)
ds = DataHandler(
ds_path=str(self.params.ds_path),
frame_size=self.params.frame_size,
target_rate=self.params.sampling_rate,
)
pt_gen = ds.uniform_patient_generator(ds.get_test_patient_ids())
data_gen = ds.signal_generator(pt_gen, samples_per_patient=self.params.samples_per_patient)
return data_gen

def ns_rpc_data_sendBlockToPC(self, block: gen_pc2evb.common.dataBlock):
Expand Down
18 changes: 11 additions & 7 deletions heartkit/demo/pc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from neuralspot.tflite.model import load_model

from ..datasets import LudbDataset
from ..datasets import IcentiaDataset, LudbDataset, QtdbDataset, SyntheticDataset
from ..defines import HeartBeat, HeartRate, HeartSegment
from ..signal import (
compute_rr_intervals,
Expand Down Expand Up @@ -54,7 +54,10 @@ def create_data_generator(self) -> Generator[npt.NDArray[np.float32], None, None
Returns:
Generator[npt.NDArray[np.float32], None, None]: Data generator
"""
ds = LudbDataset(
data_handlers = dict(icentia11k=IcentiaDataset, synthetic=SyntheticDataset, ludb=LudbDataset, qtdb=QtdbDataset)
logger.info(f"Loading dataset {self.params.dataset}")
DataHandler = data_handlers.get(self.params.dataset, LudbDataset)
ds = DataHandler(
ds_path=str(self.params.ds_path),
frame_size=self.params.frame_size,
target_rate=self.params.sampling_rate,
Expand All @@ -72,11 +75,12 @@ def load_models(self):
if self.params.beat_model:
self.beat_model = load_model(self.params.beat_model)

def preprocess(self, data: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]:
def preprocess(self, data: npt.NDArray[np.float32]) -> tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
"""Perform pre-processing to data"""
data = filter_signal(data, lowcut=0.5, highcut=30, order=3, sample_rate=self.params.sampling_rate, axis=0)
data = normalize_signal(data, eps=0.1, axis=None)
return data
ecg_data = filter_signal(data, lowcut=0.5, highcut=30, order=3, sample_rate=self.params.sampling_rate, axis=0)
qrs_data = filter_signal(data, lowcut=10, highcut=30, order=3, sample_rate=self.params.sampling_rate, axis=0)
return ecg_data, qrs_data

def arrhythmia_inference(self, data: npt.NDArray[np.float32], threshold: float = 0.75) -> npt.NDArray[np.uint8]:
"""Apply arrhythmia model to data.
Expand Down Expand Up @@ -204,7 +208,7 @@ def run(self):

# Pre-process
self.update_app_state(AppState.PREPROCESS_STATE)
data = self.preprocess(data=data)
data, qrs_data = self.preprocess(data=data)

# Inference
self.update_app_state(AppState.INFERENCE_STATE)
Expand All @@ -221,7 +225,7 @@ def run(self):
seg_mask, _ = self.segmentation_inference(data, threshold=0.7)

# Apply HRV model and extract R peaks
rpeaks, rr_ints, _ = self.hrv_inference(data, seg_mask)
rpeaks, rr_ints, _ = self.hrv_inference(qrs_data, seg_mask)
bpm = 60 / (np.mean(rr_ints) / self.params.sampling_rate)
# bpm, _, rpeaks = compute_hrv(data, qrs_mask, self.params.sampling_rate)
avg_rr = max(0, int(self.params.sampling_rate / (bpm / 60)))
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "heartkit"
version = "1.0.0"
version = "1.1.0"
description = "AI driven heart monitoring kit for ultra low-power wearables."
license = "BSD-3-Clause"
authors = ["Adam Page <adam.page@ambiq.com>"]
Expand Down

0 comments on commit 408b5da

Please sign in to comment.