Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonchinn178 committed Jun 5, 2024
1 parent dbc476c commit 4a68831
Showing 1 changed file with 89 additions and 35 deletions.
124 changes: 89 additions & 35 deletions src/Ormolu/Printer/Meat/Declaration/Value.hs
Original file line number Diff line number Diff line change
Expand Up @@ -451,14 +451,14 @@ p_stmt' s placer render = \case
space
sitcc $ p_hsLocalBinds binds
ParStmt {} ->
-- 'ParStmt' should always be eliminated in 'flattenStmts' already, such
-- 'ParStmt' should always be eliminated in 'gatherStmts' already, such
-- that it never occurs in 'p_stmt''. Consequently, handling it here
-- would be redundant.
notImplemented "ParStmt"
TransStmt {..} ->
-- 'TransStmt' only needs to account for render printing itself, since
-- pretty printing of relevant statements (e.g., in 'trS_stmts') is
-- handled through 'flattenStmts'.
-- handled through 'gatherStmts'.
case (trS_form, trS_by) of
(ThenForm, Nothing) -> do
txt "then"
Expand Down Expand Up @@ -871,14 +871,15 @@ p_listComp s es = sitcc (vlayout singleLine multiLine)
body = located es p_body
p_body xs = do
let (stmts, yield) =
-- TODO: use unsnoc when require GHC 9.8+
case xs of
[] -> error $ "list comprehension unexpectedly had no expressions"
_ -> (init xs, last xs)
sitcc $ located yield p_stmt
breakpoint
txt "|"
space
p_bodyParallels (flattenStmts stmts)
p_bodyParallels (gatherStmts stmts)

-- print the list of list comprehension sections, e.g.
-- [ "| x <- xs, y <- ys, let z = x <> y", "| a <- f z" ]
Expand All @@ -888,33 +889,95 @@ p_listComp s es = sitcc (vlayout singleLine multiLine)
-- [ "x <- xs", "y <- ys", "let z = x <> y" ]
p_bodyParallelStmts = sep commaDel (located' (sitcc . p_stmt))

-- | Flatten the set of statements in a list comprehension.
-- | Gather the set of statements in a list comprehension.
--
-- Concretely, expands all ParStmt constructors and extract out prefix
-- statements in TransStmt.
flattenStmts :: [ExprLStmt GhcPs] -> [[ExprLStmt GhcPs]]
flattenStmts = collect . map gatherStmt
-- For example, this code:
--
-- @
-- [ a + b + c + d
-- | a <- as, let b = a + a
-- | c <- cs
-- | d <- ds, then sort by f
-- ]
-- @
--
-- is parsed as roughly:
--
-- @
-- [ ParStmt
-- [ ParStmtBlock
-- [ BindStmt [| a <- as |]
-- , LetStmt [| let b = a + a |]
-- ]
-- , ParStmtBlock
-- [ BindStmt [| c <- cs |]
-- ]
-- , ParStmtBlock
-- [ TransStmt
-- [ BindStmt [| d <- ds |]
-- ]
-- [| then sort by f |]
-- ]
-- ]
-- , LastStmt [| a + b + c + d |]
-- ]
-- @
--
-- The final expression is parsed out in p_body, and the rest is passed
-- to this function. This function takes the above tree as input and
-- normalizes it into:
--
-- @
-- [ [ BindStmt [| a <- as |]
-- , LetStmt [| let b = a + a |]
-- ]
-- , [ BindStmt [| c <- cs |]
-- ]
-- , [ BindStmt [| d <- ds |]
-- , TransStmt [] [| then sortWith by f |]
-- ]
-- ]
-- @
--
-- Notes:
-- * The number of elements in the outer list is the number of pipes in
-- the comprehension; i.e. 1 unless -XParallelListComp is enabled
gatherStmts :: [ExprLStmt GhcPs] -> [[ExprLStmt GhcPs]]
gatherStmts = \case
-- When -XParallelListComp is enabled + list comprehension has
-- multiple pipes, input will have exactly 1 element, and it
-- will be ParStmt.
[L _ (ParStmt _ blocks _ _)] ->
[ concatMap collectParStmts stmts
| ParStmtBlock _ stmts _ _ <- blocks
]
-- When -XTransformListComp is enabled and there's only a
-- single pipe, the input will have exactly 1 element, and it
-- will be TransStmt.
[stmt@(L _ TransStmt {trS_stmts})] ->
[ map checkTransStmt trS_stmts ++ [stmt]
]
-- Otherwise, list will not contain any ParStmt or TransStmt
stmts ->
[ map checkPlainStmt stmts
]
where
-- Note: not exactly sure what this does...
--
-- From experimenting, it does this:
--
-- >>> collect [[[stmt1]], [[stmt2, stmt3]], [[stmt4], [stmt5]], [[stmt6]]]
-- [[stmt1, stmt2, stmt3, stmt4, stmt6], [stmt5]]
--
-- But it's not clear what the purpose is.
collect :: [[[ExprLStmt GhcPs]]] -> [[ExprLStmt GhcPs]]
collect = foldr (zipPrefixWith (<>)) []
collectParStmts = \case
L _ ParStmt {} -> unexpected "ParStmt nested inside ParStmt"
stmt@(L _ TransStmt {trS_stmts}) -> map checkTransStmt trS_stmts ++ [stmt]
stmt -> [stmt]

gatherStmt :: ExprLStmt GhcPs -> [[ExprLStmt GhcPs]]
gatherStmt (L _ (ParStmt _ block _ _)) =
concatMap gatherStmtBlock block
gatherStmt (L s stmt@TransStmt {..}) =
collect $ map gatherStmt trS_stmts <> [[[L s stmt]]]
gatherStmt stmt = [[stmt]]
checkTransStmt = \case
L _ ParStmt {} -> unexpected "ParStmt nested inside TransStmt"
L _ TransStmt {} -> unexpected "TransStmt nested inside TransStmt"
stmt -> stmt

gatherStmtBlock :: ParStmtBlock GhcPs GhcPs -> [[ExprLStmt GhcPs]]
gatherStmtBlock (ParStmtBlock _ stmts _ _) = flattenStmts stmts
checkPlainStmt = \case
L _ ParStmt {} -> unexpected "ParStmt in list comprehension"
L _ TransStmt {} -> unexpected "TransStmt in list comprehension"
stmt -> stmt

unexpected label = error $ "Unexpected " <> label <> "! Please file a bug."

p_patSynBind :: PatSynBind GhcPs GhcPs -> R ()
p_patSynBind PSB {..} = do
Expand Down Expand Up @@ -1319,15 +1382,6 @@ layoutToBraces = \case
SingleLine -> useBraces
MultiLine -> id

-- | Same as 'zipWith', except only works on lists of the same type, and
-- leaves extra elements at the end of the list.
zipPrefixWith :: (a -> a -> a) -> [a] -> [a] -> [a]
zipPrefixWith f = go
where
go [] ys = ys
go xs [] = xs
go (x : xs) (y : ys) = f x y : go xs ys

getGRHSSpan :: GRHS GhcPs (LocatedA body) -> SrcSpan
getGRHSSpan (GRHS _ guards body) =
combineSrcSpans' $ getLocA body :| map getLocA guards
Expand Down

0 comments on commit 4a68831

Please sign in to comment.