src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
changeset 33138 e2e23987c59a
parent 33137 0d16c07f8d24
child 33139 9c01ee6f8ee9
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Sat Oct 24 16:55:42 2009 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Sat Oct 24 16:55:42 2009 +0200
@@ -388,7 +388,7 @@
 (* diagnostic display functions *)
 
 fun print_modes modes =
-  tracing ("Inferred modes:\n" ^
+  Output.tracing ("Inferred modes:\n" ^
     cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map
       string_of_mode ms)) modes));
 
@@ -398,7 +398,7 @@
       ^ (string_of_entry pred mode entry)  
     fun print_pred (pred, modes) =
       "predicate " ^ pred ^ ": " ^ cat_lines (map (print_mode pred) modes)
-    val _ = tracing (cat_lines (map print_pred pred_mode_table))
+    val _ = Output.tracing (cat_lines (map print_pred pred_mode_table))
   in () end;
 
 fun string_of_prem thy (Prem (ts, p)) =
@@ -774,8 +774,6 @@
   mk_sup : term * term -> term,
   mk_if : term -> term,
   mk_not : term -> term,
-(*  funT_of : mode -> typ -> typ, *)
-(*  mk_fun_of : theory -> (string * typ) -> mode -> term, *) 
   mk_map : typ -> typ -> term -> term -> term,
   lift_pred : term -> term
 };
@@ -788,8 +786,6 @@
 fun mk_sup (CompilationFuns funs) = #mk_sup funs
 fun mk_if (CompilationFuns funs) = #mk_if funs
 fun mk_not (CompilationFuns funs) = #mk_not funs
-(*fun funT_of (CompilationFuns funs) = #funT_of funs*)
-(*fun mk_fun_of (CompilationFuns funs) = #mk_fun_of funs*)
 fun mk_map (CompilationFuns funs) = #mk_map funs
 fun lift_pred (CompilationFuns funs) = #lift_pred funs
 
@@ -889,7 +885,8 @@
 fun mk_if cond = Const (@{const_name RPred.if_rpred},
   HOLogic.boolT --> mk_rpredT HOLogic.unitT) $ cond;
 
-fun mk_not t = error "Negation is not defined for RPred"
+fun mk_not t = let val T = mk_rpredT HOLogic.unitT
+  in Const (@{const_name RPred.not_rpred}, T --> T) $ t end
 
 fun mk_map T1 T2 tf tp = Const (@{const_name RPred.map},
   (T1 --> T2) --> mk_rpredT T1 --> mk_rpredT T2) $ tf $ tp
@@ -924,7 +921,7 @@
   let
     val Ts = binder_types T
     val (paramTs, (inargTs, outargTs)) = split_modeT (iss, is) Ts
-    val paramTs' = map2 (fn SOME is => depth_limited_funT_of PredicateCompFuns.compfuns ([], is) | NONE => I) iss paramTs 
+    val paramTs' = map2 (fn SOME is => depth_limited_funT_of compfuns ([], is) | NONE => I) iss paramTs 
   in
     (paramTs' @ inargTs @ [@{typ bool}, @{typ "code_numeral"}]) ---> (mk_predT compfuns (mk_tupleT outargTs))
   end;  
@@ -963,16 +960,16 @@
 fun term_vTs tm =
   fold_aterms (fn Free xT => cons xT | _ => I) tm [];
 
-(*FIXME this function should not be named merge... make it local instead*)
-fun merge xs [] = xs
-  | merge [] ys = ys
-  | merge (x::xs) (y::ys) = if length x >= length y then x::merge xs (y::ys)
-      else y::merge (x::xs) ys;
-
-fun subsets i j = if i <= j then
-       let val is = subsets (i+1) j
-       in merge (map (fn ks => i::ks) is) is end
-     else [[]];
+fun subsets i j =
+  if i <= j then
+    let
+      fun merge xs [] = xs
+        | merge [] ys = ys
+        | merge (x::xs) (y::ys) = if length x >= length y then x::merge xs (y::ys)
+            else y::merge (x::xs) ys;
+      val is = subsets (i+1) j
+    in merge (map (fn ks => i::ks) is) is end
+  else [[]];
      
 (* FIXME: should be in library - cprod = map_prod I *)
 fun cprod ([], ys) = []
@@ -1025,7 +1022,7 @@
 *)
 fun modes_of_term modes t =
   let
-    val ks = map_index (fn (i, T) => (i, NONE)) (binder_types (fastype_of t));
+    val ks = map_index (fn (i, T) => (i + 1, NONE)) (binder_types (fastype_of t));
     val default = [Mode (([], ks), ks, [])];
     fun mk_modes name args = Option.map (maps (fn (m as (iss, is)) =>
         let
@@ -1205,16 +1202,21 @@
 fun infer_modes_with_generator options thy extra_modes all_modes param_vs clauses =
   let
     val prednames = map fst clauses
-    val extra_modes = all_modes_of thy
+    val extra_modes' = all_modes_of thy
     val gen_modes = all_generator_modes_of thy
       |> filter_out (fn (name, _) => member (op =) prednames name)
-    val starting_modes = remove_from extra_modes all_modes
+    val starting_modes = remove_from extra_modes' all_modes
+    fun eq_mode (m1, m2) = (m1 = m2)
     val modes =
       fixp (fn modes =>
-        map (check_modes_pred options true thy param_vs clauses extra_modes (gen_modes @ modes)) modes)
-         starting_modes 
+        map (check_modes_pred options true thy param_vs clauses extra_modes' (gen_modes @ modes)) modes)
+         starting_modes
   in
-    map (get_modes_pred true thy param_vs clauses extra_modes (gen_modes @ modes)) modes
+    AList.join (op =)
+    (fn _ => fn ((mps1, mps2)) =>
+      merge (fn ((m1, _), (m2, _)) => eq_mode (m1, m2)) (mps1, mps2))
+    (infer_modes options thy extra_modes all_modes param_vs clauses,
+    map (get_modes_pred true thy param_vs clauses extra_modes (gen_modes @ modes)) modes)
   end;
 
 (* term construction *)
@@ -1270,13 +1272,12 @@
     fold_rev lambda vs (f (list_comb (t, vs)))
   end;
 
-fun compile_param depth_limited thy compfuns (NONE, t) = t
-  | compile_param depth_limited thy compfuns (m as SOME (Mode ((iss, is'), is, ms)), t) =
+fun compile_param depth_limited thy compfuns mk_fun_of (NONE, t) = t
+  | compile_param depth_limited thy compfuns mk_fun_of (m as SOME (Mode ((iss, is'), is, ms)), t) =
    let
      val (f, args) = strip_comb (Envir.eta_contract t)
      val (params, args') = chop (length ms) args
-     val params' = map (compile_param depth_limited thy compfuns) (ms ~~ params)
-     val mk_fun_of = if depth_limited then mk_depth_limited_fun_of else mk_fun_of
+     val params' = map (compile_param depth_limited thy compfuns mk_fun_of) (ms ~~ params)
      val funT_of = if depth_limited then depth_limited_funT_of else funT_of
      val f' =
        case f of
@@ -1287,34 +1288,36 @@
      list_comb (f', params' @ args')
    end
 
-fun compile_expr depth_limited thy ((Mode (mode, is, ms)), t) inargs =
+fun compile_expr depth_limited thy compfuns mk_fun_of ((Mode (mode, is, ms)), t) inargs =
   case strip_comb t of
     (Const (name, T), params) =>
        let
-         val params' = map (compile_param depth_limited thy PredicateCompFuns.compfuns) (ms ~~ params)
-         val mk_fun_of = if depth_limited then mk_depth_limited_fun_of else mk_fun_of
+         val params' = map (compile_param depth_limited thy compfuns mk_fun_of) (ms ~~ params)
+           (*val mk_fun_of = if depth_limited then mk_depth_limited_fun_of else mk_fun_of*)
+         val _ = if mode = ([], [(0, NONE)]) then error "something is wrong" else ()
+           val _ = Output.tracing ("compile_expr mode: " ^ string_of_mode mode)
        in
-         list_comb (mk_fun_of PredicateCompFuns.compfuns thy (name, T) mode, params' @ inargs)
+         (*lift_pred compfuns*)(list_comb (mk_fun_of compfuns thy (name, T) mode, params' @ inargs))
        end
   | (Free (name, T), params) =>
        let 
          val funT_of = if depth_limited then depth_limited_funT_of else funT_of
        in
-         list_comb (Free (name, funT_of PredicateCompFuns.compfuns ([], is) T), params @ inargs)
+         list_comb (Free (name, funT_of compfuns ([], is) T), params @ inargs)
        end;
-       
-fun compile_gen_expr depth thy ((Mode (mode, is, ms)), t) inargs =
+
+fun compile_gen_expr depth_limited thy compfuns mk_fun_of ((Mode (mode, is, ms)), t) inargs =
   case strip_comb t of
     (Const (name, T), params) =>
       let
-        val params' = map (compile_param depth thy RPredCompFuns.compfuns) (ms ~~ params)
+        val params' = map (compile_param depth_limited thy RPredCompFuns.compfuns mk_fun_of) (ms ~~ params)
       in
         list_comb (mk_generator_of RPredCompFuns.compfuns thy (name, T) mode, params' @ inargs)
       end
   | (Free (name, T), params) =>
-    lift_pred RPredCompFuns.compfuns
-      (list_comb (Free (name, depth_limited_funT_of RPredCompFuns.compfuns ([], is) T), params @ inargs))
-          
+  (*lift_pred RPredCompFuns.compfuns*)
+  (list_comb (Free (name, depth_limited_funT_of RPredCompFuns.compfuns ([], is) T), params @ inargs))
+
 (** specific rpred functions -- move them to the correct place in this file *)
 
 fun mk_Eval_of depth ((x, T), NONE) names = (x, names)
@@ -1383,7 +1386,7 @@
       | map_params t = t
     in map_aterms map_params arg end
   
-fun compile_clause compfuns depth thy all_vs param_vs (iss, is) inp (ts, moded_ps) =
+fun compile_clause compfuns mk_fun_of depth thy all_vs param_vs (iss, is) inp (ts, moded_ps) =
   let
     fun check_constrt t (names, eqs) =
       if is_constrt thy t then (t, (names, eqs)) else
@@ -1421,8 +1424,8 @@
                    val args = case depth of
                      NONE => in_ts
                    | SOME (polarity, depth_t) => in_ts @ [polarity, depth_t]
-                   val u = lift_pred compfuns
-                     (compile_expr (is_some depth) thy (mode, t) args)
+                   val u =
+                     (compile_expr (is_some depth) thy compfuns mk_fun_of (mode, t) args)
                    val rest = compile_prems out_ts''' vs' names'' ps
                  in
                    (u, rest)
@@ -1433,8 +1436,8 @@
                    val args = case depth of
                      NONE => in_ts
                    | SOME (polarity, depth_t) => in_ts @ [HOLogic.mk_not polarity, depth_t]
-                   val u = lift_pred compfuns (mk_not PredicateCompFuns.compfuns
-                     (compile_expr (is_some depth) thy (mode, t) args))
+                 val u = (*lift_pred compfuns*) (mk_not compfuns
+                     (compile_expr (is_some depth) thy compfuns mk_fun_of (mode, t) args))
                    val rest = compile_prems out_ts''' vs' names'' ps
                  in
                    (u, rest)
@@ -1449,9 +1452,9 @@
                  let
                    val (in_ts, out_ts''') = split_smode is us;
                    val args = case depth of
-                     NONE => in_ts
+                       NONE => in_ts
                      | SOME (polarity, depth_t) => in_ts @ [polarity, depth_t]
-                   val u = compile_gen_expr (is_some depth) thy (mode, t) args
+                   val u = compile_gen_expr (is_some depth) thy compfuns mk_fun_of (mode, t) args
                    val rest = compile_prems out_ts''' vs' names'' ps
                  in
                    (u, rest)
@@ -1477,7 +1480,7 @@
 	  val (Ts1, Ts2) = chop (length (fst mode)) (binder_types T)
     val (Us1, Us2) = split_smodeT (snd mode) Ts2
     val funT_of = if depth_limited then depth_limited_funT_of else funT_of
-    val Ts1' = map2 (fn NONE => I | SOME is => funT_of PredicateCompFuns.compfuns ([], is)) (fst mode) Ts1
+    val Ts1' = map2 (fn NONE => I | SOME is => funT_of compfuns ([], is)) (fst mode) Ts1
   	fun mk_input_term (i, NONE) =
 		    [Free (Name.variant (all_vs @ param_vs) ("x" ^ string_of_int i), nth Ts2 (i - 1))]
 		  | mk_input_term (i, SOME pis) = case HOLogic.strip_tupleT (nth Ts2 (i - 1)) of
@@ -1501,7 +1504,7 @@
       else
         NONE
     val cl_ts =
-      map (compile_clause compfuns decr_depth
+      map (compile_clause compfuns mk_fun_of decr_depth
         thy all_vs param_vs mode (mk_tuple in_ts)) moded_cls;
     val compilation = foldr1 (mk_sup compfuns) cl_ts
     val T' = mk_predT compfuns (mk_tupleT Us2)
@@ -1730,11 +1733,12 @@
     val T = AList.lookup (op =) preds name |> the
     fun create_definition mode thy =
       let
+      val _ = Output.tracing ("mode: " ^ string_of_mode mode)
         val mode_cname = create_constname_of_mode thy "gen_" name mode
         val funT = generator_funT_of mode T
       in
         thy |> Sign.add_consts_i [(Binding.name (Long_Name.base_name mode_cname), funT, NoSyn)]
-        |> set_generator_name name mode mode_cname 
+        |> set_generator_name name mode mode_cname
       end;
   in
     fold create_definition modes thy
@@ -2409,7 +2413,7 @@
         goal_ctxt |> LocalTheory.theory (fold set_elim global_thms #>
           (if is_rpred options then
             (add_equations options [const] #>
-             add_depth_limited_equations options [const] #> add_quickcheck_equations options [const])
+              (*add_depth_limited_equations options [const] #> *)add_quickcheck_equations options [const])
            else if is_depth_limited options then
              add_depth_limited_equations options [const]
            else
@@ -2454,7 +2458,9 @@
       case depth_limit of
         NONE => inargs
       | SOME d => inargs @ [@{term "True"}, HOLogic.mk_number @{typ "code_numeral"} d]
-    val t_pred = compile_expr (is_some depth_limit) thy
+    val mk_fun_of = if random then mk_generator_of else
+      if (is_some depth_limit) then mk_depth_limited_fun_of else mk_fun_of
+    val t_pred = compile_expr (is_some depth_limit) thy compfuns mk_fun_of
       (m, list_comb (pred, params)) inargs';
     val t_eval = if null outargs then t_pred else
       let