cleaning the function flattening
authorbulwahn
Mon, 22 Mar 2010 08:30:13 +0100
changeset 35878 74a74828d682
parent 35877 295e1af6c8dc
child 35879 99818df5b8f5
cleaning the function flattening
src/HOL/Tools/Predicate_Compile/predicate_compile_fun.ML
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_fun.ML	Mon Mar 22 08:30:13 2010 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_fun.ML	Mon Mar 22 08:30:13 2010 +0100
@@ -9,7 +9,6 @@
   val define_predicates : (string * thm list) list -> theory -> (string * thm list) list * theory
   val rewrite_intro : theory -> thm -> thm list
   val pred_of_function : theory -> string -> string option
-  
   val add_function_predicate_translation : (term * term) -> theory -> theory
 end;
 
@@ -38,7 +37,6 @@
       SOME (Envir.subst_term subst p)
     end
   | _ => NONE
-  (*_ => error ("Multiple matches possible for lookup of " ^ Syntax.string_of_term_global thy t)*)
 
 fun pred_of_function thy name =
   case Item_Net.retrieve (Fun_Pred.get thy) (Const (name, Term.dummyT)) of
@@ -62,8 +60,8 @@
     (T as Type ("fun", _)) =>
       (case arg of
         Free (name, _) => Free (name, transform_ho_typ T)
-      | _ => error "I am surprised")
-| _ => arg
+      | _ => raise Fail "A non-variable term at a higher-order position")
+  | _ => arg
 
 fun pred_type T =
   let
@@ -120,51 +118,6 @@
     SOME (c, _) => Predicate_Compile_Data.keep_function thy c
   | _ => false
 
-(* dump:
-fun flatten' (t as Const (name, T)) (names, prems) =
-      (if is_constr thy name orelse (is_none (lookup_pred t)) then
-        [(t, (names, prems))]
-      else
-       (*(if is_none (try lookup_pred t) then
-          [(Abs ("uu", fastype_of t, HOLogic.mk_eq (t, Bound 0)), (names, prems))]
-        else*) [(the (lookup_pred t), (names, prems))])
-    | flatten' (t as Free (f, T)) (names, prems) = 
-      (case lookup_pred t of
-        SOME t' => [(t', (names, prems))]
-      | NONE => [(t, (names, prems))])
-    | flatten' (t as Abs _) (names, prems) =
-      if Predicate_Compile_Aux.is_predT (fastype_of t) then
-        ([(Envir.eta_contract t, (names, prems))])
-      else
-        let
-          val (vars, body) = strip_abs t
-          val _ = assert (fastype_of body = body_type (fastype_of body))
-          val absnames = Name.variant_list names (map fst vars)
-          val frees = map2 (curry Free) absnames (map snd vars)
-          val body' = subst_bounds (rev frees, body)
-          val resname = Name.variant (absnames @ names) "res"
-          val resvar = Free (resname, fastype_of body)
-          val t = flatten' body' ([], [])
-            |> map (fn (res, (inner_names, inner_prems)) =>
-              let
-                fun mk_exists (x, T) t = HOLogic.mk_exists (x, T, t)
-                val vTs = 
-                  fold Term.add_frees inner_prems []
-                  |> filter (fn (x, T) => member (op =) inner_names x)
-                val t = 
-                  fold mk_exists vTs
-                  (foldr1 HOLogic.mk_conj (HOLogic.mk_eq (resvar, res) ::
-                    map HOLogic.dest_Trueprop inner_prems))
-              in
-                t
-              end)
-              |> foldr1 HOLogic.mk_disj
-              |> fold lambda (resvar :: rev frees)
-        in
-          [(t, (names, prems))]
-        end
-*)
-
 fun flatten thy lookup_pred t (names, prems) =
   let
     fun lift t (names, prems) =
@@ -253,8 +206,6 @@
           let
             val (f, args) = strip_comb t
             val split_thm = prepare_split_thm (ProofContext.init thy) (the (find_split_thm' thy f))
-            (* TODO: contextify things - this line is to unvarify the split_thm *)
-            (*val ((_, [isplit_thm]), _) = Variable.import true [split_thm] (ProofContext.init thy)*)
             val (assms, concl) = Logic.strip_horn (Thm.prop_of split_thm)
             val (P, [split_t]) = strip_comb (HOLogic.dest_Trueprop concl) 
             val subst = Pattern.match thy (split_t, t) (Vartab.empty, Vartab.empty)
@@ -306,16 +257,12 @@
                           fun mk_if T (b, t, e) =
                             Const (@{const_name If}, @{typ bool} --> T --> T --> T) $ b $ t $ e
                           val Ts = binder_types (fastype_of t)
-                          val t = 
+                        in
                           list_abs (map (pair "x") Ts @ [("b", @{typ bool})],
                             mk_if @{typ bool} (list_comb (t, map Bound (length Ts downto 1)),
                             HOLogic.mk_eq (@{term True}, Bound 0),
                             HOLogic.mk_eq (@{term False}, Bound 0)))
-                        in
-                          t
                         end
-                    (*val _ = tracing ("Ts: " ^ commas (map (Syntax.string_of_typ_global thy) Ts))
-                    val _ = map2 check_arity Ts (map fastype_of (argvs @ [resvar]))*)
                     val argvs' = map2 lift_arg Ts argvs
                     val resname = Name.variant names' "res"
                     val resvar = Free (resname, body_type (fastype_of t))
@@ -326,7 +273,7 @@
     map (apfst Envir.eta_contract) (flatten' (Pattern.eta_long [] t) (names, prems))
   end;
 
-(* assumption: mutual recursive predicates all have the same parameters. *)  
+(* assumption: mutual recursive predicates all have the same parameters. *)
 fun define_predicates specs thy =
   if forall (fn (const, _) => defined_const thy const) specs then
     ([], thy)
@@ -334,24 +281,22 @@
     let
       val consts = map fst specs
       val eqns = maps snd specs
-      (*val eqns = maps (Predicate_Compile_Preproc_Data.get_specification thy) consts*)
-        (* create prednames *)
+      (* create prednames *)
       val ((funs, argss), rhss) = map_split dest_code_eqn eqns |>> split_list
       val argss' = map (map transform_ho_arg) argss
-      (* TODO: higher order arguments also occur in tuples! *)
       fun is_lifted (t1, t2) = (fastype_of t2 = pred_type (fastype_of t1))
+     (* FIXME: higher order arguments also occur in tuples! *)
       val lifted_args = distinct (op =) (filter is_lifted (flat argss ~~ flat argss'))
       val preds = map pred_of funs
-        (* mapping from term (Free or Const) to term *)
+      (* mapping from term (Free or Const) to term *)
       val net = fold Item_Net.update
         ((funs ~~ preds) @ lifted_args)
           (Fun_Pred.get thy)
       fun lookup_pred t = lookup thy net t
       (* create intro rules *)
-    
       fun mk_intros ((func, pred), (args, rhs)) =
         if (body_type (fastype_of func) = @{typ bool}) then
-         (*TODO: preprocess predicate definition of rhs *)
+         (* TODO: preprocess predicate definition of rhs *)
           [Logic.list_implies ([HOLogic.mk_Trueprop rhs], HOLogic.mk_Trueprop (list_comb (pred, args)))]
         else
           let
@@ -361,44 +306,28 @@
               Logic.list_implies (prems, HOLogic.mk_Trueprop (list_comb (pred, args @ [resultt]))))
           end
       fun mk_rewr_thm (func, pred) = @{thm refl}
+      val intr_ts = maps mk_intros ((funs ~~ preds) ~~ (argss' ~~ rhss))
+      val (ind_result, thy') =
+        thy
+        |> Sign.map_naming Name_Space.conceal
+        |> Inductive.add_inductive_global
+          {quiet_mode = false, verbose = false, alt_name = Binding.empty, coind = false,
+            no_elim = false, no_ind = false, skip_mono = false, fork_mono = false}
+          (map (fn (s, T) =>
+            ((Binding.name s, T), NoSyn)) (distinct (op =) (map dest_Free preds)))
+          (map (dest_Free o snd) lifted_args)
+          (map (fn x => (Attrib.empty_binding, x)) intr_ts)
+          []
+        ||> Sign.restore_naming thy
+      val prednames = map (fst o dest_Const) (#preds ind_result)
+      (* add constants to my table *)
+      val specs = map (fn predname => (predname, filter (Predicate_Compile_Aux.is_intro predname)
+        (#intrs ind_result))) prednames
+      val thy'' = Fun_Pred.map
+        (fold Item_Net.update (map (apfst Logic.varify_global)
+          (distinct (op =) funs ~~ (#preds ind_result)))) thy'
     in
-      case (*try *)SOME (maps mk_intros ((funs ~~ preds) ~~ (argss' ~~ rhss))) of
-        NONE =>
-          let val _ = tracing "error occured!" in ([], thy) end
-      | SOME intr_ts =>
-          if is_some (try (map (cterm_of thy)) intr_ts) then
-            let
-              val (ind_result, thy') =
-                thy
-                |> Sign.map_naming Name_Space.conceal
-                |> Inductive.add_inductive_global
-                  {quiet_mode = false, verbose = false, alt_name = Binding.empty, coind = false,
-                    no_elim = false, no_ind = false, skip_mono = false, fork_mono = false}
-                  (map (fn (s, T) =>
-                    ((Binding.name s, T), NoSyn)) (distinct (op =) (map dest_Free preds)))
-                  []
-                  (map (fn x => (Attrib.empty_binding, x)) intr_ts)
-                  []
-                ||> Sign.restore_naming thy
-              val prednames = map (fst o dest_Const) (#preds ind_result)
-              (* val rewr_thms = map mk_rewr_eq ((distinct (op =) funs) ~~ (#preds ind_result)) *)
-              (* add constants to my table *)
-              
-              val specs = map (fn predname => (predname, filter (Predicate_Compile_Aux.is_intro predname)
-                (#intrs ind_result))) prednames
-              val thy'' = Fun_Pred.map
-                (fold Item_Net.update (map (apfst Logic.varify_global)
-                  (distinct (op =) funs ~~ (#preds ind_result)))) thy'
-              (*val _ = print_specs thy'' specs*)
-            in
-              (specs, thy'')
-            end
-          else
-            let
-              val _ = Output.tracing (
-              "Introduction rules of function_predicate are not welltyped: " ^
-                commas (map (Syntax.string_of_term_global thy) intr_ts))
-            in ([], thy) end
+      (specs, thy'')
     end
 
 fun rewrite_intro thy intro =