src/HOL/Tools/Predicate_Compile/predicate_compile_fun.ML
changeset 35324 c9f428269b38
parent 35021 c839a4c670c6
child 35411 cafb74a131da
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_fun.ML	Tue Feb 23 10:02:14 2010 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_fun.ML	Tue Feb 23 13:36:15 2010 +0100
@@ -9,6 +9,8 @@
   val define_predicates : (string * thm list) list -> theory -> (string * thm list) list * theory
   val rewrite_intro : theory -> thm -> thm list
   val pred_of_function : theory -> string -> string option
+  
+  val add_function_predicate_translation : (term * term) -> theory -> theory
 end;
 
 structure Predicate_Compile_Fun : PREDICATE_COMPILE_FUN =
@@ -16,19 +18,36 @@
 
 open Predicate_Compile_Aux;
 
-(* Table from constant name (string) to term of inductive predicate *)
-structure Pred_Compile_Preproc = Theory_Data
+(* Table from function to inductive predicate *)
+structure Fun_Pred = Theory_Data
 (
-  type T = string Symtab.table;
-  val empty = Symtab.empty;
+  type T = (term * term) Item_Net.T;
+  val empty = Item_Net.init (op aconv o pairself fst) (single o fst);
   val extend = I;
-  fun merge data : T = Symtab.merge (op =) data;   (* FIXME handle Symtab.DUP ?? *)
+  val merge = Item_Net.merge;
 )
 
-fun pred_of_function thy name = Symtab.lookup (Pred_Compile_Preproc.get thy) name
+fun lookup thy net t =
+  case Item_Net.retrieve net t of
+    [] => NONE
+  | [(f, p)] =>
+    let
+      val subst = Pattern.match thy (f, t) (Vartab.empty, Vartab.empty)
+    in
+      SOME (Envir.subst_term subst p)
+    end
+  | _ => error ("Multiple matches possible for lookup of " ^ Syntax.string_of_term_global thy t)
 
-fun defined thy = Symtab.defined (Pred_Compile_Preproc.get thy) 
+fun pred_of_function thy name =
+  case Item_Net.retrieve (Fun_Pred.get thy) (Const (name, Term.dummyT)) of
+    [] => NONE
+  | [(f, p)] => SOME (fst (dest_Const p))
+  | _ => error ("Multiple matches possible for lookup of constant " ^ name)
 
+fun defined_const thy name = is_some (pred_of_function thy name)
+
+fun add_function_predicate_translation (f, p) =
+  Fun_Pred.map (Item_Net.update (f, p))
 
 fun transform_ho_typ (T as Type ("fun", _)) =
   let
@@ -63,27 +82,6 @@
       (Free (Long_Name.base_name name ^ "P", pred_type T))
   end
 
-fun mk_param thy lookup_pred (t as Free (v, _)) = lookup_pred t
-  | mk_param thy lookup_pred t =
-  if Predicate_Compile_Aux.is_predT (fastype_of t) then
-    t
-  else
-    let
-      val (vs, body) = strip_abs t
-      val names = Term.add_free_names body []
-      val vs_names = Name.variant_list names (map fst vs)
-      val vs' = map2 (curry Free) vs_names (map snd vs)
-      val body' = subst_bounds (rev vs', body)
-      val (f, args) = strip_comb body'
-      val resname = Name.variant (vs_names @ names) "res"
-      val resvar = Free (resname, body_type (fastype_of body'))
-      (*val P = case try lookup_pred f of SOME P => P | NONE => error "mk_param"
-      val pred_body = list_comb (P, args @ [resvar])
-      *)
-      val pred_body = HOLogic.mk_eq (body', resvar)
-      val param = fold_rev lambda (vs' @ [resvar]) pred_body
-    in param end
-    
 (* creates the list of premises for every intro rule *)
 (* theory -> term -> (string list, term list list) *)
 
@@ -92,22 +90,6 @@
   val (func, args) = strip_comb lhs
 in ((func, args), rhs) end;
 
-fun string_of_typ T = Syntax.string_of_typ_global @{theory} T
-
-fun string_of_term t =
-  case t of
-    Const (c, T) => "Const (" ^ c ^ ", " ^ string_of_typ T ^ ")"
-  | Free (c, T) => "Free (" ^ c ^ ", " ^ string_of_typ T ^ ")"
-  | Var ((c, i), T) => "Var ((" ^ c ^ ", " ^ string_of_int i ^ "), " ^ string_of_typ T ^ ")"
-  | Bound i => "Bound " ^ string_of_int i
-  | Abs (x, T, t) => "Abs (" ^ x ^ ", " ^ string_of_typ T ^ ", " ^ string_of_term t ^ ")"
-  | t1 $ t2 => "(" ^ string_of_term t1 ^ ") $ (" ^ string_of_term t2 ^ ")"
-  
-fun ind_package_get_nparams thy name =
-  case try (Inductive.the_inductive (ProofContext.init thy)) name of
-    SOME (_, result) => length (Inductive.params_of (#raw_induct result))
-  | NONE => error ("No such predicate: " ^ quote name) 
-
 (* TODO: does not work with higher order functions yet *)
 fun mk_rewr_eq (func, pred) =
   let
@@ -122,49 +104,6 @@
       (HOLogic.mk_eq (res, list_comb (func, args)), list_comb (pred, args @ [res]))
   end;
 
-fun has_split_rule_cname @{const_name "nat_case"} = true
-  | has_split_rule_cname @{const_name "list_case"} = true
-  | has_split_rule_cname _ = false
-  
-fun has_split_rule_term thy (Const (@{const_name "nat_case"}, _)) = true 
-  | has_split_rule_term thy (Const (@{const_name "list_case"}, _)) = true 
-  | has_split_rule_term thy _ = false
-
-fun has_split_rule_term' thy (Const (@{const_name "If"}, _)) = true
-  | has_split_rule_term' thy (Const (@{const_name "Let"}, _)) = true
-  | has_split_rule_term' thy c = has_split_rule_term thy c
-  
-fun prepare_split_thm ctxt split_thm =
-    (split_thm RS @{thm iffD2})
-    |> LocalDefs.unfold ctxt [@{thm atomize_conjL[symmetric]},
-      @{thm atomize_all[symmetric]}, @{thm atomize_imp[symmetric]}]
-
-fun find_split_thm thy (Const (name, typ)) =
-  let
-    fun split_name str =
-      case first_field "." str
-        of (SOME (field, rest)) => field :: split_name rest
-         | NONE => [str]
-    val splitted_name = split_name name
-  in
-    if length splitted_name > 0 andalso
-       String.isSuffix "_case" (List.last splitted_name)
-    then
-      (List.take (splitted_name, length splitted_name - 1)) @ ["split"]
-      |> space_implode "."
-      |> PureThy.get_thm thy
-      |> SOME
-      handle ERROR msg => NONE
-    else NONE
-  end
-  | find_split_thm _ _ = NONE
-
-fun find_split_thm' thy (Const (@{const_name "If"}, _)) = SOME @{thm split_if}
-  | find_split_thm' thy (Const (@{const_name "Let"}, _)) = SOME @{thm refl} (* TODO *)
-  | find_split_thm' thy c = find_split_thm thy c
-
-fun strip_all t = (Term.strip_all_vars t, Term.strip_all_body t)
-
 fun folds_map f xs y =
   let
     fun folds_map' acc [] y = [(rev acc, y)]
@@ -174,23 +113,91 @@
       folds_map' [] xs y
     end;
 
-fun mk_prems thy (lookup_pred, get_nparams) t (names, prems) =
+fun keep_functions thy t =
+  case try dest_Const (fst (strip_comb t)) of
+    SOME (c, _) => Predicate_Compile_Data.keep_function thy c
+  | _ => false
+
+fun mk_prems thy lookup_pred t (names, prems) =
   let
     fun mk_prems' (t as Const (name, T)) (names, prems) =
-      if is_constr thy name orelse (is_none (try lookup_pred t)) then
+      (if is_constr thy name orelse (is_none (lookup_pred t)) then
         [(t, (names, prems))]
-      else [(lookup_pred t, (names, prems))]
+      else
+       (*(if is_none (try lookup_pred t) then
+          [(Abs ("uu", fastype_of t, HOLogic.mk_eq (t, Bound 0)), (names, prems))]
+        else*) [(the (lookup_pred t), (names, prems))])
     | mk_prems' (t as Free (f, T)) (names, prems) = 
-      [(lookup_pred t, (names, prems))]
+      (case lookup_pred t of
+        SOME t' => [(t', (names, prems))]
+      | NONE => [(t, (names, prems))])
     | mk_prems' (t as Abs _) (names, prems) =
       if Predicate_Compile_Aux.is_predT (fastype_of t) then
-      [(t, (names, prems))] else error "mk_prems': Abs "
-      (* mk_param *)
+        ([(Envir.eta_contract t, (names, prems))])
+      else
+        let
+          val (vars, body) = strip_abs t
+          val _ = assert (fastype_of body = body_type (fastype_of body))
+          val absnames = Name.variant_list names (map fst vars)
+          val frees = map2 (curry Free) absnames (map snd vars)
+          val body' = subst_bounds (rev frees, body)
+          val resname = Name.variant (absnames @ names) "res"
+          val resvar = Free (resname, fastype_of body)
+          val t = mk_prems' body' ([], [])
+            |> map (fn (res, (inner_names, inner_prems)) =>
+              let
+                fun mk_exists (x, T) t = HOLogic.mk_exists (x, T, t)
+                val vTs = 
+                  fold Term.add_frees inner_prems []
+                  |> filter (fn (x, T) => member (op =) inner_names x)
+                val t = 
+                  fold mk_exists vTs
+                  (foldr1 HOLogic.mk_conj (HOLogic.mk_eq (resvar, res) ::
+                    map HOLogic.dest_Trueprop inner_prems))
+              in
+                t
+              end)
+              |> foldr1 HOLogic.mk_disj
+              |> fold lambda (resvar :: rev frees)
+        in
+          [(t, (names, prems))]
+        end
     | mk_prems' t (names, prems) =
-      if Predicate_Compile_Aux.is_constrt thy t then
+      if Predicate_Compile_Aux.is_constrt thy t orelse keep_functions thy t then
         [(t, (names, prems))]
       else
-        if has_split_rule_term' thy (fst (strip_comb t)) then
+        case (fst (strip_comb t)) of
+          Const (@{const_name "If"}, _) =>
+            (let
+              val (_, [B, x, y]) = strip_comb t
+            in
+              (mk_prems' x (names, prems)
+              |> map (fn (res, (names, prems)) => (res, (names, (HOLogic.mk_Trueprop B) :: prems))))
+              @ (mk_prems' y (names, prems)
+              |> map (fn (res, (names, prems)) =>
+                (res, (names, (HOLogic.mk_Trueprop (HOLogic.mk_not B)) :: prems))))
+            end)
+        | Const (@{const_name "Let"}, _) => 
+            (let
+              val (_, [f, g]) = strip_comb t
+            in
+              mk_prems' f (names, prems)
+              |> maps (fn (res, (names, prems)) =>
+                mk_prems' (betapply (g, res)) (names, prems))
+            end)
+        | Const (@{const_name "split"}, _) => 
+            (let
+              val (_, [g, res]) = strip_comb t
+              val [res1, res2] = Name.variant_list names ["res1", "res2"]
+              val (T1, T2) = HOLogic.dest_prodT (fastype_of res)
+              val (resv1, resv2) = (Free (res1, T1), Free (res2, T2))
+            in
+              mk_prems' (betapplys (g, [resv1, resv2]))
+              (res1 :: res2 :: names,
+              HOLogic.mk_Trueprop (HOLogic.mk_eq (res, HOLogic.mk_prod (resv1, resv2))) :: prems)
+            end)
+        | _ =>
+        if has_split_thm thy (fst (strip_comb t)) then
           let
             val (f, args) = strip_comb t
             val split_thm = prepare_split_thm (ProofContext.init thy) (the (find_split_thm' thy f))
@@ -208,8 +215,15 @@
                 val vars = map Free (var_names ~~ (map snd vTs))
                 val (prems', pre_res) = Logic.strip_horn (subst_bounds (rev vars, assm'))
                 val (_, [inner_t]) = strip_comb (HOLogic.dest_Trueprop pre_res)
+                val (lhss : term list, rhss) =
+                  split_list (map (HOLogic.dest_eq o HOLogic.dest_Trueprop) prems')
               in
-                mk_prems' inner_t (var_names @ names, prems' @ prems)
+                folds_map mk_prems' lhss (var_names @ names, prems)
+                |> map (fn (ress, (names, prems)) =>
+                  let
+                    val prems' = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (ress ~~ rhss)
+                  in (names, prems' @ prems) end)
+                |> maps (mk_prems' inner_t)
               end
           in
             maps mk_prems_of_assm assms
@@ -219,53 +233,77 @@
             val (f, args) = strip_comb t
             (* TODO: special procedure for higher-order functions: split arguments in
               simple types and function types *)
+            val args = map (Pattern.eta_long []) args
             val resname = Name.variant names "res"
             val resvar = Free (resname, body_type (fastype_of t))
+            val _ = assert (fastype_of t = body_type (fastype_of t))
             val names' = resname :: names
             fun mk_prems'' (t as Const (c, _)) =
-              if is_constr thy c orelse (is_none (try lookup_pred t)) then
+              if is_constr thy c orelse (is_none (lookup_pred t)) then
+                let
+                  val _ = ()(*tracing ("not translating function " ^ Syntax.string_of_term_global thy t)*)
+                in
                 folds_map mk_prems' args (names', prems) |>
                 map
                   (fn (argvs, (names'', prems')) =>
                   let
                     val prem = HOLogic.mk_Trueprop (HOLogic.mk_eq (resvar, list_comb (f, argvs)))
                   in (names'', prem :: prems') end)
+                end
               else
                 let
-                  val pred = lookup_pred t
-                  val nparams = get_nparams pred
-                  val (params, args) = chop nparams args
-                  val params' = map (mk_param thy lookup_pred) params
+                  (* lookup_pred is falsch für polymorphe Argumente und bool. *)
+                  val pred = the (lookup_pred t)
+                  val Ts = binder_types (fastype_of pred)
                 in
                   folds_map mk_prems' args (names', prems)
                   |> map (fn (argvs, (names'', prems')) =>
                     let
-                      val prem = HOLogic.mk_Trueprop (list_comb (pred, params' @ argvs @ [resvar]))
+                      fun lift_arg T t =
+                        if (fastype_of t) = T then t
+                        else
+                          let
+                            val _ = assert (T =
+                              (binder_types (fastype_of t) @ [@{typ bool}] ---> @{typ bool}))
+                            fun mk_if T (b, t, e) =
+                              Const (@{const_name If}, @{typ bool} --> T --> T --> T) $ b $ t $ e
+                            val Ts = binder_types (fastype_of t)
+                            val t = 
+                            list_abs (map (pair "x") Ts @ [("b", @{typ bool})],
+                              mk_if @{typ bool} (list_comb (t, map Bound (length Ts downto 1)),
+                              HOLogic.mk_eq (@{term True}, Bound 0),
+                              HOLogic.mk_eq (@{term False}, Bound 0)))
+                          in
+                            t
+                          end
+                      (*val _ = tracing ("Ts: " ^ commas (map (Syntax.string_of_typ_global thy) Ts))
+                      val _ = map2 check_arity Ts (map fastype_of (argvs @ [resvar]))*)
+                      val argvs' = map2 lift_arg (fst (split_last Ts)) argvs
+                      val prem = HOLogic.mk_Trueprop (list_comb (pred, argvs' @ [resvar]))
                     in (names'', prem :: prems') end)
                 end
             | mk_prems'' (t as Free (_, _)) =
-                let
-                  (* higher order argument call *)
-                  val pred = lookup_pred t
-                in
-                  folds_map mk_prems' args (resname :: names, prems)
-                  |> map (fn (argvs, (names', prems')) =>
-                     let
-                       val prem = HOLogic.mk_Trueprop (list_comb (pred, argvs @ [resvar]))
-                     in (names', prem :: prems') end)
-                end
+              folds_map mk_prems' args (names', prems) |>
+                map
+                  (fn (argvs, (names'', prems')) =>
+                  let
+                    val prem = 
+                      case lookup_pred t of
+                        NONE => HOLogic.mk_Trueprop (HOLogic.mk_eq (resvar, list_comb (f, argvs)))
+                      | SOME p => HOLogic.mk_Trueprop (list_comb (p, argvs @ [resvar]))
+                  in (names'', prem :: prems') end)
             | mk_prems'' t =
               error ("Invalid term: " ^ Syntax.string_of_term_global thy t)
           in
             map (pair resvar) (mk_prems'' f)
           end
   in
-    mk_prems' t (names, prems)
+    mk_prems' (Pattern.eta_long [] t) (names, prems)
   end;
 
 (* assumption: mutual recursive predicates all have the same parameters. *)  
 fun define_predicates specs thy =
-  if forall (fn (const, _) => member (op =) (Symtab.keys (Pred_Compile_Preproc.get thy)) const) specs then
+  if forall (fn (const, _) => defined_const thy const) specs then
     ([], thy)
   else
   let
@@ -275,36 +313,20 @@
       (* create prednames *)
     val ((funs, argss), rhss) = map_split dest_code_eqn eqns |>> split_list
     val argss' = map (map transform_ho_arg) argss
-    val pnames = map dest_Free (distinct (op =) (maps (filter (is_funtype o fastype_of)) argss'))
+    (* TODO: higher order arguments also occur in tuples! *)
+    val ho_argss = distinct (op =) (maps (filter (is_funtype o fastype_of)) argss)
+    val params = distinct (op =) (maps (filter (is_funtype o fastype_of)) argss')
+    val pnames = map dest_Free params
     val preds = map pred_of funs
     val prednames = map (fst o dest_Free) preds
     val funnames = map (fst o dest_Const) funs
     val fun_pred_names = (funnames ~~ prednames)  
       (* mapping from term (Free or Const) to term *)
-    fun lookup_pred (Const (name, T)) =
-      (case (Symtab.lookup (Pred_Compile_Preproc.get thy) name) of
-          SOME c => Const (c, pred_type T)
-        | NONE =>
-          (case AList.lookup op = fun_pred_names name of
-            SOME f => Free (f, pred_type T)
-          | NONE => Const (name, T)))
-      | lookup_pred (Free (name, T)) =
-        if member op = (map fst pnames) name then
-          Free (name, transform_ho_typ T)
-        else
-          Free (name, T)
-      | lookup_pred t =
-         error ("lookup function is not defined for " ^ Syntax.string_of_term_global thy t)
-     
-        (* mapping from term (predicate term, not function term!) to int *)
-    fun get_nparams (Const (name, _)) =
-      the_default 0 (try (ind_package_get_nparams thy) name)
-    | get_nparams (Free (name, _)) =
-        (if member op = prednames name then
-          length pnames
-        else 0)
-    | get_nparams t = error ("No parameters for " ^ (Syntax.string_of_term_global thy t))
-  
+    fun map_Free f = Free o f o dest_Free
+    val net = fold Item_Net.update
+      ((funs ~~ preds) @ (ho_argss ~~ params))
+        (Fun_Pred.get thy)
+    fun lookup_pred t = lookup thy net t
     (* create intro rules *)
   
     fun mk_intros ((func, pred), (args, rhs)) =
@@ -314,14 +336,15 @@
       else
         let
           val names = Term.add_free_names rhs []
-        in mk_prems thy (lookup_pred, get_nparams) rhs (names, [])
+        in mk_prems thy lookup_pred rhs (names, [])
           |> map (fn (resultt, (names', prems)) =>
             Logic.list_implies (prems, HOLogic.mk_Trueprop (list_comb (pred, args @ [resultt]))))
         end
     fun mk_rewr_thm (func, pred) = @{thm refl}
   in
-    case try (maps mk_intros) ((funs ~~ preds) ~~ (argss' ~~ rhss)) of
-      NONE => ([], thy) 
+    case (*try *)SOME (maps mk_intros ((funs ~~ preds) ~~ (argss' ~~ rhss))) of
+      NONE =>
+        let val _ = tracing "error occured!" in ([], thy) end
     | SOME intr_ts =>
         if is_some (try (map (cterm_of thy)) intr_ts) then
           let
@@ -333,53 +356,59 @@
                   no_elim = false, no_ind = false, skip_mono = false, fork_mono = false}
                 (map (fn (s, T) =>
                   ((Binding.name s, T), NoSyn)) (distinct (op =) (map dest_Free preds)))
-                pnames
+                []
                 (map (fn x => (Attrib.empty_binding, x)) intr_ts)
                 []
               ||> Sign.restore_naming thy
             val prednames = map (fst o dest_Const) (#preds ind_result)
             (* val rewr_thms = map mk_rewr_eq ((distinct (op =) funs) ~~ (#preds ind_result)) *)
             (* add constants to my table *)
+            
             val specs = map (fn predname => (predname, filter (Predicate_Compile_Aux.is_intro predname)
               (#intrs ind_result))) prednames
+            (*
             val thy'' = Pred_Compile_Preproc.map (fold Symtab.update_new (consts ~~ prednames)) thy'
+            *)
+            
+            val thy'' = Fun_Pred.map
+              (fold Item_Net.update (map (apfst Logic.varify)
+                (distinct (op =) funs ~~ (#preds ind_result)))) thy'
+            (*val _ = print_specs thy'' specs*)
           in
             (specs, thy'')
           end
         else
           let
-            val _ = tracing "Introduction rules of function_predicate are not welltyped"
+            val _ = Output.tracing (
+            "Introduction rules of function_predicate are not welltyped: " ^
+              commas (map (Syntax.string_of_term_global thy) intr_ts))
           in ([], thy) end
   end
 
 fun rewrite_intro thy intro =
   let
-    fun lookup_pred (Const (name, T)) =
+    (*val _ = tracing ("Rewriting intro with registered mapping for: " ^
+      commas (Symtab.keys (Pred_Compile_Preproc.get thy)))*)
+    (*fun lookup_pred (Const (name, T)) =
       (case (Symtab.lookup (Pred_Compile_Preproc.get thy) name) of
-        SOME c => Const (c, pred_type T)
-      | NONE => error ("Function " ^ name ^ " is not inductified"))
-    | lookup_pred (Free (name, T)) = Free (name, T)
-    | lookup_pred _ = error "lookup function is not defined!"
-
-    fun get_nparams (Const (name, _)) =
-      the_default 0 (try (ind_package_get_nparams thy) name)
-    | get_nparams (Free _) = 0
-    | get_nparams t = error ("No parameters for " ^ (Syntax.string_of_term_global thy t))
-    
+        SOME c => SOME (Const (c, pred_type T))
+      | NONE => NONE)
+    | lookup_pred _ = NONE
+    *)
+    fun lookup_pred t = lookup thy (Fun_Pred.get thy) t
     val intro_t = (Logic.unvarify o prop_of) intro
     val (prems, concl) = Logic.strip_horn intro_t
     val frees = map fst (Term.add_frees intro_t [])
     fun rewrite prem names =
       let
+        (*val _ = tracing ("Rewriting premise " ^ Syntax.string_of_term_global thy prem ^ "...")*)
         val t = (HOLogic.dest_Trueprop prem)
         val (lit, mk_lit) = case try HOLogic.dest_not t of
             SOME t => (t, HOLogic.mk_not)
           | NONE => (t, I)
-        val (P, args) = (strip_comb lit) 
+        val (P, args) = (strip_comb lit)
       in
-        folds_map (
-          fn t => if (is_funtype (fastype_of t)) then (fn x => [(t, x)])
-            else mk_prems thy (lookup_pred, get_nparams) t) args (names, [])
+        folds_map (mk_prems thy lookup_pred) args (names, [])
         |> map (fn (resargs, (names', prems')) =>
           let
             val prem' = HOLogic.mk_Trueprop (mk_lit (list_comb (P, resargs)))