src/HOL/Tools/Function/function_context_tree.ML
author wenzelm
Sat Jul 25 23:41:53 2015 +0200 (2015-07-25)
changeset 60781 2da59cdf531c
parent 60695 757549b4bbe6
child 60801 7664e0916eec
permissions -rw-r--r--
updated to infer_instantiate;
tuned;
     1 (*  Title:      HOL/Tools/Function/function_context_tree.ML
     2     Author:     Alexander Krauss, TU Muenchen
     3 
     4 Construction and traversal of trees of nested contexts along a term.
     5 *)
     6 
     7 signature FUNCTION_CONTEXT_TREE =
     8 sig
     9   (* poor man's contexts: fixes + assumes *)
    10   type ctxt = (string * typ) list * thm list
    11   type ctx_tree
    12 
    13   (* FIXME: This interface is a mess and needs to be cleaned up! *)
    14   val get_function_congs : Proof.context -> thm list
    15   val add_function_cong : thm -> Context.generic -> Context.generic
    16   val map_function_congs : (thm list -> thm list) -> Context.generic -> Context.generic
    17 
    18   val cong_add: attribute
    19   val cong_del: attribute
    20 
    21   val mk_tree: (string * typ) -> term -> Proof.context -> term -> ctx_tree
    22 
    23   val inst_tree: Proof.context -> term -> term -> ctx_tree -> ctx_tree
    24 
    25   val export_term : ctxt -> term -> term
    26   val export_thm : Proof.context -> ctxt -> thm -> thm
    27   val import_thm : Proof.context -> ctxt -> thm -> thm
    28 
    29   val traverse_tree :
    30    (ctxt -> term ->
    31    (ctxt * thm) list ->
    32    (ctxt * thm) list * 'b ->
    33    (ctxt * thm) list * 'b)
    34    -> ctx_tree -> 'b -> 'b
    35 
    36   val rewrite_by_tree : Proof.context -> term -> thm -> (thm * thm) list ->
    37     ctx_tree -> thm * (thm * thm) list
    38 end
    39 
    40 structure Function_Context_Tree : FUNCTION_CONTEXT_TREE =
    41 struct
    42 
    43 type ctxt = (string * typ) list * thm list
    44 
    45 open Function_Common
    46 open Function_Lib
    47 
    48 structure FunctionCongs = Generic_Data
    49 (
    50   type T = thm list
    51   val empty = []
    52   val extend = I
    53   val merge = Thm.merge_thms
    54 );
    55 
    56 val get_function_congs = FunctionCongs.get o Context.Proof
    57 val map_function_congs = FunctionCongs.map
    58 val add_function_cong = FunctionCongs.map o Thm.add_thm
    59 
    60 (* congruence rules *)
    61 
    62 val cong_add = Thm.declaration_attribute (map_function_congs o Thm.add_thm o safe_mk_meta_eq);
    63 val cong_del = Thm.declaration_attribute (map_function_congs o Thm.del_thm o safe_mk_meta_eq);
    64 
    65 
    66 type depgraph = int Int_Graph.T
    67 
    68 datatype ctx_tree =
    69   Leaf of term
    70   | Cong of (thm * depgraph * (ctxt * ctx_tree) list)
    71   | RCall of (term * ctx_tree)
    72 
    73 
    74 (* Maps "Trueprop A = B" to "A" *)
    75 val rhs_of = snd o HOLogic.dest_eq o HOLogic.dest_Trueprop
    76 
    77 
    78 (*** Dependency analysis for congruence rules ***)
    79 
    80 fun branch_vars t =
    81   let
    82     val t' = snd (dest_all_all t)
    83     val (assumes, concl) = Logic.strip_horn t'
    84   in
    85     (fold Term.add_vars assumes [], Term.add_vars concl [])
    86   end
    87 
    88 fun cong_deps crule =
    89   let
    90     val num_branches = map_index (apsnd branch_vars) (Thm.prems_of crule)
    91   in
    92     Int_Graph.empty
    93     |> fold (fn (i,_)=> Int_Graph.new_node (i,i)) num_branches
    94     |> fold_product (fn (i, (c1, _)) => fn (j, (_, t2)) =>
    95          if i = j orelse null (inter (op =) c1 t2)
    96          then I else Int_Graph.add_edge_acyclic (i,j))
    97        num_branches num_branches
    98     end
    99 
   100 val default_congs =
   101   map (fn c => c RS eq_reflection) [@{thm "cong"}, @{thm "ext"}]
   102 
   103 (* Called on the INSTANTIATED branches of the congruence rule *)
   104 fun mk_branch ctxt t =
   105   let
   106     val ((params, impl), ctxt') = Variable.focus NONE t ctxt
   107     val (assms, concl) = Logic.strip_horn impl
   108   in
   109     (ctxt', map #2 params, assms, rhs_of concl)
   110   end
   111 
   112 fun find_cong_rule ctxt fvar h ((r,dep)::rs) t =
   113      (let
   114         val thy = Proof_Context.theory_of ctxt
   115 
   116         val tt' = Logic.mk_equals (Pattern.rewrite_term thy [(Free fvar, h)] [] t, t)
   117         val (c, subs) = (Thm.concl_of r, Thm.prems_of r)
   118 
   119         val subst = Pattern.match thy (c, tt') (Vartab.empty, Vartab.empty)
   120         val branches = map (mk_branch ctxt o Envir.beta_norm o Envir.subst_term subst) subs
   121         val inst =
   122           map (fn v => (#1 v, Thm.cterm_of ctxt (Envir.subst_term subst (Var v))))
   123             (Term.add_vars c [])
   124       in
   125         (infer_instantiate ctxt inst r, dep, branches)
   126       end handle Pattern.MATCH => find_cong_rule ctxt fvar h rs t)
   127   | find_cong_rule _ _ _ [] _ = raise General.Fail "No cong rule found!"
   128 
   129 
   130 fun mk_tree fvar h ctxt t =
   131   let
   132     val congs = get_function_congs ctxt
   133 
   134     (* FIXME: Save in theory: *)
   135     val congs_deps = map (fn c => (c, cong_deps c)) (congs @ default_congs)
   136 
   137     fun matchcall (a $ b) = if a = Free fvar then SOME b else NONE
   138       | matchcall _ = NONE
   139 
   140     fun mk_tree' ctxt t =
   141       case matchcall t of
   142         SOME arg => RCall (t, mk_tree' ctxt arg)
   143       | NONE =>
   144         if not (exists_subterm (fn Free v => v = fvar | _ => false) t) then Leaf t
   145         else
   146           let
   147             val (r, dep, branches) = find_cong_rule ctxt fvar h congs_deps t
   148             fun subtree (ctxt', fixes, assumes, st) =
   149               ((fixes,
   150                 map (Thm.assume o Thm.cterm_of ctxt) assumes),
   151                mk_tree' ctxt' st)
   152           in
   153             Cong (r, dep, map subtree branches)
   154           end
   155   in
   156     mk_tree' ctxt t
   157   end
   158 
   159 fun inst_tree ctxt fvar f tr =
   160   let
   161     val cfvar = Thm.cterm_of ctxt fvar
   162     val cf = Thm.cterm_of ctxt f
   163 
   164     fun inst_term t =
   165       subst_bound(f, abstract_over (fvar, t))
   166 
   167     val inst_thm = Thm.forall_elim cf o Thm.forall_intr cfvar
   168 
   169     fun inst_tree_aux (Leaf t) = Leaf t
   170       | inst_tree_aux (Cong (crule, deps, branches)) =
   171         Cong (inst_thm crule, deps, map inst_branch branches)
   172       | inst_tree_aux (RCall (t, str)) =
   173         RCall (inst_term t, inst_tree_aux str)
   174     and inst_branch ((fxs, assms), str) =
   175       ((fxs, map (Thm.assume o Thm.cterm_of ctxt o inst_term o Thm.prop_of) assms),
   176        inst_tree_aux str)
   177   in
   178     inst_tree_aux tr
   179   end
   180 
   181 
   182 (* Poor man's contexts: Only fixes and assumes *)
   183 fun compose (fs1, as1) (fs2, as2) = (fs1 @ fs2, as1 @ as2)
   184 
   185 fun export_term (fixes, assumes) =
   186  fold_rev (curry Logic.mk_implies o Thm.prop_of) assumes
   187  #> fold_rev (Logic.all o Free) fixes
   188 
   189 fun export_thm ctxt (fixes, assumes) =
   190  fold_rev (Thm.implies_intr o Thm.cprop_of) assumes
   191  #> fold_rev (Thm.forall_intr o Thm.cterm_of ctxt o Free) fixes
   192 
   193 fun import_thm ctxt (fixes, athms) =
   194  fold (Thm.forall_elim o Thm.cterm_of ctxt o Free) fixes
   195  #> fold Thm.elim_implies athms
   196 
   197 
   198 (* folds in the order of the dependencies of a graph. *)
   199 fun fold_deps G f x =
   200   let
   201     fun fill_table i (T, x) =
   202       case Inttab.lookup T i of
   203         SOME _ => (T, x)
   204       | NONE =>
   205         let
   206           val (T', x') = Int_Graph.Keys.fold fill_table (Int_Graph.imm_succs G i) (T, x)
   207           val (v, x'') = f (the o Inttab.lookup T') i x'
   208         in
   209           (Inttab.update (i, v) T', x'')
   210         end
   211 
   212     val (T, x) = fold fill_table (Int_Graph.keys G) (Inttab.empty, x)
   213   in
   214     (Inttab.fold (cons o snd) T [], x)
   215   end
   216 
   217 fun traverse_tree rcOp tr =
   218   let
   219     fun traverse_help ctxt (Leaf _) _ x = ([], x)
   220       | traverse_help ctxt (RCall (t, st)) u x =
   221           rcOp ctxt t u (traverse_help ctxt st u x)
   222       | traverse_help ctxt (Cong (_, deps, branches)) u x =
   223           let
   224             fun sub_step lu i x =
   225               let
   226                 val (ctxt', subtree) = nth branches i
   227                 val used = Int_Graph.Keys.fold_rev (append o lu) (Int_Graph.imm_succs deps i) u
   228                 val (subs, x') = traverse_help (compose ctxt ctxt') subtree used x
   229                 val exported_subs = map (apfst (compose ctxt')) subs (* FIXME: Right order of composition? *)
   230               in
   231                 (exported_subs, x')
   232               end
   233           in
   234             fold_deps deps sub_step x
   235             |> apfst flat
   236           end
   237   in
   238     snd o traverse_help ([], []) tr []
   239   end
   240 
   241 fun rewrite_by_tree ctxt h ih x tr =
   242   let
   243     fun rewrite_help _ _ x (Leaf t) = (Thm.reflexive (Thm.cterm_of ctxt t), x)
   244       | rewrite_help fix h_as x (RCall (_ $ arg, st)) =
   245         let
   246           val (inner, (lRi,ha)::x') = rewrite_help fix h_as x st (* "a' = a" *)
   247 
   248           val iha = import_thm ctxt (fix, h_as) ha (* (a', h a') : G *)
   249             |> Conv.fconv_rule (Conv.arg_conv (Conv.comb_conv (Conv.arg_conv (K inner))))
   250                                                     (* (a, h a) : G   *)
   251           val inst_ih = instantiate' [] [SOME (Thm.cterm_of ctxt arg)] ih
   252           val eq = Thm.implies_elim (Thm.implies_elim inst_ih lRi) iha (* h a = f a *)
   253 
   254           val h_a'_eq_h_a = Thm.combination (Thm.reflexive (Thm.cterm_of ctxt h)) inner
   255           val h_a_eq_f_a = eq RS eq_reflection
   256           val result = Thm.transitive h_a'_eq_h_a h_a_eq_f_a
   257         in
   258           (result, x')
   259         end
   260       | rewrite_help fix h_as x (Cong (crule, deps, branches)) =
   261         let
   262           fun sub_step lu i x =
   263             let
   264               val ((fixes, assumes), st) = nth branches i
   265               val used = map lu (Int_Graph.immediate_succs deps i)
   266                 |> map (fn u_eq => (u_eq RS sym) RS eq_reflection)
   267                 |> filter_out Thm.is_reflexive
   268 
   269               val assumes' = map (simplify (put_simpset HOL_basic_ss ctxt addsimps used)) assumes
   270 
   271               val (subeq, x') =
   272                 rewrite_help (fix @ fixes) (h_as @ assumes') x st
   273               val subeq_exp =
   274                 export_thm ctxt (fixes, assumes) (subeq RS meta_eq_to_obj_eq)
   275             in
   276               (subeq_exp, x')
   277             end
   278           val (subthms, x') = fold_deps deps sub_step x
   279         in
   280           (fold_rev (curry op COMP) subthms crule, x')
   281         end
   282   in
   283     rewrite_help [] [] x tr
   284   end
   285 
   286 end