construct high-level iterator RHS
authorblanchet
Thu, 06 Sep 2012 02:56:21 +0200
changeset 49176 6d29d2db5f88
parent 49175 eab51f249c70
child 49177 db8ce685073f
construct high-level iterator RHS
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_fp_util.ML
src/HOL/Codatatype/Tools/bnf_gfp.ML
src/HOL/Codatatype/Tools/bnf_lfp.ML
src/HOL/Codatatype/Tools/bnf_util.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Thu Sep 06 01:37:24 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Thu Sep 06 02:56:21 2012 +0200
@@ -52,9 +52,6 @@
 fun args_of ((_, args), _) = args;
 fun mixfix_of_ctr (_, mx) = mx;
 
-val uncurry_fs =
-  map2 (fn f => fn xs => HOLogic.tupled_lambda (HOLogic.mk_tuple xs) (Term.list_comb (f, xs)));
-
 fun prepare_data prepare_typ gfp specs fake_lthy lthy =
   let
     val constrained_As =
@@ -75,7 +72,7 @@
         As);
 
     val bs = map type_binder_of specs;
-    val Ts = map mk_T bs;
+    val fp_Ts = map mk_T bs;
 
     val mixfixes = map mixfix_of_typ specs;
 
@@ -98,35 +95,35 @@
       | A' :: _ => error ("Extra type variables on rhs: " ^
           quote (Syntax.string_of_typ lthy (TFree A'))));
 
-    val (Bs, C) =
+    val ((Cs, Xs), _) =
       lthy
       |> fold (fold (fn s => Variable.declare_typ (TFree (s, dummyS))) o type_args_of) specs
       |> mk_TFrees N
-      ||> the_single o fst o mk_TFrees 1;
+      ||>> mk_TFrees N;
 
-    fun is_same_rec (T as Type (s, Us)) (Type (s', Us')) =
+    fun is_same_recT (T as Type (s, Us)) (Type (s', Us')) =
         s = s' andalso (Us = Us' orelse error ("Illegal occurrence of recursive type " ^
           quote (Syntax.string_of_typ fake_lthy T)))
-      | is_same_rec _ _ = false
+      | is_same_recT _ _ = false;
 
-    fun freeze_rec (T as Type (s, Us)) =
-        (case find_index (is_same_rec T) Ts of
-          ~1 => Type (s, map freeze_rec Us)
-        | i => nth Bs i)
-      | freeze_rec T = T;
+    fun freeze_recXs (T as Type (s, Us)) =
+        (case find_index (is_same_recT T) fp_Ts of
+          ~1 => Type (s, map freeze_recXs Us)
+        | i => nth Xs i)
+      | freeze_recXs T = T;
 
-    val ctr_TsssBs = map (map (map freeze_rec)) ctr_Tsss;
-    val sum_prod_TsBs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssBs;
+    val ctr_TsssXs = map (map (map freeze_recXs)) ctr_Tsss;
+    val sum_prod_TsXs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssXs;
 
-    val eqs = map dest_TFree Bs ~~ sum_prod_TsBs;
+    val eqs = map dest_TFree Xs ~~ sum_prod_TsXs;
 
-    val ((raw_unfs, raw_flds, unf_flds, fld_unfs, fld_injects), lthy') =
+    val ((raw_unfs, raw_flds, raw_fp_iters, raw_fp_recs, unf_flds, fld_unfs, fld_injects), lthy') =
       fp_bnf (if gfp then bnf_gfp else bnf_lfp) bs mixfixes As' eqs lthy;
 
     val timer = time (Timer.startRealTimer ());
 
-    fun mk_unf_or_fld get_foldedT Ts t =
-      let val Type (_, Ts0) = get_foldedT (fastype_of t) in
+    fun mk_unf_or_fld get_T Ts t =
+      let val Type (_, Ts0) = get_T (fastype_of t) in
         Term.subst_atomic_types (Ts0 ~~ Ts) t
       end;
 
@@ -136,10 +133,23 @@
     val unfs = map (mk_unf As) raw_unfs;
     val flds = map (mk_fld As) raw_flds;
 
-    fun pour_sugar_on_type (((((((((((b, T), fld), unf), fld_unf), unf_fld), fld_inject),
-        ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders), sel_binderss) no_defs_lthy =
+    fun mk_fp_iter_or_rec Ts Us t =
       let
-        val n = length ctr_binders;
+        val (binders, body) = strip_type (fastype_of t);
+        val Type (_, Ts0) = if gfp then body else List.last binders;
+        val Us0 = map (if gfp then domain_type else body_type) (fst (split_last binders));
+      in
+        Term.subst_atomic_types (Ts0 @ Us0 ~~ Ts @ Us) t
+      end;
+
+    val fp_iters = map (mk_fp_iter_or_rec As Cs) raw_fp_iters;
+    val fp_recs = map (mk_fp_iter_or_rec As Cs) raw_fp_recs;
+
+    fun pour_sugar_on_type ((((((((((((((b, fp_T), C), fld), unf), fp_iter), fp_rec), fld_unf),
+          unf_fld), fld_inject), ctr_binders), ctr_mixfixes), ctr_Tss), disc_binders), sel_binderss)
+        no_defs_lthy =
+      let
+        val n = length ctr_Tss;
         val ks = 1 upto n;
         val ms = map length ctr_Tss;
 
@@ -147,11 +157,11 @@
         val prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
         val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
 
-        val ((((fs, u), v), xss), _) =
+        val ((((u, v), fs), xss), _) =
           lthy
-          |> mk_Frees "f" case_Ts
-          ||>> yield_singleton (mk_Frees "u") unf_T
-          ||>> yield_singleton (mk_Frees "v") T
+          |> yield_singleton (mk_Frees "u") unf_T
+          ||>> yield_singleton (mk_Frees "v") fp_T
+          ||>> mk_Frees "f" case_Ts
           ||>> mk_Freess "x" ctr_Tss;
 
         val ctr_rhss =
@@ -161,7 +171,7 @@
         val case_binder = Binding.suffix_name ("_" ^ caseN) b;
 
         val case_rhs =
-          fold_rev Term.lambda (fs @ [v]) (mk_sum_caseN (uncurry_fs fs xss) $ (unf $ v));
+          fold_rev Term.lambda (fs @ [v]) (mk_sum_caseN (map2 mk_uncurried_fun fs xss) $ (unf $ v));
 
         val (((raw_ctrs, raw_ctr_defs), (raw_case, raw_case_def)), (lthy', lthy)) = no_defs_lthy
           |> apfst split_list o fold_map3 (fn b => fn mx => fn rhs =>
@@ -189,8 +199,8 @@
                     (mk_Trueprop_eq (HOLogic.mk_eq (v, fld $ u), HOLogic.mk_eq (unf $ v, u)));
               in
                 Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} =>
-                  mk_fld_iff_unf_tac ctxt (map (SOME o certifyT lthy) [unf_T, T]) (certify lthy fld)
-                    (certify lthy unf) fld_unf unf_fld)
+                  mk_fld_iff_unf_tac ctxt (map (SOME o certifyT lthy) [unf_T, fp_T])
+                    (certify lthy fld) (certify lthy unf) fld_unf unf_fld)
                 |> Thm.close_derivation
                 |> Morphism.thm phi
               end;
@@ -219,24 +229,30 @@
 
         val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs];
 
+        (* (co)iterators, (co)recursors, (co)induction *)
+
+        val is_recT = member (op =) fp_Ts;
+
+        val ns = map length ctr_Tsss;
+        val mss = map (map length) ctr_Tsss;
+        val Css = map2 replicate ns Cs;
+
         fun sugar_lfp lthy =
           let
-(*###
-            val fld_iter = @{term True}; (*###*)
+            val fp_y_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_iter))));
+            val y_prod_Tss = map2 dest_sumTN ns fp_y_Ts;
+            val y_Tsss = map2 (map2 dest_tupleT) mss y_prod_Tss;
+            val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
+            val iter_T = flat g_Tss ---> fp_T --> C;
 
-            val iter_Tss = map (fn Ts => Ts) (*###*) ctr_Tss;
-            val iter_Ts = map (fn Ts => Ts ---> C) iter_Tss;
-
-            val iter_fs = map2 (fn Free (s, _) => fn T => Free (s, T)) fs iter_Ts
+            val ((gss, ysss), _) =
+              lthy
+              |> mk_Freess "f" g_Tss
+              ||>> apfst (unflat y_Tsss) o mk_Freess "x" (flat y_Tsss);
 
             val iter_rhs =
-              fold_rev Term.lambda fs (fld_iter $ mk_sum_caseN (uncurry_fs fs xss) $ (unf $ v));
-
-
-            val uncurried_fs =
-              map2 (fn f => fn xs =>
-                HOLogic.tupled_lambda (HOLogic.mk_tuple xs) (Term.list_comb (f, xs))) fs xss;
-*)
+              fold_rev (fold_rev Term.lambda) gss
+                (Term.list_comb (fp_iter, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) gss ysss));
           in
             lthy
           end;
@@ -248,8 +264,9 @@
       end;
 
     val lthy'' =
-      fold pour_sugar_on_type (bs ~~ Ts ~~ flds ~~ unfs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~
-        ctr_binderss ~~ ctr_mixfixess ~~ ctr_Tsss ~~ disc_binderss ~~ sel_bindersss) lthy';
+      fold pour_sugar_on_type (bs ~~ fp_Ts ~~ Cs ~~ flds ~~ unfs ~~ fp_iters ~~ fp_recs ~~
+        fld_unfs ~~ unf_flds ~~ fld_injects ~~ ctr_binderss ~~ ctr_mixfixess ~~ ctr_Tsss ~~
+        disc_binderss ~~ sel_bindersss) lthy';
 
     val timer = time (timer ("Constructors, discriminators, selectors, etc., for the new " ^
       (if gfp then "co" else "") ^ "datatype"));
--- a/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Thu Sep 06 01:37:24 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Thu Sep 06 02:56:21 2012 +0200
@@ -88,6 +88,11 @@
   val mk_sum_case: term -> term -> term
   val mk_sum_caseN: term list -> term
 
+  val dest_sumTN: int -> typ -> typ list
+  val dest_tupleT: int -> typ -> typ list
+
+  val mk_uncurried_fun: term -> term list -> term
+
   val mk_Field: term -> term
   val mk_union: term * term -> term
 
@@ -219,6 +224,16 @@
 fun mk_sum_caseN [f] = f
   | mk_sum_caseN (f :: fs) = mk_sum_case f (mk_sum_caseN fs);
 
+fun dest_sumTN 1 T = [T]
+  | dest_sumTN n (Type (@{type_name sum}, [T, T'])) = T :: dest_sumTN (n - 1) T';
+
+(* TODO: move something like this to "HOLogic"? *)
+fun dest_tupleT 0 @{typ unit} = []
+  | dest_tupleT 1 T = [T]
+  | dest_tupleT n (Type (@{type_name prod}, [T, T'])) = T :: dest_tupleT (n - 1) T';
+
+fun mk_uncurried_fun f xs = HOLogic.tupled_lambda (HOLogic.mk_tuple xs) (Term.list_comb (f, xs));
+
 fun mk_Field r =
   let val T = fst (dest_relT (fastype_of r));
   in Const (@{const_name Field}, mk_relT (T, T) --> HOLogic.mk_setT T) $ r end;
--- a/src/HOL/Codatatype/Tools/bnf_gfp.ML	Thu Sep 06 01:37:24 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_gfp.ML	Thu Sep 06 02:56:21 2012 +0200
@@ -11,7 +11,7 @@
 sig
   val bnf_gfp: binding list -> mixfix list -> (string * sort) list -> typ list list ->
     BNF_Def.BNF list -> local_theory ->
-    (term list * term list * thm list * thm list * thm list) * local_theory
+    (term list * term list * term list * term list * thm list * thm list * thm list) * local_theory
 end;
 
 structure BNF_GFP : BNF_GFP =
@@ -1965,8 +1965,9 @@
 
     (*transforms defined frees into consts*)
     val phi = Proof_Context.export_morphism lthy_old lthy;
-    val coiters = map (fst o dest_Const o Morphism.term phi) coiter_frees;
-    fun mk_coiter Ts ss i = Term.list_comb (Const (nth coiters (i - 1), Library.foldr (op -->)
+    val coiters = map (Morphism.term phi) coiter_frees;
+    val coiter_names = map (fst o dest_Const) coiters;
+    fun mk_coiter Ts ss i = Term.list_comb (Const (nth coiter_names (i - 1), Library.foldr (op -->)
       (map fastype_of ss, domain_type (fastype_of (nth ss (i - 1))) --> nth Ts (i - 1))), ss);
     val coiter_defs = map ((fn thm => thm RS fun_cong) o Morphism.thm phi) coiter_def_frees;
 
@@ -2158,8 +2159,9 @@
 
     (*transforms defined frees into consts*)
     val phi = Proof_Context.export_morphism lthy_old lthy;
-    val corecs = map (fst o dest_Const o Morphism.term phi) corec_frees;
-    fun mk_corec ss i = Term.list_comb (Const (nth corecs (i - 1), Library.foldr (op -->)
+    val corecs = map (Morphism.term phi) corec_frees;
+    val corec_names = map (fst o dest_Const) corecs;
+    fun mk_corec ss i = Term.list_comb (Const (nth corec_names (i - 1), Library.foldr (op -->)
       (map fastype_of ss, domain_type (fastype_of (nth ss (i - 1))) --> nth Ts (i - 1))), ss);
     val corec_defs = map (Morphism.thm phi) corec_def_frees;
 
@@ -2990,7 +2992,7 @@
             ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]))
           bs thmss)
   in
-    ((unfs, flds, unf_fld_thms, fld_unf_thms, fld_inject_thms),
+    ((unfs, flds, coiters, corecs, unf_fld_thms, fld_unf_thms, fld_inject_thms),
       lthy |> Local_Theory.notes (common_notes @ notes) |> snd)
   end;
 
--- a/src/HOL/Codatatype/Tools/bnf_lfp.ML	Thu Sep 06 01:37:24 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_lfp.ML	Thu Sep 06 02:56:21 2012 +0200
@@ -10,7 +10,7 @@
 sig
   val bnf_lfp: binding list -> mixfix list -> (string * sort) list -> typ list list ->
     BNF_Def.BNF list -> local_theory ->
-    (term list * term list * thm list * thm list * thm list) * local_theory
+    (term list * term list * term list * term list * thm list * thm list * thm list) * local_theory
 end;
 
 structure BNF_LFP : BNF_LFP =
@@ -1078,8 +1078,9 @@
 
     (*transforms defined frees into consts*)
     val phi = Proof_Context.export_morphism lthy_old lthy;
-    val iters = map (fst o dest_Const o Morphism.term phi) iter_frees;
-    fun mk_iter Ts ss i = Term.list_comb (Const (nth iters (i - 1), Library.foldr (op -->)
+    val iters = map (Morphism.term phi) iter_frees;
+    val iter_names = map (fst o dest_Const) iters;
+    fun mk_iter Ts ss i = Term.list_comb (Const (nth iter_names (i - 1), Library.foldr (op -->)
       (map fastype_of ss, nth Ts (i - 1) --> range_type (fastype_of (nth ss (i - 1))))), ss);
     val iter_defs = map (Morphism.thm phi) iter_def_frees;
 
@@ -1239,8 +1240,9 @@
 
     (*transforms defined frees into consts*)
     val phi = Proof_Context.export_morphism lthy_old lthy;
-    val recs = map (fst o dest_Const o Morphism.term phi) rec_frees;
-    fun mk_rec ss i = Term.list_comb (Const (nth recs (i - 1), Library.foldr (op -->)
+    val recs = map (Morphism.term phi) rec_frees;
+    val rec_names = map (fst o dest_Const) recs;
+    fun mk_rec ss i = Term.list_comb (Const (nth rec_names (i - 1), Library.foldr (op -->)
       (map fastype_of ss, nth Ts (i - 1) --> range_type (fastype_of (nth ss (i - 1))))), ss);
     val rec_defs = map (Morphism.thm phi) rec_def_frees;
 
@@ -1813,7 +1815,7 @@
             ((Binding.qualify true (Binding.name_of b) (Binding.name thmN), []), [(thms, [])]))
           bs thmss)
   in
-    ((unfs, flds, unf_fld_thms, fld_unf_thms, fld_inject_thms),
+    ((unfs, flds, iters, recs, unf_fld_thms, fld_unf_thms, fld_inject_thms),
       lthy |> Local_Theory.notes (common_notes @ notes) |> snd)
   end;
 
--- a/src/HOL/Codatatype/Tools/bnf_util.ML	Thu Sep 06 01:37:24 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_util.ML	Thu Sep 06 02:56:21 2012 +0200
@@ -273,8 +273,8 @@
 fun mk_Frees x Ts ctxt = mk_fresh_names ctxt (length Ts) x
   |>> (fn names => map2 (curry Free) names Ts);
 fun mk_Freess x Tss ctxt =
-  fold_map2 (fn name => fn Ts => fn ctxt =>
-    mk_fresh_names ctxt (length Ts) name) (mk_names (length Tss) x) Tss ctxt
+  fold_map2 (fn name => fn Ts => fn ctxt => mk_fresh_names ctxt (length Ts) name)
+    (mk_names (length Tss) x) Tss ctxt
   |>> (fn namess => map2 (map2 (curry Free)) namess Tss);
 fun mk_Frees' x Ts ctxt = mk_fresh_names ctxt (length Ts) x
   |>> (fn names => `(map Free) (names ~~ Ts));