src/Pure/Isar/class.ML
changeset 25195 62638dcafe38
parent 25163 f737a88a3248
child 25209 bc21d8de18a9
--- a/src/Pure/Isar/class.ML	Thu Oct 25 19:27:52 2007 +0200
+++ b/src/Pure/Isar/class.ML	Thu Oct 25 19:27:53 2007 +0200
@@ -26,6 +26,8 @@
   val refresh_syntax: class -> Proof.context -> Proof.context
   val intro_classes_tac: thm list -> tactic
   val default_intro_classes_tac: thm list -> tactic
+  val prove_subclass: class * class -> thm list -> Proof.context
+    -> theory -> theory
   val print_classes: theory -> unit
   val uncheck: bool ref
 
@@ -61,6 +63,13 @@
       (Method.Basic (K (Method.SIMPLE_METHOD tac), Position.none), NONE)
   #> ProofContext.theory_of;
 
+fun prove_interpretation_in tac after_qed (name, expr) =
+  Locale.interpretation_in_locale
+      (ProofContext.theory after_qed) (name, expr)
+  #> Proof.global_terminal_proof
+      (Method.Basic (K (Method.SIMPLE_METHOD tac), Position.none), NONE)
+  #> ProofContext.theory_of;
+
 fun OF_LAST thm1 thm2 = thm1 RSN (Thm.nprems_of thm2, thm2);
 
 fun strip_all_ofclass thy sort =
@@ -315,29 +324,31 @@
     (*partial morphism of canonical interpretation*)
   intro: thm,
   defs: thm list,
-  operations: (string * ((term * (typ * int)) * (term * typ))) list
-    (*constant name ~> ((locale term,
-        (constant constraint, instantiaton index of class typ)), locale term & typ for uncheck)*)
+  operations: (string * (term * (typ * int))) list,
+    (*constant name ~> (locale term,
+        (constant constraint, instantiaton index of class typ))*)
+  unchecks: (term * term) list
 };
 
 fun rep_class_data (ClassData d) = d;
 fun mk_class_data ((consts, base_sort, inst, morphism, intro),
-    (defs, operations)) =
+    (defs, (operations, unchecks))) =
   ClassData { consts = consts, base_sort = base_sort, inst = inst,
     morphism = morphism, intro = intro, defs = defs,
-    operations = operations };
+    operations = operations, unchecks = unchecks };
 fun map_class_data f (ClassData { consts, base_sort, inst, morphism, intro,
-    defs, operations }) =
+    defs, operations, unchecks }) =
   mk_class_data (f ((consts, base_sort, inst, morphism, intro),
-    (defs, operations)));
+    (defs, (operations, unchecks))));
 fun merge_class_data _ (ClassData { consts = consts,
     base_sort = base_sort, inst = inst, morphism = morphism, intro = intro,
-    defs = defs1, operations = operations1 },
+    defs = defs1, operations = operations1, unchecks = unchecks1 },
   ClassData { consts = _, base_sort = _, inst = _, morphism = _, intro = _,
-    defs = defs2, operations = operations2 }) =
+    defs = defs2, operations = operations2, unchecks = unchecks2 }) =
   mk_class_data ((consts, base_sort, inst, morphism, intro),
     (Thm.merge_thms (defs1, defs2),
-      AList.merge (op =) (K true) (operations1, operations2)));
+      (AList.merge (op =) (K true) (operations1, operations2),
+        Library.merge (op aconv o pairself snd) (unchecks1, unchecks2))));
 
 structure ClassData = TheoryDataFun
 (
@@ -383,7 +394,8 @@
 fun these_operations thy =
   maps (#operations o the_class_data thy) o ancestry thy;
 
-fun local_operation thy = AList.lookup (op =) o these_operations thy;
+fun these_unchecks thy =
+  maps (#unchecks o the_class_data thy) o ancestry thy;
 
 fun sups_base_sort thy sort =
   let
@@ -435,30 +447,35 @@
 
 fun add_class_data ((class, superclasses), (cs, base_sort, inst, phi, intro)) thy =
   let
-    val operations = map (fn (v_ty as (_, ty'), (c, ty)) =>
-      (c, ((Free v_ty, (Logic.varifyT ty, 0)), (Free v_ty, ty')))) cs;
+    val operations = map (fn (v_ty, (c, ty)) =>
+      (c, ((Free v_ty, (Logic.varifyT ty, 0))))) cs;
+    val unchecks = map (fn ((v, ty'), (c, _)) =>
+      (Free (v, Type.strip_sorts ty'), Const (c, Type.strip_sorts ty'))) cs;
     val cs = (map o pairself) fst cs;
     val add_class = Graph.new_node (class,
-        mk_class_data ((cs, base_sort, map (SOME o Const) inst, phi, intro), ([], operations)))
+        mk_class_data ((cs, base_sort, map (SOME o Const) inst, phi, intro), ([], (operations, unchecks))))
       #> fold (curry Graph.add_edge class) superclasses;
   in
     ClassData.map add_class thy
   end;
 
-fun register_operation class ((c, (t, t_rev)), some_def) thy =
+fun register_operation class (c, ((t, some_t_rev), some_def)) thy =
   let
     val ty = Sign.the_const_constraint thy c;
     val typargs = Sign.const_typargs thy (c, ty);
     val typidx = find_index (fn TVar ((v, _), _) => Name.aT = v | _ => false) typargs;
-    val t_rev' = (map_types o map_atyps)
-      (fn ty as TFree (v, _) => if Name.aT = v then TFree (v, []) else ty | ty => ty) t_rev;
-    val ty' = Term.fastype_of t_rev';
+    fun mk_uncheck t_rev =
+      let
+        val t_rev' = map_types Type.strip_sorts t_rev;
+        val ty' = Term.fastype_of t_rev';
+      in (t_rev', Const (c, ty')) end;
+    val some_t_rev' = Option.map mk_uncheck some_t_rev;
   in
     thy
     |> (ClassData.map o Graph.map_node class o map_class_data o apsnd)
-      (fn (defs, operations) =>
+      (fn (defs, (operations, unchecks)) =>
         (fold cons (the_list some_def) defs,
-          (c, ((t, (ty, typidx)), (t_rev', ty'))) :: operations))
+          ((c, (t, (ty, typidx))) :: operations, fold cons (the_list some_t_rev') unchecks)))
   end;
 
 
@@ -550,6 +567,56 @@
   ("default", Method.thms_ctxt_args (Method.METHOD oo default_tac),
     "apply some intro/elim rule")]);
 
+fun subclass_rule thy (sub, sup) =
+  let
+    val ctxt = Locale.init sub thy;
+    val ctxt_thy = ProofContext.init thy;
+    val props =
+      Locale.global_asms_of thy sup
+      |> maps snd
+      |> map (ObjectLogic.ensure_propT thy);
+    fun tac { prems, context } =
+      Locale.intro_locales_tac true context prems
+        ORELSE ALLGOALS assume_tac;
+  in
+    Goal.prove_multi ctxt [] [] props tac
+    |> map (Assumption.export false ctxt ctxt_thy)
+    |> Variable.export ctxt ctxt_thy
+  end;
+
+fun prove_single_subclass (sub, sup) thms ctxt thy =
+  let
+    val ctxt_thy = ProofContext.init thy;
+    val subclass_rule = Conjunction.intr_balanced thms
+      |> Assumption.export false ctxt ctxt_thy
+      |> singleton (Variable.export ctxt ctxt_thy);
+    val sub_inst = Thm.ctyp_of thy (TVar ((Name.aT, 0), [sub]));
+    val sub_ax = #axioms (AxClass.get_info thy sub);
+    val classrel =
+      #intro (AxClass.get_info thy sup)
+      |> Drule.instantiate' [SOME sub_inst] []
+      |> OF_LAST (subclass_rule OF sub_ax)
+      |> strip_all_ofclass thy (Sign.super_classes thy sup)
+      |> Thm.strip_shyps
+  in
+    thy
+    |> AxClass.add_classrel classrel
+    |> prove_interpretation_in (ALLGOALS (ProofContext.fact_tac thms))
+         I (sub, Locale.Locale sup)
+    |> ClassData.map (Graph.add_edge (sub, sup))
+  end;
+
+fun prove_subclass (sub, sup) thms ctxt thy =
+  let
+    val supclasses = Sign.complete_sort thy [sup]
+      |> filter_out (fn class => Sign.subsort thy ([sub], [class]));
+    fun transform sup' = subclass_rule thy (sup, sup') |> map (fn thm => thm OF thms);
+  in
+    thy
+    |> fold_rev (fn sup' => prove_single_subclass (sub, sup')
+         (transform sup') ctxt) supclasses
+ end;
+
 
 (** classes and class target **)
 
@@ -560,7 +627,7 @@
     constraints: (string * typ) list,
     base_sort: sort,
     local_operation: string * typ -> (typ * term) option,
-    rews: (term * term) list,
+    unchecks: (term * term) list,
     passed: bool
   } option;
   fun init _ = NONE;
@@ -568,27 +635,30 @@
 
 fun synchronize_syntax thy sups base_sort ctxt =
   let
+    (* constraints *)
     val operations = these_operations thy sups;
-
-    (* constraints *)
-    fun local_constraint (c, ((_, (ty, _)),_ )) =
+    fun local_constraint (c, (_, (ty, _))) =
       let
         val ty' = ty
           |> map_atyps (fn ty as TVar ((v, 0), _) =>
                if v = Name.aT then TVar ((v, 0), base_sort) else ty)
           |> SOME;
       in (c, ty') end
-    val constraints = (map o apsnd) (fst o snd o fst) operations;
+    val constraints = (map o apsnd) (fst o snd) operations;
 
     (* check phase *)
     val typargs = Consts.typargs (ProofContext.consts_of ctxt);
-    fun check_const (c, ty) ((t, (_, idx)), _) =
+    fun check_const (c, ty) (t, (_, idx)) =
       ((nth (typargs (c, ty)) idx), t);
     fun local_operation (c_ty as (c, _)) = AList.lookup (op =) operations c
       |> Option.map (check_const c_ty);
 
     (* uncheck phase *)
-    val rews = map (fn (c, (_, (t, ty))) => (t, Const (c, ty))) operations;
+    val basify =
+      map_atyps (fn ty as TFree (v, _) => if Name.aT = v then TFree (v, base_sort)
+        else ty | ty => ty);
+    val unchecks = these_unchecks thy sups
+      |> (map o pairself o map_types) basify;
   in
     ctxt
     |> fold (ProofContext.add_const_constraint o local_constraint) operations
@@ -596,7 +666,7 @@
         constraints = constraints,
         base_sort = base_sort,
         local_operation = local_operation,
-        rews = rews,
+        unchecks = unchecks,
         passed = false
       }))
   end;
@@ -608,9 +678,9 @@
   in synchronize_syntax thy [class] base_sort ctxt end;
 
 val mark_passed = (ClassSyntax.map o Option.map)
-  (fn { constraints, base_sort, local_operation, rews, passed } =>
+  (fn { constraints, base_sort, local_operation, unchecks, passed } =>
     { constraints = constraints, base_sort = base_sort,
-      local_operation = local_operation, rews = rews, passed = true });
+      local_operation = local_operation, unchecks = unchecks, passed = true });
 
 fun sort_term_check ts ctxt =
   let
@@ -647,9 +717,9 @@
   let
     (*FIXME abbreviations*)
     val thy = ProofContext.theory_of ctxt;
-    val rews = (#rews o the o ClassSyntax.get) ctxt;
+    val unchecks = (#unchecks o the o ClassSyntax.get) ctxt;
     val ts' = if ! uncheck
-      then map (Pattern.rewrite_term thy rews []) ts else ts;
+      then map (Pattern.rewrite_term thy unchecks []) ts else ts;
   in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt) end;
 
 fun init_ctxt thy sups base_sort ctxt =
@@ -802,6 +872,7 @@
     val prfx = class_prefix class;
     val thy' = thy |> Sign.add_path prfx;
     val phi = morphism thy' class;
+    val base_sort = (#base_sort o the_class_data thy) class;
 
     val c' = Sign.full_name thy' c;
     val dict' = (map_types Logic.unvarifyT o Morphism.term phi) dict;
@@ -817,7 +888,7 @@
     |> Thm.add_def false (c, def_eq)
     |>> Thm.symmetric
     |-> (fn def => class_interpretation class [def] [Thm.prop_of def]
-          #> register_operation class ((c', (dict, dict')), SOME (Thm.varifyT def)))
+          #> register_operation class (c', ((dict, SOME dict'), SOME (Thm.varifyT def))))
     |> Sign.restore_naming thy
     |> Sign.add_const_constraint (c', SOME ty')
   end;
@@ -834,8 +905,7 @@
     val c' = Sign.full_name thy' c;
     val rews = map (Logic.dest_equals o Thm.prop_of) (these_defs thy' [class])
     val rhs' = (Pattern.rewrite_term thy rews [] o Morphism.term phi) rhs;
-    val rhs'' = map_types Logic.unvarifyT rhs';
-    val ty' = Term.fastype_of rhs'';
+    val ty' = (Logic.unvarifyT o Term.fastype_of) rhs';
 
     val c'' = NameSpace.full (Sign.naming_of thy' |> NameSpace.add_path prfx) c;
     val ty'' = (Type.strip_sorts o Logic.unvarifyT) (Sign.the_const_constraint thy' c'');
@@ -846,7 +916,7 @@
     |> Sign.add_abbrev (#1 prmode) pos (c, map_types Type.strip_sorts rhs') |> snd
     |> Sign.add_const_constraint (c', SOME ty')
     |> Sign.notation true prmode [(Const (c', ty'), mx)]
-    |> register_operation class ((c', (rhs, rhs'')), NONE)
+    |> register_operation class (c', ((rhs, NONE), NONE))
     |> Sign.restore_naming thy
   end;