# HG changeset patch # User haftmann # Date 1235421516 -3600 # Node ID f3043dafef5faccad1ac2a6d3468d0eb78615ed8 # Parent c9a1e0f7621be9f949e7edc7ed2d778b86a8b540 improved treatment of case certificates diff -r c9a1e0f7621b -r f3043dafef5f src/HOL/Tools/datatype_codegen.ML --- a/src/HOL/Tools/datatype_codegen.ML Mon Feb 23 10:07:57 2009 +0100 +++ b/src/HOL/Tools/datatype_codegen.ML Mon Feb 23 21:38:36 2009 +0100 @@ -6,8 +6,8 @@ signature DATATYPE_CODEGEN = sig - val get_eq: theory -> string -> thm list - val get_case_cert: theory -> string -> thm + val mk_eq: theory -> string -> thm list + val mk_case_cert: theory -> string -> thm val setup: theory -> theory end; @@ -323,7 +323,7 @@ (* case certificates *) -fun get_case_cert thy tyco = +fun mk_case_cert thy tyco = let val raw_thms = (#case_rewrites o DatatypePackage.the_datatype thy) tyco; @@ -357,10 +357,13 @@ fun add_datatype_cases dtco thy = let val {case_rewrites, ...} = DatatypePackage.the_datatype thy dtco; - val certs = get_case_cert thy dtco; + val cert = mk_case_cert thy dtco; + fun add_case_liberal thy = thy + |> try (Code.add_case cert) + |> the_default thy; in thy - |> Code.add_case certs + |> add_case_liberal |> fold_rev Code.add_default_eqn case_rewrites end; @@ -369,10 +372,10 @@ local -val not_sym = thm "HOL.not_sym"; -val not_false_true = iffD2 OF [nth (thms "HOL.simp_thms") 7, TrueI]; -val refl = thm "refl"; -val eqTrueI = thm "eqTrueI"; +val not_sym = @{thm HOL.not_sym}; +val not_false_true = iffD2 OF [nth @{thms HOL.simp_thms} 7, TrueI]; +val refl = @{thm refl}; +val eqTrueI = @{thm eqTrueI}; fun mk_distinct cos = let @@ -397,7 +400,7 @@ in -fun get_eq thy dtco = +fun mk_eq thy dtco = let val (vs, cs) = DatatypePackage.the_datatype_spec thy dtco; fun mk_triv_inject co = @@ -445,7 +448,7 @@ in (thm', lthy') end; fun tac thms = Class.intro_classes_tac [] THEN ALLGOALS (ProofContext.fact_tac thms); - fun get_eq' thy dtco = get_eq thy dtco + fun mk_eq' thy dtco = mk_eq thy dtco |> map (Code_Unit.constrain_thm thy [HOLogic.class_eq]) |> map Simpdata.mk_eq |> map (MetaSimplifier.rewrite_rule [Thm.transfer thy @{thm equals_eq}]) @@ -460,10 +463,10 @@ ([pairself (Thm.ctyp_of thy) (TVar (("'a", 0), @{sort eq}), Logic.varifyT ty)], []) |> Simpdata.mk_eq |> AxClass.unoverload thy; - fun get_thms () = (eq_refl, false) - :: rev (map (rpair true) (get_eq' (Theory.deref thy_ref) dtco)); + fun mk_thms () = (eq_refl, false) + :: rev (map (rpair true) (mk_eq' (Theory.deref thy_ref) dtco)); in - Code.add_eqnl (const, Lazy.lazy get_thms) thy + Code.add_eqnl (const, Lazy.lazy mk_thms) thy end; in thy diff -r c9a1e0f7621b -r f3043dafef5f src/Pure/Isar/code.ML --- a/src/Pure/Isar/code.ML Mon Feb 23 10:07:57 2009 +0100 +++ b/src/Pure/Isar/code.ML Mon Feb 23 21:38:36 2009 +0100 @@ -157,7 +157,7 @@ (*with explicit history*), dtyps: ((serial * ((string * sort) list * (string * typ list) list)) list) Symtab.table (*with explicit history*), - cases: (int * string list) Symtab.table * unit Symtab.table + cases: (int * (int * string list)) Symtab.table * unit Symtab.table }; fun mk_spec ((concluded_history, eqns), (dtyps, cases)) = @@ -574,12 +574,7 @@ fun del_eqns c = change_eqns true c (K (false, Lazy.value [])); -fun get_case_scheme thy c = case Symtab.lookup ((fst o the_cases o the_exec) thy) c - of SOME (base_case_scheme as (_, case_pats)) => - if forall (is_some o get_datatype_of_constr thy) case_pats - then SOME (1 + Int.max (1, length case_pats), base_case_scheme) - else NONE - | NONE => NONE; +fun get_case_scheme thy = Symtab.lookup ((fst o the_cases o the_exec) thy); val is_undefined = Symtab.defined o snd o the_cases o the_exec; @@ -589,11 +584,17 @@ let val cs = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) raw_cs; val (tyco, vs_cos) = Code_Unit.constrset_of_consts thy cs; + val old_cs = (map fst o snd o get_datatype thy) tyco; + fun drop_outdated_cases cases = fold Symtab.delete_safe + (Symtab.fold (fn (c, (_, (_, cos))) => + if exists (member (op =) old_cs) cos + then insert (op =) c else I) cases []) cases; in thy |> map_exec_purge NONE ((map_dtyps o Symtab.map_default (tyco, [])) (cons (serial (), vs_cos)) - #> map_eqns (fold (Symtab.delete_safe o fst) cs)) + #> map_eqns (fold (Symtab.delete_safe o fst) cs) + #> (map_cases o apfst) drop_outdated_cases) |> TypeInterpretation.data (tyco, serial ()) end; @@ -607,10 +608,12 @@ fun add_case thm thy = let - val entry as (c, _) = Code_Unit.case_cert thm; - in - (map_exec_purge (SOME [c]) o map_cases o apfst) (Symtab.update entry) thy - end; + val (c, (k, case_pats)) = Code_Unit.case_cert thm; + val _ = case filter (is_none o get_datatype_of_constr thy) case_pats + of [] => () + | cs => error ("Non-constructor(s) in case certificate: " ^ commas (map quote cs)); + val entry = (1 + Int.max (1, length case_pats), (k, case_pats)) + in (map_exec_purge (SOME [c]) o map_cases o apfst) (Symtab.update (c, entry)) thy end; fun add_undefined c thy = (map_exec_purge (SOME [c]) o map_cases o apsnd) (Symtab.update (c, ())) thy;