diff --git a/src/Data/Csv/Parser.hs b/src/Data/Csv/Parser.hs index d028167..82fe4d2 100644 --- a/src/Data/Csv/Parser.hs +++ b/src/Data/Csv/Parser.hs @@ -154,17 +154,20 @@ field !delim = do escapedField :: AL.Parser S.ByteString escapedField = do _ <- dquote - -- The scan state is 'True' if the previous character was a double - -- quote. We need to drop a trailing double quote left by scan. - s <- S.init <$> (A.scan False $ \s c -> if c == doubleQuote - then Just (not s) - else if s then Nothing - else Just False) - if doubleQuote `S.elem` s - then case Z.parse unescape s of - Right r -> return r - Left err -> fail err - else return s + -- The scan state is 'True' if the previous character was a double quote. + s' <- A.scan False $ \s c -> if c == doubleQuote + then Just (not s) + else if s then Nothing + else Just False + -- We need to drop a trailing double quote left by scan. + if S.null s' + then fail "trailing double quote" + else let s = S.init s' + in if doubleQuote `S.elem` s + then case Z.parse unescape s of + Right r -> return r + Left err -> fail err + else return s unescapedField :: Word8 -> AL.Parser S.ByteString unescapedField !delim = A.takeWhile (\ c -> c /= doubleQuote && diff --git a/tests/UnitTests.hs b/tests/UnitTests.hs index 4e625e0..12ea362 100644 --- a/tests/UnitTests.hs +++ b/tests/UnitTests.hs @@ -21,6 +21,7 @@ import qualified Data.Text as T import qualified Data.Text.Lazy as LT import qualified Data.Vector as V import qualified Data.Foldable as F +import Data.List (isPrefixOf) import Data.Word import Numeric.Natural import GHC.Generics (Generic) @@ -56,6 +57,14 @@ assertResult input expected res = case res of " input: " ++ show (BL8.unpack input) ++ "\n" ++ "parse error: " ++ err +decodeFailsWith :: BL.ByteString -> String -> Assertion +decodeFailsWith input expected = case decode NoHeader input of + Right r -> assertFailure $ "input: " ++ show (BL8.unpack input) ++ "\n" ++ + "retuned: " ++ show (r :: (V.Vector (V.Vector B.ByteString))) ++ "\n" ++ + "whereas should have failed with " <> expected + Left err -> assertBool ("got " <> err <> "\ninstead of " <> expected) + $ ("parse error ("++expected++")") `isPrefixOf` err + encodesAs :: [[B.ByteString]] -> BL.ByteString -> Assertion encodesAs input expected = encode (map V.fromList input) @?= expected @@ -166,6 +175,7 @@ positionalTests = [ testGroup "decode" $ map streamingDecodeTest decodeTests , testGroup "decodeWith" $ map streamingDecodeWithTest decodeWithTests ] + , testGroup "escaped" escapedTests ] where rfc4180Input = BL8.pack $ @@ -210,6 +220,15 @@ positionalTests = defEncAllEnq = defaultEncodeOptions { encQuoting = QuoteAll } defDec = defaultDecodeOptions + escapedTests = [ + testCase "escaped" $ + "\"x,y\",z\nbaz,\"bar\nfoo,\"" `decodesAs` [["x,y", "z"], ["baz", "bar\nfoo,"]], + testCase "escapedMalformed1" $ + "\"x,\"y" `decodeFailsWith` "Failed reading: satisfy", + testCase "escapedMalformed0" $ + "baz,\"" `decodeFailsWith` "Failed reading: trailing double quote" + ] + nameBasedTests :: [TF.Test] nameBasedTests = [ testGroup "encode" $ map encodeTest