tuned interface of structure Code
authorhaftmann
Tue, 07 Jul 2009 17:21:27 +0200
changeset 31957 a9742afd403e
parent 31956 c3844c4d0c2c
child 31958 2133f596c520
tuned interface of structure Code
src/HOL/Tools/recfun_codegen.ML
src/Pure/Isar/code.ML
src/Tools/Code/code_preproc.ML
src/Tools/Code/code_target.ML
src/Tools/Code/code_thingol.ML
src/Tools/nbe.ML
--- a/src/HOL/Tools/recfun_codegen.ML	Tue Jul 07 17:21:26 2009 +0200
+++ b/src/HOL/Tools/recfun_codegen.ML	Tue Jul 07 17:21:27 2009 +0200
@@ -29,7 +29,7 @@
         val (thm', _) = Code.mk_eqn thy (K false) (thm, true)
       in
         thy
-        |> ModuleData.map (Symtab.update (Code.const_eqn thy thm', module_name))
+        |> ModuleData.map (Symtab.update (fst (Code.const_typ_eqn thy thm'), module_name))
         |> Code.add_eqn thm'
       end;
 
@@ -44,7 +44,7 @@
 fun expand_eta thy [] = []
   | expand_eta thy (thms as thm :: _) =
       let
-        val (_, ty) = Code.const_typ_eqn thm;
+        val (_, ty) = Code.const_typ_eqn thy thm;
       in if null (Term.add_tvarsT ty []) orelse (null o fst o strip_type) ty
         then thms
         else map (Code.expand_eta thy 1) thms
--- a/src/Pure/Isar/code.ML	Tue Jul 07 17:21:26 2009 +0200
+++ b/src/Pure/Isar/code.ML	Tue Jul 07 17:21:27 2009 +0200
@@ -7,22 +7,18 @@
 
 signature CODE =
 sig
+  (*constants*)
+  val string_of_const: theory -> string -> string
+  val args_number: theory -> string -> int
+  val check_const: theory -> term -> string
+  val read_bare_const: theory -> string -> string * typ
+  val read_const: theory -> string -> string
+  val typscheme: theory -> string * typ -> (string * sort) list * typ
+
   (*constructor sets*)
   val constrset_of_consts: theory -> (string * typ) list
     -> string * ((string * sort) list * (string * typ list) list)
 
-  (*typ instantiations*)
-  val typscheme: theory -> string * typ -> (string * sort) list * typ
-  val inst_thm: theory -> sort Vartab.table -> thm -> thm
-
-  (*constants*)
-  val string_of_typ: theory -> typ -> string
-  val string_of_const: theory -> string -> string
-  val no_args: theory -> string -> int
-  val check_const: theory -> term -> string
-  val read_bare_const: theory -> string -> string * typ
-  val read_const: theory -> string -> string
-
   (*constant aliasses*)
   val add_const_alias: thm -> theory -> theory
   val triv_classes: theory -> class list
@@ -35,22 +31,12 @@
   val assert_eqn: theory -> thm * bool -> thm * bool
   val assert_eqns_const: theory -> string
     -> (thm * bool) list -> (thm * bool) list
-  val const_typ_eqn: thm -> string * typ
-  val const_eqn: theory -> thm -> string
+  val const_typ_eqn: theory -> thm -> string * typ
   val typscheme_eqn: theory -> thm -> (string * sort) list * typ
   val expand_eta: theory -> int -> thm -> thm
-  val rewrite_eqn: simpset -> thm -> thm
-  val rewrite_head: thm list -> thm -> thm
   val norm_args: theory -> thm list -> thm list 
   val norm_varnames: theory -> thm list -> thm list
 
-  (*case certificates*)
-  val case_cert: thm -> string * (int * string list)
-
-  (*infrastructure*)
-  val add_attribute: string * attribute parser -> theory -> theory
-  val purge_data: theory -> theory
-
   (*executable content*)
   val add_datatype: (string * typ) list -> theory -> theory
   val add_datatype_cmd: string list -> theory -> theory
@@ -58,25 +44,26 @@
     (string * ((string * sort) list * (string * typ list) list)
       -> theory -> theory) -> theory -> theory
   val add_eqn: thm -> theory -> theory
+  val add_eqnl: string * (thm * bool) list lazy -> theory -> theory
   val add_nbe_eqn: thm -> theory -> theory
   val add_default_eqn: thm -> theory -> theory
   val add_default_eqn_attribute: attribute
   val add_default_eqn_attrib: Attrib.src
   val del_eqn: thm -> theory -> theory
   val del_eqns: string -> theory -> theory
-  val add_eqnl: string * (thm * bool) list lazy -> theory -> theory
   val add_case: thm -> theory -> theory
   val add_undefined: string -> theory -> theory
-
-  (*data retrieval*)
   val get_datatype: theory -> string -> ((string * sort) list * (string * typ list) list)
   val get_datatype_of_constr: theory -> string -> string option
-  val default_typscheme: theory -> string -> (string * sort) list * typ
   val these_eqns: theory -> string -> (thm * bool) list
   val all_eqns: theory -> (thm * bool) list
   val get_case_scheme: theory -> string -> (int * (int * string list)) option
   val undefineds: theory -> string list
   val print_codesetup: theory -> unit
+
+  (*infrastructure*)
+  val add_attribute: string * attribute parser -> theory -> theory
+  val purge_data: theory -> theory
 end;
 
 signature CODE_DATA_ARGS =
@@ -117,7 +104,7 @@
  of SOME (c, tyco) => Sign.extern_const thy c ^ " " ^ enclose "[" "]" (Sign.extern_type thy tyco)
   | NONE => Sign.extern_const thy c;
 
-fun no_args thy = length o fst o strip_type o Sign.the_const_type thy;
+fun args_number thy = length o fst o strip_type o Sign.the_const_type thy;
 
 
 (* utilities *)
@@ -125,21 +112,7 @@
 fun typscheme thy (c, ty) =
   let
     val ty' = Logic.unvarifyT ty;
-    fun dest (TFree (v, sort)) = (v, sort)
-      | dest ty = error ("Illegal type parameter in type scheme: " ^ Syntax.string_of_typ_global thy ty);
-    val vs = map dest (Sign.const_typargs thy (c, ty'));
-  in (vs, Type.strip_sorts ty') end;
-
-fun inst_thm thy tvars' thm =
-  let
-    val tvars = (Term.add_tvars o Thm.prop_of) thm [];
-    val inter_sort = Sorts.inter_sort (Sign.classes_of thy);
-    fun mk_inst (tvar as (v, sort)) = case Vartab.lookup tvars' v
-     of SOME sort' => SOME (pairself (Thm.ctyp_of thy o TVar)
-          (tvar, (v, inter_sort (sort, sort'))))
-      | NONE => NONE;
-    val insts = map_filter mk_inst tvars;
-  in Thm.instantiate (insts, []) thm end;
+  in (map dest_TFree (Sign.const_typargs thy (c, ty')), Type.strip_sorts ty') end;
 
 fun expand_eta thy k thm =
   let
@@ -173,23 +146,6 @@
     |> Conv.fconv_rule Drule.beta_eta_conversion
   end;
 
-fun eqn_conv conv =
-  let
-    fun lhs_conv ct = if can Thm.dest_comb ct
-      then (Conv.combination_conv lhs_conv conv) ct
-      else Conv.all_conv ct;
-  in Conv.combination_conv (Conv.arg_conv lhs_conv) conv end;
-
-fun head_conv conv =
-  let
-    fun lhs_conv ct = if can Thm.dest_comb ct
-      then (Conv.fun_conv lhs_conv) ct
-      else conv ct;
-  in Conv.fun_conv (Conv.arg_conv lhs_conv) end;
-
-val rewrite_eqn = Conv.fconv_rule o eqn_conv o Simplifier.rewrite;
-val rewrite_head = Conv.fconv_rule o head_conv o MetaSimplifier.rewrite false;
-
 fun norm_args thy thms =
   let
     val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
@@ -265,6 +221,19 @@
   end;
 
 
+(* reading constants as terms *)
+
+fun check_bare_const thy t = case try dest_Const t
+ of SOME c_ty => c_ty
+  | NONE => error ("Not a constant: " ^ Syntax.string_of_term_global thy t);
+
+fun check_const thy = AxClass.unoverload_const thy o check_bare_const thy;
+
+fun read_bare_const thy = check_bare_const thy o Syntax.read_term_global thy;
+
+fun read_const thy = AxClass.unoverload_const thy o read_bare_const thy;
+
+
 (* const aliasses *)
 
 structure ConstAlias = TheoryDataFun
@@ -280,16 +249,29 @@
 
 fun add_const_alias thm thy =
   let
-    val lhs_rhs = case try Logic.dest_equals (Thm.prop_of thm)
+    val (ofclass, eqn) = case try Logic.dest_equals (Thm.prop_of thm)
+     of SOME ofclass_eq => ofclass_eq
+      | _ => error ("Bad certificate: " ^ Display.string_of_thm thm);
+    val (T, class) = case try Logic.dest_of_class ofclass
+     of SOME T_class => T_class
+      | _ => error ("Bad certificate: " ^ Display.string_of_thm thm);
+    val tvar = case try Term.dest_TVar T
+     of SOME tvar => tvar
+      | _ => error ("Bad type: " ^ Display.string_of_thm thm);
+    val _ = if Term.add_tvars eqn [] = [tvar] then ()
+      else error ("Inconsistent type: " ^ Display.string_of_thm thm);
+    val lhs_rhs = case try Logic.dest_equals eqn
      of SOME lhs_rhs => lhs_rhs
-      | _ => error ("Not an equation: " ^ Display.string_of_thm thm);
-    val c_c' = case try (pairself (AxClass.unoverload_const thy o dest_Const)) lhs_rhs
+      | _ => error ("Not an equation: " ^ Syntax.string_of_term_global thy eqn);
+    val c_c' = case try (pairself (check_const thy)) lhs_rhs
      of SOME c_c' => c_c'
-      | _ => error ("Not an equation with two constants: " ^ Display.string_of_thm thm);
-    val some_class = the_list (AxClass.class_of_param thy (snd c_c'));
+      | _ => error ("Not an equation with two constants: "
+          ^ Syntax.string_of_term_global thy eqn);
+    val _ = if the_list (AxClass.class_of_param thy (snd c_c')) = [class] then ()
+      else error ("Inconsistent class: " ^ Display.string_of_thm thm);
   in thy |>
     ConstAlias.map (fn (alias, classes) =>
-      ((c_c', thm) :: alias, fold (insert (op =)) some_class classes))
+      ((c_c', thm) :: alias, insert (op =) class classes))
   end;
 
 fun resubst_alias thy =
@@ -306,19 +288,6 @@
 val triv_classes = snd o ConstAlias.get;
 
 
-(* reading constants as terms *)
-
-fun check_bare_const thy t = case try dest_Const t
- of SOME c_ty => c_ty
-  | NONE => error ("Not a constant: " ^ Syntax.string_of_term_global thy t);
-
-fun check_const thy = AxClass.unoverload_const thy o check_bare_const thy;
-
-fun read_bare_const thy = check_bare_const thy o Syntax.read_term_global thy;
-
-fun read_const thy = AxClass.unoverload_const thy o read_bare_const thy;
-
-
 (* constructor sets *)
 
 fun constrset_of_consts thy cs =
@@ -440,8 +409,6 @@
 
 fun assert_eqn thy is_constr = error_thm (gen_assert_eqn thy is_constr is_constr);
 
-val const_typ_eqn = dest_Const o fst o strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of;
-
 
 (*those following are permissive wrt. to overloaded constants!*)
 
@@ -456,14 +423,14 @@
   o try_thm (gen_assert_eqn thy is_constr_head (K true))
   o rpair false o LocalDefs.meta_rewrite_rule (ProofContext.init thy);
 
-fun const_typ_eqn_unoverload thy thm =
+fun const_typ_eqn thy thm =
   let
-    val (c, ty) = const_typ_eqn thm;
+    val (c, ty) = (dest_Const o fst o strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of) thm;
     val c' = AxClass.unoverload_const thy (c, ty);
   in (c', ty) end;
 
-fun typscheme_eqn thy = typscheme thy o const_typ_eqn_unoverload thy;
-fun const_eqn thy = fst o const_typ_eqn_unoverload thy;
+fun typscheme_eqn thy = typscheme thy o const_typ_eqn thy;
+fun const_eqn thy = fst o const_typ_eqn thy;
 
 
 (* case cerificates *)
@@ -787,7 +754,7 @@
     val dtyps = the_dtyps exec
       |> Symtab.dest
       |> map (fn (dtco, (_, (vs, cos)) :: _) =>
-          (Syntax.string_of_typ_global thy (Type (dtco, map TFree vs)), cos))
+          (string_of_typ thy (Type (dtco, map TFree vs)), cos))
       |> sort (string_ord o pairself fst)
   in
     (Pretty.writeln o Pretty.chunks) [
@@ -817,7 +784,7 @@
             val max' = Thm.maxidx_of thm' + 1;
           in (thm', max') end;
         val (thms', maxidx) = fold_map incr_thm thms 0;
-        val ty1 :: tys = map (snd o const_typ_eqn) thms';
+        val ty1 :: tys = map (snd o const_typ_eqn thy) thms';
         fun unify ty env = Sign.typ_unify thy (ty1, ty) env
           handle Type.TUNIFY =>
             error ("Type unificaton failed, while unifying code equations\n"
@@ -963,19 +930,6 @@
   Symtab.dest ((the_eqns o the_exec) thy)
   |> maps (Lazy.force o snd o snd o fst o snd);
 
-fun default_typscheme thy c =
-  let
-    fun the_const_typscheme c = (curry (typscheme thy) c o snd o dest_Const
-      o TermSubst.zero_var_indexes o curry Const "" o Sign.the_const_type thy) c;
-    fun strip_sorts (vs, ty) = (map (fn (v, _) => (v, [])) vs, ty);
-  in case AxClass.class_of_param thy c
-   of SOME class => ([(Name.aT, [class])], snd (the_const_typscheme c))
-    | NONE => if is_constr thy c
-        then strip_sorts (the_const_typscheme c)
-        else case get_eqns thy c
-         of (thm, _) :: _ => (typscheme_eqn thy o Drule.zero_var_indexes) thm
-          | [] => strip_sorts (the_const_typscheme c) end;
-
 end; (*struct*)
 
 
--- a/src/Tools/Code/code_preproc.ML	Tue Jul 07 17:21:26 2009 +0200
+++ b/src/Tools/Code/code_preproc.ML	Tue Jul 07 17:21:27 2009 +0200
@@ -102,6 +102,15 @@
 
 fun rhs_conv conv thm = Thm.transitive thm ((conv o Thm.rhs_of) thm);
 
+fun eqn_conv conv =
+  let
+    fun lhs_conv ct = if can Thm.dest_comb ct
+      then Conv.combination_conv lhs_conv conv ct
+      else Conv.all_conv ct;
+  in Conv.combination_conv (Conv.arg_conv lhs_conv) conv end;
+
+val rewrite_eqn = Conv.fconv_rule o eqn_conv o Simplifier.rewrite;
+
 fun term_of_conv thy f =
   Thm.cterm_of thy
   #> f
@@ -117,7 +126,7 @@
   in
     eqns
     |> apply_functrans thy c functrans
-    |> (map o apfst) (Code.rewrite_eqn pre)
+    |> (map o apfst) (rewrite_eqn pre)
     |> (map o apfst) (AxClass.unoverload thy)
     |> map (Code.assert_eqn thy)
     |> burrow_fst (Code.norm_args thy)
@@ -213,9 +222,19 @@
   (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, Logic.unvarifyT ty)) | _ => I)
     (map (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of o fst) eqns);
 
+fun default_typscheme_of thy c =
+  let
+    val ty = (snd o dest_Const o TermSubst.zero_var_indexes o curry Const c
+      o Type.strip_sorts o Sign.the_const_type thy) c;
+  in case AxClass.class_of_param thy c
+   of SOME class => ([(Name.aT, [class])], ty)
+    | NONE => Code.typscheme thy (c, ty)
+  end;
+
 fun tyscm_rhss_of thy c eqns =
   let
-    val tyscm = case eqns of [] => Code.default_typscheme thy c
+    val tyscm = case eqns
+     of [] => default_typscheme_of thy c
       | ((thm, _) :: _) => Code.typscheme_eqn thy thm;
     val rhss = consts_of thy eqns;
   in (tyscm, rhss) end;
@@ -381,6 +400,17 @@
        handle Sorts.CLASS_ERROR _ => [] (*permissive!*))
   end;
 
+fun inst_thm thy tvars' thm =
+  let
+    val tvars = (Term.add_tvars o Thm.prop_of) thm [];
+    val inter_sort = Sorts.inter_sort (Sign.classes_of thy);
+    fun mk_inst (tvar as (v, sort)) = case Vartab.lookup tvars' v
+     of SOME sort' => SOME (pairself (Thm.ctyp_of thy o TVar)
+          (tvar, (v, inter_sort (sort, sort'))))
+      | NONE => NONE;
+    val insts = map_filter mk_inst tvars;
+  in Thm.instantiate (insts, []) thm end;
+
 fun add_arity thy vardeps (class, tyco) =
   AList.default (op =)
     ((class, tyco), map (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k))
@@ -394,7 +424,7 @@
     val inst_tab = Vartab.empty |> fold (fn (v, sort) =>
       Vartab.update ((v, 0), sort)) lhs;
     val eqns = proto_eqns
-      |> (map o apfst) (Code.inst_thm thy inst_tab);
+      |> (map o apfst) (inst_thm thy inst_tab);
     val (tyscm, rhss') = tyscm_rhss_of thy c eqns;
     val eqngr' = Graph.new_node (c, (tyscm, eqns)) eqngr;
   in (map (pair c) rhss' @ rhss, eqngr') end;
--- a/src/Tools/Code/code_target.ML	Tue Jul 07 17:21:26 2009 +0200
+++ b/src/Tools/Code/code_target.ML	Tue Jul 07 17:21:27 2009 +0200
@@ -286,7 +286,7 @@
 fun gen_add_syntax_const prep_const target raw_c raw_syn thy =
   let
     val c = prep_const thy raw_c;
-    fun check_args (syntax as (n, _)) = if n > Code.no_args thy c
+    fun check_args (syntax as (n, _)) = if n > Code.args_number thy c
       then error ("Too many arguments in syntax for constant " ^ quote c)
       else syntax;
   in case raw_syn
--- a/src/Tools/Code/code_thingol.ML	Tue Jul 07 17:21:26 2009 +0200
+++ b/src/Tools/Code/code_thingol.ML	Tue Jul 07 17:21:27 2009 +0200
@@ -627,8 +627,8 @@
     fun arg_types num_args ty = (fst o chop num_args o fst o strip_type) ty;
     val tys = arg_types num_args (snd c_ty);
     val ty = nth tys t_pos;
-    fun mk_constr c t = let val n = Code.no_args thy c
-      in ((c, arg_types (Code.no_args thy c) (fastype_of t) ---> ty), n) end;
+    fun mk_constr c t = let val n = Code.args_number thy c
+      in ((c, arg_types n (fastype_of t) ---> ty), n) end;
     val constrs = if null case_pats then []
       else map2 mk_constr case_pats (nth_drop t_pos ts);
     fun casify naming constrs ty ts =
--- a/src/Tools/nbe.ML	Tue Jul 07 17:21:26 2009 +0200
+++ b/src/Tools/nbe.ML	Tue Jul 07 17:21:27 2009 +0200
@@ -393,10 +393,11 @@
           let
             val ts' = take_until is_dict ts;
             val c = const_of_idx idx;
-            val (_, T) = Code.default_typscheme thy c;
-            val T' = map_type_tfree (fn (v, _) => TypeInfer.param typidx (v, [])) T;
+            val T = map_type_tvar (fn ((v, i), _) =>
+              TypeInfer.param typidx (v ^ string_of_int i, []))
+                (Sign.the_const_type thy c);
             val typidx' = typidx + 1;
-          in of_apps bounds (Term.Const (c, T'), ts') typidx' end
+          in of_apps bounds (Term.Const (c, T), ts') typidx' end
       | of_univ bounds (BVar (n, ts)) typidx =
           of_apps bounds (Bound (bounds - n - 1), ts) typidx
       | of_univ bounds (t as Abs _) typidx =