refactoring predicate compiler; repaired proof procedure to handle all test cases
authorbulwahn
Tue, 04 Aug 2009 08:34:56 +0200
changeset 32313 a984c04927b4
parent 32312 26a9d0c69b8b
child 32314 66bbad0bfef9
refactoring predicate compiler; repaired proof procedure to handle all test cases
src/HOL/ex/predicate_compile.ML
--- a/src/HOL/ex/predicate_compile.ML	Tue Aug 04 08:34:56 2009 +0200
+++ b/src/HOL/ex/predicate_compile.ML	Tue Aug 04 08:34:56 2009 +0200
@@ -104,8 +104,8 @@
 
 fun tracing s = (if ! Toplevel.debug then Output.tracing s else ());
 
-fun print_tac s = Seq.single; (* (if ! Toplevel.debug then Tactical.print_tac s else Seq.single); *)
-fun debug_tac msg = Seq.single; (* (fn st => (tracing msg; Seq.single st)); *)
+fun print_tac s = (if ! Toplevel.debug then Tactical.print_tac s else Seq.single);
+fun debug_tac msg = (fn st => (Output.tracing msg; Seq.single st));
 
 val do_proofs = ref true;
 
@@ -425,41 +425,45 @@
       (Trueprop_conv (Conv.try_conv (Conv.rewr_conv (Thm.symmetric @{thm Predicate.eq_is_eq})))))
     (Thm.transfer thy rule)
 
-fun preprocess_elim thy nargs elimrule = let
-   fun replace_eqs (Const ("Trueprop", _) $ (Const ("op =", T) $ lhs $ rhs)) =
-      HOLogic.mk_Trueprop (Const (@{const_name Predicate.eq}, T) $ lhs $ rhs)
-    | replace_eqs t = t
-   fun preprocess_case t =
-   let
-     val params = Logic.strip_params t
-     val (assums1, assums2) = chop nargs (Logic.strip_assums_hyp t)
-     val assums_hyp' = assums1 @ (map replace_eqs assums2)
-     in list_all (params, Logic.list_implies (assums_hyp', Logic.strip_assums_concl t)) end
-   val prems = Thm.prems_of elimrule
-   val cases' = map preprocess_case (tl prems)
-   val elimrule' = Logic.list_implies ((hd prems) :: cases', Thm.concl_of elimrule)
- in
-   Thm.equal_elim
-     (Thm.symmetric (Conv.implies_concl_conv (MetaSimplifier.rewrite true [@{thm eq_is_eq}])
-        (cterm_of thy elimrule')))
-     elimrule
- end;
+fun preprocess_elim thy nparams elimrule =
+  let
+    fun replace_eqs (Const ("Trueprop", _) $ (Const ("op =", T) $ lhs $ rhs)) =
+       HOLogic.mk_Trueprop (Const (@{const_name Predicate.eq}, T) $ lhs $ rhs)
+     | replace_eqs t = t
+    val prems = Thm.prems_of elimrule
+    val nargs = length (snd (strip_comb (HOLogic.dest_Trueprop (hd prems)))) - nparams
+    fun preprocess_case t =
+     let
+       val params = Logic.strip_params t
+       val (assums1, assums2) = chop nargs (Logic.strip_assums_hyp t)
+       val assums_hyp' = assums1 @ (map replace_eqs assums2)
+     in
+       list_all (params, Logic.list_implies (assums_hyp', Logic.strip_assums_concl t))
+     end 
+    val cases' = map preprocess_case (tl prems)
+    val elimrule' = Logic.list_implies ((hd prems) :: cases', Thm.concl_of elimrule)
+  in
+    Thm.equal_elim
+      (Thm.symmetric (Conv.implies_concl_conv (MetaSimplifier.rewrite true [@{thm eq_is_eq}])
+         (cterm_of thy elimrule')))
+      elimrule
+  end;
 
 (* special case: predicate with no introduction rule *)
-fun noclause thy predname = let
+fun noclause thy predname elim = let
   val T = (Logic.unvarifyT o Sign.the_const_type thy) predname
   val Ts = binder_types T
   val names = Name.variant_list []
         (map (fn i => "x" ^ (string_of_int i)) (1 upto (length Ts)))
   val vs = map2 (curry Free) names Ts
-  val clausehd =  HOLogic.mk_Trueprop (list_comb(Const (predname, T), vs))
+  val clausehd = HOLogic.mk_Trueprop (list_comb (Const (predname, T), vs))
   val intro_t = Logic.mk_implies (@{prop False}, clausehd)
   val P = HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT))
   val elim_t = Logic.list_implies ([clausehd, Logic.mk_implies (@{prop False}, P)], P)
   val intro = Goal.prove (ProofContext.init thy) names [] intro_t
         (fn {...} => etac @{thm FalseE} 1)
   val elim = Goal.prove (ProofContext.init thy) ("P" :: names) [] elim_t
-        (fn {...} => etac (the_elim_of thy predname) 1) 
+        (fn {...} => etac elim 1) 
 in
   ([intro], elim, 0)
 end
@@ -471,11 +475,15 @@
         fun is_intro_of intro =
           let
             val (const, _) = strip_comb (HOLogic.dest_Trueprop (concl_of intro))
-          in (fst (dest_Const const) = name) end;
-        val intros = map (preprocess_intro thy) (filter is_intro_of (#intrs result)) 
-        val elim = nth (#elims result) (find_index (fn s => s = name) (#names (fst info)))
+          in (fst (dest_Const const) = name) end;      
+        val intros = ind_set_codegen_preproc thy ((map (preprocess_intro thy))
+          (filter is_intro_of (#intrs result)))
+        val pre_elim = nth (#elims result) (find_index (fn s => s = name) (#names (fst info)))
         val nparams = length (Inductive.params_of (#raw_induct result))
-      in if null intros then noclause thy name else (intros, elim, nparams) end                                                                    
+        val elim = singleton (ind_set_codegen_preproc thy) (preprocess_elim thy nparams pre_elim)
+      in
+        if null intros then noclause thy name elim else (intros, elim, nparams)
+      end                                                                    
   | NONE => error ("No such predicate: " ^ quote name)
   
 (* updaters *)
@@ -529,11 +537,10 @@
     fun set (intros, elim, _ ) = (intros, elim, nparams) 
   in PredData.map (Graph.map_node name (map_pred_data (apfst set))) end
     
-fun register_predicate (intros, elim, nparams) thy = let
+fun register_predicate (intros, elim, nparams) = let
     val (name, _) = dest_Const (fst (strip_intro_concl nparams (prop_of (hd intros))))
   in
-    PredData.map (Graph.new_node (name, mk_pred_data ((intros, SOME elim, nparams), ([], [], [])))
-      #> fold Graph.add_edge (map (pair name) (depending_preds_of thy intros))) thy
+    PredData.map (Graph.new_node (name, mk_pred_data ((intros, SOME elim, nparams), ([], [], []))))
   end
 
 fun set_generator_name pred mode name = 
@@ -586,7 +593,7 @@
   let
     val Ts = binder_types T
     val (paramTs, (inargTs, outargTs)) = split_mode (iss, is) Ts
-    val paramTs' = map2 (fn SOME is => funT_of compfuns ([], is) | NONE => I) iss paramTs 
+    val paramTs' = map2 (fn NONE => I | SOME is => funT_of compfuns ([], is)) iss paramTs 
   in
     (paramTs' @ inargTs) ---> (mk_predT compfuns (mk_tupleT outargTs))
   end;
@@ -733,16 +740,6 @@
       RPredCompFuns.mk_rpredT T) $ random
   end;
  
-
-(* Remark: types of param_funT_of and funT_of are swapped - which is the more
-canonical order? *)
-(* maybe remove param_funT_of completely - by using funT_of *)
-fun param_funT_of compfuns T NONE = T
-  | param_funT_of compfuns T (SOME mode) =
-   let
-     val (Us1, Us2) = split_smode mode (binder_types T)
-   in Us1 ---> (mk_predT compfuns (mk_tupleT Us2)) end;
-
 (* Mode analysis *)
 
 (*** check if a term contains only constructor functions ***)
@@ -1087,7 +1084,7 @@
 
 fun compile_param thy compfuns (NONE, t) = t
   | compile_param thy compfuns (m as SOME (Mode ((iss, is'), is, ms)), t) =
-   let  
+   let
      val (f, args) = strip_comb (Envir.eta_contract t)
      val (params, args') = chop (length ms) args
      val params' = map (compile_param thy compfuns) (ms ~~ params)
@@ -1095,7 +1092,7 @@
        case f of
          Const (name, T) =>
            mk_fun_of compfuns thy (name, T) (iss, is')
-       | Free (name, T) => Free (name, param_funT_of compfuns T (SOME is'))
+       | Free (name, T) => Free (name, funT_of compfuns (iss, is') T)
    in list_comb (f', params' @ args') end
    
 fun compile_expr size thy ((Mode (mode, is, ms)), t) =
@@ -1105,11 +1102,11 @@
          val params' = map (compile_param thy PredicateCompFuns.compfuns) (ms ~~ params)
        in
          case size of
-           NONE => list_comb (mk_fun_of PredicateCompFuns.compfuns thy (name, T) mode, params)
-         | SOME _ => list_comb (mk_sizelim_fun_of PredicateCompFuns.compfuns thy (name, T) mode, params)
+           NONE => list_comb (mk_fun_of PredicateCompFuns.compfuns thy (name, T) mode, params')
+         | SOME _ => list_comb (mk_sizelim_fun_of PredicateCompFuns.compfuns thy (name, T) mode, params')
        end
   | (Free (name, T), args) =>
-       list_comb (Free (name, param_funT_of PredicateCompFuns.compfuns T (SOME is)), args)
+       list_comb (Free (name, funT_of PredicateCompFuns.compfuns ([], is) T), args)
           
 fun compile_gen_expr thy compfuns ((Mode (mode, is, ms)), t) =
   case strip_comb t of
@@ -1230,7 +1227,7 @@
 fun compile_pred compfuns mk_fun_of use_size thy all_vs param_vs s T mode moded_cls =
   let
     val (Ts1, (Us1, Us2)) = split_mode mode (binder_types T)
-    val Ts1' = map2 (param_funT_of compfuns) Ts1 (fst mode)
+    val Ts1' = map2 (fn NONE => I | SOME is => funT_of compfuns ([], is)) (fst mode) Ts1
     val xnames = Name.variant_list (all_vs @ param_vs)
       (map (fn i => "x" ^ string_of_int i) (snd mode));
     val size_name = Name.variant (all_vs @ param_vs @ xnames) "size"
@@ -1293,15 +1290,16 @@
   val argnames = Name.variant_list []
         (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
   val (Ts1, Ts2) = chop (length iss) Ts;
-  val Ts1' = map2 (param_funT_of (PredicateCompFuns.compfuns)) Ts1 iss
+  val Ts1' = map2 (fn NONE => I | SOME is => funT_of (PredicateCompFuns.compfuns) ([], is)) iss Ts1
   val args = map Free (argnames ~~ (Ts1' @ Ts2))
-  val (params, (inargs, outargs)) = split_mode mode args 
+  val (params, ioargs) = chop (length iss) args
+  val (inargs, outargs) = split_smode is ioargs
   val param_names = Name.variant_list argnames
     (map (fn i => "p" ^ string_of_int i) (1 upto (length iss)))
   val param_vs = map Free (param_names ~~ Ts1)
   val (params', names) = fold_map mk_Eval_of ((params ~~ Ts1) ~~ iss) []
-  val predpropI = HOLogic.mk_Trueprop (list_comb (pred, param_vs @ inargs @ outargs))
-  val predpropE = HOLogic.mk_Trueprop (list_comb (pred, params' @ inargs @ outargs))
+  val predpropI = HOLogic.mk_Trueprop (list_comb (pred, param_vs @ ioargs))
+  val predpropE = HOLogic.mk_Trueprop (list_comb (pred, params' @ ioargs))
   val param_eqs = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (param_vs ~~ params')
   val funargs = params @ inargs
   val funpropE = HOLogic.mk_Trueprop (PredicateCompFuns.mk_Eval (list_comb (funtrm, funargs),
@@ -1309,7 +1307,7 @@
   val funpropI = HOLogic.mk_Trueprop (PredicateCompFuns.mk_Eval (list_comb (funtrm, funargs),
                    mk_tuple outargs))
   val introtrm = Logic.list_implies (predpropI :: param_eqs, funpropI)
-  val _ = tracing (Syntax.string_of_term_global thy introtrm) 
+  val _ = Output.tracing (Syntax.string_of_term_global thy introtrm) 
   val simprules = [defthm, @{thm eval_pred},
                    @{thm "split_beta"}, @{thm "fst_conv"}, @{thm "snd_conv"}]
   val unfolddef_tac = Simplifier.asm_full_simp_tac (HOL_basic_ss addsimps simprules) 1
@@ -1325,9 +1323,8 @@
   let
     fun string_of_mode mode = if null mode then "0"
       else space_implode "_" (map string_of_int mode)
-    fun string_of_HOmode m s =
-      case m of NONE => s | SOME mode => s ^ "_and_" ^ (string_of_mode mode)
-    val HOmode = fold string_of_HOmode (fst mode) ""
+    val HOmode = space_implode "_and_"
+      (fold (fn NONE => I | SOME mode => cons (string_of_mode mode)) (fst mode) [])
   in
     (Sign.full_bname thy (prefix ^ (Long_Name.base_name name))) ^
       (if HOmode = "" then "_" else "_for_" ^ HOmode ^ "_yields_") ^ (string_of_mode (snd mode))
@@ -1341,14 +1338,16 @@
     fun create_definition (mode as (iss, is)) thy = let
       val mode_cname = create_constname_of_mode thy "" name mode
       val mode_cbasename = Long_Name.base_name mode_cname
-      val Ts = binder_types T;
-      val (Ts1, (Us1, Us2)) = split_mode mode Ts;
-      val Ts1' = map2 (param_funT_of compfuns) Ts1 iss
+      val Ts = binder_types T
+      val (Ts1, Ts2) = chop (length iss) Ts
+      val (Us1, Us2) =  split_smode is Ts2
+      val Ts1' = map2 (fn NONE => I | SOME is => funT_of compfuns ([], is)) iss Ts1
       val funT = (Ts1' @ Us1) ---> (mk_predT compfuns (mk_tupleT Us2))
       val names = Name.variant_list []
         (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
-      val xs = map Free (names ~~ (Ts1' @ Us1 @ Us2));
-      val (xparams, (xins, xouts)) = split_mode mode xs;
+      val xs = map Free (names ~~ (Ts1' @ Ts2));                   
+      val (xparams, xargs) = chop (length iss) xs;
+      val (xins, xouts) = split_smode is xargs 
       val (xparams', names') = fold_map mk_Eval_of ((xparams ~~ Ts1) ~~ iss) names
       fun mk_split_lambda [] t = lambda (Free (Name.variant names' "x", HOLogic.unitT)) t
         | mk_split_lambda [x] t = lambda x t
@@ -1360,7 +1359,7 @@
           mk_split_lambda' xs t
         end;
       val predterm = PredicateCompFuns.mk_Enum (mk_split_lambda xouts
-        (list_comb (Const (name, T), xparams' @ xins @ xouts)))
+        (list_comb (Const (name, T), xparams' @ xargs)))
       val lhs = list_comb (Const (mode_cname, funT), xparams @ xins)
       val def = Logic.mk_equals (lhs, predterm)
       val ([definition], thy') = thy |>
@@ -1384,7 +1383,7 @@
       let
         val mode_cname = create_constname_of_mode thy "sizelim_" name mode
         val (Ts1, (Us1, Us2)) = split_mode mode (binder_types T)
-        val Ts1' = map2 (param_funT_of PredicateCompFuns.compfuns) Ts1 (fst mode)
+        val Ts1' = map2 (fn NONE => I | SOME is => funT_of PredicateCompFuns.compfuns ([], is)) (fst mode) Ts1
         val funT = (Ts1' @ Us1 @ [@{typ "code_numeral"}]) ---> (PredicateCompFuns.mk_predT (mk_tupleT Us2)) 
       in
         thy |> Sign.add_consts_i [(Binding.name (Long_Name.base_name mode_cname), funT, NoSyn)]
@@ -1401,7 +1400,7 @@
       let
         val mode_cname = create_constname_of_mode thy "gen_" name mode
         val (Ts1, (Us1, Us2)) = split_mode mode (binder_types T);
-        val Ts1' = map2 (param_funT_of RPredCompFuns.compfuns) Ts1 (fst mode)
+        val Ts1' = map2 (fn NONE => I | SOME is => funT_of RPredCompFuns.compfuns ([], is)) (fst mode) Ts1
         val funT = (Ts1' @ Us1 @ [@{typ "code_numeral"}]) ---> (RPredCompFuns.mk_rpredT (mk_tupleT Us2)) 
       in
         thy |> Sign.add_consts_i [(Binding.name (Long_Name.base_name mode_cname), funT, NoSyn)]
@@ -1433,53 +1432,40 @@
 
 (* MAJOR FIXME:  prove_params should be simple
  - different form of introrule for parameters ? *)
-fun prove_param thy (NONE, t) =
-  all_tac 
-| prove_param thy (m as SOME (Mode (mode, is, ms)), t) =
-  REPEAT_DETERM (etac @{thm thin_rl} 1)
-  THEN REPEAT_DETERM (rtac @{thm ext} 1)
-  THEN (rtac @{thm iffI} 1)
-  THEN print_tac "prove_param"
-  (* proof in one direction *)
-  THEN (atac 1)
-  (* proof in the other direction *)
-  THEN (atac 1)
-  THEN print_tac "after prove_param"
-(*  let
+fun prove_param thy (NONE, t) = TRY (rtac @{thm refl} 1)
+  | prove_param thy (m as SOME (Mode (mode, is, ms)), t) =
+  let
     val  (f, args) = strip_comb t
     val (params, _) = chop (length ms) args
     val f_tac = case f of
-        Const (name, T) => simp_tac (HOL_basic_ss addsimps 
-           (@{thm eval_pred}::(predfun_definition_of thy name mode)::
-           @{thm "Product_Type.split_conv"}::[])) 1
-      | Free _ => all_tac
-      | Abs _ => error "TODO: implement here"
-  in  
-    print_tac "before simplification in prove_args:"
+      Const (name, T) => simp_tac (HOL_basic_ss addsimps 
+         (@{thm eval_pred}::(predfun_definition_of thy name mode)::
+         @{thm "Product_Type.split_conv"}::[])) 1
+    | Free _ => TRY (rtac @{thm refl} 1)
+    | Abs _ => error "prove_param: No valid parameter term"
+  in
+    REPEAT_DETERM (etac @{thm thin_rl} 1)
+    THEN REPEAT_DETERM (rtac @{thm ext} 1)
+    THEN print_tac "prove_param"
     THEN f_tac
     THEN print_tac "after simplification in prove_args"
-    THEN (EVERY (map (prove_param thy modes) (ms ~~ params)))
+    THEN (EVERY (map (prove_param thy) (ms ~~ params)))
     THEN (REPEAT_DETERM (atac 1))
   end
-*)
+
+    THEN print_tac "after prove_param"
 fun prove_expr thy (Mode (mode, is, ms), t, us) (premposition : int) =
   case strip_comb t of
     (Const (name, T), args) =>  
       let
         val introrule = predfun_intro_of thy name mode
-        (*val (in_args, out_args) = split_mode is us
-        val (pred, rargs) = strip_comb (HOLogic.dest_Trueprop
-          (hd (Logic.strip_imp_prems (prop_of introrule))))
-        val nparams = length ms (* get_nparams thy (fst (dest_Const pred)) *)
-        val (_, args) = chop nparams rargs
-        val subst = map (pairself (cterm_of thy)) (args ~~ us)
-        val inst_introrule = Drule.cterm_instantiate subst introrule*)
         val (args1, args2) = chop (length ms) args
       in
         rtac @{thm bindI} 1
         THEN print_tac "before intro rule:"
         (* for the right assumption in first position *)
         THEN rotate_tac premposition 1
+        THEN debug_tac (Display.string_of_thm (ProofContext.init thy) introrule)
         THEN rtac introrule 1
         THEN print_tac "after intro rule"
         (* work with parameter arguments *)
@@ -1565,11 +1551,13 @@
               THEN (if (is_some name) then
                   simp_tac (HOL_basic_ss addsimps [predfun_definition_of thy (the name) (iss, is)]) 1
                   THEN rtac @{thm not_predI} 1
+                  THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1
+                  THEN (REPEAT_DETERM (atac 1))
                   (* FIXME: work with parameter arguments *)
                   THEN (EVERY (map (prove_param thy) (param_modes ~~ params)))
                 else
                   rtac @{thm not_predI'} 1)
-              THEN (REPEAT_DETERM (atac 1))
+                  THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1
               THEN rec_tac
             end
           | Sidecond t =>
@@ -1597,10 +1585,7 @@
   let
     val T = the (AList.lookup (op =) preds pred)
     val nargs = length (binder_types T) - nparams_of thy pred
-    (* FIXME: preprocessing! *)
-    val pred_case_rule = singleton (ind_set_codegen_preproc thy)
-      (preprocess_elim thy nargs (the_elim_of thy pred))
-    (* FIXME preprocessor |> Simplifier.full_simplify (HOL_basic_ss addsimps [@{thm Predicate.memb_code}])*)
+    val pred_case_rule = the_elim_of thy pred
   in
     REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"}))
     THEN etac (predfun_elim_of thy pred mode) 1
@@ -1642,7 +1627,7 @@
 -- join both functions
 *)
 (* TODO: remove function *)
-(*
+
 fun prove_param2 thy (NONE, t) = all_tac 
   | prove_param2 thy (m as SOME (Mode (mode, is, ms)), t) = let
     val  (f, args) = strip_comb t
@@ -1658,7 +1643,7 @@
     THEN print_tac "after simplification in prove_args"
     THEN (EVERY (map (prove_param2 thy) (ms ~~ params)))
   end
-*)
+
 
 fun prove_expr2 thy (Mode (mode, is, ms), t) = 
   (case strip_comb t of
@@ -1670,7 +1655,7 @@
         (prop_of (predfun_elim_of thy name mode))))
       THEN (etac (predfun_elim_of thy name mode) 1)
       THEN print_tac "prove_expr2"
-      THEN (EVERY (map (prove_param thy) (ms ~~ args)))
+      THEN (EVERY (map (prove_param2 thy) (ms ~~ args)))
       THEN print_tac "finished prove_expr2"      
     | _ => etac @{thm bindE} 1)
     
@@ -1697,9 +1682,6 @@
 fun prove_clause2 thy modes pred (iss, is) (ts, ps) i =
   let
     val pred_intro_rule = nth (intros_of thy pred) (i - 1)
-      |> preprocess_intro thy
-      |> (fn thm => hd (ind_set_codegen_preproc thy [thm]))
-      (* FIXME preprocess |> Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}]) *)
     val (in_ts, clause_out_ts) = split_smode is ts;
     fun prove_prems2 out_ts [] =
       print_tac "before prove_match2 - last call:"
@@ -1739,7 +1721,8 @@
             THEN (if is_some name then
                 full_simp_tac (HOL_basic_ss addsimps [predfun_definition_of thy (the name) (iss, is)]) 1 
                 THEN etac @{thm not_predE} 1
-                THEN (EVERY (map (prove_param thy) (param_modes ~~ params)))
+                THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1
+                THEN (EVERY (map (prove_param2 thy) (param_modes ~~ params)))
               else
                 etac @{thm not_predE'} 1)
             THEN rec_tac
@@ -1982,9 +1965,10 @@
   
     val thy = ProofContext.theory_of lthy
     val const = prep_const thy raw_const
-    
+    val _ = Output.tracing "extending graph"
     val lthy' = LocalTheory.theory (PredData.map (Graph.extend (dependencies_of thy) const)) lthy
       |> LocalTheory.checkpoint
+    val _ = Output.tracing "code_pred graph extended..."  
     val thy' = ProofContext.theory_of lthy'
     val preds = Graph.all_preds (PredData.get thy') [const] |> filter_out (has_elim thy')