1-- 2-- Copyright 2018, Data61 3-- Commonwealth Scientific and Industrial Research Organisation (CSIRO) 4-- ABN 41 687 119 230. 5-- 6-- This software may be distributed and modified according to the terms of 7-- the GNU General Public License version 2. Note that NO WARRANTY is provided. 8-- See "LICENSE_GPLv2.txt" for details. 9-- 10-- @TAG(DATA61_GPL) 11-- 12 13{- LANGUAGE AllowAmbiguousTypes -} 14{-# LANGUAGE DataKinds #-} 15{- LANGUAGE DeriveDataTypeable -} 16{-# LANGUAGE DeriveFunctor #-} 17{-# LANGUAGE ExistentialQuantification #-} 18{-# LANGUAGE FlexibleContexts #-} 19{-# LANGUAGE FlexibleInstances #-} 20{-# LANGUAGE GADTs #-} 21{-# LANGUAGE GeneralizedNewtypeDeriving #-} 22{- LANGUAGE InstanceSigs -} 23{-# LANGUAGE KindSignatures #-} 24{-# LANGUAGE LambdaCase #-} 25{-# LANGUAGE MultiWayIf #-} 26{-# LANGUAGE PatternGuards #-} 27{-# LANGUAGE PolyKinds #-} 28{-# LANGUAGE Rank2Types #-} 29{-# LANGUAGE ScopedTypeVariables #-} 30{- LANGUAGE StandaloneDeriving #-} 31{-# LANGUAGE TupleSections #-} 32{-# LANGUAGE TypeFamilies #-} 33{-# LANGUAGE TypeOperators #-} 34{-# LANGUAGE UndecidableInstances #-} 35 36module Cogent.Inference where 37 38import Cogent.Common.Syntax 39import Cogent.Common.Types 40import Cogent.Compiler 41import Cogent.Core 42import Cogent.Dargent.Allocation 43import Cogent.Dargent.Util 44import Cogent.Util 45import Cogent.PrettyPrint (indent') 46import Data.Ex 47import Data.Fin 48import Data.Nat 49import qualified Data.OMap as OM 50import Data.PropEq 51import Data.Vec hiding (repeat, splitAt, length, zipWith, zip, unzip) 52import qualified Data.Vec as Vec 53 54import Control.Applicative 55import Control.Arrow 56import Control.Monad.Except hiding (fmap, forM_) 57import Control.Monad.Reader hiding (fmap, forM_) 58import Control.Monad.State hiding (fmap, forM_) 59import Control.Monad.Trans.Maybe 60import Data.Foldable (forM_) 61import Data.Function (on) 62import qualified Data.IntMap as IM 63import Data.Map (Map) 64import Data.Maybe (isJust) 65import qualified Data.Map as M 66import Data.Monoid 67#if __GLASGOW_HASKELL__ < 709 68import Data.Traversable(traverse) 69#endif 70import Lens.Micro (_2) 71import Lens.Micro.Mtl (view) 72import Text.PrettyPrint.ANSI.Leijen (Pretty, pretty) 73import qualified Unsafe.Coerce as Unsafe (unsafeCoerce) -- NOTE: used safely to coerce phantom types only 74 75import Debug.Trace 76 77guardShow :: String -> Bool -> TC t v b () 78guardShow x b = unless b $ TC (throwError $ "GUARD: " ++ x) 79 80guardShow' :: String -> [String] -> Bool -> TC t v b () 81guardShow' mh mb b = unless b $ TC (throwError $ "GUARD: " ++ mh ++ "\n" ++ unlines mb) 82 83-- ---------------------------------------------------------------------------- 84-- Type reconstruction 85 86-- Types that don't have the same representation / don't satisfy subtyping. 87isUpcastable :: (Show b, Eq b) => Type t b -> Type t b -> TC t v b Bool 88isUpcastable (TPrim p1) (TPrim p2) = return $ isSubtypePrim p1 p2 89isUpcastable (TSum s1) (TSum s2) = do 90 c1 <- flip allM s1 (\(c,(t,b)) -> case lookup c s2 of 91 Nothing -> return False 92 Just (t',b') -> (&&) <$> t `isSubtype` t' <*> pure (b == b')) 93 c2 <- flip allM s2 (\(c,(t,b)) -> return $ case lookup c s1 of Nothing -> b; Just _ -> True) -- other tags are all taken 94 return $ c1 && c2 95isUpcastable _ _ = return False 96 97isSubtype :: (Show b, Eq b) => Type t b -> Type t b -> TC t v b Bool 98isSubtype t1 t2 = runMaybeT (t1 `lub` t2) >>= \case Just t -> return $ t == t2 99 Nothing -> return False 100 101unroll :: RecParName -> RecContext (Type t b) -> Type t b 102unroll v (Just ctxt) = erp (Just ctxt) (ctxt M.! v) 103 where 104 -- Embed rec pars 105 erp :: RecContext (Type t b) -> Type t b -> Type t b 106 erp c (TCon n ts s) = TCon n (map (erp c) ts) s 107 erp c (TFun t1 t2) = TFun (erp c t1) (erp c t2) 108 erp c (TSum r) = TSum $ map (\(a,(t,b)) -> (a, (erp c t, b))) r 109 erp c (TProduct t1 t2) = TProduct (erp c t1) (erp c t2) 110 erp (Just c) t@(TRecord rp fs s) = 111 let c' = case rp of Rec v -> M.insert v t c; _ -> c 112 in TRecord rp (map (\(a,(t,b)) -> (a, (erp (Just c') t, b))) fs) s 113 -- Context must be Nothing at this point 114 erp c (TRPar v Nothing) = TRPar v c 115#ifdef BUILTIN_ARRAYS 116 erp c (TArray t l s h) = TArray (erp c t) l s h 117#endif 118 erp _ t = t 119 120bound :: (Show b, Eq b) => Bound -> Type t b -> Type t b -> MaybeT (TC t v b) (Type t b) 121bound _ t1 t2 | t1 == t2 = return t1 122bound b (TRecord rp1 fs1 s1) (TRecord rp2 fs2 s2) 123 | map fst fs1 == map fst fs2, s1 == s2, rp1 == rp2 = do 124 let op = case b of LUB -> (||); GLB -> (&&) 125 blob <- flip3 zipWithM fs2 fs1 $ \(f1,(t1,b1)) (_, (t2,b2)) -> do 126 t <- bound b t1 t2 127 ok <- lift $ if b1 == b2 then return True 128 else kindcheck t >>= \k -> return (canDiscard k) 129 return ((f1, (t, b1 `op` b2)), ok) 130 let (fs, oks) = unzip blob 131 if and oks then return $ TRecord rp1 fs s1 132 else MaybeT (return Nothing) 133bound b (TSum s1) (TSum s2) | s1' <- M.fromList s1, s2' <- M.fromList s2, M.keys s1' == M.keys s2' = do 134 let op = case b of LUB -> (&&); GLB -> (||) 135 s <- flip3 unionWithKeyM s2' s1' $ \k (t1,b1) (t2,b2) -> (,) <$> bound b t1 t2 <*> pure (b1 `op` b2) 136 return $ TSum $ M.toList s 137bound b (TProduct t11 t12) (TProduct t21 t22) = TProduct <$> bound b t11 t21 <*> bound b t12 t22 138bound b (TCon c1 t1 s1) (TCon c2 t2 s2) | c1 == c2, s1 == s2 = TCon c1 <$> zipWithM (bound b) t1 t2 <*> pure s1 139bound b (TFun t1 s1) (TFun t2 s2) = TFun <$> bound (theOtherB b) t1 t2 <*> bound b s1 s2 140-- At this point, we can assume recursive parameters and records agree 141bound b t1@(TRecord rp fs s) t2@(TRPar v ctxt) = return t2 142bound b t1@(TRPar v ctxt) t2@(TRecord rp fs s) = return t2 143bound b t1@(TRPar v1 c1) t2@(TRPar v2 c2) = return t2 144#ifdef BUILTIN_ARRAYS 145bound b (TArray t1 l1 s1 mhole1) (TArray t2 l2 s2 mhole2) 146 | l1 == l2, s1 == s2 = do 147 t <- bound b t1 t2 148 ok <- lift $ case (mhole1, mhole2) of 149 (Nothing, Nothing) -> return True 150 (Just i1, Just i2) -> return $ i1 == i2 -- FIXME: change to propositional equality / zilinc 151 _ -> kindcheck t >>= \k -> return (canDiscard k) 152 let mhole = combineHoles b mhole1 mhole2 153 if ok then return $ TArray t l1 s1 mhole 154 else MaybeT (return Nothing) 155 where 156 combineHoles b Nothing Nothing = Nothing 157 combineHoles b (Just i1) (Just _ ) = Just i1 158 combineHoles b Nothing (Just i2) = case b of GLB -> Nothing; LUB -> Just i2 159 combineHoles b (Just i1) Nothing = case b of GLB -> Nothing; LUB -> Just i1 160#endif 161bound _ t1 t2 = __impossible ("bound: not comparable:\n" ++ show t1 ++ "\n" ++ 162 "----------------------------------------\n" ++ show t2 ++ "\n") 163 164lub :: (Show b, Eq b) => Type t b -> Type t b -> MaybeT (TC t v b) (Type t b) 165lub = bound LUB 166 167glb :: (Show b, Eq b) => Type t b -> Type t b -> MaybeT (TC t v b) (Type t b) 168glb = bound GLB 169 170-- checkUExpr_B :: UntypedExpr -> TC t v Bool 171-- checkUExpr_B (E (Op op [e])) = return True 172-- checkUExpr_B (E (Op op [e1,e2])) = return True 173-- checkUExpr_B _ = return True 174 175 176bang :: Type t b -> Type t b 177bang (TVar v) = TVarBang v 178bang (TVarBang v) = TVarBang v 179bang (TVarUnboxed v) = TVarUnboxed v 180bang (TCon n ts s) = TCon n (map bang ts) (bangSigil s) 181bang (TFun ti to) = TFun ti to 182bang (TPrim i) = TPrim i 183bang (TString) = TString 184bang (TSum ts) = TSum (map (second $ first bang) ts) 185bang (TProduct t1 t2) = TProduct (bang t1) (bang t2) 186bang (TRecord rp ts s) = TRecord rp (map (second $ first bang) ts) (bangSigil s) 187bang (TRPar n ctxt) = TRPar n ctxt 188bang (TUnit) = TUnit 189#ifdef BUILTIN_ARRAYS 190bang (TArray t l s tkns) = TArray (bang t) l (bangSigil s) tkns 191#endif 192 193unbox :: Type t b -> Type t b 194unbox (TVar v) = TVarUnboxed v 195unbox (TVarBang v) = TVarUnboxed v 196unbox (TVarUnboxed v) = TVarUnboxed v 197unbox (TCon n ts s) = TCon n ts (unboxSigil s) 198unbox (TRecord rp ts s)= TRecord rp ts (unboxSigil s) 199unbox t = t -- NOTE that @#@ type operator behaves differently to @!@. 200 -- The application of @#@ should NOT be pushed inside of a 201 -- data type. / zilinc 202 203 204substitute :: Vec t (Type u b) -> Type t b -> Type u b 205substitute vs (TVar v) = vs `at` v 206substitute vs (TVarBang v) = bang (vs `at` v) 207substitute vs (TVarUnboxed v) = unbox (vs `at` v) 208substitute vs (TCon n ts s) = TCon n (map (substitute vs) ts) s 209substitute vs (TFun ti to) = TFun (substitute vs ti) (substitute vs to) 210substitute _ (TPrim i) = TPrim i 211substitute _ (TString) = TString 212substitute vs (TProduct t1 t2) = TProduct (substitute vs t1) (substitute vs t2) 213substitute vs (TRecord rp ts s) = TRecord rp (map (second (first $ substitute vs)) ts) s 214substitute vs (TSum ts) = TSum (map (second (first $ substitute vs)) ts) 215substitute _ (TUnit) = TUnit 216substitute vs (TRPar v m) = TRPar v $ fmap (M.map (substitute vs)) m 217#ifdef BUILTIN_ARRAYS 218substitute vs (TArray t l s mhole) = TArray (substitute vs t) (substituteLE vs l) s (fmap (substituteLE vs) mhole) 219#endif 220 221substituteL :: [DataLayout BitRange] -> Type t b -> Type t b 222substituteL ls (TCon n ts s) = TCon n (map (substituteL ls) ts) s 223substituteL ls (TFun ti to) = TFun (substituteL ls ti) (substituteL ls to) 224substituteL ls (TProduct t1 t2) = TProduct (substituteL ls t1) (substituteL ls t2) 225substituteL ls (TRecord rp ts s) = TRecord rp (map (second (first $ substituteL ls)) ts) (substituteS ls s) 226substituteL ls (TSum ts) = TSum (map (second (first $ substituteL ls)) ts) 227#ifdef BUILTIN_ARRAYS 228substituteL ls (TArray t l s mhole) = TArray (substituteL ls t) l (substituteS ls s) mhole 229#endif 230substituteL _ t = t 231 232substituteS :: [DataLayout BitRange] -> Sigil (DataLayout BitRange) -> Sigil (DataLayout BitRange) 233substituteS ls Unboxed = Unboxed 234substituteS ls (Boxed b CLayout) = Boxed b CLayout 235substituteS ls (Boxed b (Layout l)) = Boxed b . Layout $ substituteS' ls l 236 where 237 substituteS' :: [DataLayout BitRange] -> DataLayout' BitRange -> DataLayout' BitRange 238 substituteS' ls l = case l of 239 VarLayout n s -> case ls !! (natToInt n) of 240 CLayout -> __impossible "substituteS in Inference: CLayout shouldn't be here" 241 Layout l -> offset s l 242 SumLayout tag alts -> 243 let altl = M.toList alts 244 fns = fmap fst altl 245 fis = fmap fst $ fmap snd altl 246 fes = fmap snd $ fmap snd altl 247 in SumLayout tag $ M.fromList (zip fns $ zip fis (fmap (substituteS' ls) fes)) 248 RecordLayout fs -> 249 let fsl = M.toList fs 250 fns = fmap fst fsl 251 fes = fmap snd fsl 252 in RecordLayout $ M.fromList (zip fns (fmap (substituteS' ls) fes)) 253#ifdef BUILTIN_ARRAYS 254 ArrayLayout e -> ArrayLayout $ substituteS' ls e 255#endif 256 _ -> l 257 258substituteLE :: Vec t (Type u b) -> LExpr t b -> LExpr u b 259substituteLE vs = \case 260 LVariable va -> LVariable va 261 LFun fn ts ls -> LFun fn (fmap (substitute vs) ts) ls 262 LOp op es -> LOp op $ fmap go es 263 LApp e1 e2 -> LApp (go e1) (go e2) 264 LCon tn e t -> LCon tn (go e) (substitute vs t) 265 LUnit -> LUnit 266 LILit n t -> LILit n t 267 LSLit s -> LSLit s 268 LLet a e e' -> LLet a (go e) (go e') 269 LLetBang bs a e e' -> LLetBang bs a (go e) (go e') 270 LTuple e1 e2 -> LTuple (go e1) (go e2) 271 LStruct fs -> LStruct $ fmap (second go) fs 272 LIf c th el -> LIf (go c) (go th) (go el) 273 LCase e tn (a1,e1) (a2,e2) 274 -> LCase (go e) tn (a1,go e1) (a2,go e2) 275 LEsac e -> LEsac $ go e 276 LSplit as e e' -> LSplit as (go e) (go e') 277 LMember e f -> LMember (go e) f 278 LTake as rec f e' -> LTake as (go rec) f (go e') 279 LPut rec f e -> LPut (go rec) f (go e) 280 LPromote t e -> LPromote (substitute vs t) (go e) 281 LCast t e -> LCast (substitute vs t) (go e) 282 where go = substituteLE vs 283 284remove :: (Eq a) => a -> [(a,b)] -> [(a,b)] 285remove k = filter ((/= k) . fst) 286 287adjust :: (Eq a) => a -> (b -> b) -> [(a,b)] -> [(a,b)] 288adjust k f = map (\(a,b) -> (a,) $ if a == k then f b else b) 289 290 291newtype TC (t :: Nat) (v :: Nat) b x 292 = TC {unTC :: ExceptT String 293 (ReaderT (Vec t Kind, Map FunName (FunctionType b)) 294 (State (Vec v (Maybe (Type t b))))) 295 x} 296 deriving (Functor, Applicative, Alternative, Monad, MonadPlus, 297 MonadReader (Vec t Kind, Map FunName (FunctionType b))) 298 299#if MIN_VERSION_base(4,13,0) 300instance MonadFail (TC t v b) where 301 fail = __impossible 302#endif 303 304infixl 4 <||> 305(<||>) :: TC t v b (x -> y) -> TC t v b x -> TC t v b y 306(TC a) <||> (TC b) = TC $ do x <- get 307 f <- a 308 x1 <- get 309 put x 310 arg <- b 311 x2 <- get 312 -- XXX | unTC $ guardShow "<||>" $ x1 == x2 313 -- \ ^^^ NOTE: This check is taken out to fix 314 -- #296. The issue here is that, if we define a 315 -- variable of permission D alone (w/o S), it will 316 -- be marked as used after it's been used, which 317 -- is correct. But when it is used in one branch 318 -- but not in the other one, which is allowed as 319 -- it's droppable, it will be marked as used in 320 -- the context of one branch but not the other and 321 -- render the two contexts different. The formal 322 -- specification requires that both contexts are 323 -- the same, but it is tantamount to merging two 324 -- differerent (correct) contexts correctly, which 325 -- can be established in the typing proof. 326 -- / v.jackson, zilinc 327 return (f arg) 328 329opType :: Op -> [Type t b] -> Maybe (Type t b) 330opType opr [TPrim p1, TPrim p2] 331 | opr `elem` [Plus, Minus, Times, Divide, Mod, 332 BitAnd, BitOr, BitXor, LShift, RShift], 333 p1 == p2, p1 /= Boolean = Just $ TPrim p1 334opType opr [TPrim p1, TPrim p2] 335 | opr `elem` [Gt, Lt, Le, Ge, Eq, NEq], 336 p1 == p2, p1 /= Boolean = Just $ TPrim Boolean 337opType opr [TPrim Boolean, TPrim Boolean] 338 | opr `elem` [And, Or, Eq, NEq] = Just $ TPrim Boolean 339opType Not [TPrim Boolean] = Just $ TPrim Boolean 340opType Complement [TPrim p] | p /= Boolean = Just $ TPrim p 341opType opr ts = __impossible "opType" 342 343useVariable :: Fin v -> TC t v b (Maybe (Type t b)) 344useVariable v = TC $ do ret <- (`at` v) <$> get 345 case ret of 346 Nothing -> return ret 347 Just t -> do 348 ok <- canShare <$> unTC (kindcheck t) 349 unless ok $ modify (\s -> update s v Nothing) 350 return ret 351 352funType :: CoreFunName -> TC t v b (Maybe (FunctionType b)) 353funType v = TC $ (M.lookup (unCoreFunName v) . snd) <$> ask 354 355runTC :: TC t v b x 356 -> (Vec t Kind, Map FunName (FunctionType b)) 357 -> Vec v (Maybe (Type t b)) 358 -> Either String (Vec v (Maybe (Type t b)), x) 359runTC (TC a) readers st = case runState (runReaderT (runExceptT a) readers) st of 360 (Left x, s) -> Left x 361 (Right x, s) -> Right (s,x) 362 363-- XXX | tc_debug :: [Definition UntypedExpr a] -> IO () 364-- XXX | tc_debug = flip tc_debug' M.empty 365-- XXX | where 366-- XXX | tc_debug' :: [Definition UntypedExpr a] -> Map FunName FunctionType -> IO () 367-- XXX | tc_debug' [] _ = putStrLn "tc2... OK!" 368-- XXX | tc_debug' ((FunDef _ fn ts t rt e):ds) reader = 369-- XXX | case runTC (infer e) (fmap snd ts, reader) (Cons (Just t) Nil) of 370-- XXX | Left x -> putStrLn $ "tc2... failed! Due to: " ++ x 371-- XXX | Right _ -> tc_debug' ds (M.insert fn (FT (fmap snd ts) t rt) reader) 372-- XXX | tc_debug' ((AbsDecl _ fn ts t rt):ds) reader = tc_debug' ds (M.insert fn (FT (fmap snd ts) t rt) reader) 373-- XXX | tc_debug' (_:ds) reader = tc_debug' ds reader 374 375retype :: (Show b, Eq b, Pretty b, a ~ b) 376 => [Definition TypedExpr a b] 377 -> Either String [Definition TypedExpr a b] 378retype ds = fmap fst $ tc $ map untypeD ds 379 380tc :: (Show b, Eq b, Pretty b, a ~ b) 381 => [Definition UntypedExpr a b] 382 -> Either String ([Definition TypedExpr a b], Map FunName (FunctionType b)) 383tc = flip tc' M.empty 384 where 385 tc' :: (Show b, Eq b, Pretty b, a ~ b) 386 => [Definition UntypedExpr a b] 387 -> Map FunName (FunctionType b) -- the reader 388 -> Either String ([Definition TypedExpr a b], Map FunName (FunctionType b)) 389 tc' [] reader = return ([], reader) 390 tc' ((FunDef attr fn ks ls t rt e):ds) reader = 391 -- Enable recursion by inserting this function's type into the function type dictionary 392 let ft = FT (snd <$> ks) (snd <$> ls) t rt in 393 case runTC (infer e >>= flip typecheck rt) (fmap snd ks, M.insert fn ft reader) (Cons (Just t) Nil) of 394 Left x -> Left x 395 Right (_, e') -> (first (FunDef attr fn ks ls t rt e':)) <$> tc' ds (M.insert fn (FT (fmap snd ks) (fmap snd ls) t rt) reader) 396 tc' (d@(AbsDecl _ fn ks ls t rt):ds) reader = (first (Unsafe.unsafeCoerce d:)) <$> tc' ds (M.insert fn (FT (fmap snd ks) (fmap snd ls) t rt) reader) 397 tc' (d:ds) reader = (first (Unsafe.unsafeCoerce d:)) <$> tc' ds reader 398 399tc_ :: (Show b, Eq b, Pretty b, a ~ b) 400 => [Definition UntypedExpr a b] 401 -> Either String [Definition TypedExpr a b] 402tc_ = fmap fst . tc 403 404tcConsts :: [CoreConst UntypedExpr] 405 -> Map FunName (FunctionType VarName) 406 -> Either String ([CoreConst TypedExpr], Map FunName (FunctionType VarName)) 407tcConsts [] reader = return ([], reader) 408tcConsts ((v,e):ds) reader = 409 case runTC (infer e) (Nil, reader) Nil of 410 Left x -> Left x 411 Right (_,e') -> (first ((v,e'):)) <$> tcConsts ds reader 412 413withBinding :: Type t b -> TC t ('Suc v) b x -> TC t v b x 414withBinding t x 415 = TC $ do readers <- ask 416 st <- get 417 case runTC x readers (Cons (Just t) st) of 418 Left e -> throwError e 419 Right (Cons Nothing s,r) -> do put s; return r 420 Right (Cons (Just t) s, r) -> do 421 ok <- canDiscard <$> unTC (kindcheck t) 422 if ok then put s >> return r 423 else throwError "Didn't use linear variable" 424 425withBindings :: Vec k (Type t b) -> TC t (v :+: k) b x -> TC t v b x 426withBindings Nil tc = tc 427withBindings (Cons x xs) tc = withBindings xs (withBinding x tc) 428 429withBang :: [Fin v] -> TC t v b x -> TC t v b x 430withBang vs (TC x) = TC $ do st <- get 431 mapM_ (\v -> modify (modifyAt v (fmap bang))) vs 432 ret <- x 433 mapM_ (\v -> modify (modifyAt v (const $ st `at` v))) vs 434 return ret 435 436lookupKind :: Fin t -> TC t v b Kind 437lookupKind f = TC ((`at` f) . fst <$> ask) 438 439kindcheck_ :: (Monad m) => (Fin t -> m Kind) -> Type t a -> m Kind 440kindcheck_ f (TVar v) = f v 441kindcheck_ f (TVarBang v) = bangKind <$> f v 442kindcheck_ f (TVarUnboxed v) = return mempty 443kindcheck_ f (TCon n vs s) = return $ sigilKind s 444kindcheck_ f (TFun ti to) = return mempty 445kindcheck_ f (TPrim i) = return mempty 446kindcheck_ f (TString) = return mempty 447kindcheck_ f (TProduct t1 t2) = mappend <$> kindcheck_ f t1 <*> kindcheck_ f t2 448kindcheck_ f (TRecord _ ts s) = mconcat <$> ((sigilKind s :) <$> mapM (kindcheck_ f . fst . snd) (filter (not . snd . snd) ts)) 449kindcheck_ f (TSum ts) = mconcat <$> mapM (kindcheck_ f . fst . snd) (filter (not . snd . snd) ts) 450kindcheck_ f (TUnit) = return mempty 451kindcheck_ f (TRPar _ _) = return mempty 452 453#ifdef BUILTIN_ARRAYS 454kindcheck_ f (TArray t l s _) = mappend <$> kindcheck_ f t <*> pure (sigilKind s) 455#endif 456 457kindcheck = kindcheck_ lookupKind 458 459 460typecheck :: (Pretty a, Show a, Eq a) => TypedExpr t v a a -> Type t a -> TC t v a (TypedExpr t v a a) 461typecheck e t = do 462 let t' = exprType e 463 isSub <- isSubtype t' t 464 if | t == t' -> return e 465 | isSub -> return (promote t e) 466 | otherwise -> __impossible $ "Inferred type of\n" ++ 467 show (indent' $ pretty e) ++ 468 "\ndoesn't agree with the given type signature:\n" ++ 469 "Inferred type:\n" ++ 470 show (indent' $ pretty t') ++ 471 "\nGiven type:\n" ++ 472 show (indent' $ pretty t) ++ "\n" 473 474infer :: (Pretty a, Show a, Eq a) => UntypedExpr t v a a -> TC t v a (TypedExpr t v a a) 475infer (E (Op o es)) 476 = do es' <- mapM infer es 477 let Just t = opType o (map exprType es') 478 return (TE t (Op o es')) 479infer (E (ILit i t)) = return (TE (TPrim t) (ILit i t)) 480infer (E (SLit s)) = return (TE TString (SLit s)) 481#ifdef BUILTIN_ARRAYS 482infer (E (ALit [])) = __impossible "We don't allow 0-size array literals" 483infer (E (ALit es)) 484 = do es' <- mapM infer es 485 let ts = map exprType es' 486 n = LILit (fromIntegral $ length es) U32 487 t <- lubAll ts 488 isSub <- allM (`isSubtype` t) ts 489 return (TE (TArray t n Unboxed Nothing) (ALit es')) 490 where 491 lubAll :: (Show b, Eq b) => [Type t b] -> TC t v b (Type t b) 492 lubAll [] = __impossible "lubAll: empty list" 493 lubAll [t] = return t 494 lubAll (t1:t2:ts) = do Just t <- runMaybeT $ lub t1 t2 495 lubAll (t:ts) 496infer (E (ArrayIndex arr idx)) 497 = do arr'@(TE ta _) <- infer arr 498 let TArray te l _ _ = ta 499 idx' <- infer idx 500 -- guardShow ("arr-idx out of bound") $ idx >= 0 && idx < l -- no way to check it. need ref types. / zilinc 501 guardShow ("arr-idx on non-linear") . canShare =<< kindcheck ta 502 return (TE te (ArrayIndex arr' idx')) 503infer (E (ArrayMap2 (as,f) (e1,e2))) 504 = do e1'@(TE t1 _) <- infer e1 505 e2'@(TE t2 _) <- infer e2 506 let TArray te1 l1 _ _ = t1 507 TArray te2 l2 _ _ = t2 508 f' <- withBindings (Cons te2 (Cons te1 Nil)) $ infer f 509 let t = case __cogent_ftuples_as_sugar of 510 False -> TProduct t1 t2 511 True -> TRecord NonRec (zipWith (\f t -> (f,(t,False))) tupleFieldNames [t1,t2]) Unboxed 512 return $ TE t $ ArrayMap2 (as,f') (e1',e2') 513infer (E (Pop a e1 e2)) 514 = do e1'@(TE t1 _) <- infer e1 515 let TArray te l s tkns = t1 516 thd = te 517 ttl = TArray te (LOp Minus [l, LILit 1 U32]) s tkns 518 -- guardShow "arr-pop on a singleton array" $ l > 1 519 e2'@(TE t2 _) <- withBindings (Cons thd (Cons ttl Nil)) $ infer e2 520 return (TE t2 (Pop a e1' e2')) 521infer (E (Singleton e)) 522 = do e'@(TE t _) <- infer e 523 let TArray te l _ _ = t 524 -- guardShow "singleton on a non-singleton array" $ l == 1 525 return (TE te (Singleton e')) 526infer (E (ArrayTake as arr i e)) 527 = do arr'@(TE tarr _) <- infer arr 528 i' <- infer i 529 let TArray telt len s Nothing = tarr 530 tarr' = TArray telt len s (Just $ texprToLExpr id i') 531 e'@(TE te _) <- withBindings (Cons telt (Cons tarr' Nil)) $ infer e 532 return (TE te $ ArrayTake as arr' i' e') 533infer (E (ArrayPut arr i e)) 534 = do arr'@(TE tarr _) <- infer arr 535 i' <- infer i 536 e'@(TE te _) <- infer e 537 -- FIXME: all the checks are disabled here, for the lack of a proper 538 -- refinement type system. Also, we cannot know the exact index that 539 -- is being put, thus there's no way that we can infer the precise type 540 -- for the new array (tarr'). 541 let TArray telm len s tkns = tarr 542 -- XXX | mi <- evalExpr i' 543 -- XXX | guardShow "@put index not a integral constant" $ isJust mi 544 -- XXX | let Just i'' = mi 545 -- XXX | guardShow "@put index is out of range" $ i'' `IM.member` tkns 546 -- XXX | let Just itkn = IM.lookup i'' tkns 547 -- XXX | k <- kindcheck telm 548 -- XXX | unless itkn $ guardShow "@put a non-Discardable untaken element" $ canDiscard k 549 let tarr' = TArray telm len s Nothing 550 return (TE tarr' (ArrayPut arr' i' e')) 551#endif 552infer (E (Variable v)) 553 = do Just t <- useVariable (fst v) 554 return (TE t (Variable v)) 555infer (E (Fun f ts ls note)) 556 | ExI (Flip ts') <- Vec.fromList ts 557 , ExI (Flip ls') <- Vec.fromList ls 558 = do myMap <- ask 559 x <- funType f 560 case x of 561 Just (FT ks lts ti to) -> 562 case (Vec.length ts' =? Vec.length ks, Vec.length ls' =? Vec.length lts) 563 of (Just Refl, Just Refl) -> let ti' = substitute ts' $ substituteL ls ti 564 to' = substitute ts' $ substituteL ls to 565 in do forM_ (Vec.zip ts' ks) $ \(t, k) -> do 566 k' <- kindcheck t 567 when ((k <> k') /= k) $ __impossible "kind not matched in type instantiation" 568 return $ TE (TFun ti' to') (Fun f ts ls note) 569 _ -> __impossible "lengths don't match" 570 _ -> error $ "Something went wrong in lookup of function type: '" ++ unCoreFunName f ++ "'" 571infer (E (App e1 e2)) 572 = do e1'@(TE (TFun ti to) _) <- infer e1 573 e2'@(TE ti' _) <- infer e2 574 isSub <- ti' `isSubtype` ti 575 guardShow ("app (actual: " ++ show ti' ++ "; formal: " ++ show ti ++ ")") $ isSub 576 if ti' /= ti then return $ TE to (App e1' (promote ti e2')) 577 else return $ TE to (App e1' e2') 578infer (E (Let a e1 e2)) 579 = do e1' <- infer e1 580 e2' <- withBinding (exprType e1') (infer e2) 581 return $ TE (exprType e2') (Let a e1' e2') 582infer (E (LetBang vs a e1 e2)) 583 = do e1' <- withBang (map fst vs) (infer e1) 584 k <- kindcheck (exprType e1') 585 guardShow "let!" $ canEscape k 586 e2' <- withBinding (exprType e1') (infer e2) 587 return $ TE (exprType e2') (LetBang vs a e1' e2') 588infer (E Unit) = return $ TE TUnit Unit 589infer (E (Tuple e1 e2)) 590 = do e1' <- infer e1 591 e2' <- infer e2 592 return $ TE (TProduct (exprType e1') (exprType e2')) (Tuple e1' e2') 593infer (E (Con tag e tfull)) 594 = do e' <- infer e 595 -- Find type of payload for given tag 596 let TSum ts = tfull 597 Just (t, False) = lookup tag ts 598 -- Make sure to promote the payload to type t if necessary 599 e'' <- typecheck e' t 600 return $ TE tfull (Con tag e'' tfull) 601infer (E (If ec et ee)) 602 = do ec' <- infer ec 603 guardShow "if-1" $ exprType ec' == TPrim Boolean 604 (et', ee') <- (,) <$> infer et <||> infer ee -- have to use applicative functor, as they share the same initial env 605 let tt = exprType et' 606 te = exprType ee' 607 Just tlub <- runMaybeT $ tt `lub` te 608 isSub <- (&&) <$> tt `isSubtype` tlub <*> te `isSubtype` tlub 609 guardShow' "if-2" ["Then type:", show (pretty tt) ++ ";", "else type:", show (pretty te)] isSub 610 let et'' = if tt /= tlub then promote tlub et' else et' 611 ee'' = if te /= tlub then promote tlub ee' else ee' 612 return $ TE tlub (If ec' et'' ee'') 613infer (E (Case e tag (lt,at,et) (le,ae,ee))) 614 = do e' <- infer e 615 let TSum ts = exprType e' 616 Just (t, taken) = lookup tag ts 617 restt = TSum $ adjust tag (second $ const True) ts -- set the tag to taken 618 let e'' = case taken of 619 True -> promote (TSum $ OM.toList $ OM.adjust (\(t,True) -> (t,False)) tag $ OM.fromList ts) e' 620 False -> e' 621 (et',ee') <- (,) <$> withBinding t (infer et) 622 <||> withBinding restt (infer ee) 623 let tt = exprType et' 624 te = exprType ee' 625 Just tlub <- runMaybeT $ tt `lub` te 626 isSub <- (&&) <$> tt `isSubtype` tlub <*> te `isSubtype` tlub 627 guardShow' "case" ["Match type:", show (pretty tt) ++ ";", "rest type:", show (pretty te)] isSub 628 let et'' = if tt /= tlub then promote tlub et' else et' 629 ee'' = if te /= tlub then promote tlub ee' else ee' 630 return $ TE tlub (Case e'' tag (lt,at,et'') (le,ae,ee'')) 631infer (E (Esac e)) 632 = do e'@(TE (TSum ts) _) <- infer e 633 let t1 = filter (not . snd . snd) ts 634 case t1 of 635 [(_, (t, False))] -> return $ TE t (Esac e') 636 _ -> __impossible $ "infer: esac (t1 = " ++ show t1 ++ ", ts = " ++ show ts ++ ")" 637infer (E (Split a e1 e2)) 638 = do e1' <- infer e1 639 let (TProduct t1 t2) = exprType e1' 640 e2' <- withBindings (Cons t1 (Cons t2 Nil)) (infer e2) 641 return $ TE (exprType e2') (Split a e1' e2') 642infer (E (Member e f)) 643 = do e'@(TE t _) <- infer e -- canShare 644 let TRecord _ fs _ = t 645 guardShow "member-1" . canShare =<< kindcheck t 646 guardShow "member-2" $ f < length fs 647 let (_,(tau,c)) = fs !! f 648 guardShow "member-3" $ not c -- not taken 649 return $ TE tau (Member e' f) 650infer (E (Struct fs)) 651 = do let (ns,es) = unzip fs 652 es' <- mapM infer es 653 let ts' = zipWith (\n e' -> (n, (exprType e', False))) ns es' 654 return $ TE (TRecord NonRec ts' Unboxed) $ Struct $ zip ns es' 655infer (E (Take a e f e2)) 656 = do e'@(TE t _) <- infer e 657 -- trace ("@@@@t is " ++ show t) $ return () 658 let TRecord rp ts s = t 659 -- a common cause of this error is taking a field when you could have used member 660 guardShow ("take: sigil cannot be readonly: " ++ show (pretty e)) $ not (readonly s) 661 guardShow "take-1" $ f < length ts 662 let (init, (fn,(tau,False)):rest) = splitAt f ts 663 k <- kindcheck tau 664 e2' <- withBindings (Cons tau (Cons (TRecord rp (init ++ (fn,(tau,True)):rest) s) Nil)) (infer e2) -- take that field regardless of its shareability 665 return $ TE (exprType e2') (Take a e' f e2') 666infer (E (Put e1 f e2)) 667 = do e1'@(TE t1 _) <- infer e1 668 let TRecord rp ts s = t1 669 guardShow "put: sigil not readonly" $ not (readonly s) 670 guardShow "put-1" $ f < length ts 671 let (init, (fn,(tau,taken)):rest) = splitAt f ts 672 k <- kindcheck tau 673 unless taken $ guardShow "put-2" $ canDiscard k -- if it's not taken, then it has to be discardable; if taken, then just put 674 e2'@(TE t2 _) <- infer e2 675 isSub <- t2 `isSubtype` tau 676 guardShow "put-3" isSub 677 let e2'' = if t2 /= tau then promote tau e2' else e2' 678 return $ TE (TRecord rp (init ++ (fn,(tau,False)):rest) s) (Put e1' f e2'') -- put it regardless 679infer (E (Cast ty e)) 680 = do (TE t e') <- infer e 681 guardShow ("cast: " ++ show t ++ " <<< " ++ show ty) =<< t `isUpcastable` ty 682 return $ TE ty (Cast ty $ TE t e') 683infer (E (Promote ty e)) 684 = do (TE t e') <- infer e 685 guardShow ("promote: " ++ show t ++ " << " ++ show ty) =<< t `isSubtype` ty 686 return $ if t /= ty then promote ty $ TE t e' 687 else TE t e' -- see NOTE [How to handle type annotations?] in Desugar 688 689 690-- | Promote an expression to a given type, pushing down the promote as far as possible. 691-- This structure is useful when destructing runs of case expressions, for example in Cogent.Isabelle.Compound. 692-- 693-- Consider this example of a ternary case: 694-- > Case scrutinee tag1 695-- > when_tag1 696-- > (Promote ty 697-- > (Case (Var 0) tag2 698-- > when_tag2 699-- > (Promote ty 700-- > (Let 701-- > (Esac (Var 0)) 702-- > when_tag3)))))) 703-- 704-- Here, the promote expressions obscure the nested pattern-matching structure of the program. 705-- We would like instead to push down the promotes to the following: 706-- > Case scrutinee tag1 707-- > when_tag1 708-- > (Case (Var 0) tag2 709-- > (Promote ty when_tag2) 710-- > (Let 711-- > (Esac (Var 0)) 712-- > (Promote ty when_tag3))) 713-- 714-- In this pushed version, the promotion and the pattern matching are separate. 715-- 716-- A-normalisation results in a similar structure, but when squashing case expressions for the 717-- shallow embedding, we want this to apply to desugared as well as normalised. 718-- 719promote :: Type t b -> TypedExpr t v a b -> TypedExpr t v a b 720promote t (TE t' e) = case e of 721 -- For continuation forms, push the promote into the continuations 722 Let a e1 e2 -> TE t $ Let a e1 $ promote t e2 723 LetBang vs a e1 e2 -> TE t $ LetBang vs a e1 $ promote t e2 724 If ec et ee -> TE t $ If ec (promote t et) (promote t ee) 725 Case e tag (l1,a1,e1) (l2,a2,e2) 726 -> TE t $ Case e tag 727 (l1, a1, promote t e1) 728 (l2, a2, promote t e2) 729 -- Collapse consecutive promotes 730 Promote _ e' -> promote t e' 731 -- Otherwise, no simplification is necessary; construct a Promote expression as usual. 732 _ -> TE t $ Promote t (TE t' e) 733 734 735