moved all technical processing of code equations to code_thingol.ML
authorhaftmann
Mon Aug 10 12:24:49 2009 +0200 (2009-08-10)
changeset 323530ac26087464b
parent 32352 4839a704939a
child 32354 bb40e900e1f3
moved all technical processing of code equations to code_thingol.ML
src/HOL/Tools/recfun_codegen.ML
src/Tools/Code/code_preproc.ML
src/Tools/Code/code_thingol.ML
     1.1 --- a/src/HOL/Tools/recfun_codegen.ML	Mon Aug 10 12:24:47 2009 +0200
     1.2 +++ b/src/HOL/Tools/recfun_codegen.ML	Mon Aug 10 12:24:49 2009 +0200
     1.3 @@ -31,25 +31,18 @@
     1.4      |> ModuleData.map (Symtab.update (fst (Code.const_typ_eqn thy thm'), module_name))
     1.5    end;
     1.6  
     1.7 -fun expand_eta thy [] = []
     1.8 -  | expand_eta thy (thms as thm :: _) =
     1.9 -      let
    1.10 -        val (_, ty) = Code.const_typ_eqn thy thm;
    1.11 -      in if null (Term.add_tvarsT ty []) orelse (null o fst o strip_type) ty
    1.12 -        then thms
    1.13 -        else map (Code.expand_eta thy 1) thms
    1.14 -      end;
    1.15 -
    1.16  fun retrieve_equations thy (c, T) = if c = @{const_name "op ="} then NONE else
    1.17    let
    1.18      val c' = AxClass.unoverload_const thy (c, T);
    1.19      val opt_name = Symtab.lookup (ModuleData.get thy) c';
    1.20 -    val thms = Code.these_eqns thy c'
    1.21 -      |> map_filter (fn (thm, linear) => if linear then SOME thm else NONE)
    1.22 -      |> expand_eta thy
    1.23 -      |> Code.desymbolize_all_vars thy
    1.24 -      |> map (rpair opt_name)
    1.25 -  in if null thms then NONE else SOME thms end;
    1.26 +    val raw_thms = Code.these_eqns thy c'
    1.27 +      |> map_filter (fn (thm, linear) => if linear then SOME thm else NONE);
    1.28 +  in if null raw_thms then NONE else
    1.29 +    raw_thms
    1.30 +    |> Code_Thingol.clean_thms thy (snd (Code.const_typ_eqn thy (hd raw_thms)))
    1.31 +    |> map (rpair opt_name)
    1.32 +    |> SOME
    1.33 +  end;
    1.34  
    1.35  val dest_eqn = Logic.dest_equals;
    1.36  val const_of = dest_Const o head_of o fst o dest_eqn;
     2.1 --- a/src/Tools/Code/code_preproc.ML	Mon Aug 10 12:24:47 2009 +0200
     2.2 +++ b/src/Tools/Code/code_preproc.ML	Mon Aug 10 12:24:49 2009 +0200
     2.3 @@ -132,12 +132,6 @@
     2.4    #> Logic.dest_equals
     2.5    #> snd;
     2.6  
     2.7 -fun same_arity thy thms =
     2.8 -  let
     2.9 -    val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
    2.10 -    val k = fold (curry Int.max o num_args_of o Thm.prop_of) thms 0;
    2.11 -  in map (Code.expand_eta thy k) thms end;
    2.12 -
    2.13  fun preprocess thy c eqns =
    2.14    let
    2.15      val pre = (Simplifier.theory_context thy o #pre o the_thmproc) thy;
    2.16 @@ -149,7 +143,6 @@
    2.17      |> (map o apfst) (rewrite_eqn pre)
    2.18      |> (map o apfst) (AxClass.unoverload thy)
    2.19      |> map (Code.assert_eqn thy)
    2.20 -    |> burrow_fst (same_arity thy)
    2.21    end;
    2.22  
    2.23  fun preprocess_conv thy ct =
     3.1 --- a/src/Tools/Code/code_thingol.ML	Mon Aug 10 12:24:47 2009 +0200
     3.2 +++ b/src/Tools/Code/code_thingol.ML	Mon Aug 10 12:24:49 2009 +0200
     3.3 @@ -79,6 +79,7 @@
     3.4    val is_cons: program -> string -> bool
     3.5    val contr_classparam_typs: program -> string -> itype option list
     3.6  
     3.7 +  val clean_thms: theory -> typ -> thm list -> thm list
     3.8    val read_const_exprs: theory -> string list -> string list * string list
     3.9    val consts_program: theory -> string list -> string list * (naming * program)
    3.10    val cached_program: theory -> naming * program
    3.11 @@ -376,6 +377,67 @@
    3.12  end; (* local *)
    3.13  
    3.14  
    3.15 +(** technical transformations of code equations **)
    3.16 +
    3.17 +fun expand_eta thy k thm =
    3.18 +  let
    3.19 +    val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm;
    3.20 +    val (head, args) = strip_comb lhs;
    3.21 +    val l = if k = ~1
    3.22 +      then (length o fst o strip_abs) rhs
    3.23 +      else Int.max (0, k - length args);
    3.24 +    val (raw_vars, _) = Term.strip_abs_eta l rhs;
    3.25 +    val vars = burrow_fst (Name.variant_list (map (fst o fst) (Term.add_vars lhs [])))
    3.26 +      raw_vars;
    3.27 +    fun expand (v, ty) thm = Drule.fun_cong_rule thm
    3.28 +      (Thm.cterm_of thy (Var ((v, 0), ty)));
    3.29 +  in
    3.30 +    thm
    3.31 +    |> fold expand vars
    3.32 +    |> Conv.fconv_rule Drule.beta_eta_conversion
    3.33 +  end;
    3.34 +
    3.35 +fun same_arity thy thms =
    3.36 +  let
    3.37 +    val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
    3.38 +    val k = fold (curry Int.max o num_args_of o Thm.prop_of) thms 0;
    3.39 +  in map (expand_eta thy k) thms end;
    3.40 +
    3.41 +fun avoid_value thy ty [thm] =
    3.42 +      if null (Term.add_tfreesT ty []) orelse (null o fst o strip_type) ty
    3.43 +      then [thm]
    3.44 +      else [expand_eta thy 1 thm]
    3.45 +  | avoid_value thy _ thms = thms;
    3.46 +
    3.47 +fun mk_desymbolization pre post mk vs =
    3.48 +  let
    3.49 +    val names = map (pre o fst o fst) vs
    3.50 +      |> map (Name.desymbolize false)
    3.51 +      |> Name.variant_list []
    3.52 +      |> map post;
    3.53 +  in map_filter (fn (((v, i), x), v') =>
    3.54 +    if v = v' andalso i = 0 then NONE
    3.55 +    else SOME (((v, i), x), mk ((v', 0), x))) (vs ~~ names)
    3.56 +  end;
    3.57 +
    3.58 +fun desymbolize_tvars thy thms =
    3.59 +  let
    3.60 +    val tvs = fold (Term.add_tvars o Thm.prop_of) thms [];
    3.61 +    val tvar_subst = mk_desymbolization (unprefix "'") (prefix "'") TVar tvs;
    3.62 +  in map (Thm.certify_instantiate (tvar_subst, [])) thms end;
    3.63 +
    3.64 +fun desymbolize_vars thy thm =
    3.65 +  let
    3.66 +    val vs = Term.add_vars (Thm.prop_of thm) [];
    3.67 +    val var_subst = mk_desymbolization I I Var vs;
    3.68 +  in Thm.certify_instantiate ([], var_subst) thm end;
    3.69 +
    3.70 +fun desymbolize_all_vars thy = desymbolize_tvars thy #> map (desymbolize_vars thy);
    3.71 +
    3.72 +fun clean_thms thy ty =
    3.73 +  same_arity thy #> avoid_value thy ty #> desymbolize_all_vars thy;
    3.74 +
    3.75 +
    3.76  (** statements, abstract programs **)
    3.77  
    3.78  type typscheme = (vname * sort) list * itype;
    3.79 @@ -498,17 +560,11 @@
    3.80      fun stmt_classparam class =
    3.81        ensure_class thy algbr funcgr class
    3.82        #>> (fn class => Classparam (c, class));
    3.83 -    fun stmt_fun ((vs, ty), raw_eqns) =
    3.84 -      let
    3.85 -        val eqns = if null (Term.add_tfreesT ty []) orelse (null o fst o strip_type) ty
    3.86 -          then raw_eqns
    3.87 -          else (map o apfst) (Code.expand_eta thy 1) raw_eqns;
    3.88 -      in
    3.89 -        fold_map (translate_tyvar_sort thy algbr funcgr) vs
    3.90 -        ##>> translate_typ thy algbr funcgr ty
    3.91 -        ##>> translate_eqns thy algbr funcgr eqns
    3.92 -        #>> (fn info => Fun (c, info))
    3.93 -      end;
    3.94 +    fun stmt_fun ((vs, ty), eqns) =
    3.95 +      fold_map (translate_tyvar_sort thy algbr funcgr) vs
    3.96 +      ##>> translate_typ thy algbr funcgr ty
    3.97 +      ##>> fold_map (translate_eqn thy algbr funcgr) (burrow_fst (clean_thms thy ty) eqns)
    3.98 +      #>> (fn info => Fun (c, info));
    3.99      val stmt_const = case Code.get_datatype_of_constr thy c
   3.100       of SOME tyco => stmt_datatypecons tyco
   3.101        | NONE => (case AxClass.class_of_param thy c
   3.102 @@ -597,9 +653,6 @@
   3.103              translate_term thy algbr funcgr thm t'
   3.104              ##>> fold_map (translate_term thy algbr funcgr thm) ts
   3.105              #>> (fn (t, ts) => t `$$ ts)
   3.106 -and translate_eqns thy algbr funcgr eqns =
   3.107 -  fold_map (translate_eqn thy algbr funcgr)
   3.108 -    (burrow_fst (Code.desymbolize_all_vars thy) eqns)
   3.109  and translate_eqn thy algbr funcgr (thm, proper) =
   3.110    let
   3.111      val (args, rhs) = (apfst (snd o strip_comb) o Logic.dest_equals