src/Tools/nbe.ML
changeset 28054 2b84d34c5d02
parent 27609 b23c9ad0fe7d
child 28227 77221ee0f7b9
equal deleted inserted replaced
28053:a2106c0d8c45 28054:2b84d34c5d02
   136 fun nbe_abss 0 f = f `$` ml_list []
   136 fun nbe_abss 0 f = f `$` ml_list []
   137   | nbe_abss n f = name_abss `$$` [string_of_int n, f];
   137   | nbe_abss n f = name_abss `$$` [string_of_int n, f];
   138 
   138 
   139 end;
   139 end;
   140 
   140 
   141 open BasicCodeThingol;
   141 open Basic_Code_Thingol;
   142 
   142 
   143 (* code generation *)
   143 (* code generation *)
   144 
   144 
   145 fun assemble_eqnss idx_of deps eqnss =
   145 fun assemble_eqnss idx_of deps eqnss =
   146   let
   146   let
   170 
   170 
   171     fun assemble_iterm constapp =
   171     fun assemble_iterm constapp =
   172       let
   172       let
   173         fun of_iterm t =
   173         fun of_iterm t =
   174           let
   174           let
   175             val (t', ts) = CodeThingol.unfold_app t
   175             val (t', ts) = Code_Thingol.unfold_app t
   176           in of_iapp t' (fold_rev (cons o of_iterm) ts []) end
   176           in of_iapp t' (fold_rev (cons o of_iterm) ts []) end
   177         and of_iapp (IConst (c, (dss, _))) ts = constapp c dss ts
   177         and of_iapp (IConst (c, (dss, _))) ts = constapp c dss ts
   178           | of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts
   178           | of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts
   179           | of_iapp ((v, _) `|-> t) ts =
   179           | of_iapp ((v, _) `|-> t) ts =
   180               nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts
   180               nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts
   227         |> (fn univs => cs ~~ univs)
   227         |> (fn univs => cs ~~ univs)
   228       end;
   228       end;
   229 
   229 
   230 (* preparing function equations *)
   230 (* preparing function equations *)
   231 
   231 
   232 fun eqns_of_stmt (_, CodeThingol.Fun (_, [])) =
   232 fun eqns_of_stmt (_, Code_Thingol.Fun (_, [])) =
   233       []
   233       []
   234   | eqns_of_stmt (const, CodeThingol.Fun ((vs, _), eqns)) =
   234   | eqns_of_stmt (const, Code_Thingol.Fun ((vs, _), eqns)) =
   235       [(const, (vs, map fst eqns))]
   235       [(const, (vs, map fst eqns))]
   236   | eqns_of_stmt (_, CodeThingol.Datatypecons _) =
   236   | eqns_of_stmt (_, Code_Thingol.Datatypecons _) =
   237       []
   237       []
   238   | eqns_of_stmt (_, CodeThingol.Datatype _) =
   238   | eqns_of_stmt (_, Code_Thingol.Datatype _) =
   239       []
   239       []
   240   | eqns_of_stmt (class, CodeThingol.Class (v, (superclasses, classops))) =
   240   | eqns_of_stmt (class, Code_Thingol.Class (v, (superclasses, classops))) =
   241       let
   241       let
   242         val names = map snd superclasses @ map fst classops;
   242         val names = map snd superclasses @ map fst classops;
   243         val params = Name.invent_list [] "d" (length names);
   243         val params = Name.invent_list [] "d" (length names);
   244         fun mk (k, name) =
   244         fun mk (k, name) =
   245           (name, ([(v, [])],
   245           (name, ([(v, [])],
   246             [([IConst (class, ([], [])) `$$ map IVar params], IVar (nth params k))]));
   246             [([IConst (class, ([], [])) `$$ map IVar params], IVar (nth params k))]));
   247       in map_index mk names end
   247       in map_index mk names end
   248   | eqns_of_stmt (_, CodeThingol.Classrel _) =
   248   | eqns_of_stmt (_, Code_Thingol.Classrel _) =
   249       []
   249       []
   250   | eqns_of_stmt (_, CodeThingol.Classparam _) =
   250   | eqns_of_stmt (_, Code_Thingol.Classparam _) =
   251       []
   251       []
   252   | eqns_of_stmt (inst, CodeThingol.Classinst ((class, (_, arities)), (superinsts, instops))) =
   252   | eqns_of_stmt (inst, Code_Thingol.Classinst ((class, (_, arities)), (superinsts, instops))) =
   253       [(inst, (arities, [([], IConst (class, ([], [])) `$$
   253       [(inst, (arities, [([], IConst (class, ([], [])) `$$
   254         map (fn (_, (_, (inst, dicts))) => IConst (inst, (dicts, []))) superinsts
   254         map (fn (_, (_, (inst, dicts))) => IConst (inst, (dicts, []))) superinsts
   255         @ map (IConst o snd o fst) instops)]))];
   255         @ map (IConst o snd o fst) instops)]))];
   256 
   256 
   257 fun compile_stmts stmts_deps =
   257 fun compile_stmts stmts_deps =
   291 
   291 
   292 (* term evaluation *)
   292 (* term evaluation *)
   293 
   293 
   294 fun eval_term gr deps ((vs, ty), t) =
   294 fun eval_term gr deps ((vs, ty), t) =
   295   let 
   295   let 
   296     val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t []
   296     val frees = Code_Thingol.fold_unbound_varnames (insert (op =)) t []
   297     val frees' = map (fn v => Free (v, [])) frees;
   297     val frees' = map (fn v => Free (v, [])) frees;
   298     val dict_frees = maps (fn (v, sort) => map_index (curry DFree v o fst) sort) vs;
   298     val dict_frees = maps (fn (v, sort) => map_index (curry DFree v o fst) sort) vs;
   299   in
   299   in
   300     ("", (vs, [(map IVar frees, t)]))
   300     ("", (vs, [(map IVar frees, t)]))
   301     |> singleton (compile_eqnss gr deps)
   301     |> singleton (compile_eqnss gr deps)
   311       | take_until f (x::xs) = if f x then [] else x :: take_until f xs;
   311       | take_until f (x::xs) = if f x then [] else x :: take_until f xs;
   312     fun is_dict (Const (idx, _)) =
   312     fun is_dict (Const (idx, _)) =
   313           let
   313           let
   314             val c = the (Inttab.lookup idx_tab idx);
   314             val c = the (Inttab.lookup idx_tab idx);
   315           in
   315           in
   316             (is_some o CodeName.class_rev thy) c
   316             (is_some o Code_Name.class_rev thy) c
   317             orelse (is_some o CodeName.classrel_rev thy) c
   317             orelse (is_some o Code_Name.classrel_rev thy) c
   318             orelse (is_some o CodeName.instance_rev thy) c
   318             orelse (is_some o Code_Name.instance_rev thy) c
   319           end
   319           end
   320       | is_dict (DFree _) = true
   320       | is_dict (DFree _) = true
   321       | is_dict _ = false;
   321       | is_dict _ = false;
   322     fun of_apps bounds (t, ts) =
   322     fun of_apps bounds (t, ts) =
   323       fold_map (of_univ bounds) ts
   323       fold_map (of_univ bounds) ts
   324       #>> (fn ts' => list_comb (t, rev ts'))
   324       #>> (fn ts' => list_comb (t, rev ts'))
   325     and of_univ bounds (Const (idx, ts)) typidx =
   325     and of_univ bounds (Const (idx, ts)) typidx =
   326           let
   326           let
   327             val ts' = take_until is_dict ts;
   327             val ts' = take_until is_dict ts;
   328             val c = (the o CodeName.const_rev thy o the o Inttab.lookup idx_tab) idx;
   328             val c = (the o Code_Name.const_rev thy o the o Inttab.lookup idx_tab) idx;
   329             val (_, T) = Code.default_typ thy c;
   329             val (_, T) = Code.default_typ thy c;
   330             val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, [])) T;
   330             val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, [])) T;
   331             val typidx' = typidx + maxidx_of_typ T' + 1;
   331             val typidx' = typidx + maxidx_of_typ T' + 1;
   332           in of_apps bounds (Term.Const (c, T'), ts') typidx' end
   332           in of_apps bounds (Term.Const (c, T'), ts') typidx' end
   333       | of_univ bounds (Free (name, ts)) typidx =
   333       | of_univ bounds (Free (name, ts)) typidx =
   347   type T = (Univ option * int) Graph.T * (int * string Inttab.table);
   347   type T = (Univ option * int) Graph.T * (int * string Inttab.table);
   348   val empty = (Graph.empty, (0, Inttab.empty));
   348   val empty = (Graph.empty, (0, Inttab.empty));
   349   fun purge thy cs (gr, (maxidx, idx_tab)) =
   349   fun purge thy cs (gr, (maxidx, idx_tab)) =
   350     let
   350     let
   351       val cs_exisiting =
   351       val cs_exisiting =
   352         map_filter (CodeName.const_rev thy) (Graph.keys gr);
   352         map_filter (Code_Name.const_rev thy) (Graph.keys gr);
   353       val dels = (Graph.all_preds gr
   353       val dels = (Graph.all_preds gr
   354           o map (CodeName.const thy)
   354           o map (Code_Name.const thy)
   355           o filter (member (op =) cs_exisiting)
   355           o filter (member (op =) cs_exisiting)
   356         ) cs;
   356         ) cs;
   357     in (Graph.del_nodes dels gr, (maxidx, idx_tab)) end;
   357     in (Graph.del_nodes dels gr, (maxidx, idx_tab)) end;
   358 );
   358 );
   359 
   359 
   372 
   372 
   373 fun eval thy t program vs_ty_t deps =
   373 fun eval thy t program vs_ty_t deps =
   374   let
   374   let
   375     fun subst_const f = map_aterms (fn t as Term.Const (c, ty) => Term.Const (f c, ty)
   375     fun subst_const f = map_aterms (fn t as Term.Const (c, ty) => Term.Const (f c, ty)
   376       | t => t);
   376       | t => t);
   377     val subst_triv_consts = subst_const (CodeUnit.resubst_alias thy);
   377     val subst_triv_consts = subst_const (Code_Unit.resubst_alias thy);
   378     val ty = type_of t;
   378     val ty = type_of t;
   379     val type_free = AList.lookup (op =)
   379     val type_free = AList.lookup (op =)
   380       (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []));
   380       (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []));
   381     val type_frees = Term.map_aterms
   381     val type_frees = Term.map_aterms
   382       (fn (t as Term.Free (s, _)) => the_default t (type_free s) | t => t);
   382       (fn (t as Term.Free (s, _)) => the_default t (type_free s) | t => t);
   400     |> tracing (fn t => "---\n")
   400     |> tracing (fn t => "---\n")
   401   end;
   401   end;
   402 
   402 
   403 (* evaluation oracle *)
   403 (* evaluation oracle *)
   404 
   404 
   405 exception Norm of term * CodeThingol.program
   405 exception Norm of term * Code_Thingol.program
   406   * (CodeThingol.typscheme * CodeThingol.iterm) * string list;
   406   * (Code_Thingol.typscheme * Code_Thingol.iterm) * string list;
   407 
   407 
   408 fun norm_oracle (thy, Norm (t, program, vs_ty_t, deps)) =
   408 fun norm_oracle (thy, Norm (t, program, vs_ty_t, deps)) =
   409   Logic.mk_equals (t, eval thy t program vs_ty_t deps);
   409   Logic.mk_equals (t, eval thy t program vs_ty_t deps);
   410 
   410 
   411 fun norm_invoke thy t program vs_ty_t deps =
   411 fun norm_invoke thy t program vs_ty_t deps =
   413   (*FIXME get rid of hardwired theory name*)
   413   (*FIXME get rid of hardwired theory name*)
   414 
   414 
   415 fun add_triv_classes thy =
   415 fun add_triv_classes thy =
   416   let
   416   let
   417     val inters = curry (Sorts.inter_sort (Sign.classes_of thy))
   417     val inters = curry (Sorts.inter_sort (Sign.classes_of thy))
   418       (CodeUnit.triv_classes thy);
   418       (Code_Unit.triv_classes thy);
   419     fun map_sorts f = (map_types o map_atyps)
   419     fun map_sorts f = (map_types o map_atyps)
   420       (fn TVar (v, sort) => TVar (v, f sort)
   420       (fn TVar (v, sort) => TVar (v, f sort)
   421         | TFree (v, sort) => TFree (v, f sort));
   421         | TFree (v, sort) => TFree (v, f sort));
   422   in map_sorts inters end;
   422   in map_sorts inters end;
   423 
   423 
   424 fun norm_conv ct =
   424 fun norm_conv ct =
   425   let
   425   let
   426     val thy = Thm.theory_of_cterm ct;
   426     val thy = Thm.theory_of_cterm ct;
   427     fun evaluator' t program vs_ty_t deps = norm_invoke thy t program vs_ty_t deps;
   427     fun evaluator' t program vs_ty_t deps = norm_invoke thy t program vs_ty_t deps;
   428     fun evaluator t = (add_triv_classes thy t, evaluator' t);
   428     fun evaluator t = (add_triv_classes thy t, evaluator' t);
   429   in CodeThingol.eval_conv thy evaluator ct end;
   429   in Code_Thingol.eval_conv thy evaluator ct end;
   430 
   430 
   431 fun norm_term thy t =
   431 fun norm_term thy t =
   432   let
   432   let
   433     fun evaluator' t program vs_ty_t deps = eval thy t program vs_ty_t deps;
   433     fun evaluator' t program vs_ty_t deps = eval thy t program vs_ty_t deps;
   434     fun evaluator t = (add_triv_classes thy t, evaluator' t);
   434     fun evaluator t = (add_triv_classes thy t, evaluator' t);
   435   in (Code.postprocess_term thy o CodeThingol.eval_term thy evaluator) t end;
   435   in (Code.postprocess_term thy o Code_Thingol.eval_term thy evaluator) t end;
   436 
   436 
   437 (* evaluation command *)
   437 (* evaluation command *)
   438 
   438 
   439 fun norm_print_term ctxt modes t =
   439 fun norm_print_term ctxt modes t =
   440   let
   440   let