src/HOL/Tools/Sledgehammer/sledgehammer_compress.ML
changeset 51179 0d5f8812856f
parent 51178 06689dbfe072
child 51260 61bc5a3bef09
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_compress.ML	Mon Feb 18 12:16:02 2013 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_compress.ML	Mon Feb 18 12:16:27 2013 +0100
@@ -7,11 +7,11 @@
 
 signature SLEDGEHAMMER_COMPRESS =
 sig
-  type isar_step = Sledgehammer_Proof.isar_step
+  type isar_proof = Sledgehammer_Proof.isar_proof
   type preplay_time = Sledgehammer_Preplay.preplay_time
   val compress_proof :
     bool -> Proof.context -> string -> string -> bool -> Time.time option
-    -> real -> isar_step list -> isar_step list * (bool * preplay_time)
+    -> real -> isar_proof -> isar_proof * (bool * preplay_time)
 end
 
 structure Sledgehammer_Compress : SLEDGEHAMMER_COMPRESS =
@@ -68,38 +68,37 @@
       val metis_fail = fn () => !metis_fail
     end
 
-    (* compress proof on top level - do not compress subproofs *)
-    fun compress_top_level on_top_level ctxt proof =
+    (* compress top level steps - do not compress subproofs *)
+    fun compress_top_level on_top_level ctxt steps =
     let
-      (* proof vector *)
-      val proof_vect = proof |> map SOME |> Vector.fromList
-      val n = Vector.length proof_vect
-      val n_metis = metis_steps_top_level proof
+      (* proof step vector *)
+      val step_vect = steps |> map SOME |> Vector.fromList
+      val n = Vector.length step_vect
+      val n_metis = add_metis_steps_top_level steps 0
       val target_n_metis = Real.fromInt n_metis / isar_compress |> Real.round
 
-      (* table for mapping from (top-level-)label to proof position *)
-      fun update_table (i, Assume (l, _)) = Label_Table.update_new (l, i)
+      (* table for mapping from (top-level-)label to step_vect position *)
+      fun update_table (i, Prove (_, l, _, _)) = Label_Table.update_new (l, i)
         | update_table (i, Obtain (_, _, l, _, _)) = Label_Table.update_new (l, i)
-        | update_table (i, Prove (_, l, _, _)) = Label_Table.update_new (l, i)
         | update_table _ = I
-      val label_index_table = fold_index update_table proof Label_Table.empty
+      val label_index_table = fold_index update_table steps Label_Table.empty
       val lookup_indices = map_filter (Label_Table.lookup label_index_table)
 
-      (* proof references *)
-      fun refs (Obtain (_, _, _, _, By_Metis (subproofs, (lfs, _)))) =
-              maps (maps refs) subproofs @ lookup_indices lfs
-        | refs (Prove (_, _, _, By_Metis (subproofs, (lfs, _)))) =
-              maps (maps refs) subproofs @ lookup_indices lfs
-        | refs _ = []
+      (* proof step references *)
+      fun refs step =
+        (case byline_of_step step of
+          NONE => []
+        | SOME (By_Metis (subproofs, (lfs, _))) =>
+            maps (steps_of_proof #> maps refs) subproofs @ lookup_indices lfs)
       val refed_by_vect =
         Vector.tabulate (n, K [])
-        |> fold_index (fn (i, step) => fold (update (cons i)) (refs step)) proof
+        |> fold_index (fn (i, step) => fold (update (cons i)) (refs step)) steps
         |> Vector.map rev (* after rev, indices are sorted in ascending order *)
 
       (* candidates for elimination, use table as priority queue (greedy
          algorithm) *)
-      fun add_if_cand proof_vect (i, [j]) =
-          (case (the (get i proof_vect), the (get j proof_vect)) of
+      fun add_if_cand step_vect (i, [j]) =
+          (case (the (get i step_vect), the (get j step_vect)) of
             (Prove (_, _, t, By_Metis _), Prove (_, _, _, By_Metis _)) =>
               cons (Term.size_of_term t, i)
           | (Prove (_, _, t, By_Metis _), Obtain (_, _, _, _, By_Metis _)) =>
@@ -107,7 +106,7 @@
           | _ => I)
         | add_if_cand _ _ = I
       val cand_tab =
-        v_fold_index (add_if_cand proof_vect) refed_by_vect []
+        v_fold_index (add_if_cand step_vect) refed_by_vect []
         |> Inttab.make_list
 
       (* cache metis preplay times in lazy time vector *)
@@ -119,7 +118,7 @@
              #> try_metis debug type_enc lam_trans ctxt preplay_timeout
              #> handle_metis_fail
              #> Lazy.lazy)
-          proof_vect
+          step_vect
 
       fun sum_up_time lazy_time_vector =
         Vector.foldl
@@ -127,17 +126,16 @@
           zero_preplay_time lazy_time_vector
 
       (* Merging *)
-      fun merge (Prove (_, label1, _, By_Metis (subproofs1, (lfs1, gfs1))))
-                                                                        step2 =
+      fun merge (Prove (_, lbl1, _, By_Metis (subproofs1, (lfs1, gfs1)))) step2 =
           let
             val (step_constructor, (subproofs2, (lfs2, gfs2))) =
               (case step2 of
-                Prove (qs2, label2, t, By_Metis x) =>
-                  (fn by => Prove (qs2, label2, t, by), x)
-              | Obtain (qs2, xs, label2, t, By_Metis x) =>
-                  (fn by => Obtain (qs2, xs, label2, t, by), x)
+                Prove (qs2, lbl2, t, By_Metis x) =>
+                  (fn by => Prove (qs2, lbl2, t, by), x)
+              | Obtain (qs2, xs, lbl2, t, By_Metis x) =>
+                  (fn by => Obtain (qs2, xs, lbl2, t, by), x)
               | _ => error "sledgehammer_compress: unmergeable Isar steps" )
-            val lfs = remove (op =) label1 lfs2 |> union (op =) lfs1
+            val lfs = remove (op =) lbl1 lfs2 |> union (op =) lfs1
             val gfs = union (op =) gfs1 gfs2
             val subproofs = subproofs1 @ subproofs2
           in step_constructor (By_Metis (subproofs, (lfs, gfs))) end
@@ -166,45 +164,45 @@
 
               end))
 
-      fun merge_steps metis_time proof_vect refed_by cand_tab n' n_metis' =
+      fun merge_steps metis_time step_vect refed_by cand_tab n' n_metis' =
         if Inttab.is_empty cand_tab
           orelse n_metis' <= target_n_metis
           orelse (on_top_level andalso n'<3)
         then
           (Vector.foldr
-             (fn (NONE, proof) => proof | (SOME s, proof) => s :: proof)
-             [] proof_vect,
+             (fn (NONE, steps) => steps | (SOME s, steps) => s :: steps)
+             [] step_vect,
            sum_up_time metis_time)
         else
           let
             val (i, cand_tab) = pop_max cand_tab
             val j = get i refed_by |> the_single
-            val s1 = get i proof_vect |> the
-            val s2 = get j proof_vect |> the
+            val s1 = get i step_vect |> the
+            val s2 = get j step_vect |> the
           in
             case try_merge metis_time (s1, i) (s2, j) of
               (NONE, metis_time) =>
-              merge_steps metis_time proof_vect refed_by cand_tab n' n_metis'
+              merge_steps metis_time step_vect refed_by cand_tab n' n_metis'
             | (s, metis_time) =>
             let
               val refs = refs s1
               val refed_by = refed_by |> fold
                 (update (Ord_List.remove int_ord i #> Ord_List.insert int_ord j)) refs
               val new_candidates =
-                fold (add_if_cand proof_vect)
+                fold (add_if_cand step_vect)
                   (map (fn i => (i, get i refed_by)) refs) []
               val cand_tab = add_list cand_tab new_candidates
-              val proof_vect = proof_vect |> replace NONE i |> replace s j
+              val step_vect = step_vect |> replace NONE i |> replace s j
             in
-              merge_steps metis_time proof_vect refed_by cand_tab (n' - 1)
+              merge_steps metis_time step_vect refed_by cand_tab (n' - 1)
                           (n_metis' - 1)
             end
           end
     in
-      merge_steps metis_time proof_vect refed_by_vect cand_tab n n_metis
+      merge_steps metis_time step_vect refed_by_vect cand_tab n n_metis
     end
 
-    fun do_proof on_top_level ctxt proof =
+    fun do_proof on_top_level ctxt (Proof (fix, assms,steps)) =
       let
         (* Enrich context with top-level facts *)
         val thy = Proof_Context.theory_of ctxt
@@ -212,21 +210,25 @@
         fun enrich_with_fact l t =
           Proof_Context.put_thms false
             (string_for_label l, SOME [Skip_Proof.make_thm thy t])
-        fun enrich_with_step (Assume (l, t)) = enrich_with_fact l t
+        fun enrich_with_step (Prove (_, l, t, _)) = enrich_with_fact l t
           | enrich_with_step (Obtain (_, _, l, t, _)) = enrich_with_fact l t
-          | enrich_with_step (Prove (_, l, t, _)) = enrich_with_fact l t
           | enrich_with_step _ = I
-        val rich_ctxt = fold enrich_with_step proof ctxt
+        val enrich_with_steps = fold enrich_with_step
+        fun enrich_with_assms (Assume assms) =
+          fold (uncurry enrich_with_fact) assms
+        val rich_ctxt =
+          ctxt |> enrich_with_assms assms |> enrich_with_steps steps
 
-        (* compress subproofs and top-levl proof *)
-        val ((proof, top_level_time), lower_level_time) =
-          proof |> do_subproofs rich_ctxt
+        (* compress subproofs and top-levl steps *)
+        val ((steps, top_level_time), lower_level_time) =
+          steps |> do_subproofs rich_ctxt
                 |>> compress_top_level on_top_level rich_ctxt
       in
-        (proof, add_preplay_time lower_level_time top_level_time)
+        (Proof (fix, assms, steps),
+          add_preplay_time lower_level_time top_level_time)
       end
 
-    and do_subproofs ctxt proof =
+    and do_subproofs ctxt subproofs =
       let
         fun compress_each_and_collect_time compress subproofs =
           let fun f_m proof time = compress proof ||> add_preplay_time time
@@ -238,7 +240,7 @@
               in (Prove (qs, l, t, By_Metis(subproofs, facts)), time) end
           | compress atomic_step = (atomic_step, zero_preplay_time)
       in
-        compress_each_and_collect_time compress proof
+        compress_each_and_collect_time compress subproofs
       end
   in
     do_proof true ctxt proof