lookup for datatype constructors considers type annotations to resolve overloading
authorhaftmann
Thu, 08 Oct 2009 19:33:03 +0200
changeset 32896 99cd75a18b78
parent 32895 6f3cdb4a9e11
child 32897 2b2c56530d25
lookup for datatype constructors considers type annotations to resolve overloading
src/HOL/Tools/Datatype/datatype.ML
src/HOL/Tools/Datatype/datatype_case.ML
src/HOL/Tools/Datatype/datatype_codegen.ML
src/HOL/Tools/Function/fundef_datatype.ML
--- a/src/HOL/Tools/Datatype/datatype.ML	Thu Oct 08 15:59:17 2009 +0200
+++ b/src/HOL/Tools/Datatype/datatype.ML	Thu Oct 08 19:33:03 2009 +0200
@@ -22,7 +22,7 @@
   val the_spec : theory -> string -> (string * sort) list * (string * typ list) list
   val get_constrs : theory -> string -> (string * typ) list option
   val get_all : theory -> info Symtab.table
-  val info_of_constr : theory -> string -> info option
+  val info_of_constr : theory -> string * typ -> info option
   val info_of_case : theory -> string -> info option
   val interpretation : (config -> string list -> theory -> theory) -> theory -> theory
   val distinct_simproc : simproc
@@ -47,7 +47,7 @@
 (
   type T =
     {types: info Symtab.table,
-     constrs: info Symtab.table,
+     constrs: (string * info) list Symtab.table,
      cases: info Symtab.table};
 
   val empty =
@@ -58,7 +58,7 @@
     ({types = types1, constrs = constrs1, cases = cases1},
      {types = types2, constrs = constrs2, cases = cases2}) =
     {types = Symtab.merge (K true) (types1, types2),
-     constrs = Symtab.merge (K true) (constrs1, constrs2),
+     constrs = Symtab.join (K (AList.merge (op =) (K true))) (constrs1, constrs2),
      cases = Symtab.merge (K true) (cases1, cases2)};
 );
 
@@ -68,18 +68,32 @@
       SOME info => info
     | NONE => error ("Unknown datatype " ^ quote name));
 
-val info_of_constr = Symtab.lookup o #constrs o DatatypesData.get;
+fun info_of_constr thy (c, T) =
+  let
+    val tab = Symtab.lookup_list ((#constrs o DatatypesData.get) thy) c;
+    val hint = case strip_type T of (_, Type (tyco, _)) => SOME tyco | _ => NONE;
+    val default = if null tab then NONE
+      else SOME (snd (Library.last_elem tab))
+        (*conservative wrt. overloaded constructors*);
+  in case hint
+   of NONE => default
+    | SOME tyco => case AList.lookup (op =) tab tyco
+       of NONE => default (*permissive*)
+        | SOME info => SOME info
+  end;
+
 val info_of_case = Symtab.lookup o #cases o DatatypesData.get;
 
 fun register (dt_infos : (string * info) list) =
   DatatypesData.map (fn {types, constrs, cases} =>
-    {types = fold Symtab.update dt_infos types,
-     constrs = fold Symtab.default (*conservative wrt. overloaded constructors*)
-       (maps (fn (_, info as {descr, index, ...}) => map (rpair info o fst)
-          (#3 (the (AList.lookup op = descr index)))) dt_infos) constrs,
-     cases = fold Symtab.update
-       (map (fn (_, info as {case_name, ...}) => (case_name, info)) dt_infos)
-       cases});
+    {types = types |> fold Symtab.update dt_infos,
+     constrs = constrs |> fold (fn (constr, dtname_info) =>
+         Symtab.map_default (constr, []) (cons dtname_info))
+       (maps (fn (dtname, info as {descr, index, ...}) =>
+          map (rpair (dtname, info) o fst)
+            (#3 (the (AList.lookup op = descr index)))) dt_infos),
+     cases = cases |> fold Symtab.update
+       (map (fn (_, info as {case_name, ...}) => (case_name, info)) dt_infos)});
 
 (* complex queries *)
 
--- a/src/HOL/Tools/Datatype/datatype_case.ML	Thu Oct 08 15:59:17 2009 +0200
+++ b/src/HOL/Tools/Datatype/datatype_case.ML	Thu Oct 08 19:33:03 2009 +0200
@@ -8,14 +8,14 @@
 signature DATATYPE_CASE =
 sig
   datatype config = Error | Warning | Quiet;
-  val make_case: (string -> DatatypeAux.info option) ->
+  val make_case: (string * typ -> DatatypeAux.info option) ->
     Proof.context -> config -> string list -> term -> (term * term) list ->
     term * (term * (int * bool)) list
   val dest_case: (string -> DatatypeAux.info option) -> bool ->
     string list -> term -> (term * (term * term) list) option
   val strip_case: (string -> DatatypeAux.info option) -> bool ->
     term -> (term * (term * term) list) option
-  val case_tr: bool -> (theory -> string -> DatatypeAux.info option)
+  val case_tr: bool -> (theory -> string * typ -> DatatypeAux.info option)
     -> Proof.context -> term list -> term
   val case_tr': (theory -> string -> DatatypeAux.info option) ->
     string -> Proof.context -> term list -> term
@@ -34,9 +34,9 @@
  * Get information about datatypes
  *---------------------------------------------------------------------------*)
 
-fun ty_info (tab : string -> DatatypeAux.info option) s =
-  case tab s of
-    SOME {descr, case_name, index, sorts, ...} =>
+fun ty_info tab sT =
+  case tab sT of
+    SOME ({descr, case_name, index, sorts, ...} : DatatypeAux.info) =>
       let
         val (_, (tname, dts, constrs)) = nth descr index;
         val mk_ty = DatatypeAux.typ_of_dtyp descr sorts;
@@ -216,7 +216,7 @@
                     pattern_subst [(v, u)] |>> v_to_prfx) (col0 ~~ rows);
                   val (pref_patl, tm) = mk {path = rstp, rows = rows'}
                 in (map v_to_pats pref_patl, tm) end
-            | SOME (Const (cname, cT), i) => (case ty_info tab cname of
+            | SOME (Const (cname, cT), i) => (case ty_info tab (cname, cT) of
                 NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i)
               | SOME {case_name, constructors} =>
                 let
--- a/src/HOL/Tools/Datatype/datatype_codegen.ML	Thu Oct 08 15:59:17 2009 +0200
+++ b/src/HOL/Tools/Datatype/datatype_codegen.ML	Thu Oct 08 19:33:03 2009 +0200
@@ -281,7 +281,7 @@
           if is_some (get_assoc_code thy (s, T)) then NONE else
           SOME (pretty_case thy defs dep module brack
             (#3 (the (AList.lookup op = descr index))) c ts gr )
-      | NONE => case (Datatype.info_of_constr thy s, strip_type T) of
+      | NONE => case (Datatype.info_of_constr thy (s, T), strip_type T) of
         (SOME {index, descr, ...}, (_, U as Type (tyname, _))) =>
           if is_some (get_assoc_code thy (s, T)) then NONE else
           let
--- a/src/HOL/Tools/Function/fundef_datatype.ML	Thu Oct 08 15:59:17 2009 +0200
+++ b/src/HOL/Tools/Function/fundef_datatype.ML	Thu Oct 08 19:33:03 2009 +0200
@@ -40,7 +40,7 @@
           let
             val (hd, args) = strip_comb t
           in
-            (((case Datatype.info_of_constr thy (fst (dest_Const hd)) of
+            (((case Datatype.info_of_constr thy (dest_Const hd) of
                  SOME _ => ()
                | NONE => err "Non-constructor pattern")
               handle TERM ("dest_Const", _) => err "Non-constructor patterns");