src/Tools/nbe.ML
author wenzelm
Wed Jun 18 22:32:03 2008 +0200 (2008-06-18 ago)
changeset 27264 843472ae2116
parent 27103 d8549f4d900b
child 27499 150558266831
permissions -rw-r--r--
simplified TypeInfer.infer_types;
     1 (*  Title:      Tools/nbe.ML
     2     ID:         $Id$
     3     Authors:    Klaus Aehlig, LMU Muenchen; Tobias Nipkow, Florian Haftmann, TU Muenchen
     4 
     5 Normalization by evaluation, based on generic code generator.
     6 *)
     7 
     8 signature NBE =
     9 sig
    10   val norm_conv: cterm -> thm
    11   val norm_term: theory -> term -> term
    12 
    13   datatype Univ =
    14       Const of int * Univ list               (*named (uninterpreted) constants*)
    15     | Free of string * Univ list             (*free (uninterpreted) variables*)
    16     | DFree of string * int                  (*free (uninterpreted) dictionary parameters*)
    17     | BVar of int * Univ list
    18     | Abs of (int * (Univ list -> Univ)) * Univ list;
    19   val apps: Univ -> Univ list -> Univ        (*explicit applications*)
    20   val abss: int -> (Univ list -> Univ) -> Univ
    21                                             (*abstractions as closures*)
    22 
    23   val univs_ref: (unit -> Univ list -> Univ list) option ref
    24   val trace: bool ref
    25 
    26   val setup: theory -> theory
    27 end;
    28 
    29 structure Nbe: NBE =
    30 struct
    31 
    32 (* generic non-sense *)
    33 
    34 val trace = ref false;
    35 fun tracing f x = if !trace then (Output.tracing (f x); x) else x;
    36 
    37 
    38 (** the semantical universe **)
    39 
    40 (*
    41    Functions are given by their semantical function value. To avoid
    42    trouble with the ML-type system, these functions have the most
    43    generic type, that is "Univ list -> Univ". The calling convention is
    44    that the arguments come as a list, the last argument first. In
    45    other words, a function call that usually would look like
    46 
    47    f x_1 x_2 ... x_n   or   f(x_1,x_2, ..., x_n)
    48 
    49    would be in our convention called as
    50 
    51               f [x_n,..,x_2,x_1]
    52 
    53    Moreover, to handle functions that are still waiting for some
    54    arguments we have additionally a list of arguments collected to far
    55    and the number of arguments we're still waiting for.
    56 *)
    57 
    58 datatype Univ =
    59     Const of int * Univ list           (*named (uninterpreted) constants*)
    60   | Free of string * Univ list         (*free variables*)
    61   | DFree of string * int              (*free (uninterpreted) dictionary parameters*)
    62   | BVar of int * Univ list            (*bound named variables*)
    63   | Abs of (int * (Univ list -> Univ)) * Univ list
    64                                       (*abstractions as closures*);
    65 
    66 (* constructor functions *)
    67 
    68 fun abss n f = Abs ((n, f), []);
    69 fun apps (Abs ((n, f), xs)) ys = let val k = n - length ys in
    70       if k = 0 then f (ys @ xs)
    71       else if k < 0 then
    72         let val (zs, ws) = chop (~ k) ys
    73         in apps (f (ws @ xs)) zs end
    74       else Abs ((k, f), ys @ xs) end (*note: reverse convention also for apps!*)
    75   | apps (Const (name, xs)) ys = Const (name, ys @ xs)
    76   | apps (Free (name, xs)) ys = Free (name, ys @ xs)
    77   | apps (BVar (name, xs)) ys = BVar (name, ys @ xs);
    78 
    79 
    80 (** assembling and compiling ML code from terms **)
    81 
    82 (* abstract ML syntax *)
    83 
    84 infix 9 `$` `$$`;
    85 fun e1 `$` e2 = "(" ^ e1 ^ " " ^ e2 ^ ")";
    86 fun e `$$` [] = e
    87   | e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")";
    88 fun ml_abs v e = "(fn " ^ v ^ " => " ^ e ^ ")";
    89 
    90 fun ml_cases t cs =
    91   "(case " ^ t ^ " of " ^ space_implode " | " (map (fn (p, t) => p ^ " => " ^ t) cs) ^ ")";
    92 fun ml_Let d e = "let\n" ^ d ^ " in " ^ e ^ " end";
    93 
    94 fun ml_list es = "[" ^ commas es ^ "]";
    95 
    96 fun ml_fundefs ([(name, [([], e)])]) =
    97       "val " ^ name ^ " = " ^ e ^ "\n"
    98   | ml_fundefs (eqs :: eqss) =
    99       let
   100         fun fundef (name, eqs) =
   101           let
   102             fun eqn (es, e) = name ^ " " ^ space_implode " " es ^ " = " ^ e
   103           in space_implode "\n  | " (map eqn eqs) end;
   104       in
   105         (prefix "fun " o fundef) eqs :: map (prefix "and " o fundef) eqss
   106         |> cat_lines
   107         |> suffix "\n"
   108       end;
   109 
   110 (* nbe specific syntax and sandbox communication *)
   111 
   112 val univs_ref = ref (NONE : (unit -> Univ list -> Univ list) option);
   113 
   114 local
   115   val prefix =          "Nbe.";
   116   val name_ref =        prefix ^ "univs_ref";
   117   val name_const =      prefix ^ "Const";
   118   val name_abss =       prefix ^ "abss";
   119   val name_apps =       prefix ^ "apps";
   120 in
   121 
   122 val univs_cookie = (name_ref, univs_ref);
   123 
   124 fun nbe_fun "" = "nbe_value"
   125   | nbe_fun c = "c_" ^ translate_string (fn "." => "_" | c => c) c;
   126 fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n;
   127 fun nbe_bound v = "v_" ^ v;
   128 
   129 (*note: these three are the "turning spots" where proper argument order is established!*)
   130 fun nbe_apps t [] = t
   131   | nbe_apps t ts = name_apps `$$` [t, ml_list (rev ts)];
   132 fun nbe_apps_local c ts = nbe_fun c `$` ml_list (rev ts);
   133 fun nbe_apps_constr idx ts =
   134   name_const `$` ("(" ^ string_of_int idx ^ ", " ^ ml_list (rev ts) ^ ")");
   135 
   136 fun nbe_abss 0 f = f `$` ml_list []
   137   | nbe_abss n f = name_abss `$$` [string_of_int n, f];
   138 
   139 end;
   140 
   141 open BasicCodeThingol;
   142 
   143 (* code generation *)
   144 
   145 fun assemble_eqnss idx_of deps eqnss =
   146   let
   147     fun prep_eqns (c, (vs, eqns)) =
   148       let
   149         val dicts = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs;
   150         val num_args = length dicts + (length o fst o hd) eqns;
   151       in (c, (num_args, (dicts, eqns))) end;
   152     val eqnss' = map prep_eqns eqnss;
   153 
   154     fun assemble_constapp c dss ts = 
   155       let
   156         val ts' = (maps o map) assemble_idict dss @ ts;
   157       in case AList.lookup (op =) eqnss' c
   158        of SOME (n, _) => if n <= length ts'
   159             then let val (ts1, ts2) = chop n ts'
   160             in nbe_apps (nbe_apps_local c ts1) ts2
   161             end else nbe_apps (nbe_abss n (nbe_fun c)) ts'
   162         | NONE => if member (op =) deps c
   163             then nbe_apps (nbe_fun c) ts'
   164             else nbe_apps_constr (idx_of c) ts'
   165       end
   166     and assemble_idict (DictConst (inst, dss)) =
   167           assemble_constapp inst dss []
   168       | assemble_idict (DictVar (supers, (v, (n, _)))) =
   169           fold_rev (fn super => assemble_constapp super [] o single) supers (nbe_dict v n);
   170 
   171     fun assemble_iterm constapp =
   172       let
   173         fun of_iterm t =
   174           let
   175             val (t', ts) = CodeThingol.unfold_app t
   176           in of_iapp t' (fold_rev (cons o of_iterm) ts []) end
   177         and of_iapp (IConst (c, (dss, _))) ts = constapp c dss ts
   178           | of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts
   179           | of_iapp ((v, _) `|-> t) ts =
   180               nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts
   181           | of_iapp (ICase (((t, _), cs), t0)) ts =
   182               nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs
   183                 @ [("_", of_iterm t0)])) ts
   184       in of_iterm end;
   185 
   186     fun assemble_eqns (c, (num_args, (dicts, eqns))) =
   187       let
   188         val assemble_arg = assemble_iterm
   189           (fn c => fn _ => fn ts => nbe_apps_constr (idx_of c) ts);
   190         val assemble_rhs = assemble_iterm assemble_constapp;
   191         fun assemble_eqn (args, rhs) =
   192           ([ml_list (rev (dicts @ map assemble_arg args))], assemble_rhs rhs);
   193         val default_args = map nbe_bound (Name.invent_list [] "a" num_args);
   194         val default_eqn = if c = "" then NONE
   195           else SOME ([ml_list (rev default_args)],
   196             nbe_apps_constr (idx_of c) default_args);
   197       in
   198         ((nbe_fun c, map assemble_eqn eqns @ the_list default_eqn),
   199           nbe_abss num_args (nbe_fun c))
   200       end;
   201 
   202     val (fun_vars, fun_vals) = map_split assemble_eqns eqnss';
   203     val deps_vars = ml_list (map nbe_fun deps);
   204   in ml_abs deps_vars (ml_Let (ml_fundefs fun_vars) (ml_list fun_vals)) end;
   205 
   206 (* code compilation *)
   207 
   208 fun compile_eqnss gr raw_deps [] = []
   209   | compile_eqnss gr raw_deps eqnss = 
   210       let
   211         val (deps, deps_vals) = split_list (map_filter
   212           (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Graph.get_node gr dep)))) raw_deps);
   213         val idx_of = raw_deps
   214           |> map (fn dep => (dep, snd (Graph.get_node gr dep)))
   215           |> AList.lookup (op =)
   216           |> (fn f => the o f);
   217         val s = assemble_eqnss idx_of deps eqnss;
   218         val cs = map fst eqnss;
   219       in
   220         s
   221         |> tracing (fn s => "\n--- code to be evaluated:\n" ^ s)
   222         |> ML_Context.evaluate
   223             (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
   224             Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n")
   225             (!trace) univs_cookie
   226         |> (fn f => f deps_vals)
   227         |> (fn univs => cs ~~ univs)
   228       end;
   229 
   230 (* preparing function equations *)
   231 
   232 fun eqns_of_stmt (_, CodeThingol.Fun (_, [])) =
   233       []
   234   | eqns_of_stmt (const, CodeThingol.Fun ((vs, _), eqns)) =
   235       [(const, (vs, map fst eqns))]
   236   | eqns_of_stmt (_, CodeThingol.Datatypecons _) =
   237       []
   238   | eqns_of_stmt (_, CodeThingol.Datatype _) =
   239       []
   240   | eqns_of_stmt (class, CodeThingol.Class (v, (superclasses, classops))) =
   241       let
   242         val names = map snd superclasses @ map fst classops;
   243         val params = Name.invent_list [] "d" (length names);
   244         fun mk (k, name) =
   245           (name, ([(v, [])],
   246             [([IConst (class, ([], [])) `$$ map IVar params], IVar (nth params k))]));
   247       in map_index mk names end
   248   | eqns_of_stmt (_, CodeThingol.Classrel _) =
   249       []
   250   | eqns_of_stmt (_, CodeThingol.Classparam _) =
   251       []
   252   | eqns_of_stmt (inst, CodeThingol.Classinst ((class, (_, arities)), (superinsts, instops))) =
   253       [(inst, (arities, [([], IConst (class, ([], [])) `$$
   254         map (fn (_, (_, (inst, dicts))) => IConst (inst, (dicts, []))) superinsts
   255         @ map (IConst o snd o fst) instops)]))];
   256 
   257 fun compile_stmts stmts_deps =
   258   let
   259     val names = map (fst o fst) stmts_deps;
   260     val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps;
   261     val eqnss = maps (eqns_of_stmt o fst) stmts_deps;
   262     val refl_deps = names_deps
   263       |> maps snd
   264       |> distinct (op =)
   265       |> fold (insert (op =)) names;
   266     fun new_node name (gr, (maxidx, idx_tab)) = if can (Graph.get_node gr) name
   267       then (gr, (maxidx, idx_tab))
   268       else (Graph.new_node (name, (NONE, maxidx)) gr,
   269         (maxidx + 1, Inttab.update_new (maxidx, name) idx_tab));
   270     fun compile gr = eqnss
   271       |> compile_eqnss gr refl_deps
   272       |> rpair gr;
   273   in
   274     fold new_node refl_deps
   275     #> apfst (fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps
   276       #> compile
   277       #-> fold (fn (name, univ) => (Graph.map_node name o apfst) (K (SOME univ))))
   278   end;
   279 
   280 fun ensure_stmts program =
   281   let
   282     fun add_stmts names (gr, (maxidx, idx_tab)) = if exists ((can o Graph.get_node) gr) names
   283       then (gr, (maxidx, idx_tab))
   284       else (gr, (maxidx, idx_tab))
   285         |> compile_stmts (map (fn name => ((name, Graph.get_node program name),
   286           Graph.imm_succs program name)) names);
   287   in fold_rev add_stmts (Graph.strong_conn program) end;
   288 
   289 
   290 (** evaluation **)
   291 
   292 (* term evaluation *)
   293 
   294 fun eval_term gr deps ((vs, ty), t) =
   295   let 
   296     val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t []
   297     val frees' = map (fn v => Free (v, [])) frees;
   298     val dict_frees = maps (fn (v, sort) => map_index (curry DFree v o fst) sort) vs;
   299   in
   300     ("", (vs, [(map IVar frees, t)]))
   301     |> singleton (compile_eqnss gr deps)
   302     |> snd
   303     |> (fn t => apps t (rev (dict_frees @ frees')))
   304   end;
   305 
   306 (* reification *)
   307 
   308 fun term_of_univ thy idx_tab t =
   309   let
   310     fun take_until f [] = []
   311       | take_until f (x::xs) = if f x then [] else x :: take_until f xs;
   312     fun is_dict (Const (idx, _)) =
   313           let
   314             val c = the (Inttab.lookup idx_tab idx);
   315           in
   316             (is_some o CodeName.class_rev thy) c
   317             orelse (is_some o CodeName.classrel_rev thy) c
   318             orelse (is_some o CodeName.instance_rev thy) c
   319           end
   320       | is_dict (DFree _) = true
   321       | is_dict _ = false;
   322     fun of_apps bounds (t, ts) =
   323       fold_map (of_univ bounds) ts
   324       #>> (fn ts' => list_comb (t, rev ts'))
   325     and of_univ bounds (Const (idx, ts)) typidx =
   326           let
   327             val ts' = take_until is_dict ts;
   328             val c = (the o CodeName.const_rev thy o the o Inttab.lookup idx_tab) idx;
   329             val (_, T) = Code.default_typ thy c;
   330             val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, [])) T;
   331             val typidx' = typidx + maxidx_of_typ T' + 1;
   332           in of_apps bounds (Term.Const (c, T'), ts') typidx' end
   333       | of_univ bounds (Free (name, ts)) typidx =
   334           of_apps bounds (Term.Free (name, dummyT), ts) typidx
   335       | of_univ bounds (BVar (name, ts)) typidx =
   336           of_apps bounds (Bound (bounds - name - 1), ts) typidx
   337       | of_univ bounds (t as Abs _) typidx =
   338           typidx
   339           |> of_univ (bounds + 1) (apps t [BVar (bounds, [])])
   340           |-> (fn t' => pair (Term.Abs ("u", dummyT, t')))
   341   in of_univ 0 t 0 |> fst end;
   342 
   343 (* function store *)
   344 
   345 structure Nbe_Functions = CodeDataFun
   346 (
   347   type T = (Univ option * int) Graph.T * (int * string Inttab.table);
   348   val empty = (Graph.empty, (0, Inttab.empty));
   349   fun merge _ ((gr1, (maxidx1, idx_tab1)), (gr2, (maxidx2, idx_tab2))) =
   350     (Graph.merge (K true) (gr1, gr2), (IntInf.max (maxidx1, maxidx2),
   351       Inttab.merge (K true) (idx_tab1, idx_tab2)));
   352   fun purge _ NONE _ = empty
   353     | purge NONE _ _ = empty
   354     | purge (SOME thy) (SOME cs) (gr, (maxidx, idx_tab)) =
   355         let
   356           val cs_exisiting =
   357             map_filter (CodeName.const_rev thy) (Graph.keys gr);
   358           val dels = (Graph.all_preds gr
   359               o map (CodeName.const thy)
   360               o filter (member (op =) cs_exisiting)
   361             ) cs;
   362         in (Graph.del_nodes dels gr, (maxidx, idx_tab)) end;
   363 );
   364 
   365 (* compilation, evaluation and reification *)
   366 
   367 fun compile_eval thy program vs_ty_t deps =
   368   let
   369     val (gr, (_, idx_tab)) = Nbe_Functions.change thy (ensure_stmts program);
   370   in
   371     vs_ty_t
   372     |> eval_term gr deps
   373     |> term_of_univ thy idx_tab
   374   end;
   375 
   376 (* evaluation with type reconstruction *)
   377 
   378 fun eval thy t program vs_ty_t deps =
   379   let
   380     fun subst_const f = map_aterms (fn t as Term.Const (c, ty) => Term.Const (f c, ty)
   381       | t => t);
   382     val subst_triv_consts = subst_const (CodeUnit.resubst_alias thy);
   383     val ty = type_of t;
   384     val type_free = AList.lookup (op =)
   385       (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []));
   386     val type_frees = Term.map_aterms
   387       (fn (t as Term.Free (s, _)) => the_default t (type_free s) | t => t);
   388     fun type_infer t =
   389       singleton (TypeInfer.infer_types (Syntax.pp_global thy) (Sign.tsig_of thy) I
   390         (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE) Name.context 0)
   391       (TypeInfer.constrain ty t);
   392     fun check_tvars t = if null (Term.term_tvars t) then t else
   393       error ("Illegal schematic type variables in normalized term: "
   394         ^ setmp show_types true (Syntax.string_of_term_global thy) t);
   395     val string_of_term = setmp show_types true (Syntax.string_of_term_global thy);
   396   in
   397     compile_eval thy program vs_ty_t deps
   398     |> tracing (fn t => "Normalized:\n" ^ string_of_term t)
   399     |> subst_triv_consts
   400     |> type_frees
   401     |> tracing (fn t => "Vars typed:\n" ^ string_of_term t)
   402     |> type_infer
   403     |> tracing (fn t => "Types inferred:\n" ^ string_of_term t)
   404     |> check_tvars
   405     |> tracing (fn t => "---\n")
   406   end;
   407 
   408 (* evaluation oracle *)
   409 
   410 exception Norm of term * CodeThingol.program
   411   * (CodeThingol.typscheme * CodeThingol.iterm) * string list;
   412 
   413 fun norm_oracle (thy, Norm (t, program, vs_ty_t, deps)) =
   414   Logic.mk_equals (t, eval thy t program vs_ty_t deps);
   415 
   416 fun norm_invoke thy t program vs_ty_t deps =
   417   Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (t, program, vs_ty_t, deps));
   418   (*FIXME get rid of hardwired theory name*)
   419 
   420 fun add_triv_classes thy =
   421   let
   422     val inters = curry (Sorts.inter_sort (Sign.classes_of thy))
   423       (CodeUnit.triv_classes thy);
   424     fun map_sorts f = (map_types o map_atyps)
   425       (fn TVar (v, sort) => TVar (v, f sort)
   426         | TFree (v, sort) => TFree (v, f sort));
   427   in map_sorts inters end;
   428 
   429 fun norm_conv ct =
   430   let
   431     val thy = Thm.theory_of_cterm ct;
   432     fun evaluator' t program vs_ty_t deps = norm_invoke thy t program vs_ty_t deps;
   433     fun evaluator t = (add_triv_classes thy t, evaluator' t);
   434   in CodeThingol.eval_conv thy evaluator ct end;
   435 
   436 fun norm_term thy t =
   437   let
   438     fun evaluator' t program vs_ty_t deps = eval thy t program vs_ty_t deps;
   439     fun evaluator t = (add_triv_classes thy t, evaluator' t);
   440   in (Code.postprocess_term thy o CodeThingol.eval_term thy evaluator) t end;
   441 
   442 (* evaluation command *)
   443 
   444 fun norm_print_term ctxt modes t =
   445   let
   446     val thy = ProofContext.theory_of ctxt;
   447     val t' = norm_term thy t;
   448     val ty' = Term.type_of t';
   449     val ctxt' = Variable.auto_fixes t ctxt;
   450     val p = PrintMode.with_modes modes (fn () =>
   451       Pretty.block [Pretty.quote (Syntax.pretty_term ctxt' t'), Pretty.fbrk,
   452         Pretty.str "::", Pretty.brk 1, Pretty.quote (Syntax.pretty_typ ctxt' ty')]) ();
   453   in Pretty.writeln p end;
   454 
   455 
   456 (** Isar setup **)
   457 
   458 fun norm_print_term_cmd (modes, s) state =
   459   let val ctxt = Toplevel.context_of state
   460   in norm_print_term ctxt modes (Syntax.read_term ctxt s) end;
   461 
   462 val setup = Theory.add_oracle ("norm", norm_oracle);
   463 
   464 local structure P = OuterParse and K = OuterKeyword in
   465 
   466 val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) [];
   467 
   468 val _ =
   469   OuterSyntax.improper_command "normal_form" "normalize term by evaluation" K.diag
   470     (opt_modes -- P.typ >> (Toplevel.keep o norm_print_term_cmd));
   471 
   472 end;
   473 
   474 end;