Skip to content

Commit

Permalink
feat: latest models
Browse files Browse the repository at this point in the history
  • Loading branch information
apage224 committed Aug 10, 2023
1 parent abd53fa commit e28e5ac
Show file tree
Hide file tree
Showing 19 changed files with 10,796 additions and 12,544 deletions.
13 changes: 6 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ heartkit --task arrhythmia --mode train --config ./configs/train-arrhythmia-mode
The `evaluate` command will evaluate the performance of the model on the reserved test set. A confidence threshold can also be set such that a label is only assigned when the model's probability is greater than the threshold; otherwise, a label of inconclusive will be assigned.

```bash
heartkit --task arrhythmia --mode evaluate --config ./configs/test-arrhythmia-model.json
heartkit --task arrhythmia --mode evaluate --config ./configs/evaluate-arrhythmia-model.json
```

#### __4. Export Model__
Expand Down Expand Up @@ -137,12 +137,11 @@ HeartKit leverages several open-source datasets for training each of the HeartKi

The following table provides the latest performance and accuracy results of all models when running on Apollo4 Plus EVB.

| Task | Params | FLOPS | Metric |
| -------------- | -------- | ------- | ---------- |
| Segmentation | 33K | 6.5M | 87.0% IOU |
| Arrhythmia | 50K | 3.6M | 99.0% F1 |
| Beat | 73K | 2.2M | 91.5% F1 |
| HRV | N/A | N/A | N/A |
| Task | Params | FLOPS | Metric | Cycles/Inf | Time/Inf |
| -------------- | -------- | ------- | ---------- | ---------- | ---------- |
| Segmentation | 33K | 6.5M | 87.0% IOU | 531ms | 102M |
| Arrhythmia | 50K | 3.6M | 99.0% F1 | 465ms | 89M |
| Beat | 73K | 2.2M | 91.5% F1 | 241ms | 46M |


## References
Expand Down
2 changes: 1 addition & 1 deletion configs/evaluate-arrhythmia-model.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"samples_per_patient": [50, 400, 200],
"test_patients": 1000,
"test_size": 100000,
"model_file_rmt": "wandb://ambiq/model-registry/heartkit-arrhythmia:v1",
"model_file": "./results/arrhythmia/model.tf",
"model_file_rmt": "wandb://ambiq/model-registry/heartkit-arrhythmia:v1",
"threshold": 0.75
}
2 changes: 1 addition & 1 deletion configs/evaluate-segmentation-model.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"frame_size": 512,
"samples_per_patient": 100,
"test_size": 10000,
"num_pts": 400,
"num_pts": 600,
"use_logits": false,
"datasets": ["synthetic", "ludb"],
"model_file": "./results/segmentation/model.tf",
Expand Down
2 changes: 1 addition & 1 deletion configs/export-arrhythmia-model.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
"samples_per_patient": [5, 40, 20],
"test_patients": 1000,
"test_size": 10000,
"model_file_rmt": "wandb://ambiq/model-registry/heartkit-arrhythmia:v1",
"model_file": "./results/arrhythmia/model.tf",
"model_file_rmt": "wandb://ambiq/model-registry/heartkit-arrhythmia:v1",
"quantization": true,
"use_logits": false,
"threshold": 0.80,
Expand Down
57 changes: 57 additions & 0 deletions configs/pretrain-segmentation-model.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{
"job_dir": "./results/segmentation-pre",
"ds_path": "./datasets",
"sampling_rate": 200,
"frame_size": 512,
"samples_per_patient": 100,
"val_samples_per_patient": 100,
"val_patients": 0.10,
"batch_size": 512,
"buffer_size": 25000,
"epochs": 80,
"steps_per_epoch": 100,
"val_metric": "loss",
"datasets": ["ludb", "synthetic"],
"lr_rate": 5e-3,
"num_pts": 600,
"quantization": false,
"augmentations": [
{
"name": "baseline_wander",
"args": {
"amplitude": [0.5, 1.5],
"frequency": [0.4, 0.5]
}
},
{
"name": "motion_noise",
"args": {
"amplitude": [0.5, 1.5],
"frequency": [0.4, 0.7]
}
},
{
"name": "burst_noise",
"args": {
"burst_number": [2, 10],
"amplitude": [0.5, 1.5],
"frequency": [40, 100]
}
},
{
"name": "powerline_noise",
"args": {
"amplitude": [0.005, 0.01],
"frequency": [50, 60]
}
},
{
"name": "noise_sources",
"args": {
"num_sources": [1, 8],
"amplitude": [0.05, 0.25],
"frequency": [10, 40]
}
}
]
}
2 changes: 1 addition & 1 deletion configs/train-beat-model.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"ds_path": "./datasets",
"sampling_rate": 200,
"frame_size": 160,
"samples_per_patient": [50, 400, 400],
"samples_per_patient": [100, 400, 400],
"val_samples_per_patient": [50, 100, 100],
"train_patients": 10000,
"val_file": "./results/beat-1000pt-800ms-200fs-4n.pkl",
Expand Down
8 changes: 5 additions & 3 deletions configs/train-segmentation-model.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"job_dir": "./results/segmentation",
"model_file": "./results/segmentation-pre/model.tf",
"ds_path": "./datasets",
"sampling_rate": 200,
"frame_size": 512,
Expand All @@ -8,12 +9,13 @@
"val_patients": 0.10,
"batch_size": 512,
"buffer_size": 25000,
"epochs": 50,
"epochs": 40,
"steps_per_epoch": 100,
"val_metric": "loss",
"datasets": ["ludb", "synthetic"],
"lr_rate": 5e-3,
"num_pts": 400,
"lr_rate": 1e-4,
"lr_cycles": 1,
"num_pts": 600,
"quantization": true,
"augmentations": [
{
Expand Down
4 changes: 1 addition & 3 deletions docs/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@ The ECG segmentation model serves as the backbone and is used to annotate every

## HRV Head

The HRV head uses only DSP and statistics (i.e. no network is used). The segmentation results are stitched together and used to derive several useful metrics including heart rate, rhythm and RR interval.

The HRV head uses only DSP and statistics (i.e. no neural network is used). Using a combination of segmentation results and QRS filter, the HRV head detects R peak candidates. RR intervals are extracted and filtered, and then used to derive a variety of HRV metrics including heart rate, rhythm, SDNN, SDRR, SDANN, etc. All of the identified R peaks are further fed to the beat classifier head. Note that if segmentation model is not enabled, HRV head falls back to identifying R peaks purely on gradient of QRS signal.

## Arrhythmia Head

The arrhythmia head is used to detect the presence of Atrial Fibrillation (AFIB) or Atrial Flutter (AFL). Note that if heart arrhythmia is detected, the remaining heads are skipped. The arrhythmia model utilizes a 1-D CNN built using MBConv style blocks that incorporate expansion, inverted residuals, and squeeze and excitation layers. Furthermore, longer filter and stide lengths are utilized in the initial layers to capture more temporal dependencies.


## Beat Head

The beat head is used to extract individual beats and classify them as either normal, premature/ectopic atrial contraction (PAC), premature/ectopic ventricular contraction (PVC), or noise. In addition to the target beat, the surrounding beats are also fed into the network as context. The “neighboring” beats are determined based on the average RR interval and not the actual R peak. The beat head also utilizes a 1-D CNN built using MBConv style blocks.
12 changes: 6 additions & 6 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ HeartKit leverages several open-source datasets for training each of the HeartKi

The following table provides the latest performance and accuracy results of all models when running on Apollo4 Plus EVB. Additional result details can be found in [Results Section](./results.md).

| Task | Params | FLOPS | Metric |
| -------------- | -------- | ------- | ---------- |
| Segmentation | 33K | 6.5M | 87.0% IOU |
| Arrhythmia | 50K | 3.6M | 99.0% F1 |
| Beat | 73K | 2.2M | 91.5% F1 |
| HRV | N/A | N/A | N/A |
| Task | Params | FLOPS | Metric | Cycles/Inf | Time/Inf |
| -------------- | -------- | ------- | ---------- | ---------- | ---------- |
| Segmentation | 33K | 6.5M | 87.0% IOU | 531ms | 102M |
| Arrhythmia | 50K | 3.6M | 99.0% F1 | 465ms | 89M |
| Beat | 73K | 2.2M | 91.5% F1 | 241ms | 46M |


## References

Expand Down
26 changes: 13 additions & 13 deletions docs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,17 @@ The `train` command is used to train a HeartKit model. The following command wil
hk.arrhythmia.train_model(hk.defines.HeartTrainParams(
job_dir="./results/arrhythmia",
ds_path="./datasets",
sampling_rate=250,
frame_size=1000,
samples_per_patient=[100, 800, 400],
val_samples_per_patient=[100, 800, 400],
sampling_rate=200,
frame_size=800,
samples_per_patient=[100, 800, 800],
val_samples_per_patient=[100, 800, 800],
train_patients=10000,
val_patients=0.20,
val_size=100000,
batch_size=6144,
val_patients=0.10,
val_size=200000,
batch_size=256,
buffer_size=100000,
epochs=100,
steps_per_epoch=10,
steps_per_epoch=20,
val_metric="loss",
datasets=["icentia11k"]
))
Expand All @@ -130,9 +130,9 @@ The `evaluate` command will evaluate the performance of the model on the reserve
hk.arrhythmia.evaluate_model(hk.defines.HeartTestParams(
job_dir="./results/arrhythmia",
ds_path="./datasets",
sampling_rate=250,
frame_size=1000,
samples_per_patient=[100, 800, 400],
sampling_rate=200,
frame_size=800,
samples_per_patient=[100, 800, 800],
test_patients=1000,
test_size=100000,
model_file="./results/arrhythmia/model.tf",
Expand Down Expand Up @@ -160,8 +160,8 @@ The `export` command will convert the trained TensorFlow model into both TensorF
hk.arrhythmia.export_model(hk.defines.HeartExportParams(
job_dir="./results/arrhythmia",
ds_path="./datasets",
sampling_rate=250,
frame_size=1000,
sampling_rate=200,
frame_size=800,
samples_per_patient=[100, 500, 100],
model_file="./results/arrhythmia/model.tf",
quantization=true,
Expand Down
11 changes: 5 additions & 6 deletions docs/results.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@

The following table provides performance and accuracy results of all models when running on Apollo4 Plus EVB.

| Task | Params | FLOPS | Metric |
| -------------- | -------- | ------- | ---------- |
| Segmentation | 33K | 6.5M | 87.0% IOU |
| Arrhythmia | 50K | 3.6M | 99.0% F1 |
| Beat | 73K | 2.2M | 91.5% F1 |
| HRV | N/A | N/A | N/A |
| Task | Params | FLOPS | Metric | Cycles/Inf | Time/Inf |
| -------------- | -------- | ------- | ---------- | ---------- | ---------- |
| Segmentation | 33K | 6.5M | 87.0% IOU | 531ms | 102M |
| Arrhythmia | 50K | 3.6M | 99.0% F1 | 465ms | 89M |
| Beat | 73K | 2.2M | 91.5% F1 | 241ms | 46M |

## Segmentation Results

Expand Down
12 changes: 11 additions & 1 deletion docs/tutorials/heartkit-demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@ Please follow [EVB Setup Guide](./evb-setup.md) to prepare EVB and connect to PC

### 1. Train all the models

1.1 Train the segmentation model:
1.1 Train and fine-tune the segmentation model:

```bash
heartkit \
--task segmentation \
--mode train \
--config ./configs/pretrain-segmentation-model.json
```

```bash
heartkit \
Expand All @@ -33,6 +40,9 @@ heartkit \
--config ./configs/train-segmentation-model.json
```

!!! note
The second train command uses quantization-aware training to reduce accuracy drop when exporting to 8-bit.

1.2 Train the arrhythmia model:

```bash
Expand Down
Loading

0 comments on commit e28e5ac

Please sign in to comment.