generalized "mk_co_iter" to handle mutualized (co)iterators
authorblanchet
Mon, 27 May 2013 13:30:08 +0200
changeset 52170 564be617ae84
parent 52169 418f5ad4c1c5
child 52171 012679d3a5af
generalized "mk_co_iter" to handle mutualized (co)iterators
src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Mon May 27 12:21:17 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Mon May 27 13:30:08 2013 +0200
@@ -220,32 +220,32 @@
 val mk_ctor = mk_ctor_or_dtor range_type;
 val mk_dtor = mk_ctor_or_dtor domain_type;
 
-fun mk_co_iter thy lfp Ts Us t =
+fun mk_co_iter thy lfp fpT Us t =
   let
     val (bindings, body) = strip_type (fastype_of t);
     val (f_Us, prebody) = split_last bindings;
-    val Type (_, Ts0) = if lfp then prebody else body;
+    val fpT0 = if lfp then prebody else body;
     val Us0 = distinct (op =) (map (if lfp then body_type else domain_type) f_Us);
+    val rho = tvar_subst thy (fpT0 :: Us0) (fpT :: Us);
   in
-    Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
+    Term.subst_TVars rho t
   end;
 
-fun mk_co_iter_new thy lfp fpTs Cs fp_iters0 =
+fun mk_co_iters thy lfp fpTs Cs ts0 =
   let
     val nn = length fpTs;
     val (fpTs0, Cs0) =
-      map ((not lfp ? swap) o dest_funT o snd o strip_typeN nn o fastype_of) fp_iters0
+      map ((not lfp ? swap) o dest_funT o snd o strip_typeN nn o fastype_of) ts0
       |> split_list;
     val rho = tvar_subst thy (fpTs0 @ Cs0) (fpTs @ Cs);
-    val subst = Term.subst_TVars rho;
   in
-    map subst fp_iters0
+    map (Term.subst_TVars rho) ts0
   end;
 
 val mk_fp_iter_fun_types = fst o split_last o binder_types o fastype_of;
 
 fun mk_fp_iters thy lfp fpTs Cs fp_iters0 =
-  mk_co_iter_new thy lfp fpTs Cs fp_iters0
+  mk_co_iters thy lfp fpTs Cs fp_iters0
   |> (fn ts => (ts, mk_fp_iter_fun_types (hd ts)));
 
 fun unzip_recT fpTs T =
@@ -487,9 +487,11 @@
 
 fun define_fold_rec (fold_only, rec_only) mk_binding fpTs As Cs ctor_fold ctor_rec lthy0 =
   let
+    val thy = Proof_Context.theory_of lthy0;
+
     val nn = length fpTs;
 
-    val fpT_to_C = snd (strip_typeN nn (fastype_of ctor_fold));
+    val fpT_to_C as Type (_, [fpT, _]) = snd (strip_typeN nn (fastype_of ctor_fold));
 
     fun generate_iter (suf, ctor_iter, (fss, f_Tss, xsss)) =
       let
@@ -513,7 +515,7 @@
 
     val [fold_def, rec_def] = map (Morphism.thm phi) defs;
 
-    val [foldx, recx] = map (mk_co_iter true As Cs o Morphism.term phi) csts;
+    val [foldx, recx] = map (mk_co_iter thy true fpT Cs o Morphism.term phi) csts;
   in
     ((foldx, recx, fold_def, rec_def), lthy')
   end;
@@ -521,9 +523,11 @@
 fun define_unfold_corec (cs, cpss, unfold_only, corec_only) mk_binding fpTs As Cs dtor_unfold
     dtor_corec lthy0 =
   let
+    val thy = Proof_Context.theory_of lthy0;
+
     val nn = length fpTs;
 
-    val C_to_fpT = snd (strip_typeN nn (fastype_of dtor_unfold));
+    val C_to_fpT as Type (_, [_, fpT]) = snd (strip_typeN nn (fastype_of dtor_unfold));
 
     fun generate_coiter (suf, dtor_coiter, ((pfss, cqssss, cfssss),
         (f_sum_prod_Ts, f_Tsss, pf_Tss))) =
@@ -548,7 +552,7 @@
 
     val [unfold_def, corec_def] = map (Morphism.thm phi) defs;
 
-    val [unfold, corec] = map (mk_co_iter false As Cs o Morphism.term phi) csts;
+    val [unfold, corec] = map (mk_co_iter thy false fpT Cs o Morphism.term phi) csts;
   in
     ((unfold, corec, unfold_def, corec_def), lthy')
   end;