src/HOL/Tools/Function/context_tree.ML
author wenzelm
Fri Jul 17 23:11:40 2009 +0200 (2009-07-17)
changeset 32035 8e77b6a250d5
parent 31775 2b04504fcb69
child 33037 b22e44496dc2
permissions -rw-r--r--
tuned/modernized Envir.subst_XXX;
     1 (*  Title:      HOL/Tools/Function/context_tree.ML
     2     Author:     Alexander Krauss, TU Muenchen
     3 
     4 A package for general recursive function definitions. 
     5 Builds and traverses trees of nested contexts along a term.
     6 *)
     7 
     8 signature FUNDEF_CTXTREE =
     9 sig
    10     type ctxt = (string * typ) list * thm list (* poor man's contexts: fixes + assumes *)
    11     type ctx_tree
    12 
    13     (* FIXME: This interface is a mess and needs to be cleaned up! *)
    14     val get_fundef_congs : Proof.context -> thm list
    15     val add_fundef_cong : thm -> Context.generic -> Context.generic
    16     val map_fundef_congs : (thm list -> thm list) -> Context.generic -> Context.generic
    17 
    18     val cong_add: attribute
    19     val cong_del: attribute
    20 
    21     val mk_tree: (string * typ) -> term -> Proof.context -> term -> ctx_tree
    22 
    23     val inst_tree: theory -> term -> term -> ctx_tree -> ctx_tree
    24 
    25     val export_term : ctxt -> term -> term
    26     val export_thm : theory -> ctxt -> thm -> thm
    27     val import_thm : theory -> ctxt -> thm -> thm
    28 
    29     val traverse_tree : 
    30    (ctxt -> term ->
    31    (ctxt * thm) list ->
    32    (ctxt * thm) list * 'b ->
    33    (ctxt * thm) list * 'b)
    34    -> ctx_tree -> 'b -> 'b
    35 
    36     val rewrite_by_tree : theory -> term -> thm -> (thm * thm) list -> ctx_tree -> thm * (thm * thm) list
    37 end
    38 
    39 structure FundefCtxTree : FUNDEF_CTXTREE =
    40 struct
    41 
    42 type ctxt = (string * typ) list * thm list
    43 
    44 open FundefCommon
    45 open FundefLib
    46 
    47 structure FundefCongs = GenericDataFun
    48 (
    49   type T = thm list
    50   val empty = []
    51   val extend = I
    52   fun merge _ = Thm.merge_thms
    53 );
    54 
    55 val get_fundef_congs = FundefCongs.get o Context.Proof
    56 val map_fundef_congs = FundefCongs.map
    57 val add_fundef_cong = FundefCongs.map o Thm.add_thm
    58 
    59 (* congruence rules *)
    60 
    61 val cong_add = Thm.declaration_attribute (map_fundef_congs o Thm.add_thm o safe_mk_meta_eq);
    62 val cong_del = Thm.declaration_attribute (map_fundef_congs o Thm.del_thm o safe_mk_meta_eq);
    63 
    64 
    65 type depgraph = int IntGraph.T
    66 
    67 datatype ctx_tree 
    68   = Leaf of term
    69   | Cong of (thm * depgraph * (ctxt * ctx_tree) list)
    70   | RCall of (term * ctx_tree)
    71 
    72 
    73 (* Maps "Trueprop A = B" to "A" *)
    74 val rhs_of = snd o HOLogic.dest_eq o HOLogic.dest_Trueprop
    75 
    76 
    77 (*** Dependency analysis for congruence rules ***)
    78 
    79 fun branch_vars t = 
    80     let 
    81       val t' = snd (dest_all_all t)
    82       val (assumes, concl) = Logic.strip_horn t'
    83     in (fold Term.add_vars assumes [], Term.add_vars concl [])
    84     end
    85 
    86 fun cong_deps crule =
    87     let
    88       val num_branches = map_index (apsnd branch_vars) (prems_of crule)
    89     in
    90       IntGraph.empty
    91         |> fold (fn (i,_)=> IntGraph.new_node (i,i)) num_branches
    92         |> fold_product (fn (i, (c1, _)) => fn (j, (_, t2)) => 
    93                if i = j orelse null (c1 inter t2) 
    94                then I else IntGraph.add_edge_acyclic (i,j))
    95              num_branches num_branches
    96     end
    97     
    98 val default_congs = map (fn c => c RS eq_reflection) [@{thm "cong"}, @{thm "ext"}] 
    99 
   100 
   101 
   102 (* Called on the INSTANTIATED branches of the congruence rule *)
   103 fun mk_branch ctx t = 
   104     let
   105       val (ctx', fixes, impl) = dest_all_all_ctx ctx t
   106       val (assms, concl) = Logic.strip_horn impl
   107     in
   108       (ctx', fixes, assms, rhs_of concl)
   109     end
   110     
   111 fun find_cong_rule ctx fvar h ((r,dep)::rs) t =
   112     (let
   113        val thy = ProofContext.theory_of ctx
   114 
   115        val tt' = Logic.mk_equals (Pattern.rewrite_term thy [(Free fvar, h)] [] t, t)
   116        val (c, subs) = (concl_of r, prems_of r)
   117 
   118        val subst = Pattern.match (ProofContext.theory_of ctx) (c, tt') (Vartab.empty, Vartab.empty)
   119        val branches = map (mk_branch ctx o Envir.beta_norm o Envir.subst_term subst) subs
   120        val inst = map (fn v =>
   121         (cterm_of thy (Var v), cterm_of thy (Envir.subst_term subst (Var v)))) (Term.add_vars c [])
   122      in
   123    (cterm_instantiate inst r, dep, branches)
   124      end
   125     handle Pattern.MATCH => find_cong_rule ctx fvar h rs t)
   126   | find_cong_rule _ _ _ [] _ = sys_error "Function/context_tree.ML: No cong rule found!"
   127 
   128 
   129 fun mk_tree fvar h ctxt t =
   130     let 
   131       val congs = get_fundef_congs ctxt
   132       val congs_deps = map (fn c => (c, cong_deps c)) (congs @ default_congs) (* FIXME: Save in theory *)
   133 
   134       fun matchcall (a $ b) = if a = Free fvar then SOME b else NONE
   135         | matchcall _ = NONE
   136 
   137       fun mk_tree' ctx t =
   138           case matchcall t of
   139             SOME arg => RCall (t, mk_tree' ctx arg)
   140           | NONE => 
   141             if not (exists_subterm (fn Free v => v = fvar | _ => false) t) then Leaf t
   142             else 
   143               let val (r, dep, branches) = find_cong_rule ctx fvar h congs_deps t in
   144                 Cong (r, dep, 
   145                       map (fn (ctx', fixes, assumes, st) => 
   146                               ((fixes, map (assume o cterm_of (ProofContext.theory_of ctx)) assumes), 
   147                                mk_tree' ctx' st)) branches)
   148               end
   149     in
   150       mk_tree' ctxt t
   151     end
   152     
   153 
   154 fun inst_tree thy fvar f tr =
   155     let
   156       val cfvar = cterm_of thy fvar
   157       val cf = cterm_of thy f
   158                
   159       fun inst_term t = 
   160           subst_bound(f, abstract_over (fvar, t))
   161 
   162       val inst_thm = forall_elim cf o forall_intr cfvar 
   163 
   164       fun inst_tree_aux (Leaf t) = Leaf t
   165         | inst_tree_aux (Cong (crule, deps, branches)) =
   166           Cong (inst_thm crule, deps, map inst_branch branches)
   167         | inst_tree_aux (RCall (t, str)) =
   168           RCall (inst_term t, inst_tree_aux str)
   169       and inst_branch ((fxs, assms), str) = 
   170           ((fxs, map (assume o cterm_of thy o inst_term o prop_of) assms), inst_tree_aux str)
   171     in
   172       inst_tree_aux tr
   173     end
   174 
   175 
   176 (* Poor man's contexts: Only fixes and assumes *)
   177 fun compose (fs1, as1) (fs2, as2) = (fs1 @ fs2, as1 @ as2)
   178 
   179 fun export_term (fixes, assumes) =
   180     fold_rev (curry Logic.mk_implies o prop_of) assumes 
   181  #> fold_rev (Logic.all o Free) fixes
   182 
   183 fun export_thm thy (fixes, assumes) =
   184     fold_rev (implies_intr o cprop_of) assumes
   185  #> fold_rev (forall_intr o cterm_of thy o Free) fixes
   186 
   187 fun import_thm thy (fixes, athms) =
   188     fold (forall_elim o cterm_of thy o Free) fixes
   189  #> fold Thm.elim_implies athms
   190 
   191 
   192 (* folds in the order of the dependencies of a graph. *)
   193 fun fold_deps G f x =
   194     let
   195       fun fill_table i (T, x) =
   196           case Inttab.lookup T i of
   197             SOME _ => (T, x)
   198           | NONE => 
   199             let
   200               val (T', x') = fold fill_table (IntGraph.imm_succs G i) (T, x)
   201               val (v, x'') = f (the o Inttab.lookup T') i x'
   202             in
   203               (Inttab.update (i, v) T', x'')
   204             end
   205             
   206       val (T, x) = fold fill_table (IntGraph.keys G) (Inttab.empty, x)
   207     in
   208       (Inttab.fold (cons o snd) T [], x)
   209     end
   210     
   211 fun traverse_tree rcOp tr =
   212     let 
   213   fun traverse_help ctx (Leaf _) _ x = ([], x)
   214     | traverse_help ctx (RCall (t, st)) u x =
   215       rcOp ctx t u (traverse_help ctx st u x)
   216     | traverse_help ctx (Cong (_, deps, branches)) u x =
   217       let
   218     fun sub_step lu i x =
   219         let
   220           val (ctx', subtree) = nth branches i
   221           val used = fold_rev (append o lu) (IntGraph.imm_succs deps i) u
   222           val (subs, x') = traverse_help (compose ctx ctx') subtree used x
   223           val exported_subs = map (apfst (compose ctx')) subs (* FIXME: Right order of composition? *)
   224         in
   225           (exported_subs, x')
   226         end
   227       in
   228         fold_deps deps sub_step x
   229           |> apfst flat
   230       end
   231     in
   232       snd o traverse_help ([], []) tr []
   233     end
   234 
   235 fun rewrite_by_tree thy h ih x tr =
   236     let
   237       fun rewrite_help _ _ x (Leaf t) = (reflexive (cterm_of thy t), x)
   238         | rewrite_help fix h_as x (RCall (_ $ arg, st)) =
   239           let
   240             val (inner, (lRi,ha)::x') = rewrite_help fix h_as x st (* "a' = a" *)
   241                                                      
   242             val iha = import_thm thy (fix, h_as) ha (* (a', h a') : G *)
   243                  |> Conv.fconv_rule (Conv.arg_conv (Conv.comb_conv (Conv.arg_conv (K inner))))
   244                                                     (* (a, h a) : G   *)
   245             val inst_ih = instantiate' [] [SOME (cterm_of thy arg)] ih
   246             val eq = implies_elim (implies_elim inst_ih lRi) iha (* h a = f a *)
   247                      
   248             val h_a'_eq_h_a = combination (reflexive (cterm_of thy h)) inner
   249             val h_a_eq_f_a = eq RS eq_reflection
   250             val result = transitive h_a'_eq_h_a h_a_eq_f_a
   251           in
   252             (result, x')
   253           end
   254         | rewrite_help fix h_as x (Cong (crule, deps, branches)) =
   255           let
   256             fun sub_step lu i x =
   257                 let
   258                   val ((fixes, assumes), st) = nth branches i
   259                   val used = map lu (IntGraph.imm_succs deps i)
   260                              |> map (fn u_eq => (u_eq RS sym) RS eq_reflection)
   261                              |> filter_out Thm.is_reflexive
   262 
   263                   val assumes' = map (simplify (HOL_basic_ss addsimps used)) assumes
   264                                  
   265                   val (subeq, x') = rewrite_help (fix @ fixes) (h_as @ assumes') x st
   266                   val subeq_exp = export_thm thy (fixes, assumes) (subeq RS meta_eq_to_obj_eq)
   267                 in
   268                   (subeq_exp, x')
   269                 end
   270                 
   271             val (subthms, x') = fold_deps deps sub_step x
   272           in
   273             (fold_rev (curry op COMP) subthms crule, x')
   274           end
   275     in
   276       rewrite_help [] [] x tr
   277     end
   278     
   279 end