src/HOL/Tools/Function/fun.ML
author krauss
Wed Apr 28 11:52:04 2010 +0200 (2010-04-28)
changeset 36521 73ed9f18fdd3
parent 36519 46bf776a81e0
child 36523 a294e4ebe0a3
permissions -rw-r--r--
default termination prover as plain tactic
     1 (*  Title:      HOL/Tools/Function/fun.ML
     2     Author:     Alexander Krauss, TU Muenchen
     3 
     4 Sequential mode for function definitions
     5 Command "fun" for fully automated function definitions
     6 *)
     7 
     8 signature FUNCTION_FUN =
     9 sig
    10   val add_fun : Function_Common.function_config ->
    11     (binding * typ option * mixfix) list -> (Attrib.binding * term) list ->
    12     bool -> local_theory -> Proof.context
    13   val add_fun_cmd : Function_Common.function_config ->
    14     (binding * string option * mixfix) list -> (Attrib.binding * string) list ->
    15     bool -> local_theory -> Proof.context
    16 
    17   val setup : theory -> theory
    18 end
    19 
    20 structure Function_Fun : FUNCTION_FUN =
    21 struct
    22 
    23 open Function_Lib
    24 open Function_Common
    25 
    26 
    27 fun check_pats ctxt geq =
    28   let
    29     fun err str = error (cat_lines ["Malformed definition:",
    30       str ^ " not allowed in sequential mode.",
    31       Syntax.string_of_term ctxt geq])
    32     val thy = ProofContext.theory_of ctxt
    33 
    34     fun check_constr_pattern (Bound _) = ()
    35       | check_constr_pattern t =
    36       let
    37         val (hd, args) = strip_comb t
    38       in
    39         (((case Datatype.info_of_constr thy (dest_Const hd) of
    40              SOME _ => ()
    41            | NONE => err "Non-constructor pattern")
    42           handle TERM ("dest_Const", _) => err "Non-constructor patterns");
    43          map check_constr_pattern args;
    44          ())
    45       end
    46 
    47     val (_, qs, gs, args, _) = split_def ctxt geq
    48 
    49     val _ = if not (null gs) then err "Conditional equations" else ()
    50     val _ = map check_constr_pattern args
    51 
    52     (* just count occurrences to check linearity *)
    53     val _ = if fold (fold_aterms (fn Bound _ => Integer.add 1 | _ => I)) args 0 > length qs
    54       then err "Nonlinear patterns" else ()
    55   in
    56     ()
    57   end
    58 
    59 val by_pat_completeness_auto =
    60   Proof.global_future_terminal_proof
    61     (Method.Basic Pat_Completeness.pat_completeness,
    62      SOME (Method.Source_i (Args.src (("HOL.auto", []), Position.none))))
    63 
    64 fun termination_by method int =
    65   Function.termination NONE
    66   #> Proof.global_future_terminal_proof (Method.Basic method, NONE) int
    67 
    68 fun mk_catchall fixes arity_of =
    69   let
    70     fun mk_eqn ((fname, fT), _) =
    71       let
    72         val n = arity_of fname
    73         val (argTs, rT) = chop n (binder_types fT)
    74           |> apsnd (fn Ts => Ts ---> body_type fT)
    75 
    76         val qs = map Free (Name.invent_list [] "a" n ~~ argTs)
    77       in
    78         HOLogic.mk_eq(list_comb (Free (fname, fT), qs),
    79           Const ("HOL.undefined", rT))
    80         |> HOLogic.mk_Trueprop
    81         |> fold_rev Logic.all qs
    82       end
    83   in
    84     map mk_eqn fixes
    85   end
    86 
    87 fun add_catchall ctxt fixes spec =
    88   let val fqgars = map (split_def ctxt) spec
    89       val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
    90                      |> AList.lookup (op =) #> the
    91   in
    92     spec @ mk_catchall fixes arity_of
    93   end
    94 
    95 fun warn_if_redundant ctxt origs tss =
    96   let
    97     fun msg t = "Ignoring redundant equation: " ^ quote (Syntax.string_of_term ctxt t)
    98 
    99     val (tss', _) = chop (length origs) tss
   100     fun check (t, []) = (warning (msg t); [])
   101       | check (t, s) = s
   102   in
   103     (map check (origs ~~ tss'); tss)
   104   end
   105 
   106 fun sequential_preproc (config as FunctionConfig {sequential, ...}) ctxt fixes spec =
   107   if sequential then
   108     let
   109       val (bnds, eqss) = split_list spec
   110 
   111       val eqs = map the_single eqss
   112 
   113       val feqs = eqs
   114         |> tap (check_defs ctxt fixes) (* Standard checks *)
   115         |> tap (map (check_pats ctxt)) (* More checks for sequential mode *)
   116 
   117       val compleqs = add_catchall ctxt fixes feqs (* Completion *)
   118 
   119       val spliteqs = warn_if_redundant ctxt feqs
   120         (Function_Split.split_all_equations ctxt compleqs)
   121 
   122       fun restore_spec thms =
   123         bnds ~~ take (length bnds) (unflat spliteqs thms)
   124 
   125       val spliteqs' = flat (take (length bnds) spliteqs)
   126       val fnames = map (fst o fst) fixes
   127       val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs'
   128 
   129       fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs)
   130         |> map (map snd)
   131 
   132 
   133       val bnds' = bnds @ replicate (length spliteqs - length bnds) Attrib.empty_binding
   134 
   135       (* using theorem names for case name currently disabled *)
   136       val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) 
   137         (bnds' ~~ spliteqs) |> flat
   138     in
   139       (flat spliteqs, restore_spec, sort, case_names)
   140     end
   141   else
   142     Function_Common.empty_preproc check_defs config ctxt fixes spec
   143 
   144 val setup =
   145   Context.theory_map (Function_Common.set_preproc sequential_preproc)
   146 
   147 
   148 val fun_config = FunctionConfig { sequential=true, default="%x. undefined" (*FIXME dynamic scoping*), 
   149   domintros=false, partials=false, tailrec=false }
   150 
   151 fun gen_fun add config fixes statements int lthy =
   152   lthy
   153   |> add fixes statements config
   154   |> by_pat_completeness_auto int
   155   |> Local_Theory.restore
   156   |> termination_by (SIMPLE_METHOD o Function_Common.get_termination_prover lthy) int
   157 
   158 val add_fun = gen_fun Function.function
   159 val add_fun_cmd = gen_fun Function.function_cmd
   160 
   161 
   162 
   163 local structure P = OuterParse and K = OuterKeyword in
   164 
   165 val _ =
   166   OuterSyntax.local_theory' "fun" "define general recursive functions (short version)" K.thy_decl
   167   (function_parser fun_config
   168      >> (fn ((config, fixes), statements) => add_fun_cmd config fixes statements));
   169 
   170 end
   171 
   172 end