1(*  Title:      Tools/eqsubst.ML
2    Author:     Lucas Dixon, University of Edinburgh
3
4Perform a substitution using an equation.
5*)
6
7signature EQSUBST =
8sig
9  type match =
10    ((indexname * (sort * typ)) list (* type instantiations *)
11      * (indexname * (typ * term)) list) (* term instantiations *)
12    * (string * typ) list (* fake named type abs env *)
13    * (string * typ) list (* type abs env *)
14    * term (* outer term *)
15
16  type searchinfo =
17    Proof.context
18    * int (* maxidx *)
19    * Zipper.T (* focusterm to search under *)
20
21  datatype 'a skipseq = SkipMore of int | SkipSeq of 'a Seq.seq Seq.seq
22
23  val skip_first_asm_occs_search: ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> int -> 'b -> 'c skipseq
24  val skip_first_occs_search: int -> ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> 'b -> 'c Seq.seq
25  val skipto_skipseq: int -> 'a Seq.seq Seq.seq -> 'a skipseq
26
27  (* tactics *)
28  val eqsubst_asm_tac: Proof.context -> int list -> thm list -> int -> tactic
29  val eqsubst_asm_tac': Proof.context ->
30    (searchinfo -> int -> term -> match skipseq) -> int -> thm -> int -> tactic
31  val eqsubst_tac: Proof.context ->
32    int list -> (* list of occurrences to rewrite, use [0] for any *)
33    thm list -> int -> tactic
34  val eqsubst_tac': Proof.context ->
35    (searchinfo -> term -> match Seq.seq) (* search function *)
36    -> thm (* equation theorem to rewrite with *)
37    -> int (* subgoal number in goal theorem *)
38    -> thm (* goal theorem *)
39    -> thm Seq.seq (* rewritten goal theorem *)
40
41  (* search for substitutions *)
42  val valid_match_start: Zipper.T -> bool
43  val search_lr_all: Zipper.T -> Zipper.T Seq.seq
44  val search_lr_valid: (Zipper.T -> bool) -> Zipper.T -> Zipper.T Seq.seq
45  val searchf_lr_unify_all: searchinfo -> term -> match Seq.seq Seq.seq
46  val searchf_lr_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
47  val searchf_bt_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
48end;
49
50structure EqSubst: EQSUBST =
51struct
52
53(* changes object "=" to meta "==" which prepares a given rewrite rule *)
54fun prep_meta_eq ctxt =
55  Simplifier.mksimps ctxt #> map Drule.zero_var_indexes;
56
57(* make free vars into schematic vars with index zero *)
58fun unfix_frees frees =
59   fold (K (Thm.forall_elim_var 0)) frees o Drule.forall_intr_list frees;
60
61
62type match =
63  ((indexname * (sort * typ)) list (* type instantiations *)
64   * (indexname * (typ * term)) list) (* term instantiations *)
65  * (string * typ) list (* fake named type abs env *)
66  * (string * typ) list (* type abs env *)
67  * term; (* outer term *)
68
69type searchinfo =
70  Proof.context
71  * int (* maxidx *)
72  * Zipper.T; (* focusterm to search under *)
73
74
75(* skipping non-empty sub-sequences but when we reach the end
76   of the seq, remembering how much we have left to skip. *)
77datatype 'a skipseq =
78  SkipMore of int |
79  SkipSeq of 'a Seq.seq Seq.seq;
80
81(* given a seqseq, skip the first m non-empty seq's, note deficit *)
82fun skipto_skipseq m s =
83  let
84    fun skip_occs n sq =
85      (case Seq.pull sq of
86        NONE => SkipMore n
87      | SOME (h, t) =>
88        (case Seq.pull h of
89          NONE => skip_occs n t
90        | SOME _ => if n <= 1 then SkipSeq (Seq.cons h t) else skip_occs (n - 1) t))
91  in skip_occs m s end;
92
93(* note: outerterm is the taget with the match replaced by a bound
94   variable : ie: "P lhs" beocmes "%x. P x"
95   insts is the types of instantiations of vars in lhs
96   and typinsts is the type instantiations of types in the lhs
97   Note: Final rule is the rule lifted into the ontext of the
98   taget thm. *)
99fun mk_foo_match mkuptermfunc Ts t =
100  let
101    val ty = Term.type_of t
102    val bigtype = rev (map snd Ts) ---> ty
103    fun mk_foo 0 t = t
104      | mk_foo i t = mk_foo (i - 1) (t $ (Bound (i - 1)))
105    val num_of_bnds = length Ts
106    (* foo_term = "fooabs y0 ... yn" where y's are local bounds *)
107    val foo_term = mk_foo num_of_bnds (Bound num_of_bnds)
108  in Abs ("fooabs", bigtype, mkuptermfunc foo_term) end;
109
110(* T is outer bound vars, n is number of locally bound vars *)
111(* THINK: is order of Ts correct...? or reversed? *)
112fun mk_fake_bound_name n = ":b_" ^ n;
113fun fakefree_badbounds Ts t =
114  let val (FakeTs, Ts, newnames) =
115    fold_rev (fn (n, ty) => fn (FakeTs, Ts, usednames) =>
116      let
117        val newname = singleton (Name.variant_list usednames) n
118      in
119        ((mk_fake_bound_name newname, ty) :: FakeTs,
120          (newname, ty) :: Ts,
121          newname :: usednames)
122      end) Ts ([], [], [])
123  in (FakeTs, Ts, Term.subst_bounds (map Free FakeTs, t)) end;
124
125(* before matching we need to fake the bound vars that are missing an
126   abstraction. In this function we additionally construct the
127   abstraction environment, and an outer context term (with the focus
128   abstracted out) for use in rewriting with RW_Inst.rw *)
129fun prep_zipper_match z =
130  let
131    val t = Zipper.trm z
132    val c = Zipper.ctxt z
133    val Ts = Zipper.C.nty_ctxt c
134    val (FakeTs', Ts', t') = fakefree_badbounds Ts t
135    val absterm = mk_foo_match (Zipper.C.apply c) Ts' t'
136  in
137    (t', (FakeTs', Ts', absterm))
138  end;
139
140(* Unification with exception handled *)
141(* given context, max var index, pat, tgt; returns Seq of instantiations *)
142fun clean_unify ctxt ix (a as (pat, tgt)) =
143  let
144    (* type info will be re-derived, maybe this can be cached
145       for efficiency? *)
146    val pat_ty = Term.type_of pat;
147    val tgt_ty = Term.type_of tgt;
148    (* FIXME is it OK to ignore the type instantiation info?
149       or should I be using it? *)
150    val typs_unify =
151      SOME (Sign.typ_unify (Proof_Context.theory_of ctxt) (pat_ty, tgt_ty) (Vartab.empty, ix))
152        handle Type.TUNIFY => NONE;
153  in
154    (case typs_unify of
155      SOME (typinsttab, ix2) =>
156        let
157          (* FIXME is it right to throw away the flexes?
158             or should I be using them somehow? *)
159          fun mk_insts env =
160            (Vartab.dest (Envir.type_env env),
161             Vartab.dest (Envir.term_env env));
162          val initenv =
163            Envir.Envir {maxidx = ix2, tenv = Vartab.empty, tyenv = typinsttab};
164          val useq = Unify.smash_unifiers (Context.Proof ctxt) [a] initenv
165            handle ListPair.UnequalLengths => Seq.empty
166              | Term.TERM _ => Seq.empty;
167          fun clean_unify' useq () =
168            (case (Seq.pull useq) of
169               NONE => NONE
170             | SOME (h, t) => SOME (mk_insts h, Seq.make (clean_unify' t)))
171            handle ListPair.UnequalLengths => NONE
172              | Term.TERM _ => NONE;
173        in
174          (Seq.make (clean_unify' useq))
175        end
176    | NONE => Seq.empty)
177  end;
178
179(* Unification for zippers *)
180(* Note: Ts is a modified version of the original names of the outer
181   bound variables. New names have been introduced to make sure they are
182   unique w.r.t all names in the term and each other. usednames' is
183   oldnames + new names. *)
184fun clean_unify_z ctxt maxidx pat z =
185  let val (t, (FakeTs, Ts, absterm)) = prep_zipper_match z in
186    Seq.map (fn insts => (insts, FakeTs, Ts, absterm))
187      (clean_unify ctxt maxidx (t, pat))
188  end;
189
190
191fun bot_left_leaf_of (l $ _) = bot_left_leaf_of l
192  | bot_left_leaf_of (Abs (_, _, t)) = bot_left_leaf_of t
193  | bot_left_leaf_of x = x;
194
195(* Avoid considering replacing terms which have a var at the head as
196   they always succeed trivially, and uninterestingly. *)
197fun valid_match_start z =
198  (case bot_left_leaf_of (Zipper.trm z) of
199    Var _ => false
200  | _ => true);
201
202(* search from top, left to right, then down *)
203val search_lr_all = ZipperSearch.all_bl_ur;
204
205(* search from top, left to right, then down *)
206fun search_lr_valid validf =
207  let
208    fun sf_valid_td_lr z =
209      let val here = if validf z then [Zipper.Here z] else [] in
210        (case Zipper.trm z of
211          _ $ _ =>
212            [Zipper.LookIn (Zipper.move_down_left z)] @ here @
213            [Zipper.LookIn (Zipper.move_down_right z)]
214        | Abs _ => here @ [Zipper.LookIn (Zipper.move_down_abs z)]
215        | _ => here)
216      end;
217  in Zipper.lzy_search sf_valid_td_lr end;
218
219(* search from bottom to top, left to right *)
220fun search_bt_valid validf =
221  let
222    fun sf_valid_td_lr z =
223      let val here = if validf z then [Zipper.Here z] else [] in
224        (case Zipper.trm z of
225          _ $ _ =>
226            [Zipper.LookIn (Zipper.move_down_left z),
227             Zipper.LookIn (Zipper.move_down_right z)] @ here
228        | Abs _ => [Zipper.LookIn (Zipper.move_down_abs z)] @ here
229        | _ => here)
230      end;
231  in Zipper.lzy_search sf_valid_td_lr end;
232
233fun searchf_unify_gen f (ctxt, maxidx, z) lhs =
234  Seq.map (clean_unify_z ctxt maxidx lhs) (Zipper.limit_apply f z);
235
236(* search all unifications *)
237val searchf_lr_unify_all = searchf_unify_gen search_lr_all;
238
239(* search only for 'valid' unifiers (non abs subterms and non vars) *)
240val searchf_lr_unify_valid = searchf_unify_gen (search_lr_valid valid_match_start);
241
242val searchf_bt_unify_valid = searchf_unify_gen (search_bt_valid valid_match_start);
243
244(* apply a substitution in the conclusion of the theorem *)
245(* cfvs are certified free var placeholders for goal params *)
246(* conclthm is a theorem of for just the conclusion *)
247(* m is instantiation/match information *)
248(* rule is the equation for substitution *)
249fun apply_subst_in_concl ctxt i st (cfvs, conclthm) rule m =
250  RW_Inst.rw ctxt m rule conclthm
251  |> unfix_frees cfvs
252  |> Conv.fconv_rule Drule.beta_eta_conversion
253  |> (fn r => resolve_tac ctxt [r] i st);
254
255(* substitute within the conclusion of goal i of gth, using a meta
256equation rule. Note that we assume rule has var indicies zero'd *)
257fun prep_concl_subst ctxt i gth =
258  let
259    val th = Thm.incr_indexes 1 gth;
260    val tgt_term = Thm.prop_of th;
261
262    val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
263    val cfvs = rev (map (Thm.cterm_of ctxt) fvs);
264
265    val conclterm = Logic.strip_imp_concl fixedbody;
266    val conclthm = Thm.trivial (Thm.cterm_of ctxt conclterm);
267    val maxidx = Thm.maxidx_of th;
268    val ft =
269      (Zipper.move_down_right (* ==> *)
270       o Zipper.move_down_left (* Trueprop *)
271       o Zipper.mktop
272       o Thm.prop_of) conclthm
273  in
274    ((cfvs, conclthm), (ctxt, maxidx, ft))
275  end;
276
277(* substitute using an object or meta level equality *)
278fun eqsubst_tac' ctxt searchf instepthm i st =
279  let
280    val (cvfsconclthm, searchinfo) = prep_concl_subst ctxt i st;
281    val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
282    fun rewrite_with_thm r =
283      let val (lhs,_) = Logic.dest_equals (Thm.concl_of r) in
284        searchf searchinfo lhs
285        |> Seq.maps (apply_subst_in_concl ctxt i st cvfsconclthm r)
286      end;
287  in stepthms |> Seq.maps rewrite_with_thm end;
288
289
290(* General substitution of multiple occurrences using one of
291   the given theorems *)
292
293fun skip_first_occs_search occ srchf sinfo lhs =
294  (case skipto_skipseq occ (srchf sinfo lhs) of
295    SkipMore _ => Seq.empty
296  | SkipSeq ss => Seq.flat ss);
297
298(* The "occs" argument is a list of integers indicating which occurrence
299w.r.t. the search order, to rewrite. Backtracking will also find later
300occurrences, but all earlier ones are skipped. Thus you can use [0] to
301just find all rewrites. *)
302
303fun eqsubst_tac ctxt occs thms i st =
304  let val nprems = Thm.nprems_of st in
305    if nprems < i then Seq.empty else
306    let
307      val thmseq = Seq.of_list thms;
308      fun apply_occ occ st =
309        thmseq |> Seq.maps (fn r =>
310          eqsubst_tac' ctxt
311            (skip_first_occs_search occ searchf_lr_unify_valid) r
312            (i + (Thm.nprems_of st - nprems)) st);
313      val sorted_occs = Library.sort (rev_order o int_ord) occs;
314    in
315      Seq.maps distinct_subgoals_tac (Seq.EVERY (map apply_occ sorted_occs) st)
316    end
317  end;
318
319
320(* apply a substitution inside assumption j, keeps asm in the same place *)
321fun apply_subst_in_asm ctxt i st rule ((cfvs, j, _, pth),m) =
322  let
323    val st2 = Thm.rotate_rule (j - 1) i st; (* put premice first *)
324    val preelimrule =
325      RW_Inst.rw ctxt m rule pth
326      |> (Seq.hd o prune_params_tac ctxt)
327      |> Thm.permute_prems 0 ~1 (* put old asm first *)
328      |> unfix_frees cfvs (* unfix any global params *)
329      |> Conv.fconv_rule Drule.beta_eta_conversion; (* normal form *)
330  in
331    (* ~j because new asm starts at back, thus we subtract 1 *)
332    Seq.map (Thm.rotate_rule (~ j) (Thm.nprems_of rule + i))
333      (dresolve_tac ctxt [preelimrule] i st2)
334  end;
335
336
337(* prepare to substitute within the j'th premise of subgoal i of gth,
338using a meta-level equation. Note that we assume rule has var indicies
339zero'd. Note that we also assume that premt is the j'th premice of
340subgoal i of gth. Note the repetition of work done for each
341assumption, i.e. this can be made more efficient for search over
342multiple assumptions.  *)
343fun prep_subst_in_asm ctxt i gth j =
344  let
345    val th = Thm.incr_indexes 1 gth;
346    val tgt_term = Thm.prop_of th;
347
348    val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
349    val cfvs = rev (map (Thm.cterm_of ctxt) fvs);
350
351    val asmt = nth (Logic.strip_imp_prems fixedbody) (j - 1);
352    val asm_nprems = length (Logic.strip_imp_prems asmt);
353
354    val pth = Thm.trivial ((Thm.cterm_of ctxt) asmt);
355    val maxidx = Thm.maxidx_of th;
356
357    val ft =
358      (Zipper.move_down_right (* trueprop *)
359         o Zipper.mktop
360         o Thm.prop_of) pth
361  in ((cfvs, j, asm_nprems, pth), (ctxt, maxidx, ft)) end;
362
363(* prepare subst in every possible assumption *)
364fun prep_subst_in_asms ctxt i gth =
365  map (prep_subst_in_asm ctxt i gth)
366    ((fn l => Library.upto (1, length l))
367      (Logic.prems_of_goal (Thm.prop_of gth) i));
368
369
370(* substitute in an assumption using an object or meta level equality *)
371fun eqsubst_asm_tac' ctxt searchf skipocc instepthm i st =
372  let
373    val asmpreps = prep_subst_in_asms ctxt i st;
374    val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
375    fun rewrite_with_thm r =
376      let
377        val (lhs,_) = Logic.dest_equals (Thm.concl_of r);
378        fun occ_search occ [] = Seq.empty
379          | occ_search occ ((asminfo, searchinfo)::moreasms) =
380              (case searchf searchinfo occ lhs of
381                SkipMore i => occ_search i moreasms
382              | SkipSeq ss =>
383                  Seq.append (Seq.map (Library.pair asminfo) (Seq.flat ss))
384                    (occ_search 1 moreasms)) (* find later substs also *)
385      in
386        occ_search skipocc asmpreps |> Seq.maps (apply_subst_in_asm ctxt i st r)
387      end;
388  in stepthms |> Seq.maps rewrite_with_thm end;
389
390
391fun skip_first_asm_occs_search searchf sinfo occ lhs =
392  skipto_skipseq occ (searchf sinfo lhs);
393
394fun eqsubst_asm_tac ctxt occs thms i st =
395  let val nprems = Thm.nprems_of st in
396    if nprems < i then Seq.empty
397    else
398      let
399        val thmseq = Seq.of_list thms;
400        fun apply_occ occ st =
401          thmseq |> Seq.maps (fn r =>
402            eqsubst_asm_tac' ctxt
403              (skip_first_asm_occs_search searchf_lr_unify_valid) occ r
404              (i + (Thm.nprems_of st - nprems)) st);
405        val sorted_occs = Library.sort (rev_order o int_ord) occs;
406      in
407        Seq.maps distinct_subgoals_tac (Seq.EVERY (map apply_occ sorted_occs) st)
408      end
409  end;
410
411(* combination method that takes a flag (true indicates that subst
412   should be done to an assumption, false = apply to the conclusion of
413   the goal) as well as the theorems to use *)
414val _ =
415  Theory.setup
416    (Method.setup \<^binding>\<open>subst\<close>
417      (Scan.lift (Args.mode "asm" -- Scan.optional (Args.parens (Scan.repeat Parse.nat)) [0]) --
418        Attrib.thms >> (fn ((asm, occs), inthms) => fn ctxt =>
419          SIMPLE_METHOD' ((if asm then eqsubst_asm_tac else eqsubst_tac) ctxt occs inthms)))
420      "single-step substitution");
421
422end;
423