diff --git a/rinja/src/html.rs b/rinja/src/html.rs
index 45c1d727..bbba830e 100644
--- a/rinja/src/html.rs
+++ b/rinja/src/html.rs
@@ -1,71 +1,131 @@
-use std::fmt;
-use std::num::NonZeroU8;
+use std::{fmt, str};
#[allow(unused)]
-pub(crate) fn write_escaped_str(mut fmt: impl fmt::Write, string: &str) -> fmt::Result {
- let mut escaped_buf = *b"__;";
+pub(crate) fn write_escaped_str(mut dest: impl fmt::Write, src: &str) -> fmt::Result {
+ // This implementation reads one byte after another.
+ // It's not very fast, but should work well enough until portable SIMD gets stabilized.
+
+ let mut escaped_buf = ESCAPED_BUF_INIT;
let mut last = 0;
- for (index, byte) in string.bytes().enumerate() {
- let escaped = match byte {
- MIN_CHAR..=MAX_CHAR => TABLE.lookup[(byte - MIN_CHAR) as usize],
- _ => None,
- };
- if let Some(escaped) = escaped {
- escaped_buf[2] = escaped[0].get();
- escaped_buf[3] = escaped[1].get();
- fmt.write_str(&string[last..index])?;
- fmt.write_str(unsafe { std::str::from_utf8_unchecked(escaped_buf.as_slice()) })?;
+ for (index, byte) in src.bytes().enumerate() {
+ if let Some(escaped) = get_escaped(byte) {
+ [escaped_buf[2], escaped_buf[3]] = escaped;
+ write_str_if_nonempty(&mut dest, &src[last..index])?;
+ // SAFETY: the content of `escaped_buf` is pure ASCII
+ dest.write_str(unsafe { str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN]) })?;
last = index + 1;
}
}
- fmt.write_str(&string[last..])
+ write_str_if_nonempty(&mut dest, &src[last..])
}
#[allow(unused)]
-pub(crate) fn write_escaped_char(mut fmt: impl fmt::Write, c: char) -> fmt::Result {
- fmt.write_str(match (c.is_ascii(), c as u8) {
- (true, b'"') => """,
- (true, b'&') => "&",
- (true, b'\'') => "'",
- (true, b'<') => "<",
- (true, b'>') => ">",
- _ => return fmt.write_char(c),
- })
+pub(crate) fn write_escaped_char(mut dest: impl fmt::Write, c: char) -> fmt::Result {
+ if !c.is_ascii() {
+ dest.write_char(c)
+ } else if let Some(escaped) = get_escaped(c as u8) {
+ let mut escaped_buf = ESCAPED_BUF_INIT;
+ [escaped_buf[2], escaped_buf[3]] = escaped;
+ // SAFETY: the content of `escaped_buf` is pure ASCII
+ dest.write_str(unsafe { str::from_utf8_unchecked(&escaped_buf[..ESCAPED_BUF_LEN]) })
+ } else {
+ // RATIONALE: `write_char(c)` gets optimized if it is known that `c.is_ascii()`
+ dest.write_char(c)
+ }
}
-const MIN_CHAR: u8 = b'"';
-const MAX_CHAR: u8 = b'>';
+/// If the character needs HTML escaping, then return the decimal representation of the codepoint.
+#[inline(always)]
+fn get_escaped(byte: u8) -> Option<[u8; 2]> {
+ match byte {
+ MIN_CHAR..=MAX_CHAR => match TABLE.lookup[(byte - MIN_CHAR) as usize] {
+ 0 => None,
+ escaped => Some(escaped.to_ne_bytes()),
+ },
+ _ => None,
+ }
+}
+
+#[inline(always)]
+fn write_str_if_nonempty(output: &mut impl fmt::Write, input: &str) -> fmt::Result {
+ if !input.is_empty() {
+ output.write_str(input)
+ } else {
+ Ok(())
+ }
+}
+
+/// List of characters that need HTML escaping, not necessarily in ordinal order.
+const CHARS: &[u8] = br#""&'<>"#;
+
+/// The character with the lowest codepoint that needs HTML escaping.
+const MIN_CHAR: u8 = {
+ let mut v = u8::MAX;
+ let mut i = 0;
+ while i < CHARS.len() {
+ if v > CHARS[i] {
+ v = CHARS[i];
+ }
+ i += 1;
+ }
+ v
+};
+
+/// The character with the highest codepoint that needs HTML escaping.
+const MAX_CHAR: u8 = {
+ let mut v = u8::MIN;
+ let mut i = 0;
+ while i < CHARS.len() {
+ if v < CHARS[i] {
+ v = CHARS[i];
+ }
+ i += 1;
+ }
+ v
+};
+
+/// Number of codepoints between the lowest and highest character that needs escaping, incl.
+const CHAR_RANGE: usize = (MAX_CHAR - MIN_CHAR + 1) as usize;
struct Table {
_align: [usize; 0],
- lookup: [Option<[NonZeroU8; 2]>; (MAX_CHAR - MIN_CHAR + 1) as usize],
+ lookup: [u16; CHAR_RANGE],
}
+/// For characters that need HTML escaping, the codepoint formatted as decimal digits,
+/// otherwise `b"\0\0"`. Starting at [`MIN_CHAR`].
const TABLE: Table = {
- const fn n(c: u8) -> Option<[NonZeroU8; 2]> {
- assert!(MIN_CHAR <= c && c <= MAX_CHAR);
-
- let n0 = match NonZeroU8::new(c / 10 + b'0') {
- Some(n) => n,
- None => panic!(),
- };
- let n1 = match NonZeroU8::new(c % 10 + b'0') {
- Some(n) => n,
- None => panic!(),
- };
- Some([n0, n1])
- }
-
let mut table = Table {
_align: [],
- lookup: [None; (MAX_CHAR - MIN_CHAR + 1) as usize],
+ lookup: [0; CHAR_RANGE],
};
-
- table.lookup[(b'"' - MIN_CHAR) as usize] = n(b'"');
- table.lookup[(b'&' - MIN_CHAR) as usize] = n(b'&');
- table.lookup[(b'\'' - MIN_CHAR) as usize] = n(b'\'');
- table.lookup[(b'<' - MIN_CHAR) as usize] = n(b'<');
- table.lookup[(b'>' - MIN_CHAR) as usize] = n(b'>');
+ let mut i = 0;
+ while i < CHARS.len() {
+ let c = CHARS[i];
+ let h = c / 10 + b'0';
+ let l = c % 10 + b'0';
+ table.lookup[(c - MIN_CHAR) as usize] = u16::from_ne_bytes([h, l]);
+ i += 1;
+ }
table
};
+
+// RATIONALE: llvm generates better code if the buffer is register sized
+const ESCAPED_BUF_INIT: [u8; 8] = *b"__;\0\0\0";
+const ESCAPED_BUF_LEN: usize = b"__;".len();
+
+#[test]
+fn simple() {
+ let mut buf = String::new();
+ write_escaped_str(&mut buf, "