contextifying the proof procedure in the predicate compiler
authorbulwahn
Mon, 22 Mar 2010 08:30:13 +0100
changeset 35888 d902054e7ac6
parent 35887 f704ba9875f6
child 35889 c1f86c5d3827
contextifying the proof procedure in the predicate compiler
src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Mon Mar 22 08:30:13 2010 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Mon Mar 22 08:30:13 2010 +0100
@@ -17,8 +17,8 @@
   val is_registered : theory -> string -> bool
   val function_name_of : Predicate_Compile_Aux.compilation -> theory
     -> string -> bool * Predicate_Compile_Aux.mode -> string
-  val predfun_intro_of: theory -> string -> Predicate_Compile_Aux.mode -> thm
-  val predfun_elim_of: theory -> string -> Predicate_Compile_Aux.mode -> thm
+  val predfun_intro_of: Proof.context -> string -> Predicate_Compile_Aux.mode -> thm
+  val predfun_elim_of: Proof.context -> string -> Predicate_Compile_Aux.mode -> thm
   val all_preds_of : theory -> string list
   val modes_of: Predicate_Compile_Aux.compilation
     -> theory -> string -> Predicate_Compile_Aux.mode list
@@ -244,7 +244,7 @@
 
 fun the_elim_of thy name = case #elim (the_pred_data thy name)
  of NONE => error ("No elimination rule for predicate " ^ quote name)
-  | SOME thm => Thm.transfer thy thm 
+  | SOME thm => Thm.transfer thy thm
   
 val has_elim = is_some o #elim oo the_pred_data;
 
@@ -282,13 +282,13 @@
       " of predicate " ^ name)
   | SOME data => data;
 
-val predfun_definition_of = #definition ooo the_predfun_data
+val predfun_definition_of = #definition ooo the_predfun_data o ProofContext.theory_of
 
-val predfun_intro_of = #intro ooo the_predfun_data
+val predfun_intro_of = #intro ooo the_predfun_data o ProofContext.theory_of
 
-val predfun_elim_of = #elim ooo the_predfun_data
+val predfun_elim_of = #elim ooo the_predfun_data o ProofContext.theory_of
 
-val predfun_neg_intro_of = #neg_intro ooo the_predfun_data
+val predfun_neg_intro_of = #neg_intro ooo the_predfun_data o ProofContext.theory_of
 
 (* diagnostic display functions *)
 
@@ -1900,7 +1900,7 @@
 (* MAJOR FIXME:  prove_params should be simple
  - different form of introrule for parameters ? *)
 
-fun prove_param options thy nargs t deriv =
+fun prove_param options ctxt nargs t deriv =
   let
     val  (f, args) = strip_comb (Envir.eta_contract t)
     val mode = head_mode_of deriv
@@ -1908,20 +1908,17 @@
     val ho_args = ho_args_of mode args
     val f_tac = case f of
       Const (name, T) => simp_tac (HOL_basic_ss addsimps 
-         [@{thm eval_pred}, predfun_definition_of thy name mode,
+         [@{thm eval_pred}, predfun_definition_of ctxt name mode,
          @{thm split_eta}, @{thm split_beta}, @{thm fst_conv},
          @{thm snd_conv}, @{thm pair_collapse}, @{thm Product_Type.split_conv}]) 1
     | Free _ =>
-      (* rewrite with parameter equation *)
-    (* test: *)
-      Subgoal.FOCUS_PREMS (fn {context = ctxt, params = params, prems = prems,
-      asms = a, concl = concl, schematics = s} =>
+      Subgoal.FOCUS_PREMS (fn {context = ctxt, params = params, prems, asms, concl, schematics} =>
         let
           val prems' = maps dest_conjunct_prem (take nargs prems)
         in
           MetaSimplifier.rewrite_goal_tac
             (map (fn th => th RS @{thm sym} RS @{thm eq_reflection}) prems') 1
-        end) (ProofContext.init thy) 1 (* FIXME: proper context handling *)
+        end) ctxt 1
     | Abs _ => raise Fail "prove_param: No valid parameter term"
   in
     REPEAT_DETERM (rtac @{thm ext} 1)
@@ -1929,16 +1926,16 @@
     THEN f_tac 
     THEN print_tac options "after prove_param"
     THEN (REPEAT_DETERM (atac 1))
-    THEN (EVERY (map2 (prove_param options thy nargs) ho_args param_derivations))
+    THEN (EVERY (map2 (prove_param options ctxt nargs) ho_args param_derivations))
     THEN REPEAT_DETERM (rtac @{thm refl} 1)
   end
 
-fun prove_expr options thy nargs (premposition : int) (t, deriv) =
+fun prove_expr options ctxt nargs (premposition : int) (t, deriv) =
   case strip_comb t of
     (Const (name, T), args) =>
       let
         val mode = head_mode_of deriv
-        val introrule = predfun_intro_of thy name mode
+        val introrule = predfun_intro_of ctxt name mode
         val param_derivations = param_derivations_of deriv
         val ho_args = ho_args_of mode args
       in
@@ -1950,13 +1947,12 @@
         THEN atac 1
         THEN print_tac options "parameter goal"
         (* work with parameter arguments *)
-        THEN (EVERY (map2 (prove_param options thy nargs) ho_args param_derivations))
+        THEN (EVERY (map2 (prove_param options ctxt nargs) ho_args param_derivations))
         THEN (REPEAT_DETERM (atac 1))
       end
   | (Free _, _) =>
     print_tac options "proving parameter call.."
-    THEN Subgoal.FOCUS_PREMS (fn {context = ctxt, params = params, prems = prems,
-      asms = a, concl = cl, schematics = s} =>
+    THEN Subgoal.FOCUS_PREMS (fn {context = ctxt, params, prems, asms, concl, schematics} =>
         let
           val param_prem = nth prems premposition
           val (param, _) = strip_comb (HOLogic.dest_Trueprop (prop_of param_prem))
@@ -1970,14 +1966,14 @@
             param_prem
         in
           rtac param_prem' 1
-        end) (ProofContext.init thy) 1 (* FIXME: proper context handling *)
+        end) ctxt 1
     THEN print_tac options "after prove parameter call"
 
 fun SOLVED tac st = FILTER (fn st' => nprems_of st' = nprems_of st - 1) tac st;
 
 fun SOLVEDALL tac st = FILTER (fn st' => nprems_of st' = 0) tac st
 
-fun check_format thy st =
+fun check_format ctxt st =
   let
     val concl' = Logic.strip_assums_concl (hd (prems_of st))
     val concl = HOLogic.dest_Trueprop concl'
@@ -1992,8 +1988,9 @@
       ((*tracing "expression is not valid";*) Seq.empty) (*error "check_format: wrong format"*)
   end
 
-fun prove_match options thy (out_ts : term list) =
+fun prove_match options ctxt out_ts =
   let
+    val thy = ProofContext.theory_of ctxt
     fun get_case_rewrite t =
       if (is_constructor thy t) then let
         val case_rewrites = (#case_rewrites (Datatype.the_info thy
@@ -2006,7 +2003,7 @@
      (* make this simpset better! *)
     asm_full_simp_tac (HOL_basic_ss' addsimps simprules) 1
     THEN print_tac options "after prove_match:"
-    THEN (DETERM (TRY (EqSubst.eqsubst_tac (ProofContext.init thy) [0] [@{thm HOL.if_P}] 1
+    THEN (DETERM (TRY (EqSubst.eqsubst_tac ctxt [0] [@{thm HOL.if_P}] 1
            THEN (REPEAT_DETERM (rtac @{thm conjI} 1 THEN (SOLVED (asm_simp_tac HOL_basic_ss' 1))))
            THEN print_tac options "if condition to be solved:"
            THEN (SOLVED (asm_simp_tac HOL_basic_ss' 1 THEN print_tac options "after if simp; in SOLVED:"))
@@ -2017,8 +2014,9 @@
 
 (* corresponds to compile_fun -- maybe call that also compile_sidecond? *)
 
-fun prove_sidecond thy t =
+fun prove_sidecond ctxt t =
   let
+    val thy = ProofContext.theory_of ctxt
     fun preds_of t nameTs = case strip_comb t of 
       (f as Const (name, T), args) =>
         if is_registered thy name then (name, T) :: nameTs
@@ -2026,7 +2024,7 @@
       | _ => nameTs
     val preds = preds_of t []
     val defs = map
-      (fn (pred, T) => predfun_definition_of thy pred
+      (fn (pred, T) => predfun_definition_of ctxt pred
         (all_input_of T))
         preds
   in 
@@ -2036,11 +2034,11 @@
     (* need better control here! *)
   end
 
-fun prove_clause options thy nargs mode (_, clauses) (ts, moded_ps) =
+fun prove_clause options ctxt nargs mode (_, clauses) (ts, moded_ps) =
   let
     val (in_ts, clause_out_ts) = split_mode mode ts;
     fun prove_prems out_ts [] =
-      (prove_match options thy out_ts)
+      (prove_match options ctxt out_ts)
       THEN print_tac options "before simplifying assumptions"
       THEN asm_full_simp_tac HOL_basic_ss' 1
       THEN print_tac options "before single intro rule"
@@ -2060,7 +2058,7 @@
               print_tac options "before clause:"
               (*THEN asm_simp_tac HOL_basic_ss 1*)
               THEN print_tac options "before prove_expr:"
-              THEN prove_expr options thy nargs premposition (t, deriv)
+              THEN prove_expr options ctxt nargs premposition (t, deriv)
               THEN print_tac options "after prove_expr:"
               THEN rec_tac
             end
@@ -2071,7 +2069,8 @@
               val rec_tac = prove_prems out_ts''' ps
               val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE)
               val neg_intro_rule =
-                Option.map (fn name => the (predfun_neg_intro_of thy name mode)) name
+                Option.map (fn name =>
+                  the (predfun_neg_intro_of ctxt name mode)) name
               val param_derivations = param_derivations_of deriv
               val params = ho_args_of mode args
             in
@@ -2085,7 +2084,7 @@
                   THEN etac (the neg_intro_rule) 1
                   THEN rotate_tac (~premposition) 1
                   THEN print_tac options "after applying not introduction rule"
-                  THEN (EVERY (map2 (prove_param options thy nargs) params param_derivations))
+                  THEN (EVERY (map2 (prove_param options ctxt nargs) params param_derivations))
                   THEN (REPEAT_DETERM (atac 1))
                 else
                   rtac @{thm not_predI'} 1
@@ -2098,10 +2097,10 @@
           | Sidecond t =>
            rtac @{thm if_predI} 1
            THEN print_tac options "before sidecond:"
-           THEN prove_sidecond thy t
+           THEN prove_sidecond ctxt t
            THEN print_tac options "after sidecond:"
            THEN prove_prems [] ps)
-      in (prove_match options thy out_ts)
+      in (prove_match options ctxt out_ts)
           THEN rest_tac
       end;
     val prems_tac = prove_prems in_ts moded_ps
@@ -2116,51 +2115,55 @@
   | select_sup _ 1 = [rtac @{thm supI1}]
   | select_sup n i = (rtac @{thm supI2})::(select_sup (n - 1) (i - 1));
 
-fun prove_one_direction options thy clauses preds pred mode moded_clauses =
+fun prove_one_direction options ctxt clauses preds pred mode moded_clauses =
   let
+    val thy = ProofContext.theory_of ctxt
     val T = the (AList.lookup (op =) preds pred)
     val nargs = length (binder_types T)
     val pred_case_rule = the_elim_of thy pred
   in
     REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"}))
     THEN print_tac options "before applying elim rule"
-    THEN etac (predfun_elim_of thy pred mode) 1
+    THEN etac (predfun_elim_of ctxt pred mode) 1
     THEN etac pred_case_rule 1
     THEN print_tac options "after applying elim rule"
     THEN (EVERY (map
            (fn i => EVERY' (select_sup (length moded_clauses) i) i) 
              (1 upto (length moded_clauses))))
-    THEN (EVERY (map2 (prove_clause options thy nargs mode) clauses moded_clauses))
+    THEN (EVERY (map2 (prove_clause options ctxt nargs mode) clauses moded_clauses))
     THEN print_tac options "proved one direction"
   end;
 
 (** Proof in the other direction **)
 
-fun prove_match2 options thy out_ts = let
-  fun split_term_tac (Free _) = all_tac
-    | split_term_tac t =
-      if (is_constructor thy t) then let
-        val info = Datatype.the_info thy ((fst o dest_Type o fastype_of) t)
-        val num_of_constrs = length (#case_rewrites info)
-        (* special treatment of pairs -- because of fishing *)
-        val split_rules = case (fst o dest_Type o fastype_of) t of
-          "*" => [@{thm prod.split_asm}] 
-          | _ => PureThy.get_thms thy (((fst o dest_Type o fastype_of) t) ^ ".split_asm")
-        val (_, ts) = strip_comb t
-      in
-        (print_tac options ("Term " ^ (Syntax.string_of_term_global thy t) ^ 
-          "splitting with rules \n" ^
-        commas (map (Display.string_of_thm_global thy) split_rules)))
-        THEN TRY ((Splitter.split_asm_tac split_rules 1)
-        THEN (print_tac options "after splitting with split_asm rules")
-        (* THEN (Simplifier.asm_full_simp_tac HOL_basic_ss 1)
-          THEN (DETERM (TRY (etac @{thm Pair_inject} 1)))*)
-          THEN (REPEAT_DETERM_N (num_of_constrs - 1)
-            (etac @{thm botE} 1 ORELSE etac @{thm botE} 2)))
-        THEN (assert_tac (Max_number_of_subgoals 2))
-        THEN (EVERY (map split_term_tac ts))
-      end
-    else all_tac
+fun prove_match2 options ctxt out_ts =
+  let
+    val thy = ProofContext.theory_of ctxt
+    fun split_term_tac (Free _) = all_tac
+      | split_term_tac t =
+        if (is_constructor thy t) then
+          let
+            val info = Datatype.the_info thy ((fst o dest_Type o fastype_of) t)
+            val num_of_constrs = length (#case_rewrites info)
+            (* special treatment of pairs -- because of fishing *)
+            val split_rules = case (fst o dest_Type o fastype_of) t of
+              "*" => [@{thm prod.split_asm}] 
+              | _ => PureThy.get_thms thy (((fst o dest_Type o fastype_of) t) ^ ".split_asm")
+            val (_, ts) = strip_comb t
+          in
+            (print_tac options ("Term " ^ (Syntax.string_of_term ctxt t) ^ 
+              "splitting with rules \n" ^
+            commas (map (Display.string_of_thm ctxt) split_rules)))
+            THEN TRY ((Splitter.split_asm_tac split_rules 1)
+            THEN (print_tac options "after splitting with split_asm rules")
+            (* THEN (Simplifier.asm_full_simp_tac HOL_basic_ss 1)
+              THEN (DETERM (TRY (etac @{thm Pair_inject} 1)))*)
+              THEN (REPEAT_DETERM_N (num_of_constrs - 1)
+                (etac @{thm botE} 1 ORELSE etac @{thm botE} 2)))
+            THEN (assert_tac (Max_number_of_subgoals 2))
+            THEN (EVERY (map split_term_tac ts))
+          end
+      else all_tac
   in
     split_term_tac (HOLogic.mk_tuple out_ts)
     THEN (DETERM (TRY ((Splitter.split_asm_tac [@{thm "split_if_asm"}] 1)
@@ -2172,7 +2175,7 @@
 *)
 (* TODO: remove function *)
 
-fun prove_param2 options thy t deriv =
+fun prove_param2 options ctxt t deriv =
   let
     val (f, args) = strip_comb (Envir.eta_contract t)
     val mode = head_mode_of deriv
@@ -2180,7 +2183,7 @@
     val ho_args = ho_args_of mode args
     val f_tac = case f of
         Const (name, T) => full_simp_tac (HOL_basic_ss addsimps 
-           (@{thm eval_pred}::(predfun_definition_of thy name mode)
+           (@{thm eval_pred}::(predfun_definition_of ctxt name mode)
            :: @{thm "Product_Type.split_conv"}::[])) 1
       | Free _ => all_tac
       | _ => error "prove_param2: illegal parameter term"
@@ -2188,10 +2191,10 @@
     print_tac options "before simplification in prove_args:"
     THEN f_tac
     THEN print_tac options "after simplification in prove_args"
-    THEN EVERY (map2 (prove_param2 options thy) ho_args param_derivations)
+    THEN EVERY (map2 (prove_param2 options ctxt) ho_args param_derivations)
   end
 
-fun prove_expr2 options thy (t, deriv) = 
+fun prove_expr2 options ctxt (t, deriv) = 
   (case strip_comb t of
       (Const (name, T), args) =>
         let
@@ -2202,25 +2205,22 @@
           etac @{thm bindE} 1
           THEN (REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"})))
           THEN print_tac options "prove_expr2-before"
-          THEN etac (predfun_elim_of thy name mode) 1
+          THEN etac (predfun_elim_of ctxt name mode) 1
           THEN print_tac options "prove_expr2"
-          THEN (EVERY (map2 (prove_param2 options thy) ho_args param_derivations))
+          THEN (EVERY (map2 (prove_param2 options ctxt) ho_args param_derivations))
           THEN print_tac options "finished prove_expr2"
         end
       | _ => etac @{thm bindE} 1)
 
-(* FIXME: what is this for? *)
-(* replace defined by has_mode thy pred *)
-(* TODO: rewrite function *)
-fun prove_sidecond2 options thy t = let
+fun prove_sidecond2 options ctxt t = let
   fun preds_of t nameTs = case strip_comb t of 
     (f as Const (name, T), args) =>
-      if is_registered thy name then (name, T) :: nameTs
+      if is_registered (ProofContext.theory_of ctxt) name then (name, T) :: nameTs
         else fold preds_of args nameTs
     | _ => nameTs
   val preds = preds_of t []
   val defs = map
-    (fn (pred, T) => predfun_definition_of thy pred 
+    (fn (pred, T) => predfun_definition_of ctxt pred 
       (all_input_of T))
       preds
   in
@@ -2230,13 +2230,13 @@
    THEN print_tac options "after sidecond2 simplification"
    end
   
-fun prove_clause2 options thy pred mode (ts, ps) i =
+fun prove_clause2 options ctxt pred mode (ts, ps) i =
   let
-    val pred_intro_rule = nth (intros_of thy pred) (i - 1)
+    val pred_intro_rule = nth (intros_of (ProofContext.theory_of ctxt) pred) (i - 1)
     val (in_ts, clause_out_ts) = split_mode mode ts;
     fun prove_prems2 out_ts [] =
       print_tac options "before prove_match2 - last call:"
-      THEN prove_match2 options thy out_ts
+      THEN prove_match2 options ctxt out_ts
       THEN print_tac options "after prove_match2 - last call:"
       THEN (etac @{thm singleE} 1)
       THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1))
@@ -2262,7 +2262,7 @@
             val (_, out_ts''') = split_mode mode us
             val rec_tac = prove_prems2 out_ts''' ps
           in
-            (prove_expr2 options thy (t, deriv)) THEN rec_tac
+            (prove_expr2 options ctxt (t, deriv)) THEN rec_tac
           end
         | Negprem t =>
           let
@@ -2277,10 +2277,10 @@
             THEN etac @{thm bindE} 1
             THEN (if is_some name then
                 full_simp_tac (HOL_basic_ss addsimps
-                  [predfun_definition_of thy (the name) mode]) 1
+                  [predfun_definition_of ctxt (the name) mode]) 1
                 THEN etac @{thm not_predE} 1
                 THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1
-                THEN (EVERY (map2 (prove_param2 options thy) ho_args param_derivations))
+                THEN (EVERY (map2 (prove_param2 options ctxt) ho_args param_derivations))
               else
                 etac @{thm not_predE'} 1)
             THEN rec_tac
@@ -2288,10 +2288,10 @@
         | Sidecond t =>
           etac @{thm bindE} 1
           THEN etac @{thm if_predE} 1
-          THEN prove_sidecond2 options thy t
+          THEN prove_sidecond2 options ctxt t
           THEN prove_prems2 [] ps)
       in print_tac options "before prove_match2:"
-         THEN prove_match2 options thy out_ts
+         THEN prove_match2 options ctxt out_ts
          THEN print_tac options "after prove_match2:"
          THEN rest_tac
       end;
@@ -2305,15 +2305,15 @@
     THEN prems_tac
   end;
  
-fun prove_other_direction options thy pred mode moded_clauses =
+fun prove_other_direction options ctxt pred mode moded_clauses =
   let
     fun prove_clause clause i =
       (if i < length moded_clauses then etac @{thm supE} 1 else all_tac)
-      THEN (prove_clause2 options thy pred mode clause i)
+      THEN (prove_clause2 options ctxt pred mode clause i)
   in
     (DETERM (TRY (rtac @{thm unit.induct} 1)))
      THEN (REPEAT_DETERM (CHANGED (rewtac @{thm split_paired_all})))
-     THEN (rtac (predfun_intro_of thy pred mode) 1)
+     THEN (rtac (predfun_intro_of ctxt pred mode) 1)
      THEN (REPEAT_DETERM (rtac @{thm refl} 2))
      THEN (if null moded_clauses then
          etac @{thm botE} 1
@@ -2332,9 +2332,9 @@
         (fn _ =>
         rtac @{thm pred_iffI} 1
         THEN print_tac options "after pred_iffI"
-        THEN prove_one_direction options thy clauses preds pred mode moded_clauses
+        THEN prove_one_direction options ctxt clauses preds pred mode moded_clauses
         THEN print_tac options "proved one direction"
-        THEN prove_other_direction options thy pred mode moded_clauses
+        THEN prove_other_direction options ctxt pred mode moded_clauses
         THEN print_tac options "proved other direction")
       else (fn _ => Skip_Proof.cheat_tac thy))
   end;
@@ -2478,6 +2478,7 @@
 
 fun add_code_equations thy preds result_thmss =
   let
+    val ctxt = ProofContext.init thy
     fun add_code_equation (predname, T) (pred, result_thms) =
       let
         val full_mode = fold_rev (curry Fun) (map (K Input) (binder_types T)) Bool
@@ -2493,10 +2494,10 @@
             val rhs = @{term Predicate.holds} $ (list_comb (predfun, args))
             val eq_term = HOLogic.mk_Trueprop
               (HOLogic.mk_eq (list_comb (Const (predname, T), args), rhs))
-            val def = predfun_definition_of thy predname full_mode
+            val def = predfun_definition_of ctxt predname full_mode
             val tac = fn _ => Simplifier.simp_tac
               (HOL_basic_ss addsimps [def, @{thm holds_eq}, @{thm eval_pred}]) 1
-            val eq = Goal.prove (ProofContext.init thy) arg_names [] eq_term tac
+            val eq = Goal.prove ctxt arg_names [] eq_term tac
           in
             (pred, result_thms @ [eq])
           end