{-# LANGUAGE CPP #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Language.Haskell.TH.Desugar.Subst
-- Copyright   :  (C) 2018 Richard Eisenberg
-- License     :  BSD-style (see LICENSE)
-- Maintainer  :  Ryan Scott
-- Stability   :  experimental
-- Portability :  non-portable
--
-- Capture-avoiding substitutions on 'DType's
--
----------------------------------------------------------------------------

module Language.Haskell.TH.Desugar.Subst (
  DSubst,

  -- * Capture-avoiding substitution
  substTy, substTyVarBndrs, unionSubsts, unionMaybeSubsts,

  -- * Matching a type template against a type
  IgnoreKinds(..), matchTy
  ) where

import Data.List
import qualified Data.Map as M
import qualified Data.Set as S

import Language.Haskell.TH.Desugar.AST
import Language.Haskell.TH.Syntax
import Language.Haskell.TH.Desugar.Util

#if __GLASGOW_HASKELL__ < 710
import Control.Applicative
#endif

-- | A substitution is just a map from names to types
type DSubst = M.Map Name DType

-- | Capture-avoiding substitution on types
substTy :: Quasi q => DSubst -> DType -> q DType
substTy :: DSubst -> DType -> q DType
substTy vars :: DSubst
vars (DForallT tvbs :: [DTyVarBndr]
tvbs cxt :: DCxt
cxt ty :: DType
ty) =
  DSubst
-> [DTyVarBndr] -> (DSubst -> [DTyVarBndr] -> q DType) -> q DType
forall (q :: * -> *) a.
Quasi q =>
DSubst -> [DTyVarBndr] -> (DSubst -> [DTyVarBndr] -> q a) -> q a
substTyVarBndrs DSubst
vars [DTyVarBndr]
tvbs ((DSubst -> [DTyVarBndr] -> q DType) -> q DType)
-> (DSubst -> [DTyVarBndr] -> q DType) -> q DType
forall a b. (a -> b) -> a -> b
$ \vars' :: DSubst
vars' tvbs' :: [DTyVarBndr]
tvbs' -> do
    DCxt
cxt' <- (DType -> q DType) -> DCxt -> q DCxt
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (DSubst -> DType -> q DType
forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars') DCxt
cxt
    DType
ty' <- DSubst -> DType -> q DType
forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars' DType
ty
    DType -> q DType
forall (m :: * -> *) a. Monad m => a -> m a
return (DType -> q DType) -> DType -> q DType
forall a b. (a -> b) -> a -> b
$ [DTyVarBndr] -> DCxt -> DType -> DType
DForallT [DTyVarBndr]
tvbs' DCxt
cxt' DType
ty'
substTy vars :: DSubst
vars (DAppT t1 :: DType
t1 t2 :: DType
t2) =
  DType -> DType -> DType
DAppT (DType -> DType -> DType) -> q DType -> q (DType -> DType)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DSubst -> DType -> q DType
forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars DType
t1 q (DType -> DType) -> q DType -> q DType
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> DSubst -> DType -> q DType
forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars DType
t2
substTy vars :: DSubst
vars (DAppKindT t :: DType
t k :: DType
k) =
  DType -> DType -> DType
DAppKindT (DType -> DType -> DType) -> q DType -> q (DType -> DType)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DSubst -> DType -> q DType
forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars DType
t q (DType -> DType) -> q DType -> q DType
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> DSubst -> DType -> q DType
forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars DType
k
substTy vars :: DSubst
vars (DSigT ty :: DType
ty ki :: DType
ki) =
  DType -> DType -> DType
DSigT (DType -> DType -> DType) -> q DType -> q (DType -> DType)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DSubst -> DType -> q DType
forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars DType
ty q (DType -> DType) -> q DType -> q DType
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> DSubst -> DType -> q DType
forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars DType
ki
substTy vars :: DSubst
vars (DVarT n :: Name
n)
  | Just ty :: DType
ty <- Name -> DSubst -> Maybe DType
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
n DSubst
vars
  = DType -> q DType
forall (m :: * -> *) a. Monad m => a -> m a
return DType
ty
  | Bool
otherwise
  = DType -> q DType
forall (m :: * -> *) a. Monad m => a -> m a
return (DType -> q DType) -> DType -> q DType
forall a b. (a -> b) -> a -> b
$ Name -> DType
DVarT Name
n
substTy _ ty :: DType
ty@(DConT _)  = DType -> q DType
forall (m :: * -> *) a. Monad m => a -> m a
return DType
ty
substTy _ ty :: DType
ty@DType
DArrowT    = DType -> q DType
forall (m :: * -> *) a. Monad m => a -> m a
return DType
ty
substTy _ ty :: DType
ty@(DLitT _)  = DType -> q DType
forall (m :: * -> *) a. Monad m => a -> m a
return DType
ty
substTy _ ty :: DType
ty@DType
DWildCardT = DType -> q DType
forall (m :: * -> *) a. Monad m => a -> m a
return DType
ty

substTyVarBndrs :: Quasi q => DSubst -> [DTyVarBndr]
                -> (DSubst -> [DTyVarBndr] -> q a)
                -> q a
substTyVarBndrs :: DSubst -> [DTyVarBndr] -> (DSubst -> [DTyVarBndr] -> q a) -> q a
substTyVarBndrs vars :: DSubst
vars tvbs :: [DTyVarBndr]
tvbs thing :: DSubst -> [DTyVarBndr] -> q a
thing = do
  (vars' :: DSubst
vars', tvbs' :: [DTyVarBndr]
tvbs') <- (DSubst -> DTyVarBndr -> q (DSubst, DTyVarBndr))
-> DSubst -> [DTyVarBndr] -> q (DSubst, [DTyVarBndr])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM DSubst -> DTyVarBndr -> q (DSubst, DTyVarBndr)
forall (q :: * -> *).
Quasi q =>
DSubst -> DTyVarBndr -> q (DSubst, DTyVarBndr)
substTvb DSubst
vars [DTyVarBndr]
tvbs
  DSubst -> [DTyVarBndr] -> q a
thing DSubst
vars' [DTyVarBndr]
tvbs'

substTvb :: Quasi q => DSubst -> DTyVarBndr
         -> q (DSubst, DTyVarBndr)
substTvb :: DSubst -> DTyVarBndr -> q (DSubst, DTyVarBndr)
substTvb vars :: DSubst
vars (DPlainTV n :: Name
n) = do
  Name
new_n <- String -> q Name
forall (m :: * -> *). Quasi m => String -> m Name
qNewName (Name -> String
nameBase Name
n)
  (DSubst, DTyVarBndr) -> q (DSubst, DTyVarBndr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> DType -> DSubst -> DSubst
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
n (Name -> DType
DVarT Name
new_n) DSubst
vars, Name -> DTyVarBndr
DPlainTV Name
new_n)
substTvb vars :: DSubst
vars (DKindedTV n :: Name
n k :: DType
k) = do
  Name
new_n <- String -> q Name
forall (m :: * -> *). Quasi m => String -> m Name
qNewName (Name -> String
nameBase Name
n)
  DType
k' <- DSubst -> DType -> q DType
forall (q :: * -> *). Quasi q => DSubst -> DType -> q DType
substTy DSubst
vars DType
k
  (DSubst, DTyVarBndr) -> q (DSubst, DTyVarBndr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> DType -> DSubst -> DSubst
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
n (Name -> DType
DVarT Name
new_n) DSubst
vars, Name -> DType -> DTyVarBndr
DKindedTV Name
new_n DType
k')

-- | Computes the union of two substitutions. Fails if both subsitutions map
-- the same variable to different types.
unionSubsts :: DSubst -> DSubst -> Maybe DSubst
unionSubsts :: DSubst -> DSubst -> Maybe DSubst
unionSubsts a :: DSubst
a b :: DSubst
b =
  let shared_key_set :: Set Name
shared_key_set = DSubst -> Set Name
forall k a. Map k a -> Set k
M.keysSet DSubst
a Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
`S.intersection` DSubst -> Set Name
forall k a. Map k a -> Set k
M.keysSet DSubst
b
      matches_up :: Bool
matches_up     = (Name -> Bool -> Bool) -> Bool -> Set Name -> Bool
forall a b. (a -> b -> b) -> b -> Set a -> b
S.foldr (\name :: Name
name -> ((DSubst
a DSubst -> Name -> DType
forall k a. Ord k => Map k a -> k -> a
M.! Name
name) DType -> DType -> Bool
forall a. Eq a => a -> a -> Bool
== (DSubst
b DSubst -> Name -> DType
forall k a. Ord k => Map k a -> k -> a
M.! Name
name) Bool -> Bool -> Bool
&&))
                               Bool
True Set Name
shared_key_set
  in
  if Bool
matches_up then DSubst -> Maybe DSubst
forall (m :: * -> *) a. Monad m => a -> m a
return (DSubst
a DSubst -> DSubst -> DSubst
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` DSubst
b) else Maybe DSubst
forall a. Maybe a
Nothing

---------------------------
-- Matching

-- | Ignore kind annotations in @matchTy@?
data IgnoreKinds = YesIgnore | NoIgnore

-- | @matchTy ign tmpl targ@ matches a type template @tmpl@ against a type
-- target @targ@. This returns a Map from names of type variables in the
-- type template to types if the types indeed match up, or @Nothing@ otherwise.
-- In the @Just@ case, it is guaranteed that every type variable mentioned
-- in the template is mapped by the returned substitution.
--
-- The first argument @ign@ tells @matchTy@ whether to ignore kind signatures
-- in the template. A kind signature in the template might mean that a type
-- variable has a more restrictive kind than otherwise possible, and that
-- mapping that type variable to a type of a different kind could be disastrous.
-- So, if we don't ignore kind signatures, this function returns @Nothing@ if
-- the template has a signature anywhere. If we do ignore kind signatures, it's
-- possible the returned map will be ill-kinded. Use at your own risk.
matchTy :: IgnoreKinds -> DType -> DType -> Maybe DSubst
matchTy :: IgnoreKinds -> DType -> DType -> Maybe DSubst
matchTy _   (DVarT var_name :: Name
var_name) arg :: DType
arg = DSubst -> Maybe DSubst
forall a. a -> Maybe a
Just (DSubst -> Maybe DSubst) -> DSubst -> Maybe DSubst
forall a b. (a -> b) -> a -> b
$ Name -> DType -> DSubst
forall k a. k -> a -> Map k a
M.singleton Name
var_name DType
arg
  -- if a pattern has a kind signature, it's really easy to get
  -- this wrong.
matchTy ign :: IgnoreKinds
ign (DSigT ty :: DType
ty _ki :: DType
_ki) arg :: DType
arg = case IgnoreKinds
ign of
  YesIgnore -> IgnoreKinds -> DType -> DType -> Maybe DSubst
matchTy IgnoreKinds
ign DType
ty DType
arg
  NoIgnore  -> Maybe DSubst
forall a. Maybe a
Nothing
  -- but we can safely ignore kind signatures on the target
matchTy ign :: IgnoreKinds
ign pat :: DType
pat (DSigT ty :: DType
ty _ki :: DType
_ki) = IgnoreKinds -> DType -> DType -> Maybe DSubst
matchTy IgnoreKinds
ign DType
pat DType
ty
matchTy _   (DForallT {}) _ =
  String -> Maybe DSubst
forall a. HasCallStack => String -> a
error "Cannot match a forall in a pattern"
matchTy _   _ (DForallT {}) =
  String -> Maybe DSubst
forall a. HasCallStack => String -> a
error "Cannot match a forall in a target"
matchTy ign :: IgnoreKinds
ign (DAppT pat1 :: DType
pat1 pat2 :: DType
pat2) (DAppT arg1 :: DType
arg1 arg2 :: DType
arg2) =
  [Maybe DSubst] -> Maybe DSubst
unionMaybeSubsts [IgnoreKinds -> DType -> DType -> Maybe DSubst
matchTy IgnoreKinds
ign DType
pat1 DType
arg1, IgnoreKinds -> DType -> DType -> Maybe DSubst
matchTy IgnoreKinds
ign DType
pat2 DType
arg2]
matchTy _   (DConT pat_con :: Name
pat_con) (DConT arg_con :: Name
arg_con)
  | Name
pat_con Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
arg_con = DSubst -> Maybe DSubst
forall a. a -> Maybe a
Just DSubst
forall k a. Map k a
M.empty
matchTy _   DArrowT DArrowT = DSubst -> Maybe DSubst
forall a. a -> Maybe a
Just DSubst
forall k a. Map k a
M.empty
matchTy _   (DLitT pat_lit :: TyLit
pat_lit) (DLitT arg_lit :: TyLit
arg_lit)
  | TyLit
pat_lit TyLit -> TyLit -> Bool
forall a. Eq a => a -> a -> Bool
== TyLit
arg_lit = DSubst -> Maybe DSubst
forall a. a -> Maybe a
Just DSubst
forall k a. Map k a
M.empty
matchTy _ _ _ = Maybe DSubst
forall a. Maybe a
Nothing

unionMaybeSubsts :: [Maybe DSubst] -> Maybe DSubst
unionMaybeSubsts :: [Maybe DSubst] -> Maybe DSubst
unionMaybeSubsts = (Maybe DSubst -> Maybe DSubst -> Maybe DSubst)
-> Maybe DSubst -> [Maybe DSubst] -> Maybe DSubst
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Maybe DSubst -> Maybe DSubst -> Maybe DSubst
union_subst1 (DSubst -> Maybe DSubst
forall a. a -> Maybe a
Just DSubst
forall k a. Map k a
M.empty)
  where
    union_subst1 :: Maybe DSubst -> Maybe DSubst -> Maybe DSubst
    union_subst1 :: Maybe DSubst -> Maybe DSubst -> Maybe DSubst
union_subst1 ma :: Maybe DSubst
ma mb :: Maybe DSubst
mb = do
      DSubst
a <- Maybe DSubst
ma
      DSubst
b <- Maybe DSubst
mb
      DSubst -> DSubst -> Maybe DSubst
unionSubsts DSubst
a DSubst
b