src/HOL/Tools/Sledgehammer/sledgehammer_reconstruct.ML
changeset 50258 1c708d7728c7
parent 50257 bafbc4a3d976
child 50259 9c64a52ae499
equal deleted inserted replaced
50257:bafbc4a3d976 50258:1c708d7728c7
    51 open ATP_Problem
    51 open ATP_Problem
    52 open ATP_Proof
    52 open ATP_Proof
    53 open ATP_Problem_Generate
    53 open ATP_Problem_Generate
    54 open ATP_Proof_Reconstruct
    54 open ATP_Proof_Reconstruct
    55 open Sledgehammer_Util
    55 open Sledgehammer_Util
       
    56 open Sledgehammer_Annotate
    56 
    57 
    57 structure String_Redirect = ATP_Proof_Redirect(
    58 structure String_Redirect = ATP_Proof_Redirect(
    58   type key = step_name
    59   type key = step_name
    59   val ord = fn ((s, _ : string list), (s', _)) => fast_string_ord (s, s')
    60   val ord = fn ((s, _ : string list), (s', _)) => fast_string_ord (s, s')
    60   val string_of = fst)
    61   val string_of = fst)
   474          (* kill next to last line, which usually results in a trivial step *)
   475          (* kill next to last line, which usually results in a trivial step *)
   475          j <> 1) then
   476          j <> 1) then
   476        Inference_Step (name, role, t, rule, deps) :: lines  (* keep line *)
   477        Inference_Step (name, role, t, rule, deps) :: lines  (* keep line *)
   477      else
   478      else
   478        map (replace_dependencies_in_line (name, deps)) lines)  (* drop line *)
   479        map (replace_dependencies_in_line (name, deps)) lines)  (* drop line *)
   479 
       
   480 (** Type annotations **)
       
   481 
       
   482 fun post_traverse_term_type' f _ (t as Const (_, T)) s = f t T s
       
   483   | post_traverse_term_type' f _ (t as Free (_, T)) s = f t T s
       
   484   | post_traverse_term_type' f _ (t as Var (_, T)) s = f t T s
       
   485   | post_traverse_term_type' f env (t as Bound i) s = f t (nth env i) s
       
   486   | post_traverse_term_type' f env (Abs (x, T1, b)) s =
       
   487     let
       
   488       val ((b', s'), T2) = post_traverse_term_type' f (T1 :: env) b s
       
   489     in f (Abs (x, T1, b')) (T1 --> T2) s' end
       
   490   | post_traverse_term_type' f env (u $ v) s =
       
   491     let
       
   492       val ((u', s'), Type (_, [_, T])) = post_traverse_term_type' f env u s
       
   493       val ((v', s''), _) = post_traverse_term_type' f env v s'
       
   494     in f (u' $ v') T s'' end
       
   495 
       
   496 fun post_traverse_term_type f s t =
       
   497   post_traverse_term_type' (fn t => fn T => fn s => (f t T s, T)) [] t s |> fst
       
   498 fun post_fold_term_type f s t =
       
   499   post_traverse_term_type (fn t => fn T => fn s => (t, f t T s)) s t |> snd
       
   500 
       
   501 (* Data structures, orders *)
       
   502 val cost_ord = prod_ord int_ord (prod_ord int_ord int_ord)
       
   503 
       
   504 structure Var_Set_Tab = Table(
       
   505   type key = indexname list
       
   506   val ord = list_ord Term_Ord.fast_indexname_ord)
       
   507 
       
   508 (* (1) Generalize types *)
       
   509 fun generalize_types ctxt t =
       
   510   t |> map_types (fn _ => dummyT)
       
   511     |> Syntax.check_term
       
   512          (Proof_Context.set_mode Proof_Context.mode_pattern ctxt)
       
   513 
       
   514 (* (2) Typing-spot table *)
       
   515 local
       
   516 fun key_of_atype (TVar (z, _)) =
       
   517     Ord_List.insert Term_Ord.fast_indexname_ord z
       
   518   | key_of_atype _ = I
       
   519 fun key_of_type T = fold_atyps key_of_atype T []
       
   520 fun update_tab t T (tab, pos) =
       
   521   (case key_of_type T of
       
   522      [] => tab
       
   523    | key =>
       
   524      let val cost = (size_of_typ T, (size_of_term t, pos)) in
       
   525        case Var_Set_Tab.lookup tab key of
       
   526          NONE => Var_Set_Tab.update_new (key, cost) tab
       
   527        | SOME old_cost =>
       
   528          (case cost_ord (cost, old_cost) of
       
   529             LESS => Var_Set_Tab.update (key, cost) tab
       
   530           | _ => tab)
       
   531      end,
       
   532    pos + 1)
       
   533 in
       
   534 val typing_spot_table =
       
   535   post_fold_term_type update_tab (Var_Set_Tab.empty, 0) #> fst
       
   536 end
       
   537 
       
   538 (* (3) Reverse-greedy *)
       
   539 fun reverse_greedy typing_spot_tab =
       
   540   let
       
   541     fun update_count z =
       
   542       fold (fn tvar => fn tab =>
       
   543         let val c = Vartab.lookup tab tvar |> the_default 0 in
       
   544           Vartab.update (tvar, c + z) tab
       
   545         end)
       
   546     fun superfluous tcount =
       
   547       forall (fn tvar => the (Vartab.lookup tcount tvar) > 1)
       
   548     fun drop_superfluous (tvars, (_, (_, spot))) (spots, tcount) =
       
   549       if superfluous tcount tvars then (spots, update_count ~1 tvars tcount)
       
   550       else (spot :: spots, tcount)
       
   551     val (typing_spots, tvar_count_tab) =
       
   552       Var_Set_Tab.fold
       
   553         (fn kv as (k, _) => apfst (cons kv) #> apsnd (update_count 1 k))
       
   554         typing_spot_tab ([], Vartab.empty)
       
   555       |>> sort_distinct (rev_order o cost_ord o pairself snd)
       
   556   in fold drop_superfluous typing_spots ([], tvar_count_tab) |> fst end
       
   557 
       
   558 (* (4) Introduce annotations *)
       
   559 fun introduce_annotations ctxt spots t t' =
       
   560   let
       
   561     val thy = Proof_Context.theory_of ctxt
       
   562     val get_types = post_fold_term_type (K cons) []
       
   563     fun match_types tp =
       
   564       fold (Sign.typ_match thy) (op ~~ (pairself get_types tp)) Vartab.empty
       
   565     fun unica' b x [] = if b then [x] else []
       
   566       | unica' b x (y :: ys) =
       
   567         if x = y then unica' false x ys
       
   568         else unica' true y ys |> b ? cons x
       
   569     fun unica ord xs =
       
   570       case sort ord xs of x :: ys => unica' true x ys | [] => []
       
   571     val add_all_tfree_namesT = fold_atyps (fn TFree (x, _) => cons x | _ => I)
       
   572     fun erase_unica_tfrees env =
       
   573       let
       
   574         val unica =
       
   575           Vartab.fold (add_all_tfree_namesT o snd o snd) env []
       
   576           |> filter_out (Variable.is_declared ctxt)
       
   577           |> unica fast_string_ord
       
   578         val erase_unica = map_atyps
       
   579           (fn T as TFree (s, _) =>
       
   580               if Ord_List.member fast_string_ord unica s then dummyT else T
       
   581             | T => T)
       
   582       in Vartab.map (K (apsnd erase_unica)) env end
       
   583     val env = match_types (t', t) |> erase_unica_tfrees
       
   584     fun get_annot env (TFree _) = (false, (env, dummyT))
       
   585       | get_annot env (T as TVar (v, S)) =
       
   586         let val T' = Envir.subst_type env T in
       
   587           if T' = dummyT then (false, (env, dummyT))
       
   588           else (true, (Vartab.update (v, (S, dummyT)) env, T'))
       
   589         end
       
   590       | get_annot env (Type (S, Ts)) =
       
   591         (case fold_rev (fn T => fn (b, (env, Ts)) =>
       
   592                   let
       
   593                     val (b', (env', T)) = get_annot env T
       
   594                   in (b orelse b', (env', T :: Ts)) end)
       
   595                 Ts (false, (env, [])) of
       
   596            (true, (env', Ts)) => (true, (env', Type (S, Ts)))
       
   597          | (false, (env', _)) => (false, (env', dummyT)))
       
   598     fun post1 _ T (env, cp, ps as p :: ps', annots) =
       
   599         if p <> cp then
       
   600           (env, cp + 1, ps, annots)
       
   601         else
       
   602           let val (_, (env', T')) = get_annot env T in
       
   603             (env', cp + 1, ps', (p, T') :: annots)
       
   604           end
       
   605       | post1 _ _ accum = accum
       
   606     val (_, _, _, annots) = post_fold_term_type post1 (env, 0, spots, []) t'
       
   607     fun post2 t _ (cp, annots as (p, T) :: annots') =
       
   608         if p <> cp then (t, (cp + 1, annots))
       
   609         else (Type.constraint T t, (cp + 1, annots'))
       
   610       | post2 t _ x = (t, x)
       
   611   in post_traverse_term_type post2 (0, rev annots) t |> fst end
       
   612 
       
   613 (* (5) Annotate *)
       
   614 fun annotate_types ctxt t =
       
   615   let
       
   616     val t' = generalize_types ctxt t
       
   617     val typing_spots =
       
   618       t' |> typing_spot_table
       
   619          |> reverse_greedy
       
   620          |> sort int_ord
       
   621   in introduce_annotations ctxt typing_spots t t' end
       
   622 
   480 
   623 val indent_size = 2
   481 val indent_size = 2
   624 val no_label = ("", ~1)
   482 val no_label = ("", ~1)
   625 
   483 
   626 fun string_for_proof ctxt type_enc lam_trans i n =
   484 fun string_for_proof ctxt type_enc lam_trans i n =