src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
changeset 54160 a179353111db
parent 54157 5874be04e1f9
child 54161 496f9af15b39
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Fri Oct 18 17:47:25 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Fri Oct 18 19:03:39 2013 +0200
@@ -476,8 +476,8 @@
   Disc of coeqn_data_disc |
   Sel of coeqn_data_sel;
 
-fun dissect_coeqn_disc seq fun_names (ctr_specss : corec_ctr_spec list list) maybe_ctr_rhs
-    maybe_code_rhs prems' concl matchedsss =
+fun dissect_coeqn_disc seq fun_names (basic_ctr_specss : basic_corec_ctr_spec list list)
+    maybe_ctr_rhs maybe_code_rhs prems' concl matchedsss =
   let
     fun find_subterm p = let (* FIXME \<exists>? *)
       fun f (t as u $ v) = if p t then SOME t else merge_options (f u, f v)
@@ -489,23 +489,23 @@
       |> the
       handle Option.Option => primrec_error_eqn "malformed discriminator equation" concl;
     val ((fun_name, fun_T), fun_args) = strip_comb applied_fun |>> dest_Free;
-    val ctr_specs = the (AList.lookup (op =) (fun_names ~~ ctr_specss) fun_name);
+    val basic_ctr_specs = the (AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name);
 
-    val discs = map #disc ctr_specs;
-    val ctrs = map #ctr ctr_specs;
+    val discs = map #disc basic_ctr_specs;
+    val ctrs = map #ctr basic_ctr_specs;
     val not_disc = head_of concl = @{term Not};
     val _ = not_disc andalso length ctrs <> 2 andalso
       primrec_error_eqn "\<not>ed discriminator for a type with \<noteq> 2 constructors" concl;
-    val disc = find_subterm (member (op =) discs o head_of) concl;
+    val disc' = find_subterm (member (op =) discs o head_of) concl;
     val eq_ctr0 = concl |> perhaps (try (HOLogic.dest_not)) |> try (HOLogic.dest_eq #> snd)
         |> (fn SOME t => let val n = find_index (equal t) ctrs in
           if n >= 0 then SOME n else NONE end | _ => NONE);
-    val _ = is_some disc orelse is_some eq_ctr0 orelse
+    val _ = is_some disc' orelse is_some eq_ctr0 orelse
       primrec_error_eqn "no discriminator in equation" concl;
     val ctr_no' =
-      if is_none disc then the eq_ctr0 else find_index (equal (head_of (the disc))) discs;
+      if is_none disc' then the eq_ctr0 else find_index (equal (head_of (the disc'))) discs;
     val ctr_no = if not_disc then 1 - ctr_no' else ctr_no';
-    val ctr = #ctr (nth ctr_specs ctr_no);
+    val {ctr, disc, ...} = nth basic_ctr_specs ctr_no;
 
     val catch_all = try (fst o dest_Free o the_single) prems' = SOME Name.uu_;
     val matchedss = AList.lookup (op =) matchedsss fun_name |> the_default [];
@@ -528,7 +528,7 @@
       fun_args = fun_args,
       ctr = ctr,
       ctr_no = ctr_no,
-      disc = #disc (nth ctr_specs ctr_no),
+      disc = disc,
       prems = real_prems,
       auto_gen = catch_all,
       maybe_ctr_rhs = maybe_ctr_rhs,
@@ -537,7 +537,8 @@
     }, matchedsss')
   end;
 
-fun dissect_coeqn_sel fun_names (ctr_specss : corec_ctr_spec list list) eqn' of_spec eqn =
+fun dissect_coeqn_sel fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) eqn' of_spec
+    eqn =
   let
     val (lhs, rhs) = HOLogic.dest_eq eqn
       handle TERM _ =>
@@ -546,12 +547,12 @@
     val ((fun_name, fun_T), fun_args) = dest_comb lhs |> snd |> strip_comb |> apfst dest_Free
       handle TERM _ =>
         primrec_error_eqn "malformed selector argument in left-hand side" eqn;
-    val ctr_specs = the (AList.lookup (op =) (fun_names ~~ ctr_specss) fun_name)
+    val basic_ctr_specs = the (AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name)
       handle Option.Option => primrec_error_eqn "malformed selector argument in left-hand side" eqn;
-    val ctr_spec =
+    val {ctr, ...} =
       if is_some of_spec
-      then the (find_first (equal (the of_spec) o #ctr) ctr_specs)
-      else ctr_specs |> filter (exists (equal sel) o #sels) |> the_single
+      then the (find_first (equal (the of_spec) o #ctr) basic_ctr_specs)
+      else filter (exists (equal sel) o #sels) basic_ctr_specs |> the_single
         handle List.Empty => primrec_error_eqn "ambiguous selector - use \"of\"" eqn;
     val user_eqn = drop_All eqn';
   in
@@ -559,27 +560,27 @@
       fun_name = fun_name,
       fun_T = fun_T,
       fun_args = fun_args,
-      ctr = #ctr ctr_spec,
+      ctr = ctr,
       sel = sel,
       rhs_term = rhs,
       user_eqn = user_eqn
     }
   end;
 
-fun dissect_coeqn_ctr seq fun_names (ctr_specss : corec_ctr_spec list list) eqn' maybe_code_rhs
-    prems concl matchedsss =
+fun dissect_coeqn_ctr seq fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) eqn'
+    maybe_code_rhs prems concl matchedsss =
   let
     val (lhs, rhs) = HOLogic.dest_eq concl;
     val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free;
-    val ctr_specs = the (AList.lookup (op =) (fun_names ~~ ctr_specss) fun_name);
+    val basic_ctr_specs = the (AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name);
     val (ctr, ctr_args) = strip_comb (unfold_let rhs);
-    val {disc, sels, ...} = the (find_first (equal ctr o #ctr) ctr_specs)
+    val {disc, sels, ...} = the (find_first (equal ctr o #ctr) basic_ctr_specs)
       handle Option.Option => primrec_error_eqn "not a constructor" ctr;
 
     val disc_concl = betapply (disc, lhs);
-    val (maybe_eqn_data_disc, matchedsss') = if length ctr_specs = 1
+    val (maybe_eqn_data_disc, matchedsss') = if length basic_ctr_specs = 1
       then (NONE, matchedsss)
-      else apfst SOME (dissect_coeqn_disc seq fun_names ctr_specss
+      else apfst SOME (dissect_coeqn_disc seq fun_names basic_ctr_specss
           (SOME (abstract (List.rev fun_args) rhs)) maybe_code_rhs prems disc_concl matchedsss);
 
     val sel_concls = (sels ~~ ctr_args)
@@ -593,19 +594,20 @@
  space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) prems));
 *)
 
-    val eqns_data_sel = map (dissect_coeqn_sel fun_names ctr_specss eqn' (SOME ctr)) sel_concls;
+    val eqns_data_sel =
+      map (dissect_coeqn_sel fun_names basic_ctr_specss eqn' (SOME ctr)) sel_concls;
   in
     (the_list maybe_eqn_data_disc @ eqns_data_sel, matchedsss')
   end;
 
-fun dissect_coeqn_code lthy has_call fun_names ctr_specss eqn' concl matchedsss =
+fun dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn' concl matchedsss =
   let
     val (lhs, (rhs', rhs)) = HOLogic.dest_eq concl ||> `(expand_corec_code_rhs lthy has_call []);
     val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free;
-    val ctr_specs = the (AList.lookup (op =) (fun_names ~~ ctr_specss) fun_name);
+    val basic_ctr_specs = the (AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name);
 
     val cond_ctrs = fold_rev_corec_code_rhs lthy (fn cs => fn ctr => fn _ =>
-        if member ((op =) o apsnd #ctr) ctr_specs ctr
+        if member ((op =) o apsnd #ctr) basic_ctr_specs ctr
         then cons (ctr, cs)
         else primrec_error_eqn "not a constructor" ctr) [] rhs' []
       |> AList.group (op =);
@@ -618,13 +620,13 @@
         |> curry list_comb ctr
         |> curry HOLogic.mk_eq lhs);
   in
-    fold_map2 (dissect_coeqn_ctr false fun_names ctr_specss eqn'
+    fold_map2 (dissect_coeqn_ctr false fun_names basic_ctr_specss eqn'
         (SOME (abstract (List.rev fun_args) rhs)))
       ctr_premss ctr_concls matchedsss
   end;
 
-fun dissect_coeqn lthy seq has_call fun_names (ctr_specss : corec_ctr_spec list list) eqn' of_spec
-    matchedsss =
+fun dissect_coeqn lthy seq has_call fun_names (basic_ctr_specss : basic_corec_ctr_spec list list)
+    eqn' of_spec matchedsss =
   let
     val eqn = drop_All eqn'
       handle TERM _ => primrec_error_eqn "malformed function equation" eqn';
@@ -637,23 +639,23 @@
 
     val maybe_rhs = concl |> perhaps (try (HOLogic.dest_not)) |> try (snd o HOLogic.dest_eq);
 
-    val discs = maps (map #disc) ctr_specss;
-    val sels = maps (maps #sels) ctr_specss;
-    val ctrs = maps (map #ctr) ctr_specss;
+    val discs = maps (map #disc) basic_ctr_specss;
+    val sels = maps (maps #sels) basic_ctr_specss;
+    val ctrs = maps (map #ctr) basic_ctr_specss;
   in
     if member (op =) discs head orelse
       is_some maybe_rhs andalso
         member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then
-      dissect_coeqn_disc seq fun_names ctr_specss NONE NONE prems concl matchedsss
+      dissect_coeqn_disc seq fun_names basic_ctr_specss NONE NONE prems concl matchedsss
       |>> single
     else if member (op =) sels head then
-      ([dissect_coeqn_sel fun_names ctr_specss eqn' of_spec concl], matchedsss)
+      ([dissect_coeqn_sel fun_names basic_ctr_specss eqn' of_spec concl], matchedsss)
     else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
       member (op =) ctrs (head_of (unfold_let (the maybe_rhs))) then
-      dissect_coeqn_ctr seq fun_names ctr_specss eqn' NONE prems concl matchedsss
+      dissect_coeqn_ctr seq fun_names basic_ctr_specss eqn' NONE prems concl matchedsss
     else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
       null prems then
-      dissect_coeqn_code lthy has_call fun_names ctr_specss eqn' concl matchedsss
+      dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn' concl matchedsss
       |>> flat
     else
       primrec_error_eqn "malformed function equation" eqn
@@ -747,9 +749,8 @@
 fun build_codefs lthy bs mxs has_call arg_Tss (corec_specs : corec_spec list)
     (disc_eqnss : coeqn_data_disc list list) (sel_eqnss : coeqn_data_sel list list) =
   let
-    val corec_specs' = take (length bs) corec_specs;
-    val corecs = map #corec corec_specs';
-    val ctr_specss = map #ctr_specs corec_specs';
+    val corecs = map #corec corec_specs;
+    val ctr_specss = map #ctr_specs corec_specs;
     val corec_args = hd corecs
       |> fst o split_last o binder_types o fastype_of
       |> map (Const o pair @{const_name undefined})
@@ -808,27 +809,49 @@
       chop n disc_eqns ||> cons extra_disc_eqn |> (op @)
     end;
 
+fun find_corec_calls has_call basic_ctr_specs {ctr, sel, rhs_term, ...} =
+  let
+    val sel_no = find_first (equal ctr o #ctr) basic_ctr_specs
+      |> find_index (equal sel) o #sels o the;
+    fun find (Abs (_, _, b)) = find b
+      | find (t as _ $ _) = strip_comb t |>> find ||> maps find |> (op @)
+      | find f = if is_Free f andalso has_call f then [f] else [];
+  in
+    find rhs_term
+    |> K |> nth_map sel_no |> AList.map_entry (op =) ctr
+  end;
+
 fun add_primcorec simple seq fixes specs of_specs lthy =
   let
     val (bs, mxs) = map_split (apfst fst) fixes;
     val (arg_Ts, res_Ts) = map (strip_type o snd o fst #>> HOLogic.mk_tupleT) fixes |> split_list;
 
-    val callssss = []; (* FIXME *)
+    val fun_names = map Binding.name_of bs;
+    val basic_ctr_specss = map (basic_corec_specs_of lthy) res_Ts;
+    val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
+    val eqns_data =
+      fold_map2 (dissect_coeqn lthy seq has_call fun_names basic_ctr_specss) (map snd specs)
+        of_specs []
+      |> flat o fst;
+
+    val callssss =
+      map_filter (try (fn Sel x => x)) eqns_data
+      |> partition_eq ((op =) o pairself #fun_name)
+      |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
+      |> map (flat o snd)      |> map2 (fold o find_corec_calls has_call) basic_ctr_specss
+      |> map2 (curry (op |>)) (map (map (fn {ctr, sels, ...} =>
+        (ctr, map (K []) sels))) basic_ctr_specss);
+
+(*
+val _ = tracing ("callssss = " ^ @{make_string} callssss);
+*)
 
     val ((n2m, corec_specs', _, coinduct_thm, strong_coinduct_thm, coinduct_thms,
           strong_coinduct_thms), lthy') =
       corec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy;
-
     val actual_nn = length bs;
-    val fun_names = map Binding.name_of bs;
     val corec_specs = take actual_nn corec_specs'; (*###*)
 
-    val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
-    val eqns_data =
-      fold_map2 (dissect_coeqn lthy seq has_call fun_names (map #ctr_specs corec_specs))
-        (map snd specs) of_specs []
-      |> flat o fst;
-
     val disc_eqnss' = map_filter (try (fn Disc x => x)) eqns_data
       |> partition_eq ((op =) o pairself #fun_name)
       |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names