improving the compilation with higher-order arguments in the predicate compiler
authorbulwahn
Sat, 24 Oct 2009 16:55:43 +0200
changeset 33147 180dc60bd88c
parent 33146 bf852ef586f2
child 33148 0808f7d0d0d7
improving the compilation with higher-order arguments in the predicate compiler
src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
src/HOL/ex/Predicate_Compile_ex.thy
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Sat Oct 24 16:55:43 2009 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Sat Oct 24 16:55:43 2009 +0200
@@ -438,8 +438,9 @@
     SOME (s, ms) => (case AList.lookup (op =) modes s of
       SOME modes =>
         if not (eq_set (map (map (rpair NONE)) ms, map snd modes)) then
-          error ("expected modes were not inferred:"
-            ^ "infered modes for " ^ s ^ ": " ^ commas (map (string_of_smode o snd) modes))
+          error ("expected modes were not inferred:\n"
+          ^ "inferred modes for " ^ s ^ ": "
+          ^ commas (map ((enclose "[" "]") o string_of_smode o snd) modes))
         else ()
       | NONE => ())
   | NONE => ()
@@ -1165,9 +1166,79 @@
       in (t' $ u', nvs'') end
   | distinct_v x nvs = (x, nvs);
 
-fun compile_match thy compfuns eqs eqs' out_ts success_t =
+(** specific rpred functions -- move them to the correct place in this file *)
+
+fun mk_Eval_of additional_arguments ((x, T), NONE) names = (x, names)
+  | mk_Eval_of additional_arguments ((x, T), SOME mode) names =
+	let
+    val Ts = binder_types T
+    (*val argnames = Name.variant_list names
+        (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
+    val args = map Free (argnames ~~ Ts)
+    val (inargs, outargs) = split_smode mode args*)
+		fun mk_split_lambda [] t = lambda (Free (Name.variant names "x", HOLogic.unitT)) t
+			| mk_split_lambda [x] t = lambda x t
+			| mk_split_lambda xs t =
+			let
+				fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t))
+					| mk_split_lambda' (x::xs) t = HOLogic.mk_split (lambda x (mk_split_lambda' xs t))
+			in
+				mk_split_lambda' xs t
+			end;
+  	fun mk_arg (i, T) =
+		  let
+	  	  val vname = Name.variant names ("x" ^ string_of_int i)
+		    val default = Free (vname, T)
+		  in 
+		    case AList.lookup (op =) mode i of
+		      NONE => (([], [default]), [default])
+			  | SOME NONE => (([default], []), [default])
+			  | SOME (SOME pis) =>
+				  case HOLogic.strip_tupleT T of
+						[] => error "pair mode but unit tuple" (*(([default], []), [default])*)
+					| [_] => error "pair mode but not a tuple" (*(([default], []), [default])*)
+					| Ts =>
+					  let
+							val vnames = Name.variant_list names
+								(map (fn j => "x" ^ string_of_int i ^ "p" ^ string_of_int j)
+									(1 upto length Ts))
+							val args = map Free (vnames ~~ Ts)
+							fun split_args (i, arg) (ins, outs) =
+							  if member (op =) pis i then
+							    (arg::ins, outs)
+								else
+								  (ins, arg::outs)
+							val (inargs, outargs) = fold_rev split_args ((1 upto length Ts) ~~ args) ([], [])
+							fun tuple args = if null args then [] else [HOLogic.mk_tuple args]
+						in ((tuple inargs, tuple outargs), args) end
+			end
+		val (inoutargs, args) = split_list (map mk_arg (1 upto (length Ts) ~~ Ts))
+    val (inargs, outargs) = pairself flat (split_list inoutargs)
+		val r = PredicateCompFuns.mk_Eval (list_comb (x, inargs @ additional_arguments), mk_tuple outargs)
+    val t = fold_rev mk_split_lambda args r
+  in
+    (t, names)
+  end;
+
+fun compile_arg compilation_modifiers compfuns additional_arguments thy param_vs iss arg = 
+  let
+    fun map_params (t as Free (f, T)) =
+      if member (op =) param_vs f then
+        case (the (AList.lookup (op =) (param_vs ~~ iss) f)) of
+          SOME is =>
+            let
+              val T' = #funT_of compilation_modifiers compfuns ([], is) T
+            in fst (mk_Eval_of additional_arguments ((Free (f, T'), T), SOME is) []) end
+        | NONE => t
+      else t
+      | map_params t = t
+    in map_aterms map_params arg end
+
+fun compile_match compilation_modifiers compfuns additional_arguments param_vs iss thy eqs eqs' out_ts success_t =
   let
     val eqs'' = maps mk_eq eqs @ eqs'
+    val eqs'' =
+      map (compile_arg compilation_modifiers compfuns additional_arguments thy param_vs iss) eqs''
     val names = fold Term.add_free_names (success_t :: eqs'' @ out_ts) [];
     val name = Name.variant names "x";
     val name' = Name.variant (name :: names) "y";
@@ -1228,76 +1299,9 @@
   | (Free (name, T), params) =>
     list_comb (Free (name, #funT_of compilation_modifiers compfuns mode T), params @ inargs @ additional_arguments)
 
-(** specific rpred functions -- move them to the correct place in this file *)
-
-fun mk_Eval_of depth ((x, T), NONE) names = (x, names)
-  | mk_Eval_of depth ((x, T), SOME mode) names =
-	let
-    val Ts = binder_types T
-    (*val argnames = Name.variant_list names
-        (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
-    val args = map Free (argnames ~~ Ts)
-    val (inargs, outargs) = split_smode mode args*)
-		fun mk_split_lambda [] t = lambda (Free (Name.variant names "x", HOLogic.unitT)) t
-			| mk_split_lambda [x] t = lambda x t
-			| mk_split_lambda xs t =
-			let
-				fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t))
-					| mk_split_lambda' (x::xs) t = HOLogic.mk_split (lambda x (mk_split_lambda' xs t))
-			in
-				mk_split_lambda' xs t
-			end;
-  	fun mk_arg (i, T) =
-		  let
-	  	  val vname = Name.variant names ("x" ^ string_of_int i)
-		    val default = Free (vname, T)
-		  in 
-		    case AList.lookup (op =) mode i of
-		      NONE => (([], [default]), [default])
-			  | SOME NONE => (([default], []), [default])
-			  | SOME (SOME pis) =>
-				  case HOLogic.strip_tupleT T of
-						[] => error "pair mode but unit tuple" (*(([default], []), [default])*)
-					| [_] => error "pair mode but not a tuple" (*(([default], []), [default])*)
-					| Ts =>
-					  let
-							val vnames = Name.variant_list names
-								(map (fn j => "x" ^ string_of_int i ^ "p" ^ string_of_int j)
-									(1 upto length Ts))
-							val args = map Free (vnames ~~ Ts)
-							fun split_args (i, arg) (ins, outs) =
-							  if member (op =) pis i then
-							    (arg::ins, outs)
-								else
-								  (ins, arg::outs)
-							val (inargs, outargs) = fold_rev split_args ((1 upto length Ts) ~~ args) ([], [])
-							fun tuple args = if null args then [] else [HOLogic.mk_tuple args]
-						in ((tuple inargs, tuple outargs), args) end
-			end
-		val (inoutargs, args) = split_list (map mk_arg (1 upto (length Ts) ~~ Ts))
-    val (inargs, outargs) = pairself flat (split_list inoutargs)
-    val depth_t = case depth of NONE => [] | SOME (polarity, depth_t) => [polarity, depth_t]
-		val r = PredicateCompFuns.mk_Eval (list_comb (x, inargs @ depth_t), mk_tuple outargs)
-    val t = fold_rev mk_split_lambda args r
-  in
-    (t, names)
-  end;
-
-fun compile_arg depth thy param_vs iss arg = 
-  let
-    val funT_of = case depth of NONE => funT_of | SOME _ => depth_limited_funT_of
-    fun map_params (t as Free (f, T)) =
-      if member (op =) param_vs f then
-        case (the (AList.lookup (op =) (param_vs ~~ iss) f)) of
-          SOME is => let val T' = funT_of PredicateCompFuns.compfuns ([], is) T
-            in fst (mk_Eval_of depth ((Free (f, T'), T), SOME is) []) end
-        | NONE => t
-      else t
-      | map_params t = t
-    in map_aterms map_params arg end
-  
 fun compile_clause compilation_modifiers compfuns thy all_vs param_vs additional_arguments (iss, is) inp (ts, moded_ps) =
   let
+    val compile_match = compile_match compilation_modifiers compfuns additional_arguments param_vs iss thy
     fun check_constrt t (names, eqs) =
       if is_constrt thy t then (t, (names, eqs)) else
         let
@@ -1316,7 +1320,7 @@
             val (out_ts''', (names'', constr_vs)) = fold_map distinct_v
               out_ts'' (names', map (rpair []) vs);
           in
-            compile_match thy compfuns constr_vs (eqs @ eqs') out_ts'''
+            compile_match constr_vs (eqs @ eqs') out_ts'''
               (mk_single compfuns (mk_tuple out_ts))
           end
       | compile_prems out_ts vs names ((p, mode as Mode ((_, is), _, _)) :: ps) =
@@ -1332,13 +1336,8 @@
                Prem (us, t) =>
                  let
                    val (in_ts, out_ts''') = split_smode is us;
-                     (* TODO: add test case for compile_arg *)
-                   (*val in_ts = map (compile_arg depth thy param_vs iss) in_ts*)
-                     (* additional_arguments
-                   val args = case depth of
-                     NONE => in_ts
-                   | SOME (polarity, depth_t) => in_ts @ [polarity, depth_t]
-                   *)
+                   val in_ts = map (compile_arg compilation_modifiers compfuns additional_arguments
+                     thy param_vs iss) in_ts
                    val u =
                      compile_expr compilation_modifiers compfuns thy (mode, t) in_ts additional_arguments'
                    val rest = compile_prems out_ts''' vs' names'' ps
@@ -1348,11 +1347,8 @@
              | Negprem (us, t) =>
                  let
                    val (in_ts, out_ts''') = split_smode is us
-                     (* additional_arguments
-                   val args = case depth of
-                     NONE => in_ts
-                   | SOME (polarity, depth_t) => in_ts @ [HOLogic.mk_not polarity, depth_t]
-                   *)
+                   val in_ts = map (compile_arg compilation_modifiers compfuns additional_arguments
+                     thy param_vs iss) in_ts
                    val u = mk_not compfuns
                      (compile_expr compilation_modifiers compfuns thy (mode, t) in_ts additional_arguments')
                    val rest = compile_prems out_ts''' vs' names'' ps
@@ -1361,6 +1357,8 @@
                  end
              | Sidecond t =>
                  let
+                   val t = compile_arg compilation_modifiers compfuns additional_arguments
+                     thy param_vs iss t
                    val rest = compile_prems [] vs' names'' ps;
                  in
                    (mk_if compfuns t, rest)
@@ -1374,7 +1372,7 @@
                    (u, rest)
                  end
           in
-            compile_match thy compfuns constr_vs' eqs out_ts''
+            compile_match constr_vs' eqs out_ts''
               (mk_bind compfuns (compiled_clause, rest))
           end
     val prem_t = compile_prems in_ts' param_vs all_vs' moded_ps;
@@ -1463,7 +1461,7 @@
   val param_names' = Name.variant_list (param_names @ argnames)
     (map (fn i => "p" ^ string_of_int i) (1 upto (length iss)))
   val param_vs = map Free (param_names' ~~ Ts1)
-  val (params', names) = fold_map (mk_Eval_of NONE) ((params ~~ Ts1) ~~ iss) []
+  val (params', names) = fold_map (mk_Eval_of []) ((params ~~ Ts1) ~~ iss) []
   val predpropI = HOLogic.mk_Trueprop (list_comb (pred, param_vs @ args))
   val predpropE = HOLogic.mk_Trueprop (list_comb (pred, params' @ args))
   val param_eqs = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (param_vs ~~ params')
@@ -1566,7 +1564,7 @@
    	  val (xinoutargs, names) = fold_map mk_vars ((1 upto (length Ts2)) ~~ Ts2) param_names
       val (xinout, xargs) = split_list xinoutargs
 			val (xins, xouts) = pairself flat (split_list xinout)
-			val (xparams', names') = fold_map (mk_Eval_of NONE) ((xparams ~~ Ts1) ~~ iss) names
+			val (xparams', names') = fold_map (mk_Eval_of []) ((xparams ~~ Ts1) ~~ iss) names
       fun mk_split_lambda [] t = lambda (Free (Name.variant names' "x", HOLogic.unitT)) t
         | mk_split_lambda [x] t = lambda x t
         | mk_split_lambda xs t =
--- a/src/HOL/ex/Predicate_Compile_ex.thy	Sat Oct 24 16:55:43 2009 +0200
+++ b/src/HOL/ex/Predicate_Compile_ex.thy	Sat Oct 24 16:55:43 2009 +0200
@@ -135,6 +135,48 @@
 code_pred (mode: [1], [1, 2], [1, 2, 3], [1, 3]) map_ofP .
 thm map_ofP.equation
 
+inductive filter1
+for P
+where
+  "filter1 P [] []"
+| "P x ==> filter1 P xs ys ==> filter1 P (x#xs) (x#ys)"
+| "\<not> P x ==> filter1 P xs ys ==> filter1 P (x#xs) ys"
+
+code_pred (mode: [1], [1, 2]) filter1 .
+code_pred [depth_limited] filter1 .
+code_pred [rpred] filter1 .
+
+thm filter1.equation
+
+inductive filter2
+where
+  "filter2 P [] []"
+| "P x ==> filter2 P xs ys ==> filter2 P (x#xs) (x#ys)"
+| "\<not> P x ==> filter2 P xs ys ==> filter2 P (x#xs) ys"
+
+code_pred (mode: [1, 2, 3], [1, 2]) filter2 .
+code_pred [depth_limited] filter2 .
+code_pred [rpred] filter2 .
+thm filter2.equation
+thm filter2.rpred_equation
+
+inductive filter3
+for P
+where
+  "List.filter P xs = ys ==> filter3 P xs ys"
+
+code_pred filter3 .
+code_pred [depth_limited] filter3 .
+thm filter3.depth_limited_equation
+(*code_pred [rpred] filter3 .*)
+inductive filter4
+where
+  "List.filter P xs = ys ==> filter4 P xs ys"
+
+code_pred filter4 .
+code_pred [depth_limited] filter4 .
+code_pred [rpred] filter4 .
+
 section {* reverse *}
 
 inductive rev where