src/HOL/ex/predicate_compile.ML
changeset 32306 19f55947d4d5
parent 32287 65d5c5b30747
child 32307 55166cd57a6d
--- a/src/HOL/ex/predicate_compile.ML	Tue Aug 04 01:01:23 2009 +0200
+++ b/src/HOL/ex/predicate_compile.ML	Tue Aug 04 08:34:56 2009 +0200
@@ -17,6 +17,7 @@
   val predfun_name_of: theory -> string -> mode -> string
   val all_preds_of : theory -> string list
   val modes_of: theory -> string -> mode list
+  val string_of_mode : mode -> string
   val intros_of: theory -> string -> thm list
   val nparams_of: theory -> string -> int
   val add_intro: thm -> theory -> theory
@@ -25,12 +26,17 @@
   val code_pred: string -> Proof.context -> Proof.state
   val code_pred_cmd: string -> Proof.context -> Proof.state
   val print_stored_rules: theory -> unit
+  val print_all_modes: theory -> unit
   val do_proofs: bool ref
   val mk_casesrule : Proof.context -> int -> thm list -> term
   val analyze_compr: theory -> term -> term
   val eval_ref: (unit -> term Predicate.pred) option ref
   val add_equations : string -> theory -> theory
   val code_pred_intros_attrib : attribute
+  (* used by Quickcheck_Generator *) 
+  val funT_of : mode -> typ -> typ
+  val mk_if_pred : term -> term
+  val mk_Eval : term * term -> term
 end;
 
 structure Predicate_Compile : PREDICATE_COMPILE =
@@ -43,8 +49,7 @@
 fun tracing s = (if ! Toplevel.debug then Output.tracing s else ());
 
 fun print_tac s = (if ! Toplevel.debug then Tactical.print_tac s else Seq.single);
-fun new_print_tac s = Tactical.print_tac s
-fun debug_tac msg = (fn st => (Output.tracing msg; Seq.single st));
+fun debug_tac msg = (fn st => (tracing msg; Seq.single st));
 
 val do_proofs = ref true;
 
@@ -113,7 +118,7 @@
 
 val mk_sup = HOLogic.mk_binop @{const_name sup};
 
-fun mk_if_predenum cond = Const (@{const_name Predicate.if_pred},
+fun mk_if_pred cond = Const (@{const_name Predicate.if_pred},
   HOLogic.boolT --> mk_pred_enumT HOLogic.unitT) $ cond;
 
 fun mk_not_pred t = let val T = mk_pred_enumT HOLogic.unitT
@@ -248,6 +253,18 @@
     fold print preds ()
   end;
 
+fun print_all_modes thy =
+  let
+    val _ = writeln ("Inferred modes:")
+    fun print (pred, modes) u =
+      let
+        val _ = writeln ("predicate: " ^ pred)
+        val _ = writeln ("modes: " ^ (commas (map string_of_mode modes)))
+      in u end  
+  in
+    fold print (all_modes_of thy) ()
+  end
+
 (** preprocessing rules **)  
 
 fun imp_prems_conv cv ct =
@@ -465,12 +482,13 @@
           end
         end)) (AList.lookup op = modes name)
 
-  in (case strip_comb t of
+  in
+    case strip_comb (Envir.eta_contract t) of
       (Const (name, _), args) => the_default default (mk_modes name args)
     | (Var ((name, _), _), args) => the (mk_modes name args)
     | (Free (name, _), args) => the (mk_modes name args)
-    | (Abs _, []) => modes_of_param default modes t 
-    | _ => default)
+    | (Abs _, []) => error "Abs at param position" (* modes_of_param default modes t *)
+    | _ => default
   end
 
 datatype indprem = Prem of term list * term | Negprem of term list * term | Sidecond of term;
@@ -529,7 +547,7 @@
   in (p, List.filter (fn m => case find_index
     (not o check_mode_clause thy param_vs modes m) rs of
       ~1 => true
-    | i => (tracing ("Clause " ^ string_of_int (i+1) ^ " of " ^
+    | i => (Output.tracing ("Clause " ^ string_of_int (i+1) ^ " of " ^
       p ^ " violates mode " ^ string_of_mode m); false)) ms)
   end;
 
@@ -547,18 +565,18 @@
 
 (* term construction *)
 
-(* for simple modes (e.g. parameters) only: better call it param_funT *)
-(* or even better: remove it and only use funT'_of - some modifications to funT'_of necessary *) 
-fun funT_of T NONE = T
-  | funT_of T (SOME mode) = let
+(* Remark: types of param_funT_of and funT_of are swapped - which is the more
+canonical order? *)
+fun param_funT_of T NONE = T
+  | param_funT_of T (SOME mode) = let
      val Ts = binder_types T;
      val (Us1, Us2) = get_args mode Ts
    in Us1 ---> (mk_pred_enumT (mk_tupleT Us2)) end;
 
-fun funT'_of (iss, is) T = let
+fun funT_of (iss, is) T = let
     val Ts = binder_types T
     val (paramTs, argTs) = chop (length iss) Ts
-    val paramTs' = map2 (fn SOME is => funT'_of ([], is) | NONE => I) iss paramTs 
+    val paramTs' = map2 (fn SOME is => funT_of ([], is) | NONE => I) iss paramTs 
     val (inargTs, outargTs) = get_args is argTs
   in
     (paramTs' @ inargTs) ---> (mk_pred_enumT (mk_tupleT outargTs))
@@ -636,9 +654,9 @@
         val f' = case f of
             Const (name, T) =>
               if AList.defined op = modes name then
-                Const (predfun_name_of thy name (iss, is'), funT'_of (iss, is') T)
+                Const (predfun_name_of thy name (iss, is'), funT_of (iss, is') T)
               else error "compile param: Not an inductive predicate with correct mode"
-          | Free (name, T) => Free (name, funT_of T (SOME is'))
+          | Free (name, T) => Free (name, param_funT_of T (SOME is'))
         val outTs = dest_tupleT (dest_pred_enumT (body_type (fastype_of f')))
         val out_vs = map Free (out_names ~~ outTs)
         val params' = map (compile_param thy modes) (ms ~~ params)
@@ -662,9 +680,9 @@
      val f' = case f of
         Const (name, T) =>
           if AList.defined op = modes name then
-             Const (predfun_name_of thy name (iss, is'), funT'_of (iss, is') T)
+             Const (predfun_name_of thy name (iss, is'), funT_of (iss, is') T)
           else error "compile param: Not an inductive predicate with correct mode"
-      | Free (name, T) => Free (name, funT_of T (SOME is'))
+      | Free (name, T) => Free (name, param_funT_of T (SOME is'))
    in list_comb (f', params' @ args') end
  | compile_param _ _ _ = error "compile params"
   
@@ -684,7 +702,7 @@
        | (Free (name, T), args) =>
          (*if name mem param_vs then *)
          (* Higher order mode call *)
-         let val r = Free (name, funT_of T (SOME is))
+         let val r = Free (name, param_funT_of T (SOME is))
          in list_comb (r, args) end)
   | compile_expr _ _ _ = error "not a valid inductive expression"
 
@@ -746,7 +764,7 @@
                  let
                    val rest = compile_prems [] vs' names'' ps';
                  in
-                   (mk_if_predenum t, rest)
+                   (mk_if_pred t, rest)
                  end
           in
             compile_match thy constr_vs' eqs out_ts'' 
@@ -761,7 +779,7 @@
   let
     val Ts = binder_types T;
     val (Ts1, Ts2) = chop (length param_vs) Ts;
-    val Ts1' = map2 funT_of Ts1 (fst mode)
+    val Ts1' = map2 param_funT_of Ts1 (fst mode)
     val (Us1, Us2) = get_args (snd mode) Ts2;
     val xnames = Name.variant_list param_vs
       (map (fn i => "x" ^ string_of_int i) (snd mode));
@@ -817,7 +835,7 @@
   val argnames = Name.variant_list []
         (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
   val (Ts1, Ts2) = chop nparams Ts;
-  val Ts1' = map2 funT_of Ts1 (fst mode)
+  val Ts1' = map2 param_funT_of Ts1 (fst mode)
   val args = map Free (argnames ~~ (Ts1' @ Ts2))
   val (params, io_args) = chop nparams args
   val (inargs, outargs) = get_args (snd mode) io_args
@@ -834,7 +852,7 @@
   val funpropI = HOLogic.mk_Trueprop (mk_Eval (list_comb (funtrm, funargs),
                    mk_tuple outargs))
   val introtrm = Logic.list_implies (predpropI :: param_eqs, funpropI)
-  val _ = Output.tracing (Syntax.string_of_term_global thy introtrm) 
+  val _ = tracing (Syntax.string_of_term_global thy introtrm) 
   val simprules = [defthm, @{thm eval_pred},
                    @{thm "split_beta"}, @{thm "fst_conv"}, @{thm "snd_conv"}]
   val unfolddef_tac = (Simplifier.asm_full_simp_tac (HOL_basic_ss addsimps simprules) 1)
@@ -860,7 +878,7 @@
         ^ (string_of_mode (snd mode))
       val Ts = binder_types T;
       val (Ts1, Ts2) = chop nparams Ts;
-      val Ts1' = map2 funT_of Ts1 (fst mode)
+      val Ts1' = map2 param_funT_of Ts1 (fst mode)
       val (Us1, Us2) = get_args (snd mode) Ts2;
       val names = Name.variant_list []
         (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
@@ -921,12 +939,12 @@
   REPEAT_DETERM (etac @{thm thin_rl} 1)
   THEN REPEAT_DETERM (rtac @{thm ext} 1)
   THEN (rtac @{thm iffI} 1)
-  THEN new_print_tac "prove_param"
+  THEN print_tac "prove_param"
   (* proof in one direction *)
   THEN (atac 1)
   (* proof in the other direction *)
   THEN (atac 1)
-  THEN new_print_tac "after prove_param"
+  THEN print_tac "after prove_param"
 (*  let
     val  (f, args) = strip_comb t
     val (params, _) = chop (length ms) args
@@ -964,10 +982,10 @@
         (* for the right assumption in first position *)
         THEN rotate_tac premposition 1
         THEN rtac introrule 1
-        THEN new_print_tac "after intro rule"
+        THEN print_tac "after intro rule"
         (* work with parameter arguments *)
         THEN (atac 1)
-        THEN (new_print_tac "parameter goal")
+        THEN (print_tac "parameter goal")
         THEN (EVERY (map (prove_param thy modes) (ms ~~ args1)))
         THEN (REPEAT_DETERM (atac 1)) end)
       else error "Prove expr if case not implemented"
@@ -1110,7 +1128,7 @@
          (fn i => EVERY' (select_sup (length clauses) i) i) 
            (1 upto (length clauses))))
   THEN (EVERY (map (prove_clause thy nargs all_vs param_vs modes mode) clauses))
-  THEN new_print_tac "proved one direction"
+  THEN print_tac "proved one direction"
 end;
 
 (*******************************************************************************************************)
@@ -1168,15 +1186,15 @@
       if AList.defined op = modes name then
         etac @{thm bindE} 1
         THEN (REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"})))
-        THEN new_print_tac "prove_expr2-before"
+        THEN print_tac "prove_expr2-before"
         THEN (debug_tac (Syntax.string_of_term_global thy
           (prop_of (predfun_elim_of thy name mode))))
         THEN (etac (predfun_elim_of thy name mode) 1)
-        THEN new_print_tac "prove_expr2"
+        THEN print_tac "prove_expr2"
         (* TODO -- FIXME: replace remove_last_goal*)
         (* THEN (EVERY (replicate (length args) (remove_last_goal thy))) *)
         THEN (EVERY (map (prove_param thy modes) (ms ~~ args)))
-        THEN new_print_tac "finished prove_expr2"
+        THEN print_tac "finished prove_expr2"
       
       else error "Prove expr2 if case not implemented"
     | _ => etac @{thm bindE} 1)
@@ -1273,7 +1291,7 @@
     end;
   val prems_tac = prove_prems2 in_ts' param_vs ps 
 in
-  new_print_tac "starting prove_clause2"
+  print_tac "starting prove_clause2"
   THEN etac @{thm bindE} 1
   THEN (etac @{thm singleE'} 1)
   THEN (TRY (etac @{thm Pair_inject} 1))
@@ -1401,10 +1419,10 @@
   val clauses' = map (fn (s, cls) => (s, (the (AList.lookup (op =) preds s), cls))) clauses
   val _ = tracing "Compiling equations..."
   val ts = compile_preds thy' all_vs param_vs (extra_modes @ modes) clauses'
-  val _ = map (Output.tracing o (Syntax.string_of_term_global thy')) (flat ts)
+(*  val _ = map (tracing o (Syntax.string_of_term_global thy')) (flat ts) *)
   val pred_mode =
     maps (fn (s, (T, _)) => map (pair (s, T)) ((the o AList.lookup (op =) modes) s)) clauses' 
-  val _ = Output.tracing "Proving equations..."
+  val _ = tracing "Proving equations..."
   val result_thms =
     prove_preds thy' all_vs param_vs (extra_modes @ modes) clauses (pred_mode ~~ (flat ts))
   val thy'' = fold (fn (name, result_thms) => fn thy => snd (PureThy.add_thmss
@@ -1486,7 +1504,6 @@
         assumes = [("", Logic.strip_imp_prems case_rule)],
         binds = [], cases = []}) cases_rules
     val case_env = map2 (fn p => fn c => (Long_Name.base_name p, SOME c)) preds cases
-    val _ = Output.tracing (commas (map fst case_env))
     val lthy'' = ProofContext.add_cases true case_env lthy'
     
     fun after_qed thms =