generic improvable syntax for targets
authorhaftmann
Fri, 07 Mar 2008 13:53:07 +0100
changeset 26238 c30bb8182da2
parent 26237 4bc6e3ff8b78
child 26239 e105d24d15c1
generic improvable syntax for targets
src/Pure/Isar/class.ML
src/Pure/Isar/overloading.ML
--- a/src/Pure/Isar/class.ML	Fri Mar 07 13:53:06 2008 +0100
+++ b/src/Pure/Isar/class.ML	Fri Mar 07 13:53:07 2008 +0100
@@ -70,15 +70,6 @@
       (Method.Basic (K (Method.SIMPLE_METHOD tac), Position.none), NONE)
   #> ProofContext.theory_of;
 
-fun get_remove_global_constraint c thy =
-  let
-    val ty = Sign.the_const_constraint thy c;
-  in
-    thy
-    |> Sign.add_const_constraint (c, NONE)
-    |> pair (c, Logic.unvarifyT ty)
-  end;
-
 
 (** primitive axclass and instance commands **)
 
@@ -345,15 +336,20 @@
 fun class_interpretation class facts defs thy =
   let
     val params = these_params thy [class];
+    val consts = map (fst o snd) params;
+    val constraints = map (fn c => map_atyps (K (TFree (Name.aT,
+      [the (AxClass.class_of_param thy c)]))) (Sign.the_const_type thy c)) consts;
+    val no_constraints = map (map_atyps (K (TFree (Name.aT, [])))) constraints;
+    fun add_constraint c T = Sign.add_const_constraint (c, SOME T);
     val inst = (#inst o the_class_data thy) class;
     val tac = ALLGOALS (ProofContext.fact_tac facts);
     val prfx = class_prefix class;
   in
     thy
-    |> fold_map (get_remove_global_constraint o fst o snd) params
-    ||> prove_interpretation tac ((false, prfx), []) (Locale.Locale class)
+    |> fold2 add_constraint consts no_constraints
+    |> prove_interpretation tac ((false, prfx), []) (Locale.Locale class)
           (inst, map (fn def => (("", []), def)) defs)
-    |-> (fn cs => fold (Sign.add_const_constraint o apsnd SOME) cs)
+    |> fold2 add_constraint consts constraints
   end;
 
 fun prove_subclass (sub, sup) thm thy =
@@ -398,26 +394,7 @@
 
 (* class context syntax *)
 
-structure ClassSyntax = ProofDataFun(
-  type T = {
-    local_constraints: (string * typ) list,
-    global_constraints: (string * typ) list,
-    base_sort: sort,
-    operations: (string * (typ * term)) list,
-    unchecks: (term * term) list,
-    passed: bool
-  };
-  fun init _ = {
-    local_constraints = [],
-    global_constraints = [],
-    base_sort = [],
-    operations = [],
-    unchecks = [],
-    passed = true
-  };;
-);
-
-fun synchronize_syntax sups base_sort ctxt =
+fun synchronize_class_syntax sups base_sort ctxt =
   let
     val thy = ProofContext.theory_of ctxt;
     fun subst_class_typ sort = map_atyps
@@ -430,80 +407,44 @@
     fun declare_const (c, _) =
       let val b = Sign.base_name c
       in Sign.intern_const thy b = c ? Variable.declare_const (b, c) end;
+    fun improve (c, ty) = (case AList.lookup (op =) local_constraints c
+     of SOME ty' => (case try (Type.raw_match (ty', ty)) Vartab.empty
+         of SOME tyenv => (case Vartab.lookup tyenv (Name.aT, 0)
+             of SOME (_, ty' as TVar (tvar as (vi, _))) =>
+                  if TypeInfer.is_param vi
+                    then SOME (ty', TFree (Name.aT, base_sort))
+                    else NONE
+              | _ => NONE)
+          | NONE => NONE)
+      | NONE => NONE)
+    fun subst (c, ty) = Option.map snd (AList.lookup (op =) operations c);
     val unchecks = map (fn (c, (_, (ty, t))) => (t, Const (c, ty))) operations;
   in
     ctxt
     |> fold declare_const local_constraints
     |> fold (ProofContext.add_const_constraint o apsnd SOME) local_constraints
-    |> ClassSyntax.put {
+    |> Overloading.map_improvable_syntax (K {
         local_constraints = local_constraints,
         global_constraints = global_constraints,
-        base_sort = base_sort,
-        operations = (map o apsnd) snd operations,
+        improve = improve,
+        subst = subst,
         unchecks = unchecks,
         passed = false
-      }
+      })
   end;
 
 fun refresh_syntax class ctxt =
   let
     val thy = ProofContext.theory_of ctxt;
     val base_sort = (#base_sort o the_class_data thy) class;
-  in synchronize_syntax [class] base_sort ctxt end;
-
-val mark_passed = ClassSyntax.map
-  (fn { local_constraints, global_constraints, base_sort, operations, unchecks, passed } =>
-    { local_constraints = local_constraints, global_constraints = global_constraints,
-      base_sort = base_sort, operations = operations, unchecks = unchecks, passed = true });
-
-fun sort_term_check ts ctxt =
-  let
-    val { local_constraints, global_constraints, base_sort, operations, passed, ... } =
-      ClassSyntax.get ctxt;
-    fun check_improve (Const (c, ty)) = (case AList.lookup (op =) local_constraints c
-         of SOME ty0 => (case try (Type.raw_match (ty0, ty)) Vartab.empty
-             of SOME tyenv => (case Vartab.lookup tyenv (Name.aT, 0)
-                 of SOME (_, TVar (tvar as (vi, _))) =>
-                      if TypeInfer.is_param vi then cons tvar else I
-                  | _ => I)
-              | NONE => I)
-          | NONE => I)
-      | check_improve _ = I;
-    val improvements = (fold o fold_aterms) check_improve ts [];
-    val ts' = (map o map_types o map_atyps) (fn ty as TVar tvar =>
-        if member (op =) improvements tvar
-          then TFree (Name.aT, base_sort) else ty | ty => ty) ts;
-    fun check t0 = Envir.expand_term (fn Const (c, ty) => (case AList.lookup (op =) operations c
-         of SOME (ty0, t) =>
-              if Type.typ_instance (ProofContext.tsig_of ctxt) (ty, ty0)
-              then SOME (ty0, check t) else NONE
-          | NONE => NONE)
-      | _ => NONE) t0;
-    val ts'' = map check ts';
-  in if eq_list (op aconv) (ts, ts'') andalso passed then NONE
-  else
-    ctxt
-    |> fold (ProofContext.add_const_constraint o apsnd SOME) global_constraints
-    |> mark_passed
-    |> pair ts''
-    |> SOME
-  end;
-
-fun sort_term_uncheck ts ctxt =
-  let
-    val thy = ProofContext.theory_of ctxt;
-    val unchecks = (#unchecks o ClassSyntax.get) ctxt;
-    val ts' = map (Pattern.rewrite_term thy unchecks []) ts;
-  in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt) end;
+  in synchronize_class_syntax [class] base_sort ctxt end;
 
 fun init_ctxt sups base_sort ctxt =
   ctxt
   |> Variable.declare_term
       (Logic.mk_type (TFree (Name.aT, base_sort)))
-  |> synchronize_syntax sups base_sort
-  |> Context.proof_map (
-      Syntax.add_term_check 0 "class" sort_term_check
-      #> Syntax.add_term_uncheck 0 "class" sort_term_uncheck)
+  |> synchronize_class_syntax sups base_sort
+  |> Overloading.add_improvable_syntax;
 
 fun init class thy =
   thy
@@ -710,46 +651,52 @@
   |> find_first (fn (_, (v', _)) => v = v')
   |> Option.map (fst o fst);
 
-fun confirm_declaration c = (map_instantiation o apsnd)
-  (filter_out (fn (_, (c', _)) => c' = c));
-
 
 (* syntax *)
 
-fun inst_term_check ts lthy =
+fun synchronize_inst_syntax ctxt =
   let
-    val params = instantiation_params lthy;
-    val tsig = ProofContext.tsig_of lthy;
-    val thy = ProofContext.theory_of lthy;
-
-    fun check_improve (Const (c, ty)) = (case AxClass.inst_tyco_of thy (c, ty)
-         of SOME tyco => (case AList.lookup (op =) params (c, tyco)
-             of SOME (_, ty') => perhaps (try (Type.typ_match tsig (ty, ty')))
-              | NONE => I)
-          | NONE => I)
-      | check_improve _ = I;
-    val subst_param = map_aterms (fn t as Const (c, ty) =>
-        (case AxClass.inst_tyco_of thy (c, ty)
+    val Instantiation { arities = (_, _, sorts), params = params } = Instantiation.get ctxt;
+    val thy = ProofContext.theory_of ctxt;
+    val operations = these_operations thy sorts;
+    fun subst_class_typ sort = map_atyps
+      (fn TFree _ => TVar ((Name.aT, 0), sort) | ty' => ty');
+    val local_constraints =
+      (map o apsnd) (subst_class_typ [] o fst o snd) operations;
+    val global_constraints = map_filter (fn (c, (class, (ty, _))) =>
+      if exists (fn ((c', _), _) => c = c') params
+        then SOME (c, subst_class_typ [class] ty)
+        else NONE) operations;
+    fun improve (c, ty) = case AxClass.inst_tyco_of thy (c, ty)
          of SOME tyco => (case AList.lookup (op =) params (c, tyco)
-             of SOME v_ty => Free v_ty
-              | NONE => t)
-          | NONE => t)
-      | t => t);
-
-    val improvement = (fold o fold_aterms) check_improve ts Vartab.empty;
-    val ts' = (map o map_types) (Envir.typ_subst_TVars improvement) ts;
-    val ts'' = map subst_param ts';
-  in if eq_list (op aconv) (ts, ts'') then NONE else SOME (ts'', lthy) end;
-
-fun inst_term_uncheck ts lthy =
-  let
-    val params = instantiation_params lthy;
-    val ts' = (map o map_aterms) (fn t as Free (v, ty) =>
-       (case get_first (fn ((c, _), (v', _)) => if v = v' then SOME c else NONE) params
-         of SOME c => Const (c, ty)
-          | NONE => t)
-      | t => t) ts;
-  in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', lthy) end;
+             of SOME (_, ty') => SOME (ty, ty')
+              | NONE => NONE)
+          | NONE => NONE;
+          (*| NONE => (case map_filter
+               (fn ((c', _), (_, ty')) => if c' = c then SOME ty' else NONE) params
+             of [ty'] => (case Sign.const_typargs thy (c, ty)
+                 of [TVar (vi, _)] => if TypeInfer.is_param vi then SOME (ty, ty') else NONE
+                  | _ => NONE)
+              | _ => NONE*);
+    fun subst (c, ty) = case AxClass.inst_tyco_of thy (c, ty)
+         of SOME tyco => (case AList.lookup (op =) params (c, tyco)
+             of SOME (v_ty as (_, ty)) => SOME (ty, Free v_ty)
+              | NONE => NONE)
+          | NONE => NONE;
+    val unchecks =
+      map (fn ((c, _), v_ty as (_, ty)) => (Free v_ty, Const (c, ty))) params;
+  in
+    ctxt
+    |> fold (ProofContext.add_const_constraint o apsnd SOME) local_constraints
+    |> Overloading.map_improvable_syntax (K {
+        local_constraints = local_constraints,
+        global_constraints = global_constraints,
+        improve = improve,
+        subst = subst,
+        unchecks = unchecks,
+        passed = false
+      })
+  end;
 
 
 (* target *)
@@ -786,12 +733,14 @@
     |> Instantiation.put (mk_instantiation ((tycos, vs, sort), params))
     |> fold (Variable.declare_term o Logic.mk_type o TFree) vs
     |> fold (Variable.declare_names o Free o snd) params
-    |> fold (fn tyco => ProofContext.add_arity (tyco, map snd vs, sort)) tycos
-    |> Context.proof_map (
-        Syntax.add_term_check 0 "instance" inst_term_check
-        #> Syntax.add_term_uncheck 0 "instance" inst_term_uncheck)
+    |> synchronize_inst_syntax
+    |> Overloading.add_improvable_syntax
   end;
 
+fun confirm_declaration c = (map_instantiation o apsnd)
+  (filter_out (fn (_, (c', _)) => c' = c))
+  #> LocalTheory.target synchronize_inst_syntax
+
 fun gen_instantiation_instance do_proof after_qed lthy =
   let
     val (tycos, vs, sort) = (#arities o the_instantiation) lthy;
--- a/src/Pure/Isar/overloading.ML	Fri Mar 07 13:53:06 2008 +0100
+++ b/src/Pure/Isar/overloading.ML	Fri Mar 07 13:53:07 2008 +0100
@@ -14,6 +14,11 @@
   val define: bool -> string -> string * term -> theory -> thm * theory
   val operation: Proof.context -> string -> (string * bool) option
   val pretty: Proof.context -> Pretty.T
+  
+  type improvable_syntax
+  val add_improvable_syntax: Proof.context -> Proof.context
+  val map_improvable_syntax: (improvable_syntax -> improvable_syntax)
+    -> Proof.context -> Proof.context
 end;
 
 structure Overloading: OVERLOADING =
@@ -44,45 +49,101 @@
   Thm.add_def (not checked) true (name, Logic.mk_equals (Const (c, Term.fastype_of t), t));
 
 
-(* syntax *)
+(* generic check/uncheck combinators for improvable constants *)
+
+type improvable_syntax = {
+  local_constraints: (string * typ) list,
+  global_constraints: (string * typ) list,
+  improve: string * typ -> (typ * typ) option,
+  subst: string * typ -> (typ * term) option,
+  unchecks: (term * term) list,
+  passed: bool
+};
 
-fun subst_operation overloading = map_aterms (fn t as Const (c, ty) =>
-    (case AList.lookup (op =) overloading (c, ty)
-     of SOME (v, _) => Free (v, ty)
-      | NONE => t)
-  | t => t);
+structure ImprovableSyntax = ProofDataFun(
+  type T = improvable_syntax;
+  fun init _ = {
+    local_constraints = [],
+    global_constraints = [],
+    improve = K NONE,
+    subst = K NONE,
+    unchecks = [],
+    passed = true
+  };
+);
 
-fun term_check ts lthy =
+val map_improvable_syntax = ImprovableSyntax.map;
+
+val mark_passed = map_improvable_syntax
+  (fn { local_constraints, global_constraints, improve, subst, unchecks, passed } =>
+    { local_constraints = local_constraints, global_constraints = global_constraints,
+      improve = improve, subst = subst, unchecks = unchecks, passed = true });
+
+fun improve_term_check ts ctxt =
   let
-    val overloading = get_overloading lthy;
-    val ts' = map (subst_operation overloading) ts;
-  in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', lthy) end;
+    val { local_constraints, global_constraints, improve, subst, passed, ... } =
+      ImprovableSyntax.get ctxt;
+    val tsig = (Sign.tsig_of o ProofContext.theory_of) ctxt;
+    fun accumulate_improvements (Const (c, ty)) = (case improve (c, ty)
+         of SOME ty_ty' => (perhaps o try o Type.typ_match tsig) ty_ty'
+          | _ => I)
+      | accumulate_improvements _ = I;
+    val improvements = (fold o fold_aterms) accumulate_improvements ts Vartab.empty;
+    val ts' = (map o map_types) (Envir.typ_subst_TVars improvements) ts;
+    fun apply_subst t = Envir.expand_term (fn Const (c, ty) => (case subst (c, ty)
+         of SOME (ty', t') =>   
+              if Type.typ_instance tsig (ty, ty')
+              then SOME (ty', apply_subst t') else NONE
+          | NONE => NONE)
+      | _ => NONE) t;
+    val ts'' = map apply_subst ts';
+  in if eq_list (op aconv) (ts, ts'') andalso passed then NONE else
+    if passed then SOME (ts'', ctxt)
+    else SOME (ts'', ctxt
+      |> fold (ProofContext.add_const_constraint o apsnd SOME) global_constraints
+      |> mark_passed)
+  end;
 
-fun term_uncheck ts lthy =
+fun improve_term_uncheck ts ctxt =
   let
-    val overloading = get_overloading lthy;
-    fun subst (t as Free (v, ty)) = (case get_first (fn ((c, _), (v', _)) =>
-        if v = v' then SOME c else NONE) overloading
-         of SOME c => Const (c, ty)
-          | NONE => t)
-      | subst t = t;
-    val ts' = (map o map_aterms) subst ts;
-  in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', lthy) end;
+    val thy = ProofContext.theory_of ctxt;
+    val unchecks = (#unchecks o ImprovableSyntax.get) ctxt;
+    val ts' = map (Pattern.rewrite_term thy unchecks []) ts;
+  in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt) end;
+
+fun add_improvable_syntax ctxt = ctxt
+  |> Context.proof_map
+    (Syntax.add_term_check 0 "improvement" improve_term_check
+    #> Syntax.add_term_uncheck 0 "improvement" improve_term_uncheck)
+  |> fold (ProofContext.add_const_constraint o apsnd SOME)
+       ((#local_constraints o ImprovableSyntax.get) ctxt);
 
 
 (* target *)
 
-fun init overloading thy =
+fun init raw_overloading thy =
   let
-    val _ = if null overloading then error "At least one parameter must be given" else ();
+    val _ = if null raw_overloading then error "At least one parameter must be given" else ();
+    val overloading = map (fn (v, c_ty, checked) => (c_ty, (v, checked))) raw_overloading;
+    fun subst (c, ty) = case AList.lookup (op =) overloading (c, ty)
+     of SOME (v, _) => SOME (ty, Free (v, ty))
+      | NONE => NONE;
+    val unchecks =
+      map (fn (c_ty as (_, ty), (v, _)) => (Free (v, ty), Const c_ty)) overloading;
   in
     thy
     |> ProofContext.init
-    |> OverloadingData.put (map (fn (v, c_ty, checked) => (c_ty, (v, checked))) overloading)
-    |> fold (fn (v, (_, ty), _) => Variable.declare_term (Free (v, ty))) overloading
-    |> Context.proof_map (
-        Syntax.add_term_check 0 "overloading" term_check
-        #> Syntax.add_term_uncheck 0 "overloading" term_uncheck)
+    |> OverloadingData.put overloading
+    |> fold (fn (v, (_, ty), _) => Variable.declare_term (Free (v, ty))) raw_overloading
+    |> map_improvable_syntax (K {
+        local_constraints = [],
+        global_constraints = [],
+        improve = K NONE,
+        subst = subst,
+        unchecks = unchecks,
+        passed = false
+      })
+    |> add_improvable_syntax
   end;
 
 fun conclude lthy =