src/HOL/Tools/function_package/mutual.ML
author haftmann
Wed Jun 07 16:55:39 2006 +0200 (2006-06-07)
changeset 19818 5c5c1208a3fa
parent 19782 48c4632e2c28
child 19876 11d447d5d68c
permissions -rw-r--r--
adding case theorems for code generator
     1 (*  Title:      HOL/Tools/function_package/mutual.ML
     2     ID:         $Id$
     3     Author:     Alexander Krauss, TU Muenchen
     4 
     5 A package for general recursive function definitions. 
     6 Tools for mutual recursive definitions.
     7 
     8 *)
     9 
    10 signature FUNDEF_MUTUAL = 
    11 sig
    12   
    13   val prepare_fundef_mutual : thm list -> term list list -> theory ->
    14                               (FundefCommon.mutual_info * string * (FundefCommon.prep_result * theory))
    15 
    16 
    17   val mk_partial_rules_mutual : theory -> thm list -> FundefCommon.mutual_info -> FundefCommon.prep_result -> thm -> thm list ->
    18                                 FundefCommon.fundef_mresult
    19 end
    20 
    21 
    22 structure FundefMutual: FUNDEF_MUTUAL = 
    23 struct
    24 
    25 open FundefCommon
    26 
    27 
    28 
    29 fun check_const (Const C) = C
    30   | check_const _ = raise ERROR "Head symbol of every left hand side must be a constant." (* FIXME: Output the equation here *)
    31 
    32 
    33 
    34 
    35 
    36 fun split_def geq =
    37     let
    38 	val gs = Logic.strip_imp_prems geq
    39 	val eq = Logic.strip_imp_concl geq
    40 	val (f_args, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
    41 	val (fc, args) = strip_comb f_args
    42 	val f = check_const fc
    43 		    
    44 	val qs = fold_rev Term.add_frees args []
    45 		 
    46 	val rhs_new_vars = (Term.add_frees rhs []) \\ qs
    47 	val _ = if null rhs_new_vars then () 
    48 		else raise ERROR "Variables occur on right hand side only: " (* FIXME: Output vars here *)
    49     in
    50 	((f, length args), (qs, gs, args, rhs))
    51     end
    52 
    53 
    54 fun analyze_eqs thy eqss =
    55     let
    56 	fun all_equal ((x as ((n:string,T), k:int))::xs) = if forall (fn ((n',_),k') => n = n' andalso k = k') xs then x
    57 							   else raise ERROR ("All equations in a block must describe the same "
    58 									     ^ "constant and have the same number of arguments.")
    59 								      
    60 	val def_infoss = map (split_list o map split_def) eqss
    61 	val (consts, qgarss) = split_list (map (fn (Cis, eqs) => (all_equal Cis, eqs)) def_infoss)
    62 
    63 	val cnames = map (fst o fst) consts
    64 	val check_rcs = exists_Const (fn (n,_) => if n mem cnames 
    65 						  then raise ERROR "Recursive Calls not allowed in premises." else false)
    66 	val _ = forall (forall (fn (_, gs, _, _) => forall check_rcs gs)) qgarss
    67 
    68 	fun curried_types ((_,T), k) =
    69 	    let
    70 		val (caTs, uaTs) = chop k (binder_types T)
    71 	    in 
    72 		(caTs, uaTs ---> body_type T)
    73 	    end
    74 
    75 	val (caTss, resultTs) = split_list (map curried_types consts)
    76 	val argTs = map (foldr1 HOLogic.mk_prodT) caTss
    77 
    78 	val (RST,streeR, pthsR) = SumTools.mk_tree resultTs
    79 	val (ST, streeA, pthsA) = SumTools.mk_tree argTs
    80 
    81 	val def_name = foldr1 (fn (a,b) => a ^ "_" ^ b) (map Sign.base_name cnames)
    82 	val sfun_xname = def_name ^ "_sum"
    83 	val sfun_type = ST --> RST
    84 
    85     	val thy = Sign.add_consts_i [(sfun_xname, sfun_type, NoSyn)] thy (* Add the sum function *)
    86 	val sfun = Const (Sign.full_name thy sfun_xname, sfun_type)
    87 
    88 	fun define (((((n, T), _), caTs), (pthA, pthR)), qgars) (thy, rews) = 
    89 	    let 
    90 		val fxname = Sign.base_name n
    91 		val f = Const (n, T)
    92 		val vars = map_index (fn (i,T) => Free ("x" ^ string_of_int i, T)) caTs
    93 
    94 		val f_exp = SumTools.mk_proj streeR pthR (sfun $ SumTools.mk_inj streeA pthA (foldr1 HOLogic.mk_prod vars))
    95 		val def = Logic.mk_equals (list_comb (f, vars), f_exp)
    96 
    97 		val ([f_def], thy) = PureThy.add_defs_i false [((fxname ^ "_def", def), [])] thy
    98 		val rews' = (f, fold_rev lambda vars f_exp) :: rews
    99 	    in
   100 		(MutualPart {f_name=fxname, const=(n, T),cargTs=caTs,pthA=pthA,pthR=pthR,qgars=qgars,f_def=f_def}, (thy, rews'))
   101 	    end
   102 
   103 	val (parts, (thy, rews)) = fold_map define (((consts ~~ caTss)~~ (pthsA ~~ pthsR)) ~~ qgarss) (thy, [])
   104 
   105 	fun mk_qglrss (MutualPart {qgars, pthA, pthR, ...}) =
   106 	    let
   107 		fun convert_eqs (qs, gs, args, rhs) =
   108 		    (map Free qs, gs, SumTools.mk_inj streeA pthA (foldr1 HOLogic.mk_prod args), 
   109 		     SumTools.mk_inj streeR pthR (Pattern.rewrite_term thy rews [] rhs))
   110 	    in
   111 		map convert_eqs qgars
   112 	    end
   113 	    
   114 	val qglrss = map mk_qglrss parts
   115     in
   116 	(Mutual {name=def_name,sum_const=dest_Const sfun, ST=ST, RST=RST, streeA=streeA, streeR=streeR, parts=parts, qglrss=qglrss}, thy)
   117     end
   118 
   119 
   120 
   121 
   122 fun prepare_fundef_mutual congs eqss thy =
   123     let 
   124 	val (mutual, thy) = analyze_eqs thy eqss
   125 	val Mutual {name, sum_const, qglrss, ...} = mutual
   126 	val global_glrs = flat qglrss
   127 	val used = fold (fn (qs, _, _, _) => fold (curry op ins_string o fst o dest_Free) qs) global_glrs []
   128     in
   129 	(mutual, name, FundefPrep.prepare_fundef thy congs name (Const sum_const) global_glrs used)
   130     end
   131 
   132 
   133 (* Beta-reduce both sides of a meta-equality *)
   134 fun beta_norm_eq thm = 
   135     let
   136 	val (lhs, rhs) = dest_equals (cprop_of thm)
   137 	val lhs_conv = beta_conversion false lhs 
   138 	val rhs_conv = beta_conversion false rhs 
   139     in
   140 	transitive (symmetric lhs_conv) (transitive thm rhs_conv)
   141     end
   142 
   143 
   144 
   145 
   146 fun map_mutual2 f (Mutual {parts, ...}) =
   147     map2 (fn (p as MutualPart {qgars, ...}) => map2 (f p) qgars) parts
   148 
   149 
   150 
   151 fun recover_mutual_psimp thy RST streeR all_f_defs (MutualPart {f_def, pthR, ...}) (_,_,args,_) sum_psimp =
   152     let
   153 	val conds = cprems_of sum_psimp (* dom-condition and guards *)
   154 	val plain_eq = sum_psimp
   155                          |> fold (implies_elim_swp o assume) conds
   156 
   157 	val x = Free ("x", RST)
   158 
   159 	val f_def_inst = instantiate' [] (map (SOME o cterm_of thy) args) (freezeT f_def) (* FIXME: freezeT *)
   160     in
   161 	reflexive (cterm_of thy (lambda x (SumTools.mk_proj streeR pthR x)))  (*  PR(x) == PR(x) *)
   162 		  |> (fn it => combination it (plain_eq RS eq_reflection))
   163 		  |> beta_norm_eq (*  PR(S(I(as))) == PR(IR(...)) *)
   164 		  |> transitive f_def_inst (*  f ... == PR(IR(...)) *)
   165 		  |> simplify (HOL_basic_ss addsimps [SumTools.projl_inl, SumTools.projr_inr]) (*  f ... == ... *)
   166 		  |> simplify (HOL_basic_ss addsimps all_f_defs) (*  f ... == ... *)
   167 		  |> (fn it => it RS meta_eq_to_obj_eq)
   168 		  |> fold_rev implies_intr conds
   169     end
   170 
   171 
   172 
   173 
   174 
   175 fun mutual_induct_Pnames n = 
   176     if n < 5 then fst (chop n ["P","Q","R","S"])
   177     else map (fn i => "P" ^ string_of_int i) (1 upto n)
   178 	 
   179 	 
   180 val sum_case_rules = thms "Datatype.sum.cases"
   181 val split_apply = thm "Product_Type.split"
   182 		     
   183 		     
   184 fun mutual_induct_rules thy induct all_f_defs (Mutual {qglrss, RST, parts, streeA, ...}) =
   185     let
   186 	fun mk_P (MutualPart {cargTs, ...}) Pname =
   187 	    let
   188 		val avars = map_index (fn (i,T) => Var (("a", i), T)) cargTs
   189 		val atup = foldr1 HOLogic.mk_prod avars
   190 	    in
   191 		tupled_lambda atup (list_comb (Free (Pname, cargTs ---> HOLogic.boolT), avars))
   192 	    end
   193 	    
   194 	val Ps = map2 mk_P parts (mutual_induct_Pnames (length parts))
   195 	val case_exp = SumTools.mk_sumcases streeA HOLogic.boolT Ps
   196 		       
   197 	val induct_inst = 
   198 	    forall_elim (cterm_of thy case_exp) induct
   199 			|> full_simplify (HOL_basic_ss addsimps (split_apply :: sum_case_rules))
   200 		        |> full_simplify (HOL_basic_ss addsimps all_f_defs) 
   201 
   202 	fun mk_proj rule (MutualPart {cargTs, pthA, ...}) =
   203 	    let
   204 		val afs = map_index (fn (i,T) => Free ("a" ^ string_of_int i, T)) cargTs
   205 		val inj = SumTools.mk_inj streeA pthA (foldr1 HOLogic.mk_prod afs)
   206 	    in
   207 		rule 
   208 		    |> forall_elim (cterm_of thy inj)
   209 		    |> full_simplify (HOL_basic_ss addsimps (split_apply :: sum_case_rules))
   210 	    end
   211 
   212     in
   213 	map (mk_proj induct_inst) parts
   214     end
   215     
   216     
   217 
   218 
   219 
   220 fun mk_partial_rules_mutual thy congs (m as Mutual {qglrss, RST, parts, streeR, ...}) data complete_thm compat_thms =
   221     let
   222 	val result = FundefProof.mk_partial_rules thy congs data complete_thm compat_thms 
   223 	val FundefResult {f, G, R, compatibility, completeness, psimps, subset_pinduct,simple_pinduct,total_intro,dom_intros} = result
   224 
   225 	val sum_psimps = Library.unflat qglrss psimps
   226 
   227 	val all_f_defs = map (fn MutualPart {f_def, ...} => symmetric f_def) parts
   228 	val mpsimps = map_mutual2 (recover_mutual_psimp thy RST streeR all_f_defs) m sum_psimps
   229 	val minducts = mutual_induct_rules thy simple_pinduct all_f_defs m
   230         val termination = full_simplify (HOL_basic_ss addsimps all_f_defs) total_intro
   231     in
   232 	FundefMResult { f=f, G=G, R=R,
   233 			psimps=mpsimps, subset_pinducts=[subset_pinduct], simple_pinducts=minducts,
   234 			cases=completeness, termination=termination, domintros=dom_intros}
   235     end
   236     
   237 
   238 end
   239 
   240 
   241 
   242 
   243 
   244 
   245 
   246 
   247 
   248 
   249 
   250 
   251 
   252 
   253 
   254 
   255 
   256 
   257 
   258 
   259 
   260 
   261 
   262 
   263