src/HOL/Tools/Function/fun.ML
changeset 34232 36a2a3029fd3
parent 33957 e9afca2118d4
child 36519 46bf776a81e0
     1.1 --- a/src/HOL/Tools/Function/fun.ML	Sat Jan 02 23:18:58 2010 +0100
     1.2 +++ b/src/HOL/Tools/Function/fun.ML	Sat Jan 02 23:18:58 2010 +0100
     1.3 @@ -7,14 +7,14 @@
     1.4  
     1.5  signature FUNCTION_FUN =
     1.6  sig
     1.7 -    val add_fun : Function_Common.function_config ->
     1.8 -      (binding * typ option * mixfix) list -> (Attrib.binding * term) list ->
     1.9 -      bool -> local_theory -> Proof.context
    1.10 -    val add_fun_cmd : Function_Common.function_config ->
    1.11 -      (binding * string option * mixfix) list -> (Attrib.binding * string) list ->
    1.12 -      bool -> local_theory -> Proof.context
    1.13 +  val add_fun : Function_Common.function_config ->
    1.14 +    (binding * typ option * mixfix) list -> (Attrib.binding * term) list ->
    1.15 +    bool -> local_theory -> Proof.context
    1.16 +  val add_fun_cmd : Function_Common.function_config ->
    1.17 +    (binding * string option * mixfix) list -> (Attrib.binding * string) list ->
    1.18 +    bool -> local_theory -> Proof.context
    1.19  
    1.20 -    val setup : theory -> theory
    1.21 +  val setup : theory -> theory
    1.22  end
    1.23  
    1.24  structure Function_Fun : FUNCTION_FUN =
    1.25 @@ -25,64 +25,64 @@
    1.26  
    1.27  
    1.28  fun check_pats ctxt geq =
    1.29 -    let 
    1.30 -      fun err str = error (cat_lines ["Malformed definition:",
    1.31 -                                      str ^ " not allowed in sequential mode.",
    1.32 -                                      Syntax.string_of_term ctxt geq])
    1.33 -      val thy = ProofContext.theory_of ctxt
    1.34 -                
    1.35 -      fun check_constr_pattern (Bound _) = ()
    1.36 -        | check_constr_pattern t =
    1.37 -          let
    1.38 -            val (hd, args) = strip_comb t
    1.39 -          in
    1.40 -            (((case Datatype.info_of_constr thy (dest_Const hd) of
    1.41 -                 SOME _ => ()
    1.42 -               | NONE => err "Non-constructor pattern")
    1.43 -              handle TERM ("dest_Const", _) => err "Non-constructor patterns");
    1.44 -             map check_constr_pattern args; 
    1.45 -             ())
    1.46 -          end
    1.47 -          
    1.48 -      val (_, qs, gs, args, _) = split_def ctxt geq 
    1.49 -                                       
    1.50 -      val _ = if not (null gs) then err "Conditional equations" else ()
    1.51 -      val _ = map check_constr_pattern args
    1.52 -                  
    1.53 -                  (* just count occurrences to check linearity *)
    1.54 -      val _ = if fold (fold_aterms (fn Bound _ => Integer.add 1 | _ => I)) args 0 > length qs
    1.55 -              then err "Nonlinear patterns" else ()
    1.56 -    in
    1.57 -      ()
    1.58 -    end
    1.59 -    
    1.60 +  let
    1.61 +    fun err str = error (cat_lines ["Malformed definition:",
    1.62 +      str ^ " not allowed in sequential mode.",
    1.63 +      Syntax.string_of_term ctxt geq])
    1.64 +    val thy = ProofContext.theory_of ctxt
    1.65 +
    1.66 +    fun check_constr_pattern (Bound _) = ()
    1.67 +      | check_constr_pattern t =
    1.68 +      let
    1.69 +        val (hd, args) = strip_comb t
    1.70 +      in
    1.71 +        (((case Datatype.info_of_constr thy (dest_Const hd) of
    1.72 +             SOME _ => ()
    1.73 +           | NONE => err "Non-constructor pattern")
    1.74 +          handle TERM ("dest_Const", _) => err "Non-constructor patterns");
    1.75 +         map check_constr_pattern args;
    1.76 +         ())
    1.77 +      end
    1.78 +
    1.79 +    val (_, qs, gs, args, _) = split_def ctxt geq
    1.80 +
    1.81 +    val _ = if not (null gs) then err "Conditional equations" else ()
    1.82 +    val _ = map check_constr_pattern args
    1.83 +
    1.84 +    (* just count occurrences to check linearity *)
    1.85 +    val _ = if fold (fold_aterms (fn Bound _ => Integer.add 1 | _ => I)) args 0 > length qs
    1.86 +      then err "Nonlinear patterns" else ()
    1.87 +  in
    1.88 +    ()
    1.89 +  end
    1.90 +
    1.91  val by_pat_completeness_auto =
    1.92 -    Proof.global_future_terminal_proof
    1.93 -      (Method.Basic Pat_Completeness.pat_completeness,
    1.94 -       SOME (Method.Source_i (Args.src (("HOL.auto", []), Position.none))))
    1.95 +  Proof.global_future_terminal_proof
    1.96 +    (Method.Basic Pat_Completeness.pat_completeness,
    1.97 +     SOME (Method.Source_i (Args.src (("HOL.auto", []), Position.none))))
    1.98  
    1.99  fun termination_by method int =
   1.100 -    Function.termination_proof NONE
   1.101 -    #> Proof.global_future_terminal_proof (Method.Basic method, NONE) int
   1.102 +  Function.termination_proof NONE
   1.103 +  #> Proof.global_future_terminal_proof (Method.Basic method, NONE) int
   1.104  
   1.105  fun mk_catchall fixes arity_of =
   1.106 -    let
   1.107 -      fun mk_eqn ((fname, fT), _) =
   1.108 -          let 
   1.109 -            val n = arity_of fname
   1.110 -            val (argTs, rT) = chop n (binder_types fT)
   1.111 -                                   |> apsnd (fn Ts => Ts ---> body_type fT) 
   1.112 -                              
   1.113 -            val qs = map Free (Name.invent_list [] "a" n ~~ argTs)
   1.114 -          in
   1.115 -            HOLogic.mk_eq(list_comb (Free (fname, fT), qs),
   1.116 -                          Const ("HOL.undefined", rT))
   1.117 -              |> HOLogic.mk_Trueprop
   1.118 -              |> fold_rev Logic.all qs
   1.119 -          end
   1.120 -    in
   1.121 -      map mk_eqn fixes
   1.122 -    end
   1.123 +  let
   1.124 +    fun mk_eqn ((fname, fT), _) =
   1.125 +      let
   1.126 +        val n = arity_of fname
   1.127 +        val (argTs, rT) = chop n (binder_types fT)
   1.128 +          |> apsnd (fn Ts => Ts ---> body_type fT)
   1.129 +
   1.130 +        val qs = map Free (Name.invent_list [] "a" n ~~ argTs)
   1.131 +      in
   1.132 +        HOLogic.mk_eq(list_comb (Free (fname, fT), qs),
   1.133 +          Const ("HOL.undefined", rT))
   1.134 +        |> HOLogic.mk_Trueprop
   1.135 +        |> fold_rev Logic.all qs
   1.136 +      end
   1.137 +  in
   1.138 +    map mk_eqn fixes
   1.139 +  end
   1.140  
   1.141  fun add_catchall ctxt fixes spec =
   1.142    let val fqgars = map (split_def ctxt) spec
   1.143 @@ -93,55 +93,53 @@
   1.144    end
   1.145  
   1.146  fun warn_if_redundant ctxt origs tss =
   1.147 -    let
   1.148 -        fun msg t = "Ignoring redundant equation: " ^ quote (Syntax.string_of_term ctxt t)
   1.149 -                    
   1.150 -        val (tss', _) = chop (length origs) tss
   1.151 -        fun check (t, []) = (warning (msg t); [])
   1.152 -          | check (t, s) = s
   1.153 -    in
   1.154 -        (map check (origs ~~ tss'); tss)
   1.155 -    end
   1.156 +  let
   1.157 +    fun msg t = "Ignoring redundant equation: " ^ quote (Syntax.string_of_term ctxt t)
   1.158  
   1.159 +    val (tss', _) = chop (length origs) tss
   1.160 +    fun check (t, []) = (warning (msg t); [])
   1.161 +      | check (t, s) = s
   1.162 +  in
   1.163 +    (map check (origs ~~ tss'); tss)
   1.164 +  end
   1.165  
   1.166  fun sequential_preproc (config as FunctionConfig {sequential, ...}) ctxt fixes spec =
   1.167 -      if sequential then
   1.168 -        let
   1.169 -          val (bnds, eqss) = split_list spec
   1.170 -                            
   1.171 -          val eqs = map the_single eqss
   1.172 -                    
   1.173 -          val feqs = eqs
   1.174 -                      |> tap (check_defs ctxt fixes) (* Standard checks *)
   1.175 -                      |> tap (map (check_pats ctxt))    (* More checks for sequential mode *)
   1.176 +  if sequential then
   1.177 +    let
   1.178 +      val (bnds, eqss) = split_list spec
   1.179 +
   1.180 +      val eqs = map the_single eqss
   1.181  
   1.182 -          val compleqs = add_catchall ctxt fixes feqs   (* Completion *)
   1.183 +      val feqs = eqs
   1.184 +        |> tap (check_defs ctxt fixes) (* Standard checks *)
   1.185 +        |> tap (map (check_pats ctxt)) (* More checks for sequential mode *)
   1.186 +
   1.187 +      val compleqs = add_catchall ctxt fixes feqs (* Completion *)
   1.188  
   1.189 -          val spliteqs = warn_if_redundant ctxt feqs
   1.190 -                           (Function_Split.split_all_equations ctxt compleqs)
   1.191 +      val spliteqs = warn_if_redundant ctxt feqs
   1.192 +        (Function_Split.split_all_equations ctxt compleqs)
   1.193 +
   1.194 +      fun restore_spec thms =
   1.195 +        bnds ~~ take (length bnds) (unflat spliteqs thms)
   1.196  
   1.197 -          fun restore_spec thms =
   1.198 -              bnds ~~ take (length bnds) (unflat spliteqs thms)
   1.199 -              
   1.200 -          val spliteqs' = flat (take (length bnds) spliteqs)
   1.201 -          val fnames = map (fst o fst) fixes
   1.202 -          val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs'
   1.203 +      val spliteqs' = flat (take (length bnds) spliteqs)
   1.204 +      val fnames = map (fst o fst) fixes
   1.205 +      val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs'
   1.206  
   1.207 -          fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs)
   1.208 -                                       |> map (map snd)
   1.209 +      fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs)
   1.210 +        |> map (map snd)
   1.211  
   1.212  
   1.213 -          val bnds' = bnds @ replicate (length spliteqs - length bnds) Attrib.empty_binding
   1.214 +      val bnds' = bnds @ replicate (length spliteqs - length bnds) Attrib.empty_binding
   1.215  
   1.216 -          (* using theorem names for case name currently disabled *)
   1.217 -          val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) 
   1.218 -                                     (bnds' ~~ spliteqs)
   1.219 -                           |> flat
   1.220 -        in
   1.221 -          (flat spliteqs, restore_spec, sort, case_names)
   1.222 -        end
   1.223 -      else
   1.224 -        Function_Common.empty_preproc check_defs config ctxt fixes spec
   1.225 +      (* using theorem names for case name currently disabled *)
   1.226 +      val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) 
   1.227 +        (bnds' ~~ spliteqs) |> flat
   1.228 +    in
   1.229 +      (flat spliteqs, restore_spec, sort, case_names)
   1.230 +    end
   1.231 +  else
   1.232 +    Function_Common.empty_preproc check_defs config ctxt fixes spec
   1.233  
   1.234  val setup =
   1.235    Context.theory_map (Function_Common.set_preproc sequential_preproc)
   1.236 @@ -152,10 +150,10 @@
   1.237  
   1.238  fun gen_fun add config fixes statements int lthy =
   1.239    lthy
   1.240 -    |> add fixes statements config
   1.241 -    |> by_pat_completeness_auto int
   1.242 -    |> Local_Theory.restore
   1.243 -    |> termination_by (Function_Common.get_termination_prover lthy) int
   1.244 +  |> add fixes statements config
   1.245 +  |> by_pat_completeness_auto int
   1.246 +  |> Local_Theory.restore
   1.247 +  |> termination_by (Function_Common.get_termination_prover lthy) int
   1.248  
   1.249  val add_fun = gen_fun Function.add_function
   1.250  val add_fun_cmd = gen_fun Function.add_function_cmd