src/HOL/Tools/Predicate_Compile/predicate_compile_fun.ML
author wenzelm
Fri Nov 13 19:57:46 2009 +0100 (2009-11-13)
changeset 33669 ae9a2ea9a989
parent 33643 b275f26a638b
child 33726 0878aecbf119
permissions -rw-r--r--
inductive: eliminated obsolete kind;
     1 (*  Title:      HOL/Tools/Predicate_Compile/predicate_compile_fun.ML
     2     Author:     Lukas Bulwahn, TU Muenchen
     3 
     4 Preprocessing functions to predicates.
     5 *)
     6 
     7 signature PREDICATE_COMPILE_FUN =
     8 sig
     9   val define_predicates : (string * thm list) list -> theory -> (string * thm list) list * theory
    10   val rewrite_intro : theory -> thm -> thm list
    11   val pred_of_function : theory -> string -> string option
    12 end;
    13 
    14 structure Predicate_Compile_Fun : PREDICATE_COMPILE_FUN =
    15 struct
    16 
    17 fun is_funtype (Type ("fun", [_, _])) = true
    18   | is_funtype _ = false;
    19 
    20 fun is_Type (Type _) = true
    21   | is_Type _ = false
    22 
    23 (* returns true if t is an application of an datatype constructor *)
    24 (* which then consequently would be splitted *)
    25 (* else false *)
    26 (*
    27 fun is_constructor thy t =
    28   if (is_Type (fastype_of t)) then
    29     (case DatatypePackage.get_datatype thy ((fst o dest_Type o fastype_of) t) of
    30       NONE => false
    31     | SOME info => (let
    32       val constr_consts = maps (fn (_, (_, _, constrs)) => map fst constrs) (#descr info)
    33       val (c, _) = strip_comb t
    34       in (case c of
    35         Const (name, _) => name mem_string constr_consts
    36         | _ => false) end))
    37   else false
    38 *)
    39 
    40 (* must be exported in code.ML *)
    41 fun is_constr thy = is_some o Code.get_datatype_of_constr thy;
    42 
    43 (* Table from constant name (string) to term of inductive predicate *)
    44 structure Pred_Compile_Preproc = Theory_Data
    45 (
    46   type T = string Symtab.table;
    47   val empty = Symtab.empty;
    48   val extend = I;
    49   fun merge data : T = Symtab.merge (op =) data;   (* FIXME handle Symtab.DUP ?? *)
    50 )
    51 
    52 fun pred_of_function thy name = Symtab.lookup (Pred_Compile_Preproc.get thy) name
    53 
    54 fun defined thy = Symtab.defined (Pred_Compile_Preproc.get thy) 
    55 
    56 
    57 fun transform_ho_typ (T as Type ("fun", _)) =
    58   let
    59     val (Ts, T') = strip_type T
    60   in if T' = @{typ "bool"} then T else (Ts @ [T']) ---> HOLogic.boolT end
    61 | transform_ho_typ t = t
    62 
    63 fun transform_ho_arg arg = 
    64   case (fastype_of arg) of
    65     (T as Type ("fun", _)) =>
    66       (case arg of
    67         Free (name, _) => Free (name, transform_ho_typ T)
    68       | _ => error "I am surprised")
    69 | _ => arg
    70 
    71 fun pred_type T =
    72   let
    73     val (Ts, T') = strip_type T
    74     val Ts' = map transform_ho_typ Ts
    75   in
    76     (Ts' @ [T']) ---> HOLogic.boolT
    77   end;
    78 
    79 (* FIXME: create new predicate name -- does not avoid nameclashing *)
    80 fun pred_of f =
    81   let
    82     val (name, T) = dest_Const f
    83   in
    84     if (body_type T = @{typ bool}) then
    85       (Free (Long_Name.base_name name ^ "P", T))
    86     else
    87       (Free (Long_Name.base_name name ^ "P", pred_type T))
    88   end
    89 
    90 fun mk_param thy lookup_pred (t as Free (v, _)) = lookup_pred t
    91   | mk_param thy lookup_pred t =
    92   let
    93   val _ = tracing ("called param with " ^ (Syntax.string_of_term_global thy t))
    94   in if Predicate_Compile_Aux.is_predT (fastype_of t) then
    95     t
    96   else
    97     let
    98       val (vs, body) = strip_abs t
    99       val names = Term.add_free_names body []
   100       val vs_names = Name.variant_list names (map fst vs)
   101       val vs' = map2 (curry Free) vs_names (map snd vs)
   102       val body' = subst_bounds (rev vs', body)
   103       val (f, args) = strip_comb body'
   104       val resname = Name.variant (vs_names @ names) "res"
   105       val resvar = Free (resname, body_type (fastype_of body'))
   106       (*val P = case try lookup_pred f of SOME P => P | NONE => error "mk_param"
   107       val pred_body = list_comb (P, args @ [resvar])
   108       *)
   109       val pred_body = HOLogic.mk_eq (body', resvar)
   110       val param = fold_rev lambda (vs' @ [resvar]) pred_body
   111     in param end
   112   end
   113 (* creates the list of premises for every intro rule *)
   114 (* theory -> term -> (string list, term list list) *)
   115 
   116 fun dest_code_eqn eqn = let
   117   val (lhs, rhs) = Logic.dest_equals (Logic.unvarify (Thm.prop_of eqn))
   118   val (func, args) = strip_comb lhs
   119 in ((func, args), rhs) end;
   120 
   121 fun string_of_typ T = Syntax.string_of_typ_global @{theory} T
   122 
   123 fun string_of_term t =
   124   case t of
   125     Const (c, T) => "Const (" ^ c ^ ", " ^ string_of_typ T ^ ")"
   126   | Free (c, T) => "Free (" ^ c ^ ", " ^ string_of_typ T ^ ")"
   127   | Var ((c, i), T) => "Var ((" ^ c ^ ", " ^ string_of_int i ^ "), " ^ string_of_typ T ^ ")"
   128   | Bound i => "Bound " ^ string_of_int i
   129   | Abs (x, T, t) => "Abs (" ^ x ^ ", " ^ string_of_typ T ^ ", " ^ string_of_term t ^ ")"
   130   | t1 $ t2 => "(" ^ string_of_term t1 ^ ") $ (" ^ string_of_term t2 ^ ")"
   131   
   132 fun ind_package_get_nparams thy name =
   133   case try (Inductive.the_inductive (ProofContext.init thy)) name of
   134     SOME (_, result) => length (Inductive.params_of (#raw_induct result))
   135   | NONE => error ("No such predicate: " ^ quote name) 
   136 
   137 (* TODO: does not work with higher order functions yet *)
   138 fun mk_rewr_eq (func, pred) =
   139   let
   140     val (argTs, resT) = (strip_type (fastype_of func))
   141     val nctxt =
   142       Name.make_context (Term.fold_aterms (fn Free (x, _) => insert (op =) x | _ => I) (func $ pred) [])
   143     val (argnames, nctxt') = Name.variants (replicate (length argTs) "a") nctxt
   144     val ([resname], nctxt'') = Name.variants ["r"] nctxt'
   145     val args = map Free (argnames ~~ argTs)
   146     val res = Free (resname, resT)
   147   in Logic.mk_equals
   148       (HOLogic.mk_eq (res, list_comb (func, args)), list_comb (pred, args @ [res]))
   149   end;
   150 
   151 fun has_split_rule_cname @{const_name "nat_case"} = true
   152   | has_split_rule_cname @{const_name "list_case"} = true
   153   | has_split_rule_cname _ = false
   154   
   155 fun has_split_rule_term thy (Const (@{const_name "nat_case"}, _)) = true 
   156   | has_split_rule_term thy (Const (@{const_name "list_case"}, _)) = true 
   157   | has_split_rule_term thy _ = false
   158 
   159 fun has_split_rule_term' thy (Const (@{const_name "If"}, _)) = true
   160   | has_split_rule_term' thy (Const (@{const_name "Let"}, _)) = true
   161   | has_split_rule_term' thy c = has_split_rule_term thy c
   162   
   163 fun prepare_split_thm ctxt split_thm =
   164     (split_thm RS @{thm iffD2})
   165     |> LocalDefs.unfold ctxt [@{thm atomize_conjL[symmetric]},
   166       @{thm atomize_all[symmetric]}, @{thm atomize_imp[symmetric]}]
   167 
   168 fun find_split_thm thy (Const (name, typ)) =
   169   let
   170     fun split_name str =
   171       case first_field "." str
   172         of (SOME (field, rest)) => field :: split_name rest
   173          | NONE => [str]
   174     val splitted_name = split_name name
   175   in
   176     if length splitted_name > 0 andalso
   177        String.isSuffix "_case" (List.last splitted_name)
   178     then
   179       (List.take (splitted_name, length splitted_name - 1)) @ ["split"]
   180       |> space_implode "."
   181       |> PureThy.get_thm thy
   182       |> SOME
   183       handle ERROR msg => NONE
   184     else NONE
   185   end
   186   | find_split_thm _ _ = NONE
   187 
   188 fun find_split_thm' thy (Const (@{const_name "If"}, _)) = SOME @{thm split_if}
   189   | find_split_thm' thy (Const (@{const_name "Let"}, _)) = SOME @{thm refl} (* TODO *)
   190   | find_split_thm' thy c = find_split_thm thy c
   191 
   192 fun strip_all t = (Term.strip_all_vars t, Term.strip_all_body t)
   193 
   194 fun folds_map f xs y =
   195   let
   196     fun folds_map' acc [] y = [(rev acc, y)]
   197       | folds_map' acc (x :: xs) y =
   198         maps (fn (x, y) => folds_map' (x :: acc) xs y) (f x y)
   199     in
   200       folds_map' [] xs y
   201     end;
   202 
   203 fun mk_prems thy (lookup_pred, get_nparams) t (names, prems) =
   204   let
   205     fun mk_prems' (t as Const (name, T)) (names, prems) =
   206       if is_constr thy name orelse (is_none (try lookup_pred t)) then
   207         [(t, (names, prems))]
   208       else [(lookup_pred t, (names, prems))]
   209     | mk_prems' (t as Free (f, T)) (names, prems) = 
   210       [(lookup_pred t, (names, prems))]
   211     | mk_prems' (t as Abs _) (names, prems) =
   212       if Predicate_Compile_Aux.is_predT (fastype_of t) then
   213       [(t, (names, prems))] else error "mk_prems': Abs "
   214       (* mk_param *)
   215     | mk_prems' t (names, prems) =
   216       if Predicate_Compile_Aux.is_constrt thy t then
   217         [(t, (names, prems))]
   218       else
   219         if has_split_rule_term' thy (fst (strip_comb t)) then
   220           let
   221             val (f, args) = strip_comb t
   222             val split_thm = prepare_split_thm (ProofContext.init thy) (the (find_split_thm' thy f))
   223             (* TODO: contextify things - this line is to unvarify the split_thm *)
   224             (*val ((_, [isplit_thm]), _) = Variable.import true [split_thm] (ProofContext.init thy)*)
   225             val (assms, concl) = Logic.strip_horn (Thm.prop_of split_thm)
   226             val (P, [split_t]) = strip_comb (HOLogic.dest_Trueprop concl) 
   227             val subst = Pattern.match thy (split_t, t) (Vartab.empty, Vartab.empty)
   228             val (_, split_args) = strip_comb split_t
   229             val match = split_args ~~ args
   230             fun mk_prems_of_assm assm =
   231               let
   232                 val (vTs, assm') = strip_all (Envir.beta_norm (Envir.subst_term subst assm))
   233                 val var_names = Name.variant_list names (map fst vTs)
   234                 val vars = map Free (var_names ~~ (map snd vTs))
   235                 val (prems', pre_res) = Logic.strip_horn (subst_bounds (rev vars, assm'))
   236                 val (_, [inner_t]) = strip_comb (HOLogic.dest_Trueprop pre_res)
   237               in
   238                 mk_prems' inner_t (var_names @ names, prems' @ prems)
   239               end
   240           in
   241             maps mk_prems_of_assm assms
   242           end
   243         else
   244           let
   245             val (f, args) = strip_comb t
   246             (* TODO: special procedure for higher-order functions: split arguments in
   247               simple types and function types *)
   248             val resname = Name.variant names "res"
   249             val resvar = Free (resname, body_type (fastype_of t))
   250             val names' = resname :: names
   251             fun mk_prems'' (t as Const (c, _)) =
   252               if is_constr thy c orelse (is_none (try lookup_pred t)) then
   253                 folds_map mk_prems' args (names', prems) |>
   254                 map
   255                   (fn (argvs, (names'', prems')) =>
   256                   let
   257                     val prem = HOLogic.mk_Trueprop (HOLogic.mk_eq (resvar, list_comb (f, argvs)))
   258                   in (names'', prem :: prems') end)
   259               else
   260                 let
   261                   val pred = lookup_pred t
   262                   val nparams = get_nparams pred
   263                   val (params, args) = chop nparams args
   264                   val params' = map (mk_param thy lookup_pred) params
   265                 in
   266                   folds_map mk_prems' args (names', prems)
   267                   |> map (fn (argvs, (names'', prems')) =>
   268                     let
   269                       val prem = HOLogic.mk_Trueprop (list_comb (pred, params' @ argvs @ [resvar]))
   270                     in (names'', prem :: prems') end)
   271                 end
   272             | mk_prems'' (t as Free (_, _)) =
   273                 let
   274                   (* higher order argument call *)
   275                   val pred = lookup_pred t
   276                 in
   277                   folds_map mk_prems' args (resname :: names, prems)
   278                   |> map (fn (argvs, (names', prems')) =>
   279                      let
   280                        val prem = HOLogic.mk_Trueprop (list_comb (pred, argvs @ [resvar]))
   281                      in (names', prem :: prems') end)
   282                 end
   283             | mk_prems'' t =
   284               error ("Invalid term: " ^ Syntax.string_of_term_global thy t)
   285           in
   286             map (pair resvar) (mk_prems'' f)
   287           end
   288   in
   289     mk_prems' t (names, prems)
   290   end;
   291 
   292 (* assumption: mutual recursive predicates all have the same parameters. *)  
   293 fun define_predicates specs thy =
   294   if forall (fn (const, _) => member (op =) (Symtab.keys (Pred_Compile_Preproc.get thy)) const) specs then
   295     ([], thy)
   296   else
   297   let
   298     val consts = map fst specs
   299     val eqns = maps snd specs
   300     (*val eqns = maps (Predicate_Compile_Preproc_Data.get_specification thy) consts*)
   301       (* create prednames *)
   302     val ((funs, argss), rhss) = map_split dest_code_eqn eqns |>> split_list
   303     val argss' = map (map transform_ho_arg) argss
   304     val pnames = map dest_Free (distinct (op =) (maps (filter (is_funtype o fastype_of)) argss'))
   305     val preds = map pred_of funs
   306     val prednames = map (fst o dest_Free) preds
   307     val funnames = map (fst o dest_Const) funs
   308     val fun_pred_names = (funnames ~~ prednames)  
   309       (* mapping from term (Free or Const) to term *)
   310     fun lookup_pred (Const (name, T)) =
   311       (case (Symtab.lookup (Pred_Compile_Preproc.get thy) name) of
   312           SOME c => Const (c, pred_type T)
   313         | NONE =>
   314           (case AList.lookup op = fun_pred_names name of
   315             SOME f => Free (f, pred_type T)
   316           | NONE => Const (name, T)))
   317       | lookup_pred (Free (name, T)) =
   318         if member op = (map fst pnames) name then
   319           Free (name, transform_ho_typ T)
   320         else
   321           Free (name, T)
   322       | lookup_pred t =
   323          error ("lookup function is not defined for " ^ Syntax.string_of_term_global thy t)
   324      
   325         (* mapping from term (predicate term, not function term!) to int *)
   326     fun get_nparams (Const (name, _)) =
   327       the_default 0 (try (ind_package_get_nparams thy) name)
   328     | get_nparams (Free (name, _)) =
   329         (if member op = prednames name then
   330           length pnames
   331         else 0)
   332     | get_nparams t = error ("No parameters for " ^ (Syntax.string_of_term_global thy t))
   333   
   334     (* create intro rules *)
   335   
   336     fun mk_intros ((func, pred), (args, rhs)) =
   337       if (body_type (fastype_of func) = @{typ bool}) then
   338        (*TODO: preprocess predicate definition of rhs *)
   339         [Logic.list_implies ([HOLogic.mk_Trueprop rhs], HOLogic.mk_Trueprop (list_comb (pred, args)))]
   340       else
   341         let
   342           val names = Term.add_free_names rhs []
   343         in mk_prems thy (lookup_pred, get_nparams) rhs (names, [])
   344           |> map (fn (resultt, (names', prems)) =>
   345             Logic.list_implies (prems, HOLogic.mk_Trueprop (list_comb (pred, args @ [resultt]))))
   346         end
   347     fun mk_rewr_thm (func, pred) = @{thm refl}
   348   in
   349     case try (maps mk_intros) ((funs ~~ preds) ~~ (argss' ~~ rhss)) of
   350       NONE => ([], thy) 
   351     | SOME intr_ts =>
   352         if is_some (try (map (cterm_of thy)) intr_ts) then
   353           let
   354             val (ind_result, thy') =
   355               thy
   356               |> Sign.map_naming Name_Space.conceal
   357               |> Inductive.add_inductive_global (serial ())
   358                 {quiet_mode = false, verbose = false, alt_name = Binding.empty, coind = false,
   359                   no_elim = false, no_ind = false, skip_mono = false, fork_mono = false}
   360                 (map (fn (s, T) =>
   361                   ((Binding.name s, T), NoSyn)) (distinct (op =) (map dest_Free preds)))
   362                 pnames
   363                 (map (fn x => (Attrib.empty_binding, x)) intr_ts)
   364                 []
   365               ||> Sign.restore_naming thy
   366             val prednames = map (fst o dest_Const) (#preds ind_result)
   367             (* val rewr_thms = map mk_rewr_eq ((distinct (op =) funs) ~~ (#preds ind_result)) *)
   368             (* add constants to my table *)
   369             val specs = map (fn predname => (predname, filter (Predicate_Compile_Aux.is_intro predname) (#intrs ind_result))) prednames
   370             val thy'' = Pred_Compile_Preproc.map (fold Symtab.update_new (consts ~~ prednames)) thy'
   371           in
   372             (specs, thy'')
   373           end
   374         else
   375           let
   376             val _ = tracing "Introduction rules of function_predicate are not welltyped"
   377           in ([], thy) end
   378   end
   379 
   380 fun rewrite_intro thy intro =
   381   let
   382     fun lookup_pred (Const (name, T)) =
   383       (case (Symtab.lookup (Pred_Compile_Preproc.get thy) name) of
   384         SOME c => Const (c, pred_type T)
   385       | NONE => error ("Function " ^ name ^ " is not inductified"))
   386     | lookup_pred (Free (name, T)) = Free (name, T)
   387     | lookup_pred _ = error "lookup function is not defined!"
   388 
   389     fun get_nparams (Const (name, _)) =
   390       the_default 0 (try (ind_package_get_nparams thy) name)
   391     | get_nparams (Free _) = 0
   392     | get_nparams t = error ("No parameters for " ^ (Syntax.string_of_term_global thy t))
   393     
   394     val intro_t = (Logic.unvarify o prop_of) intro
   395     val (prems, concl) = Logic.strip_horn intro_t
   396     val frees = map fst (Term.add_frees intro_t [])
   397     fun rewrite prem names =
   398       let
   399         val t = (HOLogic.dest_Trueprop prem)
   400         val (lit, mk_lit) = case try HOLogic.dest_not t of
   401             SOME t => (t, HOLogic.mk_not)
   402           | NONE => (t, I)
   403         val (P, args) = (strip_comb lit) 
   404       in
   405         folds_map (
   406           fn t => if (is_funtype (fastype_of t)) then (fn x => [(t, x)])
   407             else mk_prems thy (lookup_pred, get_nparams) t) args (names, [])
   408         |> map (fn (resargs, (names', prems')) =>
   409           let
   410             val prem' = HOLogic.mk_Trueprop (mk_lit (list_comb (P, resargs)))
   411           in (prem'::prems', names') end)
   412       end
   413     val intro_ts' = folds_map rewrite prems frees
   414       |> maps (fn (prems', frees') =>
   415         rewrite concl frees'
   416         |> map (fn (concl'::conclprems, _) =>
   417           Logic.list_implies ((flat prems') @ conclprems, concl')))
   418   in
   419     map (Drule.standard o (Skip_Proof.make_thm thy)) intro_ts'
   420   end
   421 
   422 end;