moving the preprocessing of introduction rules after the code_pred command; added tuple expansion preprocessing of elimination rule
authorbulwahn
Thu, 23 Sep 2010 17:22:44 +0200
changeset 39657 5e57675b7e40
parent 39656 f398f66969ce
child 39658 b3644e40f661
moving the preprocessing of introduction rules after the code_pred command; added tuple expansion preprocessing of elimination rule
src/HOL/Predicate_Compile_Examples/Predicate_Compile_Tests.thy
src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
--- a/src/HOL/Predicate_Compile_Examples/Predicate_Compile_Tests.thy	Thu Sep 23 14:50:18 2010 +0200
+++ b/src/HOL/Predicate_Compile_Examples/Predicate_Compile_Tests.thy	Thu Sep 23 17:22:44 2010 +0200
@@ -1327,10 +1327,10 @@
 text {* The global introduction rules must be redeclared as introduction rules and then 
   one could invoke code_pred. *}
 
-declare A.hd_predicate_in_locale.intros [unfolded Predicate.eq_is_eq[symmetric], code_pred_intro]
+declare A.hd_predicate_in_locale.intros [code_pred_intro]
 
 code_pred (expected_modes: i => i => i => bool, i => i => o => bool) A.hd_predicate_in_locale
-unfolding eq_is_eq by (auto elim: A.hd_predicate_in_locale.cases)
+by (auto elim: A.hd_predicate_in_locale.cases)
     
 interpretation A partial_hd .
 thm hd_predicate_in_locale.intros
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Thu Sep 23 14:50:18 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Thu Sep 23 17:22:44 2010 +0200
@@ -142,8 +142,11 @@
   val default_options : options
   val bool_options : string list
   val print_step : options -> string -> unit
+  (* conversions *)
+  val imp_prems_conv : conv -> conv
   (* simple transformations *)
   val expand_tuples : theory -> thm -> thm
+  val expand_tuples_elim : Proof.context -> thm -> thm
   val eta_contract_ho_arguments : theory -> thm -> thm
   val remove_equalities : theory -> thm -> thm
   val remove_pointless_clauses : thm -> thm list
@@ -789,35 +792,38 @@
 
 (** tuple processing **)
 
+fun rewrite_args [] (pats, intro_t, ctxt) = (pats, intro_t, ctxt)
+  | rewrite_args (arg::args) (pats, intro_t, ctxt) = 
+    (case HOLogic.strip_tupleT (fastype_of arg) of
+      (Ts as _ :: _ :: _) =>
+      let
+        fun rewrite_arg' (Const (@{const_name Pair}, _) $ _ $ t2, Type (@{type_name Product_Type.prod}, [_, T2]))
+          (args, (pats, intro_t, ctxt)) = rewrite_arg' (t2, T2) (args, (pats, intro_t, ctxt))
+          | rewrite_arg' (t, Type (@{type_name Product_Type.prod}, [T1, T2])) (args, (pats, intro_t, ctxt)) =
+            let
+              val thy = ProofContext.theory_of ctxt
+              val ([x, y], ctxt') = Variable.variant_fixes ["x", "y"] ctxt
+              val pat = (t, HOLogic.mk_prod (Free (x, T1), Free (y, T2)))
+              val intro_t' = Pattern.rewrite_term thy [pat] [] intro_t
+              val args' = map (Pattern.rewrite_term thy [pat] []) args
+            in
+              rewrite_arg' (Free (y, T2), T2) (args', (pat::pats, intro_t', ctxt'))
+            end
+          | rewrite_arg' _ (args, (pats, intro_t, ctxt)) = (args, (pats, intro_t, ctxt))
+        val (args', (pats, intro_t', ctxt')) = rewrite_arg' (arg, fastype_of arg)
+          (args, (pats, intro_t, ctxt))
+      in
+        rewrite_args args' (pats, intro_t', ctxt')
+      end
+  | _ => rewrite_args args (pats, intro_t, ctxt))
+
+fun rewrite_prem atom =
+  let
+    val (_, args) = strip_comb atom
+  in rewrite_args args end
+
 fun expand_tuples thy intro =
   let
-    fun rewrite_args [] (pats, intro_t, ctxt) = (pats, intro_t, ctxt)
-      | rewrite_args (arg::args) (pats, intro_t, ctxt) = 
-      (case HOLogic.strip_tupleT (fastype_of arg) of
-        (Ts as _ :: _ :: _) =>
-        let
-          fun rewrite_arg' (Const (@{const_name Pair}, _) $ _ $ t2, Type (@{type_name Product_Type.prod}, [_, T2]))
-            (args, (pats, intro_t, ctxt)) = rewrite_arg' (t2, T2) (args, (pats, intro_t, ctxt))
-            | rewrite_arg' (t, Type (@{type_name Product_Type.prod}, [T1, T2])) (args, (pats, intro_t, ctxt)) =
-              let
-                val ([x, y], ctxt') = Variable.variant_fixes ["x", "y"] ctxt
-                val pat = (t, HOLogic.mk_prod (Free (x, T1), Free (y, T2)))
-                val intro_t' = Pattern.rewrite_term thy [pat] [] intro_t
-                val args' = map (Pattern.rewrite_term thy [pat] []) args
-              in
-                rewrite_arg' (Free (y, T2), T2) (args', (pat::pats, intro_t', ctxt'))
-              end
-            | rewrite_arg' _ (args, (pats, intro_t, ctxt)) = (args, (pats, intro_t, ctxt))
-          val (args', (pats, intro_t', ctxt')) = rewrite_arg' (arg, fastype_of arg)
-            (args, (pats, intro_t, ctxt))
-        in
-          rewrite_args args' (pats, intro_t', ctxt')
-        end
-      | _ => rewrite_args args (pats, intro_t, ctxt))
-    fun rewrite_prem atom =
-      let
-        val (_, args) = strip_comb atom
-      in rewrite_args args end
     val ctxt = ProofContext.init_global thy
     val (((T_insts, t_insts), [intro']), ctxt1) = Variable.import false [intro] ctxt
     val intro_t = prop_of intro'
@@ -842,6 +848,68 @@
     intro'''''
   end
 
+(*** conversions ***)
+
+fun imp_prems_conv cv ct =
+  case Thm.term_of ct of
+    Const ("==>", _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv) (imp_prems_conv cv) ct
+  | _ => Conv.all_conv ct
+
+fun all_params_conv cv ctxt ct =
+  if Logic.is_all (Thm.term_of ct)
+  then Conv.arg_conv (Conv.abs_conv (all_params_conv cv o #2) ctxt) ct
+  else cv ctxt ct;
+  
+fun expand_tuples_elim ctxt elimrule =
+  let
+    val thy = ProofContext.theory_of ctxt
+    val ((_, [elimrule]), ctxt1) = Variable.import false [elimrule] ctxt
+    val prems = Thm.prems_of elimrule
+    val nargs = length (snd (strip_comb (HOLogic.dest_Trueprop (hd prems))))
+    fun preprocess_case t =
+      let
+        val (param_names, param_Ts)  = split_list (Logic.strip_params t)
+        val prop = Logic.list_implies (Logic.strip_assums_hyp t, Logic.strip_assums_concl t)
+        val (free_names, ctxt2) = Variable.variant_fixes param_names ctxt1
+        val frees = map Free (free_names ~~ param_Ts)
+        val prop' = subst_bounds (rev frees, prop)
+        val (eqs, prems) = chop nargs (Logic.strip_imp_prems prop')
+        val rhss = map (snd o HOLogic.dest_eq o HOLogic.dest_Trueprop) eqs
+        val (pats, prop'', ctxt2) = fold 
+          rewrite_prem (map HOLogic.dest_Trueprop prems)
+            (rewrite_args rhss ([], prop', ctxt2)) 
+        val new_frees = fold Term.add_frees (frees @ map snd pats) [] (* FIXME: frees are not minimal and not ordered *)
+      in
+        fold Logic.all (map Free new_frees) prop''
+      end
+    val cases' = map preprocess_case (tl prems)
+    val elimrule' = Logic.list_implies ((hd prems) :: cases', Thm.concl_of elimrule)
+    val tac = (fn _ => Skip_Proof.cheat_tac thy)
+    val eq = Goal.prove ctxt1 [] [] (Logic.mk_equals ((Thm.prop_of elimrule), elimrule')) tac
+    val exported_elimrule' = Thm.equal_elim eq elimrule |> singleton (Variable.export ctxt1 ctxt)
+    val elimrule'' = Conv.fconv_rule (imp_prems_conv (all_params_conv (fn ctxt => Conv.concl_conv nargs 
+      (Simplifier.full_rewrite
+        (HOL_basic_ss addsimps [@{thm fst_conv}, @{thm snd_conv}, @{thm Pair_eq}]))) ctxt1)) 
+      exported_elimrule'
+    (* splitting conjunctions introduced by Pair_eq*)
+    fun split_conj prem =
+      map HOLogic.mk_Trueprop (conjuncts (HOLogic.dest_Trueprop prem))
+    fun map_cases f t =
+      let
+        val (prems, concl) = Logic.strip_horn t
+        val ([pred], prems') = chop 1 prems
+        fun map_params f t =
+          let
+            val prop = Logic.list_implies (Logic.strip_assums_hyp t, Logic.strip_assums_concl t)
+          in Term.list_all (Logic.strip_params t, f prop) end 
+        val prems'' = map (map_params f) prems'
+      in
+        Logic.list_implies (pred :: prems'', concl)
+      end
+    val elimrule''' = map_term thy (map_cases (maps_premises split_conj)) elimrule''
+   in
+     elimrule'''
+  end
 (** eta contract higher-order arguments **)
 
 fun eta_contract_ho_arguments thy intro =
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Thu Sep 23 14:50:18 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Thu Sep 23 17:22:44 2010 +0200
@@ -218,6 +218,7 @@
 datatype pred_data = PredData of {
   intros : (string option * thm) list,
   elim : thm option,
+  preprocessed : bool,
   function_names : (compilation * (mode * string) list) list,
   predfun_data : (mode * predfun_data) list,
   needs_random : mode list
@@ -225,12 +226,12 @@
 
 fun rep_pred_data (PredData data) = data;
 
-fun mk_pred_data ((intros, elim), (function_names, (predfun_data, needs_random))) =
-  PredData {intros = intros, elim = elim,
+fun mk_pred_data (((intros, elim), preprocessed), (function_names, (predfun_data, needs_random))) =
+  PredData {intros = intros, elim = elim, preprocessed = preprocessed,
     function_names = function_names, predfun_data = predfun_data, needs_random = needs_random}
 
-fun map_pred_data f (PredData {intros, elim, function_names, predfun_data, needs_random}) =
-  mk_pred_data (f ((intros, elim), (function_names, (predfun_data, needs_random))))
+fun map_pred_data f (PredData {intros, elim, preprocessed, function_names, predfun_data, needs_random}) =
+  mk_pred_data (f (((intros, elim), preprocessed), (function_names, (predfun_data, needs_random))))
 
 fun eq_option eq (NONE, NONE) = true
   | eq_option eq (SOME x, SOME y) = eq (x, y)
@@ -613,23 +614,20 @@
 
 (** preprocessing rules **)
 
-fun imp_prems_conv cv ct =
-  case Thm.term_of ct of
-    Const ("==>", _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv) (imp_prems_conv cv) ct
-  | _ => Conv.all_conv ct
-
 fun Trueprop_conv cv ct =
   case Thm.term_of ct of
     Const (@{const_name Trueprop}, _) $ _ => Conv.arg_conv cv ct  
   | _ => raise Fail "Trueprop_conv"
 
-fun preprocess_intro thy rule =
+fun preprocess_equality thy rule =
   Conv.fconv_rule
     (imp_prems_conv
       (Trueprop_conv (Conv.try_conv (Conv.rewr_conv (Thm.symmetric @{thm Predicate.eq_is_eq})))))
     (Thm.transfer thy rule)
 
-fun preprocess_elim ctxt elimrule =
+fun preprocess_intro thy = expand_tuples thy #> preprocess_equality thy
+
+fun preprocess_equality_elim ctxt elimrule =
   let
     fun replace_eqs (Const (@{const_name Trueprop}, _) $ (Const (@{const_name HOL.eq}, T) $ lhs $ rhs)) =
        HOLogic.mk_Trueprop (Const (@{const_name Predicate.eq}, T) $ lhs $ rhs)
@@ -640,11 +638,11 @@
     val nargs = length (snd (strip_comb (HOLogic.dest_Trueprop (hd prems))))
     fun preprocess_case t =
       let
-       val params = Logic.strip_params t
-       val (assums1, assums2) = chop nargs (Logic.strip_assums_hyp t)
-       val assums_hyp' = assums1 @ (map replace_eqs assums2)
+        val params = Logic.strip_params t
+        val (assums1, assums2) = chop nargs (Logic.strip_assums_hyp t)
+        val assums_hyp' = assums1 @ (map replace_eqs assums2)
       in
-       list_all (params, Logic.list_implies (assums_hyp', Logic.strip_assums_concl t))
+        list_all (params, Logic.list_implies (assums_hyp', Logic.strip_assums_concl t))
       end
     val cases' = map preprocess_case (tl prems)
     val elimrule' = Logic.list_implies ((hd prems) :: cases', Thm.concl_of elimrule)
@@ -657,6 +655,8 @@
     Thm.equal_elim eq elimrule |> singleton (Variable.export ctxt' ctxt)
   end;
 
+fun preprocess_elim ctxt = expand_tuples_elim ctxt #> preprocess_equality_elim ctxt
+  
 val no_compilation = ([], ([], []))
 
 fun fetch_pred_data ctxt name =
@@ -668,17 +668,15 @@
             val (const, _) = strip_comb (HOLogic.dest_Trueprop (concl_of intro))
           in (fst (dest_Const const) = name) end;
         val thy = ProofContext.theory_of ctxt
-        val intros =
-          (map (expand_tuples thy #> preprocess_intro thy) (filter is_intro_of (#intrs result)))
+        val intros = map (preprocess_intro thy) (filter is_intro_of (#intrs result))
         val index = find_index (fn s => s = name) (#names (fst info))
         val pre_elim = nth (#elims result) index
         val pred = nth (#preds result) index
+        val elim_t = mk_casesrule ctxt pred intros
         val nparams = length (Inductive.params_of (#raw_induct result))
-        val elim_t = mk_casesrule ctxt pred intros
-        val elim =
-          prove_casesrule ctxt (pred, (pre_elim, nparams)) elim_t
+        val elim = prove_casesrule ctxt (pred, (pre_elim, nparams)) elim_t
       in
-        mk_pred_data ((map (pair NONE) intros, SOME elim), no_compilation)
+        mk_pred_data (((map (pair NONE) intros, SOME elim), true), no_compilation)
       end
   | NONE => error ("No such predicate: " ^ quote name)
 
@@ -695,6 +693,10 @@
     val intros = map (Thm.prop_of o snd) ((#intros o rep_pred_data) value)
   in
     fold Term.add_const_names intros []
+      |> (fn cs =>
+        if member (op =) cs @{const_name "HOL.eq"} then
+          insert (op =) @{const_name Predicate.eq} cs
+        else cs)
       |> filter (fn c => (not (c = key)) andalso
         (is_inductive_predicate ctxt c orelse is_registered ctxt c))
   end;
@@ -705,26 +707,24 @@
     fun cons_intro gr =
      case try (Graph.get_node gr) name of
        SOME pred_data => Graph.map_node name (map_pred_data
-         (apfst (apfst (fn intros => intros @ [(opt_case_name, thm)])))) gr
-     | NONE => Graph.new_node (name, mk_pred_data (([(opt_case_name, thm)], NONE), no_compilation)) gr
+         (apfst (apfst (apfst (fn intros => intros @ [(opt_case_name, thm)]))))) gr
+     | NONE => Graph.new_node (name, mk_pred_data ((([(opt_case_name, thm)], NONE), false), no_compilation)) gr
   in PredData.map cons_intro thy end
 
 fun set_elim thm =
   let
     val (name, _) = dest_Const (fst 
       (strip_comb (HOLogic.dest_Trueprop (hd (prems_of thm)))))
-    fun set (intros, _) = (intros, SOME thm)
-  in PredData.map (Graph.map_node name (map_pred_data (apfst set))) end
+  in PredData.map (Graph.map_node name (map_pred_data (apfst (apfst (apsnd (K (SOME thm))))))) end
 
-fun register_predicate (constname, pre_intros, pre_elim) thy =
+fun register_predicate (constname, intros, elim) thy =
   let
-    val intros = map (pair NONE o preprocess_intro thy) pre_intros
-    val elim = preprocess_elim (ProofContext.init_global thy) pre_elim
+    val named_intros = map (pair NONE) intros
   in
     if not (member (op =) (Graph.keys (PredData.get thy)) constname) then
       PredData.map
         (Graph.new_node (constname,
-          mk_pred_data ((intros, SOME elim), no_compilation))) thy
+          mk_pred_data (((named_intros, SOME elim), false), no_compilation))) thy
     else thy
   end
 
@@ -795,7 +795,7 @@
     val alt_compilations = map (apsnd fst) compilations
   in
     PredData.map (Graph.new_node
-      (pred_name, mk_pred_data (([], SOME @{thm refl}), (dummy_function_names, ([], needs_random)))))
+      (pred_name, mk_pred_data ((([], SOME @{thm refl}), true), (dummy_function_names, ([], needs_random)))))
     #> Alt_Compilations_Data.map (Symtab.insert (K false) (pred_name, alt_compilations))
   end
 
@@ -2911,8 +2911,26 @@
     fun strong_conn_of gr keys =
       Graph.strong_conn (Graph.subgraph (member (op =) (Graph.all_succs gr keys)) gr)
     val scc = strong_conn_of (PredData.get thy') names
-    
-    val thy'' = fold_rev
+    fun preprocess name thy =
+      PredData.map (Graph.map_node name (map_pred_data (apfst (fn (rules, preprocessed) =>
+        if preprocessed then (rules, preprocessed)
+        else
+          let
+            val (named_intros, SOME elim) = rules
+            val named_intros' = map (apsnd (preprocess_intro thy)) named_intros
+            val pred = Const (name, Sign.the_const_type thy name)
+            val ctxt = ProofContext.init_global thy
+            val elim_t = mk_casesrule ctxt pred (map snd named_intros')
+            val nparams = (case try (Inductive.the_inductive ctxt) name of
+                SOME (_, result) => length (Inductive.params_of (#raw_induct result))
+              | NONE => 0)
+            val elim' = prove_casesrule ctxt (pred, (elim, 0)) elim_t
+          in
+            ((named_intros', SOME elim'), true)
+          end))))
+        thy
+    val thy'' = fold preprocess (flat scc) thy'
+    val thy''' = fold_rev
       (fn preds => fn thy =>
         if not (forall (defined (ProofContext.init_global thy)) preds) then
           let
@@ -2924,8 +2942,8 @@
             add_equations_of steps mode_analysis_options options preds thy
           end
         else thy)
-      scc thy' |> Theory.checkpoint
-  in thy'' end
+      scc thy'' |> Theory.checkpoint
+  in thy''' end
 
 val add_equations = gen_add_equations
   (Steps {