define "case" constant
authorblanchet
Tue, 04 Sep 2012 18:14:58 +0200
changeset 49129 b5413cb7d860
parent 49128 1a86ef0a0210
child 49130 3c26e17b2849
define "case" constant
src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
src/HOL/Codatatype/Tools/bnf_fp_util.ML
src/HOL/Codatatype/Tools/bnf_gfp_util.ML
src/HOL/Codatatype/Tools/bnf_wrap.ML
--- a/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 04 17:23:08 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_sugar.ML	Tue Sep 04 18:14:58 2012 +0200
@@ -19,6 +19,8 @@
 open BNF_GFP
 open BNF_FP_Sugar_Tactics
 
+val caseN = "case";
+
 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
 
 fun merge_type_arg_constrained ctxt (T, c) (T', c') =
@@ -41,7 +43,7 @@
 
 fun type_args_constrained_of (((cAs, _), _), _) = cAs;
 val type_args_of = map fst o type_args_constrained_of;
-fun type_name_of (((_, b), _), _) = b;
+fun type_binder_of (((_, b), _), _) = b;
 fun mixfix_of_typ ((_, mx), _) = mx;
 fun ctr_specs_of (_, ctr_specs) = ctr_specs;
 
@@ -72,7 +74,7 @@
       Type (fst (Term.dest_Type (Proof_Context.read_type_name fake_lthy true (Binding.name_of b))),
         As);
 
-    val bs = map type_name_of specs;
+    val bs = map type_binder_of specs;
     val Ts = map mk_T bs;
 
     val mixfixes = map mixfix_of_typ specs;
@@ -82,12 +84,12 @@
 
     val ctr_specss = map ctr_specs_of specs;
 
-    val disc_namess = map (map disc_of) ctr_specss;
-    val ctr_namess = map (map ctr_of) ctr_specss;
+    val disc_binderss = map (map disc_of) ctr_specss;
+    val ctr_binderss = map (map ctr_of) ctr_specss;
     val ctr_argsss = map (map args_of) ctr_specss;
     val ctr_mixfixess = map (map mixfix_of_ctr) ctr_specss;
 
-    val sel_namesss = map (map (map fst)) ctr_argsss;
+    val sel_bindersss = map (map (map fst)) ctr_argsss;
     val ctr_Tsss = map (map (map (prepare_typ fake_lthy o snd))) ctr_argsss;
 
     val (Bs, C) =
@@ -121,36 +123,44 @@
     val unfs = map (mk_unf As) raw_unfs;
     val flds = map (mk_fld As) raw_flds;
 
-    fun wrap_type (((((((((T, fld), unf), fld_unf), unf_fld), fld_inject), ctr_names), ctr_Tss),
-        disc_names), sel_namess) no_defs_lthy =
+    fun wrap_type ((((((((((b, T), fld), unf), fld_unf), unf_fld), fld_inject), ctr_binders),
+        ctr_Tss), disc_binders), sel_binderss) no_defs_lthy =
       let
-        val n = length ctr_names;
+        val n = length ctr_binders;
         val ks = 1 upto n;
         val ms = map length ctr_Tss;
 
         val unf_T = domain_type (fastype_of fld);
-
         val prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
+        val caseofC_Ts = map (fn Ts => Ts ---> C) ctr_Tss;
 
-        val (((u, v), xss), _) =
+        val ((((fs, u), v), xss), _) =
           lthy
-          |> yield_singleton (mk_Frees "u") unf_T
+          |> mk_Frees "f" caseofC_Ts
+          ||>> yield_singleton (mk_Frees "u") unf_T
           ||>> yield_singleton (mk_Frees "v") T
           ||>> mk_Freess "x" ctr_Tss;
 
-        val rhss =
+        val uncurried_fs =
+          map2 (fn f => fn xs =>
+            HOLogic.tupled_lambda (HOLogic.mk_tuple xs) (Term.list_comb (f, xs))) fs xss;
+
+        val ctr_rhss =
           map2 (fn k => fn xs =>
             fold_rev Term.lambda xs (fld $ mk_InN prod_Ts (HOLogic.mk_tuple xs) k)) ks xss;
 
-        val ((raw_ctrs, raw_ctr_defs), (lthy', lthy)) = no_defs_lthy
+        val caseof_binder = Binding.suffix_name ("_" ^ caseN) b;
+
+        val caseof_rhs = fold_rev Term.lambda (fs @ [v]) (mk_sum_caseN uncurried_fs $ (unf $ v));
+
+        val (((raw_ctrs, raw_ctr_defs), (raw_caseof, raw_caseof_def)), (lthy', lthy)) = no_defs_lthy
           |> apfst split_list o fold_map2 (fn b => fn rhs =>
                Local_Theory.define ((b, NoSyn), ((Thm.def_binding b, []), rhs)) #>> apsnd snd)
-             ctr_names rhss
+             ctr_binders ctr_rhss
+          ||>> (Local_Theory.define ((caseof_binder, NoSyn), ((Thm.def_binding caseof_binder, []),
+             caseof_rhs)) #>> apsnd snd)
           ||> `Local_Theory.restore;
 
-        val raw_caseof =
-          Const (@{const_name undefined}, map (fn Ts => Ts ---> C) ctr_Tss ---> T --> C);
-
         (*transforms defined frees into consts (and more)*)
         val phi = Proof_Context.export_morphism lthy lthy';
 
@@ -196,12 +206,12 @@
 
         val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs];
       in
-        wrap_data tacss ((ctrs, caseof), (disc_names, sel_namess)) lthy'
+        wrap_data tacss ((ctrs, caseof), (disc_binders, sel_binderss)) lthy'
       end;
   in
     lthy'
-    |> fold wrap_type (Ts ~~ flds ~~ unfs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ctr_namess ~~
-      ctr_Tsss ~~ disc_namess ~~ sel_namesss)
+    |> fold wrap_type (bs ~~ Ts ~~ flds ~~ unfs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~
+      ctr_binderss ~~ ctr_Tsss ~~ disc_binderss ~~ sel_bindersss)
   end;
 
 fun data_cmd info specs lthy =
@@ -210,17 +220,17 @@
       Proof_Context.theory_of lthy
       |> Theory.copy
       |> Sign.add_types_global (map (fn spec =>
-        (type_name_of spec, length (type_args_constrained_of spec), mixfix_of_typ spec)) specs)
+        (type_binder_of spec, length (type_args_constrained_of spec), mixfix_of_typ spec)) specs)
       |> Proof_Context.init_global
   in
     prepare_data Syntax.read_typ info specs fake_lthy lthy
   end;
 
-val parse_opt_binding_colon = Scan.optional (Parse.binding --| Parse.$$$ ":") no_name
+val parse_opt_binding_colon = Scan.optional (Parse.binding --| Parse.$$$ ":") no_binder
 
 val parse_ctr_arg =
   Parse.$$$ "(" |-- parse_opt_binding_colon -- Parse.typ --| Parse.$$$ ")" ||
-  (Parse.typ >> pair no_name);
+  (Parse.typ >> pair no_binder);
 
 val parse_single_spec =
   Parse.type_args_constrained -- Parse.binding -- Parse.opt_mixfix --
--- a/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Tue Sep 04 17:23:08 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_fp_util.ML	Tue Sep 04 18:14:58 2012 +0200
@@ -85,6 +85,8 @@
   val mk_Inl: term -> typ -> term
   val mk_Inr: term -> typ -> term
   val mk_InN: typ list -> term -> int -> term
+  val mk_sum_case: term -> term -> term
+  val mk_sum_caseN: term list -> term
 
   val mk_Field: term -> term
   val mk_union: term * term -> term
@@ -201,6 +203,18 @@
   | mk_InN (LT :: Ts) t m = mk_Inr (mk_InN Ts t (m - 1)) LT
   | mk_InN Ts t _ = raise (TYPE ("mk_InN", Ts, [t]));
 
+fun mk_sum_case f g =
+  let
+    val fT = fastype_of f;
+    val gT = fastype_of g;
+  in
+    Const (@{const_name sum_case},
+      fT --> gT --> mk_sumT (domain_type fT, domain_type gT) --> range_type fT) $ f $ g
+  end;
+
+fun mk_sum_caseN [f] = f
+  | mk_sum_caseN (f :: fs) = mk_sum_case f (mk_sum_caseN fs);
+
 fun mk_Field r =
   let val T = fst (dest_relT (fastype_of r));
   in Const (@{const_name Field}, mk_relT (T, T) --> HOLogic.mk_setT T) $ r end;
--- a/src/HOL/Codatatype/Tools/bnf_gfp_util.ML	Tue Sep 04 17:23:08 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_gfp_util.ML	Tue Sep 04 18:14:58 2012 +0200
@@ -35,9 +35,6 @@
   val mk_undefined: typ -> term
   val mk_univ: term -> term
 
-  val mk_sum_case: term -> term -> term
-  val mk_sum_caseN: term list -> term
-
   val mk_specN: int -> thm -> thm
   val mk_sum_casesN: int -> int -> thm
 
@@ -182,18 +179,6 @@
       A $ f1 $ f2 $ b1 $ b2
   end;
 
-fun mk_sum_case f g =
-  let
-    val fT = fastype_of f;
-    val gT = fastype_of g;
-  in
-    Const (@{const_name sum_case},
-      fT --> gT --> mk_sumT (domain_type fT, domain_type gT) --> range_type fT) $ f $ g
-  end;
-
-fun mk_sum_caseN [f] = f
-  | mk_sum_caseN (f :: fs) = mk_sum_case f (mk_sum_caseN fs);
-
 fun mk_InN_not_InM 1 _ = @{thm Inl_not_Inr}
   | mk_InN_not_InM n m =
     if n > m then mk_InN_not_InM m n RS @{thm not_sym}
--- a/src/HOL/Codatatype/Tools/bnf_wrap.ML	Tue Sep 04 17:23:08 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_wrap.ML	Tue Sep 04 18:14:58 2012 +0200
@@ -7,7 +7,7 @@
 
 signature BNF_WRAP =
 sig
-  val no_name: binding
+  val no_binder: binding
   val mk_half_pairss: 'a list -> ('a * 'a) list list
   val wrap_data: ({prems: thm list, context: Proof.context} -> tactic) list list ->
     (term list * term) * (binding list * binding list list) -> local_theory -> local_theory
@@ -40,8 +40,8 @@
 val split_asmN = "split_asm";
 val weak_case_cong_thmsN = "weak_case_cong";
 
-val no_name = @{binding "*"};
-val fallback_name = @{binding _};
+val no_binder = @{binding "*"};
+val fallback_binder = @{binding _};
 
 fun pad_list x n xs = xs @ replicate (n - length xs) x;
 
@@ -61,7 +61,7 @@
   | Free (s, _) => s
   | _ => error "Cannot extract name of constructor";
 
-fun prepare_wrap_data prep_term ((raw_ctrs, raw_caseof), (raw_disc_names, raw_sel_namess))
+fun prepare_wrap_data prep_term ((raw_ctrs, raw_caseof), (raw_disc_binders, raw_sel_binderss))
   no_defs_lthy =
   let
     (* TODO: sanity checks on arguments *)
@@ -96,36 +96,36 @@
 
     val ms = map length ctr_Tss;
 
-    val raw_disc_names' = pad_list no_name n raw_disc_names;
+    val raw_disc_binders' = pad_list no_binder n raw_disc_binders;
 
     fun can_rely_on_disc i =
-      not (Binding.eq_name (nth raw_disc_names' i, no_name)) orelse nth ms i = 0;
-    fun can_omit_disc_name k m =
+      not (Binding.eq_name (nth raw_disc_binders' i, no_binder)) orelse nth ms i = 0;
+    fun can_omit_disc_binder k m =
       n = 1 orelse m = 0 orelse (n = 2 andalso can_rely_on_disc (2 - k))
 
-    val fallback_disc_name = Binding.name o prefix is_N o Long_Name.base_name o name_of_ctr;
+    val fallback_disc_binder = Binding.name o prefix is_N o Long_Name.base_name o name_of_ctr;
 
-    val disc_names =
-      raw_disc_names'
+    val disc_binders =
+      raw_disc_binders'
       |> map4 (fn k => fn m => fn ctr => fn disc =>
-        if Binding.eq_name (disc, no_name) then
-          if can_omit_disc_name k m then NONE else SOME (fallback_disc_name ctr)
-        else if Binding.eq_name (disc, fallback_name) then
-          SOME (fallback_disc_name ctr)
+        if Binding.eq_name (disc, no_binder) then
+          if can_omit_disc_binder k m then NONE else SOME (fallback_disc_binder ctr)
+        else if Binding.eq_name (disc, fallback_binder) then
+          SOME (fallback_disc_binder ctr)
         else
           SOME disc) ks ms ctrs0;
 
-    val no_discs = map is_none disc_names;
+    val no_discs = map is_none disc_binders;
 
-    fun fallback_sel_name m l = Binding.name o mk_un_N m l o Long_Name.base_name o name_of_ctr;
+    fun fallback_sel_binder m l = Binding.name o mk_un_N m l o Long_Name.base_name o name_of_ctr;
 
-    val sel_namess =
-      pad_list [] n raw_sel_namess
+    val sel_binderss =
+      pad_list [] n raw_sel_binderss
       |> map3 (fn ctr => fn m => map2 (fn l => fn sel =>
-        if Binding.eq_name (sel, no_name) orelse Binding.eq_name (sel, fallback_name) then
-          fallback_sel_name m l ctr
+        if Binding.eq_name (sel, no_binder) orelse Binding.eq_name (sel, fallback_binder) then
+          fallback_sel_binder m l ctr
         else
-          sel) (1 upto m) o pad_list no_name m) ctrs0 ms;
+          sel) (1 upto m) o pad_list no_binder m) ctrs0 ms;
 
     fun mk_caseof Ts T =
       let
@@ -172,7 +172,7 @@
 
     fun not_other_disc_lhs i =
       HOLogic.mk_not
-        (case nth disc_names i of NONE => nth exist_xs_v_eq_ctrs i | SOME b => disc_free b $ v);
+        (case nth disc_binders i of NONE => nth exist_xs_v_eq_ctrs i | SOME b => disc_free b $ v);
 
     fun not_other_disc k =
       if n = 2 then Term.lambda v (not_other_disc_lhs (2 - k)) else error "Cannot use \"*\" here"
@@ -193,10 +193,10 @@
            else pair (not_other_disc k, missing_disc_def)
          | SOME b => Specification.definition (SOME (b, NONE, NoSyn),
              ((Thm.def_binding b, []), disc_spec b exist_xs_v_eq_ctr)) #>> apsnd snd)
-        ks ms exist_xs_v_eq_ctrs disc_names
+        ks ms exist_xs_v_eq_ctrs disc_binders
       ||>> apfst split_list o fold_map3 (fn bs => fn xs => fn k => apfst split_list o
           fold_map2 (fn b => fn x => Specification.definition (SOME (b, NONE, NoSyn),
-            ((Thm.def_binding b, []), sel_spec b x xs k)) #>> apsnd snd) bs xs) sel_namess xss ks
+            ((Thm.def_binding b, []), sel_spec b x xs k)) #>> apsnd snd) bs xs) sel_binderss xss ks
       ||> `Local_Theory.restore;
 
     (*transforms defined frees into consts (and more)*)