src/Tools/nbe.ML
author haftmann
Tue Aug 07 09:40:34 2007 +0200 (2007-08-07)
changeset 24166 7b28dc69bdbb
parent 24155 d86867645f4f
child 24219 e558fe311376
permissions -rw-r--r--
new nbe implementation
     1 (*  Title:      Tools/Nbe/Nbe_Eval.ML
     2     ID:         $Id$
     3     Authors:    Klaus Aehlig, LMU Muenchen; Tobias Nipkow, Florian Haftmann, TU Muenchen
     4 
     5 Evaluation mechanisms for normalization by evaluation.
     6 *)
     7 
     8 (*
     9 FIXME:
    10 - get rid of BVar (?) - it is only used for terms to be evaluated, not for functions
    11 - proper purge operation - preliminary for...
    12 - really incremental code generation
    13 *)
    14 
    15 signature NBE =
    16 sig
    17   datatype Univ = 
    18       Const of string * Univ list        (*named constructors*)
    19     | Free of string * Univ list
    20     | BVar of int * Univ list
    21     | Abs of (int * (Univ list -> Univ)) * Univ list;
    22   val free: string -> Univ list -> Univ       (*free (uninterpreted) variables*)
    23   val abs: int -> (Univ list -> Univ) -> Univ list -> Univ
    24                                             (*abstractions as functions*)
    25   val app: Univ -> Univ -> Univ              (*explicit application*)
    26 
    27   val univs_ref: Univ list ref 
    28   val lookup_fun: CodegenNames.const -> Univ
    29 
    30   val normalization_conv: cterm -> thm
    31 
    32   val trace: bool ref
    33   val setup: theory -> theory
    34 end;
    35 
    36 structure Nbe: NBE =
    37 struct
    38 
    39 (* generic non-sense *)
    40 
    41 val trace = ref false;
    42 fun tracing f x = if !trace then (Output.tracing (f x); x) else x;
    43 
    44 
    45 (** the semantical universe **)
    46 
    47 (*
    48    Functions are given by their semantical function value. To avoid
    49    trouble with the ML-type system, these functions have the most
    50    generic type, that is "Univ list -> Univ". The calling convention is
    51    that the arguments come as a list, the last argument first. In
    52    other words, a function call that usually would look like
    53 
    54    f x_1 x_2 ... x_n   or   f(x_1,x_2, ..., x_n)
    55 
    56    would be in our convention called as
    57 
    58               f [x_n,..,x_2,x_1]
    59 
    60    Moreover, to handle functions that are still waiting for some
    61    arguments we have additionally a list of arguments collected to far
    62    and the number of arguments we're still waiting for.
    63 
    64    (?) Finally, it might happen, that a function does not get all the
    65    arguments it needs.  In this case the function must provide means to
    66    present itself as a string. As this might be a heavy-wight
    67    operation, we delay it. (?) 
    68 *)
    69 
    70 datatype Univ = 
    71     Const of string * Univ list        (*named constructors*)
    72   | Free of string * Univ list         (*free variables*)
    73   | BVar of int * Univ list            (*bound named variables*)
    74   | Abs of (int * (Univ list -> Univ)) * Univ list
    75                                       (*functions*);
    76 
    77 (* constructor functions *)
    78 
    79 val free = curry Free;
    80 fun abs n f ts = Abs ((n, f), ts);
    81 fun app (Abs ((1, f), xs)) x = f (x :: xs)
    82   | app (Abs ((n, f), xs)) x = Abs ((n - 1, f), x :: xs)
    83   | app (Const (name, args)) x = Const (name, x :: args)
    84   | app (Free (name, args)) x = Free (name, x :: args)
    85   | app (BVar (name, args)) x = BVar (name, x :: args);
    86 
    87 (* global functions store *)
    88 
    89 structure Nbe_Functions = CodeDataFun
    90 (struct
    91   type T = Univ Symtab.table;
    92   val empty = Symtab.empty;
    93   fun merge _ = Symtab.merge (K true);
    94   fun purge _ _ _ = Symtab.empty;
    95 end);
    96 
    97 (* sandbox communication *)
    98 
    99 val univs_ref = ref [] : Univ list ref;
   100 
   101 local
   102 
   103 val tab_ref = ref NONE : Univ Symtab.table option ref;
   104 
   105 in
   106 
   107 fun lookup_fun s = case ! tab_ref
   108  of NONE => error "compile_univs"
   109   | SOME tab => (the o Symtab.lookup tab) s;
   110 
   111 fun compile_univs tab ([], _) = []
   112   | compile_univs tab (cs, raw_s) =
   113       let
   114         val _ = univs_ref := [];
   115         val s = "Nbe.univs_ref := " ^ raw_s;
   116         val _ = tracing (fn () => "\n---generated code:\n" ^ s) ();
   117         val _ = tab_ref := SOME tab;
   118         val _ = use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
   119           Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n")
   120           (!trace) s;
   121         val _ = tab_ref := NONE;
   122         val univs = case !univs_ref of [] => error "compile_univs" | univs => univs;
   123       in cs ~~ univs end;
   124 
   125 end; (*local*)
   126 
   127 
   128 (** assembling and compiling ML code from terms **)
   129 
   130 (* abstract ML syntax *)
   131 
   132 infix 9 `$` `$$`;
   133 fun e1 `$` e2 = "(" ^ e1 ^ " " ^ e2 ^ ")";
   134 fun e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")";
   135 fun ml_abs v e = "(fn" ^ v ^ " => " ^ e ^ ")";
   136 
   137 fun ml_Val v s = "val " ^ v ^ " = " ^ s;
   138 fun ml_cases t cs =
   139   "(case " ^ t ^ " of " ^ space_implode " | " (map (fn (p, t) => p ^ " => " ^ t) cs) ^ ")";
   140 fun ml_Let ds e = "let\n" ^ space_implode "\n" ds ^ " in " ^ e ^ " end";
   141 
   142 fun ml_list es = "[" ^ commas es ^ "]";
   143 
   144 fun ml_fundefs ([(name, [([], e)])]) =
   145       "val " ^ name ^ " = " ^ e ^ "\n"
   146   | ml_fundefs (eqs :: eqss) =
   147       let
   148         fun fundef (name, eqs) =
   149           let
   150             fun eqn (es, e) = name ^ " " ^ space_implode " " es ^ " = " ^ e
   151           in space_implode "\n  | " (map eqn eqs) end;
   152       in
   153         (prefix "fun " o fundef) eqs :: map (prefix "and " o fundef) eqss
   154         |> space_implode "\n"
   155         |> suffix "\n"
   156       end;
   157 
   158 (* nbe specific syntax *)
   159 
   160 local
   161   val prefix =          "Nbe.";
   162   val name_const =      prefix ^ "Const";
   163   val name_free =       prefix ^ "free";
   164   val name_abs =        prefix ^ "abs";
   165   val name_app =        prefix ^ "app";
   166   val name_lookup_fun = prefix ^ "lookup_fun";
   167 in
   168 
   169 fun nbe_const c ts = name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list ts ^ ")");
   170 fun nbe_fun c = "c_" ^ translate_string (fn "." => "_" | c => c) c;
   171 fun nbe_free v = name_free `$$` [ML_Syntax.print_string v, ml_list []];
   172 fun nbe_bound v = "v_" ^ v;
   173 
   174 fun nbe_apps e es =
   175   Library.foldr (fn (s, e) => name_app `$$` [e, s]) (es, e);
   176 
   177 fun nbe_abss 0 f = f `$` ml_list []
   178   | nbe_abss n f = name_abs `$$` [string_of_int n, f, ml_list []];
   179 
   180 fun nbe_lookup c = ml_Val (nbe_fun c) (name_lookup_fun `$` ML_Syntax.print_string c);
   181 
   182 val nbe_value = "value";
   183 
   184 end;
   185 
   186 open BasicCodegenThingol;
   187 
   188 (* greetings to Tarski *)
   189 
   190 fun assemble_iterm thy is_fun num_args =
   191   let
   192     fun of_iterm t =
   193       let
   194         val (t', ts) = CodegenThingol.unfold_app t
   195       in of_itermapp t' (fold (cons o of_iterm) ts []) end
   196     and of_itermapp (IConst (c, (dss, _))) ts =
   197           (case num_args c
   198            of SOME n => if n <= length ts
   199                 then let val (args2, args1) = chop (length ts - n) ts
   200                 in nbe_apps (nbe_fun c `$` ml_list args1) args2
   201                 end else nbe_const c ts
   202             | NONE => if is_fun c then nbe_apps (nbe_fun c) ts
   203                 else nbe_const c ts)
   204       | of_itermapp (IVar v) ts = nbe_apps (nbe_bound v) ts
   205       | of_itermapp ((v, _) `|-> t) ts =
   206           nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts
   207       | of_itermapp (ICase (((t, _), cs), t0)) ts =
   208           nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs
   209             @ [("_", of_iterm t0)])) ts
   210   in of_iterm end;
   211 
   212 fun assemble_fun thy is_fun num_args (c, eqns) =
   213   let
   214     val assemble_arg = assemble_iterm thy (K false) (K NONE);
   215     val assemble_rhs = assemble_iterm thy is_fun num_args;
   216     fun assemble_eqn (args, rhs) =
   217       ([ml_list (map assemble_arg (rev args))], assemble_rhs rhs);
   218     val default_params = map nbe_bound
   219       (Name.invent_list [] "a" ((the o num_args) c));
   220     val default_eqn = ([ml_list default_params], nbe_const c default_params);
   221   in map assemble_eqn eqns @ [default_eqn] end;
   222 
   223 fun assemble_eqnss thy is_fun [] = ([], "")
   224   | assemble_eqnss thy is_fun eqnss =
   225       let
   226         val cs = map fst eqnss;
   227         val num_args = cs ~~ map (fn (_, (args, rhs) :: _) => length args) eqnss;
   228         val funs = fold (fold (CodegenThingol.fold_constnames
   229           (insert (op =))) o map snd o snd) eqnss [];
   230         val bind_funs = map nbe_lookup (filter is_fun funs);
   231         val bind_locals = ml_fundefs (map nbe_fun cs ~~ map
   232           (assemble_fun thy is_fun (AList.lookup (op =) num_args)) eqnss);
   233         val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args);
   234       in (cs, ml_Let (bind_funs @ [bind_locals]) result) end;
   235 
   236 fun assemble_eval thy is_fun t =
   237   let
   238     val funs = CodegenThingol.fold_constnames (insert (op =)) t [];
   239     val frees = CodegenThingol.fold_unbound_varnames (insert (op =)) t [];
   240     val bind_funs = map nbe_lookup (filter is_fun funs);
   241     val bind_value = ml_fundefs [(nbe_value, [([ml_list (map nbe_bound frees)],
   242       assemble_iterm thy is_fun (K NONE) t)])];
   243     val result = ml_list [nbe_value `$` ml_list (map nbe_free frees)];
   244   in ([nbe_value], ml_Let (bind_funs @ [bind_value]) result) end;
   245 
   246 fun eqns_of_stmt (name, CodegenThingol.Fun ([], _)) =
   247       NONE
   248   | eqns_of_stmt (name, CodegenThingol.Fun (eqns, _)) =
   249       SOME (name, eqns)
   250   | eqns_of_stmt (_, CodegenThingol.Datatypecons _) =
   251       NONE
   252   | eqns_of_stmt (_, CodegenThingol.Datatype _) =
   253       NONE
   254   | eqns_of_stmt (_, CodegenThingol.Class _) =
   255       NONE
   256   | eqns_of_stmt (_, CodegenThingol.Classrel _) =
   257       NONE
   258   | eqns_of_stmt (_, CodegenThingol.Classop _) =
   259       NONE
   260   | eqns_of_stmt (_, CodegenThingol.Classinst _) =
   261       NONE;
   262 
   263 fun compile_stmts thy is_fun =
   264   map_filter eqns_of_stmt
   265   #> assemble_eqnss thy is_fun
   266   #> compile_univs (Nbe_Functions.get thy);
   267 
   268 fun eval_term thy is_fun =
   269   assemble_eval thy is_fun
   270   #> compile_univs (Nbe_Functions.get thy)
   271   #> the_single
   272   #> snd;
   273 
   274 
   275 (** compilation and evaluation **)
   276 
   277 (* ensure global functions *)
   278 
   279 fun ensure_funs thy code =
   280   let
   281     fun compile' stmts tab =
   282       let
   283         val compiled = compile_stmts thy (Symtab.defined tab) stmts;
   284       in Nbe_Functions.change thy (fold Symtab.update compiled) end;
   285     val nbe_tab = Nbe_Functions.get thy;
   286     val stmtss =
   287       map (AList.make (Graph.get_node code)) (rev (Graph.strong_conn code))
   288       |> (map o filter_out) (Symtab.defined nbe_tab o fst)
   289   in fold compile' stmtss nbe_tab end;
   290 
   291 (* re-conversion *)
   292 
   293 fun term_of_univ thy t =
   294   let
   295     fun of_apps bounds (t, ts) =
   296       fold_map (of_univ bounds) ts
   297       #>> (fn ts' => list_comb (t, rev ts'))
   298     and of_univ bounds (Const (name, ts)) typidx =
   299           let
   300             val SOME (const as (c, _)) = CodegenNames.const_rev thy name;
   301             val T = CodegenData.default_typ thy const;
   302             val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T;
   303             val typidx' = typidx + maxidx_of_typ T' + 1;
   304           in of_apps bounds (Term.Const (c, T'), ts) typidx' end
   305       | of_univ bounds (Free (name, ts)) typidx =
   306           of_apps bounds (Term.Free (name, dummyT), ts) typidx
   307       | of_univ bounds (BVar (name, ts)) typidx =
   308           of_apps bounds (Bound (bounds - name - 1), ts) typidx
   309       | of_univ bounds (t as Abs _) typidx =
   310           typidx
   311           |> of_univ (bounds + 1) (app t (BVar (bounds, [])))
   312           |-> (fn t' => pair (Term.Abs ("u", dummyT, t')))
   313   in of_univ 0 t 0 |> fst end;
   314 
   315 (* evaluation with type reconstruction *)
   316 
   317 fun eval thy code t t' =
   318   let
   319     fun subst_Frees [] = I
   320       | subst_Frees inst =
   321           Term.map_aterms (fn (t as Term.Free (s, _)) => the_default t (AList.lookup (op =) inst s)
   322                             | t => t);
   323     val anno_vars =
   324       subst_Frees (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []))
   325       #> subst_Vars (map (fn (ixn, T) => (ixn, Var (ixn, T))) (Term.add_vars t []))
   326     fun check_tvars t = if null (Term.term_tvars t) then t else
   327       error ("Illegal schematic type variables in normalized term: "
   328         ^ setmp show_types true (Sign.string_of_term thy) t);
   329     val ty = type_of t;
   330     fun constrain t =
   331       singleton (ProofContext.infer_types_pats (ProofContext.init thy)) (TypeInfer.constrain t ty);
   332   in
   333     t'
   334     |> eval_term thy (Symtab.defined (ensure_funs thy code))
   335     |> term_of_univ thy
   336     |> tracing (fn t => "Normalized:\n" ^ setmp show_types true Display.raw_string_of_term t)
   337     |> tracing (fn _ => "Term type:\n" ^ Display.raw_string_of_typ ty)
   338     |> anno_vars
   339     |> tracing (fn t => "Vars typed:\n" ^ setmp show_types true Display.raw_string_of_term t)
   340     |> tracing (fn t => setmp show_types true (Sign.string_of_term thy) t)
   341     |> constrain
   342     |> tracing (fn t => "Types inferred:\n" ^ setmp show_types true Display.raw_string_of_term t)
   343     |> check_tvars
   344   end;
   345 
   346 (* evaluation oracle *)
   347 
   348 exception Normalization of CodegenThingol.code * term * CodegenThingol.iterm;
   349 
   350 fun normalization_oracle (thy, Normalization (code, t, t')) =
   351   Logic.mk_equals (t, eval thy code t t');
   352 
   353 fun normalization_invoke thy code t t' =
   354   Thm.invoke_oracle_i thy "HOL.normalization" (thy, Normalization (code, t, t'));
   355   (*FIXME get rid of hardwired theory name "HOL"*)
   356 
   357 fun normalization_conv ct =
   358   let
   359     val thy = Thm.theory_of_cterm ct;
   360     fun conv code t' ct =
   361       let
   362         val t = Thm.term_of ct;
   363       in normalization_invoke thy code t t' end;
   364   in CodegenPackage.eval_conv thy conv ct end;
   365 
   366 (* evaluation command *)
   367 
   368 fun norm_print_term ctxt modes t =
   369   let
   370     val thy = ProofContext.theory_of ctxt;
   371     val ct = Thm.cterm_of thy t;
   372     val (_, t') = (Logic.dest_equals o Thm.prop_of o normalization_conv) ct;
   373     val ty = Term.type_of t';
   374     val p = Library.setmp print_mode (modes @ ! print_mode) (fn () =>
   375       Pretty.block [Pretty.quote (ProofContext.pretty_term ctxt t'), Pretty.fbrk,
   376         Pretty.str "::", Pretty.brk 1, Pretty.quote (ProofContext.pretty_typ ctxt ty)]) ();
   377   in Pretty.writeln p end;
   378 
   379 
   380 (** Isar setup **)
   381 
   382 fun norm_print_term_cmd (modes, raw_t) state =
   383   let val ctxt = Toplevel.context_of state
   384   in norm_print_term ctxt modes (ProofContext.read_term ctxt raw_t) end;
   385 
   386 val setup = Theory.add_oracle ("normalization", normalization_oracle)
   387 
   388 local structure P = OuterParse and K = OuterKeyword in
   389 
   390 val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) [];
   391 
   392 val nbeP =
   393   OuterSyntax.improper_command "normal_form" "normalize term by evaluation" K.diag
   394     (opt_modes -- P.typ >> (Toplevel.keep o norm_print_term_cmd));
   395 
   396 val _ = OuterSyntax.add_parsers [nbeP];
   397 
   398 end;
   399 
   400 end;