src/Pure/Tools/nbe.ML
author nipkow
Mon, 09 Oct 2006 12:16:29 +0200
changeset 20920 07f279940664
parent 20856 9f7f0bf89e7d
child 20937 4297a44e26ae
permissions -rw-r--r--
added pre/post-processor equations

(*  Title:      Pure/nbe.ML
    ID:         $Id$
    Author:     Tobias Nipkow, Florian Haftmann, TU Muenchen

Toplevel theory interface for "normalization by evaluation"
*)

signature NBE =
sig
  (*preconditions: no Vars/TVars in term*)
  val norm_term: theory -> term -> term
  val normalization_conv: cterm -> thm
  val lookup: string -> NBE_Eval.Univ
  val update: string * NBE_Eval.Univ -> unit
  val trace: bool ref
end;

structure NBE: NBE =
struct

val trace = ref false;
fun tracing f = if !trace then Output.tracing (f ()) else ();


(* theory data setup *)

structure NBE_Rewrite = TheoryDataFun
(struct
  val name = "Pure/nbe";
  type T = thm list * thm list

  val empty = ([],[])
  val copy = I;
  val extend = I;

  fun merge _ ((pres1,posts1), (pres2,posts2)) =
    (Library.merge eq_thm (pres1,pres2), Library.merge eq_thm (posts1,posts2))

  fun print _ _ = ()
end);

val _ = Context.add_setup NBE_Rewrite.init;

fun consts_of_pres thy = 
  let val pres = fst(NBE_Rewrite.get thy);
      val rhss = map (snd o Logic.dest_equals o prop_of) pres;
  in (fold o fold_aterms)
        (fn Const c => insert (op =) (CodegenConsts.norm_of_typ thy c) | _ => I)
        rhss []
  end


local

fun attr_pre (thy,thm) =
 ((Context.map_theory o NBE_Rewrite.map o apfst) (insert eq_thm thm) thy, thm)
fun attr_post (thy,thm) = 
 ((Context.map_theory o NBE_Rewrite.map o apsnd) (insert eq_thm thm) thy, thm)

in
val _ = Context.add_setup
  (Attrib.add_attributes
     [("normal_pre", Attrib.no_args attr_pre, "declare pre-theorems for normalization"),
      ("normal_post", Attrib.no_args attr_post, "declare posy-theorems for normalization")]);
end;

fun apply_pres thy =
  let val pres = fst(NBE_Rewrite.get thy)
  in map (CodegenData.rewrite_func pres) end

fun apply_posts thy =
  let val posts = snd(NBE_Rewrite.get thy)
  in Tactic.rewrite false posts end


structure NBE_Data = CodeDataFun
(struct
  val name = "Pure/NBE"
  type T = NBE_Eval.Univ Symtab.table
  val empty = Symtab.empty
  fun merge _ = Symtab.merge (K true)
  fun purge _ _ = Symtab.empty
end);

val _ = Context.add_setup NBE_Data.init;


(* sandbox communication *)

val tab : NBE_Eval.Univ Symtab.table ref = ref Symtab.empty;
fun lookup s = (the o Symtab.lookup (!tab)) s;
fun update sx = (tab := Symtab.update sx (!tab));


(* norm by eval *)

local

(* FIXME better turn this into a function
    NBE_Eval.Univ Symtab.table -> NBE_Eval.Univ Symtab.table
    with implicit side effect *)
fun use_code NONE = ()
  | use_code (SOME s) =
      (tracing (fn () => "\n---generated code:\n" ^ s);
       use_text (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
            Output.tracing o enclose "\n--- compiler echo (with error!):\n" "\n---\n")
        (!trace) s);

fun generate thy funs =
  let
    val _ = tab := NBE_Data.get thy;;
    val _ = Library.seq (use_code o NBE_Codegen.generate thy
      (fn s => Symtab.defined (!tab) s)) funs;
  in NBE_Data.change thy (K (!tab)) end;

fun compile_term thy t =
  let
    (*FIXME: proper interfaces in codegen_*)
    val (consts, cs) = CodegenConsts.consts_of thy t;
    val funcgr = CodegenFuncgr.mk_funcgr thy consts cs;
    (*FIXME: proper interfaces in codegen_*)
    fun const_typ (c, ty) =
      let
        val const = CodegenConsts.norm_of_typ thy (c, ty);
      in case CodegenFuncgr.get_funcs funcgr const
       of (thm :: _) => CodegenData.typ_func thy thm
        | [] => Sign.the_const_type thy c
      end;
    val (_, ct) = CodegenData.preprocess_cterm thy const_typ (Thm.cterm_of thy t)
    val t' = Thm.term_of ct;
    val (consts, cs) = CodegenConsts.consts_of thy t';
    val pre_consts = consts_of_pres thy;
    val consts' = pre_consts @ consts;
    val funcgr = CodegenFuncgr.mk_funcgr thy consts' cs;
    val nbe_tab = NBE_Data.get thy;
    val all_consts =
      (pre_consts :: CodegenFuncgr.all_deps_of funcgr consts')
      |> (map o filter_out) (Symtab.defined nbe_tab o CodegenNames.const thy)
      |> filter_out null;
    val funs = (map o map)
      (fn c => (CodegenNames.const thy c, apply_pres thy (CodegenFuncgr.get_funcs funcgr c))) all_consts;
    val _ = tracing (fn () => "new definitions: " ^ (commas o maps (map fst)) funs);
    val _ = generate thy funs;
    val nt = NBE_Eval.eval thy (!tab) t';
  in nt end;

fun subst_Frees [] = I
  | subst_Frees inst =
      Term.map_aterms (fn (t as Free(s, _)) => the_default t (AList.lookup (op =) inst s)
                  | t => t);

fun var_tab t = (Term.add_frees t [], Term.add_vars t []);

fun anno_vars (Ftab, Vtab) =
  subst_Vars  (map (fn (ixn, T) => (ixn, Var(ixn,T))) Vtab) o
  subst_Frees (map (fn (s, T) =>   (s,   Free(s,T)))  Ftab)

in

fun norm_term thy t =
  let
    val _ = tracing (fn () => "Input:\n" ^ Display.raw_string_of_term t);
    val nt = compile_term thy t;
    val vtab = var_tab t;
    val ty = type_of t;
    fun constrain ty t = Sign.infer_types (Sign.pp thy) thy (Sign.consts_of thy)
        (K NONE) (K NONE) Name.context false ([t], ty) |> fst;
    val _ = tracing (fn () => "Normalized:\n" ^ NBE_Eval.string_of_nterm nt);
    val t1 = NBE_Codegen.nterm_to_term thy nt;
    val _ = tracing (fn () =>"Converted back:\n" ^ Display.raw_string_of_term t1);
    val t2 = anno_vars vtab t1;
    val _ = tracing (fn () =>"Vars typed:\n" ^ Display.raw_string_of_term t2);
    val t3 = constrain ty t2;
    val _ = if null (Term.term_tvars t3) then () else
      error ("Illegal schematic type variables in normalized term: "
        ^ setmp show_types true (Sign.string_of_term thy) t3);
    val eq = apply_posts thy (Thm.cterm_of thy t3);
    val t4 = snd(Logic.dest_equals(prop_of eq))
  in t4 end;

fun norm_print_term ctxt modes t =
  let
    val thy = ProofContext.theory_of ctxt;
    val t' = norm_term thy t;
    val ty = Term.type_of t';
    val p = Library.setmp print_mode (modes @ ! print_mode) (fn () =>
      Pretty.block [Pretty.quote (ProofContext.pretty_term ctxt t'), Pretty.fbrk,
        Pretty.str "::", Pretty.brk 1, Pretty.quote (ProofContext.pretty_typ ctxt ty)]) ();
  in Pretty.writeln p end;

fun norm_print_term_e (modes, raw_t) state =
  let
    val ctxt = (Proof.context_of o Toplevel.enter_forward_proof) state;
  in norm_print_term ctxt modes (ProofContext.read_term ctxt raw_t) end;

end; (*local*)


(* normalization oracle *)

exception Normalization of term;

fun normalization_oracle (thy, Normalization t) =
  Logic.mk_equals (t, norm_term thy t);

fun normalization_conv ct =
  let val {sign, t, ...} = rep_cterm ct
  in Thm.invoke_oracle_i sign "Pure.normalization" (sign, Normalization t) end;

val _ = Context.add_setup
  (Theory.add_oracle ("normalization", normalization_oracle));


(* Isar setup *)

local structure P = OuterParse and K = OuterKeyword in

val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) [];

val nbeP =
  OuterSyntax.improper_command "normal_form" "normalize term by evaluation" K.diag
    (opt_modes -- P.typ >> (Toplevel.keep o norm_print_term_e));

end;

val _ = OuterSyntax.add_parsers [nbeP];

end;