optionally provide extra dead variables to the FP constructions
authorblanchet
Tue, 04 Sep 2012 23:09:08 +0200
changeset 49134 846264f80f16
parent 49133 4680ac067cb8
child 49135 de13b454fa31
optionally provide extra dead variables to the FP constructions
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_fp_util.ML
src/HOL/Codatatype/Tools/bnf_gfp.ML
src/HOL/Codatatype/Tools/bnf_lfp.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 04 21:51:31 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 04 23:09:08 2012 +0200
@@ -52,15 +52,16 @@
 fun args_of ((_, args), _) = args;
 fun mixfix_of_ctr (_, mx) = mx;
 
-val lfp_info = bnf_lfp;
-val gfp_info = bnf_gfp;
+val uncurry_fs =
+  map2 (fn f => fn xs => HOLogic.tupled_lambda (HOLogic.mk_tuple xs) (Term.list_comb (f, xs)));
 
-fun prepare_data prepare_typ construct specs fake_lthy lthy =
+fun prepare_data prepare_typ gfp specs fake_lthy lthy =
   let
     val constrained_As =
       map (map (apfst (prepare_typ fake_lthy)) o type_args_constrained_of) specs
       |> Library.foldr1 (merge_type_args_constrained lthy);
     val As = map fst constrained_As;
+    val As' = map dest_TFree As;
 
     val _ = (case duplicates (op =) As of [] => ()
       | T :: _ => error ("Duplicate type parameter " ^ quote (Syntax.string_of_typ lthy T)));
@@ -110,7 +111,7 @@
     val eqs = map dest_TFree Bs ~~ sum_prod_TsBs;
 
     val ((raw_unfs, raw_flds, unf_flds, fld_unfs, fld_injects), lthy') =
-      fp_bnf construct bs eqs lthy;
+      fp_bnf (if gfp then bnf_gfp else bnf_lfp) bs As' eqs lthy;
 
     fun mk_unf_or_fld get_foldedT Ts t =
       let val Type (_, Ts0) = get_foldedT (fastype_of t) in
@@ -123,8 +124,8 @@
     val unfs = map (mk_unf As) raw_unfs;
     val flds = map (mk_fld As) raw_flds;
 
-    fun wrap_type ((((((((((b, T), fld), unf), fld_unf), unf_fld), fld_inject), ctr_binders),
-        ctr_Tss), disc_binders), sel_binderss) no_defs_lthy =
+    fun pour_some_sugar_on_type ((((((((((b, T), fld), unf), fld_unf), unf_fld), fld_inject),
+        ctr_binders), ctr_Tss), disc_binders), sel_binderss) no_defs_lthy =
       let
         val n = length ctr_binders;
         val ks = 1 upto n;
@@ -132,26 +133,23 @@
 
         val unf_T = domain_type (fastype_of fld);
         val prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
-        val caseC_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
+        val case_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
 
         val ((((fs, u), v), xss), _) =
           lthy
-          |> mk_Frees "f" caseC_Ts
+          |> mk_Frees "f" case_Ts
           ||>> yield_singleton (mk_Frees "u") unf_T
           ||>> yield_singleton (mk_Frees "v") T
           ||>> mk_Freess "x" ctr_Tss;
 
-        val uncurried_fs =
-          map2 (fn f => fn xs =>
-            HOLogic.tupled_lambda (HOLogic.mk_tuple xs) (Term.list_comb (f, xs))) fs xss;
-
         val ctr_rhss =
           map2 (fn k => fn xs =>
             fold_rev Term.lambda xs (fld $ mk_InN prod_Ts (HOLogic.mk_tuple xs) k)) ks xss;
 
         val case_binder = Binding.suffix_name ("_" ^ caseN) b;
 
-        val case_rhs = fold_rev Term.lambda (fs @ [v]) (mk_sum_caseN uncurried_fs $ (unf $ v));
+        val case_rhs =
+          fold_rev Term.lambda (fs @ [v]) (mk_sum_caseN (uncurry_fs fs xss) $ (unf $ v));
 
         val (((raw_ctrs, raw_ctr_defs), (raw_case, raw_case_def)), (lthy', lthy)) = no_defs_lthy
           |> apfst split_list o fold_map2 (fn b => fn rhs =>
@@ -206,12 +204,31 @@
             mk_case_tac ctxt n k m case_def ctr_def unf_fld) ks ms ctr_defs;
 
         val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs];
+
+        fun sugar_lfp lthy =
+          let
+(*###
+            val iter_Tss = map ( ) ctr_Tss
+            val iter_Ts = map (fn Ts => Ts ---> C) iter_Tss;
+
+            val iter_fs = map2 (fn Free (s, _) => fn T => Free (s, T)) fs iter_Ts
+
+            val uncurried_fs =
+              map2 (fn f => fn xs =>
+                HOLogic.tupled_lambda (HOLogic.mk_tuple xs) (Term.list_comb (f, xs))) fs xss;
+*)
+          in
+            lthy
+          end;
+
+        fun sugar_gfp lthy = lthy;
       in
         wrap_data tacss ((ctrs, casex), (disc_binders, sel_binderss)) lthy'
+        |> (if gfp then sugar_gfp else sugar_lfp)
       end;
   in
     lthy'
-    |> fold wrap_type (bs ~~ Ts ~~ flds ~~ unfs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~
+    |> fold pour_some_sugar_on_type (bs ~~ Ts ~~ flds ~~ unfs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~
       ctr_binderss ~~ ctr_Tsss ~~ disc_binderss ~~ sel_bindersss)
   end;
 
@@ -240,10 +257,10 @@
 
 val _ =
   Outer_Syntax.local_theory @{command_spec "data"} "define BNF-based inductive datatypes"
-    (Parse.and_list1 parse_single_spec >> data_cmd lfp_info);
+    (Parse.and_list1 parse_single_spec >> data_cmd false);
 
 val _ =
   Outer_Syntax.local_theory @{command_spec "codata"} "define BNF-based coinductive datatypes"
-    (Parse.and_list1 parse_single_spec >> data_cmd gfp_info);
+    (Parse.and_list1 parse_single_spec >> data_cmd true);
 
 end;
--- a/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Tue Sep 04 21:51:31 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Tue Sep 04 23:09:08 2012 +0200
@@ -98,10 +98,12 @@
 
   val fixpoint: ('a * 'a -> bool) -> ('a list -> 'a list) -> 'a list -> 'a list
 
-  val fp_bnf: (binding list -> typ list list -> BNF_Def.BNF list -> Proof.context -> 'a) ->
-    binding list -> ((string * sort) * typ) list -> Proof.context -> 'a
-  val fp_bnf_cmd: (binding list -> typ list list -> BNF_Def.BNF list -> Proof.context -> 'a) ->
-    binding list * (string list * string list) -> Proof.context -> 'a
+  val fp_bnf: (binding list -> (string * sort) list -> typ list list -> BNF_Def.BNF list ->
+    local_theory -> 'a) ->
+    binding list -> (string * sort) list -> ((string * sort) * typ) list -> local_theory -> 'a
+  val fp_bnf_cmd: (binding list -> (string * sort) list -> typ list list -> BNF_Def.BNF list ->
+    local_theory -> 'a) ->
+    binding list * (string list * string list) -> local_theory -> 'a
 end;
 
 structure BNF_FP_Util : BNF_FP_UTIL =
@@ -255,7 +257,7 @@
 fun fp_sort lhss Ass = Library.sort (Term_Ord.typ_ord o pairself TFree)
   (subtract (op =) lhss (fold (fold (insert (op =))) Ass [])) @ lhss;
 
-fun mk_fp_bnf timer construct bs sort bnfs deadss livess unfold lthy =
+fun mk_fp_bnf timer construct bs resBs sort bnfs deadss livess unfold lthy =
   let
     val name = fold_rev (fn b => fn s => Binding.name_of b ^ s) bs "";
     fun qualify i bind =
@@ -266,7 +268,8 @@
       end;
 
     val Ass = map (map dest_TFree) livess;
-    val Ds = fold (fold Term.add_tfreesT) deadss [];
+    val resDs = fold (subtract (op =)) Ass resBs;
+    val Ds = fold (fold Term.add_tfreesT) deadss resDs;
 
     val _ = (case Library.inter (op =) Ds (fold (union (op =)) Ass []) of [] => ()
       | A :: _ => error ("Nonadmissible type recursion (cannot take fixed point of dead type \
@@ -284,14 +287,14 @@
 
     val timer = time (timer "Normalization & sealing of BNFs");
 
-    val res = construct bs Dss bnfs'' lthy'';
+    val res = construct bs resDs Dss bnfs'' lthy'';
 
     val timer = time (timer "FP construction in total");
   in
     res
   end;
 
-fun fp_bnf construct bs eqs lthy =
+fun fp_bnf construct bs resBs eqs lthy =
   let
     val timer = time (Timer.startRealTimer ());
     val (lhss, rhss) = split_list eqs;
@@ -300,7 +303,7 @@
       (fold_map2 (fn b => bnf_of_typ Smart_Inline (Binding.suffix_name "RAW" b) I sort) bs rhss
         (empty_unfold, lthy));
   in
-    mk_fp_bnf timer construct bs sort bnfs Dss Ass unfold lthy'
+    mk_fp_bnf timer construct bs resBs sort bnfs Dss Ass unfold lthy'
   end;
 
 fun fp_bnf_cmd construct (bs, (raw_lhss, raw_bnfs)) lthy =
@@ -313,7 +316,7 @@
         (bnf_of_typ Smart_Inline (Binding.suffix_name "RAW" b) I sort (Syntax.read_typ lthy rawT)))
         bs raw_bnfs (empty_unfold, lthy));
   in
-    mk_fp_bnf timer construct bs sort bnfs Dss Ass unfold lthy'
+    mk_fp_bnf timer construct bs [] sort bnfs Dss Ass unfold lthy'
   end;
 
 end;
--- a/src/HOL/Codatatype/Tools/bnf_gfp.ML	Tue Sep 04 21:51:31 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_gfp.ML	Tue Sep 04 23:09:08 2012 +0200
@@ -9,8 +9,8 @@
 
 signature BNF_GFP =
 sig
-  val bnf_gfp: binding list -> typ list list -> BNF_Def.BNF list -> local_theory ->
-    (term list * term list * thm list * thm list * thm list) * local_theory
+  val bnf_gfp: binding list -> (string * sort) list -> typ list list -> BNF_Def.BNF list ->
+    local_theory -> (term list * term list * thm list * thm list * thm list) * local_theory
 end;
 
 structure BNF_GFP : BNF_GFP =
@@ -52,7 +52,7 @@
      ((i, I), nth (nth lwitss i) nwit) :: maps (tree_to_coind_wits lwitss) subtrees;
 
 (*all bnfs have the same lives*)
-fun bnf_gfp bs Dss_insts bnfs lthy =
+fun bnf_gfp bs resDs Dss_insts bnfs lthy =
   let
     val timer = time (Timer.startRealTimer ());
 
@@ -66,7 +66,7 @@
     (* TODO: check if m, n etc are sane *)
 
     val Dss = map (fn Ds => map TFree (fold Term.add_tfreesT Ds [])) Dss_insts;
-    val deads = distinct (op =) (flat Dss);
+    val deads = fold (union (op =)) Dss (map TFree resDs);
     val names_lthy = fold Variable.declare_typ deads lthy;
 
     (* tvars *)
@@ -2950,7 +2950,7 @@
         val Jbnf_notes =
           [(map_simpsN, map single folded_map_simp_thms),
           (set_inclN, set_incl_thmss),
-          (set_set_inclN, map flat set_set_incl_thmsss), (* nicer names? *)
+          (set_set_inclN, map flat set_set_incl_thmsss),
           (rel_unfoldN, map single Jrel_unfold_thms),
           (pred_unfoldN, map single Jpred_unfold_thms)] @
           map2 (fn i => fn thms => (mk_set_simpsN i, map single thms)) ls' folded_set_simp_thmss
--- a/src/HOL/Codatatype/Tools/bnf_lfp.ML	Tue Sep 04 21:51:31 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_lfp.ML	Tue Sep 04 23:09:08 2012 +0200
@@ -8,8 +8,8 @@
 
 signature BNF_LFP =
 sig
-  val bnf_lfp: binding list -> typ list list -> BNF_Def.BNF list -> local_theory ->
-    (term list * term list * thm list * thm list * thm list) * local_theory
+  val bnf_lfp: binding list -> (string * sort) list -> typ list list -> BNF_Def.BNF list ->
+    local_theory -> (term list * term list * thm list * thm list * thm list) * local_theory
 end;
 
 structure BNF_LFP : BNF_LFP =
@@ -23,7 +23,7 @@
 open BNF_LFP_Tactics
 
 (*all bnfs have the same lives*)
-fun bnf_lfp bs Dss_insts bnfs lthy =
+fun bnf_lfp bs resDs Dss_insts bnfs lthy =
   let
     val timer = time (Timer.startRealTimer ());
     val live = live_of_bnf (hd bnfs);
@@ -35,7 +35,7 @@
     (* TODO: check if m, n etc are sane *)
 
     val Dss = map (fn Ds => map TFree (fold Term.add_tfreesT Ds [])) Dss_insts;
-    val deads = distinct (op =) (flat Dss);
+    val deads = fold (union (op =)) Dss (map TFree resDs);
     val names_lthy = fold Variable.declare_typ deads lthy;
 
     (* tvars *)
@@ -1778,7 +1778,7 @@
         val Ibnf_notes =
           [(map_simpsN, map single folded_map_simp_thms),
           (set_inclN, set_incl_thmss),
-          (set_set_inclN, map flat set_set_incl_thmsss), (* nicer names? *)
+          (set_set_inclN, map flat set_set_incl_thmsss),
           (rel_unfoldN, map single Irel_unfold_thms),
           (pred_unfoldN, map single Ipred_unfold_thms)] @
           map2 (fn i => fn thms => (mk_set_simpsN i, map single thms)) ls' folded_set_simp_thmss