src/HOL/Tools/Old_Datatype/old_primrec.ML
changeset 58112 8081087096ad
parent 56245 84fc7dfa3cd4
child 58839 ccda99401bc8
equal deleted inserted replaced
58111:82db9ad610b9 58112:8081087096ad
       
     1 (*  Title:      HOL/Tools/Old_Datatype/old_primrec.ML
       
     2     Author:     Norbert Voelker, FernUni Hagen
       
     3     Author:     Stefan Berghofer, TU Muenchen
       
     4     Author:     Florian Haftmann, TU Muenchen
       
     5 
       
     6 Primitive recursive functions on datatypes.
       
     7 *)
       
     8 
       
     9 signature OLD_PRIMREC =
       
    10 sig
       
    11   val add_primrec: (binding * typ option * mixfix) list ->
       
    12     (Attrib.binding * term) list -> local_theory -> (term list * thm list) * local_theory
       
    13   val add_primrec_cmd: (binding * string option * mixfix) list ->
       
    14     (Attrib.binding * string) list -> local_theory -> (term list * thm list) * local_theory
       
    15   val add_primrec_global: (binding * typ option * mixfix) list ->
       
    16     (Attrib.binding * term) list -> theory -> (term list * thm list) * theory
       
    17   val add_primrec_overloaded: (string * (string * typ) * bool) list ->
       
    18     (binding * typ option * mixfix) list ->
       
    19     (Attrib.binding * term) list -> theory -> (term list * thm list) * theory
       
    20   val add_primrec_simple: ((binding * typ) * mixfix) list -> term list ->
       
    21     local_theory -> (string * (term list * thm list)) * local_theory
       
    22 end;
       
    23 
       
    24 structure Old_Primrec : OLD_PRIMREC =
       
    25 struct
       
    26 
       
    27 exception PrimrecError of string * term option;
       
    28 
       
    29 fun primrec_error msg = raise PrimrecError (msg, NONE);
       
    30 fun primrec_error_eqn msg eqn = raise PrimrecError (msg, SOME eqn);
       
    31 
       
    32 
       
    33 (* preprocessing of equations *)
       
    34 
       
    35 fun process_eqn is_fixed spec rec_fns =
       
    36   let
       
    37     val (vs, Ts) = split_list (strip_qnt_vars @{const_name Pure.all} spec);
       
    38     val body = strip_qnt_body @{const_name Pure.all} spec;
       
    39     val (vs', _) = fold_map Name.variant vs (Name.make_context (fold_aterms
       
    40       (fn Free (v, _) => insert (op =) v | _ => I) body []));
       
    41     val eqn = curry subst_bounds (map2 (curry Free) vs' Ts |> rev) body;
       
    42     val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn)
       
    43       handle TERM _ => primrec_error "not a proper equation";
       
    44     val (recfun, args) = strip_comb lhs;
       
    45     val fname =
       
    46       (case recfun of
       
    47         Free (v, _) =>
       
    48           if is_fixed v then v
       
    49           else primrec_error "illegal head of function equation"
       
    50       | _ => primrec_error "illegal head of function equation");
       
    51 
       
    52     val (ls', rest)  = take_prefix is_Free args;
       
    53     val (middle, rs') = take_suffix is_Free rest;
       
    54     val rpos = length ls';
       
    55 
       
    56     val (constr, cargs') =
       
    57       if null middle then primrec_error "constructor missing"
       
    58       else strip_comb (hd middle);
       
    59     val (cname, T) = dest_Const constr
       
    60       handle TERM _ => primrec_error "ill-formed constructor";
       
    61     val (tname, _) = dest_Type (body_type T) handle TYPE _ =>
       
    62       primrec_error "cannot determine datatype associated with function"
       
    63 
       
    64     val (ls, cargs, rs) =
       
    65       (map dest_Free ls', map dest_Free cargs', map dest_Free rs')
       
    66       handle TERM _ => primrec_error "illegal argument in pattern";
       
    67     val lfrees = ls @ rs @ cargs;
       
    68 
       
    69     fun check_vars _ [] = ()
       
    70       | check_vars s vars = primrec_error (s ^ commas_quote (map fst vars)) eqn;
       
    71   in
       
    72     if length middle > 1 then
       
    73       primrec_error "more than one non-variable in pattern"
       
    74     else
       
    75      (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees);
       
    76       check_vars "extra variables on rhs: "
       
    77         (Term.add_frees rhs [] |> subtract (op =) lfrees
       
    78           |> filter_out (is_fixed o fst));
       
    79       (case AList.lookup (op =) rec_fns fname of
       
    80         NONE =>
       
    81           (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eqn))])) :: rec_fns
       
    82       | SOME (_, rpos', eqns) =>
       
    83           if AList.defined (op =) eqns cname then
       
    84             primrec_error "constructor already occurred as pattern"
       
    85           else if rpos <> rpos' then
       
    86             primrec_error "position of recursive argument inconsistent"
       
    87           else
       
    88             AList.update (op =)
       
    89               (fname, (tname, rpos, (cname, (ls, cargs, rs, rhs, eqn)) :: eqns))
       
    90               rec_fns))
       
    91   end handle PrimrecError (msg, NONE) => primrec_error_eqn msg spec;
       
    92 
       
    93 fun process_fun descr eqns (i, fname) (fnames, fnss) =
       
    94   let
       
    95     val (_, (tname, _, constrs)) = nth descr i;
       
    96 
       
    97     (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *)
       
    98 
       
    99     fun subst [] t fs = (t, fs)
       
   100       | subst subs (Abs (a, T, t)) fs =
       
   101           fs
       
   102           |> subst subs t
       
   103           |-> (fn t' => pair (Abs (a, T, t')))
       
   104       | subst subs (t as (_ $ _)) fs =
       
   105           let
       
   106             val (f, ts) = strip_comb t;
       
   107           in
       
   108             if is_Free f
       
   109               andalso member (fn ((v, _), (w, _)) => v = w) eqns (dest_Free f) then
       
   110               let
       
   111                 val (fname', _) = dest_Free f;
       
   112                 val (_, rpos, _) = the (AList.lookup (op =) eqns fname');
       
   113                 val (ls, rs) = chop rpos ts
       
   114                 val (x', rs') =
       
   115                   (case rs of
       
   116                     x' :: rs => (x', rs)
       
   117                   | [] => primrec_error ("not enough arguments in recursive application\n" ^
       
   118                       "of function " ^ quote fname' ^ " on rhs"));
       
   119                 val (x, xs) = strip_comb x';
       
   120               in
       
   121                 (case AList.lookup (op =) subs x of
       
   122                   NONE =>
       
   123                     fs
       
   124                     |> fold_map (subst subs) ts
       
   125                     |-> (fn ts' => pair (list_comb (f, ts')))
       
   126                 | SOME (i', y) =>
       
   127                     fs
       
   128                     |> fold_map (subst subs) (xs @ ls @ rs')
       
   129                     ||> process_fun descr eqns (i', fname')
       
   130                     |-> (fn ts' => pair (list_comb (y, ts'))))
       
   131               end
       
   132             else
       
   133               fs
       
   134               |> fold_map (subst subs) (f :: ts)
       
   135               |-> (fn f' :: ts' => pair (list_comb (f', ts')))
       
   136           end
       
   137       | subst _ t fs = (t, fs);
       
   138 
       
   139     (* translate rec equations into function arguments suitable for rec comb *)
       
   140 
       
   141     fun trans eqns (cname, cargs) (fnames', fnss', fns) =
       
   142       (case AList.lookup (op =) eqns cname of
       
   143         NONE => (warning ("No equation for constructor " ^ quote cname ^
       
   144           "\nin definition of function " ^ quote fname);
       
   145             (fnames', fnss', (Const (@{const_name undefined}, dummyT)) :: fns))
       
   146       | SOME (ls, cargs', rs, rhs, eq) =>
       
   147           let
       
   148             val recs = filter (Old_Datatype_Aux.is_rec_type o snd) (cargs' ~~ cargs);
       
   149             val rargs = map fst recs;
       
   150             val subs = map (rpair dummyT o fst)
       
   151               (rev (Term.rename_wrt_term rhs rargs));
       
   152             val (rhs', (fnames'', fnss'')) = subst (map2 (fn (x, y) => fn z =>
       
   153               (Free x, (Old_Datatype_Aux.body_index y, Free z))) recs subs) rhs (fnames', fnss')
       
   154                 handle PrimrecError (s, NONE) => primrec_error_eqn s eq
       
   155           in
       
   156             (fnames'', fnss'', fold_rev absfree (cargs' @ subs @ ls @ rs) rhs' :: fns)
       
   157           end)
       
   158 
       
   159   in
       
   160     (case AList.lookup (op =) fnames i of
       
   161       NONE =>
       
   162         if exists (fn (_, v) => fname = v) fnames then
       
   163           primrec_error ("inconsistent functions for datatype " ^ quote tname)
       
   164         else
       
   165           let
       
   166             val (_, _, eqns) = the (AList.lookup (op =) eqns fname);
       
   167             val (fnames', fnss', fns) = fold_rev (trans eqns) constrs
       
   168               ((i, fname) :: fnames, fnss, [])
       
   169           in
       
   170             (fnames', (i, (fname, #1 (snd (hd eqns)), fns)) :: fnss')
       
   171           end
       
   172     | SOME fname' =>
       
   173         if fname = fname' then (fnames, fnss)
       
   174         else primrec_error ("inconsistent functions for datatype " ^ quote tname))
       
   175   end;
       
   176 
       
   177 
       
   178 (* prepare functions needed for definitions *)
       
   179 
       
   180 fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) =
       
   181   (case AList.lookup (op =) fns i of
       
   182     NONE =>
       
   183       let
       
   184         val dummy_fns = map (fn (_, cargs) => Const (@{const_name undefined},
       
   185           replicate (length cargs + length (filter Old_Datatype_Aux.is_rec_type cargs))
       
   186             dummyT ---> HOLogic.unitT)) constrs;
       
   187         val _ = warning ("No function definition for datatype " ^ quote tname)
       
   188       in
       
   189         (dummy_fns @ fs, defs)
       
   190       end
       
   191   | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs));
       
   192 
       
   193 
       
   194 (* make definition *)
       
   195 
       
   196 fun make_def ctxt fixes fs (fname, ls, rec_name, tname) =
       
   197   let
       
   198     val SOME (var, varT) = get_first (fn ((b, T), mx) =>
       
   199       if Binding.name_of b = fname then SOME ((b, mx), T) else NONE) fixes;
       
   200     val def_name = Thm.def_name (Long_Name.base_name fname);
       
   201     val raw_rhs = fold_rev (fn T => fn t => Abs ("", T, t)) (map snd ls @ [dummyT])
       
   202       (list_comb (Const (rec_name, dummyT), fs @ map Bound (0 :: (length ls downto 1))))
       
   203     val rhs = singleton (Syntax.check_terms ctxt) (Type.constraint varT raw_rhs);
       
   204   in (var, ((Binding.conceal (Binding.name def_name), []), rhs)) end;
       
   205 
       
   206 
       
   207 (* find datatypes which contain all datatypes in tnames' *)
       
   208 
       
   209 fun find_dts _ _ [] = []
       
   210   | find_dts dt_info tnames' (tname :: tnames) =
       
   211       (case Symtab.lookup dt_info tname of
       
   212         NONE => primrec_error (quote tname ^ " is not a datatype")
       
   213       | SOME (dt : Old_Datatype_Aux.info) =>
       
   214           if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then
       
   215             (tname, dt) :: (find_dts dt_info tnames' tnames)
       
   216           else find_dts dt_info tnames' tnames);
       
   217 
       
   218 
       
   219 (* distill primitive definition(s) from primrec specification *)
       
   220 
       
   221 fun distill ctxt fixes eqs =
       
   222   let
       
   223     val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed ctxt v
       
   224       orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs [];
       
   225     val tnames = distinct (op =) (map (#1 o snd) eqns);
       
   226     val dts = find_dts (Old_Datatype_Data.get_all (Proof_Context.theory_of ctxt)) tnames tnames;
       
   227     val main_fns = map (fn (tname, {index, ...}) =>
       
   228       (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts;
       
   229     val {descr, rec_names, rec_rewrites, ...} =
       
   230       if null dts then primrec_error
       
   231         ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
       
   232       else snd (hd dts);
       
   233     val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
       
   234     val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
       
   235     val defs = map (make_def ctxt fixes fs) raw_defs;
       
   236     val names = map snd fnames;
       
   237     val names_eqns = map fst eqns;
       
   238     val _ =
       
   239       if eq_set (op =) (names, names_eqns) then ()
       
   240       else primrec_error ("functions " ^ commas_quote names_eqns ^
       
   241         "\nare not mutually recursive");
       
   242     val rec_rewrites' = map mk_meta_eq rec_rewrites;
       
   243     val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs);
       
   244     fun prove ctxt defs =
       
   245       let
       
   246         val frees = fold (Variable.add_free_names ctxt) eqs [];
       
   247         val rewrites = rec_rewrites' @ map (snd o snd) defs;
       
   248       in
       
   249         map (fn eq => Goal.prove ctxt frees [] eq
       
   250           (fn {context = ctxt', ...} => EVERY [rewrite_goals_tac ctxt' rewrites, rtac refl 1])) eqs
       
   251       end;
       
   252   in ((prefix, (fs, defs)), prove) end
       
   253   handle PrimrecError (msg, some_eqn) =>
       
   254     error ("Primrec definition error:\n" ^ msg ^
       
   255       (case some_eqn of
       
   256         SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term ctxt eqn)
       
   257       | NONE => ""));
       
   258 
       
   259 
       
   260 (* primrec definition *)
       
   261 
       
   262 fun add_primrec_simple fixes ts lthy =
       
   263   let
       
   264     val ((prefix, (_, defs)), prove) = distill lthy fixes ts;
       
   265   in
       
   266     lthy
       
   267     |> fold_map Local_Theory.define defs
       
   268     |-> (fn defs => `(fn lthy => (prefix, (map fst defs, prove lthy defs))))
       
   269   end;
       
   270 
       
   271 local
       
   272 
       
   273 fun gen_primrec prep_spec raw_fixes raw_spec lthy =
       
   274   let
       
   275     val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy);
       
   276     fun attr_bindings prefix = map (fn ((b, attrs), _) =>
       
   277       (Binding.qualify false prefix b, Code.add_default_eqn_attrib :: attrs)) spec;
       
   278     fun simp_attr_binding prefix =
       
   279       (Binding.qualify true prefix (Binding.name "simps"), @{attributes [simp, nitpick_simp]});
       
   280   in
       
   281     lthy
       
   282     |> add_primrec_simple fixes (map snd spec)
       
   283     |-> (fn (prefix, (ts, simps)) =>
       
   284       Spec_Rules.add Spec_Rules.Equational (ts, simps)
       
   285       #> fold_map Local_Theory.note (attr_bindings prefix ~~ map single simps)
       
   286       #-> (fn simps' => Local_Theory.note (simp_attr_binding prefix, maps snd simps')
       
   287       #>> (fn (_, simps'') => (ts, simps''))))
       
   288   end;
       
   289 
       
   290 in
       
   291 
       
   292 val add_primrec = gen_primrec Specification.check_spec;
       
   293 val add_primrec_cmd = gen_primrec Specification.read_spec;
       
   294 
       
   295 end;
       
   296 
       
   297 fun add_primrec_global fixes specs thy =
       
   298   let
       
   299     val lthy = Named_Target.theory_init thy;
       
   300     val ((ts, simps), lthy') = add_primrec fixes specs lthy;
       
   301     val simps' = Proof_Context.export lthy' lthy simps;
       
   302   in ((ts, simps'), Local_Theory.exit_global lthy') end;
       
   303 
       
   304 fun add_primrec_overloaded ops fixes specs thy =
       
   305   let
       
   306     val lthy = Overloading.overloading ops thy;
       
   307     val ((ts, simps), lthy') = add_primrec fixes specs lthy;
       
   308     val simps' = Proof_Context.export lthy' lthy simps;
       
   309   in ((ts, simps'), Local_Theory.exit_global lthy') end;
       
   310 
       
   311 end;