--- a/src/HOL/Tools/datatype_package.ML Tue Apr 24 15:07:27 2007 +0200
+++ b/src/HOL/Tools/datatype_package.ML Tue Apr 24 15:11:07 2007 +0200
@@ -65,9 +65,15 @@
val get_datatypes : theory -> DatatypeAux.datatype_info Symtab.table
val get_datatype : theory -> string -> DatatypeAux.datatype_info option
val the_datatype : theory -> string -> DatatypeAux.datatype_info
+ val datatype_of_constr : theory -> string -> DatatypeAux.datatype_info option
+ val datatype_of_case : theory -> string -> DatatypeAux.datatype_info option
val get_datatype_spec : theory -> string -> ((string * sort) list * (string * typ list) list) option
val get_datatype_constrs : theory -> string -> (string * typ) list option
val print_datatypes : theory -> unit
+ val make_case : Proof.context -> bool -> string list -> term ->
+ (term * term) list -> term * (term * (int * bool)) list
+ val strip_case: Proof.context -> bool ->
+ term -> (term * (term * term) list) option
val setup: theory -> theory
end;
@@ -84,31 +90,53 @@
structure DatatypesData = TheoryDataFun
(struct
val name = "HOL/datatypes";
- type T = datatype_info Symtab.table;
+ type T =
+ {types: datatype_info Symtab.table,
+ constrs: datatype_info Symtab.table,
+ cases: datatype_info Symtab.table};
- val empty = Symtab.empty;
+ val empty =
+ {types = Symtab.empty, constrs = Symtab.empty, cases = Symtab.empty};
val copy = I;
val extend = I;
- fun merge _ tabs : T = Symtab.merge (K true) tabs;
+ fun merge _
+ ({types = types1, constrs = constrs1, cases = cases1},
+ {types = types2, constrs = constrs2, cases = cases2}) =
+ {types = Symtab.merge (K true) (types1, types2),
+ constrs = Symtab.merge (K true) (constrs1, constrs2),
+ cases = Symtab.merge (K true) (cases1, cases2)};
- fun print sg tab =
+ fun print sg ({types, ...} : T) =
Pretty.writeln (Pretty.strs ("datatypes:" ::
- map #1 (NameSpace.extern_table (Sign.type_space sg, tab))));
+ map #1 (NameSpace.extern_table (Sign.type_space sg, types))));
end);
-val get_datatypes = DatatypesData.get;
-val put_datatypes = DatatypesData.put;
+val get_datatypes = #types o DatatypesData.get;
+val map_datatypes = DatatypesData.map;
val print_datatypes = DatatypesData.print;
(** theory information about datatypes **)
+fun put_dt_infos (dt_infos : (string * datatype_info) list) =
+ map_datatypes (fn {types, constrs, cases} =>
+ {types = fold Symtab.update dt_infos types,
+ constrs = fold Symtab.update
+ (maps (fn (_, info as {descr, index, ...}) => map (rpair info o fst)
+ (#3 (the (AList.lookup op = descr index)))) dt_infos) constrs,
+ cases = fold Symtab.update
+ (map (fn (_, info as {case_name, ...}) => (case_name, info)) dt_infos)
+ cases});
+
val get_datatype = Symtab.lookup o get_datatypes;
fun the_datatype thy name = (case get_datatype thy name of
SOME info => info
| NONE => error ("Unknown datatype " ^ quote name));
+val datatype_of_constr = Symtab.lookup o #constrs o DatatypesData.get;
+val datatype_of_case = Symtab.lookup o #cases o DatatypesData.get;
+
fun get_datatype_descr thy dtco =
get_datatype thy dtco
|> Option.map (fn info as { descr, index, ... } =>
@@ -392,127 +420,23 @@
(**** translation rules for case ****)
-fun case_tr ctxt [t, u] =
- let
- val thy = ProofContext.theory_of ctxt;
- fun case_error s name ts = raise TERM ("Error in case expression" ^
- getOpt (Option.map (curry op ^ " for datatype ") name, "") ^ ":\n" ^ s, ts);
- fun dest_case1 (Const ("_case1", _) $ t $ u) =
- (case strip_comb t of
- (Const (s, _), ts) =>
- (case try (unprefix Syntax.constN) s of
- SOME c => (c, ts)
- | NONE => (Sign.intern_const thy s, ts))
- | (Free (s, _), ts) => (Sign.intern_const thy s, ts)
- | _ => case_error "Head is not a constructor" NONE [t, u], u)
- | dest_case1 t = raise TERM ("dest_case1", [t]);
- fun dest_case2 (Const ("_case2", _) $ t $ u) = t :: dest_case2 u
- | dest_case2 t = [t];
- val cases as ((cname, _), _) :: _ = map dest_case1 (dest_case2 u);
- val tab = Symtab.dest (get_datatypes thy);
- val (cases', default) = (case split_last cases of
- (cases', (("dummy_pattern", []), t)) => (cases', SOME t)
- | _ => (cases, NONE))
- fun abstr (Free (x, T)) body = Term.absfree (x, T, body)
- | abstr (Const ("_constrain", _) $ Free (x, T) $ tT) body =
- Syntax.const Syntax.constrainAbsC $ Term.absfree (x, T, body) $ tT
- | abstr (Const ("Pair", _) $ x $ y) body =
- Syntax.const "split" $ (abstr x o abstr y) body
- | abstr t _ = case_error "Illegal pattern" NONE [t];
- in case find_first (fn (_, {descr, index, ...}) =>
- exists (equal cname o fst) (#3 (snd (nth descr index)))) tab of
- NONE => case_error ("Not a datatype constructor: " ^ cname) NONE [u]
- | SOME (tname, {descr, sorts, case_name, index, ...}) =>
- let
- val _ = if exists (equal "dummy_pattern" o fst o fst) cases' then
- case_error "Illegal occurrence of '_' dummy pattern" (SOME tname) [u] else ();
- val (_, (_, dts, constrs)) = nth descr index;
- fun find_case (s, dt) cases =
- (case find_first (equal s o fst o fst) cases' of
- NONE => (list_abs (map (rpair dummyT)
- (DatatypeProp.make_tnames (map (typ_of_dtyp descr sorts) dt)),
- case default of
- NONE => (warning ("No clause for constructor " ^ s ^
- " in case expression"); Const ("HOL.undefined", dummyT))
- | SOME t => t), cases)
- | SOME (c as ((_, vs), t)) =>
- if length dt <> length vs then
- case_error ("Wrong number of arguments for constructor " ^ s)
- (SOME tname) vs
- else (fold_rev abstr vs t, remove (op =) c cases))
- val (fs, cases'') = fold_map find_case constrs cases'
- in case (cases'', length constrs = length cases', default) of
- ([], true, SOME _) =>
- case_error "Extra '_' dummy pattern" (SOME tname) [u]
- | (_ :: _, _, _) =>
- let val extra = distinct (op =) (map (fst o fst) cases'')
- in case extra \\ map fst constrs of
- [] => case_error ("More than one clause for constructor(s) " ^
- commas extra) (SOME tname) [u]
- | extra' => case_error ("Illegal constructor(s): " ^ commas extra')
- (SOME tname) [u]
- end
- | _ => list_comb (Syntax.const case_name, fs) $ t
- end
- end
- | case_tr _ ts = raise TERM ("case_tr", ts);
+fun make_case ctxt = DatatypeCase.make_case
+ (datatype_of_constr (ProofContext.theory_of ctxt)) ctxt;
+
+fun strip_case ctxt = DatatypeCase.strip_case
+ (datatype_of_case (ProofContext.theory_of ctxt));
-fun case_tr' constrs ctxt ts =
- if length ts <> length constrs + 1 then raise Match else
- let
- val consts = ProofContext.consts_of ctxt;
-
- val (fs, x) = split_last ts;
- fun strip_abs 0 t = ([], t)
- | strip_abs i (Abs p) =
- let val (x, u) = Syntax.atomic_abs_tr' p
- in apfst (cons x) (strip_abs (i-1) u) end
- | strip_abs i (Const ("split", _) $ t) = (case strip_abs (i+1) t of
- (v :: v' :: vs, u) => (Syntax.const "Pair" $ v $ v' :: vs, u));
- fun is_dependent i t =
- let val k = length (strip_abs_vars t) - i
- in k < 0 orelse exists (fn j => j >= k)
- (loose_bnos (strip_abs_body t))
- end;
- val cases = map (fn ((cname, dts), t) =>
- (Consts.extern_early consts cname,
- strip_abs (length dts) t, is_dependent (length dts) t))
- (constrs ~~ fs);
- fun count_cases (_, _, true) = I
- | count_cases (cname, (_, body), false) =
- AList.map_default (op = : term * term -> bool)
- (body, []) (cons cname)
- val cases' = sort (int_ord o swap o pairself (length o snd))
- (fold_rev count_cases cases []);
- fun mk_case1 (cname, (vs, body), _) = Syntax.const "_case1" $
- list_comb (Syntax.const cname, vs) $ body;
- fun is_undefined (Const ("HOL.undefined", _)) = true
- | is_undefined _ = false;
- in
- Syntax.const "_case_syntax" $ x $
- foldr1 (fn (t, u) => Syntax.const "_case2" $ t $ u) (map mk_case1
- (case find_first (is_undefined o fst) cases' of
- SOME (_, cnames) =>
- if length cnames = length constrs then [hd cases]
- else filter_out (fn (_, (_, body), _) => is_undefined body) cases
- | NONE => case cases' of
- [] => cases
- | (default, cnames) :: _ =>
- if length cnames = 1 then cases
- else if length cnames = length constrs then
- [hd cases, ("dummy_pattern", ([], default), false)]
- else
- filter_out (fn (cname, _, _) => member (op =) cnames cname) cases @
- [("dummy_pattern", ([], default), false)]))
- end;
-
-fun make_case_tr' case_names descr = maps
- (fn ((_, (_, _, constrs)), case_name) =>
- map (rpair (case_tr' constrs)) (NameSpace.accesses' case_name))
- (descr ~~ case_names);
+fun add_case_tr' case_names thy =
+ Theory.add_advanced_trfuns ([], [],
+ map (fn case_name =>
+ let val case_name' = Sign.const_syntax_name thy case_name
+ in (case_name', DatatypeCase.case_tr' datatype_of_case case_name')
+ end) case_names, []) thy;
val trfun_setup =
- Theory.add_advanced_trfuns ([], [("_case_syntax", case_tr)], [], []);
+ Theory.add_advanced_trfuns ([],
+ [("_case_syntax", DatatypeCase.case_tr datatype_of_constr)],
+ [], []);
(* prepare types *)
@@ -586,16 +510,19 @@
||> Theory.parent_path
|-> (fn [ax] => pair ax)) (tnames ~~ tss);
-fun specify_consts args thy =
+fun gen_specify_consts add args thy =
let
val specs = map (fn (c, T, mx) =>
Const (Sign.full_name thy (Syntax.const_name c mx), T)) args;
in
thy
- |> Sign.add_consts_i args
+ |> add args
|> Theory.add_finals_i false specs
end;
+val specify_consts = gen_specify_consts Sign.add_consts_i;
+val specify_consts_authentic = gen_specify_consts Sign.add_consts_authentic;
+
fun add_datatype_axm flat_names new_type_names descr sorts types_syntax constr_syntax dt_info
case_names_induct case_names_exhausts thy =
let
@@ -660,7 +587,7 @@
(** case combinators **)
- |> specify_consts (map (fn ((name, T), Ts) =>
+ |> specify_consts_authentic (map (fn ((name, T), Ts) =>
(name, Ts @ [T] ---> freeT, NoSyn)) (case_names ~~ newTs ~~ case_fn_Ts));
val reccomb_names' = map (Sign.full_name thy2') reccomb_names;
@@ -731,11 +658,11 @@
val thy12 =
thy11
- |> Theory.add_advanced_trfuns ([], [], make_case_tr' case_names' (hd descr), [])
+ |> add_case_tr' case_names'
|> Theory.add_path (space_implode "_" new_type_names)
|> add_rules simps case_thms size_thms rec_thms inject distinct
weak_case_congs Simplifier.cong_add
- |> put_datatypes (fold Symtab.update dt_infos dt_info)
+ |> put_dt_infos dt_infos
|> add_cases_induct dt_infos induct
|> Theory.parent_path
|> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms)
@@ -792,11 +719,11 @@
val thy12 =
thy11
- |> Theory.add_advanced_trfuns ([], [], make_case_tr' case_names (hd descr), [])
+ |> add_case_tr' case_names
|> Theory.add_path (space_implode "_" new_type_names)
|> add_rules simps case_thms size_thms rec_thms inject distinct
weak_case_congs (Simplifier.attrib (op addcongs))
- |> put_datatypes (fold Symtab.update dt_infos dt_info)
+ |> put_dt_infos dt_infos
|> add_cases_induct dt_infos induct
|> Theory.parent_path
|> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms) |> snd
@@ -897,10 +824,10 @@
val thy11 =
thy10
- |> Theory.add_advanced_trfuns ([], [], make_case_tr' case_names descr, [])
+ |> add_case_tr' case_names
|> add_rules simps case_thms size_thms rec_thms inject distinct
weak_case_congs (Simplifier.attrib (op addcongs))
- |> put_datatypes (fold Symtab.update dt_infos dt_info)
+ |> put_dt_infos dt_infos
|> add_cases_induct dt_infos induction'
|> Theory.parent_path
|> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms)