src/HOL/Tools/Function/mutual.ML
author krauss
Sat Jan 02 23:18:58 2010 +0100 (2010-01-02)
changeset 34232 36a2a3029fd3
parent 33855 cd8acf137c9c
child 35624 c4e29a0bb8c1
permissions -rw-r--r--
new year's resolution: reindented code in function package
     1 (*  Title:      HOL/Tools/Function/mutual.ML
     2     Author:     Alexander Krauss, TU Muenchen
     3 
     4 A package for general recursive function definitions.
     5 Tools for mutual recursive definitions.
     6 *)
     7 
     8 signature FUNCTION_MUTUAL =
     9 sig
    10 
    11   val prepare_function_mutual : Function_Common.function_config
    12     -> string (* defname *)
    13     -> ((string * typ) * mixfix) list
    14     -> term list
    15     -> local_theory
    16     -> ((thm (* goalstate *)
    17         * (thm -> Function_Common.function_result) (* proof continuation *)
    18        ) * local_theory)
    19 
    20 end
    21 
    22 
    23 structure Function_Mutual: FUNCTION_MUTUAL =
    24 struct
    25 
    26 open Function_Lib
    27 open Function_Common
    28 
    29 type qgar = string * (string * typ) list * term list * term list * term
    30 
    31 datatype mutual_part = MutualPart of
    32  {i : int,
    33   i' : int,
    34   fvar : string * typ,
    35   cargTs: typ list,
    36   f_def: term,
    37 
    38   f: term option,
    39   f_defthm : thm option}
    40 
    41 datatype mutual_info = Mutual of
    42  {n : int,
    43   n' : int,
    44   fsum_var : string * typ,
    45 
    46   ST: typ,
    47   RST: typ,
    48 
    49   parts: mutual_part list,
    50   fqgars: qgar list,
    51   qglrs: ((string * typ) list * term list * term * term) list,
    52 
    53   fsum : term option}
    54 
    55 fun mutual_induct_Pnames n =
    56   if n < 5 then fst (chop n ["P","Q","R","S"])
    57   else map (fn i => "P" ^ string_of_int i) (1 upto n)
    58 
    59 fun get_part fname =
    60   the o find_first (fn (MutualPart {fvar=(n,_), ...}) => n = fname)
    61 
    62 (* FIXME *)
    63 fun mk_prod_abs e (t1, t2) =
    64   let
    65     val bTs = rev (map snd e)
    66     val T1 = fastype_of1 (bTs, t1)
    67     val T2 = fastype_of1 (bTs, t2)
    68   in
    69     HOLogic.pair_const T1 T2 $ t1 $ t2
    70   end
    71 
    72 fun analyze_eqs ctxt defname fs eqs =
    73   let
    74     val num = length fs
    75     val fqgars = map (split_def ctxt) eqs
    76     val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
    77       |> AList.lookup (op =) #> the
    78 
    79     fun curried_types (fname, fT) =
    80       let
    81         val (caTs, uaTs) = chop (arity_of fname) (binder_types fT)
    82       in
    83         (caTs, uaTs ---> body_type fT)
    84       end
    85 
    86     val (caTss, resultTs) = split_list (map curried_types fs)
    87     val argTs = map (foldr1 HOLogic.mk_prodT) caTss
    88 
    89     val dresultTs = distinct (op =) resultTs
    90     val n' = length dresultTs
    91 
    92     val RST = Balanced_Tree.make (uncurry SumTree.mk_sumT) dresultTs
    93     val ST = Balanced_Tree.make (uncurry SumTree.mk_sumT) argTs
    94 
    95     val fsum_type = ST --> RST
    96 
    97     val ([fsum_var_name], _) = Variable.add_fixes [ defname ^ "_sum" ] ctxt
    98     val fsum_var = (fsum_var_name, fsum_type)
    99 
   100     fun define (fvar as (n, _)) caTs resultT i =
   101       let
   102         val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *)
   103         val i' = find_index (fn Ta => Ta = resultT) dresultTs + 1
   104 
   105         val f_exp = SumTree.mk_proj RST n' i' (Free fsum_var $ SumTree.mk_inj ST num i (foldr1 HOLogic.mk_prod vars))
   106         val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp)
   107 
   108         val rew = (n, fold_rev lambda vars f_exp)
   109       in
   110         (MutualPart {i=i, i'=i', fvar=fvar,cargTs=caTs,f_def=def,f=NONE,f_defthm=NONE}, rew)
   111       end
   112 
   113     val (parts, rews) = split_list (map4 define fs caTss resultTs (1 upto num))
   114 
   115     fun convert_eqs (f, qs, gs, args, rhs) =
   116       let
   117         val MutualPart {i, i', ...} = get_part f parts
   118       in
   119         (qs, gs, SumTree.mk_inj ST num i (foldr1 (mk_prod_abs qs) args),
   120          SumTree.mk_inj RST n' i' (replace_frees rews rhs)
   121          |> Envir.beta_norm)
   122       end
   123 
   124     val qglrs = map convert_eqs fqgars
   125   in
   126     Mutual {n=num, n'=n', fsum_var=fsum_var, ST=ST, RST=RST,
   127       parts=parts, fqgars=fqgars, qglrs=qglrs, fsum=NONE}
   128   end
   129 
   130 fun define_projections fixes mutual fsum lthy =
   131   let
   132     fun def ((MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs, f_def, ...}), (_, mixfix)) lthy =
   133       let
   134         val ((f, (_, f_defthm)), lthy') =
   135           Local_Theory.define
   136             ((Binding.name fname, mixfix),
   137               ((Binding.conceal (Binding.name (fname ^ "_def")), []),
   138               Term.subst_bound (fsum, f_def))) lthy
   139       in
   140         (MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs=cargTs, f_def=f_def,
   141            f=SOME f, f_defthm=SOME f_defthm },
   142          lthy')
   143       end
   144 
   145     val Mutual { n, n', fsum_var, ST, RST, parts, fqgars, qglrs, ... } = mutual
   146     val (parts', lthy') = fold_map def (parts ~~ fixes) lthy
   147   in
   148     (Mutual { n=n, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, parts=parts',
   149        fqgars=fqgars, qglrs=qglrs, fsum=SOME fsum },
   150      lthy')
   151   end
   152 
   153 fun in_context ctxt (f, pre_qs, pre_gs, pre_args, pre_rhs) F =
   154   let
   155     val thy = ProofContext.theory_of ctxt
   156 
   157     val oqnames = map fst pre_qs
   158     val (qs, _) = Variable.variant_fixes oqnames ctxt
   159       |>> map2 (fn (_, T) => fn n => Free (n, T)) pre_qs
   160 
   161     fun inst t = subst_bounds (rev qs, t)
   162     val gs = map inst pre_gs
   163     val args = map inst pre_args
   164     val rhs = inst pre_rhs
   165 
   166     val cqs = map (cterm_of thy) qs
   167     val ags = map (assume o cterm_of thy) gs
   168 
   169     val import = fold forall_elim cqs
   170       #> fold Thm.elim_implies ags
   171 
   172     val export = fold_rev (implies_intr o cprop_of) ags
   173       #> fold_rev forall_intr_rename (oqnames ~~ cqs)
   174   in
   175     F ctxt (f, qs, gs, args, rhs) import export
   176   end
   177 
   178 fun recover_mutual_psimp all_orig_fdefs parts ctxt (fname, _, _, args, rhs)
   179   import (export : thm -> thm) sum_psimp_eq =
   180   let
   181     val (MutualPart {f=SOME f, ...}) = get_part fname parts
   182 
   183     val psimp = import sum_psimp_eq
   184     val (simp, restore_cond) =
   185       case cprems_of psimp of
   186         [] => (psimp, I)
   187       | [cond] => (implies_elim psimp (assume cond), implies_intr cond)
   188       | _ => sys_error "Too many conditions"
   189 
   190   in
   191     Goal.prove ctxt [] []
   192       (HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs))
   193       (fn _ => (LocalDefs.unfold_tac ctxt all_orig_fdefs)
   194          THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1
   195          THEN (simp_tac (simpset_of ctxt addsimps SumTree.proj_in_rules)) 1)
   196     |> restore_cond
   197     |> export
   198   end
   199 
   200 fun mk_applied_form ctxt caTs thm =
   201   let
   202     val thy = ProofContext.theory_of ctxt
   203     val xs = map_index (fn (i,T) => cterm_of thy (Free ("x" ^ string_of_int i, T))) caTs (* FIXME: Bind xs properly *)
   204   in
   205     fold (fn x => fn thm => combination thm (reflexive x)) xs thm
   206     |> Conv.fconv_rule (Thm.beta_conversion true)
   207     |> fold_rev forall_intr xs
   208     |> Thm.forall_elim_vars 0
   209   end
   210 
   211 fun mutual_induct_rules lthy induct all_f_defs (Mutual {n, ST, parts, ...}) =
   212   let
   213     val cert = cterm_of (ProofContext.theory_of lthy)
   214     val newPs =
   215       map2 (fn Pname => fn MutualPart {cargTs, ...} =>
   216           Free (Pname, cargTs ---> HOLogic.boolT))
   217         (mutual_induct_Pnames (length parts)) parts
   218 
   219     fun mk_P (MutualPart {cargTs, ...}) P =
   220       let
   221         val avars = map_index (fn (i,T) => Var (("a", i), T)) cargTs
   222         val atup = foldr1 HOLogic.mk_prod avars
   223       in
   224         tupled_lambda atup (list_comb (P, avars))
   225       end
   226 
   227     val Ps = map2 mk_P parts newPs
   228     val case_exp = SumTree.mk_sumcases HOLogic.boolT Ps
   229 
   230     val induct_inst =
   231       forall_elim (cert case_exp) induct
   232       |> full_simplify SumTree.sumcase_split_ss
   233       |> full_simplify (HOL_basic_ss addsimps all_f_defs)
   234 
   235     fun project rule (MutualPart {cargTs, i, ...}) k =
   236       let
   237         val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int (j + k), T)) cargTs (* FIXME! *)
   238         val inj = SumTree.mk_inj ST n i (foldr1 HOLogic.mk_prod afs)
   239       in
   240         (rule
   241          |> forall_elim (cert inj)
   242          |> full_simplify SumTree.sumcase_split_ss
   243          |> fold_rev (forall_intr o cert) (afs @ newPs),
   244          k + length cargTs)
   245       end
   246   in
   247     fst (fold_map (project induct_inst) parts 0)
   248   end
   249 
   250 fun mk_partial_rules_mutual lthy inner_cont (m as Mutual {parts, fqgars, ...}) proof =
   251   let
   252     val result = inner_cont proof
   253     val FunctionResult {G, R, cases, psimps, trsimps, simple_pinducts=[simple_pinduct],
   254       termination, domintros, ...} = result
   255 
   256     val (all_f_defs, fs) =
   257       map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} =>
   258         (mk_applied_form lthy cargTs (symmetric f_def), f))
   259       parts
   260       |> split_list
   261 
   262     val all_orig_fdefs =
   263       map (fn MutualPart {f_defthm = SOME f_def, ...} => f_def) parts
   264 
   265     fun mk_mpsimp fqgar sum_psimp =
   266       in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp
   267 
   268     val rew_ss = HOL_basic_ss addsimps all_f_defs
   269     val mpsimps = map2 mk_mpsimp fqgars psimps
   270     val mtrsimps = map_option (map2 mk_mpsimp fqgars) trsimps
   271     val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m
   272     val mtermination = full_simplify rew_ss termination
   273     val mdomintros = map_option (map (full_simplify rew_ss)) domintros
   274   in
   275     FunctionResult { fs=fs, G=G, R=R,
   276       psimps=mpsimps, simple_pinducts=minducts,
   277       cases=cases, termination=mtermination,
   278       domintros=mdomintros, trsimps=mtrsimps}
   279   end
   280 
   281 fun prepare_function_mutual config defname fixes eqss lthy =
   282   let
   283     val mutual as Mutual {fsum_var=(n, T), qglrs, ...} =
   284       analyze_eqs lthy defname (map fst fixes) (map Envir.beta_eta_contract eqss)
   285 
   286     val ((fsum, goalstate, cont), lthy') =
   287       Function_Core.prepare_function config defname [((n, T), NoSyn)] qglrs lthy
   288 
   289     val (mutual', lthy'') = define_projections fixes mutual fsum lthy'
   290 
   291     val mutual_cont = mk_partial_rules_mutual lthy'' cont mutual'
   292   in
   293     ((goalstate, mutual_cont), lthy'')
   294   end
   295 
   296 end