src/HOL/BNF/Tools/bnf_def.ML
changeset 54189 c0186a0d8cb3
parent 54158 0af35cebe8ca
child 54236 e00009523727
--- a/src/HOL/BNF/Tools/bnf_def.ML	Mon Oct 21 23:45:27 2013 +0200
+++ b/src/HOL/BNF/Tools/bnf_def.ML	Tue Oct 22 14:17:12 2013 +0200
@@ -612,14 +612,12 @@
     val fact_policy = mk_fact_policy no_defs_lthy;
     val bnf_b = qualify raw_bnf_b;
     val live = length raw_sets;
-    val nwits = length raw_wits;
 
     val map_rhs = prep_term no_defs_lthy raw_map;
     val set_rhss = map (prep_term no_defs_lthy) raw_sets;
     val (bd_rhsT, bd_rhs) = (case prep_term no_defs_lthy raw_bd_Abs of
       Abs (_, T, t) => (T, t)
     | _ => error "Bad bound constant");
-    val wit_rhss = map (prep_term no_defs_lthy) raw_wits;
 
     fun err T =
       error ("Trying to register the type " ^ quote (Syntax.string_of_typ no_defs_lthy T) ^
@@ -672,21 +670,14 @@
           else map (fn i => set_name i (fn () => mk_suffix_binding (mk_setN i))) (1 upto live);
       in bs ~~ set_rhss end;
     val bd_bind_def = (fn () => def_qualify (mk_suffix_binding bdN), bd_rhs);
-    val wit_binds_defs =
-      let
-        val bs = if nwits = 1 then [fn () => def_qualify (mk_suffix_binding witN)]
-          else map (fn i => fn () => def_qualify (mk_suffix_binding (mk_witN i))) (1 upto nwits);
-      in bs ~~ wit_rhss end;
 
-    val (((((bnf_map_term, raw_map_def),
+    val ((((bnf_map_term, raw_map_def),
       (bnf_set_terms, raw_set_defs)),
-      (bnf_bd_term, raw_bd_def)),
-      (bnf_wit_terms, raw_wit_defs)), (lthy, lthy_old)) =
+      (bnf_bd_term, raw_bd_def)), (lthy, lthy_old)) =
         no_defs_lthy
         |> maybe_define true map_bind_def
         ||>> apfst split_list o fold_map (maybe_define true) set_binds_defs
         ||>> maybe_define true bd_bind_def
-        ||>> apfst split_list o fold_map (maybe_define true) wit_binds_defs
         ||> `(maybe_restore no_defs_lthy);
 
     val phi = Proof_Context.export_morphism lthy_old lthy;
@@ -694,7 +685,6 @@
     val bnf_map_def = Morphism.thm phi raw_map_def;
     val bnf_set_defs = map (Morphism.thm phi) raw_set_defs;
     val bnf_bd_def = Morphism.thm phi raw_bd_def;
-    val bnf_wit_defs = map (Morphism.thm phi) raw_wit_defs;
 
     val bnf_map = Morphism.term phi bnf_map_term;
 
@@ -713,7 +703,6 @@
     val bdT = Morphism.typ phi bd_rhsT;
     val bnf_bd =
       Term.subst_TVars (Term.add_tvar_namesT bdT [] ~~ CA_params) (Morphism.term phi bnf_bd_term);
-    val bnf_wits = map (normalize_wit CA_params CA alphas o Morphism.term phi) bnf_wit_terms;
 
     (*TODO: assert Ds = (TVars of bnf_map) \ (alphas @ betas) as sets*)
     val deads = (case Ds_opt of
@@ -770,7 +759,6 @@
     val bnf_sets_As = map (mk_bnf_t As') bnf_sets;
     val bnf_sets_Bs = map (mk_bnf_t Bs') bnf_sets;
     val bnf_bd_As = mk_bnf_t As' bnf_bd;
-    val bnf_wit_As = map (apsnd (mk_bnf_t As')) bnf_wits;
 
     val pre_names_lthy = lthy;
     val ((((((((((((((((((((((((fs, gs), hs), x), y), zs), ys), As),
@@ -827,9 +815,23 @@
       (fn () => def_qualify (if Binding.is_empty rel_b then mk_suffix_binding relN else rel_b),
          rel_rhs);
 
-    val ((bnf_rel_term, raw_rel_def), (lthy, lthy_old)) =
+    val wit_rhss =
+      if null raw_wits then
+        [fold_rev Term.absdummy As' (Term.list_comb (bnf_map_AsAs,
+          map2 (fn T => fn i => Term.absdummy T (Bound i)) As' (live downto 1)) $
+          Const (@{const_name undefined}, CA'))]
+      else map (prep_term no_defs_lthy) raw_wits;
+    val nwits = length wit_rhss;
+    val wit_binds_defs =
+      let
+        val bs = if nwits = 1 then [fn () => def_qualify (mk_suffix_binding witN)]
+          else map (fn i => fn () => def_qualify (mk_suffix_binding (mk_witN i))) (1 upto nwits);
+      in bs ~~ wit_rhss end;
+
+    val (((bnf_rel_term, raw_rel_def), (bnf_wit_terms, raw_wit_defs)), (lthy, lthy_old)) =
       lthy
       |> maybe_define (is_some raw_rel_opt) rel_bind_def
+      ||>> apfst split_list o fold_map (maybe_define (not (null raw_wits))) wit_binds_defs
       ||> `(maybe_restore lthy);
 
     val phi = Proof_Context.export_morphism lthy_old lthy;
@@ -841,6 +843,10 @@
     val rel = mk_bnf_rel pred2RTs CA' CB';
     val relAsAs = mk_bnf_rel self_pred2RTs CA' CA';
 
+    val bnf_wit_defs = map (Morphism.thm phi) raw_wit_defs;
+    val bnf_wits = map (normalize_wit CA_params CA alphas o Morphism.term phi) bnf_wit_terms;
+    val bnf_wit_As = map (apsnd (mk_bnf_t As')) bnf_wits;
+
     val map_id0_goal =
       let val bnf_map_app_id = Term.list_comb (bnf_map_AsAs, map HOLogic.id_const As') in
         mk_Trueprop_eq (bnf_map_app_id, HOLogic.id_const CA')
@@ -939,11 +945,14 @@
         map wit_goal (0 upto live - 1)
       end;
 
-    val wit_goalss = map mk_wit_goals bnf_wit_As;
+    val trivial_wit_tac = mk_trivial_wit_tac bnf_wit_defs;
 
-    fun after_qed thms lthy =
+    val wit_goalss =
+      (if null raw_wits then SOME trivial_wit_tac else NONE, map mk_wit_goals bnf_wit_As);
+
+    fun after_qed mk_wit_thms thms lthy =
       let
-        val (axioms, wit_thms) = apfst (mk_axioms live) (chop (length goals) thms);
+        val (axioms, nontriv_wit_thms) = apfst (mk_axioms live) (chop (length goals) thms);
 
         val bd_Card_order = #bd_card_order axioms RS @{thm conjunct2[OF card_order_on_Card_order]};
         val bd_Cinfinite = @{thm conjI} OF [#bd_cinfinite axioms, bd_Card_order];
@@ -1016,6 +1025,9 @@
 
         val set_map = map (fn thm => Lazy.lazy (fn () => mk_set_map thm)) (#set_map0 axioms);
 
+        val wit_thms =
+          if null nontriv_wit_thms then mk_wit_thms (map Lazy.force set_map) else nontriv_wit_thms;
+
         fun mk_in_bd () =
           let
             val bdT = fst (dest_relT (fastype_of bnf_bd_As));
@@ -1259,35 +1271,45 @@
   (bnf, Local_Theory.declaration {syntax = false, pervasive = true}
     (fn phi => Data.map (Symtab.default (key, morph_bnf phi bnf))) lthy);
 
-(* TODO: Once the invariant "nwits > 0" holds, remove "mk_conjunction_balanced'" and "rtac TrueI"
-   below *)
-fun mk_conjunction_balanced' [] = @{prop True}
-  | mk_conjunction_balanced' ts = Logic.mk_conjunction_balanced ts;
-
 fun bnf_def const_policy fact_policy qualify tacs wit_tac Ds map_b rel_b set_bs =
-  (fn (_, goals, wit_goalss, after_qed, lthy, one_step_defs) =>
+  (fn (_, goals, (triv_tac_opt, wit_goalss), after_qed, lthy, one_step_defs) =>
   let
-    val wits_tac =
-      K (TRYALL Goal.conjunction_tac) THEN' K (TRYALL (rtac TrueI)) THEN'
-      mk_unfold_thms_then_tac lthy one_step_defs wit_tac;
-    val wit_goals = map mk_conjunction_balanced' wit_goalss;
-    val wit_thms =
-      Goal.prove_sorry lthy [] [] (mk_conjunction_balanced' wit_goals) wits_tac
-      |> Conjunction.elim_balanced (length wit_goals)
-      |> map2 (Conjunction.elim_balanced o length) wit_goalss
-      |> map (map (Thm.close_derivation o Thm.forall_elim_vars 0));
+    fun mk_wits_tac set_maps =
+      K (TRYALL Goal.conjunction_tac) THEN'
+      (case triv_tac_opt of
+        SOME tac => tac set_maps
+      | NONE => mk_unfold_thms_then_tac lthy one_step_defs wit_tac);
+    val wit_goals = map Logic.mk_conjunction_balanced wit_goalss;
+    fun mk_wit_thms set_maps =
+      Goal.prove_sorry lthy [] [] (Logic.mk_conjunction_balanced wit_goals) (mk_wits_tac set_maps)
+        |> Conjunction.elim_balanced (length wit_goals)
+        |> map2 (Conjunction.elim_balanced o length) wit_goalss
+        |> map (map (Thm.close_derivation o Thm.forall_elim_vars 0));
   in
     map2 (Thm.close_derivation oo Goal.prove_sorry lthy [] [])
       goals (map (mk_unfold_thms_then_tac lthy one_step_defs) tacs)
-    |> (fn thms => after_qed (map single thms @ wit_thms) lthy)
+    |> (fn thms => after_qed mk_wit_thms (map single thms) lthy)
   end) oo prepare_def const_policy fact_policy qualify (K I) Ds map_b rel_b set_bs;
 
-val bnf_cmd = (fn (key, goals, wit_goals, after_qed, lthy, defs) =>
-  Proof.unfolding ([[(defs, [])]])
-    (Proof.theorem NONE (snd o register_bnf key oo after_qed)
-      (map (single o rpair []) goals @ map (map (rpair [])) wit_goals) lthy)) oo
-  prepare_def Do_Inline (user_policy Note_Some) I Syntax.read_term NONE Binding.empty Binding.empty
-    [];
+val bnf_cmd = (fn (key, goals, (triv_tac_opt, wit_goalss), after_qed, lthy, defs) =>
+  let
+    val wit_goals = map Logic.mk_conjunction_balanced wit_goalss;
+    fun mk_triv_wit_thms tac set_maps =
+      Goal.prove_sorry lthy [] [] (Logic.mk_conjunction_balanced wit_goals)
+        (K (TRYALL Goal.conjunction_tac) THEN' tac set_maps)
+        |> Conjunction.elim_balanced (length wit_goals)
+        |> map2 (Conjunction.elim_balanced o length) wit_goalss
+        |> map (map (Thm.close_derivation o Thm.forall_elim_vars 0));
+    val (mk_wit_thms, nontriv_wit_goals) = 
+      (case triv_tac_opt of
+        NONE => (fn _ => [], map (map (rpair [])) wit_goalss)
+      | SOME tac => (mk_triv_wit_thms tac, []));
+  in
+    Proof.unfolding ([[(defs, [])]])
+      (Proof.theorem NONE (snd o register_bnf key oo after_qed mk_wit_thms)
+        (map (single o rpair []) goals @ nontriv_wit_goals) lthy)
+  end) oo prepare_def Do_Inline (user_policy Note_Some) I Syntax.read_term NONE
+    Binding.empty Binding.empty [];
 
 fun print_bnfs ctxt =
   let
@@ -1324,7 +1346,9 @@
     "register a type as a bounded natural functor"
     ((parse_opt_binding_colon -- Parse.term --
        (@{keyword "["} |-- Parse.list Parse.term --| @{keyword "]"}) -- Parse.term --
-       (@{keyword "["} |-- Parse.list Parse.term --| @{keyword "]"}) -- Scan.option Parse.term)
+       (Scan.option ((@{keyword "["} |-- Parse.list Parse.term --| @{keyword "]"}))
+         >> the_default []) --
+       Scan.option Parse.term)
        >> bnf_cmd);
 
 end;