src/Pure/Isar/class.ML
changeset 26238 c30bb8182da2
parent 26167 ccc9007a7164
child 26247 b6608fbeaae1
     1.1 --- a/src/Pure/Isar/class.ML	Fri Mar 07 13:53:06 2008 +0100
     1.2 +++ b/src/Pure/Isar/class.ML	Fri Mar 07 13:53:07 2008 +0100
     1.3 @@ -70,15 +70,6 @@
     1.4        (Method.Basic (K (Method.SIMPLE_METHOD tac), Position.none), NONE)
     1.5    #> ProofContext.theory_of;
     1.6  
     1.7 -fun get_remove_global_constraint c thy =
     1.8 -  let
     1.9 -    val ty = Sign.the_const_constraint thy c;
    1.10 -  in
    1.11 -    thy
    1.12 -    |> Sign.add_const_constraint (c, NONE)
    1.13 -    |> pair (c, Logic.unvarifyT ty)
    1.14 -  end;
    1.15 -
    1.16  
    1.17  (** primitive axclass and instance commands **)
    1.18  
    1.19 @@ -345,15 +336,20 @@
    1.20  fun class_interpretation class facts defs thy =
    1.21    let
    1.22      val params = these_params thy [class];
    1.23 +    val consts = map (fst o snd) params;
    1.24 +    val constraints = map (fn c => map_atyps (K (TFree (Name.aT,
    1.25 +      [the (AxClass.class_of_param thy c)]))) (Sign.the_const_type thy c)) consts;
    1.26 +    val no_constraints = map (map_atyps (K (TFree (Name.aT, [])))) constraints;
    1.27 +    fun add_constraint c T = Sign.add_const_constraint (c, SOME T);
    1.28      val inst = (#inst o the_class_data thy) class;
    1.29      val tac = ALLGOALS (ProofContext.fact_tac facts);
    1.30      val prfx = class_prefix class;
    1.31    in
    1.32      thy
    1.33 -    |> fold_map (get_remove_global_constraint o fst o snd) params
    1.34 -    ||> prove_interpretation tac ((false, prfx), []) (Locale.Locale class)
    1.35 +    |> fold2 add_constraint consts no_constraints
    1.36 +    |> prove_interpretation tac ((false, prfx), []) (Locale.Locale class)
    1.37            (inst, map (fn def => (("", []), def)) defs)
    1.38 -    |-> (fn cs => fold (Sign.add_const_constraint o apsnd SOME) cs)
    1.39 +    |> fold2 add_constraint consts constraints
    1.40    end;
    1.41  
    1.42  fun prove_subclass (sub, sup) thm thy =
    1.43 @@ -398,26 +394,7 @@
    1.44  
    1.45  (* class context syntax *)
    1.46  
    1.47 -structure ClassSyntax = ProofDataFun(
    1.48 -  type T = {
    1.49 -    local_constraints: (string * typ) list,
    1.50 -    global_constraints: (string * typ) list,
    1.51 -    base_sort: sort,
    1.52 -    operations: (string * (typ * term)) list,
    1.53 -    unchecks: (term * term) list,
    1.54 -    passed: bool
    1.55 -  };
    1.56 -  fun init _ = {
    1.57 -    local_constraints = [],
    1.58 -    global_constraints = [],
    1.59 -    base_sort = [],
    1.60 -    operations = [],
    1.61 -    unchecks = [],
    1.62 -    passed = true
    1.63 -  };;
    1.64 -);
    1.65 -
    1.66 -fun synchronize_syntax sups base_sort ctxt =
    1.67 +fun synchronize_class_syntax sups base_sort ctxt =
    1.68    let
    1.69      val thy = ProofContext.theory_of ctxt;
    1.70      fun subst_class_typ sort = map_atyps
    1.71 @@ -430,80 +407,44 @@
    1.72      fun declare_const (c, _) =
    1.73        let val b = Sign.base_name c
    1.74        in Sign.intern_const thy b = c ? Variable.declare_const (b, c) end;
    1.75 +    fun improve (c, ty) = (case AList.lookup (op =) local_constraints c
    1.76 +     of SOME ty' => (case try (Type.raw_match (ty', ty)) Vartab.empty
    1.77 +         of SOME tyenv => (case Vartab.lookup tyenv (Name.aT, 0)
    1.78 +             of SOME (_, ty' as TVar (tvar as (vi, _))) =>
    1.79 +                  if TypeInfer.is_param vi
    1.80 +                    then SOME (ty', TFree (Name.aT, base_sort))
    1.81 +                    else NONE
    1.82 +              | _ => NONE)
    1.83 +          | NONE => NONE)
    1.84 +      | NONE => NONE)
    1.85 +    fun subst (c, ty) = Option.map snd (AList.lookup (op =) operations c);
    1.86      val unchecks = map (fn (c, (_, (ty, t))) => (t, Const (c, ty))) operations;
    1.87    in
    1.88      ctxt
    1.89      |> fold declare_const local_constraints
    1.90      |> fold (ProofContext.add_const_constraint o apsnd SOME) local_constraints
    1.91 -    |> ClassSyntax.put {
    1.92 +    |> Overloading.map_improvable_syntax (K {
    1.93          local_constraints = local_constraints,
    1.94          global_constraints = global_constraints,
    1.95 -        base_sort = base_sort,
    1.96 -        operations = (map o apsnd) snd operations,
    1.97 +        improve = improve,
    1.98 +        subst = subst,
    1.99          unchecks = unchecks,
   1.100          passed = false
   1.101 -      }
   1.102 +      })
   1.103    end;
   1.104  
   1.105  fun refresh_syntax class ctxt =
   1.106    let
   1.107      val thy = ProofContext.theory_of ctxt;
   1.108      val base_sort = (#base_sort o the_class_data thy) class;
   1.109 -  in synchronize_syntax [class] base_sort ctxt end;
   1.110 -
   1.111 -val mark_passed = ClassSyntax.map
   1.112 -  (fn { local_constraints, global_constraints, base_sort, operations, unchecks, passed } =>
   1.113 -    { local_constraints = local_constraints, global_constraints = global_constraints,
   1.114 -      base_sort = base_sort, operations = operations, unchecks = unchecks, passed = true });
   1.115 -
   1.116 -fun sort_term_check ts ctxt =
   1.117 -  let
   1.118 -    val { local_constraints, global_constraints, base_sort, operations, passed, ... } =
   1.119 -      ClassSyntax.get ctxt;
   1.120 -    fun check_improve (Const (c, ty)) = (case AList.lookup (op =) local_constraints c
   1.121 -         of SOME ty0 => (case try (Type.raw_match (ty0, ty)) Vartab.empty
   1.122 -             of SOME tyenv => (case Vartab.lookup tyenv (Name.aT, 0)
   1.123 -                 of SOME (_, TVar (tvar as (vi, _))) =>
   1.124 -                      if TypeInfer.is_param vi then cons tvar else I
   1.125 -                  | _ => I)
   1.126 -              | NONE => I)
   1.127 -          | NONE => I)
   1.128 -      | check_improve _ = I;
   1.129 -    val improvements = (fold o fold_aterms) check_improve ts [];
   1.130 -    val ts' = (map o map_types o map_atyps) (fn ty as TVar tvar =>
   1.131 -        if member (op =) improvements tvar
   1.132 -          then TFree (Name.aT, base_sort) else ty | ty => ty) ts;
   1.133 -    fun check t0 = Envir.expand_term (fn Const (c, ty) => (case AList.lookup (op =) operations c
   1.134 -         of SOME (ty0, t) =>
   1.135 -              if Type.typ_instance (ProofContext.tsig_of ctxt) (ty, ty0)
   1.136 -              then SOME (ty0, check t) else NONE
   1.137 -          | NONE => NONE)
   1.138 -      | _ => NONE) t0;
   1.139 -    val ts'' = map check ts';
   1.140 -  in if eq_list (op aconv) (ts, ts'') andalso passed then NONE
   1.141 -  else
   1.142 -    ctxt
   1.143 -    |> fold (ProofContext.add_const_constraint o apsnd SOME) global_constraints
   1.144 -    |> mark_passed
   1.145 -    |> pair ts''
   1.146 -    |> SOME
   1.147 -  end;
   1.148 -
   1.149 -fun sort_term_uncheck ts ctxt =
   1.150 -  let
   1.151 -    val thy = ProofContext.theory_of ctxt;
   1.152 -    val unchecks = (#unchecks o ClassSyntax.get) ctxt;
   1.153 -    val ts' = map (Pattern.rewrite_term thy unchecks []) ts;
   1.154 -  in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', ctxt) end;
   1.155 +  in synchronize_class_syntax [class] base_sort ctxt end;
   1.156  
   1.157  fun init_ctxt sups base_sort ctxt =
   1.158    ctxt
   1.159    |> Variable.declare_term
   1.160        (Logic.mk_type (TFree (Name.aT, base_sort)))
   1.161 -  |> synchronize_syntax sups base_sort
   1.162 -  |> Context.proof_map (
   1.163 -      Syntax.add_term_check 0 "class" sort_term_check
   1.164 -      #> Syntax.add_term_uncheck 0 "class" sort_term_uncheck)
   1.165 +  |> synchronize_class_syntax sups base_sort
   1.166 +  |> Overloading.add_improvable_syntax;
   1.167  
   1.168  fun init class thy =
   1.169    thy
   1.170 @@ -710,46 +651,52 @@
   1.171    |> find_first (fn (_, (v', _)) => v = v')
   1.172    |> Option.map (fst o fst);
   1.173  
   1.174 -fun confirm_declaration c = (map_instantiation o apsnd)
   1.175 -  (filter_out (fn (_, (c', _)) => c' = c));
   1.176 -
   1.177  
   1.178  (* syntax *)
   1.179  
   1.180 -fun inst_term_check ts lthy =
   1.181 +fun synchronize_inst_syntax ctxt =
   1.182    let
   1.183 -    val params = instantiation_params lthy;
   1.184 -    val tsig = ProofContext.tsig_of lthy;
   1.185 -    val thy = ProofContext.theory_of lthy;
   1.186 -
   1.187 -    fun check_improve (Const (c, ty)) = (case AxClass.inst_tyco_of thy (c, ty)
   1.188 -         of SOME tyco => (case AList.lookup (op =) params (c, tyco)
   1.189 -             of SOME (_, ty') => perhaps (try (Type.typ_match tsig (ty, ty')))
   1.190 -              | NONE => I)
   1.191 -          | NONE => I)
   1.192 -      | check_improve _ = I;
   1.193 -    val subst_param = map_aterms (fn t as Const (c, ty) =>
   1.194 -        (case AxClass.inst_tyco_of thy (c, ty)
   1.195 +    val Instantiation { arities = (_, _, sorts), params = params } = Instantiation.get ctxt;
   1.196 +    val thy = ProofContext.theory_of ctxt;
   1.197 +    val operations = these_operations thy sorts;
   1.198 +    fun subst_class_typ sort = map_atyps
   1.199 +      (fn TFree _ => TVar ((Name.aT, 0), sort) | ty' => ty');
   1.200 +    val local_constraints =
   1.201 +      (map o apsnd) (subst_class_typ [] o fst o snd) operations;
   1.202 +    val global_constraints = map_filter (fn (c, (class, (ty, _))) =>
   1.203 +      if exists (fn ((c', _), _) => c = c') params
   1.204 +        then SOME (c, subst_class_typ [class] ty)
   1.205 +        else NONE) operations;
   1.206 +    fun improve (c, ty) = case AxClass.inst_tyco_of thy (c, ty)
   1.207           of SOME tyco => (case AList.lookup (op =) params (c, tyco)
   1.208 -             of SOME v_ty => Free v_ty
   1.209 -              | NONE => t)
   1.210 -          | NONE => t)
   1.211 -      | t => t);
   1.212 -
   1.213 -    val improvement = (fold o fold_aterms) check_improve ts Vartab.empty;
   1.214 -    val ts' = (map o map_types) (Envir.typ_subst_TVars improvement) ts;
   1.215 -    val ts'' = map subst_param ts';
   1.216 -  in if eq_list (op aconv) (ts, ts'') then NONE else SOME (ts'', lthy) end;
   1.217 -
   1.218 -fun inst_term_uncheck ts lthy =
   1.219 -  let
   1.220 -    val params = instantiation_params lthy;
   1.221 -    val ts' = (map o map_aterms) (fn t as Free (v, ty) =>
   1.222 -       (case get_first (fn ((c, _), (v', _)) => if v = v' then SOME c else NONE) params
   1.223 -         of SOME c => Const (c, ty)
   1.224 -          | NONE => t)
   1.225 -      | t => t) ts;
   1.226 -  in if eq_list (op aconv) (ts, ts') then NONE else SOME (ts', lthy) end;
   1.227 +             of SOME (_, ty') => SOME (ty, ty')
   1.228 +              | NONE => NONE)
   1.229 +          | NONE => NONE;
   1.230 +          (*| NONE => (case map_filter
   1.231 +               (fn ((c', _), (_, ty')) => if c' = c then SOME ty' else NONE) params
   1.232 +             of [ty'] => (case Sign.const_typargs thy (c, ty)
   1.233 +                 of [TVar (vi, _)] => if TypeInfer.is_param vi then SOME (ty, ty') else NONE
   1.234 +                  | _ => NONE)
   1.235 +              | _ => NONE*);
   1.236 +    fun subst (c, ty) = case AxClass.inst_tyco_of thy (c, ty)
   1.237 +         of SOME tyco => (case AList.lookup (op =) params (c, tyco)
   1.238 +             of SOME (v_ty as (_, ty)) => SOME (ty, Free v_ty)
   1.239 +              | NONE => NONE)
   1.240 +          | NONE => NONE;
   1.241 +    val unchecks =
   1.242 +      map (fn ((c, _), v_ty as (_, ty)) => (Free v_ty, Const (c, ty))) params;
   1.243 +  in
   1.244 +    ctxt
   1.245 +    |> fold (ProofContext.add_const_constraint o apsnd SOME) local_constraints
   1.246 +    |> Overloading.map_improvable_syntax (K {
   1.247 +        local_constraints = local_constraints,
   1.248 +        global_constraints = global_constraints,
   1.249 +        improve = improve,
   1.250 +        subst = subst,
   1.251 +        unchecks = unchecks,
   1.252 +        passed = false
   1.253 +      })
   1.254 +  end;
   1.255  
   1.256  
   1.257  (* target *)
   1.258 @@ -786,12 +733,14 @@
   1.259      |> Instantiation.put (mk_instantiation ((tycos, vs, sort), params))
   1.260      |> fold (Variable.declare_term o Logic.mk_type o TFree) vs
   1.261      |> fold (Variable.declare_names o Free o snd) params
   1.262 -    |> fold (fn tyco => ProofContext.add_arity (tyco, map snd vs, sort)) tycos
   1.263 -    |> Context.proof_map (
   1.264 -        Syntax.add_term_check 0 "instance" inst_term_check
   1.265 -        #> Syntax.add_term_uncheck 0 "instance" inst_term_uncheck)
   1.266 +    |> synchronize_inst_syntax
   1.267 +    |> Overloading.add_improvable_syntax
   1.268    end;
   1.269  
   1.270 +fun confirm_declaration c = (map_instantiation o apsnd)
   1.271 +  (filter_out (fn (_, (c', _)) => c' = c))
   1.272 +  #> LocalTheory.target synchronize_inst_syntax
   1.273 +
   1.274  fun gen_instantiation_instance do_proof after_qed lthy =
   1.275    let
   1.276      val (tycos, vs, sort) = (#arities o the_instantiation) lthy;