some experiements towards user interface for predicate compiler
authorhaftmann
Fri, 24 Apr 2009 17:45:16 +0200
changeset 30972 5b65835ccc92
parent 30971 7fbebf75b3ef
child 30973 304ab57afa6e
some experiements towards user interface for predicate compiler
src/HOL/ex/Predicate_Compile.thy
src/HOL/ex/predicate_compile.ML
--- a/src/HOL/ex/Predicate_Compile.thy	Fri Apr 24 17:45:15 2009 +0200
+++ b/src/HOL/ex/Predicate_Compile.thy	Fri Apr 24 17:45:16 2009 +0200
@@ -1,20 +1,17 @@
 theory Predicate_Compile
-imports Complex_Main Code_Index Lattice_Syntax
+imports Complex_Main Lattice_Syntax Code_Eval
 uses "predicate_compile.ML"
 begin
 
+text {* Package setup *}
+
 setup {* Predicate_Compile.setup *}
 
-primrec "next" :: "('a Predicate.pred \<Rightarrow> ('a \<times> 'a Predicate.pred) option)
-  \<Rightarrow> 'a Predicate.seq \<Rightarrow> ('a \<times> 'a Predicate.pred) option" where
-    "next yield Predicate.Empty = None"
-  | "next yield (Predicate.Insert x P) = Some (x, P)"
-  | "next yield (Predicate.Join P xq) = (case yield P
-   of None \<Rightarrow> next yield xq | Some (x, Q) \<Rightarrow> Some (x, Predicate.Seq (\<lambda>_. Predicate.Join Q xq)))"
+
+text {* Experimental code *}
 
-fun anamorph :: "('b \<Rightarrow> ('a \<times> 'b) option) \<Rightarrow> index \<Rightarrow> 'b \<Rightarrow> 'a list \<times> 'b" where
-  "anamorph f k x = (if k = 0 then ([], x)
-    else case f x of None \<Rightarrow> ([], x) | Some (v, y) \<Rightarrow> let (vs, z) = anamorph f (k - 1) y in (v # vs, z))"
+definition pred_map :: "('a \<Rightarrow> 'b) \<Rightarrow> 'a Predicate.pred \<Rightarrow> 'b Predicate.pred" where
+  "pred_map f P = Predicate.bind P (Predicate.single o f)"
 
 ML {*
 structure Predicate =
@@ -22,11 +19,68 @@
 
 open Predicate;
 
-fun yield (Predicate.Seq f) = @{code next} yield (f ());
+val pred_ref = ref (NONE : (unit -> term Predicate.pred) option);
+
+fun eval_pred thy t =
+  t 
+  |> Eval.mk_term_of (fastype_of t)
+  |> (fn t => Code_ML.eval NONE ("Predicate.pred_ref", pred_ref) @{code pred_map} thy t []);
+
+fun eval_pred_elems thy t T length =
+  t |> eval_pred thy |> yieldn length |> fst |> HOLogic.mk_list T;
 
-fun yieldn k = @{code anamorph} yield k;
+fun analyze_compr thy t =
+  let
+    val split = case t of (Const (@{const_name Collect}, _) $ t') => t'
+      | _ => error ("Not a set comprehension: " ^ Syntax.string_of_term_global thy t);
+    val (body, Ts, fp) = HOLogic.strip_split split;
+    val (t_pred, args) = strip_comb body;
+    val pred = case t_pred of Const (pred, _) => pred
+      | _ => error ("Not a constant: " ^ Syntax.string_of_term_global thy t_pred);
+    val mode = map is_Bound args; (*FIXME what about higher-order modes?*)
+    val args' = filter_out is_Bound args;
+    val T = HOLogic.mk_tupleT fp Ts;
+    val mk = HOLogic.mk_tuple' fp T;
+  in (((pred, mode), args), (mk, T)) end;
 
 end;
 *}
 
+
+text {* Example(s) *}
+
+inductive even :: "nat \<Rightarrow> bool" and odd :: "nat \<Rightarrow> bool" where
+    "even 0"
+  | "even n \<Longrightarrow> odd (Suc n)"
+  | "odd n \<Longrightarrow> even (Suc n)"
+
+setup {* pred_compile "even" *}
+thm even_codegen
+
+
+inductive append :: "'a list \<Rightarrow> 'a list \<Rightarrow> 'a list \<Rightarrow> bool" where
+    append_Nil: "append [] xs xs"
+  | append_Cons: "append xs ys zs \<Longrightarrow> append (x # xs) ys (x # zs)"
+
+setup {* pred_compile "append" *}
+thm append_codegen
+
+
+inductive partition :: "('a \<Rightarrow> bool) \<Rightarrow> 'a list \<Rightarrow> 'a list \<Rightarrow> 'a list \<Rightarrow> bool"
+  for f where
+    "partition f [] [] []"
+  | "f x \<Longrightarrow> partition f xs ys zs \<Longrightarrow> partition f (x # xs) (x # ys) zs"
+  | "\<not> f x \<Longrightarrow> partition f xs ys zs \<Longrightarrow> partition f (x # xs) ys (x # zs)"
+
+setup {* pred_compile "partition" *}
+thm partition_codegen
+
+setup {* pred_compile "tranclp" *}
+thm tranclp_codegen
+
+ML_val {* Predicate_Compile.modes_of @{theory} @{const_name partition} *}
+ML_val {* Predicate_Compile.modes_of @{theory} @{const_name tranclp} *}
+
+ML_val {* Predicate.analyze_compr @{theory} @{term "{n. odd n}"} *}
+
 end
\ No newline at end of file
--- a/src/HOL/ex/predicate_compile.ML	Fri Apr 24 17:45:15 2009 +0200
+++ b/src/HOL/ex/predicate_compile.ML	Fri Apr 24 17:45:16 2009 +0200
@@ -6,38 +6,119 @@
 
 signature PREDICATE_COMPILE =
 sig
-  val create_def_equation': string -> (int list option list * int list) option -> theory -> theory
+  type mode = int list option list * int list
+  val create_def_equation': string -> mode option -> theory -> theory
   val create_def_equation: string -> theory -> theory
-  val intro_rule: theory -> string -> (int list option list * int list) -> thm
-  val elim_rule: theory -> string -> (int list option list * int list) -> thm
+  val intro_rule: theory -> string -> mode -> thm
+  val elim_rule: theory -> string -> mode -> thm
   val strip_intro_concl : term -> int -> (term * (term list * term list))
   val code_ind_intros_attrib : attribute
   val code_ind_cases_attrib : attribute
+  val print_alternative_rules : theory -> theory
+  val modename_of: theory -> string -> mode -> string
+  val modes_of: theory -> string -> mode list
   val setup : theory -> theory
-  val print_alternative_rules : theory -> theory
   val do_proofs: bool ref
 end;
 
 structure Predicate_Compile: PREDICATE_COMPILE =
 struct
 
+(** auxiliary **)
+
+(* debug stuff *)
+
+fun tracing s = (if ! Toplevel.debug then Output.tracing s else ());
+
+fun print_tac s = (if ! Toplevel.debug then Tactical.print_tac s else Seq.single);
+fun debug_tac msg = (fn st => (tracing msg; Seq.single st));
+
+val do_proofs = ref true;
+
+
+(** fundamentals **)
+
+(* syntactic operations *)
+
+fun mk_eq (x, xs) =
+  let fun mk_eqs _ [] = []
+        | mk_eqs a (b::cs) =
+            HOLogic.mk_eq (Free (a, fastype_of b), b) :: mk_eqs a cs
+  in mk_eqs x xs end;
+
+fun mk_tupleT [] = HOLogic.unitT
+  | mk_tupleT Ts = foldr1 HOLogic.mk_prodT Ts;
+
+fun mk_tuple [] = HOLogic.unit
+  | mk_tuple ts = foldr1 HOLogic.mk_prod ts;
+
+fun dest_tuple (Const (@{const_name Product_Type.Unity}, _)) = []
+  | dest_tuple (Const (@{const_name Pair}, _) $ t1 $ t2) = t1 :: (dest_tuple t2)
+  | dest_tuple t = [t]
+
+fun mk_pred_enumT T = Type ("Predicate.pred", [T])
+
+fun dest_pred_enumT (Type ("Predicate.pred", [T])) = T
+  | dest_pred_enumT T = raise TYPE ("dest_pred_enumT", [T], []);
+
+fun mk_Enum f =
+  let val T as Type ("fun", [T', _]) = fastype_of f
+  in
+    Const (@{const_name Predicate.Pred}, T --> mk_pred_enumT T') $ f    
+  end;
+
+fun mk_Eval (f, x) =
+  let val T = fastype_of x
+  in
+    Const (@{const_name Predicate.eval}, mk_pred_enumT T --> T --> HOLogic.boolT) $ f $ x
+  end;
+
+fun mk_empty T = Const (@{const_name Orderings.bot}, mk_pred_enumT T);
+
+fun mk_single t =
+  let val T = fastype_of t
+  in Const(@{const_name Predicate.single}, T --> mk_pred_enumT T) $ t end;
+
+fun mk_bind (x, f) =
+  let val T as Type ("fun", [_, U]) = fastype_of f
+  in
+    Const (@{const_name Predicate.bind}, fastype_of x --> T --> U) $ x $ f
+  end;
+
+val mk_sup = HOLogic.mk_binop @{const_name sup};
+
+fun mk_if_predenum cond = Const (@{const_name Predicate.if_pred},
+  HOLogic.boolT --> mk_pred_enumT HOLogic.unitT) $ cond;
+
+fun mk_not_pred t = let val T = mk_pred_enumT HOLogic.unitT
+  in Const (@{const_name Predicate.not_pred}, T --> T) $ t end
+
+
+(* data structures *)
+
+type mode = int list option list * int list;
+
+val mode_ord = prod_ord (list_ord (option_ord (list_ord int_ord))) (list_ord int_ord);
+
 structure PredModetab = TableFun(
-  type key = (string * (int list option list * int list))
-  val ord = prod_ord fast_string_ord (prod_ord
-            (list_ord (option_ord (list_ord int_ord))) (list_ord int_ord)))
+  type key = string * mode
+  val ord = prod_ord fast_string_ord mode_ord
+);
 
 
+(*FIXME scrap boilerplate*)
+
 structure IndCodegenData = TheoryDataFun
 (
   type T = {names : string PredModetab.table,
-            modes : ((int list option list * int list) list) Symtab.table,
+            modes : mode list Symtab.table,
             function_defs : Thm.thm Symtab.table,
             function_intros : Thm.thm Symtab.table,
             function_elims : Thm.thm Symtab.table,
-            intro_rules : (Thm.thm list) Symtab.table,
+            intro_rules : Thm.thm list Symtab.table,
             elim_rules : Thm.thm Symtab.table,
             nparams : int Symtab.table
-           };
+           }; (*FIXME: better group tables according to key*)
       (* names: map from inductive predicate and mode to function name (string).
          modes: map from inductive predicates to modes
          function_defs: map from function name to definition
@@ -115,26 +196,12 @@
             intro_rules = #intro_rules x, elim_rules = #elim_rules x,
             nparams = f (#nparams x)}) thy
 
-(* Debug stuff and tactics ***********************************************************)
-
-fun tracing s = (if ! Toplevel.debug then Output.tracing s else ());
-fun print_tac s = (if ! Toplevel.debug then Tactical.print_tac s else Seq.single);
-
-fun debug_tac msg = (fn st =>
-     (tracing msg; Seq.single st));
-
 (* removes first subgoal *)
 fun mycheat_tac thy i st =
   (Tactic.rtac (SkipProof.make_thm thy (Var (("A", 0), propT))) i) st
 
-val (do_proofs : bool ref) = ref true;
-
 (* Lightweight mode analysis **********************************************)
 
-(* Hack for message from old code generator *)
-val message = tracing;
-
-
 (**************************************************************************)
 (* source code from old code generator ************************************)
 
@@ -153,7 +220,8 @@
       | _ => false)
   in check end;
 
-(**** check if a type is an equality type (i.e. doesn't contain fun) ****)
+(**** check if a type is an equality type (i.e. doesn't contain fun)
+  FIXME this is only an approximation ****)
 
 fun is_eqT (Type (s, Ts)) = s <> "fun" andalso forall is_eqT Ts
   | is_eqT _ = true;
@@ -165,7 +233,7 @@
     | SOME js => enclose "[" "]" (commas (map string_of_int js)))
        (iss @ [SOME is]));
 
-fun print_modes modes = message ("Inferred modes:\n" ^
+fun print_modes modes = tracing ("Inferred modes:\n" ^
   cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map
     string_of_mode ms)) modes));
 
@@ -182,6 +250,7 @@
         (get_args' is (i+1) ts)
 in get_args' is 1 ts end
 
+(*FIXME this function should not be named merge... make it local instead*)
 fun merge xs [] = xs
   | merge [] ys = ys
   | merge (x::xs) (y::ys) = if length x >= length y then x::merge xs (y::ys)
@@ -197,7 +266,8 @@
 
 fun cprods xss = foldr (map op :: o cprod) [[]] xss;
 
-datatype mode = Mode of (int list option list * int list) * int list * mode option list;
+datatype hmode = Mode of mode * int list * hmode option list; (*FIXME don't understand
+  why there is another mode type!?*)
 
 fun modes_of modes t =
   let
@@ -285,11 +355,11 @@
   in (p, List.filter (fn m => case find_index
     (not o check_mode_clause thy param_vs modes m) rs of
       ~1 => true
-    | i => (message ("Clause " ^ string_of_int (i+1) ^ " of " ^
+    | i => (tracing ("Clause " ^ string_of_int (i+1) ^ " of " ^
       p ^ " violates mode " ^ string_of_mode m); false)) ms)
   end;
 
-fun fixp f (x : (string * (int list option list * int list) list) list) =
+fun fixp f (x : (string * mode list) list) =
   let val y = f x
   in if x = y then x else fixp f y end;
 
@@ -306,66 +376,6 @@
 (*****************************************************************************************)
 (**** term construction ****)
 
-fun mk_eq (x, xs) =
-  let fun mk_eqs _ [] = []
-        | mk_eqs a (b::cs) =
-            HOLogic.mk_eq (Free (a, fastype_of b), b) :: mk_eqs a cs
-  in mk_eqs x xs end;
-
-fun mk_tuple [] = HOLogic.unit
-  | mk_tuple ts = foldr1 HOLogic.mk_prod ts;
-
-fun dest_tuple (Const (@{const_name Product_Type.Unity}, _)) = []
-  | dest_tuple (Const (@{const_name Pair}, _) $ t1 $ t2) = t1 :: (dest_tuple t2)
-  | dest_tuple t = [t]
-
-fun mk_tupleT [] = HOLogic.unitT
-  | mk_tupleT Ts = foldr1 HOLogic.mk_prodT Ts;
-
-fun mk_pred_enumT T = Type ("Predicate.pred", [T])
-
-fun dest_pred_enumT (Type ("Predicate.pred", [T])) = T
-  | dest_pred_enumT T = raise TYPE ("dest_pred_enumT", [T], []);
-
-fun mk_single t =
-  let val T = fastype_of t
-  in Const(@{const_name Predicate.single}, T --> mk_pred_enumT T) $ t end;
-
-fun mk_empty T = Const (@{const_name Orderings.bot}, mk_pred_enumT T);
-
-fun mk_if_predenum cond = Const (@{const_name Predicate.if_pred},
-                          HOLogic.boolT --> mk_pred_enumT HOLogic.unitT) 
-                         $ cond
-
-fun mk_not_pred t = let val T = mk_pred_enumT HOLogic.unitT
-  in Const (@{const_name Predicate.not_pred}, T --> T) $ t end
-
-fun mk_bind (x, f) =
-  let val T as Type ("fun", [_, U]) = fastype_of f
-  in
-    Const (@{const_name Predicate.bind}, fastype_of x --> T --> U) $ x $ f
-  end;
-
-fun mk_Enum f =
-  let val T as Type ("fun", [T', _]) = fastype_of f
-  in
-    Const (@{const_name Predicate.Pred}, T --> mk_pred_enumT T') $ f    
-  end;
-
-fun mk_Eval (f, x) =
-  let val T = fastype_of x
-  in
-    Const (@{const_name Predicate.eval}, mk_pred_enumT T --> T --> HOLogic.boolT) $ f $ x
-  end;
-
-fun mk_Eval' f =
-  let val T = fastype_of f
-  in
-    Const (@{const_name Predicate.eval}, T --> dest_pred_enumT T --> HOLogic.boolT) $ f
-  end; 
-
-val mk_sup = HOLogic.mk_binop @{const_name sup};
-
 (* for simple modes (e.g. parameters) only: better call it param_funT *)
 (* or even better: remove it and only use funT'_of - some modifications to funT'_of necessary *) 
 fun funT_of T NONE = T
@@ -424,13 +434,16 @@
        (v', mk_empty U')]))
   end;
 
-fun modename thy name mode = let
+fun modename_of thy name mode = let
     val v = (PredModetab.lookup (#names (IndCodegenData.get thy)) (name, mode))
-  in if (is_some v) then the v
-     else error ("fun modename - definition not found: name: " ^ name ^ " mode: " ^  (makestring mode))
+  in if (is_some v) then the v (*FIXME use case here*)
+     else error ("fun modename_of - definition not found: name: " ^ name ^ " mode: " ^  (makestring mode))
   end
 
-(* function can be removed *)
+fun modes_of thy =
+  these o Symtab.lookup ((#modes o IndCodegenData.get) thy);
+
+(*FIXME function can be removed*)
 fun mk_funcomp f t =
   let
     val names = Term.add_free_names t [];
@@ -449,7 +462,7 @@
     val f' = case f of
         Const (name, T) =>
           if AList.defined op = modes name then
-            Const (modename thy name (iss, is'), funT'_of (iss, is') T)
+            Const (modename_of thy name (iss, is'), funT'_of (iss, is') T)
           else error "compile param: Not an inductive predicate with correct mode"
       | Free (name, T) => Free (name, funT_of T (SOME is'))
     in list_comb (f', params' @ args') end
@@ -463,7 +476,7 @@
                val (Ts, Us) = get_args is
                  (curry Library.drop (length ms) (fst (strip_type T)))
                val params' = map (compile_param thy modes) (ms ~~ params)
-               val mode_id = modename thy name mode
+               val mode_id = modename_of thy name mode
              in list_comb (Const (mode_id, ((map fastype_of params') @ Ts) --->
                mk_pred_enumT (mk_tupleT Us)), params')
              end
@@ -556,7 +569,7 @@
     val cl_ts =
       map (fn cl => compile_clause thy
         all_vs param_vs modes mode cl (mk_tuple xs)) cls;
-    val mode_id = modename thy s mode
+    val mode_id = modename_of thy s mode
   in
     HOLogic.mk_Trueprop (HOLogic.mk_eq
       (list_comb (Const (mode_id, (Ts1' @ Us1) --->
@@ -591,7 +604,7 @@
     fold Term.add_consts intrs [] |> map fst
     |> filter_out (member (op =) preds) |> filter (is_ind_pred thy)
 
-fun print_arities arities = message ("Arities:\n" ^
+fun print_arities arities = tracing ("Arities:\n" ^
   cat_lines (map (fn (s, (ks, k)) => s ^ ": " ^
     space_implode " -> " (map
       (fn NONE => "X" | SOME k' => string_of_int k')
@@ -691,10 +704,10 @@
 (* Proving equivalence of term *)
 
 
-fun intro_rule thy pred mode = modename thy pred mode
+fun intro_rule thy pred mode = modename_of thy pred mode
     |> Symtab.lookup (#function_intros (IndCodegenData.get thy)) |> the
 
-fun elim_rule thy pred mode = modename thy pred mode
+fun elim_rule thy pred mode = modename_of thy pred mode
     |> Symtab.lookup (#function_elims (IndCodegenData.get thy)) |> the
 
 fun pred_intros thy predname = let
@@ -711,7 +724,7 @@
   end
 
 fun function_definition thy pred mode =
-  modename thy pred mode |> Symtab.lookup (#function_defs (IndCodegenData.get thy)) |> the
+  modename_of thy pred mode |> Symtab.lookup (#function_defs (IndCodegenData.get thy)) |> the
 
 fun is_Type (Type _) = true
   | is_Type _ = false
@@ -973,7 +986,7 @@
     in nth (#elims (snd ind_result)) index end)
 
 fun prove_one_direction thy all_vs param_vs modes clauses ((pred, T), mode) = let
-  val elim_rule = the (Symtab.lookup (#function_elims (IndCodegenData.get thy)) (modename thy pred mode))
+  val elim_rule = the (Symtab.lookup (#function_elims (IndCodegenData.get thy)) (modename_of thy pred mode))
 (*  val ind_result = InductivePackage.the_inductive (ProofContext.init thy) pred
   val index = find_index (fn s => s = pred) (#names (fst ind_result))
   val (_, T) = dest_Const (nth (#preds (snd ind_result)) index) *)
@@ -1225,7 +1238,7 @@
 (* main function *********************************************************************)
 (*************************************************************************************)
 
-fun create_def_equation' ind_name (mode : (int list option list * int list) option) thy =
+fun create_def_equation' ind_name (mode : mode option) thy =
 let
   val _ = tracing ("starting create_def_equation' with " ^ ind_name)
   val (prednames, preds) = 
@@ -1249,6 +1262,7 @@
   val _ = tracing ("calling preds: " ^ makestring name_of_calls)
   val _ = tracing "starting recursive compilations"
   fun rec_call name thy = 
+    (*FIXME use member instead of infix mem*)
     if not (name mem (Symtab.keys (#modes (IndCodegenData.get thy)))) then
       create_def_equation name thy else thy
   val thy'' = fold rec_call name_of_calls thy'