src/HOL/Tools/function_package/context_tree.ML
author wenzelm
Tue, 03 Jul 2007 17:17:04 +0200
changeset 23530 438c5d2db482
parent 21237 b803f9870e97
child 23819 2040846d1bbe
permissions -rw-r--r--
CONVERSION tactical;

(*  Title:      HOL/Tools/function_package/context_tree.ML
    ID:         $Id$
    Author:     Alexander Krauss, TU Muenchen

A package for general recursive function definitions. 
Builds and traverses trees of nested contexts along a term.
*)


signature FUNDEF_CTXTREE =
sig
    type ctx_tree

    (* FIXME: This interface is a mess and needs to be cleaned up! *)
    val cong_deps : thm -> int IntGraph.T
    val add_congs : thm list

    val mk_tree: (thm * FundefCommon.depgraph) list ->
                 (string * typ) -> term -> Proof.context -> term -> FundefCommon.ctx_tree

    val inst_tree: theory -> term -> term -> FundefCommon.ctx_tree
                   -> FundefCommon.ctx_tree

    val add_context_varnames : FundefCommon.ctx_tree -> string list -> string list

    val export_term : (string * typ) list * term list -> term -> term
    val export_thm : theory -> (string * typ) list * term list -> thm -> thm
    val import_thm : theory -> (string * typ) list * thm list -> thm -> thm


    val traverse_tree : 
   ((string * typ) list * thm list -> term ->
   (((string * typ) list * thm list) * thm) list ->
   (((string * typ) list * thm list) * thm) list * 'b ->
   (((string * typ) list * thm list) * thm) list * 'b)
   -> FundefCommon.ctx_tree -> 'b -> 'b

    val rewrite_by_tree : theory -> term -> thm -> (thm * thm) list -> FundefCommon.ctx_tree -> thm * (thm * thm) list
end

structure FundefCtxTree : FUNDEF_CTXTREE =
struct

open FundefCommon
open FundefLib


(* Maps "Trueprop A = B" to "A" *)
val rhs_of = snd o HOLogic.dest_eq o HOLogic.dest_Trueprop
(* Maps "A == B" to "B" *)
val meta_rhs_of = snd o Logic.dest_equals



(*** Dependency analysis for congruence rules ***)

fun branch_vars t = 
    let 
      val t' = snd (dest_all_all t)
      val assumes = Logic.strip_imp_prems t'
      val concl = Logic.strip_imp_concl t'
    in (fold (curry add_term_vars) assumes [], term_vars concl)
    end

fun cong_deps crule =
    let
  val branches = map branch_vars (prems_of crule)
  val num_branches = (1 upto (length branches)) ~~ branches
    in
  IntGraph.empty
      |> fold (fn (i,_)=> IntGraph.new_node (i,i)) num_branches
      |> fold (fn ((i,(c1,_)),(j,(_, t2))) => if i = j orelse null (c1 inter t2) then I else IntGraph.add_edge_acyclic (i,j))
      (product num_branches num_branches)
    end
    
val add_congs = map (fn c => c RS eq_reflection) [cong, ext] 



(* Called on the INSTANTIATED branches of the congruence rule *)
fun mk_branch ctx t = 
    let
  val (ctx', fixes, impl) = dest_all_all_ctx ctx t
    in
      (ctx', fixes, Logic.strip_imp_prems impl, rhs_of (Logic.strip_imp_concl impl))
    end

fun find_cong_rule ctx fvar h ((r,dep)::rs) t =
    (let
       val thy = ProofContext.theory_of ctx

       val tt' = Logic.mk_equals (Pattern.rewrite_term thy [(Free fvar, h)] [] t, t)
       val (c, subs) = (concl_of r, prems_of r)

       val subst = Pattern.match (ProofContext.theory_of ctx) (c, tt') (Vartab.empty, Vartab.empty)
       val branches = map (mk_branch ctx o Envir.beta_norm o Envir.subst_vars subst) subs
       val inst = map (fn v => (cterm_of thy (Var v), cterm_of thy (Envir.subst_vars subst (Var v)))) (Term.add_vars c [])
     in
   (cterm_instantiate inst r, dep, branches)
     end
    handle Pattern.MATCH => find_cong_rule ctx fvar h rs t)
  | find_cong_rule _ _ _ [] _ = sys_error "function_package/context_tree.ML: No cong rule found!"


fun matchcall fvar (a $ b) = if a = Free fvar then SOME b else NONE
  | matchcall fvar _ = NONE

fun mk_tree congs fvar h ctx t =
    case matchcall fvar t of
      SOME arg => RCall (t, mk_tree congs fvar h ctx arg)
    | NONE => 
      if not (exists_subterm (fn Free v => v = fvar | _ => false) t) then Leaf t
      else 
  let val (r, dep, branches) = find_cong_rule ctx fvar h congs t in
    Cong (t, r, dep, 
                map (fn (ctx', fixes, assumes, st) => 
      (fixes, map (assume o cterm_of (ProofContext.theory_of ctx)) assumes, 
                         mk_tree congs fvar h ctx' st)) branches)
  end
    

fun inst_tree thy fvar f tr =
    let
      val cfvar = cterm_of thy fvar
      val cf = cterm_of thy f
               
      fun inst_term t = 
          subst_bound(f, abstract_over (fvar, t))

      val inst_thm = forall_elim cf o forall_intr cfvar 

      fun inst_tree_aux (Leaf t) = Leaf t
        | inst_tree_aux (Cong (t, crule, deps, branches)) =
          Cong (inst_term t, inst_thm crule, deps, map inst_branch branches)
        | inst_tree_aux (RCall (t, str)) =
          RCall (inst_term t, inst_tree_aux str)
      and inst_branch (fxs, assms, str) = 
          (fxs, map (assume o cterm_of thy o inst_term o prop_of) assms, inst_tree_aux str)
    in
      inst_tree_aux tr
    end



(* FIXME: remove *)   
fun add_context_varnames (Leaf _) = I
  | add_context_varnames (Cong (_, _, _, sub)) = fold (fn (fs, _, st) => fold (insert (op =) o fst) fs o add_context_varnames st) sub
  | add_context_varnames (RCall (_,st)) = add_context_varnames st
    

(* Poor man's contexts: Only fixes and assumes *)
fun compose (fs1, as1) (fs2, as2) = (fs1 @ fs2, as1 @ as2)

fun export_term (fixes, assumes) =
    fold_rev (curry Logic.mk_implies) assumes #> fold_rev (mk_forall o Free) fixes

fun export_thm thy (fixes, assumes) =
    fold_rev (implies_intr o cterm_of thy) assumes
 #> fold_rev (forall_intr o cterm_of thy o Free) fixes

fun import_thm thy (fixes, athms) =
    fold (forall_elim o cterm_of thy o Free) fixes
 #> fold implies_elim_swp athms

fun assume_in_ctxt thy (fixes, athms) prop =
    let
  val global_assum = export_term (fixes, map prop_of athms) prop
    in
  (global_assum,
   assume (cterm_of thy global_assum) |> import_thm thy (fixes, athms))
    end


(* folds in the order of the dependencies of a graph. *)
fun fold_deps G f x =
    let
  fun fill_table i (T, x) =
      case Inttab.lookup T i of
    SOME _ => (T, x)
        | NONE => 
    let
        val (T', x') = fold fill_table (IntGraph.imm_succs G i) (T, x)
        val (v, x'') = f (the o Inttab.lookup T') i x
    in
        (Inttab.update (i, v) T', x'')
    end

  val (T, x) = fold fill_table (IntGraph.keys G) (Inttab.empty, x)
    in
  (Inttab.fold (cons o snd) T [], x)
    end


fun flatten xss = fold_rev append xss []

fun traverse_tree rcOp tr x =
    let 
  fun traverse_help ctx (Leaf _) u x = ([], x)
    | traverse_help ctx (RCall (t, st)) u x =
      rcOp ctx t u (traverse_help ctx st u x)
    | traverse_help ctx (Cong (t, crule, deps, branches)) u x =
      let
    fun sub_step lu i x =
        let
      val (fixes, assumes, subtree) = nth branches (i - 1)
      val used = fold_rev (append o lu) (IntGraph.imm_succs deps i) u
      val (subs, x') = traverse_help (compose ctx (fixes, assumes)) subtree used x
      val exported_subs = map (apfst (compose (fixes, assumes))) subs
        in
      (exported_subs, x')
        end
      in
    fold_deps deps sub_step x
        |> apfst flatten
      end
    in
  snd (traverse_help ([], []) tr [] x)
    end


fun is_refl thm = let val (l,r) = Logic.dest_equals (prop_of thm) in l = r end

fun rewrite_by_tree thy h ih x tr =
    let
      fun rewrite_help fix f_as h_as x (Leaf t) = (reflexive (cterm_of thy t), x)
        | rewrite_help fix f_as h_as x (RCall (_ $ arg, st)) =
          let
            val (inner, (lRi,ha)::x') = rewrite_help fix f_as h_as x st
                                                     
             (* Need not use the simplifier here. Can use primitive steps! *)
            val rew_ha = if is_refl inner then I else simplify (HOL_basic_ss addsimps [inner])
           
            val h_a_eq_h_a' = combination (reflexive (cterm_of thy h)) inner
            val iha = import_thm thy (fix, h_as) ha (* (a', h a') : G *)
                                 |> rew_ha
                      
            val inst_ih = instantiate' [] [SOME (cterm_of thy arg)] ih
            val eq = implies_elim (implies_elim inst_ih lRi) iha
                     
            val h_a'_eq_f_a' = eq RS eq_reflection
            val result = transitive h_a_eq_h_a' h_a'_eq_f_a'
          in
            (result, x')
          end
        | rewrite_help fix f_as h_as x (Cong (t, crule, deps, branches)) =
          let
            fun sub_step lu i x =
                let
                  val (fixes, assumes, st) = nth branches (i - 1)
                  val used = fold_rev (cons o lu) (IntGraph.imm_succs deps i) []
                  val used_rev = map (fn u_eq => (u_eq RS sym) RS eq_reflection) used
                  val assumes' = map (simplify (HOL_basic_ss addsimps (filter_out is_refl used_rev))) assumes
                                 
                  val (subeq, x') = rewrite_help (fix @ fixes) (f_as @ assumes) (h_as @ assumes') x st
                  val subeq_exp = export_thm thy (fixes, map prop_of assumes) (subeq RS meta_eq_to_obj_eq)
                in
                  (subeq_exp, x')
                end
                
            val (subthms, x') = fold_deps deps sub_step x
          in
            (fold_rev (curry op COMP) subthms crule, x')
          end
    in
      rewrite_help [] [] [] x tr
    end
    
end