Skip to content

Instantly share code, notes, and snippets.

@jozefg
Created March 25, 2016 07:19
Show Gist options
  • Save jozefg/56abfbc42e49f298458d to your computer and use it in GitHub Desktop.
Save jozefg/56abfbc42e49f298458d to your computer and use it in GitHub Desktop.

Revisions

  1. jozefg created this gist Mar 25, 2016.
    177 changes: 177 additions & 0 deletions PatCompile.hs
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,177 @@
    {-# LANGUAGE DeriveTraversable #-}
    {-# LANGUAGE DeriveFoldable #-}
    {-# LANGUAGE DeriveFunctor #-}
    {-# LANGUAGE LambdaCase #-}
    {-# LANGUAGE GeneralizedNewtypeDeriving #-}
    module PatCompile where
    import Bound
    import Bound.Var
    import Bound.Scope
    import Control.Monad (ap)
    import Prelude.Extras (Show1 (..), Eq1 (..))
    import Data.List
    import Debug.Trace

    newtype Con = MkCon Int deriving (Eq, Show, Ord, Enum)

    data Constant = IntLit Int
    | CharLit Char
    | Con Con
    -- Non-pure data constants
    | TrueLit
    | FalseLit
    | If
    | EqInt
    | EqChar
    | Plus | Minus | Times | Div
    | RaiseError
    | MatchFail
    deriving (Eq, Show)

    data Pattern = PVar Int
    | PCon Con [Pattern]
    | PLit Constant
    deriving (Eq, Show)

    data Exp a = Var a
    | Const Constant
    | App (Exp a) (Exp a)
    | Let Pattern (Exp a) (Scope Int Exp a)
    | LetRec [Bind a] (Scope RecBV Exp a)
    | Lambda Pattern (Scope Int Exp a)
    | Bar (Exp a) (Exp a)
    | Case (Exp a) [Alt a] -- Invariant: Exhaustive, non-overlapping, "simple"
    deriving (Eq, Show, Functor, Foldable, Traversable)
    data RecBV = RecBV {patternNum :: Int, varNum :: Int} deriving (Eq, Show)
    data Bind a = Bind Pattern (Scope RecBV Exp a)
    deriving (Eq, Show, Functor, Foldable, Traversable)
    data Alt a = Alt Pattern (Scope Int Exp a)
    deriving (Eq, Show, Functor, Foldable, Traversable)

    instance Show1 Exp where
    instance Eq1 Exp where
    instance Applicative Exp where
    pure = return
    (<*>) = ap
    instance Monad Exp where
    return = Var
    e >>= f =
    case e of
    Var a -> f a
    Const c -> Const c
    App l r -> App (l >>= f) (r >>= f)
    Let p e1 e2 -> Let p (e1 >>= f) (e2 >>>= f)
    LetRec binds e2 -> LetRec (map bindBind binds) (e2 >>>= f)
    Lambda p body -> Lambda p (body >>>= f)
    Bar l r -> Bar (l >>= f) (r >>= f)
    Case e alts -> Case (e >>= f) (map bindAlt alts)
    where bindAlt (Alt p e) = Alt p (e >>>= f)
    bindBind (Bind p e) = Bind p (e >>>= f)

    mkVar :: a -> Exp (Var b (Exp a))
    mkVar = Var . F . Var

    fillBound :: Eq b => b -> a -> Scope b Exp a -> Scope b Exp a
    fillBound b new s =
    Scope $ splat mkVar (\b' -> if b == b' then mkVar new else Var (B b')) s

    stripScope :: (Show a, Show b) => Scope b Exp a -> Exp a
    stripScope s =
    case traverse id $ splat (Var . Just) (const (Var Nothing)) s of
    Nothing -> error $ "Not closed: " ++ show s
    Just e -> e

    -- The multi-match branch we're working with.
    data FlexibleAlt a = FAlt [Pattern] (Scope Int Exp a)
    deriving (Functor, Foldable, Traversable, Show)

    -- Compile a match on an expression with a list of branches and
    -- a default into a simplified Case expression. This is actually work
    -- because we require that Case expressions be exhaustive, not use
    -- nested patterns, and not be overlapping.
    --
    -- To simplify things, we have parallel matching and we demand that
    -- the list of list of alternatives is a rectangle.
    match :: (Eq a, Show a) => [a] -> [FlexibleAlt a] -> Exp a -> Exp a
    match scruts alts def
    -- Base case, we've compiled all scruts.
    | [] <- scruts =
    foldr Bar def $ map (\(FAlt [] e) -> stripScope e) alts

    -- All variable branches
    | Just branches <- allVars alts,
    scrut : remaining <- scruts =
    let new (i, ps, s) = FAlt ps (fillBound i scrut s)
    in match remaining (map new branches) def

    -- Next case, first scrut is matched against only constructors
    | Just branches <- allCons alts,
    scrut : remaining <- scruts =
    let ((_, args, _, _) : _) = branches
    newVars = [0 .. length args - 1]
    def' = F . Var <$> def
    gathered = groupBy (\(i, _, _, _) (j, _, _, _) -> i == j) branches
    branches' = map (\bs -> let (i : _, as, ps, ss) = unzip4 bs
    in (i, as, ps, ss))
    gathered
    new (i, argss, pss, ss) =
    Alt (PCon i (map PVar newVars)) . Scope $
    match (map B newVars ++ map (F . Var) remaining)
    [FAlt (args ++ ps) (F . Var <$> s) | (args, ps, s) <- zip3 argss pss ss]
    def'
    in Case (Var scrut) (map new branches')

    -- A degenerate version of the above where we're matching on literals
    | Just branches <- allLits alts,
    scrut : remaining <- scruts =
    let gathered = groupBy (\(i, _, _) (j, _, _) -> i == j) branches
    branches' = map (\bs -> let (i : _, ps, ss) = unzip3 bs
    in (i, ps, ss))
    gathered
    new (l, pss, ss) =
    Alt (PLit l) . abstract (const Nothing) $
    match remaining [FAlt ps s | (ps, s) <- zip pss ss] def
    in Case (Var scrut) (map new branches')

    -- A final case, we split apart overlapping patterns into chunks of
    -- nonoverlapping patterns and process them separately.
    | chunks <- splitChunks alts = foldr (match scruts) def chunks
    where allVars [] = Just []
    allVars (FAlt (PVar i : ps) s : alts) = ((i, ps, s) :) <$> allVars alts
    allVars _ = Nothing

    allCons [] = Just []
    allCons (FAlt (PCon c args : ps) s : alts) =
    ((c, args, ps, s) :) <$> allCons alts
    allCons _ = Nothing

    allLits [] = Just []
    allLits (FAlt (PLit l : ps) s : alts) =
    ((l, ps, s) :) <$> allLits alts
    allLits _ = Nothing

    splitChunks = groupBy $ \a b -> case (a, b) of
    (FAlt (PVar _ : _) _, FAlt (PVar _ : _) _) -> True
    (FAlt (PCon _ _ : _) _, FAlt (PCon _ _ : _) _) -> True
    (FAlt (PLit _ : _) _, FAlt (PLit _ : _) _) -> True
    _ -> False

    abstractF :: [String] -> Exp String -> Scope Int Exp String
    abstractF vars = abstract (flip elemIndex vars)

    app :: Exp a -> [Exp a] -> Exp a
    app = foldl App

    con :: Con -> Exp a
    con = Const . Con

    instance Num Con where
    fromInteger = MkCon . fromIntegral
    test =
    match ["hello", "world"]
    [ FAlt [PCon 0 [], PVar 0] $ abstractF ["x"] (Var "x")
    , FAlt [PVar 0, PCon 0 []] $ abstractF ["x"] (Var "x")
    , FAlt [PCon 1 [PVar 0, PVar 1], PCon 1 [PVar 2, PVar 3]]
    . abstractF ["x", "xs", "y", "ys"]
    $ app (con 1) [Var "x", app (con 1) [Var "y", app (Var "rec") [Var "xs", Var "ys"]]]]
    (Const MatchFail)