# HG changeset patch # User panny # Date 1378036805 -7200 # Node ID b46e6cd75dc67fc44b8b76835812eb8a9e171803 # Parent 46b0c7a08af7bfcfcb3a9a31774b119dfc1baebf improved interfaces diff -r 46b0c7a08af7 -r b46e6cd75dc6 src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Sun Sep 01 10:45:54 2013 +0200 +++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Sun Sep 01 14:00:05 2013 +0200 @@ -28,11 +28,15 @@ fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]); fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns); +fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x); + val free_name = try (fn Free (v, _) => v); val const_name = try (fn Const (v, _) => v); +val undef_const = Const (@{const_name undefined}, dummyT); -fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x); -fun abs_tuple t = if const_name t = SOME @{const_name undefined} then t else +fun permute_args n t = list_comb (t, map Bound (0 :: (n downto 1))) + |> fold (K (fn u => Abs (Name.uu, dummyT, u))) (0 upto n); +fun abs_tuple t = if t = undef_const then t else strip_abs t |>> HOLogic.mk_tuple o map Free |-> HOLogic.tupled_lambda; val simp_attrs = @{attributes [simp]}; @@ -53,9 +57,6 @@ user_eqn: term }; -fun permute_args n t = list_comb (t, map Bound (0 :: (n downto 1))) - |> fold (K (fn u => Abs (Name.uu, dummyT, u))) (0 upto n); - fun dissect_eqn lthy fun_names eqn' = let val eqn = subst_bounds (strip_qnt_vars @{const_name all} eqn' |> map Free |> rev, @@ -106,20 +107,21 @@ user_eqn = eqn'} end; -fun rewrite_map_arg funs_data get_indices rec_type res_type = +fun rewrite_map_arg fun_name_ctr_pos_list rec_type res_type = let - val fun_data = hd (the (find_first (equal rec_type o #rec_type o hd) funs_data)); - val fun_name = #fun_name fun_data; - val ctr_pos = length (#left_args fun_data); - val pT = HOLogic.mk_prodT (rec_type, res_type); val maybe_suc = Option.map (fn x => x + 1); fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT) | subst d (Abs (v, T, b)) = Abs (v, if d = SOME ~1 then pT else T, subst (maybe_suc d) b) | subst d t = - let val (u, vs) = strip_comb t in - if free_name u = SOME fun_name then + let + val (u, vs) = strip_comb t; + val maybe_fun_name_ctr_pos = + find_first (equal (free_name u) o SOME o fst) fun_name_ctr_pos_list; + val (fun_name, ctr_pos) = the_default ("", ~1) maybe_fun_name_ctr_pos; + in + if is_some maybe_fun_name_ctr_pos then if d = SOME ~1 andalso length vs = ctr_pos then list_comb (permute_args ctr_pos (snd_const pT), vs) else if length vs > ctr_pos andalso is_some d @@ -136,42 +138,40 @@ subst (SOME ~1) end; -(* FIXME get rid of funs_data or get_indices *) -fun subst_rec_calls lthy funs_data get_indices direct_calls indirect_calls t = +fun subst_rec_calls lthy fun_name_ctr_pos_list has_call ctr_args direct_calls indirect_calls t = let - val contains_fun = not o null o get_indices; fun subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b) - | subst bound_Ts (t as g $ y) = + | subst bound_Ts (t as g' $ y) = let - val is_ctr_arg = exists (exists (exists (equal y) o #ctr_args)) funs_data; val maybe_direct_y' = AList.lookup (op =) direct_calls y; val maybe_indirect_y' = AList.lookup (op =) indirect_calls y; - val (g_head, g_args) = strip_comb g; + val (g, g_args) = strip_comb g'; + val maybe_ctr_pos = + try (snd o the o find_first (equal (free_name g) o SOME o fst)) fun_name_ctr_pos_list; + val _ = is_none maybe_ctr_pos orelse length g_args >= the maybe_ctr_pos orelse + primrec_error_eqn "too few arguments in recursive call" t; in - if not is_ctr_arg then - pairself (subst bound_Ts) (g, y) |> (op $) - else if contains_fun g_head then - (length g_args >= the (funs_data |> get_first (fn {fun_name, left_args, ...} :: _ => - if fst (dest_Free g_head) = fun_name then SOME (length left_args) else NONE)) (*###*) - orelse primrec_error_eqn "too few arguments in recursive call" t; - list_comb (the maybe_direct_y', g_args)) + if not (member (op =) ctr_args y) then + pairself (subst bound_Ts) (g', y) |> (op $) + else if is_some maybe_ctr_pos then + list_comb (the maybe_direct_y', g_args) else if is_some maybe_indirect_y' then - (if contains_fun g then t else y) - |> massage_indirect_rec_call lthy contains_fun - (rewrite_map_arg funs_data get_indices) bound_Ts y (the maybe_indirect_y') - |> (if contains_fun g then I else curry (op $) g) + (if has_call g' then t else y) + |> massage_indirect_rec_call lthy has_call + (rewrite_map_arg fun_name_ctr_pos_list) bound_Ts y (the maybe_indirect_y') + |> (if has_call g' then I else curry (op $) g') else t end | subst _ t = t in subst [] t - |> (fn u => ((contains_fun u andalso (* FIXME detect this case earlier *) - primrec_error_eqn "recursive call not directly applied to constructor argument" t); u)) + |> tap (fn u => has_call u andalso (* FIXME detect this case earlier *) + primrec_error_eqn "recursive call not directly applied to constructor argument" t) end; -fun build_rec_arg lthy get_indices funs_data ctr_spec maybe_eqn_data = - if is_none maybe_eqn_data then Const (@{const_name undefined}, dummyT) else +fun build_rec_arg lthy funs_data has_call ctr_spec maybe_eqn_data = + if is_none maybe_eqn_data then undef_const else let val eqn_data = the maybe_eqn_data; val t = #rhs_term eqn_data; @@ -215,13 +215,15 @@ val indirect_calls = map (apfst (nth ctr_args) o apsnd (nth args)) indirect_calls'; val abstractions = map dest_Free (args @ #left_args eqn_data @ #right_args eqn_data); + val fun_name_ctr_pos_list = + map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data; in t - |> subst_rec_calls lthy funs_data get_indices direct_calls indirect_calls + |> subst_rec_calls lthy fun_name_ctr_pos_list has_call ctr_args direct_calls indirect_calls |> fold_rev absfree abstractions end; -fun build_defs lthy bs mxs funs_data rec_specs get_indices = +fun build_defs lthy bs mxs funs_data rec_specs has_call = let val n_funs = length funs_data; @@ -239,7 +241,7 @@ val recs = take n_funs rec_specs |> map #recx; val rec_args = ctr_spec_eqn_data_list |> sort ((op <) o pairself (#offset o fst) |> make_ord) - |> map (uncurry (build_rec_arg lthy get_indices funs_data) o apsnd (try the_single)); + |> map (uncurry (build_rec_arg lthy funs_data has_call) o apsnd (try the_single)); val ctr_poss = map (fn x => if length (distinct ((op =) o pairself (length o #left_args)) x) <> 1 then primrec_error ("inconstant constructor pattern position for function " ^ @@ -253,7 +255,7 @@ |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.map_name Thm.def_name b, []), t))) bs mxs end; -fun find_rec_calls get_indices eqn_data = +fun find_rec_calls has_call eqn_data = let fun find (Abs (_, _, b)) ctr_arg = find b ctr_arg | find (t as _ $ _) ctr_arg = @@ -265,7 +267,7 @@ find f' ctr_arg @ maps (fn x => find x ctr_arg) args' else let val (f, args) = chop n args' |>> curry list_comb f' in - if exists_subterm (not o null o get_indices) f then + if has_call f then f :: maps (fn x => find x ctr_arg) args else find f ctr_arg @ maps (fn x => find x ctr_arg) args @@ -288,16 +290,16 @@ |> map (fn (x, y) => the_single y handle List.Empty => primrec_error ("missing equations for function " ^ quote x)); - fun get_indices t = map (fst #>> Binding.name_of #> Free) fixes - |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE) - |> map_filter I; - + val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =)); val arg_Ts = map (#rec_type o hd) funs_data; val res_Ts = map (#res_type o hd) funs_data; val callssss = funs_data |> map (partition_eq ((op =) o pairself #ctr)) - |> map (maps (map_filter (find_rec_calls get_indices))); + |> map (maps (map_filter (find_rec_calls has_call))); + fun get_indices t = map (fst #>> Binding.name_of #> Free) fixes + |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE) + |> map_filter I; val ((nontriv, rec_specs, _, induct_thm, induct_thms), lthy') = rec_specs_of bs arg_Ts res_Ts get_indices callssss lthy; @@ -308,7 +310,7 @@ primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy' ctr) ^ " is not a constructor in left-hand side") user_eqn) eqns_data end; - val defs = build_defs lthy' bs mxs funs_data rec_specs get_indices; + val defs = build_defs lthy' bs mxs funs_data rec_specs has_call; fun prove def_thms' {ctr_specs, nested_map_idents, nested_map_comps, ...} induct_thm fun_data lthy = @@ -561,7 +563,7 @@ 0 upto length ctr_specs - 1 |> map (fn idx => find_first (equal idx o #ctr_no) disc_eqns |> Option.map #cond - |> the_default (Const (@{const_name undefined}, dummyT))) + |> the_default undef_const) |> fst o split_last; in (* FIXME: deal with #preds above *) @@ -633,7 +635,7 @@ val ctr_specss = map (#ctr_specs o snd) fun_name_corec_spec_list; val n_args = fold (curry (op +)) (map (K 1) (maps (map_filter #pred) ctr_specss) @ map (fn Direct_Corec _ => 3 | _ => 1) (maps (maps #calls) ctr_specss)) 0; - val corec_args = replicate n_args (Const (@{const_name undefined}, dummyT)) + val corec_args = replicate n_args undef_const |> fold2 build_corec_args_discs disc_eqnss ctr_specss |> fold2 (fn sel_eqns => fold (build_corec_args_sel sel_eqns)) sel_eqnss ctr_specss;