src/HOL/Tools/function_package/context_tree.ML
author krauss
Tue, 07 Aug 2007 14:49:58 +0200
changeset 24168 86a03a092062
parent 23819 2040846d1bbe
child 24977 9f98751c9628
permissions -rw-r--r--
simplified internal interfaces; cong rules are now handled directly by "context_tree.ML"

(*  Title:      HOL/Tools/function_package/context_tree.ML
    ID:         $Id$
    Author:     Alexander Krauss, TU Muenchen

A package for general recursive function definitions. 
Builds and traverses trees of nested contexts along a term.
*)


signature FUNDEF_CTXTREE =
sig
    type depgraph
    type ctx_tree

    (* FIXME: This interface is a mess and needs to be cleaned up! *)
    val get_fundef_congs : Context.generic -> thm list
    val add_fundef_cong : thm -> Context.generic -> Context.generic
    val map_fundef_congs : (thm list -> thm list) -> Context.generic -> Context.generic

    val cong_add: attribute
    val cong_del: attribute

    val mk_tree: (string * typ) -> term -> Proof.context -> term -> ctx_tree

    val inst_tree: theory -> term -> term -> ctx_tree -> ctx_tree

    val add_context_varnames : ctx_tree -> string list -> string list

    val export_term : (string * typ) list * term list -> term -> term
    val export_thm : theory -> (string * typ) list * term list -> thm -> thm
    val import_thm : theory -> (string * typ) list * thm list -> thm -> thm


    val traverse_tree : 
   ((string * typ) list * thm list -> term ->
   (((string * typ) list * thm list) * thm) list ->
   (((string * typ) list * thm list) * thm) list * 'b ->
   (((string * typ) list * thm list) * thm) list * 'b)
   -> ctx_tree -> 'b -> 'b

    val rewrite_by_tree : theory -> term -> thm -> (thm * thm) list -> ctx_tree -> thm * (thm * thm) list
end

structure FundefCtxTree : FUNDEF_CTXTREE =
struct

open FundefCommon
open FundefLib

structure FundefCongs = GenericDataFun
(
  type T = thm list
  val empty = []
  val extend = I
  fun merge _ = Thm.merge_thms
);

val map_fundef_congs = FundefCongs.map 
val get_fundef_congs = FundefCongs.get
val add_fundef_cong = FundefCongs.map o Thm.add_thm

(* congruence rules *)

val cong_add = Thm.declaration_attribute (map_fundef_congs o Thm.add_thm o safe_mk_meta_eq);
val cong_del = Thm.declaration_attribute (map_fundef_congs o Thm.del_thm o safe_mk_meta_eq);


type depgraph = int IntGraph.T

datatype ctx_tree 
  = Leaf of term
  | Cong of (term * thm * depgraph * ((string * typ) list * thm list * ctx_tree) list)
  | RCall of (term * ctx_tree)


(* Maps "Trueprop A = B" to "A" *)
val rhs_of = snd o HOLogic.dest_eq o HOLogic.dest_Trueprop
(* Maps "A == B" to "B" *)
val meta_rhs_of = snd o Logic.dest_equals



(*** Dependency analysis for congruence rules ***)

fun branch_vars t = 
    let 
      val t' = snd (dest_all_all t)
      val assumes = Logic.strip_imp_prems t'
      val concl = Logic.strip_imp_concl t'
    in (fold (curry add_term_vars) assumes [], term_vars concl)
    end

fun cong_deps crule =
    let
      val branches = map branch_vars (prems_of crule)
      val num_branches = (1 upto (length branches)) ~~ branches
    in
      IntGraph.empty
        |> fold (fn (i,_)=> IntGraph.new_node (i,i)) num_branches
        |> fold (fn ((i,(c1,_)),(j,(_, t2))) => if i = j orelse null (c1 inter t2) then I else IntGraph.add_edge_acyclic (i,j))
        (product num_branches num_branches)
    end
    
val add_congs = map (fn c => c RS eq_reflection) [cong, ext] 



(* Called on the INSTANTIATED branches of the congruence rule *)
fun mk_branch ctx t = 
    let
      val (ctx', fixes, impl) = dest_all_all_ctx ctx t
    in
      (ctx', fixes, Logic.strip_imp_prems impl, rhs_of (Logic.strip_imp_concl impl))
    end
    
fun find_cong_rule ctx fvar h ((r,dep)::rs) t =
    (let
       val thy = ProofContext.theory_of ctx

       val tt' = Logic.mk_equals (Pattern.rewrite_term thy [(Free fvar, h)] [] t, t)
       val (c, subs) = (concl_of r, prems_of r)

       val subst = Pattern.match (ProofContext.theory_of ctx) (c, tt') (Vartab.empty, Vartab.empty)
       val branches = map (mk_branch ctx o Envir.beta_norm o Envir.subst_vars subst) subs
       val inst = map (fn v => (cterm_of thy (Var v), cterm_of thy (Envir.subst_vars subst (Var v)))) (Term.add_vars c [])
     in
   (cterm_instantiate inst r, dep, branches)
     end
    handle Pattern.MATCH => find_cong_rule ctx fvar h rs t)
  | find_cong_rule _ _ _ [] _ = sys_error "function_package/context_tree.ML: No cong rule found!"


fun matchcall fvar (a $ b) = if a = Free fvar then SOME b else NONE
  | matchcall fvar _ = NONE

fun mk_tree fvar h ctxt t =
    let 
      val congs = get_fundef_congs (Context.Proof ctxt)
      val congs_deps = map (fn c => (c, cong_deps c)) (congs @ add_congs) (* FIXME: Save in theory *)

      fun mk_tree' ctx t =
          case matchcall fvar t of
            SOME arg => RCall (t, mk_tree' ctx arg)
          | NONE => 
            if not (exists_subterm (fn Free v => v = fvar | _ => false) t) then Leaf t
            else 
              let val (r, dep, branches) = find_cong_rule ctx fvar h congs_deps t in
                Cong (t, r, dep, 
                      map (fn (ctx', fixes, assumes, st) => 
                              (fixes, map (assume o cterm_of (ProofContext.theory_of ctx)) assumes, 
                               mk_tree' ctx' st)) branches)
              end
    in
      mk_tree' ctxt t
    end
    

fun inst_tree thy fvar f tr =
    let
      val cfvar = cterm_of thy fvar
      val cf = cterm_of thy f
               
      fun inst_term t = 
          subst_bound(f, abstract_over (fvar, t))

      val inst_thm = forall_elim cf o forall_intr cfvar 

      fun inst_tree_aux (Leaf t) = Leaf t
        | inst_tree_aux (Cong (t, crule, deps, branches)) =
          Cong (inst_term t, inst_thm crule, deps, map inst_branch branches)
        | inst_tree_aux (RCall (t, str)) =
          RCall (inst_term t, inst_tree_aux str)
      and inst_branch (fxs, assms, str) = 
          (fxs, map (assume o cterm_of thy o inst_term o prop_of) assms, inst_tree_aux str)
    in
      inst_tree_aux tr
    end



(* FIXME: remove *)   
fun add_context_varnames (Leaf _) = I
  | add_context_varnames (Cong (_, _, _, sub)) = fold (fn (fs, _, st) => fold (insert (op =) o fst) fs o add_context_varnames st) sub
  | add_context_varnames (RCall (_,st)) = add_context_varnames st
    

(* Poor man's contexts: Only fixes and assumes *)
fun compose (fs1, as1) (fs2, as2) = (fs1 @ fs2, as1 @ as2)

fun export_term (fixes, assumes) =
    fold_rev (curry Logic.mk_implies) assumes #> fold_rev (mk_forall o Free) fixes

fun export_thm thy (fixes, assumes) =
    fold_rev (implies_intr o cterm_of thy) assumes
 #> fold_rev (forall_intr o cterm_of thy o Free) fixes

fun import_thm thy (fixes, athms) =
    fold (forall_elim o cterm_of thy o Free) fixes
 #> fold (flip implies_elim) athms

fun assume_in_ctxt thy (fixes, athms) prop =
    let
  val global_assum = export_term (fixes, map prop_of athms) prop
    in
  (global_assum,
   assume (cterm_of thy global_assum) |> import_thm thy (fixes, athms))
    end


(* folds in the order of the dependencies of a graph. *)
fun fold_deps G f x =
    let
  fun fill_table i (T, x) =
      case Inttab.lookup T i of
    SOME _ => (T, x)
        | NONE => 
    let
        val (T', x') = fold fill_table (IntGraph.imm_succs G i) (T, x)
        val (v, x'') = f (the o Inttab.lookup T') i x
    in
        (Inttab.update (i, v) T', x'')
    end

  val (T, x) = fold fill_table (IntGraph.keys G) (Inttab.empty, x)
    in
  (Inttab.fold (cons o snd) T [], x)
    end


fun flatten xss = fold_rev append xss []

fun traverse_tree rcOp tr x =
    let 
  fun traverse_help ctx (Leaf _) u x = ([], x)
    | traverse_help ctx (RCall (t, st)) u x =
      rcOp ctx t u (traverse_help ctx st u x)
    | traverse_help ctx (Cong (t, crule, deps, branches)) u x =
      let
    fun sub_step lu i x =
        let
      val (fixes, assumes, subtree) = nth branches (i - 1)
      val used = fold_rev (append o lu) (IntGraph.imm_succs deps i) u
      val (subs, x') = traverse_help (compose ctx (fixes, assumes)) subtree used x
      val exported_subs = map (apfst (compose (fixes, assumes))) subs
        in
      (exported_subs, x')
        end
      in
    fold_deps deps sub_step x
        |> apfst flatten
      end
    in
  snd (traverse_help ([], []) tr [] x)
    end


fun is_refl thm = let val (l,r) = Logic.dest_equals (prop_of thm) in l = r end

fun rewrite_by_tree thy h ih x tr =
    let
      fun rewrite_help fix f_as h_as x (Leaf t) = (reflexive (cterm_of thy t), x)
        | rewrite_help fix f_as h_as x (RCall (_ $ arg, st)) =
          let
            val (inner, (lRi,ha)::x') = rewrite_help fix f_as h_as x st
                                                     
             (* Need not use the simplifier here. Can use primitive steps! *)
            val rew_ha = if is_refl inner then I else simplify (HOL_basic_ss addsimps [inner])
           
            val h_a_eq_h_a' = combination (reflexive (cterm_of thy h)) inner
            val iha = import_thm thy (fix, h_as) ha (* (a', h a') : G *)
                                 |> rew_ha
                      
            val inst_ih = instantiate' [] [SOME (cterm_of thy arg)] ih
            val eq = implies_elim (implies_elim inst_ih lRi) iha
                     
            val h_a'_eq_f_a' = eq RS eq_reflection
            val result = transitive h_a_eq_h_a' h_a'_eq_f_a'
          in
            (result, x')
          end
        | rewrite_help fix f_as h_as x (Cong (t, crule, deps, branches)) =
          let
            fun sub_step lu i x =
                let
                  val (fixes, assumes, st) = nth branches (i - 1)
                  val used = fold_rev (cons o lu) (IntGraph.imm_succs deps i) []
                  val used_rev = map (fn u_eq => (u_eq RS sym) RS eq_reflection) used
                  val assumes' = map (simplify (HOL_basic_ss addsimps (filter_out is_refl used_rev))) assumes
                                 
                  val (subeq, x') = rewrite_help (fix @ fixes) (f_as @ assumes) (h_as @ assumes') x st
                  val subeq_exp = export_thm thy (fixes, map prop_of assumes) (subeq RS meta_eq_to_obj_eq)
                in
                  (subeq_exp, x')
                end
                
            val (subthms, x') = fold_deps deps sub_step x
          in
            (fold_rev (curry op COMP) subthms crule, x')
          end
    in
      rewrite_help [] [] [] x tr
    end
    
end