allow same selector name for several constructors
authorblanchet
Mon, 10 Sep 2012 17:36:02 +0200
changeset 49258 84f13469d7f0
parent 49257 e9cdacf44cc3
child 49259 b21c03c7a097
allow same selector name for several constructors
src/HOL/Codatatype/Tools/bnf_wrap.ML
--- a/src/HOL/Codatatype/Tools/bnf_wrap.ML	Mon Sep 10 17:36:02 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_wrap.ML	Mon Sep 10 17:36:02 2012 +0200
@@ -46,6 +46,9 @@
 
 fun pad_list x n xs = xs @ replicate (n - length xs) x;
 
+fun unflat_lookup _ _ [] = []
+  | unflat_lookup eq ps (xs :: xss) = map (the o AList.lookup eq ps) xs :: unflat_lookup eq ps xss;
+
 fun mk_half_pairss' _ [] = []
   | mk_half_pairss' indent (y :: ys) =
     indent @ fold_rev (cons o single o pair y) ys (mk_half_pairss' ([] :: indent) ys);
@@ -84,15 +87,15 @@
 
     val _ = if n > 0 then () else error "No constructors specified";
 
-    val Type (T_name, As0) = body_type (fastype_of (hd ctrs0));
-    val b = Binding.qualified_name T_name;
+    val Type (fpT_name, As0) = body_type (fastype_of (hd ctrs0));
+    val b = Binding.qualified_name fpT_name;
 
     val (As, B) =
       no_defs_lthy
       |> mk_TFrees (length As0)
       ||> the_single o fst o mk_TFrees 1;
 
-    val T = Type (T_name, As);
+    val fpT = Type (fpT_name, As);
     val ctrs = map (mk_ctr As) ctrs0;
     val ctr_Tss = map (binder_types o fastype_of) ctrs;
 
@@ -146,8 +149,8 @@
       ||>> mk_Freess "y" ctr_Tss
       ||>> mk_Frees "f" case_Ts
       ||>> mk_Frees "g" case_Ts
-      ||>> yield_singleton (apfst (op ~~) oo mk_Frees' "v") T
-      ||>> yield_singleton (mk_Frees "w") T
+      ||>> yield_singleton (apfst (op ~~) oo mk_Frees' "v") fpT
+      ||>> yield_singleton (mk_Frees "w") fpT
       ||>> yield_singleton (apfst (op ~~) oo mk_Frees' "P") HOLogic.boolT;
 
     val q = Free (fst p', B --> HOLogic.boolT);
@@ -170,10 +173,7 @@
     val exist_xs_v_eq_ctrs =
       map2 (fn xctr => fn xs => list_exists_free xs (HOLogic.mk_eq (v, xctr))) xctrs xss;
 
-    fun mk_sel_case_args k xs x T =
-      map2 (fn Ts => fn i => if i = k then fold_rev Term.lambda xs x else mk_undef T Ts) ctr_Tss ks;
-
-    fun disc_free b = Free (Binding.name_of b, T --> HOLogic.boolT);
+    fun disc_free b = Free (Binding.name_of b, fpT --> HOLogic.boolT);
 
     fun disc_spec b exist_xs_v_eq_ctr = mk_Trueprop_eq (disc_free b $ v, exist_xs_v_eq_ctr);
 
@@ -186,18 +186,30 @@
     fun alternate_disc k =
       if n = 2 then Term.lambda v (alternate_disc_lhs (3 - k)) else error "Cannot use \"*\" here"
 
-    fun sel_spec b x xs k =
-      let val T' = fastype_of x in
-        mk_Trueprop_eq (Free (Binding.name_of b, T --> T') $ v,
-          Term.list_comb (mk_case As T', mk_sel_case_args k xs x T') $ v)
+    fun mk_sel_case_args proto_sels T =
+      map2 (fn Ts => fn i =>
+        case AList.lookup (op =) proto_sels i of
+          NONE => mk_undef T Ts
+        | SOME (xs, x) => fold_rev Term.lambda xs x) ctr_Tss ks;
+
+    (* TODO: check types of tail of list *)
+    fun sel_spec b (proto_sels as ((_, (_, x)) :: _)) =
+      let val T = fastype_of x in
+        mk_Trueprop_eq (Free (Binding.name_of b, fpT --> T) $ v,
+          Term.list_comb (mk_case As T, mk_sel_case_args proto_sels T) $ v)
       end;
 
     val missing_unique_disc_def = TrueI; (*arbitrary marker*)
     val missing_alternate_disc_def = FalseE; (*arbitrary marker*)
 
-    (* TODO: Allow use of same selector for several constructors *)
+    val proto_selss = map3 (fn k => fn xs => map (fn x => (k, (xs, x)))) ks xss xss;
 
-    val (((raw_discs, raw_disc_defs), (raw_selss, raw_sel_defss)), (lthy', lthy)) =
+    val sel_bundles = AList.group Binding.eq_name (flat sel_binderss ~~ flat proto_selss);
+    val sel_binders = map fst sel_bundles;
+
+    fun unflat_sels xs = unflat_lookup Binding.eq_name (sel_binders ~~ xs) sel_binderss;
+
+    val (((raw_discs, raw_disc_defs), (raw_sels, raw_sel_defs)), (lthy', lthy)) =
       no_defs_lthy
       |> apfst split_list o fold_map4 (fn k => fn m => fn exist_xs_v_eq_ctr =>
         fn NONE =>
@@ -207,19 +219,19 @@
          | SOME b => Specification.definition (SOME (b, NONE, NoSyn),
              ((Thm.def_binding b, []), disc_spec b exist_xs_v_eq_ctr)) #>> apsnd snd)
         ks ms exist_xs_v_eq_ctrs disc_binders
-      ||>> apfst split_list o fold_map3 (fn bs => fn xs => fn k => apfst split_list o
-          fold_map2 (fn b => fn x => Specification.definition (SOME (b, NONE, NoSyn),
-            ((Thm.def_binding b, []), sel_spec b x xs k)) #>> apsnd snd) bs xs) sel_binderss xss ks
+      ||>> apfst split_list o fold_map (fn (b, proto_sels) =>
+        Specification.definition (SOME (b, NONE, NoSyn),
+          ((Thm.def_binding b, []), sel_spec b proto_sels)) #>> apsnd snd) sel_bundles
       ||> `Local_Theory.restore;
 
     (*transforms defined frees into consts (and more)*)
     val phi = Proof_Context.export_morphism lthy lthy';
 
     val disc_defs = map (Morphism.thm phi) raw_disc_defs;
-    val sel_defss = map (map (Morphism.thm phi)) raw_sel_defss;
+    val sel_defss = unflat_sels (map (Morphism.thm phi) raw_sel_defs);
 
     val discs0 = map (Morphism.term phi) raw_discs;
-    val selss0 = map (map (Morphism.term phi)) raw_selss;
+    val selss0 = unflat_sels (map (Morphism.term phi) raw_sels);
 
     fun mk_disc_or_sel Ts c =
       Term.subst_atomic_types (snd (Term.dest_Type (domain_type (fastype_of c))) ~~ Ts) c;
@@ -288,23 +300,8 @@
           end;
 
         val sel_thmss =
-          let
-            fun mk_thm k xs goal_case case_thm x sel_def =
-              let
-                val T = fastype_of x;
-                val cTs =
-                  map ((fn T' => certifyT lthy (if T' = B then T else T')) o TFree)
-                    (rev (Term.add_tfrees goal_case []));
-                val cxs = map (certify lthy) (mk_sel_case_args k xs x T);
-              in
-                Local_Defs.fold lthy [sel_def]
-                  (Drule.instantiate' (map SOME cTs) (map SOME cxs) case_thm)
-              end;
-            fun mk_thms k xs goal_case case_thm sel_defs =
-              map2 (mk_thm k xs (strip_all_body goal_case) case_thm) xs sel_defs;
-          in
-            map5 mk_thms ks xss goal_cases case_thms sel_defss
-          end;
+          map2 (fn case_thm => map (fn sel_def => case_thm RS (sel_def RS trans))) case_thms
+            sel_defss;
 
         fun mk_unique_disc_def () =
           let