added high-level recursor, not yet curried
authorblanchet
Sat, 08 Sep 2012 21:04:26 +0200
changeset 49199 7c9a3c67c55d
parent 49198 38af9102ee75
child 49200 73f9aede57a4
added high-level recursor, not yet curried
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_wrap.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Fri Sep 07 15:28:48 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Sat Sep 08 21:04:26 2012 +0200
@@ -21,6 +21,8 @@
 
 val caseN = "case";
 
+fun retype_free (Free (s, _)) T = Free (s, T);
+
 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
 
 fun merge_type_arg_constrained ctxt (T, c) (T', c') =
@@ -52,7 +54,7 @@
 fun args_of ((_, args), _) = args;
 fun ctr_mixfix_of (_, mx) = mx;
 
-fun prepare_data prepare_typ gfp specs fake_lthy lthy =
+fun prepare_datatype prepare_typ gfp specs fake_lthy lthy =
   let
     val constrained_As =
       map (map (apfst (prepare_typ fake_lthy)) o type_args_constrained_of) specs
@@ -240,7 +242,7 @@
         val mss = map (map length) ctr_Tsss;
         val Css = map2 replicate ns Cs;
 
-        fun sugar_lfp lthy =
+        fun sugar_datatype no_defs_lthy =
           let
             val fp_y_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_iter))));
             val y_prod_Tss = map2 dest_sumTN ns fp_y_Ts;
@@ -248,22 +250,48 @@
             val g_Tss = map2 (map2 (curry (op --->))) y_Tsss Css;
             val iter_T = flat g_Tss ---> fp_T --> C;
 
+            val fp_z_Ts = map domain_type (fst (split_last (binder_types (fastype_of fp_rec))));
+            val z_prod_Tss = map2 dest_sumTN ns fp_z_Ts;
+            val z_Tsss = map2 (map2 dest_tupleT) mss z_prod_Tss;
+            val h_Tss = map2 (map2 (curry (op --->))) z_Tsss Css;
+            val rec_T = flat h_Tss ---> fp_T --> C;
+
             val ((gss, ysss), _) =
-              lthy
+              no_defs_lthy
               |> mk_Freess "f" g_Tss
               ||>> mk_Freesss "x" y_Tsss;
 
-            val iter_rhs =
-              fold_rev (fold_rev Term.lambda) gss
-                (Term.list_comb (fp_iter, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) gss ysss));
+            val hss = map2 (map2 retype_free) gss h_Tss;
+            val (zsss, _) =
+              no_defs_lthy
+              |> mk_Freesss "x" z_Tsss;
+
+            val iter_binder = Binding.suffix_name ("_" ^ iterN) b;
+            val rec_binder = Binding.suffix_name ("_" ^ recN) b;
+
+            val iter_free = Free (Binding.name_of iter_binder, iter_T);
+            val rec_free = Free (Binding.name_of rec_binder, rec_T);
+
+            val iter_spec =
+              mk_Trueprop_eq (fold (fn gs => fn t => Term.list_comb (t, gs)) gss iter_free,
+                Term.list_comb (fp_iter, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) gss ysss));
+            val rec_spec =
+              mk_Trueprop_eq (fold (fn hs => fn t => Term.list_comb (t, hs)) hss rec_free,
+                Term.list_comb (fp_rec, map2 (mk_sum_caseN oo map2 mk_uncurried_fun) hss zsss));
+
+            val (([raw_iter, raw_rec], [raw_iter_def, raw_rec_def]), (lthy', lthy)) = no_defs_lthy
+              |> apfst split_list o fold_map (fn (b, spec) =>
+                Specification.definition (SOME (b, NONE, NoSyn), ((Thm.def_binding b, []), spec))
+                #>> apsnd snd) [(iter_binder, iter_spec), (rec_binder, rec_spec)]
+              ||> `Local_Theory.restore;
           in
             lthy
           end;
 
-        fun sugar_gfp lthy = lthy;
+        fun sugar_codatatype no_defs_lthy = no_defs_lthy;
       in
-        wrap_data tacss ((ctrs, casex), (disc_binders, sel_binderss)) lthy'
-        |> (if gfp then sugar_gfp else sugar_lfp)
+        wrap_datatype tacss ((ctrs, casex), (disc_binders, sel_binderss)) lthy'
+        |> (if gfp then sugar_codatatype else sugar_datatype)
       end;
 
     val lthy'' =
@@ -277,7 +305,7 @@
     (timer; lthy'')
   end;
 
-fun data_cmd info specs lthy =
+fun datatype_cmd info specs lthy =
   let
     (*the "perhaps o try" below helps gracefully handles the case where the new type is defined in a
       locale and shadows an existing global type*)
@@ -286,7 +314,7 @@
         (type_binder_of spec, length (type_args_constrained_of spec), mixfix_of spec)))) specs;
     val fake_lthy = Proof_Context.background_theory fake_thy lthy;
   in
-    prepare_data Syntax.read_typ info specs fake_lthy lthy
+    prepare_datatype Syntax.read_typ info specs fake_lthy lthy
   end;
 
 val parse_opt_binding_colon = Scan.optional (Parse.binding --| Parse.$$$ ":") no_binder
@@ -302,10 +330,10 @@
 
 val _ =
   Outer_Syntax.local_theory @{command_spec "data"} "define BNF-based inductive datatypes"
-    (Parse.and_list1 parse_single_spec >> data_cmd false);
+    (Parse.and_list1 parse_single_spec >> datatype_cmd false);
 
 val _ =
   Outer_Syntax.local_theory @{command_spec "codata"} "define BNF-based coinductive datatypes"
-    (Parse.and_list1 parse_single_spec >> data_cmd true);
+    (Parse.and_list1 parse_single_spec >> datatype_cmd true);
 
 end;
--- a/src/HOL/Codatatype/Tools/bnf_wrap.ML	Fri Sep 07 15:28:48 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_wrap.ML	Sat Sep 08 21:04:26 2012 +0200
@@ -9,7 +9,7 @@
 sig
   val no_binder: binding
   val mk_half_pairss: 'a list -> ('a * 'a) list list
-  val wrap_data: ({prems: thm list, context: Proof.context} -> tactic) list list ->
+  val wrap_datatype: ({prems: thm list, context: Proof.context} -> tactic) list list ->
     (term list * term) * (binding list * binding list list) -> local_theory -> local_theory
 end;
 
@@ -62,7 +62,7 @@
   | Free (s, _) => s
   | _ => error "Cannot extract name of constructor";
 
-fun prepare_wrap_data prep_term ((raw_ctrs, raw_case), (raw_disc_binders, raw_sel_binderss))
+fun prepare_wrap_datatype prep_term ((raw_ctrs, raw_case), (raw_disc_binders, raw_sel_binderss))
   no_defs_lthy =
   let
     (* TODO: sanity checks on arguments *)
@@ -507,22 +507,22 @@
     (goalss, after_qed, lthy')
   end;
 
-fun wrap_data tacss = (fn (goalss, after_qed, lthy) =>
+fun wrap_datatype tacss = (fn (goalss, after_qed, lthy) =>
   map2 (map2 (Skip_Proof.prove lthy [] [])) goalss tacss
   |> (fn thms => after_qed thms lthy)) oo
-  prepare_wrap_data (K I) (* FIXME? (singleton o Type_Infer_Context.infer_types) *)
+  prepare_wrap_datatype (K I) (* FIXME? (singleton o Type_Infer_Context.infer_types) *)
 
 val parse_bindings = Parse.$$$ "[" |-- Parse.list Parse.binding --| Parse.$$$ "]";
 val parse_bindingss = Parse.$$$ "[" |-- Parse.list parse_bindings --| Parse.$$$ "]";
 
-val wrap_data_cmd = (fn (goalss, after_qed, lthy) =>
+val wrap_datatype_cmd = (fn (goalss, after_qed, lthy) =>
   Proof.theorem NONE after_qed (map (map (rpair [])) goalss) lthy) oo
-  prepare_wrap_data Syntax.read_term;
+  prepare_wrap_datatype Syntax.read_term;
 
 val _ =
   Outer_Syntax.local_theory_to_proof @{command_spec "wrap_data"} "wraps an existing datatype"
     (((Parse.$$$ "[" |-- Parse.list Parse.term --| Parse.$$$ "]") -- Parse.term --
       Scan.optional (parse_bindings -- Scan.optional parse_bindingss []) ([], []))
-     >> wrap_data_cmd);
+     >> wrap_datatype_cmd);
 
 end;