{-# LANGUAGE OverloadedStrings #-}
module Language.Del.TypeCheck (typeCheck) where
import qualified Data.Map.Strict as Map
import qualified Data.Text as T
import qualified Data.Traversable as Tr
import Language.Del.AST
import qualified Language.Del.Builtins as B
import Language.Del.CompileDefs
import Language.Del.TypeTable
data TypeContext = TypeContext AST (Map.Map T.Text Type)
f3Subscripts :: [T.Text]
f3Subscripts = ["0", "1", "2", "x", "y", "z"]
newContext :: AST -> TypeContext
newContext ast@(AST m _) = TypeContext ast $ Map.union (Map.map deftypes m) builtins
where deftypes (Def _ r a _) = FuncType (map snd a) r
builtins = Map.map B.getType B.builtins
eq :: T.Text -> Expr -> (Type, Type, Int) -> Artifact ()
eq n e (ex, ac, i)
= if ex == ac then return () else Failure $ BadArgType n ex ac i e
unify :: TypeContext -> Expr -> T.Text -> [Expr] -> Artifact Type
unify _ _ n [] = Failure . SystemError $ T.append "tried to unify empty expr list for " n
unify tc _ _ [x] = exprType tc x
unify tc e n (x:xs) = do
t <- exprType tc x
ts <- mapM (exprType tc) xs
mapM (eq n e) (zip3 (repeat t) ts [2..]) >> return t
subType :: T.Text -> Type -> Artifact Type
subType sub Float3Type = if elem sub f3Subscripts
then return FloatType
else Failure $ BadSubscript Float3Type sub f3Subscripts
subType sub t = Failure $ Unsubscriptable t sub
checkZip :: Expr -> Expr -> [Type] -> [Type] -> Artifact ()
checkZip hd e ex ac
= if length ex /= length ac
then Failure $ BadArity (T.pack $ show hd) (length ex) (length ac)
else mapM (eq (T.pack $ show hd) e) (zip3 ex ac [1..]) >> return ()
applyType :: TypeContext -> Expr -> Expr -> [Expr] -> Artifact Type
applyType tc e hd args = exprType tc hd >>= checkApply
where checkApply (FuncType atypes rtype) =
mapM (exprType tc) args >>= checkZip hd e atypes >> return rtype
checkApply t = Failure $ BadApplyHead hd t
exprType :: TypeContext -> Expr -> Artifact Type
exprType tc@(TypeContext ast types) e
= let u = unify tc e in case e of
Add xs -> u "+" xs
Mul xs -> u "*" xs
Sub x1 x2 -> u "-" [x1, x2]
Div x1 x2 -> u "/" [x1, x2]
Pow x1 x2 -> u "^" [x1, x2]
Gt x1 x2 -> u ">" [x1, x2] >> return BoolType
Lt x1 x2 -> u "<" [x1, x2] >> return BoolType
GtE x1 x2 -> u ">=" [x1, x2] >> return BoolType
LtE x1 x2 -> u "<=" [x1, x2] >> return BoolType
Negate x -> exprType tc x
Subscript x s -> exprType tc x >>= subType s
Apply x args -> applyType tc e x args
NativeCall n args -> applyType tc e (Id n) args
Id name -> case Map.lookup name types of
Just t -> return t
Nothing -> Failure $ BadId name
ScopedId _ _ -> exprType tc (canonicalId e)
Number (DelInt _) -> return IntType
Number (DelFloat _) -> return FloatType
Boolean _ -> return BoolType
Let n v b ->
exprType tc v >>= \vtype ->
let tc' = TypeContext ast (Map.insert n vtype types)
in exprType tc' b
If c t f -> do
ct <- exprType tc c
tt <- exprType tc t
ft <- exprType tc f
if ct /= BoolType
then Failure $ BadIfCondition ct e
else if tt /= ft
then Failure $ IfTypeMismatch tt ft e
else return tt
checkDef :: TypeContext -> Def -> Artifact ()
checkDef (TypeContext ast types) (Def n r a b)
= let tc' = TypeContext ast (Map.union (Map.fromList a) types) in
exprType tc' b >>= \x -> if x == r then return () else Failure $ ReturnTypeMismatch n r x
typeCheck :: AST -> Artifact AST
typeCheck a@(AST m _) = Tr.mapM (checkDef (newContext a)) m >> return a