code certificates as integral part of code generation
authorhaftmann
Tue, 12 Jan 2010 16:27:42 +0100
changeset 34891 99b9a6290446
parent 34877 ded5b770ec1c
child 34892 6144d233b99a
code certificates as integral part of code generation
src/HOL/Tools/recfun_codegen.ML
src/Pure/Isar/code.ML
src/Tools/Code/code_preproc.ML
src/Tools/Code/code_thingol.ML
--- a/src/HOL/Tools/recfun_codegen.ML	Tue Jan 12 09:59:45 2010 +0100
+++ b/src/HOL/Tools/recfun_codegen.ML	Tue Jan 12 16:27:42 2010 +0100
@@ -43,8 +43,10 @@
 fun get_equations thy defs (raw_c, T) = if raw_c = @{const_name "op ="} then ([], "") else
   let
     val c = AxClass.unoverload_const thy (raw_c, T);
-    val raw_thms = Code.these_eqns thy c
+    val raw_thms = Code.get_cert thy I c
+      |> Code.eqns_of_cert thy
       |> map_filter (fn (thm, linear) => if linear then SOME thm else NONE)
+      |> map (AxClass.overload thy)
       |> filter (is_instance T o snd o const_of o prop_of);
     val module_name = case Symtab.lookup (ModuleData.get thy) c
      of SOME module_name => module_name
--- a/src/Pure/Isar/code.ML	Tue Jan 12 09:59:45 2010 +0100
+++ b/src/Pure/Isar/code.ML	Tue Jan 12 16:27:42 2010 +0100
@@ -28,17 +28,13 @@
   val mk_eqn_warning: theory -> thm -> (thm * bool) option
   val mk_eqn_liberal: theory -> thm -> (thm * bool) option
   val assert_eqn: theory -> thm * bool -> thm * bool
-  val assert_eqns_const: theory -> string
-    -> (thm * bool) list -> (thm * bool) list
   val const_typ_eqn: theory -> thm -> string * typ
-  val typscheme_eqn: theory -> thm -> (string * sort) list * typ
-  val typscheme_eqns: theory -> string -> thm list -> (string * sort) list * typ
-  val standard_typscheme: theory -> thm list -> thm list
   type cert = thm * bool list
-  val cert_of_eqns: theory -> (thm * bool) list -> cert
+  val empty_cert: theory -> string -> cert
+  val cert_of_eqns: theory -> string -> (thm * bool) list -> cert
   val constrain_cert: theory -> sort list -> cert -> cert
-  val dest_cert: theory -> cert -> (string * ((string * sort) list * typ)) * ((term list * term) * bool) list
   val eqns_of_cert: theory -> cert -> (thm * bool) list
+  val dest_cert: theory -> cert -> ((string * sort) list * typ) * ((term list * term) * (thm * bool)) list
 
   (*executable code*)
   val add_type: string -> theory -> theory
@@ -61,8 +57,7 @@
   val add_undefined: string -> theory -> theory
   val get_datatype: theory -> string -> ((string * sort) list * (string * typ list) list)
   val get_datatype_of_constr: theory -> string -> string option
-  val these_eqns: theory -> string -> (thm * bool) list
-  val eqn_cert: theory -> string -> cert
+  val get_cert: theory -> ((thm * bool) list -> (thm * bool) list) -> string -> cert
   val get_case_scheme: theory -> string -> (int * (int * string list)) option
   val undefineds: theory -> string list
   val print_codesetup: theory -> unit
@@ -531,20 +526,6 @@
 
 fun typscheme thy (c, ty) = logical_typscheme thy (c, subst_signature thy c ty);
 
-fun typscheme_eqn thy = typscheme thy o apsnd Logic.unvarifyT o const_typ_eqn thy;
-
-fun typscheme_eqns thy c [] = 
-      let
-        val raw_ty = const_typ thy c;
-        val tvars = Term.add_tvar_namesT raw_ty [];
-        val tvars' = case AxClass.class_of_param thy c
-         of SOME class => [TFree (Name.aT, [class])]
-          | NONE => Name.invent_list [] Name.aT (length tvars)
-              |> map (fn v => TFree (v, []));
-        val ty = typ_subst_TVars (tvars ~~ tvars') raw_ty;
-      in logical_typscheme thy (c, ty) end
-  | typscheme_eqns thy c (thm :: _) = typscheme_eqn thy thm;
-
 fun assert_eqns_const thy c eqns =
   let
     fun cert (eqn as (thm, _)) = if c = const_eqn thy thm
@@ -555,93 +536,97 @@
 
 (* code equation certificates *)
 
-fun standard_typscheme thy thms =
-  let
-    fun tvars_of T = rev (Term.add_tvarsT T []);
-    val vss = map (tvars_of o snd o head_eqn) thms;
-    fun inter_sorts vs =
-      fold (curry (Sorts.inter_sort (Sign.classes_of thy)) o snd) vs [];
-    val sorts = map_transpose inter_sorts vss;
-    val vts = Name.names Name.context Name.aT sorts
-      |> map (fn (v, sort) => TVar ((v, 0), sort));
-  in map2 (fn vs => Thm.certify_instantiate (vs ~~ vts, [])) vss thms end;
-
 type cert = thm * bool list;
 
-fun cert_of_eqns thy [] = (Drule.dummy_thm, [])
-  | cert_of_eqns thy eqns = 
+fun mk_head_cterm thy (c, ty) =
+  Thm.cterm_of thy (Logic.mk_equals (Free ("HEAD", ty), Const (c, ty)));
+
+fun empty_cert thy c = 
+  let
+    val raw_ty = const_typ thy c;
+    val tvars = Term.add_tvar_namesT raw_ty [];
+    val tvars' = case AxClass.class_of_param thy c
+     of SOME class => [TFree (Name.aT, [class])]
+      | NONE => Name.invent_list [] Name.aT (length tvars)
+          |> map (fn v => TFree (v, []));
+    val ty = typ_subst_TVars (tvars ~~ tvars') raw_ty;
+    val chead = mk_head_cterm thy (c, ty);
+  in (Thm.weaken chead Drule.dummy_thm, []) end;
+
+fun cert_of_eqns thy c [] = empty_cert thy c
+  | cert_of_eqns thy c eqns = 
       let
-        val propers = map snd eqns;
-        val thms as thm :: _ = (map Thm.freezeT o standard_typscheme thy o map fst) eqns; (*FIXME*)
-        val (c, ty) = head_eqn thm;
-        val head_thm = Thm.assume (Thm.cterm_of thy (Logic.mk_equals
-          (Free ("HEAD", ty), Const (c, ty)))) |> Thm.symmetric;
+        val _ = assert_eqns_const thy c eqns;
+        val (thms, propers) = split_list eqns;
+        fun tvars_of T = rev (Term.add_tvarsT T []);
+        val vss = map (tvars_of o snd o head_eqn) thms;
+        fun inter_sorts vs =
+          fold (curry (Sorts.inter_sort (Sign.classes_of thy)) o snd) vs [];
+        val sorts = map_transpose inter_sorts vss;
+        val vts = Name.names Name.context Name.aT sorts;
+        val thms as thm :: _ =
+          map2 (fn vs => Thm.certify_instantiate (vs ~~ map TFree vts, [])) vss thms;
+        val head_thm = Thm.symmetric (Thm.assume (mk_head_cterm thy (head_eqn (hd thms))));
         fun head_conv ct = if can Thm.dest_comb ct
           then Conv.fun_conv head_conv ct
           else Conv.rewr_conv head_thm ct;
         val rewrite_head = Conv.fconv_rule (Conv.arg1_conv head_conv);
-        val cert = Conjunction.intr_balanced (map rewrite_head thms);
-      in (cert, propers) end;
+        val cert_thm = Conjunction.intr_balanced (map rewrite_head thms);
+      in (cert_thm, propers) end;
 
-fun head_cert thy cert =
+fun head_cert thy cert_thm =
+  let
+    val [head] = Thm.hyps_of cert_thm;
+    val (Free (h, _), Const (c, ty)) = Logic.dest_equals head;
+  in ((c, typscheme thy (c, ty)), (head, h)) end;
+
+fun constrain_cert thy sorts (cert_thm, propers) =
   let
-    val [head] = Thm.hyps_of cert;
-    val (Free (h, _), Const (c, ty)) = (Logic.dest_equals o the_single o Thm.hyps_of) cert;
-  in ((c, typscheme thy (AxClass.unoverload_const thy (c, ty), ty)), (head, h)) end;
+    val ((c, (vs, _)), (head, _)) = head_cert thy cert_thm;
+    val subst = map2 (fn (v, _) => fn sort => (v, sort)) vs sorts;
+    val head' = (map_types o map_atyps)
+      (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) subst v))) head;
+    val inst = (map2 (fn (v, sort) => fn sort' =>
+      pairself (Thm.ctyp_of thy) (TVar ((v, 0), sort), TFree (v, sort'))) vs sorts, []);
+    val cert_thm' = cert_thm
+      |> Thm.implies_intr (Thm.cterm_of thy head)
+      |> Thm.varifyT
+      |> Thm.instantiate inst
+      |> Thm.elim_implies (Thm.assume (Thm.cterm_of thy head'));
+  in (cert_thm', propers) end;
 
-fun constrain_cert thy sorts (cert, []) = (cert, [])
-  | constrain_cert thy sorts (cert, propers) =
+fun eqns_of_cert thy (cert_thm, []) = []
+  | eqns_of_cert thy (cert_thm, propers) =
       let
-        val ((c, (vs, _)), (head, _)) = head_cert thy cert;
-        val subst = map2 (fn (v, _) => fn sort => (v, sort)) vs sorts;
-        val head' = (map_types o map_atyps)
-          (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) subst v))) head;
-        val inst = (map2 (fn (v, sort) => fn sort' =>
-          pairself (Thm.ctyp_of thy) (TVar ((v, 0), sort), TFree (v, sort'))) vs sorts ,[]);
-        val cert' = cert
-          |> Thm.implies_intr (Thm.cterm_of thy head)
-          |> Thm.varifyT
-          |> Thm.instantiate inst
-          |> Thm.elim_implies (Thm.assume (Thm.cterm_of thy head'))
-      in (cert', propers) end;
-
-fun dest_cert thy (cert, propers) =
-  let
-    val (c_vs_ty, (head, h)) = head_cert thy cert;
-    val equations = cert
-      |> Thm.prop_of
-      |> Logic.dest_conjunction_balanced (length propers)
-      |> map Logic.dest_equals
-      |> (map o apfst) strip_comb
-      |> (map o apfst) (fn (Free (h', _), ts) => case h = h' of True => ts)
-  in (c_vs_ty, equations ~~ propers) end;
-
-fun eqns_of_cert thy (cert, []) = []
-  | eqns_of_cert thy (cert, propers) =
-      let
-        val (_, (head, _)) = head_cert thy cert;
-        val thms = cert
+        val (_, (head, _)) = head_cert thy cert_thm;
+        val thms = cert_thm
           |> LocalDefs.expand [Thm.cterm_of thy head]
           |> Thm.varifyT
           |> Conjunction.elim_balanced (length propers)
       in thms ~~ propers end;
 
+fun dest_cert thy (cert as (cert_thm, propers)) =
+  let
+    val eqns = eqns_of_cert thy cert;
+    val ((_, vs_ty), _) = head_cert thy cert_thm;
+    val equations = if null propers then [] else cert_thm
+      |> Thm.prop_of
+      |> Logic.dest_conjunction_balanced (length propers)
+      |> map Logic.dest_equals
+      |> (map o apfst) (snd o strip_comb)
+  in (vs_ty, equations ~~ eqns) end;
+
 
 (* code equation access *)
 
-fun these_eqns thy c =
+fun get_cert thy f c =
   Symtab.lookup ((the_eqns o the_exec) thy) c
   |> Option.map (snd o snd o fst)
   |> these
   |> (map o apfst) (Thm.transfer thy)
-  |> burrow_fst (standard_typscheme thy);
-
-fun eqn_cert thy c =
-  Symtab.lookup ((the_eqns o the_exec) thy) c
-  |> Option.map (snd o snd o fst)
-  |> these
-  |> (map o apfst) (Thm.transfer thy)
-  |> cert_of_eqns thy;
+  |> f
+  |> (map o apfst) (AxClass.unoverload thy)
+  |> cert_of_eqns thy c;
 
 
 (* cases *)
--- a/src/Tools/Code/code_preproc.ML	Tue Jan 12 09:59:45 2010 +0100
+++ b/src/Tools/Code/code_preproc.ML	Tue Jan 12 16:27:42 2010 +0100
@@ -18,7 +18,7 @@
 
   type code_algebra
   type code_graph
-  val eqns: code_graph -> string -> (thm * bool) list
+  val cert: code_graph -> string -> Code.cert
   val sortargs: code_graph -> string -> sort list
   val all: code_graph -> string list
   val pretty: theory -> code_graph -> Pretty.T
@@ -53,8 +53,8 @@
     let
       val pre = Simplifier.merge_ss (pre1, pre2);
       val post = Simplifier.merge_ss (post1, post2);
-      val functrans = AList.merge (op =) (eq_fst (op =)) (functrans1, functrans2);
-        (* FIXME handle AList.DUP (!?) *)
+      val functrans = AList.merge (op =) (eq_fst (op =)) (functrans1, functrans2)
+        handle AList.DUP => error ("Duplicate function transformer");
     in make_thmproc ((pre, post), functrans) end;
 
 structure Code_Preproc_Data = Theory_Data
@@ -102,23 +102,14 @@
 
 (* post- and preprocessing *)
 
-fun apply_functrans thy c _ [] = []
-  | apply_functrans thy c [] eqns = eqns
-  | apply_functrans thy c functrans eqns = eqns
-      |> perhaps (perhaps_loop (perhaps_apply functrans))
-      |> Code.assert_eqns_const thy c
-      (*FIXME in future, the check here should be more accurate wrt. type schemes
-      -- perhaps by means of upcoming code certificates with a corresponding
-         preprocessor protocol*);
-
 fun trans_conv_rule conv thm = Thm.transitive thm ((conv o Thm.rhs_of) thm);
 
-fun eqn_conv conv =
+fun eqn_conv conv ct =
   let
     fun lhs_conv ct = if can Thm.dest_comb ct
       then Conv.combination_conv lhs_conv conv ct
       else Conv.all_conv ct;
-  in Conv.combination_conv (Conv.arg_conv lhs_conv) conv end;
+  in Conv.combination_conv (Conv.arg_conv lhs_conv) conv ct end;
 
 val rewrite_eqn = Conv.fconv_rule o eqn_conv o Simplifier.rewrite;
 
@@ -129,17 +120,15 @@
   #> Logic.dest_equals
   #> snd;
 
-fun preprocess thy c eqns =
+fun preprocess thy eqns =
   let
     val pre = (Simplifier.theory_context thy o #pre o the_thmproc) thy;
     val functrans = (map (fn (_, (_, f)) => f thy) o #functrans
       o the_thmproc) thy;
   in
     eqns
-    |> apply_functrans thy c functrans
+    |> perhaps (perhaps_loop (perhaps_apply functrans))
     |> (map o apfst) (rewrite_eqn pre)
-    |> (map o apfst) (AxClass.unoverload thy)
-    |> map (Code.assert_eqn thy)
   end;
 
 fun preprocess_conv thy ct =
@@ -196,20 +185,20 @@
 (** sort algebra and code equation graph types **)
 
 type code_algebra = (sort -> sort) * Sorts.algebra;
-type code_graph = ((string * sort) list * (thm * bool) list) Graph.T;
+type code_graph = ((string * sort) list * Code.cert) Graph.T;
 
-fun eqns eqngr = these o Option.map snd o try (Graph.get_node eqngr);
-fun sortargs eqngr = map snd o fst o Graph.get_node eqngr
+fun cert eqngr = snd o Graph.get_node eqngr;
+fun sortargs eqngr = map snd o fst o Graph.get_node eqngr;
 fun all eqngr = Graph.keys eqngr;
 
 fun pretty thy eqngr =
   AList.make (snd o Graph.get_node eqngr) (Graph.keys eqngr)
   |> (map o apfst) (Code.string_of_const thy)
   |> sort (string_ord o pairself fst)
-  |> map (fn (s, thms) =>
+  |> map (fn (s, cert) =>
        (Pretty.block o Pretty.fbreaks) (
          Pretty.str s
-         :: map (Display.pretty_thm_global thy o fst) thms
+         :: map (Display.pretty_thm_global thy o AxClass.overload thy o fst) (Code.eqns_of_cert thy cert)
        ))
   |> Pretty.chunks;
 
@@ -227,12 +216,12 @@
   map (fn (c, _) => AxClass.param_of_inst thy (c, tyco))
     o maps (#params o AxClass.get_info thy);
 
-fun typscheme_rhss thy c eqns =
+fun typscheme_rhss thy c cert =
   let
-    val tyscm = Code.typscheme_eqns thy c (map fst eqns);
+    val (tyscm, equations) = Code.dest_cert thy cert;
     val rhss = [] |> (fold o fold o fold_aterms)
-      (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, Logic.unvarifyT ty)) | _ => I)
-        (map (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of o fst) eqns);
+      (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, ty)) | _ => I)
+        (map (op :: o swap o fst) equations);
   in (tyscm, rhss) end;
 
 
@@ -259,7 +248,7 @@
       | NONE => Free;
 
 type vardeps_data = ((string * styp list) list * class list) Vargraph.T
-  * (((string * sort) list * (thm * bool) list) Symtab.table
+  * (((string * sort) list * Code.cert) Symtab.table
     * (class * string) list);
 
 val empty_vardeps_data : vardeps_data =
@@ -270,12 +259,11 @@
 
 fun obtain_eqns thy eqngr c =
   case try (Graph.get_node eqngr) c
-   of SOME (lhs, eqns) => ((lhs, []), [])
+   of SOME (lhs, cert) => ((lhs, []), cert)
     | NONE => let
-        val eqns = Code.these_eqns thy c
-          |> preprocess thy c;
-        val ((lhs, _), rhss) = typscheme_rhss thy c eqns;
-      in ((lhs, rhss), eqns) end;
+        val cert = Code.get_cert thy (preprocess thy) c;
+        val ((lhs, _), rhss) = typscheme_rhss thy c cert;
+      in ((lhs, rhss), cert) end;
 
 fun obtain_instance thy arities (inst as (class, tyco)) =
   case AList.lookup (op =) arities inst
@@ -396,32 +384,27 @@
        handle Sorts.CLASS_ERROR _ => [] (*permissive!*))
   end;
 
-fun inst_thm thy tvars' thm =
+fun inst_cert thy lhs cert =
   let
-    val tvars = (Term.add_tvars o Thm.prop_of) thm [];
-    val inter_sort = Sorts.inter_sort (Sign.classes_of thy);
-    fun mk_inst (tvar as (v, sort)) = case Vartab.lookup tvars' v
-     of SOME sort' => SOME (pairself (Thm.ctyp_of thy o TVar)
-          (tvar, (v, inter_sort (sort, sort'))))
-      | NONE => NONE;
-    val insts = map_filter mk_inst tvars;
-  in Thm.instantiate (insts, []) thm end;
+    val ((vs, _), _) = Code.dest_cert thy cert;
+    val sorts = map (fn (v, sort) => case AList.lookup (op =) lhs v
+     of SOME sort' => Sorts.inter_sort (Sign.classes_of thy) (sort, sort')
+      | NONE => sort) vs;
+  in Code.constrain_cert thy sorts cert end;
 
 fun add_arity thy vardeps (class, tyco) =
   AList.default (op =) ((class, tyco),
-    map_range (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k)) (Sign.arity_number thy tyco));
+    map_range (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k))
+      (Sign.arity_number thy tyco));
 
-fun add_eqs thy vardeps (c, (proto_lhs, proto_eqns)) (rhss, eqngr) =
+fun add_cert thy vardeps (c, (proto_lhs, proto_cert)) (rhss, eqngr) =
   if can (Graph.get_node eqngr) c then (rhss, eqngr)
   else let
     val lhs = map_index (fn (k, (v, _)) =>
       (v, snd (Vargraph.get_node vardeps (Fun c, k)))) proto_lhs;
-    val inst_tab = Vartab.empty |> fold (fn (v, sort) =>
-      Vartab.update ((v, 0), sort)) lhs;
-    val eqns = proto_eqns
-      |> (map o apfst) (inst_thm thy inst_tab);
-    val ((vs, _), rhss') = typscheme_rhss thy c eqns;
-    val eqngr' = Graph.new_node (c, (vs, eqns)) eqngr;
+    val cert = inst_cert thy lhs proto_cert;
+    val ((vs, _), rhss') = typscheme_rhss thy c cert;
+    val eqngr' = Graph.new_node (c, (vs, cert)) eqngr;
   in (map (pair c) rhss' @ rhss, eqngr') end;
 
 fun extend_arities_eqngr thy cs ts (arities, (eqngr : code_graph)) =
@@ -435,7 +418,7 @@
     val pp = Syntax.pp_global thy;
     val algebra = Sorts.subalgebra pp (is_proper_class thy)
       (AList.lookup (op =) arities') (Sign.classes_of thy);
-    val (rhss, eqngr') = Symtab.fold (add_eqs thy vardeps) eqntab ([], eqngr);
+    val (rhss, eqngr') = Symtab.fold (add_cert thy vardeps) eqntab ([], eqngr);
     fun deps_of (c, rhs) = c :: maps (dicts_of thy algebra)
       (rhs ~~ sortargs eqngr' c);
     val eqngr'' = fold (fn (c, rhs) => fold
--- a/src/Tools/Code/code_thingol.ML	Tue Jan 12 09:59:45 2010 +0100
+++ b/src/Tools/Code/code_thingol.ML	Tue Jan 12 16:27:42 2010 +0100
@@ -447,7 +447,7 @@
   in Thm.certify_instantiate ([], var_subst) thm end;
 
 fun canonize_thms thy = map (Thm.transfer thy)
-  #> Code.standard_typscheme thy #> desymbolize_tvars thy
+  #> desymbolize_tvars thy
   #> same_arity thy #> map (desymbolize_vars thy);
 
 
@@ -612,10 +612,10 @@
     fun stmt_classparam class =
       ensure_class thy algbr eqngr class
       #>> (fn class => Classparam (c, class));
-    fun stmt_fun raw_eqns =
+    fun stmt_fun cert =
       let
-        val eqns = burrow_fst (canonize_thms thy) raw_eqns;
-        val (vs, ty) = Code.typscheme_eqns thy c (map fst eqns);
+        val ((vs, ty), raw_eqns) = Code.dest_cert thy cert;
+        val eqns = burrow_fst (canonize_thms thy) (map snd raw_eqns);
       in
         fold_map (translate_tyvar_sort thy algbr eqngr) vs
         ##>> translate_typ thy algbr eqngr ty
@@ -626,7 +626,7 @@
      of SOME tyco => stmt_datatypecons tyco
       | NONE => (case AxClass.class_of_param thy c
          of SOME class => stmt_classparam class
-          | NONE => stmt_fun (Code_Preproc.eqns eqngr c))
+          | NONE => stmt_fun (Code_Preproc.cert eqngr c))
   in ensure_stmt lookup_const (declare_const thy) stmt_const c end
 and ensure_class thy (algbr as (_, algebra)) eqngr class =
   let
@@ -933,11 +933,7 @@
   let
     val (_, eqngr) = Code_Preproc.obtain thy consts [];
     val all_consts = Graph.all_succs eqngr consts;
-  in
-    eqngr
-    |> Graph.subgraph (member (op =) all_consts) 
-    |> Graph.map_nodes ((apsnd o map o apfst) (AxClass.overload thy))
-  end;
+  in Graph.subgraph (member (op =) all_consts) eqngr end;
 
 fun code_thms thy = Pretty.writeln o Code_Preproc.pretty thy o code_depgr thy;