src/Pure/Tools/codegen_theorems.ML
author haftmann
Thu, 06 Apr 2006 16:08:25 +0200
changeset 19341 3414c04fbc39
parent 19280 5091dc43817b
child 19436 3f5835aac3ce
permissions -rw-r--r--
added definitional code generator module: codegen_theorems.ML

(*  Title:      Pure/Tools/codegen_theorems.ML
    ID:         $Id$
    Author:     Florian Haftmann, TU Muenchen

Theorems used for code generation.
*)

signature CODEGEN_THEOREMS =
sig
  val add_notify: (string option -> theory -> theory) -> theory -> theory;
  val add_preproc: (theory -> thm list -> thm list) -> theory -> theory;
  val add_fun_extr: (theory -> string * typ -> thm list) -> theory -> theory;
  val add_datatype_extr: (theory -> string
     -> (((string * sort) list * (string * typ list) list) * tactic) option)
    -> theory -> theory;
  val add_fun: thm -> theory -> theory;
  val add_pred: thm -> theory -> theory;
  val add_unfold: thm -> theory -> theory;
  val del_def: thm -> theory -> theory;
  val del_unfold: thm -> theory -> theory;
  val purge_defs: string * typ -> theory -> theory;

  val common_typ: theory -> (thm -> typ) -> thm list -> thm list;
  val preprocess: theory -> (thm -> typ) option -> thm list -> thm list;
  val preprocess_fun: theory -> thm list -> (typ * thm list) option;
  val preprocess_term: theory -> term -> term;
  val get_funs: theory -> string * typ -> (typ * thm list) option;
  val get_datatypes: theory -> string
    -> (((string * sort) list * (string * typ list) list) * thm list) option;

  val debug: bool ref;
  val debug_msg: ('a -> string) -> 'a -> 'a;

  val print_thms: theory -> unit;
  val init_obj: theory -> string -> string * (thm list -> tactic) -> string * (thm list -> tactic)
    -> string * (thm list -> tactic) -> string * (thm list -> tactic) -> unit;
end;

structure CodegenTheorems: CODEGEN_THEOREMS =
struct

(** auxiliary **)

val debug = ref false;
fun debug_msg f x = (if !debug then Output.debug (f x) else (); x);


(** object logic **)

val obj_bool_ref : string option ref = ref NONE;
val obj_true_ref : string option ref = ref NONE;
val obj_false_ref : string option ref = ref NONE;
val obj_and_ref : string option ref = ref NONE;
val obj_eq_ref : string option ref = ref NONE;
val obj_eq_elim_ref : thm option ref = ref NONE;
fun idem c = (the o !) c;

fun mk_tf sel =
  let
    val bool_typ = Type (idem obj_bool_ref, []);
    val name = idem
      (if sel then obj_true_ref else obj_false_ref);
  in
    Const (name, bool_typ)
  end handle Option => error "no object logic setup for code theorems";

fun mk_obj_conj (x, y) =
  let
    val bool_typ = Type (idem obj_bool_ref, []);
  in
    Const (idem obj_and_ref, bool_typ --> bool_typ --> bool_typ) $ x $ y
  end handle Option => error "no object logic setup for code theorems";

fun mk_obj_eq (x, y) =
  let
    val bool_typ = Type (idem obj_bool_ref, []);
  in
    Const (idem obj_eq_ref, type_of x --> type_of y --> bool_typ) $ x $ y
  end handle Option => error "no object logic setup for code theorems";

fun is_obj_eq c =
  c = idem obj_eq_ref
    handle Option => error "no object logic setup for code theorems";

fun mk_bool_eq ((x, y), rhs) =
  let
    val bool_typ = Type (idem obj_bool_ref, []);
  in
    Logic.mk_equals (
      (mk_obj_eq (x, y)),
      rhs
    )
  end handle Option => error "no object logic setup for code theorems";

fun elim_obj_eq thm = rewrite_rule [idem obj_eq_elim_ref] thm
  handle Option => error "no object logic setup for code theorems";

fun init_obj thy bohl (truh, truh_tac) (fals, fals_tac) (ant, ant_tac) (eq, eq_tac) =
  let
    val _ = if (is_some o !) obj_bool_ref
      then error "already set" else ()
    val bool_typ = Type (bohl, []);
    val free_typ  = TFree ("'a", Sign.defaultS thy);
    val var_x = Free ("x", free_typ);
    val var_y = Free ("y", free_typ);
    val prop_P = Free ("P", bool_typ);
    val prop_Q = Free ("Q", bool_typ);
    val _ = Goal.prove thy [] []
      (ObjectLogic.ensure_propT thy (Const (truh, bool_typ))) truh_tac;
    val _ = Goal.prove thy ["P"] [ObjectLogic.ensure_propT thy (Const (fals, bool_typ))]
      (ObjectLogic.ensure_propT thy prop_P) fals_tac;
    val _ = Goal.prove thy ["P", "Q"] [ObjectLogic.ensure_propT thy prop_P, ObjectLogic.ensure_propT thy prop_Q]
      (ObjectLogic.ensure_propT thy (Const (ant, bool_typ --> bool_typ --> bool_typ) $ prop_P $ prop_Q)) ant_tac;
    val atomize_eq = Goal.prove thy ["x", "y"] []
      (Logic.mk_equals (
        Logic.mk_equals (var_x, var_y),
        ObjectLogic.ensure_propT thy
          (Const (eq, free_typ --> free_typ --> bool_typ) $ var_x $ var_y))) eq_tac;
  in
    obj_bool_ref := SOME bohl;
    obj_true_ref := SOME truh;
    obj_false_ref := SOME fals;
    obj_and_ref := SOME ant;
    obj_eq_ref := SOME eq;
    obj_eq_elim_ref := SOME (Thm.symmetric atomize_eq)
  end;


(** auxiliary **)

fun destr_fun thy thm =
  case try (
    prop_of
    #> ObjectLogic.drop_judgment thy
    #> Logic.dest_equals
    #> fst
    #> strip_comb
    #> fst
    #> dest_Const
  ) (elim_obj_eq thm)
   of SOME c_ty => SOME (c_ty, thm)
    | NONE => NONE;

fun dest_fun thy thm =
  case destr_fun thy thm
   of SOME x => x
    | NONE => error ("not a function equation: " ^ string_of_thm thm);

fun dest_pred thm =
  case try (fst o dest_Const o fst o strip_comb o snd o Logic.dest_implies o prop_of) thm
   of SOME c => SOME (c, thm)
    | NONE => NONE;

fun getf_first [] _ = NONE
  | getf_first (f::fs) x = case f x
     of NONE => getf_first fs x
      | y as SOME x => y;

fun getf_first_list [] x = []
  | getf_first_list (f::fs) x = case f x
     of [] => getf_first_list fs x
      | xs => xs;
      

(** theory data **)

datatype procs = Procs of {
  preprocs: (serial * (theory -> thm list -> thm list)) list,
  notify: (serial * (string option -> theory -> theory)) list
};

fun mk_procs (preprocs, notify) = Procs { preprocs = preprocs, notify = notify };
fun map_procs f (Procs { preprocs, notify }) = mk_procs (f (preprocs, notify));
fun merge_procs _ (Procs { preprocs = preprocs1, notify = notify1 },
  Procs { preprocs = preprocs2, notify = notify2 }) =
    mk_procs (AList.merge (op =) (K true) (preprocs1, preprocs2),
      AList.merge (op =) (K true) (notify1, notify2));

datatype extrs = Extrs of {
  funs: (serial * (theory -> string * typ -> thm list)) list,
  datatypes: (serial * (theory -> string -> (((string * sort) list * (string * typ list) list) * tactic) option)) list
};

fun mk_extrs (funs, datatypes) = Extrs { funs = funs, datatypes = datatypes };
fun map_extrs f (Extrs { funs, datatypes }) = mk_extrs (f (funs, datatypes));
fun merge_extrs _ (Extrs { funs = funs1, datatypes = datatypes1 },
  Extrs { funs = funs2, datatypes = datatypes2 }) =
    mk_extrs (AList.merge (op =) (K true) (funs1, funs2),
      AList.merge (op =) (K true) (datatypes1, datatypes2));

datatype codethms = Codethms of {
  funs: thm list Symtab.table,
  preds: thm list Symtab.table,
  unfolds: thm list
};

fun mk_codethms ((funs, preds), unfolds) =
  Codethms { funs = funs, preds = preds, unfolds = unfolds };
fun map_codethms f (Codethms { funs, preds, unfolds }) =
  mk_codethms (f ((funs, preds), unfolds));
fun merge_codethms _ (Codethms { funs = funs1, preds = preds1, unfolds = unfolds1 },
  Codethms { funs = funs2, preds = preds2, unfolds = unfolds2 }) =
    mk_codethms ((Symtab.join (K (uncurry (fold (insert eq_thm)))) (funs1, funs2),
        Symtab.join (K (uncurry (fold (insert eq_thm)))) (preds1, preds2)),
          fold (insert eq_thm) unfolds1 unfolds2);

datatype codecache = Codecache of {
  funs: thm list Symtab.table,
  datatypes: (string * typ list) list Symtab.table
};

fun mk_codecache (funs, datatypes) = Codecache { funs = funs, datatypes = datatypes };
fun map_codecache f (Extrs { funs, datatypes }) = Codecache (f (funs, datatypes));
fun merge_codecache _ (Codecache { funs = funs1, datatypes = datatypes1 },
  Extrs { funs = funs2, datatypes = datatypes2 }) =
    mk_codecache (Symtab.empty, Symtab.empty);

datatype T = T of {
  procs: procs,
  extrs: extrs,
  codethms: codethms
};

fun mk_T (procs, (extrs, codethms)) = T { procs = procs, extrs = extrs, codethms = codethms };
fun map_T f (T { procs, extrs, codethms }) = mk_T (f (procs, (extrs, codethms)));
fun merge_T pp (T { procs = procs1, extrs = extrs1, codethms = codethms1 },
  T { procs = procs2, extrs = extrs2, codethms = codethms2 }) =
    mk_T (merge_procs pp (procs1, procs2), (merge_extrs pp (extrs1, extrs2), merge_codethms pp (codethms1, codethms2)));

structure CodegenTheorems = TheoryDataFun
(struct
  val name = "Pure/CodegenTheorems";
  type T = T;
  val empty = mk_T (mk_procs ([], []),
    (mk_extrs ([], []), mk_codethms ((Symtab.empty, Symtab.empty), [])));
  val copy = I;
  val extend = I;
  val merge = merge_T;
  fun print (thy : theory) (data : T) =
    let
      val codethms = (fn T { codethms, ... } => codethms) data;
      val funs = (Symtab.dest o (fn Codethms { funs, ... } => funs)) codethms;
      val preds = (Symtab.dest o (fn Codethms { preds, ... } => preds)) codethms;
      val unfolds = (fn Codethms { unfolds, ... } => unfolds) codethms;
    in
      (Pretty.writeln o Pretty.block o Pretty.fbreaks) ([
        Pretty.str "code generation theorems:",
        Pretty.str "function theorems:" ] @
        Pretty.fbreaks (
          map (fn (c, thms) => 
            (Pretty.block o Pretty.fbreaks) (
              Pretty.str c :: map Display.pretty_thm thms
            )
          ) funs
        ) @ [
        Pretty.str "predicate theorems:" ] @
        Pretty.fbreaks (
          map (fn (c, thms) => 
            (Pretty.block o Pretty.fbreaks) (
              Pretty.str c :: map Display.pretty_thm thms
            )
          ) preds
        ) @ [
        Pretty.str "unfolding theorems:",
        (Pretty.block o Pretty.fbreaks o map Display.pretty_thm) unfolds
      ])
    end;
end);

val _ = Context.add_setup CodegenTheorems.init;
val print_thms = CodegenTheorems.print;

local
  val the_procs = (fn T { procs = Procs procs, ... } => procs) o CodegenTheorems.get
  val the_extrs = (fn T { extrs = Extrs extrs, ... } => extrs) o CodegenTheorems.get
  val the_codethms = (fn T { codethms = Codethms codethms, ... } => codethms) o CodegenTheorems.get
in
  val the_preprocs = (fn { preprocs, ... } => map snd preprocs) o the_procs;
  val the_notify = (fn { notify, ... } => map snd notify) o the_procs;
  val the_funs_extrs = (fn { funs, ... } => map snd funs) o the_extrs;
  val the_datatypes_extrs = (fn { datatypes, ... } => map snd datatypes) o the_extrs;
  val the_funs = (fn { funs, ... } => funs) o the_codethms;
  val the_preds = (fn { preds, ... } => preds) o the_codethms;
  val the_unfolds = (fn { unfolds, ... } => unfolds) o the_codethms;
end (*local*);

fun add_notify f =
  CodegenTheorems.map (map_T (fn (procs, codethms) =>
    (procs |> map_procs (fn (preprocs, notify) =>
      (preprocs, (serial (), f) :: notify)), codethms)));

fun notify_all c thy =
  fold (fn f => f c) (the_notify thy) thy;

fun add_preproc f =
  CodegenTheorems.map (map_T (fn (procs, codethms) =>
    (procs |> map_procs (fn (preprocs, notify) =>
      ((serial (), f) :: preprocs, notify)), codethms)))
  #> notify_all NONE;

fun add_fun_extr f =
  CodegenTheorems.map (map_T (fn (procs, (extrs, codethms)) =>
    (procs, (extrs |> map_extrs (fn (funs, datatypes) =>
      ((serial (), f) :: funs, datatypes)), codethms))))
  #> notify_all NONE;

fun add_datatype_extr f =
  CodegenTheorems.map (map_T (fn (procs, (extrs, codethms)) =>
    (procs, (extrs |> map_extrs (fn (funs, datatypes) =>
      (funs, (serial (), f) :: datatypes)), codethms))))
  #> notify_all NONE;

fun add_fun thm thy =
  case destr_fun thy thm
   of SOME ((c, _), _) =>
        thy
        |> CodegenTheorems.map (map_T (fn (procs, (extrs, codethms)) =>
           (procs, (extrs, codethms |> map_codethms (fn ((funs, preds), unfolds) =>
            ((funs |> Symtab.default (c, []) |> Symtab.map_entry c (fn thms => thms @ [thm]), preds), unfolds))))))
        |> notify_all (SOME c)
    | NONE => tap (fn _ => warning ("not a function equation: " ^ string_of_thm thm)) thy;

fun add_pred thm thy =
  case dest_pred thm
   of SOME (c, _) =>
        thy
        |> CodegenTheorems.map (map_T (fn (procs, (extrs, codethms)) =>
          (procs, (extrs, codethms |> map_codethms (fn ((funs, preds), unfolds) =>
            ((funs, preds |> Symtab.default (c, []) |> Symtab.map_entry c (fn thms => thms @ [thm])), unfolds))))))
        |> notify_all (SOME c)
    | NONE => tap (fn _ => warning ("not a predicate clause: " ^ string_of_thm thm)) thy;

fun add_unfold thm =
  CodegenTheorems.map (map_T (fn (procs, (extrs, codethms)) =>
    (procs, (extrs, codethms |> map_codethms (fn (defs, unfolds) =>
      (defs, thm :: unfolds))))))
  #> notify_all NONE;

fun del_def thm thy =
  case destr_fun thy thm
   of SOME ((c, _), thm) =>
        thy
        |> CodegenTheorems.map (map_T (fn (procs, (extrs, codethms)) =>
           (procs, (extrs, codethms |> map_codethms (fn ((funs, preds), unfolds) =>
            ((funs |> Symtab.map_entry c (remove eq_thm thm), preds), unfolds))))))
        |> notify_all (SOME c)
    | NONE => case dest_pred thm
   of SOME (c, thm) =>
        thy
        |> CodegenTheorems.map (map_T (fn (procs, (extrs, codethms)) =>
           (procs, (extrs, codethms |> map_codethms (fn ((funs, preds), unfolds) =>
            ((funs, preds |> Symtab.map_entry c (remove eq_thm thm)), unfolds))))))
        |> notify_all (SOME c)
    | NONE => error ("no code theorem to delete");

fun del_unfold thm = 
  CodegenTheorems.map (map_T (fn (procs, (extrs, codethms)) =>
    (procs, (extrs, codethms |> map_codethms (fn (defs, unfolds) =>
      (defs, remove eq_thm thm unfolds))))))
  #> notify_all NONE;

fun purge_defs (c, ty) thy =
  thy
  |> CodegenTheorems.map (map_T (fn (procs, (extrs, codethms)) =>
      (procs, (extrs, codethms |> map_codethms (fn ((funs, preds), unfolds) =>
        ((funs |> Symtab.map_entry c
            (filter (fn thm => Sign.typ_instance thy ((snd o fst o dest_fun thy) thm, ty))),
          preds |> Symtab.update (c, [])), unfolds))))))
  |> notify_all (SOME c);


(** preprocessing **)

fun common_typ thy _ [] = []
  | common_typ thy _ [thm] = [thm]
  | common_typ thy extract_typ thms =
      let
        fun incr_thm thm max =
          let
            val thm' = incr_indexes max thm;
            val max' = (maxidx_of_typ o type_of o prop_of) thm' + 1;
          in (thm', max') end;
        val (thms', maxidx) = fold_map incr_thm thms 0;
        val (ty1::tys) = map extract_typ thms;
        fun unify ty = Type.unify (Sign.tsig_of thy) (ty1, ty);
        val (env, _) = fold unify tys (Vartab.empty, maxidx)
        val instT = Vartab.fold (fn (x_i, (sort, ty)) =>
          cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [];
      in map (Thm.instantiate (instT, [])) thms end;

fun preprocess thy extract_typ thms =
  thms
  |> map (Thm.transfer thy)
  |> fold (fn f => f thy) (the_preprocs thy)
  |> map (rewrite_rule (the_unfolds thy))
  |> (if is_some extract_typ then common_typ thy (the extract_typ) else I)
  |> Drule.conj_intr_list
  |> Drule.zero_var_indexes
  |> Drule.conj_elim_list
  |> map Drule.unvarifyT
  |> map Drule.unvarify;

fun preprocess_fun thy thms =
  let
    fun tap_typ [] = NONE
      | tap_typ (thms as (thm::_)) = SOME ((snd o fst o dest_fun thy) thm, thms)
  in
    thms
    |> map elim_obj_eq
    |> preprocess thy (SOME (snd o fst o dest_fun thy))
    |> tap_typ
  end;

fun preprocess_term thy t =
  let
    val x = Free (variant (add_term_names (t, [])) "a", fastype_of t);
    (*fake definition*)
    val eq = setmp quick_and_dirty true (SkipProof.make_thm thy)
      (Logic.mk_equals (x, t));
    fun err () = error "preprocess_term: bad preprocessor"
  in case map prop_of (preprocess thy NONE [eq]) of
      [Const ("==", _) $ x' $ t'] => if x = x' then t' else err ()
    | _ => err ()
  end;


(** retrieval **)

fun get_funs thy (c, ty) =
  let
    val filter_typ = Library.mapfilter (fn ((_, ty'), thm) =>
      if Sign.typ_instance thy (ty', ty)
        orelse Sign.typ_instance thy (ty, ty')
      then SOME thm else debug_msg (fn _ => "dropping " ^ string_of_thm thm) NONE);
    val thms_funs = 
      (these o Symtab.lookup (the_funs thy)) c
      |> map (dest_fun thy)
      |> filter_typ;
    val thms = case thms_funs
     of [] =>
          Defs.specifications_of (Theory.defs_of thy) c
          |> map (PureThy.get_thms thy o Name o fst o snd)
          |> Library.flat
          |> append (getf_first_list (map (fn f => f thy) (the_funs_extrs thy)) (c, ty))
          |> map (dest_fun thy)
          |> filter_typ
      | thms => thms
  in
    thms
    |> preprocess_fun thy
  end;

fun get_datatypes thy dtco =
  let
    val truh = mk_tf true;
    val fals = mk_tf false;
    fun mk_lhs vs ((co1, tys1), (co2, tys2)) =
      let
        val dty = Type (dtco, map TFree vs);
        val (xs1, xs2) = chop (length tys1) (Term.invent_names [] "x" (length tys1 + length tys2));
        val frees1 = map2 (fn x => fn ty => Free (x, ty)) xs1 tys1;
        val frees2 = map2 (fn x => fn ty => Free (x, ty)) xs2 tys2;
        fun zip_co co xs tys = list_comb (Const (co,
          tys ---> dty), map2 (fn x => fn ty => Free (x, ty)) xs tys);
      in
        ((frees1, frees2), (zip_co co1 xs1 tys1, zip_co co2 xs2 tys2))
      end;
    fun mk_rhs [] [] = truh
      | mk_rhs xs ys = foldr1 mk_obj_conj (map2 (curry mk_obj_eq) xs ys);
    fun mk_eq vs (args as ((co1, _), (co2, _))) (inj, dist) =
      if co1 = co2
        then let
          val ((fs1, fs2), lhs) = mk_lhs vs args;
          val rhs = mk_rhs fs1 fs2;
        in (mk_bool_eq (lhs, rhs) :: inj, dist) end
        else let
          val (_, lhs) = mk_lhs vs args;
        in (inj, mk_bool_eq (lhs, fals) :: dist) end;
    fun mk_eqs (vs, cos) =
      let val cos' = rev cos 
      in (op @) (fold (mk_eq vs) (product cos' cos') ([], [])) end;
    fun mk_eq_thms tac vs_cos =
      map (fn t => (Goal.prove thy [] []
        (ObjectLogic.ensure_propT thy t) (K tac))) (mk_eqs vs_cos);
  in
    case getf_first (map (fn f => f thy) (the_datatypes_extrs thy)) dtco
     of NONE => NONE
      | SOME (vs_cos, tac) => SOME (vs_cos, mk_eq_thms tac vs_cos)
  end;

fun get_eq thy (c, ty) =
  if is_obj_eq c
  then case get_datatypes thy ((fst o dest_Type o hd o fst o strip_type) ty)
   of SOME (_, thms) => thms
    | _ => []
  else [];


(** code attributes and setup **)

local
  fun add_simple_attribute (name, f) =
    (Codegen.add_attribute name o (Scan.succeed o Thm.declaration_attribute))
      (Context.map_theory o f);
in
  val _ = map (Context.add_setup o add_simple_attribute) [
    ("fun", add_fun),
    ("pred", add_pred),
    ("unfold", (fn thm => Codegen.add_unfold thm #> add_unfold thm)),
    ("unfolt", add_unfold),
    ("nofold", del_unfold)
  ]
end; (*local*)

val _ = Context.add_setup (add_fun_extr get_eq);

end; (*struct*)