src/ZF/Tools/primrec_package.ML
author wenzelm
Wed Apr 13 18:45:25 2005 +0200 (2005-04-13)
changeset 15705 b5edb9dcec9a
parent 15703 727ef1b8b3ee
child 17057 0934ac31985f
permissions -rw-r--r--
*** MESSAGE REFERS TO PREVIOUS VERSION ***
Attrib.src;
     1 (*  Title:      ZF/Tools/primrec_package.ML
     2     ID:         $Id$
     3     Author:     Stefan Berghofer and Norbert Voelker
     4     Copyright   1998  TU Muenchen
     5     ZF version by Lawrence C Paulson (Cambridge)
     6 
     7 Package for defining functions on datatypes by primitive recursion
     8 *)
     9 
    10 signature PRIMREC_PACKAGE =
    11 sig
    12   val quiet_mode: bool ref
    13   val add_primrec: ((bstring * string) * Attrib.src list) list -> theory -> theory * thm list
    14   val add_primrec_i: ((bstring * term) * theory attribute list) list -> theory -> theory * thm list
    15 end;
    16 
    17 structure PrimrecPackage : PRIMREC_PACKAGE =
    18 struct
    19 
    20 (* messages *)
    21 
    22 val quiet_mode = ref false;
    23 fun message s = if ! quiet_mode then () else writeln s;
    24 
    25 
    26 exception RecError of string;
    27 
    28 (*Remove outer Trueprop and equality sign*)
    29 val dest_eqn = FOLogic.dest_eq o FOLogic.dest_Trueprop;
    30 
    31 fun primrec_err s = error ("Primrec definition error:\n" ^ s);
    32 
    33 fun primrec_eq_err sign s eq =
    34   primrec_err (s ^ "\nin equation\n" ^ Sign.string_of_term sign eq);
    35 
    36 
    37 (* preprocessing of equations *)
    38 
    39 (*rec_fn_opt records equations already noted for this function*)
    40 fun process_eqn thy (eq, rec_fn_opt) =
    41   let
    42     val (lhs, rhs) =
    43         if null (term_vars eq) then
    44             dest_eqn eq handle TERM _ => raise RecError "not a proper equation"
    45         else raise RecError "illegal schematic variable(s)";
    46 
    47     val (recfun, args) = strip_comb lhs;
    48     val (fname, ftype) = dest_Const recfun handle TERM _ =>
    49       raise RecError "function is not declared as constant in theory";
    50 
    51     val (ls_frees, rest)  = take_prefix is_Free args;
    52     val (middle, rs_frees) = take_suffix is_Free rest;
    53 
    54     val (constr, cargs_frees) =
    55       if null middle then raise RecError "constructor missing"
    56       else strip_comb (hd middle);
    57     val (cname, _) = dest_Const constr
    58       handle TERM _ => raise RecError "ill-formed constructor";
    59     val con_info = valOf (Symtab.lookup (ConstructorsData.get thy, cname))
    60       handle Option =>
    61       raise RecError "cannot determine datatype associated with function"
    62 
    63     val (ls, cargs, rs) = (map dest_Free ls_frees,
    64                            map dest_Free cargs_frees,
    65                            map dest_Free rs_frees)
    66       handle TERM _ => raise RecError "illegal argument in pattern";
    67     val lfrees = ls @ rs @ cargs;
    68 
    69     (*Constructor, frees to left of pattern, pattern variables,
    70       frees to right of pattern, rhs of equation, full original equation. *)
    71     val new_eqn = (cname, (rhs, cargs, eq))
    72 
    73   in
    74     if not (null (duplicates lfrees)) then
    75       raise RecError "repeated variable name in pattern"
    76     else if not ((map dest_Free (term_frees rhs)) subset lfrees) then
    77       raise RecError "extra variables on rhs"
    78     else if length middle > 1 then
    79       raise RecError "more than one non-variable in pattern"
    80     else case rec_fn_opt of
    81         NONE => SOME (fname, ftype, ls, rs, con_info, [new_eqn])
    82       | SOME (fname', _, ls', rs', con_info': constructor_info, eqns) =>
    83           if isSome (assoc (eqns, cname)) then
    84             raise RecError "constructor already occurred as pattern"
    85           else if (ls <> ls') orelse (rs <> rs') then
    86             raise RecError "non-recursive arguments are inconsistent"
    87           else if #big_rec_name con_info <> #big_rec_name con_info' then
    88              raise RecError ("Mixed datatypes for function " ^ fname)
    89           else if fname <> fname' then
    90              raise RecError ("inconsistent functions for datatype " ^
    91                              #big_rec_name con_info)
    92           else SOME (fname, ftype, ls, rs, con_info, new_eqn::eqns)
    93   end
    94   handle RecError s => primrec_eq_err (sign_of thy) s eq;
    95 
    96 
    97 (*Instantiates a recursor equation with constructor arguments*)
    98 fun inst_recursor ((_ $ constr, rhs), cargs') =
    99     subst_atomic (#2 (strip_comb constr) ~~ map Free cargs') rhs;
   100 
   101 
   102 (*Convert a list of recursion equations into a recursor call*)
   103 fun process_fun thy (fname, ftype, ls, rs, con_info: constructor_info, eqns) =
   104   let
   105     val fconst = Const(fname, ftype)
   106     val fabs = list_comb (fconst, map Free ls @ [Bound 0] @ map Free rs)
   107     and {big_rec_name, constructors, rec_rewrites, ...} = con_info
   108 
   109     (*Replace X_rec(args,t) by fname(ls,t,rs) *)
   110     fun use_fabs (_ $ t) = subst_bound (t, fabs)
   111       | use_fabs t       = t
   112 
   113     val cnames         = map (#1 o dest_Const) constructors
   114     and recursor_pairs = map (dest_eqn o concl_of) rec_rewrites
   115 
   116     fun absterm (Free(a,T), body) = absfree (a,T,body)
   117       | absterm (t,body) = Abs("rec", Ind_Syntax.iT, abstract_over (t, body))
   118 
   119     (*Translate rec equations into function arguments suitable for recursor.
   120       Missing cases are replaced by 0 and all cases are put into order.*)
   121     fun add_case ((cname, recursor_pair), cases) =
   122       let val (rhs, recursor_rhs, eq) =
   123             case assoc (eqns, cname) of
   124                 NONE => (warning ("no equation for constructor " ^ cname ^
   125                                   "\nin definition of function " ^ fname);
   126                          (Const ("0", Ind_Syntax.iT),
   127                           #2 recursor_pair, Const ("0", Ind_Syntax.iT)))
   128               | SOME (rhs, cargs', eq) =>
   129                     (rhs, inst_recursor (recursor_pair, cargs'), eq)
   130           val allowed_terms = map use_fabs (#2 (strip_comb recursor_rhs))
   131           val abs = foldr absterm rhs allowed_terms
   132       in
   133           if !Ind_Syntax.trace then
   134               writeln ("recursor_rhs = " ^
   135                        Sign.string_of_term (sign_of thy) recursor_rhs ^
   136                        "\nabs = " ^ Sign.string_of_term (sign_of thy) abs)
   137           else();
   138           if Logic.occs (fconst, abs) then
   139               primrec_eq_err (sign_of thy)
   140                    ("illegal recursive occurrences of " ^ fname)
   141                    eq
   142           else abs :: cases
   143       end
   144 
   145     val recursor = head_of (#1 (hd recursor_pairs))
   146 
   147     (** make definition **)
   148 
   149     (*the recursive argument*)
   150     val rec_arg = Free (variant (map #1 (ls@rs)) (Sign.base_name big_rec_name),
   151                         Ind_Syntax.iT)
   152 
   153     val def_tm = Logic.mk_equals
   154                     (subst_bound (rec_arg, fabs),
   155                      list_comb (recursor,
   156                                 foldr add_case [] (cnames ~~ recursor_pairs))
   157                      $ rec_arg)
   158 
   159   in
   160       if !Ind_Syntax.trace then
   161             writeln ("primrec def:\n" ^
   162                      Sign.string_of_term (sign_of thy) def_tm)
   163       else();
   164       (Sign.base_name fname ^ "_" ^ Sign.base_name big_rec_name ^ "_def",
   165        def_tm)
   166   end;
   167 
   168 
   169 (* prepare functions needed for definitions *)
   170 
   171 fun add_primrec_i args thy =
   172   let
   173     val ((eqn_names, eqn_terms), eqn_atts) = apfst split_list (split_list args);
   174     val SOME (fname, ftype, ls, rs, con_info, eqns) =
   175       foldr (process_eqn thy) NONE eqn_terms;
   176     val def = process_fun thy (fname, ftype, ls, rs, con_info, eqns);
   177 
   178     val (thy1, [def_thm]) = thy
   179       |> Theory.add_path (Sign.base_name fname)
   180       |> (PureThy.add_defs_i false o map Thm.no_attributes) [def];
   181 
   182     val rewrites = def_thm :: map mk_meta_eq (#rec_rewrites con_info)
   183     val eqn_thms =
   184       (message ("Proving equations for primrec function " ^ fname);
   185       map (fn t => prove_goalw_cterm rewrites
   186         (Ind_Syntax.traceIt "next primrec equation = " (cterm_of (sign_of thy1) t))
   187         (fn _ => [rtac refl 1])) eqn_terms);
   188 
   189     val (thy2, eqn_thms') = thy1 |> PureThy.add_thms ((eqn_names ~~ eqn_thms) ~~ eqn_atts);
   190     val thy3 = thy2
   191       |> (#1 o PureThy.add_thmss [(("simps", eqn_thms'), [Simplifier.simp_add_global])])
   192       |> Theory.parent_path;
   193   in (thy3, eqn_thms') end;
   194 
   195 fun add_primrec args thy =
   196   add_primrec_i (map (fn ((name, s), srcs) =>
   197     ((name, Sign.simple_read_term (sign_of thy) propT s), map (Attrib.global_attribute thy) srcs))
   198     args) thy;
   199 
   200 
   201 (* outer syntax *)
   202 
   203 local structure P = OuterParse and K = OuterSyntax.Keyword in
   204 
   205 val primrec_decl = Scan.repeat1 (P.opt_thm_name ":" -- P.prop);
   206 
   207 val primrecP =
   208   OuterSyntax.command "primrec" "define primitive recursive functions on datatypes" K.thy_decl
   209     (primrec_decl >> (Toplevel.theory o (#1 oo (add_primrec o map P.triple_swap))));
   210 
   211 val _ = OuterSyntax.add_parsers [primrecP];
   212 
   213 end;
   214 
   215 end;
   216 
   217