src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML
changeset 54948 516adecd99dd
parent 54927 a5a2598f0651
child 54951 e25b4d22082b
--- a/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML	Wed Jan 08 09:20:14 2014 +0100
+++ b/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML	Wed Jan 08 17:26:42 2014 +0100
@@ -93,6 +93,8 @@
 fun unexpected_corec_call ctxt t =
   error ("Unexpected corecursive call: " ^ quote (Syntax.string_of_term ctxt t));
 
+fun order_list_duplicates xs = map snd (sort (int_ord o pairself fst) xs)
+
 val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True};
 val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False};
 val mk_dnf = mk_disjs o map mk_conjs;
@@ -474,12 +476,13 @@
   fun_T: typ,
   fun_args: term list,
   ctr: term,
-  ctr_no: int, (*FIXME*)
+  ctr_no: int,
   disc: term,
   prems: term list,
   auto_gen: bool,
   ctr_rhs_opt: term option,
   code_rhs_opt: term option,
+  eqn_pos: int,
   user_eqn: term
 };
 
@@ -492,6 +495,7 @@
   rhs_term: term,
   ctr_rhs_opt: term option,
   code_rhs_opt: term option,
+  eqn_pos: int,
   user_eqn: term
 };
 
@@ -500,7 +504,7 @@
   Sel of coeqn_data_sel;
 
 fun dissect_coeqn_disc fun_names sequentials (basic_ctr_specss : basic_corec_ctr_spec list list)
-    ctr_rhs_opt code_rhs_opt prems' concl matchedsss =
+    eqn_pos ctr_rhs_opt code_rhs_opt prems' concl matchedsss =
   let
     fun find_subterm p =
       let (* FIXME \<exists>? *)
@@ -558,12 +562,13 @@
       auto_gen = catch_all,
       ctr_rhs_opt = ctr_rhs_opt,
       code_rhs_opt = code_rhs_opt,
+      eqn_pos = eqn_pos,
       user_eqn = user_eqn
     }, matchedsss')
   end;
 
-fun dissect_coeqn_sel fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) ctr_rhs_opt
-    code_rhs_opt eqn' of_spec_opt eqn =
+fun dissect_coeqn_sel fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) eqn_pos
+    ctr_rhs_opt code_rhs_opt eqn' of_spec_opt eqn =
   let
     val (lhs, rhs) = HOLogic.dest_eq eqn
       handle TERM _ =>
@@ -591,12 +596,13 @@
       rhs_term = rhs,
       ctr_rhs_opt = ctr_rhs_opt,
       code_rhs_opt = code_rhs_opt,
+      eqn_pos = eqn_pos,
       user_eqn = user_eqn
     }
   end;
 
-fun dissect_coeqn_ctr fun_names sequentials (basic_ctr_specss : basic_corec_ctr_spec list list) eqn'
-    code_rhs_opt prems concl matchedsss =
+fun dissect_coeqn_ctr fun_names sequentials (basic_ctr_specss : basic_corec_ctr_spec list list)
+    eqn_pos eqn' code_rhs_opt prems concl matchedsss =
   let
     val (lhs, rhs) = HOLogic.dest_eq concl;
     val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free;
@@ -608,7 +614,7 @@
     val disc_concl = betapply (disc, lhs);
     val (eqn_data_disc_opt, matchedsss') = if length basic_ctr_specs = 1
       then (NONE, matchedsss)
-      else apfst SOME (dissect_coeqn_disc fun_names sequentials basic_ctr_specss
+      else apfst SOME (dissect_coeqn_disc fun_names sequentials basic_ctr_specss eqn_pos
           (SOME (abstract (List.rev fun_args) rhs)) code_rhs_opt prems disc_concl matchedsss);
 
     val sel_concls = sels ~~ ctr_args
@@ -623,13 +629,13 @@
 *)
 
     val eqns_data_sel =
-      map (dissect_coeqn_sel fun_names basic_ctr_specss
+      map (dissect_coeqn_sel fun_names basic_ctr_specss eqn_pos
         (SOME (abstract (List.rev fun_args) rhs)) code_rhs_opt eqn' (SOME ctr)) sel_concls;
   in
     (the_list eqn_data_disc_opt @ eqns_data_sel, matchedsss')
   end;
 
-fun dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn' concl matchedsss =
+fun dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn_pos eqn' concl matchedsss =
   let
     val (lhs, (rhs', rhs)) = HOLogic.dest_eq concl ||> `(expand_corec_code_rhs lthy has_call []);
     val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free;
@@ -651,13 +657,13 @@
 
     val sequentials = replicate (length fun_names) false;
   in
-    fold_map2 (dissect_coeqn_ctr fun_names sequentials basic_ctr_specss eqn'
+    fold_map2 (dissect_coeqn_ctr fun_names sequentials basic_ctr_specss eqn_pos eqn'
         (SOME (abstract (List.rev fun_args) rhs)))
       ctr_premss ctr_concls matchedsss
   end;
 
 fun dissect_coeqn lthy has_call fun_names sequentials
-    (basic_ctr_specss : basic_corec_ctr_spec list list) eqn' of_spec_opt matchedsss =
+    (basic_ctr_specss : basic_corec_ctr_spec list list) (eqn_pos, eqn') of_spec_opt matchedsss =
   let
     val eqn = drop_All eqn'
       handle TERM _ => primcorec_error_eqn "malformed function equation" eqn';
@@ -677,17 +683,17 @@
     if member (op =) discs head orelse
       is_some rhs_opt andalso
         member (op =) (filter (null o binder_types o fastype_of) ctrs) (the rhs_opt) then
-      dissect_coeqn_disc fun_names sequentials basic_ctr_specss NONE NONE prems concl matchedsss
+      dissect_coeqn_disc fun_names sequentials basic_ctr_specss eqn_pos NONE NONE prems concl matchedsss
       |>> single
     else if member (op =) sels head then
-      ([dissect_coeqn_sel fun_names basic_ctr_specss NONE NONE eqn' of_spec_opt concl],
+      ([dissect_coeqn_sel fun_names basic_ctr_specss eqn_pos NONE NONE eqn' of_spec_opt concl],
        matchedsss)
     else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
       member (op =) ctrs (head_of (unfold_let (the rhs_opt))) then
-      dissect_coeqn_ctr fun_names sequentials basic_ctr_specss eqn' NONE prems concl matchedsss
+      dissect_coeqn_ctr fun_names sequentials basic_ctr_specss eqn_pos eqn' NONE prems concl matchedsss
     else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
       null prems then
-      dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn' concl matchedsss
+      dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn_pos eqn' concl matchedsss
       |>> flat
     else
       primcorec_error_eqn "malformed function equation" eqn
@@ -834,6 +840,7 @@
           auto_gen = true,
           ctr_rhs_opt = Option.map #ctr_rhs_opt sel_eqn_opt |> the_default NONE,
           code_rhs_opt = Option.map #ctr_rhs_opt sel_eqn_opt |> the_default NONE,
+          eqn_pos = Option.map (curry (op +) 1 o #eqn_pos) sel_eqn_opt |> the_default 1000 (*###*),
           user_eqn = undef_const};
       in
         chop n disc_eqns ||> cons extra_disc_eqn |> (op @)
@@ -877,7 +884,7 @@
     val basic_ctr_specss = map (basic_corec_specs_of lthy) res_Ts;
     val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
     val eqns_data =
-      fold_map2 (dissect_coeqn lthy has_call fun_names sequentials basic_ctr_specss) (map snd specs)
+      fold_map2 (dissect_coeqn lthy has_call fun_names sequentials basic_ctr_specss) (tag_list 0 (map snd specs))
         of_specs_opt []
       |> flat o fst;
 
@@ -897,7 +904,7 @@
     val ((n2m, corec_specs', _, coinduct_thm, strong_coinduct_thm, coinduct_thms,
           strong_coinduct_thms), lthy') =
       corec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy;
-    val corec_specs = take actual_nn corec_specs'; (*FIXME*)
+    val corec_specs = take actual_nn corec_specs';
     val ctr_specss = map #ctr_specs corec_specs;
 
     val disc_eqnss' = map_filter (try (fn Disc x => x)) eqns_data
@@ -1011,7 +1018,7 @@
             mk_excludesss excludes (length ctr_specs));
 
         fun prove_disc ({ctr_specs, ...} : corec_spec) excludesss
-            ({fun_name, fun_T, fun_args, ctr_no, prems, ...} : coeqn_data_disc) =
+            ({fun_name, fun_T, fun_args, ctr_no, prems, eqn_pos, ...} : coeqn_data_disc) =
           if Term.aconv_untyped (#disc (nth ctr_specs ctr_no), @{term "\<lambda>x. x = x"}) then
             []
           else
@@ -1033,12 +1040,13 @@
                 |> K |> Goal.prove lthy [] [] goal
                 |> Thm.close_derivation
                 |> pair (#disc (nth ctr_specs ctr_no))
+                |> pair eqn_pos
                 |> single
             end;
 
         fun prove_sel ({nested_map_idents, nested_map_comps, ctr_specs, ...} : corec_spec)
             (disc_eqns : coeqn_data_disc list) excludesss
-            ({fun_name, fun_T, fun_args, ctr, sel, rhs_term, ...} : coeqn_data_sel) =
+            ({fun_name, fun_T, fun_args, ctr, sel, rhs_term, eqn_pos, ...} : coeqn_data_sel) =
           let
             val SOME ctr_spec = find_first (curry (op =) ctr o #ctr) ctr_specs;
             val ctr_no = find_index (curry (op =) ctr o #ctr) ctr_specs;
@@ -1062,6 +1070,7 @@
             |> K |> Goal.prove lthy [] [] goal
             |> Thm.close_derivation
             |> pair sel
+            |> pair eqn_pos
           end;
 
         fun prove_ctr disc_alist sel_alist (disc_eqns : coeqn_data_disc list)
@@ -1075,12 +1084,13 @@
               |> exists (null o snd)
           then [] else
             let
-              val (fun_name, fun_T, fun_args, prems, rhs_opt) =
+              val (fun_name, fun_T, fun_args, prems, rhs_opt, eqn_pos) =
                 (find_first (curry (op =) ctr o #ctr) disc_eqns,
                  find_first (curry (op =) ctr o #ctr) sel_eqns)
                 |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x,
-                  #ctr_rhs_opt x))
-                ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, [], #ctr_rhs_opt x))
+                  #ctr_rhs_opt x, #eqn_pos x))
+                ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, [], #ctr_rhs_opt x,
+                  #eqn_pos x))
                 |> the o merge_options;
               val m = length prems;
               val goal =
@@ -1102,6 +1112,7 @@
                 |> K |> Goal.prove lthy [] [] goal
                 |> Thm.close_derivation
                 |> pair ctr
+                |> pair eqn_pos
                 |> single
             end;
 
@@ -1149,7 +1160,7 @@
                       val ctr_conds_argss_opt = map prove_code_ctr ctr_specs;
                       val exhaustive_code =
                         exhaustive
-                        orelse forall null (map_filter (try (fst o the)) ctr_conds_argss_opt)
+                        orelse exists (is_some andf (null o fst o the)) ctr_conds_argss_opt
                         orelse forall is_some ctr_conds_argss_opt
                           andalso exists #auto_gen disc_eqns;
                       val rhs =
@@ -1192,14 +1203,15 @@
           end;
 
         val disc_alistss = map3 (map oo prove_disc) corec_specs excludessss disc_eqnss;
-        val disc_alists = map flat disc_alistss;
+        val disc_alists = map (map snd o flat) disc_alistss;
         val sel_alists = map4 (map ooo prove_sel) corec_specs disc_eqnss excludessss sel_eqnss;
-        val disc_thmsss = map (map (map snd)) disc_alistss;
-        val disc_thmss = map flat disc_thmsss;
-        val sel_thmss = map (map snd) sel_alists;
+        val disc_thmss = map (map snd o order_list_duplicates o flat) disc_alistss;
+        val disc_thmsss' = map (map (map (snd o snd))) disc_alistss;
+        val disc_thmss' = map flat disc_thmsss';
+        val sel_thmss = map (map snd o order_list_duplicates) sel_alists;
 
-        fun prove_disc_iff ({ctr_specs, ...} : corec_spec) exhaust_thms disc_thmss disc_thms
-            ({fun_name, fun_T, fun_args, ctr_no, prems, ...} : coeqn_data_disc) =
+        fun prove_disc_iff ({ctr_specs, ...} : corec_spec) exhaust_thms disc_thmss' disc_thms
+            ({fun_name, fun_T, fun_args, ctr_no, prems, eqn_pos, ...} : coeqn_data_disc) =
           if null disc_thms orelse null exhaust_thms then
             []
           else
@@ -1214,23 +1226,26 @@
                 []
               else
                 mk_primcorec_disc_iff_tac lthy (map (fst o dest_Free) fun_args)
-                  (the_single exhaust_thms) (the_single disc_thms) disc_thmss (flat disc_excludess)
+                  (the_single exhaust_thms) (the_single disc_thms) disc_thmss' (flat disc_excludess)
                 |> K |> Goal.prove lthy [] [] goal
                 |> Thm.close_derivation
+                |> pair eqn_pos
                 |> single
             end;
 
         val disc_iff_thmss = map5 (flat ooo map2 ooo prove_disc_iff) corec_specs exhaust_thmss
-          disc_thmsss disc_thmsss disc_eqnss;
+          disc_thmsss' disc_thmsss' disc_eqnss
+          |> map order_list_duplicates;
 
-        val ctr_alists = map5 (maps oooo prove_ctr) disc_alists sel_alists disc_eqnss sel_eqnss
-          ctr_specss;
-        val ctr_thmss = map (map snd) ctr_alists;
+        val ctr_alists = map5 (maps oooo prove_ctr) disc_alists (map (map snd) sel_alists) disc_eqnss
+          sel_eqnss ctr_specss;
+        val ctr_thmss' = map (map snd) ctr_alists;
+        val ctr_thmss = map (map snd o order_list) ctr_alists;
 
-        val code_thmss = map6 prove_code exhaustives disc_eqnss sel_eqnss nchotomy_thmss ctr_alists
+        val code_thmss = map6 prove_code exhaustives disc_eqnss sel_eqnss nchotomy_thmss ctr_thmss'
           ctr_specss;
 
-        val simp_thmss = map2 append disc_thmss sel_thmss
+        val simp_thmss = map2 append disc_thmss sel_thmss;
 
         val common_name = mk_common_name fun_names;