improved rule calculation
authorhaftmann
Thu, 13 Dec 2007 07:09:06 +0100
changeset 25618 01f20279fea1
parent 25617 b495384e48e1
child 25619 e4d5cd384245
improved rule calculation
src/Pure/Isar/class.ML
--- a/src/Pure/Isar/class.ML	Thu Dec 13 07:09:05 2007 +0100
+++ b/src/Pure/Isar/class.ML	Thu Dec 13 07:09:06 2007 +0100
@@ -22,8 +22,7 @@
 
   val intro_classes_tac: thm list -> tactic
   val default_intro_classes_tac: thm list -> tactic
-  val prove_subclass: class * class -> thm list -> Proof.context
-    -> theory -> theory
+  val prove_subclass: class * class -> thm -> theory -> theory
 
   val class_prefix: string -> string
   val is_class: theory -> class -> bool
@@ -71,22 +70,6 @@
       (Method.Basic (K (Method.SIMPLE_METHOD tac), Position.none), NONE)
   #> ProofContext.theory_of;
 
-fun OF_LAST thm1 thm2 = thm1 RSN (Thm.nprems_of thm2, thm2);
-
-fun strip_all_ofclass thy sort =
-  let
-    val typ = TVar ((Name.aT, 0), sort);
-    fun prem_inclass t =
-      case Logic.strip_imp_prems t
-       of ofcls :: _ => try Logic.dest_inclass ofcls
-        | [] => NONE;
-    fun strip_ofclass class thm =
-      thm OF (fst o AxClass.of_sort thy (typ, [class])) AxClass.cache;
-    fun strip thm = case (prem_inclass o Thm.prop_of) thm
-     of SOME (_, class) => thm |> strip_ofclass class |> strip
-      | NONE => thm;
-  in strip end;
-
 fun get_remove_global_constraint c thy =
   let
     val ty = Sign.the_const_constraint thy c;
@@ -151,27 +134,29 @@
     (*canonical interpretation*),
   morphism: morphism,
     (*partial morphism of canonical interpretation*)
-  intro: thm,
+  assm_intro: thm option,
+  of_class: thm,
+  axiom: thm option,
   defs: thm list,
   operations: (string * (class * (typ * term))) list
 };
 
 fun rep_class_data (ClassData d) = d;
-fun mk_class_data ((consts, base_sort, inst, morphism, intro),
+fun mk_class_data ((consts, base_sort, inst, morphism, assm_intro, of_class, axiom),
     (defs, operations)) =
   ClassData { consts = consts, base_sort = base_sort, inst = inst,
-    morphism = morphism, intro = intro, defs = defs,
-    operations = operations };
-fun map_class_data f (ClassData { consts, base_sort, inst, morphism, intro,
-    defs, operations }) =
-  mk_class_data (f ((consts, base_sort, inst, morphism, intro),
+    morphism = morphism, assm_intro = assm_intro, of_class = of_class, axiom = axiom, 
+    defs = defs, operations = operations };
+fun map_class_data f (ClassData { consts, base_sort, inst, morphism,
+    assm_intro, of_class, axiom, defs, operations }) =
+  mk_class_data (f ((consts, base_sort, inst, morphism, assm_intro, of_class, axiom),
     (defs, operations)));
 fun merge_class_data _ (ClassData { consts = consts,
-    base_sort = base_sort, inst = inst, morphism = morphism, intro = intro,
-    defs = defs1, operations = operations1 },
-  ClassData { consts = _, base_sort = _, inst = _, morphism = _, intro = _,
-    defs = defs2, operations = operations2 }) =
-  mk_class_data ((consts, base_sort, inst, morphism, intro),
+    base_sort = base_sort, inst = inst, morphism = morphism, assm_intro = assm_intro,
+    of_class = of_class, axiom = axiom, defs = defs1, operations = operations1 },
+  ClassData { consts = _, base_sort = _, inst = _, morphism = _, assm_intro = _,
+    of_class = _, axiom = _, defs = defs2, operations = operations2 }) =
+  mk_class_data ((consts, base_sort, inst, morphism, assm_intro, of_class, axiom),
     (Thm.merge_thms (defs1, defs2),
       AList.merge (op =) (K true) (operations1, operations2)));
 
@@ -212,9 +197,9 @@
 
 fun morphism thy = #morphism o the_class_data thy;
 
-fun these_intros thy =
-  Graph.fold (fn (_, (data, _)) => insert Thm.eq_thm ((#intro o rep_class_data) data))
-    (ClassData.get thy) [];
+fun these_assm_intros thy =
+  Graph.fold (fn (_, (data, _)) => fold (insert Thm.eq_thm)
+    ((the_list o #assm_intro o rep_class_data) data)) (ClassData.get thy) [];
 
 fun these_operations thy =
   maps (#operations o the_class_data thy) o ancestry thy;
@@ -257,17 +242,17 @@
 
 (* updaters *)
 
-fun add_class_data ((class, superclasses), (cs, base_sort, inst, phi, intro)) thy =
+fun add_class_data ((class, superclasses),
+    (cs, base_sort, inst, phi, assm_intro, of_class, axiom)) thy =
   let
     val operations = map (fn (v_ty as (_, ty), (c, _)) =>
       (c, (class, (ty, Free v_ty)))) cs;
     val cs = (map o pairself) fst cs;
     val add_class = Graph.new_node (class,
-        mk_class_data ((cs, base_sort, map (SOME o Const) inst, phi, intro), ([], operations)))
+        mk_class_data ((cs, base_sort,
+          map (SOME o Const) inst, phi, assm_intro, of_class, axiom), ([], operations)))
       #> fold (curry Graph.add_edge class) superclasses;
-  in
-    ClassData.map add_class thy
-  end;
+  in ClassData.map add_class thy end;
 
 fun register_operation class (c, (t, some_def)) thy =
   let
@@ -304,34 +289,40 @@
     $> Morphism.typ_morphism subst_typ
   end;
 
-fun class_intro thy class sups =
+fun calculate_rules thy sups base_sort assm_axiom param_map class =
   let
-    fun class_elim class =
-      case (#axioms o AxClass.get_info thy) class
-       of [thm] => SOME (Drule.unconstrainTs thm)
-        | [] => NONE;
-    val pred_intro = case Locale.intros thy class
-     of ([ax_intro], [intro]) => intro |> OF_LAST ax_intro |> SOME
-      | ([intro], []) => SOME intro
-      | ([], [intro]) => SOME intro
-      | _ => NONE;
-    val pred_intro' = pred_intro
-      |> Option.map (fn intro => intro OF map_filter class_elim sups);
-    val class_intro = (#intro o AxClass.get_info thy) class;
-    val raw_intro = case pred_intro'
-     of SOME pred_intro => class_intro |> OF_LAST pred_intro
-      | NONE => class_intro;
-    val sort = Sign.super_classes thy class;
-    val typ = TVar ((Name.aT, 0), sort);
-    val defs = these_defs thy sups;
-  in
-    raw_intro
-    |> Drule.instantiate' [SOME (Thm.ctyp_of thy typ)] []
-    |> strip_all_ofclass thy sort
-    |> Thm.strip_shyps
-    |> MetaSimplifier.rewrite_rule defs
-    |> Drule.unconstrainTs
-  end;
+    (*FIXME use more primitves here rather than OF, simplifify code*)
+    fun the_option [x] = SOME x
+      | the_option [] = NONE;
+    fun VarA sort = TVar ((Name.aT, 0), sort);
+    fun FreeA sort = TFree (Name.aT, sort);
+    fun instantiate sort1 sort2 =
+      Thm.instantiate ([pairself (Thm.ctyp_of thy) (VarA sort1, FreeA sort2)], [])
+    val (proto_assm_intro, locale_intro) = pairself the_option (Locale.intros thy class);
+    val inst_ty = (map_atyps o K o VarA) base_sort;
+    val assm_intro = proto_assm_intro
+      |> Option.map (Thm.instantiate ([],
+           map (fn ((v, _), (c, ty)) => pairself (Thm.cterm_of thy)
+             (Var ((v, 0), inst_ty ty), Const (c, inst_ty ty))) param_map))
+      |> Option.map (MetaSimplifier.rewrite_rule (these_defs thy sups));
+    val axiom_premises = map_filter (#axiom o the_class_data thy) sups
+      @ the_list assm_axiom;
+    val axiom = case locale_intro
+     of SOME proto_axiom => SOME
+          ((instantiate base_sort [class] proto_axiom OF axiom_premises) |> Drule.standard)
+      | NONE => assm_axiom;
+    val class_intro = (instantiate [] base_sort o #intro o AxClass.get_info thy) class;
+    val of_class_sups = if null sups
+      then Drule.sort_triv thy (FreeA base_sort, base_sort)
+      else map (Drule.implies_intr_hyps o #of_class o the_class_data thy) sups;
+    val locale_dests = map Drule.standard (Locale.dests thy class);
+    fun mk_pred_triv () = (Thm.assume o Thm.cterm_of thy
+      o (map_types o map_atyps o K o FreeA) base_sort o Thm.prop_of o the) axiom;
+    val pred_trivs = case length locale_dests
+     of 0 => if is_none locale_intro then [] else [mk_pred_triv ()]
+      | n => replicate n (mk_pred_triv ());
+    val of_class = class_intro OF of_class_sups OF locale_dests OF pred_trivs;
+  in (assm_intro, of_class, axiom) end;
 
 fun class_interpretation class facts defs thy =
   let
@@ -347,15 +338,28 @@
     |-> (fn cs => fold (Sign.add_const_constraint o apsnd SOME) cs)
   end;
 
+fun prove_subclass (sub, sup) thm thy =
+  let
+    val of_class = (Drule.standard o #of_class o the_class_data thy) sup;
+    val intro = Drule.standard (of_class OF [Drule.standard thm]);
+    val classrel = intro OF (the_list o #axiom o the_class_data thy) sub;
+  in
+    thy
+    |> AxClass.add_classrel classrel
+    |> prove_interpretation_in (ALLGOALS (ProofContext.fact_tac [thm]))
+         I (sub, Locale.Locale sup)
+    |> ClassData.map (Graph.add_edge (sub, sup))
+  end;
+
 fun intro_classes_tac facts st =
   let
     val thy = Thm.theory_of_thm st;
     val classes = Sign.all_classes thy;
     val class_trivs = map (Thm.class_triv thy) classes;
-    val class_intros = these_intros thy;
-    val axclass_intros = map_filter (try (#intro o AxClass.get_info thy)) classes;
+    val class_intros = map_filter (try (#intro o AxClass.get_info thy)) classes;
+    val assm_intros = these_assm_intros thy;
   in
-    Method.intros_tac (class_trivs @ class_intros @ axclass_intros) facts st
+    Method.intros_tac (class_trivs @ class_intros @ assm_intros) facts st
   end;
 
 fun default_intro_classes_tac [] = intro_classes_tac []
@@ -371,57 +375,6 @@
   ("default", Method.thms_ctxt_args (Method.METHOD oo default_tac),
     "apply some intro/elim rule")]);
 
-fun subclass_rule thy (sub, sup) =
-  let
-    val ctxt = Locale.init sub thy;
-    val ctxt_thy = ProofContext.init thy;
-    val props =
-      Locale.global_asms_of thy sup
-      |> maps snd
-      |> map (ObjectLogic.ensure_propT thy);
-    fun tac { prems, context } =
-      Locale.intro_locales_tac true context prems
-        ORELSE ALLGOALS assume_tac;
-  in
-    Goal.prove_multi ctxt [] [] props tac
-    |> map (Assumption.export false ctxt ctxt_thy)
-    |> Variable.export ctxt ctxt_thy
-  end;
-
-fun prove_single_subclass (sub, sup) thms ctxt thy =
-  let
-    val ctxt_thy = ProofContext.init thy;
-    val subclass_rule = Conjunction.intr_balanced thms
-      |> Assumption.export false ctxt ctxt_thy
-      |> singleton (Variable.export ctxt ctxt_thy);
-    val sub_inst = Thm.ctyp_of thy (TVar ((Name.aT, 0), [sub]));
-    val sub_ax = #axioms (AxClass.get_info thy sub);
-    val classrel =
-      #intro (AxClass.get_info thy sup)
-      |> Drule.instantiate' [SOME sub_inst] []
-      |> OF_LAST (subclass_rule OF sub_ax)
-      |> strip_all_ofclass thy (Sign.super_classes thy sup)
-      |> Thm.strip_shyps
-  in
-    thy
-    |> AxClass.add_classrel classrel
-    |> prove_interpretation_in (ALLGOALS (ProofContext.fact_tac thms))
-         I (sub, Locale.Locale sup)
-    |> ClassData.map (Graph.add_edge (sub, sup))
-  end;
-
-fun prove_subclass (sub, sup) thms ctxt thy =
-  let
-    val classes = ClassData.get thy;
-    val is_sup = not o null o curry (Graph.irreducible_paths classes) sub;
-    val supclasses = Graph.all_succs classes [sup] |> filter_out is_sup;
-    fun transform sup' = subclass_rule thy (sup, sup') |> map (fn thm => thm OF thms);
-  in
-    thy
-    |> fold_rev (fn sup' => prove_single_subclass (sub, sup')
-         (transform sup') ctxt) supclasses
- end;
-
 
 (** classes and class target **)
 
@@ -547,12 +500,10 @@
 fun gen_class_spec prep_class prep_expr process_expr thy raw_supclasses raw_includes_elems =
   let
     val supclasses = map (prep_class thy) raw_supclasses;
-    val sups = filter (is_class thy) supclasses;
-    fun the_base_sort class = lookup_class_data thy class
-      |> Option.map #base_sort
-      |> the_default [class];
-    val base_sort = Sign.minimize_sort thy (maps the_base_sort supclasses);
     val supsort = Sign.minimize_sort thy supclasses;
+    val sups = filter (is_class thy) supsort;
+    val base_sort = if null sups then supsort else
+      (#base_sort o the_class_data thy o hd) sups;
     val suplocales = map Locale.Locale sups;
     val (raw_elems, includes) = fold_rev (fn Locale.Elem e => apfst (cons e)
       | Locale.Expr i => apsnd (cons (prep_expr thy i))) raw_includes_elems ([], []);
@@ -577,30 +528,26 @@
 val read_class_spec = gen_class_spec Sign.intern_class Locale.intern_expr Locale.read_expr;
 val check_class_spec = gen_class_spec (K I) (K I) Locale.cert_expr;
 
-fun define_class_params (name, raw_superclasses) raw_consts raw_dep_axioms other_consts thy =
+fun define_class_params name class superclasses consts dep_axiom other_consts thy =
   let
-    val superclasses = map (Sign.certify_class thy) raw_superclasses;
-    val consts = (map o apfst o apsnd) (Sign.certify_typ thy) raw_consts;
     fun add_const ((c, ty), syn) =
       Sign.declare_const [] (c, Type.strip_sorts ty, syn) #>> Term.dest_Const;
-    fun mk_axioms cs thy =
-      raw_dep_axioms thy cs
-      |> (map o apsnd o map) (Sign.cert_prop thy)
-      |> rpair thy;
-    fun constrain_typs class = (map o apsnd o Term.map_type_tfree)
+    val constrain_typs = (map o apsnd o Term.map_type_tfree)
       (fn (v, _) => TFree (v, [class]))
+    fun the_option [x] = SOME x
+      | the_option [] = NONE;
   in
     thy
     |> Sign.add_path (Logic.const_of_class name)
     |> fold_map add_const consts
     ||> Sign.restore_naming thy
-    |-> (fn cs => mk_axioms cs
-    #-> (fn axioms_prop => AxClass.define_class (name, superclasses)
-           (map fst cs @ other_consts) axioms_prop
-    #-> (fn class => `(fn _ => constrain_typs class cs)
-    #-> (fn cs' => `(fn thy => AxClass.get_info thy class)
-    #-> (fn {axioms, ...} => fold (Sign.add_const_constraint o apsnd SOME) cs'
-    #> pair (class, (cs', axioms)))))))
+    |-> (fn cs => `(fn thy => dep_axiom thy cs)
+    #-> (fn axiom => AxClass.define_class (name, superclasses)
+           (map fst cs @ other_consts) [axiom]
+    #-> (fn _ => `(fn _ => constrain_typs cs)
+    #-> (fn cs' => `(fn thy => (the_option o #axioms o AxClass.get_info thy) class)
+    #-> (fn axiom => fold (Sign.add_const_constraint o apsnd SOME) cs'
+    #> pair (cs', axiom))))))
   end;
 
 fun gen_class prep_spec prep_param bname
@@ -618,7 +565,7 @@
       | fork_syntax x = pair x;
     val (elems, global_syn) = fold_map fork_syntax elems_syn [];
     fun globalize (c, ty) =
-      ((c, Term.map_type_tfree (K (TFree (Name.aT, base_sort))) ty),
+      ((c, map_atyps (K (TFree (Name.aT, base_sort))) ty),
         (the_default NoSyn o AList.lookup (op =) global_syn) c);
     fun extract_params thy =
       let
@@ -636,8 +583,10 @@
           ((Sign.base_name name, map (Attrib.attribute_i thy) atts),
             (map o map_aterms) subst ts);
       in
-        Locale.global_asms_of thy class
-        |> map prep_asm
+        Locale.intros thy class
+        |> fst
+        |> map (map_aterms subst o Logic.unvarify o Logic.strip_imp_concl o Thm.prop_of)
+        |> pair (bname ^ "_" ^ AxClass.axiomsN, [])
       end;
   in
     thy
@@ -646,20 +595,19 @@
     |> ProofContext.theory_of
     |> `extract_params
     |-> (fn (all_params, params) =>
-        define_class_params (bname, supsort) params
+        define_class_params bname class supsort params
           (extract_assumes params) other_consts
-      #-> (fn (_, (consts, axioms)) =>
-        `(fn thy => class_intro thy class sups)
-      #-> (fn class_intro =>
-        PureThy.note_thmss_qualified "" (NameSpace.append class classN)
-          [((introN, []), [([class_intro], [])])]
-      #-> (fn [(_, [class_intro])] =>
+      #-> (fn (consts, assm_axiom) =>
+        `(fn thy => calculate_rules thy sups base_sort assm_axiom
+          (all_params ~~ map snd supconsts @ consts) class)
+      #-> (fn (assm_intro, assm_proj, axiom) =>
         add_class_data ((class, sups),
           (map fst params ~~ consts, base_sort,
             mk_inst class (map snd supconsts @ consts),
-              calculate_morphism class (supconsts @ (map (fst o fst) params ~~ consts)), class_intro))
-      #> class_interpretation class axioms []
-      ))))
+              calculate_morphism class (supconsts @ (map (fst o fst) params ~~ consts)),
+          assm_intro, assm_proj, axiom))
+      #> class_interpretation class (the_list axiom) []
+      )))
     |> init class
     |> pair class
   end;
@@ -688,13 +636,15 @@
     val ty' = Term.fastype_of dict_def;
     val ty'' = Type.strip_sorts ty';
     val def_eq = Logic.mk_equals (Const (c', ty'), dict_def);
+    fun get_axiom thy = ((Thm.varifyT o Thm.symmetric o Thm.get_axiom_i thy) c', thy);
   in
     thy'
     |> Sign.declare_const pos (c, ty'', mx) |> snd
     |> Thm.add_def false false (c, def_eq)
     |>> Thm.symmetric
-    |-> (fn def => class_interpretation class [def] [Thm.prop_of def]
-          #> register_operation class (c', (dict', SOME (Thm.varifyT def))))
+    ||>> get_axiom
+    |-> (fn (def, def') => class_interpretation class [def] [Thm.prop_of def]
+          #> register_operation class (c', (dict', SOME def')))
     |> Sign.restore_naming thy
     |> Sign.add_const_constraint (c', SOME ty')
   end;