src/Pure/Isar/code.ML
changeset 34895 19fd499cddff
parent 34894 fadbdd350dd1
child 34901 0d6a2ae86525
--- a/src/Pure/Isar/code.ML	Wed Jan 13 10:18:45 2010 +0100
+++ b/src/Pure/Isar/code.ML	Wed Jan 13 12:20:37 2010 +0100
@@ -29,12 +29,15 @@
   val mk_eqn_liberal: theory -> thm -> (thm * bool) option
   val assert_eqn: theory -> thm * bool -> thm * bool
   val const_typ_eqn: theory -> thm -> string * typ
-  type cert = thm * bool list
+  val expand_eta: theory -> int -> thm -> thm
+  type cert
   val empty_cert: theory -> string -> cert
   val cert_of_eqns: theory -> string -> (thm * bool) list -> cert
   val constrain_cert: theory -> sort list -> cert -> cert
-  val eqns_of_cert: theory -> cert -> (thm * bool) list
-  val dest_cert: theory -> cert -> ((string * sort) list * typ) * ((term list * term) * (thm * bool)) list
+  val typscheme_cert: theory -> cert -> (string * sort) list * typ
+  val equations_cert: theory -> cert -> ((string * sort) list * typ) * (term list * term) list
+  val equations_thms_cert: theory -> cert -> ((string * sort) list * typ) * ((term list * term) * (thm * bool)) list
+  val pretty_cert: theory -> cert -> Pretty.T list
 
   (*executable code*)
   val add_type: string -> theory -> theory
@@ -511,20 +514,71 @@
 
 fun typscheme thy (c, ty) = logical_typscheme thy (c, subst_signature thy c ty);
 
-fun assert_eqns_const thy c eqns =
+
+(* technical transformations of code equations *)
+
+fun expand_eta thy k thm =
+  let
+    val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm;
+    val (_, args) = strip_comb lhs;
+    val l = if k = ~1
+      then (length o fst o strip_abs) rhs
+      else Int.max (0, k - length args);
+    val (raw_vars, _) = Term.strip_abs_eta l rhs;
+    val vars = burrow_fst (Name.variant_list (map (fst o fst) (Term.add_vars lhs [])))
+      raw_vars;
+    fun expand (v, ty) thm = Drule.fun_cong_rule thm
+      (Thm.cterm_of thy (Var ((v, 0), ty)));
+  in
+    thm
+    |> fold expand vars
+    |> Conv.fconv_rule Drule.beta_eta_conversion
+  end;
+
+fun same_arity thy thms =
   let
-    fun cert (eqn as (thm, _)) = if c = const_eqn thy thm
-      then eqn else error ("Wrong head of code equation,\nexpected constant "
-        ^ string_of_const thy c ^ "\n" ^ Display.string_of_thm_global thy thm)
-  in map (cert o assert_eqn thy) eqns end;
+    val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
+    val k = fold (Integer.max o num_args_of o Thm.prop_of) thms 0;
+  in map (expand_eta thy k) thms end;
+
+fun mk_desymbolization pre post mk vs =
+  let
+    val names = map (pre o fst o fst) vs
+      |> map (Name.desymbolize false)
+      |> Name.variant_list []
+      |> map post;
+  in map_filter (fn (((v, i), x), v') =>
+    if v = v' andalso i = 0 then NONE
+    else SOME (((v, i), x), mk ((v', 0), x))) (vs ~~ names)
+  end;
+
+fun desymbolize_tvars thy thms =
+  let
+    val tvs = fold (Term.add_tvars o Thm.prop_of) thms [];
+    val tvar_subst = mk_desymbolization (unprefix "'") (prefix "'") TVar tvs;
+  in map (Thm.certify_instantiate (tvar_subst, [])) thms end;
+
+fun desymbolize_vars thy thm =
+  let
+    val vs = Term.add_vars (Thm.prop_of thm) [];
+    val var_subst = mk_desymbolization I I Var vs;
+  in Thm.certify_instantiate ([], var_subst) thm end;
+
+fun canonize_thms thy = desymbolize_tvars thy #> same_arity thy #> map (desymbolize_vars thy);
 
 
 (* code equation certificates *)
 
-type cert = thm * bool list;
+fun build_head thy (c, ty) =
+  Thm.cterm_of thy (Logic.mk_equals (Free ("HEAD", ty), Const (c, ty)));
 
-fun mk_head_cterm thy (c, ty) =
-  Thm.cterm_of thy (Logic.mk_equals (Free ("HEAD", ty), Const (c, ty)));
+fun get_head thy cert_thm =
+  let
+    val [head] = (#hyps o Thm.crep_thm) cert_thm;
+    val (_, Const (c, ty)) = (Logic.dest_equals o Thm.term_of) head;
+  in (typscheme thy (c, ty), head) end;
+
+abstype cert = Cert of thm * bool list with
 
 fun empty_cert thy c = 
   let
@@ -535,14 +589,18 @@
       | NONE => Name.invent_list [] Name.aT (length tvars)
           |> map (fn v => TFree (v, []));
     val ty = typ_subst_TVars (tvars ~~ tvars') raw_ty;
-    val chead = mk_head_cterm thy (c, ty);
-  in (Thm.weaken chead Drule.dummy_thm, []) end;
+    val chead = build_head thy (c, ty);
+  in Cert (Thm.weaken chead Drule.dummy_thm, []) end;
 
 fun cert_of_eqns thy c [] = empty_cert thy c
-  | cert_of_eqns thy c eqns = 
+  | cert_of_eqns thy c raw_eqns = 
       let
-        val _ = assert_eqns_const thy c eqns;
+        val eqns = burrow_fst (canonize_thms thy) raw_eqns;
+        val _ = map (assert_eqn thy) eqns;
         val (thms, propers) = split_list eqns;
+        val _ = map (fn thm => if c = const_eqn thy thm then ()
+          else error ("Wrong head of code equation,\nexpected constant "
+            ^ string_of_const thy c ^ "\n" ^ Display.string_of_thm_global thy thm)) thms;
         fun tvars_of T = rev (Term.add_tvarsT T []);
         val vss = map (tvars_of o snd o head_eqn) thms;
         fun inter_sorts vs =
@@ -551,55 +609,59 @@
         val vts = Name.names Name.context Name.aT sorts;
         val thms as thm :: _ =
           map2 (fn vs => Thm.certify_instantiate (vs ~~ map TFree vts, [])) vss thms;
-        val head_thm = Thm.symmetric (Thm.assume (mk_head_cterm thy (head_eqn (hd thms))));
+        val head_thm = Thm.symmetric (Thm.assume (build_head thy (head_eqn (hd thms))));
         fun head_conv ct = if can Thm.dest_comb ct
           then Conv.fun_conv head_conv ct
           else Conv.rewr_conv head_thm ct;
         val rewrite_head = Conv.fconv_rule (Conv.arg1_conv head_conv);
         val cert_thm = Conjunction.intr_balanced (map rewrite_head thms);
-      in (cert_thm, propers) end;
-
-fun head_cert thy cert_thm =
-  let
-    val [head] = Thm.hyps_of cert_thm;
-    val (Free (h, _), Const (c, ty)) = Logic.dest_equals head;
-  in ((c, typscheme thy (c, ty)), (head, h)) end;
+      in Cert (cert_thm, propers) end;
 
-fun constrain_cert thy sorts (cert_thm, propers) =
+fun constrain_cert thy sorts (Cert (cert_thm, propers)) =
   let
-    val ((c, (vs, _)), (head, _)) = head_cert thy cert_thm;
-    val subst = map2 (fn (v, _) => fn sort => (v, sort)) vs sorts;
-    val head' = (map_types o map_atyps)
-      (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) subst v))) head;
-    val inst = (map2 (fn (v, sort) => fn sort' =>
-      pairself (Thm.ctyp_of thy) (TVar ((v, 0), sort), TFree (v, sort'))) vs sorts, []);
+    val ((vs, _), head) = get_head thy cert_thm;
+    val subst = map2 (fn (v, sort) => fn sort' =>
+      (v, Sorts.inter_sort (Sign.classes_of thy) (sort, sort'))) vs sorts;
+    val head' = Thm.term_of head
+      |> (map_types o map_atyps)
+          (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) subst v)))
+      |> Thm.cterm_of thy;
+    val inst = map2 (fn (v, sort) => fn (_, sort') =>
+      (((v, 0), sort), TFree (v, sort'))) vs subst;
     val cert_thm' = cert_thm
-      |> Thm.implies_intr (Thm.cterm_of thy head)
+      |> Thm.implies_intr head
       |> Thm.varifyT
-      |> Thm.instantiate inst
-      |> Thm.elim_implies (Thm.assume (Thm.cterm_of thy head'));
-  in (cert_thm', propers) end;
+      |> Thm.certify_instantiate (inst, [])
+      |> Thm.elim_implies (Thm.assume head');
+  in (Cert (cert_thm', propers)) end;
 
-fun eqns_of_cert thy (cert_thm, []) = []
-  | eqns_of_cert thy (cert_thm, propers) =
-      let
-        val (_, (head, _)) = head_cert thy cert_thm;
-        val thms = cert_thm
-          |> LocalDefs.expand [Thm.cterm_of thy head]
-          |> Thm.varifyT
-          |> Conjunction.elim_balanced (length propers)
-      in thms ~~ propers end;
+fun typscheme_cert thy (Cert (cert_thm, _)) =
+  fst (get_head thy cert_thm);
 
-fun dest_cert thy (cert as (cert_thm, propers)) =
+fun equations_cert thy (cert as Cert (cert_thm, propers)) =
   let
-    val eqns = eqns_of_cert thy cert;
-    val ((_, vs_ty), _) = head_cert thy cert_thm;
-    val equations = if null propers then [] else cert_thm
-      |> Thm.prop_of
+    val tyscm = typscheme_cert thy cert;
+    val equations = if null propers then [] else
+      Thm.prop_of cert_thm
       |> Logic.dest_conjunction_balanced (length propers)
       |> map Logic.dest_equals
       |> (map o apfst) (snd o strip_comb)
-  in (vs_ty, equations ~~ eqns) end;
+  in (tyscm, equations) end;
+
+fun equations_thms_cert thy (cert as Cert (cert_thm, propers)) =
+  let
+    val (tyscm, equations) = equations_cert thy cert;
+    val thms = if null propers then [] else
+      cert_thm
+      |> LocalDefs.expand [snd (get_head thy cert_thm)]
+      |> Thm.varifyT
+      |> Conjunction.elim_balanced (length propers)
+  in (tyscm, equations ~~ (thms ~~ propers)) end;
+
+fun pretty_cert thy = map (Display.pretty_thm_global thy o AxClass.overload thy o fst o snd)
+  o snd o equations_thms_cert thy;
+
+end;
 
 
 (* code equation access *)