src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
changeset 33137 0d16c07f8d24
parent 33135 422cac7d6e31
child 33138 e2e23987c59a
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Sat Oct 24 16:55:42 2009 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Sat Oct 24 16:55:42 2009 +0200
@@ -12,14 +12,12 @@
   type smode = (int * int list option) list
   type mode = smode option list * smode
   datatype tmode = Mode of mode * smode * tmode option list;
-  (*val add_equations_of: bool -> string list -> theory -> theory *)
   val register_predicate : (thm list * thm * int) -> theory -> theory
   val register_intros : thm list -> theory -> theory
   val is_registered : theory -> string -> bool
- (* val fetch_pred_data : theory -> string -> (thm list * thm * int)  *)
   val predfun_intro_of: theory -> string -> mode -> thm
   val predfun_elim_of: theory -> string -> mode -> thm
-  val strip_intro_concl: int -> term -> term * (term list * term list)
+    (*  val strip_intro_concl: int -> term -> term * (term list * term list)*)
   val predfun_name_of: theory -> string -> mode -> string
   val all_preds_of : theory -> string list
   val modes_of: theory -> string -> mode list
@@ -37,22 +35,15 @@
   val print_all_modes: theory -> unit
   val do_proofs: bool Unsynchronized.ref
   val mk_casesrule : Proof.context -> int -> thm list -> term
-  val analyze_compr: theory -> int option -> term -> term
-  val eval_ref: (unit -> term Predicate.pred) option Unsynchronized.ref
+    (*  val analyze_compr: theory -> compfuns -> int option * bool -> term -> term*)
+  val eval_ref : (unit -> term Predicate.pred) option Unsynchronized.ref
+  val random_eval_ref : (unit -> int * int -> term Predicate.pred * (int * int)) option Unsynchronized.ref
   val add_equations : Predicate_Compile_Aux.options -> string list -> theory -> theory
   val code_pred_intros_attrib : attribute
   (* used by Quickcheck_Generator *) 
-  (*val funT_of : mode -> typ -> typ
-  val mk_if_pred : term -> term
-  val mk_Eval : term * term -> term*)
-  val mk_tupleT : typ list -> typ
-(*  val mk_predT :  typ -> typ *)
   (* temporary for testing of the compilation *)
   datatype indprem = Prem of term list * term | Negprem of term list * term | Sidecond of term |
     GeneratorPrem of term list * term | Generator of (string * typ);
- (* val prepare_intrs: theory -> string list ->
-    (string * typ) list * int * string list * string list * (string * mode list) list *
-    (string * (term list * indprem list) list) list * (string * (int option list * int)) list*)
   datatype compilation_funs = CompilationFuns of {
     mk_predT : typ -> typ,
     dest_predT : typ -> typ,
@@ -900,7 +891,8 @@
 
 fun mk_not t = error "Negation is not defined for RPred"
 
-fun mk_map t = error "FIXME" (*FIXME*)
+fun mk_map T1 T2 tf tp = Const (@{const_name RPred.map},
+  (T1 --> T2) --> mk_rpredT T1 --> mk_rpredT T2) $ tf $ tp
 
 fun lift_pred t =
   let
@@ -1188,7 +1180,7 @@
     (p, map (fn m =>
       (m, map (the o check_mode_clause with_generator thy param_vs modes gen_modes m) rs)) ms)
   end;
-  
+
 fun fixp f (x : (string * mode list) list) =
   let val y = f x
   in if x = y then x else fixp f y end;
@@ -1295,34 +1287,33 @@
      list_comb (f', params' @ args')
    end
 
-fun compile_expr depth_limited thy ((Mode (mode, is, ms)), t) =
+fun compile_expr depth_limited thy ((Mode (mode, is, ms)), t) inargs =
   case strip_comb t of
     (Const (name, T), params) =>
        let
          val params' = map (compile_param depth_limited thy PredicateCompFuns.compfuns) (ms ~~ params)
          val mk_fun_of = if depth_limited then mk_depth_limited_fun_of else mk_fun_of
        in
-         list_comb (mk_fun_of PredicateCompFuns.compfuns thy (name, T) mode, params')
+         list_comb (mk_fun_of PredicateCompFuns.compfuns thy (name, T) mode, params' @ inargs)
        end
-  | (Free (name, T), args) =>
+  | (Free (name, T), params) =>
        let 
          val funT_of = if depth_limited then depth_limited_funT_of else funT_of
        in
-         list_comb (Free (name, funT_of PredicateCompFuns.compfuns ([], is) T), args)
+         list_comb (Free (name, funT_of PredicateCompFuns.compfuns ([], is) T), params @ inargs)
        end;
        
-fun compile_gen_expr depth thy compfuns ((Mode (mode, is, ms)), t) inargs =
+fun compile_gen_expr depth thy ((Mode (mode, is, ms)), t) inargs =
   case strip_comb t of
     (Const (name, T), params) =>
       let
-        val params' = map (compile_param depth thy PredicateCompFuns.compfuns) (ms ~~ params)
+        val params' = map (compile_param depth thy RPredCompFuns.compfuns) (ms ~~ params)
       in
-        list_comb (mk_generator_of compfuns thy (name, T) mode, params' @ inargs)
+        list_comb (mk_generator_of RPredCompFuns.compfuns thy (name, T) mode, params' @ inargs)
       end
-    | (Free (name, T), params) =>
-    lift_pred compfuns
-    (list_comb (Free (name, depth_limited_funT_of PredicateCompFuns.compfuns ([], is) T), params @ inargs))
-      
+  | (Free (name, T), params) =>
+    lift_pred RPredCompFuns.compfuns
+      (list_comb (Free (name, depth_limited_funT_of RPredCompFuns.compfuns ([], is) T), params @ inargs))
           
 (** specific rpred functions -- move them to the correct place in this file *)
 
@@ -1431,7 +1422,7 @@
                      NONE => in_ts
                    | SOME (polarity, depth_t) => in_ts @ [polarity, depth_t]
                    val u = lift_pred compfuns
-                     (list_comb (compile_expr (is_some depth) thy (mode, t), args))
+                     (compile_expr (is_some depth) thy (mode, t) args)
                    val rest = compile_prems out_ts''' vs' names'' ps
                  in
                    (u, rest)
@@ -1443,7 +1434,7 @@
                      NONE => in_ts
                    | SOME (polarity, depth_t) => in_ts @ [HOLogic.mk_not polarity, depth_t]
                    val u = lift_pred compfuns (mk_not PredicateCompFuns.compfuns
-                   (list_comb (compile_expr (is_some depth) thy (mode, t), args)))
+                     (compile_expr (is_some depth) thy (mode, t) args))
                    val rest = compile_prems out_ts''' vs' names'' ps
                  in
                    (u, rest)
@@ -1460,7 +1451,7 @@
                    val args = case depth of
                      NONE => in_ts
                      | SOME (polarity, depth_t) => in_ts @ [polarity, depth_t]
-                   val u = compile_gen_expr (is_some depth) thy compfuns (mode, t) args
+                   val u = compile_gen_expr (is_some depth) thy (mode, t) args
                    val rest = compile_prems out_ts''' vs' names'' ps
                  in
                    (u, rest)
@@ -1729,9 +1720,9 @@
   let
     val Ts = binder_types T
     val (paramTs, (inargTs, outargTs)) = split_modeT (iss, is) Ts
-    val paramTs' = map2 (fn SOME is => depth_limited_funT_of PredicateCompFuns.compfuns ([], is) | NONE => I) iss paramTs 
+    val paramTs' = map2 (fn SOME is => generator_funT_of ([], is) | NONE => I) iss paramTs
   in
-    (paramTs' @ inargTs @ [@{typ "code_numeral"}]) ---> (mk_predT RPredCompFuns.compfuns (mk_tupleT outargTs))
+    (paramTs' @ inargTs @ [@{typ "bool"}, @{typ "code_numeral"}]) ---> (mk_predT RPredCompFuns.compfuns (mk_tupleT outargTs))
   end
 
 fun rpred_create_definitions preds (name, modes) thy =
@@ -2301,7 +2292,6 @@
     val result_thms = #prove steps options thy' clauses preds (extra_modes @ modes)
       moded_clauses compiled_terms
     val qname = #qname steps
-    (* val attrib = gn thy => Attrib.attribute_i thy Code.add_eqn_attrib *)
     val attrib = fn thy => Attrib.attribute_i thy (Attrib.internal (K (Thm.declaration_attribute
       (fn thm => Context.mapping (Code.add_eqn thm) I))))
     val thy'' = fold (fn (name, result_thms) => fn thy => snd (PureThy.add_thmss
@@ -2435,9 +2425,10 @@
 (* transformation for code generation *)
 
 val eval_ref = Unsynchronized.ref (NONE : (unit -> term Predicate.pred) option);
+val random_eval_ref = Unsynchronized.ref (NONE : (unit -> int * int -> term Predicate.pred * (int * int)) option);
 
 (*FIXME turn this into an LCF-guarded preprocessor for comprehensions*)
-fun analyze_compr thy depth_limit t_compr =
+fun analyze_compr thy compfuns (depth_limit, random) t_compr =
   let
     val split = case t_compr of (Const (@{const_name Collect}, _) $ t) => t
       | _ => error ("Not a set comprehension: " ^ Syntax.string_of_term_global thy t_compr);
@@ -2448,6 +2439,8 @@
       (fn (i, t) => case t of Bound j => if j < length Ts then NONE
         else SOME (i+1) | _ => SOME (i+1)) args); (*FIXME dangling bounds should not occur*)
     val user_mode' = map (rpair NONE) user_mode
+    val all_modes_of = if random then all_generator_modes_of else all_modes_of
+    val compile_expr = if random then compile_gen_expr else compile_expr
     val modes = filter (fn Mode (_, is, _) => is = user_mode')
       (modes_of_term (all_modes_of thy) (list_comb (pred, params)));
     val m = case modes
@@ -2457,10 +2450,12 @@
       | m :: _ :: _ => (warning ("Multiple modes possible for comprehension "
                 ^ Syntax.string_of_term_global thy t_compr); m);
     val (inargs, outargs) = split_smode user_mode' args;
-    val inargs' = case depth_limit of NONE => inargs
+    val inargs' =
+      case depth_limit of
+        NONE => inargs
       | SOME d => inargs @ [@{term "True"}, HOLogic.mk_number @{typ "code_numeral"} d]
-    val t_pred = list_comb (compile_expr (is_some depth_limit) thy
-      (m, list_comb (pred, params)), inargs');
+    val t_pred = compile_expr (is_some depth_limit) thy
+      (m, list_comb (pred, params)) inargs';
     val t_eval = if null outargs then t_pred else
       let
         val outargs_bounds = map (fn Bound i => i) outargs;
@@ -2473,22 +2468,30 @@
         val arrange = funpow (length outargs_bounds - 1) HOLogic.mk_split
           (Term.list_abs (map (pair "") outargsTs,
             HOLogic.mk_ptuple fp T_compr (map Bound arrange_bounds)))
-      in mk_map PredicateCompFuns.compfuns T_pred T_compr arrange t_pred end
+      in mk_map compfuns T_pred T_compr arrange t_pred end
   in t_eval end;
 
-fun eval thy depth_limit t_compr =
+fun eval thy (options as (depth_limit, random)) t_compr =
   let
-    val t = analyze_compr thy depth_limit t_compr;
-    val T = dest_predT PredicateCompFuns.compfuns (fastype_of t);
-    val t' = mk_map PredicateCompFuns.compfuns T HOLogic.termT (HOLogic.term_of_const T) t;
-  in (T, Code_ML.eval NONE ("Predicate_Compile_Core.eval_ref", eval_ref) Predicate.map thy t' []) end;
+    val compfuns = if random then RPredCompFuns.compfuns else PredicateCompFuns.compfuns
+    val t = analyze_compr thy compfuns options t_compr;
+    val T = dest_predT compfuns (fastype_of t);
+    val t' = mk_map compfuns T HOLogic.termT (HOLogic.term_of_const T) t;
+    val eval =
+      if random then
+        Code_ML.eval NONE ("Predicate_Compile_Core.random_eval_ref", random_eval_ref)
+            (fn proc => fn g => fn s => g s |>> Predicate.map proc) thy t' []
+          |> Random_Engine.run
+      else
+        Code_ML.eval NONE ("Predicate_Compile_Core.eval_ref", eval_ref) Predicate.map thy t' []
+  in (T, eval) end;
 
-fun values ctxt depth_limit k t_compr =
+fun values ctxt options k t_compr =
   let
     val thy = ProofContext.theory_of ctxt;
-    val (T, t) = eval thy depth_limit t_compr;
+    val (T, ts) = eval thy options t_compr;
+    val (ts, _) = Predicate.yieldn k ts;
     val setT = HOLogic.mk_setT T;
-    val (ts, _) = Predicate.yieldn k t;
     val elemsT = HOLogic.mk_set T ts;
   in if k = ~1 orelse length ts < k then elemsT
     else Const (@{const_name Set.union}, setT --> setT --> setT) $ elemsT $ t_compr
@@ -2501,11 +2504,11 @@
   in
   end;
   *)
-fun values_cmd modes depth_limit k raw_t state =
+fun values_cmd modes options k raw_t state =
   let
     val ctxt = Toplevel.context_of state;
     val t = Syntax.read_term ctxt raw_t;
-    val t' = values ctxt depth_limit k t;
+    val t' = values ctxt options k t;
     val ty' = Term.type_of t';
     val ctxt' = Variable.auto_fixes t' ctxt;
     val p = PrintMode.with_modes modes (fn () =>
@@ -2517,15 +2520,20 @@
 
 val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) [];
 
-val _ = List.app OuterKeyword.keyword ["depth_limit"]
+val _ = List.app OuterKeyword.keyword ["depth_limit", "random"]
 
-val opt_depth_limit =
-  Scan.optional (P.$$$ "[" |-- P.$$$ "depth_limit" |-- P.$$$ "=" |-- P.nat --| P.$$$ "]" >> SOME) NONE
+val options =
+  let
+    val depth_limit = Scan.optional (P.$$$ "depth_limit" |-- P.$$$ "=" |-- P.nat >> SOME) NONE
+    val random = Scan.optional (P.$$$ "random" >> K true) false
+  in
+    Scan.optional (P.$$$ "[" |-- depth_limit -- random --| P.$$$ "]") (NONE, false)
+  end
 
 val _ = OuterSyntax.improper_command "values" "enumerate and print comprehensions" OuterKeyword.diag
-  (opt_modes -- opt_depth_limit -- Scan.optional P.nat ~1 -- P.term
-    >> (fn (((modes, depth_limit), k), t) => Toplevel.no_timing o Toplevel.keep
-        (values_cmd modes depth_limit k t)));
+  (opt_modes -- options -- Scan.optional P.nat ~1 -- P.term
+    >> (fn (((modes, options), k), t) => Toplevel.no_timing o Toplevel.keep
+        (values_cmd modes options k t)));
 
 end;