src/HOL/Tools/primrec_package.ML
author haftmann
Fri Dec 07 15:07:56 2007 +0100 (2007-12-07)
changeset 25570 fdfbbb92dadf
parent 25566 33f740c5e022
child 25604 6c1714b9b805
permissions -rw-r--r--
proper treatment of code theorems for primrec
     1 (*  Title:      HOL/Tools/primrec_package.ML
     2     ID:         $Id$
     3     Author:     Stefan Berghofer, TU Muenchen; Norbert Voelker, FernUni Hagen;
     4                 Florian Haftmann, TU Muenchen
     5 
     6 Package for defining functions on datatypes by primitive recursion.
     7 *)
     8 
     9 signature PRIMREC_PACKAGE =
    10 sig
    11   val add_primrec: (string * typ option * mixfix) list ->
    12     ((bstring * Attrib.src list) * term) list -> local_theory -> thm list * local_theory
    13 end;
    14 
    15 structure PrimrecPackage : PRIMREC_PACKAGE =
    16 struct
    17 
    18 open DatatypeAux;
    19 
    20 exception PrimrecError of string * term option;
    21 
    22 fun primrec_error msg = raise PrimrecError (msg, NONE);
    23 fun primrec_error_eqn msg eqn = raise PrimrecError (msg, SOME eqn);
    24 
    25 fun message s = if ! Toplevel.debug then () else writeln s;
    26 
    27 
    28 (* preprocessing of equations *)
    29 
    30 fun process_eqn is_fixed spec rec_fns =
    31   let
    32     val (vs, Ts) = split_list (strip_qnt_vars "all" spec);
    33     val body = strip_qnt_body "all" spec;
    34     val (vs', _) = Name.variants vs (Name.make_context (fold_aterms
    35       (fn Free (v, _) => insert (op =) v | _ => I) body []));
    36     val eqn = curry subst_bounds (map2 (curry Free) vs' Ts |> rev) body;
    37     val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn)
    38       handle TERM _ => primrec_error "not a proper equation";
    39     val (recfun, args) = strip_comb lhs;
    40     val fname = case recfun of Free (v, _) => if is_fixed v then v
    41           else primrec_error "illegal head of function equation"
    42       | _ => primrec_error "illegal head of function equation";
    43 
    44     val (ls', rest)  = take_prefix is_Free args;
    45     val (middle, rs') = take_suffix is_Free rest;
    46     val rpos = length ls';
    47 
    48     val (constr, cargs') = if null middle then primrec_error "constructor missing"
    49       else strip_comb (hd middle);
    50     val (cname, T) = dest_Const constr
    51       handle TERM _ => primrec_error "ill-formed constructor";
    52     val (tname, _) = dest_Type (body_type T) handle TYPE _ =>
    53       primrec_error "cannot determine datatype associated with function"
    54 
    55     val (ls, cargs, rs) =
    56       (map dest_Free ls', map dest_Free cargs', map dest_Free rs')
    57       handle TERM _ => primrec_error "illegal argument in pattern";
    58     val lfrees = ls @ rs @ cargs;
    59 
    60     fun check_vars _ [] = ()
    61       | check_vars s vars = primrec_error (s ^ commas_quote (map fst vars)) eqn;
    62   in
    63     if length middle > 1 then
    64       primrec_error "more than one non-variable in pattern"
    65     else
    66      (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees);
    67       check_vars "extra variables on rhs: "
    68         (map dest_Free (term_frees rhs) |> subtract (op =) lfrees
    69           |> filter_out (is_fixed o fst));
    70       case AList.lookup (op =) rec_fns fname of
    71         NONE =>
    72           (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eqn))]))::rec_fns
    73       | SOME (_, rpos', eqns) =>
    74           if AList.defined (op =) eqns cname then
    75             primrec_error "constructor already occurred as pattern"
    76           else if rpos <> rpos' then
    77             primrec_error "position of recursive argument inconsistent"
    78           else
    79             AList.update (op =)
    80               (fname, (tname, rpos, (cname, (ls, cargs, rs, rhs, eqn))::eqns))
    81               rec_fns)
    82   end handle PrimrecError (msg, NONE) => primrec_error_eqn msg spec;
    83 
    84 fun process_fun descr eqns (i, fname) (fnames, fnss) =
    85   let
    86     val (_, (tname, _, constrs)) = nth descr i;
    87 
    88     (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *)
    89 
    90     fun subst [] t fs = (t, fs)
    91       | subst subs (Abs (a, T, t)) fs =
    92           fs
    93           |> subst subs t
    94           |-> (fn t' => pair (Abs (a, T, t')))
    95       | subst subs (t as (_ $ _)) fs =
    96           let
    97             val (f, ts) = strip_comb t;
    98           in
    99             if is_Free f
   100               andalso member (fn ((v, _), (w, _)) => v = w) eqns (dest_Free f) then
   101               let
   102                 val (fname', _) = dest_Free f;
   103                 val (_, rpos, _) = the (AList.lookup (op =) eqns fname');
   104                 val (ls, x' :: rs) = chop rpos ts
   105                   handle Empty => primrec_error ("not enough arguments\
   106                    \ in recursive application\nof function " ^ quote fname' ^ " on rhs");
   107                 val (x, xs) = strip_comb x'
   108               in case AList.lookup (op =) subs x
   109                of NONE =>
   110                     fs
   111                     |> fold_map (subst subs) ts
   112                     |-> (fn ts' => pair (list_comb (f, ts')))
   113                 | SOME (i', y) =>
   114                     fs
   115                     |> fold_map (subst subs) (xs @ ls @ rs)
   116                     ||> process_fun descr eqns (i', fname')
   117                     |-> (fn ts' => pair (list_comb (y, ts')))
   118               end
   119             else
   120               fs
   121               |> fold_map (subst subs) (f :: ts)
   122               |-> (fn (f'::ts') => pair (list_comb (f', ts')))
   123           end
   124       | subst _ t fs = (t, fs);
   125 
   126     (* translate rec equations into function arguments suitable for rec comb *)
   127 
   128     fun trans eqns (cname, cargs) (fnames', fnss', fns) =
   129       (case AList.lookup (op =) eqns cname of
   130           NONE => (warning ("No equation for constructor " ^ quote cname ^
   131             "\nin definition of function " ^ quote fname);
   132               (fnames', fnss', (Const ("HOL.undefined", dummyT))::fns))
   133         | SOME (ls, cargs', rs, rhs, eq) =>
   134             let
   135               val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
   136               val rargs = map fst recs;
   137               val subs = map (rpair dummyT o fst)
   138                 (rev (rename_wrt_term rhs rargs));
   139               val (rhs', (fnames'', fnss'')) = (subst (map2 (fn (x, y) => fn z =>
   140                 (Free x, (body_index y, Free z))) recs subs) rhs (fnames', fnss'))
   141                   handle PrimrecError (s, NONE) => primrec_error_eqn s eq
   142             in (fnames'', fnss'',
   143                 (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
   144             end)
   145 
   146   in (case AList.lookup (op =) fnames i of
   147       NONE =>
   148         if exists (fn (_, v) => fname = v) fnames then
   149           primrec_error ("inconsistent functions for datatype " ^ quote tname)
   150         else
   151           let
   152             val (_, _, eqns) = the (AList.lookup (op =) eqns fname);
   153             val (fnames', fnss', fns) = fold_rev (trans eqns) constrs
   154               ((i, fname)::fnames, fnss, [])
   155           in
   156             (fnames', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
   157           end
   158     | SOME fname' =>
   159         if fname = fname' then (fnames, fnss)
   160         else primrec_error ("inconsistent functions for datatype " ^ quote tname))
   161   end;
   162 
   163 
   164 (* prepare functions needed for definitions *)
   165 
   166 fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) =
   167   case AList.lookup (op =) fns i of
   168      NONE =>
   169        let
   170          val dummy_fns = map (fn (_, cargs) => Const ("HOL.undefined",
   171            replicate ((length cargs) + (length (List.filter is_rec_type cargs)))
   172              dummyT ---> HOLogic.unitT)) constrs;
   173          val _ = warning ("No function definition for datatype " ^ quote tname)
   174        in
   175          (dummy_fns @ fs, defs)
   176        end
   177    | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs);
   178 
   179 
   180 (* make definition *)
   181 
   182 fun make_def ctxt fixes fs (fname, ls, rec_name, tname) =
   183   let
   184     val raw_rhs = fold_rev (fn T => fn t => Abs ("", T, t))
   185                     ((map snd ls) @ [dummyT])
   186                     (list_comb (Const (rec_name, dummyT),
   187                                 fs @ map Bound (0 ::(length ls downto 1))))
   188     val def_name = Thm.def_name (Sign.base_name fname);
   189     val rhs = singleton (Syntax.check_terms ctxt) raw_rhs;
   190     val SOME mfx = get_first
   191       (fn ((v, _), mfx) => if v = fname then SOME mfx else NONE) fixes;
   192   in ((fname, mfx), ((def_name, []), rhs)) end;
   193 
   194 
   195 (* find datatypes which contain all datatypes in tnames' *)
   196 
   197 fun find_dts (dt_info : datatype_info Symtab.table) _ [] = []
   198   | find_dts dt_info tnames' (tname::tnames) =
   199       (case Symtab.lookup dt_info tname of
   200           NONE => primrec_error (quote tname ^ " is not a datatype")
   201         | SOME dt =>
   202             if tnames' subset (map (#1 o snd) (#descr dt)) then
   203               (tname, dt)::(find_dts dt_info tnames' tnames)
   204             else find_dts dt_info tnames' tnames);
   205 
   206 
   207 (* adapted induction rule *)
   208 
   209 fun prepare_induct ({descr, induction, ...}: datatype_info) eqns =
   210   let
   211     fun constrs_of (_, (_, _, cs)) =
   212       map (fn (cname:string, (_, cargs, _, _, _)) => (cname, map fst cargs)) cs;
   213     val params_of = these o AList.lookup (op =) (List.concat (map constrs_of eqns));
   214   in
   215     induction
   216     |> RuleCases.rename_params (map params_of (maps (map #1 o #3 o #2) descr))
   217     |> RuleCases.save induction
   218   end;
   219 
   220 
   221 (* primrec definition *)
   222 
   223 local
   224 
   225 fun prepare_spec prep_spec ctxt raw_fixes raw_spec =
   226   let
   227     val ((fixes, spec), _) = prep_spec
   228       raw_fixes (map (single o apsnd single) raw_spec) ctxt
   229   in (fixes, map (apsnd the_single) spec) end;
   230 
   231 fun prove_spec ctxt names rec_rewrites defs =
   232   let
   233     val rewrites = map mk_meta_eq rec_rewrites @ map (snd o snd) defs;
   234     fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
   235     val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names);
   236   in map (fn (name_attr, t) => (name_attr, [Goal.prove ctxt [] [] t tac])) end;
   237 
   238 fun gen_primrec prep_spec raw_fixes raw_spec lthy =
   239   let
   240     val (fixes, spec) = prepare_spec prep_spec lthy raw_fixes raw_spec;
   241     val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v
   242       orelse exists (fn ((w, _), _) => v = w) fixes) o snd) spec [];
   243     val tnames = distinct (op =) (map (#1 o snd) eqns);
   244     val dts = find_dts (DatatypePackage.get_datatypes
   245       (ProofContext.theory_of lthy)) tnames tnames;
   246     val main_fns = map (fn (tname, {index, ...}) =>
   247       (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts;
   248     val {descr, rec_names, rec_rewrites, ...} =
   249       if null dts then primrec_error
   250         ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
   251       else snd (hd dts);
   252     val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
   253     val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
   254     val names1 = map snd fnames;
   255     val names2 = map fst eqns;
   256     val _ = if gen_eq_set (op =) (names1, names2) then ()
   257       else primrec_error ("functions " ^ commas_quote names2 ^
   258         "\nare not mutually recursive");
   259     val qualify = NameSpace.qualified
   260       (space_implode "_" (map (Sign.base_name o #1) defs));
   261     val simp_atts = map (Attrib.internal o K)
   262       [Simplifier.simp_add, RecfunCodegen.add NONE];
   263   in
   264     lthy
   265     |> fold_map (LocalTheory.define Thm.definitionK o make_def lthy fixes fs) defs
   266     |-> (fn defs => `(fn ctxt => prove_spec ctxt names1 rec_rewrites defs spec))
   267     |-> (fn simps => fold_map (LocalTheory.note Thm.theoremK) simps)
   268     |-> (fn simps' => LocalTheory.note Thm.theoremK
   269           ((qualify "simps", simp_atts), maps snd simps'))
   270     ||>> LocalTheory.note Thm.theoremK
   271           ((qualify "induct", []), [prepare_induct (#2 (hd dts)) eqns])
   272     |>> (snd o fst)
   273   end handle PrimrecError (msg, some_eqn) =>
   274     error ("Primrec definition error:\n" ^ msg ^ (case some_eqn
   275      of SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
   276       | NONE => ""));
   277 
   278 in
   279 
   280 val add_primrec = gen_primrec Specification.check_specification;
   281 val add_primrec_cmd = gen_primrec Specification.read_specification;
   282 
   283 end;
   284 
   285 
   286 (* outer syntax *)
   287 
   288 local structure P = OuterParse and K = OuterKeyword in
   289 
   290 val opt_unchecked_name =
   291   Scan.optional (P.$$$ "(" |-- P.!!!
   292     (((P.$$$ "unchecked" >> K true) -- Scan.optional P.name "" ||
   293       P.name >> pair false) --| P.$$$ ")")) (false, "");
   294 
   295 val old_primrec_decl =
   296   opt_unchecked_name -- Scan.repeat1 (SpecParse.opt_thm_name ":" -- P.prop);
   297 
   298 fun pipe_error t = P.!!! (Scan.fail_with (K
   299   (cat_lines ["Equations must be separated by " ^ quote "|", quote t])));
   300 
   301 val statement = SpecParse.opt_thm_name ":" -- P.prop --| Scan.ahead
   302   ((P.term :-- pipe_error) || Scan.succeed ("",""));
   303 
   304 val statements = P.enum1 "|" statement;
   305 
   306 val primrec_decl = P.opt_target -- P.fixes --| P.$$$ "where" -- statements;
   307 
   308 val _ =
   309   OuterSyntax.command "primrec" "define primitive recursive functions on datatypes" K.thy_decl
   310     ((primrec_decl >> (fn ((opt_target, raw_fixes), raw_spec) =>
   311       Toplevel.local_theory opt_target (add_primrec_cmd raw_fixes raw_spec #> snd)))
   312     || (old_primrec_decl >> (fn ((unchecked, alt_name), eqns) =>
   313       Toplevel.theory (snd o
   314         (if unchecked then OldPrimrecPackage.add_primrec_unchecked else OldPrimrecPackage.add_primrec) alt_name
   315           (map P.triple_swap eqns)))));
   316 
   317 end;
   318 
   319 end;