explicit abstract type of code certificates
authorhaftmann
Wed Jan 13 12:20:37 2010 +0100 (2010-01-13)
changeset 3489519fd499cddff
parent 34894 fadbdd350dd1
child 34896 a22b09addd78
child 34898 62d70417f8ce
explicit abstract type of code certificates
src/HOL/Tools/recfun_codegen.ML
src/Pure/Isar/code.ML
src/Tools/Code/code_preproc.ML
src/Tools/Code/code_thingol.ML
     1.1 --- a/src/HOL/Tools/recfun_codegen.ML	Wed Jan 13 10:18:45 2010 +0100
     1.2 +++ b/src/HOL/Tools/recfun_codegen.ML	Wed Jan 13 12:20:37 2010 +0100
     1.3 @@ -36,7 +36,7 @@
     1.4        let val (_, T) = Code.const_typ_eqn thy thm
     1.5        in if null (Term.add_tvarsT T []) orelse (null o fst o strip_type) T
     1.6          then [thm]
     1.7 -        else [Code_Thingol.expand_eta thy 1 thm]
     1.8 +        else [Code.expand_eta thy 1 thm]
     1.9        end
    1.10    | avoid_value thy thms = thms;
    1.11  
    1.12 @@ -44,8 +44,9 @@
    1.13    let
    1.14      val c = AxClass.unoverload_const thy (raw_c, T);
    1.15      val raw_thms = Code.get_cert thy (Code_Preproc.preprocess_functrans thy) c
    1.16 -      |> Code.eqns_of_cert thy
    1.17 -      |> map_filter (fn (thm, linear) => if linear then SOME thm else NONE)
    1.18 +      |> Code.equations_thms_cert thy
    1.19 +      |> snd
    1.20 +      |> map_filter (fn (_, (thm, proper)) => if proper then SOME thm else NONE)
    1.21        |> map (AxClass.overload thy)
    1.22        |> filter (is_instance T o snd o const_of o prop_of);
    1.23      val module_name = case Symtab.lookup (ModuleData.get thy) c
    1.24 @@ -57,7 +58,6 @@
    1.25      raw_thms
    1.26      |> preprocess thy
    1.27      |> avoid_value thy
    1.28 -    |> Code_Thingol.canonize_thms thy
    1.29      |> rpair module_name
    1.30    end;
    1.31  
     2.1 --- a/src/Pure/Isar/code.ML	Wed Jan 13 10:18:45 2010 +0100
     2.2 +++ b/src/Pure/Isar/code.ML	Wed Jan 13 12:20:37 2010 +0100
     2.3 @@ -29,12 +29,15 @@
     2.4    val mk_eqn_liberal: theory -> thm -> (thm * bool) option
     2.5    val assert_eqn: theory -> thm * bool -> thm * bool
     2.6    val const_typ_eqn: theory -> thm -> string * typ
     2.7 -  type cert = thm * bool list
     2.8 +  val expand_eta: theory -> int -> thm -> thm
     2.9 +  type cert
    2.10    val empty_cert: theory -> string -> cert
    2.11    val cert_of_eqns: theory -> string -> (thm * bool) list -> cert
    2.12    val constrain_cert: theory -> sort list -> cert -> cert
    2.13 -  val eqns_of_cert: theory -> cert -> (thm * bool) list
    2.14 -  val dest_cert: theory -> cert -> ((string * sort) list * typ) * ((term list * term) * (thm * bool)) list
    2.15 +  val typscheme_cert: theory -> cert -> (string * sort) list * typ
    2.16 +  val equations_cert: theory -> cert -> ((string * sort) list * typ) * (term list * term) list
    2.17 +  val equations_thms_cert: theory -> cert -> ((string * sort) list * typ) * ((term list * term) * (thm * bool)) list
    2.18 +  val pretty_cert: theory -> cert -> Pretty.T list
    2.19  
    2.20    (*executable code*)
    2.21    val add_type: string -> theory -> theory
    2.22 @@ -511,20 +514,71 @@
    2.23  
    2.24  fun typscheme thy (c, ty) = logical_typscheme thy (c, subst_signature thy c ty);
    2.25  
    2.26 -fun assert_eqns_const thy c eqns =
    2.27 +
    2.28 +(* technical transformations of code equations *)
    2.29 +
    2.30 +fun expand_eta thy k thm =
    2.31 +  let
    2.32 +    val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm;
    2.33 +    val (_, args) = strip_comb lhs;
    2.34 +    val l = if k = ~1
    2.35 +      then (length o fst o strip_abs) rhs
    2.36 +      else Int.max (0, k - length args);
    2.37 +    val (raw_vars, _) = Term.strip_abs_eta l rhs;
    2.38 +    val vars = burrow_fst (Name.variant_list (map (fst o fst) (Term.add_vars lhs [])))
    2.39 +      raw_vars;
    2.40 +    fun expand (v, ty) thm = Drule.fun_cong_rule thm
    2.41 +      (Thm.cterm_of thy (Var ((v, 0), ty)));
    2.42 +  in
    2.43 +    thm
    2.44 +    |> fold expand vars
    2.45 +    |> Conv.fconv_rule Drule.beta_eta_conversion
    2.46 +  end;
    2.47 +
    2.48 +fun same_arity thy thms =
    2.49    let
    2.50 -    fun cert (eqn as (thm, _)) = if c = const_eqn thy thm
    2.51 -      then eqn else error ("Wrong head of code equation,\nexpected constant "
    2.52 -        ^ string_of_const thy c ^ "\n" ^ Display.string_of_thm_global thy thm)
    2.53 -  in map (cert o assert_eqn thy) eqns end;
    2.54 +    val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
    2.55 +    val k = fold (Integer.max o num_args_of o Thm.prop_of) thms 0;
    2.56 +  in map (expand_eta thy k) thms end;
    2.57 +
    2.58 +fun mk_desymbolization pre post mk vs =
    2.59 +  let
    2.60 +    val names = map (pre o fst o fst) vs
    2.61 +      |> map (Name.desymbolize false)
    2.62 +      |> Name.variant_list []
    2.63 +      |> map post;
    2.64 +  in map_filter (fn (((v, i), x), v') =>
    2.65 +    if v = v' andalso i = 0 then NONE
    2.66 +    else SOME (((v, i), x), mk ((v', 0), x))) (vs ~~ names)
    2.67 +  end;
    2.68 +
    2.69 +fun desymbolize_tvars thy thms =
    2.70 +  let
    2.71 +    val tvs = fold (Term.add_tvars o Thm.prop_of) thms [];
    2.72 +    val tvar_subst = mk_desymbolization (unprefix "'") (prefix "'") TVar tvs;
    2.73 +  in map (Thm.certify_instantiate (tvar_subst, [])) thms end;
    2.74 +
    2.75 +fun desymbolize_vars thy thm =
    2.76 +  let
    2.77 +    val vs = Term.add_vars (Thm.prop_of thm) [];
    2.78 +    val var_subst = mk_desymbolization I I Var vs;
    2.79 +  in Thm.certify_instantiate ([], var_subst) thm end;
    2.80 +
    2.81 +fun canonize_thms thy = desymbolize_tvars thy #> same_arity thy #> map (desymbolize_vars thy);
    2.82  
    2.83  
    2.84  (* code equation certificates *)
    2.85  
    2.86 -type cert = thm * bool list;
    2.87 +fun build_head thy (c, ty) =
    2.88 +  Thm.cterm_of thy (Logic.mk_equals (Free ("HEAD", ty), Const (c, ty)));
    2.89  
    2.90 -fun mk_head_cterm thy (c, ty) =
    2.91 -  Thm.cterm_of thy (Logic.mk_equals (Free ("HEAD", ty), Const (c, ty)));
    2.92 +fun get_head thy cert_thm =
    2.93 +  let
    2.94 +    val [head] = (#hyps o Thm.crep_thm) cert_thm;
    2.95 +    val (_, Const (c, ty)) = (Logic.dest_equals o Thm.term_of) head;
    2.96 +  in (typscheme thy (c, ty), head) end;
    2.97 +
    2.98 +abstype cert = Cert of thm * bool list with
    2.99  
   2.100  fun empty_cert thy c = 
   2.101    let
   2.102 @@ -535,14 +589,18 @@
   2.103        | NONE => Name.invent_list [] Name.aT (length tvars)
   2.104            |> map (fn v => TFree (v, []));
   2.105      val ty = typ_subst_TVars (tvars ~~ tvars') raw_ty;
   2.106 -    val chead = mk_head_cterm thy (c, ty);
   2.107 -  in (Thm.weaken chead Drule.dummy_thm, []) end;
   2.108 +    val chead = build_head thy (c, ty);
   2.109 +  in Cert (Thm.weaken chead Drule.dummy_thm, []) end;
   2.110  
   2.111  fun cert_of_eqns thy c [] = empty_cert thy c
   2.112 -  | cert_of_eqns thy c eqns = 
   2.113 +  | cert_of_eqns thy c raw_eqns = 
   2.114        let
   2.115 -        val _ = assert_eqns_const thy c eqns;
   2.116 +        val eqns = burrow_fst (canonize_thms thy) raw_eqns;
   2.117 +        val _ = map (assert_eqn thy) eqns;
   2.118          val (thms, propers) = split_list eqns;
   2.119 +        val _ = map (fn thm => if c = const_eqn thy thm then ()
   2.120 +          else error ("Wrong head of code equation,\nexpected constant "
   2.121 +            ^ string_of_const thy c ^ "\n" ^ Display.string_of_thm_global thy thm)) thms;
   2.122          fun tvars_of T = rev (Term.add_tvarsT T []);
   2.123          val vss = map (tvars_of o snd o head_eqn) thms;
   2.124          fun inter_sorts vs =
   2.125 @@ -551,55 +609,59 @@
   2.126          val vts = Name.names Name.context Name.aT sorts;
   2.127          val thms as thm :: _ =
   2.128            map2 (fn vs => Thm.certify_instantiate (vs ~~ map TFree vts, [])) vss thms;
   2.129 -        val head_thm = Thm.symmetric (Thm.assume (mk_head_cterm thy (head_eqn (hd thms))));
   2.130 +        val head_thm = Thm.symmetric (Thm.assume (build_head thy (head_eqn (hd thms))));
   2.131          fun head_conv ct = if can Thm.dest_comb ct
   2.132            then Conv.fun_conv head_conv ct
   2.133            else Conv.rewr_conv head_thm ct;
   2.134          val rewrite_head = Conv.fconv_rule (Conv.arg1_conv head_conv);
   2.135          val cert_thm = Conjunction.intr_balanced (map rewrite_head thms);
   2.136 -      in (cert_thm, propers) end;
   2.137 -
   2.138 -fun head_cert thy cert_thm =
   2.139 -  let
   2.140 -    val [head] = Thm.hyps_of cert_thm;
   2.141 -    val (Free (h, _), Const (c, ty)) = Logic.dest_equals head;
   2.142 -  in ((c, typscheme thy (c, ty)), (head, h)) end;
   2.143 +      in Cert (cert_thm, propers) end;
   2.144  
   2.145 -fun constrain_cert thy sorts (cert_thm, propers) =
   2.146 +fun constrain_cert thy sorts (Cert (cert_thm, propers)) =
   2.147    let
   2.148 -    val ((c, (vs, _)), (head, _)) = head_cert thy cert_thm;
   2.149 -    val subst = map2 (fn (v, _) => fn sort => (v, sort)) vs sorts;
   2.150 -    val head' = (map_types o map_atyps)
   2.151 -      (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) subst v))) head;
   2.152 -    val inst = (map2 (fn (v, sort) => fn sort' =>
   2.153 -      pairself (Thm.ctyp_of thy) (TVar ((v, 0), sort), TFree (v, sort'))) vs sorts, []);
   2.154 +    val ((vs, _), head) = get_head thy cert_thm;
   2.155 +    val subst = map2 (fn (v, sort) => fn sort' =>
   2.156 +      (v, Sorts.inter_sort (Sign.classes_of thy) (sort, sort'))) vs sorts;
   2.157 +    val head' = Thm.term_of head
   2.158 +      |> (map_types o map_atyps)
   2.159 +          (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) subst v)))
   2.160 +      |> Thm.cterm_of thy;
   2.161 +    val inst = map2 (fn (v, sort) => fn (_, sort') =>
   2.162 +      (((v, 0), sort), TFree (v, sort'))) vs subst;
   2.163      val cert_thm' = cert_thm
   2.164 -      |> Thm.implies_intr (Thm.cterm_of thy head)
   2.165 +      |> Thm.implies_intr head
   2.166        |> Thm.varifyT
   2.167 -      |> Thm.instantiate inst
   2.168 -      |> Thm.elim_implies (Thm.assume (Thm.cterm_of thy head'));
   2.169 -  in (cert_thm', propers) end;
   2.170 +      |> Thm.certify_instantiate (inst, [])
   2.171 +      |> Thm.elim_implies (Thm.assume head');
   2.172 +  in (Cert (cert_thm', propers)) end;
   2.173  
   2.174 -fun eqns_of_cert thy (cert_thm, []) = []
   2.175 -  | eqns_of_cert thy (cert_thm, propers) =
   2.176 -      let
   2.177 -        val (_, (head, _)) = head_cert thy cert_thm;
   2.178 -        val thms = cert_thm
   2.179 -          |> LocalDefs.expand [Thm.cterm_of thy head]
   2.180 -          |> Thm.varifyT
   2.181 -          |> Conjunction.elim_balanced (length propers)
   2.182 -      in thms ~~ propers end;
   2.183 +fun typscheme_cert thy (Cert (cert_thm, _)) =
   2.184 +  fst (get_head thy cert_thm);
   2.185  
   2.186 -fun dest_cert thy (cert as (cert_thm, propers)) =
   2.187 +fun equations_cert thy (cert as Cert (cert_thm, propers)) =
   2.188    let
   2.189 -    val eqns = eqns_of_cert thy cert;
   2.190 -    val ((_, vs_ty), _) = head_cert thy cert_thm;
   2.191 -    val equations = if null propers then [] else cert_thm
   2.192 -      |> Thm.prop_of
   2.193 +    val tyscm = typscheme_cert thy cert;
   2.194 +    val equations = if null propers then [] else
   2.195 +      Thm.prop_of cert_thm
   2.196        |> Logic.dest_conjunction_balanced (length propers)
   2.197        |> map Logic.dest_equals
   2.198        |> (map o apfst) (snd o strip_comb)
   2.199 -  in (vs_ty, equations ~~ eqns) end;
   2.200 +  in (tyscm, equations) end;
   2.201 +
   2.202 +fun equations_thms_cert thy (cert as Cert (cert_thm, propers)) =
   2.203 +  let
   2.204 +    val (tyscm, equations) = equations_cert thy cert;
   2.205 +    val thms = if null propers then [] else
   2.206 +      cert_thm
   2.207 +      |> LocalDefs.expand [snd (get_head thy cert_thm)]
   2.208 +      |> Thm.varifyT
   2.209 +      |> Conjunction.elim_balanced (length propers)
   2.210 +  in (tyscm, equations ~~ (thms ~~ propers)) end;
   2.211 +
   2.212 +fun pretty_cert thy = map (Display.pretty_thm_global thy o AxClass.overload thy o fst o snd)
   2.213 +  o snd o equations_thms_cert thy;
   2.214 +
   2.215 +end;
   2.216  
   2.217  
   2.218  (* code equation access *)
     3.1 --- a/src/Tools/Code/code_preproc.ML	Wed Jan 13 10:18:45 2010 +0100
     3.2 +++ b/src/Tools/Code/code_preproc.ML	Wed Jan 13 12:20:37 2010 +0100
     3.3 @@ -199,11 +199,7 @@
     3.4    AList.make (snd o Graph.get_node eqngr) (Graph.keys eqngr)
     3.5    |> (map o apfst) (Code.string_of_const thy)
     3.6    |> sort (string_ord o pairself fst)
     3.7 -  |> map (fn (s, cert) =>
     3.8 -       (Pretty.block o Pretty.fbreaks) (
     3.9 -         Pretty.str s
    3.10 -         :: map (Display.pretty_thm_global thy o AxClass.overload thy o fst) (Code.eqns_of_cert thy cert)
    3.11 -       ))
    3.12 +  |> map (fn (s, cert) => (Pretty.block o Pretty.fbreaks) (Pretty.str s :: Code.pretty_cert thy cert))
    3.13    |> Pretty.chunks;
    3.14  
    3.15  
    3.16 @@ -220,13 +216,13 @@
    3.17    map (fn (c, _) => AxClass.param_of_inst thy (c, tyco))
    3.18      o maps (#params o AxClass.get_info thy);
    3.19  
    3.20 -fun typscheme_rhss thy c cert =
    3.21 +fun typargs_rhss thy c cert =
    3.22    let
    3.23 -    val (tyscm, equations) = Code.dest_cert thy cert;
    3.24 +    val ((vs, _), equations) = Code.equations_cert thy cert;
    3.25      val rhss = [] |> (fold o fold o fold_aterms)
    3.26        (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, ty)) | _ => I)
    3.27 -        (map (op :: o swap o fst) equations);
    3.28 -  in (tyscm, rhss) end;
    3.29 +        (map (op :: o swap) equations);
    3.30 +  in (vs, rhss) end;
    3.31  
    3.32  
    3.33  (* data structures *)
    3.34 @@ -266,7 +262,7 @@
    3.35     of SOME (lhs, cert) => ((lhs, []), cert)
    3.36      | NONE => let
    3.37          val cert = Code.get_cert thy (preprocess thy) c;
    3.38 -        val ((lhs, _), rhss) = typscheme_rhss thy c cert;
    3.39 +        val (lhs, rhss) = typargs_rhss thy c cert;
    3.40        in ((lhs, rhss), cert) end;
    3.41  
    3.42  fun obtain_instance thy arities (inst as (class, tyco)) =
    3.43 @@ -388,14 +384,6 @@
    3.44         handle Sorts.CLASS_ERROR _ => [] (*permissive!*))
    3.45    end;
    3.46  
    3.47 -fun inst_cert thy lhs cert =
    3.48 -  let
    3.49 -    val ((vs, _), _) = Code.dest_cert thy cert;
    3.50 -    val sorts = map (fn (v, sort) => case AList.lookup (op =) lhs v
    3.51 -     of SOME sort' => Sorts.inter_sort (Sign.classes_of thy) (sort, sort')
    3.52 -      | NONE => sort) vs;
    3.53 -  in Code.constrain_cert thy sorts cert end;
    3.54 -
    3.55  fun add_arity thy vardeps (class, tyco) =
    3.56    AList.default (op =) ((class, tyco),
    3.57      map_range (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k))
    3.58 @@ -406,8 +394,8 @@
    3.59    else let
    3.60      val lhs = map_index (fn (k, (v, _)) =>
    3.61        (v, snd (Vargraph.get_node vardeps (Fun c, k)))) proto_lhs;
    3.62 -    val cert = inst_cert thy lhs proto_cert;
    3.63 -    val ((vs, _), rhss') = typscheme_rhss thy c cert;
    3.64 +    val cert = Code.constrain_cert thy (map snd lhs) proto_cert;
    3.65 +    val (vs, rhss') = typargs_rhss thy c cert;
    3.66      val eqngr' = Graph.new_node (c, (vs, cert)) eqngr;
    3.67    in (map (pair c) rhss' @ rhss, eqngr') end;
    3.68  
     4.1 --- a/src/Tools/Code/code_thingol.ML	Wed Jan 13 10:18:45 2010 +0100
     4.2 +++ b/src/Tools/Code/code_thingol.ML	Wed Jan 13 12:20:37 2010 +0100
     4.3 @@ -86,8 +86,6 @@
     4.4      -> ((string * stmt) list * (string * stmt) list
     4.5        * ((string * stmt) list * (string * stmt) list)) list
     4.6  
     4.7 -  val expand_eta: theory -> int -> thm -> thm
     4.8 -  val canonize_thms: theory -> thm list -> thm list
     4.9    val read_const_exprs: theory -> string list -> string list * string list
    4.10    val consts_program: theory -> string list -> string list * (naming * program)
    4.11    val eval_conv: theory
    4.12 @@ -397,60 +395,6 @@
    4.13  end; (* local *)
    4.14  
    4.15  
    4.16 -(** technical transformations of code equations **)
    4.17 -
    4.18 -fun expand_eta thy k thm =
    4.19 -  let
    4.20 -    val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm;
    4.21 -    val (_, args) = strip_comb lhs;
    4.22 -    val l = if k = ~1
    4.23 -      then (length o fst o strip_abs) rhs
    4.24 -      else Int.max (0, k - length args);
    4.25 -    val (raw_vars, _) = Term.strip_abs_eta l rhs;
    4.26 -    val vars = burrow_fst (Name.variant_list (map (fst o fst) (Term.add_vars lhs [])))
    4.27 -      raw_vars;
    4.28 -    fun expand (v, ty) thm = Drule.fun_cong_rule thm
    4.29 -      (Thm.cterm_of thy (Var ((v, 0), ty)));
    4.30 -  in
    4.31 -    thm
    4.32 -    |> fold expand vars
    4.33 -    |> Conv.fconv_rule Drule.beta_eta_conversion
    4.34 -  end;
    4.35 -
    4.36 -fun same_arity thy thms =
    4.37 -  let
    4.38 -    val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
    4.39 -    val k = fold (Integer.max o num_args_of o Thm.prop_of) thms 0;
    4.40 -  in map (expand_eta thy k) thms end;
    4.41 -
    4.42 -fun mk_desymbolization pre post mk vs =
    4.43 -  let
    4.44 -    val names = map (pre o fst o fst) vs
    4.45 -      |> map (Name.desymbolize false)
    4.46 -      |> Name.variant_list []
    4.47 -      |> map post;
    4.48 -  in map_filter (fn (((v, i), x), v') =>
    4.49 -    if v = v' andalso i = 0 then NONE
    4.50 -    else SOME (((v, i), x), mk ((v', 0), x))) (vs ~~ names)
    4.51 -  end;
    4.52 -
    4.53 -fun desymbolize_tvars thy thms =
    4.54 -  let
    4.55 -    val tvs = fold (Term.add_tvars o Thm.prop_of) thms [];
    4.56 -    val tvar_subst = mk_desymbolization (unprefix "'") (prefix "'") TVar tvs;
    4.57 -  in map (Thm.certify_instantiate (tvar_subst, [])) thms end;
    4.58 -
    4.59 -fun desymbolize_vars thy thm =
    4.60 -  let
    4.61 -    val vs = Term.add_vars (Thm.prop_of thm) [];
    4.62 -    val var_subst = mk_desymbolization I I Var vs;
    4.63 -  in Thm.certify_instantiate ([], var_subst) thm end;
    4.64 -
    4.65 -fun canonize_thms thy = map (Thm.transfer thy)
    4.66 -  #> desymbolize_tvars thy
    4.67 -  #> same_arity thy #> map (desymbolize_vars thy);
    4.68 -
    4.69 -
    4.70  (** statements, abstract programs **)
    4.71  
    4.72  type typscheme = (vname * sort) list * itype;
    4.73 @@ -614,8 +558,8 @@
    4.74        #>> (fn class => Classparam (c, class));
    4.75      fun stmt_fun cert =
    4.76        let
    4.77 -        val ((vs, ty), raw_eqns) = Code.dest_cert thy cert;
    4.78 -        val eqns = burrow_fst (canonize_thms thy) (map snd raw_eqns);
    4.79 +        val ((vs, ty), raw_eqns) = Code.equations_thms_cert thy cert;
    4.80 +        val eqns = map snd raw_eqns;
    4.81        in
    4.82          fold_map (translate_tyvar_sort thy algbr eqngr) vs
    4.83          ##>> translate_typ thy algbr eqngr ty