src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
changeset 36254 95ef0a3cf31c
parent 36252 beba03215d8f
child 36258 f459a0cc3241
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Wed Apr 21 12:10:52 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Wed Apr 21 12:10:52 2010 +0200
@@ -1541,7 +1541,6 @@
     val thy' = fold (fn (s, ms) => if member (op =) (map fst preds) s then
       set_needs_random s (map_filter (fn ((true, m), true) => SOME m | _ => NONE) ms) else I)
       modes thy
-
   in
     ((moded_clauses, errors), thy')
   end;
@@ -1645,7 +1644,7 @@
     val name' = Name.variant (name :: names) "y";
     val T = HOLogic.mk_tupleT (map fastype_of out_ts);
     val U = fastype_of success_t;
-    val U' = dest_predT compfuns U;        
+    val U' = dest_predT compfuns U;
     val v = Free (name, T);
     val v' = Free (name', T);
   in
@@ -1705,13 +1704,12 @@
   end
 
 fun compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments
-  mode inp (ts, moded_ps) =
+  mode inp (in_ts, out_ts) moded_ps =
   let
     val compfuns = Comp_Mod.compfuns compilation_modifiers
-    val iss = ho_arg_modes_of mode
+    val iss = ho_arg_modes_of mode (* FIXME! *)
     val compile_match = compile_match compilation_modifiers
       additional_arguments param_vs iss ctxt
-    val (in_ts, out_ts) = split_mode mode ts;
     val (in_ts', (all_vs', eqs)) =
       fold_map (collect_non_invertible_subterms ctxt) in_ts (all_vs, []);
     fun compile_prems out_ts' vs names [] =
@@ -1779,7 +1777,171 @@
     mk_bind compfuns (mk_single compfuns inp, prem_t)
   end
 
-fun compile_pred compilation_modifiers thy all_vs param_vs s T (pol, mode) moded_cls =
+(* switch detection *)
+
+(** argument position of an inductive predicates and the executable functions **)
+
+type position = int * int list
+
+fun input_positions_pair Input = [[]]
+  | input_positions_pair Output = []
+  | input_positions_pair (Fun _) = []
+  | input_positions_pair (Pair (m1, m2)) =
+    map (cons 1) (input_positions_pair m1) @ map (cons 2) (input_positions_pair m2)
+
+fun input_positions_of_mode mode = flat (map_index
+   (fn (i, Input) => [(i, [])]
+   | (_, Output) => []
+   | (_, Fun _) => []
+   | (i, m as Pair (m1, m2)) => map (pair i) (input_positions_pair m))
+     (Predicate_Compile_Aux.strip_fun_mode mode))
+
+fun argument_position_pair mode [] = []
+  | argument_position_pair (Pair (Fun _, m2)) (2 :: is) = argument_position_pair m2 is
+  | argument_position_pair (Pair (m1, m2)) (i :: is) =
+    (if eq_mode (m1, Output) andalso i = 2 then
+      argument_position_pair m2 is
+    else if eq_mode (m2, Output) andalso i = 1 then
+      argument_position_pair m1 is
+    else (i :: argument_position_pair (if i = 1 then m1 else m2) is))
+
+fun argument_position_of mode (i, is) =
+  (i - (length (filter (fn Output => true | Fun _ => true | _ => false)
+    (List.take (strip_fun_mode mode, i)))),
+  argument_position_pair (nth (strip_fun_mode mode) i) is)
+
+fun nth_pair [] t = t
+  | nth_pair (1 :: is) (Const (@{const_name Pair}, _) $ t1 $ _) = nth_pair is t1
+  | nth_pair (2 :: is) (Const (@{const_name Pair}, _) $ _ $ t2) = nth_pair is t2
+  | nth_pair _ _ = raise Fail "unexpected input for nth_tuple"
+
+(** switch detection analysis **)
+
+fun find_switch_test thy (i, is) (ts, prems) =
+  let
+    val t = nth_pair is (nth ts i)
+    val T = fastype_of t
+  in
+    case T of
+      TFree _ => NONE
+    | Type (Tcon, _) =>
+      (case Datatype_Data.get_constrs thy Tcon of
+        NONE => NONE
+      | SOME cs =>
+        (case strip_comb t of
+          (Var _, []) => NONE
+        | (Free _, []) => NONE
+        | (Const (c, T), _) => if AList.defined (op =) cs c then SOME (c, T) else NONE))
+  end
+
+fun partition_clause thy pos moded_clauses =
+  let
+    fun insert_list eq (key, value) = AList.map_default eq (key, []) (cons value)
+    fun find_switch_test' moded_clause (cases, left) =
+      case find_switch_test thy pos moded_clause of
+        SOME (c, T) => (insert_list (op =) ((c, T), moded_clause) cases, left)
+      | NONE => (cases, moded_clause :: left)
+  in
+    fold find_switch_test' moded_clauses ([], [])
+  end
+
+datatype switch_tree =
+  Atom of moded_clause list | Node of (position * ((string * typ) * switch_tree) list) * switch_tree
+
+fun mk_switch_tree thy mode moded_clauses =
+  let
+    fun select_best_switch moded_clauses input_position best_switch =
+      let
+        val ord = option_ord (rev_order o int_ord o (pairself (length o snd o snd)))
+        val partition = partition_clause thy input_position moded_clauses
+        val switch = if (length (fst partition) > 1) then SOME (input_position, partition) else NONE
+      in
+        case ord (switch, best_switch) of LESS => best_switch
+          | EQUAL => best_switch | GREATER => switch
+      end
+    fun detect_switches moded_clauses =
+      case fold (select_best_switch moded_clauses) (input_positions_of_mode mode) NONE of
+        SOME (best_pos, (switched_on, left_clauses)) =>
+          Node ((best_pos, map (apsnd detect_switches) switched_on),
+            detect_switches left_clauses)
+      | NONE => Atom moded_clauses
+  in
+    detect_switches moded_clauses
+  end
+
+(** compilation of detected switches **)
+
+fun destruct_constructor_pattern (pat, obj) =
+  (case strip_comb pat of
+    (f as Free _, []) => cons (pat, obj)
+  | (Const (c, T), pat_args) =>
+    (case strip_comb obj of
+      (Const (c', T'), obj_args) =>
+        (if c = c' andalso T = T' then
+          fold destruct_constructor_pattern (pat_args ~~ obj_args)
+        else raise Fail "pattern and object mismatch")
+    | _ => raise Fail "unexpected object")
+  | _ => raise Fail "unexpected pattern")
+
+
+fun compile_switch compilation_modifiers ctxt all_vs param_vs additional_arguments mode
+  in_ts' outTs switch_tree =
+  let
+    val compfuns = Comp_Mod.compfuns compilation_modifiers
+    val thy = ProofContext.theory_of ctxt
+    fun compile_switch_tree _ _ (Atom []) = NONE
+      | compile_switch_tree all_vs ctxt_eqs (Atom moded_clauses) =
+        let
+          val in_ts' = map (Pattern.rewrite_term thy ctxt_eqs []) in_ts'
+          fun compile_clause' (ts, moded_ps) =
+            let
+              val (ts, out_ts) = split_mode mode ts
+              val subst = fold destruct_constructor_pattern (in_ts' ~~ ts) []
+              val (fsubst, pat') = List.partition (fn (_, Free _) => true | _ => false) subst
+              val moded_ps' = (map o apfst o map_indprem)
+                (Pattern.rewrite_term thy (map swap fsubst) []) moded_ps
+              val inp = HOLogic.mk_tuple (map fst pat')
+              val in_ts' = map (Pattern.rewrite_term thy (map swap fsubst) []) (map snd pat')
+              val out_ts' = map (Pattern.rewrite_term thy (map swap fsubst) []) out_ts
+            in
+              compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments
+                mode inp (in_ts', out_ts') moded_ps'
+            end
+        in SOME (foldr1 (mk_sup compfuns) (map compile_clause' moded_clauses)) end
+    | compile_switch_tree all_vs ctxt_eqs (Node ((position, switched_clauses), left_clauses)) =
+      let
+        val (i, is) = argument_position_of mode position
+        val inp_var = nth_pair is (nth in_ts' i)
+        val x = Name.variant all_vs "x"
+        val xt = Free (x, fastype_of inp_var)
+        fun compile_single_case ((c, T), switched) =
+          let
+            val Ts = binder_types T
+            val argnames = Name.variant_list (x :: all_vs)
+              (map (fn i => "c" ^ string_of_int i) (1 upto length Ts))
+            val args = map2 (curry Free) argnames Ts
+            val pattern = list_comb (Const (c, T), args)
+            val ctxt_eqs' = (inp_var, pattern) :: ctxt_eqs
+            val compilation = the_default (mk_bot compfuns (HOLogic.mk_tupleT outTs))
+              (compile_switch_tree (argnames @ x :: all_vs) ctxt_eqs' switched)
+        in
+          (pattern, compilation)
+        end
+        val switch = fst (Datatype.make_case ctxt Datatype_Case.Quiet [] inp_var
+          ((map compile_single_case switched_clauses) @
+            [(xt, mk_bot compfuns (HOLogic.mk_tupleT outTs))]))
+      in
+        case compile_switch_tree all_vs ctxt_eqs left_clauses of
+          NONE => SOME switch
+        | SOME left_comp => SOME (mk_sup compfuns (switch, left_comp))
+      end
+  in
+    compile_switch_tree all_vs [] switch_tree
+  end
+
+(* compilation of predicates *)
+
+fun compile_pred options compilation_modifiers thy all_vs param_vs s T (pol, mode) moded_cls =
   let
     val ctxt = ProofContext.init thy
     val compilation_modifiers = if pol then compilation_modifiers else
@@ -1794,7 +1956,6 @@
       (binder_types T)
     val predT = mk_predT compfuns (HOLogic.mk_tupleT outTs)
     val funT = Comp_Mod.funT_of compilation_modifiers mode T
-    
     val (in_ts, _) = fold_map (fold_map_aterms_prodT (curry HOLogic.mk_prod)
       (fn T => fn (param_vs, names) =>
         if is_param_type T then
@@ -1806,14 +1967,24 @@
         (param_vs, (all_vs @ param_vs))
     val in_ts' = map_filter (map_filter_prod
       (fn t as Free (x, _) => if member (op =) param_vs x then NONE else SOME t | t => SOME t)) in_ts
-    val cl_ts =
-      map (compile_clause compilation_modifiers
-        ctxt all_vs param_vs additional_arguments mode (HOLogic.mk_tuple in_ts')) moded_cls;
-    val compilation = Comp_Mod.wrap_compilation compilation_modifiers compfuns
-      s T mode additional_arguments
-      (if null cl_ts then
-        mk_bot compfuns (HOLogic.mk_tupleT outTs)
-      else foldr1 (mk_sup compfuns) cl_ts)
+    val compilation =
+      if detect_switches options then
+        the_default (mk_bot compfuns (HOLogic.mk_tupleT outTs))
+          (compile_switch compilation_modifiers ctxt all_vs param_vs additional_arguments
+            mode in_ts' outTs (mk_switch_tree thy mode moded_cls))
+      else
+        let
+          val cl_ts =
+            map (fn (ts, moded_prems) => 
+              compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments
+              mode (HOLogic.mk_tuple in_ts') (split_mode mode ts) moded_prems) moded_cls;
+        in
+          Comp_Mod.wrap_compilation compilation_modifiers compfuns s T mode additional_arguments
+            (if null cl_ts then
+              mk_bot compfuns (HOLogic.mk_tupleT outTs)
+            else
+              foldr1 (mk_sup compfuns) cl_ts)
+        end
     val fun_const =
       Const (function_name_of (Comp_Mod.compilation compilation_modifiers)
       (ProofContext.theory_of ctxt) s mode, funT)
@@ -1822,7 +1993,7 @@
       (HOLogic.mk_eq (list_comb (fun_const, in_ts @ additional_arguments), compilation))
   end;
 
-(* special setup for simpset *)                  
+(** special setup for simpset **)
 val HOL_basic_ss' = HOL_basic_ss addsimps (@{thms HOL.simp_thms} @ [@{thm Pair_eq}])
   setSolver (mk_solver "all_tac_solver" (fn _ => fn _ => all_tac))
   setSolver (mk_solver "True_solver" (fn _ => rtac @{thm TrueI}))
@@ -2463,13 +2634,13 @@
 fun join_preds_modes table1 table2 =
   map_preds_modes (fn pred => fn mode => fn value =>
     (value, the (AList.lookup (op =) (the (AList.lookup (op =) table2 pred)) mode))) table1
-    
+
 fun maps_modes preds_modes_table =
   map (fn (pred, modes) =>
     (pred, map (fn (mode, value) => value) modes)) preds_modes_table
-    
-fun compile_preds comp_modifiers thy all_vs param_vs preds moded_clauses =
-  map_preds_modes (fn pred => compile_pred comp_modifiers thy all_vs param_vs pred
+
+fun compile_preds options comp_modifiers thy all_vs param_vs preds moded_clauses =
+  map_preds_modes (fn pred => compile_pred options comp_modifiers thy all_vs param_vs pred
       (the (AList.lookup (op =) preds pred))) moded_clauses
 
 fun prove options thy clauses preds moded_clauses compiled_terms =
@@ -2482,7 +2653,6 @@
     compiled_terms
 
 (* preparation of introduction rules into special datastructures *)
-
 fun dest_prem thy params t =
   (case strip_comb t of
     (v as Free _, ts) => if member (op =) params v then Prem t else Sidecond t
@@ -2626,9 +2796,6 @@
 datatype steps = Steps of
   {
   define_functions : options -> (string * typ) list -> string * (bool * mode) list -> theory -> theory,
-  (*infer_modes : options -> (string * typ) list -> (string * mode list) list
-    -> string list -> (string * (term list * indprem list) list) list
-    -> theory -> ((moded_clause list pred_mode_table * string list) * theory),*)
   prove : options -> theory -> (string * (term list * indprem list) list) list -> (string * typ) list
     -> moded_clause list pred_mode_table -> term pred_mode_table -> thm pred_mode_table,
   add_code_equations : theory -> (string * typ) list
@@ -2669,8 +2836,9 @@
       |> Theory.checkpoint)
     val _ = print_step options "Compiling equations..."
     val compiled_terms =
-      Output.cond_timeit true "Compiling equations...." (fn _ =>
-      compile_preds (#comp_modifiers (dest_steps steps)) thy'' all_vs param_vs preds moded_clauses)
+      (*Output.cond_timeit true "Compiling equations...." (fn _ =>*)
+        compile_preds options
+          (#comp_modifiers (dest_steps steps)) thy'' all_vs param_vs preds moded_clauses
     val _ = print_compiled_terms options thy'' compiled_terms
     val _ = print_step options "Proving equations..."
     val result_thms =