src/HOL/Tools/Sledgehammer/sledgehammer_shrink.ML
author smolkas
Wed Nov 28 12:25:43 2012 +0100 (2012-11-28 ago)
changeset 50278 05f8ec128e83
parent 50275 66cdf5c9b6c7
child 50557 31313171deb5
permissions -rw-r--r--
improved readability
     1 (*  Title:      HOL/Tools/Sledgehammer/sledgehammer_shrink.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3     Author:     Steffen Juilf Smolka, TU Muenchen
     4 
     5 Shrinking and preplaying of reconstructed isar proofs.
     6 *)
     7 
     8 signature SLEDGEHAMMER_SHRINK =
     9 sig
    10   type isar_step = Sledgehammer_Proof.isar_step
    11 	val shrink_proof : 
    12     bool -> Proof.context -> string -> string -> bool -> Time.time -> real
    13     -> isar_step list -> isar_step list * (bool * (bool * Time.time))
    14 end
    15 
    16 structure Sledgehammer_Shrink : SLEDGEHAMMER_SHRINK =
    17 struct
    18 
    19 open Sledgehammer_Util
    20 open Sledgehammer_Proof
    21 
    22 (* Parameters *)
    23 val merge_timeout_slack = 1.2
    24 
    25 (* Data structures, orders *)
    26 val label_ord = prod_ord int_ord fast_string_ord o pairself swap
    27 structure Label_Table = Table(
    28   type key = label
    29   val ord = label_ord)
    30 
    31 (* clean vector interface *)
    32 fun get i v = Vector.sub (v, i)
    33 fun replace x i v = Vector.update (v, i, x)
    34 fun update f i v = replace (get i v |> f) i v
    35 fun v_map_index f v = Vector.foldr (op::) nil v |> map_index f |> Vector.fromList
    36 fun v_fold_index f v s =
    37   Vector.foldl (fn (x, (i, s)) => (i+1, f (i, x) s)) (0, s) v |> snd
    38 
    39 (* Queue interface to table *)
    40 fun pop tab key =
    41   let val v = hd (Inttab.lookup_list tab key) in
    42     (v, Inttab.remove_list (op =) (key, v) tab)
    43   end
    44 fun pop_max tab = pop tab (the (Inttab.max_key tab))
    45 fun add_list tab xs = fold (Inttab.insert_list (op =)) xs tab
    46 
    47 (* Timing *)
    48 fun ext_time_add (b1, t1) (b2, t2) = (b1 orelse b2, Time.+(t1,t2))
    49 val no_time = (false, seconds 0.0)
    50 fun take_time timeout tac arg =
    51   let val timing = Timing.start () in
    52     (TimeLimit.timeLimit timeout tac arg;
    53      Timing.result timing |> #cpu |> SOME)
    54     handle TimeLimit.TimeOut => NONE
    55   end
    56 
    57 
    58 (* Main function for shrinking proofs *)
    59 fun shrink_proof debug ctxt type_enc lam_trans preplay preplay_timeout
    60                  isar_shrink proof =
    61 let
    62   (* handle metis preplay fail *)
    63   local
    64     open Unsynchronized
    65     val metis_fail = ref false
    66   in
    67     fun handle_metis_fail try_metis () =
    68       try_metis () handle _ => (metis_fail := true ; SOME (seconds 0.0))
    69     fun get_time lazy_time = 
    70       if !metis_fail then SOME (seconds 0.0) else Lazy.force lazy_time
    71     val metis_fail = fn () => !metis_fail
    72   end
    73   
    74   (* Shrink top level proof - do not shrink case splits *)
    75   fun shrink_top_level on_top_level ctxt proof =
    76   let
    77     (* proof vector *)
    78     val proof_vect = proof |> map SOME |> Vector.fromList
    79     val n = Vector.length proof_vect
    80     val n_metis = metis_steps_top_level proof
    81     val target_n_metis = Real.fromInt n_metis / isar_shrink |> Real.round
    82 
    83     (* table for mapping from (top-level-)label to proof position *)
    84     fun update_table (i, Assume (label, _)) = 
    85         Label_Table.update_new (label, i)
    86       | update_table (i, Prove (_, label, _, _)) =
    87         Label_Table.update_new (label, i)
    88       | update_table _ = I
    89     val label_index_table = fold_index update_table proof Label_Table.empty
    90 
    91     (* proof references *)
    92     fun refs (Prove (_, _, _, By_Metis (lfs, _))) =
    93         map_filter (Label_Table.lookup label_index_table) lfs
    94       | refs (Prove (_, _, _, Case_Split (cases, (lfs, _)))) =
    95         map_filter (Label_Table.lookup label_index_table) lfs
    96           @ maps (maps refs) cases
    97       | refs _ = []
    98     val refed_by_vect =
    99       Vector.tabulate (n, (fn _ => []))
   100       |> fold_index (fn (i, step) => fold (update (cons i)) (refs step)) proof
   101       |> Vector.map rev (* after rev, indices are sorted in ascending order *)
   102 
   103     (* candidates for elimination, use table as priority queue (greedy
   104        algorithm) *)
   105     fun add_if_cand proof_vect (i, [j]) =
   106         (case (the (get i proof_vect), the (get j proof_vect)) of
   107           (Prove (_, _, t, By_Metis _), Prove (_, _, _, By_Metis _)) =>
   108             cons (Term.size_of_term t, i)
   109         | _ => I)
   110       | add_if_cand _ _ = I
   111     val cand_tab = 
   112       v_fold_index (add_if_cand proof_vect) refed_by_vect []
   113       |> Inttab.make_list
   114 
   115     (* Metis Preplaying *)
   116     fun resolve_fact_names names =
   117       names
   118         |>> map string_for_label 
   119         |> op @
   120         |> maps (thms_of_name ctxt)
   121 
   122     fun try_metis timeout (succedent, Prove (_, _, t, byline)) =
   123       if not preplay then (fn () => SOME (seconds 0.0)) else
   124       (case byline of
   125         By_Metis fact_names =>
   126           let
   127             val facts = resolve_fact_names fact_names
   128             val goal =
   129               Goal.prove (Config.put Metis_Tactic.verbose debug ctxt) [] [] t
   130             fun tac {context = ctxt, prems = _} =
   131               Metis_Tactic.metis_tac [type_enc] lam_trans ctxt facts 1
   132           in
   133             take_time timeout (fn () => goal tac)
   134           end
   135       | Case_Split (cases, fact_names) =>
   136           let
   137             val facts = 
   138               resolve_fact_names fact_names
   139                 @ (case the succedent of
   140                     Prove (_, _, t, _) => 
   141                       Skip_Proof.make_thm (Proof_Context.theory_of ctxt) t
   142                   | Assume (_, t) =>
   143                       Skip_Proof.make_thm (Proof_Context.theory_of ctxt) t
   144                   | _ => error "Internal error: unexpected succedent of case split")
   145                 ::  map 
   146                       (hd #> (fn Assume (_, a) => Logic.mk_implies(a, t)
   147                                | _ => error "Internal error: malformed case split") 
   148                           #> Skip_Proof.make_thm (Proof_Context.theory_of ctxt)) 
   149                       cases
   150             val goal =
   151               Goal.prove (Config.put Metis_Tactic.verbose debug ctxt) [] [] t
   152             fun tac {context = ctxt, prems = _} =
   153               Metis_Tactic.metis_tac [type_enc] lam_trans ctxt facts 1
   154           in
   155             take_time timeout (fn () => goal tac)
   156           end)
   157       | try_metis _ _  = (fn () => SOME (seconds 0.0) )
   158 
   159     val try_metis_quietly = the_default NONE oo try oo try_metis
   160 
   161     (* cache metis preplay times in lazy time vector *)
   162     val metis_time =
   163       v_map_index
   164         (Lazy.lazy o handle_metis_fail o try_metis preplay_timeout 
   165           o apfst (fn i => try (the o get (i-1)) proof_vect) o apsnd the)
   166         proof_vect
   167     fun sum_up_time lazy_time_vector =
   168       Vector.foldl
   169         ((fn (SOME t, (b, ts)) => (b, Time.+(t, ts))
   170            | (NONE, (_, ts)) => (true, Time.+(ts, preplay_timeout))) 
   171           o apfst get_time)
   172         no_time lazy_time_vector
   173 
   174     (* Merging *)
   175     fun merge (Prove (_, label1, _, By_Metis (lfs1, gfs1)))
   176               (Prove (qs2, label2 , t, By_Metis (lfs2, gfs2))) =
   177       let
   178         val lfs = remove (op =) label1 lfs2 |> union (op =) lfs1
   179         val gfs = union (op =) gfs1 gfs2
   180       in Prove (qs2, label2, t, By_Metis (lfs, gfs)) end
   181       | merge _ _ = error "Internal error: Tring to merge unmergable isar steps"
   182 
   183     fun try_merge metis_time (s1, i) (s2, j) =
   184       (case get i metis_time |> Lazy.force of
   185         NONE => (NONE, metis_time)
   186       | SOME t1 =>
   187         (case get j metis_time |> Lazy.force of
   188           NONE => (NONE, metis_time)
   189         | SOME t2 =>
   190           let
   191             val s12 = merge s1 s2
   192             val timeout = time_mult merge_timeout_slack (Time.+(t1, t2))
   193           in
   194             case try_metis_quietly timeout (NONE, s12) () of
   195               NONE => (NONE, metis_time)
   196             | some_t12 =>
   197               (SOME s12, metis_time
   198                          |> replace (seconds 0.0 |> SOME |> Lazy.value) i
   199                          |> replace (Lazy.value some_t12) j)
   200 
   201           end))
   202 
   203     fun merge_steps metis_time proof_vect refed_by cand_tab n' n_metis' =
   204       if Inttab.is_empty cand_tab 
   205         orelse n_metis' <= target_n_metis 
   206         orelse (on_top_level andalso n'<3)
   207       then
   208         (Vector.foldr
   209            (fn (NONE, proof) => proof | (SOME s, proof) => s :: proof)
   210            [] proof_vect,
   211          sum_up_time metis_time)
   212       else
   213         let
   214           val (i, cand_tab) = pop_max cand_tab
   215           val j = get i refed_by |> the_single
   216           val s1 = get i proof_vect |> the
   217           val s2 = get j proof_vect |> the
   218         in
   219           case try_merge metis_time (s1, i) (s2, j) of
   220             (NONE, metis_time) =>
   221             merge_steps metis_time proof_vect refed_by cand_tab n' n_metis'
   222           | (s, metis_time) => 
   223           let
   224             val refs = refs s1
   225             val refed_by = refed_by |> fold
   226               (update (Ord_List.remove int_ord i #> Ord_List.insert int_ord j)) refs
   227             val new_candidates =
   228               fold (add_if_cand proof_vect)
   229                 (map (fn i => (i, get i refed_by)) refs) []
   230             val cand_tab = add_list cand_tab new_candidates
   231             val proof_vect = proof_vect |> replace NONE i |> replace s j
   232           in
   233             merge_steps metis_time proof_vect refed_by cand_tab (n' - 1) (n_metis' - 1)
   234           end
   235         end
   236   in
   237     merge_steps metis_time proof_vect refed_by_vect cand_tab n n_metis
   238   end
   239   
   240   fun shrink_proof' on_top_level ctxt proof = 
   241     let
   242       (* Enrich context with top-level facts *)
   243       val thy = Proof_Context.theory_of ctxt
   244       fun enrich_ctxt (Assume (label, t)) ctxt = 
   245           Proof_Context.put_thms false
   246             (string_for_label label, SOME [Skip_Proof.make_thm thy t]) ctxt
   247         | enrich_ctxt (Prove (_, label, t, _)) ctxt =
   248           Proof_Context.put_thms false
   249             (string_for_label label, SOME [Skip_Proof.make_thm thy t]) ctxt
   250         | enrich_ctxt _ ctxt = ctxt
   251       val rich_ctxt = fold enrich_ctxt proof ctxt
   252 
   253       (* Shrink case_splits and top-levl *)
   254       val ((proof, top_level_time), lower_level_time) = 
   255         proof |> shrink_case_splits rich_ctxt
   256               |>> shrink_top_level on_top_level rich_ctxt
   257     in
   258       (proof, ext_time_add lower_level_time top_level_time)
   259     end
   260 
   261   and shrink_case_splits ctxt proof =
   262     let
   263       fun shrink_each_and_collect_time shrink candidates =
   264         let fun f_m cand time = shrink cand ||> ext_time_add time
   265         in fold_map f_m candidates no_time end
   266       val shrink_case_split = shrink_each_and_collect_time (shrink_proof' false ctxt)
   267       fun shrink (Prove (qs, lbl, t, Case_Split (cases, facts))) =
   268           let val (cases, time) = shrink_case_split cases
   269           in (Prove (qs, lbl, t, Case_Split (cases, facts)), time) end
   270         | shrink step = (step, no_time)
   271     in 
   272       shrink_each_and_collect_time shrink proof
   273     end
   274 in
   275   shrink_proof' true ctxt proof
   276   |> apsnd (pair (metis_fail () ) )
   277 end
   278 
   279 end