src/HOL/Tools/Function/context_tree.ML
changeset 34232 36a2a3029fd3
parent 33519 e31a85f92ce9
child 35403 25a67a606782
     1.1 --- a/src/HOL/Tools/Function/context_tree.ML	Sat Jan 02 23:18:58 2010 +0100
     1.2 +++ b/src/HOL/Tools/Function/context_tree.ML	Sat Jan 02 23:18:58 2010 +0100
     1.3 @@ -1,39 +1,41 @@
     1.4  (*  Title:      HOL/Tools/Function/context_tree.ML
     1.5      Author:     Alexander Krauss, TU Muenchen
     1.6  
     1.7 -A package for general recursive function definitions. 
     1.8 +A package for general recursive function definitions.
     1.9  Builds and traverses trees of nested contexts along a term.
    1.10  *)
    1.11  
    1.12  signature FUNCTION_CTXTREE =
    1.13  sig
    1.14 -    type ctxt = (string * typ) list * thm list (* poor man's contexts: fixes + assumes *)
    1.15 -    type ctx_tree
    1.16 +  (* poor man's contexts: fixes + assumes *)
    1.17 +  type ctxt = (string * typ) list * thm list
    1.18 +  type ctx_tree
    1.19  
    1.20 -    (* FIXME: This interface is a mess and needs to be cleaned up! *)
    1.21 -    val get_function_congs : Proof.context -> thm list
    1.22 -    val add_function_cong : thm -> Context.generic -> Context.generic
    1.23 -    val map_function_congs : (thm list -> thm list) -> Context.generic -> Context.generic
    1.24 +  (* FIXME: This interface is a mess and needs to be cleaned up! *)
    1.25 +  val get_function_congs : Proof.context -> thm list
    1.26 +  val add_function_cong : thm -> Context.generic -> Context.generic
    1.27 +  val map_function_congs : (thm list -> thm list) -> Context.generic -> Context.generic
    1.28  
    1.29 -    val cong_add: attribute
    1.30 -    val cong_del: attribute
    1.31 +  val cong_add: attribute
    1.32 +  val cong_del: attribute
    1.33  
    1.34 -    val mk_tree: (string * typ) -> term -> Proof.context -> term -> ctx_tree
    1.35 +  val mk_tree: (string * typ) -> term -> Proof.context -> term -> ctx_tree
    1.36  
    1.37 -    val inst_tree: theory -> term -> term -> ctx_tree -> ctx_tree
    1.38 +  val inst_tree: theory -> term -> term -> ctx_tree -> ctx_tree
    1.39  
    1.40 -    val export_term : ctxt -> term -> term
    1.41 -    val export_thm : theory -> ctxt -> thm -> thm
    1.42 -    val import_thm : theory -> ctxt -> thm -> thm
    1.43 +  val export_term : ctxt -> term -> term
    1.44 +  val export_thm : theory -> ctxt -> thm -> thm
    1.45 +  val import_thm : theory -> ctxt -> thm -> thm
    1.46  
    1.47 -    val traverse_tree : 
    1.48 +  val traverse_tree :
    1.49     (ctxt -> term ->
    1.50     (ctxt * thm) list ->
    1.51     (ctxt * thm) list * 'b ->
    1.52     (ctxt * thm) list * 'b)
    1.53     -> ctx_tree -> 'b -> 'b
    1.54  
    1.55 -    val rewrite_by_tree : theory -> term -> thm -> (thm * thm) list -> ctx_tree -> thm * (thm * thm) list
    1.56 +  val rewrite_by_tree : theory -> term -> thm -> (thm * thm) list ->
    1.57 +    ctx_tree -> thm * (thm * thm) list
    1.58  end
    1.59  
    1.60  structure Function_Ctx_Tree : FUNCTION_CTXTREE =
    1.61 @@ -64,8 +66,8 @@
    1.62  
    1.63  type depgraph = int IntGraph.T
    1.64  
    1.65 -datatype ctx_tree 
    1.66 -  = Leaf of term
    1.67 +datatype ctx_tree =
    1.68 +  Leaf of term
    1.69    | Cong of (thm * depgraph * (ctxt * ctx_tree) list)
    1.70    | RCall of (term * ctx_tree)
    1.71  
    1.72 @@ -76,204 +78,210 @@
    1.73  
    1.74  (*** Dependency analysis for congruence rules ***)
    1.75  
    1.76 -fun branch_vars t = 
    1.77 -    let 
    1.78 -      val t' = snd (dest_all_all t)
    1.79 -      val (assumes, concl) = Logic.strip_horn t'
    1.80 -    in (fold Term.add_vars assumes [], Term.add_vars concl [])
    1.81 -    end
    1.82 +fun branch_vars t =
    1.83 +  let
    1.84 +    val t' = snd (dest_all_all t)
    1.85 +    val (assumes, concl) = Logic.strip_horn t'
    1.86 +  in
    1.87 +    (fold Term.add_vars assumes [], Term.add_vars concl [])
    1.88 +  end
    1.89  
    1.90  fun cong_deps crule =
    1.91 -    let
    1.92 -      val num_branches = map_index (apsnd branch_vars) (prems_of crule)
    1.93 -    in
    1.94 -      IntGraph.empty
    1.95 -        |> fold (fn (i,_)=> IntGraph.new_node (i,i)) num_branches
    1.96 -        |> fold_product (fn (i, (c1, _)) => fn (j, (_, t2)) => 
    1.97 -               if i = j orelse null (inter (op =) c1 t2)
    1.98 -               then I else IntGraph.add_edge_acyclic (i,j))
    1.99 -             num_branches num_branches
   1.100 +  let
   1.101 +    val num_branches = map_index (apsnd branch_vars) (prems_of crule)
   1.102 +  in
   1.103 +    IntGraph.empty
   1.104 +    |> fold (fn (i,_)=> IntGraph.new_node (i,i)) num_branches
   1.105 +    |> fold_product (fn (i, (c1, _)) => fn (j, (_, t2)) =>
   1.106 +         if i = j orelse null (inter (op =) c1 t2)
   1.107 +         then I else IntGraph.add_edge_acyclic (i,j))
   1.108 +       num_branches num_branches
   1.109      end
   1.110 -    
   1.111 -val default_congs = map (fn c => c RS eq_reflection) [@{thm "cong"}, @{thm "ext"}] 
   1.112  
   1.113 -
   1.114 +val default_congs =
   1.115 +  map (fn c => c RS eq_reflection) [@{thm "cong"}, @{thm "ext"}]
   1.116  
   1.117  (* Called on the INSTANTIATED branches of the congruence rule *)
   1.118 -fun mk_branch ctx t = 
   1.119 -    let
   1.120 -      val (ctx', fixes, impl) = dest_all_all_ctx ctx t
   1.121 -      val (assms, concl) = Logic.strip_horn impl
   1.122 -    in
   1.123 -      (ctx', fixes, assms, rhs_of concl)
   1.124 -    end
   1.125 -    
   1.126 +fun mk_branch ctx t =
   1.127 +  let
   1.128 +    val (ctx', fixes, impl) = dest_all_all_ctx ctx t
   1.129 +    val (assms, concl) = Logic.strip_horn impl
   1.130 +  in
   1.131 +    (ctx', fixes, assms, rhs_of concl)
   1.132 +  end
   1.133 +
   1.134  fun find_cong_rule ctx fvar h ((r,dep)::rs) t =
   1.135 -    (let
   1.136 -       val thy = ProofContext.theory_of ctx
   1.137 +  (let
   1.138 +     val thy = ProofContext.theory_of ctx
   1.139  
   1.140 -       val tt' = Logic.mk_equals (Pattern.rewrite_term thy [(Free fvar, h)] [] t, t)
   1.141 -       val (c, subs) = (concl_of r, prems_of r)
   1.142 +     val tt' = Logic.mk_equals (Pattern.rewrite_term thy [(Free fvar, h)] [] t, t)
   1.143 +     val (c, subs) = (concl_of r, prems_of r)
   1.144  
   1.145 -       val subst = Pattern.match (ProofContext.theory_of ctx) (c, tt') (Vartab.empty, Vartab.empty)
   1.146 -       val branches = map (mk_branch ctx o Envir.beta_norm o Envir.subst_term subst) subs
   1.147 -       val inst = map (fn v =>
   1.148 -        (cterm_of thy (Var v), cterm_of thy (Envir.subst_term subst (Var v)))) (Term.add_vars c [])
   1.149 -     in
   1.150 -   (cterm_instantiate inst r, dep, branches)
   1.151 -     end
   1.152 -    handle Pattern.MATCH => find_cong_rule ctx fvar h rs t)
   1.153 +     val subst = Pattern.match (ProofContext.theory_of ctx) (c, tt') (Vartab.empty, Vartab.empty)
   1.154 +     val branches = map (mk_branch ctx o Envir.beta_norm o Envir.subst_term subst) subs
   1.155 +     val inst = map (fn v =>
   1.156 +       (cterm_of thy (Var v), cterm_of thy (Envir.subst_term subst (Var v)))) (Term.add_vars c [])
   1.157 +   in
   1.158 +     (cterm_instantiate inst r, dep, branches)
   1.159 +   end
   1.160 +   handle Pattern.MATCH => find_cong_rule ctx fvar h rs t)
   1.161    | find_cong_rule _ _ _ [] _ = sys_error "Function/context_tree.ML: No cong rule found!"
   1.162  
   1.163  
   1.164  fun mk_tree fvar h ctxt t =
   1.165 -    let 
   1.166 -      val congs = get_function_congs ctxt
   1.167 -      val congs_deps = map (fn c => (c, cong_deps c)) (congs @ default_congs) (* FIXME: Save in theory *)
   1.168 +  let
   1.169 +    val congs = get_function_congs ctxt
   1.170  
   1.171 -      fun matchcall (a $ b) = if a = Free fvar then SOME b else NONE
   1.172 -        | matchcall _ = NONE
   1.173 +    (* FIXME: Save in theory: *)
   1.174 +    val congs_deps = map (fn c => (c, cong_deps c)) (congs @ default_congs)
   1.175 +
   1.176 +    fun matchcall (a $ b) = if a = Free fvar then SOME b else NONE
   1.177 +      | matchcall _ = NONE
   1.178  
   1.179 -      fun mk_tree' ctx t =
   1.180 -          case matchcall t of
   1.181 -            SOME arg => RCall (t, mk_tree' ctx arg)
   1.182 -          | NONE => 
   1.183 -            if not (exists_subterm (fn Free v => v = fvar | _ => false) t) then Leaf t
   1.184 -            else 
   1.185 -              let val (r, dep, branches) = find_cong_rule ctx fvar h congs_deps t in
   1.186 -                Cong (r, dep, 
   1.187 -                      map (fn (ctx', fixes, assumes, st) => 
   1.188 -                              ((fixes, map (assume o cterm_of (ProofContext.theory_of ctx)) assumes), 
   1.189 -                               mk_tree' ctx' st)) branches)
   1.190 -              end
   1.191 -    in
   1.192 -      mk_tree' ctxt t
   1.193 -    end
   1.194 -    
   1.195 +    fun mk_tree' ctx t =
   1.196 +      case matchcall t of
   1.197 +        SOME arg => RCall (t, mk_tree' ctx arg)
   1.198 +      | NONE =>
   1.199 +        if not (exists_subterm (fn Free v => v = fvar | _ => false) t) then Leaf t
   1.200 +        else
   1.201 +          let
   1.202 +            val (r, dep, branches) = find_cong_rule ctx fvar h congs_deps t
   1.203 +            fun subtree (ctx', fixes, assumes, st) =
   1.204 +              ((fixes,
   1.205 +                map (assume o cterm_of (ProofContext.theory_of ctx)) assumes),
   1.206 +               mk_tree' ctx' st)
   1.207 +          in
   1.208 +            Cong (r, dep, map subtree branches)
   1.209 +          end
   1.210 +  in
   1.211 +    mk_tree' ctxt t
   1.212 +  end
   1.213  
   1.214  fun inst_tree thy fvar f tr =
   1.215 -    let
   1.216 -      val cfvar = cterm_of thy fvar
   1.217 -      val cf = cterm_of thy f
   1.218 -               
   1.219 -      fun inst_term t = 
   1.220 -          subst_bound(f, abstract_over (fvar, t))
   1.221 +  let
   1.222 +    val cfvar = cterm_of thy fvar
   1.223 +    val cf = cterm_of thy f
   1.224  
   1.225 -      val inst_thm = forall_elim cf o forall_intr cfvar 
   1.226 +    fun inst_term t =
   1.227 +      subst_bound(f, abstract_over (fvar, t))
   1.228 +
   1.229 +    val inst_thm = forall_elim cf o forall_intr cfvar
   1.230  
   1.231 -      fun inst_tree_aux (Leaf t) = Leaf t
   1.232 -        | inst_tree_aux (Cong (crule, deps, branches)) =
   1.233 -          Cong (inst_thm crule, deps, map inst_branch branches)
   1.234 -        | inst_tree_aux (RCall (t, str)) =
   1.235 -          RCall (inst_term t, inst_tree_aux str)
   1.236 -      and inst_branch ((fxs, assms), str) = 
   1.237 -          ((fxs, map (assume o cterm_of thy o inst_term o prop_of) assms), inst_tree_aux str)
   1.238 -    in
   1.239 -      inst_tree_aux tr
   1.240 -    end
   1.241 +    fun inst_tree_aux (Leaf t) = Leaf t
   1.242 +      | inst_tree_aux (Cong (crule, deps, branches)) =
   1.243 +        Cong (inst_thm crule, deps, map inst_branch branches)
   1.244 +      | inst_tree_aux (RCall (t, str)) =
   1.245 +        RCall (inst_term t, inst_tree_aux str)
   1.246 +    and inst_branch ((fxs, assms), str) =
   1.247 +      ((fxs, map (assume o cterm_of thy o inst_term o prop_of) assms),
   1.248 +       inst_tree_aux str)
   1.249 +  in
   1.250 +    inst_tree_aux tr
   1.251 +  end
   1.252  
   1.253  
   1.254  (* Poor man's contexts: Only fixes and assumes *)
   1.255  fun compose (fs1, as1) (fs2, as2) = (fs1 @ fs2, as1 @ as2)
   1.256  
   1.257  fun export_term (fixes, assumes) =
   1.258 -    fold_rev (curry Logic.mk_implies o prop_of) assumes 
   1.259 + fold_rev (curry Logic.mk_implies o prop_of) assumes
   1.260   #> fold_rev (Logic.all o Free) fixes
   1.261  
   1.262  fun export_thm thy (fixes, assumes) =
   1.263 -    fold_rev (implies_intr o cprop_of) assumes
   1.264 + fold_rev (implies_intr o cprop_of) assumes
   1.265   #> fold_rev (forall_intr o cterm_of thy o Free) fixes
   1.266  
   1.267  fun import_thm thy (fixes, athms) =
   1.268 -    fold (forall_elim o cterm_of thy o Free) fixes
   1.269 + fold (forall_elim o cterm_of thy o Free) fixes
   1.270   #> fold Thm.elim_implies athms
   1.271  
   1.272  
   1.273  (* folds in the order of the dependencies of a graph. *)
   1.274  fun fold_deps G f x =
   1.275 -    let
   1.276 -      fun fill_table i (T, x) =
   1.277 -          case Inttab.lookup T i of
   1.278 -            SOME _ => (T, x)
   1.279 -          | NONE => 
   1.280 -            let
   1.281 -              val (T', x') = fold fill_table (IntGraph.imm_succs G i) (T, x)
   1.282 -              val (v, x'') = f (the o Inttab.lookup T') i x'
   1.283 -            in
   1.284 -              (Inttab.update (i, v) T', x'')
   1.285 -            end
   1.286 -            
   1.287 -      val (T, x) = fold fill_table (IntGraph.keys G) (Inttab.empty, x)
   1.288 -    in
   1.289 -      (Inttab.fold (cons o snd) T [], x)
   1.290 -    end
   1.291 -    
   1.292 +  let
   1.293 +    fun fill_table i (T, x) =
   1.294 +      case Inttab.lookup T i of
   1.295 +        SOME _ => (T, x)
   1.296 +      | NONE =>
   1.297 +        let
   1.298 +          val (T', x') = fold fill_table (IntGraph.imm_succs G i) (T, x)
   1.299 +          val (v, x'') = f (the o Inttab.lookup T') i x'
   1.300 +        in
   1.301 +          (Inttab.update (i, v) T', x'')
   1.302 +        end
   1.303 +
   1.304 +    val (T, x) = fold fill_table (IntGraph.keys G) (Inttab.empty, x)
   1.305 +  in
   1.306 +    (Inttab.fold (cons o snd) T [], x)
   1.307 +  end
   1.308 +
   1.309  fun traverse_tree rcOp tr =
   1.310 -    let 
   1.311 -  fun traverse_help ctx (Leaf _) _ x = ([], x)
   1.312 -    | traverse_help ctx (RCall (t, st)) u x =
   1.313 -      rcOp ctx t u (traverse_help ctx st u x)
   1.314 -    | traverse_help ctx (Cong (_, deps, branches)) u x =
   1.315 +  let
   1.316 +    fun traverse_help ctx (Leaf _) _ x = ([], x)
   1.317 +      | traverse_help ctx (RCall (t, st)) u x =
   1.318 +        rcOp ctx t u (traverse_help ctx st u x)
   1.319 +      | traverse_help ctx (Cong (_, deps, branches)) u x =
   1.320        let
   1.321 -    fun sub_step lu i x =
   1.322 -        let
   1.323 -          val (ctx', subtree) = nth branches i
   1.324 -          val used = fold_rev (append o lu) (IntGraph.imm_succs deps i) u
   1.325 -          val (subs, x') = traverse_help (compose ctx ctx') subtree used x
   1.326 -          val exported_subs = map (apfst (compose ctx')) subs (* FIXME: Right order of composition? *)
   1.327 -        in
   1.328 -          (exported_subs, x')
   1.329 -        end
   1.330 +        fun sub_step lu i x =
   1.331 +          let
   1.332 +            val (ctx', subtree) = nth branches i
   1.333 +            val used = fold_rev (append o lu) (IntGraph.imm_succs deps i) u
   1.334 +            val (subs, x') = traverse_help (compose ctx ctx') subtree used x
   1.335 +            val exported_subs = map (apfst (compose ctx')) subs (* FIXME: Right order of composition? *)
   1.336 +          in
   1.337 +            (exported_subs, x')
   1.338 +          end
   1.339        in
   1.340          fold_deps deps sub_step x
   1.341 -          |> apfst flat
   1.342 +        |> apfst flat
   1.343        end
   1.344 -    in
   1.345 -      snd o traverse_help ([], []) tr []
   1.346 -    end
   1.347 +  in
   1.348 +    snd o traverse_help ([], []) tr []
   1.349 +  end
   1.350  
   1.351  fun rewrite_by_tree thy h ih x tr =
   1.352 -    let
   1.353 -      fun rewrite_help _ _ x (Leaf t) = (reflexive (cterm_of thy t), x)
   1.354 -        | rewrite_help fix h_as x (RCall (_ $ arg, st)) =
   1.355 -          let
   1.356 -            val (inner, (lRi,ha)::x') = rewrite_help fix h_as x st (* "a' = a" *)
   1.357 -                                                     
   1.358 -            val iha = import_thm thy (fix, h_as) ha (* (a', h a') : G *)
   1.359 -                 |> Conv.fconv_rule (Conv.arg_conv (Conv.comb_conv (Conv.arg_conv (K inner))))
   1.360 +  let
   1.361 +    fun rewrite_help _ _ x (Leaf t) = (reflexive (cterm_of thy t), x)
   1.362 +      | rewrite_help fix h_as x (RCall (_ $ arg, st)) =
   1.363 +        let
   1.364 +          val (inner, (lRi,ha)::x') = rewrite_help fix h_as x st (* "a' = a" *)
   1.365 +
   1.366 +          val iha = import_thm thy (fix, h_as) ha (* (a', h a') : G *)
   1.367 +            |> Conv.fconv_rule (Conv.arg_conv (Conv.comb_conv (Conv.arg_conv (K inner))))
   1.368                                                      (* (a, h a) : G   *)
   1.369 -            val inst_ih = instantiate' [] [SOME (cterm_of thy arg)] ih
   1.370 -            val eq = implies_elim (implies_elim inst_ih lRi) iha (* h a = f a *)
   1.371 -                     
   1.372 -            val h_a'_eq_h_a = combination (reflexive (cterm_of thy h)) inner
   1.373 -            val h_a_eq_f_a = eq RS eq_reflection
   1.374 -            val result = transitive h_a'_eq_h_a h_a_eq_f_a
   1.375 -          in
   1.376 -            (result, x')
   1.377 -          end
   1.378 -        | rewrite_help fix h_as x (Cong (crule, deps, branches)) =
   1.379 -          let
   1.380 -            fun sub_step lu i x =
   1.381 -                let
   1.382 -                  val ((fixes, assumes), st) = nth branches i
   1.383 -                  val used = map lu (IntGraph.imm_succs deps i)
   1.384 -                             |> map (fn u_eq => (u_eq RS sym) RS eq_reflection)
   1.385 -                             |> filter_out Thm.is_reflexive
   1.386 +          val inst_ih = instantiate' [] [SOME (cterm_of thy arg)] ih
   1.387 +          val eq = implies_elim (implies_elim inst_ih lRi) iha (* h a = f a *)
   1.388 +
   1.389 +          val h_a'_eq_h_a = combination (reflexive (cterm_of thy h)) inner
   1.390 +          val h_a_eq_f_a = eq RS eq_reflection
   1.391 +          val result = transitive h_a'_eq_h_a h_a_eq_f_a
   1.392 +        in
   1.393 +          (result, x')
   1.394 +        end
   1.395 +      | rewrite_help fix h_as x (Cong (crule, deps, branches)) =
   1.396 +        let
   1.397 +          fun sub_step lu i x =
   1.398 +            let
   1.399 +              val ((fixes, assumes), st) = nth branches i
   1.400 +              val used = map lu (IntGraph.imm_succs deps i)
   1.401 +                |> map (fn u_eq => (u_eq RS sym) RS eq_reflection)
   1.402 +                |> filter_out Thm.is_reflexive
   1.403  
   1.404 -                  val assumes' = map (simplify (HOL_basic_ss addsimps used)) assumes
   1.405 -                                 
   1.406 -                  val (subeq, x') = rewrite_help (fix @ fixes) (h_as @ assumes') x st
   1.407 -                  val subeq_exp = export_thm thy (fixes, assumes) (subeq RS meta_eq_to_obj_eq)
   1.408 -                in
   1.409 -                  (subeq_exp, x')
   1.410 -                end
   1.411 -                
   1.412 -            val (subthms, x') = fold_deps deps sub_step x
   1.413 -          in
   1.414 -            (fold_rev (curry op COMP) subthms crule, x')
   1.415 -          end
   1.416 -    in
   1.417 -      rewrite_help [] [] x tr
   1.418 -    end
   1.419 -    
   1.420 +              val assumes' = map (simplify (HOL_basic_ss addsimps used)) assumes
   1.421 +
   1.422 +              val (subeq, x') =
   1.423 +                rewrite_help (fix @ fixes) (h_as @ assumes') x st
   1.424 +              val subeq_exp =
   1.425 +                export_thm thy (fixes, assumes) (subeq RS meta_eq_to_obj_eq)
   1.426 +            in
   1.427 +              (subeq_exp, x')
   1.428 +            end
   1.429 +          val (subthms, x') = fold_deps deps sub_step x
   1.430 +        in
   1.431 +          (fold_rev (curry op COMP) subthms crule, x')
   1.432 +        end
   1.433 +  in
   1.434 +    rewrite_help [] [] x tr
   1.435 +  end
   1.436 +
   1.437  end