src/HOL/Tools/primrec_package.ML
author bulwahn
Sat May 16 15:23:52 2009 +0200 (2009-05-16)
changeset 31172 74d72ba262fb
parent 30487 a14ff49d3083
child 31174 f1f1e9b53c81
permissions -rw-r--r--
added collection of simplification rules of recursive functions for quickcheck
     1 (*  Title:      HOL/Tools/primrec_package.ML
     2     Author:     Stefan Berghofer, TU Muenchen; Norbert Voelker, FernUni Hagen;
     3                 Florian Haftmann, TU Muenchen
     4 
     5 Package for defining functions on datatypes by primitive recursion.
     6 *)
     7 
     8 signature PRIMREC_PACKAGE =
     9 sig
    10   val add_primrec: (binding * typ option * mixfix) list ->
    11     (Attrib.binding * term) list -> local_theory -> thm list * local_theory
    12   val add_primrec_cmd: (binding * string option * mixfix) list ->
    13     (Attrib.binding * string) list -> local_theory -> thm list * local_theory
    14   val add_primrec_global: (binding * typ option * mixfix) list ->
    15     (Attrib.binding * term) list -> theory -> thm list * theory
    16   val add_primrec_overloaded: (string * (string * typ) * bool) list ->
    17     (binding * typ option * mixfix) list ->
    18     (Attrib.binding * term) list -> theory -> thm list * theory
    19 end;
    20 
    21 structure PrimrecPackage : PRIMREC_PACKAGE =
    22 struct
    23 
    24 open DatatypeAux;
    25 
    26 exception PrimrecError of string * term option;
    27 
    28 fun primrec_error msg = raise PrimrecError (msg, NONE);
    29 fun primrec_error_eqn msg eqn = raise PrimrecError (msg, SOME eqn);
    30 
    31 fun message s = if ! Toplevel.debug then tracing s else ();
    32 
    33 
    34 (* preprocessing of equations *)
    35 
    36 fun process_eqn is_fixed spec rec_fns =
    37   let
    38     val (vs, Ts) = split_list (strip_qnt_vars "all" spec);
    39     val body = strip_qnt_body "all" spec;
    40     val (vs', _) = Name.variants vs (Name.make_context (fold_aterms
    41       (fn Free (v, _) => insert (op =) v | _ => I) body []));
    42     val eqn = curry subst_bounds (map2 (curry Free) vs' Ts |> rev) body;
    43     val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn)
    44       handle TERM _ => primrec_error "not a proper equation";
    45     val (recfun, args) = strip_comb lhs;
    46     val fname = case recfun of Free (v, _) => if is_fixed v then v
    47           else primrec_error "illegal head of function equation"
    48       | _ => primrec_error "illegal head of function equation";
    49 
    50     val (ls', rest)  = take_prefix is_Free args;
    51     val (middle, rs') = take_suffix is_Free rest;
    52     val rpos = length ls';
    53 
    54     val (constr, cargs') = if null middle then primrec_error "constructor missing"
    55       else strip_comb (hd middle);
    56     val (cname, T) = dest_Const constr
    57       handle TERM _ => primrec_error "ill-formed constructor";
    58     val (tname, _) = dest_Type (body_type T) handle TYPE _ =>
    59       primrec_error "cannot determine datatype associated with function"
    60 
    61     val (ls, cargs, rs) =
    62       (map dest_Free ls', map dest_Free cargs', map dest_Free rs')
    63       handle TERM _ => primrec_error "illegal argument in pattern";
    64     val lfrees = ls @ rs @ cargs;
    65 
    66     fun check_vars _ [] = ()
    67       | check_vars s vars = primrec_error (s ^ commas_quote (map fst vars)) eqn;
    68   in
    69     if length middle > 1 then
    70       primrec_error "more than one non-variable in pattern"
    71     else
    72      (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees);
    73       check_vars "extra variables on rhs: "
    74         (map dest_Free (OldTerm.term_frees rhs) |> subtract (op =) lfrees
    75           |> filter_out (is_fixed o fst));
    76       case AList.lookup (op =) rec_fns fname of
    77         NONE =>
    78           (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eqn))]))::rec_fns
    79       | SOME (_, rpos', eqns) =>
    80           if AList.defined (op =) eqns cname then
    81             primrec_error "constructor already occurred as pattern"
    82           else if rpos <> rpos' then
    83             primrec_error "position of recursive argument inconsistent"
    84           else
    85             AList.update (op =)
    86               (fname, (tname, rpos, (cname, (ls, cargs, rs, rhs, eqn))::eqns))
    87               rec_fns)
    88   end handle PrimrecError (msg, NONE) => primrec_error_eqn msg spec;
    89 
    90 fun process_fun descr eqns (i, fname) (fnames, fnss) =
    91   let
    92     val (_, (tname, _, constrs)) = nth descr i;
    93 
    94     (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *)
    95 
    96     fun subst [] t fs = (t, fs)
    97       | subst subs (Abs (a, T, t)) fs =
    98           fs
    99           |> subst subs t
   100           |-> (fn t' => pair (Abs (a, T, t')))
   101       | subst subs (t as (_ $ _)) fs =
   102           let
   103             val (f, ts) = strip_comb t;
   104           in
   105             if is_Free f
   106               andalso member (fn ((v, _), (w, _)) => v = w) eqns (dest_Free f) then
   107               let
   108                 val (fname', _) = dest_Free f;
   109                 val (_, rpos, _) = the (AList.lookup (op =) eqns fname');
   110                 val (ls, rs) = chop rpos ts
   111                 val (x', rs') = case rs
   112                  of x' :: rs => (x', rs)
   113                   | [] => primrec_error ("not enough arguments in recursive application\n"
   114                       ^ "of function " ^ quote fname' ^ " on rhs");
   115                 val (x, xs) = strip_comb x';
   116               in case AList.lookup (op =) subs x
   117                of NONE =>
   118                     fs
   119                     |> fold_map (subst subs) ts
   120                     |-> (fn ts' => pair (list_comb (f, ts')))
   121                 | SOME (i', y) =>
   122                     fs
   123                     |> fold_map (subst subs) (xs @ ls @ rs')
   124                     ||> process_fun descr eqns (i', fname')
   125                     |-> (fn ts' => pair (list_comb (y, ts')))
   126               end
   127             else
   128               fs
   129               |> fold_map (subst subs) (f :: ts)
   130               |-> (fn (f'::ts') => pair (list_comb (f', ts')))
   131           end
   132       | subst _ t fs = (t, fs);
   133 
   134     (* translate rec equations into function arguments suitable for rec comb *)
   135 
   136     fun trans eqns (cname, cargs) (fnames', fnss', fns) =
   137       (case AList.lookup (op =) eqns cname of
   138           NONE => (warning ("No equation for constructor " ^ quote cname ^
   139             "\nin definition of function " ^ quote fname);
   140               (fnames', fnss', (Const ("HOL.undefined", dummyT))::fns))
   141         | SOME (ls, cargs', rs, rhs, eq) =>
   142             let
   143               val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
   144               val rargs = map fst recs;
   145               val subs = map (rpair dummyT o fst)
   146                 (rev (Term.rename_wrt_term rhs rargs));
   147               val (rhs', (fnames'', fnss'')) = subst (map2 (fn (x, y) => fn z =>
   148                 (Free x, (body_index y, Free z))) recs subs) rhs (fnames', fnss')
   149                   handle PrimrecError (s, NONE) => primrec_error_eqn s eq
   150             in (fnames'', fnss'',
   151                 (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
   152             end)
   153 
   154   in (case AList.lookup (op =) fnames i of
   155       NONE =>
   156         if exists (fn (_, v) => fname = v) fnames then
   157           primrec_error ("inconsistent functions for datatype " ^ quote tname)
   158         else
   159           let
   160             val (_, _, eqns) = the (AList.lookup (op =) eqns fname);
   161             val (fnames', fnss', fns) = fold_rev (trans eqns) constrs
   162               ((i, fname)::fnames, fnss, [])
   163           in
   164             (fnames', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
   165           end
   166     | SOME fname' =>
   167         if fname = fname' then (fnames, fnss)
   168         else primrec_error ("inconsistent functions for datatype " ^ quote tname))
   169   end;
   170 
   171 
   172 (* prepare functions needed for definitions *)
   173 
   174 fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) =
   175   case AList.lookup (op =) fns i of
   176      NONE =>
   177        let
   178          val dummy_fns = map (fn (_, cargs) => Const ("HOL.undefined",
   179            replicate ((length cargs) + (length (List.filter is_rec_type cargs)))
   180              dummyT ---> HOLogic.unitT)) constrs;
   181          val _ = warning ("No function definition for datatype " ^ quote tname)
   182        in
   183          (dummy_fns @ fs, defs)
   184        end
   185    | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs);
   186 
   187 
   188 (* make definition *)
   189 
   190 fun make_def ctxt fixes fs (fname, ls, rec_name, tname) =
   191   let
   192     val SOME (var, varT) = get_first (fn ((b, T), mx) =>
   193       if Binding.name_of b = fname then SOME ((b, mx), T) else NONE) fixes;
   194     val def_name = Thm.def_name (Long_Name.base_name fname);
   195     val raw_rhs = fold_rev (fn T => fn t => Abs ("", T, t)) (map snd ls @ [dummyT])
   196       (list_comb (Const (rec_name, dummyT), fs @ map Bound (0 :: (length ls downto 1))))
   197     val rhs = singleton (Syntax.check_terms ctxt)
   198       (TypeInfer.constrain varT raw_rhs);
   199   in (var, ((Binding.name def_name, []), rhs)) end;
   200 
   201 
   202 (* find datatypes which contain all datatypes in tnames' *)
   203 
   204 fun find_dts (dt_info : datatype_info Symtab.table) _ [] = []
   205   | find_dts dt_info tnames' (tname::tnames) =
   206       (case Symtab.lookup dt_info tname of
   207           NONE => primrec_error (quote tname ^ " is not a datatype")
   208         | SOME dt =>
   209             if tnames' subset (map (#1 o snd) (#descr dt)) then
   210               (tname, dt)::(find_dts dt_info tnames' tnames)
   211             else find_dts dt_info tnames' tnames);
   212 
   213 
   214 (* primrec definition *)
   215 
   216 local
   217 
   218 fun prove_spec ctxt names rec_rewrites defs eqs =
   219   let
   220     val rewrites = map mk_meta_eq rec_rewrites @ map (snd o snd) defs;
   221     fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
   222     val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names);
   223   in map (fn (a, t) => (a, [Goal.prove ctxt [] [] t tac])) eqs end;
   224 
   225 fun gen_primrec set_group prep_spec raw_fixes raw_spec lthy =
   226   let
   227     val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy);
   228     val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v
   229       orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes) o snd) spec [];
   230     val tnames = distinct (op =) (map (#1 o snd) eqns);
   231     val dts = find_dts (DatatypePackage.get_datatypes (ProofContext.theory_of lthy)) tnames tnames;
   232     val main_fns = map (fn (tname, {index, ...}) =>
   233       (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts;
   234     val {descr, rec_names, rec_rewrites, ...} =
   235       if null dts then primrec_error
   236         ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
   237       else snd (hd dts);
   238     val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
   239     val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
   240     val names1 = map snd fnames;
   241     val names2 = map fst eqns;
   242     val _ = if gen_eq_set (op =) (names1, names2) then ()
   243       else primrec_error ("functions " ^ commas_quote names2 ^
   244         "\nare not mutually recursive");
   245     val prefix = space_implode "_" (map (Long_Name.base_name o #1) defs);
   246     val qualify = Binding.qualify false prefix;
   247     val spec' = (map o apfst)
   248       (fn (b, attrs) => (qualify b, Code.add_default_eqn_attrib :: attrs)) spec;
   249     val simp_atts = map (Attrib.internal o K)
   250       [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add, Quickcheck_RecFun_Simp_Thms.add];
   251   in
   252     lthy
   253     |> set_group ? LocalTheory.set_group (serial_string ())
   254     |> fold_map (LocalTheory.define Thm.definitionK o make_def lthy fixes fs) defs
   255     |-> (fn defs => `(fn ctxt => prove_spec ctxt names1 rec_rewrites defs spec'))
   256     |-> (fn simps => fold_map (LocalTheory.note Thm.theoremK) simps)
   257     |-> (fn simps' => LocalTheory.note Thm.theoremK
   258           ((qualify (Binding.qualified_name "simps"), simp_atts), maps snd simps'))
   259     |>> snd
   260   end handle PrimrecError (msg, some_eqn) =>
   261     error ("Primrec definition error:\n" ^ msg ^ (case some_eqn
   262      of SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
   263       | NONE => ""));
   264 
   265 in
   266 
   267 val add_primrec = gen_primrec false Specification.check_spec;
   268 val add_primrec_cmd = gen_primrec true Specification.read_spec;
   269 
   270 end;
   271 
   272 fun add_primrec_global fixes specs thy =
   273   let
   274     val lthy = TheoryTarget.init NONE thy;
   275     val (simps, lthy') = add_primrec fixes specs lthy;
   276     val simps' = ProofContext.export lthy' lthy simps;
   277   in (simps', LocalTheory.exit_global lthy') end;
   278 
   279 fun add_primrec_overloaded ops fixes specs thy =
   280   let
   281     val lthy = TheoryTarget.overloading ops thy;
   282     val (simps, lthy') = add_primrec fixes specs lthy;
   283     val simps' = ProofContext.export lthy' lthy simps;
   284   in (simps', LocalTheory.exit_global lthy') end;
   285 
   286 
   287 (* outer syntax *)
   288 
   289 local structure P = OuterParse and K = OuterKeyword in
   290 
   291 val opt_unchecked_name =
   292   Scan.optional (P.$$$ "(" |-- P.!!!
   293     (((P.$$$ "unchecked" >> K true) -- Scan.optional P.name "" ||
   294       P.name >> pair false) --| P.$$$ ")")) (false, "");
   295 
   296 val old_primrec_decl =
   297   opt_unchecked_name -- Scan.repeat1 ((SpecParse.opt_thm_name ":" >> apfst Binding.name_of) -- P.prop);
   298 
   299 val primrec_decl = P.opt_target -- P.fixes -- SpecParse.where_alt_specs;
   300 
   301 val _ =
   302   OuterSyntax.command "primrec" "define primitive recursive functions on datatypes" K.thy_decl
   303     ((primrec_decl >> (fn ((opt_target, fixes), specs) =>
   304       Toplevel.local_theory opt_target (add_primrec_cmd fixes specs #> snd)))
   305     || (old_primrec_decl >> (fn ((unchecked, alt_name), eqns) =>
   306       Toplevel.theory (snd o
   307         (if unchecked then OldPrimrecPackage.add_primrec_unchecked else OldPrimrecPackage.add_primrec)
   308           alt_name (map P.triple_swap eqns)))));
   309 
   310 end;
   311 
   312 end;