--- a/src/HOL/Tools/inductive_codegen.ML Thu Nov 19 17:53:22 2009 -0800
+++ b/src/HOL/Tools/inductive_codegen.ML Thu Nov 19 20:09:56 2009 -0800
@@ -7,7 +7,10 @@
signature INDUCTIVE_CODEGEN =
sig
val add : string option -> int option -> attribute
+ val test_fn : (int * int * int -> term list option) Unsynchronized.ref
+ val test_term: Proof.context -> term -> int -> term list option
val setup : theory -> theory
+ val quickcheck_setup : theory -> theory
end;
structure InductiveCodegen : INDUCTIVE_CODEGEN =
@@ -124,7 +127,8 @@
fun print_modes modes = message ("Inferred modes:\n" ^
cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map
- string_of_mode ms)) modes));
+ (fn (m, rnd) => string_of_mode m ^
+ (if rnd then " (random)" else "")) ms)) modes));
val term_vs = map (fst o fst o dest_Var) o OldTerm.term_vars;
val terms_vs = distinct (op =) o maps term_vs;
@@ -152,14 +156,17 @@
fun cprods xss = List.foldr (map op :: o cprod) [[]] xss;
-datatype mode = Mode of (int list option list * int list) * int list * mode option list;
+datatype mode = Mode of ((int list option list * int list) * bool) * int list * mode option list;
+
+fun needs_random (Mode ((_, b), _, ms)) =
+ b orelse exists (fn NONE => false | SOME m => needs_random m) ms;
fun modes_of modes t =
let
val ks = 1 upto length (binder_types (fastype_of t));
- val default = [Mode (([], ks), ks, [])];
+ val default = [Mode ((([], ks), false), ks, [])];
fun mk_modes name args = Option.map
- (maps (fn (m as (iss, is)) =>
+ (maps (fn (m as ((iss, is), _)) =>
let
val (args1, args2) =
if length args < length iss then
@@ -180,8 +187,8 @@
in (case strip_comb t of
(Const ("op =", Type (_, [T, _])), _) =>
- [Mode (([], [1]), [1], []), Mode (([], [2]), [2], [])] @
- (if is_eqT T then [Mode (([], [1, 2]), [1, 2], [])] else [])
+ [Mode ((([], [1]), false), [1], []), Mode ((([], [2]), false), [2], [])] @
+ (if is_eqT T then [Mode ((([], [1, 2]), false), [1, 2], [])] else [])
| (Const (name, _), args) => the_default default (mk_modes name args)
| (Var ((name, _), _), args) => the (mk_modes name args)
| (Free (name, _), args) => the (mk_modes name args)
@@ -190,68 +197,101 @@
datatype indprem = Prem of term list * term * bool | Sidecond of term;
+fun missing_vars vs ts = subtract (fn (x, ((y, _), _)) => x = y) vs
+ (fold Term.add_vars ts []);
+
+fun monomorphic_vars vs = null (fold (Term.add_tvarsT o snd) vs []);
+
+fun mode_ord p = int_ord (pairself (fn (Mode ((_, rnd), _, _), vs) =>
+ length vs + (if null vs then 0 else 1) + (if rnd then 1 else 0)) p);
+
fun select_mode_prem thy modes vs ps =
- find_first (is_some o snd) (ps ~~ map
- (fn Prem (us, t, is_set) => find_first (fn Mode (_, is, _) =>
- let
- val (in_ts, out_ts) = get_args is 1 us;
- val (out_ts', in_ts') = List.partition (is_constrt thy) out_ts;
- val vTs = maps term_vTs out_ts';
- val dupTs = map snd (duplicates (op =) vTs) @
- map_filter (AList.lookup (op =) vTs) vs;
- in
- subset (op =) (terms_vs (in_ts @ in_ts'), vs) andalso
- forall (is_eqT o fastype_of) in_ts' andalso
- subset (op =) (term_vs t, vs) andalso
- forall is_eqT dupTs
- end)
- (if is_set then [Mode (([], []), [], [])]
- else modes_of modes t handle Option =>
- error ("Bad predicate: " ^ Syntax.string_of_term_global thy t))
- | Sidecond t => if subset (op =) (term_vs t, vs) then SOME (Mode (([], []), [], []))
- else NONE) ps);
+ sort (mode_ord o pairself (hd o snd))
+ (filter_out (null o snd) (ps ~~ map
+ (fn Prem (us, t, is_set) => sort mode_ord
+ (List.mapPartial (fn m as Mode (_, is, _) =>
+ let
+ val (in_ts, out_ts) = get_args is 1 us;
+ val (out_ts', in_ts') = List.partition (is_constrt thy) out_ts;
+ val vTs = maps term_vTs out_ts';
+ val dupTs = map snd (duplicates (op =) vTs) @
+ map_filter (AList.lookup (op =) vTs) vs;
+ val missing_vs = missing_vars vs (t :: in_ts @ in_ts')
+ in
+ if forall (is_eqT o fastype_of) in_ts' andalso forall is_eqT dupTs
+ andalso monomorphic_vars missing_vs
+ then SOME (m, missing_vs)
+ else NONE
+ end)
+ (if is_set then [Mode ((([], []), false), [], [])]
+ else modes_of modes t handle Option =>
+ error ("Bad predicate: " ^ Syntax.string_of_term_global thy t)))
+ | Sidecond t =>
+ let val missing_vs = missing_vars vs [t]
+ in
+ if monomorphic_vars missing_vs
+ then [(Mode ((([], []), false), [], []), missing_vs)]
+ else []
+ end)
+ ps));
-fun check_mode_clause thy arg_vs modes (iss, is) (ts, ps) =
+fun use_random () = "random_ind" mem !Codegen.mode;
+
+fun check_mode_clause thy arg_vs modes ((iss, is), rnd) (ts, ps) =
let
val modes' = modes @ map_filter
- (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
+ (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [(([], js), false)]))
(arg_vs ~~ iss);
- fun check_mode_prems vs [] = SOME vs
- | check_mode_prems vs ps = (case select_mode_prem thy modes' vs ps of
- NONE => NONE
- | SOME (x, _) => check_mode_prems
+ fun check_mode_prems vs rnd [] = SOME (vs, rnd)
+ | check_mode_prems vs rnd ps = (case select_mode_prem thy modes' vs ps of
+ (x, (m, []) :: _) :: _ => check_mode_prems
(case x of Prem (us, _, _) => union (op =) vs (terms_vs us) | _ => vs)
- (filter_out (equal x) ps));
+ (rnd orelse needs_random m)
+ (filter_out (equal x) ps)
+ | (_, (_, vs') :: _) :: _ =>
+ if use_random () then
+ check_mode_prems (union (op =) vs (map (fst o fst) vs')) true ps
+ else NONE
+ | _ => NONE);
val (in_ts, in_ts') = List.partition (is_constrt thy) (fst (get_args is 1 ts));
val in_vs = terms_vs in_ts;
- val concl_vs = terms_vs ts
in
- forall is_eqT (map snd (duplicates (op =) (maps term_vTs in_ts))) andalso
- forall (is_eqT o fastype_of) in_ts' andalso
- (case check_mode_prems (union (op =) arg_vs in_vs) ps of
- NONE => false
- | SOME vs => subset (op =) (concl_vs, vs))
+ if forall is_eqT (map snd (duplicates (op =) (maps term_vTs in_ts))) andalso
+ forall (is_eqT o fastype_of) in_ts'
+ then (case check_mode_prems (union (op =) arg_vs in_vs) rnd ps of
+ NONE => NONE
+ | SOME (vs, rnd') =>
+ let val missing_vs = missing_vars vs ts
+ in
+ if null missing_vs orelse
+ use_random () andalso monomorphic_vars missing_vs
+ then SOME (rnd' orelse not (null missing_vs))
+ else NONE
+ end)
+ else NONE
end;
fun check_modes_pred thy arg_vs preds modes (p, ms) =
let val SOME rs = AList.lookup (op =) preds p
- in (p, filter (fn m => case find_index
- (not o check_mode_clause thy arg_vs modes m) rs of
- ~1 => true
- | i => (message ("Clause " ^ string_of_int (i+1) ^ " of " ^
- p ^ " violates mode " ^ string_of_mode m); false)) ms)
+ in (p, List.mapPartial (fn m as (m', _) =>
+ let val xs = map (check_mode_clause thy arg_vs modes m) rs
+ in case find_index is_none xs of
+ ~1 => SOME (m', exists (fn SOME b => b) xs)
+ | i => (message ("Clause " ^ string_of_int (i+1) ^ " of " ^
+ p ^ " violates mode " ^ string_of_mode m'); NONE)
+ end) ms)
end;
-fun fixp f (x : (string * (int list option list * int list) list) list) =
+fun fixp f (x : (string * ((int list option list * int list) * bool) list) list) =
let val y = f x
in if x = y then x else fixp f y end;
fun infer_modes thy extra_modes arities arg_vs preds = fixp (fn modes =>
map (check_modes_pred thy arg_vs preds (modes @ extra_modes)) modes)
- (map (fn (s, (ks, k)) => (s, cprod (cprods (map
+ (map (fn (s, (ks, k)) => (s, map (rpair false) (cprod (cprods (map
(fn NONE => [NONE]
| SOME k' => map SOME (subsets 1 k')) ks),
- subsets 1 k))) arities);
+ subsets 1 k)))) arities);
(**** code generation ****)
@@ -318,7 +358,7 @@
apfst single (invoke_codegen thy defs dep module brack t gr)
| compile_expr _ _ _ _ _ _ (SOME _, Var ((name, _), _)) gr =
([str name], gr)
- | compile_expr thy defs dep module brack modes (SOME (Mode (mode, _, ms)), t) gr =
+ | compile_expr thy defs dep module brack modes (SOME (Mode ((mode, _), _, ms)), t) gr =
(case strip_comb t of
(Const (name, _), args) =>
if name = @{const_name "op ="} orelse AList.defined op = modes name then
@@ -344,7 +384,7 @@
fun compile_clause thy defs dep module all_vs arg_vs modes (iss, is) (ts, ps) inp gr =
let
val modes' = modes @ map_filter
- (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
+ (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [(([], js), false)]))
(arg_vs ~~ iss);
fun check_constrt t (names, eqs) =
@@ -371,24 +411,41 @@
val (out_ps', gr4) =
fold_map (invoke_codegen thy defs dep module false) out_ts''' gr3;
val (eq_ps', gr5) = fold_map compile_eq eqs' gr4;
+ val vs' = distinct (op =) (flat (vs :: map term_vs out_ts'));
+ val missing_vs = missing_vars vs' out_ts;
+ val final_p = Pretty.block
+ [str "DSeq.single", Pretty.brk 1, mk_tuple out_ps]
in
- (compile_match (snd nvs) (eq_ps @ eq_ps') out_ps'
- (Pretty.block [str "DSeq.single", Pretty.brk 1, mk_tuple out_ps])
- (exists (not o is_exhaustive) out_ts'''), gr5)
+ if null missing_vs then
+ (compile_match (snd nvs) (eq_ps @ eq_ps') out_ps'
+ final_p (exists (not o is_exhaustive) out_ts'''), gr5)
+ else
+ let
+ val (pat_p, gr6) = invoke_codegen thy defs dep module true
+ (HOLogic.mk_tuple (map Var missing_vs)) gr5;
+ val gen_p = mk_gen gr6 module true [] ""
+ (HOLogic.mk_tupleT (map snd missing_vs))
+ in
+ (compile_match (snd nvs) eq_ps' out_ps'
+ (Pretty.block [str "DSeq.generator ", gen_p,
+ str " :->", Pretty.brk 1,
+ compile_match [] eq_ps [pat_p] final_p false])
+ (exists (not o is_exhaustive) out_ts'''),
+ gr6)
+ end
end
| compile_prems out_ts vs names ps gr =
let
val vs' = distinct (op =) (flat (vs :: map term_vs out_ts));
- val SOME (p, mode as SOME (Mode (_, js, _))) = select_mode_prem thy modes' vs' ps;
- val ps' = filter_out (equal p) ps;
val (out_ts', (names', eqs)) = fold_map check_constrt out_ts (names, []);
val (out_ts'', nvs) = fold_map distinct_v out_ts' (names', map (fn x => (x, [x])) vs);
val (out_ps, gr0) = fold_map (invoke_codegen thy defs dep module false) out_ts'' gr;
val (eq_ps, gr1) = fold_map compile_eq eqs gr0;
in
- (case p of
- Prem (us, t, is_set) =>
+ (case hd (select_mode_prem thy modes' vs' ps) of
+ (p as Prem (us, t, is_set), (mode as Mode (_, js, _), []) :: _) =>
let
+ val ps' = filter_out (equal p) ps;
val (in_ts, out_ts''') = get_args js 1 us;
val (in_ps, gr2) = fold_map
(invoke_codegen thy defs dep module true) in_ts gr1;
@@ -398,7 +455,7 @@
(if null in_ps then [] else [Pretty.brk 1]) @
separate (Pretty.brk 1) in_ps)
(compile_expr thy defs dep module false modes
- (mode, t) gr2)
+ (SOME mode, t) gr2)
else
apfst (fn p => Pretty.breaks [str "DSeq.of_list", str "(case", p,
str "of", str "Set", str "xs", str "=>", str "xs)"])
@@ -411,8 +468,9 @@
[str " :->", Pretty.brk 1, rest]))
(exists (not o is_exhaustive) out_ts''), gr4)
end
- | Sidecond t =>
+ | (p as Sidecond t, [(_, [])]) =>
let
+ val ps' = filter_out (equal p) ps;
val (side_p, gr2) = invoke_codegen thy defs dep module true t gr1;
val (rest, gr3) = compile_prems [] vs' (fst nvs) ps' gr2;
in
@@ -420,6 +478,19 @@
(Pretty.block [str "?? ", side_p,
str " :->", Pretty.brk 1, rest])
(exists (not o is_exhaustive) out_ts''), gr3)
+ end
+ | (_, (_, missing_vs) :: _) =>
+ let
+ val T = HOLogic.mk_tupleT (map snd missing_vs);
+ val (_, gr2) = invoke_tycodegen thy defs dep module false T gr1;
+ val gen_p = mk_gen gr2 module true [] "" T;
+ val (rest, gr3) = compile_prems
+ [HOLogic.mk_tuple (map Var missing_vs)] vs' (fst nvs) ps gr2
+ in
+ (compile_match (snd nvs) eq_ps out_ps
+ (Pretty.block [str "DSeq.generator", Pretty.brk 1,
+ gen_p, str " :->", Pretty.brk 1, rest])
+ (exists (not o is_exhaustive) out_ts''), gr3)
end)
end;
@@ -450,7 +521,7 @@
fun compile_preds thy defs dep module all_vs arg_vs modes preds gr =
let val (prs, (gr', _)) = fold_map (fn (s, cls) =>
- fold_map (fn mode => fn (gr', prfx') => compile_pred thy defs
+ fold_map (fn (mode, _) => fn (gr', prfx') => compile_pred thy defs
dep module prfx' all_vs arg_vs modes s cls mode gr')
(((the o AList.lookup (op =) modes) s))) preds (gr, "fun ")
in
@@ -460,7 +531,7 @@
(**** processing of introduction rules ****)
exception Modes of
- (string * (int list option list * int list) list) list *
+ (string * ((int list option list * int list) * bool) list) list *
(string * (int option list * int)) list;
fun lookup_modes gr dep = apfst flat (apsnd flat (ListPair.unzip
@@ -480,7 +551,7 @@
(s,
case AList.lookup (op =) cs (s : string) of
NONE => xs
- | SOME xs' => inter (op =) xs' xs) :: constrain cs ys;
+ | SOME xs' => inter (op = o apfst fst) xs' xs) :: constrain cs ys;
fun mk_extra_defs thy defs gr dep names module ts =
fold (fn name => fn gr =>
@@ -573,6 +644,8 @@
if is_query then fst (fold mk_mode ts2 (([], []), 1))
else (ts2, 1 upto length (binder_types T) - k);
val mode = find_mode gr1 dep s u modes is;
+ val _ = if is_query orelse not (needs_random (the mode)) then ()
+ else warning ("Illegal use of random data generators in " ^ s);
val (in_ps, gr2) = fold_map (invoke_codegen thy defs dep module true) ts' gr1;
val (ps, gr3) = compile_expr thy defs dep module false modes (mode, u) gr2;
in
@@ -700,4 +773,91 @@
Scan.option (Args.$$$ "params" |-- Args.colon |-- OuterParse.nat) >> uncurry add))
"introduction rules for executable predicates";
+(**** Quickcheck involving inductive predicates ****)
+
+val test_fn : (int * int * int -> term list option) Unsynchronized.ref =
+ Unsynchronized.ref (fn _ => NONE);
+
+fun strip_imp p =
+ let val (q, r) = HOLogic.dest_imp p
+ in strip_imp r |>> cons q end
+ handle TERM _ => ([], p);
+
+fun deepen bound f i =
+ if i > bound then NONE
+ else (case f i of
+ NONE => deepen bound f (i + 1)
+ | SOME x => SOME x);
+
+val depth_bound_value =
+ Config.declare false "ind_quickcheck_depth" (Config.Int 10);
+val depth_bound = Config.int depth_bound_value;
+
+val depth_start_value =
+ Config.declare false "ind_quickcheck_depth_start" (Config.Int 1);
+val depth_start = Config.int depth_start_value;
+
+val random_values_value =
+ Config.declare false "ind_quickcheck_random" (Config.Int 5);
+val random_values = Config.int random_values_value;
+
+val size_offset_value =
+ Config.declare false "ind_quickcheck_size_offset" (Config.Int 0);
+val size_offset = Config.int size_offset_value;
+
+fun test_term ctxt t =
+ let
+ val thy = ProofContext.theory_of ctxt;
+ val (xs, p) = strip_abs t;
+ val args' = map_index (fn (i, (_, T)) => ("arg" ^ string_of_int i, T)) xs;
+ val args = map Free args';
+ val (ps, q) = strip_imp p;
+ val Ts = map snd xs;
+ val T = Ts ---> HOLogic.boolT;
+ val rl = Logic.list_implies
+ (map (HOLogic.mk_Trueprop o curry subst_bounds (rev args)) ps @
+ [HOLogic.mk_Trueprop (HOLogic.mk_not (subst_bounds (rev args, q)))],
+ HOLogic.mk_Trueprop (list_comb (Free ("quickcheckp", T), args)));
+ val (_, thy') = Inductive.add_inductive_global
+ {quiet_mode=true, verbose=false, alt_name=Binding.empty, coind=false,
+ no_elim=true, no_ind=false, skip_mono=false, fork_mono=false}
+ [((Binding.name "quickcheckp", T), NoSyn)] []
+ [(Attrib.empty_binding, rl)] [] (Theory.copy thy);
+ val pred = HOLogic.mk_Trueprop (list_comb
+ (Const (Sign.intern_const thy' "quickcheckp", T),
+ map Term.dummy_pattern Ts));
+ val (code, gr) = setmp_CRITICAL mode ["term_of", "test", "random_ind"]
+ (generate_code_i thy' [] "Generated") [("testf", pred)];
+ val s = "structure TestTerm =\nstruct\n\n" ^
+ cat_lines (map snd code) ^
+ "\nopen Generated;\n\n" ^ string_of
+ (Pretty.block [str "val () = InductiveCodegen.test_fn :=",
+ Pretty.brk 1, str "(fn p =>", Pretty.brk 1,
+ str "case Seq.pull (testf p) of", Pretty.brk 1,
+ str "SOME ", mk_tuple [mk_tuple (map (str o fst) args'), str "_"],
+ str " =>", Pretty.brk 1, str "SOME ",
+ Pretty.block (str "[" ::
+ Pretty.commas (map (fn (s, T) => Pretty.block
+ [mk_term_of gr "Generated" false T, Pretty.brk 1, str s]) args') @
+ [str "]"]), Pretty.brk 1,
+ str "| NONE => NONE);"]) ^
+ "\n\nend;\n";
+ val _ = ML_Context.eval_in (SOME ctxt) false Position.none s;
+ val values = Config.get ctxt random_values;
+ val bound = Config.get ctxt depth_bound;
+ val start = Config.get ctxt depth_start;
+ val offset = Config.get ctxt size_offset;
+ val test_fn' = !test_fn;
+ fun test k = deepen bound (fn i =>
+ (priority ("Search depth: " ^ string_of_int i);
+ test_fn' (i, values, k+offset))) start;
+ in test end;
+
+val quickcheck_setup =
+ Attrib.register_config depth_bound_value #>
+ Attrib.register_config depth_start_value #>
+ Attrib.register_config random_values_value #>
+ Attrib.register_config size_offset_value #>
+ Quickcheck.add_generator ("SML_inductive", test_term);
+
end;