explicit type schemes for functions
authorhaftmann
Fri May 23 16:05:07 2008 +0200 (2008-05-23)
changeset 26970bc28e7bcb765
parent 26969 cf3f998d0631
child 26971 160117247294
explicit type schemes for functions
src/Pure/Isar/code.ML
src/Pure/Isar/code_unit.ML
src/Tools/nbe.ML
     1.1 --- a/src/Pure/Isar/code.ML	Fri May 23 16:05:04 2008 +0200
     1.2 +++ b/src/Pure/Isar/code.ML	Fri May 23 16:05:07 2008 +0200
     1.3 @@ -38,7 +38,7 @@
     1.4    val get_datatype_of_constr: theory -> string -> string option
     1.5    val get_case_data: theory -> string -> (int * string list) option
     1.6    val is_undefined: theory -> string -> bool
     1.7 -  val default_typ: theory -> string -> typ
     1.8 +  val default_typ: theory -> string -> (string * sort) list * typ
     1.9  
    1.10    val preprocess_conv: cterm -> thm
    1.11    val preprocess_term: theory -> term -> term
    1.12 @@ -89,7 +89,7 @@
    1.13    val empty = [];
    1.14    val copy = I;
    1.15    val extend = I;
    1.16 -  fun merge _ = AList.merge (op =) (K true);
    1.17 +  fun merge _ = AList.merge (op = : string * string -> bool) (K true);
    1.18  );
    1.19  
    1.20  fun add_attribute (attr as (name, _)) =
    1.21 @@ -510,9 +510,14 @@
    1.22  
    1.23  (** theorem transformation and certification **)
    1.24  
    1.25 +fun const_of thy = dest_Const o fst o strip_comb o fst o Logic.dest_equals
    1.26 +  o ObjectLogic.drop_judgment thy o Thm.plain_prop_of;
    1.27 +
    1.28 +fun const_of_func thy = AxClass.unoverload_const thy o const_of thy;
    1.29 +
    1.30  fun common_typ_funcs [] = []
    1.31    | common_typ_funcs [thm] = [thm]
    1.32 -  | common_typ_funcs (thms as thm :: _) =
    1.33 +  | common_typ_funcs (thms as thm :: _) = (*FIXME is too general*)
    1.34        let
    1.35          val thy = Thm.theory_of_thm thm;
    1.36          fun incr_thm thm max =
    1.37 @@ -521,7 +526,7 @@
    1.38              val max' = Thm.maxidx_of thm' + 1;
    1.39            in (thm', max') end;
    1.40          val (thms', maxidx) = fold_map incr_thm thms 0;
    1.41 -        val ty1 :: tys = map (snd o CodeUnit.head_func) thms';
    1.42 +        val ty1 :: tys = map (snd o const_of thy) thms';
    1.43          fun unify ty env = Sign.typ_unify thy (ty1, ty) env
    1.44            handle Type.TUNIFY =>
    1.45              error ("Type unificaton failed, while unifying defining equations\n"
    1.46 @@ -533,8 +538,6 @@
    1.47            cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [];
    1.48        in map (Thm.instantiate (instT, [])) thms' end;
    1.49  
    1.50 -fun const_of_func thy = AxClass.unoverload_const thy o CodeUnit.head_func;
    1.51 -
    1.52  fun certify_const thy const thms =
    1.53    let
    1.54      fun cert thm = if const = const_of_func thy thm
    1.55 @@ -569,7 +572,7 @@
    1.56        |> map (Thm.transfer thy)
    1.57      fun sorts_of [Type (_, tys)] = map (snd o dest_TVar) tys
    1.58        | sorts_of tys = map (snd o dest_TVar) tys;
    1.59 -    val sorts = map (sorts_of o Sign.const_typargs thy o CodeUnit.head_func) funcs;
    1.60 +    val sorts = map (sorts_of o Sign.const_typargs thy o const_of thy) funcs;
    1.61    in sorts end;
    1.62  
    1.63  fun weakest_constraints thy algebra (class, tyco) =
    1.64 @@ -640,14 +643,14 @@
    1.65      fun check_typ_classparam tyco (c, thm) =
    1.66            let
    1.67              val SOME class = AxClass.class_of_param thy c;
    1.68 -            val (_, ty) = CodeUnit.head_func thm;
    1.69 +            val (_, ty) = const_of thy thm;
    1.70              val ty_decl = classparam_weakest_typ thy class (c, tyco);
    1.71              val ty_strongest = classparam_strongest_typ thy class (c, tyco);
    1.72              fun constrain thm = 
    1.73                let
    1.74                  val max = Thm.maxidx_of thm + 1;
    1.75                  val ty_decl' = Logic.incr_tvar max ty_decl;
    1.76 -                val (_, ty') = CodeUnit.head_func thm;
    1.77 +                val (_, ty') = const_of thy thm;
    1.78                  val (env, _) = Sign.typ_unify thy (ty_decl', ty') (Vartab.empty, max);
    1.79                  val instT = Vartab.fold (fn (x_i, (sort, ty)) =>
    1.80                    cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [];
    1.81 @@ -669,7 +672,7 @@
    1.82            end;
    1.83      fun check_typ_fun (c, thm) =
    1.84        let
    1.85 -        val (_, ty) = CodeUnit.head_func thm;
    1.86 +        val (_, ty) = const_of thy thm;
    1.87          val ty_decl = Sign.the_const_type thy c;
    1.88        in if Sign.typ_equiv thy (Type.strip_sorts ty_decl, Type.strip_sorts ty)
    1.89          then thm
    1.90 @@ -926,8 +929,9 @@
    1.91    |> map (CodeUnit.rewrite_func ((#inlines o the_thmproc o the_exec) thy))
    1.92    |> fold (fn (_, (_, f)) => apply_inline_proc thy f) ((#inline_procs o the_thmproc o the_exec) thy)
    1.93  (*FIXME - must check: rewrite rule, defining equation, proper constant |> map (snd o check_func false thy) *)
    1.94 -  |> common_typ_funcs
    1.95 -  |> map (AxClass.unoverload thy);
    1.96 +  |> map (AxClass.unoverload thy)
    1.97 +  |> common_typ_funcs;
    1.98 +
    1.99  
   1.100  fun preprocess_conv ct =
   1.101    let
   1.102 @@ -984,10 +988,10 @@
   1.103    end;
   1.104  
   1.105  fun default_typ thy c = case default_typ_proto thy c
   1.106 - of SOME ty => ty
   1.107 + of SOME ty => CodeUnit.typscheme thy (c, ty)
   1.108    | NONE => (case get_funcs thy c
   1.109       of thm :: _ => snd (CodeUnit.head_func (AxClass.unoverload thy thm))
   1.110 -      | [] => Sign.the_const_type thy c);
   1.111 +      | [] => CodeUnit.typscheme thy (c, Sign.the_const_type thy c));
   1.112  
   1.113  end; (*local*)
   1.114  
     2.1 --- a/src/Pure/Isar/code_unit.ML	Fri May 23 16:05:04 2008 +0200
     2.2 +++ b/src/Pure/Isar/code_unit.ML	Fri May 23 16:05:07 2008 +0200
     2.3 @@ -14,6 +14,7 @@
     2.4    val try_thm: (thm -> thm) -> thm -> thm option
     2.5  
     2.6    (*typ instantiations*)
     2.7 +  val typscheme: theory -> string * typ -> (string * sort) list * typ
     2.8    val inst_thm: sort Vartab.table -> thm -> thm
     2.9    val constrain_thm: sort -> thm -> thm
    2.10  
    2.11 @@ -39,7 +40,7 @@
    2.12    val assert_rew: thm -> thm
    2.13    val mk_rew: thm -> thm
    2.14    val mk_func: thm -> thm
    2.15 -  val head_func: thm -> string * typ
    2.16 +  val head_func: thm -> string * ((string * sort) list * typ)
    2.17    val expand_eta: int -> thm -> thm
    2.18    val rewrite_func: thm list -> thm -> thm
    2.19    val norm_args: thm list -> thm list 
    2.20 @@ -72,6 +73,13 @@
    2.21  
    2.22  (* utilities *)
    2.23  
    2.24 +fun typscheme thy (c, ty) =
    2.25 +  let
    2.26 +    fun dest (TVar ((v, 0), sort)) = (v, sort)
    2.27 +      | dest ty = error ("Illegal type parameter in type scheme: " ^ Syntax.string_of_typ_global thy ty);
    2.28 +    val vs = map dest (Sign.const_typargs thy (c, ty));
    2.29 +  in (vs, ty) end;
    2.30 +
    2.31  fun inst_thm tvars' thm =
    2.32    let
    2.33      val thy = Thm.theory_of_thm thm;
    2.34 @@ -297,7 +305,7 @@
    2.35      fun ty_sorts (c, ty) =
    2.36        let
    2.37          val ty_decl = (Logic.unvarifyT o Sign.the_const_type thy) c;
    2.38 -        val (tyco, vs_decl) = last_typ (c, ty) ty_decl;
    2.39 +        val (tyco, _) = last_typ (c, ty) ty_decl;
    2.40          val (_, vs) = last_typ (c, ty) ty;
    2.41        in ((tyco, map snd vs), (c, (map fst vs, ty_decl))) end;
    2.42      fun add ((tyco', sorts'), c) ((tyco, sorts), cs) =
    2.43 @@ -399,8 +407,8 @@
    2.44    let
    2.45      val thy = Thm.theory_of_thm thm;
    2.46      val Const (c, ty) = (fst o strip_comb o fst o Logic.dest_equals
    2.47 -      o ObjectLogic.drop_judgment thy o Thm.plain_prop_of) thm;
    2.48 -  in (c, ty) end;
    2.49 +      o (*ObjectLogic.drop_judgment thy o *)Thm.plain_prop_of) thm;
    2.50 +  in (c, typscheme thy (c, ty)) end;
    2.51  
    2.52  
    2.53  (* case cerificates *)
     3.1 --- a/src/Tools/nbe.ML	Fri May 23 16:05:04 2008 +0200
     3.2 +++ b/src/Tools/nbe.ML	Fri May 23 16:05:07 2008 +0200
     3.3 @@ -326,7 +326,7 @@
     3.4            let
     3.5              val ts' = take_until is_dict ts;
     3.6              val c = (the o CodeName.const_rev thy o the o Inttab.lookup idx_tab) idx;
     3.7 -            val T = Code.default_typ thy c;
     3.8 +            val (_, T) = Code.default_typ thy c;
     3.9              val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, [])) T;
    3.10              val typidx' = typidx + maxidx_of_typ T' + 1;
    3.11            in of_apps bounds (Term.Const (c, T'), ts') typidx' end