src/HOL/Tools/Function/mutual.ML
author wenzelm
Thu Nov 19 14:46:33 2009 +0100 (2009-11-19)
changeset 33766 c679f05600cd
parent 33671 4b0f2599ed48
child 33855 cd8acf137c9c
permissions -rw-r--r--
adapted Local_Theory.define -- eliminated odd thm kind;
     1 (*  Title:      HOL/Tools/Function/mutual.ML
     2     Author:     Alexander Krauss, TU Muenchen
     3 
     4 A package for general recursive function definitions.
     5 Tools for mutual recursive definitions.
     6 *)
     7 
     8 signature FUNCTION_MUTUAL =
     9 sig
    10 
    11   val prepare_function_mutual : Function_Common.function_config
    12                               -> string (* defname *)
    13                               -> ((string * typ) * mixfix) list
    14                               -> term list
    15                               -> local_theory
    16                               -> ((thm (* goalstate *)
    17                                    * (thm -> Function_Common.function_result) (* proof continuation *)
    18                                   ) * local_theory)
    19 
    20 end
    21 
    22 
    23 structure Function_Mutual: FUNCTION_MUTUAL =
    24 struct
    25 
    26 open Function_Lib
    27 open Function_Common
    28 
    29 type qgar = string * (string * typ) list * term list * term list * term
    30 
    31 fun name_of_fqgar ((f, _, _, _, _): qgar) = f
    32 
    33 datatype mutual_part =
    34   MutualPart of 
    35    {
    36     i : int,
    37     i' : int,
    38     fvar : string * typ,
    39     cargTs: typ list,
    40     f_def: term,
    41 
    42     f: term option,
    43     f_defthm : thm option
    44    }
    45    
    46 
    47 datatype mutual_info =
    48   Mutual of 
    49    { 
    50     n : int,
    51     n' : int,
    52     fsum_var : string * typ,
    53 
    54     ST: typ,
    55     RST: typ,
    56 
    57     parts: mutual_part list,
    58     fqgars: qgar list,
    59     qglrs: ((string * typ) list * term list * term * term) list,
    60 
    61     fsum : term option
    62    }
    63 
    64 fun mutual_induct_Pnames n =
    65     if n < 5 then fst (chop n ["P","Q","R","S"])
    66     else map (fn i => "P" ^ string_of_int i) (1 upto n)
    67 
    68 fun get_part fname =
    69     the o find_first (fn (MutualPart {fvar=(n,_), ...}) => n = fname)
    70                      
    71 (* FIXME *)
    72 fun mk_prod_abs e (t1, t2) =
    73     let
    74       val bTs = rev (map snd e)
    75       val T1 = fastype_of1 (bTs, t1)
    76       val T2 = fastype_of1 (bTs, t2)
    77     in
    78       HOLogic.pair_const T1 T2 $ t1 $ t2
    79     end;
    80 
    81 
    82 fun analyze_eqs ctxt defname fs eqs =
    83     let
    84       val num = length fs
    85         val fnames = map fst fs
    86         val fqgars = map (split_def ctxt) eqs
    87         val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
    88                        |> AList.lookup (op =) #> the
    89 
    90         fun curried_types (fname, fT) =
    91             let
    92               val (caTs, uaTs) = chop (arity_of fname) (binder_types fT)
    93             in
    94                 (caTs, uaTs ---> body_type fT)
    95             end
    96 
    97         val (caTss, resultTs) = split_list (map curried_types fs)
    98         val argTs = map (foldr1 HOLogic.mk_prodT) caTss
    99 
   100         val dresultTs = distinct (op =) resultTs
   101         val n' = length dresultTs
   102 
   103         val RST = Balanced_Tree.make (uncurry SumTree.mk_sumT) dresultTs
   104         val ST = Balanced_Tree.make (uncurry SumTree.mk_sumT) argTs
   105 
   106         val fsum_type = ST --> RST
   107 
   108         val ([fsum_var_name], _) = Variable.add_fixes [ defname ^ "_sum" ] ctxt
   109         val fsum_var = (fsum_var_name, fsum_type)
   110 
   111         fun define (fvar as (n, T)) caTs resultT i =
   112             let
   113                 val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *)
   114                 val i' = find_index (fn Ta => Ta = resultT) dresultTs + 1 
   115 
   116                 val f_exp = SumTree.mk_proj RST n' i' (Free fsum_var $ SumTree.mk_inj ST num i (foldr1 HOLogic.mk_prod vars))
   117                 val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp)
   118 
   119                 val rew = (n, fold_rev lambda vars f_exp)
   120             in
   121                 (MutualPart {i=i, i'=i', fvar=fvar,cargTs=caTs,f_def=def,f=NONE,f_defthm=NONE}, rew)
   122             end
   123             
   124         val (parts, rews) = split_list (map4 define fs caTss resultTs (1 upto num))
   125 
   126         fun convert_eqs (f, qs, gs, args, rhs) =
   127             let
   128               val MutualPart {i, i', ...} = get_part f parts
   129             in
   130               (qs, gs, SumTree.mk_inj ST num i (foldr1 (mk_prod_abs qs) args),
   131                SumTree.mk_inj RST n' i' (replace_frees rews rhs)
   132                                |> Envir.beta_norm)
   133             end
   134 
   135         val qglrs = map convert_eqs fqgars
   136     in
   137         Mutual {n=num, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, 
   138                 parts=parts, fqgars=fqgars, qglrs=qglrs, fsum=NONE}
   139     end
   140 
   141 
   142 
   143 
   144 fun define_projections fixes mutual fsum lthy =
   145     let
   146       fun def ((MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs, f_def, ...}), (_, mixfix)) lthy =
   147           let
   148             val ((f, (_, f_defthm)), lthy') =
   149               Local_Theory.define
   150                 ((Binding.name fname, mixfix),
   151                   ((Binding.conceal (Binding.name (fname ^ "_def")), []),
   152                   Term.subst_bound (fsum, f_def))) lthy
   153           in
   154             (MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs=cargTs, f_def=f_def,
   155                          f=SOME f, f_defthm=SOME f_defthm },
   156              lthy')
   157           end
   158           
   159       val Mutual { n, n', fsum_var, ST, RST, parts, fqgars, qglrs, ... } = mutual
   160       val (parts', lthy') = fold_map def (parts ~~ fixes) lthy
   161     in
   162       (Mutual { n=n, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, parts=parts',
   163                 fqgars=fqgars, qglrs=qglrs, fsum=SOME fsum },
   164        lthy')
   165     end
   166 
   167 
   168 fun in_context ctxt (f, pre_qs, pre_gs, pre_args, pre_rhs) F =
   169     let
   170       val thy = ProofContext.theory_of ctxt
   171                 
   172       val oqnames = map fst pre_qs
   173       val (qs, ctxt') = Variable.variant_fixes oqnames ctxt
   174                         |>> map2 (fn (_, T) => fn n => Free (n, T)) pre_qs
   175                         
   176       fun inst t = subst_bounds (rev qs, t)
   177       val gs = map inst pre_gs
   178       val args = map inst pre_args
   179       val rhs = inst pre_rhs
   180 
   181       val cqs = map (cterm_of thy) qs
   182       val ags = map (assume o cterm_of thy) gs
   183 
   184       val import = fold forall_elim cqs
   185                    #> fold Thm.elim_implies ags
   186 
   187       val export = fold_rev (implies_intr o cprop_of) ags
   188                    #> fold_rev forall_intr_rename (oqnames ~~ cqs)
   189     in
   190       F ctxt (f, qs, gs, args, rhs) import export
   191     end
   192 
   193 fun recover_mutual_psimp all_orig_fdefs parts ctxt (fname, _, _, args, rhs) import (export : thm -> thm) sum_psimp_eq =
   194     let
   195       val (MutualPart {f=SOME f, f_defthm=SOME f_def, ...}) = get_part fname parts
   196 
   197       val psimp = import sum_psimp_eq
   198       val (simp, restore_cond) = case cprems_of psimp of
   199                                    [] => (psimp, I)
   200                                  | [cond] => (implies_elim psimp (assume cond), implies_intr cond)
   201                                  | _ => sys_error "Too many conditions"
   202     in
   203       Goal.prove ctxt [] [] 
   204                  (HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs))
   205                  (fn _ => (LocalDefs.unfold_tac ctxt all_orig_fdefs)
   206                           THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1
   207                           THEN (simp_tac (simpset_of ctxt addsimps SumTree.proj_in_rules)) 1)
   208         |> restore_cond 
   209         |> export
   210     end
   211 
   212 
   213 (* FIXME HACK *)
   214 fun mk_applied_form ctxt caTs thm =
   215     let
   216       val thy = ProofContext.theory_of ctxt
   217       val xs = map_index (fn (i,T) => cterm_of thy (Free ("x" ^ string_of_int i, T))) caTs (* FIXME: Bind xs properly *)
   218     in
   219       fold (fn x => fn thm => combination thm (reflexive x)) xs thm
   220            |> Conv.fconv_rule (Thm.beta_conversion true)
   221            |> fold_rev forall_intr xs
   222            |> Thm.forall_elim_vars 0
   223     end
   224 
   225 
   226 fun mutual_induct_rules lthy induct all_f_defs (Mutual {n, ST, RST, parts, ...}) =
   227     let
   228       val cert = cterm_of (ProofContext.theory_of lthy)
   229       val newPs = map2 (fn Pname => fn MutualPart {cargTs, ...} => 
   230                                        Free (Pname, cargTs ---> HOLogic.boolT))
   231                        (mutual_induct_Pnames (length parts))
   232                        parts
   233                        
   234       fun mk_P (MutualPart {cargTs, ...}) P =
   235           let
   236             val avars = map_index (fn (i,T) => Var (("a", i), T)) cargTs
   237             val atup = foldr1 HOLogic.mk_prod avars
   238           in
   239             tupled_lambda atup (list_comb (P, avars))
   240           end
   241           
   242       val Ps = map2 mk_P parts newPs
   243       val case_exp = SumTree.mk_sumcases HOLogic.boolT Ps
   244                      
   245       val induct_inst =
   246           forall_elim (cert case_exp) induct
   247                       |> full_simplify SumTree.sumcase_split_ss
   248                       |> full_simplify (HOL_basic_ss addsimps all_f_defs)
   249           
   250       fun project rule (MutualPart {cargTs, i, ...}) k =
   251           let
   252             val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int (j + k), T)) cargTs (* FIXME! *)
   253             val inj = SumTree.mk_inj ST n i (foldr1 HOLogic.mk_prod afs)
   254           in
   255             (rule
   256               |> forall_elim (cert inj)
   257               |> full_simplify SumTree.sumcase_split_ss
   258               |> fold_rev (forall_intr o cert) (afs @ newPs),
   259              k + length cargTs)
   260           end
   261     in
   262       fst (fold_map (project induct_inst) parts 0)
   263     end
   264     
   265 
   266 fun mk_partial_rules_mutual lthy inner_cont (m as Mutual {parts, fqgars, ...}) proof =
   267     let
   268       val result = inner_cont proof
   269       val FunctionResult {fs=[f], G, R, cases, psimps, trsimps, simple_pinducts=[simple_pinduct],
   270                         termination,domintros} = result
   271                                                                                                                
   272       val (all_f_defs, fs) = map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} =>
   273                                      (mk_applied_form lthy cargTs (symmetric f_def), f))
   274                                  parts
   275                              |> split_list
   276 
   277       val all_orig_fdefs = map (fn MutualPart {f_defthm = SOME f_def, ...} => f_def) parts
   278                            
   279       fun mk_mpsimp fqgar sum_psimp =
   280           in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp
   281           
   282       val rew_ss = HOL_basic_ss addsimps all_f_defs
   283       val mpsimps = map2 mk_mpsimp fqgars psimps
   284       val mtrsimps = map_option (map2 mk_mpsimp fqgars) trsimps
   285       val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m
   286       val mtermination = full_simplify rew_ss termination
   287       val mdomintros = map_option (map (full_simplify rew_ss)) domintros
   288     in
   289       FunctionResult { fs=fs, G=G, R=R,
   290                      psimps=mpsimps, simple_pinducts=minducts,
   291                      cases=cases, termination=mtermination,
   292                      domintros=mdomintros,
   293                      trsimps=mtrsimps}
   294     end
   295       
   296 fun prepare_function_mutual config defname fixes eqss lthy =
   297     let
   298       val mutual = analyze_eqs lthy defname (map fst fixes) (map Envir.beta_eta_contract eqss)
   299       val Mutual {fsum_var=(n, T), qglrs, ...} = mutual
   300           
   301       val ((fsum, goalstate, cont), lthy') =
   302           Function_Core.prepare_function config defname [((n, T), NoSyn)] qglrs lthy
   303           
   304       val (mutual', lthy'') = define_projections fixes mutual fsum lthy'
   305 
   306       val mutual_cont = mk_partial_rules_mutual lthy'' cont mutual'
   307     in
   308       ((goalstate, mutual_cont), lthy'')
   309     end
   310 
   311     
   312 end