# HG changeset patch # User bulwahn # Date 1271844652 -7200 # Node ID 95ef0a3cf31c9e405d2c4c0e4a790ca4c5fbe612 # Parent 6e969ce3dfcc5f508a7b20e34e797f4feb244474 added switch detection to the predicate compiler diff -r 6e969ce3dfcc -r 95ef0a3cf31c src/HOL/Tools/Predicate_Compile/predicate_compile.ML --- a/src/HOL/Tools/Predicate_Compile/predicate_compile.ML Wed Apr 21 12:10:52 2010 +0200 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile.ML Wed Apr 21 12:10:52 2010 +0200 @@ -162,6 +162,7 @@ no_topmost_reordering = (chk "no_topmost_reordering"), no_higher_order_predicate = [], inductify = chk "inductify", + detect_switches = chk "detect_switches", compilation = compilation } end diff -r 6e969ce3dfcc -r 95ef0a3cf31c src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML --- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML Wed Apr 21 12:10:52 2010 +0200 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML Wed Apr 21 12:10:52 2010 +0200 @@ -39,6 +39,7 @@ datatype indprem = Prem of term | Negprem of term | Sidecond of term | Generator of (string * typ) val dest_indprem : indprem -> term + val map_indprem : (term -> term) -> indprem -> indprem (* general syntactic functions *) val conjuncts : term -> term list val is_equationlike : thm -> bool @@ -111,6 +112,7 @@ specialise : bool, no_higher_order_predicate : string list, inductify : bool, + detect_switches : bool, compilation : compilation }; val expected_modes : options -> (string * mode list) option @@ -130,6 +132,7 @@ val specialise : options -> bool val no_higher_order_predicate : options -> string list val is_inductify : options -> bool + val detect_switches : options -> bool val compilation : options -> compilation val default_options : options val bool_options : string list @@ -394,6 +397,11 @@ | dest_indprem (Sidecond t) = t | dest_indprem (Generator _) = raise Fail "cannot destruct generator" +fun map_indprem f (Prem t) = Prem (f t) + | map_indprem f (Negprem t) = Negprem (f t) + | map_indprem f (Sidecond t) = Sidecond (f t) + | map_indprem f (Generator (v, T)) = Generator (dest_Free (f (Free (v, T)))) + (* general syntactic functions *) (*Like dest_conj, but flattens conjunctions however nested*) @@ -677,6 +685,7 @@ fail_safe_function_flattening : bool, no_higher_order_predicate : string list, inductify : bool, + detect_switches : bool, compilation : compilation }; @@ -705,6 +714,8 @@ fun compilation (Options opt) = #compilation opt +fun detect_switches (Options opt) = #detect_switches opt + val default_options = Options { expected_modes = NONE, proposed_modes = NONE, @@ -723,12 +734,13 @@ fail_safe_function_flattening = false, no_higher_order_predicate = [], inductify = false, + detect_switches = true, compilation = Pred } val bool_options = ["show_steps", "show_intermediate_results", "show_proof_trace", "show_modes", "show_mode_inference", "show_compilation", "skip_proof", "inductify", "no_function_flattening", - "specialise", "no_topmost_reordering"] + "detect_switches", "specialise", "no_topmost_reordering"] fun print_step options s = if show_steps options then tracing s else () diff -r 6e969ce3dfcc -r 95ef0a3cf31c src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML --- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Wed Apr 21 12:10:52 2010 +0200 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Wed Apr 21 12:10:52 2010 +0200 @@ -1541,7 +1541,6 @@ val thy' = fold (fn (s, ms) => if member (op =) (map fst preds) s then set_needs_random s (map_filter (fn ((true, m), true) => SOME m | _ => NONE) ms) else I) modes thy - in ((moded_clauses, errors), thy') end; @@ -1645,7 +1644,7 @@ val name' = Name.variant (name :: names) "y"; val T = HOLogic.mk_tupleT (map fastype_of out_ts); val U = fastype_of success_t; - val U' = dest_predT compfuns U; + val U' = dest_predT compfuns U; val v = Free (name, T); val v' = Free (name', T); in @@ -1705,13 +1704,12 @@ end fun compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments - mode inp (ts, moded_ps) = + mode inp (in_ts, out_ts) moded_ps = let val compfuns = Comp_Mod.compfuns compilation_modifiers - val iss = ho_arg_modes_of mode + val iss = ho_arg_modes_of mode (* FIXME! *) val compile_match = compile_match compilation_modifiers additional_arguments param_vs iss ctxt - val (in_ts, out_ts) = split_mode mode ts; val (in_ts', (all_vs', eqs)) = fold_map (collect_non_invertible_subterms ctxt) in_ts (all_vs, []); fun compile_prems out_ts' vs names [] = @@ -1779,7 +1777,171 @@ mk_bind compfuns (mk_single compfuns inp, prem_t) end -fun compile_pred compilation_modifiers thy all_vs param_vs s T (pol, mode) moded_cls = +(* switch detection *) + +(** argument position of an inductive predicates and the executable functions **) + +type position = int * int list + +fun input_positions_pair Input = [[]] + | input_positions_pair Output = [] + | input_positions_pair (Fun _) = [] + | input_positions_pair (Pair (m1, m2)) = + map (cons 1) (input_positions_pair m1) @ map (cons 2) (input_positions_pair m2) + +fun input_positions_of_mode mode = flat (map_index + (fn (i, Input) => [(i, [])] + | (_, Output) => [] + | (_, Fun _) => [] + | (i, m as Pair (m1, m2)) => map (pair i) (input_positions_pair m)) + (Predicate_Compile_Aux.strip_fun_mode mode)) + +fun argument_position_pair mode [] = [] + | argument_position_pair (Pair (Fun _, m2)) (2 :: is) = argument_position_pair m2 is + | argument_position_pair (Pair (m1, m2)) (i :: is) = + (if eq_mode (m1, Output) andalso i = 2 then + argument_position_pair m2 is + else if eq_mode (m2, Output) andalso i = 1 then + argument_position_pair m1 is + else (i :: argument_position_pair (if i = 1 then m1 else m2) is)) + +fun argument_position_of mode (i, is) = + (i - (length (filter (fn Output => true | Fun _ => true | _ => false) + (List.take (strip_fun_mode mode, i)))), + argument_position_pair (nth (strip_fun_mode mode) i) is) + +fun nth_pair [] t = t + | nth_pair (1 :: is) (Const (@{const_name Pair}, _) $ t1 $ _) = nth_pair is t1 + | nth_pair (2 :: is) (Const (@{const_name Pair}, _) $ _ $ t2) = nth_pair is t2 + | nth_pair _ _ = raise Fail "unexpected input for nth_tuple" + +(** switch detection analysis **) + +fun find_switch_test thy (i, is) (ts, prems) = + let + val t = nth_pair is (nth ts i) + val T = fastype_of t + in + case T of + TFree _ => NONE + | Type (Tcon, _) => + (case Datatype_Data.get_constrs thy Tcon of + NONE => NONE + | SOME cs => + (case strip_comb t of + (Var _, []) => NONE + | (Free _, []) => NONE + | (Const (c, T), _) => if AList.defined (op =) cs c then SOME (c, T) else NONE)) + end + +fun partition_clause thy pos moded_clauses = + let + fun insert_list eq (key, value) = AList.map_default eq (key, []) (cons value) + fun find_switch_test' moded_clause (cases, left) = + case find_switch_test thy pos moded_clause of + SOME (c, T) => (insert_list (op =) ((c, T), moded_clause) cases, left) + | NONE => (cases, moded_clause :: left) + in + fold find_switch_test' moded_clauses ([], []) + end + +datatype switch_tree = + Atom of moded_clause list | Node of (position * ((string * typ) * switch_tree) list) * switch_tree + +fun mk_switch_tree thy mode moded_clauses = + let + fun select_best_switch moded_clauses input_position best_switch = + let + val ord = option_ord (rev_order o int_ord o (pairself (length o snd o snd))) + val partition = partition_clause thy input_position moded_clauses + val switch = if (length (fst partition) > 1) then SOME (input_position, partition) else NONE + in + case ord (switch, best_switch) of LESS => best_switch + | EQUAL => best_switch | GREATER => switch + end + fun detect_switches moded_clauses = + case fold (select_best_switch moded_clauses) (input_positions_of_mode mode) NONE of + SOME (best_pos, (switched_on, left_clauses)) => + Node ((best_pos, map (apsnd detect_switches) switched_on), + detect_switches left_clauses) + | NONE => Atom moded_clauses + in + detect_switches moded_clauses + end + +(** compilation of detected switches **) + +fun destruct_constructor_pattern (pat, obj) = + (case strip_comb pat of + (f as Free _, []) => cons (pat, obj) + | (Const (c, T), pat_args) => + (case strip_comb obj of + (Const (c', T'), obj_args) => + (if c = c' andalso T = T' then + fold destruct_constructor_pattern (pat_args ~~ obj_args) + else raise Fail "pattern and object mismatch") + | _ => raise Fail "unexpected object") + | _ => raise Fail "unexpected pattern") + + +fun compile_switch compilation_modifiers ctxt all_vs param_vs additional_arguments mode + in_ts' outTs switch_tree = + let + val compfuns = Comp_Mod.compfuns compilation_modifiers + val thy = ProofContext.theory_of ctxt + fun compile_switch_tree _ _ (Atom []) = NONE + | compile_switch_tree all_vs ctxt_eqs (Atom moded_clauses) = + let + val in_ts' = map (Pattern.rewrite_term thy ctxt_eqs []) in_ts' + fun compile_clause' (ts, moded_ps) = + let + val (ts, out_ts) = split_mode mode ts + val subst = fold destruct_constructor_pattern (in_ts' ~~ ts) [] + val (fsubst, pat') = List.partition (fn (_, Free _) => true | _ => false) subst + val moded_ps' = (map o apfst o map_indprem) + (Pattern.rewrite_term thy (map swap fsubst) []) moded_ps + val inp = HOLogic.mk_tuple (map fst pat') + val in_ts' = map (Pattern.rewrite_term thy (map swap fsubst) []) (map snd pat') + val out_ts' = map (Pattern.rewrite_term thy (map swap fsubst) []) out_ts + in + compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments + mode inp (in_ts', out_ts') moded_ps' + end + in SOME (foldr1 (mk_sup compfuns) (map compile_clause' moded_clauses)) end + | compile_switch_tree all_vs ctxt_eqs (Node ((position, switched_clauses), left_clauses)) = + let + val (i, is) = argument_position_of mode position + val inp_var = nth_pair is (nth in_ts' i) + val x = Name.variant all_vs "x" + val xt = Free (x, fastype_of inp_var) + fun compile_single_case ((c, T), switched) = + let + val Ts = binder_types T + val argnames = Name.variant_list (x :: all_vs) + (map (fn i => "c" ^ string_of_int i) (1 upto length Ts)) + val args = map2 (curry Free) argnames Ts + val pattern = list_comb (Const (c, T), args) + val ctxt_eqs' = (inp_var, pattern) :: ctxt_eqs + val compilation = the_default (mk_bot compfuns (HOLogic.mk_tupleT outTs)) + (compile_switch_tree (argnames @ x :: all_vs) ctxt_eqs' switched) + in + (pattern, compilation) + end + val switch = fst (Datatype.make_case ctxt Datatype_Case.Quiet [] inp_var + ((map compile_single_case switched_clauses) @ + [(xt, mk_bot compfuns (HOLogic.mk_tupleT outTs))])) + in + case compile_switch_tree all_vs ctxt_eqs left_clauses of + NONE => SOME switch + | SOME left_comp => SOME (mk_sup compfuns (switch, left_comp)) + end + in + compile_switch_tree all_vs [] switch_tree + end + +(* compilation of predicates *) + +fun compile_pred options compilation_modifiers thy all_vs param_vs s T (pol, mode) moded_cls = let val ctxt = ProofContext.init thy val compilation_modifiers = if pol then compilation_modifiers else @@ -1794,7 +1956,6 @@ (binder_types T) val predT = mk_predT compfuns (HOLogic.mk_tupleT outTs) val funT = Comp_Mod.funT_of compilation_modifiers mode T - val (in_ts, _) = fold_map (fold_map_aterms_prodT (curry HOLogic.mk_prod) (fn T => fn (param_vs, names) => if is_param_type T then @@ -1806,14 +1967,24 @@ (param_vs, (all_vs @ param_vs)) val in_ts' = map_filter (map_filter_prod (fn t as Free (x, _) => if member (op =) param_vs x then NONE else SOME t | t => SOME t)) in_ts - val cl_ts = - map (compile_clause compilation_modifiers - ctxt all_vs param_vs additional_arguments mode (HOLogic.mk_tuple in_ts')) moded_cls; - val compilation = Comp_Mod.wrap_compilation compilation_modifiers compfuns - s T mode additional_arguments - (if null cl_ts then - mk_bot compfuns (HOLogic.mk_tupleT outTs) - else foldr1 (mk_sup compfuns) cl_ts) + val compilation = + if detect_switches options then + the_default (mk_bot compfuns (HOLogic.mk_tupleT outTs)) + (compile_switch compilation_modifiers ctxt all_vs param_vs additional_arguments + mode in_ts' outTs (mk_switch_tree thy mode moded_cls)) + else + let + val cl_ts = + map (fn (ts, moded_prems) => + compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments + mode (HOLogic.mk_tuple in_ts') (split_mode mode ts) moded_prems) moded_cls; + in + Comp_Mod.wrap_compilation compilation_modifiers compfuns s T mode additional_arguments + (if null cl_ts then + mk_bot compfuns (HOLogic.mk_tupleT outTs) + else + foldr1 (mk_sup compfuns) cl_ts) + end val fun_const = Const (function_name_of (Comp_Mod.compilation compilation_modifiers) (ProofContext.theory_of ctxt) s mode, funT) @@ -1822,7 +1993,7 @@ (HOLogic.mk_eq (list_comb (fun_const, in_ts @ additional_arguments), compilation)) end; -(* special setup for simpset *) +(** special setup for simpset **) val HOL_basic_ss' = HOL_basic_ss addsimps (@{thms HOL.simp_thms} @ [@{thm Pair_eq}]) setSolver (mk_solver "all_tac_solver" (fn _ => fn _ => all_tac)) setSolver (mk_solver "True_solver" (fn _ => rtac @{thm TrueI})) @@ -2463,13 +2634,13 @@ fun join_preds_modes table1 table2 = map_preds_modes (fn pred => fn mode => fn value => (value, the (AList.lookup (op =) (the (AList.lookup (op =) table2 pred)) mode))) table1 - + fun maps_modes preds_modes_table = map (fn (pred, modes) => (pred, map (fn (mode, value) => value) modes)) preds_modes_table - -fun compile_preds comp_modifiers thy all_vs param_vs preds moded_clauses = - map_preds_modes (fn pred => compile_pred comp_modifiers thy all_vs param_vs pred + +fun compile_preds options comp_modifiers thy all_vs param_vs preds moded_clauses = + map_preds_modes (fn pred => compile_pred options comp_modifiers thy all_vs param_vs pred (the (AList.lookup (op =) preds pred))) moded_clauses fun prove options thy clauses preds moded_clauses compiled_terms = @@ -2482,7 +2653,6 @@ compiled_terms (* preparation of introduction rules into special datastructures *) - fun dest_prem thy params t = (case strip_comb t of (v as Free _, ts) => if member (op =) params v then Prem t else Sidecond t @@ -2626,9 +2796,6 @@ datatype steps = Steps of { define_functions : options -> (string * typ) list -> string * (bool * mode) list -> theory -> theory, - (*infer_modes : options -> (string * typ) list -> (string * mode list) list - -> string list -> (string * (term list * indprem list) list) list - -> theory -> ((moded_clause list pred_mode_table * string list) * theory),*) prove : options -> theory -> (string * (term list * indprem list) list) list -> (string * typ) list -> moded_clause list pred_mode_table -> term pred_mode_table -> thm pred_mode_table, add_code_equations : theory -> (string * typ) list @@ -2669,8 +2836,9 @@ |> Theory.checkpoint) val _ = print_step options "Compiling equations..." val compiled_terms = - Output.cond_timeit true "Compiling equations...." (fn _ => - compile_preds (#comp_modifiers (dest_steps steps)) thy'' all_vs param_vs preds moded_clauses) + (*Output.cond_timeit true "Compiling equations...." (fn _ =>*) + compile_preds options + (#comp_modifiers (dest_steps steps)) thy'' all_vs param_vs preds moded_clauses val _ = print_compiled_terms options thy'' compiled_terms val _ = print_step options "Proving equations..." val result_thms =