src/HOL/Tools/Function/function_common.ML
author krauss
Sat Jan 02 23:18:58 2010 +0100 (2010-01-02)
changeset 34232 36a2a3029fd3
parent 34231 da4d7d40f2f9
child 36521 73ed9f18fdd3
permissions -rw-r--r--
new year's resolution: reindented code in function package
     1 (*  Title:      HOL/Tools/Function/function_common.ML
     2     Author:     Alexander Krauss, TU Muenchen
     3 
     4 A package for general recursive function definitions.
     5 Common definitions and other infrastructure.
     6 *)
     7 
     8 signature FUNCTION_DATA =
     9 sig
    10 
    11 type info =
    12  {is_partial : bool,
    13   defname : string,
    14     (* contains no logical entities: invariant under morphisms: *)
    15   add_simps : (binding -> binding) -> string -> (binding -> binding) ->
    16     Attrib.src list -> thm list -> local_theory -> thm list * local_theory,
    17   case_names : string list,
    18   fs : term list,
    19   R : term,
    20   psimps: thm list,
    21   pinducts: thm list,
    22   simps : thm list option,
    23   inducts : thm list option,
    24   termination: thm}
    25 
    26 end
    27 
    28 structure Function_Data : FUNCTION_DATA =
    29 struct
    30 
    31 type info =
    32  {is_partial : bool,
    33   defname : string,
    34     (* contains no logical entities: invariant under morphisms: *)
    35   add_simps : (binding -> binding) -> string -> (binding -> binding) ->
    36     Attrib.src list -> thm list -> local_theory -> thm list * local_theory,
    37   case_names : string list,
    38   fs : term list,
    39   R : term,
    40   psimps: thm list,
    41   pinducts: thm list,
    42   simps : thm list option,
    43   inducts : thm list option,
    44   termination: thm}
    45 
    46 end
    47 
    48 structure Function_Common =
    49 struct
    50 
    51 open Function_Data
    52 
    53 local open Function_Lib in
    54 
    55 (* Profiling *)
    56 val profile = Unsynchronized.ref false;
    57 
    58 fun PROFILE msg = if !profile then timeap_msg msg else I
    59 
    60 
    61 val acc_const_name = @{const_name accp}
    62 fun mk_acc domT R =
    63   Const (acc_const_name, (domT --> domT --> HOLogic.boolT) --> domT --> HOLogic.boolT) $ R 
    64 
    65 val function_name = suffix "C"
    66 val graph_name = suffix "_graph"
    67 val rel_name = suffix "_rel"
    68 val dom_name = suffix "_dom"
    69 
    70 (* Termination rules *)
    71 
    72 structure TerminationRule = Generic_Data
    73 (
    74   type T = thm list
    75   val empty = []
    76   val extend = I
    77   val merge = Thm.merge_thms
    78 );
    79 
    80 val get_termination_rules = TerminationRule.get
    81 val store_termination_rule = TerminationRule.map o cons
    82 val apply_termination_rule = resolve_tac o get_termination_rules o Context.Proof
    83 
    84 
    85 (* Function definition result data *)
    86 
    87 datatype function_result = FunctionResult of
    88  {fs: term list,
    89   G: term,
    90   R: term,
    91 
    92   psimps : thm list,
    93   trsimps : thm list option,
    94 
    95   simple_pinducts : thm list,
    96   cases : thm,
    97   termination : thm,
    98   domintros : thm list option}
    99 
   100 fun morph_function_data ({add_simps, case_names, fs, R, psimps, pinducts,
   101   simps, inducts, termination, defname, is_partial} : info) phi =
   102     let
   103       val term = Morphism.term phi val thm = Morphism.thm phi val fact = Morphism.fact phi
   104       val name = Binding.name_of o Morphism.binding phi o Binding.name
   105     in
   106       { add_simps = add_simps, case_names = case_names,
   107         fs = map term fs, R = term R, psimps = fact psimps,
   108         pinducts = fact pinducts, simps = Option.map fact simps,
   109         inducts = Option.map fact inducts, termination = thm termination,
   110         defname = name defname, is_partial=is_partial }
   111     end
   112 
   113 structure FunctionData = Generic_Data
   114 (
   115   type T = (term * info) Item_Net.T;
   116   val empty : T = Item_Net.init (op aconv o pairself fst) (single o fst);
   117   val extend = I;
   118   fun merge tabs : T = Item_Net.merge tabs;
   119 )
   120 
   121 val get_function = FunctionData.get o Context.Proof;
   122 
   123 
   124 fun lift_morphism thy f =
   125   let
   126     val term = Drule.term_rule thy f
   127   in
   128     Morphism.thm_morphism f $> Morphism.term_morphism term
   129     $> Morphism.typ_morphism (Logic.type_map term)
   130   end
   131 
   132 fun import_function_data t ctxt =
   133   let
   134     val thy = ProofContext.theory_of ctxt
   135     val ct = cterm_of thy t
   136     val inst_morph = lift_morphism thy o Thm.instantiate
   137 
   138     fun match (trm, data) =
   139       SOME (morph_function_data data (inst_morph (Thm.match (cterm_of thy trm, ct))))
   140       handle Pattern.MATCH => NONE
   141   in
   142     get_first match (Item_Net.retrieve (get_function ctxt) t)
   143   end
   144 
   145 fun import_last_function ctxt =
   146   case Item_Net.content (get_function ctxt) of
   147     [] => NONE
   148   | (t, data) :: _ =>
   149     let
   150       val ([t'], ctxt') = Variable.import_terms true [t] ctxt
   151     in
   152       import_function_data t' ctxt'
   153     end
   154 
   155 val all_function_data = Item_Net.content o get_function
   156 
   157 fun add_function_data (data : info as {fs, termination, ...}) =
   158   FunctionData.map (fold (fn f => Item_Net.update (f, data)) fs)
   159   #> store_termination_rule termination
   160 
   161 
   162 (* Simp rules for termination proofs *)
   163 
   164 structure Termination_Simps = Named_Thms
   165 (
   166   val name = "termination_simp"
   167   val description = "Simplification rule for termination proofs"
   168 )
   169 
   170 
   171 (* Default Termination Prover *)
   172 
   173 structure TerminationProver = Generic_Data
   174 (
   175   type T = Proof.context -> Proof.method
   176   val empty = (fn _ => error "Termination prover not configured")
   177   val extend = I
   178   fun merge (a, b) = b  (* FIXME ? *)
   179 )
   180 
   181 val set_termination_prover = TerminationProver.put
   182 val get_termination_prover = TerminationProver.get o Context.Proof
   183 
   184 
   185 (* Configuration management *)
   186 datatype function_opt
   187   = Sequential
   188   | Default of string
   189   | DomIntros
   190   | No_Partials
   191   | Tailrec
   192 
   193 datatype function_config = FunctionConfig of
   194  {sequential: bool,
   195   default: string,
   196   domintros: bool,
   197   partials: bool,
   198   tailrec: bool}
   199 
   200 fun apply_opt Sequential (FunctionConfig {sequential, default, domintros, partials, tailrec}) =
   201     FunctionConfig {sequential=true, default=default, domintros=domintros, partials=partials, tailrec=tailrec}
   202   | apply_opt (Default d) (FunctionConfig {sequential, default, domintros, partials, tailrec}) =
   203     FunctionConfig {sequential=sequential, default=d, domintros=domintros, partials=partials, tailrec=tailrec}
   204   | apply_opt DomIntros (FunctionConfig {sequential, default, domintros, partials, tailrec}) =
   205     FunctionConfig {sequential=sequential, default=default, domintros=true, partials=partials, tailrec=tailrec}
   206   | apply_opt Tailrec (FunctionConfig {sequential, default, domintros, partials, tailrec}) =
   207     FunctionConfig {sequential=sequential, default=default, domintros=domintros, partials=partials, tailrec=true}
   208   | apply_opt No_Partials (FunctionConfig {sequential, default, domintros, partials, tailrec}) =
   209     FunctionConfig {sequential=sequential, default=default, domintros=domintros, partials=false, tailrec=true}
   210 
   211 val default_config =
   212   FunctionConfig { sequential=false, default="%x. undefined" (*FIXME dynamic scoping*), 
   213     domintros=false, partials=true, tailrec=false }
   214 
   215 
   216 (* Analyzing function equations *)
   217 
   218 fun split_def ctxt geq =
   219   let
   220     fun input_error msg = cat_lines [msg, Syntax.string_of_term ctxt geq]
   221     val qs = Term.strip_qnt_vars "all" geq
   222     val imp = Term.strip_qnt_body "all" geq
   223     val (gs, eq) = Logic.strip_horn imp
   224 
   225     val (f_args, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
   226       handle TERM _ => error (input_error "Not an equation")
   227 
   228     val (head, args) = strip_comb f_args
   229 
   230     val fname = fst (dest_Free head)
   231       handle TERM _ => error (input_error "Head symbol must not be a bound variable")
   232   in
   233     (fname, qs, gs, args, rhs)
   234   end
   235 
   236 (* Check for all sorts of errors in the input *)
   237 fun check_defs ctxt fixes eqs =
   238   let
   239     val fnames = map (fst o fst) fixes
   240 
   241     fun check geq =
   242       let
   243         fun input_error msg = error (cat_lines [msg, Syntax.string_of_term ctxt geq])
   244 
   245         val fqgar as (fname, qs, gs, args, rhs) = split_def ctxt geq
   246 
   247         val _ = fname mem fnames
   248           orelse input_error ("Head symbol of left hand side must be " ^
   249             plural "" "one out of " fnames ^ commas_quote fnames)
   250 
   251         val _ = length args > 0 orelse input_error "Function has no arguments:"
   252 
   253         fun add_bvs t is = add_loose_bnos (t, 0, is)
   254             val rvs = (subtract (op =) (fold add_bvs args []) (add_bvs rhs []))
   255                         |> map (fst o nth (rev qs))
   256 
   257         val _ = null rvs orelse input_error
   258           ("Variable" ^ plural " " "s " rvs ^ commas_quote rvs ^
   259            " occur" ^ plural "s" "" rvs ^ " on right hand side only:")
   260 
   261         val _ = forall (not o Term.exists_subterm
   262           (fn Free (n, _) => n mem fnames | _ => false)) (gs @ args)
   263           orelse input_error "Defined function may not occur in premises or arguments"
   264 
   265         val freeargs = map (fn t => subst_bounds (rev (map Free qs), t)) args
   266         val funvars = filter (fn q => exists (exists_subterm (fn (Free q') $ _ => q = q' | _ => false)) freeargs) qs
   267         val _ = null funvars orelse (warning (cat_lines
   268           ["Bound variable" ^ plural " " "s " funvars ^
   269           commas_quote (map fst funvars) ^ " occur" ^ plural "s" "" funvars ^
   270           " in function position.", "Misspelled constructor???"]); true)
   271       in
   272         (fname, length args)
   273       end
   274 
   275     val grouped_args = AList.group (op =) (map check eqs)
   276     val _ = grouped_args
   277       |> map (fn (fname, ars) =>
   278         length (distinct (op =) ars) = 1
   279         orelse error ("Function " ^ quote fname ^
   280           " has different numbers of arguments in different equations"))
   281 
   282     val not_defined = subtract (op =) (map fst grouped_args) fnames
   283     val _ = null not_defined
   284       orelse error ("No defining equations for function" ^
   285         plural " " "s " not_defined ^ commas_quote not_defined)
   286 
   287     fun check_sorts ((fname, fT), _) =
   288       Sorts.of_sort (Sign.classes_of (ProofContext.theory_of ctxt)) (fT, HOLogic.typeS)
   289       orelse error (cat_lines
   290       ["Type of " ^ quote fname ^ " is not of sort " ^ quote "type" ^ ":",
   291        setmp_CRITICAL show_sorts true (Syntax.string_of_typ ctxt) fT])
   292 
   293     val _ = map check_sorts fixes
   294   in
   295     ()
   296   end
   297 
   298 (* Preprocessors *)
   299 
   300 type fixes = ((string * typ) * mixfix) list
   301 type 'a spec = (Attrib.binding * 'a list) list
   302 type preproc = function_config -> Proof.context -> fixes -> term spec ->
   303   (term list * (thm list -> thm spec) * (thm list -> thm list list) * string list)
   304 
   305 val fname_of = fst o dest_Free o fst o strip_comb o fst o HOLogic.dest_eq o
   306   HOLogic.dest_Trueprop o Logic.strip_imp_concl o snd o dest_all_all
   307 
   308 fun mk_case_names i "" k = mk_case_names i (string_of_int (i + 1)) k
   309   | mk_case_names _ n 0 = []
   310   | mk_case_names _ n 1 = [n]
   311   | mk_case_names _ n k = map (fn i => n ^ "_" ^ string_of_int i) (1 upto k)
   312 
   313 fun empty_preproc check _ ctxt fixes spec =
   314   let
   315     val (bnds, tss) = split_list spec
   316     val ts = flat tss
   317     val _ = check ctxt fixes ts
   318     val fnames = map (fst o fst) fixes
   319     val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) ts
   320 
   321     fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) 
   322       (indices ~~ xs) |> map (map snd)
   323 
   324     (* using theorem names for case name currently disabled *)
   325     val cnames = map_index (fn (i, _) => mk_case_names i "" 1) bnds |> flat
   326   in
   327     (ts, curry op ~~ bnds o Library.unflat tss, sort, cnames)
   328   end
   329 
   330 structure Preprocessor = Generic_Data
   331 (
   332   type T = preproc
   333   val empty : T = empty_preproc check_defs
   334   val extend = I
   335   fun merge (a, _) = a
   336 )
   337 
   338 val get_preproc = Preprocessor.get o Context.Proof
   339 val set_preproc = Preprocessor.map o K
   340 
   341 
   342 
   343 local
   344   structure P = OuterParse and K = OuterKeyword
   345 
   346   val option_parser = P.group "option"
   347     ((P.reserved "sequential" >> K Sequential)
   348      || ((P.reserved "default" |-- P.term) >> Default)
   349      || (P.reserved "domintros" >> K DomIntros)
   350      || (P.reserved "no_partials" >> K No_Partials)
   351      || (P.reserved "tailrec" >> K Tailrec))
   352 
   353   fun config_parser default =
   354     (Scan.optional (P.$$$ "(" |-- P.!!! (P.list1 option_parser) --| P.$$$ ")") [])
   355      >> (fn opts => fold apply_opt opts default)
   356 in
   357   fun function_parser default_cfg =
   358       config_parser default_cfg -- P.fixes -- SpecParse.where_alt_specs
   359 end
   360 
   361 
   362 end
   363 end