# HG changeset patch # User panny # Date 1377965913 -7200 # Node ID 63015d03530139fd4b6578295752f50cf88fe63a # Parent a1cd4126a1c494fa4da942dae6642e968579df33 handle selector formulae with no corecursive calls diff -r a1cd4126a1c4 -r 63015d035301 src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Sat Aug 31 00:40:21 2013 +0200 +++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Sat Aug 31 18:18:33 2013 +0200 @@ -29,6 +29,8 @@ fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns); fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x); +fun abs_tuple t = if try (fst o dest_Const) t = SOME @{const_name undefined} then t else + strip_abs t |>> HOLogic.mk_tuple o map Free |-> HOLogic.tupled_lambda; val simp_attrs = @{attributes [simp]}; @@ -398,23 +400,23 @@ (* Primcorec *) -type co_eqn_data_dtr_disc = { +type co_eqn_data_disc = { fun_name: string, - ctr_no: int, + ctr_no: int, (*###*) cond: term, user_eqn: term }; -type co_eqn_data_dtr_sel = { +type co_eqn_data_sel = { fun_name: string, - ctr_no: int, - sel_no: int, + ctr: term, + sel: term, fun_args: term list, rhs_term: term, user_eqn: term }; datatype co_eqn_data = - Dtr_Disc of co_eqn_data_dtr_disc | - Dtr_Sel of co_eqn_data_dtr_sel + Disc of co_eqn_data_disc | + Sel of co_eqn_data_sel; fun co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps = let @@ -476,7 +478,7 @@ then (fun_name, cond) :: filter (not_equal fun_name o fst) matched_conds_ps else (fun_name, matched_cond) :: matched_conds_ps; in - (Dtr_Disc { + (Disc { fun_name = fun_name, ctr_no = ctr_no, cond = cond, @@ -495,15 +497,14 @@ primrec_error_eqn "malformed selector argument in left-hand side" eqn; val corec_spec = the (AList.lookup (op =) fun_name_corec_spec_list fun_name) handle Option.Option => primrec_error_eqn "malformed selector argument in left-hand side" eqn; - val ((ctr_spec, ctr_no), sel) = #ctr_specs corec_spec + val (ctr_spec, sel) = #ctr_specs corec_spec |> the o get_index (try (the o find_first (equal sel) o #sels)) - |>> `(nth (#ctr_specs corec_spec)); - val sel_no = find_index (equal sel) (#sels ctr_spec); + |>> nth (#ctr_specs corec_spec); in - Dtr_Sel { + Sel { fun_name = fun_name, - ctr_no = ctr_no, - sel_no = sel_no, + ctr = #ctr ctr_spec, + sel = sel, fun_args = fun_args, rhs_term = rhs, user_eqn = eqn' @@ -518,21 +519,24 @@ val (ctr, ctr_args) = strip_comb rhs; val ctr_spec = the (find_first (equal ctr o #ctr) (#ctr_specs corec_spec)) handle Option.Option => primrec_error_eqn "not a constructor" ctr; + val disc_imp_rhs = betapply (#disc ctr_spec, lhs); - val (eqn_data_disc, matched_conds_ps') = co_dissect_eqn_disc - sequential fun_name_corec_spec_list eqn' imp_lhs' disc_imp_rhs matched_conds_ps; + val (maybe_eqn_data_disc, matched_conds_ps') = if length (#ctr_specs corec_spec) = 1 + then (NONE, matched_conds_ps) + else apfst SOME (co_dissect_eqn_disc + sequential fun_name_corec_spec_list eqn' imp_lhs' disc_imp_rhs matched_conds_ps); val sel_imp_rhss = (#sels ctr_spec ~~ ctr_args) |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg)); val _ = warning ("reduced\n " ^ Syntax.string_of_term @{context} imp_rhs ^ "\nto\n \ " ^ - Syntax.string_of_term @{context} disc_imp_rhs ^ "\n \ " ^ + (is_some maybe_eqn_data_disc ? K (Syntax.string_of_term @{context} disc_imp_rhs ^ "\n \ ")) "" ^ space_implode "\n \ " (map (Syntax.string_of_term @{context}) sel_imp_rhss)); val eqns_data_sel = - map (co_dissect_eqn_sel fun_name_corec_spec_list @{const True}(*###*)) sel_imp_rhss; + map (co_dissect_eqn_sel fun_name_corec_spec_list eqn') sel_imp_rhss; in - (eqn_data_disc :: eqns_data_sel, matched_conds_ps') + (map_filter I [maybe_eqn_data_disc] @ eqns_data_sel, matched_conds_ps') end; fun co_dissect_eqn sequential fun_name_corec_spec_list eqn' matched_conds_ps = @@ -540,9 +544,8 @@ val eqn = subst_bounds (strip_qnt_vars @{const_name all} eqn' |> map Free |> rev, strip_qnt_body @{const_name all} eqn') handle TERM _ => primrec_error_eqn "malformed function equation" eqn'; - val (imp_lhs', imp_rhs) = - (map HOLogic.dest_Trueprop (Logic.strip_imp_prems eqn), - HOLogic.dest_Trueprop (Logic.strip_imp_concl eqn)); + val (imp_lhs', imp_rhs) = Logic.strip_horn eqn + |> apfst (map HOLogic.dest_Trueprop) o apsnd HOLogic.dest_Trueprop; val head = imp_rhs |> perhaps (try HOLogic.dest_not) |> perhaps (try (fst o HOLogic.dest_eq)) @@ -568,10 +571,10 @@ primrec_error_eqn "malformed function equation" eqn end; -fun build_corec_args_discs ctr_specs disc_eqns = +fun build_corec_args_discs disc_eqns ctr_specs = let val conds = map #cond disc_eqns; - val args = + val args' = if length ctr_specs = 1 then [] else if length disc_eqns = length ctr_specs then fst (split_last conds) @@ -592,33 +595,54 @@ |> Option.map #cond |> the_default (Const (@{const_name undefined}, dummyT))) |> fst o split_last; - fun finish t = - let val n = length (fastype_of t |> binder_types) in - if t = Const (@{const_name undefined}, dummyT) then t - else if n = 0 then Abs (Name.uu_, @{typ unit}, t) - else if n = 1 then t - else Const (@{const_name prod_case}, dummyT) $ t - end; in - map finish args + (* FIXME: deal with #preds above *) + fold2 (fn idx => nth_map idx o K o abs_tuple) (map_filter #pred ctr_specs) args' end; -fun build_corec_args_sel sel_eqns ctr_spec = - let - (* FIXME *) - val n_args = fold (curry (op +)) (map (fn Direct_Corec _ => 3 | _ => 1) (#calls ctr_spec)) 0; - in - replicate n_args (Const (@{const_name undefined}, dummyT)) +fun build_corec_args_sel all_sel_eqns ctr_spec = + let val sel_eqns = filter (equal (#ctr ctr_spec) o #ctr) all_sel_eqns in + if null sel_eqns then I else + let + val sel_call_list = #sels ctr_spec ~~ #calls ctr_spec; + +val _ = warning ("sels / calls:\n \ " ^ space_implode "\n \ " (map ((op ^) o + apfst (Syntax.string_of_term @{context}) o apsnd (curry (op ^) " / " o @{make_string})) + (sel_call_list))); + + (* FIXME get rid of dummy_no_calls' *) + val dummy_no_calls' = map_filter (try (apsnd (fn Dummy_No_Corec n => n))) sel_call_list; + val no_calls' = map_filter (try (apsnd (fn No_Corec n => n))) sel_call_list; + val direct_calls' = map_filter (try (apsnd (fn Direct_Corec n => n))) sel_call_list; + val indirect_calls' = map_filter (try (apsnd (fn Indirect_Corec n => n))) sel_call_list; + + fun build_arg_no_call sel = find_first (equal sel o #sel) sel_eqns |> #rhs_term o the; + fun build_arg_direct_call sel = primrec_error "not implemented yet"; + fun build_arg_indirect_call sel = primrec_error "not implemented yet"; + + val update_args = I + #> fold (fn (sel, rec_arg_idx) => nth_map rec_arg_idx + (build_arg_no_call sel |> K)) no_calls' + #> fold (fn (sel, rec_arg_idx) => nth_map rec_arg_idx + (build_arg_indirect_call sel |> K)) indirect_calls' + #> fold (fn (sel, (q_idx, g_idx, h_idx)) => + let val (q, g, h) = build_arg_indirect_call sel in + nth_map q_idx (K q) o nth_map g_idx (K g) o nth_map h_idx (K h) end) direct_calls'; + + val arg_idxs = maps (fn (_, (x, y, z)) => [x, y, z]) direct_calls' @ + maps (map snd) [dummy_no_calls', no_calls', indirect_calls']; + val abs_args = fold (fn idx => nth_map idx + (abs_tuple o fold_rev absfree (sel_eqns |> #fun_args o hd |> map dest_Free))) arg_idxs; + in + abs_args o update_args + end end; fun co_build_defs lthy sequential bs arg_Tss fun_name_corec_spec_list eqns_data = let val fun_names = map Binding.name_of bs; -(* fun group _ [] = [] (* FIXME \? *) - | group eq (x :: xs) = - let val (xs', ys) = List.partition (eq x) xs in (x :: xs') :: group eq ys end;*) - val disc_eqnss = map_filter (try (fn Dtr_Disc x => x)) eqns_data + val disc_eqnss = map_filter (try (fn Disc x => x)) eqns_data |> partition_eq ((op =) o pairself #fun_name) |> finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names |> fst |> map (sort ((op <) o pairself #ctr_no |> make_ord) o flat o snd); @@ -630,20 +654,20 @@ val _ = warning ("disc_eqnss:\n \ " ^ space_implode "\n \ " (map @{make_string} disc_eqnss)); - val sel_eqnss = map_filter (try (fn Dtr_Sel x => x)) eqns_data + val sel_eqnss = map_filter (try (fn Sel x => x)) eqns_data |> partition_eq ((op =) o pairself #fun_name) |> finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names |> fst - |> map (sort ((op <) o pairself #ctr_no |> make_ord) o flat o snd); + |> map (flat o snd); val _ = warning ("sel_eqnss:\n \ " ^ space_implode "\n \ " (map @{make_string} sel_eqnss)); - fun splice (xs :: xss) (ys :: yss) = xs @ ys @ splice xss yss (* FIXME \? *) - | splice xss yss = flat xss @ flat yss; val corecs = map (#corec o snd) fun_name_corec_spec_list; - val corec_args = (map snd fun_name_corec_spec_list ~~ disc_eqnss ~~ sel_eqnss) - |> maps (fn (({ctr_specs, ...}, disc_eqns), sel_eqns) => - splice (build_corec_args_discs ctr_specs disc_eqns |> map single) - (map (build_corec_args_sel sel_eqns) ctr_specs)); + val ctr_specss = map (#ctr_specs o snd) fun_name_corec_spec_list; + val n_args = fold (curry (op +)) (map (K 1) (maps (map_filter #pred) ctr_specss) @ + map (fn Direct_Corec _ => 3 | _ => 1) (maps (maps #calls) ctr_specss)) 0; + val corec_args = replicate n_args (Const (@{const_name undefined}, dummyT)) + |> fold2 build_corec_args_discs disc_eqnss ctr_specss + |> fold2 (fn sel_eqns => fold (build_corec_args_sel sel_eqns)) sel_eqnss ctr_specss; val _ = warning ("corecursor arguments:\n \ " ^ space_implode "\n \ " (map (Syntax.string_of_term @{context}) corec_args));