now the predicate compilere handles the predicate without introduction rules better as before
authorbulwahn
Sat, 24 Oct 2009 16:55:43 +0200
changeset 33146 bf852ef586f2
parent 33145 1a22f7ca1dfc
child 33147 180dc60bd88c
now the predicate compilere handles the predicate without introduction rules better as before
src/HOL/Tools/Predicate_Compile/pred_compile_pred.ML
src/HOL/Tools/Predicate_Compile/predicate_compile.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
src/HOL/ex/Predicate_Compile_ex.thy
--- a/src/HOL/Tools/Predicate_Compile/pred_compile_pred.ML	Sat Oct 24 16:55:43 2009 +0200
+++ b/src/HOL/Tools/Predicate_Compile/pred_compile_pred.ML	Sat Oct 24 16:55:43 2009 +0200
@@ -126,7 +126,7 @@
     val (intros', (local_defs, thy')) = flatten_intros constname intros thy
     val (intross, thy'') = fold_map preprocess local_defs thy'
   in
-    (intros' :: flat intross,thy'')
+    ((constname, intros') :: flat intross,thy'')
   end;
 
 fun preprocess_term t thy = error "preprocess_pred_term: to implement" 
@@ -158,14 +158,11 @@
             val ([definition], thy') = thy
               |> Sign.add_consts_i [(Binding.name constname, constT, NoSyn)]
               |> PureThy.add_defs false [((Binding.name (constname ^ "_def"), def), [])]
-            
           in
             (list_comb (Logic.varify const, vars), ((full_constname, [definition])::new_defs, thy'))
           end
         | replace_abs_arg arg (new_defs, thy) = (arg, (new_defs, thy))
-
-      val (args', (new_defs', thy')) = fold_map replace_abs_arg args (new_defs, thy)
-            (*        val _ = if not (null abs_args) then error "Found some abs argument" else ()*)
+        val (args', (new_defs', thy')) = fold_map replace_abs_arg args (new_defs, thy)
       in
         (list_comb (pred, args'), (new_defs', thy'))
       end
@@ -178,7 +175,13 @@
       in
         (th, (new_defs, thy))
       end
-    val (intross', (new_defs, thy')) = fold_map (fold_map flat_intro) intross ([], thy)
+    fun fold_map_spec f [] s = ([], s)
+      | fold_map_spec f ((c, ths) :: specs) s =
+        let
+          val (ths', s') = f ths s
+          val (specs', s'') = fold_map_spec f specs s'
+        in ((c, ths') :: specs', s'') end
+    val (intross', (new_defs, thy')) = fold_map_spec (fold_map flat_intro) intross ([], thy)
   in
     (intross', (new_defs, thy'))
   end
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile.ML	Sat Oct 24 16:55:43 2009 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile.ML	Sat Oct 24 16:55:43 2009 +0200
@@ -29,14 +29,18 @@
 fun print_intross options thy msg intross =
   if show_intermediate_results options then
    Output.tracing (msg ^ 
-    (space_implode "; " (map 
-      (fn intros => commas (map (Display.string_of_thm_global thy) intros)) intross)))
+    (space_implode "\n" (map 
+      (fn (c, intros) => "Introduction rule(s) of " ^ c ^ ":\n" ^
+         commas (map (Display.string_of_thm_global thy) intros)) intross)))
   else ()
       
 fun print_specs thy specs =
   map (fn (c, thms) => "Constant " ^ c ^ " has specification:\n"
     ^ (space_implode "\n" (map (Display.string_of_thm_global thy) thms)) ^ "\n") specs
 
+fun map_specs f specs =
+  map (fn (s, ths) => (s, f ths)) specs
+
 fun process_specification options specs thy' =
   let
     val _ = print_step options "Compiling predicates to flat introrules..."
@@ -47,10 +51,10 @@
     val _ = print_step options "Replacing functions in introrules..."
     val intross2 =
       if fail_safe_mode then
-        case try (burrow (maps (Predicate_Compile_Fun.rewrite_intro thy''))) intross1 of
+        case try (map_specs (maps (Predicate_Compile_Fun.rewrite_intro thy''))) intross1 of
           SOME intross => intross
         | NONE => let val _ = warning "Function replacement failed!" in intross1 end
-      else burrow (maps (Predicate_Compile_Fun.rewrite_intro thy'')) intross1
+      else map_specs (maps (Predicate_Compile_Fun.rewrite_intro thy'')) intross1
     val _ = print_intross options thy'' "Introduction rules with replaced functions: " intross2
     val _ = print_step options "Introducing new constants for abstractions at higher-order argument positions..."
     val (intross3, (new_defs, thy''')) = Predicate_Compile_Pred.flat_higher_order_arguments (intross2, thy'')
@@ -80,10 +84,10 @@
     val specs = (get_specs prednames) @ fun_pred_specs
     val (intross3, thy''') = process_specification options specs thy'
     val _ = print_intross options thy''' "Introduction rules with new constants: " intross3
-    val intross4 = map (maps remove_pointless_clauses) intross3
+    val intross4 = map_specs (maps remove_pointless_clauses) intross3
     val _ = print_intross options thy''' "After removing pointless clauses: " intross4
-    val intross5 = map (map (AxClass.overload thy''')) intross4
-    val intross6 = map (map (expand_tuples thy''')) intross5
+      (*val intross5 = map (fn s, ths) => ( s, map (AxClass.overload thy''') ths)) intross4*)
+    val intross6 = map_specs (map (expand_tuples thy''')) intross4
     val _ = print_intross options thy''' "introduction rules before registering: " intross6
     val _ = print_step options "Registering introduction rules..."
     val thy'''' = fold Predicate_Compile_Core.register_intros intross6 thy'''
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Sat Oct 24 16:55:43 2009 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Sat Oct 24 16:55:43 2009 +0200
@@ -12,8 +12,8 @@
   type smode = (int * int list option) list
   type mode = smode option list * smode
   datatype tmode = Mode of mode * smode * tmode option list;
-  val register_predicate : (thm list * thm * int) -> theory -> theory
-  val register_intros : thm list -> theory -> theory
+  val register_predicate : (string * thm list * thm * int) -> theory -> theory
+  val register_intros : string * thm list -> theory -> theory
   val is_registered : theory -> string -> bool
   val predfun_intro_of: theory -> string -> mode -> thm
   val predfun_elim_of: theory -> string -> mode -> thm
@@ -34,7 +34,7 @@
   val set_nparams : string -> int -> theory -> theory
   val print_stored_rules: theory -> unit
   val print_all_modes: theory -> unit
-  val mk_casesrule : Proof.context -> int -> thm list -> term
+  val mk_casesrule : Proof.context -> term -> int -> thm list -> term
   val eval_ref : (unit -> term Predicate.pred) option Unsynchronized.ref
   val random_eval_ref : (unit -> int * int -> term Predicate.pred * (int * int)) option Unsynchronized.ref
   val code_pred_intros_attrib : attribute
@@ -465,46 +465,55 @@
    end) handle Type.TUNIFY =>
      (warning "Occurrences of recursive constant have non-unifiable types"; (cs, intr_ts));
 
-fun import_intros _ [] ctxt = ([], ctxt)
-  | import_intros nparams (th :: ths) ctxt =
+fun import_intros inp_pred nparams [] ctxt =
+  let
+    val ([outp_pred], ctxt') = Variable.import_terms false [inp_pred] ctxt
+    val (paramTs, _) = chop nparams (binder_types (fastype_of outp_pred))
+    val (param_names, ctxt'') = Variable.variant_fixes (map (fn i => "p" ^ (string_of_int i))
+      (1 upto nparams)) ctxt'
+    val params = map Free (param_names ~~ paramTs)
+    in (((outp_pred, params), []), ctxt') end
+  | import_intros inp_pred nparams (th :: ths) ctxt =
     let
       val ((_, [th']), ctxt') = Variable.import false [th] ctxt
       val thy = ProofContext.theory_of ctxt'
       val (pred, (params, args)) = strip_intro_concl nparams (prop_of th')
       val ho_args = filter (is_predT o fastype_of) args
+      fun subst_of (pred', pred) =
+        let
+          val subst = Sign.typ_match thy (fastype_of pred', fastype_of pred) Vartab.empty
+        in map (fn (indexname, (s, T)) => ((indexname, s), T)) (Vartab.dest subst) end
       fun instantiate_typ th =
         let
           val (pred', _) = strip_intro_concl 0 (prop_of th)
           val _ = if not (fst (dest_Const pred) = fst (dest_Const pred')) then
             error "Trying to instantiate another predicate" else ()
-          val subst = Sign.typ_match thy
-            (fastype_of pred', fastype_of pred) Vartab.empty
-          val subst' = map (fn (indexname, (s, T)) => ((indexname, s), T))
-            (Vartab.dest subst)
-        in Thm.certify_instantiate (subst', []) th end;
+        in Thm.certify_instantiate (subst_of (pred', pred), []) th end;
       fun instantiate_ho_args th =
         let
           val (_, (params', args')) = strip_intro_concl nparams (prop_of th)
           val ho_args' = map dest_Var (filter (is_predT o fastype_of) args')
         in Thm.certify_instantiate ([], map dest_Var params' ~~ params) th end
+      val outp_pred =
+        Term_Subst.instantiate (subst_of (inp_pred, pred), []) inp_pred
       val ((_, ths'), ctxt1) =
         Variable.import false (map (instantiate_typ #> instantiate_ho_args) ths) ctxt'
     in
-      (th' :: ths', ctxt1)
+      (((outp_pred, params), th' :: ths'), ctxt1)
     end
 
 (* generation of case rules from user-given introduction rules *)
 
-fun mk_casesrule ctxt nparams introrules =
+fun mk_casesrule ctxt pred nparams introrules =
   let
-    val (intros_th, ctxt1) = import_intros nparams introrules ctxt
+    val (((pred, params), intros_th), ctxt1) = import_intros pred nparams introrules ctxt
     val intros = map prop_of intros_th
-    val (pred, (params, args)) = strip_intro_concl nparams (hd intros)
     val ([propname], ctxt2) = Variable.variant_fixes ["thesis"] ctxt1
     val prop = HOLogic.mk_Trueprop (Free (propname, HOLogic.boolT))
+    val (_, argsT) = chop nparams (binder_types (fastype_of pred))
     val (argnames, ctxt3) = Variable.variant_fixes
-      (map (fn i => "a" ^ string_of_int i) (1 upto (length args))) ctxt2
-    val argvs = map2 (curry Free) argnames (map fastype_of args)
+      (map (fn i => "a" ^ string_of_int i) (1 upto length argsT)) ctxt2
+    val argvs = map2 (curry Free) argnames argsT
     fun mk_case intro =
       let
         val (_, (_, args)) = strip_intro_concl nparams intro
@@ -596,14 +605,16 @@
           in (fst (dest_Const const) = name) end;      
         val intros = ind_set_codegen_preproc thy
           (map (expand_tuples thy #> preprocess_intro thy) (filter is_intro_of (#intrs result)))
-        val pre_elim = nth (#elims result) (find_index (fn s => s = name) (#names (fst info)))
+        val index = find_index (fn s => s = name) (#names (fst info))
+        val pre_elim = nth (#elims result) index
+        val pred = nth (#preds result) index
         val nparams = length (Inductive.params_of (#raw_induct result))
         (*val elim = singleton (ind_set_codegen_preproc thy) (preprocess_elim thy nparams 
           (expand_tuples_elim pre_elim))*)
         val elim =
           (Drule.standard o (setmp quick_and_dirty true (SkipProof.make_thm thy)))
-          (mk_casesrule (ProofContext.init thy) nparams intros)
-        val (intros, elim) = if null intros then noclause thy name elim else (intros, elim)
+          (mk_casesrule (ProofContext.init thy) pred nparams intros)
+        val (intros, elim) = (*if null intros then noclause thy name elim else*) (intros, elim)
       in
         mk_pred_data ((intros, SOME elim, nparams), ([], [], []))
       end                                                                    
@@ -631,7 +642,6 @@
       |> filter (fn c => (not (c = key)) andalso (is_inductive_predicate thy c orelse is_registered thy c))
   end;
 
-
 (* code dependency graph *)
 (*
 fun dependencies_of thy name =
@@ -643,6 +653,7 @@
     (data, keys)
   end;
 *)
+
 (* guessing number of parameters *)
 fun find_indexes pred xs =
   let
@@ -678,31 +689,34 @@
 fun set_nparams name nparams = let
     fun set (intros, elim, _ ) = (intros, elim, nparams) 
   in PredData.map (Graph.map_node name (map_pred_data (apfst set))) end
-    
-fun register_predicate (pre_intros, pre_elim, nparams) thy =
+
+fun register_predicate (constname, pre_intros, pre_elim, nparams) thy =
   let
-    val (name, _) = dest_Const (fst (strip_intro_concl nparams (prop_of (hd pre_intros))))
     (* preprocessing *)
     val intros = ind_set_codegen_preproc thy (map (preprocess_intro thy) pre_intros)
     val elim = singleton (ind_set_codegen_preproc thy) (preprocess_elim thy nparams pre_elim)
   in
-    if not (member (op =) (Graph.keys (PredData.get thy)) name) then
+    if not (member (op =) (Graph.keys (PredData.get thy)) constname) then
       PredData.map
-        (Graph.new_node (name, mk_pred_data ((intros, SOME elim, nparams), ([], [], [])))) thy
+        (Graph.new_node (constname, mk_pred_data ((intros, SOME elim, nparams), ([], [], [])))) thy
     else thy
   end
 
-fun register_intros pre_intros thy =
+fun register_intros (constname, pre_intros) thy =
   let
-    val (c, T) = dest_Const (fst (strip_intro_concl 0 (prop_of (hd pre_intros))))
+    val T = Sign.the_const_type thy constname
     fun constname_of_intro intr = fst (dest_Const (fst (strip_intro_concl 0 (prop_of intr))))
-    val _ = if not (forall (fn intr => constname_of_intro intr = c) pre_intros) then
-      error "register_intros: Introduction rules of different constants are used" else ()
+    val _ = if not (forall (fn intr => constname_of_intro intr = constname) pre_intros) then
+      error ("register_intros: Introduction rules of different constants are used\n" ^
+        "expected rules for " ^ constname ^ ", but received rules for " ^
+          commas (map constname_of_intro pre_intros))
+      else ()
+    val pred = Const (constname, T)
     val nparams = guess_nparams T
     val pre_elim = 
       (Drule.standard o (setmp quick_and_dirty true (SkipProof.make_thm thy)))
-      (mk_casesrule (ProofContext.init thy) nparams pre_intros)
-  in register_predicate (pre_intros, pre_elim, nparams) thy end
+      (mk_casesrule (ProofContext.init thy) pred nparams pre_intros)
+  in register_predicate (constname, pre_intros, pre_elim, nparams) thy end
 
 fun set_generator_name pred mode name = 
   let
@@ -1074,7 +1088,8 @@
   else ()
 
 fun check_modes_pred options with_generator thy param_vs clauses modes gen_modes (p, ms) =
-  let val SOME rs = AList.lookup (op =) clauses p
+  let
+    val rs = case AList.lookup (op =) clauses p of SOME rs => rs | NONE => []
   in (p, List.filter (fn m => case find_index
     (is_none o check_mode_clause with_generator thy param_vs modes gen_modes m) rs of
       ~1 => true
@@ -1083,7 +1098,7 @@
 
 fun get_modes_pred with_generator thy param_vs clauses modes gen_modes (p, ms) =
   let
-    val SOME rs = AList.lookup (op =) clauses p
+    val rs = case AList.lookup (op =) clauses p of SOME rs => rs | NONE => []
   in
     (p, map (fn m =>
       (m, map (the o check_mode_clause with_generator thy param_vs modes gen_modes m) rs)) ms)
@@ -1391,7 +1406,9 @@
       map (compile_clause compilation_modifiers compfuns
         thy all_vs param_vs additional_arguments mode (mk_tuple in_ts)) moded_cls;
     val compilation = #wrap_compilation compilation_modifiers compfuns s T mode additional_arguments
-        (foldr1 (mk_sup compfuns) cl_ts)
+      (if null cl_ts then
+        mk_bot compfuns (mk_tupleT Us2)
+      else foldr1 (mk_sup compfuns) cl_ts)
     val fun_const =
       Const (#const_name_of compilation_modifiers thy s mode,
         #funT_of compilation_modifiers compfuns mode T)
@@ -1733,7 +1750,7 @@
     (* need better control here! *)
   end
 
-fun prove_clause thy nargs modes (iss, is) (_, clauses) (ts, moded_ps) =
+fun prove_clause options thy nargs modes (iss, is) (_, clauses) (ts, moded_ps) =
   let
     val (in_ts, clause_out_ts) = split_smode is ts;
     fun prove_prems out_ts [] =
@@ -1789,7 +1806,8 @@
       end;
     val prems_tac = prove_prems in_ts moded_ps
   in
-    rtac @{thm bindI} 1
+    print_tac' options "Proving clause..."
+    THEN rtac @{thm bindI} 1
     THEN rtac @{thm singleI} 1
     THEN prems_tac
   end;
@@ -1811,7 +1829,7 @@
     THEN (EVERY (map
            (fn i => EVERY' (select_sup (length moded_clauses) i) i) 
              (1 upto (length moded_clauses))))
-    THEN (EVERY (map2 (prove_clause thy nargs modes mode) clauses moded_clauses))
+    THEN (EVERY (map2 (prove_clause options thy nargs modes mode) clauses moded_clauses))
     THEN print_tac "proved one direction"
   end;
 
@@ -1973,7 +1991,7 @@
     THEN prems_tac
   end;
  
-fun prove_other_direction thy modes pred mode moded_clauses =
+fun prove_other_direction options thy modes pred mode moded_clauses =
   let
     fun prove_clause clause i =
       (if i < length moded_clauses then etac @{thm supE} 1 else all_tac)
@@ -1983,7 +2001,9 @@
      THEN (REPEAT_DETERM (CHANGED (rewtac @{thm split_paired_all})))
      THEN (rtac (predfun_intro_of thy pred mode) 1)
      THEN (REPEAT_DETERM (rtac @{thm refl} 2))
-     THEN (EVERY (map2 prove_clause moded_clauses (1 upto (length moded_clauses))))
+     THEN (if null moded_clauses then
+         etac @{thm botE} 1
+       else EVERY (map2 prove_clause moded_clauses (1 upto (length moded_clauses))))
   end;
 
 (** proof procedure **)
@@ -1991,7 +2011,7 @@
 fun prove_pred options thy clauses preds modes pred mode (moded_clauses, compiled_term) =
   let
     val ctxt = ProofContext.init thy
-    val clauses = the (AList.lookup (op =) clauses pred)
+    val clauses = case AList.lookup (op =) clauses pred of SOME rs => rs | NONE => []
   in
     Goal.prove ctxt (Term.add_free_names compiled_term []) [] compiled_term
       (if not (skip_proof options) then
@@ -2000,7 +2020,7 @@
 				THEN print_tac' options "after pred_iffI"
         THEN prove_one_direction options thy clauses preds modes pred mode moded_clauses
         THEN print_tac' options "proved one direction"
-        THEN prove_other_direction thy modes pred mode moded_clauses
+        THEN prove_other_direction options thy modes pred mode moded_clauses
         THEN print_tac' options "proved other direction")
       else (fn _ => setmp quick_and_dirty true SkipProof.cheat_tac thy))
   end;
@@ -2051,13 +2071,19 @@
   let
     val intrs = map prop_of intros
     val nparams = nparams_of thy (hd prednames)
-    val preds = distinct (fn ((c1, _), (c2, _)) => c1 = c2) (map (dest_Const o fst o (strip_intro_concl nparams)) intrs)
-    val (preds, intrs) = unify_consts thy (map Const preds) intrs
+    val preds = map (fn c => Const (c, Sign.the_const_type thy c)) prednames
+    val (preds, intrs) = unify_consts thy preds intrs
     val ([preds, intrs], _) = fold_burrow (Variable.import_terms false) [preds, intrs] (ProofContext.init thy)
     val preds = map dest_Const preds
     val extra_modes = all_modes_of thy |> filter_out (fn (name, _) => member (op =) prednames name)
-    val _ $ u = Logic.strip_imp_concl (hd intrs);
-    val params = List.take (snd (strip_comb u), nparams);
+    val params = case intrs of
+        [] =>
+          let
+            val (paramTs, _) = chop nparams (binder_types (snd (hd preds)))
+            val param_names = Name.variant_list [] (map (fn i => "p" ^ string_of_int i) (1 upto length paramTs))
+          in map Free (param_names ~~ paramTs) end
+      | intr :: _ => fst (chop nparams
+        (snd (strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl intr)))))
     val param_vs = maps term_vs params
     val all_vs = terms_vs intrs
     fun add_clause intr (clauses, arities) =
@@ -2314,9 +2340,11 @@
     val preds = Graph.all_preds (PredData.get thy') [const] |> filter_out (has_elim thy')
     fun mk_cases const =
       let
+        val T = Sign.the_const_type thy const
+        val pred = Const (const, T)
         val nparams = nparams_of thy' const
         val intros = intros_of thy' const
-      in mk_casesrule lthy' nparams intros end  
+      in mk_casesrule lthy' pred nparams intros end  
     val cases_rules = map mk_cases preds
     val cases =
       map (fn case_rule => RuleCases.Case {fixes = [],
--- a/src/HOL/ex/Predicate_Compile_ex.thy	Sat Oct 24 16:55:43 2009 +0200
+++ b/src/HOL/ex/Predicate_Compile_ex.thy	Sat Oct 24 16:55:43 2009 +0200
@@ -43,7 +43,7 @@
     "append [] xs xs"
   | "append xs ys zs \<Longrightarrow> append (x # xs) ys (x # zs)"
 
-code_pred (mode: [1, 2], [3], [2, 3], [1, 3], [1, 2, 3]) [inductify] append .
+code_pred (mode: [1, 2], [3], [2, 3], [1, 3], [1, 2, 3]) append .
 code_pred [depth_limited] append .
 code_pred [rpred] append .
 
@@ -366,7 +366,7 @@
 
 thm lexn.rpred_equation
 
-code_pred [inductify] lenlex .
+code_pred [inductify, show_steps] lenlex .
 thm lenlex.equation
 
 code_pred [inductify, rpred] lenlex .
@@ -425,13 +425,20 @@
 code_pred [inductify] Image .
 thm Image.equation
 (*TODO: *)
-(*code_pred [inductify] Id_on .*)
+ML {* Toplevel.debug := true *}
+declare Id_on_def[unfolded UNION_def, code_pred_def]
+
+code_pred [inductify] Id_on .
+thm Id_on.equation
 code_pred [inductify] Domain .
 thm Domain.equation
 code_pred [inductify] Range .
 thm sym_def
 code_pred [inductify] Field .
-(* code_pred [inductify] refl_on .*)
+declare Sigma_def[unfolded UNION_def, code_pred_def]
+declare refl_on_def[unfolded UNION_def, code_pred_def]
+code_pred [inductify] refl_on .
+thm refl_on.equation
 code_pred [inductify] total_on .
 thm total_on.equation
 (*