{-# LANGUAGE AllowAmbiguousTypes, DeriveGeneric, EmptyCase, TypeOperators, FlexibleInstances, FlexibleContexts, MultiParamTypeClasses, PolyKinds, TypeApplications, UndecidableInstances #-} {-# OPTIONS_GHC -Wall #-} import GHC.Generics gcase :: forall t b r. (Generic t, GCaseSum (Rep t) r b r) => t -> b gcase t = gCaseSum @(Rep t) @r @b @r (from t) id -- Two classes: GCaseSum to decompose sums, GCaseProd to decompose products -- Sums: f :+: g -> (f -> r) -> (g -> r) -> r -- - to process f, we must consume (f -> r) and ignore (g -> r) -- - to process g, we must ignore (f -> r) and consume (g -> r) -- -- Define gCaseSum to consume arguments and gCaseSumSkip to ignore arguments in a sum type g. -- - gCaseSum gets passed a function (r -> a) that expects the final result r after processing g, -- ignores extra arguments, and returns r. -- - gCaseSumSkip ignores the arguments corresponding to g, -- and uses the function contained in a to process the remaining arguments. -- -- For example, for (Either x y) it will look like: -- - At the toplevel, we have (GCaseSum (Either x y) a b) with a and b shown below: -- gCaseSum :: Either x y -> (r -> r) -> (x -> r) -> (y -> r) -> r -- r=a ^^^^^^^^^^^^^^^^^^^^^^^^^=b -- (gCaseSumSkip will not be used) -- - When processing the Left side (GCaseSum x a b): -- gCaseSum :: x -> (r -> (y -> r) -> r) -> (x -> r) -> (y -> r) -> r -- ^^^^^^^^^^^^^=a ^^^^^^^^^^^^^^^^^^^^^^^^^=b -- gCaseSumSkip :: ((y -> r) -> r) -> (x -> r) -> (y -> r) -> r -- ^^^^^^^^^^^^^^^=a ^^^^^^^^^^^^^^^^^^^^^^^^^=b -- - When processing the Right side (GCaseSum y a b): -- gCaseSum :: y -> (r -> r) -> (y -> r) -> r -- r=a ^^^^^^^^^^^^^=b -- (gCaseSumSkip will not be used) class GCaseSum g a b r where gCaseSum :: forall x. g x -> (r -> a) -> b gCaseSumSkip :: a -> b instance (GCaseSum f b c r, GCaseSum g a b r) => GCaseSum (f :+: g) a c r where -- consume arguments for f, skip arguments for g gCaseSum (L1 f) = gCaseSum @f @b @c f . fmap (gCaseSumSkip @g @a @b @r) -- skip arguments for f, consume arguments for g gCaseSum (R1 g) = gCaseSumSkip @f @b @c @r . gCaseSum @g @a @b g -- skip f and g gCaseSumSkip = gCaseSumSkip @f @b @c @r . gCaseSumSkip @g @a @b @r -- Empty sum (Void) -- gcase :: Void -> r instance (a ~ b) => GCaseSum V1 a b r where gCaseSum v = case v of {} gCaseSumSkip = id -- Unwrap toplevel M1 instance GCaseSum f a b r => GCaseSum (M1 D w f) a b r where gCaseSum (M1 f) = gCaseSum @f @a @b @r f gCaseSumSkip = gCaseSumSkip @f @a @b @r -- Unwrap leaf M1 (we are now looking at a single constructor, call GCaseProd) instance (ca ~ (c -> a), GCaseProd f c r) => GCaseSum (M1 C w f) a ca r where gCaseSum (M1 f) k c = k (gCaseProd f c) gCaseSumSkip a _ = a -- Products: (x, y) -> (x -> y -> r) -> r -- -- - Toplevel instance (GCaseProd (x, y) c d): -- gCaseProd :: (x, y) -> (x -> y -> r) -> r -- ^^^^^^^^^^^^^=c r=d -- - When processing x (instance GCaseProd x c d): -- gCaseProd :: x -> (x -> y -> r) -> y -> r -- ^^^^^^^^^^^^^=c ^^^^^^=d -- - When processing y (instance GCaseProd y c d): -- gCaseProd :: y -> (y -> r) -> r -- ^^^^^^^^=c r=d -- -- In both of the leaf cases (instance for M1 S), gCaseProd = \a k -> k a (= flip ($)) -- The instance for (:*:) combines them using (.) class GCaseProd g c d where gCaseProd :: forall x. g x -> c -> d instance (GCaseProd f c d, GCaseProd g d e) => GCaseProd (f :*: g) c e where gCaseProd (f :*: g) = gCaseProd g . gCaseProd @f @c @d f instance (ad ~ (a -> d)) => GCaseProd (M1 S w (K1 i a)) ad d where gCaseProd (M1 (K1 a)) k = k a -- Empty product (unit) -- gcase :: () -> r -> r -- gcase :: Maybe a -> r -> (a -> r) -> r -- ^ instance (r ~ r') => GCaseProd U1 r r' where gCaseProd U1 r = r -- -- Assert equality (=?) :: (Eq a, Show a) => a -> a -> IO () (=?) x y = if x == y then pure () else error (show x ++ " /= " ++ show y) -- Example with three constructors and three fields data Three a = Uno a | Dos a a | Tres a a a deriving (Generic) -- Should print nothing if the tests pass main :: IO () main = do gcase (Left (3 :: Int)) id id =? 3 gcase (Right (3 :: Int)) id id =? 3 gcase (Just (3 :: Int)) 0 id =? 3 gcase (40, 2) (+) =? (42 :: Int) gcase () 42 =? (42 :: Int) gcase (Tres 100 20 3) id (+) (\x y z -> x + y + z) =? (123 :: Int)