src/HOL/Tools/function_package/mutual.ML
changeset 20523 36a59e5d0039
parent 19922 984ae977f7aa
child 20534 b147d0c13f6e
--- a/src/HOL/Tools/function_package/mutual.ML	Wed Sep 13 00:38:38 2006 +0200
+++ b/src/HOL/Tools/function_package/mutual.ML	Wed Sep 13 12:05:50 2006 +0200
@@ -10,12 +10,17 @@
 signature FUNDEF_MUTUAL = 
 sig
   
-  val prepare_fundef_mutual : thm list -> term list list -> theory ->
-                              (FundefCommon.mutual_info * string * (FundefCommon.prep_result * theory))
+  val prepare_fundef_mutual : ((string * typ) * mixfix) list 
+                              -> term list 
+                              -> local_theory 
+                              -> ((FundefCommon.mutual_info * string * FundefCommon.prep_result) * local_theory)
 
 
-  val mk_partial_rules_mutual : theory -> FundefCommon.mutual_info -> FundefCommon.prep_result -> thm -> 
+  val mk_partial_rules_mutual : Proof.context -> FundefCommon.mutual_info -> FundefCommon.prep_result -> thm -> 
                                 FundefCommon.fundef_mresult
+
+  val sort_by_function : FundefCommon.mutual_info -> string list -> 'a list -> 'a list list
+
 end
 
 
@@ -24,46 +29,83 @@
 
 open FundefCommon
 
+(* Theory dependencies *)
+val sum_case_rules = thms "Datatype.sum.cases"
+val split_apply = thm "Product_Type.split"
+
 
 
-fun check_const (Const C) = C
-  | check_const _ = raise ERROR "Head symbol of every left hand side must be a constant." (* FIXME: Output the equation here *)
+fun mutual_induct_Pnames n = 
+    if n < 5 then fst (chop n ["P","Q","R","S"])
+    else map (fn i => "P" ^ string_of_int i) (1 upto n)
+	 
+
+fun check_head fs t =
+    if (case t of 
+          (Free (n, _)) => n mem fs
+        | _ => false)
+    then dest_Free t
+    else raise ERROR "Head symbol of every left hand side must be the new function." (* FIXME: Output the equation here *)
 
 
+fun open_all_all (Const ("all", _) $ Abs (n, T, b)) = apfst (cons (n, T)) (open_all_all b)
+  | open_all_all t = ([], t)
 
 
 
-fun split_def geq =
+(* Builds a curried clause description in abstracted form *)
+fun split_def fnames geq =
     let
-	val gs = Logic.strip_imp_prems geq
-	val eq = Logic.strip_imp_concl geq
-	val (f_args, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
-	val (fc, args) = strip_comb f_args
-	val f = check_const fc
-		    
-	val qs = fold_rev Term.add_frees args []
-		 
-	val rhs_new_vars = (Term.add_frees rhs []) \\ qs
-	val _ = if null rhs_new_vars then () 
-		else raise ERROR "Variables occur on right hand side only: " (* FIXME: Output vars here *)
+      val (qs, imp) = open_all_all geq
+
+      val gs = Logic.strip_imp_prems imp
+      val eq = Logic.strip_imp_concl imp
+
+      val (f_args, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
+      val (fc, args) = strip_comb f_args
+      val f as (fname, _) = check_head fnames fc
+
+      val add_bvs = fold_aterms (fn Bound i => insert (op =) i | _ => I)
+      val rhs_only = (add_bvs rhs [] \\ fold add_bvs args [])
+                        |> map (fst o nth (rev qs))
+      val _ = if null rhs_only then () 
+	      else raise ERROR "Variables occur on right hand side only." (* FIXME: Output vars *)
     in
-	((f, length args), (qs, gs, args, rhs))
+	((f, length args), (fname, qs, gs, args, rhs))
     end
 
+fun get_part fname =
+    the o find_first (fn (MutualPart {fvar=(n,_), ...}) => n = fname)
 
-fun analyze_eqs thy eqss =
+(* FIXME *)
+fun mk_prod_abs e (t1, t2) =
+    let 
+      val bTs = rev (map snd e)
+      val T1 = fastype_of1 (bTs, t1)
+      val T2 = fastype_of1 (bTs, t2)
+    in
+      HOLogic.pair_const T1 T2 $ t1 $ t2
+    end;
+
+
+fun analyze_eqs ctxt fnames eqs =
     let
+      (* FIXME: Add check for number of arguments
 	fun all_equal ((x as ((n:string,T), k:int))::xs) = if forall (fn ((n',_),k') => n = n' andalso k = k') xs then x
 							   else raise ERROR ("All equations in a block must describe the same "
-									     ^ "constant and have the same number of arguments.")
+									     ^ "function and have the same number of arguments.")
+       *)
 								      
-	val def_infoss = map (split_list o map split_def) eqss
-	val (consts, qgarss) = split_list (map (fn (Cis, eqs) => (all_equal Cis, eqs)) def_infoss)
+        val (consts, fqgars) = split_list (map (split_def fnames) eqs)
 
-	val cnames = map (fst o fst) consts
-	val check_rcs = exists_Const (fn (n,_) => if n mem cnames 
-						  then raise ERROR "Recursive Calls not allowed in premises." else false)
-	val _ = forall (forall (fn (_, gs, _, _) => forall check_rcs gs)) qgarss
+        val different_consts = distinct (eq_fst (eq_fst eq_str)) consts
+	val cnames = map (fst o fst) different_consts
+
+	val check_rcs = exists_subterm (fn Free (n, _) => if n mem cnames 
+						          then raise ERROR "Recursive Calls not allowed in premises." else false
+                                         | _ => false)
+                        
+	val _ = forall (fn (_, _, gs, _, _) => forall check_rcs gs) fqgars
 
 	fun curried_types ((_,T), k) =
 	    let
@@ -72,61 +114,87 @@
 		(caTs, uaTs ---> body_type T)
 	    end
 
-	val (caTss, resultTs) = split_list (map curried_types consts)
+	val (caTss, resultTs) = split_list (map curried_types different_consts)
 	val argTs = map (foldr1 HOLogic.mk_prodT) caTss
 
-	val (RST,streeR, pthsR) = SumTools.mk_tree resultTs
+	val (RST,streeR, pthsR) = SumTools.mk_tree_distinct resultTs
 	val (ST, streeA, pthsA) = SumTools.mk_tree argTs
 
 	val def_name = foldr1 (fn (a,b) => a ^ "_" ^ b) (map Sign.base_name cnames)
-	val sfun_xname = def_name ^ "_sum"
-	val sfun_type = ST --> RST
+	val fsum_type = ST --> RST
 
-    	val thy = Sign.add_consts_i [(sfun_xname, sfun_type, NoSyn)] thy (* Add the sum function *)
-	val sfun = Const (Sign.full_name thy sfun_xname, sfun_type)
+        val ([fsum_var_name], _) = Variable.add_fixes [ def_name ^ "_sum" ] ctxt
+        val fsum_var = (fsum_var_name, fsum_type)
 
-	fun define (((((n, T), _), caTs), (pthA, pthR)), qgars) (thy, rews) = 
+	fun define (fvar as (n, T), _) caTs pthA pthR = 
 	    let 
-		val fxname = Sign.base_name n
-		val f = Const (n, T)
-		val vars = map_index (fn (i,T) => Free ("x" ^ string_of_int i, T)) caTs
+		val vars = map_index (fn (i,T) => Free ("x" ^ string_of_int i, T)) caTs (* FIXME: Bind xs properly *)
 
-		val f_exp = SumTools.mk_proj streeR pthR (sfun $ SumTools.mk_inj streeA pthA (foldr1 HOLogic.mk_prod vars))
-		val def = Logic.mk_equals (list_comb (f, vars), f_exp)
-
-		val ([f_def], thy) = PureThy.add_defs_i false [((fxname ^ "_def", def), [])] thy
-		val rews' = (f, fold_rev lambda vars f_exp) :: rews
+		val f_exp = SumTools.mk_proj streeR pthR (Free fsum_var $ SumTools.mk_inj streeA pthA (foldr1 HOLogic.mk_prod vars))
+		val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp)
+                                          
+		val rew = (n, fold_rev lambda vars f_exp) 
 	    in
-		(MutualPart {f_name=fxname, const=(n, T),cargTs=caTs,pthA=pthA,pthR=pthR,qgars=qgars,f_def=f_def}, (thy, rews'))
+		(MutualPart {fvar=fvar,cargTs=caTs,pthA=pthA,pthR=pthR,f_def=def,f=NONE,f_defthm=NONE}, rew)
 	    end
 
-	val (parts, (thy, rews)) = fold_map define (((consts ~~ caTss)~~ (pthsA ~~ pthsR)) ~~ qgarss) (thy, [])
+	val (parts, rews) = split_list (map4 define different_consts caTss pthsA pthsR)
 
-	fun mk_qglrss (MutualPart {qgars, pthA, pthR, ...}) =
-	    let
-		fun convert_eqs (qs, gs, args, rhs) =
-		    (map Free qs, gs, SumTools.mk_inj streeA pthA (foldr1 HOLogic.mk_prod args), 
-		     SumTools.mk_inj streeR pthR (Pattern.rewrite_term thy rews [] rhs))
-	    in
-		map convert_eqs qgars
-	    end
-	    
-	val qglrss = map mk_qglrss parts
+        fun convert_eqs (f, qs, gs, args, rhs) =
+            let
+              val MutualPart {pthA, pthR, ...} = get_part f parts
+            in
+	      (qs, gs, SumTools.mk_inj streeA pthA (foldr1 (mk_prod_abs qs) args), 
+	       SumTools.mk_inj streeR pthR (replace_frees rews rhs)
+                               |> Envir.norm_term (Envir.empty 0))
+            end
+
+	val qglrs = map convert_eqs fqgars
     in
-	(Mutual {name=def_name,sum_const=dest_Const sfun, ST=ST, RST=RST, streeA=streeA, streeR=streeR, parts=parts, qglrss=qglrss}, thy)
+	Mutual {defname=def_name,fsum_var=fsum_var, ST=ST, RST=RST, streeA=streeA, streeR=streeR, 
+                parts=parts, fqgars=fqgars, qglrs=qglrs, fsum=NONE}
     end
 
 
 
 
-fun prepare_fundef_mutual congs eqss thy =
+fun define_projections fixes mutual fsum lthy =
+    let
+      fun def ((MutualPart {fvar=(fname, fT), cargTs, pthA, pthR, f_def, ...}), (_, mixfix)) lthy =
+          let
+            val ((f, (_, f_defthm)), lthy') = LocalTheory.def ((fname, mixfix), 
+                                                               ((fname ^ "_def", []), Term.subst_bound (fsum, f_def))) 
+                                                              lthy
+          in
+            (MutualPart {fvar=(fname, fT), cargTs=cargTs, pthA=pthA, pthR=pthR, f_def=f_def, 
+                        f=SOME f, f_defthm=SOME f_defthm },
+             lthy')
+          end
+
+      val Mutual { defname, fsum_var, ST, RST, streeA, streeR, parts, fqgars, qglrs, ... } = mutual
+      val (parts', lthy') = fold_map def (parts ~~ fixes) lthy 
+    in
+      (Mutual { defname=defname, fsum_var=fsum_var, ST=ST, RST=RST, streeA=streeA, streeR=streeR, parts=parts', 
+                fqgars=fqgars, qglrs=qglrs, fsum=SOME fsum },
+       lthy')
+    end
+
+
+
+  
+
+
+fun prepare_fundef_mutual fixes eqss lthy =
     let 
-	val (mutual, thy) = analyze_eqs thy eqss
-	val Mutual {name, sum_const, qglrss, ...} = mutual
-	val global_glrs = flat qglrss
-	val used = fold (fn (qs, _, _, _) => fold (curry op ins_string o fst o dest_Free) qs) global_glrs []
+	val mutual = analyze_eqs lthy (map (fst o fst) fixes) eqss
+	val Mutual {defname, fsum_var=(n, T), qglrs, ...} = mutual
+
+        val (prep_result, fsum, lthy') =
+            FundefPrep.prepare_fundef defname (n, T, NoSyn) qglrs lthy
+
+        val (mutual', lthy'') = define_projections fixes mutual fsum lthy'
     in
-	(mutual, name, FundefPrep.prepare_fundef thy congs name (Const sum_const) global_glrs used)
+      ((mutual', defname, prep_result), lthy'')
     end
 
 
@@ -140,48 +208,76 @@
 	transitive (symmetric lhs_conv) (transitive thm rhs_conv)
     end
 
-
-
-
-fun map_mutual2 f (Mutual {parts, ...}) =
-    map2 (fn (p as MutualPart {qgars, ...}) => map2 (f p) qgars) parts
-
+fun beta_reduce thm = Thm.equal_elim (Thm.beta_conversion true (cprop_of thm)) thm
 
 
-fun recover_mutual_psimp thy RST streeR all_f_defs (MutualPart {f_def, pthR, ...}) (_,_,args,_) sum_psimp =
+fun in_context ctxt (f, pre_qs, pre_gs, pre_args, pre_rhs) F =
     let
-	val conds = cprems_of sum_psimp (* dom-condition and guards *)
-	val plain_eq = sum_psimp
-                         |> fold (implies_elim_swp o assume) conds
+      val thy = ProofContext.theory_of ctxt
 
-	val x = Free ("x", RST)
+      val oqnames = map fst pre_qs
+      val (qs, ctxt') = Variable.invent_fixes oqnames ctxt
+                                           |>> map2 (fn (_, T) => fn n => Free (n, T)) pre_qs
 
-	val f_def_inst = instantiate' [] (map (SOME o cterm_of thy) args) (Thm.freezeT f_def) (* FIXME: freezeT *)
+      fun inst t = subst_bounds (rev qs, t)
+      val gs = map inst pre_gs
+      val args = map inst pre_args
+      val rhs = inst pre_rhs
+
+      val cqs = map (cterm_of thy) qs
+      val ags = map (assume o cterm_of thy) gs
+                
+      val import = fold forall_elim cqs
+                   #> fold implies_elim_swp ags
+
+      val export = fold_rev (implies_intr o cprop_of) ags
+                   #> fold_rev forall_intr_rename (oqnames ~~ cqs)
     in
-	reflexive (cterm_of thy (lambda x (SumTools.mk_proj streeR pthR x)))  (*  PR(x) == PR(x) *)
-		  |> (fn it => combination it (plain_eq RS eq_reflection))
-		  |> beta_norm_eq (*  PR(S(I(as))) == PR(IR(...)) *)
-		  |> transitive f_def_inst (*  f ... == PR(IR(...)) *)
-		  |> simplify (HOL_basic_ss addsimps [SumTools.projl_inl, SumTools.projr_inr]) (*  f ... == ... *)
-		  |> simplify (HOL_basic_ss addsimps all_f_defs) (*  f ... == ... *)
-		  |> (fn it => it RS meta_eq_to_obj_eq)
-		  |> fold_rev implies_intr conds
+      F (f, qs, gs, args, rhs) import export
     end
 
 
+fun recover_mutual_psimp thy RST streeR all_f_defs parts (f, _, _, args, _) import (export : thm -> thm) sum_psimp_eq =
+    let
+      val (MutualPart {f_defthm=SOME f_def, pthR, ...}) = get_part f parts
 
+      val psimp = import sum_psimp_eq
+      val (simp, restore_cond) = case cprems_of psimp of
+                                   [] => (psimp, I)
+                                 | [cond] => (implies_elim psimp (assume cond), implies_intr cond)
+                                 | _ => sys_error "Too many conditions"
+
+      val x = Free ("x", RST)
+
+      val f_def_inst = fold (fn arg => fn thm => combination thm (reflexive (cterm_of thy arg))) args (Thm.freezeT f_def) (* FIXME *)
+                            |> beta_reduce
+    in
+      reflexive (cterm_of thy (lambda x (SumTools.mk_proj streeR pthR x)))  (*  PR(x) == PR(x) *)
+                |> (fn it => combination it (simp RS eq_reflection))
+	        |> beta_norm_eq (*  PR(S(I(as))) == PR(IR(...)) *)
+                |> transitive f_def_inst (*  f ... == PR(IR(...)) *)
+                |> simplify (HOL_basic_ss addsimps [SumTools.projl_inl, SumTools.projr_inr]) (*  f ... == ... *)
+                |> simplify (HOL_basic_ss addsimps all_f_defs) (*  f ... == ... *)
+                |> (fn it => it RS meta_eq_to_obj_eq)
+                |> restore_cond
+                |> export
+    end 
 
 
-fun mutual_induct_Pnames n = 
-    if n < 5 then fst (chop n ["P","Q","R","S"])
-    else map (fn i => "P" ^ string_of_int i) (1 upto n)
-	 
-	 
-val sum_case_rules = thms "Datatype.sum.cases"
-val split_apply = thm "Product_Type.split"
+(* FIXME HACK *)
+fun mk_applied_form ctxt caTs thm = 
+    let
+      val thy = ProofContext.theory_of ctxt
+      val xs = map_index (fn (i,T) => cterm_of thy (Free ("x" ^ string_of_int i, T))) caTs (* FIXME: Bind xs properly *)
+    in
+      fold (fn x => fn thm => combination thm (reflexive x)) xs thm
+           |> beta_reduce
+           |> fold_rev forall_intr xs
+           |> forall_elim_vars 0
+    end
+
 		     
-		     
-fun mutual_induct_rules thy induct all_f_defs (Mutual {qglrss, RST, parts, streeA, ...}) =
+fun mutual_induct_rules thy induct all_f_defs (Mutual {RST, parts, streeA, ...}) =
     let
 	fun mk_P (MutualPart {cargTs, ...}) Pname =
 	    let
@@ -210,29 +306,53 @@
 	    end
 
     in
-	map (mk_proj induct_inst) parts
+      map (mk_proj induct_inst) parts
     end
-    
+
     
 
 
 
-fun mk_partial_rules_mutual thy (m as Mutual {qglrss, RST, parts, streeR, ...}) data result =
+fun mk_partial_rules_mutual lthy (m as Mutual {RST, parts, streeR, fqgars, ...}) data prep_result =
     let
-	val result = FundefProof.mk_partial_rules thy data result
-	val FundefResult {f, G, R, completeness, psimps, subset_pinduct,simple_pinduct,total_intro,dom_intros} = result
-
-	val sum_psimps = Library.unflat qglrss psimps
+      val thy = ProofContext.theory_of lthy
+                
+      val result = FundefProof.mk_partial_rules thy data prep_result
+      val FundefResult {f, G, R, completeness, psimps, subset_pinduct,simple_pinduct,total_intro,dom_intros} = result
+                                                                                                               
+      val all_f_defs = map (fn MutualPart {f_defthm = SOME f_def, cargTs, ...} => 
+                               mk_applied_form lthy cargTs (symmetric (Thm.freezeT f_def))) 
+                           parts
+                           |> print
+                          
+      fun mk_mpsimp fqgar sum_psimp = 
+          in_context lthy fqgar (recover_mutual_psimp thy RST streeR all_f_defs parts) sum_psimp
+          
+      val mpsimps = map2 mk_mpsimp fqgars psimps
+                    
+      val minducts = mutual_induct_rules thy simple_pinduct all_f_defs m
+      val termination = full_simplify (HOL_basic_ss addsimps all_f_defs) total_intro
+    in
+      FundefMResult { f=f, G=G, R=R,
+		      psimps=mpsimps, subset_pinducts=[subset_pinduct], simple_pinducts=minducts,
+		      cases=completeness, termination=termination, domintros=dom_intros }
+    end
 
-	val all_f_defs = map (fn MutualPart {f_def, ...} => symmetric f_def) parts
-	val mpsimps = map_mutual2 (recover_mutual_psimp thy RST streeR all_f_defs) m sum_psimps
-	val minducts = mutual_induct_rules thy simple_pinduct all_f_defs m
-        val termination = full_simplify (HOL_basic_ss addsimps all_f_defs) total_intro
-    in
-	FundefMResult { f=f, G=G, R=R,
-			psimps=mpsimps, subset_pinducts=[subset_pinduct], simple_pinducts=minducts,
-			cases=completeness, termination=termination, domintros=dom_intros}
-    end
+
+
+(* puts an object in the "right bucket" *) 
+fun store_grouped P x [] = []
+  | store_grouped P x ((l, xs)::bs) = 
+    if P (x, l) then ((l, x::xs)::bs) else ((l, xs)::store_grouped P x bs)
+
+fun sort_by_function (Mutual {fqgars, ...}) names xs =
+      fold_rev (store_grouped (eq_str o apfst fst))  (* fill *)
+               (map name_of_fqgar fqgars ~~ xs)      (* the name-thm pairs *)
+               (map (rpair []) names)                (* in the empty buckets labeled with names *)
+
+         |> map (snd #> map snd)                     (* and remove the labels afterwards *)
+
+
     
 
 end