src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML
changeset 34948 2d5f2a9f7601
parent 33752 9aa8e961f850
child 35224 1c9866c5f6fb
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Sat Jan 16 21:14:15 2010 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Wed Jan 20 11:56:45 2010 +0100
@@ -6,49 +6,45 @@
 
 (* FIXME proper signature *)
 
+structure TermGraph = Graph(type key = term val ord = TermOrd.fast_term_ord);
+
 structure Predicate_Compile_Aux =
 struct
 
+(* general functions *)
+
+fun apfst3 f (x, y, z) = (f x, y, z)
+fun apsnd3 f (x, y, z) = (x, f y, z)
+fun aptrd3 f (x, y, z) = (x, y, f z)
+
+fun comb_option f (SOME x1, SOME x2) = SOME (f (x1, x2))
+  | comb_option f (NONE, SOME x2) = SOME x2
+  | comb_option f (SOME x1, NONE) = SOME x1
+  | comb_option f (NONE, NONE) = NONE
+
+fun map2_optional f (x :: xs) (y :: ys) = (f x (SOME y)) :: (map2_optional f xs ys)
+  | map2_optional f (x :: xs) [] = (f x NONE) :: (map2_optional f xs [])
+  | map2_optional f [] [] = []
+
+fun find_indices f xs =
+  map_filter (fn (i, true) => SOME i | (i, false) => NONE) (map_index (apsnd f) xs)
 
 (* mode *)
 
-type smode = (int * int list option) list
-type mode = smode option list * smode
-datatype tmode = Mode of mode * smode * tmode option list;
-
-fun string_of_smode js =
-    commas (map
-      (fn (i, is) =>
-        string_of_int i ^ (case is of NONE => ""
-    | SOME is => "p" ^ enclose "[" "]" (commas (map string_of_int is)))) js)
-(* FIXME: remove! *)
-
-fun string_of_mode (iss, is) = space_implode " -> " (map
-  (fn NONE => "X"
-    | SOME js => enclose "[" "]" (string_of_smode js))
-       (iss @ [SOME is]));
-
-fun string_of_tmode (Mode (predmode, termmode, param_modes)) =
-  "predmode: " ^ (string_of_mode predmode) ^
-  (if null param_modes then "" else
-    "; " ^ "params: " ^ commas (map (the_default "NONE" o Option.map string_of_tmode) param_modes))
-
-(* new datatype for mode *)
-
-datatype mode' = Bool | Input | Output | Pair of mode' * mode' | Fun of mode' * mode'
+datatype mode = Bool | Input | Output | Pair of mode * mode | Fun of mode * mode
 
 (* equality of instantiatedness with respect to equivalences:
   Pair Input Input == Input and Pair Output Output == Output *)
-fun eq_mode' (Fun (m1, m2), Fun (m3, m4)) = eq_mode' (m1, m3) andalso eq_mode' (m2, m4)
-  | eq_mode' (Pair (m1, m2), Pair (m3, m4)) = eq_mode' (m1, m3) andalso eq_mode' (m2, m4)
-  | eq_mode' (Pair (m1, m2), Input) = eq_mode' (m1, Input) andalso eq_mode' (m2, Input)
-  | eq_mode' (Pair (m1, m2), Output) = eq_mode' (m1, Output) andalso eq_mode' (m2, Output)
-  | eq_mode' (Input, Pair (m1, m2)) = eq_mode' (Input, m1) andalso eq_mode' (Input, m2)
-  | eq_mode' (Output, Pair (m1, m2)) = eq_mode' (Output, m1) andalso eq_mode' (Output, m2)
-  | eq_mode' (Input, Input) = true
-  | eq_mode' (Output, Output) = true
-  | eq_mode' (Bool, Bool) = true
-  | eq_mode' _ = false
+fun eq_mode (Fun (m1, m2), Fun (m3, m4)) = eq_mode (m1, m3) andalso eq_mode (m2, m4)
+  | eq_mode (Pair (m1, m2), Pair (m3, m4)) = eq_mode (m1, m3) andalso eq_mode (m2, m4)
+  | eq_mode (Pair (m1, m2), Input) = eq_mode (m1, Input) andalso eq_mode (m2, Input)
+  | eq_mode (Pair (m1, m2), Output) = eq_mode (m1, Output) andalso eq_mode (m2, Output)
+  | eq_mode (Input, Pair (m1, m2)) = eq_mode (Input, m1) andalso eq_mode (Input, m2)
+  | eq_mode (Output, Pair (m1, m2)) = eq_mode (Output, m1) andalso eq_mode (Output, m2)
+  | eq_mode (Input, Input) = true
+  | eq_mode (Output, Output) = true
+  | eq_mode (Bool, Bool) = true
+  | eq_mode _ = false
 
 (* name: binder_modes? *)
 fun strip_fun_mode (Fun (mode, mode')) = mode :: strip_fun_mode mode'
@@ -61,7 +57,153 @@
 fun dest_tuple_mode (Pair (mode, mode')) = mode :: dest_tuple_mode mode'
   | dest_tuple_mode _ = []
 
-fun string_of_mode' mode' =
+fun all_modes_of_typ (T as Type ("fun", _)) = 
+  let
+    val (S, U) = strip_type T
+  in
+    if U = HOLogic.boolT then
+      fold_rev (fn m1 => fn m2 => map_product (curry Fun) m1 m2)
+        (map all_modes_of_typ S) [Bool]
+    else
+      [Input, Output]
+  end
+  | all_modes_of_typ (Type ("*", [T1, T2])) = 
+    map_product (curry Pair) (all_modes_of_typ T1) (all_modes_of_typ T2)
+  | all_modes_of_typ (Type ("bool", [])) = [Bool]
+  | all_modes_of_typ _ = [Input, Output]
+
+fun extract_params arg =
+  case fastype_of arg of
+    (T as Type ("fun", _)) =>
+      (if (body_type T = HOLogic.boolT) then
+        (case arg of
+          Free _ => [arg] | _ => error "extract_params: Unexpected term")
+      else [])
+  | (Type ("*", [T1, T2])) =>
+      let
+        val (t1, t2) = HOLogic.dest_prod arg
+      in
+        extract_params t1 @ extract_params t2
+      end
+  | _ => []
+
+fun ho_arg_modes_of mode =
+  let
+    fun ho_arg_mode (m as Fun _) =  [m]
+      | ho_arg_mode (Pair (m1, m2)) = ho_arg_mode m1 @ ho_arg_mode m2
+      | ho_arg_mode _ = []
+  in
+    maps ho_arg_mode (strip_fun_mode mode)
+  end
+
+fun ho_args_of mode ts =
+  let
+    fun ho_arg (Fun _) (SOME t) = [t]
+      | ho_arg (Fun _) NONE = error "ho_arg_of"
+      | ho_arg (Pair (m1, m2)) (SOME (Const ("Pair", _) $ t1 $ t2)) =
+          ho_arg m1 (SOME t1) @ ho_arg m2 (SOME t2)
+      | ho_arg (Pair (m1, m2)) NONE = ho_arg m1 NONE @ ho_arg m2 NONE
+      | ho_arg _ _ = []
+  in
+    flat (map2_optional ho_arg (strip_fun_mode mode) ts)
+  end
+
+(* temporary function should be replaced by unsplit_input or so? *)
+fun replace_ho_args mode hoargs ts =
+  let
+    fun replace (Fun _, _) (arg' :: hoargs') = (arg', hoargs')
+      | replace (Pair (m1, m2), Const ("Pair", T) $ t1 $ t2) hoargs =
+        let
+          val (t1', hoargs') = replace (m1, t1) hoargs
+          val (t2', hoargs'') = replace (m2, t2) hoargs'
+        in
+          (Const ("Pair", T) $ t1' $ t2', hoargs'')
+        end
+      | replace (_, t) hoargs = (t, hoargs)
+  in
+    fst (fold_map replace ((strip_fun_mode mode) ~~ ts) hoargs)
+  end
+
+fun ho_argsT_of mode Ts =
+  let
+    fun ho_arg (Fun _) T = [T]
+      | ho_arg (Pair (m1, m2)) (Type ("*", [T1, T2])) = ho_arg m1 T1 @ ho_arg m2 T2
+      | ho_arg _ _ = []
+  in
+    flat (map2 ho_arg (strip_fun_mode mode) Ts)
+  end
+
+(* splits mode and maps function to higher-order argument types *)
+fun split_map_mode f mode ts =
+  let
+    fun split_arg_mode' (m as Fun _) t = f m t
+      | split_arg_mode' (Pair (m1, m2)) (Const ("Pair", _) $ t1 $ t2) =
+        let
+          val (i1, o1) = split_arg_mode' m1 t1
+          val (i2, o2) = split_arg_mode' m2 t2
+        in
+          (comb_option HOLogic.mk_prod (i1, i2), comb_option HOLogic.mk_prod (o1, o2))
+        end
+      | split_arg_mode' Input t = (SOME t, NONE)
+      | split_arg_mode' Output t = (NONE,  SOME t)
+      | split_arg_mode' _ _ = error "split_map_mode: mode and term do not match"
+  in
+    (pairself (map_filter I) o split_list) (map2 split_arg_mode' (strip_fun_mode mode) ts)
+  end
+
+(* splits mode and maps function to higher-order argument types *)
+fun split_map_modeT f mode Ts =
+  let
+    fun split_arg_mode' (m as Fun _) T = f m T
+      | split_arg_mode' (Pair (m1, m2)) (Type ("*", [T1, T2])) =
+        let
+          val (i1, o1) = split_arg_mode' m1 T1
+          val (i2, o2) = split_arg_mode' m2 T2
+        in
+          (comb_option HOLogic.mk_prodT (i1, i2), comb_option HOLogic.mk_prodT (o1, o2))
+        end
+      | split_arg_mode' Input T = (SOME T, NONE)
+      | split_arg_mode' Output T = (NONE,  SOME T)
+      | split_arg_mode' _ _ = error "split_modeT': mode and type do not match"
+  in
+    (pairself (map_filter I) o split_list) (map2 split_arg_mode' (strip_fun_mode mode) Ts)
+  end
+
+fun split_mode mode ts = split_map_mode (fn _ => fn _ => (NONE, NONE)) mode ts
+
+fun fold_map_aterms_prodT comb f (Type ("*", [T1, T2])) s =
+  let
+    val (x1, s') = fold_map_aterms_prodT comb f T1 s
+    val (x2, s'') = fold_map_aterms_prodT comb f T2 s'
+  in
+    (comb x1 x2, s'')
+  end
+  | fold_map_aterms_prodT comb f T s = f T s
+
+fun map_filter_prod f (Const ("Pair", _) $ t1 $ t2) =
+  comb_option HOLogic.mk_prod (map_filter_prod f t1, map_filter_prod f t2)
+  | map_filter_prod f t = f t
+
+(* obviously, split_mode' and split_modeT' do not match? where does that cause problems? *)
+  
+fun split_modeT' mode Ts =
+  let
+    fun split_arg_mode' (Fun _) T = ([], [])
+      | split_arg_mode' (Pair (m1, m2)) (Type ("*", [T1, T2])) =
+        let
+          val (i1, o1) = split_arg_mode' m1 T1
+          val (i2, o2) = split_arg_mode' m2 T2
+        in
+          (i1 @ i2, o1 @ o2)
+        end
+      | split_arg_mode' Input T = ([T], [])
+      | split_arg_mode' Output T = ([], [T])
+      | split_arg_mode' _ _ = error "split_modeT': mode and type do not match"
+  in
+    (pairself flat o split_list) (map2 split_arg_mode' (strip_fun_mode mode) Ts)
+  end
+
+fun string_of_mode mode =
   let
     fun string_of_mode1 Input = "i"
       | string_of_mode1 Output = "o"
@@ -71,9 +213,9 @@
       | string_of_mode2 mode = string_of_mode1 mode
     and string_of_mode3 (Fun (m1, m2)) = string_of_mode2 m1 ^ " => " ^ string_of_mode3 m2
       | string_of_mode3 mode = string_of_mode2 mode
-  in string_of_mode3 mode' end
+  in string_of_mode3 mode end
 
-fun ascii_string_of_mode' mode' =
+fun ascii_string_of_mode mode' =
   let
     fun ascii_string_of_mode' Input = "i"
       | ascii_string_of_mode' Output = "o"
@@ -91,55 +233,10 @@
       | ascii_string_of_mode'_Pair m = ascii_string_of_mode' m
   in ascii_string_of_mode'_Fun mode' end
 
-fun translate_mode T (iss, is) =
-  let
-    val Ts = binder_types T
-    val (Ts1, Ts2) = chop (length iss) Ts
-    fun translate_smode Ts is =
-      let
-        fun translate_arg (i, T) =
-          case AList.lookup (op =) is (i + 1) of
-            SOME NONE => Input
-          | SOME (SOME its) =>
-            let
-              fun translate_tuple (i, T) = if member (op =) its (i + 1) then Input else Output
-            in 
-              foldr1 Pair (map_index translate_tuple (HOLogic.strip_tupleT T))
-            end
-          | NONE => Output
-      in map_index translate_arg Ts end
-    fun mk_mode arg_modes = foldr1 Fun (arg_modes @ [Bool])
-    val param_modes =
-      map (fn (T, NONE) => Input | (T, SOME is) => mk_mode (translate_smode (binder_types T) is))
-        (Ts1 ~~ iss)
-  in
-    mk_mode (param_modes @ translate_smode Ts2 is)
-  end;
+(* premises *)
 
-fun translate_mode' nparams mode' =
-  let
-    fun err () = error "translate_mode': given mode cannot be translated"
-    val (m1, m2) = chop nparams (strip_fun_mode mode')
-    val translate_to_tupled_mode =
-      (map_filter I) o (map_index (fn (i, m) =>
-        if eq_mode' (m, Input) then SOME (i + 1)
-        else if eq_mode' (m, Output) then NONE
-        else err ()))
-    val translate_to_smode =
-      (map_filter I) o (map_index (fn (i, m) =>
-        if eq_mode' (m, Input) then SOME (i + 1, NONE)
-        else if eq_mode' (m, Output) then NONE
-        else SOME (i + 1, SOME (translate_to_tupled_mode (dest_tuple_mode m)))))
-    fun translate_to_param_mode m =
-      case rev (dest_fun_mode m) of
-        Bool :: _ :: _ => SOME (translate_to_smode (strip_fun_mode m))
-      | _ => if eq_mode' (m, Input) then NONE else err ()
-  in
-    (map translate_to_param_mode m1, translate_to_smode m2)
-  end
-
-fun string_of_mode thy constname mode =
-  string_of_mode' (translate_mode (Sign.the_const_type thy constname) mode)
+datatype indprem = Prem of term | Negprem of term | Sidecond of term
+  | Generator of (string * typ);
 
 (* general syntactic functions *)
 
@@ -162,9 +259,9 @@
 val is_pred_equation = is_pred_equation_term o prop_of 
 
 fun is_intro_term constname t =
-  case fst (strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl t))) of
+  the_default false (try (fn t => case fst (strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl t))) of
     Const (c, _) => c = constname
-  | _ => false
+  | _ => false) t)
   
 fun is_intro constname t = is_intro_term constname (prop_of t)
 
@@ -177,21 +274,8 @@
 fun is_predT (T as Type("fun", [_, _])) = (snd (strip_type T) = HOLogic.boolT)
   | is_predT _ = false
 
-(* guessing number of parameters *)
-fun find_indexes pred xs =
-  let
-    fun find is n [] = is
-      | find is n (x :: xs) = find (if pred x then (n :: is) else is) (n + 1) xs;
-  in rev (find [] 0 xs) end;
-
-fun guess_nparams T =
-  let
-    val argTs = binder_types T
-    val nparams = fold Integer.max
-      (map (fn x => x + 1) (find_indexes is_predT argTs)) 0
-  in nparams end;
-
 (*** check if a term contains only constructor functions ***)
+(* TODO: another copy in the core! *)
 (* FIXME: constructor terms are supposed to be seen in the way the code generator
   sees constructors.*)
 fun is_constrt thy =
@@ -206,7 +290,34 @@
           | _ => false)
       | _ => false)
   in check end;  
-  
+
+fun is_funtype (Type ("fun", [_, _])) = true
+  | is_funtype _ = false;
+
+fun is_Type (Type _) = true
+  | is_Type _ = false
+
+(* returns true if t is an application of an datatype constructor *)
+(* which then consequently would be splitted *)
+(* else false *)
+(*
+fun is_constructor thy t =
+  if (is_Type (fastype_of t)) then
+    (case DatatypePackage.get_datatype thy ((fst o dest_Type o fastype_of) t) of
+      NONE => false
+    | SOME info => (let
+      val constr_consts = maps (fn (_, (_, _, constrs)) => map fst constrs) (#descr info)
+      val (c, _) = strip_comb t
+      in (case c of
+        Const (name, _) => name mem_string constr_consts
+        | _ => false) end))
+  else false
+*)
+
+(* must be exported in code.ML *)
+(* TODO: is there copy in the core? *)
+fun is_constr thy = is_some o Code.get_datatype_of_constr thy;
+
 fun strip_ex (Const ("Ex", _) $ Abs (x, T, t)) =
   let
     val (xTs, t') = strip_ex t
@@ -224,7 +335,6 @@
     val t'' = Term.subst_bounds (rev vs, t');
   in ((ps', t''), nctxt') end;
 
-
 (* introduction rule combinators *)
 
 (* combinators to apply a function to all literals of an introduction rules *)
@@ -280,10 +390,23 @@
 
 (* Different options for compiler *)
 
+datatype compilation = Pred | Random | Depth_Limited | DSeq | Annotated | Random_DSeq
+
+fun string_of_compilation c = case c of
+    Pred => ""
+  | Random => "random"
+  | Depth_Limited => "depth limited"
+  | DSeq => "dseq"
+  | Annotated => "annotated"
+  | Random_DSeq => "random dseq"
+
+(*datatype compilation_options =
+  Pred | Random of int | Depth_Limited of int | DSeq of int | Annotated*)
+
 datatype options = Options of {  
-  expected_modes : (string * mode' list) option,
-  proposed_modes : (string * mode' list) option,
-  proposed_names : ((string * mode') * string) list,
+  expected_modes : (string * mode list) option,
+  proposed_modes : (string * mode list) option,
+  proposed_names : ((string * mode) * string) list,
   show_steps : bool,
   show_proof_trace : bool,
   show_intermediate_results : bool,
@@ -293,14 +416,12 @@
   skip_proof : bool,
 
   inductify : bool,
-  random : bool,
-  depth_limited : bool,
-  annotated : bool
+  compilation : compilation
 };
 
 fun expected_modes (Options opt) = #expected_modes opt
 fun proposed_modes (Options opt) = #proposed_modes opt
-fun proposed_names (Options opt) name mode = AList.lookup (eq_pair (op =) eq_mode')
+fun proposed_names (Options opt) name mode = AList.lookup (eq_pair (op =) eq_mode)
   (#proposed_names opt) (name, mode)
 
 fun show_steps (Options opt) = #show_steps opt
@@ -312,9 +433,8 @@
 fun skip_proof (Options opt) = #skip_proof opt
 
 fun is_inductify (Options opt) = #inductify opt
-fun is_random (Options opt) = #random opt
-fun is_depth_limited (Options opt) = #depth_limited opt
-fun is_annotated (Options opt) = #annotated opt
+
+fun compilation (Options opt) = #compilation opt
 
 val default_options = Options {
   expected_modes = NONE,
@@ -326,14 +446,18 @@
   show_modes = false,
   show_mode_inference = false,
   show_compilation = false,
-  skip_proof = false,
+  skip_proof = true,
   
   inductify = false,
-  random = false,
-  depth_limited = false,
-  annotated = false
+  compilation = Pred
 }
 
+val bool_options = ["show_steps", "show_intermediate_results", "show_proof_trace", "show_modes",
+  "show_mode_inference", "show_compilation", "skip_proof", "inductify"]
+
+val compilation_names = [("pred", Pred),
+  (*("random", Random), ("depth_limited", Depth_Limited), ("annotated", Annotated),*)
+  ("dseq", DSeq), ("random_dseq", Random_DSeq)]
 
 fun print_step options s =
   if show_steps options then tracing s else ()