added to process higher-order arguments by adding new constants
authorbulwahn
Sat, 24 Oct 2009 16:55:42 +0200
changeset 33121 9b10dc5da0e0
parent 33120 ca77d8c34ce2
child 33122 7d01480cc8e3
added to process higher-order arguments by adding new constants
src/HOL/Tools/Predicate_Compile/pred_compile_pred.ML
src/HOL/Tools/Predicate_Compile/predicate_compile.ML
--- a/src/HOL/Tools/Predicate_Compile/pred_compile_pred.ML	Sat Oct 24 16:55:42 2009 +0200
+++ b/src/HOL/Tools/Predicate_Compile/pred_compile_pred.ML	Sat Oct 24 16:55:42 2009 +0200
@@ -133,6 +133,56 @@
   end;
 
 fun preprocess_term t thy = error "preprocess_pred_term: to implement" 
-  
-  
+
+fun is_Abs (Abs _) = true
+  | is_Abs _       = false
+
+fun flat_higher_order_arguments (intross, thy) =
+  let
+    fun process constname atom (new_defs, thy) =
+      let
+        val (pred, args) = strip_comb atom
+        val abs_args = filter is_Abs args
+        fun replace_abs_arg (abs_arg as Abs _ ) (new_defs, thy) =
+          let
+            val _ = tracing ("Introduce new constant for " ^
+              Syntax.string_of_term_global thy abs_arg)
+            val vars = map Var (Term.add_vars abs_arg [])
+            val abs_arg' = Logic.unvarify abs_arg
+            val frees = map Free (Term.add_frees abs_arg' [])
+            val constname = Name.variant []
+              ((Long_Name.base_name constname) ^ "_hoaux")
+            val full_constname = Sign.full_bname thy constname
+            val constT = map fastype_of frees ---> (fastype_of abs_arg')
+            val const = Const (full_constname, constT)
+            val lhs = list_comb (const, frees)
+            val def = Logic.mk_equals (lhs, abs_arg')
+            val _ = tracing (Syntax.string_of_term_global thy def)
+            val ([definition], thy') = thy
+              |> Sign.add_consts_i [(Binding.name constname, constT, NoSyn)]
+              |> PureThy.add_defs false [((Binding.name (constname ^ "_def"), def), [])]
+            
+          in
+            (list_comb (Logic.varify const, vars), ((full_constname, [definition])::new_defs, thy'))
+          end
+        | replace_abs_arg arg (new_defs, thy) = (arg, (new_defs, thy))
+
+      val (args', (new_defs', thy')) = fold_map replace_abs_arg args (new_defs, thy)
+            (*        val _ = if not (null abs_args) then error "Found some abs argument" else ()*)
+      in
+        (list_comb (pred, args'), (new_defs', thy'))
+      end
+    fun flat_intro intro (new_defs, thy) =
+      let
+        val constname = "dummy"
+        val (intro_ts, (new_defs, thy)) = fold_map_atoms (process constname) (prop_of intro) (new_defs, thy)
+        val th = setmp quick_and_dirty true (SkipProof.make_thm thy) intro_ts
+      in
+        (th, (new_defs, thy))
+      end
+    val (intross', (new_defs, thy')) = fold_map (fold_map flat_intro) intross ([], thy)
+  in
+    (intross', (new_defs, thy'))
+  end
+
 end;
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile.ML	Sat Oct 24 16:55:42 2009 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile.ML	Sat Oct 24 16:55:42 2009 +0200
@@ -89,13 +89,40 @@
     []
   else [intro]
 
-
 fun print_intross thy msg intross =
   tracing (msg ^ 
     (space_implode "; " (map 
       (fn intros => commas (map (Display.string_of_thm_global thy) intros)) intross)))
 
-  
+fun process_specification specs thy' =
+  let
+  val specs = map (apsnd (map
+  (fn th => if is_equationlike th then Pred_Compile_Data.normalize_equation thy' th else th))) specs
+    val (intross1, thy'') = apfst flat (fold_map Predicate_Compile_Pred.preprocess specs thy')
+    val _ = print_intross thy'' "Flattened introduction rules: " intross1
+    val _ = priority "Replacing functions in introrules..."
+      (*  val _ = burrow (maps (Predicate_Compile_Fun.rewrite_intro thy'')) intross  *)
+    val intross2 =
+      if fail_safe_mode then
+        case try (burrow (maps (Predicate_Compile_Fun.rewrite_intro thy''))) intross1 of
+          SOME intross => intross
+        | NONE => let val _ = warning "Function replacement failed!" in intross1 end
+      else burrow (maps (Predicate_Compile_Fun.rewrite_intro thy'')) intross1
+    val _ = print_intross thy'' "Introduction rules with replaced functions: " intross2
+    val _ = priority "Introducing new constants for abstractions at higher-order argument positions..."
+    val (intross3, (new_defs, thy''')) = Predicate_Compile_Pred.flat_higher_order_arguments (intross2, thy'')
+    val _ = tracing ("Now derive introduction rules for new_defs: "
+        ^ space_implode "\n" 
+        (map (fn (c, ths) => c ^ ": " ^ 
+    commas (map (Display.string_of_thm_global thy''') ths)) new_defs))
+  val (new_intross, thy'''')  = if not (null new_defs) then
+    process_specification new_defs thy'''
+    else ([], thy''')
+  in
+    (intross3 @ new_intross, thy'''')
+  end
+
+
 fun preprocess_strong_conn_constnames gr constnames thy =
   let
     val get_specs = map (fn k => (k, Graph.get_node gr k))
@@ -108,27 +135,18 @@
       thy |> not (null funnames) ? Predicate_Compile_Fun.define_predicates
       (get_specs funnames)
     val _ = priority "Compiling predicates to flat introrules..."
-    val (intross1, thy'') = apfst flat (fold_map Predicate_Compile_Pred.preprocess
-      (get_specs prednames) thy')
-    val _ = print_intross thy'' "Flattened introduction rules: " intross1
-    val _ = priority "Replacing functions in introrules..."
-      (*  val _ = burrow (maps (Predicate_Compile_Fun.rewrite_intro thy'')) intross  *)
-    val intross2 =
-      if fail_safe_mode then
-        case try (burrow (maps (Predicate_Compile_Fun.rewrite_intro thy''))) intross1 of
-          SOME intross => intross
-        | NONE => let val _ = warning "Function replacement failed!" in intross1 end
-      else burrow (maps (Predicate_Compile_Fun.rewrite_intro thy'')) intross1
-    val _ = print_intross thy'' "Introduction rules with replaced functions: " intross2
-    val intross3 = map (maps remove_pointless_clauses) intross2
-    val _ = print_intross thy'' "After removing pointless clauses: " intross3
-    val intross4 = burrow (map (AxClass.overload thy'')) intross3
-    val intross5 = burrow (map (simplify_fst_snd o expand_tuples thy'')) intross4
-    val _ = print_intross thy'' "introduction rules before registering: " intross5
+    val specs = (get_specs prednames) 
+    val (intross3, thy''') = process_specification specs thy'
+    val _ = print_intross thy''' "Introduction rules with new constants: " intross3
+    val intross4 = map (maps remove_pointless_clauses) intross3
+    val _ = print_intross thy''' "After removing pointless clauses: " intross4
+    val intross5 = burrow (map (AxClass.overload thy''')) intross4
+    val intross6 = burrow (map (simplify_fst_snd o expand_tuples thy''')) intross5
+    val _ = print_intross thy''' "introduction rules before registering: " intross6
     val _ = priority "Registering intro rules..."
-    val thy''' = fold Predicate_Compile_Core.register_intros intross5 thy''
+    val thy'''' = fold Predicate_Compile_Core.register_intros intross6 thy'''
   in
-    thy'''
+    thy''''
   end;
 
 fun preprocess const thy =