src/HOL/Tools/Function/fun.ML
author wenzelm
Sat Mar 22 18:19:57 2014 +0100 (2014-03-22)
changeset 56254 a2dd9200854d
parent 54407 e95831757903
child 58826 2ed2eaabe3df
permissions -rw-r--r--
more antiquotations;
     1 (*  Title:      HOL/Tools/Function/fun.ML
     2     Author:     Alexander Krauss, TU Muenchen
     3 
     4 Command "fun": Function definitions with pattern splitting/completion
     5 and automated termination proofs.
     6 *)
     7 
     8 signature FUNCTION_FUN =
     9 sig
    10   val fun_config : Function_Common.function_config
    11   val add_fun : (binding * typ option * mixfix) list ->
    12     (Attrib.binding * term) list -> Function_Common.function_config ->
    13     local_theory -> Proof.context
    14   val add_fun_cmd : (binding * string option * mixfix) list ->
    15     (Attrib.binding * string) list -> Function_Common.function_config ->
    16     bool -> local_theory -> Proof.context
    17 
    18   val setup : theory -> theory
    19 end
    20 
    21 structure Function_Fun : FUNCTION_FUN =
    22 struct
    23 
    24 open Function_Lib
    25 open Function_Common
    26 
    27 
    28 fun check_pats ctxt geq =
    29   let
    30     fun err str = error (cat_lines ["Malformed definition:",
    31       str ^ " not allowed in sequential mode.",
    32       Syntax.string_of_term ctxt geq])
    33     val thy = Proof_Context.theory_of ctxt
    34 
    35     fun check_constr_pattern (Bound _) = ()
    36       | check_constr_pattern t =
    37       let
    38         val (hd, args) = strip_comb t
    39       in
    40         (case hd of
    41           Const (hd_s, hd_T) =>
    42           (case body_type hd_T of
    43             Type (Tname, _) =>
    44             (case Ctr_Sugar.ctr_sugar_of ctxt Tname of
    45               SOME {ctrs, ...} => exists (fn Const (s, _) => s = hd_s) ctrs
    46             | NONE => false)
    47           | _ => false)
    48         | _ => false) orelse err "Non-constructor pattern";
    49         map check_constr_pattern args;
    50         ()
    51       end
    52 
    53     val (_, qs, gs, args, _) = split_def ctxt (K true) geq
    54 
    55     val _ = if not (null gs) then err "Conditional equations" else ()
    56     val _ = map check_constr_pattern args
    57 
    58     (* just count occurrences to check linearity *)
    59     val _ = if fold (fold_aterms (fn Bound _ => Integer.add 1 | _ => I)) args 0 > length qs
    60       then err "Nonlinear patterns" else ()
    61   in
    62     ()
    63   end
    64 
    65 fun mk_catchall fixes arity_of =
    66   let
    67     fun mk_eqn ((fname, fT), _) =
    68       let
    69         val n = arity_of fname
    70         val (argTs, rT) = chop n (binder_types fT)
    71           |> apsnd (fn Ts => Ts ---> body_type fT)
    72 
    73         val qs = map Free (Name.invent Name.context "a" n ~~ argTs)
    74       in
    75         HOLogic.mk_eq(list_comb (Free (fname, fT), qs),
    76           Const (@{const_name undefined}, rT))
    77         |> HOLogic.mk_Trueprop
    78         |> fold_rev Logic.all qs
    79       end
    80   in
    81     map mk_eqn fixes
    82   end
    83 
    84 fun add_catchall ctxt fixes spec =
    85   let val fqgars = map (split_def ctxt (K true)) spec
    86       val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
    87                      |> AList.lookup (op =) #> the
    88   in
    89     spec @ mk_catchall fixes arity_of
    90   end
    91 
    92 fun further_checks ctxt origs tss =
    93   let
    94     fun fail_redundant t =
    95       error (cat_lines ["Equation is redundant (covered by preceding clauses):", Syntax.string_of_term ctxt t])
    96     fun warn_missing strs =
    97       warning (cat_lines ("Missing patterns in function definition:" :: strs))
    98 
    99     val (tss', added) = chop (length origs) tss
   100 
   101     val _ = case chop 3 (flat added) of
   102        ([], []) => ()
   103      | (eqs, []) => warn_missing (map (Syntax.string_of_term ctxt) eqs)
   104      | (eqs, rest) => warn_missing (map (Syntax.string_of_term ctxt) eqs
   105          @ ["(" ^ string_of_int (length rest) ^ " more)"])
   106 
   107     val _ = (origs ~~ tss')
   108       |> map (fn (t, ts) => if null ts then fail_redundant t else ())
   109   in
   110     ()
   111   end
   112 
   113 fun sequential_preproc (config as FunctionConfig {sequential, ...}) ctxt fixes spec =
   114   if sequential then
   115     let
   116       val (bnds, eqss) = split_list spec
   117 
   118       val eqs = map the_single eqss
   119 
   120       val feqs = eqs
   121         |> tap (check_defs ctxt fixes) (* Standard checks *)
   122         |> tap (map (check_pats ctxt)) (* More checks for sequential mode *)
   123 
   124       val compleqs = add_catchall ctxt fixes feqs (* Completion *)
   125 
   126       val spliteqs = Function_Split.split_all_equations ctxt compleqs
   127         |> tap (further_checks ctxt feqs)
   128 
   129       fun restore_spec thms =
   130         bnds ~~ take (length bnds) (unflat spliteqs thms)
   131 
   132       val spliteqs' = flat (take (length bnds) spliteqs)
   133       val fnames = map (fst o fst) fixes
   134       val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs'
   135 
   136       fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs)
   137         |> map (map snd)
   138 
   139 
   140       val bnds' = bnds @ replicate (length spliteqs - length bnds) Attrib.empty_binding
   141 
   142       (* using theorem names for case name currently disabled *)
   143       val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) 
   144         (bnds' ~~ spliteqs) |> flat
   145     in
   146       (flat spliteqs, restore_spec, sort, case_names)
   147     end
   148   else
   149     Function_Common.empty_preproc check_defs config ctxt fixes spec
   150 
   151 val setup =
   152   Context.theory_map (Function_Common.set_preproc sequential_preproc)
   153 
   154 
   155 val fun_config = FunctionConfig { sequential=true, default=NONE,
   156   domintros=false, partials=false }
   157 
   158 fun gen_add_fun add lthy =
   159   let
   160     fun pat_completeness_auto ctxt =
   161       Pat_Completeness.pat_completeness_tac ctxt 1
   162       THEN auto_tac ctxt
   163     fun prove_termination lthy =
   164       Function.prove_termination NONE
   165         (Function_Common.get_termination_prover lthy) lthy
   166   in
   167     lthy
   168     |> add pat_completeness_auto |> snd
   169     |> prove_termination |> snd
   170   end
   171 
   172 fun add_fun a b c = gen_add_fun (Function.add_function a b c)
   173 fun add_fun_cmd a b c int = gen_add_fun (fn tac => Function.add_function_cmd a b c tac int)
   174 
   175 
   176 
   177 val _ =
   178   Outer_Syntax.local_theory' @{command_spec "fun"}
   179     "define general recursive functions (short version)"
   180     (function_parser fun_config
   181       >> (fn ((config, fixes), statements) => add_fun_cmd fixes statements config))
   182 
   183 end