# HG changeset patch # User haftmann # Date 1245751770 -7200 # Node ID 2b04504fcb690b38a26f9b8ee1a1fbeb2642f793 # Parent 5c8cfaed32e644b4b48a8fc62671d2959c56e16d uniformly capitialized names for subdirectories diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Fun.thy --- a/src/HOL/Fun.thy Tue Jun 23 12:09:14 2009 +0200 +++ b/src/HOL/Fun.thy Tue Jun 23 12:09:30 2009 +0200 @@ -133,7 +133,7 @@ shows "inj f" using assms unfolding inj_on_def by auto -text{*For Proofs in @{text "Tools/datatype_package/datatype_rep_proofs"}*} +text{*For Proofs in @{text "Tools/Datatype/datatype_rep_proofs"}*} lemma datatype_injI: "(!! x. ALL y. f(x) = f(y) --> x=y) ==> inj(f)" by (simp add: inj_on_def) diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/FunDef.thy --- a/src/HOL/FunDef.thy Tue Jun 23 12:09:14 2009 +0200 +++ b/src/HOL/FunDef.thy Tue Jun 23 12:09:30 2009 +0200 @@ -9,25 +9,25 @@ uses "Tools/prop_logic.ML" "Tools/sat_solver.ML" - ("Tools/function_package/fundef_lib.ML") - ("Tools/function_package/fundef_common.ML") - ("Tools/function_package/inductive_wrap.ML") - ("Tools/function_package/context_tree.ML") - ("Tools/function_package/fundef_core.ML") - ("Tools/function_package/sum_tree.ML") - ("Tools/function_package/mutual.ML") - ("Tools/function_package/pattern_split.ML") - ("Tools/function_package/fundef.ML") - ("Tools/function_package/auto_term.ML") - ("Tools/function_package/measure_functions.ML") - ("Tools/function_package/lexicographic_order.ML") - ("Tools/function_package/fundef_datatype.ML") - ("Tools/function_package/induction_scheme.ML") - ("Tools/function_package/termination.ML") - ("Tools/function_package/decompose.ML") - ("Tools/function_package/descent.ML") - ("Tools/function_package/scnp_solve.ML") - ("Tools/function_package/scnp_reconstruct.ML") + ("Tools/Function/fundef_lib.ML") + ("Tools/Function/fundef_common.ML") + ("Tools/Function/inductive_wrap.ML") + ("Tools/Function/context_tree.ML") + ("Tools/Function/fundef_core.ML") + ("Tools/Function/sum_tree.ML") + ("Tools/Function/mutual.ML") + ("Tools/Function/pattern_split.ML") + ("Tools/Function/fundef.ML") + ("Tools/Function/auto_term.ML") + ("Tools/Function/measure_functions.ML") + ("Tools/Function/lexicographic_order.ML") + ("Tools/Function/fundef_datatype.ML") + ("Tools/Function/induction_scheme.ML") + ("Tools/Function/termination.ML") + ("Tools/Function/decompose.ML") + ("Tools/Function/descent.ML") + ("Tools/Function/scnp_solve.ML") + ("Tools/Function/scnp_reconstruct.ML") begin subsection {* Definitions with default value. *} @@ -103,18 +103,18 @@ "wf R \ wfP (in_rel R)" by (simp add: wfP_def) -use "Tools/function_package/fundef_lib.ML" -use "Tools/function_package/fundef_common.ML" -use "Tools/function_package/inductive_wrap.ML" -use "Tools/function_package/context_tree.ML" -use "Tools/function_package/fundef_core.ML" -use "Tools/function_package/sum_tree.ML" -use "Tools/function_package/mutual.ML" -use "Tools/function_package/pattern_split.ML" -use "Tools/function_package/auto_term.ML" -use "Tools/function_package/fundef.ML" -use "Tools/function_package/fundef_datatype.ML" -use "Tools/function_package/induction_scheme.ML" +use "Tools/Function/fundef_lib.ML" +use "Tools/Function/fundef_common.ML" +use "Tools/Function/inductive_wrap.ML" +use "Tools/Function/context_tree.ML" +use "Tools/Function/fundef_core.ML" +use "Tools/Function/sum_tree.ML" +use "Tools/Function/mutual.ML" +use "Tools/Function/pattern_split.ML" +use "Tools/Function/auto_term.ML" +use "Tools/Function/fundef.ML" +use "Tools/Function/fundef_datatype.ML" +use "Tools/Function/induction_scheme.ML" setup {* Fundef.setup @@ -127,7 +127,7 @@ inductive is_measure :: "('a \ nat) \ bool" where is_measure_trivial: "is_measure f" -use "Tools/function_package/measure_functions.ML" +use "Tools/Function/measure_functions.ML" setup MeasureFunctions.setup lemma measure_size[measure_function]: "is_measure size" @@ -138,7 +138,7 @@ lemma measure_snd[measure_function]: "is_measure f \ is_measure (\p. f (snd p))" by (rule is_measure_trivial) -use "Tools/function_package/lexicographic_order.ML" +use "Tools/Function/lexicographic_order.ML" setup LexicographicOrder.setup @@ -307,11 +307,11 @@ subsection {* Tool setup *} -use "Tools/function_package/termination.ML" -use "Tools/function_package/decompose.ML" -use "Tools/function_package/descent.ML" -use "Tools/function_package/scnp_solve.ML" -use "Tools/function_package/scnp_reconstruct.ML" +use "Tools/Function/termination.ML" +use "Tools/Function/decompose.ML" +use "Tools/Function/descent.ML" +use "Tools/Function/scnp_solve.ML" +use "Tools/Function/scnp_reconstruct.ML" setup {* ScnpReconstruct.setup *} diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Inductive.thy --- a/src/HOL/Inductive.thy Tue Jun 23 12:09:14 2009 +0200 +++ b/src/HOL/Inductive.thy Tue Jun 23 12:09:30 2009 +0200 @@ -10,15 +10,15 @@ ("Tools/inductive.ML") "Tools/dseq.ML" ("Tools/inductive_codegen.ML") - ("Tools/datatype_package/datatype_aux.ML") - ("Tools/datatype_package/datatype_prop.ML") - ("Tools/datatype_package/datatype_rep_proofs.ML") - ("Tools/datatype_package/datatype_abs_proofs.ML") - ("Tools/datatype_package/datatype_case.ML") - ("Tools/datatype_package/datatype.ML") + ("Tools/Datatype/datatype_aux.ML") + ("Tools/Datatype/datatype_prop.ML") + ("Tools/Datatype/datatype_rep_proofs.ML") + ("Tools/Datatype/datatype_abs_proofs.ML") + ("Tools/Datatype/datatype_case.ML") + ("Tools/Datatype/datatype.ML") ("Tools/old_primrec.ML") ("Tools/primrec.ML") - ("Tools/datatype_package/datatype_codegen.ML") + ("Tools/Datatype/datatype_codegen.ML") begin subsection {* Least and greatest fixed points *} @@ -335,18 +335,18 @@ text {* Package setup. *} -use "Tools/datatype_package/datatype_aux.ML" -use "Tools/datatype_package/datatype_prop.ML" -use "Tools/datatype_package/datatype_rep_proofs.ML" -use "Tools/datatype_package/datatype_abs_proofs.ML" -use "Tools/datatype_package/datatype_case.ML" -use "Tools/datatype_package/datatype.ML" +use "Tools/Datatype/datatype_aux.ML" +use "Tools/Datatype/datatype_prop.ML" +use "Tools/Datatype/datatype_rep_proofs.ML" +use "Tools/Datatype/datatype_abs_proofs.ML" +use "Tools/Datatype/datatype_case.ML" +use "Tools/Datatype/datatype.ML" setup Datatype.setup use "Tools/old_primrec.ML" use "Tools/primrec.ML" -use "Tools/datatype_package/datatype_codegen.ML" +use "Tools/Datatype/datatype_codegen.ML" setup DatatypeCodegen.setup use "Tools/inductive_codegen.ML" diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Product_Type.thy --- a/src/HOL/Product_Type.thy Tue Jun 23 12:09:14 2009 +0200 +++ b/src/HOL/Product_Type.thy Tue Jun 23 12:09:30 2009 +0200 @@ -11,7 +11,7 @@ ("Tools/split_rule.ML") ("Tools/inductive_set.ML") ("Tools/inductive_realizer.ML") - ("Tools/datatype_package/datatype_realizer.ML") + ("Tools/Datatype/datatype_realizer.ML") begin subsection {* @{typ bool} is a datatype *} @@ -399,7 +399,7 @@ by (simp add: split_def id_def) lemma split_eta: "(\(x, y). f (x, y)) = f" - -- {* Subsumes the old @{text split_Pair} when @{term f} is the identity function. *} + -- {* Subsumes the old @{text split_Pair} when @{term f} is the identity Datatype. *} by (rule ext) auto lemma split_comp: "split (f \ g) x = f (g (fst x)) (snd x)" @@ -734,7 +734,7 @@ text {* @{term prod_fun} --- action of the product functor upon - functions. + Datatypes. *} definition prod_fun :: "('a \ 'c) \ ('b \ 'd) \ 'a \ 'b \ 'c \ 'd" where @@ -1154,7 +1154,7 @@ use "Tools/inductive_set.ML" setup Inductive_Set.setup -use "Tools/datatype_package/datatype_realizer.ML" +use "Tools/Datatype/datatype_realizer.ML" setup DatatypeRealizer.setup end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Datatype/datatype.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Datatype/datatype.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,704 @@ +(* Title: HOL/Tools/datatype.ML + Author: Stefan Berghofer, TU Muenchen + +Datatype package for Isabelle/HOL. +*) + +signature DATATYPE = +sig + include DATATYPE_COMMON + type rules = {distinct : thm list list, + inject : thm list list, + exhaustion : thm list, + rec_thms : thm list, + case_thms : thm list list, + split_thms : (thm * thm) list, + induction : thm, + simps : thm list} + val add_datatype : config -> string list -> (string list * binding * mixfix * + (binding * typ list * mixfix) list) list -> theory -> rules * theory + val datatype_cmd : string list -> (string list * binding * mixfix * + (binding * string list * mixfix) list) list -> theory -> theory + val rep_datatype : config -> (rules -> Proof.context -> Proof.context) + -> string list option -> term list -> theory -> Proof.state + val rep_datatype_cmd : string list option -> string list -> theory -> Proof.state + val get_datatypes : theory -> info Symtab.table + val get_datatype : theory -> string -> info option + val the_datatype : theory -> string -> info + val datatype_of_constr : theory -> string -> info option + val datatype_of_case : theory -> string -> info option + val the_datatype_spec : theory -> string -> (string * sort) list * (string * typ list) list + val the_datatype_descr : theory -> string list + -> descr * (string * sort) list * string list + * (string list * string list) * (typ list * typ list) + val get_datatype_constrs : theory -> string -> (string * typ) list option + val interpretation : (config -> string list -> theory -> theory) -> theory -> theory + val distinct_simproc : simproc + val make_case : Proof.context -> bool -> string list -> term -> + (term * term) list -> term * (term * (int * bool)) list + val strip_case : Proof.context -> bool -> term -> (term * (term * term) list) option + val read_typ: theory -> + (typ list * (string * sort) list) * string -> typ list * (string * sort) list + val setup: theory -> theory +end; + +structure Datatype : DATATYPE = +struct + +open DatatypeAux; + + +(* theory data *) + +structure DatatypesData = TheoryDataFun +( + type T = + {types: info Symtab.table, + constrs: info Symtab.table, + cases: info Symtab.table}; + + val empty = + {types = Symtab.empty, constrs = Symtab.empty, cases = Symtab.empty}; + val copy = I; + val extend = I; + fun merge _ + ({types = types1, constrs = constrs1, cases = cases1}, + {types = types2, constrs = constrs2, cases = cases2}) = + {types = Symtab.merge (K true) (types1, types2), + constrs = Symtab.merge (K true) (constrs1, constrs2), + cases = Symtab.merge (K true) (cases1, cases2)}; +); + +val get_datatypes = #types o DatatypesData.get; +val map_datatypes = DatatypesData.map; + + +(** theory information about datatypes **) + +fun put_dt_infos (dt_infos : (string * info) list) = + map_datatypes (fn {types, constrs, cases} => + {types = fold Symtab.update dt_infos types, + constrs = fold Symtab.default (*conservative wrt. overloaded constructors*) + (maps (fn (_, info as {descr, index, ...}) => map (rpair info o fst) + (#3 (the (AList.lookup op = descr index)))) dt_infos) constrs, + cases = fold Symtab.update + (map (fn (_, info as {case_name, ...}) => (case_name, info)) dt_infos) + cases}); + +val get_datatype = Symtab.lookup o get_datatypes; + +fun the_datatype thy name = (case get_datatype thy name of + SOME info => info + | NONE => error ("Unknown datatype " ^ quote name)); + +val datatype_of_constr = Symtab.lookup o #constrs o DatatypesData.get; +val datatype_of_case = Symtab.lookup o #cases o DatatypesData.get; + +fun get_datatype_descr thy dtco = + get_datatype thy dtco + |> Option.map (fn info as { descr, index, ... } => + (info, (((fn SOME (_, dtys, cos) => (dtys, cos)) o AList.lookup (op =) descr) index))); + +fun the_datatype_spec thy dtco = + let + val info as { descr, index, sorts = raw_sorts, ... } = the_datatype thy dtco; + val SOME (_, dtys, raw_cos) = AList.lookup (op =) descr index; + val sorts = map ((fn v => (v, (the o AList.lookup (op =) raw_sorts) v)) + o DatatypeAux.dest_DtTFree) dtys; + val cos = map + (fn (co, tys) => (co, map (DatatypeAux.typ_of_dtyp descr sorts) tys)) raw_cos; + in (sorts, cos) end; + +fun the_datatype_descr thy (raw_tycos as raw_tyco :: _) = + let + val info = the_datatype thy raw_tyco; + val descr = #descr info; + + val SOME (_, dtys, raw_cos) = AList.lookup (op =) descr (#index info); + val vs = map ((fn v => (v, (the o AList.lookup (op =) (#sorts info)) v)) + o dest_DtTFree) dtys; + + fun is_DtTFree (DtTFree _) = true + | is_DtTFree _ = false + val k = find_index (fn (_, (_, dTs, _)) => not (forall is_DtTFree dTs)) descr; + val protoTs as (dataTs, _) = chop k descr + |> (pairself o map) (fn (_, (tyco, dTs, _)) => (tyco, map (typ_of_dtyp descr vs) dTs)); + + val tycos = map fst dataTs; + val _ = if gen_eq_set (op =) (tycos, raw_tycos) then () + else error ("Type constructors " ^ commas (map quote raw_tycos) + ^ "do not belong exhaustively to one mutual recursive datatype"); + + val (Ts, Us) = (pairself o map) Type protoTs; + + val names = map Long_Name.base_name (the_default tycos (#alt_names info)); + val (auxnames, _) = Name.make_context names + |> fold_map (yield_singleton Name.variants o name_of_typ) Us + + in (descr, vs, tycos, (names, auxnames), (Ts, Us)) end; + +fun get_datatype_constrs thy dtco = + case try (the_datatype_spec thy) dtco + of SOME (sorts, cos) => + let + fun subst (v, sort) = TVar ((v, 0), sort); + fun subst_ty (TFree v) = subst v + | subst_ty ty = ty; + val dty = Type (dtco, map subst sorts); + fun mk_co (co, tys) = (co, map (Term.map_atyps subst_ty) tys ---> dty); + in SOME (map mk_co cos) end + | NONE => NONE; + + +(** induct method setup **) + +(* case names *) + +local + +fun dt_recs (DtTFree _) = [] + | dt_recs (DtType (_, dts)) = maps dt_recs dts + | dt_recs (DtRec i) = [i]; + +fun dt_cases (descr: descr) (_, args, constrs) = + let + fun the_bname i = Long_Name.base_name (#1 (the (AList.lookup (op =) descr i))); + val bnames = map the_bname (distinct (op =) (maps dt_recs args)); + in map (fn (c, _) => space_implode "_" (Long_Name.base_name c :: bnames)) constrs end; + + +fun induct_cases descr = + DatatypeProp.indexify_names (maps (dt_cases descr) (map #2 descr)); + +fun exhaust_cases descr i = dt_cases descr (the (AList.lookup (op =) descr i)); + +in + +fun mk_case_names_induct descr = RuleCases.case_names (induct_cases descr); + +fun mk_case_names_exhausts descr new = + map (RuleCases.case_names o exhaust_cases descr o #1) + (filter (fn ((_, (name, _, _))) => member (op =) new name) descr); + +end; + +fun add_rules simps case_thms rec_thms inject distinct + weak_case_congs cong_att = + PureThy.add_thmss [((Binding.name "simps", simps), []), + ((Binding.empty, flat case_thms @ + flat distinct @ rec_thms), [Simplifier.simp_add]), + ((Binding.empty, rec_thms), [Code.add_default_eqn_attribute]), + ((Binding.empty, flat inject), [iff_add]), + ((Binding.empty, map (fn th => th RS notE) (flat distinct)), [Classical.safe_elim NONE]), + ((Binding.empty, weak_case_congs), [cong_att])] + #> snd; + + +(* add_cases_induct *) + +fun add_cases_induct infos induction thy = + let + val inducts = ProjectRule.projections (ProofContext.init thy) induction; + + fun named_rules (name, {index, exhaustion, ...}: info) = + [((Binding.empty, nth inducts index), [Induct.induct_type name]), + ((Binding.empty, exhaustion), [Induct.cases_type name])]; + fun unnamed_rule i = + ((Binding.empty, nth inducts i), [Thm.kind_internal, Induct.induct_type ""]); + in + thy |> PureThy.add_thms + (maps named_rules infos @ + map unnamed_rule (length infos upto length inducts - 1)) |> snd + |> PureThy.add_thmss [((Binding.name "inducts", inducts), [])] |> snd + end; + + + +(**** simplification procedure for showing distinctness of constructors ****) + +fun stripT (i, Type ("fun", [_, T])) = stripT (i + 1, T) + | stripT p = p; + +fun stripC (i, f $ x) = stripC (i + 1, f) + | stripC p = p; + +val distinctN = "constr_distinct"; + +fun distinct_rule thy ss tname eq_t = case #distinct (the_datatype thy tname) of + FewConstrs thms => Goal.prove (Simplifier.the_context ss) [] [] eq_t (K + (EVERY [rtac eq_reflection 1, rtac iffI 1, rtac notE 1, + atac 2, resolve_tac thms 1, etac FalseE 1])) + | ManyConstrs (thm, simpset) => + let + val [In0_inject, In1_inject, In0_not_In1, In1_not_In0] = + map (PureThy.get_thm (ThyInfo.the_theory "Datatype" thy)) + ["In0_inject", "In1_inject", "In0_not_In1", "In1_not_In0"]; + in + Goal.prove (Simplifier.the_context ss) [] [] eq_t (K + (EVERY [rtac eq_reflection 1, rtac iffI 1, dtac thm 1, + full_simp_tac (Simplifier.inherit_context ss simpset) 1, + REPEAT (dresolve_tac [In0_inject, In1_inject] 1), + eresolve_tac [In0_not_In1 RS notE, In1_not_In0 RS notE] 1, + etac FalseE 1])) + end; + +fun distinct_proc thy ss (t as Const ("op =", _) $ t1 $ t2) = + (case (stripC (0, t1), stripC (0, t2)) of + ((i, Const (cname1, T1)), (j, Const (cname2, T2))) => + (case (stripT (0, T1), stripT (0, T2)) of + ((i', Type (tname1, _)), (j', Type (tname2, _))) => + if tname1 = tname2 andalso not (cname1 = cname2) andalso i = i' andalso j = j' then + (case (get_datatype_descr thy) tname1 of + SOME (_, (_, constrs)) => let val cnames = map fst constrs + in if cname1 mem cnames andalso cname2 mem cnames then + SOME (distinct_rule thy ss tname1 + (Logic.mk_equals (t, Const ("False", HOLogic.boolT)))) + else NONE + end + | NONE => NONE) + else NONE + | _ => NONE) + | _ => NONE) + | distinct_proc _ _ _ = NONE; + +val distinct_simproc = + Simplifier.simproc @{theory HOL} distinctN ["s = t"] distinct_proc; + +val dist_ss = HOL_ss addsimprocs [distinct_simproc]; + +val simproc_setup = + Simplifier.map_simpset (fn ss => ss addsimprocs [distinct_simproc]); + + +(**** translation rules for case ****) + +fun make_case ctxt = DatatypeCase.make_case + (datatype_of_constr (ProofContext.theory_of ctxt)) ctxt; + +fun strip_case ctxt = DatatypeCase.strip_case + (datatype_of_case (ProofContext.theory_of ctxt)); + +fun add_case_tr' case_names thy = + Sign.add_advanced_trfuns ([], [], + map (fn case_name => + let val case_name' = Sign.const_syntax_name thy case_name + in (case_name', DatatypeCase.case_tr' datatype_of_case case_name') + end) case_names, []) thy; + +val trfun_setup = + Sign.add_advanced_trfuns ([], + [("_case_syntax", DatatypeCase.case_tr true datatype_of_constr)], + [], []); + + +(* prepare types *) + +fun read_typ thy ((Ts, sorts), str) = + let + val ctxt = ProofContext.init thy + |> fold (Variable.declare_typ o TFree) sorts; + val T = Syntax.read_typ ctxt str; + in (Ts @ [T], Term.add_tfreesT T sorts) end; + +fun cert_typ sign ((Ts, sorts), raw_T) = + let + val T = Type.no_tvars (Sign.certify_typ sign raw_T) handle + TYPE (msg, _, _) => error msg; + val sorts' = Term.add_tfreesT T sorts; + in (Ts @ [T], + case duplicates (op =) (map fst sorts') of + [] => sorts' + | dups => error ("Inconsistent sort constraints for " ^ commas dups)) + end; + + +(**** make datatype info ****) + +fun make_dt_info alt_names descr sorts induct reccomb_names rec_thms + (((((((((i, (_, (tname, _, _))), case_name), case_thms), + exhaustion_thm), distinct_thm), inject), nchotomy), case_cong), weak_case_cong) = + (tname, + {index = i, + alt_names = alt_names, + descr = descr, + sorts = sorts, + rec_names = reccomb_names, + rec_rewrites = rec_thms, + case_name = case_name, + case_rewrites = case_thms, + induction = induct, + exhaustion = exhaustion_thm, + distinct = distinct_thm, + inject = inject, + nchotomy = nchotomy, + case_cong = case_cong, + weak_case_cong = weak_case_cong}); + +type rules = {distinct : thm list list, + inject : thm list list, + exhaustion : thm list, + rec_thms : thm list, + case_thms : thm list list, + split_thms : (thm * thm) list, + induction : thm, + simps : thm list} + +structure DatatypeInterpretation = InterpretationFun + (type T = config * string list val eq: T * T -> bool = eq_snd op =); +fun interpretation f = DatatypeInterpretation.interpretation (uncurry f); + + +(******************* definitional introduction of datatypes *******************) + +fun add_datatype_def (config : config) new_type_names descr sorts types_syntax constr_syntax dt_info + case_names_induct case_names_exhausts thy = + let + val _ = message config ("Proofs for datatype(s) " ^ commas_quote new_type_names); + + val ((inject, distinct, dist_rewrites, simproc_dists, induct), thy2) = thy |> + DatatypeRepProofs.representation_proofs config dt_info new_type_names descr sorts + types_syntax constr_syntax case_names_induct; + + val (casedist_thms, thy3) = DatatypeAbsProofs.prove_casedist_thms config new_type_names descr + sorts induct case_names_exhausts thy2; + val ((reccomb_names, rec_thms), thy4) = DatatypeAbsProofs.prove_primrec_thms + config new_type_names descr sorts dt_info inject dist_rewrites + (Simplifier.theory_context thy3 dist_ss) induct thy3; + val ((case_thms, case_names), thy6) = DatatypeAbsProofs.prove_case_thms + config new_type_names descr sorts reccomb_names rec_thms thy4; + val (split_thms, thy7) = DatatypeAbsProofs.prove_split_thms config new_type_names + descr sorts inject dist_rewrites casedist_thms case_thms thy6; + val (nchotomys, thy8) = DatatypeAbsProofs.prove_nchotomys config new_type_names + descr sorts casedist_thms thy7; + val (case_congs, thy9) = DatatypeAbsProofs.prove_case_congs new_type_names + descr sorts nchotomys case_thms thy8; + val (weak_case_congs, thy10) = DatatypeAbsProofs.prove_weak_case_congs new_type_names + descr sorts thy9; + + val dt_infos = map + (make_dt_info (SOME new_type_names) (flat descr) sorts induct reccomb_names rec_thms) + ((0 upto length (hd descr) - 1) ~~ (hd descr) ~~ case_names ~~ case_thms ~~ + casedist_thms ~~ simproc_dists ~~ inject ~~ nchotomys ~~ case_congs ~~ weak_case_congs); + + val simps = flat (distinct @ inject @ case_thms) @ rec_thms; + + val thy12 = + thy10 + |> add_case_tr' case_names + |> Sign.add_path (space_implode "_" new_type_names) + |> add_rules simps case_thms rec_thms inject distinct + weak_case_congs (Simplifier.attrib (op addcongs)) + |> put_dt_infos dt_infos + |> add_cases_induct dt_infos induct + |> Sign.parent_path + |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms) |> snd + |> DatatypeInterpretation.data (config, map fst dt_infos); + in + ({distinct = distinct, + inject = inject, + exhaustion = casedist_thms, + rec_thms = rec_thms, + case_thms = case_thms, + split_thms = split_thms, + induction = induct, + simps = simps}, thy12) + end; + + +(*********************** declare existing type as datatype *********************) + +fun prove_rep_datatype (config : config) alt_names new_type_names descr sorts induct inject half_distinct thy = + let + val ((_, [induct']), _) = + Variable.importT_thms [induct] (Variable.thm_context induct); + + fun err t = error ("Ill-formed predicate in induction rule: " ^ + Syntax.string_of_term_global thy t); + + fun get_typ (t as _ $ Var (_, Type (tname, Ts))) = + ((tname, map (fst o dest_TFree) Ts) handle TERM _ => err t) + | get_typ t = err t; + val dtnames = map get_typ (HOLogic.dest_conj (HOLogic.dest_Trueprop (Thm.concl_of induct'))); + + val dt_info = get_datatypes thy; + + val distinct = (map o maps) (fn thm => [thm, thm RS not_sym]) half_distinct; + val (case_names_induct, case_names_exhausts) = + (mk_case_names_induct descr, mk_case_names_exhausts descr (map #1 dtnames)); + + val _ = message config ("Proofs for datatype(s) " ^ commas_quote new_type_names); + + val (casedist_thms, thy2) = thy |> + DatatypeAbsProofs.prove_casedist_thms config new_type_names [descr] sorts induct + case_names_exhausts; + val ((reccomb_names, rec_thms), thy3) = DatatypeAbsProofs.prove_primrec_thms + config new_type_names [descr] sorts dt_info inject distinct + (Simplifier.theory_context thy2 dist_ss) induct thy2; + val ((case_thms, case_names), thy4) = DatatypeAbsProofs.prove_case_thms + config new_type_names [descr] sorts reccomb_names rec_thms thy3; + val (split_thms, thy5) = DatatypeAbsProofs.prove_split_thms + config new_type_names [descr] sorts inject distinct casedist_thms case_thms thy4; + val (nchotomys, thy6) = DatatypeAbsProofs.prove_nchotomys config new_type_names + [descr] sorts casedist_thms thy5; + val (case_congs, thy7) = DatatypeAbsProofs.prove_case_congs new_type_names + [descr] sorts nchotomys case_thms thy6; + val (weak_case_congs, thy8) = DatatypeAbsProofs.prove_weak_case_congs new_type_names + [descr] sorts thy7; + + val ((_, [induct']), thy10) = + thy8 + |> store_thmss "inject" new_type_names inject + ||>> store_thmss "distinct" new_type_names distinct + ||> Sign.add_path (space_implode "_" new_type_names) + ||>> PureThy.add_thms [((Binding.name "induct", induct), [case_names_induct])]; + + val dt_infos = map (make_dt_info alt_names descr sorts induct' reccomb_names rec_thms) + ((0 upto length descr - 1) ~~ descr ~~ case_names ~~ case_thms ~~ casedist_thms ~~ + map FewConstrs distinct ~~ inject ~~ nchotomys ~~ case_congs ~~ weak_case_congs); + + val simps = flat (distinct @ inject @ case_thms) @ rec_thms; + + val thy11 = + thy10 + |> add_case_tr' case_names + |> add_rules simps case_thms rec_thms inject distinct + weak_case_congs (Simplifier.attrib (op addcongs)) + |> put_dt_infos dt_infos + |> add_cases_induct dt_infos induct' + |> Sign.parent_path + |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms) + |> snd + |> DatatypeInterpretation.data (config, map fst dt_infos); + in + ({distinct = distinct, + inject = inject, + exhaustion = casedist_thms, + rec_thms = rec_thms, + case_thms = case_thms, + split_thms = split_thms, + induction = induct', + simps = simps}, thy11) + end; + +fun gen_rep_datatype prep_term (config : config) after_qed alt_names raw_ts thy = + let + fun constr_of_term (Const (c, T)) = (c, T) + | constr_of_term t = + error ("Not a constant: " ^ Syntax.string_of_term_global thy t); + fun no_constr (c, T) = error ("Bad constructor: " + ^ Sign.extern_const thy c ^ "::" + ^ Syntax.string_of_typ_global thy T); + fun type_of_constr (cT as (_, T)) = + let + val frees = OldTerm.typ_tfrees T; + val (tyco, vs) = ((apsnd o map) (dest_TFree) o dest_Type o snd o strip_type) T + handle TYPE _ => no_constr cT + val _ = if has_duplicates (eq_fst (op =)) vs then no_constr cT else (); + val _ = if length frees <> length vs then no_constr cT else (); + in (tyco, (vs, cT)) end; + + val raw_cs = AList.group (op =) (map (type_of_constr o constr_of_term o prep_term thy) raw_ts); + val _ = case map_filter (fn (tyco, _) => + if Symtab.defined (get_datatypes thy) tyco then SOME tyco else NONE) raw_cs + of [] => () + | tycos => error ("Type(s) " ^ commas (map quote tycos) + ^ " already represented inductivly"); + val raw_vss = maps (map (map snd o fst) o snd) raw_cs; + val ms = case distinct (op =) (map length raw_vss) + of [n] => 0 upto n - 1 + | _ => error ("Different types in given constructors"); + fun inter_sort m = map (fn xs => nth xs m) raw_vss + |> Library.foldr1 (Sorts.inter_sort (Sign.classes_of thy)) + val sorts = map inter_sort ms; + val vs = Name.names Name.context Name.aT sorts; + + fun norm_constr (raw_vs, (c, T)) = (c, map_atyps + (TFree o (the o AList.lookup (op =) (map fst raw_vs ~~ vs)) o fst o dest_TFree) T); + + val cs = map (apsnd (map norm_constr)) raw_cs; + val dtyps_of_typ = map (dtyp_of_typ (map (rpair (map fst vs) o fst) cs)) + o fst o strip_type; + val new_type_names = map Long_Name.base_name (the_default (map fst cs) alt_names); + + fun mk_spec (i, (tyco, constr)) = (i, (tyco, + map (DtTFree o fst) vs, + (map o apsnd) dtyps_of_typ constr)) + val descr = map_index mk_spec cs; + val injs = DatatypeProp.make_injs [descr] vs; + val half_distincts = map snd (DatatypeProp.make_distincts [descr] vs); + val ind = DatatypeProp.make_ind [descr] vs; + val rules = (map o map o map) Logic.close_form [[[ind]], injs, half_distincts]; + + fun after_qed' raw_thms = + let + val [[[induct]], injs, half_distincts] = + unflat rules (map Drule.zero_var_indexes_list raw_thms); + (*FIXME somehow dubious*) + in + ProofContext.theory_result + (prove_rep_datatype config alt_names new_type_names descr vs induct injs half_distincts) + #-> after_qed + end; + in + thy + |> ProofContext.init + |> Proof.theorem_i NONE after_qed' ((map o map) (rpair []) (flat rules)) + end; + +val rep_datatype = gen_rep_datatype Sign.cert_term; +val rep_datatype_cmd = gen_rep_datatype Syntax.read_term_global default_config (K I); + + + +(******************************** add datatype ********************************) + +fun gen_add_datatype prep_typ (config : config) new_type_names dts thy = + let + val _ = Theory.requires thy "Datatype" "datatype definitions"; + + (* this theory is used just for parsing *) + + val tmp_thy = thy |> + Theory.copy |> + Sign.add_types (map (fn (tvs, tname, mx, _) => + (tname, length tvs, mx)) dts); + + val (tyvars, _, _, _)::_ = dts; + val (new_dts, types_syntax) = ListPair.unzip (map (fn (tvs, tname, mx, _) => + let val full_tname = Sign.full_name tmp_thy (Binding.map_name (Syntax.type_name mx) tname) + in (case duplicates (op =) tvs of + [] => if eq_set (tyvars, tvs) then ((full_tname, tvs), (tname, mx)) + else error ("Mutually recursive datatypes must have same type parameters") + | dups => error ("Duplicate parameter(s) for datatype " ^ quote (Binding.str_of tname) ^ + " : " ^ commas dups)) + end) dts); + + val _ = (case duplicates (op =) (map fst new_dts) @ duplicates (op =) new_type_names of + [] => () | dups => error ("Duplicate datatypes: " ^ commas dups)); + + fun prep_dt_spec ((tvs, tname, mx, constrs), tname') (dts', constr_syntax, sorts, i) = + let + fun prep_constr (cname, cargs, mx') (constrs, constr_syntax', sorts') = + let + val (cargs', sorts'') = Library.foldl (prep_typ tmp_thy) (([], sorts'), cargs); + val _ = (case fold (curry OldTerm.add_typ_tfree_names) cargs' [] \\ tvs of + [] => () + | vs => error ("Extra type variables on rhs: " ^ commas vs)) + in (constrs @ [((if #flat_names config then Sign.full_name tmp_thy else + Sign.full_name_path tmp_thy tname') + (Binding.map_name (Syntax.const_name mx') cname), + map (dtyp_of_typ new_dts) cargs')], + constr_syntax' @ [(cname, mx')], sorts'') + end handle ERROR msg => cat_error msg + ("The error above occured in constructor " ^ quote (Binding.str_of cname) ^ + " of datatype " ^ quote (Binding.str_of tname)); + + val (constrs', constr_syntax', sorts') = + fold prep_constr constrs ([], [], sorts) + + in + case duplicates (op =) (map fst constrs') of + [] => + (dts' @ [(i, (Sign.full_name tmp_thy (Binding.map_name (Syntax.type_name mx) tname), + map DtTFree tvs, constrs'))], + constr_syntax @ [constr_syntax'], sorts', i + 1) + | dups => error ("Duplicate constructors " ^ commas dups ^ + " in datatype " ^ quote (Binding.str_of tname)) + end; + + val (dts', constr_syntax, sorts', i) = + fold prep_dt_spec (dts ~~ new_type_names) ([], [], [], 0); + val sorts = sorts' @ (map (rpair (Sign.defaultS tmp_thy)) (tyvars \\ map fst sorts')); + val dt_info = get_datatypes thy; + val (descr, _) = unfold_datatypes tmp_thy dts' sorts dt_info dts' i; + val _ = check_nonempty descr handle (exn as Datatype_Empty s) => + if #strict config then error ("Nonemptiness check failed for datatype " ^ s) + else raise exn; + + val descr' = flat descr; + val case_names_induct = mk_case_names_induct descr'; + val case_names_exhausts = mk_case_names_exhausts descr' (map #1 new_dts); + in + add_datatype_def + (config : config) new_type_names descr sorts types_syntax constr_syntax dt_info + case_names_induct case_names_exhausts thy + end; + +val add_datatype = gen_add_datatype cert_typ; +val datatype_cmd = snd ooo gen_add_datatype read_typ default_config; + + + +(** package setup **) + +(* setup theory *) + +val setup = + DatatypeRepProofs.distinctness_limit_setup #> + simproc_setup #> + trfun_setup #> + DatatypeInterpretation.init; + + +(* outer syntax *) + +local + +structure P = OuterParse and K = OuterKeyword + +fun prep_datatype_decls args = + let + val names = map + (fn ((((NONE, _), t), _), _) => Binding.name_of t | ((((SOME t, _), _), _), _) => t) args; + val specs = map (fn ((((_, vs), t), mx), cons) => + (vs, t, mx, map (fn ((x, y), z) => (x, y, z)) cons)) args; + in (names, specs) end; + +val parse_datatype_decl = + (Scan.option (P.$$$ "(" |-- P.name --| P.$$$ ")") -- P.type_args -- P.binding -- P.opt_infix -- + (P.$$$ "=" |-- P.enum1 "|" (P.binding -- Scan.repeat P.typ -- P.opt_mixfix))); + +val parse_datatype_decls = P.and_list1 parse_datatype_decl >> prep_datatype_decls; + +in + +val _ = + OuterSyntax.command "datatype" "define inductive datatypes" K.thy_decl + (parse_datatype_decls >> (fn (names, specs) => Toplevel.theory (datatype_cmd names specs))); + +val _ = + OuterSyntax.command "rep_datatype" "represent existing types inductively" K.thy_goal + (Scan.option (P.$$$ "(" |-- Scan.repeat1 P.name --| P.$$$ ")") -- Scan.repeat1 P.term + >> (fn (alt_names, ts) => Toplevel.print + o Toplevel.theory_to_proof (rep_datatype_cmd alt_names ts))); + +end; + + +(* document antiquotation *) + +val _ = ThyOutput.antiquotation "datatype" Args.tyname + (fn {source = src, context = ctxt, ...} => fn dtco => + let + val thy = ProofContext.theory_of ctxt; + val (vs, cos) = the_datatype_spec thy dtco; + val ty = Type (dtco, map TFree vs); + fun pretty_typ_bracket (ty as Type (_, _ :: _)) = + Pretty.enclose "(" ")" [Syntax.pretty_typ ctxt ty] + | pretty_typ_bracket ty = + Syntax.pretty_typ ctxt ty; + fun pretty_constr (co, tys) = + (Pretty.block o Pretty.breaks) + (Syntax.pretty_term ctxt (Const (co, tys ---> ty)) :: + map pretty_typ_bracket tys); + val pretty_datatype = + Pretty.block + (Pretty.command "datatype" :: Pretty.brk 1 :: + Syntax.pretty_typ ctxt ty :: + Pretty.str " =" :: Pretty.brk 1 :: + flat (separate [Pretty.brk 1, Pretty.str "| "] + (map (single o pretty_constr) cos))); + in ThyOutput.output (ThyOutput.maybe_pretty_source (K pretty_datatype) src [()]) end); + +end; + diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Datatype/datatype_abs_proofs.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Datatype/datatype_abs_proofs.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,447 @@ +(* Title: HOL/Tools/datatype_abs_proofs.ML + Author: Stefan Berghofer, TU Muenchen + +Proofs and defintions independent of concrete representation +of datatypes (i.e. requiring only abstract properties such as +injectivity / distinctness of constructors and induction) + + - case distinction (exhaustion) theorems + - characteristic equations for primrec combinators + - characteristic equations for case combinators + - equations for splitting "P (case ...)" expressions + - "nchotomy" and "case_cong" theorems for TFL +*) + +signature DATATYPE_ABS_PROOFS = +sig + include DATATYPE_COMMON + val prove_casedist_thms : config -> string list -> + descr list -> (string * sort) list -> thm -> + attribute list -> theory -> thm list * theory + val prove_primrec_thms : config -> string list -> + descr list -> (string * sort) list -> + info Symtab.table -> thm list list -> thm list list -> + simpset -> thm -> theory -> (string list * thm list) * theory + val prove_case_thms : config -> string list -> + descr list -> (string * sort) list -> + string list -> thm list -> theory -> (thm list list * string list) * theory + val prove_split_thms : config -> string list -> + descr list -> (string * sort) list -> + thm list list -> thm list list -> thm list -> thm list list -> theory -> + (thm * thm) list * theory + val prove_nchotomys : config -> string list -> descr list -> + (string * sort) list -> thm list -> theory -> thm list * theory + val prove_weak_case_congs : string list -> descr list -> + (string * sort) list -> theory -> thm list * theory + val prove_case_congs : string list -> + descr list -> (string * sort) list -> + thm list -> thm list list -> theory -> thm list * theory +end; + +structure DatatypeAbsProofs: DATATYPE_ABS_PROOFS = +struct + +open DatatypeAux; + +(************************ case distinction theorems ***************************) + +fun prove_casedist_thms (config : config) new_type_names descr sorts induct case_names_exhausts thy = + let + val _ = message config "Proving case distinction theorems ..."; + + val descr' = List.concat descr; + val recTs = get_rec_types descr' sorts; + val newTs = Library.take (length (hd descr), recTs); + + val {maxidx, ...} = rep_thm induct; + val induct_Ps = map head_of (HOLogic.dest_conj (HOLogic.dest_Trueprop (concl_of induct))); + + fun prove_casedist_thm ((i, t), T) = + let + val dummyPs = map (fn (Var (_, Type (_, [T', T'']))) => + Abs ("z", T', Const ("True", T''))) induct_Ps; + val P = Abs ("z", T, HOLogic.imp $ HOLogic.mk_eq (Var (("a", maxidx+1), T), Bound 0) $ + Var (("P", 0), HOLogic.boolT)) + val insts = Library.take (i, dummyPs) @ (P::(Library.drop (i + 1, dummyPs))); + val cert = cterm_of thy; + val insts' = (map cert induct_Ps) ~~ (map cert insts); + val induct' = refl RS ((List.nth + (split_conj_thm (cterm_instantiate insts' induct), i)) RSN (2, rev_mp)) + + in + SkipProof.prove_global thy [] (Logic.strip_imp_prems t) (Logic.strip_imp_concl t) + (fn {prems, ...} => EVERY + [rtac induct' 1, + REPEAT (rtac TrueI 1), + REPEAT ((rtac impI 1) THEN (eresolve_tac prems 1)), + REPEAT (rtac TrueI 1)]) + end; + + val casedist_thms = map prove_casedist_thm ((0 upto (length newTs - 1)) ~~ + (DatatypeProp.make_casedists descr sorts) ~~ newTs) + in + thy + |> store_thms_atts "exhaust" new_type_names (map single case_names_exhausts) casedist_thms + end; + + +(*************************** primrec combinators ******************************) + +fun prove_primrec_thms (config : config) new_type_names descr sorts + (dt_info : info Symtab.table) constr_inject dist_rewrites dist_ss induct thy = + let + val _ = message config "Constructing primrec combinators ..."; + + val big_name = space_implode "_" new_type_names; + val thy0 = add_path (#flat_names config) big_name thy; + + val descr' = List.concat descr; + val recTs = get_rec_types descr' sorts; + val used = List.foldr OldTerm.add_typ_tfree_names [] recTs; + val newTs = Library.take (length (hd descr), recTs); + + val induct_Ps = map head_of (HOLogic.dest_conj (HOLogic.dest_Trueprop (concl_of induct))); + + val big_rec_name' = big_name ^ "_rec_set"; + val rec_set_names' = + if length descr' = 1 then [big_rec_name'] else + map ((curry (op ^) (big_rec_name' ^ "_")) o string_of_int) + (1 upto (length descr')); + val rec_set_names = map (Sign.full_bname thy0) rec_set_names'; + + val (rec_result_Ts, reccomb_fn_Ts) = DatatypeProp.make_primrec_Ts descr sorts used; + + val rec_set_Ts = map (fn (T1, T2) => + reccomb_fn_Ts @ [T1, T2] ---> HOLogic.boolT) (recTs ~~ rec_result_Ts); + + val rec_fns = map (uncurry (mk_Free "f")) + (reccomb_fn_Ts ~~ (1 upto (length reccomb_fn_Ts))); + val rec_sets' = map (fn c => list_comb (Free c, rec_fns)) + (rec_set_names' ~~ rec_set_Ts); + val rec_sets = map (fn c => list_comb (Const c, rec_fns)) + (rec_set_names ~~ rec_set_Ts); + + (* introduction rules for graph of primrec function *) + + fun make_rec_intr T rec_set ((rec_intr_ts, l), (cname, cargs)) = + let + fun mk_prem ((dt, U), (j, k, prems, t1s, t2s)) = + let val free1 = mk_Free "x" U j + in (case (strip_dtyp dt, strip_type U) of + ((_, DtRec m), (Us, _)) => + let + val free2 = mk_Free "y" (Us ---> List.nth (rec_result_Ts, m)) k; + val i = length Us + in (j + 1, k + 1, HOLogic.mk_Trueprop (HOLogic.list_all + (map (pair "x") Us, List.nth (rec_sets', m) $ + app_bnds free1 i $ app_bnds free2 i)) :: prems, + free1::t1s, free2::t2s) + end + | _ => (j + 1, k, prems, free1::t1s, t2s)) + end; + + val Ts = map (typ_of_dtyp descr' sorts) cargs; + val (_, _, prems, t1s, t2s) = List.foldr mk_prem (1, 1, [], [], []) (cargs ~~ Ts) + + in (rec_intr_ts @ [Logic.list_implies (prems, HOLogic.mk_Trueprop + (rec_set $ list_comb (Const (cname, Ts ---> T), t1s) $ + list_comb (List.nth (rec_fns, l), t1s @ t2s)))], l + 1) + end; + + val (rec_intr_ts, _) = Library.foldl (fn (x, ((d, T), set_name)) => + Library.foldl (make_rec_intr T set_name) (x, #3 (snd d))) + (([], 0), descr' ~~ recTs ~~ rec_sets'); + + val ({intrs = rec_intrs, elims = rec_elims, ...}, thy1) = + Inductive.add_inductive_global (serial_string ()) + {quiet_mode = #quiet config, verbose = false, kind = Thm.internalK, + alt_name = Binding.name big_rec_name', coind = false, no_elim = false, no_ind = true, + skip_mono = true, fork_mono = false} + (map (fn (s, T) => ((Binding.name s, T), NoSyn)) (rec_set_names' ~~ rec_set_Ts)) + (map dest_Free rec_fns) + (map (fn x => (Attrib.empty_binding, x)) rec_intr_ts) [] thy0; + + (* prove uniqueness and termination of primrec combinators *) + + val _ = message config "Proving termination and uniqueness of primrec functions ..."; + + fun mk_unique_tac ((tac, intrs), ((((i, (tname, _, constrs)), elim), T), T')) = + let + val distinct_tac = + (if i < length newTs then + full_simp_tac (HOL_ss addsimps (List.nth (dist_rewrites, i))) 1 + else full_simp_tac dist_ss 1); + + val inject = map (fn r => r RS iffD1) + (if i < length newTs then List.nth (constr_inject, i) + else #inject (the (Symtab.lookup dt_info tname))); + + fun mk_unique_constr_tac n ((tac, intr::intrs, j), (cname, cargs)) = + let + val k = length (List.filter is_rec_type cargs) + + in (EVERY [DETERM tac, + REPEAT (etac ex1E 1), rtac ex1I 1, + DEPTH_SOLVE_1 (ares_tac [intr] 1), + REPEAT_DETERM_N k (etac thin_rl 1 THEN rotate_tac 1 1), + etac elim 1, + REPEAT_DETERM_N j distinct_tac, + TRY (dresolve_tac inject 1), + REPEAT (etac conjE 1), hyp_subst_tac 1, + REPEAT (EVERY [etac allE 1, dtac mp 1, atac 1]), + TRY (hyp_subst_tac 1), + rtac refl 1, + REPEAT_DETERM_N (n - j - 1) distinct_tac], + intrs, j + 1) + end; + + val (tac', intrs', _) = Library.foldl (mk_unique_constr_tac (length constrs)) + ((tac, intrs, 0), constrs); + + in (tac', intrs') end; + + val rec_unique_thms = + let + val rec_unique_ts = map (fn (((set_t, T1), T2), i) => + Const ("Ex1", (T2 --> HOLogic.boolT) --> HOLogic.boolT) $ + absfree ("y", T2, set_t $ mk_Free "x" T1 i $ Free ("y", T2))) + (rec_sets ~~ recTs ~~ rec_result_Ts ~~ (1 upto length recTs)); + val cert = cterm_of thy1 + val insts = map (fn ((i, T), t) => absfree ("x" ^ (string_of_int i), T, t)) + ((1 upto length recTs) ~~ recTs ~~ rec_unique_ts); + val induct' = cterm_instantiate ((map cert induct_Ps) ~~ + (map cert insts)) induct; + val (tac, _) = Library.foldl mk_unique_tac + (((rtac induct' THEN_ALL_NEW ObjectLogic.atomize_prems_tac) 1 + THEN rewrite_goals_tac [mk_meta_eq choice_eq], rec_intrs), + descr' ~~ rec_elims ~~ recTs ~~ rec_result_Ts); + + in split_conj_thm (SkipProof.prove_global thy1 [] [] + (HOLogic.mk_Trueprop (mk_conj rec_unique_ts)) (K tac)) + end; + + val rec_total_thms = map (fn r => r RS theI') rec_unique_thms; + + (* define primrec combinators *) + + val big_reccomb_name = (space_implode "_" new_type_names) ^ "_rec"; + val reccomb_names = map (Sign.full_bname thy1) + (if length descr' = 1 then [big_reccomb_name] else + (map ((curry (op ^) (big_reccomb_name ^ "_")) o string_of_int) + (1 upto (length descr')))); + val reccombs = map (fn ((name, T), T') => list_comb + (Const (name, reccomb_fn_Ts @ [T] ---> T'), rec_fns)) + (reccomb_names ~~ recTs ~~ rec_result_Ts); + + val (reccomb_defs, thy2) = + thy1 + |> Sign.add_consts_i (map (fn ((name, T), T') => + (Binding.name (Long_Name.base_name name), reccomb_fn_Ts @ [T] ---> T', NoSyn)) + (reccomb_names ~~ recTs ~~ rec_result_Ts)) + |> (PureThy.add_defs false o map Thm.no_attributes) (map (fn ((((name, comb), set), T), T') => + (Binding.name (Long_Name.base_name name ^ "_def"), Logic.mk_equals (comb, absfree ("x", T, + Const ("The", (T' --> HOLogic.boolT) --> T') $ absfree ("y", T', + set $ Free ("x", T) $ Free ("y", T')))))) + (reccomb_names ~~ reccombs ~~ rec_sets ~~ recTs ~~ rec_result_Ts)) + ||> parent_path (#flat_names config) + ||> Theory.checkpoint; + + + (* prove characteristic equations for primrec combinators *) + + val _ = message config "Proving characteristic theorems for primrec combinators ..." + + val rec_thms = map (fn t => SkipProof.prove_global thy2 [] [] t + (fn _ => EVERY + [rewrite_goals_tac reccomb_defs, + rtac the1_equality 1, + resolve_tac rec_unique_thms 1, + resolve_tac rec_intrs 1, + REPEAT (rtac allI 1 ORELSE resolve_tac rec_total_thms 1)])) + (DatatypeProp.make_primrecs new_type_names descr sorts thy2) + + in + thy2 + |> Sign.add_path (space_implode "_" new_type_names) + |> PureThy.add_thmss [((Binding.name "recs", rec_thms), + [Nitpick_Const_Simp_Thms.add])] + ||> Sign.parent_path + ||> Theory.checkpoint + |-> (fn thms => pair (reccomb_names, Library.flat thms)) + end; + + +(***************************** case combinators *******************************) + +fun prove_case_thms (config : config) new_type_names descr sorts reccomb_names primrec_thms thy = + let + val _ = message config "Proving characteristic theorems for case combinators ..."; + + val thy1 = add_path (#flat_names config) (space_implode "_" new_type_names) thy; + + val descr' = List.concat descr; + val recTs = get_rec_types descr' sorts; + val used = List.foldr OldTerm.add_typ_tfree_names [] recTs; + val newTs = Library.take (length (hd descr), recTs); + val T' = TFree (Name.variant used "'t", HOLogic.typeS); + + fun mk_dummyT dt = binder_types (typ_of_dtyp descr' sorts dt) ---> T'; + + val case_dummy_fns = map (fn (_, (_, _, constrs)) => map (fn (_, cargs) => + let + val Ts = map (typ_of_dtyp descr' sorts) cargs; + val Ts' = map mk_dummyT (List.filter is_rec_type cargs) + in Const (@{const_name undefined}, Ts @ Ts' ---> T') + end) constrs) descr'; + + val case_names = map (fn s => Sign.full_bname thy1 (s ^ "_case")) new_type_names; + + (* define case combinators via primrec combinators *) + + val (case_defs, thy2) = Library.foldl (fn ((defs, thy), + ((((i, (_, _, constrs)), T), name), recname)) => + let + val (fns1, fns2) = ListPair.unzip (map (fn ((_, cargs), j) => + let + val Ts = map (typ_of_dtyp descr' sorts) cargs; + val Ts' = Ts @ map mk_dummyT (List.filter is_rec_type cargs); + val frees' = map (uncurry (mk_Free "x")) (Ts' ~~ (1 upto length Ts')); + val frees = Library.take (length cargs, frees'); + val free = mk_Free "f" (Ts ---> T') j + in + (free, list_abs_free (map dest_Free frees', + list_comb (free, frees))) + end) (constrs ~~ (1 upto length constrs))); + + val caseT = (map (snd o dest_Free) fns1) @ [T] ---> T'; + val fns = (List.concat (Library.take (i, case_dummy_fns))) @ + fns2 @ (List.concat (Library.drop (i + 1, case_dummy_fns))); + val reccomb = Const (recname, (map fastype_of fns) @ [T] ---> T'); + val decl = ((Binding.name (Long_Name.base_name name), caseT), NoSyn); + val def = (Binding.name (Long_Name.base_name name ^ "_def"), + Logic.mk_equals (list_comb (Const (name, caseT), fns1), + list_comb (reccomb, (List.concat (Library.take (i, case_dummy_fns))) @ + fns2 @ (List.concat (Library.drop (i + 1, case_dummy_fns))) ))); + val ([def_thm], thy') = + thy + |> Sign.declare_const [] decl |> snd + |> (PureThy.add_defs false o map Thm.no_attributes) [def]; + + in (defs @ [def_thm], thy') + end) (([], thy1), (hd descr) ~~ newTs ~~ case_names ~~ + (Library.take (length newTs, reccomb_names))) + ||> Theory.checkpoint; + + val case_thms = map (map (fn t => SkipProof.prove_global thy2 [] [] t + (fn _ => EVERY [rewrite_goals_tac (case_defs @ map mk_meta_eq primrec_thms), rtac refl 1]))) + (DatatypeProp.make_cases new_type_names descr sorts thy2) + in + thy2 + |> Context.the_theory o fold (fold Nitpick_Const_Simp_Thms.add_thm) case_thms + o Context.Theory + |> parent_path (#flat_names config) + |> store_thmss "cases" new_type_names case_thms + |-> (fn thmss => pair (thmss, case_names)) + end; + + +(******************************* case splitting *******************************) + +fun prove_split_thms (config : config) new_type_names descr sorts constr_inject dist_rewrites + casedist_thms case_thms thy = + let + val _ = message config "Proving equations for case splitting ..."; + + val descr' = flat descr; + val recTs = get_rec_types descr' sorts; + val newTs = Library.take (length (hd descr), recTs); + + fun prove_split_thms ((((((t1, t2), inject), dist_rewrites'), + exhaustion), case_thms'), T) = + let + val cert = cterm_of thy; + val _ $ (_ $ lhs $ _) = hd (Logic.strip_assums_hyp (hd (prems_of exhaustion))); + val exhaustion' = cterm_instantiate + [(cert lhs, cert (Free ("x", T)))] exhaustion; + val tacf = K (EVERY [rtac exhaustion' 1, ALLGOALS (asm_simp_tac + (HOL_ss addsimps (dist_rewrites' @ inject @ case_thms')))]) + in + (SkipProof.prove_global thy [] [] t1 tacf, + SkipProof.prove_global thy [] [] t2 tacf) + end; + + val split_thm_pairs = map prove_split_thms + ((DatatypeProp.make_splits new_type_names descr sorts thy) ~~ constr_inject ~~ + dist_rewrites ~~ casedist_thms ~~ case_thms ~~ newTs); + + val (split_thms, split_asm_thms) = ListPair.unzip split_thm_pairs + + in + thy + |> store_thms "split" new_type_names split_thms + ||>> store_thms "split_asm" new_type_names split_asm_thms + |-> (fn (thms1, thms2) => pair (thms1 ~~ thms2)) + end; + +fun prove_weak_case_congs new_type_names descr sorts thy = + let + fun prove_weak_case_cong t = + SkipProof.prove_global thy [] (Logic.strip_imp_prems t) (Logic.strip_imp_concl t) + (fn {prems, ...} => EVERY [rtac ((hd prems) RS arg_cong) 1]) + + val weak_case_congs = map prove_weak_case_cong (DatatypeProp.make_weak_case_congs + new_type_names descr sorts thy) + + in thy |> store_thms "weak_case_cong" new_type_names weak_case_congs end; + +(************************* additional theorems for TFL ************************) + +fun prove_nchotomys (config : config) new_type_names descr sorts casedist_thms thy = + let + val _ = message config "Proving additional theorems for TFL ..."; + + fun prove_nchotomy (t, exhaustion) = + let + (* For goal i, select the correct disjunct to attack, then prove it *) + fun tac i 0 = EVERY [TRY (rtac disjI1 i), + hyp_subst_tac i, REPEAT (rtac exI i), rtac refl i] + | tac i n = rtac disjI2 i THEN tac i (n - 1) + in + SkipProof.prove_global thy [] [] t (fn _ => + EVERY [rtac allI 1, + exh_tac (K exhaustion) 1, + ALLGOALS (fn i => tac i (i-1))]) + end; + + val nchotomys = + map prove_nchotomy (DatatypeProp.make_nchotomys descr sorts ~~ casedist_thms) + + in thy |> store_thms "nchotomy" new_type_names nchotomys end; + +fun prove_case_congs new_type_names descr sorts nchotomys case_thms thy = + let + fun prove_case_cong ((t, nchotomy), case_rewrites) = + let + val (Const ("==>", _) $ tm $ _) = t; + val (Const ("Trueprop", _) $ (Const ("op =", _) $ _ $ Ma)) = tm; + val cert = cterm_of thy; + val nchotomy' = nchotomy RS spec; + val [v] = Term.add_vars (concl_of nchotomy') []; + val nchotomy'' = cterm_instantiate [(cert (Var v), cert Ma)] nchotomy' + in + SkipProof.prove_global thy [] (Logic.strip_imp_prems t) (Logic.strip_imp_concl t) + (fn {prems, ...} => + let val simplify = asm_simp_tac (HOL_ss addsimps (prems @ case_rewrites)) + in EVERY [simp_tac (HOL_ss addsimps [hd prems]) 1, + cut_facts_tac [nchotomy''] 1, + REPEAT (etac disjE 1 THEN REPEAT (etac exE 1) THEN simplify 1), + REPEAT (etac exE 1) THEN simplify 1 (* Get last disjunct *)] + end) + end; + + val case_congs = map prove_case_cong (DatatypeProp.make_case_congs + new_type_names descr sorts thy ~~ nchotomys ~~ case_thms) + + in thy |> store_thms "case_cong" new_type_names case_congs end; + +end; diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Datatype/datatype_aux.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Datatype/datatype_aux.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,381 @@ +(* Title: HOL/Tools/datatype_aux.ML + Author: Stefan Berghofer, TU Muenchen + +Auxiliary functions for defining datatypes. +*) + +signature DATATYPE_COMMON = +sig + type config + val default_config : config + datatype dtyp = + DtTFree of string + | DtType of string * (dtyp list) + | DtRec of int; + type descr + type info +end + +signature DATATYPE_AUX = +sig + include DATATYPE_COMMON + + val message : config -> string -> unit + + val add_path : bool -> string -> theory -> theory + val parent_path : bool -> theory -> theory + + val store_thmss_atts : string -> string list -> attribute list list -> thm list list + -> theory -> thm list list * theory + val store_thmss : string -> string list -> thm list list -> theory -> thm list list * theory + val store_thms_atts : string -> string list -> attribute list list -> thm list + -> theory -> thm list * theory + val store_thms : string -> string list -> thm list -> theory -> thm list * theory + + val split_conj_thm : thm -> thm list + val mk_conj : term list -> term + val mk_disj : term list -> term + + val app_bnds : term -> int -> term + + val cong_tac : int -> tactic + val indtac : thm -> string list -> int -> tactic + val exh_tac : (string -> thm) -> int -> tactic + + datatype simproc_dist = FewConstrs of thm list + | ManyConstrs of thm * simpset; + + + exception Datatype + exception Datatype_Empty of string + val name_of_typ : typ -> string + val dtyp_of_typ : (string * string list) list -> typ -> dtyp + val mk_Free : string -> typ -> int -> term + val is_rec_type : dtyp -> bool + val typ_of_dtyp : descr -> (string * sort) list -> dtyp -> typ + val dest_DtTFree : dtyp -> string + val dest_DtRec : dtyp -> int + val strip_dtyp : dtyp -> dtyp list * dtyp + val body_index : dtyp -> int + val mk_fun_dtyp : dtyp list -> dtyp -> dtyp + val get_nonrec_types : descr -> (string * sort) list -> typ list + val get_branching_types : descr -> (string * sort) list -> typ list + val get_arities : descr -> int list + val get_rec_types : descr -> (string * sort) list -> typ list + val interpret_construction : descr -> (string * sort) list + -> { atyp: typ -> 'a, dtyp: typ list -> int * bool -> string * typ list -> 'a } + -> ((string * Term.typ list) * (string * 'a list) list) list + val check_nonempty : descr list -> unit + val unfold_datatypes : + theory -> descr -> (string * sort) list -> info Symtab.table -> + descr -> int -> descr list * int +end; + +structure DatatypeAux : DATATYPE_AUX = +struct + +(* datatype option flags *) + +type config = { strict: bool, flat_names: bool, quiet: bool }; +val default_config : config = + { strict = true, flat_names = false, quiet = false }; +fun message ({ quiet, ...} : config) s = + if quiet then () else writeln s; + +fun add_path flat_names s = if flat_names then I else Sign.add_path s; +fun parent_path flat_names = if flat_names then I else Sign.parent_path; + + +(* store theorems in theory *) + +fun store_thmss_atts label tnames attss thmss = + fold_map (fn ((tname, atts), thms) => + Sign.add_path tname + #> PureThy.add_thmss [((Binding.name label, thms), atts)] + #-> (fn thm::_ => Sign.parent_path #> pair thm)) (tnames ~~ attss ~~ thmss) + ##> Theory.checkpoint; + +fun store_thmss label tnames = store_thmss_atts label tnames (replicate (length tnames) []); + +fun store_thms_atts label tnames attss thmss = + fold_map (fn ((tname, atts), thms) => + Sign.add_path tname + #> PureThy.add_thms [((Binding.name label, thms), atts)] + #-> (fn thm::_ => Sign.parent_path #> pair thm)) (tnames ~~ attss ~~ thmss) + ##> Theory.checkpoint; + +fun store_thms label tnames = store_thms_atts label tnames (replicate (length tnames) []); + + +(* split theorem thm_1 & ... & thm_n into n theorems *) + +fun split_conj_thm th = + ((th RS conjunct1)::(split_conj_thm (th RS conjunct2))) handle THM _ => [th]; + +val mk_conj = foldr1 (HOLogic.mk_binop "op &"); +val mk_disj = foldr1 (HOLogic.mk_binop "op |"); + +fun app_bnds t i = list_comb (t, map Bound (i - 1 downto 0)); + + +fun cong_tac i st = (case Logic.strip_assums_concl + (List.nth (prems_of st, i - 1)) of + _ $ (_ $ (f $ x) $ (g $ y)) => + let + val cong' = Thm.lift_rule (Thm.cprem_of st i) cong; + val _ $ (_ $ (f' $ x') $ (g' $ y')) = + Logic.strip_assums_concl (prop_of cong'); + val insts = map (pairself (cterm_of (Thm.theory_of_thm st)) o + apsnd (curry list_abs (Logic.strip_params (concl_of cong'))) o + apfst head_of) [(f', f), (g', g), (x', x), (y', y)] + in compose_tac (false, cterm_instantiate insts cong', 2) i st + handle THM _ => no_tac st + end + | _ => no_tac st); + +(* instantiate induction rule *) + +fun indtac indrule indnames i st = + let + val ts = HOLogic.dest_conj (HOLogic.dest_Trueprop (concl_of indrule)); + val ts' = HOLogic.dest_conj (HOLogic.dest_Trueprop + (Logic.strip_imp_concl (List.nth (prems_of st, i - 1)))); + val getP = if can HOLogic.dest_imp (hd ts) then + (apfst SOME) o HOLogic.dest_imp else pair NONE; + val flt = if null indnames then I else + filter (fn Free (s, _) => s mem indnames | _ => false); + fun abstr (t1, t2) = (case t1 of + NONE => (case flt (OldTerm.term_frees t2) of + [Free (s, T)] => SOME (absfree (s, T, t2)) + | _ => NONE) + | SOME (_ $ t') => SOME (Abs ("x", fastype_of t', abstract_over (t', t2)))) + val cert = cterm_of (Thm.theory_of_thm st); + val insts = List.mapPartial (fn (t, u) => case abstr (getP u) of + NONE => NONE + | SOME u' => SOME (t |> getP |> snd |> head_of |> cert, cert u')) (ts ~~ ts'); + val indrule' = cterm_instantiate insts indrule + in + rtac indrule' i st + end; + +(* perform exhaustive case analysis on last parameter of subgoal i *) + +fun exh_tac exh_thm_of i state = + let + val thy = Thm.theory_of_thm state; + val prem = nth (prems_of state) (i - 1); + val params = Logic.strip_params prem; + val (_, Type (tname, _)) = hd (rev params); + val exhaustion = Thm.lift_rule (Thm.cprem_of state i) (exh_thm_of tname); + val prem' = hd (prems_of exhaustion); + val _ $ (_ $ lhs $ _) = hd (rev (Logic.strip_assums_hyp prem')); + val exhaustion' = cterm_instantiate [(cterm_of thy (head_of lhs), + cterm_of thy (List.foldr (fn ((_, T), t) => Abs ("z", T, t)) + (Bound 0) params))] exhaustion + in compose_tac (false, exhaustion', nprems_of exhaustion) i state + end; + +(* handling of distinctness theorems *) + +datatype simproc_dist = FewConstrs of thm list + | ManyConstrs of thm * simpset; + +(********************** Internal description of datatypes *********************) + +datatype dtyp = + DtTFree of string + | DtType of string * (dtyp list) + | DtRec of int; + +(* information about datatypes *) + +(* index, datatype name, type arguments, constructor name, types of constructor's arguments *) +type descr = (int * (string * dtyp list * (string * dtyp list) list)) list; + +type info = + {index : int, + alt_names : string list option, + descr : descr, + sorts : (string * sort) list, + rec_names : string list, + rec_rewrites : thm list, + case_name : string, + case_rewrites : thm list, + induction : thm, + exhaustion : thm, + distinct : simproc_dist, + inject : thm list, + nchotomy : thm, + case_cong : thm, + weak_case_cong : thm}; + +fun mk_Free s T i = Free (s ^ (string_of_int i), T); + +fun subst_DtTFree _ substs (T as (DtTFree name)) = + AList.lookup (op =) substs name |> the_default T + | subst_DtTFree i substs (DtType (name, ts)) = + DtType (name, map (subst_DtTFree i substs) ts) + | subst_DtTFree i _ (DtRec j) = DtRec (i + j); + +exception Datatype; +exception Datatype_Empty of string; + +fun dest_DtTFree (DtTFree a) = a + | dest_DtTFree _ = raise Datatype; + +fun dest_DtRec (DtRec i) = i + | dest_DtRec _ = raise Datatype; + +fun is_rec_type (DtType (_, dts)) = exists is_rec_type dts + | is_rec_type (DtRec _) = true + | is_rec_type _ = false; + +fun strip_dtyp (DtType ("fun", [T, U])) = apfst (cons T) (strip_dtyp U) + | strip_dtyp T = ([], T); + +val body_index = dest_DtRec o snd o strip_dtyp; + +fun mk_fun_dtyp [] U = U + | mk_fun_dtyp (T :: Ts) U = DtType ("fun", [T, mk_fun_dtyp Ts U]); + +fun name_of_typ (Type (s, Ts)) = + let val s' = Long_Name.base_name s + in space_implode "_" (List.filter (not o equal "") (map name_of_typ Ts) @ + [if Syntax.is_identifier s' then s' else "x"]) + end + | name_of_typ _ = ""; + +fun dtyp_of_typ _ (TFree (n, _)) = DtTFree n + | dtyp_of_typ _ (TVar _) = error "Illegal schematic type variable(s)" + | dtyp_of_typ new_dts (Type (tname, Ts)) = + (case AList.lookup (op =) new_dts tname of + NONE => DtType (tname, map (dtyp_of_typ new_dts) Ts) + | SOME vs => if map (try (fst o dest_TFree)) Ts = map SOME vs then + DtRec (find_index (curry op = tname o fst) new_dts) + else error ("Illegal occurrence of recursive type " ^ tname)); + +fun typ_of_dtyp descr sorts (DtTFree a) = TFree (a, (the o AList.lookup (op =) sorts) a) + | typ_of_dtyp descr sorts (DtRec i) = + let val (s, ds, _) = (the o AList.lookup (op =) descr) i + in Type (s, map (typ_of_dtyp descr sorts) ds) end + | typ_of_dtyp descr sorts (DtType (s, ds)) = + Type (s, map (typ_of_dtyp descr sorts) ds); + +(* find all non-recursive types in datatype description *) + +fun get_nonrec_types descr sorts = + map (typ_of_dtyp descr sorts) (Library.foldl (fn (Ts, (_, (_, _, constrs))) => + Library.foldl (fn (Ts', (_, cargs)) => + filter_out is_rec_type cargs union Ts') (Ts, constrs)) ([], descr)); + +(* get all recursive types in datatype description *) + +fun get_rec_types descr sorts = map (fn (_ , (s, ds, _)) => + Type (s, map (typ_of_dtyp descr sorts) ds)) descr; + +(* get all branching types *) + +fun get_branching_types descr sorts = + map (typ_of_dtyp descr sorts) (fold (fn (_, (_, _, constrs)) => + fold (fn (_, cargs) => fold (strip_dtyp #> fst #> fold (insert op =)) cargs) + constrs) descr []); + +fun get_arities descr = fold (fn (_, (_, _, constrs)) => + fold (fn (_, cargs) => fold (insert op =) (map (length o fst o strip_dtyp) + (List.filter is_rec_type cargs))) constrs) descr []; + +(* interpret construction of datatype *) + +fun interpret_construction descr vs { atyp, dtyp } = + let + val typ_of_dtyp = typ_of_dtyp descr vs; + fun interpT dT = case strip_dtyp dT + of (dTs, DtRec l) => + let + val (tyco, dTs', _) = (the o AList.lookup (op =) descr) l; + val Ts = map typ_of_dtyp dTs; + val Ts' = map typ_of_dtyp dTs'; + val is_proper = forall (can dest_TFree) Ts'; + in dtyp Ts (l, is_proper) (tyco, Ts') end + | _ => atyp (typ_of_dtyp dT); + fun interpC (c, dTs) = (c, map interpT dTs); + fun interpD (_, (tyco, dTs, cs)) = ((tyco, map typ_of_dtyp dTs), map interpC cs); + in map interpD descr end; + +(* nonemptiness check for datatypes *) + +fun check_nonempty descr = + let + val descr' = List.concat descr; + fun is_nonempty_dt is i = + let + val (_, _, constrs) = (the o AList.lookup (op =) descr') i; + fun arg_nonempty (_, DtRec i) = if i mem is then false + else is_nonempty_dt (i::is) i + | arg_nonempty _ = true; + in exists ((forall (arg_nonempty o strip_dtyp)) o snd) constrs + end + in assert_all (fn (i, _) => is_nonempty_dt [i] i) (hd descr) + (fn (_, (s, _, _)) => raise Datatype_Empty s) + end; + +(* unfold a list of mutually recursive datatype specifications *) +(* all types of the form DtType (dt_name, [..., DtRec _, ...]) *) +(* need to be unfolded *) + +fun unfold_datatypes sign orig_descr sorts (dt_info : info Symtab.table) descr i = + let + fun typ_error T msg = error ("Non-admissible type expression\n" ^ + Syntax.string_of_typ_global sign (typ_of_dtyp (orig_descr @ descr) sorts T) ^ "\n" ^ msg); + + fun get_dt_descr T i tname dts = + (case Symtab.lookup dt_info tname of + NONE => typ_error T (tname ^ " is not a datatype - can't use it in\ + \ nested recursion") + | (SOME {index, descr, ...}) => + let val (_, vars, _) = (the o AList.lookup (op =) descr) index; + val subst = ((map dest_DtTFree vars) ~~ dts) handle Library.UnequalLengths => + typ_error T ("Type constructor " ^ tname ^ " used with wrong\ + \ number of arguments") + in (i + index, map (fn (j, (tn, args, cs)) => (i + j, + (tn, map (subst_DtTFree i subst) args, + map (apsnd (map (subst_DtTFree i subst))) cs))) descr) + end); + + (* unfold a single constructor argument *) + + fun unfold_arg ((i, Ts, descrs), T) = + if is_rec_type T then + let val (Us, U) = strip_dtyp T + in if exists is_rec_type Us then + typ_error T "Non-strictly positive recursive occurrence of type" + else (case U of + DtType (tname, dts) => + let + val (index, descr) = get_dt_descr T i tname dts; + val (descr', i') = unfold_datatypes sign orig_descr sorts + dt_info descr (i + length descr) + in (i', Ts @ [mk_fun_dtyp Us (DtRec index)], descrs @ descr') end + | _ => (i, Ts @ [T], descrs)) + end + else (i, Ts @ [T], descrs); + + (* unfold a constructor *) + + fun unfold_constr ((i, constrs, descrs), (cname, cargs)) = + let val (i', cargs', descrs') = Library.foldl unfold_arg ((i, [], descrs), cargs) + in (i', constrs @ [(cname, cargs')], descrs') end; + + (* unfold a single datatype *) + + fun unfold_datatype ((i, dtypes, descrs), (j, (tname, tvars, constrs))) = + let val (i', constrs', descrs') = + Library.foldl unfold_constr ((i, [], descrs), constrs) + in (i', dtypes @ [(j, (tname, tvars, constrs'))], descrs') + end; + + val (i', descr', descrs) = Library.foldl unfold_datatype ((i, [],[]), descr); + + in (descr' :: descrs, i') end; + +end; diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Datatype/datatype_case.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Datatype/datatype_case.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,469 @@ +(* Title: HOL/Tools/datatype_case.ML + Author: Konrad Slind, Cambridge University Computer Laboratory + Author: Stefan Berghofer, TU Muenchen + +Nested case expressions on datatypes. +*) + +signature DATATYPE_CASE = +sig + val make_case: (string -> DatatypeAux.info option) -> + Proof.context -> bool -> string list -> term -> (term * term) list -> + term * (term * (int * bool)) list + val dest_case: (string -> DatatypeAux.info option) -> bool -> + string list -> term -> (term * (term * term) list) option + val strip_case: (string -> DatatypeAux.info option) -> bool -> + term -> (term * (term * term) list) option + val case_tr: bool -> (theory -> string -> DatatypeAux.info option) + -> Proof.context -> term list -> term + val case_tr': (theory -> string -> DatatypeAux.info option) -> + string -> Proof.context -> term list -> term +end; + +structure DatatypeCase : DATATYPE_CASE = +struct + +exception CASE_ERROR of string * int; + +fun match_type thy pat ob = Sign.typ_match thy (pat, ob) Vartab.empty; + +(*--------------------------------------------------------------------------- + * Get information about datatypes + *---------------------------------------------------------------------------*) + +fun ty_info (tab : string -> DatatypeAux.info option) s = + case tab s of + SOME {descr, case_name, index, sorts, ...} => + let + val (_, (tname, dts, constrs)) = nth descr index; + val mk_ty = DatatypeAux.typ_of_dtyp descr sorts; + val T = Type (tname, map mk_ty dts) + in + SOME {case_name = case_name, + constructors = map (fn (cname, dts') => + Const (cname, Logic.varifyT (map mk_ty dts' ---> T))) constrs} + end + | NONE => NONE; + + +(*--------------------------------------------------------------------------- + * Each pattern carries with it a tag (i,b) where + * i is the clause it came from and + * b=true indicates that clause was given by the user + * (or is an instantiation of a user supplied pattern) + * b=false --> i = ~1 + *---------------------------------------------------------------------------*) + +fun pattern_subst theta (tm, x) = (subst_free theta tm, x); + +fun row_of_pat x = fst (snd x); + +fun add_row_used ((prfx, pats), (tm, tag)) = + fold Term.add_free_names (tm :: pats @ prfx); + +(* try to preserve names given by user *) +fun default_names names ts = + map (fn ("", Free (name', _)) => name' | (name, _) => name) (names ~~ ts); + +fun strip_constraints (Const ("_constrain", _) $ t $ tT) = + strip_constraints t ||> cons tT + | strip_constraints t = (t, []); + +fun mk_fun_constrain tT t = Syntax.const "_constrain" $ t $ + (Syntax.free "fun" $ tT $ Syntax.free "dummy"); + + +(*--------------------------------------------------------------------------- + * Produce an instance of a constructor, plus genvars for its arguments. + *---------------------------------------------------------------------------*) +fun fresh_constr ty_match ty_inst colty used c = + let + val (_, Ty) = dest_Const c + val Ts = binder_types Ty; + val names = Name.variant_list used + (DatatypeProp.make_tnames (map Logic.unvarifyT Ts)); + val ty = body_type Ty; + val ty_theta = ty_match ty colty handle Type.TYPE_MATCH => + raise CASE_ERROR ("type mismatch", ~1) + val c' = ty_inst ty_theta c + val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts) + in (c', gvars) + end; + + +(*--------------------------------------------------------------------------- + * Goes through a list of rows and picks out the ones beginning with a + * pattern with constructor = name. + *---------------------------------------------------------------------------*) +fun mk_group (name, T) rows = + let val k = length (binder_types T) + in fold (fn (row as ((prfx, p :: rst), rhs as (_, (i, _)))) => + fn ((in_group, not_in_group), (names, cnstrts)) => (case strip_comb p of + (Const (name', _), args) => + if name = name' then + if length args = k then + let val (args', cnstrts') = split_list (map strip_constraints args) + in + ((((prfx, args' @ rst), rhs) :: in_group, not_in_group), + (default_names names args', map2 append cnstrts cnstrts')) + end + else raise CASE_ERROR + ("Wrong number of arguments for constructor " ^ name, i) + else ((in_group, row :: not_in_group), (names, cnstrts)) + | _ => raise CASE_ERROR ("Not a constructor pattern", i))) + rows (([], []), (replicate k "", replicate k [])) |>> pairself rev + end; + +(*--------------------------------------------------------------------------- + * Partition the rows. Not efficient: we should use hashing. + *---------------------------------------------------------------------------*) +fun partition _ _ _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1) + | partition ty_match ty_inst type_of used constructors colty res_ty + (rows as (((prfx, _ :: rstp), _) :: _)) = + let + fun part {constrs = [], rows = [], A} = rev A + | part {constrs = [], rows = (_, (_, (i, _))) :: _, A} = + raise CASE_ERROR ("Not a constructor pattern", i) + | part {constrs = c :: crst, rows, A} = + let + val ((in_group, not_in_group), (names, cnstrts)) = + mk_group (dest_Const c) rows; + val used' = fold add_row_used in_group used; + val (c', gvars) = fresh_constr ty_match ty_inst colty used' c; + val in_group' = + if null in_group (* Constructor not given *) + then + let + val Ts = map type_of rstp; + val xs = Name.variant_list + (fold Term.add_free_names gvars used') + (replicate (length rstp) "x") + in + [((prfx, gvars @ map Free (xs ~~ Ts)), + (Const ("HOL.undefined", res_ty), (~1, false)))] + end + else in_group + in + part{constrs = crst, + rows = not_in_group, + A = {constructor = c', + new_formals = gvars, + names = names, + constraints = cnstrts, + group = in_group'} :: A} + end + in part {constrs = constructors, rows = rows, A = []} + end; + +(*--------------------------------------------------------------------------- + * Misc. routines used in mk_case + *---------------------------------------------------------------------------*) + +fun mk_pat ((c, c'), l) = + let + val L = length (binder_types (fastype_of c)) + fun build (prfx, tag, plist) = + let val (args, plist') = chop L plist + in (prfx, tag, list_comb (c', args) :: plist') end + in map build l end; + +fun v_to_prfx (prfx, v::pats) = (v::prfx,pats) + | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1); + +fun v_to_pats (v::prfx,tag, pats) = (prfx, tag, v::pats) + | v_to_pats _ = raise CASE_ERROR ("mk_case: v_to_pats", ~1); + + +(*---------------------------------------------------------------------------- + * Translation of pattern terms into nested case expressions. + * + * This performs the translation and also builds the full set of patterns. + * Thus it supports the construction of induction theorems even when an + * incomplete set of patterns is given. + *---------------------------------------------------------------------------*) + +fun mk_case tab ctxt ty_match ty_inst type_of used range_ty = + let + val name = Name.variant used "a"; + fun expand constructors used ty ((_, []), _) = + raise CASE_ERROR ("mk_case: expand_var_row", ~1) + | expand constructors used ty (row as ((prfx, p :: rst), rhs)) = + if is_Free p then + let + val used' = add_row_used row used; + fun expnd c = + let val capp = + list_comb (fresh_constr ty_match ty_inst ty used' c) + in ((prfx, capp :: rst), pattern_subst [(p, capp)] rhs) + end + in map expnd constructors end + else [row] + fun mk {rows = [], ...} = raise CASE_ERROR ("no rows", ~1) + | mk {path = [], rows = ((prfx, []), (tm, tag)) :: _} = (* Done *) + ([(prfx, tag, [])], tm) + | mk {path, rows as ((row as ((_, [Free _]), _)) :: _ :: _)} = + mk {path = path, rows = [row]} + | mk {path = u :: rstp, rows as ((_, _ :: _), _) :: _} = + let val col0 = map (fn ((_, p :: _), (_, (i, _))) => (p, i)) rows + in case Option.map (apfst head_of) + (find_first (not o is_Free o fst) col0) of + NONE => + let + val rows' = map (fn ((v, _), row) => row ||> + pattern_subst [(v, u)] |>> v_to_prfx) (col0 ~~ rows); + val (pref_patl, tm) = mk {path = rstp, rows = rows'} + in (map v_to_pats pref_patl, tm) end + | SOME (Const (cname, cT), i) => (case ty_info tab cname of + NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i) + | SOME {case_name, constructors} => + let + val pty = body_type cT; + val used' = fold Term.add_free_names rstp used; + val nrows = maps (expand constructors used' pty) rows; + val subproblems = partition ty_match ty_inst type_of used' + constructors pty range_ty nrows; + val new_formals = map #new_formals subproblems + val constructors' = map #constructor subproblems + val news = map (fn {new_formals, group, ...} => + {path = new_formals @ rstp, rows = group}) subproblems; + val (pat_rect, dtrees) = split_list (map mk news); + val case_functions = map2 + (fn {new_formals, names, constraints, ...} => + fold_rev (fn ((x as Free (_, T), s), cnstrts) => fn t => + Abs (if s = "" then name else s, T, + abstract_over (x, t)) |> + fold mk_fun_constrain cnstrts) + (new_formals ~~ names ~~ constraints)) + subproblems dtrees; + val types = map type_of (case_functions @ [u]); + val case_const = Const (case_name, types ---> range_ty) + val tree = list_comb (case_const, case_functions @ [u]) + val pat_rect1 = flat (map mk_pat + (constructors ~~ constructors' ~~ pat_rect)) + in (pat_rect1, tree) + end) + | SOME (t, i) => raise CASE_ERROR ("Not a datatype constructor: " ^ + Syntax.string_of_term ctxt t, i) + end + | mk _ = raise CASE_ERROR ("Malformed row matrix", ~1) + in mk + end; + +fun case_error s = error ("Error in case expression:\n" ^ s); + +(* Repeated variable occurrences in a pattern are not allowed. *) +fun no_repeat_vars ctxt pat = fold_aterms + (fn x as Free (s, _) => (fn xs => + if member op aconv xs x then + case_error (quote s ^ " occurs repeatedly in the pattern " ^ + quote (Syntax.string_of_term ctxt pat)) + else x :: xs) + | _ => I) pat []; + +fun gen_make_case ty_match ty_inst type_of tab ctxt err used x clauses = + let + fun string_of_clause (pat, rhs) = Syntax.string_of_term ctxt + (Syntax.const "_case1" $ pat $ rhs); + val _ = map (no_repeat_vars ctxt o fst) clauses; + val rows = map_index (fn (i, (pat, rhs)) => + (([], [pat]), (rhs, (i, true)))) clauses; + val rangeT = (case distinct op = (map (type_of o snd) clauses) of + [] => case_error "no clauses given" + | [T] => T + | _ => case_error "all cases must have the same result type"); + val used' = fold add_row_used rows used; + val (patts, case_tm) = mk_case tab ctxt ty_match ty_inst type_of + used' rangeT {path = [x], rows = rows} + handle CASE_ERROR (msg, i) => case_error (msg ^ + (if i < 0 then "" + else "\nIn clause\n" ^ string_of_clause (nth clauses i))); + val patts1 = map + (fn (_, tag, [pat]) => (pat, tag) + | _ => case_error "error in pattern-match translation") patts; + val patts2 = Library.sort (Library.int_ord o Library.pairself row_of_pat) patts1 + val finals = map row_of_pat patts2 + val originals = map (row_of_pat o #2) rows + val _ = case originals \\ finals of + [] => () + | is => (if err then case_error else warning) + ("The following clauses are redundant (covered by preceding clauses):\n" ^ + cat_lines (map (string_of_clause o nth clauses) is)); + in + (case_tm, patts2) + end; + +fun make_case tab ctxt = gen_make_case + (match_type (ProofContext.theory_of ctxt)) Envir.subst_TVars fastype_of tab ctxt; +val make_case_untyped = gen_make_case (K (K Vartab.empty)) + (K (Term.map_types (K dummyT))) (K dummyT); + + +(* parse translation *) + +fun case_tr err tab_of ctxt [t, u] = + let + val thy = ProofContext.theory_of ctxt; + (* replace occurrences of dummy_pattern by distinct variables *) + (* internalize constant names *) + fun prep_pat ((c as Const ("_constrain", _)) $ t $ tT) used = + let val (t', used') = prep_pat t used + in (c $ t' $ tT, used') end + | prep_pat (Const ("dummy_pattern", T)) used = + let val x = Name.variant used "x" + in (Free (x, T), x :: used) end + | prep_pat (Const (s, T)) used = + (case try (unprefix Syntax.constN) s of + SOME c => (Const (c, T), used) + | NONE => (Const (Sign.intern_const thy s, T), used)) + | prep_pat (v as Free (s, T)) used = + let val s' = Sign.intern_const thy s + in + if Sign.declared_const thy s' then + (Const (s', T), used) + else (v, used) + end + | prep_pat (t $ u) used = + let + val (t', used') = prep_pat t used; + val (u', used'') = prep_pat u used' + in + (t' $ u', used'') + end + | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t); + fun dest_case1 (t as Const ("_case1", _) $ l $ r) = + let val (l', cnstrts) = strip_constraints l + in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts) + end + | dest_case1 t = case_error "dest_case1"; + fun dest_case2 (Const ("_case2", _) $ t $ u) = t :: dest_case2 u + | dest_case2 t = [t]; + val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u)); + val (case_tm, _) = make_case_untyped (tab_of thy) ctxt err [] + (fold (fn tT => fn t => Syntax.const "_constrain" $ t $ tT) + (flat cnstrts) t) cases; + in case_tm end + | case_tr _ _ _ ts = case_error "case_tr"; + + +(*--------------------------------------------------------------------------- + * Pretty printing of nested case expressions + *---------------------------------------------------------------------------*) + +(* destruct one level of pattern matching *) + +fun gen_dest_case name_of type_of tab d used t = + case apfst name_of (strip_comb t) of + (SOME cname, ts as _ :: _) => + let + val (fs, x) = split_last ts; + fun strip_abs i t = + let + val zs = strip_abs_vars t; + val _ = if length zs < i then raise CASE_ERROR ("", 0) else (); + val (xs, ys) = chop i zs; + val u = list_abs (ys, strip_abs_body t); + val xs' = map Free (Name.variant_list (OldTerm.add_term_names (u, used)) + (map fst xs) ~~ map snd xs) + in (xs', subst_bounds (rev xs', u)) end; + fun is_dependent i t = + let val k = length (strip_abs_vars t) - i + in k < 0 orelse exists (fn j => j >= k) + (loose_bnos (strip_abs_body t)) + end; + fun count_cases (_, _, true) = I + | count_cases (c, (_, body), false) = + AList.map_default op aconv (body, []) (cons c); + val is_undefined = name_of #> equal (SOME "HOL.undefined"); + fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body) + in case ty_info tab cname of + SOME {constructors, case_name} => + if length fs = length constructors then + let + val cases = map (fn (Const (s, U), t) => + let + val k = length (binder_types U); + val p as (xs, _) = strip_abs k t + in + (Const (s, map type_of xs ---> type_of x), + p, is_dependent k t) + end) (constructors ~~ fs); + val cases' = sort (int_ord o swap o pairself (length o snd)) + (fold_rev count_cases cases []); + val R = type_of t; + val dummy = if d then Const ("dummy_pattern", R) + else Free (Name.variant used "x", R) + in + SOME (x, map mk_case (case find_first (is_undefined o fst) cases' of + SOME (_, cs) => + if length cs = length constructors then [hd cases] + else filter_out (fn (_, (_, body), _) => is_undefined body) cases + | NONE => case cases' of + [] => cases + | (default, cs) :: _ => + if length cs = 1 then cases + else if length cs = length constructors then + [hd cases, (dummy, ([], default), false)] + else + filter_out (fn (c, _, _) => member op aconv cs c) cases @ + [(dummy, ([], default), false)])) + end handle CASE_ERROR _ => NONE + else NONE + | _ => NONE + end + | _ => NONE; + +val dest_case = gen_dest_case (try (dest_Const #> fst)) fastype_of; +val dest_case' = gen_dest_case + (try (dest_Const #> fst #> unprefix Syntax.constN)) (K dummyT); + + +(* destruct nested patterns *) + +fun strip_case'' dest (pat, rhs) = + case dest (Term.add_free_names pat []) rhs of + SOME (exp as Free _, clauses) => + if member op aconv (OldTerm.term_frees pat) exp andalso + not (exists (fn (_, rhs') => + member op aconv (OldTerm.term_frees rhs') exp) clauses) + then + maps (strip_case'' dest) (map (fn (pat', rhs') => + (subst_free [(exp, pat')] pat, rhs')) clauses) + else [(pat, rhs)] + | _ => [(pat, rhs)]; + +fun gen_strip_case dest t = case dest [] t of + SOME (x, clauses) => + SOME (x, maps (strip_case'' dest) clauses) + | NONE => NONE; + +val strip_case = gen_strip_case oo dest_case; +val strip_case' = gen_strip_case oo dest_case'; + + +(* print translation *) + +fun case_tr' tab_of cname ctxt ts = + let + val thy = ProofContext.theory_of ctxt; + val consts = ProofContext.consts_of ctxt; + fun mk_clause (pat, rhs) = + let val xs = Term.add_frees pat [] + in + Syntax.const "_case1" $ + map_aterms + (fn Free p => Syntax.mark_boundT p + | Const (s, _) => Const (Consts.extern_early consts s, dummyT) + | t => t) pat $ + map_aterms + (fn x as Free (s, T) => + if member (op =) xs (s, T) then Syntax.mark_bound s else x + | t => t) rhs + end + in case strip_case' (tab_of thy) true (list_comb (Syntax.const cname, ts)) of + SOME (x, clauses) => Syntax.const "_case_syntax" $ x $ + foldr1 (fn (t, u) => Syntax.const "_case2" $ t $ u) + (map mk_clause clauses) + | NONE => raise Match + end; + +end; diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Datatype/datatype_codegen.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Datatype/datatype_codegen.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,455 @@ +(* Title: HOL/Tools/datatype_codegen.ML + Author: Stefan Berghofer and Florian Haftmann, TU Muenchen + +Code generator facilities for inductive datatypes. +*) + +signature DATATYPE_CODEGEN = +sig + val find_shortest_path: Datatype.descr -> int -> (string * int) option + val mk_eq_eqns: theory -> string -> (thm * bool) list + val mk_case_cert: theory -> string -> thm + val setup: theory -> theory +end; + +structure DatatypeCodegen : DATATYPE_CODEGEN = +struct + +(** find shortest path to constructor with no recursive arguments **) + +fun find_nonempty (descr: Datatype.descr) is i = + let + val (_, _, constrs) = the (AList.lookup (op =) descr i); + fun arg_nonempty (_, DatatypeAux.DtRec i) = if member (op =) is i + then NONE + else Option.map (curry op + 1 o snd) (find_nonempty descr (i::is) i) + | arg_nonempty _ = SOME 0; + fun max xs = Library.foldl + (fn (NONE, _) => NONE + | (SOME i, SOME j) => SOME (Int.max (i, j)) + | (_, NONE) => NONE) (SOME 0, xs); + val xs = sort (int_ord o pairself snd) + (map_filter (fn (s, dts) => Option.map (pair s) + (max (map (arg_nonempty o DatatypeAux.strip_dtyp) dts))) constrs) + in case xs of [] => NONE | x :: _ => SOME x end; + +fun find_shortest_path descr i = find_nonempty descr [i] i; + + +(** SML code generator **) + +open Codegen; + +(* datatype definition *) + +fun add_dt_defs thy defs dep module (descr: Datatype.descr) sorts gr = + let + val descr' = List.filter (can (map DatatypeAux.dest_DtTFree o #2 o snd)) descr; + val rtnames = map (#1 o snd) (List.filter (fn (_, (_, _, cs)) => + exists (exists DatatypeAux.is_rec_type o snd) cs) descr'); + + val (_, (tname, _, _)) :: _ = descr'; + val node_id = tname ^ " (type)"; + val module' = if_library (thyname_of_type thy tname) module; + + fun mk_dtdef prfx [] gr = ([], gr) + | mk_dtdef prfx ((_, (tname, dts, cs))::xs) gr = + let + val tvs = map DatatypeAux.dest_DtTFree dts; + val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs; + val ((_, type_id), gr') = mk_type_id module' tname gr; + val (ps, gr'') = gr' |> + fold_map (fn (cname, cargs) => + fold_map (invoke_tycodegen thy defs node_id module' false) + cargs ##>> + mk_const_id module' cname) cs'; + val (rest, gr''') = mk_dtdef "and " xs gr'' + in + (Pretty.block (str prfx :: + (if null tvs then [] else + [mk_tuple (map str tvs), str " "]) @ + [str (type_id ^ " ="), Pretty.brk 1] @ + List.concat (separate [Pretty.brk 1, str "| "] + (map (fn (ps', (_, cname)) => [Pretty.block + (str cname :: + (if null ps' then [] else + List.concat ([str " of", Pretty.brk 1] :: + separate [str " *", Pretty.brk 1] + (map single ps'))))]) ps))) :: rest, gr''') + end; + + fun mk_constr_term cname Ts T ps = + List.concat (separate [str " $", Pretty.brk 1] + ([str ("Const (\"" ^ cname ^ "\","), Pretty.brk 1, + mk_type false (Ts ---> T), str ")"] :: ps)); + + fun mk_term_of_def gr prfx [] = [] + | mk_term_of_def gr prfx ((_, (tname, dts, cs)) :: xs) = + let + val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs; + val dts' = map (DatatypeAux.typ_of_dtyp descr sorts) dts; + val T = Type (tname, dts'); + val rest = mk_term_of_def gr "and " xs; + val (eqs, _) = fold_map (fn (cname, Ts) => fn prfx => + let val args = map (fn i => + str ("x" ^ string_of_int i)) (1 upto length Ts) + in (Pretty.blk (4, + [str prfx, mk_term_of gr module' false T, Pretty.brk 1, + if null Ts then str (snd (get_const_id gr cname)) + else parens (Pretty.block + [str (snd (get_const_id gr cname)), + Pretty.brk 1, mk_tuple args]), + str " =", Pretty.brk 1] @ + mk_constr_term cname Ts T + (map2 (fn x => fn U => [Pretty.block [mk_term_of gr module' false U, + Pretty.brk 1, x]]) args Ts)), " | ") + end) cs' prfx + in eqs @ rest end; + + fun mk_gen_of_def gr prfx [] = [] + | mk_gen_of_def gr prfx ((i, (tname, dts, cs)) :: xs) = + let + val tvs = map DatatypeAux.dest_DtTFree dts; + val Us = map (DatatypeAux.typ_of_dtyp descr sorts) dts; + val T = Type (tname, Us); + val (cs1, cs2) = + List.partition (exists DatatypeAux.is_rec_type o snd) cs; + val SOME (cname, _) = find_shortest_path descr i; + + fun mk_delay p = Pretty.block + [str "fn () =>", Pretty.brk 1, p]; + + fun mk_force p = Pretty.block [p, Pretty.brk 1, str "()"]; + + fun mk_constr s b (cname, dts) = + let + val gs = map (fn dt => mk_app false (mk_gen gr module' false rtnames s + (DatatypeAux.typ_of_dtyp descr sorts dt)) + [str (if b andalso DatatypeAux.is_rec_type dt then "0" + else "j")]) dts; + val Ts = map (DatatypeAux.typ_of_dtyp descr sorts) dts; + val xs = map str + (DatatypeProp.indexify_names (replicate (length dts) "x")); + val ts = map str + (DatatypeProp.indexify_names (replicate (length dts) "t")); + val (_, id) = get_const_id gr cname + in + mk_let + (map2 (fn p => fn q => mk_tuple [p, q]) xs ts ~~ gs) + (mk_tuple + [case xs of + _ :: _ :: _ => Pretty.block + [str id, Pretty.brk 1, mk_tuple xs] + | _ => mk_app false (str id) xs, + mk_delay (Pretty.block (mk_constr_term cname Ts T + (map (single o mk_force) ts)))]) + end; + + fun mk_choice [c] = mk_constr "(i-1)" false c + | mk_choice cs = Pretty.block [str "one_of", + Pretty.brk 1, Pretty.blk (1, str "[" :: + List.concat (separate [str ",", Pretty.fbrk] + (map (single o mk_delay o mk_constr "(i-1)" false) cs)) @ + [str "]"]), Pretty.brk 1, str "()"]; + + val gs = maps (fn s => + let val s' = strip_tname s + in [str (s' ^ "G"), str (s' ^ "T")] end) tvs; + val gen_name = "gen_" ^ snd (get_type_id gr tname) + + in + Pretty.blk (4, separate (Pretty.brk 1) + (str (prfx ^ gen_name ^ + (if null cs1 then "" else "'")) :: gs @ + (if null cs1 then [] else [str "i"]) @ + [str "j"]) @ + [str " =", Pretty.brk 1] @ + (if not (null cs1) andalso not (null cs2) + then [str "frequency", Pretty.brk 1, + Pretty.blk (1, [str "[", + mk_tuple [str "i", mk_delay (mk_choice cs1)], + str ",", Pretty.fbrk, + mk_tuple [str "1", mk_delay (mk_choice cs2)], + str "]"]), Pretty.brk 1, str "()"] + else if null cs2 then + [Pretty.block [str "(case", Pretty.brk 1, + str "i", Pretty.brk 1, str "of", + Pretty.brk 1, str "0 =>", Pretty.brk 1, + mk_constr "0" true (cname, valOf (AList.lookup (op =) cs cname)), + Pretty.brk 1, str "| _ =>", Pretty.brk 1, + mk_choice cs1, str ")"]] + else [mk_choice cs2])) :: + (if null cs1 then [] + else [Pretty.blk (4, separate (Pretty.brk 1) + (str ("and " ^ gen_name) :: gs @ [str "i"]) @ + [str " =", Pretty.brk 1] @ + separate (Pretty.brk 1) (str (gen_name ^ "'") :: gs @ + [str "i", str "i"]))]) @ + mk_gen_of_def gr "and " xs + end + + in + (module', (add_edge_acyclic (node_id, dep) gr + handle Graph.CYCLES _ => gr) handle Graph.UNDEF _ => + let + val gr1 = add_edge (node_id, dep) + (new_node (node_id, (NONE, "", "")) gr); + val (dtdef, gr2) = mk_dtdef "datatype " descr' gr1 ; + in + map_node node_id (K (NONE, module', + string_of (Pretty.blk (0, separate Pretty.fbrk dtdef @ + [str ";"])) ^ "\n\n" ^ + (if "term_of" mem !mode then + string_of (Pretty.blk (0, separate Pretty.fbrk + (mk_term_of_def gr2 "fun " descr') @ [str ";"])) ^ "\n\n" + else "") ^ + (if "test" mem !mode then + string_of (Pretty.blk (0, separate Pretty.fbrk + (mk_gen_of_def gr2 "fun " descr') @ [str ";"])) ^ "\n\n" + else ""))) gr2 + end) + end; + + +(* case expressions *) + +fun pretty_case thy defs dep module brack constrs (c as Const (_, T)) ts gr = + let val i = length constrs + in if length ts <= i then + invoke_codegen thy defs dep module brack (eta_expand c ts (i+1)) gr + else + let + val ts1 = Library.take (i, ts); + val t :: ts2 = Library.drop (i, ts); + val names = List.foldr OldTerm.add_term_names + (map (fst o fst o dest_Var) (List.foldr OldTerm.add_term_vars [] ts1)) ts1; + val (Ts, dT) = split_last (Library.take (i+1, fst (strip_type T))); + + fun pcase [] [] [] gr = ([], gr) + | pcase ((cname, cargs)::cs) (t::ts) (U::Us) gr = + let + val j = length cargs; + val xs = Name.variant_list names (replicate j "x"); + val Us' = Library.take (j, fst (strip_type U)); + val frees = map Free (xs ~~ Us'); + val (cp, gr0) = invoke_codegen thy defs dep module false + (list_comb (Const (cname, Us' ---> dT), frees)) gr; + val t' = Envir.beta_norm (list_comb (t, frees)); + val (p, gr1) = invoke_codegen thy defs dep module false t' gr0; + val (ps, gr2) = pcase cs ts Us gr1; + in + ([Pretty.block [cp, str " =>", Pretty.brk 1, p]] :: ps, gr2) + end; + + val (ps1, gr1) = pcase constrs ts1 Ts gr ; + val ps = List.concat (separate [Pretty.brk 1, str "| "] ps1); + val (p, gr2) = invoke_codegen thy defs dep module false t gr1; + val (ps2, gr3) = fold_map (invoke_codegen thy defs dep module true) ts2 gr2; + in ((if not (null ts2) andalso brack then parens else I) + (Pretty.block (separate (Pretty.brk 1) + (Pretty.block ([str "(case ", p, str " of", + Pretty.brk 1] @ ps @ [str ")"]) :: ps2))), gr3) + end + end; + + +(* constructors *) + +fun pretty_constr thy defs dep module brack args (c as Const (s, T)) ts gr = + let val i = length args + in if i > 1 andalso length ts < i then + invoke_codegen thy defs dep module brack (eta_expand c ts i) gr + else + let + val id = mk_qual_id module (get_const_id gr s); + val (ps, gr') = fold_map + (invoke_codegen thy defs dep module (i = 1)) ts gr; + in (case args of + _ :: _ :: _ => (if brack then parens else I) + (Pretty.block [str id, Pretty.brk 1, mk_tuple ps]) + | _ => (mk_app brack (str id) ps), gr') + end + end; + + +(* code generators for terms and types *) + +fun datatype_codegen thy defs dep module brack t gr = (case strip_comb t of + (c as Const (s, T), ts) => + (case Datatype.datatype_of_case thy s of + SOME {index, descr, ...} => + if is_some (get_assoc_code thy (s, T)) then NONE else + SOME (pretty_case thy defs dep module brack + (#3 (the (AList.lookup op = descr index))) c ts gr ) + | NONE => case (Datatype.datatype_of_constr thy s, strip_type T) of + (SOME {index, descr, ...}, (_, U as Type (tyname, _))) => + if is_some (get_assoc_code thy (s, T)) then NONE else + let + val SOME (tyname', _, constrs) = AList.lookup op = descr index; + val SOME args = AList.lookup op = constrs s + in + if tyname <> tyname' then NONE + else SOME (pretty_constr thy defs + dep module brack args c ts (snd (invoke_tycodegen thy defs dep module false U gr))) + end + | _ => NONE) + | _ => NONE); + +fun datatype_tycodegen thy defs dep module brack (Type (s, Ts)) gr = + (case Datatype.get_datatype thy s of + NONE => NONE + | SOME {descr, sorts, ...} => + if is_some (get_assoc_type thy s) then NONE else + let + val (ps, gr') = fold_map + (invoke_tycodegen thy defs dep module false) Ts gr; + val (module', gr'') = add_dt_defs thy defs dep module descr sorts gr' ; + val (tyid, gr''') = mk_type_id module' s gr'' + in SOME (Pretty.block ((if null Ts then [] else + [mk_tuple ps, str " "]) @ + [str (mk_qual_id module tyid)]), gr''') + end) + | datatype_tycodegen _ _ _ _ _ _ _ = NONE; + + +(** generic code generator **) + +(* liberal addition of code data for datatypes *) + +fun mk_constr_consts thy vs dtco cos = + let + val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos; + val cs' = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) cs; + in if is_some (try (Code.constrset_of_consts thy) cs') + then SOME cs + else NONE + end; + + +(* case certificates *) + +fun mk_case_cert thy tyco = + let + val raw_thms = + (#case_rewrites o Datatype.the_datatype thy) tyco; + val thms as hd_thm :: _ = raw_thms + |> Conjunction.intr_balanced + |> Thm.unvarify + |> Conjunction.elim_balanced (length raw_thms) + |> map Simpdata.mk_meta_eq + |> map Drule.zero_var_indexes + val params = fold_aterms (fn (Free (v, _)) => insert (op =) v + | _ => I) (Thm.prop_of hd_thm) []; + val rhs = hd_thm + |> Thm.prop_of + |> Logic.dest_equals + |> fst + |> Term.strip_comb + |> apsnd (fst o split_last) + |> list_comb; + val lhs = Free (Name.variant params "case", Term.fastype_of rhs); + val asm = (Thm.cterm_of thy o Logic.mk_equals) (lhs, rhs); + in + thms + |> Conjunction.intr_balanced + |> MetaSimplifier.rewrite_rule [(Thm.symmetric o Thm.assume) asm] + |> Thm.implies_intr asm + |> Thm.generalize ([], params) 0 + |> AxClass.unoverload thy + |> Thm.varifyT + end; + + +(* equality *) + +fun mk_eq_eqns thy dtco = + let + val (vs, cos) = Datatype.the_datatype_spec thy dtco; + val { descr, index, inject = inject_thms, ... } = Datatype.the_datatype thy dtco; + val ty = Type (dtco, map TFree vs); + fun mk_eq (t1, t2) = Const (@{const_name eq_class.eq}, ty --> ty --> HOLogic.boolT) + $ t1 $ t2; + fun true_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.true_const); + fun false_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.false_const); + val triv_injects = map_filter + (fn (c, []) => SOME (HOLogic.mk_Trueprop (true_eq (Const (c, ty), Const (c, ty)))) + | _ => NONE) cos; + fun prep_inject (trueprop $ (equiv $ (_ $ t1 $ t2) $ rhs)) = + trueprop $ (equiv $ mk_eq (t1, t2) $ rhs); + val injects = map prep_inject (nth (DatatypeProp.make_injs [descr] vs) index); + fun prep_distinct (trueprop $ (not $ (_ $ t1 $ t2))) = + [trueprop $ false_eq (t1, t2), trueprop $ false_eq (t2, t1)]; + val distincts = maps prep_distinct (snd (nth (DatatypeProp.make_distincts [descr] vs) index)); + val refl = HOLogic.mk_Trueprop (true_eq (Free ("x", ty), Free ("x", ty))); + val simpset = Simplifier.context (ProofContext.init thy) (HOL_basic_ss + addsimps (map Simpdata.mk_eq (@{thm eq} :: @{thm eq_True} :: inject_thms)) + addsimprocs [Datatype.distinct_simproc]); + fun prove prop = SkipProof.prove_global thy [] [] prop (K (ALLGOALS (simp_tac simpset))) + |> Simpdata.mk_eq; + in map (rpair true o prove) (triv_injects @ injects @ distincts) @ [(prove refl, false)] end; + +fun add_equality vs dtcos thy = + let + fun add_def dtco lthy = + let + val ty = Type (dtco, map TFree vs); + fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT) + $ Free ("x", ty) $ Free ("y", ty); + val def = HOLogic.mk_Trueprop (HOLogic.mk_eq + (mk_side @{const_name eq_class.eq}, mk_side @{const_name "op ="})); + val def' = Syntax.check_term lthy def; + val ((_, (_, thm)), lthy') = Specification.definition + (NONE, (Attrib.empty_binding, def')) lthy; + val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy); + val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm; + in (thm', lthy') end; + fun tac thms = Class.intro_classes_tac [] + THEN ALLGOALS (ProofContext.fact_tac thms); + fun add_eq_thms dtco thy = + let + val const = AxClass.param_of_inst thy (@{const_name eq_class.eq}, dtco); + val thy_ref = Theory.check_thy thy; + fun mk_thms () = rev ((mk_eq_eqns (Theory.deref thy_ref) dtco)); + in + Code.add_eqnl (const, Lazy.lazy mk_thms) thy + end; + in + thy + |> TheoryTarget.instantiation (dtcos, vs, [HOLogic.class_eq]) + |> fold_map add_def dtcos + |-> (fn def_thms => Class.prove_instantiation_exit_result (map o Morphism.thm) + (fn _ => fn def_thms => tac def_thms) def_thms) + |-> (fn def_thms => fold Code.del_eqn def_thms) + |> fold add_eq_thms dtcos + end; + + +(* register a datatype etc. *) + +fun add_all_code config dtcos thy = + let + val (vs :: _, coss) = (split_list o map (Datatype.the_datatype_spec thy)) dtcos; + val any_css = map2 (mk_constr_consts thy vs) dtcos coss; + val css = if exists is_none any_css then [] + else map_filter I any_css; + val case_rewrites = maps (#case_rewrites o Datatype.the_datatype thy) dtcos; + val certs = map (mk_case_cert thy) dtcos; + in + if null css then thy + else thy + |> tap (fn _ => DatatypeAux.message config "Registering datatype for code generator ...") + |> fold Code.add_datatype css + |> fold_rev Code.add_default_eqn case_rewrites + |> fold Code.add_case certs + |> add_equality vs dtcos + end; + + +(** theory setup **) + +val setup = + add_codegen "datatype" datatype_codegen + #> add_tycodegen "datatype" datatype_tycodegen + #> Datatype.interpretation add_all_code + +end; diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Datatype/datatype_prop.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Datatype/datatype_prop.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,435 @@ +(* Title: HOL/Tools/datatype_prop.ML + Author: Stefan Berghofer, TU Muenchen + +Characteristic properties of datatypes. +*) + +signature DATATYPE_PROP = +sig + val indexify_names: string list -> string list + val make_tnames: typ list -> string list + val make_injs : DatatypeAux.descr list -> (string * sort) list -> term list list + val make_distincts : DatatypeAux.descr list -> + (string * sort) list -> (int * term list) list (*no symmetric inequalities*) + val make_ind : DatatypeAux.descr list -> (string * sort) list -> term + val make_casedists : DatatypeAux.descr list -> (string * sort) list -> term list + val make_primrec_Ts : DatatypeAux.descr list -> (string * sort) list -> + string list -> typ list * typ list + val make_primrecs : string list -> DatatypeAux.descr list -> + (string * sort) list -> theory -> term list + val make_cases : string list -> DatatypeAux.descr list -> + (string * sort) list -> theory -> term list list + val make_splits : string list -> DatatypeAux.descr list -> + (string * sort) list -> theory -> (term * term) list + val make_weak_case_congs : string list -> DatatypeAux.descr list -> + (string * sort) list -> theory -> term list + val make_case_congs : string list -> DatatypeAux.descr list -> + (string * sort) list -> theory -> term list + val make_nchotomys : DatatypeAux.descr list -> + (string * sort) list -> term list +end; + +structure DatatypeProp : DATATYPE_PROP = +struct + +open DatatypeAux; + +fun indexify_names names = + let + fun index (x :: xs) tab = + (case AList.lookup (op =) tab x of + NONE => if member (op =) xs x then (x ^ "1") :: index xs ((x, 2) :: tab) else x :: index xs tab + | SOME i => (x ^ string_of_int i) :: index xs ((x, i + 1) :: tab)) + | index [] _ = []; + in index names [] end; + +fun make_tnames Ts = + let + fun type_name (TFree (name, _)) = implode (tl (explode name)) + | type_name (Type (name, _)) = + let val name' = Long_Name.base_name name + in if Syntax.is_identifier name' then name' else "x" end; + in indexify_names (map type_name Ts) end; + + +(************************* injectivity of constructors ************************) + +fun make_injs descr sorts = + let + val descr' = flat descr; + fun make_inj T (cname, cargs) = + if null cargs then I else + let + val Ts = map (typ_of_dtyp descr' sorts) cargs; + val constr_t = Const (cname, Ts ---> T); + val tnames = make_tnames Ts; + val frees = map Free (tnames ~~ Ts); + val frees' = map Free ((map ((op ^) o (rpair "'")) tnames) ~~ Ts); + in cons (HOLogic.mk_Trueprop (HOLogic.mk_eq + (HOLogic.mk_eq (list_comb (constr_t, frees), list_comb (constr_t, frees')), + foldr1 (HOLogic.mk_binop "op &") + (map HOLogic.mk_eq (frees ~~ frees'))))) + end; + in + map2 (fn d => fn T => fold_rev (make_inj T) (#3 (snd d)) []) + (hd descr) (Library.take (length (hd descr), get_rec_types descr' sorts)) + end; + + +(************************* distinctness of constructors ***********************) + +fun make_distincts descr sorts = + let + val descr' = flat descr; + val recTs = get_rec_types descr' sorts; + val newTs = Library.take (length (hd descr), recTs); + + fun prep_constr (cname, cargs) = (cname, map (typ_of_dtyp descr' sorts) cargs); + + fun make_distincts' _ [] = [] + | make_distincts' T ((cname, cargs)::constrs) = + let + val frees = map Free ((make_tnames cargs) ~~ cargs); + val t = list_comb (Const (cname, cargs ---> T), frees); + + fun make_distincts'' (cname', cargs') = + let + val frees' = map Free ((map ((op ^) o (rpair "'")) + (make_tnames cargs')) ~~ cargs'); + val t' = list_comb (Const (cname', cargs' ---> T), frees') + in + HOLogic.mk_Trueprop (HOLogic.Not $ HOLogic.mk_eq (t, t')) + end + + in map make_distincts'' constrs @ make_distincts' T constrs end; + + in + map2 (fn ((_, (_, _, constrs))) => fn T => + (length constrs, make_distincts' T (map prep_constr constrs))) (hd descr) newTs + end; + + +(********************************* induction **********************************) + +fun make_ind descr sorts = + let + val descr' = List.concat descr; + val recTs = get_rec_types descr' sorts; + val pnames = if length descr' = 1 then ["P"] + else map (fn i => "P" ^ string_of_int i) (1 upto length descr'); + + fun make_pred i T = + let val T' = T --> HOLogic.boolT + in Free (List.nth (pnames, i), T') end; + + fun make_ind_prem k T (cname, cargs) = + let + fun mk_prem ((dt, s), T) = + let val (Us, U) = strip_type T + in list_all (map (pair "x") Us, HOLogic.mk_Trueprop + (make_pred (body_index dt) U $ app_bnds (Free (s, T)) (length Us))) + end; + + val recs = List.filter is_rec_type cargs; + val Ts = map (typ_of_dtyp descr' sorts) cargs; + val recTs' = map (typ_of_dtyp descr' sorts) recs; + val tnames = Name.variant_list pnames (make_tnames Ts); + val rec_tnames = map fst (List.filter (is_rec_type o snd) (tnames ~~ cargs)); + val frees = tnames ~~ Ts; + val prems = map mk_prem (recs ~~ rec_tnames ~~ recTs'); + + in list_all_free (frees, Logic.list_implies (prems, + HOLogic.mk_Trueprop (make_pred k T $ + list_comb (Const (cname, Ts ---> T), map Free frees)))) + end; + + val prems = List.concat (map (fn ((i, (_, _, constrs)), T) => + map (make_ind_prem i T) constrs) (descr' ~~ recTs)); + val tnames = make_tnames recTs; + val concl = HOLogic.mk_Trueprop (foldr1 (HOLogic.mk_binop "op &") + (map (fn (((i, _), T), tname) => make_pred i T $ Free (tname, T)) + (descr' ~~ recTs ~~ tnames))) + + in Logic.list_implies (prems, concl) end; + +(******************************* case distinction *****************************) + +fun make_casedists descr sorts = + let + val descr' = List.concat descr; + + fun make_casedist_prem T (cname, cargs) = + let + val Ts = map (typ_of_dtyp descr' sorts) cargs; + val frees = Name.variant_list ["P", "y"] (make_tnames Ts) ~~ Ts; + val free_ts = map Free frees + in list_all_free (frees, Logic.mk_implies (HOLogic.mk_Trueprop + (HOLogic.mk_eq (Free ("y", T), list_comb (Const (cname, Ts ---> T), free_ts))), + HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT)))) + end; + + fun make_casedist ((_, (_, _, constrs)), T) = + let val prems = map (make_casedist_prem T) constrs + in Logic.list_implies (prems, HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT))) + end + + in map make_casedist + ((hd descr) ~~ Library.take (length (hd descr), get_rec_types descr' sorts)) + end; + +(*************** characteristic equations for primrec combinator **************) + +fun make_primrec_Ts descr sorts used = + let + val descr' = List.concat descr; + + val rec_result_Ts = map TFree (Name.variant_list used (replicate (length descr') "'t") ~~ + replicate (length descr') HOLogic.typeS); + + val reccomb_fn_Ts = List.concat (map (fn (i, (_, _, constrs)) => + map (fn (_, cargs) => + let + val Ts = map (typ_of_dtyp descr' sorts) cargs; + val recs = List.filter (is_rec_type o fst) (cargs ~~ Ts); + + fun mk_argT (dt, T) = + binder_types T ---> List.nth (rec_result_Ts, body_index dt); + + val argTs = Ts @ map mk_argT recs + in argTs ---> List.nth (rec_result_Ts, i) + end) constrs) descr'); + + in (rec_result_Ts, reccomb_fn_Ts) end; + +fun make_primrecs new_type_names descr sorts thy = + let + val descr' = List.concat descr; + val recTs = get_rec_types descr' sorts; + val used = List.foldr OldTerm.add_typ_tfree_names [] recTs; + + val (rec_result_Ts, reccomb_fn_Ts) = make_primrec_Ts descr sorts used; + + val rec_fns = map (uncurry (mk_Free "f")) + (reccomb_fn_Ts ~~ (1 upto (length reccomb_fn_Ts))); + + val big_reccomb_name = (space_implode "_" new_type_names) ^ "_rec"; + val reccomb_names = map (Sign.intern_const thy) + (if length descr' = 1 then [big_reccomb_name] else + (map ((curry (op ^) (big_reccomb_name ^ "_")) o string_of_int) + (1 upto (length descr')))); + val reccombs = map (fn ((name, T), T') => list_comb + (Const (name, reccomb_fn_Ts @ [T] ---> T'), rec_fns)) + (reccomb_names ~~ recTs ~~ rec_result_Ts); + + fun make_primrec T comb_t ((ts, f::fs), (cname, cargs)) = + let + val recs = List.filter is_rec_type cargs; + val Ts = map (typ_of_dtyp descr' sorts) cargs; + val recTs' = map (typ_of_dtyp descr' sorts) recs; + val tnames = make_tnames Ts; + val rec_tnames = map fst (List.filter (is_rec_type o snd) (tnames ~~ cargs)); + val frees = map Free (tnames ~~ Ts); + val frees' = map Free (rec_tnames ~~ recTs'); + + fun mk_reccomb ((dt, T), t) = + let val (Us, U) = strip_type T + in list_abs (map (pair "x") Us, + List.nth (reccombs, body_index dt) $ app_bnds t (length Us)) + end; + + val reccombs' = map mk_reccomb (recs ~~ recTs' ~~ frees') + + in (ts @ [HOLogic.mk_Trueprop (HOLogic.mk_eq + (comb_t $ list_comb (Const (cname, Ts ---> T), frees), + list_comb (f, frees @ reccombs')))], fs) + end + + in fst (Library.foldl (fn (x, ((dt, T), comb_t)) => + Library.foldl (make_primrec T comb_t) (x, #3 (snd dt))) + (([], rec_fns), descr' ~~ recTs ~~ reccombs)) + end; + +(****************** make terms of form t_case f1 ... fn *********************) + +fun make_case_combs new_type_names descr sorts thy fname = + let + val descr' = List.concat descr; + val recTs = get_rec_types descr' sorts; + val used = List.foldr OldTerm.add_typ_tfree_names [] recTs; + val newTs = Library.take (length (hd descr), recTs); + val T' = TFree (Name.variant used "'t", HOLogic.typeS); + + val case_fn_Ts = map (fn (i, (_, _, constrs)) => + map (fn (_, cargs) => + let val Ts = map (typ_of_dtyp descr' sorts) cargs + in Ts ---> T' end) constrs) (hd descr); + + val case_names = map (fn s => + Sign.intern_const thy (s ^ "_case")) new_type_names + in + map (fn ((name, Ts), T) => list_comb + (Const (name, Ts @ [T] ---> T'), + map (uncurry (mk_Free fname)) (Ts ~~ (1 upto length Ts)))) + (case_names ~~ case_fn_Ts ~~ newTs) + end; + +(**************** characteristic equations for case combinator ****************) + +fun make_cases new_type_names descr sorts thy = + let + val descr' = List.concat descr; + val recTs = get_rec_types descr' sorts; + val newTs = Library.take (length (hd descr), recTs); + + fun make_case T comb_t ((cname, cargs), f) = + let + val Ts = map (typ_of_dtyp descr' sorts) cargs; + val frees = map Free ((make_tnames Ts) ~~ Ts) + in HOLogic.mk_Trueprop (HOLogic.mk_eq + (comb_t $ list_comb (Const (cname, Ts ---> T), frees), + list_comb (f, frees))) + end + + in map (fn (((_, (_, _, constrs)), T), comb_t) => + map (make_case T comb_t) (constrs ~~ (snd (strip_comb comb_t)))) + ((hd descr) ~~ newTs ~~ (make_case_combs new_type_names descr sorts thy "f")) + end; + + +(*************************** the "split" - equations **************************) + +fun make_splits new_type_names descr sorts thy = + let + val descr' = List.concat descr; + val recTs = get_rec_types descr' sorts; + val used' = List.foldr OldTerm.add_typ_tfree_names [] recTs; + val newTs = Library.take (length (hd descr), recTs); + val T' = TFree (Name.variant used' "'t", HOLogic.typeS); + val P = Free ("P", T' --> HOLogic.boolT); + + fun make_split (((_, (_, _, constrs)), T), comb_t) = + let + val (_, fs) = strip_comb comb_t; + val used = ["P", "x"] @ (map (fst o dest_Free) fs); + + fun process_constr (((cname, cargs), f), (t1s, t2s)) = + let + val Ts = map (typ_of_dtyp descr' sorts) cargs; + val frees = map Free (Name.variant_list used (make_tnames Ts) ~~ Ts); + val eqn = HOLogic.mk_eq (Free ("x", T), + list_comb (Const (cname, Ts ---> T), frees)); + val P' = P $ list_comb (f, frees) + in ((List.foldr (fn (Free (s, T), t) => HOLogic.mk_all (s, T, t)) + (HOLogic.imp $ eqn $ P') frees)::t1s, + (List.foldr (fn (Free (s, T), t) => HOLogic.mk_exists (s, T, t)) + (HOLogic.conj $ eqn $ (HOLogic.Not $ P')) frees)::t2s) + end; + + val (t1s, t2s) = List.foldr process_constr ([], []) (constrs ~~ fs); + val lhs = P $ (comb_t $ Free ("x", T)) + in + (HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, mk_conj t1s)), + HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, HOLogic.Not $ mk_disj t2s))) + end + + in map make_split ((hd descr) ~~ newTs ~~ + (make_case_combs new_type_names descr sorts thy "f")) + end; + +(************************* additional rules for TFL ***************************) + +fun make_weak_case_congs new_type_names descr sorts thy = + let + val case_combs = make_case_combs new_type_names descr sorts thy "f"; + + fun mk_case_cong comb = + let + val Type ("fun", [T, _]) = fastype_of comb; + val M = Free ("M", T); + val M' = Free ("M'", T); + in + Logic.mk_implies (HOLogic.mk_Trueprop (HOLogic.mk_eq (M, M')), + HOLogic.mk_Trueprop (HOLogic.mk_eq (comb $ M, comb $ M'))) + end + in + map mk_case_cong case_combs + end; + + +(*--------------------------------------------------------------------------- + * Structure of case congruence theorem looks like this: + * + * (M = M') + * ==> (!!x1,...,xk. (M' = C1 x1..xk) ==> (f1 x1..xk = g1 x1..xk)) + * ==> ... + * ==> (!!x1,...,xj. (M' = Cn x1..xj) ==> (fn x1..xj = gn x1..xj)) + * ==> + * (ty_case f1..fn M = ty_case g1..gn M') + *---------------------------------------------------------------------------*) + +fun make_case_congs new_type_names descr sorts thy = + let + val case_combs = make_case_combs new_type_names descr sorts thy "f"; + val case_combs' = make_case_combs new_type_names descr sorts thy "g"; + + fun mk_case_cong ((comb, comb'), (_, (_, _, constrs))) = + let + val Type ("fun", [T, _]) = fastype_of comb; + val (_, fs) = strip_comb comb; + val (_, gs) = strip_comb comb'; + val used = ["M", "M'"] @ map (fst o dest_Free) (fs @ gs); + val M = Free ("M", T); + val M' = Free ("M'", T); + + fun mk_clause ((f, g), (cname, _)) = + let + val (Ts, _) = strip_type (fastype_of f); + val tnames = Name.variant_list used (make_tnames Ts); + val frees = map Free (tnames ~~ Ts) + in + list_all_free (tnames ~~ Ts, Logic.mk_implies + (HOLogic.mk_Trueprop + (HOLogic.mk_eq (M', list_comb (Const (cname, Ts ---> T), frees))), + HOLogic.mk_Trueprop + (HOLogic.mk_eq (list_comb (f, frees), list_comb (g, frees))))) + end + + in + Logic.list_implies (HOLogic.mk_Trueprop (HOLogic.mk_eq (M, M')) :: + map mk_clause (fs ~~ gs ~~ constrs), + HOLogic.mk_Trueprop (HOLogic.mk_eq (comb $ M, comb' $ M'))) + end + + in + map mk_case_cong (case_combs ~~ case_combs' ~~ hd descr) + end; + +(*--------------------------------------------------------------------------- + * Structure of exhaustion theorem looks like this: + * + * !v. (? y1..yi. v = C1 y1..yi) | ... | (? y1..yj. v = Cn y1..yj) + *---------------------------------------------------------------------------*) + +fun make_nchotomys descr sorts = + let + val descr' = List.concat descr; + val recTs = get_rec_types descr' sorts; + val newTs = Library.take (length (hd descr), recTs); + + fun mk_eqn T (cname, cargs) = + let + val Ts = map (typ_of_dtyp descr' sorts) cargs; + val tnames = Name.variant_list ["v"] (make_tnames Ts); + val frees = tnames ~~ Ts + in + List.foldr (fn ((s, T'), t) => HOLogic.mk_exists (s, T', t)) + (HOLogic.mk_eq (Free ("v", T), + list_comb (Const (cname, Ts ---> T), map Free frees))) frees + end + + in map (fn ((_, (_, _, constrs)), T) => + HOLogic.mk_Trueprop (HOLogic.mk_all ("v", T, mk_disj (map (mk_eqn T) constrs)))) + (hd descr ~~ newTs) + end; + +end; diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Datatype/datatype_realizer.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Datatype/datatype_realizer.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,230 @@ +(* Title: HOL/Tools/datatype_realizer.ML + Author: Stefan Berghofer, TU Muenchen + +Porgram extraction from proofs involving datatypes: +Realizers for induction and case analysis +*) + +signature DATATYPE_REALIZER = +sig + val add_dt_realizers: Datatype.config -> string list -> theory -> theory + val setup: theory -> theory +end; + +structure DatatypeRealizer : DATATYPE_REALIZER = +struct + +open DatatypeAux; + +fun subsets i j = if i <= j then + let val is = subsets (i+1) j + in map (fn ks => i::ks) is @ is end + else [[]]; + +fun forall_intr_prf (t, prf) = + let val (a, T) = (case t of Var ((a, _), T) => (a, T) | Free p => p) + in Abst (a, SOME T, Proofterm.prf_abstract_over t prf) end; + +fun prf_of thm = + Reconstruct.reconstruct_proof (Thm.theory_of_thm thm) (Thm.prop_of thm) (Thm.proof_of thm); + +fun prf_subst_vars inst = + Proofterm.map_proof_terms (subst_vars ([], inst)) I; + +fun is_unit t = snd (strip_type (fastype_of t)) = HOLogic.unitT; + +fun tname_of (Type (s, _)) = s + | tname_of _ = ""; + +fun mk_realizes T = Const ("realizes", T --> HOLogic.boolT --> HOLogic.boolT); + +fun make_ind sorts ({descr, rec_names, rec_rewrites, induction, ...} : info) is thy = + let + val recTs = get_rec_types descr sorts; + val pnames = if length descr = 1 then ["P"] + else map (fn i => "P" ^ string_of_int i) (1 upto length descr); + + val rec_result_Ts = map (fn ((i, _), P) => + if i mem is then TFree ("'" ^ P, HOLogic.typeS) else HOLogic.unitT) + (descr ~~ pnames); + + fun make_pred i T U r x = + if i mem is then + Free (List.nth (pnames, i), T --> U --> HOLogic.boolT) $ r $ x + else Free (List.nth (pnames, i), U --> HOLogic.boolT) $ x; + + fun mk_all i s T t = + if i mem is then list_all_free ([(s, T)], t) else t; + + val (prems, rec_fns) = split_list (flat (fst (fold_map + (fn ((i, (_, _, constrs)), T) => fold_map (fn (cname, cargs) => fn j => + let + val Ts = map (typ_of_dtyp descr sorts) cargs; + val tnames = Name.variant_list pnames (DatatypeProp.make_tnames Ts); + val recs = filter (is_rec_type o fst o fst) (cargs ~~ tnames ~~ Ts); + val frees = tnames ~~ Ts; + + fun mk_prems vs [] = + let + val rT = nth (rec_result_Ts) i; + val vs' = filter_out is_unit vs; + val f = mk_Free "f" (map fastype_of vs' ---> rT) j; + val f' = Envir.eta_contract (list_abs_free + (map dest_Free vs, if i mem is then list_comb (f, vs') + else HOLogic.unit)); + in (HOLogic.mk_Trueprop (make_pred i rT T (list_comb (f, vs')) + (list_comb (Const (cname, Ts ---> T), map Free frees))), f') + end + | mk_prems vs (((dt, s), T) :: ds) = + let + val k = body_index dt; + val (Us, U) = strip_type T; + val i = length Us; + val rT = nth (rec_result_Ts) k; + val r = Free ("r" ^ s, Us ---> rT); + val (p, f) = mk_prems (vs @ [r]) ds + in (mk_all k ("r" ^ s) (Us ---> rT) (Logic.mk_implies + (list_all (map (pair "x") Us, HOLogic.mk_Trueprop + (make_pred k rT U (app_bnds r i) + (app_bnds (Free (s, T)) i))), p)), f) + end + + in (apfst (curry list_all_free frees) (mk_prems (map Free frees) recs), j + 1) end) + constrs) (descr ~~ recTs) 1))); + + fun mk_proj j [] t = t + | mk_proj j (i :: is) t = if null is then t else + if (j: int) = i then HOLogic.mk_fst t + else mk_proj j is (HOLogic.mk_snd t); + + val tnames = DatatypeProp.make_tnames recTs; + val fTs = map fastype_of rec_fns; + val ps = map (fn ((((i, _), T), U), s) => Abs ("x", T, make_pred i U T + (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Bound 0) (Bound 0))) + (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names); + val r = if null is then Extraction.nullt else + foldr1 HOLogic.mk_prod (List.mapPartial (fn (((((i, _), T), U), s), tname) => + if i mem is then SOME + (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Free (tname, T)) + else NONE) (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names ~~ tnames)); + val concl = HOLogic.mk_Trueprop (foldr1 (HOLogic.mk_binop "op &") + (map (fn ((((i, _), T), U), tname) => + make_pred i U T (mk_proj i is r) (Free (tname, T))) + (descr ~~ recTs ~~ rec_result_Ts ~~ tnames))); + val cert = cterm_of thy; + val inst = map (pairself cert) (map head_of (HOLogic.dest_conj + (HOLogic.dest_Trueprop (concl_of induction))) ~~ ps); + + val thm = OldGoals.simple_prove_goal_cterm (cert (Logic.list_implies (prems, concl))) + (fn prems => + [rewrite_goals_tac (map mk_meta_eq [fst_conv, snd_conv]), + rtac (cterm_instantiate inst induction) 1, + ALLGOALS ObjectLogic.atomize_prems_tac, + rewrite_goals_tac (@{thm o_def} :: map mk_meta_eq rec_rewrites), + REPEAT ((resolve_tac prems THEN_ALL_NEW (fn i => + REPEAT (etac allE i) THEN atac i)) 1)]); + + val ind_name = Thm.get_name induction; + val vs = map (fn i => List.nth (pnames, i)) is; + val (thm', thy') = thy + |> Sign.root_path + |> PureThy.store_thm + (Binding.qualified_name (space_implode "_" (ind_name :: vs @ ["correctness"])), thm) + ||> Sign.restore_naming thy; + + val ivs = rev (Term.add_vars (Logic.varify (DatatypeProp.make_ind [descr] sorts)) []); + val rvs = rev (Thm.fold_terms Term.add_vars thm' []); + val ivs1 = map Var (filter_out (fn (_, T) => + tname_of (body_type T) mem ["set", "bool"]) ivs); + val ivs2 = map (fn (ixn, _) => Var (ixn, valOf (AList.lookup (op =) rvs ixn))) ivs; + + val prf = List.foldr forall_intr_prf + (List.foldr (fn ((f, p), prf) => + (case head_of (strip_abs_body f) of + Free (s, T) => + let val T' = Logic.varifyT T + in Abst (s, SOME T', Proofterm.prf_abstract_over + (Var ((s, 0), T')) (AbsP ("H", SOME p, prf))) + end + | _ => AbsP ("H", SOME p, prf))) + (Proofterm.proof_combP + (prf_of thm', map PBound (length prems - 1 downto 0))) (rec_fns ~~ prems_of thm)) ivs2; + + val r' = if null is then r else Logic.varify (List.foldr (uncurry lambda) + r (map Logic.unvarify ivs1 @ filter_out is_unit + (map (head_of o strip_abs_body) rec_fns))); + + in Extraction.add_realizers_i [(ind_name, (vs, r', prf))] thy' end; + + +fun make_casedists sorts ({index, descr, case_name, case_rewrites, exhaustion, ...} : info) thy = + let + val cert = cterm_of thy; + val rT = TFree ("'P", HOLogic.typeS); + val rT' = TVar (("'P", 0), HOLogic.typeS); + + fun make_casedist_prem T (cname, cargs) = + let + val Ts = map (typ_of_dtyp descr sorts) cargs; + val frees = Name.variant_list ["P", "y"] (DatatypeProp.make_tnames Ts) ~~ Ts; + val free_ts = map Free frees; + val r = Free ("r" ^ Long_Name.base_name cname, Ts ---> rT) + in (r, list_all_free (frees, Logic.mk_implies (HOLogic.mk_Trueprop + (HOLogic.mk_eq (Free ("y", T), list_comb (Const (cname, Ts ---> T), free_ts))), + HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $ + list_comb (r, free_ts))))) + end; + + val SOME (_, _, constrs) = AList.lookup (op =) descr index; + val T = List.nth (get_rec_types descr sorts, index); + val (rs, prems) = split_list (map (make_casedist_prem T) constrs); + val r = Const (case_name, map fastype_of rs ---> T --> rT); + + val y = Var (("y", 0), Logic.legacy_varifyT T); + val y' = Free ("y", T); + + val thm = OldGoals.prove_goalw_cterm [] (cert (Logic.list_implies (prems, + HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $ + list_comb (r, rs @ [y']))))) + (fn prems => + [rtac (cterm_instantiate [(cert y, cert y')] exhaustion) 1, + ALLGOALS (EVERY' + [asm_simp_tac (HOL_basic_ss addsimps case_rewrites), + resolve_tac prems, asm_simp_tac HOL_basic_ss])]); + + val exh_name = Thm.get_name exhaustion; + val (thm', thy') = thy + |> Sign.root_path + |> PureThy.store_thm (Binding.qualified_name (exh_name ^ "_P_correctness"), thm) + ||> Sign.restore_naming thy; + + val P = Var (("P", 0), rT' --> HOLogic.boolT); + val prf = forall_intr_prf (y, forall_intr_prf (P, + List.foldr (fn ((p, r), prf) => + forall_intr_prf (Logic.legacy_varify r, AbsP ("H", SOME (Logic.varify p), + prf))) (Proofterm.proof_combP (prf_of thm', + map PBound (length prems - 1 downto 0))) (prems ~~ rs))); + val r' = Logic.legacy_varify (Abs ("y", Logic.legacy_varifyT T, + list_abs (map dest_Free rs, list_comb (r, + map Bound ((length rs - 1 downto 0) @ [length rs]))))); + + in Extraction.add_realizers_i + [(exh_name, (["P"], r', prf)), + (exh_name, ([], Extraction.nullt, prf_of exhaustion))] thy' + end; + +fun add_dt_realizers config names thy = + if ! Proofterm.proofs < 2 then thy + else let + val _ = message config "Adding realizers for induction and case analysis ..." + val infos = map (Datatype.the_datatype thy) names; + val info :: _ = infos; + in + thy + |> fold_rev (make_ind (#sorts info) info) (subsets 0 (length (#descr info) - 1)) + |> fold_rev (make_casedists (#sorts info)) infos + end; + +val setup = Datatype.interpretation add_dt_realizers; + +end; diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Datatype/datatype_rep_proofs.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Datatype/datatype_rep_proofs.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,643 @@ +(* Title: HOL/Tools/datatype_rep_proofs.ML + Author: Stefan Berghofer, TU Muenchen + +Definitional introduction of datatypes +Proof of characteristic theorems: + + - injectivity of constructors + - distinctness of constructors + - induction theorem +*) + +signature DATATYPE_REP_PROOFS = +sig + include DATATYPE_COMMON + val distinctness_limit : int Config.T + val distinctness_limit_setup : theory -> theory + val representation_proofs : config -> info Symtab.table -> + string list -> descr list -> (string * sort) list -> + (binding * mixfix) list -> (binding * mixfix) list list -> attribute + -> theory -> (thm list list * thm list list * thm list list * + DatatypeAux.simproc_dist list * thm) * theory +end; + +structure DatatypeRepProofs : DATATYPE_REP_PROOFS = +struct + +open DatatypeAux; + +(*the kind of distinctiveness axioms depends on number of constructors*) +val (distinctness_limit, distinctness_limit_setup) = + Attrib.config_int "datatype_distinctness_limit" 7; + +val (_ $ (_ $ (_ $ (distinct_f $ _) $ _))) = hd (prems_of distinct_lemma); + +val collect_simp = rewrite_rule [mk_meta_eq mem_Collect_eq]; + + +(** theory context references **) + +val f_myinv_f = thm "f_myinv_f"; +val myinv_f_f = thm "myinv_f_f"; + + +fun exh_thm_of (dt_info : info Symtab.table) tname = + #exhaustion (the (Symtab.lookup dt_info tname)); + +(******************************************************************************) + +fun representation_proofs (config : config) (dt_info : info Symtab.table) + new_type_names descr sorts types_syntax constr_syntax case_names_induct thy = + let + val Datatype_thy = ThyInfo.the_theory "Datatype" thy; + val node_name = "Datatype.node"; + val In0_name = "Datatype.In0"; + val In1_name = "Datatype.In1"; + val Scons_name = "Datatype.Scons"; + val Leaf_name = "Datatype.Leaf"; + val Numb_name = "Datatype.Numb"; + val Lim_name = "Datatype.Lim"; + val Suml_name = "Datatype.Suml"; + val Sumr_name = "Datatype.Sumr"; + + val [In0_inject, In1_inject, Scons_inject, Leaf_inject, + In0_eq, In1_eq, In0_not_In1, In1_not_In0, + Lim_inject, Suml_inject, Sumr_inject] = map (PureThy.get_thm Datatype_thy) + ["In0_inject", "In1_inject", "Scons_inject", "Leaf_inject", + "In0_eq", "In1_eq", "In0_not_In1", "In1_not_In0", + "Lim_inject", "Suml_inject", "Sumr_inject"]; + + val descr' = flat descr; + + val big_name = space_implode "_" new_type_names; + val thy1 = add_path (#flat_names config) big_name thy; + val big_rec_name = big_name ^ "_rep_set"; + val rep_set_names' = + (if length descr' = 1 then [big_rec_name] else + (map ((curry (op ^) (big_rec_name ^ "_")) o string_of_int) + (1 upto (length descr')))); + val rep_set_names = map (Sign.full_bname thy1) rep_set_names'; + + val tyvars = map (fn (_, (_, Ts, _)) => map dest_DtTFree Ts) (hd descr); + val leafTs' = get_nonrec_types descr' sorts; + val branchTs = get_branching_types descr' sorts; + val branchT = if null branchTs then HOLogic.unitT + else BalancedTree.make (fn (T, U) => Type ("+", [T, U])) branchTs; + val arities = get_arities descr' \ 0; + val unneeded_vars = hd tyvars \\ List.foldr OldTerm.add_typ_tfree_names [] (leafTs' @ branchTs); + val leafTs = leafTs' @ (map (fn n => TFree (n, (the o AList.lookup (op =) sorts) n)) unneeded_vars); + val recTs = get_rec_types descr' sorts; + val newTs = Library.take (length (hd descr), recTs); + val oldTs = Library.drop (length (hd descr), recTs); + val sumT = if null leafTs then HOLogic.unitT + else BalancedTree.make (fn (T, U) => Type ("+", [T, U])) leafTs; + val Univ_elT = HOLogic.mk_setT (Type (node_name, [sumT, branchT])); + val UnivT = HOLogic.mk_setT Univ_elT; + val UnivT' = Univ_elT --> HOLogic.boolT; + val Collect = Const ("Collect", UnivT' --> UnivT); + + val In0 = Const (In0_name, Univ_elT --> Univ_elT); + val In1 = Const (In1_name, Univ_elT --> Univ_elT); + val Leaf = Const (Leaf_name, sumT --> Univ_elT); + val Lim = Const (Lim_name, (branchT --> Univ_elT) --> Univ_elT); + + (* make injections needed for embedding types in leaves *) + + fun mk_inj T' x = + let + fun mk_inj' T n i = + if n = 1 then x else + let val n2 = n div 2; + val Type (_, [T1, T2]) = T + in + if i <= n2 then + Const ("Sum_Type.Inl", T1 --> T) $ (mk_inj' T1 n2 i) + else + Const ("Sum_Type.Inr", T2 --> T) $ (mk_inj' T2 (n - n2) (i - n2)) + end + in mk_inj' sumT (length leafTs) (1 + find_index_eq T' leafTs) + end; + + (* make injections for constructors *) + + fun mk_univ_inj ts = BalancedTree.access + {left = fn t => In0 $ t, + right = fn t => In1 $ t, + init = + if ts = [] then Const (@{const_name undefined}, Univ_elT) + else foldr1 (HOLogic.mk_binop Scons_name) ts}; + + (* function spaces *) + + fun mk_fun_inj T' x = + let + fun mk_inj T n i = + if n = 1 then x else + let + val n2 = n div 2; + val Type (_, [T1, T2]) = T; + fun mkT U = (U --> Univ_elT) --> T --> Univ_elT + in + if i <= n2 then Const (Suml_name, mkT T1) $ mk_inj T1 n2 i + else Const (Sumr_name, mkT T2) $ mk_inj T2 (n - n2) (i - n2) + end + in mk_inj branchT (length branchTs) (1 + find_index_eq T' branchTs) + end; + + val mk_lim = List.foldr (fn (T, t) => Lim $ mk_fun_inj T (Abs ("x", T, t))); + + (************** generate introduction rules for representing set **********) + + val _ = message config "Constructing representing sets ..."; + + (* make introduction rule for a single constructor *) + + fun make_intr s n (i, (_, cargs)) = + let + fun mk_prem (dt, (j, prems, ts)) = (case strip_dtyp dt of + (dts, DtRec k) => + let + val Ts = map (typ_of_dtyp descr' sorts) dts; + val free_t = + app_bnds (mk_Free "x" (Ts ---> Univ_elT) j) (length Ts) + in (j + 1, list_all (map (pair "x") Ts, + HOLogic.mk_Trueprop + (Free (List.nth (rep_set_names', k), UnivT') $ free_t)) :: prems, + mk_lim free_t Ts :: ts) + end + | _ => + let val T = typ_of_dtyp descr' sorts dt + in (j + 1, prems, (Leaf $ mk_inj T (mk_Free "x" T j))::ts) + end); + + val (_, prems, ts) = List.foldr mk_prem (1, [], []) cargs; + val concl = HOLogic.mk_Trueprop + (Free (s, UnivT') $ mk_univ_inj ts n i) + in Logic.list_implies (prems, concl) + end; + + val intr_ts = maps (fn ((_, (_, _, constrs)), rep_set_name) => + map (make_intr rep_set_name (length constrs)) + ((1 upto (length constrs)) ~~ constrs)) (descr' ~~ rep_set_names'); + + val ({raw_induct = rep_induct, intrs = rep_intrs, ...}, thy2) = + Inductive.add_inductive_global (serial_string ()) + {quiet_mode = #quiet config, verbose = false, kind = Thm.internalK, + alt_name = Binding.name big_rec_name, coind = false, no_elim = true, no_ind = false, + skip_mono = true, fork_mono = false} + (map (fn s => ((Binding.name s, UnivT'), NoSyn)) rep_set_names') [] + (map (fn x => (Attrib.empty_binding, x)) intr_ts) [] thy1; + + (********************************* typedef ********************************) + + val (typedefs, thy3) = thy2 |> + parent_path (#flat_names config) |> + fold_map (fn ((((name, mx), tvs), c), name') => + Typedef.add_typedef false (SOME (Binding.name name')) (name, tvs, mx) + (Collect $ Const (c, UnivT')) NONE + (rtac exI 1 THEN rtac CollectI 1 THEN + QUIET_BREADTH_FIRST (has_fewer_prems 1) + (resolve_tac rep_intrs 1))) + (types_syntax ~~ tyvars ~~ + (Library.take (length newTs, rep_set_names)) ~~ new_type_names) ||> + add_path (#flat_names config) big_name; + + (*********************** definition of constructors ***********************) + + val big_rep_name = (space_implode "_" new_type_names) ^ "_Rep_"; + val rep_names = map (curry op ^ "Rep_") new_type_names; + val rep_names' = map (fn i => big_rep_name ^ (string_of_int i)) + (1 upto (length (flat (tl descr)))); + val all_rep_names = map (Sign.intern_const thy3) rep_names @ + map (Sign.full_bname thy3) rep_names'; + + (* isomorphism declarations *) + + val iso_decls = map (fn (T, s) => (Binding.name s, T --> Univ_elT, NoSyn)) + (oldTs ~~ rep_names'); + + (* constructor definitions *) + + fun make_constr_def tname T n ((thy, defs, eqns, i), ((cname, cargs), (cname', mx))) = + let + fun constr_arg (dt, (j, l_args, r_args)) = + let val T = typ_of_dtyp descr' sorts dt; + val free_t = mk_Free "x" T j + in (case (strip_dtyp dt, strip_type T) of + ((_, DtRec m), (Us, U)) => (j + 1, free_t :: l_args, mk_lim + (Const (List.nth (all_rep_names, m), U --> Univ_elT) $ + app_bnds free_t (length Us)) Us :: r_args) + | _ => (j + 1, free_t::l_args, (Leaf $ mk_inj T free_t)::r_args)) + end; + + val (_, l_args, r_args) = List.foldr constr_arg (1, [], []) cargs; + val constrT = (map (typ_of_dtyp descr' sorts) cargs) ---> T; + val abs_name = Sign.intern_const thy ("Abs_" ^ tname); + val rep_name = Sign.intern_const thy ("Rep_" ^ tname); + val lhs = list_comb (Const (cname, constrT), l_args); + val rhs = mk_univ_inj r_args n i; + val def = Logic.mk_equals (lhs, Const (abs_name, Univ_elT --> T) $ rhs); + val def_name = Long_Name.base_name cname ^ "_def"; + val eqn = HOLogic.mk_Trueprop (HOLogic.mk_eq + (Const (rep_name, T --> Univ_elT) $ lhs, rhs)); + val ([def_thm], thy') = + thy + |> Sign.add_consts_i [(cname', constrT, mx)] + |> (PureThy.add_defs false o map Thm.no_attributes) [(Binding.name def_name, def)]; + + in (thy', defs @ [def_thm], eqns @ [eqn], i + 1) end; + + (* constructor definitions for datatype *) + + fun dt_constr_defs ((thy, defs, eqns, rep_congs, dist_lemmas), + ((((_, (_, _, constrs)), tname), T), constr_syntax)) = + let + val _ $ (_ $ (cong_f $ _) $ _) = concl_of arg_cong; + val rep_const = cterm_of thy + (Const (Sign.intern_const thy ("Rep_" ^ tname), T --> Univ_elT)); + val cong' = standard (cterm_instantiate [(cterm_of thy cong_f, rep_const)] arg_cong); + val dist = standard (cterm_instantiate [(cterm_of thy distinct_f, rep_const)] distinct_lemma); + val (thy', defs', eqns', _) = Library.foldl ((make_constr_def tname T) (length constrs)) + ((add_path (#flat_names config) tname thy, defs, [], 1), constrs ~~ constr_syntax) + in + (parent_path (#flat_names config) thy', defs', eqns @ [eqns'], + rep_congs @ [cong'], dist_lemmas @ [dist]) + end; + + val (thy4, constr_defs, constr_rep_eqns, rep_congs, dist_lemmas) = Library.foldl dt_constr_defs + ((thy3 |> Sign.add_consts_i iso_decls |> parent_path (#flat_names config), [], [], [], []), + hd descr ~~ new_type_names ~~ newTs ~~ constr_syntax); + + (*********** isomorphisms for new types (introduced by typedef) ***********) + + val _ = message config "Proving isomorphism properties ..."; + + val newT_iso_axms = map (fn (_, td) => + (collect_simp (#Abs_inverse td), #Rep_inverse td, + collect_simp (#Rep td))) typedefs; + + val newT_iso_inj_thms = map (fn (_, td) => + (collect_simp (#Abs_inject td) RS iffD1, #Rep_inject td RS iffD1)) typedefs; + + (********* isomorphisms between existing types and "unfolded" types *******) + + (*---------------------------------------------------------------------*) + (* isomorphisms are defined using primrec-combinators: *) + (* generate appropriate functions for instantiating primrec-combinator *) + (* *) + (* e.g. dt_Rep_i = list_rec ... (%h t y. In1 (Scons (Leaf h) y)) *) + (* *) + (* also generate characteristic equations for isomorphisms *) + (* *) + (* e.g. dt_Rep_i (cons h t) = In1 (Scons (dt_Rep_j h) (dt_Rep_i t)) *) + (*---------------------------------------------------------------------*) + + fun make_iso_def k ks n ((fs, eqns, i), (cname, cargs)) = + let + val argTs = map (typ_of_dtyp descr' sorts) cargs; + val T = List.nth (recTs, k); + val rep_name = List.nth (all_rep_names, k); + val rep_const = Const (rep_name, T --> Univ_elT); + val constr = Const (cname, argTs ---> T); + + fun process_arg ks' ((i2, i2', ts, Ts), dt) = + let + val T' = typ_of_dtyp descr' sorts dt; + val (Us, U) = strip_type T' + in (case strip_dtyp dt of + (_, DtRec j) => if j mem ks' then + (i2 + 1, i2' + 1, ts @ [mk_lim (app_bnds + (mk_Free "y" (Us ---> Univ_elT) i2') (length Us)) Us], + Ts @ [Us ---> Univ_elT]) + else + (i2 + 1, i2', ts @ [mk_lim + (Const (List.nth (all_rep_names, j), U --> Univ_elT) $ + app_bnds (mk_Free "x" T' i2) (length Us)) Us], Ts) + | _ => (i2 + 1, i2', ts @ [Leaf $ mk_inj T' (mk_Free "x" T' i2)], Ts)) + end; + + val (i2, i2', ts, Ts) = Library.foldl (process_arg ks) ((1, 1, [], []), cargs); + val xs = map (uncurry (mk_Free "x")) (argTs ~~ (1 upto (i2 - 1))); + val ys = map (uncurry (mk_Free "y")) (Ts ~~ (1 upto (i2' - 1))); + val f = list_abs_free (map dest_Free (xs @ ys), mk_univ_inj ts n i); + + val (_, _, ts', _) = Library.foldl (process_arg []) ((1, 1, [], []), cargs); + val eqn = HOLogic.mk_Trueprop (HOLogic.mk_eq + (rep_const $ list_comb (constr, xs), mk_univ_inj ts' n i)) + + in (fs @ [f], eqns @ [eqn], i + 1) end; + + (* define isomorphisms for all mutually recursive datatypes in list ds *) + + fun make_iso_defs (ds, (thy, char_thms)) = + let + val ks = map fst ds; + val (_, (tname, _, _)) = hd ds; + val {rec_rewrites, rec_names, ...} = the (Symtab.lookup dt_info tname); + + fun process_dt ((fs, eqns, isos), (k, (tname, _, constrs))) = + let + val (fs', eqns', _) = Library.foldl (make_iso_def k ks (length constrs)) + ((fs, eqns, 1), constrs); + val iso = (List.nth (recTs, k), List.nth (all_rep_names, k)) + in (fs', eqns', isos @ [iso]) end; + + val (fs, eqns, isos) = Library.foldl process_dt (([], [], []), ds); + val fTs = map fastype_of fs; + val defs = map (fn (rec_name, (T, iso_name)) => (Binding.name (Long_Name.base_name iso_name ^ "_def"), + Logic.mk_equals (Const (iso_name, T --> Univ_elT), + list_comb (Const (rec_name, fTs @ [T] ---> Univ_elT), fs)))) (rec_names ~~ isos); + val (def_thms, thy') = + apsnd Theory.checkpoint ((PureThy.add_defs false o map Thm.no_attributes) defs thy); + + (* prove characteristic equations *) + + val rewrites = def_thms @ (map mk_meta_eq rec_rewrites); + val char_thms' = map (fn eqn => SkipProof.prove_global thy' [] [] eqn + (fn _ => EVERY [rewrite_goals_tac rewrites, rtac refl 1])) eqns; + + in (thy', char_thms' @ char_thms) end; + + val (thy5, iso_char_thms) = apfst Theory.checkpoint (List.foldr make_iso_defs + (add_path (#flat_names config) big_name thy4, []) (tl descr)); + + (* prove isomorphism properties *) + + fun mk_funs_inv thy thm = + let + val prop = Thm.prop_of thm; + val _ $ (_ $ ((S as Const (_, Type (_, [U, _]))) $ _ )) $ + (_ $ (_ $ (r $ (a $ _)) $ _)) = Type.freeze prop; + val used = OldTerm.add_term_tfree_names (a, []); + + fun mk_thm i = + let + val Ts = map (TFree o rpair HOLogic.typeS) + (Name.variant_list used (replicate i "'t")); + val f = Free ("f", Ts ---> U) + in SkipProof.prove_global thy [] [] (Logic.mk_implies + (HOLogic.mk_Trueprop (HOLogic.list_all + (map (pair "x") Ts, S $ app_bnds f i)), + HOLogic.mk_Trueprop (HOLogic.mk_eq (list_abs (map (pair "x") Ts, + r $ (a $ app_bnds f i)), f)))) + (fn _ => EVERY [REPEAT_DETERM_N i (rtac ext 1), + REPEAT (etac allE 1), rtac thm 1, atac 1]) + end + in map (fn r => r RS subst) (thm :: map mk_thm arities) end; + + (* prove inj dt_Rep_i and dt_Rep_i x : dt_rep_set_i *) + + val fun_congs = map (fn T => make_elim (Drule.instantiate' + [SOME (ctyp_of thy5 T)] [] fun_cong)) branchTs; + + fun prove_iso_thms (ds, (inj_thms, elem_thms)) = + let + val (_, (tname, _, _)) = hd ds; + val {induction, ...} = the (Symtab.lookup dt_info tname); + + fun mk_ind_concl (i, _) = + let + val T = List.nth (recTs, i); + val Rep_t = Const (List.nth (all_rep_names, i), T --> Univ_elT); + val rep_set_name = List.nth (rep_set_names, i) + in (HOLogic.all_const T $ Abs ("y", T, HOLogic.imp $ + HOLogic.mk_eq (Rep_t $ mk_Free "x" T i, Rep_t $ Bound 0) $ + HOLogic.mk_eq (mk_Free "x" T i, Bound 0)), + Const (rep_set_name, UnivT') $ (Rep_t $ mk_Free "x" T i)) + end; + + val (ind_concl1, ind_concl2) = ListPair.unzip (map mk_ind_concl ds); + + val rewrites = map mk_meta_eq iso_char_thms; + val inj_thms' = map snd newT_iso_inj_thms @ + map (fn r => r RS @{thm injD}) inj_thms; + + val inj_thm = SkipProof.prove_global thy5 [] [] + (HOLogic.mk_Trueprop (mk_conj ind_concl1)) (fn _ => EVERY + [(indtac induction [] THEN_ALL_NEW ObjectLogic.atomize_prems_tac) 1, + REPEAT (EVERY + [rtac allI 1, rtac impI 1, + exh_tac (exh_thm_of dt_info) 1, + REPEAT (EVERY + [hyp_subst_tac 1, + rewrite_goals_tac rewrites, + REPEAT (dresolve_tac [In0_inject, In1_inject] 1), + (eresolve_tac [In0_not_In1 RS notE, In1_not_In0 RS notE] 1) + ORELSE (EVERY + [REPEAT (eresolve_tac (Scons_inject :: + map make_elim [Leaf_inject, Inl_inject, Inr_inject]) 1), + REPEAT (cong_tac 1), rtac refl 1, + REPEAT (atac 1 ORELSE (EVERY + [REPEAT (rtac ext 1), + REPEAT (eresolve_tac (mp :: allE :: + map make_elim (Suml_inject :: Sumr_inject :: + Lim_inject :: inj_thms') @ fun_congs) 1), + atac 1]))])])])]); + + val inj_thms'' = map (fn r => r RS @{thm datatype_injI}) + (split_conj_thm inj_thm); + + val elem_thm = + SkipProof.prove_global thy5 [] [] (HOLogic.mk_Trueprop (mk_conj ind_concl2)) + (fn _ => + EVERY [(indtac induction [] THEN_ALL_NEW ObjectLogic.atomize_prems_tac) 1, + rewrite_goals_tac rewrites, + REPEAT ((resolve_tac rep_intrs THEN_ALL_NEW + ((REPEAT o etac allE) THEN' ares_tac elem_thms)) 1)]); + + in (inj_thms'' @ inj_thms, elem_thms @ (split_conj_thm elem_thm)) + end; + + val (iso_inj_thms_unfolded, iso_elem_thms) = List.foldr prove_iso_thms + ([], map #3 newT_iso_axms) (tl descr); + val iso_inj_thms = map snd newT_iso_inj_thms @ + map (fn r => r RS @{thm injD}) iso_inj_thms_unfolded; + + (* prove dt_rep_set_i x --> x : range dt_Rep_i *) + + fun mk_iso_t (((set_name, iso_name), i), T) = + let val isoT = T --> Univ_elT + in HOLogic.imp $ + (Const (set_name, UnivT') $ mk_Free "x" Univ_elT i) $ + (if i < length newTs then HOLogic.true_const + else HOLogic.mk_mem (mk_Free "x" Univ_elT i, + Const (@{const_name image}, isoT --> HOLogic.mk_setT T --> UnivT) $ + Const (iso_name, isoT) $ Const (@{const_name UNIV}, HOLogic.mk_setT T))) + end; + + val iso_t = HOLogic.mk_Trueprop (mk_conj (map mk_iso_t + (rep_set_names ~~ all_rep_names ~~ (0 upto (length descr' - 1)) ~~ recTs))); + + (* all the theorems are proved by one single simultaneous induction *) + + val range_eqs = map (fn r => mk_meta_eq (r RS @{thm range_ex1_eq})) + iso_inj_thms_unfolded; + + val iso_thms = if length descr = 1 then [] else + Library.drop (length newTs, split_conj_thm + (SkipProof.prove_global thy5 [] [] iso_t (fn _ => EVERY + [(indtac rep_induct [] THEN_ALL_NEW ObjectLogic.atomize_prems_tac) 1, + REPEAT (rtac TrueI 1), + rewrite_goals_tac (mk_meta_eq choice_eq :: + symmetric (mk_meta_eq @{thm expand_fun_eq}) :: range_eqs), + rewrite_goals_tac (map symmetric range_eqs), + REPEAT (EVERY + [REPEAT (eresolve_tac ([rangeE, ex1_implies_ex RS exE] @ + maps (mk_funs_inv thy5 o #1) newT_iso_axms) 1), + TRY (hyp_subst_tac 1), + rtac (sym RS range_eqI) 1, + resolve_tac iso_char_thms 1])]))); + + val Abs_inverse_thms' = + map #1 newT_iso_axms @ + map2 (fn r_inj => fn r => f_myinv_f OF [r_inj, r RS mp]) + iso_inj_thms_unfolded iso_thms; + + val Abs_inverse_thms = maps (mk_funs_inv thy5) Abs_inverse_thms'; + + (******************* freeness theorems for constructors *******************) + + val _ = message config "Proving freeness of constructors ..."; + + (* prove theorem Rep_i (Constr_j ...) = Inj_j ... *) + + fun prove_constr_rep_thm eqn = + let + val inj_thms = map fst newT_iso_inj_thms; + val rewrites = @{thm o_def} :: constr_defs @ (map (mk_meta_eq o #2) newT_iso_axms) + in SkipProof.prove_global thy5 [] [] eqn (fn _ => EVERY + [resolve_tac inj_thms 1, + rewrite_goals_tac rewrites, + rtac refl 3, + resolve_tac rep_intrs 2, + REPEAT (resolve_tac iso_elem_thms 1)]) + end; + + (*--------------------------------------------------------------*) + (* constr_rep_thms and rep_congs are used to prove distinctness *) + (* of constructors. *) + (*--------------------------------------------------------------*) + + val constr_rep_thms = map (map prove_constr_rep_thm) constr_rep_eqns; + + val dist_rewrites = map (fn (rep_thms, dist_lemma) => + dist_lemma::(rep_thms @ [In0_eq, In1_eq, In0_not_In1, In1_not_In0])) + (constr_rep_thms ~~ dist_lemmas); + + fun prove_distinct_thms _ _ (_, []) = [] + | prove_distinct_thms lim dist_rewrites' (k, ts as _ :: _) = + if k >= lim then [] else let + (*number of constructors < distinctness_limit : C_i ... ~= C_j ...*) + fun prove [] = [] + | prove (t :: ts) = + let + val dist_thm = SkipProof.prove_global thy5 [] [] t (fn _ => + EVERY [simp_tac (HOL_ss addsimps dist_rewrites') 1]) + in dist_thm :: standard (dist_thm RS not_sym) :: prove ts end; + in prove ts end; + + val distinct_thms = DatatypeProp.make_distincts descr sorts + |> map2 (prove_distinct_thms + (Config.get_thy thy5 distinctness_limit)) dist_rewrites; + + val simproc_dists = map (fn ((((_, (_, _, constrs)), rep_thms), congr), dists) => + if length constrs < Config.get_thy thy5 distinctness_limit + then FewConstrs dists + else ManyConstrs (congr, HOL_basic_ss addsimps rep_thms)) (hd descr ~~ + constr_rep_thms ~~ rep_congs ~~ distinct_thms); + + (* prove injectivity of constructors *) + + fun prove_constr_inj_thm rep_thms t = + let val inj_thms = Scons_inject :: (map make_elim + (iso_inj_thms @ + [In0_inject, In1_inject, Leaf_inject, Inl_inject, Inr_inject, + Lim_inject, Suml_inject, Sumr_inject])) + in SkipProof.prove_global thy5 [] [] t (fn _ => EVERY + [rtac iffI 1, + REPEAT (etac conjE 2), hyp_subst_tac 2, rtac refl 2, + dresolve_tac rep_congs 1, dtac box_equals 1, + REPEAT (resolve_tac rep_thms 1), + REPEAT (eresolve_tac inj_thms 1), + REPEAT (ares_tac [conjI] 1 ORELSE (EVERY [REPEAT (rtac ext 1), + REPEAT (eresolve_tac (make_elim fun_cong :: inj_thms) 1), + atac 1]))]) + end; + + val constr_inject = map (fn (ts, thms) => map (prove_constr_inj_thm thms) ts) + ((DatatypeProp.make_injs descr sorts) ~~ constr_rep_thms); + + val ((constr_inject', distinct_thms'), thy6) = + thy5 + |> parent_path (#flat_names config) + |> store_thmss "inject" new_type_names constr_inject + ||>> store_thmss "distinct" new_type_names distinct_thms; + + (*************************** induction theorem ****************************) + + val _ = message config "Proving induction rule for datatypes ..."; + + val Rep_inverse_thms = (map (fn (_, iso, _) => iso RS subst) newT_iso_axms) @ + (map (fn r => r RS myinv_f_f RS subst) iso_inj_thms_unfolded); + val Rep_inverse_thms' = map (fn r => r RS myinv_f_f) iso_inj_thms_unfolded; + + fun mk_indrule_lemma ((prems, concls), ((i, _), T)) = + let + val Rep_t = Const (List.nth (all_rep_names, i), T --> Univ_elT) $ + mk_Free "x" T i; + + val Abs_t = if i < length newTs then + Const (Sign.intern_const thy6 + ("Abs_" ^ (List.nth (new_type_names, i))), Univ_elT --> T) + else Const ("Inductive.myinv", [T --> Univ_elT, Univ_elT] ---> T) $ + Const (List.nth (all_rep_names, i), T --> Univ_elT) + + in (prems @ [HOLogic.imp $ + (Const (List.nth (rep_set_names, i), UnivT') $ Rep_t) $ + (mk_Free "P" (T --> HOLogic.boolT) (i + 1) $ (Abs_t $ Rep_t))], + concls @ [mk_Free "P" (T --> HOLogic.boolT) (i + 1) $ mk_Free "x" T i]) + end; + + val (indrule_lemma_prems, indrule_lemma_concls) = + Library.foldl mk_indrule_lemma (([], []), (descr' ~~ recTs)); + + val cert = cterm_of thy6; + + val indrule_lemma = SkipProof.prove_global thy6 [] [] + (Logic.mk_implies + (HOLogic.mk_Trueprop (mk_conj indrule_lemma_prems), + HOLogic.mk_Trueprop (mk_conj indrule_lemma_concls))) (fn _ => EVERY + [REPEAT (etac conjE 1), + REPEAT (EVERY + [TRY (rtac conjI 1), resolve_tac Rep_inverse_thms 1, + etac mp 1, resolve_tac iso_elem_thms 1])]); + + val Ps = map head_of (HOLogic.dest_conj (HOLogic.dest_Trueprop (concl_of indrule_lemma))); + val frees = if length Ps = 1 then [Free ("P", snd (dest_Var (hd Ps)))] else + map (Free o apfst fst o dest_Var) Ps; + val indrule_lemma' = cterm_instantiate (map cert Ps ~~ map cert frees) indrule_lemma; + + val dt_induct_prop = DatatypeProp.make_ind descr sorts; + val dt_induct = SkipProof.prove_global thy6 [] + (Logic.strip_imp_prems dt_induct_prop) (Logic.strip_imp_concl dt_induct_prop) + (fn {prems, ...} => EVERY + [rtac indrule_lemma' 1, + (indtac rep_induct [] THEN_ALL_NEW ObjectLogic.atomize_prems_tac) 1, + EVERY (map (fn (prem, r) => (EVERY + [REPEAT (eresolve_tac Abs_inverse_thms 1), + simp_tac (HOL_basic_ss addsimps ((symmetric r)::Rep_inverse_thms')) 1, + DEPTH_SOLVE_1 (ares_tac [prem] 1 ORELSE etac allE 1)])) + (prems ~~ (constr_defs @ (map mk_meta_eq iso_char_thms))))]); + + val ([dt_induct'], thy7) = + thy6 + |> Sign.add_path big_name + |> PureThy.add_thms [((Binding.name "induct", dt_induct), [case_names_induct])] + ||> Sign.parent_path + ||> Theory.checkpoint; + + in + ((constr_inject', distinct_thms', dist_rewrites, simproc_dists, dt_induct'), thy7) + end; + +end; diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/auto_term.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/auto_term.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,36 @@ +(* Title: HOL/Tools/Function/auto_term.ML + Author: Alexander Krauss, TU Muenchen + +A package for general recursive function definitions. +Method "relation" to commence a termination proof using a user-specified relation. +*) + +signature FUNDEF_RELATION = +sig + val relation_tac: Proof.context -> term -> int -> tactic + val setup: theory -> theory +end + +structure FundefRelation : FUNDEF_RELATION = +struct + +fun inst_thm ctxt rel st = + let + val cert = Thm.cterm_of (ProofContext.theory_of ctxt) + val rel' = cert (singleton (Variable.polymorphic ctxt) rel) + val st' = Thm.incr_indexes (#maxidx (Thm.rep_cterm rel') + 1) st + val Rvar = cert (Var (the_single (Term.add_vars (prop_of st') []))) + in + Drule.cterm_instantiate [(Rvar, rel')] st' + end + +fun relation_tac ctxt rel i = + TRY (FundefCommon.apply_termination_rule ctxt i) + THEN PRIMITIVE (inst_thm ctxt rel) + +val setup = + Method.setup @{binding relation} + (Args.term >> (fn rel => fn ctxt => SIMPLE_METHOD' (relation_tac ctxt rel))) + "proves termination using a user-specified wellfounded relation" + +end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/context_tree.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/context_tree.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,278 @@ +(* Title: HOL/Tools/Function/context_tree.ML + Author: Alexander Krauss, TU Muenchen + +A package for general recursive function definitions. +Builds and traverses trees of nested contexts along a term. +*) + +signature FUNDEF_CTXTREE = +sig + type ctxt = (string * typ) list * thm list (* poor man's contexts: fixes + assumes *) + type ctx_tree + + (* FIXME: This interface is a mess and needs to be cleaned up! *) + val get_fundef_congs : Proof.context -> thm list + val add_fundef_cong : thm -> Context.generic -> Context.generic + val map_fundef_congs : (thm list -> thm list) -> Context.generic -> Context.generic + + val cong_add: attribute + val cong_del: attribute + + val mk_tree: (string * typ) -> term -> Proof.context -> term -> ctx_tree + + val inst_tree: theory -> term -> term -> ctx_tree -> ctx_tree + + val export_term : ctxt -> term -> term + val export_thm : theory -> ctxt -> thm -> thm + val import_thm : theory -> ctxt -> thm -> thm + + val traverse_tree : + (ctxt -> term -> + (ctxt * thm) list -> + (ctxt * thm) list * 'b -> + (ctxt * thm) list * 'b) + -> ctx_tree -> 'b -> 'b + + val rewrite_by_tree : theory -> term -> thm -> (thm * thm) list -> ctx_tree -> thm * (thm * thm) list +end + +structure FundefCtxTree : FUNDEF_CTXTREE = +struct + +type ctxt = (string * typ) list * thm list + +open FundefCommon +open FundefLib + +structure FundefCongs = GenericDataFun +( + type T = thm list + val empty = [] + val extend = I + fun merge _ = Thm.merge_thms +); + +val get_fundef_congs = FundefCongs.get o Context.Proof +val map_fundef_congs = FundefCongs.map +val add_fundef_cong = FundefCongs.map o Thm.add_thm + +(* congruence rules *) + +val cong_add = Thm.declaration_attribute (map_fundef_congs o Thm.add_thm o safe_mk_meta_eq); +val cong_del = Thm.declaration_attribute (map_fundef_congs o Thm.del_thm o safe_mk_meta_eq); + + +type depgraph = int IntGraph.T + +datatype ctx_tree + = Leaf of term + | Cong of (thm * depgraph * (ctxt * ctx_tree) list) + | RCall of (term * ctx_tree) + + +(* Maps "Trueprop A = B" to "A" *) +val rhs_of = snd o HOLogic.dest_eq o HOLogic.dest_Trueprop + + +(*** Dependency analysis for congruence rules ***) + +fun branch_vars t = + let + val t' = snd (dest_all_all t) + val (assumes, concl) = Logic.strip_horn t' + in (fold Term.add_vars assumes [], Term.add_vars concl []) + end + +fun cong_deps crule = + let + val num_branches = map_index (apsnd branch_vars) (prems_of crule) + in + IntGraph.empty + |> fold (fn (i,_)=> IntGraph.new_node (i,i)) num_branches + |> fold_product (fn (i, (c1, _)) => fn (j, (_, t2)) => + if i = j orelse null (c1 inter t2) + then I else IntGraph.add_edge_acyclic (i,j)) + num_branches num_branches + end + +val default_congs = map (fn c => c RS eq_reflection) [@{thm "cong"}, @{thm "ext"}] + + + +(* Called on the INSTANTIATED branches of the congruence rule *) +fun mk_branch ctx t = + let + val (ctx', fixes, impl) = dest_all_all_ctx ctx t + val (assms, concl) = Logic.strip_horn impl + in + (ctx', fixes, assms, rhs_of concl) + end + +fun find_cong_rule ctx fvar h ((r,dep)::rs) t = + (let + val thy = ProofContext.theory_of ctx + + val tt' = Logic.mk_equals (Pattern.rewrite_term thy [(Free fvar, h)] [] t, t) + val (c, subs) = (concl_of r, prems_of r) + + val subst = Pattern.match (ProofContext.theory_of ctx) (c, tt') (Vartab.empty, Vartab.empty) + val branches = map (mk_branch ctx o Envir.beta_norm o Envir.subst_vars subst) subs + val inst = map (fn v => (cterm_of thy (Var v), cterm_of thy (Envir.subst_vars subst (Var v)))) (Term.add_vars c []) + in + (cterm_instantiate inst r, dep, branches) + end + handle Pattern.MATCH => find_cong_rule ctx fvar h rs t) + | find_cong_rule _ _ _ [] _ = sys_error "Function/context_tree.ML: No cong rule found!" + + +fun mk_tree fvar h ctxt t = + let + val congs = get_fundef_congs ctxt + val congs_deps = map (fn c => (c, cong_deps c)) (congs @ default_congs) (* FIXME: Save in theory *) + + fun matchcall (a $ b) = if a = Free fvar then SOME b else NONE + | matchcall _ = NONE + + fun mk_tree' ctx t = + case matchcall t of + SOME arg => RCall (t, mk_tree' ctx arg) + | NONE => + if not (exists_subterm (fn Free v => v = fvar | _ => false) t) then Leaf t + else + let val (r, dep, branches) = find_cong_rule ctx fvar h congs_deps t in + Cong (r, dep, + map (fn (ctx', fixes, assumes, st) => + ((fixes, map (assume o cterm_of (ProofContext.theory_of ctx)) assumes), + mk_tree' ctx' st)) branches) + end + in + mk_tree' ctxt t + end + + +fun inst_tree thy fvar f tr = + let + val cfvar = cterm_of thy fvar + val cf = cterm_of thy f + + fun inst_term t = + subst_bound(f, abstract_over (fvar, t)) + + val inst_thm = forall_elim cf o forall_intr cfvar + + fun inst_tree_aux (Leaf t) = Leaf t + | inst_tree_aux (Cong (crule, deps, branches)) = + Cong (inst_thm crule, deps, map inst_branch branches) + | inst_tree_aux (RCall (t, str)) = + RCall (inst_term t, inst_tree_aux str) + and inst_branch ((fxs, assms), str) = + ((fxs, map (assume o cterm_of thy o inst_term o prop_of) assms), inst_tree_aux str) + in + inst_tree_aux tr + end + + +(* Poor man's contexts: Only fixes and assumes *) +fun compose (fs1, as1) (fs2, as2) = (fs1 @ fs2, as1 @ as2) + +fun export_term (fixes, assumes) = + fold_rev (curry Logic.mk_implies o prop_of) assumes + #> fold_rev (Logic.all o Free) fixes + +fun export_thm thy (fixes, assumes) = + fold_rev (implies_intr o cprop_of) assumes + #> fold_rev (forall_intr o cterm_of thy o Free) fixes + +fun import_thm thy (fixes, athms) = + fold (forall_elim o cterm_of thy o Free) fixes + #> fold Thm.elim_implies athms + + +(* folds in the order of the dependencies of a graph. *) +fun fold_deps G f x = + let + fun fill_table i (T, x) = + case Inttab.lookup T i of + SOME _ => (T, x) + | NONE => + let + val (T', x') = fold fill_table (IntGraph.imm_succs G i) (T, x) + val (v, x'') = f (the o Inttab.lookup T') i x' + in + (Inttab.update (i, v) T', x'') + end + + val (T, x) = fold fill_table (IntGraph.keys G) (Inttab.empty, x) + in + (Inttab.fold (cons o snd) T [], x) + end + +fun traverse_tree rcOp tr = + let + fun traverse_help ctx (Leaf _) _ x = ([], x) + | traverse_help ctx (RCall (t, st)) u x = + rcOp ctx t u (traverse_help ctx st u x) + | traverse_help ctx (Cong (_, deps, branches)) u x = + let + fun sub_step lu i x = + let + val (ctx', subtree) = nth branches i + val used = fold_rev (append o lu) (IntGraph.imm_succs deps i) u + val (subs, x') = traverse_help (compose ctx ctx') subtree used x + val exported_subs = map (apfst (compose ctx')) subs (* FIXME: Right order of composition? *) + in + (exported_subs, x') + end + in + fold_deps deps sub_step x + |> apfst flat + end + in + snd o traverse_help ([], []) tr [] + end + +fun rewrite_by_tree thy h ih x tr = + let + fun rewrite_help _ _ x (Leaf t) = (reflexive (cterm_of thy t), x) + | rewrite_help fix h_as x (RCall (_ $ arg, st)) = + let + val (inner, (lRi,ha)::x') = rewrite_help fix h_as x st (* "a' = a" *) + + val iha = import_thm thy (fix, h_as) ha (* (a', h a') : G *) + |> Conv.fconv_rule (Conv.arg_conv (Conv.comb_conv (Conv.arg_conv (K inner)))) + (* (a, h a) : G *) + val inst_ih = instantiate' [] [SOME (cterm_of thy arg)] ih + val eq = implies_elim (implies_elim inst_ih lRi) iha (* h a = f a *) + + val h_a'_eq_h_a = combination (reflexive (cterm_of thy h)) inner + val h_a_eq_f_a = eq RS eq_reflection + val result = transitive h_a'_eq_h_a h_a_eq_f_a + in + (result, x') + end + | rewrite_help fix h_as x (Cong (crule, deps, branches)) = + let + fun sub_step lu i x = + let + val ((fixes, assumes), st) = nth branches i + val used = map lu (IntGraph.imm_succs deps i) + |> map (fn u_eq => (u_eq RS sym) RS eq_reflection) + |> filter_out Thm.is_reflexive + + val assumes' = map (simplify (HOL_basic_ss addsimps used)) assumes + + val (subeq, x') = rewrite_help (fix @ fixes) (h_as @ assumes') x st + val subeq_exp = export_thm thy (fixes, assumes) (subeq RS meta_eq_to_obj_eq) + in + (subeq_exp, x') + end + + val (subthms, x') = fold_deps deps sub_step x + in + (fold_rev (curry op COMP) subthms crule, x') + end + in + rewrite_help [] [] x tr + end + +end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/decompose.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/decompose.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,105 @@ +(* Title: HOL/Tools/Function/decompose.ML + Author: Alexander Krauss, TU Muenchen + +Graph decomposition using "Shallow Dependency Pairs". +*) + +signature DECOMPOSE = +sig + + val derive_chains : Proof.context -> tactic + -> (Termination.data -> int -> tactic) + -> Termination.data -> int -> tactic + + val decompose_tac : Proof.context -> tactic + -> Termination.ttac + +end + +structure Decompose : DECOMPOSE = +struct + +structure TermGraph = GraphFun(type key = term val ord = TermOrd.fast_term_ord); + + +fun derive_chains ctxt chain_tac cont D = Termination.CALLS (fn (cs, i) => + let + val thy = ProofContext.theory_of ctxt + + fun prove_chain c1 c2 D = + if is_some (Termination.get_chain D c1 c2) then D else + let + val goal = HOLogic.mk_eq (HOLogic.mk_binop @{const_name "Relation.rel_comp"} (c1, c2), + Const (@{const_name Set.empty}, fastype_of c1)) + |> HOLogic.mk_Trueprop (* "C1 O C2 = {}" *) + + val chain = case FundefLib.try_proof (cterm_of thy goal) chain_tac of + FundefLib.Solved thm => SOME thm + | _ => NONE + in + Termination.note_chain c1 c2 chain D + end + in + cont (fold_product prove_chain cs cs D) i + end) + + +fun mk_dgraph D cs = + TermGraph.empty + |> fold (fn c => TermGraph.new_node (c,())) cs + |> fold_product (fn c1 => fn c2 => + if is_none (Termination.get_chain D c1 c2 |> the_default NONE) + then TermGraph.add_edge (c1, c2) else I) + cs cs + + +fun ucomp_empty_tac T = + REPEAT_ALL_NEW (rtac @{thm union_comp_emptyR} + ORELSE' rtac @{thm union_comp_emptyL} + ORELSE' SUBGOAL (fn (_ $ (_ $ (_ $ c1 $ c2) $ _), i) => rtac (T c1 c2) i)) + +fun regroup_calls_tac cs = Termination.CALLS (fn (cs', i) => + let + val is = map (fn c => find_index (curry op aconv c) cs') cs + in + CONVERSION (Conv.arg_conv (Conv.arg_conv (FundefLib.regroup_union_conv is))) i + end) + + +fun solve_trivial_tac D = Termination.CALLS +(fn ([c], i) => + (case Termination.get_chain D c c of + SOME (SOME thm) => rtac @{thm wf_no_loop} i + THEN rtac thm i + | _ => no_tac) + | _ => no_tac) + +fun decompose_tac' ctxt cont err_cont D = Termination.CALLS (fn (cs, i) => + let + val G = mk_dgraph D cs + val sccs = TermGraph.strong_conn G + + fun split [SCC] i = (solve_trivial_tac D i ORELSE cont D i) + | split (SCC::rest) i = + regroup_calls_tac SCC i + THEN rtac @{thm wf_union_compatible} i + THEN rtac @{thm less_by_empty} (i + 2) + THEN ucomp_empty_tac (the o the oo Termination.get_chain D) (i + 2) + THEN split rest (i + 1) + THEN (solve_trivial_tac D i ORELSE cont D i) + in + if length sccs > 1 then split sccs i + else solve_trivial_tac D i ORELSE err_cont D i + end) + +fun decompose_tac ctxt chain_tac cont err_cont = + derive_chains ctxt chain_tac + (decompose_tac' ctxt cont err_cont) + +fun auto_decompose_tac ctxt = + Termination.TERMINATION ctxt + (decompose_tac ctxt (auto_tac (local_clasimpset_of ctxt)) + (K (K all_tac)) (K (K no_tac))) + + +end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/descent.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/descent.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,44 @@ +(* Title: HOL/Tools/Function/descent.ML + Author: Alexander Krauss, TU Muenchen + +Descent proofs for termination +*) + + +signature DESCENT = +sig + + val derive_diag : Proof.context -> tactic -> (Termination.data -> int -> tactic) + -> Termination.data -> int -> tactic + + val derive_all : Proof.context -> tactic -> (Termination.data -> int -> tactic) + -> Termination.data -> int -> tactic + +end + + +structure Descent : DESCENT = +struct + +fun gen_descent diag ctxt tac cont D = Termination.CALLS (fn (cs, i) => + let + val thy = ProofContext.theory_of ctxt + val measures_of = Termination.get_measures D + + fun derive c D = + let + val (_, p, _, q, _, _) = Termination.dest_call D c + in + if diag andalso p = q + then fold (fn m => Termination.derive_descent thy tac c m m) (measures_of p) D + else fold_product (Termination.derive_descent thy tac c) + (measures_of p) (measures_of q) D + end + in + cont (FundefCommon.PROFILE "deriving descents" (fold derive cs) D) i + end) + +val derive_diag = gen_descent true +val derive_all = gen_descent false + +end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/fundef.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/fundef.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,226 @@ +(* Title: HOL/Tools/Function/fundef.ML + Author: Alexander Krauss, TU Muenchen + +A package for general recursive function definitions. +Isar commands. +*) + +signature FUNDEF = +sig + val add_fundef : (binding * typ option * mixfix) list + -> (Attrib.binding * term) list + -> FundefCommon.fundef_config + -> local_theory + -> Proof.state + val add_fundef_cmd : (binding * string option * mixfix) list + -> (Attrib.binding * string) list + -> FundefCommon.fundef_config + -> local_theory + -> Proof.state + + val termination_proof : term option -> local_theory -> Proof.state + val termination_proof_cmd : string option -> local_theory -> Proof.state + val termination : term option -> local_theory -> Proof.state + val termination_cmd : string option -> local_theory -> Proof.state + + val setup : theory -> theory + val get_congs : Proof.context -> thm list +end + + +structure Fundef : FUNDEF = +struct + +open FundefLib +open FundefCommon + +val simp_attribs = map (Attrib.internal o K) + [Simplifier.simp_add, + Code.add_default_eqn_attribute, + Nitpick_Const_Simp_Thms.add, + Quickcheck_RecFun_Simp_Thms.add] + +val psimp_attribs = map (Attrib.internal o K) + [Simplifier.simp_add, + Nitpick_Const_Psimp_Thms.add] + +fun note_theorem ((name, atts), ths) = + LocalTheory.note Thm.generatedK ((Binding.qualified_name name, atts), ths) + +fun mk_defname fixes = fixes |> map (fst o fst) |> space_implode "_" + +fun add_simps fnames post sort extra_qualify label moreatts simps lthy = + let + val spec = post simps + |> map (apfst (apsnd (fn ats => moreatts @ ats))) + |> map (apfst (apfst extra_qualify)) + + val (saved_spec_simps, lthy) = + fold_map (LocalTheory.note Thm.generatedK) spec lthy + + val saved_simps = flat (map snd saved_spec_simps) + val simps_by_f = sort saved_simps + + fun add_for_f fname simps = + note_theorem ((Long_Name.qualify fname label, []), simps) #> snd + in + (saved_simps, + fold2 add_for_f fnames simps_by_f lthy) + end + +fun gen_add_fundef is_external prep default_constraint fixspec eqns config lthy = + let + val constrn_fxs = map (fn (b, T, mx) => (b, SOME (the_default default_constraint T), mx)) + val ((fixes0, spec0), ctxt') = prep (constrn_fxs fixspec) eqns lthy + val fixes = map (apfst (apfst Binding.name_of)) fixes0; + val spec = map (fn (bnd, prop) => (bnd, [prop])) spec0; + val (eqs, post, sort_cont, cnames) = FundefCommon.get_preproc lthy config ctxt' fixes spec + + val defname = mk_defname fixes + + val ((goalstate, cont), lthy) = + FundefMutual.prepare_fundef_mutual config defname fixes eqs lthy + + fun afterqed [[proof]] lthy = + let + val FundefResult {fs, R, psimps, trsimps, simple_pinducts, termination, + domintros, cases, ...} = + cont (Thm.close_derivation proof) + + val fnames = map (fst o fst) fixes + val qualify = Long_Name.qualify defname + val addsmps = add_simps fnames post sort_cont + + val (((psimps', pinducts'), (_, [termination'])), lthy) = + lthy + |> addsmps (Binding.qualify false "partial") "psimps" + psimp_attribs psimps + ||> fold_option (snd oo addsmps I "simps" simp_attribs) trsimps + ||>> note_theorem ((qualify "pinduct", + [Attrib.internal (K (RuleCases.case_names cnames)), + Attrib.internal (K (RuleCases.consumes 1)), + Attrib.internal (K (Induct.induct_pred ""))]), simple_pinducts) + ||>> note_theorem ((qualify "termination", []), [termination]) + ||> (snd o note_theorem ((qualify "cases", + [Attrib.internal (K (RuleCases.case_names cnames))]), [cases])) + ||> fold_option (snd oo curry note_theorem (qualify "domintros", [])) domintros + + val cdata = FundefCtxData { add_simps=addsmps, case_names=cnames, psimps=psimps', + pinducts=snd pinducts', termination=termination', + fs=fs, R=R, defname=defname } + val _ = + if not is_external then () + else Specification.print_consts lthy (K false) (map fst fixes) + in + lthy + |> LocalTheory.declaration (add_fundef_data o morph_fundef_data cdata) + end + in + lthy + |> is_external ? LocalTheory.set_group (serial_string ()) + |> Proof.theorem_i NONE afterqed [[(Logic.unprotect (concl_of goalstate), [])]] + |> Proof.refine (Method.primitive_text (fn _ => goalstate)) |> Seq.hd + end + +val add_fundef = gen_add_fundef false Specification.check_spec (TypeInfer.anyT HOLogic.typeS) +val add_fundef_cmd = gen_add_fundef true Specification.read_spec "_::type" + +fun gen_termination_proof prep_term raw_term_opt lthy = + let + val term_opt = Option.map (prep_term lthy) raw_term_opt + val data = the (case term_opt of + SOME t => (import_fundef_data t lthy + handle Option.Option => + error ("Not a function: " ^ quote (Syntax.string_of_term lthy t))) + | NONE => (import_last_fundef lthy handle Option.Option => error "Not a function")) + + val FundefCtxData { termination, R, add_simps, case_names, psimps, + pinducts, defname, ...} = data + val domT = domain_type (fastype_of R) + val goal = HOLogic.mk_Trueprop + (HOLogic.mk_all ("x", domT, mk_acc domT R $ Free ("x", domT))) + fun afterqed [[totality]] lthy = + let + val totality = Thm.close_derivation totality + val remove_domain_condition = + full_simplify (HOL_basic_ss addsimps [totality, True_implies_equals]) + val tsimps = map remove_domain_condition psimps + val tinduct = map remove_domain_condition pinducts + val qualify = Long_Name.qualify defname; + in + lthy + |> add_simps I "simps" simp_attribs tsimps |> snd + |> note_theorem + ((qualify "induct", + [Attrib.internal (K (RuleCases.case_names case_names))]), + tinduct) |> snd + end + in + lthy + |> ProofContext.note_thmss "" + [((Binding.empty, [ContextRules.rule_del]), [([allI], [])])] |> snd + |> ProofContext.note_thmss "" + [((Binding.empty, [ContextRules.intro_bang (SOME 1)]), [([allI], [])])] |> snd + |> ProofContext.note_thmss "" + [((Binding.name "termination", [ContextRules.intro_bang (SOME 0)]), + [([Goal.norm_result termination], [])])] |> snd + |> Proof.theorem_i NONE afterqed [[(goal, [])]] + end + +val termination_proof = gen_termination_proof Syntax.check_term; +val termination_proof_cmd = gen_termination_proof Syntax.read_term; + +fun termination term_opt lthy = + lthy + |> LocalTheory.set_group (serial_string ()) + |> termination_proof term_opt; + +fun termination_cmd term_opt lthy = + lthy + |> LocalTheory.set_group (serial_string ()) + |> termination_proof_cmd term_opt; + + +(* Datatype hook to declare datatype congs as "fundef_congs" *) + + +fun add_case_cong n thy = + Context.theory_map (FundefCtxTree.map_fundef_congs (Thm.add_thm + (Datatype.get_datatype thy n |> the + |> #case_cong + |> safe_mk_meta_eq))) + thy + +val setup_case_cong = Datatype.interpretation (K (fold add_case_cong)) + + +(* setup *) + +val setup = + Attrib.setup @{binding fundef_cong} + (Attrib.add_del FundefCtxTree.cong_add FundefCtxTree.cong_del) + "declaration of congruence rule for function definitions" + #> setup_case_cong + #> FundefRelation.setup + #> FundefCommon.TerminationSimps.setup + +val get_congs = FundefCtxTree.get_fundef_congs + + +(* outer syntax *) + +local structure P = OuterParse and K = OuterKeyword in + +val _ = + OuterSyntax.local_theory_to_proof "function" "define general recursive functions" K.thy_goal + (fundef_parser default_config + >> (fn ((config, fixes), statements) => add_fundef_cmd fixes statements config)); + +val _ = + OuterSyntax.local_theory_to_proof "termination" "prove termination of a recursive function" K.thy_goal + (Scan.option P.term >> termination_cmd); + +end; + + +end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/fundef_common.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/fundef_common.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,343 @@ +(* Title: HOL/Tools/Function/fundef_common.ML + Author: Alexander Krauss, TU Muenchen + +A package for general recursive function definitions. +Common definitions and other infrastructure. +*) + +structure FundefCommon = +struct + +local open FundefLib in + +(* Profiling *) +val profile = ref false; + +fun PROFILE msg = if !profile then timeap_msg msg else I + + +val acc_const_name = @{const_name "accp"} +fun mk_acc domT R = + Const (acc_const_name, (domT --> domT --> HOLogic.boolT) --> domT --> HOLogic.boolT) $ R + +val function_name = suffix "C" +val graph_name = suffix "_graph" +val rel_name = suffix "_rel" +val dom_name = suffix "_dom" + +(* Termination rules *) + +structure TerminationRule = GenericDataFun +( + type T = thm list + val empty = [] + val extend = I + fun merge _ = Thm.merge_thms +); + +val get_termination_rules = TerminationRule.get +val store_termination_rule = TerminationRule.map o cons +val apply_termination_rule = resolve_tac o get_termination_rules o Context.Proof + + +(* Function definition result data *) + +datatype fundef_result = + FundefResult of + { + fs: term list, + G: term, + R: term, + + psimps : thm list, + trsimps : thm list option, + + simple_pinducts : thm list, + cases : thm, + termination : thm, + domintros : thm list option + } + + +datatype fundef_context_data = + FundefCtxData of + { + defname : string, + + (* contains no logical entities: invariant under morphisms *) + add_simps : (binding -> binding) -> string -> Attrib.src list -> thm list + -> local_theory -> thm list * local_theory, + case_names : string list, + + fs : term list, + R : term, + + psimps: thm list, + pinducts: thm list, + termination: thm + } + +fun morph_fundef_data (FundefCtxData {add_simps, case_names, fs, R, + psimps, pinducts, termination, defname}) phi = + let + val term = Morphism.term phi val thm = Morphism.thm phi val fact = Morphism.fact phi + val name = Binding.name_of o Morphism.binding phi o Binding.name + in + FundefCtxData { add_simps = add_simps, case_names = case_names, + fs = map term fs, R = term R, psimps = fact psimps, + pinducts = fact pinducts, termination = thm termination, + defname = name defname } + end + +structure FundefData = GenericDataFun +( + type T = (term * fundef_context_data) Item_Net.T; + val empty = Item_Net.init + (op aconv o pairself fst : (term * fundef_context_data) * (term * fundef_context_data) -> bool) + fst; + val copy = I; + val extend = I; + fun merge _ (tab1, tab2) = Item_Net.merge (tab1, tab2) +); + +val get_fundef = FundefData.get o Context.Proof; + + +(* Generally useful?? *) +fun lift_morphism thy f = + let + val term = Drule.term_rule thy f + in + Morphism.thm_morphism f $> Morphism.term_morphism term + $> Morphism.typ_morphism (Logic.type_map term) + end + +fun import_fundef_data t ctxt = + let + val thy = ProofContext.theory_of ctxt + val ct = cterm_of thy t + val inst_morph = lift_morphism thy o Thm.instantiate + + fun match (trm, data) = + SOME (morph_fundef_data data (inst_morph (Thm.match (cterm_of thy trm, ct)))) + handle Pattern.MATCH => NONE + in + get_first match (Item_Net.retrieve (get_fundef ctxt) t) + end + +fun import_last_fundef ctxt = + case Item_Net.content (get_fundef ctxt) of + [] => NONE + | (t, data) :: _ => + let + val ([t'], ctxt') = Variable.import_terms true [t] ctxt + in + import_fundef_data t' ctxt' + end + +val all_fundef_data = Item_Net.content o get_fundef + +fun add_fundef_data (data as FundefCtxData {fs, termination, ...}) = + FundefData.map (fold (fn f => Item_Net.insert (f, data)) fs) + #> store_termination_rule termination + + +(* Simp rules for termination proofs *) + +structure TerminationSimps = NamedThmsFun +( + val name = "termination_simp" + val description = "Simplification rule for termination proofs" +); + + +(* Default Termination Prover *) + +structure TerminationProver = GenericDataFun +( + type T = Proof.context -> Proof.method + val empty = (fn _ => error "Termination prover not configured") + val extend = I + fun merge _ (a,b) = b (* FIXME *) +); + +val set_termination_prover = TerminationProver.put +val get_termination_prover = TerminationProver.get o Context.Proof + + +(* Configuration management *) +datatype fundef_opt + = Sequential + | Default of string + | DomIntros + | Tailrec + +datatype fundef_config + = FundefConfig of + { + sequential: bool, + default: string, + domintros: bool, + tailrec: bool + } + +fun apply_opt Sequential (FundefConfig {sequential, default, domintros,tailrec}) = + FundefConfig {sequential=true, default=default, domintros=domintros, tailrec=tailrec} + | apply_opt (Default d) (FundefConfig {sequential, default, domintros,tailrec}) = + FundefConfig {sequential=sequential, default=d, domintros=domintros, tailrec=tailrec} + | apply_opt DomIntros (FundefConfig {sequential, default, domintros,tailrec}) = + FundefConfig {sequential=sequential, default=default, domintros=true,tailrec=tailrec} + | apply_opt Tailrec (FundefConfig {sequential, default, domintros,tailrec}) = + FundefConfig {sequential=sequential, default=default, domintros=domintros,tailrec=true} + +val default_config = + FundefConfig { sequential=false, default="%x. undefined" (*FIXME dynamic scoping*), + domintros=false, tailrec=false } + + +(* Analyzing function equations *) + +fun split_def ctxt geq = + let + fun input_error msg = cat_lines [msg, Syntax.string_of_term ctxt geq] + val qs = Term.strip_qnt_vars "all" geq + val imp = Term.strip_qnt_body "all" geq + val (gs, eq) = Logic.strip_horn imp + + val (f_args, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eq) + handle TERM _ => error (input_error "Not an equation") + + val (head, args) = strip_comb f_args + + val fname = fst (dest_Free head) + handle TERM _ => error (input_error "Head symbol must not be a bound variable") + in + (fname, qs, gs, args, rhs) + end + +(* Check for all sorts of errors in the input *) +fun check_defs ctxt fixes eqs = + let + val fnames = map (fst o fst) fixes + + fun check geq = + let + fun input_error msg = error (cat_lines [msg, Syntax.string_of_term ctxt geq]) + + val fqgar as (fname, qs, gs, args, rhs) = split_def ctxt geq + + val _ = fname mem fnames + orelse input_error + ("Head symbol of left hand side must be " + ^ plural "" "one out of " fnames ^ commas_quote fnames) + + val _ = length args > 0 orelse input_error "Function has no arguments:" + + fun add_bvs t is = add_loose_bnos (t, 0, is) + val rvs = (add_bvs rhs [] \\ fold add_bvs args []) + |> map (fst o nth (rev qs)) + + val _ = null rvs orelse input_error + ("Variable" ^ plural " " "s " rvs ^ commas_quote rvs + ^ " occur" ^ plural "s" "" rvs ^ " on right hand side only:") + + val _ = forall (not o Term.exists_subterm + (fn Free (n, _) => n mem fnames | _ => false)) (gs @ args) + orelse input_error "Defined function may not occur in premises or arguments" + + val freeargs = map (fn t => subst_bounds (rev (map Free qs), t)) args + val funvars = filter (fn q => exists (exists_subterm (fn (Free q') $ _ => q = q' | _ => false)) freeargs) qs + val _ = null funvars + orelse (warning (cat_lines + ["Bound variable" ^ plural " " "s " funvars + ^ commas_quote (map fst funvars) ^ + " occur" ^ plural "s" "" funvars ^ " in function position.", + "Misspelled constructor???"]); true) + in + (fname, length args) + end + + val _ = AList.group (op =) (map check eqs) + |> map (fn (fname, ars) => + length (distinct (op =) ars) = 1 + orelse error ("Function " ^ quote fname ^ + " has different numbers of arguments in different equations")) + + fun check_sorts ((fname, fT), _) = + Sorts.of_sort (Sign.classes_of (ProofContext.theory_of ctxt)) (fT, HOLogic.typeS) + orelse error (cat_lines + ["Type of " ^ quote fname ^ " is not of sort " ^ quote "type" ^ ":", + setmp show_sorts true (Syntax.string_of_typ ctxt) fT]) + + val _ = map check_sorts fixes + in + () + end + +(* Preprocessors *) + +type fixes = ((string * typ) * mixfix) list +type 'a spec = (Attrib.binding * 'a list) list +type preproc = fundef_config -> Proof.context -> fixes -> term spec + -> (term list * (thm list -> thm spec) * (thm list -> thm list list) * string list) + +val fname_of = fst o dest_Free o fst o strip_comb o fst + o HOLogic.dest_eq o HOLogic.dest_Trueprop o Logic.strip_imp_concl o snd o dest_all_all + +fun mk_case_names i "" k = mk_case_names i (string_of_int (i + 1)) k + | mk_case_names _ n 0 = [] + | mk_case_names _ n 1 = [n] + | mk_case_names _ n k = map (fn i => n ^ "_" ^ string_of_int i) (1 upto k) + +fun empty_preproc check _ ctxt fixes spec = + let + val (bnds, tss) = split_list spec + val ts = flat tss + val _ = check ctxt fixes ts + val fnames = map (fst o fst) fixes + val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) ts + + fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) + (indices ~~ xs) + |> map (map snd) + + (* using theorem names for case name currently disabled *) + val cnames = map_index (fn (i, _) => mk_case_names i "" 1) bnds |> flat + in + (ts, curry op ~~ bnds o Library.unflat tss, sort, cnames) + end + +structure Preprocessor = GenericDataFun +( + type T = preproc + val empty : T = empty_preproc check_defs + val extend = I + fun merge _ (a, _) = a +); + +val get_preproc = Preprocessor.get o Context.Proof +val set_preproc = Preprocessor.map o K + + + +local + structure P = OuterParse and K = OuterKeyword + + val option_parser = + P.group "option" ((P.reserved "sequential" >> K Sequential) + || ((P.reserved "default" |-- P.term) >> Default) + || (P.reserved "domintros" >> K DomIntros) + || (P.reserved "tailrec" >> K Tailrec)) + + fun config_parser default = + (Scan.optional (P.$$$ "(" |-- P.!!! (P.list1 option_parser) --| P.$$$ ")") []) + >> (fn opts => fold apply_opt opts default) +in + fun fundef_parser default_cfg = + config_parser default_cfg -- P.fixes -- SpecParse.where_alt_specs +end + + +end +end + diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/fundef_core.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/fundef_core.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,954 @@ +(* Title: HOL/Tools/Function/fundef_core.ML + Author: Alexander Krauss, TU Muenchen + +A package for general recursive function definitions: +Main functionality. +*) + +signature FUNDEF_CORE = +sig + val prepare_fundef : FundefCommon.fundef_config + -> string (* defname *) + -> ((bstring * typ) * mixfix) list (* defined symbol *) + -> ((bstring * typ) list * term list * term * term) list (* specification *) + -> local_theory + + -> (term (* f *) + * thm (* goalstate *) + * (thm -> FundefCommon.fundef_result) (* continuation *) + ) * local_theory + +end + +structure FundefCore : FUNDEF_CORE = +struct + +val boolT = HOLogic.boolT +val mk_eq = HOLogic.mk_eq + +open FundefLib +open FundefCommon + +datatype globals = + Globals of { + fvar: term, + domT: typ, + ranT: typ, + h: term, + y: term, + x: term, + z: term, + a: term, + P: term, + D: term, + Pbool:term +} + + +datatype rec_call_info = + RCInfo of + { + RIvs: (string * typ) list, (* Call context: fixes and assumes *) + CCas: thm list, + rcarg: term, (* The recursive argument *) + + llRI: thm, + h_assum: term + } + + +datatype clause_context = + ClauseContext of + { + ctxt : Proof.context, + + qs : term list, + gs : term list, + lhs: term, + rhs: term, + + cqs: cterm list, + ags: thm list, + case_hyp : thm + } + + +fun transfer_clause_ctx thy (ClauseContext { ctxt, qs, gs, lhs, rhs, cqs, ags, case_hyp }) = + ClauseContext { ctxt = ProofContext.transfer thy ctxt, + qs = qs, gs = gs, lhs = lhs, rhs = rhs, cqs = cqs, ags = ags, case_hyp = case_hyp } + + +datatype clause_info = + ClauseInfo of + { + no: int, + qglr : ((string * typ) list * term list * term * term), + cdata : clause_context, + + tree: FundefCtxTree.ctx_tree, + lGI: thm, + RCs: rec_call_info list + } + + +(* Theory dependencies. *) +val Pair_inject = @{thm Product_Type.Pair_inject}; + +val acc_induct_rule = @{thm accp_induct_rule}; + +val ex1_implies_ex = @{thm FunDef.fundef_ex1_existence}; +val ex1_implies_un = @{thm FunDef.fundef_ex1_uniqueness}; +val ex1_implies_iff = @{thm FunDef.fundef_ex1_iff}; + +val acc_downward = @{thm accp_downward}; +val accI = @{thm accp.accI}; +val case_split = @{thm HOL.case_split}; +val fundef_default_value = @{thm FunDef.fundef_default_value}; +val not_acc_down = @{thm not_accp_down}; + + + +fun find_calls tree = + let + fun add_Ri (fixes,assumes) (_ $ arg) _ (_, xs) = ([], (fixes, assumes, arg) :: xs) + | add_Ri _ _ _ _ = raise Match + in + rev (FundefCtxTree.traverse_tree add_Ri tree []) + end + + +(** building proof obligations *) + +fun mk_compat_proof_obligations domT ranT fvar f glrs = + let + fun mk_impl ((qs, gs, lhs, rhs),(qs', gs', lhs', rhs')) = + let + val shift = incr_boundvars (length qs') + in + Logic.mk_implies + (HOLogic.mk_Trueprop (HOLogic.eq_const domT $ shift lhs $ lhs'), + HOLogic.mk_Trueprop (HOLogic.eq_const ranT $ shift rhs $ rhs')) + |> fold_rev (curry Logic.mk_implies) (map shift gs @ gs') + |> fold_rev (fn (n,T) => fn b => Term.all T $ Abs(n,T,b)) (qs @ qs') + |> curry abstract_over fvar + |> curry subst_bound f + end + in + map mk_impl (unordered_pairs glrs) + end + + +fun mk_completeness (Globals {x, Pbool, ...}) clauses qglrs = + let + fun mk_case (ClauseContext {qs, gs, lhs, ...}, (oqs, _, _, _)) = + HOLogic.mk_Trueprop Pbool + |> curry Logic.mk_implies (HOLogic.mk_Trueprop (mk_eq (x, lhs))) + |> fold_rev (curry Logic.mk_implies) gs + |> fold_rev mk_forall_rename (map fst oqs ~~ qs) + in + HOLogic.mk_Trueprop Pbool + |> fold_rev (curry Logic.mk_implies o mk_case) (clauses ~~ qglrs) + |> mk_forall_rename ("x", x) + |> mk_forall_rename ("P", Pbool) + end + +(** making a context with it's own local bindings **) + +fun mk_clause_context x ctxt (pre_qs,pre_gs,pre_lhs,pre_rhs) = + let + val (qs, ctxt') = Variable.variant_fixes (map fst pre_qs) ctxt + |>> map2 (fn (_, T) => fn n => Free (n, T)) pre_qs + + val thy = ProofContext.theory_of ctxt' + + fun inst t = subst_bounds (rev qs, t) + val gs = map inst pre_gs + val lhs = inst pre_lhs + val rhs = inst pre_rhs + + val cqs = map (cterm_of thy) qs + val ags = map (assume o cterm_of thy) gs + + val case_hyp = assume (cterm_of thy (HOLogic.mk_Trueprop (mk_eq (x, lhs)))) + in + ClauseContext { ctxt = ctxt', qs = qs, gs = gs, lhs = lhs, rhs = rhs, + cqs = cqs, ags = ags, case_hyp = case_hyp } + end + + +(* lowlevel term function *) +fun abstract_over_list vs body = + let + exception SAME; + fun abs lev v tm = + if v aconv tm then Bound lev + else + (case tm of + Abs (a, T, t) => Abs (a, T, abs (lev + 1) v t) + | t $ u => (abs lev v t $ (abs lev v u handle SAME => u) handle SAME => t $ abs lev v u) + | _ => raise SAME); + in + fold_index (fn (i,v) => fn t => abs i v t handle SAME => t) vs body + end + + + +fun mk_clause_info globals G f no cdata qglr tree RCs GIntro_thm RIntro_thms = + let + val Globals {h, fvar, x, ...} = globals + + val ClauseContext { ctxt, qs, cqs, ags, ... } = cdata + val cert = Thm.cterm_of (ProofContext.theory_of ctxt) + + (* Instantiate the GIntro thm with "f" and import into the clause context. *) + val lGI = GIntro_thm + |> forall_elim (cert f) + |> fold forall_elim cqs + |> fold Thm.elim_implies ags + + fun mk_call_info (rcfix, rcassm, rcarg) RI = + let + val llRI = RI + |> fold forall_elim cqs + |> fold (forall_elim o cert o Free) rcfix + |> fold Thm.elim_implies ags + |> fold Thm.elim_implies rcassm + + val h_assum = + HOLogic.mk_Trueprop (G $ rcarg $ (h $ rcarg)) + |> fold_rev (curry Logic.mk_implies o prop_of) rcassm + |> fold_rev (Logic.all o Free) rcfix + |> Pattern.rewrite_term (ProofContext.theory_of ctxt) [(f, h)] [] + |> abstract_over_list (rev qs) + in + RCInfo {RIvs=rcfix, rcarg=rcarg, CCas=rcassm, llRI=llRI, h_assum=h_assum} + end + + val RC_infos = map2 mk_call_info RCs RIntro_thms + in + ClauseInfo + { + no=no, + cdata=cdata, + qglr=qglr, + + lGI=lGI, + RCs=RC_infos, + tree=tree + } + end + + + + + + + +(* replace this by a table later*) +fun store_compat_thms 0 thms = [] + | store_compat_thms n thms = + let + val (thms1, thms2) = chop n thms + in + (thms1 :: store_compat_thms (n - 1) thms2) + end + +(* expects i <= j *) +fun lookup_compat_thm i j cts = + nth (nth cts (i - 1)) (j - i) + +(* Returns "Gsi, Gsj, lhs_i = lhs_j |-- rhs_j_f = rhs_i_f" *) +(* if j < i, then turn around *) +fun get_compat_thm thy cts i j ctxi ctxj = + let + val ClauseContext {cqs=cqsi,ags=agsi,lhs=lhsi,...} = ctxi + val ClauseContext {cqs=cqsj,ags=agsj,lhs=lhsj,...} = ctxj + + val lhsi_eq_lhsj = cterm_of thy (HOLogic.mk_Trueprop (mk_eq (lhsi, lhsj))) + in if j < i then + let + val compat = lookup_compat_thm j i cts + in + compat (* "!!qj qi. Gsj => Gsi => lhsj = lhsi ==> rhsj = rhsi" *) + |> fold forall_elim (cqsj @ cqsi) (* "Gsj => Gsi => lhsj = lhsi ==> rhsj = rhsi" *) + |> fold Thm.elim_implies agsj + |> fold Thm.elim_implies agsi + |> Thm.elim_implies ((assume lhsi_eq_lhsj) RS sym) (* "Gsj, Gsi, lhsi = lhsj |-- rhsj = rhsi" *) + end + else + let + val compat = lookup_compat_thm i j cts + in + compat (* "!!qi qj. Gsi => Gsj => lhsi = lhsj ==> rhsi = rhsj" *) + |> fold forall_elim (cqsi @ cqsj) (* "Gsi => Gsj => lhsi = lhsj ==> rhsi = rhsj" *) + |> fold Thm.elim_implies agsi + |> fold Thm.elim_implies agsj + |> Thm.elim_implies (assume lhsi_eq_lhsj) + |> (fn thm => thm RS sym) (* "Gsi, Gsj, lhsi = lhsj |-- rhsj = rhsi" *) + end + end + + + + +(* Generates the replacement lemma in fully quantified form. *) +fun mk_replacement_lemma thy h ih_elim clause = + let + val ClauseInfo {cdata=ClauseContext {qs, lhs, rhs, cqs, ags, case_hyp, ...}, RCs, tree, ...} = clause + local open Conv in + val ih_conv = arg1_conv o arg_conv o arg_conv + end + + val ih_elim_case = Conv.fconv_rule (ih_conv (K (case_hyp RS eq_reflection))) ih_elim + + val Ris = map (fn RCInfo {llRI, ...} => llRI) RCs + val h_assums = map (fn RCInfo {h_assum, ...} => assume (cterm_of thy (subst_bounds (rev qs, h_assum)))) RCs + + val (eql, _) = FundefCtxTree.rewrite_by_tree thy h ih_elim_case (Ris ~~ h_assums) tree + + val replace_lemma = (eql RS meta_eq_to_obj_eq) + |> implies_intr (cprop_of case_hyp) + |> fold_rev (implies_intr o cprop_of) h_assums + |> fold_rev (implies_intr o cprop_of) ags + |> fold_rev forall_intr cqs + |> Thm.close_derivation + in + replace_lemma + end + + +fun mk_uniqueness_clause thy globals f compat_store clausei clausej RLj = + let + val Globals {h, y, x, fvar, ...} = globals + val ClauseInfo {no=i, cdata=cctxi as ClauseContext {ctxt=ctxti, lhs=lhsi, case_hyp, ...}, ...} = clausei + val ClauseInfo {no=j, qglr=cdescj, RCs=RCsj, ...} = clausej + + val cctxj as ClauseContext {ags = agsj', lhs = lhsj', rhs = rhsj', qs = qsj', cqs = cqsj', ...} + = mk_clause_context x ctxti cdescj + + val rhsj'h = Pattern.rewrite_term thy [(fvar,h)] [] rhsj' + val compat = get_compat_thm thy compat_store i j cctxi cctxj + val Ghsj' = map (fn RCInfo {h_assum, ...} => assume (cterm_of thy (subst_bounds (rev qsj', h_assum)))) RCsj + + val RLj_import = + RLj |> fold forall_elim cqsj' + |> fold Thm.elim_implies agsj' + |> fold Thm.elim_implies Ghsj' + + val y_eq_rhsj'h = assume (cterm_of thy (HOLogic.mk_Trueprop (mk_eq (y, rhsj'h)))) + val lhsi_eq_lhsj' = assume (cterm_of thy (HOLogic.mk_Trueprop (mk_eq (lhsi, lhsj')))) (* lhs_i = lhs_j' |-- lhs_i = lhs_j' *) + in + (trans OF [case_hyp, lhsi_eq_lhsj']) (* lhs_i = lhs_j' |-- x = lhs_j' *) + |> implies_elim RLj_import (* Rj1' ... Rjk', lhs_i = lhs_j' |-- rhs_j'_h = rhs_j'_f *) + |> (fn it => trans OF [it, compat]) (* lhs_i = lhs_j', Gj', Rj1' ... Rjk' |-- rhs_j'_h = rhs_i_f *) + |> (fn it => trans OF [y_eq_rhsj'h, it]) (* lhs_i = lhs_j', Gj', Rj1' ... Rjk', y = rhs_j_h' |-- y = rhs_i_f *) + |> fold_rev (implies_intr o cprop_of) Ghsj' + |> fold_rev (implies_intr o cprop_of) agsj' (* lhs_i = lhs_j' , y = rhs_j_h' |-- Gj', Rj1'...Rjk' ==> y = rhs_i_f *) + |> implies_intr (cprop_of y_eq_rhsj'h) + |> implies_intr (cprop_of lhsi_eq_lhsj') + |> fold_rev forall_intr (cterm_of thy h :: cqsj') + end + + + +fun mk_uniqueness_case ctxt thy globals G f ihyp ih_intro G_cases compat_store clauses rep_lemmas clausei = + let + val Globals {x, y, ranT, fvar, ...} = globals + val ClauseInfo {cdata = ClauseContext {lhs, rhs, qs, cqs, ags, case_hyp, ...}, lGI, RCs, ...} = clausei + val rhsC = Pattern.rewrite_term thy [(fvar, f)] [] rhs + + val ih_intro_case = full_simplify (HOL_basic_ss addsimps [case_hyp]) ih_intro + + fun prep_RC (RCInfo {llRI, RIvs, CCas, ...}) = (llRI RS ih_intro_case) + |> fold_rev (implies_intr o cprop_of) CCas + |> fold_rev (forall_intr o cterm_of thy o Free) RIvs + + val existence = fold (curry op COMP o prep_RC) RCs lGI + + val P = cterm_of thy (mk_eq (y, rhsC)) + val G_lhs_y = assume (cterm_of thy (HOLogic.mk_Trueprop (G $ lhs $ y))) + + val unique_clauses = map2 (mk_uniqueness_clause thy globals f compat_store clausei) clauses rep_lemmas + + val uniqueness = G_cases + |> forall_elim (cterm_of thy lhs) + |> forall_elim (cterm_of thy y) + |> forall_elim P + |> Thm.elim_implies G_lhs_y + |> fold Thm.elim_implies unique_clauses + |> implies_intr (cprop_of G_lhs_y) + |> forall_intr (cterm_of thy y) + + val P2 = cterm_of thy (lambda y (G $ lhs $ y)) (* P2 y := (lhs, y): G *) + + val exactly_one = + ex1I |> instantiate' [SOME (ctyp_of thy ranT)] [SOME P2, SOME (cterm_of thy rhsC)] + |> curry (op COMP) existence + |> curry (op COMP) uniqueness + |> simplify (HOL_basic_ss addsimps [case_hyp RS sym]) + |> implies_intr (cprop_of case_hyp) + |> fold_rev (implies_intr o cprop_of) ags + |> fold_rev forall_intr cqs + + val function_value = + existence + |> implies_intr ihyp + |> implies_intr (cprop_of case_hyp) + |> forall_intr (cterm_of thy x) + |> forall_elim (cterm_of thy lhs) + |> curry (op RS) refl + in + (exactly_one, function_value) + end + + + + +fun prove_stuff ctxt globals G f R clauses complete compat compat_store G_elim f_def = + let + val Globals {h, domT, ranT, x, ...} = globals + val thy = ProofContext.theory_of ctxt + + (* Inductive Hypothesis: !!z. (z,x):R ==> EX!y. (z,y):G *) + val ihyp = Term.all domT $ Abs ("z", domT, + Logic.mk_implies (HOLogic.mk_Trueprop (R $ Bound 0 $ x), + HOLogic.mk_Trueprop (Const ("Ex1", (ranT --> boolT) --> boolT) $ + Abs ("y", ranT, G $ Bound 1 $ Bound 0)))) + |> cterm_of thy + + val ihyp_thm = assume ihyp |> Thm.forall_elim_vars 0 + val ih_intro = ihyp_thm RS (f_def RS ex1_implies_ex) + val ih_elim = ihyp_thm RS (f_def RS ex1_implies_un) + |> instantiate' [] [NONE, SOME (cterm_of thy h)] + + val _ = Output.debug (K "Proving Replacement lemmas...") + val repLemmas = map (mk_replacement_lemma thy h ih_elim) clauses + + val _ = Output.debug (K "Proving cases for unique existence...") + val (ex1s, values) = + split_list (map (mk_uniqueness_case ctxt thy globals G f ihyp ih_intro G_elim compat_store clauses repLemmas) clauses) + + val _ = Output.debug (K "Proving: Graph is a function") + val graph_is_function = complete + |> Thm.forall_elim_vars 0 + |> fold (curry op COMP) ex1s + |> implies_intr (ihyp) + |> implies_intr (cterm_of thy (HOLogic.mk_Trueprop (mk_acc domT R $ x))) + |> forall_intr (cterm_of thy x) + |> (fn it => Drule.compose_single (it, 2, acc_induct_rule)) (* "EX! y. (?x,y):G" *) + |> (fn it => fold (forall_intr o cterm_of thy o Var) (Term.add_vars (prop_of it) []) it) + + val goalstate = Conjunction.intr graph_is_function complete + |> Thm.close_derivation + |> Goal.protect + |> fold_rev (implies_intr o cprop_of) compat + |> implies_intr (cprop_of complete) + in + (goalstate, values) + end + + +fun define_graph Gname fvar domT ranT clauses RCss lthy = + let + val GT = domT --> ranT --> boolT + val Gvar = Free (the_single (Variable.variant_frees lthy [] [(Gname, GT)])) + + fun mk_GIntro (ClauseContext {qs, gs, lhs, rhs, ...}) RCs = + let + fun mk_h_assm (rcfix, rcassm, rcarg) = + HOLogic.mk_Trueprop (Gvar $ rcarg $ (fvar $ rcarg)) + |> fold_rev (curry Logic.mk_implies o prop_of) rcassm + |> fold_rev (Logic.all o Free) rcfix + in + HOLogic.mk_Trueprop (Gvar $ lhs $ rhs) + |> fold_rev (curry Logic.mk_implies o mk_h_assm) RCs + |> fold_rev (curry Logic.mk_implies) gs + |> fold_rev Logic.all (fvar :: qs) + end + + val G_intros = map2 mk_GIntro clauses RCss + + val (GIntro_thms, (G, G_elim, G_induct, lthy)) = + FundefInductiveWrap.inductive_def G_intros ((dest_Free Gvar, NoSyn), lthy) + in + ((G, GIntro_thms, G_elim, G_induct), lthy) + end + + + +fun define_function fdefname (fname, mixfix) domT ranT G default lthy = + let + val f_def = + Abs ("x", domT, Const ("FunDef.THE_default", ranT --> (ranT --> boolT) --> ranT) $ (default $ Bound 0) $ + Abs ("y", ranT, G $ Bound 1 $ Bound 0)) + |> Syntax.check_term lthy + + val ((f, (_, f_defthm)), lthy) = + LocalTheory.define Thm.internalK ((Binding.name (function_name fname), mixfix), ((Binding.name fdefname, []), f_def)) lthy + in + ((f, f_defthm), lthy) + end + + +fun define_recursion_relation Rname domT ranT fvar f qglrs clauses RCss lthy = + let + + val RT = domT --> domT --> boolT + val Rvar = Free (the_single (Variable.variant_frees lthy [] [(Rname, RT)])) + + fun mk_RIntro (ClauseContext {qs, gs, lhs, ...}, (oqs, _, _, _)) (rcfix, rcassm, rcarg) = + HOLogic.mk_Trueprop (Rvar $ rcarg $ lhs) + |> fold_rev (curry Logic.mk_implies o prop_of) rcassm + |> fold_rev (curry Logic.mk_implies) gs + |> fold_rev (Logic.all o Free) rcfix + |> fold_rev mk_forall_rename (map fst oqs ~~ qs) + (* "!!qs xs. CS ==> G => (r, lhs) : R" *) + + val R_intross = map2 (map o mk_RIntro) (clauses ~~ qglrs) RCss + + val (RIntro_thmss, (R, R_elim, _, lthy)) = + fold_burrow FundefInductiveWrap.inductive_def R_intross ((dest_Free Rvar, NoSyn), lthy) + in + ((R, RIntro_thmss, R_elim), lthy) + end + + +fun fix_globals domT ranT fvar ctxt = + let + val ([h, y, x, z, a, D, P, Pbool],ctxt') = + Variable.variant_fixes ["h_fd", "y_fd", "x_fd", "z_fd", "a_fd", "D_fd", "P_fd", "Pb_fd"] ctxt + in + (Globals {h = Free (h, domT --> ranT), + y = Free (y, ranT), + x = Free (x, domT), + z = Free (z, domT), + a = Free (a, domT), + D = Free (D, domT --> boolT), + P = Free (P, domT --> boolT), + Pbool = Free (Pbool, boolT), + fvar = fvar, + domT = domT, + ranT = ranT + }, + ctxt') + end + + + +fun inst_RC thy fvar f (rcfix, rcassm, rcarg) = + let + fun inst_term t = subst_bound(f, abstract_over (fvar, t)) + in + (rcfix, map (assume o cterm_of thy o inst_term o prop_of) rcassm, inst_term rcarg) + end + + + +(********************************************************** + * PROVING THE RULES + **********************************************************) + +fun mk_psimps thy globals R clauses valthms f_iff graph_is_function = + let + val Globals {domT, z, ...} = globals + + fun mk_psimp (ClauseInfo {qglr = (oqs, _, _, _), cdata = ClauseContext {cqs, lhs, ags, ...}, ...}) valthm = + let + val lhs_acc = cterm_of thy (HOLogic.mk_Trueprop (mk_acc domT R $ lhs)) (* "acc R lhs" *) + val z_smaller = cterm_of thy (HOLogic.mk_Trueprop (R $ z $ lhs)) (* "R z lhs" *) + in + ((assume z_smaller) RS ((assume lhs_acc) RS acc_downward)) + |> (fn it => it COMP graph_is_function) + |> implies_intr z_smaller + |> forall_intr (cterm_of thy z) + |> (fn it => it COMP valthm) + |> implies_intr lhs_acc + |> asm_simplify (HOL_basic_ss addsimps [f_iff]) + |> fold_rev (implies_intr o cprop_of) ags + |> fold_rev forall_intr_rename (map fst oqs ~~ cqs) + end + in + map2 mk_psimp clauses valthms + end + + +(** Induction rule **) + + +val acc_subset_induct = @{thm Orderings.predicate1I} RS @{thm accp_subset_induct} + + +fun binder_conv cv ctxt = Conv.arg_conv (Conv.abs_conv (K cv) ctxt); + +fun mk_partial_induct_rule thy globals R complete_thm clauses = + let + val Globals {domT, x, z, a, P, D, ...} = globals + val acc_R = mk_acc domT R + + val x_D = assume (cterm_of thy (HOLogic.mk_Trueprop (D $ x))) + val a_D = cterm_of thy (HOLogic.mk_Trueprop (D $ a)) + + val D_subset = cterm_of thy (Logic.all x + (Logic.mk_implies (HOLogic.mk_Trueprop (D $ x), HOLogic.mk_Trueprop (acc_R $ x)))) + + val D_dcl = (* "!!x z. [| x: D; (z,x):R |] ==> z:D" *) + Logic.all x + (Logic.all z (Logic.mk_implies (HOLogic.mk_Trueprop (D $ x), + Logic.mk_implies (HOLogic.mk_Trueprop (R $ z $ x), + HOLogic.mk_Trueprop (D $ z))))) + |> cterm_of thy + + + (* Inductive Hypothesis: !!z. (z,x):R ==> P z *) + val ihyp = Term.all domT $ Abs ("z", domT, + Logic.mk_implies (HOLogic.mk_Trueprop (R $ Bound 0 $ x), + HOLogic.mk_Trueprop (P $ Bound 0))) + |> cterm_of thy + + val aihyp = assume ihyp + + fun prove_case clause = + let + val ClauseInfo {cdata = ClauseContext {ctxt, qs, cqs, ags, gs, lhs, case_hyp, ...}, RCs, + qglr = (oqs, _, _, _), ...} = clause + + val case_hyp_conv = K (case_hyp RS eq_reflection) + local open Conv in + val lhs_D = fconv_rule (arg_conv (arg_conv (case_hyp_conv))) x_D + val sih = fconv_rule (binder_conv (arg1_conv (arg_conv (arg_conv case_hyp_conv))) ctxt) aihyp + end + + fun mk_Prec (RCInfo {llRI, RIvs, CCas, rcarg, ...}) = + sih |> forall_elim (cterm_of thy rcarg) + |> Thm.elim_implies llRI + |> fold_rev (implies_intr o cprop_of) CCas + |> fold_rev (forall_intr o cterm_of thy o Free) RIvs + + val P_recs = map mk_Prec RCs (* [P rec1, P rec2, ... ] *) + + val step = HOLogic.mk_Trueprop (P $ lhs) + |> fold_rev (curry Logic.mk_implies o prop_of) P_recs + |> fold_rev (curry Logic.mk_implies) gs + |> curry Logic.mk_implies (HOLogic.mk_Trueprop (D $ lhs)) + |> fold_rev mk_forall_rename (map fst oqs ~~ qs) + |> cterm_of thy + + val P_lhs = assume step + |> fold forall_elim cqs + |> Thm.elim_implies lhs_D + |> fold Thm.elim_implies ags + |> fold Thm.elim_implies P_recs + + val res = cterm_of thy (HOLogic.mk_Trueprop (P $ x)) + |> Conv.arg_conv (Conv.arg_conv case_hyp_conv) + |> symmetric (* P lhs == P x *) + |> (fn eql => equal_elim eql P_lhs) (* "P x" *) + |> implies_intr (cprop_of case_hyp) + |> fold_rev (implies_intr o cprop_of) ags + |> fold_rev forall_intr cqs + in + (res, step) + end + + val (cases, steps) = split_list (map prove_case clauses) + + val istep = complete_thm + |> Thm.forall_elim_vars 0 + |> fold (curry op COMP) cases (* P x *) + |> implies_intr ihyp + |> implies_intr (cprop_of x_D) + |> forall_intr (cterm_of thy x) + + val subset_induct_rule = + acc_subset_induct + |> (curry op COMP) (assume D_subset) + |> (curry op COMP) (assume D_dcl) + |> (curry op COMP) (assume a_D) + |> (curry op COMP) istep + |> fold_rev implies_intr steps + |> implies_intr a_D + |> implies_intr D_dcl + |> implies_intr D_subset + + val subset_induct_all = fold_rev (forall_intr o cterm_of thy) [P, a, D] subset_induct_rule + + val simple_induct_rule = + subset_induct_rule + |> forall_intr (cterm_of thy D) + |> forall_elim (cterm_of thy acc_R) + |> assume_tac 1 |> Seq.hd + |> (curry op COMP) (acc_downward + |> (instantiate' [SOME (ctyp_of thy domT)] + (map (SOME o cterm_of thy) [R, x, z])) + |> forall_intr (cterm_of thy z) + |> forall_intr (cterm_of thy x)) + |> forall_intr (cterm_of thy a) + |> forall_intr (cterm_of thy P) + in + simple_induct_rule + end + + + +(* FIXME: This should probably use fixed goals, to be more reliable and faster *) +fun mk_domain_intro ctxt (Globals {domT, ...}) R R_cases clause = + let + val thy = ProofContext.theory_of ctxt + val ClauseInfo {cdata = ClauseContext {qs, gs, lhs, rhs, cqs, ...}, + qglr = (oqs, _, _, _), ...} = clause + val goal = HOLogic.mk_Trueprop (mk_acc domT R $ lhs) + |> fold_rev (curry Logic.mk_implies) gs + |> cterm_of thy + in + Goal.init goal + |> (SINGLE (resolve_tac [accI] 1)) |> the + |> (SINGLE (eresolve_tac [Thm.forall_elim_vars 0 R_cases] 1)) |> the + |> (SINGLE (auto_tac (local_clasimpset_of ctxt))) |> the + |> Goal.conclude + |> fold_rev forall_intr_rename (map fst oqs ~~ cqs) + end + + + +(** Termination rule **) + +val wf_induct_rule = @{thm Wellfounded.wfP_induct_rule}; +val wf_in_rel = @{thm FunDef.wf_in_rel}; +val in_rel_def = @{thm FunDef.in_rel_def}; + +fun mk_nest_term_case thy globals R' ihyp clause = + let + val Globals {x, z, ...} = globals + val ClauseInfo {cdata = ClauseContext {qs,cqs,ags,lhs,rhs,case_hyp,...},tree, + qglr=(oqs, _, _, _), ...} = clause + + val ih_case = full_simplify (HOL_basic_ss addsimps [case_hyp]) ihyp + + fun step (fixes, assumes) (_ $ arg) u (sub,(hyps,thms)) = + let + val used = map (fn (ctx,thm) => FundefCtxTree.export_thm thy ctx thm) (u @ sub) + + val hyp = HOLogic.mk_Trueprop (R' $ arg $ lhs) + |> fold_rev (curry Logic.mk_implies o prop_of) used (* additional hyps *) + |> FundefCtxTree.export_term (fixes, assumes) + |> fold_rev (curry Logic.mk_implies o prop_of) ags + |> fold_rev mk_forall_rename (map fst oqs ~~ qs) + |> cterm_of thy + + val thm = assume hyp + |> fold forall_elim cqs + |> fold Thm.elim_implies ags + |> FundefCtxTree.import_thm thy (fixes, assumes) + |> fold Thm.elim_implies used (* "(arg, lhs) : R'" *) + + val z_eq_arg = assume (cterm_of thy (HOLogic.mk_Trueprop (mk_eq (z, arg)))) + + val acc = thm COMP ih_case + val z_acc_local = acc + |> Conv.fconv_rule (Conv.arg_conv (Conv.arg_conv (K (symmetric (z_eq_arg RS eq_reflection))))) + + val ethm = z_acc_local + |> FundefCtxTree.export_thm thy (fixes, + z_eq_arg :: case_hyp :: ags @ assumes) + |> fold_rev forall_intr_rename (map fst oqs ~~ cqs) + + val sub' = sub @ [(([],[]), acc)] + in + (sub', (hyp :: hyps, ethm :: thms)) + end + | step _ _ _ _ = raise Match + in + FundefCtxTree.traverse_tree step tree + end + + +fun mk_nest_term_rule thy globals R R_cases clauses = + let + val Globals { domT, x, z, ... } = globals + val acc_R = mk_acc domT R + + val R' = Free ("R", fastype_of R) + + val Rrel = Free ("R", HOLogic.mk_setT (HOLogic.mk_prodT (domT, domT))) + val inrel_R = Const ("FunDef.in_rel", HOLogic.mk_setT (HOLogic.mk_prodT (domT, domT)) --> fastype_of R) $ Rrel + + val wfR' = cterm_of thy (HOLogic.mk_Trueprop (Const (@{const_name "Wellfounded.wfP"}, (domT --> domT --> boolT) --> boolT) $ R')) (* "wf R'" *) + + (* Inductive Hypothesis: !!z. (z,x):R' ==> z : acc R *) + val ihyp = Term.all domT $ Abs ("z", domT, + Logic.mk_implies (HOLogic.mk_Trueprop (R' $ Bound 0 $ x), + HOLogic.mk_Trueprop (acc_R $ Bound 0))) + |> cterm_of thy + + val ihyp_a = assume ihyp |> Thm.forall_elim_vars 0 + + val R_z_x = cterm_of thy (HOLogic.mk_Trueprop (R $ z $ x)) + + val (hyps,cases) = fold (mk_nest_term_case thy globals R' ihyp_a) clauses ([],[]) + in + R_cases + |> forall_elim (cterm_of thy z) + |> forall_elim (cterm_of thy x) + |> forall_elim (cterm_of thy (acc_R $ z)) + |> curry op COMP (assume R_z_x) + |> fold_rev (curry op COMP) cases + |> implies_intr R_z_x + |> forall_intr (cterm_of thy z) + |> (fn it => it COMP accI) + |> implies_intr ihyp + |> forall_intr (cterm_of thy x) + |> (fn it => Drule.compose_single(it,2,wf_induct_rule)) + |> curry op RS (assume wfR') + |> forall_intr_vars + |> (fn it => it COMP allI) + |> fold implies_intr hyps + |> implies_intr wfR' + |> forall_intr (cterm_of thy R') + |> forall_elim (cterm_of thy (inrel_R)) + |> curry op RS wf_in_rel + |> full_simplify (HOL_basic_ss addsimps [in_rel_def]) + |> forall_intr (cterm_of thy Rrel) + end + + + +(* Tail recursion (probably very fragile) + * + * FIXME: + * - Need to do forall_elim_vars on psimps: Unneccesary, if psimps would be taken from the same context. + * - Must we really replace the fvar by f here? + * - Splitting is not configured automatically: Problems with case? + *) +fun mk_trsimps octxt globals f G R f_def R_cases G_induct clauses psimps = + let + val Globals {domT, ranT, fvar, ...} = globals + + val R_cases = Thm.forall_elim_vars 0 R_cases (* FIXME: Should be already in standard form. *) + + val graph_implies_dom = (* "G ?x ?y ==> dom ?x" *) + Goal.prove octxt ["x", "y"] [HOLogic.mk_Trueprop (G $ Free ("x", domT) $ Free ("y", ranT))] + (HOLogic.mk_Trueprop (mk_acc domT R $ Free ("x", domT))) + (fn {prems=[a], ...} => + ((rtac (G_induct OF [a])) + THEN_ALL_NEW (rtac accI) + THEN_ALL_NEW (etac R_cases) + THEN_ALL_NEW (asm_full_simp_tac (local_simpset_of octxt))) 1) + + val default_thm = (forall_intr_vars graph_implies_dom) COMP (f_def COMP fundef_default_value) + + fun mk_trsimp clause psimp = + let + val ClauseInfo {qglr = (oqs, _, _, _), cdata = ClauseContext {ctxt, cqs, qs, gs, lhs, rhs, ...}, ...} = clause + val thy = ProofContext.theory_of ctxt + val rhs_f = Pattern.rewrite_term thy [(fvar, f)] [] rhs + + val trsimp = Logic.list_implies(gs, HOLogic.mk_Trueprop (HOLogic.mk_eq(f $ lhs, rhs_f))) (* "f lhs = rhs" *) + val lhs_acc = (mk_acc domT R $ lhs) (* "acc R lhs" *) + fun simp_default_tac ss = asm_full_simp_tac (ss addsimps [default_thm, Let_def]) + in + Goal.prove ctxt [] [] trsimp + (fn _ => + rtac (instantiate' [] [SOME (cterm_of thy lhs_acc)] case_split) 1 + THEN (rtac (Thm.forall_elim_vars 0 psimp) THEN_ALL_NEW assume_tac) 1 + THEN (simp_default_tac (local_simpset_of ctxt) 1) + THEN (etac not_acc_down 1) + THEN ((etac R_cases) THEN_ALL_NEW (simp_default_tac (local_simpset_of ctxt))) 1) + |> fold_rev forall_intr_rename (map fst oqs ~~ cqs) + end + in + map2 mk_trsimp clauses psimps + end + + +fun prepare_fundef config defname [((fname, fT), mixfix)] abstract_qglrs lthy = + let + val FundefConfig {domintros, tailrec, default=default_str, ...} = config + + val fvar = Free (fname, fT) + val domT = domain_type fT + val ranT = range_type fT + + val default = Syntax.parse_term lthy default_str + |> TypeInfer.constrain fT |> Syntax.check_term lthy + + val (globals, ctxt') = fix_globals domT ranT fvar lthy + + val Globals { x, h, ... } = globals + + val clauses = map (mk_clause_context x ctxt') abstract_qglrs + + val n = length abstract_qglrs + + fun build_tree (ClauseContext { ctxt, rhs, ...}) = + FundefCtxTree.mk_tree (fname, fT) h ctxt rhs + + val trees = map build_tree clauses + val RCss = map find_calls trees + + val ((G, GIntro_thms, G_elim, G_induct), lthy) = + PROFILE "def_graph" (define_graph (graph_name defname) fvar domT ranT clauses RCss) lthy + + val ((f, f_defthm), lthy) = + PROFILE "def_fun" (define_function (defname ^ "_sumC_def") (fname, mixfix) domT ranT G default) lthy + + val RCss = map (map (inst_RC (ProofContext.theory_of lthy) fvar f)) RCss + val trees = map (FundefCtxTree.inst_tree (ProofContext.theory_of lthy) fvar f) trees + + val ((R, RIntro_thmss, R_elim), lthy) = + PROFILE "def_rel" (define_recursion_relation (rel_name defname) domT ranT fvar f abstract_qglrs clauses RCss) lthy + + val (_, lthy) = + LocalTheory.abbrev Syntax.mode_default ((Binding.name (dom_name defname), NoSyn), mk_acc domT R) lthy + + val newthy = ProofContext.theory_of lthy + val clauses = map (transfer_clause_ctx newthy) clauses + + val cert = cterm_of (ProofContext.theory_of lthy) + + val xclauses = PROFILE "xclauses" (map7 (mk_clause_info globals G f) (1 upto n) clauses abstract_qglrs trees RCss GIntro_thms) RIntro_thmss + + val complete = mk_completeness globals clauses abstract_qglrs |> cert |> assume + val compat = mk_compat_proof_obligations domT ranT fvar f abstract_qglrs |> map (cert #> assume) + + val compat_store = store_compat_thms n compat + + val (goalstate, values) = PROFILE "prove_stuff" (prove_stuff lthy globals G f R xclauses complete compat compat_store G_elim) f_defthm + + val mk_trsimps = mk_trsimps lthy globals f G R f_defthm R_elim G_induct xclauses + + fun mk_partial_rules provedgoal = + let + val newthy = theory_of_thm provedgoal (*FIXME*) + + val (graph_is_function, complete_thm) = + provedgoal + |> Conjunction.elim + |> apfst (Thm.forall_elim_vars 0) + + val f_iff = graph_is_function RS (f_defthm RS ex1_implies_iff) + + val psimps = PROFILE "Proving simplification rules" (mk_psimps newthy globals R xclauses values f_iff) graph_is_function + + val simple_pinduct = PROFILE "Proving partial induction rule" + (mk_partial_induct_rule newthy globals R complete_thm) xclauses + + + val total_intro = PROFILE "Proving nested termination rule" (mk_nest_term_rule newthy globals R R_elim) xclauses + + val dom_intros = if domintros + then SOME (PROFILE "Proving domain introduction rules" (map (mk_domain_intro lthy globals R R_elim)) xclauses) + else NONE + val trsimps = if tailrec then SOME (mk_trsimps psimps) else NONE + + in + FundefResult {fs=[f], G=G, R=R, cases=complete_thm, + psimps=psimps, simple_pinducts=[simple_pinduct], + termination=total_intro, trsimps=trsimps, + domintros=dom_intros} + end + in + ((f, goalstate, mk_partial_rules), lthy) + end + + +end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/fundef_datatype.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/fundef_datatype.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,330 @@ +(* Title: HOL/Tools/Function/fundef_datatype.ML + Author: Alexander Krauss, TU Muenchen + +A package for general recursive function definitions. +A tactic to prove completeness of datatype patterns. +*) + +signature FUNDEF_DATATYPE = +sig + val pat_completeness_tac: Proof.context -> int -> tactic + val pat_completeness: Proof.context -> Proof.method + val prove_completeness : theory -> term list -> term -> term list list -> term list list -> thm + + val setup : theory -> theory + + val add_fun : FundefCommon.fundef_config -> + (binding * typ option * mixfix) list -> (Attrib.binding * term) list -> + bool -> local_theory -> Proof.context + val add_fun_cmd : FundefCommon.fundef_config -> + (binding * string option * mixfix) list -> (Attrib.binding * string) list -> + bool -> local_theory -> Proof.context +end + +structure FundefDatatype : FUNDEF_DATATYPE = +struct + +open FundefLib +open FundefCommon + + +fun check_pats ctxt geq = + let + fun err str = error (cat_lines ["Malformed definition:", + str ^ " not allowed in sequential mode.", + Syntax.string_of_term ctxt geq]) + val thy = ProofContext.theory_of ctxt + + fun check_constr_pattern (Bound _) = () + | check_constr_pattern t = + let + val (hd, args) = strip_comb t + in + (((case Datatype.datatype_of_constr thy (fst (dest_Const hd)) of + SOME _ => () + | NONE => err "Non-constructor pattern") + handle TERM ("dest_Const", _) => err "Non-constructor patterns"); + map check_constr_pattern args; + ()) + end + + val (fname, qs, gs, args, rhs) = split_def ctxt geq + + val _ = if not (null gs) then err "Conditional equations" else () + val _ = map check_constr_pattern args + + (* just count occurrences to check linearity *) + val _ = if fold (fold_aterms (fn Bound _ => curry (op +) 1 | _ => I)) args 0 > length qs + then err "Nonlinear patterns" else () + in + () + end + + +fun mk_argvar i T = Free ("_av" ^ (string_of_int i), T) +fun mk_patvar i T = Free ("_pv" ^ (string_of_int i), T) + +fun inst_free var inst thm = + forall_elim inst (forall_intr var thm) + + +fun inst_case_thm thy x P thm = + let + val [Pv, xv] = Term.add_vars (prop_of thm) [] + in + cterm_instantiate [(cterm_of thy (Var xv), cterm_of thy x), + (cterm_of thy (Var Pv), cterm_of thy P)] thm + end + + +fun invent_vars constr i = + let + val Ts = binder_types (fastype_of constr) + val j = i + length Ts + val is = i upto (j - 1) + val avs = map2 mk_argvar is Ts + val pvs = map2 mk_patvar is Ts + in + (avs, pvs, j) + end + + +fun filter_pats thy cons pvars [] = [] + | filter_pats thy cons pvars (([], thm) :: pts) = raise Match + | filter_pats thy cons pvars ((pat :: pats, thm) :: pts) = + case pat of + Free _ => let val inst = list_comb (cons, pvars) + in (inst :: pats, inst_free (cterm_of thy pat) (cterm_of thy inst) thm) + :: (filter_pats thy cons pvars pts) end + | _ => if fst (strip_comb pat) = cons + then (pat :: pats, thm) :: (filter_pats thy cons pvars pts) + else filter_pats thy cons pvars pts + + +fun inst_constrs_of thy (T as Type (name, _)) = + map (fn (Cn,CT) => Envir.subst_TVars (Sign.typ_match thy (body_type CT, T) Vartab.empty) (Const (Cn, CT))) + (the (Datatype.get_datatype_constrs thy name)) + | inst_constrs_of thy _ = raise Match + + +fun transform_pat thy avars c_assum ([] , thm) = raise Match + | transform_pat thy avars c_assum (pat :: pats, thm) = + let + val (_, subps) = strip_comb pat + val eqs = map (cterm_of thy o HOLogic.mk_Trueprop o HOLogic.mk_eq) (avars ~~ subps) + val a_eqs = map assume eqs + val c_eq_pat = simplify (HOL_basic_ss addsimps a_eqs) c_assum + in + (subps @ pats, fold_rev implies_intr eqs + (implies_elim thm c_eq_pat)) + end + + +exception COMPLETENESS + +fun constr_case thy P idx (v :: vs) pats cons = + let + val (avars, pvars, newidx) = invent_vars cons idx + val c_hyp = cterm_of thy (HOLogic.mk_Trueprop (HOLogic.mk_eq (v, list_comb (cons, avars)))) + val c_assum = assume c_hyp + val newpats = map (transform_pat thy avars c_assum) (filter_pats thy cons pvars pats) + in + o_alg thy P newidx (avars @ vs) newpats + |> implies_intr c_hyp + |> fold_rev (forall_intr o cterm_of thy) avars + end + | constr_case _ _ _ _ _ _ = raise Match +and o_alg thy P idx [] (([], Pthm) :: _) = Pthm + | o_alg thy P idx (v :: vs) [] = raise COMPLETENESS + | o_alg thy P idx (v :: vs) pts = + if forall (is_Free o hd o fst) pts (* Var case *) + then o_alg thy P idx vs (map (fn (pv :: pats, thm) => + (pats, refl RS (inst_free (cterm_of thy pv) (cterm_of thy v) thm))) pts) + else (* Cons case *) + let + val T = fastype_of v + val (tname, _) = dest_Type T + val {exhaustion=case_thm, ...} = Datatype.the_datatype thy tname + val constrs = inst_constrs_of thy T + val c_cases = map (constr_case thy P idx (v :: vs) pts) constrs + in + inst_case_thm thy v P case_thm + |> fold (curry op COMP) c_cases + end + | o_alg _ _ _ _ _ = raise Match + + +fun prove_completeness thy xs P qss patss = + let + fun mk_assum qs pats = + HOLogic.mk_Trueprop P + |> fold_rev (curry Logic.mk_implies o HOLogic.mk_Trueprop o HOLogic.mk_eq) (xs ~~ pats) + |> fold_rev Logic.all qs + |> cterm_of thy + + val hyps = map2 mk_assum qss patss + + fun inst_hyps hyp qs = fold (forall_elim o cterm_of thy) qs (assume hyp) + + val assums = map2 inst_hyps hyps qss + in + o_alg thy P 2 xs (patss ~~ assums) + |> fold_rev implies_intr hyps + end + + + +fun pat_completeness_tac ctxt = SUBGOAL (fn (subgoal, i) => + let + val thy = ProofContext.theory_of ctxt + val (vs, subgf) = dest_all_all subgoal + val (cases, _ $ thesis) = Logic.strip_horn subgf + handle Bind => raise COMPLETENESS + + fun pat_of assum = + let + val (qs, imp) = dest_all_all assum + val prems = Logic.strip_imp_prems imp + in + (qs, map (HOLogic.dest_eq o HOLogic.dest_Trueprop) prems) + end + + val (qss, x_pats) = split_list (map pat_of cases) + val xs = map fst (hd x_pats) + handle Empty => raise COMPLETENESS + + val patss = map (map snd) x_pats + + val complete_thm = prove_completeness thy xs thesis qss patss + |> fold_rev (forall_intr o cterm_of thy) vs + in + PRIMITIVE (fn st => Drule.compose_single(complete_thm, i, st)) + end + handle COMPLETENESS => no_tac) + + +fun pat_completeness ctxt = SIMPLE_METHOD' (pat_completeness_tac ctxt) + +val by_pat_completeness_auto = + Proof.global_future_terminal_proof + (Method.Basic (pat_completeness, Position.none), + SOME (Method.Source_i (Args.src (("HOL.auto", []), Position.none)))) + +fun termination_by method int = + Fundef.termination_proof NONE + #> Proof.global_future_terminal_proof + (Method.Basic (method, Position.none), NONE) int + +fun mk_catchall fixes arity_of = + let + fun mk_eqn ((fname, fT), _) = + let + val n = arity_of fname + val (argTs, rT) = chop n (binder_types fT) + |> apsnd (fn Ts => Ts ---> body_type fT) + + val qs = map Free (Name.invent_list [] "a" n ~~ argTs) + in + HOLogic.mk_eq(list_comb (Free (fname, fT), qs), + Const ("HOL.undefined", rT)) + |> HOLogic.mk_Trueprop + |> fold_rev Logic.all qs + end + in + map mk_eqn fixes + end + +fun add_catchall ctxt fixes spec = + let val fqgars = map (split_def ctxt) spec + val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars + |> AList.lookup (op =) #> the + in + spec @ mk_catchall fixes arity_of + end + +fun warn_if_redundant ctxt origs tss = + let + fun msg t = "Ignoring redundant equation: " ^ quote (Syntax.string_of_term ctxt t) + + val (tss', _) = chop (length origs) tss + fun check (t, []) = (Output.warning (msg t); []) + | check (t, s) = s + in + (map check (origs ~~ tss'); tss) + end + + +fun sequential_preproc (config as FundefConfig {sequential, ...}) ctxt fixes spec = + if sequential then + let + val (bnds, eqss) = split_list spec + + val eqs = map the_single eqss + + val feqs = eqs + |> tap (check_defs ctxt fixes) (* Standard checks *) + |> tap (map (check_pats ctxt)) (* More checks for sequential mode *) + + val compleqs = add_catchall ctxt fixes feqs (* Completion *) + + val spliteqs = warn_if_redundant ctxt feqs + (FundefSplit.split_all_equations ctxt compleqs) + + fun restore_spec thms = + bnds ~~ Library.take (length bnds, Library.unflat spliteqs thms) + + val spliteqs' = flat (Library.take (length bnds, spliteqs)) + val fnames = map (fst o fst) fixes + val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs' + + fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs) + |> map (map snd) + + + val bnds' = bnds @ replicate (length spliteqs - length bnds) Attrib.empty_binding + + (* using theorem names for case name currently disabled *) + val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) + (bnds' ~~ spliteqs) + |> flat + in + (flat spliteqs, restore_spec, sort, case_names) + end + else + FundefCommon.empty_preproc check_defs config ctxt fixes spec + +val setup = + Method.setup @{binding pat_completeness} (Scan.succeed pat_completeness) + "Completeness prover for datatype patterns" + #> Context.theory_map (FundefCommon.set_preproc sequential_preproc) + + +val fun_config = FundefConfig { sequential=true, default="%x. undefined" (*FIXME dynamic scoping*), + domintros=false, tailrec=false } + +fun gen_fun add config fixes statements int lthy = + let val group = serial_string () in + lthy + |> LocalTheory.set_group group + |> add fixes statements config + |> by_pat_completeness_auto int + |> LocalTheory.restore + |> LocalTheory.set_group group + |> termination_by (FundefCommon.get_termination_prover lthy) int + end; + +val add_fun = gen_fun Fundef.add_fundef +val add_fun_cmd = gen_fun Fundef.add_fundef_cmd + + + +local structure P = OuterParse and K = OuterKeyword in + +val _ = + OuterSyntax.local_theory' "fun" "define general recursive functions (short version)" K.thy_decl + (fundef_parser fun_config + >> (fn ((config, fixes), statements) => add_fun_cmd config fixes statements)); + +end + +end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/fundef_lib.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/fundef_lib.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,176 @@ +(* Title: HOL/Tools/Function/fundef_lib.ML + Author: Alexander Krauss, TU Muenchen + +A package for general recursive function definitions. +Some fairly general functions that should probably go somewhere else... +*) + +structure FundefLib = struct + +fun map_option f NONE = NONE + | map_option f (SOME x) = SOME (f x); + +fun fold_option f NONE y = y + | fold_option f (SOME x) y = f x y; + +fun fold_map_option f NONE y = (NONE, y) + | fold_map_option f (SOME x) y = apfst SOME (f x y); + +(* Ex: "The variable" ^ plural " is" "s are" vs *) +fun plural sg pl [x] = sg + | plural sg pl _ = pl + +(* lambda-abstracts over an arbitrarily nested tuple + ==> hologic.ML? *) +fun tupled_lambda vars t = + case vars of + (Free v) => lambda (Free v) t + | (Var v) => lambda (Var v) t + | (Const ("Pair", Type ("fun", [Ta, Type ("fun", [Tb, _])]))) $ us $ vs => + (HOLogic.split_const (Ta,Tb, fastype_of t)) $ (tupled_lambda us (tupled_lambda vs t)) + | _ => raise Match + + +fun dest_all (Const ("all", _) $ Abs (a as (_,T,_))) = + let + val (n, body) = Term.dest_abs a + in + (Free (n, T), body) + end + | dest_all _ = raise Match + + +(* Removes all quantifiers from a term, replacing bound variables by frees. *) +fun dest_all_all (t as (Const ("all",_) $ _)) = + let + val (v,b) = dest_all t + val (vs, b') = dest_all_all b + in + (v :: vs, b') + end + | dest_all_all t = ([],t) + + +(* FIXME: similar to Variable.focus *) +fun dest_all_all_ctx ctx (Const ("all", _) $ Abs (a as (n,T,b))) = + let + val [(n', _)] = Variable.variant_frees ctx [] [(n,T)] + val (_, ctx') = ProofContext.add_fixes [(Binding.name n', SOME T, NoSyn)] ctx + + val (n'', body) = Term.dest_abs (n', T, b) + val _ = (n' = n'') orelse error "dest_all_ctx" + (* Note: We assume that n' does not occur in the body. Otherwise it would be fixed. *) + + val (ctx'', vs, bd) = dest_all_all_ctx ctx' body + in + (ctx'', (n', T) :: vs, bd) + end + | dest_all_all_ctx ctx t = + (ctx, [], t) + + +fun map3 _ [] [] [] = [] + | map3 f (x :: xs) (y :: ys) (z :: zs) = f x y z :: map3 f xs ys zs + | map3 _ _ _ _ = raise Library.UnequalLengths; + +fun map4 _ [] [] [] [] = [] + | map4 f (x :: xs) (y :: ys) (z :: zs) (u :: us) = f x y z u :: map4 f xs ys zs us + | map4 _ _ _ _ _ = raise Library.UnequalLengths; + +fun map6 _ [] [] [] [] [] [] = [] + | map6 f (x :: xs) (y :: ys) (z :: zs) (u :: us) (v :: vs) (w :: ws) = f x y z u v w :: map6 f xs ys zs us vs ws + | map6 _ _ _ _ _ _ _ = raise Library.UnequalLengths; + +fun map7 _ [] [] [] [] [] [] [] = [] + | map7 f (x :: xs) (y :: ys) (z :: zs) (u :: us) (v :: vs) (w :: ws) (b :: bs) = f x y z u v w b :: map7 f xs ys zs us vs ws bs + | map7 _ _ _ _ _ _ _ _ = raise Library.UnequalLengths; + + + +(* forms all "unordered pairs": [1, 2, 3] ==> [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)] *) +(* ==> library *) +fun unordered_pairs [] = [] + | unordered_pairs (x::xs) = map (pair x) (x::xs) @ unordered_pairs xs + + +(* Replaces Frees by name. Works with loose Bounds. *) +fun replace_frees assoc = + map_aterms (fn c as Free (n, _) => the_default c (AList.lookup (op =) assoc n) + | t => t) + + +fun rename_bound n (Q $ Abs(_, T, b)) = (Q $ Abs(n, T, b)) + | rename_bound n _ = raise Match + +fun mk_forall_rename (n, v) = + rename_bound n o Logic.all v + +fun forall_intr_rename (n, cv) thm = + let + val allthm = forall_intr cv thm + val (_ $ abs) = prop_of allthm + in + Thm.rename_boundvars abs (Abs (n, dummyT, Term.dummy_pattern dummyT)) allthm + end + + +(* Returns the frees in a term in canonical order, excluding the fixes from the context *) +fun frees_in_term ctxt t = + Term.add_frees t [] + |> filter_out (Variable.is_fixed ctxt o fst) + |> rev + + +datatype proof_attempt = Solved of thm | Stuck of thm | Fail + +fun try_proof cgoal tac = + case SINGLE tac (Goal.init cgoal) of + NONE => Fail + | SOME st => if Thm.no_prems st then Solved (Goal.finish st) else Stuck st + + +fun dest_binop_list cn (t as (Const (n, _) $ a $ b)) = + if cn = n then dest_binop_list cn a @ dest_binop_list cn b else [ t ] + | dest_binop_list _ t = [ t ] + + +(* separate two parts in a +-expression: + "a + b + c + d + e" --> "(a + b + d) + (c + e)" + + Here, + can be any binary operation that is AC. + + cn - The name of the binop-constructor (e.g. @{const_name Un}) + ac - the AC rewrite rules for cn + is - the list of indices of the expressions that should become the first part + (e.g. [0,1,3] in the above example) +*) + +fun regroup_conv neu cn ac is ct = + let + val mk = HOLogic.mk_binop cn + val t = term_of ct + val xs = dest_binop_list cn t + val js = 0 upto (length xs) - 1 \\ is + val ty = fastype_of t + val thy = theory_of_cterm ct + in + Goal.prove_internal [] + (cterm_of thy + (Logic.mk_equals (t, + if is = [] + then mk (Const (neu, ty), foldr1 mk (map (nth xs) js)) + else if js = [] + then mk (foldr1 mk (map (nth xs) is), Const (neu, ty)) + else mk (foldr1 mk (map (nth xs) is), foldr1 mk (map (nth xs) js))))) + (K (rewrite_goals_tac ac + THEN rtac Drule.reflexive_thm 1)) + end + +(* instance for unions *) +fun regroup_union_conv t = regroup_conv @{const_name Set.empty} @{const_name Un} + (map (fn t => t RS eq_reflection) (@{thms "Un_ac"} @ + @{thms "Un_empty_right"} @ + @{thms "Un_empty_left"})) t + + +end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/induction_scheme.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/induction_scheme.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,405 @@ +(* Title: HOL/Tools/Function/induction_scheme.ML + Author: Alexander Krauss, TU Muenchen + +A method to prove induction schemes. +*) + +signature INDUCTION_SCHEME = +sig + val mk_ind_tac : (int -> tactic) -> (int -> tactic) -> (int -> tactic) + -> Proof.context -> thm list -> tactic + val induct_scheme_tac : Proof.context -> thm list -> tactic + val setup : theory -> theory +end + + +structure InductionScheme : INDUCTION_SCHEME = +struct + +open FundefLib + + +type rec_call_info = int * (string * typ) list * term list * term list + +datatype scheme_case = + SchemeCase of + { + bidx : int, + qs: (string * typ) list, + oqnames: string list, + gs: term list, + lhs: term list, + rs: rec_call_info list + } + +datatype scheme_branch = + SchemeBranch of + { + P : term, + xs: (string * typ) list, + ws: (string * typ) list, + Cs: term list + } + +datatype ind_scheme = + IndScheme of + { + T: typ, (* sum of products *) + branches: scheme_branch list, + cases: scheme_case list + } + +val ind_atomize = MetaSimplifier.rewrite true @{thms induct_atomize} +val ind_rulify = MetaSimplifier.rewrite true @{thms induct_rulify} + +fun meta thm = thm RS eq_reflection + +val sum_prod_conv = MetaSimplifier.rewrite true + (map meta (@{thm split_conv} :: @{thms sum.cases})) + +fun term_conv thy cv t = + cv (cterm_of thy t) + |> prop_of |> Logic.dest_equals |> snd + +fun mk_relT T = HOLogic.mk_setT (HOLogic.mk_prodT (T, T)) + +fun dest_hhf ctxt t = + let + val (ctxt', vars, imp) = dest_all_all_ctx ctxt t + in + (ctxt', vars, Logic.strip_imp_prems imp, Logic.strip_imp_concl imp) + end + + +fun mk_scheme' ctxt cases concl = + let + fun mk_branch concl = + let + val (ctxt', ws, Cs, _ $ Pxs) = dest_hhf ctxt concl + val (P, xs) = strip_comb Pxs + in + SchemeBranch { P=P, xs=map dest_Free xs, ws=ws, Cs=Cs } + end + + val (branches, cases') = (* correction *) + case Logic.dest_conjunction_list concl of + [conc] => + let + val _ $ Pxs = Logic.strip_assums_concl conc + val (P, _) = strip_comb Pxs + val (cases', conds) = take_prefix (Term.exists_subterm (curry op aconv P)) cases + val concl' = fold_rev (curry Logic.mk_implies) conds conc + in + ([mk_branch concl'], cases') + end + | concls => (map mk_branch concls, cases) + + fun mk_case premise = + let + val (ctxt', qs, prems, _ $ Plhs) = dest_hhf ctxt premise + val (P, lhs) = strip_comb Plhs + + fun bidx Q = find_index (fn SchemeBranch {P=P',...} => Q aconv P') branches + + fun mk_rcinfo pr = + let + val (ctxt'', Gvs, Gas, _ $ Phyp) = dest_hhf ctxt' pr + val (P', rcs) = strip_comb Phyp + in + (bidx P', Gvs, Gas, rcs) + end + + fun is_pred v = exists (fn SchemeBranch {P,...} => v aconv P) branches + + val (gs, rcprs) = + take_prefix (not o Term.exists_subterm is_pred) prems + in + SchemeCase {bidx=bidx P, qs=qs, oqnames=map fst qs(*FIXME*), gs=gs, lhs=lhs, rs=map mk_rcinfo rcprs} + end + + fun PT_of (SchemeBranch { xs, ...}) = + foldr1 HOLogic.mk_prodT (map snd xs) + + val ST = BalancedTree.make (uncurry SumTree.mk_sumT) (map PT_of branches) + in + IndScheme {T=ST, cases=map mk_case cases', branches=branches } + end + + + +fun mk_completeness ctxt (IndScheme {cases, branches, ...}) bidx = + let + val SchemeBranch { xs, ws, Cs, ... } = nth branches bidx + val relevant_cases = filter (fn SchemeCase {bidx=bidx', ...} => bidx' = bidx) cases + + val allqnames = fold (fn SchemeCase {qs, ...} => fold (insert (op =) o Free) qs) relevant_cases [] + val (Pbool :: xs') = map Free (Variable.variant_frees ctxt allqnames (("P", HOLogic.boolT) :: xs)) + val Cs' = map (Pattern.rewrite_term (ProofContext.theory_of ctxt) (filter_out (op aconv) (map Free xs ~~ xs')) []) Cs + + fun mk_case (SchemeCase {qs, oqnames, gs, lhs, ...}) = + HOLogic.mk_Trueprop Pbool + |> fold_rev (fn x_l => curry Logic.mk_implies (HOLogic.mk_Trueprop(HOLogic.mk_eq x_l))) + (xs' ~~ lhs) + |> fold_rev (curry Logic.mk_implies) gs + |> fold_rev mk_forall_rename (oqnames ~~ map Free qs) + in + HOLogic.mk_Trueprop Pbool + |> fold_rev (curry Logic.mk_implies o mk_case) relevant_cases + |> fold_rev (curry Logic.mk_implies) Cs' + |> fold_rev (Logic.all o Free) ws + |> fold_rev mk_forall_rename (map fst xs ~~ xs') + |> mk_forall_rename ("P", Pbool) + end + +fun mk_wf ctxt R (IndScheme {T, ...}) = + HOLogic.Trueprop $ (Const (@{const_name "wf"}, mk_relT T --> HOLogic.boolT) $ R) + +fun mk_ineqs R (IndScheme {T, cases, branches}) = + let + fun inject i ts = + SumTree.mk_inj T (length branches) (i + 1) (foldr1 HOLogic.mk_prod ts) + + val thesis = Free ("thesis", HOLogic.boolT) (* FIXME *) + + fun mk_pres bdx args = + let + val SchemeBranch { xs, ws, Cs, ... } = nth branches bdx + fun replace (x, v) t = betapply (lambda (Free x) t, v) + val Cs' = map (fold replace (xs ~~ args)) Cs + val cse = + HOLogic.mk_Trueprop thesis + |> fold_rev (curry Logic.mk_implies) Cs' + |> fold_rev (Logic.all o Free) ws + in + Logic.mk_implies (cse, HOLogic.mk_Trueprop thesis) + end + + fun f (SchemeCase {bidx, qs, oqnames, gs, lhs, rs, ...}) = + let + fun g (bidx', Gvs, Gas, rcarg) = + let val export = + fold_rev (curry Logic.mk_implies) Gas + #> fold_rev (curry Logic.mk_implies) gs + #> fold_rev (Logic.all o Free) Gvs + #> fold_rev mk_forall_rename (oqnames ~~ map Free qs) + in + (HOLogic.mk_mem (HOLogic.mk_prod (inject bidx' rcarg, inject bidx lhs), R) + |> HOLogic.mk_Trueprop + |> export, + mk_pres bidx' rcarg + |> export + |> Logic.all thesis) + end + in + map g rs + end + in + map f cases + end + + +fun mk_hol_imp a b = HOLogic.imp $ a $ b + +fun mk_ind_goal thy branches = + let + fun brnch (SchemeBranch { P, xs, ws, Cs, ... }) = + HOLogic.mk_Trueprop (list_comb (P, map Free xs)) + |> fold_rev (curry Logic.mk_implies) Cs + |> fold_rev (Logic.all o Free) ws + |> term_conv thy ind_atomize + |> ObjectLogic.drop_judgment thy + |> tupled_lambda (foldr1 HOLogic.mk_prod (map Free xs)) + in + SumTree.mk_sumcases HOLogic.boolT (map brnch branches) + end + + +fun mk_induct_rule ctxt R x complete_thms wf_thm ineqss (IndScheme {T, cases=scases, branches}) = + let + val n = length branches + + val scases_idx = map_index I scases + + fun inject i ts = + SumTree.mk_inj T n (i + 1) (foldr1 HOLogic.mk_prod ts) + val P_of = nth (map (fn (SchemeBranch { P, ... }) => P) branches) + + val thy = ProofContext.theory_of ctxt + val cert = cterm_of thy + + val P_comp = mk_ind_goal thy branches + + (* Inductive Hypothesis: !!z. (z,x):R ==> P z *) + val ihyp = Term.all T $ Abs ("z", T, + Logic.mk_implies + (HOLogic.mk_Trueprop ( + Const ("op :", HOLogic.mk_prodT (T, T) --> mk_relT T --> HOLogic.boolT) + $ (HOLogic.pair_const T T $ Bound 0 $ x) + $ R), + HOLogic.mk_Trueprop (P_comp $ Bound 0))) + |> cert + + val aihyp = assume ihyp + + (* Rule for case splitting along the sum types *) + val xss = map (fn (SchemeBranch { xs, ... }) => map Free xs) branches + val pats = map_index (uncurry inject) xss + val sum_split_rule = FundefDatatype.prove_completeness thy [x] (P_comp $ x) xss (map single pats) + + fun prove_branch (bidx, (SchemeBranch { P, xs, ws, Cs, ... }, (complete_thm, pat))) = + let + val fxs = map Free xs + val branch_hyp = assume (cert (HOLogic.mk_Trueprop (HOLogic.mk_eq (x, pat)))) + + val C_hyps = map (cert #> assume) Cs + + val (relevant_cases, ineqss') = filter (fn ((_, SchemeCase {bidx=bidx', ...}), _) => bidx' = bidx) (scases_idx ~~ ineqss) + |> split_list + + fun prove_case (cidx, SchemeCase {qs, oqnames, gs, lhs, rs, ...}) ineq_press = + let + val case_hyps = map (assume o cert o HOLogic.mk_Trueprop o HOLogic.mk_eq) (fxs ~~ lhs) + + val cqs = map (cert o Free) qs + val ags = map (assume o cert) gs + + val replace_x_ss = HOL_basic_ss addsimps (branch_hyp :: case_hyps) + val sih = full_simplify replace_x_ss aihyp + + fun mk_Prec (idx, Gvs, Gas, rcargs) (ineq, pres) = + let + val cGas = map (assume o cert) Gas + val cGvs = map (cert o Free) Gvs + val import = fold forall_elim (cqs @ cGvs) + #> fold Thm.elim_implies (ags @ cGas) + val ipres = pres + |> forall_elim (cert (list_comb (P_of idx, rcargs))) + |> import + in + sih |> forall_elim (cert (inject idx rcargs)) + |> Thm.elim_implies (import ineq) (* Psum rcargs *) + |> Conv.fconv_rule sum_prod_conv + |> Conv.fconv_rule ind_rulify + |> (fn th => th COMP ipres) (* P rs *) + |> fold_rev (implies_intr o cprop_of) cGas + |> fold_rev forall_intr cGvs + end + + val P_recs = map2 mk_Prec rs ineq_press (* [P rec1, P rec2, ... ] *) + + val step = HOLogic.mk_Trueprop (list_comb (P, lhs)) + |> fold_rev (curry Logic.mk_implies o prop_of) P_recs + |> fold_rev (curry Logic.mk_implies) gs + |> fold_rev (Logic.all o Free) qs + |> cert + + val Plhs_to_Pxs_conv = + foldl1 (uncurry Conv.combination_conv) + (Conv.all_conv :: map (fn ch => K (Thm.symmetric (ch RS eq_reflection))) case_hyps) + + val res = assume step + |> fold forall_elim cqs + |> fold Thm.elim_implies ags + |> fold Thm.elim_implies P_recs (* P lhs *) + |> Conv.fconv_rule (Conv.arg_conv Plhs_to_Pxs_conv) (* P xs *) + |> fold_rev (implies_intr o cprop_of) (ags @ case_hyps) + |> fold_rev forall_intr cqs (* !!qs. Gas ==> xs = lhss ==> P xs *) + in + (res, (cidx, step)) + end + + val (cases, steps) = split_list (map2 prove_case relevant_cases ineqss') + + val bstep = complete_thm + |> forall_elim (cert (list_comb (P, fxs))) + |> fold (forall_elim o cert) (fxs @ map Free ws) + |> fold Thm.elim_implies C_hyps (* FIXME: optimization using rotate_prems *) + |> fold Thm.elim_implies cases (* P xs *) + |> fold_rev (implies_intr o cprop_of) C_hyps + |> fold_rev (forall_intr o cert o Free) ws + + val Pxs = cert (HOLogic.mk_Trueprop (P_comp $ x)) + |> Goal.init + |> (MetaSimplifier.rewrite_goals_tac (map meta (branch_hyp :: @{thm split_conv} :: @{thms sum.cases})) + THEN CONVERSION ind_rulify 1) + |> Seq.hd + |> Thm.elim_implies (Conv.fconv_rule Drule.beta_eta_conversion bstep) + |> Goal.finish + |> implies_intr (cprop_of branch_hyp) + |> fold_rev (forall_intr o cert) fxs + in + (Pxs, steps) + end + + val (branches, steps) = split_list (map_index prove_branch (branches ~~ (complete_thms ~~ pats))) + |> apsnd flat + + val istep = sum_split_rule + |> fold (fn b => fn th => Drule.compose_single (b, 1, th)) branches + |> implies_intr ihyp + |> forall_intr (cert x) (* "!!x. (!!y P x" *) + + val induct_rule = + @{thm "wf_induct_rule"} + |> (curry op COMP) wf_thm + |> (curry op COMP) istep + + val steps_sorted = map snd (sort (int_ord o pairself fst) steps) + in + (steps_sorted, induct_rule) + end + + +fun mk_ind_tac comp_tac pres_tac term_tac ctxt facts = (ALLGOALS (Method.insert_tac facts)) THEN HEADGOAL +(SUBGOAL (fn (t, i) => + let + val (ctxt', _, cases, concl) = dest_hhf ctxt t + val scheme as IndScheme {T=ST, branches, ...} = mk_scheme' ctxt' cases concl +(* val _ = Output.tracing (makestring scheme)*) + val ([Rn,xn], ctxt'') = Variable.variant_fixes ["R","x"] ctxt' + val R = Free (Rn, mk_relT ST) + val x = Free (xn, ST) + val cert = cterm_of (ProofContext.theory_of ctxt) + + val ineqss = mk_ineqs R scheme + |> map (map (pairself (assume o cert))) + val complete = map (mk_completeness ctxt scheme #> cert #> assume) (0 upto (length branches - 1)) + val wf_thm = mk_wf ctxt R scheme |> cert |> assume + + val (descent, pres) = split_list (flat ineqss) + val newgoals = complete @ pres @ wf_thm :: descent + + val (steps, indthm) = mk_induct_rule ctxt'' R x complete wf_thm ineqss scheme + + fun project (i, SchemeBranch {xs, ...}) = + let + val inst = cert (SumTree.mk_inj ST (length branches) (i + 1) (foldr1 HOLogic.mk_prod (map Free xs))) + in + indthm |> Drule.instantiate' [] [SOME inst] + |> simplify SumTree.sumcase_split_ss + |> Conv.fconv_rule ind_rulify +(* |> (fn thm => (Output.tracing (makestring thm); thm))*) + end + + val res = Conjunction.intr_balanced (map_index project branches) + |> fold_rev implies_intr (map cprop_of newgoals @ steps) + |> (fn thm => Thm.generalize ([], [Rn]) (Thm.maxidx_of thm + 1) thm) + + val nbranches = length branches + val npres = length pres + in + Thm.compose_no_flatten false (res, length newgoals) i + THEN term_tac (i + nbranches + npres) + THEN (EVERY (map (TRY o pres_tac) ((i + nbranches + npres - 1) downto (i + nbranches)))) + THEN (EVERY (map (TRY o comp_tac) ((i + nbranches - 1) downto i))) + end)) + + +fun induct_scheme_tac ctxt = + mk_ind_tac (K all_tac) (assume_tac APPEND' Goal.assume_rule_tac ctxt) (K all_tac) ctxt; + +val setup = + Method.setup @{binding induct_scheme} (Scan.succeed (RAW_METHOD o induct_scheme_tac)) + "proves an induction principle" + +end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/inductive_wrap.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/inductive_wrap.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,67 @@ +(* Title: HOL/Tools/Function/inductive_wrap.ML + Author: Alexander Krauss, TU Muenchen + + +A wrapper around the inductive package, restoring the quantifiers in +the introduction and elimination rules. +*) + +signature FUNDEF_INDUCTIVE_WRAP = +sig + val inductive_def : term list + -> ((bstring * typ) * mixfix) * local_theory + -> thm list * (term * thm * thm * local_theory) +end + +structure FundefInductiveWrap: FUNDEF_INDUCTIVE_WRAP = +struct + +open FundefLib + +fun requantify ctxt lfix orig_def thm = + let + val (qs, t) = dest_all_all orig_def + val thy = theory_of_thm thm + val frees = frees_in_term ctxt t + |> remove (op =) lfix + val vars = Term.add_vars (prop_of thm) [] |> rev + + val varmap = frees ~~ vars + in + fold_rev (fn Free (n, T) => + forall_intr_rename (n, cterm_of thy (Var (the_default (("",0), T) (AList.lookup (op =) varmap (n, T)))))) + qs + thm + end + + + +fun inductive_def defs (((R, T), mixfix), lthy) = + let + val ({intrs = intrs_gen, elims = [elim_gen], preds = [ Rdef ], induct, ...}, lthy) = + Inductive.add_inductive_i + {quiet_mode = false, + verbose = ! Toplevel.debug, + kind = Thm.internalK, + alt_name = Binding.empty, + coind = false, + no_elim = false, + no_ind = false, + skip_mono = true, + fork_mono = false} + [((Binding.name R, T), NoSyn)] (* the relation *) + [] (* no parameters *) + (map (fn t => (Attrib.empty_binding, t)) defs) (* the intros *) + [] (* no special monos *) + lthy + + val intrs = map2 (requantify lthy (R, T)) defs intrs_gen + + val elim = elim_gen + |> forall_intr_vars (* FIXME... *) + + in + (intrs, (Rdef, elim, induct, lthy)) + end + +end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/lexicographic_order.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/lexicographic_order.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,230 @@ +(* Title: HOL/Tools/Function/lexicographic_order.ML + Author: Lukas Bulwahn, TU Muenchen + +Method for termination proofs with lexicographic orderings. +*) + +signature LEXICOGRAPHIC_ORDER = +sig + val lex_order_tac : Proof.context -> tactic -> tactic + val lexicographic_order_tac : Proof.context -> tactic + val lexicographic_order : Proof.context -> Proof.method + + val setup: theory -> theory +end + +structure LexicographicOrder : LEXICOGRAPHIC_ORDER = +struct + +open FundefLib + +(** General stuff **) + +fun mk_measures domT mfuns = + let + val relT = HOLogic.mk_setT (HOLogic.mk_prodT (domT, domT)) + val mlexT = (domT --> HOLogic.natT) --> relT --> relT + fun mk_ms [] = Const (@{const_name Set.empty}, relT) + | mk_ms (f::fs) = + Const (@{const_name "mlex_prod"}, mlexT) $ f $ mk_ms fs + in + mk_ms mfuns + end + +fun del_index n [] = [] + | del_index n (x :: xs) = + if n > 0 then x :: del_index (n - 1) xs else xs + +fun transpose ([]::_) = [] + | transpose xss = map hd xss :: transpose (map tl xss) + +(** Matrix cell datatype **) + +datatype cell = Less of thm| LessEq of (thm * thm) | None of (thm * thm) | False of thm; + +fun is_Less (Less _) = true + | is_Less _ = false + +fun is_LessEq (LessEq _) = true + | is_LessEq _ = false + +fun pr_cell (Less _ ) = " < " + | pr_cell (LessEq _) = " <=" + | pr_cell (None _) = " ? " + | pr_cell (False _) = " F " + + +(** Proof attempts to build the matrix **) + +fun dest_term (t : term) = + let + val (vars, prop) = FundefLib.dest_all_all t + val (prems, concl) = Logic.strip_horn prop + val (lhs, rhs) = concl + |> HOLogic.dest_Trueprop + |> HOLogic.dest_mem |> fst + |> HOLogic.dest_prod + in + (vars, prems, lhs, rhs) + end + +fun mk_goal (vars, prems, lhs, rhs) rel = + let + val concl = HOLogic.mk_binrel rel (lhs, rhs) |> HOLogic.mk_Trueprop + in + fold_rev Logic.all vars (Logic.list_implies (prems, concl)) + end + +fun prove thy solve_tac t = + cterm_of thy t |> Goal.init + |> SINGLE solve_tac |> the + +fun mk_cell (thy : theory) solve_tac (vars, prems, lhs, rhs) mfun = + let + val goals = cterm_of thy o mk_goal (vars, prems, mfun $ lhs, mfun $ rhs) + in + case try_proof (goals @{const_name HOL.less}) solve_tac of + Solved thm => Less thm + | Stuck thm => + (case try_proof (goals @{const_name HOL.less_eq}) solve_tac of + Solved thm2 => LessEq (thm2, thm) + | Stuck thm2 => + if prems_of thm2 = [HOLogic.Trueprop $ HOLogic.false_const] then False thm2 + else None (thm2, thm) + | _ => raise Match) (* FIXME *) + | _ => raise Match + end + + +(** Search algorithms **) + +fun check_col ls = forall (fn c => is_Less c orelse is_LessEq c) ls andalso not (forall (is_LessEq) ls) + +fun transform_table table col = table |> filter_out (fn x => is_Less (nth x col)) |> map (del_index col) + +fun transform_order col order = map (fn x => if x >= col then x + 1 else x) order + +(* simple depth-first search algorithm for the table *) +fun search_table table = + case table of + [] => SOME [] + | _ => + let + val col = find_index (check_col) (transpose table) + in case col of + ~1 => NONE + | _ => + let + val order_opt = (table, col) |-> transform_table |> search_table + in case order_opt of + NONE => NONE + | SOME order =>SOME (col :: transform_order col order) + end + end + +(** Proof Reconstruction **) + +(* prove row :: cell list -> tactic *) +fun prove_row (Less less_thm :: _) = + (rtac @{thm "mlex_less"} 1) + THEN PRIMITIVE (Thm.elim_implies less_thm) + | prove_row (LessEq (lesseq_thm, _) :: tail) = + (rtac @{thm "mlex_leq"} 1) + THEN PRIMITIVE (Thm.elim_implies lesseq_thm) + THEN prove_row tail + | prove_row _ = sys_error "lexicographic_order" + + +(** Error reporting **) + +fun pr_table table = writeln (cat_lines (map (fn r => concat (map pr_cell r)) table)) + +fun pr_goals ctxt st = + Display.pretty_goals_aux (Syntax.pp ctxt) Markup.none (true, false) (Thm.nprems_of st) st + |> Pretty.chunks + |> Pretty.string_of + +fun row_index i = chr (i + 97) +fun col_index j = string_of_int (j + 1) + +fun pr_unprovable_cell _ ((i,j), Less _) = "" + | pr_unprovable_cell ctxt ((i,j), LessEq (_, st)) = + "(" ^ row_index i ^ ", " ^ col_index j ^ ", <):\n" ^ pr_goals ctxt st + | pr_unprovable_cell ctxt ((i,j), None (st_leq, st_less)) = + "(" ^ row_index i ^ ", " ^ col_index j ^ ", <):\n" ^ pr_goals ctxt st_less + ^ "\n(" ^ row_index i ^ ", " ^ col_index j ^ ", <=):\n" ^ pr_goals ctxt st_leq + | pr_unprovable_cell ctxt ((i,j), False st) = + "(" ^ row_index i ^ ", " ^ col_index j ^ ", <):\n" ^ pr_goals ctxt st + +fun pr_unprovable_subgoals ctxt table = + table + |> map_index (fn (i,cs) => map_index (fn (j,x) => ((i,j), x)) cs) + |> flat + |> map (pr_unprovable_cell ctxt) + +fun no_order_msg ctxt table tl measure_funs = + let + val prterm = Syntax.string_of_term ctxt + fun pr_fun t i = string_of_int i ^ ") " ^ prterm t + + fun pr_goal t i = + let + val (_, _, lhs, rhs) = dest_term t + in (* also show prems? *) + i ^ ") " ^ prterm rhs ^ " ~> " ^ prterm lhs + end + + val gc = map (fn i => chr (i + 96)) (1 upto length table) + val mc = 1 upto length measure_funs + val tstr = "Result matrix:" :: (" " ^ concat (map (enclose " " " " o string_of_int) mc)) + :: map2 (fn r => fn i => i ^ ": " ^ concat (map pr_cell r)) table gc + val gstr = "Calls:" :: map2 (prefix " " oo pr_goal) tl gc + val mstr = "Measures:" :: map2 (prefix " " oo pr_fun) measure_funs mc + val ustr = "Unfinished subgoals:" :: pr_unprovable_subgoals ctxt table + in + cat_lines (ustr @ gstr @ mstr @ tstr @ ["", "Could not find lexicographic termination order."]) + end + +(** The Main Function **) + +fun lex_order_tac ctxt solve_tac (st: thm) = + let + val thy = ProofContext.theory_of ctxt + val ((trueprop $ (wf $ rel)) :: tl) = prems_of st + + val (domT, _) = HOLogic.dest_prodT (HOLogic.dest_setT (fastype_of rel)) + + val measure_funs = MeasureFunctions.get_measure_functions ctxt domT (* 1: generate measures *) + + (* 2: create table *) + val table = map (fn t => map (mk_cell thy solve_tac (dest_term t)) measure_funs) tl + + val order = the (search_table table) (* 3: search table *) + handle Option => error (no_order_msg ctxt table tl measure_funs) + + val clean_table = map (fn x => map (nth x) order) table + + val relation = mk_measures domT (map (nth measure_funs) order) + val _ = writeln ("Found termination order: " ^ quote (Syntax.string_of_term ctxt relation)) + + in (* 4: proof reconstruction *) + st |> (PRIMITIVE (cterm_instantiate [(cterm_of thy rel, cterm_of thy relation)]) + THEN (REPEAT (rtac @{thm "wf_mlex"} 1)) + THEN (rtac @{thm "wf_empty"} 1) + THEN EVERY (map prove_row clean_table)) + end + +fun lexicographic_order_tac ctxt = + TRY (FundefCommon.apply_termination_rule ctxt 1) + THEN lex_order_tac ctxt (auto_tac (local_clasimpset_of ctxt addsimps2 FundefCommon.TerminationSimps.get ctxt)) + +val lexicographic_order = SIMPLE_METHOD o lexicographic_order_tac + +val setup = + Method.setup @{binding lexicographic_order} + (Method.sections clasimp_modifiers >> (K lexicographic_order)) + "termination prover for lexicographic orderings" + #> Context.theory_map (FundefCommon.set_termination_prover lexicographic_order) + +end; + diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/measure_functions.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/measure_functions.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,58 @@ +(* Title: HOL/Tools/Function/measure_functions.ML + Author: Alexander Krauss, TU Muenchen + +Measure functions, generated heuristically +*) + +signature MEASURE_FUNCTIONS = +sig + + val get_measure_functions : Proof.context -> typ -> term list + val setup : theory -> theory + +end + +structure MeasureFunctions : MEASURE_FUNCTIONS = +struct + +(** User-declared size functions **) +structure MeasureHeuristicRules = NamedThmsFun( + val name = "measure_function" + val description = "Rules that guide the heuristic generation of measure functions" +); + +fun mk_is_measures t = Const (@{const_name "is_measure"}, fastype_of t --> HOLogic.boolT) $ t + +fun find_measures ctxt T = + DEPTH_SOLVE (resolve_tac (MeasureHeuristicRules.get ctxt) 1) + (HOLogic.mk_Trueprop (mk_is_measures (Var (("f",0), T --> HOLogic.natT))) + |> cterm_of (ProofContext.theory_of ctxt) |> Goal.init) + |> Seq.map (prop_of #> (fn _ $ (_ $ (_ $ f)) => f)) + |> Seq.list_of + + +(** Generating Measure Functions **) + +fun constant_0 T = Abs ("x", T, HOLogic.zero) +fun constant_1 T = Abs ("x", T, HOLogic.Suc_zero) + +fun mk_funorder_funs (Type ("+", [fT, sT])) = + map (fn m => SumTree.mk_sumcase fT sT HOLogic.natT m (constant_0 sT)) (mk_funorder_funs fT) + @ map (fn m => SumTree.mk_sumcase fT sT HOLogic.natT (constant_0 fT) m) (mk_funorder_funs sT) + | mk_funorder_funs T = [ constant_1 T ] + +fun mk_ext_base_funs ctxt (Type("+", [fT, sT])) = + map_product (SumTree.mk_sumcase fT sT HOLogic.natT) + (mk_ext_base_funs ctxt fT) (mk_ext_base_funs ctxt sT) + | mk_ext_base_funs ctxt T = find_measures ctxt T + +fun mk_all_measure_funs ctxt (T as Type ("+", _)) = + mk_ext_base_funs ctxt T @ mk_funorder_funs T + | mk_all_measure_funs ctxt T = find_measures ctxt T + +val get_measure_functions = mk_all_measure_funs + +val setup = MeasureHeuristicRules.setup + +end + diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/mutual.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/mutual.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,314 @@ +(* Title: HOL/Tools/Function/mutual.ML + Author: Alexander Krauss, TU Muenchen + +A package for general recursive function definitions. +Tools for mutual recursive definitions. +*) + +signature FUNDEF_MUTUAL = +sig + + val prepare_fundef_mutual : FundefCommon.fundef_config + -> string (* defname *) + -> ((string * typ) * mixfix) list + -> term list + -> local_theory + -> ((thm (* goalstate *) + * (thm -> FundefCommon.fundef_result) (* proof continuation *) + ) * local_theory) + +end + + +structure FundefMutual: FUNDEF_MUTUAL = +struct + +open FundefLib +open FundefCommon + + + + +type qgar = string * (string * typ) list * term list * term list * term + +fun name_of_fqgar ((f, _, _, _, _): qgar) = f + +datatype mutual_part = + MutualPart of + { + i : int, + i' : int, + fvar : string * typ, + cargTs: typ list, + f_def: term, + + f: term option, + f_defthm : thm option + } + + +datatype mutual_info = + Mutual of + { + n : int, + n' : int, + fsum_var : string * typ, + + ST: typ, + RST: typ, + + parts: mutual_part list, + fqgars: qgar list, + qglrs: ((string * typ) list * term list * term * term) list, + + fsum : term option + } + +fun mutual_induct_Pnames n = + if n < 5 then fst (chop n ["P","Q","R","S"]) + else map (fn i => "P" ^ string_of_int i) (1 upto n) + +fun get_part fname = + the o find_first (fn (MutualPart {fvar=(n,_), ...}) => n = fname) + +(* FIXME *) +fun mk_prod_abs e (t1, t2) = + let + val bTs = rev (map snd e) + val T1 = fastype_of1 (bTs, t1) + val T2 = fastype_of1 (bTs, t2) + in + HOLogic.pair_const T1 T2 $ t1 $ t2 + end; + + +fun analyze_eqs ctxt defname fs eqs = + let + val num = length fs + val fnames = map fst fs + val fqgars = map (split_def ctxt) eqs + val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars + |> AList.lookup (op =) #> the + + fun curried_types (fname, fT) = + let + val (caTs, uaTs) = chop (arity_of fname) (binder_types fT) + in + (caTs, uaTs ---> body_type fT) + end + + val (caTss, resultTs) = split_list (map curried_types fs) + val argTs = map (foldr1 HOLogic.mk_prodT) caTss + + val dresultTs = distinct (Type.eq_type Vartab.empty) resultTs + val n' = length dresultTs + + val RST = BalancedTree.make (uncurry SumTree.mk_sumT) dresultTs + val ST = BalancedTree.make (uncurry SumTree.mk_sumT) argTs + + val fsum_type = ST --> RST + + val ([fsum_var_name], _) = Variable.add_fixes [ defname ^ "_sum" ] ctxt + val fsum_var = (fsum_var_name, fsum_type) + + fun define (fvar as (n, T)) caTs resultT i = + let + val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *) + val i' = find_index (fn Ta => Type.eq_type Vartab.empty (Ta, resultT)) dresultTs + 1 + + val f_exp = SumTree.mk_proj RST n' i' (Free fsum_var $ SumTree.mk_inj ST num i (foldr1 HOLogic.mk_prod vars)) + val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp) + + val rew = (n, fold_rev lambda vars f_exp) + in + (MutualPart {i=i, i'=i', fvar=fvar,cargTs=caTs,f_def=def,f=NONE,f_defthm=NONE}, rew) + end + + val (parts, rews) = split_list (map4 define fs caTss resultTs (1 upto num)) + + fun convert_eqs (f, qs, gs, args, rhs) = + let + val MutualPart {i, i', ...} = get_part f parts + in + (qs, gs, SumTree.mk_inj ST num i (foldr1 (mk_prod_abs qs) args), + SumTree.mk_inj RST n' i' (replace_frees rews rhs) + |> Envir.beta_norm) + end + + val qglrs = map convert_eqs fqgars + in + Mutual {n=num, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, + parts=parts, fqgars=fqgars, qglrs=qglrs, fsum=NONE} + end + + + + +fun define_projections fixes mutual fsum lthy = + let + fun def ((MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs, f_def, ...}), (_, mixfix)) lthy = + let + val ((f, (_, f_defthm)), lthy') = + LocalTheory.define Thm.internalK ((Binding.name fname, mixfix), + ((Binding.name (fname ^ "_def"), []), Term.subst_bound (fsum, f_def))) + lthy + in + (MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs=cargTs, f_def=f_def, + f=SOME f, f_defthm=SOME f_defthm }, + lthy') + end + + val Mutual { n, n', fsum_var, ST, RST, parts, fqgars, qglrs, ... } = mutual + val (parts', lthy') = fold_map def (parts ~~ fixes) lthy + in + (Mutual { n=n, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, parts=parts', + fqgars=fqgars, qglrs=qglrs, fsum=SOME fsum }, + lthy') + end + + +fun in_context ctxt (f, pre_qs, pre_gs, pre_args, pre_rhs) F = + let + val thy = ProofContext.theory_of ctxt + + val oqnames = map fst pre_qs + val (qs, ctxt') = Variable.variant_fixes oqnames ctxt + |>> map2 (fn (_, T) => fn n => Free (n, T)) pre_qs + + fun inst t = subst_bounds (rev qs, t) + val gs = map inst pre_gs + val args = map inst pre_args + val rhs = inst pre_rhs + + val cqs = map (cterm_of thy) qs + val ags = map (assume o cterm_of thy) gs + + val import = fold forall_elim cqs + #> fold Thm.elim_implies ags + + val export = fold_rev (implies_intr o cprop_of) ags + #> fold_rev forall_intr_rename (oqnames ~~ cqs) + in + F ctxt (f, qs, gs, args, rhs) import export + end + +fun recover_mutual_psimp all_orig_fdefs parts ctxt (fname, _, _, args, rhs) import (export : thm -> thm) sum_psimp_eq = + let + val (MutualPart {f=SOME f, f_defthm=SOME f_def, ...}) = get_part fname parts + + val psimp = import sum_psimp_eq + val (simp, restore_cond) = case cprems_of psimp of + [] => (psimp, I) + | [cond] => (implies_elim psimp (assume cond), implies_intr cond) + | _ => sys_error "Too many conditions" + in + Goal.prove ctxt [] [] + (HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs)) + (fn _ => (LocalDefs.unfold_tac ctxt all_orig_fdefs) + THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1 + THEN (simp_tac (local_simpset_of ctxt addsimps SumTree.proj_in_rules)) 1) + |> restore_cond + |> export + end + + +(* FIXME HACK *) +fun mk_applied_form ctxt caTs thm = + let + val thy = ProofContext.theory_of ctxt + val xs = map_index (fn (i,T) => cterm_of thy (Free ("x" ^ string_of_int i, T))) caTs (* FIXME: Bind xs properly *) + in + fold (fn x => fn thm => combination thm (reflexive x)) xs thm + |> Conv.fconv_rule (Thm.beta_conversion true) + |> fold_rev forall_intr xs + |> Thm.forall_elim_vars 0 + end + + +fun mutual_induct_rules lthy induct all_f_defs (Mutual {n, ST, RST, parts, ...}) = + let + val cert = cterm_of (ProofContext.theory_of lthy) + val newPs = map2 (fn Pname => fn MutualPart {cargTs, ...} => + Free (Pname, cargTs ---> HOLogic.boolT)) + (mutual_induct_Pnames (length parts)) + parts + + fun mk_P (MutualPart {cargTs, ...}) P = + let + val avars = map_index (fn (i,T) => Var (("a", i), T)) cargTs + val atup = foldr1 HOLogic.mk_prod avars + in + tupled_lambda atup (list_comb (P, avars)) + end + + val Ps = map2 mk_P parts newPs + val case_exp = SumTree.mk_sumcases HOLogic.boolT Ps + + val induct_inst = + forall_elim (cert case_exp) induct + |> full_simplify SumTree.sumcase_split_ss + |> full_simplify (HOL_basic_ss addsimps all_f_defs) + + fun project rule (MutualPart {cargTs, i, ...}) k = + let + val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int (j + k), T)) cargTs (* FIXME! *) + val inj = SumTree.mk_inj ST n i (foldr1 HOLogic.mk_prod afs) + in + (rule + |> forall_elim (cert inj) + |> full_simplify SumTree.sumcase_split_ss + |> fold_rev (forall_intr o cert) (afs @ newPs), + k + length cargTs) + end + in + fst (fold_map (project induct_inst) parts 0) + end + + +fun mk_partial_rules_mutual lthy inner_cont (m as Mutual {parts, fqgars, ...}) proof = + let + val result = inner_cont proof + val FundefResult {fs=[f], G, R, cases, psimps, trsimps, simple_pinducts=[simple_pinduct], + termination,domintros} = result + + val (all_f_defs, fs) = map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} => + (mk_applied_form lthy cargTs (symmetric f_def), f)) + parts + |> split_list + + val all_orig_fdefs = map (fn MutualPart {f_defthm = SOME f_def, ...} => f_def) parts + + fun mk_mpsimp fqgar sum_psimp = + in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp + + val rew_ss = HOL_basic_ss addsimps all_f_defs + val mpsimps = map2 mk_mpsimp fqgars psimps + val mtrsimps = map_option (map2 mk_mpsimp fqgars) trsimps + val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m + val mtermination = full_simplify rew_ss termination + val mdomintros = map_option (map (full_simplify rew_ss)) domintros + in + FundefResult { fs=fs, G=G, R=R, + psimps=mpsimps, simple_pinducts=minducts, + cases=cases, termination=mtermination, + domintros=mdomintros, + trsimps=mtrsimps} + end + +fun prepare_fundef_mutual config defname fixes eqss lthy = + let + val mutual = analyze_eqs lthy defname (map fst fixes) (map Envir.beta_eta_contract eqss) + val Mutual {fsum_var=(n, T), qglrs, ...} = mutual + + val ((fsum, goalstate, cont), lthy') = + FundefCore.prepare_fundef config defname [((n, T), NoSyn)] qglrs lthy + + val (mutual', lthy'') = define_projections fixes mutual fsum lthy' + + val mutual_cont = mk_partial_rules_mutual lthy'' cont mutual' + in + ((goalstate, mutual_cont), lthy'') + end + + +end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/pattern_split.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/pattern_split.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,137 @@ +(* Title: HOL/Tools/Function/pattern_split.ML + Author: Alexander Krauss, TU Muenchen + +A package for general recursive function definitions. + +Automatic splitting of overlapping constructor patterns. This is a preprocessing step which +turns a specification with overlaps into an overlap-free specification. + +*) + +signature FUNDEF_SPLIT = +sig + val split_some_equations : + Proof.context -> (bool * term) list -> term list list + + val split_all_equations : + Proof.context -> term list -> term list list +end + +structure FundefSplit : FUNDEF_SPLIT = +struct + +open FundefLib + +(* We use proof context for the variable management *) +(* FIXME: no __ *) + +fun new_var ctx vs T = + let + val [v] = Variable.variant_frees ctx vs [("v", T)] + in + (Free v :: vs, Free v) + end + +fun saturate ctx vs t = + fold (fn T => fn (vs, t) => new_var ctx vs T |> apsnd (curry op $ t)) + (binder_types (fastype_of t)) (vs, t) + + +(* This is copied from "fundef_datatype.ML" *) +fun inst_constrs_of thy (T as Type (name, _)) = + map (fn (Cn,CT) => Envir.subst_TVars (Sign.typ_match thy (body_type CT, T) Vartab.empty) (Const (Cn, CT))) + (the (Datatype.get_datatype_constrs thy name)) + | inst_constrs_of thy T = raise TYPE ("inst_constrs_of", [T], []) + + + + +fun join ((vs1,sub1), (vs2,sub2)) = (merge (op aconv) (vs1,vs2), sub1 @ sub2) +fun join_product (xs, ys) = map_product (curry join) xs ys + +fun join_list [] = [] + | join_list xs = foldr1 (join_product) xs + + +exception DISJ + +fun pattern_subtract_subst ctx vs t t' = + let + exception DISJ + fun pattern_subtract_subst_aux vs _ (Free v2) = [] + | pattern_subtract_subst_aux vs (v as (Free (_, T))) t' = + let + fun foo constr = + let + val (vs', t) = saturate ctx vs constr + val substs = pattern_subtract_subst ctx vs' t t' + in + map (fn (vs, subst) => (vs, (v,t)::subst)) substs + end + in + flat (map foo (inst_constrs_of (ProofContext.theory_of ctx) T)) + end + | pattern_subtract_subst_aux vs t t' = + let + val (C, ps) = strip_comb t + val (C', qs) = strip_comb t' + in + if C = C' + then flat (map2 (pattern_subtract_subst_aux vs) ps qs) + else raise DISJ + end + in + pattern_subtract_subst_aux vs t t' + handle DISJ => [(vs, [])] + end + + +(* p - q *) +fun pattern_subtract ctx eq2 eq1 = + let + val thy = ProofContext.theory_of ctx + + val (vs, feq1 as (_ $ (_ $ lhs1 $ _))) = dest_all_all eq1 + val (_, _ $ (_ $ lhs2 $ _)) = dest_all_all eq2 + + val substs = pattern_subtract_subst ctx vs lhs1 lhs2 + + fun instantiate (vs', sigma) = + let + val t = Pattern.rewrite_term thy sigma [] feq1 + in + fold_rev Logic.all (map Free (frees_in_term ctx t) inter vs') t + end + in + map instantiate substs + end + + +(* ps - p' *) +fun pattern_subtract_from_many ctx p'= + flat o map (pattern_subtract ctx p') + +(* in reverse order *) +fun pattern_subtract_many ctx ps' = + fold_rev (pattern_subtract_from_many ctx) ps' + + + +fun split_some_equations ctx eqns = + let + fun split_aux prev [] = [] + | split_aux prev ((true, eq) :: es) = pattern_subtract_many ctx prev [eq] + :: split_aux (eq :: prev) es + | split_aux prev ((false, eq) :: es) = [eq] + :: split_aux (eq :: prev) es + in + split_aux [] eqns + end + +fun split_all_equations ctx = + split_some_equations ctx o map (pair true) + + + + +end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/scnp_reconstruct.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/scnp_reconstruct.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,429 @@ +(* Title: HOL/Tools/Function/scnp_reconstruct.ML + Author: Armin Heller, TU Muenchen + Author: Alexander Krauss, TU Muenchen + +Proof reconstruction for SCNP +*) + +signature SCNP_RECONSTRUCT = +sig + + val sizechange_tac : Proof.context -> tactic -> tactic + + val decomp_scnp : ScnpSolve.label list -> Proof.context -> Proof.method + + val setup : theory -> theory + + datatype multiset_setup = + Multiset of + { + msetT : typ -> typ, + mk_mset : typ -> term list -> term, + mset_regroup_conv : int list -> conv, + mset_member_tac : int -> int -> tactic, + mset_nonempty_tac : int -> tactic, + mset_pwleq_tac : int -> tactic, + set_of_simps : thm list, + smsI' : thm, + wmsI2'' : thm, + wmsI1 : thm, + reduction_pair : thm + } + + + val multiset_setup : multiset_setup -> theory -> theory + +end + +structure ScnpReconstruct : SCNP_RECONSTRUCT = +struct + +val PROFILE = FundefCommon.PROFILE +fun TRACE x = if ! FundefCommon.profile then Output.tracing x else () + +open ScnpSolve + +val natT = HOLogic.natT +val nat_pairT = HOLogic.mk_prodT (natT, natT) + +(* Theory dependencies *) + +datatype multiset_setup = + Multiset of + { + msetT : typ -> typ, + mk_mset : typ -> term list -> term, + mset_regroup_conv : int list -> conv, + mset_member_tac : int -> int -> tactic, + mset_nonempty_tac : int -> tactic, + mset_pwleq_tac : int -> tactic, + set_of_simps : thm list, + smsI' : thm, + wmsI2'' : thm, + wmsI1 : thm, + reduction_pair : thm + } + +structure MultisetSetup = TheoryDataFun +( + type T = multiset_setup option + val empty = NONE + val copy = I; + val extend = I; + fun merge _ (v1, v2) = if is_some v2 then v2 else v1 +) + +val multiset_setup = MultisetSetup.put o SOME + +fun undef x = error "undef" +fun get_multiset_setup thy = MultisetSetup.get thy + |> the_default (Multiset +{ msetT = undef, mk_mset=undef, + mset_regroup_conv=undef, mset_member_tac = undef, + mset_nonempty_tac = undef, mset_pwleq_tac = undef, + set_of_simps = [],reduction_pair = refl, + smsI'=refl, wmsI2''=refl, wmsI1=refl }) + +fun order_rpair _ MAX = @{thm max_rpair_set} + | order_rpair msrp MS = msrp + | order_rpair _ MIN = @{thm min_rpair_set} + +fun ord_intros_max true = + (@{thm smax_emptyI}, @{thm smax_insertI}) + | ord_intros_max false = + (@{thm wmax_emptyI}, @{thm wmax_insertI}) +fun ord_intros_min true = + (@{thm smin_emptyI}, @{thm smin_insertI}) + | ord_intros_min false = + (@{thm wmin_emptyI}, @{thm wmin_insertI}) + +fun gen_probl D cs = + let + val n = Termination.get_num_points D + val arity = length o Termination.get_measures D + fun measure p i = nth (Termination.get_measures D p) i + + fun mk_graph c = + let + val (_, p, _, q, _, _) = Termination.dest_call D c + + fun add_edge i j = + case Termination.get_descent D c (measure p i) (measure q j) + of SOME (Termination.Less _) => cons (i, GTR, j) + | SOME (Termination.LessEq _) => cons (i, GEQ, j) + | _ => I + + val edges = + fold_product add_edge (0 upto arity p - 1) (0 upto arity q - 1) [] + in + G (p, q, edges) + end + in + GP (map arity (0 upto n - 1), map mk_graph cs) + end + +(* General reduction pair application *) +fun rem_inv_img ctxt = + let + val unfold_tac = LocalDefs.unfold_tac ctxt + in + rtac @{thm subsetI} 1 + THEN etac @{thm CollectE} 1 + THEN REPEAT (etac @{thm exE} 1) + THEN unfold_tac @{thms inv_image_def} + THEN rtac @{thm CollectI} 1 + THEN etac @{thm conjE} 1 + THEN etac @{thm ssubst} 1 + THEN unfold_tac (@{thms split_conv} @ @{thms triv_forall_equality} + @ @{thms sum.cases}) + end + +(* Sets *) + +val setT = HOLogic.mk_setT + +fun set_member_tac m i = + if m = 0 then rtac @{thm insertI1} i + else rtac @{thm insertI2} i THEN set_member_tac (m - 1) i + +val set_nonempty_tac = rtac @{thm insert_not_empty} + +fun set_finite_tac i = + rtac @{thm finite.emptyI} i + ORELSE (rtac @{thm finite.insertI} i THEN (fn st => set_finite_tac i st)) + + +(* Reconstruction *) + +fun reconstruct_tac ctxt D cs (gp as GP (_, gs)) certificate = + let + val thy = ProofContext.theory_of ctxt + val Multiset + { msetT, mk_mset, + mset_regroup_conv, mset_member_tac, + mset_nonempty_tac, mset_pwleq_tac, set_of_simps, + smsI', wmsI2'', wmsI1, reduction_pair=ms_rp } + = get_multiset_setup thy + + fun measure_fn p = nth (Termination.get_measures D p) + + fun get_desc_thm cidx m1 m2 bStrict = + case Termination.get_descent D (nth cs cidx) m1 m2 + of SOME (Termination.Less thm) => + if bStrict then thm + else (thm COMP (Thm.lift_rule (cprop_of thm) @{thm less_imp_le})) + | SOME (Termination.LessEq (thm, _)) => + if not bStrict then thm + else sys_error "get_desc_thm" + | _ => sys_error "get_desc_thm" + + val (label, lev, sl, covering) = certificate + + fun prove_lev strict g = + let + val G (p, q, el) = nth gs g + + fun less_proof strict (j, b) (i, a) = + let + val tag_flag = b < a orelse (not strict andalso b <= a) + + val stored_thm = + get_desc_thm g (measure_fn p i) (measure_fn q j) + (not tag_flag) + |> Conv.fconv_rule (Thm.beta_conversion true) + + val rule = if strict + then if b < a then @{thm pair_lessI2} else @{thm pair_lessI1} + else if b <= a then @{thm pair_leqI2} else @{thm pair_leqI1} + in + rtac rule 1 THEN PRIMITIVE (Thm.elim_implies stored_thm) + THEN (if tag_flag then Arith_Data.verbose_arith_tac ctxt 1 else all_tac) + end + + fun steps_tac MAX strict lq lp = + let + val (empty, step) = ord_intros_max strict + in + if length lq = 0 + then rtac empty 1 THEN set_finite_tac 1 + THEN (if strict then set_nonempty_tac 1 else all_tac) + else + let + val (j, b) :: rest = lq + val (i, a) = the (covering g strict j) + fun choose xs = set_member_tac (Library.find_index (curry op = (i, a)) xs) 1 + val solve_tac = choose lp THEN less_proof strict (j, b) (i, a) + in + rtac step 1 THEN solve_tac THEN steps_tac MAX strict rest lp + end + end + | steps_tac MIN strict lq lp = + let + val (empty, step) = ord_intros_min strict + in + if length lp = 0 + then rtac empty 1 + THEN (if strict then set_nonempty_tac 1 else all_tac) + else + let + val (i, a) :: rest = lp + val (j, b) = the (covering g strict i) + fun choose xs = set_member_tac (Library.find_index (curry op = (j, b)) xs) 1 + val solve_tac = choose lq THEN less_proof strict (j, b) (i, a) + in + rtac step 1 THEN solve_tac THEN steps_tac MIN strict lq rest + end + end + | steps_tac MS strict lq lp = + let + fun get_str_cover (j, b) = + if is_some (covering g true j) then SOME (j, b) else NONE + fun get_wk_cover (j, b) = the (covering g false j) + + val qs = lq \\ map_filter get_str_cover lq + val ps = map get_wk_cover qs + + fun indices xs ys = map (fn y => Library.find_index (curry op = y) xs) ys + val iqs = indices lq qs + val ips = indices lp ps + + local open Conv in + fun t_conv a C = + params_conv ~1 (K ((concl_conv ~1 o arg_conv o arg1_conv o a) C)) ctxt + val goal_rewrite = + t_conv arg1_conv (mset_regroup_conv iqs) + then_conv t_conv arg_conv (mset_regroup_conv ips) + end + in + CONVERSION goal_rewrite 1 + THEN (if strict then rtac smsI' 1 + else if qs = lq then rtac wmsI2'' 1 + else rtac wmsI1 1) + THEN mset_pwleq_tac 1 + THEN EVERY (map2 (less_proof false) qs ps) + THEN (if strict orelse qs <> lq + then LocalDefs.unfold_tac ctxt set_of_simps + THEN steps_tac MAX true (lq \\ qs) (lp \\ ps) + else all_tac) + end + in + rem_inv_img ctxt + THEN steps_tac label strict (nth lev q) (nth lev p) + end + + val (mk_set, setT) = if label = MS then (mk_mset, msetT) else (HOLogic.mk_set, setT) + + fun tag_pair p (i, tag) = + HOLogic.pair_const natT natT $ + (measure_fn p i $ Bound 0) $ HOLogic.mk_number natT tag + + fun pt_lev (p, lm) = Abs ("x", Termination.get_types D p, + mk_set nat_pairT (map (tag_pair p) lm)) + + val level_mapping = + map_index pt_lev lev + |> Termination.mk_sumcases D (setT nat_pairT) + |> cterm_of thy + in + PROFILE "Proof Reconstruction" + (CONVERSION (Conv.arg_conv (Conv.arg_conv (FundefLib.regroup_union_conv sl))) 1 + THEN (rtac @{thm reduction_pair_lemma} 1) + THEN (rtac @{thm rp_inv_image_rp} 1) + THEN (rtac (order_rpair ms_rp label) 1) + THEN PRIMITIVE (instantiate' [] [SOME level_mapping]) + THEN unfold_tac @{thms rp_inv_image_def} (local_simpset_of ctxt) + THEN LocalDefs.unfold_tac ctxt + (@{thms split_conv} @ @{thms fst_conv} @ @{thms snd_conv}) + THEN REPEAT (SOMEGOAL (resolve_tac [@{thm Un_least}, @{thm empty_subsetI}])) + THEN EVERY (map (prove_lev true) sl) + THEN EVERY (map (prove_lev false) ((0 upto length cs - 1) \\ sl))) + end + + + +local open Termination in +fun print_cell (SOME (Less _)) = "<" + | print_cell (SOME (LessEq _)) = "\" + | print_cell (SOME (None _)) = "-" + | print_cell (SOME (False _)) = "-" + | print_cell (NONE) = "?" + +fun print_error ctxt D = CALLS (fn (cs, i) => + let + val np = get_num_points D + val ms = map (get_measures D) (0 upto np - 1) + val tys = map (get_types D) (0 upto np - 1) + fun index xs = (1 upto length xs) ~~ xs + fun outp s t f xs = map (fn (x, y) => s ^ Int.toString x ^ t ^ f y ^ "\n") xs + val ims = index (map index ms) + val _ = Output.tracing (concat (outp "fn #" ":\n" (concat o outp "\tmeasure #" ": " (Syntax.string_of_term ctxt)) ims)) + fun print_call (k, c) = + let + val (_, p, _, q, _, _) = dest_call D c + val _ = Output.tracing ("call table for call #" ^ Int.toString k ^ ": fn " ^ + Int.toString (p + 1) ^ " ~> fn " ^ Int.toString (q + 1)) + val caller_ms = nth ms p + val callee_ms = nth ms q + val entries = map (fn x => map (pair x) (callee_ms)) (caller_ms) + fun print_ln (i : int, l) = concat (Int.toString i :: " " :: map (enclose " " " " o print_cell o (uncurry (get_descent D c))) l) + val _ = Output.tracing (concat (Int.toString (p + 1) ^ "|" ^ Int.toString (q + 1) ^ + " " :: map (enclose " " " " o Int.toString) (1 upto length callee_ms)) ^ "\n" + ^ cat_lines (map print_ln ((1 upto (length entries)) ~~ entries))) + in + true + end + fun list_call (k, c) = + let + val (_, p, _, q, _, _) = dest_call D c + val _ = Output.tracing ("call #" ^ (Int.toString k) ^ ": fn " ^ + Int.toString (p + 1) ^ " ~> fn " ^ Int.toString (q + 1) ^ "\n" ^ + (Syntax.string_of_term ctxt c)) + in true end + val _ = forall list_call ((1 upto length cs) ~~ cs) + val _ = forall print_call ((1 upto length cs) ~~ cs) + in + all_tac + end) +end + + +fun single_scnp_tac use_tags orders ctxt cont err_cont D = Termination.CALLS (fn (cs, i) => + let + val ms_configured = is_some (MultisetSetup.get (ProofContext.theory_of ctxt)) + val orders' = if ms_configured then orders + else filter_out (curry op = MS) orders + val gp = gen_probl D cs +(* val _ = TRACE ("SCNP instance: " ^ makestring gp)*) + val certificate = generate_certificate use_tags orders' gp +(* val _ = TRACE ("Certificate: " ^ makestring certificate)*) + + in + case certificate + of NONE => err_cont D i + | SOME cert => + SELECT_GOAL (reconstruct_tac ctxt D cs gp cert) i + THEN (rtac @{thm wf_empty} i ORELSE cont D i) + end) + +fun gen_decomp_scnp_tac orders autom_tac ctxt err_cont = + let + open Termination + val derive_diag = Descent.derive_diag ctxt autom_tac + val derive_all = Descent.derive_all ctxt autom_tac + val decompose = Decompose.decompose_tac ctxt autom_tac + val scnp_no_tags = single_scnp_tac false orders ctxt + val scnp_full = single_scnp_tac true orders ctxt + + fun first_round c e = + derive_diag (REPEAT scnp_no_tags c e) + + val second_round = + REPEAT (fn c => fn e => decompose (scnp_no_tags c c) e) + + val third_round = + derive_all oo + REPEAT (fn c => fn e => + scnp_full (decompose c c) e) + + fun Then s1 s2 c e = s1 (s2 c c) (s2 c e) + + val strategy = Then (Then first_round second_round) third_round + + in + TERMINATION ctxt (strategy err_cont err_cont) + end + +fun gen_sizechange_tac orders autom_tac ctxt err_cont = + TRY (FundefCommon.apply_termination_rule ctxt 1) + THEN TRY (Termination.wf_union_tac ctxt) + THEN + (rtac @{thm wf_empty} 1 + ORELSE gen_decomp_scnp_tac orders autom_tac ctxt err_cont 1) + +fun sizechange_tac ctxt autom_tac = + gen_sizechange_tac [MAX, MS, MIN] autom_tac ctxt (K (K all_tac)) + +fun decomp_scnp orders ctxt = + let + val extra_simps = FundefCommon.TerminationSimps.get ctxt + val autom_tac = auto_tac (local_clasimpset_of ctxt addsimps2 extra_simps) + in + SIMPLE_METHOD + (gen_sizechange_tac orders autom_tac ctxt (print_error ctxt)) + end + + +(* Method setup *) + +val orders = + Scan.repeat1 + ((Args.$$$ "max" >> K MAX) || + (Args.$$$ "min" >> K MIN) || + (Args.$$$ "ms" >> K MS)) + || Scan.succeed [MAX, MS, MIN] + +val setup = Method.setup @{binding sizechange} + (Scan.lift orders --| Method.sections clasimp_modifiers >> decomp_scnp) + "termination prover with graph decomposition and the NP subset of size change termination" + +end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/scnp_solve.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/scnp_solve.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,257 @@ +(* Title: HOL/Tools/Function/scnp_solve.ML + Author: Armin Heller, TU Muenchen + Author: Alexander Krauss, TU Muenchen + +Generate certificates for SCNP using a SAT solver +*) + + +signature SCNP_SOLVE = +sig + + datatype edge = GTR | GEQ + datatype graph = G of int * int * (int * edge * int) list + datatype graph_problem = GP of int list * graph list + + datatype label = MIN | MAX | MS + + type certificate = + label (* which order *) + * (int * int) list list (* (multi)sets *) + * int list (* strictly ordered calls *) + * (int -> bool -> int -> (int * int) option) (* covering function *) + + val generate_certificate : bool -> label list -> graph_problem -> certificate option + + val solver : string ref +end + +structure ScnpSolve : SCNP_SOLVE = +struct + +(** Graph problems **) + +datatype edge = GTR | GEQ ; +datatype graph = G of int * int * (int * edge * int) list ; +datatype graph_problem = GP of int list * graph list ; + +datatype label = MIN | MAX | MS ; +type certificate = + label + * (int * int) list list + * int list + * (int -> bool -> int -> (int * int) option) + +fun graph_at (GP (_, gs), i) = nth gs i ; +fun num_prog_pts (GP (arities, _)) = length arities ; +fun num_graphs (GP (_, gs)) = length gs ; +fun arity (GP (arities, gl)) i = nth arities i ; +fun ndigits (GP (arities, _)) = IntInf.log2 (List.foldl (op +) 0 arities) + 1 + + +(** Propositional formulas **) + +val Not = PropLogic.Not and And = PropLogic.And and Or = PropLogic.Or +val BoolVar = PropLogic.BoolVar +fun Implies (p, q) = Or (Not p, q) +fun Equiv (p, q) = And (Implies (p, q), Implies (q, p)) +val all = PropLogic.all + +(* finite indexed quantifiers: + +iforall n f <==> /\ + / \ f i + 0<=i Equiv (TAG x i, TAG y i))) + + fun encode_graph (g, p, q, n, m, edges) = + let + fun encode_edge i j = + if exists (fn x => x = (i, GTR, j)) edges then + And (ES (g, i, j), EW (g, i, j)) + else if not (exists (fn x => x = (i, GEQ, j)) edges) then + And (Not (ES (g, i, j)), Not (EW (g, i, j))) + else + And ( + Equiv (ES (g, i, j), + encode_constraint_strict bits ((p, i), (q, j))), + Equiv (EW (g, i, j), + encode_constraint_weak bits ((p, i), (q, j)))) + in + iforall2 n m encode_edge + end + in + iforall ng (encode_graph o graph_info gp) + end + + +(* Order-specific part of encoding *) + +fun encode bits gp mu = + let + val ng = num_graphs gp + val (ES,EW,WEAK,STRICT,P,GAM,EPS,_) = var_constrs gp + + fun encode_graph MAX (g, p, q, n, m, _) = + And ( + Equiv (WEAK g, + iforall m (fn j => + Implies (P (q, j), + iexists n (fn i => + And (P (p, i), EW (g, i, j)))))), + Equiv (STRICT g, + And ( + iforall m (fn j => + Implies (P (q, j), + iexists n (fn i => + And (P (p, i), ES (g, i, j))))), + iexists n (fn i => P (p, i))))) + | encode_graph MIN (g, p, q, n, m, _) = + And ( + Equiv (WEAK g, + iforall n (fn i => + Implies (P (p, i), + iexists m (fn j => + And (P (q, j), EW (g, i, j)))))), + Equiv (STRICT g, + And ( + iforall n (fn i => + Implies (P (p, i), + iexists m (fn j => + And (P (q, j), ES (g, i, j))))), + iexists m (fn j => P (q, j))))) + | encode_graph MS (g, p, q, n, m, _) = + all [ + Equiv (WEAK g, + iforall m (fn j => + Implies (P (q, j), + iexists n (fn i => GAM (g, i, j))))), + Equiv (STRICT g, + iexists n (fn i => + And (P (p, i), Not (EPS (g, i))))), + iforall2 n m (fn i => fn j => + Implies (GAM (g, i, j), + all [ + P (p, i), + P (q, j), + EW (g, i, j), + Equiv (Not (EPS (g, i)), ES (g, i, j))])), + iforall n (fn i => + Implies (And (P (p, i), EPS (g, i)), + exactly_one m (fn j => GAM (g, i, j)))) + ] + in + all [ + encode_graphs bits gp, + iforall ng (encode_graph mu o graph_info gp), + iforall ng (fn x => WEAK x), + iexists ng (fn x => STRICT x) + ] + end + + +(*Generieren des level-mapping und diverser output*) +fun mk_certificate bits label gp f = + let + val (ES,EW,WEAK,STRICT,P,GAM,EPS,TAG) = var_constrs gp + fun assign (PropLogic.BoolVar v) = the_default false (f v) + fun assignTag i j = + (fold (fn x => fn y => 2 * y + (if assign (TAG (i, j) x) then 1 else 0)) + (bits - 1 downto 0) 0) + + val level_mapping = + let fun prog_pt_mapping p = + map_filter (fn x => if assign (P(p, x)) then SOME (x, assignTag p x) else NONE) + (0 upto (arity gp p) - 1) + in map prog_pt_mapping (0 upto num_prog_pts gp - 1) end + + val strict_list = filter (assign o STRICT) (0 upto num_graphs gp - 1) + + fun covering_pair g bStrict j = + let + val (_, p, q, n, m, _) = graph_info gp g + + fun cover MAX j = find_index (fn i => assign (P (p, i)) andalso assign (EW (g, i, j))) (0 upto n - 1) + | cover MS k = find_index (fn i => assign (GAM (g, i, k))) (0 upto n - 1) + | cover MIN i = find_index (fn j => assign (P (q, j)) andalso assign (EW (g, i, j))) (0 upto m - 1) + fun cover_strict MAX j = find_index (fn i => assign (P (p, i)) andalso assign (ES (g, i, j))) (0 upto n - 1) + | cover_strict MS k = find_index (fn i => assign (GAM (g, i, k)) andalso not (assign (EPS (g, i) ))) (0 upto n - 1) + | cover_strict MIN i = find_index (fn j => assign (P (q, j)) andalso assign (ES (g, i, j))) (0 upto m - 1) + val i = if bStrict then cover_strict label j else cover label j + in + find_first (fn x => fst x = i) (nth level_mapping (if label = MIN then q else p)) + end + in + (label, level_mapping, strict_list, covering_pair) + end + +(*interface for the proof reconstruction*) +fun generate_certificate use_tags labels gp = + let + val bits = if use_tags then ndigits gp else 0 + in + get_first + (fn l => case sat_solver (encode bits gp l) of + SatSolver.SATISFIABLE f => SOME (mk_certificate bits l gp f) + | _ => NONE) + labels + end +end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/size.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/size.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,242 @@ +(* Title: HOL/Tools/Function/size.ML + Author: Stefan Berghofer, Florian Haftmann & Alexander Krauss, TU Muenchen + +Size functions for datatypes. +*) + +signature SIZE = +sig + val size_thms: theory -> string -> thm list + val setup: theory -> theory +end; + +structure Size: SIZE = +struct + +open DatatypeAux; + +structure SizeData = TheoryDataFun +( + type T = (string * thm list) Symtab.table; + val empty = Symtab.empty; + val copy = I + val extend = I + fun merge _ = Symtab.merge (K true); +); + +val lookup_size = SizeData.get #> Symtab.lookup; + +fun plus (t1, t2) = Const ("HOL.plus_class.plus", + HOLogic.natT --> HOLogic.natT --> HOLogic.natT) $ t1 $ t2; + +fun size_of_type f g h (T as Type (s, Ts)) = + (case f s of + SOME t => SOME t + | NONE => (case g s of + SOME size_name => + SOME (list_comb (Const (size_name, + map (fn U => U --> HOLogic.natT) Ts @ [T] ---> HOLogic.natT), + map (size_of_type' f g h) Ts)) + | NONE => NONE)) + | size_of_type f g h (TFree (s, _)) = h s +and size_of_type' f g h T = (case size_of_type f g h T of + NONE => Abs ("x", T, HOLogic.zero) + | SOME t => t); + +fun is_poly thy (DtType (name, dts)) = + (case Datatype.get_datatype thy name of + NONE => false + | SOME _ => exists (is_poly thy) dts) + | is_poly _ _ = true; + +fun constrs_of thy name = + let + val {descr, index, ...} = Datatype.the_datatype thy name + val SOME (_, _, constrs) = AList.lookup op = descr index + in constrs end; + +val app = curry (list_comb o swap); + +fun prove_size_thms (info : info) new_type_names thy = + let + val {descr, alt_names, sorts, rec_names, rec_rewrites, induction, ...} = info; + val l = length new_type_names; + val alt_names' = (case alt_names of + NONE => replicate l NONE | SOME names => map SOME names); + val descr' = List.take (descr, l); + val (rec_names1, rec_names2) = chop l rec_names; + val recTs = get_rec_types descr sorts; + val (recTs1, recTs2) = chop l recTs; + val (_, (_, paramdts, _)) :: _ = descr; + val paramTs = map (typ_of_dtyp descr sorts) paramdts; + val ((param_size_fs, param_size_fTs), f_names) = paramTs |> + map (fn T as TFree (s, _) => + let + val name = "f" ^ implode (tl (explode s)); + val U = T --> HOLogic.natT + in + (((s, Free (name, U)), U), name) + end) |> split_list |>> split_list; + val param_size = AList.lookup op = param_size_fs; + + val extra_rewrites = descr |> map (#1 o snd) |> distinct op = |> + map_filter (Option.map snd o lookup_size thy) |> flat; + val extra_size = Option.map fst o lookup_size thy; + + val (((size_names, size_fns), def_names), def_names') = + recTs1 ~~ alt_names' |> + map (fn (T as Type (s, _), optname) => + let + val s' = the_default (Long_Name.base_name s) optname ^ "_size"; + val s'' = Sign.full_bname thy s' + in + (s'', + (list_comb (Const (s'', param_size_fTs @ [T] ---> HOLogic.natT), + map snd param_size_fs), + (s' ^ "_def", s' ^ "_overloaded_def"))) + end) |> split_list ||>> split_list ||>> split_list; + val overloaded_size_fns = map HOLogic.size_const recTs1; + + (* instantiation for primrec combinator *) + fun size_of_constr b size_ofp ((_, cargs), (_, cargs')) = + let + val Ts = map (typ_of_dtyp descr sorts) cargs; + val k = length (filter is_rec_type cargs); + val (ts, _, _) = fold_rev (fn ((dt, dt'), T) => fn (us, i, j) => + if is_rec_type dt then (Bound i :: us, i + 1, j + 1) + else + (if b andalso is_poly thy dt' then + case size_of_type (K NONE) extra_size size_ofp T of + NONE => us | SOME sz => sz $ Bound j :: us + else us, i, j + 1)) + (cargs ~~ cargs' ~~ Ts) ([], 0, k); + val t = + if null ts andalso (not b orelse not (exists (is_poly thy) cargs')) + then HOLogic.zero + else foldl1 plus (ts @ [HOLogic.Suc_zero]) + in + List.foldr (fn (T, t') => Abs ("x", T, t')) t (Ts @ replicate k HOLogic.natT) + end; + + val fs = maps (fn (_, (name, _, constrs)) => + map (size_of_constr true param_size) (constrs ~~ constrs_of thy name)) descr; + val fs' = maps (fn (n, (name, _, constrs)) => + map (size_of_constr (l <= n) (K NONE)) (constrs ~~ constrs_of thy name)) descr; + val fTs = map fastype_of fs; + + val (rec_combs1, rec_combs2) = chop l (map (fn (T, rec_name) => + Const (rec_name, fTs @ [T] ---> HOLogic.natT)) + (recTs ~~ rec_names)); + + fun define_overloaded (def_name, eq) lthy = + let + val (Free (c, _), rhs) = (Logic.dest_equals o Syntax.check_term lthy) eq; + val ((_, (_, thm)), lthy') = lthy |> LocalTheory.define Thm.definitionK + ((Binding.name c, NoSyn), ((Binding.name def_name, []), rhs)); + val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy'); + val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm; + in (thm', lthy') end; + + val ((size_def_thms, size_def_thms'), thy') = + thy + |> Sign.add_consts_i (map (fn (s, T) => + (Binding.name (Long_Name.base_name s), param_size_fTs @ [T] ---> HOLogic.natT, NoSyn)) + (size_names ~~ recTs1)) + |> PureThy.add_defs false + (map (Thm.no_attributes o apsnd (Logic.mk_equals o apsnd (app fs))) + (map Binding.name def_names ~~ (size_fns ~~ rec_combs1))) + ||> TheoryTarget.instantiation + (map (#1 o snd) descr', map dest_TFree paramTs, [HOLogic.class_size]) + ||>> fold_map define_overloaded + (def_names' ~~ map Logic.mk_equals (overloaded_size_fns ~~ map (app fs') rec_combs1)) + ||> Class.prove_instantiation_instance (K (Class.intro_classes_tac [])) + ||> LocalTheory.exit_global; + + val ctxt = ProofContext.init thy'; + + val simpset1 = HOL_basic_ss addsimps @{thm add_0} :: @{thm add_0_right} :: + size_def_thms @ size_def_thms' @ rec_rewrites @ extra_rewrites; + val xs = map (fn i => "x" ^ string_of_int i) (1 upto length recTs2); + + fun mk_unfolded_size_eq tab size_ofp fs (p as (x, T), r) = + HOLogic.mk_eq (app fs r $ Free p, + the (size_of_type tab extra_size size_ofp T) $ Free p); + + fun prove_unfolded_size_eqs size_ofp fs = + if null recTs2 then [] + else split_conj_thm (SkipProof.prove ctxt xs [] + (HOLogic.mk_Trueprop (mk_conj (replicate l HOLogic.true_const @ + map (mk_unfolded_size_eq (AList.lookup op = + (new_type_names ~~ map (app fs) rec_combs1)) size_ofp fs) + (xs ~~ recTs2 ~~ rec_combs2)))) + (fn _ => (indtac induction xs THEN_ALL_NEW asm_simp_tac simpset1) 1)); + + val unfolded_size_eqs1 = prove_unfolded_size_eqs param_size fs; + val unfolded_size_eqs2 = prove_unfolded_size_eqs (K NONE) fs'; + + (* characteristic equations for size functions *) + fun gen_mk_size_eq p size_of size_ofp size_const T (cname, cargs) = + let + val Ts = map (typ_of_dtyp descr sorts) cargs; + val tnames = Name.variant_list f_names (DatatypeProp.make_tnames Ts); + val ts = map_filter (fn (sT as (s, T), dt) => + Option.map (fn sz => sz $ Free sT) + (if p dt then size_of_type size_of extra_size size_ofp T + else NONE)) (tnames ~~ Ts ~~ cargs) + in + HOLogic.mk_Trueprop (HOLogic.mk_eq + (size_const $ list_comb (Const (cname, Ts ---> T), + map2 (curry Free) tnames Ts), + if null ts then HOLogic.zero + else foldl1 plus (ts @ [HOLogic.Suc_zero]))) + end; + + val simpset2 = HOL_basic_ss addsimps + rec_rewrites @ size_def_thms @ unfolded_size_eqs1; + val simpset3 = HOL_basic_ss addsimps + rec_rewrites @ size_def_thms' @ unfolded_size_eqs2; + + fun prove_size_eqs p size_fns size_ofp simpset = + maps (fn (((_, (_, _, constrs)), size_const), T) => + map (fn constr => standard (SkipProof.prove ctxt [] [] + (gen_mk_size_eq p (AList.lookup op = (new_type_names ~~ size_fns)) + size_ofp size_const T constr) + (fn _ => simp_tac simpset 1))) constrs) + (descr' ~~ size_fns ~~ recTs1); + + val size_eqns = prove_size_eqs (is_poly thy') size_fns param_size simpset2 @ + prove_size_eqs is_rec_type overloaded_size_fns (K NONE) simpset3; + + val ([size_thms], thy'') = PureThy.add_thmss + [((Binding.name "size", size_eqns), + [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add, + Thm.declaration_attribute + (fn thm => Context.mapping (Code.add_default_eqn thm) I)])] thy' + + in + SizeData.map (fold (Symtab.update_new o apsnd (rpair size_thms)) + (new_type_names ~~ size_names)) thy'' + end; + +fun add_size_thms config (new_type_names as name :: _) thy = + let + val info as {descr, alt_names, ...} = Datatype.the_datatype thy name; + val prefix = Long_Name.map_base_name (K (space_implode "_" + (the_default (map Long_Name.base_name new_type_names) alt_names))) name; + val no_size = exists (fn (_, (_, _, constrs)) => exists (fn (_, cargs) => exists (fn dt => + is_rec_type dt andalso not (null (fst (strip_dtyp dt)))) cargs) constrs) descr + in if no_size then thy + else + thy + |> Sign.root_path + |> Sign.add_path prefix + |> Theory.checkpoint + |> prove_size_thms info new_type_names + |> Sign.restore_naming thy + end; + +val size_thms = snd oo (the oo lookup_size); + +val setup = Datatype.interpretation add_size_thms; + +end; diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/sum_tree.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/sum_tree.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,43 @@ +(* Title: HOL/Tools/Function/sum_tree.ML + Author: Alexander Krauss, TU Muenchen + +Some common tools for working with sum types in balanced tree form. +*) + +structure SumTree = +struct + +(* Theory dependencies *) +val proj_in_rules = [@{thm "Datatype.Projl_Inl"}, @{thm "Datatype.Projr_Inr"}] +val sumcase_split_ss = HOL_basic_ss addsimps (@{thm "Product_Type.split"} :: @{thms "sum.cases"}) + +(* top-down access in balanced tree *) +fun access_top_down {left, right, init} len i = + BalancedTree.access {left = (fn f => f o left), right = (fn f => f o right), init = I} len i init + +(* Sum types *) +fun mk_sumT LT RT = Type ("+", [LT, RT]) +fun mk_sumcase TL TR T l r = Const (@{const_name "sum.sum_case"}, (TL --> T) --> (TR --> T) --> mk_sumT TL TR --> T) $ l $ r + +val App = curry op $ + +fun mk_inj ST n i = + access_top_down + { init = (ST, I : term -> term), + left = (fn (T as Type ("+", [LT, RT]), inj) => (LT, inj o App (Const (@{const_name "Inl"}, LT --> T)))), + right =(fn (T as Type ("+", [LT, RT]), inj) => (RT, inj o App (Const (@{const_name "Inr"}, RT --> T))))} n i + |> snd + +fun mk_proj ST n i = + access_top_down + { init = (ST, I : term -> term), + left = (fn (T as Type ("+", [LT, RT]), proj) => (LT, App (Const (@{const_name "Datatype.Projl"}, T --> LT)) o proj)), + right =(fn (T as Type ("+", [LT, RT]), proj) => (RT, App (Const (@{const_name "Datatype.Projr"}, T --> RT)) o proj))} n i + |> snd + +fun mk_sumcases T fs = + BalancedTree.make (fn ((f, fT), (g, gT)) => (mk_sumcase fT gT T f g, mk_sumT fT gT)) + (map (fn f => (f, domain_type (fastype_of f))) fs) + |> fst + +end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/Function/termination.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/Function/termination.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,324 @@ +(* Title: HOL/Tools/Function/termination.ML + Author: Alexander Krauss, TU Muenchen + +Context data for termination proofs +*) + + +signature TERMINATION = +sig + + type data + datatype cell = Less of thm | LessEq of (thm * thm) | None of (thm * thm) | False of thm + + val mk_sumcases : data -> typ -> term list -> term + + val note_measure : int -> term -> data -> data + val note_chain : term -> term -> thm option -> data -> data + val note_descent : term -> term -> term -> cell -> data -> data + + val get_num_points : data -> int + val get_types : data -> int -> typ + val get_measures : data -> int -> term list + + (* read from cache *) + val get_chain : data -> term -> term -> thm option option + val get_descent : data -> term -> term -> term -> cell option + + (* writes *) + val derive_descent : theory -> tactic -> term -> term -> term -> data -> data + val derive_descents : theory -> tactic -> term -> data -> data + + val dest_call : data -> term -> ((string * typ) list * int * term * int * term * term) + + val CALLS : (term list * int -> tactic) -> int -> tactic + + (* Termination tactics. Sequential composition via continuations. (2nd argument is the error continuation) *) + type ttac = (data -> int -> tactic) -> (data -> int -> tactic) -> data -> int -> tactic + + val TERMINATION : Proof.context -> (data -> int -> tactic) -> int -> tactic + + val REPEAT : ttac -> ttac + + val wf_union_tac : Proof.context -> tactic +end + + + +structure Termination : TERMINATION = +struct + +open FundefLib + +val term2_ord = prod_ord TermOrd.fast_term_ord TermOrd.fast_term_ord +structure Term2tab = TableFun(type key = term * term val ord = term2_ord); +structure Term3tab = TableFun(type key = term * (term * term) val ord = prod_ord TermOrd.fast_term_ord term2_ord); + +(** Analyzing binary trees **) + +(* Skeleton of a tree structure *) + +datatype skel = + SLeaf of int (* index *) +| SBranch of (skel * skel) + + +(* abstract make and dest functions *) +fun mk_tree leaf branch = + let fun mk (SLeaf i) = leaf i + | mk (SBranch (s, t)) = branch (mk s, mk t) + in mk end + + +fun dest_tree split = + let fun dest (SLeaf i) x = [(i, x)] + | dest (SBranch (s, t)) x = + let val (l, r) = split x + in dest s l @ dest t r end + in dest end + + +(* concrete versions for sum types *) +fun is_inj (Const ("Sum_Type.Inl", _) $ _) = true + | is_inj (Const ("Sum_Type.Inr", _) $ _) = true + | is_inj _ = false + +fun dest_inl (Const ("Sum_Type.Inl", _) $ t) = SOME t + | dest_inl _ = NONE + +fun dest_inr (Const ("Sum_Type.Inr", _) $ t) = SOME t + | dest_inr _ = NONE + + +fun mk_skel ps = + let + fun skel i ps = + if forall is_inj ps andalso not (null ps) + then let + val (j, s) = skel i (map_filter dest_inl ps) + val (k, t) = skel j (map_filter dest_inr ps) + in (k, SBranch (s, t)) end + else (i + 1, SLeaf i) + in + snd (skel 0 ps) + end + +(* compute list of types for nodes *) +fun node_types sk T = dest_tree (fn Type ("+", [LT, RT]) => (LT, RT)) sk T |> map snd + +(* find index and raw term *) +fun dest_inj (SLeaf i) trm = (i, trm) + | dest_inj (SBranch (s, t)) trm = + case dest_inl trm of + SOME trm' => dest_inj s trm' + | _ => dest_inj t (the (dest_inr trm)) + + + +(** Matrix cell datatype **) + +datatype cell = Less of thm | LessEq of (thm * thm) | None of (thm * thm) | False of thm; + + +type data = + skel (* structure of the sum type encoding "program points" *) + * (int -> typ) (* types of program points *) + * (term list Inttab.table) (* measures for program points *) + * (thm option Term2tab.table) (* which calls form chains? *) + * (cell Term3tab.table) (* local descents *) + + +fun map_measures f (p, T, M, C, D) = (p, T, f M, C, D) +fun map_chains f (p, T, M, C, D) = (p, T, M, f C, D) +fun map_descent f (p, T, M, C, D) = (p, T, M, C, f D) + +fun note_measure p m = map_measures (Inttab.insert_list (op aconv) (p, m)) +fun note_chain c1 c2 res = map_chains (Term2tab.update ((c1, c2), res)) +fun note_descent c m1 m2 res = map_descent (Term3tab.update ((c,(m1, m2)), res)) + +(* Build case expression *) +fun mk_sumcases (sk, _, _, _, _) T fs = + mk_tree (fn i => (nth fs i, domain_type (fastype_of (nth fs i)))) + (fn ((f, fT), (g, gT)) => (SumTree.mk_sumcase fT gT T f g, SumTree.mk_sumT fT gT)) + sk + |> fst + +fun mk_sum_skel rel = + let + val cs = FundefLib.dest_binop_list @{const_name Un} rel + fun collect_pats (Const ("Collect", _) $ Abs (_, _, c)) = + let + val (Const ("op &", _) $ (Const ("op =", _) $ _ $ (Const ("Pair", _) $ r $ l)) $ Gam) + = Term.strip_qnt_body "Ex" c + in cons r o cons l end + in + mk_skel (fold collect_pats cs []) + end + +fun create ctxt T rel = + let + val sk = mk_sum_skel rel + val Ts = node_types sk T + val M = Inttab.make (map_index (apsnd (MeasureFunctions.get_measure_functions ctxt)) Ts) + in + (sk, nth Ts, M, Term2tab.empty, Term3tab.empty) + end + +fun get_num_points (sk, _, _, _, _) = + let + fun num (SLeaf i) = i + 1 + | num (SBranch (s, t)) = num t + in num sk end + +fun get_types (_, T, _, _, _) = T +fun get_measures (_, _, M, _, _) = Inttab.lookup_list M + +fun get_chain (_, _, _, C, _) c1 c2 = + Term2tab.lookup C (c1, c2) + +fun get_descent (_, _, _, _, D) c m1 m2 = + Term3tab.lookup D (c, (m1, m2)) + +fun dest_call D (Const ("Collect", _) $ Abs (_, _, c)) = + let + val n = get_num_points D + val (sk, _, _, _, _) = D + val vs = Term.strip_qnt_vars "Ex" c + + (* FIXME: throw error "dest_call" for malformed terms *) + val (Const ("op &", _) $ (Const ("op =", _) $ _ $ (Const ("Pair", _) $ r $ l)) $ Gam) + = Term.strip_qnt_body "Ex" c + val (p, l') = dest_inj sk l + val (q, r') = dest_inj sk r + in + (vs, p, l', q, r', Gam) + end + | dest_call D t = error "dest_call" + + +fun derive_desc_aux thy tac c (vs, p, l', q, r', Gam) m1 m2 D = + case get_descent D c m1 m2 of + SOME _ => D + | NONE => let + fun cgoal rel = + Term.list_all (vs, + Logic.mk_implies (HOLogic.mk_Trueprop Gam, + HOLogic.mk_Trueprop (Const (rel, @{typ "nat => nat => bool"}) + $ (m2 $ r') $ (m1 $ l')))) + |> cterm_of thy + in + note_descent c m1 m2 + (case try_proof (cgoal @{const_name HOL.less}) tac of + Solved thm => Less thm + | Stuck thm => + (case try_proof (cgoal @{const_name HOL.less_eq}) tac of + Solved thm2 => LessEq (thm2, thm) + | Stuck thm2 => + if prems_of thm2 = [HOLogic.Trueprop $ HOLogic.false_const] + then False thm2 else None (thm2, thm) + | _ => raise Match) (* FIXME *) + | _ => raise Match) D + end + +fun derive_descent thy tac c m1 m2 D = + derive_desc_aux thy tac c (dest_call D c) m1 m2 D + +(* all descents in one go *) +fun derive_descents thy tac c D = + let val cdesc as (vs, p, l', q, r', Gam) = dest_call D c + in fold_product (derive_desc_aux thy tac c cdesc) + (get_measures D p) (get_measures D q) D + end + +fun CALLS tac i st = + if Thm.no_prems st then all_tac st + else case Thm.term_of (Thm.cprem_of st i) of + (_ $ (_ $ rel)) => tac (FundefLib.dest_binop_list @{const_name Un} rel, i) st + |_ => no_tac st + +type ttac = (data -> int -> tactic) -> (data -> int -> tactic) -> data -> int -> tactic + +fun TERMINATION ctxt tac = + SUBGOAL (fn (_ $ (Const (@{const_name "wf"}, wfT) $ rel), i) => + let + val (T, _) = HOLogic.dest_prodT (HOLogic.dest_setT (domain_type wfT)) + in + tac (create ctxt T rel) i + end) + + +(* A tactic to convert open to closed termination goals *) +local +fun dest_term (t : term) = (* FIXME, cf. Lexicographic order *) + let + val (vars, prop) = FundefLib.dest_all_all t + val (prems, concl) = Logic.strip_horn prop + val (lhs, rhs) = concl + |> HOLogic.dest_Trueprop + |> HOLogic.dest_mem |> fst + |> HOLogic.dest_prod + in + (vars, prems, lhs, rhs) + end + +fun mk_pair_compr (T, qs, l, r, conds) = + let + val pT = HOLogic.mk_prodT (T, T) + val n = length qs + val peq = HOLogic.eq_const pT $ Bound n $ (HOLogic.pair_const T T $ l $ r) + val conds' = if null conds then [HOLogic.true_const] else conds + in + HOLogic.Collect_const pT $ + Abs ("uu_", pT, + (foldr1 HOLogic.mk_conj (peq :: conds') + |> fold_rev (fn v => fn t => HOLogic.exists_const (fastype_of v) $ lambda v t) qs)) + end + +in + +fun wf_union_tac ctxt st = + let + val thy = ProofContext.theory_of ctxt + val cert = cterm_of (theory_of_thm st) + val ((trueprop $ (wf $ rel)) :: ineqs) = prems_of st + + fun mk_compr ineq = + let + val (vars, prems, lhs, rhs) = dest_term ineq + in + mk_pair_compr (fastype_of lhs, vars, lhs, rhs, map (ObjectLogic.atomize_term thy) prems) + end + + val relation = + if null ineqs then + Const (@{const_name Set.empty}, fastype_of rel) + else + foldr1 (HOLogic.mk_binop @{const_name Un}) (map mk_compr ineqs) + + fun solve_membership_tac i = + (EVERY' (replicate (i - 2) (rtac @{thm UnI2})) (* pick the right component of the union *) + THEN' (fn j => TRY (rtac @{thm UnI1} j)) + THEN' (rtac @{thm CollectI}) (* unfold comprehension *) + THEN' (fn i => REPEAT (rtac @{thm exI} i)) (* Turn existentials into schematic Vars *) + THEN' ((rtac @{thm refl}) (* unification instantiates all Vars *) + ORELSE' ((rtac @{thm conjI}) + THEN' (rtac @{thm refl}) + THEN' (blast_tac (local_claset_of ctxt)))) (* Solve rest of context... not very elegant *) + ) i + in + ((PRIMITIVE (Drule.cterm_instantiate [(cert rel, cert relation)]) + THEN ALLGOALS (fn i => if i = 1 then all_tac else solve_membership_tac i))) st + end + + +end + + +(* continuation passing repeat combinator *) +fun REPEAT ttac cont err_cont = + ttac (fn D => fn i => (REPEAT ttac cont cont D i)) err_cont + + + + +end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/datatype_package/datatype.ML --- a/src/HOL/Tools/datatype_package/datatype.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,704 +0,0 @@ -(* Title: HOL/Tools/datatype.ML - Author: Stefan Berghofer, TU Muenchen - -Datatype package for Isabelle/HOL. -*) - -signature DATATYPE = -sig - include DATATYPE_COMMON - type rules = {distinct : thm list list, - inject : thm list list, - exhaustion : thm list, - rec_thms : thm list, - case_thms : thm list list, - split_thms : (thm * thm) list, - induction : thm, - simps : thm list} - val add_datatype : config -> string list -> (string list * binding * mixfix * - (binding * typ list * mixfix) list) list -> theory -> rules * theory - val datatype_cmd : string list -> (string list * binding * mixfix * - (binding * string list * mixfix) list) list -> theory -> theory - val rep_datatype : config -> (rules -> Proof.context -> Proof.context) - -> string list option -> term list -> theory -> Proof.state - val rep_datatype_cmd : string list option -> string list -> theory -> Proof.state - val get_datatypes : theory -> info Symtab.table - val get_datatype : theory -> string -> info option - val the_datatype : theory -> string -> info - val datatype_of_constr : theory -> string -> info option - val datatype_of_case : theory -> string -> info option - val the_datatype_spec : theory -> string -> (string * sort) list * (string * typ list) list - val the_datatype_descr : theory -> string list - -> descr * (string * sort) list * string list - * (string list * string list) * (typ list * typ list) - val get_datatype_constrs : theory -> string -> (string * typ) list option - val interpretation : (config -> string list -> theory -> theory) -> theory -> theory - val distinct_simproc : simproc - val make_case : Proof.context -> bool -> string list -> term -> - (term * term) list -> term * (term * (int * bool)) list - val strip_case : Proof.context -> bool -> term -> (term * (term * term) list) option - val read_typ: theory -> - (typ list * (string * sort) list) * string -> typ list * (string * sort) list - val setup: theory -> theory -end; - -structure Datatype : DATATYPE = -struct - -open DatatypeAux; - - -(* theory data *) - -structure DatatypesData = TheoryDataFun -( - type T = - {types: info Symtab.table, - constrs: info Symtab.table, - cases: info Symtab.table}; - - val empty = - {types = Symtab.empty, constrs = Symtab.empty, cases = Symtab.empty}; - val copy = I; - val extend = I; - fun merge _ - ({types = types1, constrs = constrs1, cases = cases1}, - {types = types2, constrs = constrs2, cases = cases2}) = - {types = Symtab.merge (K true) (types1, types2), - constrs = Symtab.merge (K true) (constrs1, constrs2), - cases = Symtab.merge (K true) (cases1, cases2)}; -); - -val get_datatypes = #types o DatatypesData.get; -val map_datatypes = DatatypesData.map; - - -(** theory information about datatypes **) - -fun put_dt_infos (dt_infos : (string * info) list) = - map_datatypes (fn {types, constrs, cases} => - {types = fold Symtab.update dt_infos types, - constrs = fold Symtab.default (*conservative wrt. overloaded constructors*) - (maps (fn (_, info as {descr, index, ...}) => map (rpair info o fst) - (#3 (the (AList.lookup op = descr index)))) dt_infos) constrs, - cases = fold Symtab.update - (map (fn (_, info as {case_name, ...}) => (case_name, info)) dt_infos) - cases}); - -val get_datatype = Symtab.lookup o get_datatypes; - -fun the_datatype thy name = (case get_datatype thy name of - SOME info => info - | NONE => error ("Unknown datatype " ^ quote name)); - -val datatype_of_constr = Symtab.lookup o #constrs o DatatypesData.get; -val datatype_of_case = Symtab.lookup o #cases o DatatypesData.get; - -fun get_datatype_descr thy dtco = - get_datatype thy dtco - |> Option.map (fn info as { descr, index, ... } => - (info, (((fn SOME (_, dtys, cos) => (dtys, cos)) o AList.lookup (op =) descr) index))); - -fun the_datatype_spec thy dtco = - let - val info as { descr, index, sorts = raw_sorts, ... } = the_datatype thy dtco; - val SOME (_, dtys, raw_cos) = AList.lookup (op =) descr index; - val sorts = map ((fn v => (v, (the o AList.lookup (op =) raw_sorts) v)) - o DatatypeAux.dest_DtTFree) dtys; - val cos = map - (fn (co, tys) => (co, map (DatatypeAux.typ_of_dtyp descr sorts) tys)) raw_cos; - in (sorts, cos) end; - -fun the_datatype_descr thy (raw_tycos as raw_tyco :: _) = - let - val info = the_datatype thy raw_tyco; - val descr = #descr info; - - val SOME (_, dtys, raw_cos) = AList.lookup (op =) descr (#index info); - val vs = map ((fn v => (v, (the o AList.lookup (op =) (#sorts info)) v)) - o dest_DtTFree) dtys; - - fun is_DtTFree (DtTFree _) = true - | is_DtTFree _ = false - val k = find_index (fn (_, (_, dTs, _)) => not (forall is_DtTFree dTs)) descr; - val protoTs as (dataTs, _) = chop k descr - |> (pairself o map) (fn (_, (tyco, dTs, _)) => (tyco, map (typ_of_dtyp descr vs) dTs)); - - val tycos = map fst dataTs; - val _ = if gen_eq_set (op =) (tycos, raw_tycos) then () - else error ("Type constructors " ^ commas (map quote raw_tycos) - ^ "do not belong exhaustively to one mutual recursive datatype"); - - val (Ts, Us) = (pairself o map) Type protoTs; - - val names = map Long_Name.base_name (the_default tycos (#alt_names info)); - val (auxnames, _) = Name.make_context names - |> fold_map (yield_singleton Name.variants o name_of_typ) Us - - in (descr, vs, tycos, (names, auxnames), (Ts, Us)) end; - -fun get_datatype_constrs thy dtco = - case try (the_datatype_spec thy) dtco - of SOME (sorts, cos) => - let - fun subst (v, sort) = TVar ((v, 0), sort); - fun subst_ty (TFree v) = subst v - | subst_ty ty = ty; - val dty = Type (dtco, map subst sorts); - fun mk_co (co, tys) = (co, map (Term.map_atyps subst_ty) tys ---> dty); - in SOME (map mk_co cos) end - | NONE => NONE; - - -(** induct method setup **) - -(* case names *) - -local - -fun dt_recs (DtTFree _) = [] - | dt_recs (DtType (_, dts)) = maps dt_recs dts - | dt_recs (DtRec i) = [i]; - -fun dt_cases (descr: descr) (_, args, constrs) = - let - fun the_bname i = Long_Name.base_name (#1 (the (AList.lookup (op =) descr i))); - val bnames = map the_bname (distinct (op =) (maps dt_recs args)); - in map (fn (c, _) => space_implode "_" (Long_Name.base_name c :: bnames)) constrs end; - - -fun induct_cases descr = - DatatypeProp.indexify_names (maps (dt_cases descr) (map #2 descr)); - -fun exhaust_cases descr i = dt_cases descr (the (AList.lookup (op =) descr i)); - -in - -fun mk_case_names_induct descr = RuleCases.case_names (induct_cases descr); - -fun mk_case_names_exhausts descr new = - map (RuleCases.case_names o exhaust_cases descr o #1) - (filter (fn ((_, (name, _, _))) => member (op =) new name) descr); - -end; - -fun add_rules simps case_thms rec_thms inject distinct - weak_case_congs cong_att = - PureThy.add_thmss [((Binding.name "simps", simps), []), - ((Binding.empty, flat case_thms @ - flat distinct @ rec_thms), [Simplifier.simp_add]), - ((Binding.empty, rec_thms), [Code.add_default_eqn_attribute]), - ((Binding.empty, flat inject), [iff_add]), - ((Binding.empty, map (fn th => th RS notE) (flat distinct)), [Classical.safe_elim NONE]), - ((Binding.empty, weak_case_congs), [cong_att])] - #> snd; - - -(* add_cases_induct *) - -fun add_cases_induct infos induction thy = - let - val inducts = ProjectRule.projections (ProofContext.init thy) induction; - - fun named_rules (name, {index, exhaustion, ...}: info) = - [((Binding.empty, nth inducts index), [Induct.induct_type name]), - ((Binding.empty, exhaustion), [Induct.cases_type name])]; - fun unnamed_rule i = - ((Binding.empty, nth inducts i), [Thm.kind_internal, Induct.induct_type ""]); - in - thy |> PureThy.add_thms - (maps named_rules infos @ - map unnamed_rule (length infos upto length inducts - 1)) |> snd - |> PureThy.add_thmss [((Binding.name "inducts", inducts), [])] |> snd - end; - - - -(**** simplification procedure for showing distinctness of constructors ****) - -fun stripT (i, Type ("fun", [_, T])) = stripT (i + 1, T) - | stripT p = p; - -fun stripC (i, f $ x) = stripC (i + 1, f) - | stripC p = p; - -val distinctN = "constr_distinct"; - -fun distinct_rule thy ss tname eq_t = case #distinct (the_datatype thy tname) of - FewConstrs thms => Goal.prove (Simplifier.the_context ss) [] [] eq_t (K - (EVERY [rtac eq_reflection 1, rtac iffI 1, rtac notE 1, - atac 2, resolve_tac thms 1, etac FalseE 1])) - | ManyConstrs (thm, simpset) => - let - val [In0_inject, In1_inject, In0_not_In1, In1_not_In0] = - map (PureThy.get_thm (ThyInfo.the_theory "Datatype" thy)) - ["In0_inject", "In1_inject", "In0_not_In1", "In1_not_In0"]; - in - Goal.prove (Simplifier.the_context ss) [] [] eq_t (K - (EVERY [rtac eq_reflection 1, rtac iffI 1, dtac thm 1, - full_simp_tac (Simplifier.inherit_context ss simpset) 1, - REPEAT (dresolve_tac [In0_inject, In1_inject] 1), - eresolve_tac [In0_not_In1 RS notE, In1_not_In0 RS notE] 1, - etac FalseE 1])) - end; - -fun distinct_proc thy ss (t as Const ("op =", _) $ t1 $ t2) = - (case (stripC (0, t1), stripC (0, t2)) of - ((i, Const (cname1, T1)), (j, Const (cname2, T2))) => - (case (stripT (0, T1), stripT (0, T2)) of - ((i', Type (tname1, _)), (j', Type (tname2, _))) => - if tname1 = tname2 andalso not (cname1 = cname2) andalso i = i' andalso j = j' then - (case (get_datatype_descr thy) tname1 of - SOME (_, (_, constrs)) => let val cnames = map fst constrs - in if cname1 mem cnames andalso cname2 mem cnames then - SOME (distinct_rule thy ss tname1 - (Logic.mk_equals (t, Const ("False", HOLogic.boolT)))) - else NONE - end - | NONE => NONE) - else NONE - | _ => NONE) - | _ => NONE) - | distinct_proc _ _ _ = NONE; - -val distinct_simproc = - Simplifier.simproc @{theory HOL} distinctN ["s = t"] distinct_proc; - -val dist_ss = HOL_ss addsimprocs [distinct_simproc]; - -val simproc_setup = - Simplifier.map_simpset (fn ss => ss addsimprocs [distinct_simproc]); - - -(**** translation rules for case ****) - -fun make_case ctxt = DatatypeCase.make_case - (datatype_of_constr (ProofContext.theory_of ctxt)) ctxt; - -fun strip_case ctxt = DatatypeCase.strip_case - (datatype_of_case (ProofContext.theory_of ctxt)); - -fun add_case_tr' case_names thy = - Sign.add_advanced_trfuns ([], [], - map (fn case_name => - let val case_name' = Sign.const_syntax_name thy case_name - in (case_name', DatatypeCase.case_tr' datatype_of_case case_name') - end) case_names, []) thy; - -val trfun_setup = - Sign.add_advanced_trfuns ([], - [("_case_syntax", DatatypeCase.case_tr true datatype_of_constr)], - [], []); - - -(* prepare types *) - -fun read_typ thy ((Ts, sorts), str) = - let - val ctxt = ProofContext.init thy - |> fold (Variable.declare_typ o TFree) sorts; - val T = Syntax.read_typ ctxt str; - in (Ts @ [T], Term.add_tfreesT T sorts) end; - -fun cert_typ sign ((Ts, sorts), raw_T) = - let - val T = Type.no_tvars (Sign.certify_typ sign raw_T) handle - TYPE (msg, _, _) => error msg; - val sorts' = Term.add_tfreesT T sorts; - in (Ts @ [T], - case duplicates (op =) (map fst sorts') of - [] => sorts' - | dups => error ("Inconsistent sort constraints for " ^ commas dups)) - end; - - -(**** make datatype info ****) - -fun make_dt_info alt_names descr sorts induct reccomb_names rec_thms - (((((((((i, (_, (tname, _, _))), case_name), case_thms), - exhaustion_thm), distinct_thm), inject), nchotomy), case_cong), weak_case_cong) = - (tname, - {index = i, - alt_names = alt_names, - descr = descr, - sorts = sorts, - rec_names = reccomb_names, - rec_rewrites = rec_thms, - case_name = case_name, - case_rewrites = case_thms, - induction = induct, - exhaustion = exhaustion_thm, - distinct = distinct_thm, - inject = inject, - nchotomy = nchotomy, - case_cong = case_cong, - weak_case_cong = weak_case_cong}); - -type rules = {distinct : thm list list, - inject : thm list list, - exhaustion : thm list, - rec_thms : thm list, - case_thms : thm list list, - split_thms : (thm * thm) list, - induction : thm, - simps : thm list} - -structure DatatypeInterpretation = InterpretationFun - (type T = config * string list val eq: T * T -> bool = eq_snd op =); -fun interpretation f = DatatypeInterpretation.interpretation (uncurry f); - - -(******************* definitional introduction of datatypes *******************) - -fun add_datatype_def (config : config) new_type_names descr sorts types_syntax constr_syntax dt_info - case_names_induct case_names_exhausts thy = - let - val _ = message config ("Proofs for datatype(s) " ^ commas_quote new_type_names); - - val ((inject, distinct, dist_rewrites, simproc_dists, induct), thy2) = thy |> - DatatypeRepProofs.representation_proofs config dt_info new_type_names descr sorts - types_syntax constr_syntax case_names_induct; - - val (casedist_thms, thy3) = DatatypeAbsProofs.prove_casedist_thms config new_type_names descr - sorts induct case_names_exhausts thy2; - val ((reccomb_names, rec_thms), thy4) = DatatypeAbsProofs.prove_primrec_thms - config new_type_names descr sorts dt_info inject dist_rewrites - (Simplifier.theory_context thy3 dist_ss) induct thy3; - val ((case_thms, case_names), thy6) = DatatypeAbsProofs.prove_case_thms - config new_type_names descr sorts reccomb_names rec_thms thy4; - val (split_thms, thy7) = DatatypeAbsProofs.prove_split_thms config new_type_names - descr sorts inject dist_rewrites casedist_thms case_thms thy6; - val (nchotomys, thy8) = DatatypeAbsProofs.prove_nchotomys config new_type_names - descr sorts casedist_thms thy7; - val (case_congs, thy9) = DatatypeAbsProofs.prove_case_congs new_type_names - descr sorts nchotomys case_thms thy8; - val (weak_case_congs, thy10) = DatatypeAbsProofs.prove_weak_case_congs new_type_names - descr sorts thy9; - - val dt_infos = map - (make_dt_info (SOME new_type_names) (flat descr) sorts induct reccomb_names rec_thms) - ((0 upto length (hd descr) - 1) ~~ (hd descr) ~~ case_names ~~ case_thms ~~ - casedist_thms ~~ simproc_dists ~~ inject ~~ nchotomys ~~ case_congs ~~ weak_case_congs); - - val simps = flat (distinct @ inject @ case_thms) @ rec_thms; - - val thy12 = - thy10 - |> add_case_tr' case_names - |> Sign.add_path (space_implode "_" new_type_names) - |> add_rules simps case_thms rec_thms inject distinct - weak_case_congs (Simplifier.attrib (op addcongs)) - |> put_dt_infos dt_infos - |> add_cases_induct dt_infos induct - |> Sign.parent_path - |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms) |> snd - |> DatatypeInterpretation.data (config, map fst dt_infos); - in - ({distinct = distinct, - inject = inject, - exhaustion = casedist_thms, - rec_thms = rec_thms, - case_thms = case_thms, - split_thms = split_thms, - induction = induct, - simps = simps}, thy12) - end; - - -(*********************** declare existing type as datatype *********************) - -fun prove_rep_datatype (config : config) alt_names new_type_names descr sorts induct inject half_distinct thy = - let - val ((_, [induct']), _) = - Variable.importT_thms [induct] (Variable.thm_context induct); - - fun err t = error ("Ill-formed predicate in induction rule: " ^ - Syntax.string_of_term_global thy t); - - fun get_typ (t as _ $ Var (_, Type (tname, Ts))) = - ((tname, map (fst o dest_TFree) Ts) handle TERM _ => err t) - | get_typ t = err t; - val dtnames = map get_typ (HOLogic.dest_conj (HOLogic.dest_Trueprop (Thm.concl_of induct'))); - - val dt_info = get_datatypes thy; - - val distinct = (map o maps) (fn thm => [thm, thm RS not_sym]) half_distinct; - val (case_names_induct, case_names_exhausts) = - (mk_case_names_induct descr, mk_case_names_exhausts descr (map #1 dtnames)); - - val _ = message config ("Proofs for datatype(s) " ^ commas_quote new_type_names); - - val (casedist_thms, thy2) = thy |> - DatatypeAbsProofs.prove_casedist_thms config new_type_names [descr] sorts induct - case_names_exhausts; - val ((reccomb_names, rec_thms), thy3) = DatatypeAbsProofs.prove_primrec_thms - config new_type_names [descr] sorts dt_info inject distinct - (Simplifier.theory_context thy2 dist_ss) induct thy2; - val ((case_thms, case_names), thy4) = DatatypeAbsProofs.prove_case_thms - config new_type_names [descr] sorts reccomb_names rec_thms thy3; - val (split_thms, thy5) = DatatypeAbsProofs.prove_split_thms - config new_type_names [descr] sorts inject distinct casedist_thms case_thms thy4; - val (nchotomys, thy6) = DatatypeAbsProofs.prove_nchotomys config new_type_names - [descr] sorts casedist_thms thy5; - val (case_congs, thy7) = DatatypeAbsProofs.prove_case_congs new_type_names - [descr] sorts nchotomys case_thms thy6; - val (weak_case_congs, thy8) = DatatypeAbsProofs.prove_weak_case_congs new_type_names - [descr] sorts thy7; - - val ((_, [induct']), thy10) = - thy8 - |> store_thmss "inject" new_type_names inject - ||>> store_thmss "distinct" new_type_names distinct - ||> Sign.add_path (space_implode "_" new_type_names) - ||>> PureThy.add_thms [((Binding.name "induct", induct), [case_names_induct])]; - - val dt_infos = map (make_dt_info alt_names descr sorts induct' reccomb_names rec_thms) - ((0 upto length descr - 1) ~~ descr ~~ case_names ~~ case_thms ~~ casedist_thms ~~ - map FewConstrs distinct ~~ inject ~~ nchotomys ~~ case_congs ~~ weak_case_congs); - - val simps = flat (distinct @ inject @ case_thms) @ rec_thms; - - val thy11 = - thy10 - |> add_case_tr' case_names - |> add_rules simps case_thms rec_thms inject distinct - weak_case_congs (Simplifier.attrib (op addcongs)) - |> put_dt_infos dt_infos - |> add_cases_induct dt_infos induct' - |> Sign.parent_path - |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms) - |> snd - |> DatatypeInterpretation.data (config, map fst dt_infos); - in - ({distinct = distinct, - inject = inject, - exhaustion = casedist_thms, - rec_thms = rec_thms, - case_thms = case_thms, - split_thms = split_thms, - induction = induct', - simps = simps}, thy11) - end; - -fun gen_rep_datatype prep_term (config : config) after_qed alt_names raw_ts thy = - let - fun constr_of_term (Const (c, T)) = (c, T) - | constr_of_term t = - error ("Not a constant: " ^ Syntax.string_of_term_global thy t); - fun no_constr (c, T) = error ("Bad constructor: " - ^ Sign.extern_const thy c ^ "::" - ^ Syntax.string_of_typ_global thy T); - fun type_of_constr (cT as (_, T)) = - let - val frees = OldTerm.typ_tfrees T; - val (tyco, vs) = ((apsnd o map) (dest_TFree) o dest_Type o snd o strip_type) T - handle TYPE _ => no_constr cT - val _ = if has_duplicates (eq_fst (op =)) vs then no_constr cT else (); - val _ = if length frees <> length vs then no_constr cT else (); - in (tyco, (vs, cT)) end; - - val raw_cs = AList.group (op =) (map (type_of_constr o constr_of_term o prep_term thy) raw_ts); - val _ = case map_filter (fn (tyco, _) => - if Symtab.defined (get_datatypes thy) tyco then SOME tyco else NONE) raw_cs - of [] => () - | tycos => error ("Type(s) " ^ commas (map quote tycos) - ^ " already represented inductivly"); - val raw_vss = maps (map (map snd o fst) o snd) raw_cs; - val ms = case distinct (op =) (map length raw_vss) - of [n] => 0 upto n - 1 - | _ => error ("Different types in given constructors"); - fun inter_sort m = map (fn xs => nth xs m) raw_vss - |> Library.foldr1 (Sorts.inter_sort (Sign.classes_of thy)) - val sorts = map inter_sort ms; - val vs = Name.names Name.context Name.aT sorts; - - fun norm_constr (raw_vs, (c, T)) = (c, map_atyps - (TFree o (the o AList.lookup (op =) (map fst raw_vs ~~ vs)) o fst o dest_TFree) T); - - val cs = map (apsnd (map norm_constr)) raw_cs; - val dtyps_of_typ = map (dtyp_of_typ (map (rpair (map fst vs) o fst) cs)) - o fst o strip_type; - val new_type_names = map Long_Name.base_name (the_default (map fst cs) alt_names); - - fun mk_spec (i, (tyco, constr)) = (i, (tyco, - map (DtTFree o fst) vs, - (map o apsnd) dtyps_of_typ constr)) - val descr = map_index mk_spec cs; - val injs = DatatypeProp.make_injs [descr] vs; - val half_distincts = map snd (DatatypeProp.make_distincts [descr] vs); - val ind = DatatypeProp.make_ind [descr] vs; - val rules = (map o map o map) Logic.close_form [[[ind]], injs, half_distincts]; - - fun after_qed' raw_thms = - let - val [[[induct]], injs, half_distincts] = - unflat rules (map Drule.zero_var_indexes_list raw_thms); - (*FIXME somehow dubious*) - in - ProofContext.theory_result - (prove_rep_datatype config alt_names new_type_names descr vs induct injs half_distincts) - #-> after_qed - end; - in - thy - |> ProofContext.init - |> Proof.theorem_i NONE after_qed' ((map o map) (rpair []) (flat rules)) - end; - -val rep_datatype = gen_rep_datatype Sign.cert_term; -val rep_datatype_cmd = gen_rep_datatype Syntax.read_term_global default_config (K I); - - - -(******************************** add datatype ********************************) - -fun gen_add_datatype prep_typ (config : config) new_type_names dts thy = - let - val _ = Theory.requires thy "Datatype" "datatype definitions"; - - (* this theory is used just for parsing *) - - val tmp_thy = thy |> - Theory.copy |> - Sign.add_types (map (fn (tvs, tname, mx, _) => - (tname, length tvs, mx)) dts); - - val (tyvars, _, _, _)::_ = dts; - val (new_dts, types_syntax) = ListPair.unzip (map (fn (tvs, tname, mx, _) => - let val full_tname = Sign.full_name tmp_thy (Binding.map_name (Syntax.type_name mx) tname) - in (case duplicates (op =) tvs of - [] => if eq_set (tyvars, tvs) then ((full_tname, tvs), (tname, mx)) - else error ("Mutually recursive datatypes must have same type parameters") - | dups => error ("Duplicate parameter(s) for datatype " ^ quote (Binding.str_of tname) ^ - " : " ^ commas dups)) - end) dts); - - val _ = (case duplicates (op =) (map fst new_dts) @ duplicates (op =) new_type_names of - [] => () | dups => error ("Duplicate datatypes: " ^ commas dups)); - - fun prep_dt_spec ((tvs, tname, mx, constrs), tname') (dts', constr_syntax, sorts, i) = - let - fun prep_constr (cname, cargs, mx') (constrs, constr_syntax', sorts') = - let - val (cargs', sorts'') = Library.foldl (prep_typ tmp_thy) (([], sorts'), cargs); - val _ = (case fold (curry OldTerm.add_typ_tfree_names) cargs' [] \\ tvs of - [] => () - | vs => error ("Extra type variables on rhs: " ^ commas vs)) - in (constrs @ [((if #flat_names config then Sign.full_name tmp_thy else - Sign.full_name_path tmp_thy tname') - (Binding.map_name (Syntax.const_name mx') cname), - map (dtyp_of_typ new_dts) cargs')], - constr_syntax' @ [(cname, mx')], sorts'') - end handle ERROR msg => cat_error msg - ("The error above occured in constructor " ^ quote (Binding.str_of cname) ^ - " of datatype " ^ quote (Binding.str_of tname)); - - val (constrs', constr_syntax', sorts') = - fold prep_constr constrs ([], [], sorts) - - in - case duplicates (op =) (map fst constrs') of - [] => - (dts' @ [(i, (Sign.full_name tmp_thy (Binding.map_name (Syntax.type_name mx) tname), - map DtTFree tvs, constrs'))], - constr_syntax @ [constr_syntax'], sorts', i + 1) - | dups => error ("Duplicate constructors " ^ commas dups ^ - " in datatype " ^ quote (Binding.str_of tname)) - end; - - val (dts', constr_syntax, sorts', i) = - fold prep_dt_spec (dts ~~ new_type_names) ([], [], [], 0); - val sorts = sorts' @ (map (rpair (Sign.defaultS tmp_thy)) (tyvars \\ map fst sorts')); - val dt_info = get_datatypes thy; - val (descr, _) = unfold_datatypes tmp_thy dts' sorts dt_info dts' i; - val _ = check_nonempty descr handle (exn as Datatype_Empty s) => - if #strict config then error ("Nonemptiness check failed for datatype " ^ s) - else raise exn; - - val descr' = flat descr; - val case_names_induct = mk_case_names_induct descr'; - val case_names_exhausts = mk_case_names_exhausts descr' (map #1 new_dts); - in - add_datatype_def - (config : config) new_type_names descr sorts types_syntax constr_syntax dt_info - case_names_induct case_names_exhausts thy - end; - -val add_datatype = gen_add_datatype cert_typ; -val datatype_cmd = snd ooo gen_add_datatype read_typ default_config; - - - -(** package setup **) - -(* setup theory *) - -val setup = - DatatypeRepProofs.distinctness_limit_setup #> - simproc_setup #> - trfun_setup #> - DatatypeInterpretation.init; - - -(* outer syntax *) - -local - -structure P = OuterParse and K = OuterKeyword - -fun prep_datatype_decls args = - let - val names = map - (fn ((((NONE, _), t), _), _) => Binding.name_of t | ((((SOME t, _), _), _), _) => t) args; - val specs = map (fn ((((_, vs), t), mx), cons) => - (vs, t, mx, map (fn ((x, y), z) => (x, y, z)) cons)) args; - in (names, specs) end; - -val parse_datatype_decl = - (Scan.option (P.$$$ "(" |-- P.name --| P.$$$ ")") -- P.type_args -- P.binding -- P.opt_infix -- - (P.$$$ "=" |-- P.enum1 "|" (P.binding -- Scan.repeat P.typ -- P.opt_mixfix))); - -val parse_datatype_decls = P.and_list1 parse_datatype_decl >> prep_datatype_decls; - -in - -val _ = - OuterSyntax.command "datatype" "define inductive datatypes" K.thy_decl - (parse_datatype_decls >> (fn (names, specs) => Toplevel.theory (datatype_cmd names specs))); - -val _ = - OuterSyntax.command "rep_datatype" "represent existing types inductively" K.thy_goal - (Scan.option (P.$$$ "(" |-- Scan.repeat1 P.name --| P.$$$ ")") -- Scan.repeat1 P.term - >> (fn (alt_names, ts) => Toplevel.print - o Toplevel.theory_to_proof (rep_datatype_cmd alt_names ts))); - -end; - - -(* document antiquotation *) - -val _ = ThyOutput.antiquotation "datatype" Args.tyname - (fn {source = src, context = ctxt, ...} => fn dtco => - let - val thy = ProofContext.theory_of ctxt; - val (vs, cos) = the_datatype_spec thy dtco; - val ty = Type (dtco, map TFree vs); - fun pretty_typ_bracket (ty as Type (_, _ :: _)) = - Pretty.enclose "(" ")" [Syntax.pretty_typ ctxt ty] - | pretty_typ_bracket ty = - Syntax.pretty_typ ctxt ty; - fun pretty_constr (co, tys) = - (Pretty.block o Pretty.breaks) - (Syntax.pretty_term ctxt (Const (co, tys ---> ty)) :: - map pretty_typ_bracket tys); - val pretty_datatype = - Pretty.block - (Pretty.command "datatype" :: Pretty.brk 1 :: - Syntax.pretty_typ ctxt ty :: - Pretty.str " =" :: Pretty.brk 1 :: - flat (separate [Pretty.brk 1, Pretty.str "| "] - (map (single o pretty_constr) cos))); - in ThyOutput.output (ThyOutput.maybe_pretty_source (K pretty_datatype) src [()]) end); - -end; - diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/datatype_package/datatype_abs_proofs.ML --- a/src/HOL/Tools/datatype_package/datatype_abs_proofs.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,447 +0,0 @@ -(* Title: HOL/Tools/datatype_abs_proofs.ML - Author: Stefan Berghofer, TU Muenchen - -Proofs and defintions independent of concrete representation -of datatypes (i.e. requiring only abstract properties such as -injectivity / distinctness of constructors and induction) - - - case distinction (exhaustion) theorems - - characteristic equations for primrec combinators - - characteristic equations for case combinators - - equations for splitting "P (case ...)" expressions - - "nchotomy" and "case_cong" theorems for TFL -*) - -signature DATATYPE_ABS_PROOFS = -sig - include DATATYPE_COMMON - val prove_casedist_thms : config -> string list -> - descr list -> (string * sort) list -> thm -> - attribute list -> theory -> thm list * theory - val prove_primrec_thms : config -> string list -> - descr list -> (string * sort) list -> - info Symtab.table -> thm list list -> thm list list -> - simpset -> thm -> theory -> (string list * thm list) * theory - val prove_case_thms : config -> string list -> - descr list -> (string * sort) list -> - string list -> thm list -> theory -> (thm list list * string list) * theory - val prove_split_thms : config -> string list -> - descr list -> (string * sort) list -> - thm list list -> thm list list -> thm list -> thm list list -> theory -> - (thm * thm) list * theory - val prove_nchotomys : config -> string list -> descr list -> - (string * sort) list -> thm list -> theory -> thm list * theory - val prove_weak_case_congs : string list -> descr list -> - (string * sort) list -> theory -> thm list * theory - val prove_case_congs : string list -> - descr list -> (string * sort) list -> - thm list -> thm list list -> theory -> thm list * theory -end; - -structure DatatypeAbsProofs: DATATYPE_ABS_PROOFS = -struct - -open DatatypeAux; - -(************************ case distinction theorems ***************************) - -fun prove_casedist_thms (config : config) new_type_names descr sorts induct case_names_exhausts thy = - let - val _ = message config "Proving case distinction theorems ..."; - - val descr' = List.concat descr; - val recTs = get_rec_types descr' sorts; - val newTs = Library.take (length (hd descr), recTs); - - val {maxidx, ...} = rep_thm induct; - val induct_Ps = map head_of (HOLogic.dest_conj (HOLogic.dest_Trueprop (concl_of induct))); - - fun prove_casedist_thm ((i, t), T) = - let - val dummyPs = map (fn (Var (_, Type (_, [T', T'']))) => - Abs ("z", T', Const ("True", T''))) induct_Ps; - val P = Abs ("z", T, HOLogic.imp $ HOLogic.mk_eq (Var (("a", maxidx+1), T), Bound 0) $ - Var (("P", 0), HOLogic.boolT)) - val insts = Library.take (i, dummyPs) @ (P::(Library.drop (i + 1, dummyPs))); - val cert = cterm_of thy; - val insts' = (map cert induct_Ps) ~~ (map cert insts); - val induct' = refl RS ((List.nth - (split_conj_thm (cterm_instantiate insts' induct), i)) RSN (2, rev_mp)) - - in - SkipProof.prove_global thy [] (Logic.strip_imp_prems t) (Logic.strip_imp_concl t) - (fn {prems, ...} => EVERY - [rtac induct' 1, - REPEAT (rtac TrueI 1), - REPEAT ((rtac impI 1) THEN (eresolve_tac prems 1)), - REPEAT (rtac TrueI 1)]) - end; - - val casedist_thms = map prove_casedist_thm ((0 upto (length newTs - 1)) ~~ - (DatatypeProp.make_casedists descr sorts) ~~ newTs) - in - thy - |> store_thms_atts "exhaust" new_type_names (map single case_names_exhausts) casedist_thms - end; - - -(*************************** primrec combinators ******************************) - -fun prove_primrec_thms (config : config) new_type_names descr sorts - (dt_info : info Symtab.table) constr_inject dist_rewrites dist_ss induct thy = - let - val _ = message config "Constructing primrec combinators ..."; - - val big_name = space_implode "_" new_type_names; - val thy0 = add_path (#flat_names config) big_name thy; - - val descr' = List.concat descr; - val recTs = get_rec_types descr' sorts; - val used = List.foldr OldTerm.add_typ_tfree_names [] recTs; - val newTs = Library.take (length (hd descr), recTs); - - val induct_Ps = map head_of (HOLogic.dest_conj (HOLogic.dest_Trueprop (concl_of induct))); - - val big_rec_name' = big_name ^ "_rec_set"; - val rec_set_names' = - if length descr' = 1 then [big_rec_name'] else - map ((curry (op ^) (big_rec_name' ^ "_")) o string_of_int) - (1 upto (length descr')); - val rec_set_names = map (Sign.full_bname thy0) rec_set_names'; - - val (rec_result_Ts, reccomb_fn_Ts) = DatatypeProp.make_primrec_Ts descr sorts used; - - val rec_set_Ts = map (fn (T1, T2) => - reccomb_fn_Ts @ [T1, T2] ---> HOLogic.boolT) (recTs ~~ rec_result_Ts); - - val rec_fns = map (uncurry (mk_Free "f")) - (reccomb_fn_Ts ~~ (1 upto (length reccomb_fn_Ts))); - val rec_sets' = map (fn c => list_comb (Free c, rec_fns)) - (rec_set_names' ~~ rec_set_Ts); - val rec_sets = map (fn c => list_comb (Const c, rec_fns)) - (rec_set_names ~~ rec_set_Ts); - - (* introduction rules for graph of primrec function *) - - fun make_rec_intr T rec_set ((rec_intr_ts, l), (cname, cargs)) = - let - fun mk_prem ((dt, U), (j, k, prems, t1s, t2s)) = - let val free1 = mk_Free "x" U j - in (case (strip_dtyp dt, strip_type U) of - ((_, DtRec m), (Us, _)) => - let - val free2 = mk_Free "y" (Us ---> List.nth (rec_result_Ts, m)) k; - val i = length Us - in (j + 1, k + 1, HOLogic.mk_Trueprop (HOLogic.list_all - (map (pair "x") Us, List.nth (rec_sets', m) $ - app_bnds free1 i $ app_bnds free2 i)) :: prems, - free1::t1s, free2::t2s) - end - | _ => (j + 1, k, prems, free1::t1s, t2s)) - end; - - val Ts = map (typ_of_dtyp descr' sorts) cargs; - val (_, _, prems, t1s, t2s) = List.foldr mk_prem (1, 1, [], [], []) (cargs ~~ Ts) - - in (rec_intr_ts @ [Logic.list_implies (prems, HOLogic.mk_Trueprop - (rec_set $ list_comb (Const (cname, Ts ---> T), t1s) $ - list_comb (List.nth (rec_fns, l), t1s @ t2s)))], l + 1) - end; - - val (rec_intr_ts, _) = Library.foldl (fn (x, ((d, T), set_name)) => - Library.foldl (make_rec_intr T set_name) (x, #3 (snd d))) - (([], 0), descr' ~~ recTs ~~ rec_sets'); - - val ({intrs = rec_intrs, elims = rec_elims, ...}, thy1) = - Inductive.add_inductive_global (serial_string ()) - {quiet_mode = #quiet config, verbose = false, kind = Thm.internalK, - alt_name = Binding.name big_rec_name', coind = false, no_elim = false, no_ind = true, - skip_mono = true, fork_mono = false} - (map (fn (s, T) => ((Binding.name s, T), NoSyn)) (rec_set_names' ~~ rec_set_Ts)) - (map dest_Free rec_fns) - (map (fn x => (Attrib.empty_binding, x)) rec_intr_ts) [] thy0; - - (* prove uniqueness and termination of primrec combinators *) - - val _ = message config "Proving termination and uniqueness of primrec functions ..."; - - fun mk_unique_tac ((tac, intrs), ((((i, (tname, _, constrs)), elim), T), T')) = - let - val distinct_tac = - (if i < length newTs then - full_simp_tac (HOL_ss addsimps (List.nth (dist_rewrites, i))) 1 - else full_simp_tac dist_ss 1); - - val inject = map (fn r => r RS iffD1) - (if i < length newTs then List.nth (constr_inject, i) - else #inject (the (Symtab.lookup dt_info tname))); - - fun mk_unique_constr_tac n ((tac, intr::intrs, j), (cname, cargs)) = - let - val k = length (List.filter is_rec_type cargs) - - in (EVERY [DETERM tac, - REPEAT (etac ex1E 1), rtac ex1I 1, - DEPTH_SOLVE_1 (ares_tac [intr] 1), - REPEAT_DETERM_N k (etac thin_rl 1 THEN rotate_tac 1 1), - etac elim 1, - REPEAT_DETERM_N j distinct_tac, - TRY (dresolve_tac inject 1), - REPEAT (etac conjE 1), hyp_subst_tac 1, - REPEAT (EVERY [etac allE 1, dtac mp 1, atac 1]), - TRY (hyp_subst_tac 1), - rtac refl 1, - REPEAT_DETERM_N (n - j - 1) distinct_tac], - intrs, j + 1) - end; - - val (tac', intrs', _) = Library.foldl (mk_unique_constr_tac (length constrs)) - ((tac, intrs, 0), constrs); - - in (tac', intrs') end; - - val rec_unique_thms = - let - val rec_unique_ts = map (fn (((set_t, T1), T2), i) => - Const ("Ex1", (T2 --> HOLogic.boolT) --> HOLogic.boolT) $ - absfree ("y", T2, set_t $ mk_Free "x" T1 i $ Free ("y", T2))) - (rec_sets ~~ recTs ~~ rec_result_Ts ~~ (1 upto length recTs)); - val cert = cterm_of thy1 - val insts = map (fn ((i, T), t) => absfree ("x" ^ (string_of_int i), T, t)) - ((1 upto length recTs) ~~ recTs ~~ rec_unique_ts); - val induct' = cterm_instantiate ((map cert induct_Ps) ~~ - (map cert insts)) induct; - val (tac, _) = Library.foldl mk_unique_tac - (((rtac induct' THEN_ALL_NEW ObjectLogic.atomize_prems_tac) 1 - THEN rewrite_goals_tac [mk_meta_eq choice_eq], rec_intrs), - descr' ~~ rec_elims ~~ recTs ~~ rec_result_Ts); - - in split_conj_thm (SkipProof.prove_global thy1 [] [] - (HOLogic.mk_Trueprop (mk_conj rec_unique_ts)) (K tac)) - end; - - val rec_total_thms = map (fn r => r RS theI') rec_unique_thms; - - (* define primrec combinators *) - - val big_reccomb_name = (space_implode "_" new_type_names) ^ "_rec"; - val reccomb_names = map (Sign.full_bname thy1) - (if length descr' = 1 then [big_reccomb_name] else - (map ((curry (op ^) (big_reccomb_name ^ "_")) o string_of_int) - (1 upto (length descr')))); - val reccombs = map (fn ((name, T), T') => list_comb - (Const (name, reccomb_fn_Ts @ [T] ---> T'), rec_fns)) - (reccomb_names ~~ recTs ~~ rec_result_Ts); - - val (reccomb_defs, thy2) = - thy1 - |> Sign.add_consts_i (map (fn ((name, T), T') => - (Binding.name (Long_Name.base_name name), reccomb_fn_Ts @ [T] ---> T', NoSyn)) - (reccomb_names ~~ recTs ~~ rec_result_Ts)) - |> (PureThy.add_defs false o map Thm.no_attributes) (map (fn ((((name, comb), set), T), T') => - (Binding.name (Long_Name.base_name name ^ "_def"), Logic.mk_equals (comb, absfree ("x", T, - Const ("The", (T' --> HOLogic.boolT) --> T') $ absfree ("y", T', - set $ Free ("x", T) $ Free ("y", T')))))) - (reccomb_names ~~ reccombs ~~ rec_sets ~~ recTs ~~ rec_result_Ts)) - ||> parent_path (#flat_names config) - ||> Theory.checkpoint; - - - (* prove characteristic equations for primrec combinators *) - - val _ = message config "Proving characteristic theorems for primrec combinators ..." - - val rec_thms = map (fn t => SkipProof.prove_global thy2 [] [] t - (fn _ => EVERY - [rewrite_goals_tac reccomb_defs, - rtac the1_equality 1, - resolve_tac rec_unique_thms 1, - resolve_tac rec_intrs 1, - REPEAT (rtac allI 1 ORELSE resolve_tac rec_total_thms 1)])) - (DatatypeProp.make_primrecs new_type_names descr sorts thy2) - - in - thy2 - |> Sign.add_path (space_implode "_" new_type_names) - |> PureThy.add_thmss [((Binding.name "recs", rec_thms), - [Nitpick_Const_Simp_Thms.add])] - ||> Sign.parent_path - ||> Theory.checkpoint - |-> (fn thms => pair (reccomb_names, Library.flat thms)) - end; - - -(***************************** case combinators *******************************) - -fun prove_case_thms (config : config) new_type_names descr sorts reccomb_names primrec_thms thy = - let - val _ = message config "Proving characteristic theorems for case combinators ..."; - - val thy1 = add_path (#flat_names config) (space_implode "_" new_type_names) thy; - - val descr' = List.concat descr; - val recTs = get_rec_types descr' sorts; - val used = List.foldr OldTerm.add_typ_tfree_names [] recTs; - val newTs = Library.take (length (hd descr), recTs); - val T' = TFree (Name.variant used "'t", HOLogic.typeS); - - fun mk_dummyT dt = binder_types (typ_of_dtyp descr' sorts dt) ---> T'; - - val case_dummy_fns = map (fn (_, (_, _, constrs)) => map (fn (_, cargs) => - let - val Ts = map (typ_of_dtyp descr' sorts) cargs; - val Ts' = map mk_dummyT (List.filter is_rec_type cargs) - in Const (@{const_name undefined}, Ts @ Ts' ---> T') - end) constrs) descr'; - - val case_names = map (fn s => Sign.full_bname thy1 (s ^ "_case")) new_type_names; - - (* define case combinators via primrec combinators *) - - val (case_defs, thy2) = Library.foldl (fn ((defs, thy), - ((((i, (_, _, constrs)), T), name), recname)) => - let - val (fns1, fns2) = ListPair.unzip (map (fn ((_, cargs), j) => - let - val Ts = map (typ_of_dtyp descr' sorts) cargs; - val Ts' = Ts @ map mk_dummyT (List.filter is_rec_type cargs); - val frees' = map (uncurry (mk_Free "x")) (Ts' ~~ (1 upto length Ts')); - val frees = Library.take (length cargs, frees'); - val free = mk_Free "f" (Ts ---> T') j - in - (free, list_abs_free (map dest_Free frees', - list_comb (free, frees))) - end) (constrs ~~ (1 upto length constrs))); - - val caseT = (map (snd o dest_Free) fns1) @ [T] ---> T'; - val fns = (List.concat (Library.take (i, case_dummy_fns))) @ - fns2 @ (List.concat (Library.drop (i + 1, case_dummy_fns))); - val reccomb = Const (recname, (map fastype_of fns) @ [T] ---> T'); - val decl = ((Binding.name (Long_Name.base_name name), caseT), NoSyn); - val def = (Binding.name (Long_Name.base_name name ^ "_def"), - Logic.mk_equals (list_comb (Const (name, caseT), fns1), - list_comb (reccomb, (List.concat (Library.take (i, case_dummy_fns))) @ - fns2 @ (List.concat (Library.drop (i + 1, case_dummy_fns))) ))); - val ([def_thm], thy') = - thy - |> Sign.declare_const [] decl |> snd - |> (PureThy.add_defs false o map Thm.no_attributes) [def]; - - in (defs @ [def_thm], thy') - end) (([], thy1), (hd descr) ~~ newTs ~~ case_names ~~ - (Library.take (length newTs, reccomb_names))) - ||> Theory.checkpoint; - - val case_thms = map (map (fn t => SkipProof.prove_global thy2 [] [] t - (fn _ => EVERY [rewrite_goals_tac (case_defs @ map mk_meta_eq primrec_thms), rtac refl 1]))) - (DatatypeProp.make_cases new_type_names descr sorts thy2) - in - thy2 - |> Context.the_theory o fold (fold Nitpick_Const_Simp_Thms.add_thm) case_thms - o Context.Theory - |> parent_path (#flat_names config) - |> store_thmss "cases" new_type_names case_thms - |-> (fn thmss => pair (thmss, case_names)) - end; - - -(******************************* case splitting *******************************) - -fun prove_split_thms (config : config) new_type_names descr sorts constr_inject dist_rewrites - casedist_thms case_thms thy = - let - val _ = message config "Proving equations for case splitting ..."; - - val descr' = flat descr; - val recTs = get_rec_types descr' sorts; - val newTs = Library.take (length (hd descr), recTs); - - fun prove_split_thms ((((((t1, t2), inject), dist_rewrites'), - exhaustion), case_thms'), T) = - let - val cert = cterm_of thy; - val _ $ (_ $ lhs $ _) = hd (Logic.strip_assums_hyp (hd (prems_of exhaustion))); - val exhaustion' = cterm_instantiate - [(cert lhs, cert (Free ("x", T)))] exhaustion; - val tacf = K (EVERY [rtac exhaustion' 1, ALLGOALS (asm_simp_tac - (HOL_ss addsimps (dist_rewrites' @ inject @ case_thms')))]) - in - (SkipProof.prove_global thy [] [] t1 tacf, - SkipProof.prove_global thy [] [] t2 tacf) - end; - - val split_thm_pairs = map prove_split_thms - ((DatatypeProp.make_splits new_type_names descr sorts thy) ~~ constr_inject ~~ - dist_rewrites ~~ casedist_thms ~~ case_thms ~~ newTs); - - val (split_thms, split_asm_thms) = ListPair.unzip split_thm_pairs - - in - thy - |> store_thms "split" new_type_names split_thms - ||>> store_thms "split_asm" new_type_names split_asm_thms - |-> (fn (thms1, thms2) => pair (thms1 ~~ thms2)) - end; - -fun prove_weak_case_congs new_type_names descr sorts thy = - let - fun prove_weak_case_cong t = - SkipProof.prove_global thy [] (Logic.strip_imp_prems t) (Logic.strip_imp_concl t) - (fn {prems, ...} => EVERY [rtac ((hd prems) RS arg_cong) 1]) - - val weak_case_congs = map prove_weak_case_cong (DatatypeProp.make_weak_case_congs - new_type_names descr sorts thy) - - in thy |> store_thms "weak_case_cong" new_type_names weak_case_congs end; - -(************************* additional theorems for TFL ************************) - -fun prove_nchotomys (config : config) new_type_names descr sorts casedist_thms thy = - let - val _ = message config "Proving additional theorems for TFL ..."; - - fun prove_nchotomy (t, exhaustion) = - let - (* For goal i, select the correct disjunct to attack, then prove it *) - fun tac i 0 = EVERY [TRY (rtac disjI1 i), - hyp_subst_tac i, REPEAT (rtac exI i), rtac refl i] - | tac i n = rtac disjI2 i THEN tac i (n - 1) - in - SkipProof.prove_global thy [] [] t (fn _ => - EVERY [rtac allI 1, - exh_tac (K exhaustion) 1, - ALLGOALS (fn i => tac i (i-1))]) - end; - - val nchotomys = - map prove_nchotomy (DatatypeProp.make_nchotomys descr sorts ~~ casedist_thms) - - in thy |> store_thms "nchotomy" new_type_names nchotomys end; - -fun prove_case_congs new_type_names descr sorts nchotomys case_thms thy = - let - fun prove_case_cong ((t, nchotomy), case_rewrites) = - let - val (Const ("==>", _) $ tm $ _) = t; - val (Const ("Trueprop", _) $ (Const ("op =", _) $ _ $ Ma)) = tm; - val cert = cterm_of thy; - val nchotomy' = nchotomy RS spec; - val [v] = Term.add_vars (concl_of nchotomy') []; - val nchotomy'' = cterm_instantiate [(cert (Var v), cert Ma)] nchotomy' - in - SkipProof.prove_global thy [] (Logic.strip_imp_prems t) (Logic.strip_imp_concl t) - (fn {prems, ...} => - let val simplify = asm_simp_tac (HOL_ss addsimps (prems @ case_rewrites)) - in EVERY [simp_tac (HOL_ss addsimps [hd prems]) 1, - cut_facts_tac [nchotomy''] 1, - REPEAT (etac disjE 1 THEN REPEAT (etac exE 1) THEN simplify 1), - REPEAT (etac exE 1) THEN simplify 1 (* Get last disjunct *)] - end) - end; - - val case_congs = map prove_case_cong (DatatypeProp.make_case_congs - new_type_names descr sorts thy ~~ nchotomys ~~ case_thms) - - in thy |> store_thms "case_cong" new_type_names case_congs end; - -end; diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/datatype_package/datatype_aux.ML --- a/src/HOL/Tools/datatype_package/datatype_aux.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,381 +0,0 @@ -(* Title: HOL/Tools/datatype_aux.ML - Author: Stefan Berghofer, TU Muenchen - -Auxiliary functions for defining datatypes. -*) - -signature DATATYPE_COMMON = -sig - type config - val default_config : config - datatype dtyp = - DtTFree of string - | DtType of string * (dtyp list) - | DtRec of int; - type descr - type info -end - -signature DATATYPE_AUX = -sig - include DATATYPE_COMMON - - val message : config -> string -> unit - - val add_path : bool -> string -> theory -> theory - val parent_path : bool -> theory -> theory - - val store_thmss_atts : string -> string list -> attribute list list -> thm list list - -> theory -> thm list list * theory - val store_thmss : string -> string list -> thm list list -> theory -> thm list list * theory - val store_thms_atts : string -> string list -> attribute list list -> thm list - -> theory -> thm list * theory - val store_thms : string -> string list -> thm list -> theory -> thm list * theory - - val split_conj_thm : thm -> thm list - val mk_conj : term list -> term - val mk_disj : term list -> term - - val app_bnds : term -> int -> term - - val cong_tac : int -> tactic - val indtac : thm -> string list -> int -> tactic - val exh_tac : (string -> thm) -> int -> tactic - - datatype simproc_dist = FewConstrs of thm list - | ManyConstrs of thm * simpset; - - - exception Datatype - exception Datatype_Empty of string - val name_of_typ : typ -> string - val dtyp_of_typ : (string * string list) list -> typ -> dtyp - val mk_Free : string -> typ -> int -> term - val is_rec_type : dtyp -> bool - val typ_of_dtyp : descr -> (string * sort) list -> dtyp -> typ - val dest_DtTFree : dtyp -> string - val dest_DtRec : dtyp -> int - val strip_dtyp : dtyp -> dtyp list * dtyp - val body_index : dtyp -> int - val mk_fun_dtyp : dtyp list -> dtyp -> dtyp - val get_nonrec_types : descr -> (string * sort) list -> typ list - val get_branching_types : descr -> (string * sort) list -> typ list - val get_arities : descr -> int list - val get_rec_types : descr -> (string * sort) list -> typ list - val interpret_construction : descr -> (string * sort) list - -> { atyp: typ -> 'a, dtyp: typ list -> int * bool -> string * typ list -> 'a } - -> ((string * Term.typ list) * (string * 'a list) list) list - val check_nonempty : descr list -> unit - val unfold_datatypes : - theory -> descr -> (string * sort) list -> info Symtab.table -> - descr -> int -> descr list * int -end; - -structure DatatypeAux : DATATYPE_AUX = -struct - -(* datatype option flags *) - -type config = { strict: bool, flat_names: bool, quiet: bool }; -val default_config : config = - { strict = true, flat_names = false, quiet = false }; -fun message ({ quiet, ...} : config) s = - if quiet then () else writeln s; - -fun add_path flat_names s = if flat_names then I else Sign.add_path s; -fun parent_path flat_names = if flat_names then I else Sign.parent_path; - - -(* store theorems in theory *) - -fun store_thmss_atts label tnames attss thmss = - fold_map (fn ((tname, atts), thms) => - Sign.add_path tname - #> PureThy.add_thmss [((Binding.name label, thms), atts)] - #-> (fn thm::_ => Sign.parent_path #> pair thm)) (tnames ~~ attss ~~ thmss) - ##> Theory.checkpoint; - -fun store_thmss label tnames = store_thmss_atts label tnames (replicate (length tnames) []); - -fun store_thms_atts label tnames attss thmss = - fold_map (fn ((tname, atts), thms) => - Sign.add_path tname - #> PureThy.add_thms [((Binding.name label, thms), atts)] - #-> (fn thm::_ => Sign.parent_path #> pair thm)) (tnames ~~ attss ~~ thmss) - ##> Theory.checkpoint; - -fun store_thms label tnames = store_thms_atts label tnames (replicate (length tnames) []); - - -(* split theorem thm_1 & ... & thm_n into n theorems *) - -fun split_conj_thm th = - ((th RS conjunct1)::(split_conj_thm (th RS conjunct2))) handle THM _ => [th]; - -val mk_conj = foldr1 (HOLogic.mk_binop "op &"); -val mk_disj = foldr1 (HOLogic.mk_binop "op |"); - -fun app_bnds t i = list_comb (t, map Bound (i - 1 downto 0)); - - -fun cong_tac i st = (case Logic.strip_assums_concl - (List.nth (prems_of st, i - 1)) of - _ $ (_ $ (f $ x) $ (g $ y)) => - let - val cong' = Thm.lift_rule (Thm.cprem_of st i) cong; - val _ $ (_ $ (f' $ x') $ (g' $ y')) = - Logic.strip_assums_concl (prop_of cong'); - val insts = map (pairself (cterm_of (Thm.theory_of_thm st)) o - apsnd (curry list_abs (Logic.strip_params (concl_of cong'))) o - apfst head_of) [(f', f), (g', g), (x', x), (y', y)] - in compose_tac (false, cterm_instantiate insts cong', 2) i st - handle THM _ => no_tac st - end - | _ => no_tac st); - -(* instantiate induction rule *) - -fun indtac indrule indnames i st = - let - val ts = HOLogic.dest_conj (HOLogic.dest_Trueprop (concl_of indrule)); - val ts' = HOLogic.dest_conj (HOLogic.dest_Trueprop - (Logic.strip_imp_concl (List.nth (prems_of st, i - 1)))); - val getP = if can HOLogic.dest_imp (hd ts) then - (apfst SOME) o HOLogic.dest_imp else pair NONE; - val flt = if null indnames then I else - filter (fn Free (s, _) => s mem indnames | _ => false); - fun abstr (t1, t2) = (case t1 of - NONE => (case flt (OldTerm.term_frees t2) of - [Free (s, T)] => SOME (absfree (s, T, t2)) - | _ => NONE) - | SOME (_ $ t') => SOME (Abs ("x", fastype_of t', abstract_over (t', t2)))) - val cert = cterm_of (Thm.theory_of_thm st); - val insts = List.mapPartial (fn (t, u) => case abstr (getP u) of - NONE => NONE - | SOME u' => SOME (t |> getP |> snd |> head_of |> cert, cert u')) (ts ~~ ts'); - val indrule' = cterm_instantiate insts indrule - in - rtac indrule' i st - end; - -(* perform exhaustive case analysis on last parameter of subgoal i *) - -fun exh_tac exh_thm_of i state = - let - val thy = Thm.theory_of_thm state; - val prem = nth (prems_of state) (i - 1); - val params = Logic.strip_params prem; - val (_, Type (tname, _)) = hd (rev params); - val exhaustion = Thm.lift_rule (Thm.cprem_of state i) (exh_thm_of tname); - val prem' = hd (prems_of exhaustion); - val _ $ (_ $ lhs $ _) = hd (rev (Logic.strip_assums_hyp prem')); - val exhaustion' = cterm_instantiate [(cterm_of thy (head_of lhs), - cterm_of thy (List.foldr (fn ((_, T), t) => Abs ("z", T, t)) - (Bound 0) params))] exhaustion - in compose_tac (false, exhaustion', nprems_of exhaustion) i state - end; - -(* handling of distinctness theorems *) - -datatype simproc_dist = FewConstrs of thm list - | ManyConstrs of thm * simpset; - -(********************** Internal description of datatypes *********************) - -datatype dtyp = - DtTFree of string - | DtType of string * (dtyp list) - | DtRec of int; - -(* information about datatypes *) - -(* index, datatype name, type arguments, constructor name, types of constructor's arguments *) -type descr = (int * (string * dtyp list * (string * dtyp list) list)) list; - -type info = - {index : int, - alt_names : string list option, - descr : descr, - sorts : (string * sort) list, - rec_names : string list, - rec_rewrites : thm list, - case_name : string, - case_rewrites : thm list, - induction : thm, - exhaustion : thm, - distinct : simproc_dist, - inject : thm list, - nchotomy : thm, - case_cong : thm, - weak_case_cong : thm}; - -fun mk_Free s T i = Free (s ^ (string_of_int i), T); - -fun subst_DtTFree _ substs (T as (DtTFree name)) = - AList.lookup (op =) substs name |> the_default T - | subst_DtTFree i substs (DtType (name, ts)) = - DtType (name, map (subst_DtTFree i substs) ts) - | subst_DtTFree i _ (DtRec j) = DtRec (i + j); - -exception Datatype; -exception Datatype_Empty of string; - -fun dest_DtTFree (DtTFree a) = a - | dest_DtTFree _ = raise Datatype; - -fun dest_DtRec (DtRec i) = i - | dest_DtRec _ = raise Datatype; - -fun is_rec_type (DtType (_, dts)) = exists is_rec_type dts - | is_rec_type (DtRec _) = true - | is_rec_type _ = false; - -fun strip_dtyp (DtType ("fun", [T, U])) = apfst (cons T) (strip_dtyp U) - | strip_dtyp T = ([], T); - -val body_index = dest_DtRec o snd o strip_dtyp; - -fun mk_fun_dtyp [] U = U - | mk_fun_dtyp (T :: Ts) U = DtType ("fun", [T, mk_fun_dtyp Ts U]); - -fun name_of_typ (Type (s, Ts)) = - let val s' = Long_Name.base_name s - in space_implode "_" (List.filter (not o equal "") (map name_of_typ Ts) @ - [if Syntax.is_identifier s' then s' else "x"]) - end - | name_of_typ _ = ""; - -fun dtyp_of_typ _ (TFree (n, _)) = DtTFree n - | dtyp_of_typ _ (TVar _) = error "Illegal schematic type variable(s)" - | dtyp_of_typ new_dts (Type (tname, Ts)) = - (case AList.lookup (op =) new_dts tname of - NONE => DtType (tname, map (dtyp_of_typ new_dts) Ts) - | SOME vs => if map (try (fst o dest_TFree)) Ts = map SOME vs then - DtRec (find_index (curry op = tname o fst) new_dts) - else error ("Illegal occurrence of recursive type " ^ tname)); - -fun typ_of_dtyp descr sorts (DtTFree a) = TFree (a, (the o AList.lookup (op =) sorts) a) - | typ_of_dtyp descr sorts (DtRec i) = - let val (s, ds, _) = (the o AList.lookup (op =) descr) i - in Type (s, map (typ_of_dtyp descr sorts) ds) end - | typ_of_dtyp descr sorts (DtType (s, ds)) = - Type (s, map (typ_of_dtyp descr sorts) ds); - -(* find all non-recursive types in datatype description *) - -fun get_nonrec_types descr sorts = - map (typ_of_dtyp descr sorts) (Library.foldl (fn (Ts, (_, (_, _, constrs))) => - Library.foldl (fn (Ts', (_, cargs)) => - filter_out is_rec_type cargs union Ts') (Ts, constrs)) ([], descr)); - -(* get all recursive types in datatype description *) - -fun get_rec_types descr sorts = map (fn (_ , (s, ds, _)) => - Type (s, map (typ_of_dtyp descr sorts) ds)) descr; - -(* get all branching types *) - -fun get_branching_types descr sorts = - map (typ_of_dtyp descr sorts) (fold (fn (_, (_, _, constrs)) => - fold (fn (_, cargs) => fold (strip_dtyp #> fst #> fold (insert op =)) cargs) - constrs) descr []); - -fun get_arities descr = fold (fn (_, (_, _, constrs)) => - fold (fn (_, cargs) => fold (insert op =) (map (length o fst o strip_dtyp) - (List.filter is_rec_type cargs))) constrs) descr []; - -(* interpret construction of datatype *) - -fun interpret_construction descr vs { atyp, dtyp } = - let - val typ_of_dtyp = typ_of_dtyp descr vs; - fun interpT dT = case strip_dtyp dT - of (dTs, DtRec l) => - let - val (tyco, dTs', _) = (the o AList.lookup (op =) descr) l; - val Ts = map typ_of_dtyp dTs; - val Ts' = map typ_of_dtyp dTs'; - val is_proper = forall (can dest_TFree) Ts'; - in dtyp Ts (l, is_proper) (tyco, Ts') end - | _ => atyp (typ_of_dtyp dT); - fun interpC (c, dTs) = (c, map interpT dTs); - fun interpD (_, (tyco, dTs, cs)) = ((tyco, map typ_of_dtyp dTs), map interpC cs); - in map interpD descr end; - -(* nonemptiness check for datatypes *) - -fun check_nonempty descr = - let - val descr' = List.concat descr; - fun is_nonempty_dt is i = - let - val (_, _, constrs) = (the o AList.lookup (op =) descr') i; - fun arg_nonempty (_, DtRec i) = if i mem is then false - else is_nonempty_dt (i::is) i - | arg_nonempty _ = true; - in exists ((forall (arg_nonempty o strip_dtyp)) o snd) constrs - end - in assert_all (fn (i, _) => is_nonempty_dt [i] i) (hd descr) - (fn (_, (s, _, _)) => raise Datatype_Empty s) - end; - -(* unfold a list of mutually recursive datatype specifications *) -(* all types of the form DtType (dt_name, [..., DtRec _, ...]) *) -(* need to be unfolded *) - -fun unfold_datatypes sign orig_descr sorts (dt_info : info Symtab.table) descr i = - let - fun typ_error T msg = error ("Non-admissible type expression\n" ^ - Syntax.string_of_typ_global sign (typ_of_dtyp (orig_descr @ descr) sorts T) ^ "\n" ^ msg); - - fun get_dt_descr T i tname dts = - (case Symtab.lookup dt_info tname of - NONE => typ_error T (tname ^ " is not a datatype - can't use it in\ - \ nested recursion") - | (SOME {index, descr, ...}) => - let val (_, vars, _) = (the o AList.lookup (op =) descr) index; - val subst = ((map dest_DtTFree vars) ~~ dts) handle Library.UnequalLengths => - typ_error T ("Type constructor " ^ tname ^ " used with wrong\ - \ number of arguments") - in (i + index, map (fn (j, (tn, args, cs)) => (i + j, - (tn, map (subst_DtTFree i subst) args, - map (apsnd (map (subst_DtTFree i subst))) cs))) descr) - end); - - (* unfold a single constructor argument *) - - fun unfold_arg ((i, Ts, descrs), T) = - if is_rec_type T then - let val (Us, U) = strip_dtyp T - in if exists is_rec_type Us then - typ_error T "Non-strictly positive recursive occurrence of type" - else (case U of - DtType (tname, dts) => - let - val (index, descr) = get_dt_descr T i tname dts; - val (descr', i') = unfold_datatypes sign orig_descr sorts - dt_info descr (i + length descr) - in (i', Ts @ [mk_fun_dtyp Us (DtRec index)], descrs @ descr') end - | _ => (i, Ts @ [T], descrs)) - end - else (i, Ts @ [T], descrs); - - (* unfold a constructor *) - - fun unfold_constr ((i, constrs, descrs), (cname, cargs)) = - let val (i', cargs', descrs') = Library.foldl unfold_arg ((i, [], descrs), cargs) - in (i', constrs @ [(cname, cargs')], descrs') end; - - (* unfold a single datatype *) - - fun unfold_datatype ((i, dtypes, descrs), (j, (tname, tvars, constrs))) = - let val (i', constrs', descrs') = - Library.foldl unfold_constr ((i, [], descrs), constrs) - in (i', dtypes @ [(j, (tname, tvars, constrs'))], descrs') - end; - - val (i', descr', descrs) = Library.foldl unfold_datatype ((i, [],[]), descr); - - in (descr' :: descrs, i') end; - -end; diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/datatype_package/datatype_case.ML --- a/src/HOL/Tools/datatype_package/datatype_case.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,469 +0,0 @@ -(* Title: HOL/Tools/datatype_case.ML - Author: Konrad Slind, Cambridge University Computer Laboratory - Author: Stefan Berghofer, TU Muenchen - -Nested case expressions on datatypes. -*) - -signature DATATYPE_CASE = -sig - val make_case: (string -> DatatypeAux.info option) -> - Proof.context -> bool -> string list -> term -> (term * term) list -> - term * (term * (int * bool)) list - val dest_case: (string -> DatatypeAux.info option) -> bool -> - string list -> term -> (term * (term * term) list) option - val strip_case: (string -> DatatypeAux.info option) -> bool -> - term -> (term * (term * term) list) option - val case_tr: bool -> (theory -> string -> DatatypeAux.info option) - -> Proof.context -> term list -> term - val case_tr': (theory -> string -> DatatypeAux.info option) -> - string -> Proof.context -> term list -> term -end; - -structure DatatypeCase : DATATYPE_CASE = -struct - -exception CASE_ERROR of string * int; - -fun match_type thy pat ob = Sign.typ_match thy (pat, ob) Vartab.empty; - -(*--------------------------------------------------------------------------- - * Get information about datatypes - *---------------------------------------------------------------------------*) - -fun ty_info (tab : string -> DatatypeAux.info option) s = - case tab s of - SOME {descr, case_name, index, sorts, ...} => - let - val (_, (tname, dts, constrs)) = nth descr index; - val mk_ty = DatatypeAux.typ_of_dtyp descr sorts; - val T = Type (tname, map mk_ty dts) - in - SOME {case_name = case_name, - constructors = map (fn (cname, dts') => - Const (cname, Logic.varifyT (map mk_ty dts' ---> T))) constrs} - end - | NONE => NONE; - - -(*--------------------------------------------------------------------------- - * Each pattern carries with it a tag (i,b) where - * i is the clause it came from and - * b=true indicates that clause was given by the user - * (or is an instantiation of a user supplied pattern) - * b=false --> i = ~1 - *---------------------------------------------------------------------------*) - -fun pattern_subst theta (tm, x) = (subst_free theta tm, x); - -fun row_of_pat x = fst (snd x); - -fun add_row_used ((prfx, pats), (tm, tag)) = - fold Term.add_free_names (tm :: pats @ prfx); - -(* try to preserve names given by user *) -fun default_names names ts = - map (fn ("", Free (name', _)) => name' | (name, _) => name) (names ~~ ts); - -fun strip_constraints (Const ("_constrain", _) $ t $ tT) = - strip_constraints t ||> cons tT - | strip_constraints t = (t, []); - -fun mk_fun_constrain tT t = Syntax.const "_constrain" $ t $ - (Syntax.free "fun" $ tT $ Syntax.free "dummy"); - - -(*--------------------------------------------------------------------------- - * Produce an instance of a constructor, plus genvars for its arguments. - *---------------------------------------------------------------------------*) -fun fresh_constr ty_match ty_inst colty used c = - let - val (_, Ty) = dest_Const c - val Ts = binder_types Ty; - val names = Name.variant_list used - (DatatypeProp.make_tnames (map Logic.unvarifyT Ts)); - val ty = body_type Ty; - val ty_theta = ty_match ty colty handle Type.TYPE_MATCH => - raise CASE_ERROR ("type mismatch", ~1) - val c' = ty_inst ty_theta c - val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts) - in (c', gvars) - end; - - -(*--------------------------------------------------------------------------- - * Goes through a list of rows and picks out the ones beginning with a - * pattern with constructor = name. - *---------------------------------------------------------------------------*) -fun mk_group (name, T) rows = - let val k = length (binder_types T) - in fold (fn (row as ((prfx, p :: rst), rhs as (_, (i, _)))) => - fn ((in_group, not_in_group), (names, cnstrts)) => (case strip_comb p of - (Const (name', _), args) => - if name = name' then - if length args = k then - let val (args', cnstrts') = split_list (map strip_constraints args) - in - ((((prfx, args' @ rst), rhs) :: in_group, not_in_group), - (default_names names args', map2 append cnstrts cnstrts')) - end - else raise CASE_ERROR - ("Wrong number of arguments for constructor " ^ name, i) - else ((in_group, row :: not_in_group), (names, cnstrts)) - | _ => raise CASE_ERROR ("Not a constructor pattern", i))) - rows (([], []), (replicate k "", replicate k [])) |>> pairself rev - end; - -(*--------------------------------------------------------------------------- - * Partition the rows. Not efficient: we should use hashing. - *---------------------------------------------------------------------------*) -fun partition _ _ _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1) - | partition ty_match ty_inst type_of used constructors colty res_ty - (rows as (((prfx, _ :: rstp), _) :: _)) = - let - fun part {constrs = [], rows = [], A} = rev A - | part {constrs = [], rows = (_, (_, (i, _))) :: _, A} = - raise CASE_ERROR ("Not a constructor pattern", i) - | part {constrs = c :: crst, rows, A} = - let - val ((in_group, not_in_group), (names, cnstrts)) = - mk_group (dest_Const c) rows; - val used' = fold add_row_used in_group used; - val (c', gvars) = fresh_constr ty_match ty_inst colty used' c; - val in_group' = - if null in_group (* Constructor not given *) - then - let - val Ts = map type_of rstp; - val xs = Name.variant_list - (fold Term.add_free_names gvars used') - (replicate (length rstp) "x") - in - [((prfx, gvars @ map Free (xs ~~ Ts)), - (Const ("HOL.undefined", res_ty), (~1, false)))] - end - else in_group - in - part{constrs = crst, - rows = not_in_group, - A = {constructor = c', - new_formals = gvars, - names = names, - constraints = cnstrts, - group = in_group'} :: A} - end - in part {constrs = constructors, rows = rows, A = []} - end; - -(*--------------------------------------------------------------------------- - * Misc. routines used in mk_case - *---------------------------------------------------------------------------*) - -fun mk_pat ((c, c'), l) = - let - val L = length (binder_types (fastype_of c)) - fun build (prfx, tag, plist) = - let val (args, plist') = chop L plist - in (prfx, tag, list_comb (c', args) :: plist') end - in map build l end; - -fun v_to_prfx (prfx, v::pats) = (v::prfx,pats) - | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1); - -fun v_to_pats (v::prfx,tag, pats) = (prfx, tag, v::pats) - | v_to_pats _ = raise CASE_ERROR ("mk_case: v_to_pats", ~1); - - -(*---------------------------------------------------------------------------- - * Translation of pattern terms into nested case expressions. - * - * This performs the translation and also builds the full set of patterns. - * Thus it supports the construction of induction theorems even when an - * incomplete set of patterns is given. - *---------------------------------------------------------------------------*) - -fun mk_case tab ctxt ty_match ty_inst type_of used range_ty = - let - val name = Name.variant used "a"; - fun expand constructors used ty ((_, []), _) = - raise CASE_ERROR ("mk_case: expand_var_row", ~1) - | expand constructors used ty (row as ((prfx, p :: rst), rhs)) = - if is_Free p then - let - val used' = add_row_used row used; - fun expnd c = - let val capp = - list_comb (fresh_constr ty_match ty_inst ty used' c) - in ((prfx, capp :: rst), pattern_subst [(p, capp)] rhs) - end - in map expnd constructors end - else [row] - fun mk {rows = [], ...} = raise CASE_ERROR ("no rows", ~1) - | mk {path = [], rows = ((prfx, []), (tm, tag)) :: _} = (* Done *) - ([(prfx, tag, [])], tm) - | mk {path, rows as ((row as ((_, [Free _]), _)) :: _ :: _)} = - mk {path = path, rows = [row]} - | mk {path = u :: rstp, rows as ((_, _ :: _), _) :: _} = - let val col0 = map (fn ((_, p :: _), (_, (i, _))) => (p, i)) rows - in case Option.map (apfst head_of) - (find_first (not o is_Free o fst) col0) of - NONE => - let - val rows' = map (fn ((v, _), row) => row ||> - pattern_subst [(v, u)] |>> v_to_prfx) (col0 ~~ rows); - val (pref_patl, tm) = mk {path = rstp, rows = rows'} - in (map v_to_pats pref_patl, tm) end - | SOME (Const (cname, cT), i) => (case ty_info tab cname of - NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i) - | SOME {case_name, constructors} => - let - val pty = body_type cT; - val used' = fold Term.add_free_names rstp used; - val nrows = maps (expand constructors used' pty) rows; - val subproblems = partition ty_match ty_inst type_of used' - constructors pty range_ty nrows; - val new_formals = map #new_formals subproblems - val constructors' = map #constructor subproblems - val news = map (fn {new_formals, group, ...} => - {path = new_formals @ rstp, rows = group}) subproblems; - val (pat_rect, dtrees) = split_list (map mk news); - val case_functions = map2 - (fn {new_formals, names, constraints, ...} => - fold_rev (fn ((x as Free (_, T), s), cnstrts) => fn t => - Abs (if s = "" then name else s, T, - abstract_over (x, t)) |> - fold mk_fun_constrain cnstrts) - (new_formals ~~ names ~~ constraints)) - subproblems dtrees; - val types = map type_of (case_functions @ [u]); - val case_const = Const (case_name, types ---> range_ty) - val tree = list_comb (case_const, case_functions @ [u]) - val pat_rect1 = flat (map mk_pat - (constructors ~~ constructors' ~~ pat_rect)) - in (pat_rect1, tree) - end) - | SOME (t, i) => raise CASE_ERROR ("Not a datatype constructor: " ^ - Syntax.string_of_term ctxt t, i) - end - | mk _ = raise CASE_ERROR ("Malformed row matrix", ~1) - in mk - end; - -fun case_error s = error ("Error in case expression:\n" ^ s); - -(* Repeated variable occurrences in a pattern are not allowed. *) -fun no_repeat_vars ctxt pat = fold_aterms - (fn x as Free (s, _) => (fn xs => - if member op aconv xs x then - case_error (quote s ^ " occurs repeatedly in the pattern " ^ - quote (Syntax.string_of_term ctxt pat)) - else x :: xs) - | _ => I) pat []; - -fun gen_make_case ty_match ty_inst type_of tab ctxt err used x clauses = - let - fun string_of_clause (pat, rhs) = Syntax.string_of_term ctxt - (Syntax.const "_case1" $ pat $ rhs); - val _ = map (no_repeat_vars ctxt o fst) clauses; - val rows = map_index (fn (i, (pat, rhs)) => - (([], [pat]), (rhs, (i, true)))) clauses; - val rangeT = (case distinct op = (map (type_of o snd) clauses) of - [] => case_error "no clauses given" - | [T] => T - | _ => case_error "all cases must have the same result type"); - val used' = fold add_row_used rows used; - val (patts, case_tm) = mk_case tab ctxt ty_match ty_inst type_of - used' rangeT {path = [x], rows = rows} - handle CASE_ERROR (msg, i) => case_error (msg ^ - (if i < 0 then "" - else "\nIn clause\n" ^ string_of_clause (nth clauses i))); - val patts1 = map - (fn (_, tag, [pat]) => (pat, tag) - | _ => case_error "error in pattern-match translation") patts; - val patts2 = Library.sort (Library.int_ord o Library.pairself row_of_pat) patts1 - val finals = map row_of_pat patts2 - val originals = map (row_of_pat o #2) rows - val _ = case originals \\ finals of - [] => () - | is => (if err then case_error else warning) - ("The following clauses are redundant (covered by preceding clauses):\n" ^ - cat_lines (map (string_of_clause o nth clauses) is)); - in - (case_tm, patts2) - end; - -fun make_case tab ctxt = gen_make_case - (match_type (ProofContext.theory_of ctxt)) Envir.subst_TVars fastype_of tab ctxt; -val make_case_untyped = gen_make_case (K (K Vartab.empty)) - (K (Term.map_types (K dummyT))) (K dummyT); - - -(* parse translation *) - -fun case_tr err tab_of ctxt [t, u] = - let - val thy = ProofContext.theory_of ctxt; - (* replace occurrences of dummy_pattern by distinct variables *) - (* internalize constant names *) - fun prep_pat ((c as Const ("_constrain", _)) $ t $ tT) used = - let val (t', used') = prep_pat t used - in (c $ t' $ tT, used') end - | prep_pat (Const ("dummy_pattern", T)) used = - let val x = Name.variant used "x" - in (Free (x, T), x :: used) end - | prep_pat (Const (s, T)) used = - (case try (unprefix Syntax.constN) s of - SOME c => (Const (c, T), used) - | NONE => (Const (Sign.intern_const thy s, T), used)) - | prep_pat (v as Free (s, T)) used = - let val s' = Sign.intern_const thy s - in - if Sign.declared_const thy s' then - (Const (s', T), used) - else (v, used) - end - | prep_pat (t $ u) used = - let - val (t', used') = prep_pat t used; - val (u', used'') = prep_pat u used' - in - (t' $ u', used'') - end - | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t); - fun dest_case1 (t as Const ("_case1", _) $ l $ r) = - let val (l', cnstrts) = strip_constraints l - in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts) - end - | dest_case1 t = case_error "dest_case1"; - fun dest_case2 (Const ("_case2", _) $ t $ u) = t :: dest_case2 u - | dest_case2 t = [t]; - val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u)); - val (case_tm, _) = make_case_untyped (tab_of thy) ctxt err [] - (fold (fn tT => fn t => Syntax.const "_constrain" $ t $ tT) - (flat cnstrts) t) cases; - in case_tm end - | case_tr _ _ _ ts = case_error "case_tr"; - - -(*--------------------------------------------------------------------------- - * Pretty printing of nested case expressions - *---------------------------------------------------------------------------*) - -(* destruct one level of pattern matching *) - -fun gen_dest_case name_of type_of tab d used t = - case apfst name_of (strip_comb t) of - (SOME cname, ts as _ :: _) => - let - val (fs, x) = split_last ts; - fun strip_abs i t = - let - val zs = strip_abs_vars t; - val _ = if length zs < i then raise CASE_ERROR ("", 0) else (); - val (xs, ys) = chop i zs; - val u = list_abs (ys, strip_abs_body t); - val xs' = map Free (Name.variant_list (OldTerm.add_term_names (u, used)) - (map fst xs) ~~ map snd xs) - in (xs', subst_bounds (rev xs', u)) end; - fun is_dependent i t = - let val k = length (strip_abs_vars t) - i - in k < 0 orelse exists (fn j => j >= k) - (loose_bnos (strip_abs_body t)) - end; - fun count_cases (_, _, true) = I - | count_cases (c, (_, body), false) = - AList.map_default op aconv (body, []) (cons c); - val is_undefined = name_of #> equal (SOME "HOL.undefined"); - fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body) - in case ty_info tab cname of - SOME {constructors, case_name} => - if length fs = length constructors then - let - val cases = map (fn (Const (s, U), t) => - let - val k = length (binder_types U); - val p as (xs, _) = strip_abs k t - in - (Const (s, map type_of xs ---> type_of x), - p, is_dependent k t) - end) (constructors ~~ fs); - val cases' = sort (int_ord o swap o pairself (length o snd)) - (fold_rev count_cases cases []); - val R = type_of t; - val dummy = if d then Const ("dummy_pattern", R) - else Free (Name.variant used "x", R) - in - SOME (x, map mk_case (case find_first (is_undefined o fst) cases' of - SOME (_, cs) => - if length cs = length constructors then [hd cases] - else filter_out (fn (_, (_, body), _) => is_undefined body) cases - | NONE => case cases' of - [] => cases - | (default, cs) :: _ => - if length cs = 1 then cases - else if length cs = length constructors then - [hd cases, (dummy, ([], default), false)] - else - filter_out (fn (c, _, _) => member op aconv cs c) cases @ - [(dummy, ([], default), false)])) - end handle CASE_ERROR _ => NONE - else NONE - | _ => NONE - end - | _ => NONE; - -val dest_case = gen_dest_case (try (dest_Const #> fst)) fastype_of; -val dest_case' = gen_dest_case - (try (dest_Const #> fst #> unprefix Syntax.constN)) (K dummyT); - - -(* destruct nested patterns *) - -fun strip_case'' dest (pat, rhs) = - case dest (Term.add_free_names pat []) rhs of - SOME (exp as Free _, clauses) => - if member op aconv (OldTerm.term_frees pat) exp andalso - not (exists (fn (_, rhs') => - member op aconv (OldTerm.term_frees rhs') exp) clauses) - then - maps (strip_case'' dest) (map (fn (pat', rhs') => - (subst_free [(exp, pat')] pat, rhs')) clauses) - else [(pat, rhs)] - | _ => [(pat, rhs)]; - -fun gen_strip_case dest t = case dest [] t of - SOME (x, clauses) => - SOME (x, maps (strip_case'' dest) clauses) - | NONE => NONE; - -val strip_case = gen_strip_case oo dest_case; -val strip_case' = gen_strip_case oo dest_case'; - - -(* print translation *) - -fun case_tr' tab_of cname ctxt ts = - let - val thy = ProofContext.theory_of ctxt; - val consts = ProofContext.consts_of ctxt; - fun mk_clause (pat, rhs) = - let val xs = Term.add_frees pat [] - in - Syntax.const "_case1" $ - map_aterms - (fn Free p => Syntax.mark_boundT p - | Const (s, _) => Const (Consts.extern_early consts s, dummyT) - | t => t) pat $ - map_aterms - (fn x as Free (s, T) => - if member (op =) xs (s, T) then Syntax.mark_bound s else x - | t => t) rhs - end - in case strip_case' (tab_of thy) true (list_comb (Syntax.const cname, ts)) of - SOME (x, clauses) => Syntax.const "_case_syntax" $ x $ - foldr1 (fn (t, u) => Syntax.const "_case2" $ t $ u) - (map mk_clause clauses) - | NONE => raise Match - end; - -end; diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/datatype_package/datatype_codegen.ML --- a/src/HOL/Tools/datatype_package/datatype_codegen.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,455 +0,0 @@ -(* Title: HOL/Tools/datatype_codegen.ML - Author: Stefan Berghofer and Florian Haftmann, TU Muenchen - -Code generator facilities for inductive datatypes. -*) - -signature DATATYPE_CODEGEN = -sig - val find_shortest_path: Datatype.descr -> int -> (string * int) option - val mk_eq_eqns: theory -> string -> (thm * bool) list - val mk_case_cert: theory -> string -> thm - val setup: theory -> theory -end; - -structure DatatypeCodegen : DATATYPE_CODEGEN = -struct - -(** find shortest path to constructor with no recursive arguments **) - -fun find_nonempty (descr: Datatype.descr) is i = - let - val (_, _, constrs) = the (AList.lookup (op =) descr i); - fun arg_nonempty (_, DatatypeAux.DtRec i) = if member (op =) is i - then NONE - else Option.map (curry op + 1 o snd) (find_nonempty descr (i::is) i) - | arg_nonempty _ = SOME 0; - fun max xs = Library.foldl - (fn (NONE, _) => NONE - | (SOME i, SOME j) => SOME (Int.max (i, j)) - | (_, NONE) => NONE) (SOME 0, xs); - val xs = sort (int_ord o pairself snd) - (map_filter (fn (s, dts) => Option.map (pair s) - (max (map (arg_nonempty o DatatypeAux.strip_dtyp) dts))) constrs) - in case xs of [] => NONE | x :: _ => SOME x end; - -fun find_shortest_path descr i = find_nonempty descr [i] i; - - -(** SML code generator **) - -open Codegen; - -(* datatype definition *) - -fun add_dt_defs thy defs dep module (descr: Datatype.descr) sorts gr = - let - val descr' = List.filter (can (map DatatypeAux.dest_DtTFree o #2 o snd)) descr; - val rtnames = map (#1 o snd) (List.filter (fn (_, (_, _, cs)) => - exists (exists DatatypeAux.is_rec_type o snd) cs) descr'); - - val (_, (tname, _, _)) :: _ = descr'; - val node_id = tname ^ " (type)"; - val module' = if_library (thyname_of_type thy tname) module; - - fun mk_dtdef prfx [] gr = ([], gr) - | mk_dtdef prfx ((_, (tname, dts, cs))::xs) gr = - let - val tvs = map DatatypeAux.dest_DtTFree dts; - val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs; - val ((_, type_id), gr') = mk_type_id module' tname gr; - val (ps, gr'') = gr' |> - fold_map (fn (cname, cargs) => - fold_map (invoke_tycodegen thy defs node_id module' false) - cargs ##>> - mk_const_id module' cname) cs'; - val (rest, gr''') = mk_dtdef "and " xs gr'' - in - (Pretty.block (str prfx :: - (if null tvs then [] else - [mk_tuple (map str tvs), str " "]) @ - [str (type_id ^ " ="), Pretty.brk 1] @ - List.concat (separate [Pretty.brk 1, str "| "] - (map (fn (ps', (_, cname)) => [Pretty.block - (str cname :: - (if null ps' then [] else - List.concat ([str " of", Pretty.brk 1] :: - separate [str " *", Pretty.brk 1] - (map single ps'))))]) ps))) :: rest, gr''') - end; - - fun mk_constr_term cname Ts T ps = - List.concat (separate [str " $", Pretty.brk 1] - ([str ("Const (\"" ^ cname ^ "\","), Pretty.brk 1, - mk_type false (Ts ---> T), str ")"] :: ps)); - - fun mk_term_of_def gr prfx [] = [] - | mk_term_of_def gr prfx ((_, (tname, dts, cs)) :: xs) = - let - val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs; - val dts' = map (DatatypeAux.typ_of_dtyp descr sorts) dts; - val T = Type (tname, dts'); - val rest = mk_term_of_def gr "and " xs; - val (eqs, _) = fold_map (fn (cname, Ts) => fn prfx => - let val args = map (fn i => - str ("x" ^ string_of_int i)) (1 upto length Ts) - in (Pretty.blk (4, - [str prfx, mk_term_of gr module' false T, Pretty.brk 1, - if null Ts then str (snd (get_const_id gr cname)) - else parens (Pretty.block - [str (snd (get_const_id gr cname)), - Pretty.brk 1, mk_tuple args]), - str " =", Pretty.brk 1] @ - mk_constr_term cname Ts T - (map2 (fn x => fn U => [Pretty.block [mk_term_of gr module' false U, - Pretty.brk 1, x]]) args Ts)), " | ") - end) cs' prfx - in eqs @ rest end; - - fun mk_gen_of_def gr prfx [] = [] - | mk_gen_of_def gr prfx ((i, (tname, dts, cs)) :: xs) = - let - val tvs = map DatatypeAux.dest_DtTFree dts; - val Us = map (DatatypeAux.typ_of_dtyp descr sorts) dts; - val T = Type (tname, Us); - val (cs1, cs2) = - List.partition (exists DatatypeAux.is_rec_type o snd) cs; - val SOME (cname, _) = find_shortest_path descr i; - - fun mk_delay p = Pretty.block - [str "fn () =>", Pretty.brk 1, p]; - - fun mk_force p = Pretty.block [p, Pretty.brk 1, str "()"]; - - fun mk_constr s b (cname, dts) = - let - val gs = map (fn dt => mk_app false (mk_gen gr module' false rtnames s - (DatatypeAux.typ_of_dtyp descr sorts dt)) - [str (if b andalso DatatypeAux.is_rec_type dt then "0" - else "j")]) dts; - val Ts = map (DatatypeAux.typ_of_dtyp descr sorts) dts; - val xs = map str - (DatatypeProp.indexify_names (replicate (length dts) "x")); - val ts = map str - (DatatypeProp.indexify_names (replicate (length dts) "t")); - val (_, id) = get_const_id gr cname - in - mk_let - (map2 (fn p => fn q => mk_tuple [p, q]) xs ts ~~ gs) - (mk_tuple - [case xs of - _ :: _ :: _ => Pretty.block - [str id, Pretty.brk 1, mk_tuple xs] - | _ => mk_app false (str id) xs, - mk_delay (Pretty.block (mk_constr_term cname Ts T - (map (single o mk_force) ts)))]) - end; - - fun mk_choice [c] = mk_constr "(i-1)" false c - | mk_choice cs = Pretty.block [str "one_of", - Pretty.brk 1, Pretty.blk (1, str "[" :: - List.concat (separate [str ",", Pretty.fbrk] - (map (single o mk_delay o mk_constr "(i-1)" false) cs)) @ - [str "]"]), Pretty.brk 1, str "()"]; - - val gs = maps (fn s => - let val s' = strip_tname s - in [str (s' ^ "G"), str (s' ^ "T")] end) tvs; - val gen_name = "gen_" ^ snd (get_type_id gr tname) - - in - Pretty.blk (4, separate (Pretty.brk 1) - (str (prfx ^ gen_name ^ - (if null cs1 then "" else "'")) :: gs @ - (if null cs1 then [] else [str "i"]) @ - [str "j"]) @ - [str " =", Pretty.brk 1] @ - (if not (null cs1) andalso not (null cs2) - then [str "frequency", Pretty.brk 1, - Pretty.blk (1, [str "[", - mk_tuple [str "i", mk_delay (mk_choice cs1)], - str ",", Pretty.fbrk, - mk_tuple [str "1", mk_delay (mk_choice cs2)], - str "]"]), Pretty.brk 1, str "()"] - else if null cs2 then - [Pretty.block [str "(case", Pretty.brk 1, - str "i", Pretty.brk 1, str "of", - Pretty.brk 1, str "0 =>", Pretty.brk 1, - mk_constr "0" true (cname, valOf (AList.lookup (op =) cs cname)), - Pretty.brk 1, str "| _ =>", Pretty.brk 1, - mk_choice cs1, str ")"]] - else [mk_choice cs2])) :: - (if null cs1 then [] - else [Pretty.blk (4, separate (Pretty.brk 1) - (str ("and " ^ gen_name) :: gs @ [str "i"]) @ - [str " =", Pretty.brk 1] @ - separate (Pretty.brk 1) (str (gen_name ^ "'") :: gs @ - [str "i", str "i"]))]) @ - mk_gen_of_def gr "and " xs - end - - in - (module', (add_edge_acyclic (node_id, dep) gr - handle Graph.CYCLES _ => gr) handle Graph.UNDEF _ => - let - val gr1 = add_edge (node_id, dep) - (new_node (node_id, (NONE, "", "")) gr); - val (dtdef, gr2) = mk_dtdef "datatype " descr' gr1 ; - in - map_node node_id (K (NONE, module', - string_of (Pretty.blk (0, separate Pretty.fbrk dtdef @ - [str ";"])) ^ "\n\n" ^ - (if "term_of" mem !mode then - string_of (Pretty.blk (0, separate Pretty.fbrk - (mk_term_of_def gr2 "fun " descr') @ [str ";"])) ^ "\n\n" - else "") ^ - (if "test" mem !mode then - string_of (Pretty.blk (0, separate Pretty.fbrk - (mk_gen_of_def gr2 "fun " descr') @ [str ";"])) ^ "\n\n" - else ""))) gr2 - end) - end; - - -(* case expressions *) - -fun pretty_case thy defs dep module brack constrs (c as Const (_, T)) ts gr = - let val i = length constrs - in if length ts <= i then - invoke_codegen thy defs dep module brack (eta_expand c ts (i+1)) gr - else - let - val ts1 = Library.take (i, ts); - val t :: ts2 = Library.drop (i, ts); - val names = List.foldr OldTerm.add_term_names - (map (fst o fst o dest_Var) (List.foldr OldTerm.add_term_vars [] ts1)) ts1; - val (Ts, dT) = split_last (Library.take (i+1, fst (strip_type T))); - - fun pcase [] [] [] gr = ([], gr) - | pcase ((cname, cargs)::cs) (t::ts) (U::Us) gr = - let - val j = length cargs; - val xs = Name.variant_list names (replicate j "x"); - val Us' = Library.take (j, fst (strip_type U)); - val frees = map Free (xs ~~ Us'); - val (cp, gr0) = invoke_codegen thy defs dep module false - (list_comb (Const (cname, Us' ---> dT), frees)) gr; - val t' = Envir.beta_norm (list_comb (t, frees)); - val (p, gr1) = invoke_codegen thy defs dep module false t' gr0; - val (ps, gr2) = pcase cs ts Us gr1; - in - ([Pretty.block [cp, str " =>", Pretty.brk 1, p]] :: ps, gr2) - end; - - val (ps1, gr1) = pcase constrs ts1 Ts gr ; - val ps = List.concat (separate [Pretty.brk 1, str "| "] ps1); - val (p, gr2) = invoke_codegen thy defs dep module false t gr1; - val (ps2, gr3) = fold_map (invoke_codegen thy defs dep module true) ts2 gr2; - in ((if not (null ts2) andalso brack then parens else I) - (Pretty.block (separate (Pretty.brk 1) - (Pretty.block ([str "(case ", p, str " of", - Pretty.brk 1] @ ps @ [str ")"]) :: ps2))), gr3) - end - end; - - -(* constructors *) - -fun pretty_constr thy defs dep module brack args (c as Const (s, T)) ts gr = - let val i = length args - in if i > 1 andalso length ts < i then - invoke_codegen thy defs dep module brack (eta_expand c ts i) gr - else - let - val id = mk_qual_id module (get_const_id gr s); - val (ps, gr') = fold_map - (invoke_codegen thy defs dep module (i = 1)) ts gr; - in (case args of - _ :: _ :: _ => (if brack then parens else I) - (Pretty.block [str id, Pretty.brk 1, mk_tuple ps]) - | _ => (mk_app brack (str id) ps), gr') - end - end; - - -(* code generators for terms and types *) - -fun datatype_codegen thy defs dep module brack t gr = (case strip_comb t of - (c as Const (s, T), ts) => - (case Datatype.datatype_of_case thy s of - SOME {index, descr, ...} => - if is_some (get_assoc_code thy (s, T)) then NONE else - SOME (pretty_case thy defs dep module brack - (#3 (the (AList.lookup op = descr index))) c ts gr ) - | NONE => case (Datatype.datatype_of_constr thy s, strip_type T) of - (SOME {index, descr, ...}, (_, U as Type (tyname, _))) => - if is_some (get_assoc_code thy (s, T)) then NONE else - let - val SOME (tyname', _, constrs) = AList.lookup op = descr index; - val SOME args = AList.lookup op = constrs s - in - if tyname <> tyname' then NONE - else SOME (pretty_constr thy defs - dep module brack args c ts (snd (invoke_tycodegen thy defs dep module false U gr))) - end - | _ => NONE) - | _ => NONE); - -fun datatype_tycodegen thy defs dep module brack (Type (s, Ts)) gr = - (case Datatype.get_datatype thy s of - NONE => NONE - | SOME {descr, sorts, ...} => - if is_some (get_assoc_type thy s) then NONE else - let - val (ps, gr') = fold_map - (invoke_tycodegen thy defs dep module false) Ts gr; - val (module', gr'') = add_dt_defs thy defs dep module descr sorts gr' ; - val (tyid, gr''') = mk_type_id module' s gr'' - in SOME (Pretty.block ((if null Ts then [] else - [mk_tuple ps, str " "]) @ - [str (mk_qual_id module tyid)]), gr''') - end) - | datatype_tycodegen _ _ _ _ _ _ _ = NONE; - - -(** generic code generator **) - -(* liberal addition of code data for datatypes *) - -fun mk_constr_consts thy vs dtco cos = - let - val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos; - val cs' = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) cs; - in if is_some (try (Code.constrset_of_consts thy) cs') - then SOME cs - else NONE - end; - - -(* case certificates *) - -fun mk_case_cert thy tyco = - let - val raw_thms = - (#case_rewrites o Datatype.the_datatype thy) tyco; - val thms as hd_thm :: _ = raw_thms - |> Conjunction.intr_balanced - |> Thm.unvarify - |> Conjunction.elim_balanced (length raw_thms) - |> map Simpdata.mk_meta_eq - |> map Drule.zero_var_indexes - val params = fold_aterms (fn (Free (v, _)) => insert (op =) v - | _ => I) (Thm.prop_of hd_thm) []; - val rhs = hd_thm - |> Thm.prop_of - |> Logic.dest_equals - |> fst - |> Term.strip_comb - |> apsnd (fst o split_last) - |> list_comb; - val lhs = Free (Name.variant params "case", Term.fastype_of rhs); - val asm = (Thm.cterm_of thy o Logic.mk_equals) (lhs, rhs); - in - thms - |> Conjunction.intr_balanced - |> MetaSimplifier.rewrite_rule [(Thm.symmetric o Thm.assume) asm] - |> Thm.implies_intr asm - |> Thm.generalize ([], params) 0 - |> AxClass.unoverload thy - |> Thm.varifyT - end; - - -(* equality *) - -fun mk_eq_eqns thy dtco = - let - val (vs, cos) = Datatype.the_datatype_spec thy dtco; - val { descr, index, inject = inject_thms, ... } = Datatype.the_datatype thy dtco; - val ty = Type (dtco, map TFree vs); - fun mk_eq (t1, t2) = Const (@{const_name eq_class.eq}, ty --> ty --> HOLogic.boolT) - $ t1 $ t2; - fun true_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.true_const); - fun false_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.false_const); - val triv_injects = map_filter - (fn (c, []) => SOME (HOLogic.mk_Trueprop (true_eq (Const (c, ty), Const (c, ty)))) - | _ => NONE) cos; - fun prep_inject (trueprop $ (equiv $ (_ $ t1 $ t2) $ rhs)) = - trueprop $ (equiv $ mk_eq (t1, t2) $ rhs); - val injects = map prep_inject (nth (DatatypeProp.make_injs [descr] vs) index); - fun prep_distinct (trueprop $ (not $ (_ $ t1 $ t2))) = - [trueprop $ false_eq (t1, t2), trueprop $ false_eq (t2, t1)]; - val distincts = maps prep_distinct (snd (nth (DatatypeProp.make_distincts [descr] vs) index)); - val refl = HOLogic.mk_Trueprop (true_eq (Free ("x", ty), Free ("x", ty))); - val simpset = Simplifier.context (ProofContext.init thy) (HOL_basic_ss - addsimps (map Simpdata.mk_eq (@{thm eq} :: @{thm eq_True} :: inject_thms)) - addsimprocs [Datatype.distinct_simproc]); - fun prove prop = SkipProof.prove_global thy [] [] prop (K (ALLGOALS (simp_tac simpset))) - |> Simpdata.mk_eq; - in map (rpair true o prove) (triv_injects @ injects @ distincts) @ [(prove refl, false)] end; - -fun add_equality vs dtcos thy = - let - fun add_def dtco lthy = - let - val ty = Type (dtco, map TFree vs); - fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT) - $ Free ("x", ty) $ Free ("y", ty); - val def = HOLogic.mk_Trueprop (HOLogic.mk_eq - (mk_side @{const_name eq_class.eq}, mk_side @{const_name "op ="})); - val def' = Syntax.check_term lthy def; - val ((_, (_, thm)), lthy') = Specification.definition - (NONE, (Attrib.empty_binding, def')) lthy; - val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy); - val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm; - in (thm', lthy') end; - fun tac thms = Class.intro_classes_tac [] - THEN ALLGOALS (ProofContext.fact_tac thms); - fun add_eq_thms dtco thy = - let - val const = AxClass.param_of_inst thy (@{const_name eq_class.eq}, dtco); - val thy_ref = Theory.check_thy thy; - fun mk_thms () = rev ((mk_eq_eqns (Theory.deref thy_ref) dtco)); - in - Code.add_eqnl (const, Lazy.lazy mk_thms) thy - end; - in - thy - |> TheoryTarget.instantiation (dtcos, vs, [HOLogic.class_eq]) - |> fold_map add_def dtcos - |-> (fn def_thms => Class.prove_instantiation_exit_result (map o Morphism.thm) - (fn _ => fn def_thms => tac def_thms) def_thms) - |-> (fn def_thms => fold Code.del_eqn def_thms) - |> fold add_eq_thms dtcos - end; - - -(* register a datatype etc. *) - -fun add_all_code config dtcos thy = - let - val (vs :: _, coss) = (split_list o map (Datatype.the_datatype_spec thy)) dtcos; - val any_css = map2 (mk_constr_consts thy vs) dtcos coss; - val css = if exists is_none any_css then [] - else map_filter I any_css; - val case_rewrites = maps (#case_rewrites o Datatype.the_datatype thy) dtcos; - val certs = map (mk_case_cert thy) dtcos; - in - if null css then thy - else thy - |> tap (fn _ => DatatypeAux.message config "Registering datatype for code generator ...") - |> fold Code.add_datatype css - |> fold_rev Code.add_default_eqn case_rewrites - |> fold Code.add_case certs - |> add_equality vs dtcos - end; - - -(** theory setup **) - -val setup = - add_codegen "datatype" datatype_codegen - #> add_tycodegen "datatype" datatype_tycodegen - #> Datatype.interpretation add_all_code - -end; diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/datatype_package/datatype_prop.ML --- a/src/HOL/Tools/datatype_package/datatype_prop.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,435 +0,0 @@ -(* Title: HOL/Tools/datatype_prop.ML - Author: Stefan Berghofer, TU Muenchen - -Characteristic properties of datatypes. -*) - -signature DATATYPE_PROP = -sig - val indexify_names: string list -> string list - val make_tnames: typ list -> string list - val make_injs : DatatypeAux.descr list -> (string * sort) list -> term list list - val make_distincts : DatatypeAux.descr list -> - (string * sort) list -> (int * term list) list (*no symmetric inequalities*) - val make_ind : DatatypeAux.descr list -> (string * sort) list -> term - val make_casedists : DatatypeAux.descr list -> (string * sort) list -> term list - val make_primrec_Ts : DatatypeAux.descr list -> (string * sort) list -> - string list -> typ list * typ list - val make_primrecs : string list -> DatatypeAux.descr list -> - (string * sort) list -> theory -> term list - val make_cases : string list -> DatatypeAux.descr list -> - (string * sort) list -> theory -> term list list - val make_splits : string list -> DatatypeAux.descr list -> - (string * sort) list -> theory -> (term * term) list - val make_weak_case_congs : string list -> DatatypeAux.descr list -> - (string * sort) list -> theory -> term list - val make_case_congs : string list -> DatatypeAux.descr list -> - (string * sort) list -> theory -> term list - val make_nchotomys : DatatypeAux.descr list -> - (string * sort) list -> term list -end; - -structure DatatypeProp : DATATYPE_PROP = -struct - -open DatatypeAux; - -fun indexify_names names = - let - fun index (x :: xs) tab = - (case AList.lookup (op =) tab x of - NONE => if member (op =) xs x then (x ^ "1") :: index xs ((x, 2) :: tab) else x :: index xs tab - | SOME i => (x ^ string_of_int i) :: index xs ((x, i + 1) :: tab)) - | index [] _ = []; - in index names [] end; - -fun make_tnames Ts = - let - fun type_name (TFree (name, _)) = implode (tl (explode name)) - | type_name (Type (name, _)) = - let val name' = Long_Name.base_name name - in if Syntax.is_identifier name' then name' else "x" end; - in indexify_names (map type_name Ts) end; - - -(************************* injectivity of constructors ************************) - -fun make_injs descr sorts = - let - val descr' = flat descr; - fun make_inj T (cname, cargs) = - if null cargs then I else - let - val Ts = map (typ_of_dtyp descr' sorts) cargs; - val constr_t = Const (cname, Ts ---> T); - val tnames = make_tnames Ts; - val frees = map Free (tnames ~~ Ts); - val frees' = map Free ((map ((op ^) o (rpair "'")) tnames) ~~ Ts); - in cons (HOLogic.mk_Trueprop (HOLogic.mk_eq - (HOLogic.mk_eq (list_comb (constr_t, frees), list_comb (constr_t, frees')), - foldr1 (HOLogic.mk_binop "op &") - (map HOLogic.mk_eq (frees ~~ frees'))))) - end; - in - map2 (fn d => fn T => fold_rev (make_inj T) (#3 (snd d)) []) - (hd descr) (Library.take (length (hd descr), get_rec_types descr' sorts)) - end; - - -(************************* distinctness of constructors ***********************) - -fun make_distincts descr sorts = - let - val descr' = flat descr; - val recTs = get_rec_types descr' sorts; - val newTs = Library.take (length (hd descr), recTs); - - fun prep_constr (cname, cargs) = (cname, map (typ_of_dtyp descr' sorts) cargs); - - fun make_distincts' _ [] = [] - | make_distincts' T ((cname, cargs)::constrs) = - let - val frees = map Free ((make_tnames cargs) ~~ cargs); - val t = list_comb (Const (cname, cargs ---> T), frees); - - fun make_distincts'' (cname', cargs') = - let - val frees' = map Free ((map ((op ^) o (rpair "'")) - (make_tnames cargs')) ~~ cargs'); - val t' = list_comb (Const (cname', cargs' ---> T), frees') - in - HOLogic.mk_Trueprop (HOLogic.Not $ HOLogic.mk_eq (t, t')) - end - - in map make_distincts'' constrs @ make_distincts' T constrs end; - - in - map2 (fn ((_, (_, _, constrs))) => fn T => - (length constrs, make_distincts' T (map prep_constr constrs))) (hd descr) newTs - end; - - -(********************************* induction **********************************) - -fun make_ind descr sorts = - let - val descr' = List.concat descr; - val recTs = get_rec_types descr' sorts; - val pnames = if length descr' = 1 then ["P"] - else map (fn i => "P" ^ string_of_int i) (1 upto length descr'); - - fun make_pred i T = - let val T' = T --> HOLogic.boolT - in Free (List.nth (pnames, i), T') end; - - fun make_ind_prem k T (cname, cargs) = - let - fun mk_prem ((dt, s), T) = - let val (Us, U) = strip_type T - in list_all (map (pair "x") Us, HOLogic.mk_Trueprop - (make_pred (body_index dt) U $ app_bnds (Free (s, T)) (length Us))) - end; - - val recs = List.filter is_rec_type cargs; - val Ts = map (typ_of_dtyp descr' sorts) cargs; - val recTs' = map (typ_of_dtyp descr' sorts) recs; - val tnames = Name.variant_list pnames (make_tnames Ts); - val rec_tnames = map fst (List.filter (is_rec_type o snd) (tnames ~~ cargs)); - val frees = tnames ~~ Ts; - val prems = map mk_prem (recs ~~ rec_tnames ~~ recTs'); - - in list_all_free (frees, Logic.list_implies (prems, - HOLogic.mk_Trueprop (make_pred k T $ - list_comb (Const (cname, Ts ---> T), map Free frees)))) - end; - - val prems = List.concat (map (fn ((i, (_, _, constrs)), T) => - map (make_ind_prem i T) constrs) (descr' ~~ recTs)); - val tnames = make_tnames recTs; - val concl = HOLogic.mk_Trueprop (foldr1 (HOLogic.mk_binop "op &") - (map (fn (((i, _), T), tname) => make_pred i T $ Free (tname, T)) - (descr' ~~ recTs ~~ tnames))) - - in Logic.list_implies (prems, concl) end; - -(******************************* case distinction *****************************) - -fun make_casedists descr sorts = - let - val descr' = List.concat descr; - - fun make_casedist_prem T (cname, cargs) = - let - val Ts = map (typ_of_dtyp descr' sorts) cargs; - val frees = Name.variant_list ["P", "y"] (make_tnames Ts) ~~ Ts; - val free_ts = map Free frees - in list_all_free (frees, Logic.mk_implies (HOLogic.mk_Trueprop - (HOLogic.mk_eq (Free ("y", T), list_comb (Const (cname, Ts ---> T), free_ts))), - HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT)))) - end; - - fun make_casedist ((_, (_, _, constrs)), T) = - let val prems = map (make_casedist_prem T) constrs - in Logic.list_implies (prems, HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT))) - end - - in map make_casedist - ((hd descr) ~~ Library.take (length (hd descr), get_rec_types descr' sorts)) - end; - -(*************** characteristic equations for primrec combinator **************) - -fun make_primrec_Ts descr sorts used = - let - val descr' = List.concat descr; - - val rec_result_Ts = map TFree (Name.variant_list used (replicate (length descr') "'t") ~~ - replicate (length descr') HOLogic.typeS); - - val reccomb_fn_Ts = List.concat (map (fn (i, (_, _, constrs)) => - map (fn (_, cargs) => - let - val Ts = map (typ_of_dtyp descr' sorts) cargs; - val recs = List.filter (is_rec_type o fst) (cargs ~~ Ts); - - fun mk_argT (dt, T) = - binder_types T ---> List.nth (rec_result_Ts, body_index dt); - - val argTs = Ts @ map mk_argT recs - in argTs ---> List.nth (rec_result_Ts, i) - end) constrs) descr'); - - in (rec_result_Ts, reccomb_fn_Ts) end; - -fun make_primrecs new_type_names descr sorts thy = - let - val descr' = List.concat descr; - val recTs = get_rec_types descr' sorts; - val used = List.foldr OldTerm.add_typ_tfree_names [] recTs; - - val (rec_result_Ts, reccomb_fn_Ts) = make_primrec_Ts descr sorts used; - - val rec_fns = map (uncurry (mk_Free "f")) - (reccomb_fn_Ts ~~ (1 upto (length reccomb_fn_Ts))); - - val big_reccomb_name = (space_implode "_" new_type_names) ^ "_rec"; - val reccomb_names = map (Sign.intern_const thy) - (if length descr' = 1 then [big_reccomb_name] else - (map ((curry (op ^) (big_reccomb_name ^ "_")) o string_of_int) - (1 upto (length descr')))); - val reccombs = map (fn ((name, T), T') => list_comb - (Const (name, reccomb_fn_Ts @ [T] ---> T'), rec_fns)) - (reccomb_names ~~ recTs ~~ rec_result_Ts); - - fun make_primrec T comb_t ((ts, f::fs), (cname, cargs)) = - let - val recs = List.filter is_rec_type cargs; - val Ts = map (typ_of_dtyp descr' sorts) cargs; - val recTs' = map (typ_of_dtyp descr' sorts) recs; - val tnames = make_tnames Ts; - val rec_tnames = map fst (List.filter (is_rec_type o snd) (tnames ~~ cargs)); - val frees = map Free (tnames ~~ Ts); - val frees' = map Free (rec_tnames ~~ recTs'); - - fun mk_reccomb ((dt, T), t) = - let val (Us, U) = strip_type T - in list_abs (map (pair "x") Us, - List.nth (reccombs, body_index dt) $ app_bnds t (length Us)) - end; - - val reccombs' = map mk_reccomb (recs ~~ recTs' ~~ frees') - - in (ts @ [HOLogic.mk_Trueprop (HOLogic.mk_eq - (comb_t $ list_comb (Const (cname, Ts ---> T), frees), - list_comb (f, frees @ reccombs')))], fs) - end - - in fst (Library.foldl (fn (x, ((dt, T), comb_t)) => - Library.foldl (make_primrec T comb_t) (x, #3 (snd dt))) - (([], rec_fns), descr' ~~ recTs ~~ reccombs)) - end; - -(****************** make terms of form t_case f1 ... fn *********************) - -fun make_case_combs new_type_names descr sorts thy fname = - let - val descr' = List.concat descr; - val recTs = get_rec_types descr' sorts; - val used = List.foldr OldTerm.add_typ_tfree_names [] recTs; - val newTs = Library.take (length (hd descr), recTs); - val T' = TFree (Name.variant used "'t", HOLogic.typeS); - - val case_fn_Ts = map (fn (i, (_, _, constrs)) => - map (fn (_, cargs) => - let val Ts = map (typ_of_dtyp descr' sorts) cargs - in Ts ---> T' end) constrs) (hd descr); - - val case_names = map (fn s => - Sign.intern_const thy (s ^ "_case")) new_type_names - in - map (fn ((name, Ts), T) => list_comb - (Const (name, Ts @ [T] ---> T'), - map (uncurry (mk_Free fname)) (Ts ~~ (1 upto length Ts)))) - (case_names ~~ case_fn_Ts ~~ newTs) - end; - -(**************** characteristic equations for case combinator ****************) - -fun make_cases new_type_names descr sorts thy = - let - val descr' = List.concat descr; - val recTs = get_rec_types descr' sorts; - val newTs = Library.take (length (hd descr), recTs); - - fun make_case T comb_t ((cname, cargs), f) = - let - val Ts = map (typ_of_dtyp descr' sorts) cargs; - val frees = map Free ((make_tnames Ts) ~~ Ts) - in HOLogic.mk_Trueprop (HOLogic.mk_eq - (comb_t $ list_comb (Const (cname, Ts ---> T), frees), - list_comb (f, frees))) - end - - in map (fn (((_, (_, _, constrs)), T), comb_t) => - map (make_case T comb_t) (constrs ~~ (snd (strip_comb comb_t)))) - ((hd descr) ~~ newTs ~~ (make_case_combs new_type_names descr sorts thy "f")) - end; - - -(*************************** the "split" - equations **************************) - -fun make_splits new_type_names descr sorts thy = - let - val descr' = List.concat descr; - val recTs = get_rec_types descr' sorts; - val used' = List.foldr OldTerm.add_typ_tfree_names [] recTs; - val newTs = Library.take (length (hd descr), recTs); - val T' = TFree (Name.variant used' "'t", HOLogic.typeS); - val P = Free ("P", T' --> HOLogic.boolT); - - fun make_split (((_, (_, _, constrs)), T), comb_t) = - let - val (_, fs) = strip_comb comb_t; - val used = ["P", "x"] @ (map (fst o dest_Free) fs); - - fun process_constr (((cname, cargs), f), (t1s, t2s)) = - let - val Ts = map (typ_of_dtyp descr' sorts) cargs; - val frees = map Free (Name.variant_list used (make_tnames Ts) ~~ Ts); - val eqn = HOLogic.mk_eq (Free ("x", T), - list_comb (Const (cname, Ts ---> T), frees)); - val P' = P $ list_comb (f, frees) - in ((List.foldr (fn (Free (s, T), t) => HOLogic.mk_all (s, T, t)) - (HOLogic.imp $ eqn $ P') frees)::t1s, - (List.foldr (fn (Free (s, T), t) => HOLogic.mk_exists (s, T, t)) - (HOLogic.conj $ eqn $ (HOLogic.Not $ P')) frees)::t2s) - end; - - val (t1s, t2s) = List.foldr process_constr ([], []) (constrs ~~ fs); - val lhs = P $ (comb_t $ Free ("x", T)) - in - (HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, mk_conj t1s)), - HOLogic.mk_Trueprop (HOLogic.mk_eq (lhs, HOLogic.Not $ mk_disj t2s))) - end - - in map make_split ((hd descr) ~~ newTs ~~ - (make_case_combs new_type_names descr sorts thy "f")) - end; - -(************************* additional rules for TFL ***************************) - -fun make_weak_case_congs new_type_names descr sorts thy = - let - val case_combs = make_case_combs new_type_names descr sorts thy "f"; - - fun mk_case_cong comb = - let - val Type ("fun", [T, _]) = fastype_of comb; - val M = Free ("M", T); - val M' = Free ("M'", T); - in - Logic.mk_implies (HOLogic.mk_Trueprop (HOLogic.mk_eq (M, M')), - HOLogic.mk_Trueprop (HOLogic.mk_eq (comb $ M, comb $ M'))) - end - in - map mk_case_cong case_combs - end; - - -(*--------------------------------------------------------------------------- - * Structure of case congruence theorem looks like this: - * - * (M = M') - * ==> (!!x1,...,xk. (M' = C1 x1..xk) ==> (f1 x1..xk = g1 x1..xk)) - * ==> ... - * ==> (!!x1,...,xj. (M' = Cn x1..xj) ==> (fn x1..xj = gn x1..xj)) - * ==> - * (ty_case f1..fn M = ty_case g1..gn M') - *---------------------------------------------------------------------------*) - -fun make_case_congs new_type_names descr sorts thy = - let - val case_combs = make_case_combs new_type_names descr sorts thy "f"; - val case_combs' = make_case_combs new_type_names descr sorts thy "g"; - - fun mk_case_cong ((comb, comb'), (_, (_, _, constrs))) = - let - val Type ("fun", [T, _]) = fastype_of comb; - val (_, fs) = strip_comb comb; - val (_, gs) = strip_comb comb'; - val used = ["M", "M'"] @ map (fst o dest_Free) (fs @ gs); - val M = Free ("M", T); - val M' = Free ("M'", T); - - fun mk_clause ((f, g), (cname, _)) = - let - val (Ts, _) = strip_type (fastype_of f); - val tnames = Name.variant_list used (make_tnames Ts); - val frees = map Free (tnames ~~ Ts) - in - list_all_free (tnames ~~ Ts, Logic.mk_implies - (HOLogic.mk_Trueprop - (HOLogic.mk_eq (M', list_comb (Const (cname, Ts ---> T), frees))), - HOLogic.mk_Trueprop - (HOLogic.mk_eq (list_comb (f, frees), list_comb (g, frees))))) - end - - in - Logic.list_implies (HOLogic.mk_Trueprop (HOLogic.mk_eq (M, M')) :: - map mk_clause (fs ~~ gs ~~ constrs), - HOLogic.mk_Trueprop (HOLogic.mk_eq (comb $ M, comb' $ M'))) - end - - in - map mk_case_cong (case_combs ~~ case_combs' ~~ hd descr) - end; - -(*--------------------------------------------------------------------------- - * Structure of exhaustion theorem looks like this: - * - * !v. (? y1..yi. v = C1 y1..yi) | ... | (? y1..yj. v = Cn y1..yj) - *---------------------------------------------------------------------------*) - -fun make_nchotomys descr sorts = - let - val descr' = List.concat descr; - val recTs = get_rec_types descr' sorts; - val newTs = Library.take (length (hd descr), recTs); - - fun mk_eqn T (cname, cargs) = - let - val Ts = map (typ_of_dtyp descr' sorts) cargs; - val tnames = Name.variant_list ["v"] (make_tnames Ts); - val frees = tnames ~~ Ts - in - List.foldr (fn ((s, T'), t) => HOLogic.mk_exists (s, T', t)) - (HOLogic.mk_eq (Free ("v", T), - list_comb (Const (cname, Ts ---> T), map Free frees))) frees - end - - in map (fn ((_, (_, _, constrs)), T) => - HOLogic.mk_Trueprop (HOLogic.mk_all ("v", T, mk_disj (map (mk_eqn T) constrs)))) - (hd descr ~~ newTs) - end; - -end; diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/datatype_package/datatype_realizer.ML --- a/src/HOL/Tools/datatype_package/datatype_realizer.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,230 +0,0 @@ -(* Title: HOL/Tools/datatype_realizer.ML - Author: Stefan Berghofer, TU Muenchen - -Porgram extraction from proofs involving datatypes: -Realizers for induction and case analysis -*) - -signature DATATYPE_REALIZER = -sig - val add_dt_realizers: Datatype.config -> string list -> theory -> theory - val setup: theory -> theory -end; - -structure DatatypeRealizer : DATATYPE_REALIZER = -struct - -open DatatypeAux; - -fun subsets i j = if i <= j then - let val is = subsets (i+1) j - in map (fn ks => i::ks) is @ is end - else [[]]; - -fun forall_intr_prf (t, prf) = - let val (a, T) = (case t of Var ((a, _), T) => (a, T) | Free p => p) - in Abst (a, SOME T, Proofterm.prf_abstract_over t prf) end; - -fun prf_of thm = - Reconstruct.reconstruct_proof (Thm.theory_of_thm thm) (Thm.prop_of thm) (Thm.proof_of thm); - -fun prf_subst_vars inst = - Proofterm.map_proof_terms (subst_vars ([], inst)) I; - -fun is_unit t = snd (strip_type (fastype_of t)) = HOLogic.unitT; - -fun tname_of (Type (s, _)) = s - | tname_of _ = ""; - -fun mk_realizes T = Const ("realizes", T --> HOLogic.boolT --> HOLogic.boolT); - -fun make_ind sorts ({descr, rec_names, rec_rewrites, induction, ...} : info) is thy = - let - val recTs = get_rec_types descr sorts; - val pnames = if length descr = 1 then ["P"] - else map (fn i => "P" ^ string_of_int i) (1 upto length descr); - - val rec_result_Ts = map (fn ((i, _), P) => - if i mem is then TFree ("'" ^ P, HOLogic.typeS) else HOLogic.unitT) - (descr ~~ pnames); - - fun make_pred i T U r x = - if i mem is then - Free (List.nth (pnames, i), T --> U --> HOLogic.boolT) $ r $ x - else Free (List.nth (pnames, i), U --> HOLogic.boolT) $ x; - - fun mk_all i s T t = - if i mem is then list_all_free ([(s, T)], t) else t; - - val (prems, rec_fns) = split_list (flat (fst (fold_map - (fn ((i, (_, _, constrs)), T) => fold_map (fn (cname, cargs) => fn j => - let - val Ts = map (typ_of_dtyp descr sorts) cargs; - val tnames = Name.variant_list pnames (DatatypeProp.make_tnames Ts); - val recs = filter (is_rec_type o fst o fst) (cargs ~~ tnames ~~ Ts); - val frees = tnames ~~ Ts; - - fun mk_prems vs [] = - let - val rT = nth (rec_result_Ts) i; - val vs' = filter_out is_unit vs; - val f = mk_Free "f" (map fastype_of vs' ---> rT) j; - val f' = Envir.eta_contract (list_abs_free - (map dest_Free vs, if i mem is then list_comb (f, vs') - else HOLogic.unit)); - in (HOLogic.mk_Trueprop (make_pred i rT T (list_comb (f, vs')) - (list_comb (Const (cname, Ts ---> T), map Free frees))), f') - end - | mk_prems vs (((dt, s), T) :: ds) = - let - val k = body_index dt; - val (Us, U) = strip_type T; - val i = length Us; - val rT = nth (rec_result_Ts) k; - val r = Free ("r" ^ s, Us ---> rT); - val (p, f) = mk_prems (vs @ [r]) ds - in (mk_all k ("r" ^ s) (Us ---> rT) (Logic.mk_implies - (list_all (map (pair "x") Us, HOLogic.mk_Trueprop - (make_pred k rT U (app_bnds r i) - (app_bnds (Free (s, T)) i))), p)), f) - end - - in (apfst (curry list_all_free frees) (mk_prems (map Free frees) recs), j + 1) end) - constrs) (descr ~~ recTs) 1))); - - fun mk_proj j [] t = t - | mk_proj j (i :: is) t = if null is then t else - if (j: int) = i then HOLogic.mk_fst t - else mk_proj j is (HOLogic.mk_snd t); - - val tnames = DatatypeProp.make_tnames recTs; - val fTs = map fastype_of rec_fns; - val ps = map (fn ((((i, _), T), U), s) => Abs ("x", T, make_pred i U T - (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Bound 0) (Bound 0))) - (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names); - val r = if null is then Extraction.nullt else - foldr1 HOLogic.mk_prod (List.mapPartial (fn (((((i, _), T), U), s), tname) => - if i mem is then SOME - (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Free (tname, T)) - else NONE) (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names ~~ tnames)); - val concl = HOLogic.mk_Trueprop (foldr1 (HOLogic.mk_binop "op &") - (map (fn ((((i, _), T), U), tname) => - make_pred i U T (mk_proj i is r) (Free (tname, T))) - (descr ~~ recTs ~~ rec_result_Ts ~~ tnames))); - val cert = cterm_of thy; - val inst = map (pairself cert) (map head_of (HOLogic.dest_conj - (HOLogic.dest_Trueprop (concl_of induction))) ~~ ps); - - val thm = OldGoals.simple_prove_goal_cterm (cert (Logic.list_implies (prems, concl))) - (fn prems => - [rewrite_goals_tac (map mk_meta_eq [fst_conv, snd_conv]), - rtac (cterm_instantiate inst induction) 1, - ALLGOALS ObjectLogic.atomize_prems_tac, - rewrite_goals_tac (@{thm o_def} :: map mk_meta_eq rec_rewrites), - REPEAT ((resolve_tac prems THEN_ALL_NEW (fn i => - REPEAT (etac allE i) THEN atac i)) 1)]); - - val ind_name = Thm.get_name induction; - val vs = map (fn i => List.nth (pnames, i)) is; - val (thm', thy') = thy - |> Sign.root_path - |> PureThy.store_thm - (Binding.qualified_name (space_implode "_" (ind_name :: vs @ ["correctness"])), thm) - ||> Sign.restore_naming thy; - - val ivs = rev (Term.add_vars (Logic.varify (DatatypeProp.make_ind [descr] sorts)) []); - val rvs = rev (Thm.fold_terms Term.add_vars thm' []); - val ivs1 = map Var (filter_out (fn (_, T) => - tname_of (body_type T) mem ["set", "bool"]) ivs); - val ivs2 = map (fn (ixn, _) => Var (ixn, valOf (AList.lookup (op =) rvs ixn))) ivs; - - val prf = List.foldr forall_intr_prf - (List.foldr (fn ((f, p), prf) => - (case head_of (strip_abs_body f) of - Free (s, T) => - let val T' = Logic.varifyT T - in Abst (s, SOME T', Proofterm.prf_abstract_over - (Var ((s, 0), T')) (AbsP ("H", SOME p, prf))) - end - | _ => AbsP ("H", SOME p, prf))) - (Proofterm.proof_combP - (prf_of thm', map PBound (length prems - 1 downto 0))) (rec_fns ~~ prems_of thm)) ivs2; - - val r' = if null is then r else Logic.varify (List.foldr (uncurry lambda) - r (map Logic.unvarify ivs1 @ filter_out is_unit - (map (head_of o strip_abs_body) rec_fns))); - - in Extraction.add_realizers_i [(ind_name, (vs, r', prf))] thy' end; - - -fun make_casedists sorts ({index, descr, case_name, case_rewrites, exhaustion, ...} : info) thy = - let - val cert = cterm_of thy; - val rT = TFree ("'P", HOLogic.typeS); - val rT' = TVar (("'P", 0), HOLogic.typeS); - - fun make_casedist_prem T (cname, cargs) = - let - val Ts = map (typ_of_dtyp descr sorts) cargs; - val frees = Name.variant_list ["P", "y"] (DatatypeProp.make_tnames Ts) ~~ Ts; - val free_ts = map Free frees; - val r = Free ("r" ^ Long_Name.base_name cname, Ts ---> rT) - in (r, list_all_free (frees, Logic.mk_implies (HOLogic.mk_Trueprop - (HOLogic.mk_eq (Free ("y", T), list_comb (Const (cname, Ts ---> T), free_ts))), - HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $ - list_comb (r, free_ts))))) - end; - - val SOME (_, _, constrs) = AList.lookup (op =) descr index; - val T = List.nth (get_rec_types descr sorts, index); - val (rs, prems) = split_list (map (make_casedist_prem T) constrs); - val r = Const (case_name, map fastype_of rs ---> T --> rT); - - val y = Var (("y", 0), Logic.legacy_varifyT T); - val y' = Free ("y", T); - - val thm = OldGoals.prove_goalw_cterm [] (cert (Logic.list_implies (prems, - HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $ - list_comb (r, rs @ [y']))))) - (fn prems => - [rtac (cterm_instantiate [(cert y, cert y')] exhaustion) 1, - ALLGOALS (EVERY' - [asm_simp_tac (HOL_basic_ss addsimps case_rewrites), - resolve_tac prems, asm_simp_tac HOL_basic_ss])]); - - val exh_name = Thm.get_name exhaustion; - val (thm', thy') = thy - |> Sign.root_path - |> PureThy.store_thm (Binding.qualified_name (exh_name ^ "_P_correctness"), thm) - ||> Sign.restore_naming thy; - - val P = Var (("P", 0), rT' --> HOLogic.boolT); - val prf = forall_intr_prf (y, forall_intr_prf (P, - List.foldr (fn ((p, r), prf) => - forall_intr_prf (Logic.legacy_varify r, AbsP ("H", SOME (Logic.varify p), - prf))) (Proofterm.proof_combP (prf_of thm', - map PBound (length prems - 1 downto 0))) (prems ~~ rs))); - val r' = Logic.legacy_varify (Abs ("y", Logic.legacy_varifyT T, - list_abs (map dest_Free rs, list_comb (r, - map Bound ((length rs - 1 downto 0) @ [length rs]))))); - - in Extraction.add_realizers_i - [(exh_name, (["P"], r', prf)), - (exh_name, ([], Extraction.nullt, prf_of exhaustion))] thy' - end; - -fun add_dt_realizers config names thy = - if ! Proofterm.proofs < 2 then thy - else let - val _ = message config "Adding realizers for induction and case analysis ..." - val infos = map (Datatype.the_datatype thy) names; - val info :: _ = infos; - in - thy - |> fold_rev (make_ind (#sorts info) info) (subsets 0 (length (#descr info) - 1)) - |> fold_rev (make_casedists (#sorts info)) infos - end; - -val setup = Datatype.interpretation add_dt_realizers; - -end; diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/datatype_package/datatype_rep_proofs.ML --- a/src/HOL/Tools/datatype_package/datatype_rep_proofs.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,643 +0,0 @@ -(* Title: HOL/Tools/datatype_rep_proofs.ML - Author: Stefan Berghofer, TU Muenchen - -Definitional introduction of datatypes -Proof of characteristic theorems: - - - injectivity of constructors - - distinctness of constructors - - induction theorem -*) - -signature DATATYPE_REP_PROOFS = -sig - include DATATYPE_COMMON - val distinctness_limit : int Config.T - val distinctness_limit_setup : theory -> theory - val representation_proofs : config -> info Symtab.table -> - string list -> descr list -> (string * sort) list -> - (binding * mixfix) list -> (binding * mixfix) list list -> attribute - -> theory -> (thm list list * thm list list * thm list list * - DatatypeAux.simproc_dist list * thm) * theory -end; - -structure DatatypeRepProofs : DATATYPE_REP_PROOFS = -struct - -open DatatypeAux; - -(*the kind of distinctiveness axioms depends on number of constructors*) -val (distinctness_limit, distinctness_limit_setup) = - Attrib.config_int "datatype_distinctness_limit" 7; - -val (_ $ (_ $ (_ $ (distinct_f $ _) $ _))) = hd (prems_of distinct_lemma); - -val collect_simp = rewrite_rule [mk_meta_eq mem_Collect_eq]; - - -(** theory context references **) - -val f_myinv_f = thm "f_myinv_f"; -val myinv_f_f = thm "myinv_f_f"; - - -fun exh_thm_of (dt_info : info Symtab.table) tname = - #exhaustion (the (Symtab.lookup dt_info tname)); - -(******************************************************************************) - -fun representation_proofs (config : config) (dt_info : info Symtab.table) - new_type_names descr sorts types_syntax constr_syntax case_names_induct thy = - let - val Datatype_thy = ThyInfo.the_theory "Datatype" thy; - val node_name = "Datatype.node"; - val In0_name = "Datatype.In0"; - val In1_name = "Datatype.In1"; - val Scons_name = "Datatype.Scons"; - val Leaf_name = "Datatype.Leaf"; - val Numb_name = "Datatype.Numb"; - val Lim_name = "Datatype.Lim"; - val Suml_name = "Datatype.Suml"; - val Sumr_name = "Datatype.Sumr"; - - val [In0_inject, In1_inject, Scons_inject, Leaf_inject, - In0_eq, In1_eq, In0_not_In1, In1_not_In0, - Lim_inject, Suml_inject, Sumr_inject] = map (PureThy.get_thm Datatype_thy) - ["In0_inject", "In1_inject", "Scons_inject", "Leaf_inject", - "In0_eq", "In1_eq", "In0_not_In1", "In1_not_In0", - "Lim_inject", "Suml_inject", "Sumr_inject"]; - - val descr' = flat descr; - - val big_name = space_implode "_" new_type_names; - val thy1 = add_path (#flat_names config) big_name thy; - val big_rec_name = big_name ^ "_rep_set"; - val rep_set_names' = - (if length descr' = 1 then [big_rec_name] else - (map ((curry (op ^) (big_rec_name ^ "_")) o string_of_int) - (1 upto (length descr')))); - val rep_set_names = map (Sign.full_bname thy1) rep_set_names'; - - val tyvars = map (fn (_, (_, Ts, _)) => map dest_DtTFree Ts) (hd descr); - val leafTs' = get_nonrec_types descr' sorts; - val branchTs = get_branching_types descr' sorts; - val branchT = if null branchTs then HOLogic.unitT - else BalancedTree.make (fn (T, U) => Type ("+", [T, U])) branchTs; - val arities = get_arities descr' \ 0; - val unneeded_vars = hd tyvars \\ List.foldr OldTerm.add_typ_tfree_names [] (leafTs' @ branchTs); - val leafTs = leafTs' @ (map (fn n => TFree (n, (the o AList.lookup (op =) sorts) n)) unneeded_vars); - val recTs = get_rec_types descr' sorts; - val newTs = Library.take (length (hd descr), recTs); - val oldTs = Library.drop (length (hd descr), recTs); - val sumT = if null leafTs then HOLogic.unitT - else BalancedTree.make (fn (T, U) => Type ("+", [T, U])) leafTs; - val Univ_elT = HOLogic.mk_setT (Type (node_name, [sumT, branchT])); - val UnivT = HOLogic.mk_setT Univ_elT; - val UnivT' = Univ_elT --> HOLogic.boolT; - val Collect = Const ("Collect", UnivT' --> UnivT); - - val In0 = Const (In0_name, Univ_elT --> Univ_elT); - val In1 = Const (In1_name, Univ_elT --> Univ_elT); - val Leaf = Const (Leaf_name, sumT --> Univ_elT); - val Lim = Const (Lim_name, (branchT --> Univ_elT) --> Univ_elT); - - (* make injections needed for embedding types in leaves *) - - fun mk_inj T' x = - let - fun mk_inj' T n i = - if n = 1 then x else - let val n2 = n div 2; - val Type (_, [T1, T2]) = T - in - if i <= n2 then - Const ("Sum_Type.Inl", T1 --> T) $ (mk_inj' T1 n2 i) - else - Const ("Sum_Type.Inr", T2 --> T) $ (mk_inj' T2 (n - n2) (i - n2)) - end - in mk_inj' sumT (length leafTs) (1 + find_index_eq T' leafTs) - end; - - (* make injections for constructors *) - - fun mk_univ_inj ts = BalancedTree.access - {left = fn t => In0 $ t, - right = fn t => In1 $ t, - init = - if ts = [] then Const (@{const_name undefined}, Univ_elT) - else foldr1 (HOLogic.mk_binop Scons_name) ts}; - - (* function spaces *) - - fun mk_fun_inj T' x = - let - fun mk_inj T n i = - if n = 1 then x else - let - val n2 = n div 2; - val Type (_, [T1, T2]) = T; - fun mkT U = (U --> Univ_elT) --> T --> Univ_elT - in - if i <= n2 then Const (Suml_name, mkT T1) $ mk_inj T1 n2 i - else Const (Sumr_name, mkT T2) $ mk_inj T2 (n - n2) (i - n2) - end - in mk_inj branchT (length branchTs) (1 + find_index_eq T' branchTs) - end; - - val mk_lim = List.foldr (fn (T, t) => Lim $ mk_fun_inj T (Abs ("x", T, t))); - - (************** generate introduction rules for representing set **********) - - val _ = message config "Constructing representing sets ..."; - - (* make introduction rule for a single constructor *) - - fun make_intr s n (i, (_, cargs)) = - let - fun mk_prem (dt, (j, prems, ts)) = (case strip_dtyp dt of - (dts, DtRec k) => - let - val Ts = map (typ_of_dtyp descr' sorts) dts; - val free_t = - app_bnds (mk_Free "x" (Ts ---> Univ_elT) j) (length Ts) - in (j + 1, list_all (map (pair "x") Ts, - HOLogic.mk_Trueprop - (Free (List.nth (rep_set_names', k), UnivT') $ free_t)) :: prems, - mk_lim free_t Ts :: ts) - end - | _ => - let val T = typ_of_dtyp descr' sorts dt - in (j + 1, prems, (Leaf $ mk_inj T (mk_Free "x" T j))::ts) - end); - - val (_, prems, ts) = List.foldr mk_prem (1, [], []) cargs; - val concl = HOLogic.mk_Trueprop - (Free (s, UnivT') $ mk_univ_inj ts n i) - in Logic.list_implies (prems, concl) - end; - - val intr_ts = maps (fn ((_, (_, _, constrs)), rep_set_name) => - map (make_intr rep_set_name (length constrs)) - ((1 upto (length constrs)) ~~ constrs)) (descr' ~~ rep_set_names'); - - val ({raw_induct = rep_induct, intrs = rep_intrs, ...}, thy2) = - Inductive.add_inductive_global (serial_string ()) - {quiet_mode = #quiet config, verbose = false, kind = Thm.internalK, - alt_name = Binding.name big_rec_name, coind = false, no_elim = true, no_ind = false, - skip_mono = true, fork_mono = false} - (map (fn s => ((Binding.name s, UnivT'), NoSyn)) rep_set_names') [] - (map (fn x => (Attrib.empty_binding, x)) intr_ts) [] thy1; - - (********************************* typedef ********************************) - - val (typedefs, thy3) = thy2 |> - parent_path (#flat_names config) |> - fold_map (fn ((((name, mx), tvs), c), name') => - Typedef.add_typedef false (SOME (Binding.name name')) (name, tvs, mx) - (Collect $ Const (c, UnivT')) NONE - (rtac exI 1 THEN rtac CollectI 1 THEN - QUIET_BREADTH_FIRST (has_fewer_prems 1) - (resolve_tac rep_intrs 1))) - (types_syntax ~~ tyvars ~~ - (Library.take (length newTs, rep_set_names)) ~~ new_type_names) ||> - add_path (#flat_names config) big_name; - - (*********************** definition of constructors ***********************) - - val big_rep_name = (space_implode "_" new_type_names) ^ "_Rep_"; - val rep_names = map (curry op ^ "Rep_") new_type_names; - val rep_names' = map (fn i => big_rep_name ^ (string_of_int i)) - (1 upto (length (flat (tl descr)))); - val all_rep_names = map (Sign.intern_const thy3) rep_names @ - map (Sign.full_bname thy3) rep_names'; - - (* isomorphism declarations *) - - val iso_decls = map (fn (T, s) => (Binding.name s, T --> Univ_elT, NoSyn)) - (oldTs ~~ rep_names'); - - (* constructor definitions *) - - fun make_constr_def tname T n ((thy, defs, eqns, i), ((cname, cargs), (cname', mx))) = - let - fun constr_arg (dt, (j, l_args, r_args)) = - let val T = typ_of_dtyp descr' sorts dt; - val free_t = mk_Free "x" T j - in (case (strip_dtyp dt, strip_type T) of - ((_, DtRec m), (Us, U)) => (j + 1, free_t :: l_args, mk_lim - (Const (List.nth (all_rep_names, m), U --> Univ_elT) $ - app_bnds free_t (length Us)) Us :: r_args) - | _ => (j + 1, free_t::l_args, (Leaf $ mk_inj T free_t)::r_args)) - end; - - val (_, l_args, r_args) = List.foldr constr_arg (1, [], []) cargs; - val constrT = (map (typ_of_dtyp descr' sorts) cargs) ---> T; - val abs_name = Sign.intern_const thy ("Abs_" ^ tname); - val rep_name = Sign.intern_const thy ("Rep_" ^ tname); - val lhs = list_comb (Const (cname, constrT), l_args); - val rhs = mk_univ_inj r_args n i; - val def = Logic.mk_equals (lhs, Const (abs_name, Univ_elT --> T) $ rhs); - val def_name = Long_Name.base_name cname ^ "_def"; - val eqn = HOLogic.mk_Trueprop (HOLogic.mk_eq - (Const (rep_name, T --> Univ_elT) $ lhs, rhs)); - val ([def_thm], thy') = - thy - |> Sign.add_consts_i [(cname', constrT, mx)] - |> (PureThy.add_defs false o map Thm.no_attributes) [(Binding.name def_name, def)]; - - in (thy', defs @ [def_thm], eqns @ [eqn], i + 1) end; - - (* constructor definitions for datatype *) - - fun dt_constr_defs ((thy, defs, eqns, rep_congs, dist_lemmas), - ((((_, (_, _, constrs)), tname), T), constr_syntax)) = - let - val _ $ (_ $ (cong_f $ _) $ _) = concl_of arg_cong; - val rep_const = cterm_of thy - (Const (Sign.intern_const thy ("Rep_" ^ tname), T --> Univ_elT)); - val cong' = standard (cterm_instantiate [(cterm_of thy cong_f, rep_const)] arg_cong); - val dist = standard (cterm_instantiate [(cterm_of thy distinct_f, rep_const)] distinct_lemma); - val (thy', defs', eqns', _) = Library.foldl ((make_constr_def tname T) (length constrs)) - ((add_path (#flat_names config) tname thy, defs, [], 1), constrs ~~ constr_syntax) - in - (parent_path (#flat_names config) thy', defs', eqns @ [eqns'], - rep_congs @ [cong'], dist_lemmas @ [dist]) - end; - - val (thy4, constr_defs, constr_rep_eqns, rep_congs, dist_lemmas) = Library.foldl dt_constr_defs - ((thy3 |> Sign.add_consts_i iso_decls |> parent_path (#flat_names config), [], [], [], []), - hd descr ~~ new_type_names ~~ newTs ~~ constr_syntax); - - (*********** isomorphisms for new types (introduced by typedef) ***********) - - val _ = message config "Proving isomorphism properties ..."; - - val newT_iso_axms = map (fn (_, td) => - (collect_simp (#Abs_inverse td), #Rep_inverse td, - collect_simp (#Rep td))) typedefs; - - val newT_iso_inj_thms = map (fn (_, td) => - (collect_simp (#Abs_inject td) RS iffD1, #Rep_inject td RS iffD1)) typedefs; - - (********* isomorphisms between existing types and "unfolded" types *******) - - (*---------------------------------------------------------------------*) - (* isomorphisms are defined using primrec-combinators: *) - (* generate appropriate functions for instantiating primrec-combinator *) - (* *) - (* e.g. dt_Rep_i = list_rec ... (%h t y. In1 (Scons (Leaf h) y)) *) - (* *) - (* also generate characteristic equations for isomorphisms *) - (* *) - (* e.g. dt_Rep_i (cons h t) = In1 (Scons (dt_Rep_j h) (dt_Rep_i t)) *) - (*---------------------------------------------------------------------*) - - fun make_iso_def k ks n ((fs, eqns, i), (cname, cargs)) = - let - val argTs = map (typ_of_dtyp descr' sorts) cargs; - val T = List.nth (recTs, k); - val rep_name = List.nth (all_rep_names, k); - val rep_const = Const (rep_name, T --> Univ_elT); - val constr = Const (cname, argTs ---> T); - - fun process_arg ks' ((i2, i2', ts, Ts), dt) = - let - val T' = typ_of_dtyp descr' sorts dt; - val (Us, U) = strip_type T' - in (case strip_dtyp dt of - (_, DtRec j) => if j mem ks' then - (i2 + 1, i2' + 1, ts @ [mk_lim (app_bnds - (mk_Free "y" (Us ---> Univ_elT) i2') (length Us)) Us], - Ts @ [Us ---> Univ_elT]) - else - (i2 + 1, i2', ts @ [mk_lim - (Const (List.nth (all_rep_names, j), U --> Univ_elT) $ - app_bnds (mk_Free "x" T' i2) (length Us)) Us], Ts) - | _ => (i2 + 1, i2', ts @ [Leaf $ mk_inj T' (mk_Free "x" T' i2)], Ts)) - end; - - val (i2, i2', ts, Ts) = Library.foldl (process_arg ks) ((1, 1, [], []), cargs); - val xs = map (uncurry (mk_Free "x")) (argTs ~~ (1 upto (i2 - 1))); - val ys = map (uncurry (mk_Free "y")) (Ts ~~ (1 upto (i2' - 1))); - val f = list_abs_free (map dest_Free (xs @ ys), mk_univ_inj ts n i); - - val (_, _, ts', _) = Library.foldl (process_arg []) ((1, 1, [], []), cargs); - val eqn = HOLogic.mk_Trueprop (HOLogic.mk_eq - (rep_const $ list_comb (constr, xs), mk_univ_inj ts' n i)) - - in (fs @ [f], eqns @ [eqn], i + 1) end; - - (* define isomorphisms for all mutually recursive datatypes in list ds *) - - fun make_iso_defs (ds, (thy, char_thms)) = - let - val ks = map fst ds; - val (_, (tname, _, _)) = hd ds; - val {rec_rewrites, rec_names, ...} = the (Symtab.lookup dt_info tname); - - fun process_dt ((fs, eqns, isos), (k, (tname, _, constrs))) = - let - val (fs', eqns', _) = Library.foldl (make_iso_def k ks (length constrs)) - ((fs, eqns, 1), constrs); - val iso = (List.nth (recTs, k), List.nth (all_rep_names, k)) - in (fs', eqns', isos @ [iso]) end; - - val (fs, eqns, isos) = Library.foldl process_dt (([], [], []), ds); - val fTs = map fastype_of fs; - val defs = map (fn (rec_name, (T, iso_name)) => (Binding.name (Long_Name.base_name iso_name ^ "_def"), - Logic.mk_equals (Const (iso_name, T --> Univ_elT), - list_comb (Const (rec_name, fTs @ [T] ---> Univ_elT), fs)))) (rec_names ~~ isos); - val (def_thms, thy') = - apsnd Theory.checkpoint ((PureThy.add_defs false o map Thm.no_attributes) defs thy); - - (* prove characteristic equations *) - - val rewrites = def_thms @ (map mk_meta_eq rec_rewrites); - val char_thms' = map (fn eqn => SkipProof.prove_global thy' [] [] eqn - (fn _ => EVERY [rewrite_goals_tac rewrites, rtac refl 1])) eqns; - - in (thy', char_thms' @ char_thms) end; - - val (thy5, iso_char_thms) = apfst Theory.checkpoint (List.foldr make_iso_defs - (add_path (#flat_names config) big_name thy4, []) (tl descr)); - - (* prove isomorphism properties *) - - fun mk_funs_inv thy thm = - let - val prop = Thm.prop_of thm; - val _ $ (_ $ ((S as Const (_, Type (_, [U, _]))) $ _ )) $ - (_ $ (_ $ (r $ (a $ _)) $ _)) = Type.freeze prop; - val used = OldTerm.add_term_tfree_names (a, []); - - fun mk_thm i = - let - val Ts = map (TFree o rpair HOLogic.typeS) - (Name.variant_list used (replicate i "'t")); - val f = Free ("f", Ts ---> U) - in SkipProof.prove_global thy [] [] (Logic.mk_implies - (HOLogic.mk_Trueprop (HOLogic.list_all - (map (pair "x") Ts, S $ app_bnds f i)), - HOLogic.mk_Trueprop (HOLogic.mk_eq (list_abs (map (pair "x") Ts, - r $ (a $ app_bnds f i)), f)))) - (fn _ => EVERY [REPEAT_DETERM_N i (rtac ext 1), - REPEAT (etac allE 1), rtac thm 1, atac 1]) - end - in map (fn r => r RS subst) (thm :: map mk_thm arities) end; - - (* prove inj dt_Rep_i and dt_Rep_i x : dt_rep_set_i *) - - val fun_congs = map (fn T => make_elim (Drule.instantiate' - [SOME (ctyp_of thy5 T)] [] fun_cong)) branchTs; - - fun prove_iso_thms (ds, (inj_thms, elem_thms)) = - let - val (_, (tname, _, _)) = hd ds; - val {induction, ...} = the (Symtab.lookup dt_info tname); - - fun mk_ind_concl (i, _) = - let - val T = List.nth (recTs, i); - val Rep_t = Const (List.nth (all_rep_names, i), T --> Univ_elT); - val rep_set_name = List.nth (rep_set_names, i) - in (HOLogic.all_const T $ Abs ("y", T, HOLogic.imp $ - HOLogic.mk_eq (Rep_t $ mk_Free "x" T i, Rep_t $ Bound 0) $ - HOLogic.mk_eq (mk_Free "x" T i, Bound 0)), - Const (rep_set_name, UnivT') $ (Rep_t $ mk_Free "x" T i)) - end; - - val (ind_concl1, ind_concl2) = ListPair.unzip (map mk_ind_concl ds); - - val rewrites = map mk_meta_eq iso_char_thms; - val inj_thms' = map snd newT_iso_inj_thms @ - map (fn r => r RS @{thm injD}) inj_thms; - - val inj_thm = SkipProof.prove_global thy5 [] [] - (HOLogic.mk_Trueprop (mk_conj ind_concl1)) (fn _ => EVERY - [(indtac induction [] THEN_ALL_NEW ObjectLogic.atomize_prems_tac) 1, - REPEAT (EVERY - [rtac allI 1, rtac impI 1, - exh_tac (exh_thm_of dt_info) 1, - REPEAT (EVERY - [hyp_subst_tac 1, - rewrite_goals_tac rewrites, - REPEAT (dresolve_tac [In0_inject, In1_inject] 1), - (eresolve_tac [In0_not_In1 RS notE, In1_not_In0 RS notE] 1) - ORELSE (EVERY - [REPEAT (eresolve_tac (Scons_inject :: - map make_elim [Leaf_inject, Inl_inject, Inr_inject]) 1), - REPEAT (cong_tac 1), rtac refl 1, - REPEAT (atac 1 ORELSE (EVERY - [REPEAT (rtac ext 1), - REPEAT (eresolve_tac (mp :: allE :: - map make_elim (Suml_inject :: Sumr_inject :: - Lim_inject :: inj_thms') @ fun_congs) 1), - atac 1]))])])])]); - - val inj_thms'' = map (fn r => r RS @{thm datatype_injI}) - (split_conj_thm inj_thm); - - val elem_thm = - SkipProof.prove_global thy5 [] [] (HOLogic.mk_Trueprop (mk_conj ind_concl2)) - (fn _ => - EVERY [(indtac induction [] THEN_ALL_NEW ObjectLogic.atomize_prems_tac) 1, - rewrite_goals_tac rewrites, - REPEAT ((resolve_tac rep_intrs THEN_ALL_NEW - ((REPEAT o etac allE) THEN' ares_tac elem_thms)) 1)]); - - in (inj_thms'' @ inj_thms, elem_thms @ (split_conj_thm elem_thm)) - end; - - val (iso_inj_thms_unfolded, iso_elem_thms) = List.foldr prove_iso_thms - ([], map #3 newT_iso_axms) (tl descr); - val iso_inj_thms = map snd newT_iso_inj_thms @ - map (fn r => r RS @{thm injD}) iso_inj_thms_unfolded; - - (* prove dt_rep_set_i x --> x : range dt_Rep_i *) - - fun mk_iso_t (((set_name, iso_name), i), T) = - let val isoT = T --> Univ_elT - in HOLogic.imp $ - (Const (set_name, UnivT') $ mk_Free "x" Univ_elT i) $ - (if i < length newTs then HOLogic.true_const - else HOLogic.mk_mem (mk_Free "x" Univ_elT i, - Const (@{const_name image}, isoT --> HOLogic.mk_setT T --> UnivT) $ - Const (iso_name, isoT) $ Const (@{const_name UNIV}, HOLogic.mk_setT T))) - end; - - val iso_t = HOLogic.mk_Trueprop (mk_conj (map mk_iso_t - (rep_set_names ~~ all_rep_names ~~ (0 upto (length descr' - 1)) ~~ recTs))); - - (* all the theorems are proved by one single simultaneous induction *) - - val range_eqs = map (fn r => mk_meta_eq (r RS @{thm range_ex1_eq})) - iso_inj_thms_unfolded; - - val iso_thms = if length descr = 1 then [] else - Library.drop (length newTs, split_conj_thm - (SkipProof.prove_global thy5 [] [] iso_t (fn _ => EVERY - [(indtac rep_induct [] THEN_ALL_NEW ObjectLogic.atomize_prems_tac) 1, - REPEAT (rtac TrueI 1), - rewrite_goals_tac (mk_meta_eq choice_eq :: - symmetric (mk_meta_eq @{thm expand_fun_eq}) :: range_eqs), - rewrite_goals_tac (map symmetric range_eqs), - REPEAT (EVERY - [REPEAT (eresolve_tac ([rangeE, ex1_implies_ex RS exE] @ - maps (mk_funs_inv thy5 o #1) newT_iso_axms) 1), - TRY (hyp_subst_tac 1), - rtac (sym RS range_eqI) 1, - resolve_tac iso_char_thms 1])]))); - - val Abs_inverse_thms' = - map #1 newT_iso_axms @ - map2 (fn r_inj => fn r => f_myinv_f OF [r_inj, r RS mp]) - iso_inj_thms_unfolded iso_thms; - - val Abs_inverse_thms = maps (mk_funs_inv thy5) Abs_inverse_thms'; - - (******************* freeness theorems for constructors *******************) - - val _ = message config "Proving freeness of constructors ..."; - - (* prove theorem Rep_i (Constr_j ...) = Inj_j ... *) - - fun prove_constr_rep_thm eqn = - let - val inj_thms = map fst newT_iso_inj_thms; - val rewrites = @{thm o_def} :: constr_defs @ (map (mk_meta_eq o #2) newT_iso_axms) - in SkipProof.prove_global thy5 [] [] eqn (fn _ => EVERY - [resolve_tac inj_thms 1, - rewrite_goals_tac rewrites, - rtac refl 3, - resolve_tac rep_intrs 2, - REPEAT (resolve_tac iso_elem_thms 1)]) - end; - - (*--------------------------------------------------------------*) - (* constr_rep_thms and rep_congs are used to prove distinctness *) - (* of constructors. *) - (*--------------------------------------------------------------*) - - val constr_rep_thms = map (map prove_constr_rep_thm) constr_rep_eqns; - - val dist_rewrites = map (fn (rep_thms, dist_lemma) => - dist_lemma::(rep_thms @ [In0_eq, In1_eq, In0_not_In1, In1_not_In0])) - (constr_rep_thms ~~ dist_lemmas); - - fun prove_distinct_thms _ _ (_, []) = [] - | prove_distinct_thms lim dist_rewrites' (k, ts as _ :: _) = - if k >= lim then [] else let - (*number of constructors < distinctness_limit : C_i ... ~= C_j ...*) - fun prove [] = [] - | prove (t :: ts) = - let - val dist_thm = SkipProof.prove_global thy5 [] [] t (fn _ => - EVERY [simp_tac (HOL_ss addsimps dist_rewrites') 1]) - in dist_thm :: standard (dist_thm RS not_sym) :: prove ts end; - in prove ts end; - - val distinct_thms = DatatypeProp.make_distincts descr sorts - |> map2 (prove_distinct_thms - (Config.get_thy thy5 distinctness_limit)) dist_rewrites; - - val simproc_dists = map (fn ((((_, (_, _, constrs)), rep_thms), congr), dists) => - if length constrs < Config.get_thy thy5 distinctness_limit - then FewConstrs dists - else ManyConstrs (congr, HOL_basic_ss addsimps rep_thms)) (hd descr ~~ - constr_rep_thms ~~ rep_congs ~~ distinct_thms); - - (* prove injectivity of constructors *) - - fun prove_constr_inj_thm rep_thms t = - let val inj_thms = Scons_inject :: (map make_elim - (iso_inj_thms @ - [In0_inject, In1_inject, Leaf_inject, Inl_inject, Inr_inject, - Lim_inject, Suml_inject, Sumr_inject])) - in SkipProof.prove_global thy5 [] [] t (fn _ => EVERY - [rtac iffI 1, - REPEAT (etac conjE 2), hyp_subst_tac 2, rtac refl 2, - dresolve_tac rep_congs 1, dtac box_equals 1, - REPEAT (resolve_tac rep_thms 1), - REPEAT (eresolve_tac inj_thms 1), - REPEAT (ares_tac [conjI] 1 ORELSE (EVERY [REPEAT (rtac ext 1), - REPEAT (eresolve_tac (make_elim fun_cong :: inj_thms) 1), - atac 1]))]) - end; - - val constr_inject = map (fn (ts, thms) => map (prove_constr_inj_thm thms) ts) - ((DatatypeProp.make_injs descr sorts) ~~ constr_rep_thms); - - val ((constr_inject', distinct_thms'), thy6) = - thy5 - |> parent_path (#flat_names config) - |> store_thmss "inject" new_type_names constr_inject - ||>> store_thmss "distinct" new_type_names distinct_thms; - - (*************************** induction theorem ****************************) - - val _ = message config "Proving induction rule for datatypes ..."; - - val Rep_inverse_thms = (map (fn (_, iso, _) => iso RS subst) newT_iso_axms) @ - (map (fn r => r RS myinv_f_f RS subst) iso_inj_thms_unfolded); - val Rep_inverse_thms' = map (fn r => r RS myinv_f_f) iso_inj_thms_unfolded; - - fun mk_indrule_lemma ((prems, concls), ((i, _), T)) = - let - val Rep_t = Const (List.nth (all_rep_names, i), T --> Univ_elT) $ - mk_Free "x" T i; - - val Abs_t = if i < length newTs then - Const (Sign.intern_const thy6 - ("Abs_" ^ (List.nth (new_type_names, i))), Univ_elT --> T) - else Const ("Inductive.myinv", [T --> Univ_elT, Univ_elT] ---> T) $ - Const (List.nth (all_rep_names, i), T --> Univ_elT) - - in (prems @ [HOLogic.imp $ - (Const (List.nth (rep_set_names, i), UnivT') $ Rep_t) $ - (mk_Free "P" (T --> HOLogic.boolT) (i + 1) $ (Abs_t $ Rep_t))], - concls @ [mk_Free "P" (T --> HOLogic.boolT) (i + 1) $ mk_Free "x" T i]) - end; - - val (indrule_lemma_prems, indrule_lemma_concls) = - Library.foldl mk_indrule_lemma (([], []), (descr' ~~ recTs)); - - val cert = cterm_of thy6; - - val indrule_lemma = SkipProof.prove_global thy6 [] [] - (Logic.mk_implies - (HOLogic.mk_Trueprop (mk_conj indrule_lemma_prems), - HOLogic.mk_Trueprop (mk_conj indrule_lemma_concls))) (fn _ => EVERY - [REPEAT (etac conjE 1), - REPEAT (EVERY - [TRY (rtac conjI 1), resolve_tac Rep_inverse_thms 1, - etac mp 1, resolve_tac iso_elem_thms 1])]); - - val Ps = map head_of (HOLogic.dest_conj (HOLogic.dest_Trueprop (concl_of indrule_lemma))); - val frees = if length Ps = 1 then [Free ("P", snd (dest_Var (hd Ps)))] else - map (Free o apfst fst o dest_Var) Ps; - val indrule_lemma' = cterm_instantiate (map cert Ps ~~ map cert frees) indrule_lemma; - - val dt_induct_prop = DatatypeProp.make_ind descr sorts; - val dt_induct = SkipProof.prove_global thy6 [] - (Logic.strip_imp_prems dt_induct_prop) (Logic.strip_imp_concl dt_induct_prop) - (fn {prems, ...} => EVERY - [rtac indrule_lemma' 1, - (indtac rep_induct [] THEN_ALL_NEW ObjectLogic.atomize_prems_tac) 1, - EVERY (map (fn (prem, r) => (EVERY - [REPEAT (eresolve_tac Abs_inverse_thms 1), - simp_tac (HOL_basic_ss addsimps ((symmetric r)::Rep_inverse_thms')) 1, - DEPTH_SOLVE_1 (ares_tac [prem] 1 ORELSE etac allE 1)])) - (prems ~~ (constr_defs @ (map mk_meta_eq iso_char_thms))))]); - - val ([dt_induct'], thy7) = - thy6 - |> Sign.add_path big_name - |> PureThy.add_thms [((Binding.name "induct", dt_induct), [case_names_induct])] - ||> Sign.parent_path - ||> Theory.checkpoint; - - in - ((constr_inject', distinct_thms', dist_rewrites, simproc_dists, dt_induct'), thy7) - end; - -end; diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/auto_term.ML --- a/src/HOL/Tools/function_package/auto_term.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,37 +0,0 @@ -(* Title: HOL/Tools/function_package/auto_term.ML - ID: $Id$ - Author: Alexander Krauss, TU Muenchen - -A package for general recursive function definitions. -Method "relation" to commence a termination proof using a user-specified relation. -*) - -signature FUNDEF_RELATION = -sig - val relation_tac: Proof.context -> term -> int -> tactic - val setup: theory -> theory -end - -structure FundefRelation : FUNDEF_RELATION = -struct - -fun inst_thm ctxt rel st = - let - val cert = Thm.cterm_of (ProofContext.theory_of ctxt) - val rel' = cert (singleton (Variable.polymorphic ctxt) rel) - val st' = Thm.incr_indexes (#maxidx (Thm.rep_cterm rel') + 1) st - val Rvar = cert (Var (the_single (Term.add_vars (prop_of st') []))) - in - Drule.cterm_instantiate [(Rvar, rel')] st' - end - -fun relation_tac ctxt rel i = - TRY (FundefCommon.apply_termination_rule ctxt i) - THEN PRIMITIVE (inst_thm ctxt rel) - -val setup = - Method.setup @{binding relation} - (Args.term >> (fn rel => fn ctxt => SIMPLE_METHOD' (relation_tac ctxt rel))) - "proves termination using a user-specified wellfounded relation" - -end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/context_tree.ML --- a/src/HOL/Tools/function_package/context_tree.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,278 +0,0 @@ -(* Title: HOL/Tools/function_package/context_tree.ML - Author: Alexander Krauss, TU Muenchen - -A package for general recursive function definitions. -Builds and traverses trees of nested contexts along a term. -*) - -signature FUNDEF_CTXTREE = -sig - type ctxt = (string * typ) list * thm list (* poor man's contexts: fixes + assumes *) - type ctx_tree - - (* FIXME: This interface is a mess and needs to be cleaned up! *) - val get_fundef_congs : Proof.context -> thm list - val add_fundef_cong : thm -> Context.generic -> Context.generic - val map_fundef_congs : (thm list -> thm list) -> Context.generic -> Context.generic - - val cong_add: attribute - val cong_del: attribute - - val mk_tree: (string * typ) -> term -> Proof.context -> term -> ctx_tree - - val inst_tree: theory -> term -> term -> ctx_tree -> ctx_tree - - val export_term : ctxt -> term -> term - val export_thm : theory -> ctxt -> thm -> thm - val import_thm : theory -> ctxt -> thm -> thm - - val traverse_tree : - (ctxt -> term -> - (ctxt * thm) list -> - (ctxt * thm) list * 'b -> - (ctxt * thm) list * 'b) - -> ctx_tree -> 'b -> 'b - - val rewrite_by_tree : theory -> term -> thm -> (thm * thm) list -> ctx_tree -> thm * (thm * thm) list -end - -structure FundefCtxTree : FUNDEF_CTXTREE = -struct - -type ctxt = (string * typ) list * thm list - -open FundefCommon -open FundefLib - -structure FundefCongs = GenericDataFun -( - type T = thm list - val empty = [] - val extend = I - fun merge _ = Thm.merge_thms -); - -val get_fundef_congs = FundefCongs.get o Context.Proof -val map_fundef_congs = FundefCongs.map -val add_fundef_cong = FundefCongs.map o Thm.add_thm - -(* congruence rules *) - -val cong_add = Thm.declaration_attribute (map_fundef_congs o Thm.add_thm o safe_mk_meta_eq); -val cong_del = Thm.declaration_attribute (map_fundef_congs o Thm.del_thm o safe_mk_meta_eq); - - -type depgraph = int IntGraph.T - -datatype ctx_tree - = Leaf of term - | Cong of (thm * depgraph * (ctxt * ctx_tree) list) - | RCall of (term * ctx_tree) - - -(* Maps "Trueprop A = B" to "A" *) -val rhs_of = snd o HOLogic.dest_eq o HOLogic.dest_Trueprop - - -(*** Dependency analysis for congruence rules ***) - -fun branch_vars t = - let - val t' = snd (dest_all_all t) - val (assumes, concl) = Logic.strip_horn t' - in (fold Term.add_vars assumes [], Term.add_vars concl []) - end - -fun cong_deps crule = - let - val num_branches = map_index (apsnd branch_vars) (prems_of crule) - in - IntGraph.empty - |> fold (fn (i,_)=> IntGraph.new_node (i,i)) num_branches - |> fold_product (fn (i, (c1, _)) => fn (j, (_, t2)) => - if i = j orelse null (c1 inter t2) - then I else IntGraph.add_edge_acyclic (i,j)) - num_branches num_branches - end - -val default_congs = map (fn c => c RS eq_reflection) [@{thm "cong"}, @{thm "ext"}] - - - -(* Called on the INSTANTIATED branches of the congruence rule *) -fun mk_branch ctx t = - let - val (ctx', fixes, impl) = dest_all_all_ctx ctx t - val (assms, concl) = Logic.strip_horn impl - in - (ctx', fixes, assms, rhs_of concl) - end - -fun find_cong_rule ctx fvar h ((r,dep)::rs) t = - (let - val thy = ProofContext.theory_of ctx - - val tt' = Logic.mk_equals (Pattern.rewrite_term thy [(Free fvar, h)] [] t, t) - val (c, subs) = (concl_of r, prems_of r) - - val subst = Pattern.match (ProofContext.theory_of ctx) (c, tt') (Vartab.empty, Vartab.empty) - val branches = map (mk_branch ctx o Envir.beta_norm o Envir.subst_vars subst) subs - val inst = map (fn v => (cterm_of thy (Var v), cterm_of thy (Envir.subst_vars subst (Var v)))) (Term.add_vars c []) - in - (cterm_instantiate inst r, dep, branches) - end - handle Pattern.MATCH => find_cong_rule ctx fvar h rs t) - | find_cong_rule _ _ _ [] _ = sys_error "function_package/context_tree.ML: No cong rule found!" - - -fun mk_tree fvar h ctxt t = - let - val congs = get_fundef_congs ctxt - val congs_deps = map (fn c => (c, cong_deps c)) (congs @ default_congs) (* FIXME: Save in theory *) - - fun matchcall (a $ b) = if a = Free fvar then SOME b else NONE - | matchcall _ = NONE - - fun mk_tree' ctx t = - case matchcall t of - SOME arg => RCall (t, mk_tree' ctx arg) - | NONE => - if not (exists_subterm (fn Free v => v = fvar | _ => false) t) then Leaf t - else - let val (r, dep, branches) = find_cong_rule ctx fvar h congs_deps t in - Cong (r, dep, - map (fn (ctx', fixes, assumes, st) => - ((fixes, map (assume o cterm_of (ProofContext.theory_of ctx)) assumes), - mk_tree' ctx' st)) branches) - end - in - mk_tree' ctxt t - end - - -fun inst_tree thy fvar f tr = - let - val cfvar = cterm_of thy fvar - val cf = cterm_of thy f - - fun inst_term t = - subst_bound(f, abstract_over (fvar, t)) - - val inst_thm = forall_elim cf o forall_intr cfvar - - fun inst_tree_aux (Leaf t) = Leaf t - | inst_tree_aux (Cong (crule, deps, branches)) = - Cong (inst_thm crule, deps, map inst_branch branches) - | inst_tree_aux (RCall (t, str)) = - RCall (inst_term t, inst_tree_aux str) - and inst_branch ((fxs, assms), str) = - ((fxs, map (assume o cterm_of thy o inst_term o prop_of) assms), inst_tree_aux str) - in - inst_tree_aux tr - end - - -(* Poor man's contexts: Only fixes and assumes *) -fun compose (fs1, as1) (fs2, as2) = (fs1 @ fs2, as1 @ as2) - -fun export_term (fixes, assumes) = - fold_rev (curry Logic.mk_implies o prop_of) assumes - #> fold_rev (Logic.all o Free) fixes - -fun export_thm thy (fixes, assumes) = - fold_rev (implies_intr o cprop_of) assumes - #> fold_rev (forall_intr o cterm_of thy o Free) fixes - -fun import_thm thy (fixes, athms) = - fold (forall_elim o cterm_of thy o Free) fixes - #> fold Thm.elim_implies athms - - -(* folds in the order of the dependencies of a graph. *) -fun fold_deps G f x = - let - fun fill_table i (T, x) = - case Inttab.lookup T i of - SOME _ => (T, x) - | NONE => - let - val (T', x') = fold fill_table (IntGraph.imm_succs G i) (T, x) - val (v, x'') = f (the o Inttab.lookup T') i x' - in - (Inttab.update (i, v) T', x'') - end - - val (T, x) = fold fill_table (IntGraph.keys G) (Inttab.empty, x) - in - (Inttab.fold (cons o snd) T [], x) - end - -fun traverse_tree rcOp tr = - let - fun traverse_help ctx (Leaf _) _ x = ([], x) - | traverse_help ctx (RCall (t, st)) u x = - rcOp ctx t u (traverse_help ctx st u x) - | traverse_help ctx (Cong (_, deps, branches)) u x = - let - fun sub_step lu i x = - let - val (ctx', subtree) = nth branches i - val used = fold_rev (append o lu) (IntGraph.imm_succs deps i) u - val (subs, x') = traverse_help (compose ctx ctx') subtree used x - val exported_subs = map (apfst (compose ctx')) subs (* FIXME: Right order of composition? *) - in - (exported_subs, x') - end - in - fold_deps deps sub_step x - |> apfst flat - end - in - snd o traverse_help ([], []) tr [] - end - -fun rewrite_by_tree thy h ih x tr = - let - fun rewrite_help _ _ x (Leaf t) = (reflexive (cterm_of thy t), x) - | rewrite_help fix h_as x (RCall (_ $ arg, st)) = - let - val (inner, (lRi,ha)::x') = rewrite_help fix h_as x st (* "a' = a" *) - - val iha = import_thm thy (fix, h_as) ha (* (a', h a') : G *) - |> Conv.fconv_rule (Conv.arg_conv (Conv.comb_conv (Conv.arg_conv (K inner)))) - (* (a, h a) : G *) - val inst_ih = instantiate' [] [SOME (cterm_of thy arg)] ih - val eq = implies_elim (implies_elim inst_ih lRi) iha (* h a = f a *) - - val h_a'_eq_h_a = combination (reflexive (cterm_of thy h)) inner - val h_a_eq_f_a = eq RS eq_reflection - val result = transitive h_a'_eq_h_a h_a_eq_f_a - in - (result, x') - end - | rewrite_help fix h_as x (Cong (crule, deps, branches)) = - let - fun sub_step lu i x = - let - val ((fixes, assumes), st) = nth branches i - val used = map lu (IntGraph.imm_succs deps i) - |> map (fn u_eq => (u_eq RS sym) RS eq_reflection) - |> filter_out Thm.is_reflexive - - val assumes' = map (simplify (HOL_basic_ss addsimps used)) assumes - - val (subeq, x') = rewrite_help (fix @ fixes) (h_as @ assumes') x st - val subeq_exp = export_thm thy (fixes, assumes) (subeq RS meta_eq_to_obj_eq) - in - (subeq_exp, x') - end - - val (subthms, x') = fold_deps deps sub_step x - in - (fold_rev (curry op COMP) subthms crule, x') - end - in - rewrite_help [] [] x tr - end - -end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/decompose.ML --- a/src/HOL/Tools/function_package/decompose.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,105 +0,0 @@ -(* Title: HOL/Tools/function_package/decompose.ML - Author: Alexander Krauss, TU Muenchen - -Graph decomposition using "Shallow Dependency Pairs". -*) - -signature DECOMPOSE = -sig - - val derive_chains : Proof.context -> tactic - -> (Termination.data -> int -> tactic) - -> Termination.data -> int -> tactic - - val decompose_tac : Proof.context -> tactic - -> Termination.ttac - -end - -structure Decompose : DECOMPOSE = -struct - -structure TermGraph = GraphFun(type key = term val ord = TermOrd.fast_term_ord); - - -fun derive_chains ctxt chain_tac cont D = Termination.CALLS (fn (cs, i) => - let - val thy = ProofContext.theory_of ctxt - - fun prove_chain c1 c2 D = - if is_some (Termination.get_chain D c1 c2) then D else - let - val goal = HOLogic.mk_eq (HOLogic.mk_binop @{const_name "Relation.rel_comp"} (c1, c2), - Const (@{const_name Set.empty}, fastype_of c1)) - |> HOLogic.mk_Trueprop (* "C1 O C2 = {}" *) - - val chain = case FundefLib.try_proof (cterm_of thy goal) chain_tac of - FundefLib.Solved thm => SOME thm - | _ => NONE - in - Termination.note_chain c1 c2 chain D - end - in - cont (fold_product prove_chain cs cs D) i - end) - - -fun mk_dgraph D cs = - TermGraph.empty - |> fold (fn c => TermGraph.new_node (c,())) cs - |> fold_product (fn c1 => fn c2 => - if is_none (Termination.get_chain D c1 c2 |> the_default NONE) - then TermGraph.add_edge (c1, c2) else I) - cs cs - - -fun ucomp_empty_tac T = - REPEAT_ALL_NEW (rtac @{thm union_comp_emptyR} - ORELSE' rtac @{thm union_comp_emptyL} - ORELSE' SUBGOAL (fn (_ $ (_ $ (_ $ c1 $ c2) $ _), i) => rtac (T c1 c2) i)) - -fun regroup_calls_tac cs = Termination.CALLS (fn (cs', i) => - let - val is = map (fn c => find_index (curry op aconv c) cs') cs - in - CONVERSION (Conv.arg_conv (Conv.arg_conv (FundefLib.regroup_union_conv is))) i - end) - - -fun solve_trivial_tac D = Termination.CALLS -(fn ([c], i) => - (case Termination.get_chain D c c of - SOME (SOME thm) => rtac @{thm wf_no_loop} i - THEN rtac thm i - | _ => no_tac) - | _ => no_tac) - -fun decompose_tac' ctxt cont err_cont D = Termination.CALLS (fn (cs, i) => - let - val G = mk_dgraph D cs - val sccs = TermGraph.strong_conn G - - fun split [SCC] i = (solve_trivial_tac D i ORELSE cont D i) - | split (SCC::rest) i = - regroup_calls_tac SCC i - THEN rtac @{thm wf_union_compatible} i - THEN rtac @{thm less_by_empty} (i + 2) - THEN ucomp_empty_tac (the o the oo Termination.get_chain D) (i + 2) - THEN split rest (i + 1) - THEN (solve_trivial_tac D i ORELSE cont D i) - in - if length sccs > 1 then split sccs i - else solve_trivial_tac D i ORELSE err_cont D i - end) - -fun decompose_tac ctxt chain_tac cont err_cont = - derive_chains ctxt chain_tac - (decompose_tac' ctxt cont err_cont) - -fun auto_decompose_tac ctxt = - Termination.TERMINATION ctxt - (decompose_tac ctxt (auto_tac (local_clasimpset_of ctxt)) - (K (K all_tac)) (K (K no_tac))) - - -end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/descent.ML --- a/src/HOL/Tools/function_package/descent.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,44 +0,0 @@ -(* Title: HOL/Tools/function_package/descent.ML - Author: Alexander Krauss, TU Muenchen - -Descent proofs for termination -*) - - -signature DESCENT = -sig - - val derive_diag : Proof.context -> tactic -> (Termination.data -> int -> tactic) - -> Termination.data -> int -> tactic - - val derive_all : Proof.context -> tactic -> (Termination.data -> int -> tactic) - -> Termination.data -> int -> tactic - -end - - -structure Descent : DESCENT = -struct - -fun gen_descent diag ctxt tac cont D = Termination.CALLS (fn (cs, i) => - let - val thy = ProofContext.theory_of ctxt - val measures_of = Termination.get_measures D - - fun derive c D = - let - val (_, p, _, q, _, _) = Termination.dest_call D c - in - if diag andalso p = q - then fold (fn m => Termination.derive_descent thy tac c m m) (measures_of p) D - else fold_product (Termination.derive_descent thy tac c) - (measures_of p) (measures_of q) D - end - in - cont (FundefCommon.PROFILE "deriving descents" (fold derive cs) D) i - end) - -val derive_diag = gen_descent true -val derive_all = gen_descent false - -end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/fundef.ML --- a/src/HOL/Tools/function_package/fundef.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,226 +0,0 @@ -(* Title: HOL/Tools/function_package/fundef.ML - Author: Alexander Krauss, TU Muenchen - -A package for general recursive function definitions. -Isar commands. -*) - -signature FUNDEF = -sig - val add_fundef : (binding * typ option * mixfix) list - -> (Attrib.binding * term) list - -> FundefCommon.fundef_config - -> local_theory - -> Proof.state - val add_fundef_cmd : (binding * string option * mixfix) list - -> (Attrib.binding * string) list - -> FundefCommon.fundef_config - -> local_theory - -> Proof.state - - val termination_proof : term option -> local_theory -> Proof.state - val termination_proof_cmd : string option -> local_theory -> Proof.state - val termination : term option -> local_theory -> Proof.state - val termination_cmd : string option -> local_theory -> Proof.state - - val setup : theory -> theory - val get_congs : Proof.context -> thm list -end - - -structure Fundef : FUNDEF = -struct - -open FundefLib -open FundefCommon - -val simp_attribs = map (Attrib.internal o K) - [Simplifier.simp_add, - Code.add_default_eqn_attribute, - Nitpick_Const_Simp_Thms.add, - Quickcheck_RecFun_Simp_Thms.add] - -val psimp_attribs = map (Attrib.internal o K) - [Simplifier.simp_add, - Nitpick_Const_Psimp_Thms.add] - -fun note_theorem ((name, atts), ths) = - LocalTheory.note Thm.generatedK ((Binding.qualified_name name, atts), ths) - -fun mk_defname fixes = fixes |> map (fst o fst) |> space_implode "_" - -fun add_simps fnames post sort extra_qualify label moreatts simps lthy = - let - val spec = post simps - |> map (apfst (apsnd (fn ats => moreatts @ ats))) - |> map (apfst (apfst extra_qualify)) - - val (saved_spec_simps, lthy) = - fold_map (LocalTheory.note Thm.generatedK) spec lthy - - val saved_simps = flat (map snd saved_spec_simps) - val simps_by_f = sort saved_simps - - fun add_for_f fname simps = - note_theorem ((Long_Name.qualify fname label, []), simps) #> snd - in - (saved_simps, - fold2 add_for_f fnames simps_by_f lthy) - end - -fun gen_add_fundef is_external prep default_constraint fixspec eqns config lthy = - let - val constrn_fxs = map (fn (b, T, mx) => (b, SOME (the_default default_constraint T), mx)) - val ((fixes0, spec0), ctxt') = prep (constrn_fxs fixspec) eqns lthy - val fixes = map (apfst (apfst Binding.name_of)) fixes0; - val spec = map (fn (bnd, prop) => (bnd, [prop])) spec0; - val (eqs, post, sort_cont, cnames) = FundefCommon.get_preproc lthy config ctxt' fixes spec - - val defname = mk_defname fixes - - val ((goalstate, cont), lthy) = - FundefMutual.prepare_fundef_mutual config defname fixes eqs lthy - - fun afterqed [[proof]] lthy = - let - val FundefResult {fs, R, psimps, trsimps, simple_pinducts, termination, - domintros, cases, ...} = - cont (Thm.close_derivation proof) - - val fnames = map (fst o fst) fixes - val qualify = Long_Name.qualify defname - val addsmps = add_simps fnames post sort_cont - - val (((psimps', pinducts'), (_, [termination'])), lthy) = - lthy - |> addsmps (Binding.qualify false "partial") "psimps" - psimp_attribs psimps - ||> fold_option (snd oo addsmps I "simps" simp_attribs) trsimps - ||>> note_theorem ((qualify "pinduct", - [Attrib.internal (K (RuleCases.case_names cnames)), - Attrib.internal (K (RuleCases.consumes 1)), - Attrib.internal (K (Induct.induct_pred ""))]), simple_pinducts) - ||>> note_theorem ((qualify "termination", []), [termination]) - ||> (snd o note_theorem ((qualify "cases", - [Attrib.internal (K (RuleCases.case_names cnames))]), [cases])) - ||> fold_option (snd oo curry note_theorem (qualify "domintros", [])) domintros - - val cdata = FundefCtxData { add_simps=addsmps, case_names=cnames, psimps=psimps', - pinducts=snd pinducts', termination=termination', - fs=fs, R=R, defname=defname } - val _ = - if not is_external then () - else Specification.print_consts lthy (K false) (map fst fixes) - in - lthy - |> LocalTheory.declaration (add_fundef_data o morph_fundef_data cdata) - end - in - lthy - |> is_external ? LocalTheory.set_group (serial_string ()) - |> Proof.theorem_i NONE afterqed [[(Logic.unprotect (concl_of goalstate), [])]] - |> Proof.refine (Method.primitive_text (fn _ => goalstate)) |> Seq.hd - end - -val add_fundef = gen_add_fundef false Specification.check_spec (TypeInfer.anyT HOLogic.typeS) -val add_fundef_cmd = gen_add_fundef true Specification.read_spec "_::type" - -fun gen_termination_proof prep_term raw_term_opt lthy = - let - val term_opt = Option.map (prep_term lthy) raw_term_opt - val data = the (case term_opt of - SOME t => (import_fundef_data t lthy - handle Option.Option => - error ("Not a function: " ^ quote (Syntax.string_of_term lthy t))) - | NONE => (import_last_fundef lthy handle Option.Option => error "Not a function")) - - val FundefCtxData { termination, R, add_simps, case_names, psimps, - pinducts, defname, ...} = data - val domT = domain_type (fastype_of R) - val goal = HOLogic.mk_Trueprop - (HOLogic.mk_all ("x", domT, mk_acc domT R $ Free ("x", domT))) - fun afterqed [[totality]] lthy = - let - val totality = Thm.close_derivation totality - val remove_domain_condition = - full_simplify (HOL_basic_ss addsimps [totality, True_implies_equals]) - val tsimps = map remove_domain_condition psimps - val tinduct = map remove_domain_condition pinducts - val qualify = Long_Name.qualify defname; - in - lthy - |> add_simps I "simps" simp_attribs tsimps |> snd - |> note_theorem - ((qualify "induct", - [Attrib.internal (K (RuleCases.case_names case_names))]), - tinduct) |> snd - end - in - lthy - |> ProofContext.note_thmss "" - [((Binding.empty, [ContextRules.rule_del]), [([allI], [])])] |> snd - |> ProofContext.note_thmss "" - [((Binding.empty, [ContextRules.intro_bang (SOME 1)]), [([allI], [])])] |> snd - |> ProofContext.note_thmss "" - [((Binding.name "termination", [ContextRules.intro_bang (SOME 0)]), - [([Goal.norm_result termination], [])])] |> snd - |> Proof.theorem_i NONE afterqed [[(goal, [])]] - end - -val termination_proof = gen_termination_proof Syntax.check_term; -val termination_proof_cmd = gen_termination_proof Syntax.read_term; - -fun termination term_opt lthy = - lthy - |> LocalTheory.set_group (serial_string ()) - |> termination_proof term_opt; - -fun termination_cmd term_opt lthy = - lthy - |> LocalTheory.set_group (serial_string ()) - |> termination_proof_cmd term_opt; - - -(* Datatype hook to declare datatype congs as "fundef_congs" *) - - -fun add_case_cong n thy = - Context.theory_map (FundefCtxTree.map_fundef_congs (Thm.add_thm - (Datatype.get_datatype thy n |> the - |> #case_cong - |> safe_mk_meta_eq))) - thy - -val setup_case_cong = Datatype.interpretation (K (fold add_case_cong)) - - -(* setup *) - -val setup = - Attrib.setup @{binding fundef_cong} - (Attrib.add_del FundefCtxTree.cong_add FundefCtxTree.cong_del) - "declaration of congruence rule for function definitions" - #> setup_case_cong - #> FundefRelation.setup - #> FundefCommon.TerminationSimps.setup - -val get_congs = FundefCtxTree.get_fundef_congs - - -(* outer syntax *) - -local structure P = OuterParse and K = OuterKeyword in - -val _ = - OuterSyntax.local_theory_to_proof "function" "define general recursive functions" K.thy_goal - (fundef_parser default_config - >> (fn ((config, fixes), statements) => add_fundef_cmd fixes statements config)); - -val _ = - OuterSyntax.local_theory_to_proof "termination" "prove termination of a recursive function" K.thy_goal - (Scan.option P.term >> termination_cmd); - -end; - - -end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/fundef_common.ML --- a/src/HOL/Tools/function_package/fundef_common.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,343 +0,0 @@ -(* Title: HOL/Tools/function_package/fundef_common.ML - Author: Alexander Krauss, TU Muenchen - -A package for general recursive function definitions. -Common definitions and other infrastructure. -*) - -structure FundefCommon = -struct - -local open FundefLib in - -(* Profiling *) -val profile = ref false; - -fun PROFILE msg = if !profile then timeap_msg msg else I - - -val acc_const_name = @{const_name "accp"} -fun mk_acc domT R = - Const (acc_const_name, (domT --> domT --> HOLogic.boolT) --> domT --> HOLogic.boolT) $ R - -val function_name = suffix "C" -val graph_name = suffix "_graph" -val rel_name = suffix "_rel" -val dom_name = suffix "_dom" - -(* Termination rules *) - -structure TerminationRule = GenericDataFun -( - type T = thm list - val empty = [] - val extend = I - fun merge _ = Thm.merge_thms -); - -val get_termination_rules = TerminationRule.get -val store_termination_rule = TerminationRule.map o cons -val apply_termination_rule = resolve_tac o get_termination_rules o Context.Proof - - -(* Function definition result data *) - -datatype fundef_result = - FundefResult of - { - fs: term list, - G: term, - R: term, - - psimps : thm list, - trsimps : thm list option, - - simple_pinducts : thm list, - cases : thm, - termination : thm, - domintros : thm list option - } - - -datatype fundef_context_data = - FundefCtxData of - { - defname : string, - - (* contains no logical entities: invariant under morphisms *) - add_simps : (binding -> binding) -> string -> Attrib.src list -> thm list - -> local_theory -> thm list * local_theory, - case_names : string list, - - fs : term list, - R : term, - - psimps: thm list, - pinducts: thm list, - termination: thm - } - -fun morph_fundef_data (FundefCtxData {add_simps, case_names, fs, R, - psimps, pinducts, termination, defname}) phi = - let - val term = Morphism.term phi val thm = Morphism.thm phi val fact = Morphism.fact phi - val name = Binding.name_of o Morphism.binding phi o Binding.name - in - FundefCtxData { add_simps = add_simps, case_names = case_names, - fs = map term fs, R = term R, psimps = fact psimps, - pinducts = fact pinducts, termination = thm termination, - defname = name defname } - end - -structure FundefData = GenericDataFun -( - type T = (term * fundef_context_data) Item_Net.T; - val empty = Item_Net.init - (op aconv o pairself fst : (term * fundef_context_data) * (term * fundef_context_data) -> bool) - fst; - val copy = I; - val extend = I; - fun merge _ (tab1, tab2) = Item_Net.merge (tab1, tab2) -); - -val get_fundef = FundefData.get o Context.Proof; - - -(* Generally useful?? *) -fun lift_morphism thy f = - let - val term = Drule.term_rule thy f - in - Morphism.thm_morphism f $> Morphism.term_morphism term - $> Morphism.typ_morphism (Logic.type_map term) - end - -fun import_fundef_data t ctxt = - let - val thy = ProofContext.theory_of ctxt - val ct = cterm_of thy t - val inst_morph = lift_morphism thy o Thm.instantiate - - fun match (trm, data) = - SOME (morph_fundef_data data (inst_morph (Thm.match (cterm_of thy trm, ct)))) - handle Pattern.MATCH => NONE - in - get_first match (Item_Net.retrieve (get_fundef ctxt) t) - end - -fun import_last_fundef ctxt = - case Item_Net.content (get_fundef ctxt) of - [] => NONE - | (t, data) :: _ => - let - val ([t'], ctxt') = Variable.import_terms true [t] ctxt - in - import_fundef_data t' ctxt' - end - -val all_fundef_data = Item_Net.content o get_fundef - -fun add_fundef_data (data as FundefCtxData {fs, termination, ...}) = - FundefData.map (fold (fn f => Item_Net.insert (f, data)) fs) - #> store_termination_rule termination - - -(* Simp rules for termination proofs *) - -structure TerminationSimps = NamedThmsFun -( - val name = "termination_simp" - val description = "Simplification rule for termination proofs" -); - - -(* Default Termination Prover *) - -structure TerminationProver = GenericDataFun -( - type T = Proof.context -> Proof.method - val empty = (fn _ => error "Termination prover not configured") - val extend = I - fun merge _ (a,b) = b (* FIXME *) -); - -val set_termination_prover = TerminationProver.put -val get_termination_prover = TerminationProver.get o Context.Proof - - -(* Configuration management *) -datatype fundef_opt - = Sequential - | Default of string - | DomIntros - | Tailrec - -datatype fundef_config - = FundefConfig of - { - sequential: bool, - default: string, - domintros: bool, - tailrec: bool - } - -fun apply_opt Sequential (FundefConfig {sequential, default, domintros,tailrec}) = - FundefConfig {sequential=true, default=default, domintros=domintros, tailrec=tailrec} - | apply_opt (Default d) (FundefConfig {sequential, default, domintros,tailrec}) = - FundefConfig {sequential=sequential, default=d, domintros=domintros, tailrec=tailrec} - | apply_opt DomIntros (FundefConfig {sequential, default, domintros,tailrec}) = - FundefConfig {sequential=sequential, default=default, domintros=true,tailrec=tailrec} - | apply_opt Tailrec (FundefConfig {sequential, default, domintros,tailrec}) = - FundefConfig {sequential=sequential, default=default, domintros=domintros,tailrec=true} - -val default_config = - FundefConfig { sequential=false, default="%x. undefined" (*FIXME dynamic scoping*), - domintros=false, tailrec=false } - - -(* Analyzing function equations *) - -fun split_def ctxt geq = - let - fun input_error msg = cat_lines [msg, Syntax.string_of_term ctxt geq] - val qs = Term.strip_qnt_vars "all" geq - val imp = Term.strip_qnt_body "all" geq - val (gs, eq) = Logic.strip_horn imp - - val (f_args, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eq) - handle TERM _ => error (input_error "Not an equation") - - val (head, args) = strip_comb f_args - - val fname = fst (dest_Free head) - handle TERM _ => error (input_error "Head symbol must not be a bound variable") - in - (fname, qs, gs, args, rhs) - end - -(* Check for all sorts of errors in the input *) -fun check_defs ctxt fixes eqs = - let - val fnames = map (fst o fst) fixes - - fun check geq = - let - fun input_error msg = error (cat_lines [msg, Syntax.string_of_term ctxt geq]) - - val fqgar as (fname, qs, gs, args, rhs) = split_def ctxt geq - - val _ = fname mem fnames - orelse input_error - ("Head symbol of left hand side must be " - ^ plural "" "one out of " fnames ^ commas_quote fnames) - - val _ = length args > 0 orelse input_error "Function has no arguments:" - - fun add_bvs t is = add_loose_bnos (t, 0, is) - val rvs = (add_bvs rhs [] \\ fold add_bvs args []) - |> map (fst o nth (rev qs)) - - val _ = null rvs orelse input_error - ("Variable" ^ plural " " "s " rvs ^ commas_quote rvs - ^ " occur" ^ plural "s" "" rvs ^ " on right hand side only:") - - val _ = forall (not o Term.exists_subterm - (fn Free (n, _) => n mem fnames | _ => false)) (gs @ args) - orelse input_error "Defined function may not occur in premises or arguments" - - val freeargs = map (fn t => subst_bounds (rev (map Free qs), t)) args - val funvars = filter (fn q => exists (exists_subterm (fn (Free q') $ _ => q = q' | _ => false)) freeargs) qs - val _ = null funvars - orelse (warning (cat_lines - ["Bound variable" ^ plural " " "s " funvars - ^ commas_quote (map fst funvars) ^ - " occur" ^ plural "s" "" funvars ^ " in function position.", - "Misspelled constructor???"]); true) - in - (fname, length args) - end - - val _ = AList.group (op =) (map check eqs) - |> map (fn (fname, ars) => - length (distinct (op =) ars) = 1 - orelse error ("Function " ^ quote fname ^ - " has different numbers of arguments in different equations")) - - fun check_sorts ((fname, fT), _) = - Sorts.of_sort (Sign.classes_of (ProofContext.theory_of ctxt)) (fT, HOLogic.typeS) - orelse error (cat_lines - ["Type of " ^ quote fname ^ " is not of sort " ^ quote "type" ^ ":", - setmp show_sorts true (Syntax.string_of_typ ctxt) fT]) - - val _ = map check_sorts fixes - in - () - end - -(* Preprocessors *) - -type fixes = ((string * typ) * mixfix) list -type 'a spec = (Attrib.binding * 'a list) list -type preproc = fundef_config -> Proof.context -> fixes -> term spec - -> (term list * (thm list -> thm spec) * (thm list -> thm list list) * string list) - -val fname_of = fst o dest_Free o fst o strip_comb o fst - o HOLogic.dest_eq o HOLogic.dest_Trueprop o Logic.strip_imp_concl o snd o dest_all_all - -fun mk_case_names i "" k = mk_case_names i (string_of_int (i + 1)) k - | mk_case_names _ n 0 = [] - | mk_case_names _ n 1 = [n] - | mk_case_names _ n k = map (fn i => n ^ "_" ^ string_of_int i) (1 upto k) - -fun empty_preproc check _ ctxt fixes spec = - let - val (bnds, tss) = split_list spec - val ts = flat tss - val _ = check ctxt fixes ts - val fnames = map (fst o fst) fixes - val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) ts - - fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) - (indices ~~ xs) - |> map (map snd) - - (* using theorem names for case name currently disabled *) - val cnames = map_index (fn (i, _) => mk_case_names i "" 1) bnds |> flat - in - (ts, curry op ~~ bnds o Library.unflat tss, sort, cnames) - end - -structure Preprocessor = GenericDataFun -( - type T = preproc - val empty : T = empty_preproc check_defs - val extend = I - fun merge _ (a, _) = a -); - -val get_preproc = Preprocessor.get o Context.Proof -val set_preproc = Preprocessor.map o K - - - -local - structure P = OuterParse and K = OuterKeyword - - val option_parser = - P.group "option" ((P.reserved "sequential" >> K Sequential) - || ((P.reserved "default" |-- P.term) >> Default) - || (P.reserved "domintros" >> K DomIntros) - || (P.reserved "tailrec" >> K Tailrec)) - - fun config_parser default = - (Scan.optional (P.$$$ "(" |-- P.!!! (P.list1 option_parser) --| P.$$$ ")") []) - >> (fn opts => fold apply_opt opts default) -in - fun fundef_parser default_cfg = - config_parser default_cfg -- P.fixes -- SpecParse.where_alt_specs -end - - -end -end - diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/fundef_core.ML --- a/src/HOL/Tools/function_package/fundef_core.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,954 +0,0 @@ -(* Title: HOL/Tools/function_package/fundef_core.ML - Author: Alexander Krauss, TU Muenchen - -A package for general recursive function definitions: -Main functionality. -*) - -signature FUNDEF_CORE = -sig - val prepare_fundef : FundefCommon.fundef_config - -> string (* defname *) - -> ((bstring * typ) * mixfix) list (* defined symbol *) - -> ((bstring * typ) list * term list * term * term) list (* specification *) - -> local_theory - - -> (term (* f *) - * thm (* goalstate *) - * (thm -> FundefCommon.fundef_result) (* continuation *) - ) * local_theory - -end - -structure FundefCore : FUNDEF_CORE = -struct - -val boolT = HOLogic.boolT -val mk_eq = HOLogic.mk_eq - -open FundefLib -open FundefCommon - -datatype globals = - Globals of { - fvar: term, - domT: typ, - ranT: typ, - h: term, - y: term, - x: term, - z: term, - a: term, - P: term, - D: term, - Pbool:term -} - - -datatype rec_call_info = - RCInfo of - { - RIvs: (string * typ) list, (* Call context: fixes and assumes *) - CCas: thm list, - rcarg: term, (* The recursive argument *) - - llRI: thm, - h_assum: term - } - - -datatype clause_context = - ClauseContext of - { - ctxt : Proof.context, - - qs : term list, - gs : term list, - lhs: term, - rhs: term, - - cqs: cterm list, - ags: thm list, - case_hyp : thm - } - - -fun transfer_clause_ctx thy (ClauseContext { ctxt, qs, gs, lhs, rhs, cqs, ags, case_hyp }) = - ClauseContext { ctxt = ProofContext.transfer thy ctxt, - qs = qs, gs = gs, lhs = lhs, rhs = rhs, cqs = cqs, ags = ags, case_hyp = case_hyp } - - -datatype clause_info = - ClauseInfo of - { - no: int, - qglr : ((string * typ) list * term list * term * term), - cdata : clause_context, - - tree: FundefCtxTree.ctx_tree, - lGI: thm, - RCs: rec_call_info list - } - - -(* Theory dependencies. *) -val Pair_inject = @{thm Product_Type.Pair_inject}; - -val acc_induct_rule = @{thm accp_induct_rule}; - -val ex1_implies_ex = @{thm FunDef.fundef_ex1_existence}; -val ex1_implies_un = @{thm FunDef.fundef_ex1_uniqueness}; -val ex1_implies_iff = @{thm FunDef.fundef_ex1_iff}; - -val acc_downward = @{thm accp_downward}; -val accI = @{thm accp.accI}; -val case_split = @{thm HOL.case_split}; -val fundef_default_value = @{thm FunDef.fundef_default_value}; -val not_acc_down = @{thm not_accp_down}; - - - -fun find_calls tree = - let - fun add_Ri (fixes,assumes) (_ $ arg) _ (_, xs) = ([], (fixes, assumes, arg) :: xs) - | add_Ri _ _ _ _ = raise Match - in - rev (FundefCtxTree.traverse_tree add_Ri tree []) - end - - -(** building proof obligations *) - -fun mk_compat_proof_obligations domT ranT fvar f glrs = - let - fun mk_impl ((qs, gs, lhs, rhs),(qs', gs', lhs', rhs')) = - let - val shift = incr_boundvars (length qs') - in - Logic.mk_implies - (HOLogic.mk_Trueprop (HOLogic.eq_const domT $ shift lhs $ lhs'), - HOLogic.mk_Trueprop (HOLogic.eq_const ranT $ shift rhs $ rhs')) - |> fold_rev (curry Logic.mk_implies) (map shift gs @ gs') - |> fold_rev (fn (n,T) => fn b => Term.all T $ Abs(n,T,b)) (qs @ qs') - |> curry abstract_over fvar - |> curry subst_bound f - end - in - map mk_impl (unordered_pairs glrs) - end - - -fun mk_completeness (Globals {x, Pbool, ...}) clauses qglrs = - let - fun mk_case (ClauseContext {qs, gs, lhs, ...}, (oqs, _, _, _)) = - HOLogic.mk_Trueprop Pbool - |> curry Logic.mk_implies (HOLogic.mk_Trueprop (mk_eq (x, lhs))) - |> fold_rev (curry Logic.mk_implies) gs - |> fold_rev mk_forall_rename (map fst oqs ~~ qs) - in - HOLogic.mk_Trueprop Pbool - |> fold_rev (curry Logic.mk_implies o mk_case) (clauses ~~ qglrs) - |> mk_forall_rename ("x", x) - |> mk_forall_rename ("P", Pbool) - end - -(** making a context with it's own local bindings **) - -fun mk_clause_context x ctxt (pre_qs,pre_gs,pre_lhs,pre_rhs) = - let - val (qs, ctxt') = Variable.variant_fixes (map fst pre_qs) ctxt - |>> map2 (fn (_, T) => fn n => Free (n, T)) pre_qs - - val thy = ProofContext.theory_of ctxt' - - fun inst t = subst_bounds (rev qs, t) - val gs = map inst pre_gs - val lhs = inst pre_lhs - val rhs = inst pre_rhs - - val cqs = map (cterm_of thy) qs - val ags = map (assume o cterm_of thy) gs - - val case_hyp = assume (cterm_of thy (HOLogic.mk_Trueprop (mk_eq (x, lhs)))) - in - ClauseContext { ctxt = ctxt', qs = qs, gs = gs, lhs = lhs, rhs = rhs, - cqs = cqs, ags = ags, case_hyp = case_hyp } - end - - -(* lowlevel term function *) -fun abstract_over_list vs body = - let - exception SAME; - fun abs lev v tm = - if v aconv tm then Bound lev - else - (case tm of - Abs (a, T, t) => Abs (a, T, abs (lev + 1) v t) - | t $ u => (abs lev v t $ (abs lev v u handle SAME => u) handle SAME => t $ abs lev v u) - | _ => raise SAME); - in - fold_index (fn (i,v) => fn t => abs i v t handle SAME => t) vs body - end - - - -fun mk_clause_info globals G f no cdata qglr tree RCs GIntro_thm RIntro_thms = - let - val Globals {h, fvar, x, ...} = globals - - val ClauseContext { ctxt, qs, cqs, ags, ... } = cdata - val cert = Thm.cterm_of (ProofContext.theory_of ctxt) - - (* Instantiate the GIntro thm with "f" and import into the clause context. *) - val lGI = GIntro_thm - |> forall_elim (cert f) - |> fold forall_elim cqs - |> fold Thm.elim_implies ags - - fun mk_call_info (rcfix, rcassm, rcarg) RI = - let - val llRI = RI - |> fold forall_elim cqs - |> fold (forall_elim o cert o Free) rcfix - |> fold Thm.elim_implies ags - |> fold Thm.elim_implies rcassm - - val h_assum = - HOLogic.mk_Trueprop (G $ rcarg $ (h $ rcarg)) - |> fold_rev (curry Logic.mk_implies o prop_of) rcassm - |> fold_rev (Logic.all o Free) rcfix - |> Pattern.rewrite_term (ProofContext.theory_of ctxt) [(f, h)] [] - |> abstract_over_list (rev qs) - in - RCInfo {RIvs=rcfix, rcarg=rcarg, CCas=rcassm, llRI=llRI, h_assum=h_assum} - end - - val RC_infos = map2 mk_call_info RCs RIntro_thms - in - ClauseInfo - { - no=no, - cdata=cdata, - qglr=qglr, - - lGI=lGI, - RCs=RC_infos, - tree=tree - } - end - - - - - - - -(* replace this by a table later*) -fun store_compat_thms 0 thms = [] - | store_compat_thms n thms = - let - val (thms1, thms2) = chop n thms - in - (thms1 :: store_compat_thms (n - 1) thms2) - end - -(* expects i <= j *) -fun lookup_compat_thm i j cts = - nth (nth cts (i - 1)) (j - i) - -(* Returns "Gsi, Gsj, lhs_i = lhs_j |-- rhs_j_f = rhs_i_f" *) -(* if j < i, then turn around *) -fun get_compat_thm thy cts i j ctxi ctxj = - let - val ClauseContext {cqs=cqsi,ags=agsi,lhs=lhsi,...} = ctxi - val ClauseContext {cqs=cqsj,ags=agsj,lhs=lhsj,...} = ctxj - - val lhsi_eq_lhsj = cterm_of thy (HOLogic.mk_Trueprop (mk_eq (lhsi, lhsj))) - in if j < i then - let - val compat = lookup_compat_thm j i cts - in - compat (* "!!qj qi. Gsj => Gsi => lhsj = lhsi ==> rhsj = rhsi" *) - |> fold forall_elim (cqsj @ cqsi) (* "Gsj => Gsi => lhsj = lhsi ==> rhsj = rhsi" *) - |> fold Thm.elim_implies agsj - |> fold Thm.elim_implies agsi - |> Thm.elim_implies ((assume lhsi_eq_lhsj) RS sym) (* "Gsj, Gsi, lhsi = lhsj |-- rhsj = rhsi" *) - end - else - let - val compat = lookup_compat_thm i j cts - in - compat (* "!!qi qj. Gsi => Gsj => lhsi = lhsj ==> rhsi = rhsj" *) - |> fold forall_elim (cqsi @ cqsj) (* "Gsi => Gsj => lhsi = lhsj ==> rhsi = rhsj" *) - |> fold Thm.elim_implies agsi - |> fold Thm.elim_implies agsj - |> Thm.elim_implies (assume lhsi_eq_lhsj) - |> (fn thm => thm RS sym) (* "Gsi, Gsj, lhsi = lhsj |-- rhsj = rhsi" *) - end - end - - - - -(* Generates the replacement lemma in fully quantified form. *) -fun mk_replacement_lemma thy h ih_elim clause = - let - val ClauseInfo {cdata=ClauseContext {qs, lhs, rhs, cqs, ags, case_hyp, ...}, RCs, tree, ...} = clause - local open Conv in - val ih_conv = arg1_conv o arg_conv o arg_conv - end - - val ih_elim_case = Conv.fconv_rule (ih_conv (K (case_hyp RS eq_reflection))) ih_elim - - val Ris = map (fn RCInfo {llRI, ...} => llRI) RCs - val h_assums = map (fn RCInfo {h_assum, ...} => assume (cterm_of thy (subst_bounds (rev qs, h_assum)))) RCs - - val (eql, _) = FundefCtxTree.rewrite_by_tree thy h ih_elim_case (Ris ~~ h_assums) tree - - val replace_lemma = (eql RS meta_eq_to_obj_eq) - |> implies_intr (cprop_of case_hyp) - |> fold_rev (implies_intr o cprop_of) h_assums - |> fold_rev (implies_intr o cprop_of) ags - |> fold_rev forall_intr cqs - |> Thm.close_derivation - in - replace_lemma - end - - -fun mk_uniqueness_clause thy globals f compat_store clausei clausej RLj = - let - val Globals {h, y, x, fvar, ...} = globals - val ClauseInfo {no=i, cdata=cctxi as ClauseContext {ctxt=ctxti, lhs=lhsi, case_hyp, ...}, ...} = clausei - val ClauseInfo {no=j, qglr=cdescj, RCs=RCsj, ...} = clausej - - val cctxj as ClauseContext {ags = agsj', lhs = lhsj', rhs = rhsj', qs = qsj', cqs = cqsj', ...} - = mk_clause_context x ctxti cdescj - - val rhsj'h = Pattern.rewrite_term thy [(fvar,h)] [] rhsj' - val compat = get_compat_thm thy compat_store i j cctxi cctxj - val Ghsj' = map (fn RCInfo {h_assum, ...} => assume (cterm_of thy (subst_bounds (rev qsj', h_assum)))) RCsj - - val RLj_import = - RLj |> fold forall_elim cqsj' - |> fold Thm.elim_implies agsj' - |> fold Thm.elim_implies Ghsj' - - val y_eq_rhsj'h = assume (cterm_of thy (HOLogic.mk_Trueprop (mk_eq (y, rhsj'h)))) - val lhsi_eq_lhsj' = assume (cterm_of thy (HOLogic.mk_Trueprop (mk_eq (lhsi, lhsj')))) (* lhs_i = lhs_j' |-- lhs_i = lhs_j' *) - in - (trans OF [case_hyp, lhsi_eq_lhsj']) (* lhs_i = lhs_j' |-- x = lhs_j' *) - |> implies_elim RLj_import (* Rj1' ... Rjk', lhs_i = lhs_j' |-- rhs_j'_h = rhs_j'_f *) - |> (fn it => trans OF [it, compat]) (* lhs_i = lhs_j', Gj', Rj1' ... Rjk' |-- rhs_j'_h = rhs_i_f *) - |> (fn it => trans OF [y_eq_rhsj'h, it]) (* lhs_i = lhs_j', Gj', Rj1' ... Rjk', y = rhs_j_h' |-- y = rhs_i_f *) - |> fold_rev (implies_intr o cprop_of) Ghsj' - |> fold_rev (implies_intr o cprop_of) agsj' (* lhs_i = lhs_j' , y = rhs_j_h' |-- Gj', Rj1'...Rjk' ==> y = rhs_i_f *) - |> implies_intr (cprop_of y_eq_rhsj'h) - |> implies_intr (cprop_of lhsi_eq_lhsj') - |> fold_rev forall_intr (cterm_of thy h :: cqsj') - end - - - -fun mk_uniqueness_case ctxt thy globals G f ihyp ih_intro G_cases compat_store clauses rep_lemmas clausei = - let - val Globals {x, y, ranT, fvar, ...} = globals - val ClauseInfo {cdata = ClauseContext {lhs, rhs, qs, cqs, ags, case_hyp, ...}, lGI, RCs, ...} = clausei - val rhsC = Pattern.rewrite_term thy [(fvar, f)] [] rhs - - val ih_intro_case = full_simplify (HOL_basic_ss addsimps [case_hyp]) ih_intro - - fun prep_RC (RCInfo {llRI, RIvs, CCas, ...}) = (llRI RS ih_intro_case) - |> fold_rev (implies_intr o cprop_of) CCas - |> fold_rev (forall_intr o cterm_of thy o Free) RIvs - - val existence = fold (curry op COMP o prep_RC) RCs lGI - - val P = cterm_of thy (mk_eq (y, rhsC)) - val G_lhs_y = assume (cterm_of thy (HOLogic.mk_Trueprop (G $ lhs $ y))) - - val unique_clauses = map2 (mk_uniqueness_clause thy globals f compat_store clausei) clauses rep_lemmas - - val uniqueness = G_cases - |> forall_elim (cterm_of thy lhs) - |> forall_elim (cterm_of thy y) - |> forall_elim P - |> Thm.elim_implies G_lhs_y - |> fold Thm.elim_implies unique_clauses - |> implies_intr (cprop_of G_lhs_y) - |> forall_intr (cterm_of thy y) - - val P2 = cterm_of thy (lambda y (G $ lhs $ y)) (* P2 y := (lhs, y): G *) - - val exactly_one = - ex1I |> instantiate' [SOME (ctyp_of thy ranT)] [SOME P2, SOME (cterm_of thy rhsC)] - |> curry (op COMP) existence - |> curry (op COMP) uniqueness - |> simplify (HOL_basic_ss addsimps [case_hyp RS sym]) - |> implies_intr (cprop_of case_hyp) - |> fold_rev (implies_intr o cprop_of) ags - |> fold_rev forall_intr cqs - - val function_value = - existence - |> implies_intr ihyp - |> implies_intr (cprop_of case_hyp) - |> forall_intr (cterm_of thy x) - |> forall_elim (cterm_of thy lhs) - |> curry (op RS) refl - in - (exactly_one, function_value) - end - - - - -fun prove_stuff ctxt globals G f R clauses complete compat compat_store G_elim f_def = - let - val Globals {h, domT, ranT, x, ...} = globals - val thy = ProofContext.theory_of ctxt - - (* Inductive Hypothesis: !!z. (z,x):R ==> EX!y. (z,y):G *) - val ihyp = Term.all domT $ Abs ("z", domT, - Logic.mk_implies (HOLogic.mk_Trueprop (R $ Bound 0 $ x), - HOLogic.mk_Trueprop (Const ("Ex1", (ranT --> boolT) --> boolT) $ - Abs ("y", ranT, G $ Bound 1 $ Bound 0)))) - |> cterm_of thy - - val ihyp_thm = assume ihyp |> Thm.forall_elim_vars 0 - val ih_intro = ihyp_thm RS (f_def RS ex1_implies_ex) - val ih_elim = ihyp_thm RS (f_def RS ex1_implies_un) - |> instantiate' [] [NONE, SOME (cterm_of thy h)] - - val _ = Output.debug (K "Proving Replacement lemmas...") - val repLemmas = map (mk_replacement_lemma thy h ih_elim) clauses - - val _ = Output.debug (K "Proving cases for unique existence...") - val (ex1s, values) = - split_list (map (mk_uniqueness_case ctxt thy globals G f ihyp ih_intro G_elim compat_store clauses repLemmas) clauses) - - val _ = Output.debug (K "Proving: Graph is a function") - val graph_is_function = complete - |> Thm.forall_elim_vars 0 - |> fold (curry op COMP) ex1s - |> implies_intr (ihyp) - |> implies_intr (cterm_of thy (HOLogic.mk_Trueprop (mk_acc domT R $ x))) - |> forall_intr (cterm_of thy x) - |> (fn it => Drule.compose_single (it, 2, acc_induct_rule)) (* "EX! y. (?x,y):G" *) - |> (fn it => fold (forall_intr o cterm_of thy o Var) (Term.add_vars (prop_of it) []) it) - - val goalstate = Conjunction.intr graph_is_function complete - |> Thm.close_derivation - |> Goal.protect - |> fold_rev (implies_intr o cprop_of) compat - |> implies_intr (cprop_of complete) - in - (goalstate, values) - end - - -fun define_graph Gname fvar domT ranT clauses RCss lthy = - let - val GT = domT --> ranT --> boolT - val Gvar = Free (the_single (Variable.variant_frees lthy [] [(Gname, GT)])) - - fun mk_GIntro (ClauseContext {qs, gs, lhs, rhs, ...}) RCs = - let - fun mk_h_assm (rcfix, rcassm, rcarg) = - HOLogic.mk_Trueprop (Gvar $ rcarg $ (fvar $ rcarg)) - |> fold_rev (curry Logic.mk_implies o prop_of) rcassm - |> fold_rev (Logic.all o Free) rcfix - in - HOLogic.mk_Trueprop (Gvar $ lhs $ rhs) - |> fold_rev (curry Logic.mk_implies o mk_h_assm) RCs - |> fold_rev (curry Logic.mk_implies) gs - |> fold_rev Logic.all (fvar :: qs) - end - - val G_intros = map2 mk_GIntro clauses RCss - - val (GIntro_thms, (G, G_elim, G_induct, lthy)) = - FundefInductiveWrap.inductive_def G_intros ((dest_Free Gvar, NoSyn), lthy) - in - ((G, GIntro_thms, G_elim, G_induct), lthy) - end - - - -fun define_function fdefname (fname, mixfix) domT ranT G default lthy = - let - val f_def = - Abs ("x", domT, Const ("FunDef.THE_default", ranT --> (ranT --> boolT) --> ranT) $ (default $ Bound 0) $ - Abs ("y", ranT, G $ Bound 1 $ Bound 0)) - |> Syntax.check_term lthy - - val ((f, (_, f_defthm)), lthy) = - LocalTheory.define Thm.internalK ((Binding.name (function_name fname), mixfix), ((Binding.name fdefname, []), f_def)) lthy - in - ((f, f_defthm), lthy) - end - - -fun define_recursion_relation Rname domT ranT fvar f qglrs clauses RCss lthy = - let - - val RT = domT --> domT --> boolT - val Rvar = Free (the_single (Variable.variant_frees lthy [] [(Rname, RT)])) - - fun mk_RIntro (ClauseContext {qs, gs, lhs, ...}, (oqs, _, _, _)) (rcfix, rcassm, rcarg) = - HOLogic.mk_Trueprop (Rvar $ rcarg $ lhs) - |> fold_rev (curry Logic.mk_implies o prop_of) rcassm - |> fold_rev (curry Logic.mk_implies) gs - |> fold_rev (Logic.all o Free) rcfix - |> fold_rev mk_forall_rename (map fst oqs ~~ qs) - (* "!!qs xs. CS ==> G => (r, lhs) : R" *) - - val R_intross = map2 (map o mk_RIntro) (clauses ~~ qglrs) RCss - - val (RIntro_thmss, (R, R_elim, _, lthy)) = - fold_burrow FundefInductiveWrap.inductive_def R_intross ((dest_Free Rvar, NoSyn), lthy) - in - ((R, RIntro_thmss, R_elim), lthy) - end - - -fun fix_globals domT ranT fvar ctxt = - let - val ([h, y, x, z, a, D, P, Pbool],ctxt') = - Variable.variant_fixes ["h_fd", "y_fd", "x_fd", "z_fd", "a_fd", "D_fd", "P_fd", "Pb_fd"] ctxt - in - (Globals {h = Free (h, domT --> ranT), - y = Free (y, ranT), - x = Free (x, domT), - z = Free (z, domT), - a = Free (a, domT), - D = Free (D, domT --> boolT), - P = Free (P, domT --> boolT), - Pbool = Free (Pbool, boolT), - fvar = fvar, - domT = domT, - ranT = ranT - }, - ctxt') - end - - - -fun inst_RC thy fvar f (rcfix, rcassm, rcarg) = - let - fun inst_term t = subst_bound(f, abstract_over (fvar, t)) - in - (rcfix, map (assume o cterm_of thy o inst_term o prop_of) rcassm, inst_term rcarg) - end - - - -(********************************************************** - * PROVING THE RULES - **********************************************************) - -fun mk_psimps thy globals R clauses valthms f_iff graph_is_function = - let - val Globals {domT, z, ...} = globals - - fun mk_psimp (ClauseInfo {qglr = (oqs, _, _, _), cdata = ClauseContext {cqs, lhs, ags, ...}, ...}) valthm = - let - val lhs_acc = cterm_of thy (HOLogic.mk_Trueprop (mk_acc domT R $ lhs)) (* "acc R lhs" *) - val z_smaller = cterm_of thy (HOLogic.mk_Trueprop (R $ z $ lhs)) (* "R z lhs" *) - in - ((assume z_smaller) RS ((assume lhs_acc) RS acc_downward)) - |> (fn it => it COMP graph_is_function) - |> implies_intr z_smaller - |> forall_intr (cterm_of thy z) - |> (fn it => it COMP valthm) - |> implies_intr lhs_acc - |> asm_simplify (HOL_basic_ss addsimps [f_iff]) - |> fold_rev (implies_intr o cprop_of) ags - |> fold_rev forall_intr_rename (map fst oqs ~~ cqs) - end - in - map2 mk_psimp clauses valthms - end - - -(** Induction rule **) - - -val acc_subset_induct = @{thm Orderings.predicate1I} RS @{thm accp_subset_induct} - - -fun binder_conv cv ctxt = Conv.arg_conv (Conv.abs_conv (K cv) ctxt); - -fun mk_partial_induct_rule thy globals R complete_thm clauses = - let - val Globals {domT, x, z, a, P, D, ...} = globals - val acc_R = mk_acc domT R - - val x_D = assume (cterm_of thy (HOLogic.mk_Trueprop (D $ x))) - val a_D = cterm_of thy (HOLogic.mk_Trueprop (D $ a)) - - val D_subset = cterm_of thy (Logic.all x - (Logic.mk_implies (HOLogic.mk_Trueprop (D $ x), HOLogic.mk_Trueprop (acc_R $ x)))) - - val D_dcl = (* "!!x z. [| x: D; (z,x):R |] ==> z:D" *) - Logic.all x - (Logic.all z (Logic.mk_implies (HOLogic.mk_Trueprop (D $ x), - Logic.mk_implies (HOLogic.mk_Trueprop (R $ z $ x), - HOLogic.mk_Trueprop (D $ z))))) - |> cterm_of thy - - - (* Inductive Hypothesis: !!z. (z,x):R ==> P z *) - val ihyp = Term.all domT $ Abs ("z", domT, - Logic.mk_implies (HOLogic.mk_Trueprop (R $ Bound 0 $ x), - HOLogic.mk_Trueprop (P $ Bound 0))) - |> cterm_of thy - - val aihyp = assume ihyp - - fun prove_case clause = - let - val ClauseInfo {cdata = ClauseContext {ctxt, qs, cqs, ags, gs, lhs, case_hyp, ...}, RCs, - qglr = (oqs, _, _, _), ...} = clause - - val case_hyp_conv = K (case_hyp RS eq_reflection) - local open Conv in - val lhs_D = fconv_rule (arg_conv (arg_conv (case_hyp_conv))) x_D - val sih = fconv_rule (binder_conv (arg1_conv (arg_conv (arg_conv case_hyp_conv))) ctxt) aihyp - end - - fun mk_Prec (RCInfo {llRI, RIvs, CCas, rcarg, ...}) = - sih |> forall_elim (cterm_of thy rcarg) - |> Thm.elim_implies llRI - |> fold_rev (implies_intr o cprop_of) CCas - |> fold_rev (forall_intr o cterm_of thy o Free) RIvs - - val P_recs = map mk_Prec RCs (* [P rec1, P rec2, ... ] *) - - val step = HOLogic.mk_Trueprop (P $ lhs) - |> fold_rev (curry Logic.mk_implies o prop_of) P_recs - |> fold_rev (curry Logic.mk_implies) gs - |> curry Logic.mk_implies (HOLogic.mk_Trueprop (D $ lhs)) - |> fold_rev mk_forall_rename (map fst oqs ~~ qs) - |> cterm_of thy - - val P_lhs = assume step - |> fold forall_elim cqs - |> Thm.elim_implies lhs_D - |> fold Thm.elim_implies ags - |> fold Thm.elim_implies P_recs - - val res = cterm_of thy (HOLogic.mk_Trueprop (P $ x)) - |> Conv.arg_conv (Conv.arg_conv case_hyp_conv) - |> symmetric (* P lhs == P x *) - |> (fn eql => equal_elim eql P_lhs) (* "P x" *) - |> implies_intr (cprop_of case_hyp) - |> fold_rev (implies_intr o cprop_of) ags - |> fold_rev forall_intr cqs - in - (res, step) - end - - val (cases, steps) = split_list (map prove_case clauses) - - val istep = complete_thm - |> Thm.forall_elim_vars 0 - |> fold (curry op COMP) cases (* P x *) - |> implies_intr ihyp - |> implies_intr (cprop_of x_D) - |> forall_intr (cterm_of thy x) - - val subset_induct_rule = - acc_subset_induct - |> (curry op COMP) (assume D_subset) - |> (curry op COMP) (assume D_dcl) - |> (curry op COMP) (assume a_D) - |> (curry op COMP) istep - |> fold_rev implies_intr steps - |> implies_intr a_D - |> implies_intr D_dcl - |> implies_intr D_subset - - val subset_induct_all = fold_rev (forall_intr o cterm_of thy) [P, a, D] subset_induct_rule - - val simple_induct_rule = - subset_induct_rule - |> forall_intr (cterm_of thy D) - |> forall_elim (cterm_of thy acc_R) - |> assume_tac 1 |> Seq.hd - |> (curry op COMP) (acc_downward - |> (instantiate' [SOME (ctyp_of thy domT)] - (map (SOME o cterm_of thy) [R, x, z])) - |> forall_intr (cterm_of thy z) - |> forall_intr (cterm_of thy x)) - |> forall_intr (cterm_of thy a) - |> forall_intr (cterm_of thy P) - in - simple_induct_rule - end - - - -(* FIXME: This should probably use fixed goals, to be more reliable and faster *) -fun mk_domain_intro ctxt (Globals {domT, ...}) R R_cases clause = - let - val thy = ProofContext.theory_of ctxt - val ClauseInfo {cdata = ClauseContext {qs, gs, lhs, rhs, cqs, ...}, - qglr = (oqs, _, _, _), ...} = clause - val goal = HOLogic.mk_Trueprop (mk_acc domT R $ lhs) - |> fold_rev (curry Logic.mk_implies) gs - |> cterm_of thy - in - Goal.init goal - |> (SINGLE (resolve_tac [accI] 1)) |> the - |> (SINGLE (eresolve_tac [Thm.forall_elim_vars 0 R_cases] 1)) |> the - |> (SINGLE (auto_tac (local_clasimpset_of ctxt))) |> the - |> Goal.conclude - |> fold_rev forall_intr_rename (map fst oqs ~~ cqs) - end - - - -(** Termination rule **) - -val wf_induct_rule = @{thm Wellfounded.wfP_induct_rule}; -val wf_in_rel = @{thm FunDef.wf_in_rel}; -val in_rel_def = @{thm FunDef.in_rel_def}; - -fun mk_nest_term_case thy globals R' ihyp clause = - let - val Globals {x, z, ...} = globals - val ClauseInfo {cdata = ClauseContext {qs,cqs,ags,lhs,rhs,case_hyp,...},tree, - qglr=(oqs, _, _, _), ...} = clause - - val ih_case = full_simplify (HOL_basic_ss addsimps [case_hyp]) ihyp - - fun step (fixes, assumes) (_ $ arg) u (sub,(hyps,thms)) = - let - val used = map (fn (ctx,thm) => FundefCtxTree.export_thm thy ctx thm) (u @ sub) - - val hyp = HOLogic.mk_Trueprop (R' $ arg $ lhs) - |> fold_rev (curry Logic.mk_implies o prop_of) used (* additional hyps *) - |> FundefCtxTree.export_term (fixes, assumes) - |> fold_rev (curry Logic.mk_implies o prop_of) ags - |> fold_rev mk_forall_rename (map fst oqs ~~ qs) - |> cterm_of thy - - val thm = assume hyp - |> fold forall_elim cqs - |> fold Thm.elim_implies ags - |> FundefCtxTree.import_thm thy (fixes, assumes) - |> fold Thm.elim_implies used (* "(arg, lhs) : R'" *) - - val z_eq_arg = assume (cterm_of thy (HOLogic.mk_Trueprop (mk_eq (z, arg)))) - - val acc = thm COMP ih_case - val z_acc_local = acc - |> Conv.fconv_rule (Conv.arg_conv (Conv.arg_conv (K (symmetric (z_eq_arg RS eq_reflection))))) - - val ethm = z_acc_local - |> FundefCtxTree.export_thm thy (fixes, - z_eq_arg :: case_hyp :: ags @ assumes) - |> fold_rev forall_intr_rename (map fst oqs ~~ cqs) - - val sub' = sub @ [(([],[]), acc)] - in - (sub', (hyp :: hyps, ethm :: thms)) - end - | step _ _ _ _ = raise Match - in - FundefCtxTree.traverse_tree step tree - end - - -fun mk_nest_term_rule thy globals R R_cases clauses = - let - val Globals { domT, x, z, ... } = globals - val acc_R = mk_acc domT R - - val R' = Free ("R", fastype_of R) - - val Rrel = Free ("R", HOLogic.mk_setT (HOLogic.mk_prodT (domT, domT))) - val inrel_R = Const ("FunDef.in_rel", HOLogic.mk_setT (HOLogic.mk_prodT (domT, domT)) --> fastype_of R) $ Rrel - - val wfR' = cterm_of thy (HOLogic.mk_Trueprop (Const (@{const_name "Wellfounded.wfP"}, (domT --> domT --> boolT) --> boolT) $ R')) (* "wf R'" *) - - (* Inductive Hypothesis: !!z. (z,x):R' ==> z : acc R *) - val ihyp = Term.all domT $ Abs ("z", domT, - Logic.mk_implies (HOLogic.mk_Trueprop (R' $ Bound 0 $ x), - HOLogic.mk_Trueprop (acc_R $ Bound 0))) - |> cterm_of thy - - val ihyp_a = assume ihyp |> Thm.forall_elim_vars 0 - - val R_z_x = cterm_of thy (HOLogic.mk_Trueprop (R $ z $ x)) - - val (hyps,cases) = fold (mk_nest_term_case thy globals R' ihyp_a) clauses ([],[]) - in - R_cases - |> forall_elim (cterm_of thy z) - |> forall_elim (cterm_of thy x) - |> forall_elim (cterm_of thy (acc_R $ z)) - |> curry op COMP (assume R_z_x) - |> fold_rev (curry op COMP) cases - |> implies_intr R_z_x - |> forall_intr (cterm_of thy z) - |> (fn it => it COMP accI) - |> implies_intr ihyp - |> forall_intr (cterm_of thy x) - |> (fn it => Drule.compose_single(it,2,wf_induct_rule)) - |> curry op RS (assume wfR') - |> forall_intr_vars - |> (fn it => it COMP allI) - |> fold implies_intr hyps - |> implies_intr wfR' - |> forall_intr (cterm_of thy R') - |> forall_elim (cterm_of thy (inrel_R)) - |> curry op RS wf_in_rel - |> full_simplify (HOL_basic_ss addsimps [in_rel_def]) - |> forall_intr (cterm_of thy Rrel) - end - - - -(* Tail recursion (probably very fragile) - * - * FIXME: - * - Need to do forall_elim_vars on psimps: Unneccesary, if psimps would be taken from the same context. - * - Must we really replace the fvar by f here? - * - Splitting is not configured automatically: Problems with case? - *) -fun mk_trsimps octxt globals f G R f_def R_cases G_induct clauses psimps = - let - val Globals {domT, ranT, fvar, ...} = globals - - val R_cases = Thm.forall_elim_vars 0 R_cases (* FIXME: Should be already in standard form. *) - - val graph_implies_dom = (* "G ?x ?y ==> dom ?x" *) - Goal.prove octxt ["x", "y"] [HOLogic.mk_Trueprop (G $ Free ("x", domT) $ Free ("y", ranT))] - (HOLogic.mk_Trueprop (mk_acc domT R $ Free ("x", domT))) - (fn {prems=[a], ...} => - ((rtac (G_induct OF [a])) - THEN_ALL_NEW (rtac accI) - THEN_ALL_NEW (etac R_cases) - THEN_ALL_NEW (asm_full_simp_tac (local_simpset_of octxt))) 1) - - val default_thm = (forall_intr_vars graph_implies_dom) COMP (f_def COMP fundef_default_value) - - fun mk_trsimp clause psimp = - let - val ClauseInfo {qglr = (oqs, _, _, _), cdata = ClauseContext {ctxt, cqs, qs, gs, lhs, rhs, ...}, ...} = clause - val thy = ProofContext.theory_of ctxt - val rhs_f = Pattern.rewrite_term thy [(fvar, f)] [] rhs - - val trsimp = Logic.list_implies(gs, HOLogic.mk_Trueprop (HOLogic.mk_eq(f $ lhs, rhs_f))) (* "f lhs = rhs" *) - val lhs_acc = (mk_acc domT R $ lhs) (* "acc R lhs" *) - fun simp_default_tac ss = asm_full_simp_tac (ss addsimps [default_thm, Let_def]) - in - Goal.prove ctxt [] [] trsimp - (fn _ => - rtac (instantiate' [] [SOME (cterm_of thy lhs_acc)] case_split) 1 - THEN (rtac (Thm.forall_elim_vars 0 psimp) THEN_ALL_NEW assume_tac) 1 - THEN (simp_default_tac (local_simpset_of ctxt) 1) - THEN (etac not_acc_down 1) - THEN ((etac R_cases) THEN_ALL_NEW (simp_default_tac (local_simpset_of ctxt))) 1) - |> fold_rev forall_intr_rename (map fst oqs ~~ cqs) - end - in - map2 mk_trsimp clauses psimps - end - - -fun prepare_fundef config defname [((fname, fT), mixfix)] abstract_qglrs lthy = - let - val FundefConfig {domintros, tailrec, default=default_str, ...} = config - - val fvar = Free (fname, fT) - val domT = domain_type fT - val ranT = range_type fT - - val default = Syntax.parse_term lthy default_str - |> TypeInfer.constrain fT |> Syntax.check_term lthy - - val (globals, ctxt') = fix_globals domT ranT fvar lthy - - val Globals { x, h, ... } = globals - - val clauses = map (mk_clause_context x ctxt') abstract_qglrs - - val n = length abstract_qglrs - - fun build_tree (ClauseContext { ctxt, rhs, ...}) = - FundefCtxTree.mk_tree (fname, fT) h ctxt rhs - - val trees = map build_tree clauses - val RCss = map find_calls trees - - val ((G, GIntro_thms, G_elim, G_induct), lthy) = - PROFILE "def_graph" (define_graph (graph_name defname) fvar domT ranT clauses RCss) lthy - - val ((f, f_defthm), lthy) = - PROFILE "def_fun" (define_function (defname ^ "_sumC_def") (fname, mixfix) domT ranT G default) lthy - - val RCss = map (map (inst_RC (ProofContext.theory_of lthy) fvar f)) RCss - val trees = map (FundefCtxTree.inst_tree (ProofContext.theory_of lthy) fvar f) trees - - val ((R, RIntro_thmss, R_elim), lthy) = - PROFILE "def_rel" (define_recursion_relation (rel_name defname) domT ranT fvar f abstract_qglrs clauses RCss) lthy - - val (_, lthy) = - LocalTheory.abbrev Syntax.mode_default ((Binding.name (dom_name defname), NoSyn), mk_acc domT R) lthy - - val newthy = ProofContext.theory_of lthy - val clauses = map (transfer_clause_ctx newthy) clauses - - val cert = cterm_of (ProofContext.theory_of lthy) - - val xclauses = PROFILE "xclauses" (map7 (mk_clause_info globals G f) (1 upto n) clauses abstract_qglrs trees RCss GIntro_thms) RIntro_thmss - - val complete = mk_completeness globals clauses abstract_qglrs |> cert |> assume - val compat = mk_compat_proof_obligations domT ranT fvar f abstract_qglrs |> map (cert #> assume) - - val compat_store = store_compat_thms n compat - - val (goalstate, values) = PROFILE "prove_stuff" (prove_stuff lthy globals G f R xclauses complete compat compat_store G_elim) f_defthm - - val mk_trsimps = mk_trsimps lthy globals f G R f_defthm R_elim G_induct xclauses - - fun mk_partial_rules provedgoal = - let - val newthy = theory_of_thm provedgoal (*FIXME*) - - val (graph_is_function, complete_thm) = - provedgoal - |> Conjunction.elim - |> apfst (Thm.forall_elim_vars 0) - - val f_iff = graph_is_function RS (f_defthm RS ex1_implies_iff) - - val psimps = PROFILE "Proving simplification rules" (mk_psimps newthy globals R xclauses values f_iff) graph_is_function - - val simple_pinduct = PROFILE "Proving partial induction rule" - (mk_partial_induct_rule newthy globals R complete_thm) xclauses - - - val total_intro = PROFILE "Proving nested termination rule" (mk_nest_term_rule newthy globals R R_elim) xclauses - - val dom_intros = if domintros - then SOME (PROFILE "Proving domain introduction rules" (map (mk_domain_intro lthy globals R R_elim)) xclauses) - else NONE - val trsimps = if tailrec then SOME (mk_trsimps psimps) else NONE - - in - FundefResult {fs=[f], G=G, R=R, cases=complete_thm, - psimps=psimps, simple_pinducts=[simple_pinduct], - termination=total_intro, trsimps=trsimps, - domintros=dom_intros} - end - in - ((f, goalstate, mk_partial_rules), lthy) - end - - -end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/fundef_datatype.ML --- a/src/HOL/Tools/function_package/fundef_datatype.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,330 +0,0 @@ -(* Title: HOL/Tools/function_package/fundef_datatype.ML - Author: Alexander Krauss, TU Muenchen - -A package for general recursive function definitions. -A tactic to prove completeness of datatype patterns. -*) - -signature FUNDEF_DATATYPE = -sig - val pat_completeness_tac: Proof.context -> int -> tactic - val pat_completeness: Proof.context -> Proof.method - val prove_completeness : theory -> term list -> term -> term list list -> term list list -> thm - - val setup : theory -> theory - - val add_fun : FundefCommon.fundef_config -> - (binding * typ option * mixfix) list -> (Attrib.binding * term) list -> - bool -> local_theory -> Proof.context - val add_fun_cmd : FundefCommon.fundef_config -> - (binding * string option * mixfix) list -> (Attrib.binding * string) list -> - bool -> local_theory -> Proof.context -end - -structure FundefDatatype : FUNDEF_DATATYPE = -struct - -open FundefLib -open FundefCommon - - -fun check_pats ctxt geq = - let - fun err str = error (cat_lines ["Malformed definition:", - str ^ " not allowed in sequential mode.", - Syntax.string_of_term ctxt geq]) - val thy = ProofContext.theory_of ctxt - - fun check_constr_pattern (Bound _) = () - | check_constr_pattern t = - let - val (hd, args) = strip_comb t - in - (((case Datatype.datatype_of_constr thy (fst (dest_Const hd)) of - SOME _ => () - | NONE => err "Non-constructor pattern") - handle TERM ("dest_Const", _) => err "Non-constructor patterns"); - map check_constr_pattern args; - ()) - end - - val (fname, qs, gs, args, rhs) = split_def ctxt geq - - val _ = if not (null gs) then err "Conditional equations" else () - val _ = map check_constr_pattern args - - (* just count occurrences to check linearity *) - val _ = if fold (fold_aterms (fn Bound _ => curry (op +) 1 | _ => I)) args 0 > length qs - then err "Nonlinear patterns" else () - in - () - end - - -fun mk_argvar i T = Free ("_av" ^ (string_of_int i), T) -fun mk_patvar i T = Free ("_pv" ^ (string_of_int i), T) - -fun inst_free var inst thm = - forall_elim inst (forall_intr var thm) - - -fun inst_case_thm thy x P thm = - let - val [Pv, xv] = Term.add_vars (prop_of thm) [] - in - cterm_instantiate [(cterm_of thy (Var xv), cterm_of thy x), - (cterm_of thy (Var Pv), cterm_of thy P)] thm - end - - -fun invent_vars constr i = - let - val Ts = binder_types (fastype_of constr) - val j = i + length Ts - val is = i upto (j - 1) - val avs = map2 mk_argvar is Ts - val pvs = map2 mk_patvar is Ts - in - (avs, pvs, j) - end - - -fun filter_pats thy cons pvars [] = [] - | filter_pats thy cons pvars (([], thm) :: pts) = raise Match - | filter_pats thy cons pvars ((pat :: pats, thm) :: pts) = - case pat of - Free _ => let val inst = list_comb (cons, pvars) - in (inst :: pats, inst_free (cterm_of thy pat) (cterm_of thy inst) thm) - :: (filter_pats thy cons pvars pts) end - | _ => if fst (strip_comb pat) = cons - then (pat :: pats, thm) :: (filter_pats thy cons pvars pts) - else filter_pats thy cons pvars pts - - -fun inst_constrs_of thy (T as Type (name, _)) = - map (fn (Cn,CT) => Envir.subst_TVars (Sign.typ_match thy (body_type CT, T) Vartab.empty) (Const (Cn, CT))) - (the (Datatype.get_datatype_constrs thy name)) - | inst_constrs_of thy _ = raise Match - - -fun transform_pat thy avars c_assum ([] , thm) = raise Match - | transform_pat thy avars c_assum (pat :: pats, thm) = - let - val (_, subps) = strip_comb pat - val eqs = map (cterm_of thy o HOLogic.mk_Trueprop o HOLogic.mk_eq) (avars ~~ subps) - val a_eqs = map assume eqs - val c_eq_pat = simplify (HOL_basic_ss addsimps a_eqs) c_assum - in - (subps @ pats, fold_rev implies_intr eqs - (implies_elim thm c_eq_pat)) - end - - -exception COMPLETENESS - -fun constr_case thy P idx (v :: vs) pats cons = - let - val (avars, pvars, newidx) = invent_vars cons idx - val c_hyp = cterm_of thy (HOLogic.mk_Trueprop (HOLogic.mk_eq (v, list_comb (cons, avars)))) - val c_assum = assume c_hyp - val newpats = map (transform_pat thy avars c_assum) (filter_pats thy cons pvars pats) - in - o_alg thy P newidx (avars @ vs) newpats - |> implies_intr c_hyp - |> fold_rev (forall_intr o cterm_of thy) avars - end - | constr_case _ _ _ _ _ _ = raise Match -and o_alg thy P idx [] (([], Pthm) :: _) = Pthm - | o_alg thy P idx (v :: vs) [] = raise COMPLETENESS - | o_alg thy P idx (v :: vs) pts = - if forall (is_Free o hd o fst) pts (* Var case *) - then o_alg thy P idx vs (map (fn (pv :: pats, thm) => - (pats, refl RS (inst_free (cterm_of thy pv) (cterm_of thy v) thm))) pts) - else (* Cons case *) - let - val T = fastype_of v - val (tname, _) = dest_Type T - val {exhaustion=case_thm, ...} = Datatype.the_datatype thy tname - val constrs = inst_constrs_of thy T - val c_cases = map (constr_case thy P idx (v :: vs) pts) constrs - in - inst_case_thm thy v P case_thm - |> fold (curry op COMP) c_cases - end - | o_alg _ _ _ _ _ = raise Match - - -fun prove_completeness thy xs P qss patss = - let - fun mk_assum qs pats = - HOLogic.mk_Trueprop P - |> fold_rev (curry Logic.mk_implies o HOLogic.mk_Trueprop o HOLogic.mk_eq) (xs ~~ pats) - |> fold_rev Logic.all qs - |> cterm_of thy - - val hyps = map2 mk_assum qss patss - - fun inst_hyps hyp qs = fold (forall_elim o cterm_of thy) qs (assume hyp) - - val assums = map2 inst_hyps hyps qss - in - o_alg thy P 2 xs (patss ~~ assums) - |> fold_rev implies_intr hyps - end - - - -fun pat_completeness_tac ctxt = SUBGOAL (fn (subgoal, i) => - let - val thy = ProofContext.theory_of ctxt - val (vs, subgf) = dest_all_all subgoal - val (cases, _ $ thesis) = Logic.strip_horn subgf - handle Bind => raise COMPLETENESS - - fun pat_of assum = - let - val (qs, imp) = dest_all_all assum - val prems = Logic.strip_imp_prems imp - in - (qs, map (HOLogic.dest_eq o HOLogic.dest_Trueprop) prems) - end - - val (qss, x_pats) = split_list (map pat_of cases) - val xs = map fst (hd x_pats) - handle Empty => raise COMPLETENESS - - val patss = map (map snd) x_pats - - val complete_thm = prove_completeness thy xs thesis qss patss - |> fold_rev (forall_intr o cterm_of thy) vs - in - PRIMITIVE (fn st => Drule.compose_single(complete_thm, i, st)) - end - handle COMPLETENESS => no_tac) - - -fun pat_completeness ctxt = SIMPLE_METHOD' (pat_completeness_tac ctxt) - -val by_pat_completeness_auto = - Proof.global_future_terminal_proof - (Method.Basic (pat_completeness, Position.none), - SOME (Method.Source_i (Args.src (("HOL.auto", []), Position.none)))) - -fun termination_by method int = - Fundef.termination_proof NONE - #> Proof.global_future_terminal_proof - (Method.Basic (method, Position.none), NONE) int - -fun mk_catchall fixes arity_of = - let - fun mk_eqn ((fname, fT), _) = - let - val n = arity_of fname - val (argTs, rT) = chop n (binder_types fT) - |> apsnd (fn Ts => Ts ---> body_type fT) - - val qs = map Free (Name.invent_list [] "a" n ~~ argTs) - in - HOLogic.mk_eq(list_comb (Free (fname, fT), qs), - Const ("HOL.undefined", rT)) - |> HOLogic.mk_Trueprop - |> fold_rev Logic.all qs - end - in - map mk_eqn fixes - end - -fun add_catchall ctxt fixes spec = - let val fqgars = map (split_def ctxt) spec - val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars - |> AList.lookup (op =) #> the - in - spec @ mk_catchall fixes arity_of - end - -fun warn_if_redundant ctxt origs tss = - let - fun msg t = "Ignoring redundant equation: " ^ quote (Syntax.string_of_term ctxt t) - - val (tss', _) = chop (length origs) tss - fun check (t, []) = (Output.warning (msg t); []) - | check (t, s) = s - in - (map check (origs ~~ tss'); tss) - end - - -fun sequential_preproc (config as FundefConfig {sequential, ...}) ctxt fixes spec = - if sequential then - let - val (bnds, eqss) = split_list spec - - val eqs = map the_single eqss - - val feqs = eqs - |> tap (check_defs ctxt fixes) (* Standard checks *) - |> tap (map (check_pats ctxt)) (* More checks for sequential mode *) - - val compleqs = add_catchall ctxt fixes feqs (* Completion *) - - val spliteqs = warn_if_redundant ctxt feqs - (FundefSplit.split_all_equations ctxt compleqs) - - fun restore_spec thms = - bnds ~~ Library.take (length bnds, Library.unflat spliteqs thms) - - val spliteqs' = flat (Library.take (length bnds, spliteqs)) - val fnames = map (fst o fst) fixes - val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs' - - fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs) - |> map (map snd) - - - val bnds' = bnds @ replicate (length spliteqs - length bnds) Attrib.empty_binding - - (* using theorem names for case name currently disabled *) - val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) - (bnds' ~~ spliteqs) - |> flat - in - (flat spliteqs, restore_spec, sort, case_names) - end - else - FundefCommon.empty_preproc check_defs config ctxt fixes spec - -val setup = - Method.setup @{binding pat_completeness} (Scan.succeed pat_completeness) - "Completeness prover for datatype patterns" - #> Context.theory_map (FundefCommon.set_preproc sequential_preproc) - - -val fun_config = FundefConfig { sequential=true, default="%x. undefined" (*FIXME dynamic scoping*), - domintros=false, tailrec=false } - -fun gen_fun add config fixes statements int lthy = - let val group = serial_string () in - lthy - |> LocalTheory.set_group group - |> add fixes statements config - |> by_pat_completeness_auto int - |> LocalTheory.restore - |> LocalTheory.set_group group - |> termination_by (FundefCommon.get_termination_prover lthy) int - end; - -val add_fun = gen_fun Fundef.add_fundef -val add_fun_cmd = gen_fun Fundef.add_fundef_cmd - - - -local structure P = OuterParse and K = OuterKeyword in - -val _ = - OuterSyntax.local_theory' "fun" "define general recursive functions (short version)" K.thy_decl - (fundef_parser fun_config - >> (fn ((config, fixes), statements) => add_fun_cmd config fixes statements)); - -end - -end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/fundef_lib.ML --- a/src/HOL/Tools/function_package/fundef_lib.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,176 +0,0 @@ -(* Title: HOL/Tools/function_package/fundef_lib.ML - Author: Alexander Krauss, TU Muenchen - -A package for general recursive function definitions. -Some fairly general functions that should probably go somewhere else... -*) - -structure FundefLib = struct - -fun map_option f NONE = NONE - | map_option f (SOME x) = SOME (f x); - -fun fold_option f NONE y = y - | fold_option f (SOME x) y = f x y; - -fun fold_map_option f NONE y = (NONE, y) - | fold_map_option f (SOME x) y = apfst SOME (f x y); - -(* Ex: "The variable" ^ plural " is" "s are" vs *) -fun plural sg pl [x] = sg - | plural sg pl _ = pl - -(* lambda-abstracts over an arbitrarily nested tuple - ==> hologic.ML? *) -fun tupled_lambda vars t = - case vars of - (Free v) => lambda (Free v) t - | (Var v) => lambda (Var v) t - | (Const ("Pair", Type ("fun", [Ta, Type ("fun", [Tb, _])]))) $ us $ vs => - (HOLogic.split_const (Ta,Tb, fastype_of t)) $ (tupled_lambda us (tupled_lambda vs t)) - | _ => raise Match - - -fun dest_all (Const ("all", _) $ Abs (a as (_,T,_))) = - let - val (n, body) = Term.dest_abs a - in - (Free (n, T), body) - end - | dest_all _ = raise Match - - -(* Removes all quantifiers from a term, replacing bound variables by frees. *) -fun dest_all_all (t as (Const ("all",_) $ _)) = - let - val (v,b) = dest_all t - val (vs, b') = dest_all_all b - in - (v :: vs, b') - end - | dest_all_all t = ([],t) - - -(* FIXME: similar to Variable.focus *) -fun dest_all_all_ctx ctx (Const ("all", _) $ Abs (a as (n,T,b))) = - let - val [(n', _)] = Variable.variant_frees ctx [] [(n,T)] - val (_, ctx') = ProofContext.add_fixes [(Binding.name n', SOME T, NoSyn)] ctx - - val (n'', body) = Term.dest_abs (n', T, b) - val _ = (n' = n'') orelse error "dest_all_ctx" - (* Note: We assume that n' does not occur in the body. Otherwise it would be fixed. *) - - val (ctx'', vs, bd) = dest_all_all_ctx ctx' body - in - (ctx'', (n', T) :: vs, bd) - end - | dest_all_all_ctx ctx t = - (ctx, [], t) - - -fun map3 _ [] [] [] = [] - | map3 f (x :: xs) (y :: ys) (z :: zs) = f x y z :: map3 f xs ys zs - | map3 _ _ _ _ = raise Library.UnequalLengths; - -fun map4 _ [] [] [] [] = [] - | map4 f (x :: xs) (y :: ys) (z :: zs) (u :: us) = f x y z u :: map4 f xs ys zs us - | map4 _ _ _ _ _ = raise Library.UnequalLengths; - -fun map6 _ [] [] [] [] [] [] = [] - | map6 f (x :: xs) (y :: ys) (z :: zs) (u :: us) (v :: vs) (w :: ws) = f x y z u v w :: map6 f xs ys zs us vs ws - | map6 _ _ _ _ _ _ _ = raise Library.UnequalLengths; - -fun map7 _ [] [] [] [] [] [] [] = [] - | map7 f (x :: xs) (y :: ys) (z :: zs) (u :: us) (v :: vs) (w :: ws) (b :: bs) = f x y z u v w b :: map7 f xs ys zs us vs ws bs - | map7 _ _ _ _ _ _ _ _ = raise Library.UnequalLengths; - - - -(* forms all "unordered pairs": [1, 2, 3] ==> [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)] *) -(* ==> library *) -fun unordered_pairs [] = [] - | unordered_pairs (x::xs) = map (pair x) (x::xs) @ unordered_pairs xs - - -(* Replaces Frees by name. Works with loose Bounds. *) -fun replace_frees assoc = - map_aterms (fn c as Free (n, _) => the_default c (AList.lookup (op =) assoc n) - | t => t) - - -fun rename_bound n (Q $ Abs(_, T, b)) = (Q $ Abs(n, T, b)) - | rename_bound n _ = raise Match - -fun mk_forall_rename (n, v) = - rename_bound n o Logic.all v - -fun forall_intr_rename (n, cv) thm = - let - val allthm = forall_intr cv thm - val (_ $ abs) = prop_of allthm - in - Thm.rename_boundvars abs (Abs (n, dummyT, Term.dummy_pattern dummyT)) allthm - end - - -(* Returns the frees in a term in canonical order, excluding the fixes from the context *) -fun frees_in_term ctxt t = - Term.add_frees t [] - |> filter_out (Variable.is_fixed ctxt o fst) - |> rev - - -datatype proof_attempt = Solved of thm | Stuck of thm | Fail - -fun try_proof cgoal tac = - case SINGLE tac (Goal.init cgoal) of - NONE => Fail - | SOME st => if Thm.no_prems st then Solved (Goal.finish st) else Stuck st - - -fun dest_binop_list cn (t as (Const (n, _) $ a $ b)) = - if cn = n then dest_binop_list cn a @ dest_binop_list cn b else [ t ] - | dest_binop_list _ t = [ t ] - - -(* separate two parts in a +-expression: - "a + b + c + d + e" --> "(a + b + d) + (c + e)" - - Here, + can be any binary operation that is AC. - - cn - The name of the binop-constructor (e.g. @{const_name Un}) - ac - the AC rewrite rules for cn - is - the list of indices of the expressions that should become the first part - (e.g. [0,1,3] in the above example) -*) - -fun regroup_conv neu cn ac is ct = - let - val mk = HOLogic.mk_binop cn - val t = term_of ct - val xs = dest_binop_list cn t - val js = 0 upto (length xs) - 1 \\ is - val ty = fastype_of t - val thy = theory_of_cterm ct - in - Goal.prove_internal [] - (cterm_of thy - (Logic.mk_equals (t, - if is = [] - then mk (Const (neu, ty), foldr1 mk (map (nth xs) js)) - else if js = [] - then mk (foldr1 mk (map (nth xs) is), Const (neu, ty)) - else mk (foldr1 mk (map (nth xs) is), foldr1 mk (map (nth xs) js))))) - (K (rewrite_goals_tac ac - THEN rtac Drule.reflexive_thm 1)) - end - -(* instance for unions *) -fun regroup_union_conv t = regroup_conv @{const_name Set.empty} @{const_name Un} - (map (fn t => t RS eq_reflection) (@{thms "Un_ac"} @ - @{thms "Un_empty_right"} @ - @{thms "Un_empty_left"})) t - - -end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/induction_scheme.ML --- a/src/HOL/Tools/function_package/induction_scheme.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,405 +0,0 @@ -(* Title: HOL/Tools/function_package/induction_scheme.ML - Author: Alexander Krauss, TU Muenchen - -A method to prove induction schemes. -*) - -signature INDUCTION_SCHEME = -sig - val mk_ind_tac : (int -> tactic) -> (int -> tactic) -> (int -> tactic) - -> Proof.context -> thm list -> tactic - val induct_scheme_tac : Proof.context -> thm list -> tactic - val setup : theory -> theory -end - - -structure InductionScheme : INDUCTION_SCHEME = -struct - -open FundefLib - - -type rec_call_info = int * (string * typ) list * term list * term list - -datatype scheme_case = - SchemeCase of - { - bidx : int, - qs: (string * typ) list, - oqnames: string list, - gs: term list, - lhs: term list, - rs: rec_call_info list - } - -datatype scheme_branch = - SchemeBranch of - { - P : term, - xs: (string * typ) list, - ws: (string * typ) list, - Cs: term list - } - -datatype ind_scheme = - IndScheme of - { - T: typ, (* sum of products *) - branches: scheme_branch list, - cases: scheme_case list - } - -val ind_atomize = MetaSimplifier.rewrite true @{thms induct_atomize} -val ind_rulify = MetaSimplifier.rewrite true @{thms induct_rulify} - -fun meta thm = thm RS eq_reflection - -val sum_prod_conv = MetaSimplifier.rewrite true - (map meta (@{thm split_conv} :: @{thms sum.cases})) - -fun term_conv thy cv t = - cv (cterm_of thy t) - |> prop_of |> Logic.dest_equals |> snd - -fun mk_relT T = HOLogic.mk_setT (HOLogic.mk_prodT (T, T)) - -fun dest_hhf ctxt t = - let - val (ctxt', vars, imp) = dest_all_all_ctx ctxt t - in - (ctxt', vars, Logic.strip_imp_prems imp, Logic.strip_imp_concl imp) - end - - -fun mk_scheme' ctxt cases concl = - let - fun mk_branch concl = - let - val (ctxt', ws, Cs, _ $ Pxs) = dest_hhf ctxt concl - val (P, xs) = strip_comb Pxs - in - SchemeBranch { P=P, xs=map dest_Free xs, ws=ws, Cs=Cs } - end - - val (branches, cases') = (* correction *) - case Logic.dest_conjunction_list concl of - [conc] => - let - val _ $ Pxs = Logic.strip_assums_concl conc - val (P, _) = strip_comb Pxs - val (cases', conds) = take_prefix (Term.exists_subterm (curry op aconv P)) cases - val concl' = fold_rev (curry Logic.mk_implies) conds conc - in - ([mk_branch concl'], cases') - end - | concls => (map mk_branch concls, cases) - - fun mk_case premise = - let - val (ctxt', qs, prems, _ $ Plhs) = dest_hhf ctxt premise - val (P, lhs) = strip_comb Plhs - - fun bidx Q = find_index (fn SchemeBranch {P=P',...} => Q aconv P') branches - - fun mk_rcinfo pr = - let - val (ctxt'', Gvs, Gas, _ $ Phyp) = dest_hhf ctxt' pr - val (P', rcs) = strip_comb Phyp - in - (bidx P', Gvs, Gas, rcs) - end - - fun is_pred v = exists (fn SchemeBranch {P,...} => v aconv P) branches - - val (gs, rcprs) = - take_prefix (not o Term.exists_subterm is_pred) prems - in - SchemeCase {bidx=bidx P, qs=qs, oqnames=map fst qs(*FIXME*), gs=gs, lhs=lhs, rs=map mk_rcinfo rcprs} - end - - fun PT_of (SchemeBranch { xs, ...}) = - foldr1 HOLogic.mk_prodT (map snd xs) - - val ST = BalancedTree.make (uncurry SumTree.mk_sumT) (map PT_of branches) - in - IndScheme {T=ST, cases=map mk_case cases', branches=branches } - end - - - -fun mk_completeness ctxt (IndScheme {cases, branches, ...}) bidx = - let - val SchemeBranch { xs, ws, Cs, ... } = nth branches bidx - val relevant_cases = filter (fn SchemeCase {bidx=bidx', ...} => bidx' = bidx) cases - - val allqnames = fold (fn SchemeCase {qs, ...} => fold (insert (op =) o Free) qs) relevant_cases [] - val (Pbool :: xs') = map Free (Variable.variant_frees ctxt allqnames (("P", HOLogic.boolT) :: xs)) - val Cs' = map (Pattern.rewrite_term (ProofContext.theory_of ctxt) (filter_out (op aconv) (map Free xs ~~ xs')) []) Cs - - fun mk_case (SchemeCase {qs, oqnames, gs, lhs, ...}) = - HOLogic.mk_Trueprop Pbool - |> fold_rev (fn x_l => curry Logic.mk_implies (HOLogic.mk_Trueprop(HOLogic.mk_eq x_l))) - (xs' ~~ lhs) - |> fold_rev (curry Logic.mk_implies) gs - |> fold_rev mk_forall_rename (oqnames ~~ map Free qs) - in - HOLogic.mk_Trueprop Pbool - |> fold_rev (curry Logic.mk_implies o mk_case) relevant_cases - |> fold_rev (curry Logic.mk_implies) Cs' - |> fold_rev (Logic.all o Free) ws - |> fold_rev mk_forall_rename (map fst xs ~~ xs') - |> mk_forall_rename ("P", Pbool) - end - -fun mk_wf ctxt R (IndScheme {T, ...}) = - HOLogic.Trueprop $ (Const (@{const_name "wf"}, mk_relT T --> HOLogic.boolT) $ R) - -fun mk_ineqs R (IndScheme {T, cases, branches}) = - let - fun inject i ts = - SumTree.mk_inj T (length branches) (i + 1) (foldr1 HOLogic.mk_prod ts) - - val thesis = Free ("thesis", HOLogic.boolT) (* FIXME *) - - fun mk_pres bdx args = - let - val SchemeBranch { xs, ws, Cs, ... } = nth branches bdx - fun replace (x, v) t = betapply (lambda (Free x) t, v) - val Cs' = map (fold replace (xs ~~ args)) Cs - val cse = - HOLogic.mk_Trueprop thesis - |> fold_rev (curry Logic.mk_implies) Cs' - |> fold_rev (Logic.all o Free) ws - in - Logic.mk_implies (cse, HOLogic.mk_Trueprop thesis) - end - - fun f (SchemeCase {bidx, qs, oqnames, gs, lhs, rs, ...}) = - let - fun g (bidx', Gvs, Gas, rcarg) = - let val export = - fold_rev (curry Logic.mk_implies) Gas - #> fold_rev (curry Logic.mk_implies) gs - #> fold_rev (Logic.all o Free) Gvs - #> fold_rev mk_forall_rename (oqnames ~~ map Free qs) - in - (HOLogic.mk_mem (HOLogic.mk_prod (inject bidx' rcarg, inject bidx lhs), R) - |> HOLogic.mk_Trueprop - |> export, - mk_pres bidx' rcarg - |> export - |> Logic.all thesis) - end - in - map g rs - end - in - map f cases - end - - -fun mk_hol_imp a b = HOLogic.imp $ a $ b - -fun mk_ind_goal thy branches = - let - fun brnch (SchemeBranch { P, xs, ws, Cs, ... }) = - HOLogic.mk_Trueprop (list_comb (P, map Free xs)) - |> fold_rev (curry Logic.mk_implies) Cs - |> fold_rev (Logic.all o Free) ws - |> term_conv thy ind_atomize - |> ObjectLogic.drop_judgment thy - |> tupled_lambda (foldr1 HOLogic.mk_prod (map Free xs)) - in - SumTree.mk_sumcases HOLogic.boolT (map brnch branches) - end - - -fun mk_induct_rule ctxt R x complete_thms wf_thm ineqss (IndScheme {T, cases=scases, branches}) = - let - val n = length branches - - val scases_idx = map_index I scases - - fun inject i ts = - SumTree.mk_inj T n (i + 1) (foldr1 HOLogic.mk_prod ts) - val P_of = nth (map (fn (SchemeBranch { P, ... }) => P) branches) - - val thy = ProofContext.theory_of ctxt - val cert = cterm_of thy - - val P_comp = mk_ind_goal thy branches - - (* Inductive Hypothesis: !!z. (z,x):R ==> P z *) - val ihyp = Term.all T $ Abs ("z", T, - Logic.mk_implies - (HOLogic.mk_Trueprop ( - Const ("op :", HOLogic.mk_prodT (T, T) --> mk_relT T --> HOLogic.boolT) - $ (HOLogic.pair_const T T $ Bound 0 $ x) - $ R), - HOLogic.mk_Trueprop (P_comp $ Bound 0))) - |> cert - - val aihyp = assume ihyp - - (* Rule for case splitting along the sum types *) - val xss = map (fn (SchemeBranch { xs, ... }) => map Free xs) branches - val pats = map_index (uncurry inject) xss - val sum_split_rule = FundefDatatype.prove_completeness thy [x] (P_comp $ x) xss (map single pats) - - fun prove_branch (bidx, (SchemeBranch { P, xs, ws, Cs, ... }, (complete_thm, pat))) = - let - val fxs = map Free xs - val branch_hyp = assume (cert (HOLogic.mk_Trueprop (HOLogic.mk_eq (x, pat)))) - - val C_hyps = map (cert #> assume) Cs - - val (relevant_cases, ineqss') = filter (fn ((_, SchemeCase {bidx=bidx', ...}), _) => bidx' = bidx) (scases_idx ~~ ineqss) - |> split_list - - fun prove_case (cidx, SchemeCase {qs, oqnames, gs, lhs, rs, ...}) ineq_press = - let - val case_hyps = map (assume o cert o HOLogic.mk_Trueprop o HOLogic.mk_eq) (fxs ~~ lhs) - - val cqs = map (cert o Free) qs - val ags = map (assume o cert) gs - - val replace_x_ss = HOL_basic_ss addsimps (branch_hyp :: case_hyps) - val sih = full_simplify replace_x_ss aihyp - - fun mk_Prec (idx, Gvs, Gas, rcargs) (ineq, pres) = - let - val cGas = map (assume o cert) Gas - val cGvs = map (cert o Free) Gvs - val import = fold forall_elim (cqs @ cGvs) - #> fold Thm.elim_implies (ags @ cGas) - val ipres = pres - |> forall_elim (cert (list_comb (P_of idx, rcargs))) - |> import - in - sih |> forall_elim (cert (inject idx rcargs)) - |> Thm.elim_implies (import ineq) (* Psum rcargs *) - |> Conv.fconv_rule sum_prod_conv - |> Conv.fconv_rule ind_rulify - |> (fn th => th COMP ipres) (* P rs *) - |> fold_rev (implies_intr o cprop_of) cGas - |> fold_rev forall_intr cGvs - end - - val P_recs = map2 mk_Prec rs ineq_press (* [P rec1, P rec2, ... ] *) - - val step = HOLogic.mk_Trueprop (list_comb (P, lhs)) - |> fold_rev (curry Logic.mk_implies o prop_of) P_recs - |> fold_rev (curry Logic.mk_implies) gs - |> fold_rev (Logic.all o Free) qs - |> cert - - val Plhs_to_Pxs_conv = - foldl1 (uncurry Conv.combination_conv) - (Conv.all_conv :: map (fn ch => K (Thm.symmetric (ch RS eq_reflection))) case_hyps) - - val res = assume step - |> fold forall_elim cqs - |> fold Thm.elim_implies ags - |> fold Thm.elim_implies P_recs (* P lhs *) - |> Conv.fconv_rule (Conv.arg_conv Plhs_to_Pxs_conv) (* P xs *) - |> fold_rev (implies_intr o cprop_of) (ags @ case_hyps) - |> fold_rev forall_intr cqs (* !!qs. Gas ==> xs = lhss ==> P xs *) - in - (res, (cidx, step)) - end - - val (cases, steps) = split_list (map2 prove_case relevant_cases ineqss') - - val bstep = complete_thm - |> forall_elim (cert (list_comb (P, fxs))) - |> fold (forall_elim o cert) (fxs @ map Free ws) - |> fold Thm.elim_implies C_hyps (* FIXME: optimization using rotate_prems *) - |> fold Thm.elim_implies cases (* P xs *) - |> fold_rev (implies_intr o cprop_of) C_hyps - |> fold_rev (forall_intr o cert o Free) ws - - val Pxs = cert (HOLogic.mk_Trueprop (P_comp $ x)) - |> Goal.init - |> (MetaSimplifier.rewrite_goals_tac (map meta (branch_hyp :: @{thm split_conv} :: @{thms sum.cases})) - THEN CONVERSION ind_rulify 1) - |> Seq.hd - |> Thm.elim_implies (Conv.fconv_rule Drule.beta_eta_conversion bstep) - |> Goal.finish - |> implies_intr (cprop_of branch_hyp) - |> fold_rev (forall_intr o cert) fxs - in - (Pxs, steps) - end - - val (branches, steps) = split_list (map_index prove_branch (branches ~~ (complete_thms ~~ pats))) - |> apsnd flat - - val istep = sum_split_rule - |> fold (fn b => fn th => Drule.compose_single (b, 1, th)) branches - |> implies_intr ihyp - |> forall_intr (cert x) (* "!!x. (!!y P x" *) - - val induct_rule = - @{thm "wf_induct_rule"} - |> (curry op COMP) wf_thm - |> (curry op COMP) istep - - val steps_sorted = map snd (sort (int_ord o pairself fst) steps) - in - (steps_sorted, induct_rule) - end - - -fun mk_ind_tac comp_tac pres_tac term_tac ctxt facts = (ALLGOALS (Method.insert_tac facts)) THEN HEADGOAL -(SUBGOAL (fn (t, i) => - let - val (ctxt', _, cases, concl) = dest_hhf ctxt t - val scheme as IndScheme {T=ST, branches, ...} = mk_scheme' ctxt' cases concl -(* val _ = Output.tracing (makestring scheme)*) - val ([Rn,xn], ctxt'') = Variable.variant_fixes ["R","x"] ctxt' - val R = Free (Rn, mk_relT ST) - val x = Free (xn, ST) - val cert = cterm_of (ProofContext.theory_of ctxt) - - val ineqss = mk_ineqs R scheme - |> map (map (pairself (assume o cert))) - val complete = map (mk_completeness ctxt scheme #> cert #> assume) (0 upto (length branches - 1)) - val wf_thm = mk_wf ctxt R scheme |> cert |> assume - - val (descent, pres) = split_list (flat ineqss) - val newgoals = complete @ pres @ wf_thm :: descent - - val (steps, indthm) = mk_induct_rule ctxt'' R x complete wf_thm ineqss scheme - - fun project (i, SchemeBranch {xs, ...}) = - let - val inst = cert (SumTree.mk_inj ST (length branches) (i + 1) (foldr1 HOLogic.mk_prod (map Free xs))) - in - indthm |> Drule.instantiate' [] [SOME inst] - |> simplify SumTree.sumcase_split_ss - |> Conv.fconv_rule ind_rulify -(* |> (fn thm => (Output.tracing (makestring thm); thm))*) - end - - val res = Conjunction.intr_balanced (map_index project branches) - |> fold_rev implies_intr (map cprop_of newgoals @ steps) - |> (fn thm => Thm.generalize ([], [Rn]) (Thm.maxidx_of thm + 1) thm) - - val nbranches = length branches - val npres = length pres - in - Thm.compose_no_flatten false (res, length newgoals) i - THEN term_tac (i + nbranches + npres) - THEN (EVERY (map (TRY o pres_tac) ((i + nbranches + npres - 1) downto (i + nbranches)))) - THEN (EVERY (map (TRY o comp_tac) ((i + nbranches - 1) downto i))) - end)) - - -fun induct_scheme_tac ctxt = - mk_ind_tac (K all_tac) (assume_tac APPEND' Goal.assume_rule_tac ctxt) (K all_tac) ctxt; - -val setup = - Method.setup @{binding induct_scheme} (Scan.succeed (RAW_METHOD o induct_scheme_tac)) - "proves an induction principle" - -end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/inductive_wrap.ML --- a/src/HOL/Tools/function_package/inductive_wrap.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,67 +0,0 @@ -(* Title: HOL/Tools/function_package/inductive_wrap.ML - Author: Alexander Krauss, TU Muenchen - - -A wrapper around the inductive package, restoring the quantifiers in -the introduction and elimination rules. -*) - -signature FUNDEF_INDUCTIVE_WRAP = -sig - val inductive_def : term list - -> ((bstring * typ) * mixfix) * local_theory - -> thm list * (term * thm * thm * local_theory) -end - -structure FundefInductiveWrap: FUNDEF_INDUCTIVE_WRAP = -struct - -open FundefLib - -fun requantify ctxt lfix orig_def thm = - let - val (qs, t) = dest_all_all orig_def - val thy = theory_of_thm thm - val frees = frees_in_term ctxt t - |> remove (op =) lfix - val vars = Term.add_vars (prop_of thm) [] |> rev - - val varmap = frees ~~ vars - in - fold_rev (fn Free (n, T) => - forall_intr_rename (n, cterm_of thy (Var (the_default (("",0), T) (AList.lookup (op =) varmap (n, T)))))) - qs - thm - end - - - -fun inductive_def defs (((R, T), mixfix), lthy) = - let - val ({intrs = intrs_gen, elims = [elim_gen], preds = [ Rdef ], induct, ...}, lthy) = - Inductive.add_inductive_i - {quiet_mode = false, - verbose = ! Toplevel.debug, - kind = Thm.internalK, - alt_name = Binding.empty, - coind = false, - no_elim = false, - no_ind = false, - skip_mono = true, - fork_mono = false} - [((Binding.name R, T), NoSyn)] (* the relation *) - [] (* no parameters *) - (map (fn t => (Attrib.empty_binding, t)) defs) (* the intros *) - [] (* no special monos *) - lthy - - val intrs = map2 (requantify lthy (R, T)) defs intrs_gen - - val elim = elim_gen - |> forall_intr_vars (* FIXME... *) - - in - (intrs, (Rdef, elim, induct, lthy)) - end - -end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/lexicographic_order.ML --- a/src/HOL/Tools/function_package/lexicographic_order.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,230 +0,0 @@ -(* Title: HOL/Tools/function_package/lexicographic_order.ML - Author: Lukas Bulwahn, TU Muenchen - -Method for termination proofs with lexicographic orderings. -*) - -signature LEXICOGRAPHIC_ORDER = -sig - val lex_order_tac : Proof.context -> tactic -> tactic - val lexicographic_order_tac : Proof.context -> tactic - val lexicographic_order : Proof.context -> Proof.method - - val setup: theory -> theory -end - -structure LexicographicOrder : LEXICOGRAPHIC_ORDER = -struct - -open FundefLib - -(** General stuff **) - -fun mk_measures domT mfuns = - let - val relT = HOLogic.mk_setT (HOLogic.mk_prodT (domT, domT)) - val mlexT = (domT --> HOLogic.natT) --> relT --> relT - fun mk_ms [] = Const (@{const_name Set.empty}, relT) - | mk_ms (f::fs) = - Const (@{const_name "mlex_prod"}, mlexT) $ f $ mk_ms fs - in - mk_ms mfuns - end - -fun del_index n [] = [] - | del_index n (x :: xs) = - if n > 0 then x :: del_index (n - 1) xs else xs - -fun transpose ([]::_) = [] - | transpose xss = map hd xss :: transpose (map tl xss) - -(** Matrix cell datatype **) - -datatype cell = Less of thm| LessEq of (thm * thm) | None of (thm * thm) | False of thm; - -fun is_Less (Less _) = true - | is_Less _ = false - -fun is_LessEq (LessEq _) = true - | is_LessEq _ = false - -fun pr_cell (Less _ ) = " < " - | pr_cell (LessEq _) = " <=" - | pr_cell (None _) = " ? " - | pr_cell (False _) = " F " - - -(** Proof attempts to build the matrix **) - -fun dest_term (t : term) = - let - val (vars, prop) = FundefLib.dest_all_all t - val (prems, concl) = Logic.strip_horn prop - val (lhs, rhs) = concl - |> HOLogic.dest_Trueprop - |> HOLogic.dest_mem |> fst - |> HOLogic.dest_prod - in - (vars, prems, lhs, rhs) - end - -fun mk_goal (vars, prems, lhs, rhs) rel = - let - val concl = HOLogic.mk_binrel rel (lhs, rhs) |> HOLogic.mk_Trueprop - in - fold_rev Logic.all vars (Logic.list_implies (prems, concl)) - end - -fun prove thy solve_tac t = - cterm_of thy t |> Goal.init - |> SINGLE solve_tac |> the - -fun mk_cell (thy : theory) solve_tac (vars, prems, lhs, rhs) mfun = - let - val goals = cterm_of thy o mk_goal (vars, prems, mfun $ lhs, mfun $ rhs) - in - case try_proof (goals @{const_name HOL.less}) solve_tac of - Solved thm => Less thm - | Stuck thm => - (case try_proof (goals @{const_name HOL.less_eq}) solve_tac of - Solved thm2 => LessEq (thm2, thm) - | Stuck thm2 => - if prems_of thm2 = [HOLogic.Trueprop $ HOLogic.false_const] then False thm2 - else None (thm2, thm) - | _ => raise Match) (* FIXME *) - | _ => raise Match - end - - -(** Search algorithms **) - -fun check_col ls = forall (fn c => is_Less c orelse is_LessEq c) ls andalso not (forall (is_LessEq) ls) - -fun transform_table table col = table |> filter_out (fn x => is_Less (nth x col)) |> map (del_index col) - -fun transform_order col order = map (fn x => if x >= col then x + 1 else x) order - -(* simple depth-first search algorithm for the table *) -fun search_table table = - case table of - [] => SOME [] - | _ => - let - val col = find_index (check_col) (transpose table) - in case col of - ~1 => NONE - | _ => - let - val order_opt = (table, col) |-> transform_table |> search_table - in case order_opt of - NONE => NONE - | SOME order =>SOME (col :: transform_order col order) - end - end - -(** Proof Reconstruction **) - -(* prove row :: cell list -> tactic *) -fun prove_row (Less less_thm :: _) = - (rtac @{thm "mlex_less"} 1) - THEN PRIMITIVE (Thm.elim_implies less_thm) - | prove_row (LessEq (lesseq_thm, _) :: tail) = - (rtac @{thm "mlex_leq"} 1) - THEN PRIMITIVE (Thm.elim_implies lesseq_thm) - THEN prove_row tail - | prove_row _ = sys_error "lexicographic_order" - - -(** Error reporting **) - -fun pr_table table = writeln (cat_lines (map (fn r => concat (map pr_cell r)) table)) - -fun pr_goals ctxt st = - Display.pretty_goals_aux (Syntax.pp ctxt) Markup.none (true, false) (Thm.nprems_of st) st - |> Pretty.chunks - |> Pretty.string_of - -fun row_index i = chr (i + 97) -fun col_index j = string_of_int (j + 1) - -fun pr_unprovable_cell _ ((i,j), Less _) = "" - | pr_unprovable_cell ctxt ((i,j), LessEq (_, st)) = - "(" ^ row_index i ^ ", " ^ col_index j ^ ", <):\n" ^ pr_goals ctxt st - | pr_unprovable_cell ctxt ((i,j), None (st_leq, st_less)) = - "(" ^ row_index i ^ ", " ^ col_index j ^ ", <):\n" ^ pr_goals ctxt st_less - ^ "\n(" ^ row_index i ^ ", " ^ col_index j ^ ", <=):\n" ^ pr_goals ctxt st_leq - | pr_unprovable_cell ctxt ((i,j), False st) = - "(" ^ row_index i ^ ", " ^ col_index j ^ ", <):\n" ^ pr_goals ctxt st - -fun pr_unprovable_subgoals ctxt table = - table - |> map_index (fn (i,cs) => map_index (fn (j,x) => ((i,j), x)) cs) - |> flat - |> map (pr_unprovable_cell ctxt) - -fun no_order_msg ctxt table tl measure_funs = - let - val prterm = Syntax.string_of_term ctxt - fun pr_fun t i = string_of_int i ^ ") " ^ prterm t - - fun pr_goal t i = - let - val (_, _, lhs, rhs) = dest_term t - in (* also show prems? *) - i ^ ") " ^ prterm rhs ^ " ~> " ^ prterm lhs - end - - val gc = map (fn i => chr (i + 96)) (1 upto length table) - val mc = 1 upto length measure_funs - val tstr = "Result matrix:" :: (" " ^ concat (map (enclose " " " " o string_of_int) mc)) - :: map2 (fn r => fn i => i ^ ": " ^ concat (map pr_cell r)) table gc - val gstr = "Calls:" :: map2 (prefix " " oo pr_goal) tl gc - val mstr = "Measures:" :: map2 (prefix " " oo pr_fun) measure_funs mc - val ustr = "Unfinished subgoals:" :: pr_unprovable_subgoals ctxt table - in - cat_lines (ustr @ gstr @ mstr @ tstr @ ["", "Could not find lexicographic termination order."]) - end - -(** The Main Function **) - -fun lex_order_tac ctxt solve_tac (st: thm) = - let - val thy = ProofContext.theory_of ctxt - val ((trueprop $ (wf $ rel)) :: tl) = prems_of st - - val (domT, _) = HOLogic.dest_prodT (HOLogic.dest_setT (fastype_of rel)) - - val measure_funs = MeasureFunctions.get_measure_functions ctxt domT (* 1: generate measures *) - - (* 2: create table *) - val table = map (fn t => map (mk_cell thy solve_tac (dest_term t)) measure_funs) tl - - val order = the (search_table table) (* 3: search table *) - handle Option => error (no_order_msg ctxt table tl measure_funs) - - val clean_table = map (fn x => map (nth x) order) table - - val relation = mk_measures domT (map (nth measure_funs) order) - val _ = writeln ("Found termination order: " ^ quote (Syntax.string_of_term ctxt relation)) - - in (* 4: proof reconstruction *) - st |> (PRIMITIVE (cterm_instantiate [(cterm_of thy rel, cterm_of thy relation)]) - THEN (REPEAT (rtac @{thm "wf_mlex"} 1)) - THEN (rtac @{thm "wf_empty"} 1) - THEN EVERY (map prove_row clean_table)) - end - -fun lexicographic_order_tac ctxt = - TRY (FundefCommon.apply_termination_rule ctxt 1) - THEN lex_order_tac ctxt (auto_tac (local_clasimpset_of ctxt addsimps2 FundefCommon.TerminationSimps.get ctxt)) - -val lexicographic_order = SIMPLE_METHOD o lexicographic_order_tac - -val setup = - Method.setup @{binding lexicographic_order} - (Method.sections clasimp_modifiers >> (K lexicographic_order)) - "termination prover for lexicographic orderings" - #> Context.theory_map (FundefCommon.set_termination_prover lexicographic_order) - -end; - diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/measure_functions.ML --- a/src/HOL/Tools/function_package/measure_functions.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,59 +0,0 @@ -(* Title: HOL/Tools/function_package/measure_functions.ML - ID: $Id$ - Author: Alexander Krauss, TU Muenchen - -Measure functions, generated heuristically -*) - -signature MEASURE_FUNCTIONS = -sig - - val get_measure_functions : Proof.context -> typ -> term list - val setup : theory -> theory - -end - -structure MeasureFunctions : MEASURE_FUNCTIONS = -struct - -(** User-declared size functions **) -structure MeasureHeuristicRules = NamedThmsFun( - val name = "measure_function" - val description = "Rules that guide the heuristic generation of measure functions" -); - -fun mk_is_measures t = Const (@{const_name "is_measure"}, fastype_of t --> HOLogic.boolT) $ t - -fun find_measures ctxt T = - DEPTH_SOLVE (resolve_tac (MeasureHeuristicRules.get ctxt) 1) - (HOLogic.mk_Trueprop (mk_is_measures (Var (("f",0), T --> HOLogic.natT))) - |> cterm_of (ProofContext.theory_of ctxt) |> Goal.init) - |> Seq.map (prop_of #> (fn _ $ (_ $ (_ $ f)) => f)) - |> Seq.list_of - - -(** Generating Measure Functions **) - -fun constant_0 T = Abs ("x", T, HOLogic.zero) -fun constant_1 T = Abs ("x", T, HOLogic.Suc_zero) - -fun mk_funorder_funs (Type ("+", [fT, sT])) = - map (fn m => SumTree.mk_sumcase fT sT HOLogic.natT m (constant_0 sT)) (mk_funorder_funs fT) - @ map (fn m => SumTree.mk_sumcase fT sT HOLogic.natT (constant_0 fT) m) (mk_funorder_funs sT) - | mk_funorder_funs T = [ constant_1 T ] - -fun mk_ext_base_funs ctxt (Type("+", [fT, sT])) = - map_product (SumTree.mk_sumcase fT sT HOLogic.natT) - (mk_ext_base_funs ctxt fT) (mk_ext_base_funs ctxt sT) - | mk_ext_base_funs ctxt T = find_measures ctxt T - -fun mk_all_measure_funs ctxt (T as Type ("+", _)) = - mk_ext_base_funs ctxt T @ mk_funorder_funs T - | mk_all_measure_funs ctxt T = find_measures ctxt T - -val get_measure_functions = mk_all_measure_funs - -val setup = MeasureHeuristicRules.setup - -end - diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/mutual.ML --- a/src/HOL/Tools/function_package/mutual.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,314 +0,0 @@ -(* Title: HOL/Tools/function_package/mutual.ML - Author: Alexander Krauss, TU Muenchen - -A package for general recursive function definitions. -Tools for mutual recursive definitions. -*) - -signature FUNDEF_MUTUAL = -sig - - val prepare_fundef_mutual : FundefCommon.fundef_config - -> string (* defname *) - -> ((string * typ) * mixfix) list - -> term list - -> local_theory - -> ((thm (* goalstate *) - * (thm -> FundefCommon.fundef_result) (* proof continuation *) - ) * local_theory) - -end - - -structure FundefMutual: FUNDEF_MUTUAL = -struct - -open FundefLib -open FundefCommon - - - - -type qgar = string * (string * typ) list * term list * term list * term - -fun name_of_fqgar ((f, _, _, _, _): qgar) = f - -datatype mutual_part = - MutualPart of - { - i : int, - i' : int, - fvar : string * typ, - cargTs: typ list, - f_def: term, - - f: term option, - f_defthm : thm option - } - - -datatype mutual_info = - Mutual of - { - n : int, - n' : int, - fsum_var : string * typ, - - ST: typ, - RST: typ, - - parts: mutual_part list, - fqgars: qgar list, - qglrs: ((string * typ) list * term list * term * term) list, - - fsum : term option - } - -fun mutual_induct_Pnames n = - if n < 5 then fst (chop n ["P","Q","R","S"]) - else map (fn i => "P" ^ string_of_int i) (1 upto n) - -fun get_part fname = - the o find_first (fn (MutualPart {fvar=(n,_), ...}) => n = fname) - -(* FIXME *) -fun mk_prod_abs e (t1, t2) = - let - val bTs = rev (map snd e) - val T1 = fastype_of1 (bTs, t1) - val T2 = fastype_of1 (bTs, t2) - in - HOLogic.pair_const T1 T2 $ t1 $ t2 - end; - - -fun analyze_eqs ctxt defname fs eqs = - let - val num = length fs - val fnames = map fst fs - val fqgars = map (split_def ctxt) eqs - val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars - |> AList.lookup (op =) #> the - - fun curried_types (fname, fT) = - let - val (caTs, uaTs) = chop (arity_of fname) (binder_types fT) - in - (caTs, uaTs ---> body_type fT) - end - - val (caTss, resultTs) = split_list (map curried_types fs) - val argTs = map (foldr1 HOLogic.mk_prodT) caTss - - val dresultTs = distinct (Type.eq_type Vartab.empty) resultTs - val n' = length dresultTs - - val RST = BalancedTree.make (uncurry SumTree.mk_sumT) dresultTs - val ST = BalancedTree.make (uncurry SumTree.mk_sumT) argTs - - val fsum_type = ST --> RST - - val ([fsum_var_name], _) = Variable.add_fixes [ defname ^ "_sum" ] ctxt - val fsum_var = (fsum_var_name, fsum_type) - - fun define (fvar as (n, T)) caTs resultT i = - let - val vars = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) caTs (* FIXME: Bind xs properly *) - val i' = find_index (fn Ta => Type.eq_type Vartab.empty (Ta, resultT)) dresultTs + 1 - - val f_exp = SumTree.mk_proj RST n' i' (Free fsum_var $ SumTree.mk_inj ST num i (foldr1 HOLogic.mk_prod vars)) - val def = Term.abstract_over (Free fsum_var, fold_rev lambda vars f_exp) - - val rew = (n, fold_rev lambda vars f_exp) - in - (MutualPart {i=i, i'=i', fvar=fvar,cargTs=caTs,f_def=def,f=NONE,f_defthm=NONE}, rew) - end - - val (parts, rews) = split_list (map4 define fs caTss resultTs (1 upto num)) - - fun convert_eqs (f, qs, gs, args, rhs) = - let - val MutualPart {i, i', ...} = get_part f parts - in - (qs, gs, SumTree.mk_inj ST num i (foldr1 (mk_prod_abs qs) args), - SumTree.mk_inj RST n' i' (replace_frees rews rhs) - |> Envir.beta_norm) - end - - val qglrs = map convert_eqs fqgars - in - Mutual {n=num, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, - parts=parts, fqgars=fqgars, qglrs=qglrs, fsum=NONE} - end - - - - -fun define_projections fixes mutual fsum lthy = - let - fun def ((MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs, f_def, ...}), (_, mixfix)) lthy = - let - val ((f, (_, f_defthm)), lthy') = - LocalTheory.define Thm.internalK ((Binding.name fname, mixfix), - ((Binding.name (fname ^ "_def"), []), Term.subst_bound (fsum, f_def))) - lthy - in - (MutualPart {i=i, i'=i', fvar=(fname, fT), cargTs=cargTs, f_def=f_def, - f=SOME f, f_defthm=SOME f_defthm }, - lthy') - end - - val Mutual { n, n', fsum_var, ST, RST, parts, fqgars, qglrs, ... } = mutual - val (parts', lthy') = fold_map def (parts ~~ fixes) lthy - in - (Mutual { n=n, n'=n', fsum_var=fsum_var, ST=ST, RST=RST, parts=parts', - fqgars=fqgars, qglrs=qglrs, fsum=SOME fsum }, - lthy') - end - - -fun in_context ctxt (f, pre_qs, pre_gs, pre_args, pre_rhs) F = - let - val thy = ProofContext.theory_of ctxt - - val oqnames = map fst pre_qs - val (qs, ctxt') = Variable.variant_fixes oqnames ctxt - |>> map2 (fn (_, T) => fn n => Free (n, T)) pre_qs - - fun inst t = subst_bounds (rev qs, t) - val gs = map inst pre_gs - val args = map inst pre_args - val rhs = inst pre_rhs - - val cqs = map (cterm_of thy) qs - val ags = map (assume o cterm_of thy) gs - - val import = fold forall_elim cqs - #> fold Thm.elim_implies ags - - val export = fold_rev (implies_intr o cprop_of) ags - #> fold_rev forall_intr_rename (oqnames ~~ cqs) - in - F ctxt (f, qs, gs, args, rhs) import export - end - -fun recover_mutual_psimp all_orig_fdefs parts ctxt (fname, _, _, args, rhs) import (export : thm -> thm) sum_psimp_eq = - let - val (MutualPart {f=SOME f, f_defthm=SOME f_def, ...}) = get_part fname parts - - val psimp = import sum_psimp_eq - val (simp, restore_cond) = case cprems_of psimp of - [] => (psimp, I) - | [cond] => (implies_elim psimp (assume cond), implies_intr cond) - | _ => sys_error "Too many conditions" - in - Goal.prove ctxt [] [] - (HOLogic.Trueprop $ HOLogic.mk_eq (list_comb (f, args), rhs)) - (fn _ => (LocalDefs.unfold_tac ctxt all_orig_fdefs) - THEN EqSubst.eqsubst_tac ctxt [0] [simp] 1 - THEN (simp_tac (local_simpset_of ctxt addsimps SumTree.proj_in_rules)) 1) - |> restore_cond - |> export - end - - -(* FIXME HACK *) -fun mk_applied_form ctxt caTs thm = - let - val thy = ProofContext.theory_of ctxt - val xs = map_index (fn (i,T) => cterm_of thy (Free ("x" ^ string_of_int i, T))) caTs (* FIXME: Bind xs properly *) - in - fold (fn x => fn thm => combination thm (reflexive x)) xs thm - |> Conv.fconv_rule (Thm.beta_conversion true) - |> fold_rev forall_intr xs - |> Thm.forall_elim_vars 0 - end - - -fun mutual_induct_rules lthy induct all_f_defs (Mutual {n, ST, RST, parts, ...}) = - let - val cert = cterm_of (ProofContext.theory_of lthy) - val newPs = map2 (fn Pname => fn MutualPart {cargTs, ...} => - Free (Pname, cargTs ---> HOLogic.boolT)) - (mutual_induct_Pnames (length parts)) - parts - - fun mk_P (MutualPart {cargTs, ...}) P = - let - val avars = map_index (fn (i,T) => Var (("a", i), T)) cargTs - val atup = foldr1 HOLogic.mk_prod avars - in - tupled_lambda atup (list_comb (P, avars)) - end - - val Ps = map2 mk_P parts newPs - val case_exp = SumTree.mk_sumcases HOLogic.boolT Ps - - val induct_inst = - forall_elim (cert case_exp) induct - |> full_simplify SumTree.sumcase_split_ss - |> full_simplify (HOL_basic_ss addsimps all_f_defs) - - fun project rule (MutualPart {cargTs, i, ...}) k = - let - val afs = map_index (fn (j,T) => Free ("a" ^ string_of_int (j + k), T)) cargTs (* FIXME! *) - val inj = SumTree.mk_inj ST n i (foldr1 HOLogic.mk_prod afs) - in - (rule - |> forall_elim (cert inj) - |> full_simplify SumTree.sumcase_split_ss - |> fold_rev (forall_intr o cert) (afs @ newPs), - k + length cargTs) - end - in - fst (fold_map (project induct_inst) parts 0) - end - - -fun mk_partial_rules_mutual lthy inner_cont (m as Mutual {parts, fqgars, ...}) proof = - let - val result = inner_cont proof - val FundefResult {fs=[f], G, R, cases, psimps, trsimps, simple_pinducts=[simple_pinduct], - termination,domintros} = result - - val (all_f_defs, fs) = map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} => - (mk_applied_form lthy cargTs (symmetric f_def), f)) - parts - |> split_list - - val all_orig_fdefs = map (fn MutualPart {f_defthm = SOME f_def, ...} => f_def) parts - - fun mk_mpsimp fqgar sum_psimp = - in_context lthy fqgar (recover_mutual_psimp all_orig_fdefs parts) sum_psimp - - val rew_ss = HOL_basic_ss addsimps all_f_defs - val mpsimps = map2 mk_mpsimp fqgars psimps - val mtrsimps = map_option (map2 mk_mpsimp fqgars) trsimps - val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m - val mtermination = full_simplify rew_ss termination - val mdomintros = map_option (map (full_simplify rew_ss)) domintros - in - FundefResult { fs=fs, G=G, R=R, - psimps=mpsimps, simple_pinducts=minducts, - cases=cases, termination=mtermination, - domintros=mdomintros, - trsimps=mtrsimps} - end - -fun prepare_fundef_mutual config defname fixes eqss lthy = - let - val mutual = analyze_eqs lthy defname (map fst fixes) (map Envir.beta_eta_contract eqss) - val Mutual {fsum_var=(n, T), qglrs, ...} = mutual - - val ((fsum, goalstate, cont), lthy') = - FundefCore.prepare_fundef config defname [((n, T), NoSyn)] qglrs lthy - - val (mutual', lthy'') = define_projections fixes mutual fsum lthy' - - val mutual_cont = mk_partial_rules_mutual lthy'' cont mutual' - in - ((goalstate, mutual_cont), lthy'') - end - - -end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/pattern_split.ML --- a/src/HOL/Tools/function_package/pattern_split.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,138 +0,0 @@ -(* Title: HOL/Tools/function_package/pattern_split.ML - ID: $Id$ - Author: Alexander Krauss, TU Muenchen - -A package for general recursive function definitions. - -Automatic splitting of overlapping constructor patterns. This is a preprocessing step which -turns a specification with overlaps into an overlap-free specification. - -*) - -signature FUNDEF_SPLIT = -sig - val split_some_equations : - Proof.context -> (bool * term) list -> term list list - - val split_all_equations : - Proof.context -> term list -> term list list -end - -structure FundefSplit : FUNDEF_SPLIT = -struct - -open FundefLib - -(* We use proof context for the variable management *) -(* FIXME: no __ *) - -fun new_var ctx vs T = - let - val [v] = Variable.variant_frees ctx vs [("v", T)] - in - (Free v :: vs, Free v) - end - -fun saturate ctx vs t = - fold (fn T => fn (vs, t) => new_var ctx vs T |> apsnd (curry op $ t)) - (binder_types (fastype_of t)) (vs, t) - - -(* This is copied from "fundef_datatype.ML" *) -fun inst_constrs_of thy (T as Type (name, _)) = - map (fn (Cn,CT) => Envir.subst_TVars (Sign.typ_match thy (body_type CT, T) Vartab.empty) (Const (Cn, CT))) - (the (Datatype.get_datatype_constrs thy name)) - | inst_constrs_of thy T = raise TYPE ("inst_constrs_of", [T], []) - - - - -fun join ((vs1,sub1), (vs2,sub2)) = (merge (op aconv) (vs1,vs2), sub1 @ sub2) -fun join_product (xs, ys) = map_product (curry join) xs ys - -fun join_list [] = [] - | join_list xs = foldr1 (join_product) xs - - -exception DISJ - -fun pattern_subtract_subst ctx vs t t' = - let - exception DISJ - fun pattern_subtract_subst_aux vs _ (Free v2) = [] - | pattern_subtract_subst_aux vs (v as (Free (_, T))) t' = - let - fun foo constr = - let - val (vs', t) = saturate ctx vs constr - val substs = pattern_subtract_subst ctx vs' t t' - in - map (fn (vs, subst) => (vs, (v,t)::subst)) substs - end - in - flat (map foo (inst_constrs_of (ProofContext.theory_of ctx) T)) - end - | pattern_subtract_subst_aux vs t t' = - let - val (C, ps) = strip_comb t - val (C', qs) = strip_comb t' - in - if C = C' - then flat (map2 (pattern_subtract_subst_aux vs) ps qs) - else raise DISJ - end - in - pattern_subtract_subst_aux vs t t' - handle DISJ => [(vs, [])] - end - - -(* p - q *) -fun pattern_subtract ctx eq2 eq1 = - let - val thy = ProofContext.theory_of ctx - - val (vs, feq1 as (_ $ (_ $ lhs1 $ _))) = dest_all_all eq1 - val (_, _ $ (_ $ lhs2 $ _)) = dest_all_all eq2 - - val substs = pattern_subtract_subst ctx vs lhs1 lhs2 - - fun instantiate (vs', sigma) = - let - val t = Pattern.rewrite_term thy sigma [] feq1 - in - fold_rev Logic.all (map Free (frees_in_term ctx t) inter vs') t - end - in - map instantiate substs - end - - -(* ps - p' *) -fun pattern_subtract_from_many ctx p'= - flat o map (pattern_subtract ctx p') - -(* in reverse order *) -fun pattern_subtract_many ctx ps' = - fold_rev (pattern_subtract_from_many ctx) ps' - - - -fun split_some_equations ctx eqns = - let - fun split_aux prev [] = [] - | split_aux prev ((true, eq) :: es) = pattern_subtract_many ctx prev [eq] - :: split_aux (eq :: prev) es - | split_aux prev ((false, eq) :: es) = [eq] - :: split_aux (eq :: prev) es - in - split_aux [] eqns - end - -fun split_all_equations ctx = - split_some_equations ctx o map (pair true) - - - - -end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/scnp_reconstruct.ML --- a/src/HOL/Tools/function_package/scnp_reconstruct.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,429 +0,0 @@ -(* Title: HOL/Tools/function_package/scnp_reconstruct.ML - Author: Armin Heller, TU Muenchen - Author: Alexander Krauss, TU Muenchen - -Proof reconstruction for SCNP -*) - -signature SCNP_RECONSTRUCT = -sig - - val sizechange_tac : Proof.context -> tactic -> tactic - - val decomp_scnp : ScnpSolve.label list -> Proof.context -> Proof.method - - val setup : theory -> theory - - datatype multiset_setup = - Multiset of - { - msetT : typ -> typ, - mk_mset : typ -> term list -> term, - mset_regroup_conv : int list -> conv, - mset_member_tac : int -> int -> tactic, - mset_nonempty_tac : int -> tactic, - mset_pwleq_tac : int -> tactic, - set_of_simps : thm list, - smsI' : thm, - wmsI2'' : thm, - wmsI1 : thm, - reduction_pair : thm - } - - - val multiset_setup : multiset_setup -> theory -> theory - -end - -structure ScnpReconstruct : SCNP_RECONSTRUCT = -struct - -val PROFILE = FundefCommon.PROFILE -fun TRACE x = if ! FundefCommon.profile then Output.tracing x else () - -open ScnpSolve - -val natT = HOLogic.natT -val nat_pairT = HOLogic.mk_prodT (natT, natT) - -(* Theory dependencies *) - -datatype multiset_setup = - Multiset of - { - msetT : typ -> typ, - mk_mset : typ -> term list -> term, - mset_regroup_conv : int list -> conv, - mset_member_tac : int -> int -> tactic, - mset_nonempty_tac : int -> tactic, - mset_pwleq_tac : int -> tactic, - set_of_simps : thm list, - smsI' : thm, - wmsI2'' : thm, - wmsI1 : thm, - reduction_pair : thm - } - -structure MultisetSetup = TheoryDataFun -( - type T = multiset_setup option - val empty = NONE - val copy = I; - val extend = I; - fun merge _ (v1, v2) = if is_some v2 then v2 else v1 -) - -val multiset_setup = MultisetSetup.put o SOME - -fun undef x = error "undef" -fun get_multiset_setup thy = MultisetSetup.get thy - |> the_default (Multiset -{ msetT = undef, mk_mset=undef, - mset_regroup_conv=undef, mset_member_tac = undef, - mset_nonempty_tac = undef, mset_pwleq_tac = undef, - set_of_simps = [],reduction_pair = refl, - smsI'=refl, wmsI2''=refl, wmsI1=refl }) - -fun order_rpair _ MAX = @{thm max_rpair_set} - | order_rpair msrp MS = msrp - | order_rpair _ MIN = @{thm min_rpair_set} - -fun ord_intros_max true = - (@{thm smax_emptyI}, @{thm smax_insertI}) - | ord_intros_max false = - (@{thm wmax_emptyI}, @{thm wmax_insertI}) -fun ord_intros_min true = - (@{thm smin_emptyI}, @{thm smin_insertI}) - | ord_intros_min false = - (@{thm wmin_emptyI}, @{thm wmin_insertI}) - -fun gen_probl D cs = - let - val n = Termination.get_num_points D - val arity = length o Termination.get_measures D - fun measure p i = nth (Termination.get_measures D p) i - - fun mk_graph c = - let - val (_, p, _, q, _, _) = Termination.dest_call D c - - fun add_edge i j = - case Termination.get_descent D c (measure p i) (measure q j) - of SOME (Termination.Less _) => cons (i, GTR, j) - | SOME (Termination.LessEq _) => cons (i, GEQ, j) - | _ => I - - val edges = - fold_product add_edge (0 upto arity p - 1) (0 upto arity q - 1) [] - in - G (p, q, edges) - end - in - GP (map arity (0 upto n - 1), map mk_graph cs) - end - -(* General reduction pair application *) -fun rem_inv_img ctxt = - let - val unfold_tac = LocalDefs.unfold_tac ctxt - in - rtac @{thm subsetI} 1 - THEN etac @{thm CollectE} 1 - THEN REPEAT (etac @{thm exE} 1) - THEN unfold_tac @{thms inv_image_def} - THEN rtac @{thm CollectI} 1 - THEN etac @{thm conjE} 1 - THEN etac @{thm ssubst} 1 - THEN unfold_tac (@{thms split_conv} @ @{thms triv_forall_equality} - @ @{thms sum.cases}) - end - -(* Sets *) - -val setT = HOLogic.mk_setT - -fun set_member_tac m i = - if m = 0 then rtac @{thm insertI1} i - else rtac @{thm insertI2} i THEN set_member_tac (m - 1) i - -val set_nonempty_tac = rtac @{thm insert_not_empty} - -fun set_finite_tac i = - rtac @{thm finite.emptyI} i - ORELSE (rtac @{thm finite.insertI} i THEN (fn st => set_finite_tac i st)) - - -(* Reconstruction *) - -fun reconstruct_tac ctxt D cs (gp as GP (_, gs)) certificate = - let - val thy = ProofContext.theory_of ctxt - val Multiset - { msetT, mk_mset, - mset_regroup_conv, mset_member_tac, - mset_nonempty_tac, mset_pwleq_tac, set_of_simps, - smsI', wmsI2'', wmsI1, reduction_pair=ms_rp } - = get_multiset_setup thy - - fun measure_fn p = nth (Termination.get_measures D p) - - fun get_desc_thm cidx m1 m2 bStrict = - case Termination.get_descent D (nth cs cidx) m1 m2 - of SOME (Termination.Less thm) => - if bStrict then thm - else (thm COMP (Thm.lift_rule (cprop_of thm) @{thm less_imp_le})) - | SOME (Termination.LessEq (thm, _)) => - if not bStrict then thm - else sys_error "get_desc_thm" - | _ => sys_error "get_desc_thm" - - val (label, lev, sl, covering) = certificate - - fun prove_lev strict g = - let - val G (p, q, el) = nth gs g - - fun less_proof strict (j, b) (i, a) = - let - val tag_flag = b < a orelse (not strict andalso b <= a) - - val stored_thm = - get_desc_thm g (measure_fn p i) (measure_fn q j) - (not tag_flag) - |> Conv.fconv_rule (Thm.beta_conversion true) - - val rule = if strict - then if b < a then @{thm pair_lessI2} else @{thm pair_lessI1} - else if b <= a then @{thm pair_leqI2} else @{thm pair_leqI1} - in - rtac rule 1 THEN PRIMITIVE (Thm.elim_implies stored_thm) - THEN (if tag_flag then Arith_Data.verbose_arith_tac ctxt 1 else all_tac) - end - - fun steps_tac MAX strict lq lp = - let - val (empty, step) = ord_intros_max strict - in - if length lq = 0 - then rtac empty 1 THEN set_finite_tac 1 - THEN (if strict then set_nonempty_tac 1 else all_tac) - else - let - val (j, b) :: rest = lq - val (i, a) = the (covering g strict j) - fun choose xs = set_member_tac (Library.find_index (curry op = (i, a)) xs) 1 - val solve_tac = choose lp THEN less_proof strict (j, b) (i, a) - in - rtac step 1 THEN solve_tac THEN steps_tac MAX strict rest lp - end - end - | steps_tac MIN strict lq lp = - let - val (empty, step) = ord_intros_min strict - in - if length lp = 0 - then rtac empty 1 - THEN (if strict then set_nonempty_tac 1 else all_tac) - else - let - val (i, a) :: rest = lp - val (j, b) = the (covering g strict i) - fun choose xs = set_member_tac (Library.find_index (curry op = (j, b)) xs) 1 - val solve_tac = choose lq THEN less_proof strict (j, b) (i, a) - in - rtac step 1 THEN solve_tac THEN steps_tac MIN strict lq rest - end - end - | steps_tac MS strict lq lp = - let - fun get_str_cover (j, b) = - if is_some (covering g true j) then SOME (j, b) else NONE - fun get_wk_cover (j, b) = the (covering g false j) - - val qs = lq \\ map_filter get_str_cover lq - val ps = map get_wk_cover qs - - fun indices xs ys = map (fn y => Library.find_index (curry op = y) xs) ys - val iqs = indices lq qs - val ips = indices lp ps - - local open Conv in - fun t_conv a C = - params_conv ~1 (K ((concl_conv ~1 o arg_conv o arg1_conv o a) C)) ctxt - val goal_rewrite = - t_conv arg1_conv (mset_regroup_conv iqs) - then_conv t_conv arg_conv (mset_regroup_conv ips) - end - in - CONVERSION goal_rewrite 1 - THEN (if strict then rtac smsI' 1 - else if qs = lq then rtac wmsI2'' 1 - else rtac wmsI1 1) - THEN mset_pwleq_tac 1 - THEN EVERY (map2 (less_proof false) qs ps) - THEN (if strict orelse qs <> lq - then LocalDefs.unfold_tac ctxt set_of_simps - THEN steps_tac MAX true (lq \\ qs) (lp \\ ps) - else all_tac) - end - in - rem_inv_img ctxt - THEN steps_tac label strict (nth lev q) (nth lev p) - end - - val (mk_set, setT) = if label = MS then (mk_mset, msetT) else (HOLogic.mk_set, setT) - - fun tag_pair p (i, tag) = - HOLogic.pair_const natT natT $ - (measure_fn p i $ Bound 0) $ HOLogic.mk_number natT tag - - fun pt_lev (p, lm) = Abs ("x", Termination.get_types D p, - mk_set nat_pairT (map (tag_pair p) lm)) - - val level_mapping = - map_index pt_lev lev - |> Termination.mk_sumcases D (setT nat_pairT) - |> cterm_of thy - in - PROFILE "Proof Reconstruction" - (CONVERSION (Conv.arg_conv (Conv.arg_conv (FundefLib.regroup_union_conv sl))) 1 - THEN (rtac @{thm reduction_pair_lemma} 1) - THEN (rtac @{thm rp_inv_image_rp} 1) - THEN (rtac (order_rpair ms_rp label) 1) - THEN PRIMITIVE (instantiate' [] [SOME level_mapping]) - THEN unfold_tac @{thms rp_inv_image_def} (local_simpset_of ctxt) - THEN LocalDefs.unfold_tac ctxt - (@{thms split_conv} @ @{thms fst_conv} @ @{thms snd_conv}) - THEN REPEAT (SOMEGOAL (resolve_tac [@{thm Un_least}, @{thm empty_subsetI}])) - THEN EVERY (map (prove_lev true) sl) - THEN EVERY (map (prove_lev false) ((0 upto length cs - 1) \\ sl))) - end - - - -local open Termination in -fun print_cell (SOME (Less _)) = "<" - | print_cell (SOME (LessEq _)) = "\" - | print_cell (SOME (None _)) = "-" - | print_cell (SOME (False _)) = "-" - | print_cell (NONE) = "?" - -fun print_error ctxt D = CALLS (fn (cs, i) => - let - val np = get_num_points D - val ms = map (get_measures D) (0 upto np - 1) - val tys = map (get_types D) (0 upto np - 1) - fun index xs = (1 upto length xs) ~~ xs - fun outp s t f xs = map (fn (x, y) => s ^ Int.toString x ^ t ^ f y ^ "\n") xs - val ims = index (map index ms) - val _ = Output.tracing (concat (outp "fn #" ":\n" (concat o outp "\tmeasure #" ": " (Syntax.string_of_term ctxt)) ims)) - fun print_call (k, c) = - let - val (_, p, _, q, _, _) = dest_call D c - val _ = Output.tracing ("call table for call #" ^ Int.toString k ^ ": fn " ^ - Int.toString (p + 1) ^ " ~> fn " ^ Int.toString (q + 1)) - val caller_ms = nth ms p - val callee_ms = nth ms q - val entries = map (fn x => map (pair x) (callee_ms)) (caller_ms) - fun print_ln (i : int, l) = concat (Int.toString i :: " " :: map (enclose " " " " o print_cell o (uncurry (get_descent D c))) l) - val _ = Output.tracing (concat (Int.toString (p + 1) ^ "|" ^ Int.toString (q + 1) ^ - " " :: map (enclose " " " " o Int.toString) (1 upto length callee_ms)) ^ "\n" - ^ cat_lines (map print_ln ((1 upto (length entries)) ~~ entries))) - in - true - end - fun list_call (k, c) = - let - val (_, p, _, q, _, _) = dest_call D c - val _ = Output.tracing ("call #" ^ (Int.toString k) ^ ": fn " ^ - Int.toString (p + 1) ^ " ~> fn " ^ Int.toString (q + 1) ^ "\n" ^ - (Syntax.string_of_term ctxt c)) - in true end - val _ = forall list_call ((1 upto length cs) ~~ cs) - val _ = forall print_call ((1 upto length cs) ~~ cs) - in - all_tac - end) -end - - -fun single_scnp_tac use_tags orders ctxt cont err_cont D = Termination.CALLS (fn (cs, i) => - let - val ms_configured = is_some (MultisetSetup.get (ProofContext.theory_of ctxt)) - val orders' = if ms_configured then orders - else filter_out (curry op = MS) orders - val gp = gen_probl D cs -(* val _ = TRACE ("SCNP instance: " ^ makestring gp)*) - val certificate = generate_certificate use_tags orders' gp -(* val _ = TRACE ("Certificate: " ^ makestring certificate)*) - - in - case certificate - of NONE => err_cont D i - | SOME cert => - SELECT_GOAL (reconstruct_tac ctxt D cs gp cert) i - THEN (rtac @{thm wf_empty} i ORELSE cont D i) - end) - -fun gen_decomp_scnp_tac orders autom_tac ctxt err_cont = - let - open Termination - val derive_diag = Descent.derive_diag ctxt autom_tac - val derive_all = Descent.derive_all ctxt autom_tac - val decompose = Decompose.decompose_tac ctxt autom_tac - val scnp_no_tags = single_scnp_tac false orders ctxt - val scnp_full = single_scnp_tac true orders ctxt - - fun first_round c e = - derive_diag (REPEAT scnp_no_tags c e) - - val second_round = - REPEAT (fn c => fn e => decompose (scnp_no_tags c c) e) - - val third_round = - derive_all oo - REPEAT (fn c => fn e => - scnp_full (decompose c c) e) - - fun Then s1 s2 c e = s1 (s2 c c) (s2 c e) - - val strategy = Then (Then first_round second_round) third_round - - in - TERMINATION ctxt (strategy err_cont err_cont) - end - -fun gen_sizechange_tac orders autom_tac ctxt err_cont = - TRY (FundefCommon.apply_termination_rule ctxt 1) - THEN TRY (Termination.wf_union_tac ctxt) - THEN - (rtac @{thm wf_empty} 1 - ORELSE gen_decomp_scnp_tac orders autom_tac ctxt err_cont 1) - -fun sizechange_tac ctxt autom_tac = - gen_sizechange_tac [MAX, MS, MIN] autom_tac ctxt (K (K all_tac)) - -fun decomp_scnp orders ctxt = - let - val extra_simps = FundefCommon.TerminationSimps.get ctxt - val autom_tac = auto_tac (local_clasimpset_of ctxt addsimps2 extra_simps) - in - SIMPLE_METHOD - (gen_sizechange_tac orders autom_tac ctxt (print_error ctxt)) - end - - -(* Method setup *) - -val orders = - Scan.repeat1 - ((Args.$$$ "max" >> K MAX) || - (Args.$$$ "min" >> K MIN) || - (Args.$$$ "ms" >> K MS)) - || Scan.succeed [MAX, MS, MIN] - -val setup = Method.setup @{binding sizechange} - (Scan.lift orders --| Method.sections clasimp_modifiers >> decomp_scnp) - "termination prover with graph decomposition and the NP subset of size change termination" - -end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/scnp_solve.ML --- a/src/HOL/Tools/function_package/scnp_solve.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,257 +0,0 @@ -(* Title: HOL/Tools/function_package/scnp_solve.ML - Author: Armin Heller, TU Muenchen - Author: Alexander Krauss, TU Muenchen - -Generate certificates for SCNP using a SAT solver -*) - - -signature SCNP_SOLVE = -sig - - datatype edge = GTR | GEQ - datatype graph = G of int * int * (int * edge * int) list - datatype graph_problem = GP of int list * graph list - - datatype label = MIN | MAX | MS - - type certificate = - label (* which order *) - * (int * int) list list (* (multi)sets *) - * int list (* strictly ordered calls *) - * (int -> bool -> int -> (int * int) option) (* covering function *) - - val generate_certificate : bool -> label list -> graph_problem -> certificate option - - val solver : string ref -end - -structure ScnpSolve : SCNP_SOLVE = -struct - -(** Graph problems **) - -datatype edge = GTR | GEQ ; -datatype graph = G of int * int * (int * edge * int) list ; -datatype graph_problem = GP of int list * graph list ; - -datatype label = MIN | MAX | MS ; -type certificate = - label - * (int * int) list list - * int list - * (int -> bool -> int -> (int * int) option) - -fun graph_at (GP (_, gs), i) = nth gs i ; -fun num_prog_pts (GP (arities, _)) = length arities ; -fun num_graphs (GP (_, gs)) = length gs ; -fun arity (GP (arities, gl)) i = nth arities i ; -fun ndigits (GP (arities, _)) = IntInf.log2 (List.foldl (op +) 0 arities) + 1 - - -(** Propositional formulas **) - -val Not = PropLogic.Not and And = PropLogic.And and Or = PropLogic.Or -val BoolVar = PropLogic.BoolVar -fun Implies (p, q) = Or (Not p, q) -fun Equiv (p, q) = And (Implies (p, q), Implies (q, p)) -val all = PropLogic.all - -(* finite indexed quantifiers: - -iforall n f <==> /\ - / \ f i - 0<=i Equiv (TAG x i, TAG y i))) - - fun encode_graph (g, p, q, n, m, edges) = - let - fun encode_edge i j = - if exists (fn x => x = (i, GTR, j)) edges then - And (ES (g, i, j), EW (g, i, j)) - else if not (exists (fn x => x = (i, GEQ, j)) edges) then - And (Not (ES (g, i, j)), Not (EW (g, i, j))) - else - And ( - Equiv (ES (g, i, j), - encode_constraint_strict bits ((p, i), (q, j))), - Equiv (EW (g, i, j), - encode_constraint_weak bits ((p, i), (q, j)))) - in - iforall2 n m encode_edge - end - in - iforall ng (encode_graph o graph_info gp) - end - - -(* Order-specific part of encoding *) - -fun encode bits gp mu = - let - val ng = num_graphs gp - val (ES,EW,WEAK,STRICT,P,GAM,EPS,_) = var_constrs gp - - fun encode_graph MAX (g, p, q, n, m, _) = - And ( - Equiv (WEAK g, - iforall m (fn j => - Implies (P (q, j), - iexists n (fn i => - And (P (p, i), EW (g, i, j)))))), - Equiv (STRICT g, - And ( - iforall m (fn j => - Implies (P (q, j), - iexists n (fn i => - And (P (p, i), ES (g, i, j))))), - iexists n (fn i => P (p, i))))) - | encode_graph MIN (g, p, q, n, m, _) = - And ( - Equiv (WEAK g, - iforall n (fn i => - Implies (P (p, i), - iexists m (fn j => - And (P (q, j), EW (g, i, j)))))), - Equiv (STRICT g, - And ( - iforall n (fn i => - Implies (P (p, i), - iexists m (fn j => - And (P (q, j), ES (g, i, j))))), - iexists m (fn j => P (q, j))))) - | encode_graph MS (g, p, q, n, m, _) = - all [ - Equiv (WEAK g, - iforall m (fn j => - Implies (P (q, j), - iexists n (fn i => GAM (g, i, j))))), - Equiv (STRICT g, - iexists n (fn i => - And (P (p, i), Not (EPS (g, i))))), - iforall2 n m (fn i => fn j => - Implies (GAM (g, i, j), - all [ - P (p, i), - P (q, j), - EW (g, i, j), - Equiv (Not (EPS (g, i)), ES (g, i, j))])), - iforall n (fn i => - Implies (And (P (p, i), EPS (g, i)), - exactly_one m (fn j => GAM (g, i, j)))) - ] - in - all [ - encode_graphs bits gp, - iforall ng (encode_graph mu o graph_info gp), - iforall ng (fn x => WEAK x), - iexists ng (fn x => STRICT x) - ] - end - - -(*Generieren des level-mapping und diverser output*) -fun mk_certificate bits label gp f = - let - val (ES,EW,WEAK,STRICT,P,GAM,EPS,TAG) = var_constrs gp - fun assign (PropLogic.BoolVar v) = the_default false (f v) - fun assignTag i j = - (fold (fn x => fn y => 2 * y + (if assign (TAG (i, j) x) then 1 else 0)) - (bits - 1 downto 0) 0) - - val level_mapping = - let fun prog_pt_mapping p = - map_filter (fn x => if assign (P(p, x)) then SOME (x, assignTag p x) else NONE) - (0 upto (arity gp p) - 1) - in map prog_pt_mapping (0 upto num_prog_pts gp - 1) end - - val strict_list = filter (assign o STRICT) (0 upto num_graphs gp - 1) - - fun covering_pair g bStrict j = - let - val (_, p, q, n, m, _) = graph_info gp g - - fun cover MAX j = find_index (fn i => assign (P (p, i)) andalso assign (EW (g, i, j))) (0 upto n - 1) - | cover MS k = find_index (fn i => assign (GAM (g, i, k))) (0 upto n - 1) - | cover MIN i = find_index (fn j => assign (P (q, j)) andalso assign (EW (g, i, j))) (0 upto m - 1) - fun cover_strict MAX j = find_index (fn i => assign (P (p, i)) andalso assign (ES (g, i, j))) (0 upto n - 1) - | cover_strict MS k = find_index (fn i => assign (GAM (g, i, k)) andalso not (assign (EPS (g, i) ))) (0 upto n - 1) - | cover_strict MIN i = find_index (fn j => assign (P (q, j)) andalso assign (ES (g, i, j))) (0 upto m - 1) - val i = if bStrict then cover_strict label j else cover label j - in - find_first (fn x => fst x = i) (nth level_mapping (if label = MIN then q else p)) - end - in - (label, level_mapping, strict_list, covering_pair) - end - -(*interface for the proof reconstruction*) -fun generate_certificate use_tags labels gp = - let - val bits = if use_tags then ndigits gp else 0 - in - get_first - (fn l => case sat_solver (encode bits gp l) of - SatSolver.SATISFIABLE f => SOME (mk_certificate bits l gp f) - | _ => NONE) - labels - end -end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/size.ML --- a/src/HOL/Tools/function_package/size.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,242 +0,0 @@ -(* Title: HOL/Tools/function_package/size.ML - Author: Stefan Berghofer, Florian Haftmann & Alexander Krauss, TU Muenchen - -Size functions for datatypes. -*) - -signature SIZE = -sig - val size_thms: theory -> string -> thm list - val setup: theory -> theory -end; - -structure Size: SIZE = -struct - -open DatatypeAux; - -structure SizeData = TheoryDataFun -( - type T = (string * thm list) Symtab.table; - val empty = Symtab.empty; - val copy = I - val extend = I - fun merge _ = Symtab.merge (K true); -); - -val lookup_size = SizeData.get #> Symtab.lookup; - -fun plus (t1, t2) = Const ("HOL.plus_class.plus", - HOLogic.natT --> HOLogic.natT --> HOLogic.natT) $ t1 $ t2; - -fun size_of_type f g h (T as Type (s, Ts)) = - (case f s of - SOME t => SOME t - | NONE => (case g s of - SOME size_name => - SOME (list_comb (Const (size_name, - map (fn U => U --> HOLogic.natT) Ts @ [T] ---> HOLogic.natT), - map (size_of_type' f g h) Ts)) - | NONE => NONE)) - | size_of_type f g h (TFree (s, _)) = h s -and size_of_type' f g h T = (case size_of_type f g h T of - NONE => Abs ("x", T, HOLogic.zero) - | SOME t => t); - -fun is_poly thy (DtType (name, dts)) = - (case Datatype.get_datatype thy name of - NONE => false - | SOME _ => exists (is_poly thy) dts) - | is_poly _ _ = true; - -fun constrs_of thy name = - let - val {descr, index, ...} = Datatype.the_datatype thy name - val SOME (_, _, constrs) = AList.lookup op = descr index - in constrs end; - -val app = curry (list_comb o swap); - -fun prove_size_thms (info : info) new_type_names thy = - let - val {descr, alt_names, sorts, rec_names, rec_rewrites, induction, ...} = info; - val l = length new_type_names; - val alt_names' = (case alt_names of - NONE => replicate l NONE | SOME names => map SOME names); - val descr' = List.take (descr, l); - val (rec_names1, rec_names2) = chop l rec_names; - val recTs = get_rec_types descr sorts; - val (recTs1, recTs2) = chop l recTs; - val (_, (_, paramdts, _)) :: _ = descr; - val paramTs = map (typ_of_dtyp descr sorts) paramdts; - val ((param_size_fs, param_size_fTs), f_names) = paramTs |> - map (fn T as TFree (s, _) => - let - val name = "f" ^ implode (tl (explode s)); - val U = T --> HOLogic.natT - in - (((s, Free (name, U)), U), name) - end) |> split_list |>> split_list; - val param_size = AList.lookup op = param_size_fs; - - val extra_rewrites = descr |> map (#1 o snd) |> distinct op = |> - map_filter (Option.map snd o lookup_size thy) |> flat; - val extra_size = Option.map fst o lookup_size thy; - - val (((size_names, size_fns), def_names), def_names') = - recTs1 ~~ alt_names' |> - map (fn (T as Type (s, _), optname) => - let - val s' = the_default (Long_Name.base_name s) optname ^ "_size"; - val s'' = Sign.full_bname thy s' - in - (s'', - (list_comb (Const (s'', param_size_fTs @ [T] ---> HOLogic.natT), - map snd param_size_fs), - (s' ^ "_def", s' ^ "_overloaded_def"))) - end) |> split_list ||>> split_list ||>> split_list; - val overloaded_size_fns = map HOLogic.size_const recTs1; - - (* instantiation for primrec combinator *) - fun size_of_constr b size_ofp ((_, cargs), (_, cargs')) = - let - val Ts = map (typ_of_dtyp descr sorts) cargs; - val k = length (filter is_rec_type cargs); - val (ts, _, _) = fold_rev (fn ((dt, dt'), T) => fn (us, i, j) => - if is_rec_type dt then (Bound i :: us, i + 1, j + 1) - else - (if b andalso is_poly thy dt' then - case size_of_type (K NONE) extra_size size_ofp T of - NONE => us | SOME sz => sz $ Bound j :: us - else us, i, j + 1)) - (cargs ~~ cargs' ~~ Ts) ([], 0, k); - val t = - if null ts andalso (not b orelse not (exists (is_poly thy) cargs')) - then HOLogic.zero - else foldl1 plus (ts @ [HOLogic.Suc_zero]) - in - List.foldr (fn (T, t') => Abs ("x", T, t')) t (Ts @ replicate k HOLogic.natT) - end; - - val fs = maps (fn (_, (name, _, constrs)) => - map (size_of_constr true param_size) (constrs ~~ constrs_of thy name)) descr; - val fs' = maps (fn (n, (name, _, constrs)) => - map (size_of_constr (l <= n) (K NONE)) (constrs ~~ constrs_of thy name)) descr; - val fTs = map fastype_of fs; - - val (rec_combs1, rec_combs2) = chop l (map (fn (T, rec_name) => - Const (rec_name, fTs @ [T] ---> HOLogic.natT)) - (recTs ~~ rec_names)); - - fun define_overloaded (def_name, eq) lthy = - let - val (Free (c, _), rhs) = (Logic.dest_equals o Syntax.check_term lthy) eq; - val ((_, (_, thm)), lthy') = lthy |> LocalTheory.define Thm.definitionK - ((Binding.name c, NoSyn), ((Binding.name def_name, []), rhs)); - val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy'); - val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm; - in (thm', lthy') end; - - val ((size_def_thms, size_def_thms'), thy') = - thy - |> Sign.add_consts_i (map (fn (s, T) => - (Binding.name (Long_Name.base_name s), param_size_fTs @ [T] ---> HOLogic.natT, NoSyn)) - (size_names ~~ recTs1)) - |> PureThy.add_defs false - (map (Thm.no_attributes o apsnd (Logic.mk_equals o apsnd (app fs))) - (map Binding.name def_names ~~ (size_fns ~~ rec_combs1))) - ||> TheoryTarget.instantiation - (map (#1 o snd) descr', map dest_TFree paramTs, [HOLogic.class_size]) - ||>> fold_map define_overloaded - (def_names' ~~ map Logic.mk_equals (overloaded_size_fns ~~ map (app fs') rec_combs1)) - ||> Class.prove_instantiation_instance (K (Class.intro_classes_tac [])) - ||> LocalTheory.exit_global; - - val ctxt = ProofContext.init thy'; - - val simpset1 = HOL_basic_ss addsimps @{thm add_0} :: @{thm add_0_right} :: - size_def_thms @ size_def_thms' @ rec_rewrites @ extra_rewrites; - val xs = map (fn i => "x" ^ string_of_int i) (1 upto length recTs2); - - fun mk_unfolded_size_eq tab size_ofp fs (p as (x, T), r) = - HOLogic.mk_eq (app fs r $ Free p, - the (size_of_type tab extra_size size_ofp T) $ Free p); - - fun prove_unfolded_size_eqs size_ofp fs = - if null recTs2 then [] - else split_conj_thm (SkipProof.prove ctxt xs [] - (HOLogic.mk_Trueprop (mk_conj (replicate l HOLogic.true_const @ - map (mk_unfolded_size_eq (AList.lookup op = - (new_type_names ~~ map (app fs) rec_combs1)) size_ofp fs) - (xs ~~ recTs2 ~~ rec_combs2)))) - (fn _ => (indtac induction xs THEN_ALL_NEW asm_simp_tac simpset1) 1)); - - val unfolded_size_eqs1 = prove_unfolded_size_eqs param_size fs; - val unfolded_size_eqs2 = prove_unfolded_size_eqs (K NONE) fs'; - - (* characteristic equations for size functions *) - fun gen_mk_size_eq p size_of size_ofp size_const T (cname, cargs) = - let - val Ts = map (typ_of_dtyp descr sorts) cargs; - val tnames = Name.variant_list f_names (DatatypeProp.make_tnames Ts); - val ts = map_filter (fn (sT as (s, T), dt) => - Option.map (fn sz => sz $ Free sT) - (if p dt then size_of_type size_of extra_size size_ofp T - else NONE)) (tnames ~~ Ts ~~ cargs) - in - HOLogic.mk_Trueprop (HOLogic.mk_eq - (size_const $ list_comb (Const (cname, Ts ---> T), - map2 (curry Free) tnames Ts), - if null ts then HOLogic.zero - else foldl1 plus (ts @ [HOLogic.Suc_zero]))) - end; - - val simpset2 = HOL_basic_ss addsimps - rec_rewrites @ size_def_thms @ unfolded_size_eqs1; - val simpset3 = HOL_basic_ss addsimps - rec_rewrites @ size_def_thms' @ unfolded_size_eqs2; - - fun prove_size_eqs p size_fns size_ofp simpset = - maps (fn (((_, (_, _, constrs)), size_const), T) => - map (fn constr => standard (SkipProof.prove ctxt [] [] - (gen_mk_size_eq p (AList.lookup op = (new_type_names ~~ size_fns)) - size_ofp size_const T constr) - (fn _ => simp_tac simpset 1))) constrs) - (descr' ~~ size_fns ~~ recTs1); - - val size_eqns = prove_size_eqs (is_poly thy') size_fns param_size simpset2 @ - prove_size_eqs is_rec_type overloaded_size_fns (K NONE) simpset3; - - val ([size_thms], thy'') = PureThy.add_thmss - [((Binding.name "size", size_eqns), - [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add, - Thm.declaration_attribute - (fn thm => Context.mapping (Code.add_default_eqn thm) I)])] thy' - - in - SizeData.map (fold (Symtab.update_new o apsnd (rpair size_thms)) - (new_type_names ~~ size_names)) thy'' - end; - -fun add_size_thms config (new_type_names as name :: _) thy = - let - val info as {descr, alt_names, ...} = Datatype.the_datatype thy name; - val prefix = Long_Name.map_base_name (K (space_implode "_" - (the_default (map Long_Name.base_name new_type_names) alt_names))) name; - val no_size = exists (fn (_, (_, _, constrs)) => exists (fn (_, cargs) => exists (fn dt => - is_rec_type dt andalso not (null (fst (strip_dtyp dt)))) cargs) constrs) descr - in if no_size then thy - else - thy - |> Sign.root_path - |> Sign.add_path prefix - |> Theory.checkpoint - |> prove_size_thms info new_type_names - |> Sign.restore_naming thy - end; - -val size_thms = snd oo (the oo lookup_size); - -val setup = Datatype.interpretation add_size_thms; - -end; diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/sum_tree.ML --- a/src/HOL/Tools/function_package/sum_tree.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,44 +0,0 @@ -(* Title: HOL/Tools/function_package/sum_tree.ML - ID: $Id$ - Author: Alexander Krauss, TU Muenchen - -Some common tools for working with sum types in balanced tree form. -*) - -structure SumTree = -struct - -(* Theory dependencies *) -val proj_in_rules = [@{thm "Datatype.Projl_Inl"}, @{thm "Datatype.Projr_Inr"}] -val sumcase_split_ss = HOL_basic_ss addsimps (@{thm "Product_Type.split"} :: @{thms "sum.cases"}) - -(* top-down access in balanced tree *) -fun access_top_down {left, right, init} len i = - BalancedTree.access {left = (fn f => f o left), right = (fn f => f o right), init = I} len i init - -(* Sum types *) -fun mk_sumT LT RT = Type ("+", [LT, RT]) -fun mk_sumcase TL TR T l r = Const (@{const_name "sum.sum_case"}, (TL --> T) --> (TR --> T) --> mk_sumT TL TR --> T) $ l $ r - -val App = curry op $ - -fun mk_inj ST n i = - access_top_down - { init = (ST, I : term -> term), - left = (fn (T as Type ("+", [LT, RT]), inj) => (LT, inj o App (Const (@{const_name "Inl"}, LT --> T)))), - right =(fn (T as Type ("+", [LT, RT]), inj) => (RT, inj o App (Const (@{const_name "Inr"}, RT --> T))))} n i - |> snd - -fun mk_proj ST n i = - access_top_down - { init = (ST, I : term -> term), - left = (fn (T as Type ("+", [LT, RT]), proj) => (LT, App (Const (@{const_name "Datatype.Projl"}, T --> LT)) o proj)), - right =(fn (T as Type ("+", [LT, RT]), proj) => (RT, App (Const (@{const_name "Datatype.Projr"}, T --> RT)) o proj))} n i - |> snd - -fun mk_sumcases T fs = - BalancedTree.make (fn ((f, fT), (g, gT)) => (mk_sumcase fT gT T f g, mk_sumT fT gT)) - (map (fn f => (f, domain_type (fastype_of f))) fs) - |> fst - -end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Tools/function_package/termination.ML --- a/src/HOL/Tools/function_package/termination.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,324 +0,0 @@ -(* Title: HOL/Tools/function_package/termination.ML - Author: Alexander Krauss, TU Muenchen - -Context data for termination proofs -*) - - -signature TERMINATION = -sig - - type data - datatype cell = Less of thm | LessEq of (thm * thm) | None of (thm * thm) | False of thm - - val mk_sumcases : data -> typ -> term list -> term - - val note_measure : int -> term -> data -> data - val note_chain : term -> term -> thm option -> data -> data - val note_descent : term -> term -> term -> cell -> data -> data - - val get_num_points : data -> int - val get_types : data -> int -> typ - val get_measures : data -> int -> term list - - (* read from cache *) - val get_chain : data -> term -> term -> thm option option - val get_descent : data -> term -> term -> term -> cell option - - (* writes *) - val derive_descent : theory -> tactic -> term -> term -> term -> data -> data - val derive_descents : theory -> tactic -> term -> data -> data - - val dest_call : data -> term -> ((string * typ) list * int * term * int * term * term) - - val CALLS : (term list * int -> tactic) -> int -> tactic - - (* Termination tactics. Sequential composition via continuations. (2nd argument is the error continuation) *) - type ttac = (data -> int -> tactic) -> (data -> int -> tactic) -> data -> int -> tactic - - val TERMINATION : Proof.context -> (data -> int -> tactic) -> int -> tactic - - val REPEAT : ttac -> ttac - - val wf_union_tac : Proof.context -> tactic -end - - - -structure Termination : TERMINATION = -struct - -open FundefLib - -val term2_ord = prod_ord TermOrd.fast_term_ord TermOrd.fast_term_ord -structure Term2tab = TableFun(type key = term * term val ord = term2_ord); -structure Term3tab = TableFun(type key = term * (term * term) val ord = prod_ord TermOrd.fast_term_ord term2_ord); - -(** Analyzing binary trees **) - -(* Skeleton of a tree structure *) - -datatype skel = - SLeaf of int (* index *) -| SBranch of (skel * skel) - - -(* abstract make and dest functions *) -fun mk_tree leaf branch = - let fun mk (SLeaf i) = leaf i - | mk (SBranch (s, t)) = branch (mk s, mk t) - in mk end - - -fun dest_tree split = - let fun dest (SLeaf i) x = [(i, x)] - | dest (SBranch (s, t)) x = - let val (l, r) = split x - in dest s l @ dest t r end - in dest end - - -(* concrete versions for sum types *) -fun is_inj (Const ("Sum_Type.Inl", _) $ _) = true - | is_inj (Const ("Sum_Type.Inr", _) $ _) = true - | is_inj _ = false - -fun dest_inl (Const ("Sum_Type.Inl", _) $ t) = SOME t - | dest_inl _ = NONE - -fun dest_inr (Const ("Sum_Type.Inr", _) $ t) = SOME t - | dest_inr _ = NONE - - -fun mk_skel ps = - let - fun skel i ps = - if forall is_inj ps andalso not (null ps) - then let - val (j, s) = skel i (map_filter dest_inl ps) - val (k, t) = skel j (map_filter dest_inr ps) - in (k, SBranch (s, t)) end - else (i + 1, SLeaf i) - in - snd (skel 0 ps) - end - -(* compute list of types for nodes *) -fun node_types sk T = dest_tree (fn Type ("+", [LT, RT]) => (LT, RT)) sk T |> map snd - -(* find index and raw term *) -fun dest_inj (SLeaf i) trm = (i, trm) - | dest_inj (SBranch (s, t)) trm = - case dest_inl trm of - SOME trm' => dest_inj s trm' - | _ => dest_inj t (the (dest_inr trm)) - - - -(** Matrix cell datatype **) - -datatype cell = Less of thm | LessEq of (thm * thm) | None of (thm * thm) | False of thm; - - -type data = - skel (* structure of the sum type encoding "program points" *) - * (int -> typ) (* types of program points *) - * (term list Inttab.table) (* measures for program points *) - * (thm option Term2tab.table) (* which calls form chains? *) - * (cell Term3tab.table) (* local descents *) - - -fun map_measures f (p, T, M, C, D) = (p, T, f M, C, D) -fun map_chains f (p, T, M, C, D) = (p, T, M, f C, D) -fun map_descent f (p, T, M, C, D) = (p, T, M, C, f D) - -fun note_measure p m = map_measures (Inttab.insert_list (op aconv) (p, m)) -fun note_chain c1 c2 res = map_chains (Term2tab.update ((c1, c2), res)) -fun note_descent c m1 m2 res = map_descent (Term3tab.update ((c,(m1, m2)), res)) - -(* Build case expression *) -fun mk_sumcases (sk, _, _, _, _) T fs = - mk_tree (fn i => (nth fs i, domain_type (fastype_of (nth fs i)))) - (fn ((f, fT), (g, gT)) => (SumTree.mk_sumcase fT gT T f g, SumTree.mk_sumT fT gT)) - sk - |> fst - -fun mk_sum_skel rel = - let - val cs = FundefLib.dest_binop_list @{const_name Un} rel - fun collect_pats (Const ("Collect", _) $ Abs (_, _, c)) = - let - val (Const ("op &", _) $ (Const ("op =", _) $ _ $ (Const ("Pair", _) $ r $ l)) $ Gam) - = Term.strip_qnt_body "Ex" c - in cons r o cons l end - in - mk_skel (fold collect_pats cs []) - end - -fun create ctxt T rel = - let - val sk = mk_sum_skel rel - val Ts = node_types sk T - val M = Inttab.make (map_index (apsnd (MeasureFunctions.get_measure_functions ctxt)) Ts) - in - (sk, nth Ts, M, Term2tab.empty, Term3tab.empty) - end - -fun get_num_points (sk, _, _, _, _) = - let - fun num (SLeaf i) = i + 1 - | num (SBranch (s, t)) = num t - in num sk end - -fun get_types (_, T, _, _, _) = T -fun get_measures (_, _, M, _, _) = Inttab.lookup_list M - -fun get_chain (_, _, _, C, _) c1 c2 = - Term2tab.lookup C (c1, c2) - -fun get_descent (_, _, _, _, D) c m1 m2 = - Term3tab.lookup D (c, (m1, m2)) - -fun dest_call D (Const ("Collect", _) $ Abs (_, _, c)) = - let - val n = get_num_points D - val (sk, _, _, _, _) = D - val vs = Term.strip_qnt_vars "Ex" c - - (* FIXME: throw error "dest_call" for malformed terms *) - val (Const ("op &", _) $ (Const ("op =", _) $ _ $ (Const ("Pair", _) $ r $ l)) $ Gam) - = Term.strip_qnt_body "Ex" c - val (p, l') = dest_inj sk l - val (q, r') = dest_inj sk r - in - (vs, p, l', q, r', Gam) - end - | dest_call D t = error "dest_call" - - -fun derive_desc_aux thy tac c (vs, p, l', q, r', Gam) m1 m2 D = - case get_descent D c m1 m2 of - SOME _ => D - | NONE => let - fun cgoal rel = - Term.list_all (vs, - Logic.mk_implies (HOLogic.mk_Trueprop Gam, - HOLogic.mk_Trueprop (Const (rel, @{typ "nat => nat => bool"}) - $ (m2 $ r') $ (m1 $ l')))) - |> cterm_of thy - in - note_descent c m1 m2 - (case try_proof (cgoal @{const_name HOL.less}) tac of - Solved thm => Less thm - | Stuck thm => - (case try_proof (cgoal @{const_name HOL.less_eq}) tac of - Solved thm2 => LessEq (thm2, thm) - | Stuck thm2 => - if prems_of thm2 = [HOLogic.Trueprop $ HOLogic.false_const] - then False thm2 else None (thm2, thm) - | _ => raise Match) (* FIXME *) - | _ => raise Match) D - end - -fun derive_descent thy tac c m1 m2 D = - derive_desc_aux thy tac c (dest_call D c) m1 m2 D - -(* all descents in one go *) -fun derive_descents thy tac c D = - let val cdesc as (vs, p, l', q, r', Gam) = dest_call D c - in fold_product (derive_desc_aux thy tac c cdesc) - (get_measures D p) (get_measures D q) D - end - -fun CALLS tac i st = - if Thm.no_prems st then all_tac st - else case Thm.term_of (Thm.cprem_of st i) of - (_ $ (_ $ rel)) => tac (FundefLib.dest_binop_list @{const_name Un} rel, i) st - |_ => no_tac st - -type ttac = (data -> int -> tactic) -> (data -> int -> tactic) -> data -> int -> tactic - -fun TERMINATION ctxt tac = - SUBGOAL (fn (_ $ (Const (@{const_name "wf"}, wfT) $ rel), i) => - let - val (T, _) = HOLogic.dest_prodT (HOLogic.dest_setT (domain_type wfT)) - in - tac (create ctxt T rel) i - end) - - -(* A tactic to convert open to closed termination goals *) -local -fun dest_term (t : term) = (* FIXME, cf. Lexicographic order *) - let - val (vars, prop) = FundefLib.dest_all_all t - val (prems, concl) = Logic.strip_horn prop - val (lhs, rhs) = concl - |> HOLogic.dest_Trueprop - |> HOLogic.dest_mem |> fst - |> HOLogic.dest_prod - in - (vars, prems, lhs, rhs) - end - -fun mk_pair_compr (T, qs, l, r, conds) = - let - val pT = HOLogic.mk_prodT (T, T) - val n = length qs - val peq = HOLogic.eq_const pT $ Bound n $ (HOLogic.pair_const T T $ l $ r) - val conds' = if null conds then [HOLogic.true_const] else conds - in - HOLogic.Collect_const pT $ - Abs ("uu_", pT, - (foldr1 HOLogic.mk_conj (peq :: conds') - |> fold_rev (fn v => fn t => HOLogic.exists_const (fastype_of v) $ lambda v t) qs)) - end - -in - -fun wf_union_tac ctxt st = - let - val thy = ProofContext.theory_of ctxt - val cert = cterm_of (theory_of_thm st) - val ((trueprop $ (wf $ rel)) :: ineqs) = prems_of st - - fun mk_compr ineq = - let - val (vars, prems, lhs, rhs) = dest_term ineq - in - mk_pair_compr (fastype_of lhs, vars, lhs, rhs, map (ObjectLogic.atomize_term thy) prems) - end - - val relation = - if null ineqs then - Const (@{const_name Set.empty}, fastype_of rel) - else - foldr1 (HOLogic.mk_binop @{const_name Un}) (map mk_compr ineqs) - - fun solve_membership_tac i = - (EVERY' (replicate (i - 2) (rtac @{thm UnI2})) (* pick the right component of the union *) - THEN' (fn j => TRY (rtac @{thm UnI1} j)) - THEN' (rtac @{thm CollectI}) (* unfold comprehension *) - THEN' (fn i => REPEAT (rtac @{thm exI} i)) (* Turn existentials into schematic Vars *) - THEN' ((rtac @{thm refl}) (* unification instantiates all Vars *) - ORELSE' ((rtac @{thm conjI}) - THEN' (rtac @{thm refl}) - THEN' (blast_tac (local_claset_of ctxt)))) (* Solve rest of context... not very elegant *) - ) i - in - ((PRIMITIVE (Drule.cterm_instantiate [(cert rel, cert relation)]) - THEN ALLGOALS (fn i => if i = 1 then all_tac else solve_membership_tac i))) st - end - - -end - - -(* continuation passing repeat combinator *) -fun REPEAT ttac cont err_cont = - ttac (fn D => fn i => (REPEAT ttac cont cont D i)) err_cont - - - - -end diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/HOL/Wellfounded.thy --- a/src/HOL/Wellfounded.thy Tue Jun 23 12:09:14 2009 +0200 +++ b/src/HOL/Wellfounded.thy Tue Jun 23 12:09:30 2009 +0200 @@ -8,7 +8,7 @@ theory Wellfounded imports Finite_Set Transitive_Closure -uses ("Tools/function_package/size.ML") +uses ("Tools/Function/size.ML") begin subsection {* Basic Definitions *} @@ -693,7 +693,7 @@ lemma in_inv_image[simp]: "((x,y) : inv_image r f) = ((f x, f y) : r)" by (auto simp:inv_image_def) -text {* Measure functions into @{typ nat} *} +text {* Measure Datatypes into @{typ nat} *} definition measure :: "('a => nat) => ('a * 'a)set" where "measure == inv_image less_than" @@ -733,7 +733,7 @@ "[| trans R1; trans R2 |] ==> trans (R1 <*lex*> R2)" by (unfold trans_def lex_prod_def, blast) -text {* lexicographic combinations with measure functions *} +text {* lexicographic combinations with measure Datatypes *} definition mlex_prod :: "('a \ nat) \ ('a \ 'a) set \ ('a \ 'a) set" (infixr "<*mlex*>" 80) @@ -948,7 +948,7 @@ subsection {* size of a datatype value *} -use "Tools/function_package/size.ML" +use "Tools/Function/size.ML" setup Size.setup diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/Tools/Code/code_haskell.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/Tools/Code/code_haskell.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,567 @@ +(* Title: Tools/code/code_haskell.ML + Author: Florian Haftmann, TU Muenchen + +Serializer for Haskell. +*) + +signature CODE_HASKELL = +sig + val setup: theory -> theory +end; + +structure Code_Haskell : CODE_HASKELL = +struct + +val target = "Haskell"; + +open Basic_Code_Thingol; +open Code_Printer; + +infixr 5 @@; +infixr 5 @|; + + +(** Haskell serializer **) + +fun pr_haskell_bind pr_term = + let + fun pr_bind ((NONE, NONE), _) = str "_" + | pr_bind ((SOME v, NONE), _) = str v + | pr_bind ((NONE, SOME p), _) = p + | pr_bind ((SOME v, SOME p), _) = brackets [str v, str "@", p]; + in gen_pr_bind pr_bind pr_term end; + +fun pr_haskell_stmt labelled_name syntax_class syntax_tyco syntax_const + init_syms deresolve is_cons contr_classparam_typs deriving_show = + let + val deresolve_base = Long_Name.base_name o deresolve; + fun class_name class = case syntax_class class + of NONE => deresolve class + | SOME class => class; + fun pr_typcontext tyvars vs = case maps (fn (v, sort) => map (pair v) sort) vs + of [] => [] + | classbinds => Pretty.enum "," "(" ")" ( + map (fn (v, class) => + str (class_name class ^ " " ^ Code_Printer.lookup_var tyvars v)) classbinds) + @@ str " => "; + fun pr_typforall tyvars vs = case map fst vs + of [] => [] + | vnames => str "forall " :: Pretty.breaks + (map (str o Code_Printer.lookup_var tyvars) vnames) @ str "." @@ Pretty.brk 1; + fun pr_tycoexpr tyvars fxy (tyco, tys) = + brackify fxy (str tyco :: map (pr_typ tyvars BR) tys) + and pr_typ tyvars fxy (tycoexpr as tyco `%% tys) = (case syntax_tyco tyco + of NONE => pr_tycoexpr tyvars fxy (deresolve tyco, tys) + | SOME (i, pr) => pr (pr_typ tyvars) fxy tys) + | pr_typ tyvars fxy (ITyVar v) = (str o Code_Printer.lookup_var tyvars) v; + fun pr_typdecl tyvars (vs, tycoexpr) = + Pretty.block (pr_typcontext tyvars vs @| pr_tycoexpr tyvars NOBR tycoexpr); + fun pr_typscheme tyvars (vs, ty) = + Pretty.block (pr_typforall tyvars vs @ pr_typcontext tyvars vs @| pr_typ tyvars NOBR ty); + fun pr_term tyvars thm vars fxy (IConst c) = + pr_app tyvars thm vars fxy (c, []) + | pr_term tyvars thm vars fxy (t as (t1 `$ t2)) = + (case Code_Thingol.unfold_const_app t + of SOME app => pr_app tyvars thm vars fxy app + | _ => + brackify fxy [ + pr_term tyvars thm vars NOBR t1, + pr_term tyvars thm vars BR t2 + ]) + | pr_term tyvars thm vars fxy (IVar v) = + (str o Code_Printer.lookup_var vars) v + | pr_term tyvars thm vars fxy (t as _ `|=> _) = + let + val (binds, t') = Code_Thingol.unfold_abs t; + fun pr ((v, pat), ty) = pr_bind tyvars thm BR ((SOME v, pat), ty); + val (ps, vars') = fold_map pr binds vars; + in brackets (str "\\" :: ps @ str "->" @@ pr_term tyvars thm vars' NOBR t') end + | pr_term tyvars thm vars fxy (ICase (cases as (_, t0))) = + (case Code_Thingol.unfold_const_app t0 + of SOME (c_ts as ((c, _), _)) => if is_none (syntax_const c) + then pr_case tyvars thm vars fxy cases + else pr_app tyvars thm vars fxy c_ts + | NONE => pr_case tyvars thm vars fxy cases) + and pr_app' tyvars thm vars ((c, (_, tys)), ts) = case contr_classparam_typs c + of [] => (str o deresolve) c :: map (pr_term tyvars thm vars BR) ts + | fingerprint => let + val ts_fingerprint = ts ~~ curry Library.take (length ts) fingerprint; + val needs_annotation = forall (fn (_, NONE) => true | (t, SOME _) => + (not o Code_Thingol.locally_monomorphic) t) ts_fingerprint; + fun pr_term_anno (t, NONE) _ = pr_term tyvars thm vars BR t + | pr_term_anno (t, SOME _) ty = + brackets [pr_term tyvars thm vars NOBR t, str "::", pr_typ tyvars NOBR ty]; + in + if needs_annotation then + (str o deresolve) c :: map2 pr_term_anno ts_fingerprint (curry Library.take (length ts) tys) + else (str o deresolve) c :: map (pr_term tyvars thm vars BR) ts + end + and pr_app tyvars = gen_pr_app (pr_app' tyvars) (pr_term tyvars) syntax_const + and pr_bind tyvars = pr_haskell_bind (pr_term tyvars) + and pr_case tyvars thm vars fxy (cases as ((_, [_]), _)) = + let + val (binds, body) = Code_Thingol.unfold_let (ICase cases); + fun pr ((pat, ty), t) vars = + vars + |> pr_bind tyvars thm BR ((NONE, SOME pat), ty) + |>> (fn p => semicolon [p, str "=", pr_term tyvars thm vars NOBR t]) + val (ps, vars') = fold_map pr binds vars; + in brackify_block fxy (str "let {") + ps + (concat [str "}", str "in", pr_term tyvars thm vars' NOBR body]) + end + | pr_case tyvars thm vars fxy (((t, ty), clauses as _ :: _), _) = + let + fun pr (pat, body) = + let + val (p, vars') = pr_bind tyvars thm NOBR ((NONE, SOME pat), ty) vars; + in semicolon [p, str "->", pr_term tyvars thm vars' NOBR body] end; + in brackify_block fxy + (concat [str "case", pr_term tyvars thm vars NOBR t, str "of", str "{"]) + (map pr clauses) + (str "}") + end + | pr_case tyvars thm vars fxy ((_, []), _) = + (brackify fxy o Pretty.breaks o map str) ["error", "\"empty case\""]; + fun pr_stmt (name, Code_Thingol.Fun (_, ((vs, ty), []))) = + let + val tyvars = Code_Printer.intro_vars (map fst vs) init_syms; + val n = (length o fst o Code_Thingol.unfold_fun) ty; + in + Pretty.chunks [ + Pretty.block [ + (str o suffix " ::" o deresolve_base) name, + Pretty.brk 1, + pr_typscheme tyvars (vs, ty), + str ";" + ], + concat ( + (str o deresolve_base) name + :: map str (replicate n "_") + @ str "=" + :: str "error" + @@ (str o (fn s => s ^ ";") o ML_Syntax.print_string + o Long_Name.base_name o Long_Name.qualifier) name + ) + ] + end + | pr_stmt (name, Code_Thingol.Fun (_, ((vs, ty), raw_eqs))) = + let + val eqs = filter (snd o snd) raw_eqs; + val tyvars = Code_Printer.intro_vars (map fst vs) init_syms; + fun pr_eq ((ts, t), (thm, _)) = + let + val consts = map_filter + (fn c => if (is_some o syntax_const) c + then NONE else (SOME o Long_Name.base_name o deresolve) c) + ((fold o Code_Thingol.fold_constnames) (insert (op =)) (t :: ts) []); + val vars = init_syms + |> Code_Printer.intro_vars consts + |> Code_Printer.intro_vars ((fold o Code_Thingol.fold_unbound_varnames) + (insert (op =)) ts []); + in + semicolon ( + (str o deresolve_base) name + :: map (pr_term tyvars thm vars BR) ts + @ str "=" + @@ pr_term tyvars thm vars NOBR t + ) + end; + in + Pretty.chunks ( + Pretty.block [ + (str o suffix " ::" o deresolve_base) name, + Pretty.brk 1, + pr_typscheme tyvars (vs, ty), + str ";" + ] + :: map pr_eq eqs + ) + end + | pr_stmt (name, Code_Thingol.Datatype (_, (vs, []))) = + let + val tyvars = Code_Printer.intro_vars (map fst vs) init_syms; + in + semicolon [ + str "data", + pr_typdecl tyvars (vs, (deresolve_base name, map (ITyVar o fst) vs)) + ] + end + | pr_stmt (name, Code_Thingol.Datatype (_, (vs, [(co, [ty])]))) = + let + val tyvars = Code_Printer.intro_vars (map fst vs) init_syms; + in + semicolon ( + str "newtype" + :: pr_typdecl tyvars (vs, (deresolve_base name, map (ITyVar o fst) vs)) + :: str "=" + :: (str o deresolve_base) co + :: pr_typ tyvars BR ty + :: (if deriving_show name then [str "deriving (Read, Show)"] else []) + ) + end + | pr_stmt (name, Code_Thingol.Datatype (_, (vs, co :: cos))) = + let + val tyvars = Code_Printer.intro_vars (map fst vs) init_syms; + fun pr_co (co, tys) = + concat ( + (str o deresolve_base) co + :: map (pr_typ tyvars BR) tys + ) + in + semicolon ( + str "data" + :: pr_typdecl tyvars (vs, (deresolve_base name, map (ITyVar o fst) vs)) + :: str "=" + :: pr_co co + :: map ((fn p => Pretty.block [str "| ", p]) o pr_co) cos + @ (if deriving_show name then [str "deriving (Read, Show)"] else []) + ) + end + | pr_stmt (name, Code_Thingol.Class (_, (v, (superclasses, classparams)))) = + let + val tyvars = Code_Printer.intro_vars [v] init_syms; + fun pr_classparam (classparam, ty) = + semicolon [ + (str o deresolve_base) classparam, + str "::", + pr_typ tyvars NOBR ty + ] + in + Pretty.block_enclose ( + Pretty.block [ + str "class ", + Pretty.block (pr_typcontext tyvars [(v, map fst superclasses)]), + str (deresolve_base name ^ " " ^ Code_Printer.lookup_var tyvars v), + str " where {" + ], + str "};" + ) (map pr_classparam classparams) + end + | pr_stmt (_, Code_Thingol.Classinst ((class, (tyco, vs)), (_, classparam_insts))) = + let + val split_abs_pure = (fn (v, _) `|=> t => SOME (v, t) | _ => NONE); + val unfold_abs_pure = Code_Thingol.unfoldr split_abs_pure; + val tyvars = Code_Printer.intro_vars (map fst vs) init_syms; + fun pr_instdef ((classparam, c_inst), (thm, _)) = case syntax_const classparam + of NONE => semicolon [ + (str o deresolve_base) classparam, + str "=", + pr_app tyvars thm init_syms NOBR (c_inst, []) + ] + | SOME (k, pr) => + let + val (c_inst_name, (_, tys)) = c_inst; + val const = if (is_some o syntax_const) c_inst_name + then NONE else (SOME o Long_Name.base_name o deresolve) c_inst_name; + val proto_rhs = Code_Thingol.eta_expand k (c_inst, []); + val (vs, rhs) = unfold_abs_pure proto_rhs; + val vars = init_syms + |> Code_Printer.intro_vars (the_list const) + |> Code_Printer.intro_vars vs; + val lhs = IConst (classparam, (([], []), tys)) `$$ map IVar vs; + (*dictionaries are not relevant at this late stage*) + in + semicolon [ + pr_term tyvars thm vars NOBR lhs, + str "=", + pr_term tyvars thm vars NOBR rhs + ] + end; + in + Pretty.block_enclose ( + Pretty.block [ + str "instance ", + Pretty.block (pr_typcontext tyvars vs), + str (class_name class ^ " "), + pr_typ tyvars BR (tyco `%% map (ITyVar o fst) vs), + str " where {" + ], + str "};" + ) (map pr_instdef classparam_insts) + end; + in pr_stmt end; + +fun haskell_program_of_program labelled_name module_name module_prefix reserved_names raw_module_alias program = + let + val module_alias = if is_some module_name then K module_name else raw_module_alias; + val reserved_names = Name.make_context reserved_names; + val mk_name_module = Code_Printer.mk_name_module reserved_names module_prefix module_alias program; + fun add_stmt (name, (stmt, deps)) = + let + val (module_name, base) = Code_Printer.dest_name name; + val module_name' = mk_name_module module_name; + val mk_name_stmt = yield_singleton Name.variants; + fun add_fun upper (nsp_fun, nsp_typ) = + let + val (base', nsp_fun') = + mk_name_stmt (if upper then Code_Printer.first_upper base else base) nsp_fun + in (base', (nsp_fun', nsp_typ)) end; + fun add_typ (nsp_fun, nsp_typ) = + let + val (base', nsp_typ') = mk_name_stmt (Code_Printer.first_upper base) nsp_typ + in (base', (nsp_fun, nsp_typ')) end; + val add_name = case stmt + of Code_Thingol.Fun _ => add_fun false + | Code_Thingol.Datatype _ => add_typ + | Code_Thingol.Datatypecons _ => add_fun true + | Code_Thingol.Class _ => add_typ + | Code_Thingol.Classrel _ => pair base + | Code_Thingol.Classparam _ => add_fun false + | Code_Thingol.Classinst _ => pair base; + fun add_stmt' base' = case stmt + of Code_Thingol.Datatypecons _ => + cons (name, (Long_Name.append module_name' base', NONE)) + | Code_Thingol.Classrel _ => I + | Code_Thingol.Classparam _ => + cons (name, (Long_Name.append module_name' base', NONE)) + | _ => cons (name, (Long_Name.append module_name' base', SOME stmt)); + in + Symtab.map_default (module_name', ([], ([], (reserved_names, reserved_names)))) + (apfst (fold (insert (op = : string * string -> bool)) deps)) + #> `(fn program => add_name ((snd o snd o the o Symtab.lookup program) module_name')) + #-> (fn (base', names) => + (Symtab.map_entry module_name' o apsnd) (fn (stmts, _) => + (add_stmt' base' stmts, names))) + end; + val hs_program = fold add_stmt (AList.make (fn name => + (Graph.get_node program name, Graph.imm_succs program name)) + (Graph.strong_conn program |> flat)) Symtab.empty; + fun deresolver name = (fst o the o AList.lookup (op =) ((fst o snd o the + o Symtab.lookup hs_program) ((mk_name_module o fst o Code_Printer.dest_name) name))) name + handle Option => error ("Unknown statement name: " ^ labelled_name name); + in (deresolver, hs_program) end; + +fun serialize_haskell module_prefix raw_module_name string_classes labelled_name + raw_reserved_names includes raw_module_alias + syntax_class syntax_tyco syntax_const program cs destination = + let + val stmt_names = Code_Target.stmt_names_of_destination destination; + val module_name = if null stmt_names then raw_module_name else SOME "Code"; + val reserved_names = fold (insert (op =) o fst) includes raw_reserved_names; + val (deresolver, hs_program) = haskell_program_of_program labelled_name + module_name module_prefix reserved_names raw_module_alias program; + val is_cons = Code_Thingol.is_cons program; + val contr_classparam_typs = Code_Thingol.contr_classparam_typs program; + fun deriving_show tyco = + let + fun deriv _ "fun" = false + | deriv tycos tyco = member (op =) tycos tyco orelse + case try (Graph.get_node program) tyco + of SOME (Code_Thingol.Datatype (_, (_, cs))) => forall (deriv' (tyco :: tycos)) + (maps snd cs) + | NONE => true + and deriv' tycos (tyco `%% tys) = deriv tycos tyco + andalso forall (deriv' tycos) tys + | deriv' _ (ITyVar _) = true + in deriv [] tyco end; + val reserved_names = Code_Printer.make_vars reserved_names; + fun pr_stmt qualified = pr_haskell_stmt labelled_name + syntax_class syntax_tyco syntax_const reserved_names + (if qualified then deresolver else Long_Name.base_name o deresolver) + is_cons contr_classparam_typs + (if string_classes then deriving_show else K false); + fun pr_module name content = + (name, Pretty.chunks [ + str ("module " ^ name ^ " where {"), + str "", + content, + str "", + str "}" + ]); + fun serialize_module1 (module_name', (deps, (stmts, _))) = + let + val stmt_names = map fst stmts; + val deps' = subtract (op =) stmt_names deps + |> distinct (op =) + |> map_filter (try deresolver); + val qualified = is_none module_name andalso + map deresolver stmt_names @ deps' + |> map Long_Name.base_name + |> has_duplicates (op =); + val imports = deps' + |> map Long_Name.qualifier + |> distinct (op =); + fun pr_import_include (name, _) = str ("import qualified " ^ name ^ ";"); + val pr_import_module = str o (if qualified + then prefix "import qualified " + else prefix "import ") o suffix ";"; + val content = Pretty.chunks ( + map pr_import_include includes + @ map pr_import_module imports + @ str "" + :: separate (str "") (map_filter + (fn (name, (_, SOME stmt)) => SOME (pr_stmt qualified (name, stmt)) + | (_, (_, NONE)) => NONE) stmts) + ) + in pr_module module_name' content end; + fun serialize_module2 (_, (_, (stmts, _))) = Pretty.chunks ( + separate (str "") (map_filter + (fn (name, (_, SOME stmt)) => if null stmt_names + orelse member (op =) stmt_names name + then SOME (pr_stmt false (name, stmt)) + else NONE + | (_, (_, NONE)) => NONE) stmts)); + val serialize_module = + if null stmt_names then serialize_module1 else pair "" o serialize_module2; + fun check_destination destination = + (File.check destination; destination); + fun write_module destination (modlname, content) = + let + val filename = case modlname + of "" => Path.explode "Main.hs" + | _ => (Path.ext "hs" o Path.explode o implode o separate "/" + o Long_Name.explode) modlname; + val pathname = Path.append destination filename; + val _ = File.mkdir (Path.dir pathname); + in File.write pathname + ("{-# OPTIONS_GHC -fglasgow-exts #-}\n\n" + ^ Code_Target.code_of_pretty content) + end + in + Code_Target.mk_serialization target NONE + (fn NONE => K () o map (Code_Target.code_writeln o snd) | SOME file => K () o map + (write_module (check_destination file))) + (rpair [] o cat_lines o map (Code_Target.code_of_pretty o snd)) + (map (uncurry pr_module) includes + @ map serialize_module (Symtab.dest hs_program)) + destination + end; + +val literals = let + fun char_haskell c = + let + val s = ML_Syntax.print_char c; + in if s = "'" then "\\'" else s end; +in Literals { + literal_char = enclose "'" "'" o char_haskell, + literal_string = quote o translate_string char_haskell, + literal_numeral = fn unbounded => fn k => if k >= 0 then string_of_int k + else enclose "(" ")" (signed_string_of_int k), + literal_list = Pretty.enum "," "[" "]", + infix_cons = (5, ":") +} end; + + +(** optional monad syntax **) + +fun pretty_haskell_monad c_bind = + let + fun dest_bind t1 t2 = case Code_Thingol.split_abs t2 + of SOME (((v, pat), ty), t') => + SOME ((SOME (((SOME v, pat), ty), true), t1), t') + | NONE => NONE; + fun dest_monad c_bind_name (IConst (c, _) `$ t1 `$ t2) = + if c = c_bind_name then dest_bind t1 t2 + else NONE + | dest_monad _ t = case Code_Thingol.split_let t + of SOME (((pat, ty), tbind), t') => + SOME ((SOME (((NONE, SOME pat), ty), false), tbind), t') + | NONE => NONE; + fun implode_monad c_bind_name = Code_Thingol.unfoldr (dest_monad c_bind_name); + fun pr_monad pr_bind pr (NONE, t) vars = + (semicolon [pr vars NOBR t], vars) + | pr_monad pr_bind pr (SOME (bind, true), t) vars = vars + |> pr_bind NOBR bind + |>> (fn p => semicolon [p, str "<-", pr vars NOBR t]) + | pr_monad pr_bind pr (SOME (bind, false), t) vars = vars + |> pr_bind NOBR bind + |>> (fn p => semicolon [str "let", p, str "=", pr vars NOBR t]); + fun pretty _ [c_bind'] pr thm vars fxy [(t1, _), (t2, _)] = case dest_bind t1 t2 + of SOME (bind, t') => let + val (binds, t'') = implode_monad c_bind' t' + val (ps, vars') = fold_map (pr_monad (pr_haskell_bind (K pr) thm) pr) (bind :: binds) vars; + in (brackify fxy o single o Pretty.enclose "do {" "}" o Pretty.breaks) (ps @| pr vars' NOBR t'') end + | NONE => brackify_infix (1, L) fxy + [pr vars (INFX (1, L)) t1, str ">>=", pr vars (INFX (1, X)) t2] + in (2, ([c_bind], pretty)) end; + +fun add_monad target' raw_c_bind thy = + let + val c_bind = Code.read_const thy raw_c_bind; + in if target = target' then + thy + |> Code_Target.add_syntax_const target c_bind + (SOME (pretty_haskell_monad c_bind)) + else error "Only Haskell target allows for monad syntax" end; + + +(** Isar setup **) + +fun isar_seri_haskell module = + Code_Target.parse_args (Scan.option (Args.$$$ "root" -- Args.colon |-- Args.name) + -- Scan.optional (Args.$$$ "string_classes" >> K true) false + >> (fn (module_prefix, string_classes) => + serialize_haskell module_prefix module string_classes)); + +val _ = + OuterSyntax.command "code_monad" "define code syntax for monads" OuterKeyword.thy_decl ( + OuterParse.term_group -- OuterParse.name >> (fn (raw_bind, target) => + Toplevel.theory (add_monad target raw_bind)) + ); + +val setup = + Code_Target.add_target (target, (isar_seri_haskell, literals)) + #> Code_Target.add_syntax_tyco target "fun" (SOME (2, fn pr_typ => fn fxy => fn [ty1, ty2] => + brackify_infix (1, R) fxy [ + pr_typ (INFX (1, X)) ty1, + str "->", + pr_typ (INFX (1, R)) ty2 + ])) + #> fold (Code_Target.add_reserved target) [ + "hiding", "deriving", "where", "case", "of", "infix", "infixl", "infixr", + "import", "default", "forall", "let", "in", "class", "qualified", "data", + "newtype", "instance", "if", "then", "else", "type", "as", "do", "module" + ] + #> fold (Code_Target.add_reserved target) [ + "Prelude", "Main", "Bool", "Maybe", "Either", "Ordering", "Char", "String", "Int", + "Integer", "Float", "Double", "Rational", "IO", "Eq", "Ord", "Enum", "Bounded", + "Num", "Real", "Integral", "Fractional", "Floating", "RealFloat", "Monad", "Functor", + "AlreadyExists", "ArithException", "ArrayException", "AssertionFailed", "AsyncException", + "BlockedOnDeadMVar", "Deadlock", "Denormal", "DivideByZero", "DotNetException", "DynException", + "Dynamic", "EOF", "EQ", "EmptyRec", "ErrorCall", "ExitException", "ExitFailure", + "ExitSuccess", "False", "GT", "HeapOverflow", + "IOError", "IOException", "IllegalOperation", + "IndexOutOfBounds", "Just", "Key", "LT", "Left", "LossOfPrecision", "NoMethodError", + "NoSuchThing", "NonTermination", "Nothing", "Obj", "OtherError", "Overflow", + "PatternMatchFail", "PermissionDenied", "ProtocolError", "RecConError", "RecSelError", + "RecUpdError", "ResourceBusy", "ResourceExhausted", "Right", "StackOverflow", + "ThreadKilled", "True", "TyCon", "TypeRep", "UndefinedElement", "Underflow", + "UnsupportedOperation", "UserError", "abs", "absReal", "acos", "acosh", "all", + "and", "any", "appendFile", "asTypeOf", "asciiTab", "asin", "asinh", "atan", + "atan2", "atanh", "basicIORun", "blockIO", "boundedEnumFrom", "boundedEnumFromThen", + "boundedEnumFromThenTo", "boundedEnumFromTo", "boundedPred", "boundedSucc", "break", + "catch", "catchException", "ceiling", "compare", "concat", "concatMap", "const", + "cos", "cosh", "curry", "cycle", "decodeFloat", "denominator", "div", "divMod", + "doubleToRatio", "doubleToRational", "drop", "dropWhile", "either", "elem", + "emptyRec", "encodeFloat", "enumFrom", "enumFromThen", "enumFromThenTo", + "enumFromTo", "error", "even", "exp", "exponent", "fail", "filter", "flip", + "floatDigits", "floatProperFraction", "floatRadix", "floatRange", "floatToRational", + "floor", "fmap", "foldl", "foldl'", "foldl1", "foldr", "foldr1", "fromDouble", + "fromEnum", "fromEnum_0", "fromInt", "fromInteger", "fromIntegral", "fromObj", + "fromRational", "fst", "gcd", "getChar", "getContents", "getLine", "head", + "id", "inRange", "index", "init", "intToRatio", "interact", "ioError", "isAlpha", + "isAlphaNum", "isDenormalized", "isDigit", "isHexDigit", "isIEEE", "isInfinite", + "isLower", "isNaN", "isNegativeZero", "isOctDigit", "isSpace", "isUpper", "iterate", "iterate'", + "last", "lcm", "length", "lex", "lexDigits", "lexLitChar", "lexmatch", "lines", "log", + "logBase", "lookup", "loop", "map", "mapM", "mapM_", "max", "maxBound", "maximum", + "maybe", "min", "minBound", "minimum", "mod", "negate", "nonnull", "not", "notElem", + "null", "numerator", "numericEnumFrom", "numericEnumFromThen", "numericEnumFromThenTo", + "numericEnumFromTo", "odd", "or", "otherwise", "pi", "pred", + "print", "product", "properFraction", "protectEsc", "putChar", "putStr", "putStrLn", + "quot", "quotRem", "range", "rangeSize", "rationalToDouble", "rationalToFloat", + "rationalToRealFloat", "read", "readDec", "readField", "readFieldName", "readFile", + "readFloat", "readHex", "readIO", "readInt", "readList", "readLitChar", "readLn", + "readOct", "readParen", "readSigned", "reads", "readsPrec", "realFloatToRational", + "realToFrac", "recip", "reduce", "rem", "repeat", "replicate", "return", "reverse", + "round", "scaleFloat", "scanl", "scanl1", "scanr", "scanr1", "seq", "sequence", + "sequence_", "show", "showChar", "showException", "showField", "showList", + "showLitChar", "showParen", "showString", "shows", "showsPrec", "significand", + "signum", "signumReal", "sin", "sinh", "snd", "span", "splitAt", "sqrt", "subtract", + "succ", "sum", "tail", "take", "takeWhile", "takeWhile1", "tan", "tanh", "threadToIOResult", + "throw", "toEnum", "toInt", "toInteger", "toObj", "toRational", "truncate", "uncurry", + "undefined", "unlines", "unsafeCoerce", "unsafeIndex", "unsafeRangeSize", "until", "unwords", + "unzip", "unzip3", "userError", "words", "writeFile", "zip", "zip3", "zipWith", "zipWith3" + ] (*due to weird handling of ':', we can't do anything else than to import *all* prelude symbols*); + +end; (*struct*) diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/Tools/Code/code_ml.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/Tools/Code/code_ml.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,1122 @@ +(* Title: Tools/code/code_ml.ML + Author: Florian Haftmann, TU Muenchen + +Serializer for SML and OCaml. +*) + +signature CODE_ML = +sig + val eval: string option -> string * (unit -> 'a) option ref + -> ((term -> term) -> 'a -> 'a) -> theory -> term -> string list -> 'a + val target_Eval: string + val setup: theory -> theory +end; + +structure Code_ML : CODE_ML = +struct + +open Basic_Code_Thingol; +open Code_Printer; + +infixr 5 @@; +infixr 5 @|; + +val target_SML = "SML"; +val target_OCaml = "OCaml"; +val target_Eval = "Eval"; + +datatype ml_stmt = + MLExc of string * int + | MLVal of string * ((typscheme * iterm) * (thm * bool)) + | MLFuns of (string * (typscheme * ((iterm list * iterm) * (thm * bool)) list)) list * string list + | MLDatas of (string * ((vname * sort) list * (string * itype list) list)) list + | MLClass of string * (vname * ((class * string) list * (string * itype) list)) + | MLClassinst of string * ((class * (string * (vname * sort) list)) + * ((class * (string * (string * dict list list))) list + * ((string * const) * (thm * bool)) list)); + +fun stmt_names_of (MLExc (name, _)) = [name] + | stmt_names_of (MLVal (name, _)) = [name] + | stmt_names_of (MLFuns (fs, _)) = map fst fs + | stmt_names_of (MLDatas ds) = map fst ds + | stmt_names_of (MLClass (name, _)) = [name] + | stmt_names_of (MLClassinst (name, _)) = [name]; + + +(** SML serailizer **) + +fun pr_sml_stmt labelled_name syntax_tyco syntax_const reserved_names deresolve is_cons = + let + fun pr_dicts fxy ds = + let + fun pr_dictvar (v, (_, 1)) = Code_Printer.first_upper v ^ "_" + | pr_dictvar (v, (i, _)) = Code_Printer.first_upper v ^ string_of_int (i+1) ^ "_"; + fun pr_proj [] p = + p + | pr_proj [p'] p = + brackets [p', p] + | pr_proj (ps as _ :: _) p = + brackets [Pretty.enum " o" "(" ")" ps, p]; + fun pr_dict fxy (DictConst (inst, dss)) = + brackify fxy ((str o deresolve) inst :: map (pr_dicts BR) dss) + | pr_dict fxy (DictVar (classrels, v)) = + pr_proj (map (str o deresolve) classrels) ((str o pr_dictvar) v) + in case ds + of [] => str "()" + | [d] => pr_dict fxy d + | _ :: _ => (Pretty.list "(" ")" o map (pr_dict NOBR)) ds + end; + fun pr_tyvar_dicts vs = + vs + |> map (fn (v, sort) => map_index (fn (i, _) => + DictVar ([], (v, (i, length sort)))) sort) + |> map (pr_dicts BR); + fun pr_tycoexpr fxy (tyco, tys) = + let + val tyco' = (str o deresolve) tyco + in case map (pr_typ BR) tys + of [] => tyco' + | [p] => Pretty.block [p, Pretty.brk 1, tyco'] + | (ps as _::_) => Pretty.block [Pretty.list "(" ")" ps, Pretty.brk 1, tyco'] + end + and pr_typ fxy (tyco `%% tys) = (case syntax_tyco tyco + of NONE => pr_tycoexpr fxy (tyco, tys) + | SOME (i, pr) => pr pr_typ fxy tys) + | pr_typ fxy (ITyVar v) = str ("'" ^ v); + fun pr_term is_closure thm vars fxy (IConst c) = + pr_app is_closure thm vars fxy (c, []) + | pr_term is_closure thm vars fxy (IVar v) = + str (Code_Printer.lookup_var vars v) + | pr_term is_closure thm vars fxy (t as t1 `$ t2) = + (case Code_Thingol.unfold_const_app t + of SOME c_ts => pr_app is_closure thm vars fxy c_ts + | NONE => brackify fxy + [pr_term is_closure thm vars NOBR t1, pr_term is_closure thm vars BR t2]) + | pr_term is_closure thm vars fxy (t as _ `|=> _) = + let + val (binds, t') = Code_Thingol.unfold_abs t; + fun pr ((v, pat), ty) = + pr_bind is_closure thm NOBR ((SOME v, pat), ty) + #>> (fn p => concat [str "fn", p, str "=>"]); + val (ps, vars') = fold_map pr binds vars; + in brackets (ps @ [pr_term is_closure thm vars' NOBR t']) end + | pr_term is_closure thm vars fxy (ICase (cases as (_, t0))) = + (case Code_Thingol.unfold_const_app t0 + of SOME (c_ts as ((c, _), _)) => if is_none (syntax_const c) + then pr_case is_closure thm vars fxy cases + else pr_app is_closure thm vars fxy c_ts + | NONE => pr_case is_closure thm vars fxy cases) + and pr_app' is_closure thm vars (app as ((c, ((_, iss), tys)), ts)) = + if is_cons c then + let + val k = length tys + in if k < 2 then + (str o deresolve) c :: map (pr_term is_closure thm vars BR) ts + else if k = length ts then + [(str o deresolve) c, Pretty.enum "," "(" ")" (map (pr_term is_closure thm vars NOBR) ts)] + else [pr_term is_closure thm vars BR (Code_Thingol.eta_expand k app)] end + else if is_closure c + then (str o deresolve) c @@ str "()" + else + (str o deresolve) c + :: (map (pr_dicts BR) o filter_out null) iss @ map (pr_term is_closure thm vars BR) ts + and pr_app is_closure thm vars = gen_pr_app (pr_app' is_closure) (pr_term is_closure) + syntax_const thm vars + and pr_bind' ((NONE, NONE), _) = str "_" + | pr_bind' ((SOME v, NONE), _) = str v + | pr_bind' ((NONE, SOME p), _) = p + | pr_bind' ((SOME v, SOME p), _) = concat [str v, str "as", p] + and pr_bind is_closure = gen_pr_bind pr_bind' (pr_term is_closure) + and pr_case is_closure thm vars fxy (cases as ((_, [_]), _)) = + let + val (binds, body) = Code_Thingol.unfold_let (ICase cases); + fun pr ((pat, ty), t) vars = + vars + |> pr_bind is_closure thm NOBR ((NONE, SOME pat), ty) + |>> (fn p => semicolon [str "val", p, str "=", pr_term is_closure thm vars NOBR t]) + val (ps, vars') = fold_map pr binds vars; + in + Pretty.chunks [ + [str ("let"), Pretty.fbrk, Pretty.chunks ps] |> Pretty.block, + [str ("in"), Pretty.fbrk, pr_term is_closure thm vars' NOBR body] |> Pretty.block, + str ("end") + ] + end + | pr_case is_closure thm vars fxy (((t, ty), clause :: clauses), _) = + let + fun pr delim (pat, body) = + let + val (p, vars') = pr_bind is_closure thm NOBR ((NONE, SOME pat), ty) vars; + in + concat [str delim, p, str "=>", pr_term is_closure thm vars' NOBR body] + end; + in + brackets ( + str "case" + :: pr_term is_closure thm vars NOBR t + :: pr "of" clause + :: map (pr "|") clauses + ) + end + | pr_case is_closure thm vars fxy ((_, []), _) = + (concat o map str) ["raise", "Fail", "\"empty case\""]; + fun pr_stmt (MLExc (name, n)) = + let + val exc_str = + (ML_Syntax.print_string o Long_Name.base_name o Long_Name.qualifier) name; + in + (concat o map str) ( + (if n = 0 then "val" else "fun") + :: deresolve name + :: replicate n "_" + @ "=" + :: "raise" + :: "Fail" + @@ exc_str + ) + end + | pr_stmt (MLVal (name, (((vs, ty), t), (thm, _)))) = + let + val consts = map_filter + (fn c => if (is_some o syntax_const) c + then NONE else (SOME o Long_Name.base_name o deresolve) c) + (Code_Thingol.fold_constnames (insert (op =)) t []); + val vars = reserved_names + |> Code_Printer.intro_vars consts; + in + concat [ + str "val", + (str o deresolve) name, + str ":", + pr_typ NOBR ty, + str "=", + pr_term (K false) thm vars NOBR t + ] + end + | pr_stmt (MLFuns (funn :: funns, pseudo_funs)) = + let + fun pr_funn definer (name, ((vs, ty), eqs as eq :: eqs')) = + let + val vs_dict = filter_out (null o snd) vs; + val shift = if null eqs' then I else + map (Pretty.block o single o Pretty.block o single); + fun pr_eq definer ((ts, t), (thm, _)) = + let + val consts = map_filter + (fn c => if (is_some o syntax_const) c + then NONE else (SOME o Long_Name.base_name o deresolve) c) + ((fold o Code_Thingol.fold_constnames) (insert (op =)) (t :: ts) []); + val vars = reserved_names + |> Code_Printer.intro_vars consts + |> Code_Printer.intro_vars ((fold o Code_Thingol.fold_unbound_varnames) + (insert (op =)) ts []); + in + concat ( + str definer + :: (str o deresolve) name + :: (if member (op =) pseudo_funs name then [str "()"] + else pr_tyvar_dicts vs_dict + @ map (pr_term (member (op =) pseudo_funs) thm vars BR) ts) + @ str "=" + @@ pr_term (member (op =) pseudo_funs) thm vars NOBR t + ) + end + in + (Pretty.block o Pretty.fbreaks o shift) ( + pr_eq definer eq + :: map (pr_eq "|") eqs' + ) + end; + fun pr_pseudo_fun name = concat [ + str "val", + (str o deresolve) name, + str "=", + (str o deresolve) name, + str "();" + ]; + val (ps, p) = split_last (pr_funn "fun" funn :: map (pr_funn "and") funns); + val pseudo_ps = map pr_pseudo_fun pseudo_funs; + in Pretty.chunks (ps @ Pretty.block ([p, str ";"]) :: pseudo_ps) end + | pr_stmt (MLDatas (datas as (data :: datas'))) = + let + fun pr_co (co, []) = + str (deresolve co) + | pr_co (co, tys) = + concat [ + str (deresolve co), + str "of", + Pretty.enum " *" "" "" (map (pr_typ (INFX (2, X))) tys) + ]; + fun pr_data definer (tyco, (vs, [])) = + concat ( + str definer + :: pr_tycoexpr NOBR (tyco, map (ITyVar o fst) vs) + :: str "=" + @@ str "EMPTY__" + ) + | pr_data definer (tyco, (vs, cos)) = + concat ( + str definer + :: pr_tycoexpr NOBR (tyco, map (ITyVar o fst) vs) + :: str "=" + :: separate (str "|") (map pr_co cos) + ); + val (ps, p) = split_last + (pr_data "datatype" data :: map (pr_data "and") datas'); + in Pretty.chunks (ps @| Pretty.block ([p, str ";"])) end + | pr_stmt (MLClass (class, (v, (superclasses, classparams)))) = + let + val w = Code_Printer.first_upper v ^ "_"; + fun pr_superclass_field (class, classrel) = + (concat o map str) [ + deresolve classrel, ":", "'" ^ v, deresolve class + ]; + fun pr_classparam_field (classparam, ty) = + concat [ + (str o deresolve) classparam, str ":", pr_typ NOBR ty + ]; + fun pr_classparam_proj (classparam, _) = + semicolon [ + str "fun", + (str o deresolve) classparam, + Pretty.enclose "(" ")" [str (w ^ ":'" ^ v ^ " " ^ deresolve class)], + str "=", + str ("#" ^ deresolve classparam), + str w + ]; + fun pr_superclass_proj (_, classrel) = + semicolon [ + str "fun", + (str o deresolve) classrel, + Pretty.enclose "(" ")" [str (w ^ ":'" ^ v ^ " " ^ deresolve class)], + str "=", + str ("#" ^ deresolve classrel), + str w + ]; + in + Pretty.chunks ( + concat [ + str ("type '" ^ v), + (str o deresolve) class, + str "=", + Pretty.enum "," "{" "};" ( + map pr_superclass_field superclasses @ map pr_classparam_field classparams + ) + ] + :: map pr_superclass_proj superclasses + @ map pr_classparam_proj classparams + ) + end + | pr_stmt (MLClassinst (inst, ((class, (tyco, arity)), (superarities, classparam_insts)))) = + let + fun pr_superclass (_, (classrel, dss)) = + concat [ + (str o Long_Name.base_name o deresolve) classrel, + str "=", + pr_dicts NOBR [DictConst dss] + ]; + fun pr_classparam ((classparam, c_inst), (thm, _)) = + concat [ + (str o Long_Name.base_name o deresolve) classparam, + str "=", + pr_app (K false) thm reserved_names NOBR (c_inst, []) + ]; + in + semicolon ([ + str (if null arity then "val" else "fun"), + (str o deresolve) inst ] @ + pr_tyvar_dicts arity @ [ + str "=", + Pretty.enum "," "{" "}" + (map pr_superclass superarities @ map pr_classparam classparam_insts), + str ":", + pr_tycoexpr NOBR (class, [tyco `%% map (ITyVar o fst) arity]) + ]) + end; + in pr_stmt end; + +fun pr_sml_module name content = + Pretty.chunks ( + str ("structure " ^ name ^ " = ") + :: str "struct" + :: str "" + :: content + @ str "" + @@ str ("end; (*struct " ^ name ^ "*)") + ); + +val literals_sml = Literals { + literal_char = prefix "#" o quote o ML_Syntax.print_char, + literal_string = quote o translate_string ML_Syntax.print_char, + literal_numeral = fn unbounded => fn k => + if unbounded then "(" ^ string_of_int k ^ " : IntInf.int)" + else string_of_int k, + literal_list = Pretty.enum "," "[" "]", + infix_cons = (7, "::") +}; + + +(** OCaml serializer **) + +fun pr_ocaml_stmt labelled_name syntax_tyco syntax_const reserved_names deresolve is_cons = + let + fun pr_dicts fxy ds = + let + fun pr_dictvar (v, (_, 1)) = "_" ^ Code_Printer.first_upper v + | pr_dictvar (v, (i, _)) = "_" ^ Code_Printer.first_upper v ^ string_of_int (i+1); + fun pr_proj ps p = + fold_rev (fn p2 => fn p1 => Pretty.block [p1, str ".", str p2]) ps p + fun pr_dict fxy (DictConst (inst, dss)) = + brackify fxy ((str o deresolve) inst :: map (pr_dicts BR) dss) + | pr_dict fxy (DictVar (classrels, v)) = + pr_proj (map deresolve classrels) ((str o pr_dictvar) v) + in case ds + of [] => str "()" + | [d] => pr_dict fxy d + | _ :: _ => (Pretty.list "(" ")" o map (pr_dict NOBR)) ds + end; + fun pr_tyvar_dicts vs = + vs + |> map (fn (v, sort) => map_index (fn (i, _) => + DictVar ([], (v, (i, length sort)))) sort) + |> map (pr_dicts BR); + fun pr_tycoexpr fxy (tyco, tys) = + let + val tyco' = (str o deresolve) tyco + in case map (pr_typ BR) tys + of [] => tyco' + | [p] => Pretty.block [p, Pretty.brk 1, tyco'] + | (ps as _::_) => Pretty.block [Pretty.list "(" ")" ps, Pretty.brk 1, tyco'] + end + and pr_typ fxy (tyco `%% tys) = (case syntax_tyco tyco + of NONE => pr_tycoexpr fxy (tyco, tys) + | SOME (i, pr) => pr pr_typ fxy tys) + | pr_typ fxy (ITyVar v) = str ("'" ^ v); + fun pr_term is_closure thm vars fxy (IConst c) = + pr_app is_closure thm vars fxy (c, []) + | pr_term is_closure thm vars fxy (IVar v) = + str (Code_Printer.lookup_var vars v) + | pr_term is_closure thm vars fxy (t as t1 `$ t2) = + (case Code_Thingol.unfold_const_app t + of SOME c_ts => pr_app is_closure thm vars fxy c_ts + | NONE => + brackify fxy [pr_term is_closure thm vars NOBR t1, pr_term is_closure thm vars BR t2]) + | pr_term is_closure thm vars fxy (t as _ `|=> _) = + let + val (binds, t') = Code_Thingol.unfold_abs t; + fun pr ((v, pat), ty) = pr_bind is_closure thm BR ((SOME v, pat), ty); + val (ps, vars') = fold_map pr binds vars; + in brackets (str "fun" :: ps @ str "->" @@ pr_term is_closure thm vars' NOBR t') end + | pr_term is_closure thm vars fxy (ICase (cases as (_, t0))) = (case Code_Thingol.unfold_const_app t0 + of SOME (c_ts as ((c, _), _)) => if is_none (syntax_const c) + then pr_case is_closure thm vars fxy cases + else pr_app is_closure thm vars fxy c_ts + | NONE => pr_case is_closure thm vars fxy cases) + and pr_app' is_closure thm vars (app as ((c, ((_, iss), tys)), ts)) = + if is_cons c then + if length tys = length ts + then case ts + of [] => [(str o deresolve) c] + | [t] => [(str o deresolve) c, pr_term is_closure thm vars BR t] + | _ => [(str o deresolve) c, Pretty.enum "," "(" ")" + (map (pr_term is_closure thm vars NOBR) ts)] + else [pr_term is_closure thm vars BR (Code_Thingol.eta_expand (length tys) app)] + else if is_closure c + then (str o deresolve) c @@ str "()" + else (str o deresolve) c + :: ((map (pr_dicts BR) o filter_out null) iss @ map (pr_term is_closure thm vars BR) ts) + and pr_app is_closure = gen_pr_app (pr_app' is_closure) (pr_term is_closure) + syntax_const + and pr_bind' ((NONE, NONE), _) = str "_" + | pr_bind' ((SOME v, NONE), _) = str v + | pr_bind' ((NONE, SOME p), _) = p + | pr_bind' ((SOME v, SOME p), _) = brackets [p, str "as", str v] + and pr_bind is_closure = gen_pr_bind pr_bind' (pr_term is_closure) + and pr_case is_closure thm vars fxy (cases as ((_, [_]), _)) = + let + val (binds, body) = Code_Thingol.unfold_let (ICase cases); + fun pr ((pat, ty), t) vars = + vars + |> pr_bind is_closure thm NOBR ((NONE, SOME pat), ty) + |>> (fn p => concat + [str "let", p, str "=", pr_term is_closure thm vars NOBR t, str "in"]) + val (ps, vars') = fold_map pr binds vars; + in + brackify_block fxy (Pretty.chunks ps) [] + (pr_term is_closure thm vars' NOBR body) + end + | pr_case is_closure thm vars fxy (((t, ty), clause :: clauses), _) = + let + fun pr delim (pat, body) = + let + val (p, vars') = pr_bind is_closure thm NOBR ((NONE, SOME pat), ty) vars; + in concat [str delim, p, str "->", pr_term is_closure thm vars' NOBR body] end; + in + brackets ( + str "match" + :: pr_term is_closure thm vars NOBR t + :: pr "with" clause + :: map (pr "|") clauses + ) + end + | pr_case is_closure thm vars fxy ((_, []), _) = + (concat o map str) ["failwith", "\"empty case\""]; + fun fish_params vars eqs = + let + fun fish_param _ (w as SOME _) = w + | fish_param (IVar v) NONE = SOME v + | fish_param _ NONE = NONE; + fun fillup_param _ (_, SOME v) = v + | fillup_param x (i, NONE) = x ^ string_of_int i; + val fished1 = fold (map2 fish_param) eqs (replicate (length (hd eqs)) NONE); + val x = Name.variant (map_filter I fished1) "x"; + val fished2 = map_index (fillup_param x) fished1; + val (fished3, _) = Name.variants fished2 Name.context; + val vars' = Code_Printer.intro_vars fished3 vars; + in map (Code_Printer.lookup_var vars') fished3 end; + fun pr_stmt (MLExc (name, n)) = + let + val exc_str = + (ML_Syntax.print_string o Long_Name.base_name o Long_Name.qualifier) name; + in + (concat o map str) ( + "let" + :: deresolve name + :: replicate n "_" + @ "=" + :: "failwith" + @@ exc_str + ) + end + | pr_stmt (MLVal (name, (((vs, ty), t), (thm, _)))) = + let + val consts = map_filter + (fn c => if (is_some o syntax_const) c + then NONE else (SOME o Long_Name.base_name o deresolve) c) + (Code_Thingol.fold_constnames (insert (op =)) t []); + val vars = reserved_names + |> Code_Printer.intro_vars consts; + in + concat [ + str "let", + (str o deresolve) name, + str ":", + pr_typ NOBR ty, + str "=", + pr_term (K false) thm vars NOBR t + ] + end + | pr_stmt (MLFuns (funn :: funns, pseudo_funs)) = + let + fun pr_eq ((ts, t), (thm, _)) = + let + val consts = map_filter + (fn c => if (is_some o syntax_const) c + then NONE else (SOME o Long_Name.base_name o deresolve) c) + ((fold o Code_Thingol.fold_constnames) (insert (op =)) (t :: ts) []); + val vars = reserved_names + |> Code_Printer.intro_vars consts + |> Code_Printer.intro_vars ((fold o Code_Thingol.fold_unbound_varnames) + (insert (op =)) ts []); + in concat [ + (Pretty.block o Pretty.commas) + (map (pr_term (member (op =) pseudo_funs) thm vars NOBR) ts), + str "->", + pr_term (member (op =) pseudo_funs) thm vars NOBR t + ] end; + fun pr_eqs is_pseudo [((ts, t), (thm, _))] = + let + val consts = map_filter + (fn c => if (is_some o syntax_const) c + then NONE else (SOME o Long_Name.base_name o deresolve) c) + ((fold o Code_Thingol.fold_constnames) (insert (op =)) (t :: ts) []); + val vars = reserved_names + |> Code_Printer.intro_vars consts + |> Code_Printer.intro_vars ((fold o Code_Thingol.fold_unbound_varnames) + (insert (op =)) ts []); + in + concat ( + (if is_pseudo then [str "()"] + else map (pr_term (member (op =) pseudo_funs) thm vars BR) ts) + @ str "=" + @@ pr_term (member (op =) pseudo_funs) thm vars NOBR t + ) + end + | pr_eqs _ (eqs as (eq as (([_], _), _)) :: eqs') = + Pretty.block ( + str "=" + :: Pretty.brk 1 + :: str "function" + :: Pretty.brk 1 + :: pr_eq eq + :: maps (append [Pretty.fbrk, str "|", Pretty.brk 1] + o single o pr_eq) eqs' + ) + | pr_eqs _ (eqs as eq :: eqs') = + let + val consts = map_filter + (fn c => if (is_some o syntax_const) c + then NONE else (SOME o Long_Name.base_name o deresolve) c) + ((fold o Code_Thingol.fold_constnames) + (insert (op =)) (map (snd o fst) eqs) []); + val vars = reserved_names + |> Code_Printer.intro_vars consts; + val dummy_parms = (map str o fish_params vars o map (fst o fst)) eqs; + in + Pretty.block ( + Pretty.breaks dummy_parms + @ Pretty.brk 1 + :: str "=" + :: Pretty.brk 1 + :: str "match" + :: Pretty.brk 1 + :: (Pretty.block o Pretty.commas) dummy_parms + :: Pretty.brk 1 + :: str "with" + :: Pretty.brk 1 + :: pr_eq eq + :: maps (append [Pretty.fbrk, str "|", Pretty.brk 1] + o single o pr_eq) eqs' + ) + end; + fun pr_funn definer (name, ((vs, ty), eqs)) = + concat ( + str definer + :: (str o deresolve) name + :: pr_tyvar_dicts (filter_out (null o snd) vs) + @| pr_eqs (member (op =) pseudo_funs name) eqs + ); + fun pr_pseudo_fun name = concat [ + str "let", + (str o deresolve) name, + str "=", + (str o deresolve) name, + str "();;" + ]; + val (ps, p) = split_last (pr_funn "fun" funn :: map (pr_funn "and") funns); + val (ps, p) = split_last + (pr_funn "let rec" funn :: map (pr_funn "and") funns); + val pseudo_ps = map pr_pseudo_fun pseudo_funs; + in Pretty.chunks (ps @ Pretty.block ([p, str ";;"]) :: pseudo_ps) end + | pr_stmt (MLDatas (datas as (data :: datas'))) = + let + fun pr_co (co, []) = + str (deresolve co) + | pr_co (co, tys) = + concat [ + str (deresolve co), + str "of", + Pretty.enum " *" "" "" (map (pr_typ (INFX (2, X))) tys) + ]; + fun pr_data definer (tyco, (vs, [])) = + concat ( + str definer + :: pr_tycoexpr NOBR (tyco, map (ITyVar o fst) vs) + :: str "=" + @@ str "EMPTY_" + ) + | pr_data definer (tyco, (vs, cos)) = + concat ( + str definer + :: pr_tycoexpr NOBR (tyco, map (ITyVar o fst) vs) + :: str "=" + :: separate (str "|") (map pr_co cos) + ); + val (ps, p) = split_last + (pr_data "type" data :: map (pr_data "and") datas'); + in Pretty.chunks (ps @| Pretty.block ([p, str ";;"])) end + | pr_stmt (MLClass (class, (v, (superclasses, classparams)))) = + let + val w = "_" ^ Code_Printer.first_upper v; + fun pr_superclass_field (class, classrel) = + (concat o map str) [ + deresolve classrel, ":", "'" ^ v, deresolve class + ]; + fun pr_classparam_field (classparam, ty) = + concat [ + (str o deresolve) classparam, str ":", pr_typ NOBR ty + ]; + fun pr_classparam_proj (classparam, _) = + concat [ + str "let", + (str o deresolve) classparam, + str w, + str "=", + str (w ^ "." ^ deresolve classparam ^ ";;") + ]; + in Pretty.chunks ( + concat [ + str ("type '" ^ v), + (str o deresolve) class, + str "=", + enum_default "unit;;" ";" "{" "};;" ( + map pr_superclass_field superclasses + @ map pr_classparam_field classparams + ) + ] + :: map pr_classparam_proj classparams + ) end + | pr_stmt (MLClassinst (inst, ((class, (tyco, arity)), (superarities, classparam_insts)))) = + let + fun pr_superclass (_, (classrel, dss)) = + concat [ + (str o deresolve) classrel, + str "=", + pr_dicts NOBR [DictConst dss] + ]; + fun pr_classparam_inst ((classparam, c_inst), (thm, _)) = + concat [ + (str o deresolve) classparam, + str "=", + pr_app (K false) thm reserved_names NOBR (c_inst, []) + ]; + in + concat ( + str "let" + :: (str o deresolve) inst + :: pr_tyvar_dicts arity + @ str "=" + @@ (Pretty.enclose "(" ");;" o Pretty.breaks) [ + enum_default "()" ";" "{" "}" (map pr_superclass superarities + @ map pr_classparam_inst classparam_insts), + str ":", + pr_tycoexpr NOBR (class, [tyco `%% map (ITyVar o fst) arity]) + ] + ) + end; + in pr_stmt end; + +fun pr_ocaml_module name content = + Pretty.chunks ( + str ("module " ^ name ^ " = ") + :: str "struct" + :: str "" + :: content + @ str "" + @@ str ("end;; (*struct " ^ name ^ "*)") + ); + +val literals_ocaml = let + fun chr i = + let + val xs = string_of_int i; + val ys = replicate_string (3 - length (explode xs)) "0"; + in "\\" ^ ys ^ xs end; + fun char_ocaml c = + let + val i = ord c; + val s = if i < 32 orelse i = 34 orelse i = 39 orelse i = 92 orelse i > 126 + then chr i else c + in s end; + fun bignum_ocaml k = if k <= 1073741823 + then "(Big_int.big_int_of_int " ^ string_of_int k ^ ")" + else "(Big_int.big_int_of_string " ^ quote (string_of_int k) ^ ")" +in Literals { + literal_char = enclose "'" "'" o char_ocaml, + literal_string = quote o translate_string char_ocaml, + literal_numeral = fn unbounded => fn k => if k >= 0 then + if unbounded then bignum_ocaml k + else string_of_int k + else + if unbounded then "(Big_int.minus_big_int " ^ bignum_ocaml (~ k) ^ ")" + else (enclose "(" ")" o prefix "-" o string_of_int o op ~) k, + literal_list = Pretty.enum ";" "[" "]", + infix_cons = (6, "::") +} end; + + + +(** SML/OCaml generic part **) + +local + +datatype ml_node = + Dummy of string + | Stmt of string * ml_stmt + | Module of string * ((Name.context * Name.context) * ml_node Graph.T); + +in + +fun ml_node_of_program labelled_name module_name reserved_names raw_module_alias program = + let + val module_alias = if is_some module_name then K module_name else raw_module_alias; + val reserved_names = Name.make_context reserved_names; + val empty_module = ((reserved_names, reserved_names), Graph.empty); + fun map_node [] f = f + | map_node (m::ms) f = + Graph.default_node (m, Module (m, empty_module)) + #> Graph.map_node m (fn (Module (module_name, (nsp, nodes))) => + Module (module_name, (nsp, map_node ms f nodes))); + fun map_nsp_yield [] f (nsp, nodes) = + let + val (x, nsp') = f nsp + in (x, (nsp', nodes)) end + | map_nsp_yield (m::ms) f (nsp, nodes) = + let + val (x, nodes') = + nodes + |> Graph.default_node (m, Module (m, empty_module)) + |> Graph.map_node_yield m (fn Module (d_module_name, nsp_nodes) => + let + val (x, nsp_nodes') = map_nsp_yield ms f nsp_nodes + in (x, Module (d_module_name, nsp_nodes')) end) + in (x, (nsp, nodes')) end; + fun map_nsp_fun_yield f (nsp_fun, nsp_typ) = + let + val (x, nsp_fun') = f nsp_fun + in (x, (nsp_fun', nsp_typ)) end; + fun map_nsp_typ_yield f (nsp_fun, nsp_typ) = + let + val (x, nsp_typ') = f nsp_typ + in (x, (nsp_fun, nsp_typ')) end; + val mk_name_module = Code_Printer.mk_name_module reserved_names NONE module_alias program; + fun mk_name_stmt upper name nsp = + let + val (_, base) = Code_Printer.dest_name name; + val base' = if upper then Code_Printer.first_upper base else base; + val ([base''], nsp') = Name.variants [base'] nsp; + in (base'', nsp') end; + fun rearrange_fun name (tysm as (vs, ty), raw_eqs) = + let + val eqs = filter (snd o snd) raw_eqs; + val (eqs', is_value) = if null (filter_out (null o snd) vs) then case eqs + of [(([], t), thm)] => if (not o null o fst o Code_Thingol.unfold_fun) ty + then ([(([IVar "x"], t `$ IVar "x"), thm)], false) + else (eqs, not (Code_Thingol.fold_constnames + (fn name' => fn b => b orelse name = name') t false)) + | _ => (eqs, false) + else (eqs, false) + in ((name, (tysm, eqs')), is_value) end; + fun check_kind [((name, (tysm, [(([], t), thm)])), true)] = MLVal (name, ((tysm, t), thm)) + | check_kind [((name, ((vs, ty), [])), _)] = + MLExc (name, (length o filter_out (null o snd)) vs + (length o fst o Code_Thingol.unfold_fun) ty) + | check_kind funns = + MLFuns (map fst funns, map_filter + (fn ((name, ((vs, _), [(([], _), _)])), _) => + if null (filter_out (null o snd) vs) then SOME name else NONE + | _ => NONE) funns); + fun add_funs stmts = fold_map + (fn (name, Code_Thingol.Fun (_, stmt)) => + map_nsp_fun_yield (mk_name_stmt false name) + #>> rpair (rearrange_fun name stmt) + | (name, _) => + error ("Function block containing illegal statement: " ^ labelled_name name) + ) stmts + #>> (split_list #> apsnd check_kind); + fun add_datatypes stmts = + fold_map + (fn (name, Code_Thingol.Datatype (_, stmt)) => + map_nsp_typ_yield (mk_name_stmt false name) #>> rpair (SOME (name, stmt)) + | (name, Code_Thingol.Datatypecons _) => + map_nsp_fun_yield (mk_name_stmt true name) #>> rpair NONE + | (name, _) => + error ("Datatype block containing illegal statement: " ^ labelled_name name) + ) stmts + #>> (split_list #> apsnd (map_filter I + #> (fn [] => error ("Datatype block without data statement: " + ^ (commas o map (labelled_name o fst)) stmts) + | stmts => MLDatas stmts))); + fun add_class stmts = + fold_map + (fn (name, Code_Thingol.Class (_, stmt)) => + map_nsp_typ_yield (mk_name_stmt false name) #>> rpair (SOME (name, stmt)) + | (name, Code_Thingol.Classrel _) => + map_nsp_fun_yield (mk_name_stmt false name) #>> rpair NONE + | (name, Code_Thingol.Classparam _) => + map_nsp_fun_yield (mk_name_stmt false name) #>> rpair NONE + | (name, _) => + error ("Class block containing illegal statement: " ^ labelled_name name) + ) stmts + #>> (split_list #> apsnd (map_filter I + #> (fn [] => error ("Class block without class statement: " + ^ (commas o map (labelled_name o fst)) stmts) + | [stmt] => MLClass stmt))); + fun add_inst [(name, Code_Thingol.Classinst stmt)] = + map_nsp_fun_yield (mk_name_stmt false name) + #>> (fn base => ([base], MLClassinst (name, stmt))); + fun add_stmts ((stmts as (_, Code_Thingol.Fun _)::_)) = + add_funs stmts + | add_stmts ((stmts as (_, Code_Thingol.Datatypecons _)::_)) = + add_datatypes stmts + | add_stmts ((stmts as (_, Code_Thingol.Datatype _)::_)) = + add_datatypes stmts + | add_stmts ((stmts as (_, Code_Thingol.Class _)::_)) = + add_class stmts + | add_stmts ((stmts as (_, Code_Thingol.Classrel _)::_)) = + add_class stmts + | add_stmts ((stmts as (_, Code_Thingol.Classparam _)::_)) = + add_class stmts + | add_stmts ((stmts as [(_, Code_Thingol.Classinst _)])) = + add_inst stmts + | add_stmts stmts = error ("Illegal mutual dependencies: " ^ + (commas o map (labelled_name o fst)) stmts); + fun add_stmts' stmts nsp_nodes = + let + val names as (name :: names') = map fst stmts; + val deps = + [] + |> fold (fold (insert (op =)) o Graph.imm_succs program) names + |> subtract (op =) names; + val (module_names, _) = (split_list o map Code_Printer.dest_name) names; + val module_name = (the_single o distinct (op =) o map mk_name_module) module_names + handle Empty => + error ("Different namespace prefixes for mutual dependencies:\n" + ^ commas (map labelled_name names) + ^ "\n" + ^ commas module_names); + val module_name_path = Long_Name.explode module_name; + fun add_dep name name' = + let + val module_name' = (mk_name_module o fst o Code_Printer.dest_name) name'; + in if module_name = module_name' then + map_node module_name_path (Graph.add_edge (name, name')) + else let + val (common, (diff1 :: _, diff2 :: _)) = chop_prefix (op =) + (module_name_path, Long_Name.explode module_name'); + in + map_node common + (fn node => Graph.add_edge_acyclic (diff1, diff2) node + handle Graph.CYCLES _ => error ("Dependency " + ^ quote name ^ " -> " ^ quote name' + ^ " would result in module dependency cycle")) + end end; + in + nsp_nodes + |> map_nsp_yield module_name_path (add_stmts stmts) + |-> (fn (base' :: bases', stmt') => + apsnd (map_node module_name_path (Graph.new_node (name, (Stmt (base', stmt'))) + #> fold2 (fn name' => fn base' => + Graph.new_node (name', (Dummy base'))) names' bases'))) + |> apsnd (fold (fn name => fold (add_dep name) deps) names) + |> apsnd (fold_product (curry (map_node module_name_path o Graph.add_edge)) names names) + end; + val (_, nodes) = empty_module + |> fold add_stmts' (map (AList.make (Graph.get_node program)) + (rev (Graph.strong_conn program))); + fun deresolver prefix name = + let + val module_name = (fst o Code_Printer.dest_name) name; + val module_name' = (Long_Name.explode o mk_name_module) module_name; + val (_, (_, remainder)) = chop_prefix (op =) (prefix, module_name'); + val stmt_name = + nodes + |> fold (fn name => fn node => case Graph.get_node node name + of Module (_, (_, node)) => node) module_name' + |> (fn node => case Graph.get_node node name of Stmt (stmt_name, _) => stmt_name + | Dummy stmt_name => stmt_name); + in + Long_Name.implode (remainder @ [stmt_name]) + end handle Graph.UNDEF _ => + error ("Unknown statement name: " ^ labelled_name name); + in (deresolver, nodes) end; + +fun serialize_ml target compile pr_module pr_stmt raw_module_name labelled_name reserved_names includes raw_module_alias + _ syntax_tyco syntax_const program stmt_names destination = + let + val is_cons = Code_Thingol.is_cons program; + val present_stmt_names = Code_Target.stmt_names_of_destination destination; + val is_present = not (null present_stmt_names); + val module_name = if is_present then SOME "Code" else raw_module_name; + val (deresolver, nodes) = ml_node_of_program labelled_name module_name + reserved_names raw_module_alias program; + val reserved_names = Code_Printer.make_vars reserved_names; + fun pr_node prefix (Dummy _) = + NONE + | pr_node prefix (Stmt (_, stmt)) = if is_present andalso + (null o filter (member (op =) present_stmt_names) o stmt_names_of) stmt + then NONE + else SOME + (pr_stmt labelled_name syntax_tyco syntax_const reserved_names + (deresolver prefix) is_cons stmt) + | pr_node prefix (Module (module_name, (_, nodes))) = + separate (str "") + ((map_filter (pr_node (prefix @ [module_name]) o Graph.get_node nodes) + o rev o flat o Graph.strong_conn) nodes) + |> (if is_present then Pretty.chunks else pr_module module_name) + |> SOME; + val stmt_names' = (map o try) + (deresolver (if is_some module_name then the_list module_name else [])) stmt_names; + val p = Pretty.chunks (separate (str "") (map snd includes @ (map_filter + (pr_node [] o Graph.get_node nodes) o rev o flat o Graph.strong_conn) nodes)); + in + Code_Target.mk_serialization target + (case compile of SOME compile => SOME (compile o Code_Target.code_of_pretty) | NONE => NONE) + (fn NONE => Code_Target.code_writeln | SOME file => File.write file o Code_Target.code_of_pretty) + (rpair stmt_names' o Code_Target.code_of_pretty) p destination + end; + +end; (*local*) + + +(** ML (system language) code for evaluation and instrumentalization **) + +fun eval_code_of some_target thy = Code_Target.serialize_custom thy (the_default target_Eval some_target, + (fn _ => fn [] => serialize_ml target_SML (SOME (K ())) (K Pretty.chunks) pr_sml_stmt (SOME ""), + literals_sml)); + + +(* evaluation *) + +fun eval some_target reff postproc thy t args = + let + val ctxt = ProofContext.init thy; + fun evaluator naming program ((_, (_, ty)), t) deps = + let + val _ = if Code_Thingol.contains_dictvar t then + error "Term to be evaluated contains free dictionaries" else (); + val value_name = "Value.VALUE.value" + val program' = program + |> Graph.new_node (value_name, + Code_Thingol.Fun (Term.dummy_patternN, (([], ty), [(([], t), (Drule.dummy_thm, true))]))) + |> fold (curry Graph.add_edge value_name) deps; + val (value_code, [SOME value_name']) = eval_code_of some_target thy naming program' [value_name]; + val sml_code = "let\n" ^ value_code ^ "\nin " ^ value_name' + ^ space_implode " " (map (enclose "(" ")") args) ^ " end"; + in ML_Context.evaluate ctxt false reff sml_code end; + in Code_Thingol.eval thy I postproc evaluator t end; + + +(* instrumentalization by antiquotation *) + +local + +structure CodeAntiqData = ProofDataFun +( + type T = (string list * string list) * (bool * (string + * (string * ((string * string) list * (string * string) list)) lazy)); + fun init _ = (([], []), (true, ("", Lazy.value ("", ([], []))))); +); + +val is_first_occ = fst o snd o CodeAntiqData.get; + +fun delayed_code thy tycos consts () = + let + val (consts', (naming, program)) = Code_Thingol.consts_program thy consts; + val tycos' = map (the o Code_Thingol.lookup_tyco naming) tycos; + val (ml_code, target_names) = eval_code_of NONE thy naming program (consts' @ tycos'); + val (consts'', tycos'') = chop (length consts') target_names; + val consts_map = map2 (fn const => fn NONE => + error ("Constant " ^ (quote o Code.string_of_const thy) const + ^ "\nhas a user-defined serialization") + | SOME const'' => (const, const'')) consts consts'' + val tycos_map = map2 (fn tyco => fn NONE => + error ("Type " ^ (quote o Sign.extern_type thy) tyco + ^ "\nhas a user-defined serialization") + | SOME tyco'' => (tyco, tyco'')) tycos tycos''; + in (ml_code, (tycos_map, consts_map)) end; + +fun register_code new_tycos new_consts ctxt = + let + val ((tycos, consts), (_, (struct_name, _))) = CodeAntiqData.get ctxt; + val tycos' = fold (insert (op =)) new_tycos tycos; + val consts' = fold (insert (op =)) new_consts consts; + val (struct_name', ctxt') = if struct_name = "" + then ML_Antiquote.variant "Code" ctxt + else (struct_name, ctxt); + val acc_code = Lazy.lazy (delayed_code (ProofContext.theory_of ctxt) tycos' consts'); + in CodeAntiqData.put ((tycos', consts'), (false, (struct_name', acc_code))) ctxt' end; + +fun register_const const = register_code [] [const]; + +fun register_datatype tyco constrs = register_code [tyco] constrs; + +fun print_const const all_struct_name tycos_map consts_map = + (Long_Name.append all_struct_name o the o AList.lookup (op =) consts_map) const; + +fun print_datatype tyco constrs all_struct_name tycos_map consts_map = + let + val upperize = implode o nth_map 0 Symbol.to_ascii_upper o explode; + fun check_base name name'' = + if upperize (Long_Name.base_name name) = upperize name'' + then () else error ("Name as printed " ^ quote name'' + ^ "\ndiffers from logical base name " ^ quote (Long_Name.base_name name) ^ "; sorry."); + val tyco'' = (the o AList.lookup (op =) tycos_map) tyco; + val constrs'' = map (the o AList.lookup (op =) consts_map) constrs; + val _ = check_base tyco tyco''; + val _ = map2 check_base constrs constrs''; + in "datatype " ^ tyco'' ^ " = datatype " ^ Long_Name.append all_struct_name tyco'' end; + +fun print_code struct_name is_first print_it ctxt = + let + val (_, (_, (struct_code_name, acc_code))) = CodeAntiqData.get ctxt; + val (raw_ml_code, (tycos_map, consts_map)) = Lazy.force acc_code; + val ml_code = if is_first then "\nstructure " ^ struct_code_name + ^ " =\nstruct\n\n" ^ raw_ml_code ^ "\nend;\n\n" + else ""; + val all_struct_name = Long_Name.append struct_name struct_code_name; + in (ml_code, print_it all_struct_name tycos_map consts_map) end; + +in + +fun ml_code_antiq raw_const {struct_name, background} = + let + val const = Code.check_const (ProofContext.theory_of background) raw_const; + val is_first = is_first_occ background; + val background' = register_const const background; + in (print_code struct_name is_first (print_const const), background') end; + +fun ml_code_datatype_antiq (raw_tyco, raw_constrs) {struct_name, background} = + let + val thy = ProofContext.theory_of background; + val tyco = Sign.intern_type thy raw_tyco; + val constrs = map (Code.check_const thy) raw_constrs; + val constrs' = (map fst o snd o Code.get_datatype thy) tyco; + val _ = if gen_eq_set (op =) (constrs, constrs') then () + else error ("Type " ^ quote tyco ^ ": given constructors diverge from real constructors") + val is_first = is_first_occ background; + val background' = register_datatype tyco constrs background; + in (print_code struct_name is_first (print_datatype tyco constrs), background') end; + +end; (*local*) + + +(** Isar setup **) + +val _ = ML_Context.add_antiq "code" (fn _ => Args.term >> ml_code_antiq); +val _ = ML_Context.add_antiq "code_datatype" (fn _ => + (Args.tyname --| Scan.lift (Args.$$$ "=") + -- (Args.term ::: Scan.repeat (Scan.lift (Args.$$$ "|") |-- Args.term))) + >> ml_code_datatype_antiq); + +fun isar_seri_sml module_name = + Code_Target.parse_args (Scan.succeed ()) + #> (fn () => serialize_ml target_SML + (SOME (use_text ML_Env.local_context (1, "generated code") false)) + pr_sml_module pr_sml_stmt module_name); + +fun isar_seri_ocaml module_name = + Code_Target.parse_args (Scan.succeed ()) + #> (fn () => serialize_ml target_OCaml NONE + pr_ocaml_module pr_ocaml_stmt module_name); + +val setup = + Code_Target.add_target (target_SML, (isar_seri_sml, literals_sml)) + #> Code_Target.add_target (target_OCaml, (isar_seri_ocaml, literals_ocaml)) + #> Code_Target.extend_target (target_Eval, (target_SML, K I)) + #> Code_Target.add_syntax_tyco target_SML "fun" (SOME (2, fn pr_typ => fn fxy => fn [ty1, ty2] => + brackify_infix (1, R) fxy [ + pr_typ (INFX (1, X)) ty1, + str "->", + pr_typ (INFX (1, R)) ty2 + ])) + #> Code_Target.add_syntax_tyco target_OCaml "fun" (SOME (2, fn pr_typ => fn fxy => fn [ty1, ty2] => + brackify_infix (1, R) fxy [ + pr_typ (INFX (1, X)) ty1, + str "->", + pr_typ (INFX (1, R)) ty2 + ])) + #> fold (Code_Target.add_reserved target_SML) ML_Syntax.reserved_names + #> fold (Code_Target.add_reserved target_SML) + ["o" (*dictionary projections use it already*), "Fail", "div", "mod" (*standard infixes*)] + #> fold (Code_Target.add_reserved target_OCaml) [ + "and", "as", "assert", "begin", "class", + "constraint", "do", "done", "downto", "else", "end", "exception", + "external", "false", "for", "fun", "function", "functor", "if", + "in", "include", "inherit", "initializer", "lazy", "let", "match", "method", + "module", "mutable", "new", "object", "of", "open", "or", "private", "rec", + "sig", "struct", "then", "to", "true", "try", "type", "val", + "virtual", "when", "while", "with" + ] + #> fold (Code_Target.add_reserved target_OCaml) ["failwith", "mod"]; + +end; (*struct*) diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/Tools/Code/code_preproc.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/Tools/Code/code_preproc.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,515 @@ +(* Title: Tools/code/code_preproc.ML + Author: Florian Haftmann, TU Muenchen + +Preprocessing code equations into a well-sorted system +in a graph with explicit dependencies. +*) + +signature CODE_PREPROC = +sig + val map_pre: (simpset -> simpset) -> theory -> theory + val map_post: (simpset -> simpset) -> theory -> theory + val add_inline: thm -> theory -> theory + val add_functrans: string * (theory -> (thm * bool) list -> (thm * bool) list option) -> theory -> theory + val del_functrans: string -> theory -> theory + val simple_functrans: (theory -> thm list -> thm list option) + -> theory -> (thm * bool) list -> (thm * bool) list option + val print_codeproc: theory -> unit + + type code_algebra + type code_graph + val eqns: code_graph -> string -> (thm * bool) list + val typ: code_graph -> string -> (string * sort) list * typ + val all: code_graph -> string list + val pretty: theory -> code_graph -> Pretty.T + val obtain: theory -> string list -> term list -> code_algebra * code_graph + val eval_conv: theory -> (sort -> sort) + -> (code_algebra -> code_graph -> (string * sort) list -> term -> cterm -> thm) -> cterm -> thm + val eval: theory -> (sort -> sort) -> ((term -> term) -> 'a -> 'a) + -> (code_algebra -> code_graph -> (string * sort) list -> term -> 'a) -> term -> 'a + + val setup: theory -> theory +end + +structure Code_Preproc : CODE_PREPROC = +struct + +(** preprocessor administration **) + +(* theory data *) + +datatype thmproc = Thmproc of { + pre: simpset, + post: simpset, + functrans: (string * (serial * (theory -> (thm * bool) list -> (thm * bool) list option))) list +}; + +fun make_thmproc ((pre, post), functrans) = + Thmproc { pre = pre, post = post, functrans = functrans }; +fun map_thmproc f (Thmproc { pre, post, functrans }) = + make_thmproc (f ((pre, post), functrans)); +fun merge_thmproc (Thmproc { pre = pre1, post = post1, functrans = functrans1 }, + Thmproc { pre = pre2, post = post2, functrans = functrans2 }) = + let + val pre = Simplifier.merge_ss (pre1, pre2); + val post = Simplifier.merge_ss (post1, post2); + val functrans = AList.merge (op =) (eq_fst (op =)) (functrans1, functrans2); + in make_thmproc ((pre, post), functrans) end; + +structure Code_Preproc_Data = TheoryDataFun +( + type T = thmproc; + val empty = make_thmproc ((Simplifier.empty_ss, Simplifier.empty_ss), []); + fun copy spec = spec; + val extend = copy; + fun merge pp = merge_thmproc; +); + +fun the_thmproc thy = case Code_Preproc_Data.get thy + of Thmproc x => x; + +fun delete_force msg key xs = + if AList.defined (op =) xs key then AList.delete (op =) key xs + else error ("No such " ^ msg ^ ": " ^ quote key); + +fun map_data f thy = + thy + |> Code.purge_data + |> (Code_Preproc_Data.map o map_thmproc) f; + +val map_pre = map_data o apfst o apfst; +val map_post = map_data o apfst o apsnd; + +val add_inline = map_pre o MetaSimplifier.add_simp; +val del_inline = map_pre o MetaSimplifier.del_simp; +val add_post = map_post o MetaSimplifier.add_simp; +val del_post = map_post o MetaSimplifier.del_simp; + +fun add_functrans (name, f) = (map_data o apsnd) + (AList.update (op =) (name, (serial (), f))); + +fun del_functrans name = (map_data o apsnd) + (delete_force "function transformer" name); + + +(* post- and preprocessing *) + +fun apply_functrans thy c _ [] = [] + | apply_functrans thy c [] eqns = eqns + | apply_functrans thy c functrans eqns = eqns + |> perhaps (perhaps_loop (perhaps_apply functrans)) + |> Code.assert_eqns_const thy c; + +fun rhs_conv conv thm = Thm.transitive thm ((conv o Thm.rhs_of) thm); + +fun term_of_conv thy f = + Thm.cterm_of thy + #> f + #> Thm.prop_of + #> Logic.dest_equals + #> snd; + +fun preprocess thy c eqns = + let + val pre = (Simplifier.theory_context thy o #pre o the_thmproc) thy; + val functrans = (map (fn (_, (_, f)) => f thy) o #functrans + o the_thmproc) thy; + in + eqns + |> apply_functrans thy c functrans + |> (map o apfst) (Code.rewrite_eqn pre) + |> (map o apfst) (AxClass.unoverload thy) + |> map (Code.assert_eqn thy) + |> burrow_fst (Code.norm_args thy) + |> burrow_fst (Code.norm_varnames thy) + end; + +fun preprocess_conv thy ct = + let + val pre = (Simplifier.theory_context thy o #pre o the_thmproc) thy; + in + ct + |> Simplifier.rewrite pre + |> rhs_conv (AxClass.unoverload_conv thy) + end; + +fun postprocess_conv thy ct = + let + val post = (Simplifier.theory_context thy o #post o the_thmproc) thy; + in + ct + |> AxClass.overload_conv thy + |> rhs_conv (Simplifier.rewrite post) + end; + +fun postprocess_term thy = term_of_conv thy (postprocess_conv thy); + +fun print_codeproc thy = + let + val ctxt = ProofContext.init thy; + val pre = (#pre o the_thmproc) thy; + val post = (#post o the_thmproc) thy; + val functrans = (map fst o #functrans o the_thmproc) thy; + in + (Pretty.writeln o Pretty.chunks) [ + Pretty.block [ + Pretty.str "preprocessing simpset:", + Pretty.fbrk, + Simplifier.pretty_ss ctxt pre + ], + Pretty.block [ + Pretty.str "postprocessing simpset:", + Pretty.fbrk, + Simplifier.pretty_ss ctxt post + ], + Pretty.block ( + Pretty.str "function transformers:" + :: Pretty.fbrk + :: (Pretty.fbreaks o map Pretty.str) functrans + ) + ] + end; + +fun simple_functrans f thy eqns = case f thy (map fst eqns) + of SOME thms' => SOME (map (rpair (forall snd eqns)) thms') + | NONE => NONE; + + +(** sort algebra and code equation graph types **) + +type code_algebra = (sort -> sort) * Sorts.algebra; +type code_graph = (((string * sort) list * typ) * (thm * bool) list) Graph.T; + +fun eqns eqngr = these o Option.map snd o try (Graph.get_node eqngr); +fun typ eqngr = fst o Graph.get_node eqngr; +fun all eqngr = Graph.keys eqngr; + +fun pretty thy eqngr = + AList.make (snd o Graph.get_node eqngr) (Graph.keys eqngr) + |> (map o apfst) (Code.string_of_const thy) + |> sort (string_ord o pairself fst) + |> map (fn (s, thms) => + (Pretty.block o Pretty.fbreaks) ( + Pretty.str s + :: map (Display.pretty_thm o fst) thms + )) + |> Pretty.chunks; + + +(** the Waisenhaus algorithm **) + +(* auxiliary *) + +fun is_proper_class thy = can (AxClass.get_info thy); + +fun complete_proper_sort thy = + Sign.complete_sort thy #> filter (is_proper_class thy); + +fun inst_params thy tyco = + map (fn (c, _) => AxClass.param_of_inst thy (c, tyco)) + o maps (#params o AxClass.get_info thy); + +fun consts_of thy eqns = [] |> (fold o fold o fold_aterms) + (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, Logic.unvarifyT ty)) | _ => I) + (map (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of o fst) eqns); + +fun tyscm_rhss_of thy c eqns = + let + val tyscm = case eqns of [] => Code.default_typscheme thy c + | ((thm, _) :: _) => Code.typscheme_eqn thy thm; + val rhss = consts_of thy eqns; + in (tyscm, rhss) end; + + +(* data structures *) + +datatype const = Fun of string | Inst of class * string; + +fun const_ord (Fun c1, Fun c2) = fast_string_ord (c1, c2) + | const_ord (Inst class_tyco1, Inst class_tyco2) = + prod_ord fast_string_ord fast_string_ord (class_tyco1, class_tyco2) + | const_ord (Fun _, Inst _) = LESS + | const_ord (Inst _, Fun _) = GREATER; + +type var = const * int; + +structure Vargraph = + GraphFun(type key = var val ord = prod_ord const_ord int_ord); + +datatype styp = Tyco of string * styp list | Var of var | Free; + +fun styp_of c_lhs (Type (tyco, tys)) = Tyco (tyco, map (styp_of c_lhs) tys) + | styp_of c_lhs (TFree (v, _)) = case c_lhs + of SOME (c, lhs) => Var (Fun c, find_index (fn (v', _) => v = v') lhs) + | NONE => Free; + +type vardeps_data = ((string * styp list) list * class list) Vargraph.T + * (((string * sort) list * (thm * bool) list) Symtab.table + * (class * string) list); + +val empty_vardeps_data : vardeps_data = + (Vargraph.empty, (Symtab.empty, [])); + + +(* retrieving equations and instances from the background context *) + +fun obtain_eqns thy eqngr c = + case try (Graph.get_node eqngr) c + of SOME ((lhs, _), eqns) => ((lhs, []), []) + | NONE => let + val eqns = Code.these_eqns thy c + |> preprocess thy c; + val ((lhs, _), rhss) = tyscm_rhss_of thy c eqns; + in ((lhs, rhss), eqns) end; + +fun obtain_instance thy arities (inst as (class, tyco)) = + case AList.lookup (op =) arities inst + of SOME classess => (classess, ([], [])) + | NONE => let + val all_classes = complete_proper_sort thy [class]; + val superclasses = remove (op =) class all_classes + val classess = map (complete_proper_sort thy) + (Sign.arity_sorts thy tyco [class]); + val inst_params = inst_params thy tyco all_classes; + in (classess, (superclasses, inst_params)) end; + + +(* computing instantiations *) + +fun add_classes thy arities eqngr c_k new_classes vardeps_data = + let + val (styps, old_classes) = Vargraph.get_node (fst vardeps_data) c_k; + val diff_classes = new_classes |> subtract (op =) old_classes; + in if null diff_classes then vardeps_data + else let + val c_ks = Vargraph.imm_succs (fst vardeps_data) c_k |> insert (op =) c_k; + in + vardeps_data + |> (apfst o Vargraph.map_node c_k o apsnd) (append diff_classes) + |> fold (fn styp => fold (ensure_typmatch_inst thy arities eqngr styp) new_classes) styps + |> fold (fn c_k => add_classes thy arities eqngr c_k diff_classes) c_ks + end end +and add_styp thy arities eqngr c_k tyco_styps vardeps_data = + let + val (old_styps, classes) = Vargraph.get_node (fst vardeps_data) c_k; + in if member (op =) old_styps tyco_styps then vardeps_data + else + vardeps_data + |> (apfst o Vargraph.map_node c_k o apfst) (cons tyco_styps) + |> fold (ensure_typmatch_inst thy arities eqngr tyco_styps) classes + end +and add_dep thy arities eqngr c_k c_k' vardeps_data = + let + val (_, classes) = Vargraph.get_node (fst vardeps_data) c_k; + in + vardeps_data + |> add_classes thy arities eqngr c_k' classes + |> apfst (Vargraph.add_edge (c_k, c_k')) + end +and ensure_typmatch_inst thy arities eqngr (tyco, styps) class vardeps_data = + if can (Sign.arity_sorts thy tyco) [class] + then vardeps_data + |> ensure_inst thy arities eqngr (class, tyco) + |> fold_index (fn (k, styp) => + ensure_typmatch thy arities eqngr styp (Inst (class, tyco), k)) styps + else vardeps_data (*permissive!*) +and ensure_inst thy arities eqngr (inst as (class, tyco)) (vardeps_data as (_, (_, insts))) = + if member (op =) insts inst then vardeps_data + else let + val (classess, (superclasses, inst_params)) = + obtain_instance thy arities inst; + in + vardeps_data + |> (apsnd o apsnd) (insert (op =) inst) + |> fold_index (fn (k, _) => + apfst (Vargraph.new_node ((Inst (class, tyco), k), ([] ,[])))) classess + |> fold (fn superclass => ensure_inst thy arities eqngr (superclass, tyco)) superclasses + |> fold (ensure_fun thy arities eqngr) inst_params + |> fold_index (fn (k, classes) => + add_classes thy arities eqngr (Inst (class, tyco), k) classes + #> fold (fn superclass => + add_dep thy arities eqngr (Inst (superclass, tyco), k) + (Inst (class, tyco), k)) superclasses + #> fold (fn inst_param => + add_dep thy arities eqngr (Fun inst_param, k) + (Inst (class, tyco), k) + ) inst_params + ) classess + end +and ensure_typmatch thy arities eqngr (Tyco tyco_styps) c_k vardeps_data = + vardeps_data + |> add_styp thy arities eqngr c_k tyco_styps + | ensure_typmatch thy arities eqngr (Var c_k') c_k vardeps_data = + vardeps_data + |> add_dep thy arities eqngr c_k c_k' + | ensure_typmatch thy arities eqngr Free c_k vardeps_data = + vardeps_data +and ensure_rhs thy arities eqngr (c', styps) vardeps_data = + vardeps_data + |> ensure_fun thy arities eqngr c' + |> fold_index (fn (k, styp) => + ensure_typmatch thy arities eqngr styp (Fun c', k)) styps +and ensure_fun thy arities eqngr c (vardeps_data as (_, (eqntab, _))) = + if Symtab.defined eqntab c then vardeps_data + else let + val ((lhs, rhss), eqns) = obtain_eqns thy eqngr c; + val rhss' = (map o apsnd o map) (styp_of (SOME (c, lhs))) rhss; + in + vardeps_data + |> (apsnd o apfst) (Symtab.update_new (c, (lhs, eqns))) + |> fold_index (fn (k, _) => + apfst (Vargraph.new_node ((Fun c, k), ([] ,[])))) lhs + |> fold_index (fn (k, (_, sort)) => + add_classes thy arities eqngr (Fun c, k) (complete_proper_sort thy sort)) lhs + |> fold (ensure_rhs thy arities eqngr) rhss' + end; + + +(* applying instantiations *) + +fun dicts_of thy (proj_sort, algebra) (T, sort) = + let + fun class_relation (x, _) _ = x; + fun type_constructor tyco xs class = + inst_params thy tyco (Sorts.complete_sort algebra [class]) + @ (maps o maps) fst xs; + fun type_variable (TFree (_, sort)) = map (pair []) (proj_sort sort); + in + flat (Sorts.of_sort_derivation (Syntax.pp_global thy) algebra + { class_relation = class_relation, type_constructor = type_constructor, + type_variable = type_variable } (T, proj_sort sort) + handle Sorts.CLASS_ERROR _ => [] (*permissive!*)) + end; + +fun add_arity thy vardeps (class, tyco) = + AList.default (op =) + ((class, tyco), map (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k)) + (0 upto Sign.arity_number thy tyco - 1)); + +fun add_eqs thy vardeps (c, (proto_lhs, proto_eqns)) (rhss, eqngr) = + if can (Graph.get_node eqngr) c then (rhss, eqngr) + else let + val lhs = map_index (fn (k, (v, _)) => + (v, snd (Vargraph.get_node vardeps (Fun c, k)))) proto_lhs; + val inst_tab = Vartab.empty |> fold (fn (v, sort) => + Vartab.update ((v, 0), sort)) lhs; + val eqns = proto_eqns + |> (map o apfst) (Code.inst_thm thy inst_tab); + val (tyscm, rhss') = tyscm_rhss_of thy c eqns; + val eqngr' = Graph.new_node (c, (tyscm, eqns)) eqngr; + in (map (pair c) rhss' @ rhss, eqngr') end; + +fun extend_arities_eqngr thy cs ts (arities, eqngr) = + let + val cs_rhss = (fold o fold_aterms) (fn Const (c_ty as (c, _)) => + insert (op =) (c, (map (styp_of NONE) o Sign.const_typargs thy) c_ty) | _ => I) ts []; + val (vardeps, (eqntab, insts)) = empty_vardeps_data + |> fold (ensure_fun thy arities eqngr) cs + |> fold (ensure_rhs thy arities eqngr) cs_rhss; + val arities' = fold (add_arity thy vardeps) insts arities; + val pp = Syntax.pp_global thy; + val algebra = Sorts.subalgebra pp (is_proper_class thy) + (AList.lookup (op =) arities') (Sign.classes_of thy); + val (rhss, eqngr') = Symtab.fold (add_eqs thy vardeps) eqntab ([], eqngr); + fun deps_of (c, rhs) = c :: maps (dicts_of thy algebra) + (rhs ~~ (map snd o fst o fst o Graph.get_node eqngr') c); + val eqngr'' = fold (fn (c, rhs) => fold + (curry Graph.add_edge c) (deps_of rhs)) rhss eqngr'; + in (algebra, (arities', eqngr'')) end; + + +(** store for preprocessed arities and code equations **) + +structure Wellsorted = CodeDataFun +( + type T = ((string * class) * sort list) list * code_graph; + val empty = ([], Graph.empty); + fun purge thy cs (arities, eqngr) = + let + val del_cs = ((Graph.all_preds eqngr + o filter (can (Graph.get_node eqngr))) cs); + val del_arities = del_cs + |> map_filter (AxClass.inst_of_param thy) + |> maps (fn (c, tyco) => + (map (rpair tyco) o Sign.complete_sort thy o the_list + o AxClass.class_of_param thy) c); + val arities' = fold (AList.delete (op =)) del_arities arities; + val eqngr' = Graph.del_nodes del_cs eqngr; + in (arities', eqngr') end; +); + + +(** retrieval and evaluation interfaces **) + +fun obtain thy cs ts = apsnd snd + (Wellsorted.change_yield thy (extend_arities_eqngr thy cs ts)); + +fun prepare_sorts_typ prep_sort + = map_type_tfree (fn (v, sort) => TFree (v, prep_sort sort)); + +fun prepare_sorts prep_sort (Const (c, ty)) = + Const (c, prepare_sorts_typ prep_sort ty) + | prepare_sorts prep_sort (t1 $ t2) = + prepare_sorts prep_sort t1 $ prepare_sorts prep_sort t2 + | prepare_sorts prep_sort (Abs (v, ty, t)) = + Abs (v, prepare_sorts_typ prep_sort ty, prepare_sorts prep_sort t) + | prepare_sorts _ (t as Bound _) = t; + +fun gen_eval thy cterm_of conclude_evaluation prep_sort evaluator proto_ct = + let + val pp = Syntax.pp_global thy; + val ct = cterm_of proto_ct; + val _ = (Sign.no_frees pp o map_types (K dummyT) o Sign.no_vars pp) + (Thm.term_of ct); + val thm = preprocess_conv thy ct; + val ct' = Thm.rhs_of thm; + val t' = Thm.term_of ct'; + val vs = Term.add_tfrees t' []; + val consts = fold_aterms + (fn Const (c, _) => insert (op =) c | _ => I) t' []; + + val t'' = prepare_sorts prep_sort t'; + val (algebra', eqngr') = obtain thy consts [t'']; + in conclude_evaluation (evaluator algebra' eqngr' vs t'' ct') thm end; + +fun simple_evaluator evaluator algebra eqngr vs t ct = + evaluator algebra eqngr vs t; + +fun eval_conv thy = + let + fun conclude_evaluation thm2 thm1 = + let + val thm3 = postprocess_conv thy (Thm.rhs_of thm2); + in + Thm.transitive thm1 (Thm.transitive thm2 thm3) handle THM _ => + error ("could not construct evaluation proof:\n" + ^ (cat_lines o map Display.string_of_thm) [thm1, thm2, thm3]) + end; + in gen_eval thy I conclude_evaluation end; + +fun eval thy prep_sort postproc evaluator = gen_eval thy (Thm.cterm_of thy) + (K o postproc (postprocess_term thy)) prep_sort (simple_evaluator evaluator); + + +(** setup **) + +val setup = + let + fun mk_attribute f = Thm.declaration_attribute (fn thm => Context.mapping (f thm) I); + fun add_del_attribute (name, (add, del)) = + Code.add_attribute (name, Args.del |-- Scan.succeed (mk_attribute del) + || Scan.succeed (mk_attribute add)) + in + add_del_attribute ("inline", (add_inline, del_inline)) + #> add_del_attribute ("post", (add_post, del_post)) + #> Code.add_attribute ("unfold", Scan.succeed (Thm.declaration_attribute + (fn thm => Context.mapping (Codegen.add_unfold thm #> add_inline thm) I))) + end; + +val _ = + OuterSyntax.improper_command "print_codeproc" "print code preprocessor setup" + OuterKeyword.diag (Scan.succeed + (Toplevel.no_timing o Toplevel.unknown_theory o Toplevel.keep + (print_codeproc o Toplevel.theory_of))); + +end; (*struct*) diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/Tools/Code/code_printer.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/Tools/Code/code_printer.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,318 @@ +(* Title: Tools/code/code_printer.ML + Author: Florian Haftmann, TU Muenchen + +Generic operations for pretty printing of target language code. +*) + +signature CODE_PRINTER = +sig + val nerror: thm -> string -> 'a + + val @@ : 'a * 'a -> 'a list + val @| : 'a list * 'a -> 'a list + val str: string -> Pretty.T + val concat: Pretty.T list -> Pretty.T + val brackets: Pretty.T list -> Pretty.T + val semicolon: Pretty.T list -> Pretty.T + val enum_default: string -> string -> string -> string -> Pretty.T list -> Pretty.T + + val first_upper: string -> string + val first_lower: string -> string + type var_ctxt + val make_vars: string list -> var_ctxt + val intro_vars: string list -> var_ctxt -> var_ctxt + val lookup_var: var_ctxt -> string -> string + + type literals + val Literals: { literal_char: string -> string, literal_string: string -> string, + literal_numeral: bool -> int -> string, + literal_list: Pretty.T list -> Pretty.T, infix_cons: int * string } + -> literals + val literal_char: literals -> string -> string + val literal_string: literals -> string -> string + val literal_numeral: literals -> bool -> int -> string + val literal_list: literals -> Pretty.T list -> Pretty.T + val infix_cons: literals -> int * string + + type lrx + val L: lrx + val R: lrx + val X: lrx + type fixity + val BR: fixity + val NOBR: fixity + val INFX: int * lrx -> fixity + val APP: fixity + val brackify: fixity -> Pretty.T list -> Pretty.T + val brackify_infix: int * lrx -> fixity -> Pretty.T list -> Pretty.T + val brackify_block: fixity -> Pretty.T -> Pretty.T list -> Pretty.T -> Pretty.T + + type itype = Code_Thingol.itype + type iterm = Code_Thingol.iterm + type const = Code_Thingol.const + type dict = Code_Thingol.dict + type tyco_syntax + type const_syntax + type proto_const_syntax + val parse_infix: ('a -> 'b) -> lrx * int -> string + -> int * ((fixity -> 'b -> Pretty.T) + -> fixity -> 'a list -> Pretty.T) + val parse_syntax: ('a -> 'b) -> OuterParse.token list + -> (int * ((fixity -> 'b -> Pretty.T) + -> fixity -> 'a list -> Pretty.T)) option * OuterParse.token list + val simple_const_syntax: (int * ((fixity -> iterm -> Pretty.T) + -> fixity -> (iterm * itype) list -> Pretty.T)) option -> proto_const_syntax option + val activate_const_syntax: theory -> literals + -> proto_const_syntax -> Code_Thingol.naming -> const_syntax * Code_Thingol.naming + val gen_pr_app: (thm -> var_ctxt -> const * iterm list -> Pretty.T list) + -> (thm -> var_ctxt -> fixity -> iterm -> Pretty.T) + -> (string -> const_syntax option) + -> thm -> var_ctxt -> fixity -> const * iterm list -> Pretty.T + val gen_pr_bind: ((string option * Pretty.T option) * itype -> Pretty.T) + -> (thm -> var_ctxt -> fixity -> iterm -> Pretty.T) + -> thm -> fixity + -> (string option * iterm option) * itype -> var_ctxt -> Pretty.T * var_ctxt + + val mk_name_module: Name.context -> string option -> (string -> string option) + -> 'a Graph.T -> string -> string + val dest_name: string -> string * string +end; + +structure Code_Printer : CODE_PRINTER = +struct + +open Code_Thingol; + +fun nerror thm s = error (s ^ ",\nin equation " ^ Display.string_of_thm thm); + +(** assembling text pieces **) + +infixr 5 @@; +infixr 5 @|; +fun x @@ y = [x, y]; +fun xs @| y = xs @ [y]; +val str = PrintMode.setmp [] Pretty.str; +val concat = Pretty.block o Pretty.breaks; +val brackets = Pretty.enclose "(" ")" o Pretty.breaks; +fun semicolon ps = Pretty.block [concat ps, str ";"]; +fun enum_default default sep opn cls [] = str default + | enum_default default sep opn cls xs = Pretty.enum sep opn cls xs; + + +(** names and variable name contexts **) + +type var_ctxt = string Symtab.table * Name.context; + +fun make_vars names = (fold (fn name => Symtab.update_new (name, name)) names Symtab.empty, + Name.make_context names); + +fun intro_vars names (namemap, namectxt) = + let + val (names', namectxt') = Name.variants names namectxt; + val namemap' = fold2 (curry Symtab.update) names names' namemap; + in (namemap', namectxt') end; + +fun lookup_var (namemap, _) name = case Symtab.lookup namemap name + of SOME name' => name' + | NONE => error ("Invalid name in context: " ^ quote name); + +val first_upper = implode o nth_map 0 Symbol.to_ascii_upper o explode; +val first_lower = implode o nth_map 0 Symbol.to_ascii_lower o explode; + + +(** pretty literals **) + +datatype literals = Literals of { + literal_char: string -> string, + literal_string: string -> string, + literal_numeral: bool -> int -> string, + literal_list: Pretty.T list -> Pretty.T, + infix_cons: int * string +}; + +fun dest_Literals (Literals lits) = lits; + +val literal_char = #literal_char o dest_Literals; +val literal_string = #literal_string o dest_Literals; +val literal_numeral = #literal_numeral o dest_Literals; +val literal_list = #literal_list o dest_Literals; +val infix_cons = #infix_cons o dest_Literals; + + +(** syntax printer **) + +(* binding priorities *) + +datatype lrx = L | R | X; + +datatype fixity = + BR + | NOBR + | INFX of (int * lrx); + +val APP = INFX (~1, L); + +fun fixity_lrx L L = false + | fixity_lrx R R = false + | fixity_lrx _ _ = true; + +fun fixity NOBR _ = false + | fixity _ NOBR = false + | fixity (INFX (pr, lr)) (INFX (pr_ctxt, lr_ctxt)) = + pr < pr_ctxt + orelse pr = pr_ctxt + andalso fixity_lrx lr lr_ctxt + orelse pr_ctxt = ~1 + | fixity BR (INFX _) = false + | fixity _ _ = true; + +fun gen_brackify _ [p] = p + | gen_brackify true (ps as _::_) = Pretty.enclose "(" ")" ps + | gen_brackify false (ps as _::_) = Pretty.block ps; + +fun brackify fxy_ctxt = + gen_brackify (fixity BR fxy_ctxt) o Pretty.breaks; + +fun brackify_infix infx fxy_ctxt = + gen_brackify (fixity (INFX infx) fxy_ctxt) o Pretty.breaks; + +fun brackify_block fxy_ctxt p1 ps p2 = + let val p = Pretty.block_enclose (p1, p2) ps + in if fixity BR fxy_ctxt + then Pretty.enclose "(" ")" [p] + else p + end; + + +(* generic syntax *) + +type tyco_syntax = int * ((fixity -> itype -> Pretty.T) + -> fixity -> itype list -> Pretty.T); +type const_syntax = int * ((var_ctxt -> fixity -> iterm -> Pretty.T) + -> thm -> var_ctxt -> fixity -> (iterm * itype) list -> Pretty.T); +type proto_const_syntax = int * (string list * (literals -> string list + -> (var_ctxt -> fixity -> iterm -> Pretty.T) + -> thm -> var_ctxt -> fixity -> (iterm * itype) list -> Pretty.T)); + +fun simple_const_syntax (SOME (n, f)) = SOME (n, + ([], (fn _ => fn _ => fn pr => fn thm => fn vars => f (pr vars)))) + | simple_const_syntax NONE = NONE; + +fun activate_const_syntax thy literals (n, (cs, f)) naming = + fold_map (Code_Thingol.ensure_declared_const thy) cs naming + |-> (fn cs' => pair (n, f literals cs')); + +fun gen_pr_app pr_app pr_term syntax_const thm vars fxy (app as ((c, (_, tys)), ts)) = + case syntax_const c + of NONE => brackify fxy (pr_app thm vars app) + | SOME (k, pr) => + let + fun pr' fxy ts = pr (pr_term thm) thm vars fxy (ts ~~ curry Library.take k tys); + in if k = length ts + then pr' fxy ts + else if k < length ts + then case chop k ts of (ts1, ts2) => + brackify fxy (pr' APP ts1 :: map (pr_term thm vars BR) ts2) + else pr_term thm vars fxy (Code_Thingol.eta_expand k app) + end; + +fun gen_pr_bind pr_bind pr_term thm (fxy : fixity) ((v, pat), ty : itype) vars = + let + val vs = case pat + of SOME pat => Code_Thingol.fold_varnames (insert (op =)) pat [] + | NONE => []; + val vars' = intro_vars (the_list v) vars; + val vars'' = intro_vars vs vars'; + val v' = Option.map (lookup_var vars') v; + val pat' = Option.map (pr_term thm vars'' fxy) pat; + in (pr_bind ((v', pat'), ty), vars'') end; + + +(* mixfix syntax *) + +datatype 'a mixfix = + Arg of fixity + | Pretty of Pretty.T; + +fun mk_mixfix prep_arg (fixity_this, mfx) = + let + fun is_arg (Arg _) = true + | is_arg _ = false; + val i = (length o filter is_arg) mfx; + fun fillin _ [] [] = + [] + | fillin pr (Arg fxy :: mfx) (a :: args) = + (pr fxy o prep_arg) a :: fillin pr mfx args + | fillin pr (Pretty p :: mfx) args = + p :: fillin pr mfx args; + in + (i, fn pr => fn fixity_ctxt => fn args => + gen_brackify (fixity fixity_this fixity_ctxt) (fillin pr mfx args)) + end; + +fun parse_infix prep_arg (x, i) s = + let + val l = case x of L => INFX (i, L) | _ => INFX (i, X); + val r = case x of R => INFX (i, R) | _ => INFX (i, X); + in + mk_mixfix prep_arg (INFX (i, x), + [Arg l, (Pretty o Pretty.brk) 1, (Pretty o str) s, (Pretty o Pretty.brk) 1, Arg r]) + end; + +fun parse_mixfix prep_arg s = + let + val sym_any = Scan.one Symbol.is_regular; + val parse = Scan.optional ($$ "!" >> K true) false -- Scan.repeat ( + ($$ "(" -- $$ "_" -- $$ ")" >> K (Arg NOBR)) + || ($$ "_" >> K (Arg BR)) + || ($$ "/" |-- Scan.repeat ($$ " ") >> (Pretty o Pretty.brk o length)) + || (Scan.repeat1 + ( $$ "'" |-- sym_any + || Scan.unless ($$ "_" || $$ "/" || $$ "(" |-- $$ "_" |-- $$ ")") + sym_any) >> (Pretty o str o implode))); + in case Scan.finite Symbol.stopper parse (Symbol.explode s) + of ((_, p as [_]), []) => mk_mixfix prep_arg (NOBR, p) + | ((b, p as _ :: _ :: _), []) => mk_mixfix prep_arg (if b then NOBR else BR, p) + | _ => Scan.!! + (the_default ("malformed mixfix annotation: " ^ quote s) o snd) Scan.fail () + end; + +val (infixK, infixlK, infixrK) = ("infix", "infixl", "infixr"); + +fun parse_syntax prep_arg xs = + Scan.option (( + ((OuterParse.$$$ infixK >> K X) + || (OuterParse.$$$ infixlK >> K L) + || (OuterParse.$$$ infixrK >> K R)) + -- OuterParse.nat >> parse_infix prep_arg + || Scan.succeed (parse_mixfix prep_arg)) + -- OuterParse.string + >> (fn (parse, s) => parse s)) xs; + +val _ = List.app OuterKeyword.keyword [infixK, infixlK, infixrK]; + + +(** module name spaces **) + +val dest_name = + apfst Long_Name.implode o split_last o fst o split_last o Long_Name.explode; + +fun mk_name_module reserved_names module_prefix module_alias program = + let + fun mk_alias name = case module_alias name + of SOME name' => name' + | NONE => name + |> Long_Name.explode + |> map (fn name => (the_single o fst) (Name.variants [name] reserved_names)) + |> Long_Name.implode; + fun mk_prefix name = case module_prefix + of SOME module_prefix => Long_Name.append module_prefix name + | NONE => name; + val tab = + Symtab.empty + |> Graph.fold ((fn name => Symtab.default (name, (mk_alias #> mk_prefix) name)) + o fst o dest_name o fst) + program + in the o Symtab.lookup tab end; + +end; (*struct*) diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/Tools/Code/code_target.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/Tools/Code/code_target.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,629 @@ +(* Title: Tools/code/code_target.ML + Author: Florian Haftmann, TU Muenchen + +Serializer from intermediate language ("Thin-gol") to target languages. +*) + +signature CODE_TARGET = +sig + include CODE_PRINTER + + type serializer + val add_target: string * (serializer * literals) -> theory -> theory + val extend_target: string * + (string * (Code_Thingol.naming -> Code_Thingol.program -> Code_Thingol.program)) + -> theory -> theory + val assert_target: theory -> string -> string + + type destination + type serialization + val parse_args: (OuterLex.token list -> 'a * OuterLex.token list) + -> OuterLex.token list -> 'a + val stmt_names_of_destination: destination -> string list + val code_of_pretty: Pretty.T -> string + val code_writeln: Pretty.T -> unit + val mk_serialization: string -> ('a -> unit) option + -> (Path.T option -> 'a -> unit) + -> ('a -> string * string option list) + -> 'a -> serialization + val serialize: theory -> string -> string option -> OuterLex.token list + -> Code_Thingol.naming -> Code_Thingol.program -> string list -> serialization + val serialize_custom: theory -> string * (serializer * literals) + -> Code_Thingol.naming -> Code_Thingol.program -> string list -> string * string option list + val the_literals: theory -> string -> literals + val compile: serialization -> unit + val export: serialization -> unit + val file: Path.T -> serialization -> unit + val string: string list -> serialization -> string + val code_of: theory -> string -> string + -> string list -> (Code_Thingol.naming -> string list) -> string + val shell_command: string (*theory name*) -> string (*export_code expr*) -> unit + val code_width: int ref + + val allow_abort: string -> theory -> theory + val add_syntax_class: string -> class -> string option -> theory -> theory + val add_syntax_inst: string -> string * class -> bool -> theory -> theory + val add_syntax_tyco: string -> string -> tyco_syntax option -> theory -> theory + val add_syntax_const: string -> string -> proto_const_syntax option -> theory -> theory + val add_reserved: string -> string -> theory -> theory +end; + +structure Code_Target : CODE_TARGET = +struct + +open Basic_Code_Thingol; +open Code_Printer; + +(** basics **) + +datatype destination = Compile | Export | File of Path.T | String of string list; +type serialization = destination -> (string * string option list) option; + +val code_width = ref 80; (*FIXME after Pretty module no longer depends on print mode*) +fun code_setmp f = PrintMode.setmp [] (Pretty.setmp_margin (!code_width) f); +fun code_of_pretty p = code_setmp Pretty.string_of p ^ "\n"; +fun code_writeln p = Pretty.setmp_margin (!code_width) Pretty.writeln p; + +(*FIXME why another code_setmp?*) +fun compile f = (code_setmp f Compile; ()); +fun export f = (code_setmp f Export; ()); +fun file p f = (code_setmp f (File p); ()); +fun string stmts f = fst (the (code_setmp f (String stmts))); + +fun stmt_names_of_destination (String stmts) = stmts + | stmt_names_of_destination _ = []; + +fun mk_serialization target (SOME comp) _ _ code Compile = (comp code; NONE) + | mk_serialization target NONE _ _ _ Compile = error (target ^ ": no internal compilation") + | mk_serialization target _ output _ code Export = (output NONE code ; NONE) + | mk_serialization target _ output _ code (File file) = (output (SOME file) code; NONE) + | mk_serialization target _ _ string code (String _) = SOME (string code); + + +(** theory data **) + +datatype name_syntax_table = NameSyntaxTable of { + class: string Symtab.table, + instance: unit Symreltab.table, + tyco: tyco_syntax Symtab.table, + const: proto_const_syntax Symtab.table +}; + +fun mk_name_syntax_table ((class, instance), (tyco, const)) = + NameSyntaxTable { class = class, instance = instance, tyco = tyco, const = const }; +fun map_name_syntax_table f (NameSyntaxTable { class, instance, tyco, const }) = + mk_name_syntax_table (f ((class, instance), (tyco, const))); +fun merge_name_syntax_table (NameSyntaxTable { class = class1, instance = instance1, tyco = tyco1, const = const1 }, + NameSyntaxTable { class = class2, instance = instance2, tyco = tyco2, const = const2 }) = + mk_name_syntax_table ( + (Symtab.join (K snd) (class1, class2), + Symreltab.join (K snd) (instance1, instance2)), + (Symtab.join (K snd) (tyco1, tyco2), + Symtab.join (K snd) (const1, const2)) + ); + +type serializer = + string option (*module name*) + -> OuterLex.token list (*arguments*) + -> (string -> string) (*labelled_name*) + -> string list (*reserved symbols*) + -> (string * Pretty.T) list (*includes*) + -> (string -> string option) (*module aliasses*) + -> (string -> string option) (*class syntax*) + -> (string -> tyco_syntax option) + -> (string -> const_syntax option) + -> Code_Thingol.program + -> string list (*selected statements*) + -> serialization; + +datatype serializer_entry = Serializer of serializer * literals + | Extends of string * (Code_Thingol.naming -> Code_Thingol.program -> Code_Thingol.program); + +datatype target = Target of { + serial: serial, + serializer: serializer_entry, + reserved: string list, + includes: (Pretty.T * string list) Symtab.table, + name_syntax_table: name_syntax_table, + module_alias: string Symtab.table +}; + +fun make_target ((serial, serializer), ((reserved, includes), (name_syntax_table, module_alias))) = + Target { serial = serial, serializer = serializer, reserved = reserved, + includes = includes, name_syntax_table = name_syntax_table, module_alias = module_alias }; +fun map_target f ( Target { serial, serializer, reserved, includes, name_syntax_table, module_alias } ) = + make_target (f ((serial, serializer), ((reserved, includes), (name_syntax_table, module_alias)))); +fun merge_target strict target (Target { serial = serial1, serializer = serializer, + reserved = reserved1, includes = includes1, + name_syntax_table = name_syntax_table1, module_alias = module_alias1 }, + Target { serial = serial2, serializer = _, + reserved = reserved2, includes = includes2, + name_syntax_table = name_syntax_table2, module_alias = module_alias2 }) = + if serial1 = serial2 orelse not strict then + make_target ((serial1, serializer), + ((merge (op =) (reserved1, reserved2), Symtab.merge (op =) (includes1, includes2)), + (merge_name_syntax_table (name_syntax_table1, name_syntax_table2), + Symtab.join (K snd) (module_alias1, module_alias2)) + )) + else + error ("Incompatible serializers: " ^ quote target); + +structure CodeTargetData = TheoryDataFun +( + type T = target Symtab.table * string list; + val empty = (Symtab.empty, []); + val copy = I; + val extend = I; + fun merge _ ((target1, exc1) : T, (target2, exc2)) = + (Symtab.join (merge_target true) (target1, target2), Library.merge (op =) (exc1, exc2)); +); + +fun the_serializer (Target { serializer, ... }) = serializer; +fun the_reserved (Target { reserved, ... }) = reserved; +fun the_includes (Target { includes, ... }) = includes; +fun the_name_syntax (Target { name_syntax_table = NameSyntaxTable x, ... }) = x; +fun the_module_alias (Target { module_alias , ... }) = module_alias; + +val abort_allowed = snd o CodeTargetData.get; + +fun assert_target thy target = + case Symtab.lookup (fst (CodeTargetData.get thy)) target + of SOME data => target + | NONE => error ("Unknown code target language: " ^ quote target); + +fun put_target (target, seri) thy = + let + val lookup_target = Symtab.lookup (fst (CodeTargetData.get thy)); + val _ = case seri + of Extends (super, _) => if is_some (lookup_target super) then () + else error ("Unknown code target language: " ^ quote super) + | _ => (); + val overwriting = case (Option.map the_serializer o lookup_target) target + of NONE => false + | SOME (Extends _) => true + | SOME (Serializer _) => (case seri + of Extends _ => error ("Will not overwrite existing target " ^ quote target) + | _ => true); + val _ = if overwriting + then warning ("Overwriting existing target " ^ quote target) + else (); + in + thy + |> (CodeTargetData.map o apfst oo Symtab.map_default) + (target, make_target ((serial (), seri), (([], Symtab.empty), + (mk_name_syntax_table ((Symtab.empty, Symreltab.empty), (Symtab.empty, Symtab.empty)), + Symtab.empty)))) + ((map_target o apfst o apsnd o K) seri) + end; + +fun add_target (target, seri) = put_target (target, Serializer seri); +fun extend_target (target, (super, modify)) = + put_target (target, Extends (super, modify)); + +fun map_target_data target f thy = + let + val _ = assert_target thy target; + in + thy + |> (CodeTargetData.map o apfst o Symtab.map_entry target o map_target) f + end; + +fun map_reserved target = + map_target_data target o apsnd o apfst o apfst; +fun map_includes target = + map_target_data target o apsnd o apfst o apsnd; +fun map_name_syntax target = + map_target_data target o apsnd o apsnd o apfst o map_name_syntax_table; +fun map_module_alias target = + map_target_data target o apsnd o apsnd o apsnd; + + +(** serializer configuration **) + +(* data access *) + +local + +fun cert_class thy class = + let + val _ = AxClass.get_info thy class; + in class end; + +fun read_class thy = cert_class thy o Sign.intern_class thy; + +fun cert_tyco thy tyco = + let + val _ = if Sign.declared_tyname thy tyco then () + else error ("No such type constructor: " ^ quote tyco); + in tyco end; + +fun read_tyco thy = cert_tyco thy o Sign.intern_type thy; + +fun gen_add_syntax_class prep_class prep_const target raw_class raw_syn thy = + let + val class = prep_class thy raw_class; + in case raw_syn + of SOME syntax => + thy + |> (map_name_syntax target o apfst o apfst) + (Symtab.update (class, syntax)) + | NONE => + thy + |> (map_name_syntax target o apfst o apfst) + (Symtab.delete_safe class) + end; + +fun gen_add_syntax_inst prep_class prep_tyco target (raw_tyco, raw_class) add_del thy = + let + val inst = (prep_class thy raw_class, prep_tyco thy raw_tyco); + in if add_del then + thy + |> (map_name_syntax target o apfst o apsnd) + (Symreltab.update (inst, ())) + else + thy + |> (map_name_syntax target o apfst o apsnd) + (Symreltab.delete_safe inst) + end; + +fun gen_add_syntax_tyco prep_tyco target raw_tyco raw_syn thy = + let + val tyco = prep_tyco thy raw_tyco; + fun check_args (syntax as (n, _)) = if n <> Sign.arity_number thy tyco + then error ("Number of arguments mismatch in syntax for type constructor " ^ quote tyco) + else syntax + in case raw_syn + of SOME syntax => + thy + |> (map_name_syntax target o apsnd o apfst) + (Symtab.update (tyco, check_args syntax)) + | NONE => + thy + |> (map_name_syntax target o apsnd o apfst) + (Symtab.delete_safe tyco) + end; + +fun gen_add_syntax_const prep_const target raw_c raw_syn thy = + let + val c = prep_const thy raw_c; + fun check_args (syntax as (n, _)) = if n > Code.no_args thy c + then error ("Too many arguments in syntax for constant " ^ quote c) + else syntax; + in case raw_syn + of SOME syntax => + thy + |> (map_name_syntax target o apsnd o apsnd) + (Symtab.update (c, check_args syntax)) + | NONE => + thy + |> (map_name_syntax target o apsnd o apsnd) + (Symtab.delete_safe c) + end; + +fun add_reserved target = + let + fun add sym syms = if member (op =) syms sym + then error ("Reserved symbol " ^ quote sym ^ " already declared") + else insert (op =) sym syms + in map_reserved target o add end; + +fun gen_add_include read_const target args thy = + let + fun add (name, SOME (content, raw_cs)) incls = + let + val _ = if Symtab.defined incls name + then warning ("Overwriting existing include " ^ name) + else (); + val cs = map (read_const thy) raw_cs; + in Symtab.update (name, (str content, cs)) incls end + | add (name, NONE) incls = Symtab.delete name incls; + in map_includes target (add args) thy end; + +val add_include = gen_add_include Code.check_const; +val add_include_cmd = gen_add_include Code.read_const; + +fun add_module_alias target (thyname, modlname) = + let + val xs = Long_Name.explode modlname; + val xs' = map (Name.desymbolize true) xs; + in if xs' = xs + then map_module_alias target (Symtab.update (thyname, modlname)) + else error ("Invalid module name: " ^ quote modlname ^ "\n" + ^ "perhaps try " ^ quote (Long_Name.implode xs')) + end; + +fun gen_allow_abort prep_const raw_c thy = + let + val c = prep_const thy raw_c; + in thy |> (CodeTargetData.map o apsnd) (insert (op =) c) end; + +fun zip_list (x::xs) f g = + f + #-> (fn y => + fold_map (fn x => g |-- f >> pair x) xs + #-> (fn xys => pair ((x, y) :: xys))); + + +(* concrete syntax *) + +structure P = OuterParse +and K = OuterKeyword + +fun parse_multi_syntax parse_thing parse_syntax = + P.and_list1 parse_thing + #-> (fn things => Scan.repeat1 (P.$$$ "(" |-- P.name -- + (zip_list things parse_syntax (P.$$$ "and")) --| P.$$$ ")")); + +in + +val add_syntax_class = gen_add_syntax_class cert_class (K I); +val add_syntax_inst = gen_add_syntax_inst cert_class cert_tyco; +val add_syntax_tyco = gen_add_syntax_tyco cert_tyco; +val add_syntax_const = gen_add_syntax_const (K I); +val allow_abort = gen_allow_abort (K I); +val add_reserved = add_reserved; + +val add_syntax_class_cmd = gen_add_syntax_class read_class Code.read_const; +val add_syntax_inst_cmd = gen_add_syntax_inst read_class read_tyco; +val add_syntax_tyco_cmd = gen_add_syntax_tyco read_tyco; +val add_syntax_const_cmd = gen_add_syntax_const Code.read_const; +val allow_abort_cmd = gen_allow_abort Code.read_const; + +fun the_literals thy = + let + val (targets, _) = CodeTargetData.get thy; + fun literals target = case Symtab.lookup targets target + of SOME data => (case the_serializer data + of Serializer (_, literals) => literals + | Extends (super, _) => literals super) + | NONE => error ("Unknown code target language: " ^ quote target); + in literals end; + + +(** serializer usage **) + +(* montage *) + +local + +fun labelled_name thy program name = case Graph.get_node program name + of Code_Thingol.Fun (c, _) => quote (Code.string_of_const thy c) + | Code_Thingol.Datatype (tyco, _) => "type " ^ quote (Sign.extern_type thy tyco) + | Code_Thingol.Datatypecons (c, _) => quote (Code.string_of_const thy c) + | Code_Thingol.Class (class, _) => "class " ^ quote (Sign.extern_class thy class) + | Code_Thingol.Classrel (sub, super) => let + val Code_Thingol.Class (sub, _) = Graph.get_node program sub + val Code_Thingol.Class (super, _) = Graph.get_node program super + in quote (Sign.extern_class thy sub ^ " < " ^ Sign.extern_class thy super) end + | Code_Thingol.Classparam (c, _) => quote (Code.string_of_const thy c) + | Code_Thingol.Classinst ((class, (tyco, _)), _) => let + val Code_Thingol.Class (class, _) = Graph.get_node program class + val Code_Thingol.Datatype (tyco, _) = Graph.get_node program tyco + in quote (Sign.extern_type thy tyco ^ " :: " ^ Sign.extern_class thy class) end + +fun activate_syntax lookup_name src_tab = Symtab.empty + |> fold_map (fn thing_identifier => fn tab => case lookup_name thing_identifier + of SOME name => (SOME name, + Symtab.update_new (name, the (Symtab.lookup src_tab thing_identifier)) tab) + | NONE => (NONE, tab)) (Symtab.keys src_tab) + |>> map_filter I; + +fun activate_const_syntax thy literals src_tab naming = (Symtab.empty, naming) + |> fold_map (fn thing_identifier => fn (tab, naming) => + case Code_Thingol.lookup_const naming thing_identifier + of SOME name => let + val (syn, naming') = Code_Printer.activate_const_syntax thy + literals (the (Symtab.lookup src_tab thing_identifier)) naming + in (SOME name, (Symtab.update_new (name, syn) tab, naming')) end + | NONE => (NONE, (tab, naming))) (Symtab.keys src_tab) + |>> map_filter I; + +fun invoke_serializer thy abortable serializer literals reserved abs_includes + module_alias class instance tyco const module args naming program2 names1 = + let + val (names_class, class') = + activate_syntax (Code_Thingol.lookup_class naming) class; + val names_inst = map_filter (Code_Thingol.lookup_instance naming) + (Symreltab.keys instance); + val (names_tyco, tyco') = + activate_syntax (Code_Thingol.lookup_tyco naming) tyco; + val (names_const, (const', _)) = + activate_const_syntax thy literals const naming; + val names_hidden = names_class @ names_inst @ names_tyco @ names_const; + val names2 = subtract (op =) names_hidden names1; + val program3 = Graph.subgraph (not o member (op =) names_hidden) program2; + val names_all = Graph.all_succs program3 names2; + val includes = abs_includes names_all; + val program4 = Graph.subgraph (member (op =) names_all) program3; + val empty_funs = filter_out (member (op =) abortable) + (Code_Thingol.empty_funs program3); + val _ = if null empty_funs then () else error ("No code equations for " + ^ commas (map (Sign.extern_const thy) empty_funs)); + in + serializer module args (labelled_name thy program2) reserved includes + (Symtab.lookup module_alias) (Symtab.lookup class') + (Symtab.lookup tyco') (Symtab.lookup const') + program4 names2 + end; + +fun mount_serializer thy alt_serializer target module args naming program names = + let + val (targets, abortable) = CodeTargetData.get thy; + fun collapse_hierarchy target = + let + val data = case Symtab.lookup targets target + of SOME data => data + | NONE => error ("Unknown code target language: " ^ quote target); + in case the_serializer data + of Serializer _ => (I, data) + | Extends (super, modify) => let + val (modify', data') = collapse_hierarchy super + in (modify' #> modify naming, merge_target false target (data', data)) end + end; + val (modify, data) = collapse_hierarchy target; + val (serializer, _) = the_default (case the_serializer data + of Serializer seri => seri) alt_serializer; + val reserved = the_reserved data; + fun select_include names_all (name, (content, cs)) = + if null cs then SOME (name, content) + else if exists (fn c => case Code_Thingol.lookup_const naming c + of SOME name => member (op =) names_all name + | NONE => false) cs + then SOME (name, content) else NONE; + fun includes names_all = map_filter (select_include names_all) + ((Symtab.dest o the_includes) data); + val module_alias = the_module_alias data; + val { class, instance, tyco, const } = the_name_syntax data; + val literals = the_literals thy target; + in + invoke_serializer thy abortable serializer literals reserved + includes module_alias class instance tyco const module args naming (modify program) names + end; + +in + +fun serialize thy = mount_serializer thy NONE; + +fun serialize_custom thy (target_name, seri) naming program names = + mount_serializer thy (SOME seri) target_name NONE [] naming program names (String []) + |> the; + +end; (* local *) + +fun parse_args f args = + case Scan.read OuterLex.stopper f args + of SOME x => x + | NONE => error "Bad serializer arguments"; + + +(* code presentation *) + +fun code_of thy target module_name cs names_stmt = + let + val (names_cs, (naming, program)) = Code_Thingol.consts_program thy cs; + in + string (names_stmt naming) (serialize thy target (SOME module_name) [] + naming program names_cs) + end; + + +(* code generation *) + +fun transitivly_non_empty_funs thy naming program = + let + val cs = subtract (op =) (abort_allowed thy) (Code_Thingol.empty_funs program); + val names = map_filter (Code_Thingol.lookup_const naming) cs; + in subtract (op =) (Graph.all_preds program names) (Graph.keys program) end; + +fun read_const_exprs thy cs = + let + val (cs1, cs2) = Code_Thingol.read_const_exprs thy cs; + val (names3, (naming, program)) = Code_Thingol.consts_program thy cs2; + val names4 = transitivly_non_empty_funs thy naming program; + val cs5 = map_filter + (fn (c, name) => if member (op =) names4 name then SOME c else NONE) (cs2 ~~ names3); + in fold (insert (op =)) cs5 cs1 end; + +fun cached_program thy = + let + val (naming, program) = Code_Thingol.cached_program thy; + in (transitivly_non_empty_funs thy naming program, (naming, program)) end + +fun export_code thy cs seris = + let + val (cs', (naming, program)) = if null cs then cached_program thy + else Code_Thingol.consts_program thy cs; + fun mk_seri_dest dest = case dest + of NONE => compile + | SOME "-" => export + | SOME f => file (Path.explode f) + val _ = map (fn (((target, module), dest), args) => + (mk_seri_dest dest (serialize thy target module args naming program cs'))) seris; + in () end; + +fun export_code_cmd raw_cs seris thy = export_code thy (read_const_exprs thy raw_cs) seris; + + +(** Isar setup **) + +val (inK, module_nameK, fileK) = ("in", "module_name", "file"); + +val code_exprP = + (Scan.repeat P.term_group + -- Scan.repeat (P.$$$ inK |-- P.name + -- Scan.option (P.$$$ module_nameK |-- P.name) + -- Scan.option (P.$$$ fileK |-- P.name) + -- Scan.optional (P.$$$ "(" |-- Args.parse --| P.$$$ ")") [] + ) >> (fn (raw_cs, seris) => export_code_cmd raw_cs seris)); + +val _ = List.app OuterKeyword.keyword [inK, module_nameK, fileK]; + +val _ = + OuterSyntax.command "code_class" "define code syntax for class" K.thy_decl ( + parse_multi_syntax P.xname (Scan.option P.string) + >> (Toplevel.theory oo fold) (fn (target, syns) => + fold (fn (raw_class, syn) => add_syntax_class_cmd target raw_class syn) syns) + ); + +val _ = + OuterSyntax.command "code_instance" "define code syntax for instance" K.thy_decl ( + parse_multi_syntax (P.xname --| P.$$$ "::" -- P.xname) + ((P.minus >> K true) || Scan.succeed false) + >> (Toplevel.theory oo fold) (fn (target, syns) => + fold (fn (raw_inst, add_del) => add_syntax_inst_cmd target raw_inst add_del) syns) + ); + +val _ = + OuterSyntax.command "code_type" "define code syntax for type constructor" K.thy_decl ( + parse_multi_syntax P.xname (parse_syntax I) + >> (Toplevel.theory oo fold) (fn (target, syns) => + fold (fn (raw_tyco, syn) => add_syntax_tyco_cmd target raw_tyco syn) syns) + ); + +val _ = + OuterSyntax.command "code_const" "define code syntax for constant" K.thy_decl ( + parse_multi_syntax P.term_group (parse_syntax fst) + >> (Toplevel.theory oo fold) (fn (target, syns) => + fold (fn (raw_const, syn) => add_syntax_const_cmd target raw_const + (Code_Printer.simple_const_syntax syn)) syns) + ); + +val _ = + OuterSyntax.command "code_reserved" "declare words as reserved for target language" K.thy_decl ( + P.name -- Scan.repeat1 P.name + >> (fn (target, reserveds) => (Toplevel.theory o fold (add_reserved target)) reserveds) + ); + +val _ = + OuterSyntax.command "code_include" "declare piece of code to be included in generated code" K.thy_decl ( + P.name -- P.name -- (P.text :|-- (fn "-" => Scan.succeed NONE + | s => Scan.optional (P.$$$ "attach" |-- Scan.repeat1 P.term) [] >> pair s >> SOME)) + >> (fn ((target, name), content_consts) => + (Toplevel.theory o add_include_cmd target) (name, content_consts)) + ); + +val _ = + OuterSyntax.command "code_modulename" "alias module to other name" K.thy_decl ( + P.name -- Scan.repeat1 (P.name -- P.name) + >> (fn (target, modlnames) => (Toplevel.theory o fold (add_module_alias target)) modlnames) + ); + +val _ = + OuterSyntax.command "code_abort" "permit constant to be implemented as program abort" K.thy_decl ( + Scan.repeat1 P.term_group >> (Toplevel.theory o fold allow_abort_cmd) + ); + +val _ = + OuterSyntax.command "export_code" "generate executable code for constants" + K.diag (P.!!! code_exprP >> (fn f => Toplevel.keep (f o Toplevel.theory_of))); + +fun shell_command thyname cmd = Toplevel.program (fn _ => + (use_thy thyname; case Scan.read OuterLex.stopper (P.!!! code_exprP) + ((filter OuterLex.is_proper o OuterSyntax.scan Position.none) cmd) + of SOME f => (writeln "Now generating code..."; f (theory thyname)) + | NONE => error ("Bad directive " ^ quote cmd))) + handle TOPLEVEL_ERROR => OS.Process.exit OS.Process.failure; + +end; (*local*) + +end; (*struct*) diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/Tools/Code/code_thingol.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/Tools/Code/code_thingol.ML Tue Jun 23 12:09:30 2009 +0200 @@ -0,0 +1,876 @@ +(* Title: Tools/code/code_thingol.ML + Author: Florian Haftmann, TU Muenchen + +Intermediate language ("Thin-gol") representing executable code. +Representation and translation. +*) + +infix 8 `%%; +infix 4 `$; +infix 4 `$$; +infixr 3 `|=>; +infixr 3 `|==>; + +signature BASIC_CODE_THINGOL = +sig + type vname = string; + datatype dict = + DictConst of string * dict list list + | DictVar of string list * (vname * (int * int)); + datatype itype = + `%% of string * itype list + | ITyVar of vname; + type const = string * ((itype list * dict list list) * itype list (*types of arguments*)) + datatype iterm = + IConst of const + | IVar of vname + | `$ of iterm * iterm + | `|=> of (vname * itype) * iterm + | ICase of ((iterm * itype) * (iterm * iterm) list) * iterm; + (*((term, type), [(selector pattern, body term )]), primitive term)*) + val `$$ : iterm * iterm list -> iterm; + val `|==> : (vname * itype) list * iterm -> iterm; + type typscheme = (vname * sort) list * itype; +end; + +signature CODE_THINGOL = +sig + include BASIC_CODE_THINGOL + val unfoldl: ('a -> ('a * 'b) option) -> 'a -> 'a * 'b list + val unfoldr: ('a -> ('b * 'a) option) -> 'a -> 'b list * 'a + val unfold_fun: itype -> itype list * itype + val unfold_app: iterm -> iterm * iterm list + val split_abs: iterm -> (((vname * iterm option) * itype) * iterm) option + val unfold_abs: iterm -> ((vname * iterm option) * itype) list * iterm + val split_let: iterm -> (((iterm * itype) * iterm) * iterm) option + val unfold_let: iterm -> ((iterm * itype) * iterm) list * iterm + val unfold_const_app: iterm -> (const * iterm list) option + val collapse_let: ((vname * itype) * iterm) * iterm + -> (iterm * itype) * (iterm * iterm) list + val eta_expand: int -> const * iterm list -> iterm + val contains_dictvar: iterm -> bool + val locally_monomorphic: iterm -> bool + val fold_constnames: (string -> 'a -> 'a) -> iterm -> 'a -> 'a + val fold_varnames: (string -> 'a -> 'a) -> iterm -> 'a -> 'a + val fold_unbound_varnames: (string -> 'a -> 'a) -> iterm -> 'a -> 'a + + type naming + val empty_naming: naming + val lookup_class: naming -> class -> string option + val lookup_classrel: naming -> class * class -> string option + val lookup_tyco: naming -> string -> string option + val lookup_instance: naming -> class * string -> string option + val lookup_const: naming -> string -> string option + val ensure_declared_const: theory -> string -> naming -> string * naming + + datatype stmt = + NoStmt + | Fun of string * (typscheme * ((iterm list * iterm) * (thm * bool)) list) + | Datatype of string * ((vname * sort) list * (string * itype list) list) + | Datatypecons of string * string + | Class of class * (vname * ((class * string) list * (string * itype) list)) + | Classrel of class * class + | Classparam of string * class + | Classinst of (class * (string * (vname * sort) list)) + * ((class * (string * (string * dict list list))) list + * ((string * const) * (thm * bool)) list) + type program = stmt Graph.T + val empty_funs: program -> string list + val map_terms_bottom_up: (iterm -> iterm) -> iterm -> iterm + val map_terms_stmt: (iterm -> iterm) -> stmt -> stmt + val is_cons: program -> string -> bool + val contr_classparam_typs: program -> string -> itype option list + + val read_const_exprs: theory -> string list -> string list * string list + val consts_program: theory -> string list -> string list * (naming * program) + val cached_program: theory -> naming * program + val eval_conv: theory -> (sort -> sort) + -> (naming -> program -> ((string * sort) list * typscheme) * iterm -> string list -> cterm -> thm) + -> cterm -> thm + val eval: theory -> (sort -> sort) -> ((term -> term) -> 'a -> 'a) + -> (naming -> program -> ((string * sort) list * typscheme) * iterm -> string list -> 'a) + -> term -> 'a +end; + +structure Code_Thingol: CODE_THINGOL = +struct + +(** auxiliary **) + +fun unfoldl dest x = + case dest x + of NONE => (x, []) + | SOME (x1, x2) => + let val (x', xs') = unfoldl dest x1 in (x', xs' @ [x2]) end; + +fun unfoldr dest x = + case dest x + of NONE => ([], x) + | SOME (x1, x2) => + let val (xs', x') = unfoldr dest x2 in (x1::xs', x') end; + + +(** language core - types, terms **) + +type vname = string; + +datatype dict = + DictConst of string * dict list list + | DictVar of string list * (vname * (int * int)); + +datatype itype = + `%% of string * itype list + | ITyVar of vname; + +type const = string * ((itype list * dict list list) * itype list (*types of arguments*)) + +datatype iterm = + IConst of const + | IVar of vname + | `$ of iterm * iterm + | `|=> of (vname * itype) * iterm + | ICase of ((iterm * itype) * (iterm * iterm) list) * iterm; + (*see also signature*) + +val op `$$ = Library.foldl (op `$); +val op `|==> = Library.foldr (op `|=>); + +val unfold_app = unfoldl + (fn op `$ t => SOME t + | _ => NONE); + +val split_abs = + (fn (v, ty) `|=> (t as ICase (((IVar w, _), [(p, t')]), _)) => + if v = w then SOME (((v, SOME p), ty), t') else SOME (((v, NONE), ty), t) + | (v, ty) `|=> t => SOME (((v, NONE), ty), t) + | _ => NONE); + +val unfold_abs = unfoldr split_abs; + +val split_let = + (fn ICase (((td, ty), [(p, t)]), _) => SOME (((p, ty), td), t) + | _ => NONE); + +val unfold_let = unfoldr split_let; + +fun unfold_const_app t = + case unfold_app t + of (IConst c, ts) => SOME (c, ts) + | _ => NONE; + +fun fold_aiterms f (t as IConst _) = f t + | fold_aiterms f (t as IVar _) = f t + | fold_aiterms f (t1 `$ t2) = fold_aiterms f t1 #> fold_aiterms f t2 + | fold_aiterms f (t as _ `|=> t') = f t #> fold_aiterms f t' + | fold_aiterms f (ICase (_, t)) = fold_aiterms f t; + +fun fold_constnames f = + let + fun add (IConst (c, _)) = f c + | add _ = I; + in fold_aiterms add end; + +fun fold_varnames f = + let + fun add (IVar v) = f v + | add ((v, _) `|=> _) = f v + | add _ = I; + in fold_aiterms add end; + +fun fold_unbound_varnames f = + let + fun add _ (IConst _) = I + | add vs (IVar v) = if not (member (op =) vs v) then f v else I + | add vs (t1 `$ t2) = add vs t1 #> add vs t2 + | add vs ((v, _) `|=> t) = add (insert (op =) v vs) t + | add vs (ICase (_, t)) = add vs t; + in add [] end; + +fun collapse_let (((v, ty), se), be as ICase (((IVar w, _), ds), _)) = + let + fun exists_v t = fold_unbound_varnames (fn w => fn b => + b orelse v = w) t false; + in if v = w andalso forall (fn (t1, t2) => + exists_v t1 orelse not (exists_v t2)) ds + then ((se, ty), ds) + else ((se, ty), [(IVar v, be)]) + end + | collapse_let (((v, ty), se), be) = + ((se, ty), [(IVar v, be)]) + +fun eta_expand k (c as (_, (_, tys)), ts) = + let + val j = length ts; + val l = k - j; + val ctxt = (fold o fold_varnames) Name.declare ts Name.context; + val vs_tys = Name.names ctxt "a" ((curry Library.take l o curry Library.drop j) tys); + in vs_tys `|==> IConst c `$$ ts @ map (fn (v, _) => IVar v) vs_tys end; + +fun contains_dictvar t = + let + fun contains (DictConst (_, dss)) = (fold o fold) contains dss + | contains (DictVar _) = K true; + in + fold_aiterms + (fn IConst (_, ((_, dss), _)) => (fold o fold) contains dss | _ => I) t false + end; + +fun locally_monomorphic (IConst _) = false + | locally_monomorphic (IVar _) = true + | locally_monomorphic (t `$ _) = locally_monomorphic t + | locally_monomorphic (_ `|=> t) = locally_monomorphic t + | locally_monomorphic (ICase ((_, ds), _)) = exists (locally_monomorphic o snd) ds; + + +(** namings **) + +(* policies *) + +local + fun thyname_of thy f x = the (AList.lookup (op =) (f x) Markup.theory_nameN); + fun thyname_of_class thy = + thyname_of thy (ProofContext.query_class (ProofContext.init thy)); + fun thyname_of_tyco thy = + thyname_of thy (Type.the_tags (Sign.tsig_of thy)); + fun thyname_of_instance thy inst = case AxClass.arity_property thy inst Markup.theory_nameN + of [] => error ("no such instance: " ^ quote (snd inst ^ " :: " ^ fst inst)) + | thyname :: _ => thyname; + fun thyname_of_const thy c = case AxClass.class_of_param thy c + of SOME class => thyname_of_class thy class + | NONE => (case Code.get_datatype_of_constr thy c + of SOME dtco => thyname_of_tyco thy dtco + | NONE => thyname_of thy (Consts.the_tags (Sign.consts_of thy)) c); + fun purify_base "op &" = "and" + | purify_base "op |" = "or" + | purify_base "op -->" = "implies" + | purify_base "op :" = "member" + | purify_base "op =" = "eq" + | purify_base "*" = "product" + | purify_base "+" = "sum" + | purify_base s = Name.desymbolize false s; + fun namify thy get_basename get_thyname name = + let + val prefix = get_thyname thy name; + val base = (purify_base o get_basename) name; + in Long_Name.append prefix base end; +in + +fun namify_class thy = namify thy Long_Name.base_name thyname_of_class; +fun namify_classrel thy = namify thy (fn (class1, class2) => + Long_Name.base_name class2 ^ "_" ^ Long_Name.base_name class1) (fn thy => thyname_of_class thy o fst); + (*order fits nicely with composed projections*) +fun namify_tyco thy "fun" = "Pure.fun" + | namify_tyco thy tyco = namify thy Long_Name.base_name thyname_of_tyco tyco; +fun namify_instance thy = namify thy (fn (class, tyco) => + Long_Name.base_name class ^ "_" ^ Long_Name.base_name tyco) thyname_of_instance; +fun namify_const thy = namify thy Long_Name.base_name thyname_of_const; + +end; (* local *) + + +(* data *) + +datatype naming = Naming of { + class: class Symtab.table * Name.context, + classrel: string Symreltab.table * Name.context, + tyco: string Symtab.table * Name.context, + instance: string Symreltab.table * Name.context, + const: string Symtab.table * Name.context +} + +fun dest_Naming (Naming naming) = naming; + +val empty_naming = Naming { + class = (Symtab.empty, Name.context), + classrel = (Symreltab.empty, Name.context), + tyco = (Symtab.empty, Name.context), + instance = (Symreltab.empty, Name.context), + const = (Symtab.empty, Name.context) +}; + +local + fun mk_naming (class, classrel, tyco, instance, const) = + Naming { class = class, classrel = classrel, + tyco = tyco, instance = instance, const = const }; + fun map_naming f (Naming { class, classrel, tyco, instance, const }) = + mk_naming (f (class, classrel, tyco, instance, const)); +in + fun map_class f = map_naming + (fn (class, classrel, tyco, inst, const) => + (f class, classrel, tyco, inst, const)); + fun map_classrel f = map_naming + (fn (class, classrel, tyco, inst, const) => + (class, f classrel, tyco, inst, const)); + fun map_tyco f = map_naming + (fn (class, classrel, tyco, inst, const) => + (class, classrel, f tyco, inst, const)); + fun map_instance f = map_naming + (fn (class, classrel, tyco, inst, const) => + (class, classrel, tyco, f inst, const)); + fun map_const f = map_naming + (fn (class, classrel, tyco, inst, const) => + (class, classrel, tyco, inst, f const)); +end; (*local*) + +fun add_variant update (thing, name) (tab, used) = + let + val (name', used') = yield_singleton Name.variants name used; + val tab' = update (thing, name') tab; + in (tab', used') end; + +fun declare thy mapp lookup update namify thing = + mapp (add_variant update (thing, namify thy thing)) + #> `(fn naming => the (lookup naming thing)); + + +(* lookup and declare *) + +local + +val suffix_class = "class"; +val suffix_classrel = "classrel" +val suffix_tyco = "tyco"; +val suffix_instance = "inst"; +val suffix_const = "const"; + +fun add_suffix nsp NONE = NONE + | add_suffix nsp (SOME name) = SOME (Long_Name.append name nsp); + +in + +val lookup_class = add_suffix suffix_class + oo Symtab.lookup o fst o #class o dest_Naming; +val lookup_classrel = add_suffix suffix_classrel + oo Symreltab.lookup o fst o #classrel o dest_Naming; +val lookup_tyco = add_suffix suffix_tyco + oo Symtab.lookup o fst o #tyco o dest_Naming; +val lookup_instance = add_suffix suffix_instance + oo Symreltab.lookup o fst o #instance o dest_Naming; +val lookup_const = add_suffix suffix_const + oo Symtab.lookup o fst o #const o dest_Naming; + +fun declare_class thy = declare thy map_class + lookup_class Symtab.update_new namify_class; +fun declare_classrel thy = declare thy map_classrel + lookup_classrel Symreltab.update_new namify_classrel; +fun declare_tyco thy = declare thy map_tyco + lookup_tyco Symtab.update_new namify_tyco; +fun declare_instance thy = declare thy map_instance + lookup_instance Symreltab.update_new namify_instance; +fun declare_const thy = declare thy map_const + lookup_const Symtab.update_new namify_const; + +fun ensure_declared_const thy const naming = + case lookup_const naming const + of SOME const' => (const', naming) + | NONE => declare_const thy const naming; + +val unfold_fun = unfoldr + (fn "Pure.fun.tyco" `%% [ty1, ty2] => SOME (ty1, ty2) + | _ => NONE); (*depends on suffix_tyco and namify_tyco!*) + +end; (* local *) + + +(** statements, abstract programs **) + +type typscheme = (vname * sort) list * itype; +datatype stmt = + NoStmt + | Fun of string * (typscheme * ((iterm list * iterm) * (thm * bool)) list) + | Datatype of string * ((vname * sort) list * (string * itype list) list) + | Datatypecons of string * string + | Class of class * (vname * ((class * string) list * (string * itype) list)) + | Classrel of class * class + | Classparam of string * class + | Classinst of (class * (string * (vname * sort) list)) + * ((class * (string * (string * dict list list))) list + * ((string * const) * (thm * bool)) list); + +type program = stmt Graph.T; + +fun empty_funs program = + Graph.fold (fn (name, (Fun (c, (_, [])), _)) => cons c + | _ => I) program []; + +fun map_terms_bottom_up f (t as IConst _) = f t + | map_terms_bottom_up f (t as IVar _) = f t + | map_terms_bottom_up f (t1 `$ t2) = f + (map_terms_bottom_up f t1 `$ map_terms_bottom_up f t2) + | map_terms_bottom_up f ((v, ty) `|=> t) = f + ((v, ty) `|=> map_terms_bottom_up f t) + | map_terms_bottom_up f (ICase (((t, ty), ps), t0)) = f + (ICase (((map_terms_bottom_up f t, ty), (map o pairself) + (map_terms_bottom_up f) ps), map_terms_bottom_up f t0)); + +fun map_terms_stmt f NoStmt = NoStmt + | map_terms_stmt f (Fun (c, (tysm, eqs))) = Fun (c, (tysm, (map o apfst) + (fn (ts, t) => (map f ts, f t)) eqs)) + | map_terms_stmt f (stmt as Datatype _) = stmt + | map_terms_stmt f (stmt as Datatypecons _) = stmt + | map_terms_stmt f (stmt as Class _) = stmt + | map_terms_stmt f (stmt as Classrel _) = stmt + | map_terms_stmt f (stmt as Classparam _) = stmt + | map_terms_stmt f (Classinst (arity, (superarities, classparms))) = + Classinst (arity, (superarities, (map o apfst o apsnd) (fn const => + case f (IConst const) of IConst const' => const') classparms)); + +fun is_cons program name = case Graph.get_node program name + of Datatypecons _ => true + | _ => false; + +fun contr_classparam_typs program name = case Graph.get_node program name + of Classparam (_, class) => let + val Class (_, (_, (_, params))) = Graph.get_node program class; + val SOME ty = AList.lookup (op =) params name; + val (tys, res_ty) = unfold_fun ty; + fun no_tyvar (_ `%% tys) = forall no_tyvar tys + | no_tyvar (ITyVar _) = false; + in if no_tyvar res_ty + then map (fn ty => if no_tyvar ty then NONE else SOME ty) tys + else [] + end + | _ => []; + + +(** translation kernel **) + +(* generic mechanisms *) + +fun ensure_stmt lookup declare generate thing (dep, (naming, program)) = + let + fun add_dep name = case dep of NONE => I + | SOME dep => Graph.add_edge (dep, name); + val (name, naming') = case lookup naming thing + of SOME name => (name, naming) + | NONE => declare thing naming; + in case try (Graph.get_node program) name + of SOME stmt => program + |> add_dep name + |> pair naming' + |> pair dep + |> pair name + | NONE => program + |> Graph.default_node (name, NoStmt) + |> add_dep name + |> pair naming' + |> curry generate (SOME name) + ||> snd + |-> (fn stmt => (apsnd o Graph.map_node name) (K stmt)) + |> pair dep + |> pair name + end; + +fun not_wellsorted thy thm ty sort e = + let + val err_class = Sorts.class_error (Syntax.pp_global thy) e; + val err_thm = case thm + of SOME thm => "\n(in code equation " ^ Display.string_of_thm thm ^ ")" | NONE => ""; + val err_typ = "Type " ^ Syntax.string_of_typ_global thy ty ^ " not of sort " + ^ Syntax.string_of_sort_global thy sort; + in error ("Wellsortedness error" ^ err_thm ^ ":\n" ^ err_typ ^ "\n" ^ err_class) end; + + +(* translation *) + +fun ensure_tyco thy algbr funcgr tyco = + let + val stmt_datatype = + let + val (vs, cos) = Code.get_datatype thy tyco; + in + fold_map (translate_tyvar_sort thy algbr funcgr) vs + ##>> fold_map (fn (c, tys) => + ensure_const thy algbr funcgr c + ##>> fold_map (translate_typ thy algbr funcgr) tys) cos + #>> (fn info => Datatype (tyco, info)) + end; + in ensure_stmt lookup_tyco (declare_tyco thy) stmt_datatype tyco end +and ensure_const thy algbr funcgr c = + let + fun stmt_datatypecons tyco = + ensure_tyco thy algbr funcgr tyco + #>> (fn tyco => Datatypecons (c, tyco)); + fun stmt_classparam class = + ensure_class thy algbr funcgr class + #>> (fn class => Classparam (c, class)); + fun stmt_fun ((vs, ty), raw_thms) = + let + val thms = if null (Term.add_tfreesT ty []) orelse (null o fst o strip_type) ty + then raw_thms + else (map o apfst) (Code.expand_eta thy 1) raw_thms; + in + fold_map (translate_tyvar_sort thy algbr funcgr) vs + ##>> translate_typ thy algbr funcgr ty + ##>> fold_map (translate_eq thy algbr funcgr) thms + #>> (fn info => Fun (c, info)) + end; + val stmt_const = case Code.get_datatype_of_constr thy c + of SOME tyco => stmt_datatypecons tyco + | NONE => (case AxClass.class_of_param thy c + of SOME class => stmt_classparam class + | NONE => stmt_fun (Code_Preproc.typ funcgr c, Code_Preproc.eqns funcgr c)) + in ensure_stmt lookup_const (declare_const thy) stmt_const c end +and ensure_class thy (algbr as (_, algebra)) funcgr class = + let + val superclasses = (Sorts.minimize_sort algebra o Sorts.super_classes algebra) class; + val cs = #params (AxClass.get_info thy class); + val stmt_class = + fold_map (fn superclass => ensure_class thy algbr funcgr superclass + ##>> ensure_classrel thy algbr funcgr (class, superclass)) superclasses + ##>> fold_map (fn (c, ty) => ensure_const thy algbr funcgr c + ##>> translate_typ thy algbr funcgr ty) cs + #>> (fn info => Class (class, (unprefix "'" Name.aT, info))) + in ensure_stmt lookup_class (declare_class thy) stmt_class class end +and ensure_classrel thy algbr funcgr (subclass, superclass) = + let + val stmt_classrel = + ensure_class thy algbr funcgr subclass + ##>> ensure_class thy algbr funcgr superclass + #>> Classrel; + in ensure_stmt lookup_classrel (declare_classrel thy) stmt_classrel (subclass, superclass) end +and ensure_inst thy (algbr as (_, algebra)) funcgr (class, tyco) = + let + val superclasses = (Sorts.minimize_sort algebra o Sorts.super_classes algebra) class; + val classparams = these (try (#params o AxClass.get_info thy) class); + val vs = Name.names Name.context "'a" (Sorts.mg_domain algebra tyco [class]); + val sorts' = Sorts.mg_domain (Sign.classes_of thy) tyco [class]; + val vs' = map2 (fn (v, sort1) => fn sort2 => (v, + Sorts.inter_sort (Sign.classes_of thy) (sort1, sort2))) vs sorts'; + val arity_typ = Type (tyco, map TFree vs); + val arity_typ' = Type (tyco, map (fn (v, sort) => TVar ((v, 0), sort)) vs'); + fun translate_superarity superclass = + ensure_class thy algbr funcgr superclass + ##>> ensure_classrel thy algbr funcgr (class, superclass) + ##>> translate_dicts thy algbr funcgr NONE (arity_typ, [superclass]) + #>> (fn ((superclass, classrel), [DictConst (inst, dss)]) => + (superclass, (classrel, (inst, dss)))); + fun translate_classparam_inst (c, ty) = + let + val c_inst = Const (c, map_type_tfree (K arity_typ') ty); + val thm = AxClass.unoverload_conv thy (Thm.cterm_of thy c_inst); + val c_ty = (apsnd Logic.unvarifyT o dest_Const o snd + o Logic.dest_equals o Thm.prop_of) thm; + in + ensure_const thy algbr funcgr c + ##>> translate_const thy algbr funcgr (SOME thm) c_ty + #>> (fn (c, IConst c_inst) => ((c, c_inst), (thm, true))) + end; + val stmt_inst = + ensure_class thy algbr funcgr class + ##>> ensure_tyco thy algbr funcgr tyco + ##>> fold_map (translate_tyvar_sort thy algbr funcgr) vs + ##>> fold_map translate_superarity superclasses + ##>> fold_map translate_classparam_inst classparams + #>> (fn ((((class, tyco), arity), superarities), classparams) => + Classinst ((class, (tyco, arity)), (superarities, classparams))); + in ensure_stmt lookup_instance (declare_instance thy) stmt_inst (class, tyco) end +and translate_typ thy algbr funcgr (TFree (v, _)) = + pair (ITyVar (unprefix "'" v)) + | translate_typ thy algbr funcgr (Type (tyco, tys)) = + ensure_tyco thy algbr funcgr tyco + ##>> fold_map (translate_typ thy algbr funcgr) tys + #>> (fn (tyco, tys) => tyco `%% tys) +and translate_term thy algbr funcgr thm (Const (c, ty)) = + translate_app thy algbr funcgr thm ((c, ty), []) + | translate_term thy algbr funcgr thm (Free (v, _)) = + pair (IVar v) + | translate_term thy algbr funcgr thm (Abs (abs as (_, ty, _))) = + let + val (v, t) = Syntax.variant_abs abs; + in + translate_typ thy algbr funcgr ty + ##>> translate_term thy algbr funcgr thm t + #>> (fn (ty, t) => (v, ty) `|=> t) + end + | translate_term thy algbr funcgr thm (t as _ $ _) = + case strip_comb t + of (Const (c, ty), ts) => + translate_app thy algbr funcgr thm ((c, ty), ts) + | (t', ts) => + translate_term thy algbr funcgr thm t' + ##>> fold_map (translate_term thy algbr funcgr thm) ts + #>> (fn (t, ts) => t `$$ ts) +and translate_eq thy algbr funcgr (thm, proper) = + let + val (args, rhs) = (apfst (snd o strip_comb) o Logic.dest_equals + o Logic.unvarify o prop_of) thm; + in + fold_map (translate_term thy algbr funcgr (SOME thm)) args + ##>> translate_term thy algbr funcgr (SOME thm) rhs + #>> rpair (thm, proper) + end +and translate_const thy algbr funcgr thm (c, ty) = + let + val tys = Sign.const_typargs thy (c, ty); + val sorts = (map snd o fst o Code_Preproc.typ funcgr) c; + val tys_args = (fst o Term.strip_type) ty; + in + ensure_const thy algbr funcgr c + ##>> fold_map (translate_typ thy algbr funcgr) tys + ##>> fold_map (translate_dicts thy algbr funcgr thm) (tys ~~ sorts) + ##>> fold_map (translate_typ thy algbr funcgr) tys_args + #>> (fn (((c, tys), iss), tys_args) => IConst (c, ((tys, iss), tys_args))) + end +and translate_app_const thy algbr funcgr thm (c_ty, ts) = + translate_const thy algbr funcgr thm c_ty + ##>> fold_map (translate_term thy algbr funcgr thm) ts + #>> (fn (t, ts) => t `$$ ts) +and translate_case thy algbr funcgr thm (num_args, (t_pos, case_pats)) (c_ty, ts) = + let + val (tys, _) = (chop num_args o fst o strip_type o snd) c_ty; + val t = nth ts t_pos; + val ty = nth tys t_pos; + val ts_clause = nth_drop t_pos ts; + fun mk_clause (co, num_co_args) t = + let + val (vs, body) = Term.strip_abs_eta num_co_args t; + val not_undefined = case body + of (Const (c, _)) => not (Code.is_undefined thy c) + | _ => true; + val pat = list_comb (Const (co, map snd vs ---> ty), map Free vs); + in (not_undefined, (pat, body)) end; + val clauses = if null case_pats then let val ([v_ty], body) = + Term.strip_abs_eta 1 (the_single ts_clause) + in [(true, (Free v_ty, body))] end + else map (uncurry mk_clause) + (AList.make (Code.no_args thy) case_pats ~~ ts_clause); + fun retermify ty (_, (IVar x, body)) = + (x, ty) `|=> body + | retermify _ (_, (pat, body)) = + let + val (IConst (_, (_, tys)), ts) = unfold_app pat; + val vs = map2 (fn IVar x => fn ty => (x, ty)) ts tys; + in vs `|==> body end; + fun mk_icase const t ty clauses = + let + val (ts1, ts2) = chop t_pos (map (retermify ty) clauses); + in + ICase (((t, ty), map_filter (fn (b, d) => if b then SOME d else NONE) clauses), + const `$$ (ts1 @ t :: ts2)) + end; + in + translate_const thy algbr funcgr thm c_ty + ##>> translate_term thy algbr funcgr thm t + ##>> translate_typ thy algbr funcgr ty + ##>> fold_map (fn (b, (pat, body)) => translate_term thy algbr funcgr thm pat + ##>> translate_term thy algbr funcgr thm body + #>> pair b) clauses + #>> (fn (((const, t), ty), ds) => mk_icase const t ty ds) + end +and translate_app_case thy algbr funcgr thm (case_scheme as (num_args, _)) ((c, ty), ts) = + if length ts < num_args then + let + val k = length ts; + val tys = (curry Library.take (num_args - k) o curry Library.drop k o fst o strip_type) ty; + val ctxt = (fold o fold_aterms) Term.declare_term_frees ts Name.context; + val vs = Name.names ctxt "a" tys; + in + fold_map (translate_typ thy algbr funcgr) tys + ##>> translate_case thy algbr funcgr thm case_scheme ((c, ty), ts @ map Free vs) + #>> (fn (tys, t) => map2 (fn (v, _) => pair v) vs tys `|==> t) + end + else if length ts > num_args then + translate_case thy algbr funcgr thm case_scheme ((c, ty), Library.take (num_args, ts)) + ##>> fold_map (translate_term thy algbr funcgr thm) (Library.drop (num_args, ts)) + #>> (fn (t, ts) => t `$$ ts) + else + translate_case thy algbr funcgr thm case_scheme ((c, ty), ts) +and translate_app thy algbr funcgr thm (c_ty_ts as ((c, _), _)) = + case Code.get_case_scheme thy c + of SOME case_scheme => translate_app_case thy algbr funcgr thm case_scheme c_ty_ts + | NONE => translate_app_const thy algbr funcgr thm c_ty_ts +and translate_tyvar_sort thy (algbr as (proj_sort, _)) funcgr (v, sort) = + fold_map (ensure_class thy algbr funcgr) (proj_sort sort) + #>> (fn sort => (unprefix "'" v, sort)) +and translate_dicts thy (algbr as (proj_sort, algebra)) funcgr thm (ty, sort) = + let + val pp = Syntax.pp_global thy; + datatype typarg = + Global of (class * string) * typarg list list + | Local of (class * class) list * (string * (int * sort)); + fun class_relation (Global ((_, tyco), yss), _) class = + Global ((class, tyco), yss) + | class_relation (Local (classrels, v), subclass) superclass = + Local ((subclass, superclass) :: classrels, v); + fun type_constructor tyco yss class = + Global ((class, tyco), (map o map) fst yss); + fun type_variable (TFree (v, sort)) = + let + val sort' = proj_sort sort; + in map_index (fn (n, class) => (Local ([], (v, (n, sort'))), class)) sort' end; + val typargs = Sorts.of_sort_derivation pp algebra + {class_relation = class_relation, type_constructor = type_constructor, + type_variable = type_variable} (ty, proj_sort sort) + handle Sorts.CLASS_ERROR e => not_wellsorted thy thm ty sort e; + fun mk_dict (Global (inst, yss)) = + ensure_inst thy algbr funcgr inst + ##>> (fold_map o fold_map) mk_dict yss + #>> (fn (inst, dss) => DictConst (inst, dss)) + | mk_dict (Local (classrels, (v, (k, sort)))) = + fold_map (ensure_classrel thy algbr funcgr) classrels + #>> (fn classrels => DictVar (classrels, (unprefix "'" v, (k, length sort)))) + in fold_map mk_dict typargs end; + + +(* store *) + +structure Program = CodeDataFun +( + type T = naming * program; + val empty = (empty_naming, Graph.empty); + fun purge thy cs (naming, program) = + let + val names_delete = cs + |> map_filter (lookup_const naming) + |> filter (can (Graph.get_node program)) + |> Graph.all_preds program; + val program' = Graph.del_nodes names_delete program; + in (naming, program') end; +); + +val cached_program = Program.get; + +fun invoke_generation thy (algebra, funcgr) f name = + Program.change_yield thy (fn naming_program => (NONE, naming_program) + |> f thy algebra funcgr name + |-> (fn name => fn (_, naming_program) => (name, naming_program))); + + +(* program generation *) + +fun consts_program thy cs = + let + fun project_consts cs (naming, program) = + let + val cs_all = Graph.all_succs program cs; + in (cs, (naming, Graph.subgraph (member (op =) cs_all) program)) end; + fun generate_consts thy algebra funcgr = + fold_map (ensure_const thy algebra funcgr); + in + invoke_generation thy (Code_Preproc.obtain thy cs []) generate_consts cs + |-> project_consts + end; + + +(* value evaluation *) + +fun ensure_value thy algbr funcgr t = + let + val ty = fastype_of t; + val vs = fold_term_types (K (fold_atyps (insert (eq_fst op =) + o dest_TFree))) t []; + val stmt_value = + fold_map (translate_tyvar_sort thy algbr funcgr) vs + ##>> translate_typ thy algbr funcgr ty + ##>> translate_term thy algbr funcgr NONE t + #>> (fn ((vs, ty), t) => Fun + (Term.dummy_patternN, ((vs, ty), [(([], t), (Drule.dummy_thm, true))]))); + fun term_value (dep, (naming, program1)) = + let + val Fun (_, (vs_ty, [(([], t), _)])) = + Graph.get_node program1 Term.dummy_patternN; + val deps = Graph.imm_succs program1 Term.dummy_patternN; + val program2 = Graph.del_nodes [Term.dummy_patternN] program1; + val deps_all = Graph.all_succs program2 deps; + val program3 = Graph.subgraph (member (op =) deps_all) program2; + in (((naming, program3), ((vs_ty, t), deps)), (dep, (naming, program2))) end; + in + ensure_stmt ((K o K) NONE) pair stmt_value Term.dummy_patternN + #> snd + #> term_value + end; + +fun base_evaluator thy evaluator algebra funcgr vs t = + let + val (((naming, program), (((vs', ty'), t'), deps)), _) = + invoke_generation thy (algebra, funcgr) ensure_value t; + val vs'' = map (fn (v, _) => (v, (the o AList.lookup (op =) vs o prefix "'") v)) vs'; + in evaluator naming program ((vs'', (vs', ty')), t') deps end; + +fun eval_conv thy prep_sort = Code_Preproc.eval_conv thy prep_sort o base_evaluator thy; +fun eval thy prep_sort postproc = Code_Preproc.eval thy prep_sort postproc o base_evaluator thy; + + +(** diagnostic commands **) + +fun read_const_exprs thy = + let + fun consts_of some_thyname = + let + val thy' = case some_thyname + of SOME thyname => ThyInfo.the_theory thyname thy + | NONE => thy; + val cs = Symtab.fold (fn (c, (_, NONE)) => cons c | _ => I) + ((snd o #constants o Consts.dest o #consts o Sign.rep_sg) thy') []; + fun belongs_here c = + not (exists (fn thy'' => Sign.declared_const thy'' c) (Theory.parents_of thy')) + in case some_thyname + of NONE => cs + | SOME thyname => filter belongs_here cs + end; + fun read_const_expr "*" = ([], consts_of NONE) + | read_const_expr s = if String.isSuffix ".*" s + then ([], consts_of (SOME (unsuffix ".*" s))) + else ([Code.read_const thy s], []); + in pairself flat o split_list o map read_const_expr end; + +fun code_depgr thy consts = + let + val (_, eqngr) = Code_Preproc.obtain thy consts []; + val select = Graph.all_succs eqngr consts; + in + eqngr + |> not (null consts) ? Graph.subgraph (member (op =) select) + |> Graph.map_nodes ((apsnd o map o apfst) (AxClass.overload thy)) + end; + +fun code_thms thy = Pretty.writeln o Code_Preproc.pretty thy o code_depgr thy; + +fun code_deps thy consts = + let + val eqngr = code_depgr thy consts; + val constss = Graph.strong_conn eqngr; + val mapping = Symtab.empty |> fold (fn consts => fold (fn const => + Symtab.update (const, consts)) consts) constss; + fun succs consts = consts + |> maps (Graph.imm_succs eqngr) + |> subtract (op =) consts + |> map (the o Symtab.lookup mapping) + |> distinct (op =); + val conn = [] |> fold (fn consts => cons (consts, succs consts)) constss; + fun namify consts = map (Code.string_of_const thy) consts + |> commas; + val prgr = map (fn (consts, constss) => + { name = namify consts, ID = namify consts, dir = "", unfold = true, + path = "", parents = map namify constss }) conn; + in Present.display_graph prgr end; + +local + +structure P = OuterParse +and K = OuterKeyword + +fun code_thms_cmd thy = code_thms thy o op @ o read_const_exprs thy; +fun code_deps_cmd thy = code_deps thy o op @ o read_const_exprs thy; + +in + +val _ = + OuterSyntax.improper_command "code_thms" "print system of code equations for code" OuterKeyword.diag + (Scan.repeat P.term_group + >> (fn cs => Toplevel.no_timing o Toplevel.unknown_theory + o Toplevel.keep ((fn thy => code_thms_cmd thy cs) o Toplevel.theory_of))); + +val _ = + OuterSyntax.improper_command "code_deps" "visualize dependencies of code equations for code" OuterKeyword.diag + (Scan.repeat P.term_group + >> (fn cs => Toplevel.no_timing o Toplevel.unknown_theory + o Toplevel.keep ((fn thy => code_deps_cmd thy cs) o Toplevel.theory_of))); + +end; + +end; (*struct*) + + +structure Basic_Code_Thingol: BASIC_CODE_THINGOL = Code_Thingol; diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/Tools/Code_Generator.thy --- a/src/Tools/Code_Generator.thy Tue Jun 23 12:09:14 2009 +0200 +++ b/src/Tools/Code_Generator.thy Tue Jun 23 12:09:30 2009 +0200 @@ -9,12 +9,12 @@ uses "~~/src/Tools/value.ML" "~~/src/Tools/quickcheck.ML" - "~~/src/Tools/code/code_preproc.ML" - "~~/src/Tools/code/code_thingol.ML" - "~~/src/Tools/code/code_printer.ML" - "~~/src/Tools/code/code_target.ML" - "~~/src/Tools/code/code_ml.ML" - "~~/src/Tools/code/code_haskell.ML" + "~~/src/Tools/Code/code_preproc.ML" + "~~/src/Tools/Code/code_thingol.ML" + "~~/src/Tools/Code/code_printer.ML" + "~~/src/Tools/Code/code_target.ML" + "~~/src/Tools/Code/code_ml.ML" + "~~/src/Tools/Code/code_haskell.ML" "~~/src/Tools/nbe.ML" begin diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/Tools/code/code_haskell.ML --- a/src/Tools/code/code_haskell.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,567 +0,0 @@ -(* Title: Tools/code/code_haskell.ML - Author: Florian Haftmann, TU Muenchen - -Serializer for Haskell. -*) - -signature CODE_HASKELL = -sig - val setup: theory -> theory -end; - -structure Code_Haskell : CODE_HASKELL = -struct - -val target = "Haskell"; - -open Basic_Code_Thingol; -open Code_Printer; - -infixr 5 @@; -infixr 5 @|; - - -(** Haskell serializer **) - -fun pr_haskell_bind pr_term = - let - fun pr_bind ((NONE, NONE), _) = str "_" - | pr_bind ((SOME v, NONE), _) = str v - | pr_bind ((NONE, SOME p), _) = p - | pr_bind ((SOME v, SOME p), _) = brackets [str v, str "@", p]; - in gen_pr_bind pr_bind pr_term end; - -fun pr_haskell_stmt labelled_name syntax_class syntax_tyco syntax_const - init_syms deresolve is_cons contr_classparam_typs deriving_show = - let - val deresolve_base = Long_Name.base_name o deresolve; - fun class_name class = case syntax_class class - of NONE => deresolve class - | SOME class => class; - fun pr_typcontext tyvars vs = case maps (fn (v, sort) => map (pair v) sort) vs - of [] => [] - | classbinds => Pretty.enum "," "(" ")" ( - map (fn (v, class) => - str (class_name class ^ " " ^ Code_Printer.lookup_var tyvars v)) classbinds) - @@ str " => "; - fun pr_typforall tyvars vs = case map fst vs - of [] => [] - | vnames => str "forall " :: Pretty.breaks - (map (str o Code_Printer.lookup_var tyvars) vnames) @ str "." @@ Pretty.brk 1; - fun pr_tycoexpr tyvars fxy (tyco, tys) = - brackify fxy (str tyco :: map (pr_typ tyvars BR) tys) - and pr_typ tyvars fxy (tycoexpr as tyco `%% tys) = (case syntax_tyco tyco - of NONE => pr_tycoexpr tyvars fxy (deresolve tyco, tys) - | SOME (i, pr) => pr (pr_typ tyvars) fxy tys) - | pr_typ tyvars fxy (ITyVar v) = (str o Code_Printer.lookup_var tyvars) v; - fun pr_typdecl tyvars (vs, tycoexpr) = - Pretty.block (pr_typcontext tyvars vs @| pr_tycoexpr tyvars NOBR tycoexpr); - fun pr_typscheme tyvars (vs, ty) = - Pretty.block (pr_typforall tyvars vs @ pr_typcontext tyvars vs @| pr_typ tyvars NOBR ty); - fun pr_term tyvars thm vars fxy (IConst c) = - pr_app tyvars thm vars fxy (c, []) - | pr_term tyvars thm vars fxy (t as (t1 `$ t2)) = - (case Code_Thingol.unfold_const_app t - of SOME app => pr_app tyvars thm vars fxy app - | _ => - brackify fxy [ - pr_term tyvars thm vars NOBR t1, - pr_term tyvars thm vars BR t2 - ]) - | pr_term tyvars thm vars fxy (IVar v) = - (str o Code_Printer.lookup_var vars) v - | pr_term tyvars thm vars fxy (t as _ `|=> _) = - let - val (binds, t') = Code_Thingol.unfold_abs t; - fun pr ((v, pat), ty) = pr_bind tyvars thm BR ((SOME v, pat), ty); - val (ps, vars') = fold_map pr binds vars; - in brackets (str "\\" :: ps @ str "->" @@ pr_term tyvars thm vars' NOBR t') end - | pr_term tyvars thm vars fxy (ICase (cases as (_, t0))) = - (case Code_Thingol.unfold_const_app t0 - of SOME (c_ts as ((c, _), _)) => if is_none (syntax_const c) - then pr_case tyvars thm vars fxy cases - else pr_app tyvars thm vars fxy c_ts - | NONE => pr_case tyvars thm vars fxy cases) - and pr_app' tyvars thm vars ((c, (_, tys)), ts) = case contr_classparam_typs c - of [] => (str o deresolve) c :: map (pr_term tyvars thm vars BR) ts - | fingerprint => let - val ts_fingerprint = ts ~~ curry Library.take (length ts) fingerprint; - val needs_annotation = forall (fn (_, NONE) => true | (t, SOME _) => - (not o Code_Thingol.locally_monomorphic) t) ts_fingerprint; - fun pr_term_anno (t, NONE) _ = pr_term tyvars thm vars BR t - | pr_term_anno (t, SOME _) ty = - brackets [pr_term tyvars thm vars NOBR t, str "::", pr_typ tyvars NOBR ty]; - in - if needs_annotation then - (str o deresolve) c :: map2 pr_term_anno ts_fingerprint (curry Library.take (length ts) tys) - else (str o deresolve) c :: map (pr_term tyvars thm vars BR) ts - end - and pr_app tyvars = gen_pr_app (pr_app' tyvars) (pr_term tyvars) syntax_const - and pr_bind tyvars = pr_haskell_bind (pr_term tyvars) - and pr_case tyvars thm vars fxy (cases as ((_, [_]), _)) = - let - val (binds, body) = Code_Thingol.unfold_let (ICase cases); - fun pr ((pat, ty), t) vars = - vars - |> pr_bind tyvars thm BR ((NONE, SOME pat), ty) - |>> (fn p => semicolon [p, str "=", pr_term tyvars thm vars NOBR t]) - val (ps, vars') = fold_map pr binds vars; - in brackify_block fxy (str "let {") - ps - (concat [str "}", str "in", pr_term tyvars thm vars' NOBR body]) - end - | pr_case tyvars thm vars fxy (((t, ty), clauses as _ :: _), _) = - let - fun pr (pat, body) = - let - val (p, vars') = pr_bind tyvars thm NOBR ((NONE, SOME pat), ty) vars; - in semicolon [p, str "->", pr_term tyvars thm vars' NOBR body] end; - in brackify_block fxy - (concat [str "case", pr_term tyvars thm vars NOBR t, str "of", str "{"]) - (map pr clauses) - (str "}") - end - | pr_case tyvars thm vars fxy ((_, []), _) = - (brackify fxy o Pretty.breaks o map str) ["error", "\"empty case\""]; - fun pr_stmt (name, Code_Thingol.Fun (_, ((vs, ty), []))) = - let - val tyvars = Code_Printer.intro_vars (map fst vs) init_syms; - val n = (length o fst o Code_Thingol.unfold_fun) ty; - in - Pretty.chunks [ - Pretty.block [ - (str o suffix " ::" o deresolve_base) name, - Pretty.brk 1, - pr_typscheme tyvars (vs, ty), - str ";" - ], - concat ( - (str o deresolve_base) name - :: map str (replicate n "_") - @ str "=" - :: str "error" - @@ (str o (fn s => s ^ ";") o ML_Syntax.print_string - o Long_Name.base_name o Long_Name.qualifier) name - ) - ] - end - | pr_stmt (name, Code_Thingol.Fun (_, ((vs, ty), raw_eqs))) = - let - val eqs = filter (snd o snd) raw_eqs; - val tyvars = Code_Printer.intro_vars (map fst vs) init_syms; - fun pr_eq ((ts, t), (thm, _)) = - let - val consts = map_filter - (fn c => if (is_some o syntax_const) c - then NONE else (SOME o Long_Name.base_name o deresolve) c) - ((fold o Code_Thingol.fold_constnames) (insert (op =)) (t :: ts) []); - val vars = init_syms - |> Code_Printer.intro_vars consts - |> Code_Printer.intro_vars ((fold o Code_Thingol.fold_unbound_varnames) - (insert (op =)) ts []); - in - semicolon ( - (str o deresolve_base) name - :: map (pr_term tyvars thm vars BR) ts - @ str "=" - @@ pr_term tyvars thm vars NOBR t - ) - end; - in - Pretty.chunks ( - Pretty.block [ - (str o suffix " ::" o deresolve_base) name, - Pretty.brk 1, - pr_typscheme tyvars (vs, ty), - str ";" - ] - :: map pr_eq eqs - ) - end - | pr_stmt (name, Code_Thingol.Datatype (_, (vs, []))) = - let - val tyvars = Code_Printer.intro_vars (map fst vs) init_syms; - in - semicolon [ - str "data", - pr_typdecl tyvars (vs, (deresolve_base name, map (ITyVar o fst) vs)) - ] - end - | pr_stmt (name, Code_Thingol.Datatype (_, (vs, [(co, [ty])]))) = - let - val tyvars = Code_Printer.intro_vars (map fst vs) init_syms; - in - semicolon ( - str "newtype" - :: pr_typdecl tyvars (vs, (deresolve_base name, map (ITyVar o fst) vs)) - :: str "=" - :: (str o deresolve_base) co - :: pr_typ tyvars BR ty - :: (if deriving_show name then [str "deriving (Read, Show)"] else []) - ) - end - | pr_stmt (name, Code_Thingol.Datatype (_, (vs, co :: cos))) = - let - val tyvars = Code_Printer.intro_vars (map fst vs) init_syms; - fun pr_co (co, tys) = - concat ( - (str o deresolve_base) co - :: map (pr_typ tyvars BR) tys - ) - in - semicolon ( - str "data" - :: pr_typdecl tyvars (vs, (deresolve_base name, map (ITyVar o fst) vs)) - :: str "=" - :: pr_co co - :: map ((fn p => Pretty.block [str "| ", p]) o pr_co) cos - @ (if deriving_show name then [str "deriving (Read, Show)"] else []) - ) - end - | pr_stmt (name, Code_Thingol.Class (_, (v, (superclasses, classparams)))) = - let - val tyvars = Code_Printer.intro_vars [v] init_syms; - fun pr_classparam (classparam, ty) = - semicolon [ - (str o deresolve_base) classparam, - str "::", - pr_typ tyvars NOBR ty - ] - in - Pretty.block_enclose ( - Pretty.block [ - str "class ", - Pretty.block (pr_typcontext tyvars [(v, map fst superclasses)]), - str (deresolve_base name ^ " " ^ Code_Printer.lookup_var tyvars v), - str " where {" - ], - str "};" - ) (map pr_classparam classparams) - end - | pr_stmt (_, Code_Thingol.Classinst ((class, (tyco, vs)), (_, classparam_insts))) = - let - val split_abs_pure = (fn (v, _) `|=> t => SOME (v, t) | _ => NONE); - val unfold_abs_pure = Code_Thingol.unfoldr split_abs_pure; - val tyvars = Code_Printer.intro_vars (map fst vs) init_syms; - fun pr_instdef ((classparam, c_inst), (thm, _)) = case syntax_const classparam - of NONE => semicolon [ - (str o deresolve_base) classparam, - str "=", - pr_app tyvars thm init_syms NOBR (c_inst, []) - ] - | SOME (k, pr) => - let - val (c_inst_name, (_, tys)) = c_inst; - val const = if (is_some o syntax_const) c_inst_name - then NONE else (SOME o Long_Name.base_name o deresolve) c_inst_name; - val proto_rhs = Code_Thingol.eta_expand k (c_inst, []); - val (vs, rhs) = unfold_abs_pure proto_rhs; - val vars = init_syms - |> Code_Printer.intro_vars (the_list const) - |> Code_Printer.intro_vars vs; - val lhs = IConst (classparam, (([], []), tys)) `$$ map IVar vs; - (*dictionaries are not relevant at this late stage*) - in - semicolon [ - pr_term tyvars thm vars NOBR lhs, - str "=", - pr_term tyvars thm vars NOBR rhs - ] - end; - in - Pretty.block_enclose ( - Pretty.block [ - str "instance ", - Pretty.block (pr_typcontext tyvars vs), - str (class_name class ^ " "), - pr_typ tyvars BR (tyco `%% map (ITyVar o fst) vs), - str " where {" - ], - str "};" - ) (map pr_instdef classparam_insts) - end; - in pr_stmt end; - -fun haskell_program_of_program labelled_name module_name module_prefix reserved_names raw_module_alias program = - let - val module_alias = if is_some module_name then K module_name else raw_module_alias; - val reserved_names = Name.make_context reserved_names; - val mk_name_module = Code_Printer.mk_name_module reserved_names module_prefix module_alias program; - fun add_stmt (name, (stmt, deps)) = - let - val (module_name, base) = Code_Printer.dest_name name; - val module_name' = mk_name_module module_name; - val mk_name_stmt = yield_singleton Name.variants; - fun add_fun upper (nsp_fun, nsp_typ) = - let - val (base', nsp_fun') = - mk_name_stmt (if upper then Code_Printer.first_upper base else base) nsp_fun - in (base', (nsp_fun', nsp_typ)) end; - fun add_typ (nsp_fun, nsp_typ) = - let - val (base', nsp_typ') = mk_name_stmt (Code_Printer.first_upper base) nsp_typ - in (base', (nsp_fun, nsp_typ')) end; - val add_name = case stmt - of Code_Thingol.Fun _ => add_fun false - | Code_Thingol.Datatype _ => add_typ - | Code_Thingol.Datatypecons _ => add_fun true - | Code_Thingol.Class _ => add_typ - | Code_Thingol.Classrel _ => pair base - | Code_Thingol.Classparam _ => add_fun false - | Code_Thingol.Classinst _ => pair base; - fun add_stmt' base' = case stmt - of Code_Thingol.Datatypecons _ => - cons (name, (Long_Name.append module_name' base', NONE)) - | Code_Thingol.Classrel _ => I - | Code_Thingol.Classparam _ => - cons (name, (Long_Name.append module_name' base', NONE)) - | _ => cons (name, (Long_Name.append module_name' base', SOME stmt)); - in - Symtab.map_default (module_name', ([], ([], (reserved_names, reserved_names)))) - (apfst (fold (insert (op = : string * string -> bool)) deps)) - #> `(fn program => add_name ((snd o snd o the o Symtab.lookup program) module_name')) - #-> (fn (base', names) => - (Symtab.map_entry module_name' o apsnd) (fn (stmts, _) => - (add_stmt' base' stmts, names))) - end; - val hs_program = fold add_stmt (AList.make (fn name => - (Graph.get_node program name, Graph.imm_succs program name)) - (Graph.strong_conn program |> flat)) Symtab.empty; - fun deresolver name = (fst o the o AList.lookup (op =) ((fst o snd o the - o Symtab.lookup hs_program) ((mk_name_module o fst o Code_Printer.dest_name) name))) name - handle Option => error ("Unknown statement name: " ^ labelled_name name); - in (deresolver, hs_program) end; - -fun serialize_haskell module_prefix raw_module_name string_classes labelled_name - raw_reserved_names includes raw_module_alias - syntax_class syntax_tyco syntax_const program cs destination = - let - val stmt_names = Code_Target.stmt_names_of_destination destination; - val module_name = if null stmt_names then raw_module_name else SOME "Code"; - val reserved_names = fold (insert (op =) o fst) includes raw_reserved_names; - val (deresolver, hs_program) = haskell_program_of_program labelled_name - module_name module_prefix reserved_names raw_module_alias program; - val is_cons = Code_Thingol.is_cons program; - val contr_classparam_typs = Code_Thingol.contr_classparam_typs program; - fun deriving_show tyco = - let - fun deriv _ "fun" = false - | deriv tycos tyco = member (op =) tycos tyco orelse - case try (Graph.get_node program) tyco - of SOME (Code_Thingol.Datatype (_, (_, cs))) => forall (deriv' (tyco :: tycos)) - (maps snd cs) - | NONE => true - and deriv' tycos (tyco `%% tys) = deriv tycos tyco - andalso forall (deriv' tycos) tys - | deriv' _ (ITyVar _) = true - in deriv [] tyco end; - val reserved_names = Code_Printer.make_vars reserved_names; - fun pr_stmt qualified = pr_haskell_stmt labelled_name - syntax_class syntax_tyco syntax_const reserved_names - (if qualified then deresolver else Long_Name.base_name o deresolver) - is_cons contr_classparam_typs - (if string_classes then deriving_show else K false); - fun pr_module name content = - (name, Pretty.chunks [ - str ("module " ^ name ^ " where {"), - str "", - content, - str "", - str "}" - ]); - fun serialize_module1 (module_name', (deps, (stmts, _))) = - let - val stmt_names = map fst stmts; - val deps' = subtract (op =) stmt_names deps - |> distinct (op =) - |> map_filter (try deresolver); - val qualified = is_none module_name andalso - map deresolver stmt_names @ deps' - |> map Long_Name.base_name - |> has_duplicates (op =); - val imports = deps' - |> map Long_Name.qualifier - |> distinct (op =); - fun pr_import_include (name, _) = str ("import qualified " ^ name ^ ";"); - val pr_import_module = str o (if qualified - then prefix "import qualified " - else prefix "import ") o suffix ";"; - val content = Pretty.chunks ( - map pr_import_include includes - @ map pr_import_module imports - @ str "" - :: separate (str "") (map_filter - (fn (name, (_, SOME stmt)) => SOME (pr_stmt qualified (name, stmt)) - | (_, (_, NONE)) => NONE) stmts) - ) - in pr_module module_name' content end; - fun serialize_module2 (_, (_, (stmts, _))) = Pretty.chunks ( - separate (str "") (map_filter - (fn (name, (_, SOME stmt)) => if null stmt_names - orelse member (op =) stmt_names name - then SOME (pr_stmt false (name, stmt)) - else NONE - | (_, (_, NONE)) => NONE) stmts)); - val serialize_module = - if null stmt_names then serialize_module1 else pair "" o serialize_module2; - fun check_destination destination = - (File.check destination; destination); - fun write_module destination (modlname, content) = - let - val filename = case modlname - of "" => Path.explode "Main.hs" - | _ => (Path.ext "hs" o Path.explode o implode o separate "/" - o Long_Name.explode) modlname; - val pathname = Path.append destination filename; - val _ = File.mkdir (Path.dir pathname); - in File.write pathname - ("{-# OPTIONS_GHC -fglasgow-exts #-}\n\n" - ^ Code_Target.code_of_pretty content) - end - in - Code_Target.mk_serialization target NONE - (fn NONE => K () o map (Code_Target.code_writeln o snd) | SOME file => K () o map - (write_module (check_destination file))) - (rpair [] o cat_lines o map (Code_Target.code_of_pretty o snd)) - (map (uncurry pr_module) includes - @ map serialize_module (Symtab.dest hs_program)) - destination - end; - -val literals = let - fun char_haskell c = - let - val s = ML_Syntax.print_char c; - in if s = "'" then "\\'" else s end; -in Literals { - literal_char = enclose "'" "'" o char_haskell, - literal_string = quote o translate_string char_haskell, - literal_numeral = fn unbounded => fn k => if k >= 0 then string_of_int k - else enclose "(" ")" (signed_string_of_int k), - literal_list = Pretty.enum "," "[" "]", - infix_cons = (5, ":") -} end; - - -(** optional monad syntax **) - -fun pretty_haskell_monad c_bind = - let - fun dest_bind t1 t2 = case Code_Thingol.split_abs t2 - of SOME (((v, pat), ty), t') => - SOME ((SOME (((SOME v, pat), ty), true), t1), t') - | NONE => NONE; - fun dest_monad c_bind_name (IConst (c, _) `$ t1 `$ t2) = - if c = c_bind_name then dest_bind t1 t2 - else NONE - | dest_monad _ t = case Code_Thingol.split_let t - of SOME (((pat, ty), tbind), t') => - SOME ((SOME (((NONE, SOME pat), ty), false), tbind), t') - | NONE => NONE; - fun implode_monad c_bind_name = Code_Thingol.unfoldr (dest_monad c_bind_name); - fun pr_monad pr_bind pr (NONE, t) vars = - (semicolon [pr vars NOBR t], vars) - | pr_monad pr_bind pr (SOME (bind, true), t) vars = vars - |> pr_bind NOBR bind - |>> (fn p => semicolon [p, str "<-", pr vars NOBR t]) - | pr_monad pr_bind pr (SOME (bind, false), t) vars = vars - |> pr_bind NOBR bind - |>> (fn p => semicolon [str "let", p, str "=", pr vars NOBR t]); - fun pretty _ [c_bind'] pr thm vars fxy [(t1, _), (t2, _)] = case dest_bind t1 t2 - of SOME (bind, t') => let - val (binds, t'') = implode_monad c_bind' t' - val (ps, vars') = fold_map (pr_monad (pr_haskell_bind (K pr) thm) pr) (bind :: binds) vars; - in (brackify fxy o single o Pretty.enclose "do {" "}" o Pretty.breaks) (ps @| pr vars' NOBR t'') end - | NONE => brackify_infix (1, L) fxy - [pr vars (INFX (1, L)) t1, str ">>=", pr vars (INFX (1, X)) t2] - in (2, ([c_bind], pretty)) end; - -fun add_monad target' raw_c_bind thy = - let - val c_bind = Code.read_const thy raw_c_bind; - in if target = target' then - thy - |> Code_Target.add_syntax_const target c_bind - (SOME (pretty_haskell_monad c_bind)) - else error "Only Haskell target allows for monad syntax" end; - - -(** Isar setup **) - -fun isar_seri_haskell module = - Code_Target.parse_args (Scan.option (Args.$$$ "root" -- Args.colon |-- Args.name) - -- Scan.optional (Args.$$$ "string_classes" >> K true) false - >> (fn (module_prefix, string_classes) => - serialize_haskell module_prefix module string_classes)); - -val _ = - OuterSyntax.command "code_monad" "define code syntax for monads" OuterKeyword.thy_decl ( - OuterParse.term_group -- OuterParse.name >> (fn (raw_bind, target) => - Toplevel.theory (add_monad target raw_bind)) - ); - -val setup = - Code_Target.add_target (target, (isar_seri_haskell, literals)) - #> Code_Target.add_syntax_tyco target "fun" (SOME (2, fn pr_typ => fn fxy => fn [ty1, ty2] => - brackify_infix (1, R) fxy [ - pr_typ (INFX (1, X)) ty1, - str "->", - pr_typ (INFX (1, R)) ty2 - ])) - #> fold (Code_Target.add_reserved target) [ - "hiding", "deriving", "where", "case", "of", "infix", "infixl", "infixr", - "import", "default", "forall", "let", "in", "class", "qualified", "data", - "newtype", "instance", "if", "then", "else", "type", "as", "do", "module" - ] - #> fold (Code_Target.add_reserved target) [ - "Prelude", "Main", "Bool", "Maybe", "Either", "Ordering", "Char", "String", "Int", - "Integer", "Float", "Double", "Rational", "IO", "Eq", "Ord", "Enum", "Bounded", - "Num", "Real", "Integral", "Fractional", "Floating", "RealFloat", "Monad", "Functor", - "AlreadyExists", "ArithException", "ArrayException", "AssertionFailed", "AsyncException", - "BlockedOnDeadMVar", "Deadlock", "Denormal", "DivideByZero", "DotNetException", "DynException", - "Dynamic", "EOF", "EQ", "EmptyRec", "ErrorCall", "ExitException", "ExitFailure", - "ExitSuccess", "False", "GT", "HeapOverflow", - "IOError", "IOException", "IllegalOperation", - "IndexOutOfBounds", "Just", "Key", "LT", "Left", "LossOfPrecision", "NoMethodError", - "NoSuchThing", "NonTermination", "Nothing", "Obj", "OtherError", "Overflow", - "PatternMatchFail", "PermissionDenied", "ProtocolError", "RecConError", "RecSelError", - "RecUpdError", "ResourceBusy", "ResourceExhausted", "Right", "StackOverflow", - "ThreadKilled", "True", "TyCon", "TypeRep", "UndefinedElement", "Underflow", - "UnsupportedOperation", "UserError", "abs", "absReal", "acos", "acosh", "all", - "and", "any", "appendFile", "asTypeOf", "asciiTab", "asin", "asinh", "atan", - "atan2", "atanh", "basicIORun", "blockIO", "boundedEnumFrom", "boundedEnumFromThen", - "boundedEnumFromThenTo", "boundedEnumFromTo", "boundedPred", "boundedSucc", "break", - "catch", "catchException", "ceiling", "compare", "concat", "concatMap", "const", - "cos", "cosh", "curry", "cycle", "decodeFloat", "denominator", "div", "divMod", - "doubleToRatio", "doubleToRational", "drop", "dropWhile", "either", "elem", - "emptyRec", "encodeFloat", "enumFrom", "enumFromThen", "enumFromThenTo", - "enumFromTo", "error", "even", "exp", "exponent", "fail", "filter", "flip", - "floatDigits", "floatProperFraction", "floatRadix", "floatRange", "floatToRational", - "floor", "fmap", "foldl", "foldl'", "foldl1", "foldr", "foldr1", "fromDouble", - "fromEnum", "fromEnum_0", "fromInt", "fromInteger", "fromIntegral", "fromObj", - "fromRational", "fst", "gcd", "getChar", "getContents", "getLine", "head", - "id", "inRange", "index", "init", "intToRatio", "interact", "ioError", "isAlpha", - "isAlphaNum", "isDenormalized", "isDigit", "isHexDigit", "isIEEE", "isInfinite", - "isLower", "isNaN", "isNegativeZero", "isOctDigit", "isSpace", "isUpper", "iterate", "iterate'", - "last", "lcm", "length", "lex", "lexDigits", "lexLitChar", "lexmatch", "lines", "log", - "logBase", "lookup", "loop", "map", "mapM", "mapM_", "max", "maxBound", "maximum", - "maybe", "min", "minBound", "minimum", "mod", "negate", "nonnull", "not", "notElem", - "null", "numerator", "numericEnumFrom", "numericEnumFromThen", "numericEnumFromThenTo", - "numericEnumFromTo", "odd", "or", "otherwise", "pi", "pred", - "print", "product", "properFraction", "protectEsc", "putChar", "putStr", "putStrLn", - "quot", "quotRem", "range", "rangeSize", "rationalToDouble", "rationalToFloat", - "rationalToRealFloat", "read", "readDec", "readField", "readFieldName", "readFile", - "readFloat", "readHex", "readIO", "readInt", "readList", "readLitChar", "readLn", - "readOct", "readParen", "readSigned", "reads", "readsPrec", "realFloatToRational", - "realToFrac", "recip", "reduce", "rem", "repeat", "replicate", "return", "reverse", - "round", "scaleFloat", "scanl", "scanl1", "scanr", "scanr1", "seq", "sequence", - "sequence_", "show", "showChar", "showException", "showField", "showList", - "showLitChar", "showParen", "showString", "shows", "showsPrec", "significand", - "signum", "signumReal", "sin", "sinh", "snd", "span", "splitAt", "sqrt", "subtract", - "succ", "sum", "tail", "take", "takeWhile", "takeWhile1", "tan", "tanh", "threadToIOResult", - "throw", "toEnum", "toInt", "toInteger", "toObj", "toRational", "truncate", "uncurry", - "undefined", "unlines", "unsafeCoerce", "unsafeIndex", "unsafeRangeSize", "until", "unwords", - "unzip", "unzip3", "userError", "words", "writeFile", "zip", "zip3", "zipWith", "zipWith3" - ] (*due to weird handling of ':', we can't do anything else than to import *all* prelude symbols*); - -end; (*struct*) diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/Tools/code/code_ml.ML --- a/src/Tools/code/code_ml.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,1122 +0,0 @@ -(* Title: Tools/code/code_ml.ML - Author: Florian Haftmann, TU Muenchen - -Serializer for SML and OCaml. -*) - -signature CODE_ML = -sig - val eval: string option -> string * (unit -> 'a) option ref - -> ((term -> term) -> 'a -> 'a) -> theory -> term -> string list -> 'a - val target_Eval: string - val setup: theory -> theory -end; - -structure Code_ML : CODE_ML = -struct - -open Basic_Code_Thingol; -open Code_Printer; - -infixr 5 @@; -infixr 5 @|; - -val target_SML = "SML"; -val target_OCaml = "OCaml"; -val target_Eval = "Eval"; - -datatype ml_stmt = - MLExc of string * int - | MLVal of string * ((typscheme * iterm) * (thm * bool)) - | MLFuns of (string * (typscheme * ((iterm list * iterm) * (thm * bool)) list)) list * string list - | MLDatas of (string * ((vname * sort) list * (string * itype list) list)) list - | MLClass of string * (vname * ((class * string) list * (string * itype) list)) - | MLClassinst of string * ((class * (string * (vname * sort) list)) - * ((class * (string * (string * dict list list))) list - * ((string * const) * (thm * bool)) list)); - -fun stmt_names_of (MLExc (name, _)) = [name] - | stmt_names_of (MLVal (name, _)) = [name] - | stmt_names_of (MLFuns (fs, _)) = map fst fs - | stmt_names_of (MLDatas ds) = map fst ds - | stmt_names_of (MLClass (name, _)) = [name] - | stmt_names_of (MLClassinst (name, _)) = [name]; - - -(** SML serailizer **) - -fun pr_sml_stmt labelled_name syntax_tyco syntax_const reserved_names deresolve is_cons = - let - fun pr_dicts fxy ds = - let - fun pr_dictvar (v, (_, 1)) = Code_Printer.first_upper v ^ "_" - | pr_dictvar (v, (i, _)) = Code_Printer.first_upper v ^ string_of_int (i+1) ^ "_"; - fun pr_proj [] p = - p - | pr_proj [p'] p = - brackets [p', p] - | pr_proj (ps as _ :: _) p = - brackets [Pretty.enum " o" "(" ")" ps, p]; - fun pr_dict fxy (DictConst (inst, dss)) = - brackify fxy ((str o deresolve) inst :: map (pr_dicts BR) dss) - | pr_dict fxy (DictVar (classrels, v)) = - pr_proj (map (str o deresolve) classrels) ((str o pr_dictvar) v) - in case ds - of [] => str "()" - | [d] => pr_dict fxy d - | _ :: _ => (Pretty.list "(" ")" o map (pr_dict NOBR)) ds - end; - fun pr_tyvar_dicts vs = - vs - |> map (fn (v, sort) => map_index (fn (i, _) => - DictVar ([], (v, (i, length sort)))) sort) - |> map (pr_dicts BR); - fun pr_tycoexpr fxy (tyco, tys) = - let - val tyco' = (str o deresolve) tyco - in case map (pr_typ BR) tys - of [] => tyco' - | [p] => Pretty.block [p, Pretty.brk 1, tyco'] - | (ps as _::_) => Pretty.block [Pretty.list "(" ")" ps, Pretty.brk 1, tyco'] - end - and pr_typ fxy (tyco `%% tys) = (case syntax_tyco tyco - of NONE => pr_tycoexpr fxy (tyco, tys) - | SOME (i, pr) => pr pr_typ fxy tys) - | pr_typ fxy (ITyVar v) = str ("'" ^ v); - fun pr_term is_closure thm vars fxy (IConst c) = - pr_app is_closure thm vars fxy (c, []) - | pr_term is_closure thm vars fxy (IVar v) = - str (Code_Printer.lookup_var vars v) - | pr_term is_closure thm vars fxy (t as t1 `$ t2) = - (case Code_Thingol.unfold_const_app t - of SOME c_ts => pr_app is_closure thm vars fxy c_ts - | NONE => brackify fxy - [pr_term is_closure thm vars NOBR t1, pr_term is_closure thm vars BR t2]) - | pr_term is_closure thm vars fxy (t as _ `|=> _) = - let - val (binds, t') = Code_Thingol.unfold_abs t; - fun pr ((v, pat), ty) = - pr_bind is_closure thm NOBR ((SOME v, pat), ty) - #>> (fn p => concat [str "fn", p, str "=>"]); - val (ps, vars') = fold_map pr binds vars; - in brackets (ps @ [pr_term is_closure thm vars' NOBR t']) end - | pr_term is_closure thm vars fxy (ICase (cases as (_, t0))) = - (case Code_Thingol.unfold_const_app t0 - of SOME (c_ts as ((c, _), _)) => if is_none (syntax_const c) - then pr_case is_closure thm vars fxy cases - else pr_app is_closure thm vars fxy c_ts - | NONE => pr_case is_closure thm vars fxy cases) - and pr_app' is_closure thm vars (app as ((c, ((_, iss), tys)), ts)) = - if is_cons c then - let - val k = length tys - in if k < 2 then - (str o deresolve) c :: map (pr_term is_closure thm vars BR) ts - else if k = length ts then - [(str o deresolve) c, Pretty.enum "," "(" ")" (map (pr_term is_closure thm vars NOBR) ts)] - else [pr_term is_closure thm vars BR (Code_Thingol.eta_expand k app)] end - else if is_closure c - then (str o deresolve) c @@ str "()" - else - (str o deresolve) c - :: (map (pr_dicts BR) o filter_out null) iss @ map (pr_term is_closure thm vars BR) ts - and pr_app is_closure thm vars = gen_pr_app (pr_app' is_closure) (pr_term is_closure) - syntax_const thm vars - and pr_bind' ((NONE, NONE), _) = str "_" - | pr_bind' ((SOME v, NONE), _) = str v - | pr_bind' ((NONE, SOME p), _) = p - | pr_bind' ((SOME v, SOME p), _) = concat [str v, str "as", p] - and pr_bind is_closure = gen_pr_bind pr_bind' (pr_term is_closure) - and pr_case is_closure thm vars fxy (cases as ((_, [_]), _)) = - let - val (binds, body) = Code_Thingol.unfold_let (ICase cases); - fun pr ((pat, ty), t) vars = - vars - |> pr_bind is_closure thm NOBR ((NONE, SOME pat), ty) - |>> (fn p => semicolon [str "val", p, str "=", pr_term is_closure thm vars NOBR t]) - val (ps, vars') = fold_map pr binds vars; - in - Pretty.chunks [ - [str ("let"), Pretty.fbrk, Pretty.chunks ps] |> Pretty.block, - [str ("in"), Pretty.fbrk, pr_term is_closure thm vars' NOBR body] |> Pretty.block, - str ("end") - ] - end - | pr_case is_closure thm vars fxy (((t, ty), clause :: clauses), _) = - let - fun pr delim (pat, body) = - let - val (p, vars') = pr_bind is_closure thm NOBR ((NONE, SOME pat), ty) vars; - in - concat [str delim, p, str "=>", pr_term is_closure thm vars' NOBR body] - end; - in - brackets ( - str "case" - :: pr_term is_closure thm vars NOBR t - :: pr "of" clause - :: map (pr "|") clauses - ) - end - | pr_case is_closure thm vars fxy ((_, []), _) = - (concat o map str) ["raise", "Fail", "\"empty case\""]; - fun pr_stmt (MLExc (name, n)) = - let - val exc_str = - (ML_Syntax.print_string o Long_Name.base_name o Long_Name.qualifier) name; - in - (concat o map str) ( - (if n = 0 then "val" else "fun") - :: deresolve name - :: replicate n "_" - @ "=" - :: "raise" - :: "Fail" - @@ exc_str - ) - end - | pr_stmt (MLVal (name, (((vs, ty), t), (thm, _)))) = - let - val consts = map_filter - (fn c => if (is_some o syntax_const) c - then NONE else (SOME o Long_Name.base_name o deresolve) c) - (Code_Thingol.fold_constnames (insert (op =)) t []); - val vars = reserved_names - |> Code_Printer.intro_vars consts; - in - concat [ - str "val", - (str o deresolve) name, - str ":", - pr_typ NOBR ty, - str "=", - pr_term (K false) thm vars NOBR t - ] - end - | pr_stmt (MLFuns (funn :: funns, pseudo_funs)) = - let - fun pr_funn definer (name, ((vs, ty), eqs as eq :: eqs')) = - let - val vs_dict = filter_out (null o snd) vs; - val shift = if null eqs' then I else - map (Pretty.block o single o Pretty.block o single); - fun pr_eq definer ((ts, t), (thm, _)) = - let - val consts = map_filter - (fn c => if (is_some o syntax_const) c - then NONE else (SOME o Long_Name.base_name o deresolve) c) - ((fold o Code_Thingol.fold_constnames) (insert (op =)) (t :: ts) []); - val vars = reserved_names - |> Code_Printer.intro_vars consts - |> Code_Printer.intro_vars ((fold o Code_Thingol.fold_unbound_varnames) - (insert (op =)) ts []); - in - concat ( - str definer - :: (str o deresolve) name - :: (if member (op =) pseudo_funs name then [str "()"] - else pr_tyvar_dicts vs_dict - @ map (pr_term (member (op =) pseudo_funs) thm vars BR) ts) - @ str "=" - @@ pr_term (member (op =) pseudo_funs) thm vars NOBR t - ) - end - in - (Pretty.block o Pretty.fbreaks o shift) ( - pr_eq definer eq - :: map (pr_eq "|") eqs' - ) - end; - fun pr_pseudo_fun name = concat [ - str "val", - (str o deresolve) name, - str "=", - (str o deresolve) name, - str "();" - ]; - val (ps, p) = split_last (pr_funn "fun" funn :: map (pr_funn "and") funns); - val pseudo_ps = map pr_pseudo_fun pseudo_funs; - in Pretty.chunks (ps @ Pretty.block ([p, str ";"]) :: pseudo_ps) end - | pr_stmt (MLDatas (datas as (data :: datas'))) = - let - fun pr_co (co, []) = - str (deresolve co) - | pr_co (co, tys) = - concat [ - str (deresolve co), - str "of", - Pretty.enum " *" "" "" (map (pr_typ (INFX (2, X))) tys) - ]; - fun pr_data definer (tyco, (vs, [])) = - concat ( - str definer - :: pr_tycoexpr NOBR (tyco, map (ITyVar o fst) vs) - :: str "=" - @@ str "EMPTY__" - ) - | pr_data definer (tyco, (vs, cos)) = - concat ( - str definer - :: pr_tycoexpr NOBR (tyco, map (ITyVar o fst) vs) - :: str "=" - :: separate (str "|") (map pr_co cos) - ); - val (ps, p) = split_last - (pr_data "datatype" data :: map (pr_data "and") datas'); - in Pretty.chunks (ps @| Pretty.block ([p, str ";"])) end - | pr_stmt (MLClass (class, (v, (superclasses, classparams)))) = - let - val w = Code_Printer.first_upper v ^ "_"; - fun pr_superclass_field (class, classrel) = - (concat o map str) [ - deresolve classrel, ":", "'" ^ v, deresolve class - ]; - fun pr_classparam_field (classparam, ty) = - concat [ - (str o deresolve) classparam, str ":", pr_typ NOBR ty - ]; - fun pr_classparam_proj (classparam, _) = - semicolon [ - str "fun", - (str o deresolve) classparam, - Pretty.enclose "(" ")" [str (w ^ ":'" ^ v ^ " " ^ deresolve class)], - str "=", - str ("#" ^ deresolve classparam), - str w - ]; - fun pr_superclass_proj (_, classrel) = - semicolon [ - str "fun", - (str o deresolve) classrel, - Pretty.enclose "(" ")" [str (w ^ ":'" ^ v ^ " " ^ deresolve class)], - str "=", - str ("#" ^ deresolve classrel), - str w - ]; - in - Pretty.chunks ( - concat [ - str ("type '" ^ v), - (str o deresolve) class, - str "=", - Pretty.enum "," "{" "};" ( - map pr_superclass_field superclasses @ map pr_classparam_field classparams - ) - ] - :: map pr_superclass_proj superclasses - @ map pr_classparam_proj classparams - ) - end - | pr_stmt (MLClassinst (inst, ((class, (tyco, arity)), (superarities, classparam_insts)))) = - let - fun pr_superclass (_, (classrel, dss)) = - concat [ - (str o Long_Name.base_name o deresolve) classrel, - str "=", - pr_dicts NOBR [DictConst dss] - ]; - fun pr_classparam ((classparam, c_inst), (thm, _)) = - concat [ - (str o Long_Name.base_name o deresolve) classparam, - str "=", - pr_app (K false) thm reserved_names NOBR (c_inst, []) - ]; - in - semicolon ([ - str (if null arity then "val" else "fun"), - (str o deresolve) inst ] @ - pr_tyvar_dicts arity @ [ - str "=", - Pretty.enum "," "{" "}" - (map pr_superclass superarities @ map pr_classparam classparam_insts), - str ":", - pr_tycoexpr NOBR (class, [tyco `%% map (ITyVar o fst) arity]) - ]) - end; - in pr_stmt end; - -fun pr_sml_module name content = - Pretty.chunks ( - str ("structure " ^ name ^ " = ") - :: str "struct" - :: str "" - :: content - @ str "" - @@ str ("end; (*struct " ^ name ^ "*)") - ); - -val literals_sml = Literals { - literal_char = prefix "#" o quote o ML_Syntax.print_char, - literal_string = quote o translate_string ML_Syntax.print_char, - literal_numeral = fn unbounded => fn k => - if unbounded then "(" ^ string_of_int k ^ " : IntInf.int)" - else string_of_int k, - literal_list = Pretty.enum "," "[" "]", - infix_cons = (7, "::") -}; - - -(** OCaml serializer **) - -fun pr_ocaml_stmt labelled_name syntax_tyco syntax_const reserved_names deresolve is_cons = - let - fun pr_dicts fxy ds = - let - fun pr_dictvar (v, (_, 1)) = "_" ^ Code_Printer.first_upper v - | pr_dictvar (v, (i, _)) = "_" ^ Code_Printer.first_upper v ^ string_of_int (i+1); - fun pr_proj ps p = - fold_rev (fn p2 => fn p1 => Pretty.block [p1, str ".", str p2]) ps p - fun pr_dict fxy (DictConst (inst, dss)) = - brackify fxy ((str o deresolve) inst :: map (pr_dicts BR) dss) - | pr_dict fxy (DictVar (classrels, v)) = - pr_proj (map deresolve classrels) ((str o pr_dictvar) v) - in case ds - of [] => str "()" - | [d] => pr_dict fxy d - | _ :: _ => (Pretty.list "(" ")" o map (pr_dict NOBR)) ds - end; - fun pr_tyvar_dicts vs = - vs - |> map (fn (v, sort) => map_index (fn (i, _) => - DictVar ([], (v, (i, length sort)))) sort) - |> map (pr_dicts BR); - fun pr_tycoexpr fxy (tyco, tys) = - let - val tyco' = (str o deresolve) tyco - in case map (pr_typ BR) tys - of [] => tyco' - | [p] => Pretty.block [p, Pretty.brk 1, tyco'] - | (ps as _::_) => Pretty.block [Pretty.list "(" ")" ps, Pretty.brk 1, tyco'] - end - and pr_typ fxy (tyco `%% tys) = (case syntax_tyco tyco - of NONE => pr_tycoexpr fxy (tyco, tys) - | SOME (i, pr) => pr pr_typ fxy tys) - | pr_typ fxy (ITyVar v) = str ("'" ^ v); - fun pr_term is_closure thm vars fxy (IConst c) = - pr_app is_closure thm vars fxy (c, []) - | pr_term is_closure thm vars fxy (IVar v) = - str (Code_Printer.lookup_var vars v) - | pr_term is_closure thm vars fxy (t as t1 `$ t2) = - (case Code_Thingol.unfold_const_app t - of SOME c_ts => pr_app is_closure thm vars fxy c_ts - | NONE => - brackify fxy [pr_term is_closure thm vars NOBR t1, pr_term is_closure thm vars BR t2]) - | pr_term is_closure thm vars fxy (t as _ `|=> _) = - let - val (binds, t') = Code_Thingol.unfold_abs t; - fun pr ((v, pat), ty) = pr_bind is_closure thm BR ((SOME v, pat), ty); - val (ps, vars') = fold_map pr binds vars; - in brackets (str "fun" :: ps @ str "->" @@ pr_term is_closure thm vars' NOBR t') end - | pr_term is_closure thm vars fxy (ICase (cases as (_, t0))) = (case Code_Thingol.unfold_const_app t0 - of SOME (c_ts as ((c, _), _)) => if is_none (syntax_const c) - then pr_case is_closure thm vars fxy cases - else pr_app is_closure thm vars fxy c_ts - | NONE => pr_case is_closure thm vars fxy cases) - and pr_app' is_closure thm vars (app as ((c, ((_, iss), tys)), ts)) = - if is_cons c then - if length tys = length ts - then case ts - of [] => [(str o deresolve) c] - | [t] => [(str o deresolve) c, pr_term is_closure thm vars BR t] - | _ => [(str o deresolve) c, Pretty.enum "," "(" ")" - (map (pr_term is_closure thm vars NOBR) ts)] - else [pr_term is_closure thm vars BR (Code_Thingol.eta_expand (length tys) app)] - else if is_closure c - then (str o deresolve) c @@ str "()" - else (str o deresolve) c - :: ((map (pr_dicts BR) o filter_out null) iss @ map (pr_term is_closure thm vars BR) ts) - and pr_app is_closure = gen_pr_app (pr_app' is_closure) (pr_term is_closure) - syntax_const - and pr_bind' ((NONE, NONE), _) = str "_" - | pr_bind' ((SOME v, NONE), _) = str v - | pr_bind' ((NONE, SOME p), _) = p - | pr_bind' ((SOME v, SOME p), _) = brackets [p, str "as", str v] - and pr_bind is_closure = gen_pr_bind pr_bind' (pr_term is_closure) - and pr_case is_closure thm vars fxy (cases as ((_, [_]), _)) = - let - val (binds, body) = Code_Thingol.unfold_let (ICase cases); - fun pr ((pat, ty), t) vars = - vars - |> pr_bind is_closure thm NOBR ((NONE, SOME pat), ty) - |>> (fn p => concat - [str "let", p, str "=", pr_term is_closure thm vars NOBR t, str "in"]) - val (ps, vars') = fold_map pr binds vars; - in - brackify_block fxy (Pretty.chunks ps) [] - (pr_term is_closure thm vars' NOBR body) - end - | pr_case is_closure thm vars fxy (((t, ty), clause :: clauses), _) = - let - fun pr delim (pat, body) = - let - val (p, vars') = pr_bind is_closure thm NOBR ((NONE, SOME pat), ty) vars; - in concat [str delim, p, str "->", pr_term is_closure thm vars' NOBR body] end; - in - brackets ( - str "match" - :: pr_term is_closure thm vars NOBR t - :: pr "with" clause - :: map (pr "|") clauses - ) - end - | pr_case is_closure thm vars fxy ((_, []), _) = - (concat o map str) ["failwith", "\"empty case\""]; - fun fish_params vars eqs = - let - fun fish_param _ (w as SOME _) = w - | fish_param (IVar v) NONE = SOME v - | fish_param _ NONE = NONE; - fun fillup_param _ (_, SOME v) = v - | fillup_param x (i, NONE) = x ^ string_of_int i; - val fished1 = fold (map2 fish_param) eqs (replicate (length (hd eqs)) NONE); - val x = Name.variant (map_filter I fished1) "x"; - val fished2 = map_index (fillup_param x) fished1; - val (fished3, _) = Name.variants fished2 Name.context; - val vars' = Code_Printer.intro_vars fished3 vars; - in map (Code_Printer.lookup_var vars') fished3 end; - fun pr_stmt (MLExc (name, n)) = - let - val exc_str = - (ML_Syntax.print_string o Long_Name.base_name o Long_Name.qualifier) name; - in - (concat o map str) ( - "let" - :: deresolve name - :: replicate n "_" - @ "=" - :: "failwith" - @@ exc_str - ) - end - | pr_stmt (MLVal (name, (((vs, ty), t), (thm, _)))) = - let - val consts = map_filter - (fn c => if (is_some o syntax_const) c - then NONE else (SOME o Long_Name.base_name o deresolve) c) - (Code_Thingol.fold_constnames (insert (op =)) t []); - val vars = reserved_names - |> Code_Printer.intro_vars consts; - in - concat [ - str "let", - (str o deresolve) name, - str ":", - pr_typ NOBR ty, - str "=", - pr_term (K false) thm vars NOBR t - ] - end - | pr_stmt (MLFuns (funn :: funns, pseudo_funs)) = - let - fun pr_eq ((ts, t), (thm, _)) = - let - val consts = map_filter - (fn c => if (is_some o syntax_const) c - then NONE else (SOME o Long_Name.base_name o deresolve) c) - ((fold o Code_Thingol.fold_constnames) (insert (op =)) (t :: ts) []); - val vars = reserved_names - |> Code_Printer.intro_vars consts - |> Code_Printer.intro_vars ((fold o Code_Thingol.fold_unbound_varnames) - (insert (op =)) ts []); - in concat [ - (Pretty.block o Pretty.commas) - (map (pr_term (member (op =) pseudo_funs) thm vars NOBR) ts), - str "->", - pr_term (member (op =) pseudo_funs) thm vars NOBR t - ] end; - fun pr_eqs is_pseudo [((ts, t), (thm, _))] = - let - val consts = map_filter - (fn c => if (is_some o syntax_const) c - then NONE else (SOME o Long_Name.base_name o deresolve) c) - ((fold o Code_Thingol.fold_constnames) (insert (op =)) (t :: ts) []); - val vars = reserved_names - |> Code_Printer.intro_vars consts - |> Code_Printer.intro_vars ((fold o Code_Thingol.fold_unbound_varnames) - (insert (op =)) ts []); - in - concat ( - (if is_pseudo then [str "()"] - else map (pr_term (member (op =) pseudo_funs) thm vars BR) ts) - @ str "=" - @@ pr_term (member (op =) pseudo_funs) thm vars NOBR t - ) - end - | pr_eqs _ (eqs as (eq as (([_], _), _)) :: eqs') = - Pretty.block ( - str "=" - :: Pretty.brk 1 - :: str "function" - :: Pretty.brk 1 - :: pr_eq eq - :: maps (append [Pretty.fbrk, str "|", Pretty.brk 1] - o single o pr_eq) eqs' - ) - | pr_eqs _ (eqs as eq :: eqs') = - let - val consts = map_filter - (fn c => if (is_some o syntax_const) c - then NONE else (SOME o Long_Name.base_name o deresolve) c) - ((fold o Code_Thingol.fold_constnames) - (insert (op =)) (map (snd o fst) eqs) []); - val vars = reserved_names - |> Code_Printer.intro_vars consts; - val dummy_parms = (map str o fish_params vars o map (fst o fst)) eqs; - in - Pretty.block ( - Pretty.breaks dummy_parms - @ Pretty.brk 1 - :: str "=" - :: Pretty.brk 1 - :: str "match" - :: Pretty.brk 1 - :: (Pretty.block o Pretty.commas) dummy_parms - :: Pretty.brk 1 - :: str "with" - :: Pretty.brk 1 - :: pr_eq eq - :: maps (append [Pretty.fbrk, str "|", Pretty.brk 1] - o single o pr_eq) eqs' - ) - end; - fun pr_funn definer (name, ((vs, ty), eqs)) = - concat ( - str definer - :: (str o deresolve) name - :: pr_tyvar_dicts (filter_out (null o snd) vs) - @| pr_eqs (member (op =) pseudo_funs name) eqs - ); - fun pr_pseudo_fun name = concat [ - str "let", - (str o deresolve) name, - str "=", - (str o deresolve) name, - str "();;" - ]; - val (ps, p) = split_last (pr_funn "fun" funn :: map (pr_funn "and") funns); - val (ps, p) = split_last - (pr_funn "let rec" funn :: map (pr_funn "and") funns); - val pseudo_ps = map pr_pseudo_fun pseudo_funs; - in Pretty.chunks (ps @ Pretty.block ([p, str ";;"]) :: pseudo_ps) end - | pr_stmt (MLDatas (datas as (data :: datas'))) = - let - fun pr_co (co, []) = - str (deresolve co) - | pr_co (co, tys) = - concat [ - str (deresolve co), - str "of", - Pretty.enum " *" "" "" (map (pr_typ (INFX (2, X))) tys) - ]; - fun pr_data definer (tyco, (vs, [])) = - concat ( - str definer - :: pr_tycoexpr NOBR (tyco, map (ITyVar o fst) vs) - :: str "=" - @@ str "EMPTY_" - ) - | pr_data definer (tyco, (vs, cos)) = - concat ( - str definer - :: pr_tycoexpr NOBR (tyco, map (ITyVar o fst) vs) - :: str "=" - :: separate (str "|") (map pr_co cos) - ); - val (ps, p) = split_last - (pr_data "type" data :: map (pr_data "and") datas'); - in Pretty.chunks (ps @| Pretty.block ([p, str ";;"])) end - | pr_stmt (MLClass (class, (v, (superclasses, classparams)))) = - let - val w = "_" ^ Code_Printer.first_upper v; - fun pr_superclass_field (class, classrel) = - (concat o map str) [ - deresolve classrel, ":", "'" ^ v, deresolve class - ]; - fun pr_classparam_field (classparam, ty) = - concat [ - (str o deresolve) classparam, str ":", pr_typ NOBR ty - ]; - fun pr_classparam_proj (classparam, _) = - concat [ - str "let", - (str o deresolve) classparam, - str w, - str "=", - str (w ^ "." ^ deresolve classparam ^ ";;") - ]; - in Pretty.chunks ( - concat [ - str ("type '" ^ v), - (str o deresolve) class, - str "=", - enum_default "unit;;" ";" "{" "};;" ( - map pr_superclass_field superclasses - @ map pr_classparam_field classparams - ) - ] - :: map pr_classparam_proj classparams - ) end - | pr_stmt (MLClassinst (inst, ((class, (tyco, arity)), (superarities, classparam_insts)))) = - let - fun pr_superclass (_, (classrel, dss)) = - concat [ - (str o deresolve) classrel, - str "=", - pr_dicts NOBR [DictConst dss] - ]; - fun pr_classparam_inst ((classparam, c_inst), (thm, _)) = - concat [ - (str o deresolve) classparam, - str "=", - pr_app (K false) thm reserved_names NOBR (c_inst, []) - ]; - in - concat ( - str "let" - :: (str o deresolve) inst - :: pr_tyvar_dicts arity - @ str "=" - @@ (Pretty.enclose "(" ");;" o Pretty.breaks) [ - enum_default "()" ";" "{" "}" (map pr_superclass superarities - @ map pr_classparam_inst classparam_insts), - str ":", - pr_tycoexpr NOBR (class, [tyco `%% map (ITyVar o fst) arity]) - ] - ) - end; - in pr_stmt end; - -fun pr_ocaml_module name content = - Pretty.chunks ( - str ("module " ^ name ^ " = ") - :: str "struct" - :: str "" - :: content - @ str "" - @@ str ("end;; (*struct " ^ name ^ "*)") - ); - -val literals_ocaml = let - fun chr i = - let - val xs = string_of_int i; - val ys = replicate_string (3 - length (explode xs)) "0"; - in "\\" ^ ys ^ xs end; - fun char_ocaml c = - let - val i = ord c; - val s = if i < 32 orelse i = 34 orelse i = 39 orelse i = 92 orelse i > 126 - then chr i else c - in s end; - fun bignum_ocaml k = if k <= 1073741823 - then "(Big_int.big_int_of_int " ^ string_of_int k ^ ")" - else "(Big_int.big_int_of_string " ^ quote (string_of_int k) ^ ")" -in Literals { - literal_char = enclose "'" "'" o char_ocaml, - literal_string = quote o translate_string char_ocaml, - literal_numeral = fn unbounded => fn k => if k >= 0 then - if unbounded then bignum_ocaml k - else string_of_int k - else - if unbounded then "(Big_int.minus_big_int " ^ bignum_ocaml (~ k) ^ ")" - else (enclose "(" ")" o prefix "-" o string_of_int o op ~) k, - literal_list = Pretty.enum ";" "[" "]", - infix_cons = (6, "::") -} end; - - - -(** SML/OCaml generic part **) - -local - -datatype ml_node = - Dummy of string - | Stmt of string * ml_stmt - | Module of string * ((Name.context * Name.context) * ml_node Graph.T); - -in - -fun ml_node_of_program labelled_name module_name reserved_names raw_module_alias program = - let - val module_alias = if is_some module_name then K module_name else raw_module_alias; - val reserved_names = Name.make_context reserved_names; - val empty_module = ((reserved_names, reserved_names), Graph.empty); - fun map_node [] f = f - | map_node (m::ms) f = - Graph.default_node (m, Module (m, empty_module)) - #> Graph.map_node m (fn (Module (module_name, (nsp, nodes))) => - Module (module_name, (nsp, map_node ms f nodes))); - fun map_nsp_yield [] f (nsp, nodes) = - let - val (x, nsp') = f nsp - in (x, (nsp', nodes)) end - | map_nsp_yield (m::ms) f (nsp, nodes) = - let - val (x, nodes') = - nodes - |> Graph.default_node (m, Module (m, empty_module)) - |> Graph.map_node_yield m (fn Module (d_module_name, nsp_nodes) => - let - val (x, nsp_nodes') = map_nsp_yield ms f nsp_nodes - in (x, Module (d_module_name, nsp_nodes')) end) - in (x, (nsp, nodes')) end; - fun map_nsp_fun_yield f (nsp_fun, nsp_typ) = - let - val (x, nsp_fun') = f nsp_fun - in (x, (nsp_fun', nsp_typ)) end; - fun map_nsp_typ_yield f (nsp_fun, nsp_typ) = - let - val (x, nsp_typ') = f nsp_typ - in (x, (nsp_fun, nsp_typ')) end; - val mk_name_module = Code_Printer.mk_name_module reserved_names NONE module_alias program; - fun mk_name_stmt upper name nsp = - let - val (_, base) = Code_Printer.dest_name name; - val base' = if upper then Code_Printer.first_upper base else base; - val ([base''], nsp') = Name.variants [base'] nsp; - in (base'', nsp') end; - fun rearrange_fun name (tysm as (vs, ty), raw_eqs) = - let - val eqs = filter (snd o snd) raw_eqs; - val (eqs', is_value) = if null (filter_out (null o snd) vs) then case eqs - of [(([], t), thm)] => if (not o null o fst o Code_Thingol.unfold_fun) ty - then ([(([IVar "x"], t `$ IVar "x"), thm)], false) - else (eqs, not (Code_Thingol.fold_constnames - (fn name' => fn b => b orelse name = name') t false)) - | _ => (eqs, false) - else (eqs, false) - in ((name, (tysm, eqs')), is_value) end; - fun check_kind [((name, (tysm, [(([], t), thm)])), true)] = MLVal (name, ((tysm, t), thm)) - | check_kind [((name, ((vs, ty), [])), _)] = - MLExc (name, (length o filter_out (null o snd)) vs + (length o fst o Code_Thingol.unfold_fun) ty) - | check_kind funns = - MLFuns (map fst funns, map_filter - (fn ((name, ((vs, _), [(([], _), _)])), _) => - if null (filter_out (null o snd) vs) then SOME name else NONE - | _ => NONE) funns); - fun add_funs stmts = fold_map - (fn (name, Code_Thingol.Fun (_, stmt)) => - map_nsp_fun_yield (mk_name_stmt false name) - #>> rpair (rearrange_fun name stmt) - | (name, _) => - error ("Function block containing illegal statement: " ^ labelled_name name) - ) stmts - #>> (split_list #> apsnd check_kind); - fun add_datatypes stmts = - fold_map - (fn (name, Code_Thingol.Datatype (_, stmt)) => - map_nsp_typ_yield (mk_name_stmt false name) #>> rpair (SOME (name, stmt)) - | (name, Code_Thingol.Datatypecons _) => - map_nsp_fun_yield (mk_name_stmt true name) #>> rpair NONE - | (name, _) => - error ("Datatype block containing illegal statement: " ^ labelled_name name) - ) stmts - #>> (split_list #> apsnd (map_filter I - #> (fn [] => error ("Datatype block without data statement: " - ^ (commas o map (labelled_name o fst)) stmts) - | stmts => MLDatas stmts))); - fun add_class stmts = - fold_map - (fn (name, Code_Thingol.Class (_, stmt)) => - map_nsp_typ_yield (mk_name_stmt false name) #>> rpair (SOME (name, stmt)) - | (name, Code_Thingol.Classrel _) => - map_nsp_fun_yield (mk_name_stmt false name) #>> rpair NONE - | (name, Code_Thingol.Classparam _) => - map_nsp_fun_yield (mk_name_stmt false name) #>> rpair NONE - | (name, _) => - error ("Class block containing illegal statement: " ^ labelled_name name) - ) stmts - #>> (split_list #> apsnd (map_filter I - #> (fn [] => error ("Class block without class statement: " - ^ (commas o map (labelled_name o fst)) stmts) - | [stmt] => MLClass stmt))); - fun add_inst [(name, Code_Thingol.Classinst stmt)] = - map_nsp_fun_yield (mk_name_stmt false name) - #>> (fn base => ([base], MLClassinst (name, stmt))); - fun add_stmts ((stmts as (_, Code_Thingol.Fun _)::_)) = - add_funs stmts - | add_stmts ((stmts as (_, Code_Thingol.Datatypecons _)::_)) = - add_datatypes stmts - | add_stmts ((stmts as (_, Code_Thingol.Datatype _)::_)) = - add_datatypes stmts - | add_stmts ((stmts as (_, Code_Thingol.Class _)::_)) = - add_class stmts - | add_stmts ((stmts as (_, Code_Thingol.Classrel _)::_)) = - add_class stmts - | add_stmts ((stmts as (_, Code_Thingol.Classparam _)::_)) = - add_class stmts - | add_stmts ((stmts as [(_, Code_Thingol.Classinst _)])) = - add_inst stmts - | add_stmts stmts = error ("Illegal mutual dependencies: " ^ - (commas o map (labelled_name o fst)) stmts); - fun add_stmts' stmts nsp_nodes = - let - val names as (name :: names') = map fst stmts; - val deps = - [] - |> fold (fold (insert (op =)) o Graph.imm_succs program) names - |> subtract (op =) names; - val (module_names, _) = (split_list o map Code_Printer.dest_name) names; - val module_name = (the_single o distinct (op =) o map mk_name_module) module_names - handle Empty => - error ("Different namespace prefixes for mutual dependencies:\n" - ^ commas (map labelled_name names) - ^ "\n" - ^ commas module_names); - val module_name_path = Long_Name.explode module_name; - fun add_dep name name' = - let - val module_name' = (mk_name_module o fst o Code_Printer.dest_name) name'; - in if module_name = module_name' then - map_node module_name_path (Graph.add_edge (name, name')) - else let - val (common, (diff1 :: _, diff2 :: _)) = chop_prefix (op =) - (module_name_path, Long_Name.explode module_name'); - in - map_node common - (fn node => Graph.add_edge_acyclic (diff1, diff2) node - handle Graph.CYCLES _ => error ("Dependency " - ^ quote name ^ " -> " ^ quote name' - ^ " would result in module dependency cycle")) - end end; - in - nsp_nodes - |> map_nsp_yield module_name_path (add_stmts stmts) - |-> (fn (base' :: bases', stmt') => - apsnd (map_node module_name_path (Graph.new_node (name, (Stmt (base', stmt'))) - #> fold2 (fn name' => fn base' => - Graph.new_node (name', (Dummy base'))) names' bases'))) - |> apsnd (fold (fn name => fold (add_dep name) deps) names) - |> apsnd (fold_product (curry (map_node module_name_path o Graph.add_edge)) names names) - end; - val (_, nodes) = empty_module - |> fold add_stmts' (map (AList.make (Graph.get_node program)) - (rev (Graph.strong_conn program))); - fun deresolver prefix name = - let - val module_name = (fst o Code_Printer.dest_name) name; - val module_name' = (Long_Name.explode o mk_name_module) module_name; - val (_, (_, remainder)) = chop_prefix (op =) (prefix, module_name'); - val stmt_name = - nodes - |> fold (fn name => fn node => case Graph.get_node node name - of Module (_, (_, node)) => node) module_name' - |> (fn node => case Graph.get_node node name of Stmt (stmt_name, _) => stmt_name - | Dummy stmt_name => stmt_name); - in - Long_Name.implode (remainder @ [stmt_name]) - end handle Graph.UNDEF _ => - error ("Unknown statement name: " ^ labelled_name name); - in (deresolver, nodes) end; - -fun serialize_ml target compile pr_module pr_stmt raw_module_name labelled_name reserved_names includes raw_module_alias - _ syntax_tyco syntax_const program stmt_names destination = - let - val is_cons = Code_Thingol.is_cons program; - val present_stmt_names = Code_Target.stmt_names_of_destination destination; - val is_present = not (null present_stmt_names); - val module_name = if is_present then SOME "Code" else raw_module_name; - val (deresolver, nodes) = ml_node_of_program labelled_name module_name - reserved_names raw_module_alias program; - val reserved_names = Code_Printer.make_vars reserved_names; - fun pr_node prefix (Dummy _) = - NONE - | pr_node prefix (Stmt (_, stmt)) = if is_present andalso - (null o filter (member (op =) present_stmt_names) o stmt_names_of) stmt - then NONE - else SOME - (pr_stmt labelled_name syntax_tyco syntax_const reserved_names - (deresolver prefix) is_cons stmt) - | pr_node prefix (Module (module_name, (_, nodes))) = - separate (str "") - ((map_filter (pr_node (prefix @ [module_name]) o Graph.get_node nodes) - o rev o flat o Graph.strong_conn) nodes) - |> (if is_present then Pretty.chunks else pr_module module_name) - |> SOME; - val stmt_names' = (map o try) - (deresolver (if is_some module_name then the_list module_name else [])) stmt_names; - val p = Pretty.chunks (separate (str "") (map snd includes @ (map_filter - (pr_node [] o Graph.get_node nodes) o rev o flat o Graph.strong_conn) nodes)); - in - Code_Target.mk_serialization target - (case compile of SOME compile => SOME (compile o Code_Target.code_of_pretty) | NONE => NONE) - (fn NONE => Code_Target.code_writeln | SOME file => File.write file o Code_Target.code_of_pretty) - (rpair stmt_names' o Code_Target.code_of_pretty) p destination - end; - -end; (*local*) - - -(** ML (system language) code for evaluation and instrumentalization **) - -fun eval_code_of some_target thy = Code_Target.serialize_custom thy (the_default target_Eval some_target, - (fn _ => fn [] => serialize_ml target_SML (SOME (K ())) (K Pretty.chunks) pr_sml_stmt (SOME ""), - literals_sml)); - - -(* evaluation *) - -fun eval some_target reff postproc thy t args = - let - val ctxt = ProofContext.init thy; - fun evaluator naming program ((_, (_, ty)), t) deps = - let - val _ = if Code_Thingol.contains_dictvar t then - error "Term to be evaluated contains free dictionaries" else (); - val value_name = "Value.VALUE.value" - val program' = program - |> Graph.new_node (value_name, - Code_Thingol.Fun (Term.dummy_patternN, (([], ty), [(([], t), (Drule.dummy_thm, true))]))) - |> fold (curry Graph.add_edge value_name) deps; - val (value_code, [SOME value_name']) = eval_code_of some_target thy naming program' [value_name]; - val sml_code = "let\n" ^ value_code ^ "\nin " ^ value_name' - ^ space_implode " " (map (enclose "(" ")") args) ^ " end"; - in ML_Context.evaluate ctxt false reff sml_code end; - in Code_Thingol.eval thy I postproc evaluator t end; - - -(* instrumentalization by antiquotation *) - -local - -structure CodeAntiqData = ProofDataFun -( - type T = (string list * string list) * (bool * (string - * (string * ((string * string) list * (string * string) list)) lazy)); - fun init _ = (([], []), (true, ("", Lazy.value ("", ([], []))))); -); - -val is_first_occ = fst o snd o CodeAntiqData.get; - -fun delayed_code thy tycos consts () = - let - val (consts', (naming, program)) = Code_Thingol.consts_program thy consts; - val tycos' = map (the o Code_Thingol.lookup_tyco naming) tycos; - val (ml_code, target_names) = eval_code_of NONE thy naming program (consts' @ tycos'); - val (consts'', tycos'') = chop (length consts') target_names; - val consts_map = map2 (fn const => fn NONE => - error ("Constant " ^ (quote o Code.string_of_const thy) const - ^ "\nhas a user-defined serialization") - | SOME const'' => (const, const'')) consts consts'' - val tycos_map = map2 (fn tyco => fn NONE => - error ("Type " ^ (quote o Sign.extern_type thy) tyco - ^ "\nhas a user-defined serialization") - | SOME tyco'' => (tyco, tyco'')) tycos tycos''; - in (ml_code, (tycos_map, consts_map)) end; - -fun register_code new_tycos new_consts ctxt = - let - val ((tycos, consts), (_, (struct_name, _))) = CodeAntiqData.get ctxt; - val tycos' = fold (insert (op =)) new_tycos tycos; - val consts' = fold (insert (op =)) new_consts consts; - val (struct_name', ctxt') = if struct_name = "" - then ML_Antiquote.variant "Code" ctxt - else (struct_name, ctxt); - val acc_code = Lazy.lazy (delayed_code (ProofContext.theory_of ctxt) tycos' consts'); - in CodeAntiqData.put ((tycos', consts'), (false, (struct_name', acc_code))) ctxt' end; - -fun register_const const = register_code [] [const]; - -fun register_datatype tyco constrs = register_code [tyco] constrs; - -fun print_const const all_struct_name tycos_map consts_map = - (Long_Name.append all_struct_name o the o AList.lookup (op =) consts_map) const; - -fun print_datatype tyco constrs all_struct_name tycos_map consts_map = - let - val upperize = implode o nth_map 0 Symbol.to_ascii_upper o explode; - fun check_base name name'' = - if upperize (Long_Name.base_name name) = upperize name'' - then () else error ("Name as printed " ^ quote name'' - ^ "\ndiffers from logical base name " ^ quote (Long_Name.base_name name) ^ "; sorry."); - val tyco'' = (the o AList.lookup (op =) tycos_map) tyco; - val constrs'' = map (the o AList.lookup (op =) consts_map) constrs; - val _ = check_base tyco tyco''; - val _ = map2 check_base constrs constrs''; - in "datatype " ^ tyco'' ^ " = datatype " ^ Long_Name.append all_struct_name tyco'' end; - -fun print_code struct_name is_first print_it ctxt = - let - val (_, (_, (struct_code_name, acc_code))) = CodeAntiqData.get ctxt; - val (raw_ml_code, (tycos_map, consts_map)) = Lazy.force acc_code; - val ml_code = if is_first then "\nstructure " ^ struct_code_name - ^ " =\nstruct\n\n" ^ raw_ml_code ^ "\nend;\n\n" - else ""; - val all_struct_name = Long_Name.append struct_name struct_code_name; - in (ml_code, print_it all_struct_name tycos_map consts_map) end; - -in - -fun ml_code_antiq raw_const {struct_name, background} = - let - val const = Code.check_const (ProofContext.theory_of background) raw_const; - val is_first = is_first_occ background; - val background' = register_const const background; - in (print_code struct_name is_first (print_const const), background') end; - -fun ml_code_datatype_antiq (raw_tyco, raw_constrs) {struct_name, background} = - let - val thy = ProofContext.theory_of background; - val tyco = Sign.intern_type thy raw_tyco; - val constrs = map (Code.check_const thy) raw_constrs; - val constrs' = (map fst o snd o Code.get_datatype thy) tyco; - val _ = if gen_eq_set (op =) (constrs, constrs') then () - else error ("Type " ^ quote tyco ^ ": given constructors diverge from real constructors") - val is_first = is_first_occ background; - val background' = register_datatype tyco constrs background; - in (print_code struct_name is_first (print_datatype tyco constrs), background') end; - -end; (*local*) - - -(** Isar setup **) - -val _ = ML_Context.add_antiq "code" (fn _ => Args.term >> ml_code_antiq); -val _ = ML_Context.add_antiq "code_datatype" (fn _ => - (Args.tyname --| Scan.lift (Args.$$$ "=") - -- (Args.term ::: Scan.repeat (Scan.lift (Args.$$$ "|") |-- Args.term))) - >> ml_code_datatype_antiq); - -fun isar_seri_sml module_name = - Code_Target.parse_args (Scan.succeed ()) - #> (fn () => serialize_ml target_SML - (SOME (use_text ML_Env.local_context (1, "generated code") false)) - pr_sml_module pr_sml_stmt module_name); - -fun isar_seri_ocaml module_name = - Code_Target.parse_args (Scan.succeed ()) - #> (fn () => serialize_ml target_OCaml NONE - pr_ocaml_module pr_ocaml_stmt module_name); - -val setup = - Code_Target.add_target (target_SML, (isar_seri_sml, literals_sml)) - #> Code_Target.add_target (target_OCaml, (isar_seri_ocaml, literals_ocaml)) - #> Code_Target.extend_target (target_Eval, (target_SML, K I)) - #> Code_Target.add_syntax_tyco target_SML "fun" (SOME (2, fn pr_typ => fn fxy => fn [ty1, ty2] => - brackify_infix (1, R) fxy [ - pr_typ (INFX (1, X)) ty1, - str "->", - pr_typ (INFX (1, R)) ty2 - ])) - #> Code_Target.add_syntax_tyco target_OCaml "fun" (SOME (2, fn pr_typ => fn fxy => fn [ty1, ty2] => - brackify_infix (1, R) fxy [ - pr_typ (INFX (1, X)) ty1, - str "->", - pr_typ (INFX (1, R)) ty2 - ])) - #> fold (Code_Target.add_reserved target_SML) ML_Syntax.reserved_names - #> fold (Code_Target.add_reserved target_SML) - ["o" (*dictionary projections use it already*), "Fail", "div", "mod" (*standard infixes*)] - #> fold (Code_Target.add_reserved target_OCaml) [ - "and", "as", "assert", "begin", "class", - "constraint", "do", "done", "downto", "else", "end", "exception", - "external", "false", "for", "fun", "function", "functor", "if", - "in", "include", "inherit", "initializer", "lazy", "let", "match", "method", - "module", "mutable", "new", "object", "of", "open", "or", "private", "rec", - "sig", "struct", "then", "to", "true", "try", "type", "val", - "virtual", "when", "while", "with" - ] - #> fold (Code_Target.add_reserved target_OCaml) ["failwith", "mod"]; - -end; (*struct*) diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/Tools/code/code_preproc.ML --- a/src/Tools/code/code_preproc.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,515 +0,0 @@ -(* Title: Tools/code/code_preproc.ML - Author: Florian Haftmann, TU Muenchen - -Preprocessing code equations into a well-sorted system -in a graph with explicit dependencies. -*) - -signature CODE_PREPROC = -sig - val map_pre: (simpset -> simpset) -> theory -> theory - val map_post: (simpset -> simpset) -> theory -> theory - val add_inline: thm -> theory -> theory - val add_functrans: string * (theory -> (thm * bool) list -> (thm * bool) list option) -> theory -> theory - val del_functrans: string -> theory -> theory - val simple_functrans: (theory -> thm list -> thm list option) - -> theory -> (thm * bool) list -> (thm * bool) list option - val print_codeproc: theory -> unit - - type code_algebra - type code_graph - val eqns: code_graph -> string -> (thm * bool) list - val typ: code_graph -> string -> (string * sort) list * typ - val all: code_graph -> string list - val pretty: theory -> code_graph -> Pretty.T - val obtain: theory -> string list -> term list -> code_algebra * code_graph - val eval_conv: theory -> (sort -> sort) - -> (code_algebra -> code_graph -> (string * sort) list -> term -> cterm -> thm) -> cterm -> thm - val eval: theory -> (sort -> sort) -> ((term -> term) -> 'a -> 'a) - -> (code_algebra -> code_graph -> (string * sort) list -> term -> 'a) -> term -> 'a - - val setup: theory -> theory -end - -structure Code_Preproc : CODE_PREPROC = -struct - -(** preprocessor administration **) - -(* theory data *) - -datatype thmproc = Thmproc of { - pre: simpset, - post: simpset, - functrans: (string * (serial * (theory -> (thm * bool) list -> (thm * bool) list option))) list -}; - -fun make_thmproc ((pre, post), functrans) = - Thmproc { pre = pre, post = post, functrans = functrans }; -fun map_thmproc f (Thmproc { pre, post, functrans }) = - make_thmproc (f ((pre, post), functrans)); -fun merge_thmproc (Thmproc { pre = pre1, post = post1, functrans = functrans1 }, - Thmproc { pre = pre2, post = post2, functrans = functrans2 }) = - let - val pre = Simplifier.merge_ss (pre1, pre2); - val post = Simplifier.merge_ss (post1, post2); - val functrans = AList.merge (op =) (eq_fst (op =)) (functrans1, functrans2); - in make_thmproc ((pre, post), functrans) end; - -structure Code_Preproc_Data = TheoryDataFun -( - type T = thmproc; - val empty = make_thmproc ((Simplifier.empty_ss, Simplifier.empty_ss), []); - fun copy spec = spec; - val extend = copy; - fun merge pp = merge_thmproc; -); - -fun the_thmproc thy = case Code_Preproc_Data.get thy - of Thmproc x => x; - -fun delete_force msg key xs = - if AList.defined (op =) xs key then AList.delete (op =) key xs - else error ("No such " ^ msg ^ ": " ^ quote key); - -fun map_data f thy = - thy - |> Code.purge_data - |> (Code_Preproc_Data.map o map_thmproc) f; - -val map_pre = map_data o apfst o apfst; -val map_post = map_data o apfst o apsnd; - -val add_inline = map_pre o MetaSimplifier.add_simp; -val del_inline = map_pre o MetaSimplifier.del_simp; -val add_post = map_post o MetaSimplifier.add_simp; -val del_post = map_post o MetaSimplifier.del_simp; - -fun add_functrans (name, f) = (map_data o apsnd) - (AList.update (op =) (name, (serial (), f))); - -fun del_functrans name = (map_data o apsnd) - (delete_force "function transformer" name); - - -(* post- and preprocessing *) - -fun apply_functrans thy c _ [] = [] - | apply_functrans thy c [] eqns = eqns - | apply_functrans thy c functrans eqns = eqns - |> perhaps (perhaps_loop (perhaps_apply functrans)) - |> Code.assert_eqns_const thy c; - -fun rhs_conv conv thm = Thm.transitive thm ((conv o Thm.rhs_of) thm); - -fun term_of_conv thy f = - Thm.cterm_of thy - #> f - #> Thm.prop_of - #> Logic.dest_equals - #> snd; - -fun preprocess thy c eqns = - let - val pre = (Simplifier.theory_context thy o #pre o the_thmproc) thy; - val functrans = (map (fn (_, (_, f)) => f thy) o #functrans - o the_thmproc) thy; - in - eqns - |> apply_functrans thy c functrans - |> (map o apfst) (Code.rewrite_eqn pre) - |> (map o apfst) (AxClass.unoverload thy) - |> map (Code.assert_eqn thy) - |> burrow_fst (Code.norm_args thy) - |> burrow_fst (Code.norm_varnames thy) - end; - -fun preprocess_conv thy ct = - let - val pre = (Simplifier.theory_context thy o #pre o the_thmproc) thy; - in - ct - |> Simplifier.rewrite pre - |> rhs_conv (AxClass.unoverload_conv thy) - end; - -fun postprocess_conv thy ct = - let - val post = (Simplifier.theory_context thy o #post o the_thmproc) thy; - in - ct - |> AxClass.overload_conv thy - |> rhs_conv (Simplifier.rewrite post) - end; - -fun postprocess_term thy = term_of_conv thy (postprocess_conv thy); - -fun print_codeproc thy = - let - val ctxt = ProofContext.init thy; - val pre = (#pre o the_thmproc) thy; - val post = (#post o the_thmproc) thy; - val functrans = (map fst o #functrans o the_thmproc) thy; - in - (Pretty.writeln o Pretty.chunks) [ - Pretty.block [ - Pretty.str "preprocessing simpset:", - Pretty.fbrk, - Simplifier.pretty_ss ctxt pre - ], - Pretty.block [ - Pretty.str "postprocessing simpset:", - Pretty.fbrk, - Simplifier.pretty_ss ctxt post - ], - Pretty.block ( - Pretty.str "function transformers:" - :: Pretty.fbrk - :: (Pretty.fbreaks o map Pretty.str) functrans - ) - ] - end; - -fun simple_functrans f thy eqns = case f thy (map fst eqns) - of SOME thms' => SOME (map (rpair (forall snd eqns)) thms') - | NONE => NONE; - - -(** sort algebra and code equation graph types **) - -type code_algebra = (sort -> sort) * Sorts.algebra; -type code_graph = (((string * sort) list * typ) * (thm * bool) list) Graph.T; - -fun eqns eqngr = these o Option.map snd o try (Graph.get_node eqngr); -fun typ eqngr = fst o Graph.get_node eqngr; -fun all eqngr = Graph.keys eqngr; - -fun pretty thy eqngr = - AList.make (snd o Graph.get_node eqngr) (Graph.keys eqngr) - |> (map o apfst) (Code.string_of_const thy) - |> sort (string_ord o pairself fst) - |> map (fn (s, thms) => - (Pretty.block o Pretty.fbreaks) ( - Pretty.str s - :: map (Display.pretty_thm o fst) thms - )) - |> Pretty.chunks; - - -(** the Waisenhaus algorithm **) - -(* auxiliary *) - -fun is_proper_class thy = can (AxClass.get_info thy); - -fun complete_proper_sort thy = - Sign.complete_sort thy #> filter (is_proper_class thy); - -fun inst_params thy tyco = - map (fn (c, _) => AxClass.param_of_inst thy (c, tyco)) - o maps (#params o AxClass.get_info thy); - -fun consts_of thy eqns = [] |> (fold o fold o fold_aterms) - (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, Logic.unvarifyT ty)) | _ => I) - (map (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of o fst) eqns); - -fun tyscm_rhss_of thy c eqns = - let - val tyscm = case eqns of [] => Code.default_typscheme thy c - | ((thm, _) :: _) => Code.typscheme_eqn thy thm; - val rhss = consts_of thy eqns; - in (tyscm, rhss) end; - - -(* data structures *) - -datatype const = Fun of string | Inst of class * string; - -fun const_ord (Fun c1, Fun c2) = fast_string_ord (c1, c2) - | const_ord (Inst class_tyco1, Inst class_tyco2) = - prod_ord fast_string_ord fast_string_ord (class_tyco1, class_tyco2) - | const_ord (Fun _, Inst _) = LESS - | const_ord (Inst _, Fun _) = GREATER; - -type var = const * int; - -structure Vargraph = - GraphFun(type key = var val ord = prod_ord const_ord int_ord); - -datatype styp = Tyco of string * styp list | Var of var | Free; - -fun styp_of c_lhs (Type (tyco, tys)) = Tyco (tyco, map (styp_of c_lhs) tys) - | styp_of c_lhs (TFree (v, _)) = case c_lhs - of SOME (c, lhs) => Var (Fun c, find_index (fn (v', _) => v = v') lhs) - | NONE => Free; - -type vardeps_data = ((string * styp list) list * class list) Vargraph.T - * (((string * sort) list * (thm * bool) list) Symtab.table - * (class * string) list); - -val empty_vardeps_data : vardeps_data = - (Vargraph.empty, (Symtab.empty, [])); - - -(* retrieving equations and instances from the background context *) - -fun obtain_eqns thy eqngr c = - case try (Graph.get_node eqngr) c - of SOME ((lhs, _), eqns) => ((lhs, []), []) - | NONE => let - val eqns = Code.these_eqns thy c - |> preprocess thy c; - val ((lhs, _), rhss) = tyscm_rhss_of thy c eqns; - in ((lhs, rhss), eqns) end; - -fun obtain_instance thy arities (inst as (class, tyco)) = - case AList.lookup (op =) arities inst - of SOME classess => (classess, ([], [])) - | NONE => let - val all_classes = complete_proper_sort thy [class]; - val superclasses = remove (op =) class all_classes - val classess = map (complete_proper_sort thy) - (Sign.arity_sorts thy tyco [class]); - val inst_params = inst_params thy tyco all_classes; - in (classess, (superclasses, inst_params)) end; - - -(* computing instantiations *) - -fun add_classes thy arities eqngr c_k new_classes vardeps_data = - let - val (styps, old_classes) = Vargraph.get_node (fst vardeps_data) c_k; - val diff_classes = new_classes |> subtract (op =) old_classes; - in if null diff_classes then vardeps_data - else let - val c_ks = Vargraph.imm_succs (fst vardeps_data) c_k |> insert (op =) c_k; - in - vardeps_data - |> (apfst o Vargraph.map_node c_k o apsnd) (append diff_classes) - |> fold (fn styp => fold (ensure_typmatch_inst thy arities eqngr styp) new_classes) styps - |> fold (fn c_k => add_classes thy arities eqngr c_k diff_classes) c_ks - end end -and add_styp thy arities eqngr c_k tyco_styps vardeps_data = - let - val (old_styps, classes) = Vargraph.get_node (fst vardeps_data) c_k; - in if member (op =) old_styps tyco_styps then vardeps_data - else - vardeps_data - |> (apfst o Vargraph.map_node c_k o apfst) (cons tyco_styps) - |> fold (ensure_typmatch_inst thy arities eqngr tyco_styps) classes - end -and add_dep thy arities eqngr c_k c_k' vardeps_data = - let - val (_, classes) = Vargraph.get_node (fst vardeps_data) c_k; - in - vardeps_data - |> add_classes thy arities eqngr c_k' classes - |> apfst (Vargraph.add_edge (c_k, c_k')) - end -and ensure_typmatch_inst thy arities eqngr (tyco, styps) class vardeps_data = - if can (Sign.arity_sorts thy tyco) [class] - then vardeps_data - |> ensure_inst thy arities eqngr (class, tyco) - |> fold_index (fn (k, styp) => - ensure_typmatch thy arities eqngr styp (Inst (class, tyco), k)) styps - else vardeps_data (*permissive!*) -and ensure_inst thy arities eqngr (inst as (class, tyco)) (vardeps_data as (_, (_, insts))) = - if member (op =) insts inst then vardeps_data - else let - val (classess, (superclasses, inst_params)) = - obtain_instance thy arities inst; - in - vardeps_data - |> (apsnd o apsnd) (insert (op =) inst) - |> fold_index (fn (k, _) => - apfst (Vargraph.new_node ((Inst (class, tyco), k), ([] ,[])))) classess - |> fold (fn superclass => ensure_inst thy arities eqngr (superclass, tyco)) superclasses - |> fold (ensure_fun thy arities eqngr) inst_params - |> fold_index (fn (k, classes) => - add_classes thy arities eqngr (Inst (class, tyco), k) classes - #> fold (fn superclass => - add_dep thy arities eqngr (Inst (superclass, tyco), k) - (Inst (class, tyco), k)) superclasses - #> fold (fn inst_param => - add_dep thy arities eqngr (Fun inst_param, k) - (Inst (class, tyco), k) - ) inst_params - ) classess - end -and ensure_typmatch thy arities eqngr (Tyco tyco_styps) c_k vardeps_data = - vardeps_data - |> add_styp thy arities eqngr c_k tyco_styps - | ensure_typmatch thy arities eqngr (Var c_k') c_k vardeps_data = - vardeps_data - |> add_dep thy arities eqngr c_k c_k' - | ensure_typmatch thy arities eqngr Free c_k vardeps_data = - vardeps_data -and ensure_rhs thy arities eqngr (c', styps) vardeps_data = - vardeps_data - |> ensure_fun thy arities eqngr c' - |> fold_index (fn (k, styp) => - ensure_typmatch thy arities eqngr styp (Fun c', k)) styps -and ensure_fun thy arities eqngr c (vardeps_data as (_, (eqntab, _))) = - if Symtab.defined eqntab c then vardeps_data - else let - val ((lhs, rhss), eqns) = obtain_eqns thy eqngr c; - val rhss' = (map o apsnd o map) (styp_of (SOME (c, lhs))) rhss; - in - vardeps_data - |> (apsnd o apfst) (Symtab.update_new (c, (lhs, eqns))) - |> fold_index (fn (k, _) => - apfst (Vargraph.new_node ((Fun c, k), ([] ,[])))) lhs - |> fold_index (fn (k, (_, sort)) => - add_classes thy arities eqngr (Fun c, k) (complete_proper_sort thy sort)) lhs - |> fold (ensure_rhs thy arities eqngr) rhss' - end; - - -(* applying instantiations *) - -fun dicts_of thy (proj_sort, algebra) (T, sort) = - let - fun class_relation (x, _) _ = x; - fun type_constructor tyco xs class = - inst_params thy tyco (Sorts.complete_sort algebra [class]) - @ (maps o maps) fst xs; - fun type_variable (TFree (_, sort)) = map (pair []) (proj_sort sort); - in - flat (Sorts.of_sort_derivation (Syntax.pp_global thy) algebra - { class_relation = class_relation, type_constructor = type_constructor, - type_variable = type_variable } (T, proj_sort sort) - handle Sorts.CLASS_ERROR _ => [] (*permissive!*)) - end; - -fun add_arity thy vardeps (class, tyco) = - AList.default (op =) - ((class, tyco), map (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k)) - (0 upto Sign.arity_number thy tyco - 1)); - -fun add_eqs thy vardeps (c, (proto_lhs, proto_eqns)) (rhss, eqngr) = - if can (Graph.get_node eqngr) c then (rhss, eqngr) - else let - val lhs = map_index (fn (k, (v, _)) => - (v, snd (Vargraph.get_node vardeps (Fun c, k)))) proto_lhs; - val inst_tab = Vartab.empty |> fold (fn (v, sort) => - Vartab.update ((v, 0), sort)) lhs; - val eqns = proto_eqns - |> (map o apfst) (Code.inst_thm thy inst_tab); - val (tyscm, rhss') = tyscm_rhss_of thy c eqns; - val eqngr' = Graph.new_node (c, (tyscm, eqns)) eqngr; - in (map (pair c) rhss' @ rhss, eqngr') end; - -fun extend_arities_eqngr thy cs ts (arities, eqngr) = - let - val cs_rhss = (fold o fold_aterms) (fn Const (c_ty as (c, _)) => - insert (op =) (c, (map (styp_of NONE) o Sign.const_typargs thy) c_ty) | _ => I) ts []; - val (vardeps, (eqntab, insts)) = empty_vardeps_data - |> fold (ensure_fun thy arities eqngr) cs - |> fold (ensure_rhs thy arities eqngr) cs_rhss; - val arities' = fold (add_arity thy vardeps) insts arities; - val pp = Syntax.pp_global thy; - val algebra = Sorts.subalgebra pp (is_proper_class thy) - (AList.lookup (op =) arities') (Sign.classes_of thy); - val (rhss, eqngr') = Symtab.fold (add_eqs thy vardeps) eqntab ([], eqngr); - fun deps_of (c, rhs) = c :: maps (dicts_of thy algebra) - (rhs ~~ (map snd o fst o fst o Graph.get_node eqngr') c); - val eqngr'' = fold (fn (c, rhs) => fold - (curry Graph.add_edge c) (deps_of rhs)) rhss eqngr'; - in (algebra, (arities', eqngr'')) end; - - -(** store for preprocessed arities and code equations **) - -structure Wellsorted = CodeDataFun -( - type T = ((string * class) * sort list) list * code_graph; - val empty = ([], Graph.empty); - fun purge thy cs (arities, eqngr) = - let - val del_cs = ((Graph.all_preds eqngr - o filter (can (Graph.get_node eqngr))) cs); - val del_arities = del_cs - |> map_filter (AxClass.inst_of_param thy) - |> maps (fn (c, tyco) => - (map (rpair tyco) o Sign.complete_sort thy o the_list - o AxClass.class_of_param thy) c); - val arities' = fold (AList.delete (op =)) del_arities arities; - val eqngr' = Graph.del_nodes del_cs eqngr; - in (arities', eqngr') end; -); - - -(** retrieval and evaluation interfaces **) - -fun obtain thy cs ts = apsnd snd - (Wellsorted.change_yield thy (extend_arities_eqngr thy cs ts)); - -fun prepare_sorts_typ prep_sort - = map_type_tfree (fn (v, sort) => TFree (v, prep_sort sort)); - -fun prepare_sorts prep_sort (Const (c, ty)) = - Const (c, prepare_sorts_typ prep_sort ty) - | prepare_sorts prep_sort (t1 $ t2) = - prepare_sorts prep_sort t1 $ prepare_sorts prep_sort t2 - | prepare_sorts prep_sort (Abs (v, ty, t)) = - Abs (v, prepare_sorts_typ prep_sort ty, prepare_sorts prep_sort t) - | prepare_sorts _ (t as Bound _) = t; - -fun gen_eval thy cterm_of conclude_evaluation prep_sort evaluator proto_ct = - let - val pp = Syntax.pp_global thy; - val ct = cterm_of proto_ct; - val _ = (Sign.no_frees pp o map_types (K dummyT) o Sign.no_vars pp) - (Thm.term_of ct); - val thm = preprocess_conv thy ct; - val ct' = Thm.rhs_of thm; - val t' = Thm.term_of ct'; - val vs = Term.add_tfrees t' []; - val consts = fold_aterms - (fn Const (c, _) => insert (op =) c | _ => I) t' []; - - val t'' = prepare_sorts prep_sort t'; - val (algebra', eqngr') = obtain thy consts [t'']; - in conclude_evaluation (evaluator algebra' eqngr' vs t'' ct') thm end; - -fun simple_evaluator evaluator algebra eqngr vs t ct = - evaluator algebra eqngr vs t; - -fun eval_conv thy = - let - fun conclude_evaluation thm2 thm1 = - let - val thm3 = postprocess_conv thy (Thm.rhs_of thm2); - in - Thm.transitive thm1 (Thm.transitive thm2 thm3) handle THM _ => - error ("could not construct evaluation proof:\n" - ^ (cat_lines o map Display.string_of_thm) [thm1, thm2, thm3]) - end; - in gen_eval thy I conclude_evaluation end; - -fun eval thy prep_sort postproc evaluator = gen_eval thy (Thm.cterm_of thy) - (K o postproc (postprocess_term thy)) prep_sort (simple_evaluator evaluator); - - -(** setup **) - -val setup = - let - fun mk_attribute f = Thm.declaration_attribute (fn thm => Context.mapping (f thm) I); - fun add_del_attribute (name, (add, del)) = - Code.add_attribute (name, Args.del |-- Scan.succeed (mk_attribute del) - || Scan.succeed (mk_attribute add)) - in - add_del_attribute ("inline", (add_inline, del_inline)) - #> add_del_attribute ("post", (add_post, del_post)) - #> Code.add_attribute ("unfold", Scan.succeed (Thm.declaration_attribute - (fn thm => Context.mapping (Codegen.add_unfold thm #> add_inline thm) I))) - end; - -val _ = - OuterSyntax.improper_command "print_codeproc" "print code preprocessor setup" - OuterKeyword.diag (Scan.succeed - (Toplevel.no_timing o Toplevel.unknown_theory o Toplevel.keep - (print_codeproc o Toplevel.theory_of))); - -end; (*struct*) diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/Tools/code/code_printer.ML --- a/src/Tools/code/code_printer.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,318 +0,0 @@ -(* Title: Tools/code/code_printer.ML - Author: Florian Haftmann, TU Muenchen - -Generic operations for pretty printing of target language code. -*) - -signature CODE_PRINTER = -sig - val nerror: thm -> string -> 'a - - val @@ : 'a * 'a -> 'a list - val @| : 'a list * 'a -> 'a list - val str: string -> Pretty.T - val concat: Pretty.T list -> Pretty.T - val brackets: Pretty.T list -> Pretty.T - val semicolon: Pretty.T list -> Pretty.T - val enum_default: string -> string -> string -> string -> Pretty.T list -> Pretty.T - - val first_upper: string -> string - val first_lower: string -> string - type var_ctxt - val make_vars: string list -> var_ctxt - val intro_vars: string list -> var_ctxt -> var_ctxt - val lookup_var: var_ctxt -> string -> string - - type literals - val Literals: { literal_char: string -> string, literal_string: string -> string, - literal_numeral: bool -> int -> string, - literal_list: Pretty.T list -> Pretty.T, infix_cons: int * string } - -> literals - val literal_char: literals -> string -> string - val literal_string: literals -> string -> string - val literal_numeral: literals -> bool -> int -> string - val literal_list: literals -> Pretty.T list -> Pretty.T - val infix_cons: literals -> int * string - - type lrx - val L: lrx - val R: lrx - val X: lrx - type fixity - val BR: fixity - val NOBR: fixity - val INFX: int * lrx -> fixity - val APP: fixity - val brackify: fixity -> Pretty.T list -> Pretty.T - val brackify_infix: int * lrx -> fixity -> Pretty.T list -> Pretty.T - val brackify_block: fixity -> Pretty.T -> Pretty.T list -> Pretty.T -> Pretty.T - - type itype = Code_Thingol.itype - type iterm = Code_Thingol.iterm - type const = Code_Thingol.const - type dict = Code_Thingol.dict - type tyco_syntax - type const_syntax - type proto_const_syntax - val parse_infix: ('a -> 'b) -> lrx * int -> string - -> int * ((fixity -> 'b -> Pretty.T) - -> fixity -> 'a list -> Pretty.T) - val parse_syntax: ('a -> 'b) -> OuterParse.token list - -> (int * ((fixity -> 'b -> Pretty.T) - -> fixity -> 'a list -> Pretty.T)) option * OuterParse.token list - val simple_const_syntax: (int * ((fixity -> iterm -> Pretty.T) - -> fixity -> (iterm * itype) list -> Pretty.T)) option -> proto_const_syntax option - val activate_const_syntax: theory -> literals - -> proto_const_syntax -> Code_Thingol.naming -> const_syntax * Code_Thingol.naming - val gen_pr_app: (thm -> var_ctxt -> const * iterm list -> Pretty.T list) - -> (thm -> var_ctxt -> fixity -> iterm -> Pretty.T) - -> (string -> const_syntax option) - -> thm -> var_ctxt -> fixity -> const * iterm list -> Pretty.T - val gen_pr_bind: ((string option * Pretty.T option) * itype -> Pretty.T) - -> (thm -> var_ctxt -> fixity -> iterm -> Pretty.T) - -> thm -> fixity - -> (string option * iterm option) * itype -> var_ctxt -> Pretty.T * var_ctxt - - val mk_name_module: Name.context -> string option -> (string -> string option) - -> 'a Graph.T -> string -> string - val dest_name: string -> string * string -end; - -structure Code_Printer : CODE_PRINTER = -struct - -open Code_Thingol; - -fun nerror thm s = error (s ^ ",\nin equation " ^ Display.string_of_thm thm); - -(** assembling text pieces **) - -infixr 5 @@; -infixr 5 @|; -fun x @@ y = [x, y]; -fun xs @| y = xs @ [y]; -val str = PrintMode.setmp [] Pretty.str; -val concat = Pretty.block o Pretty.breaks; -val brackets = Pretty.enclose "(" ")" o Pretty.breaks; -fun semicolon ps = Pretty.block [concat ps, str ";"]; -fun enum_default default sep opn cls [] = str default - | enum_default default sep opn cls xs = Pretty.enum sep opn cls xs; - - -(** names and variable name contexts **) - -type var_ctxt = string Symtab.table * Name.context; - -fun make_vars names = (fold (fn name => Symtab.update_new (name, name)) names Symtab.empty, - Name.make_context names); - -fun intro_vars names (namemap, namectxt) = - let - val (names', namectxt') = Name.variants names namectxt; - val namemap' = fold2 (curry Symtab.update) names names' namemap; - in (namemap', namectxt') end; - -fun lookup_var (namemap, _) name = case Symtab.lookup namemap name - of SOME name' => name' - | NONE => error ("Invalid name in context: " ^ quote name); - -val first_upper = implode o nth_map 0 Symbol.to_ascii_upper o explode; -val first_lower = implode o nth_map 0 Symbol.to_ascii_lower o explode; - - -(** pretty literals **) - -datatype literals = Literals of { - literal_char: string -> string, - literal_string: string -> string, - literal_numeral: bool -> int -> string, - literal_list: Pretty.T list -> Pretty.T, - infix_cons: int * string -}; - -fun dest_Literals (Literals lits) = lits; - -val literal_char = #literal_char o dest_Literals; -val literal_string = #literal_string o dest_Literals; -val literal_numeral = #literal_numeral o dest_Literals; -val literal_list = #literal_list o dest_Literals; -val infix_cons = #infix_cons o dest_Literals; - - -(** syntax printer **) - -(* binding priorities *) - -datatype lrx = L | R | X; - -datatype fixity = - BR - | NOBR - | INFX of (int * lrx); - -val APP = INFX (~1, L); - -fun fixity_lrx L L = false - | fixity_lrx R R = false - | fixity_lrx _ _ = true; - -fun fixity NOBR _ = false - | fixity _ NOBR = false - | fixity (INFX (pr, lr)) (INFX (pr_ctxt, lr_ctxt)) = - pr < pr_ctxt - orelse pr = pr_ctxt - andalso fixity_lrx lr lr_ctxt - orelse pr_ctxt = ~1 - | fixity BR (INFX _) = false - | fixity _ _ = true; - -fun gen_brackify _ [p] = p - | gen_brackify true (ps as _::_) = Pretty.enclose "(" ")" ps - | gen_brackify false (ps as _::_) = Pretty.block ps; - -fun brackify fxy_ctxt = - gen_brackify (fixity BR fxy_ctxt) o Pretty.breaks; - -fun brackify_infix infx fxy_ctxt = - gen_brackify (fixity (INFX infx) fxy_ctxt) o Pretty.breaks; - -fun brackify_block fxy_ctxt p1 ps p2 = - let val p = Pretty.block_enclose (p1, p2) ps - in if fixity BR fxy_ctxt - then Pretty.enclose "(" ")" [p] - else p - end; - - -(* generic syntax *) - -type tyco_syntax = int * ((fixity -> itype -> Pretty.T) - -> fixity -> itype list -> Pretty.T); -type const_syntax = int * ((var_ctxt -> fixity -> iterm -> Pretty.T) - -> thm -> var_ctxt -> fixity -> (iterm * itype) list -> Pretty.T); -type proto_const_syntax = int * (string list * (literals -> string list - -> (var_ctxt -> fixity -> iterm -> Pretty.T) - -> thm -> var_ctxt -> fixity -> (iterm * itype) list -> Pretty.T)); - -fun simple_const_syntax (SOME (n, f)) = SOME (n, - ([], (fn _ => fn _ => fn pr => fn thm => fn vars => f (pr vars)))) - | simple_const_syntax NONE = NONE; - -fun activate_const_syntax thy literals (n, (cs, f)) naming = - fold_map (Code_Thingol.ensure_declared_const thy) cs naming - |-> (fn cs' => pair (n, f literals cs')); - -fun gen_pr_app pr_app pr_term syntax_const thm vars fxy (app as ((c, (_, tys)), ts)) = - case syntax_const c - of NONE => brackify fxy (pr_app thm vars app) - | SOME (k, pr) => - let - fun pr' fxy ts = pr (pr_term thm) thm vars fxy (ts ~~ curry Library.take k tys); - in if k = length ts - then pr' fxy ts - else if k < length ts - then case chop k ts of (ts1, ts2) => - brackify fxy (pr' APP ts1 :: map (pr_term thm vars BR) ts2) - else pr_term thm vars fxy (Code_Thingol.eta_expand k app) - end; - -fun gen_pr_bind pr_bind pr_term thm (fxy : fixity) ((v, pat), ty : itype) vars = - let - val vs = case pat - of SOME pat => Code_Thingol.fold_varnames (insert (op =)) pat [] - | NONE => []; - val vars' = intro_vars (the_list v) vars; - val vars'' = intro_vars vs vars'; - val v' = Option.map (lookup_var vars') v; - val pat' = Option.map (pr_term thm vars'' fxy) pat; - in (pr_bind ((v', pat'), ty), vars'') end; - - -(* mixfix syntax *) - -datatype 'a mixfix = - Arg of fixity - | Pretty of Pretty.T; - -fun mk_mixfix prep_arg (fixity_this, mfx) = - let - fun is_arg (Arg _) = true - | is_arg _ = false; - val i = (length o filter is_arg) mfx; - fun fillin _ [] [] = - [] - | fillin pr (Arg fxy :: mfx) (a :: args) = - (pr fxy o prep_arg) a :: fillin pr mfx args - | fillin pr (Pretty p :: mfx) args = - p :: fillin pr mfx args; - in - (i, fn pr => fn fixity_ctxt => fn args => - gen_brackify (fixity fixity_this fixity_ctxt) (fillin pr mfx args)) - end; - -fun parse_infix prep_arg (x, i) s = - let - val l = case x of L => INFX (i, L) | _ => INFX (i, X); - val r = case x of R => INFX (i, R) | _ => INFX (i, X); - in - mk_mixfix prep_arg (INFX (i, x), - [Arg l, (Pretty o Pretty.brk) 1, (Pretty o str) s, (Pretty o Pretty.brk) 1, Arg r]) - end; - -fun parse_mixfix prep_arg s = - let - val sym_any = Scan.one Symbol.is_regular; - val parse = Scan.optional ($$ "!" >> K true) false -- Scan.repeat ( - ($$ "(" -- $$ "_" -- $$ ")" >> K (Arg NOBR)) - || ($$ "_" >> K (Arg BR)) - || ($$ "/" |-- Scan.repeat ($$ " ") >> (Pretty o Pretty.brk o length)) - || (Scan.repeat1 - ( $$ "'" |-- sym_any - || Scan.unless ($$ "_" || $$ "/" || $$ "(" |-- $$ "_" |-- $$ ")") - sym_any) >> (Pretty o str o implode))); - in case Scan.finite Symbol.stopper parse (Symbol.explode s) - of ((_, p as [_]), []) => mk_mixfix prep_arg (NOBR, p) - | ((b, p as _ :: _ :: _), []) => mk_mixfix prep_arg (if b then NOBR else BR, p) - | _ => Scan.!! - (the_default ("malformed mixfix annotation: " ^ quote s) o snd) Scan.fail () - end; - -val (infixK, infixlK, infixrK) = ("infix", "infixl", "infixr"); - -fun parse_syntax prep_arg xs = - Scan.option (( - ((OuterParse.$$$ infixK >> K X) - || (OuterParse.$$$ infixlK >> K L) - || (OuterParse.$$$ infixrK >> K R)) - -- OuterParse.nat >> parse_infix prep_arg - || Scan.succeed (parse_mixfix prep_arg)) - -- OuterParse.string - >> (fn (parse, s) => parse s)) xs; - -val _ = List.app OuterKeyword.keyword [infixK, infixlK, infixrK]; - - -(** module name spaces **) - -val dest_name = - apfst Long_Name.implode o split_last o fst o split_last o Long_Name.explode; - -fun mk_name_module reserved_names module_prefix module_alias program = - let - fun mk_alias name = case module_alias name - of SOME name' => name' - | NONE => name - |> Long_Name.explode - |> map (fn name => (the_single o fst) (Name.variants [name] reserved_names)) - |> Long_Name.implode; - fun mk_prefix name = case module_prefix - of SOME module_prefix => Long_Name.append module_prefix name - | NONE => name; - val tab = - Symtab.empty - |> Graph.fold ((fn name => Symtab.default (name, (mk_alias #> mk_prefix) name)) - o fst o dest_name o fst) - program - in the o Symtab.lookup tab end; - -end; (*struct*) diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/Tools/code/code_target.ML --- a/src/Tools/code/code_target.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,629 +0,0 @@ -(* Title: Tools/code/code_target.ML - Author: Florian Haftmann, TU Muenchen - -Serializer from intermediate language ("Thin-gol") to target languages. -*) - -signature CODE_TARGET = -sig - include CODE_PRINTER - - type serializer - val add_target: string * (serializer * literals) -> theory -> theory - val extend_target: string * - (string * (Code_Thingol.naming -> Code_Thingol.program -> Code_Thingol.program)) - -> theory -> theory - val assert_target: theory -> string -> string - - type destination - type serialization - val parse_args: (OuterLex.token list -> 'a * OuterLex.token list) - -> OuterLex.token list -> 'a - val stmt_names_of_destination: destination -> string list - val code_of_pretty: Pretty.T -> string - val code_writeln: Pretty.T -> unit - val mk_serialization: string -> ('a -> unit) option - -> (Path.T option -> 'a -> unit) - -> ('a -> string * string option list) - -> 'a -> serialization - val serialize: theory -> string -> string option -> OuterLex.token list - -> Code_Thingol.naming -> Code_Thingol.program -> string list -> serialization - val serialize_custom: theory -> string * (serializer * literals) - -> Code_Thingol.naming -> Code_Thingol.program -> string list -> string * string option list - val the_literals: theory -> string -> literals - val compile: serialization -> unit - val export: serialization -> unit - val file: Path.T -> serialization -> unit - val string: string list -> serialization -> string - val code_of: theory -> string -> string - -> string list -> (Code_Thingol.naming -> string list) -> string - val shell_command: string (*theory name*) -> string (*export_code expr*) -> unit - val code_width: int ref - - val allow_abort: string -> theory -> theory - val add_syntax_class: string -> class -> string option -> theory -> theory - val add_syntax_inst: string -> string * class -> bool -> theory -> theory - val add_syntax_tyco: string -> string -> tyco_syntax option -> theory -> theory - val add_syntax_const: string -> string -> proto_const_syntax option -> theory -> theory - val add_reserved: string -> string -> theory -> theory -end; - -structure Code_Target : CODE_TARGET = -struct - -open Basic_Code_Thingol; -open Code_Printer; - -(** basics **) - -datatype destination = Compile | Export | File of Path.T | String of string list; -type serialization = destination -> (string * string option list) option; - -val code_width = ref 80; (*FIXME after Pretty module no longer depends on print mode*) -fun code_setmp f = PrintMode.setmp [] (Pretty.setmp_margin (!code_width) f); -fun code_of_pretty p = code_setmp Pretty.string_of p ^ "\n"; -fun code_writeln p = Pretty.setmp_margin (!code_width) Pretty.writeln p; - -(*FIXME why another code_setmp?*) -fun compile f = (code_setmp f Compile; ()); -fun export f = (code_setmp f Export; ()); -fun file p f = (code_setmp f (File p); ()); -fun string stmts f = fst (the (code_setmp f (String stmts))); - -fun stmt_names_of_destination (String stmts) = stmts - | stmt_names_of_destination _ = []; - -fun mk_serialization target (SOME comp) _ _ code Compile = (comp code; NONE) - | mk_serialization target NONE _ _ _ Compile = error (target ^ ": no internal compilation") - | mk_serialization target _ output _ code Export = (output NONE code ; NONE) - | mk_serialization target _ output _ code (File file) = (output (SOME file) code; NONE) - | mk_serialization target _ _ string code (String _) = SOME (string code); - - -(** theory data **) - -datatype name_syntax_table = NameSyntaxTable of { - class: string Symtab.table, - instance: unit Symreltab.table, - tyco: tyco_syntax Symtab.table, - const: proto_const_syntax Symtab.table -}; - -fun mk_name_syntax_table ((class, instance), (tyco, const)) = - NameSyntaxTable { class = class, instance = instance, tyco = tyco, const = const }; -fun map_name_syntax_table f (NameSyntaxTable { class, instance, tyco, const }) = - mk_name_syntax_table (f ((class, instance), (tyco, const))); -fun merge_name_syntax_table (NameSyntaxTable { class = class1, instance = instance1, tyco = tyco1, const = const1 }, - NameSyntaxTable { class = class2, instance = instance2, tyco = tyco2, const = const2 }) = - mk_name_syntax_table ( - (Symtab.join (K snd) (class1, class2), - Symreltab.join (K snd) (instance1, instance2)), - (Symtab.join (K snd) (tyco1, tyco2), - Symtab.join (K snd) (const1, const2)) - ); - -type serializer = - string option (*module name*) - -> OuterLex.token list (*arguments*) - -> (string -> string) (*labelled_name*) - -> string list (*reserved symbols*) - -> (string * Pretty.T) list (*includes*) - -> (string -> string option) (*module aliasses*) - -> (string -> string option) (*class syntax*) - -> (string -> tyco_syntax option) - -> (string -> const_syntax option) - -> Code_Thingol.program - -> string list (*selected statements*) - -> serialization; - -datatype serializer_entry = Serializer of serializer * literals - | Extends of string * (Code_Thingol.naming -> Code_Thingol.program -> Code_Thingol.program); - -datatype target = Target of { - serial: serial, - serializer: serializer_entry, - reserved: string list, - includes: (Pretty.T * string list) Symtab.table, - name_syntax_table: name_syntax_table, - module_alias: string Symtab.table -}; - -fun make_target ((serial, serializer), ((reserved, includes), (name_syntax_table, module_alias))) = - Target { serial = serial, serializer = serializer, reserved = reserved, - includes = includes, name_syntax_table = name_syntax_table, module_alias = module_alias }; -fun map_target f ( Target { serial, serializer, reserved, includes, name_syntax_table, module_alias } ) = - make_target (f ((serial, serializer), ((reserved, includes), (name_syntax_table, module_alias)))); -fun merge_target strict target (Target { serial = serial1, serializer = serializer, - reserved = reserved1, includes = includes1, - name_syntax_table = name_syntax_table1, module_alias = module_alias1 }, - Target { serial = serial2, serializer = _, - reserved = reserved2, includes = includes2, - name_syntax_table = name_syntax_table2, module_alias = module_alias2 }) = - if serial1 = serial2 orelse not strict then - make_target ((serial1, serializer), - ((merge (op =) (reserved1, reserved2), Symtab.merge (op =) (includes1, includes2)), - (merge_name_syntax_table (name_syntax_table1, name_syntax_table2), - Symtab.join (K snd) (module_alias1, module_alias2)) - )) - else - error ("Incompatible serializers: " ^ quote target); - -structure CodeTargetData = TheoryDataFun -( - type T = target Symtab.table * string list; - val empty = (Symtab.empty, []); - val copy = I; - val extend = I; - fun merge _ ((target1, exc1) : T, (target2, exc2)) = - (Symtab.join (merge_target true) (target1, target2), Library.merge (op =) (exc1, exc2)); -); - -fun the_serializer (Target { serializer, ... }) = serializer; -fun the_reserved (Target { reserved, ... }) = reserved; -fun the_includes (Target { includes, ... }) = includes; -fun the_name_syntax (Target { name_syntax_table = NameSyntaxTable x, ... }) = x; -fun the_module_alias (Target { module_alias , ... }) = module_alias; - -val abort_allowed = snd o CodeTargetData.get; - -fun assert_target thy target = - case Symtab.lookup (fst (CodeTargetData.get thy)) target - of SOME data => target - | NONE => error ("Unknown code target language: " ^ quote target); - -fun put_target (target, seri) thy = - let - val lookup_target = Symtab.lookup (fst (CodeTargetData.get thy)); - val _ = case seri - of Extends (super, _) => if is_some (lookup_target super) then () - else error ("Unknown code target language: " ^ quote super) - | _ => (); - val overwriting = case (Option.map the_serializer o lookup_target) target - of NONE => false - | SOME (Extends _) => true - | SOME (Serializer _) => (case seri - of Extends _ => error ("Will not overwrite existing target " ^ quote target) - | _ => true); - val _ = if overwriting - then warning ("Overwriting existing target " ^ quote target) - else (); - in - thy - |> (CodeTargetData.map o apfst oo Symtab.map_default) - (target, make_target ((serial (), seri), (([], Symtab.empty), - (mk_name_syntax_table ((Symtab.empty, Symreltab.empty), (Symtab.empty, Symtab.empty)), - Symtab.empty)))) - ((map_target o apfst o apsnd o K) seri) - end; - -fun add_target (target, seri) = put_target (target, Serializer seri); -fun extend_target (target, (super, modify)) = - put_target (target, Extends (super, modify)); - -fun map_target_data target f thy = - let - val _ = assert_target thy target; - in - thy - |> (CodeTargetData.map o apfst o Symtab.map_entry target o map_target) f - end; - -fun map_reserved target = - map_target_data target o apsnd o apfst o apfst; -fun map_includes target = - map_target_data target o apsnd o apfst o apsnd; -fun map_name_syntax target = - map_target_data target o apsnd o apsnd o apfst o map_name_syntax_table; -fun map_module_alias target = - map_target_data target o apsnd o apsnd o apsnd; - - -(** serializer configuration **) - -(* data access *) - -local - -fun cert_class thy class = - let - val _ = AxClass.get_info thy class; - in class end; - -fun read_class thy = cert_class thy o Sign.intern_class thy; - -fun cert_tyco thy tyco = - let - val _ = if Sign.declared_tyname thy tyco then () - else error ("No such type constructor: " ^ quote tyco); - in tyco end; - -fun read_tyco thy = cert_tyco thy o Sign.intern_type thy; - -fun gen_add_syntax_class prep_class prep_const target raw_class raw_syn thy = - let - val class = prep_class thy raw_class; - in case raw_syn - of SOME syntax => - thy - |> (map_name_syntax target o apfst o apfst) - (Symtab.update (class, syntax)) - | NONE => - thy - |> (map_name_syntax target o apfst o apfst) - (Symtab.delete_safe class) - end; - -fun gen_add_syntax_inst prep_class prep_tyco target (raw_tyco, raw_class) add_del thy = - let - val inst = (prep_class thy raw_class, prep_tyco thy raw_tyco); - in if add_del then - thy - |> (map_name_syntax target o apfst o apsnd) - (Symreltab.update (inst, ())) - else - thy - |> (map_name_syntax target o apfst o apsnd) - (Symreltab.delete_safe inst) - end; - -fun gen_add_syntax_tyco prep_tyco target raw_tyco raw_syn thy = - let - val tyco = prep_tyco thy raw_tyco; - fun check_args (syntax as (n, _)) = if n <> Sign.arity_number thy tyco - then error ("Number of arguments mismatch in syntax for type constructor " ^ quote tyco) - else syntax - in case raw_syn - of SOME syntax => - thy - |> (map_name_syntax target o apsnd o apfst) - (Symtab.update (tyco, check_args syntax)) - | NONE => - thy - |> (map_name_syntax target o apsnd o apfst) - (Symtab.delete_safe tyco) - end; - -fun gen_add_syntax_const prep_const target raw_c raw_syn thy = - let - val c = prep_const thy raw_c; - fun check_args (syntax as (n, _)) = if n > Code.no_args thy c - then error ("Too many arguments in syntax for constant " ^ quote c) - else syntax; - in case raw_syn - of SOME syntax => - thy - |> (map_name_syntax target o apsnd o apsnd) - (Symtab.update (c, check_args syntax)) - | NONE => - thy - |> (map_name_syntax target o apsnd o apsnd) - (Symtab.delete_safe c) - end; - -fun add_reserved target = - let - fun add sym syms = if member (op =) syms sym - then error ("Reserved symbol " ^ quote sym ^ " already declared") - else insert (op =) sym syms - in map_reserved target o add end; - -fun gen_add_include read_const target args thy = - let - fun add (name, SOME (content, raw_cs)) incls = - let - val _ = if Symtab.defined incls name - then warning ("Overwriting existing include " ^ name) - else (); - val cs = map (read_const thy) raw_cs; - in Symtab.update (name, (str content, cs)) incls end - | add (name, NONE) incls = Symtab.delete name incls; - in map_includes target (add args) thy end; - -val add_include = gen_add_include Code.check_const; -val add_include_cmd = gen_add_include Code.read_const; - -fun add_module_alias target (thyname, modlname) = - let - val xs = Long_Name.explode modlname; - val xs' = map (Name.desymbolize true) xs; - in if xs' = xs - then map_module_alias target (Symtab.update (thyname, modlname)) - else error ("Invalid module name: " ^ quote modlname ^ "\n" - ^ "perhaps try " ^ quote (Long_Name.implode xs')) - end; - -fun gen_allow_abort prep_const raw_c thy = - let - val c = prep_const thy raw_c; - in thy |> (CodeTargetData.map o apsnd) (insert (op =) c) end; - -fun zip_list (x::xs) f g = - f - #-> (fn y => - fold_map (fn x => g |-- f >> pair x) xs - #-> (fn xys => pair ((x, y) :: xys))); - - -(* concrete syntax *) - -structure P = OuterParse -and K = OuterKeyword - -fun parse_multi_syntax parse_thing parse_syntax = - P.and_list1 parse_thing - #-> (fn things => Scan.repeat1 (P.$$$ "(" |-- P.name -- - (zip_list things parse_syntax (P.$$$ "and")) --| P.$$$ ")")); - -in - -val add_syntax_class = gen_add_syntax_class cert_class (K I); -val add_syntax_inst = gen_add_syntax_inst cert_class cert_tyco; -val add_syntax_tyco = gen_add_syntax_tyco cert_tyco; -val add_syntax_const = gen_add_syntax_const (K I); -val allow_abort = gen_allow_abort (K I); -val add_reserved = add_reserved; - -val add_syntax_class_cmd = gen_add_syntax_class read_class Code.read_const; -val add_syntax_inst_cmd = gen_add_syntax_inst read_class read_tyco; -val add_syntax_tyco_cmd = gen_add_syntax_tyco read_tyco; -val add_syntax_const_cmd = gen_add_syntax_const Code.read_const; -val allow_abort_cmd = gen_allow_abort Code.read_const; - -fun the_literals thy = - let - val (targets, _) = CodeTargetData.get thy; - fun literals target = case Symtab.lookup targets target - of SOME data => (case the_serializer data - of Serializer (_, literals) => literals - | Extends (super, _) => literals super) - | NONE => error ("Unknown code target language: " ^ quote target); - in literals end; - - -(** serializer usage **) - -(* montage *) - -local - -fun labelled_name thy program name = case Graph.get_node program name - of Code_Thingol.Fun (c, _) => quote (Code.string_of_const thy c) - | Code_Thingol.Datatype (tyco, _) => "type " ^ quote (Sign.extern_type thy tyco) - | Code_Thingol.Datatypecons (c, _) => quote (Code.string_of_const thy c) - | Code_Thingol.Class (class, _) => "class " ^ quote (Sign.extern_class thy class) - | Code_Thingol.Classrel (sub, super) => let - val Code_Thingol.Class (sub, _) = Graph.get_node program sub - val Code_Thingol.Class (super, _) = Graph.get_node program super - in quote (Sign.extern_class thy sub ^ " < " ^ Sign.extern_class thy super) end - | Code_Thingol.Classparam (c, _) => quote (Code.string_of_const thy c) - | Code_Thingol.Classinst ((class, (tyco, _)), _) => let - val Code_Thingol.Class (class, _) = Graph.get_node program class - val Code_Thingol.Datatype (tyco, _) = Graph.get_node program tyco - in quote (Sign.extern_type thy tyco ^ " :: " ^ Sign.extern_class thy class) end - -fun activate_syntax lookup_name src_tab = Symtab.empty - |> fold_map (fn thing_identifier => fn tab => case lookup_name thing_identifier - of SOME name => (SOME name, - Symtab.update_new (name, the (Symtab.lookup src_tab thing_identifier)) tab) - | NONE => (NONE, tab)) (Symtab.keys src_tab) - |>> map_filter I; - -fun activate_const_syntax thy literals src_tab naming = (Symtab.empty, naming) - |> fold_map (fn thing_identifier => fn (tab, naming) => - case Code_Thingol.lookup_const naming thing_identifier - of SOME name => let - val (syn, naming') = Code_Printer.activate_const_syntax thy - literals (the (Symtab.lookup src_tab thing_identifier)) naming - in (SOME name, (Symtab.update_new (name, syn) tab, naming')) end - | NONE => (NONE, (tab, naming))) (Symtab.keys src_tab) - |>> map_filter I; - -fun invoke_serializer thy abortable serializer literals reserved abs_includes - module_alias class instance tyco const module args naming program2 names1 = - let - val (names_class, class') = - activate_syntax (Code_Thingol.lookup_class naming) class; - val names_inst = map_filter (Code_Thingol.lookup_instance naming) - (Symreltab.keys instance); - val (names_tyco, tyco') = - activate_syntax (Code_Thingol.lookup_tyco naming) tyco; - val (names_const, (const', _)) = - activate_const_syntax thy literals const naming; - val names_hidden = names_class @ names_inst @ names_tyco @ names_const; - val names2 = subtract (op =) names_hidden names1; - val program3 = Graph.subgraph (not o member (op =) names_hidden) program2; - val names_all = Graph.all_succs program3 names2; - val includes = abs_includes names_all; - val program4 = Graph.subgraph (member (op =) names_all) program3; - val empty_funs = filter_out (member (op =) abortable) - (Code_Thingol.empty_funs program3); - val _ = if null empty_funs then () else error ("No code equations for " - ^ commas (map (Sign.extern_const thy) empty_funs)); - in - serializer module args (labelled_name thy program2) reserved includes - (Symtab.lookup module_alias) (Symtab.lookup class') - (Symtab.lookup tyco') (Symtab.lookup const') - program4 names2 - end; - -fun mount_serializer thy alt_serializer target module args naming program names = - let - val (targets, abortable) = CodeTargetData.get thy; - fun collapse_hierarchy target = - let - val data = case Symtab.lookup targets target - of SOME data => data - | NONE => error ("Unknown code target language: " ^ quote target); - in case the_serializer data - of Serializer _ => (I, data) - | Extends (super, modify) => let - val (modify', data') = collapse_hierarchy super - in (modify' #> modify naming, merge_target false target (data', data)) end - end; - val (modify, data) = collapse_hierarchy target; - val (serializer, _) = the_default (case the_serializer data - of Serializer seri => seri) alt_serializer; - val reserved = the_reserved data; - fun select_include names_all (name, (content, cs)) = - if null cs then SOME (name, content) - else if exists (fn c => case Code_Thingol.lookup_const naming c - of SOME name => member (op =) names_all name - | NONE => false) cs - then SOME (name, content) else NONE; - fun includes names_all = map_filter (select_include names_all) - ((Symtab.dest o the_includes) data); - val module_alias = the_module_alias data; - val { class, instance, tyco, const } = the_name_syntax data; - val literals = the_literals thy target; - in - invoke_serializer thy abortable serializer literals reserved - includes module_alias class instance tyco const module args naming (modify program) names - end; - -in - -fun serialize thy = mount_serializer thy NONE; - -fun serialize_custom thy (target_name, seri) naming program names = - mount_serializer thy (SOME seri) target_name NONE [] naming program names (String []) - |> the; - -end; (* local *) - -fun parse_args f args = - case Scan.read OuterLex.stopper f args - of SOME x => x - | NONE => error "Bad serializer arguments"; - - -(* code presentation *) - -fun code_of thy target module_name cs names_stmt = - let - val (names_cs, (naming, program)) = Code_Thingol.consts_program thy cs; - in - string (names_stmt naming) (serialize thy target (SOME module_name) [] - naming program names_cs) - end; - - -(* code generation *) - -fun transitivly_non_empty_funs thy naming program = - let - val cs = subtract (op =) (abort_allowed thy) (Code_Thingol.empty_funs program); - val names = map_filter (Code_Thingol.lookup_const naming) cs; - in subtract (op =) (Graph.all_preds program names) (Graph.keys program) end; - -fun read_const_exprs thy cs = - let - val (cs1, cs2) = Code_Thingol.read_const_exprs thy cs; - val (names3, (naming, program)) = Code_Thingol.consts_program thy cs2; - val names4 = transitivly_non_empty_funs thy naming program; - val cs5 = map_filter - (fn (c, name) => if member (op =) names4 name then SOME c else NONE) (cs2 ~~ names3); - in fold (insert (op =)) cs5 cs1 end; - -fun cached_program thy = - let - val (naming, program) = Code_Thingol.cached_program thy; - in (transitivly_non_empty_funs thy naming program, (naming, program)) end - -fun export_code thy cs seris = - let - val (cs', (naming, program)) = if null cs then cached_program thy - else Code_Thingol.consts_program thy cs; - fun mk_seri_dest dest = case dest - of NONE => compile - | SOME "-" => export - | SOME f => file (Path.explode f) - val _ = map (fn (((target, module), dest), args) => - (mk_seri_dest dest (serialize thy target module args naming program cs'))) seris; - in () end; - -fun export_code_cmd raw_cs seris thy = export_code thy (read_const_exprs thy raw_cs) seris; - - -(** Isar setup **) - -val (inK, module_nameK, fileK) = ("in", "module_name", "file"); - -val code_exprP = - (Scan.repeat P.term_group - -- Scan.repeat (P.$$$ inK |-- P.name - -- Scan.option (P.$$$ module_nameK |-- P.name) - -- Scan.option (P.$$$ fileK |-- P.name) - -- Scan.optional (P.$$$ "(" |-- Args.parse --| P.$$$ ")") [] - ) >> (fn (raw_cs, seris) => export_code_cmd raw_cs seris)); - -val _ = List.app OuterKeyword.keyword [inK, module_nameK, fileK]; - -val _ = - OuterSyntax.command "code_class" "define code syntax for class" K.thy_decl ( - parse_multi_syntax P.xname (Scan.option P.string) - >> (Toplevel.theory oo fold) (fn (target, syns) => - fold (fn (raw_class, syn) => add_syntax_class_cmd target raw_class syn) syns) - ); - -val _ = - OuterSyntax.command "code_instance" "define code syntax for instance" K.thy_decl ( - parse_multi_syntax (P.xname --| P.$$$ "::" -- P.xname) - ((P.minus >> K true) || Scan.succeed false) - >> (Toplevel.theory oo fold) (fn (target, syns) => - fold (fn (raw_inst, add_del) => add_syntax_inst_cmd target raw_inst add_del) syns) - ); - -val _ = - OuterSyntax.command "code_type" "define code syntax for type constructor" K.thy_decl ( - parse_multi_syntax P.xname (parse_syntax I) - >> (Toplevel.theory oo fold) (fn (target, syns) => - fold (fn (raw_tyco, syn) => add_syntax_tyco_cmd target raw_tyco syn) syns) - ); - -val _ = - OuterSyntax.command "code_const" "define code syntax for constant" K.thy_decl ( - parse_multi_syntax P.term_group (parse_syntax fst) - >> (Toplevel.theory oo fold) (fn (target, syns) => - fold (fn (raw_const, syn) => add_syntax_const_cmd target raw_const - (Code_Printer.simple_const_syntax syn)) syns) - ); - -val _ = - OuterSyntax.command "code_reserved" "declare words as reserved for target language" K.thy_decl ( - P.name -- Scan.repeat1 P.name - >> (fn (target, reserveds) => (Toplevel.theory o fold (add_reserved target)) reserveds) - ); - -val _ = - OuterSyntax.command "code_include" "declare piece of code to be included in generated code" K.thy_decl ( - P.name -- P.name -- (P.text :|-- (fn "-" => Scan.succeed NONE - | s => Scan.optional (P.$$$ "attach" |-- Scan.repeat1 P.term) [] >> pair s >> SOME)) - >> (fn ((target, name), content_consts) => - (Toplevel.theory o add_include_cmd target) (name, content_consts)) - ); - -val _ = - OuterSyntax.command "code_modulename" "alias module to other name" K.thy_decl ( - P.name -- Scan.repeat1 (P.name -- P.name) - >> (fn (target, modlnames) => (Toplevel.theory o fold (add_module_alias target)) modlnames) - ); - -val _ = - OuterSyntax.command "code_abort" "permit constant to be implemented as program abort" K.thy_decl ( - Scan.repeat1 P.term_group >> (Toplevel.theory o fold allow_abort_cmd) - ); - -val _ = - OuterSyntax.command "export_code" "generate executable code for constants" - K.diag (P.!!! code_exprP >> (fn f => Toplevel.keep (f o Toplevel.theory_of))); - -fun shell_command thyname cmd = Toplevel.program (fn _ => - (use_thy thyname; case Scan.read OuterLex.stopper (P.!!! code_exprP) - ((filter OuterLex.is_proper o OuterSyntax.scan Position.none) cmd) - of SOME f => (writeln "Now generating code..."; f (theory thyname)) - | NONE => error ("Bad directive " ^ quote cmd))) - handle TOPLEVEL_ERROR => OS.Process.exit OS.Process.failure; - -end; (*local*) - -end; (*struct*) diff -r 5c8cfaed32e6 -r 2b04504fcb69 src/Tools/code/code_thingol.ML --- a/src/Tools/code/code_thingol.ML Tue Jun 23 12:09:14 2009 +0200 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,876 +0,0 @@ -(* Title: Tools/code/code_thingol.ML - Author: Florian Haftmann, TU Muenchen - -Intermediate language ("Thin-gol") representing executable code. -Representation and translation. -*) - -infix 8 `%%; -infix 4 `$; -infix 4 `$$; -infixr 3 `|=>; -infixr 3 `|==>; - -signature BASIC_CODE_THINGOL = -sig - type vname = string; - datatype dict = - DictConst of string * dict list list - | DictVar of string list * (vname * (int * int)); - datatype itype = - `%% of string * itype list - | ITyVar of vname; - type const = string * ((itype list * dict list list) * itype list (*types of arguments*)) - datatype iterm = - IConst of const - | IVar of vname - | `$ of iterm * iterm - | `|=> of (vname * itype) * iterm - | ICase of ((iterm * itype) * (iterm * iterm) list) * iterm; - (*((term, type), [(selector pattern, body term )]), primitive term)*) - val `$$ : iterm * iterm list -> iterm; - val `|==> : (vname * itype) list * iterm -> iterm; - type typscheme = (vname * sort) list * itype; -end; - -signature CODE_THINGOL = -sig - include BASIC_CODE_THINGOL - val unfoldl: ('a -> ('a * 'b) option) -> 'a -> 'a * 'b list - val unfoldr: ('a -> ('b * 'a) option) -> 'a -> 'b list * 'a - val unfold_fun: itype -> itype list * itype - val unfold_app: iterm -> iterm * iterm list - val split_abs: iterm -> (((vname * iterm option) * itype) * iterm) option - val unfold_abs: iterm -> ((vname * iterm option) * itype) list * iterm - val split_let: iterm -> (((iterm * itype) * iterm) * iterm) option - val unfold_let: iterm -> ((iterm * itype) * iterm) list * iterm - val unfold_const_app: iterm -> (const * iterm list) option - val collapse_let: ((vname * itype) * iterm) * iterm - -> (iterm * itype) * (iterm * iterm) list - val eta_expand: int -> const * iterm list -> iterm - val contains_dictvar: iterm -> bool - val locally_monomorphic: iterm -> bool - val fold_constnames: (string -> 'a -> 'a) -> iterm -> 'a -> 'a - val fold_varnames: (string -> 'a -> 'a) -> iterm -> 'a -> 'a - val fold_unbound_varnames: (string -> 'a -> 'a) -> iterm -> 'a -> 'a - - type naming - val empty_naming: naming - val lookup_class: naming -> class -> string option - val lookup_classrel: naming -> class * class -> string option - val lookup_tyco: naming -> string -> string option - val lookup_instance: naming -> class * string -> string option - val lookup_const: naming -> string -> string option - val ensure_declared_const: theory -> string -> naming -> string * naming - - datatype stmt = - NoStmt - | Fun of string * (typscheme * ((iterm list * iterm) * (thm * bool)) list) - | Datatype of string * ((vname * sort) list * (string * itype list) list) - | Datatypecons of string * string - | Class of class * (vname * ((class * string) list * (string * itype) list)) - | Classrel of class * class - | Classparam of string * class - | Classinst of (class * (string * (vname * sort) list)) - * ((class * (string * (string * dict list list))) list - * ((string * const) * (thm * bool)) list) - type program = stmt Graph.T - val empty_funs: program -> string list - val map_terms_bottom_up: (iterm -> iterm) -> iterm -> iterm - val map_terms_stmt: (iterm -> iterm) -> stmt -> stmt - val is_cons: program -> string -> bool - val contr_classparam_typs: program -> string -> itype option list - - val read_const_exprs: theory -> string list -> string list * string list - val consts_program: theory -> string list -> string list * (naming * program) - val cached_program: theory -> naming * program - val eval_conv: theory -> (sort -> sort) - -> (naming -> program -> ((string * sort) list * typscheme) * iterm -> string list -> cterm -> thm) - -> cterm -> thm - val eval: theory -> (sort -> sort) -> ((term -> term) -> 'a -> 'a) - -> (naming -> program -> ((string * sort) list * typscheme) * iterm -> string list -> 'a) - -> term -> 'a -end; - -structure Code_Thingol: CODE_THINGOL = -struct - -(** auxiliary **) - -fun unfoldl dest x = - case dest x - of NONE => (x, []) - | SOME (x1, x2) => - let val (x', xs') = unfoldl dest x1 in (x', xs' @ [x2]) end; - -fun unfoldr dest x = - case dest x - of NONE => ([], x) - | SOME (x1, x2) => - let val (xs', x') = unfoldr dest x2 in (x1::xs', x') end; - - -(** language core - types, terms **) - -type vname = string; - -datatype dict = - DictConst of string * dict list list - | DictVar of string list * (vname * (int * int)); - -datatype itype = - `%% of string * itype list - | ITyVar of vname; - -type const = string * ((itype list * dict list list) * itype list (*types of arguments*)) - -datatype iterm = - IConst of const - | IVar of vname - | `$ of iterm * iterm - | `|=> of (vname * itype) * iterm - | ICase of ((iterm * itype) * (iterm * iterm) list) * iterm; - (*see also signature*) - -val op `$$ = Library.foldl (op `$); -val op `|==> = Library.foldr (op `|=>); - -val unfold_app = unfoldl - (fn op `$ t => SOME t - | _ => NONE); - -val split_abs = - (fn (v, ty) `|=> (t as ICase (((IVar w, _), [(p, t')]), _)) => - if v = w then SOME (((v, SOME p), ty), t') else SOME (((v, NONE), ty), t) - | (v, ty) `|=> t => SOME (((v, NONE), ty), t) - | _ => NONE); - -val unfold_abs = unfoldr split_abs; - -val split_let = - (fn ICase (((td, ty), [(p, t)]), _) => SOME (((p, ty), td), t) - | _ => NONE); - -val unfold_let = unfoldr split_let; - -fun unfold_const_app t = - case unfold_app t - of (IConst c, ts) => SOME (c, ts) - | _ => NONE; - -fun fold_aiterms f (t as IConst _) = f t - | fold_aiterms f (t as IVar _) = f t - | fold_aiterms f (t1 `$ t2) = fold_aiterms f t1 #> fold_aiterms f t2 - | fold_aiterms f (t as _ `|=> t') = f t #> fold_aiterms f t' - | fold_aiterms f (ICase (_, t)) = fold_aiterms f t; - -fun fold_constnames f = - let - fun add (IConst (c, _)) = f c - | add _ = I; - in fold_aiterms add end; - -fun fold_varnames f = - let - fun add (IVar v) = f v - | add ((v, _) `|=> _) = f v - | add _ = I; - in fold_aiterms add end; - -fun fold_unbound_varnames f = - let - fun add _ (IConst _) = I - | add vs (IVar v) = if not (member (op =) vs v) then f v else I - | add vs (t1 `$ t2) = add vs t1 #> add vs t2 - | add vs ((v, _) `|=> t) = add (insert (op =) v vs) t - | add vs (ICase (_, t)) = add vs t; - in add [] end; - -fun collapse_let (((v, ty), se), be as ICase (((IVar w, _), ds), _)) = - let - fun exists_v t = fold_unbound_varnames (fn w => fn b => - b orelse v = w) t false; - in if v = w andalso forall (fn (t1, t2) => - exists_v t1 orelse not (exists_v t2)) ds - then ((se, ty), ds) - else ((se, ty), [(IVar v, be)]) - end - | collapse_let (((v, ty), se), be) = - ((se, ty), [(IVar v, be)]) - -fun eta_expand k (c as (_, (_, tys)), ts) = - let - val j = length ts; - val l = k - j; - val ctxt = (fold o fold_varnames) Name.declare ts Name.context; - val vs_tys = Name.names ctxt "a" ((curry Library.take l o curry Library.drop j) tys); - in vs_tys `|==> IConst c `$$ ts @ map (fn (v, _) => IVar v) vs_tys end; - -fun contains_dictvar t = - let - fun contains (DictConst (_, dss)) = (fold o fold) contains dss - | contains (DictVar _) = K true; - in - fold_aiterms - (fn IConst (_, ((_, dss), _)) => (fold o fold) contains dss | _ => I) t false - end; - -fun locally_monomorphic (IConst _) = false - | locally_monomorphic (IVar _) = true - | locally_monomorphic (t `$ _) = locally_monomorphic t - | locally_monomorphic (_ `|=> t) = locally_monomorphic t - | locally_monomorphic (ICase ((_, ds), _)) = exists (locally_monomorphic o snd) ds; - - -(** namings **) - -(* policies *) - -local - fun thyname_of thy f x = the (AList.lookup (op =) (f x) Markup.theory_nameN); - fun thyname_of_class thy = - thyname_of thy (ProofContext.query_class (ProofContext.init thy)); - fun thyname_of_tyco thy = - thyname_of thy (Type.the_tags (Sign.tsig_of thy)); - fun thyname_of_instance thy inst = case AxClass.arity_property thy inst Markup.theory_nameN - of [] => error ("no such instance: " ^ quote (snd inst ^ " :: " ^ fst inst)) - | thyname :: _ => thyname; - fun thyname_of_const thy c = case AxClass.class_of_param thy c - of SOME class => thyname_of_class thy class - | NONE => (case Code.get_datatype_of_constr thy c - of SOME dtco => thyname_of_tyco thy dtco - | NONE => thyname_of thy (Consts.the_tags (Sign.consts_of thy)) c); - fun purify_base "op &" = "and" - | purify_base "op |" = "or" - | purify_base "op -->" = "implies" - | purify_base "op :" = "member" - | purify_base "op =" = "eq" - | purify_base "*" = "product" - | purify_base "+" = "sum" - | purify_base s = Name.desymbolize false s; - fun namify thy get_basename get_thyname name = - let - val prefix = get_thyname thy name; - val base = (purify_base o get_basename) name; - in Long_Name.append prefix base end; -in - -fun namify_class thy = namify thy Long_Name.base_name thyname_of_class; -fun namify_classrel thy = namify thy (fn (class1, class2) => - Long_Name.base_name class2 ^ "_" ^ Long_Name.base_name class1) (fn thy => thyname_of_class thy o fst); - (*order fits nicely with composed projections*) -fun namify_tyco thy "fun" = "Pure.fun" - | namify_tyco thy tyco = namify thy Long_Name.base_name thyname_of_tyco tyco; -fun namify_instance thy = namify thy (fn (class, tyco) => - Long_Name.base_name class ^ "_" ^ Long_Name.base_name tyco) thyname_of_instance; -fun namify_const thy = namify thy Long_Name.base_name thyname_of_const; - -end; (* local *) - - -(* data *) - -datatype naming = Naming of { - class: class Symtab.table * Name.context, - classrel: string Symreltab.table * Name.context, - tyco: string Symtab.table * Name.context, - instance: string Symreltab.table * Name.context, - const: string Symtab.table * Name.context -} - -fun dest_Naming (Naming naming) = naming; - -val empty_naming = Naming { - class = (Symtab.empty, Name.context), - classrel = (Symreltab.empty, Name.context), - tyco = (Symtab.empty, Name.context), - instance = (Symreltab.empty, Name.context), - const = (Symtab.empty, Name.context) -}; - -local - fun mk_naming (class, classrel, tyco, instance, const) = - Naming { class = class, classrel = classrel, - tyco = tyco, instance = instance, const = const }; - fun map_naming f (Naming { class, classrel, tyco, instance, const }) = - mk_naming (f (class, classrel, tyco, instance, const)); -in - fun map_class f = map_naming - (fn (class, classrel, tyco, inst, const) => - (f class, classrel, tyco, inst, const)); - fun map_classrel f = map_naming - (fn (class, classrel, tyco, inst, const) => - (class, f classrel, tyco, inst, const)); - fun map_tyco f = map_naming - (fn (class, classrel, tyco, inst, const) => - (class, classrel, f tyco, inst, const)); - fun map_instance f = map_naming - (fn (class, classrel, tyco, inst, const) => - (class, classrel, tyco, f inst, const)); - fun map_const f = map_naming - (fn (class, classrel, tyco, inst, const) => - (class, classrel, tyco, inst, f const)); -end; (*local*) - -fun add_variant update (thing, name) (tab, used) = - let - val (name', used') = yield_singleton Name.variants name used; - val tab' = update (thing, name') tab; - in (tab', used') end; - -fun declare thy mapp lookup update namify thing = - mapp (add_variant update (thing, namify thy thing)) - #> `(fn naming => the (lookup naming thing)); - - -(* lookup and declare *) - -local - -val suffix_class = "class"; -val suffix_classrel = "classrel" -val suffix_tyco = "tyco"; -val suffix_instance = "inst"; -val suffix_const = "const"; - -fun add_suffix nsp NONE = NONE - | add_suffix nsp (SOME name) = SOME (Long_Name.append name nsp); - -in - -val lookup_class = add_suffix suffix_class - oo Symtab.lookup o fst o #class o dest_Naming; -val lookup_classrel = add_suffix suffix_classrel - oo Symreltab.lookup o fst o #classrel o dest_Naming; -val lookup_tyco = add_suffix suffix_tyco - oo Symtab.lookup o fst o #tyco o dest_Naming; -val lookup_instance = add_suffix suffix_instance - oo Symreltab.lookup o fst o #instance o dest_Naming; -val lookup_const = add_suffix suffix_const - oo Symtab.lookup o fst o #const o dest_Naming; - -fun declare_class thy = declare thy map_class - lookup_class Symtab.update_new namify_class; -fun declare_classrel thy = declare thy map_classrel - lookup_classrel Symreltab.update_new namify_classrel; -fun declare_tyco thy = declare thy map_tyco - lookup_tyco Symtab.update_new namify_tyco; -fun declare_instance thy = declare thy map_instance - lookup_instance Symreltab.update_new namify_instance; -fun declare_const thy = declare thy map_const - lookup_const Symtab.update_new namify_const; - -fun ensure_declared_const thy const naming = - case lookup_const naming const - of SOME const' => (const', naming) - | NONE => declare_const thy const naming; - -val unfold_fun = unfoldr - (fn "Pure.fun.tyco" `%% [ty1, ty2] => SOME (ty1, ty2) - | _ => NONE); (*depends on suffix_tyco and namify_tyco!*) - -end; (* local *) - - -(** statements, abstract programs **) - -type typscheme = (vname * sort) list * itype; -datatype stmt = - NoStmt - | Fun of string * (typscheme * ((iterm list * iterm) * (thm * bool)) list) - | Datatype of string * ((vname * sort) list * (string * itype list) list) - | Datatypecons of string * string - | Class of class * (vname * ((class * string) list * (string * itype) list)) - | Classrel of class * class - | Classparam of string * class - | Classinst of (class * (string * (vname * sort) list)) - * ((class * (string * (string * dict list list))) list - * ((string * const) * (thm * bool)) list); - -type program = stmt Graph.T; - -fun empty_funs program = - Graph.fold (fn (name, (Fun (c, (_, [])), _)) => cons c - | _ => I) program []; - -fun map_terms_bottom_up f (t as IConst _) = f t - | map_terms_bottom_up f (t as IVar _) = f t - | map_terms_bottom_up f (t1 `$ t2) = f - (map_terms_bottom_up f t1 `$ map_terms_bottom_up f t2) - | map_terms_bottom_up f ((v, ty) `|=> t) = f - ((v, ty) `|=> map_terms_bottom_up f t) - | map_terms_bottom_up f (ICase (((t, ty), ps), t0)) = f - (ICase (((map_terms_bottom_up f t, ty), (map o pairself) - (map_terms_bottom_up f) ps), map_terms_bottom_up f t0)); - -fun map_terms_stmt f NoStmt = NoStmt - | map_terms_stmt f (Fun (c, (tysm, eqs))) = Fun (c, (tysm, (map o apfst) - (fn (ts, t) => (map f ts, f t)) eqs)) - | map_terms_stmt f (stmt as Datatype _) = stmt - | map_terms_stmt f (stmt as Datatypecons _) = stmt - | map_terms_stmt f (stmt as Class _) = stmt - | map_terms_stmt f (stmt as Classrel _) = stmt - | map_terms_stmt f (stmt as Classparam _) = stmt - | map_terms_stmt f (Classinst (arity, (superarities, classparms))) = - Classinst (arity, (superarities, (map o apfst o apsnd) (fn const => - case f (IConst const) of IConst const' => const') classparms)); - -fun is_cons program name = case Graph.get_node program name - of Datatypecons _ => true - | _ => false; - -fun contr_classparam_typs program name = case Graph.get_node program name - of Classparam (_, class) => let - val Class (_, (_, (_, params))) = Graph.get_node program class; - val SOME ty = AList.lookup (op =) params name; - val (tys, res_ty) = unfold_fun ty; - fun no_tyvar (_ `%% tys) = forall no_tyvar tys - | no_tyvar (ITyVar _) = false; - in if no_tyvar res_ty - then map (fn ty => if no_tyvar ty then NONE else SOME ty) tys - else [] - end - | _ => []; - - -(** translation kernel **) - -(* generic mechanisms *) - -fun ensure_stmt lookup declare generate thing (dep, (naming, program)) = - let - fun add_dep name = case dep of NONE => I - | SOME dep => Graph.add_edge (dep, name); - val (name, naming') = case lookup naming thing - of SOME name => (name, naming) - | NONE => declare thing naming; - in case try (Graph.get_node program) name - of SOME stmt => program - |> add_dep name - |> pair naming' - |> pair dep - |> pair name - | NONE => program - |> Graph.default_node (name, NoStmt) - |> add_dep name - |> pair naming' - |> curry generate (SOME name) - ||> snd - |-> (fn stmt => (apsnd o Graph.map_node name) (K stmt)) - |> pair dep - |> pair name - end; - -fun not_wellsorted thy thm ty sort e = - let - val err_class = Sorts.class_error (Syntax.pp_global thy) e; - val err_thm = case thm - of SOME thm => "\n(in code equation " ^ Display.string_of_thm thm ^ ")" | NONE => ""; - val err_typ = "Type " ^ Syntax.string_of_typ_global thy ty ^ " not of sort " - ^ Syntax.string_of_sort_global thy sort; - in error ("Wellsortedness error" ^ err_thm ^ ":\n" ^ err_typ ^ "\n" ^ err_class) end; - - -(* translation *) - -fun ensure_tyco thy algbr funcgr tyco = - let - val stmt_datatype = - let - val (vs, cos) = Code.get_datatype thy tyco; - in - fold_map (translate_tyvar_sort thy algbr funcgr) vs - ##>> fold_map (fn (c, tys) => - ensure_const thy algbr funcgr c - ##>> fold_map (translate_typ thy algbr funcgr) tys) cos - #>> (fn info => Datatype (tyco, info)) - end; - in ensure_stmt lookup_tyco (declare_tyco thy) stmt_datatype tyco end -and ensure_const thy algbr funcgr c = - let - fun stmt_datatypecons tyco = - ensure_tyco thy algbr funcgr tyco - #>> (fn tyco => Datatypecons (c, tyco)); - fun stmt_classparam class = - ensure_class thy algbr funcgr class - #>> (fn class => Classparam (c, class)); - fun stmt_fun ((vs, ty), raw_thms) = - let - val thms = if null (Term.add_tfreesT ty []) orelse (null o fst o strip_type) ty - then raw_thms - else (map o apfst) (Code.expand_eta thy 1) raw_thms; - in - fold_map (translate_tyvar_sort thy algbr funcgr) vs - ##>> translate_typ thy algbr funcgr ty - ##>> fold_map (translate_eq thy algbr funcgr) thms - #>> (fn info => Fun (c, info)) - end; - val stmt_const = case Code.get_datatype_of_constr thy c - of SOME tyco => stmt_datatypecons tyco - | NONE => (case AxClass.class_of_param thy c - of SOME class => stmt_classparam class - | NONE => stmt_fun (Code_Preproc.typ funcgr c, Code_Preproc.eqns funcgr c)) - in ensure_stmt lookup_const (declare_const thy) stmt_const c end -and ensure_class thy (algbr as (_, algebra)) funcgr class = - let - val superclasses = (Sorts.minimize_sort algebra o Sorts.super_classes algebra) class; - val cs = #params (AxClass.get_info thy class); - val stmt_class = - fold_map (fn superclass => ensure_class thy algbr funcgr superclass - ##>> ensure_classrel thy algbr funcgr (class, superclass)) superclasses - ##>> fold_map (fn (c, ty) => ensure_const thy algbr funcgr c - ##>> translate_typ thy algbr funcgr ty) cs - #>> (fn info => Class (class, (unprefix "'" Name.aT, info))) - in ensure_stmt lookup_class (declare_class thy) stmt_class class end -and ensure_classrel thy algbr funcgr (subclass, superclass) = - let - val stmt_classrel = - ensure_class thy algbr funcgr subclass - ##>> ensure_class thy algbr funcgr superclass - #>> Classrel; - in ensure_stmt lookup_classrel (declare_classrel thy) stmt_classrel (subclass, superclass) end -and ensure_inst thy (algbr as (_, algebra)) funcgr (class, tyco) = - let - val superclasses = (Sorts.minimize_sort algebra o Sorts.super_classes algebra) class; - val classparams = these (try (#params o AxClass.get_info thy) class); - val vs = Name.names Name.context "'a" (Sorts.mg_domain algebra tyco [class]); - val sorts' = Sorts.mg_domain (Sign.classes_of thy) tyco [class]; - val vs' = map2 (fn (v, sort1) => fn sort2 => (v, - Sorts.inter_sort (Sign.classes_of thy) (sort1, sort2))) vs sorts'; - val arity_typ = Type (tyco, map TFree vs); - val arity_typ' = Type (tyco, map (fn (v, sort) => TVar ((v, 0), sort)) vs'); - fun translate_superarity superclass = - ensure_class thy algbr funcgr superclass - ##>> ensure_classrel thy algbr funcgr (class, superclass) - ##>> translate_dicts thy algbr funcgr NONE (arity_typ, [superclass]) - #>> (fn ((superclass, classrel), [DictConst (inst, dss)]) => - (superclass, (classrel, (inst, dss)))); - fun translate_classparam_inst (c, ty) = - let - val c_inst = Const (c, map_type_tfree (K arity_typ') ty); - val thm = AxClass.unoverload_conv thy (Thm.cterm_of thy c_inst); - val c_ty = (apsnd Logic.unvarifyT o dest_Const o snd - o Logic.dest_equals o Thm.prop_of) thm; - in - ensure_const thy algbr funcgr c - ##>> translate_const thy algbr funcgr (SOME thm) c_ty - #>> (fn (c, IConst c_inst) => ((c, c_inst), (thm, true))) - end; - val stmt_inst = - ensure_class thy algbr funcgr class - ##>> ensure_tyco thy algbr funcgr tyco - ##>> fold_map (translate_tyvar_sort thy algbr funcgr) vs - ##>> fold_map translate_superarity superclasses - ##>> fold_map translate_classparam_inst classparams - #>> (fn ((((class, tyco), arity), superarities), classparams) => - Classinst ((class, (tyco, arity)), (superarities, classparams))); - in ensure_stmt lookup_instance (declare_instance thy) stmt_inst (class, tyco) end -and translate_typ thy algbr funcgr (TFree (v, _)) = - pair (ITyVar (unprefix "'" v)) - | translate_typ thy algbr funcgr (Type (tyco, tys)) = - ensure_tyco thy algbr funcgr tyco - ##>> fold_map (translate_typ thy algbr funcgr) tys - #>> (fn (tyco, tys) => tyco `%% tys) -and translate_term thy algbr funcgr thm (Const (c, ty)) = - translate_app thy algbr funcgr thm ((c, ty), []) - | translate_term thy algbr funcgr thm (Free (v, _)) = - pair (IVar v) - | translate_term thy algbr funcgr thm (Abs (abs as (_, ty, _))) = - let - val (v, t) = Syntax.variant_abs abs; - in - translate_typ thy algbr funcgr ty - ##>> translate_term thy algbr funcgr thm t - #>> (fn (ty, t) => (v, ty) `|=> t) - end - | translate_term thy algbr funcgr thm (t as _ $ _) = - case strip_comb t - of (Const (c, ty), ts) => - translate_app thy algbr funcgr thm ((c, ty), ts) - | (t', ts) => - translate_term thy algbr funcgr thm t' - ##>> fold_map (translate_term thy algbr funcgr thm) ts - #>> (fn (t, ts) => t `$$ ts) -and translate_eq thy algbr funcgr (thm, proper) = - let - val (args, rhs) = (apfst (snd o strip_comb) o Logic.dest_equals - o Logic.unvarify o prop_of) thm; - in - fold_map (translate_term thy algbr funcgr (SOME thm)) args - ##>> translate_term thy algbr funcgr (SOME thm) rhs - #>> rpair (thm, proper) - end -and translate_const thy algbr funcgr thm (c, ty) = - let - val tys = Sign.const_typargs thy (c, ty); - val sorts = (map snd o fst o Code_Preproc.typ funcgr) c; - val tys_args = (fst o Term.strip_type) ty; - in - ensure_const thy algbr funcgr c - ##>> fold_map (translate_typ thy algbr funcgr) tys - ##>> fold_map (translate_dicts thy algbr funcgr thm) (tys ~~ sorts) - ##>> fold_map (translate_typ thy algbr funcgr) tys_args - #>> (fn (((c, tys), iss), tys_args) => IConst (c, ((tys, iss), tys_args))) - end -and translate_app_const thy algbr funcgr thm (c_ty, ts) = - translate_const thy algbr funcgr thm c_ty - ##>> fold_map (translate_term thy algbr funcgr thm) ts - #>> (fn (t, ts) => t `$$ ts) -and translate_case thy algbr funcgr thm (num_args, (t_pos, case_pats)) (c_ty, ts) = - let - val (tys, _) = (chop num_args o fst o strip_type o snd) c_ty; - val t = nth ts t_pos; - val ty = nth tys t_pos; - val ts_clause = nth_drop t_pos ts; - fun mk_clause (co, num_co_args) t = - let - val (vs, body) = Term.strip_abs_eta num_co_args t; - val not_undefined = case body - of (Const (c, _)) => not (Code.is_undefined thy c) - | _ => true; - val pat = list_comb (Const (co, map snd vs ---> ty), map Free vs); - in (not_undefined, (pat, body)) end; - val clauses = if null case_pats then let val ([v_ty], body) = - Term.strip_abs_eta 1 (the_single ts_clause) - in [(true, (Free v_ty, body))] end - else map (uncurry mk_clause) - (AList.make (Code.no_args thy) case_pats ~~ ts_clause); - fun retermify ty (_, (IVar x, body)) = - (x, ty) `|=> body - | retermify _ (_, (pat, body)) = - let - val (IConst (_, (_, tys)), ts) = unfold_app pat; - val vs = map2 (fn IVar x => fn ty => (x, ty)) ts tys; - in vs `|==> body end; - fun mk_icase const t ty clauses = - let - val (ts1, ts2) = chop t_pos (map (retermify ty) clauses); - in - ICase (((t, ty), map_filter (fn (b, d) => if b then SOME d else NONE) clauses), - const `$$ (ts1 @ t :: ts2)) - end; - in - translate_const thy algbr funcgr thm c_ty - ##>> translate_term thy algbr funcgr thm t - ##>> translate_typ thy algbr funcgr ty - ##>> fold_map (fn (b, (pat, body)) => translate_term thy algbr funcgr thm pat - ##>> translate_term thy algbr funcgr thm body - #>> pair b) clauses - #>> (fn (((const, t), ty), ds) => mk_icase const t ty ds) - end -and translate_app_case thy algbr funcgr thm (case_scheme as (num_args, _)) ((c, ty), ts) = - if length ts < num_args then - let - val k = length ts; - val tys = (curry Library.take (num_args - k) o curry Library.drop k o fst o strip_type) ty; - val ctxt = (fold o fold_aterms) Term.declare_term_frees ts Name.context; - val vs = Name.names ctxt "a" tys; - in - fold_map (translate_typ thy algbr funcgr) tys - ##>> translate_case thy algbr funcgr thm case_scheme ((c, ty), ts @ map Free vs) - #>> (fn (tys, t) => map2 (fn (v, _) => pair v) vs tys `|==> t) - end - else if length ts > num_args then - translate_case thy algbr funcgr thm case_scheme ((c, ty), Library.take (num_args, ts)) - ##>> fold_map (translate_term thy algbr funcgr thm) (Library.drop (num_args, ts)) - #>> (fn (t, ts) => t `$$ ts) - else - translate_case thy algbr funcgr thm case_scheme ((c, ty), ts) -and translate_app thy algbr funcgr thm (c_ty_ts as ((c, _), _)) = - case Code.get_case_scheme thy c - of SOME case_scheme => translate_app_case thy algbr funcgr thm case_scheme c_ty_ts - | NONE => translate_app_const thy algbr funcgr thm c_ty_ts -and translate_tyvar_sort thy (algbr as (proj_sort, _)) funcgr (v, sort) = - fold_map (ensure_class thy algbr funcgr) (proj_sort sort) - #>> (fn sort => (unprefix "'" v, sort)) -and translate_dicts thy (algbr as (proj_sort, algebra)) funcgr thm (ty, sort) = - let - val pp = Syntax.pp_global thy; - datatype typarg = - Global of (class * string) * typarg list list - | Local of (class * class) list * (string * (int * sort)); - fun class_relation (Global ((_, tyco), yss), _) class = - Global ((class, tyco), yss) - | class_relation (Local (classrels, v), subclass) superclass = - Local ((subclass, superclass) :: classrels, v); - fun type_constructor tyco yss class = - Global ((class, tyco), (map o map) fst yss); - fun type_variable (TFree (v, sort)) = - let - val sort' = proj_sort sort; - in map_index (fn (n, class) => (Local ([], (v, (n, sort'))), class)) sort' end; - val typargs = Sorts.of_sort_derivation pp algebra - {class_relation = class_relation, type_constructor = type_constructor, - type_variable = type_variable} (ty, proj_sort sort) - handle Sorts.CLASS_ERROR e => not_wellsorted thy thm ty sort e; - fun mk_dict (Global (inst, yss)) = - ensure_inst thy algbr funcgr inst - ##>> (fold_map o fold_map) mk_dict yss - #>> (fn (inst, dss) => DictConst (inst, dss)) - | mk_dict (Local (classrels, (v, (k, sort)))) = - fold_map (ensure_classrel thy algbr funcgr) classrels - #>> (fn classrels => DictVar (classrels, (unprefix "'" v, (k, length sort)))) - in fold_map mk_dict typargs end; - - -(* store *) - -structure Program = CodeDataFun -( - type T = naming * program; - val empty = (empty_naming, Graph.empty); - fun purge thy cs (naming, program) = - let - val names_delete = cs - |> map_filter (lookup_const naming) - |> filter (can (Graph.get_node program)) - |> Graph.all_preds program; - val program' = Graph.del_nodes names_delete program; - in (naming, program') end; -); - -val cached_program = Program.get; - -fun invoke_generation thy (algebra, funcgr) f name = - Program.change_yield thy (fn naming_program => (NONE, naming_program) - |> f thy algebra funcgr name - |-> (fn name => fn (_, naming_program) => (name, naming_program))); - - -(* program generation *) - -fun consts_program thy cs = - let - fun project_consts cs (naming, program) = - let - val cs_all = Graph.all_succs program cs; - in (cs, (naming, Graph.subgraph (member (op =) cs_all) program)) end; - fun generate_consts thy algebra funcgr = - fold_map (ensure_const thy algebra funcgr); - in - invoke_generation thy (Code_Preproc.obtain thy cs []) generate_consts cs - |-> project_consts - end; - - -(* value evaluation *) - -fun ensure_value thy algbr funcgr t = - let - val ty = fastype_of t; - val vs = fold_term_types (K (fold_atyps (insert (eq_fst op =) - o dest_TFree))) t []; - val stmt_value = - fold_map (translate_tyvar_sort thy algbr funcgr) vs - ##>> translate_typ thy algbr funcgr ty - ##>> translate_term thy algbr funcgr NONE t - #>> (fn ((vs, ty), t) => Fun - (Term.dummy_patternN, ((vs, ty), [(([], t), (Drule.dummy_thm, true))]))); - fun term_value (dep, (naming, program1)) = - let - val Fun (_, (vs_ty, [(([], t), _)])) = - Graph.get_node program1 Term.dummy_patternN; - val deps = Graph.imm_succs program1 Term.dummy_patternN; - val program2 = Graph.del_nodes [Term.dummy_patternN] program1; - val deps_all = Graph.all_succs program2 deps; - val program3 = Graph.subgraph (member (op =) deps_all) program2; - in (((naming, program3), ((vs_ty, t), deps)), (dep, (naming, program2))) end; - in - ensure_stmt ((K o K) NONE) pair stmt_value Term.dummy_patternN - #> snd - #> term_value - end; - -fun base_evaluator thy evaluator algebra funcgr vs t = - let - val (((naming, program), (((vs', ty'), t'), deps)), _) = - invoke_generation thy (algebra, funcgr) ensure_value t; - val vs'' = map (fn (v, _) => (v, (the o AList.lookup (op =) vs o prefix "'") v)) vs'; - in evaluator naming program ((vs'', (vs', ty')), t') deps end; - -fun eval_conv thy prep_sort = Code_Preproc.eval_conv thy prep_sort o base_evaluator thy; -fun eval thy prep_sort postproc = Code_Preproc.eval thy prep_sort postproc o base_evaluator thy; - - -(** diagnostic commands **) - -fun read_const_exprs thy = - let - fun consts_of some_thyname = - let - val thy' = case some_thyname - of SOME thyname => ThyInfo.the_theory thyname thy - | NONE => thy; - val cs = Symtab.fold (fn (c, (_, NONE)) => cons c | _ => I) - ((snd o #constants o Consts.dest o #consts o Sign.rep_sg) thy') []; - fun belongs_here c = - not (exists (fn thy'' => Sign.declared_const thy'' c) (Theory.parents_of thy')) - in case some_thyname - of NONE => cs - | SOME thyname => filter belongs_here cs - end; - fun read_const_expr "*" = ([], consts_of NONE) - | read_const_expr s = if String.isSuffix ".*" s - then ([], consts_of (SOME (unsuffix ".*" s))) - else ([Code.read_const thy s], []); - in pairself flat o split_list o map read_const_expr end; - -fun code_depgr thy consts = - let - val (_, eqngr) = Code_Preproc.obtain thy consts []; - val select = Graph.all_succs eqngr consts; - in - eqngr - |> not (null consts) ? Graph.subgraph (member (op =) select) - |> Graph.map_nodes ((apsnd o map o apfst) (AxClass.overload thy)) - end; - -fun code_thms thy = Pretty.writeln o Code_Preproc.pretty thy o code_depgr thy; - -fun code_deps thy consts = - let - val eqngr = code_depgr thy consts; - val constss = Graph.strong_conn eqngr; - val mapping = Symtab.empty |> fold (fn consts => fold (fn const => - Symtab.update (const, consts)) consts) constss; - fun succs consts = consts - |> maps (Graph.imm_succs eqngr) - |> subtract (op =) consts - |> map (the o Symtab.lookup mapping) - |> distinct (op =); - val conn = [] |> fold (fn consts => cons (consts, succs consts)) constss; - fun namify consts = map (Code.string_of_const thy) consts - |> commas; - val prgr = map (fn (consts, constss) => - { name = namify consts, ID = namify consts, dir = "", unfold = true, - path = "", parents = map namify constss }) conn; - in Present.display_graph prgr end; - -local - -structure P = OuterParse -and K = OuterKeyword - -fun code_thms_cmd thy = code_thms thy o op @ o read_const_exprs thy; -fun code_deps_cmd thy = code_deps thy o op @ o read_const_exprs thy; - -in - -val _ = - OuterSyntax.improper_command "code_thms" "print system of code equations for code" OuterKeyword.diag - (Scan.repeat P.term_group - >> (fn cs => Toplevel.no_timing o Toplevel.unknown_theory - o Toplevel.keep ((fn thy => code_thms_cmd thy cs) o Toplevel.theory_of))); - -val _ = - OuterSyntax.improper_command "code_deps" "visualize dependencies of code equations for code" OuterKeyword.diag - (Scan.repeat P.term_group - >> (fn cs => Toplevel.no_timing o Toplevel.unknown_theory - o Toplevel.keep ((fn thy => code_deps_cmd thy cs) o Toplevel.theory_of))); - -end; - -end; (*struct*) - - -structure Basic_Code_Thingol: BASIC_CODE_THINGOL = Code_Thingol;