avoid frees and vars in terms to be evaluated by abstracting and applying
authorhaftmann
Tue, 21 Sep 2010 15:46:05 +0200
changeset 39604 f17fb9ccb836
parent 39603 eb0a51312752
child 39605 6dc866b9c548
avoid frees and vars in terms to be evaluated by abstracting and applying
src/Tools/Code/code_preproc.ML
--- a/src/Tools/Code/code_preproc.ML	Tue Sep 21 15:46:05 2010 +0200
+++ b/src/Tools/Code/code_preproc.ML	Tue Sep 21 15:46:05 2010 +0200
@@ -106,6 +106,22 @@
 
 (* post- and preprocessing *)
 
+fun no_variables_conv conv ct =
+  let
+    val cert = Thm.cterm_of (Thm.theory_of_cterm ct);
+    val all_vars = fold_aterms (fn t as Free _ => insert (op aconvc) (cert t)
+      | t as Var _ => insert (op aconvc) (cert t)
+      | _ => I) (Thm.term_of ct) [];
+    fun apply_beta var thm = Thm.combination thm (Thm.reflexive var)
+      |> Conv.fconv_rule (Conv.arg_conv (Conv.try_conv (Thm.beta_conversion false)))
+      |> Conv.fconv_rule (Conv.arg1_conv (Thm.beta_conversion false));
+  in
+    ct
+    |> fold_rev Thm.cabs all_vars
+    |> conv
+    |> fold apply_beta all_vars
+  end;
+
 fun trans_conv_rule conv thm = Thm.transitive thm ((conv o Thm.rhs_of) thm);
 
 fun eqn_conv conv ct =
@@ -141,7 +157,6 @@
 fun preprocess_conv thy ct =
   let
     val ctxt = ProofContext.init_global thy;
-    val _ = (Sign.no_frees ctxt o map_types (K dummyT) o Sign.no_vars ctxt) (Thm.term_of ct);
     val pre = (Simplifier.global_context thy o #pre o the_thmproc) thy;
   in
     ct
@@ -149,7 +164,13 @@
     |> trans_conv_rule (AxClass.unoverload_conv thy)
   end;
 
-fun preprocess_term thy = term_of_conv thy (preprocess_conv thy);
+fun preprocess_term thy t =
+  let
+    val all_vars = fold_aterms (fn t as Free _ => insert (op aconv) t
+      | t as Var _ => insert (op aconv) t
+      | _ => I) t [];
+    val resubst = curry (Term.betapplys o swap) all_vars;
+  in (resubst, term_of_conv thy (preprocess_conv thy) (fold_rev lambda all_vars t)) end;
 
 fun postprocess_conv thy ct =
   let
@@ -198,8 +219,11 @@
 type code_algebra = (sort -> sort) * Sorts.algebra;
 type code_graph = ((string * sort) list * Code.cert) Graph.T;
 
-fun cert eqngr = snd o Graph.get_node eqngr;
-fun sortargs eqngr = map snd o fst o Graph.get_node eqngr;
+fun get_node eqngr const = Graph.get_node eqngr const
+  handle Graph.UNDEF _ => error ("No such constant in code equation graph: " ^ quote const);
+
+fun cert eqngr = snd o get_node eqngr;
+fun sortargs eqngr = map snd o fst o get_node eqngr;
 fun all eqngr = Graph.keys eqngr;
 
 fun pretty thy eqngr =
@@ -433,7 +457,7 @@
 
 fun dest_cterm ct = let val t = Thm.term_of ct in (Term.add_tfrees t [], t) end;
 
-fun dynamic_eval_conv thy conv ct =
+fun dynamic_eval_conv thy conv = no_variables_conv (fn ct =>
   let
     val thm1 = preprocess_conv thy ct;
     val ct' = Thm.rhs_of thm1;
@@ -447,11 +471,11 @@
     Thm.transitive thm1 (Thm.transitive thm2 thm3) handle THM _ =>
       error ("could not construct evaluation proof:\n"
       ^ (cat_lines o map (Display.string_of_thm_global thy)) [thm1, thm2, thm3])
-  end;
+  end);
 
 fun dynamic_eval_value thy postproc evaluator t =
   let
-    val t' = preprocess_term thy t;
+    val (resubst, t') = preprocess_term thy t;
     val vs' = Term.add_tfrees t' [];
     val consts = fold_aterms
       (fn Const (c, _) => insert (op =) c | _ => I) t' [];
@@ -459,7 +483,7 @@
     val result = evaluator algebra' eqngr' vs' t';
   in
     evaluator algebra' eqngr' vs' t'
-    |> postproc (postprocess_term thy)
+    |> postproc (postprocess_term thy o resubst)
   end;
 
 fun static_eval_conv thy consts conv =
@@ -467,9 +491,9 @@
     val (algebra, eqngr) = obtain true thy consts [];
     val conv' = conv algebra eqngr;
   in
-    Conv.tap_thy (fn thy => (preprocess_conv thy)
+    no_variables_conv (Conv.tap_thy (fn thy => (preprocess_conv thy)
       then_conv (fn ct => uncurry (conv' thy) (dest_cterm ct) ct)
-      then_conv (postprocess_conv thy))
+      then_conv (postprocess_conv thy)))
   end;
 
 fun static_eval_value thy postproc consts evaluator =
@@ -479,8 +503,9 @@
   in fn t =>
     t
     |> preprocess_term thy
-    |> (fn t => evaluator' thy (Term.add_tfrees t [])  t)
-    |> postproc (postprocess_term thy)
+    |-> (fn resubst => fn t => t
+    |> evaluator' thy (Term.add_tfrees t [])
+    |> postproc (postprocess_term thy o resubst))
   end;