src/Tools/nbe.ML
changeset 55757 9fc71814b8c1
parent 55167 f3ac344284ff
child 56245 84fc7dfa3cd4
     1.1 --- a/src/Tools/nbe.ML	Wed Feb 26 10:10:38 2014 +0100
     1.2 +++ b/src/Tools/nbe.ML	Wed Feb 26 11:57:52 2014 +0100
     1.3 @@ -6,10 +6,10 @@
     1.4  
     1.5  signature NBE =
     1.6  sig
     1.7 -  val dynamic_conv: theory -> conv
     1.8 -  val dynamic_value: theory -> term -> term
     1.9 -  val static_conv: theory -> string list -> conv
    1.10 -  val static_value: theory -> string list -> term -> term
    1.11 +  val dynamic_conv: Proof.context -> conv
    1.12 +  val dynamic_value: Proof.context -> term -> term
    1.13 +  val static_conv: Proof.context -> string list -> Proof.context -> conv
    1.14 +  val static_value: Proof.context -> string list -> Proof.context -> term -> term
    1.15  
    1.16    datatype Univ =
    1.17        Const of int * Univ list               (*named (uninterpreted) constants*)
    1.18 @@ -83,8 +83,9 @@
    1.19  
    1.20  in
    1.21  
    1.22 -fun lift_triv_classes_conv thy conv ct =
    1.23 +fun lift_triv_classes_conv ctxt conv ct =
    1.24    let
    1.25 +    val thy = Proof_Context.theory_of ctxt;
    1.26      val algebra = Sign.classes_of thy;
    1.27      val certT = Thm.ctyp_of thy;
    1.28      val triv_classes = get_triv_classes thy;
    1.29 @@ -128,8 +129,9 @@
    1.30      |> strip_of_class
    1.31    end;
    1.32  
    1.33 -fun lift_triv_classes_rew thy rew t =
    1.34 +fun lift_triv_classes_rew ctxt rew t =
    1.35    let
    1.36 +    val thy = Proof_Context.theory_of ctxt;
    1.37      val algebra = Sign.classes_of thy;
    1.38      val triv_classes = get_triv_classes thy;
    1.39      val vs = Term.add_tfrees t [];
    1.40 @@ -388,10 +390,9 @@
    1.41  
    1.42  (* compile equations *)
    1.43  
    1.44 -fun compile_eqnss thy nbe_program raw_deps [] = []
    1.45 -  | compile_eqnss thy nbe_program raw_deps eqnss =
    1.46 +fun compile_eqnss ctxt nbe_program raw_deps [] = []
    1.47 +  | compile_eqnss ctxt nbe_program raw_deps eqnss =
    1.48        let
    1.49 -        val ctxt = Proof_Context.init_global thy;
    1.50          val (deps, deps_vals) = split_list (map_filter
    1.51            (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Code_Symbol.Graph.get_node nbe_program dep)))) raw_deps);
    1.52          val idx_of = raw_deps
    1.53 @@ -453,7 +454,7 @@
    1.54    else (Code_Symbol.Graph.new_node (name, (NONE, maxidx)) nbe_program,
    1.55      (maxidx + 1, Inttab.update_new (maxidx, name) idx_tab));
    1.56  
    1.57 -fun compile_stmts thy stmts_deps =
    1.58 +fun compile_stmts ctxt stmts_deps =
    1.59    let
    1.60      val names = map (fst o fst) stmts_deps;
    1.61      val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps;
    1.62 @@ -463,7 +464,7 @@
    1.63        |> distinct (op =)
    1.64        |> fold (insert (op =)) names;
    1.65      fun compile nbe_program = eqnss
    1.66 -      |> compile_eqnss thy nbe_program refl_deps
    1.67 +      |> compile_eqnss ctxt nbe_program refl_deps
    1.68        |> rpair nbe_program;
    1.69    in
    1.70      fold ensure_const_idx refl_deps
    1.71 @@ -472,12 +473,12 @@
    1.72        #-> fold (fn (name, univ) => (Code_Symbol.Graph.map_node name o apfst) (K (SOME univ))))
    1.73    end;
    1.74  
    1.75 -fun compile_program thy program =
    1.76 +fun compile_program ctxt program =
    1.77    let
    1.78      fun add_stmts names (nbe_program, (maxidx, idx_tab)) = if exists ((can o Code_Symbol.Graph.get_node) nbe_program) names
    1.79        then (nbe_program, (maxidx, idx_tab))
    1.80        else (nbe_program, (maxidx, idx_tab))
    1.81 -        |> compile_stmts thy (map (fn name => ((name, Code_Symbol.Graph.get_node program name),
    1.82 +        |> compile_stmts ctxt (map (fn name => ((name, Code_Symbol.Graph.get_node program name),
    1.83            Code_Symbol.Graph.immediate_succs program name)) names);
    1.84    in
    1.85      fold_rev add_stmts (Code_Symbol.Graph.strong_conn program)
    1.86 @@ -488,12 +489,12 @@
    1.87  
    1.88  (* term evaluation by compilation *)
    1.89  
    1.90 -fun compile_term thy nbe_program deps (vs : (string * sort) list, t) =
    1.91 +fun compile_term ctxt nbe_program deps (vs : (string * sort) list, t) =
    1.92    let 
    1.93      val dict_frees = maps (fn (v, sort) => map_index (curry DFree v o fst) sort) vs;
    1.94    in
    1.95      (Code_Symbol.value, (vs, [([], t)]))
    1.96 -    |> singleton (compile_eqnss thy nbe_program deps)
    1.97 +    |> singleton (compile_eqnss ctxt nbe_program deps)
    1.98      |> snd
    1.99      |> (fn t => apps t (rev dict_frees))
   1.100    end;
   1.101 @@ -506,7 +507,7 @@
   1.102    | typ_of_itype vs (ITyVar v) =
   1.103        TFree ("'" ^ v, (the o AList.lookup (op =) vs) v);
   1.104  
   1.105 -fun term_of_univ thy (idx_tab : Code_Symbol.T Inttab.table) t =
   1.106 +fun term_of_univ ctxt (idx_tab : Code_Symbol.T Inttab.table) t =
   1.107    let
   1.108      fun take_until f [] = []
   1.109        | take_until f (x :: xs) = if f x then [] else x :: take_until f xs;
   1.110 @@ -527,7 +528,7 @@
   1.111              val const = const_of_idx idx;
   1.112              val T = map_type_tvar (fn ((v, i), _) =>
   1.113                Type_Infer.param typidx (v ^ string_of_int i, []))
   1.114 -                (Sign.the_const_type thy const);
   1.115 +                (Sign.the_const_type (Proof_Context.theory_of ctxt) const);
   1.116              val typidx' = typidx + 1;
   1.117            in of_apps bounds (Term.Const (const, T), ts') typidx' end
   1.118        | of_univ bounds (BVar (n, ts)) typidx =
   1.119 @@ -541,9 +542,9 @@
   1.120  
   1.121  (* evaluation with type reconstruction *)
   1.122  
   1.123 -fun eval_term thy (nbe_program, idx_tab) ((vs0, (vs, ty)), t) deps =
   1.124 +fun eval_term raw_ctxt (nbe_program, idx_tab) ((vs0, (vs, ty)), t) deps =
   1.125    let
   1.126 -    val ctxt = Syntax.init_pretty_global thy;
   1.127 +    val ctxt = Syntax.init_pretty_global (Proof_Context.theory_of raw_ctxt);
   1.128      val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt);
   1.129      val ty' = typ_of_itype vs0 ty;
   1.130      fun type_infer t =
   1.131 @@ -553,8 +554,8 @@
   1.132        if null (Term.add_tvars t []) then t
   1.133        else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t);
   1.134    in
   1.135 -    compile_term thy nbe_program deps (vs, t)
   1.136 -    |> term_of_univ thy idx_tab
   1.137 +    compile_term ctxt nbe_program deps (vs, t)
   1.138 +    |> term_of_univ ctxt idx_tab
   1.139      |> traced (fn t => "Normalized:\n" ^ string_of_term t)
   1.140      |> type_infer
   1.141      |> traced (fn t => "Types inferred:\n" ^ string_of_term t)
   1.142 @@ -571,18 +572,19 @@
   1.143    val empty = (Code_Symbol.Graph.empty, (0, Inttab.empty));
   1.144  );
   1.145  
   1.146 -fun compile ignore_cache thy program =
   1.147 +fun compile ignore_cache ctxt program =
   1.148    let
   1.149      val (nbe_program, (_, idx_tab)) =
   1.150 -      Nbe_Functions.change (if ignore_cache then NONE else SOME thy)
   1.151 -        (compile_program thy program);
   1.152 +      Nbe_Functions.change (if ignore_cache then NONE else SOME (Proof_Context.theory_of ctxt))
   1.153 +        (compile_program ctxt program);
   1.154    in (nbe_program, idx_tab) end;
   1.155  
   1.156  
   1.157  (* evaluation oracle *)
   1.158  
   1.159 -fun mk_equals thy lhs raw_rhs =
   1.160 +fun mk_equals ctxt lhs raw_rhs =
   1.161    let
   1.162 +    val thy = Proof_Context.theory_of ctxt;
   1.163      val ty = Thm.typ_of (Thm.ctyp_of_term lhs);
   1.164      val eq = Thm.cterm_of thy (Term.Const ("==", ty --> ty --> propT));
   1.165      val rhs = Thm.cterm_of thy raw_rhs;
   1.166 @@ -590,28 +592,33 @@
   1.167  
   1.168  val (_, raw_oracle) = Context.>>> (Context.map_theory_result
   1.169    (Thm.add_oracle (@{binding normalization_by_evaluation},
   1.170 -    fn (thy, nbe_program_idx_tab, vsp_ty_t, deps, ct) =>
   1.171 -      mk_equals thy ct (eval_term thy nbe_program_idx_tab vsp_ty_t deps))));
   1.172 +    fn (ctxt, nbe_program_idx_tab, vsp_ty_t, deps, ct) =>
   1.173 +      mk_equals ctxt ct (eval_term ctxt nbe_program_idx_tab vsp_ty_t deps))));
   1.174 +
   1.175 +fun oracle ctxt nbe_program_idx_tab vsp_ty_t deps ct =
   1.176 +  raw_oracle (ctxt, nbe_program_idx_tab, vsp_ty_t, deps, ct);
   1.177  
   1.178 -fun oracle thy nbe_program_idx_tab vsp_ty_t deps ct =
   1.179 -  raw_oracle (thy, nbe_program_idx_tab, vsp_ty_t, deps, ct);
   1.180 +fun dynamic_conv ctxt = lift_triv_classes_conv ctxt
   1.181 +  (Code_Thingol.dynamic_conv ctxt (oracle ctxt o compile false ctxt));
   1.182  
   1.183 -fun dynamic_conv thy = lift_triv_classes_conv thy
   1.184 -  (Code_Thingol.dynamic_conv thy (oracle thy o compile false thy));
   1.185 +fun dynamic_value ctxt = lift_triv_classes_rew ctxt
   1.186 +  (Code_Thingol.dynamic_value ctxt I (eval_term ctxt o compile false ctxt));
   1.187  
   1.188 -fun dynamic_value thy = lift_triv_classes_rew thy
   1.189 -  (Code_Thingol.dynamic_value thy I (eval_term thy o compile false thy));
   1.190 +fun static_conv ctxt consts =
   1.191 +  let
   1.192 +    val evaluator = Code_Thingol.static_conv ctxt consts
   1.193 +      (fn program => fn _ => K (oracle ctxt (compile true ctxt program)));
   1.194 +  in fn ctxt' => lift_triv_classes_conv ctxt' (evaluator ctxt') end;
   1.195  
   1.196 -fun static_conv thy consts = lift_triv_classes_conv thy
   1.197 -  (Code_Thingol.static_conv thy consts (K o oracle thy o compile true thy));
   1.198 -
   1.199 -fun static_value thy consts = lift_triv_classes_rew thy
   1.200 -  (Code_Thingol.static_value thy I consts (K o eval_term thy o compile true thy));
   1.201 +fun static_value ctxt consts =
   1.202 +  let
   1.203 +    val evaluator = Code_Thingol.static_value ctxt I consts
   1.204 +      (fn program => fn _ => K (eval_term ctxt (compile true ctxt program)));
   1.205 +  in fn ctxt' => lift_triv_classes_rew ctxt' (evaluator ctxt') end;
   1.206  
   1.207  
   1.208  (** setup **)
   1.209  
   1.210 -val setup = Value.add_evaluator ("nbe", dynamic_value o Proof_Context.theory_of);
   1.211 +val setup = Value.add_evaluator ("nbe", dynamic_value);
   1.212  
   1.213  end;
   1.214 - 
   1.215 \ No newline at end of file