src/HOL/Tools/function_package/scnp_reconstruct.ML
author wenzelm
Fri Mar 20 15:24:18 2009 +0100 (2009-03-20)
changeset 30607 c3d1590debd8
parent 30510 4120fc59dd85
child 30686 47a32dd1b86e
permissions -rw-r--r--
eliminated global SIMPSET, CLASET etc. -- refer to explicit context;
     1 (*  Title:       HOL/Tools/function_package/scnp_reconstruct.ML
     2     Author:      Armin Heller, TU Muenchen
     3     Author:      Alexander Krauss, TU Muenchen
     4 
     5 Proof reconstruction for SCNP
     6 *)
     7 
     8 signature SCNP_RECONSTRUCT =
     9 sig
    10 
    11   val sizechange_tac : Proof.context -> tactic -> tactic
    12 
    13   val decomp_scnp : ScnpSolve.label list -> Proof.context -> Proof.method
    14 
    15   val setup : theory -> theory
    16 
    17   datatype multiset_setup =
    18     Multiset of
    19     {
    20      msetT : typ -> typ,
    21      mk_mset : typ -> term list -> term,
    22      mset_regroup_conv : int list -> conv,
    23      mset_member_tac : int -> int -> tactic,
    24      mset_nonempty_tac : int -> tactic,
    25      mset_pwleq_tac : int -> tactic,
    26      set_of_simps : thm list,
    27      smsI' : thm,
    28      wmsI2'' : thm,
    29      wmsI1 : thm,
    30      reduction_pair : thm
    31     }
    32 
    33 
    34   val multiset_setup : multiset_setup -> theory -> theory
    35 
    36 end
    37 
    38 structure ScnpReconstruct : SCNP_RECONSTRUCT =
    39 struct
    40 
    41 val PROFILE = FundefCommon.PROFILE
    42 fun TRACE x = if ! FundefCommon.profile then Output.tracing x else ()
    43 
    44 open ScnpSolve
    45 
    46 val natT = HOLogic.natT
    47 val nat_pairT = HOLogic.mk_prodT (natT, natT)
    48 
    49 (* Theory dependencies *)
    50 
    51 datatype multiset_setup =
    52   Multiset of
    53   {
    54    msetT : typ -> typ,
    55    mk_mset : typ -> term list -> term,
    56    mset_regroup_conv : int list -> conv,
    57    mset_member_tac : int -> int -> tactic,
    58    mset_nonempty_tac : int -> tactic,
    59    mset_pwleq_tac : int -> tactic,
    60    set_of_simps : thm list,
    61    smsI' : thm,
    62    wmsI2'' : thm,
    63    wmsI1 : thm,
    64    reduction_pair : thm
    65   }
    66 
    67 structure MultisetSetup = TheoryDataFun
    68 (
    69   type T = multiset_setup option
    70   val empty = NONE
    71   val copy = I;
    72   val extend = I;
    73   fun merge _ (v1, v2) = if is_some v2 then v2 else v1
    74 )
    75 
    76 val multiset_setup = MultisetSetup.put o SOME
    77 
    78 fun undef x = error "undef"
    79 fun get_multiset_setup thy = MultisetSetup.get thy
    80   |> the_default (Multiset
    81 { msetT = undef, mk_mset=undef,
    82   mset_regroup_conv=undef, mset_member_tac = undef,
    83   mset_nonempty_tac = undef, mset_pwleq_tac = undef,
    84   set_of_simps = [],reduction_pair = refl,
    85   smsI'=refl, wmsI2''=refl, wmsI1=refl })
    86 
    87 fun order_rpair _ MAX = @{thm max_rpair_set}
    88   | order_rpair msrp MS  = msrp
    89   | order_rpair _ MIN = @{thm min_rpair_set}
    90 
    91 fun ord_intros_max true =
    92     (@{thm smax_emptyI}, @{thm smax_insertI})
    93   | ord_intros_max false =
    94     (@{thm wmax_emptyI}, @{thm wmax_insertI})
    95 fun ord_intros_min true =
    96     (@{thm smin_emptyI}, @{thm smin_insertI})
    97   | ord_intros_min false =
    98     (@{thm wmin_emptyI}, @{thm wmin_insertI})
    99 
   100 fun gen_probl D cs =
   101   let
   102     val n = Termination.get_num_points D
   103     val arity = length o Termination.get_measures D
   104     fun measure p i = nth (Termination.get_measures D p) i
   105 
   106     fun mk_graph c =
   107       let
   108         val (_, p, _, q, _, _) = Termination.dest_call D c
   109 
   110         fun add_edge i j =
   111           case Termination.get_descent D c (measure p i) (measure q j)
   112            of SOME (Termination.Less _) => cons (i, GTR, j)
   113             | SOME (Termination.LessEq _) => cons (i, GEQ, j)
   114             | _ => I
   115 
   116         val edges =
   117           fold_product add_edge (0 upto arity p - 1) (0 upto arity q - 1) []
   118       in
   119         G (p, q, edges)
   120       end
   121   in
   122     GP (map arity (0 upto n - 1), map mk_graph cs)
   123   end
   124 
   125 (* General reduction pair application *)
   126 fun rem_inv_img ctxt =
   127   let
   128     val unfold_tac = LocalDefs.unfold_tac ctxt
   129   in
   130     rtac @{thm subsetI} 1
   131     THEN etac @{thm CollectE} 1
   132     THEN REPEAT (etac @{thm exE} 1)
   133     THEN unfold_tac @{thms inv_image_def}
   134     THEN rtac @{thm CollectI} 1
   135     THEN etac @{thm conjE} 1
   136     THEN etac @{thm ssubst} 1
   137     THEN unfold_tac (@{thms split_conv} @ @{thms triv_forall_equality}
   138                      @ @{thms sum.cases})
   139   end
   140 
   141 (* Sets *)
   142 
   143 val setT = HOLogic.mk_setT
   144 
   145 fun set_member_tac m i =
   146   if m = 0 then rtac @{thm insertI1} i
   147   else rtac @{thm insertI2} i THEN set_member_tac (m - 1) i
   148 
   149 val set_nonempty_tac = rtac @{thm insert_not_empty}
   150 
   151 fun set_finite_tac i =
   152   rtac @{thm finite.emptyI} i
   153   ORELSE (rtac @{thm finite.insertI} i THEN (fn st => set_finite_tac i st))
   154 
   155 
   156 (* Reconstruction *)
   157 
   158 fun reconstruct_tac ctxt D cs (gp as GP (_, gs)) certificate =
   159   let
   160     val thy = ProofContext.theory_of ctxt
   161     val Multiset
   162           { msetT, mk_mset,
   163             mset_regroup_conv, mset_member_tac,
   164             mset_nonempty_tac, mset_pwleq_tac, set_of_simps,
   165             smsI', wmsI2'', wmsI1, reduction_pair=ms_rp } 
   166         = get_multiset_setup thy
   167 
   168     fun measure_fn p = nth (Termination.get_measures D p)
   169 
   170     fun get_desc_thm cidx m1 m2 bStrict =
   171       case Termination.get_descent D (nth cs cidx) m1 m2
   172        of SOME (Termination.Less thm) =>
   173           if bStrict then thm
   174           else (thm COMP (Thm.lift_rule (cprop_of thm) @{thm less_imp_le}))
   175         | SOME (Termination.LessEq (thm, _))  =>
   176           if not bStrict then thm
   177           else sys_error "get_desc_thm"
   178         | _ => sys_error "get_desc_thm"
   179 
   180     val (label, lev, sl, covering) = certificate
   181 
   182     fun prove_lev strict g =
   183       let
   184         val G (p, q, el) = nth gs g
   185 
   186         fun less_proof strict (j, b) (i, a) =
   187           let
   188             val tag_flag = b < a orelse (not strict andalso b <= a)
   189 
   190             val stored_thm =
   191               get_desc_thm g (measure_fn p i) (measure_fn q j)
   192                              (not tag_flag)
   193               |> Conv.fconv_rule (Thm.beta_conversion true)
   194 
   195             val rule = if strict
   196               then if b < a then @{thm pair_lessI2} else @{thm pair_lessI1}
   197               else if b <= a then @{thm pair_leqI2} else @{thm pair_leqI1}
   198           in
   199             rtac rule 1 THEN PRIMITIVE (Thm.elim_implies stored_thm)
   200             THEN (if tag_flag then arith_tac ctxt 1 else all_tac)
   201           end
   202 
   203         fun steps_tac MAX strict lq lp =
   204           let
   205             val (empty, step) = ord_intros_max strict
   206           in
   207             if length lq = 0
   208             then rtac empty 1 THEN set_finite_tac 1
   209                  THEN (if strict then set_nonempty_tac 1 else all_tac)
   210             else
   211               let
   212                 val (j, b) :: rest = lq
   213                 val (i, a) = the (covering g strict j)
   214                 fun choose xs = set_member_tac (Library.find_index (curry op = (i, a)) xs) 1
   215                 val solve_tac = choose lp THEN less_proof strict (j, b) (i, a)
   216               in
   217                 rtac step 1 THEN solve_tac THEN steps_tac MAX strict rest lp
   218               end
   219           end
   220           | steps_tac MIN strict lq lp =
   221           let
   222             val (empty, step) = ord_intros_min strict
   223           in
   224             if length lp = 0
   225             then rtac empty 1
   226                  THEN (if strict then set_nonempty_tac 1 else all_tac)
   227             else
   228               let
   229                 val (i, a) :: rest = lp
   230                 val (j, b) = the (covering g strict i)
   231                 fun choose xs = set_member_tac (Library.find_index (curry op = (j, b)) xs) 1
   232                 val solve_tac = choose lq THEN less_proof strict (j, b) (i, a)
   233               in
   234                 rtac step 1 THEN solve_tac THEN steps_tac MIN strict lq rest
   235               end
   236           end
   237           | steps_tac MS strict lq lp =
   238           let
   239             fun get_str_cover (j, b) =
   240               if is_some (covering g true j) then SOME (j, b) else NONE
   241             fun get_wk_cover (j, b) = the (covering g false j)
   242 
   243             val qs = lq \\ map_filter get_str_cover lq
   244             val ps = map get_wk_cover qs
   245 
   246             fun indices xs ys = map (fn y => Library.find_index (curry op = y) xs) ys
   247             val iqs = indices lq qs
   248             val ips = indices lp ps
   249 
   250             local open Conv in
   251             fun t_conv a C =
   252               params_conv ~1 (K ((concl_conv ~1 o arg_conv o arg1_conv o a) C)) ctxt
   253             val goal_rewrite =
   254                 t_conv arg1_conv (mset_regroup_conv iqs)
   255                 then_conv t_conv arg_conv (mset_regroup_conv ips)
   256             end
   257           in
   258             CONVERSION goal_rewrite 1
   259             THEN (if strict then rtac smsI' 1
   260                   else if qs = lq then rtac wmsI2'' 1
   261                   else rtac wmsI1 1)
   262             THEN mset_pwleq_tac 1
   263             THEN EVERY (map2 (less_proof false) qs ps)
   264             THEN (if strict orelse qs <> lq
   265                   then LocalDefs.unfold_tac ctxt set_of_simps
   266                        THEN steps_tac MAX true (lq \\ qs) (lp \\ ps)
   267                   else all_tac)
   268           end
   269       in
   270         rem_inv_img ctxt
   271         THEN steps_tac label strict (nth lev q) (nth lev p)
   272       end
   273 
   274     val (mk_set, setT) = if label = MS then (mk_mset, msetT) else (HOLogic.mk_set, setT)
   275 
   276     fun tag_pair p (i, tag) =
   277       HOLogic.pair_const natT natT $
   278         (measure_fn p i $ Bound 0) $ HOLogic.mk_number natT tag
   279 
   280     fun pt_lev (p, lm) = Abs ("x", Termination.get_types D p,
   281                            mk_set nat_pairT (map (tag_pair p) lm))
   282 
   283     val level_mapping =
   284       map_index pt_lev lev
   285         |> Termination.mk_sumcases D (setT nat_pairT)
   286         |> cterm_of thy
   287     in
   288       PROFILE "Proof Reconstruction"
   289         (CONVERSION (Conv.arg_conv (Conv.arg_conv (FundefLib.regroup_union_conv sl))) 1
   290          THEN (rtac @{thm reduction_pair_lemma} 1)
   291          THEN (rtac @{thm rp_inv_image_rp} 1)
   292          THEN (rtac (order_rpair ms_rp label) 1)
   293          THEN PRIMITIVE (instantiate' [] [SOME level_mapping])
   294          THEN unfold_tac @{thms rp_inv_image_def} (local_simpset_of ctxt)
   295          THEN LocalDefs.unfold_tac ctxt
   296            (@{thms split_conv} @ @{thms fst_conv} @ @{thms snd_conv})
   297          THEN REPEAT (SOMEGOAL (resolve_tac [@{thm Un_least}, @{thm empty_subsetI}]))
   298          THEN EVERY (map (prove_lev true) sl)
   299          THEN EVERY (map (prove_lev false) ((0 upto length cs - 1) \\ sl)))
   300     end
   301 
   302 
   303 
   304 local open Termination in
   305 fun print_cell (SOME (Less _)) = "<"
   306   | print_cell (SOME (LessEq _)) = "\<le>"
   307   | print_cell (SOME (None _)) = "-"
   308   | print_cell (SOME (False _)) = "-"
   309   | print_cell (NONE) = "?"
   310 
   311 fun print_error ctxt D = CALLS (fn (cs, i) =>
   312   let
   313     val np = get_num_points D
   314     val ms = map (get_measures D) (0 upto np - 1)
   315     val tys = map (get_types D) (0 upto np - 1)
   316     fun index xs = (1 upto length xs) ~~ xs
   317     fun outp s t f xs = map (fn (x, y) => s ^ Int.toString x ^ t ^ f y ^ "\n") xs
   318     val ims = index (map index ms)
   319     val _ = Output.tracing (concat (outp "fn #" ":\n" (concat o outp "\tmeasure #" ": " (Syntax.string_of_term ctxt)) ims))
   320     fun print_call (k, c) =
   321       let
   322         val (_, p, _, q, _, _) = dest_call D c
   323         val _ = Output.tracing ("call table for call #" ^ Int.toString k ^ ": fn " ^ 
   324                                 Int.toString (p + 1) ^ " ~> fn " ^ Int.toString (q + 1))
   325         val caller_ms = nth ms p
   326         val callee_ms = nth ms q
   327         val entries = map (fn x => map (pair x) (callee_ms)) (caller_ms)
   328         fun print_ln (i : int, l) = concat (Int.toString i :: "   " :: map (enclose " " " " o print_cell o (uncurry (get_descent D c))) l)
   329         val _ = Output.tracing (concat (Int.toString (p + 1) ^ "|" ^ Int.toString (q + 1) ^ 
   330                                         " " :: map (enclose " " " " o Int.toString) (1 upto length callee_ms)) ^ "\n" 
   331                                 ^ cat_lines (map print_ln ((1 upto (length entries)) ~~ entries)))
   332       in
   333         true
   334       end
   335     fun list_call (k, c) =
   336       let
   337         val (_, p, _, q, _, _) = dest_call D c
   338         val _ = Output.tracing ("call #" ^ (Int.toString k) ^ ": fn " ^
   339                                 Int.toString (p + 1) ^ " ~> fn " ^ Int.toString (q + 1) ^ "\n" ^ 
   340                                 (Syntax.string_of_term ctxt c))
   341       in true end
   342     val _ = forall list_call ((1 upto length cs) ~~ cs)
   343     val _ = forall print_call ((1 upto length cs) ~~ cs)
   344   in
   345     all_tac
   346   end)
   347 end
   348 
   349 
   350 fun single_scnp_tac use_tags orders ctxt cont err_cont D = Termination.CALLS (fn (cs, i) =>
   351   let
   352     val ms_configured = is_some (MultisetSetup.get (ProofContext.theory_of ctxt))
   353     val orders' = if ms_configured then orders
   354                   else filter_out (curry op = MS) orders
   355     val gp = gen_probl D cs
   356 (*    val _ = TRACE ("SCNP instance: " ^ makestring gp)*)
   357     val certificate = generate_certificate use_tags orders' gp
   358 (*    val _ = TRACE ("Certificate: " ^ makestring certificate)*)
   359 
   360   in
   361     case certificate
   362      of NONE => err_cont D i
   363       | SOME cert =>
   364           SELECT_GOAL (reconstruct_tac ctxt D cs gp cert) i
   365           THEN (rtac @{thm wf_empty} i ORELSE cont D i)
   366   end)
   367 
   368 fun gen_decomp_scnp_tac orders autom_tac ctxt err_cont =
   369   let
   370     open Termination
   371     val derive_diag = Descent.derive_diag ctxt autom_tac
   372     val derive_all = Descent.derive_all ctxt autom_tac
   373     val decompose = Decompose.decompose_tac ctxt autom_tac
   374     val scnp_no_tags = single_scnp_tac false orders ctxt
   375     val scnp_full = single_scnp_tac true orders ctxt
   376 
   377     fun first_round c e =
   378         derive_diag (REPEAT scnp_no_tags c e)
   379 
   380     val second_round =
   381         REPEAT (fn c => fn e => decompose (scnp_no_tags c c) e)
   382 
   383     val third_round =
   384         derive_all oo
   385         REPEAT (fn c => fn e =>
   386           scnp_full (decompose c c) e)
   387 
   388     fun Then s1 s2 c e = s1 (s2 c c) (s2 c e)
   389 
   390     val strategy = Then (Then first_round second_round) third_round
   391 
   392   in
   393     TERMINATION ctxt (strategy err_cont err_cont)
   394   end
   395 
   396 fun gen_sizechange_tac orders autom_tac ctxt err_cont =
   397   TRY (FundefCommon.apply_termination_rule ctxt 1)
   398   THEN TRY (Termination.wf_union_tac ctxt)
   399   THEN
   400    (rtac @{thm wf_empty} 1
   401     ORELSE gen_decomp_scnp_tac orders autom_tac ctxt err_cont 1)
   402 
   403 fun sizechange_tac ctxt autom_tac =
   404   gen_sizechange_tac [MAX, MS, MIN] autom_tac ctxt (K (K all_tac))
   405 
   406 fun decomp_scnp orders ctxt =
   407   let
   408     val extra_simps = FundefCommon.TerminationSimps.get ctxt
   409     val autom_tac = auto_tac (local_clasimpset_of ctxt addsimps2 extra_simps)
   410   in
   411     SIMPLE_METHOD
   412       (gen_sizechange_tac orders autom_tac ctxt (print_error ctxt))
   413   end
   414 
   415 
   416 (* Method setup *)
   417 
   418 val orders =
   419   (Scan.repeat1
   420     ((Args.$$$ "max" >> K MAX) ||
   421      (Args.$$$ "min" >> K MIN) ||
   422      (Args.$$$ "ms" >> K MS))
   423   || Scan.succeed [MAX, MS, MIN])
   424 
   425 val setup = Method.add_method
   426   ("sizechange", Method.sectioned_args (Scan.lift orders) clasimp_modifiers decomp_scnp,
   427    "termination prover with graph decomposition and the NP subset of size change termination")
   428 
   429 end