src/HOL/Tools/function_package/scnp_reconstruct.ML
changeset 29125 d41182a8135c
child 29183 f1648e009dc1
equal deleted inserted replaced
29117:5a79ec2fedfb 29125:d41182a8135c
       
     1 (*  Title:       HOL/Tools/function_package/scnp_reconstruct.ML
       
     2     Author:      Armin Heller, TU Muenchen
       
     3     Author:      Alexander Krauss, TU Muenchen
       
     4 
       
     5 Proof reconstruction for SCNP
       
     6 *)
       
     7 
       
     8 signature SCNP_RECONSTRUCT =
       
     9 sig
       
    10 
       
    11   val decomp_scnp : ScnpSolve.label list -> Proof.context -> method
       
    12 
       
    13   val setup : theory -> theory
       
    14 
       
    15   datatype multiset_setup =
       
    16     Multiset of
       
    17     {
       
    18      msetT : typ -> typ,
       
    19      mk_mset : typ -> term list -> term,
       
    20      mset_regroup_conv : int list -> conv,
       
    21      mset_member_tac : int -> int -> tactic,
       
    22      mset_nonempty_tac : int -> tactic,
       
    23      mset_pwleq_tac : int -> tactic,
       
    24      set_of_simps : thm list,
       
    25      smsI' : thm,
       
    26      wmsI2'' : thm,
       
    27      wmsI1 : thm,
       
    28      reduction_pair : thm
       
    29     }
       
    30 
       
    31 
       
    32   val multiset_setup : multiset_setup -> theory -> theory
       
    33 
       
    34 end
       
    35 
       
    36 structure ScnpReconstruct : SCNP_RECONSTRUCT =
       
    37 struct
       
    38 
       
    39 val PROFILE = FundefCommon.PROFILE
       
    40 fun TRACE x = if ! FundefCommon.profile then Output.tracing x else ()
       
    41 
       
    42 open ScnpSolve
       
    43 
       
    44 val natT = HOLogic.natT
       
    45 val nat_pairT = HOLogic.mk_prodT (natT, natT)
       
    46 
       
    47 (* Theory dependencies *)
       
    48 
       
    49 datatype multiset_setup =
       
    50   Multiset of
       
    51   {
       
    52    msetT : typ -> typ,
       
    53    mk_mset : typ -> term list -> term,
       
    54    mset_regroup_conv : int list -> conv,
       
    55    mset_member_tac : int -> int -> tactic,
       
    56    mset_nonempty_tac : int -> tactic,
       
    57    mset_pwleq_tac : int -> tactic,
       
    58    set_of_simps : thm list,
       
    59    smsI' : thm,
       
    60    wmsI2'' : thm,
       
    61    wmsI1 : thm,
       
    62    reduction_pair : thm
       
    63   }
       
    64 
       
    65 structure MultisetSetup = TheoryDataFun
       
    66 (
       
    67   type T = multiset_setup option
       
    68   val empty = NONE
       
    69   val copy = I;
       
    70   val extend = I;
       
    71   fun merge _ (v1, v2) = if is_some v2 then v2 else v1
       
    72 )
       
    73 
       
    74 val multiset_setup = MultisetSetup.put o SOME
       
    75 
       
    76 fun undef x = error "undef"
       
    77 fun get_multiset_setup thy = MultisetSetup.get thy
       
    78   |> the_default (Multiset
       
    79 { msetT = undef, mk_mset=undef,
       
    80   mset_regroup_conv=undef, mset_member_tac = undef,
       
    81   mset_nonempty_tac = undef, mset_pwleq_tac = undef,
       
    82   set_of_simps = [],reduction_pair = refl,
       
    83   smsI'=refl, wmsI2''=refl, wmsI1=refl })
       
    84 
       
    85 fun order_rpair _ MAX = @{thm max_rpair_set}
       
    86   | order_rpair msrp MS  = msrp
       
    87   | order_rpair _ MIN = @{thm min_rpair_set}
       
    88 
       
    89 fun ord_intros_max true =
       
    90     (@{thm smax_emptyI}, @{thm smax_insertI})
       
    91   | ord_intros_max false =
       
    92     (@{thm wmax_emptyI}, @{thm wmax_insertI})
       
    93 fun ord_intros_min true =
       
    94     (@{thm smin_emptyI}, @{thm smin_insertI})
       
    95   | ord_intros_min false =
       
    96     (@{thm wmin_emptyI}, @{thm wmin_insertI})
       
    97 
       
    98 fun gen_probl D cs =
       
    99   let
       
   100     val n = Termination.get_num_points D
       
   101     val arity = length o Termination.get_measures D
       
   102     fun measure p i = nth (Termination.get_measures D p) i
       
   103 
       
   104     fun mk_graph c =
       
   105       let
       
   106         val (_, p, _, q, _, _) = Termination.dest_call D c
       
   107 
       
   108         fun add_edge i j =
       
   109           case Termination.get_descent D c (measure p i) (measure q j)
       
   110            of SOME (Termination.Less _) => cons (i, GTR, j)
       
   111             | SOME (Termination.LessEq _) => cons (i, GEQ, j)
       
   112             | _ => I
       
   113 
       
   114         val edges =
       
   115           fold_product add_edge (0 upto arity p - 1) (0 upto arity q - 1) []
       
   116       in
       
   117         G (p, q, edges)
       
   118       end
       
   119   in
       
   120     GP (map arity (0 upto n - 1), map mk_graph cs)
       
   121   end
       
   122 
       
   123 (* General reduction pair application *)
       
   124 fun rem_inv_img ctxt =
       
   125   let
       
   126     val unfold_tac = LocalDefs.unfold_tac ctxt
       
   127   in
       
   128     rtac @{thm subsetI} 1
       
   129     THEN etac @{thm CollectE} 1
       
   130     THEN REPEAT (etac @{thm exE} 1)
       
   131     THEN unfold_tac @{thms inv_image_def}
       
   132     THEN rtac @{thm CollectI} 1
       
   133     THEN etac @{thm conjE} 1
       
   134     THEN etac @{thm ssubst} 1
       
   135     THEN unfold_tac (@{thms split_conv} @ @{thms triv_forall_equality}
       
   136                      @ @{thms Sum_Type.sum_cases})
       
   137   end
       
   138 
       
   139 (* Sets *)
       
   140 
       
   141 val setT = HOLogic.mk_setT
       
   142 
       
   143 fun mk_set T [] = Const (@{const_name "{}"}, setT T)
       
   144   | mk_set T (x :: xs) =
       
   145       Const (@{const_name insert}, T --> setT T --> setT T) $
       
   146             x $ mk_set T xs
       
   147 
       
   148 fun set_member_tac m i =
       
   149   if m = 0 then rtac @{thm insertI1} i
       
   150   else rtac @{thm insertI2} i THEN set_member_tac (m - 1) i
       
   151 
       
   152 val set_nonempty_tac = rtac @{thm insert_not_empty}
       
   153 
       
   154 fun set_finite_tac i =
       
   155   rtac @{thm finite.emptyI} i
       
   156   ORELSE (rtac @{thm finite.insertI} i THEN (fn st => set_finite_tac i st))
       
   157 
       
   158 
       
   159 (* Reconstruction *)
       
   160 
       
   161 fun reconstruct_tac ctxt D cs (gp as GP (_, gs)) certificate =
       
   162   let
       
   163     val thy = ProofContext.theory_of ctxt
       
   164     val Multiset
       
   165           { msetT, mk_mset,
       
   166             mset_regroup_conv, mset_member_tac,
       
   167             mset_nonempty_tac, mset_pwleq_tac, set_of_simps,
       
   168             smsI', wmsI2'', wmsI1, reduction_pair=ms_rp } 
       
   169         = get_multiset_setup thy
       
   170 
       
   171     fun measure_fn p = nth (Termination.get_measures D p)
       
   172 
       
   173     fun get_desc_thm cidx m1 m2 bStrict =
       
   174       case Termination.get_descent D (nth cs cidx) m1 m2
       
   175        of SOME (Termination.Less thm) =>
       
   176           if bStrict then thm
       
   177           else (thm COMP (Thm.lift_rule (cprop_of thm) @{thm less_imp_le}))
       
   178         | SOME (Termination.LessEq (thm, _))  =>
       
   179           if not bStrict then thm
       
   180           else sys_error "get_desc_thm"
       
   181         | _ => sys_error "get_desc_thm"
       
   182 
       
   183     val (label, lev, sl, covering) = certificate
       
   184 
       
   185     fun prove_lev strict g =
       
   186       let
       
   187         val G (p, q, el) = nth gs g
       
   188 
       
   189         fun less_proof strict (j, b) (i, a) =
       
   190           let
       
   191             val tag_flag = b < a orelse (not strict andalso b <= a)
       
   192 
       
   193             val stored_thm =
       
   194               get_desc_thm g (measure_fn p i) (measure_fn q j)
       
   195                              (not tag_flag)
       
   196               |> Conv.fconv_rule (Thm.beta_conversion true)
       
   197 
       
   198             val rule = if strict
       
   199               then if b < a then @{thm pair_lessI2} else @{thm pair_lessI1}
       
   200               else if b <= a then @{thm pair_leqI2} else @{thm pair_leqI1}
       
   201           in
       
   202             rtac rule 1 THEN PRIMITIVE (Thm.elim_implies stored_thm)
       
   203             THEN (if tag_flag then arith_tac ctxt 1 else all_tac)
       
   204           end
       
   205 
       
   206         fun steps_tac MAX strict lq lp =
       
   207           let
       
   208             val (empty, step) = ord_intros_max strict
       
   209           in
       
   210             if length lq = 0
       
   211             then rtac empty 1 THEN set_finite_tac 1
       
   212                  THEN (if strict then set_nonempty_tac 1 else all_tac)
       
   213             else
       
   214               let
       
   215                 val (j, b) :: rest = lq
       
   216                 val (i, a) = the (covering g strict j)
       
   217                 fun choose xs = set_member_tac (Library.find_index (curry op = (i, a)) xs) 1
       
   218                 val solve_tac = choose lp THEN less_proof strict (j, b) (i, a)
       
   219               in
       
   220                 rtac step 1 THEN solve_tac THEN steps_tac MAX strict rest lp
       
   221               end
       
   222           end
       
   223           | steps_tac MIN strict lq lp =
       
   224           let
       
   225             val (empty, step) = ord_intros_min strict
       
   226           in
       
   227             if length lp = 0
       
   228             then rtac empty 1
       
   229                  THEN (if strict then set_nonempty_tac 1 else all_tac)
       
   230             else
       
   231               let
       
   232                 val (i, a) :: rest = lp
       
   233                 val (j, b) = the (covering g strict i)
       
   234                 fun choose xs = set_member_tac (Library.find_index (curry op = (j, b)) xs) 1
       
   235                 val solve_tac = choose lq THEN less_proof strict (j, b) (i, a)
       
   236               in
       
   237                 rtac step 1 THEN solve_tac THEN steps_tac MIN strict lq rest
       
   238               end
       
   239           end
       
   240           | steps_tac MS strict lq lp =
       
   241           let
       
   242             fun get_str_cover (j, b) =
       
   243               if is_some (covering g true j) then SOME (j, b) else NONE
       
   244             fun get_wk_cover (j, b) = the (covering g false j)
       
   245 
       
   246             val qs = lq \\ map_filter get_str_cover lq
       
   247             val ps = map get_wk_cover qs
       
   248 
       
   249             fun indices xs ys = map (fn y => Library.find_index (curry op = y) xs) ys
       
   250             val iqs = indices lq qs
       
   251             val ips = indices lp ps
       
   252 
       
   253             local open Conv in
       
   254             fun t_conv a C =
       
   255               params_conv ~1 (K ((concl_conv ~1 o arg_conv o arg1_conv o a) C)) ctxt
       
   256             val goal_rewrite =
       
   257                 t_conv arg1_conv (mset_regroup_conv iqs)
       
   258                 then_conv t_conv arg_conv (mset_regroup_conv ips)
       
   259             end
       
   260           in
       
   261             CONVERSION goal_rewrite 1
       
   262             THEN (if strict then rtac smsI' 1
       
   263                   else if qs = lq then rtac wmsI2'' 1
       
   264                   else rtac wmsI1 1)
       
   265             THEN mset_pwleq_tac 1
       
   266             THEN EVERY (map2 (less_proof false) qs ps)
       
   267             THEN (if strict orelse qs <> lq
       
   268                   then LocalDefs.unfold_tac ctxt set_of_simps
       
   269                        THEN steps_tac MAX true (lq \\ qs) (lp \\ ps)
       
   270                   else all_tac)
       
   271           end
       
   272       in
       
   273         rem_inv_img ctxt
       
   274         THEN steps_tac label strict (nth lev q) (nth lev p)
       
   275       end
       
   276 
       
   277     val (mk_set, setT) = if label = MS then (mk_mset, msetT) else (mk_set, setT)
       
   278 
       
   279     fun tag_pair p (i, tag) =
       
   280       HOLogic.pair_const natT natT $
       
   281         (measure_fn p i $ Bound 0) $ HOLogic.mk_number natT tag
       
   282 
       
   283     fun pt_lev (p, lm) = Abs ("x", Termination.get_types D p,
       
   284                            mk_set nat_pairT (map (tag_pair p) lm))
       
   285 
       
   286     val level_mapping =
       
   287       map_index pt_lev lev
       
   288         |> Termination.mk_sumcases D (setT nat_pairT)
       
   289         |> cterm_of thy
       
   290     in
       
   291       PROFILE "Proof Reconstruction"
       
   292         (CONVERSION (Conv.arg_conv (Conv.arg_conv (FundefLib.regroup_union_conv sl))) 1
       
   293          THEN (rtac @{thm reduction_pair_lemma} 1)
       
   294          THEN (rtac @{thm rp_inv_image_rp} 1)
       
   295          THEN (rtac (order_rpair ms_rp label) 1)
       
   296          THEN PRIMITIVE (instantiate' [] [SOME level_mapping])
       
   297          THEN unfold_tac @{thms rp_inv_image_def} (simpset_of thy)
       
   298          THEN LocalDefs.unfold_tac ctxt
       
   299            (@{thms split_conv} @ @{thms fst_conv} @ @{thms snd_conv})
       
   300          THEN REPEAT (SOMEGOAL (resolve_tac [@{thm Un_least}, @{thm empty_subsetI}]))
       
   301          THEN EVERY (map (prove_lev true) sl)
       
   302          THEN EVERY (map (prove_lev false) ((0 upto length cs - 1) \\ sl)))
       
   303     end
       
   304 
       
   305 
       
   306 
       
   307 local open Termination in
       
   308 fun print_cell (SOME (Less _)) = "<"
       
   309   | print_cell (SOME (LessEq _)) = "\<le>"
       
   310   | print_cell (SOME (None _)) = "-"
       
   311   | print_cell (SOME (False _)) = "-"
       
   312   | print_cell (NONE) = "?"
       
   313 
       
   314 fun print_error ctxt D = CALLS (fn (cs, i) =>
       
   315   let
       
   316     val np = get_num_points D
       
   317     val ms = map (get_measures D) (0 upto np - 1)
       
   318     val tys = map (get_types D) (0 upto np - 1)
       
   319     fun index xs = (1 upto length xs) ~~ xs
       
   320     fun outp s t f xs = map (fn (x, y) => s ^ Int.toString x ^ t ^ f y ^ "\n") xs
       
   321     val ims = index (map index ms)
       
   322     val _ = Output.tracing (concat (outp "fn #" ":\n" (concat o outp "\tmeasure #" ": " (Syntax.string_of_term ctxt)) ims))
       
   323     fun print_call (k, c) =
       
   324       let
       
   325         val (_, p, _, q, _, _) = dest_call D c
       
   326         val _ = Output.tracing ("call table for call #" ^ Int.toString k ^ ": fn " ^ 
       
   327                                 Int.toString (p + 1) ^ " ~> fn " ^ Int.toString (q + 1))
       
   328         val caller_ms = nth ms p
       
   329         val callee_ms = nth ms q
       
   330         val entries = map (fn x => map (pair x) (callee_ms)) (caller_ms)
       
   331         fun print_ln (i : int, l) = concat (Int.toString i :: "   " :: map (enclose " " " " o print_cell o (uncurry (get_descent D c))) l)
       
   332         val _ = Output.tracing (concat (Int.toString (p + 1) ^ "|" ^ Int.toString (q + 1) ^ 
       
   333                                         " " :: map (enclose " " " " o Int.toString) (1 upto length callee_ms)) ^ "\n" 
       
   334                                 ^ cat_lines (map print_ln ((1 upto (length entries)) ~~ entries)))
       
   335       in
       
   336         true
       
   337       end
       
   338     fun list_call (k, c) =
       
   339       let
       
   340         val (_, p, _, q, _, _) = dest_call D c
       
   341         val _ = Output.tracing ("call #" ^ (Int.toString k) ^ ": fn " ^
       
   342                                 Int.toString (p + 1) ^ " ~> fn " ^ Int.toString (q + 1) ^ "\n" ^ 
       
   343                                 (Syntax.string_of_term ctxt c))
       
   344       in true end
       
   345     val _ = forall list_call ((1 upto length cs) ~~ cs)
       
   346     val _ = forall print_call ((1 upto length cs) ~~ cs)
       
   347   in
       
   348     all_tac
       
   349   end)
       
   350 end
       
   351 
       
   352 
       
   353 fun single_scnp_tac use_tags orders ctxt cont err_cont D = Termination.CALLS (fn (cs, i) =>
       
   354   let
       
   355     val gp = gen_probl D cs
       
   356 (*    val _ = TRACE ("SCNP instance: " ^ makestring gp)*)
       
   357     val certificate = generate_certificate use_tags orders gp
       
   358 (*    val _ = TRACE ("Certificate: " ^ makestring certificate)*)
       
   359 
       
   360     val ms_configured = is_some (MultisetSetup.get (ProofContext.theory_of ctxt))
       
   361     in
       
   362     case certificate
       
   363      of NONE => err_cont D i
       
   364       | SOME cert =>
       
   365         if not ms_configured andalso #1 cert = MS
       
   366         then err_cont D i
       
   367         else SELECT_GOAL (reconstruct_tac ctxt D cs gp cert) i
       
   368              THEN (rtac @{thm wf_empty} i ORELSE cont D i)
       
   369   end)
       
   370 
       
   371 fun decomp_scnp_tac orders autom_tac ctxt err_cont =
       
   372   let
       
   373     open Termination
       
   374     val derive_diag = Descent.derive_diag ctxt autom_tac
       
   375     val derive_all = Descent.derive_all ctxt autom_tac
       
   376     val decompose = Decompose.decompose_tac ctxt autom_tac
       
   377     val scnp_no_tags = single_scnp_tac false orders ctxt
       
   378     val scnp_full = single_scnp_tac true orders ctxt
       
   379 
       
   380     fun first_round c e =
       
   381         derive_diag (REPEAT scnp_no_tags c e)
       
   382 
       
   383     val second_round =
       
   384         REPEAT (fn c => fn e => decompose (scnp_no_tags c c) e)
       
   385 
       
   386     val third_round =
       
   387         derive_all oo
       
   388         REPEAT (fn c => fn e =>
       
   389           scnp_full (decompose c c) e)
       
   390 
       
   391     fun Then s1 s2 c e = s1 (s2 c c) (s2 c e)
       
   392 
       
   393     val strategy = Then (Then first_round second_round) third_round
       
   394 
       
   395   in
       
   396     TERMINATION ctxt (strategy err_cont err_cont)
       
   397   end
       
   398 
       
   399 fun decomp_scnp orders ctxt =
       
   400   let
       
   401     val extra_simps = FundefCommon.TerminationSimps.get ctxt
       
   402     val autom_tac = auto_tac (local_clasimpset_of ctxt addsimps2 extra_simps)
       
   403   in
       
   404     Method.SIMPLE_METHOD
       
   405       (TRY (FundefCommon.apply_termination_rule ctxt 1)
       
   406        THEN TRY Termination.wf_union_tac
       
   407        THEN
       
   408          (rtac @{thm wf_empty} 1
       
   409           ORELSE decomp_scnp_tac orders autom_tac ctxt (print_error ctxt) 1))
       
   410   end
       
   411 
       
   412 
       
   413 (* Method setup *)
       
   414 
       
   415 val orders =
       
   416   (Scan.repeat1
       
   417     ((Args.$$$ "max" >> K MAX) ||
       
   418      (Args.$$$ "min" >> K MIN) ||
       
   419      (Args.$$$ "ms" >> K MS))
       
   420   || Scan.succeed [MAX, MS, MIN])
       
   421 
       
   422 val setup = Method.add_method
       
   423   ("sizechange", Method.sectioned_args (Scan.lift orders) clasimp_modifiers decomp_scnp,
       
   424    "termination prover with graph decomposition and the NP subset of size change termination")
       
   425 
       
   426 end