src/HOL/Tools/Function/fun.ML
changeset 34232 36a2a3029fd3
parent 33957 e9afca2118d4
child 36519 46bf776a81e0
equal deleted inserted replaced
34231:da4d7d40f2f9 34232:36a2a3029fd3
     5 Command "fun" for fully automated function definitions
     5 Command "fun" for fully automated function definitions
     6 *)
     6 *)
     7 
     7 
     8 signature FUNCTION_FUN =
     8 signature FUNCTION_FUN =
     9 sig
     9 sig
    10     val add_fun : Function_Common.function_config ->
    10   val add_fun : Function_Common.function_config ->
    11       (binding * typ option * mixfix) list -> (Attrib.binding * term) list ->
    11     (binding * typ option * mixfix) list -> (Attrib.binding * term) list ->
    12       bool -> local_theory -> Proof.context
    12     bool -> local_theory -> Proof.context
    13     val add_fun_cmd : Function_Common.function_config ->
    13   val add_fun_cmd : Function_Common.function_config ->
    14       (binding * string option * mixfix) list -> (Attrib.binding * string) list ->
    14     (binding * string option * mixfix) list -> (Attrib.binding * string) list ->
    15       bool -> local_theory -> Proof.context
    15     bool -> local_theory -> Proof.context
    16 
    16 
    17     val setup : theory -> theory
    17   val setup : theory -> theory
    18 end
    18 end
    19 
    19 
    20 structure Function_Fun : FUNCTION_FUN =
    20 structure Function_Fun : FUNCTION_FUN =
    21 struct
    21 struct
    22 
    22 
    23 open Function_Lib
    23 open Function_Lib
    24 open Function_Common
    24 open Function_Common
    25 
    25 
    26 
    26 
    27 fun check_pats ctxt geq =
    27 fun check_pats ctxt geq =
    28     let 
    28   let
    29       fun err str = error (cat_lines ["Malformed definition:",
    29     fun err str = error (cat_lines ["Malformed definition:",
    30                                       str ^ " not allowed in sequential mode.",
    30       str ^ " not allowed in sequential mode.",
    31                                       Syntax.string_of_term ctxt geq])
    31       Syntax.string_of_term ctxt geq])
    32       val thy = ProofContext.theory_of ctxt
    32     val thy = ProofContext.theory_of ctxt
    33                 
    33 
    34       fun check_constr_pattern (Bound _) = ()
    34     fun check_constr_pattern (Bound _) = ()
    35         | check_constr_pattern t =
    35       | check_constr_pattern t =
    36           let
    36       let
    37             val (hd, args) = strip_comb t
    37         val (hd, args) = strip_comb t
    38           in
    38       in
    39             (((case Datatype.info_of_constr thy (dest_Const hd) of
    39         (((case Datatype.info_of_constr thy (dest_Const hd) of
    40                  SOME _ => ()
    40              SOME _ => ()
    41                | NONE => err "Non-constructor pattern")
    41            | NONE => err "Non-constructor pattern")
    42               handle TERM ("dest_Const", _) => err "Non-constructor patterns");
    42           handle TERM ("dest_Const", _) => err "Non-constructor patterns");
    43              map check_constr_pattern args; 
    43          map check_constr_pattern args;
    44              ())
    44          ())
    45           end
    45       end
    46           
    46 
    47       val (_, qs, gs, args, _) = split_def ctxt geq 
    47     val (_, qs, gs, args, _) = split_def ctxt geq
    48                                        
    48 
    49       val _ = if not (null gs) then err "Conditional equations" else ()
    49     val _ = if not (null gs) then err "Conditional equations" else ()
    50       val _ = map check_constr_pattern args
    50     val _ = map check_constr_pattern args
    51                   
    51 
    52                   (* just count occurrences to check linearity *)
    52     (* just count occurrences to check linearity *)
    53       val _ = if fold (fold_aterms (fn Bound _ => Integer.add 1 | _ => I)) args 0 > length qs
    53     val _ = if fold (fold_aterms (fn Bound _ => Integer.add 1 | _ => I)) args 0 > length qs
    54               then err "Nonlinear patterns" else ()
    54       then err "Nonlinear patterns" else ()
    55     in
    55   in
    56       ()
    56     ()
    57     end
    57   end
    58     
    58 
    59 val by_pat_completeness_auto =
    59 val by_pat_completeness_auto =
    60     Proof.global_future_terminal_proof
    60   Proof.global_future_terminal_proof
    61       (Method.Basic Pat_Completeness.pat_completeness,
    61     (Method.Basic Pat_Completeness.pat_completeness,
    62        SOME (Method.Source_i (Args.src (("HOL.auto", []), Position.none))))
    62      SOME (Method.Source_i (Args.src (("HOL.auto", []), Position.none))))
    63 
    63 
    64 fun termination_by method int =
    64 fun termination_by method int =
    65     Function.termination_proof NONE
    65   Function.termination_proof NONE
    66     #> Proof.global_future_terminal_proof (Method.Basic method, NONE) int
    66   #> Proof.global_future_terminal_proof (Method.Basic method, NONE) int
    67 
    67 
    68 fun mk_catchall fixes arity_of =
    68 fun mk_catchall fixes arity_of =
    69     let
    69   let
    70       fun mk_eqn ((fname, fT), _) =
    70     fun mk_eqn ((fname, fT), _) =
    71           let 
    71       let
    72             val n = arity_of fname
    72         val n = arity_of fname
    73             val (argTs, rT) = chop n (binder_types fT)
    73         val (argTs, rT) = chop n (binder_types fT)
    74                                    |> apsnd (fn Ts => Ts ---> body_type fT) 
    74           |> apsnd (fn Ts => Ts ---> body_type fT)
    75                               
    75 
    76             val qs = map Free (Name.invent_list [] "a" n ~~ argTs)
    76         val qs = map Free (Name.invent_list [] "a" n ~~ argTs)
    77           in
    77       in
    78             HOLogic.mk_eq(list_comb (Free (fname, fT), qs),
    78         HOLogic.mk_eq(list_comb (Free (fname, fT), qs),
    79                           Const ("HOL.undefined", rT))
    79           Const ("HOL.undefined", rT))
    80               |> HOLogic.mk_Trueprop
    80         |> HOLogic.mk_Trueprop
    81               |> fold_rev Logic.all qs
    81         |> fold_rev Logic.all qs
    82           end
    82       end
    83     in
    83   in
    84       map mk_eqn fixes
    84     map mk_eqn fixes
    85     end
    85   end
    86 
    86 
    87 fun add_catchall ctxt fixes spec =
    87 fun add_catchall ctxt fixes spec =
    88   let val fqgars = map (split_def ctxt) spec
    88   let val fqgars = map (split_def ctxt) spec
    89       val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
    89       val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
    90                      |> AList.lookup (op =) #> the
    90                      |> AList.lookup (op =) #> the
    91   in
    91   in
    92     spec @ mk_catchall fixes arity_of
    92     spec @ mk_catchall fixes arity_of
    93   end
    93   end
    94 
    94 
    95 fun warn_if_redundant ctxt origs tss =
    95 fun warn_if_redundant ctxt origs tss =
       
    96   let
       
    97     fun msg t = "Ignoring redundant equation: " ^ quote (Syntax.string_of_term ctxt t)
       
    98 
       
    99     val (tss', _) = chop (length origs) tss
       
   100     fun check (t, []) = (warning (msg t); [])
       
   101       | check (t, s) = s
       
   102   in
       
   103     (map check (origs ~~ tss'); tss)
       
   104   end
       
   105 
       
   106 fun sequential_preproc (config as FunctionConfig {sequential, ...}) ctxt fixes spec =
       
   107   if sequential then
    96     let
   108     let
    97         fun msg t = "Ignoring redundant equation: " ^ quote (Syntax.string_of_term ctxt t)
   109       val (bnds, eqss) = split_list spec
    98                     
   110 
    99         val (tss', _) = chop (length origs) tss
   111       val eqs = map the_single eqss
   100         fun check (t, []) = (warning (msg t); [])
   112 
   101           | check (t, s) = s
   113       val feqs = eqs
   102     in
   114         |> tap (check_defs ctxt fixes) (* Standard checks *)
   103         (map check (origs ~~ tss'); tss)
   115         |> tap (map (check_pats ctxt)) (* More checks for sequential mode *)
   104     end
   116 
       
   117       val compleqs = add_catchall ctxt fixes feqs (* Completion *)
       
   118 
       
   119       val spliteqs = warn_if_redundant ctxt feqs
       
   120         (Function_Split.split_all_equations ctxt compleqs)
       
   121 
       
   122       fun restore_spec thms =
       
   123         bnds ~~ take (length bnds) (unflat spliteqs thms)
       
   124 
       
   125       val spliteqs' = flat (take (length bnds) spliteqs)
       
   126       val fnames = map (fst o fst) fixes
       
   127       val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs'
       
   128 
       
   129       fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs)
       
   130         |> map (map snd)
   105 
   131 
   106 
   132 
   107 fun sequential_preproc (config as FunctionConfig {sequential, ...}) ctxt fixes spec =
   133       val bnds' = bnds @ replicate (length spliteqs - length bnds) Attrib.empty_binding
   108       if sequential then
       
   109         let
       
   110           val (bnds, eqss) = split_list spec
       
   111                             
       
   112           val eqs = map the_single eqss
       
   113                     
       
   114           val feqs = eqs
       
   115                       |> tap (check_defs ctxt fixes) (* Standard checks *)
       
   116                       |> tap (map (check_pats ctxt))    (* More checks for sequential mode *)
       
   117 
   134 
   118           val compleqs = add_catchall ctxt fixes feqs   (* Completion *)
   135       (* using theorem names for case name currently disabled *)
   119 
   136       val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) 
   120           val spliteqs = warn_if_redundant ctxt feqs
   137         (bnds' ~~ spliteqs) |> flat
   121                            (Function_Split.split_all_equations ctxt compleqs)
   138     in
   122 
   139       (flat spliteqs, restore_spec, sort, case_names)
   123           fun restore_spec thms =
   140     end
   124               bnds ~~ take (length bnds) (unflat spliteqs thms)
   141   else
   125               
   142     Function_Common.empty_preproc check_defs config ctxt fixes spec
   126           val spliteqs' = flat (take (length bnds) spliteqs)
       
   127           val fnames = map (fst o fst) fixes
       
   128           val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs'
       
   129 
       
   130           fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs)
       
   131                                        |> map (map snd)
       
   132 
       
   133 
       
   134           val bnds' = bnds @ replicate (length spliteqs - length bnds) Attrib.empty_binding
       
   135 
       
   136           (* using theorem names for case name currently disabled *)
       
   137           val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) 
       
   138                                      (bnds' ~~ spliteqs)
       
   139                            |> flat
       
   140         in
       
   141           (flat spliteqs, restore_spec, sort, case_names)
       
   142         end
       
   143       else
       
   144         Function_Common.empty_preproc check_defs config ctxt fixes spec
       
   145 
   143 
   146 val setup =
   144 val setup =
   147   Context.theory_map (Function_Common.set_preproc sequential_preproc)
   145   Context.theory_map (Function_Common.set_preproc sequential_preproc)
   148 
   146 
   149 
   147 
   150 val fun_config = FunctionConfig { sequential=true, default="%x. undefined" (*FIXME dynamic scoping*), 
   148 val fun_config = FunctionConfig { sequential=true, default="%x. undefined" (*FIXME dynamic scoping*), 
   151   domintros=false, partials=false, tailrec=false }
   149   domintros=false, partials=false, tailrec=false }
   152 
   150 
   153 fun gen_fun add config fixes statements int lthy =
   151 fun gen_fun add config fixes statements int lthy =
   154   lthy
   152   lthy
   155     |> add fixes statements config
   153   |> add fixes statements config
   156     |> by_pat_completeness_auto int
   154   |> by_pat_completeness_auto int
   157     |> Local_Theory.restore
   155   |> Local_Theory.restore
   158     |> termination_by (Function_Common.get_termination_prover lthy) int
   156   |> termination_by (Function_Common.get_termination_prover lthy) int
   159 
   157 
   160 val add_fun = gen_fun Function.add_function
   158 val add_fun = gen_fun Function.add_function
   161 val add_fun_cmd = gen_fun Function.add_function_cmd
   159 val add_fun_cmd = gen_fun Function.add_function_cmd
   162 
   160 
   163 
   161