rewriting function mk_Eval_of in predicate compiler
authorbulwahn
Thu, 23 Sep 2010 14:50:13 +0200
changeset 39648 655307cb8489
parent 39647 7bf0c7f0f24c
child 39649 7186d338f2e1
rewriting function mk_Eval_of in predicate compiler
src/HOL/Predicate_Compile_Examples/Predicate_Compile_Examples.thy
src/HOL/Tools/Predicate_Compile/predicate_compile_compilations.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
--- a/src/HOL/Predicate_Compile_Examples/Predicate_Compile_Examples.thy	Thu Sep 23 10:39:25 2010 +0200
+++ b/src/HOL/Predicate_Compile_Examples/Predicate_Compile_Examples.thy	Thu Sep 23 14:50:13 2010 +0200
@@ -522,7 +522,6 @@
 thm filter2.equation
 thm filter2.random_dseq_equation
 
-(*
 inductive filter3
 for P
 where
@@ -530,9 +529,9 @@
 
 code_pred (expected_modes: (o => bool) => i => o => bool, (o => bool) => i => i => bool , (i => bool) => i => o => bool, (i => bool) => i => i => bool) [skip_proof] filter3 .
 
-code_pred [dseq] filter3 .
-thm filter3.dseq_equation
-*)
+code_pred filter3 .
+thm filter3.equation
+
 (*
 inductive filter4
 where
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_compilations.ML	Thu Sep 23 10:39:25 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_compilations.ML	Thu Sep 23 14:50:13 2010 +0200
@@ -47,7 +47,7 @@
 
 fun mk_Eval (f, x) =
   let
-    val T = fastype_of x
+    val T = dest_predT (fastype_of f)
   in
     Const (@{const_name Predicate.eval}, mk_predT T --> T --> HOLogic.boolT) $ f $ x
   end;
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Thu Sep 23 10:39:25 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Thu Sep 23 14:50:13 2010 +0200
@@ -1402,7 +1402,7 @@
           val modes = map (fn (s, ms) => (s, map (fn ((p, m), r) => m) ms)) modes'
         in (modes, modes) end
     val (in_ts, out_ts) = split_mode mode ts
-    val in_vs = maps (vars_of_destructable_term ctxt) in_ts
+    val in_vs = union (op =) param_vs (maps (vars_of_destructable_term ctxt) in_ts)
     val out_vs = terms_vs out_ts
     fun known_vs_after p vs = (case p of
         Prem t => union (op =) vs (term_vs t)
@@ -1590,80 +1590,54 @@
       in (t' $ u', nvs'') end
   | distinct_v x nvs = (x, nvs);
 
-(** specific rpred functions -- move them to the correct place in this file *)
 
-fun mk_Eval_of additional_arguments ((x, T), NONE) names = (x, names)
-  | mk_Eval_of additional_arguments ((x, T), SOME mode) names =
-  let
-    val Ts = binder_types T
-    fun mk_split_lambda [] t = lambda (Free (Name.variant names "x", HOLogic.unitT)) t
-      | mk_split_lambda [x] t = lambda x t
-      | mk_split_lambda xs t =
-      let
-        fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t))
-          | mk_split_lambda' (x::xs) t = HOLogic.mk_split (lambda x (mk_split_lambda' xs t))
-      in
-        mk_split_lambda' xs t
-      end;
-    fun mk_arg (i, T) =
-      let
-        val vname = Name.variant names ("x" ^ string_of_int i)
-        val default = Free (vname, T)
-      in 
-        case AList.lookup (op =) mode i of
-          NONE => (([], [default]), [default])
-        | SOME NONE => (([default], []), [default])
-        | SOME (SOME pis) =>
-          case HOLogic.strip_tupleT T of
-            [] => error "pair mode but unit tuple" (*(([default], []), [default])*)
-          | [_] => error "pair mode but not a tuple" (*(([default], []), [default])*)
-          | Ts =>
-            let
-              val vnames = Name.variant_list names
-                (map (fn j => "x" ^ string_of_int i ^ "p" ^ string_of_int j)
-                  (1 upto length Ts))
-              val args = map2 (curry Free) vnames Ts
-              fun split_args (i, arg) (ins, outs) =
-                if member (op =) pis i then
-                  (arg::ins, outs)
-                else
-                  (ins, arg::outs)
-              val (inargs, outargs) = fold_rev split_args ((1 upto length Ts) ~~ args) ([], [])
-              fun tuple args = if null args then [] else [HOLogic.mk_tuple args]
-            in ((tuple inargs, tuple outargs), args) end
-      end
-    val (inoutargs, args) = split_list (map mk_arg (1 upto (length Ts) ~~ Ts))
-    val (inargs, outargs) = pairself flat (split_list inoutargs)
-    val r = PredicateCompFuns.mk_Eval 
-      (list_comb (x, inargs @ additional_arguments), HOLogic.mk_tuple outargs)
-    val t = fold_rev mk_split_lambda args r
-  in
-    (t, names)
-  end;
+(** specific rpred functions -- move them to the correct place in this file *)
+fun mk_Eval_of (P as (Free (f, _)), T) mode =
+let
+  fun mk_bounds (Type (@{type_name Product_Type.prod}, [T1, T2])) i =
+    let
+      val (bs2, i') = mk_bounds T2 i 
+      val (bs1, i'') = mk_bounds T1 i'
+    in
+      (HOLogic.pair_const T1 T2 $ bs1 $ bs2, i'' + 1)
+    end
+    | mk_bounds T i = (Bound i, i + 1)
+  fun mk_prod ((t1, T1), (t2, T2)) = (HOLogic.pair_const T1 T2 $ t1 $ t2, HOLogic.mk_prodT (T1, T2))
+  fun mk_tuple [] = (HOLogic.unit, HOLogic.unitT)
+    | mk_tuple tTs = foldr1 mk_prod tTs;
+  fun mk_split_abs (T as Type (@{type_name Product_Type.prod}, [T1, T2])) t = absdummy (T, HOLogic.split_const (T1, T2, @{typ bool}) $ (mk_split_abs T1 (mk_split_abs T2 t)))
+    | mk_split_abs T t = absdummy (T, t)
+  val args = rev (fst (fold_map mk_bounds (rev (binder_types T)) 0))
+  val (inargs, outargs) = split_mode mode args
+  val (inTs, outTs) = split_map_modeT (fn _ => fn T => (SOME T, NONE)) mode (binder_types T)
+  val inner_term = PredicateCompFuns.mk_Eval (list_comb (P, inargs), fst (mk_tuple (outargs ~~ outTs)))
+in
+  fold_rev mk_split_abs (binder_types T) inner_term  
+end
 
-(* TODO: uses param_vs -- change necessary for compilation with new modes *)
-fun compile_arg compilation_modifiers additional_arguments ctxt param_vs iss arg = 
+fun compile_arg compilation_modifiers additional_arguments ctxt param_modes arg = 
   let
     fun map_params (t as Free (f, T)) =
-      if member (op =) param_vs f then
-        case (AList.lookup (op =) (param_vs ~~ iss) f) of
-          SOME is =>
+      (case (AList.lookup (op =) param_modes f) of
+          SOME mode =>
             let
-              val _ = error "compile_arg: A parameter in a input position -- do we have a test case?"
-              val T' = Comp_Mod.funT_of compilation_modifiers is T
-            in t(*fst (mk_Eval_of additional_arguments ((Free (f, T'), T), is) [])*) end
-        | NONE => t
-      else t
+              val T' = Comp_Mod.funT_of compilation_modifiers mode T
+            in
+              mk_Eval_of (Free (f, T'), T) mode
+            end
+        | NONE => t)
       | map_params t = t
-    in map_aterms map_params arg end
+  in
+    map_aterms map_params arg
+  end
 
-fun compile_match compilation_modifiers additional_arguments
-  param_vs iss ctxt eqs eqs' out_ts success_t =
+fun compile_match compilation_modifiers additional_arguments ctxt param_modes
+      eqs eqs' out_ts success_t =
   let
     val compfuns = Comp_Mod.compfuns compilation_modifiers
     val eqs'' = maps mk_eq eqs @ eqs'
     val eqs'' =
-      map (compile_arg compilation_modifiers additional_arguments ctxt param_vs iss) eqs''
+      map (compile_arg compilation_modifiers additional_arguments ctxt param_modes) eqs''
     val names = fold Term.add_free_names (success_t :: eqs'' @ out_ts) [];
     val name = Name.variant names "x";
     val name' = Name.variant (name :: names) "y";
@@ -1692,12 +1666,12 @@
   | (t, Term Output) => Syntax.string_of_term ctxt t ^ "[Output]"
   | (t, Context m) => Syntax.string_of_term ctxt t ^ "[" ^ string_of_mode m ^ "]")
 
-fun compile_expr compilation_modifiers ctxt (t, deriv) additional_arguments =
+fun compile_expr compilation_modifiers ctxt (t, deriv) param_modes additional_arguments =
   let
     val compfuns = Comp_Mod.compfuns compilation_modifiers
     fun expr_of (t, deriv) =
       (case (t, deriv) of
-        (t, Term Input) => SOME t
+        (t, Term Input) => SOME (compile_arg compilation_modifiers additional_arguments ctxt param_modes t)
       | (t, Term Output) => NONE
       | (Const (name, T), Context mode) =>
         (case alternative_compilation_of ctxt name mode of
@@ -1728,13 +1702,12 @@
     list_comb (the (expr_of (t, deriv)), additional_arguments)
   end
 
-fun compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments
-  mode inp (in_ts, out_ts) moded_ps =
+fun compile_clause compilation_modifiers ctxt all_vs param_modes additional_arguments
+  inp (in_ts, out_ts) moded_ps =
   let
     val compfuns = Comp_Mod.compfuns compilation_modifiers
-    val iss = ho_arg_modes_of mode (* FIXME! *)
     val compile_match = compile_match compilation_modifiers
-      additional_arguments param_vs iss ctxt
+      additional_arguments ctxt param_modes
     val (in_ts', (all_vs', eqs)) =
       fold_map (collect_non_invertible_subterms ctxt) in_ts (all_vs, []);
     fun compile_prems out_ts' vs names [] =
@@ -1761,7 +1734,7 @@
                Prem t =>
                  let
                    val u =
-                     compile_expr compilation_modifiers ctxt (t, deriv) additional_arguments'
+                     compile_expr compilation_modifiers ctxt (t, deriv) param_modes additional_arguments'
                    val (_, out_ts''') = split_mode mode (snd (strip_comb t))
                    val rest = compile_prems out_ts''' vs' names'' ps
                  in
@@ -1772,7 +1745,7 @@
                    val neg_compilation_modifiers =
                      negative_comp_modifiers_of compilation_modifiers
                    val u = mk_not compfuns
-                     (compile_expr neg_compilation_modifiers ctxt (t, deriv) additional_arguments')
+                     (compile_expr neg_compilation_modifiers ctxt (t, deriv) param_modes additional_arguments')
                    val (_, out_ts''') = split_mode mode (snd (strip_comb t))
                    val rest = compile_prems out_ts''' vs' names'' ps
                  in
@@ -1781,7 +1754,7 @@
              | Sidecond t =>
                  let
                    val t = compile_arg compilation_modifiers additional_arguments
-                     ctxt param_vs iss t
+                     ctxt param_modes t
                    val rest = compile_prems [] vs' names'' ps;
                  in
                    (mk_if compfuns t, rest)
@@ -1797,7 +1770,7 @@
             compile_match constr_vs' eqs out_ts''
               (mk_bind compfuns (compiled_clause, rest))
           end
-    val prem_t = compile_prems in_ts' param_vs all_vs' moded_ps;
+    val prem_t = compile_prems in_ts' (map fst param_modes) all_vs' moded_ps;
   in
     mk_bind compfuns (mk_single compfuns inp, prem_t)
   end
@@ -1909,7 +1882,7 @@
   | _ => raise Fail "unexpected pattern")
 
 
-fun compile_switch compilation_modifiers ctxt all_vs param_vs additional_arguments mode
+fun compile_switch compilation_modifiers ctxt all_vs param_modes additional_arguments mode
   in_ts' outTs switch_tree =
   let
     val compfuns = Comp_Mod.compfuns compilation_modifiers
@@ -1929,8 +1902,8 @@
               val in_ts' = map (Pattern.rewrite_term thy (map swap fsubst) []) (map snd pat')
               val out_ts' = map (Pattern.rewrite_term thy (map swap fsubst) []) out_ts
             in
-              compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments
-                mode inp (in_ts', out_ts') moded_ps'
+              compile_clause compilation_modifiers ctxt all_vs param_modes additional_arguments
+                inp (in_ts', out_ts') moded_ps'
             end
         in SOME (foldr1 (mk_sup compfuns) (map compile_clause' moded_clauses)) end
     | compile_switch_tree all_vs ctxt_eqs (Node ((position, switched_clauses), left_clauses)) =
@@ -1991,17 +1964,18 @@
         (param_vs, (all_vs @ param_vs))
     val in_ts' = map_filter (map_filter_prod
       (fn t as Free (x, _) => if member (op =) param_vs x then NONE else SOME t | t => SOME t)) in_ts
+    val param_modes = param_vs ~~ ho_arg_modes_of mode
     val compilation =
       if detect_switches options then
         the_default (mk_bot compfuns (HOLogic.mk_tupleT outTs))
-          (compile_switch compilation_modifiers ctxt all_vs param_vs additional_arguments
-            mode in_ts' outTs (mk_switch_tree ctxt mode moded_cls))
+          (compile_switch compilation_modifiers ctxt all_vs param_modes additional_arguments mode
+            in_ts' outTs (mk_switch_tree ctxt mode moded_cls))
       else
         let
           val cl_ts =
             map (fn (ts, moded_prems) => 
-              compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments
-              mode (HOLogic.mk_tuple in_ts') (split_mode mode ts) moded_prems) moded_cls;
+              compile_clause compilation_modifiers ctxt all_vs param_modes additional_arguments
+                (HOLogic.mk_tuple in_ts') (split_mode mode ts) moded_prems) moded_cls;
         in
           Comp_Mod.wrap_compilation compilation_modifiers compfuns s T mode additional_arguments
             (if null cl_ts then
@@ -3236,7 +3210,7 @@
           | Pos_Random_DSeq => pos_random_dseq_comp_modifiers
           | New_Pos_Random_DSeq => new_pos_random_dseq_comp_modifiers
         val t_pred = compile_expr comp_modifiers ctxt
-          (body, deriv) additional_arguments;
+          (body, deriv) [] additional_arguments;
         val T_pred = dest_predT compfuns (fastype_of t_pred)
         val arrange = split_lambda (HOLogic.mk_tuple outargs) output_tuple
       in