added first version of user-space type system for class target
authorhaftmann
Mon, 08 Oct 2007 08:04:28 +0200
changeset 24901 d3cbf79769b9
parent 24900 5471709833a4
child 24902 49f002c3964e
added first version of user-space type system for class target
src/HOL/HOL.thy
src/HOL/Real/RealVector.thy
src/Pure/Isar/class.ML
src/Pure/Isar/theory_target.ML
--- a/src/HOL/HOL.thy	Mon Oct 08 08:04:26 2007 +0200
+++ b/src/HOL/HOL.thy	Mon Oct 08 08:04:28 2007 +0200
@@ -22,13 +22,13 @@
   "~~/src/Provers/eqsubst.ML"
   "~~/src/Provers/quantifier1.ML"
   ("simpdata.ML")
+  "~~/src/Tools/induct.ML"
   "~~/src/Tools/code/code_name.ML"
   "~~/src/Tools/code/code_funcgr.ML"
   "~~/src/Tools/code/code_thingol.ML"
   "~~/src/Tools/code/code_target.ML"
   "~~/src/Tools/code/code_package.ML"
   "~~/src/Tools/nbe.ML"
-  "~~/src/Tools/induct.ML"
 begin
 
 subsection {* Primitive logic *}
@@ -205,13 +205,13 @@
 subsubsection {* Generic classes and algebraic operations *}
 
 class default = type +
-  fixes default :: "'a"
+  fixes default :: 'a
 
 class zero = type + 
-  fixes zero :: "'a"  ("\<^loc>0")
+  fixes zero :: 'a  ("\<^loc>0")
 
 class one = type +
-  fixes one  :: "'a"  ("\<^loc>1")
+  fixes one  :: 'a  ("\<^loc>1")
 
 hide (open) const zero one
 
@@ -295,11 +295,6 @@
   less_eq  ("(_/ \<le> _)"  [51, 51] 50)
 
 notation (input)
-  greater  (infix ">" 50)
-
-notation (input)
-  greater_eq  (infix ">=" 50)
-and
   greater_eq  (infix "\<ge>" 50)
 
 syntax
--- a/src/HOL/Real/RealVector.thy	Mon Oct 08 08:04:26 2007 +0200
+++ b/src/HOL/Real/RealVector.thy	Mon Oct 08 08:04:28 2007 +0200
@@ -54,11 +54,6 @@
 
 end
 
-abbreviation
-  divideR :: "'a \<Rightarrow> real \<Rightarrow> 'a\<Colon>scaleR" (infixl "'/#" 70)
-where
-  "x /# r == scaleR (inverse r) x"
-
 notation (xsymbols)
   scaleR (infixr "*\<^sub>R" 75) and
   divideR (infixl "'/\<^sub>R" 70)
--- a/src/Pure/Isar/class.ML	Mon Oct 08 08:04:26 2007 +0200
+++ b/src/Pure/Isar/class.ML	Mon Oct 08 08:04:28 2007 +0200
@@ -45,13 +45,10 @@
   val param_const: theory -> string -> (string * string) option
   val params_of_sort: theory -> sort -> (string * (string * typ)) list
 
-  (*experimental*)
-  val init_ref: (sort -> Proof.context -> Proof.context) ref
-  val init: sort -> Proof.context -> Proof.context;
-  val init_exp: sort -> Proof.context -> Proof.context;
+  val init: class -> Proof.context -> Proof.context;
   val local_syntax: theory -> class -> bool
-  val add_abbrev_in_class: string -> (string * term) * Syntax.mixfix
-    -> theory -> term * theory
+  val add_abbrev_in_class: string -> Syntax.mode -> (string * term) * mixfix
+    -> theory -> theory
 end;
 
 structure Class : CLASS =
@@ -62,8 +59,8 @@
 fun fork_mixfix is_loc some_class mx =
   let
     val mx' = Syntax.unlocalize_mixfix mx;
-    val mx_global = if is_some some_class orelse (is_loc andalso mx = mx')
-      then NoSyn else mx';
+    val mx_global = if not is_loc orelse (is_some some_class andalso not (mx = mx'))
+      then mx' else NoSyn;
     val mx_local = if is_loc then mx else NoSyn;
   in (mx_global, mx_local) end;
 
@@ -398,6 +395,15 @@
 
 fun local_operation thy = Option.join oo AList.lookup (op =) o these_operations thy;
 
+fun sups_local_sort thy sort =
+  let
+    val sups = filter (is_some o lookup_class_data thy) sort
+      |> Sign.minimize_sort thy;
+    val local_sort = case sups
+     of sup :: _ => #local_sort (the_class_data thy sup)
+      | [] => sort;
+  in (sups, local_sort) end;
+
 fun local_syntax thy = #local_syntax o the_class_data thy;
 
 fun print_classes thy =
@@ -445,7 +451,7 @@
     |> Symtab.update (locale, class)
   ));
 
-fun register_const (class, (entry, def)) =
+fun register_const class (entry, def) =
   (ClassData.map o apfst o Graph.map_node class o map_class_data o apsnd)
     (fn (defs, operations) => (def :: defs, apsnd SOME entry :: operations));
 
@@ -545,18 +551,20 @@
 
 (** classes and class target **)
 
-(* class context initialization - experimental *)
+(* class context initialization *)
 
-fun get_remove_constraints sort ctxt =
+fun get_remove_constraints sups local_sort ctxt =
   let
-    val operations = these_operations (ProofContext.theory_of ctxt) sort;
+    val thy = ProofContext.theory_of ctxt;
+    val operations = these_operations thy sups;
     fun get_remove (c, _) ctxt =
       let
         val ty = ProofContext.the_const_constraint ctxt c;
-        val _ = tracing c;
+        val ty' = map_atyps (fn ty as TFree (v, _) => if v = Name.aT
+          then TFree (v, local_sort) else ty | ty => ty) ty;
       in
         ctxt
-        |> ProofContext.add_const_constraint (c, NONE)
+        |> ProofContext.add_const_constraint (c, SOME ty')
         |> pair (c, ty)
       end;
   in
@@ -564,51 +572,43 @@
     |> fold_map get_remove operations
   end;
 
-fun sort_term_check thy sort constraints =
+fun sort_term_check thy sups local_sort constraints ts ctxt =
   let
-    val local_operation = local_operation thy sort;
-    fun default_typ consts c = case AList.lookup (op =) constraints c
-     of SOME ty => SOME ty
-      | NONE => try (Consts.the_constraint consts) c;
-    fun infer_constraints ctxt ts =
-        TypeInfer.infer_types (ProofContext.pp ctxt)
-          (Sign.tsig_of (ProofContext.theory_of ctxt))
-          I (default_typ (ProofContext.consts_of ctxt)) (K NONE)
-          (Variable.names_of ctxt) (Variable.maxidx_of ctxt) NONE (map (rpair dummyT) ts)
-        |> fst |> map fst
-      handle TYPE (msg, _, _) => error msg;
-    fun check_typ c idx ty = case (nth (Sign.const_typargs thy (c, ty)) idx) (*FIXME localize*)
-     of TFree (v, _) => v = Name.aT
-      | TVar (vi, _) => TypeInfer.is_param vi (*FIXME substitute in all typs*)
-      | _ => false;
-    fun subst_operation (t as Const (c, ty)) = (case local_operation c
-         of SOME (t', idx) => if check_typ c idx ty then t' else t
-          | NONE => t)
-      | subst_operation t = t;
-    fun subst_operations ts ctxt =
-      ts
-      |> (map o map_aterms) subst_operation
-      |> infer_constraints ctxt
-      |> rpair ctxt; (*FIXME add constraints here*)
-  in subst_operations end;
+    val local_operation = local_operation thy sups;
+    val consts = ProofContext.consts_of ctxt;
+    fun check_typ (c, ty) (t', idx) = case nth (Consts.typargs consts (c, ty)) idx
+     of TFree (v, _) => if v = Name.aT
+          then apfst (AList.update (op =) ((c, ty), t')) else I
+      | TVar (vi, _) => if TypeInfer.is_param vi
+          then apfst (AList.update (op =) ((c, ty), t'))
+            #> apsnd (insert (op =) vi) else I
+      | _ => I;
+    fun add_const (Const (c, ty)) = (case local_operation c
+         of SOME (t', idx) => check_typ (c, ty) (t', idx)
+          | NONE => I)
+      | add_const _ = I;
+    val (cs, typarams) = (fold o fold_aterms) add_const ts ([], []);
+    val ts' = map (map_aterms
+        (fn t as Const (c, ty) => the_default t (AList.lookup (op =) cs (c, ty)) | t => t)
+      #> (map_types o map_type_tvar) (fn var as (vi, _) => if member (op =) typarams vi
+           then TFree (Name.aT, local_sort) else TVar var)) ts;
+    val ctxt' = fold (ProofContext.add_const_constraint o apsnd SOME) constraints ctxt;
+  in (ts', ctxt') end;
 
-fun init_exp sort ctxt =
+fun init_class_ctxt thy sups local_sort ctxt =
+  ctxt
+  |> Variable.declare_term
+      (Logic.mk_type (TFree (Name.aT, local_sort)))
+  |> get_remove_constraints sups local_sort
+  |-> (fn constraints => Context.proof_map (Syntax.add_term_check 50 "class"
+        (sort_term_check thy sups local_sort constraints)));
+
+fun init class ctxt =
   let
     val thy = ProofContext.theory_of ctxt;
-    val local_sort = (#local_sort o the_class_data thy) (hd sort);
-    val term_check = sort_term_check thy sort;
-  in
-    ctxt
-    |> Variable.declare_term
-        (Logic.mk_type (TFree (Name.aT, local_sort)))
-    |> get_remove_constraints sort
-    |-> (fn constraints => Context.proof_map (Syntax.add_term_check 50 "class"
-          (sort_term_check thy sort constraints)))
-  end;
-
-val init_ref = ref (K I : sort -> Proof.context -> Proof.context);
-fun init class = ! init_ref class;
-
+    val local_sort = (#local_sort o the_class_data thy) class;
+  in init_class_ctxt thy [class] local_sort ctxt end;
+  
 
 (* class definition *)
 
@@ -625,12 +625,8 @@
 fun gen_class_spec prep_class prep_expr process_expr thy raw_supclasses raw_includes_elems =
   let
     val supclasses = map (prep_class thy) raw_supclasses;
-    val sups = filter (is_some o lookup_class_data thy) supclasses
-      |> Sign.minimize_sort thy;
+    val (sups, local_sort) = sups_local_sort thy supclasses;
     val supsort = Sign.minimize_sort thy supclasses;
-    val local_sort = case sups
-     of sup :: _ => (#local_sort o the_class_data thy) sup
-      | [] => supsort;
     val suplocales = map (Locale.Locale o #locale o the_class_data thy) sups;
     val (raw_elems, includes) = fold_rev (fn Locale.Elem e => apfst (cons e)
       | Locale.Expr i => apsnd (cons (prep_expr thy i))) raw_includes_elems ([], []);
@@ -645,7 +641,7 @@
     ProofContext.init thy
     |> Locale.cert_expr supexpr [constrain]
     |> snd
-    |> init supsort
+    |> init_class_ctxt thy sups local_sort
     |> process_expr Locale.empty raw_elems
     |> fst
     |> (fn elems => ((((sups, supconsts), (supsort, local_sort, mergeexpr)),
@@ -670,20 +666,16 @@
     fun extract_params thy name_locale =
       let
         val params = Locale.parameters_of thy name_locale;
-        val local_sort = case AList.group (op =) ((maps typ_tfrees o map (snd o fst)) params)
-         of [(_, local_sort :: _)] => local_sort
-          | _ => Sign.defaultS thy
-          | vs => error ("exactly one type variable required: " ^ commas (map fst vs));
         val _ = if Sign.subsort thy (supsort, local_sort) then () else error
           ("Sort " ^ Sign.string_of_sort thy local_sort
             ^ " is less general than permitted least general sort "
             ^ Sign.string_of_sort thy supsort);
       in
-        (local_sort, (map fst params, params
+        (map fst params, params
         |> (map o apfst o apsnd o Term.map_type_tfree) (K (TFree (Name.aT, local_sort)))
-        |> (map o apsnd) (fork_mixfix true NONE #> fst)
+        |> (map o apsnd) (fork_mixfix true (SOME "") #> fst)
         |> chop (length supconsts)
-        |> snd))
+        |> snd)
       end;
     fun extract_assumes name_locale params thy cs =
       let
@@ -707,7 +699,7 @@
     |> Locale.add_locale_i (SOME "") bname mergeexpr elems
     |-> (fn name_locale => ProofContext.theory_result (
       `(fn thy => extract_params thy name_locale)
-      #-> (fn (_, (globals, params)) =>
+      #-> (fn (globals, params) =>
         AxClass.define_class_params (bname, supsort) params
           (extract_assumes name_locale params) other_consts
       #-> (fn (name_axclass, (consts, axioms)) =>
@@ -742,6 +734,12 @@
       | subst_aterm t = t;
   in Term.map_aterms subst_aterm end;
 
+fun mk_operation_entry thy (c, rhs) =
+  let
+    val typargs = Sign.const_typargs thy (c, fastype_of rhs);
+    val typidx = find_index (fn TFree (w, _) => Name.aT = w | _ => false) typargs;
+  in (c, (rhs, typidx)) end;
+
 fun add_const_in_class class ((c, rhs), syn) thy =
   let
     val prfx = (Logic.const_of_class o NameSpace.base) class;
@@ -759,17 +757,16 @@
     val ty'' = subst_typ ty';
     val c' = mk_name c;
     val def = (c, Logic.mk_equals (Const (c', ty'), rhs'));
-    val (syn', _) = fork_mixfix true NONE syn;
+    val (syn', _) = fork_mixfix true (SOME class) syn;
     fun interpret def thy =
       let
         val def' = symmetric def;
         val def_eq = Thm.prop_of def';
-        val typargs = Sign.const_typargs thy (c', fastype_of rhs);
-        val typidx = find_index (fn TFree (w, _) => Name.aT = w | _ => false) typargs;
+        val entry = mk_operation_entry thy (c', rhs);
       in
         thy
         |> class_interpretation class [def'] [def_eq]
-        |> register_const (class, ((c', (rhs, typidx)), def'))
+        |> register_const class (entry, def')
       end;
   in
     thy
@@ -783,21 +780,25 @@
     |> Sign.restore_naming thy
   end;
 
-fun add_abbrev_in_class class ((c, rhs), syn) thy =
+
+(* abbreviation in class target *)
+
+fun add_abbrev_in_class class prmode ((c, rhs), syn) thy =
   let
-    val local_sort = (#local_sort o the_class_data thy) class;
-    val subst_typ = Term.map_type_tfree (fn var as (w, sort) =>
-      if w = Name.aT then TFree (w, local_sort) else TFree var);
-    val ty = fastype_of rhs;
-    val rhs' = map_types subst_typ rhs;
+    val prfx = (Logic.const_of_class o NameSpace.base) class;
+    fun mk_name c =
+      let
+        val n1 = Sign.full_name thy c;
+        val n2 = NameSpace.qualifier n1;
+        val n3 = NameSpace.base n1;
+      in NameSpace.implode [n2, prfx, prfx, n3] end;
+    val c' = mk_name c;
+    val rhs' = export_fixes thy class rhs;
+    val ty' = fastype_of rhs';
   in
     thy
-    |> Sign.parent_path (*FIXME*)
-    |> Sign.add_abbrev Syntax.internalM [] (c, rhs)
-    |-> (fn (lhs as Const (c', _), _) => register_abbrev class c'
-      (*#> Sign.add_const_constraint (c', SOME ty)*)
-      #> pair lhs)
-    ||> Sign.restore_naming thy
+    |> Sign.add_notation prmode [(Const (c', ty'), syn)]
+    |> register_abbrev class c'
   end;
 
 
--- a/src/Pure/Isar/theory_target.ML	Mon Oct 08 08:04:26 2007 +0200
+++ b/src/Pure/Isar/theory_target.ML	Mon Oct 08 08:04:28 2007 +0200
@@ -115,7 +115,8 @@
         val U = map #2 xs ---> T;
         val t = Term.list_comb (Const (Sign.full_name thy c, U), map Free xs);
         val (mx1, mx2) = Class.fork_mixfix is_loc some_class mx;
-        val thy' = Sign.add_consts_authentic (ContextPosition.properties_of lthy) [(c, U, mx1)] thy;
+        val mx3 = if is_loc then NoSyn else mx1;
+        val thy' = Sign.add_consts_authentic (ContextPosition.properties_of lthy) [(c, U, mx3)] thy;
       in (((c, mx2), t), thy') end;
 
     fun const_class (SOME class) ((c, _), mx) (_, t) =
@@ -182,12 +183,18 @@
     val U = Term.fastype_of u;
     val u' = singleton (Variable.export_terms (Variable.declare_term u target) thy_ctxt) u;
     val (mx1, mx2) = Class.fork_mixfix is_loc some_class mx;
+    val mx3 = if is_loc then NoSyn else mx1;
+    fun add_abbrev_in_class NONE = K I
+      | add_abbrev_in_class (SOME class) =
+          Class.add_abbrev_in_class class prmode;
   in
     lthy
     |> LocalTheory.theory_result
         (Sign.add_abbrev (#1 prmode) (ContextPosition.properties_of lthy) (c, u'))
-    |-> (fn (lhs as Const (full_c, _), rhs) => LocalTheory.theory (Sign.add_notation prmode [(lhs, mx1)])
+    |-> (fn (lhs as Const (full_c, _), rhs) => LocalTheory.theory (Sign.add_notation prmode [(lhs, mx3)])
     #> is_loc ? internal_abbrev prmode ((c, mx2), Term.list_comb (Const (full_c, U), xs))
+    #> LocalTheory.raw_theory
+         (add_abbrev_in_class some_class ((c, Term.list_comb (Const (full_c, U), xs)), mx1))
     #> local_abbrev (c, rhs))
   end;
 
@@ -373,14 +380,10 @@
     val thy = ProofContext.theory_of ctxt;
     val is_loc = loc <> "";
     val some_class = Class.class_of_locale thy loc;
-    fun class_init (SOME class) =
-          Class.init [class]
-      | class_init NONE =
-          I;
   in
     ctxt
     |> Data.put (if is_loc then SOME loc else NONE)
-    |> class_init some_class
+    |> the_default I (Option.map Class.init some_class)
     |> LocalTheory.init (NameSpace.base loc)
      {pretty = pretty loc,
       consts = consts is_loc some_class,