instantiate schematics as projections to avoid HOU trouble
authorblanchet
Fri, 03 Jan 2014 11:26:44 +0100
changeset 54923 ffed2452f5f6
parent 54922 494fd4ec3850
child 54924 44373f3560c7
instantiate schematics as projections to avoid HOU trouble
src/HOL/BNF/Tools/bnf_fp_def_sugar_tactics.ML
src/HOL/BNF/Tools/bnf_fp_util.ML
src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML
src/HOL/BNF/Tools/bnf_gfp_rec_sugar_tactics.ML
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar_tactics.ML	Fri Jan 03 10:48:48 2014 +0100
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar_tactics.ML	Fri Jan 03 11:26:44 2014 +0100
@@ -45,11 +45,6 @@
 
 val ss_if_True_False = simpset_of (ss_only @{thms if_True if_False} @{context});
 
-fun mk_proj T k =
-  let val binders = binder_types T in
-    fold_rev (fn T => fn t => Abs (Name.uu, T, t)) binders (Bound (length binders - k))
-  end;
-
 fun hhf_concl_conv cv ctxt ct =
   (case Thm.term_of ct of
     Const (@{const_name all}, _) $ Abs _ =>
@@ -60,9 +55,11 @@
   let
     val fs = Term.add_vars (prop_of thm) []
       |> filter (fn (_, Type (@{type_name fun}, [_, T'])) => T' <> HOLogic.boolT | _ => false);
-    val cfs = map (fn f as (_, T) => (certify ctxt (Var f), certify ctxt (mk_proj T k))) fs;
+    fun mk_cfp (f as (_, T)) =
+      (certify ctxt (Var f), certify ctxt (mk_proj T (num_binder_types T) k));
+    val cfps = map mk_cfp fs;
   in
-    Drule.cterm_instantiate cfs thm
+    Drule.cterm_instantiate cfps thm
   end;
 
 val co_induct_inst_as_projs_tac = PRIMITIVE oo co_induct_inst_as_projs;
@@ -135,7 +132,7 @@
 fun mk_induct_tac ctxt nn ns mss kkss ctr_defs ctor_induct' set_maps pre_set_defss =
   let val n = Integer.sum ns in
     unfold_thms_tac ctxt ctr_defs THEN HEADGOAL (rtac ctor_induct') THEN
-    co_induct_inst_as_projs_tac ctxt 1 THEN
+    co_induct_inst_as_projs_tac ctxt 0 THEN
     EVERY (map4 (EVERY oooo map3 o mk_induct_discharge_prem_tac ctxt nn n set_maps) pre_set_defss
       mss (unflat mss (1 upto n)) kkss)
   end;
@@ -165,10 +162,10 @@
     discss selss =
   let val ks = 1 upto n in
     EVERY' ([rtac allI, rtac allI, rtac impI, select_prem_tac nn (dtac meta_spec) kk,
-        dtac meta_spec, dtac meta_mp, atac, rtac exhaust, K (co_induct_inst_as_projs_tac ctxt 1),
+        dtac meta_spec, dtac meta_mp, atac, rtac exhaust, K (co_induct_inst_as_projs_tac ctxt 0),
         hyp_subst_tac ctxt] @
       map4 (fn k => fn ctr_def => fn discs => fn sels =>
-        EVERY' ([rtac exhaust, K (co_induct_inst_as_projs_tac ctxt 2)] @
+        EVERY' ([rtac exhaust, K (co_induct_inst_as_projs_tac ctxt 1)] @
           map2 (fn k' => fn discs' =>
             if k' = k then
               mk_coinduct_same_ctr_tac ctxt rel_eqs' pre_rel_def dtor_ctor ctr_def discs sels
--- a/src/HOL/BNF/Tools/bnf_fp_util.ML	Fri Jan 03 10:48:48 2014 +0100
+++ b/src/HOL/BNF/Tools/bnf_fp_util.ML	Fri Jan 03 11:26:44 2014 +0100
@@ -138,6 +138,8 @@
   val mk_sumTN: typ list -> typ
   val mk_sumTN_balanced: typ list -> typ
 
+  val mk_proj: typ -> int -> int -> term
+
   val mk_convol: term * term -> term
 
   val Inl_const: typ -> typ -> term
@@ -374,6 +376,11 @@
 val mk_sumTN = Library.foldr1 mk_sumT;
 val mk_sumTN_balanced = Balanced_Tree.make mk_sumT;
 
+fun mk_proj T n k =
+  let val (binders, _) = strip_typeN n T in
+    fold_rev (fn T => fn t => Abs (Name.uu, T, t)) binders (Bound (n - k - 1))
+  end;
+
 fun mk_convol (f, g) =
   let
     val (fU, fTU) = `range_type (fastype_of f);
--- a/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML	Fri Jan 03 10:48:48 2014 +0100
+++ b/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML	Fri Jan 03 11:26:44 2014 +0100
@@ -459,7 +459,7 @@
 fun abstract vs =
   let fun a n (t $ u) = a n t $ a n u
         | a n (Abs (v, T, b)) = Abs (v, T, a (n + 1) b)
-        | a n t = let val idx = find_index (equal t) vs in
+        | a n t = let val idx = find_index (curry (op =) t) vs in
             if idx < 0 then t else Bound (n + idx) end
   in a 0 end;
 
@@ -521,12 +521,12 @@
       primcorec_error_eqn "negated discriminator for a type with \<noteq> 2 constructors" concl;
     val disc' = find_subterm (member (op =) discs o head_of) concl;
     val eq_ctr0 = concl |> perhaps (try HOLogic.dest_not) |> try (HOLogic.dest_eq #> snd)
-        |> (fn SOME t => let val n = find_index (equal t) ctrs in
+        |> (fn SOME t => let val n = find_index (curry (op =) t) ctrs in
           if n >= 0 then SOME n else NONE end | _ => NONE);
     val _ = is_some disc' orelse is_some eq_ctr0 orelse
       primcorec_error_eqn "no discriminator in equation" concl;
     val ctr_no' =
-      if is_none disc' then the eq_ctr0 else find_index (equal (head_of (the disc'))) discs;
+      if is_none disc' then the eq_ctr0 else find_index (curry (op =) (head_of (the disc'))) discs;
     val ctr_no = if not_disc then 1 - ctr_no' else ctr_no';
     val {ctr, disc, ...} = nth basic_ctr_specs ctr_no;
 
@@ -575,8 +575,8 @@
         primcorec_error_eqn "malformed selector argument in left-hand side" eqn;
     val {ctr, ...} =
       (case maybe_of_spec of
-        SOME of_spec => the (find_first (equal of_spec o #ctr) basic_ctr_specs)
-      | NONE => filter (exists (equal sel) o #sels) basic_ctr_specs |> the_single
+        SOME of_spec => the (find_first (curry (op =) of_spec o #ctr) basic_ctr_specs)
+      | NONE => filter (exists (curry (op =) sel) o #sels) basic_ctr_specs |> the_single
           handle List.Empty => primcorec_error_eqn "ambiguous selector - use \"of\"" eqn);
     val user_eqn = drop_All eqn';
   in
@@ -600,7 +600,7 @@
     val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free;
     val SOME basic_ctr_specs = AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name;
     val (ctr, ctr_args) = strip_comb (unfold_let rhs);
-    val {disc, sels, ...} = the (find_first (equal ctr o #ctr) basic_ctr_specs)
+    val {disc, sels, ...} = the (find_first (curry (op =) ctr o #ctr) basic_ctr_specs)
       handle Option.Option => primcorec_error_eqn "not a constructor" ctr;
 
     val disc_concl = betapply (disc, lhs);
@@ -700,13 +700,13 @@
     |> K |> nth_map (the (#pred (nth ctr_specs ctr_no)));
 
 fun build_corec_arg_no_call (sel_eqns : coeqn_data_sel list) sel =
-  find_first (equal sel o #sel) sel_eqns
+  find_first (curry (op =) sel o #sel) sel_eqns
   |> try (fn SOME {fun_args, rhs_term, ...} => abs_tuple fun_args rhs_term)
   |> the_default undef_const
   |> K;
 
 fun build_corec_args_mutual_call lthy has_call (sel_eqns : coeqn_data_sel list) sel =
-  (case find_first (equal sel o #sel) sel_eqns of
+  (case find_first (curry (op =) sel o #sel) sel_eqns of
     NONE => (I, I, I)
   | SOME {fun_args, rhs_term, ... } =>
     let
@@ -722,7 +722,7 @@
     end);
 
 fun build_corec_arg_nested_call lthy has_call (sel_eqns : coeqn_data_sel list) sel =
-  (case find_first (equal sel o #sel) sel_eqns of
+  (case find_first (curry (op =) sel o #sel) sel_eqns of
     NONE => I
   | SOME {fun_args, rhs_term, ...} =>
     let
@@ -749,7 +749,7 @@
 
 fun build_corec_args_sel lthy has_call (all_sel_eqns : coeqn_data_sel list)
     (ctr_spec : corec_ctr_spec) =
-  (case filter (equal (#ctr ctr_spec) o #ctr) all_sel_eqns of
+  (case filter (curry (op =) (#ctr ctr_spec) o #ctr) all_sel_eqns of
     [] => I
   | sel_eqns =>
     let
@@ -816,10 +816,11 @@
     else
       let
         val n = 0 upto length ctr_specs
-          |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns));
+          |> the o find_first (fn idx => not (exists (curry (op =) idx o #ctr_no) disc_eqns));
         val fun_args = (try (#fun_args o hd) disc_eqns, try (#fun_args o hd) sel_eqns)
           |> the_default (map (curry Free Name.uu) arg_Ts) o merge_options;
-        val maybe_sel_eqn = find_first (equal (Binding.name_of fun_binding) o #fun_name) sel_eqns;
+        val maybe_sel_eqn =
+          find_first (curry (op =) (Binding.name_of fun_binding) o #fun_name) sel_eqns;
         val extra_disc_eqn = {
           fun_name = Binding.name_of fun_binding,
           fun_T = arg_Ts ---> body_type (fastype_of (#ctr (hd ctr_specs))),
@@ -839,8 +840,8 @@
 
 fun find_corec_calls ctxt has_call basic_ctr_specs ({ctr, sel, rhs_term, ...} : coeqn_data_sel) =
   let
-    val sel_no = find_first (equal ctr o #ctr) basic_ctr_specs
-      |> find_index (equal sel) o #sels o the;
+    val sel_no = find_first (curry (op =) ctr o #ctr) basic_ctr_specs
+      |> find_index (curry (op =) sel) o #sels o the;
     fun find t = if has_call t then snd (fold_rev_corec_call ctxt (K cons) [] t []) else [];
   in
     find rhs_term
@@ -904,7 +905,7 @@
     val _ = disc_eqnss' |> map (fn x =>
       let val d = duplicates ((op =) o pairself #ctr_no) x in null d orelse
         primcorec_error_eqns "excess discriminator formula in definition"
-          (maps (fn t => filter (equal (#ctr_no t) o #ctr_no) x) d |> map #user_eqn) end);
+          (maps (fn t => filter (curry (op =) (#ctr_no t) o #ctr_no) x) d |> map #user_eqn) end);
 
     val sel_eqnss = map_filter (try (fn Sel x => x)) eqns_data
       |> partition_eq ((op =) o pairself #fun_name)
@@ -992,9 +993,10 @@
                 end)
             de_facto_exhaustives disc_eqnss
           |> list_all_fun_args [("P", HOLogic.boolT)]
-          |> map3 (fn disc_eqns => fn [] => K []
+          |> map3 (fn disc_eqns as {fun_args, ...} :: _ => fn [] => K []
               | [nchotomy_thm] => fn [goal] =>
-                [mk_primcorec_exhaust_tac (length disc_eqns) nchotomy_thm
+                [mk_primcorec_exhaust_tac lthy ("" (* for "P" *) :: map (fst o dest_Free) fun_args)
+                   (length disc_eqns) nchotomy_thm
                  |> K |> Goal.prove lthy [] [] goal
                  |> Thm.close_derivation])
             disc_eqnss nchotomy_thmss;
@@ -1038,11 +1040,11 @@
             (disc_eqns : coeqn_data_disc list) excludesss
             ({fun_name, fun_T, fun_args, ctr, sel, rhs_term, ...} : coeqn_data_sel) =
           let
-            val SOME ctr_spec = find_first (equal ctr o #ctr) ctr_specs;
-            val ctr_no = find_index (equal ctr o #ctr) ctr_specs;
+            val SOME ctr_spec = find_first (curry (op =) ctr o #ctr) ctr_specs;
+            val ctr_no = find_index (curry (op =) ctr o #ctr) ctr_specs;
             val prems = the_default (maps (s_not_conj o #prems) disc_eqns)
-              (find_first (equal ctr_no o #ctr_no) disc_eqns |> Option.map #prems);
-            val sel_corec = find_index (equal sel) (#sels ctr_spec)
+              (find_first (curry (op =) ctr_no o #ctr_no) disc_eqns |> Option.map #prems);
+            val sel_corec = find_index (curry (op =) sel) (#sels ctr_spec)
               |> nth (#sel_corecs ctr_spec);
             val k = 1 + ctr_no;
             val m = length prems;
@@ -1065,16 +1067,17 @@
         fun prove_ctr disc_alist sel_alist (disc_eqns : coeqn_data_disc list)
             (sel_eqns : coeqn_data_sel list) ({ctr, disc, sels, collapse, ...} : corec_ctr_spec) =
           (* don't try to prove theorems when some sel_eqns are missing *)
-          if not (exists (equal ctr o #ctr) disc_eqns)
-              andalso not (exists (equal ctr o #ctr) sel_eqns)
+          if not (exists (curry (op =) ctr o #ctr) disc_eqns)
+              andalso not (exists (curry (op =) ctr o #ctr) sel_eqns)
             orelse
-              filter (equal ctr o #ctr) sel_eqns
+              filter (curry (op =) ctr o #ctr) sel_eqns
               |> fst o finds ((op =) o apsnd #sel) sels
               |> exists (null o snd)
           then [] else
             let
               val (fun_name, fun_T, fun_args, prems, maybe_rhs) =
-                (find_first (equal ctr o #ctr) disc_eqns, find_first (equal ctr o #ctr) sel_eqns)
+                (find_first (curry (op =) ctr o #ctr) disc_eqns,
+                 find_first (curry (op =) ctr o #ctr) sel_eqns)
                 |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x,
                   #maybe_ctr_rhs x))
                 ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, [], #maybe_ctr_rhs x))
@@ -1084,7 +1087,7 @@
                 (if is_some maybe_rhs then
                    the maybe_rhs
                  else
-                   filter (equal ctr o #ctr) sel_eqns
+                   filter (curry (op =) ctr o #ctr) sel_eqns
                    |> fst o finds ((op =) o apsnd #sel) sels
                    |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x)) #-> abstract)
                    |> curry list_comb ctr)
@@ -1130,12 +1133,12 @@
                   | NONE =>
                     let
                       fun prove_code_ctr {ctr, sels, ...} =
-                        if not (exists (equal ctr o fst) ctr_alist) then NONE else
+                        if not (exists (curry (op =) ctr o fst) ctr_alist) then NONE else
                           let
-                            val prems = find_first (equal ctr o #ctr) disc_eqns
+                            val prems = find_first (curry (op =) ctr o #ctr) disc_eqns
                               |> Option.map #prems |> the_default [];
                             val t =
-                              filter (equal ctr o #ctr) sel_eqns
+                              filter (curry (op =) ctr o #ctr) sel_eqns
                               |> fst o finds ((op =) o apsnd #sel) sels
                               |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x))
                                 #-> abstract)
@@ -1210,8 +1213,8 @@
               if prems = [@{term False}] then
                 []
               else
-                mk_primcorec_disc_iff_tac lthy (the_single exhaust_thms) (the_single disc_thms)
-                  disc_thmss (flat disc_excludess)
+                mk_primcorec_disc_iff_tac lthy (map (fst o dest_Free) fun_args)
+                  (the_single exhaust_thms) (the_single disc_thms) disc_thmss (flat disc_excludess)
                 |> K |> Goal.prove lthy [] [] goal
                 |> Thm.close_derivation
                 |> single
--- a/src/HOL/BNF/Tools/bnf_gfp_rec_sugar_tactics.ML	Fri Jan 03 10:48:48 2014 +0100
+++ b/src/HOL/BNF/Tools/bnf_gfp_rec_sugar_tactics.ML	Fri Jan 03 11:26:44 2014 +0100
@@ -12,8 +12,9 @@
   val mk_primcorec_ctr_of_dtr_tac: Proof.context -> int -> thm -> thm option -> thm list -> tactic
   val mk_primcorec_disc_tac: Proof.context -> thm list -> thm -> int -> int -> thm list list list ->
     tactic
-  val mk_primcorec_disc_iff_tac: Proof.context -> thm -> thm -> thm list list -> thm list -> tactic
-  val mk_primcorec_exhaust_tac: int -> thm -> tactic
+  val mk_primcorec_disc_iff_tac: Proof.context -> string list -> thm -> thm -> thm list list ->
+    thm list -> tactic
+  val mk_primcorec_exhaust_tac: Proof.context -> string list -> int -> thm -> tactic
   val mk_primcorec_raw_code_of_ctr_tac: Proof.context -> thm list -> thm list -> thm list ->
     thm list -> int list -> thm list -> thm option -> tactic
   val mk_primcorec_sel_tac: Proof.context -> thm list -> thm list -> thm list -> thm list ->
@@ -25,6 +26,7 @@
 
 open BNF_Util
 open BNF_Tactics
+open BNF_FP_Util
 
 val atomize_conjL = @{thm atomize_conjL};
 val falseEs = @{thms not_TrueE FalseE};
@@ -34,10 +36,28 @@
 val split_if_asm = @{thm split_if_asm};
 val split_connectI = @{thms allI impI conjI};
 
-fun mk_primcorec_exhaust_tac n nchotomy =
+fun exhaust_inst_as_projs ctxt frees thm =
+  let
+    val num_frees = length frees;
+    val fs = Term.add_vars (prop_of thm) [] |> filter (can dest_funT o snd);
+    fun find s = find_index (curry (op =) s) frees;
+    fun mk_cfp (f as ((s, _), T)) =
+      (certify ctxt (Var f), certify ctxt (mk_proj T num_frees (find s)));
+    val cfps = map mk_cfp fs;
+  in
+    Drule.cterm_instantiate cfps thm
+  end;
+
+val exhaust_inst_as_projs_tac = PRIMITIVE oo exhaust_inst_as_projs;
+
+fun distinct_in_prems_tac distincts =
+  eresolve_tac (map (fn thm => thm RS neq_eq_eq_contradict) distincts) THEN' atac;
+
+fun mk_primcorec_exhaust_tac ctxt frees n nchotomy =
   let val ks = 1 upto n in
     HEADGOAL (atac ORELSE'
       cut_tac nchotomy THEN'
+      K (exhaust_inst_as_projs_tac ctxt frees) THEN'
       EVERY' (map (fn k =>
           (if k < n then etac disjE else K all_tac) THEN'
           REPEAT o (dtac meta_mp THEN' atac ORELSE'
@@ -78,9 +98,10 @@
 fun mk_primcorec_disc_tac ctxt defs disc_corec k m exclsss =
   mk_primcorec_prelude ctxt defs disc_corec THEN mk_primcorec_cases_tac ctxt k m exclsss;
 
-fun mk_primcorec_disc_iff_tac ctxt fun_exhaust fun_disc fun_discss disc_excludes =
+fun mk_primcorec_disc_iff_tac ctxt frees fun_exhaust fun_disc fun_discss disc_excludes =
   HEADGOAL (rtac iffI THEN'
     rtac fun_exhaust THEN'
+    K (exhaust_inst_as_projs_tac ctxt frees) THEN'
     EVERY' (map (fn [] => etac FalseE
         | [fun_disc'] =>
           if Thm.eq_thm (fun_disc', fun_disc) then
@@ -122,9 +143,6 @@
     end
   | _ => split);
 
-fun distinct_in_prems_tac distincts =
-  eresolve_tac (map (fn thm => thm RS neq_eq_eq_contradict) distincts) THEN' atac;
-
 fun mk_primcorec_raw_code_of_ctr_single_tac ctxt distincts discIs splits split_asms m fun_ctr =
   let
     val splits' =