tuned primitive inferences
authorhaftmann
Wed, 19 Dec 2007 22:33:44 +0100
changeset 25711 91cee0cefaf7
parent 25710 4cdf7de81e1b
child 25712 f488a37cfad4
tuned primitive inferences
src/Pure/Isar/class.ML
--- a/src/Pure/Isar/class.ML	Wed Dec 19 17:40:48 2007 +0100
+++ b/src/Pure/Isar/class.ML	Wed Dec 19 22:33:44 2007 +0100
@@ -132,7 +132,7 @@
   base_sort: sort,
   inst: term option list
     (*canonical interpretation*),
-  morphism: morphism,
+  morphism: theory -> thm list -> morphism,
     (*partial morphism of canonical interpretation*)
   assm_intro: thm option,
   of_class: thm,
@@ -195,7 +195,8 @@
 
 fun these_defs thy = maps (these o Option.map #defs o lookup_class_data thy) o ancestry thy;
 
-fun morphism thy = #morphism o the_class_data thy;
+fun partial_morphism thy class = #morphism (the_class_data thy class) thy [];
+fun morphism thy class = #morphism (the_class_data thy class) thy (these_defs thy [class]);
 
 fun these_assm_intros thy =
   Graph.fold (fn (_, (data, _)) => fold (insert Thm.eq_thm)
@@ -243,7 +244,7 @@
 (* updaters *)
 
 fun add_class_data ((class, superclasses),
-    (params, base_sort, inst, phi, assm_intro, of_class, axiom)) thy =
+    (params, base_sort, inst, phi, axiom, assm_intro, of_class)) thy =
   let
     val operations = map (fn (v_ty as (_, ty), (c, _)) =>
       (c, (class, (ty, Free v_ty)))) params;
@@ -276,6 +277,7 @@
 
 fun calculate thy sups base_sort assm_axiom param_map class =
   let
+    (*static parts of morphism*)
     val subst_typ = map_atyps (fn TFree (v, sort) =>
           if v = Name.aT then TVar ((v, 0), [class]) else TVar ((v, 0), sort)
       | ty => ty);
@@ -283,52 +285,62 @@
          of SOME (c, _) => Const (c, ty)
           | NONE => t)
       | subst_aterm t = t;
-    val subst_term = map_aterms subst_aterm #> map_types subst_typ;
-    val matches = ([pairself (Thm.ctyp_of thy o TVar o pair (Name.aT, 0))
-      (base_sort, [class])], map (fn (v, (c, ty)) => pairself (Thm.cterm_of thy)
-        (Var ((v, 0), map_atyps (fn _ => TVar ((Name.aT, 0), [class])) ty),
-          Const (c, map_atyps (fn _ => TVar ((Name.aT, 0), [class])) ty))) param_map);
-    val inst_thm = Thm.instantiate matches;
+    fun instantiate thy sort = Thm.instantiate ([pairself (Thm.ctyp_of thy o TVar o pair (Name.aT, 0))
+      (base_sort, sort)], map (fn (v, (c, ty)) => pairself (Thm.cterm_of thy)
+        (Var ((v, 0), map_atyps (fn _ => TVar ((Name.aT, 0), sort)) ty),
+          Const (c, map_atyps (fn _ => TVar ((Name.aT, 0), sort)) ty))) param_map);
+    val instantiate_base_sort = instantiate thy base_sort;
+    val instantiate_class = instantiate thy [class];
     val (proto_assm_intro, locale_intro) = Locale.intros thy class
       |> pairself (try the_single);
     val axiom_premises = map_filter (#axiom o the_class_data thy) sups
       @ the_list assm_axiom;
-    val axiom = case locale_intro
-     of SOME proto_axiom => SOME ((inst_thm proto_axiom OF axiom_premises) |> Drule.standard)
-      | NONE => assm_axiom;
-    val lift_axiom = case axiom of SOME axiom =>
-          (fn thm => Thm.implies_elim (inst_thm thm) axiom)
+    val axiom = locale_intro
+      |> Option.map (Drule.standard o (fn thm => thm OF axiom_premises) o instantiate_class)
+      |> (fn x as SOME _ => x | NONE => assm_axiom);
+    val lift_axiom = case axiom
+     of SOME axiom => (fn thm => Thm.implies_elim thm axiom)
       | NONE => I;
-    val subst_thm = Drule.standard' #> inst_thm #> lift_axiom;
-    val morphism = Morphism.term_morphism subst_term
-      $> Morphism.typ_morphism subst_typ
-      $> Morphism.thm_morphism subst_thm;
 
-    (*FIXME use more primitives here rather than OF, simplifify code*)
-    fun VarA sort = TVar ((Name.aT, 0), sort);
-    fun FreeA sort = TFree (Name.aT, sort);
-    fun instantiate sort1 sort2 =
-      Thm.instantiate ([pairself (Thm.ctyp_of thy) (VarA sort1, FreeA sort2)], [])
-    val inst_ty = (map_atyps o K o VarA) base_sort;
+    (*dynamic parts of morphism*)
+    fun rew_term thy defs = Pattern.rewrite_term thy
+      (map (Logic.dest_equals o Thm.prop_of) defs) [];
+    fun subst_term thy defs = map_aterms subst_aterm #> rew_term thy defs
+      #> map_types subst_typ;
+    fun subst_thm defs = Drule.standard' #> instantiate_class #> lift_axiom
+      #> MetaSimplifier.rewrite_rule defs;
+    fun morphism thy defs = 
+      Morphism.typ_morphism subst_typ
+      $> Morphism.term_morphism (subst_term thy defs)
+      $> Morphism.thm_morphism (subst_thm defs);
+
+    (*class rules*)
+    val defs = these_defs thy sups;
     val assm_intro = proto_assm_intro
-      |> Option.map (Thm.instantiate ([],
-           map (fn (v, (c, ty)) => pairself (Thm.cterm_of thy)
-             (Var ((v, 0), inst_ty ty), Const (c, inst_ty ty))) param_map))
-      |> Option.map (MetaSimplifier.rewrite_rule (these_defs thy sups))
+      |> Option.map instantiate_base_sort
+      |> Option.map (MetaSimplifier.rewrite_rule defs)
       |> Option.map Goal.close_result;
-    val class_intro = (instantiate [] base_sort o #intro o AxClass.get_info thy) class;
+    val fixate = Thm.instantiate
+      (map (pairself (Thm.ctyp_of thy)) [(TVar ((Name.aT, 0), []), TFree (Name.aT, base_sort)),
+        (TVar ((Name.aT, 0), base_sort), TFree (Name.aT, base_sort))], [])
+    val class_intro = (fixate o #intro o AxClass.get_info thy) class;
     val of_class_sups = if null sups
-      then Drule.sort_triv thy (FreeA base_sort, base_sort)
-      else map (Drule.implies_intr_hyps o #of_class o the_class_data thy) sups;
+      then map (fixate o Thm.class_triv thy) base_sort
+      else map (fixate o #of_class o the_class_data thy) sups;
     val locale_dests = map Drule.standard' (Locale.dests thy class);
-    fun mk_pred_triv () = (Thm.assume o Thm.cterm_of thy
-      o (map_types o map_atyps o K o FreeA) base_sort o Thm.prop_of o the) axiom;
-    val pred_trivs = case length locale_dests
-     of 0 => if is_none locale_intro then [] else [mk_pred_triv ()]
-      | n => replicate n (mk_pred_triv ());
+    val num_trivs = case length locale_dests
+     of 0 => if is_none axiom then 0 else 1
+      | n => n;
+    val pred_trivs = if num_trivs = 0 then []
+      else the axiom
+        |> Thm.prop_of
+        |> (map_types o map_atyps o K) (TFree (Name.aT, base_sort))
+        |> (Thm.assume o Thm.cterm_of thy)
+        |> replicate num_trivs;
     val of_class = (class_intro OF of_class_sups OF locale_dests OF pred_trivs)
+      |> Drule.standard'
       |> Goal.close_result;
-  in (morphism, assm_intro, of_class, axiom) end;
+  in (morphism, axiom, assm_intro, of_class) end;
 
 fun class_interpretation class facts defs thy =
   let
@@ -346,7 +358,7 @@
 
 fun prove_subclass (sub, sup) thm thy =
   let
-    val of_class = (Drule.standard' o #of_class o the_class_data thy) sup;
+    val of_class = (#of_class o the_class_data thy) sup;
     val intro = Drule.standard' (of_class OF [Drule.standard' thm]);
     val classrel = intro OF (the_list o #axiom o the_class_data thy) sub;
   in
@@ -600,9 +612,9 @@
     |> adjungate_axclass bname class base_sort sups supsort supparams global_syntax other_consts
     |-> (fn (param_map, params, assm_axiom) =>
          `(fn thy => calculate thy sups base_sort assm_axiom param_map class)
-    #-> (fn (morphism, assm_intro, assm_proj, axiom) =>
+    #-> (fn (morphism, axiom, assm_intro, of_class) =>
         add_class_data ((class, sups), (params, base_sort,
-          map snd param_map, morphism, assm_intro, assm_proj, axiom))
+          map snd param_map, morphism, axiom, assm_intro, of_class))
     #> class_interpretation class (the_list axiom) []))
     |> init class
     |> pair class
@@ -624,7 +636,7 @@
   let
     val prfx = class_prefix class;
     val thy' = thy |> Sign.add_path prfx;
-    val phi = morphism thy' class;
+    val phi = partial_morphism thy' class;
 
     val c' = Sign.full_name thy' c;
     val dict' = Morphism.term phi dict;
@@ -652,8 +664,7 @@
     val phi = morphism thy class;
 
     val c' = Sign.full_name thy' c;
-    val rews = map (Logic.dest_equals o Thm.prop_of) (these_defs thy' [class])
-    val rhs' = (Pattern.rewrite_term thy rews [] o Morphism.term phi) rhs;
+    val rhs' = Morphism.term phi rhs;
     val ty' = Logic.unvarifyT (Term.fastype_of rhs');
   in
     thy'
@@ -704,14 +715,6 @@
 
 (* syntax *)
 
-fun subst_param thy params = map_aterms (fn t as Const (c, ty) =>
-    (case AxClass.inst_tyco_of thy (c, ty)
-     of SOME tyco => (case AList.lookup (op =) params (c, tyco)
-         of SOME v_ty => Free v_ty
-          | NONE => t)
-      | NONE => t)
-  | t => t);
-
 fun inst_term_check ts lthy =
   let
     val params = instantiation_params lthy;
@@ -724,9 +727,17 @@
               | NONE => I)
           | NONE => I)
       | check_improve _ = I;
+    val subst_param = map_aterms (fn t as Const (c, ty) =>
+        (case AxClass.inst_tyco_of thy (c, ty)
+         of SOME tyco => (case AList.lookup (op =) params (c, tyco)
+             of SOME v_ty => Free v_ty
+              | NONE => t)
+          | NONE => t)
+      | t => t);
+
     val improvement = (fold o fold_aterms) check_improve ts Vartab.empty;
     val ts' = (map o map_types) (Envir.typ_subst_TVars improvement) ts;
-    val ts'' = map (subst_param thy params) ts';
+    val ts'' = map subst_param ts';
   in if eq_list (op aconv) (ts, ts'') then NONE else SOME (ts'', lthy) end;
 
 fun inst_term_uncheck ts lthy =