--- a/src/Pure/Tools/rule_insts.ML Fri Mar 20 21:16:42 2015 +0100
+++ b/src/Pure/Tools/rule_insts.ML Fri Mar 20 22:18:40 2015 +0100
@@ -34,10 +34,9 @@
structure Rule_Insts: RULE_INSTS =
struct
-(** reading instantiations **)
+(** read instantiations **)
-fun partition_insts mixed_insts =
- List.partition (fn (((x, _), _), _) => String.isPrefix "'" x) mixed_insts;
+local
fun error_var msg (xi, pos) =
error (msg ^ quote (Term.string_of_vname xi) ^ Position.here pos);
@@ -52,26 +51,6 @@
SOME T => T
| NONE => error_var "No such variable in theorem: " (xi, pos));
-local
-
-fun instantiate inst =
- Term_Subst.instantiate ([], map (fn (xi, t) => ((xi, Term.fastype_of t), t)) inst) #>
- Envir.beta_norm;
-
-fun make_instT f v =
- let
- val T = TVar v;
- val T' = f T;
- in if T = T' then NONE else SOME (v, T') end;
-
-fun make_inst f v =
- let
- val t = Var v;
- val t' = f t;
- in if t aconv t' then NONE else SOME (v, t') end;
-
-in
-
fun readT ctxt tvars ((xi, pos), s) =
let
val S = the_sort tvars (xi, pos);
@@ -94,37 +73,56 @@
val tyenv' = Vartab.fold (fn (xi, (S, T)) => cons ((xi, S), T)) tyenv [];
in (ts', tyenv') end;
+fun instantiate inst =
+ Term_Subst.instantiate ([], map (fn (xi, t) => ((xi, Term.fastype_of t), t)) inst) #>
+ Envir.beta_norm;
+
+fun make_instT f v =
+ let
+ val T = TVar v;
+ val T' = f T;
+ in if T = T' then NONE else SOME (v, T') end;
+
+fun make_inst f v =
+ let
+ val t = Var v;
+ val t' = f t;
+ in if t aconv t' then NONE else SOME (v, t') end;
+
+in
+
fun read_insts ctxt mixed_insts thm =
let
- val (type_insts, term_insts) = partition_insts mixed_insts;
+ val (type_insts, term_insts) =
+ List.partition (fn (((x, _), _), _) => String.isPrefix "'" x) mixed_insts;
val tvars = Thm.fold_terms Term.add_tvars thm [];
val vars = Thm.fold_terms Term.add_vars thm [];
-
- (* type instantiations *)
-
+ (*explicit type instantiations*)
val instT1 = Term_Subst.instantiateT (map (readT ctxt tvars) type_insts);
val vars1 = map (apsnd instT1) vars;
-
- (* term instantiations *)
-
+ (*term instantiations*)
val (xs, ss) = split_list term_insts;
val Ts = map (the_type vars1) xs;
val (ts, inferred) = read_termTs ctxt ss Ts;
+ (*implicit type instantiations*)
val instT2 = Term_Subst.instantiateT inferred;
val vars2 = map (apsnd instT2) vars1;
val inst2 = instantiate (map #1 xs ~~ ts);
-
- (* result *)
-
val inst_tvars = map_filter (make_instT (instT2 o instT1)) tvars;
val inst_vars = map_filter (make_inst inst2) vars2;
in (inst_tvars, inst_vars) end;
+end;
+
+
+
+(** forward rules **)
+
fun where_rule ctxt mixed_insts fixes thm =
let
val ctxt' = ctxt
@@ -151,8 +149,6 @@
zip_vars (rev (Term.add_vars (Thm.concl_of thm) [])) concl_args;
in where_rule ctxt insts fixes thm end;
-end;
-
fun read_instantiate ctxt insts xs =
where_rule ctxt insts (map (fn x => (Binding.name x, NONE, NoSyn)) xs);
@@ -200,9 +196,6 @@
fun bires_inst_tac bires_flag ctxt mixed_insts thm i st = CSUBGOAL (fn (cgoal, _) =>
let
- val (Tinsts, tinsts) = partition_insts mixed_insts;
-
-
(* goal context *)
val goal = Thm.term_of cgoal;
@@ -217,42 +210,32 @@
|> Proof_Context.add_fixes (map (fn (x, T) => (Binding.name x, SOME T, NoSyn)) params);
- (* preprocess rule *)
-
- val tvars = Thm.fold_terms Term.add_tvars thm [];
- val vars = Thm.fold_terms Term.add_vars thm [];
-
- val Tinsts_env = map (readT ctxt' tvars) Tinsts;
- val (xis, ss) = split_list tinsts;
- val Ts = map (Term_Subst.instantiateT Tinsts_env o the_type vars) xis;
+ (* lift and instantiate rule *)
- val (ts, envT) =
- read_termTs (Proof_Context.set_mode Proof_Context.mode_schematic ctxt') ss Ts;
- val envT' = map (fn (v, T) => (TVar v, T)) (envT @ Tinsts_env);
- val cenv =
- map (fn ((xi, _), t) => apply2 (Thm.cterm_of ctxt') (Var (xi, fastype_of t), t))
- (xis ~~ ts);
-
-
- (* lift and instantiate rule *)
+ val (inst_tvars, inst_vars) =
+ read_insts (Proof_Context.set_mode Proof_Context.mode_schematic ctxt')
+ mixed_insts thm;
val maxidx = Thm.maxidx_of st;
val paramTs = map #2 params;
val inc = maxidx + 1;
- fun lift_var (Var ((a, j), T)) = Var ((a, j + inc), paramTs ---> Logic.incr_tvar inc T)
- | lift_var t = raise TERM ("Variable expected", [t]);
+ fun lift_var ((a, j), T) =
+ Var ((a, j + inc), paramTs ---> Logic.incr_tvar inc T);
fun lift_term t =
- fold_rev absfree (param_names ~~ paramTs) (Logic.incr_indexes (paramTs, inc) t);
- fun lift_inst (cv, ct) = (cterm_fun lift_var cv, cterm_fun lift_term ct);
- val lift_tvar = apply2 (Thm.ctyp_of ctxt' o Logic.incr_tvar inc);
+ fold_rev absfree (param_names ~~ paramTs)
+ (Logic.incr_indexes (paramTs, inc) t);
- val rule =
- Drule.instantiate_normalize
- (map lift_tvar envT', map lift_inst cenv)
+ val inst_tvars' = inst_tvars
+ |> map (apply2 (Thm.ctyp_of ctxt' o Logic.incr_tvar inc) o apfst TVar);
+ val inst_vars' = inst_vars
+ |> map (fn (v, t) => apply2 (Thm.cterm_of ctxt') (lift_var v, lift_term t));
+
+ val thm' =
+ Drule.instantiate_normalize (inst_tvars', inst_vars')
(Thm.lift_rule cgoal thm);
in
- compose_tac ctxt' (bires_flag, rule, Thm.nprems_of thm) i
+ compose_tac ctxt' (bires_flag, thm', Thm.nprems_of thm) i
end) i st;
val res_inst_tac = bires_inst_tac false;