# HG changeset patch # User haftmann # Date 1243351774 -7200 # Node ID 580510315ddaf449e6592f32e322a0a8ef1ab4b7 # Parent 900ebbc35e30f3ef1ed6670853971a930735505f add_primrec_simple diff -r 900ebbc35e30 -r 580510315dda src/HOL/Tools/primrec_package.ML --- a/src/HOL/Tools/primrec_package.ML Tue May 26 17:29:33 2009 +0200 +++ b/src/HOL/Tools/primrec_package.ML Tue May 26 17:29:34 2009 +0200 @@ -16,6 +16,8 @@ val add_primrec_overloaded: (string * (string * typ) * bool) list -> (binding * typ option * mixfix) list -> (Attrib.binding * term) list -> theory -> thm list * theory + val add_primrec_simple: ((binding * typ) * mixfix) list -> (binding * term) list -> + local_theory -> (string * thm list list) * local_theory end; structure PrimrecPackage : PRIMREC_PACKAGE = @@ -211,22 +213,12 @@ else find_dts dt_info tnames' tnames); -(* primrec definition *) +(* distill primitive definition(s) from primrec specification *) -local - -fun prove_spec ctxt names rec_rewrites defs eqs = +fun distill lthy fixes eqs = let - val rewrites = map mk_meta_eq rec_rewrites @ map (snd o snd) defs; - fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1]; - val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names); - in map (fn (a, t) => (a, [Goal.prove ctxt [] [] t tac])) eqs end; - -fun gen_primrec set_group prep_spec raw_fixes raw_spec lthy = - let - val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy); val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v - orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes) o snd) spec []; + orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs []; val tnames = distinct (op =) (map (#1 o snd) eqns); val dts = find_dts (DatatypePackage.get_datatypes (ProofContext.theory_of lthy)) tnames tnames; val main_fns = map (fn (tname, {index, ...}) => @@ -236,31 +228,59 @@ ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive") else snd (hd dts); val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []); - val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []); - val names1 = map snd fnames; - val names2 = map fst eqns; - val _ = if gen_eq_set (op =) (names1, names2) then () - else primrec_error ("functions " ^ commas_quote names2 ^ + val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []); + val defs = map (make_def lthy fixes fs) raw_defs; + val names = map snd fnames; + val names_eqns = map fst eqns; + val _ = if gen_eq_set (op =) (names, names_eqns) then () + else primrec_error ("functions " ^ commas_quote names_eqns ^ "\nare not mutually recursive"); - val prefix = space_implode "_" (map (Long_Name.base_name o #1) defs); - val qualify = Binding.qualify false prefix; - val spec' = (map o apfst) - (fn (b, attrs) => (qualify b, Code.add_default_eqn_attrib :: attrs)) spec; - val simp_atts = map (Attrib.internal o K) - [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add, Quickcheck_RecFun_Simp_Thms.add]; + val rec_rewrites' = map mk_meta_eq rec_rewrites; + val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs); + fun prove lthy defs = + let + val rewrites = rec_rewrites' @ map (snd o snd) defs; + fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1]; + val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names); + in map (fn eq => [Goal.prove lthy [] [] eq tac]) eqs end; + in ((prefix, (fs, defs)), prove) end + handle PrimrecError (msg, some_eqn) => + error ("Primrec definition error:\n" ^ msg ^ (case some_eqn + of SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn) + | NONE => "")); + + +(* primrec definition *) + +fun add_primrec_simple fixes spec lthy = + let + val ((prefix, (fs, defs)), prove) = distill lthy fixes (map snd spec); + in + lthy + |> fold_map (LocalTheory.define Thm.definitionK) defs + |-> (fn defs => `(fn lthy => (prefix, prove lthy defs))) + end; + +local + +fun gen_primrec set_group prep_spec raw_fixes raw_spec lthy = + let + val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy); + fun attr_bindings prefix = map (fn ((b, attrs), _) => + (Binding.qualify false prefix b, Code.add_default_eqn_attrib :: attrs)) spec; + fun simp_attr_binding prefix = (Binding.qualify false prefix (Binding.name "simps"), + map (Attrib.internal o K) + [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add, Quickcheck_RecFun_Simp_Thms.add]); in lthy |> set_group ? LocalTheory.set_group (serial_string ()) - |> fold_map (LocalTheory.define Thm.definitionK o make_def lthy fixes fs) defs - |-> (fn defs => `(fn ctxt => prove_spec ctxt names1 rec_rewrites defs spec')) - |-> (fn simps => fold_map (LocalTheory.note Thm.generatedK) simps) - |-> (fn simps' => LocalTheory.note Thm.generatedK - ((qualify (Binding.qualified_name "simps"), simp_atts), maps snd simps')) + |> add_primrec_simple fixes spec + |-> (fn (prefix, simps) => fold_map (LocalTheory.note Thm.generatedK) + (attr_bindings prefix ~~ simps) + #-> (fn simps' => LocalTheory.note Thm.generatedK + (simp_attr_binding prefix, maps snd simps'))) |>> snd - end handle PrimrecError (msg, some_eqn) => - error ("Primrec definition error:\n" ^ msg ^ (case some_eqn - of SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn) - | NONE => "")); + end; in