src/Tools/nbe.ML
changeset 39392 7a0fcee7a2a3
parent 39388 fdbb2c55ffc2
child 39396 e9cad160aa0f
     1.1 --- a/src/Tools/nbe.ML	Wed Sep 15 11:30:32 2010 +0200
     1.2 +++ b/src/Tools/nbe.ML	Wed Sep 15 12:11:11 2010 +0200
     1.3 @@ -377,15 +377,16 @@
     1.4    in ml_abs deps_vars (ml_Let (ml_fundefs (flat fun_vars)) (ml_list fun_vals)) end;
     1.5  
     1.6  
     1.7 -(* code compilation *)
     1.8 +(* compile equations *)
     1.9  
    1.10 -fun compile_eqnss ctxt gr raw_deps [] = []
    1.11 -  | compile_eqnss ctxt gr raw_deps eqnss =
    1.12 +fun compile_eqnss thy nbe_program raw_deps [] = []
    1.13 +  | compile_eqnss thy nbe_program raw_deps eqnss =
    1.14        let
    1.15 +        val ctxt = ProofContext.init_global thy;
    1.16          val (deps, deps_vals) = split_list (map_filter
    1.17 -          (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Graph.get_node gr dep)))) raw_deps);
    1.18 +          (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Graph.get_node nbe_program dep)))) raw_deps);
    1.19          val idx_of = raw_deps
    1.20 -          |> map (fn dep => (dep, snd (Graph.get_node gr dep)))
    1.21 +          |> map (fn dep => (dep, snd (Graph.get_node nbe_program dep)))
    1.22            |> AList.lookup (op =)
    1.23            |> (fn f => the o f);
    1.24          val s = assemble_eqnss idx_of deps eqnss;
    1.25 @@ -400,7 +401,7 @@
    1.26        end;
    1.27  
    1.28  
    1.29 -(* preparing function equations *)
    1.30 +(* extract equations from statements *)
    1.31  
    1.32  fun eqns_of_stmt (_, Code_Thingol.Fun (_, ((_, []), _))) =
    1.33        []
    1.34 @@ -428,7 +429,10 @@
    1.35          map (fn (_, (_, (inst, dss))) => IConst (inst, (([], dss), []))) super_instances
    1.36          @ map (IConst o snd o fst) classparam_instances)]))];
    1.37  
    1.38 -fun compile_stmts ctxt stmts_deps =
    1.39 +
    1.40 +(* compile whole programs *)
    1.41 +
    1.42 +fun compile_stmts thy stmts_deps =
    1.43    let
    1.44      val names = map (fst o fst) stmts_deps;
    1.45      val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps;
    1.46 @@ -437,13 +441,13 @@
    1.47        |> maps snd
    1.48        |> distinct (op =)
    1.49        |> fold (insert (op =)) names;
    1.50 -    fun new_node name (gr, (maxidx, idx_tab)) = if can (Graph.get_node gr) name
    1.51 -      then (gr, (maxidx, idx_tab))
    1.52 -      else (Graph.new_node (name, (NONE, maxidx)) gr,
    1.53 +    fun new_node name (nbe_program, (maxidx, idx_tab)) = if can (Graph.get_node nbe_program) name
    1.54 +      then (nbe_program, (maxidx, idx_tab))
    1.55 +      else (Graph.new_node (name, (NONE, maxidx)) nbe_program,
    1.56          (maxidx + 1, Inttab.update_new (maxidx, name) idx_tab));
    1.57 -    fun compile gr = eqnss
    1.58 -      |> compile_eqnss ctxt gr refl_deps
    1.59 -      |> rpair gr;
    1.60 +    fun compile nbe_program = eqnss
    1.61 +      |> compile_eqnss thy nbe_program refl_deps
    1.62 +      |> rpair nbe_program;
    1.63    in
    1.64      fold new_node refl_deps
    1.65      #> apfst (fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps
    1.66 @@ -451,12 +455,12 @@
    1.67        #-> fold (fn (name, univ) => (Graph.map_node name o apfst) (K (SOME univ))))
    1.68    end;
    1.69  
    1.70 -fun ensure_stmts ctxt program =
    1.71 +fun compile_program thy program =
    1.72    let
    1.73 -    fun add_stmts names (gr, (maxidx, idx_tab)) = if exists ((can o Graph.get_node) gr) names
    1.74 -      then (gr, (maxidx, idx_tab))
    1.75 -      else (gr, (maxidx, idx_tab))
    1.76 -        |> compile_stmts ctxt (map (fn name => ((name, Graph.get_node program name),
    1.77 +    fun add_stmts names (nbe_program, (maxidx, idx_tab)) = if exists ((can o Graph.get_node) nbe_program) names
    1.78 +      then (nbe_program, (maxidx, idx_tab))
    1.79 +      else (nbe_program, (maxidx, idx_tab))
    1.80 +        |> compile_stmts thy (map (fn name => ((name, Graph.get_node program name),
    1.81            Graph.imm_succs program name)) names);
    1.82    in
    1.83      fold_rev add_stmts (Graph.strong_conn program)
    1.84 @@ -465,20 +469,20 @@
    1.85  
    1.86  (** evaluation **)
    1.87  
    1.88 -(* term evaluation *)
    1.89 +(* term evaluation by compilation *)
    1.90  
    1.91 -fun eval_term ctxt gr deps (vs : (string * sort) list, t) =
    1.92 +fun compile_term thy nbe_program deps (vs : (string * sort) list, t) =
    1.93    let 
    1.94      val dict_frees = maps (fn (v, sort) => map_index (curry DFree v o fst) sort) vs;
    1.95    in
    1.96      ("", (vs, [([], t)]))
    1.97 -    |> singleton (compile_eqnss ctxt gr deps)
    1.98 +    |> singleton (compile_eqnss thy nbe_program deps)
    1.99      |> snd
   1.100      |> (fn t => apps t (rev dict_frees))
   1.101    end;
   1.102  
   1.103  
   1.104 -(* reification *)
   1.105 +(* reconstruction *)
   1.106  
   1.107  fun typ_of_itype program vs (ityco `%% itys) =
   1.108        let
   1.109 @@ -525,6 +529,29 @@
   1.110    in of_univ 0 t 0 |> fst end;
   1.111  
   1.112  
   1.113 +(* evaluation with type reconstruction *)
   1.114 +
   1.115 +fun eval_term thy program (nbe_program, idx_tab) ((vs0, (vs, ty)), t) deps =
   1.116 +  let
   1.117 +    val ctxt = Syntax.init_pretty_global thy;
   1.118 +    val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt);
   1.119 +    val ty' = typ_of_itype program vs0 ty;
   1.120 +    fun type_infer t = singleton
   1.121 +      (Type_Infer.infer_types ctxt (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE))
   1.122 +      (Type.constraint ty' t);
   1.123 +    fun check_tvars t =
   1.124 +      if null (Term.add_tvars t []) then t
   1.125 +      else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t);
   1.126 +  in
   1.127 +    compile_term thy nbe_program deps (vs, t)
   1.128 +    |> term_of_univ thy program idx_tab
   1.129 +    |> traced (fn t => "Normalized:\n" ^ string_of_term t)
   1.130 +    |> type_infer
   1.131 +    |> traced (fn t => "Types inferred:\n" ^ string_of_term t)
   1.132 +    |> check_tvars
   1.133 +    |> traced (fn _ => "---\n")
   1.134 +  end;
   1.135 +
   1.136  (* function store *)
   1.137  
   1.138  structure Nbe_Functions = Code_Data
   1.139 @@ -533,43 +560,11 @@
   1.140    val empty = (Graph.empty, (0, Inttab.empty));
   1.141  );
   1.142  
   1.143 -
   1.144 -(* compilation, evaluation and reification *)
   1.145 -
   1.146 -fun compile_eval thy program =
   1.147 -  let
   1.148 -    val ctxt = ProofContext.init_global thy;
   1.149 -    val (gr, (_, idx_tab)) =
   1.150 -      Nbe_Functions.change thy (ensure_stmts ctxt program);
   1.151 -  in fn vs_t => fn deps =>
   1.152 -    vs_t
   1.153 -    |> eval_term ctxt gr deps
   1.154 -    |> term_of_univ thy program idx_tab
   1.155 -  end;
   1.156 -
   1.157 -
   1.158 -(* evaluation with type reconstruction *)
   1.159 -
   1.160 -fun normalize thy program ((vs0, (vs, ty)), t) deps =
   1.161 +fun compile thy program =
   1.162    let
   1.163 -    val ctxt = Syntax.init_pretty_global thy;
   1.164 -    val ty' = typ_of_itype program vs0 ty;
   1.165 -    fun type_infer t =
   1.166 -      singleton
   1.167 -        (Type_Infer.infer_types ctxt (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE))
   1.168 -        (Type.constraint ty' t);
   1.169 -    val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt);
   1.170 -    fun check_tvars t =
   1.171 -      if null (Term.add_tvars t []) then t
   1.172 -      else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t);
   1.173 -  in
   1.174 -    compile_eval thy program (vs, t) deps
   1.175 -    |> traced (fn t => "Normalized:\n" ^ string_of_term t)
   1.176 -    |> type_infer
   1.177 -    |> traced (fn t => "Types inferred:\n" ^ string_of_term t)
   1.178 -    |> check_tvars
   1.179 -    |> traced (fn _ => "---\n")
   1.180 -  end;
   1.181 +    val (nbe_program, (_, idx_tab)) =
   1.182 +      Nbe_Functions.change thy (compile_program thy program);
   1.183 +  in (nbe_program, idx_tab) end;
   1.184  
   1.185  
   1.186  (* evaluation oracle *)
   1.187 @@ -583,7 +578,7 @@
   1.188  
   1.189  val (_, raw_oracle) = Context.>>> (Context.map_theory_result
   1.190    (Thm.add_oracle (Binding.name "norm", fn (thy, program, vsp_ty_t, deps, ct) =>
   1.191 -    mk_equals thy ct (normalize thy program vsp_ty_t deps))));
   1.192 +    mk_equals thy ct (eval_term thy program (compile thy program) vsp_ty_t deps))));
   1.193  
   1.194  fun oracle thy program vsp_ty_t deps ct = raw_oracle (thy, program, vsp_ty_t, deps, ct);
   1.195  
   1.196 @@ -601,7 +596,8 @@
   1.197    (fn thy => lift_triv_classes_conv thy (Code_Thingol.dynamic_eval_conv thy (K (oracle thy)))));
   1.198  
   1.199  fun dynamic_eval_value thy = lift_triv_classes_rew thy
   1.200 -  (no_frees_rew (Code_Thingol.dynamic_eval_value thy I (K (normalize thy))));
   1.201 +  (no_frees_rew (Code_Thingol.dynamic_eval_value thy I
   1.202 +    (K (fn program => eval_term thy program (compile thy program)))));
   1.203  
   1.204  
   1.205  (* evaluation command *)