merged
authorwenzelm
Tue, 11 Sep 2012 23:27:19 +0200
changeset 49299 f9f240dfb50b
parent 49298 36e551d3af3b (diff)
parent 49296 313369027391 (current diff)
child 49300 c707df2e2083
merged
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 11 23:26:03 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 11 23:27:19 2012 +0200
@@ -7,7 +7,10 @@
 
 signature BNF_FP_SUGAR =
 sig
-  (* TODO: programmatic interface *)
+  val datatyp: bool ->
+    bool * ((((typ * sort) list * binding) * mixfix) * ((((binding * binding) *
+      (binding * typ) list) * (binding * term) list) * mixfix) list) list ->
+    local_theory -> local_theory
 end;
 
 structure BNF_FP_Sugar : BNF_FP_SUGAR =
@@ -44,9 +47,11 @@
     | SOME T' => T')
   | typ_subst inst T = the_default T (AList.lookup (op =) inst T);
 
-fun retype_free (Free (s, _)) T = Free (s, T);
+fun resort_tfree S (TFree (s, _)) = TFree (s, S);
 
-val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs))
+fun retype_free T (Free (s, _)) = Free (s, T);
+
+val lists_bmoc = fold (fn xs => fn t => Term.list_comb (t, xs));
 
 fun mk_predT T = T --> HOLogic.boolT;
 
@@ -66,23 +71,10 @@
 
 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
 
-fun merge_type_arg_constrained ctxt (T, c) (T', c') =
-  if T = T' then
-    (case (c, c') of
-      (_, NONE) => (T, c)
-    | (NONE, _) => (T, c')
-    | _ =>
-      if c = c' then
-        (T, c)
-      else
-        error ("Inconsistent sort constraints for type variable " ^
-          quote (Syntax.string_of_typ ctxt T)))
-  else
-    cannot_merge_types ();
+fun merge_type_arg T T' = if T = T' then T else cannot_merge_types ();
 
-fun merge_type_args_constrained ctxt (cAs, cAs') =
-  if length cAs = length cAs' then map2 (merge_type_arg_constrained ctxt) cAs cAs'
-  else cannot_merge_types ();
+fun merge_type_args (As, As') =
+  if length As = length As' then map2 merge_type_arg As As' else cannot_merge_types ();
 
 fun type_args_constrained_of (((cAs, _), _), _) = cAs;
 val type_args_of = map fst o type_args_constrained_of;
@@ -96,31 +88,45 @@
 fun defaults_of ((_, ds), _) = ds;
 fun ctr_mixfix_of (_, mx) = mx;
 
-fun prepare_datatype prepare_typ prepare_term lfp (no_dests, specs) fake_lthy no_defs_lthy =
+fun define_datatype prepare_constraint prepare_typ prepare_term lfp (no_dests, specs)
+    no_defs_lthy0 =
   let
+    (* TODO: sanity checks on arguments *)
+
     val _ = if not lfp andalso no_dests then error "Cannot define destructor-less codatatypes"
       else ();
 
-    val constrained_As =
-      map (map (apfst (prepare_typ fake_lthy)) o type_args_constrained_of) specs
-      |> Library.foldr1 (merge_type_args_constrained no_defs_lthy);
-    val As = map fst constrained_As;
-    val As' = map dest_TFree As;
+    val N = length specs;
+
+    fun prepare_type_arg (ty, c) =
+      let val TFree (s, _) = prepare_typ no_defs_lthy0 ty in
+        TFree (s, prepare_constraint no_defs_lthy0 c)
+      end;
+
+    val Ass0 = map (map prepare_type_arg o type_args_constrained_of) specs;
+    val unsorted_Ass0 = map (map (resort_tfree HOLogic.typeS)) Ass0;
+    val unsorted_As = Library.foldr1 merge_type_args unsorted_Ass0;
 
-    val _ = (case duplicates (op =) As of [] => ()
-      | A :: _ => error ("Duplicate type parameter " ^
-          quote (Syntax.string_of_typ no_defs_lthy A)));
+    val ((Bs, Cs), no_defs_lthy) =
+      no_defs_lthy0
+      |> fold (Variable.declare_typ o resort_tfree dummyS) unsorted_As
+      |> mk_TFrees N
+      ||>> mk_TFrees N;
 
-    (* TODO: use sort constraints on type args *)
-
-    val N = length specs;
+    (* TODO: cleaner handling of fake contexts, without "background_theory" *)
+    (*the "perhaps o try" below helps gracefully handles the case where the new type is defined in a
+      locale and shadows an existing global type*)
+    val fake_thy =
+      Theory.copy #> fold (fn spec => perhaps (try (Sign.add_type no_defs_lthy
+        (type_binder_of spec, length (type_args_constrained_of spec), mixfix_of spec)))) specs;
+    val fake_lthy = Proof_Context.background_theory fake_thy no_defs_lthy;
 
     fun mk_fake_T b =
       Type (fst (Term.dest_Type (Proof_Context.read_type_name fake_lthy true (Binding.name_of b))),
-        As);
+        unsorted_As);
 
     val bs = map type_binder_of specs;
-    val fakeTs = map mk_fake_T bs;
+    val fake_Ts = map mk_fake_T bs;
 
     val mixfixes = map mixfix_of specs;
 
@@ -135,39 +141,41 @@
     val ctr_mixfixess = map (map ctr_mixfix_of) ctr_specss;
 
     val sel_bindersss = map (map (map fst)) ctr_argsss;
-    val fake_ctr_Tsss = map (map (map (prepare_typ fake_lthy o snd))) ctr_argsss;
-
+    val fake_ctr_Tsss0 = map (map (map (prepare_typ fake_lthy o snd))) ctr_argsss;
     val raw_sel_defaultsss = map (map defaults_of) ctr_specss;
 
+    val (Ass as As :: _) :: fake_ctr_Tsss =
+      burrow (burrow (Syntax.check_typs fake_lthy)) (Ass0 :: fake_ctr_Tsss0);
+
+    val _ = (case duplicates (op =) unsorted_As of [] => ()
+      | A :: _ => error ("Duplicate type parameter " ^
+          quote (Syntax.string_of_typ no_defs_lthy A)));
+
     val rhs_As' = fold (fold (fold Term.add_tfreesT)) fake_ctr_Tsss [];
-    val _ = (case subtract (op =) As' rhs_As' of
+    val _ = (case subtract (op =) (map dest_TFree As) rhs_As' of
         [] => ()
       | A' :: _ => error ("Extra type variables on rhs: " ^
           quote (Syntax.string_of_typ no_defs_lthy (TFree A'))));
 
-    val ((Cs, Xs), _) =
-      no_defs_lthy
-      |> fold (fold (fn s => Variable.declare_typ (TFree (s, dummyS))) o type_args_of) specs
-      |> mk_TFrees N
-      ||>> mk_TFrees N;
-
     fun eq_fpT (T as Type (s, Us)) (Type (s', Us')) =
         s = s' andalso (Us = Us' orelse error ("Illegal occurrence of recursive type " ^
           quote (Syntax.string_of_typ fake_lthy T)))
       | eq_fpT _ _ = false;
 
     fun freeze_fp (T as Type (s, Us)) =
-        (case find_index (eq_fpT T) fakeTs of ~1 => Type (s, map freeze_fp Us) | j => nth Xs j)
+        (case find_index (eq_fpT T) fake_Ts of ~1 => Type (s, map freeze_fp Us) | j => nth Bs j)
       | freeze_fp T = T;
 
-    val ctr_TsssXs = map (map (map freeze_fp)) fake_ctr_Tsss;
-    val ctr_sum_prod_TsXs = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctr_TsssXs;
+    val ctr_TsssBs = map (map (map freeze_fp)) fake_ctr_Tsss;
+    val ctr_sum_prod_TsBs = map (mk_sumTN_balanced o map HOLogic.mk_tupleT) ctr_TsssBs;
 
-    val eqs = map dest_TFree Xs ~~ ctr_sum_prod_TsXs;
+    val fp_eqs =
+      map dest_TFree Bs ~~ map (Term.typ_subst_atomic (As ~~ unsorted_As)) ctr_sum_prod_TsBs;
 
     val (pre_bnfs, ((unfs0, flds0, fp_iters0, fp_recs0, unf_flds, fld_unfs, fld_injects,
         fp_iter_thms, fp_rec_thms), lthy)) =
-      fp_bnf (if lfp then bnf_lfp else bnf_gfp) bs mixfixes As' eqs no_defs_lthy;
+      fp_bnf (if lfp then bnf_lfp else bnf_gfp) bs mixfixes (map dest_TFree unsorted_As) fp_eqs
+        no_defs_lthy0;
 
     val add_nested_bnf_names =
       let
@@ -179,7 +187,7 @@
       in snd oo add end;
 
     val nested_bnfs =
-      map_filter (bnf_of lthy) (fold (fold (fold add_nested_bnf_names)) ctr_TsssXs []);
+      map_filter (bnf_of lthy) (fold (fold (fold add_nested_bnf_names)) ctr_TsssBs []);
 
     val timer = time (Timer.startRealTimer ());
 
@@ -196,7 +204,7 @@
 
     val fpTs = map (domain_type o fastype_of) unfs;
 
-    val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Xs ~~ fpTs)))) ctr_TsssXs;
+    val ctr_Tsss = map (map (map (Term.typ_subst_atomic (Bs ~~ fpTs)))) ctr_TsssBs;
     val ns = map length ctr_Tsss;
     val kss = map (fn n => 1 upto n) ns;
     val mss = map (map length) ctr_Tsss;
@@ -242,8 +250,8 @@
               dest_sumTN_balanced n o domain_type) ns mss fp_rec_fun_Ts;
           val h_Tss = map2 (map2 (fold_rev (curry (op --->)))) z_Tssss Css;
 
-          val hss = map2 (map2 retype_free) gss h_Tss;
-          val zssss_hd = map2 (map2 (map2 (fn y => fn T :: _ => retype_free y T))) ysss z_Tssss;
+          val hss = map2 (map2 retype_free) h_Tss gss;
+          val zssss_hd = map2 (map2 (map2 (retype_free o hd))) z_Tssss ysss;
           val (zssss_tl, _) =
             lthy
             |> mk_Freessss "y" (map (map (map tl)) z_Tssss);
@@ -293,7 +301,7 @@
 
           val (s_Tssss, h_sum_prod_Ts, h_Tssss, ph_Tss) = mk_types dest_corec_sumT fp_rec_fun_Ts;
 
-          val hssss_hd = map2 (map2 (map2 (fn [g] => fn T :: _ => retype_free g T))) gssss h_Tssss;
+          val hssss_hd = map2 (map2 (map2 (fn T :: _ => fn [g] => retype_free T g))) h_Tssss gssss;
           val ((sssss, hssss_tl), _) =
             lthy
             |> mk_Freessss "q" s_Tssss
@@ -685,18 +693,9 @@
     (timer; lthy')
   end;
 
-fun datatype_cmd lfp (bundle as (_, specs)) lthy =
-  let
-    (* TODO: cleaner handling of fake contexts, without "background_theory" *)
-    (*the "perhaps o try" below helps gracefully handles the case where the new type is defined in a
-      locale and shadows an existing global type*)
-    val fake_thy = Theory.copy
-      #> fold (fn spec => perhaps (try (Sign.add_type lthy
-        (type_binder_of spec, length (type_args_constrained_of spec), mixfix_of spec)))) specs;
-    val fake_lthy = Proof_Context.background_theory fake_thy lthy;
-  in
-    prepare_datatype Syntax.read_typ Syntax.read_term lfp bundle fake_lthy lthy
-  end;
+val datatyp = define_datatype (K I) (K I) (K I);
+
+val datatype_cmd = define_datatype Typedecl.read_constraint Syntax.parse_typ Syntax.read_term;
 
 val parse_opt_binding_colon = Scan.optional (Parse.binding --| @{keyword ":"}) no_binder
 
--- a/src/HOL/Codatatype/Tools/bnf_util.ML	Tue Sep 11 23:26:03 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_util.ML	Tue Sep 11 23:27:19 2012 +0200
@@ -47,6 +47,7 @@
   val mk_fresh_names: Proof.context -> int -> string -> string list * Proof.context
   val mk_TFrees: int -> Proof.context -> typ list * Proof.context
   val mk_TFreess: int list -> Proof.context -> typ list list * Proof.context
+  val mk_TFrees': sort list -> Proof.context -> typ list * Proof.context
   val mk_Frees: string -> typ list -> Proof.context -> term list * Proof.context
   val mk_Freess: string -> typ list list -> Proof.context -> term list list * Proof.context
   val mk_Freesss: string -> typ list list list -> Proof.context ->
@@ -282,9 +283,10 @@
 
 (** Fresh variables **)
 
-fun mk_TFrees n = apfst (map TFree) o Variable.invent_types (replicate n (HOLogic.typeS));
-fun mk_TFreess ns = apfst (map (map TFree)) o
-  fold_map Variable.invent_types (map (fn n => replicate n (HOLogic.typeS)) ns);
+val mk_TFrees' = apfst (map TFree) oo Variable.invent_types;
+
+fun mk_TFrees n = mk_TFrees' (replicate n HOLogic.typeS);
+val mk_TFreess = fold_map mk_TFrees;
 
 fun mk_names n x = if n = 1 then [x] else map (fn i => x ^ string_of_int i) (1 upto n);
 
--- a/src/HOL/Codatatype/Tools/bnf_wrap.ML	Tue Sep 11 23:26:03 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_wrap.ML	Tue Sep 11 23:27:19 2012 +0200
@@ -96,7 +96,7 @@
 
     val (As, B) =
       no_defs_lthy
-      |> mk_TFrees (length As0)
+      |> mk_TFrees' (map Type.sort_of_atyp As0)
       ||> the_single o fst o mk_TFrees 1;
 
     val fpT = Type (fpT_name, As);
@@ -572,6 +572,10 @@
   map2 (map2 (Skip_Proof.prove lthy [] [])) goalss tacss
   |> (fn thms => after_qed thms lthy)) oo prepare_wrap_datatype (K I);
 
+val wrap_datatype_cmd = (fn (goalss, after_qed, lthy) =>
+  Proof.theorem NONE (snd oo after_qed) (map (map (rpair [])) goalss) lthy) oo
+  prepare_wrap_datatype Syntax.read_term;
+
 fun parse_bracket_list parser = @{keyword "["} |-- Parse.list parser --|  @{keyword "]"};
 
 val parse_bindings = parse_bracket_list Parse.binding;
@@ -581,10 +585,6 @@
 val parse_bound_terms = parse_bracket_list parse_bound_term;
 val parse_bound_termss = parse_bracket_list parse_bound_terms;
 
-val wrap_datatype_cmd = (fn (goalss, after_qed, lthy) =>
-  Proof.theorem NONE (snd oo after_qed) (map (map (rpair [])) goalss) lthy) oo
-  prepare_wrap_datatype Syntax.read_term;
-
 val parse_wrap_options =
   Scan.optional (@{keyword "("} |-- (@{keyword "no_dests"} >> K true) --| @{keyword ")"}) false;