contextifying the compilation of the predicate compiler
authorbulwahn
Mon, 22 Mar 2010 08:30:13 +0100
changeset 35891 3122bdd95275
parent 35890 14a0993fe64b
child 35892 5ed2e9a545ac
contextifying the compilation of the predicate compiler
src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Mon Mar 22 08:30:13 2010 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Mon Mar 22 08:30:13 2010 +0100
@@ -325,7 +325,7 @@
   else false
 *)
 
-val is_constr = Code.is_constr;
+val is_constr = Code.is_constr o ProofContext.theory_of;
 
 fun strip_ex (Const ("Ex", _) $ Abs (x, T, t)) =
   let
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Mon Mar 22 08:30:13 2010 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Mon Mar 22 08:30:13 2010 +0100
@@ -1027,24 +1027,27 @@
     fold_rev (curry Fun) (map (K Output) Ts) Output
   end
 
-fun is_invertible_function thy (Const (f, _)) = is_constr thy f
-  | is_invertible_function thy _ = false
+fun is_invertible_function ctxt (Const (f, _)) = is_constr ctxt f
+  | is_invertible_function ctxt _ = false
 
-fun non_invertible_subterms thy (t as Free _) = []
-  | non_invertible_subterms thy t = 
-  case (strip_comb t) of (f, args) =>
-    if is_invertible_function thy f then
-      maps (non_invertible_subterms thy) args
+fun non_invertible_subterms ctxt (t as Free _) = []
+  | non_invertible_subterms ctxt t = 
+  let
+    val (f, args) = strip_comb t
+  in
+    if is_invertible_function ctxt f then
+      maps (non_invertible_subterms ctxt) args
     else
       [t]
+  end
 
-fun collect_non_invertible_subterms thy (f as Free _) (names, eqs) = (f, (names, eqs))
-  | collect_non_invertible_subterms thy t (names, eqs) =
+fun collect_non_invertible_subterms ctxt (f as Free _) (names, eqs) = (f, (names, eqs))
+  | collect_non_invertible_subterms ctxt t (names, eqs) =
     case (strip_comb t) of (f, args) =>
-      if is_invertible_function thy f then
+      if is_invertible_function ctxt f then
           let
             val (args', (names', eqs')) =
-              fold_map (collect_non_invertible_subterms thy) args (names, eqs)
+              fold_map (collect_non_invertible_subterms ctxt) args (names, eqs)
           in
             (list_comb (f, args'), (names', eqs'))
           end
@@ -1066,18 +1069,21 @@
 fun is_possible_output thy vs t =
   forall
     (fn t => is_eqT (fastype_of t) andalso forall (member (op =) vs) (term_vs t))
-      (non_invertible_subterms thy t)
+      (non_invertible_subterms (ProofContext.init thy) t)
   andalso
     (forall (is_eqT o snd)
       (inter (fn ((f', _), f) => f = f') vs (Term.add_frees t [])))
 
-fun vars_of_destructable_term thy (Free (x, _)) = [x]
-  | vars_of_destructable_term thy t =
-  case (strip_comb t) of (f, args) =>
-    if is_invertible_function thy f then
-      maps (vars_of_destructable_term thy) args
+fun vars_of_destructable_term ctxt (Free (x, _)) = [x]
+  | vars_of_destructable_term ctxt t =
+  let
+    val (f, args) = strip_comb t
+  in
+    if is_invertible_function ctxt f then
+      maps (vars_of_destructable_term ctxt) args
     else
       []
+  end
 
 fun is_constructable thy vs t = forall (member (op =) vs) (term_vs t)
 
@@ -1099,7 +1105,7 @@
       SOME ms => SOME (map (fn m => (Context m , [])) ms)
     | NONE => NONE)
 
-fun derivations_of thy modes vs (Const ("Pair", _) $ t1 $ t2) (Pair (m1, m2)) =
+fun derivations_of (thy : theory) modes vs (Const ("Pair", _) $ t1 $ t2) (Pair (m1, m2)) =
     map_product
       (fn (m1, mvars1) => fn (m2, mvars2) => (Mode_Pair (m1, m2), union (op =) mvars1 mvars2))
         (derivations_of thy modes vs t1 m1) (derivations_of thy modes vs t2 m2)
@@ -1215,7 +1221,7 @@
   tracing ("modes: " ^ (commas (map (fn (s, ms) => s ^ ": " ^
     commas (map (fn (m, r) => string_of_mode m ^ (if r then " random " else " not ")) ms)) modes)))
 
-fun select_mode_prem (mode_analysis_options : mode_analysis_options) thy pol (modes, (pos_modes, neg_modes)) vs ps =
+fun select_mode_prem (mode_analysis_options : mode_analysis_options) (thy : theory) pol (modes, (pos_modes, neg_modes)) vs ps =
   let
     fun choose_mode_of_prem (Prem t) = partial_hd
         (sort (deriv_ord2 thy modes t) (all_derivations_of thy pos_modes vs t))
@@ -1246,7 +1252,7 @@
           val modes = map (fn (s, ms) => (s, map (fn ((p, m), r) => m) ms)) modes'
         in (modes, modes) end
     val (in_ts, out_ts) = split_mode mode ts
-    val in_vs = maps (vars_of_destructable_term thy) in_ts
+    val in_vs = maps (vars_of_destructable_term (ProofContext.init thy)) in_ts
     val out_vs = terms_vs out_ts
     fun known_vs_after p vs = (case p of
         Prem t => union (op =) vs (term_vs t)
@@ -1509,7 +1515,7 @@
 end;
 
 (* TODO: uses param_vs -- change necessary for compilation with new modes *)
-fun compile_arg compilation_modifiers compfuns additional_arguments thy param_vs iss arg = 
+fun compile_arg compilation_modifiers compfuns additional_arguments ctxt param_vs iss arg = 
   let
     fun map_params (t as Free (f, T)) =
       if member (op =) param_vs f then
@@ -1525,11 +1531,11 @@
     in map_aterms map_params arg end
 
 fun compile_match compilation_modifiers compfuns additional_arguments
-  param_vs iss thy eqs eqs' out_ts success_t =
+  param_vs iss ctxt eqs eqs' out_ts success_t =
   let
     val eqs'' = maps mk_eq eqs @ eqs'
     val eqs'' =
-      map (compile_arg compilation_modifiers compfuns additional_arguments thy param_vs iss) eqs''
+      map (compile_arg compilation_modifiers compfuns additional_arguments ctxt param_vs iss) eqs''
     val names = fold Term.add_free_names (success_t :: eqs'' @ out_ts) [];
     val name = Name.variant names "x";
     val name' = Name.variant (name :: names) "y";
@@ -1539,8 +1545,7 @@
     val v = Free (name, T);
     val v' = Free (name', T);
   in
-    lambda v (fst (Datatype.make_case
-      (ProofContext.init thy) Datatype_Case.Quiet [] v
+    lambda v (fst (Datatype.make_case ctxt Datatype_Case.Quiet [] v
       [(HOLogic.mk_tuple out_ts,
         if null eqs'' then success_t
         else Const (@{const_name HOL.If}, HOLogic.boolT --> U --> U --> U) $
@@ -1549,25 +1554,25 @@
        (v', mk_bot compfuns U')]))
   end;
 
-fun string_of_tderiv thy (t, deriv) = 
+fun string_of_tderiv ctxt (t, deriv) = 
   (case (t, deriv) of
     (t1 $ t2, Mode_App (deriv1, deriv2)) =>
-      string_of_tderiv thy (t1, deriv1) ^ " $ " ^ string_of_tderiv thy (t2, deriv2)
+      string_of_tderiv ctxt (t1, deriv1) ^ " $ " ^ string_of_tderiv ctxt (t2, deriv2)
   | (Const ("Pair", _) $ t1 $ t2, Mode_Pair (deriv1, deriv2)) =>
-    "(" ^ string_of_tderiv thy (t1, deriv1) ^ ", " ^ string_of_tderiv thy (t2, deriv2) ^ ")"
-  | (t, Term Input) => Syntax.string_of_term_global thy t ^ "[Input]"
-  | (t, Term Output) => Syntax.string_of_term_global thy t ^ "[Output]"
-  | (t, Context m) => Syntax.string_of_term_global thy t ^ "[" ^ string_of_mode m ^ "]")
+    "(" ^ string_of_tderiv ctxt (t1, deriv1) ^ ", " ^ string_of_tderiv ctxt (t2, deriv2) ^ ")"
+  | (t, Term Input) => Syntax.string_of_term ctxt t ^ "[Input]"
+  | (t, Term Output) => Syntax.string_of_term ctxt t ^ "[Output]"
+  | (t, Context m) => Syntax.string_of_term ctxt t ^ "[" ^ string_of_mode m ^ "]")
 
-fun compile_expr compilation_modifiers compfuns thy pol (t, deriv) additional_arguments =
+fun compile_expr compilation_modifiers compfuns ctxt pol (t, deriv) additional_arguments =
   let
     fun expr_of (t, deriv) =
       (case (t, deriv) of
         (t, Term Input) => SOME t
       | (t, Term Output) => NONE
       | (Const (name, T), Context mode) =>
-        SOME (Const (function_name_of (Comp_Mod.compilation compilation_modifiers) thy name
-          (pol, mode),
+        SOME (Const (function_name_of (Comp_Mod.compilation compilation_modifiers)
+          (ProofContext.theory_of ctxt) name (pol, mode),
           Comp_Mod.funT_of compilation_modifiers mode T))
       | (Free (s, T), Context m) =>
         SOME (Free (s, Comp_Mod.funT_of compilation_modifiers m T))
@@ -1591,19 +1596,19 @@
     list_comb (the (expr_of (t, deriv)), additional_arguments)
   end
 
-fun compile_clause compilation_modifiers compfuns thy all_vs param_vs additional_arguments
+fun compile_clause compilation_modifiers compfuns ctxt all_vs param_vs additional_arguments
   (pol, mode) inp (ts, moded_ps) =
   let
     val iss = ho_arg_modes_of mode
     val compile_match = compile_match compilation_modifiers compfuns
-      additional_arguments param_vs iss thy
+      additional_arguments param_vs iss ctxt
     val (in_ts, out_ts) = split_mode mode ts;
     val (in_ts', (all_vs', eqs)) =
-      fold_map (collect_non_invertible_subterms thy) in_ts (all_vs, []);
+      fold_map (collect_non_invertible_subterms ctxt) in_ts (all_vs, []);
     fun compile_prems out_ts' vs names [] =
           let
             val (out_ts'', (names', eqs')) =
-              fold_map (collect_non_invertible_subterms thy) out_ts' (names, []);
+              fold_map (collect_non_invertible_subterms ctxt) out_ts' (names, []);
             val (out_ts''', (names'', constr_vs)) = fold_map distinct_v
               out_ts'' (names', map (rpair []) vs);
           in
@@ -1614,7 +1619,7 @@
           let
             val vs' = distinct (op =) (flat (vs :: map term_vs out_ts));
             val (out_ts', (names', eqs)) =
-              fold_map (collect_non_invertible_subterms thy) out_ts (names, [])
+              fold_map (collect_non_invertible_subterms ctxt) out_ts (names, [])
             val (out_ts'', (names'', constr_vs')) = fold_map distinct_v
               out_ts' ((names', map (rpair []) vs))
             val mode = head_mode_of deriv
@@ -1624,7 +1629,7 @@
                Prem t =>
                  let
                    val u =
-                     compile_expr compilation_modifiers compfuns thy
+                     compile_expr compilation_modifiers compfuns ctxt
                        pol (t, deriv) additional_arguments'
                    val (_, out_ts''') = split_mode mode (snd (strip_comb t))
                    val rest = compile_prems out_ts''' vs' names'' ps
@@ -1634,7 +1639,7 @@
              | Negprem t =>
                  let
                    val u = mk_not compfuns
-                     (compile_expr compilation_modifiers compfuns thy
+                     (compile_expr compilation_modifiers compfuns ctxt
                        (not pol) (t, deriv) additional_arguments')
                    val (_, out_ts''') = split_mode mode (snd (strip_comb t))
                    val rest = compile_prems out_ts''' vs' names'' ps
@@ -1644,7 +1649,7 @@
              | Sidecond t =>
                  let
                    val t = compile_arg compilation_modifiers compfuns additional_arguments
-                     thy param_vs iss t
+                     ctxt param_vs iss t
                    val rest = compile_prems [] vs' names'' ps;
                  in
                    (mk_if compfuns t, rest)
@@ -1667,6 +1672,7 @@
 
 fun compile_pred compilation_modifiers thy all_vs param_vs s T (pol, mode) moded_cls =
   let
+    val ctxt = ProofContext.init thy
     val additional_arguments = Comp_Mod.additional_arguments compilation_modifiers
       (all_vs @ param_vs)
     val compfuns = Comp_Mod.compfuns compilation_modifiers
@@ -1691,14 +1697,15 @@
       (fn t as Free (x, _) => if member (op =) param_vs x then NONE else SOME t | t => SOME t)) in_ts
     val cl_ts =
       map (compile_clause compilation_modifiers compfuns
-        thy all_vs param_vs additional_arguments (pol, mode) (HOLogic.mk_tuple in_ts')) moded_cls;
+        ctxt all_vs param_vs additional_arguments (pol, mode) (HOLogic.mk_tuple in_ts')) moded_cls;
     val compilation = Comp_Mod.wrap_compilation compilation_modifiers compfuns
       s T mode additional_arguments
       (if null cl_ts then
         mk_bot compfuns (HOLogic.mk_tupleT outTs)
       else foldr1 (mk_sup compfuns) cl_ts)
     val fun_const =
-      Const (function_name_of (Comp_Mod.compilation compilation_modifiers) thy s (pol, mode), funT)
+      Const (function_name_of (Comp_Mod.compilation compilation_modifiers)
+      (ProofContext.theory_of ctxt) s (pol, mode), funT)
   in
     HOLogic.mk_Trueprop
       (HOLogic.mk_eq (list_comb (fun_const, in_ts @ additional_arguments), compilation))
@@ -3032,7 +3039,8 @@
           (*| Annotated => annotated_comp_modifiers*)
           | DSeq => dseq_comp_modifiers
           | Pos_Random_DSeq => pos_random_dseq_comp_modifiers
-        val t_pred = compile_expr comp_modifiers compfuns thy true (body, deriv) additional_arguments;
+        val t_pred = compile_expr comp_modifiers compfuns (ProofContext.init thy)
+          true (body, deriv) additional_arguments;
         val T_pred = dest_predT compfuns (fastype_of t_pred)
         val arrange = split_lambda (HOLogic.mk_tuple outargs) output_tuple
       in