src/HOL/Tools/function_package/mutual.ML
author wenzelm
Fri Mar 20 15:24:18 2009 +0100 (2009-03-20)
changeset 30607 c3d1590debd8
parent 28965 1de908189869
child 30906 3c7a76e79898
permissions -rw-r--r--
eliminated global SIMPSET, CLASET etc. -- refer to explicit context;
krauss@19770
     1
(*  Title:      HOL/Tools/function_package/mutual.ML
krauss@19770
     2
    Author:     Alexander Krauss, TU Muenchen
krauss@19770
     3
wenzelm@20878
     4
A package for general recursive function definitions.
krauss@19770
     5
Tools for mutual recursive definitions.
krauss@19770
     6
*)
krauss@19770
     7
wenzelm@20878
     8
signature FUNDEF_MUTUAL =
krauss@19770
     9
sig
wenzelm@20878
    10
krauss@22166
    11
  val prepare_fundef_mutual : FundefCommon.fundef_config
krauss@22166
    12
                              -> string (* defname *)
krauss@22166
    13
                              -> ((string * typ) * mixfix) list
wenzelm@20878
    14
                              -> term list
wenzelm@20878
    15
                              -> local_theory
krauss@22166
    16
                              -> ((thm (* goalstate *)
krauss@22166
    17
                                   * (thm -> FundefCommon.fundef_result) (* proof continuation *)
krauss@22166
    18
                                  ) * local_theory)
krauss@20523
    19
krauss@19770
    20
end
krauss@19770
    21
krauss@19770
    22
wenzelm@20878
    23
structure FundefMutual: FUNDEF_MUTUAL =
krauss@19770
    24
struct
krauss@19770
    25
krauss@21051
    26
open FundefLib
krauss@19770
    27
open FundefCommon
krauss@19770
    28
krauss@23494
    29
krauss@23494
    30
krauss@20523
    31
krauss@22166
    32
type qgar = string * (string * typ) list * term list * term list * term
krauss@22166
    33
wenzelm@23217
    34
fun name_of_fqgar ((f, _, _, _, _): qgar) = f
krauss@22166
    35
krauss@22166
    36
datatype mutual_part =
krauss@22166
    37
  MutualPart of 
krauss@22166
    38
   {
krauss@23494
    39
    i : int,
krauss@23494
    40
    i' : int,
krauss@22166
    41
    fvar : string * typ,
krauss@22166
    42
    cargTs: typ list,
krauss@22166
    43
    f_def: term,
krauss@22166
    44
krauss@22166
    45
    f: term option,
krauss@22166
    46
    f_defthm : thm option
krauss@22166
    47
   }
krauss@22166
    48
   
krauss@22166
    49
krauss@22166
    50
datatype mutual_info =
krauss@22166
    51
  Mutual of 
krauss@22166
    52
   { 
krauss@23494
    53
    n : int,
krauss@23494
    54
    n' : int,
krauss@22166
    55
    fsum_var : string * typ,
krauss@22166
    56
krauss@22166
    57
    ST: typ,
krauss@22166
    58
    RST: typ,
krauss@22166
    59
krauss@22166
    60
    parts: mutual_part list,
krauss@22166
    61
    fqgars: qgar list,
krauss@22166
    62
    qglrs: ((string * typ) list * term list * term * term) list,
krauss@22166
    63
krauss@22166
    64
    fsum : term option
krauss@22166
    65
   }
krauss@22166
    66
wenzelm@20878
    67
fun mutual_induct_Pnames n =
krauss@20523
    68
    if n < 5 then fst (chop n ["P","Q","R","S"])
krauss@20523
    69
    else map (fn i => "P" ^ string_of_int i) (1 upto n)
wenzelm@20878
    70
krauss@20523
    71
fun get_part fname =
krauss@20523
    72
    the o find_first (fn (MutualPart {fvar=(n,_), ...}) => n = fname)
krauss@21237
    73
                     
krauss@20523
    74
(* FIXME *)
krauss@20523
    75
fun mk_prod_abs e (t1, t2) =
wenzelm@20878
    76
    let
krauss@20523
    77
      val bTs = rev (map snd e)
krauss@20523
    78
      val T1 = fastype_of1 (bTs, t1)
krauss@20523
    79
      val T2 = fastype_of1 (bTs, t2)
krauss@20523
    80
    in
krauss@20523
    81
      HOLogic.pair_const T1 T2 $ t1 $ t2
krauss@20523
    82
    end;
krauss@20523
    83
krauss@20523
    84
krauss@22166
    85
fun analyze_eqs ctxt defname fs eqs =
krauss@19770
    86
    let
krauss@23494
    87
      val num = length fs
wenzelm@20878
    88
        val fnames = map fst fs
krauss@24170
    89
        val fqgars = map (split_def ctxt) eqs
krauss@23189
    90
        val arities = mk_arities fqgars
krauss@19770
    91
wenzelm@20878
    92
        fun curried_types (fname, fT) =
wenzelm@20878
    93
            let
krauss@20654
    94
              val k = the_default 1 (Symtab.lookup arities fname)
wenzelm@20878
    95
              val (caTs, uaTs) = chop k (binder_types fT)
wenzelm@20878
    96
            in
wenzelm@20878
    97
                (caTs, uaTs ---> body_type fT)
wenzelm@20878
    98
            end
krauss@19770
    99
wenzelm@20878
   100
        val (caTss, resultTs) = split_list (map curried_types fs)
wenzelm@20878
   101
        val argTs = map (foldr1 HOLogic.mk_prodT) caTss
krauss@19770
   102
krauss@23494
   103
        val dresultTs = distinct (Type.eq_type Vartab.empty) resultTs
krauss@23494
   104
        val n' = length dresultTs
krauss@23494
   105
krauss@25555
   106
        val RST = BalancedTree.make (uncurry SumTree.mk_sumT) dresultTs
krauss@25555
   107
        val ST = BalancedTree.make (uncurry SumTree.mk_sumT) argTs
krauss@19770
   108
wenzelm@20878
   109
        val fsum_type = ST --> RST
krauss@19770
   110
krauss@22166
   111
        val ([fsum_var_name], _) = Variable.add_fixes [ defname ^ "_sum" ] ctxt
krauss@20523
   112
        val fsum_var = (fsum_var_name, fsum_type)
krauss@19770
   113
krauss@23494
   114
        fun define (fvar as (n, T)) caTs resultT i =
wenzelm@20878
   115
            let
krauss@23494
   116
                val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *)
krauss@23494
   117
                val i' = find_index (fn Ta => Type.eq_type Vartab.empty (Ta, resultT)) dresultTs + 1 
wenzelm@20878
   118
krauss@25555
   119
                val f_exp = SumTree.mk_proj RST n' i' (Free fsum_var $ SumTree.mk_inj ST num i (foldr1 HOLogic.mk_prod vars))
wenzelm@20878
   120
                val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp)
krauss@19770
   121
wenzelm@20878
   122
                val rew = (n, fold_rev lambda vars f_exp)
wenzelm@20878
   123
            in
krauss@23494
   124
                (MutualPart {i=i, i'=i', fvar=fvar,cargTs=caTs,f_def=def,f=NONE,f_defthm=NONE}, rew)
wenzelm@20878
   125
            end
krauss@21237
   126
            
krauss@23494
   127
        val (parts, rews) = split_list (map4 define fs caTss resultTs (1 upto num))
krauss@19770
   128
krauss@20523
   129
        fun convert_eqs (f, qs, gs, args, rhs) =
krauss@20523
   130
            let
krauss@23494
   131
              val MutualPart {i, i', ...} = get_part f parts
krauss@20523
   132
            in
krauss@25555
   133
              (qs, gs, SumTree.mk_inj ST num i (foldr1 (mk_prod_abs qs) args),
krauss@25555
   134
               SumTree.mk_inj RST n' i' (replace_frees rews rhs)
krauss@23494
   135
                               |> Envir.beta_norm)
krauss@20523
   136
            end
krauss@20523
   137
wenzelm@20878
   138
        val qglrs = map convert_eqs fqgars
krauss@19770
   139
    in
krauss@23494
   140
        Mutual {n=num, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, 
krauss@20523
   141
                parts=parts, fqgars=fqgars, qglrs=qglrs, fsum=NONE}
krauss@19770
   142
    end
krauss@19770
   143
krauss@19770
   144
krauss@19770
   145
krauss@19770
   146
krauss@20523
   147
fun define_projections fixes mutual fsum lthy =
krauss@20523
   148
    let
krauss@23494
   149
      fun def ((MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs, f_def, ...}), (_, mixfix)) lthy =
krauss@20523
   150
          let
wenzelm@21436
   151
            val ((f, (_, f_defthm)), lthy') =
haftmann@28965
   152
              LocalTheory.define Thm.internalK ((Binding.name fname, mixfix),
haftmann@28965
   153
                                            ((Binding.name (fname ^ "_def"), []), Term.subst_bound (fsum, f_def)))
wenzelm@21436
   154
                              lthy
krauss@20523
   155
          in
krauss@23494
   156
            (MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs=cargTs, f_def=f_def,
krauss@21237
   157
                         f=SOME f, f_defthm=SOME f_defthm },
krauss@20523
   158
             lthy')
krauss@20523
   159
          end
krauss@21237
   160
          
krauss@23494
   161
      val Mutual { n, n', fsum_var, ST, RST, parts, fqgars, qglrs, ... } = mutual
wenzelm@20878
   162
      val (parts', lthy') = fold_map def (parts ~~ fixes) lthy
krauss@20523
   163
    in
krauss@23494
   164
      (Mutual { n=n, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, parts=parts',
krauss@20523
   165
                fqgars=fqgars, qglrs=qglrs, fsum=SOME fsum },
krauss@20523
   166
       lthy')
krauss@20523
   167
    end
krauss@20523
   168
krauss@20523
   169
krauss@20523
   170
fun in_context ctxt (f, pre_qs, pre_gs, pre_args, pre_rhs) F =
krauss@19770
   171
    let
krauss@20523
   172
      val thy = ProofContext.theory_of ctxt
krauss@21237
   173
                
krauss@20523
   174
      val oqnames = map fst pre_qs
wenzelm@20797
   175
      val (qs, ctxt') = Variable.variant_fixes oqnames ctxt
krauss@21237
   176
                        |>> map2 (fn (_, T) => fn n => Free (n, T)) pre_qs
krauss@21237
   177
                        
krauss@20523
   178
      fun inst t = subst_bounds (rev qs, t)
krauss@20523
   179
      val gs = map inst pre_gs
krauss@20523
   180
      val args = map inst pre_args
krauss@20523
   181
      val rhs = inst pre_rhs
krauss@20523
   182
krauss@20523
   183
      val cqs = map (cterm_of thy) qs
krauss@20523
   184
      val ags = map (assume o cterm_of thy) gs
wenzelm@20878
   185
krauss@20523
   186
      val import = fold forall_elim cqs
wenzelm@24977
   187
                   #> fold Thm.elim_implies ags
krauss@20523
   188
krauss@20523
   189
      val export = fold_rev (implies_intr o cprop_of) ags
krauss@20523
   190
                   #> fold_rev forall_intr_rename (oqnames ~~ cqs)
krauss@19770
   191
    in
krauss@22497
   192
      F ctxt (f, qs, gs, args, rhs) import export
krauss@19770
   193
    end
krauss@19770
   194
krauss@22497
   195
fun recover_mutual_psimp all_orig_fdefs parts ctxt (fname, _, _, args, rhs) import (export : thm -> thm) sum_psimp_eq =
krauss@20523
   196
    let
krauss@23494
   197
      val (MutualPart {f=SOME f, f_defthm=SOME f_def, ...}) = get_part fname parts
krauss@22497
   198
krauss@20523
   199
      val psimp = import sum_psimp_eq
krauss@20523
   200
      val (simp, restore_cond) = case cprems_of psimp of
krauss@20523
   201
                                   [] => (psimp, I)
krauss@20523
   202
                                 | [cond] => (implies_elim psimp (assume cond), implies_intr cond)
krauss@20523
   203
                                 | _ => sys_error "Too many conditions"
krauss@20523
   204
    in
krauss@22497
   205
      Goal.prove ctxt [] [] 
krauss@22497
   206
                 (HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs))
krauss@24171
   207
                 (fn _ => (LocalDefs.unfold_tac ctxt all_orig_fdefs)
krauss@22497
   208
                          THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1
wenzelm@30607
   209
                          THEN (simp_tac (local_simpset_of ctxt addsimps SumTree.proj_in_rules)) 1)
krauss@22497
   210
        |> restore_cond 
krauss@22497
   211
        |> export
wenzelm@20878
   212
    end
krauss@19770
   213
krauss@19770
   214
krauss@20523
   215
(* FIXME HACK *)
wenzelm@20878
   216
fun mk_applied_form ctxt caTs thm =
krauss@20523
   217
    let
krauss@20523
   218
      val thy = ProofContext.theory_of ctxt
krauss@20523
   219
      val xs = map_index (fn (i,T) => cterm_of thy (Free ("x" ^ string_of_int i, T))) caTs (* FIXME: Bind xs properly *)
krauss@20523
   220
    in
krauss@20523
   221
      fold (fn x => fn thm => combination thm (reflexive x)) xs thm
krauss@26196
   222
           |> Conv.fconv_rule (Thm.beta_conversion true)
krauss@20523
   223
           |> fold_rev forall_intr xs
wenzelm@26653
   224
           |> Thm.forall_elim_vars 0
krauss@20523
   225
    end
krauss@23494
   226
wenzelm@20878
   227
krauss@23494
   228
fun mutual_induct_rules lthy induct all_f_defs (Mutual {n, ST, RST, parts, ...}) =
krauss@19770
   229
    let
krauss@22623
   230
      val cert = cterm_of (ProofContext.theory_of lthy)
krauss@20949
   231
      val newPs = map2 (fn Pname => fn MutualPart {cargTs, ...} => 
krauss@21237
   232
                                       Free (Pname, cargTs ---> HOLogic.boolT))
krauss@20949
   233
                       (mutual_induct_Pnames (length parts))
krauss@20949
   234
                       parts
krauss@21237
   235
                       
krauss@21237
   236
      fun mk_P (MutualPart {cargTs, ...}) P =
krauss@21237
   237
          let
krauss@21237
   238
            val avars = map_index (fn (i,T) => Var (("a", i), T)) cargTs
krauss@21237
   239
            val atup = foldr1 HOLogic.mk_prod avars
krauss@21237
   240
          in
krauss@21237
   241
            tupled_lambda atup (list_comb (P, avars))
krauss@21237
   242
          end
krauss@21237
   243
          
krauss@21237
   244
      val Ps = map2 mk_P parts newPs
krauss@25555
   245
      val case_exp = SumTree.mk_sumcases HOLogic.boolT Ps
krauss@21237
   246
                     
krauss@21237
   247
      val induct_inst =
krauss@22623
   248
          forall_elim (cert case_exp) induct
krauss@25555
   249
                      |> full_simplify SumTree.sumcase_split_ss
krauss@21237
   250
                      |> full_simplify (HOL_basic_ss addsimps all_f_defs)
krauss@21237
   251
          
krauss@25046
   252
      fun project rule (MutualPart {cargTs, i, ...}) k =
krauss@21237
   253
          let
krauss@25046
   254
            val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int (j + k), T)) cargTs (* FIXME! *)
krauss@25555
   255
            val inj = SumTree.mk_inj ST n i (foldr1 HOLogic.mk_prod afs)
krauss@21237
   256
          in
krauss@25046
   257
            (rule
krauss@22623
   258
              |> forall_elim (cert inj)
krauss@25555
   259
              |> full_simplify SumTree.sumcase_split_ss
krauss@25046
   260
              |> fold_rev (forall_intr o cert) (afs @ newPs),
krauss@25046
   261
             k + length cargTs)
krauss@21237
   262
          end
krauss@19770
   263
    in
krauss@25046
   264
      fst (fold_map (project induct_inst) parts 0)
krauss@19770
   265
    end
krauss@21237
   266
    
wenzelm@20878
   267
krauss@22623
   268
fun mk_partial_rules_mutual lthy inner_cont (m as Mutual {parts, fqgars, ...}) proof =
krauss@19770
   269
    let
krauss@22166
   270
      val result = inner_cont proof
krauss@26529
   271
      val FundefResult {fs=[f], G, R, cases, psimps, trsimps, simple_pinducts=[simple_pinduct],
krauss@22166
   272
                        termination,domintros} = result
krauss@21237
   273
                                                                                                               
krauss@22733
   274
      val (all_f_defs, fs) = map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} =>
krauss@22733
   275
                                     (mk_applied_form lthy cargTs (symmetric f_def), f))
krauss@22733
   276
                                 parts
krauss@22733
   277
                             |> split_list
krauss@22497
   278
krauss@22497
   279
      val all_orig_fdefs = map (fn MutualPart {f_defthm = SOME f_def, ...} => f_def) parts
krauss@21237
   280
                           
wenzelm@20878
   281
      fun mk_mpsimp fqgar sum_psimp =
krauss@22497
   282
          in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp
krauss@21237
   283
          
krauss@23189
   284
      val rew_ss = HOL_basic_ss addsimps all_f_defs
krauss@22166
   285
      val mpsimps = map2 mk_mpsimp fqgars psimps
krauss@22166
   286
      val mtrsimps = map_option (map2 mk_mpsimp fqgars) trsimps
krauss@22623
   287
      val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m
krauss@23189
   288
      val mtermination = full_simplify rew_ss termination
krauss@23189
   289
      val mdomintros = map_option (map (full_simplify rew_ss)) domintros
krauss@20523
   290
    in
krauss@22733
   291
      FundefResult { fs=fs, G=G, R=R,
krauss@26529
   292
                     psimps=mpsimps, simple_pinducts=minducts,
krauss@22623
   293
                     cases=cases, termination=mtermination,
krauss@23189
   294
                     domintros=mdomintros,
krauss@22623
   295
                     trsimps=mtrsimps}
krauss@20523
   296
    end
krauss@21237
   297
      
krauss@23189
   298
fun prepare_fundef_mutual config defname fixes eqss lthy =
krauss@22166
   299
    let
krauss@22323
   300
      val mutual = analyze_eqs lthy defname (map fst fixes) (map Envir.beta_eta_contract eqss)
krauss@22166
   301
      val Mutual {fsum_var=(n, T), qglrs, ...} = mutual
krauss@22166
   302
          
krauss@22166
   303
      val ((fsum, goalstate, cont), lthy') =
krauss@23189
   304
          FundefCore.prepare_fundef config defname [((n, T), NoSyn)] qglrs lthy
krauss@22166
   305
          
krauss@22166
   306
      val (mutual', lthy'') = define_projections fixes mutual fsum lthy'
krauss@22166
   307
krauss@22166
   308
      val mutual_cont = mk_partial_rules_mutual lthy'' cont mutual'
krauss@22166
   309
    in
krauss@23819
   310
      ((goalstate, mutual_cont), lthy'')
krauss@22166
   311
    end
krauss@22166
   312
krauss@21237
   313
    
krauss@19770
   314
end