src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
changeset 54097 92c5bd3b342d
parent 54074 43cdae9524bf
child 54098 07a8145aaeba
--- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Thu Oct 10 08:23:57 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML	Fri Oct 11 16:31:23 2013 +0200
@@ -455,6 +455,8 @@
   disc: term,
   prems: term list,
   auto_gen: bool,
+  maybe_ctr_rhs: term option,
+  maybe_code_rhs: term option,
   user_eqn: term
 };
 
@@ -472,8 +474,8 @@
   Disc of co_eqn_data_disc |
   Sel of co_eqn_data_sel;
 
-fun co_dissect_eqn_disc seq fun_names (corec_specs : corec_spec list) prems' concl
-    matchedsss =
+fun co_dissect_eqn_disc seq fun_names (corec_specs : corec_spec list) maybe_ctr_rhs maybe_code_rhs
+    prems' concl matchedsss =
   let
     fun find_subterm p = let (* FIXME \<exists>? *)
       fun f (t as u $ v) = if p t then SOME t else merge_options (f u, f v)
@@ -514,9 +516,9 @@
       |> cons (fun_name, if seq then matchedss @ [prems] else matchedss @ [real_prems]);
 
     val user_eqn =
-      (real_prems, betapply (#disc (nth ctr_specs ctr_no), applied_fun))
-      |>> map HOLogic.mk_Trueprop ||> HOLogic.mk_Trueprop
-      |> Logic.list_implies;
+      (real_prems, concl)
+      |>> map HOLogic.mk_Trueprop ||> HOLogic.mk_Trueprop o abstract (List.rev fun_args)
+      |> curry Logic.list_all (map dest_Free fun_args) o Logic.list_implies;
   in
     (Disc {
       fun_name = fun_name,
@@ -527,6 +529,8 @@
       disc = #disc (nth ctr_specs ctr_no),
       prems = real_prems,
       auto_gen = catch_all,
+      maybe_ctr_rhs = maybe_ctr_rhs,
+      maybe_code_rhs = maybe_code_rhs,
       user_eqn = user_eqn
     }, matchedsss')
   end;
@@ -560,11 +564,11 @@
     }
   end;
 
-fun co_dissect_eqn_ctr seq fun_names (corec_specs : corec_spec list) eqn' prems concl
-    matchedsss =
+fun co_dissect_eqn_ctr seq fun_names (corec_specs : corec_spec list) eqn' maybe_code_rhs
+    prems concl matchedsss =
   let
     val (lhs, rhs) = HOLogic.dest_eq concl;
-    val fun_name = head_of lhs |> fst o dest_Free;
+    val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free;
     val {ctr_specs, ...} = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name);
     val (ctr, ctr_args) = strip_comb (unfold_let rhs);
     val {disc, sels, ...} = the (find_first (equal ctr o #ctr) ctr_specs)
@@ -573,8 +577,8 @@
     val disc_concl = betapply (disc, lhs);
     val (maybe_eqn_data_disc, matchedsss') = if length ctr_specs = 1
       then (NONE, matchedsss)
-      else apfst SOME (co_dissect_eqn_disc
-          seq fun_names corec_specs prems disc_concl matchedsss);
+      else apfst SOME (co_dissect_eqn_disc seq fun_names corec_specs
+          (SOME (abstract (List.rev fun_args) rhs)) maybe_code_rhs prems disc_concl matchedsss);
 
     val sel_concls = (sels ~~ ctr_args)
       |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
@@ -582,7 +586,9 @@
 (*
 val _ = tracing ("reduced\n    " ^ Syntax.string_of_term @{context} concl ^ "\nto\n    \<cdot> " ^
  (is_some maybe_eqn_data_disc ? K (Syntax.string_of_term @{context} disc_concl ^ "\n    \<cdot> ")) "" ^
- space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) sel_concls));
+ space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) sel_concls) ^
+ "\nfor premise(s)\n    \<cdot> " ^
+ space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) prems));
 *)
 
     val eqns_data_sel = map (co_dissect_eqn_sel fun_names corec_specs eqn' (SOME ctr)) sel_concls;
@@ -593,7 +599,7 @@
 fun co_dissect_eqn_code lthy has_call fun_names corec_specs eqn' concl matchedsss =
   let
     val (lhs, (rhs', rhs)) = HOLogic.dest_eq concl ||> `(expand_corec_code_rhs lthy has_call []);
-    val fun_name = head_of lhs |> fst o dest_Free;
+    val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free;
     val {ctr_specs, ...} = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name);
 
     val cond_ctrs = fold_rev_corec_code_rhs lthy (fn cs => fn ctr => fn _ =>
@@ -610,10 +616,12 @@
         |> curry list_comb ctr
         |> curry HOLogic.mk_eq lhs);
   in
-    fold_map2 (co_dissect_eqn_ctr false fun_names corec_specs eqn') ctr_premss ctr_concls matchedsss
+    fold_map2 (co_dissect_eqn_ctr false fun_names corec_specs eqn'
+        (SOME (abstract (List.rev fun_args) rhs)))
+      ctr_premss ctr_concls matchedsss
   end;
 
-fun co_dissect_eqn lthy has_call seq fun_names (corec_specs : corec_spec list) eqn' of_spec
+fun co_dissect_eqn lthy seq has_call fun_names (corec_specs : corec_spec list) eqn' of_spec
     matchedsss =
   let
     val eqn = drop_All eqn'
@@ -634,13 +642,13 @@
     if member (op =) discs head orelse
       is_some maybe_rhs andalso
         member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then
-      co_dissect_eqn_disc seq fun_names corec_specs prems concl matchedsss
+      co_dissect_eqn_disc seq fun_names corec_specs NONE NONE prems concl matchedsss
       |>> single
     else if member (op =) sels head then
       ([co_dissect_eqn_sel fun_names corec_specs eqn' of_spec concl], matchedsss)
     else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
       member (op =) ctrs (head_of (unfold_let (the maybe_rhs))) then
-      co_dissect_eqn_ctr seq fun_names corec_specs eqn' prems concl matchedsss
+      co_dissect_eqn_ctr seq fun_names corec_specs eqn' NONE prems concl matchedsss
     else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
       null prems then
       co_dissect_eqn_code lthy has_call fun_names corec_specs eqn' concl matchedsss
@@ -670,11 +678,13 @@
     if is_none maybe_sel_eqn then (I, I, I) else
     let
       val {fun_args, rhs_term, ... } = the maybe_sel_eqn;
+      val bound_Ts = List.rev (map fastype_of fun_args);
       fun rewrite_q _ t = if has_call t then @{term False} else @{term True};
       fun rewrite_g _ t = if has_call t then undef_const else t;
       fun rewrite_h bound_Ts t =
         if has_call t then mk_tuple1 bound_Ts (snd (strip_comb t)) else undef_const;
-      fun massage f _ = massage_direct_corec_call lthy has_call f [] rhs_term |> abs_tuple fun_args;
+      fun massage f _ = massage_direct_corec_call lthy has_call f bound_Ts rhs_term
+        |> abs_tuple fun_args;
     in
       (massage rewrite_q,
        massage rewrite_g,
@@ -689,6 +699,7 @@
     if is_none maybe_sel_eqn then I else
     let
       val {fun_args, rhs_term, ...} = the maybe_sel_eqn;
+      val bound_Ts = List.rev (map fastype_of fun_args);
       fun rewrite bound_Ts U T (Abs (v, V, b)) = Abs (v, V, rewrite (V :: bound_Ts) U T b)
         | rewrite bound_Ts U T (t as _ $ _) =
           let val (u, vs) = strip_comb t in
@@ -702,7 +713,8 @@
         | rewrite _ U T t =
           if is_Free t andalso has_call t then Inr_const U T $ HOLogic.unit else t;
       fun massage t =
-        massage_indirect_corec_call lthy has_call rewrite [] (range_type (fastype_of t)) rhs_term
+        rhs_term
+        |> massage_indirect_corec_call lthy has_call rewrite bound_Ts (range_type (fastype_of t))
         |> abs_tuple fun_args;
     in
       massage
@@ -786,6 +798,8 @@
         disc = #disc (nth ctr_specs n),
         prems = maps (s_not_conj o #prems) disc_eqns,
         auto_gen = true,
+        maybe_ctr_rhs = NONE,
+        maybe_code_rhs = NONE,
         user_eqn = undef_const};
     in
       chop n disc_eqns ||> cons extra_disc_eqn |> (op @)
@@ -808,7 +822,7 @@
 
     val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
     val eqns_data =
-      fold_map2 (co_dissect_eqn lthy has_call seq fun_names corec_specs)
+      fold_map2 (co_dissect_eqn lthy seq has_call fun_names corec_specs)
         (map snd specs) of_specs []
       |> flat o fst;
 
@@ -876,6 +890,7 @@
                 |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
                 |> curry Logic.list_all (map dest_Free fun_args);
             in
+              if prems = [@{term False}] then [] else
               mk_primcorec_disc_tac lthy def_thms disc_corec k m exclsss
               |> K |> Goal.prove lthy [] [] t
               |> pair (#disc (nth ctr_specs ctr_no))
@@ -911,24 +926,27 @@
 
         fun prove_ctr disc_alist sel_alist (disc_eqns : co_eqn_data_disc list)
             (sel_eqns : co_eqn_data_sel list) ({ctr, disc, sels, collapse, ...} : corec_ctr_spec) =
+          (* don't try to prove theorems when some sel_eqns are missing *)
           if not (exists (equal ctr o #ctr) disc_eqns)
               andalso not (exists (equal ctr o #ctr) sel_eqns)
-            orelse (* don't try to prove theorems when some sel_eqns are missing *)
+            orelse
               filter (equal ctr o #ctr) sel_eqns
               |> fst o finds ((op =) o apsnd #sel) sels
               |> exists (null o snd)
           then [] else
             let
-              val (fun_name, fun_T, fun_args, prems) =
+              val (fun_name, fun_T, fun_args, prems, maybe_rhs) =
                 (find_first (equal ctr o #ctr) disc_eqns, find_first (equal ctr o #ctr) sel_eqns)
-                |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x))
-                ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, []))
+                |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x,
+                  #maybe_ctr_rhs x))
+                ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, [], NONE))
                 |> the o merge_options;
               val m = length prems;
-              val t = filter (equal ctr o #ctr) sel_eqns
-                |> fst o finds ((op =) o apsnd #sel) sels
-                |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x)) #-> abstract)
-                |> curry list_comb ctr
+              val t = (if is_some maybe_rhs then the maybe_rhs else
+                  filter (equal ctr o #ctr) sel_eqns
+                  |> fst o finds ((op =) o apsnd #sel) sels
+                  |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x)) #-> abstract)
+                  |> curry list_comb ctr)
                 |> curry HOLogic.mk_eq (list_comb (Free (fun_name, fun_T),
                   map Bound (length fun_args - 1 downto 0)))
                 |> HOLogic.mk_Trueprop
@@ -937,18 +955,87 @@
               val maybe_disc_thm = AList.lookup (op =) disc_alist disc;
               val sel_thms = map snd (filter (member (op =) sels o fst) sel_alist);
             in
-              mk_primcorec_ctr_of_dtr_tac lthy m collapse maybe_disc_thm sel_thms
-              |> K |> Goal.prove lthy [] [] t
-              |> single
+              if prems = [@{term False}] then [] else
+                mk_primcorec_ctr_of_dtr_tac lthy m collapse maybe_disc_thm sel_thms
+                |> K |> Goal.prove lthy [] [] t
+                |> pair ctr
+                |> single
             end;
 
+        fun prove_code disc_eqns sel_eqns ctr_alist
+            {distincts, sel_splits, sel_split_asms, ctr_specs, ...} =
+(* FIXME doesn't work reliably yet *)
+[](*          let
+            val (fun_name, fun_T, fun_args, maybe_rhs) =
+              (find_first (member (op =) (map #ctr ctr_specs) o #ctr) disc_eqns,
+               find_first (member (op =) (map #ctr ctr_specs) o #ctr) sel_eqns)
+              |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #maybe_code_rhs x))
+              ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, NONE))
+              |> the o merge_options;
+
+            val maybe_rhs' = if is_some maybe_rhs then maybe_rhs else
+              let
+                fun prove_code_ctr {ctr, disc, sels, ...} =
+                  if not (exists (equal ctr o fst) ctr_alist) then NONE else
+                    let
+                      val (fun_name, fun_T, fun_args, prems) =
+                        (find_first (equal ctr o #ctr) disc_eqns,
+                         find_first (equal ctr o #ctr) sel_eqns)
+                        |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x))
+                        ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, []))
+                        |> the o merge_options;
+                      val m = length prems;
+                      val t =
+                        filter (equal ctr o #ctr) sel_eqns
+                        |> fst o finds ((op =) o apsnd #sel) sels
+                        |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x)) #-> abstract)
+                        |> curry list_comb ctr;
+                    in
+                      SOME (prems, t)
+                    end;
+                val maybe_ctr_conds_argss = map prove_code_ctr ctr_specs;
+              in
+                if exists is_none maybe_ctr_conds_argss then NONE else
+                  fold_rev (fn SOME (prems, u) => fn t => mk_If (s_conjs prems) u t)
+                    maybe_ctr_conds_argss (Const (@{const_name undefined}, body_type fun_T))
+                  |> SOME
+              end;
+          in
+            if is_none maybe_rhs' then [] else
+              let
+                val rhs = the maybe_rhs';
+                val bound_Ts = List.rev (map fastype_of fun_args);
+                val rhs' = expand_corec_code_rhs lthy has_call bound_Ts rhs;
+                val cond_ctrs = fold_rev_corec_code_rhs lthy (K oo (cons oo pair)) bound_Ts rhs' [];
+                val ctr_thms = map (the o AList.lookup (op =) ctr_alist o snd) cond_ctrs;
+                val ms = map (Logic.count_prems o prop_of) ctr_thms;
+                val (t', t) = (rhs', rhs)
+                  |> pairself
+                    (curry HOLogic.mk_eq (list_comb (Free (fun_name, fun_T),
+                      map Bound (length fun_args - 1 downto 0)))
+                    #> HOLogic.mk_Trueprop
+                    #> curry Logic.list_all (map dest_Free fun_args));
+                val discIs = map #discI ctr_specs;
+                val raw_code = mk_primcorec_raw_code_of_ctr_tac lthy distincts discIs sel_splits
+                    sel_split_asms ms ctr_thms
+                  |> K |> Goal.prove lthy [] [] t';
+              in
+                mk_primcorec_code_of_raw_code_tac sel_splits raw_code
+                |> K |> Goal.prove lthy [] [] t
+                |> single
+              end
+          end;*)
+
         val disc_alists = map3 (maps oo prove_disc) corec_specs exclssss disc_eqnss;
         val sel_alists = map4 (map ooo prove_sel) corec_specs disc_eqnss exclssss sel_eqnss;
-
         val disc_thmss = map (map snd) disc_alists;
         val sel_thmss = map (map snd) sel_alists;
-        val ctr_thmss = map5 (maps oooo prove_ctr) disc_alists sel_alists disc_eqnss sel_eqnss
+
+        val ctr_alists = map5 (maps oooo prove_ctr) disc_alists sel_alists disc_eqnss sel_eqnss
           (map #ctr_specs corec_specs);
+        val ctr_thmss = map (map snd) ctr_alists;
+
+        val code_thmss = map4 prove_code disc_eqnss sel_eqnss ctr_alists corec_specs;
 
         val simp_thmss = map2 append disc_thmss sel_thmss
 
@@ -956,7 +1043,7 @@
 
         val notes =
           [(coinductN, map (if n2m then single else K []) coinduct_thms, []),
-           (codeN, ctr_thmss(*FIXME*), code_nitpick_attrs),
+           (codeN, code_thmss, code_nitpick_attrs),
            (ctrN, ctr_thmss, []),
            (discN, disc_thmss, simp_attrs),
            (selN, sel_thmss, simp_attrs),