src/HOL/Tools/Function/scnp_reconstruct.ML
author blanchet
Wed, 08 Jun 2011 16:20:18 +0200
changeset 43293 a80cdc4b27a3
parent 42795 66fcc9882784
child 51717 9e7d1c139569
permissions -rw-r--r--
made "query" type systes a bit more sound -- local facts, e.g. the negated conjecture, may make invalid the infinity check, e.g. if we are proving that there exists two values of an infinite type, we can use the negated conjecture that there is only one value to derive unsound proofs unless the type is properly encoded

(*  Title:       HOL/Tools/Function/scnp_reconstruct.ML
    Author:      Armin Heller, TU Muenchen
    Author:      Alexander Krauss, TU Muenchen

Proof reconstruction for SCNP termination.
*)

signature SCNP_RECONSTRUCT =
sig

  val sizechange_tac : Proof.context -> tactic -> tactic

  val decomp_scnp_tac : ScnpSolve.label list -> Proof.context -> tactic

  val setup : theory -> theory

  datatype multiset_setup =
    Multiset of
    {
     msetT : typ -> typ,
     mk_mset : typ -> term list -> term,
     mset_regroup_conv : int list -> conv,
     mset_member_tac : int -> int -> tactic,
     mset_nonempty_tac : int -> tactic,
     mset_pwleq_tac : int -> tactic,
     set_of_simps : thm list,
     smsI' : thm,
     wmsI2'' : thm,
     wmsI1 : thm,
     reduction_pair : thm
    }


  val multiset_setup : multiset_setup -> theory -> theory

end

structure ScnpReconstruct : SCNP_RECONSTRUCT =
struct

val PROFILE = Function_Common.PROFILE

open ScnpSolve

val natT = HOLogic.natT
val nat_pairT = HOLogic.mk_prodT (natT, natT)

(* Theory dependencies *)

datatype multiset_setup =
  Multiset of
  {
   msetT : typ -> typ,
   mk_mset : typ -> term list -> term,
   mset_regroup_conv : int list -> conv,
   mset_member_tac : int -> int -> tactic,
   mset_nonempty_tac : int -> tactic,
   mset_pwleq_tac : int -> tactic,
   set_of_simps : thm list,
   smsI' : thm,
   wmsI2'' : thm,
   wmsI1 : thm,
   reduction_pair : thm
  }

structure Multiset_Setup = Theory_Data
(
  type T = multiset_setup option
  val empty = NONE
  val extend = I;
  val merge = merge_options
)

val multiset_setup = Multiset_Setup.put o SOME

fun undef _ = error "undef"
fun get_multiset_setup thy = Multiset_Setup.get thy
  |> the_default (Multiset
{ msetT = undef, mk_mset=undef,
  mset_regroup_conv=undef, mset_member_tac = undef,
  mset_nonempty_tac = undef, mset_pwleq_tac = undef,
  set_of_simps = [],reduction_pair = refl,
  smsI'=refl, wmsI2''=refl, wmsI1=refl })

fun order_rpair _ MAX = @{thm max_rpair_set}
  | order_rpair msrp MS  = msrp
  | order_rpair _ MIN = @{thm min_rpair_set}

fun ord_intros_max true =
    (@{thm smax_emptyI}, @{thm smax_insertI})
  | ord_intros_max false =
    (@{thm wmax_emptyI}, @{thm wmax_insertI})
fun ord_intros_min true =
    (@{thm smin_emptyI}, @{thm smin_insertI})
  | ord_intros_min false =
    (@{thm wmin_emptyI}, @{thm wmin_insertI})

fun gen_probl D cs =
  let
    val n = Termination.get_num_points D
    val arity = length o Termination.get_measures D
    fun measure p i = nth (Termination.get_measures D p) i

    fun mk_graph c =
      let
        val (_, p, _, q, _, _) = Termination.dest_call D c

        fun add_edge i j =
          case Termination.get_descent D c (measure p i) (measure q j)
           of SOME (Termination.Less _) => cons (i, GTR, j)
            | SOME (Termination.LessEq _) => cons (i, GEQ, j)
            | _ => I

        val edges =
          fold_product add_edge (0 upto arity p - 1) (0 upto arity q - 1) []
      in
        G (p, q, edges)
      end
  in
    GP (map_range arity n, map mk_graph cs)
  end

(* General reduction pair application *)
fun rem_inv_img ctxt =
  let
    val unfold_tac = Local_Defs.unfold_tac ctxt
  in
    rtac @{thm subsetI} 1
    THEN etac @{thm CollectE} 1
    THEN REPEAT (etac @{thm exE} 1)
    THEN unfold_tac @{thms inv_image_def}
    THEN rtac @{thm CollectI} 1
    THEN etac @{thm conjE} 1
    THEN etac @{thm ssubst} 1
    THEN unfold_tac (@{thms split_conv} @ @{thms triv_forall_equality}
                     @ @{thms sum.cases})
  end

(* Sets *)

val setT = HOLogic.mk_setT

fun set_member_tac m i =
  if m = 0 then rtac @{thm insertI1} i
  else rtac @{thm insertI2} i THEN set_member_tac (m - 1) i

val set_nonempty_tac = rtac @{thm insert_not_empty}

fun set_finite_tac i =
  rtac @{thm finite.emptyI} i
  ORELSE (rtac @{thm finite.insertI} i THEN (fn st => set_finite_tac i st))


(* Reconstruction *)

fun reconstruct_tac ctxt D cs (GP (_, gs)) certificate =
  let
    val thy = Proof_Context.theory_of ctxt
    val Multiset
          { msetT, mk_mset,
            mset_regroup_conv, mset_pwleq_tac, set_of_simps,
            smsI', wmsI2'', wmsI1, reduction_pair=ms_rp, ...} 
        = get_multiset_setup thy

    fun measure_fn p = nth (Termination.get_measures D p)

    fun get_desc_thm cidx m1 m2 bStrict =
      case Termination.get_descent D (nth cs cidx) m1 m2
       of SOME (Termination.Less thm) =>
          if bStrict then thm
          else (thm COMP (Thm.lift_rule (cprop_of thm) @{thm less_imp_le}))
        | SOME (Termination.LessEq (thm, _))  =>
          if not bStrict then thm
          else raise Fail "get_desc_thm"
        | _ => raise Fail "get_desc_thm"

    val (label, lev, sl, covering) = certificate

    fun prove_lev strict g =
      let
        val G (p, q, _) = nth gs g

        fun less_proof strict (j, b) (i, a) =
          let
            val tag_flag = b < a orelse (not strict andalso b <= a)

            val stored_thm =
              get_desc_thm g (measure_fn p i) (measure_fn q j)
                             (not tag_flag)
              |> Conv.fconv_rule (Thm.beta_conversion true)

            val rule = if strict
              then if b < a then @{thm pair_lessI2} else @{thm pair_lessI1}
              else if b <= a then @{thm pair_leqI2} else @{thm pair_leqI1}
          in
            rtac rule 1 THEN PRIMITIVE (Thm.elim_implies stored_thm)
            THEN (if tag_flag then Arith_Data.arith_tac ctxt 1 else all_tac)
          end

        fun steps_tac MAX strict lq lp =
          let
            val (empty, step) = ord_intros_max strict
          in
            if length lq = 0
            then rtac empty 1 THEN set_finite_tac 1
                 THEN (if strict then set_nonempty_tac 1 else all_tac)
            else
              let
                val (j, b) :: rest = lq
                val (i, a) = the (covering g strict j)
                fun choose xs = set_member_tac (Library.find_index (curry op = (i, a)) xs) 1
                val solve_tac = choose lp THEN less_proof strict (j, b) (i, a)
              in
                rtac step 1 THEN solve_tac THEN steps_tac MAX strict rest lp
              end
          end
          | steps_tac MIN strict lq lp =
          let
            val (empty, step) = ord_intros_min strict
          in
            if length lp = 0
            then rtac empty 1
                 THEN (if strict then set_nonempty_tac 1 else all_tac)
            else
              let
                val (i, a) :: rest = lp
                val (j, b) = the (covering g strict i)
                fun choose xs = set_member_tac (Library.find_index (curry op = (j, b)) xs) 1
                val solve_tac = choose lq THEN less_proof strict (j, b) (i, a)
              in
                rtac step 1 THEN solve_tac THEN steps_tac MIN strict lq rest
              end
          end
          | steps_tac MS strict lq lp =
          let
            fun get_str_cover (j, b) =
              if is_some (covering g true j) then SOME (j, b) else NONE
            fun get_wk_cover (j, b) = the (covering g false j)

            val qs = subtract (op =) (map_filter get_str_cover lq) lq
            val ps = map get_wk_cover qs

            fun indices xs ys = map (fn y => Library.find_index (curry op = y) xs) ys
            val iqs = indices lq qs
            val ips = indices lp ps

            local open Conv in
            fun t_conv a C =
              params_conv ~1 (K ((concl_conv ~1 o arg_conv o arg1_conv o a) C)) ctxt
            val goal_rewrite =
                t_conv arg1_conv (mset_regroup_conv iqs)
                then_conv t_conv arg_conv (mset_regroup_conv ips)
            end
          in
            CONVERSION goal_rewrite 1
            THEN (if strict then rtac smsI' 1
                  else if qs = lq then rtac wmsI2'' 1
                  else rtac wmsI1 1)
            THEN mset_pwleq_tac 1
            THEN EVERY (map2 (less_proof false) qs ps)
            THEN (if strict orelse qs <> lq
                  then Local_Defs.unfold_tac ctxt set_of_simps
                       THEN steps_tac MAX true
                       (subtract (op =) qs lq) (subtract (op =) ps lp)
                  else all_tac)
          end
      in
        rem_inv_img ctxt
        THEN steps_tac label strict (nth lev q) (nth lev p)
      end

    val (mk_set, setT) = if label = MS then (mk_mset, msetT) else (HOLogic.mk_set, setT)

    fun tag_pair p (i, tag) =
      HOLogic.pair_const natT natT $
        (measure_fn p i $ Bound 0) $ HOLogic.mk_number natT tag

    fun pt_lev (p, lm) = Abs ("x", Termination.get_types D p,
                           mk_set nat_pairT (map (tag_pair p) lm))

    val level_mapping =
      map_index pt_lev lev
        |> Termination.mk_sumcases D (setT nat_pairT)
        |> cterm_of thy
    in
      PROFILE "Proof Reconstruction"
        (CONVERSION (Conv.arg_conv (Conv.arg_conv (Function_Lib.regroup_union_conv sl))) 1
         THEN (rtac @{thm reduction_pair_lemma} 1)
         THEN (rtac @{thm rp_inv_image_rp} 1)
         THEN (rtac (order_rpair ms_rp label) 1)
         THEN PRIMITIVE (instantiate' [] [SOME level_mapping])
         THEN unfold_tac @{thms rp_inv_image_def} (simpset_of ctxt)
         THEN Local_Defs.unfold_tac ctxt
           (@{thms split_conv} @ @{thms fst_conv} @ @{thms snd_conv})
         THEN REPEAT (SOMEGOAL (resolve_tac [@{thm Un_least}, @{thm empty_subsetI}]))
         THEN EVERY (map (prove_lev true) sl)
         THEN EVERY (map (prove_lev false) (subtract (op =) sl (0 upto length cs - 1))))
    end


fun single_scnp_tac use_tags orders ctxt D = Termination.CALLS (fn (cs, i) =>
  let
    val ms_configured = is_some (Multiset_Setup.get (Proof_Context.theory_of ctxt))
    val orders' = if ms_configured then orders
                  else filter_out (curry op = MS) orders
    val gp = gen_probl D cs
    val certificate = generate_certificate use_tags orders' gp
  in
    case certificate
     of NONE => no_tac
      | SOME cert =>
          SELECT_GOAL (reconstruct_tac ctxt D cs gp cert) i
          THEN TRY (rtac @{thm wf_empty} i)
  end)

local open Termination in

fun gen_decomp_scnp_tac orders autom_tac ctxt =
TERMINATION ctxt autom_tac (fn D => 
  let
    val decompose = decompose_tac D
    val scnp_full = single_scnp_tac true orders ctxt D
  in
    REPEAT_ALL_NEW (scnp_full ORELSE' decompose)
  end)
end

fun gen_sizechange_tac orders autom_tac ctxt =
  TRY (Function_Common.apply_termination_rule ctxt 1)
  THEN TRY (Termination.wf_union_tac ctxt)
  THEN
   (rtac @{thm wf_empty} 1
    ORELSE gen_decomp_scnp_tac orders autom_tac ctxt 1)

fun sizechange_tac ctxt autom_tac =
  gen_sizechange_tac [MAX, MS, MIN] autom_tac ctxt

fun decomp_scnp_tac orders ctxt =
  let
    val extra_simps = Function_Common.Termination_Simps.get ctxt
    val autom_tac = auto_tac (map_simpset (fn ss => ss addsimps extra_simps) ctxt)
  in
     gen_sizechange_tac orders autom_tac ctxt
  end


(* Method setup *)

val orders =
  Scan.repeat1
    ((Args.$$$ "max" >> K MAX) ||
     (Args.$$$ "min" >> K MIN) ||
     (Args.$$$ "ms" >> K MS))
  || Scan.succeed [MAX, MS, MIN]

val setup = Method.setup @{binding size_change}
  (Scan.lift orders --| Method.sections clasimp_modifiers >>
    (fn orders => SIMPLE_METHOD o decomp_scnp_tac orders))
  "termination prover with graph decomposition and the NP subset of size change termination"

end