src/ZF/Tools/primrec_package.ML
author wenzelm
Sat Jan 14 22:25:34 2006 +0100 (2006-01-14)
changeset 18688 abf0f018b5ec
parent 18377 0e1d025d57b3
child 18728 6790126ab5f6
permissions -rw-r--r--
generic attributes;
paulson@6050
     1
(*  Title:      ZF/Tools/primrec_package.ML
paulson@6050
     2
    ID:         $Id$
paulson@6050
     3
    Author:     Stefan Berghofer and Norbert Voelker
paulson@6050
     4
    Copyright   1998  TU Muenchen
paulson@6050
     5
    ZF version by Lawrence C Paulson (Cambridge)
paulson@6050
     6
paulson@6050
     7
Package for defining functions on datatypes by primitive recursion
paulson@6050
     8
*)
paulson@6050
     9
paulson@6050
    10
signature PRIMREC_PACKAGE =
paulson@6050
    11
sig
wenzelm@12183
    12
  val quiet_mode: bool ref
wenzelm@15703
    13
  val add_primrec: ((bstring * string) * Attrib.src list) list -> theory -> theory * thm list
wenzelm@12183
    14
  val add_primrec_i: ((bstring * term) * theory attribute list) list -> theory -> theory * thm list
paulson@6050
    15
end;
paulson@6050
    16
paulson@6050
    17
structure PrimrecPackage : PRIMREC_PACKAGE =
paulson@6050
    18
struct
paulson@6050
    19
wenzelm@12183
    20
(* messages *)
wenzelm@12183
    21
wenzelm@12183
    22
val quiet_mode = ref false;
wenzelm@12183
    23
fun message s = if ! quiet_mode then () else writeln s;
wenzelm@12183
    24
wenzelm@12183
    25
paulson@6050
    26
exception RecError of string;
paulson@6050
    27
paulson@6141
    28
(*Remove outer Trueprop and equality sign*)
paulson@6141
    29
val dest_eqn = FOLogic.dest_eq o FOLogic.dest_Trueprop;
paulson@6050
    30
paulson@6050
    31
fun primrec_err s = error ("Primrec definition error:\n" ^ s);
paulson@6050
    32
paulson@6050
    33
fun primrec_eq_err sign s eq =
paulson@6050
    34
  primrec_err (s ^ "\nin equation\n" ^ Sign.string_of_term sign eq);
paulson@6050
    35
wenzelm@12183
    36
paulson@6050
    37
(* preprocessing of equations *)
paulson@6050
    38
paulson@6050
    39
(*rec_fn_opt records equations already noted for this function*)
wenzelm@12183
    40
fun process_eqn thy (eq, rec_fn_opt) =
paulson@6050
    41
  let
wenzelm@12183
    42
    val (lhs, rhs) =
wenzelm@12183
    43
        if null (term_vars eq) then
wenzelm@12203
    44
            dest_eqn eq handle TERM _ => raise RecError "not a proper equation"
wenzelm@12183
    45
        else raise RecError "illegal schematic variable(s)";
paulson@6050
    46
paulson@6050
    47
    val (recfun, args) = strip_comb lhs;
wenzelm@12203
    48
    val (fname, ftype) = dest_Const recfun handle TERM _ =>
paulson@6050
    49
      raise RecError "function is not declared as constant in theory";
paulson@6050
    50
paulson@6050
    51
    val (ls_frees, rest)  = take_prefix is_Free args;
paulson@6050
    52
    val (middle, rs_frees) = take_suffix is_Free rest;
paulson@6050
    53
wenzelm@12183
    54
    val (constr, cargs_frees) =
paulson@6050
    55
      if null middle then raise RecError "constructor missing"
paulson@6050
    56
      else strip_comb (hd middle);
paulson@6050
    57
    val (cname, _) = dest_Const constr
wenzelm@12203
    58
      handle TERM _ => raise RecError "ill-formed constructor";
wenzelm@17412
    59
    val con_info = the (Symtab.lookup (ConstructorsData.get thy) cname)
skalberg@15531
    60
      handle Option =>
paulson@6050
    61
      raise RecError "cannot determine datatype associated with function"
paulson@6050
    62
wenzelm@12183
    63
    val (ls, cargs, rs) = (map dest_Free ls_frees,
wenzelm@12183
    64
                           map dest_Free cargs_frees,
wenzelm@12183
    65
                           map dest_Free rs_frees)
wenzelm@12203
    66
      handle TERM _ => raise RecError "illegal argument in pattern";
paulson@6050
    67
    val lfrees = ls @ rs @ cargs;
paulson@6050
    68
paulson@6050
    69
    (*Constructor, frees to left of pattern, pattern variables,
paulson@6050
    70
      frees to right of pattern, rhs of equation, full original equation. *)
paulson@6050
    71
    val new_eqn = (cname, (rhs, cargs, eq))
paulson@6050
    72
paulson@6050
    73
  in
wenzelm@12183
    74
    if not (null (duplicates lfrees)) then
wenzelm@12183
    75
      raise RecError "repeated variable name in pattern"
paulson@6050
    76
    else if not ((map dest_Free (term_frees rhs)) subset lfrees) then
paulson@6050
    77
      raise RecError "extra variables on rhs"
wenzelm@12183
    78
    else if length middle > 1 then
paulson@6050
    79
      raise RecError "more than one non-variable in pattern"
paulson@6050
    80
    else case rec_fn_opt of
skalberg@15531
    81
        NONE => SOME (fname, ftype, ls, rs, con_info, [new_eqn])
skalberg@15531
    82
      | SOME (fname', _, ls', rs', con_info': constructor_info, eqns) =>
haftmann@17314
    83
          if AList.defined (op =) eqns cname then
wenzelm@12183
    84
            raise RecError "constructor already occurred as pattern"
wenzelm@12183
    85
          else if (ls <> ls') orelse (rs <> rs') then
wenzelm@12183
    86
            raise RecError "non-recursive arguments are inconsistent"
wenzelm@12183
    87
          else if #big_rec_name con_info <> #big_rec_name con_info' then
wenzelm@12183
    88
             raise RecError ("Mixed datatypes for function " ^ fname)
wenzelm@12183
    89
          else if fname <> fname' then
wenzelm@12183
    90
             raise RecError ("inconsistent functions for datatype " ^
wenzelm@12183
    91
                             #big_rec_name con_info)
skalberg@15531
    92
          else SOME (fname, ftype, ls, rs, con_info, new_eqn::eqns)
paulson@6050
    93
  end
paulson@6050
    94
  handle RecError s => primrec_eq_err (sign_of thy) s eq;
paulson@6050
    95
paulson@6050
    96
paulson@6050
    97
(*Instantiates a recursor equation with constructor arguments*)
wenzelm@12183
    98
fun inst_recursor ((_ $ constr, rhs), cargs') =
paulson@6050
    99
    subst_atomic (#2 (strip_comb constr) ~~ map Free cargs') rhs;
paulson@6050
   100
paulson@6050
   101
paulson@6050
   102
(*Convert a list of recursion equations into a recursor call*)
paulson@6050
   103
fun process_fun thy (fname, ftype, ls, rs, con_info: constructor_info, eqns) =
paulson@6050
   104
  let
paulson@6050
   105
    val fconst = Const(fname, ftype)
paulson@6050
   106
    val fabs = list_comb (fconst, map Free ls @ [Bound 0] @ map Free rs)
paulson@6050
   107
    and {big_rec_name, constructors, rec_rewrites, ...} = con_info
paulson@6050
   108
paulson@6050
   109
    (*Replace X_rec(args,t) by fname(ls,t,rs) *)
paulson@6050
   110
    fun use_fabs (_ $ t) = subst_bound (t, fabs)
paulson@6050
   111
      | use_fabs t       = t
paulson@6050
   112
paulson@6050
   113
    val cnames         = map (#1 o dest_Const) constructors
paulson@6141
   114
    and recursor_pairs = map (dest_eqn o concl_of) rec_rewrites
paulson@6050
   115
paulson@6050
   116
    fun absterm (Free(a,T), body) = absfree (a,T,body)
paulson@9179
   117
      | absterm (t,body) = Abs("rec", Ind_Syntax.iT, abstract_over (t, body))
paulson@6050
   118
paulson@6050
   119
    (*Translate rec equations into function arguments suitable for recursor.
paulson@6050
   120
      Missing cases are replaced by 0 and all cases are put into order.*)
paulson@6050
   121
    fun add_case ((cname, recursor_pair), cases) =
wenzelm@12183
   122
      let val (rhs, recursor_rhs, eq) =
haftmann@17314
   123
            case AList.lookup (op =) eqns cname of
skalberg@15531
   124
                NONE => (warning ("no equation for constructor " ^ cname ^
wenzelm@12183
   125
                                  "\nin definition of function " ^ fname);
wenzelm@12183
   126
                         (Const ("0", Ind_Syntax.iT),
wenzelm@12183
   127
                          #2 recursor_pair, Const ("0", Ind_Syntax.iT)))
skalberg@15531
   128
              | SOME (rhs, cargs', eq) =>
wenzelm@12183
   129
                    (rhs, inst_recursor (recursor_pair, cargs'), eq)
wenzelm@12183
   130
          val allowed_terms = map use_fabs (#2 (strip_comb recursor_rhs))
skalberg@15574
   131
          val abs = foldr absterm rhs allowed_terms
wenzelm@12183
   132
      in
paulson@6050
   133
          if !Ind_Syntax.trace then
wenzelm@12183
   134
              writeln ("recursor_rhs = " ^
wenzelm@12183
   135
                       Sign.string_of_term (sign_of thy) recursor_rhs ^
wenzelm@12183
   136
                       "\nabs = " ^ Sign.string_of_term (sign_of thy) abs)
paulson@6050
   137
          else();
wenzelm@12183
   138
          if Logic.occs (fconst, abs) then
wenzelm@12183
   139
              primrec_eq_err (sign_of thy)
wenzelm@12183
   140
                   ("illegal recursive occurrences of " ^ fname)
wenzelm@12183
   141
                   eq
wenzelm@12183
   142
          else abs :: cases
paulson@6050
   143
      end
paulson@6050
   144
paulson@6050
   145
    val recursor = head_of (#1 (hd recursor_pairs))
paulson@6050
   146
paulson@6050
   147
    (** make definition **)
paulson@6050
   148
paulson@6050
   149
    (*the recursive argument*)
paulson@6050
   150
    val rec_arg = Free (variant (map #1 (ls@rs)) (Sign.base_name big_rec_name),
wenzelm@12183
   151
                        Ind_Syntax.iT)
paulson@6050
   152
paulson@6050
   153
    val def_tm = Logic.mk_equals
wenzelm@12183
   154
                    (subst_bound (rec_arg, fabs),
wenzelm@12183
   155
                     list_comb (recursor,
skalberg@15574
   156
                                foldr add_case [] (cnames ~~ recursor_pairs))
wenzelm@12183
   157
                     $ rec_arg)
paulson@6050
   158
paulson@6050
   159
  in
paulson@6065
   160
      if !Ind_Syntax.trace then
wenzelm@12183
   161
            writeln ("primrec def:\n" ^
wenzelm@12183
   162
                     Sign.string_of_term (sign_of thy) def_tm)
paulson@6065
   163
      else();
paulson@6050
   164
      (Sign.base_name fname ^ "_" ^ Sign.base_name big_rec_name ^ "_def",
paulson@6050
   165
       def_tm)
paulson@6050
   166
  end;
paulson@6050
   167
paulson@6050
   168
paulson@6050
   169
(* prepare functions needed for definitions *)
paulson@6050
   170
wenzelm@12183
   171
fun add_primrec_i args thy =
paulson@6050
   172
  let
wenzelm@12183
   173
    val ((eqn_names, eqn_terms), eqn_atts) = apfst split_list (split_list args);
skalberg@15531
   174
    val SOME (fname, ftype, ls, rs, con_info, eqns) =
skalberg@15574
   175
      foldr (process_eqn thy) NONE eqn_terms;
wenzelm@12183
   176
    val def = process_fun thy (fname, ftype, ls, rs, con_info, eqns);
wenzelm@12183
   177
haftmann@18358
   178
    val ([def_thm], thy1) = thy
wenzelm@12188
   179
      |> Theory.add_path (Sign.base_name fname)
wenzelm@12183
   180
      |> (PureThy.add_defs_i false o map Thm.no_attributes) [def];
wenzelm@12183
   181
wenzelm@8438
   182
    val rewrites = def_thm :: map mk_meta_eq (#rec_rewrites con_info)
wenzelm@12183
   183
    val eqn_thms =
wenzelm@17985
   184
     (message ("Proving equations for primrec function " ^ fname);
wenzelm@17985
   185
      eqn_terms |> map (fn t =>
wenzelm@17985
   186
        standard (Goal.prove thy1 [] [] (Ind_Syntax.traceIt "next primrec equation = " thy1 t)
wenzelm@17985
   187
          (fn _ => EVERY [rewrite_goals_tac rewrites, rtac refl 1]))));
wenzelm@12183
   188
haftmann@18377
   189
    val (eqn_thms', thy2) =
haftmann@18377
   190
      thy1
haftmann@18377
   191
      |> PureThy.add_thms ((eqn_names ~~ eqn_thms) ~~ eqn_atts);
haftmann@18377
   192
    val (_, thy3) =
haftmann@18377
   193
      thy2
wenzelm@18688
   194
      |> PureThy.add_thmss [(("simps", eqn_thms'), [Attrib.theory Simplifier.simp_add])]
haftmann@18377
   195
      ||> Theory.parent_path;
wenzelm@12183
   196
  in (thy3, eqn_thms') end;
wenzelm@12183
   197
wenzelm@12183
   198
fun add_primrec args thy =
wenzelm@12183
   199
  add_primrec_i (map (fn ((name, s), srcs) =>
wenzelm@12183
   200
    ((name, Sign.simple_read_term (sign_of thy) propT s), map (Attrib.global_attribute thy) srcs))
wenzelm@12183
   201
    args) thy;
wenzelm@12183
   202
wenzelm@12183
   203
wenzelm@12183
   204
(* outer syntax *)
paulson@6050
   205
wenzelm@17057
   206
local structure P = OuterParse and K = OuterKeyword in
wenzelm@12183
   207
wenzelm@12876
   208
val primrec_decl = Scan.repeat1 (P.opt_thm_name ":" -- P.prop);
wenzelm@12183
   209
wenzelm@12183
   210
val primrecP =
wenzelm@12183
   211
  OuterSyntax.command "primrec" "define primitive recursive functions on datatypes" K.thy_decl
wenzelm@12183
   212
    (primrec_decl >> (Toplevel.theory o (#1 oo (add_primrec o map P.triple_swap))));
wenzelm@12183
   213
wenzelm@12183
   214
val _ = OuterSyntax.add_parsers [primrecP];
paulson@6050
   215
paulson@6050
   216
end;
wenzelm@12183
   217
wenzelm@12183
   218
end;
wenzelm@15705
   219
wenzelm@15705
   220