--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML Wed Apr 21 12:10:52 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML Wed Apr 21 12:10:52 2010 +0200
@@ -39,6 +39,7 @@
datatype indprem = Prem of term | Negprem of term | Sidecond of term
| Generator of (string * typ)
val dest_indprem : indprem -> term
+ val map_indprem : (term -> term) -> indprem -> indprem
(* general syntactic functions *)
val conjuncts : term -> term list
val is_equationlike : thm -> bool
@@ -111,6 +112,7 @@
specialise : bool,
no_higher_order_predicate : string list,
inductify : bool,
+ detect_switches : bool,
compilation : compilation
};
val expected_modes : options -> (string * mode list) option
@@ -130,6 +132,7 @@
val specialise : options -> bool
val no_higher_order_predicate : options -> string list
val is_inductify : options -> bool
+ val detect_switches : options -> bool
val compilation : options -> compilation
val default_options : options
val bool_options : string list
@@ -394,6 +397,11 @@
| dest_indprem (Sidecond t) = t
| dest_indprem (Generator _) = raise Fail "cannot destruct generator"
+fun map_indprem f (Prem t) = Prem (f t)
+ | map_indprem f (Negprem t) = Negprem (f t)
+ | map_indprem f (Sidecond t) = Sidecond (f t)
+ | map_indprem f (Generator (v, T)) = Generator (dest_Free (f (Free (v, T))))
+
(* general syntactic functions *)
(*Like dest_conj, but flattens conjunctions however nested*)
@@ -677,6 +685,7 @@
fail_safe_function_flattening : bool,
no_higher_order_predicate : string list,
inductify : bool,
+ detect_switches : bool,
compilation : compilation
};
@@ -705,6 +714,8 @@
fun compilation (Options opt) = #compilation opt
+fun detect_switches (Options opt) = #detect_switches opt
+
val default_options = Options {
expected_modes = NONE,
proposed_modes = NONE,
@@ -723,12 +734,13 @@
fail_safe_function_flattening = false,
no_higher_order_predicate = [],
inductify = false,
+ detect_switches = true,
compilation = Pred
}
val bool_options = ["show_steps", "show_intermediate_results", "show_proof_trace", "show_modes",
"show_mode_inference", "show_compilation", "skip_proof", "inductify", "no_function_flattening",
- "specialise", "no_topmost_reordering"]
+ "detect_switches", "specialise", "no_topmost_reordering"]
fun print_step options s =
if show_steps options then tracing s else ()
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Wed Apr 21 12:10:52 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Wed Apr 21 12:10:52 2010 +0200
@@ -1541,7 +1541,6 @@
val thy' = fold (fn (s, ms) => if member (op =) (map fst preds) s then
set_needs_random s (map_filter (fn ((true, m), true) => SOME m | _ => NONE) ms) else I)
modes thy
-
in
((moded_clauses, errors), thy')
end;
@@ -1645,7 +1644,7 @@
val name' = Name.variant (name :: names) "y";
val T = HOLogic.mk_tupleT (map fastype_of out_ts);
val U = fastype_of success_t;
- val U' = dest_predT compfuns U;
+ val U' = dest_predT compfuns U;
val v = Free (name, T);
val v' = Free (name', T);
in
@@ -1705,13 +1704,12 @@
end
fun compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments
- mode inp (ts, moded_ps) =
+ mode inp (in_ts, out_ts) moded_ps =
let
val compfuns = Comp_Mod.compfuns compilation_modifiers
- val iss = ho_arg_modes_of mode
+ val iss = ho_arg_modes_of mode (* FIXME! *)
val compile_match = compile_match compilation_modifiers
additional_arguments param_vs iss ctxt
- val (in_ts, out_ts) = split_mode mode ts;
val (in_ts', (all_vs', eqs)) =
fold_map (collect_non_invertible_subterms ctxt) in_ts (all_vs, []);
fun compile_prems out_ts' vs names [] =
@@ -1779,7 +1777,171 @@
mk_bind compfuns (mk_single compfuns inp, prem_t)
end
-fun compile_pred compilation_modifiers thy all_vs param_vs s T (pol, mode) moded_cls =
+(* switch detection *)
+
+(** argument position of an inductive predicates and the executable functions **)
+
+type position = int * int list
+
+fun input_positions_pair Input = [[]]
+ | input_positions_pair Output = []
+ | input_positions_pair (Fun _) = []
+ | input_positions_pair (Pair (m1, m2)) =
+ map (cons 1) (input_positions_pair m1) @ map (cons 2) (input_positions_pair m2)
+
+fun input_positions_of_mode mode = flat (map_index
+ (fn (i, Input) => [(i, [])]
+ | (_, Output) => []
+ | (_, Fun _) => []
+ | (i, m as Pair (m1, m2)) => map (pair i) (input_positions_pair m))
+ (Predicate_Compile_Aux.strip_fun_mode mode))
+
+fun argument_position_pair mode [] = []
+ | argument_position_pair (Pair (Fun _, m2)) (2 :: is) = argument_position_pair m2 is
+ | argument_position_pair (Pair (m1, m2)) (i :: is) =
+ (if eq_mode (m1, Output) andalso i = 2 then
+ argument_position_pair m2 is
+ else if eq_mode (m2, Output) andalso i = 1 then
+ argument_position_pair m1 is
+ else (i :: argument_position_pair (if i = 1 then m1 else m2) is))
+
+fun argument_position_of mode (i, is) =
+ (i - (length (filter (fn Output => true | Fun _ => true | _ => false)
+ (List.take (strip_fun_mode mode, i)))),
+ argument_position_pair (nth (strip_fun_mode mode) i) is)
+
+fun nth_pair [] t = t
+ | nth_pair (1 :: is) (Const (@{const_name Pair}, _) $ t1 $ _) = nth_pair is t1
+ | nth_pair (2 :: is) (Const (@{const_name Pair}, _) $ _ $ t2) = nth_pair is t2
+ | nth_pair _ _ = raise Fail "unexpected input for nth_tuple"
+
+(** switch detection analysis **)
+
+fun find_switch_test thy (i, is) (ts, prems) =
+ let
+ val t = nth_pair is (nth ts i)
+ val T = fastype_of t
+ in
+ case T of
+ TFree _ => NONE
+ | Type (Tcon, _) =>
+ (case Datatype_Data.get_constrs thy Tcon of
+ NONE => NONE
+ | SOME cs =>
+ (case strip_comb t of
+ (Var _, []) => NONE
+ | (Free _, []) => NONE
+ | (Const (c, T), _) => if AList.defined (op =) cs c then SOME (c, T) else NONE))
+ end
+
+fun partition_clause thy pos moded_clauses =
+ let
+ fun insert_list eq (key, value) = AList.map_default eq (key, []) (cons value)
+ fun find_switch_test' moded_clause (cases, left) =
+ case find_switch_test thy pos moded_clause of
+ SOME (c, T) => (insert_list (op =) ((c, T), moded_clause) cases, left)
+ | NONE => (cases, moded_clause :: left)
+ in
+ fold find_switch_test' moded_clauses ([], [])
+ end
+
+datatype switch_tree =
+ Atom of moded_clause list | Node of (position * ((string * typ) * switch_tree) list) * switch_tree
+
+fun mk_switch_tree thy mode moded_clauses =
+ let
+ fun select_best_switch moded_clauses input_position best_switch =
+ let
+ val ord = option_ord (rev_order o int_ord o (pairself (length o snd o snd)))
+ val partition = partition_clause thy input_position moded_clauses
+ val switch = if (length (fst partition) > 1) then SOME (input_position, partition) else NONE
+ in
+ case ord (switch, best_switch) of LESS => best_switch
+ | EQUAL => best_switch | GREATER => switch
+ end
+ fun detect_switches moded_clauses =
+ case fold (select_best_switch moded_clauses) (input_positions_of_mode mode) NONE of
+ SOME (best_pos, (switched_on, left_clauses)) =>
+ Node ((best_pos, map (apsnd detect_switches) switched_on),
+ detect_switches left_clauses)
+ | NONE => Atom moded_clauses
+ in
+ detect_switches moded_clauses
+ end
+
+(** compilation of detected switches **)
+
+fun destruct_constructor_pattern (pat, obj) =
+ (case strip_comb pat of
+ (f as Free _, []) => cons (pat, obj)
+ | (Const (c, T), pat_args) =>
+ (case strip_comb obj of
+ (Const (c', T'), obj_args) =>
+ (if c = c' andalso T = T' then
+ fold destruct_constructor_pattern (pat_args ~~ obj_args)
+ else raise Fail "pattern and object mismatch")
+ | _ => raise Fail "unexpected object")
+ | _ => raise Fail "unexpected pattern")
+
+
+fun compile_switch compilation_modifiers ctxt all_vs param_vs additional_arguments mode
+ in_ts' outTs switch_tree =
+ let
+ val compfuns = Comp_Mod.compfuns compilation_modifiers
+ val thy = ProofContext.theory_of ctxt
+ fun compile_switch_tree _ _ (Atom []) = NONE
+ | compile_switch_tree all_vs ctxt_eqs (Atom moded_clauses) =
+ let
+ val in_ts' = map (Pattern.rewrite_term thy ctxt_eqs []) in_ts'
+ fun compile_clause' (ts, moded_ps) =
+ let
+ val (ts, out_ts) = split_mode mode ts
+ val subst = fold destruct_constructor_pattern (in_ts' ~~ ts) []
+ val (fsubst, pat') = List.partition (fn (_, Free _) => true | _ => false) subst
+ val moded_ps' = (map o apfst o map_indprem)
+ (Pattern.rewrite_term thy (map swap fsubst) []) moded_ps
+ val inp = HOLogic.mk_tuple (map fst pat')
+ val in_ts' = map (Pattern.rewrite_term thy (map swap fsubst) []) (map snd pat')
+ val out_ts' = map (Pattern.rewrite_term thy (map swap fsubst) []) out_ts
+ in
+ compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments
+ mode inp (in_ts', out_ts') moded_ps'
+ end
+ in SOME (foldr1 (mk_sup compfuns) (map compile_clause' moded_clauses)) end
+ | compile_switch_tree all_vs ctxt_eqs (Node ((position, switched_clauses), left_clauses)) =
+ let
+ val (i, is) = argument_position_of mode position
+ val inp_var = nth_pair is (nth in_ts' i)
+ val x = Name.variant all_vs "x"
+ val xt = Free (x, fastype_of inp_var)
+ fun compile_single_case ((c, T), switched) =
+ let
+ val Ts = binder_types T
+ val argnames = Name.variant_list (x :: all_vs)
+ (map (fn i => "c" ^ string_of_int i) (1 upto length Ts))
+ val args = map2 (curry Free) argnames Ts
+ val pattern = list_comb (Const (c, T), args)
+ val ctxt_eqs' = (inp_var, pattern) :: ctxt_eqs
+ val compilation = the_default (mk_bot compfuns (HOLogic.mk_tupleT outTs))
+ (compile_switch_tree (argnames @ x :: all_vs) ctxt_eqs' switched)
+ in
+ (pattern, compilation)
+ end
+ val switch = fst (Datatype.make_case ctxt Datatype_Case.Quiet [] inp_var
+ ((map compile_single_case switched_clauses) @
+ [(xt, mk_bot compfuns (HOLogic.mk_tupleT outTs))]))
+ in
+ case compile_switch_tree all_vs ctxt_eqs left_clauses of
+ NONE => SOME switch
+ | SOME left_comp => SOME (mk_sup compfuns (switch, left_comp))
+ end
+ in
+ compile_switch_tree all_vs [] switch_tree
+ end
+
+(* compilation of predicates *)
+
+fun compile_pred options compilation_modifiers thy all_vs param_vs s T (pol, mode) moded_cls =
let
val ctxt = ProofContext.init thy
val compilation_modifiers = if pol then compilation_modifiers else
@@ -1794,7 +1956,6 @@
(binder_types T)
val predT = mk_predT compfuns (HOLogic.mk_tupleT outTs)
val funT = Comp_Mod.funT_of compilation_modifiers mode T
-
val (in_ts, _) = fold_map (fold_map_aterms_prodT (curry HOLogic.mk_prod)
(fn T => fn (param_vs, names) =>
if is_param_type T then
@@ -1806,14 +1967,24 @@
(param_vs, (all_vs @ param_vs))
val in_ts' = map_filter (map_filter_prod
(fn t as Free (x, _) => if member (op =) param_vs x then NONE else SOME t | t => SOME t)) in_ts
- val cl_ts =
- map (compile_clause compilation_modifiers
- ctxt all_vs param_vs additional_arguments mode (HOLogic.mk_tuple in_ts')) moded_cls;
- val compilation = Comp_Mod.wrap_compilation compilation_modifiers compfuns
- s T mode additional_arguments
- (if null cl_ts then
- mk_bot compfuns (HOLogic.mk_tupleT outTs)
- else foldr1 (mk_sup compfuns) cl_ts)
+ val compilation =
+ if detect_switches options then
+ the_default (mk_bot compfuns (HOLogic.mk_tupleT outTs))
+ (compile_switch compilation_modifiers ctxt all_vs param_vs additional_arguments
+ mode in_ts' outTs (mk_switch_tree thy mode moded_cls))
+ else
+ let
+ val cl_ts =
+ map (fn (ts, moded_prems) =>
+ compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments
+ mode (HOLogic.mk_tuple in_ts') (split_mode mode ts) moded_prems) moded_cls;
+ in
+ Comp_Mod.wrap_compilation compilation_modifiers compfuns s T mode additional_arguments
+ (if null cl_ts then
+ mk_bot compfuns (HOLogic.mk_tupleT outTs)
+ else
+ foldr1 (mk_sup compfuns) cl_ts)
+ end
val fun_const =
Const (function_name_of (Comp_Mod.compilation compilation_modifiers)
(ProofContext.theory_of ctxt) s mode, funT)
@@ -1822,7 +1993,7 @@
(HOLogic.mk_eq (list_comb (fun_const, in_ts @ additional_arguments), compilation))
end;
-(* special setup for simpset *)
+(** special setup for simpset **)
val HOL_basic_ss' = HOL_basic_ss addsimps (@{thms HOL.simp_thms} @ [@{thm Pair_eq}])
setSolver (mk_solver "all_tac_solver" (fn _ => fn _ => all_tac))
setSolver (mk_solver "True_solver" (fn _ => rtac @{thm TrueI}))
@@ -2463,13 +2634,13 @@
fun join_preds_modes table1 table2 =
map_preds_modes (fn pred => fn mode => fn value =>
(value, the (AList.lookup (op =) (the (AList.lookup (op =) table2 pred)) mode))) table1
-
+
fun maps_modes preds_modes_table =
map (fn (pred, modes) =>
(pred, map (fn (mode, value) => value) modes)) preds_modes_table
-
-fun compile_preds comp_modifiers thy all_vs param_vs preds moded_clauses =
- map_preds_modes (fn pred => compile_pred comp_modifiers thy all_vs param_vs pred
+
+fun compile_preds options comp_modifiers thy all_vs param_vs preds moded_clauses =
+ map_preds_modes (fn pred => compile_pred options comp_modifiers thy all_vs param_vs pred
(the (AList.lookup (op =) preds pred))) moded_clauses
fun prove options thy clauses preds moded_clauses compiled_terms =
@@ -2482,7 +2653,6 @@
compiled_terms
(* preparation of introduction rules into special datastructures *)
-
fun dest_prem thy params t =
(case strip_comb t of
(v as Free _, ts) => if member (op =) params v then Prem t else Sidecond t
@@ -2626,9 +2796,6 @@
datatype steps = Steps of
{
define_functions : options -> (string * typ) list -> string * (bool * mode) list -> theory -> theory,
- (*infer_modes : options -> (string * typ) list -> (string * mode list) list
- -> string list -> (string * (term list * indprem list) list) list
- -> theory -> ((moded_clause list pred_mode_table * string list) * theory),*)
prove : options -> theory -> (string * (term list * indprem list) list) list -> (string * typ) list
-> moded_clause list pred_mode_table -> term pred_mode_table -> thm pred_mode_table,
add_code_equations : theory -> (string * typ) list
@@ -2669,8 +2836,9 @@
|> Theory.checkpoint)
val _ = print_step options "Compiling equations..."
val compiled_terms =
- Output.cond_timeit true "Compiling equations...." (fn _ =>
- compile_preds (#comp_modifiers (dest_steps steps)) thy'' all_vs param_vs preds moded_clauses)
+ (*Output.cond_timeit true "Compiling equations...." (fn _ =>*)
+ compile_preds options
+ (#comp_modifiers (dest_steps steps)) thy'' all_vs param_vs preds moded_clauses
val _ = print_compiled_terms options thy'' compiled_terms
val _ = print_step options "Proving equations..."
val result_thms =