src/HOL/Tools/Predicate_Compile/pred_compile_fun.ML
changeset 33334 cba65e4bf565
parent 33325 7994994c4d8e
parent 33333 78faaec3209f
child 33335 1e189f96c393
child 33344 7a1f597f454e
equal deleted inserted replaced
33325:7994994c4d8e 33334:cba65e4bf565
     1 (* Author: Lukas Bulwahn, TU Muenchen
       
     2 
       
     3 Preprocessing functions to predicates
       
     4 *)
       
     5 
       
     6 signature PREDICATE_COMPILE_FUN =
       
     7 sig
       
     8 val define_predicates : (string * thm list) list -> theory -> (string * thm list) list * theory
       
     9   val rewrite_intro : theory -> thm -> thm list
       
    10   val setup_oracle : theory -> theory
       
    11   val pred_of_function : theory -> string -> string option
       
    12 end;
       
    13 
       
    14 structure Predicate_Compile_Fun : PREDICATE_COMPILE_FUN =
       
    15 struct
       
    16 
       
    17 
       
    18 (* Oracle for preprocessing  *)
       
    19 
       
    20 val (oracle : (string * (cterm -> thm)) option Unsynchronized.ref) = Unsynchronized.ref NONE;
       
    21 
       
    22 fun the_oracle () =
       
    23   case !oracle of
       
    24     NONE => error "Oracle is not setup"
       
    25   | SOME (_, oracle) => oracle
       
    26              
       
    27 val setup_oracle = Thm.add_oracle (Binding.name "pred_compile_preprocessing", I) #->
       
    28   (fn ora => fn thy => let val _ = (oracle := SOME ora) in thy end)
       
    29   
       
    30   
       
    31 fun is_funtype (Type ("fun", [_, _])) = true
       
    32   | is_funtype _ = false;
       
    33 
       
    34 fun is_Type (Type _) = true
       
    35   | is_Type _ = false
       
    36 
       
    37 (* returns true if t is an application of an datatype constructor *)
       
    38 (* which then consequently would be splitted *)
       
    39 (* else false *)
       
    40 (*
       
    41 fun is_constructor thy t =
       
    42   if (is_Type (fastype_of t)) then
       
    43     (case DatatypePackage.get_datatype thy ((fst o dest_Type o fastype_of) t) of
       
    44       NONE => false
       
    45     | SOME info => (let
       
    46       val constr_consts = maps (fn (_, (_, _, constrs)) => map fst constrs) (#descr info)
       
    47       val (c, _) = strip_comb t
       
    48       in (case c of
       
    49         Const (name, _) => name mem_string constr_consts
       
    50         | _ => false) end))
       
    51   else false
       
    52 *)
       
    53 
       
    54 (* must be exported in code.ML *)
       
    55 fun is_constr thy = is_some o Code.get_datatype_of_constr thy;
       
    56 
       
    57 (* Table from constant name (string) to term of inductive predicate *)
       
    58 structure Pred_Compile_Preproc = TheoryDataFun
       
    59 (
       
    60   type T = string Symtab.table;
       
    61   val empty = Symtab.empty;
       
    62   val copy = I;
       
    63   val extend = I;
       
    64   fun merge _ = Symtab.merge (op =);
       
    65 )
       
    66 
       
    67 fun pred_of_function thy name = Symtab.lookup (Pred_Compile_Preproc.get thy) name
       
    68 
       
    69 fun defined thy = Symtab.defined (Pred_Compile_Preproc.get thy) 
       
    70 
       
    71 
       
    72 fun transform_ho_typ (T as Type ("fun", _)) =
       
    73   let
       
    74     val (Ts, T') = strip_type T
       
    75   in if T' = @{typ "bool"} then T else (Ts @ [T']) ---> HOLogic.boolT end
       
    76 | transform_ho_typ t = t
       
    77 
       
    78 fun transform_ho_arg arg = 
       
    79   case (fastype_of arg) of
       
    80     (T as Type ("fun", _)) =>
       
    81       (case arg of
       
    82         Free (name, _) => Free (name, transform_ho_typ T)
       
    83       | _ => error "I am surprised")
       
    84 | _ => arg
       
    85 
       
    86 fun pred_type T =
       
    87   let
       
    88     val (Ts, T') = strip_type T
       
    89     val Ts' = map transform_ho_typ Ts
       
    90   in
       
    91     (Ts' @ [T']) ---> HOLogic.boolT
       
    92   end;
       
    93 
       
    94 (* FIXME: create new predicate name -- does not avoid nameclashing *)
       
    95 fun pred_of f =
       
    96   let
       
    97     val (name, T) = dest_Const f
       
    98   in
       
    99     if (body_type T = @{typ bool}) then
       
   100       (Free (Long_Name.base_name name ^ "P", T))
       
   101     else
       
   102       (Free (Long_Name.base_name name ^ "P", pred_type T))
       
   103   end
       
   104 
       
   105 fun mk_param thy lookup_pred (t as Free (v, _)) = lookup_pred t
       
   106   | mk_param thy lookup_pred t =
       
   107   let
       
   108   val _ = tracing ("called param with " ^ (Syntax.string_of_term_global thy t))
       
   109   in if Predicate_Compile_Aux.is_predT (fastype_of t) then
       
   110     t
       
   111   else
       
   112     let
       
   113       val (vs, body) = strip_abs t
       
   114       val names = Term.add_free_names body []
       
   115       val vs_names = Name.variant_list names (map fst vs)
       
   116       val vs' = map2 (curry Free) vs_names (map snd vs)
       
   117       val body' = subst_bounds (rev vs', body)
       
   118       val (f, args) = strip_comb body'
       
   119       val resname = Name.variant (vs_names @ names) "res"
       
   120       val resvar = Free (resname, body_type (fastype_of body'))
       
   121       (*val P = case try lookup_pred f of SOME P => P | NONE => error "mk_param"
       
   122       val pred_body = list_comb (P, args @ [resvar])
       
   123       *)
       
   124       val pred_body = HOLogic.mk_eq (body', resvar)
       
   125       val param = fold_rev lambda (vs' @ [resvar]) pred_body
       
   126     in param end
       
   127   end
       
   128 (* creates the list of premises for every intro rule *)
       
   129 (* theory -> term -> (string list, term list list) *)
       
   130 
       
   131 fun dest_code_eqn eqn = let
       
   132   val (lhs, rhs) = Logic.dest_equals (Logic.unvarify (Thm.prop_of eqn))
       
   133   val (func, args) = strip_comb lhs
       
   134 in ((func, args), rhs) end;
       
   135 
       
   136 fun string_of_typ T = Syntax.string_of_typ_global @{theory} T
       
   137 
       
   138 fun string_of_term t =
       
   139   case t of
       
   140     Const (c, T) => "Const (" ^ c ^ ", " ^ string_of_typ T ^ ")"
       
   141   | Free (c, T) => "Free (" ^ c ^ ", " ^ string_of_typ T ^ ")"
       
   142   | Var ((c, i), T) => "Var ((" ^ c ^ ", " ^ string_of_int i ^ "), " ^ string_of_typ T ^ ")"
       
   143   | Bound i => "Bound " ^ string_of_int i
       
   144   | Abs (x, T, t) => "Abs (" ^ x ^ ", " ^ string_of_typ T ^ ", " ^ string_of_term t ^ ")"
       
   145   | t1 $ t2 => "(" ^ string_of_term t1 ^ ") $ (" ^ string_of_term t2 ^ ")"
       
   146   
       
   147 fun ind_package_get_nparams thy name =
       
   148   case try (Inductive.the_inductive (ProofContext.init thy)) name of
       
   149     SOME (_, result) => length (Inductive.params_of (#raw_induct result))
       
   150   | NONE => error ("No such predicate: " ^ quote name) 
       
   151 
       
   152 (* TODO: does not work with higher order functions yet *)
       
   153 fun mk_rewr_eq (func, pred) =
       
   154   let
       
   155     val (argTs, resT) = (strip_type (fastype_of func))
       
   156     val nctxt =
       
   157       Name.make_context (Term.fold_aterms (fn Free (x, _) => insert (op =) x | _ => I) (func $ pred) [])
       
   158     val (argnames, nctxt') = Name.variants (replicate (length argTs) "a") nctxt
       
   159     val ([resname], nctxt'') = Name.variants ["r"] nctxt'
       
   160     val args = map Free (argnames ~~ argTs)
       
   161     val res = Free (resname, resT)
       
   162   in Logic.mk_equals
       
   163       (HOLogic.mk_eq (res, list_comb (func, args)), list_comb (pred, args @ [res]))
       
   164   end;
       
   165 
       
   166 fun has_split_rule_cname @{const_name "nat_case"} = true
       
   167   | has_split_rule_cname @{const_name "list_case"} = true
       
   168   | has_split_rule_cname _ = false
       
   169   
       
   170 fun has_split_rule_term thy (Const (@{const_name "nat_case"}, _)) = true 
       
   171   | has_split_rule_term thy (Const (@{const_name "list_case"}, _)) = true 
       
   172   | has_split_rule_term thy _ = false
       
   173 
       
   174 fun has_split_rule_term' thy (Const (@{const_name "If"}, _)) = true
       
   175   | has_split_rule_term' thy (Const (@{const_name "Let"}, _)) = true
       
   176   | has_split_rule_term' thy c = has_split_rule_term thy c
       
   177   
       
   178 fun prepare_split_thm ctxt split_thm =
       
   179     (split_thm RS @{thm iffD2})
       
   180     |> LocalDefs.unfold ctxt [@{thm atomize_conjL[symmetric]},
       
   181       @{thm atomize_all[symmetric]}, @{thm atomize_imp[symmetric]}]
       
   182 
       
   183 fun find_split_thm thy (Const (name, typ)) =
       
   184   let
       
   185     fun split_name str =
       
   186       case first_field "." str
       
   187         of (SOME (field, rest)) => field :: split_name rest
       
   188          | NONE => [str]
       
   189     val splitted_name = split_name name
       
   190   in
       
   191     if length splitted_name > 0 andalso
       
   192        String.isSuffix "_case" (List.last splitted_name)
       
   193     then
       
   194       (List.take (splitted_name, length splitted_name - 1)) @ ["split"]
       
   195       |> space_implode "."
       
   196       |> PureThy.get_thm thy
       
   197       |> SOME
       
   198       handle ERROR msg => NONE
       
   199     else NONE
       
   200   end
       
   201   | find_split_thm _ _ = NONE
       
   202 
       
   203 fun find_split_thm' thy (Const (@{const_name "If"}, _)) = SOME @{thm split_if}
       
   204   | find_split_thm' thy (Const (@{const_name "Let"}, _)) = SOME @{thm refl} (* TODO *)
       
   205   | find_split_thm' thy c = find_split_thm thy c
       
   206 
       
   207 fun strip_all t = (Term.strip_all_vars t, Term.strip_all_body t)
       
   208 
       
   209 fun folds_map f xs y =
       
   210   let
       
   211     fun folds_map' acc [] y = [(rev acc, y)]
       
   212       | folds_map' acc (x :: xs) y =
       
   213         maps (fn (x, y) => folds_map' (x :: acc) xs y) (f x y)
       
   214     in
       
   215       folds_map' [] xs y
       
   216     end;
       
   217 
       
   218 fun mk_prems thy (lookup_pred, get_nparams) t (names, prems) =
       
   219   let
       
   220     fun mk_prems' (t as Const (name, T)) (names, prems) =
       
   221       if is_constr thy name orelse (is_none (try lookup_pred t)) then
       
   222         [(t, (names, prems))]
       
   223       else [(lookup_pred t, (names, prems))]
       
   224     | mk_prems' (t as Free (f, T)) (names, prems) = 
       
   225       [(lookup_pred t, (names, prems))]
       
   226     | mk_prems' (t as Abs _) (names, prems) =
       
   227       if Predicate_Compile_Aux.is_predT (fastype_of t) then
       
   228       [(t, (names, prems))] else error "mk_prems': Abs "
       
   229       (* mk_param *)
       
   230     | mk_prems' t (names, prems) =
       
   231       if Predicate_Compile_Aux.is_constrt thy t then
       
   232         [(t, (names, prems))]
       
   233       else
       
   234         if has_split_rule_term' thy (fst (strip_comb t)) then
       
   235           let
       
   236             val (f, args) = strip_comb t
       
   237             val split_thm = prepare_split_thm (ProofContext.init thy) (the (find_split_thm' thy f))
       
   238             (* TODO: contextify things - this line is to unvarify the split_thm *)
       
   239             (*val ((_, [isplit_thm]), _) = Variable.import true [split_thm] (ProofContext.init thy)*)
       
   240             val (assms, concl) = Logic.strip_horn (Thm.prop_of split_thm)
       
   241             val (P, [split_t]) = strip_comb (HOLogic.dest_Trueprop concl) 
       
   242             val subst = Pattern.match thy (split_t, t) (Vartab.empty, Vartab.empty)
       
   243             val (_, split_args) = strip_comb split_t
       
   244             val match = split_args ~~ args
       
   245             fun mk_prems_of_assm assm =
       
   246               let
       
   247                 val (vTs, assm') = strip_all (Envir.beta_norm (Envir.subst_term subst assm))
       
   248                 val var_names = Name.variant_list names (map fst vTs)
       
   249                 val vars = map Free (var_names ~~ (map snd vTs))
       
   250                 val (prems', pre_res) = Logic.strip_horn (subst_bounds (rev vars, assm'))
       
   251                 val (_, [inner_t]) = strip_comb (HOLogic.dest_Trueprop pre_res)
       
   252               in
       
   253                 mk_prems' inner_t (var_names @ names, prems' @ prems)
       
   254               end
       
   255           in
       
   256             maps mk_prems_of_assm assms
       
   257           end
       
   258         else
       
   259           let
       
   260             val (f, args) = strip_comb t
       
   261             (* TODO: special procedure for higher-order functions: split arguments in
       
   262               simple types and function types *)
       
   263             val resname = Name.variant names "res"
       
   264             val resvar = Free (resname, body_type (fastype_of t))
       
   265             val names' = resname :: names
       
   266             fun mk_prems'' (t as Const (c, _)) =
       
   267               if is_constr thy c orelse (is_none (try lookup_pred t)) then
       
   268                 folds_map mk_prems' args (names', prems) |>
       
   269                 map
       
   270                   (fn (argvs, (names'', prems')) =>
       
   271                   let
       
   272                     val prem = HOLogic.mk_Trueprop (HOLogic.mk_eq (resvar, list_comb (f, argvs)))
       
   273                   in (names'', prem :: prems') end)
       
   274               else
       
   275                 let
       
   276                   val pred = lookup_pred t
       
   277                   val nparams = get_nparams pred
       
   278                   val (params, args) = chop nparams args
       
   279                   val params' = map (mk_param thy lookup_pred) params
       
   280                 in
       
   281                   folds_map mk_prems' args (names', prems)
       
   282                   |> map (fn (argvs, (names'', prems')) =>
       
   283                     let
       
   284                       val prem = HOLogic.mk_Trueprop (list_comb (pred, params' @ argvs @ [resvar]))
       
   285                     in (names'', prem :: prems') end)
       
   286                 end
       
   287             | mk_prems'' (t as Free (_, _)) =
       
   288                 let
       
   289                   (* higher order argument call *)
       
   290                   val pred = lookup_pred t
       
   291                 in
       
   292                   folds_map mk_prems' args (resname :: names, prems)
       
   293                   |> map (fn (argvs, (names', prems')) =>
       
   294                      let
       
   295                        val prem = HOLogic.mk_Trueprop (list_comb (pred, argvs @ [resvar]))
       
   296                      in (names', prem :: prems') end)
       
   297                 end
       
   298             | mk_prems'' t =
       
   299               error ("Invalid term: " ^ Syntax.string_of_term_global thy t)
       
   300           in
       
   301             map (pair resvar) (mk_prems'' f)
       
   302           end
       
   303   in
       
   304     mk_prems' t (names, prems)
       
   305   end;
       
   306 
       
   307 (* assumption: mutual recursive predicates all have the same parameters. *)  
       
   308 fun define_predicates specs thy =
       
   309   if forall (fn (const, _) => member (op =) (Symtab.keys (Pred_Compile_Preproc.get thy)) const) specs then
       
   310     ([], thy)
       
   311   else
       
   312   let
       
   313     val consts = map fst specs
       
   314     val eqns = maps snd specs
       
   315     (*val eqns = maps (Predicate_Compile_Preproc_Data.get_specification thy) consts*)
       
   316       (* create prednames *)
       
   317     val ((funs, argss), rhss) = map_split dest_code_eqn eqns |>> split_list
       
   318     val argss' = map (map transform_ho_arg) argss
       
   319     val pnames = map dest_Free (distinct (op =) (maps (filter (is_funtype o fastype_of)) argss'))
       
   320     val preds = map pred_of funs
       
   321     val prednames = map (fst o dest_Free) preds
       
   322     val funnames = map (fst o dest_Const) funs
       
   323     val fun_pred_names = (funnames ~~ prednames)  
       
   324       (* mapping from term (Free or Const) to term *)
       
   325     fun lookup_pred (Const (name, T)) =
       
   326       (case (Symtab.lookup (Pred_Compile_Preproc.get thy) name) of
       
   327           SOME c => Const (c, pred_type T)
       
   328         | NONE =>
       
   329           (case AList.lookup op = fun_pred_names name of
       
   330             SOME f => Free (f, pred_type T)
       
   331           | NONE => Const (name, T)))
       
   332       | lookup_pred (Free (name, T)) =
       
   333         if member op = (map fst pnames) name then
       
   334           Free (name, transform_ho_typ T)
       
   335         else
       
   336           Free (name, T)
       
   337       | lookup_pred t =
       
   338          error ("lookup function is not defined for " ^ Syntax.string_of_term_global thy t)
       
   339      
       
   340         (* mapping from term (predicate term, not function term!) to int *)
       
   341     fun get_nparams (Const (name, _)) =
       
   342       the_default 0 (try (ind_package_get_nparams thy) name)
       
   343     | get_nparams (Free (name, _)) =
       
   344         (if member op = prednames name then
       
   345           length pnames
       
   346         else 0)
       
   347     | get_nparams t = error ("No parameters for " ^ (Syntax.string_of_term_global thy t))
       
   348   
       
   349     (* create intro rules *)
       
   350   
       
   351     fun mk_intros ((func, pred), (args, rhs)) =
       
   352       if (body_type (fastype_of func) = @{typ bool}) then
       
   353        (*TODO: preprocess predicate definition of rhs *)
       
   354         [Logic.list_implies ([HOLogic.mk_Trueprop rhs], HOLogic.mk_Trueprop (list_comb (pred, args)))]
       
   355       else
       
   356         let
       
   357           val names = Term.add_free_names rhs []
       
   358         in mk_prems thy (lookup_pred, get_nparams) rhs (names, [])
       
   359           |> map (fn (resultt, (names', prems)) =>
       
   360             Logic.list_implies (prems, HOLogic.mk_Trueprop (list_comb (pred, args @ [resultt]))))
       
   361         end
       
   362     fun mk_rewr_thm (func, pred) = @{thm refl}
       
   363   in
       
   364     case try (maps mk_intros) ((funs ~~ preds) ~~ (argss' ~~ rhss)) of
       
   365       NONE => ([], thy) 
       
   366     | SOME intr_ts =>
       
   367         if is_some (try (map (cterm_of thy)) intr_ts) then
       
   368           let
       
   369             val (ind_result, thy') =
       
   370               Inductive.add_inductive_global (serial ())
       
   371                 {quiet_mode = false, verbose = false, kind = Thm.internalK,
       
   372                   alt_name = Binding.empty, coind = false, no_elim = false,
       
   373                   no_ind = false, skip_mono = false, fork_mono = false}
       
   374                 (map (fn (s, T) => ((Binding.name s, T), NoSyn)) (distinct (op =) (map dest_Free preds)))
       
   375                 pnames
       
   376                 (map (fn x => (Attrib.empty_binding, x)) intr_ts)
       
   377                 [] thy
       
   378             val prednames = map (fst o dest_Const) (#preds ind_result)
       
   379             (* val rewr_thms = map mk_rewr_eq ((distinct (op =) funs) ~~ (#preds ind_result)) *)
       
   380             (* add constants to my table *)
       
   381             val specs = map (fn predname => (predname, filter (Predicate_Compile_Aux.is_intro predname) (#intrs ind_result))) prednames
       
   382             val thy'' = Pred_Compile_Preproc.map (fold Symtab.update_new (consts ~~ prednames)) thy'
       
   383           in
       
   384             (specs, thy'')
       
   385           end
       
   386         else
       
   387           let
       
   388             val _ = tracing "Introduction rules of function_predicate are not welltyped"
       
   389           in ([], thy) end
       
   390   end
       
   391 
       
   392 (* preprocessing intro rules - uses oracle *)
       
   393 
       
   394 (* theory -> thm -> thm *)
       
   395 fun rewrite_intro thy intro =
       
   396   let
       
   397     fun lookup_pred (Const (name, T)) =
       
   398       (case (Symtab.lookup (Pred_Compile_Preproc.get thy) name) of
       
   399         SOME c => Const (c, pred_type T)
       
   400       | NONE => error ("Function " ^ name ^ " is not inductified"))
       
   401     | lookup_pred (Free (name, T)) = Free (name, T)
       
   402     | lookup_pred _ = error "lookup function is not defined!"
       
   403 
       
   404     fun get_nparams (Const (name, _)) =
       
   405       the_default 0 (try (ind_package_get_nparams thy) name)
       
   406     | get_nparams (Free _) = 0
       
   407     | get_nparams t = error ("No parameters for " ^ (Syntax.string_of_term_global thy t))
       
   408     
       
   409     val intro_t = (Logic.unvarify o prop_of) intro
       
   410     val (prems, concl) = Logic.strip_horn intro_t
       
   411     val frees = map fst (Term.add_frees intro_t [])
       
   412     fun rewrite prem names =
       
   413       let
       
   414         val t = (HOLogic.dest_Trueprop prem)
       
   415         val (lit, mk_lit) = case try HOLogic.dest_not t of
       
   416             SOME t => (t, HOLogic.mk_not)
       
   417           | NONE => (t, I)
       
   418         val (P, args) = (strip_comb lit) 
       
   419       in
       
   420         folds_map (
       
   421           fn t => if (is_funtype (fastype_of t)) then (fn x => [(t, x)])
       
   422             else mk_prems thy (lookup_pred, get_nparams) t) args (names, [])
       
   423         |> map (fn (resargs, (names', prems')) =>
       
   424           let
       
   425             val prem' = HOLogic.mk_Trueprop (mk_lit (list_comb (P, resargs)))
       
   426           in (prem'::prems', names') end)
       
   427       end
       
   428     val intro_ts' = folds_map rewrite prems frees
       
   429       |> maps (fn (prems', frees') =>
       
   430         rewrite concl frees'
       
   431         |> map (fn (concl'::conclprems, _) =>
       
   432           Logic.list_implies ((flat prems') @ conclprems, concl')))
       
   433   in
       
   434     map (Drule.standard o the_oracle () o cterm_of thy) intro_ts'
       
   435   end; 
       
   436 
       
   437 end;