tuned proof procedure; added size-limiting predicate compilation of higher order predicates; added guessing of number parameters for registrating predicates; removed debug messages
authorbulwahn
Tue, 04 Aug 2009 08:34:56 +0200
changeset 32315 79f324944be4
parent 32314 66bbad0bfef9
child 32316 1d83ac469459
tuned proof procedure; added size-limiting predicate compilation of higher order predicates; added guessing of number parameters for registrating predicates; removed debug messages
src/HOL/ex/predicate_compile.ML
--- a/src/HOL/ex/predicate_compile.ML	Tue Aug 04 08:34:56 2009 +0200
+++ b/src/HOL/ex/predicate_compile.ML	Tue Aug 04 08:34:56 2009 +0200
@@ -93,6 +93,9 @@
   val compile_clause : compilation_funs -> term option -> (term list -> term) ->
     theory -> string list -> string list -> mode -> term -> moded_clause -> term
   val preprocess_intro : theory -> thm -> thm
+  val is_constrt : theory -> term -> bool
+  val is_predT : typ -> bool
+  val guess_nparams : typ -> int
 end;
 
 structure Predicate_Compile : PREDICATE_COMPILE =
@@ -105,7 +108,7 @@
 fun tracing s = (if ! Toplevel.debug then Output.tracing s else ());
 
 fun print_tac s = Seq.single; (* (if ! Toplevel.debug then Tactical.print_tac s else Seq.single); *)
-fun debug_tac msg = (fn st => (Output.tracing msg; Seq.single st));
+fun debug_tac msg = Seq.single; (* (fn st => (Output.tracing msg; Seq.single st)); *)
 
 val do_proofs = ref true;
 
@@ -160,7 +163,6 @@
   let
     val (_, B) = dest_funT (fastype_of t)
     val (C, A) = dest_funT (fastype_of u)
-    val _ = tracing (Syntax.string_of_typ_global @{theory} A)
   in
     Const(@{const_name "Fun.comp"}, (A --> B) --> (C --> A) --> C --> B) $ t $ u
   end;
@@ -521,17 +523,32 @@
     (data, keys)
   end;
 *)
+(* guessing number of parameters *)
+fun find_indexes pred xs =
+  let
+    fun find is n [] = is
+      | find is n (x :: xs) = find (if pred x then (n :: is) else is) (n + 1) xs;
+  in rev (find [] 0 xs) end;
 
-(* TODO: add_edges - by analysing dependencies *)
+fun is_predT (T as Type("fun", [_, _])) = (snd (strip_type T) = HOLogic.boolT)
+  | is_predT _ = false
+  
+fun guess_nparams T =
+  let
+    val argTs = binder_types T
+    val nparams = fold (curry Int.max)
+      (map (fn x => x + 1) (find_indexes is_predT argTs)) 0
+  in nparams end;
+
 fun add_intro thm thy = let
-   val (name, _) = dest_Const (fst (strip_intro_concl 0 (prop_of thm)))
+   val (name, T) = dest_Const (fst (strip_intro_concl 0 (prop_of thm)))
    fun cons_intro gr =
      case try (Graph.get_node gr) name of
        SOME pred_data => Graph.map_node name (map_pred_data
          (apfst (fn (intro, elim, nparams) => (thm::intro, elim, nparams)))) gr
      | NONE =>
        let
-         val nparams = the_default 0 (try (#nparams o rep_pred_data o (fetch_pred_data thy)) name)
+         val nparams = the_default (guess_nparams T)  (try (#nparams o rep_pred_data o (fetch_pred_data thy)) name)
        in Graph.new_node (name, mk_pred_data (([thm], NONE, nparams), ([], [], []))) gr end;
   in PredData.map cons_intro thy end
 
@@ -1094,37 +1111,43 @@
   | compile_param_ext _ _ _ _ = error "compile params"
 *)
 
-fun compile_param thy compfuns (NONE, t) = t
-  | compile_param thy compfuns (m as SOME (Mode ((iss, is'), is, ms)), t) =
+fun compile_param size thy compfuns (NONE, t) = t
+  | compile_param size thy compfuns (m as SOME (Mode ((iss, is'), is, ms)), t) =
    let
      val (f, args) = strip_comb (Envir.eta_contract t)
      val (params, args') = chop (length ms) args
-     val params' = map (compile_param thy compfuns) (ms ~~ params)
+     val params' = map (compile_param size thy compfuns) (ms ~~ params)
+     val mk_fun_of = case size of NONE => mk_fun_of | SOME _ => mk_sizelim_fun_of
+     val funT_of = case size of NONE => funT_of | SOME _ => sizelim_funT_of
      val f' =
        case f of
          Const (name, T) =>
            mk_fun_of compfuns thy (name, T) (iss, is')
        | Free (name, T) => Free (name, funT_of compfuns (iss, is') T)
+       | _ => error ("PredicateCompiler: illegal parameter term")
    in list_comb (f', params' @ args') end
    
 fun compile_expr size thy ((Mode (mode, is, ms)), t) =
   case strip_comb t of
     (Const (name, T), params) =>
        let
-         val params' = map (compile_param thy PredicateCompFuns.compfuns) (ms ~~ params)
+         val params' = map (compile_param size thy PredicateCompFuns.compfuns) (ms ~~ params)
+         val mk_fun_of = case size of NONE => mk_fun_of | SOME _ => mk_sizelim_fun_of
        in
-         case size of
-           NONE => list_comb (mk_fun_of PredicateCompFuns.compfuns thy (name, T) mode, params')
-         | SOME _ => list_comb (mk_sizelim_fun_of PredicateCompFuns.compfuns thy (name, T) mode, params')
+         list_comb (mk_fun_of PredicateCompFuns.compfuns thy (name, T) mode, params')
        end
   | (Free (name, T), args) =>
-       list_comb (Free (name, funT_of PredicateCompFuns.compfuns ([], is) T), args)
-          
-fun compile_gen_expr thy compfuns ((Mode (mode, is, ms)), t) =
+       let 
+         val funT_of = case size of NONE => funT_of | SOME _ => sizelim_funT_of 
+       in
+         list_comb (Free (name, funT_of PredicateCompFuns.compfuns ([], is) T), args)
+       end;
+       
+fun compile_gen_expr size thy compfuns ((Mode (mode, is, ms)), t) =
   case strip_comb t of
     (Const (name, T), params) =>
        let
-         val params' = map (compile_param thy compfuns) (ms ~~ params)
+         val params' = map (compile_param size thy compfuns) (ms ~~ params)
        in
          list_comb (mk_generator_of compfuns thy (name, T) mode, params')
        end
@@ -1215,7 +1238,7 @@
                    val args = case size of
                      NONE => in_ts
                    | SOME size_t => in_ts @ [size_t]
-                   val u = list_comb (compile_gen_expr thy compfuns (mode, t), args)
+                   val u = list_comb (compile_gen_expr size thy compfuns (mode, t), args)
                    val rest = compile_prems out_ts''' vs' names'' ps
                  in
                    (u, rest)
@@ -1239,6 +1262,7 @@
 fun compile_pred compfuns mk_fun_of use_size thy all_vs param_vs s T mode moded_cls =
   let
     val (Ts1, (Us1, Us2)) = split_mode mode (binder_types T)
+    val funT_of = if use_size then sizelim_funT_of else funT_of 
     val Ts1' = map2 (fn NONE => I | SOME is => funT_of compfuns ([], is)) (fst mode) Ts1
     val xnames = Name.variant_list (all_vs @ param_vs)
       (map (fn i => "x" ^ string_of_int i) (snd mode));
@@ -1319,7 +1343,6 @@
   val funpropI = HOLogic.mk_Trueprop (PredicateCompFuns.mk_Eval (list_comb (funtrm, funargs),
                    mk_tuple outargs))
   val introtrm = Logic.list_implies (predpropI :: param_eqs, funpropI)
-  val _ = Output.tracing (Syntax.string_of_term_global thy introtrm) 
   val simprules = [defthm, @{thm eval_pred},
                    @{thm "split_beta"}, @{thm "fst_conv"}, @{thm "snd_conv"}]
   val unfolddef_tac = Simplifier.asm_full_simp_tac (HOL_basic_ss addsimps simprules) 1
@@ -1344,7 +1367,6 @@
   
 fun create_definitions preds (name, modes) thy =
   let
-    val _ = tracing "create definitions"
     val compfuns = PredicateCompFuns.compfuns
     val T = AList.lookup (op =) preds name |> the
     fun create_definition (mode as (iss, is)) thy = let
@@ -1394,9 +1416,10 @@
     fun create_definition mode thy =
       let
         val mode_cname = create_constname_of_mode thy "sizelim_" name mode
-        val (Ts1, (Us1, Us2)) = split_mode mode (binder_types T)
-        val Ts1' = map2 (fn NONE => I | SOME is => funT_of PredicateCompFuns.compfuns ([], is)) (fst mode) Ts1
-        val funT = (Ts1' @ Us1 @ [@{typ "code_numeral"}]) ---> (PredicateCompFuns.mk_predT (mk_tupleT Us2)) 
+        (* val (Ts1, (Us1, Us2)) = split_mode mode (binder_types T)
+        val Ts1' = map2 (fn NONE => I | SOME is => size_funT_of PredicateCompFuns.compfuns ([], is)) (fst mode) Ts1
+         (Ts1' @ Us1 @ [@{typ "code_numeral"}]) ---> (PredicateCompFuns.mk_predT (mk_tupleT Us2)) *)
+        val funT = sizelim_funT_of PredicateCompFuns.compfuns mode T
       in
         thy |> Sign.add_consts_i [(Binding.name (Long_Name.base_name mode_cname), funT, NoSyn)]
         |> set_sizelim_function_name name mode mode_cname 
@@ -1447,7 +1470,7 @@
 fun prove_param thy (NONE, t) = TRY (rtac @{thm refl} 1)
   | prove_param thy (m as SOME (Mode (mode, is, ms)), t) =
   let
-    val  (f, args) = strip_comb t
+    val  (f, args) = strip_comb (Envir.eta_contract t)
     val (params, _) = chop (length ms) args
     val f_tac = case f of
       Const (name, T) => simp_tac (HOL_basic_ss addsimps 
@@ -1465,7 +1488,6 @@
     THEN (REPEAT_DETERM (atac 1))
   end
 
-    THEN print_tac "after prove_param"
 fun prove_expr thy (Mode (mode, is, ms), t, us) (premposition : int) =
   case strip_comb t of
     (Const (name, T), args) =>  
@@ -1642,13 +1664,14 @@
 
 fun prove_param2 thy (NONE, t) = all_tac 
   | prove_param2 thy (m as SOME (Mode (mode, is, ms)), t) = let
-    val  (f, args) = strip_comb t
+    val  (f, args) = strip_comb (Envir.eta_contract t)
     val (params, _) = chop (length ms) args
     val f_tac = case f of
         Const (name, T) => full_simp_tac (HOL_basic_ss addsimps 
            (@{thm eval_pred}::(predfun_definition_of thy name mode)
            :: @{thm "Product_Type.split_conv"}::[])) 1
       | Free _ => all_tac
+      | _ => error "prove_param2: illegal parameter term"
   in  
     print_tac "before simplification in prove_args:"
     THEN f_tac
@@ -1870,7 +1893,7 @@
   val _ = Output.tracing ("Starting predicate compiler for predicates " ^ commas prednames ^ "...")
   val (preds, nparams, all_vs, param_vs, extra_modes, clauses, arities) =
     prepare_intrs thy prednames
-  val _ = tracing "Infering modes..."
+  val _ = Output.tracing "Infering modes..."
   val moded_clauses = #infer_modes steps thy extra_modes arities param_vs clauses 
   val modes = map (fn (p, mps) => (p, map fst mps)) moded_clauses
   val _ = Output.tracing "Defining executable functions..."
@@ -1898,7 +1921,6 @@
 
 fun extend' value_of edges_of key (G, visited) =
   let
-    val _ = Output.tracing ("calling extend' with " ^ key)  
     val (G', v) = case try (Graph.get_node G) key of
         SOME v => (G, v)
       | NONE => (Graph.new_node (key, value_of key) G, value_of key)