src/HOL/Tools/Sledgehammer/sledgehammer_shrink.ML
changeset 50672 ab5b8b5c9cbe
parent 50557 31313171deb5
child 50678 027c09d7f6ec
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_shrink.ML	Wed Jan 02 13:31:13 2013 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_shrink.ML	Wed Jan 02 15:44:00 2013 +0100
@@ -8,7 +8,7 @@
 signature SLEDGEHAMMER_SHRINK =
 sig
   type isar_step = Sledgehammer_Proof.isar_step
-	val shrink_proof : 
+  val shrink_proof :
     bool -> Proof.context -> string -> string -> bool -> Time.time option
     -> real -> isar_step list -> isar_step list * (bool * (bool * Time.time))
 end
@@ -58,225 +58,229 @@
 (* Main function for shrinking proofs *)
 fun shrink_proof debug ctxt type_enc lam_trans preplay preplay_timeout
                  isar_shrink proof =
-let
-  (* 60 seconds seems like a good interpreation of "no timeout" *)
-  val preplay_timeout = preplay_timeout |> the_default (seconds 60.0)
-
-  (* handle metis preplay fail *)
-  local
-    open Unsynchronized
-    val metis_fail = ref false
-  in
-    fun handle_metis_fail try_metis () =
-      try_metis () handle _ => (metis_fail := true; SOME Time.zeroTime)
-    fun get_time lazy_time = 
-      if !metis_fail then SOME Time.zeroTime else Lazy.force lazy_time
-    val metis_fail = fn () => !metis_fail
-  end
-  
-  (* Shrink top level proof - do not shrink case splits *)
-  fun shrink_top_level on_top_level ctxt proof =
   let
-    (* proof vector *)
-    val proof_vect = proof |> map SOME |> Vector.fromList
-    val n = Vector.length proof_vect
-    val n_metis = metis_steps_top_level proof
-    val target_n_metis = Real.fromInt n_metis / isar_shrink |> Real.round
-
-    (* table for mapping from (top-level-)label to proof position *)
-    fun update_table (i, Assume (label, _)) = 
-        Label_Table.update_new (label, i)
-      | update_table (i, Prove (_, label, _, _)) =
-        Label_Table.update_new (label, i)
-      | update_table _ = I
-    val label_index_table = fold_index update_table proof Label_Table.empty
-
-    (* proof references *)
-    fun refs (Prove (_, _, _, By_Metis (lfs, _))) =
-        map_filter (Label_Table.lookup label_index_table) lfs
-      | refs (Prove (_, _, _, Case_Split (cases, (lfs, _)))) =
-        map_filter (Label_Table.lookup label_index_table) lfs
-          @ maps (maps refs) cases
-      | refs _ = []
-    val refed_by_vect =
-      Vector.tabulate (n, (fn _ => []))
-      |> fold_index (fn (i, step) => fold (update (cons i)) (refs step)) proof
-      |> Vector.map rev (* after rev, indices are sorted in ascending order *)
-
-    (* candidates for elimination, use table as priority queue (greedy
-       algorithm) *)
-    fun add_if_cand proof_vect (i, [j]) =
-        (case (the (get i proof_vect), the (get j proof_vect)) of
-          (Prove (_, _, t, By_Metis _), Prove (_, _, _, By_Metis _)) =>
-            cons (Term.size_of_term t, i)
-        | _ => I)
-      | add_if_cand _ _ = I
-    val cand_tab = 
-      v_fold_index (add_if_cand proof_vect) refed_by_vect []
-      |> Inttab.make_list
-
-    (* Metis Preplaying *)
-    fun resolve_fact_names names =
-      names
-        |>> map string_for_label 
-        |> op @
-        |> maps (thms_of_name ctxt)
-
-    fun try_metis timeout (succedent, Prove (_, _, t, byline)) =
-      if not preplay then (fn () => SOME Time.zeroTime) else
-      (case byline of
-        By_Metis fact_names =>
-          let
-            val facts = resolve_fact_names fact_names
-            val goal =
-              Goal.prove (Config.put Metis_Tactic.verbose debug ctxt) [] [] t
-            fun tac {context = ctxt, prems = _} =
-              Metis_Tactic.metis_tac [type_enc] lam_trans ctxt facts 1
-          in
-            take_time timeout (fn () => goal tac)
-          end
-      | Case_Split (cases, fact_names) =>
-          let
-            val facts = 
-              resolve_fact_names fact_names
-                @ (case the succedent of
-                    Prove (_, _, t, _) => 
-                      Skip_Proof.make_thm (Proof_Context.theory_of ctxt) t
-                  | Assume (_, t) =>
-                      Skip_Proof.make_thm (Proof_Context.theory_of ctxt) t
-                  | _ => error "Internal error: unexpected succedent of case split")
-                ::  map 
-                      (hd #> (fn Assume (_, a) => Logic.mk_implies(a, t)
-                               | _ => error "Internal error: malformed case split") 
-                          #> Skip_Proof.make_thm (Proof_Context.theory_of ctxt)) 
-                      cases
-            val goal =
-              Goal.prove (Config.put Metis_Tactic.verbose debug ctxt) [] [] t
-            fun tac {context = ctxt, prems = _} =
-              Metis_Tactic.metis_tac [type_enc] lam_trans ctxt facts 1
-          in
-            take_time timeout (fn () => goal tac)
-          end)
-      | try_metis _ _  = (fn () => SOME Time.zeroTime)
+    (* 60 seconds seems like a good interpreation of "no timeout" *)
+    val preplay_timeout = preplay_timeout |> the_default (seconds 60.0)
 
-    val try_metis_quietly = the_default NONE oo try oo try_metis
-
-    (* cache metis preplay times in lazy time vector *)
-    val metis_time =
-      v_map_index
-        (Lazy.lazy o handle_metis_fail o try_metis preplay_timeout 
-          o apfst (fn i => try (the o get (i-1)) proof_vect) o apsnd the)
-        proof_vect
-    fun sum_up_time lazy_time_vector =
-      Vector.foldl
-        ((fn (SOME t, (b, ts)) => (b, Time.+(t, ts))
-           | (NONE, (_, ts)) => (true, Time.+(ts, preplay_timeout))) 
-          o apfst get_time)
-        no_time lazy_time_vector
-
-    (* Merging *)
-    fun merge (Prove (_, label1, _, By_Metis (lfs1, gfs1)))
-              (Prove (qs2, label2 , t, By_Metis (lfs2, gfs2))) =
-      let
-        val lfs = remove (op =) label1 lfs2 |> union (op =) lfs1
-        val gfs = union (op =) gfs1 gfs2
-      in Prove (qs2, label2, t, By_Metis (lfs, gfs)) end
-      | merge _ _ = error "Internal error: Tring to merge unmergable isar steps"
-
-    fun try_merge metis_time (s1, i) (s2, j) =
-      (case get i metis_time |> Lazy.force of
-        NONE => (NONE, metis_time)
-      | SOME t1 =>
-        (case get j metis_time |> Lazy.force of
-          NONE => (NONE, metis_time)
-        | SOME t2 =>
-          let
-            val s12 = merge s1 s2
-            val timeout = time_mult merge_timeout_slack (Time.+(t1, t2))
-          in
-            case try_metis_quietly timeout (NONE, s12) () of
-              NONE => (NONE, metis_time)
-            | some_t12 =>
-              (SOME s12, metis_time
-                         |> replace (Time.zeroTime |> SOME |> Lazy.value) i
-                         |> replace (Lazy.value some_t12) j)
-
-          end))
-
-    fun merge_steps metis_time proof_vect refed_by cand_tab n' n_metis' =
-      if Inttab.is_empty cand_tab 
-        orelse n_metis' <= target_n_metis 
-        orelse (on_top_level andalso n'<3)
-      then
-        (Vector.foldr
-           (fn (NONE, proof) => proof | (SOME s, proof) => s :: proof)
-           [] proof_vect,
-         sum_up_time metis_time)
-      else
-        let
-          val (i, cand_tab) = pop_max cand_tab
-          val j = get i refed_by |> the_single
-          val s1 = get i proof_vect |> the
-          val s2 = get j proof_vect |> the
-        in
-          case try_merge metis_time (s1, i) (s2, j) of
-            (NONE, metis_time) =>
-            merge_steps metis_time proof_vect refed_by cand_tab n' n_metis'
-          | (s, metis_time) => 
-          let
-            val refs = refs s1
-            val refed_by = refed_by |> fold
-              (update (Ord_List.remove int_ord i #> Ord_List.insert int_ord j)) refs
-            val new_candidates =
-              fold (add_if_cand proof_vect)
-                (map (fn i => (i, get i refed_by)) refs) []
-            val cand_tab = add_list cand_tab new_candidates
-            val proof_vect = proof_vect |> replace NONE i |> replace s j
-          in
-            merge_steps metis_time proof_vect refed_by cand_tab (n' - 1) (n_metis' - 1)
-          end
-        end
-  in
-    merge_steps metis_time proof_vect refed_by_vect cand_tab n n_metis
-  end
-  
-  fun shrink_proof' on_top_level ctxt proof = 
-    let
-      (* Enrich context with top-level facts *)
-      val thy = Proof_Context.theory_of ctxt
-      fun enrich_ctxt (Assume (label, t)) ctxt = 
-          Proof_Context.put_thms false
-            (string_for_label label, SOME [Skip_Proof.make_thm thy t]) ctxt
-        | enrich_ctxt (Prove (_, label, t, _)) ctxt =
-          Proof_Context.put_thms false
-            (string_for_label label, SOME [Skip_Proof.make_thm thy t]) ctxt
-        | enrich_ctxt _ ctxt = ctxt
-      val rich_ctxt = fold enrich_ctxt proof ctxt
-
-      (* Shrink case_splits and top-levl *)
-      val ((proof, top_level_time), lower_level_time) = 
-        proof |> shrink_case_splits rich_ctxt
-              |>> shrink_top_level on_top_level rich_ctxt
+    (* handle metis preplay fail *)
+    local
+      open Unsynchronized
+      val metis_fail = ref false
     in
-      (proof, ext_time_add lower_level_time top_level_time)
+      fun handle_metis_fail try_metis () =
+        try_metis () handle _ => (metis_fail := true; SOME Time.zeroTime)
+      fun get_time lazy_time =
+        if !metis_fail then SOME Time.zeroTime else Lazy.force lazy_time
+      val metis_fail = fn () => !metis_fail
     end
 
-  and shrink_case_splits ctxt proof =
+    (* Shrink top level proof - do not shrink case splits *)
+    fun shrink_top_level on_top_level ctxt proof =
     let
-      fun shrink_each_and_collect_time shrink candidates =
-        let fun f_m cand time = shrink cand ||> ext_time_add time
-        in fold_map f_m candidates no_time end
-      val shrink_case_split = shrink_each_and_collect_time (shrink_proof' false ctxt)
-      fun shrink (Prove (qs, lbl, t, Case_Split (cases, facts))) =
-          let val (cases, time) = shrink_case_split cases
-          in (Prove (qs, lbl, t, Case_Split (cases, facts)), time) end
-        | shrink step = (step, no_time)
-    in 
-      shrink_each_and_collect_time shrink proof
+      (* proof vector *)
+      val proof_vect = proof |> map SOME |> Vector.fromList
+      val n = Vector.length proof_vect
+      val n_metis = metis_steps_top_level proof
+      val target_n_metis = Real.fromInt n_metis / isar_shrink |> Real.round
+
+      (* table for mapping from (top-level-)label to proof position *)
+      fun update_table (i, Assume (l, _)) = Label_Table.update_new (l, i)
+        | update_table (i, Obtain (_, _, l, _, _)) = Label_Table.update_new (l, i)
+        | update_table (i, Prove (_, l, _, _)) = Label_Table.update_new (l, i)
+        | update_table _ = I
+      val label_index_table = fold_index update_table proof Label_Table.empty
+      val filter_refs = map_filter (Label_Table.lookup label_index_table)
+
+      (* proof references *)
+      fun refs (Obtain (_, _, _, _, By_Metis (lfs, _))) = filter_refs lfs
+        | refs (Prove (_, _, _, By_Metis (lfs, _))) = filter_refs lfs
+        | refs (Prove (_, _, _, Case_Split (cases, (lfs, _)))) =
+          filter_refs lfs @ maps (maps refs) cases
+        | refs _ = []
+      val refed_by_vect =
+        Vector.tabulate (n, (fn _ => []))
+        |> fold_index (fn (i, step) => fold (update (cons i)) (refs step)) proof
+        |> Vector.map rev (* after rev, indices are sorted in ascending order *)
+
+      (* candidates for elimination, use table as priority queue (greedy
+         algorithm) *)
+      (* TODO: consider adding "Obtain" cases *)
+      fun add_if_cand proof_vect (i, [j]) =
+          (case (the (get i proof_vect), the (get j proof_vect)) of
+            (Prove (_, _, t, By_Metis _), Prove (_, _, _, By_Metis _)) =>
+            cons (Term.size_of_term t, i)
+          | _ => I)
+        | add_if_cand _ _ = I
+      val cand_tab =
+        v_fold_index (add_if_cand proof_vect) refed_by_vect []
+        |> Inttab.make_list
+
+      (* Metis Preplaying *)
+      fun resolve_fact_names names =
+        names
+          |>> map string_for_label
+          |> op @
+          |> maps (thms_of_name ctxt)
+
+      (* TODO: add "Obtain" case *)
+      fun try_metis timeout (succedent, Prove (_, _, t, byline)) =
+        if not preplay then K (SOME Time.zeroTime) else
+        (case byline of
+          By_Metis fact_names =>
+            let
+              val facts = resolve_fact_names fact_names
+              val goal =
+                Goal.prove (Config.put Metis_Tactic.verbose debug ctxt) [] [] t
+              fun tac {context = ctxt, prems = _} =
+                Metis_Tactic.metis_tac [type_enc] lam_trans ctxt facts 1
+            in
+              take_time timeout (fn () => goal tac)
+            end
+        | Case_Split (cases, fact_names) =>
+            let
+              val make_thm = Skip_Proof.make_thm (Proof_Context.theory_of ctxt)
+              val facts =
+                resolve_fact_names fact_names
+                  @ (case the succedent of
+                      Assume (_, t) => make_thm t
+                    | Obtain (_, _, _, t, _) => make_thm t
+                    | Prove (_, _, t, _) => make_thm t
+                    | _ => error "Internal error: unexpected succedent of case split")
+                  :: map (hd #> (fn Assume (_, a) => Logic.mk_implies (a, t)
+                                  | _ => error "Internal error: malformed case split")
+                             #> Skip_Proof.make_thm (Proof_Context.theory_of ctxt))
+                         cases
+              val goal =
+                Goal.prove (Config.put Metis_Tactic.verbose debug ctxt) [] [] t
+              fun tac {context = ctxt, prems = _} =
+                Metis_Tactic.metis_tac [type_enc] lam_trans ctxt facts 1
+            in
+              take_time timeout (fn () => goal tac)
+            end)
+        | try_metis _ _  = K (SOME Time.zeroTime)
+
+      val try_metis_quietly = the_default NONE oo try oo try_metis
+
+      (* cache metis preplay times in lazy time vector *)
+      val metis_time =
+        v_map_index
+          (Lazy.lazy o handle_metis_fail o try_metis preplay_timeout
+            o apfst (fn i => try (the o get (i-1)) proof_vect) o apsnd the)
+          proof_vect
+      fun sum_up_time lazy_time_vector =
+        Vector.foldl
+          ((fn (SOME t, (b, ts)) => (b, Time.+(t, ts))
+             | (NONE, (_, ts)) => (true, Time.+(ts, preplay_timeout)))
+            o apfst get_time)
+          no_time lazy_time_vector
+
+      (* Merging *)
+      (* TODO: consider adding "Obtain" cases *)
+      fun merge (Prove (_, label1, _, By_Metis (lfs1, gfs1)))
+                (Prove (qs2, label2, t, By_Metis (lfs2, gfs2))) =
+          let
+            val lfs = remove (op =) label1 lfs2 |> union (op =) lfs1
+            val gfs = union (op =) gfs1 gfs2
+          in Prove (qs2, label2, t, By_Metis (lfs, gfs)) end
+        | merge _ _ = error "Internal error: Unmergeable Isar steps"
+
+      fun try_merge metis_time (s1, i) (s2, j) =
+        (case get i metis_time |> Lazy.force of
+          NONE => (NONE, metis_time)
+        | SOME t1 =>
+          (case get j metis_time |> Lazy.force of
+            NONE => (NONE, metis_time)
+          | SOME t2 =>
+            let
+              val s12 = merge s1 s2
+              val timeout = time_mult merge_timeout_slack (Time.+(t1, t2))
+            in
+              case try_metis_quietly timeout (NONE, s12) () of
+                NONE => (NONE, metis_time)
+              | some_t12 =>
+                (SOME s12, metis_time
+                           |> replace (Time.zeroTime |> SOME |> Lazy.value) i
+                           |> replace (Lazy.value some_t12) j)
+
+            end))
+
+      fun merge_steps metis_time proof_vect refed_by cand_tab n' n_metis' =
+        if Inttab.is_empty cand_tab
+          orelse n_metis' <= target_n_metis
+          orelse (on_top_level andalso n'<3)
+        then
+          (Vector.foldr
+             (fn (NONE, proof) => proof | (SOME s, proof) => s :: proof)
+             [] proof_vect,
+           sum_up_time metis_time)
+        else
+          let
+            val (i, cand_tab) = pop_max cand_tab
+            val j = get i refed_by |> the_single
+            val s1 = get i proof_vect |> the
+            val s2 = get j proof_vect |> the
+          in
+            case try_merge metis_time (s1, i) (s2, j) of
+              (NONE, metis_time) =>
+              merge_steps metis_time proof_vect refed_by cand_tab n' n_metis'
+            | (s, metis_time) =>
+            let
+              val refs = refs s1
+              val refed_by = refed_by |> fold
+                (update (Ord_List.remove int_ord i #> Ord_List.insert int_ord j)) refs
+              val new_candidates =
+                fold (add_if_cand proof_vect)
+                  (map (fn i => (i, get i refed_by)) refs) []
+              val cand_tab = add_list cand_tab new_candidates
+              val proof_vect = proof_vect |> replace NONE i |> replace s j
+            in
+              merge_steps metis_time proof_vect refed_by cand_tab (n' - 1)
+                          (n_metis' - 1)
+            end
+          end
+    in
+      merge_steps metis_time proof_vect refed_by_vect cand_tab n n_metis
     end
-in
-  shrink_proof' true ctxt proof
-  |> apsnd (pair (metis_fail () ) )
-end
+
+    fun do_proof on_top_level ctxt proof =
+      let
+        (* Enrich context with top-level facts *)
+        val thy = Proof_Context.theory_of ctxt
+        (* TODO: add Skolem variables to context? *)
+        fun enrich_with_fact l t =
+          Proof_Context.put_thms false
+            (string_for_label l, SOME [Skip_Proof.make_thm thy t])
+        fun enrich_with_step (Assume (l, t)) = enrich_with_fact l t
+          | enrich_with_step (Obtain (_, _, l, t, _)) = enrich_with_fact l t
+          | enrich_with_step (Prove (_, l, t, _)) = enrich_with_fact l t
+          | enrich_with_step _ = I
+        val rich_ctxt = fold enrich_with_step proof ctxt
+
+        (* Shrink case_splits and top-levl *)
+        val ((proof, top_level_time), lower_level_time) =
+          proof |> do_case_splits rich_ctxt
+                |>> shrink_top_level on_top_level rich_ctxt
+      in
+        (proof, ext_time_add lower_level_time top_level_time)
+      end
+
+    and do_case_splits ctxt proof =
+      let
+        fun shrink_each_and_collect_time shrink candidates =
+          let fun f_m cand time = shrink cand ||> ext_time_add time
+          in fold_map f_m candidates no_time end
+        val shrink_case_split =
+          shrink_each_and_collect_time (do_proof false ctxt)
+        fun shrink (Prove (qs, l, t, Case_Split (cases, facts))) =
+            let val (cases, time) = shrink_case_split cases
+            in (Prove (qs, l, t, Case_Split (cases, facts)), time) end
+          | shrink step = (step, no_time)
+      in
+        shrink_each_and_collect_time shrink proof
+      end
+  in
+    do_proof true ctxt proof
+    |> apsnd (pair (metis_fail ()))
+  end
 
 end