haftmann@31723: (* Title: HOL/Tools/old_primrec.ML wenzelm@29266: Author: Norbert Voelker, FernUni Hagen wenzelm@29266: Author: Stefan Berghofer, TU Muenchen haftmann@25557: haftmann@25557: Package for defining functions on datatypes by primitive recursion. haftmann@25557: *) haftmann@25557: haftmann@31723: signature OLD_PRIMREC = haftmann@25557: sig haftmann@25557: val unify_consts: theory -> term list -> term list -> term list * term list haftmann@25557: val add_primrec: string -> ((bstring * string) * Attrib.src list) list haftmann@25557: -> theory -> thm list * theory haftmann@25557: val add_primrec_unchecked: string -> ((bstring * string) * Attrib.src list) list haftmann@25557: -> theory -> thm list * theory haftmann@25557: val add_primrec_i: string -> ((bstring * term) * attribute list) list haftmann@25557: -> theory -> thm list * theory haftmann@25557: val add_primrec_unchecked_i: string -> ((bstring * term) * attribute list) list haftmann@25557: -> theory -> thm list * theory haftmann@25557: end; haftmann@25557: haftmann@31723: structure OldPrimrec : OLD_PRIMREC = haftmann@25557: struct haftmann@25557: haftmann@25557: open DatatypeAux; haftmann@25557: haftmann@25557: exception RecError of string; haftmann@25557: haftmann@25557: fun primrec_err s = error ("Primrec definition error:\n" ^ s); haftmann@25557: fun primrec_eq_err thy s eq = wenzelm@26939: primrec_err (s ^ "\nin\n" ^ quote (Syntax.string_of_term_global thy eq)); haftmann@25557: haftmann@25557: haftmann@25557: (*the following code ensures that each recursive set always has the haftmann@25557: same type in all introduction rules*) haftmann@25557: fun unify_consts thy cs intr_ts = haftmann@25557: (let wenzelm@33338: fun varify t (i, ts) = haftmann@25557: let val t' = map_types (Logic.incr_tvar (i + 1)) (snd (Type.varify [] t)) haftmann@25557: in (maxidx_of_term t', t'::ts) end; wenzelm@33338: val (i, cs') = fold_rev varify cs (~1, []); wenzelm@33338: val (i', intr_ts') = fold_rev varify intr_ts (i, []); wenzelm@29290: val rec_consts = fold Term.add_consts cs' []; wenzelm@29290: val intr_consts = fold Term.add_consts intr_ts' []; haftmann@25557: fun unify (cname, cT) = haftmann@25557: let val consts = map snd (filter (fn (c, _) => c = cname) intr_consts) haftmann@25557: in fold (Sign.typ_unify thy) ((replicate (length consts) cT) ~~ consts) end; haftmann@25557: val (env, _) = fold unify rec_consts (Vartab.empty, i'); wenzelm@33832: val subst = Type.legacy_freeze o map_types (Envir.norm_type env) haftmann@25557: haftmann@25557: in (map subst cs', map subst intr_ts') haftmann@25557: end) handle Type.TUNIFY => haftmann@25557: (warning "Occurrences of recursive constant have non-unifiable types"; (cs, intr_ts)); haftmann@25557: haftmann@25557: haftmann@25557: (* preprocessing of equations *) haftmann@25557: haftmann@25557: fun process_eqn thy eq rec_fns = haftmann@25557: let haftmann@25557: val (lhs, rhs) = wenzelm@29266: if null (Term.add_vars eq []) then haftmann@25557: HOLogic.dest_eq (HOLogic.dest_Trueprop eq) haftmann@25557: handle TERM _ => raise RecError "not a proper equation" haftmann@25557: else raise RecError "illegal schematic variable(s)"; haftmann@25557: haftmann@25557: val (recfun, args) = strip_comb lhs; haftmann@25557: val fnameT = dest_Const recfun handle TERM _ => haftmann@25557: raise RecError "function is not declared as constant in theory"; haftmann@25557: haftmann@25557: val (ls', rest) = take_prefix is_Free args; haftmann@25557: val (middle, rs') = take_suffix is_Free rest; haftmann@25557: val rpos = length ls'; haftmann@25557: haftmann@25557: val (constr, cargs') = if null middle then raise RecError "constructor missing" haftmann@25557: else strip_comb (hd middle); haftmann@25557: val (cname, T) = dest_Const constr haftmann@25557: handle TERM _ => raise RecError "ill-formed constructor"; haftmann@25557: val (tname, _) = dest_Type (body_type T) handle TYPE _ => haftmann@25557: raise RecError "cannot determine datatype associated with function" haftmann@25557: haftmann@25557: val (ls, cargs, rs) = haftmann@25557: (map dest_Free ls', map dest_Free cargs', map dest_Free rs') haftmann@25557: handle TERM _ => raise RecError "illegal argument in pattern"; haftmann@25557: val lfrees = ls @ rs @ cargs; haftmann@25557: haftmann@25557: fun check_vars _ [] = () haftmann@25557: | check_vars s vars = raise RecError (s ^ commas_quote (map fst vars)) haftmann@25557: in haftmann@25557: if length middle > 1 then haftmann@25557: raise RecError "more than one non-variable in pattern" haftmann@25557: else haftmann@25557: (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees); haftmann@25557: check_vars "extra variables on rhs: " haftmann@33040: (subtract (op =) lfrees (map dest_Free (OldTerm.term_frees rhs))); haftmann@25557: case AList.lookup (op =) rec_fns fnameT of haftmann@25557: NONE => haftmann@25557: (fnameT, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eq))]))::rec_fns haftmann@25557: | SOME (_, rpos', eqns) => haftmann@25557: if AList.defined (op =) eqns cname then haftmann@25557: raise RecError "constructor already occurred as pattern" haftmann@25557: else if rpos <> rpos' then haftmann@25557: raise RecError "position of recursive argument inconsistent" haftmann@25557: else haftmann@25557: AList.update (op =) (fnameT, (tname, rpos, (cname, (ls, cargs, rs, rhs, eq))::eqns)) haftmann@25557: rec_fns) haftmann@25557: end haftmann@25557: handle RecError s => primrec_eq_err thy s eq; haftmann@25557: haftmann@25557: fun process_fun thy descr rec_eqns (i, fnameT as (fname, _)) (fnameTs, fnss) = haftmann@25557: let haftmann@25557: val (_, (tname, _, constrs)) = List.nth (descr, i); haftmann@25557: haftmann@25557: (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *) haftmann@25557: haftmann@25557: fun subst [] t fs = (t, fs) haftmann@25557: | subst subs (Abs (a, T, t)) fs = haftmann@25557: fs haftmann@25557: |> subst subs t haftmann@25557: |-> (fn t' => pair (Abs (a, T, t'))) haftmann@25557: | subst subs (t as (_ $ _)) fs = haftmann@25557: let haftmann@25557: val (f, ts) = strip_comb t; haftmann@25557: in haftmann@25557: if is_Const f andalso dest_Const f mem map fst rec_eqns then haftmann@25557: let haftmann@25557: val fnameT' as (fname', _) = dest_Const f; haftmann@25557: val (_, rpos, _) = the (AList.lookup (op =) rec_eqns fnameT'); haftmann@33957: val ls = take rpos ts; haftmann@33957: val rest = drop rpos ts; haftmann@25557: val (x', rs) = (hd rest, tl rest) haftmann@25557: handle Empty => raise RecError ("not enough arguments\ haftmann@25557: \ in recursive application\nof function " ^ quote fname' ^ " on rhs"); haftmann@25557: val (x, xs) = strip_comb x' haftmann@25557: in case AList.lookup (op =) subs x haftmann@25557: of NONE => haftmann@25557: fs haftmann@25557: |> fold_map (subst subs) ts haftmann@25557: |-> (fn ts' => pair (list_comb (f, ts'))) haftmann@25557: | SOME (i', y) => haftmann@25557: fs haftmann@25557: |> fold_map (subst subs) (xs @ ls @ rs) haftmann@25557: ||> process_fun thy descr rec_eqns (i', fnameT') haftmann@25557: |-> (fn ts' => pair (list_comb (y, ts'))) haftmann@25557: end haftmann@25557: else haftmann@25557: fs haftmann@25557: |> fold_map (subst subs) (f :: ts) haftmann@25557: |-> (fn (f'::ts') => pair (list_comb (f', ts'))) haftmann@25557: end haftmann@25557: | subst _ t fs = (t, fs); haftmann@25557: haftmann@25557: (* translate rec equations into function arguments suitable for rec comb *) haftmann@25557: haftmann@25557: fun trans eqns (cname, cargs) (fnameTs', fnss', fns) = haftmann@25557: (case AList.lookup (op =) eqns cname of haftmann@25557: NONE => (warning ("No equation for constructor " ^ quote cname ^ haftmann@25557: "\nin definition of function " ^ quote fname); haftmann@25557: (fnameTs', fnss', (Const ("HOL.undefined", dummyT))::fns)) haftmann@25557: | SOME (ls, cargs', rs, rhs, eq) => haftmann@25557: let haftmann@25557: val recs = filter (is_rec_type o snd) (cargs' ~~ cargs); haftmann@25557: val rargs = map fst recs; haftmann@25557: val subs = map (rpair dummyT o fst) wenzelm@29276: (rev (Term.rename_wrt_term rhs rargs)); haftmann@25557: val (rhs', (fnameTs'', fnss'')) = haftmann@25557: (subst (map (fn ((x, y), z) => haftmann@25557: (Free x, (body_index y, Free z))) haftmann@25557: (recs ~~ subs)) rhs (fnameTs', fnss')) haftmann@25557: handle RecError s => primrec_eq_err thy s eq haftmann@25557: in (fnameTs'', fnss'', haftmann@25557: (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns) haftmann@25557: end) haftmann@25557: haftmann@25557: in (case AList.lookup (op =) fnameTs i of haftmann@25557: NONE => haftmann@25557: if exists (equal fnameT o snd) fnameTs then haftmann@25557: raise RecError ("inconsistent functions for datatype " ^ quote tname) haftmann@25557: else haftmann@25557: let haftmann@25557: val (_, _, eqns) = the (AList.lookup (op =) rec_eqns fnameT); haftmann@25557: val (fnameTs', fnss', fns) = fold_rev (trans eqns) constrs haftmann@25557: ((i, fnameT)::fnameTs, fnss, []) haftmann@25557: in haftmann@25557: (fnameTs', (i, (fname, #1 (snd (hd eqns)), fns))::fnss') haftmann@25557: end haftmann@25557: | SOME fnameT' => haftmann@25557: if fnameT = fnameT' then (fnameTs, fnss) haftmann@25557: else raise RecError ("inconsistent functions for datatype " ^ quote tname)) haftmann@25557: end; haftmann@25557: haftmann@25557: haftmann@25557: (* prepare functions needed for definitions *) haftmann@25557: haftmann@25557: fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) = haftmann@25557: case AList.lookup (op =) fns i of haftmann@25557: NONE => haftmann@25557: let haftmann@25557: val dummy_fns = map (fn (_, cargs) => Const ("HOL.undefined", wenzelm@33317: replicate (length cargs + length (filter is_rec_type cargs)) haftmann@25557: dummyT ---> HOLogic.unitT)) constrs; haftmann@25557: val _ = warning ("No function definition for datatype " ^ quote tname) haftmann@25557: in haftmann@25557: (dummy_fns @ fs, defs) haftmann@25557: end haftmann@25557: | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs); haftmann@25557: haftmann@25557: haftmann@25557: (* make definition *) haftmann@25557: haftmann@25557: fun make_def thy fs (fname, ls, rec_name, tname) = haftmann@25557: let haftmann@25557: val rhs = fold_rev (fn T => fn t => Abs ("", T, t)) haftmann@25557: ((map snd ls) @ [dummyT]) haftmann@25557: (list_comb (Const (rec_name, dummyT), haftmann@25557: fs @ map Bound (0 ::(length ls downto 1)))) wenzelm@30364: val def_name = Long_Name.base_name fname ^ "_" ^ Long_Name.base_name tname ^ "_def"; haftmann@25557: val def_prop = haftmann@25557: singleton (Syntax.check_terms (ProofContext.init thy)) haftmann@25557: (Logic.mk_equals (Const (fname, dummyT), rhs)); haftmann@25557: in (def_name, def_prop) end; haftmann@25557: haftmann@25557: haftmann@25557: (* find datatypes which contain all datatypes in tnames' *) haftmann@25557: haftmann@31737: fun find_dts (dt_info : info Symtab.table) _ [] = [] haftmann@25557: | find_dts dt_info tnames' (tname::tnames) = haftmann@25557: (case Symtab.lookup dt_info tname of haftmann@25557: NONE => primrec_err (quote tname ^ " is not a datatype") haftmann@25557: | SOME dt => haftmann@33038: if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then haftmann@25557: (tname, dt)::(find_dts dt_info tnames' tnames) haftmann@25557: else find_dts dt_info tnames' tnames); haftmann@25557: haftmann@32727: fun prepare_induct ({descr, induct, ...}: info) rec_eqns = haftmann@25557: let haftmann@25557: fun constrs_of (_, (_, _, cs)) = haftmann@25557: map (fn (cname:string, (_, cargs, _, _, _)) => (cname, map fst cargs)) cs; haftmann@32712: val params_of = these o AList.lookup (op =) (maps constrs_of rec_eqns); haftmann@25557: in haftmann@32712: induct wenzelm@33368: |> Rule_Cases.rename_params (map params_of (maps (map #1 o #3 o #2) descr)) wenzelm@33368: |> Rule_Cases.save induct haftmann@25557: end; haftmann@25557: haftmann@25557: local haftmann@25557: haftmann@25557: fun gen_primrec_i note def alt_name eqns_atts thy = haftmann@25557: let haftmann@25557: val (eqns, atts) = split_list eqns_atts; haftmann@31784: val dt_info = Datatype.get_all thy; haftmann@25557: val rec_eqns = fold_rev (process_eqn thy o snd) eqns [] ; haftmann@25557: val tnames = distinct (op =) (map (#1 o snd) rec_eqns); haftmann@25557: val dts = find_dts dt_info tnames tnames; haftmann@25557: val main_fns = haftmann@25557: map (fn (tname, {index, ...}) => haftmann@25557: (index, haftmann@25557: (fst o the o find_first (fn f => (#1 o snd) f = tname)) rec_eqns)) haftmann@25557: dts; haftmann@25557: val {descr, rec_names, rec_rewrites, ...} = haftmann@25557: if null dts then haftmann@25557: primrec_err ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive") haftmann@25557: else snd (hd dts); haftmann@25557: val (fnameTs, fnss) = haftmann@25557: fold_rev (process_fun thy descr rec_eqns) main_fns ([], []); haftmann@25557: val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []); haftmann@25557: val defs' = map (make_def thy fs) defs; haftmann@25557: val nameTs1 = map snd fnameTs; haftmann@25557: val nameTs2 = map fst rec_eqns; haftmann@33038: val _ = if eq_set (op =) (nameTs1, nameTs2) then () haftmann@25557: else primrec_err ("functions " ^ commas_quote (map fst nameTs2) ^ haftmann@25557: "\nare not mutually recursive"); haftmann@25557: val primrec_name = wenzelm@30364: if alt_name = "" then (space_implode "_" (map (Long_Name.base_name o #1) defs)) else alt_name; haftmann@25557: val (defs_thms', thy') = haftmann@25557: thy haftmann@25557: |> Sign.add_path primrec_name haftmann@25557: |> fold_map def (map (fn (name, t) => ((name, []), t)) defs'); haftmann@25557: val rewrites = (map mk_meta_eq rec_rewrites) @ map snd defs_thms'; haftmann@25557: val simps = map (fn (_, t) => Goal.prove_global thy' [] [] t haftmann@25557: (fn _ => EVERY [rewrite_goals_tac rewrites, rtac refl 1])) eqns; haftmann@25557: val (simps', thy'') = haftmann@25557: thy' haftmann@25557: |> fold_map note ((map fst eqns ~~ atts) ~~ map single simps); haftmann@25557: val simps'' = maps snd simps'; haftmann@25557: in haftmann@25557: thy'' wenzelm@31902: |> note (("simps", blanchet@33056: [Simplifier.simp_add, Nitpick_Simps.add, Code.add_default_eqn_attribute]), simps'') haftmann@25557: |> snd haftmann@25557: |> note (("induct", []), [prepare_induct (#2 (hd dts)) rec_eqns]) haftmann@25557: |> snd haftmann@25557: |> Sign.parent_path haftmann@25557: |> pair simps'' haftmann@25557: end; haftmann@25557: haftmann@25557: fun gen_primrec note def alt_name eqns thy = haftmann@25557: let haftmann@25557: val ((names, strings), srcss) = apfst split_list (split_list eqns); haftmann@25557: val atts = map (map (Attrib.attribute thy)) srcss; haftmann@25557: val eqn_ts = map (fn s => Syntax.read_prop_global thy s haftmann@25557: handle ERROR msg => cat_error msg ("The error(s) above occurred for " ^ s)) strings; haftmann@25557: val rec_ts = map (fn eq => head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop eq))) haftmann@25557: handle TERM _ => primrec_eq_err thy "not a proper equation" eq) eqn_ts; haftmann@25557: val (_, eqn_ts') = unify_consts thy rec_ts eqn_ts haftmann@25557: in haftmann@25557: gen_primrec_i note def alt_name (names ~~ eqn_ts' ~~ atts) thy haftmann@25557: end; haftmann@25557: haftmann@25557: fun thy_note ((name, atts), thms) = haftmann@29579: PureThy.add_thmss [((Binding.name name, thms), atts)] #-> (fn [thms] => pair (name, thms)); haftmann@25557: fun thy_def false ((name, atts), t) = haftmann@29579: PureThy.add_defs false [((Binding.name name, t), atts)] #-> (fn [thm] => pair (name, thm)) haftmann@25557: | thy_def true ((name, atts), t) = haftmann@29579: PureThy.add_defs_unchecked false [((Binding.name name, t), atts)] #-> (fn [thm] => pair (name, thm)); haftmann@25557: haftmann@25557: in haftmann@25557: haftmann@25557: val add_primrec = gen_primrec thy_note (thy_def false); haftmann@25557: val add_primrec_unchecked = gen_primrec thy_note (thy_def true); haftmann@25557: val add_primrec_i = gen_primrec_i thy_note (thy_def false); haftmann@25557: val add_primrec_unchecked_i = gen_primrec_i thy_note (thy_def true); haftmann@25557: haftmann@25557: end; haftmann@25557: haftmann@25557: end;