git.haldean.org del / 30e5126 src / Language / Del / TypeCheck.hs
30e5126

Tree @30e5126 (Download .tar.gz)

TypeCheck.hs @30e5126raw · history · blame

{-# LANGUAGE OverloadedStrings #-}

module Language.Del.TypeCheck (typeCheck) where

import qualified Data.Map.Strict as Map
import qualified Data.Text as T
import qualified Data.Traversable as Tr
import           Language.Del.AST
import qualified Language.Del.Builtins as B
import           Language.Del.CompileDefs
import           Language.Del.TypeTable

data TypeContext = TypeContext AST (Map.Map T.Text Type)

f3Subscripts :: [T.Text]
f3Subscripts = ["0", "1", "2", "x", "y", "z"]

newContext :: AST -> TypeContext
newContext ast@(AST m _) = TypeContext ast $ Map.union (Map.map deftypes m) builtins
    where deftypes (Def _ r a _) = FuncType (map snd a) r
          builtins = Map.map B.getType B.builtins

eq :: T.Text -> Expr -> (Type, Type, Int) -> Artifact ()
eq n e (ex, ac, i)
    = if ex == ac then return () else Failure $ BadArgType n ex ac i e

unify :: TypeContext -> Expr -> T.Text -> [Expr] -> Artifact Type
unify _ _ n [] = Failure . SystemError $ T.append "tried to unify empty expr list for " n
unify tc _ _ [x] = exprType tc x
unify tc e n (x:xs) = do
    t <- exprType tc x
    ts <- mapM (exprType tc) xs
    mapM (eq n e) (zip3 (repeat t) ts [2..]) >> return t

subType :: T.Text -> Type -> Artifact Type
subType sub Float3Type = if elem sub f3Subscripts
                            then return FloatType
                            else Failure $ BadSubscript Float3Type sub f3Subscripts
subType sub t = Failure $ Unsubscriptable t sub

checkZip :: Expr -> Expr -> [Type] -> [Type] -> Artifact ()
checkZip hd e ex ac
    = if length ex /= length ac
         then Failure $ BadArity (T.pack $ show hd) (length ex) (length ac)
         else mapM (eq (T.pack $ show hd) e) (zip3 ex ac [1..]) >> return ()

applyType :: TypeContext -> Expr -> Expr -> [Expr] -> Artifact Type
applyType tc e hd args = exprType tc hd >>= checkApply
    where checkApply (FuncType atypes rtype) =
              mapM (exprType tc) args >>= checkZip hd e atypes >> return rtype
          checkApply t = Failure $ BadApplyHead hd t

exprType :: TypeContext -> Expr -> Artifact Type
exprType tc@(TypeContext ast types) e
    = let u = unify tc e in case e of
           Add xs -> u "+" xs
           Mul xs -> u "*" xs
           Sub x1 x2 -> u "-" [x1, x2]
           Div x1 x2 -> u "/" [x1, x2]
           Pow x1 x2 -> u "^" [x1, x2]
           Gt x1 x2 -> u ">" [x1, x2] >> return BoolType
           Lt x1 x2 -> u "<" [x1, x2] >> return BoolType
           GtE x1 x2 -> u ">=" [x1, x2] >> return BoolType
           LtE x1 x2 -> u "<=" [x1, x2] >> return BoolType
           Negate x -> exprType tc x
           Subscript x s -> exprType tc x >>= subType s
           Apply x args -> applyType tc e x args
           NativeCall n args -> applyType tc e (Id n) args
           Id name -> case Map.lookup name types of
                           Just t -> return t
                           Nothing -> Failure $ BadId name
           ScopedId _ _ -> exprType tc (canonicalId e)
           Number (DelInt _) -> return IntType
           Number (DelFloat _) -> return FloatType
           Boolean _ -> return BoolType
           Let n v b ->
               exprType tc v >>= \vtype ->
                   let tc' = TypeContext ast (Map.insert n vtype types)
                       in exprType tc' b
           If c t f -> do
               ct <- exprType tc c
               tt <- exprType tc t
               ft <- exprType tc f
               if ct /= BoolType
                  then Failure $ BadIfCondition ct e
                  else if tt /= ft
                          then Failure $ IfTypeMismatch tt ft e
                          else return tt

checkDef :: TypeContext -> Def -> Artifact ()
checkDef (TypeContext ast types) (Def n r a b)
    = let tc' = TypeContext ast (Map.union (Map.fromList a) types) in
      exprType tc' b >>= \x -> if x == r then return () else Failure $ ReturnTypeMismatch n r x

typeCheck :: AST -> Artifact AST
typeCheck a@(AST m _) = Tr.mapM (checkDef (newContext a)) m >> return a