--- a/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML Wed Feb 19 08:34:32 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML Wed Feb 19 08:34:33 2014 +0100
@@ -19,12 +19,15 @@
rec_thms: thm list};
type lfp_rec_extension =
- {is_new_datatype: Proof.context -> string -> bool,
+ {nested_simps: thm list,
+ is_new_datatype: Proof.context -> string -> bool,
get_basic_lfp_sugars: binding list -> typ list -> (term -> int list) ->
(term * term list list) list list -> local_theory ->
typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * bool * local_theory,
- massage_nested_rec_call: Proof.context -> (term -> bool) -> (typ -> typ -> term -> term) ->
- typ list -> term -> term -> term -> term};
+ rewrite_nested_rec_call: Proof.context -> (term -> bool) -> (string -> int) -> typ list -> term ->
+ term -> term -> term};
+
+ exception PRIMREC of string * term list;
val register_lfp_rec_extension: lfp_rec_extension -> theory -> theory
@@ -45,11 +48,13 @@
struct
open Ctr_Sugar
+open Ctr_Sugar_Util
open Ctr_Sugar_General_Tactics
-open BNF_Util
-open BNF_FP_Util
open BNF_FP_Rec_Sugar_Util
+val inductN = "induct"
+val simpsN = "simps"
+
val nitpicksimp_attrs = @{attributes [nitpick_simp]};
val simp_attrs = @{attributes [simp]};
val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs @ simp_attrs;
@@ -57,10 +62,6 @@
exception OLD_PRIMREC of unit;
exception PRIMREC of string * term list;
-fun primrec_error str = raise PRIMREC (str, []);
-fun primrec_error_eqn str eqn = raise PRIMREC (str, [eqn]);
-fun primrec_error_eqns str eqns = raise PRIMREC (str, eqns);
-
datatype rec_call =
No_Rec of int * typ |
Mutual_Rec of (int * typ) * (int * typ) |
@@ -89,12 +90,13 @@
rec_thms: thm list};
type lfp_rec_extension =
- {is_new_datatype: Proof.context -> string -> bool,
+ {nested_simps: thm list,
+ is_new_datatype: Proof.context -> string -> bool,
get_basic_lfp_sugars: binding list -> typ list -> (term -> int list) ->
(term * term list list) list list -> local_theory ->
typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * bool * local_theory,
- massage_nested_rec_call: Proof.context -> (term -> bool) -> (typ -> typ -> term -> term) ->
- typ list -> term -> term -> term -> term};
+ rewrite_nested_rec_call: Proof.context -> (term -> bool) -> (string -> int) -> typ list -> term ->
+ term -> term -> term};
structure Data = Theory_Data
(
@@ -106,6 +108,11 @@
val register_lfp_rec_extension = Data.put o SOME;
+fun nested_simps ctxt =
+ (case Data.get (Proof_Context.theory_of ctxt) of
+ SOME {nested_simps, ...} => nested_simps
+ | NONE => []);
+
fun is_new_datatype ctxt =
(case Data.get (Proof_Context.theory_of ctxt) of
SOME {is_new_datatype, ...} => is_new_datatype ctxt
@@ -116,9 +123,9 @@
SOME {get_basic_lfp_sugars, ...} => get_basic_lfp_sugars bs arg_Ts get_indices callssss lthy
| NONE => error "Not implemented yet");
-fun massage_nested_rec_call ctxt =
+fun rewrite_nested_rec_call ctxt =
(case Data.get (Proof_Context.theory_of ctxt) of
- SOME {massage_nested_rec_call, ...} => massage_nested_rec_call ctxt);
+ SOME {rewrite_nested_rec_call, ...} => rewrite_nested_rec_call ctxt);
fun rec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy0 =
let
@@ -193,9 +200,6 @@
val undef_const = Const (@{const_name undefined}, dummyT);
-fun permute_args n t =
- list_comb (t, map Bound (0 :: (n downto 1))) |> fold (K (Term.abs (Name.uu, dummyT))) (0 upto n);
-
type eqn_data = {
fun_name: string,
rec_type: typ,
@@ -212,30 +216,30 @@
let
val eqn = drop_all eqn' |> HOLogic.dest_Trueprop
handle TERM _ =>
- primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
+ raise PRIMREC ("malformed function equation (expected \"lhs = rhs\")", [eqn']);
val (lhs, rhs) = HOLogic.dest_eq eqn
handle TERM _ =>
- primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
+ raise PRIMREC ("malformed function equation (expected \"lhs = rhs\")", [eqn']);
val (fun_name, args) = strip_comb lhs
|>> (fn x => if is_Free x then fst (dest_Free x)
- else primrec_error_eqn "malformed function equation (does not start with free)" eqn);
+ else raise PRIMREC ("malformed function equation (does not start with free)", [eqn]));
val (left_args, rest) = take_prefix is_Free args;
val (nonfrees, right_args) = take_suffix is_Free rest;
val num_nonfrees = length nonfrees;
val _ = num_nonfrees = 1 orelse if num_nonfrees = 0 then
- primrec_error_eqn "constructor pattern missing in left-hand side" eqn else
- primrec_error_eqn "more than one non-variable argument in left-hand side" eqn;
+ raise PRIMREC ("constructor pattern missing in left-hand side", [eqn]) else
+ raise PRIMREC ("more than one non-variable argument in left-hand side", [eqn]);
val _ = member (op =) fun_names fun_name orelse
- primrec_error_eqn "malformed function equation (does not start with function name)" eqn
+ raise PRIMREC ("malformed function equation (does not start with function name)", [eqn]);
val (ctr, ctr_args) = strip_comb (the_single nonfrees);
val _ = try (num_binder_types o fastype_of) ctr = SOME (length ctr_args) orelse
- primrec_error_eqn "partially applied constructor in pattern" eqn;
+ raise PRIMREC ("partially applied constructor in pattern", [eqn]);
val _ = let val d = duplicates (op =) (left_args @ ctr_args @ right_args) in null d orelse
- primrec_error_eqn ("duplicate variable \"" ^ Syntax.string_of_term lthy (hd d) ^
- "\" in left-hand side") eqn end;
+ raise PRIMREC ("duplicate variable \"" ^ Syntax.string_of_term lthy (hd d) ^
+ "\" in left-hand side", [eqn]) end;
val _ = forall is_Free ctr_args orelse
- primrec_error_eqn "non-primitive pattern in left-hand side" eqn;
+ raise PRIMREC ("non-primitive pattern in left-hand side", [eqn]);
val _ =
let val b = fold_aterms (fn x as Free (v, _) =>
if (not (member (op =) (left_args @ ctr_args @ right_args) x) andalso
@@ -243,8 +247,8 @@
not (Variable.is_fixed lthy v)) then cons x else I | _ => I) rhs []
in
null b orelse
- primrec_error_eqn ("extra variable(s) in right-hand side: " ^
- commas (map (Syntax.string_of_term lthy) b)) eqn
+ raise PRIMREC ("extra variable(s) in right-hand side: " ^
+ commas (map (Syntax.string_of_term lthy) b), [eqn])
end;
in
{fun_name = fun_name,
@@ -258,39 +262,11 @@
user_eqn = eqn'}
end;
-fun rewrite_map_arg get_ctr_pos rec_type res_type =
- let
- val pT = HOLogic.mk_prodT (rec_type, res_type);
-
- fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT)
- | subst d (Abs (v, T, b)) =
- Abs (v, if d = SOME ~1 then pT else T, subst (Option.map (Integer.add 1) d) b)
- | subst d t =
- let
- val (u, vs) = strip_comb t;
- val ctr_pos = try (get_ctr_pos o fst o dest_Free) u |> the_default ~1;
- in
- if ctr_pos >= 0 then
- if d = SOME ~1 andalso length vs = ctr_pos then
- list_comb (permute_args ctr_pos (snd_const pT), vs)
- else if length vs > ctr_pos andalso is_some d
- andalso d = try (fn Bound n => n) (nth vs ctr_pos) then
- list_comb (snd_const pT $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs))
- else
- primrec_error_eqn ("recursive call not directly applied to constructor argument") t
- else
- list_comb (u, map (subst (d |> d = SOME ~1 ? K NONE)) vs)
- end
- in
- subst (SOME ~1)
- end;
-
fun subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls =
let
fun try_nested_rec bound_Ts y t =
AList.lookup (op =) nested_calls y
- |> Option.map (fn y' =>
- massage_nested_rec_call lthy has_call (rewrite_map_arg get_ctr_pos) bound_Ts y y' t);
+ |> Option.map (fn y' => rewrite_nested_rec_call lthy has_call get_ctr_pos bound_Ts y y' t);
fun subst bound_Ts (t as g' $ y) =
let
@@ -307,7 +283,7 @@
(case try (get_ctr_pos o fst o dest_Free) g of
SOME ctr_pos =>
(length g_args >= ctr_pos orelse
- primrec_error_eqn "too few arguments in recursive call" t;
+ raise PRIMREC ("too few arguments in recursive call", [t]);
(case AList.lookup (op =) mutual_calls y of
SOME y' => list_comb (y', g_args)
| NONE => subst_rec ()))
@@ -320,7 +296,7 @@
fun subst' t =
if has_call t then
(* FIXME detect this case earlier? *)
- primrec_error_eqn "recursive call not directly applied to constructor argument" t
+ raise PRIMREC ("recursive call not directly applied to constructor argument", [t])
else
try_nested_rec [] (head_of t) t |> the_default t
in
@@ -378,9 +354,9 @@
(take n_funs rec_specs |> map #ctr_specs) ~~ funs_data
|> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y))
##> (fn x => null x orelse
- primrec_error_eqns "excess equations in definition" (map #rhs_term x)) #> fst);
+ raise PRIMREC ("excess equations in definition", map #rhs_term x)) #> fst);
val _ = ctr_spec_eqn_data_list' |> map (fn (_, x) => length x <= 1 orelse
- primrec_error_eqns ("multiple equations for constructor") (map #user_eqn x));
+ raise PRIMREC ("multiple equations for constructor", map #user_eqn x));
val ctr_spec_eqn_data_list =
ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair []));
@@ -391,8 +367,8 @@
|> map (uncurry (build_rec_arg lthy funs_data has_call) o apsnd (try the_single));
val ctr_poss = map (fn x =>
if length (distinct ((op =) o pairself (length o #left_args)) x) <> 1 then
- primrec_error ("inconstant constructor pattern position for function " ^
- quote (#fun_name (hd x)))
+ raise PRIMREC ("inconstant constructor pattern position for function " ^
+ quote (#fun_name (hd x)), [])
else
hd x |> #left_args |> length) funs_data;
in
@@ -435,8 +411,7 @@
fun mk_primrec_tac ctxt num_extra_args map_idents map_comps fun_defs recx =
unfold_thms_tac ctxt fun_defs THEN
HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN
- unfold_thms_tac ctxt (@{thms id_def split comp_def fst_conv snd_conv} @ map_comps @
- map_idents) THEN
+ unfold_thms_tac ctxt (nested_simps ctxt @ map_comps @ map_idents) THEN
HEADGOAL (rtac refl);
fun prepare_primrec fixes specs lthy0 =
@@ -450,7 +425,7 @@
|> partition_eq ((op =) o pairself #fun_name)
|> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
|> map (fn (x, y) => the_single y
- handle List.Empty => primrec_error ("missing equations for function " ^ quote x));
+ handle List.Empty => raise PRIMREC ("missing equations for function " ^ quote x, []));
val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
val arg_Ts = map (#rec_type o hd) funs_data;
@@ -466,7 +441,7 @@
val _ = if exists is_only_old_datatype arg_Ts then raise OLD_PRIMREC () else ();
val _ = (case filter_out (fn (_, T) => Sign.of_sort thy (T, HOLogic.typeS)) (bs ~~ res_Ts) of
[] => ()
- | (b, _) :: _ => primrec_error ("type of " ^ Binding.print b ^ " contains top sort"));
+ | (b, _) :: _ => raise PRIMREC ("type of " ^ Binding.print b ^ " contains top sort", []));
val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy) =
rec_specs_of bs arg_Ts res_Ts (get_free_indices fixes) callssss lthy0;
@@ -476,8 +451,8 @@
val ctrs = maps (map #ctr o #ctr_specs) rec_specs;
val _ =
map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
- primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy ctr) ^
- " is not a constructor in left-hand side") user_eqn) eqns_data;
+ raise PRIMREC ("argument " ^ quote (Syntax.string_of_term lthy ctr) ^
+ " is not a constructor in left-hand side", [user_eqn])) eqns_data;
val defs = build_defs lthy bs mxs funs_data rec_specs has_call;
@@ -508,8 +483,8 @@
val notes =
(if n2m then
- map2 (fn name => fn thm =>
- (name, inductN, [thm], [])) fun_names (take actual_nn induct_thms)
+ map2 (fn name => fn thm => (name, inductN, [thm], [])) fun_names
+ (take actual_nn induct_thms)
else
[])
|> map (fn (prefix, thmN, thms, attrs) =>
@@ -531,7 +506,7 @@
fun add_primrec_simple fixes ts lthy =
let
val (((names, defs), prove), lthy') = prepare_primrec fixes ts lthy
- handle ERROR str => primrec_error str;
+ handle ERROR str => raise PRIMREC (str, []);
in
lthy'
|> fold_map Local_Theory.define defs
@@ -548,7 +523,7 @@
lthy =
let
val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes)
- val _ = null d orelse primrec_error ("duplicate function name(s): " ^ commas d);
+ val _ = null d orelse raise PRIMREC ("duplicate function name(s): " ^ commas d, []);
val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy);