crude support for type aliasses and corresponding constant signatures
authorhaftmann
Wed, 02 Dec 2009 17:53:35 +0100
changeset 33940 317933ce3712
parent 33939 fcb50b497763
child 33941 40408e6b833b
crude support for type aliasses and corresponding constant signatures
src/Pure/Isar/code.ML
--- a/src/Pure/Isar/code.ML	Wed Dec 02 17:53:34 2009 +0100
+++ b/src/Pure/Isar/code.ML	Wed Dec 02 17:53:35 2009 +0100
@@ -12,6 +12,10 @@
   val read_bare_const: theory -> string -> string * typ
   val read_const: theory -> string -> string
   val string_of_const: theory -> string -> string
+  val cert_signature: theory -> typ -> typ
+  val read_signature: theory -> string -> typ
+  val const_typ: theory -> string -> typ
+  val subst_signatures: theory -> term -> term
   val args_number: theory -> string -> int
 
   (*constructor sets*)
@@ -31,6 +35,10 @@
   val standard_typscheme: theory -> thm list -> thm list
 
   (*executable code*)
+  val add_type: string -> theory -> theory
+  val add_type_cmd: string -> theory -> theory
+  val add_signature: string * typ -> theory -> theory
+  val add_signature_cmd: string * string -> theory -> theory
   val add_datatype: (string * typ) list -> theory -> theory
   val add_datatype_cmd: string list -> theory -> theory
   val type_interpretation:
@@ -102,6 +110,8 @@
 
 (* constants *)
 
+fun typ_equiv tys = Type.raw_instance tys andalso Type.raw_instance (swap tys);
+
 fun check_bare_const thy t = case try dest_Const t
  of SOME c_ty => c_ty
   | NONE => error ("Not a constant: " ^ Syntax.string_of_term_global thy t);
@@ -147,6 +157,7 @@
 
 datatype spec = Spec of {
   history_concluded: bool,
+  signatures: int Symtab.table * typ Symtab.table,
   eqns: ((bool * eqns) * (serial * eqns) list) Symtab.table
     (*with explicit history*),
   dtyps: ((serial * ((string * sort) list * (string * typ list) list)) list) Symtab.table
@@ -154,16 +165,19 @@
   cases: (int * (int * string list)) Symtab.table * unit Symtab.table
 };
 
-fun make_spec (history_concluded, (eqns, (dtyps, cases))) =
-  Spec { history_concluded = history_concluded, eqns = eqns, dtyps = dtyps, cases = cases };
-fun map_spec f (Spec { history_concluded = history_concluded, eqns = eqns,
-  dtyps = dtyps, cases = cases }) =
-  make_spec (f (history_concluded, (eqns, (dtyps, cases))));
-fun merge_spec (Spec { history_concluded = _, eqns = eqns1,
+fun make_spec (history_concluded, ((signatures, eqns), (dtyps, cases))) =
+  Spec { history_concluded = history_concluded,
+    signatures = signatures, eqns = eqns, dtyps = dtyps, cases = cases };
+fun map_spec f (Spec { history_concluded = history_concluded, signatures = signatures,
+  eqns = eqns, dtyps = dtyps, cases = cases }) =
+  make_spec (f (history_concluded, ((signatures, eqns), (dtyps, cases))));
+fun merge_spec (Spec { history_concluded = _, signatures = (tycos1, sigs1), eqns = eqns1,
     dtyps = dtyps1, cases = (cases1, undefs1) },
-  Spec { history_concluded = _, eqns = eqns2,
+  Spec { history_concluded = _, signatures = (tycos2, sigs2), eqns = eqns2,
     dtyps = dtyps2, cases = (cases2, undefs2) }) =
   let
+    val signatures = (Symtab.merge (op =) (tycos1, tycos2),
+      Symtab.merge typ_equiv (sigs1, sigs2));
     fun merge_eqns ((_, history1), (_, history2)) =
       let
         val raw_history = AList.merge (op = : serial * serial -> bool)
@@ -176,14 +190,16 @@
     val dtyps = Symtab.join (K (AList.merge (op =) (K true))) (dtyps1, dtyps2);
     val cases = (Symtab.merge (K true) (cases1, cases2),
       Symtab.merge (K true) (undefs1, undefs2));
-  in make_spec (false, (eqns, (dtyps, cases))) end;
+  in make_spec (false, ((signatures, eqns), (dtyps, cases))) end;
 
 fun history_concluded (Spec { history_concluded, ... }) = history_concluded;
+fun the_signatures (Spec { signatures, ... }) = signatures;
 fun the_eqns (Spec { eqns, ... }) = eqns;
 fun the_dtyps (Spec { dtyps, ... }) = dtyps;
 fun the_cases (Spec { cases, ... }) = cases;
 val map_history_concluded = map_spec o apfst;
-val map_eqns = map_spec o apsnd o apfst;
+val map_signatures = map_spec o apsnd o apfst o apfst;
+val map_eqns = map_spec o apsnd o apfst o apsnd;
 val map_dtyps = map_spec o apsnd o apsnd o apfst;
 val map_cases = map_spec o apsnd o apsnd o apsnd;
 
@@ -236,11 +252,11 @@
 structure Code_Data = TheoryDataFun
 (
   type T = spec * data Unsynchronized.ref;
-  val empty = (make_spec (false,
-    (Symtab.empty, (Symtab.empty, (Symtab.empty, Symtab.empty)))), Unsynchronized.ref empty_data);
+  val empty = (make_spec (false, (((Symtab.empty, Symtab.empty), Symtab.empty),
+    (Symtab.empty, (Symtab.empty, Symtab.empty)))), Unsynchronized.ref empty_data);
   fun copy (spec, data) = (spec, Unsynchronized.ref (! data));
   val extend = copy;
-  fun merge pp ((spec1, data1), (spec2, data2)) =
+  fun merge _ ((spec1, data1), (spec2, data2)) =
     (merge_spec (spec1, spec2), Unsynchronized.ref empty_data);
 );
 
@@ -334,7 +350,44 @@
 
 (* constants *)
 
-fun args_number thy = length o fst o strip_type o Sign.the_const_type thy;
+fun arity_number thy tyco = case Symtab.lookup ((fst o the_signatures o the_exec) thy) tyco
+ of SOME n => n
+  | NONE => Sign.arity_number thy tyco;
+
+fun build_tsig thy =
+  let
+    val (tycos, _) = (the_signatures o the_exec) thy;
+    val decls = (#types o Type.rep_tsig o Sign.tsig_of) thy
+      |> apsnd (Symtab.fold (fn (tyco, n) =>
+          Symtab.update (tyco, Type.LogicalType n)) tycos);
+  in Type.build_tsig ((Name_Space.empty "", Sorts.empty_algebra), [], decls) end;
+
+fun cert_signature thy = Logic.varifyT o Type.cert_typ (build_tsig thy) o Type.no_tvars;
+
+fun read_signature thy = cert_signature thy o Type.strip_sorts
+  o Syntax.parse_typ (ProofContext.init thy);
+
+fun expand_signature thy = Type.cert_typ_mode Type.mode_syntax (Sign.tsig_of thy);
+
+fun lookup_typ thy = Symtab.lookup ((snd o the_signatures o the_exec) thy);
+
+fun const_typ thy c = case lookup_typ thy c
+ of SOME ty => ty
+  | NONE => (Type.strip_sorts o Sign.the_const_type thy) c;
+
+fun subst_signature thy c ty =
+  let
+    fun mk_subst (Type (tyco, tys1)) (ty2 as Type (tyco2, tys2)) =
+          fold2 mk_subst tys1 tys2
+      | mk_subst ty (TVar (v, sort)) = Vartab.update (v, ([], ty))
+  in case lookup_typ thy c
+   of SOME ty' => Envir.subst_type (mk_subst ty (expand_signature thy ty') Vartab.empty) ty'
+    | NONE => ty
+  end;
+
+fun subst_signatures thy = map_aterms (fn Const (c, ty) => Const (c, subst_signature thy c ty) | t => t);
+
+fun args_number thy = length o fst o strip_type o const_typ thy;
 
 
 (* datatypes *)
@@ -355,9 +408,10 @@
         val _ = if length tfrees <> length vs
           then no_constr "type variables missing in datatype" c_ty else ();
       in (tyco, vs) end;
-    fun ty_sorts (c, ty) =
+    fun ty_sorts (c, raw_ty) =
       let
-        val ty_decl = (Logic.unvarifyT o Sign.the_const_type thy) c;
+        val ty = subst_signature thy c raw_ty;
+        val ty_decl = (Logic.unvarifyT o const_typ thy) c;
         val (tyco, _) = last_typ (c, ty) ty_decl;
         val (_, vs) = last_typ (c, ty) ty;
       in ((tyco, map snd vs), (c, (map fst vs, ty))) end;
@@ -382,13 +436,13 @@
 fun get_datatype thy tyco =
   case these (Symtab.lookup ((the_dtyps o the_exec) thy) tyco)
    of (_, spec) :: _ => spec
-    | [] => Sign.arity_number thy tyco
+    | [] => arity_number thy tyco
         |> Name.invents Name.context Name.aT
         |> map (rpair [])
         |> rpair [];
 
 fun get_datatype_of_constr thy c =
-  case (snd o strip_type o Sign.the_const_type thy) c
+  case (snd o strip_type o const_typ thy) c
    of Type (tyco, _) => if member (op =) ((map fst o snd o get_datatype thy) tyco) c
        then SOME tyco else NONE
     | _ => NONE;
@@ -446,21 +500,25 @@
           ("Variable with application on left hand side of equation\n"
             ^ Display.string_of_thm_global thy thm)
       | check n (t1 $ t2) = (check (n+1) t1; check 0 t2)
-      | check n (Const (c_ty as (c, ty))) = if n = (length o fst o strip_type) ty
-          then if not proper orelse is_constr_pat (AxClass.unoverload_const thy c_ty)
-            then ()
-            else bad_thm (quote c ^ " is not a constructor, on left hand side of equation\n"
-              ^ Display.string_of_thm_global thy thm)
-          else bad_thm
-            ("Partially applied constant " ^ quote c ^ " on left hand side of equation\n"
-               ^ Display.string_of_thm_global thy thm);
+      | check n (Const (c_ty as (_, ty))) =
+          let
+            val c' = AxClass.unoverload_const thy c_ty
+          in if n = (length o fst o strip_type o subst_signature thy c') ty
+            then if not proper orelse is_constr_pat c'
+              then ()
+              else bad_thm (quote c ^ " is not a constructor, on left hand side of equation\n"
+                ^ Display.string_of_thm_global thy thm)
+            else bad_thm
+              ("Partially applied constant " ^ quote c ^ " on left hand side of equation\n"
+                 ^ Display.string_of_thm_global thy thm)
+          end;
     val _ = map (check 0) args;
     val _ = if not proper orelse is_linear thm then ()
       else bad_thm ("Duplicate variables on left hand side of equation\n"
         ^ Display.string_of_thm_global thy thm);
     val _ = if (is_none o AxClass.class_of_param thy) c
       then ()
-      else bad_thm ("Polymorphic constant as head in equation\n"
+      else bad_thm ("Overloaded constant as head in equation\n"
         ^ Display.string_of_thm_global thy thm)
     val _ = if not (is_constr thy c)
       then ()
@@ -488,29 +546,34 @@
 fun mk_eqn_liberal thy = Option.map (fn (thm, _) => (thm, is_linear thm))
   o try_thm (gen_assert_eqn thy (K true)) o rpair false o meta_rewrite thy;
 
-(*those following are permissive wrt. to overloaded constants!*)
+val head_eqn = dest_Const o fst o strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of;
 
-val head_eqn = dest_Const o fst o strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of;
 fun const_typ_eqn thy thm =
   let
     val (c, ty) = head_eqn thm;
     val c' = AxClass.unoverload_const thy (c, ty);
+      (*permissive wrt. to overloaded constants!*)
   in (c', ty) end;
+
 fun const_eqn thy = fst o const_typ_eqn thy;
 
-fun typscheme thy (c, ty) =
+fun raw_typscheme thy (c, ty) =
   (map dest_TFree (Sign.const_typargs thy (c, ty)), Type.strip_sorts ty);
+
+fun typscheme thy (c, ty) = raw_typscheme thy (c, subst_signature thy c ty);
+
 fun typscheme_eqn thy = typscheme thy o apsnd Logic.unvarifyT o const_typ_eqn thy;
+
 fun typscheme_eqns thy c [] = 
       let
-        val raw_ty = Sign.the_const_type thy c;
+        val raw_ty = const_typ thy c;
         val tvars = Term.add_tvar_namesT raw_ty [];
         val tvars' = case AxClass.class_of_param thy c
          of SOME class => [TFree (Name.aT, [class])]
           | NONE => Name.invent_list [] Name.aT (length tvars)
               |> map (fn v => TFree (v, []));
         val ty = typ_subst_TVars (tvars ~~ tvars') raw_ty;
-      in typscheme thy (c, ty) end
+      in raw_typscheme thy (c, ty) end
   | typscheme_eqns thy c (thms as thm :: _) = typscheme_eqn thy thm;
 
 fun assert_eqns_const thy c eqns =
@@ -639,6 +702,34 @@
 
 (** declaring executable ingredients **)
 
+(* constant signatures *)
+
+fun add_type tyco thy =
+  case Symtab.lookup ((snd o #types o Type.rep_tsig o Sign.tsig_of) thy) tyco
+   of SOME (Type.Abbreviation (vs, _, _)) =>
+          (map_exec_purge NONE o map_signatures o apfst)
+            (Symtab.update (tyco, length vs)) thy
+    | _ => error ("No such type abbreviation: " ^ quote tyco);
+
+fun add_type_cmd s thy = add_type (Sign.intern_type thy s) thy;
+
+fun gen_add_signature prep_const prep_signature (raw_c, raw_ty) thy =
+  let
+    val c = prep_const thy raw_c;
+    val ty = prep_signature thy raw_ty;
+    val ty' = expand_signature thy ty;
+    val ty'' = Sign.the_const_type thy c;
+    val _ = if typ_equiv (ty', ty'') then () else
+      error ("Illegal constant signature: " ^ Syntax.string_of_typ_global thy ty);
+  in
+    thy
+    |> (map_exec_purge NONE o map_signatures o apsnd) (Symtab.update (c, ty))
+  end;
+
+val add_signature = gen_add_signature (K I) cert_signature;
+val add_signature_cmd = gen_add_signature read_const read_signature;
+
+
 (* datatypes *)
 
 structure Type_Interpretation =