diff --git a/rinja_derive/Cargo.toml b/rinja_derive/Cargo.toml index 3614cb71..ffac3f1e 100644 --- a/rinja_derive/Cargo.toml +++ b/rinja_derive/Cargo.toml @@ -27,6 +27,7 @@ with-warp = [] [dependencies] parser = { package = "rinja_parser", version = "0.2.0", path = "../rinja_parser" } basic-toml = { version = "0.1.1", optional = true } +memchr = "2" mime = "0.3" mime_guess = "2" once_map = "0.4.18" diff --git a/rinja_derive/src/generator.rs b/rinja_derive/src/generator.rs index 3ad8547e..386fb84a 100644 --- a/rinja_derive/src/generator.rs +++ b/rinja_derive/src/generator.rs @@ -1950,19 +1950,19 @@ impl Buffer { } fn write_writer(&mut self, s: &str) -> usize { - if self.discard { - // nothing to do - } else if !self.last_was_write_str { - write!(self.buf, "writer.write_str({s:#?})?;").unwrap(); - self.last_was_write_str = true; - } else { - // strip trailing `")?;`, leaving an unterminated string - let len = self.buf.strip_suffix("\")?;").unwrap().len(); - self.buf.truncate(len); - // append the new string, adding a stray `"` in the mid of the string - write!(self.buf, "{s:#?})?;").unwrap(); - // left shift new string by one to overwrite the stray `"` - self.buf.replace_range(len..=len, ""); + const OPEN: &str = r#"writer.write_str(""#; + const CLOSE: &str = r#"")?;"#; + + if !self.discard { + if !self.last_was_write_str { + self.last_was_write_str = true; + self.buf.push_str(OPEN); + } else { + // strip trailing `")?;`, leaving an unterminated string + self.buf.truncate(self.buf.len() - CLOSE.len()) + } + string_escape(&mut self.buf, s); + self.buf.push_str(CLOSE); } s.len() } @@ -2278,3 +2278,26 @@ fn normalize_identifier(ident: &str) -> &str { // SAFETY: We know that the input byte slice is pure-ASCII. unsafe { std::str::from_utf8_unchecked(&replacement[..ident.len() + 2]) } } + +/// Similar to `write!(dest, "{src:?}")`, but only escapes the strictly needed characters, +/// and without the surrounding `"…"` quotation marks. +pub(crate) fn string_escape(dest: &mut String, src: &str) { + // SAFETY: we will only push valid str slices + let dest = unsafe { dest.as_mut_vec() }; + let src = src.as_bytes(); + let mut last = 0; + + // According to , every + // character is valid except `" \ IsolatedCR`. We don't test if the `\r` is isolated or not, + // but always escape it. + for x in memchr::memchr3_iter(b'\\', b'"', b'\r', src) { + dest.extend(&src[last..x]); + dest.extend(match src[x] { + b'\\' => br#"\\"#, + b'\"' => br#"\""#, + _ => br#"\r"#, + }); + last = x + 1; + } + dest.extend(&src[last..]); +} diff --git a/rinja_derive_standalone/Cargo.toml b/rinja_derive_standalone/Cargo.toml index cf9ab034..657daa61 100644 --- a/rinja_derive_standalone/Cargo.toml +++ b/rinja_derive_standalone/Cargo.toml @@ -26,6 +26,7 @@ with-warp = [] [dependencies] parser = { package = "rinja_parser", version = "0.2.0", path = "../rinja_parser" } basic-toml = { version = "0.1.1", optional = true } +memchr = "2" mime = "0.3" mime_guess = "2" once_map = "0.4.18"