explicit abstract type of code certificates
authorhaftmann
Wed, 13 Jan 2010 12:20:37 +0100
changeset 34895 19fd499cddff
parent 34894 fadbdd350dd1
child 34896 a22b09addd78
child 34898 62d70417f8ce
explicit abstract type of code certificates
src/HOL/Tools/recfun_codegen.ML
src/Pure/Isar/code.ML
src/Tools/Code/code_preproc.ML
src/Tools/Code/code_thingol.ML
--- a/src/HOL/Tools/recfun_codegen.ML	Wed Jan 13 10:18:45 2010 +0100
+++ b/src/HOL/Tools/recfun_codegen.ML	Wed Jan 13 12:20:37 2010 +0100
@@ -36,7 +36,7 @@
       let val (_, T) = Code.const_typ_eqn thy thm
       in if null (Term.add_tvarsT T []) orelse (null o fst o strip_type) T
         then [thm]
-        else [Code_Thingol.expand_eta thy 1 thm]
+        else [Code.expand_eta thy 1 thm]
       end
   | avoid_value thy thms = thms;
 
@@ -44,8 +44,9 @@
   let
     val c = AxClass.unoverload_const thy (raw_c, T);
     val raw_thms = Code.get_cert thy (Code_Preproc.preprocess_functrans thy) c
-      |> Code.eqns_of_cert thy
-      |> map_filter (fn (thm, linear) => if linear then SOME thm else NONE)
+      |> Code.equations_thms_cert thy
+      |> snd
+      |> map_filter (fn (_, (thm, proper)) => if proper then SOME thm else NONE)
       |> map (AxClass.overload thy)
       |> filter (is_instance T o snd o const_of o prop_of);
     val module_name = case Symtab.lookup (ModuleData.get thy) c
@@ -57,7 +58,6 @@
     raw_thms
     |> preprocess thy
     |> avoid_value thy
-    |> Code_Thingol.canonize_thms thy
     |> rpair module_name
   end;
 
--- a/src/Pure/Isar/code.ML	Wed Jan 13 10:18:45 2010 +0100
+++ b/src/Pure/Isar/code.ML	Wed Jan 13 12:20:37 2010 +0100
@@ -29,12 +29,15 @@
   val mk_eqn_liberal: theory -> thm -> (thm * bool) option
   val assert_eqn: theory -> thm * bool -> thm * bool
   val const_typ_eqn: theory -> thm -> string * typ
-  type cert = thm * bool list
+  val expand_eta: theory -> int -> thm -> thm
+  type cert
   val empty_cert: theory -> string -> cert
   val cert_of_eqns: theory -> string -> (thm * bool) list -> cert
   val constrain_cert: theory -> sort list -> cert -> cert
-  val eqns_of_cert: theory -> cert -> (thm * bool) list
-  val dest_cert: theory -> cert -> ((string * sort) list * typ) * ((term list * term) * (thm * bool)) list
+  val typscheme_cert: theory -> cert -> (string * sort) list * typ
+  val equations_cert: theory -> cert -> ((string * sort) list * typ) * (term list * term) list
+  val equations_thms_cert: theory -> cert -> ((string * sort) list * typ) * ((term list * term) * (thm * bool)) list
+  val pretty_cert: theory -> cert -> Pretty.T list
 
   (*executable code*)
   val add_type: string -> theory -> theory
@@ -511,20 +514,71 @@
 
 fun typscheme thy (c, ty) = logical_typscheme thy (c, subst_signature thy c ty);
 
-fun assert_eqns_const thy c eqns =
+
+(* technical transformations of code equations *)
+
+fun expand_eta thy k thm =
+  let
+    val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm;
+    val (_, args) = strip_comb lhs;
+    val l = if k = ~1
+      then (length o fst o strip_abs) rhs
+      else Int.max (0, k - length args);
+    val (raw_vars, _) = Term.strip_abs_eta l rhs;
+    val vars = burrow_fst (Name.variant_list (map (fst o fst) (Term.add_vars lhs [])))
+      raw_vars;
+    fun expand (v, ty) thm = Drule.fun_cong_rule thm
+      (Thm.cterm_of thy (Var ((v, 0), ty)));
+  in
+    thm
+    |> fold expand vars
+    |> Conv.fconv_rule Drule.beta_eta_conversion
+  end;
+
+fun same_arity thy thms =
   let
-    fun cert (eqn as (thm, _)) = if c = const_eqn thy thm
-      then eqn else error ("Wrong head of code equation,\nexpected constant "
-        ^ string_of_const thy c ^ "\n" ^ Display.string_of_thm_global thy thm)
-  in map (cert o assert_eqn thy) eqns end;
+    val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
+    val k = fold (Integer.max o num_args_of o Thm.prop_of) thms 0;
+  in map (expand_eta thy k) thms end;
+
+fun mk_desymbolization pre post mk vs =
+  let
+    val names = map (pre o fst o fst) vs
+      |> map (Name.desymbolize false)
+      |> Name.variant_list []
+      |> map post;
+  in map_filter (fn (((v, i), x), v') =>
+    if v = v' andalso i = 0 then NONE
+    else SOME (((v, i), x), mk ((v', 0), x))) (vs ~~ names)
+  end;
+
+fun desymbolize_tvars thy thms =
+  let
+    val tvs = fold (Term.add_tvars o Thm.prop_of) thms [];
+    val tvar_subst = mk_desymbolization (unprefix "'") (prefix "'") TVar tvs;
+  in map (Thm.certify_instantiate (tvar_subst, [])) thms end;
+
+fun desymbolize_vars thy thm =
+  let
+    val vs = Term.add_vars (Thm.prop_of thm) [];
+    val var_subst = mk_desymbolization I I Var vs;
+  in Thm.certify_instantiate ([], var_subst) thm end;
+
+fun canonize_thms thy = desymbolize_tvars thy #> same_arity thy #> map (desymbolize_vars thy);
 
 
 (* code equation certificates *)
 
-type cert = thm * bool list;
+fun build_head thy (c, ty) =
+  Thm.cterm_of thy (Logic.mk_equals (Free ("HEAD", ty), Const (c, ty)));
 
-fun mk_head_cterm thy (c, ty) =
-  Thm.cterm_of thy (Logic.mk_equals (Free ("HEAD", ty), Const (c, ty)));
+fun get_head thy cert_thm =
+  let
+    val [head] = (#hyps o Thm.crep_thm) cert_thm;
+    val (_, Const (c, ty)) = (Logic.dest_equals o Thm.term_of) head;
+  in (typscheme thy (c, ty), head) end;
+
+abstype cert = Cert of thm * bool list with
 
 fun empty_cert thy c = 
   let
@@ -535,14 +589,18 @@
       | NONE => Name.invent_list [] Name.aT (length tvars)
           |> map (fn v => TFree (v, []));
     val ty = typ_subst_TVars (tvars ~~ tvars') raw_ty;
-    val chead = mk_head_cterm thy (c, ty);
-  in (Thm.weaken chead Drule.dummy_thm, []) end;
+    val chead = build_head thy (c, ty);
+  in Cert (Thm.weaken chead Drule.dummy_thm, []) end;
 
 fun cert_of_eqns thy c [] = empty_cert thy c
-  | cert_of_eqns thy c eqns = 
+  | cert_of_eqns thy c raw_eqns = 
       let
-        val _ = assert_eqns_const thy c eqns;
+        val eqns = burrow_fst (canonize_thms thy) raw_eqns;
+        val _ = map (assert_eqn thy) eqns;
         val (thms, propers) = split_list eqns;
+        val _ = map (fn thm => if c = const_eqn thy thm then ()
+          else error ("Wrong head of code equation,\nexpected constant "
+            ^ string_of_const thy c ^ "\n" ^ Display.string_of_thm_global thy thm)) thms;
         fun tvars_of T = rev (Term.add_tvarsT T []);
         val vss = map (tvars_of o snd o head_eqn) thms;
         fun inter_sorts vs =
@@ -551,55 +609,59 @@
         val vts = Name.names Name.context Name.aT sorts;
         val thms as thm :: _ =
           map2 (fn vs => Thm.certify_instantiate (vs ~~ map TFree vts, [])) vss thms;
-        val head_thm = Thm.symmetric (Thm.assume (mk_head_cterm thy (head_eqn (hd thms))));
+        val head_thm = Thm.symmetric (Thm.assume (build_head thy (head_eqn (hd thms))));
         fun head_conv ct = if can Thm.dest_comb ct
           then Conv.fun_conv head_conv ct
           else Conv.rewr_conv head_thm ct;
         val rewrite_head = Conv.fconv_rule (Conv.arg1_conv head_conv);
         val cert_thm = Conjunction.intr_balanced (map rewrite_head thms);
-      in (cert_thm, propers) end;
-
-fun head_cert thy cert_thm =
-  let
-    val [head] = Thm.hyps_of cert_thm;
-    val (Free (h, _), Const (c, ty)) = Logic.dest_equals head;
-  in ((c, typscheme thy (c, ty)), (head, h)) end;
+      in Cert (cert_thm, propers) end;
 
-fun constrain_cert thy sorts (cert_thm, propers) =
+fun constrain_cert thy sorts (Cert (cert_thm, propers)) =
   let
-    val ((c, (vs, _)), (head, _)) = head_cert thy cert_thm;
-    val subst = map2 (fn (v, _) => fn sort => (v, sort)) vs sorts;
-    val head' = (map_types o map_atyps)
-      (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) subst v))) head;
-    val inst = (map2 (fn (v, sort) => fn sort' =>
-      pairself (Thm.ctyp_of thy) (TVar ((v, 0), sort), TFree (v, sort'))) vs sorts, []);
+    val ((vs, _), head) = get_head thy cert_thm;
+    val subst = map2 (fn (v, sort) => fn sort' =>
+      (v, Sorts.inter_sort (Sign.classes_of thy) (sort, sort'))) vs sorts;
+    val head' = Thm.term_of head
+      |> (map_types o map_atyps)
+          (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) subst v)))
+      |> Thm.cterm_of thy;
+    val inst = map2 (fn (v, sort) => fn (_, sort') =>
+      (((v, 0), sort), TFree (v, sort'))) vs subst;
     val cert_thm' = cert_thm
-      |> Thm.implies_intr (Thm.cterm_of thy head)
+      |> Thm.implies_intr head
       |> Thm.varifyT
-      |> Thm.instantiate inst
-      |> Thm.elim_implies (Thm.assume (Thm.cterm_of thy head'));
-  in (cert_thm', propers) end;
+      |> Thm.certify_instantiate (inst, [])
+      |> Thm.elim_implies (Thm.assume head');
+  in (Cert (cert_thm', propers)) end;
 
-fun eqns_of_cert thy (cert_thm, []) = []
-  | eqns_of_cert thy (cert_thm, propers) =
-      let
-        val (_, (head, _)) = head_cert thy cert_thm;
-        val thms = cert_thm
-          |> LocalDefs.expand [Thm.cterm_of thy head]
-          |> Thm.varifyT
-          |> Conjunction.elim_balanced (length propers)
-      in thms ~~ propers end;
+fun typscheme_cert thy (Cert (cert_thm, _)) =
+  fst (get_head thy cert_thm);
 
-fun dest_cert thy (cert as (cert_thm, propers)) =
+fun equations_cert thy (cert as Cert (cert_thm, propers)) =
   let
-    val eqns = eqns_of_cert thy cert;
-    val ((_, vs_ty), _) = head_cert thy cert_thm;
-    val equations = if null propers then [] else cert_thm
-      |> Thm.prop_of
+    val tyscm = typscheme_cert thy cert;
+    val equations = if null propers then [] else
+      Thm.prop_of cert_thm
       |> Logic.dest_conjunction_balanced (length propers)
       |> map Logic.dest_equals
       |> (map o apfst) (snd o strip_comb)
-  in (vs_ty, equations ~~ eqns) end;
+  in (tyscm, equations) end;
+
+fun equations_thms_cert thy (cert as Cert (cert_thm, propers)) =
+  let
+    val (tyscm, equations) = equations_cert thy cert;
+    val thms = if null propers then [] else
+      cert_thm
+      |> LocalDefs.expand [snd (get_head thy cert_thm)]
+      |> Thm.varifyT
+      |> Conjunction.elim_balanced (length propers)
+  in (tyscm, equations ~~ (thms ~~ propers)) end;
+
+fun pretty_cert thy = map (Display.pretty_thm_global thy o AxClass.overload thy o fst o snd)
+  o snd o equations_thms_cert thy;
+
+end;
 
 
 (* code equation access *)
--- a/src/Tools/Code/code_preproc.ML	Wed Jan 13 10:18:45 2010 +0100
+++ b/src/Tools/Code/code_preproc.ML	Wed Jan 13 12:20:37 2010 +0100
@@ -199,11 +199,7 @@
   AList.make (snd o Graph.get_node eqngr) (Graph.keys eqngr)
   |> (map o apfst) (Code.string_of_const thy)
   |> sort (string_ord o pairself fst)
-  |> map (fn (s, cert) =>
-       (Pretty.block o Pretty.fbreaks) (
-         Pretty.str s
-         :: map (Display.pretty_thm_global thy o AxClass.overload thy o fst) (Code.eqns_of_cert thy cert)
-       ))
+  |> map (fn (s, cert) => (Pretty.block o Pretty.fbreaks) (Pretty.str s :: Code.pretty_cert thy cert))
   |> Pretty.chunks;
 
 
@@ -220,13 +216,13 @@
   map (fn (c, _) => AxClass.param_of_inst thy (c, tyco))
     o maps (#params o AxClass.get_info thy);
 
-fun typscheme_rhss thy c cert =
+fun typargs_rhss thy c cert =
   let
-    val (tyscm, equations) = Code.dest_cert thy cert;
+    val ((vs, _), equations) = Code.equations_cert thy cert;
     val rhss = [] |> (fold o fold o fold_aterms)
       (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, ty)) | _ => I)
-        (map (op :: o swap o fst) equations);
-  in (tyscm, rhss) end;
+        (map (op :: o swap) equations);
+  in (vs, rhss) end;
 
 
 (* data structures *)
@@ -266,7 +262,7 @@
    of SOME (lhs, cert) => ((lhs, []), cert)
     | NONE => let
         val cert = Code.get_cert thy (preprocess thy) c;
-        val ((lhs, _), rhss) = typscheme_rhss thy c cert;
+        val (lhs, rhss) = typargs_rhss thy c cert;
       in ((lhs, rhss), cert) end;
 
 fun obtain_instance thy arities (inst as (class, tyco)) =
@@ -388,14 +384,6 @@
        handle Sorts.CLASS_ERROR _ => [] (*permissive!*))
   end;
 
-fun inst_cert thy lhs cert =
-  let
-    val ((vs, _), _) = Code.dest_cert thy cert;
-    val sorts = map (fn (v, sort) => case AList.lookup (op =) lhs v
-     of SOME sort' => Sorts.inter_sort (Sign.classes_of thy) (sort, sort')
-      | NONE => sort) vs;
-  in Code.constrain_cert thy sorts cert end;
-
 fun add_arity thy vardeps (class, tyco) =
   AList.default (op =) ((class, tyco),
     map_range (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k))
@@ -406,8 +394,8 @@
   else let
     val lhs = map_index (fn (k, (v, _)) =>
       (v, snd (Vargraph.get_node vardeps (Fun c, k)))) proto_lhs;
-    val cert = inst_cert thy lhs proto_cert;
-    val ((vs, _), rhss') = typscheme_rhss thy c cert;
+    val cert = Code.constrain_cert thy (map snd lhs) proto_cert;
+    val (vs, rhss') = typargs_rhss thy c cert;
     val eqngr' = Graph.new_node (c, (vs, cert)) eqngr;
   in (map (pair c) rhss' @ rhss, eqngr') end;
 
--- a/src/Tools/Code/code_thingol.ML	Wed Jan 13 10:18:45 2010 +0100
+++ b/src/Tools/Code/code_thingol.ML	Wed Jan 13 12:20:37 2010 +0100
@@ -86,8 +86,6 @@
     -> ((string * stmt) list * (string * stmt) list
       * ((string * stmt) list * (string * stmt) list)) list
 
-  val expand_eta: theory -> int -> thm -> thm
-  val canonize_thms: theory -> thm list -> thm list
   val read_const_exprs: theory -> string list -> string list * string list
   val consts_program: theory -> string list -> string list * (naming * program)
   val eval_conv: theory
@@ -397,60 +395,6 @@
 end; (* local *)
 
 
-(** technical transformations of code equations **)
-
-fun expand_eta thy k thm =
-  let
-    val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm;
-    val (_, args) = strip_comb lhs;
-    val l = if k = ~1
-      then (length o fst o strip_abs) rhs
-      else Int.max (0, k - length args);
-    val (raw_vars, _) = Term.strip_abs_eta l rhs;
-    val vars = burrow_fst (Name.variant_list (map (fst o fst) (Term.add_vars lhs [])))
-      raw_vars;
-    fun expand (v, ty) thm = Drule.fun_cong_rule thm
-      (Thm.cterm_of thy (Var ((v, 0), ty)));
-  in
-    thm
-    |> fold expand vars
-    |> Conv.fconv_rule Drule.beta_eta_conversion
-  end;
-
-fun same_arity thy thms =
-  let
-    val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
-    val k = fold (Integer.max o num_args_of o Thm.prop_of) thms 0;
-  in map (expand_eta thy k) thms end;
-
-fun mk_desymbolization pre post mk vs =
-  let
-    val names = map (pre o fst o fst) vs
-      |> map (Name.desymbolize false)
-      |> Name.variant_list []
-      |> map post;
-  in map_filter (fn (((v, i), x), v') =>
-    if v = v' andalso i = 0 then NONE
-    else SOME (((v, i), x), mk ((v', 0), x))) (vs ~~ names)
-  end;
-
-fun desymbolize_tvars thy thms =
-  let
-    val tvs = fold (Term.add_tvars o Thm.prop_of) thms [];
-    val tvar_subst = mk_desymbolization (unprefix "'") (prefix "'") TVar tvs;
-  in map (Thm.certify_instantiate (tvar_subst, [])) thms end;
-
-fun desymbolize_vars thy thm =
-  let
-    val vs = Term.add_vars (Thm.prop_of thm) [];
-    val var_subst = mk_desymbolization I I Var vs;
-  in Thm.certify_instantiate ([], var_subst) thm end;
-
-fun canonize_thms thy = map (Thm.transfer thy)
-  #> desymbolize_tvars thy
-  #> same_arity thy #> map (desymbolize_vars thy);
-
-
 (** statements, abstract programs **)
 
 type typscheme = (vname * sort) list * itype;
@@ -614,8 +558,8 @@
       #>> (fn class => Classparam (c, class));
     fun stmt_fun cert =
       let
-        val ((vs, ty), raw_eqns) = Code.dest_cert thy cert;
-        val eqns = burrow_fst (canonize_thms thy) (map snd raw_eqns);
+        val ((vs, ty), raw_eqns) = Code.equations_thms_cert thy cert;
+        val eqns = map snd raw_eqns;
       in
         fold_map (translate_tyvar_sort thy algbr eqngr) vs
         ##>> translate_typ thy algbr eqngr ty