src/Tools/nbe.ML
changeset 26739 947b6013e863
parent 26064 65585de05a66
child 26747 f32fa5f5bdd1
equal deleted inserted replaced
26738:615e1a86787b 26739:947b6013e863
    21                                             (*abstractions as closures*)
    21                                             (*abstractions as closures*)
    22 
    22 
    23   val univs_ref: (unit -> Univ list -> Univ list) option ref
    23   val univs_ref: (unit -> Univ list -> Univ list) option ref
    24   val trace: bool ref
    24   val trace: bool ref
    25 
    25 
    26   val setup: theory -> theory
    26   val setup: class list -> (string * string) list -> theory -> theory
    27 end;
    27 end;
    28 
    28 
    29 structure Nbe: NBE =
    29 structure Nbe: NBE =
    30 struct
    30 struct
    31 
    31 
   325     and of_univ bounds (Const (idx, ts)) typidx =
   325     and of_univ bounds (Const (idx, ts)) typidx =
   326           let
   326           let
   327             val ts' = take_until is_dict ts;
   327             val ts' = take_until is_dict ts;
   328             val c = (the o CodeName.const_rev thy o the o Inttab.lookup idx_tab) idx;
   328             val c = (the o CodeName.const_rev thy o the o Inttab.lookup idx_tab) idx;
   329             val T = Code.default_typ thy c;
   329             val T = Code.default_typ thy c;
   330             val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T;
   330             val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, [])) T;
   331             val typidx' = typidx + maxidx_of_typ T' + 1;
   331             val typidx' = typidx + maxidx_of_typ T' + 1;
   332           in of_apps bounds (Term.Const (c, T'), ts') typidx' end
   332           in of_apps bounds (Term.Const (c, T'), ts') typidx' end
   333       | of_univ bounds (Free (name, ts)) typidx =
   333       | of_univ bounds (Free (name, ts)) typidx =
   334           of_apps bounds (Term.Free (name, dummyT), ts) typidx
   334           of_apps bounds (Term.Free (name, dummyT), ts) typidx
   335       | of_univ bounds (BVar (name, ts)) typidx =
   335       | of_univ bounds (BVar (name, ts)) typidx =
   371     vs_ty_t
   371     vs_ty_t
   372     |> eval_term gr deps
   372     |> eval_term gr deps
   373     |> term_of_univ thy idx_tab
   373     |> term_of_univ thy idx_tab
   374   end;
   374   end;
   375 
   375 
       
   376 (* trivial type classes *)
       
   377 
       
   378 structure Nbe_Triv_Classes = TheoryDataFun
       
   379 (
       
   380   type T = class list * (string * string) list;
       
   381   val empty = ([], []);
       
   382   val copy = I;
       
   383   val extend = I;
       
   384   fun merge _ ((classes1, consts1), (classes2, consts2)) =
       
   385     (Library.merge (op =) (classes1, classes2), Library.merge (op =) (consts1, consts2));
       
   386 )
       
   387 
       
   388 fun add_triv_classes thy =
       
   389   let
       
   390     val (trivs, _) = Nbe_Triv_Classes.get thy;
       
   391     val inters = curry (Sorts.inter_sort (Sign.classes_of thy)) trivs;
       
   392     fun map_sorts f = (map_types o map_atyps)
       
   393       (fn TVar (v, sort) => TVar (v, f sort)
       
   394         | TFree (v, sort) => TFree (v, f sort));
       
   395   in map_sorts inters end;
       
   396 
       
   397 fun subst_triv_consts thy =
       
   398   let
       
   399     fun subst_const f = map_aterms (fn t as Term.Const (c, ty) => (case f c
       
   400          of SOME c' => Term.Const (c', ty)
       
   401           | NONE => t)
       
   402       | t => t);
       
   403     val (_, consts) = Nbe_Triv_Classes.get thy;
       
   404     val subst_inst = perhaps (Option.map fst o AxClass.inst_of_param thy);
       
   405   in map_aterms (subst_const (AList.lookup (op =) consts o subst_inst)) end;
       
   406 
   376 (* evaluation with type reconstruction *)
   407 (* evaluation with type reconstruction *)
   377 
   408 
   378 fun eval thy code t vs_ty_t deps =
   409 fun eval thy t code vs_ty_t deps =
   379   let
   410   let
   380     val ty = type_of t;
   411     val ty = type_of t;
   381     fun subst_Frees [] = I
   412     val type_free = AList.lookup (op =)
   382       | subst_Frees inst =
   413       (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []));
   383           Term.map_aterms (fn (t as Term.Free (s, _)) => the_default t (AList.lookup (op =) inst s)
   414     val type_frees = Term.map_aterms
   384                             | t => t);
   415       (fn (t as Term.Free (s, _)) => the_default t (type_free s) | t => t);
   385     val anno_vars =
   416     fun type_infer t = [(t, ty)]
   386       subst_Frees (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []))
   417       |> TypeInfer.infer_types (Sign.pp thy) (Sign.tsig_of thy) I
   387       #> subst_Vars (map (fn (ixn, T) => (ixn, Var (ixn, T))) (Term.add_vars t []))
   418            (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE)
   388     fun constrain t =
   419            Name.context 0 NONE
   389       singleton (Syntax.check_terms (ProofContext.init thy)) (TypeInfer.constrain ty t);
   420       |> fst |> the_single |> fst;
   390     fun check_tvars t = if null (Term.term_tvars t) then t else
   421     fun check_tvars t = if null (Term.term_tvars t) then t else
   391       error ("Illegal schematic type variables in normalized term: "
   422       error ("Illegal schematic type variables in normalized term: "
   392         ^ setmp show_types true (Sign.string_of_term thy) t);
   423         ^ setmp show_types true (Sign.string_of_term thy) t);
   393     val string_of_term = setmp show_types true (Sign.string_of_term thy);
   424     val string_of_term = setmp show_types true (Sign.string_of_term thy);
   394   in
   425   in
   395     compile_eval thy code vs_ty_t deps
   426     compile_eval thy code vs_ty_t deps
   396     |> tracing (fn t => "Normalized:\n" ^ string_of_term t)
   427     |> tracing (fn t => "Normalized:\n" ^ string_of_term t)
   397     |> anno_vars
   428     |> subst_triv_consts thy
       
   429     |> type_frees
   398     |> tracing (fn t => "Vars typed:\n" ^ string_of_term t)
   430     |> tracing (fn t => "Vars typed:\n" ^ string_of_term t)
   399     |> constrain
   431     |> type_infer
   400     |> tracing (fn t => "Types inferred:\n" ^ string_of_term t)
   432     |> tracing (fn t => "Types inferred:\n" ^ string_of_term t)
       
   433     |> check_tvars
   401     |> tracing (fn t => "---\n")
   434     |> tracing (fn t => "---\n")
   402     |> check_tvars
       
   403   end;
   435   end;
   404 
   436 
   405 (* evaluation oracle *)
   437 (* evaluation oracle *)
   406 
   438 
   407 exception Norm of CodeThingol.code * term
   439 exception Norm of term * CodeThingol.code
   408   * (CodeThingol.typscheme * CodeThingol.iterm) * string list;
   440   * (CodeThingol.typscheme * CodeThingol.iterm) * string list;
   409 
   441 
   410 fun norm_oracle (thy, Norm (code, t, vs_ty_t, deps)) =
   442 fun norm_oracle (thy, Norm (t, code, vs_ty_t, deps)) =
   411   Logic.mk_equals (t, eval thy code t vs_ty_t deps);
   443   Logic.mk_equals (t, eval thy t code vs_ty_t deps);
   412 
   444 
   413 fun norm_invoke thy code t vs_ty_t deps =
   445 fun norm_invoke thy t code vs_ty_t deps =
   414   Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (code, t, vs_ty_t, deps));
   446   Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (t, code, vs_ty_t, deps));
   415   (*FIXME get rid of hardwired theory name*)
   447   (*FIXME get rid of hardwired theory name*)
   416 
   448 
   417 fun norm_conv ct =
   449 fun norm_conv ct =
   418   let
   450   let
   419     val thy = Thm.theory_of_cterm ct;
   451     val thy = Thm.theory_of_cterm ct;
   420     fun conv code vs_ty_t deps ct =
   452     fun evaluator' t code vs_ty_t deps = norm_invoke thy t code vs_ty_t deps;
   421       let
   453     fun evaluator t = (add_triv_classes thy t, evaluator' t);
   422         val t = Thm.term_of ct;
   454   in CodePackage.evaluate_conv thy evaluator ct end;
   423       in norm_invoke thy code t vs_ty_t deps end;
   455 
   424   in CodePackage.evaluate_conv thy conv ct end;
   456 fun norm_term thy t =
   425 
   457   let
   426 fun norm_term thy =
   458     fun evaluator' t code vs_ty_t deps = eval thy t code vs_ty_t deps;
   427   let
   459     fun evaluator t = (add_triv_classes thy t, evaluator' t);
   428     fun invoke code vs_ty_t deps t =
   460   in (Code.postprocess_term thy o CodePackage.evaluate_term thy evaluator) t end;
   429       eval thy code t vs_ty_t deps;
       
   430   in CodePackage.evaluate_term thy invoke #> Code.postprocess_term thy end;
       
   431 
   461 
   432 (* evaluation command *)
   462 (* evaluation command *)
   433 
   463 
   434 fun norm_print_term ctxt modes t =
   464 fun norm_print_term ctxt modes t =
   435   let
   465   let
   446 
   476 
   447 fun norm_print_term_cmd (modes, s) state =
   477 fun norm_print_term_cmd (modes, s) state =
   448   let val ctxt = Toplevel.context_of state
   478   let val ctxt = Toplevel.context_of state
   449   in norm_print_term ctxt modes (Syntax.read_term ctxt s) end;
   479   in norm_print_term ctxt modes (Syntax.read_term ctxt s) end;
   450 
   480 
   451 val setup = Theory.add_oracle ("norm", norm_oracle)
   481 fun setup nbe_classes nbe_consts =
       
   482   Theory.add_oracle ("norm", norm_oracle)
       
   483   #> Nbe_Triv_Classes.map (K (nbe_classes, nbe_consts));
   452 
   484 
   453 local structure P = OuterParse and K = OuterKeyword in
   485 local structure P = OuterParse and K = OuterKeyword in
   454 
   486 
   455 val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) [];
   487 val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) [];
   456 
   488