src/ZF/Tools/primrec_package.ML
author wenzelm
Thu Oct 30 16:55:29 2014 +0100 (2014-10-30)
changeset 58838 59203adfc33f
parent 58011 bc6bced136e5
child 59498 50b60f501b05
permissions -rw-r--r--
eliminated aliases;
     1 (*  Title:      ZF/Tools/primrec_package.ML
     2     Author:     Norbert Voelker, FernUni Hagen
     3     Author:     Stefan Berghofer, TU Muenchen
     4     Author:     Lawrence C Paulson, Cambridge University Computer Laboratory
     5 
     6 Package for defining functions on datatypes by primitive recursion.
     7 *)
     8 
     9 signature PRIMREC_PACKAGE =
    10 sig
    11   val add_primrec: ((binding * string) * Token.src list) list -> theory -> theory * thm list
    12   val add_primrec_i: ((binding * term) * attribute list) list -> theory -> theory * thm list
    13 end;
    14 
    15 structure PrimrecPackage : PRIMREC_PACKAGE =
    16 struct
    17 
    18 exception RecError of string;
    19 
    20 (*Remove outer Trueprop and equality sign*)
    21 val dest_eqn = FOLogic.dest_eq o FOLogic.dest_Trueprop;
    22 
    23 fun primrec_err s = error ("Primrec definition error:\n" ^ s);
    24 
    25 fun primrec_eq_err sign s eq =
    26   primrec_err (s ^ "\nin equation\n" ^ Syntax.string_of_term_global sign eq);
    27 
    28 
    29 (* preprocessing of equations *)
    30 
    31 (*rec_fn_opt records equations already noted for this function*)
    32 fun process_eqn thy (eq, rec_fn_opt) =
    33   let
    34     val (lhs, rhs) =
    35         if null (Term.add_vars eq []) then
    36             dest_eqn eq handle TERM _ => raise RecError "not a proper equation"
    37         else raise RecError "illegal schematic variable(s)";
    38 
    39     val (recfun, args) = strip_comb lhs;
    40     val (fname, ftype) = dest_Const recfun handle TERM _ =>
    41       raise RecError "function is not declared as constant in theory";
    42 
    43     val (ls_frees, rest)  = take_prefix is_Free args;
    44     val (middle, rs_frees) = take_suffix is_Free rest;
    45 
    46     val (constr, cargs_frees) =
    47       if null middle then raise RecError "constructor missing"
    48       else strip_comb (hd middle);
    49     val (cname, _) = dest_Const constr
    50       handle TERM _ => raise RecError "ill-formed constructor";
    51     val con_info = the (Symtab.lookup (ConstructorsData.get thy) cname)
    52       handle Option.Option =>
    53       raise RecError "cannot determine datatype associated with function"
    54 
    55     val (ls, cargs, rs) = (map dest_Free ls_frees,
    56                            map dest_Free cargs_frees,
    57                            map dest_Free rs_frees)
    58       handle TERM _ => raise RecError "illegal argument in pattern";
    59     val lfrees = ls @ rs @ cargs;
    60 
    61     (*Constructor, frees to left of pattern, pattern variables,
    62       frees to right of pattern, rhs of equation, full original equation. *)
    63     val new_eqn = (cname, (rhs, cargs, eq))
    64 
    65   in
    66     if has_duplicates (op =) lfrees then
    67       raise RecError "repeated variable name in pattern"
    68     else if not (subset (op =) (Term.add_frees rhs [], lfrees)) then
    69       raise RecError "extra variables on rhs"
    70     else if length middle > 1 then
    71       raise RecError "more than one non-variable in pattern"
    72     else case rec_fn_opt of
    73         NONE => SOME (fname, ftype, ls, rs, con_info, [new_eqn])
    74       | SOME (fname', _, ls', rs', con_info': constructor_info, eqns) =>
    75           if AList.defined (op =) eqns cname then
    76             raise RecError "constructor already occurred as pattern"
    77           else if (ls <> ls') orelse (rs <> rs') then
    78             raise RecError "non-recursive arguments are inconsistent"
    79           else if #big_rec_name con_info <> #big_rec_name con_info' then
    80              raise RecError ("Mixed datatypes for function " ^ fname)
    81           else if fname <> fname' then
    82              raise RecError ("inconsistent functions for datatype " ^
    83                              #big_rec_name con_info)
    84           else SOME (fname, ftype, ls, rs, con_info, new_eqn::eqns)
    85   end
    86   handle RecError s => primrec_eq_err thy s eq;
    87 
    88 
    89 (*Instantiates a recursor equation with constructor arguments*)
    90 fun inst_recursor ((_ $ constr, rhs), cargs') =
    91     subst_atomic (#2 (strip_comb constr) ~~ map Free cargs') rhs;
    92 
    93 
    94 (*Convert a list of recursion equations into a recursor call*)
    95 fun process_fun thy (fname, ftype, ls, rs, con_info: constructor_info, eqns) =
    96   let
    97     val fconst = Const(fname, ftype)
    98     val fabs = list_comb (fconst, map Free ls @ [Bound 0] @ map Free rs)
    99     and {big_rec_name, constructors, rec_rewrites, ...} = con_info
   100 
   101     (*Replace X_rec(args,t) by fname(ls,t,rs) *)
   102     fun use_fabs (_ $ t) = subst_bound (t, fabs)
   103       | use_fabs t       = t
   104 
   105     val cnames         = map (#1 o dest_Const) constructors
   106     and recursor_pairs = map (dest_eqn o concl_of) rec_rewrites
   107 
   108     fun absterm (Free x, body) = absfree x body
   109       | absterm (t, body) = Abs("rec", Ind_Syntax.iT, abstract_over (t, body))
   110 
   111     (*Translate rec equations into function arguments suitable for recursor.
   112       Missing cases are replaced by 0 and all cases are put into order.*)
   113     fun add_case ((cname, recursor_pair), cases) =
   114       let val (rhs, recursor_rhs, eq) =
   115             case AList.lookup (op =) eqns cname of
   116                 NONE => (warning ("no equation for constructor " ^ cname ^
   117                                   "\nin definition of function " ^ fname);
   118                          (Const (@{const_name zero}, Ind_Syntax.iT),
   119                           #2 recursor_pair, Const (@{const_name zero}, Ind_Syntax.iT)))
   120               | SOME (rhs, cargs', eq) =>
   121                     (rhs, inst_recursor (recursor_pair, cargs'), eq)
   122           val allowed_terms = map use_fabs (#2 (strip_comb recursor_rhs))
   123           val abs = List.foldr absterm rhs allowed_terms
   124       in
   125           if !Ind_Syntax.trace then
   126               writeln ("recursor_rhs = " ^
   127                        Syntax.string_of_term_global thy recursor_rhs ^
   128                        "\nabs = " ^ Syntax.string_of_term_global thy abs)
   129           else();
   130           if Logic.occs (fconst, abs) then
   131               primrec_eq_err thy
   132                    ("illegal recursive occurrences of " ^ fname)
   133                    eq
   134           else abs :: cases
   135       end
   136 
   137     val recursor = head_of (#1 (hd recursor_pairs))
   138 
   139     (** make definition **)
   140 
   141     (*the recursive argument*)
   142     val rec_arg =
   143       Free (singleton (Name.variant_list (map #1 (ls@rs))) (Long_Name.base_name big_rec_name),
   144         Ind_Syntax.iT)
   145 
   146     val def_tm = Logic.mk_equals
   147                     (subst_bound (rec_arg, fabs),
   148                      list_comb (recursor,
   149                                 List.foldr add_case [] (cnames ~~ recursor_pairs))
   150                      $ rec_arg)
   151 
   152   in
   153       if !Ind_Syntax.trace then
   154             writeln ("primrec def:\n" ^
   155                      Syntax.string_of_term_global thy def_tm)
   156       else();
   157       (Long_Name.base_name fname ^ "_" ^ Long_Name.base_name big_rec_name ^ "_def",
   158        def_tm)
   159   end;
   160 
   161 
   162 (* prepare functions needed for definitions *)
   163 
   164 fun add_primrec_i args thy =
   165   let
   166     val ((eqn_names, eqn_terms), eqn_atts) = apfst split_list (split_list args);
   167     val SOME (fname, ftype, ls, rs, con_info, eqns) =
   168       List.foldr (process_eqn thy) NONE eqn_terms;
   169     val def = process_fun thy (fname, ftype, ls, rs, con_info, eqns);
   170 
   171     val ([def_thm], thy1) = thy
   172       |> Sign.add_path (Long_Name.base_name fname)
   173       |> Global_Theory.add_defs false [Thm.no_attributes (apfst Binding.name def)];
   174 
   175     val rewrites = def_thm :: map mk_meta_eq (#rec_rewrites con_info)
   176     val eqn_thms =
   177       eqn_terms |> map (fn t =>
   178         Goal.prove_global thy1 [] [] (Ind_Syntax.traceIt "next primrec equation = " thy1 t)
   179           (fn {context = ctxt, ...} =>
   180             EVERY [rewrite_goals_tac ctxt rewrites, resolve_tac @{thms refl} 1]));
   181 
   182     val (eqn_thms', thy2) =
   183       thy1
   184       |> Global_Theory.add_thms ((eqn_names ~~ eqn_thms) ~~ eqn_atts);
   185     val (_, thy3) =
   186       thy2
   187       |> Global_Theory.add_thmss [((Binding.name "simps", eqn_thms'), [Simplifier.simp_add])]
   188       ||> Sign.parent_path;
   189   in (thy3, eqn_thms') end;
   190 
   191 fun add_primrec args thy =
   192   add_primrec_i (map (fn ((name, s), srcs) =>
   193     ((name, Syntax.read_prop_global thy s), map (Attrib.attribute_cmd_global thy) srcs))
   194     args) thy;
   195 
   196 
   197 (* outer syntax *)
   198 
   199 val _ =
   200   Outer_Syntax.command @{command_spec "primrec"} "define primitive recursive functions on datatypes"
   201     (Scan.repeat1 (Parse_Spec.opt_thm_name ":" -- Parse.prop)
   202       >> (Toplevel.theory o (#1 oo (add_primrec o map Parse.triple_swap))));
   203 
   204 end;
   205