diff --git a/src/Ormolu/Printer/Meat/Declaration/Value.hs b/src/Ormolu/Printer/Meat/Declaration/Value.hs index 06f698325..268cc3000 100644 --- a/src/Ormolu/Printer/Meat/Declaration/Value.hs +++ b/src/Ormolu/Printer/Meat/Declaration/Value.hs @@ -451,14 +451,14 @@ p_stmt' s placer render = \case space sitcc $ p_hsLocalBinds binds ParStmt {} -> - -- 'ParStmt' should always be eliminated in 'flattenStmts' already, such + -- 'ParStmt' should always be eliminated in 'gatherStmts' already, such -- that it never occurs in 'p_stmt''. Consequently, handling it here -- would be redundant. notImplemented "ParStmt" TransStmt {..} -> -- 'TransStmt' only needs to account for render printing itself, since -- pretty printing of relevant statements (e.g., in 'trS_stmts') is - -- handled through 'flattenStmts'. + -- handled through 'gatherStmts'. case (trS_form, trS_by) of (ThenForm, Nothing) -> do txt "then" @@ -871,6 +871,7 @@ p_listComp s es = sitcc (vlayout singleLine multiLine) body = located es p_body p_body xs = do let (stmts, yield) = + -- TODO: use unsnoc when require GHC 9.8+ case xs of [] -> error $ "list comprehension unexpectedly had no expressions" _ -> (init xs, last xs) @@ -878,7 +879,7 @@ p_listComp s es = sitcc (vlayout singleLine multiLine) breakpoint txt "|" space - p_bodyParallels (flattenStmts stmts) + p_bodyParallels (gatherStmts stmts) -- print the list of list comprehension sections, e.g. -- [ "| x <- xs, y <- ys, let z = x <> y", "| a <- f z" ] @@ -888,33 +889,95 @@ p_listComp s es = sitcc (vlayout singleLine multiLine) -- [ "x <- xs", "y <- ys", "let z = x <> y" ] p_bodyParallelStmts = sep commaDel (located' (sitcc . p_stmt)) --- | Flatten the set of statements in a list comprehension. +-- | Gather the set of statements in a list comprehension. -- --- Concretely, expands all ParStmt constructors and extract out prefix --- statements in TransStmt. -flattenStmts :: [ExprLStmt GhcPs] -> [[ExprLStmt GhcPs]] -flattenStmts = collect . map gatherStmt +-- For example, this code: +-- +-- @ +-- [ a + b + c + d +-- | a <- as, let b = a + a +-- | c <- cs +-- | d <- ds, then sort by f +-- ] +-- @ +-- +-- is parsed as roughly: +-- +-- @ +-- [ ParStmt +-- [ ParStmtBlock +-- [ BindStmt [| a <- as |] +-- , LetStmt [| let b = a + a |] +-- ] +-- , ParStmtBlock +-- [ BindStmt [| c <- cs |] +-- ] +-- , ParStmtBlock +-- [ TransStmt +-- [ BindStmt [| d <- ds |] +-- ] +-- [| then sort by f |] +-- ] +-- ] +-- , LastStmt [| a + b + c + d |] +-- ] +-- @ +-- +-- The final expression is parsed out in p_body, and the rest is passed +-- to this function. This function takes the above tree as input and +-- normalizes it into: +-- +-- @ +-- [ [ BindStmt [| a <- as |] +-- , LetStmt [| let b = a + a |] +-- ] +-- , [ BindStmt [| c <- cs |] +-- ] +-- , [ BindStmt [| d <- ds |] +-- , TransStmt [] [| then sortWith by f |] +-- ] +-- ] +-- @ +-- +-- Notes: +-- * The number of elements in the outer list is the number of pipes in +-- the comprehension; i.e. 1 unless -XParallelListComp is enabled +gatherStmts :: [ExprLStmt GhcPs] -> [[ExprLStmt GhcPs]] +gatherStmts = \case + -- When -XParallelListComp is enabled + list comprehension has + -- multiple pipes, input will have exactly 1 element, and it + -- will be ParStmt. + [L _ (ParStmt _ blocks _ _)] -> + [ concatMap collectParStmts stmts + | ParStmtBlock _ stmts _ _ <- blocks + ] + -- When -XTransformListComp is enabled and there's only a + -- single pipe, the input will have exactly 1 element, and it + -- will be TransStmt. + [stmt@(L _ TransStmt {trS_stmts})] -> + [ map checkTransStmt trS_stmts ++ [stmt] + ] + -- Otherwise, list will not contain any ParStmt or TransStmt + stmts -> + [ map checkPlainStmt stmts + ] where - -- Note: not exactly sure what this does... - -- - -- From experimenting, it does this: - -- - -- >>> collect [[[stmt1]], [[stmt2, stmt3]], [[stmt4], [stmt5]], [[stmt6]]] - -- [[stmt1, stmt2, stmt3, stmt4, stmt6], [stmt5]] - -- - -- But it's not clear what the purpose is. - collect :: [[[ExprLStmt GhcPs]]] -> [[ExprLStmt GhcPs]] - collect = foldr (zipPrefixWith (<>)) [] + collectParStmts = \case + L _ ParStmt {} -> unexpected "ParStmt nested inside ParStmt" + stmt@(L _ TransStmt {trS_stmts}) -> map checkTransStmt trS_stmts ++ [stmt] + stmt -> [stmt] - gatherStmt :: ExprLStmt GhcPs -> [[ExprLStmt GhcPs]] - gatherStmt (L _ (ParStmt _ block _ _)) = - concatMap gatherStmtBlock block - gatherStmt (L s stmt@TransStmt {..}) = - collect $ map gatherStmt trS_stmts <> [[[L s stmt]]] - gatherStmt stmt = [[stmt]] + checkTransStmt = \case + L _ ParStmt {} -> unexpected "ParStmt nested inside TransStmt" + L _ TransStmt {} -> unexpected "TransStmt nested inside TransStmt" + stmt -> stmt - gatherStmtBlock :: ParStmtBlock GhcPs GhcPs -> [[ExprLStmt GhcPs]] - gatherStmtBlock (ParStmtBlock _ stmts _ _) = flattenStmts stmts + checkPlainStmt = \case + L _ ParStmt {} -> unexpected "ParStmt in list comprehension" + L _ TransStmt {} -> unexpected "TransStmt in list comprehension" + stmt -> stmt + + unexpected label = error $ "Unexpected " <> label <> "! Please file a bug." p_patSynBind :: PatSynBind GhcPs GhcPs -> R () p_patSynBind PSB {..} = do @@ -1319,15 +1382,6 @@ layoutToBraces = \case SingleLine -> useBraces MultiLine -> id --- | Same as 'zipWith', except only works on lists of the same type, and --- leaves extra elements at the end of the list. -zipPrefixWith :: (a -> a -> a) -> [a] -> [a] -> [a] -zipPrefixWith f = go - where - go [] ys = ys - go xs [] = xs - go (x : xs) (y : ys) = f x y : go xs ys - getGRHSSpan :: GRHS GhcPs (LocatedA body) -> SrcSpan getGRHSSpan (GRHS _ guards body) = combineSrcSpans' $ getLocA body :| map getLocA guards