src/HOL/Tools/Function/scnp_reconstruct.ML
author wenzelm
Fri Mar 06 15:58:56 2015 +0100 (2015-03-06)
changeset 59621 291934bac95e
parent 59618 e6939796717e
child 59625 aacdce52b2fc
permissions -rw-r--r--
Thm.cterm_of and Thm.ctyp_of operate on local context;
haftmann@31775
     1
(*  Title:       HOL/Tools/Function/scnp_reconstruct.ML
krauss@29125
     2
    Author:      Armin Heller, TU Muenchen
krauss@29125
     3
    Author:      Alexander Krauss, TU Muenchen
krauss@29125
     4
krauss@41114
     5
Proof reconstruction for SCNP termination.
krauss@29125
     6
*)
krauss@29125
     7
krauss@29125
     8
signature SCNP_RECONSTRUCT =
krauss@29125
     9
sig
krauss@29877
    10
  val sizechange_tac : Proof.context -> tactic -> tactic
krauss@29877
    11
krauss@36521
    12
  val decomp_scnp_tac : ScnpSolve.label list -> Proof.context -> tactic
krauss@29125
    13
krauss@29125
    14
  datatype multiset_setup =
krauss@29125
    15
    Multiset of
krauss@29125
    16
    {
krauss@29125
    17
     msetT : typ -> typ,
krauss@29125
    18
     mk_mset : typ -> term list -> term,
krauss@29125
    19
     mset_regroup_conv : int list -> conv,
krauss@29125
    20
     mset_member_tac : int -> int -> tactic,
krauss@29125
    21
     mset_nonempty_tac : int -> tactic,
krauss@29125
    22
     mset_pwleq_tac : int -> tactic,
krauss@29125
    23
     set_of_simps : thm list,
krauss@29125
    24
     smsI' : thm,
krauss@29125
    25
     wmsI2'' : thm,
krauss@29125
    26
     wmsI1 : thm,
krauss@29125
    27
     reduction_pair : thm
krauss@29125
    28
    }
krauss@29125
    29
krauss@29125
    30
  val multiset_setup : multiset_setup -> theory -> theory
krauss@29125
    31
end
krauss@29125
    32
krauss@29125
    33
structure ScnpReconstruct : SCNP_RECONSTRUCT =
krauss@29125
    34
struct
krauss@29125
    35
krauss@33099
    36
val PROFILE = Function_Common.PROFILE
krauss@29125
    37
krauss@29125
    38
open ScnpSolve
krauss@29125
    39
krauss@29125
    40
val natT = HOLogic.natT
krauss@29125
    41
val nat_pairT = HOLogic.mk_prodT (natT, natT)
krauss@29125
    42
wenzelm@58819
    43
krauss@29125
    44
(* Theory dependencies *)
krauss@29125
    45
krauss@29125
    46
datatype multiset_setup =
krauss@29125
    47
  Multiset of
krauss@29125
    48
  {
krauss@29125
    49
   msetT : typ -> typ,
krauss@29125
    50
   mk_mset : typ -> term list -> term,
krauss@29125
    51
   mset_regroup_conv : int list -> conv,
krauss@29125
    52
   mset_member_tac : int -> int -> tactic,
krauss@29125
    53
   mset_nonempty_tac : int -> tactic,
krauss@29125
    54
   mset_pwleq_tac : int -> tactic,
krauss@29125
    55
   set_of_simps : thm list,
krauss@29125
    56
   smsI' : thm,
krauss@29125
    57
   wmsI2'' : thm,
krauss@29125
    58
   wmsI1 : thm,
krauss@29125
    59
   reduction_pair : thm
krauss@29125
    60
  }
krauss@29125
    61
wenzelm@33522
    62
structure Multiset_Setup = Theory_Data
krauss@29125
    63
(
krauss@29125
    64
  type T = multiset_setup option
krauss@29125
    65
  val empty = NONE
krauss@29125
    66
  val extend = I;
wenzelm@41493
    67
  val merge = merge_options
krauss@29125
    68
)
krauss@29125
    69
krauss@33099
    70
val multiset_setup = Multiset_Setup.put o SOME
krauss@29125
    71
krauss@33855
    72
fun undef _ = error "undef"
wenzelm@58819
    73
krauss@33099
    74
fun get_multiset_setup thy = Multiset_Setup.get thy
krauss@29125
    75
  |> the_default (Multiset
wenzelm@58819
    76
    { msetT = undef, mk_mset=undef,
wenzelm@58819
    77
      mset_regroup_conv=undef, mset_member_tac = undef,
wenzelm@58819
    78
      mset_nonempty_tac = undef, mset_pwleq_tac = undef,
wenzelm@58819
    79
      set_of_simps = [],reduction_pair = refl,
wenzelm@58819
    80
      smsI'=refl, wmsI2''=refl, wmsI1=refl })
krauss@29125
    81
krauss@29125
    82
fun order_rpair _ MAX = @{thm max_rpair_set}
krauss@29125
    83
  | order_rpair msrp MS  = msrp
krauss@29125
    84
  | order_rpair _ MIN = @{thm min_rpair_set}
krauss@29125
    85
wenzelm@58819
    86
fun ord_intros_max true = (@{thm smax_emptyI}, @{thm smax_insertI})
wenzelm@58819
    87
  | ord_intros_max false = (@{thm wmax_emptyI}, @{thm wmax_insertI})
wenzelm@58819
    88
wenzelm@58819
    89
fun ord_intros_min true = (@{thm smin_emptyI}, @{thm smin_insertI})
wenzelm@58819
    90
  | ord_intros_min false = (@{thm wmin_emptyI}, @{thm wmin_insertI})
krauss@29125
    91
krauss@29125
    92
fun gen_probl D cs =
krauss@29125
    93
  let
krauss@29125
    94
    val n = Termination.get_num_points D
krauss@29125
    95
    val arity = length o Termination.get_measures D
krauss@29125
    96
    fun measure p i = nth (Termination.get_measures D p) i
krauss@29125
    97
krauss@29125
    98
    fun mk_graph c =
krauss@29125
    99
      let
krauss@29125
   100
        val (_, p, _, q, _, _) = Termination.dest_call D c
krauss@29125
   101
krauss@29125
   102
        fun add_edge i j =
krauss@29125
   103
          case Termination.get_descent D c (measure p i) (measure q j)
krauss@29125
   104
           of SOME (Termination.Less _) => cons (i, GTR, j)
krauss@29125
   105
            | SOME (Termination.LessEq _) => cons (i, GEQ, j)
krauss@29125
   106
            | _ => I
krauss@29125
   107
krauss@29125
   108
        val edges =
krauss@29125
   109
          fold_product add_edge (0 upto arity p - 1) (0 upto arity q - 1) []
krauss@29125
   110
      in
krauss@29125
   111
        G (p, q, edges)
krauss@29125
   112
      end
krauss@29125
   113
  in
haftmann@33063
   114
    GP (map_range arity n, map mk_graph cs)
krauss@29125
   115
  end
krauss@29125
   116
krauss@29125
   117
(* General reduction pair application *)
krauss@29125
   118
fun rem_inv_img ctxt =
wenzelm@54998
   119
  rtac @{thm subsetI} 1
wenzelm@54998
   120
  THEN etac @{thm CollectE} 1
wenzelm@54998
   121
  THEN REPEAT (etac @{thm exE} 1)
wenzelm@54998
   122
  THEN Local_Defs.unfold_tac ctxt @{thms inv_image_def}
wenzelm@54998
   123
  THEN rtac @{thm CollectI} 1
wenzelm@54998
   124
  THEN etac @{thm conjE} 1
wenzelm@54998
   125
  THEN etac @{thm ssubst} 1
blanchet@55642
   126
  THEN Local_Defs.unfold_tac ctxt @{thms split_conv triv_forall_equality sum.case}
krauss@29125
   127
wenzelm@58819
   128
krauss@29125
   129
(* Sets *)
krauss@29125
   130
krauss@29125
   131
val setT = HOLogic.mk_setT
krauss@29125
   132
krauss@29125
   133
fun set_member_tac m i =
krauss@29125
   134
  if m = 0 then rtac @{thm insertI1} i
krauss@29125
   135
  else rtac @{thm insertI2} i THEN set_member_tac (m - 1) i
krauss@29125
   136
krauss@29125
   137
val set_nonempty_tac = rtac @{thm insert_not_empty}
krauss@29125
   138
krauss@29125
   139
fun set_finite_tac i =
krauss@29125
   140
  rtac @{thm finite.emptyI} i
krauss@29125
   141
  ORELSE (rtac @{thm finite.insertI} i THEN (fn st => set_finite_tac i st))
krauss@29125
   142
krauss@29125
   143
krauss@29125
   144
(* Reconstruction *)
krauss@29125
   145
krauss@33855
   146
fun reconstruct_tac ctxt D cs (GP (_, gs)) certificate =
krauss@29125
   147
  let
wenzelm@42361
   148
    val thy = Proof_Context.theory_of ctxt
krauss@29125
   149
    val Multiset
krauss@29125
   150
          { msetT, mk_mset,
krauss@33855
   151
            mset_regroup_conv, mset_pwleq_tac, set_of_simps,
wenzelm@58819
   152
            smsI', wmsI2'', wmsI1, reduction_pair=ms_rp, ...}
krauss@29125
   153
        = get_multiset_setup thy
krauss@29125
   154
krauss@29125
   155
    fun measure_fn p = nth (Termination.get_measures D p)
krauss@29125
   156
krauss@29125
   157
    fun get_desc_thm cidx m1 m2 bStrict =
wenzelm@58819
   158
      (case Termination.get_descent D (nth cs cidx) m1 m2 of
wenzelm@58819
   159
        SOME (Termination.Less thm) =>
krauss@29125
   160
          if bStrict then thm
wenzelm@59582
   161
          else (thm COMP (Thm.lift_rule (Thm.cprop_of thm) @{thm less_imp_le}))
wenzelm@58819
   162
      | SOME (Termination.LessEq (thm, _))  =>
krauss@29125
   163
          if not bStrict then thm
wenzelm@40317
   164
          else raise Fail "get_desc_thm"
wenzelm@58819
   165
      | _ => raise Fail "get_desc_thm")
krauss@29125
   166
krauss@29125
   167
    val (label, lev, sl, covering) = certificate
krauss@29125
   168
krauss@29125
   169
    fun prove_lev strict g =
krauss@29125
   170
      let
krauss@33855
   171
        val G (p, q, _) = nth gs g
krauss@29125
   172
krauss@29125
   173
        fun less_proof strict (j, b) (i, a) =
krauss@29125
   174
          let
krauss@29125
   175
            val tag_flag = b < a orelse (not strict andalso b <= a)
krauss@29125
   176
krauss@29125
   177
            val stored_thm =
krauss@29125
   178
              get_desc_thm g (measure_fn p i) (measure_fn q j)
krauss@29125
   179
                             (not tag_flag)
krauss@29125
   180
              |> Conv.fconv_rule (Thm.beta_conversion true)
krauss@29125
   181
wenzelm@58819
   182
            val rule =
wenzelm@58819
   183
              if strict
krauss@29125
   184
              then if b < a then @{thm pair_lessI2} else @{thm pair_lessI1}
krauss@29125
   185
              else if b <= a then @{thm pair_leqI2} else @{thm pair_leqI1}
krauss@29125
   186
          in
krauss@29125
   187
            rtac rule 1 THEN PRIMITIVE (Thm.elim_implies stored_thm)
blanchet@33569
   188
            THEN (if tag_flag then Arith_Data.arith_tac ctxt 1 else all_tac)
krauss@29125
   189
          end
krauss@29125
   190
krauss@29125
   191
        fun steps_tac MAX strict lq lp =
krauss@29125
   192
              let
wenzelm@58819
   193
                val (empty, step) = ord_intros_max strict
krauss@29125
   194
              in
wenzelm@58819
   195
                if length lq = 0
wenzelm@58819
   196
                then rtac empty 1 THEN set_finite_tac 1
wenzelm@58819
   197
                     THEN (if strict then set_nonempty_tac 1 else all_tac)
wenzelm@58819
   198
                else
wenzelm@58819
   199
                  let
wenzelm@58819
   200
                    val (j, b) :: rest = lq
wenzelm@58819
   201
                    val (i, a) = the (covering g strict j)
wenzelm@59584
   202
                    fun choose xs = set_member_tac (find_index (curry op = (i, a)) xs) 1
wenzelm@58819
   203
                    val solve_tac = choose lp THEN less_proof strict (j, b) (i, a)
wenzelm@58819
   204
                  in
wenzelm@58819
   205
                    rtac step 1 THEN solve_tac THEN steps_tac MAX strict rest lp
wenzelm@58819
   206
                  end
krauss@29125
   207
              end
krauss@29125
   208
          | steps_tac MIN strict lq lp =
krauss@29125
   209
              let
wenzelm@58819
   210
                val (empty, step) = ord_intros_min strict
krauss@29125
   211
              in
wenzelm@58819
   212
                if length lp = 0
wenzelm@58819
   213
                then rtac empty 1
wenzelm@58819
   214
                     THEN (if strict then set_nonempty_tac 1 else all_tac)
wenzelm@58819
   215
                else
wenzelm@58819
   216
                  let
wenzelm@58819
   217
                    val (i, a) :: rest = lp
wenzelm@58819
   218
                    val (j, b) = the (covering g strict i)
wenzelm@59584
   219
                    fun choose xs = set_member_tac (find_index (curry op = (j, b)) xs) 1
wenzelm@58819
   220
                    val solve_tac = choose lq THEN less_proof strict (j, b) (i, a)
wenzelm@58819
   221
                  in
wenzelm@58819
   222
                    rtac step 1 THEN solve_tac THEN steps_tac MIN strict lq rest
wenzelm@58819
   223
                  end
krauss@29125
   224
              end
krauss@29125
   225
          | steps_tac MS strict lq lp =
wenzelm@58819
   226
              let
wenzelm@58819
   227
                fun get_str_cover (j, b) =
wenzelm@58819
   228
                  if is_some (covering g true j) then SOME (j, b) else NONE
wenzelm@58819
   229
                fun get_wk_cover (j, b) = the (covering g false j)
krauss@29125
   230
wenzelm@58819
   231
                val qs = subtract (op =) (map_filter get_str_cover lq) lq
wenzelm@58819
   232
                val ps = map get_wk_cover qs
krauss@29125
   233
wenzelm@59584
   234
                fun indices xs ys = map (fn y => find_index (curry op = y) xs) ys
wenzelm@58819
   235
                val iqs = indices lq qs
wenzelm@58819
   236
                val ips = indices lp ps
krauss@29125
   237
wenzelm@58819
   238
                local open Conv in
wenzelm@58819
   239
                fun t_conv a C =
wenzelm@58819
   240
                  params_conv ~1 (K ((concl_conv ~1 o arg_conv o arg1_conv o a) C)) ctxt
wenzelm@58819
   241
                val goal_rewrite =
wenzelm@58819
   242
                    t_conv arg1_conv (mset_regroup_conv iqs)
wenzelm@58819
   243
                    then_conv t_conv arg_conv (mset_regroup_conv ips)
wenzelm@58819
   244
                end
wenzelm@58819
   245
              in
wenzelm@58819
   246
                CONVERSION goal_rewrite 1
wenzelm@58819
   247
                THEN (if strict then rtac smsI' 1
wenzelm@58819
   248
                      else if qs = lq then rtac wmsI2'' 1
wenzelm@58819
   249
                      else rtac wmsI1 1)
wenzelm@58819
   250
                THEN mset_pwleq_tac 1
wenzelm@58819
   251
                THEN EVERY (map2 (less_proof false) qs ps)
wenzelm@58819
   252
                THEN (if strict orelse qs <> lq
wenzelm@58819
   253
                      then Local_Defs.unfold_tac ctxt set_of_simps
wenzelm@58819
   254
                           THEN steps_tac MAX true
wenzelm@58819
   255
                           (subtract (op =) qs lq) (subtract (op =) ps lp)
wenzelm@58819
   256
                      else all_tac)
wenzelm@58819
   257
              end
krauss@29125
   258
      in
krauss@29125
   259
        rem_inv_img ctxt
krauss@29125
   260
        THEN steps_tac label strict (nth lev q) (nth lev p)
krauss@29125
   261
      end
krauss@29125
   262
haftmann@30450
   263
    val (mk_set, setT) = if label = MS then (mk_mset, msetT) else (HOLogic.mk_set, setT)
krauss@29125
   264
krauss@29125
   265
    fun tag_pair p (i, tag) =
krauss@29125
   266
      HOLogic.pair_const natT natT $
krauss@29125
   267
        (measure_fn p i $ Bound 0) $ HOLogic.mk_number natT tag
krauss@29125
   268
wenzelm@58819
   269
    fun pt_lev (p, lm) =
wenzelm@58819
   270
      Abs ("x", Termination.get_types D p, mk_set nat_pairT (map (tag_pair p) lm))
krauss@29125
   271
krauss@29125
   272
    val level_mapping =
krauss@29125
   273
      map_index pt_lev lev
krauss@29125
   274
        |> Termination.mk_sumcases D (setT nat_pairT)
wenzelm@59621
   275
        |> Thm.cterm_of ctxt
krauss@29125
   276
    in
krauss@29125
   277
      PROFILE "Proof Reconstruction"
krauss@33099
   278
        (CONVERSION (Conv.arg_conv (Conv.arg_conv (Function_Lib.regroup_union_conv sl))) 1
krauss@29125
   279
         THEN (rtac @{thm reduction_pair_lemma} 1)
krauss@29125
   280
         THEN (rtac @{thm rp_inv_image_rp} 1)
krauss@29125
   281
         THEN (rtac (order_rpair ms_rp label) 1)
krauss@29125
   282
         THEN PRIMITIVE (instantiate' [] [SOME level_mapping])
wenzelm@54998
   283
         THEN unfold_tac ctxt @{thms rp_inv_image_def}
wenzelm@54998
   284
         THEN Local_Defs.unfold_tac ctxt @{thms split_conv fst_conv snd_conv}
wenzelm@59498
   285
         THEN REPEAT (SOMEGOAL (resolve_tac ctxt [@{thm Un_least}, @{thm empty_subsetI}]))
krauss@29125
   286
         THEN EVERY (map (prove_lev true) sl)
haftmann@33040
   287
         THEN EVERY (map (prove_lev false) (subtract (op =) sl (0 upto length cs - 1))))
krauss@29125
   288
    end
krauss@29125
   289
krauss@29125
   290
krauss@39924
   291
fun single_scnp_tac use_tags orders ctxt D = Termination.CALLS (fn (cs, i) =>
krauss@29125
   292
  let
wenzelm@42361
   293
    val ms_configured = is_some (Multiset_Setup.get (Proof_Context.theory_of ctxt))
wenzelm@58819
   294
    val orders' =
wenzelm@58819
   295
      if ms_configured then orders
wenzelm@58819
   296
      else filter_out (curry op = MS) orders
krauss@29125
   297
    val gp = gen_probl D cs
krauss@29877
   298
    val certificate = generate_certificate use_tags orders' gp
krauss@29877
   299
  in
wenzelm@58819
   300
    (case certificate of
wenzelm@58819
   301
      NONE => no_tac
wenzelm@58819
   302
    | SOME cert =>
wenzelm@58819
   303
        SELECT_GOAL (reconstruct_tac ctxt D cs gp cert) i
wenzelm@58819
   304
        THEN TRY (rtac @{thm wf_empty} i))
krauss@29125
   305
  end)
krauss@29125
   306
krauss@39924
   307
fun gen_decomp_scnp_tac orders autom_tac ctxt =
wenzelm@58819
   308
  Termination.TERMINATION ctxt autom_tac (fn D =>
wenzelm@58819
   309
    let
wenzelm@59618
   310
      val decompose = Termination.decompose_tac ctxt D
wenzelm@58819
   311
      val scnp_full = single_scnp_tac true orders ctxt D
wenzelm@58819
   312
    in
wenzelm@58819
   313
      REPEAT_ALL_NEW (scnp_full ORELSE' decompose)
wenzelm@58819
   314
    end)
krauss@29125
   315
krauss@39924
   316
fun gen_sizechange_tac orders autom_tac ctxt =
wenzelm@59159
   317
  TRY (Function_Common.termination_rule_tac ctxt 1)
wenzelm@30607
   318
  THEN TRY (Termination.wf_union_tac ctxt)
wenzelm@58819
   319
  THEN (rtac @{thm wf_empty} 1 ORELSE gen_decomp_scnp_tac orders autom_tac ctxt 1)
krauss@29877
   320
krauss@29877
   321
fun sizechange_tac ctxt autom_tac =
krauss@39924
   322
  gen_sizechange_tac [MAX, MS, MIN] autom_tac ctxt
krauss@29877
   323
krauss@36521
   324
fun decomp_scnp_tac orders ctxt =
krauss@29125
   325
  let
wenzelm@57959
   326
    val extra_simps = Named_Theorems.get ctxt @{named_theorems termination_simp}
wenzelm@51717
   327
    val autom_tac = auto_tac (ctxt addsimps extra_simps)
krauss@29125
   328
  in
krauss@39924
   329
     gen_sizechange_tac orders autom_tac ctxt
krauss@29125
   330
  end
krauss@29125
   331
krauss@29125
   332
krauss@29125
   333
(* Method setup *)
krauss@29125
   334
krauss@29125
   335
val orders =
wenzelm@31242
   336
  Scan.repeat1
krauss@29125
   337
    ((Args.$$$ "max" >> K MAX) ||
krauss@29125
   338
     (Args.$$$ "min" >> K MIN) ||
krauss@29125
   339
     (Args.$$$ "ms" >> K MS))
wenzelm@31242
   340
  || Scan.succeed [MAX, MS, MIN]
krauss@29125
   341
wenzelm@58819
   342
val _ =
wenzelm@58819
   343
  Theory.setup
wenzelm@58819
   344
    (Method.setup @{binding size_change}
wenzelm@58819
   345
      (Scan.lift orders --| Method.sections clasimp_modifiers >>
wenzelm@58819
   346
        (fn orders => SIMPLE_METHOD o decomp_scnp_tac orders))
wenzelm@58819
   347
      "termination prover with graph decomposition and the NP subset of size change termination")
krauss@29125
   348
krauss@29125
   349
end