added generator compilation of higher-order predicates; refined mode analysis for generators; some tuning
authorbulwahn
Tue, 04 Aug 2009 08:34:56 +0200
changeset 32316 1d83ac469459
parent 32315 79f324944be4
child 32317 b4b871808223
added generator compilation of higher-order predicates; refined mode analysis for generators; some tuning
src/HOL/ex/Predicate_Compile_ex.thy
src/HOL/ex/predicate_compile.ML
--- a/src/HOL/ex/Predicate_Compile_ex.thy	Tue Aug 04 08:34:56 2009 +0200
+++ b/src/HOL/ex/Predicate_Compile_ex.thy	Tue Aug 04 08:34:56 2009 +0200
@@ -71,6 +71,11 @@
 
 thm tranclp.equation
 
+setup {* Predicate_Compile.add_sizelim_equations [@{const_name tranclp}] *}
+setup {* fn thy => exception_trace (fn () => Predicate_Compile.add_quickcheck_equations [@{const_name tranclp}] thy)  *}
+
+thm tranclp.rpred_equation
+
 inductive succ :: "nat \<Rightarrow> nat \<Rightarrow> bool" where
     "succ 0 1"
   | "succ m n \<Longrightarrow> succ (Suc m) (Suc n)"
--- a/src/HOL/ex/predicate_compile.ML	Tue Aug 04 08:34:56 2009 +0200
+++ b/src/HOL/ex/predicate_compile.ML	Tue Aug 04 08:34:56 2009 +0200
@@ -203,6 +203,11 @@
     | SOME js => enclose "[" "]" (commas (map string_of_int js)))
        (iss @ [SOME is]));
 
+fun string_of_tmode (Mode (predmode, termmode, param_modes)) =
+  "predmode: " ^ (string_of_mode predmode) ^ 
+  (if null param_modes then "" else
+    "; " ^ "params: " ^ commas (map (the_default "NONE" o Option.map string_of_tmode) param_modes))
+    
 datatype indprem = Prem of term list * term | Negprem of term list * term | Sidecond of term |
   GeneratorPrem of term list * term | Generator of (string * typ);
 
@@ -351,12 +356,12 @@
     val _ = Output.tracing (cat_lines (map print_pred pred_mode_table))
   in () end;
 
-fun string_of_moded_prem thy (Prem (ts, p), Mode (_, is, _)) =
+fun string_of_moded_prem thy (Prem (ts, p), tmode) =
     (Syntax.string_of_term_global thy (list_comb (p, ts))) ^
-    "(mode: " ^ (space_implode ", " (map string_of_int is)) ^ ")"
-  | string_of_moded_prem thy (GeneratorPrem (ts, p), Mode (_, is, _)) =
+    "(" ^ (string_of_tmode tmode) ^ ")"
+  | string_of_moded_prem thy (GeneratorPrem (ts, p), Mode (predmode, is, _)) =
     (Syntax.string_of_term_global thy (list_comb (p, ts))) ^
-    "(generator_mode: " ^ (space_implode ", " (map string_of_int is)) ^ ")"
+    "(generator_mode: " ^ (string_of_mode predmode) ^ ")"
   | string_of_moded_prem thy (Generator (v, T), _) =
     "Generator for " ^ v ^ " of Type " ^ (Syntax.string_of_typ_global thy T)
   | string_of_moded_prem thy (Negprem (ts, p), Mode (_, is, _)) =
@@ -928,19 +933,28 @@
 
 fun gen_prem (Prem (us, t)) = GeneratorPrem (us, t) 
   | gen_prem _ = error "gen_prem : invalid input for gen_prem"
+
+fun param_gen_prem param_vs (p as Prem (us, t as Free (v, _))) =
+  if member (op =) param_vs v then
+    GeneratorPrem (us, t)
+  else p  
+  | param_gen_prem param_vs p = p
   
 fun check_mode_clause with_generator thy param_vs modes gen_modes (iss, is) (ts, ps) =
   let
     val modes' = modes @ List.mapPartial
       (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
         (param_vs ~~ iss);
+    val gen_modes' = gen_modes @ List.mapPartial
+      (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
+        (param_vs ~~ iss);  
     val vTs = distinct (op =) ((fold o fold_prem) Term.add_frees ps (fold Term.add_frees ts []))
     val prem_vs = distinct (op =) ((fold o fold_prem) Term.add_free_names ps [])
     fun check_mode_prems acc_ps vs [] = SOME (acc_ps, vs)
       | check_mode_prems acc_ps vs ps = (case select_mode_prem thy modes' vs ps of
           NONE =>
             (if with_generator then
-              (case select_mode_prem thy gen_modes vs ps of
+              (case select_mode_prem thy gen_modes' vs ps of
                   SOME (p, SOME mode) => check_mode_prems ((gen_prem p, mode) :: acc_ps) 
                   (case p of Prem (us, _) => vs union terms_vs us | _ => vs)
                   (filter_out (equal p) ps)
@@ -956,7 +970,7 @@
                   end)
             else
               NONE)
-        | SOME (p, SOME mode) => check_mode_prems ((p, mode) :: acc_ps) 
+        | SOME (p, SOME mode) => check_mode_prems ((if with_generator then param_gen_prem param_vs p else p, mode) :: acc_ps) 
             (case p of Prem (us, _) => vs union terms_vs us | _ => vs)
             (filter_out (equal p) ps))
     val (in_ts, in_ts') = List.partition (is_constrt thy) (fst (split_smode is ts));
@@ -1012,15 +1026,24 @@
     map (get_modes_pred with_generator thy param_vs preds (modes @ extra_modes) []) modes
   end;
 
+fun remove_from rem [] = []
+  | remove_from rem ((k, vs) :: xs) =
+    (case AList.lookup (op =) rem k of
+      NONE => (k, vs)
+    | SOME vs' => (k, vs \\ vs'))
+    :: remove_from rem xs
+    
 fun infer_modes_with_generator thy extra_modes arities param_vs preds =
   let
     val prednames = map fst preds
+    val extra_modes = all_modes_of thy
     val gen_modes = all_generator_modes_of thy
       |> filter_out (fn (name, _) => member (op =) prednames name)
+    val starting_modes = remove_from extra_modes (modes_of_arities arities) 
     val modes =
       fixp (fn modes =>
         map (check_modes_pred true thy param_vs preds extra_modes (gen_modes @ modes)) modes)
-          (modes_of_arities arities)
+         starting_modes 
   in
     map (get_modes_pred true thy param_vs preds extra_modes (gen_modes @ modes)) modes
   end;
@@ -1131,7 +1154,9 @@
   case strip_comb t of
     (Const (name, T), params) =>
        let
+         val _ = Output.tracing (Syntax.string_of_term_global thy t)
          val params' = map (compile_param size thy PredicateCompFuns.compfuns) (ms ~~ params)
+         val _ = Output.tracing "..."
          val mk_fun_of = case size of NONE => mk_fun_of | SOME _ => mk_sizelim_fun_of
        in
          list_comb (mk_fun_of PredicateCompFuns.compfuns thy (name, T) mode, params')
@@ -1146,12 +1171,14 @@
 fun compile_gen_expr size thy compfuns ((Mode (mode, is, ms)), t) =
   case strip_comb t of
     (Const (name, T), params) =>
-       let
-         val params' = map (compile_param size thy compfuns) (ms ~~ params)
-       in
-         list_comb (mk_generator_of compfuns thy (name, T) mode, params')
-       end
-       
+      let
+        val params' = map (compile_param size thy compfuns) (ms ~~ params)
+      in
+        list_comb (mk_generator_of compfuns thy (name, T) mode, params')
+      end
+    | (Free (name, T), args) =>
+      list_comb (Free (name, sizelim_funT_of RPredCompFuns.compfuns ([], is) T), args)
+          
 (** specific rpred functions -- move them to the correct place in this file *)
 
 (* uncommented termify code; causes more trouble than expected at first *) 
@@ -1416,9 +1443,6 @@
     fun create_definition mode thy =
       let
         val mode_cname = create_constname_of_mode thy "sizelim_" name mode
-        (* val (Ts1, (Us1, Us2)) = split_mode mode (binder_types T)
-        val Ts1' = map2 (fn NONE => I | SOME is => size_funT_of PredicateCompFuns.compfuns ([], is)) (fst mode) Ts1
-         (Ts1' @ Us1 @ [@{typ "code_numeral"}]) ---> (PredicateCompFuns.mk_predT (mk_tupleT Us2)) *)
         val funT = sizelim_funT_of PredicateCompFuns.compfuns mode T
       in
         thy |> Sign.add_consts_i [(Binding.name (Long_Name.base_name mode_cname), funT, NoSyn)]
@@ -1434,9 +1458,7 @@
     fun create_definition mode thy =
       let
         val mode_cname = create_constname_of_mode thy "gen_" name mode
-        val (Ts1, (Us1, Us2)) = split_mode mode (binder_types T);
-        val Ts1' = map2 (fn NONE => I | SOME is => funT_of RPredCompFuns.compfuns ([], is)) (fst mode) Ts1
-        val funT = (Ts1' @ Us1 @ [@{typ "code_numeral"}]) ---> (RPredCompFuns.mk_rpredT (mk_tupleT Us2)) 
+        val funT = sizelim_funT_of RPredCompFuns.compfuns mode T
       in
         thy |> Sign.add_consts_i [(Binding.name (Long_Name.base_name mode_cname), funT, NoSyn)]
         |> set_generator_name name mode mode_cname 
@@ -1846,8 +1868,8 @@
     val intrs = maps (intros_of thy) prednames
       |> map (Logic.unvarify o prop_of)
     val nparams = nparams_of thy (hd prednames)
+    val extra_modes = all_modes_of thy |> filter_out (fn (name, _) => member (op =) prednames name)
     val preds = distinct (op =) (map (dest_Const o fst o (strip_intro_concl nparams)) intrs)
-    val extra_modes = all_modes_of thy |> filter_out (fn (name, _) => member (op =) prednames name)
     (*val _ = Output.tracing ("extra_modes are: " ^
       cat_lines (map (fn (name, modes) => name ^ " has modes:" ^
       (commas (map string_of_mode modes))) extra_modes)) *)
@@ -1896,6 +1918,8 @@
   val _ = Output.tracing "Infering modes..."
   val moded_clauses = #infer_modes steps thy extra_modes arities param_vs clauses 
   val modes = map (fn (p, mps) => (p, map fst mps)) moded_clauses
+  val _ = print_modes modes
+  val _ = print_moded_clauses thy moded_clauses
   val _ = Output.tracing "Defining executable functions..."
   val thy' = fold (#create_definitions steps preds) modes thy
     |> Theory.checkpoint
@@ -1934,7 +1958,8 @@
   
 fun gen_add_equations steps names thy =
   let
-    val thy' = PredData.map (fold (extend (fetch_pred_data thy) (depending_preds_of thy)) names) thy |> Theory.checkpoint;
+    val thy' = PredData.map (fold (extend (fetch_pred_data thy) (depending_preds_of thy)) names) thy
+      |> Theory.checkpoint;
     fun strong_conn_of gr keys =
       Graph.strong_conn (Graph.subgraph (member (op =) (Graph.all_succs gr keys)) gr)
     val scc = strong_conn_of (PredData.get thy') names