src/HOL/Tools/Function/fun.ML
author wenzelm
Fri Mar 16 18:20:12 2012 +0100 (2012-03-16)
changeset 46961 5c6955f487e5
parent 45639 efddd75c741e
child 48099 e7e647949c95
permissions -rw-r--r--
outer syntax command definitions based on formal command_spec derived from theory header declarations;
     1 (*  Title:      HOL/Tools/Function/fun.ML
     2     Author:     Alexander Krauss, TU Muenchen
     3 
     4 Command "fun": Function definitions with pattern splitting/completion
     5 and automated termination proofs.
     6 *)
     7 
     8 signature FUNCTION_FUN =
     9 sig
    10   val fun_config : Function_Common.function_config
    11   val add_fun : (binding * typ option * mixfix) list ->
    12     (Attrib.binding * term) list -> Function_Common.function_config ->
    13     local_theory -> Proof.context
    14   val add_fun_cmd : (binding * string option * mixfix) list ->
    15     (Attrib.binding * string) list -> Function_Common.function_config ->
    16     bool -> local_theory -> Proof.context
    17 
    18   val setup : theory -> theory
    19 end
    20 
    21 structure Function_Fun : FUNCTION_FUN =
    22 struct
    23 
    24 open Function_Lib
    25 open Function_Common
    26 
    27 
    28 fun check_pats ctxt geq =
    29   let
    30     fun err str = error (cat_lines ["Malformed definition:",
    31       str ^ " not allowed in sequential mode.",
    32       Syntax.string_of_term ctxt geq])
    33     val thy = Proof_Context.theory_of ctxt
    34 
    35     fun check_constr_pattern (Bound _) = ()
    36       | check_constr_pattern t =
    37       let
    38         val (hd, args) = strip_comb t
    39       in
    40         (((case Datatype.info_of_constr thy (dest_Const hd) of
    41              SOME _ => ()
    42            | NONE => err "Non-constructor pattern")
    43           handle TERM ("dest_Const", _) => err "Non-constructor patterns");
    44          map check_constr_pattern args;
    45          ())
    46       end
    47 
    48     val (_, qs, gs, args, _) = split_def ctxt (K true) geq
    49 
    50     val _ = if not (null gs) then err "Conditional equations" else ()
    51     val _ = map check_constr_pattern args
    52 
    53     (* just count occurrences to check linearity *)
    54     val _ = if fold (fold_aterms (fn Bound _ => Integer.add 1 | _ => I)) args 0 > length qs
    55       then err "Nonlinear patterns" else ()
    56   in
    57     ()
    58   end
    59 
    60 fun mk_catchall fixes arity_of =
    61   let
    62     fun mk_eqn ((fname, fT), _) =
    63       let
    64         val n = arity_of fname
    65         val (argTs, rT) = chop n (binder_types fT)
    66           |> apsnd (fn Ts => Ts ---> body_type fT)
    67 
    68         val qs = map Free (Name.invent Name.context "a" n ~~ argTs)
    69       in
    70         HOLogic.mk_eq(list_comb (Free (fname, fT), qs),
    71           Const ("HOL.undefined", rT))
    72         |> HOLogic.mk_Trueprop
    73         |> fold_rev Logic.all qs
    74       end
    75   in
    76     map mk_eqn fixes
    77   end
    78 
    79 fun add_catchall ctxt fixes spec =
    80   let val fqgars = map (split_def ctxt (K true)) spec
    81       val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
    82                      |> AList.lookup (op =) #> the
    83   in
    84     spec @ mk_catchall fixes arity_of
    85   end
    86 
    87 fun warnings ctxt origs tss =
    88   let
    89     fun warn_redundant t =
    90       warning ("Ignoring redundant equation: " ^ quote (Syntax.string_of_term ctxt t))
    91     fun warn_missing strs =
    92       warning (cat_lines ("Missing patterns in function definition:" :: strs))
    93 
    94     val (tss', added) = chop (length origs) tss
    95 
    96     val _ = case chop 3 (flat added) of
    97        ([], []) => ()
    98      | (eqs, []) => warn_missing (map (Syntax.string_of_term ctxt) eqs)
    99      | (eqs, rest) => warn_missing (map (Syntax.string_of_term ctxt) eqs
   100          @ ["(" ^ string_of_int (length rest) ^ " more)"])
   101 
   102     val _ = (origs ~~ tss')
   103       |> map (fn (t, ts) => if null ts then warn_redundant t else ())
   104   in
   105     ()
   106   end
   107 
   108 fun sequential_preproc (config as FunctionConfig {sequential, ...}) ctxt fixes spec =
   109   if sequential then
   110     let
   111       val (bnds, eqss) = split_list spec
   112 
   113       val eqs = map the_single eqss
   114 
   115       val feqs = eqs
   116         |> tap (check_defs ctxt fixes) (* Standard checks *)
   117         |> tap (map (check_pats ctxt)) (* More checks for sequential mode *)
   118 
   119       val compleqs = add_catchall ctxt fixes feqs (* Completion *)
   120 
   121       val spliteqs = Function_Split.split_all_equations ctxt compleqs
   122         |> tap (warnings ctxt feqs)
   123 
   124       fun restore_spec thms =
   125         bnds ~~ take (length bnds) (unflat spliteqs thms)
   126 
   127       val spliteqs' = flat (take (length bnds) spliteqs)
   128       val fnames = map (fst o fst) fixes
   129       val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs'
   130 
   131       fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs)
   132         |> map (map snd)
   133 
   134 
   135       val bnds' = bnds @ replicate (length spliteqs - length bnds) Attrib.empty_binding
   136 
   137       (* using theorem names for case name currently disabled *)
   138       val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) 
   139         (bnds' ~~ spliteqs) |> flat
   140     in
   141       (flat spliteqs, restore_spec, sort, case_names)
   142     end
   143   else
   144     Function_Common.empty_preproc check_defs config ctxt fixes spec
   145 
   146 val setup =
   147   Context.theory_map (Function_Common.set_preproc sequential_preproc)
   148 
   149 
   150 val fun_config = FunctionConfig { sequential=true, default=NONE,
   151   domintros=false, partials=false }
   152 
   153 fun gen_add_fun add lthy =
   154   let
   155     fun pat_completeness_auto ctxt =
   156       Pat_Completeness.pat_completeness_tac ctxt 1
   157       THEN auto_tac ctxt
   158     fun prove_termination lthy =
   159       Function.prove_termination NONE
   160         (Function_Common.get_termination_prover lthy lthy) lthy
   161   in
   162     lthy
   163     |> add pat_completeness_auto |> snd
   164     |> prove_termination |> snd
   165   end
   166 
   167 fun add_fun a b c = gen_add_fun (Function.add_function a b c)
   168 fun add_fun_cmd a b c int = gen_add_fun (fn tac => Function.add_function_cmd a b c tac int)
   169 
   170 
   171 
   172 val _ =
   173   Outer_Syntax.local_theory' @{command_spec "fun"}
   174     "define general recursive functions (short version)"
   175     (function_parser fun_config
   176       >> (fn ((config, fixes), statements) => add_fun_cmd fixes statements config))
   177 
   178 end