src/HOL/Tools/Function/context_tree.ML
changeset 34232 36a2a3029fd3
parent 33519 e31a85f92ce9
child 35403 25a67a606782
equal deleted inserted replaced
34231:da4d7d40f2f9 34232:36a2a3029fd3
     1 (*  Title:      HOL/Tools/Function/context_tree.ML
     1 (*  Title:      HOL/Tools/Function/context_tree.ML
     2     Author:     Alexander Krauss, TU Muenchen
     2     Author:     Alexander Krauss, TU Muenchen
     3 
     3 
     4 A package for general recursive function definitions. 
     4 A package for general recursive function definitions.
     5 Builds and traverses trees of nested contexts along a term.
     5 Builds and traverses trees of nested contexts along a term.
     6 *)
     6 *)
     7 
     7 
     8 signature FUNCTION_CTXTREE =
     8 signature FUNCTION_CTXTREE =
     9 sig
     9 sig
    10     type ctxt = (string * typ) list * thm list (* poor man's contexts: fixes + assumes *)
    10   (* poor man's contexts: fixes + assumes *)
    11     type ctx_tree
    11   type ctxt = (string * typ) list * thm list
    12 
    12   type ctx_tree
    13     (* FIXME: This interface is a mess and needs to be cleaned up! *)
    13 
    14     val get_function_congs : Proof.context -> thm list
    14   (* FIXME: This interface is a mess and needs to be cleaned up! *)
    15     val add_function_cong : thm -> Context.generic -> Context.generic
    15   val get_function_congs : Proof.context -> thm list
    16     val map_function_congs : (thm list -> thm list) -> Context.generic -> Context.generic
    16   val add_function_cong : thm -> Context.generic -> Context.generic
    17 
    17   val map_function_congs : (thm list -> thm list) -> Context.generic -> Context.generic
    18     val cong_add: attribute
    18 
    19     val cong_del: attribute
    19   val cong_add: attribute
    20 
    20   val cong_del: attribute
    21     val mk_tree: (string * typ) -> term -> Proof.context -> term -> ctx_tree
    21 
    22 
    22   val mk_tree: (string * typ) -> term -> Proof.context -> term -> ctx_tree
    23     val inst_tree: theory -> term -> term -> ctx_tree -> ctx_tree
    23 
    24 
    24   val inst_tree: theory -> term -> term -> ctx_tree -> ctx_tree
    25     val export_term : ctxt -> term -> term
    25 
    26     val export_thm : theory -> ctxt -> thm -> thm
    26   val export_term : ctxt -> term -> term
    27     val import_thm : theory -> ctxt -> thm -> thm
    27   val export_thm : theory -> ctxt -> thm -> thm
    28 
    28   val import_thm : theory -> ctxt -> thm -> thm
    29     val traverse_tree : 
    29 
       
    30   val traverse_tree :
    30    (ctxt -> term ->
    31    (ctxt -> term ->
    31    (ctxt * thm) list ->
    32    (ctxt * thm) list ->
    32    (ctxt * thm) list * 'b ->
    33    (ctxt * thm) list * 'b ->
    33    (ctxt * thm) list * 'b)
    34    (ctxt * thm) list * 'b)
    34    -> ctx_tree -> 'b -> 'b
    35    -> ctx_tree -> 'b -> 'b
    35 
    36 
    36     val rewrite_by_tree : theory -> term -> thm -> (thm * thm) list -> ctx_tree -> thm * (thm * thm) list
    37   val rewrite_by_tree : theory -> term -> thm -> (thm * thm) list ->
       
    38     ctx_tree -> thm * (thm * thm) list
    37 end
    39 end
    38 
    40 
    39 structure Function_Ctx_Tree : FUNCTION_CTXTREE =
    41 structure Function_Ctx_Tree : FUNCTION_CTXTREE =
    40 struct
    42 struct
    41 
    43 
    62 val cong_del = Thm.declaration_attribute (map_function_congs o Thm.del_thm o safe_mk_meta_eq);
    64 val cong_del = Thm.declaration_attribute (map_function_congs o Thm.del_thm o safe_mk_meta_eq);
    63 
    65 
    64 
    66 
    65 type depgraph = int IntGraph.T
    67 type depgraph = int IntGraph.T
    66 
    68 
    67 datatype ctx_tree 
    69 datatype ctx_tree =
    68   = Leaf of term
    70   Leaf of term
    69   | Cong of (thm * depgraph * (ctxt * ctx_tree) list)
    71   | Cong of (thm * depgraph * (ctxt * ctx_tree) list)
    70   | RCall of (term * ctx_tree)
    72   | RCall of (term * ctx_tree)
    71 
    73 
    72 
    74 
    73 (* Maps "Trueprop A = B" to "A" *)
    75 (* Maps "Trueprop A = B" to "A" *)
    74 val rhs_of = snd o HOLogic.dest_eq o HOLogic.dest_Trueprop
    76 val rhs_of = snd o HOLogic.dest_eq o HOLogic.dest_Trueprop
    75 
    77 
    76 
    78 
    77 (*** Dependency analysis for congruence rules ***)
    79 (*** Dependency analysis for congruence rules ***)
    78 
    80 
    79 fun branch_vars t = 
    81 fun branch_vars t =
    80     let 
    82   let
    81       val t' = snd (dest_all_all t)
    83     val t' = snd (dest_all_all t)
    82       val (assumes, concl) = Logic.strip_horn t'
    84     val (assumes, concl) = Logic.strip_horn t'
    83     in (fold Term.add_vars assumes [], Term.add_vars concl [])
    85   in
       
    86     (fold Term.add_vars assumes [], Term.add_vars concl [])
       
    87   end
       
    88 
       
    89 fun cong_deps crule =
       
    90   let
       
    91     val num_branches = map_index (apsnd branch_vars) (prems_of crule)
       
    92   in
       
    93     IntGraph.empty
       
    94     |> fold (fn (i,_)=> IntGraph.new_node (i,i)) num_branches
       
    95     |> fold_product (fn (i, (c1, _)) => fn (j, (_, t2)) =>
       
    96          if i = j orelse null (inter (op =) c1 t2)
       
    97          then I else IntGraph.add_edge_acyclic (i,j))
       
    98        num_branches num_branches
    84     end
    99     end
    85 
   100 
    86 fun cong_deps crule =
   101 val default_congs =
    87     let
   102   map (fn c => c RS eq_reflection) [@{thm "cong"}, @{thm "ext"}]
    88       val num_branches = map_index (apsnd branch_vars) (prems_of crule)
       
    89     in
       
    90       IntGraph.empty
       
    91         |> fold (fn (i,_)=> IntGraph.new_node (i,i)) num_branches
       
    92         |> fold_product (fn (i, (c1, _)) => fn (j, (_, t2)) => 
       
    93                if i = j orelse null (inter (op =) c1 t2)
       
    94                then I else IntGraph.add_edge_acyclic (i,j))
       
    95              num_branches num_branches
       
    96     end
       
    97     
       
    98 val default_congs = map (fn c => c RS eq_reflection) [@{thm "cong"}, @{thm "ext"}] 
       
    99 
       
   100 
       
   101 
   103 
   102 (* Called on the INSTANTIATED branches of the congruence rule *)
   104 (* Called on the INSTANTIATED branches of the congruence rule *)
   103 fun mk_branch ctx t = 
   105 fun mk_branch ctx t =
   104     let
   106   let
   105       val (ctx', fixes, impl) = dest_all_all_ctx ctx t
   107     val (ctx', fixes, impl) = dest_all_all_ctx ctx t
   106       val (assms, concl) = Logic.strip_horn impl
   108     val (assms, concl) = Logic.strip_horn impl
   107     in
   109   in
   108       (ctx', fixes, assms, rhs_of concl)
   110     (ctx', fixes, assms, rhs_of concl)
   109     end
   111   end
   110     
   112 
   111 fun find_cong_rule ctx fvar h ((r,dep)::rs) t =
   113 fun find_cong_rule ctx fvar h ((r,dep)::rs) t =
   112     (let
   114   (let
   113        val thy = ProofContext.theory_of ctx
   115      val thy = ProofContext.theory_of ctx
   114 
   116 
   115        val tt' = Logic.mk_equals (Pattern.rewrite_term thy [(Free fvar, h)] [] t, t)
   117      val tt' = Logic.mk_equals (Pattern.rewrite_term thy [(Free fvar, h)] [] t, t)
   116        val (c, subs) = (concl_of r, prems_of r)
   118      val (c, subs) = (concl_of r, prems_of r)
   117 
   119 
   118        val subst = Pattern.match (ProofContext.theory_of ctx) (c, tt') (Vartab.empty, Vartab.empty)
   120      val subst = Pattern.match (ProofContext.theory_of ctx) (c, tt') (Vartab.empty, Vartab.empty)
   119        val branches = map (mk_branch ctx o Envir.beta_norm o Envir.subst_term subst) subs
   121      val branches = map (mk_branch ctx o Envir.beta_norm o Envir.subst_term subst) subs
   120        val inst = map (fn v =>
   122      val inst = map (fn v =>
   121         (cterm_of thy (Var v), cterm_of thy (Envir.subst_term subst (Var v)))) (Term.add_vars c [])
   123        (cterm_of thy (Var v), cterm_of thy (Envir.subst_term subst (Var v)))) (Term.add_vars c [])
   122      in
   124    in
   123    (cterm_instantiate inst r, dep, branches)
   125      (cterm_instantiate inst r, dep, branches)
   124      end
   126    end
   125     handle Pattern.MATCH => find_cong_rule ctx fvar h rs t)
   127    handle Pattern.MATCH => find_cong_rule ctx fvar h rs t)
   126   | find_cong_rule _ _ _ [] _ = sys_error "Function/context_tree.ML: No cong rule found!"
   128   | find_cong_rule _ _ _ [] _ = sys_error "Function/context_tree.ML: No cong rule found!"
   127 
   129 
   128 
   130 
   129 fun mk_tree fvar h ctxt t =
   131 fun mk_tree fvar h ctxt t =
   130     let 
   132   let
   131       val congs = get_function_congs ctxt
   133     val congs = get_function_congs ctxt
   132       val congs_deps = map (fn c => (c, cong_deps c)) (congs @ default_congs) (* FIXME: Save in theory *)
   134 
   133 
   135     (* FIXME: Save in theory: *)
   134       fun matchcall (a $ b) = if a = Free fvar then SOME b else NONE
   136     val congs_deps = map (fn c => (c, cong_deps c)) (congs @ default_congs)
   135         | matchcall _ = NONE
   137 
   136 
   138     fun matchcall (a $ b) = if a = Free fvar then SOME b else NONE
   137       fun mk_tree' ctx t =
   139       | matchcall _ = NONE
   138           case matchcall t of
   140 
   139             SOME arg => RCall (t, mk_tree' ctx arg)
   141     fun mk_tree' ctx t =
   140           | NONE => 
   142       case matchcall t of
   141             if not (exists_subterm (fn Free v => v = fvar | _ => false) t) then Leaf t
   143         SOME arg => RCall (t, mk_tree' ctx arg)
   142             else 
   144       | NONE =>
   143               let val (r, dep, branches) = find_cong_rule ctx fvar h congs_deps t in
   145         if not (exists_subterm (fn Free v => v = fvar | _ => false) t) then Leaf t
   144                 Cong (r, dep, 
   146         else
   145                       map (fn (ctx', fixes, assumes, st) => 
   147           let
   146                               ((fixes, map (assume o cterm_of (ProofContext.theory_of ctx)) assumes), 
   148             val (r, dep, branches) = find_cong_rule ctx fvar h congs_deps t
   147                                mk_tree' ctx' st)) branches)
   149             fun subtree (ctx', fixes, assumes, st) =
   148               end
   150               ((fixes,
   149     in
   151                 map (assume o cterm_of (ProofContext.theory_of ctx)) assumes),
   150       mk_tree' ctxt t
   152                mk_tree' ctx' st)
   151     end
   153           in
   152     
   154             Cong (r, dep, map subtree branches)
       
   155           end
       
   156   in
       
   157     mk_tree' ctxt t
       
   158   end
   153 
   159 
   154 fun inst_tree thy fvar f tr =
   160 fun inst_tree thy fvar f tr =
   155     let
   161   let
   156       val cfvar = cterm_of thy fvar
   162     val cfvar = cterm_of thy fvar
   157       val cf = cterm_of thy f
   163     val cf = cterm_of thy f
   158                
   164 
   159       fun inst_term t = 
   165     fun inst_term t =
   160           subst_bound(f, abstract_over (fvar, t))
   166       subst_bound(f, abstract_over (fvar, t))
   161 
   167 
   162       val inst_thm = forall_elim cf o forall_intr cfvar 
   168     val inst_thm = forall_elim cf o forall_intr cfvar
   163 
   169 
   164       fun inst_tree_aux (Leaf t) = Leaf t
   170     fun inst_tree_aux (Leaf t) = Leaf t
   165         | inst_tree_aux (Cong (crule, deps, branches)) =
   171       | inst_tree_aux (Cong (crule, deps, branches)) =
   166           Cong (inst_thm crule, deps, map inst_branch branches)
   172         Cong (inst_thm crule, deps, map inst_branch branches)
   167         | inst_tree_aux (RCall (t, str)) =
   173       | inst_tree_aux (RCall (t, str)) =
   168           RCall (inst_term t, inst_tree_aux str)
   174         RCall (inst_term t, inst_tree_aux str)
   169       and inst_branch ((fxs, assms), str) = 
   175     and inst_branch ((fxs, assms), str) =
   170           ((fxs, map (assume o cterm_of thy o inst_term o prop_of) assms), inst_tree_aux str)
   176       ((fxs, map (assume o cterm_of thy o inst_term o prop_of) assms),
   171     in
   177        inst_tree_aux str)
   172       inst_tree_aux tr
   178   in
   173     end
   179     inst_tree_aux tr
       
   180   end
   174 
   181 
   175 
   182 
   176 (* Poor man's contexts: Only fixes and assumes *)
   183 (* Poor man's contexts: Only fixes and assumes *)
   177 fun compose (fs1, as1) (fs2, as2) = (fs1 @ fs2, as1 @ as2)
   184 fun compose (fs1, as1) (fs2, as2) = (fs1 @ fs2, as1 @ as2)
   178 
   185 
   179 fun export_term (fixes, assumes) =
   186 fun export_term (fixes, assumes) =
   180     fold_rev (curry Logic.mk_implies o prop_of) assumes 
   187  fold_rev (curry Logic.mk_implies o prop_of) assumes
   181  #> fold_rev (Logic.all o Free) fixes
   188  #> fold_rev (Logic.all o Free) fixes
   182 
   189 
   183 fun export_thm thy (fixes, assumes) =
   190 fun export_thm thy (fixes, assumes) =
   184     fold_rev (implies_intr o cprop_of) assumes
   191  fold_rev (implies_intr o cprop_of) assumes
   185  #> fold_rev (forall_intr o cterm_of thy o Free) fixes
   192  #> fold_rev (forall_intr o cterm_of thy o Free) fixes
   186 
   193 
   187 fun import_thm thy (fixes, athms) =
   194 fun import_thm thy (fixes, athms) =
   188     fold (forall_elim o cterm_of thy o Free) fixes
   195  fold (forall_elim o cterm_of thy o Free) fixes
   189  #> fold Thm.elim_implies athms
   196  #> fold Thm.elim_implies athms
   190 
   197 
   191 
   198 
   192 (* folds in the order of the dependencies of a graph. *)
   199 (* folds in the order of the dependencies of a graph. *)
   193 fun fold_deps G f x =
   200 fun fold_deps G f x =
   194     let
   201   let
   195       fun fill_table i (T, x) =
   202     fun fill_table i (T, x) =
   196           case Inttab.lookup T i of
   203       case Inttab.lookup T i of
   197             SOME _ => (T, x)
   204         SOME _ => (T, x)
   198           | NONE => 
   205       | NONE =>
   199             let
   206         let
   200               val (T', x') = fold fill_table (IntGraph.imm_succs G i) (T, x)
   207           val (T', x') = fold fill_table (IntGraph.imm_succs G i) (T, x)
   201               val (v, x'') = f (the o Inttab.lookup T') i x'
   208           val (v, x'') = f (the o Inttab.lookup T') i x'
   202             in
   209         in
   203               (Inttab.update (i, v) T', x'')
   210           (Inttab.update (i, v) T', x'')
   204             end
   211         end
   205             
   212 
   206       val (T, x) = fold fill_table (IntGraph.keys G) (Inttab.empty, x)
   213     val (T, x) = fold fill_table (IntGraph.keys G) (Inttab.empty, x)
   207     in
   214   in
   208       (Inttab.fold (cons o snd) T [], x)
   215     (Inttab.fold (cons o snd) T [], x)
   209     end
   216   end
   210     
   217 
   211 fun traverse_tree rcOp tr =
   218 fun traverse_tree rcOp tr =
   212     let 
   219   let
   213   fun traverse_help ctx (Leaf _) _ x = ([], x)
   220     fun traverse_help ctx (Leaf _) _ x = ([], x)
   214     | traverse_help ctx (RCall (t, st)) u x =
   221       | traverse_help ctx (RCall (t, st)) u x =
   215       rcOp ctx t u (traverse_help ctx st u x)
   222         rcOp ctx t u (traverse_help ctx st u x)
   216     | traverse_help ctx (Cong (_, deps, branches)) u x =
   223       | traverse_help ctx (Cong (_, deps, branches)) u x =
   217       let
   224       let
   218     fun sub_step lu i x =
   225         fun sub_step lu i x =
   219         let
   226           let
   220           val (ctx', subtree) = nth branches i
   227             val (ctx', subtree) = nth branches i
   221           val used = fold_rev (append o lu) (IntGraph.imm_succs deps i) u
   228             val used = fold_rev (append o lu) (IntGraph.imm_succs deps i) u
   222           val (subs, x') = traverse_help (compose ctx ctx') subtree used x
   229             val (subs, x') = traverse_help (compose ctx ctx') subtree used x
   223           val exported_subs = map (apfst (compose ctx')) subs (* FIXME: Right order of composition? *)
   230             val exported_subs = map (apfst (compose ctx')) subs (* FIXME: Right order of composition? *)
   224         in
   231           in
   225           (exported_subs, x')
   232             (exported_subs, x')
   226         end
   233           end
   227       in
   234       in
   228         fold_deps deps sub_step x
   235         fold_deps deps sub_step x
   229           |> apfst flat
   236         |> apfst flat
   230       end
   237       end
   231     in
   238   in
   232       snd o traverse_help ([], []) tr []
   239     snd o traverse_help ([], []) tr []
   233     end
   240   end
   234 
   241 
   235 fun rewrite_by_tree thy h ih x tr =
   242 fun rewrite_by_tree thy h ih x tr =
   236     let
   243   let
   237       fun rewrite_help _ _ x (Leaf t) = (reflexive (cterm_of thy t), x)
   244     fun rewrite_help _ _ x (Leaf t) = (reflexive (cterm_of thy t), x)
   238         | rewrite_help fix h_as x (RCall (_ $ arg, st)) =
   245       | rewrite_help fix h_as x (RCall (_ $ arg, st)) =
   239           let
   246         let
   240             val (inner, (lRi,ha)::x') = rewrite_help fix h_as x st (* "a' = a" *)
   247           val (inner, (lRi,ha)::x') = rewrite_help fix h_as x st (* "a' = a" *)
   241                                                      
   248 
   242             val iha = import_thm thy (fix, h_as) ha (* (a', h a') : G *)
   249           val iha = import_thm thy (fix, h_as) ha (* (a', h a') : G *)
   243                  |> Conv.fconv_rule (Conv.arg_conv (Conv.comb_conv (Conv.arg_conv (K inner))))
   250             |> Conv.fconv_rule (Conv.arg_conv (Conv.comb_conv (Conv.arg_conv (K inner))))
   244                                                     (* (a, h a) : G   *)
   251                                                     (* (a, h a) : G   *)
   245             val inst_ih = instantiate' [] [SOME (cterm_of thy arg)] ih
   252           val inst_ih = instantiate' [] [SOME (cterm_of thy arg)] ih
   246             val eq = implies_elim (implies_elim inst_ih lRi) iha (* h a = f a *)
   253           val eq = implies_elim (implies_elim inst_ih lRi) iha (* h a = f a *)
   247                      
   254 
   248             val h_a'_eq_h_a = combination (reflexive (cterm_of thy h)) inner
   255           val h_a'_eq_h_a = combination (reflexive (cterm_of thy h)) inner
   249             val h_a_eq_f_a = eq RS eq_reflection
   256           val h_a_eq_f_a = eq RS eq_reflection
   250             val result = transitive h_a'_eq_h_a h_a_eq_f_a
   257           val result = transitive h_a'_eq_h_a h_a_eq_f_a
   251           in
   258         in
   252             (result, x')
   259           (result, x')
   253           end
   260         end
   254         | rewrite_help fix h_as x (Cong (crule, deps, branches)) =
   261       | rewrite_help fix h_as x (Cong (crule, deps, branches)) =
   255           let
   262         let
   256             fun sub_step lu i x =
   263           fun sub_step lu i x =
   257                 let
   264             let
   258                   val ((fixes, assumes), st) = nth branches i
   265               val ((fixes, assumes), st) = nth branches i
   259                   val used = map lu (IntGraph.imm_succs deps i)
   266               val used = map lu (IntGraph.imm_succs deps i)
   260                              |> map (fn u_eq => (u_eq RS sym) RS eq_reflection)
   267                 |> map (fn u_eq => (u_eq RS sym) RS eq_reflection)
   261                              |> filter_out Thm.is_reflexive
   268                 |> filter_out Thm.is_reflexive
   262 
   269 
   263                   val assumes' = map (simplify (HOL_basic_ss addsimps used)) assumes
   270               val assumes' = map (simplify (HOL_basic_ss addsimps used)) assumes
   264                                  
   271 
   265                   val (subeq, x') = rewrite_help (fix @ fixes) (h_as @ assumes') x st
   272               val (subeq, x') =
   266                   val subeq_exp = export_thm thy (fixes, assumes) (subeq RS meta_eq_to_obj_eq)
   273                 rewrite_help (fix @ fixes) (h_as @ assumes') x st
   267                 in
   274               val subeq_exp =
   268                   (subeq_exp, x')
   275                 export_thm thy (fixes, assumes) (subeq RS meta_eq_to_obj_eq)
   269                 end
   276             in
   270                 
   277               (subeq_exp, x')
   271             val (subthms, x') = fold_deps deps sub_step x
   278             end
   272           in
   279           val (subthms, x') = fold_deps deps sub_step x
   273             (fold_rev (curry op COMP) subthms crule, x')
   280         in
   274           end
   281           (fold_rev (curry op COMP) subthms crule, x')
   275     in
   282         end
   276       rewrite_help [] [] x tr
   283   in
   277     end
   284     rewrite_help [] [] x tr
   278     
   285   end
       
   286 
   279 end
   287 end