src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
changeset 35884 362bfc2ca0ee
parent 35881 aa412e08bfee
child 35885 7b39120a1494
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Mon Mar 22 08:30:13 2010 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Mon Mar 22 08:30:13 2010 +0100
@@ -189,13 +189,14 @@
 datatype predfun_data = PredfunData of {
   definition : thm,
   intro : thm,
-  elim : thm
+  elim : thm,
+  neg_intro : thm option
 };
 
 fun rep_predfun_data (PredfunData data) = data;
 
-fun mk_predfun_data (definition, intro, elim) =
-  PredfunData {definition = definition, intro = intro, elim = elim}
+fun mk_predfun_data (definition, ((intro, elim), neg_intro)) =
+  PredfunData {definition = definition, intro = intro, elim = elim, neg_intro = neg_intro}
 
 datatype pred_data = PredData of {
   intros : thm list,
@@ -291,6 +292,8 @@
 
 val predfun_elim_of = #elim ooo the_predfun_data
 
+val predfun_neg_intro_of = #neg_intro ooo the_predfun_data
+
 (* diagnostic display functions *)
 
 fun print_modes options thy modes =
@@ -470,7 +473,7 @@
     in
       (HOLogic.mk_prod (t1, t2), st'')
     end
-  | mk_args2 (T as Type ("fun", _)) (params, ctxt) = 
+  (*| mk_args2 (T as Type ("fun", _)) (params, ctxt) = 
     let
       val (S, U) = strip_type T
     in
@@ -482,36 +485,80 @@
         in
           (Free (x, T), (params, ctxt'))
         end
-    end
+    end*)
   | mk_args2 T (params, ctxt) =
     let
       val ([x], ctxt') = Variable.variant_fixes ["x"] ctxt
     in
       (Free (x, T), (params, ctxt'))
     end
-  
+
 fun mk_casesrule ctxt pred introrules =
   let
+    (* TODO: can be simplified if parameters are not treated specially ? *)
     val (((pred, params), intros_th), ctxt1) = import_intros pred introrules ctxt
+    (* TODO: distinct required ? -- test case with more than one parameter! *)
+    val params = distinct (op aconv) params
     val intros = map prop_of intros_th
     val ([propname], ctxt2) = Variable.variant_fixes ["thesis"] ctxt1
     val prop = HOLogic.mk_Trueprop (Free (propname, HOLogic.boolT))
     val argsT = binder_types (fastype_of pred)
+    (* TODO: can be simplified if parameters are not treated specially ? <-- see uncommented code! *)
     val (argvs, _) = fold_map mk_args2 argsT (params, ctxt2)
     fun mk_case intro =
       let
         val (_, args) = (strip_comb o HOLogic.dest_Trueprop o Logic.strip_imp_concl) intro
         val prems = Logic.strip_imp_prems intro
-        val eqprems = map2 (HOLogic.mk_Trueprop oo (curry HOLogic.mk_eq)) argvs args
-        val frees = (fold o fold_aterms)
-          (fn t as Free _ =>
-              if member (op aconv) params t then I else insert (op aconv) t
-           | _ => I) (args @ prems) []
+        val eqprems =
+          map2 (HOLogic.mk_Trueprop oo (curry HOLogic.mk_eq)) argvs args
+        val frees = map Free (fold Term.add_frees (args @ prems) [])
       in fold Logic.all frees (Logic.list_implies (eqprems @ prems, prop)) end
     val assm = HOLogic.mk_Trueprop (list_comb (pred, argvs))
     val cases = map mk_case intros
   in Logic.list_implies (assm :: cases, prop) end;
 
+fun dest_conjunct_prem th =
+  case HOLogic.dest_Trueprop (prop_of th) of
+    (Const ("op &", _) $ t $ t') =>
+      dest_conjunct_prem (th RS @{thm conjunct1})
+        @ dest_conjunct_prem (th RS @{thm conjunct2})
+    | _ => [th]
+
+fun prove_casesrule ctxt (pred, (pre_cases_rule, nparams)) cases_rule =
+  let
+    val thy = ProofContext.theory_of ctxt
+    val nargs = length (binder_types (fastype_of pred))
+    fun PEEK f dependent_tactic st = dependent_tactic (f st) st
+    fun meta_eq_of th = th RS @{thm eq_reflection}
+    val tuple_rew_rules = map meta_eq_of [@{thm fst_conv}, @{thm snd_conv}, @{thm Pair_eq}]
+    fun instantiate i n {context = ctxt, params = p, prems = prems,
+      asms = a, concl = cl, schematics = s}  =
+      let
+        val (cases, (eqs, prems)) = apsnd (chop (nargs - nparams)) (chop n prems)
+        val case_th = MetaSimplifier.simplify true
+        (@{thm Predicate.eq_is_eq} :: map meta_eq_of eqs)
+          (nth cases (i - 1))
+        val prems' = maps (dest_conjunct_prem o MetaSimplifier.simplify true tuple_rew_rules) prems
+        val pats = map (swap o HOLogic.dest_eq o HOLogic.dest_Trueprop) (take nargs (prems_of case_th))
+        val (_, tenv) = fold (Pattern.match thy) pats (Vartab.empty, Vartab.empty)
+        fun term_pair_of (ix, (ty,t)) = (Var (ix,ty), t)
+        val inst = map (pairself (cterm_of thy) o term_pair_of) (Vartab.dest tenv)
+        val thesis = Thm.instantiate ([], inst) case_th OF (replicate nargs @{thm refl}) OF prems'
+      in
+        (rtac thesis 1)
+      end
+    val tac =
+      etac pre_cases_rule 1
+      THEN
+      (PEEK nprems_of
+        (fn n =>
+          ALLGOALS (fn i =>
+            MetaSimplifier.rewrite_goal_tac [@{thm split_paired_all}] i
+            THEN (SUBPROOF (instantiate i n) ctxt i))))
+  in
+    Goal.prove ctxt (Term.add_free_names cases_rule []) [] cases_rule (fn _ => tac)
+  end
+
 (** preprocessing rules **)
 
 fun imp_prems_conv cv ct =
@@ -552,7 +599,7 @@
     val bigeq = (Thm.symmetric (Conv.implies_concl_conv
       (MetaSimplifier.rewrite true [@{thm Predicate.eq_is_eq}])
         (cterm_of thy elimrule')))
-    val tac = (fn _ => Skip_Proof.cheat_tac thy)    
+    val tac = (fn _ => Skip_Proof.cheat_tac thy)
     val eq = Goal.prove ctxt' [] [] (Logic.mk_equals ((Thm.prop_of elimrule), elimrule')) tac
   in
     Thm.equal_elim eq elimrule |> singleton (Variable.export ctxt' ctxt)
@@ -575,11 +622,14 @@
         val index = find_index (fn s => s = name) (#names (fst info))
         val pre_elim = nth (#elims result) index
         val pred = nth (#preds result) index
+        val nparams = length (Inductive.params_of (#raw_induct result))
         (*val elim = singleton (Inductive_Set.codegen_preproc thy) (preprocess_elim thy nparams 
           (expand_tuples_elim pre_elim))*)
+        (* FIXME: missing Inductive Set preprocessing *)
+        val ctxt = ProofContext.init thy
+        val elim_t = mk_casesrule ctxt pred intros
         val elim =
-          (Drule.export_without_context o Skip_Proof.make_thm thy)
-          (mk_casesrule (ProofContext.init thy) pred intros)
+          prove_casesrule ctxt (pred, (pre_elim, nparams)) elim_t
       in
         mk_pred_data ((intros, SOME elim), no_compilation)
       end
@@ -1630,7 +1680,7 @@
     
     val (in_ts, _) = fold_map (fold_map_aterms_prodT (curry HOLogic.mk_prod)
       (fn T => fn (param_vs, names) =>
-        if is_param_type T then
+        if is_param_type T then                                                
           (Free (hd param_vs, T), (tl param_vs, names))
         else
           let
@@ -1743,8 +1793,24 @@
     val elimtrm = Logic.list_implies ([funpropE, Logic.mk_implies (predpropE, P)], P)
     val elimthm = Goal.prove (ProofContext.init thy)
       (argnames @ ["y", "P"]) [] elimtrm (fn _ => unfolddef_tac)
+    val opt_neg_introthm =
+      if is_all_input mode then
+        let
+          val neg_predpropI = HOLogic.mk_Trueprop (HOLogic.mk_not (list_comb (pred, args')))
+          val neg_funpropI =
+            HOLogic.mk_Trueprop (PredicateCompFuns.mk_Eval
+              (PredicateCompFuns.mk_not (list_comb (funtrm, inargs)), HOLogic.unit))
+          val neg_introtrm = Logic.list_implies (neg_predpropI :: param_eqs, neg_funpropI)
+          val tac =
+            Simplifier.asm_full_simp_tac (HOL_basic_ss addsimps
+              (@{thm if_False} :: @{thm Predicate.not_pred_eq} :: simprules)) 1
+            THEN rtac @{thm Predicate.singleI} 1
+        in SOME (Goal.prove (ProofContext.init thy) (argnames @ hoarg_names') []
+            neg_introtrm (fn _ => tac))
+        end
+      else NONE
   in
-    (introthm, elimthm)
+    ((introthm, elimthm), opt_neg_introthm)
   end
 
 fun create_constname_of_mode options thy prefix name T mode = 
@@ -1780,11 +1846,11 @@
         val ([definition], thy') = thy |>
           Sign.add_consts_i [(Binding.name mode_cbasename, funT, NoSyn)] |>
           PureThy.add_defs false [((Binding.name (mode_cbasename ^ "_def"), def), [])]
-        val (intro, elim) =
+        val rules as ((intro, elim), _) =
           create_intro_elim_rule mode definition mode_cname funT (Const (name, T)) thy'
         in thy'
           |> set_function_name Pred name mode mode_cname
-          |> add_predfun_data name mode (definition, intro, elim)
+          |> add_predfun_data name mode (definition, rules)
           |> PureThy.store_thm (Binding.name (mode_cbasename ^ "I"), intro) |> snd
           |> PureThy.store_thm (Binding.name (mode_cbasename ^ "E"), elim)  |> snd
           |> Theory.checkpoint
@@ -1834,7 +1900,7 @@
 (* MAJOR FIXME:  prove_params should be simple
  - different form of introrule for parameters ? *)
 
-fun prove_param options thy t deriv =
+fun prove_param options thy nargs t deriv =
   let
     val  (f, args) = strip_comb (Envir.eta_contract t)
     val mode = head_mode_of deriv
@@ -1842,21 +1908,32 @@
     val ho_args = ho_args_of mode args
     val f_tac = case f of
       Const (name, T) => simp_tac (HOL_basic_ss addsimps 
-         ([@{thm eval_pred}, (predfun_definition_of thy name mode),
-         @{thm "split_eta"}, @{thm "split_beta"}, @{thm "fst_conv"},
-         @{thm "snd_conv"}, @{thm pair_collapse}, @{thm "Product_Type.split_conv"}])) 1
-    | Free _ => TRY (rtac @{thm refl} 1)
-    | Abs _ => error "prove_param: No valid parameter term"
+         [@{thm eval_pred}, predfun_definition_of thy name mode,
+         @{thm split_eta}, @{thm split_beta}, @{thm fst_conv},
+         @{thm snd_conv}, @{thm pair_collapse}, @{thm Product_Type.split_conv}]) 1
+    | Free _ =>
+      (* rewrite with parameter equation *)
+    (* test: *)
+      Subgoal.FOCUS_PREMS (fn {context = ctxt, params = params, prems = prems,
+      asms = a, concl = concl, schematics = s} =>
+        let
+          val prems' = maps dest_conjunct_prem (take nargs prems)
+        in
+          MetaSimplifier.rewrite_goal_tac
+            (map (fn th => th RS @{thm sym} RS @{thm eq_reflection}) prems') 1
+        end) (ProofContext.init thy) 1 (* FIXME: proper context handling *)
+    | Abs _ => raise Fail "prove_param: No valid parameter term"
   in
     REPEAT_DETERM (rtac @{thm ext} 1)
     THEN print_tac' options "prove_param"
-    THEN f_tac
-    THEN print_tac' options "after simplification in prove_args"
+    THEN f_tac 
+    THEN print_tac' options "after prove_param"
     THEN (REPEAT_DETERM (atac 1))
-    THEN (EVERY (map2 (prove_param options thy) ho_args param_derivations))
+    THEN (EVERY (map2 (prove_param options thy nargs) ho_args param_derivations))
+    THEN REPEAT_DETERM (rtac @{thm refl} 1)
   end
 
-fun prove_expr options thy (premposition : int) (t, deriv) =
+fun prove_expr options thy nargs (premposition : int) (t, deriv) =
   case strip_comb t of
     (Const (name, T), args) =>
       let
@@ -1866,25 +1943,36 @@
         val ho_args = ho_args_of mode args
       in
         print_tac' options "before intro rule:"
+        THEN rtac introrule 1
+        THEN print_tac' options "after intro rule"
         (* for the right assumption in first position *)
         THEN rotate_tac premposition 1
-        THEN debug_tac (Display.string_of_thm (ProofContext.init thy) introrule)
-        THEN rtac introrule 1
-        THEN print_tac' options "after intro rule"
-        (* work with parameter arguments *)
         THEN atac 1
         THEN print_tac' options "parameter goal"
-        THEN (EVERY (map2 (prove_param options thy) ho_args param_derivations))
+        (* work with parameter arguments *)
+        THEN (EVERY (map2 (prove_param options thy nargs) ho_args param_derivations))
         THEN (REPEAT_DETERM (atac 1))
       end
-  | _ =>
-    asm_full_simp_tac
-      (HOL_basic_ss' addsimps [@{thm "split_eta"}, @{thm "split_beta"}, @{thm "fst_conv"},
-         @{thm "snd_conv"}, @{thm pair_collapse}]) 1
-    THEN (atac 1)
+  | (Free _, _) =>
+    print_tac' options "proving parameter call.."
+    THEN Subgoal.FOCUS_PREMS (fn {context = ctxt, params = params, prems = prems,
+      asms = a, concl = cl, schematics = s} =>
+        let
+          val param_prem = nth prems premposition
+          val (param, _) = strip_comb (HOLogic.dest_Trueprop (prop_of param_prem))
+          val prems' = maps dest_conjunct_prem (take nargs prems)
+          fun param_rewrite prem =
+            param = snd (HOLogic.dest_eq (HOLogic.dest_Trueprop (prop_of prem)))
+          val SOME rew_eq = find_first param_rewrite prems'
+          val param_prem' = MetaSimplifier.rewrite_rule
+            (map (fn th => th RS @{thm eq_reflection})
+              [rew_eq RS @{thm sym}, @{thm split_beta}, @{thm fst_conv}, @{thm snd_conv}])
+            param_prem
+        in
+          rtac param_prem' 1
+        end) (ProofContext.init thy) 1 (* FIXME: proper context handling *)
     THEN print_tac' options "after prove parameter call"
 
-
 fun SOLVED tac st = FILTER (fn st' => nprems_of st' = nprems_of st - 1) tac st;
 
 fun SOLVEDALL tac st = FILTER (fn st' => nprems_of st' = 0) tac st
@@ -1972,7 +2060,7 @@
               print_tac' options "before clause:"
               (*THEN asm_simp_tac HOL_basic_ss 1*)
               THEN print_tac' options "before prove_expr:"
-              THEN prove_expr options thy premposition (t, deriv)
+              THEN prove_expr options thy nargs premposition (t, deriv)
               THEN print_tac' options "after prove_expr:"
               THEN rec_tac
             end
@@ -1982,6 +2070,8 @@
               val (_, out_ts''') = split_mode mode args
               val rec_tac = prove_prems out_ts''' ps
               val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE)
+              val neg_intro_rule =
+                Option.map (fn name => the (predfun_neg_intro_of thy name mode)) name
               val param_derivations = param_derivations_of deriv
               val params = ho_args_of mode args
             in
@@ -1990,22 +2080,18 @@
                 [@{thm split_eta}, @{thm split_beta}, @{thm fst_conv},
                  @{thm snd_conv}, @{thm pair_collapse}, @{thm Product_Type.split_conv}]) 1
               THEN (if (is_some name) then
-                  print_tac' options ("before unfolding definition " ^
-                    (Display.string_of_thm_global thy
-                      (predfun_definition_of thy (the name) mode)))
-                  
-                  THEN simp_tac (HOL_basic_ss addsimps
-                    [predfun_definition_of thy (the name) mode]) 1
-                  THEN rtac @{thm not_predI} 1
-                  THEN print_tac' options "after applying rule not_predI"
-                  THEN full_simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True},
-                    @{thm split_eta}, @{thm split_beta}, @{thm fst_conv},
-                    @{thm snd_conv}, @{thm pair_collapse}, @{thm Product_Type.split_conv}]) 1
-                  THEN (REPEAT_DETERM (atac 1))
-                  THEN (EVERY (map2 (prove_param options thy) params param_derivations))
+                  print_tac' options "before applying not introduction rule"
+                  THEN rotate_tac premposition 1
+                  THEN etac (the neg_intro_rule) 1
+                  THEN rotate_tac (~premposition) 1
+                  THEN print_tac' options "after applying not introduction rule"
+                  THEN (EVERY (map2 (prove_param options thy nargs) params param_derivations))
                   THEN (REPEAT_DETERM (atac 1))
                 else
-                  rtac @{thm not_predI'} 1)
+                  rtac @{thm not_predI'} 1
+                  (* test: *)
+                  THEN dtac @{thm sym} 1
+                  THEN asm_full_simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1)
                   THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1
               THEN rec_tac
             end
@@ -2040,6 +2126,7 @@
     THEN print_tac' options "before applying elim rule"
     THEN etac (predfun_elim_of thy pred mode) 1
     THEN etac pred_case_rule 1
+    THEN print_tac' options "after applying elim rule"
     THEN (EVERY (map
            (fn i => EVERY' (select_sup (length moded_clauses) i) i) 
              (1 upto (length moded_clauses))))
@@ -2115,9 +2202,7 @@
           etac @{thm bindE} 1
           THEN (REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"})))
           THEN print_tac "prove_expr2-before"
-          THEN (debug_tac (Syntax.string_of_term_global thy
-            (prop_of (predfun_elim_of thy name mode))))
-          THEN (etac (predfun_elim_of thy name mode) 1)
+          THEN etac (predfun_elim_of thy name mode) 1
           THEN print_tac "prove_expr2"
           THEN (EVERY (map2 (prove_param2 thy) ho_args param_derivations))
           THEN print_tac "finished prove_expr2"