--- a/src/HOL/Tools/function_package/mutual.ML Mon Jun 25 00:36:42 2007 +0200
+++ b/src/HOL/Tools/function_package/mutual.ML Mon Jun 25 12:16:27 2007 +0200
@@ -31,6 +31,36 @@
(* Theory dependencies *)
val sum_case_rules = thms "Sum_Type.sum_cases"
val split_apply = thm "Product_Type.split"
+val projl_inl = thm "Sum_Type.Projl_Inl"
+val projr_inr = thm "Sum_Type.Projr_Inr"
+
+
+(* Sum types *)
+fun mk_sumT LT RT = Type ("+", [LT, RT])
+fun mk_sumcase TL TR T l r = Const (@{const_name "Sum_Type.sum_case"}, (TL --> T) --> (TR --> T) --> mk_sumT TL TR --> T) $ l $ r
+
+val App = curry op $
+
+fun mk_inj ST n i =
+ BalancedTree.access
+ { init = (ST, I : term -> term),
+ left = (fn (T as Type ("+", [LT, RT]), inj) => (LT, App (Const (@{const_name "Inl"}, LT --> T)) o inj)),
+ right =(fn (T as Type ("+", [LT, RT]), inj) => (RT, App (Const (@{const_name "Inr"}, RT --> T)) o inj))} n i
+ |> snd
+
+fun mk_proj ST n i =
+ BalancedTree.access
+ { init = (ST, I : term -> term),
+ left = (fn (T as Type ("+", [LT, RT]), proj) => (LT, proj o App (Const (@{const_name "Projl"}, T --> LT)))),
+ right =(fn (T as Type ("+", [LT, RT]), proj) => (RT, proj o App (Const (@{const_name "Projr"}, T --> RT))))} n i
+ |> snd
+
+fun mk_sumcases T fs =
+ BalancedTree.make (fn ((f, fT), (g, gT)) => (mk_sumcase fT gT T f g, mk_sumT fT gT))
+ (map (fn f => (f, domain_type (fastype_of f))) fs)
+ |> fst
+
+
type qgar = string * (string * typ) list * term list * term list * term
@@ -39,10 +69,10 @@
datatype mutual_part =
MutualPart of
{
+ i : int,
+ i' : int,
fvar : string * typ,
cargTs: typ list,
- pthA: SumTools.sum_path,
- pthR: SumTools.sum_path,
f_def: term,
f: term option,
@@ -53,12 +83,12 @@
datatype mutual_info =
Mutual of
{
+ n : int,
+ n' : int,
fsum_var : string * typ,
ST: typ,
RST: typ,
- streeA: SumTools.sum_tree,
- streeR: SumTools.sum_tree,
parts: mutual_part list,
fqgars: qgar list,
@@ -87,6 +117,7 @@
fun analyze_eqs ctxt defname fs eqs =
let
+ val num = length fs
val fnames = map fst fs
val fqgars = map split_def eqs
val arities = mk_arities fqgars
@@ -102,40 +133,44 @@
val (caTss, resultTs) = split_list (map curried_types fs)
val argTs = map (foldr1 HOLogic.mk_prodT) caTss
- val (RST,streeR, pthsR) = SumTools.mk_tree_distinct resultTs
- val (ST, streeA, pthsA) = SumTools.mk_tree argTs
+ val dresultTs = distinct (Type.eq_type Vartab.empty) resultTs
+ val n' = length dresultTs
+
+ val RST = BalancedTree.make (uncurry mk_sumT) dresultTs
+ val ST = BalancedTree.make (uncurry mk_sumT) argTs
val fsum_type = ST --> RST
val ([fsum_var_name], _) = Variable.add_fixes [ defname ^ "_sum" ] ctxt
val fsum_var = (fsum_var_name, fsum_type)
- fun define (fvar as (n, T)) caTs pthA pthR =
+ fun define (fvar as (n, T)) caTs resultT i =
let
- val vars = map_index (fn (i,T) => Free ("x" ^ string_of_int i, T)) caTs (* FIXME: Bind xs properly *)
+ val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *)
+ val i' = find_index (fn Ta => Type.eq_type Vartab.empty (Ta, resultT)) dresultTs + 1
- val f_exp = SumTools.mk_proj streeR pthR (Free fsum_var $ SumTools.mk_inj streeA pthA (foldr1 HOLogic.mk_prod vars))
+ val f_exp = mk_proj RST n' i' (Free fsum_var $ mk_inj ST num i (foldr1 HOLogic.mk_prod vars))
val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp)
val rew = (n, fold_rev lambda vars f_exp)
in
- (MutualPart {fvar=fvar,cargTs=caTs,pthA=pthA,pthR=pthR,f_def=def,f=NONE,f_defthm=NONE}, rew)
+ (MutualPart {i=i, i'=i', fvar=fvar,cargTs=caTs,f_def=def,f=NONE,f_defthm=NONE}, rew)
end
- val (parts, rews) = split_list (map4 define fs caTss pthsA pthsR)
+ val (parts, rews) = split_list (map4 define fs caTss resultTs (1 upto num))
fun convert_eqs (f, qs, gs, args, rhs) =
let
- val MutualPart {pthA, pthR, ...} = get_part f parts
+ val MutualPart {i, i', ...} = get_part f parts
in
- (qs, gs, SumTools.mk_inj streeA pthA (foldr1 (mk_prod_abs qs) args),
- SumTools.mk_inj streeR pthR (replace_frees rews rhs)
- |> Envir.norm_term (Envir.empty 0))
+ (qs, gs, mk_inj ST num i (foldr1 (mk_prod_abs qs) args),
+ mk_inj RST n' i' (replace_frees rews rhs)
+ |> Envir.beta_norm)
end
val qglrs = map convert_eqs fqgars
in
- Mutual {fsum_var=fsum_var, ST=ST, RST=RST, streeA=streeA, streeR=streeR,
+ Mutual {n=num, n'=n', fsum_var=fsum_var, ST=ST, RST=RST,
parts=parts, fqgars=fqgars, qglrs=qglrs, fsum=NONE}
end
@@ -144,22 +179,22 @@
fun define_projections fixes mutual fsum lthy =
let
- fun def ((MutualPart {fvar=(fname, fT), cargTs, pthA, pthR, f_def, ...}), (_, mixfix)) lthy =
+ fun def ((MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs, f_def, ...}), (_, mixfix)) lthy =
let
val ((f, (_, f_defthm)), lthy') =
LocalTheory.def Thm.internalK ((fname, mixfix),
((fname ^ "_def", []), Term.subst_bound (fsum, f_def)))
lthy
in
- (MutualPart {fvar=(fname, fT), cargTs=cargTs, pthA=pthA, pthR=pthR, f_def=f_def,
+ (MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs=cargTs, f_def=f_def,
f=SOME f, f_defthm=SOME f_defthm },
lthy')
end
- val Mutual { fsum_var, ST, RST, streeA, streeR, parts, fqgars, qglrs, ... } = mutual
+ val Mutual { n, n', fsum_var, ST, RST, parts, fqgars, qglrs, ... } = mutual
val (parts', lthy') = fold_map def (parts ~~ fixes) lthy
in
- (Mutual { fsum_var=fsum_var, ST=ST, RST=RST, streeA=streeA, streeR=streeR, parts=parts',
+ (Mutual { n=n, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, parts=parts',
fqgars=fqgars, qglrs=qglrs, fsum=SOME fsum },
lthy')
end
@@ -184,7 +219,7 @@
val ags = map (assume o cterm_of thy) gs
val import = fold forall_elim cqs
- #> fold implies_elim_swp ags
+ #> fold (flip implies_elim) ags
val export = fold_rev (implies_intr o cprop_of) ags
#> fold_rev forall_intr_rename (oqnames ~~ cqs)
@@ -194,7 +229,7 @@
fun recover_mutual_psimp all_orig_fdefs parts ctxt (fname, _, _, args, rhs) import (export : thm -> thm) sum_psimp_eq =
let
- val (MutualPart {f=SOME f, f_defthm=SOME f_def, pthR, ...}) = get_part fname parts
+ val (MutualPart {f=SOME f, f_defthm=SOME f_def, ...}) = get_part fname parts
val psimp = import sum_psimp_eq
val (simp, restore_cond) = case cprems_of psimp of
@@ -206,7 +241,7 @@
(HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs))
(fn _ => SIMPSET (unfold_tac all_orig_fdefs)
THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1
- THEN SIMPSET' (fn ss => simp_tac (ss addsimps [SumTools.projl_inl, SumTools.projr_inr])) 1)
+ THEN SIMPSET' (fn ss => simp_tac (ss addsimps [projl_inl, projr_inr])) 1)
|> restore_cond
|> export
end
@@ -223,9 +258,9 @@
|> fold_rev forall_intr xs
|> forall_elim_vars 0
end
-
+
-fun mutual_induct_rules lthy induct all_f_defs (Mutual {RST, parts, streeA, ...}) =
+fun mutual_induct_rules lthy induct all_f_defs (Mutual {n, ST, RST, parts, ...}) =
let
val cert = cterm_of (ProofContext.theory_of lthy)
val newPs = map2 (fn Pname => fn MutualPart {cargTs, ...} =>
@@ -242,26 +277,25 @@
end
val Ps = map2 mk_P parts newPs
- val case_exp = SumTools.mk_sumcases streeA HOLogic.boolT Ps
+ val case_exp = mk_sumcases HOLogic.boolT Ps
val induct_inst =
forall_elim (cert case_exp) induct
|> full_simplify (HOL_basic_ss addsimps (split_apply :: sum_case_rules))
|> full_simplify (HOL_basic_ss addsimps all_f_defs)
- fun mk_proj rule (MutualPart {cargTs, pthA, ...}) =
+ fun project rule (MutualPart {cargTs, i, ...}) =
let
- val afs = map_index (fn (i,T) => Free ("a" ^ string_of_int i, T)) cargTs
- val inj = SumTools.mk_inj streeA pthA (foldr1 HOLogic.mk_prod afs)
+ val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int j, T)) cargTs
+ val inj = mk_inj ST n i (foldr1 HOLogic.mk_prod afs)
in
rule
|> forall_elim (cert inj)
|> full_simplify (HOL_basic_ss addsimps (split_apply :: sum_case_rules))
|> fold_rev (forall_intr o cert) (afs @ newPs)
end
-
in
- map (mk_proj induct_inst) parts
+ map (project induct_inst) parts
end
@@ -295,7 +329,6 @@
trsimps=mtrsimps}
end
-
(* puts an object in the "right bucket" *)
fun store_grouped P x [] = []
| store_grouped P x ((l, xs)::bs) =
--- a/src/HOL/Tools/function_package/sum_tools.ML Mon Jun 25 00:36:42 2007 +0200
+++ /dev/null Thu Jan 01 00:00:00 1970 +0000
@@ -1,124 +0,0 @@
-(* Title: HOL/Tools/function_package/sum_tools.ML
- ID: $Id$
- Author: Alexander Krauss, TU Muenchen
-
-A package for general recursive function definitions.
-Tools for mutual recursive definitions. This could actually be useful for other packages, too, but needs
-some cleanup first...
-
-*)
-
-signature SUM_TOOLS =
-sig
- type sum_tree
- type sum_path
-
- val projl_inl: thm
- val projr_inr: thm
-
- val mk_tree : typ list -> typ * sum_tree * sum_path list
- val mk_tree_distinct : typ list -> typ * sum_tree * sum_path list
-
- val mk_proj: sum_tree -> sum_path -> term -> term
- val mk_inj: sum_tree -> sum_path -> term -> term
-
- val mk_sumcases: sum_tree -> typ -> term list -> term
-end
-
-
-structure SumTools: SUM_TOOLS =
-struct
-
-val inlN = "Sum_Type.Inl"
-val inrN = "Sum_Type.Inr"
-val sumcaseN = "Sum_Type.sum_case"
-
-val projlN = "Sum_Type.Projl"
-val projrN = "Sum_Type.Projr"
-val projl_inl = thm "Sum_Type.Projl_Inl"
-val projr_inr = thm "Sum_Type.Projr_Inr"
-
-fun mk_sumT LT RT = Type ("+", [LT, RT])
-fun mk_sumcase TL TR T l r = Const (sumcaseN, (TL --> T) --> (TR --> T) --> mk_sumT TL TR --> T) $ l $ r
-
-datatype sum_tree
- = Leaf of typ
- | Branch of (typ * (typ * sum_tree) * (typ * sum_tree))
-
-type sum_path = bool list (* true: left, false: right *)
-
-fun sum_type_of (Leaf T) = T
- | sum_type_of (Branch (ST,(LT,_),(RT,_))) = ST
-
-
-fun mk_tree Ts =
- let
- fun mk_tree' 1 [T] = (T, Leaf T, [[]])
- | mk_tree' n Ts =
- let
- val n2 = n div 2
- val (lTs, rTs) = chop n2 Ts
- val (TL, ltree, lpaths) = mk_tree' n2 lTs
- val (TR, rtree, rpaths) = mk_tree' (n - n2) rTs
- val T = mk_sumT TL TR
- val pths = map (cons true) lpaths @ map (cons false) rpaths
- in
- (T, Branch (T, (TL, ltree), (TR, rtree)), pths)
- end
- in
- mk_tree' (length Ts) Ts
- end
-
-
-fun mk_tree_distinct Ts =
- let
- fun insert_once T Ts =
- let
- val i = find_index_eq T Ts
- in
- if i = ~1 then (length Ts, Ts @ [T]) else (i, Ts)
- end
-
- val (idxs, dist_Ts) = fold_map insert_once Ts []
-
- val (ST, tree, pths) = mk_tree dist_Ts
- in
- (ST, tree, map (nth pths) idxs)
- end
-
-
-fun mk_inj (Leaf _) [] t = t
- | mk_inj (Branch (ST, (LT, tr), _)) (true::pth) t =
- Const (inlN, LT --> ST) $ mk_inj tr pth t
- | mk_inj (Branch (ST, _, (RT, tr))) (false::pth) t =
- Const (inrN, RT --> ST) $ mk_inj tr pth t
- | mk_inj _ _ _ = sys_error "mk_inj"
-
-fun mk_proj (Leaf _) [] t = t
- | mk_proj (Branch (ST, (LT, tr), _)) (true::pth) t =
- mk_proj tr pth (Const (projlN, ST --> LT) $ t)
- | mk_proj (Branch (ST, _, (RT, tr))) (false::pth) t =
- mk_proj tr pth (Const (projrN, ST --> RT) $ t)
- | mk_proj _ _ _ = sys_error "mk_proj"
-
-
-fun mk_sumcases tree T ts =
- let
- fun mk_sumcases' (Leaf _) (t::ts) = (t,ts)
- | mk_sumcases' (Branch (ST, (LT, ltr), (RT, rtr))) ts =
- let
- val (lcase, ts') = mk_sumcases' ltr ts
- val (rcase, ts'') = mk_sumcases' rtr ts'
- in
- (mk_sumcase LT RT T lcase rcase, ts'')
- end
- | mk_sumcases' _ [] = sys_error "mk_sumcases"
- in
- fst (mk_sumcases' tree ts)
- end
-
-end
-
-
-
-