--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Sat Aug 31 00:40:21 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Sat Aug 31 18:18:33 2013 +0200
@@ -29,6 +29,8 @@
fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns);
fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x);
+fun abs_tuple t = if try (fst o dest_Const) t = SOME @{const_name undefined} then t else
+ strip_abs t |>> HOLogic.mk_tuple o map Free |-> HOLogic.tupled_lambda;
val simp_attrs = @{attributes [simp]};
@@ -398,23 +400,23 @@
(* Primcorec *)
-type co_eqn_data_dtr_disc = {
+type co_eqn_data_disc = {
fun_name: string,
- ctr_no: int,
+ ctr_no: int, (*###*)
cond: term,
user_eqn: term
};
-type co_eqn_data_dtr_sel = {
+type co_eqn_data_sel = {
fun_name: string,
- ctr_no: int,
- sel_no: int,
+ ctr: term,
+ sel: term,
fun_args: term list,
rhs_term: term,
user_eqn: term
};
datatype co_eqn_data =
- Dtr_Disc of co_eqn_data_dtr_disc |
- Dtr_Sel of co_eqn_data_dtr_sel
+ Disc of co_eqn_data_disc |
+ Sel of co_eqn_data_sel;
fun co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps =
let
@@ -476,7 +478,7 @@
then (fun_name, cond) :: filter (not_equal fun_name o fst) matched_conds_ps
else (fun_name, matched_cond) :: matched_conds_ps;
in
- (Dtr_Disc {
+ (Disc {
fun_name = fun_name,
ctr_no = ctr_no,
cond = cond,
@@ -495,15 +497,14 @@
primrec_error_eqn "malformed selector argument in left-hand side" eqn;
val corec_spec = the (AList.lookup (op =) fun_name_corec_spec_list fun_name)
handle Option.Option => primrec_error_eqn "malformed selector argument in left-hand side" eqn;
- val ((ctr_spec, ctr_no), sel) = #ctr_specs corec_spec
+ val (ctr_spec, sel) = #ctr_specs corec_spec
|> the o get_index (try (the o find_first (equal sel) o #sels))
- |>> `(nth (#ctr_specs corec_spec));
- val sel_no = find_index (equal sel) (#sels ctr_spec);
+ |>> nth (#ctr_specs corec_spec);
in
- Dtr_Sel {
+ Sel {
fun_name = fun_name,
- ctr_no = ctr_no,
- sel_no = sel_no,
+ ctr = #ctr ctr_spec,
+ sel = sel,
fun_args = fun_args,
rhs_term = rhs,
user_eqn = eqn'
@@ -518,21 +519,24 @@
val (ctr, ctr_args) = strip_comb rhs;
val ctr_spec = the (find_first (equal ctr o #ctr) (#ctr_specs corec_spec))
handle Option.Option => primrec_error_eqn "not a constructor" ctr;
+
val disc_imp_rhs = betapply (#disc ctr_spec, lhs);
- val (eqn_data_disc, matched_conds_ps') = co_dissect_eqn_disc
- sequential fun_name_corec_spec_list eqn' imp_lhs' disc_imp_rhs matched_conds_ps;
+ val (maybe_eqn_data_disc, matched_conds_ps') = if length (#ctr_specs corec_spec) = 1
+ then (NONE, matched_conds_ps)
+ else apfst SOME (co_dissect_eqn_disc
+ sequential fun_name_corec_spec_list eqn' imp_lhs' disc_imp_rhs matched_conds_ps);
val sel_imp_rhss = (#sels ctr_spec ~~ ctr_args)
|> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
val _ = warning ("reduced\n " ^ Syntax.string_of_term @{context} imp_rhs ^ "\nto\n \<cdot> " ^
- Syntax.string_of_term @{context} disc_imp_rhs ^ "\n \<cdot> " ^
+ (is_some maybe_eqn_data_disc ? K (Syntax.string_of_term @{context} disc_imp_rhs ^ "\n \<cdot> ")) "" ^
space_implode "\n \<cdot> " (map (Syntax.string_of_term @{context}) sel_imp_rhss));
val eqns_data_sel =
- map (co_dissect_eqn_sel fun_name_corec_spec_list @{const True}(*###*)) sel_imp_rhss;
+ map (co_dissect_eqn_sel fun_name_corec_spec_list eqn') sel_imp_rhss;
in
- (eqn_data_disc :: eqns_data_sel, matched_conds_ps')
+ (map_filter I [maybe_eqn_data_disc] @ eqns_data_sel, matched_conds_ps')
end;
fun co_dissect_eqn sequential fun_name_corec_spec_list eqn' matched_conds_ps =
@@ -540,9 +544,8 @@
val eqn = subst_bounds (strip_qnt_vars @{const_name all} eqn' |> map Free |> rev,
strip_qnt_body @{const_name all} eqn')
handle TERM _ => primrec_error_eqn "malformed function equation" eqn';
- val (imp_lhs', imp_rhs) =
- (map HOLogic.dest_Trueprop (Logic.strip_imp_prems eqn),
- HOLogic.dest_Trueprop (Logic.strip_imp_concl eqn));
+ val (imp_lhs', imp_rhs) = Logic.strip_horn eqn
+ |> apfst (map HOLogic.dest_Trueprop) o apsnd HOLogic.dest_Trueprop;
val head = imp_rhs
|> perhaps (try HOLogic.dest_not) |> perhaps (try (fst o HOLogic.dest_eq))
@@ -568,10 +571,10 @@
primrec_error_eqn "malformed function equation" eqn
end;
-fun build_corec_args_discs ctr_specs disc_eqns =
+fun build_corec_args_discs disc_eqns ctr_specs =
let
val conds = map #cond disc_eqns;
- val args =
+ val args' =
if length ctr_specs = 1 then []
else if length disc_eqns = length ctr_specs then
fst (split_last conds)
@@ -592,33 +595,54 @@
|> Option.map #cond
|> the_default (Const (@{const_name undefined}, dummyT)))
|> fst o split_last;
- fun finish t =
- let val n = length (fastype_of t |> binder_types) in
- if t = Const (@{const_name undefined}, dummyT) then t
- else if n = 0 then Abs (Name.uu_, @{typ unit}, t)
- else if n = 1 then t
- else Const (@{const_name prod_case}, dummyT) $ t
- end;
in
- map finish args
+ (* FIXME: deal with #preds above *)
+ fold2 (fn idx => nth_map idx o K o abs_tuple) (map_filter #pred ctr_specs) args'
end;
-fun build_corec_args_sel sel_eqns ctr_spec =
- let
- (* FIXME *)
- val n_args = fold (curry (op +)) (map (fn Direct_Corec _ => 3 | _ => 1) (#calls ctr_spec)) 0;
- in
- replicate n_args (Const (@{const_name undefined}, dummyT))
+fun build_corec_args_sel all_sel_eqns ctr_spec =
+ let val sel_eqns = filter (equal (#ctr ctr_spec) o #ctr) all_sel_eqns in
+ if null sel_eqns then I else
+ let
+ val sel_call_list = #sels ctr_spec ~~ #calls ctr_spec;
+
+val _ = warning ("sels / calls:\n \<cdot> " ^ space_implode "\n \<cdot> " (map ((op ^) o
+ apfst (Syntax.string_of_term @{context}) o apsnd (curry (op ^) " / " o @{make_string}))
+ (sel_call_list)));
+
+ (* FIXME get rid of dummy_no_calls' *)
+ val dummy_no_calls' = map_filter (try (apsnd (fn Dummy_No_Corec n => n))) sel_call_list;
+ val no_calls' = map_filter (try (apsnd (fn No_Corec n => n))) sel_call_list;
+ val direct_calls' = map_filter (try (apsnd (fn Direct_Corec n => n))) sel_call_list;
+ val indirect_calls' = map_filter (try (apsnd (fn Indirect_Corec n => n))) sel_call_list;
+
+ fun build_arg_no_call sel = find_first (equal sel o #sel) sel_eqns |> #rhs_term o the;
+ fun build_arg_direct_call sel = primrec_error "not implemented yet";
+ fun build_arg_indirect_call sel = primrec_error "not implemented yet";
+
+ val update_args = I
+ #> fold (fn (sel, rec_arg_idx) => nth_map rec_arg_idx
+ (build_arg_no_call sel |> K)) no_calls'
+ #> fold (fn (sel, rec_arg_idx) => nth_map rec_arg_idx
+ (build_arg_indirect_call sel |> K)) indirect_calls'
+ #> fold (fn (sel, (q_idx, g_idx, h_idx)) =>
+ let val (q, g, h) = build_arg_indirect_call sel in
+ nth_map q_idx (K q) o nth_map g_idx (K g) o nth_map h_idx (K h) end) direct_calls';
+
+ val arg_idxs = maps (fn (_, (x, y, z)) => [x, y, z]) direct_calls' @
+ maps (map snd) [dummy_no_calls', no_calls', indirect_calls'];
+ val abs_args = fold (fn idx => nth_map idx
+ (abs_tuple o fold_rev absfree (sel_eqns |> #fun_args o hd |> map dest_Free))) arg_idxs;
+ in
+ abs_args o update_args
+ end
end;
fun co_build_defs lthy sequential bs arg_Tss fun_name_corec_spec_list eqns_data =
let
val fun_names = map Binding.name_of bs;
-(* fun group _ [] = [] (* FIXME \<exists>? *)
- | group eq (x :: xs) =
- let val (xs', ys) = List.partition (eq x) xs in (x :: xs') :: group eq ys end;*)
- val disc_eqnss = map_filter (try (fn Dtr_Disc x => x)) eqns_data
+ val disc_eqnss = map_filter (try (fn Disc x => x)) eqns_data
|> partition_eq ((op =) o pairself #fun_name)
|> finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names |> fst
|> map (sort ((op <) o pairself #ctr_no |> make_ord) o flat o snd);
@@ -630,20 +654,20 @@
val _ = warning ("disc_eqnss:\n \<cdot> " ^ space_implode "\n \<cdot> " (map @{make_string} disc_eqnss));
- val sel_eqnss = map_filter (try (fn Dtr_Sel x => x)) eqns_data
+ val sel_eqnss = map_filter (try (fn Sel x => x)) eqns_data
|> partition_eq ((op =) o pairself #fun_name)
|> finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names |> fst
- |> map (sort ((op <) o pairself #ctr_no |> make_ord) o flat o snd);
+ |> map (flat o snd);
val _ = warning ("sel_eqnss:\n \<cdot> " ^ space_implode "\n \<cdot> " (map @{make_string} sel_eqnss));
- fun splice (xs :: xss) (ys :: yss) = xs @ ys @ splice xss yss (* FIXME \<exists>? *)
- | splice xss yss = flat xss @ flat yss;
val corecs = map (#corec o snd) fun_name_corec_spec_list;
- val corec_args = (map snd fun_name_corec_spec_list ~~ disc_eqnss ~~ sel_eqnss)
- |> maps (fn (({ctr_specs, ...}, disc_eqns), sel_eqns) =>
- splice (build_corec_args_discs ctr_specs disc_eqns |> map single)
- (map (build_corec_args_sel sel_eqns) ctr_specs));
+ val ctr_specss = map (#ctr_specs o snd) fun_name_corec_spec_list;
+ val n_args = fold (curry (op +)) (map (K 1) (maps (map_filter #pred) ctr_specss) @
+ map (fn Direct_Corec _ => 3 | _ => 1) (maps (maps #calls) ctr_specss)) 0;
+ val corec_args = replicate n_args (Const (@{const_name undefined}, dummyT))
+ |> fold2 build_corec_args_discs disc_eqnss ctr_specss
+ |> fold2 (fn sel_eqns => fold (build_corec_args_sel sel_eqns)) sel_eqnss ctr_specss;
val _ = warning ("corecursor arguments:\n \<cdot> " ^
space_implode "\n \<cdot> " (map (Syntax.string_of_term @{context}) corec_args));