merged
authorhaftmann
Mon, 23 Feb 2009 21:38:45 +0100
changeset 30077 c5920259850c
parent 30075 ff5b4900d9a5 (current diff)
parent 30076 f3043dafef5f (diff)
child 30078 beee83623cc9
child 30083 41a20af1fb77
merged
--- a/src/HOL/Tools/datatype_codegen.ML	Mon Feb 23 21:34:14 2009 +0100
+++ b/src/HOL/Tools/datatype_codegen.ML	Mon Feb 23 21:38:45 2009 +0100
@@ -6,8 +6,8 @@
 
 signature DATATYPE_CODEGEN =
 sig
-  val get_eq: theory -> string -> thm list
-  val get_case_cert: theory -> string -> thm
+  val mk_eq: theory -> string -> thm list
+  val mk_case_cert: theory -> string -> thm
   val setup: theory -> theory
 end;
 
@@ -323,7 +323,7 @@
 
 (* case certificates *)
 
-fun get_case_cert thy tyco =
+fun mk_case_cert thy tyco =
   let
     val raw_thms =
       (#case_rewrites o DatatypePackage.the_datatype thy) tyco;
@@ -357,10 +357,13 @@
 fun add_datatype_cases dtco thy =
   let
     val {case_rewrites, ...} = DatatypePackage.the_datatype thy dtco;
-    val certs = get_case_cert thy dtco;
+    val cert = mk_case_cert thy dtco;
+    fun add_case_liberal thy = thy
+      |> try (Code.add_case cert)
+      |> the_default thy;
   in
     thy
-    |> Code.add_case certs
+    |> add_case_liberal
     |> fold_rev Code.add_default_eqn case_rewrites
   end;
 
@@ -369,10 +372,10 @@
 
 local
 
-val not_sym = thm "HOL.not_sym";
-val not_false_true = iffD2 OF [nth (thms "HOL.simp_thms") 7, TrueI];
-val refl = thm "refl";
-val eqTrueI = thm "eqTrueI";
+val not_sym = @{thm HOL.not_sym};
+val not_false_true = iffD2 OF [nth @{thms HOL.simp_thms} 7, TrueI];
+val refl = @{thm refl};
+val eqTrueI = @{thm eqTrueI};
 
 fun mk_distinct cos =
   let
@@ -397,7 +400,7 @@
 
 in
 
-fun get_eq thy dtco =
+fun mk_eq thy dtco =
   let
     val (vs, cs) = DatatypePackage.the_datatype_spec thy dtco;
     fun mk_triv_inject co =
@@ -445,7 +448,7 @@
       in (thm', lthy') end;
     fun tac thms = Class.intro_classes_tac []
       THEN ALLGOALS (ProofContext.fact_tac thms);
-    fun get_eq' thy dtco = get_eq thy dtco
+    fun mk_eq' thy dtco = mk_eq thy dtco
       |> map (Code_Unit.constrain_thm thy [HOLogic.class_eq])
       |> map Simpdata.mk_eq
       |> map (MetaSimplifier.rewrite_rule [Thm.transfer thy @{thm equals_eq}])
@@ -460,10 +463,10 @@
               ([pairself (Thm.ctyp_of thy) (TVar (("'a", 0), @{sort eq}), Logic.varifyT ty)], [])
           |> Simpdata.mk_eq
           |> AxClass.unoverload thy;
-        fun get_thms () = (eq_refl, false)
-          :: rev (map (rpair true) (get_eq' (Theory.deref thy_ref) dtco));
+        fun mk_thms () = (eq_refl, false)
+          :: rev (map (rpair true) (mk_eq' (Theory.deref thy_ref) dtco));
       in
-        Code.add_eqnl (const, Lazy.lazy get_thms) thy
+        Code.add_eqnl (const, Lazy.lazy mk_thms) thy
       end;
   in
     thy
--- a/src/Pure/Isar/code.ML	Mon Feb 23 21:34:14 2009 +0100
+++ b/src/Pure/Isar/code.ML	Mon Feb 23 21:38:45 2009 +0100
@@ -157,7 +157,7 @@
     (*with explicit history*),
   dtyps: ((serial * ((string * sort) list * (string * typ list) list)) list) Symtab.table
     (*with explicit history*),
-  cases: (int * string list) Symtab.table * unit Symtab.table
+  cases: (int * (int * string list)) Symtab.table * unit Symtab.table
 };
 
 fun mk_spec ((concluded_history, eqns), (dtyps, cases)) =
@@ -574,12 +574,7 @@
 
 fun del_eqns c = change_eqns true c (K (false, Lazy.value []));
 
-fun get_case_scheme thy c = case Symtab.lookup ((fst o the_cases o the_exec) thy) c
- of SOME (base_case_scheme as (_, case_pats)) =>
-      if forall (is_some o get_datatype_of_constr thy) case_pats
-      then SOME (1 + Int.max (1, length case_pats), base_case_scheme)
-      else NONE
-  | NONE => NONE;
+fun get_case_scheme thy = Symtab.lookup ((fst o the_cases o the_exec) thy);
 
 val is_undefined = Symtab.defined o snd o the_cases o the_exec;
 
@@ -589,11 +584,17 @@
   let
     val cs = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) raw_cs;
     val (tyco, vs_cos) = Code_Unit.constrset_of_consts thy cs;
+    val old_cs = (map fst o snd o get_datatype thy) tyco;
+    fun drop_outdated_cases cases = fold Symtab.delete_safe
+      (Symtab.fold (fn (c, (_, (_, cos))) =>
+        if exists (member (op =) old_cs) cos
+          then insert (op =) c else I) cases []) cases;
   in
     thy
     |> map_exec_purge NONE
         ((map_dtyps o Symtab.map_default (tyco, [])) (cons (serial (), vs_cos))
-        #> map_eqns (fold (Symtab.delete_safe o fst) cs))
+        #> map_eqns (fold (Symtab.delete_safe o fst) cs)
+        #> (map_cases o apfst) drop_outdated_cases)
     |> TypeInterpretation.data (tyco, serial ())
   end;
 
@@ -607,10 +608,12 @@
 
 fun add_case thm thy =
   let
-    val entry as (c, _) = Code_Unit.case_cert thm;
-  in
-    (map_exec_purge (SOME [c]) o map_cases o apfst) (Symtab.update entry) thy
-  end;
+    val (c, (k, case_pats)) = Code_Unit.case_cert thm;
+    val _ = case filter (is_none o get_datatype_of_constr thy) case_pats
+     of [] => ()
+      | cs => error ("Non-constructor(s) in case certificate: " ^ commas (map quote cs));
+    val entry = (1 + Int.max (1, length case_pats), (k, case_pats))
+  in (map_exec_purge (SOME [c]) o map_cases o apfst) (Symtab.update (c, entry)) thy end;
 
 fun add_undefined c thy =
   (map_exec_purge (SOME [c]) o map_cases o apsnd) (Symtab.update (c, ())) thy;