--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML Thu Nov 12 09:10:07 2009 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML Thu Nov 12 09:10:16 2009 +0100
@@ -21,6 +21,7 @@
(fn (i, is) =>
string_of_int i ^ (case is of NONE => ""
| SOME is => "p" ^ enclose "[" "]" (commas (map string_of_int is)))) js)
+(* FIXME: remove! *)
fun string_of_mode (iss, is) = space_implode " -> " (map
(fn NONE => "X"
@@ -28,10 +29,102 @@
(iss @ [SOME is]));
fun string_of_tmode (Mode (predmode, termmode, param_modes)) =
- "predmode: " ^ (string_of_mode predmode) ^
+ "predmode: " ^ (string_of_mode predmode) ^
(if null param_modes then "" else
"; " ^ "params: " ^ commas (map (the_default "NONE" o Option.map string_of_tmode) param_modes))
+(* new datatype for mode *)
+
+datatype mode' = Bool | Input | Output | Pair of mode' * mode' | Fun of mode' * mode'
+
+(* name: binder_modes? *)
+fun strip_fun_mode (Fun (mode, mode')) = mode :: strip_fun_mode mode'
+ | strip_fun_mode Bool = []
+ | strip_fun_mode _ = error "Bad mode for strip_fun_mode"
+
+fun dest_fun_mode (Fun (mode, mode')) = mode :: dest_fun_mode mode'
+ | dest_fun_mode mode = [mode]
+
+fun dest_tuple_mode (Pair (mode, mode')) = mode :: dest_tuple_mode mode'
+ | dest_tuple_mode _ = []
+
+fun string_of_mode' mode' =
+ let
+ fun string_of_mode1 Input = "i"
+ | string_of_mode1 Output = "o"
+ | string_of_mode1 Bool = "bool"
+ | string_of_mode1 mode = "(" ^ (string_of_mode3 mode) ^ ")"
+ and string_of_mode2 (Pair (m1, m2)) = string_of_mode3 m1 ^ " * " ^ string_of_mode2 m2
+ | string_of_mode2 mode = string_of_mode1 mode
+ and string_of_mode3 (Fun (m1, m2)) = string_of_mode2 m1 ^ " => " ^ string_of_mode3 m2
+ | string_of_mode3 mode = string_of_mode2 mode
+ in string_of_mode3 mode' end
+
+
+fun translate_mode T (iss, is) =
+ let
+ val Ts = binder_types T
+ val (Ts1, Ts2) = chop (length iss) Ts
+ fun translate_smode Ts is =
+ let
+ fun translate_arg (i, T) =
+ case AList.lookup (op =) is (i + 1) of
+ SOME NONE => Input
+ | SOME (SOME its) =>
+ let
+ fun translate_tuple (i, T) = if member (op =) its (i + 1) then Input else Output
+ in
+ foldr1 Pair (map_index translate_tuple (HOLogic.strip_tupleT T))
+ end
+ | NONE => Output
+ in map_index translate_arg Ts end
+ fun mk_mode arg_modes = foldr1 Fun (arg_modes @ [Bool])
+ val param_modes =
+ map (fn (T, NONE) => Input | (T, SOME is) => mk_mode (translate_smode (binder_types T) is))
+ (Ts1 ~~ iss)
+ in
+ mk_mode (param_modes @ translate_smode Ts2 is)
+ end;
+
+fun string_of_mode thy constname mode =
+ string_of_mode' (translate_mode (Sign.the_const_type thy constname) mode)
+
+fun eq_mode' (Fun (m1, m2), Fun (m3, m4)) = eq_mode' (m1, m3) andalso eq_mode' (m2, m4)
+ | eq_mode' (Pair (m1, m2), Pair (m3, m4)) = eq_mode' (m1, m3) andalso eq_mode' (m2, m4)
+ | eq_mode' (Pair (m1, m2), Input) = eq_mode' (m1, Input) andalso eq_mode' (m2, Input)
+ | eq_mode' (Pair (m1, m2), Output) = eq_mode' (m1, Output) andalso eq_mode' (m2, Output)
+ | eq_mode' (Input, Pair (m1, m2)) = eq_mode' (Input, m1) andalso eq_mode' (Input, m2)
+ | eq_mode' (Output, Pair (m1, m2)) = eq_mode' (Output, m1) andalso eq_mode' (Output, m2)
+ | eq_mode' (Input, Input) = true
+ | eq_mode' (Output, Output) = true
+ | eq_mode' (Bool, Bool) = true
+ | eq_mode' _ = false
+(* FIXME: remove! *)
+fun eq_mode'_mode (mode', (iss, is)) =
+ let
+ val arg_modes = strip_fun_mode mode'
+ val (arg_modes1, arg_modes2) = chop (length iss) arg_modes
+ fun eq_arg Input NONE = true
+ | eq_arg _ NONE = false
+ | eq_arg mode (SOME is) =
+ let
+ val modes = dest_tuple_mode mode
+ in
+ forall (fn i => nth modes (i - 1) = Input) is
+ andalso forall (fn i => nth modes (i - 1) = Output)
+ (subtract (op =) is (1 upto length modes))
+ end
+ fun eq_mode'_smode mode' is =
+ forall (fn (i, t) => eq_arg (nth mode' (i - 1)) t) is
+ andalso forall (fn i => (nth mode' (i - 1) = Output))
+ (subtract (op =) (map fst is) (1 upto length mode'))
+ in
+ forall (fn (m, NONE) => m = Input | (m, SOME is) => eq_mode'_smode (strip_fun_mode m) is)
+ (arg_modes1 ~~ iss)
+ andalso eq_mode'_smode arg_modes2 is
+ end
+
+
(* general syntactic functions *)
(*Like dest_conj, but flattens conjunctions however nested*)
@@ -70,7 +163,20 @@
fun is_predT (T as Type("fun", [_, _])) = (snd (strip_type T) = HOLogic.boolT)
| is_predT _ = false
-
+(* guessing number of parameters *)
+fun find_indexes pred xs =
+ let
+ fun find is n [] = is
+ | find is n (x :: xs) = find (if pred x then (n :: is) else is) (n + 1) xs;
+ in rev (find [] 0 xs) end;
+
+fun guess_nparams T =
+ let
+ val argTs = binder_types T
+ val nparams = fold Integer.max
+ (map (fn x => x + 1) (find_indexes is_predT argTs)) 0
+ in nparams end;
+
(*** check if a term contains only constructor functions ***)
fun is_constrt thy =
let
@@ -159,7 +265,7 @@
(* Different options for compiler *)
datatype options = Options of {
- expected_modes : (string * mode list) option,
+ expected_modes : (string * mode' list) option,
show_steps : bool,
show_proof_trace : bool,
show_intermediate_results : bool,