allow default names for selectors via wildcard (_) + fix wrong index (k)
authorblanchet
Fri, 31 Aug 2012 15:04:03 +0200
changeset 49046 3c5eba97d93a
parent 49045 7d9631754bba
child 49047 f57c4bb10575
allow default names for selectors via wildcard (_) + fix wrong index (k)
src/HOL/Codatatype/Tools/bnf_sugar.ML
--- a/src/HOL/Codatatype/Tools/bnf_sugar.ML	Fri Aug 31 15:04:03 2012 +0200
+++ b/src/HOL/Codatatype/Tools/bnf_sugar.ML	Fri Aug 31 15:04:03 2012 +0200
@@ -16,6 +16,8 @@
 open BNF_FP_Util
 open BNF_Sugar_Tactics
 
+val is_N = "is_";
+
 val case_congN = "case_cong"
 val case_discsN = "case_discs"
 val casesN = "cases"
@@ -29,19 +31,24 @@
 val split_asmN = "split_asm"
 val weak_case_cong_thmsN = "weak_case_cong"
 
+val default_name = @{binding _};
+
 fun mk_half_pairs [] = []
   | mk_half_pairs (x :: xs) = fold_rev (cons o pair x) xs (mk_half_pairs xs);
 
-fun index_of_half_row _ 0 = 0
-  | index_of_half_row n j = index_of_half_row n (j - 1) + n - j;
-
-fun index_of_half_cell n j k = index_of_half_row n j + k - (j + 1);
+fun index_of_half_cell n j k = j * (2 * n - (j + 1)) div 2 + k - (j + 1);
 
 val mk_Trueprop_eq = HOLogic.mk_Trueprop o HOLogic.mk_eq;
 
 fun eta_expand_caseof_arg xs f_xs = fold_rev Term.lambda xs f_xs;
 
-fun prepare_sugar prep_term (((raw_ctrs, raw_caseof), disc_names), sel_namess) no_defs_lthy =
+fun name_of_ctr t =
+  case head_of t of
+    Const (s, _) => s
+  | Free (s, _) => s
+  | _ => error "Cannot extract name of constructor";
+
+fun prepare_sugar prep_term (((raw_ctrs, raw_caseof), raw_disc_names), sel_namess) no_defs_lthy =
   let
     (* TODO: sanity checks on arguments *)
 
@@ -50,6 +57,13 @@
     val ctrs0 = map (prep_term no_defs_lthy) raw_ctrs;
     val caseof0 = prep_term no_defs_lthy raw_caseof;
 
+    val disc_names =
+      map2 (fn ctr => fn disc =>
+        if Binding.eq_name (disc, default_name) then
+          Binding.name (prefix is_N (Long_Name.base_name (name_of_ctr ctr)))
+        else
+          disc) ctrs0 raw_disc_names;
+
     val n = length ctrs0;
     val ks = 1 upto n;
 
@@ -227,7 +241,7 @@
           let
             fun get_distinct_thm k k' =
               if k > k' then nth half_distinct_thms (index_of_half_cell n (k' - 1) (k - 1))
-              else nth other_half_distinct_thms (index_of_half_cell n (k' - 1) (k' - 1))
+              else nth other_half_distinct_thms (index_of_half_cell n (k - 1) (k' - 1))
             fun mk_thm ((k, discI), not_disc) k' =
               if k = k' then refl RS discI else get_distinct_thm k k' RS not_disc;
           in