src/HOL/ex/predicate_compile.ML
changeset 31108 0ce5f53fc65d
parent 31107 657386d94f14
parent 30972 5b65835ccc92
child 31111 ae2b24698695
--- a/src/HOL/ex/predicate_compile.ML	Mon May 11 09:39:53 2009 +0200
+++ b/src/HOL/ex/predicate_compile.ML	Mon May 11 17:20:52 2009 +0200
@@ -6,13 +6,17 @@
 
 signature PREDICATE_COMPILE =
 sig
-  val create_def_equation': string -> (int list option list * int list) option -> theory -> theory
+  type mode = int list option list * int list
+  val create_def_equation': string -> mode option -> theory -> theory
   val create_def_equation: string -> theory -> theory
-  val intro_rule: theory -> string -> (int list option list * int list) -> thm
-  val elim_rule: theory -> string -> (int list option list * int list) -> thm
+  val intro_rule: theory -> string -> mode -> thm
+  val elim_rule: theory -> string -> mode -> thm
   val strip_intro_concl : term -> int -> (term * (term list * term list))
   val code_ind_intros_attrib : attribute
   val code_ind_cases_attrib : attribute
+  val print_alternative_rules : theory -> theory
+  val modename_of: theory -> string -> mode -> string
+  val modes_of: theory -> string -> mode list
   val setup : theory -> theory
   val code_pred : string -> Proof.context -> Proof.state
   val code_pred_cmd : string -> Proof.context -> Proof.state
@@ -25,23 +29,101 @@
 structure Predicate_Compile: PREDICATE_COMPILE =
 struct
 
+(** auxiliary **)
+
+(* debug stuff *)
+
+fun tracing s = (if ! Toplevel.debug then Output.tracing s else ());
+
+fun print_tac s = (if ! Toplevel.debug then Tactical.print_tac s else Seq.single);
+fun debug_tac msg = (fn st => (tracing msg; Seq.single st));
+
+val do_proofs = ref true;
+
+
+(** fundamentals **)
+
+(* syntactic operations *)
+
+fun mk_eq (x, xs) =
+  let fun mk_eqs _ [] = []
+        | mk_eqs a (b::cs) =
+            HOLogic.mk_eq (Free (a, fastype_of b), b) :: mk_eqs a cs
+  in mk_eqs x xs end;
+
+fun mk_tupleT [] = HOLogic.unitT
+  | mk_tupleT Ts = foldr1 HOLogic.mk_prodT Ts;
+
+fun mk_tuple [] = HOLogic.unit
+  | mk_tuple ts = foldr1 HOLogic.mk_prod ts;
+
+fun dest_tuple (Const (@{const_name Product_Type.Unity}, _)) = []
+  | dest_tuple (Const (@{const_name Pair}, _) $ t1 $ t2) = t1 :: (dest_tuple t2)
+  | dest_tuple t = [t]
+
+fun mk_pred_enumT T = Type ("Predicate.pred", [T])
+
+fun dest_pred_enumT (Type ("Predicate.pred", [T])) = T
+  | dest_pred_enumT T = raise TYPE ("dest_pred_enumT", [T], []);
+
+fun mk_Enum f =
+  let val T as Type ("fun", [T', _]) = fastype_of f
+  in
+    Const (@{const_name Predicate.Pred}, T --> mk_pred_enumT T') $ f    
+  end;
+
+fun mk_Eval (f, x) =
+  let val T = fastype_of x
+  in
+    Const (@{const_name Predicate.eval}, mk_pred_enumT T --> T --> HOLogic.boolT) $ f $ x
+  end;
+
+fun mk_empty T = Const (@{const_name Orderings.bot}, mk_pred_enumT T);
+
+fun mk_single t =
+  let val T = fastype_of t
+  in Const(@{const_name Predicate.single}, T --> mk_pred_enumT T) $ t end;
+
+fun mk_bind (x, f) =
+  let val T as Type ("fun", [_, U]) = fastype_of f
+  in
+    Const (@{const_name Predicate.bind}, fastype_of x --> T --> U) $ x $ f
+  end;
+
+val mk_sup = HOLogic.mk_binop @{const_name sup};
+
+fun mk_if_predenum cond = Const (@{const_name Predicate.if_pred},
+  HOLogic.boolT --> mk_pred_enumT HOLogic.unitT) $ cond;
+
+fun mk_not_pred t = let val T = mk_pred_enumT HOLogic.unitT
+  in Const (@{const_name Predicate.not_pred}, T --> T) $ t end
+
+
+(* data structures *)
+
+type mode = int list option list * int list;
+
+val mode_ord = prod_ord (list_ord (option_ord (list_ord int_ord))) (list_ord int_ord);
+
 structure PredModetab = TableFun(
-  type key = (string * (int list option list * int list))
-  val ord = prod_ord fast_string_ord (prod_ord
-            (list_ord (option_ord (list_ord int_ord))) (list_ord int_ord)))
+  type key = string * mode
+  val ord = prod_ord fast_string_ord mode_ord
+);
 
 
+(*FIXME scrap boilerplate*)
+
 structure IndCodegenData = TheoryDataFun
 (
   type T = {names : string PredModetab.table,
-            modes : ((int list option list * int list) list) Symtab.table,
+            modes : mode list Symtab.table,
             function_defs : Thm.thm Symtab.table,
             function_intros : Thm.thm Symtab.table,
             function_elims : Thm.thm Symtab.table,
-            intro_rules : (Thm.thm list) Symtab.table,
+            intro_rules : Thm.thm list Symtab.table,
             elim_rules : Thm.thm Symtab.table,
             nparams : int Symtab.table
-           };
+           }; (*FIXME: better group tables according to key*)
       (* names: map from inductive predicate and mode to function name (string).
          modes: map from inductive predicates to modes
          function_defs: map from function name to definition
@@ -119,26 +201,12 @@
             intro_rules = #intro_rules x, elim_rules = #elim_rules x,
             nparams = f (#nparams x)}) thy
 
-(* Debug stuff and tactics ***********************************************************)
-
-fun tracing s = (if ! Toplevel.debug then Output.tracing s else ());
-fun print_tac s = (if ! Toplevel.debug then Tactical.print_tac s else Seq.single);
-
-fun debug_tac msg = (fn st =>
-     (tracing msg; Seq.single st));
-
 (* removes first subgoal *)
 fun mycheat_tac thy i st =
   (Tactic.rtac (SkipProof.make_thm thy (Var (("A", 0), propT))) i) st
 
-val (do_proofs : bool ref) = ref true;
-
 (* Lightweight mode analysis **********************************************)
 
-(* Hack for message from old code generator *)
-val message = tracing;
-
-
 (**************************************************************************)
 (* source code from old code generator ************************************)
 
@@ -157,7 +225,8 @@
       | _ => false)
   in check end;
 
-(**** check if a type is an equality type (i.e. doesn't contain fun) ****)
+(**** check if a type is an equality type (i.e. doesn't contain fun)
+  FIXME this is only an approximation ****)
 
 fun is_eqT (Type (s, Ts)) = s <> "fun" andalso forall is_eqT Ts
   | is_eqT _ = true;
@@ -169,7 +238,7 @@
     | SOME js => enclose "[" "]" (commas (map string_of_int js)))
        (iss @ [SOME is]));
 
-fun print_modes modes = message ("Inferred modes:\n" ^
+fun print_modes modes = tracing ("Inferred modes:\n" ^
   cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map
     string_of_mode ms)) modes));
 
@@ -186,6 +255,7 @@
         (get_args' is (i+1) ts)
 in get_args' is 1 ts end
 
+(*FIXME this function should not be named merge... make it local instead*)
 fun merge xs [] = xs
   | merge [] ys = ys
   | merge (x::xs) (y::ys) = if length x >= length y then x::merge xs (y::ys)
@@ -201,7 +271,8 @@
 
 fun cprods xss = foldr (map op :: o cprod) [[]] xss;
 
-datatype mode = Mode of (int list option list * int list) * int list * mode option list;
+datatype hmode = Mode of mode * int list * hmode option list; (*FIXME don't understand
+  why there is another mode type!?*)
 
 fun modes_of modes t =
   let
@@ -289,11 +360,11 @@
   in (p, List.filter (fn m => case find_index
     (not o check_mode_clause thy param_vs modes m) rs of
       ~1 => true
-    | i => (message ("Clause " ^ string_of_int (i+1) ^ " of " ^
+    | i => (tracing ("Clause " ^ string_of_int (i+1) ^ " of " ^
       p ^ " violates mode " ^ string_of_mode m); false)) ms)
   end;
 
-fun fixp f (x : (string * (int list option list * int list) list) list) =
+fun fixp f (x : (string * mode list) list) =
   let val y = f x
   in if x = y then x else fixp f y end;
 
@@ -310,66 +381,6 @@
 (*****************************************************************************************)
 (**** term construction ****)
 
-fun mk_eq (x, xs) =
-  let fun mk_eqs _ [] = []
-        | mk_eqs a (b::cs) =
-            HOLogic.mk_eq (Free (a, fastype_of b), b) :: mk_eqs a cs
-  in mk_eqs x xs end;
-
-fun mk_tuple [] = HOLogic.unit
-  | mk_tuple ts = foldr1 HOLogic.mk_prod ts;
-
-fun dest_tuple (Const (@{const_name Product_Type.Unity}, _)) = []
-  | dest_tuple (Const (@{const_name Pair}, _) $ t1 $ t2) = t1 :: (dest_tuple t2)
-  | dest_tuple t = [t]
-
-fun mk_tupleT [] = HOLogic.unitT
-  | mk_tupleT Ts = foldr1 HOLogic.mk_prodT Ts;
-
-fun mk_pred_enumT T = Type ("Predicate.pred", [T])
-
-fun dest_pred_enumT (Type ("Predicate.pred", [T])) = T
-  | dest_pred_enumT T = raise TYPE ("dest_pred_enumT", [T], []);
-
-fun mk_single t =
-  let val T = fastype_of t
-  in Const(@{const_name Predicate.single}, T --> mk_pred_enumT T) $ t end;
-
-fun mk_empty T = Const (@{const_name Orderings.bot}, mk_pred_enumT T);
-
-fun mk_if_predenum cond = Const (@{const_name Predicate.if_pred},
-                          HOLogic.boolT --> mk_pred_enumT HOLogic.unitT) 
-                         $ cond
-
-fun mk_not_pred t = let val T = mk_pred_enumT HOLogic.unitT
-  in Const (@{const_name Predicate.not_pred}, T --> T) $ t end
-
-fun mk_bind (x, f) =
-  let val T as Type ("fun", [_, U]) = fastype_of f
-  in
-    Const (@{const_name Predicate.bind}, fastype_of x --> T --> U) $ x $ f
-  end;
-
-fun mk_Enum f =
-  let val T as Type ("fun", [T', _]) = fastype_of f
-  in
-    Const (@{const_name Predicate.Pred}, T --> mk_pred_enumT T') $ f    
-  end;
-
-fun mk_Eval (f, x) =
-  let val T = fastype_of x
-  in
-    Const (@{const_name Predicate.eval}, mk_pred_enumT T --> T --> HOLogic.boolT) $ f $ x
-  end;
-
-fun mk_Eval' f =
-  let val T = fastype_of f
-  in
-    Const (@{const_name Predicate.eval}, T --> dest_pred_enumT T --> HOLogic.boolT) $ f
-  end; 
-
-val mk_sup = HOLogic.mk_binop @{const_name sup};
-
 (* for simple modes (e.g. parameters) only: better call it param_funT *)
 (* or even better: remove it and only use funT'_of - some modifications to funT'_of necessary *) 
 fun funT_of T NONE = T
@@ -428,13 +439,16 @@
        (v', mk_empty U')]))
   end;
 
-fun modename thy name mode = let
+fun modename_of thy name mode = let
     val v = (PredModetab.lookup (#names (IndCodegenData.get thy)) (name, mode))
-  in if (is_some v) then the v
-     else error ("fun modename - definition not found: name: " ^ name ^ " mode: " ^  (makestring mode))
+  in if (is_some v) then the v (*FIXME use case here*)
+     else error ("fun modename_of - definition not found: name: " ^ name ^ " mode: " ^  (makestring mode))
   end
 
-(* function can be removed *)
+fun modes_of thy =
+  these o Symtab.lookup ((#modes o IndCodegenData.get) thy);
+
+(*FIXME function can be removed*)
 fun mk_funcomp f t =
   let
     val names = Term.add_free_names t [];
@@ -453,7 +467,7 @@
     val f' = case f of
         Const (name, T) =>
           if AList.defined op = modes name then
-            Const (modename thy name (iss, is'), funT'_of (iss, is') T)
+            Const (modename_of thy name (iss, is'), funT'_of (iss, is') T)
           else error "compile param: Not an inductive predicate with correct mode"
       | Free (name, T) => Free (name, funT_of T (SOME is'))
     in list_comb (f', params' @ args') end
@@ -467,7 +481,7 @@
                val (Ts, Us) = get_args is
                  (curry Library.drop (length ms) (fst (strip_type T)))
                val params' = map (compile_param thy modes) (ms ~~ params)
-               val mode_id = modename thy name mode
+               val mode_id = modename_of thy name mode
              in list_comb (Const (mode_id, ((map fastype_of params') @ Ts) --->
                mk_pred_enumT (mk_tupleT Us)), params')
              end
@@ -560,7 +574,7 @@
     val cl_ts =
       map (fn cl => compile_clause thy
         all_vs param_vs modes mode cl (mk_tuple xs)) cls;
-    val mode_id = modename thy s mode
+    val mode_id = modename_of thy s mode
   in
     HOLogic.mk_Trueprop (HOLogic.mk_eq
       (list_comb (Const (mode_id, (Ts1' @ Us1) --->
@@ -595,7 +609,7 @@
     fold Term.add_consts intrs [] |> map fst
     |> filter_out (member (op =) preds) |> filter (is_ind_pred thy)
 
-fun print_arities arities = message ("Arities:\n" ^
+fun print_arities arities = tracing ("Arities:\n" ^
   cat_lines (map (fn (s, (ks, k)) => s ^ ": " ^
     space_implode " -> " (map
       (fn NONE => "X" | SOME k' => string_of_int k')
@@ -695,10 +709,10 @@
 (* Proving equivalence of term *)
 
 
-fun intro_rule thy pred mode = modename thy pred mode
+fun intro_rule thy pred mode = modename_of thy pred mode
     |> Symtab.lookup (#function_intros (IndCodegenData.get thy)) |> the
 
-fun elim_rule thy pred mode = modename thy pred mode
+fun elim_rule thy pred mode = modename_of thy pred mode
     |> Symtab.lookup (#function_elims (IndCodegenData.get thy)) |> the
 
 fun pred_intros thy predname = let
@@ -715,7 +729,7 @@
   end
 
 fun function_definition thy pred mode =
-  modename thy pred mode |> Symtab.lookup (#function_defs (IndCodegenData.get thy)) |> the
+  modename_of thy pred mode |> Symtab.lookup (#function_defs (IndCodegenData.get thy)) |> the
 
 fun is_Type (Type _) = true
   | is_Type _ = false
@@ -977,7 +991,7 @@
     in nth (#elims (snd ind_result)) index end)
 
 fun prove_one_direction thy all_vs param_vs modes clauses ((pred, T), mode) = let
-  val elim_rule = the (Symtab.lookup (#function_elims (IndCodegenData.get thy)) (modename thy pred mode))
+  val elim_rule = the (Symtab.lookup (#function_elims (IndCodegenData.get thy)) (modename_of thy pred mode))
 (*  val ind_result = InductivePackage.the_inductive (ProofContext.init thy) pred
   val index = find_index (fn s => s = pred) (#names (fst ind_result))
   val (_, T) = dest_Const (nth (#preds (snd ind_result)) index) *)
@@ -1229,7 +1243,7 @@
 (* main function *********************************************************************)
 (*************************************************************************************)
 
-fun create_def_equation' ind_name (mode : (int list option list * int list) option) thy =
+fun create_def_equation' ind_name (mode : mode option) thy =
 let
   val _ = tracing ("starting create_def_equation' with " ^ ind_name)
   val (prednames, preds) = 
@@ -1253,6 +1267,7 @@
   val _ = tracing ("calling preds: " ^ makestring name_of_calls)
   val _ = tracing "starting recursive compilations"
   fun rec_call name thy = 
+    (*FIXME use member instead of infix mem*)
     if not (name mem (Symtab.keys (#modes (IndCodegenData.get thy)))) then
       create_def_equation name thy else thy
   val thy'' = fold rec_call name_of_calls thy'