Skip to content

Commit

Permalink
fix fp16 overflow in totalsegmentator <3
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Oct 27, 2023
1 parent de48541 commit d78f17f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \
device=results_device)
if self.use_gaussian:
gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
value_scaling_factor=1000,
value_scaling_factor=10,
device=results_device)
except RuntimeError:
# sometimes the stuff is too large for GPUs. In that case fall back to CPU
Expand All @@ -620,7 +620,7 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \
device=results_device)
if self.use_gaussian:
gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
value_scaling_factor=1000,
value_scaling_factor=10,
device=results_device)
finally:
empty_cache(self.device)
Expand Down

5 comments on commit d78f17f

@josegcpa
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were just checking a similar issue and the same was happening with us at fp32... so good fix :-)

@FabianIsensee
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not happen in fp32 as the value range is much higher. Are you certain this was what caused it? We only ran into the problem because we use fp16 here

@josegcpa
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not happen in fp32 as the value range is much higher. Are you certain this was what caused it? We only ran into the problem because we use fp16 here

It was only in a few instances but it still happened. I changed it from 1000 to 1 exactly where 1000 was changed to 10 and we stopped having this issue, so we are fairly certain that this is what caused it. In any case it is fixed so that's the most important

@FabianIsensee
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Glad this fixed it, then :-) Maybe we need to add an explicit inf check at some point in the pipeline to avoid future problems o.O

@josegcpa
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think maybe something like

max_value_precision = 3.4e38 if predicted_logits.dtype is torch.float32 else 6.5e4
predicted_logits = torch.where(torch.isinf(predicted_logits), torch.sign(predicted_logits) * max_value_precision, predicted_logits)

would be an easy fix for it?

A more overbearing question I had here - why is there a scaling factor in the gaussian? Is the point avoiding underflow? If so then maybe the gaussian could be 'adapted' to

gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
                                                    value_scaling_factor=1000,
                                                    device=results_device)
minimum_precision_value = 1.2e-38 if gaussian.dtype is torch.float32 else 0.00000006
gaussian = torch.clip(gaussian, minimum_precision_value)

and this would avoid the earlier solution which is slightly more intensive

Please sign in to comment.