src/HOL/Tools/Sledgehammer/sledgehammer_shrink.ML
changeset 50924 beb95bf66b21
parent 50923 141d8f575f6f
child 51128 0021ea861129
equal deleted inserted replaced
50923:141d8f575f6f 50924:beb95bf66b21
     6 *)
     6 *)
     7 
     7 
     8 signature SLEDGEHAMMER_SHRINK =
     8 signature SLEDGEHAMMER_SHRINK =
     9 sig
     9 sig
    10   type isar_step = Sledgehammer_Proof.isar_step
    10   type isar_step = Sledgehammer_Proof.isar_step
       
    11   type preplay_time = Sledgehammer_Preplay.preplay_time
    11   val shrink_proof :
    12   val shrink_proof :
    12     bool -> Proof.context -> string -> string -> bool -> Time.time option
    13     bool -> Proof.context -> string -> string -> bool -> Time.time option
    13     -> real -> isar_step list -> isar_step list * (bool * (bool * Time.time))
    14     -> real -> isar_step list -> isar_step list * (bool * preplay_time)
    14 end
    15 end
    15 
    16 
    16 structure Sledgehammer_Shrink : SLEDGEHAMMER_SHRINK =
    17 structure Sledgehammer_Shrink : SLEDGEHAMMER_SHRINK =
    17 struct
    18 struct
    18 
    19 
    42   let val v = hd (Inttab.lookup_list tab key) in
    43   let val v = hd (Inttab.lookup_list tab key) in
    43     (v, Inttab.remove_list (op =) (key, v) tab)
    44     (v, Inttab.remove_list (op =) (key, v) tab)
    44   end
    45   end
    45 fun pop_max tab = pop tab (the (Inttab.max_key tab))
    46 fun pop_max tab = pop tab (the (Inttab.max_key tab))
    46 fun add_list tab xs = fold (Inttab.insert_list (op =)) xs tab
    47 fun add_list tab xs = fold (Inttab.insert_list (op =)) xs tab
    47 
       
    48 (* Timing *)
       
    49 fun ext_time_add (b1, t1) (b2, t2) = (b1 orelse b2, Time.+(t1,t2))
       
    50 val no_time = (false, Time.zeroTime)
       
    51 
    48 
    52 (* Main function for shrinking proofs *)
    49 (* Main function for shrinking proofs *)
    53 fun shrink_proof debug ctxt type_enc lam_trans preplay preplay_timeout
    50 fun shrink_proof debug ctxt type_enc lam_trans preplay preplay_timeout
    54                  isar_shrink proof =
    51                  isar_shrink proof =
    55   let
    52   let
    62       val metis_fail = ref false
    59       val metis_fail = ref false
    63     in
    60     in
    64       fun handle_metis_fail try_metis () =
    61       fun handle_metis_fail try_metis () =
    65         try_metis () handle exn =>
    62         try_metis () handle exn =>
    66           (if Exn.is_interrupt exn orelse debug then reraise exn
    63           (if Exn.is_interrupt exn orelse debug then reraise exn
    67            else metis_fail := true; SOME Time.zeroTime)
    64            else metis_fail := true; some_preplay_time)
    68       fun get_time lazy_time =
    65       fun get_time lazy_time =
    69         if !metis_fail then SOME Time.zeroTime else Lazy.force lazy_time
    66         if !metis_fail andalso not (Lazy.is_finished lazy_time)
       
    67           then some_preplay_time
       
    68           else Lazy.force lazy_time
    70       val metis_fail = fn () => !metis_fail
    69       val metis_fail = fn () => !metis_fail
    71     end
    70     end
    72 
    71 
    73     (* Shrink proof on top level - do not shrink case splits *)
    72     (* Shrink proof on top level - do not shrink case splits *)
    74     fun shrink_top_level on_top_level ctxt proof =
    73     fun shrink_top_level on_top_level ctxt proof =
   113         |> Inttab.make_list
   112         |> Inttab.make_list
   114 
   113 
   115       (* cache metis preplay times in lazy time vector *)
   114       (* cache metis preplay times in lazy time vector *)
   116       val metis_time =
   115       val metis_time =
   117         v_map_index
   116         v_map_index
   118           (if not preplay then K (SOME Time.zeroTime) #> Lazy.value
   117           (if not preplay then K (zero_preplay_time) #> Lazy.value
   119            else
   118            else
   120              apsnd the
   119              apsnd the (* step *)
   121              #> apfst (fn i => try (get (i-1) #> the) proof_vect)
   120              #> apfst (fn i => try (get (i-1) #> the) proof_vect) (* succedent *)
   122              #> try_metis debug type_enc lam_trans ctxt preplay_timeout
   121              #> try_metis debug type_enc lam_trans ctxt preplay_timeout
   123              #> handle_metis_fail
   122              #> handle_metis_fail
   124              #> Lazy.lazy)
   123              #> Lazy.lazy)
   125           proof_vect
   124           proof_vect
   126 
   125 
   127       fun sum_up_time lazy_time_vector =
   126       fun sum_up_time lazy_time_vector =
   128         Vector.foldl
   127         Vector.foldl
   129           ((fn (SOME t, (b, ts)) => (b, Time.+(t, ts))
   128           (apfst get_time #> uncurry add_preplay_time)
   130              | (NONE, (_, ts)) => (true, Time.+(ts, preplay_timeout)))
   129           zero_preplay_time lazy_time_vector
   131             o apfst get_time)
       
   132           no_time lazy_time_vector
       
   133 
   130 
   134       (* Merging *)
   131       (* Merging *)
   135       (* TODO: consider adding "Obtain" cases *)
       
   136       fun merge (Prove (_, label1, _, By_Metis (lfs1, gfs1))) step2 =
   132       fun merge (Prove (_, label1, _, By_Metis (lfs1, gfs1))) step2 =
   137           let
   133           let
   138             val (step_constructor, lfs2, gfs2) =
   134             val (step_constructor, lfs2, gfs2) =
   139               (case step2 of
   135               (case step2 of
   140                 (Prove (qs2, label2, t, By_Metis (lfs2, gfs2))) =>
   136                 (Prove (qs2, label2, t, By_Metis (lfs2, gfs2))) =>
   141                   (fn by => Prove (qs2, label2, t, by), lfs2, gfs2)
   137                   (fn by => Prove (qs2, label2, t, by), lfs2, gfs2)
   142               | (Obtain(qs2, xs, label2, t, By_Metis (lfs2, gfs2))) =>
   138               | (Obtain (qs2, xs, label2, t, By_Metis (lfs2, gfs2))) =>
   143                   (fn by => Obtain (qs2, xs, label2, t, by), lfs2, gfs2)
   139                   (fn by => Obtain (qs2, xs, label2, t, by), lfs2, gfs2)
   144               | _ => error "Internal error: unmergeable Isar steps" )
   140               | _ => error "sledgehammer_shrink: unmergeable Isar steps" )
   145             val lfs = remove (op =) label1 lfs2 |> union (op =) lfs1
   141             val lfs = remove (op =) label1 lfs2 |> union (op =) lfs1
   146             val gfs = union (op =) gfs1 gfs2
   142             val gfs = union (op =) gfs1 gfs2
   147           in step_constructor (By_Metis (lfs, gfs)) end
   143           in step_constructor (By_Metis (lfs, gfs)) end
   148         | merge _ _ = error "Internal error: unmergeable Isar steps"
   144         | merge _ _ = error "sledgehammer_shrink: unmergeable Isar steps"
   149 
   145 
   150       fun try_merge metis_time (s1, i) (s2, j) =
   146       fun try_merge metis_time (s1, i) (s2, j) =
   151         (case get i metis_time |> Lazy.force of
   147         if not preplay then (merge s1 s2 |> SOME, metis_time)
   152           NONE => (NONE, metis_time)
   148         else
   153         | SOME t1 =>
   149           (case get i metis_time |> Lazy.force of
   154           (case get j metis_time |> Lazy.force of
   150             (true, _) => (NONE, metis_time)
   155             NONE => (NONE, metis_time)
   151           | (_, t1) =>
   156           | SOME t2 =>
   152             (case get j metis_time |> Lazy.force of
   157             let
   153               (true, _) => (NONE, metis_time)
   158               val s12 = merge s1 s2
   154             | (_, t2) =>
   159               val timeout = time_mult merge_timeout_slack (Time.+(t1, t2))
   155               let
   160             in
   156                 val s12 = merge s1 s2
   161               case try_metis_quietly debug type_enc lam_trans ctxt timeout
   157                 val timeout = time_mult merge_timeout_slack (Time.+(t1, t2))
   162               (NONE, s12) () of
   158               in
   163                 NONE => (NONE, metis_time)
   159                 case try_metis_quietly debug type_enc lam_trans ctxt timeout
   164               | some_t12 =>
   160                 (NONE, s12) () of
   165                 (SOME s12, metis_time
   161                   (true, _) => (NONE, metis_time)
   166                            |> replace (Time.zeroTime |> SOME |> Lazy.value) i
   162                 | exact_time =>
   167                            |> replace (Lazy.value some_t12) j)
   163                   (SOME s12, metis_time
   168 
   164                              |> replace (zero_preplay_time |> Lazy.value) i
   169             end))
   165                              |> replace (Lazy.value exact_time) j)
       
   166 
       
   167               end))
   170 
   168 
   171       fun merge_steps metis_time proof_vect refed_by cand_tab n' n_metis' =
   169       fun merge_steps metis_time proof_vect refed_by cand_tab n' n_metis' =
   172         if Inttab.is_empty cand_tab
   170         if Inttab.is_empty cand_tab
   173           orelse n_metis' <= target_n_metis
   171           orelse n_metis' <= target_n_metis
   174           orelse (on_top_level andalso n'<3)
   172           orelse (on_top_level andalso n'<3)
   223         (* Shrink case_splits and top-levl *)
   221         (* Shrink case_splits and top-levl *)
   224         val ((proof, top_level_time), lower_level_time) =
   222         val ((proof, top_level_time), lower_level_time) =
   225           proof |> do_case_splits rich_ctxt
   223           proof |> do_case_splits rich_ctxt
   226                 |>> shrink_top_level on_top_level rich_ctxt
   224                 |>> shrink_top_level on_top_level rich_ctxt
   227       in
   225       in
   228         (proof, ext_time_add lower_level_time top_level_time)
   226         (proof, add_preplay_time lower_level_time top_level_time)
   229       end
   227       end
   230 
   228 
   231     and do_case_splits ctxt proof =
   229     and do_case_splits ctxt proof =
   232       let
   230       let
   233         fun shrink_each_and_collect_time shrink candidates =
   231         fun shrink_each_and_collect_time shrink candidates =
   234           let fun f_m cand time = shrink cand ||> ext_time_add time
   232           let fun f_m cand time = shrink cand ||> add_preplay_time time
   235           in fold_map f_m candidates no_time end
   233           in fold_map f_m candidates zero_preplay_time end
   236         val shrink_case_split =
   234         val shrink_case_split =
   237           shrink_each_and_collect_time (do_proof false ctxt)
   235           shrink_each_and_collect_time (do_proof false ctxt)
   238         fun shrink (Prove (qs, l, t, Case_Split (cases, facts))) =
   236         fun shrink (Prove (qs, l, t, Case_Split (cases, facts))) =
   239             let val (cases, time) = shrink_case_split cases
   237             let val (cases, time) = shrink_case_split cases
   240             in (Prove (qs, l, t, Case_Split (cases, facts)), time) end
   238             in (Prove (qs, l, t, Case_Split (cases, facts)), time) end
   241           | shrink step = (step, no_time)
   239           | shrink step = (step, zero_preplay_time)
   242       in
   240       in
   243         shrink_each_and_collect_time shrink proof
   241         shrink_each_and_collect_time shrink proof
   244       end
   242       end
   245   in
   243   in
   246     do_proof true ctxt proof
   244     do_proof true ctxt proof