more robust iterator construction (needed for mutualized FPs)
authorblanchet
Tue, 07 May 2013 21:09:47 +0200
changeset 51911 6c425d450a8c
parent 51910 31bb70ddee7e
child 51912 a6b963bc46f0
more robust iterator construction (needed for mutualized FPs)
src/HOL/BNF/Tools/bnf_fp_def_sugar.ML
--- a/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Tue May 07 21:09:46 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_def_sugar.ML	Tue May 07 21:09:47 2013 +0200
@@ -25,6 +25,7 @@
   val morph_fp_sugar: morphism -> fp_sugar -> fp_sugar
   val fp_sugar_of: Proof.context -> string -> fp_sugar option
 
+  val tvar_subst: theory -> typ list -> typ list -> ((string * int) * typ) list
   val exists_subtype_in: typ list -> typ -> bool
   val nesty_bnfs: Proof.context -> typ list list list -> typ list -> BNF_Def.bnf list
   val indexify_fst: ''a list -> (int -> ''a * 'b -> 'c) -> ''a * 'b -> 'c
@@ -165,6 +166,9 @@
 val simp_attrs = @{attributes [simp]};
 val code_simp_attrs = Code.add_default_eqn_attrib :: simp_attrs;
 
+fun tvar_subst thy Ts Us =
+  Vartab.fold (cons o apsnd snd) (fold (Sign.typ_match thy) (Ts ~~ Us) Vartab.empty) [];
+
 val exists_subtype_in = Term.exists_subtype o member (op =);
 
 fun resort_tfree S (TFree (s, _)) = TFree (s, S);
@@ -228,8 +232,18 @@
 
 val mk_fp_iter_fun_types = fst o split_last o binder_types o fastype_of;
 
-fun mk_fp_iter lfp As Cs =
-  map (mk_co_iter lfp As Cs) #> (fn ts => (ts, mk_fp_iter_fun_types (hd ts)));
+fun mk_fp_iters ctxt lfp fpTs Cs fp_iters0 =
+  let
+    val thy = Proof_Context.theory_of ctxt;
+    val nn = length fpTs;
+    val (fpTs0, Cs0) =
+      map ((not lfp ? swap) o dest_funT o snd o strip_typeN nn o fastype_of) fp_iters0
+      |> split_list;
+    val rho = tvar_subst thy (fpTs0 @ Cs0) (fpTs @ Cs);
+    val subst = Term.subst_TVars rho;
+  in
+    map subst fp_iters0 |> (fn ts => (ts, mk_fp_iter_fun_types (hd ts)))
+  end;
 
 fun unzip_recT fpTs T =
   let
@@ -337,8 +351,8 @@
 
 fun mk_un_fold_co_rec_prelims lfp fpTs As Cs ns mss fp_folds0 fp_recs0 lthy =
   let
-    val (fp_folds, fp_fold_fun_Ts) = mk_fp_iter lfp As Cs fp_folds0;
-    val (fp_recs, fp_rec_fun_Ts) = mk_fp_iter lfp As Cs fp_recs0;
+    val (fp_folds, fp_fold_fun_Ts) = mk_fp_iters lthy lfp fpTs Cs fp_folds0;
+    val (fp_recs, fp_rec_fun_Ts) = mk_fp_iters lthy lfp fpTs Cs fp_recs0;
 
     val ((fold_rec_args_types, unfold_corec_args_types), lthy') =
       if lfp then
@@ -553,8 +567,8 @@
 
     val fp_b_names = map base_name_of_typ fpTs;
 
-    val (_, ctor_fold_fun_Ts) = mk_fp_iter true As Cs ctor_folds0;
-    val (_, ctor_rec_fun_Ts) = mk_fp_iter true As Cs ctor_recs0;
+    val (_, ctor_fold_fun_Ts) = mk_fp_iters lthy true fpTs Cs ctor_folds0;
+    val (_, ctor_rec_fun_Ts) = mk_fp_iters lthy true fpTs Cs ctor_recs0;
 
     val (((gss, _, _), (hss, _, _)), names_lthy0) =
       mk_fold_rec_args_types fpTs Cs ns mss ctor_fold_fun_Ts ctor_rec_fun_Ts lthy;
@@ -708,8 +722,8 @@
 
     val fp_b_names = map base_name_of_typ fpTs;
 
-    val (_, dtor_unfold_fun_Ts) = mk_fp_iter false As Cs dtor_unfolds0;
-    val (_, dtor_corec_fun_Ts) = mk_fp_iter false As Cs dtor_corecs0;
+    val (_, dtor_unfold_fun_Ts) = mk_fp_iters lthy false fpTs Cs dtor_unfolds0;
+    val (_, dtor_corec_fun_Ts) = mk_fp_iters lthy false fpTs Cs dtor_corecs0;
 
     val ctrss = map (map (mk_ctr As) o #ctrs) ctr_sugars;
     val discss = map (map (mk_disc_or_sel As) o #discs) ctr_sugars;