src/HOL/Tools/Sledgehammer/sledgehammer_shrink.ML
author smolkas
Wed, 09 Jan 2013 14:35:46 +0100
changeset 50779 6f571f6797bd
parent 50711 eb67eec63a8b
child 50780 4174abe2c5fd
permissions -rw-r--r--
preplay obtain steps

(*  Title:      HOL/Tools/Sledgehammer/sledgehammer_shrink.ML
    Author:     Jasmin Blanchette, TU Muenchen
    Author:     Steffen Juilf Smolka, TU Muenchen

Shrinking and preplaying of reconstructed isar proofs.
*)

signature SLEDGEHAMMER_SHRINK =
sig
  type isar_step = Sledgehammer_Proof.isar_step
  val shrink_proof :
    bool -> Proof.context -> string -> string -> bool -> Time.time option
    -> real -> isar_step list -> isar_step list * (bool * (bool * Time.time))
end

structure Sledgehammer_Shrink : SLEDGEHAMMER_SHRINK =
struct

open Sledgehammer_Util
open Sledgehammer_Proof

(* Parameters *)
val merge_timeout_slack = 1.2

(* Data structures, orders *)
val label_ord = prod_ord int_ord fast_string_ord o pairself swap
structure Label_Table = Table(
  type key = label
  val ord = label_ord)

(* clean vector interface *)
fun get i v = Vector.sub (v, i)
fun replace x i v = Vector.update (v, i, x)
fun update f i v = replace (get i v |> f) i v
fun v_map_index f v = Vector.foldr (op::) nil v |> map_index f |> Vector.fromList
fun v_fold_index f v s =
  Vector.foldl (fn (x, (i, s)) => (i+1, f (i, x) s)) (0, s) v |> snd

(* Queue interface to table *)
fun pop tab key =
  let val v = hd (Inttab.lookup_list tab key) in
    (v, Inttab.remove_list (op =) (key, v) tab)
  end
fun pop_max tab = pop tab (the (Inttab.max_key tab))
fun add_list tab xs = fold (Inttab.insert_list (op =)) xs tab

(* Timing *)
fun ext_time_add (b1, t1) (b2, t2) = (b1 orelse b2, Time.+(t1,t2))
val no_time = (false, Time.zeroTime)
fun take_time timeout tac arg =
  let val timing = Timing.start () in
    (TimeLimit.timeLimit timeout tac arg;
     Timing.result timing |> #cpu |> SOME)
    handle TimeLimit.TimeOut => NONE
  end


(* Main function for shrinking proofs *)
fun shrink_proof debug ctxt type_enc lam_trans preplay preplay_timeout
                 isar_shrink proof =
  let
    (* 60 seconds seems like a good interpreation of "no timeout" *)
    val preplay_timeout = preplay_timeout |> the_default (seconds 60.0)

    (* handle metis preplay fail *)
    local
      open Unsynchronized
      val metis_fail = ref false
    in
      fun handle_metis_fail try_metis () =
        try_metis () handle exp =>
          (if debug then raise exp else metis_fail := true; SOME Time.zeroTime)
      fun get_time lazy_time =
        if !metis_fail then SOME Time.zeroTime else Lazy.force lazy_time
      val metis_fail = fn () => !metis_fail
    end

    (* Shrink proof on top level - do not shrink case splits *)
    fun shrink_top_level on_top_level ctxt proof =
    let
      (* proof vector *)
      val proof_vect = proof |> map SOME |> Vector.fromList
      val n = Vector.length proof_vect
      val n_metis = metis_steps_top_level proof
      val target_n_metis = Real.fromInt n_metis / isar_shrink |> Real.round

      (* table for mapping from (top-level-)label to proof position *)
      fun update_table (i, Assume (l, _)) = Label_Table.update_new (l, i)
        | update_table (i, Obtain (_, _, l, _, _)) = Label_Table.update_new (l, i)
        | update_table (i, Prove (_, l, _, _)) = Label_Table.update_new (l, i)
        | update_table _ = I
      val label_index_table = fold_index update_table proof Label_Table.empty
      val lookup_indices = map_filter (Label_Table.lookup label_index_table)

      (* proof references *)
      fun refs (Obtain (_, _, _, _, By_Metis (lfs, _))) = lookup_indices lfs
        | refs (Prove (_, _, _, By_Metis (lfs, _))) = lookup_indices lfs
        | refs (Prove (_, _, _, Case_Split (cases, (lfs, _)))) =
          lookup_indices lfs @ maps (maps refs) cases
        | refs _ = []
      val refed_by_vect =
        Vector.tabulate (n, (fn _ => []))
        |> fold_index (fn (i, step) => fold (update (cons i)) (refs step)) proof
        |> Vector.map rev (* after rev, indices are sorted in ascending order *)

      (* candidates for elimination, use table as priority queue (greedy
         algorithm) *)
      (* TODO: consider adding "Obtain" cases *)
      fun add_if_cand proof_vect (i, [j]) =
          (case (the (get i proof_vect), the (get j proof_vect)) of
            (Prove (_, _, t, By_Metis _), Prove (_, _, _, By_Metis _)) =>
            cons (Term.size_of_term t, i)
          | _ => I)
        | add_if_cand _ _ = I
      val cand_tab =
        v_fold_index (add_if_cand proof_vect) refed_by_vect []
        |> Inttab.make_list

      (* Metis Preplaying *)
      fun resolve_fact_names names =
        names
          |>> map string_for_label
          |> op @
          |> maps (thms_of_name ctxt)

      (* TODO: add "Obtain" case *)
      exception ZeroTime
      fun try_metis timeout (succedent, step) =
        (if not preplay then K (SOME Time.zeroTime) else
          let
            val (t, byline, obtain) =
              (case step of
                Prove (_, _, t, byline) => (t, byline, false)
              | Obtain (_, xs, _, t, byline) =>
                (* proof obligation: !!thesis. (!!x. A x ==> thesis) ==> thesis
                   (see ~~/src/Pure/Isar/obtain.ML) *)
                let
                  (*val thesis = Term.Free ("thesis", HOLogic.boolT)
                    |> HOLogic.mk_Trueprop
                  val frees = map Term.Free xs

                  (* !!x1..xn. t ==> thesis (xs = [x1, .., xn]) *)
                  val inner_prop = fold_rev Logic.all frees (Logic.mk_implies (t, thesis))

                  (* !!thesis. (!!x1..xn. t ==> thesis) ==> thesis *)
                  val prop = Logic.all thesis (Logic.mk_implies (inner_prop, thesis))*)

                  val thesis = Term.Free ("thesis", HOLogic.boolT)
                  val prop =
                    HOLogic.mk_imp (HOLogic.dest_Trueprop t, thesis) 
                    |> fold_rev (fn (x, T) => fn t => HOLogic.mk_all (x, T, t)) xs
                    |> rpair thesis
                    |> HOLogic.mk_imp
                    |> (fn t => HOLogic.mk_all ("thesis", HOLogic.boolT, t))
                    |> HOLogic.mk_Trueprop
                in
                  (prop, byline, true)
                end
              | _ => raise ZeroTime)
            val make_thm = Skip_Proof.make_thm (Proof_Context.theory_of ctxt)
            val facts =
              (case byline of
                By_Metis fact_names => resolve_fact_names fact_names
              | Case_Split (cases, fact_names) =>
                resolve_fact_names fact_names
                  @ (case the succedent of
                      Assume (_, t) => make_thm t
                    | Obtain (_, _, _, t, _) => make_thm t
                    | Prove (_, _, t, _) => make_thm t
                    | _ => error "Internal error: unexpected succedent of case split")
                  :: map (hd #> (fn Assume (_, a) => Logic.mk_implies (a, t)
                                  | _ => error "Internal error: malformed case split")
                             #> make_thm)
                       cases)
            val ctxt = ctxt |> Config.put Metis_Tactic.verbose debug
                            |> obtain ? Config.put Metis_Tactic.new_skolem true
            val goal =
              Goal.prove (Config.put Metis_Tactic.verbose debug ctxt) [] [] t
            fun tac {context = ctxt, prems = _} =
              Metis_Tactic.metis_tac [type_enc] lam_trans ctxt facts 1
          in
            take_time timeout (fn () => goal tac)
          end)
          handle ZeroTime => K (SOME Time.zeroTime)

      val try_metis_quietly = the_default NONE oo try oo try_metis

      (* cache metis preplay times in lazy time vector *)
      val metis_time =
        v_map_index
          (Lazy.lazy o handle_metis_fail o try_metis preplay_timeout
            o apfst (fn i => try (the o get (i-1)) proof_vect) o apsnd the)
          proof_vect
      fun sum_up_time lazy_time_vector =
        Vector.foldl
          ((fn (SOME t, (b, ts)) => (b, Time.+(t, ts))
             | (NONE, (_, ts)) => (true, Time.+(ts, preplay_timeout)))
            o apfst get_time)
          no_time lazy_time_vector

      (* Merging *)
      (* TODO: consider adding "Obtain" cases *)
      fun merge (Prove (_, label1, _, By_Metis (lfs1, gfs1)))
                (Prove (qs2, label2, t, By_Metis (lfs2, gfs2))) =
          let
            val lfs = remove (op =) label1 lfs2 |> union (op =) lfs1
            val gfs = union (op =) gfs1 gfs2
          in Prove (qs2, label2, t, By_Metis (lfs, gfs)) end
        | merge _ _ = error "Internal error: Unmergeable Isar steps"

      fun try_merge metis_time (s1, i) (s2, j) =
        (case get i metis_time |> Lazy.force of
          NONE => (NONE, metis_time)
        | SOME t1 =>
          (case get j metis_time |> Lazy.force of
            NONE => (NONE, metis_time)
          | SOME t2 =>
            let
              val s12 = merge s1 s2
              val timeout = time_mult merge_timeout_slack (Time.+(t1, t2))
            in
              case try_metis_quietly timeout (NONE, s12) () of
                NONE => (NONE, metis_time)
              | some_t12 =>
                (SOME s12, metis_time
                           |> replace (Time.zeroTime |> SOME |> Lazy.value) i
                           |> replace (Lazy.value some_t12) j)

            end))

      fun merge_steps metis_time proof_vect refed_by cand_tab n' n_metis' =
        if Inttab.is_empty cand_tab
          orelse n_metis' <= target_n_metis
          orelse (on_top_level andalso n'<3)
        then
          (Vector.foldr
             (fn (NONE, proof) => proof | (SOME s, proof) => s :: proof)
             [] proof_vect,
           sum_up_time metis_time)
        else
          let
            val (i, cand_tab) = pop_max cand_tab
            val j = get i refed_by |> the_single
            val s1 = get i proof_vect |> the
            val s2 = get j proof_vect |> the
          in
            case try_merge metis_time (s1, i) (s2, j) of
              (NONE, metis_time) =>
              merge_steps metis_time proof_vect refed_by cand_tab n' n_metis'
            | (s, metis_time) =>
            let
              val refs = refs s1
              val refed_by = refed_by |> fold
                (update (Ord_List.remove int_ord i #> Ord_List.insert int_ord j)) refs
              val new_candidates =
                fold (add_if_cand proof_vect)
                  (map (fn i => (i, get i refed_by)) refs) []
              val cand_tab = add_list cand_tab new_candidates
              val proof_vect = proof_vect |> replace NONE i |> replace s j
            in
              merge_steps metis_time proof_vect refed_by cand_tab (n' - 1)
                          (n_metis' - 1)
            end
          end
    in
      merge_steps metis_time proof_vect refed_by_vect cand_tab n n_metis
    end

    fun do_proof on_top_level ctxt proof =
      let
        (* Enrich context with top-level facts *)
        val thy = Proof_Context.theory_of ctxt
        (* TODO: add Skolem variables to context? *)
        fun enrich_with_fact l t =
          Proof_Context.put_thms false
            (string_for_label l, SOME [Skip_Proof.make_thm thy t])
        fun enrich_with_step (Assume (l, t)) = enrich_with_fact l t
          | enrich_with_step (Obtain (_, _, l, t, _)) = enrich_with_fact l t
          | enrich_with_step (Prove (_, l, t, _)) = enrich_with_fact l t
          | enrich_with_step _ = I
        val rich_ctxt = fold enrich_with_step proof ctxt

        (* Shrink case_splits and top-levl *)
        val ((proof, top_level_time), lower_level_time) =
          proof |> do_case_splits rich_ctxt
                |>> shrink_top_level on_top_level rich_ctxt
      in
        (proof, ext_time_add lower_level_time top_level_time)
      end

    and do_case_splits ctxt proof =
      let
        fun shrink_each_and_collect_time shrink candidates =
          let fun f_m cand time = shrink cand ||> ext_time_add time
          in fold_map f_m candidates no_time end
        val shrink_case_split =
          shrink_each_and_collect_time (do_proof false ctxt)
        fun shrink (Prove (qs, l, t, Case_Split (cases, facts))) =
            let val (cases, time) = shrink_case_split cases
            in (Prove (qs, l, t, Case_Split (cases, facts)), time) end
          | shrink step = (step, no_time)
      in
        shrink_each_and_collect_time shrink proof
      end
  in
    do_proof true ctxt proof
    |> apsnd (pair (metis_fail ()))
  end

end