src/HOL/Tools/Sledgehammer/sledgehammer_shrink.ML
author blanchet
Wed Dec 12 21:48:29 2012 +0100 (2012-12-12 ago)
changeset 50510 7e4f2f8d9b50
parent 50278 05f8ec128e83
child 50557 31313171deb5
permissions -rw-r--r--
export a pair of ML functions
smolkas@50263
     1
(*  Title:      HOL/Tools/Sledgehammer/sledgehammer_shrink.ML
smolkas@50263
     2
    Author:     Jasmin Blanchette, TU Muenchen
smolkas@50263
     3
    Author:     Steffen Juilf Smolka, TU Muenchen
smolkas@50263
     4
smolkas@50265
     5
Shrinking and preplaying of reconstructed isar proofs.
smolkas@50263
     6
*)
smolkas@50263
     7
smolkas@50259
     8
signature SLEDGEHAMMER_SHRINK =
smolkas@50259
     9
sig
smolkas@50264
    10
  type isar_step = Sledgehammer_Proof.isar_step
smolkas@50259
    11
	val shrink_proof : 
smolkas@50259
    12
    bool -> Proof.context -> string -> string -> bool -> Time.time -> real
smolkas@50271
    13
    -> isar_step list -> isar_step list * (bool * (bool * Time.time))
smolkas@50259
    14
end
smolkas@50259
    15
smolkas@50269
    16
structure Sledgehammer_Shrink : SLEDGEHAMMER_SHRINK =
smolkas@50259
    17
struct
smolkas@50259
    18
smolkas@50265
    19
open Sledgehammer_Util
smolkas@50264
    20
open Sledgehammer_Proof
smolkas@50259
    21
smolkas@50259
    22
(* Parameters *)
smolkas@50259
    23
val merge_timeout_slack = 1.2
smolkas@50259
    24
smolkas@50259
    25
(* Data structures, orders *)
smolkas@50259
    26
val label_ord = prod_ord int_ord fast_string_ord o pairself swap
smolkas@50259
    27
structure Label_Table = Table(
smolkas@50259
    28
  type key = label
smolkas@50259
    29
  val ord = label_ord)
smolkas@50259
    30
smolkas@50259
    31
(* clean vector interface *)
smolkas@50259
    32
fun get i v = Vector.sub (v, i)
smolkas@50259
    33
fun replace x i v = Vector.update (v, i, x)
smolkas@50259
    34
fun update f i v = replace (get i v |> f) i v
smolkas@50273
    35
fun v_map_index f v = Vector.foldr (op::) nil v |> map_index f |> Vector.fromList
smolkas@50259
    36
fun v_fold_index f v s =
smolkas@50259
    37
  Vector.foldl (fn (x, (i, s)) => (i+1, f (i, x) s)) (0, s) v |> snd
smolkas@50259
    38
smolkas@50259
    39
(* Queue interface to table *)
smolkas@50259
    40
fun pop tab key =
smolkas@50259
    41
  let val v = hd (Inttab.lookup_list tab key) in
smolkas@50259
    42
    (v, Inttab.remove_list (op =) (key, v) tab)
smolkas@50259
    43
  end
smolkas@50259
    44
fun pop_max tab = pop tab (the (Inttab.max_key tab))
smolkas@50259
    45
fun add_list tab xs = fold (Inttab.insert_list (op =)) xs tab
smolkas@50259
    46
smolkas@50271
    47
(* Timing *)
smolkas@50271
    48
fun ext_time_add (b1, t1) (b2, t2) = (b1 orelse b2, Time.+(t1,t2))
smolkas@50271
    49
val no_time = (false, seconds 0.0)
smolkas@50271
    50
fun take_time timeout tac arg =
smolkas@50271
    51
  let val timing = Timing.start () in
smolkas@50271
    52
    (TimeLimit.timeLimit timeout tac arg;
smolkas@50271
    53
     Timing.result timing |> #cpu |> SOME)
smolkas@50271
    54
    handle TimeLimit.TimeOut => NONE
smolkas@50271
    55
  end
smolkas@50271
    56
smolkas@50271
    57
smolkas@50259
    58
(* Main function for shrinking proofs *)
smolkas@50259
    59
fun shrink_proof debug ctxt type_enc lam_trans preplay preplay_timeout
smolkas@50259
    60
                 isar_shrink proof =
smolkas@50259
    61
let
smolkas@50271
    62
  (* handle metis preplay fail *)
smolkas@50271
    63
  local
smolkas@50271
    64
    open Unsynchronized
smolkas@50271
    65
    val metis_fail = ref false
smolkas@50271
    66
  in
smolkas@50271
    67
    fun handle_metis_fail try_metis () =
smolkas@50271
    68
      try_metis () handle _ => (metis_fail := true ; SOME (seconds 0.0))
smolkas@50271
    69
    fun get_time lazy_time = 
smolkas@50271
    70
      if !metis_fail then SOME (seconds 0.0) else Lazy.force lazy_time
smolkas@50271
    71
    val metis_fail = fn () => !metis_fail
smolkas@50271
    72
  end
smolkas@50271
    73
  
smolkas@50271
    74
  (* Shrink top level proof - do not shrink case splits *)
smolkas@50269
    75
  fun shrink_top_level on_top_level ctxt proof =
smolkas@50259
    76
  let
smolkas@50259
    77
    (* proof vector *)
smolkas@50259
    78
    val proof_vect = proof |> map SOME |> Vector.fromList
smolkas@50260
    79
    val n = Vector.length proof_vect
smolkas@50260
    80
    val n_metis = metis_steps_top_level proof
smolkas@50260
    81
    val target_n_metis = Real.fromInt n_metis / isar_shrink |> Real.round
smolkas@50259
    82
smolkas@50259
    83
    (* table for mapping from (top-level-)label to proof position *)
smolkas@50259
    84
    fun update_table (i, Assume (label, _)) = 
smolkas@50259
    85
        Label_Table.update_new (label, i)
smolkas@50259
    86
      | update_table (i, Prove (_, label, _, _)) =
smolkas@50259
    87
        Label_Table.update_new (label, i)
smolkas@50259
    88
      | update_table _ = I
smolkas@50259
    89
    val label_index_table = fold_index update_table proof Label_Table.empty
smolkas@50259
    90
smolkas@50259
    91
    (* proof references *)
smolkas@50259
    92
    fun refs (Prove (_, _, _, By_Metis (lfs, _))) =
smolkas@50269
    93
        map_filter (Label_Table.lookup label_index_table) lfs
smolkas@50259
    94
      | refs (Prove (_, _, _, Case_Split (cases, (lfs, _)))) =
smolkas@50269
    95
        map_filter (Label_Table.lookup label_index_table) lfs
smolkas@50259
    96
          @ maps (maps refs) cases
smolkas@50259
    97
      | refs _ = []
smolkas@50259
    98
    val refed_by_vect =
smolkas@50260
    99
      Vector.tabulate (n, (fn _ => []))
smolkas@50259
   100
      |> fold_index (fn (i, step) => fold (update (cons i)) (refs step)) proof
smolkas@50259
   101
      |> Vector.map rev (* after rev, indices are sorted in ascending order *)
smolkas@50259
   102
smolkas@50259
   103
    (* candidates for elimination, use table as priority queue (greedy
smolkas@50259
   104
       algorithm) *)
smolkas@50259
   105
    fun add_if_cand proof_vect (i, [j]) =
smolkas@50259
   106
        (case (the (get i proof_vect), the (get j proof_vect)) of
smolkas@50259
   107
          (Prove (_, _, t, By_Metis _), Prove (_, _, _, By_Metis _)) =>
smolkas@50259
   108
            cons (Term.size_of_term t, i)
smolkas@50259
   109
        | _ => I)
smolkas@50259
   110
      | add_if_cand _ _ = I
smolkas@50259
   111
    val cand_tab = 
smolkas@50259
   112
      v_fold_index (add_if_cand proof_vect) refed_by_vect []
smolkas@50259
   113
      |> Inttab.make_list
smolkas@50259
   114
smolkas@50259
   115
    (* Metis Preplaying *)
smolkas@50273
   116
    fun resolve_fact_names names =
smolkas@50273
   117
      names
smolkas@50273
   118
        |>> map string_for_label 
smolkas@50273
   119
        |> op @
smolkas@50273
   120
        |> maps (thms_of_name ctxt)
smolkas@50275
   121
smolkas@50273
   122
    fun try_metis timeout (succedent, Prove (_, _, t, byline)) =
smolkas@50259
   123
      if not preplay then (fn () => SOME (seconds 0.0)) else
smolkas@50273
   124
      (case byline of
smolkas@50273
   125
        By_Metis fact_names =>
smolkas@50273
   126
          let
smolkas@50273
   127
            val facts = resolve_fact_names fact_names
smolkas@50273
   128
            val goal =
smolkas@50273
   129
              Goal.prove (Config.put Metis_Tactic.verbose debug ctxt) [] [] t
smolkas@50273
   130
            fun tac {context = ctxt, prems = _} =
smolkas@50273
   131
              Metis_Tactic.metis_tac [type_enc] lam_trans ctxt facts 1
smolkas@50273
   132
          in
smolkas@50273
   133
            take_time timeout (fn () => goal tac)
smolkas@50273
   134
          end
smolkas@50273
   135
      | Case_Split (cases, fact_names) =>
smolkas@50273
   136
          let
smolkas@50275
   137
            val facts = 
smolkas@50275
   138
              resolve_fact_names fact_names
smolkas@50275
   139
                @ (case the succedent of
smolkas@50275
   140
                    Prove (_, _, t, _) => 
smolkas@50275
   141
                      Skip_Proof.make_thm (Proof_Context.theory_of ctxt) t
smolkas@50275
   142
                  | Assume (_, t) =>
smolkas@50275
   143
                      Skip_Proof.make_thm (Proof_Context.theory_of ctxt) t
smolkas@50275
   144
                  | _ => error "Internal error: unexpected succedent of case split")
smolkas@50278
   145
                ::  map 
smolkas@50278
   146
                      (hd #> (fn Assume (_, a) => Logic.mk_implies(a, t)
smolkas@50278
   147
                               | _ => error "Internal error: malformed case split") 
smolkas@50278
   148
                          #> Skip_Proof.make_thm (Proof_Context.theory_of ctxt)) 
smolkas@50273
   149
                      cases
smolkas@50273
   150
            val goal =
smolkas@50275
   151
              Goal.prove (Config.put Metis_Tactic.verbose debug ctxt) [] [] t
smolkas@50273
   152
            fun tac {context = ctxt, prems = _} =
smolkas@50273
   153
              Metis_Tactic.metis_tac [type_enc] lam_trans ctxt facts 1
smolkas@50273
   154
          in
smolkas@50273
   155
            take_time timeout (fn () => goal tac)
smolkas@50273
   156
          end)
smolkas@50271
   157
      | try_metis _ _  = (fn () => SOME (seconds 0.0) )
smolkas@50271
   158
smolkas@50271
   159
    val try_metis_quietly = the_default NONE oo try oo try_metis
smolkas@50259
   160
smolkas@50271
   161
    (* cache metis preplay times in lazy time vector *)
smolkas@50259
   162
    val metis_time =
smolkas@50274
   163
      v_map_index
smolkas@50273
   164
        (Lazy.lazy o handle_metis_fail o try_metis preplay_timeout 
smolkas@50275
   165
          o apfst (fn i => try (the o get (i-1)) proof_vect) o apsnd the)
smolkas@50271
   166
        proof_vect
smolkas@50271
   167
    fun sum_up_time lazy_time_vector =
smolkas@50271
   168
      Vector.foldl
smolkas@50271
   169
        ((fn (SOME t, (b, ts)) => (b, Time.+(t, ts))
smolkas@50271
   170
           | (NONE, (_, ts)) => (true, Time.+(ts, preplay_timeout))) 
smolkas@50271
   171
          o apfst get_time)
smolkas@50271
   172
        no_time lazy_time_vector
smolkas@50259
   173
smolkas@50259
   174
    (* Merging *)
smolkas@50260
   175
    fun merge (Prove (_, label1, _, By_Metis (lfs1, gfs1)))
smolkas@50259
   176
              (Prove (qs2, label2 , t, By_Metis (lfs2, gfs2))) =
smolkas@50259
   177
      let
smolkas@50269
   178
        val lfs = remove (op =) label1 lfs2 |> union (op =) lfs1
smolkas@50269
   179
        val gfs = union (op =) gfs1 gfs2
smolkas@50269
   180
      in Prove (qs2, label2, t, By_Metis (lfs, gfs)) end
smolkas@50271
   181
      | merge _ _ = error "Internal error: Tring to merge unmergable isar steps"
smolkas@50271
   182
smolkas@50259
   183
    fun try_merge metis_time (s1, i) (s2, j) =
smolkas@50259
   184
      (case get i metis_time |> Lazy.force of
smolkas@50259
   185
        NONE => (NONE, metis_time)
smolkas@50259
   186
      | SOME t1 =>
smolkas@50259
   187
        (case get j metis_time |> Lazy.force of
smolkas@50259
   188
          NONE => (NONE, metis_time)
smolkas@50259
   189
        | SOME t2 =>
smolkas@50259
   190
          let
smolkas@50259
   191
            val s12 = merge s1 s2
smolkas@50270
   192
            val timeout = time_mult merge_timeout_slack (Time.+(t1, t2))
smolkas@50259
   193
          in
smolkas@50274
   194
            case try_metis_quietly timeout (NONE, s12) () of
smolkas@50259
   195
              NONE => (NONE, metis_time)
smolkas@50259
   196
            | some_t12 =>
smolkas@50259
   197
              (SOME s12, metis_time
smolkas@50259
   198
                         |> replace (seconds 0.0 |> SOME |> Lazy.value) i
smolkas@50259
   199
                         |> replace (Lazy.value some_t12) j)
smolkas@50259
   200
smolkas@50259
   201
          end))
smolkas@50259
   202
smolkas@50260
   203
    fun merge_steps metis_time proof_vect refed_by cand_tab n' n_metis' =
smolkas@50259
   204
      if Inttab.is_empty cand_tab 
smolkas@50260
   205
        orelse n_metis' <= target_n_metis 
smolkas@50269
   206
        orelse (on_top_level andalso n'<3)
smolkas@50259
   207
      then
smolkas@50259
   208
        (Vector.foldr
smolkas@50259
   209
           (fn (NONE, proof) => proof | (SOME s, proof) => s :: proof)
smolkas@50259
   210
           [] proof_vect,
smolkas@50271
   211
         sum_up_time metis_time)
smolkas@50259
   212
      else
smolkas@50259
   213
        let
smolkas@50259
   214
          val (i, cand_tab) = pop_max cand_tab
smolkas@50259
   215
          val j = get i refed_by |> the_single
smolkas@50259
   216
          val s1 = get i proof_vect |> the
smolkas@50259
   217
          val s2 = get j proof_vect |> the
smolkas@50259
   218
        in
smolkas@50259
   219
          case try_merge metis_time (s1, i) (s2, j) of
smolkas@50259
   220
            (NONE, metis_time) =>
smolkas@50260
   221
            merge_steps metis_time proof_vect refed_by cand_tab n' n_metis'
smolkas@50259
   222
          | (s, metis_time) => 
smolkas@50259
   223
          let
smolkas@50259
   224
            val refs = refs s1
smolkas@50259
   225
            val refed_by = refed_by |> fold
smolkas@50259
   226
              (update (Ord_List.remove int_ord i #> Ord_List.insert int_ord j)) refs
smolkas@50259
   227
            val new_candidates =
smolkas@50259
   228
              fold (add_if_cand proof_vect)
smolkas@50259
   229
                (map (fn i => (i, get i refed_by)) refs) []
smolkas@50259
   230
            val cand_tab = add_list cand_tab new_candidates
smolkas@50259
   231
            val proof_vect = proof_vect |> replace NONE i |> replace s j
smolkas@50259
   232
          in
smolkas@50260
   233
            merge_steps metis_time proof_vect refed_by cand_tab (n' - 1) (n_metis' - 1)
smolkas@50259
   234
          end
smolkas@50259
   235
        end
smolkas@50259
   236
  in
smolkas@50260
   237
    merge_steps metis_time proof_vect refed_by_vect cand_tab n n_metis
smolkas@50259
   238
  end
smolkas@50259
   239
  
smolkas@50269
   240
  fun shrink_proof' on_top_level ctxt proof = 
smolkas@50259
   241
    let
smolkas@50259
   242
      (* Enrich context with top-level facts *)
smolkas@50259
   243
      val thy = Proof_Context.theory_of ctxt
smolkas@50259
   244
      fun enrich_ctxt (Assume (label, t)) ctxt = 
smolkas@50259
   245
          Proof_Context.put_thms false
smolkas@50259
   246
            (string_for_label label, SOME [Skip_Proof.make_thm thy t]) ctxt
smolkas@50259
   247
        | enrich_ctxt (Prove (_, label, t, _)) ctxt =
smolkas@50259
   248
          Proof_Context.put_thms false
smolkas@50259
   249
            (string_for_label label, SOME [Skip_Proof.make_thm thy t]) ctxt
smolkas@50259
   250
        | enrich_ctxt _ ctxt = ctxt
smolkas@50259
   251
      val rich_ctxt = fold enrich_ctxt proof ctxt
smolkas@50259
   252
smolkas@50259
   253
      (* Shrink case_splits and top-levl *)
smolkas@50259
   254
      val ((proof, top_level_time), lower_level_time) = 
smolkas@50259
   255
        proof |> shrink_case_splits rich_ctxt
smolkas@50269
   256
              |>> shrink_top_level on_top_level rich_ctxt
smolkas@50259
   257
    in
smolkas@50259
   258
      (proof, ext_time_add lower_level_time top_level_time)
smolkas@50259
   259
    end
smolkas@50259
   260
smolkas@50259
   261
  and shrink_case_splits ctxt proof =
smolkas@50259
   262
    let
smolkas@50269
   263
      fun shrink_each_and_collect_time shrink candidates =
smolkas@50271
   264
        let fun f_m cand time = shrink cand ||> ext_time_add time
smolkas@50271
   265
        in fold_map f_m candidates no_time end
smolkas@50269
   266
      val shrink_case_split = shrink_each_and_collect_time (shrink_proof' false ctxt)
smolkas@50259
   267
      fun shrink (Prove (qs, lbl, t, Case_Split (cases, facts))) =
smolkas@50259
   268
          let val (cases, time) = shrink_case_split cases
smolkas@50259
   269
          in (Prove (qs, lbl, t, Case_Split (cases, facts)), time) end
smolkas@50259
   270
        | shrink step = (step, no_time)
smolkas@50259
   271
    in 
smolkas@50269
   272
      shrink_each_and_collect_time shrink proof
smolkas@50259
   273
    end
smolkas@50259
   274
in
smolkas@50259
   275
  shrink_proof' true ctxt proof
smolkas@50271
   276
  |> apsnd (pair (metis_fail () ) )
smolkas@50259
   277
end
smolkas@50259
   278
smolkas@50259
   279
end