# HG changeset patch # User wenzelm # Date 1324029775 -3600 # Node ID b619242b0439e8c81ff6dd7d4a87644f281ace78 # Parent 65cef02981583b7a7e8a8f89d43cb1ed17db7193 tuned; diff -r 65cef0298158 -r b619242b0439 src/HOL/Tools/Datatype/datatype_case.ML --- a/src/HOL/Tools/Datatype/datatype_case.ML Fri Dec 16 10:52:35 2011 +0100 +++ b/src/HOL/Tools/Datatype/datatype_case.ML Fri Dec 16 11:02:55 2011 +0100 @@ -130,7 +130,7 @@ names = names, constraints = cnstrts, group = in_group'} :: part cs not_in_group - end + end; in part constructors rows end; fun v_to_prfx (prfx, Free v :: pats) = (v :: prfx, pats) @@ -143,7 +143,6 @@ let val get_info = Datatype_Data.info_of_constr_permissive (Proof_Context.theory_of ctxt); - val name = singleton (Name.variant_list used) "a"; fun expand constructors used ty ((_, []), _) = raise CASE_ERROR ("mk_case: expand_var_row", ~1) | expand constructors used ty (row as ((prfx, p :: ps), (rhs, tag))) = if is_Free p then @@ -153,7 +152,10 @@ let val capp = list_comb (fresh_constr ty_match ty_inst ty used' c) in ((prfx, capp :: ps), (subst_free [(p, capp)] rhs, tag)) end in map expnd constructors end - else [row] + else [row]; + + val name = singleton (Name.variant_list used) "a"; + fun mk _ [] = raise CASE_ERROR ("no rows", ~1) | mk [] (((_, []), (tm, tag)) :: _) = ([tag], tm) (* Done *) | mk path (rows as ((row as ((_, [Free _]), _)) :: _ :: _)) = mk path [row] @@ -277,19 +279,22 @@ val (u', used'') = prep_pat u used'; in (t' $ u', used'') end | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t); + fun dest_case1 (t as Const (@{syntax_const "_case1"}, _) $ l $ r) = let val (l', cnstrts) = strip_constraints l in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts) end | dest_case1 t = case_error "dest_case1"; + fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u | dest_case2 t = [t]; + val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u)); - val case_tm = - make_case_untyped ctxt - (if err then Error else Warning) [] - (fold (fn tT => fn t => Syntax.const @{syntax_const "_constrain"} $ t $ tT) - (flat cnstrts) t) cases; - in case_tm end + in + make_case_untyped ctxt + (if err then Error else Warning) [] + (fold (fn tT => fn t => Syntax.const @{syntax_const "_constrain"} $ t $ tT) + (flat cnstrts) t) cases + end | case_tr _ _ _ = case_error "case_tr"; val trfun_setup = diff -r 65cef0298158 -r b619242b0439 src/HOL/Tools/Datatype/primrec.ML --- a/src/HOL/Tools/Datatype/primrec.ML Fri Dec 16 10:52:35 2011 +0100 +++ b/src/HOL/Tools/Datatype/primrec.ML Fri Dec 16 11:02:55 2011 +0100 @@ -206,11 +206,11 @@ (* find datatypes which contain all datatypes in tnames' *) -fun find_dts (dt_info : Datatype_Aux.info Symtab.table) _ [] = [] +fun find_dts _ _ [] = [] | find_dts dt_info tnames' (tname :: tnames) = (case Symtab.lookup dt_info tname of NONE => primrec_error (quote tname ^ " is not a datatype") - | SOME dt => + | SOME (dt : Datatype_Aux.info) => if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then (tname, dt) :: (find_dts dt_info tnames' tnames) else find_dts dt_info tnames' tnames); @@ -218,12 +218,12 @@ (* distill primitive definition(s) from primrec specification *) -fun distill lthy fixes eqs = +fun distill ctxt fixes eqs = let - val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v + val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed ctxt v orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs []; val tnames = distinct (op =) (map (#1 o snd) eqns); - val dts = find_dts (Datatype_Data.get_all (Proof_Context.theory_of lthy)) tnames tnames; + val dts = find_dts (Datatype_Data.get_all (Proof_Context.theory_of ctxt)) tnames tnames; val main_fns = map (fn (tname, {index, ...}) => (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts; val {descr, rec_names, rec_rewrites, ...} = @@ -232,7 +232,7 @@ else snd (hd dts); val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []); val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []); - val defs = map (make_def lthy fixes fs) raw_defs; + val defs = map (make_def ctxt fixes fs) raw_defs; val names = map snd fnames; val names_eqns = map fst eqns; val _ = @@ -241,17 +241,17 @@ "\nare not mutually recursive"); val rec_rewrites' = map mk_meta_eq rec_rewrites; val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs); - fun prove lthy defs = + fun prove ctxt defs = let - val frees = fold (Variable.add_free_names lthy) eqs []; + val frees = fold (Variable.add_free_names ctxt) eqs []; val rewrites = rec_rewrites' @ map (snd o snd) defs; fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1]; - in map (fn eq => Goal.prove lthy frees [] eq tac) eqs end; + in map (fn eq => Goal.prove ctxt frees [] eq tac) eqs end; in ((prefix, (fs, defs)), prove) end handle PrimrecError (msg, some_eqn) => error ("Primrec definition error:\n" ^ msg ^ (case some_eqn of - SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn) + SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term ctxt eqn) | NONE => "")); @@ -259,7 +259,7 @@ fun add_primrec_simple fixes ts lthy = let - val ((prefix, (fs, defs)), prove) = distill lthy fixes ts; + val ((prefix, (_, defs)), prove) = distill lthy fixes ts; in lthy |> fold_map Local_Theory.define defs