intermediate cleanup
authorhaftmann
Thu, 04 Oct 2007 19:41:50 +0200
changeset 24836 dab06e93ec28
parent 24835 8c26128f8997
child 24837 cacc5744be75
intermediate cleanup
src/Pure/Isar/class.ML
--- a/src/Pure/Isar/class.ML	Thu Oct 04 19:41:49 2007 +0200
+++ b/src/Pure/Isar/class.ML	Thu Oct 04 19:41:50 2007 +0200
@@ -25,7 +25,6 @@
   val intro_classes_tac: thm list -> tactic
   val default_intro_classes_tac: thm list -> tactic
   val class_of_locale: theory -> string -> class option
-  val local_syntax: theory -> class -> bool
   val print_classes: theory -> unit
 
   val instance_arity: (theory -> theory) -> arity list -> theory -> Proof.state
@@ -49,10 +48,10 @@
   (*experimental*)
   val init_ref: (sort -> Proof.context -> Proof.context) ref
   val init: sort -> Proof.context -> Proof.context;
-  val init_default: sort -> Proof.context -> Proof.context;
-  val local_param: theory -> sort -> string -> (term * (class * int)) option
-  val remove_constraints': sort -> theory -> (string * typ) list * theory
-  val remove_constraints: sort -> Proof.context -> (string * typ) list * Proof.context
+  val init_exp: sort -> Proof.context -> Proof.context;
+  val local_syntax: theory -> class -> bool
+  val add_abbrev_in_class: string -> (string * term) * Syntax.mixfix
+    -> theory -> term * theory
 end;
 
 structure Class : CLASS =
@@ -100,15 +99,6 @@
       | NONE => thm;
   in strip end;
 
-fun get_remove_constraint c thy =
-  let
-    val ty = Sign.the_const_constraint thy c;
-  in
-    thy
-    |> Sign.add_const_constraint (c, NONE)
-    |> pair (c, Logic.unvarifyT ty)
-  end;
-
 
 (** axclass command **)
 
@@ -256,6 +246,14 @@
      of [] => ()
       | dupl_tycos => error ("type constructors occur more than once in arities: "
           ^ (commas o map quote) dupl_tycos);
+    fun get_remove_constraint c thy =
+      let
+        val ty = Sign.the_const_constraint thy c;
+      in
+        thy
+        |> Sign.add_const_constraint (c, NONE)
+        |> pair (c, Logic.unvarifyT ty)
+      end;
     fun get_consts_class tyco ty class =
       let
         val cs = (these o Option.map snd o try (AxClass.params_of_class theory)) class;
@@ -329,29 +327,29 @@
 datatype class_data = ClassData of {
   locale: string,
   consts: (string * string) list
-    (*locale parameter ~> theory constant name*),
+    (*locale parameter ~> constant name*),
   local_sort: sort,
   inst: typ Symtab.table * term Symtab.table
     (*canonical interpretation*),
   intro: thm,
   local_syntax: bool,
   defs: thm list,
-  localized: (string * (term * (class * int))) list
-    (*theory constant name ~> (locale parameter, (class, instantiaton index of class typ))*)
+  operations: (string * (term * int) option) list
+    (*constant name ~> (locale term, instantiaton index of class typ)*)
 };
 
 fun rep_class_data (ClassData d) = d;
-fun mk_class_data ((locale, consts, local_sort, inst, intro, local_syntax), (defs, localized)) =
+fun mk_class_data ((locale, consts, local_sort, inst, intro, local_syntax), (defs, operations)) =
   ClassData { locale = locale, consts = consts, local_sort = local_sort, inst = inst, intro = intro,
-    local_syntax = local_syntax, defs = defs, localized = localized };
-fun map_class_data f (ClassData { locale, consts, local_sort, inst, intro, local_syntax, defs, localized }) =
-  mk_class_data (f ((locale, consts, local_sort, inst, intro, local_syntax), (defs, localized)))
+    local_syntax = local_syntax, defs = defs, operations = operations };
+fun map_class_data f (ClassData { locale, consts, local_sort, inst, intro, local_syntax, defs, operations }) =
+  mk_class_data (f ((locale, consts, local_sort, inst, intro, local_syntax), (defs, operations)))
 fun merge_class_data _ (ClassData { locale = locale, consts = consts, local_sort = local_sort, inst = inst,
-    intro = intro, local_syntax = local_syntax, defs = defs1, localized = localized1 },
+    intro = intro, local_syntax = local_syntax, defs = defs1, operations = operations1 },
   ClassData { locale = _, consts = _, local_sort = _, inst = _, intro = _, local_syntax = _,
-    defs = defs2, localized = localized2 }) =
+    defs = defs2, operations = operations2 }) =
   mk_class_data ((locale, consts, local_sort, inst, intro, local_syntax),
-    (Thm.merge_thms (defs1, defs2), AList.merge (op =) (K true) (localized1, localized2)));
+    (Thm.merge_thms (defs1, defs2), AList.merge (op =) (K true) (operations1, operations2)));
 
 fun merge_pair f1 f2 ((x1, y1), (x2, y2)) = (f1 (x1, x2), f2 (y1, y2));
 
@@ -395,12 +393,12 @@
   Graph.fold (fn (_, (data, _)) => insert Thm.eq_thm ((#intro o rep_class_data) data))
     ((fst o ClassData.get) thy) [];
 
-fun these_localized thy =
-  maps (#localized o the_class_data thy) o ancestry thy;
+fun these_operations thy =
+  maps (#operations o the_class_data thy) o ancestry thy;
 
-fun local_param thy = AList.lookup (op =) o these_localized thy;
+fun local_operation thy = Option.join oo AList.lookup (op =) o these_operations thy;
 
-fun local_syntax thy = #local_syntax o the_class_data thy
+fun local_syntax thy = #local_syntax o the_class_data thy;
 
 fun print_classes thy =
   let
@@ -441,15 +439,19 @@
   ClassData.map (fn (gr, tab) => (
     gr
     |> Graph.new_node (class, mk_class_data ((locale, (map o apfst) fst consts, local_sort, inst,
-         intro, local_syntax), ([], map (apsnd (rpair (class, 0) o Free) o swap) consts)))
+         intro, local_syntax), ([], map (apsnd (SOME o rpair 0 o Free) o swap) consts)))
     |> fold (curry Graph.add_edge class) superclasses,
     tab
     |> Symtab.update (locale, class)
   ));
 
-fun add_class_const_def (class, (entry, def)) =
+fun register_const (class, (entry, def)) =
   (ClassData.map o apfst o Graph.map_node class o map_class_data o apsnd)
-    (fn (defs, localized) => (def :: defs, (apsnd o apsnd) (pair class) entry :: localized));
+    (fn (defs, operations) => (def :: defs, apsnd SOME entry :: operations));
+
+fun register_abbrev class abbrev =
+  (ClassData.map o apfst o Graph.map_node class o map_class_data o apsnd o apsnd)
+    (cons (abbrev, NONE));
 
 
 (** rule calculation, tactics and methods **)
@@ -543,84 +545,65 @@
 
 (** classes and class target **)
 
-(* class context initialization *)
+(* class context initialization - experimental *)
 
-(*experimental*)
-fun get_remove_constraint_ctxt c ctxt =
+fun get_remove_constraints sort ctxt =
   let
-    val ty = ProofContext.the_const_constraint ctxt c;
+    val operations = these_operations (ProofContext.theory_of ctxt) sort;
+    fun get_remove (c, _) ctxt =
+      let
+        val ty = ProofContext.the_const_constraint ctxt c;
+        val _ = tracing c;
+      in
+        ctxt
+        |> ProofContext.add_const_constraint (c, NONE)
+        |> pair (c, ty)
+      end;
   in
     ctxt
-    |> ProofContext.add_const_constraint (c, NONE)
-    |> pair (c, ty)
-  end;
-
-fun remove_constraints' class thy =
-  thy |> fold_map (get_remove_constraint o fst) (these_localized thy class);
-
-fun remove_constraints class ctxt =
-  ctxt |> fold_map (get_remove_constraint_ctxt o fst) (these_localized (ProofContext.theory_of ctxt) class);
-
-fun default_typ ctxt constraints c =
-  case AList.lookup (op =) constraints c
-   of SOME ty => SOME ty
-    | NONE => try (Consts.the_constraint (ProofContext.consts_of ctxt)) c;
-
-fun infer_constraints ctxt constraints ts =
-    TypeInfer.infer_types (ProofContext.pp ctxt) (Sign.tsig_of (ProofContext.theory_of ctxt))
-     (Syntax.check_typs ctxt)
-      (default_typ ctxt constraints) (K NONE)
-      (Variable.names_of ctxt) (Variable.maxidx_of ctxt) (SOME true) (map (rpair dummyT) ts)
-    |> #1 |> map #1
-  handle TYPE (msg, _, _) => error msg
-
-fun subst_typ local_sort =
-  map_atyps (fn (t as TFree (v, _)) => if v = AxClass.param_tyvarname
-        then TFree (v, local_sort)
-        else t
-    | t => t);
-
-fun sort_typ_check thy sort =
-  let
-    val local_sort = (#local_sort o the_class_data thy) (hd sort);
-  in
-    pair o map (subst_typ local_sort)
+    |> fold_map get_remove operations
   end;
 
 fun sort_term_check thy sort constraints =
   let
-    val algebra = Sign.classes_of thy;
-    val local_sort = (#local_sort o the_class_data thy) (hd sort);
-    val v = AxClass.param_tyvarname;
-    val local_param = local_param thy sort;
-      (*FIXME efficiency*)
-    fun class_arg c idx ty =
-      let
-        val typargs = Sign.const_typargs thy (c, ty);
-        fun classtyp (TFree (w, _)) = w = v
-          | classtyp t = false;
-      in classtyp (nth typargs idx) end;
-    fun subst (t as Const (c, ty)) = (case local_param c
-         of NONE => t
-          | SOME (t', (_, idx)) => if class_arg c idx ty
-             then t' else t)
-      | subst t = t;
-  in fn ts => fn ctxt =>
-    ((map (map_aterms subst) #> infer_constraints ctxt constraints) ts, ctxt)
-  end;
+    val local_operation = local_operation thy sort;
+    fun default_typ consts c = case AList.lookup (op =) constraints c
+     of SOME ty => SOME ty
+      | NONE => try (Consts.the_constraint consts) c;
+    fun infer_constraints ctxt ts =
+        TypeInfer.infer_types (ProofContext.pp ctxt)
+          (Sign.tsig_of (ProofContext.theory_of ctxt))
+          I (default_typ (ProofContext.consts_of ctxt)) (K NONE)
+          (Variable.names_of ctxt) (Variable.maxidx_of ctxt) NONE (map (rpair dummyT) ts)
+        |> fst |> map fst
+      handle TYPE (msg, _, _) => error msg;
+    fun check_typ c idx ty = case (nth (Sign.const_typargs thy (c, ty)) idx) (*FIXME localize*)
+     of TFree (v, _) => v = AxClass.param_tyvarname
+      | TVar (vi, _) => TypeInfer.is_param vi (*FIXME substitute in all typs*)
+      | _ => false;
+    fun subst_operation (t as Const (c, ty)) = (case local_operation c
+         of SOME (t', idx) => if check_typ c idx ty then t' else t
+          | NONE => t)
+      | subst_operation t = t;
+    fun subst_operations ts ctxt =
+      ts
+      |> (map o map_aterms) subst_operation
+      |> infer_constraints ctxt
+      |> rpair ctxt; (*FIXME add constraints here*)
+  in subst_operations end;
 
-fun init_default sort ctxt =
+fun init_exp sort ctxt =
   let
     val thy = ProofContext.theory_of ctxt;
-    val typ_check = sort_typ_check thy sort;
+    val local_sort = (#local_sort o the_class_data thy) (hd sort);
     val term_check = sort_term_check thy sort;
   in
     ctxt
-    |> remove_constraints sort
-    ||> Variable.declare_term (Logic.mk_type (TFree (AxClass.param_tyvarname, sort)))
-    ||> Context.proof_map (Syntax.add_typ_check 0 "class" typ_check)
-    |-> (fn constraints =>
-        Context.proof_map (Syntax.add_term_check 0 "class" (term_check constraints)))
+    |> Variable.declare_term
+        (Logic.mk_type (TFree (AxClass.param_tyvarname, local_sort)))
+    |> get_remove_constraints sort
+    |-> (fn constraints => Context.proof_map (Syntax.add_term_check 50 "class"
+          (sort_term_check thy sort constraints)))
   end;
 
 val init_ref = ref (K I : sort -> Proof.context -> Proof.context);
@@ -645,6 +628,9 @@
     val sups = filter (is_some o lookup_class_data thy) supclasses
       |> Sign.minimize_sort thy;
     val supsort = Sign.minimize_sort thy supclasses;
+    val local_sort = case sups
+     of sup :: _ => (#local_sort o the_class_data thy) sup
+      | [] => supsort;
     val suplocales = map (Locale.Locale o #locale o the_class_data thy) sups;
     val (raw_elems, includes) = fold_rev (fn Locale.Elem e => apfst (cons e)
       | Locale.Expr i => apsnd (cons (prep_expr thy i))) raw_includes_elems ([], []);
@@ -662,7 +648,7 @@
     |> init supsort
     |> process_expr Locale.empty raw_elems
     |> fst
-    |> (fn elems => ((((sups, supconsts), (supsort, mergeexpr)),
+    |> (fn elems => ((((sups, supconsts), (supsort, local_sort, mergeexpr)),
           (*FIXME*) if null includes then constrain :: elems else elems)))
   end;
 
@@ -672,7 +658,7 @@
 fun gen_class prep_spec prep_param local_syntax bname
     raw_supclasses raw_includes_elems raw_other_consts thy =
   let
-    val (((sups, supconsts), (supsort, mergeexpr)), elems) =
+    val (((sups, supconsts), (supsort, local_sort, mergeexpr)), elems) =
       prep_spec thy raw_supclasses raw_includes_elems;
     val other_consts = map (prep_param thy) raw_other_consts;
     fun mk_instT class = Symtab.empty
@@ -721,7 +707,7 @@
     |> Locale.add_locale_i (SOME "") bname mergeexpr elems
     |-> (fn name_locale => ProofContext.theory_result (
       `(fn thy => extract_params thy name_locale)
-      #-> (fn (local_sort, (globals, params)) =>
+      #-> (fn (_, (globals, params)) =>
         AxClass.define_class_params (bname, supsort) params
           (extract_assumes name_locale params) other_consts
       #-> (fn (name_axclass, (consts, axioms)) =>
@@ -783,7 +769,7 @@
       in
         thy
         |> class_interpretation class [def'] [def_eq]
-        |> add_class_const_def (class, ((c', (rhs, typidx)), def'))
+        |> register_const (class, ((c', (rhs, typidx)), def'))
       end;
   in
     thy
@@ -797,6 +783,23 @@
     |> Sign.restore_naming thy
   end;
 
+fun add_abbrev_in_class class ((c, rhs), syn) thy =
+  let
+    val local_sort = (#local_sort o the_class_data thy) class;
+    val subst_typ = Term.map_type_tfree (fn var as (w, sort) =>
+      if w = AxClass.param_tyvarname then TFree (w, local_sort) else TFree var);
+    val ty = fastype_of rhs;
+    val rhs' = map_types subst_typ rhs;
+  in
+    thy
+    |> Sign.parent_path (*FIXME*)
+    |> Sign.add_abbrev Syntax.internalM [] (c, rhs)
+    |-> (fn (lhs as Const (c', _), _) => register_abbrev class c'
+      (*#> Sign.add_const_constraint (c', SOME ty)*)
+      #> pair lhs)
+    ||> Sign.restore_naming thy
+  end;
+
 
 (* interpretation in class target *)