added switch detection to the predicate compiler
authorbulwahn
Wed Apr 21 12:10:52 2010 +0200 (2010-04-21)
changeset 3625495ef0a3cf31c
parent 36253 6e969ce3dfcc
child 36255 f8b3381e1437
added switch detection to the predicate compiler
src/HOL/Tools/Predicate_Compile/predicate_compile.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
     1.1 --- a/src/HOL/Tools/Predicate_Compile/predicate_compile.ML	Wed Apr 21 12:10:52 2010 +0200
     1.2 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile.ML	Wed Apr 21 12:10:52 2010 +0200
     1.3 @@ -162,6 +162,7 @@
     1.4        no_topmost_reordering = (chk "no_topmost_reordering"),
     1.5        no_higher_order_predicate = [],
     1.6        inductify = chk "inductify",
     1.7 +      detect_switches = chk "detect_switches",
     1.8        compilation = compilation
     1.9      }
    1.10    end
     2.1 --- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Wed Apr 21 12:10:52 2010 +0200
     2.2 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Wed Apr 21 12:10:52 2010 +0200
     2.3 @@ -39,6 +39,7 @@
     2.4    datatype indprem = Prem of term | Negprem of term | Sidecond of term
     2.5      | Generator of (string * typ)
     2.6    val dest_indprem : indprem -> term
     2.7 +  val map_indprem : (term -> term) -> indprem -> indprem
     2.8    (* general syntactic functions *)
     2.9    val conjuncts : term -> term list
    2.10    val is_equationlike : thm -> bool
    2.11 @@ -111,6 +112,7 @@
    2.12      specialise : bool,
    2.13      no_higher_order_predicate : string list,
    2.14      inductify : bool,
    2.15 +    detect_switches : bool,
    2.16      compilation : compilation
    2.17    };
    2.18    val expected_modes : options -> (string * mode list) option
    2.19 @@ -130,6 +132,7 @@
    2.20    val specialise : options -> bool
    2.21    val no_higher_order_predicate : options -> string list
    2.22    val is_inductify : options -> bool
    2.23 +  val detect_switches : options -> bool
    2.24    val compilation : options -> compilation
    2.25    val default_options : options
    2.26    val bool_options : string list
    2.27 @@ -394,6 +397,11 @@
    2.28    | dest_indprem (Sidecond t) = t
    2.29    | dest_indprem (Generator _) = raise Fail "cannot destruct generator"
    2.30  
    2.31 +fun map_indprem f (Prem t) = Prem (f t)
    2.32 +  | map_indprem f (Negprem t) = Negprem (f t)
    2.33 +  | map_indprem f (Sidecond t) = Sidecond (f t)
    2.34 +  | map_indprem f (Generator (v, T)) = Generator (dest_Free (f (Free (v, T))))
    2.35 +
    2.36  (* general syntactic functions *)
    2.37  
    2.38  (*Like dest_conj, but flattens conjunctions however nested*)
    2.39 @@ -677,6 +685,7 @@
    2.40    fail_safe_function_flattening : bool,
    2.41    no_higher_order_predicate : string list,
    2.42    inductify : bool,
    2.43 +  detect_switches : bool,
    2.44    compilation : compilation
    2.45  };
    2.46  
    2.47 @@ -705,6 +714,8 @@
    2.48  
    2.49  fun compilation (Options opt) = #compilation opt
    2.50  
    2.51 +fun detect_switches (Options opt) = #detect_switches opt
    2.52 +
    2.53  val default_options = Options {
    2.54    expected_modes = NONE,
    2.55    proposed_modes = NONE,
    2.56 @@ -723,12 +734,13 @@
    2.57    fail_safe_function_flattening = false,
    2.58    no_higher_order_predicate = [],
    2.59    inductify = false,
    2.60 +  detect_switches = true,
    2.61    compilation = Pred
    2.62  }
    2.63  
    2.64  val bool_options = ["show_steps", "show_intermediate_results", "show_proof_trace", "show_modes",
    2.65    "show_mode_inference", "show_compilation", "skip_proof", "inductify", "no_function_flattening",
    2.66 -  "specialise", "no_topmost_reordering"]
    2.67 +  "detect_switches", "specialise", "no_topmost_reordering"]
    2.68  
    2.69  fun print_step options s =
    2.70    if show_steps options then tracing s else ()
     3.1 --- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Wed Apr 21 12:10:52 2010 +0200
     3.2 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Wed Apr 21 12:10:52 2010 +0200
     3.3 @@ -1541,7 +1541,6 @@
     3.4      val thy' = fold (fn (s, ms) => if member (op =) (map fst preds) s then
     3.5        set_needs_random s (map_filter (fn ((true, m), true) => SOME m | _ => NONE) ms) else I)
     3.6        modes thy
     3.7 -
     3.8    in
     3.9      ((moded_clauses, errors), thy')
    3.10    end;
    3.11 @@ -1645,7 +1644,7 @@
    3.12      val name' = Name.variant (name :: names) "y";
    3.13      val T = HOLogic.mk_tupleT (map fastype_of out_ts);
    3.14      val U = fastype_of success_t;
    3.15 -    val U' = dest_predT compfuns U;        
    3.16 +    val U' = dest_predT compfuns U;
    3.17      val v = Free (name, T);
    3.18      val v' = Free (name', T);
    3.19    in
    3.20 @@ -1705,13 +1704,12 @@
    3.21    end
    3.22  
    3.23  fun compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments
    3.24 -  mode inp (ts, moded_ps) =
    3.25 +  mode inp (in_ts, out_ts) moded_ps =
    3.26    let
    3.27      val compfuns = Comp_Mod.compfuns compilation_modifiers
    3.28 -    val iss = ho_arg_modes_of mode
    3.29 +    val iss = ho_arg_modes_of mode (* FIXME! *)
    3.30      val compile_match = compile_match compilation_modifiers
    3.31        additional_arguments param_vs iss ctxt
    3.32 -    val (in_ts, out_ts) = split_mode mode ts;
    3.33      val (in_ts', (all_vs', eqs)) =
    3.34        fold_map (collect_non_invertible_subterms ctxt) in_ts (all_vs, []);
    3.35      fun compile_prems out_ts' vs names [] =
    3.36 @@ -1779,7 +1777,171 @@
    3.37      mk_bind compfuns (mk_single compfuns inp, prem_t)
    3.38    end
    3.39  
    3.40 -fun compile_pred compilation_modifiers thy all_vs param_vs s T (pol, mode) moded_cls =
    3.41 +(* switch detection *)
    3.42 +
    3.43 +(** argument position of an inductive predicates and the executable functions **)
    3.44 +
    3.45 +type position = int * int list
    3.46 +
    3.47 +fun input_positions_pair Input = [[]]
    3.48 +  | input_positions_pair Output = []
    3.49 +  | input_positions_pair (Fun _) = []
    3.50 +  | input_positions_pair (Pair (m1, m2)) =
    3.51 +    map (cons 1) (input_positions_pair m1) @ map (cons 2) (input_positions_pair m2)
    3.52 +
    3.53 +fun input_positions_of_mode mode = flat (map_index
    3.54 +   (fn (i, Input) => [(i, [])]
    3.55 +   | (_, Output) => []
    3.56 +   | (_, Fun _) => []
    3.57 +   | (i, m as Pair (m1, m2)) => map (pair i) (input_positions_pair m))
    3.58 +     (Predicate_Compile_Aux.strip_fun_mode mode))
    3.59 +
    3.60 +fun argument_position_pair mode [] = []
    3.61 +  | argument_position_pair (Pair (Fun _, m2)) (2 :: is) = argument_position_pair m2 is
    3.62 +  | argument_position_pair (Pair (m1, m2)) (i :: is) =
    3.63 +    (if eq_mode (m1, Output) andalso i = 2 then
    3.64 +      argument_position_pair m2 is
    3.65 +    else if eq_mode (m2, Output) andalso i = 1 then
    3.66 +      argument_position_pair m1 is
    3.67 +    else (i :: argument_position_pair (if i = 1 then m1 else m2) is))
    3.68 +
    3.69 +fun argument_position_of mode (i, is) =
    3.70 +  (i - (length (filter (fn Output => true | Fun _ => true | _ => false)
    3.71 +    (List.take (strip_fun_mode mode, i)))),
    3.72 +  argument_position_pair (nth (strip_fun_mode mode) i) is)
    3.73 +
    3.74 +fun nth_pair [] t = t
    3.75 +  | nth_pair (1 :: is) (Const (@{const_name Pair}, _) $ t1 $ _) = nth_pair is t1
    3.76 +  | nth_pair (2 :: is) (Const (@{const_name Pair}, _) $ _ $ t2) = nth_pair is t2
    3.77 +  | nth_pair _ _ = raise Fail "unexpected input for nth_tuple"
    3.78 +
    3.79 +(** switch detection analysis **)
    3.80 +
    3.81 +fun find_switch_test thy (i, is) (ts, prems) =
    3.82 +  let
    3.83 +    val t = nth_pair is (nth ts i)
    3.84 +    val T = fastype_of t
    3.85 +  in
    3.86 +    case T of
    3.87 +      TFree _ => NONE
    3.88 +    | Type (Tcon, _) =>
    3.89 +      (case Datatype_Data.get_constrs thy Tcon of
    3.90 +        NONE => NONE
    3.91 +      | SOME cs =>
    3.92 +        (case strip_comb t of
    3.93 +          (Var _, []) => NONE
    3.94 +        | (Free _, []) => NONE
    3.95 +        | (Const (c, T), _) => if AList.defined (op =) cs c then SOME (c, T) else NONE))
    3.96 +  end
    3.97 +
    3.98 +fun partition_clause thy pos moded_clauses =
    3.99 +  let
   3.100 +    fun insert_list eq (key, value) = AList.map_default eq (key, []) (cons value)
   3.101 +    fun find_switch_test' moded_clause (cases, left) =
   3.102 +      case find_switch_test thy pos moded_clause of
   3.103 +        SOME (c, T) => (insert_list (op =) ((c, T), moded_clause) cases, left)
   3.104 +      | NONE => (cases, moded_clause :: left)
   3.105 +  in
   3.106 +    fold find_switch_test' moded_clauses ([], [])
   3.107 +  end
   3.108 +
   3.109 +datatype switch_tree =
   3.110 +  Atom of moded_clause list | Node of (position * ((string * typ) * switch_tree) list) * switch_tree
   3.111 +
   3.112 +fun mk_switch_tree thy mode moded_clauses =
   3.113 +  let
   3.114 +    fun select_best_switch moded_clauses input_position best_switch =
   3.115 +      let
   3.116 +        val ord = option_ord (rev_order o int_ord o (pairself (length o snd o snd)))
   3.117 +        val partition = partition_clause thy input_position moded_clauses
   3.118 +        val switch = if (length (fst partition) > 1) then SOME (input_position, partition) else NONE
   3.119 +      in
   3.120 +        case ord (switch, best_switch) of LESS => best_switch
   3.121 +          | EQUAL => best_switch | GREATER => switch
   3.122 +      end
   3.123 +    fun detect_switches moded_clauses =
   3.124 +      case fold (select_best_switch moded_clauses) (input_positions_of_mode mode) NONE of
   3.125 +        SOME (best_pos, (switched_on, left_clauses)) =>
   3.126 +          Node ((best_pos, map (apsnd detect_switches) switched_on),
   3.127 +            detect_switches left_clauses)
   3.128 +      | NONE => Atom moded_clauses
   3.129 +  in
   3.130 +    detect_switches moded_clauses
   3.131 +  end
   3.132 +
   3.133 +(** compilation of detected switches **)
   3.134 +
   3.135 +fun destruct_constructor_pattern (pat, obj) =
   3.136 +  (case strip_comb pat of
   3.137 +    (f as Free _, []) => cons (pat, obj)
   3.138 +  | (Const (c, T), pat_args) =>
   3.139 +    (case strip_comb obj of
   3.140 +      (Const (c', T'), obj_args) =>
   3.141 +        (if c = c' andalso T = T' then
   3.142 +          fold destruct_constructor_pattern (pat_args ~~ obj_args)
   3.143 +        else raise Fail "pattern and object mismatch")
   3.144 +    | _ => raise Fail "unexpected object")
   3.145 +  | _ => raise Fail "unexpected pattern")
   3.146 +
   3.147 +
   3.148 +fun compile_switch compilation_modifiers ctxt all_vs param_vs additional_arguments mode
   3.149 +  in_ts' outTs switch_tree =
   3.150 +  let
   3.151 +    val compfuns = Comp_Mod.compfuns compilation_modifiers
   3.152 +    val thy = ProofContext.theory_of ctxt
   3.153 +    fun compile_switch_tree _ _ (Atom []) = NONE
   3.154 +      | compile_switch_tree all_vs ctxt_eqs (Atom moded_clauses) =
   3.155 +        let
   3.156 +          val in_ts' = map (Pattern.rewrite_term thy ctxt_eqs []) in_ts'
   3.157 +          fun compile_clause' (ts, moded_ps) =
   3.158 +            let
   3.159 +              val (ts, out_ts) = split_mode mode ts
   3.160 +              val subst = fold destruct_constructor_pattern (in_ts' ~~ ts) []
   3.161 +              val (fsubst, pat') = List.partition (fn (_, Free _) => true | _ => false) subst
   3.162 +              val moded_ps' = (map o apfst o map_indprem)
   3.163 +                (Pattern.rewrite_term thy (map swap fsubst) []) moded_ps
   3.164 +              val inp = HOLogic.mk_tuple (map fst pat')
   3.165 +              val in_ts' = map (Pattern.rewrite_term thy (map swap fsubst) []) (map snd pat')
   3.166 +              val out_ts' = map (Pattern.rewrite_term thy (map swap fsubst) []) out_ts
   3.167 +            in
   3.168 +              compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments
   3.169 +                mode inp (in_ts', out_ts') moded_ps'
   3.170 +            end
   3.171 +        in SOME (foldr1 (mk_sup compfuns) (map compile_clause' moded_clauses)) end
   3.172 +    | compile_switch_tree all_vs ctxt_eqs (Node ((position, switched_clauses), left_clauses)) =
   3.173 +      let
   3.174 +        val (i, is) = argument_position_of mode position
   3.175 +        val inp_var = nth_pair is (nth in_ts' i)
   3.176 +        val x = Name.variant all_vs "x"
   3.177 +        val xt = Free (x, fastype_of inp_var)
   3.178 +        fun compile_single_case ((c, T), switched) =
   3.179 +          let
   3.180 +            val Ts = binder_types T
   3.181 +            val argnames = Name.variant_list (x :: all_vs)
   3.182 +              (map (fn i => "c" ^ string_of_int i) (1 upto length Ts))
   3.183 +            val args = map2 (curry Free) argnames Ts
   3.184 +            val pattern = list_comb (Const (c, T), args)
   3.185 +            val ctxt_eqs' = (inp_var, pattern) :: ctxt_eqs
   3.186 +            val compilation = the_default (mk_bot compfuns (HOLogic.mk_tupleT outTs))
   3.187 +              (compile_switch_tree (argnames @ x :: all_vs) ctxt_eqs' switched)
   3.188 +        in
   3.189 +          (pattern, compilation)
   3.190 +        end
   3.191 +        val switch = fst (Datatype.make_case ctxt Datatype_Case.Quiet [] inp_var
   3.192 +          ((map compile_single_case switched_clauses) @
   3.193 +            [(xt, mk_bot compfuns (HOLogic.mk_tupleT outTs))]))
   3.194 +      in
   3.195 +        case compile_switch_tree all_vs ctxt_eqs left_clauses of
   3.196 +          NONE => SOME switch
   3.197 +        | SOME left_comp => SOME (mk_sup compfuns (switch, left_comp))
   3.198 +      end
   3.199 +  in
   3.200 +    compile_switch_tree all_vs [] switch_tree
   3.201 +  end
   3.202 +
   3.203 +(* compilation of predicates *)
   3.204 +
   3.205 +fun compile_pred options compilation_modifiers thy all_vs param_vs s T (pol, mode) moded_cls =
   3.206    let
   3.207      val ctxt = ProofContext.init thy
   3.208      val compilation_modifiers = if pol then compilation_modifiers else
   3.209 @@ -1794,7 +1956,6 @@
   3.210        (binder_types T)
   3.211      val predT = mk_predT compfuns (HOLogic.mk_tupleT outTs)
   3.212      val funT = Comp_Mod.funT_of compilation_modifiers mode T
   3.213 -    
   3.214      val (in_ts, _) = fold_map (fold_map_aterms_prodT (curry HOLogic.mk_prod)
   3.215        (fn T => fn (param_vs, names) =>
   3.216          if is_param_type T then
   3.217 @@ -1806,14 +1967,24 @@
   3.218          (param_vs, (all_vs @ param_vs))
   3.219      val in_ts' = map_filter (map_filter_prod
   3.220        (fn t as Free (x, _) => if member (op =) param_vs x then NONE else SOME t | t => SOME t)) in_ts
   3.221 -    val cl_ts =
   3.222 -      map (compile_clause compilation_modifiers
   3.223 -        ctxt all_vs param_vs additional_arguments mode (HOLogic.mk_tuple in_ts')) moded_cls;
   3.224 -    val compilation = Comp_Mod.wrap_compilation compilation_modifiers compfuns
   3.225 -      s T mode additional_arguments
   3.226 -      (if null cl_ts then
   3.227 -        mk_bot compfuns (HOLogic.mk_tupleT outTs)
   3.228 -      else foldr1 (mk_sup compfuns) cl_ts)
   3.229 +    val compilation =
   3.230 +      if detect_switches options then
   3.231 +        the_default (mk_bot compfuns (HOLogic.mk_tupleT outTs))
   3.232 +          (compile_switch compilation_modifiers ctxt all_vs param_vs additional_arguments
   3.233 +            mode in_ts' outTs (mk_switch_tree thy mode moded_cls))
   3.234 +      else
   3.235 +        let
   3.236 +          val cl_ts =
   3.237 +            map (fn (ts, moded_prems) => 
   3.238 +              compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments
   3.239 +              mode (HOLogic.mk_tuple in_ts') (split_mode mode ts) moded_prems) moded_cls;
   3.240 +        in
   3.241 +          Comp_Mod.wrap_compilation compilation_modifiers compfuns s T mode additional_arguments
   3.242 +            (if null cl_ts then
   3.243 +              mk_bot compfuns (HOLogic.mk_tupleT outTs)
   3.244 +            else
   3.245 +              foldr1 (mk_sup compfuns) cl_ts)
   3.246 +        end
   3.247      val fun_const =
   3.248        Const (function_name_of (Comp_Mod.compilation compilation_modifiers)
   3.249        (ProofContext.theory_of ctxt) s mode, funT)
   3.250 @@ -1822,7 +1993,7 @@
   3.251        (HOLogic.mk_eq (list_comb (fun_const, in_ts @ additional_arguments), compilation))
   3.252    end;
   3.253  
   3.254 -(* special setup for simpset *)                  
   3.255 +(** special setup for simpset **)
   3.256  val HOL_basic_ss' = HOL_basic_ss addsimps (@{thms HOL.simp_thms} @ [@{thm Pair_eq}])
   3.257    setSolver (mk_solver "all_tac_solver" (fn _ => fn _ => all_tac))
   3.258    setSolver (mk_solver "True_solver" (fn _ => rtac @{thm TrueI}))
   3.259 @@ -2463,13 +2634,13 @@
   3.260  fun join_preds_modes table1 table2 =
   3.261    map_preds_modes (fn pred => fn mode => fn value =>
   3.262      (value, the (AList.lookup (op =) (the (AList.lookup (op =) table2 pred)) mode))) table1
   3.263 -    
   3.264 +
   3.265  fun maps_modes preds_modes_table =
   3.266    map (fn (pred, modes) =>
   3.267      (pred, map (fn (mode, value) => value) modes)) preds_modes_table
   3.268 -    
   3.269 -fun compile_preds comp_modifiers thy all_vs param_vs preds moded_clauses =
   3.270 -  map_preds_modes (fn pred => compile_pred comp_modifiers thy all_vs param_vs pred
   3.271 +
   3.272 +fun compile_preds options comp_modifiers thy all_vs param_vs preds moded_clauses =
   3.273 +  map_preds_modes (fn pred => compile_pred options comp_modifiers thy all_vs param_vs pred
   3.274        (the (AList.lookup (op =) preds pred))) moded_clauses
   3.275  
   3.276  fun prove options thy clauses preds moded_clauses compiled_terms =
   3.277 @@ -2482,7 +2653,6 @@
   3.278      compiled_terms
   3.279  
   3.280  (* preparation of introduction rules into special datastructures *)
   3.281 -
   3.282  fun dest_prem thy params t =
   3.283    (case strip_comb t of
   3.284      (v as Free _, ts) => if member (op =) params v then Prem t else Sidecond t
   3.285 @@ -2626,9 +2796,6 @@
   3.286  datatype steps = Steps of
   3.287    {
   3.288    define_functions : options -> (string * typ) list -> string * (bool * mode) list -> theory -> theory,
   3.289 -  (*infer_modes : options -> (string * typ) list -> (string * mode list) list
   3.290 -    -> string list -> (string * (term list * indprem list) list) list
   3.291 -    -> theory -> ((moded_clause list pred_mode_table * string list) * theory),*)
   3.292    prove : options -> theory -> (string * (term list * indprem list) list) list -> (string * typ) list
   3.293      -> moded_clause list pred_mode_table -> term pred_mode_table -> thm pred_mode_table,
   3.294    add_code_equations : theory -> (string * typ) list
   3.295 @@ -2669,8 +2836,9 @@
   3.296        |> Theory.checkpoint)
   3.297      val _ = print_step options "Compiling equations..."
   3.298      val compiled_terms =
   3.299 -      Output.cond_timeit true "Compiling equations...." (fn _ =>
   3.300 -      compile_preds (#comp_modifiers (dest_steps steps)) thy'' all_vs param_vs preds moded_clauses)
   3.301 +      (*Output.cond_timeit true "Compiling equations...." (fn _ =>*)
   3.302 +        compile_preds options
   3.303 +          (#comp_modifiers (dest_steps steps)) thy'' all_vs param_vs preds moded_clauses
   3.304      val _ = print_compiled_terms options thy'' compiled_terms
   3.305      val _ = print_step options "Proving equations..."
   3.306      val result_thms =