src/HOL/Tools/Predicate_Compile/pred_compile_fun.ML
changeset 32672 90f3ce5d27ae
parent 32668 b2de45007537
child 32740 9dd0a2f83429
--- a/src/HOL/Tools/Predicate_Compile/pred_compile_fun.ML	Wed Sep 23 16:20:12 2009 +0200
+++ b/src/HOL/Tools/Predicate_Compile/pred_compile_fun.ML	Wed Sep 23 16:20:13 2009 +0200
@@ -69,19 +69,19 @@
 fun transform_ho_typ (T as Type ("fun", _)) =
   let
     val (Ts, T') = strip_type T
-  in (Ts @ [T']) ---> HOLogic.boolT end
+  in if T' = @{typ "bool"} then T else (Ts @ [T']) ---> HOLogic.boolT end
 | transform_ho_typ t = t
 
 fun transform_ho_arg arg = 
   case (fastype_of arg) of
-    (T as Type ("fun", _)) => (* if T = bool might be a relation already *)
+    (T as Type ("fun", _)) =>
       (case arg of
         Free (name, _) => Free (name, transform_ho_typ T)
       | _ => error "I am surprised")
 | _ => arg
 
 fun pred_type T =
-  let  
+  let
     val (Ts, T') = strip_type T
     val Ts' = map transform_ho_typ Ts
   in
@@ -194,22 +194,24 @@
 fun find_split_thm' thy (Const (@{const_name "If"}, _)) = SOME @{thm split_if}
   | find_split_thm' thy (Const (@{const_name "Let"}, _)) = SOME @{thm refl} (* TODO *)
   | find_split_thm' thy c = find_split_thm thy c
-  
-fun strip_all t = (Term.strip_all_vars t, Term.strip_all_body t)  
-  
+
+fun strip_all t = (Term.strip_all_vars t, Term.strip_all_body t)
+
 fun folds_map f xs y =
   let
     fun folds_map' acc [] y = [(rev acc, y)]
       | folds_map' acc (x :: xs) y =
-        maps (fn (x, y) => folds_map' (x :: acc) xs y) (f x y)     
+        maps (fn (x, y) => folds_map' (x :: acc) xs y) (f x y)
     in
       folds_map' [] xs y
     end;
-      
+
 fun mk_prems thy (lookup_pred, get_nparams) t (names, prems) =
   let
     fun mk_prems' (t as Const (name, T)) (names, prems) =
-      [(lookup_pred t, (names, prems))]
+      if is_constr thy name orelse (is_none (try lookup_pred t)) then
+        [(t ,(names, prems))]
+      else [(lookup_pred t, (names, prems))]
     | mk_prems' (t as Free (f, T)) (names, prems) = 
       [(lookup_pred t, (names, prems))]
     | mk_prems' t (names, prems) =
@@ -247,7 +249,7 @@
             val resvar = Free (resname, body_type (fastype_of t))
             val names' = resname :: names
             fun mk_prems'' (t as Const (c, _)) =
-              if is_constr thy c then
+              if is_constr thy c orelse (is_none (try lookup_pred t)) then
                 folds_map mk_prems' args (names', prems) |>
                 map
                   (fn (argvs, (names'', prems')) =>
@@ -259,6 +261,7 @@
                   val pred = lookup_pred t
                   val nparams = get_nparams pred
                   val (params, args) = chop nparams args
+                  val _ = tracing ("mk_prems'': " ^ (Syntax.string_of_term_global thy t) ^ " has " ^ string_of_int nparams ^ " parameters.")
                   val params' = map (mk_param lookup_pred) params
                 in
                   folds_map mk_prems' args (names', prems)
@@ -284,7 +287,7 @@
           end
   in
     mk_prems' t (names, prems)
-  end;    
+  end;
 
 (* assumption: mutual recursive predicates all have the same parameters. *)  
 fun define_predicates specs thy =
@@ -298,7 +301,7 @@
       (* create prednames *)
     val ((funs, argss), rhss) = map_split dest_code_eqn eqns |>> split_list
     val argss' = map (map transform_ho_arg) argss
-    val pnames = map (dest_Free) (distinct (op =) (flat argss' \\ flat argss))
+    val pnames = map dest_Free (distinct (op =) (maps (filter (is_funtype o fastype_of)) argss'))
     val preds = map pred_of funs
     val prednames = map (fst o dest_Free) preds
     val funnames = map (fst o dest_Const) funs
@@ -344,26 +347,30 @@
             Logic.list_implies (prems, HOLogic.mk_Trueprop (list_comb (pred, args @ [resultt]))))
         end
     fun mk_rewr_thm (func, pred) = @{thm refl}
-      
-    val intr_ts = maps mk_intros ((funs ~~ preds) ~~ (argss' ~~ rhss))
-    val _ = map (tracing o (Syntax.string_of_term_global thy)) intr_ts
-    (* define new inductive predicates *)
-    val (ind_result, thy') =
-      Inductive.add_inductive_global (serial_string ())
-        {quiet_mode = false, verbose = false, kind = Thm.internalK,
-          alt_name = Binding.empty, coind = false, no_elim = false,
-          no_ind = false, skip_mono = false, fork_mono = false}
-        (map (fn (s, T) => ((Binding.name s, T), NoSyn)) (distinct (op =) (map dest_Free preds)))
-        pnames
-        (map (fn x => (Attrib.empty_binding, x)) intr_ts)
-        [] thy
-    val prednames = map (fst o dest_Const) (#preds ind_result)
-    (* val rewr_thms = map mk_rewr_eq ((distinct (op =) funs) ~~ (#preds ind_result)) *)
-    (* add constants to my table *)
-    val thy'' = thy'
-      |> Pred_Compile_Preproc.map (fold Symtab.update_new (consts ~~ prednames))
-  in
-    thy''
+  in    
+    case try (maps mk_intros) ((funs ~~ preds) ~~ (argss' ~~ rhss)) of
+      NONE => thy 
+    | SOME intr_ts => let
+        val _ = map (tracing o (Syntax.string_of_term_global thy)) intr_ts      
+      in
+        if is_some (try (map (cterm_of thy)) intr_ts) then
+          let
+            val (ind_result, thy') =
+              Inductive.add_inductive_global (serial_string ())
+                {quiet_mode = false, verbose = false, kind = Thm.internalK,
+                  alt_name = Binding.empty, coind = false, no_elim = false,
+                  no_ind = false, skip_mono = false, fork_mono = false}
+                (map (fn (s, T) => ((Binding.name s, T), NoSyn)) (distinct (op =) (map dest_Free preds)))
+                pnames
+                (map (fn x => (Attrib.empty_binding, x)) intr_ts)
+                [] thy
+            val prednames = map (fst o dest_Const) (#preds ind_result)
+            (* val rewr_thms = map mk_rewr_eq ((distinct (op =) funs) ~~ (#preds ind_result)) *)
+            (* add constants to my table *)
+          in Pred_Compile_Preproc.map (fold Symtab.update_new (consts ~~ prednames)) thy' end
+        else
+          thy
+      end
   end
 
 (* preprocessing intro rules - uses oracle *)
@@ -374,7 +381,7 @@
     fun lookup_pred (Const (name, T)) =
       (case (Symtab.lookup (Pred_Compile_Preproc.get thy) name) of
         SOME c => Const (c, pred_type T)
-      | NONE => Const (name, T))
+      | NONE => error ("Function " ^ name ^ " is not inductified"))
     | lookup_pred (Free (name, T)) = Free (name, T)
     | lookup_pred _ = error "lookup function is not defined!"
 
@@ -387,17 +394,20 @@
     val _ = tracing (Syntax.string_of_term_global thy intro_t)
     val (prems, concl) = Logic.strip_horn intro_t
     val frees = map fst (Term.add_frees intro_t [])
-    fun opt_dest_Not t = the_default t (try HOLogic.dest_not t)
     fun rewrite prem names =
       let
-        val (P, args) = (strip_comb o opt_dest_Not o HOLogic.dest_Trueprop) prem
+        val t = (HOLogic.dest_Trueprop prem)
+        val (lit, mk_lit) = case try HOLogic.dest_not t of
+            SOME t => (t, HOLogic.mk_not)
+          | NONE => (t, I)
+        val (P, args) = (strip_comb lit) 
       in
         folds_map (
           fn t => if (is_funtype (fastype_of t)) then (fn x => [(t, x)])
             else mk_prems thy (lookup_pred, get_nparams) t) args (names, [])
         |> map (fn (resargs, (names', prems')) =>
           let
-            val prem' = HOLogic.mk_Trueprop (list_comb (P, resargs))
+            val prem' = HOLogic.mk_Trueprop (mk_lit (list_comb (P, resargs)))
           in (prem'::prems', names') end)
       end
     val intro_ts' = folds_map rewrite prems frees