diff -r d1f7b6245a75 -r f5cafe803b55 src/HOL/Tools/primrec_package.ML --- a/src/HOL/Tools/primrec_package.ML Thu Jun 18 18:31:14 2009 -0700 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,332 +0,0 @@ -(* Title: HOL/Tools/primrec_package.ML - Author: Stefan Berghofer, TU Muenchen; Norbert Voelker, FernUni Hagen; - Florian Haftmann, TU Muenchen - -Package for defining functions on datatypes by primitive recursion. -*) - -signature PRIMREC_PACKAGE = -sig - val add_primrec: (binding * typ option * mixfix) list -> - (Attrib.binding * term) list -> local_theory -> thm list * local_theory - val add_primrec_cmd: (binding * string option * mixfix) list -> - (Attrib.binding * string) list -> local_theory -> thm list * local_theory - val add_primrec_global: (binding * typ option * mixfix) list -> - (Attrib.binding * term) list -> theory -> thm list * theory - val add_primrec_overloaded: (string * (string * typ) * bool) list -> - (binding * typ option * mixfix) list -> - (Attrib.binding * term) list -> theory -> thm list * theory - val add_primrec_simple: ((binding * typ) * mixfix) list -> term list -> - local_theory -> (string * thm list list) * local_theory -end; - -structure PrimrecPackage : PRIMREC_PACKAGE = -struct - -open DatatypeAux; - -exception PrimrecError of string * term option; - -fun primrec_error msg = raise PrimrecError (msg, NONE); -fun primrec_error_eqn msg eqn = raise PrimrecError (msg, SOME eqn); - -fun message s = if ! Toplevel.debug then tracing s else (); - - -(* preprocessing of equations *) - -fun process_eqn is_fixed spec rec_fns = - let - val (vs, Ts) = split_list (strip_qnt_vars "all" spec); - val body = strip_qnt_body "all" spec; - val (vs', _) = Name.variants vs (Name.make_context (fold_aterms - (fn Free (v, _) => insert (op =) v | _ => I) body [])); - val eqn = curry subst_bounds (map2 (curry Free) vs' Ts |> rev) body; - val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn) - handle TERM _ => primrec_error "not a proper equation"; - val (recfun, args) = strip_comb lhs; - val fname = case recfun of Free (v, _) => if is_fixed v then v - else primrec_error "illegal head of function equation" - | _ => primrec_error "illegal head of function equation"; - - val (ls', rest) = take_prefix is_Free args; - val (middle, rs') = take_suffix is_Free rest; - val rpos = length ls'; - - val (constr, cargs') = if null middle then primrec_error "constructor missing" - else strip_comb (hd middle); - val (cname, T) = dest_Const constr - handle TERM _ => primrec_error "ill-formed constructor"; - val (tname, _) = dest_Type (body_type T) handle TYPE _ => - primrec_error "cannot determine datatype associated with function" - - val (ls, cargs, rs) = - (map dest_Free ls', map dest_Free cargs', map dest_Free rs') - handle TERM _ => primrec_error "illegal argument in pattern"; - val lfrees = ls @ rs @ cargs; - - fun check_vars _ [] = () - | check_vars s vars = primrec_error (s ^ commas_quote (map fst vars)) eqn; - in - if length middle > 1 then - primrec_error "more than one non-variable in pattern" - else - (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees); - check_vars "extra variables on rhs: " - (map dest_Free (OldTerm.term_frees rhs) |> subtract (op =) lfrees - |> filter_out (is_fixed o fst)); - case AList.lookup (op =) rec_fns fname of - NONE => - (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eqn))]))::rec_fns - | SOME (_, rpos', eqns) => - if AList.defined (op =) eqns cname then - primrec_error "constructor already occurred as pattern" - else if rpos <> rpos' then - primrec_error "position of recursive argument inconsistent" - else - AList.update (op =) - (fname, (tname, rpos, (cname, (ls, cargs, rs, rhs, eqn))::eqns)) - rec_fns) - end handle PrimrecError (msg, NONE) => primrec_error_eqn msg spec; - -fun process_fun descr eqns (i, fname) (fnames, fnss) = - let - val (_, (tname, _, constrs)) = nth descr i; - - (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *) - - fun subst [] t fs = (t, fs) - | subst subs (Abs (a, T, t)) fs = - fs - |> subst subs t - |-> (fn t' => pair (Abs (a, T, t'))) - | subst subs (t as (_ $ _)) fs = - let - val (f, ts) = strip_comb t; - in - if is_Free f - andalso member (fn ((v, _), (w, _)) => v = w) eqns (dest_Free f) then - let - val (fname', _) = dest_Free f; - val (_, rpos, _) = the (AList.lookup (op =) eqns fname'); - val (ls, rs) = chop rpos ts - val (x', rs') = case rs - of x' :: rs => (x', rs) - | [] => primrec_error ("not enough arguments in recursive application\n" - ^ "of function " ^ quote fname' ^ " on rhs"); - val (x, xs) = strip_comb x'; - in case AList.lookup (op =) subs x - of NONE => - fs - |> fold_map (subst subs) ts - |-> (fn ts' => pair (list_comb (f, ts'))) - | SOME (i', y) => - fs - |> fold_map (subst subs) (xs @ ls @ rs') - ||> process_fun descr eqns (i', fname') - |-> (fn ts' => pair (list_comb (y, ts'))) - end - else - fs - |> fold_map (subst subs) (f :: ts) - |-> (fn (f'::ts') => pair (list_comb (f', ts'))) - end - | subst _ t fs = (t, fs); - - (* translate rec equations into function arguments suitable for rec comb *) - - fun trans eqns (cname, cargs) (fnames', fnss', fns) = - (case AList.lookup (op =) eqns cname of - NONE => (warning ("No equation for constructor " ^ quote cname ^ - "\nin definition of function " ^ quote fname); - (fnames', fnss', (Const ("HOL.undefined", dummyT))::fns)) - | SOME (ls, cargs', rs, rhs, eq) => - let - val recs = filter (is_rec_type o snd) (cargs' ~~ cargs); - val rargs = map fst recs; - val subs = map (rpair dummyT o fst) - (rev (Term.rename_wrt_term rhs rargs)); - val (rhs', (fnames'', fnss'')) = subst (map2 (fn (x, y) => fn z => - (Free x, (body_index y, Free z))) recs subs) rhs (fnames', fnss') - handle PrimrecError (s, NONE) => primrec_error_eqn s eq - in (fnames'', fnss'', - (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns) - end) - - in (case AList.lookup (op =) fnames i of - NONE => - if exists (fn (_, v) => fname = v) fnames then - primrec_error ("inconsistent functions for datatype " ^ quote tname) - else - let - val (_, _, eqns) = the (AList.lookup (op =) eqns fname); - val (fnames', fnss', fns) = fold_rev (trans eqns) constrs - ((i, fname)::fnames, fnss, []) - in - (fnames', (i, (fname, #1 (snd (hd eqns)), fns))::fnss') - end - | SOME fname' => - if fname = fname' then (fnames, fnss) - else primrec_error ("inconsistent functions for datatype " ^ quote tname)) - end; - - -(* prepare functions needed for definitions *) - -fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) = - case AList.lookup (op =) fns i of - NONE => - let - val dummy_fns = map (fn (_, cargs) => Const ("HOL.undefined", - replicate ((length cargs) + (length (List.filter is_rec_type cargs))) - dummyT ---> HOLogic.unitT)) constrs; - val _ = warning ("No function definition for datatype " ^ quote tname) - in - (dummy_fns @ fs, defs) - end - | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs); - - -(* make definition *) - -fun make_def ctxt fixes fs (fname, ls, rec_name, tname) = - let - val SOME (var, varT) = get_first (fn ((b, T), mx) => - if Binding.name_of b = fname then SOME ((b, mx), T) else NONE) fixes; - val def_name = Thm.def_name (Long_Name.base_name fname); - val raw_rhs = fold_rev (fn T => fn t => Abs ("", T, t)) (map snd ls @ [dummyT]) - (list_comb (Const (rec_name, dummyT), fs @ map Bound (0 :: (length ls downto 1)))) - val rhs = singleton (Syntax.check_terms ctxt) - (TypeInfer.constrain varT raw_rhs); - in (var, ((Binding.name def_name, []), rhs)) end; - - -(* find datatypes which contain all datatypes in tnames' *) - -fun find_dts (dt_info : datatype_info Symtab.table) _ [] = [] - | find_dts dt_info tnames' (tname::tnames) = - (case Symtab.lookup dt_info tname of - NONE => primrec_error (quote tname ^ " is not a datatype") - | SOME dt => - if tnames' subset (map (#1 o snd) (#descr dt)) then - (tname, dt)::(find_dts dt_info tnames' tnames) - else find_dts dt_info tnames' tnames); - - -(* distill primitive definition(s) from primrec specification *) - -fun distill lthy fixes eqs = - let - val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v - orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs []; - val tnames = distinct (op =) (map (#1 o snd) eqns); - val dts = find_dts (DatatypePackage.get_datatypes (ProofContext.theory_of lthy)) tnames tnames; - val main_fns = map (fn (tname, {index, ...}) => - (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts; - val {descr, rec_names, rec_rewrites, ...} = - if null dts then primrec_error - ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive") - else snd (hd dts); - val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []); - val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []); - val defs = map (make_def lthy fixes fs) raw_defs; - val names = map snd fnames; - val names_eqns = map fst eqns; - val _ = if gen_eq_set (op =) (names, names_eqns) then () - else primrec_error ("functions " ^ commas_quote names_eqns ^ - "\nare not mutually recursive"); - val rec_rewrites' = map mk_meta_eq rec_rewrites; - val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs); - fun prove lthy defs = - let - val rewrites = rec_rewrites' @ map (snd o snd) defs; - fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1]; - val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names); - in map (fn eq => [Goal.prove lthy [] [] eq tac]) eqs end; - in ((prefix, (fs, defs)), prove) end - handle PrimrecError (msg, some_eqn) => - error ("Primrec definition error:\n" ^ msg ^ (case some_eqn - of SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn) - | NONE => "")); - - -(* primrec definition *) - -fun add_primrec_simple fixes ts lthy = - let - val ((prefix, (fs, defs)), prove) = distill lthy fixes ts; - in - lthy - |> fold_map (LocalTheory.define Thm.definitionK) defs - |-> (fn defs => `(fn lthy => (prefix, prove lthy defs))) - end; - -local - -fun gen_primrec set_group prep_spec raw_fixes raw_spec lthy = - let - val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy); - fun attr_bindings prefix = map (fn ((b, attrs), _) => - (Binding.qualify false prefix b, Code.add_default_eqn_attrib :: attrs)) spec; - fun simp_attr_binding prefix = (Binding.qualify true prefix (Binding.name "simps"), - map (Attrib.internal o K) - [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add, Quickcheck_RecFun_Simp_Thms.add]); - in - lthy - |> set_group ? LocalTheory.set_group (serial_string ()) - |> add_primrec_simple fixes (map snd spec) - |-> (fn (prefix, simps) => fold_map (LocalTheory.note Thm.generatedK) - (attr_bindings prefix ~~ simps) - #-> (fn simps' => LocalTheory.note Thm.generatedK - (simp_attr_binding prefix, maps snd simps'))) - |>> snd - end; - -in - -val add_primrec = gen_primrec false Specification.check_spec; -val add_primrec_cmd = gen_primrec true Specification.read_spec; - -end; - -fun add_primrec_global fixes specs thy = - let - val lthy = TheoryTarget.init NONE thy; - val (simps, lthy') = add_primrec fixes specs lthy; - val simps' = ProofContext.export lthy' lthy simps; - in (simps', LocalTheory.exit_global lthy') end; - -fun add_primrec_overloaded ops fixes specs thy = - let - val lthy = TheoryTarget.overloading ops thy; - val (simps, lthy') = add_primrec fixes specs lthy; - val simps' = ProofContext.export lthy' lthy simps; - in (simps', LocalTheory.exit_global lthy') end; - - -(* outer syntax *) - -local structure P = OuterParse and K = OuterKeyword in - -val opt_unchecked_name = - Scan.optional (P.$$$ "(" |-- P.!!! - (((P.$$$ "unchecked" >> K true) -- Scan.optional P.name "" || - P.name >> pair false) --| P.$$$ ")")) (false, ""); - -val old_primrec_decl = - opt_unchecked_name -- Scan.repeat1 ((SpecParse.opt_thm_name ":" >> apfst Binding.name_of) -- P.prop); - -val primrec_decl = P.opt_target -- P.fixes -- SpecParse.where_alt_specs; - -val _ = - OuterSyntax.command "primrec" "define primitive recursive functions on datatypes" K.thy_decl - ((primrec_decl >> (fn ((opt_target, fixes), specs) => - Toplevel.local_theory opt_target (add_primrec_cmd fixes specs #> snd))) - || (old_primrec_decl >> (fn ((unchecked, alt_name), eqns) => - Toplevel.theory (snd o - (if unchecked then OldPrimrecPackage.add_primrec_unchecked else OldPrimrecPackage.add_primrec) - alt_name (map P.triple_swap eqns))))); - -end; - -end;