src/HOL/Tools/Predicate_Compile/predicate_compile_proof.ML
author bulwahn
Thu, 20 Oct 2011 09:11:13 +0200
changeset 45214 66ba67adafab
parent 42816 ba14eafef077
child 45450 dc2236b19a3d
permissions -rw-r--r--
modernizing predicate_compile_quickcheck

(*  Title:      HOL/Tools/Predicate_Compile/predicate_compile_proof.ML
    Author:     Lukas Bulwahn, TU Muenchen

Proof procedure for the compiler from predicates specified by intro/elim rules to equations.
*)

signature PREDICATE_COMPILE_PROOF =
sig
  type indprem = Predicate_Compile_Aux.indprem;
  type mode = Predicate_Compile_Aux.mode
  val prove_pred : Predicate_Compile_Aux.options -> theory
    -> (string * (term list * indprem list) list) list
    -> (string * typ) list -> string -> bool * mode
    -> (term list * (indprem * Mode_Inference.mode_derivation) list) list * term
    -> Thm.thm
end;

structure Predicate_Compile_Proof : PREDICATE_COMPILE_PROOF =
struct

open Predicate_Compile_Aux;
open Core_Data;
open Mode_Inference;

(* debug stuff *)

fun print_tac options s = 
  if show_proof_trace options then Tactical.print_tac s else Seq.single;

(** auxiliary **)

datatype assertion = Max_number_of_subgoals of int
fun assert_tac (Max_number_of_subgoals i) st =
  if (nprems_of st <= i) then Seq.single st
  else raise Fail ("assert_tac: Numbers of subgoals mismatch at goal state :"
    ^ "\n" ^ Pretty.string_of (Pretty.chunks
      (Goal_Display.pretty_goals_without_context st)));


(** special setup for simpset **)
val HOL_basic_ss' = HOL_basic_ss addsimps (@{thms HOL.simp_thms} @ [@{thm Pair_eq}])
  setSolver (mk_solver "all_tac_solver" (fn _ => fn _ => all_tac))
  setSolver (mk_solver "True_solver" (fn _ => rtac @{thm TrueI}))

(* auxillary functions *)

fun is_Type (Type _) = true
  | is_Type _ = false

(* returns true if t is an application of an datatype constructor *)
(* which then consequently would be splitted *)
(* else false *)
fun is_constructor thy t =
  if (is_Type (fastype_of t)) then
    (case Datatype.get_info thy ((fst o dest_Type o fastype_of) t) of
      NONE => false
    | SOME info => (let
      val constr_consts = maps (fn (_, (_, _, constrs)) => map fst constrs) (#descr info)
      val (c, _) = strip_comb t
      in (case c of
        Const (name, _) => member (op =) constr_consts name
        | _ => false) end))
  else false

(* MAJOR FIXME:  prove_params should be simple
 - different form of introrule for parameters ? *)

fun prove_param options ctxt nargs t deriv =
  let
    val  (f, args) = strip_comb (Envir.eta_contract t)
    val mode = head_mode_of deriv
    val param_derivations = param_derivations_of deriv
    val ho_args = ho_args_of mode args
    val f_tac = case f of
      Const (name, T) => simp_tac (HOL_basic_ss addsimps 
         [@{thm eval_pred}, predfun_definition_of ctxt name mode,
         @{thm split_eta}, @{thm split_beta}, @{thm fst_conv},
         @{thm snd_conv}, @{thm pair_collapse}, @{thm Product_Type.split_conv}]) 1
    | Free _ =>
      Subgoal.FOCUS_PREMS (fn {context = ctxt, params = params, prems, asms, concl, schematics} =>
        let
          val prems' = maps dest_conjunct_prem (take nargs prems)
        in
          Simplifier.rewrite_goal_tac
            (map (fn th => th RS @{thm sym} RS @{thm eq_reflection}) prems') 1
        end) ctxt 1
    | Abs _ => raise Fail "prove_param: No valid parameter term"
  in
    REPEAT_DETERM (rtac @{thm ext} 1)
    THEN print_tac options "prove_param"
    THEN f_tac 
    THEN print_tac options "after prove_param"
    THEN (REPEAT_DETERM (atac 1))
    THEN (EVERY (map2 (prove_param options ctxt nargs) ho_args param_derivations))
    THEN REPEAT_DETERM (rtac @{thm refl} 1)
  end

fun prove_expr options ctxt nargs (premposition : int) (t, deriv) =
  case strip_comb t of
    (Const (name, T), args) =>
      let
        val mode = head_mode_of deriv
        val introrule = predfun_intro_of ctxt name mode
        val param_derivations = param_derivations_of deriv
        val ho_args = ho_args_of mode args
      in
        print_tac options "before intro rule:"
        THEN rtac introrule 1
        THEN print_tac options "after intro rule"
        (* for the right assumption in first position *)
        THEN rotate_tac premposition 1
        THEN atac 1
        THEN print_tac options "parameter goal"
        (* work with parameter arguments *)
        THEN (EVERY (map2 (prove_param options ctxt nargs) ho_args param_derivations))
        THEN (REPEAT_DETERM (atac 1))
      end
  | (Free _, _) =>
    print_tac options "proving parameter call.."
    THEN Subgoal.FOCUS_PREMS (fn {context = ctxt, params, prems, asms, concl, schematics} =>
        let
          val param_prem = nth prems premposition
          val (param, _) = strip_comb (HOLogic.dest_Trueprop (prop_of param_prem))
          val prems' = maps dest_conjunct_prem (take nargs prems)
          fun param_rewrite prem =
            param = snd (HOLogic.dest_eq (HOLogic.dest_Trueprop (prop_of prem)))
          val SOME rew_eq = find_first param_rewrite prems'
          val param_prem' = Raw_Simplifier.rewrite_rule
            (map (fn th => th RS @{thm eq_reflection})
              [rew_eq RS @{thm sym}, @{thm split_beta}, @{thm fst_conv}, @{thm snd_conv}])
            param_prem
        in
          rtac param_prem' 1
        end) ctxt 1
    THEN print_tac options "after prove parameter call"

fun SOLVED tac st = FILTER (fn st' => nprems_of st' = nprems_of st - 1) tac st;

fun SOLVEDALL tac st = FILTER (fn st' => nprems_of st' = 0) tac st

fun check_format ctxt st =
  let
    val concl' = Logic.strip_assums_concl (hd (prems_of st))
    val concl = HOLogic.dest_Trueprop concl'
    val expr = fst (strip_comb (fst (PredicateCompFuns.dest_Eval concl)))
    fun valid_expr (Const (@{const_name Predicate.bind}, _)) = true
      | valid_expr (Const (@{const_name Predicate.single}, _)) = true
      | valid_expr _ = false
  in
    if valid_expr expr then
      ((*tracing "expression is valid";*) Seq.single st)
    else
      ((*tracing "expression is not valid";*) Seq.empty) (* error "check_format: wrong format" *)
  end

fun prove_match options ctxt nargs out_ts =
  let
    val thy = Proof_Context.theory_of ctxt
    val eval_if_P =
      @{lemma "P ==> Predicate.eval x z ==> Predicate.eval (if P then x else y) z" by simp} 
    fun get_case_rewrite t =
      if (is_constructor thy t) then let
        val case_rewrites = (#case_rewrites (Datatype.the_info thy
          ((fst o dest_Type o fastype_of) t)))
        in fold (union Thm.eq_thm) (case_rewrites :: map get_case_rewrite (snd (strip_comb t))) [] end
      else []
    val simprules = insert Thm.eq_thm @{thm "unit.cases"} (insert Thm.eq_thm @{thm "prod.cases"}
      (fold (union Thm.eq_thm) (map get_case_rewrite out_ts) []))
  (* replace TRY by determining if it necessary - are there equations when calling compile match? *)
  in
     (* make this simpset better! *)
    asm_full_simp_tac (HOL_basic_ss' addsimps simprules) 1
    THEN print_tac options "after prove_match:"
    THEN (DETERM (TRY 
           (rtac eval_if_P 1
           THEN (SUBPROOF (fn {context = ctxt, params, prems, asms, concl, schematics} =>
             (REPEAT_DETERM (rtac @{thm conjI} 1
             THEN (SOLVED (asm_simp_tac HOL_basic_ss' 1))))
             THEN print_tac options "if condition to be solved:"
             THEN asm_simp_tac HOL_basic_ss' 1
             THEN TRY (
                let
                  val prems' = maps dest_conjunct_prem (take nargs prems)
                in
                  Simplifier.rewrite_goal_tac
                    (map (fn th => th RS @{thm sym} RS @{thm eq_reflection}) prems') 1
                end
             THEN REPEAT_DETERM (rtac @{thm refl} 1))
             THEN print_tac options "after if simp; in SUBPROOF") ctxt 1))))
    THEN print_tac options "after if simplification"
  end;

(* corresponds to compile_fun -- maybe call that also compile_sidecond? *)

fun prove_sidecond ctxt t =
  let
    fun preds_of t nameTs = case strip_comb t of 
      (f as Const (name, T), args) =>
        if is_registered ctxt name then (name, T) :: nameTs
          else fold preds_of args nameTs
      | _ => nameTs
    val preds = preds_of t []
    val defs = map
      (fn (pred, T) => predfun_definition_of ctxt pred
        (all_input_of T))
        preds
  in 
    simp_tac (HOL_basic_ss addsimps
      (@{thms HOL.simp_thms} @ (@{thm eval_pred} :: defs))) 1 
    (* need better control here! *)
  end

fun prove_clause options ctxt nargs mode (_, clauses) (ts, moded_ps) =
  let
    val (in_ts, clause_out_ts) = split_mode mode ts;
    fun prove_prems out_ts [] =
      (prove_match options ctxt nargs out_ts)
      THEN print_tac options "before simplifying assumptions"
      THEN asm_full_simp_tac HOL_basic_ss' 1
      THEN print_tac options "before single intro rule"
      THEN Subgoal.FOCUS_PREMS
             (fn {context = ctxt, params = params, prems, asms, concl, schematics} =>
              let
                val prems' = maps dest_conjunct_prem (take nargs prems)
              in
                Simplifier.rewrite_goal_tac
                  (map (fn th => th RS @{thm sym} RS @{thm eq_reflection}) prems') 1
              end) ctxt 1
      THEN (rtac (if null clause_out_ts then @{thm singleI_unit} else @{thm singleI}) 1)
    | prove_prems out_ts ((p, deriv) :: ps) =
      let
        val premposition = (find_index (equal p) clauses) + nargs
        val mode = head_mode_of deriv
        val rest_tac =
          rtac @{thm bindI} 1
          THEN (case p of Prem t =>
            let
              val (_, us) = strip_comb t
              val (_, out_ts''') = split_mode mode us
              val rec_tac = prove_prems out_ts''' ps
            in
              print_tac options "before clause:"
              (*THEN asm_simp_tac HOL_basic_ss 1*)
              THEN print_tac options "before prove_expr:"
              THEN prove_expr options ctxt nargs premposition (t, deriv)
              THEN print_tac options "after prove_expr:"
              THEN rec_tac
            end
          | Negprem t =>
            let
              val (t, args) = strip_comb t
              val (_, out_ts''') = split_mode mode args
              val rec_tac = prove_prems out_ts''' ps
              val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE)
              val neg_intro_rule =
                Option.map (fn name =>
                  the (predfun_neg_intro_of ctxt name mode)) name
              val param_derivations = param_derivations_of deriv
              val params = ho_args_of mode args
            in
              print_tac options "before prove_neg_expr:"
              THEN full_simp_tac (HOL_basic_ss addsimps
                [@{thm split_eta}, @{thm split_beta}, @{thm fst_conv},
                 @{thm snd_conv}, @{thm pair_collapse}, @{thm Product_Type.split_conv}]) 1
              THEN (if (is_some name) then
                  print_tac options "before applying not introduction rule"
                  THEN Subgoal.FOCUS_PREMS
                    (fn {context = ctxt, params = params, prems, asms, concl, schematics} =>
                      rtac (the neg_intro_rule) 1
                      THEN rtac (nth prems premposition) 1) ctxt 1
                  THEN print_tac options "after applying not introduction rule"
                  THEN (EVERY (map2 (prove_param options ctxt nargs) params param_derivations))
                  THEN (REPEAT_DETERM (atac 1))
                else
                  rtac @{thm not_predI'} 1
                  (* test: *)
                  THEN dtac @{thm sym} 1
                  THEN asm_full_simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1)
                  THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1
              THEN rec_tac
            end
          | Sidecond t =>
           rtac @{thm if_predI} 1
           THEN print_tac options "before sidecond:"
           THEN prove_sidecond ctxt t
           THEN print_tac options "after sidecond:"
           THEN prove_prems [] ps)
      in (prove_match options ctxt nargs out_ts)
          THEN rest_tac
      end;
    val prems_tac = prove_prems in_ts moded_ps
  in
    print_tac options "Proving clause..."
    THEN rtac @{thm bindI} 1
    THEN rtac @{thm singleI} 1
    THEN prems_tac
  end;

fun select_sup 1 1 = []
  | select_sup _ 1 = [rtac @{thm supI1}]
  | select_sup n i = (rtac @{thm supI2})::(select_sup (n - 1) (i - 1));

fun prove_one_direction options ctxt clauses preds pred mode moded_clauses =
  let
    val T = the (AList.lookup (op =) preds pred)
    val nargs = length (binder_types T)
    val pred_case_rule = the_elim_of ctxt pred
  in
    REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"}))
    THEN print_tac options "before applying elim rule"
    THEN etac (predfun_elim_of ctxt pred mode) 1
    THEN etac pred_case_rule 1
    THEN print_tac options "after applying elim rule"
    THEN (EVERY (map
           (fn i => EVERY' (select_sup (length moded_clauses) i) i) 
             (1 upto (length moded_clauses))))
    THEN (EVERY (map2 (prove_clause options ctxt nargs mode) clauses moded_clauses))
    THEN print_tac options "proved one direction"
  end;

(** Proof in the other direction **)

fun prove_match2 options ctxt out_ts =
  let
    val thy = Proof_Context.theory_of ctxt
    fun split_term_tac (Free _) = all_tac
      | split_term_tac t =
        if (is_constructor thy t) then
          let
            val info = Datatype.the_info thy ((fst o dest_Type o fastype_of) t)
            val num_of_constrs = length (#case_rewrites info)
            val (_, ts) = strip_comb t
          in
            print_tac options ("Term " ^ (Syntax.string_of_term ctxt t) ^ 
              "splitting with rules \n" ^ Display.string_of_thm ctxt (#split_asm info))
            THEN TRY ((Splitter.split_asm_tac [#split_asm info] 1)
              THEN (print_tac options "after splitting with split_asm rules")
            (* THEN (Simplifier.asm_full_simp_tac HOL_basic_ss 1)
              THEN (DETERM (TRY (etac @{thm Pair_inject} 1)))*)
              THEN (REPEAT_DETERM_N (num_of_constrs - 1)
                (etac @{thm botE} 1 ORELSE etac @{thm botE} 2)))
            THEN (assert_tac (Max_number_of_subgoals 2))
            THEN (EVERY (map split_term_tac ts))
          end
      else all_tac
  in
    split_term_tac (HOLogic.mk_tuple out_ts)
    THEN (DETERM (TRY ((Splitter.split_asm_tac [@{thm "split_if_asm"}] 1)
    THEN (etac @{thm botE} 2))))
  end

(* VERY LARGE SIMILIRATIY to function prove_param 
-- join both functions
*)
(* TODO: remove function *)

fun prove_param2 options ctxt t deriv =
  let
    val (f, args) = strip_comb (Envir.eta_contract t)
    val mode = head_mode_of deriv
    val param_derivations = param_derivations_of deriv
    val ho_args = ho_args_of mode args
    val f_tac = case f of
        Const (name, T) => full_simp_tac (HOL_basic_ss addsimps 
           (@{thm eval_pred}::(predfun_definition_of ctxt name mode)
           :: @{thm "Product_Type.split_conv"}::[])) 1
      | Free _ => all_tac
      | _ => error "prove_param2: illegal parameter term"
  in
    print_tac options "before simplification in prove_args:"
    THEN f_tac
    THEN print_tac options "after simplification in prove_args"
    THEN EVERY (map2 (prove_param2 options ctxt) ho_args param_derivations)
  end

fun prove_expr2 options ctxt (t, deriv) = 
  (case strip_comb t of
      (Const (name, T), args) =>
        let
          val mode = head_mode_of deriv
          val param_derivations = param_derivations_of deriv
          val ho_args = ho_args_of mode args
        in
          etac @{thm bindE} 1
          THEN (REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"})))
          THEN print_tac options "prove_expr2-before"
          THEN etac (predfun_elim_of ctxt name mode) 1
          THEN print_tac options "prove_expr2"
          THEN (EVERY (map2 (prove_param2 options ctxt) ho_args param_derivations))
          THEN print_tac options "finished prove_expr2"
        end
      | _ => etac @{thm bindE} 1)

fun prove_sidecond2 options ctxt t = let
  fun preds_of t nameTs = case strip_comb t of 
    (f as Const (name, T), args) =>
      if is_registered ctxt name then (name, T) :: nameTs
        else fold preds_of args nameTs
    | _ => nameTs
  val preds = preds_of t []
  val defs = map
    (fn (pred, T) => predfun_definition_of ctxt pred 
      (all_input_of T))
      preds
  in
   (* only simplify the one assumption *)
   full_simp_tac (HOL_basic_ss' addsimps @{thm eval_pred} :: defs) 1 
   (* need better control here! *)
   THEN print_tac options "after sidecond2 simplification"
   end
  
fun prove_clause2 options ctxt pred mode (ts, ps) i =
  let
    val pred_intro_rule = nth (intros_of ctxt pred) (i - 1)
    val (in_ts, clause_out_ts) = split_mode mode ts;
    val split_ss = HOL_basic_ss' addsimps [@{thm split_eta}, @{thm split_beta},
      @{thm fst_conv}, @{thm snd_conv}, @{thm pair_collapse}]
    fun prove_prems2 out_ts [] =
      print_tac options "before prove_match2 - last call:"
      THEN prove_match2 options ctxt out_ts
      THEN print_tac options "after prove_match2 - last call:"
      THEN (etac @{thm singleE} 1)
      THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1))
      THEN (asm_full_simp_tac HOL_basic_ss' 1)
      THEN TRY (
        (REPEAT_DETERM (etac @{thm Pair_inject} 1))
        THEN (asm_full_simp_tac HOL_basic_ss' 1)
        
        THEN SOLVED (print_tac options "state before applying intro rule:"
        THEN (rtac pred_intro_rule
        (* How to handle equality correctly? *)
        THEN_ALL_NEW (K (print_tac options "state before assumption matching")
        THEN' (atac ORELSE' ((CHANGED o asm_full_simp_tac split_ss) THEN' (TRY o atac)))
          THEN' (K (print_tac options "state after pre-simplification:"))
        THEN' (K (print_tac options "state after assumption matching:")))) 1))
    | prove_prems2 out_ts ((p, deriv) :: ps) =
      let
        val mode = head_mode_of deriv
        val rest_tac = (case p of
          Prem t =>
          let
            val (_, us) = strip_comb t
            val (_, out_ts''') = split_mode mode us
            val rec_tac = prove_prems2 out_ts''' ps
          in
            (prove_expr2 options ctxt (t, deriv)) THEN rec_tac
          end
        | Negprem t =>
          let
            val (_, args) = strip_comb t
            val (_, out_ts''') = split_mode mode args
            val rec_tac = prove_prems2 out_ts''' ps
            val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE)
            val param_derivations = param_derivations_of deriv
            val ho_args = ho_args_of mode args
          in
            print_tac options "before neg prem 2"
            THEN etac @{thm bindE} 1
            THEN (if is_some name then
                full_simp_tac (HOL_basic_ss addsimps
                  [predfun_definition_of ctxt (the name) mode]) 1
                THEN etac @{thm not_predE} 1
                THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1
                THEN (EVERY (map2 (prove_param2 options ctxt) ho_args param_derivations))
              else
                etac @{thm not_predE'} 1)
            THEN rec_tac
          end 
        | Sidecond t =>
          etac @{thm bindE} 1
          THEN etac @{thm if_predE} 1
          THEN prove_sidecond2 options ctxt t
          THEN prove_prems2 [] ps)
      in print_tac options "before prove_match2:"
         THEN prove_match2 options ctxt out_ts
         THEN print_tac options "after prove_match2:"
         THEN rest_tac
      end;
    val prems_tac = prove_prems2 in_ts ps 
  in
    print_tac options "starting prove_clause2"
    THEN etac @{thm bindE} 1
    THEN (etac @{thm singleE'} 1)
    THEN (TRY (etac @{thm Pair_inject} 1))
    THEN print_tac options "after singleE':"
    THEN prems_tac
  end;
 
fun prove_other_direction options ctxt pred mode moded_clauses =
  let
    fun prove_clause clause i =
      (if i < length moded_clauses then etac @{thm supE} 1 else all_tac)
      THEN (prove_clause2 options ctxt pred mode clause i)
  in
    (DETERM (TRY (rtac @{thm unit.induct} 1)))
     THEN (REPEAT_DETERM (CHANGED (rewtac @{thm split_paired_all})))
     THEN (rtac (predfun_intro_of ctxt pred mode) 1)
     THEN (REPEAT_DETERM (rtac @{thm refl} 2))
     THEN (if null moded_clauses then
         etac @{thm botE} 1
       else EVERY (map2 prove_clause moded_clauses (1 upto (length moded_clauses))))
  end;

(** proof procedure **)

fun prove_pred options thy clauses preds pred (pol, mode) (moded_clauses, compiled_term) =
  let
    val ctxt = Proof_Context.init_global thy
    val clauses = case AList.lookup (op =) clauses pred of SOME rs => rs | NONE => []
  in
    Goal.prove ctxt (Term.add_free_names compiled_term []) [] compiled_term
      (if not (skip_proof options) then
        (fn _ =>
        rtac @{thm pred_iffI} 1
        THEN print_tac options "after pred_iffI"
        THEN prove_one_direction options ctxt clauses preds pred mode moded_clauses
        THEN print_tac options "proved one direction"
        THEN prove_other_direction options ctxt pred mode moded_clauses
        THEN print_tac options "proved other direction")
      else (fn _ => Skip_Proof.cheat_tac thy))
  end;

end;