proper abbreviations in class
authorhaftmann
Tue, 22 Apr 2008 08:33:12 +0200
changeset 26730 bbb5e6904d78
parent 26729 43a72d892594
child 26731 48df747c8543
proper abbreviations in class
src/Pure/Isar/class.ML
src/Pure/Isar/overloading.ML
--- a/src/Pure/Isar/class.ML	Tue Apr 22 08:33:10 2008 +0200
+++ b/src/Pure/Isar/class.ML	Tue Apr 22 08:33:12 2008 +0200
@@ -191,7 +191,6 @@
 
 fun these_defs thy = maps (these o Option.map #defs o lookup_class_data thy) o ancestry thy;
 
-fun partial_morphism thy class = #morphism (the_class_data thy class) thy [];
 fun morphism thy class = #morphism (the_class_data thy class) thy (these_defs thy [class]);
 
 fun these_assm_intros thy =
@@ -438,7 +437,7 @@
     ctxt
     |> fold declare_const primary_constraints
     |> Overloading.map_improvable_syntax (K (((primary_constraints, secondary_constraints),
-        ((improve, subst), unchecks)), false))
+        (((improve, subst), true), unchecks)), false))
     |> Overloading.set_primary_constraints
   end;
 
@@ -581,7 +580,7 @@
   let
     val prfx = class_prefix class;
     val thy' = thy |> Sign.add_path prfx;
-    val phi = partial_morphism thy' class;
+    val phi = morphism thy' class;
 
     val c' = Sign.full_name thy' c;
     val dict' = Morphism.term phi dict;
@@ -608,14 +607,16 @@
   let
     val prfx = class_prefix class;
     val thy' = thy |> Sign.add_path prfx;
-    val phi = morphism thy class;
 
+    val unchecks = map (fn (c, (_, (ty, t))) => (t, Const (c, ty)))
+      (these_operations thy [class]);
     val c' = Sign.full_name thy' c;
-    val rhs' = Morphism.term phi rhs;
-    val ty' = Logic.unvarifyT (Term.fastype_of rhs');
+    val rhs' = Pattern.rewrite_term thy unchecks [] rhs;
+    val rhs'' = map_types Logic.varifyT rhs';
+    val ty' = Term.fastype_of rhs';
   in
     thy'
-    |> Sign.add_abbrev (#1 prmode) pos (c, map_types Type.strip_sorts rhs') |> snd
+    |> Sign.add_abbrev (#1 prmode) pos (c, map_types Type.strip_sorts rhs'') |> snd
     |> Sign.add_const_constraint (c', SOME ty')
     |> Sign.notation true prmode [(Const (c', ty'), mx)]
     |> register_operation class (c', (rhs', NONE))
@@ -673,8 +674,8 @@
   in
     ctxt
     |> Overloading.map_improvable_syntax
-         (fn (((primary_constraints, _), ((improve, _), _)), _) =>
-            (((primary_constraints, []), ((improve, subst), unchecks)), false))
+         (fn (((primary_constraints, _), (((improve, _), _), _)), _) =>
+            (((primary_constraints, []), (((improve, subst), false), unchecks)), false))
   end;
 
 
@@ -744,7 +745,7 @@
     |> fold (Variable.declare_names o Free o snd) inst_params
     |> (Overloading.map_improvable_syntax o apfst)
          (fn ((_, _), ((_, subst), unchecks)) =>
-            ((primary_constraints, []), ((improve, K NONE), [])))
+            ((primary_constraints, []), (((improve, K NONE), false), [])))
     |> Overloading.add_improvable_syntax
     |> Context.proof_map (Syntax.add_term_check 0 "resorting" resort_check)
     |> synchronize_inst_syntax
--- a/src/Pure/Isar/overloading.ML	Tue Apr 22 08:33:10 2008 +0200
+++ b/src/Pure/Isar/overloading.ML	Tue Apr 22 08:33:12 2008 +0200
@@ -28,7 +28,7 @@
 (** generic check/uncheck combinators for improvable constants **)
 
 type improvable_syntax = ((((string * typ) list * (string * typ) list) *
-  (((string * typ -> (typ * typ) option) * (string * typ -> (typ * term) option)) *
+  ((((string * typ -> (typ * typ) option) * (string * typ -> (typ * term) option)) * bool) *
     (term * term) list)) * bool);
 
 structure ImprovableSyntax = ProofDataFun(
@@ -37,6 +37,7 @@
     secondary_constraints: (string * typ) list,
     improve: string * typ -> (typ * typ) option,
     subst: string * typ -> (typ * term) option,
+    consider_abbrevs: bool,
     unchecks: (term * term) list,
     passed: bool
   };
@@ -45,26 +46,32 @@
     secondary_constraints = [],
     improve = K NONE,
     subst = K NONE,
+    consider_abbrevs = false,
     unchecks = [],
     passed = true
   };
 );
 
 fun map_improvable_syntax f = ImprovableSyntax.map (fn { primary_constraints,
-  secondary_constraints, improve, subst, unchecks, passed } => let
-    val (((primary_constraints', secondary_constraints'), ((improve', subst'), unchecks')), passed')
-      = f (((primary_constraints, secondary_constraints), ((improve, subst), unchecks)), passed)
+  secondary_constraints, improve, subst, consider_abbrevs, unchecks, passed } => let
+    val (((primary_constraints', secondary_constraints'),
+      (((improve', subst'), consider_abbrevs'), unchecks')), passed')
+        = f (((primary_constraints, secondary_constraints),
+            (((improve, subst), consider_abbrevs), unchecks)), passed)
   in { primary_constraints = primary_constraints', secondary_constraints = secondary_constraints',
-    improve = improve', subst = subst', unchecks = unchecks', passed = passed'
+    improve = improve', subst = subst', consider_abbrevs = consider_abbrevs',
+    unchecks = unchecks', passed = passed'
   } end);
 
 val mark_passed = (map_improvable_syntax o apsnd) (K true);
 
 fun improve_term_check ts ctxt =
   let
-    val { primary_constraints, secondary_constraints, improve, subst, passed, ... } =
-      ImprovableSyntax.get ctxt;
+    val { primary_constraints, secondary_constraints, improve, subst,
+      consider_abbrevs, passed, ... } = ImprovableSyntax.get ctxt;
     val tsig = (Sign.tsig_of o ProofContext.theory_of) ctxt;
+    val is_abbrev = consider_abbrevs andalso ProofContext.is_abbrev_mode ctxt;
+    val passed_or_abbrev = passed orelse is_abbrev;
     fun accumulate_improvements (Const (c, ty)) = (case improve (c, ty)
          of SOME ty_ty' => Type.typ_match tsig ty_ty'
           | _ => I)
@@ -77,9 +84,9 @@
               then SOME (ty', apply_subst t') else NONE
           | NONE => NONE)
         | _ => NONE) t;
-    val ts'' = map apply_subst ts';
-  in if eq_list (op aconv) (ts, ts'') andalso passed then NONE else
-    if passed then SOME (ts'', ctxt)
+    val ts'' = if is_abbrev then ts' else map apply_subst ts';
+  in if eq_list (op aconv) (ts, ts'') andalso passed_or_abbrev then NONE else
+    if passed_or_abbrev then SOME (ts'', ctxt)
     else SOME (ts'', ctxt
       |> fold (ProofContext.add_const_constraint o apsnd SOME) secondary_constraints
       |> mark_passed)
@@ -147,7 +154,7 @@
     |> ProofContext.init
     |> OverloadingData.put overloading
     |> fold (fn ((_, ty), (v, _)) => Variable.declare_names (Free (v, ty))) overloading
-    |> map_improvable_syntax (K ((([], []), ((K NONE, subst), unchecks)), false))
+    |> map_improvable_syntax (K ((([], []), (((K NONE, subst), false), unchecks)), false))
     |> add_improvable_syntax
   end;