Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[naga msl-out] Avoid UB by making all loops bounded. #6545

Merged
merged 1 commit into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 47 additions & 49 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,10 @@ pub struct Writer<W> {
/// padding inserted **before** them (i.e. between fields at index - 1 and index)
struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,

/// Name of the loop reachability macro.
/// Name of the force-bounded-loop macro.
///
/// See `emit_loop_reachable_macro` for details.
loop_reachable_macro_name: String,
/// See `emit_force_bounded_loop_macro` for details.
force_bounded_loop_macro_name: String,
}

impl crate::Scalar {
Expand Down Expand Up @@ -682,7 +682,7 @@ impl<W: Write> Writer<W> {
#[cfg(test)]
put_block_stack_pointers: Default::default(),
struct_member_pads: FastHashSet::default(),
loop_reachable_macro_name: String::default(),
force_bounded_loop_macro_name: String::default(),
}
}

Expand All @@ -693,12 +693,13 @@ impl<W: Write> Writer<W> {
self.out
}

/// Define a macro to invoke before loops, to defeat MSL infinite loop
/// reasoning.
/// Define a macro to invoke at the bottom of each loop body, to
/// defeat MSL infinite loop reasoning.
///
/// If we haven't done so already, emit the definition of a preprocessor
/// macro to be invoked before each loop in the generated MSL, to ensure
/// that the MSL compiler's optimizations do not remove bounds checks.
/// macro to be invoked at the end of each loop body in the generated MSL,
/// to ensure that the MSL compiler's optimizations do not remove bounds
/// checks.
///
/// Only the first call to this function for a given module actually causes
/// the macro definition to be written. Subsequent loops can simply use the
Expand Down Expand Up @@ -764,52 +765,51 @@ impl<W: Write> Writer<W> {
/// nicely, after having stolen data from elsewhere in the GPU address
/// space.
///
/// Ideally, Naga would prevent UB entirely via some means that persuades
/// the MSL compiler that no loop Naga generates is infinite. One approach
/// would be to add inline assembly to each loop that is annotated as
/// potentially branching out of the loop, but which in fact generates no
/// instructions. Unfortunately, inline assembly is not handled correctly by
/// some Metal device drivers. Further experimentation hasn't produced a
/// satisfactory approach.
/// To avoid UB, Naga must persuade the MSL compiler that no loop Naga
/// generates is infinite. One approach would be to add inline assembly to
/// each loop that is annotated as potentially branching out of the loop,
/// but which in fact generates no instructions. Unfortunately, inline
/// assembly is not handled correctly by some Metal device drivers.
jimblandy marked this conversation as resolved.
Show resolved Hide resolved
///
/// Instead, we accept that the MSL compiler may determine that some loops
/// are infinite, and focus instead on preventing the range analysis from
/// being affected. We transform *every* loop into something like this:
/// Instead, we add the following code to the bottom of every loop:
///
/// ```ignore
/// if (volatile bool unpredictable = true; unpredictable)
/// while (true) { }
/// if (volatile bool unpredictable = false; unpredictable)
/// break;
/// ```
///
/// Since the `volatile` qualifier prevents the compiler from assuming that
/// the `if` condition is true, it cannot be sure the infinite loop is
/// reached, and thus it cannot assume the entire structure is unreachable.
/// This prevents the range analysis impact described above.
/// Although the `if` condition will always be false in any real execution,
/// the `volatile` qualifier prevents the compiler from assuming this. Thus,
/// it must assume that the `break` might be reached, and hence that the
/// loop is not unbounded. This prevents the range analysis impact described
/// above.
///
/// Unfortunately, what makes this a kludge, not a hack, is that this
/// solution leaves the GPU executing a pointless conditional branch, at
/// runtime, before each loop. There's no part of the system that has a
/// global enough view to be sure that `unpredictable` is true, and remove
/// it from the code.
/// runtime, in every iteration of the loop. There's no part of the system
/// that has a global enough view to be sure that `unpredictable` is true,
/// and remove it from the code. Adding the branch also affects
/// optimization: for example, it's impossible to unroll this loop. This
/// transformation has been observed to significantly hurt performance.
///
/// To make our output a bit more legible, we pull the condition out into a
/// preprocessor macro defined at the top of the module.
///
/// This approach is also used by Chromium WebGPU's Dawn shader compiler, as of
/// <https://github.com/google/dawn/commit/ffd485c685040edb1e678165dcbf0e841cfa0298>.
fn emit_loop_reachable_macro(&mut self) -> BackendResult {
if !self.loop_reachable_macro_name.is_empty() {
/// This approach is also used by Chromium WebGPU's Dawn shader compiler:
/// <https://dawn.googlesource.com/dawn/+/a37557db581c2b60fb1cd2c01abdb232927dd961/src/tint/lang/msl/writer/printer/printer.cc#222>
fn emit_force_bounded_loop_macro(&mut self) -> BackendResult {
if !self.force_bounded_loop_macro_name.is_empty() {
return Ok(());
}

self.loop_reachable_macro_name = self.namer.call("LOOP_IS_REACHABLE");
let loop_reachable_volatile_name = self.namer.call("unpredictable_jump_over_loop");
self.force_bounded_loop_macro_name = self.namer.call("LOOP_IS_BOUNDED");
let loop_bounded_volatile_name = self.namer.call("unpredictable_break_from_loop");
writeln!(
self.out,
"#define {} if (volatile bool {} = true; {})",
self.loop_reachable_macro_name,
loop_reachable_volatile_name,
loop_reachable_volatile_name,
"#define {} {{ volatile bool {} = false; if ({}) break; }}",
jimblandy marked this conversation as resolved.
Show resolved Hide resolved
self.force_bounded_loop_macro_name,
loop_bounded_volatile_name,
loop_bounded_volatile_name,
)?;

Ok(())
Expand Down Expand Up @@ -3045,15 +3045,10 @@ impl<W: Write> Writer<W> {
ref continuing,
break_if,
} => {
self.emit_loop_reachable_macro()?;
if !continuing.is_empty() || break_if.is_some() {
let gate_name = self.namer.call("loop_init");
writeln!(self.out, "{level}bool {gate_name} = true;")?;
writeln!(
self.out,
"{level}{} while(true) {{",
self.loop_reachable_macro_name,
)?;
writeln!(self.out, "{level}while(true) {{",)?;
let lif = level.next();
let lcontinuing = lif.next();
writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
Expand All @@ -3068,13 +3063,16 @@ impl<W: Write> Writer<W> {
writeln!(self.out, "{lif}}}")?;
writeln!(self.out, "{lif}{gate_name} = false;")?;
} else {
writeln!(
self.out,
"{level}{} while(true) {{",
self.loop_reachable_macro_name,
)?;
writeln!(self.out, "{level}while(true) {{",)?;
}
self.put_block(level.next(), body, context)?;
self.emit_force_bounded_loop_macro()?;
writeln!(
self.out,
"{}{}",
level.next(),
self.force_bounded_loop_macro_name
)?;
writeln!(self.out, "{level}}}")?;
}
crate::Statement::Break => {
Expand Down Expand Up @@ -3553,7 +3551,7 @@ impl<W: Write> Writer<W> {
&[CLAMPED_LOD_LOAD_PREFIX],
&mut self.names,
);
self.loop_reachable_macro_name.clear();
self.force_bounded_loop_macro_name.clear();
self.struct_member_pads.clear();

writeln!(
Expand Down
14 changes: 9 additions & 5 deletions naga/tests/out/msl/atomicCompareExchange.msl
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,8 @@ kernel void test_atomic_compare_exchange_i32_(
uint i = 0u;
int old = {};
bool exchanged = {};
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
bool loop_init = true;
LOOP_IS_REACHABLE while(true) {
while(true) {
if (!loop_init) {
uint _e27 = i;
i = _e27 + 1u;
Expand All @@ -94,7 +93,7 @@ kernel void test_atomic_compare_exchange_i32_(
int _e8 = metal::atomic_load_explicit(&arr_i32_.inner[_e6], metal::memory_order_relaxed);
old = _e8;
exchanged = false;
LOOP_IS_REACHABLE while(true) {
while(true) {
bool _e12 = exchanged;
if (!(_e12)) {
} else {
Expand All @@ -109,8 +108,11 @@ kernel void test_atomic_compare_exchange_i32_(
old = _e23.old_value;
exchanged = _e23.exchanged;
}
#define LOOP_IS_BOUNDED { volatile bool unpredictable_break_from_loop = false; if (unpredictable_break_from_loop) break; }
LOOP_IS_BOUNDED
}
}
LOOP_IS_BOUNDED
}
return;
}
Expand All @@ -123,7 +125,7 @@ kernel void test_atomic_compare_exchange_u32_(
uint old_1 = {};
bool exchanged_1 = {};
bool loop_init_1 = true;
LOOP_IS_REACHABLE while(true) {
while(true) {
if (!loop_init_1) {
uint _e27 = i_1;
i_1 = _e27 + 1u;
Expand All @@ -139,7 +141,7 @@ kernel void test_atomic_compare_exchange_u32_(
uint _e8 = metal::atomic_load_explicit(&arr_u32_.inner[_e6], metal::memory_order_relaxed);
old_1 = _e8;
exchanged_1 = false;
LOOP_IS_REACHABLE while(true) {
while(true) {
bool _e12 = exchanged_1;
if (!(_e12)) {
} else {
Expand All @@ -154,8 +156,10 @@ kernel void test_atomic_compare_exchange_u32_(
old_1 = _e23.old_value;
exchanged_1 = _e23.exchanged;
}
LOOP_IS_BOUNDED
}
}
LOOP_IS_BOUNDED
}
return;
}
5 changes: 3 additions & 2 deletions naga/tests/out/msl/boids.msl
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ kernel void main_(
vPos = _e8;
metal::float2 _e14 = particlesSrc.particles[index].vel;
vVel = _e14;
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
bool loop_init = true;
LOOP_IS_REACHABLE while(true) {
while(true) {
if (!loop_init) {
uint _e91 = i;
i = _e91 + 1u;
Expand Down Expand Up @@ -106,6 +105,8 @@ kernel void main_(
int _e88 = cVelCount;
cVelCount = _e88 + 1;
}
#define LOOP_IS_BOUNDED { volatile bool unpredictable_break_from_loop = false; if (unpredictable_break_from_loop) break; }
LOOP_IS_BOUNDED
}
int _e94 = cMassCount;
if (_e94 > 0) {
Expand Down
14 changes: 9 additions & 5 deletions naga/tests/out/msl/break-if.msl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@ using metal::uint;

void breakIfEmpty(
) {
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
bool loop_init = true;
LOOP_IS_REACHABLE while(true) {
while(true) {
if (!loop_init) {
if (true) {
break;
}
}
loop_init = false;
#define LOOP_IS_BOUNDED { volatile bool unpredictable_break_from_loop = false; if (unpredictable_break_from_loop) break; }
LOOP_IS_BOUNDED
}
return;
}
Expand All @@ -26,7 +27,7 @@ void breakIfEmptyBody(
bool b = {};
bool c = {};
bool loop_init_1 = true;
LOOP_IS_REACHABLE while(true) {
while(true) {
if (!loop_init_1) {
b = a;
bool _e2 = b;
Expand All @@ -37,6 +38,7 @@ void breakIfEmptyBody(
}
}
loop_init_1 = false;
LOOP_IS_BOUNDED
}
return;
}
Expand All @@ -47,7 +49,7 @@ void breakIf(
bool d = {};
bool e = {};
bool loop_init_2 = true;
LOOP_IS_REACHABLE while(true) {
while(true) {
if (!loop_init_2) {
bool _e5 = e;
if (a_1 == e) {
Expand All @@ -58,6 +60,7 @@ void breakIf(
d = a_1;
bool _e2 = d;
e = a_1 != _e2;
LOOP_IS_BOUNDED
}
return;
}
Expand All @@ -66,7 +69,7 @@ void breakIfSeparateVariable(
) {
uint counter = 0u;
bool loop_init_3 = true;
LOOP_IS_REACHABLE while(true) {
while(true) {
if (!loop_init_3) {
uint _e5 = counter;
if (counter == 5u) {
Expand All @@ -76,6 +79,7 @@ void breakIfSeparateVariable(
loop_init_3 = false;
uint _e3 = counter;
counter = _e3 + 1u;
LOOP_IS_BOUNDED
}
return;
}
Expand Down
5 changes: 3 additions & 2 deletions naga/tests/out/msl/collatz.msl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ uint collatz_iterations(
uint n = {};
uint i = 0u;
n = n_base;
#define LOOP_IS_REACHABLE if (volatile bool unpredictable_jump_over_loop = true; unpredictable_jump_over_loop)
LOOP_IS_REACHABLE while(true) {
while(true) {
uint _e4 = n;
if (_e4 > 1u) {
} else {
Expand All @@ -38,6 +37,8 @@ uint collatz_iterations(
uint _e20 = i;
i = _e20 + 1u;
}
#define LOOP_IS_BOUNDED { volatile bool unpredictable_break_from_loop = false; if (unpredictable_break_from_loop) break; }
LOOP_IS_BOUNDED
}
uint _e23 = i;
return _e23;
Expand Down
Loading