restructuring function flattening
authorbulwahn
Mon, 22 Mar 2010 08:30:12 +0100
changeset 35875 b0d24a74b06b
parent 35874 bcfa6b4b21c6
child 35876 ac44e2312f0a
restructuring function flattening
src/HOL/Tools/Predicate_Compile/predicate_compile.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_fun.ML
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile.ML	Mon Mar 22 08:30:12 2010 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile.ML	Mon Mar 22 08:30:12 2010 +0100
@@ -111,7 +111,8 @@
       val intross5 =
         map (fn (s, ths) => (overload_const thy''' s, map (AxClass.overload thy''') ths)) intross4
       val intross6 = map_specs (map (expand_tuples thy''')) intross5
-      val _ = print_intross options thy''' "introduction rules before registering: " intross6
+      val intross7 = map_specs (map (eta_contract_ho_arguments thy''')) intross6
+      val _ = print_intross options thy''' "introduction rules before registering: " intross7
       val _ = print_step options "Registering introduction rules..."
       val thy'''' = fold Predicate_Compile_Core.register_intros intross6 thy'''
     in
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Mon Mar 22 08:30:12 2010 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Mon Mar 22 08:30:12 2010 +0100
@@ -399,6 +399,18 @@
     Logic.list_implies (maps f premises, head)
   end
 
+fun map_concl f intro =
+  let
+    val (premises, head) = Logic.strip_horn intro
+  in
+    Logic.list_implies (premises, f head)
+  end
+
+(* combinators to apply a function to all basic parts of nested products *)
+
+fun map_products f (Const ("Pair", T) $ t1 $ t2) =
+  Const ("Pair", T) $ map_products f t1 $ map_products f t2
+  | map_products f t = f t
 
 (* split theorems of case expressions *)
 
@@ -619,4 +631,15 @@
     intro'''''
   end
 
+(* eta contract higher-order arguments *)
+
+
+fun eta_contract_ho_arguments thy intro =
+  let
+    fun f atom = list_comb (apsnd ((map o map_products) Envir.eta_contract) (strip_comb atom))
+  in
+    map_term thy (map_concl f o map_atoms f) intro
+  end
+
+
 end;
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_fun.ML	Mon Mar 22 08:30:12 2010 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_fun.ML	Mon Mar 22 08:30:12 2010 +0100
@@ -37,7 +37,8 @@
     in
       SOME (Envir.subst_term subst p)
     end
-  | _ => error ("Multiple matches possible for lookup of " ^ Syntax.string_of_term_global thy t)
+  | _ => NONE
+  (*_ => error ("Multiple matches possible for lookup of " ^ Syntax.string_of_term_global thy t)*)
 
 fun pred_of_function thy name =
   case Item_Net.retrieve (Fun_Pred.get thy) (Const (name, Term.dummyT)) of
@@ -119,9 +120,8 @@
     SOME (c, _) => Predicate_Compile_Data.keep_function thy c
   | _ => false
 
-fun flatten thy lookup_pred t (names, prems) =
-  let
-    fun flatten' (t as Const (name, T)) (names, prems) =
+(* dump:
+fun flatten' (t as Const (name, T)) (names, prems) =
       (if is_constr thy name orelse (is_none (lookup_pred t)) then
         [(t, (names, prems))]
       else
@@ -163,7 +163,55 @@
         in
           [(t, (names, prems))]
         end
-    | flatten' t (names, prems) =
+*)
+
+fun flatten thy lookup_pred t (names, prems) =
+  let
+    fun lift t (names, prems) =
+      case lookup_pred (Envir.eta_contract t) of
+        SOME pred => [(pred, (names, prems))]
+      | NONE =>
+        let
+          val (vars, body) = strip_abs t
+          val _ = assert (fastype_of body = body_type (fastype_of body))
+          val absnames = Name.variant_list names (map fst vars)
+          val frees = map2 (curry Free) absnames (map snd vars)
+          val body' = subst_bounds (rev frees, body)
+          val resname = Name.variant (absnames @ names) "res"
+          val resvar = Free (resname, fastype_of body)
+          val t = flatten' body' ([], [])
+            |> map (fn (res, (inner_names, inner_prems)) =>
+              let
+                fun mk_exists (x, T) t = HOLogic.mk_exists (x, T, t)
+                val vTs = 
+                  fold Term.add_frees inner_prems []
+                  |> filter (fn (x, T) => member (op =) inner_names x)
+                val t = 
+                  fold mk_exists vTs
+                  (foldr1 HOLogic.mk_conj (HOLogic.mk_eq (res, resvar) ::
+                    map HOLogic.dest_Trueprop inner_prems))
+              in
+                t
+              end)
+              |> foldr1 HOLogic.mk_disj
+              |> fold lambda (resvar :: rev frees)
+        in
+          [(t, (names, prems))]
+        end
+    and flatten_or_lift (t, T) (names, prems) =
+      if fastype_of t = T then
+        flatten' t (names, prems)
+      else
+        (* note pred_type might be to general! *)
+        if (pred_type (fastype_of t) = T) then
+          lift t (names, prems)
+        else
+          error ("unexpected input for flatten or lift" ^ Syntax.string_of_term_global thy t ^
+          ", " ^  Syntax.string_of_typ_global thy T)
+    and flatten' (t as Const (name, T)) (names, prems) = [(t, (names, prems))]
+      | flatten' (t as Free (f, T)) (names, prems) = [(t, (names, prems))]
+      | flatten' (t as Abs _) (names, prems) = [(t, (names, prems))]
+      | flatten' (t as _ $ _) (names, prems) =
       if Predicate_Compile_Aux.is_constrt thy t orelse keep_functions thy t then
         [(t, (names, prems))]
       else
@@ -172,11 +220,14 @@
             (let
               val (_, [B, x, y]) = strip_comb t
             in
-              (flatten' x (names, prems)
-              |> map (fn (res, (names, prems)) => (res, (names, (HOLogic.mk_Trueprop B) :: prems))))
-              @ (flatten' y (names, prems)
-              |> map (fn (res, (names, prems)) =>
-                (res, (names, (HOLogic.mk_Trueprop (HOLogic.mk_not B)) :: prems))))
+              flatten' B (names, prems)
+              |> maps (fn (B', (names, prems)) =>
+                (flatten' x (names, prems)
+                |> map (fn (res, (names, prems)) => (res, (names, (HOLogic.mk_Trueprop B') :: prems))))
+                @ (flatten' y (names, prems)
+                |> map (fn (res, (names, prems)) =>
+                  (* in general unsound! *)
+                  (res, (names, (HOLogic.mk_Trueprop (HOLogic.mk_not B')) :: prems)))))
             end)
         | Const (@{const_name "Let"}, _) => 
             (let
@@ -232,57 +283,47 @@
         else
           let
             val (f, args) = strip_comb t
-            (* TODO: special procedure for higher-order functions: split arguments in
-              simple types and function types *)
             val args = map (Pattern.eta_long []) args
-            val resname = Name.variant names "res"
-            val resvar = Free (resname, body_type (fastype_of t))
             val _ = assert (fastype_of t = body_type (fastype_of t))
-            val names' = resname :: names
-            val t' = lookup_pred f
-            val Ts = case t' of
+            val f' = lookup_pred f
+            val Ts = case f' of
               SOME pred => (fst (split_last (binder_types (fastype_of pred))))
-            | NONE => binder_types (fastype_of t)
-            val namesprems =
-              case t' of
-                NONE =>
-                  folds_map flatten' args (names', prems) |>
-                  map
-                    (fn (argvs, (names'', prems')) =>
-                    let
-                      val prem = HOLogic.mk_Trueprop (HOLogic.mk_eq (resvar, list_comb (f, argvs)))
-                    in (names'', prem :: prems') end)
-              | SOME pred =>
-                  folds_map flatten' args (names', prems)
-                  |> map (fn (argvs, (names'', prems')) =>
-                    let
-                      fun lift_arg T t =
-                        if (fastype_of t) = T then t
-                        else
-                          let
-                            val _ = assert (T =
-                              (binder_types (fastype_of t) @ [@{typ bool}] ---> @{typ bool}))
-                            fun mk_if T (b, t, e) =
-                              Const (@{const_name If}, @{typ bool} --> T --> T --> T) $ b $ t $ e
-                            val Ts = binder_types (fastype_of t)
-                            val t = 
-                            list_abs (map (pair "x") Ts @ [("b", @{typ bool})],
-                              mk_if @{typ bool} (list_comb (t, map Bound (length Ts downto 1)),
-                              HOLogic.mk_eq (@{term True}, Bound 0),
-                              HOLogic.mk_eq (@{term False}, Bound 0)))
-                          in
-                            t
-                          end
-                      (*val _ = tracing ("Ts: " ^ commas (map (Syntax.string_of_typ_global thy) Ts))
-                      val _ = map2 check_arity Ts (map fastype_of (argvs @ [resvar]))*)
-                      val argvs' = map2 lift_arg Ts argvs
-                      val prem = HOLogic.mk_Trueprop (list_comb (pred, argvs' @ [resvar]))
-                    in (names'', prem :: prems') end)
+            | NONE => binder_types (fastype_of f)
           in
-            map (pair resvar) namesprems
+            folds_map flatten_or_lift (args ~~ Ts) (names, prems) |>
+            (case f' of
+              NONE =>
+                map (fn (argvs, (names', prems')) => (list_comb (f, argvs), (names', prems')))
+            | SOME pred =>
+                map (fn (argvs, (names', prems')) =>
+                  let
+                    fun lift_arg T t =
+                      if (fastype_of t) = T then t
+                      else
+                        let
+                          val _ = assert (T =
+                            (binder_types (fastype_of t) @ [@{typ bool}] ---> @{typ bool}))
+                          fun mk_if T (b, t, e) =
+                            Const (@{const_name If}, @{typ bool} --> T --> T --> T) $ b $ t $ e
+                          val Ts = binder_types (fastype_of t)
+                          val t = 
+                          list_abs (map (pair "x") Ts @ [("b", @{typ bool})],
+                            mk_if @{typ bool} (list_comb (t, map Bound (length Ts downto 1)),
+                            HOLogic.mk_eq (@{term True}, Bound 0),
+                            HOLogic.mk_eq (@{term False}, Bound 0)))
+                        in
+                          t
+                        end
+                    (*val _ = tracing ("Ts: " ^ commas (map (Syntax.string_of_typ_global thy) Ts))
+                    val _ = map2 check_arity Ts (map fastype_of (argvs @ [resvar]))*)
+                    val argvs' = map2 lift_arg Ts argvs
+                    val resname = Name.variant names' "res"
+                    val resvar = Free (resname, body_type (fastype_of t))
+                    val prem = HOLogic.mk_Trueprop (list_comb (pred, argvs' @ [resvar]))
+                  in (resvar, (resname :: names', prem :: prems')) end))
           end
   in
-    flatten' (Pattern.eta_long [] t) (names, prems)
+    map (apfst Envir.eta_contract) (flatten' (Pattern.eta_long [] t) (names, prems))
   end;
 
 (* assumption: mutual recursive predicates all have the same parameters. *)  
@@ -373,12 +414,6 @@
   let
     (*val _ = tracing ("Rewriting intro with registered mapping for: " ^
       commas (Symtab.keys (Pred_Compile_Preproc.get thy)))*)
-    (*fun lookup_pred (Const (name, T)) =
-      (case (Symtab.lookup (Pred_Compile_Preproc.get thy) name) of
-        SOME c => SOME (Const (c, pred_type T))
-      | NONE => NONE)
-    | lookup_pred _ = NONE
-    *)
     fun lookup_pred t = lookup thy (Fun_Pred.get thy) t
     val intro_t = Logic.unvarify_global (prop_of intro)
     val (prems, concl) = Logic.strip_horn intro_t