# HG changeset patch # User haftmann # Date 1252488680 -7200 # Node ID e129333b9df05f75069e165623930c7885212ef2 # Parent 62e6c9b67c6fcc3653c4d2295409a5c707ea05f9 moved eq handling in nbe into separate oracle diff -r 62e6c9b67c6f -r e129333b9df0 src/HOL/HOL.thy --- a/src/HOL/HOL.thy Tue Sep 08 18:31:26 2009 +0200 +++ b/src/HOL/HOL.thy Wed Sep 09 11:31:20 2009 +0200 @@ -1887,7 +1887,7 @@ *} setup {* - Code.add_const_alias @{thm equals_alias_cert} + Nbe.add_const_alias @{thm equals_alias_cert} *} hide (open) const eq diff -r 62e6c9b67c6f -r e129333b9df0 src/HOL/ex/NormalForm.thy --- a/src/HOL/ex/NormalForm.thy Tue Sep 08 18:31:26 2009 +0200 +++ b/src/HOL/ex/NormalForm.thy Wed Sep 09 11:31:20 2009 +0200 @@ -4,8 +4,17 @@ theory NormalForm imports Main Rational +uses "~~/src/Tools/nbe.ML" begin +setup {* + Nbe.add_const_alias @{thm equals_alias_cert} +*} + +method_setup normalization = {* + Scan.succeed (K (SIMPLE_METHOD' (CONVERSION Nbe.norm_conv THEN' (fn k => TRY (rtac TrueI k))))) +*} "solve goal by normalization" + lemma "True" by normalization lemma "p \ True" by normalization declare disj_assoc [code nbe] @@ -120,4 +129,13 @@ normal_form "(%m n f x. m (n f) x) (%f x. f(f(f(x)))) (%f x. f(f(f(x))))" normal_form "(%m n. n m) (%f x. f(f(f(x)))) (%f x. f(f(f(x))))" +(* handling of type classes in connection with equality *) + +lemma "map f [x, y] = [f x, f y]" by normalization +lemma "(map f [x, y], w) = ([f x, f y], w)" by normalization +lemma "map f [x, y] = [f x \ 'a\semigroup_add, f y]" by normalization +lemma "map f [x \ 'a\semigroup_add, y] = [f x, f y]" by normalization +lemma "(map f [x \ 'a\semigroup_add, y], w \ 'b\finite) = ([f x, f y], w)" by normalization + + end diff -r 62e6c9b67c6f -r e129333b9df0 src/Pure/Isar/code.ML --- a/src/Pure/Isar/code.ML Tue Sep 08 18:31:26 2009 +0200 +++ b/src/Pure/Isar/code.ML Wed Sep 09 11:31:20 2009 +0200 @@ -19,11 +19,6 @@ val constrset_of_consts: theory -> (string * typ) list -> string * ((string * sort) list * (string * typ list) list) - (*constant aliasses*) - val add_const_alias: thm -> theory -> theory - val triv_classes: theory -> class list - val resubst_alias: theory -> string -> string - (*code equations*) val mk_eqn: theory -> thm * bool -> thm * bool val mk_eqn_warning: theory -> thm -> (thm * bool) option @@ -169,7 +164,6 @@ datatype spec = Spec of { history_concluded: bool, - aliasses: ((string * string) * thm) list * class list, eqns: ((bool * eqns) * (serial * eqns) list) Symtab.table (*with explicit history*), dtyps: ((serial * ((string * sort) list * (string * typ list) list)) list) Symtab.table @@ -177,19 +171,16 @@ cases: (int * (int * string list)) Symtab.table * unit Symtab.table }; -fun make_spec ((history_concluded, aliasses), (eqns, (dtyps, cases))) = - Spec { history_concluded = history_concluded, aliasses = aliasses, - eqns = eqns, dtyps = dtyps, cases = cases }; -fun map_spec f (Spec { history_concluded = history_concluded, aliasses = aliasses, eqns = eqns, +fun make_spec (history_concluded, (eqns, (dtyps, cases))) = + Spec { history_concluded = history_concluded, eqns = eqns, dtyps = dtyps, cases = cases }; +fun map_spec f (Spec { history_concluded = history_concluded, eqns = eqns, dtyps = dtyps, cases = cases }) = - make_spec (f ((history_concluded, aliasses), (eqns, (dtyps, cases)))); -fun merge_spec (Spec { history_concluded = _, aliasses = aliasses1, eqns = eqns1, + make_spec (f (history_concluded, (eqns, (dtyps, cases)))); +fun merge_spec (Spec { history_concluded = _, eqns = eqns1, dtyps = dtyps1, cases = (cases1, undefs1) }, - Spec { history_concluded = _, aliasses = aliasses2, eqns = eqns2, + Spec { history_concluded = _, eqns = eqns2, dtyps = dtyps2, cases = (cases2, undefs2) }) = let - val aliasses = (Library.merge (eq_snd Thm.eq_thm_prop) (pairself fst (aliasses1, aliasses2)), - Library.merge (op =) (pairself snd (aliasses1, aliasses2))); fun merge_eqns ((_, history1), (_, history2)) = let val raw_history = AList.merge (op = : serial * serial -> bool) @@ -202,15 +193,13 @@ val dtyps = Symtab.join (K (AList.merge (op =) (K true))) (dtyps1, dtyps2); val cases = (Symtab.merge (K true) (cases1, cases2), Symtab.merge (K true) (undefs1, undefs2)); - in make_spec ((false, aliasses), (eqns, (dtyps, cases))) end; + in make_spec (false, (eqns, (dtyps, cases))) end; fun history_concluded (Spec { history_concluded, ... }) = history_concluded; -fun the_aliasses (Spec { aliasses, ... }) = aliasses; fun the_eqns (Spec { eqns, ... }) = eqns; fun the_dtyps (Spec { dtyps, ... }) = dtyps; fun the_cases (Spec { cases, ... }) = cases; -val map_history_concluded = map_spec o apfst o apfst; -val map_aliasses = map_spec o apfst o apsnd; +val map_history_concluded = map_spec o apfst; val map_eqns = map_spec o apsnd o apfst; val map_dtyps = map_spec o apsnd o apsnd o apfst; val map_cases = map_spec o apsnd o apsnd o apsnd; @@ -264,7 +253,7 @@ structure Code_Data = TheoryDataFun ( type T = spec * data ref; - val empty = (make_spec ((false, ([], [])), + val empty = (make_spec (false, (Symtab.empty, (Symtab.empty, (Symtab.empty, Symtab.empty)))), ref empty_data); fun copy (spec, data) = (spec, ref (! data)); val extend = copy; @@ -358,24 +347,6 @@ end; (*local*) -(** retrieval interfaces **) - -(* constant aliasses *) - -fun resubst_alias thy = - let - val alias = (fst o the_aliasses o the_exec) thy; - val subst_inst_param = Option.map fst o AxClass.inst_of_param thy; - fun subst_alias c = - get_first (fn ((c', c''), _) => if c = c'' then SOME c' else NONE) alias; - in - perhaps subst_inst_param - #> perhaps subst_alias - end; - -val triv_classes = snd o the_aliasses o the_exec; - - (** foundation **) (* constants *) @@ -669,38 +640,6 @@ (** declaring executable ingredients **) -(* constant aliasses *) - -fun add_const_alias thm thy = - let - val (ofclass, eqn) = case try Logic.dest_equals (Thm.prop_of thm) - of SOME ofclass_eq => ofclass_eq - | _ => error ("Bad certificate: " ^ Display.string_of_thm_global thy thm); - val (T, class) = case try Logic.dest_of_class ofclass - of SOME T_class => T_class - | _ => error ("Bad certificate: " ^ Display.string_of_thm_global thy thm); - val tvar = case try Term.dest_TVar T - of SOME (tvar as (_, sort)) => if null (filter (can (AxClass.get_info thy)) sort) - then tvar - else error ("Bad sort: " ^ Display.string_of_thm_global thy thm) - | _ => error ("Bad type: " ^ Display.string_of_thm_global thy thm); - val _ = if Term.add_tvars eqn [] = [tvar] then () - else error ("Inconsistent type: " ^ Display.string_of_thm_global thy thm); - val lhs_rhs = case try Logic.dest_equals eqn - of SOME lhs_rhs => lhs_rhs - | _ => error ("Not an equation: " ^ Syntax.string_of_term_global thy eqn); - val c_c' = case try (pairself (check_const thy)) lhs_rhs - of SOME c_c' => c_c' - | _ => error ("Not an equation with two constants: " - ^ Syntax.string_of_term_global thy eqn); - val _ = if the_list (AxClass.class_of_param thy (snd c_c')) = [class] then () - else error ("Inconsistent class: " ^ Display.string_of_thm_global thy thm); - in thy |> - (map_exec_purge NONE o map_aliasses) (fn (alias, classes) => - ((c_c', thm) :: alias, insert (op =) class classes)) - end; - - (* datatypes *) structure Type_Interpretation = InterpretationFun(type T = string * serial val eq = eq_snd (op =) : T * T -> bool); diff -r 62e6c9b67c6f -r e129333b9df0 src/Tools/Code/code_preproc.ML --- a/src/Tools/Code/code_preproc.ML Tue Sep 08 18:31:26 2009 +0200 +++ b/src/Tools/Code/code_preproc.ML Wed Sep 09 11:31:20 2009 +0200 @@ -23,7 +23,6 @@ val all: code_graph -> string list val pretty: theory -> code_graph -> Pretty.T val obtain: theory -> string list -> term list -> code_algebra * code_graph - val resubst_triv_consts: theory -> term -> term val eval_conv: theory -> (code_algebra -> code_graph -> (string * sort) list -> term -> cterm -> thm) -> cterm -> thm val eval: theory -> ((term -> term) -> 'a -> 'a) @@ -73,10 +72,8 @@ if AList.defined (op =) xs key then AList.delete (op =) key xs else error ("No such " ^ msg ^ ": " ^ quote key); -fun map_data f thy = - thy - |> Code.purge_data - |> (Code_Preproc_Data.map o map_thmproc) f; +fun map_data f = Code.purge_data + #> (Code_Preproc_Data.map o map_thmproc) f; val map_pre_post = map_data o apfst; val map_pre = map_pre_post o apfst; @@ -163,10 +160,7 @@ |> rhs_conv (Simplifier.rewrite post) end; -fun resubst_triv_consts thy = map_aterms (fn t as Const (c, ty) => Const (Code.resubst_alias thy c, ty) - | t => t); - -fun postprocess_term thy = term_of_conv thy (postprocess_conv thy) #> resubst_triv_consts thy; +fun postprocess_term thy = term_of_conv thy (postprocess_conv thy); fun print_codeproc thy = let @@ -489,17 +483,6 @@ fun obtain thy cs ts = apsnd snd (Wellsorted.change_yield thy (extend_arities_eqngr thy cs ts)); -fun prepare_sorts_typ prep_sort - = map_type_tfree (fn (v, sort) => TFree (v, prep_sort sort)); - -fun prepare_sorts prep_sort (Const (c, ty)) = - Const (c, prepare_sorts_typ prep_sort ty) - | prepare_sorts prep_sort (t1 $ t2) = - prepare_sorts prep_sort t1 $ prepare_sorts prep_sort t2 - | prepare_sorts prep_sort (Abs (v, ty, t)) = - Abs (v, prepare_sorts_typ prep_sort ty, prepare_sorts prep_sort t) - | prepare_sorts _ (t as Bound _) = t; - fun gen_eval thy cterm_of conclude_evaluation evaluator proto_ct = let val pp = Syntax.pp_global thy; @@ -512,12 +495,8 @@ val vs = Term.add_tfrees t' []; val consts = fold_aterms (fn Const (c, _) => insert (op =) c | _ => I) t' []; - - val add_triv_classes = curry (Sorts.inter_sort (Sign.classes_of thy)) - (Code.triv_classes thy); - val t'' = prepare_sorts add_triv_classes t'; - val (algebra', eqngr') = obtain thy consts [t'']; - in conclude_evaluation (evaluator algebra' eqngr' vs t'' ct') thm end; + val (algebra', eqngr') = obtain thy consts [t']; + in conclude_evaluation (evaluator algebra' eqngr' vs t' ct') thm end; fun simple_evaluator evaluator algebra eqngr vs t ct = evaluator algebra eqngr vs t; diff -r 62e6c9b67c6f -r e129333b9df0 src/Tools/nbe.ML --- a/src/Tools/nbe.ML Tue Sep 08 18:31:26 2009 +0200 +++ b/src/Tools/nbe.ML Wed Sep 09 11:31:20 2009 +0200 @@ -23,6 +23,7 @@ val trace: bool ref val setup: theory -> theory + val add_const_alias: thm -> theory -> theory end; structure Nbe: NBE = @@ -31,10 +32,107 @@ (* generic non-sense *) val trace = ref false; -fun tracing f x = if !trace then (Output.tracing (f x); x) else x; +fun traced f x = if !trace then (tracing (f x); x) else x; -(** the semantical universe **) +(** certificates and oracle for "trivial type classes" **) + +structure Triv_Class_Data = TheoryDataFun +( + type T = (class * thm) list; + val empty = []; + val copy = I; + val extend = I; + fun merge pp = AList.merge (op =) (K true); +); + +fun add_const_alias thm thy = + let + val (ofclass, eqn) = case try Logic.dest_equals (Thm.prop_of thm) + of SOME ofclass_eq => ofclass_eq + | _ => error ("Bad certificate: " ^ Display.string_of_thm_global thy thm); + val (T, class) = case try Logic.dest_of_class ofclass + of SOME T_class => T_class + | _ => error ("Bad certificate: " ^ Display.string_of_thm_global thy thm); + val tvar = case try Term.dest_TVar T + of SOME (tvar as (_, sort)) => if null (filter (can (AxClass.get_info thy)) sort) + then tvar + else error ("Bad sort: " ^ Display.string_of_thm_global thy thm) + | _ => error ("Bad type: " ^ Display.string_of_thm_global thy thm); + val _ = if Term.add_tvars eqn [] = [tvar] then () + else error ("Inconsistent type: " ^ Display.string_of_thm_global thy thm); + val lhs_rhs = case try Logic.dest_equals eqn + of SOME lhs_rhs => lhs_rhs + | _ => error ("Not an equation: " ^ Syntax.string_of_term_global thy eqn); + val c_c' = case try (pairself (Code.check_const thy)) lhs_rhs + of SOME c_c' => c_c' + | _ => error ("Not an equation with two constants: " + ^ Syntax.string_of_term_global thy eqn); + val _ = if the_list (AxClass.class_of_param thy (snd c_c')) = [class] then () + else error ("Inconsistent class: " ^ Display.string_of_thm_global thy thm); + in Triv_Class_Data.map (AList.update (op =) (class, thm)) thy end; + +local + +val get_triv_classes = map fst o Triv_Class_Data.get; + +val (_, triv_of_class) = Context.>>> (Context.map_theory_result + (Thm.add_oracle (Binding.name "triv_of_class", fn (thy, (v, T), class) => + Thm.cterm_of thy (Logic.mk_of_class (T, class))))); + +in + +fun lift_triv_classes_conv thy conv ct = + let + val algebra = Sign.classes_of thy; + val triv_classes = get_triv_classes thy; + val certT = Thm.ctyp_of thy; + fun critical_classes sort = filter_out (fn class => Sign.subsort thy (sort, [class])) triv_classes; + val vs = Term.add_tfrees (Thm.term_of ct) [] + |> map_filter (fn (v, sort) => case critical_classes sort + of [] => NONE + | classes => SOME (v, ((sort, classes), Sorts.inter_sort algebra (triv_classes, sort)))); + val of_classes = maps (fn (v, ((sort, classes), _)) => map (fn class => + ((v, class), triv_of_class (thy, (v, TVar ((v, 0), sort)), class))) classes + @ map (fn class => ((v, class), Thm.of_class (certT (TVar ((v, 0), sort)), class))) + sort) vs; + fun strip_of_class thm = + let + val prem_props = (Logic.strip_imp_prems o Thm.prop_of) thm; + val prem_thms = map (the o AList.lookup (op =) of_classes + o apfst (fst o fst o dest_TVar) o Logic.dest_of_class) prem_props; + in Drule.implies_elim_list thm prem_thms end; + in ct + |> Drule.cterm_rule Thm.varifyT + |> Thm.instantiate_cterm (Thm.certify_inst thy (map (fn (v, ((sort, _), sort')) => + (((v, 0), sort), TFree (v, sort'))) vs, [])) + |> Drule.cterm_rule Thm.freezeT + |> conv + |> Thm.varifyT + |> fold (fn (v, (_, sort')) => Thm.unconstrainT (certT (TVar ((v, 0), sort')))) vs + |> Thm.certify_instantiate (map (fn (v, ((sort, _), _)) => + (((v, 0), []), TVar ((v, 0), sort))) vs, []) + |> strip_of_class + |> Thm.freezeT + end; + +fun lift_triv_classes_rew thy rew t = + let + val algebra = Sign.classes_of thy; + val triv_classes = get_triv_classes thy; + val vs = Term.add_tfrees t []; + in t + |> (map_types o map_type_tfree) + (fn (v, sort) => TFree (v, Sorts.inter_sort algebra (sort, triv_classes))) + |> rew + |> (map_types o map_type_tfree) + (fn (v, _) => TFree (v, the (AList.lookup (op =) vs v))) + end; + +end; + + +(** the semantic universe **) (* Functions are given by their semantical function value. To avoid @@ -275,7 +373,7 @@ val cs = map fst eqnss; in s - |> tracing (fn s => "\n--- code to be evaluated:\n" ^ s) + |> traced (fn s => "\n--- code to be evaluated:\n" ^ s) |> ML_Context.evaluate ctxt (!trace) univs_cookie |> (fn f => f deps_vals) |> (fn univs => cs ~~ univs) @@ -450,14 +548,14 @@ val string_of_term = setmp show_types true (Syntax.string_of_term_global thy); in compile_eval thy naming program (vs, t) deps - |> Code_Preproc.resubst_triv_consts thy - |> tracing (fn t => "Normalized:\n" ^ string_of_term t) + |> traced (fn t => "Normalized:\n" ^ string_of_term t) |> type_infer - |> tracing (fn t => "Types inferred:\n" ^ string_of_term t) + |> traced (fn t => "Types inferred:\n" ^ string_of_term t) |> check_tvars - |> tracing (fn t => "---\n") + |> traced (fn t => "---\n") end; + (* evaluation oracle *) fun mk_equals thy lhs raw_rhs = @@ -500,9 +598,9 @@ val norm_conv = no_frees_conv (fn ct => let val thy = Thm.theory_of_cterm ct; - in Code_Thingol.eval_conv thy (norm_oracle thy) ct end); + in lift_triv_classes_conv thy (Code_Thingol.eval_conv thy (norm_oracle thy)) ct end); -fun norm thy = no_frees_rew (Code_Thingol.eval thy I (normalize thy)); +fun norm thy = lift_triv_classes_rew thy (no_frees_rew (Code_Thingol.eval thy I (normalize thy))); (* evaluation command *)