# HG changeset patch # User panny # Date 1378253510 -7200 # Node ID 2101a97e6220f19f644b2ab19c711620814f5505 # Parent a1a78a2716822cd62edad2adef818f3f286e276a various refactoring; handle self-mappings; handle range types containing function types; diff -r a1a78a271682 -r 2101a97e6220 src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Tue Sep 03 21:46:42 2013 +0100 +++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Wed Sep 04 02:11:50 2013 +0200 @@ -36,8 +36,7 @@ fun permute_args n t = list_comb (t, map Bound (0 :: (n downto 1))) |> fold (K (fn u => Abs (Name.uu, dummyT, u))) (0 upto n); -fun abs_tuple t = if t = undef_const then t else - strip_abs t |>> HOLogic.mk_tuple o map Free |-> HOLogic.tupled_lambda; +val abs_tuple = HOLogic.tupled_lambda o HOLogic.mk_tuple; val simp_attrs = @{attributes [simp]}; @@ -107,7 +106,7 @@ user_eqn = eqn'} end; -fun rewrite_map_arg fun_name_ctr_pos_list rec_type res_type = +fun rewrite_map_arg get_ctr_pos rec_type res_type = let val pT = HOLogic.mk_prodT (rec_type, res_type); @@ -117,11 +116,9 @@ | subst d t = let val (u, vs) = strip_comb t; - val maybe_fun_name_ctr_pos = - find_first (equal (free_name u) o SOME o fst) fun_name_ctr_pos_list; - val (fun_name, ctr_pos) = the_default ("", ~1) maybe_fun_name_ctr_pos; + val ctr_pos = try (get_ctr_pos o the) (free_name u) |> the_default ~1; in - if is_some maybe_fun_name_ctr_pos then + if ctr_pos >= 0 then if d = SOME ~1 andalso length vs = ctr_pos then list_comb (permute_args ctr_pos (snd_const pT), vs) else if length vs > ctr_pos andalso is_some d @@ -138,7 +135,7 @@ subst (SOME ~1) end; -fun subst_rec_calls lthy fun_name_ctr_pos_list has_call ctr_args direct_calls indirect_calls t = +fun subst_rec_calls lthy get_ctr_pos has_call ctr_args direct_calls indirect_calls t = let fun subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b) | subst bound_Ts (t as g' $ y) = @@ -146,19 +143,18 @@ val maybe_direct_y' = AList.lookup (op =) direct_calls y; val maybe_indirect_y' = AList.lookup (op =) indirect_calls y; val (g, g_args) = strip_comb g'; - val maybe_ctr_pos = - try (snd o the o find_first (equal (free_name g) o SOME o fst)) fun_name_ctr_pos_list; - val _ = is_none maybe_ctr_pos orelse length g_args >= the maybe_ctr_pos orelse + val ctr_pos = try (get_ctr_pos o the) (free_name g) |> the_default ~1; + val _ = ctr_pos < 0 orelse length g_args >= ctr_pos orelse primrec_error_eqn "too few arguments in recursive call" t; in if not (member (op =) ctr_args y) then pairself (subst bound_Ts) (g', y) |> (op $) - else if is_some maybe_ctr_pos then + else if ctr_pos >= 0 then list_comb (the maybe_direct_y', g_args) else if is_some maybe_indirect_y' then (if has_call g' then t else y) |> massage_indirect_rec_call lthy has_call - (rewrite_map_arg fun_name_ctr_pos_list) bound_Ts y (the maybe_indirect_y') + (rewrite_map_arg get_ctr_pos) bound_Ts y (the maybe_indirect_y') |> (if has_call g' then I else curry (op $) g') else t @@ -211,16 +207,17 @@ nth_map arg_idx (K (nth ctr_args ctr_arg_idx |> map_types make_indirect_type))) indirect_calls'; + val fun_name_ctr_pos_list = + map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data; + val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1; val direct_calls = map (apfst (nth ctr_args) o apsnd (nth args)) direct_calls'; val indirect_calls = map (apfst (nth ctr_args) o apsnd (nth args)) indirect_calls'; - val abstractions = map dest_Free (args @ #left_args eqn_data @ #right_args eqn_data); - val fun_name_ctr_pos_list = - map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data; + val abstractions = args @ #left_args eqn_data @ #right_args eqn_data; in t - |> subst_rec_calls lthy fun_name_ctr_pos_list has_call ctr_args direct_calls indirect_calls - |> fold_rev absfree abstractions + |> subst_rec_calls lthy get_ctr_pos has_call ctr_args direct_calls indirect_calls + |> fold_rev lambda abstractions end; fun build_defs lthy bs mxs funs_data rec_specs has_call = @@ -372,15 +369,16 @@ type co_eqn_data_disc = { fun_name: string, + fun_args: term list, ctr_no: int, (*###*) cond: term, user_eqn: term }; type co_eqn_data_sel = { fun_name: string, + fun_args: term list, ctr: term, sel: term, - fun_args: term list, rhs_term: term, user_eqn: term }; @@ -388,11 +386,10 @@ Disc of co_eqn_data_disc | Sel of co_eqn_data_sel; -fun co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps = +fun co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds = let fun find_subterm p = let (* FIXME \? *) - fun f (t as u $ v) = - fold_rev (curry merge_options) [if p t then SOME t else NONE, f u, f v] NONE + fun f (t as u $ v) = if p t then SOME t else merge_options (f u, f v) | f t = if p t then SOME t else NONE in f end; @@ -406,9 +403,8 @@ val discs = #ctr_specs corec_spec |> map #disc; val ctrs = #ctr_specs corec_spec |> map #ctr; - val n_ctrs = length ctrs; val not_disc = head_of imp_rhs = @{term Not}; - val _ = not_disc andalso n_ctrs <> 2 andalso + val _ = not_disc andalso length ctrs <> 2 andalso primrec_error_eqn "\ed discriminator for a type with \ 2 constructors" imp_rhs; val disc = find_subterm (member (op =) discs o head_of) imp_rhs; val eq_ctr0 = imp_rhs |> perhaps (try (HOLogic.dest_not)) |> try (HOLogic.dest_eq #> snd) @@ -428,32 +424,28 @@ val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True}; val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False}; val catch_all = try (fst o dest_Free o the_single) imp_lhs' = SOME Name.uu_; - val matched_conds = filter (equal fun_name o fst) matched_conds_ps |> map snd; - val imp_lhs = mk_conjs imp_lhs'; + val matched_cond = filter (equal fun_name o fst) matched_conds |> map snd |> mk_disjs; + val imp_lhs = mk_conjs imp_lhs' + |> incr_boundvars (length fun_args) + |> subst_atomic (fun_args ~~ map Bound (length fun_args - 1 downto 0)) val cond = if catch_all then - if null matched_conds then fold_rev absfree (map dest_Free fun_args) @{const True} else - (strip_abs_vars (hd matched_conds), - mk_disjs (map strip_abs_body matched_conds) |> HOLogic.mk_not) - |-> fold_rev (fn (v, T) => fn u => Abs (v, T, u)) + matched_cond |> HOLogic.mk_not else if sequential then - HOLogic.mk_conj (HOLogic.mk_not (mk_disjs (map strip_abs_body matched_conds)), imp_lhs) - |> fold_rev absfree (map dest_Free fun_args) + HOLogic.mk_conj (HOLogic.mk_not matched_cond, imp_lhs) else - imp_lhs |> fold_rev absfree (map dest_Free fun_args); - val matched_cond = - if sequential then fold_rev absfree (map dest_Free fun_args) imp_lhs else cond; + imp_lhs; - val matched_conds_ps' = if catch_all - then (fun_name, cond) :: filter (not_equal fun_name o fst) matched_conds_ps - else (fun_name, matched_cond) :: matched_conds_ps; + val matched_conds' = + (fun_name, if catch_all orelse not sequential then cond else imp_lhs) :: matched_conds; in (Disc { fun_name = fun_name, + fun_args = fun_args, ctr_no = ctr_no, cond = cond, user_eqn = eqn' - }, matched_conds_ps') + }, matched_conds') end; fun co_dissect_eqn_sel fun_name_corec_spec_list eqn' eqn = @@ -473,15 +465,15 @@ in Sel { fun_name = fun_name, + fun_args = fun_args, ctr = #ctr ctr_spec, sel = sel, - fun_args = fun_args, rhs_term = rhs, user_eqn = eqn' } end; -fun co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps = +fun co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds = let val (lhs, rhs) = HOLogic.dest_eq imp_rhs; val fun_name = head_of lhs |> fst o dest_Free; @@ -491,10 +483,10 @@ handle Option.Option => primrec_error_eqn "not a constructor" ctr; val disc_imp_rhs = betapply (#disc ctr_spec, lhs); - val (maybe_eqn_data_disc, matched_conds_ps') = if length (#ctr_specs corec_spec) = 1 - then (NONE, matched_conds_ps) + val (maybe_eqn_data_disc, matched_conds') = if length (#ctr_specs corec_spec) = 1 + then (NONE, matched_conds) else apfst SOME (co_dissect_eqn_disc - sequential fun_name_corec_spec_list eqn' imp_lhs' disc_imp_rhs matched_conds_ps); + sequential fun_name_corec_spec_list eqn' imp_lhs' disc_imp_rhs matched_conds); val sel_imp_rhss = (#sels ctr_spec ~~ ctr_args) |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg)); @@ -506,10 +498,10 @@ val eqns_data_sel = map (co_dissect_eqn_sel fun_name_corec_spec_list eqn') sel_imp_rhss; in - (map_filter I [maybe_eqn_data_disc] @ eqns_data_sel, matched_conds_ps') + (map_filter I [maybe_eqn_data_disc] @ eqns_data_sel, matched_conds') end; -fun co_dissect_eqn sequential fun_name_corec_spec_list eqn' matched_conds_ps = +fun co_dissect_eqn sequential fun_name_corec_spec_list eqn' matched_conds = let val eqn = subst_bounds (strip_qnt_vars @{const_name all} eqn' |> map Free |> rev, strip_qnt_body @{const_name all} eqn') @@ -531,65 +523,68 @@ if member (op =) discs head orelse is_some maybe_rhs andalso member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then - co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps + co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds |>> single else if member (op =) sels head then - ([co_dissect_eqn_sel fun_name_corec_spec_list eqn' imp_rhs], matched_conds_ps) + ([co_dissect_eqn_sel fun_name_corec_spec_list eqn' imp_rhs], matched_conds) else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) then - co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds_ps + co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds else primrec_error_eqn "malformed function equation" eqn end; fun build_corec_args_discs disc_eqns ctr_specs = - let - val conds = map #cond disc_eqns; - val args' = - if length ctr_specs = 1 then [] - else if length disc_eqns = length ctr_specs then - fst (split_last conds) - else if length disc_eqns = length ctr_specs - 1 then - let val n = 0 upto length ctr_specs - 1 - |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns)) (*###*) in - if n = length ctr_specs - 1 then - conds - else - split_last conds - ||> (fn t => fold_rev absfree (strip_abs_vars t) (strip_abs_body t |> HOLogic.mk_not)) - |>> chop n - |> (fn ((l, r), x) => l @ (x :: r)) - end - else - 0 upto length ctr_specs - 1 - |> map (fn idx => find_first (equal idx o #ctr_no) disc_eqns - |> Option.map #cond - |> the_default undef_const) - |> fst o split_last; - in - (* FIXME: deal with #preds above *) - fold2 (fn idx => nth_map idx o K o abs_tuple) (map_filter #pred ctr_specs) args' - end; + if null disc_eqns then I else + let + val conds = map #cond disc_eqns; + val fun_args = #fun_args (hd disc_eqns); + val args = + if length ctr_specs = 1 then [] + else if length disc_eqns = length ctr_specs then + fst (split_last conds) + else if length disc_eqns = length ctr_specs - 1 then + let val n = 0 upto length ctr_specs - 1 + |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns)) (*###*) in + if n = length ctr_specs - 1 then + conds + else + split_last conds + ||> HOLogic.mk_not + |>> chop n + |> (fn ((l, r), x) => l @ (x :: r)) + end + else + 0 upto length ctr_specs - 1 + |> map (fn idx => find_first (equal idx o #ctr_no) disc_eqns + |> Option.map #cond + |> the_default undef_const) + |> fst o split_last; + in + (* FIXME deal with #preds above *) + (map_filter #pred ctr_specs, args) + |-> fold2 (fn idx => fn t => nth_map idx + (K (subst_bounds (List.rev fun_args, t) + |> HOLogic.tupled_lambda (HOLogic.mk_tuple fun_args)))) + end; fun build_corec_arg_no_call sel_eqns sel = find_first (equal sel o #sel) sel_eqns - |> try (fn SOME sel_eqn => (#fun_args sel_eqn |> map dest_Free, #rhs_term sel_eqn)) + |> try (fn SOME sel_eqn => (#fun_args sel_eqn, #rhs_term sel_eqn)) |> the_default ([], undef_const) - |-> abs_tuple oo fold_rev absfree; + |-> abs_tuple; fun build_corec_arg_direct_call lthy has_call sel_eqns sel = let val maybe_sel_eqn = find_first (equal sel o #sel) sel_eqns - - fun rewrite U T t = + fun rewrite is_end U T t = if U = @{typ bool} then @{term True} |> has_call t ? K @{term False} (* stop? *) - else if T = U = has_call t then undef_const - else if T = U then t (* end *) + else if is_end = has_call t then undef_const + else if is_end then t (* end *) else HOLogic.mk_tuple (snd (strip_comb t)); (* continue *) - fun massage rhs_term t = - massage_direct_corec_call lthy has_call rewrite [] (body_type (fastype_of t)) rhs_term; - val abstract = abs_tuple oo fold_rev absfree o map dest_Free; + fun massage rhs_term is_end t = massage_direct_corec_call + lthy has_call (rewrite is_end) [] (range_type (fastype_of t)) rhs_term; in - if is_none maybe_sel_eqn then I else - massage (#rhs_term (the maybe_sel_eqn)) #> abstract (#fun_args (the maybe_sel_eqn)) + if is_none maybe_sel_eqn then K I else + abs_tuple (#fun_args (the maybe_sel_eqn)) oo massage (#rhs_term (the maybe_sel_eqn)) end; fun build_corec_arg_indirect_call sel_eqns sel = @@ -614,7 +609,7 @@ (build_corec_arg_no_call sel_eqns sel |> K)) no_calls' #> fold (fn (sel, (q, g, h)) => let val f = build_corec_arg_direct_call lthy has_call sel_eqns sel in - nth_map h f o nth_map g f o nth_map q f end) direct_calls' + nth_map h (f false) o nth_map g (f true) o nth_map q (f true) end) direct_calls' #> fold (fn (sel, n) => nth_map n (build_corec_arg_indirect_call sel_eqns sel |> K)) indirect_calls' end @@ -651,24 +646,25 @@ |> fold2 build_corec_args_discs disc_eqnss ctr_specss |> fold2 (fold o build_corec_args_sel lthy has_call) sel_eqnss ctr_specss; + fun currys Ts t = if length Ts <= 1 then t else + t $ foldr1 (fn (u, v) => HOLogic.pair_const dummyT dummyT $ u $ v) + (length Ts - 1 downto 0 |> map Bound) + |> fold_rev (fn T => fn u => Abs (Name.uu, T, u)) Ts; + val _ = tracing ("corecursor arguments:\n \ " ^ space_implode "\n \ " (map (Syntax.string_of_term @{context}) corec_args)); fun uneq_pairs_rev xs = xs (* FIXME \? *) |> these o try (split_last #> (fn (ys, y) => uneq_pairs_rev ys @ map (pair y) ys)); val proof_obligations = if sequential then [] else - maps (uneq_pairs_rev o map #cond) disc_eqnss - |> map (fn (x, y) => ((strip_abs_body x, strip_abs_body y), strip_abs_vars x)) - |> map (apfst (apsnd HOLogic.mk_not #> pairself HOLogic.mk_Trueprop - #> apfst (curry (op $) @{const ==>}) #> (op $))) - |> map (fn (t, abs_vars) => fold_rev (fn (v, T) => fn u => - Const (@{const_name all}, (T --> @{typ prop}) --> @{typ prop}) $ - Abs (v, T, u)) abs_vars t); + maps (uneq_pairs_rev o map (fn {fun_args, cond, ...} => (fun_args, cond))) disc_eqnss + |> map (fn ((fun_args, x), (_, y)) => [x, HOLogic.mk_not y] + |> map (HOLogic.mk_Trueprop o curry subst_bounds (List.rev fun_args)) + |> curry list_comb @{const ==>}); - fun currys Ts t = if length Ts <= 1 then t else - t $ foldr1 (fn (u, v) => HOLogic.pair_const dummyT dummyT $ u $ v) - (length Ts - 1 downto 0 |> map Bound) - |> fold_rev (fn T => fn u => Abs (Name.uu, T, u)) Ts; +val _ = tracing ("proof obligations:\n \ " ^ + space_implode "\n \ " (map (Syntax.string_of_term @{context}) proof_obligations)); + in map (list_comb o rpair corec_args) corecs |> map2 (fn Ts => fn t => if length Ts = 0 then t $ HOLogic.unit else t) arg_Tss