clarified modules that contribute to datatype package;
authorwenzelm
Fri Dec 16 10:52:35 2011 +0100 (2011-12-16)
changeset 4589765cef0298158
parent 45896 100fb1f33e3e
child 45898 b619242b0439
clarified modules that contribute to datatype package;
src/HOL/HOLCF/Tools/fixrec.ML
src/HOL/Inductive.thy
src/HOL/IsaMakefile
src/HOL/Tools/Datatype/primrec.ML
src/HOL/Tools/primrec.ML
     1.1 --- a/src/HOL/HOLCF/Tools/fixrec.ML	Fri Dec 16 10:38:38 2011 +0100
     1.2 +++ b/src/HOL/HOLCF/Tools/fixrec.ML	Fri Dec 16 10:52:35 2011 +0100
     1.3 @@ -326,7 +326,7 @@
     1.4  (*************************************************************************)
     1.5  
     1.6  local
     1.7 -(* code adapted from HOL/Tools/primrec.ML *)
     1.8 +(* code adapted from HOL/Tools/Datatype/primrec.ML *)
     1.9  
    1.10  fun gen_fixrec
    1.11    prep_spec
     2.1 --- a/src/HOL/Inductive.thy	Fri Dec 16 10:38:38 2011 +0100
     2.2 +++ b/src/HOL/Inductive.thy	Fri Dec 16 10:52:35 2011 +0100
     2.3 @@ -7,16 +7,16 @@
     2.4  theory Inductive 
     2.5  imports Complete_Lattices
     2.6  uses
     2.7 +  "Tools/dseq.ML"
     2.8    ("Tools/inductive.ML")
     2.9 -  "Tools/dseq.ML"
    2.10 -  "Tools/Datatype/datatype_aux.ML"
    2.11 -  "Tools/Datatype/datatype_prop.ML"
    2.12 +  ("Tools/Datatype/datatype_aux.ML")
    2.13 +  ("Tools/Datatype/datatype_prop.ML")
    2.14    ("Tools/Datatype/datatype_abs_proofs.ML")
    2.15    ("Tools/Datatype/datatype_data.ML")
    2.16    ("Tools/Datatype/datatype_case.ML")
    2.17    ("Tools/Datatype/rep_datatype.ML")
    2.18 -  ("Tools/primrec.ML")
    2.19    ("Tools/Datatype/datatype_codegen.ML")
    2.20 +  ("Tools/Datatype/primrec.ML")
    2.21  begin
    2.22  
    2.23  subsection {* Least and greatest fixed points *}
    2.24 @@ -276,15 +276,14 @@
    2.25  
    2.26  text {* Package setup. *}
    2.27  
    2.28 +use "Tools/Datatype/datatype_aux.ML"
    2.29 +use "Tools/Datatype/datatype_prop.ML"
    2.30  use "Tools/Datatype/datatype_abs_proofs.ML"
    2.31  use "Tools/Datatype/datatype_data.ML" setup Datatype_Data.setup
    2.32  use "Tools/Datatype/datatype_case.ML" setup Datatype_Case.setup
    2.33  use "Tools/Datatype/rep_datatype.ML"
    2.34 -
    2.35 -use "Tools/Datatype/datatype_codegen.ML"
    2.36 -setup Datatype_Codegen.setup
    2.37 -
    2.38 -use "Tools/primrec.ML"
    2.39 +use "Tools/Datatype/datatype_codegen.ML" setup Datatype_Codegen.setup
    2.40 +use "Tools/Datatype/primrec.ML"
    2.41  
    2.42  text{* Lambda-abstractions with pattern matching: *}
    2.43  
     3.1 --- a/src/HOL/IsaMakefile	Fri Dec 16 10:38:38 2011 +0100
     3.2 +++ b/src/HOL/IsaMakefile	Fri Dec 16 10:52:35 2011 +0100
     3.3 @@ -218,6 +218,7 @@
     3.4    Tools/Datatype/datatype_data.ML \
     3.5    Tools/Datatype/datatype_prop.ML \
     3.6    Tools/Datatype/datatype_realizer.ML \
     3.7 +  Tools/Datatype/primrec.ML \
     3.8    Tools/Datatype/rep_datatype.ML \
     3.9    Tools/Function/context_tree.ML \
    3.10    Tools/Function/fun.ML \
    3.11 @@ -256,7 +257,6 @@
    3.12    Tools/lin_arith.ML \
    3.13    Tools/monomorph.ML \
    3.14    Tools/nat_arith.ML \
    3.15 -  Tools/primrec.ML \
    3.16    Tools/prop_logic.ML \
    3.17    Tools/refute.ML \
    3.18    Tools/rewrite_hol_proof.ML \
     4.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     4.2 +++ b/src/HOL/Tools/Datatype/primrec.ML	Fri Dec 16 10:52:35 2011 +0100
     4.3 @@ -0,0 +1,318 @@
     4.4 +(*  Title:      HOL/Tools/Datatype/primrec.ML
     4.5 +    Author:     Norbert Voelker, FernUni Hagen
     4.6 +    Author:     Stefan Berghofer, TU Muenchen
     4.7 +    Author:     Florian Haftmann, TU Muenchen
     4.8 +
     4.9 +Primitive recursive functions on datatypes.
    4.10 +*)
    4.11 +
    4.12 +signature PRIMREC =
    4.13 +sig
    4.14 +  val add_primrec: (binding * typ option * mixfix) list ->
    4.15 +    (Attrib.binding * term) list -> local_theory -> (term list * thm list) * local_theory
    4.16 +  val add_primrec_cmd: (binding * string option * mixfix) list ->
    4.17 +    (Attrib.binding * string) list -> local_theory -> (term list * thm list) * local_theory
    4.18 +  val add_primrec_global: (binding * typ option * mixfix) list ->
    4.19 +    (Attrib.binding * term) list -> theory -> (term list * thm list) * theory
    4.20 +  val add_primrec_overloaded: (string * (string * typ) * bool) list ->
    4.21 +    (binding * typ option * mixfix) list ->
    4.22 +    (Attrib.binding * term) list -> theory -> (term list * thm list) * theory
    4.23 +  val add_primrec_simple: ((binding * typ) * mixfix) list -> term list ->
    4.24 +    local_theory -> (string * (term list * thm list)) * local_theory
    4.25 +end;
    4.26 +
    4.27 +structure Primrec : PRIMREC =
    4.28 +struct
    4.29 +
    4.30 +exception PrimrecError of string * term option;
    4.31 +
    4.32 +fun primrec_error msg = raise PrimrecError (msg, NONE);
    4.33 +fun primrec_error_eqn msg eqn = raise PrimrecError (msg, SOME eqn);
    4.34 +
    4.35 +
    4.36 +(* preprocessing of equations *)
    4.37 +
    4.38 +fun process_eqn is_fixed spec rec_fns =
    4.39 +  let
    4.40 +    val (vs, Ts) = split_list (strip_qnt_vars "all" spec);
    4.41 +    val body = strip_qnt_body "all" spec;
    4.42 +    val (vs', _) = fold_map Name.variant vs (Name.make_context (fold_aterms
    4.43 +      (fn Free (v, _) => insert (op =) v | _ => I) body []));
    4.44 +    val eqn = curry subst_bounds (map2 (curry Free) vs' Ts |> rev) body;
    4.45 +    val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn)
    4.46 +      handle TERM _ => primrec_error "not a proper equation";
    4.47 +    val (recfun, args) = strip_comb lhs;
    4.48 +    val fname =
    4.49 +      (case recfun of
    4.50 +        Free (v, _) =>
    4.51 +          if is_fixed v then v
    4.52 +          else primrec_error "illegal head of function equation"
    4.53 +      | _ => primrec_error "illegal head of function equation");
    4.54 +
    4.55 +    val (ls', rest)  = take_prefix is_Free args;
    4.56 +    val (middle, rs') = take_suffix is_Free rest;
    4.57 +    val rpos = length ls';
    4.58 +
    4.59 +    val (constr, cargs') =
    4.60 +      if null middle then primrec_error "constructor missing"
    4.61 +      else strip_comb (hd middle);
    4.62 +    val (cname, T) = dest_Const constr
    4.63 +      handle TERM _ => primrec_error "ill-formed constructor";
    4.64 +    val (tname, _) = dest_Type (body_type T) handle TYPE _ =>
    4.65 +      primrec_error "cannot determine datatype associated with function"
    4.66 +
    4.67 +    val (ls, cargs, rs) =
    4.68 +      (map dest_Free ls', map dest_Free cargs', map dest_Free rs')
    4.69 +      handle TERM _ => primrec_error "illegal argument in pattern";
    4.70 +    val lfrees = ls @ rs @ cargs;
    4.71 +
    4.72 +    fun check_vars _ [] = ()
    4.73 +      | check_vars s vars = primrec_error (s ^ commas_quote (map fst vars)) eqn;
    4.74 +  in
    4.75 +    if length middle > 1 then
    4.76 +      primrec_error "more than one non-variable in pattern"
    4.77 +    else
    4.78 +     (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees);
    4.79 +      check_vars "extra variables on rhs: "
    4.80 +        (Term.add_frees rhs [] |> subtract (op =) lfrees
    4.81 +          |> filter_out (is_fixed o fst));
    4.82 +      (case AList.lookup (op =) rec_fns fname of
    4.83 +        NONE =>
    4.84 +          (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eqn))])) :: rec_fns
    4.85 +      | SOME (_, rpos', eqns) =>
    4.86 +          if AList.defined (op =) eqns cname then
    4.87 +            primrec_error "constructor already occurred as pattern"
    4.88 +          else if rpos <> rpos' then
    4.89 +            primrec_error "position of recursive argument inconsistent"
    4.90 +          else
    4.91 +            AList.update (op =)
    4.92 +              (fname, (tname, rpos, (cname, (ls, cargs, rs, rhs, eqn)) :: eqns))
    4.93 +              rec_fns))
    4.94 +  end handle PrimrecError (msg, NONE) => primrec_error_eqn msg spec;
    4.95 +
    4.96 +fun process_fun descr eqns (i, fname) (fnames, fnss) =
    4.97 +  let
    4.98 +    val (_, (tname, _, constrs)) = nth descr i;
    4.99 +
   4.100 +    (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *)
   4.101 +
   4.102 +    fun subst [] t fs = (t, fs)
   4.103 +      | subst subs (Abs (a, T, t)) fs =
   4.104 +          fs
   4.105 +          |> subst subs t
   4.106 +          |-> (fn t' => pair (Abs (a, T, t')))
   4.107 +      | subst subs (t as (_ $ _)) fs =
   4.108 +          let
   4.109 +            val (f, ts) = strip_comb t;
   4.110 +          in
   4.111 +            if is_Free f
   4.112 +              andalso member (fn ((v, _), (w, _)) => v = w) eqns (dest_Free f) then
   4.113 +              let
   4.114 +                val (fname', _) = dest_Free f;
   4.115 +                val (_, rpos, _) = the (AList.lookup (op =) eqns fname');
   4.116 +                val (ls, rs) = chop rpos ts
   4.117 +                val (x', rs') =
   4.118 +                  (case rs of
   4.119 +                    x' :: rs => (x', rs)
   4.120 +                  | [] => primrec_error ("not enough arguments in recursive application\n" ^
   4.121 +                      "of function " ^ quote fname' ^ " on rhs"));
   4.122 +                val (x, xs) = strip_comb x';
   4.123 +              in
   4.124 +                (case AList.lookup (op =) subs x of
   4.125 +                  NONE =>
   4.126 +                    fs
   4.127 +                    |> fold_map (subst subs) ts
   4.128 +                    |-> (fn ts' => pair (list_comb (f, ts')))
   4.129 +                | SOME (i', y) =>
   4.130 +                    fs
   4.131 +                    |> fold_map (subst subs) (xs @ ls @ rs')
   4.132 +                    ||> process_fun descr eqns (i', fname')
   4.133 +                    |-> (fn ts' => pair (list_comb (y, ts'))))
   4.134 +              end
   4.135 +            else
   4.136 +              fs
   4.137 +              |> fold_map (subst subs) (f :: ts)
   4.138 +              |-> (fn f' :: ts' => pair (list_comb (f', ts')))
   4.139 +          end
   4.140 +      | subst _ t fs = (t, fs);
   4.141 +
   4.142 +    (* translate rec equations into function arguments suitable for rec comb *)
   4.143 +
   4.144 +    fun trans eqns (cname, cargs) (fnames', fnss', fns) =
   4.145 +      (case AList.lookup (op =) eqns cname of
   4.146 +        NONE => (warning ("No equation for constructor " ^ quote cname ^
   4.147 +          "\nin definition of function " ^ quote fname);
   4.148 +            (fnames', fnss', (Const (@{const_name undefined}, dummyT)) :: fns))
   4.149 +      | SOME (ls, cargs', rs, rhs, eq) =>
   4.150 +          let
   4.151 +            val recs = filter (Datatype_Aux.is_rec_type o snd) (cargs' ~~ cargs);
   4.152 +            val rargs = map fst recs;
   4.153 +            val subs = map (rpair dummyT o fst)
   4.154 +              (rev (Term.rename_wrt_term rhs rargs));
   4.155 +            val (rhs', (fnames'', fnss'')) = subst (map2 (fn (x, y) => fn z =>
   4.156 +              (Free x, (Datatype_Aux.body_index y, Free z))) recs subs) rhs (fnames', fnss')
   4.157 +                handle PrimrecError (s, NONE) => primrec_error_eqn s eq
   4.158 +          in
   4.159 +            (fnames'', fnss'', fold_rev absfree (cargs' @ subs @ ls @ rs) rhs' :: fns)
   4.160 +          end)
   4.161 +
   4.162 +  in
   4.163 +    (case AList.lookup (op =) fnames i of
   4.164 +      NONE =>
   4.165 +        if exists (fn (_, v) => fname = v) fnames then
   4.166 +          primrec_error ("inconsistent functions for datatype " ^ quote tname)
   4.167 +        else
   4.168 +          let
   4.169 +            val (_, _, eqns) = the (AList.lookup (op =) eqns fname);
   4.170 +            val (fnames', fnss', fns) = fold_rev (trans eqns) constrs
   4.171 +              ((i, fname) :: fnames, fnss, [])
   4.172 +          in
   4.173 +            (fnames', (i, (fname, #1 (snd (hd eqns)), fns)) :: fnss')
   4.174 +          end
   4.175 +    | SOME fname' =>
   4.176 +        if fname = fname' then (fnames, fnss)
   4.177 +        else primrec_error ("inconsistent functions for datatype " ^ quote tname))
   4.178 +  end;
   4.179 +
   4.180 +
   4.181 +(* prepare functions needed for definitions *)
   4.182 +
   4.183 +fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) =
   4.184 +  (case AList.lookup (op =) fns i of
   4.185 +    NONE =>
   4.186 +      let
   4.187 +        val dummy_fns = map (fn (_, cargs) => Const (@{const_name undefined},
   4.188 +          replicate (length cargs + length (filter Datatype_Aux.is_rec_type cargs))
   4.189 +            dummyT ---> HOLogic.unitT)) constrs;
   4.190 +        val _ = warning ("No function definition for datatype " ^ quote tname)
   4.191 +      in
   4.192 +        (dummy_fns @ fs, defs)
   4.193 +      end
   4.194 +  | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs));
   4.195 +
   4.196 +
   4.197 +(* make definition *)
   4.198 +
   4.199 +fun make_def ctxt fixes fs (fname, ls, rec_name, tname) =
   4.200 +  let
   4.201 +    val SOME (var, varT) = get_first (fn ((b, T), mx) =>
   4.202 +      if Binding.name_of b = fname then SOME ((b, mx), T) else NONE) fixes;
   4.203 +    val def_name = Thm.def_name (Long_Name.base_name fname);
   4.204 +    val raw_rhs = fold_rev (fn T => fn t => Abs ("", T, t)) (map snd ls @ [dummyT])
   4.205 +      (list_comb (Const (rec_name, dummyT), fs @ map Bound (0 :: (length ls downto 1))))
   4.206 +    val rhs = singleton (Syntax.check_terms ctxt) (Type.constraint varT raw_rhs);
   4.207 +  in (var, ((Binding.conceal (Binding.name def_name), []), rhs)) end;
   4.208 +
   4.209 +
   4.210 +(* find datatypes which contain all datatypes in tnames' *)
   4.211 +
   4.212 +fun find_dts (dt_info : Datatype_Aux.info Symtab.table) _ [] = []
   4.213 +  | find_dts dt_info tnames' (tname :: tnames) =
   4.214 +      (case Symtab.lookup dt_info tname of
   4.215 +        NONE => primrec_error (quote tname ^ " is not a datatype")
   4.216 +      | SOME dt =>
   4.217 +          if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then
   4.218 +            (tname, dt) :: (find_dts dt_info tnames' tnames)
   4.219 +          else find_dts dt_info tnames' tnames);
   4.220 +
   4.221 +
   4.222 +(* distill primitive definition(s) from primrec specification *)
   4.223 +
   4.224 +fun distill lthy fixes eqs =
   4.225 +  let
   4.226 +    val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v
   4.227 +      orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs [];
   4.228 +    val tnames = distinct (op =) (map (#1 o snd) eqns);
   4.229 +    val dts = find_dts (Datatype_Data.get_all (Proof_Context.theory_of lthy)) tnames tnames;
   4.230 +    val main_fns = map (fn (tname, {index, ...}) =>
   4.231 +      (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts;
   4.232 +    val {descr, rec_names, rec_rewrites, ...} =
   4.233 +      if null dts then primrec_error
   4.234 +        ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
   4.235 +      else snd (hd dts);
   4.236 +    val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
   4.237 +    val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
   4.238 +    val defs = map (make_def lthy fixes fs) raw_defs;
   4.239 +    val names = map snd fnames;
   4.240 +    val names_eqns = map fst eqns;
   4.241 +    val _ =
   4.242 +      if eq_set (op =) (names, names_eqns) then ()
   4.243 +      else primrec_error ("functions " ^ commas_quote names_eqns ^
   4.244 +        "\nare not mutually recursive");
   4.245 +    val rec_rewrites' = map mk_meta_eq rec_rewrites;
   4.246 +    val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs);
   4.247 +    fun prove lthy defs =
   4.248 +      let
   4.249 +        val frees = fold (Variable.add_free_names lthy) eqs [];
   4.250 +        val rewrites = rec_rewrites' @ map (snd o snd) defs;
   4.251 +        fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
   4.252 +      in map (fn eq => Goal.prove lthy frees [] eq tac) eqs end;
   4.253 +  in ((prefix, (fs, defs)), prove) end
   4.254 +  handle PrimrecError (msg, some_eqn) =>
   4.255 +    error ("Primrec definition error:\n" ^ msg ^
   4.256 +      (case some_eqn of
   4.257 +        SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
   4.258 +      | NONE => ""));
   4.259 +
   4.260 +
   4.261 +(* primrec definition *)
   4.262 +
   4.263 +fun add_primrec_simple fixes ts lthy =
   4.264 +  let
   4.265 +    val ((prefix, (fs, defs)), prove) = distill lthy fixes ts;
   4.266 +  in
   4.267 +    lthy
   4.268 +    |> fold_map Local_Theory.define defs
   4.269 +    |-> (fn defs => `(fn lthy => (prefix, (map fst defs, prove lthy defs))))
   4.270 +  end;
   4.271 +
   4.272 +local
   4.273 +
   4.274 +fun gen_primrec prep_spec raw_fixes raw_spec lthy =
   4.275 +  let
   4.276 +    val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy);
   4.277 +    fun attr_bindings prefix = map (fn ((b, attrs), _) =>
   4.278 +      (Binding.qualify false prefix b, Code.add_default_eqn_attrib :: attrs)) spec;
   4.279 +    fun simp_attr_binding prefix =
   4.280 +      (Binding.qualify true prefix (Binding.name "simps"), @{attributes [simp, nitpick_simp]});
   4.281 +  in
   4.282 +    lthy
   4.283 +    |> add_primrec_simple fixes (map snd spec)
   4.284 +    |-> (fn (prefix, (ts, simps)) =>
   4.285 +      Spec_Rules.add Spec_Rules.Equational (ts, simps)
   4.286 +      #> fold_map Local_Theory.note (attr_bindings prefix ~~ map single simps)
   4.287 +      #-> (fn simps' => Local_Theory.note (simp_attr_binding prefix, maps snd simps')
   4.288 +      #>> (fn (_, simps'') => (ts, simps''))))
   4.289 +  end;
   4.290 +
   4.291 +in
   4.292 +
   4.293 +val add_primrec = gen_primrec Specification.check_spec;
   4.294 +val add_primrec_cmd = gen_primrec Specification.read_spec;
   4.295 +
   4.296 +end;
   4.297 +
   4.298 +fun add_primrec_global fixes specs thy =
   4.299 +  let
   4.300 +    val lthy = Named_Target.theory_init thy;
   4.301 +    val ((ts, simps), lthy') = add_primrec fixes specs lthy;
   4.302 +    val simps' = Proof_Context.export lthy' lthy simps;
   4.303 +  in ((ts, simps'), Local_Theory.exit_global lthy') end;
   4.304 +
   4.305 +fun add_primrec_overloaded ops fixes specs thy =
   4.306 +  let
   4.307 +    val lthy = Overloading.overloading ops thy;
   4.308 +    val ((ts, simps), lthy') = add_primrec fixes specs lthy;
   4.309 +    val simps' = Proof_Context.export lthy' lthy simps;
   4.310 +  in ((ts, simps'), Local_Theory.exit_global lthy') end;
   4.311 +
   4.312 +
   4.313 +(* outer syntax *)
   4.314 +
   4.315 +val _ =
   4.316 +  Outer_Syntax.local_theory "primrec" "define primitive recursive functions on datatypes"
   4.317 +    Keyword.thy_decl
   4.318 +    (Parse.fixes -- Parse_Spec.where_alt_specs
   4.319 +      >> (fn (fixes, specs) => add_primrec_cmd fixes specs #> snd));
   4.320 +
   4.321 +end;
     5.1 --- a/src/HOL/Tools/primrec.ML	Fri Dec 16 10:38:38 2011 +0100
     5.2 +++ /dev/null	Thu Jan 01 00:00:00 1970 +0000
     5.3 @@ -1,318 +0,0 @@
     5.4 -(*  Title:      HOL/Tools/primrec.ML
     5.5 -    Author:     Norbert Voelker, FernUni Hagen
     5.6 -    Author:     Stefan Berghofer, TU Muenchen
     5.7 -    Author:     Florian Haftmann, TU Muenchen
     5.8 -
     5.9 -Primitive recursive functions on datatypes.
    5.10 -*)
    5.11 -
    5.12 -signature PRIMREC =
    5.13 -sig
    5.14 -  val add_primrec: (binding * typ option * mixfix) list ->
    5.15 -    (Attrib.binding * term) list -> local_theory -> (term list * thm list) * local_theory
    5.16 -  val add_primrec_cmd: (binding * string option * mixfix) list ->
    5.17 -    (Attrib.binding * string) list -> local_theory -> (term list * thm list) * local_theory
    5.18 -  val add_primrec_global: (binding * typ option * mixfix) list ->
    5.19 -    (Attrib.binding * term) list -> theory -> (term list * thm list) * theory
    5.20 -  val add_primrec_overloaded: (string * (string * typ) * bool) list ->
    5.21 -    (binding * typ option * mixfix) list ->
    5.22 -    (Attrib.binding * term) list -> theory -> (term list * thm list) * theory
    5.23 -  val add_primrec_simple: ((binding * typ) * mixfix) list -> term list ->
    5.24 -    local_theory -> (string * (term list * thm list)) * local_theory
    5.25 -end;
    5.26 -
    5.27 -structure Primrec : PRIMREC =
    5.28 -struct
    5.29 -
    5.30 -exception PrimrecError of string * term option;
    5.31 -
    5.32 -fun primrec_error msg = raise PrimrecError (msg, NONE);
    5.33 -fun primrec_error_eqn msg eqn = raise PrimrecError (msg, SOME eqn);
    5.34 -
    5.35 -
    5.36 -(* preprocessing of equations *)
    5.37 -
    5.38 -fun process_eqn is_fixed spec rec_fns =
    5.39 -  let
    5.40 -    val (vs, Ts) = split_list (strip_qnt_vars "all" spec);
    5.41 -    val body = strip_qnt_body "all" spec;
    5.42 -    val (vs', _) = fold_map Name.variant vs (Name.make_context (fold_aterms
    5.43 -      (fn Free (v, _) => insert (op =) v | _ => I) body []));
    5.44 -    val eqn = curry subst_bounds (map2 (curry Free) vs' Ts |> rev) body;
    5.45 -    val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn)
    5.46 -      handle TERM _ => primrec_error "not a proper equation";
    5.47 -    val (recfun, args) = strip_comb lhs;
    5.48 -    val fname =
    5.49 -      (case recfun of
    5.50 -        Free (v, _) =>
    5.51 -          if is_fixed v then v
    5.52 -          else primrec_error "illegal head of function equation"
    5.53 -      | _ => primrec_error "illegal head of function equation");
    5.54 -
    5.55 -    val (ls', rest)  = take_prefix is_Free args;
    5.56 -    val (middle, rs') = take_suffix is_Free rest;
    5.57 -    val rpos = length ls';
    5.58 -
    5.59 -    val (constr, cargs') =
    5.60 -      if null middle then primrec_error "constructor missing"
    5.61 -      else strip_comb (hd middle);
    5.62 -    val (cname, T) = dest_Const constr
    5.63 -      handle TERM _ => primrec_error "ill-formed constructor";
    5.64 -    val (tname, _) = dest_Type (body_type T) handle TYPE _ =>
    5.65 -      primrec_error "cannot determine datatype associated with function"
    5.66 -
    5.67 -    val (ls, cargs, rs) =
    5.68 -      (map dest_Free ls', map dest_Free cargs', map dest_Free rs')
    5.69 -      handle TERM _ => primrec_error "illegal argument in pattern";
    5.70 -    val lfrees = ls @ rs @ cargs;
    5.71 -
    5.72 -    fun check_vars _ [] = ()
    5.73 -      | check_vars s vars = primrec_error (s ^ commas_quote (map fst vars)) eqn;
    5.74 -  in
    5.75 -    if length middle > 1 then
    5.76 -      primrec_error "more than one non-variable in pattern"
    5.77 -    else
    5.78 -     (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees);
    5.79 -      check_vars "extra variables on rhs: "
    5.80 -        (Term.add_frees rhs [] |> subtract (op =) lfrees
    5.81 -          |> filter_out (is_fixed o fst));
    5.82 -      (case AList.lookup (op =) rec_fns fname of
    5.83 -        NONE =>
    5.84 -          (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eqn))])) :: rec_fns
    5.85 -      | SOME (_, rpos', eqns) =>
    5.86 -          if AList.defined (op =) eqns cname then
    5.87 -            primrec_error "constructor already occurred as pattern"
    5.88 -          else if rpos <> rpos' then
    5.89 -            primrec_error "position of recursive argument inconsistent"
    5.90 -          else
    5.91 -            AList.update (op =)
    5.92 -              (fname, (tname, rpos, (cname, (ls, cargs, rs, rhs, eqn)) :: eqns))
    5.93 -              rec_fns))
    5.94 -  end handle PrimrecError (msg, NONE) => primrec_error_eqn msg spec;
    5.95 -
    5.96 -fun process_fun descr eqns (i, fname) (fnames, fnss) =
    5.97 -  let
    5.98 -    val (_, (tname, _, constrs)) = nth descr i;
    5.99 -
   5.100 -    (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *)
   5.101 -
   5.102 -    fun subst [] t fs = (t, fs)
   5.103 -      | subst subs (Abs (a, T, t)) fs =
   5.104 -          fs
   5.105 -          |> subst subs t
   5.106 -          |-> (fn t' => pair (Abs (a, T, t')))
   5.107 -      | subst subs (t as (_ $ _)) fs =
   5.108 -          let
   5.109 -            val (f, ts) = strip_comb t;
   5.110 -          in
   5.111 -            if is_Free f
   5.112 -              andalso member (fn ((v, _), (w, _)) => v = w) eqns (dest_Free f) then
   5.113 -              let
   5.114 -                val (fname', _) = dest_Free f;
   5.115 -                val (_, rpos, _) = the (AList.lookup (op =) eqns fname');
   5.116 -                val (ls, rs) = chop rpos ts
   5.117 -                val (x', rs') =
   5.118 -                  (case rs of
   5.119 -                    x' :: rs => (x', rs)
   5.120 -                  | [] => primrec_error ("not enough arguments in recursive application\n" ^
   5.121 -                      "of function " ^ quote fname' ^ " on rhs"));
   5.122 -                val (x, xs) = strip_comb x';
   5.123 -              in
   5.124 -                (case AList.lookup (op =) subs x of
   5.125 -                  NONE =>
   5.126 -                    fs
   5.127 -                    |> fold_map (subst subs) ts
   5.128 -                    |-> (fn ts' => pair (list_comb (f, ts')))
   5.129 -                | SOME (i', y) =>
   5.130 -                    fs
   5.131 -                    |> fold_map (subst subs) (xs @ ls @ rs')
   5.132 -                    ||> process_fun descr eqns (i', fname')
   5.133 -                    |-> (fn ts' => pair (list_comb (y, ts'))))
   5.134 -              end
   5.135 -            else
   5.136 -              fs
   5.137 -              |> fold_map (subst subs) (f :: ts)
   5.138 -              |-> (fn f' :: ts' => pair (list_comb (f', ts')))
   5.139 -          end
   5.140 -      | subst _ t fs = (t, fs);
   5.141 -
   5.142 -    (* translate rec equations into function arguments suitable for rec comb *)
   5.143 -
   5.144 -    fun trans eqns (cname, cargs) (fnames', fnss', fns) =
   5.145 -      (case AList.lookup (op =) eqns cname of
   5.146 -        NONE => (warning ("No equation for constructor " ^ quote cname ^
   5.147 -          "\nin definition of function " ^ quote fname);
   5.148 -            (fnames', fnss', (Const (@{const_name undefined}, dummyT)) :: fns))
   5.149 -      | SOME (ls, cargs', rs, rhs, eq) =>
   5.150 -          let
   5.151 -            val recs = filter (Datatype_Aux.is_rec_type o snd) (cargs' ~~ cargs);
   5.152 -            val rargs = map fst recs;
   5.153 -            val subs = map (rpair dummyT o fst)
   5.154 -              (rev (Term.rename_wrt_term rhs rargs));
   5.155 -            val (rhs', (fnames'', fnss'')) = subst (map2 (fn (x, y) => fn z =>
   5.156 -              (Free x, (Datatype_Aux.body_index y, Free z))) recs subs) rhs (fnames', fnss')
   5.157 -                handle PrimrecError (s, NONE) => primrec_error_eqn s eq
   5.158 -          in
   5.159 -            (fnames'', fnss'', fold_rev absfree (cargs' @ subs @ ls @ rs) rhs' :: fns)
   5.160 -          end)
   5.161 -
   5.162 -  in
   5.163 -    (case AList.lookup (op =) fnames i of
   5.164 -      NONE =>
   5.165 -        if exists (fn (_, v) => fname = v) fnames then
   5.166 -          primrec_error ("inconsistent functions for datatype " ^ quote tname)
   5.167 -        else
   5.168 -          let
   5.169 -            val (_, _, eqns) = the (AList.lookup (op =) eqns fname);
   5.170 -            val (fnames', fnss', fns) = fold_rev (trans eqns) constrs
   5.171 -              ((i, fname) :: fnames, fnss, [])
   5.172 -          in
   5.173 -            (fnames', (i, (fname, #1 (snd (hd eqns)), fns)) :: fnss')
   5.174 -          end
   5.175 -    | SOME fname' =>
   5.176 -        if fname = fname' then (fnames, fnss)
   5.177 -        else primrec_error ("inconsistent functions for datatype " ^ quote tname))
   5.178 -  end;
   5.179 -
   5.180 -
   5.181 -(* prepare functions needed for definitions *)
   5.182 -
   5.183 -fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) =
   5.184 -  (case AList.lookup (op =) fns i of
   5.185 -    NONE =>
   5.186 -      let
   5.187 -        val dummy_fns = map (fn (_, cargs) => Const (@{const_name undefined},
   5.188 -          replicate (length cargs + length (filter Datatype_Aux.is_rec_type cargs))
   5.189 -            dummyT ---> HOLogic.unitT)) constrs;
   5.190 -        val _ = warning ("No function definition for datatype " ^ quote tname)
   5.191 -      in
   5.192 -        (dummy_fns @ fs, defs)
   5.193 -      end
   5.194 -  | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs));
   5.195 -
   5.196 -
   5.197 -(* make definition *)
   5.198 -
   5.199 -fun make_def ctxt fixes fs (fname, ls, rec_name, tname) =
   5.200 -  let
   5.201 -    val SOME (var, varT) = get_first (fn ((b, T), mx) =>
   5.202 -      if Binding.name_of b = fname then SOME ((b, mx), T) else NONE) fixes;
   5.203 -    val def_name = Thm.def_name (Long_Name.base_name fname);
   5.204 -    val raw_rhs = fold_rev (fn T => fn t => Abs ("", T, t)) (map snd ls @ [dummyT])
   5.205 -      (list_comb (Const (rec_name, dummyT), fs @ map Bound (0 :: (length ls downto 1))))
   5.206 -    val rhs = singleton (Syntax.check_terms ctxt) (Type.constraint varT raw_rhs);
   5.207 -  in (var, ((Binding.conceal (Binding.name def_name), []), rhs)) end;
   5.208 -
   5.209 -
   5.210 -(* find datatypes which contain all datatypes in tnames' *)
   5.211 -
   5.212 -fun find_dts (dt_info : Datatype_Aux.info Symtab.table) _ [] = []
   5.213 -  | find_dts dt_info tnames' (tname :: tnames) =
   5.214 -      (case Symtab.lookup dt_info tname of
   5.215 -        NONE => primrec_error (quote tname ^ " is not a datatype")
   5.216 -      | SOME dt =>
   5.217 -          if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then
   5.218 -            (tname, dt) :: (find_dts dt_info tnames' tnames)
   5.219 -          else find_dts dt_info tnames' tnames);
   5.220 -
   5.221 -
   5.222 -(* distill primitive definition(s) from primrec specification *)
   5.223 -
   5.224 -fun distill lthy fixes eqs = 
   5.225 -  let
   5.226 -    val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v
   5.227 -      orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs [];
   5.228 -    val tnames = distinct (op =) (map (#1 o snd) eqns);
   5.229 -    val dts = find_dts (Datatype_Data.get_all (Proof_Context.theory_of lthy)) tnames tnames;
   5.230 -    val main_fns = map (fn (tname, {index, ...}) =>
   5.231 -      (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts;
   5.232 -    val {descr, rec_names, rec_rewrites, ...} =
   5.233 -      if null dts then primrec_error
   5.234 -        ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
   5.235 -      else snd (hd dts);
   5.236 -    val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
   5.237 -    val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
   5.238 -    val defs = map (make_def lthy fixes fs) raw_defs;
   5.239 -    val names = map snd fnames;
   5.240 -    val names_eqns = map fst eqns;
   5.241 -    val _ =
   5.242 -      if eq_set (op =) (names, names_eqns) then ()
   5.243 -      else primrec_error ("functions " ^ commas_quote names_eqns ^
   5.244 -        "\nare not mutually recursive");
   5.245 -    val rec_rewrites' = map mk_meta_eq rec_rewrites;
   5.246 -    val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs);
   5.247 -    fun prove lthy defs =
   5.248 -      let
   5.249 -        val frees = fold (Variable.add_free_names lthy) eqs [];
   5.250 -        val rewrites = rec_rewrites' @ map (snd o snd) defs;
   5.251 -        fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
   5.252 -      in map (fn eq => Goal.prove lthy frees [] eq tac) eqs end;
   5.253 -  in ((prefix, (fs, defs)), prove) end
   5.254 -  handle PrimrecError (msg, some_eqn) =>
   5.255 -    error ("Primrec definition error:\n" ^ msg ^
   5.256 -      (case some_eqn of
   5.257 -        SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
   5.258 -      | NONE => ""));
   5.259 -
   5.260 -
   5.261 -(* primrec definition *)
   5.262 -
   5.263 -fun add_primrec_simple fixes ts lthy =
   5.264 -  let
   5.265 -    val ((prefix, (fs, defs)), prove) = distill lthy fixes ts;
   5.266 -  in
   5.267 -    lthy
   5.268 -    |> fold_map Local_Theory.define defs
   5.269 -    |-> (fn defs => `(fn lthy => (prefix, (map fst defs, prove lthy defs))))
   5.270 -  end;
   5.271 -
   5.272 -local
   5.273 -
   5.274 -fun gen_primrec prep_spec raw_fixes raw_spec lthy =
   5.275 -  let
   5.276 -    val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy);
   5.277 -    fun attr_bindings prefix = map (fn ((b, attrs), _) =>
   5.278 -      (Binding.qualify false prefix b, Code.add_default_eqn_attrib :: attrs)) spec;
   5.279 -    fun simp_attr_binding prefix =
   5.280 -      (Binding.qualify true prefix (Binding.name "simps"), @{attributes [simp, nitpick_simp]});
   5.281 -  in
   5.282 -    lthy
   5.283 -    |> add_primrec_simple fixes (map snd spec)
   5.284 -    |-> (fn (prefix, (ts, simps)) =>
   5.285 -      Spec_Rules.add Spec_Rules.Equational (ts, simps)
   5.286 -      #> fold_map Local_Theory.note (attr_bindings prefix ~~ map single simps)
   5.287 -      #-> (fn simps' => Local_Theory.note (simp_attr_binding prefix, maps snd simps')
   5.288 -      #>> (fn (_, simps'') => (ts, simps''))))
   5.289 -  end;
   5.290 -
   5.291 -in
   5.292 -
   5.293 -val add_primrec = gen_primrec Specification.check_spec;
   5.294 -val add_primrec_cmd = gen_primrec Specification.read_spec;
   5.295 -
   5.296 -end;
   5.297 -
   5.298 -fun add_primrec_global fixes specs thy =
   5.299 -  let
   5.300 -    val lthy = Named_Target.theory_init thy;
   5.301 -    val ((ts, simps), lthy') = add_primrec fixes specs lthy;
   5.302 -    val simps' = Proof_Context.export lthy' lthy simps;
   5.303 -  in ((ts, simps'), Local_Theory.exit_global lthy') end;
   5.304 -
   5.305 -fun add_primrec_overloaded ops fixes specs thy =
   5.306 -  let
   5.307 -    val lthy = Overloading.overloading ops thy;
   5.308 -    val ((ts, simps), lthy') = add_primrec fixes specs lthy;
   5.309 -    val simps' = Proof_Context.export lthy' lthy simps;
   5.310 -  in ((ts, simps'), Local_Theory.exit_global lthy') end;
   5.311 -
   5.312 -
   5.313 -(* outer syntax *)
   5.314 -
   5.315 -val _ =
   5.316 -  Outer_Syntax.local_theory "primrec" "define primitive recursive functions on datatypes"
   5.317 -    Keyword.thy_decl
   5.318 -    (Parse.fixes -- Parse_Spec.where_alt_specs
   5.319 -      >> (fn (fixes, specs) => add_primrec_cmd fixes specs #> snd));
   5.320 -
   5.321 -end;