added first support for higher-order function translation
authorbulwahn
Sat, 24 Oct 2009 16:55:42 +0200
changeset 33122 7d01480cc8e3
parent 33121 9b10dc5da0e0
child 33123 3c7c4372f9ad
added first support for higher-order function translation
src/HOL/Tools/Predicate_Compile/pred_compile_fun.ML
src/HOL/Tools/Predicate_Compile/predicate_compile.ML
--- a/src/HOL/Tools/Predicate_Compile/pred_compile_fun.ML	Sat Oct 24 16:55:42 2009 +0200
+++ b/src/HOL/Tools/Predicate_Compile/pred_compile_fun.ML	Sat Oct 24 16:55:42 2009 +0200
@@ -5,7 +5,7 @@
 
 signature PREDICATE_COMPILE_FUN =
 sig
-  val define_predicates : (string * thm list) list -> theory -> theory
+val define_predicates : (string * thm list) list -> theory -> (string * thm list) list * theory
   val rewrite_intro : theory -> thm -> thm list
   val setup_oracle : theory -> theory
   val pred_of_function : theory -> string -> string option
@@ -104,20 +104,25 @@
 
 fun mk_param lookup_pred (t as Free (v, _)) = lookup_pred t
   | mk_param lookup_pred t =
-  let
-    val (vs, body) = strip_abs t
-    val names = Term.add_free_names body []
-    val vs_names = Name.variant_list names (map fst vs)
-    val vs' = map2 (curry Free) vs_names (map snd vs)
-    val body' = subst_bounds (rev vs', body)
-    val (f, args) = strip_comb body'
-    val resname = Name.variant (vs_names @ names) "res"
-    val resvar = Free (resname, body_type (fastype_of body'))
-    val P = lookup_pred f
-    val pred_body = list_comb (P, args @ [resvar])
-    val param = fold_rev lambda (vs' @ [resvar]) pred_body
-  in param end;
-
+  if Predicate_Compile_Aux.is_predT (fastype_of t) then
+    t
+  else
+    error "not implemented"
+  (*  
+    let
+      val (vs, body) = strip_abs t
+      val names = Term.add_free_names body []
+      val vs_names = Name.variant_list names (map fst vs)
+      val vs' = map2 (curry Free) vs_names (map snd vs)
+      val body' = subst_bounds (rev vs', body)
+      val (f, args) = strip_comb body'
+      val resname = Name.variant (vs_names @ names) "res"
+      val resvar = Free (resname, body_type (fastype_of body'))
+      val P = lookup_pred f
+      val pred_body = list_comb (P, args @ [resvar])
+      val param = fold_rev lambda (vs' @ [resvar]) pred_body
+    in param end;
+  *)
 
 (* creates the list of premises for every intro rule *)
 (* theory -> term -> (string list, term list list) *)
@@ -217,6 +222,10 @@
       else [(lookup_pred t, (names, prems))]
     | mk_prems' (t as Free (f, T)) (names, prems) = 
       [(lookup_pred t, (names, prems))]
+    | mk_prems' (t as Abs _) (names, prems) =
+      if Predicate_Compile_Aux.is_predT (fastype_of t) then
+      [(t, (names, prems))] else error "mk_prems': Abs "
+      (* mk_param *)
     | mk_prems' t (names, prems) =
       if Predicate_Compile_Aux.is_constrt thy t then
         [(t, (names, prems))]
@@ -288,14 +297,7 @@
                      in (names', prem :: prems') end)
                 end
             | mk_prems'' t =
-                let
-                  val _ = tracing ("must define new constant for "
-                    ^ (Syntax.string_of_term_global thy t))
-                in 
-                  (*if is_predT (fastype_of t) then
-                  else*)
-                  error ("Invalid term: " ^ Syntax.string_of_term_global thy t)
-                end
+              error ("Invalid term: " ^ Syntax.string_of_term_global thy t)
           in
             map (pair resvar) (mk_prems'' f)
           end
@@ -306,7 +308,7 @@
 (* assumption: mutual recursive predicates all have the same parameters. *)  
 fun define_predicates specs thy =
   if forall (fn (const, _) => member (op =) (Symtab.keys (Pred_Compile_Preproc.get thy)) const) specs then
-    thy
+    ([], thy)
   else
   let
     val consts = map fst specs
@@ -363,9 +365,10 @@
     fun mk_rewr_thm (func, pred) = @{thm refl}
   in
     case try (maps mk_intros) ((funs ~~ preds) ~~ (argss' ~~ rhss)) of
-      NONE => thy 
+      NONE => ([], thy) 
     | SOME intr_ts => let
-        val _ = map (tracing o (Syntax.string_of_term_global thy)) intr_ts      
+        val _ = map (tracing o (Syntax.string_of_term_global thy)) intr_ts
+        val _ = map (cterm_of thy) intr_ts
       in
         if is_some (try (map (cterm_of thy)) intr_ts) then
           let
@@ -381,9 +384,18 @@
             val prednames = map (fst o dest_Const) (#preds ind_result)
             (* val rewr_thms = map mk_rewr_eq ((distinct (op =) funs) ~~ (#preds ind_result)) *)
             (* add constants to my table *)
-          in Pred_Compile_Preproc.map (fold Symtab.update_new (consts ~~ prednames)) thy' end
+            val specs = map (fn predname => (predname, filter (Predicate_Compile_Aux.is_intro predname) (#intrs ind_result))) prednames
+            val thy'' = Pred_Compile_Preproc.map (fold Symtab.update_new (consts ~~ prednames)) thy'
+          in
+            (specs, thy'')
+          end
         else
-          thy
+          let
+            val (p, _) = strip_comb (HOLogic.dest_Trueprop (hd (Logic.strip_imp_prems (hd intr_ts))))
+            val (_, T) = dest_Const p
+            val _ = tracing (Syntax.string_of_typ_global thy T)
+            val _ = tracing "Introduction rules of function_predicate are not welltyped"
+          in ([], thy) end
       end
   end
 
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile.ML	Sat Oct 24 16:55:42 2009 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile.ML	Sat Oct 24 16:55:42 2009 +0200
@@ -11,7 +11,7 @@
 struct
 
 (* options *)
-val fail_safe_mode = true
+val fail_safe_mode = false
 
 open Predicate_Compile_Aux;
 
@@ -94,6 +94,10 @@
     (space_implode "; " (map 
       (fn intros => commas (map (Display.string_of_thm_global thy) intros)) intross)))
 
+fun print_specs thy specs =
+  map (fn (c, thms) => "Constant " ^ c ^ " has specification:\n"
+    ^ (space_implode "\n" (map (Display.string_of_thm_global thy) thms)) ^ "\n") specs
+
 fun process_specification specs thy' =
   let
   val specs = map (apsnd (map
@@ -101,7 +105,6 @@
     val (intross1, thy'') = apfst flat (fold_map Predicate_Compile_Pred.preprocess specs thy')
     val _ = print_intross thy'' "Flattened introduction rules: " intross1
     val _ = priority "Replacing functions in introrules..."
-      (*  val _ = burrow (maps (Predicate_Compile_Fun.rewrite_intro thy'')) intross  *)
     val intross2 =
       if fail_safe_mode then
         case try (burrow (maps (Predicate_Compile_Fun.rewrite_intro thy''))) intross1 of
@@ -131,11 +134,12 @@
     (* untangle recursion by defining predicates for all functions *)
     val _ = priority "Compiling functions to predicates..."
     val _ = Output.tracing ("funnames: " ^ commas funnames)
-    val thy' =
-      thy |> not (null funnames) ? Predicate_Compile_Fun.define_predicates
-      (get_specs funnames)
+    val (fun_pred_specs, thy') =
+      if not (null funnames) then Predicate_Compile_Fun.define_predicates
+      (get_specs funnames) thy else ([], thy)
+    val _ = print_specs thy' fun_pred_specs
     val _ = priority "Compiling predicates to flat introrules..."
-    val specs = (get_specs prednames) 
+    val specs = (get_specs prednames) @ fun_pred_specs
     val (intross3, thy''') = process_specification specs thy'
     val _ = print_intross thy''' "Introduction rules with new constants: " intross3
     val intross4 = map (maps remove_pointless_clauses) intross3