src/HOL/Tools/Sledgehammer/sledgehammer_reconstruct.ML
changeset 50004 c96e8e40d789
parent 49994 ceceb403eb4e
child 50005 e9a9bff107da
--- a/src/HOL/Tools/Sledgehammer/sledgehammer_reconstruct.ML	Fri Nov 02 14:23:54 2012 +0100
+++ b/src/HOL/Tools/Sledgehammer/sledgehammer_reconstruct.ML	Fri Nov 02 16:16:48 2012 +0100
@@ -23,8 +23,8 @@
   type one_line_params =
     play * string * (string * stature) list * minimize_command * int * int
   type isar_params =
-    bool * bool * real * string Symtab.table * (string * stature) list vector
-    * int Symtab.table * string proof * thm
+    bool * bool * Time.time * real * string Symtab.table
+    * (string * stature) list vector * int Symtab.table * string proof * thm
 
   val smtN : string
   val string_for_reconstructor : reconstructor -> string
@@ -506,8 +506,8 @@
 
 (* (2) Typing-spot Table *)
 local
-fun key_of_atype (TVar (idxn, _)) =
-    Ord_List.insert Term_Ord.fast_indexname_ord idxn
+fun key_of_atype (TVar (z, _)) =
+    Ord_List.insert Term_Ord.fast_indexname_ord z
   | key_of_atype _ = I
 fun key_of_type T = fold_atyps key_of_atype T []
 fun update_tab t T (tab, pos) =
@@ -755,97 +755,200 @@
 
 val merge_timeout_slack = 1.2
 
-fun shrink_locally ctxt type_enc lam_trans isar_shrinkage proof =
+val label_ord = prod_ord int_ord fast_string_ord o pairself swap
+
+structure Label_Table = Table(
+  type key = label
+  val ord = label_ord)
+
+fun shrink_proof debug ctxt type_enc lam_trans preplay
+                 preplay_timeout isar_shrinkage proof =
   let
-    (* Merging spots, greedy algorithm *)
+    (* clean vector interface *)
+    fun get i v = Vector.sub (v, i)
+    fun replace x i v = Vector.update (v, i, x)
+    fun update f i v = replace (get i v |> f) i v
+    fun v_fold_index f v s =
+      Vector.foldl (fn (x, (i, s)) => (i+1, f (i, x) s)) (0, s) v |> snd
+
+    (* Queue interface to table *)
+    fun pop tab key =
+      let val v = hd (Inttab.lookup_list tab key) in
+        (v, Inttab.remove_list (op =) (key, v) tab)
+      end
+    fun pop_max tab = pop tab (the (Inttab.max_key tab))
+    val is_empty = Inttab.is_empty
+    fun add_list tab xs = fold (Inttab.insert_list (op =)) xs tab
+
+    (* proof vector *)
+    val proof_vect = proof |> map SOME |> Vector.fromList
+    val n = Vector.length proof_vect
+    val n_target = Real.fromInt n / isar_shrinkage |> Real.round
+
+    (* table for mapping from label to proof position *)
+    fun update_table (i, Prove (_, label, _, _)) =
+        Label_Table.update_new (label, i)
+      | update_table _ = I
+    val label_index_table = fold_index update_table proof Label_Table.empty
+
+    (* proof references *)
+    fun refs (Prove (_, _, _, By_Metis (refs, _))) =
+      map (the o Label_Table.lookup label_index_table) refs
+      | refs _ = []
+    val refed_by_vect =
+      Vector.tabulate (n, (fn _ => []))
+      |> fold_index (fn (i, step) => fold (update (cons i)) (refs step)) proof
+      |> Vector.map rev (* after rev, indices are sorted in ascending order *)
+
+    (* candidates for elimination, use table as priority queue (greedy
+       algorithm) *)
     fun cost (Prove (_, _ , t, _)) = Term.size_of_term t
-      | cost _ = ~1
-    fun can_merge (Prove (_, lbl, _, By_Metis _))
-                  (Prove (_, _, _, By_Metis _)) =
-        (lbl = no_label)
-      | can_merge _ _ = false
-    val merge_spots = 
-      fold_index (fn (i, s2) => fn (s1, pile) =>
-          (s2, pile |> can_merge s1 s2 ? cons (i, cost s1)))
-        (tl proof) (hd proof, [])
-      |> snd |> sort (rev_order o int_ord o pairself snd) |> map fst
+      | cost _ = 0
+    val cand_ord =  rev_order o prod_ord int_ord int_ord
+    val cand_tab =
+      v_fold_index
+        (fn (i, [_]) => cons (get i proof_vect |> the |> cost, i)
+        | _ => I) refed_by_vect []
+      |> Inttab.make_list
 
     (* Enrich context with local facts *)
     val thy = Proof_Context.theory_of ctxt
-    fun sorry t = Skip_Proof.make_thm thy t
-    fun enrich_ctxt' (Prove (_, lbl, t, _)) ctxt = 
-        ctxt |> lbl <> no_label
-          ? Proof_Context.put_thms false (string_for_label lbl, SOME [sorry t])
+    fun sorry t = Skip_Proof.make_thm thy (HOLogic.mk_Trueprop t)
+    fun enrich_ctxt' (Prove (_, label, t, _)) ctxt =
+        Proof_Context.put_thms false (string_for_label label, SOME [sorry t])
+                               ctxt
       | enrich_ctxt' _ ctxt = ctxt
     val rich_ctxt = fold enrich_ctxt' proof ctxt
 
     (* Timing *)
-    fun take_time tac arg =
+    fun take_time timeout tac arg =
       let val timing = Timing.start () in
-        (tac arg; Timing.result timing |> #cpu)
+        (TimeLimit.timeLimit timeout tac arg;
+         Timing.result timing |> #cpu |> SOME)
+        handle _ => NONE
       end
-    fun try_metis (Prove (qs, _, t, By_Metis fact_names)) s0 =
+    val sum_up_time =
+      Vector.foldl
+      ((fn (SOME t, (b, s)) => (b, t + s)
+         | (NONE, (_, s)) => (true, preplay_timeout + s)) o apfst Lazy.force)
+      (false, seconds 0.0)
+
+    (* Metis Preplaying *)
+    fun try_metis timeout (Prove (_, _, t, By_Metis fact_names)) =
+      if not preplay then (fn () => SOME (seconds 0.0)) else
+        let
+          val facts =
+            fact_names
+            |>> map string_for_label |> op @
+            |> map (the_single o thms_of_name rich_ctxt)
+          val goal = Goal.prove (Config.put Metis_Tactic.verbose debug ctxt)
+            [] [] (HOLogic.mk_Trueprop t)
+          fun tac {context = ctxt, prems = _} =
+            Metis_Tactic.metis_tac [type_enc] lam_trans ctxt facts 1
+        in
+          take_time timeout (fn () => goal tac)
+        end
+
+    (* Lazy metis time vector, cache *)
+    val metis_time =
+      Vector.map (Lazy.lazy o try_metis preplay_timeout o the) proof_vect
+
+    (* Merging *)
+    fun merge (Prove (qs1, label1, _, By_Metis (lfs1, gfs1)))
+              (Prove (qs2, label2 , t, By_Metis (lfs2, gfs2))) =
       let
-        fun thmify (Prove (_, _, t, _)) = sorry t
-        val facts =
-          fact_names
-          |>> map string_for_label |> op @
-          |> map (the_single o thms_of_name rich_ctxt)
-          |> (if member (op =) qs Then then cons (the s0 |> thmify) else I)
-        val goal = Goal.prove ctxt [] [] t
-        fun tac {context = ctxt, prems = _} =
-          Metis_Tactic.metis_tac [type_enc] lam_trans ctxt facts 1
-      in
-        take_time (fn () => goal tac)
-      end
-  
-    (* Merging *)
-    fun merge (Prove (qs1, _, _, By_Metis (ls1, ss1))) 
-              (Prove (qs2, lbl , t, By_Metis (ls2, ss2))) =
-      let
-        val qs =
-          inter (op =) qs1 qs2 (* FIXME: Is this correct? *)
+        val qs = inter (op =) qs1 qs2 (* FIXME: Is this correct? *)
           |> member (op =) (union (op =) qs1 qs2) Ultimately ? cons Ultimately
           |> member (op =) qs2 Show ? cons Show
-      in Prove (qs, lbl, t, By_Metis (ls1 @ ls2, ss1 @ ss2)) end
-    fun try_merge proof i =
-      let
-        val (front, s0, s1, s2, tail) = 
-          case (proof, i) of
-            ((s1 :: s2 :: proof), 0) => ([], NONE, s1, s2, proof)
-          | _ =>
-            let val (front, s0 :: s1 :: s2 :: tail) = chop (i - 1) proof in
-              (front, SOME s0, s1, s2, tail)
-            end
-        val s12 = merge s1 s2
-        val t1 = try_metis s1 s0 ()
-        val t2 = try_metis s2 (SOME s1) ()
-        val timeout =
-          Time.+ (t1, t2) |> Time.toReal |> curry Real.* merge_timeout_slack
-                  |> Time.fromReal
-      in
-        (TimeLimit.timeLimit timeout (try_metis s12 s0) ();
-         SOME (front @ (the_list s0 @ s12 :: tail)))
-        handle _ => NONE
-      end
-    fun spill_shrinkage shrinkage = isar_shrinkage + shrinkage - 1.0
-    fun merge_steps _ proof [] = proof
-      | merge_steps shrinkage proof (i :: is) = 
-        if shrinkage < 1.5 then
-          merge_steps (spill_shrinkage shrinkage) proof is
-        else case try_merge proof i of
-          NONE => merge_steps (spill_shrinkage shrinkage) proof is
-        | SOME proof' =>
-          merge_steps (shrinkage - 1.0) proof'
-            (map (fn j => if j > i then j - 1 else j) is)
-  in merge_steps isar_shrinkage proof merge_spots end
+        val ls = remove (op =) label1 lfs2 |> union (op =) lfs1
+        val ss = union (op =) gfs1 gfs2
+      in Prove (qs, label2, t, By_Metis (ls, ss)) end
+    fun try_merge metis_time (s1, i) (s2, j) =
+      (case get i metis_time |> Lazy.force of
+        NONE => (NONE, metis_time)
+      | SOME t1 =>
+        (case get j metis_time |> Lazy.force of
+          NONE => (NONE, metis_time)
+        | SOME t2 =>
+          let
+            val s12 = merge s1 s2
+            val timeout =
+              t1 + t2 |> Time.toReal |> curry Real.* merge_timeout_slack
+                      |> Time.fromReal
+          in
+            case try_metis timeout s12 () of
+              NONE => (NONE, metis_time)
+            | some_t12 =>
+              (SOME s12, metis_time
+                         |> replace (seconds 0.0 |> SOME |> Lazy.value) i
+                         |> replace (Lazy.value some_t12) j)
+
+          end))
+
+    fun merge_steps metis_time proof_vect refed_by cand_tab n' =
+      if is_empty cand_tab orelse n' <= n_target orelse n'<3 then
+        (sum_up_time metis_time,
+         Vector.foldr
+           (fn (NONE, proof) => proof | (SOME s, proof) => s :: proof)
+           [] proof_vect)
+      else
+        let
+          val (i, cand_tab) = pop_max cand_tab
+          val j = get i refed_by |> the_single
+          val s1 = get i proof_vect |> the
+          val s2 = get j proof_vect |> the
+        in
+          case try_merge metis_time (s1, i) (s2, j) of
+            (NONE, metis_time) =>
+            merge_steps metis_time proof_vect refed_by cand_tab n'
+          | (s, metis_time) => let
+            val refs = refs s1
+            val refed_by = refed_by |> fold
+              (update (Ord_List.remove int_ord i #> Ord_List.insert int_ord j)) refs
+            val new_candidates =
+              fold
+                (fn (i, [_]) => cons (cost (get i proof_vect |> the), i) | _ => I)
+                (map (fn i => (i, get i refed_by)) refs) []
+              |> sort cand_ord
+            val cand_tab = add_list cand_tab new_candidates
+            val proof_vect = proof_vect |> replace NONE i |> replace s j
+          in
+            merge_steps metis_time proof_vect refed_by cand_tab (n' - 1)
+          end
+        end
+  in
+    merge_steps metis_time proof_vect refed_by_vect cand_tab n
+  end
+
+val chain_direct_proof =
+  let
+    fun succedent_of_step (Prove (_, label, _, _)) = SOME label
+      | succedent_of_step (Assume (label, _)) = SOME label
+      | succedent_of_step _ = NONE
+    fun chain_inf (SOME label0)
+                  (step as Prove (qs, label, t, By_Metis (lfs, gfs))) =
+        if member (op =) lfs label0 then
+          Prove (Then :: qs, label, t,
+                 By_Metis (filter_out (curry (op =) label0) lfs, gfs))
+        else
+          step
+      | chain_inf _ (Prove (qs, label, t, Case_Split (proofs, facts))) =
+        Prove (qs, label, t, Case_Split ((map (chain_proof NONE) proofs), facts))
+      | chain_inf _ step = step
+    and chain_proof _ [] = []
+      | chain_proof (prev as SOME _) (i :: is) =
+        chain_inf prev i :: chain_proof (succedent_of_step i) is
+      | chain_proof _ (i :: is) =
+        i :: chain_proof (succedent_of_step i) is
+  in chain_proof NONE end
 
 type isar_params =
-  bool * bool * real * string Symtab.table * (string * stature) list vector
-  * int Symtab.table * string proof * thm
+  bool * bool * Time.time * real * string Symtab.table
+  * (string * stature) list vector * int Symtab.table * string proof * thm
 
 fun isar_proof_text ctxt isar_proofs
-    (debug, verbose, isar_shrinkage, pool, fact_names, sym_tab, atp_proof, goal)
+    (debug, verbose, preplay_timeout, isar_shrinkage,
+     pool, fact_names, sym_tab, atp_proof, goal)
     (one_line_params as (_, _, _, _, subgoal, subgoal_count)) =
   let
     val (params, hyp_ts, concl_t) = strip_subgoal ctxt goal subgoal
@@ -855,6 +958,7 @@
       if is_typed_helper_used_in_atp_proof atp_proof then full_typesN
       else partial_typesN
     val lam_trans = lam_trans_from_atp_proof atp_proof metis_default_lam_trans
+    val preplay = preplay_timeout <> seconds 0.0
 
     fun isar_proof_of () =
       let
@@ -907,7 +1011,6 @@
                  By_Metis (fold (add_fact_from_dependency fact_names
                                  o the_single) gamma ([], [])))
         fun do_inf outer (Have z) = do_have outer [] z
-          | do_inf outer (Hence z) = do_have outer [Then] z
           | do_inf outer (Cases cases) =
             let val c = succedent_of_cases cases in
               Prove (maybe_show outer c [Ultimately], label_of_clause c,
@@ -917,17 +1020,22 @@
         and do_case outer (c, infs) =
           Assume (label_of_clause c, prop_of_clause c) ::
           map (do_inf outer) infs
-        val isar_proof =
-          (if null params then [] else [Fix params]) @
-          (ref_graph
-           |> redirect_graph axioms tainted
-           |> chain_direct_proof
-           |> map (do_inf true)
-           |> kill_duplicate_assumptions_in_proof
-           |> kill_useless_labels_in_proof
-           |> relabel_proof
-           |> shrink_locally ctxt type_enc lam_trans
-                (if isar_proofs then isar_shrinkage else 1000.0))
+        val (ext_time, isar_proof) =
+          ref_graph
+          |> redirect_graph axioms tainted
+          |> map (do_inf true)
+          |> kill_duplicate_assumptions_in_proof
+          |> (if isar_shrinkage <= 1.0 andalso isar_proofs then
+                pair (true, seconds 0.0)
+              else
+                shrink_proof debug ctxt type_enc lam_trans preplay
+                     preplay_timeout
+                     (if isar_proofs then isar_shrinkage else 1000.0))
+       (* ||> reorder_proof_to_minimize_jumps (* ? *) *)
+          ||> chain_direct_proof
+          ||> kill_useless_labels_in_proof
+          ||> relabel_proof
+          ||> not (null params) ? cons (Fix params)
         val num_steps = length isar_proof
         val isar_text =
           string_for_proof ctxt type_enc lam_trans subgoal subgoal_count
@@ -945,6 +1053,10 @@
              "Structured proof" ^
              (if verbose then
                 " (" ^ string_of_int num_steps ^ " step" ^ plural_s num_steps ^
+                (if preplay andalso isar_shrinkage > 1.0 then
+                   ", " ^ string_from_ext_time ext_time
+                 else
+                   "") ^
                 ")"
               else
                 "")