src/HOL/Tools/Predicate_Compile/predicate_compile_fun.ML
author wenzelm
Sat, 20 Mar 2010 17:33:11 +0100
changeset 35845 e5980f0ad025
parent 35411 cafb74a131da
child 35873 d09a58c890e3
permissions -rw-r--r--
renamed varify/unvarify operations to varify_global/unvarify_global to emphasize that these only work in a global situation;

(*  Title:      HOL/Tools/Predicate_Compile/predicate_compile_fun.ML
    Author:     Lukas Bulwahn, TU Muenchen

Preprocessing functions to predicates.
*)

signature PREDICATE_COMPILE_FUN =
sig
  val define_predicates : (string * thm list) list -> theory -> (string * thm list) list * theory
  val rewrite_intro : theory -> thm -> thm list
  val pred_of_function : theory -> string -> string option
  
  val add_function_predicate_translation : (term * term) -> theory -> theory
end;

structure Predicate_Compile_Fun : PREDICATE_COMPILE_FUN =
struct

open Predicate_Compile_Aux;

(* Table from function to inductive predicate *)
structure Fun_Pred = Theory_Data
(
  type T = (term * term) Item_Net.T;
  val empty = Item_Net.init ((op aconv o pairself fst) : (term * term) * (term * term) -> bool)
    (single o fst);
  val extend = I;
  val merge = Item_Net.merge;
)

fun lookup thy net t =
  case Item_Net.retrieve net t of
    [] => NONE
  | [(f, p)] =>
    let
      val subst = Pattern.match thy (f, t) (Vartab.empty, Vartab.empty)
    in
      SOME (Envir.subst_term subst p)
    end
  | _ => error ("Multiple matches possible for lookup of " ^ Syntax.string_of_term_global thy t)

fun pred_of_function thy name =
  case Item_Net.retrieve (Fun_Pred.get thy) (Const (name, Term.dummyT)) of
    [] => NONE
  | [(f, p)] => SOME (fst (dest_Const p))
  | _ => error ("Multiple matches possible for lookup of constant " ^ name)

fun defined_const thy name = is_some (pred_of_function thy name)

fun add_function_predicate_translation (f, p) =
  Fun_Pred.map (Item_Net.update (f, p))

fun transform_ho_typ (T as Type ("fun", _)) =
  let
    val (Ts, T') = strip_type T
  in if T' = @{typ "bool"} then T else (Ts @ [T']) ---> HOLogic.boolT end
| transform_ho_typ t = t

fun transform_ho_arg arg = 
  case (fastype_of arg) of
    (T as Type ("fun", _)) =>
      (case arg of
        Free (name, _) => Free (name, transform_ho_typ T)
      | _ => error "I am surprised")
| _ => arg

fun pred_type T =
  let
    val (Ts, T') = strip_type T
    val Ts' = map transform_ho_typ Ts
  in
    (Ts' @ [T']) ---> HOLogic.boolT
  end;

(* FIXME: create new predicate name -- does not avoid nameclashing *)
fun pred_of f =
  let
    val (name, T) = dest_Const f
  in
    if (body_type T = @{typ bool}) then
      (Free (Long_Name.base_name name ^ "P", T))
    else
      (Free (Long_Name.base_name name ^ "P", pred_type T))
  end

(* creates the list of premises for every intro rule *)
(* theory -> term -> (string list, term list list) *)

fun dest_code_eqn eqn = let
  val (lhs, rhs) = Logic.dest_equals (Logic.unvarify_global (Thm.prop_of eqn))
  val (func, args) = strip_comb lhs
in ((func, args), rhs) end;

(* TODO: does not work with higher order functions yet *)
fun mk_rewr_eq (func, pred) =
  let
    val (argTs, resT) = (strip_type (fastype_of func))
    val nctxt =
      Name.make_context (Term.fold_aterms (fn Free (x, _) => insert (op =) x | _ => I) (func $ pred) [])
    val (argnames, nctxt') = Name.variants (replicate (length argTs) "a") nctxt
    val ([resname], nctxt'') = Name.variants ["r"] nctxt'
    val args = map Free (argnames ~~ argTs)
    val res = Free (resname, resT)
  in Logic.mk_equals
      (HOLogic.mk_eq (res, list_comb (func, args)), list_comb (pred, args @ [res]))
  end;

fun folds_map f xs y =
  let
    fun folds_map' acc [] y = [(rev acc, y)]
      | folds_map' acc (x :: xs) y =
        maps (fn (x, y) => folds_map' (x :: acc) xs y) (f x y)
    in
      folds_map' [] xs y
    end;

fun keep_functions thy t =
  case try dest_Const (fst (strip_comb t)) of
    SOME (c, _) => Predicate_Compile_Data.keep_function thy c
  | _ => false

fun mk_prems thy lookup_pred t (names, prems) =
  let
    fun mk_prems' (t as Const (name, T)) (names, prems) =
      (if is_constr thy name orelse (is_none (lookup_pred t)) then
        [(t, (names, prems))]
      else
       (*(if is_none (try lookup_pred t) then
          [(Abs ("uu", fastype_of t, HOLogic.mk_eq (t, Bound 0)), (names, prems))]
        else*) [(the (lookup_pred t), (names, prems))])
    | mk_prems' (t as Free (f, T)) (names, prems) = 
      (case lookup_pred t of
        SOME t' => [(t', (names, prems))]
      | NONE => [(t, (names, prems))])
    | mk_prems' (t as Abs _) (names, prems) =
      if Predicate_Compile_Aux.is_predT (fastype_of t) then
        ([(Envir.eta_contract t, (names, prems))])
      else
        let
          val (vars, body) = strip_abs t
          val _ = assert (fastype_of body = body_type (fastype_of body))
          val absnames = Name.variant_list names (map fst vars)
          val frees = map2 (curry Free) absnames (map snd vars)
          val body' = subst_bounds (rev frees, body)
          val resname = Name.variant (absnames @ names) "res"
          val resvar = Free (resname, fastype_of body)
          val t = mk_prems' body' ([], [])
            |> map (fn (res, (inner_names, inner_prems)) =>
              let
                fun mk_exists (x, T) t = HOLogic.mk_exists (x, T, t)
                val vTs = 
                  fold Term.add_frees inner_prems []
                  |> filter (fn (x, T) => member (op =) inner_names x)
                val t = 
                  fold mk_exists vTs
                  (foldr1 HOLogic.mk_conj (HOLogic.mk_eq (resvar, res) ::
                    map HOLogic.dest_Trueprop inner_prems))
              in
                t
              end)
              |> foldr1 HOLogic.mk_disj
              |> fold lambda (resvar :: rev frees)
        in
          [(t, (names, prems))]
        end
    | mk_prems' t (names, prems) =
      if Predicate_Compile_Aux.is_constrt thy t orelse keep_functions thy t then
        [(t, (names, prems))]
      else
        case (fst (strip_comb t)) of
          Const (@{const_name "If"}, _) =>
            (let
              val (_, [B, x, y]) = strip_comb t
            in
              (mk_prems' x (names, prems)
              |> map (fn (res, (names, prems)) => (res, (names, (HOLogic.mk_Trueprop B) :: prems))))
              @ (mk_prems' y (names, prems)
              |> map (fn (res, (names, prems)) =>
                (res, (names, (HOLogic.mk_Trueprop (HOLogic.mk_not B)) :: prems))))
            end)
        | Const (@{const_name "Let"}, _) => 
            (let
              val (_, [f, g]) = strip_comb t
            in
              mk_prems' f (names, prems)
              |> maps (fn (res, (names, prems)) =>
                mk_prems' (betapply (g, res)) (names, prems))
            end)
        | Const (@{const_name "split"}, _) => 
            (let
              val (_, [g, res]) = strip_comb t
              val [res1, res2] = Name.variant_list names ["res1", "res2"]
              val (T1, T2) = HOLogic.dest_prodT (fastype_of res)
              val (resv1, resv2) = (Free (res1, T1), Free (res2, T2))
            in
              mk_prems' (betapplys (g, [resv1, resv2]))
              (res1 :: res2 :: names,
              HOLogic.mk_Trueprop (HOLogic.mk_eq (res, HOLogic.mk_prod (resv1, resv2))) :: prems)
            end)
        | _ =>
        if has_split_thm thy (fst (strip_comb t)) then
          let
            val (f, args) = strip_comb t
            val split_thm = prepare_split_thm (ProofContext.init thy) (the (find_split_thm' thy f))
            (* TODO: contextify things - this line is to unvarify the split_thm *)
            (*val ((_, [isplit_thm]), _) = Variable.import true [split_thm] (ProofContext.init thy)*)
            val (assms, concl) = Logic.strip_horn (Thm.prop_of split_thm)
            val (P, [split_t]) = strip_comb (HOLogic.dest_Trueprop concl) 
            val subst = Pattern.match thy (split_t, t) (Vartab.empty, Vartab.empty)
            val (_, split_args) = strip_comb split_t
            val match = split_args ~~ args
            fun mk_prems_of_assm assm =
              let
                val (vTs, assm') = strip_all (Envir.beta_norm (Envir.subst_term subst assm))
                val var_names = Name.variant_list names (map fst vTs)
                val vars = map Free (var_names ~~ (map snd vTs))
                val (prems', pre_res) = Logic.strip_horn (subst_bounds (rev vars, assm'))
                val (_, [inner_t]) = strip_comb (HOLogic.dest_Trueprop pre_res)
                val (lhss : term list, rhss) =
                  split_list (map (HOLogic.dest_eq o HOLogic.dest_Trueprop) prems')
              in
                folds_map mk_prems' lhss (var_names @ names, prems)
                |> map (fn (ress, (names, prems)) =>
                  let
                    val prems' = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (ress ~~ rhss)
                  in (names, prems' @ prems) end)
                |> maps (mk_prems' inner_t)
              end
          in
            maps mk_prems_of_assm assms
          end
        else
          let
            val (f, args) = strip_comb t
            (* TODO: special procedure for higher-order functions: split arguments in
              simple types and function types *)
            val args = map (Pattern.eta_long []) args
            val resname = Name.variant names "res"
            val resvar = Free (resname, body_type (fastype_of t))
            val _ = assert (fastype_of t = body_type (fastype_of t))
            val names' = resname :: names
            fun mk_prems'' (t as Const (c, _)) =
              if is_constr thy c orelse (is_none (lookup_pred t)) then
                let
                  val _ = ()(*tracing ("not translating function " ^ Syntax.string_of_term_global thy t)*)
                in
                folds_map mk_prems' args (names', prems) |>
                map
                  (fn (argvs, (names'', prems')) =>
                  let
                    val prem = HOLogic.mk_Trueprop (HOLogic.mk_eq (resvar, list_comb (f, argvs)))
                  in (names'', prem :: prems') end)
                end
              else
                let
                  (* lookup_pred is falsch für polymorphe Argumente und bool. *)
                  val pred = the (lookup_pred t)
                  val Ts = binder_types (fastype_of pred)
                in
                  folds_map mk_prems' args (names', prems)
                  |> map (fn (argvs, (names'', prems')) =>
                    let
                      fun lift_arg T t =
                        if (fastype_of t) = T then t
                        else
                          let
                            val _ = assert (T =
                              (binder_types (fastype_of t) @ [@{typ bool}] ---> @{typ bool}))
                            fun mk_if T (b, t, e) =
                              Const (@{const_name If}, @{typ bool} --> T --> T --> T) $ b $ t $ e
                            val Ts = binder_types (fastype_of t)
                            val t = 
                            list_abs (map (pair "x") Ts @ [("b", @{typ bool})],
                              mk_if @{typ bool} (list_comb (t, map Bound (length Ts downto 1)),
                              HOLogic.mk_eq (@{term True}, Bound 0),
                              HOLogic.mk_eq (@{term False}, Bound 0)))
                          in
                            t
                          end
                      (*val _ = tracing ("Ts: " ^ commas (map (Syntax.string_of_typ_global thy) Ts))
                      val _ = map2 check_arity Ts (map fastype_of (argvs @ [resvar]))*)
                      val argvs' = map2 lift_arg (fst (split_last Ts)) argvs
                      val prem = HOLogic.mk_Trueprop (list_comb (pred, argvs' @ [resvar]))
                    in (names'', prem :: prems') end)
                end
            | mk_prems'' (t as Free (_, _)) =
              folds_map mk_prems' args (names', prems) |>
                map
                  (fn (argvs, (names'', prems')) =>
                  let
                    val prem = 
                      case lookup_pred t of
                        NONE => HOLogic.mk_Trueprop (HOLogic.mk_eq (resvar, list_comb (f, argvs)))
                      | SOME p => HOLogic.mk_Trueprop (list_comb (p, argvs @ [resvar]))
                  in (names'', prem :: prems') end)
            | mk_prems'' t =
              error ("Invalid term: " ^ Syntax.string_of_term_global thy t)
          in
            map (pair resvar) (mk_prems'' f)
          end
  in
    mk_prems' (Pattern.eta_long [] t) (names, prems)
  end;

(* assumption: mutual recursive predicates all have the same parameters. *)  
fun define_predicates specs thy =
  if forall (fn (const, _) => defined_const thy const) specs then
    ([], thy)
  else
  let
    val consts = map fst specs
    val eqns = maps snd specs
    (*val eqns = maps (Predicate_Compile_Preproc_Data.get_specification thy) consts*)
      (* create prednames *)
    val ((funs, argss), rhss) = map_split dest_code_eqn eqns |>> split_list
    val argss' = map (map transform_ho_arg) argss
    (* TODO: higher order arguments also occur in tuples! *)
    val ho_argss = distinct (op =) (maps (filter (is_funtype o fastype_of)) argss)
    val params = distinct (op =) (maps (filter (is_funtype o fastype_of)) argss')
    val pnames = map dest_Free params
    val preds = map pred_of funs
    val prednames = map (fst o dest_Free) preds
    val funnames = map (fst o dest_Const) funs
    val fun_pred_names = (funnames ~~ prednames)  
      (* mapping from term (Free or Const) to term *)
    fun map_Free f = Free o f o dest_Free
    val net = fold Item_Net.update
      ((funs ~~ preds) @ (ho_argss ~~ params))
        (Fun_Pred.get thy)
    fun lookup_pred t = lookup thy net t
    (* create intro rules *)
  
    fun mk_intros ((func, pred), (args, rhs)) =
      if (body_type (fastype_of func) = @{typ bool}) then
       (*TODO: preprocess predicate definition of rhs *)
        [Logic.list_implies ([HOLogic.mk_Trueprop rhs], HOLogic.mk_Trueprop (list_comb (pred, args)))]
      else
        let
          val names = Term.add_free_names rhs []
        in mk_prems thy lookup_pred rhs (names, [])
          |> map (fn (resultt, (names', prems)) =>
            Logic.list_implies (prems, HOLogic.mk_Trueprop (list_comb (pred, args @ [resultt]))))
        end
    fun mk_rewr_thm (func, pred) = @{thm refl}
  in
    case (*try *)SOME (maps mk_intros ((funs ~~ preds) ~~ (argss' ~~ rhss))) of
      NONE =>
        let val _ = tracing "error occured!" in ([], thy) end
    | SOME intr_ts =>
        if is_some (try (map (cterm_of thy)) intr_ts) then
          let
            val (ind_result, thy') =
              thy
              |> Sign.map_naming Name_Space.conceal
              |> Inductive.add_inductive_global
                {quiet_mode = false, verbose = false, alt_name = Binding.empty, coind = false,
                  no_elim = false, no_ind = false, skip_mono = false, fork_mono = false}
                (map (fn (s, T) =>
                  ((Binding.name s, T), NoSyn)) (distinct (op =) (map dest_Free preds)))
                []
                (map (fn x => (Attrib.empty_binding, x)) intr_ts)
                []
              ||> Sign.restore_naming thy
            val prednames = map (fst o dest_Const) (#preds ind_result)
            (* val rewr_thms = map mk_rewr_eq ((distinct (op =) funs) ~~ (#preds ind_result)) *)
            (* add constants to my table *)
            
            val specs = map (fn predname => (predname, filter (Predicate_Compile_Aux.is_intro predname)
              (#intrs ind_result))) prednames
            (*
            val thy'' = Pred_Compile_Preproc.map (fold Symtab.update_new (consts ~~ prednames)) thy'
            *)
            
            val thy'' = Fun_Pred.map
              (fold Item_Net.update (map (apfst Logic.varify_global)
                (distinct (op =) funs ~~ (#preds ind_result)))) thy'
            (*val _ = print_specs thy'' specs*)
          in
            (specs, thy'')
          end
        else
          let
            val _ = Output.tracing (
            "Introduction rules of function_predicate are not welltyped: " ^
              commas (map (Syntax.string_of_term_global thy) intr_ts))
          in ([], thy) end
  end

fun rewrite_intro thy intro =
  let
    (*val _ = tracing ("Rewriting intro with registered mapping for: " ^
      commas (Symtab.keys (Pred_Compile_Preproc.get thy)))*)
    (*fun lookup_pred (Const (name, T)) =
      (case (Symtab.lookup (Pred_Compile_Preproc.get thy) name) of
        SOME c => SOME (Const (c, pred_type T))
      | NONE => NONE)
    | lookup_pred _ = NONE
    *)
    fun lookup_pred t = lookup thy (Fun_Pred.get thy) t
    val intro_t = Logic.unvarify_global (prop_of intro)
    val (prems, concl) = Logic.strip_horn intro_t
    val frees = map fst (Term.add_frees intro_t [])
    fun rewrite prem names =
      let
        (*val _ = tracing ("Rewriting premise " ^ Syntax.string_of_term_global thy prem ^ "...")*)
        val t = (HOLogic.dest_Trueprop prem)
        val (lit, mk_lit) = case try HOLogic.dest_not t of
            SOME t => (t, HOLogic.mk_not)
          | NONE => (t, I)
        val (P, args) = (strip_comb lit)
      in
        folds_map (mk_prems thy lookup_pred) args (names, [])
        |> map (fn (resargs, (names', prems')) =>
          let
            val prem' = HOLogic.mk_Trueprop (mk_lit (list_comb (P, resargs)))
          in (prem'::prems', names') end)
      end
    val intro_ts' = folds_map rewrite prems frees
      |> maps (fn (prems', frees') =>
        rewrite concl frees'
        |> map (fn (concl'::conclprems, _) =>
          Logic.list_implies ((flat prems') @ conclprems, concl')))
  in
    map (Drule.export_without_context o Skip_Proof.make_thm thy) intro_ts'
  end

end;