added new primrec package
authorhaftmann
Thu Dec 06 15:10:09 2007 +0100 (2007-12-06)
changeset 25557ea6b11021e79
parent 25556 8d3b7c27049b
child 25558 5c317e8f5673
added new primrec package
NEWS
src/HOL/Inductive.thy
src/HOL/IsaMakefile
src/HOL/Library/Eval.thy
src/HOL/Nominal/nominal_atoms.ML
src/HOL/Nominal/nominal_package.ML
src/HOL/Nominal/nominal_primrec.ML
src/HOL/Tools/old_primrec_package.ML
src/HOL/Tools/primrec_package.ML
src/HOLCF/Tools/fixrec_package.ML
     1.1 --- a/NEWS	Thu Dec 06 12:58:01 2007 +0100
     1.2 +++ b/NEWS	Thu Dec 06 15:10:09 2007 +0100
     1.3 @@ -20,6 +20,11 @@
     1.4  
     1.5  *** HOL ***
     1.6  
     1.7 +* New primrec package.  Specification syntax conforms in style to
     1.8 +  definition/function/....  The "primrec" command distinguished old-style
     1.9 +  and new-style specifications by syntax.  The old primrec package is
    1.10 +  now named OldPrimrecPackage.
    1.11 +
    1.12  * Library/Multiset: {#a, b, c#} abbreviates {#a#} + {#b#} + {#c#}.
    1.13  
    1.14  * Constants "card", "internal_split", "option_map" now with authentic
     2.1 --- a/src/HOL/Inductive.thy	Thu Dec 06 12:58:01 2007 +0100
     2.2 +++ b/src/HOL/Inductive.thy	Thu Dec 06 15:10:09 2007 +0100
     2.3 @@ -17,6 +17,7 @@
     2.4    ("Tools/datatype_abs_proofs.ML")
     2.5    ("Tools/datatype_case.ML")
     2.6    ("Tools/datatype_package.ML")
     2.7 +  ("Tools/old_primrec_package.ML")
     2.8    ("Tools/primrec_package.ML")
     2.9    ("Tools/datatype_codegen.ML")
    2.10  begin
    2.11 @@ -328,6 +329,7 @@
    2.12  use "Tools/datatype_case.ML"
    2.13  use "Tools/datatype_package.ML"
    2.14  setup DatatypePackage.setup
    2.15 +use "Tools/old_primrec_package.ML"
    2.16  use "Tools/primrec_package.ML"
    2.17  
    2.18  use "Tools/datatype_codegen.ML"
     3.1 --- a/src/HOL/IsaMakefile	Thu Dec 06 12:58:01 2007 +0100
     3.2 +++ b/src/HOL/IsaMakefile	Thu Dec 06 15:10:09 2007 +0100
     3.3 @@ -132,6 +132,7 @@
     3.4    Tools/inductive_package.ML Tools/inductive_realizer.ML		\
     3.5    Tools/inductive_set_package.ML Tools/lin_arith.ML Tools/meson.ML	\
     3.6    Tools/metis_tools.ML Tools/numeral.ML Tools/numeral_syntax.ML		\
     3.7 +  Tools/old_primrec_package.ML \
     3.8    Tools/polyhash.ML Tools/primrec_package.ML Tools/prop_logic.ML 	\
     3.9    Tools/recdef_package.ML Tools/recfun_codegen.ML			\
    3.10    Tools/record_package.ML Tools/refute.ML Tools/refute_isar.ML		\
     4.1 --- a/src/HOL/Library/Eval.thy	Thu Dec 06 12:58:01 2007 +0100
     4.2 +++ b/src/HOL/Library/Eval.thy	Thu Dec 06 15:10:09 2007 +0100
     4.3 @@ -151,7 +151,7 @@
     4.4        thy
     4.5        |> Instance.instantiate (tycos, sorts, @{sort term_of})
     4.6             (pair ()) ((K o K) (Class.intro_classes_tac []))
     4.7 -      |> PrimrecPackage.gen_primrec thy_note thy_def "" defs
     4.8 +      |> OldPrimrecPackage.gen_primrec thy_note thy_def "" defs
     4.9        |> snd
    4.10      | NONE => thy;
    4.11    in DatatypePackage.interpretation interpretator end
     5.1 --- a/src/HOL/Nominal/nominal_atoms.ML	Thu Dec 06 12:58:01 2007 +0100
     5.2 +++ b/src/HOL/Nominal/nominal_atoms.ML	Thu Dec 06 15:10:09 2007 +0100
     5.3 @@ -166,7 +166,7 @@
     5.4          thy |> Sign.add_consts_i [("swap_" ^ ak_name, swapT, NoSyn)] 
     5.5              |> PureThy.add_defs_unchecked_i true [((name, def2),[])]
     5.6              |> snd
     5.7 -            |> PrimrecPackage.add_primrec_unchecked_i "" [(("", def1),[])]
     5.8 +            |> OldPrimrecPackage.add_primrec_unchecked_i "" [(("", def1),[])]
     5.9        end) ak_names_types thy2;
    5.10      
    5.11      (* declares a permutation function for every atom-kind acting  *)
    5.12 @@ -194,7 +194,7 @@
    5.13                      Const (swap_name, swapT) $ x $ (Const (qu_prm_name, prmT) $ xs $ a)));
    5.14        in
    5.15          thy |> Sign.add_consts_i [(prm_name, mk_permT T --> T --> T, NoSyn)] 
    5.16 -            |> PrimrecPackage.add_primrec_unchecked_i "" [(("", def1), []),(("", def2), [])]
    5.17 +            |> OldPrimrecPackage.add_primrec_unchecked_i "" [(("", def1), []),(("", def2), [])]
    5.18        end) ak_names_types thy3;
    5.19      
    5.20      (* defines permutation functions for all combinations of atom-kinds; *)
     6.1 --- a/src/HOL/Nominal/nominal_package.ML	Thu Dec 06 12:58:01 2007 +0100
     6.2 +++ b/src/HOL/Nominal/nominal_package.ML	Thu Dec 06 15:10:09 2007 +0100
     6.3 @@ -332,7 +332,7 @@
     6.4      val (perm_simps, thy2) = thy1 |>
     6.5        Sign.add_consts_i (map (fn (s, T) => (Sign.base_name s, T, NoSyn))
     6.6          (List.drop (perm_names_types, length new_type_names))) |>
     6.7 -      PrimrecPackage.add_primrec_unchecked_i "" perm_eqs;
     6.8 +      OldPrimrecPackage.add_primrec_unchecked_i "" perm_eqs;
     6.9  
    6.10      (**** prove that permutation functions introduced by unfolding are ****)
    6.11      (**** equivalent to already existing permutation functions         ****)
     7.1 --- a/src/HOL/Nominal/nominal_primrec.ML	Thu Dec 06 12:58:01 2007 +0100
     7.2 +++ b/src/HOL/Nominal/nominal_primrec.ML	Thu Dec 06 15:10:09 2007 +0100
     7.3 @@ -387,7 +387,7 @@
     7.4      val rec_ts = map (fn eq => head_of (fst (HOLogic.dest_eq
     7.5        (HOLogic.dest_Trueprop (Logic.strip_imp_concl eq))))
     7.6        handle TERM _ => primrec_eq_err thy "not a proper equation" eq) eqn_ts;
     7.7 -    val (_, eqn_ts') = PrimrecPackage.unify_consts thy rec_ts eqn_ts
     7.8 +    val (_, eqn_ts') = OldPrimrecPackage.unify_consts thy rec_ts eqn_ts
     7.9    in
    7.10      gen_primrec_i note def alt_name
    7.11        (Option.map (map (Syntax.read_term_global thy)) invs)
     8.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     8.2 +++ b/src/HOL/Tools/old_primrec_package.ML	Thu Dec 06 15:10:09 2007 +0100
     8.3 @@ -0,0 +1,362 @@
     8.4 +(*  Title:      HOL/Tools/primrec_package.ML
     8.5 +    ID:         $Id$
     8.6 +    Author:     Stefan Berghofer, TU Muenchen and Norbert Voelker, FernUni Hagen
     8.7 +
     8.8 +Package for defining functions on datatypes by primitive recursion.
     8.9 +*)
    8.10 +
    8.11 +signature OLD_PRIMREC_PACKAGE =
    8.12 +sig
    8.13 +  val quiet_mode: bool ref
    8.14 +  val unify_consts: theory -> term list -> term list -> term list * term list
    8.15 +  val add_primrec: string -> ((bstring * string) * Attrib.src list) list
    8.16 +    -> theory -> thm list * theory
    8.17 +  val add_primrec_unchecked: string -> ((bstring * string) * Attrib.src list) list
    8.18 +    -> theory -> thm list * theory
    8.19 +  val add_primrec_i: string -> ((bstring * term) * attribute list) list
    8.20 +    -> theory -> thm list * theory
    8.21 +  val add_primrec_unchecked_i: string -> ((bstring * term) * attribute list) list
    8.22 +    -> theory -> thm list * theory
    8.23 +  (* FIXME !? *)
    8.24 +  val gen_primrec: ((bstring * attribute list) * thm list -> theory -> (bstring * thm list) * theory)
    8.25 +    -> ((bstring * attribute list) * term -> theory -> (bstring * thm) * theory)
    8.26 +    -> string -> ((bstring * attribute list) * term) list
    8.27 +    -> theory -> thm list * theory;
    8.28 +end;
    8.29 +
    8.30 +structure OldPrimrecPackage : OLD_PRIMREC_PACKAGE =
    8.31 +struct
    8.32 +
    8.33 +open DatatypeAux;
    8.34 +
    8.35 +exception RecError of string;
    8.36 +
    8.37 +fun primrec_err s = error ("Primrec definition error:\n" ^ s);
    8.38 +fun primrec_eq_err thy s eq =
    8.39 +  primrec_err (s ^ "\nin\n" ^ quote (Sign.string_of_term thy eq));
    8.40 +
    8.41 +
    8.42 +(* messages *)
    8.43 +
    8.44 +val quiet_mode = ref false;
    8.45 +fun message s = if ! quiet_mode then () else writeln s;
    8.46 +
    8.47 +
    8.48 +(*the following code ensures that each recursive set always has the
    8.49 +  same type in all introduction rules*)
    8.50 +fun unify_consts thy cs intr_ts =
    8.51 +  (let
    8.52 +    val add_term_consts_2 = fold_aterms (fn Const c => insert (op =) c | _ => I);
    8.53 +    fun varify (t, (i, ts)) =
    8.54 +      let val t' = map_types (Logic.incr_tvar (i + 1)) (snd (Type.varify [] t))
    8.55 +      in (maxidx_of_term t', t'::ts) end;
    8.56 +    val (i, cs') = foldr varify (~1, []) cs;
    8.57 +    val (i', intr_ts') = foldr varify (i, []) intr_ts;
    8.58 +    val rec_consts = fold add_term_consts_2 cs' [];
    8.59 +    val intr_consts = fold add_term_consts_2 intr_ts' [];
    8.60 +    fun unify (cname, cT) =
    8.61 +      let val consts = map snd (filter (fn (c, _) => c = cname) intr_consts)
    8.62 +      in fold (Sign.typ_unify thy) ((replicate (length consts) cT) ~~ consts) end;
    8.63 +    val (env, _) = fold unify rec_consts (Vartab.empty, i');
    8.64 +    val subst = Type.freeze o map_types (Envir.norm_type env)
    8.65 +
    8.66 +  in (map subst cs', map subst intr_ts')
    8.67 +  end) handle Type.TUNIFY =>
    8.68 +    (warning "Occurrences of recursive constant have non-unifiable types"; (cs, intr_ts));
    8.69 +
    8.70 +
    8.71 +(* preprocessing of equations *)
    8.72 +
    8.73 +fun process_eqn thy eq rec_fns =
    8.74 +  let
    8.75 +    val (lhs, rhs) =
    8.76 +      if null (term_vars eq) then
    8.77 +        HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
    8.78 +        handle TERM _ => raise RecError "not a proper equation"
    8.79 +      else raise RecError "illegal schematic variable(s)";
    8.80 +
    8.81 +    val (recfun, args) = strip_comb lhs;
    8.82 +    val fnameT = dest_Const recfun handle TERM _ =>
    8.83 +      raise RecError "function is not declared as constant in theory";
    8.84 +
    8.85 +    val (ls', rest)  = take_prefix is_Free args;
    8.86 +    val (middle, rs') = take_suffix is_Free rest;
    8.87 +    val rpos = length ls';
    8.88 +
    8.89 +    val (constr, cargs') = if null middle then raise RecError "constructor missing"
    8.90 +      else strip_comb (hd middle);
    8.91 +    val (cname, T) = dest_Const constr
    8.92 +      handle TERM _ => raise RecError "ill-formed constructor";
    8.93 +    val (tname, _) = dest_Type (body_type T) handle TYPE _ =>
    8.94 +      raise RecError "cannot determine datatype associated with function"
    8.95 +
    8.96 +    val (ls, cargs, rs) =
    8.97 +      (map dest_Free ls', map dest_Free cargs', map dest_Free rs')
    8.98 +      handle TERM _ => raise RecError "illegal argument in pattern";
    8.99 +    val lfrees = ls @ rs @ cargs;
   8.100 +
   8.101 +    fun check_vars _ [] = ()
   8.102 +      | check_vars s vars = raise RecError (s ^ commas_quote (map fst vars))
   8.103 +  in
   8.104 +    if length middle > 1 then
   8.105 +      raise RecError "more than one non-variable in pattern"
   8.106 +    else
   8.107 +     (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees);
   8.108 +      check_vars "extra variables on rhs: "
   8.109 +        (map dest_Free (term_frees rhs) \\ lfrees);
   8.110 +      case AList.lookup (op =) rec_fns fnameT of
   8.111 +        NONE =>
   8.112 +          (fnameT, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eq))]))::rec_fns
   8.113 +      | SOME (_, rpos', eqns) =>
   8.114 +          if AList.defined (op =) eqns cname then
   8.115 +            raise RecError "constructor already occurred as pattern"
   8.116 +          else if rpos <> rpos' then
   8.117 +            raise RecError "position of recursive argument inconsistent"
   8.118 +          else
   8.119 +            AList.update (op =) (fnameT, (tname, rpos, (cname, (ls, cargs, rs, rhs, eq))::eqns))
   8.120 +              rec_fns)
   8.121 +  end
   8.122 +  handle RecError s => primrec_eq_err thy s eq;
   8.123 +
   8.124 +fun process_fun thy descr rec_eqns (i, fnameT as (fname, _)) (fnameTs, fnss) =
   8.125 +  let
   8.126 +    val (_, (tname, _, constrs)) = List.nth (descr, i);
   8.127 +
   8.128 +    (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *)
   8.129 +
   8.130 +    fun subst [] t fs = (t, fs)
   8.131 +      | subst subs (Abs (a, T, t)) fs =
   8.132 +          fs
   8.133 +          |> subst subs t
   8.134 +          |-> (fn t' => pair (Abs (a, T, t')))
   8.135 +      | subst subs (t as (_ $ _)) fs =
   8.136 +          let
   8.137 +            val (f, ts) = strip_comb t;
   8.138 +          in
   8.139 +            if is_Const f andalso dest_Const f mem map fst rec_eqns then
   8.140 +              let
   8.141 +                val fnameT' as (fname', _) = dest_Const f;
   8.142 +                val (_, rpos, _) = the (AList.lookup (op =) rec_eqns fnameT');
   8.143 +                val ls = Library.take (rpos, ts);
   8.144 +                val rest = Library.drop (rpos, ts);
   8.145 +                val (x', rs) = (hd rest, tl rest)
   8.146 +                  handle Empty => raise RecError ("not enough arguments\
   8.147 +                   \ in recursive application\nof function " ^ quote fname' ^ " on rhs");
   8.148 +                val (x, xs) = strip_comb x'
   8.149 +              in case AList.lookup (op =) subs x
   8.150 +               of NONE =>
   8.151 +                    fs
   8.152 +                    |> fold_map (subst subs) ts
   8.153 +                    |-> (fn ts' => pair (list_comb (f, ts')))
   8.154 +                | SOME (i', y) =>
   8.155 +                    fs
   8.156 +                    |> fold_map (subst subs) (xs @ ls @ rs)
   8.157 +                    ||> process_fun thy descr rec_eqns (i', fnameT')
   8.158 +                    |-> (fn ts' => pair (list_comb (y, ts')))
   8.159 +              end
   8.160 +            else
   8.161 +              fs
   8.162 +              |> fold_map (subst subs) (f :: ts)
   8.163 +              |-> (fn (f'::ts') => pair (list_comb (f', ts')))
   8.164 +          end
   8.165 +      | subst _ t fs = (t, fs);
   8.166 +
   8.167 +    (* translate rec equations into function arguments suitable for rec comb *)
   8.168 +
   8.169 +    fun trans eqns (cname, cargs) (fnameTs', fnss', fns) =
   8.170 +      (case AList.lookup (op =) eqns cname of
   8.171 +          NONE => (warning ("No equation for constructor " ^ quote cname ^
   8.172 +            "\nin definition of function " ^ quote fname);
   8.173 +              (fnameTs', fnss', (Const ("HOL.undefined", dummyT))::fns))
   8.174 +        | SOME (ls, cargs', rs, rhs, eq) =>
   8.175 +            let
   8.176 +              val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
   8.177 +              val rargs = map fst recs;
   8.178 +              val subs = map (rpair dummyT o fst)
   8.179 +                (rev (rename_wrt_term rhs rargs));
   8.180 +              val (rhs', (fnameTs'', fnss'')) =
   8.181 +                  (subst (map (fn ((x, y), z) =>
   8.182 +                               (Free x, (body_index y, Free z)))
   8.183 +                          (recs ~~ subs)) rhs (fnameTs', fnss'))
   8.184 +                  handle RecError s => primrec_eq_err thy s eq
   8.185 +            in (fnameTs'', fnss'',
   8.186 +                (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
   8.187 +            end)
   8.188 +
   8.189 +  in (case AList.lookup (op =) fnameTs i of
   8.190 +      NONE =>
   8.191 +        if exists (equal fnameT o snd) fnameTs then
   8.192 +          raise RecError ("inconsistent functions for datatype " ^ quote tname)
   8.193 +        else
   8.194 +          let
   8.195 +            val (_, _, eqns) = the (AList.lookup (op =) rec_eqns fnameT);
   8.196 +            val (fnameTs', fnss', fns) = fold_rev (trans eqns) constrs
   8.197 +              ((i, fnameT)::fnameTs, fnss, [])
   8.198 +          in
   8.199 +            (fnameTs', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
   8.200 +          end
   8.201 +    | SOME fnameT' =>
   8.202 +        if fnameT = fnameT' then (fnameTs, fnss)
   8.203 +        else raise RecError ("inconsistent functions for datatype " ^ quote tname))
   8.204 +  end;
   8.205 +
   8.206 +
   8.207 +(* prepare functions needed for definitions *)
   8.208 +
   8.209 +fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) =
   8.210 +  case AList.lookup (op =) fns i of
   8.211 +     NONE =>
   8.212 +       let
   8.213 +         val dummy_fns = map (fn (_, cargs) => Const ("HOL.undefined",
   8.214 +           replicate ((length cargs) + (length (List.filter is_rec_type cargs)))
   8.215 +             dummyT ---> HOLogic.unitT)) constrs;
   8.216 +         val _ = warning ("No function definition for datatype " ^ quote tname)
   8.217 +       in
   8.218 +         (dummy_fns @ fs, defs)
   8.219 +       end
   8.220 +   | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs);
   8.221 +
   8.222 +
   8.223 +(* make definition *)
   8.224 +
   8.225 +fun make_def thy fs (fname, ls, rec_name, tname) =
   8.226 +  let
   8.227 +    val rhs = fold_rev (fn T => fn t => Abs ("", T, t))
   8.228 +                    ((map snd ls) @ [dummyT])
   8.229 +                    (list_comb (Const (rec_name, dummyT),
   8.230 +                                fs @ map Bound (0 ::(length ls downto 1))))
   8.231 +    val def_name = Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def";
   8.232 +    val def_prop =
   8.233 +      singleton (Syntax.check_terms (ProofContext.init thy))
   8.234 +        (Logic.mk_equals (Const (fname, dummyT), rhs));
   8.235 +  in (def_name, def_prop) end;
   8.236 +
   8.237 +
   8.238 +(* find datatypes which contain all datatypes in tnames' *)
   8.239 +
   8.240 +fun find_dts (dt_info : datatype_info Symtab.table) _ [] = []
   8.241 +  | find_dts dt_info tnames' (tname::tnames) =
   8.242 +      (case Symtab.lookup dt_info tname of
   8.243 +          NONE => primrec_err (quote tname ^ " is not a datatype")
   8.244 +        | SOME dt =>
   8.245 +            if tnames' subset (map (#1 o snd) (#descr dt)) then
   8.246 +              (tname, dt)::(find_dts dt_info tnames' tnames)
   8.247 +            else find_dts dt_info tnames' tnames);
   8.248 +
   8.249 +fun prepare_induct ({descr, induction, ...}: datatype_info) rec_eqns =
   8.250 +  let
   8.251 +    fun constrs_of (_, (_, _, cs)) =
   8.252 +      map (fn (cname:string, (_, cargs, _, _, _)) => (cname, map fst cargs)) cs;
   8.253 +    val params_of = these o AList.lookup (op =) (List.concat (map constrs_of rec_eqns));
   8.254 +  in
   8.255 +    induction
   8.256 +    |> RuleCases.rename_params (map params_of (List.concat (map (map #1 o #3 o #2) descr)))
   8.257 +    |> RuleCases.save induction
   8.258 +  end;
   8.259 +
   8.260 +local
   8.261 +
   8.262 +fun gen_primrec_i note def alt_name eqns_atts thy =
   8.263 +  let
   8.264 +    val (eqns, atts) = split_list eqns_atts;
   8.265 +    val dt_info = DatatypePackage.get_datatypes thy;
   8.266 +    val rec_eqns = fold_rev (process_eqn thy o snd) eqns [] ;
   8.267 +    val tnames = distinct (op =) (map (#1 o snd) rec_eqns);
   8.268 +    val dts = find_dts dt_info tnames tnames;
   8.269 +    val main_fns =
   8.270 +      map (fn (tname, {index, ...}) =>
   8.271 +        (index,
   8.272 +          (fst o the o find_first (fn f => (#1 o snd) f = tname)) rec_eqns))
   8.273 +      dts;
   8.274 +    val {descr, rec_names, rec_rewrites, ...} =
   8.275 +      if null dts then
   8.276 +        primrec_err ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
   8.277 +      else snd (hd dts);
   8.278 +    val (fnameTs, fnss) =
   8.279 +      fold_rev (process_fun thy descr rec_eqns) main_fns ([], []);
   8.280 +    val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
   8.281 +    val defs' = map (make_def thy fs) defs;
   8.282 +    val nameTs1 = map snd fnameTs;
   8.283 +    val nameTs2 = map fst rec_eqns;
   8.284 +    val _ = if gen_eq_set (op =) (nameTs1, nameTs2) then ()
   8.285 +            else primrec_err ("functions " ^ commas_quote (map fst nameTs2) ^
   8.286 +              "\nare not mutually recursive");
   8.287 +    val primrec_name =
   8.288 +      if alt_name = "" then (space_implode "_" (map (Sign.base_name o #1) defs)) else alt_name;
   8.289 +    val (defs_thms', thy') =
   8.290 +      thy
   8.291 +      |> Sign.add_path primrec_name
   8.292 +      |> fold_map def (map (fn (name, t) => ((name, []), t)) defs');
   8.293 +    val rewrites = (map mk_meta_eq rec_rewrites) @ map snd defs_thms';
   8.294 +    val _ = message ("Proving equations for primrec function(s) " ^
   8.295 +      commas_quote (map fst nameTs1) ^ " ...");
   8.296 +    val simps = map (fn (_, t) => Goal.prove_global thy' [] [] t
   8.297 +        (fn _ => EVERY [rewrite_goals_tac rewrites, rtac refl 1])) eqns;
   8.298 +    val (simps', thy'') =
   8.299 +      thy'
   8.300 +      |> fold_map note ((map fst eqns ~~ atts) ~~ map single simps);
   8.301 +    val simps'' = maps snd simps';
   8.302 +  in
   8.303 +    thy''
   8.304 +    |> note (("simps", [Simplifier.simp_add, RecfunCodegen.add_default]), simps'')
   8.305 +    |> snd
   8.306 +    |> note (("induct", []), [prepare_induct (#2 (hd dts)) rec_eqns])
   8.307 +    |> snd
   8.308 +    |> Sign.parent_path
   8.309 +    |> pair simps''
   8.310 +  end;
   8.311 +
   8.312 +fun gen_primrec note def alt_name eqns thy =
   8.313 +  let
   8.314 +    val ((names, strings), srcss) = apfst split_list (split_list eqns);
   8.315 +    val atts = map (map (Attrib.attribute thy)) srcss;
   8.316 +    val eqn_ts = map (fn s => Syntax.read_prop_global thy s
   8.317 +      handle ERROR msg => cat_error msg ("The error(s) above occurred for " ^ s)) strings;
   8.318 +    val rec_ts = map (fn eq => head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop eq)))
   8.319 +      handle TERM _ => primrec_eq_err thy "not a proper equation" eq) eqn_ts;
   8.320 +    val (_, eqn_ts') = unify_consts thy rec_ts eqn_ts
   8.321 +  in
   8.322 +    gen_primrec_i note def alt_name (names ~~ eqn_ts' ~~ atts) thy
   8.323 +  end;
   8.324 +
   8.325 +fun thy_note ((name, atts), thms) =
   8.326 +  PureThy.add_thmss [((name, thms), atts)] #-> (fn [thms] => pair (name, thms));
   8.327 +fun thy_def false ((name, atts), t) =
   8.328 +      PureThy.add_defs_i false [((name, t), atts)] #-> (fn [thm] => pair (name, thm))
   8.329 +  | thy_def true ((name, atts), t) =
   8.330 +      PureThy.add_defs_unchecked_i false [((name, t), atts)] #-> (fn [thm] => pair (name, thm));
   8.331 +
   8.332 +in
   8.333 +
   8.334 +val add_primrec = gen_primrec thy_note (thy_def false);
   8.335 +val add_primrec_unchecked = gen_primrec thy_note (thy_def true);
   8.336 +val add_primrec_i = gen_primrec_i thy_note (thy_def false);
   8.337 +val add_primrec_unchecked_i = gen_primrec_i thy_note (thy_def true);
   8.338 +fun gen_primrec note def alt_name specs =
   8.339 +  gen_primrec_i note def alt_name (map (fn ((name, t), atts) => ((name, atts), t)) specs);
   8.340 +
   8.341 +end;
   8.342 +
   8.343 +
   8.344 +(* see primrecr_package.ML (* outer syntax *)
   8.345 +
   8.346 +local structure P = OuterParse and K = OuterKeyword in
   8.347 +
   8.348 +val opt_unchecked_name =
   8.349 +  Scan.optional (P.$$$ "(" |-- P.!!!
   8.350 +    (((P.$$$ "unchecked" >> K true) -- Scan.optional P.name "" ||
   8.351 +      P.name >> pair false) --| P.$$$ ")")) (false, "");
   8.352 +
   8.353 +val primrec_decl =
   8.354 +  opt_unchecked_name -- Scan.repeat1 (SpecParse.opt_thm_name ":" -- P.prop);
   8.355 +
   8.356 +val _ =
   8.357 +  OuterSyntax.command "primrec" "define primitive recursive functions on datatypes" K.thy_decl
   8.358 +    (primrec_decl >> (fn ((unchecked, alt_name), eqns) =>
   8.359 +      Toplevel.theory (snd o
   8.360 +        (if unchecked then add_primrec_unchecked else add_primrec) alt_name
   8.361 +          (map P.triple_swap eqns))));
   8.362 +
   8.363 +end;*)
   8.364 +
   8.365 +end;
     9.1 --- a/src/HOL/Tools/primrec_package.ML	Thu Dec 06 12:58:01 2007 +0100
     9.2 +++ b/src/HOL/Tools/primrec_package.ML	Thu Dec 06 15:10:09 2007 +0100
     9.3 @@ -1,27 +1,15 @@
     9.4  (*  Title:      HOL/Tools/primrec_package.ML
     9.5      ID:         $Id$
     9.6 -    Author:     Stefan Berghofer, TU Muenchen and Norbert Voelker, FernUni Hagen
     9.7 +    Author:     Stefan Berghofer, TU Muenchen; Norbert Voelker, FernUni Hagen;
     9.8 +                Florian Haftmann, TU Muenchen
     9.9  
    9.10  Package for defining functions on datatypes by primitive recursion.
    9.11  *)
    9.12  
    9.13  signature PRIMREC_PACKAGE =
    9.14  sig
    9.15 -  val quiet_mode: bool ref
    9.16 -  val unify_consts: theory -> term list -> term list -> term list * term list
    9.17 -  val add_primrec: string -> ((bstring * string) * Attrib.src list) list
    9.18 -    -> theory -> thm list * theory
    9.19 -  val add_primrec_unchecked: string -> ((bstring * string) * Attrib.src list) list
    9.20 -    -> theory -> thm list * theory
    9.21 -  val add_primrec_i: string -> ((bstring * term) * attribute list) list
    9.22 -    -> theory -> thm list * theory
    9.23 -  val add_primrec_unchecked_i: string -> ((bstring * term) * attribute list) list
    9.24 -    -> theory -> thm list * theory
    9.25 -  (* FIXME !? *)
    9.26 -  val gen_primrec: ((bstring * attribute list) * thm list -> theory -> (bstring * thm list) * theory)
    9.27 -    -> ((bstring * attribute list) * term -> theory -> (bstring * thm) * theory)
    9.28 -    -> string -> ((bstring * attribute list) * term) list
    9.29 -    -> theory -> thm list * theory;
    9.30 +  val add_primrec: (string * typ option * mixfix) list ->
    9.31 +    ((bstring * Attrib.src list) * term) list -> local_theory -> thm list * local_theory
    9.32  end;
    9.33  
    9.34  structure PrimrecPackage : PRIMREC_PACKAGE =
    9.35 @@ -29,98 +17,71 @@
    9.36  
    9.37  open DatatypeAux;
    9.38  
    9.39 -exception RecError of string;
    9.40 -
    9.41 -fun primrec_err s = error ("Primrec definition error:\n" ^ s);
    9.42 -fun primrec_eq_err thy s eq =
    9.43 -  primrec_err (s ^ "\nin\n" ^ quote (Sign.string_of_term thy eq));
    9.44 -
    9.45 -
    9.46 -(* messages *)
    9.47 -
    9.48 -val quiet_mode = ref false;
    9.49 -fun message s = if ! quiet_mode then () else writeln s;
    9.50 -
    9.51 +exception PrimrecError of string * term option;
    9.52  
    9.53 -(*the following code ensures that each recursive set always has the
    9.54 -  same type in all introduction rules*)
    9.55 -fun unify_consts thy cs intr_ts =
    9.56 -  (let
    9.57 -    val add_term_consts_2 = fold_aterms (fn Const c => insert (op =) c | _ => I);
    9.58 -    fun varify (t, (i, ts)) =
    9.59 -      let val t' = map_types (Logic.incr_tvar (i + 1)) (snd (Type.varify [] t))
    9.60 -      in (maxidx_of_term t', t'::ts) end;
    9.61 -    val (i, cs') = foldr varify (~1, []) cs;
    9.62 -    val (i', intr_ts') = foldr varify (i, []) intr_ts;
    9.63 -    val rec_consts = fold add_term_consts_2 cs' [];
    9.64 -    val intr_consts = fold add_term_consts_2 intr_ts' [];
    9.65 -    fun unify (cname, cT) =
    9.66 -      let val consts = map snd (filter (fn (c, _) => c = cname) intr_consts)
    9.67 -      in fold (Sign.typ_unify thy) ((replicate (length consts) cT) ~~ consts) end;
    9.68 -    val (env, _) = fold unify rec_consts (Vartab.empty, i');
    9.69 -    val subst = Type.freeze o map_types (Envir.norm_type env)
    9.70 +fun primrec_error msg = raise PrimrecError (msg, NONE);
    9.71 +fun primrec_error_eqn msg eqn = raise PrimrecError (msg, SOME eqn);
    9.72  
    9.73 -  in (map subst cs', map subst intr_ts')
    9.74 -  end) handle Type.TUNIFY =>
    9.75 -    (warning "Occurrences of recursive constant have non-unifiable types"; (cs, intr_ts));
    9.76 +fun message s = if ! Toplevel.debug then () else writeln s;
    9.77  
    9.78  
    9.79  (* preprocessing of equations *)
    9.80  
    9.81 -fun process_eqn thy eq rec_fns =
    9.82 +fun process_eqn is_fixed is_const spec rec_fns =
    9.83    let
    9.84 -    val (lhs, rhs) =
    9.85 -      if null (term_vars eq) then
    9.86 -        HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
    9.87 -        handle TERM _ => raise RecError "not a proper equation"
    9.88 -      else raise RecError "illegal schematic variable(s)";
    9.89 -
    9.90 +    val vars = strip_qnt_vars "all" spec;
    9.91 +    val body = strip_qnt_body "all" spec;
    9.92 +    val eqn = curry subst_bounds (map Free (rev vars)) body;
    9.93 +    val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn)
    9.94 +      handle TERM _ => primrec_error "not a proper equation";
    9.95      val (recfun, args) = strip_comb lhs;
    9.96 -    val fnameT = dest_Const recfun handle TERM _ =>
    9.97 -      raise RecError "function is not declared as constant in theory";
    9.98 +    val fname = case recfun of Free (v, _) => if is_fixed v then v
    9.99 +          else primrec_error "illegal head of function equation"
   9.100 +      | _ => primrec_error "illegal head of function equation";
   9.101  
   9.102      val (ls', rest)  = take_prefix is_Free args;
   9.103      val (middle, rs') = take_suffix is_Free rest;
   9.104      val rpos = length ls';
   9.105  
   9.106 -    val (constr, cargs') = if null middle then raise RecError "constructor missing"
   9.107 +    val (constr, cargs') = if null middle then primrec_error "constructor missing"
   9.108        else strip_comb (hd middle);
   9.109      val (cname, T) = dest_Const constr
   9.110 -      handle TERM _ => raise RecError "ill-formed constructor";
   9.111 +      handle TERM _ => primrec_error "ill-formed constructor";
   9.112      val (tname, _) = dest_Type (body_type T) handle TYPE _ =>
   9.113 -      raise RecError "cannot determine datatype associated with function"
   9.114 +      primrec_error "cannot determine datatype associated with function"
   9.115  
   9.116      val (ls, cargs, rs) =
   9.117        (map dest_Free ls', map dest_Free cargs', map dest_Free rs')
   9.118 -      handle TERM _ => raise RecError "illegal argument in pattern";
   9.119 +      handle TERM _ => primrec_error "illegal argument in pattern";
   9.120      val lfrees = ls @ rs @ cargs;
   9.121  
   9.122      fun check_vars _ [] = ()
   9.123 -      | check_vars s vars = raise RecError (s ^ commas_quote (map fst vars))
   9.124 +      | check_vars s vars = primrec_error (s ^ commas_quote (map fst vars)) eqn;
   9.125    in
   9.126      if length middle > 1 then
   9.127 -      raise RecError "more than one non-variable in pattern"
   9.128 +      primrec_error "more than one non-variable in pattern"
   9.129      else
   9.130       (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees);
   9.131        check_vars "extra variables on rhs: "
   9.132 -        (map dest_Free (term_frees rhs) \\ lfrees);
   9.133 -      case AList.lookup (op =) rec_fns fnameT of
   9.134 +        (map dest_Free (term_frees rhs) |> subtract (op =) lfrees
   9.135 +          |> filter_out (is_const o fst) |> filter_out (is_fixed o fst));
   9.136 +      case AList.lookup (op =) rec_fns fname of
   9.137          NONE =>
   9.138 -          (fnameT, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eq))]))::rec_fns
   9.139 +          (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eqn))]))::rec_fns
   9.140        | SOME (_, rpos', eqns) =>
   9.141            if AList.defined (op =) eqns cname then
   9.142 -            raise RecError "constructor already occurred as pattern"
   9.143 +            primrec_error "constructor already occurred as pattern"
   9.144            else if rpos <> rpos' then
   9.145 -            raise RecError "position of recursive argument inconsistent"
   9.146 +            primrec_error "position of recursive argument inconsistent"
   9.147            else
   9.148 -            AList.update (op =) (fnameT, (tname, rpos, (cname, (ls, cargs, rs, rhs, eq))::eqns))
   9.149 +            AList.update (op =)
   9.150 +              (fname, (tname, rpos, (cname, (ls, cargs, rs, rhs, eqn))::eqns))
   9.151                rec_fns)
   9.152 -  end
   9.153 -  handle RecError s => primrec_eq_err thy s eq;
   9.154 +  end handle PrimrecError (msg, NONE) => primrec_error_eqn msg spec;
   9.155  
   9.156 -fun process_fun thy descr rec_eqns (i, fnameT as (fname, _)) (fnameTs, fnss) =
   9.157 +fun process_fun descr eqns (i, fname) (fnames, fnss) =
   9.158    let
   9.159 -    val (_, (tname, _, constrs)) = List.nth (descr, i);
   9.160 +    val (_, (tname, _, constrs)) = nth descr i;
   9.161  
   9.162      (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *)
   9.163  
   9.164 @@ -133,14 +94,13 @@
   9.165            let
   9.166              val (f, ts) = strip_comb t;
   9.167            in
   9.168 -            if is_Const f andalso dest_Const f mem map fst rec_eqns then
   9.169 +            if is_Free f
   9.170 +              andalso member (fn ((v, _), (w, _)) => v = w) eqns (dest_Free f) then
   9.171                let
   9.172 -                val fnameT' as (fname', _) = dest_Const f;
   9.173 -                val (_, rpos, _) = the (AList.lookup (op =) rec_eqns fnameT');
   9.174 -                val ls = Library.take (rpos, ts);
   9.175 -                val rest = Library.drop (rpos, ts);
   9.176 -                val (x', rs) = (hd rest, tl rest)
   9.177 -                  handle Empty => raise RecError ("not enough arguments\
   9.178 +                val (fname', _) = dest_Free f;
   9.179 +                val (_, rpos, _) = the (AList.lookup (op =) eqns fname');
   9.180 +                val (ls, x' :: rs) = chop rpos ts
   9.181 +                  handle Empty => primrec_error ("not enough arguments\
   9.182                     \ in recursive application\nof function " ^ quote fname' ^ " on rhs");
   9.183                  val (x, xs) = strip_comb x'
   9.184                in case AList.lookup (op =) subs x
   9.185 @@ -151,7 +111,7 @@
   9.186                  | SOME (i', y) =>
   9.187                      fs
   9.188                      |> fold_map (subst subs) (xs @ ls @ rs)
   9.189 -                    ||> process_fun thy descr rec_eqns (i', fnameT')
   9.190 +                    ||> process_fun descr eqns (i', fname')
   9.191                      |-> (fn ts' => pair (list_comb (y, ts')))
   9.192                end
   9.193              else
   9.194 @@ -163,41 +123,39 @@
   9.195  
   9.196      (* translate rec equations into function arguments suitable for rec comb *)
   9.197  
   9.198 -    fun trans eqns (cname, cargs) (fnameTs', fnss', fns) =
   9.199 +    fun trans eqns (cname, cargs) (fnames', fnss', fns) =
   9.200        (case AList.lookup (op =) eqns cname of
   9.201            NONE => (warning ("No equation for constructor " ^ quote cname ^
   9.202              "\nin definition of function " ^ quote fname);
   9.203 -              (fnameTs', fnss', (Const ("HOL.undefined", dummyT))::fns))
   9.204 +              (fnames', fnss', (Const ("HOL.undefined", dummyT))::fns))
   9.205          | SOME (ls, cargs', rs, rhs, eq) =>
   9.206              let
   9.207                val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
   9.208                val rargs = map fst recs;
   9.209                val subs = map (rpair dummyT o fst)
   9.210                  (rev (rename_wrt_term rhs rargs));
   9.211 -              val (rhs', (fnameTs'', fnss'')) =
   9.212 -                  (subst (map (fn ((x, y), z) =>
   9.213 -                               (Free x, (body_index y, Free z)))
   9.214 -                          (recs ~~ subs)) rhs (fnameTs', fnss'))
   9.215 -                  handle RecError s => primrec_eq_err thy s eq
   9.216 -            in (fnameTs'', fnss'',
   9.217 +              val (rhs', (fnames'', fnss'')) = (subst (map2 (fn (x, y) => fn z =>
   9.218 +                (Free x, (body_index y, Free z))) recs subs) rhs (fnames', fnss'))
   9.219 +                  handle PrimrecError (s, NONE) => primrec_error_eqn s eq
   9.220 +            in (fnames'', fnss'',
   9.221                  (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
   9.222              end)
   9.223  
   9.224 -  in (case AList.lookup (op =) fnameTs i of
   9.225 +  in (case AList.lookup (op =) fnames i of
   9.226        NONE =>
   9.227 -        if exists (equal fnameT o snd) fnameTs then
   9.228 -          raise RecError ("inconsistent functions for datatype " ^ quote tname)
   9.229 +        if exists (fn (_, v) => fname = v) fnames then
   9.230 +          primrec_error ("inconsistent functions for datatype " ^ quote tname)
   9.231          else
   9.232            let
   9.233 -            val (_, _, eqns) = the (AList.lookup (op =) rec_eqns fnameT);
   9.234 -            val (fnameTs', fnss', fns) = fold_rev (trans eqns) constrs
   9.235 -              ((i, fnameT)::fnameTs, fnss, [])
   9.236 +            val (_, _, eqns) = the (AList.lookup (op =) eqns fname);
   9.237 +            val (fnames', fnss', fns) = fold_rev (trans eqns) constrs
   9.238 +              ((i, fname)::fnames, fnss, [])
   9.239            in
   9.240 -            (fnameTs', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
   9.241 +            (fnames', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
   9.242            end
   9.243 -    | SOME fnameT' =>
   9.244 -        if fnameT = fnameT' then (fnameTs, fnss)
   9.245 -        else raise RecError ("inconsistent functions for datatype " ^ quote tname))
   9.246 +    | SOME fname' =>
   9.247 +        if fname = fname' then (fnames, fnss)
   9.248 +        else primrec_error ("inconsistent functions for datatype " ^ quote tname))
   9.249    end;
   9.250  
   9.251  
   9.252 @@ -219,17 +177,17 @@
   9.253  
   9.254  (* make definition *)
   9.255  
   9.256 -fun make_def thy fs (fname, ls, rec_name, tname) =
   9.257 +fun make_def ctxt fixes fs (fname, ls, rec_name, tname) =
   9.258    let
   9.259 -    val rhs = fold_rev (fn T => fn t => Abs ("", T, t))
   9.260 +    val raw_rhs = fold_rev (fn T => fn t => Abs ("", T, t))
   9.261                      ((map snd ls) @ [dummyT])
   9.262                      (list_comb (Const (rec_name, dummyT),
   9.263                                  fs @ map Bound (0 ::(length ls downto 1))))
   9.264 -    val def_name = Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def";
   9.265 -    val def_prop =
   9.266 -      singleton (Syntax.check_terms (ProofContext.init thy))
   9.267 -        (Logic.mk_equals (Const (fname, dummyT), rhs));
   9.268 -  in (def_name, def_prop) end;
   9.269 +    val def_name = Thm.def_name (Sign.base_name fname);
   9.270 +    val rhs = singleton (Syntax.check_terms ctxt) raw_rhs;
   9.271 +    val SOME mfx = get_first
   9.272 +      (fn ((v, _), mfx) => if v = fname then SOME mfx else NONE) fixes;
   9.273 +  in ((fname, mfx), ((def_name, []), rhs)) end;
   9.274  
   9.275  
   9.276  (* find datatypes which contain all datatypes in tnames' *)
   9.277 @@ -237,103 +195,87 @@
   9.278  fun find_dts (dt_info : datatype_info Symtab.table) _ [] = []
   9.279    | find_dts dt_info tnames' (tname::tnames) =
   9.280        (case Symtab.lookup dt_info tname of
   9.281 -          NONE => primrec_err (quote tname ^ " is not a datatype")
   9.282 +          NONE => primrec_error (quote tname ^ " is not a datatype")
   9.283          | SOME dt =>
   9.284              if tnames' subset (map (#1 o snd) (#descr dt)) then
   9.285                (tname, dt)::(find_dts dt_info tnames' tnames)
   9.286              else find_dts dt_info tnames' tnames);
   9.287  
   9.288 -fun prepare_induct ({descr, induction, ...}: datatype_info) rec_eqns =
   9.289 +
   9.290 +(* adapted induction rule *)
   9.291 +
   9.292 +fun prepare_induct ({descr, induction, ...}: datatype_info) eqns =
   9.293    let
   9.294      fun constrs_of (_, (_, _, cs)) =
   9.295        map (fn (cname:string, (_, cargs, _, _, _)) => (cname, map fst cargs)) cs;
   9.296 -    val params_of = these o AList.lookup (op =) (List.concat (map constrs_of rec_eqns));
   9.297 +    val params_of = these o AList.lookup (op =) (List.concat (map constrs_of eqns));
   9.298    in
   9.299      induction
   9.300 -    |> RuleCases.rename_params (map params_of (List.concat (map (map #1 o #3 o #2) descr)))
   9.301 +    |> RuleCases.rename_params (map params_of (maps (map #1 o #3 o #2) descr))
   9.302      |> RuleCases.save induction
   9.303    end;
   9.304  
   9.305 +
   9.306 +(* primrec definition *)
   9.307 +
   9.308  local
   9.309  
   9.310 -fun gen_primrec_i note def alt_name eqns_atts thy =
   9.311 +fun prepare_spec prep_spec ctxt raw_fixes raw_spec =
   9.312 +  let
   9.313 +    val ((fixes, spec), _) = prep_spec raw_fixes [(map o apsnd) single raw_spec] ctxt
   9.314 +  in (fixes, (map o apsnd) the_single spec) end;
   9.315 +
   9.316 +fun prove_spec ctxt rec_rewrites defs =
   9.317 +  let
   9.318 +    val rewrites = map mk_meta_eq rec_rewrites @ map (snd o snd) defs;
   9.319 +    fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
   9.320 +    val _ = message "Proving equations for primrec function";
   9.321 +  in map (fn (name_attr, t) => (name_attr, [Goal.prove ctxt [] [] t tac])) end;
   9.322 +
   9.323 +fun gen_primrec prep_spec raw_fixes raw_spec lthy =
   9.324    let
   9.325 -    val (eqns, atts) = split_list eqns_atts;
   9.326 -    val dt_info = DatatypePackage.get_datatypes thy;
   9.327 -    val rec_eqns = fold_rev (process_eqn thy o snd) eqns [] ;
   9.328 -    val tnames = distinct (op =) (map (#1 o snd) rec_eqns);
   9.329 -    val dts = find_dts dt_info tnames tnames;
   9.330 -    val main_fns =
   9.331 -      map (fn (tname, {index, ...}) =>
   9.332 -        (index,
   9.333 -          (fst o the o find_first (fn f => (#1 o snd) f = tname)) rec_eqns))
   9.334 -      dts;
   9.335 +    val (fixes, spec) = prepare_spec prep_spec lthy raw_fixes raw_spec;
   9.336 +    val eqns = fold_rev (process_eqn (member (op =) (map (fst o fst) fixes))
   9.337 +      (Variable.is_const lthy) o snd) spec [];
   9.338 +    val tnames = distinct (op =) (map (#1 o snd) eqns);
   9.339 +    val dts = find_dts (DatatypePackage.get_datatypes
   9.340 +      (ProofContext.theory_of lthy)) tnames tnames;
   9.341 +    val main_fns = map (fn (tname, {index, ...}) =>
   9.342 +      (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts;
   9.343      val {descr, rec_names, rec_rewrites, ...} =
   9.344 -      if null dts then
   9.345 -        primrec_err ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
   9.346 +      if null dts then primrec_error
   9.347 +        ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
   9.348        else snd (hd dts);
   9.349 -    val (fnameTs, fnss) =
   9.350 -      fold_rev (process_fun thy descr rec_eqns) main_fns ([], []);
   9.351 +    val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
   9.352      val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
   9.353 -    val defs' = map (make_def thy fs) defs;
   9.354 -    val nameTs1 = map snd fnameTs;
   9.355 -    val nameTs2 = map fst rec_eqns;
   9.356 +    val nameTs1 = map snd fnames;
   9.357 +    val nameTs2 = map fst eqns;
   9.358      val _ = if gen_eq_set (op =) (nameTs1, nameTs2) then ()
   9.359 -            else primrec_err ("functions " ^ commas_quote (map fst nameTs2) ^
   9.360 -              "\nare not mutually recursive");
   9.361 -    val primrec_name =
   9.362 -      if alt_name = "" then (space_implode "_" (map (Sign.base_name o #1) defs)) else alt_name;
   9.363 -    val (defs_thms', thy') =
   9.364 -      thy
   9.365 -      |> Sign.add_path primrec_name
   9.366 -      |> fold_map def (map (fn (name, t) => ((name, []), t)) defs');
   9.367 -    val rewrites = (map mk_meta_eq rec_rewrites) @ map snd defs_thms';
   9.368 -    val _ = message ("Proving equations for primrec function(s) " ^
   9.369 -      commas_quote (map fst nameTs1) ^ " ...");
   9.370 -    val simps = map (fn (_, t) => Goal.prove_global thy' [] [] t
   9.371 -        (fn _ => EVERY [rewrite_goals_tac rewrites, rtac refl 1])) eqns;
   9.372 -    val (simps', thy'') =
   9.373 -      thy'
   9.374 -      |> fold_map note ((map fst eqns ~~ atts) ~~ map single simps);
   9.375 -    val simps'' = maps snd simps';
   9.376 +      else primrec_error ("functions " ^ commas_quote nameTs2 ^
   9.377 +        "\nare not mutually recursive");
   9.378 +    val qualify = NameSpace.qualified
   9.379 +      (space_implode "_" (map (Sign.base_name o #1) defs));
   9.380 +    val simp_atts = [Attrib.internal (K Simplifier.simp_add),
   9.381 +      Code.add_default_func_attr (*FIXME*)];
   9.382    in
   9.383 -    thy''
   9.384 -    |> note (("simps", [Simplifier.simp_add, RecfunCodegen.add_default]), simps'')
   9.385 -    |> snd
   9.386 -    |> note (("induct", []), [prepare_induct (#2 (hd dts)) rec_eqns])
   9.387 -    |> snd
   9.388 -    |> Sign.parent_path
   9.389 -    |> pair simps''
   9.390 -  end;
   9.391 -
   9.392 -fun gen_primrec note def alt_name eqns thy =
   9.393 -  let
   9.394 -    val ((names, strings), srcss) = apfst split_list (split_list eqns);
   9.395 -    val atts = map (map (Attrib.attribute thy)) srcss;
   9.396 -    val eqn_ts = map (fn s => Syntax.read_prop_global thy s
   9.397 -      handle ERROR msg => cat_error msg ("The error(s) above occurred for " ^ s)) strings;
   9.398 -    val rec_ts = map (fn eq => head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop eq)))
   9.399 -      handle TERM _ => primrec_eq_err thy "not a proper equation" eq) eqn_ts;
   9.400 -    val (_, eqn_ts') = unify_consts thy rec_ts eqn_ts
   9.401 -  in
   9.402 -    gen_primrec_i note def alt_name (names ~~ eqn_ts' ~~ atts) thy
   9.403 -  end;
   9.404 -
   9.405 -fun thy_note ((name, atts), thms) =
   9.406 -  PureThy.add_thmss [((name, thms), atts)] #-> (fn [thms] => pair (name, thms));
   9.407 -fun thy_def false ((name, atts), t) =
   9.408 -      PureThy.add_defs_i false [((name, t), atts)] #-> (fn [thm] => pair (name, thm))
   9.409 -  | thy_def true ((name, atts), t) =
   9.410 -      PureThy.add_defs_unchecked_i false [((name, t), atts)] #-> (fn [thm] => pair (name, thm));
   9.411 +    lthy
   9.412 +    |> fold_map (LocalTheory.define Thm.definitionK o make_def lthy fixes fs) defs
   9.413 +    |-> (fn defs => `(fn ctxt => prove_spec ctxt rec_rewrites defs spec))
   9.414 +    |-> (fn simps => fold_map (LocalTheory.note Thm.theoremK) simps)
   9.415 +    |-> (fn simps' => LocalTheory.note Thm.theoremK
   9.416 +          ((qualify "simps", simp_atts), maps snd simps'))
   9.417 +    ||>> LocalTheory.note Thm.theoremK
   9.418 +          ((qualify "induct", []), [prepare_induct (#2 (hd dts)) eqns])
   9.419 +    |>> (snd o fst)
   9.420 +  end handle PrimrecError (msg, some_eqn) =>
   9.421 +    error ("Primrec definition error:\n" ^ msg ^ (case some_eqn
   9.422 +     of SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
   9.423 +      | NONE => ""));
   9.424  
   9.425  in
   9.426  
   9.427 -val add_primrec = gen_primrec thy_note (thy_def false);
   9.428 -val add_primrec_unchecked = gen_primrec thy_note (thy_def true);
   9.429 -val add_primrec_i = gen_primrec_i thy_note (thy_def false);
   9.430 -val add_primrec_unchecked_i = gen_primrec_i thy_note (thy_def true);
   9.431 -fun gen_primrec note def alt_name specs =
   9.432 -  gen_primrec_i note def alt_name (map (fn ((name, t), atts) => ((name, atts), t)) specs);
   9.433 +val add_primrec = gen_primrec Specification.check_specification;
   9.434 +val add_primrec_cmd = gen_primrec Specification.read_specification;
   9.435  
   9.436  end;
   9.437  
   9.438 @@ -347,15 +289,27 @@
   9.439      (((P.$$$ "unchecked" >> K true) -- Scan.optional P.name "" ||
   9.440        P.name >> pair false) --| P.$$$ ")")) (false, "");
   9.441  
   9.442 -val primrec_decl =
   9.443 +val old_primrec_decl =
   9.444    opt_unchecked_name -- Scan.repeat1 (SpecParse.opt_thm_name ":" -- P.prop);
   9.445  
   9.446 +fun pipe_error t = P.!!! (Scan.fail_with (K
   9.447 +  (cat_lines ["Equations must be separated by " ^ quote "|", quote t])));
   9.448 +
   9.449 +val statement = SpecParse.opt_thm_name ":" -- P.prop --| Scan.ahead
   9.450 +  ((P.term :-- pipe_error) || Scan.succeed ("",""));
   9.451 +
   9.452 +val statements = P.enum1 "|" statement;
   9.453 +
   9.454 +val primrec_decl = P.opt_target -- P.fixes --| P.$$$ "where" -- statements;
   9.455 +
   9.456  val _ =
   9.457    OuterSyntax.command "primrec" "define primitive recursive functions on datatypes" K.thy_decl
   9.458 -    (primrec_decl >> (fn ((unchecked, alt_name), eqns) =>
   9.459 +    ((primrec_decl >> (fn ((opt_target, raw_fixes), raw_spec) =>
   9.460 +      Toplevel.local_theory opt_target (add_primrec_cmd raw_fixes raw_spec #> snd)))
   9.461 +    || (old_primrec_decl >> (fn ((unchecked, alt_name), eqns) =>
   9.462        Toplevel.theory (snd o
   9.463 -        (if unchecked then add_primrec_unchecked else add_primrec) alt_name
   9.464 -          (map P.triple_swap eqns))));
   9.465 +        (if unchecked then OldPrimrecPackage.add_primrec_unchecked else OldPrimrecPackage.add_primrec) alt_name
   9.466 +          (map P.triple_swap eqns)))));
   9.467  
   9.468  end;
   9.469  
    10.1 --- a/src/HOLCF/Tools/fixrec_package.ML	Thu Dec 06 12:58:01 2007 +0100
    10.2 +++ b/src/HOLCF/Tools/fixrec_package.ML	Thu Dec 06 15:10:09 2007 +0100
    10.3 @@ -231,7 +231,7 @@
    10.4      val eqn_ts = map (prep_prop thy) strings;
    10.5      val rec_ts = map (fn eq => chead_of (fst (dest_eqs (Logic.strip_imp_concl eq)))
    10.6        handle TERM _ => fixrec_eq_err thy "not a proper equation" eq) eqn_ts;
    10.7 -    val (_, eqn_ts') = PrimrecPackage.unify_consts thy rec_ts eqn_ts;
    10.8 +    val (_, eqn_ts') = OldPrimrecPackage.unify_consts thy rec_ts eqn_ts;
    10.9      
   10.10      fun unconcat [] _ = []
   10.11        | unconcat (n::ns) xs = List.take (xs,n) :: unconcat ns (List.drop (xs,n));