src/HOL/Tools/Function/fun.ML
author wenzelm
Mon May 17 23:54:15 2010 +0200 (2010-05-17)
changeset 36960 01594f816e3a
parent 36547 2a9d0ec8c10d
child 39276 2ad95934521f
permissions -rw-r--r--
prefer structure Keyword, Parse, Parse_Spec, Outer_Syntax;
eliminated old-style structure aliases K = Keyword, P = Parse;
     1 (*  Title:      HOL/Tools/Function/fun.ML
     2     Author:     Alexander Krauss, TU Muenchen
     3 
     4 Sequential mode for function definitions
     5 Command "fun" for fully automated function definitions
     6 *)
     7 
     8 signature FUNCTION_FUN =
     9 sig
    10   val add_fun : (binding * typ option * mixfix) list ->
    11     (Attrib.binding * term) list -> Function_Common.function_config ->
    12     local_theory -> Proof.context
    13   val add_fun_cmd : (binding * string option * mixfix) list ->
    14     (Attrib.binding * string) list -> Function_Common.function_config ->
    15     local_theory -> Proof.context
    16 
    17   val setup : theory -> theory
    18 end
    19 
    20 structure Function_Fun : FUNCTION_FUN =
    21 struct
    22 
    23 open Function_Lib
    24 open Function_Common
    25 
    26 
    27 fun check_pats ctxt geq =
    28   let
    29     fun err str = error (cat_lines ["Malformed definition:",
    30       str ^ " not allowed in sequential mode.",
    31       Syntax.string_of_term ctxt geq])
    32     val thy = ProofContext.theory_of ctxt
    33 
    34     fun check_constr_pattern (Bound _) = ()
    35       | check_constr_pattern t =
    36       let
    37         val (hd, args) = strip_comb t
    38       in
    39         (((case Datatype.info_of_constr thy (dest_Const hd) of
    40              SOME _ => ()
    41            | NONE => err "Non-constructor pattern")
    42           handle TERM ("dest_Const", _) => err "Non-constructor patterns");
    43          map check_constr_pattern args;
    44          ())
    45       end
    46 
    47     val (_, qs, gs, args, _) = split_def ctxt geq
    48 
    49     val _ = if not (null gs) then err "Conditional equations" else ()
    50     val _ = map check_constr_pattern args
    51 
    52     (* just count occurrences to check linearity *)
    53     val _ = if fold (fold_aterms (fn Bound _ => Integer.add 1 | _ => I)) args 0 > length qs
    54       then err "Nonlinear patterns" else ()
    55   in
    56     ()
    57   end
    58 
    59 fun mk_catchall fixes arity_of =
    60   let
    61     fun mk_eqn ((fname, fT), _) =
    62       let
    63         val n = arity_of fname
    64         val (argTs, rT) = chop n (binder_types fT)
    65           |> apsnd (fn Ts => Ts ---> body_type fT)
    66 
    67         val qs = map Free (Name.invent_list [] "a" n ~~ argTs)
    68       in
    69         HOLogic.mk_eq(list_comb (Free (fname, fT), qs),
    70           Const ("HOL.undefined", rT))
    71         |> HOLogic.mk_Trueprop
    72         |> fold_rev Logic.all qs
    73       end
    74   in
    75     map mk_eqn fixes
    76   end
    77 
    78 fun add_catchall ctxt fixes spec =
    79   let val fqgars = map (split_def ctxt) spec
    80       val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
    81                      |> AList.lookup (op =) #> the
    82   in
    83     spec @ mk_catchall fixes arity_of
    84   end
    85 
    86 fun warn_if_redundant ctxt origs tss =
    87   let
    88     fun msg t = "Ignoring redundant equation: " ^ quote (Syntax.string_of_term ctxt t)
    89 
    90     val (tss', _) = chop (length origs) tss
    91     fun check (t, []) = (warning (msg t); [])
    92       | check (t, s) = s
    93   in
    94     (map check (origs ~~ tss'); tss)
    95   end
    96 
    97 fun sequential_preproc (config as FunctionConfig {sequential, ...}) ctxt fixes spec =
    98   if sequential then
    99     let
   100       val (bnds, eqss) = split_list spec
   101 
   102       val eqs = map the_single eqss
   103 
   104       val feqs = eqs
   105         |> tap (check_defs ctxt fixes) (* Standard checks *)
   106         |> tap (map (check_pats ctxt)) (* More checks for sequential mode *)
   107 
   108       val compleqs = add_catchall ctxt fixes feqs (* Completion *)
   109 
   110       val spliteqs = warn_if_redundant ctxt feqs
   111         (Function_Split.split_all_equations ctxt compleqs)
   112 
   113       fun restore_spec thms =
   114         bnds ~~ take (length bnds) (unflat spliteqs thms)
   115 
   116       val spliteqs' = flat (take (length bnds) spliteqs)
   117       val fnames = map (fst o fst) fixes
   118       val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs'
   119 
   120       fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs)
   121         |> map (map snd)
   122 
   123 
   124       val bnds' = bnds @ replicate (length spliteqs - length bnds) Attrib.empty_binding
   125 
   126       (* using theorem names for case name currently disabled *)
   127       val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) 
   128         (bnds' ~~ spliteqs) |> flat
   129     in
   130       (flat spliteqs, restore_spec, sort, case_names)
   131     end
   132   else
   133     Function_Common.empty_preproc check_defs config ctxt fixes spec
   134 
   135 val setup =
   136   Context.theory_map (Function_Common.set_preproc sequential_preproc)
   137 
   138 
   139 val fun_config = FunctionConfig { sequential=true, default="%x. undefined" (*FIXME dynamic scoping*), 
   140   domintros=false, partials=false, tailrec=false }
   141 
   142 fun gen_add_fun add fixes statements config lthy =
   143   let
   144     fun pat_completeness_auto ctxt =
   145       Pat_Completeness.pat_completeness_tac ctxt 1
   146       THEN auto_tac (clasimpset_of ctxt)
   147     fun prove_termination lthy =
   148       Function.prove_termination NONE
   149         (Function_Common.get_termination_prover lthy lthy) lthy
   150   in
   151     lthy
   152     |> add fixes statements config pat_completeness_auto |> snd
   153     |> Local_Theory.restore
   154     |> prove_termination |> snd
   155   end
   156 
   157 val add_fun = gen_add_fun Function.add_function
   158 val add_fun_cmd = gen_add_fun Function.add_function_cmd
   159 
   160 
   161 
   162 val _ =
   163   Outer_Syntax.local_theory "fun" "define general recursive functions (short version)"
   164   Keyword.thy_decl
   165   (function_parser fun_config
   166      >> (fn ((config, fixes), statements) => add_fun_cmd fixes statements config))
   167 
   168 end