src/Pure/Isar/class.ML
changeset 24901 d3cbf79769b9
parent 24847 bc15dcaed517
child 24914 95cda5dd58d5
     1.1 --- a/src/Pure/Isar/class.ML	Mon Oct 08 08:04:26 2007 +0200
     1.2 +++ b/src/Pure/Isar/class.ML	Mon Oct 08 08:04:28 2007 +0200
     1.3 @@ -45,13 +45,10 @@
     1.4    val param_const: theory -> string -> (string * string) option
     1.5    val params_of_sort: theory -> sort -> (string * (string * typ)) list
     1.6  
     1.7 -  (*experimental*)
     1.8 -  val init_ref: (sort -> Proof.context -> Proof.context) ref
     1.9 -  val init: sort -> Proof.context -> Proof.context;
    1.10 -  val init_exp: sort -> Proof.context -> Proof.context;
    1.11 +  val init: class -> Proof.context -> Proof.context;
    1.12    val local_syntax: theory -> class -> bool
    1.13 -  val add_abbrev_in_class: string -> (string * term) * Syntax.mixfix
    1.14 -    -> theory -> term * theory
    1.15 +  val add_abbrev_in_class: string -> Syntax.mode -> (string * term) * mixfix
    1.16 +    -> theory -> theory
    1.17  end;
    1.18  
    1.19  structure Class : CLASS =
    1.20 @@ -62,8 +59,8 @@
    1.21  fun fork_mixfix is_loc some_class mx =
    1.22    let
    1.23      val mx' = Syntax.unlocalize_mixfix mx;
    1.24 -    val mx_global = if is_some some_class orelse (is_loc andalso mx = mx')
    1.25 -      then NoSyn else mx';
    1.26 +    val mx_global = if not is_loc orelse (is_some some_class andalso not (mx = mx'))
    1.27 +      then mx' else NoSyn;
    1.28      val mx_local = if is_loc then mx else NoSyn;
    1.29    in (mx_global, mx_local) end;
    1.30  
    1.31 @@ -398,6 +395,15 @@
    1.32  
    1.33  fun local_operation thy = Option.join oo AList.lookup (op =) o these_operations thy;
    1.34  
    1.35 +fun sups_local_sort thy sort =
    1.36 +  let
    1.37 +    val sups = filter (is_some o lookup_class_data thy) sort
    1.38 +      |> Sign.minimize_sort thy;
    1.39 +    val local_sort = case sups
    1.40 +     of sup :: _ => #local_sort (the_class_data thy sup)
    1.41 +      | [] => sort;
    1.42 +  in (sups, local_sort) end;
    1.43 +
    1.44  fun local_syntax thy = #local_syntax o the_class_data thy;
    1.45  
    1.46  fun print_classes thy =
    1.47 @@ -445,7 +451,7 @@
    1.48      |> Symtab.update (locale, class)
    1.49    ));
    1.50  
    1.51 -fun register_const (class, (entry, def)) =
    1.52 +fun register_const class (entry, def) =
    1.53    (ClassData.map o apfst o Graph.map_node class o map_class_data o apsnd)
    1.54      (fn (defs, operations) => (def :: defs, apsnd SOME entry :: operations));
    1.55  
    1.56 @@ -545,18 +551,20 @@
    1.57  
    1.58  (** classes and class target **)
    1.59  
    1.60 -(* class context initialization - experimental *)
    1.61 +(* class context initialization *)
    1.62  
    1.63 -fun get_remove_constraints sort ctxt =
    1.64 +fun get_remove_constraints sups local_sort ctxt =
    1.65    let
    1.66 -    val operations = these_operations (ProofContext.theory_of ctxt) sort;
    1.67 +    val thy = ProofContext.theory_of ctxt;
    1.68 +    val operations = these_operations thy sups;
    1.69      fun get_remove (c, _) ctxt =
    1.70        let
    1.71          val ty = ProofContext.the_const_constraint ctxt c;
    1.72 -        val _ = tracing c;
    1.73 +        val ty' = map_atyps (fn ty as TFree (v, _) => if v = Name.aT
    1.74 +          then TFree (v, local_sort) else ty | ty => ty) ty;
    1.75        in
    1.76          ctxt
    1.77 -        |> ProofContext.add_const_constraint (c, NONE)
    1.78 +        |> ProofContext.add_const_constraint (c, SOME ty')
    1.79          |> pair (c, ty)
    1.80        end;
    1.81    in
    1.82 @@ -564,51 +572,43 @@
    1.83      |> fold_map get_remove operations
    1.84    end;
    1.85  
    1.86 -fun sort_term_check thy sort constraints =
    1.87 +fun sort_term_check thy sups local_sort constraints ts ctxt =
    1.88    let
    1.89 -    val local_operation = local_operation thy sort;
    1.90 -    fun default_typ consts c = case AList.lookup (op =) constraints c
    1.91 -     of SOME ty => SOME ty
    1.92 -      | NONE => try (Consts.the_constraint consts) c;
    1.93 -    fun infer_constraints ctxt ts =
    1.94 -        TypeInfer.infer_types (ProofContext.pp ctxt)
    1.95 -          (Sign.tsig_of (ProofContext.theory_of ctxt))
    1.96 -          I (default_typ (ProofContext.consts_of ctxt)) (K NONE)
    1.97 -          (Variable.names_of ctxt) (Variable.maxidx_of ctxt) NONE (map (rpair dummyT) ts)
    1.98 -        |> fst |> map fst
    1.99 -      handle TYPE (msg, _, _) => error msg;
   1.100 -    fun check_typ c idx ty = case (nth (Sign.const_typargs thy (c, ty)) idx) (*FIXME localize*)
   1.101 -     of TFree (v, _) => v = Name.aT
   1.102 -      | TVar (vi, _) => TypeInfer.is_param vi (*FIXME substitute in all typs*)
   1.103 -      | _ => false;
   1.104 -    fun subst_operation (t as Const (c, ty)) = (case local_operation c
   1.105 -         of SOME (t', idx) => if check_typ c idx ty then t' else t
   1.106 -          | NONE => t)
   1.107 -      | subst_operation t = t;
   1.108 -    fun subst_operations ts ctxt =
   1.109 -      ts
   1.110 -      |> (map o map_aterms) subst_operation
   1.111 -      |> infer_constraints ctxt
   1.112 -      |> rpair ctxt; (*FIXME add constraints here*)
   1.113 -  in subst_operations end;
   1.114 +    val local_operation = local_operation thy sups;
   1.115 +    val consts = ProofContext.consts_of ctxt;
   1.116 +    fun check_typ (c, ty) (t', idx) = case nth (Consts.typargs consts (c, ty)) idx
   1.117 +     of TFree (v, _) => if v = Name.aT
   1.118 +          then apfst (AList.update (op =) ((c, ty), t')) else I
   1.119 +      | TVar (vi, _) => if TypeInfer.is_param vi
   1.120 +          then apfst (AList.update (op =) ((c, ty), t'))
   1.121 +            #> apsnd (insert (op =) vi) else I
   1.122 +      | _ => I;
   1.123 +    fun add_const (Const (c, ty)) = (case local_operation c
   1.124 +         of SOME (t', idx) => check_typ (c, ty) (t', idx)
   1.125 +          | NONE => I)
   1.126 +      | add_const _ = I;
   1.127 +    val (cs, typarams) = (fold o fold_aterms) add_const ts ([], []);
   1.128 +    val ts' = map (map_aterms
   1.129 +        (fn t as Const (c, ty) => the_default t (AList.lookup (op =) cs (c, ty)) | t => t)
   1.130 +      #> (map_types o map_type_tvar) (fn var as (vi, _) => if member (op =) typarams vi
   1.131 +           then TFree (Name.aT, local_sort) else TVar var)) ts;
   1.132 +    val ctxt' = fold (ProofContext.add_const_constraint o apsnd SOME) constraints ctxt;
   1.133 +  in (ts', ctxt') end;
   1.134  
   1.135 -fun init_exp sort ctxt =
   1.136 +fun init_class_ctxt thy sups local_sort ctxt =
   1.137 +  ctxt
   1.138 +  |> Variable.declare_term
   1.139 +      (Logic.mk_type (TFree (Name.aT, local_sort)))
   1.140 +  |> get_remove_constraints sups local_sort
   1.141 +  |-> (fn constraints => Context.proof_map (Syntax.add_term_check 50 "class"
   1.142 +        (sort_term_check thy sups local_sort constraints)));
   1.143 +
   1.144 +fun init class ctxt =
   1.145    let
   1.146      val thy = ProofContext.theory_of ctxt;
   1.147 -    val local_sort = (#local_sort o the_class_data thy) (hd sort);
   1.148 -    val term_check = sort_term_check thy sort;
   1.149 -  in
   1.150 -    ctxt
   1.151 -    |> Variable.declare_term
   1.152 -        (Logic.mk_type (TFree (Name.aT, local_sort)))
   1.153 -    |> get_remove_constraints sort
   1.154 -    |-> (fn constraints => Context.proof_map (Syntax.add_term_check 50 "class"
   1.155 -          (sort_term_check thy sort constraints)))
   1.156 -  end;
   1.157 -
   1.158 -val init_ref = ref (K I : sort -> Proof.context -> Proof.context);
   1.159 -fun init class = ! init_ref class;
   1.160 -
   1.161 +    val local_sort = (#local_sort o the_class_data thy) class;
   1.162 +  in init_class_ctxt thy [class] local_sort ctxt end;
   1.163 +  
   1.164  
   1.165  (* class definition *)
   1.166  
   1.167 @@ -625,12 +625,8 @@
   1.168  fun gen_class_spec prep_class prep_expr process_expr thy raw_supclasses raw_includes_elems =
   1.169    let
   1.170      val supclasses = map (prep_class thy) raw_supclasses;
   1.171 -    val sups = filter (is_some o lookup_class_data thy) supclasses
   1.172 -      |> Sign.minimize_sort thy;
   1.173 +    val (sups, local_sort) = sups_local_sort thy supclasses;
   1.174      val supsort = Sign.minimize_sort thy supclasses;
   1.175 -    val local_sort = case sups
   1.176 -     of sup :: _ => (#local_sort o the_class_data thy) sup
   1.177 -      | [] => supsort;
   1.178      val suplocales = map (Locale.Locale o #locale o the_class_data thy) sups;
   1.179      val (raw_elems, includes) = fold_rev (fn Locale.Elem e => apfst (cons e)
   1.180        | Locale.Expr i => apsnd (cons (prep_expr thy i))) raw_includes_elems ([], []);
   1.181 @@ -645,7 +641,7 @@
   1.182      ProofContext.init thy
   1.183      |> Locale.cert_expr supexpr [constrain]
   1.184      |> snd
   1.185 -    |> init supsort
   1.186 +    |> init_class_ctxt thy sups local_sort
   1.187      |> process_expr Locale.empty raw_elems
   1.188      |> fst
   1.189      |> (fn elems => ((((sups, supconsts), (supsort, local_sort, mergeexpr)),
   1.190 @@ -670,20 +666,16 @@
   1.191      fun extract_params thy name_locale =
   1.192        let
   1.193          val params = Locale.parameters_of thy name_locale;
   1.194 -        val local_sort = case AList.group (op =) ((maps typ_tfrees o map (snd o fst)) params)
   1.195 -         of [(_, local_sort :: _)] => local_sort
   1.196 -          | _ => Sign.defaultS thy
   1.197 -          | vs => error ("exactly one type variable required: " ^ commas (map fst vs));
   1.198          val _ = if Sign.subsort thy (supsort, local_sort) then () else error
   1.199            ("Sort " ^ Sign.string_of_sort thy local_sort
   1.200              ^ " is less general than permitted least general sort "
   1.201              ^ Sign.string_of_sort thy supsort);
   1.202        in
   1.203 -        (local_sort, (map fst params, params
   1.204 +        (map fst params, params
   1.205          |> (map o apfst o apsnd o Term.map_type_tfree) (K (TFree (Name.aT, local_sort)))
   1.206 -        |> (map o apsnd) (fork_mixfix true NONE #> fst)
   1.207 +        |> (map o apsnd) (fork_mixfix true (SOME "") #> fst)
   1.208          |> chop (length supconsts)
   1.209 -        |> snd))
   1.210 +        |> snd)
   1.211        end;
   1.212      fun extract_assumes name_locale params thy cs =
   1.213        let
   1.214 @@ -707,7 +699,7 @@
   1.215      |> Locale.add_locale_i (SOME "") bname mergeexpr elems
   1.216      |-> (fn name_locale => ProofContext.theory_result (
   1.217        `(fn thy => extract_params thy name_locale)
   1.218 -      #-> (fn (_, (globals, params)) =>
   1.219 +      #-> (fn (globals, params) =>
   1.220          AxClass.define_class_params (bname, supsort) params
   1.221            (extract_assumes name_locale params) other_consts
   1.222        #-> (fn (name_axclass, (consts, axioms)) =>
   1.223 @@ -742,6 +734,12 @@
   1.224        | subst_aterm t = t;
   1.225    in Term.map_aterms subst_aterm end;
   1.226  
   1.227 +fun mk_operation_entry thy (c, rhs) =
   1.228 +  let
   1.229 +    val typargs = Sign.const_typargs thy (c, fastype_of rhs);
   1.230 +    val typidx = find_index (fn TFree (w, _) => Name.aT = w | _ => false) typargs;
   1.231 +  in (c, (rhs, typidx)) end;
   1.232 +
   1.233  fun add_const_in_class class ((c, rhs), syn) thy =
   1.234    let
   1.235      val prfx = (Logic.const_of_class o NameSpace.base) class;
   1.236 @@ -759,17 +757,16 @@
   1.237      val ty'' = subst_typ ty';
   1.238      val c' = mk_name c;
   1.239      val def = (c, Logic.mk_equals (Const (c', ty'), rhs'));
   1.240 -    val (syn', _) = fork_mixfix true NONE syn;
   1.241 +    val (syn', _) = fork_mixfix true (SOME class) syn;
   1.242      fun interpret def thy =
   1.243        let
   1.244          val def' = symmetric def;
   1.245          val def_eq = Thm.prop_of def';
   1.246 -        val typargs = Sign.const_typargs thy (c', fastype_of rhs);
   1.247 -        val typidx = find_index (fn TFree (w, _) => Name.aT = w | _ => false) typargs;
   1.248 +        val entry = mk_operation_entry thy (c', rhs);
   1.249        in
   1.250          thy
   1.251          |> class_interpretation class [def'] [def_eq]
   1.252 -        |> register_const (class, ((c', (rhs, typidx)), def'))
   1.253 +        |> register_const class (entry, def')
   1.254        end;
   1.255    in
   1.256      thy
   1.257 @@ -783,21 +780,25 @@
   1.258      |> Sign.restore_naming thy
   1.259    end;
   1.260  
   1.261 -fun add_abbrev_in_class class ((c, rhs), syn) thy =
   1.262 +
   1.263 +(* abbreviation in class target *)
   1.264 +
   1.265 +fun add_abbrev_in_class class prmode ((c, rhs), syn) thy =
   1.266    let
   1.267 -    val local_sort = (#local_sort o the_class_data thy) class;
   1.268 -    val subst_typ = Term.map_type_tfree (fn var as (w, sort) =>
   1.269 -      if w = Name.aT then TFree (w, local_sort) else TFree var);
   1.270 -    val ty = fastype_of rhs;
   1.271 -    val rhs' = map_types subst_typ rhs;
   1.272 +    val prfx = (Logic.const_of_class o NameSpace.base) class;
   1.273 +    fun mk_name c =
   1.274 +      let
   1.275 +        val n1 = Sign.full_name thy c;
   1.276 +        val n2 = NameSpace.qualifier n1;
   1.277 +        val n3 = NameSpace.base n1;
   1.278 +      in NameSpace.implode [n2, prfx, prfx, n3] end;
   1.279 +    val c' = mk_name c;
   1.280 +    val rhs' = export_fixes thy class rhs;
   1.281 +    val ty' = fastype_of rhs';
   1.282    in
   1.283      thy
   1.284 -    |> Sign.parent_path (*FIXME*)
   1.285 -    |> Sign.add_abbrev Syntax.internalM [] (c, rhs)
   1.286 -    |-> (fn (lhs as Const (c', _), _) => register_abbrev class c'
   1.287 -      (*#> Sign.add_const_constraint (c', SOME ty)*)
   1.288 -      #> pair lhs)
   1.289 -    ||> Sign.restore_naming thy
   1.290 +    |> Sign.add_notation prmode [(Const (c', ty'), syn)]
   1.291 +    |> register_abbrev class c'
   1.292    end;
   1.293  
   1.294