src/Pure/Isar/code.ML
changeset 37438 4906ab970316
parent 37425 b5492f611129
child 37448 3bd4b3809bee
--- a/src/Pure/Isar/code.ML	Tue Jun 15 08:32:32 2010 +0200
+++ b/src/Pure/Isar/code.ML	Tue Jun 15 11:38:39 2010 +0200
@@ -72,6 +72,7 @@
   val is_abstr: theory -> string -> bool
   val get_cert: theory -> ((thm * bool) list -> (thm * bool) list) -> string -> cert
   val get_case_scheme: theory -> string -> (int * (int * string list)) option
+  val get_case_cong: theory -> string -> thm option
   val undefineds: theory -> string list
   val print_codesetup: theory -> unit
 
@@ -168,7 +169,7 @@
     (*with explicit history*),
   types: ((serial * ((string * sort) list * typ_spec)) list) Symtab.table
     (*with explicit history*),
-  cases: (int * (int * string list)) Symtab.table * unit Symtab.table
+  cases: ((int * (int * string list)) * thm) Symtab.table * unit Symtab.table
 };
 
 fun make_spec (history_concluded, ((signatures, functions), (types, cases))) =
@@ -935,7 +936,8 @@
   handle Bind => error "bad case certificate"
        | TERM _ => error "bad case certificate";
 
-fun get_case_scheme thy = Symtab.lookup ((fst o the_cases o the_exec) thy);
+fun get_case_scheme thy = Option.map fst o Symtab.lookup ((fst o the_cases o the_exec) thy);
+fun get_case_cong thy = Option.map snd o Symtab.lookup ((fst o the_cases o the_exec) thy);
 
 val undefineds = Symtab.keys o snd o the_cases o the_exec;
 
@@ -970,8 +972,8 @@
                       :: Pretty.str "of"
                       :: map (Pretty.quote o Syntax.pretty_typ_global thy) tys)) cos)
       );
-    fun pretty_case (const, (_, (_, []))) = Pretty.str (string_of_const thy const)
-      | pretty_case (const, (_, (_, cos))) = (Pretty.block o Pretty.breaks) [
+    fun pretty_case (const, ((_, (_, [])), _)) = Pretty.str (string_of_const thy const)
+      | pretty_case (const, ((_, (_, cos)), _)) = (Pretty.block o Pretty.breaks) [
           Pretty.str (string_of_const thy const), Pretty.str "with",
           (Pretty.block o Pretty.commas o map (Pretty.str o string_of_const thy)) cos];
     val functions = the_functions exec
@@ -1108,14 +1110,34 @@
 
 (* cases *)
 
+fun case_cong thy case_const (num_args, (pos, constrs)) =
+  let
+    val ([x, y], ctxt) = Name.variants ["A", "A'"] Name.context;
+    val (zs, _) = Name.variants (replicate (num_args - 1) "") ctxt;
+    val (ws, vs) = chop pos zs;
+    val T = Logic.unvarifyT_global (const_typ thy case_const);
+    val Ts = (fst o strip_type) T;
+    val T_cong = nth Ts pos;
+    fun mk_prem z = Free (z, T_cong);
+    fun mk_concl z = list_comb (Const (case_const, T), map2 (curry Free) (ws @ z :: vs) Ts);
+    val (prem, concl) = pairself Logic.mk_equals (pairself mk_prem (x, y), pairself mk_concl (x, y));
+    fun tac { prems, ... } = Simplifier.rewrite_goals_tac prems
+      THEN ALLGOALS (ProofContext.fact_tac [Drule.reflexive_thm]);
+  in Skip_Proof.prove_global thy (x :: y :: zs) [prem] concl tac end;
+
 fun add_case thm thy =
   let
-    val (c, (k, case_pats)) = case_cert thm;
+    val (case_const, (k, case_pats)) = case_cert thm;
     val _ = case filter_out (is_constr thy) case_pats
      of [] => ()
       | cs => error ("Non-constructor(s) in case certificate: " ^ commas (map quote cs));
-    val entry = (1 + Int.max (1, length case_pats), (k, case_pats))
-  in (map_exec_purge o map_cases o apfst) (Symtab.update (c, entry)) thy end;
+    val entry = (1 + Int.max (1, length case_pats), (k, case_pats));
+  in
+    thy
+    |> Theory.checkpoint
+    |> `(fn thy => case_cong thy case_const entry)
+    |-> (fn cong => (map_exec_purge o map_cases o apfst) (Symtab.update (case_const, (entry, cong))))
+  end;
 
 fun add_undefined c thy =
   (map_exec_purge o map_cases o apsnd) (Symtab.update (c, ())) thy;
@@ -1138,7 +1160,7 @@
             then insert (op =) c else I)
             ((the_functions o the_exec) thy) (old_proj :: old_constrs);
     fun drop_outdated_cases cases = fold Symtab.delete_safe
-      (Symtab.fold (fn (c, (_, (_, cos))) =>
+      (Symtab.fold (fn (c, ((_, (_, cos)), _)) =>
         if exists (member (op =) old_constrs) cos
           then insert (op =) c else I) cases []) cases;
   in