src/HOL/Tools/Predicate_Compile/predicate_compile_fun.ML
author wenzelm
Fri May 24 17:00:46 2013 +0200 (2013-05-24)
changeset 52131 366fa32ee2a3
parent 51317 0e70cc4e94e8
child 54229 ca638d713ff8
permissions -rw-r--r--
tuned signature;
     1 (*  Title:      HOL/Tools/Predicate_Compile/predicate_compile_fun.ML
     2     Author:     Lukas Bulwahn, TU Muenchen
     3 
     4 Preprocessing functions to predicates.
     5 *)
     6 
     7 signature PREDICATE_COMPILE_FUN =
     8 sig
     9   val define_predicates : (string * thm list) list -> theory -> (string * thm list) list * theory
    10   val rewrite_intro : theory -> thm -> thm list
    11   val pred_of_function : theory -> string -> string option
    12   val add_function_predicate_translation : (term * term) -> theory -> theory
    13 end;
    14 
    15 structure Predicate_Compile_Fun : PREDICATE_COMPILE_FUN =
    16 struct
    17 
    18 open Predicate_Compile_Aux;
    19 
    20 (* Table from function to inductive predicate *)
    21 structure Fun_Pred = Theory_Data
    22 (
    23   type T = (term * term) Item_Net.T;
    24   val empty : T = Item_Net.init (op aconv o pairself fst) (single o fst);
    25   val extend = I;
    26   val merge = Item_Net.merge;
    27 )
    28 
    29 fun lookup thy net t =
    30   let
    31     val poss_preds = map_filter (fn (f, p) =>
    32     SOME (Envir.subst_term (Pattern.match thy (f, t) (Vartab.empty, Vartab.empty)) p)
    33     handle Pattern.MATCH => NONE) (Item_Net.retrieve net t)
    34   in
    35     case poss_preds of
    36       [p] => SOME p
    37     | _ => NONE
    38   end
    39 
    40 fun pred_of_function thy name =
    41   case Item_Net.retrieve (Fun_Pred.get thy) (Const (name, dummyT)) of
    42     [] => NONE
    43   | [(_, p)] => SOME (fst (dest_Const p))
    44   | _ => error ("Multiple matches possible for lookup of constant " ^ name)
    45 
    46 fun defined_const thy name = is_some (pred_of_function thy name)
    47 
    48 fun add_function_predicate_translation (f, p) =
    49   Fun_Pred.map (Item_Net.update (f, p))
    50 
    51 fun transform_ho_typ (T as Type ("fun", _)) =
    52   let
    53     val (Ts, T') = strip_type T
    54   in if T' = HOLogic.boolT then T else (Ts @ [T']) ---> HOLogic.boolT end
    55 | transform_ho_typ t = t
    56 
    57 fun transform_ho_arg arg = 
    58   case (fastype_of arg) of
    59     (T as Type ("fun", _)) =>
    60       (case arg of
    61         Free (name, _) => Free (name, transform_ho_typ T)
    62       | _ => raise Fail "A non-variable term at a higher-order position")
    63   | _ => arg
    64 
    65 fun pred_type T =
    66   let
    67     val (Ts, T') = strip_type T
    68     val Ts' = map transform_ho_typ Ts
    69   in
    70     (Ts' @ [T']) ---> HOLogic.boolT
    71   end;
    72 
    73 (* creates the list of premises for every intro rule *)
    74 (* theory -> term -> (string list, term list list) *)
    75 
    76 fun dest_code_eqn eqn = let
    77   val (lhs, rhs) = Logic.dest_equals (Logic.unvarify_global (Thm.prop_of eqn))
    78   val (func, args) = strip_comb lhs
    79 in ((func, args), rhs) end;
    80 
    81 fun folds_map f xs y =
    82   let
    83     fun folds_map' acc [] y = [(rev acc, y)]
    84       | folds_map' acc (x :: xs) y =
    85         maps (fn (x, y) => folds_map' (x :: acc) xs y) (f x y)
    86     in
    87       folds_map' [] xs y
    88     end;
    89 
    90 fun keep_functions thy t =
    91   case try dest_Const (fst (strip_comb t)) of
    92     SOME (c, _) => Predicate_Compile_Data.keep_function thy c
    93   | _ => false
    94 
    95 fun flatten thy lookup_pred t (names, prems) =
    96   let
    97     fun lift t (names, prems) =
    98       case lookup_pred (Envir.eta_contract t) of
    99         SOME pred => [(pred, (names, prems))]
   100       | NONE =>
   101         let
   102           val (vars, body) = strip_abs t
   103           val _ = @{assert} (fastype_of body = body_type (fastype_of body))
   104           val absnames = Name.variant_list names (map fst vars)
   105           val frees = map2 (curry Free) absnames (map snd vars)
   106           val body' = subst_bounds (rev frees, body)
   107           val resname = singleton (Name.variant_list (absnames @ names)) "res"
   108           val resvar = Free (resname, fastype_of body)
   109           val t = flatten' body' ([], [])
   110             |> map (fn (res, (inner_names, inner_prems)) =>
   111               let
   112                 fun mk_exists (x, T) t = HOLogic.mk_exists (x, T, t)
   113                 val vTs = 
   114                   fold Term.add_frees inner_prems []
   115                   |> filter (fn (x, _) => member (op =) inner_names x)
   116                 val t = 
   117                   fold mk_exists vTs
   118                   (foldr1 HOLogic.mk_conj (HOLogic.mk_eq (res, resvar) ::
   119                     map HOLogic.dest_Trueprop inner_prems))
   120               in
   121                 t
   122               end)
   123               |> foldr1 HOLogic.mk_disj
   124               |> fold lambda (resvar :: rev frees)
   125         in
   126           [(t, (names, prems))]
   127         end
   128     and flatten_or_lift (t, T) (names, prems) =
   129       if fastype_of t = T then
   130         flatten' t (names, prems)
   131       else
   132         (* note pred_type might be to general! *)
   133         if (pred_type (fastype_of t) = T) then
   134           lift t (names, prems)
   135         else
   136           error ("unexpected input for flatten or lift" ^ Syntax.string_of_term_global thy t ^
   137           ", " ^  Syntax.string_of_typ_global thy T)
   138     and flatten' (t as Const _) (names, prems) = [(t, (names, prems))]
   139       | flatten' (t as Free _) (names, prems) = [(t, (names, prems))]
   140       | flatten' (t as Abs _) (names, prems) = [(t, (names, prems))]
   141       | flatten' (t as _ $ _) (names, prems) =
   142       if Predicate_Compile_Aux.is_constrt thy t orelse keep_functions thy t then
   143         [(t, (names, prems))]
   144       else
   145         case (fst (strip_comb t)) of
   146           Const (@{const_name "If"}, _) =>
   147             (let
   148               val (_, [B, x, y]) = strip_comb t
   149             in
   150               flatten' B (names, prems)
   151               |> maps (fn (B', (names, prems)) =>
   152                 (flatten' x (names, prems)
   153                 |> map (fn (res, (names, prems)) => (res, (names, (HOLogic.mk_Trueprop B') :: prems))))
   154                 @ (flatten' y (names, prems)
   155                 |> map (fn (res, (names, prems)) =>
   156                   (* in general unsound! *)
   157                   (res, (names, (HOLogic.mk_Trueprop (HOLogic.mk_not B')) :: prems)))))
   158             end)
   159         | Const (@{const_name "Let"}, _) => 
   160             (let
   161               val (_, [f, g]) = strip_comb t
   162             in
   163               flatten' f (names, prems)
   164               |> maps (fn (res, (names, prems)) =>
   165                 flatten' (betapply (g, res)) (names, prems))
   166             end)
   167         | _ =>
   168         case find_split_thm thy (fst (strip_comb t)) of
   169           SOME raw_split_thm =>
   170           let
   171             val split_thm = prepare_split_thm (Proof_Context.init_global thy) raw_split_thm
   172             val (assms, concl) = Logic.strip_horn (prop_of split_thm)
   173             val (_, [split_t]) = strip_comb (HOLogic.dest_Trueprop concl)
   174             val t' = case_betapply thy t
   175             val subst = Pattern.match thy (split_t, t') (Vartab.empty, Vartab.empty)
   176             fun flatten_of_assm assm =
   177               let
   178                 val (vTs, assm') = strip_all (Envir.beta_norm (Envir.subst_term subst assm))
   179                 val var_names = Name.variant_list names (map fst vTs)
   180                 val vars = map Free (var_names ~~ (map snd vTs))
   181                 val (prems', pre_res) = Logic.strip_horn (subst_bounds (rev vars, assm'))
   182                 val (_, [inner_t]) = strip_comb (HOLogic.dest_Trueprop pre_res)
   183                 val (lhss : term list, rhss) =
   184                   split_list (map (HOLogic.dest_eq o HOLogic.dest_Trueprop) prems')
   185               in
   186                 folds_map flatten' lhss (var_names @ names, prems)
   187                 |> map (fn (ress, (names, prems)) =>
   188                   let
   189                     val prems' = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (ress ~~ rhss)
   190                   in (names, prems' @ prems) end)
   191                 |> maps (flatten' inner_t)
   192               end
   193           in
   194             maps flatten_of_assm assms
   195           end
   196       | NONE =>
   197           let
   198             val (f, args) = strip_comb t
   199             val args = map (Envir.eta_long []) args
   200             val _ = @{assert} (fastype_of t = body_type (fastype_of t))
   201             val f' = lookup_pred f
   202             val Ts = case f' of
   203               SOME pred => (fst (split_last (binder_types (fastype_of pred))))
   204             | NONE => binder_types (fastype_of f)
   205           in
   206             folds_map flatten_or_lift (args ~~ Ts) (names, prems) |>
   207             (case f' of
   208               NONE =>
   209                 map (fn (argvs, (names', prems')) => (list_comb (f, argvs), (names', prems')))
   210             | SOME pred =>
   211                 map (fn (argvs, (names', prems')) =>
   212                   let
   213                     fun lift_arg T t =
   214                       if (fastype_of t) = T then t
   215                       else
   216                         let
   217                           val _ = @{assert} (T =
   218                             (binder_types (fastype_of t) @ [@{typ bool}] ---> @{typ bool}))
   219                           fun mk_if T (b, t, e) =
   220                             Const (@{const_name If}, @{typ bool} --> T --> T --> T) $ b $ t $ e
   221                           val Ts = binder_types (fastype_of t)
   222                         in
   223                           fold_rev Term.abs (map (pair "x") Ts @ [("b", @{typ bool})])
   224                             (mk_if @{typ bool} (list_comb (t, map Bound (length Ts downto 1)),
   225                               HOLogic.mk_eq (@{term True}, Bound 0),
   226                               HOLogic.mk_eq (@{term False}, Bound 0)))
   227                         end
   228                     val argvs' = map2 lift_arg Ts argvs
   229                     val resname = singleton (Name.variant_list names') "res"
   230                     val resvar = Free (resname, body_type (fastype_of t))
   231                     val prem = HOLogic.mk_Trueprop (list_comb (pred, argvs' @ [resvar]))
   232                   in (resvar, (resname :: names', prem :: prems')) end))
   233           end
   234   in
   235     map (apfst Envir.eta_contract) (flatten' (Envir.eta_long [] t) (names, prems))
   236   end;
   237 
   238 (* FIXME: create new predicate name -- does not avoid nameclashing *)
   239 fun pred_of thy f =
   240   let
   241     val (name, T) = dest_Const f
   242     val base_name' = (Long_Name.base_name name ^ "P")
   243     val name' = Sign.full_bname thy base_name'
   244     val T' = if (body_type T = @{typ bool}) then T else pred_type T
   245   in
   246     (name', Const (name', T'))
   247   end
   248 
   249 (* assumption: mutual recursive predicates all have the same parameters. *)
   250 fun define_predicates specs thy =
   251   if forall (fn (const, _) => defined_const thy const) specs then
   252     ([], thy)
   253   else
   254     let
   255       val eqns = maps snd specs
   256       (* create prednames *)
   257       val ((funs, argss), rhss) = map_split dest_code_eqn eqns |>> split_list
   258       val dst_funs = distinct (op =) funs
   259       val argss' = map (map transform_ho_arg) argss
   260       fun is_lifted (t1, t2) = (fastype_of t2 = pred_type (fastype_of t1))
   261       (* FIXME: higher order arguments also occur in tuples! *)
   262       val lifted_args = distinct (op =) (filter is_lifted (flat argss ~~ flat argss'))
   263       val (prednames, preds) = split_list (map (pred_of thy) funs)
   264       val dst_preds = distinct (op =) preds
   265       val dst_prednames = distinct (op =) prednames
   266       (* mapping from term (Free or Const) to term *)
   267       val net = fold Item_Net.update
   268         ((dst_funs ~~ dst_preds) @ lifted_args)
   269           (Fun_Pred.get thy)
   270       fun lookup_pred t = lookup thy net t
   271       (* create intro rules *)
   272       fun mk_intros ((func, pred), (args, rhs)) =
   273         if (body_type (fastype_of func) = @{typ bool}) then
   274          (* TODO: preprocess predicate definition of rhs *)
   275           [Logic.list_implies ([HOLogic.mk_Trueprop rhs], HOLogic.mk_Trueprop (list_comb (pred, args)))]
   276         else
   277           let
   278             val names = Term.add_free_names rhs []
   279           in flatten thy lookup_pred rhs (names, [])
   280             |> map (fn (resultt, (_, prems)) =>
   281               Logic.list_implies (prems, HOLogic.mk_Trueprop (list_comb (pred, args @ [resultt]))))
   282           end
   283       val intr_ts = maps mk_intros ((funs ~~ preds) ~~ (argss' ~~ rhss))
   284       val (intrs, thy') = thy
   285         |> Sign.add_consts_i
   286           (map (fn Const (name, T) => (Binding.name (Long_Name.base_name name), T, NoSyn))
   287            dst_preds)
   288         |> fold_map Specification.axiom
   289             (map (fn t => ((Binding.name ("unnamed_axiom_" ^ serial_string ()), []), t)) intr_ts)
   290       val specs = map (fn predname => (predname,
   291           map Drule.export_without_context (filter (Predicate_Compile_Aux.is_intro predname) intrs)))
   292         dst_prednames
   293       val thy'' = Fun_Pred.map
   294         (fold Item_Net.update (map (pairself Logic.varify_global)
   295           (dst_funs ~~ dst_preds))) thy'
   296       fun functional_mode_of T =
   297         list_fun_mode (replicate (length (binder_types T)) Input @ [Output])
   298       val thy''' = fold
   299         (fn (predname, Const (name, T)) => Core_Data.register_alternative_function
   300           predname (functional_mode_of T) name)
   301       (dst_prednames ~~ dst_funs) thy''
   302     in
   303       (specs, thy''')
   304     end
   305 
   306 fun rewrite_intro thy intro =
   307   let
   308     fun lookup_pred t = lookup thy (Fun_Pred.get thy) t
   309     (*val _ = tracing ("Rewriting intro " ^ Display.string_of_thm_global thy intro)*)
   310     val intro_t = Logic.unvarify_global (prop_of intro)
   311     val (prems, concl) = Logic.strip_horn intro_t
   312     val frees = map fst (Term.add_frees intro_t [])
   313     fun rewrite prem names =
   314       let
   315         (*val _ = tracing ("Rewriting premise " ^ Syntax.string_of_term_global thy prem ^ "...")*)
   316         val t = HOLogic.dest_Trueprop prem
   317         val (lit, mk_lit) = case try HOLogic.dest_not t of
   318             SOME t => (t, HOLogic.mk_not)
   319           | NONE => (t, I)
   320         val (P, args) = strip_comb lit
   321       in
   322         folds_map (flatten thy lookup_pred) args (names, [])
   323         |> map (fn (resargs, (names', prems')) =>
   324           let
   325             val prem' = HOLogic.mk_Trueprop (mk_lit (list_comb (P, resargs)))
   326           in (prems' @ [prem'], names') end)
   327       end
   328     val intro_ts' = folds_map rewrite prems frees
   329       |> maps (fn (prems', frees') =>
   330         rewrite concl frees'
   331         |> map (fn (conclprems, _) =>
   332           let
   333             val (conclprems', concl') = split_last conclprems
   334           in
   335             Logic.list_implies ((flat prems') @ conclprems', concl')
   336           end))
   337     (*val _ = tracing ("Rewritten intro to " ^
   338       commas (map (Syntax.string_of_term_global thy) intro_ts'))*)
   339   in
   340     map (Drule.export_without_context o Skip_Proof.make_thm thy) intro_ts'
   341   end
   342 
   343 end;