src/Tools/Code/code_preproc.ML
changeset 56968 d2b1d95eb722
parent 56967 c3746e999805
child 56970 a3f911785efa
--- a/src/Tools/Code/code_preproc.ML	Thu May 15 16:38:17 2014 +0200
+++ b/src/Tools/Code/code_preproc.ML	Thu May 15 16:38:28 2014 +0200
@@ -24,14 +24,14 @@
   val obtain: bool -> theory -> string list -> term list -> code_algebra * code_graph
   val dynamic_conv: Proof.context
     -> (code_algebra -> code_graph -> term -> conv) -> conv
-  val dynamic_value: Proof.context -> ((term -> term) -> 'a -> 'a)
-    -> (code_algebra -> code_graph -> term -> 'a) -> term -> 'a
+  val dynamic_value: Proof.context -> ((term -> term) -> 'a -> 'b)
+    -> (code_algebra -> code_graph -> term -> 'a) -> term -> 'b
   val static_conv: Proof.context -> string list
     -> (code_algebra -> code_graph -> Proof.context -> term -> conv)
     -> Proof.context -> conv
-  val static_value: Proof.context -> ((term -> term) -> 'a -> 'a) -> string list
+  val static_value: Proof.context -> ((term -> term) -> 'a -> 'b) -> string list
     -> (code_algebra -> code_graph -> Proof.context -> term -> 'a)
-    -> Proof.context -> term -> 'a
+    -> Proof.context -> term -> 'b
 
   val setup: theory -> theory
 end
@@ -107,26 +107,47 @@
   (delete_force "function transformer" name);
 
 
+(* algebra of sandwiches: cterm transformations with pending postprocessors *)
+
+fun trans_comb eq1 eq2 =
+  if Thm.is_reflexive eq1 then eq2
+  else if Thm.is_reflexive eq2 then eq1
+  else Thm.transitive eq1 eq2;
+
+fun trans_conv_rule conv eq = trans_comb eq (conv (Thm.rhs_of eq));
+
+type sandwich = Proof.context -> cterm -> (thm -> thm) * cterm;
+type conv_sandwich = Proof.context -> cterm -> conv * thm;
+
+fun chain sandwich2 sandwich1 ctxt =
+  sandwich1 ctxt
+  ##>> sandwich2 ctxt
+  #>> (op o);
+
+fun lift_conv_sandwich conv_sandwhich ctxt ct =
+  let
+    val (postproc_conv, eq) = conv_sandwhich ctxt ct;
+  in (trans_conv_rule postproc_conv o trans_comb eq, Thm.rhs_of eq) end;
+
+fun finalize sandwich conv ctxt ct =
+  let
+    val (postproc, ct') = sandwich ctxt ct;
+  in postproc (conv ctxt (term_of ct') ct') end;
+
+fun evaluation sandwich lift_postproc eval ctxt t =
+  let
+    val cert = Thm.cterm_of (Proof_Context.theory_of ctxt);
+    val (postproc, ct') = sandwich ctxt (cert t);
+  in
+    term_of ct'
+    |> eval ctxt
+    |> lift_postproc (term_of o Thm.rhs_of o postproc o Thm.reflexive o cert)
+  end;
+
+
 (* post- and preprocessing *)
 
-fun no_variables_conv ctxt conv ct =
-  let
-    val thy = Proof_Context.theory_of ctxt;
-    val cert = Thm.cterm_of thy;
-    val all_vars = fold_aterms (fn t as Free _ => insert (op aconvc) (cert t)
-      | t as Var _ => insert (op aconvc) (cert t)
-      | _ => I) (Thm.term_of ct) [];
-    fun apply_beta var thm = Thm.combination thm (Thm.reflexive var)
-      |> Conv.fconv_rule (Conv.arg_conv (Conv.try_conv (Thm.beta_conversion false)))
-      |> Conv.fconv_rule (Conv.arg1_conv (Thm.beta_conversion false));
-  in
-    ct
-    |> fold_rev Thm.lambda all_vars
-    |> conv
-    |> fold apply_beta all_vars
-  end;
-
-fun normalized_tfrees ctxt conv ct =
+fun normalized_tfrees_sandwich ctxt ct =
   let
     val cert = cterm_of (Proof_Context.theory_of ctxt);
     val t = term_of ct;
@@ -136,58 +157,41 @@
     val normalize = map_type_tfree (TFree o the o AList.lookup (op =) (vs_original ~~ vs_normalized));
     val normalization =
       map2 (fn (v, sort) => fn (v', _) => (((v', 0), sort), TFree (v, sort))) vs_original vs_normalized;
-    val ct_normalized = cert (map_types normalize t);
   in
-    ct_normalized
-    |> conv
-    |> Thm.varifyT_global
-    |> Thm.certify_instantiate (normalization, [])
+    (Thm.certify_instantiate (normalization, []) o Thm.varifyT_global, cert (map_types normalize t))
   end;
 
-fun trans_conv_rule conv thm = Thm.transitive thm ((conv o Thm.rhs_of) thm);
-
-fun term_of_conv ctxt conv =
-  Thm.cterm_of (Proof_Context.theory_of ctxt)
-  #> conv ctxt
-  #> Thm.prop_of
-  #> Logic.dest_equals
-  #> snd;
+fun no_variables_sandwich ctxt ct =
+  let
+    val thy = Proof_Context.theory_of ctxt;
+    val cert = Thm.cterm_of thy;
+    val all_vars = fold_aterms (fn t as Free _ => insert (op aconvc) (cert t)
+      | t as Var _ => insert (op aconvc) (cert t)
+      | _ => I) (Thm.term_of ct) [];
+    fun apply_beta var thm = Thm.combination thm (Thm.reflexive var)
+      |> Conv.fconv_rule (Conv.arg_conv (Conv.try_conv (Thm.beta_conversion false)))
+      |> Conv.fconv_rule (Conv.arg1_conv (Thm.beta_conversion false));
+  in (fold apply_beta all_vars, fold_rev Thm.lambda all_vars ct) end;
 
-fun term_of_conv_resubst ctxt conv t =
-  let
-    val all_vars = fold_aterms (fn t as Free _ => insert (op aconv) t
-      | t as Var _ => insert (op aconv) t
-      | _ => I) t [];
-    val resubst = curry (Term.betapplys o swap) all_vars;
-  in (resubst, term_of_conv ctxt conv (fold_rev lambda all_vars t)) end;
-
-fun preprocess_conv ctxt =
+fun simplifier_conv_sandwich ctxt =
   let
     val thy = Proof_Context.theory_of ctxt;
-    val ss = (#pre o the_thmproc) thy;
-  in fn ctxt' =>
-    Simplifier.rewrite (put_simpset ss ctxt')
-    #> trans_conv_rule (Axclass.unoverload_conv (Proof_Context.theory_of ctxt'))
-  end;
-
-fun preprocess_term ctxt =
-  let
-    val conv = preprocess_conv ctxt;
-  in fn ctxt' => term_of_conv_resubst ctxt' conv end;
+    val pre = (#pre o the_thmproc) thy;
+    val post = (#post o the_thmproc) thy;
+    fun pre_conv ctxt' =
+      Simplifier.rewrite (put_simpset pre ctxt')
+      #> trans_conv_rule (Axclass.unoverload_conv (Proof_Context.theory_of ctxt'))
+    fun post_conv ctxt' =
+      Axclass.overload_conv (Proof_Context.theory_of ctxt')
+      #> trans_conv_rule (Simplifier.rewrite (put_simpset post ctxt'))
+  in fn ctxt' => pre_conv ctxt' #> pair (post_conv ctxt') end;
 
-fun postprocess_conv ctxt =
-  let
-    val thy = Proof_Context.theory_of ctxt;
-    val ss = (#post o the_thmproc) thy;
-  in fn ctxt' =>
-    Axclass.overload_conv (Proof_Context.theory_of ctxt')
-    #> trans_conv_rule (Simplifier.rewrite (put_simpset ss ctxt'))
-  end;
+fun simplifier_sandwich ctxt = lift_conv_sandwich (simplifier_conv_sandwich ctxt);
 
-fun postprocess_term ctxt =
-  let
-    val conv = postprocess_conv ctxt;
-  in fn ctxt' => term_of_conv ctxt' conv end;
+fun value_sandwich ctxt =
+  normalized_tfrees_sandwich
+  |> chain no_variables_sandwich
+  |> chain (simplifier_sandwich ctxt);
 
 fun print_codeproc ctxt =
   let
@@ -477,57 +481,28 @@
   (Wellsorted.change_yield (if ignore_cache then NONE else SOME thy)
     (extend_arities_eqngr (Proof_Context.init_global thy) consts ts));
 
-fun dynamic_conv ctxt conv = normalized_tfrees ctxt (no_variables_conv ctxt (fn ct =>
+fun dynamic_evaluator eval ctxt t =
   let
-    val thm1 = preprocess_conv ctxt ctxt ct;
-    val ct' = Thm.rhs_of thm1;
-    val t' = Thm.term_of ct';
     val consts = fold_aterms
-      (fn Const (c, _) => insert (op =) c | _ => I) t' [];
-    val (algebra', eqngr') = obtain false (Proof_Context.theory_of ctxt) consts [t'];
-    val thm2 = conv algebra' eqngr' t' ct';
-    val thm3 = postprocess_conv ctxt ctxt (Thm.rhs_of thm2);
-  in
-    Thm.transitive thm1 (Thm.transitive thm2 thm3) handle THM _ =>
-      error ("could not construct evaluation proof:\n"
-      ^ (cat_lines o map (Display.string_of_thm ctxt)) [thm1, thm2, thm3])
-  end));
+      (fn Const (c, _) => insert (op =) c | _ => I) t [];
+    val (algebra, eqngr) = obtain false (Proof_Context.theory_of ctxt) consts [t];
+  in eval algebra eqngr t end;
 
-fun dynamic_value ctxt postproc evaluator t =
-  let
-    val (resubst, t') = preprocess_term ctxt ctxt t;
-    val consts = fold_aterms
-      (fn Const (c, _) => insert (op =) c | _ => I) t' [];
-    val (algebra', eqngr') = obtain false (Proof_Context.theory_of ctxt) consts [t'];
-  in
-    t'
-    |> evaluator algebra' eqngr'
-    |> postproc (postprocess_term ctxt ctxt o resubst)
-  end;
+fun dynamic_conv ctxt conv =
+  finalize (value_sandwich ctxt) (dynamic_evaluator conv) ctxt;
+
+fun dynamic_value ctxt lift_postproc evaluator =
+  evaluation (value_sandwich ctxt) lift_postproc (dynamic_evaluator evaluator) ctxt;
 
 fun static_conv ctxt consts conv =
   let
     val (algebra, eqngr) = obtain true (Proof_Context.theory_of ctxt) consts [];
-    val pre_conv = preprocess_conv ctxt;
-    val conv' = conv algebra eqngr;
-    val post_conv = postprocess_conv ctxt;
-  in fn ctxt' => normalized_tfrees ctxt' (no_variables_conv ctxt' ((pre_conv ctxt')
-    then_conv (fn ct => conv' ctxt' (Thm.term_of ct) ct)
-    then_conv (post_conv ctxt')))
-  end;
+  in finalize (value_sandwich ctxt) (conv algebra eqngr) end;
 
-fun static_value ctxt postproc consts evaluator =
+fun static_value ctxt lift_postproc consts evaluator =
   let
     val (algebra, eqngr) = obtain true (Proof_Context.theory_of ctxt) consts [];
-    val preproc = preprocess_term ctxt;
-    val evaluator' = evaluator algebra eqngr;
-    val postproc' = postprocess_term ctxt;
-  in fn ctxt' => 
-    preproc ctxt'
-    #-> (fn resubst => fn t => t
-      |> evaluator' ctxt'
-      |> postproc (postproc' ctxt' o resubst))
-  end;
+  in evaluation (value_sandwich ctxt) lift_postproc (evaluator algebra eqngr) end;
 
 
 (** setup **)