Skip to content

Commit

Permalink
Merge pull request #127 from BrainLesion/aggregate_fix
Browse files Browse the repository at this point in the history
Aggregate fix
  • Loading branch information
Hendrik-code authored Sep 3, 2024
2 parents b174726 + ace99fa commit 7d0b12a
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 15 deletions.
16 changes: 9 additions & 7 deletions examples/example_spine_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


if __name__ == "__main__":
parallel_opt = "None" # none, pool, joblib, future
parallel_opt = "future" # none, pool, joblib, future
#
parallel_opt = parallel_opt.lower()

Expand All @@ -42,15 +42,17 @@
for i in range(4):
results = evaluator.evaluate(prediction_mask, reference_mask, f"sample{i}")
elif parallel_opt == "joblib":
Parallel(n_jobs=4, backend="threading")(
delayed(evaluator.evaluate)(prediction_mask, reference_mask)
for i in range(4)
Parallel(n_jobs=5, backend="threading")(
delayed(evaluator.evaluate)(prediction_mask, reference_mask, f"sample{i}")
for i in range(10)
)
elif parallel_opt == "future":
with ProcessPoolExecutor() as executor:
with ProcessPoolExecutor(max_workers=5) as executor:
futures = {
executor.submit(evaluator.evaluate, prediction_mask, reference_mask)
for i in range(4)
executor.submit(
evaluator.evaluate, prediction_mask, reference_mask, f"sample{i}"
)
for i in range(10)
}
for future in tqdm(
as_completed(futures), total=len(futures), desc="Panoptica Evaluation"
Expand Down
34 changes: 34 additions & 0 deletions panoptica/configs/panoptica_evaluator_VERSE.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
!Panoptica_Evaluator
decision_metric: null
decision_threshold: null
edge_case_handler: !EdgeCaseHandler
empty_list_std: !EdgeCaseResult NAN
listmetric_zeroTP_handling:
!Metric DSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult ZERO}
!Metric clDSC: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult ZERO}
!Metric IOU: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult ZERO,
empty_reference_result: !EdgeCaseResult ZERO, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult ZERO}
!Metric ASSD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult INF,
empty_reference_result: !EdgeCaseResult INF, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult INF}
!Metric RVD: !MetricZeroTPEdgeCaseHandling {empty_prediction_result: !EdgeCaseResult NAN,
empty_reference_result: !EdgeCaseResult NAN, no_instances_result: !EdgeCaseResult NAN,
normal: !EdgeCaseResult NAN}
instance_metrics: [!Metric DSC, !Metric IOU, !Metric ASSD, !Metric RVD]
global_metrics: [!Metric DSC, !Metric RVD, !Metric IOU]
expected_input: !InputType UNMATCHED_INSTANCE
instance_approximator: null
instance_matcher: !NaiveThresholdMatching {allow_many_to_one: false, matching_metric: !Metric IOU,
matching_threshold: 0.5}
log_times: false
segmentation_class_groups: !SegmentationClassGroups
groups:
vertebra: !LabelGroup
single_instance: false
value_labels: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26]
verbose: false
4 changes: 3 additions & 1 deletion panoptica/panoptica_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def __init__(
atexit.register(self.__exist_handler)

def __exist_handler(self):
os.remove(self.__output_buffer_file)
if os.path.exists(self.__output_buffer_file):
os.remove(self.__output_buffer_file)

def make_statistic(self) -> Panoptica_Statistic:
with filelock:
Expand Down Expand Up @@ -120,6 +121,7 @@ def evaluate(
_write_content(self.__output_buffer_file, [[subject_name]])

# Run Evaluation (allowed in parallel)
print(f"Call evaluate on {subject_name}")
res = self.__panoptica_evaluator.evaluate(
prediction_arr,
reference_arr,
Expand Down
34 changes: 27 additions & 7 deletions panoptica/panoptica_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def from_file(cls, file: str):
if group_name not in value_dict:
value_dict[group_name] = {m: [] for m in metric_names}

value_dict[group_name][metric_name].append(float(value))
if len(value) > 0:
value = float(value)
if not np.isnan(value) and value != np.inf:
value_dict[group_name][metric_name].append(float(value))

return Panoptica_Statistic(subj_names=subj_names, value_dict=value_dict)

Expand Down Expand Up @@ -157,14 +160,14 @@ def avg_std(self, group, metric) -> tuple[float, float]:
std = float(np.std(values))
return (avg, std)

def print_summary(self):
def print_summary(self, ndigits: int = 3):
summary = self.get_summary_dict()
print()
for g in self.__groupnames:
print(f"Group {g}:")
for m in self.__metricnames:
avg, std = summary[g][m]
print(m, ":", avg, "+-", std)
print(m, ":", round(avg, ndigits), "+-", round(std, ndigits))
print()

def get_summary_figure(
Expand Down Expand Up @@ -201,12 +204,18 @@ def make_curve_over_setups(
statistics_dict: dict[str | int | float, Panoptica_Statistic],
metric: str,
groups: list[str] | str | None = None,
alternate_groupnames: list[str] | str | None = None,
fig: None = None,
plot_dotsize: int | None = None,
plot_lines: bool = True,
):
if groups is None:
groups = list(statistics_dict.values())[0].groupnames
#
if isinstance(groups, str):
groups = [groups]
if isinstance(alternate_groupnames, str):
alternate_groupnames = [alternate_groupnames]
#
for setupname, stat in statistics_dict.items():
assert (
Expand All @@ -226,16 +235,27 @@ def make_curve_over_setups(
else:
X = range(len(setupnames))

fig = plt.figure()
if fig is None:
fig = plt.figure()

if not convert_x_to_digit:
plt.xticks(X, setupnames)

plt.ylabel("average " + metric)
plt.ylabel("Average " + metric)
plt.grid("major")
# Y values are average metric values in that group and metric
for g in groups:
for idx, g in enumerate(groups):
Y = [stat.avg_std(g, metric)[0] for stat in statistics_dict.values()]
plt.plot(X, Y, label=g)

if plot_lines:
plt.plot(
X,
Y,
label=g if alternate_groupnames is None else alternate_groupnames[idx],
)

if plot_dotsize is not None:
plt.scatter(X, Y, s=plot_dotsize)

plt.legend()
return fig
Expand Down

0 comments on commit 7d0b12a

Please sign in to comment.