improved treatment of case certificates
authorhaftmann
Mon Feb 23 21:38:36 2009 +0100 (2009-02-23)
changeset 30076f3043dafef5f
parent 30065 c9a1e0f7621b
child 30077 c5920259850c
improved treatment of case certificates
src/HOL/Tools/datatype_codegen.ML
src/Pure/Isar/code.ML
     1.1 --- a/src/HOL/Tools/datatype_codegen.ML	Mon Feb 23 10:07:57 2009 +0100
     1.2 +++ b/src/HOL/Tools/datatype_codegen.ML	Mon Feb 23 21:38:36 2009 +0100
     1.3 @@ -6,8 +6,8 @@
     1.4  
     1.5  signature DATATYPE_CODEGEN =
     1.6  sig
     1.7 -  val get_eq: theory -> string -> thm list
     1.8 -  val get_case_cert: theory -> string -> thm
     1.9 +  val mk_eq: theory -> string -> thm list
    1.10 +  val mk_case_cert: theory -> string -> thm
    1.11    val setup: theory -> theory
    1.12  end;
    1.13  
    1.14 @@ -323,7 +323,7 @@
    1.15  
    1.16  (* case certificates *)
    1.17  
    1.18 -fun get_case_cert thy tyco =
    1.19 +fun mk_case_cert thy tyco =
    1.20    let
    1.21      val raw_thms =
    1.22        (#case_rewrites o DatatypePackage.the_datatype thy) tyco;
    1.23 @@ -357,10 +357,13 @@
    1.24  fun add_datatype_cases dtco thy =
    1.25    let
    1.26      val {case_rewrites, ...} = DatatypePackage.the_datatype thy dtco;
    1.27 -    val certs = get_case_cert thy dtco;
    1.28 +    val cert = mk_case_cert thy dtco;
    1.29 +    fun add_case_liberal thy = thy
    1.30 +      |> try (Code.add_case cert)
    1.31 +      |> the_default thy;
    1.32    in
    1.33      thy
    1.34 -    |> Code.add_case certs
    1.35 +    |> add_case_liberal
    1.36      |> fold_rev Code.add_default_eqn case_rewrites
    1.37    end;
    1.38  
    1.39 @@ -369,10 +372,10 @@
    1.40  
    1.41  local
    1.42  
    1.43 -val not_sym = thm "HOL.not_sym";
    1.44 -val not_false_true = iffD2 OF [nth (thms "HOL.simp_thms") 7, TrueI];
    1.45 -val refl = thm "refl";
    1.46 -val eqTrueI = thm "eqTrueI";
    1.47 +val not_sym = @{thm HOL.not_sym};
    1.48 +val not_false_true = iffD2 OF [nth @{thms HOL.simp_thms} 7, TrueI];
    1.49 +val refl = @{thm refl};
    1.50 +val eqTrueI = @{thm eqTrueI};
    1.51  
    1.52  fun mk_distinct cos =
    1.53    let
    1.54 @@ -397,7 +400,7 @@
    1.55  
    1.56  in
    1.57  
    1.58 -fun get_eq thy dtco =
    1.59 +fun mk_eq thy dtco =
    1.60    let
    1.61      val (vs, cs) = DatatypePackage.the_datatype_spec thy dtco;
    1.62      fun mk_triv_inject co =
    1.63 @@ -445,7 +448,7 @@
    1.64        in (thm', lthy') end;
    1.65      fun tac thms = Class.intro_classes_tac []
    1.66        THEN ALLGOALS (ProofContext.fact_tac thms);
    1.67 -    fun get_eq' thy dtco = get_eq thy dtco
    1.68 +    fun mk_eq' thy dtco = mk_eq thy dtco
    1.69        |> map (Code_Unit.constrain_thm thy [HOLogic.class_eq])
    1.70        |> map Simpdata.mk_eq
    1.71        |> map (MetaSimplifier.rewrite_rule [Thm.transfer thy @{thm equals_eq}])
    1.72 @@ -460,10 +463,10 @@
    1.73                ([pairself (Thm.ctyp_of thy) (TVar (("'a", 0), @{sort eq}), Logic.varifyT ty)], [])
    1.74            |> Simpdata.mk_eq
    1.75            |> AxClass.unoverload thy;
    1.76 -        fun get_thms () = (eq_refl, false)
    1.77 -          :: rev (map (rpair true) (get_eq' (Theory.deref thy_ref) dtco));
    1.78 +        fun mk_thms () = (eq_refl, false)
    1.79 +          :: rev (map (rpair true) (mk_eq' (Theory.deref thy_ref) dtco));
    1.80        in
    1.81 -        Code.add_eqnl (const, Lazy.lazy get_thms) thy
    1.82 +        Code.add_eqnl (const, Lazy.lazy mk_thms) thy
    1.83        end;
    1.84    in
    1.85      thy
     2.1 --- a/src/Pure/Isar/code.ML	Mon Feb 23 10:07:57 2009 +0100
     2.2 +++ b/src/Pure/Isar/code.ML	Mon Feb 23 21:38:36 2009 +0100
     2.3 @@ -157,7 +157,7 @@
     2.4      (*with explicit history*),
     2.5    dtyps: ((serial * ((string * sort) list * (string * typ list) list)) list) Symtab.table
     2.6      (*with explicit history*),
     2.7 -  cases: (int * string list) Symtab.table * unit Symtab.table
     2.8 +  cases: (int * (int * string list)) Symtab.table * unit Symtab.table
     2.9  };
    2.10  
    2.11  fun mk_spec ((concluded_history, eqns), (dtyps, cases)) =
    2.12 @@ -574,12 +574,7 @@
    2.13  
    2.14  fun del_eqns c = change_eqns true c (K (false, Lazy.value []));
    2.15  
    2.16 -fun get_case_scheme thy c = case Symtab.lookup ((fst o the_cases o the_exec) thy) c
    2.17 - of SOME (base_case_scheme as (_, case_pats)) =>
    2.18 -      if forall (is_some o get_datatype_of_constr thy) case_pats
    2.19 -      then SOME (1 + Int.max (1, length case_pats), base_case_scheme)
    2.20 -      else NONE
    2.21 -  | NONE => NONE;
    2.22 +fun get_case_scheme thy = Symtab.lookup ((fst o the_cases o the_exec) thy);
    2.23  
    2.24  val is_undefined = Symtab.defined o snd o the_cases o the_exec;
    2.25  
    2.26 @@ -589,11 +584,17 @@
    2.27    let
    2.28      val cs = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) raw_cs;
    2.29      val (tyco, vs_cos) = Code_Unit.constrset_of_consts thy cs;
    2.30 +    val old_cs = (map fst o snd o get_datatype thy) tyco;
    2.31 +    fun drop_outdated_cases cases = fold Symtab.delete_safe
    2.32 +      (Symtab.fold (fn (c, (_, (_, cos))) =>
    2.33 +        if exists (member (op =) old_cs) cos
    2.34 +          then insert (op =) c else I) cases []) cases;
    2.35    in
    2.36      thy
    2.37      |> map_exec_purge NONE
    2.38          ((map_dtyps o Symtab.map_default (tyco, [])) (cons (serial (), vs_cos))
    2.39 -        #> map_eqns (fold (Symtab.delete_safe o fst) cs))
    2.40 +        #> map_eqns (fold (Symtab.delete_safe o fst) cs)
    2.41 +        #> (map_cases o apfst) drop_outdated_cases)
    2.42      |> TypeInterpretation.data (tyco, serial ())
    2.43    end;
    2.44  
    2.45 @@ -607,10 +608,12 @@
    2.46  
    2.47  fun add_case thm thy =
    2.48    let
    2.49 -    val entry as (c, _) = Code_Unit.case_cert thm;
    2.50 -  in
    2.51 -    (map_exec_purge (SOME [c]) o map_cases o apfst) (Symtab.update entry) thy
    2.52 -  end;
    2.53 +    val (c, (k, case_pats)) = Code_Unit.case_cert thm;
    2.54 +    val _ = case filter (is_none o get_datatype_of_constr thy) case_pats
    2.55 +     of [] => ()
    2.56 +      | cs => error ("Non-constructor(s) in case certificate: " ^ commas (map quote cs));
    2.57 +    val entry = (1 + Int.max (1, length case_pats), (k, case_pats))
    2.58 +  in (map_exec_purge (SOME [c]) o map_cases o apfst) (Symtab.update (c, entry)) thy end;
    2.59  
    2.60  fun add_undefined c thy =
    2.61    (map_exec_purge (SOME [c]) o map_cases o apsnd) (Symtab.update (c, ())) thy;