1--
2-- Copyright 2018, Data61
3-- Commonwealth Scientific and Industrial Research Organisation (CSIRO)
4-- ABN 41 687 119 230.
5--
6-- This software may be distributed and modified according to the terms of
7-- the GNU General Public License version 2. Note that NO WARRANTY is provided.
8-- See "LICENSE_GPLv2.txt" for details.
9--
10-- @TAG(DATA61_GPL)
11--
12
13{- LANGUAGE AllowAmbiguousTypes -}
14{-# LANGUAGE DataKinds #-}
15{- LANGUAGE DeriveDataTypeable -}
16{-# LANGUAGE DeriveFunctor #-}
17{-# LANGUAGE ExistentialQuantification #-}
18{-# LANGUAGE FlexibleContexts #-}
19{-# LANGUAGE FlexibleInstances #-}
20{-# LANGUAGE GADTs #-}
21{-# LANGUAGE GeneralizedNewtypeDeriving #-}
22{- LANGUAGE InstanceSigs -}
23{-# LANGUAGE KindSignatures #-}
24{-# LANGUAGE LambdaCase #-}
25{-# LANGUAGE MultiWayIf #-}
26{-# LANGUAGE PatternGuards #-}
27{-# LANGUAGE PolyKinds #-}
28{-# LANGUAGE Rank2Types #-}
29{-# LANGUAGE ScopedTypeVariables #-}
30{- LANGUAGE StandaloneDeriving #-}
31{-# LANGUAGE TupleSections #-}
32{-# LANGUAGE TypeFamilies #-}
33{-# LANGUAGE TypeOperators #-}
34{-# LANGUAGE UndecidableInstances #-}
35
36module Cogent.Inference where
37
38import Cogent.Common.Syntax
39import Cogent.Common.Types
40import Cogent.Compiler
41import Cogent.Core
42import Cogent.Dargent.Allocation
43import Cogent.Dargent.Util
44import Cogent.Util
45import Cogent.PrettyPrint (indent')
46import Data.Ex
47import Data.Fin
48import Data.Nat
49import qualified Data.OMap as OM
50import Data.PropEq
51import Data.Vec hiding (repeat, splitAt, length, zipWith, zip, unzip)
52import qualified Data.Vec as Vec
53
54import Control.Applicative
55import Control.Arrow
56import Control.Monad.Except hiding (fmap, forM_)
57import Control.Monad.Reader hiding (fmap, forM_)
58import Control.Monad.State hiding (fmap, forM_)
59import Control.Monad.Trans.Maybe
60import Data.Foldable (forM_)
61import Data.Function (on)
62import qualified Data.IntMap as IM
63import Data.Map (Map)
64import Data.Maybe (isJust)
65import qualified Data.Map as M
66import Data.Monoid
67#if __GLASGOW_HASKELL__ < 709
68import Data.Traversable(traverse)
69#endif
70import Lens.Micro (_2)
71import Lens.Micro.Mtl (view)
72import Text.PrettyPrint.ANSI.Leijen (Pretty, pretty)
73import qualified Unsafe.Coerce as Unsafe (unsafeCoerce)  -- NOTE: used safely to coerce phantom types only
74
75import Debug.Trace
76
77guardShow :: String -> Bool -> TC t v b ()
78guardShow x b = unless b $ TC (throwError $ "GUARD: " ++ x)
79
80guardShow' :: String -> [String] -> Bool -> TC t v b ()
81guardShow' mh mb b = unless b $ TC (throwError $ "GUARD: " ++ mh ++ "\n" ++ unlines mb)
82
83-- ----------------------------------------------------------------------------
84-- Type reconstruction
85
86-- Types that don't have the same representation / don't satisfy subtyping.
87isUpcastable :: (Show b, Eq b) => Type t b -> Type t b -> TC t v b Bool
88isUpcastable (TPrim p1) (TPrim p2) = return $ isSubtypePrim p1 p2
89isUpcastable (TSum s1) (TSum s2) = do
90  c1 <- flip allM s1 (\(c,(t,b)) -> case lookup c s2 of
91          Nothing -> return False
92          Just (t',b') -> (&&) <$> t `isSubtype` t' <*> pure (b == b'))
93  c2 <- flip allM s2 (\(c,(t,b)) -> return $ case lookup c s1 of Nothing -> b; Just _ -> True)  -- other tags are all taken
94  return $ c1 && c2
95isUpcastable _ _ = return False
96
97isSubtype :: (Show b, Eq b) => Type t b -> Type t b -> TC t v b Bool
98isSubtype t1 t2 = runMaybeT (t1 `lub` t2) >>= \case Just t  -> return $ t == t2
99                                                    Nothing -> return False
100
101unroll :: RecParName -> RecContext (Type t b) -> Type t b
102unroll v (Just ctxt) = erp (Just ctxt) (ctxt M.! v)
103  where
104    -- Embed rec pars
105    erp :: RecContext (Type t b) -> Type t b -> Type t b
106    erp c (TCon n ts s) = TCon n (map (erp c) ts) s
107    erp c (TFun t1 t2) = TFun (erp c t1) (erp c t2)
108    erp c (TSum r) = TSum $ map (\(a,(t,b)) -> (a, (erp c t, b))) r
109    erp c (TProduct t1 t2) = TProduct (erp c t1) (erp c t2)
110    erp (Just c) t@(TRecord rp fs s) =
111      let c' = case rp of Rec v -> M.insert v t c; _ -> c
112      in TRecord rp (map (\(a,(t,b)) -> (a, (erp (Just c') t, b))) fs) s
113    -- Context must be Nothing at this point
114    erp c (TRPar v Nothing) = TRPar v c
115#ifdef BUILTIN_ARRAYS
116    erp c (TArray t l s h) = TArray (erp c t) l s h
117#endif
118    erp _ t = t
119
120bound :: (Show b, Eq b) => Bound -> Type t b -> Type t b -> MaybeT (TC t v b) (Type t b)
121bound _ t1 t2 | t1 == t2 = return t1
122bound b (TRecord rp1 fs1 s1) (TRecord rp2 fs2 s2)
123  | map fst fs1 == map fst fs2, s1 == s2, rp1 == rp2 = do
124    let op = case b of LUB -> (||); GLB -> (&&)
125    blob <- flip3 zipWithM fs2 fs1 $ \(f1,(t1,b1)) (_, (t2,b2)) -> do
126      t <- bound b t1 t2
127      ok <- lift $ if b1 == b2 then return True
128                               else kindcheck t >>= \k -> return (canDiscard k)
129      return ((f1, (t, b1 `op` b2)), ok)
130    let (fs, oks) = unzip blob
131    if and oks then return $ TRecord rp1 fs s1
132               else MaybeT (return Nothing)
133bound b (TSum s1) (TSum s2) | s1' <- M.fromList s1, s2' <- M.fromList s2, M.keys s1' == M.keys s2' = do
134  let op = case b of LUB -> (&&); GLB -> (||)
135  s <- flip3 unionWithKeyM s2' s1' $ \k (t1,b1) (t2,b2) -> (,) <$> bound b t1 t2 <*> pure (b1 `op` b2)
136  return $ TSum $ M.toList s
137bound b (TProduct t11 t12) (TProduct t21 t22) = TProduct <$> bound b t11 t21 <*> bound b t12 t22
138bound b (TCon c1 t1 s1) (TCon c2 t2 s2) | c1 == c2, s1 == s2 = TCon c1 <$> zipWithM (bound b) t1 t2 <*> pure s1
139bound b (TFun t1 s1) (TFun t2 s2) = TFun <$> bound (theOtherB b) t1 t2 <*> bound b s1 s2
140-- At this point, we can assume recursive parameters and records agree
141bound b t1@(TRecord rp fs s) t2@(TRPar v ctxt)    = return t2
142bound b t1@(TRPar v ctxt)    t2@(TRecord rp fs s) = return t2
143bound b t1@(TRPar v1 c1)     t2@(TRPar v2 c2)     = return t2
144#ifdef BUILTIN_ARRAYS
145bound b (TArray t1 l1 s1 mhole1) (TArray t2 l2 s2 mhole2)
146  | l1 == l2, s1 == s2 = do
147      t <- bound b t1 t2
148      ok <- lift $ case (mhole1, mhole2) of
149                     (Nothing, Nothing) -> return True
150                     (Just i1, Just i2) -> return $ i1 == i2  -- FIXME: change to propositional equality / zilinc
151                     _ -> kindcheck t >>= \k -> return (canDiscard k)
152      let mhole = combineHoles b mhole1 mhole2
153      if ok then return $ TArray t l1 s1 mhole
154            else MaybeT (return Nothing)
155  where
156    combineHoles b Nothing   Nothing   = Nothing
157    combineHoles b (Just i1) (Just _ ) = Just i1
158    combineHoles b Nothing   (Just i2) = case b of GLB -> Nothing; LUB -> Just i2
159    combineHoles b (Just i1) Nothing   = case b of GLB -> Nothing; LUB -> Just i1
160#endif
161bound _ t1 t2 = __impossible ("bound: not comparable:\n" ++ show t1 ++ "\n" ++ 
162                              "----------------------------------------\n" ++ show t2 ++ "\n")
163
164lub :: (Show b, Eq b) => Type t b -> Type t b -> MaybeT (TC t v b) (Type t b)
165lub = bound LUB
166
167glb :: (Show b, Eq b) => Type t b -> Type t b -> MaybeT (TC t v b) (Type t b)
168glb = bound GLB
169
170-- checkUExpr_B :: UntypedExpr -> TC t v Bool
171-- checkUExpr_B (E (Op op [e])) = return True
172-- checkUExpr_B (E (Op op [e1,e2])) = return True
173-- checkUExpr_B _ = return True
174
175
176bang :: Type t b -> Type t b
177bang (TVar v)          = TVarBang v
178bang (TVarBang v)      = TVarBang v
179bang (TVarUnboxed v)   = TVarUnboxed v
180bang (TCon n ts s)     = TCon n (map bang ts) (bangSigil s)
181bang (TFun ti to)      = TFun ti to
182bang (TPrim i)         = TPrim i
183bang (TString)         = TString
184bang (TSum ts)         = TSum (map (second $ first bang) ts)
185bang (TProduct t1 t2)  = TProduct (bang t1) (bang t2)
186bang (TRecord rp ts s) = TRecord rp (map (second $ first bang) ts) (bangSigil s)
187bang (TRPar n ctxt)    = TRPar n ctxt
188bang (TUnit)           = TUnit
189#ifdef BUILTIN_ARRAYS
190bang (TArray t l s tkns) = TArray (bang t) l (bangSigil s) tkns
191#endif
192
193unbox :: Type t b -> Type t b
194unbox (TVar v)         = TVarUnboxed v
195unbox (TVarBang v)     = TVarUnboxed v
196unbox (TVarUnboxed v)  = TVarUnboxed v
197unbox (TCon n ts s)    = TCon n ts (unboxSigil s)
198unbox (TRecord rp ts s)= TRecord rp ts (unboxSigil s)
199unbox t                = t  -- NOTE that @#@ type operator behaves differently to @!@.
200                            -- The application of @#@ should NOT be pushed inside of a
201                            -- data type. / zilinc
202
203
204substitute :: Vec t (Type u b) -> Type t b -> Type u b
205substitute vs (TVar v)         = vs `at` v
206substitute vs (TVarBang v)     = bang (vs `at` v)
207substitute vs (TVarUnboxed v)  = unbox (vs `at` v)
208substitute vs (TCon n ts s)    = TCon n (map (substitute vs) ts) s
209substitute vs (TFun ti to)     = TFun (substitute vs ti) (substitute vs to)
210substitute _  (TPrim i)        = TPrim i
211substitute _  (TString)        = TString
212substitute vs (TProduct t1 t2) = TProduct (substitute vs t1) (substitute vs t2)
213substitute vs (TRecord rp ts s) = TRecord rp (map (second (first $ substitute vs)) ts) s
214substitute vs (TSum ts)         = TSum (map (second (first $ substitute vs)) ts)
215substitute _  (TUnit)           = TUnit
216substitute vs (TRPar v m)       = TRPar v $ fmap (M.map (substitute vs)) m
217#ifdef BUILTIN_ARRAYS
218substitute vs (TArray t l s mhole) = TArray (substitute vs t) (substituteLE vs l) s (fmap (substituteLE vs) mhole)
219#endif
220
221substituteL :: [DataLayout BitRange] -> Type t b -> Type t b
222substituteL ls (TCon n ts s)     = TCon n (map (substituteL ls) ts) s
223substituteL ls (TFun ti to)      = TFun (substituteL ls ti) (substituteL ls to)
224substituteL ls (TProduct t1 t2)  = TProduct (substituteL ls t1) (substituteL ls t2)
225substituteL ls (TRecord rp ts s) = TRecord rp (map (second (first $ substituteL ls)) ts) (substituteS ls s)
226substituteL ls (TSum ts)         = TSum (map (second (first $ substituteL ls)) ts)
227#ifdef BUILTIN_ARRAYS
228substituteL ls (TArray t l s mhole) = TArray (substituteL ls t) l (substituteS ls s) mhole
229#endif
230substituteL _  t                 = t
231
232substituteS :: [DataLayout BitRange] -> Sigil (DataLayout BitRange) -> Sigil (DataLayout BitRange)
233substituteS ls Unboxed = Unboxed
234substituteS ls (Boxed b CLayout) = Boxed b CLayout
235substituteS ls (Boxed b (Layout l)) = Boxed b . Layout $ substituteS' ls l
236  where
237    substituteS' :: [DataLayout BitRange] -> DataLayout' BitRange -> DataLayout' BitRange
238    substituteS' ls l = case l of
239      VarLayout n s -> case ls !! (natToInt n) of
240                       CLayout -> __impossible "substituteS in Inference: CLayout shouldn't be here"
241                       Layout l -> offset s l
242      SumLayout tag alts ->
243        let altl = M.toList alts
244            fns = fmap fst altl
245            fis = fmap fst $ fmap snd altl
246            fes = fmap snd $ fmap snd altl
247         in SumLayout tag $ M.fromList (zip fns $ zip fis (fmap (substituteS' ls) fes))
248      RecordLayout fs ->
249        let fsl = M.toList fs
250            fns = fmap fst fsl
251            fes = fmap snd fsl
252         in RecordLayout $ M.fromList (zip fns (fmap (substituteS' ls) fes))
253#ifdef BUILTIN_ARRAYS
254      ArrayLayout e -> ArrayLayout $ substituteS' ls e
255#endif
256      _ -> l
257
258substituteLE :: Vec t (Type u b) -> LExpr t b -> LExpr u b
259substituteLE vs = \case
260  LVariable va       -> LVariable va
261  LFun fn ts ls      -> LFun fn (fmap (substitute vs) ts) ls
262  LOp op es          -> LOp op $ fmap go es
263  LApp e1 e2         -> LApp (go e1) (go e2)
264  LCon tn e t        -> LCon tn (go e) (substitute vs t)
265  LUnit              -> LUnit
266  LILit n t          -> LILit n t
267  LSLit s            -> LSLit s
268  LLet a e e'        -> LLet a (go e) (go e')
269  LLetBang bs a e e' -> LLetBang bs a (go e) (go e')
270  LTuple e1 e2       -> LTuple (go e1) (go e2)
271  LStruct fs         -> LStruct $ fmap (second go) fs
272  LIf c th el        -> LIf (go c) (go th) (go el)
273  LCase e tn (a1,e1) (a2,e2)
274                     -> LCase (go e) tn (a1,go e1) (a2,go e2)
275  LEsac e            -> LEsac $ go e
276  LSplit as e e'     -> LSplit as (go e) (go e')
277  LMember e f        -> LMember (go e) f
278  LTake as rec f e'  -> LTake as (go rec) f (go e')
279  LPut rec f e       -> LPut (go rec) f (go e)
280  LPromote t e       -> LPromote (substitute vs t) (go e)
281  LCast t e          -> LCast (substitute vs t) (go e)
282 where go = substituteLE vs
283
284remove :: (Eq a) => a -> [(a,b)] -> [(a,b)]
285remove k = filter ((/= k) . fst)
286
287adjust :: (Eq a) => a -> (b -> b) -> [(a,b)] -> [(a,b)]
288adjust k f = map (\(a,b) -> (a,) $ if a == k then f b else b)
289
290
291newtype TC (t :: Nat) (v :: Nat) b x
292  = TC {unTC :: ExceptT String
293                        (ReaderT (Vec t Kind, Map FunName (FunctionType b))
294                                 (State (Vec v (Maybe (Type t b)))))
295                        x}
296  deriving (Functor, Applicative, Alternative, Monad, MonadPlus,
297            MonadReader (Vec t Kind, Map FunName (FunctionType b)))
298
299#if MIN_VERSION_base(4,13,0)
300instance MonadFail (TC t v b) where
301  fail = __impossible
302#endif
303
304infixl 4 <||>
305(<||>) :: TC t v b (x -> y) -> TC t v b x -> TC t v b y
306(TC a) <||> (TC b) = TC $ do x <- get
307                             f <- a
308                             x1 <- get
309                             put x
310                             arg <- b
311                             x2 <- get
312                             -- XXX | unTC $ guardShow "<||>" $ x1 == x2
313                             -- \ ^^^ NOTE: This check is taken out to fix
314                             -- #296.  The issue here is that, if we define a
315                             -- variable of permission D alone (w/o S), it will
316                             -- be marked as used after it's been used, which
317                             -- is correct. But when it is used in one branch
318                             -- but not in the other one, which is allowed as
319                             -- it's droppable, it will be marked as used in
320                             -- the context of one branch but not the other and
321                             -- render the two contexts different. The formal
322                             -- specification requires that both contexts are
323                             -- the same, but it is tantamount to merging two
324                             -- differerent (correct) contexts correctly, which
325                             -- can be established in the typing proof.
326                             -- / v.jackson, zilinc
327                             return (f arg)
328
329opType :: Op -> [Type t b] -> Maybe (Type t b)
330opType opr [TPrim p1, TPrim p2]
331  | opr `elem` [Plus, Minus, Times, Divide, Mod,
332                BitAnd, BitOr, BitXor, LShift, RShift],
333    p1 == p2, p1 /= Boolean = Just $ TPrim p1
334opType opr [TPrim p1, TPrim p2]
335  | opr `elem` [Gt, Lt, Le, Ge, Eq, NEq],
336    p1 == p2, p1 /= Boolean = Just $ TPrim Boolean
337opType opr [TPrim Boolean, TPrim Boolean]
338  | opr `elem` [And, Or, Eq, NEq] = Just $ TPrim Boolean
339opType Not [TPrim Boolean] = Just $ TPrim Boolean
340opType Complement [TPrim p] | p /= Boolean = Just $ TPrim p
341opType opr ts = __impossible "opType"
342
343useVariable :: Fin v -> TC t v b (Maybe (Type t b))
344useVariable v = TC $ do ret <- (`at` v) <$> get
345                        case ret of
346                          Nothing -> return ret
347                          Just t  -> do
348                            ok <- canShare <$> unTC (kindcheck t)
349                            unless ok $ modify (\s -> update s v Nothing)
350                            return ret
351
352funType :: CoreFunName -> TC t v b (Maybe (FunctionType b))
353funType v = TC $ (M.lookup (unCoreFunName v) . snd) <$> ask
354
355runTC :: TC t v b x
356      -> (Vec t Kind, Map FunName (FunctionType b))
357      -> Vec v (Maybe (Type t b))
358      -> Either String (Vec v (Maybe (Type t b)), x)
359runTC (TC a) readers st = case runState (runReaderT (runExceptT a) readers) st of
360                            (Left x, s)  -> Left x
361                            (Right x, s) -> Right (s,x)
362
363-- XXX | tc_debug :: [Definition UntypedExpr a] -> IO ()
364-- XXX | tc_debug = flip tc_debug' M.empty
365-- XXX |   where
366-- XXX |     tc_debug' :: [Definition UntypedExpr a] -> Map FunName FunctionType -> IO ()
367-- XXX |     tc_debug' [] _ = putStrLn "tc2... OK!"
368-- XXX |     tc_debug' ((FunDef _ fn ts t rt e):ds) reader =
369-- XXX |       case runTC (infer e) (fmap snd ts, reader) (Cons (Just t) Nil) of
370-- XXX |         Left x -> putStrLn $ "tc2... failed! Due to: " ++ x
371-- XXX |         Right _ -> tc_debug' ds (M.insert fn (FT (fmap snd ts) t rt) reader)
372-- XXX |     tc_debug' ((AbsDecl _ fn ts t rt):ds) reader = tc_debug' ds (M.insert fn (FT (fmap snd ts) t rt) reader)
373-- XXX |     tc_debug' (_:ds) reader = tc_debug' ds reader
374
375retype :: (Show b, Eq b, Pretty b, a ~ b)
376       => [Definition TypedExpr a b]
377       -> Either String [Definition TypedExpr a b]
378retype ds = fmap fst $ tc $ map untypeD ds
379
380tc :: (Show b, Eq b, Pretty b, a ~ b)
381   => [Definition UntypedExpr a b]
382   -> Either String ([Definition TypedExpr a b], Map FunName (FunctionType b))
383tc = flip tc' M.empty
384  where
385    tc' :: (Show b, Eq b, Pretty b, a ~ b)
386        => [Definition UntypedExpr a b]
387        -> Map FunName (FunctionType b)  -- the reader
388        -> Either String ([Definition TypedExpr a b], Map FunName (FunctionType b))
389    tc' [] reader = return ([], reader)
390    tc' ((FunDef attr fn ks ls t rt e):ds) reader =
391      -- Enable recursion by inserting this function's type into the function type dictionary
392      let ft = FT (snd <$> ks) (snd <$> ls) t rt in
393      case runTC (infer e >>= flip typecheck rt) (fmap snd ks, M.insert fn ft reader) (Cons (Just t) Nil) of
394        Left x -> Left x
395        Right (_, e') -> (first (FunDef attr fn ks ls t rt e':)) <$> tc' ds (M.insert fn (FT (fmap snd ks) (fmap snd ls) t rt) reader)
396    tc' (d@(AbsDecl _ fn ks ls t rt):ds) reader = (first (Unsafe.unsafeCoerce d:)) <$> tc' ds (M.insert fn (FT (fmap snd ks) (fmap snd ls) t rt) reader)
397    tc' (d:ds) reader = (first (Unsafe.unsafeCoerce d:)) <$> tc' ds reader
398
399tc_ :: (Show b, Eq b, Pretty b, a ~ b)
400    => [Definition UntypedExpr a b]
401    -> Either String [Definition TypedExpr a b]
402tc_ = fmap fst . tc
403
404tcConsts :: [CoreConst UntypedExpr]
405         -> Map FunName (FunctionType VarName)
406         -> Either String ([CoreConst TypedExpr], Map FunName (FunctionType VarName))
407tcConsts [] reader = return ([], reader)
408tcConsts ((v,e):ds) reader =
409  case runTC (infer e) (Nil, reader) Nil of
410    Left x -> Left x
411    Right (_,e') -> (first ((v,e'):)) <$> tcConsts ds reader
412
413withBinding :: Type t b -> TC t ('Suc v) b x -> TC t v b x
414withBinding t x
415  = TC $ do readers <- ask
416            st      <- get
417            case runTC x readers (Cons (Just t) st) of
418              Left e -> throwError e
419              Right (Cons Nothing s,r)   -> do put s; return r
420              Right (Cons (Just t) s, r) -> do
421                ok <- canDiscard <$> unTC (kindcheck t)
422                if ok then put s >> return r
423                      else throwError "Didn't use linear variable"
424
425withBindings :: Vec k (Type t b) -> TC t (v :+: k) b x -> TC t v b x
426withBindings Nil tc = tc
427withBindings (Cons x xs) tc = withBindings xs (withBinding x tc)
428
429withBang :: [Fin v] -> TC t v b x -> TC t v b x
430withBang vs (TC x) = TC $ do st <- get
431                             mapM_ (\v -> modify (modifyAt v (fmap bang))) vs
432                             ret <- x
433                             mapM_ (\v -> modify (modifyAt v (const $ st `at` v))) vs
434                             return ret
435
436lookupKind :: Fin t -> TC t v b Kind
437lookupKind f = TC ((`at` f) . fst <$> ask)
438
439kindcheck_ :: (Monad m) => (Fin t -> m Kind) -> Type t a -> m Kind
440kindcheck_ f (TVar v)         = f v
441kindcheck_ f (TVarBang v)     = bangKind <$> f v
442kindcheck_ f (TVarUnboxed v)  = return mempty
443kindcheck_ f (TCon n vs s)    = return $ sigilKind s
444kindcheck_ f (TFun ti to)     = return mempty
445kindcheck_ f (TPrim i)        = return mempty
446kindcheck_ f (TString)        = return mempty
447kindcheck_ f (TProduct t1 t2) = mappend <$> kindcheck_ f t1 <*> kindcheck_ f t2
448kindcheck_ f (TRecord _ ts s) = mconcat <$> ((sigilKind s :) <$> mapM (kindcheck_ f . fst . snd) (filter (not . snd . snd) ts))
449kindcheck_ f (TSum ts)        = mconcat <$> mapM (kindcheck_ f . fst . snd) (filter (not . snd . snd) ts)
450kindcheck_ f (TUnit)          = return mempty
451kindcheck_ f (TRPar _ _)      = return mempty
452
453#ifdef BUILTIN_ARRAYS
454kindcheck_ f (TArray t l s _) = mappend <$> kindcheck_ f t <*> pure (sigilKind s)
455#endif
456
457kindcheck = kindcheck_ lookupKind
458
459
460typecheck :: (Pretty a, Show a, Eq a) => TypedExpr t v a a -> Type t a -> TC t v a (TypedExpr t v a a)
461typecheck e t = do
462  let t' = exprType e
463  isSub <- isSubtype t' t
464  if | t == t' -> return e
465     | isSub -> return (promote t e)
466     | otherwise -> __impossible $ "Inferred type of\n" ++
467                                   show (indent' $ pretty e) ++
468                                   "\ndoesn't agree with the given type signature:\n" ++
469                                   "Inferred type:\n" ++
470                                   show (indent' $ pretty t') ++
471                                   "\nGiven type:\n" ++
472                                   show (indent' $ pretty t) ++ "\n"
473
474infer :: (Pretty a, Show a, Eq a) => UntypedExpr t v a a -> TC t v a (TypedExpr t v a a)
475infer (E (Op o es))
476   = do es' <- mapM infer es
477        let Just t = opType o (map exprType es')
478        return (TE t (Op o es'))
479infer (E (ILit i t)) = return (TE (TPrim t) (ILit i t))
480infer (E (SLit s)) = return (TE TString (SLit s))
481#ifdef BUILTIN_ARRAYS
482infer (E (ALit [])) = __impossible "We don't allow 0-size array literals"
483infer (E (ALit es))
484   = do es' <- mapM infer es
485        let ts = map exprType es'
486            n = LILit (fromIntegral $ length es) U32
487        t <- lubAll ts
488        isSub <- allM (`isSubtype` t) ts
489        return (TE (TArray t n Unboxed Nothing) (ALit es'))
490  where
491    lubAll :: (Show b, Eq b) => [Type t b] -> TC t v b (Type t b)
492    lubAll [] = __impossible "lubAll: empty list"
493    lubAll [t] = return t
494    lubAll (t1:t2:ts) = do Just t <- runMaybeT $ lub t1 t2
495                           lubAll (t:ts)
496infer (E (ArrayIndex arr idx))
497   = do arr'@(TE ta _) <- infer arr
498        let TArray te l _ _ = ta
499        idx' <- infer idx
500        -- guardShow ("arr-idx out of bound") $ idx >= 0 && idx < l  -- no way to check it. need ref types. / zilinc
501        guardShow ("arr-idx on non-linear") . canShare =<< kindcheck ta
502        return (TE te (ArrayIndex arr' idx'))
503infer (E (ArrayMap2 (as,f) (e1,e2)))
504   = do e1'@(TE t1 _) <- infer e1
505        e2'@(TE t2 _) <- infer e2
506        let TArray te1 l1 _ _ = t1
507            TArray te2 l2 _ _ = t2
508        f' <- withBindings (Cons te2 (Cons te1 Nil)) $ infer f
509        let t = case __cogent_ftuples_as_sugar of
510                  False -> TProduct t1 t2
511                  True  -> TRecord NonRec (zipWith (\f t -> (f,(t,False))) tupleFieldNames [t1,t2]) Unboxed
512        return $ TE t $ ArrayMap2 (as,f') (e1',e2')
513infer (E (Pop a e1 e2))
514   = do e1'@(TE t1 _) <- infer e1
515        let TArray te l s tkns = t1
516            thd = te
517            ttl = TArray te (LOp Minus [l, LILit 1 U32]) s tkns
518        -- guardShow "arr-pop on a singleton array" $ l > 1
519        e2'@(TE t2 _) <- withBindings (Cons thd (Cons ttl Nil)) $ infer e2
520        return (TE t2 (Pop a e1' e2'))
521infer (E (Singleton e))
522   = do e'@(TE t _) <- infer e
523        let TArray te l _ _ = t
524        -- guardShow "singleton on a non-singleton array" $ l == 1
525        return (TE te (Singleton e'))
526infer (E (ArrayTake as arr i e))
527   = do arr'@(TE tarr _) <- infer arr
528        i' <- infer i
529        let TArray telt len s Nothing = tarr
530            tarr' = TArray telt len s (Just $ texprToLExpr id i')
531        e'@(TE te _) <- withBindings (Cons telt (Cons tarr' Nil)) $ infer e
532        return (TE te $ ArrayTake as arr' i' e')
533infer (E (ArrayPut arr i e))
534   = do arr'@(TE tarr _) <- infer arr
535        i' <- infer i
536        e'@(TE te _)   <- infer e
537        -- FIXME: all the checks are disabled here, for the lack of a proper
538        -- refinement type system. Also, we cannot know the exact index that
539        -- is being put, thus there's no way that we can infer the precise type
540        -- for the new array (tarr').
541        let TArray telm len s tkns = tarr
542        -- XXX | mi <- evalExpr i'
543        -- XXX | guardShow "@put index not a integral constant" $ isJust mi
544        -- XXX | let Just i'' = mi
545        -- XXX | guardShow "@put index is out of range" $ i'' `IM.member` tkns
546        -- XXX | let Just itkn = IM.lookup i'' tkns
547        -- XXX | k <- kindcheck telm
548        -- XXX | unless itkn $ guardShow "@put a non-Discardable untaken element" $ canDiscard k
549        let tarr' = TArray telm len s Nothing
550        return (TE tarr' (ArrayPut arr' i' e'))
551#endif
552infer (E (Variable v))
553   = do Just t <- useVariable (fst v)
554        return (TE t (Variable v))
555infer (E (Fun f ts ls note))
556   | ExI (Flip ts') <- Vec.fromList ts
557   , ExI (Flip ls') <- Vec.fromList ls
558   = do myMap <- ask
559        x <- funType f
560        case x of
561          Just (FT ks lts ti to) ->
562            case (Vec.length ts' =? Vec.length ks, Vec.length ls' =? Vec.length lts)
563              of (Just Refl, Just Refl) -> let ti' = substitute ts' $ substituteL ls ti
564                                               to' = substitute ts' $ substituteL ls to
565                                            in do forM_ (Vec.zip ts' ks) $ \(t, k) -> do
566                                                    k' <- kindcheck t
567                                                    when ((k <> k') /= k) $ __impossible "kind not matched in type instantiation"
568                                                  return $ TE (TFun ti' to') (Fun f ts ls note)
569                 _ -> __impossible "lengths don't match"
570          _        -> error $ "Something went wrong in lookup of function type: '" ++ unCoreFunName f ++ "'"
571infer (E (App e1 e2))
572   = do e1'@(TE (TFun ti to) _) <- infer e1
573        e2'@(TE ti' _) <- infer e2
574        isSub <- ti' `isSubtype` ti
575        guardShow ("app (actual: " ++ show ti' ++ "; formal: " ++ show ti ++ ")") $ isSub
576        if ti' /= ti then return $ TE to (App e1' (promote ti e2'))
577                     else return $ TE to (App e1' e2')
578infer (E (Let a e1 e2))
579   = do e1' <- infer e1
580        e2' <- withBinding (exprType e1') (infer e2)
581        return $ TE (exprType e2') (Let a e1' e2')
582infer (E (LetBang vs a e1 e2))
583   = do e1' <- withBang (map fst vs) (infer e1)
584        k <- kindcheck (exprType e1')
585        guardShow "let!" $ canEscape k
586        e2' <- withBinding (exprType e1') (infer e2)
587        return $ TE (exprType e2') (LetBang vs a e1' e2')
588infer (E Unit) = return $ TE TUnit Unit
589infer (E (Tuple e1 e2))
590   = do e1' <- infer e1
591        e2' <- infer e2
592        return $ TE (TProduct (exprType e1') (exprType e2')) (Tuple e1' e2')
593infer (E (Con tag e tfull))
594   = do e' <- infer e
595        -- Find type of payload for given tag
596        let TSum ts          = tfull
597            Just (t, False) = lookup tag ts
598        -- Make sure to promote the payload to type t if necessary
599        e'' <- typecheck e' t
600        return $ TE tfull (Con tag e'' tfull)
601infer (E (If ec et ee))
602   = do ec' <- infer ec
603        guardShow "if-1" $ exprType ec' == TPrim Boolean
604        (et', ee') <- (,) <$> infer et <||> infer ee  -- have to use applicative functor, as they share the same initial env
605        let tt = exprType et'
606            te = exprType ee'
607        Just tlub <- runMaybeT $ tt `lub` te
608        isSub <- (&&) <$> tt `isSubtype` tlub <*> te `isSubtype` tlub
609        guardShow' "if-2" ["Then type:", show (pretty tt) ++ ";", "else type:", show (pretty te)] isSub
610        let et'' = if tt /= tlub then promote tlub et' else et'
611            ee'' = if te /= tlub then promote tlub ee' else ee'
612        return $ TE tlub (If ec' et'' ee'')
613infer (E (Case e tag (lt,at,et) (le,ae,ee)))
614   = do e' <- infer e
615        let TSum ts = exprType e'
616            Just (t, taken) = lookup tag ts
617            restt = TSum $ adjust tag (second $ const True) ts  -- set the tag to taken
618        let e'' = case taken of
619                    True  -> promote (TSum $ OM.toList $ OM.adjust (\(t,True) -> (t,False)) tag $ OM.fromList ts) e'
620                    False -> e'
621        (et',ee') <- (,) <$>  withBinding t     (infer et)
622                         <||> withBinding restt (infer ee)
623        let tt = exprType et'
624            te = exprType ee'
625        Just tlub <- runMaybeT $ tt `lub` te
626        isSub <- (&&) <$> tt `isSubtype` tlub <*> te `isSubtype` tlub
627        guardShow' "case" ["Match type:", show (pretty tt) ++ ";", "rest type:", show (pretty te)] isSub
628        let et'' = if tt /= tlub then promote tlub et' else et'
629            ee'' = if te /= tlub then promote tlub ee' else ee'
630        return $ TE tlub (Case e'' tag (lt,at,et'') (le,ae,ee''))
631infer (E (Esac e))
632   = do e'@(TE (TSum ts) _) <- infer e
633        let t1 = filter (not . snd . snd) ts
634        case t1 of
635          [(_, (t, False))] -> return $ TE t (Esac e')
636          _ -> __impossible $ "infer: esac (t1 = " ++ show t1 ++ ", ts = " ++ show ts ++ ")"
637infer (E (Split a e1 e2))
638   = do e1' <- infer e1
639        let (TProduct t1 t2) = exprType e1'
640        e2' <- withBindings (Cons t1 (Cons t2 Nil)) (infer e2)
641        return $ TE (exprType e2') (Split a e1' e2')
642infer (E (Member e f))
643   = do e'@(TE t _) <- infer e  -- canShare
644        let TRecord _ fs _ = t
645        guardShow "member-1" . canShare =<< kindcheck t
646        guardShow "member-2" $ f < length fs
647        let (_,(tau,c)) = fs !! f
648        guardShow "member-3" $ not c  -- not taken
649        return $ TE tau (Member e' f)
650infer (E (Struct fs))
651   = do let (ns,es) = unzip fs
652        es' <- mapM infer es
653        let ts' = zipWith (\n e' -> (n, (exprType e', False))) ns es'
654        return $ TE (TRecord NonRec ts' Unboxed) $ Struct $ zip ns es'
655infer (E (Take a e f e2))
656   = do e'@(TE t _) <- infer e
657        -- trace ("@@@@t is " ++ show t) $ return ()
658        let TRecord rp ts s = t
659        -- a common cause of this error is taking a field when you could have used member
660        guardShow ("take: sigil cannot be readonly: " ++ show (pretty e)) $ not (readonly s)
661        guardShow "take-1" $ f < length ts
662        let (init, (fn,(tau,False)):rest) = splitAt f ts
663        k <- kindcheck tau
664        e2' <- withBindings (Cons tau (Cons (TRecord rp (init ++ (fn,(tau,True)):rest) s) Nil)) (infer e2)  -- take that field regardless of its shareability
665        return $ TE (exprType e2') (Take a e' f e2')
666infer (E (Put e1 f e2))
667   = do e1'@(TE t1 _) <- infer e1
668        let TRecord rp ts s = t1
669        guardShow "put: sigil not readonly" $ not (readonly s)
670        guardShow "put-1" $ f < length ts
671        let (init, (fn,(tau,taken)):rest) = splitAt f ts
672        k <- kindcheck tau
673        unless taken $ guardShow "put-2" $ canDiscard k  -- if it's not taken, then it has to be discardable; if taken, then just put
674        e2'@(TE t2 _) <- infer e2
675        isSub <- t2 `isSubtype` tau
676        guardShow "put-3" isSub
677        let e2'' = if t2 /= tau then promote tau e2' else e2'
678        return $ TE (TRecord rp (init ++ (fn,(tau,False)):rest) s) (Put e1' f e2'')  -- put it regardless
679infer (E (Cast ty e))
680   = do (TE t e') <- infer e
681        guardShow ("cast: " ++ show t ++ " <<< " ++ show ty) =<< t `isUpcastable` ty
682        return $ TE ty (Cast ty $ TE t e')
683infer (E (Promote ty e))
684   = do (TE t e') <- infer e
685        guardShow ("promote: " ++ show t ++ " << " ++ show ty) =<< t `isSubtype` ty
686        return $ if t /= ty then promote ty $ TE t e'
687                            else TE t e'  -- see NOTE [How to handle type annotations?] in Desugar
688
689
690-- | Promote an expression to a given type, pushing down the promote as far as possible.
691-- This structure is useful when destructing runs of case expressions, for example in Cogent.Isabelle.Compound.
692--
693-- Consider this example of a ternary case:
694-- > Case scrutinee tag1
695-- >  when_tag1
696-- >  (Promote ty
697-- >    (Case (Var 0) tag2
698-- >      when_tag2
699-- >      (Promote ty
700-- >        (Let
701-- >          (Esac (Var 0))
702-- >          when_tag3))))))
703--
704-- Here, the promote expressions obscure the nested pattern-matching structure of the program.
705-- We would like instead to push down the promotes to the following:
706-- > Case scrutinee tag1
707-- >  when_tag1
708-- >  (Case (Var 0) tag2
709-- >    (Promote ty when_tag2)
710-- >    (Let
711-- >      (Esac (Var 0))
712-- >      (Promote ty when_tag3)))
713--
714-- In this pushed version, the promotion and the pattern matching are separate.
715--
716-- A-normalisation results in a similar structure, but when squashing case expressions for the
717-- shallow embedding, we want this to apply to desugared as well as normalised.
718--
719promote :: Type t b -> TypedExpr t v a b -> TypedExpr t v a b
720promote t (TE t' e) = case e of
721  -- For continuation forms, push the promote into the continuations
722  Let a e1 e2         -> TE t $ Let a e1 $ promote t e2
723  LetBang vs a e1 e2  -> TE t $ LetBang vs a e1 $ promote t e2
724  If ec et ee         -> TE t $ If ec (promote t et) (promote t ee)
725  Case e tag (l1,a1,e1) (l2,a2,e2)
726                      -> TE t $ Case e tag
727                                  (l1, a1, promote t e1)
728                                  (l2, a2, promote t e2)
729  -- Collapse consecutive promotes
730  Promote _ e'        -> promote t e'
731  -- Otherwise, no simplification is necessary; construct a Promote expression as usual.
732  _                   -> TE t $ Promote t (TE t' e)
733
734
735