diff -r 0af35cebe8ca -r a179353111db src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Fri Oct 18 17:47:25 2013 +0200 +++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Fri Oct 18 19:03:39 2013 +0200 @@ -476,8 +476,8 @@ Disc of coeqn_data_disc | Sel of coeqn_data_sel; -fun dissect_coeqn_disc seq fun_names (ctr_specss : corec_ctr_spec list list) maybe_ctr_rhs - maybe_code_rhs prems' concl matchedsss = +fun dissect_coeqn_disc seq fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) + maybe_ctr_rhs maybe_code_rhs prems' concl matchedsss = let fun find_subterm p = let (* FIXME \? *) fun f (t as u $ v) = if p t then SOME t else merge_options (f u, f v) @@ -489,23 +489,23 @@ |> the handle Option.Option => primrec_error_eqn "malformed discriminator equation" concl; val ((fun_name, fun_T), fun_args) = strip_comb applied_fun |>> dest_Free; - val ctr_specs = the (AList.lookup (op =) (fun_names ~~ ctr_specss) fun_name); + val basic_ctr_specs = the (AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name); - val discs = map #disc ctr_specs; - val ctrs = map #ctr ctr_specs; + val discs = map #disc basic_ctr_specs; + val ctrs = map #ctr basic_ctr_specs; val not_disc = head_of concl = @{term Not}; val _ = not_disc andalso length ctrs <> 2 andalso primrec_error_eqn "\ed discriminator for a type with \ 2 constructors" concl; - val disc = find_subterm (member (op =) discs o head_of) concl; + val disc' = find_subterm (member (op =) discs o head_of) concl; val eq_ctr0 = concl |> perhaps (try (HOLogic.dest_not)) |> try (HOLogic.dest_eq #> snd) |> (fn SOME t => let val n = find_index (equal t) ctrs in if n >= 0 then SOME n else NONE end | _ => NONE); - val _ = is_some disc orelse is_some eq_ctr0 orelse + val _ = is_some disc' orelse is_some eq_ctr0 orelse primrec_error_eqn "no discriminator in equation" concl; val ctr_no' = - if is_none disc then the eq_ctr0 else find_index (equal (head_of (the disc))) discs; + if is_none disc' then the eq_ctr0 else find_index (equal (head_of (the disc'))) discs; val ctr_no = if not_disc then 1 - ctr_no' else ctr_no'; - val ctr = #ctr (nth ctr_specs ctr_no); + val {ctr, disc, ...} = nth basic_ctr_specs ctr_no; val catch_all = try (fst o dest_Free o the_single) prems' = SOME Name.uu_; val matchedss = AList.lookup (op =) matchedsss fun_name |> the_default []; @@ -528,7 +528,7 @@ fun_args = fun_args, ctr = ctr, ctr_no = ctr_no, - disc = #disc (nth ctr_specs ctr_no), + disc = disc, prems = real_prems, auto_gen = catch_all, maybe_ctr_rhs = maybe_ctr_rhs, @@ -537,7 +537,8 @@ }, matchedsss') end; -fun dissect_coeqn_sel fun_names (ctr_specss : corec_ctr_spec list list) eqn' of_spec eqn = +fun dissect_coeqn_sel fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) eqn' of_spec + eqn = let val (lhs, rhs) = HOLogic.dest_eq eqn handle TERM _ => @@ -546,12 +547,12 @@ val ((fun_name, fun_T), fun_args) = dest_comb lhs |> snd |> strip_comb |> apfst dest_Free handle TERM _ => primrec_error_eqn "malformed selector argument in left-hand side" eqn; - val ctr_specs = the (AList.lookup (op =) (fun_names ~~ ctr_specss) fun_name) + val basic_ctr_specs = the (AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name) handle Option.Option => primrec_error_eqn "malformed selector argument in left-hand side" eqn; - val ctr_spec = + val {ctr, ...} = if is_some of_spec - then the (find_first (equal (the of_spec) o #ctr) ctr_specs) - else ctr_specs |> filter (exists (equal sel) o #sels) |> the_single + then the (find_first (equal (the of_spec) o #ctr) basic_ctr_specs) + else filter (exists (equal sel) o #sels) basic_ctr_specs |> the_single handle List.Empty => primrec_error_eqn "ambiguous selector - use \"of\"" eqn; val user_eqn = drop_All eqn'; in @@ -559,27 +560,27 @@ fun_name = fun_name, fun_T = fun_T, fun_args = fun_args, - ctr = #ctr ctr_spec, + ctr = ctr, sel = sel, rhs_term = rhs, user_eqn = user_eqn } end; -fun dissect_coeqn_ctr seq fun_names (ctr_specss : corec_ctr_spec list list) eqn' maybe_code_rhs - prems concl matchedsss = +fun dissect_coeqn_ctr seq fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) eqn' + maybe_code_rhs prems concl matchedsss = let val (lhs, rhs) = HOLogic.dest_eq concl; val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free; - val ctr_specs = the (AList.lookup (op =) (fun_names ~~ ctr_specss) fun_name); + val basic_ctr_specs = the (AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name); val (ctr, ctr_args) = strip_comb (unfold_let rhs); - val {disc, sels, ...} = the (find_first (equal ctr o #ctr) ctr_specs) + val {disc, sels, ...} = the (find_first (equal ctr o #ctr) basic_ctr_specs) handle Option.Option => primrec_error_eqn "not a constructor" ctr; val disc_concl = betapply (disc, lhs); - val (maybe_eqn_data_disc, matchedsss') = if length ctr_specs = 1 + val (maybe_eqn_data_disc, matchedsss') = if length basic_ctr_specs = 1 then (NONE, matchedsss) - else apfst SOME (dissect_coeqn_disc seq fun_names ctr_specss + else apfst SOME (dissect_coeqn_disc seq fun_names basic_ctr_specss (SOME (abstract (List.rev fun_args) rhs)) maybe_code_rhs prems disc_concl matchedsss); val sel_concls = (sels ~~ ctr_args) @@ -593,19 +594,20 @@ space_implode "\n \ " (map (Syntax.string_of_term @{context}) prems)); *) - val eqns_data_sel = map (dissect_coeqn_sel fun_names ctr_specss eqn' (SOME ctr)) sel_concls; + val eqns_data_sel = + map (dissect_coeqn_sel fun_names basic_ctr_specss eqn' (SOME ctr)) sel_concls; in (the_list maybe_eqn_data_disc @ eqns_data_sel, matchedsss') end; -fun dissect_coeqn_code lthy has_call fun_names ctr_specss eqn' concl matchedsss = +fun dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn' concl matchedsss = let val (lhs, (rhs', rhs)) = HOLogic.dest_eq concl ||> `(expand_corec_code_rhs lthy has_call []); val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free; - val ctr_specs = the (AList.lookup (op =) (fun_names ~~ ctr_specss) fun_name); + val basic_ctr_specs = the (AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name); val cond_ctrs = fold_rev_corec_code_rhs lthy (fn cs => fn ctr => fn _ => - if member ((op =) o apsnd #ctr) ctr_specs ctr + if member ((op =) o apsnd #ctr) basic_ctr_specs ctr then cons (ctr, cs) else primrec_error_eqn "not a constructor" ctr) [] rhs' [] |> AList.group (op =); @@ -618,13 +620,13 @@ |> curry list_comb ctr |> curry HOLogic.mk_eq lhs); in - fold_map2 (dissect_coeqn_ctr false fun_names ctr_specss eqn' + fold_map2 (dissect_coeqn_ctr false fun_names basic_ctr_specss eqn' (SOME (abstract (List.rev fun_args) rhs))) ctr_premss ctr_concls matchedsss end; -fun dissect_coeqn lthy seq has_call fun_names (ctr_specss : corec_ctr_spec list list) eqn' of_spec - matchedsss = +fun dissect_coeqn lthy seq has_call fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) + eqn' of_spec matchedsss = let val eqn = drop_All eqn' handle TERM _ => primrec_error_eqn "malformed function equation" eqn'; @@ -637,23 +639,23 @@ val maybe_rhs = concl |> perhaps (try (HOLogic.dest_not)) |> try (snd o HOLogic.dest_eq); - val discs = maps (map #disc) ctr_specss; - val sels = maps (maps #sels) ctr_specss; - val ctrs = maps (map #ctr) ctr_specss; + val discs = maps (map #disc) basic_ctr_specss; + val sels = maps (maps #sels) basic_ctr_specss; + val ctrs = maps (map #ctr) basic_ctr_specss; in if member (op =) discs head orelse is_some maybe_rhs andalso member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then - dissect_coeqn_disc seq fun_names ctr_specss NONE NONE prems concl matchedsss + dissect_coeqn_disc seq fun_names basic_ctr_specss NONE NONE prems concl matchedsss |>> single else if member (op =) sels head then - ([dissect_coeqn_sel fun_names ctr_specss eqn' of_spec concl], matchedsss) + ([dissect_coeqn_sel fun_names basic_ctr_specss eqn' of_spec concl], matchedsss) else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso member (op =) ctrs (head_of (unfold_let (the maybe_rhs))) then - dissect_coeqn_ctr seq fun_names ctr_specss eqn' NONE prems concl matchedsss + dissect_coeqn_ctr seq fun_names basic_ctr_specss eqn' NONE prems concl matchedsss else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso null prems then - dissect_coeqn_code lthy has_call fun_names ctr_specss eqn' concl matchedsss + dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn' concl matchedsss |>> flat else primrec_error_eqn "malformed function equation" eqn @@ -747,9 +749,8 @@ fun build_codefs lthy bs mxs has_call arg_Tss (corec_specs : corec_spec list) (disc_eqnss : coeqn_data_disc list list) (sel_eqnss : coeqn_data_sel list list) = let - val corec_specs' = take (length bs) corec_specs; - val corecs = map #corec corec_specs'; - val ctr_specss = map #ctr_specs corec_specs'; + val corecs = map #corec corec_specs; + val ctr_specss = map #ctr_specs corec_specs; val corec_args = hd corecs |> fst o split_last o binder_types o fastype_of |> map (Const o pair @{const_name undefined}) @@ -808,27 +809,49 @@ chop n disc_eqns ||> cons extra_disc_eqn |> (op @) end; +fun find_corec_calls has_call basic_ctr_specs {ctr, sel, rhs_term, ...} = + let + val sel_no = find_first (equal ctr o #ctr) basic_ctr_specs + |> find_index (equal sel) o #sels o the; + fun find (Abs (_, _, b)) = find b + | find (t as _ $ _) = strip_comb t |>> find ||> maps find |> (op @) + | find f = if is_Free f andalso has_call f then [f] else []; + in + find rhs_term + |> K |> nth_map sel_no |> AList.map_entry (op =) ctr + end; + fun add_primcorec simple seq fixes specs of_specs lthy = let val (bs, mxs) = map_split (apfst fst) fixes; val (arg_Ts, res_Ts) = map (strip_type o snd o fst #>> HOLogic.mk_tupleT) fixes |> split_list; - val callssss = []; (* FIXME *) + val fun_names = map Binding.name_of bs; + val basic_ctr_specss = map (basic_corec_specs_of lthy) res_Ts; + val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =)); + val eqns_data = + fold_map2 (dissect_coeqn lthy seq has_call fun_names basic_ctr_specss) (map snd specs) + of_specs [] + |> flat o fst; + + val callssss = + map_filter (try (fn Sel x => x)) eqns_data + |> partition_eq ((op =) o pairself #fun_name) + |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names + |> map (flat o snd) |> map2 (fold o find_corec_calls has_call) basic_ctr_specss + |> map2 (curry (op |>)) (map (map (fn {ctr, sels, ...} => + (ctr, map (K []) sels))) basic_ctr_specss); + +(* +val _ = tracing ("callssss = " ^ @{make_string} callssss); +*) val ((n2m, corec_specs', _, coinduct_thm, strong_coinduct_thm, coinduct_thms, strong_coinduct_thms), lthy') = corec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy; - val actual_nn = length bs; - val fun_names = map Binding.name_of bs; val corec_specs = take actual_nn corec_specs'; (*###*) - val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =)); - val eqns_data = - fold_map2 (dissect_coeqn lthy seq has_call fun_names (map #ctr_specs corec_specs)) - (map snd specs) of_specs [] - |> flat o fst; - val disc_eqnss' = map_filter (try (fn Disc x => x)) eqns_data |> partition_eq ((op =) o pairself #fun_name) |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names