# HG changeset patch # User berghofe # Date 1177420267 -7200 # Node ID 2fc921376a860d461fccc9855b5bd73a3a6cfb92 # Parent 292dbccd875598e9da5fd1ccdbcf1f9f6f779b25 - Moved parse / print translations for case to datatype_case.ML - Added new functions datatype_of_constr and datatype_of_case to retrieve datatype corresponding to name of constructor or case combinator. diff -r 292dbccd8755 -r 2fc921376a86 src/HOL/Tools/datatype_package.ML --- a/src/HOL/Tools/datatype_package.ML Tue Apr 24 15:07:27 2007 +0200 +++ b/src/HOL/Tools/datatype_package.ML Tue Apr 24 15:11:07 2007 +0200 @@ -65,9 +65,15 @@ val get_datatypes : theory -> DatatypeAux.datatype_info Symtab.table val get_datatype : theory -> string -> DatatypeAux.datatype_info option val the_datatype : theory -> string -> DatatypeAux.datatype_info + val datatype_of_constr : theory -> string -> DatatypeAux.datatype_info option + val datatype_of_case : theory -> string -> DatatypeAux.datatype_info option val get_datatype_spec : theory -> string -> ((string * sort) list * (string * typ list) list) option val get_datatype_constrs : theory -> string -> (string * typ) list option val print_datatypes : theory -> unit + val make_case : Proof.context -> bool -> string list -> term -> + (term * term) list -> term * (term * (int * bool)) list + val strip_case: Proof.context -> bool -> + term -> (term * (term * term) list) option val setup: theory -> theory end; @@ -84,31 +90,53 @@ structure DatatypesData = TheoryDataFun (struct val name = "HOL/datatypes"; - type T = datatype_info Symtab.table; + type T = + {types: datatype_info Symtab.table, + constrs: datatype_info Symtab.table, + cases: datatype_info Symtab.table}; - val empty = Symtab.empty; + val empty = + {types = Symtab.empty, constrs = Symtab.empty, cases = Symtab.empty}; val copy = I; val extend = I; - fun merge _ tabs : T = Symtab.merge (K true) tabs; + fun merge _ + ({types = types1, constrs = constrs1, cases = cases1}, + {types = types2, constrs = constrs2, cases = cases2}) = + {types = Symtab.merge (K true) (types1, types2), + constrs = Symtab.merge (K true) (constrs1, constrs2), + cases = Symtab.merge (K true) (cases1, cases2)}; - fun print sg tab = + fun print sg ({types, ...} : T) = Pretty.writeln (Pretty.strs ("datatypes:" :: - map #1 (NameSpace.extern_table (Sign.type_space sg, tab)))); + map #1 (NameSpace.extern_table (Sign.type_space sg, types)))); end); -val get_datatypes = DatatypesData.get; -val put_datatypes = DatatypesData.put; +val get_datatypes = #types o DatatypesData.get; +val map_datatypes = DatatypesData.map; val print_datatypes = DatatypesData.print; (** theory information about datatypes **) +fun put_dt_infos (dt_infos : (string * datatype_info) list) = + map_datatypes (fn {types, constrs, cases} => + {types = fold Symtab.update dt_infos types, + constrs = fold Symtab.update + (maps (fn (_, info as {descr, index, ...}) => map (rpair info o fst) + (#3 (the (AList.lookup op = descr index)))) dt_infos) constrs, + cases = fold Symtab.update + (map (fn (_, info as {case_name, ...}) => (case_name, info)) dt_infos) + cases}); + val get_datatype = Symtab.lookup o get_datatypes; fun the_datatype thy name = (case get_datatype thy name of SOME info => info | NONE => error ("Unknown datatype " ^ quote name)); +val datatype_of_constr = Symtab.lookup o #constrs o DatatypesData.get; +val datatype_of_case = Symtab.lookup o #cases o DatatypesData.get; + fun get_datatype_descr thy dtco = get_datatype thy dtco |> Option.map (fn info as { descr, index, ... } => @@ -392,127 +420,23 @@ (**** translation rules for case ****) -fun case_tr ctxt [t, u] = - let - val thy = ProofContext.theory_of ctxt; - fun case_error s name ts = raise TERM ("Error in case expression" ^ - getOpt (Option.map (curry op ^ " for datatype ") name, "") ^ ":\n" ^ s, ts); - fun dest_case1 (Const ("_case1", _) $ t $ u) = - (case strip_comb t of - (Const (s, _), ts) => - (case try (unprefix Syntax.constN) s of - SOME c => (c, ts) - | NONE => (Sign.intern_const thy s, ts)) - | (Free (s, _), ts) => (Sign.intern_const thy s, ts) - | _ => case_error "Head is not a constructor" NONE [t, u], u) - | dest_case1 t = raise TERM ("dest_case1", [t]); - fun dest_case2 (Const ("_case2", _) $ t $ u) = t :: dest_case2 u - | dest_case2 t = [t]; - val cases as ((cname, _), _) :: _ = map dest_case1 (dest_case2 u); - val tab = Symtab.dest (get_datatypes thy); - val (cases', default) = (case split_last cases of - (cases', (("dummy_pattern", []), t)) => (cases', SOME t) - | _ => (cases, NONE)) - fun abstr (Free (x, T)) body = Term.absfree (x, T, body) - | abstr (Const ("_constrain", _) $ Free (x, T) $ tT) body = - Syntax.const Syntax.constrainAbsC $ Term.absfree (x, T, body) $ tT - | abstr (Const ("Pair", _) $ x $ y) body = - Syntax.const "split" $ (abstr x o abstr y) body - | abstr t _ = case_error "Illegal pattern" NONE [t]; - in case find_first (fn (_, {descr, index, ...}) => - exists (equal cname o fst) (#3 (snd (nth descr index)))) tab of - NONE => case_error ("Not a datatype constructor: " ^ cname) NONE [u] - | SOME (tname, {descr, sorts, case_name, index, ...}) => - let - val _ = if exists (equal "dummy_pattern" o fst o fst) cases' then - case_error "Illegal occurrence of '_' dummy pattern" (SOME tname) [u] else (); - val (_, (_, dts, constrs)) = nth descr index; - fun find_case (s, dt) cases = - (case find_first (equal s o fst o fst) cases' of - NONE => (list_abs (map (rpair dummyT) - (DatatypeProp.make_tnames (map (typ_of_dtyp descr sorts) dt)), - case default of - NONE => (warning ("No clause for constructor " ^ s ^ - " in case expression"); Const ("HOL.undefined", dummyT)) - | SOME t => t), cases) - | SOME (c as ((_, vs), t)) => - if length dt <> length vs then - case_error ("Wrong number of arguments for constructor " ^ s) - (SOME tname) vs - else (fold_rev abstr vs t, remove (op =) c cases)) - val (fs, cases'') = fold_map find_case constrs cases' - in case (cases'', length constrs = length cases', default) of - ([], true, SOME _) => - case_error "Extra '_' dummy pattern" (SOME tname) [u] - | (_ :: _, _, _) => - let val extra = distinct (op =) (map (fst o fst) cases'') - in case extra \\ map fst constrs of - [] => case_error ("More than one clause for constructor(s) " ^ - commas extra) (SOME tname) [u] - | extra' => case_error ("Illegal constructor(s): " ^ commas extra') - (SOME tname) [u] - end - | _ => list_comb (Syntax.const case_name, fs) $ t - end - end - | case_tr _ ts = raise TERM ("case_tr", ts); +fun make_case ctxt = DatatypeCase.make_case + (datatype_of_constr (ProofContext.theory_of ctxt)) ctxt; + +fun strip_case ctxt = DatatypeCase.strip_case + (datatype_of_case (ProofContext.theory_of ctxt)); -fun case_tr' constrs ctxt ts = - if length ts <> length constrs + 1 then raise Match else - let - val consts = ProofContext.consts_of ctxt; - - val (fs, x) = split_last ts; - fun strip_abs 0 t = ([], t) - | strip_abs i (Abs p) = - let val (x, u) = Syntax.atomic_abs_tr' p - in apfst (cons x) (strip_abs (i-1) u) end - | strip_abs i (Const ("split", _) $ t) = (case strip_abs (i+1) t of - (v :: v' :: vs, u) => (Syntax.const "Pair" $ v $ v' :: vs, u)); - fun is_dependent i t = - let val k = length (strip_abs_vars t) - i - in k < 0 orelse exists (fn j => j >= k) - (loose_bnos (strip_abs_body t)) - end; - val cases = map (fn ((cname, dts), t) => - (Consts.extern_early consts cname, - strip_abs (length dts) t, is_dependent (length dts) t)) - (constrs ~~ fs); - fun count_cases (_, _, true) = I - | count_cases (cname, (_, body), false) = - AList.map_default (op = : term * term -> bool) - (body, []) (cons cname) - val cases' = sort (int_ord o swap o pairself (length o snd)) - (fold_rev count_cases cases []); - fun mk_case1 (cname, (vs, body), _) = Syntax.const "_case1" $ - list_comb (Syntax.const cname, vs) $ body; - fun is_undefined (Const ("HOL.undefined", _)) = true - | is_undefined _ = false; - in - Syntax.const "_case_syntax" $ x $ - foldr1 (fn (t, u) => Syntax.const "_case2" $ t $ u) (map mk_case1 - (case find_first (is_undefined o fst) cases' of - SOME (_, cnames) => - if length cnames = length constrs then [hd cases] - else filter_out (fn (_, (_, body), _) => is_undefined body) cases - | NONE => case cases' of - [] => cases - | (default, cnames) :: _ => - if length cnames = 1 then cases - else if length cnames = length constrs then - [hd cases, ("dummy_pattern", ([], default), false)] - else - filter_out (fn (cname, _, _) => member (op =) cnames cname) cases @ - [("dummy_pattern", ([], default), false)])) - end; - -fun make_case_tr' case_names descr = maps - (fn ((_, (_, _, constrs)), case_name) => - map (rpair (case_tr' constrs)) (NameSpace.accesses' case_name)) - (descr ~~ case_names); +fun add_case_tr' case_names thy = + Theory.add_advanced_trfuns ([], [], + map (fn case_name => + let val case_name' = Sign.const_syntax_name thy case_name + in (case_name', DatatypeCase.case_tr' datatype_of_case case_name') + end) case_names, []) thy; val trfun_setup = - Theory.add_advanced_trfuns ([], [("_case_syntax", case_tr)], [], []); + Theory.add_advanced_trfuns ([], + [("_case_syntax", DatatypeCase.case_tr datatype_of_constr)], + [], []); (* prepare types *) @@ -586,16 +510,19 @@ ||> Theory.parent_path |-> (fn [ax] => pair ax)) (tnames ~~ tss); -fun specify_consts args thy = +fun gen_specify_consts add args thy = let val specs = map (fn (c, T, mx) => Const (Sign.full_name thy (Syntax.const_name c mx), T)) args; in thy - |> Sign.add_consts_i args + |> add args |> Theory.add_finals_i false specs end; +val specify_consts = gen_specify_consts Sign.add_consts_i; +val specify_consts_authentic = gen_specify_consts Sign.add_consts_authentic; + fun add_datatype_axm flat_names new_type_names descr sorts types_syntax constr_syntax dt_info case_names_induct case_names_exhausts thy = let @@ -660,7 +587,7 @@ (** case combinators **) - |> specify_consts (map (fn ((name, T), Ts) => + |> specify_consts_authentic (map (fn ((name, T), Ts) => (name, Ts @ [T] ---> freeT, NoSyn)) (case_names ~~ newTs ~~ case_fn_Ts)); val reccomb_names' = map (Sign.full_name thy2') reccomb_names; @@ -731,11 +658,11 @@ val thy12 = thy11 - |> Theory.add_advanced_trfuns ([], [], make_case_tr' case_names' (hd descr), []) + |> add_case_tr' case_names' |> Theory.add_path (space_implode "_" new_type_names) |> add_rules simps case_thms size_thms rec_thms inject distinct weak_case_congs Simplifier.cong_add - |> put_datatypes (fold Symtab.update dt_infos dt_info) + |> put_dt_infos dt_infos |> add_cases_induct dt_infos induct |> Theory.parent_path |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms) @@ -792,11 +719,11 @@ val thy12 = thy11 - |> Theory.add_advanced_trfuns ([], [], make_case_tr' case_names (hd descr), []) + |> add_case_tr' case_names |> Theory.add_path (space_implode "_" new_type_names) |> add_rules simps case_thms size_thms rec_thms inject distinct weak_case_congs (Simplifier.attrib (op addcongs)) - |> put_datatypes (fold Symtab.update dt_infos dt_info) + |> put_dt_infos dt_infos |> add_cases_induct dt_infos induct |> Theory.parent_path |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms) |> snd @@ -897,10 +824,10 @@ val thy11 = thy10 - |> Theory.add_advanced_trfuns ([], [], make_case_tr' case_names descr, []) + |> add_case_tr' case_names |> add_rules simps case_thms size_thms rec_thms inject distinct weak_case_congs (Simplifier.attrib (op addcongs)) - |> put_datatypes (fold Symtab.update dt_infos dt_info) + |> put_dt_infos dt_infos |> add_cases_induct dt_infos induction' |> Theory.parent_path |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms)