src/HOL/Tools/Function/mutual.ML
author wenzelm
Tue Sep 29 22:48:24 2009 +0200 (2009-09-29)
changeset 32765 3032c0308019
parent 32149 ef59550a55d3
child 33099 b8cdd3d73022
permissions -rw-r--r--
modernized Balanced_Tree;
     1 (*  Title:      HOL/Tools/Function/mutual.ML
     2     Author:     Alexander Krauss, TU Muenchen
     3 
     4 A package for general recursive function definitions.
     5 Tools for mutual recursive definitions.
     6 *)
     7 
     8 signature FUNDEF_MUTUAL =
     9 sig
    10 
    11   val prepare_fundef_mutual : FundefCommon.fundef_config
    12                               -> string (* defname *)
    13                               -> ((string * typ) * mixfix) list
    14                               -> term list
    15                               -> local_theory
    16                               -> ((thm (* goalstate *)
    17                                    * (thm -> FundefCommon.fundef_result) (* proof continuation *)
    18                                   ) * local_theory)
    19 
    20 end
    21 
    22 
    23 structure FundefMutual: FUNDEF_MUTUAL =
    24 struct
    25 
    26 open FundefLib
    27 open FundefCommon
    28 
    29 
    30 
    31 
    32 type qgar = string * (string * typ) list * term list * term list * term
    33 
    34 fun name_of_fqgar ((f, _, _, _, _): qgar) = f
    35 
    36 datatype mutual_part =
    37   MutualPart of 
    38    {
    39     i : int,
    40     i' : int,
    41     fvar : string * typ,
    42     cargTs: typ list,
    43     f_def: term,
    44 
    45     f: term option,
    46     f_defthm : thm option
    47    }
    48    
    49 
    50 datatype mutual_info =
    51   Mutual of 
    52    { 
    53     n : int,
    54     n' : int,
    55     fsum_var : string * typ,
    56 
    57     ST: typ,
    58     RST: typ,
    59 
    60     parts: mutual_part list,
    61     fqgars: qgar list,
    62     qglrs: ((string * typ) list * term list * term * term) list,
    63 
    64     fsum : term option
    65    }
    66 
    67 fun mutual_induct_Pnames n =
    68     if n < 5 then fst (chop n ["P","Q","R","S"])
    69     else map (fn i => "P" ^ string_of_int i) (1 upto n)
    70 
    71 fun get_part fname =
    72     the o find_first (fn (MutualPart {fvar=(n,_), ...}) => n = fname)
    73                      
    74 (* FIXME *)
    75 fun mk_prod_abs e (t1, t2) =
    76     let
    77       val bTs = rev (map snd e)
    78       val T1 = fastype_of1 (bTs, t1)
    79       val T2 = fastype_of1 (bTs, t2)
    80     in
    81       HOLogic.pair_const T1 T2 $ t1 $ t2
    82     end;
    83 
    84 
    85 fun analyze_eqs ctxt defname fs eqs =
    86     let
    87       val num = length fs
    88         val fnames = map fst fs
    89         val fqgars = map (split_def ctxt) eqs
    90         val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
    91                        |> AList.lookup (op =) #> the
    92 
    93         fun curried_types (fname, fT) =
    94             let
    95               val (caTs, uaTs) = chop (arity_of fname) (binder_types fT)
    96             in
    97                 (caTs, uaTs ---> body_type fT)
    98             end
    99 
   100         val (caTss, resultTs) = split_list (map curried_types fs)
   101         val argTs = map (foldr1 HOLogic.mk_prodT) caTss
   102 
   103         val dresultTs = distinct (op =) resultTs
   104         val n' = length dresultTs
   105 
   106         val RST = Balanced_Tree.make (uncurry SumTree.mk_sumT) dresultTs
   107         val ST = Balanced_Tree.make (uncurry SumTree.mk_sumT) argTs
   108 
   109         val fsum_type = ST --> RST
   110 
   111         val ([fsum_var_name], _) = Variable.add_fixes [ defname ^ "_sum" ] ctxt
   112         val fsum_var = (fsum_var_name, fsum_type)
   113 
   114         fun define (fvar as (n, T)) caTs resultT i =
   115             let
   116                 val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *)
   117                 val i' = find_index (fn Ta => Ta = resultT) dresultTs + 1 
   118 
   119                 val f_exp = SumTree.mk_proj RST n' i' (Free fsum_var $ SumTree.mk_inj ST num i (foldr1 HOLogic.mk_prod vars))
   120                 val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp)
   121 
   122                 val rew = (n, fold_rev lambda vars f_exp)
   123             in
   124                 (MutualPart {i=i, i'=i', fvar=fvar,cargTs=caTs,f_def=def,f=NONE,f_defthm=NONE}, rew)
   125             end
   126             
   127         val (parts, rews) = split_list (map4 define fs caTss resultTs (1 upto num))
   128 
   129         fun convert_eqs (f, qs, gs, args, rhs) =
   130             let
   131               val MutualPart {i, i', ...} = get_part f parts
   132             in
   133               (qs, gs, SumTree.mk_inj ST num i (foldr1 (mk_prod_abs qs) args),
   134                SumTree.mk_inj RST n' i' (replace_frees rews rhs)
   135                                |> Envir.beta_norm)
   136             end
   137 
   138         val qglrs = map convert_eqs fqgars
   139     in
   140         Mutual {n=num, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, 
   141                 parts=parts, fqgars=fqgars, qglrs=qglrs, fsum=NONE}
   142     end
   143 
   144 
   145 
   146 
   147 fun define_projections fixes mutual fsum lthy =
   148     let
   149       fun def ((MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs, f_def, ...}), (_, mixfix)) lthy =
   150           let
   151             val ((f, (_, f_defthm)), lthy') =
   152               LocalTheory.define Thm.internalK ((Binding.name fname, mixfix),
   153                                             ((Binding.name (fname ^ "_def"), []), Term.subst_bound (fsum, f_def)))
   154                               lthy
   155           in
   156             (MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs=cargTs, f_def=f_def,
   157                          f=SOME f, f_defthm=SOME f_defthm },
   158              lthy')
   159           end
   160           
   161       val Mutual { n, n', fsum_var, ST, RST, parts, fqgars, qglrs, ... } = mutual
   162       val (parts', lthy') = fold_map def (parts ~~ fixes) lthy
   163     in
   164       (Mutual { n=n, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, parts=parts',
   165                 fqgars=fqgars, qglrs=qglrs, fsum=SOME fsum },
   166        lthy')
   167     end
   168 
   169 
   170 fun in_context ctxt (f, pre_qs, pre_gs, pre_args, pre_rhs) F =
   171     let
   172       val thy = ProofContext.theory_of ctxt
   173                 
   174       val oqnames = map fst pre_qs
   175       val (qs, ctxt') = Variable.variant_fixes oqnames ctxt
   176                         |>> map2 (fn (_, T) => fn n => Free (n, T)) pre_qs
   177                         
   178       fun inst t = subst_bounds (rev qs, t)
   179       val gs = map inst pre_gs
   180       val args = map inst pre_args
   181       val rhs = inst pre_rhs
   182 
   183       val cqs = map (cterm_of thy) qs
   184       val ags = map (assume o cterm_of thy) gs
   185 
   186       val import = fold forall_elim cqs
   187                    #> fold Thm.elim_implies ags
   188 
   189       val export = fold_rev (implies_intr o cprop_of) ags
   190                    #> fold_rev forall_intr_rename (oqnames ~~ cqs)
   191     in
   192       F ctxt (f, qs, gs, args, rhs) import export
   193     end
   194 
   195 fun recover_mutual_psimp all_orig_fdefs parts ctxt (fname, _, _, args, rhs) import (export : thm -> thm) sum_psimp_eq =
   196     let
   197       val (MutualPart {f=SOME f, f_defthm=SOME f_def, ...}) = get_part fname parts
   198 
   199       val psimp = import sum_psimp_eq
   200       val (simp, restore_cond) = case cprems_of psimp of
   201                                    [] => (psimp, I)
   202                                  | [cond] => (implies_elim psimp (assume cond), implies_intr cond)
   203                                  | _ => sys_error "Too many conditions"
   204     in
   205       Goal.prove ctxt [] [] 
   206                  (HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs))
   207                  (fn _ => (LocalDefs.unfold_tac ctxt all_orig_fdefs)
   208                           THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1
   209                           THEN (simp_tac (simpset_of ctxt addsimps SumTree.proj_in_rules)) 1)
   210         |> restore_cond 
   211         |> export
   212     end
   213 
   214 
   215 (* FIXME HACK *)
   216 fun mk_applied_form ctxt caTs thm =
   217     let
   218       val thy = ProofContext.theory_of ctxt
   219       val xs = map_index (fn (i,T) => cterm_of thy (Free ("x" ^ string_of_int i, T))) caTs (* FIXME: Bind xs properly *)
   220     in
   221       fold (fn x => fn thm => combination thm (reflexive x)) xs thm
   222            |> Conv.fconv_rule (Thm.beta_conversion true)
   223            |> fold_rev forall_intr xs
   224            |> Thm.forall_elim_vars 0
   225     end
   226 
   227 
   228 fun mutual_induct_rules lthy induct all_f_defs (Mutual {n, ST, RST, parts, ...}) =
   229     let
   230       val cert = cterm_of (ProofContext.theory_of lthy)
   231       val newPs = map2 (fn Pname => fn MutualPart {cargTs, ...} => 
   232                                        Free (Pname, cargTs ---> HOLogic.boolT))
   233                        (mutual_induct_Pnames (length parts))
   234                        parts
   235                        
   236       fun mk_P (MutualPart {cargTs, ...}) P =
   237           let
   238             val avars = map_index (fn (i,T) => Var (("a", i), T)) cargTs
   239             val atup = foldr1 HOLogic.mk_prod avars
   240           in
   241             tupled_lambda atup (list_comb (P, avars))
   242           end
   243           
   244       val Ps = map2 mk_P parts newPs
   245       val case_exp = SumTree.mk_sumcases HOLogic.boolT Ps
   246                      
   247       val induct_inst =
   248           forall_elim (cert case_exp) induct
   249                       |> full_simplify SumTree.sumcase_split_ss
   250                       |> full_simplify (HOL_basic_ss addsimps all_f_defs)
   251           
   252       fun project rule (MutualPart {cargTs, i, ...}) k =
   253           let
   254             val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int (j + k), T)) cargTs (* FIXME! *)
   255             val inj = SumTree.mk_inj ST n i (foldr1 HOLogic.mk_prod afs)
   256           in
   257             (rule
   258               |> forall_elim (cert inj)
   259               |> full_simplify SumTree.sumcase_split_ss
   260               |> fold_rev (forall_intr o cert) (afs @ newPs),
   261              k + length cargTs)
   262           end
   263     in
   264       fst (fold_map (project induct_inst) parts 0)
   265     end
   266     
   267 
   268 fun mk_partial_rules_mutual lthy inner_cont (m as Mutual {parts, fqgars, ...}) proof =
   269     let
   270       val result = inner_cont proof
   271       val FundefResult {fs=[f], G, R, cases, psimps, trsimps, simple_pinducts=[simple_pinduct],
   272                         termination,domintros} = result
   273                                                                                                                
   274       val (all_f_defs, fs) = map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} =>
   275                                      (mk_applied_form lthy cargTs (symmetric f_def), f))
   276                                  parts
   277                              |> split_list
   278 
   279       val all_orig_fdefs = map (fn MutualPart {f_defthm = SOME f_def, ...} => f_def) parts
   280                            
   281       fun mk_mpsimp fqgar sum_psimp =
   282           in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp
   283           
   284       val rew_ss = HOL_basic_ss addsimps all_f_defs
   285       val mpsimps = map2 mk_mpsimp fqgars psimps
   286       val mtrsimps = map_option (map2 mk_mpsimp fqgars) trsimps
   287       val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m
   288       val mtermination = full_simplify rew_ss termination
   289       val mdomintros = map_option (map (full_simplify rew_ss)) domintros
   290     in
   291       FundefResult { fs=fs, G=G, R=R,
   292                      psimps=mpsimps, simple_pinducts=minducts,
   293                      cases=cases, termination=mtermination,
   294                      domintros=mdomintros,
   295                      trsimps=mtrsimps}
   296     end
   297       
   298 fun prepare_fundef_mutual config defname fixes eqss lthy =
   299     let
   300       val mutual = analyze_eqs lthy defname (map fst fixes) (map Envir.beta_eta_contract eqss)
   301       val Mutual {fsum_var=(n, T), qglrs, ...} = mutual
   302           
   303       val ((fsum, goalstate, cont), lthy') =
   304           FundefCore.prepare_fundef config defname [((n, T), NoSyn)] qglrs lthy
   305           
   306       val (mutual', lthy'') = define_projections fixes mutual fsum lthy'
   307 
   308       val mutual_cont = mk_partial_rules_mutual lthy'' cont mutual'
   309     in
   310       ((goalstate, mutual_cont), lthy'')
   311     end
   312 
   313     
   314 end