--- a/src/HOL/Nat.thy Mon Oct 16 14:07:18 2006 +0200
+++ b/src/HOL/Nat.thy Mon Oct 16 14:07:19 2006 +0200
@@ -114,6 +114,65 @@
inject Suc_Suc_eq
induction nat_induct [case_names 0 Suc]
+text {* fix syntax translation for nat case *}
+
+setup {*
+let
+ val thy = the_context ();
+ val info = DatatypePackage.the_datatype thy "nat";
+ val constrs = (#3 o snd o hd o #descr) info;
+ val constrs' = ["0", "Suc"];
+ val case_name = Sign.extern_const thy (#case_name info);
+ fun nat_case_tr' context ts =
+ if length ts <> length constrs + 1 then raise Match else
+ let
+ val (fs, x) = split_last ts;
+ fun strip_abs 0 t = ([], t)
+ | strip_abs i (Abs p) =
+ let val (x, u) = Syntax.atomic_abs_tr' p
+ in apfst (cons x) (strip_abs (i-1) u) end
+ | strip_abs i (Const ("split", _) $ t) = (case strip_abs (i+1) t of
+ (v :: v' :: vs, u) => (Syntax.const "Pair" $ v $ v' :: vs, u));
+ fun is_dependent i t =
+ let val k = length (strip_abs_vars t) - i
+ in k < 0 orelse exists (fn j => j >= k)
+ (loose_bnos (strip_abs_body t))
+ end;
+ val cases = map (fn (((cname, dts), cname'), t) =>
+ (cname', strip_abs (length dts) t, is_dependent (length dts) t))
+ (constrs ~~ constrs' ~~ fs);
+ fun count_cases (_, _, true) = I
+ | count_cases (cname, (_, body), false) =
+ AList.map_default (op = : term * term -> bool)
+ (body, []) (cons cname)
+ val cases' = sort (int_ord o swap o pairself (length o snd))
+ (fold_rev count_cases cases []);
+ fun mk_case1 (cname, (vs, body), _) = Syntax.const "_case1" $
+ list_comb (Syntax.const cname, vs) $ body;
+ fun is_undefined (Const ("undefined", _)) = true
+ | is_undefined _ = false;
+ in
+ Syntax.const "_case_syntax" $ x $
+ foldr1 (fn (t, u) => Syntax.const "_case2" $ t $ u) (map mk_case1
+ (case find_first (is_undefined o fst) cases' of
+ SOME (_, cnames) =>
+ if length cnames = length constrs then [hd cases]
+ else filter_out (fn (_, (_, body), _) => is_undefined body) cases
+ | NONE => case cases' of
+ [] => cases
+ | (default, cnames) :: _ =>
+ if length cnames = 1 then cases
+ else if length cnames = length constrs then
+ [hd cases, ("dummy_pattern", ([], default), false)]
+ else
+ filter_out (fn (cname, _, _) => cname mem cnames) cases @
+ [("dummy_pattern", ([], default), false)]))
+ end
+in
+ Theory.add_advanced_trfuns ([], [], [(case_name, nat_case_tr')], [])
+end
+*}
+
lemma n_not_Suc_n: "n \<noteq> Suc n"
by (induct n) simp_all
@@ -1051,16 +1110,16 @@
instance nat :: eq ..
lemma [code func]:
- "OperationalEquality.eq (0\<Colon>nat) 0 = True" unfolding eq_def by auto
+ "Code_Generator.eq (0\<Colon>nat) 0 = True" unfolding eq_def by auto
lemma [code func]:
- "OperationalEquality.eq (Suc n) (Suc m) = OperationalEquality.eq n m" unfolding eq_def by auto
+ "Code_Generator.eq (Suc n) (Suc m) = Code_Generator.eq n m" unfolding eq_def by auto
lemma [code func]:
- "OperationalEquality.eq (Suc n) 0 = False" unfolding eq_def by auto
+ "Code_Generator.eq (Suc n) 0 = False" unfolding eq_def by auto
lemma [code func]:
- "OperationalEquality.eq 0 (Suc m) = False" unfolding eq_def by auto
+ "Code_Generator.eq 0 (Suc m) = False" unfolding eq_def by auto
code_typename
nat "IntDef.nat"
@@ -1083,7 +1142,7 @@
"op * \<Colon> nat \<Rightarrow> nat \<Rightarrow> nat" "IntDef.times_nat"
"op < \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool" "IntDef.less_nat"
"op \<le> \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool" "IntDef.less_eq_nat"
- "OperationalEquality.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool" "IntDef.eq_nat"
+ "Code_Generator.eq \<Colon> nat \<Rightarrow> nat \<Rightarrow> bool" "IntDef.eq_nat"
nat_rec "IntDef.nat_rec"
nat_case "IntDef.nat_case"