changed resolving depending predicates and fetching in the predicate compiler
authorbulwahn
Tue, 04 Aug 2009 08:34:56 +0200
changeset 32314 66bbad0bfef9
parent 32313 a984c04927b4
child 32315 79f324944be4
changed resolving depending predicates and fetching in the predicate compiler
src/HOL/ex/Predicate_Compile_ex.thy
src/HOL/ex/predicate_compile.ML
--- a/src/HOL/ex/Predicate_Compile_ex.thy	Tue Aug 04 08:34:56 2009 +0200
+++ b/src/HOL/ex/Predicate_Compile_ex.thy	Tue Aug 04 08:34:56 2009 +0200
@@ -17,7 +17,6 @@
 values 10 "{n. even n}"
 values 10 "{n. odd n}"
 
-
 inductive append :: "'a list \<Rightarrow> 'a list \<Rightarrow> 'a list \<Rightarrow> bool" where
     append_Nil: "append [] xs xs"
   | append_Cons: "append xs ys zs \<Longrightarrow> append (x # xs) ys (x # zs)"
@@ -42,27 +41,28 @@
   | "f x \<Longrightarrow> partition f xs ys zs \<Longrightarrow> partition f (x # xs) (x # ys) zs"
   | "\<not> f x \<Longrightarrow> partition f xs ys zs \<Longrightarrow> partition f (x # xs) ys (x # zs)"
 
-(* FIXME: correct handling of parameters *)
-(*
-ML {* reset Predicate_Compile.do_proofs *}
 code_pred partition .
 
 thm partition.equation
-ML {* set Predicate_Compile.do_proofs *}
-*)
+
+inductive is_even :: "nat \<Rightarrow> bool"
+where
+  "n mod 2 = 0 \<Longrightarrow> is_even n"
+
+code_pred is_even .
 
 (* TODO: requires to handle abstractions in parameter positions correctly *)
-(*FIXME values 10 "{(ys, zs). partition (\<lambda>n. n mod 2 = 0)
-  [0, Suc 0, 2, 3, 4, 5, 6, 7] ys zs}" *)
+values 10 "{(ys, zs). partition is_even
+  [0, Suc 0, 2, 3, 4, 5, 6, 7] ys zs}"
 
+values 10 "{zs. partition is_even zs [0, 2] [3, 5]}"
+values 10 "{zs. partition is_even zs [0, 7] [3, 5]}"
 
 lemma [code_pred_intros]:
 "r a b ==> tranclp r a b"
 "r a b ==> tranclp r b c ==> tranclp r a c"
 by auto
 
-(* Setup requires quick and dirty proof *)
-
 code_pred tranclp
 proof -
   case tranclp
@@ -78,7 +78,7 @@
 code_pred succ .
 
 thm succ.equation
-(* FIXME: why does this not terminate? *)
+(* FIXME: why does this not terminate? -- value chooses mode [] --> [1] and then starts enumerating all successors *)
 (*
 values 20 "{n. tranclp succ 10 n}"
 values "{n. tranclp succ n 10}"
--- a/src/HOL/ex/predicate_compile.ML	Tue Aug 04 08:34:56 2009 +0200
+++ b/src/HOL/ex/predicate_compile.ML	Tue Aug 04 08:34:56 2009 +0200
@@ -10,7 +10,7 @@
   (*val add_equations_of: bool -> string list -> theory -> theory *)
   val register_predicate : (thm list * thm * int) -> theory -> theory
   val is_registered : theory -> string -> bool
-  val fetch_pred_data : theory -> string -> (thm list * thm * int)  
+ (* val fetch_pred_data : theory -> string -> (thm list * thm * int)  *)
   val predfun_intro_of: theory -> string -> mode -> thm
   val predfun_elim_of: theory -> string -> mode -> thm
   val strip_intro_concl: int -> term -> term * (term list * term list)
@@ -78,7 +78,7 @@
   (*val rpred_prove_preds : theory -> term pred_mode_table -> thm pred_mode_table*)
   val rpred_compfuns : compilation_funs
   val dest_funT : typ -> typ * typ
-  val depending_preds_of : theory -> thm list -> string list
+ (* val depending_preds_of : theory -> thm list -> string list *)
   val add_quickcheck_equations : string list -> theory -> theory
   val add_sizelim_equations : string list -> theory -> theory
   val is_inductive_predicate : theory -> string -> bool
@@ -104,7 +104,7 @@
 
 fun tracing s = (if ! Toplevel.debug then Output.tracing s else ());
 
-fun print_tac s = (if ! Toplevel.debug then Tactical.print_tac s else Seq.single);
+fun print_tac s = Seq.single; (* (if ! Toplevel.debug then Tactical.print_tac s else Seq.single); *)
 fun debug_tac msg = (fn st => (Output.tracing msg; Seq.single st));
 
 val do_proofs = ref true;
@@ -465,7 +465,7 @@
   val elim = Goal.prove (ProofContext.init thy) ("P" :: names) [] elim_t
         (fn {...} => etac elim 1) 
 in
-  ([intro], elim, 0)
+  ([intro], elim)
 end
 
 fun fetch_pred_data thy name =
@@ -481,8 +481,9 @@
         val pre_elim = nth (#elims result) (find_index (fn s => s = name) (#names (fst info)))
         val nparams = length (Inductive.params_of (#raw_induct result))
         val elim = singleton (ind_set_codegen_preproc thy) (preprocess_elim thy nparams pre_elim)
+        val (intros, elim) = if null intros then noclause thy name elim else (intros, elim)
       in
-        if null intros then noclause thy name elim else (intros, elim, nparams)
+        mk_pred_data ((intros, SOME elim, nparams), ([], [], []))
       end                                                                    
   | NONE => error ("No such predicate: " ^ quote name)
   
@@ -500,11 +501,17 @@
 fun is_inductive_predicate thy name =
   is_some (try (Inductive.the_inductive (ProofContext.init thy)) name)
 
-fun depending_preds_of thy intros = fold Term.add_const_names (map Thm.prop_of intros) []
-    |> filter (fn c => is_inductive_predicate thy c orelse is_registered thy c)
-
+fun depending_preds_of thy (key, value) =
+  let
+    val intros = (#intros o rep_pred_data) value
+  in
+    fold Term.add_const_names (map Thm.prop_of intros) []
+      |> filter (fn c => (not (c = key)) andalso (is_inductive_predicate thy c orelse is_registered thy c))
+  end;
+    
+    
 (* code dependency graph *)    
-
+(*
 fun dependencies_of thy name =
   let
     val (intros, elim, nparams) = fetch_pred_data thy name 
@@ -513,6 +520,7 @@
   in
     (data, keys)
   end;
+*)
 
 (* TODO: add_edges - by analysing dependencies *)
 fun add_intro thm thy = let
@@ -523,7 +531,7 @@
          (apfst (fn (intro, elim, nparams) => (thm::intro, elim, nparams)))) gr
      | NONE =>
        let
-         val nparams = the_default 0 (try (#3 o fetch_pred_data thy) name)
+         val nparams = the_default 0 (try (#nparams o rep_pred_data o (fetch_pred_data thy)) name)
        in Graph.new_node (name, mk_pred_data (([thm], NONE, nparams), ([], [], []))) gr end;
   in PredData.map cons_intro thy end
 
@@ -537,10 +545,14 @@
     fun set (intros, elim, _ ) = (intros, elim, nparams) 
   in PredData.map (Graph.map_node name (map_pred_data (apfst set))) end
     
-fun register_predicate (intros, elim, nparams) = let
-    val (name, _) = dest_Const (fst (strip_intro_concl nparams (prop_of (hd intros))))
+fun register_predicate (pre_intros, pre_elim, nparams) thy = let
+    val (name, _) = dest_Const (fst (strip_intro_concl nparams (prop_of (hd pre_intros))))
+    (* preprocessing *)
+    val intros = ind_set_codegen_preproc thy (map (preprocess_intro thy) pre_intros)
+    val elim = singleton (ind_set_codegen_preproc thy) (preprocess_elim thy nparams pre_elim)
   in
-    PredData.map (Graph.new_node (name, mk_pred_data ((intros, SOME elim, nparams), ([], [], []))))
+    PredData.map
+      (Graph.new_node (name, mk_pred_data ((intros, SOME elim, nparams), ([], [], [])))) thy
   end
 
 fun set_generator_name pred mode name = 
@@ -1813,9 +1825,9 @@
     val nparams = nparams_of thy (hd prednames)
     val preds = distinct (op =) (map (dest_Const o fst o (strip_intro_concl nparams)) intrs)
     val extra_modes = all_modes_of thy |> filter_out (fn (name, _) => member (op =) prednames name)
-    val _ = Output.tracing ("extra_modes are: " ^
+    (*val _ = Output.tracing ("extra_modes are: " ^
       cat_lines (map (fn (name, modes) => name ^ " has modes:" ^
-      (commas (map string_of_mode modes))) extra_modes))
+      (commas (map string_of_mode modes))) extra_modes)) *)
     val _ $ u = Logic.strip_imp_concl (hd intrs);
     val params = List.take (snd (strip_comb u), nparams);
     val param_vs = maps term_vs params
@@ -1884,9 +1896,23 @@
   thy''
 end
 
+fun extend' value_of edges_of key (G, visited) =
+  let
+    val _ = Output.tracing ("calling extend' with " ^ key)  
+    val (G', v) = case try (Graph.get_node G) key of
+        SOME v => (G, v)
+      | NONE => (Graph.new_node (key, value_of key) G, value_of key)
+    val (G'', visited') = fold (extend' value_of edges_of) (edges_of (key, v) \\ visited)
+      (G', key :: visited) 
+  in
+    (fold (Graph.add_edge o (pair key)) (edges_of (key, v)) G'', visited')
+  end;
+
+fun extend value_of edges_of key G = fst (extend' value_of edges_of key (G, [])) 
+  
 fun gen_add_equations steps names thy =
   let
-    val thy' = PredData.map (fold (Graph.extend (dependencies_of thy)) names) thy |> Theory.checkpoint;
+    val thy' = PredData.map (fold (extend (fetch_pred_data thy) (depending_preds_of thy)) names) thy |> Theory.checkpoint;
     fun strong_conn_of gr keys =
       Graph.strong_conn (Graph.subgraph (member (op =) (Graph.all_succs gr keys)) gr)
     val scc = strong_conn_of (PredData.get thy') names
@@ -1962,16 +1988,13 @@
 (* TODO: must create state to prove multiple cases *)
 fun generic_code_pred prep_const raw_const lthy =
   let
-  
     val thy = ProofContext.theory_of lthy
     val const = prep_const thy raw_const
-    val _ = Output.tracing "extending graph"
-    val lthy' = LocalTheory.theory (PredData.map (Graph.extend (dependencies_of thy) const)) lthy
+    val lthy' = LocalTheory.theory (PredData.map
+        (extend (fetch_pred_data thy) (depending_preds_of thy) const)) lthy
       |> LocalTheory.checkpoint
-    val _ = Output.tracing "code_pred graph extended..."  
     val thy' = ProofContext.theory_of lthy'
     val preds = Graph.all_preds (PredData.get thy') [const] |> filter_out (has_elim thy')
-    
     fun mk_cases const =
       let
         val nparams = nparams_of thy' const