# HG changeset patch # User haftmann # Date 1153485987 -7200 # Node ID 36737fb58614cb4b8e1da2594ce7cd853e93deee # Parent 0a8ca32f6e64a6bf38a3db7e3e2a671ddbf295dd exported equation transformator diff -r 0a8ca32f6e64 -r 36737fb58614 src/HOL/Tools/primrec_package.ML --- a/src/HOL/Tools/primrec_package.ML Fri Jul 21 14:45:43 2006 +0200 +++ b/src/HOL/Tools/primrec_package.ML Fri Jul 21 14:46:27 2006 +0200 @@ -8,14 +8,15 @@ signature PRIMREC_PACKAGE = sig val quiet_mode: bool ref + val mk_combdefs: theory -> term list -> (string * term) list val add_primrec: string -> ((bstring * string) * Attrib.src list) list - -> theory -> theory * thm list + -> theory -> thm list * theory val add_primrec_unchecked: string -> ((bstring * string) * Attrib.src list) list - -> theory -> theory * thm list + -> theory -> thm list * theory val add_primrec_i: string -> ((bstring * term) * attribute list) list - -> theory -> theory * thm list + -> theory -> thm list * theory val add_primrec_unchecked_i: string -> ((bstring * term) * attribute list) list - -> theory -> theory * thm list + -> theory -> thm list * theory end; structure PrimrecPackage : PRIMREC_PACKAGE = @@ -26,8 +27,8 @@ exception RecError of string; fun primrec_err s = error ("Primrec definition error:\n" ^ s); -fun primrec_eq_err sign s eq = - primrec_err (s ^ "\nin\n" ^ quote (Sign.string_of_term sign eq)); +fun primrec_eq_err thy s eq = + primrec_err (s ^ "\nin\n" ^ quote (Sign.string_of_term thy eq)); (* messages *) @@ -38,13 +39,13 @@ (* preprocessing of equations *) -fun process_eqn sign (eq, rec_fns) = +fun process_eqn thy (eq, rec_fns) = let val (lhs, rhs) = - if null (term_vars eq) then - HOLogic.dest_eq (HOLogic.dest_Trueprop eq) - handle TERM _ => raise RecError "not a proper equation" - else raise RecError "illegal schematic variable(s)"; + if null (term_vars eq) then + HOLogic.dest_eq (HOLogic.dest_Trueprop eq) + handle TERM _ => raise RecError "not a proper equation" + else raise RecError "illegal schematic variable(s)"; val (recfun, args) = strip_comb lhs; val fnameT = dest_Const recfun handle TERM _ => @@ -61,9 +62,8 @@ val (tname, _) = dest_Type (body_type T) handle TYPE _ => raise RecError "cannot determine datatype associated with function" - val (ls, cargs, rs) = (map dest_Free ls', - map dest_Free cargs', - map dest_Free rs') + val (ls, cargs, rs) = + (map dest_Free ls', map dest_Free cargs', map dest_Free rs') handle TERM _ => raise RecError "illegal argument in pattern"; val lfrees = ls @ rs @ cargs; @@ -88,9 +88,9 @@ AList.update (op =) (fnameT, (tname, rpos, (cname, (ls, cargs, rs, rhs, eq))::eqns)) rec_fns) end - handle RecError s => primrec_eq_err sign s eq; + handle RecError s => primrec_eq_err thy s eq; -fun process_fun sign descr rec_eqns ((i, fnameT as (fname, _)), (fnameTs, fnss)) = +fun process_fun thy descr rec_eqns ((i, fnameT as (fname, _)), (fnameTs, fnss)) = let val (_, (tname, _, constrs)) = List.nth (descr, i); @@ -122,7 +122,7 @@ | SOME (i', y) => let val (fs', ts') = foldl_map (subst subs) (fs, xs @ ls @ rs); - val fs'' = process_fun sign descr rec_eqns ((i', fnameT'), fs') + val fs'' = process_fun thy descr rec_eqns ((i', fnameT'), fs') in (fs'', list_comb (y, ts')) end) end @@ -145,15 +145,15 @@ val recs = List.filter (is_rec_type o snd) (cargs' ~~ cargs); val rargs = map fst recs; val subs = map (rpair dummyT o fst) - (rev (rename_wrt_term rhs rargs)); + (rev (rename_wrt_term rhs rargs)); val ((fnameTs'', fnss''), rhs') = - (subst (map (fn ((x, y), z) => - (Free x, (body_index y, Free z))) - (recs ~~ subs)) - ((fnameTs', fnss'), rhs)) - handle RecError s => primrec_eq_err sign s eq + (subst (map (fn ((x, y), z) => + (Free x, (body_index y, Free z))) + (recs ~~ subs)) + ((fnameTs', fnss'), rhs)) + handle RecError s => primrec_eq_err thy s eq in (fnameTs'', fnss'', - (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns) + (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns) end) in (case AList.lookup (op =) fnameTs i of @@ -176,7 +176,7 @@ (* prepare functions needed for definitions *) -fun get_fns fns (((i, (tname, _, constrs)), rec_name), (fs, defs)) = +fun get_fns fns (((i : int, (tname, _, constrs)), rec_name), (fs, defs)) = case AList.lookup (op =) fns i of NONE => let @@ -192,15 +192,15 @@ (* make definition *) -fun make_def sign fs (fname, ls, rec_name, tname) = +fun make_def thy fs (fname, ls, rec_name, tname) = let val rhs = foldr (fn (T, t) => Abs ("", T, t)) - (list_comb (Const (rec_name, dummyT), - fs @ map Bound (0 ::(length ls downto 1)))) - ((map snd ls) @ [dummyT]); + (list_comb (Const (rec_name, dummyT), + fs @ map Bound (0 ::(length ls downto 1)))) + ((map snd ls) @ [dummyT]); val defpair = (Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def", - Logic.mk_equals (Const (fname, dummyT), rhs)) - in Theory.inferT_axm sign defpair end; + Logic.mk_equals (Const (fname, dummyT), rhs)) + in Theory.inferT_axm thy defpair end; (* find datatypes which contain all datatypes in tnames' *) @@ -225,12 +225,10 @@ |> RuleCases.save induction end; -fun gen_primrec_i unchecked alt_name eqns_atts thy = +fun mk_defs thy eqns = let - val (eqns, atts) = split_list eqns_atts; - val sg = Theory.sign_of thy; val dt_info = DatatypePackage.get_datatypes thy; - val rec_eqns = foldr (process_eqn sg) [] (map snd eqns); + val rec_eqns = foldr (process_eqn thy) [] eqns; val tnames = distinct (op =) (map (#1 o snd) rec_eqns); val dts = find_dts dt_info tnames tnames; val main_fns = @@ -242,10 +240,20 @@ if null dts then primrec_err ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive") else snd (hd dts); - val (fnameTs, fnss) = foldr (process_fun sg descr rec_eqns) - ([], []) main_fns; + val (fnameTs, fnss) = + foldr (process_fun thy descr rec_eqns) ([], []) main_fns; val (fs, defs) = foldr (get_fns fnss) ([], []) (descr ~~ rec_names); - val defs' = map (make_def sg fs) defs; + val defs' = map (make_def thy fs) defs; + in (fnameTs, rec_eqns, rec_rewrites, dts, defs, defs') end; + +fun mk_combdefs thy = + #6 o mk_defs thy o map (ObjectLogic.ensure_propT thy); + +fun gen_primrec_i unchecked alt_name eqns_atts thy = + let + val (eqns, atts) = split_list eqns_atts; + val (fnameTs, rec_eqns, rec_rewrites, dts, defs, defs') = + mk_defs thy (map snd eqns); val nameTs1 = map snd fnameTs; val nameTs2 = map fst rec_eqns; val primrec_name = @@ -261,26 +269,25 @@ commas_quote (map fst nameTs1) ^ " ..."); val simps = map (fn (_, t) => Goal.prove_global thy' [] [] t (fn _ => EVERY [rewrite_goals_tac rewrites, rtac refl 1])) eqns; - val (simps', thy'') = PureThy.add_thms ((map fst eqns ~~ simps) ~~ atts) thy'; - val thy''' = thy'' - |> (snd o PureThy.add_thmss [(("simps", simps'), - [Simplifier.simp_add, RecfunCodegen.add NONE])]) - |> (snd o PureThy.add_thms [(("induct", prepare_induct (#2 (hd dts)) rec_eqns), [])]) - |> Theory.parent_path + val (simps', thy'') = thy' |> PureThy.add_thms ((map fst eqns ~~ simps) ~~ atts); in - (thy''', simps') + thy'' + |> (snd o PureThy.add_thmss [(("simps", simps'), + [Simplifier.simp_add, RecfunCodegen.add NONE])]) + |> (snd o PureThy.add_thms [(("induct", prepare_induct (#2 (hd dts)) rec_eqns), [])]) + |> Theory.parent_path + |> pair simps' end; fun gen_primrec unchecked alt_name eqns thy = let - val sign = Theory.sign_of thy; val ((names, strings), srcss) = apfst split_list (split_list eqns); val atts = map (map (Attrib.attribute thy)) srcss; - val eqn_ts = map (fn s => term_of (Thm.read_cterm sign (s, propT)) + val eqn_ts = map (fn s => term_of (Thm.read_cterm thy (s, propT)) handle ERROR msg => cat_error msg ("The error(s) above occurred for " ^ s)) strings; val rec_ts = map (fn eq => head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop eq))) - handle TERM _ => primrec_eq_err sign "not a proper equation" eq) eqn_ts; - val (_, eqn_ts') = InductivePackage.unify_consts (sign_of thy) rec_ts eqn_ts + handle TERM _ => primrec_eq_err thy "not a proper equation" eq) eqn_ts; + val (_, eqn_ts') = InductivePackage.unify_consts thy rec_ts eqn_ts in gen_primrec_i unchecked alt_name (names ~~ eqn_ts' ~~ atts) thy end; @@ -306,7 +313,7 @@ val primrecP = OuterSyntax.command "primrec" "define primitive recursive functions on datatypes" K.thy_decl (primrec_decl >> (fn ((unchecked, alt_name), eqns) => - Toplevel.theory (#1 o + Toplevel.theory (snd o (if unchecked then add_primrec_unchecked else add_primrec) alt_name (map P.triple_swap eqns))));