src/HOL/Tools/Datatype/primrec.ML
changeset 45898 b619242b0439
parent 45897 65cef0298158
child 46961 5c6955f487e5
equal deleted inserted replaced
45897:65cef0298158 45898:b619242b0439
   204   in (var, ((Binding.conceal (Binding.name def_name), []), rhs)) end;
   204   in (var, ((Binding.conceal (Binding.name def_name), []), rhs)) end;
   205 
   205 
   206 
   206 
   207 (* find datatypes which contain all datatypes in tnames' *)
   207 (* find datatypes which contain all datatypes in tnames' *)
   208 
   208 
   209 fun find_dts (dt_info : Datatype_Aux.info Symtab.table) _ [] = []
   209 fun find_dts _ _ [] = []
   210   | find_dts dt_info tnames' (tname :: tnames) =
   210   | find_dts dt_info tnames' (tname :: tnames) =
   211       (case Symtab.lookup dt_info tname of
   211       (case Symtab.lookup dt_info tname of
   212         NONE => primrec_error (quote tname ^ " is not a datatype")
   212         NONE => primrec_error (quote tname ^ " is not a datatype")
   213       | SOME dt =>
   213       | SOME (dt : Datatype_Aux.info) =>
   214           if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then
   214           if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then
   215             (tname, dt) :: (find_dts dt_info tnames' tnames)
   215             (tname, dt) :: (find_dts dt_info tnames' tnames)
   216           else find_dts dt_info tnames' tnames);
   216           else find_dts dt_info tnames' tnames);
   217 
   217 
   218 
   218 
   219 (* distill primitive definition(s) from primrec specification *)
   219 (* distill primitive definition(s) from primrec specification *)
   220 
   220 
   221 fun distill lthy fixes eqs =
   221 fun distill ctxt fixes eqs =
   222   let
   222   let
   223     val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v
   223     val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed ctxt v
   224       orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs [];
   224       orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs [];
   225     val tnames = distinct (op =) (map (#1 o snd) eqns);
   225     val tnames = distinct (op =) (map (#1 o snd) eqns);
   226     val dts = find_dts (Datatype_Data.get_all (Proof_Context.theory_of lthy)) tnames tnames;
   226     val dts = find_dts (Datatype_Data.get_all (Proof_Context.theory_of ctxt)) tnames tnames;
   227     val main_fns = map (fn (tname, {index, ...}) =>
   227     val main_fns = map (fn (tname, {index, ...}) =>
   228       (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts;
   228       (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts;
   229     val {descr, rec_names, rec_rewrites, ...} =
   229     val {descr, rec_names, rec_rewrites, ...} =
   230       if null dts then primrec_error
   230       if null dts then primrec_error
   231         ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
   231         ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
   232       else snd (hd dts);
   232       else snd (hd dts);
   233     val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
   233     val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
   234     val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
   234     val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
   235     val defs = map (make_def lthy fixes fs) raw_defs;
   235     val defs = map (make_def ctxt fixes fs) raw_defs;
   236     val names = map snd fnames;
   236     val names = map snd fnames;
   237     val names_eqns = map fst eqns;
   237     val names_eqns = map fst eqns;
   238     val _ =
   238     val _ =
   239       if eq_set (op =) (names, names_eqns) then ()
   239       if eq_set (op =) (names, names_eqns) then ()
   240       else primrec_error ("functions " ^ commas_quote names_eqns ^
   240       else primrec_error ("functions " ^ commas_quote names_eqns ^
   241         "\nare not mutually recursive");
   241         "\nare not mutually recursive");
   242     val rec_rewrites' = map mk_meta_eq rec_rewrites;
   242     val rec_rewrites' = map mk_meta_eq rec_rewrites;
   243     val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs);
   243     val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs);
   244     fun prove lthy defs =
   244     fun prove ctxt defs =
   245       let
   245       let
   246         val frees = fold (Variable.add_free_names lthy) eqs [];
   246         val frees = fold (Variable.add_free_names ctxt) eqs [];
   247         val rewrites = rec_rewrites' @ map (snd o snd) defs;
   247         val rewrites = rec_rewrites' @ map (snd o snd) defs;
   248         fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
   248         fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
   249       in map (fn eq => Goal.prove lthy frees [] eq tac) eqs end;
   249       in map (fn eq => Goal.prove ctxt frees [] eq tac) eqs end;
   250   in ((prefix, (fs, defs)), prove) end
   250   in ((prefix, (fs, defs)), prove) end
   251   handle PrimrecError (msg, some_eqn) =>
   251   handle PrimrecError (msg, some_eqn) =>
   252     error ("Primrec definition error:\n" ^ msg ^
   252     error ("Primrec definition error:\n" ^ msg ^
   253       (case some_eqn of
   253       (case some_eqn of
   254         SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
   254         SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term ctxt eqn)
   255       | NONE => ""));
   255       | NONE => ""));
   256 
   256 
   257 
   257 
   258 (* primrec definition *)
   258 (* primrec definition *)
   259 
   259 
   260 fun add_primrec_simple fixes ts lthy =
   260 fun add_primrec_simple fixes ts lthy =
   261   let
   261   let
   262     val ((prefix, (fs, defs)), prove) = distill lthy fixes ts;
   262     val ((prefix, (_, defs)), prove) = distill lthy fixes ts;
   263   in
   263   in
   264     lthy
   264     lthy
   265     |> fold_map Local_Theory.define defs
   265     |> fold_map Local_Theory.define defs
   266     |-> (fn defs => `(fn lthy => (prefix, (map fst defs, prove lthy defs))))
   266     |-> (fn defs => `(fn lthy => (prefix, (map fst defs, prove lthy defs))))
   267   end;
   267   end;