properly synchronize parallel lists
authorblanchet
Thu, 02 Jan 2014 09:50:22 +0100
changeset 54910 0ec2cccbf8ad
parent 54909 63db983c6953
child 54911 6a6980245ce0
properly synchronize parallel lists
src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML
src/HOL/BNF/Tools/bnf_gfp_rec_sugar_tactics.ML
--- a/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML	Thu Jan 02 09:50:22 2014 +0100
+++ b/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML	Thu Jan 02 09:50:22 2014 +0100
@@ -809,30 +809,32 @@
 
 fun mk_actual_disc_eqns fun_binding arg_Ts exhaustive ({ctr_specs, ...} : corec_spec)
     (sel_eqns : coeqn_data_sel list) (disc_eqns : coeqn_data_disc list) =
-  if exhaustive orelse length disc_eqns <> length ctr_specs - 1 then
-    disc_eqns
-  else
-    let
-      val n = 0 upto length ctr_specs
-        |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns));
-      val fun_args = (try (#fun_args o hd) disc_eqns, try (#fun_args o hd) sel_eqns)
-        |> the_default (map (curry Free Name.uu) arg_Ts) o merge_options;
-      val maybe_sel_eqn = find_first (equal (Binding.name_of fun_binding) o #fun_name) sel_eqns;
-      val extra_disc_eqn = {
-        fun_name = Binding.name_of fun_binding,
-        fun_T = arg_Ts ---> body_type (fastype_of (#ctr (hd ctr_specs))),
-        fun_args = fun_args,
-        ctr = #ctr (nth ctr_specs n),
-        ctr_no = n,
-        disc = #disc (nth ctr_specs n),
-        prems = maps (s_not_conj o #prems) disc_eqns,
-        auto_gen = true,
-        maybe_ctr_rhs = Option.map #maybe_ctr_rhs maybe_sel_eqn |> the_default NONE,
-        maybe_code_rhs = Option.map #maybe_ctr_rhs maybe_sel_eqn |> the_default NONE,
-        user_eqn = undef_const};
-    in
-      chop n disc_eqns ||> cons extra_disc_eqn |> (op @)
-    end;
+  let val num_disc_eqns = length disc_eqns in
+    if num_disc_eqns < length ctr_specs - 1 andalso num_disc_eqns > 1 then
+      disc_eqns
+    else
+      let
+        val n = 0 upto length ctr_specs
+          |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns));
+        val fun_args = (try (#fun_args o hd) disc_eqns, try (#fun_args o hd) sel_eqns)
+          |> the_default (map (curry Free Name.uu) arg_Ts) o merge_options;
+        val maybe_sel_eqn = find_first (equal (Binding.name_of fun_binding) o #fun_name) sel_eqns;
+        val extra_disc_eqn = {
+          fun_name = Binding.name_of fun_binding,
+          fun_T = arg_Ts ---> body_type (fastype_of (#ctr (hd ctr_specs))),
+          fun_args = fun_args,
+          ctr = #ctr (nth ctr_specs n),
+          ctr_no = n,
+          disc = #disc (nth ctr_specs n),
+          prems = maps (s_not_conj o #prems) disc_eqns,
+          auto_gen = true,
+          maybe_ctr_rhs = Option.map #maybe_ctr_rhs maybe_sel_eqn |> the_default NONE,
+          maybe_code_rhs = Option.map #maybe_ctr_rhs maybe_sel_eqn |> the_default NONE,
+          user_eqn = undef_const};
+      in
+        chop n disc_eqns ||> cons extra_disc_eqn |> (op @)
+      end
+  end;
 
 fun find_corec_calls ctxt has_call basic_ctr_specs ({ctr, sel, rhs_term, ...} : coeqn_data_sel) =
   let
@@ -931,7 +933,8 @@
       |> split_list o map split_list;
 
     val list_all_fun_args =
-      map2 ((fn {fun_args, ...} => map (curry Logic.list_all (map dest_Free fun_args))) o hd)
+      map2 (fn [] => I
+          | {fun_args, ...} :: _ => map (curry Logic.list_all (map dest_Free fun_args)))
         disc_eqnss;
 
     val syntactic_exhaustives =
@@ -1180,14 +1183,16 @@
               end)
           end;
 
-        val disc_alists = map3 (maps oo prove_disc) corec_specs excludessss disc_eqnss;
+        val disc_alistss = map3 (map oo prove_disc) corec_specs excludessss disc_eqnss;
+        val disc_alists = map flat disc_alistss;
         val sel_alists = map4 (map ooo prove_sel) corec_specs disc_eqnss excludessss sel_eqnss;
-        val disc_thmss = map (map snd) disc_alists;
+        val disc_thmsss = map (map (map snd)) disc_alistss;
+        val disc_thmss = map flat disc_thmsss;
         val sel_thmss = map (map snd) sel_alists;
 
-        fun prove_disc_iff ({ctr_specs, ...} : corec_spec) exhaust_thms discs
+        fun prove_disc_iff ({ctr_specs, ...} : corec_spec) exhaust_thms disc_thmss disc_thms
             ({fun_name, fun_T, fun_args, ctr_no, prems, ...} : coeqn_data_disc) =
-          if null discs orelse null exhaust_thms then
+          if null disc_thms orelse null exhaust_thms then
             []
           else
             let
@@ -1200,15 +1205,15 @@
               if prems = [@{term False}] then
                 []
               else
-                mk_primcorec_disc_iff_tac lthy (ctr_no + 1) (the_single exhaust_thms) discs
-                  (flat disc_excludess)
+                mk_primcorec_disc_iff_tac lthy (the_single exhaust_thms) (the_single disc_thms)
+                  (flat disc_thmss) (flat disc_excludess)
                 |> K |> Goal.prove lthy [] [] goal
                 |> Thm.close_derivation
                 |> single
             end;
 
-        val disc_iff_thmss = map4 (maps ooo prove_disc_iff) corec_specs exhaust_thmss disc_thmss
-          disc_eqnss;
+        val disc_iff_thmss = map5 (flat ooo map2 ooo prove_disc_iff) corec_specs exhaust_thmss
+          disc_thmsss disc_thmsss disc_eqnss;
 
         val ctr_alists = map5 (maps oooo prove_ctr) disc_alists sel_alists disc_eqnss sel_eqnss
           ctr_specss;
--- a/src/HOL/BNF/Tools/bnf_gfp_rec_sugar_tactics.ML	Thu Jan 02 09:50:22 2014 +0100
+++ b/src/HOL/BNF/Tools/bnf_gfp_rec_sugar_tactics.ML	Thu Jan 02 09:50:22 2014 +0100
@@ -12,7 +12,7 @@
   val mk_primcorec_ctr_of_dtr_tac: Proof.context -> int -> thm -> thm option -> thm list -> tactic
   val mk_primcorec_disc_tac: Proof.context -> thm list -> thm -> int -> int -> thm list list list ->
     tactic
-  val mk_primcorec_disc_iff_tac: Proof.context -> int -> thm -> thm list -> thm list -> tactic
+  val mk_primcorec_disc_iff_tac: Proof.context -> thm -> thm -> thm list -> thm list -> tactic
   val mk_primcorec_exhaust_tac: int -> thm -> tactic
   val mk_primcorec_raw_code_of_ctr_tac: Proof.context -> thm list -> thm list -> thm list ->
     thm list -> int list -> thm list -> thm option -> tactic
@@ -77,22 +77,17 @@
 fun mk_primcorec_disc_tac ctxt defs disc_corec k m exclsss =
   mk_primcorec_prelude ctxt defs disc_corec THEN mk_primcorec_cases_tac ctxt k m exclsss;
 
-fun mk_primcorec_disc_iff_tac ctxt k fun_exhaust fun_discs disc_excludes =
-  let
-    val n = length fun_discs;
-    val ks = 1 upto n;
-  in
-    HEADGOAL (rtac iffI THEN'
-      rtac fun_exhaust THEN'
-      EVERY' (map2 (fn k' => fn disc =>
-          if k' = k then
-            REPEAT_DETERM o (atac ORELSE' rtac TrueI ORELSE' etac conjI)
-          else
-            dtac disc THEN' (REPEAT_DETERM o atac) THEN' dresolve_tac disc_excludes THEN'
-            etac notE THEN' atac)
-        ks fun_discs) THEN'
-      rtac (unfold_thms ctxt [atomize_conjL] (nth fun_discs (k - 1))) THEN_MAYBE' atac)
-  end;
+fun mk_primcorec_disc_iff_tac ctxt fun_exhaust fun_disc fun_discs disc_excludes =
+  HEADGOAL (rtac iffI THEN'
+    rtac fun_exhaust THEN'
+    EVERY' (map (fn fun_disc' =>
+        if Thm.eq_thm (fun_disc', fun_disc) then
+          REPEAT_DETERM o (atac ORELSE' rtac TrueI ORELSE' etac conjI)
+        else
+          dtac fun_disc' THEN' (REPEAT_DETERM o atac) THEN' dresolve_tac disc_excludes THEN'
+          etac notE THEN' atac)
+      fun_discs) THEN'
+    rtac (unfold_thms ctxt [atomize_conjL] fun_disc) THEN_MAYBE' atac);
 
 fun mk_primcorec_sel_tac ctxt defs distincts splits split_asms map_idents map_comps fun_sel k m
     exclsss =