--- a/src/HOL/HOL.thy Tue Sep 08 18:31:26 2009 +0200
+++ b/src/HOL/HOL.thy Wed Sep 09 11:31:20 2009 +0200
@@ -1887,7 +1887,7 @@
*}
setup {*
- Code.add_const_alias @{thm equals_alias_cert}
+ Nbe.add_const_alias @{thm equals_alias_cert}
*}
hide (open) const eq
--- a/src/HOL/ex/NormalForm.thy Tue Sep 08 18:31:26 2009 +0200
+++ b/src/HOL/ex/NormalForm.thy Wed Sep 09 11:31:20 2009 +0200
@@ -4,8 +4,17 @@
theory NormalForm
imports Main Rational
+uses "~~/src/Tools/nbe.ML"
begin
+setup {*
+ Nbe.add_const_alias @{thm equals_alias_cert}
+*}
+
+method_setup normalization = {*
+ Scan.succeed (K (SIMPLE_METHOD' (CONVERSION Nbe.norm_conv THEN' (fn k => TRY (rtac TrueI k)))))
+*} "solve goal by normalization"
+
lemma "True" by normalization
lemma "p \<longrightarrow> True" by normalization
declare disj_assoc [code nbe]
@@ -120,4 +129,13 @@
normal_form "(%m n f x. m (n f) x) (%f x. f(f(f(x)))) (%f x. f(f(f(x))))"
normal_form "(%m n. n m) (%f x. f(f(f(x)))) (%f x. f(f(f(x))))"
+(* handling of type classes in connection with equality *)
+
+lemma "map f [x, y] = [f x, f y]" by normalization
+lemma "(map f [x, y], w) = ([f x, f y], w)" by normalization
+lemma "map f [x, y] = [f x \<Colon> 'a\<Colon>semigroup_add, f y]" by normalization
+lemma "map f [x \<Colon> 'a\<Colon>semigroup_add, y] = [f x, f y]" by normalization
+lemma "(map f [x \<Colon> 'a\<Colon>semigroup_add, y], w \<Colon> 'b\<Colon>finite) = ([f x, f y], w)" by normalization
+
+
end
--- a/src/Pure/Isar/code.ML Tue Sep 08 18:31:26 2009 +0200
+++ b/src/Pure/Isar/code.ML Wed Sep 09 11:31:20 2009 +0200
@@ -19,11 +19,6 @@
val constrset_of_consts: theory -> (string * typ) list
-> string * ((string * sort) list * (string * typ list) list)
- (*constant aliasses*)
- val add_const_alias: thm -> theory -> theory
- val triv_classes: theory -> class list
- val resubst_alias: theory -> string -> string
-
(*code equations*)
val mk_eqn: theory -> thm * bool -> thm * bool
val mk_eqn_warning: theory -> thm -> (thm * bool) option
@@ -169,7 +164,6 @@
datatype spec = Spec of {
history_concluded: bool,
- aliasses: ((string * string) * thm) list * class list,
eqns: ((bool * eqns) * (serial * eqns) list) Symtab.table
(*with explicit history*),
dtyps: ((serial * ((string * sort) list * (string * typ list) list)) list) Symtab.table
@@ -177,19 +171,16 @@
cases: (int * (int * string list)) Symtab.table * unit Symtab.table
};
-fun make_spec ((history_concluded, aliasses), (eqns, (dtyps, cases))) =
- Spec { history_concluded = history_concluded, aliasses = aliasses,
- eqns = eqns, dtyps = dtyps, cases = cases };
-fun map_spec f (Spec { history_concluded = history_concluded, aliasses = aliasses, eqns = eqns,
+fun make_spec (history_concluded, (eqns, (dtyps, cases))) =
+ Spec { history_concluded = history_concluded, eqns = eqns, dtyps = dtyps, cases = cases };
+fun map_spec f (Spec { history_concluded = history_concluded, eqns = eqns,
dtyps = dtyps, cases = cases }) =
- make_spec (f ((history_concluded, aliasses), (eqns, (dtyps, cases))));
-fun merge_spec (Spec { history_concluded = _, aliasses = aliasses1, eqns = eqns1,
+ make_spec (f (history_concluded, (eqns, (dtyps, cases))));
+fun merge_spec (Spec { history_concluded = _, eqns = eqns1,
dtyps = dtyps1, cases = (cases1, undefs1) },
- Spec { history_concluded = _, aliasses = aliasses2, eqns = eqns2,
+ Spec { history_concluded = _, eqns = eqns2,
dtyps = dtyps2, cases = (cases2, undefs2) }) =
let
- val aliasses = (Library.merge (eq_snd Thm.eq_thm_prop) (pairself fst (aliasses1, aliasses2)),
- Library.merge (op =) (pairself snd (aliasses1, aliasses2)));
fun merge_eqns ((_, history1), (_, history2)) =
let
val raw_history = AList.merge (op = : serial * serial -> bool)
@@ -202,15 +193,13 @@
val dtyps = Symtab.join (K (AList.merge (op =) (K true))) (dtyps1, dtyps2);
val cases = (Symtab.merge (K true) (cases1, cases2),
Symtab.merge (K true) (undefs1, undefs2));
- in make_spec ((false, aliasses), (eqns, (dtyps, cases))) end;
+ in make_spec (false, (eqns, (dtyps, cases))) end;
fun history_concluded (Spec { history_concluded, ... }) = history_concluded;
-fun the_aliasses (Spec { aliasses, ... }) = aliasses;
fun the_eqns (Spec { eqns, ... }) = eqns;
fun the_dtyps (Spec { dtyps, ... }) = dtyps;
fun the_cases (Spec { cases, ... }) = cases;
-val map_history_concluded = map_spec o apfst o apfst;
-val map_aliasses = map_spec o apfst o apsnd;
+val map_history_concluded = map_spec o apfst;
val map_eqns = map_spec o apsnd o apfst;
val map_dtyps = map_spec o apsnd o apsnd o apfst;
val map_cases = map_spec o apsnd o apsnd o apsnd;
@@ -264,7 +253,7 @@
structure Code_Data = TheoryDataFun
(
type T = spec * data ref;
- val empty = (make_spec ((false, ([], [])),
+ val empty = (make_spec (false,
(Symtab.empty, (Symtab.empty, (Symtab.empty, Symtab.empty)))), ref empty_data);
fun copy (spec, data) = (spec, ref (! data));
val extend = copy;
@@ -358,24 +347,6 @@
end; (*local*)
-(** retrieval interfaces **)
-
-(* constant aliasses *)
-
-fun resubst_alias thy =
- let
- val alias = (fst o the_aliasses o the_exec) thy;
- val subst_inst_param = Option.map fst o AxClass.inst_of_param thy;
- fun subst_alias c =
- get_first (fn ((c', c''), _) => if c = c'' then SOME c' else NONE) alias;
- in
- perhaps subst_inst_param
- #> perhaps subst_alias
- end;
-
-val triv_classes = snd o the_aliasses o the_exec;
-
-
(** foundation **)
(* constants *)
@@ -669,38 +640,6 @@
(** declaring executable ingredients **)
-(* constant aliasses *)
-
-fun add_const_alias thm thy =
- let
- val (ofclass, eqn) = case try Logic.dest_equals (Thm.prop_of thm)
- of SOME ofclass_eq => ofclass_eq
- | _ => error ("Bad certificate: " ^ Display.string_of_thm_global thy thm);
- val (T, class) = case try Logic.dest_of_class ofclass
- of SOME T_class => T_class
- | _ => error ("Bad certificate: " ^ Display.string_of_thm_global thy thm);
- val tvar = case try Term.dest_TVar T
- of SOME (tvar as (_, sort)) => if null (filter (can (AxClass.get_info thy)) sort)
- then tvar
- else error ("Bad sort: " ^ Display.string_of_thm_global thy thm)
- | _ => error ("Bad type: " ^ Display.string_of_thm_global thy thm);
- val _ = if Term.add_tvars eqn [] = [tvar] then ()
- else error ("Inconsistent type: " ^ Display.string_of_thm_global thy thm);
- val lhs_rhs = case try Logic.dest_equals eqn
- of SOME lhs_rhs => lhs_rhs
- | _ => error ("Not an equation: " ^ Syntax.string_of_term_global thy eqn);
- val c_c' = case try (pairself (check_const thy)) lhs_rhs
- of SOME c_c' => c_c'
- | _ => error ("Not an equation with two constants: "
- ^ Syntax.string_of_term_global thy eqn);
- val _ = if the_list (AxClass.class_of_param thy (snd c_c')) = [class] then ()
- else error ("Inconsistent class: " ^ Display.string_of_thm_global thy thm);
- in thy |>
- (map_exec_purge NONE o map_aliasses) (fn (alias, classes) =>
- ((c_c', thm) :: alias, insert (op =) class classes))
- end;
-
-
(* datatypes *)
structure Type_Interpretation = InterpretationFun(type T = string * serial val eq = eq_snd (op =) : T * T -> bool);
--- a/src/Tools/Code/code_preproc.ML Tue Sep 08 18:31:26 2009 +0200
+++ b/src/Tools/Code/code_preproc.ML Wed Sep 09 11:31:20 2009 +0200
@@ -23,7 +23,6 @@
val all: code_graph -> string list
val pretty: theory -> code_graph -> Pretty.T
val obtain: theory -> string list -> term list -> code_algebra * code_graph
- val resubst_triv_consts: theory -> term -> term
val eval_conv: theory
-> (code_algebra -> code_graph -> (string * sort) list -> term -> cterm -> thm) -> cterm -> thm
val eval: theory -> ((term -> term) -> 'a -> 'a)
@@ -73,10 +72,8 @@
if AList.defined (op =) xs key then AList.delete (op =) key xs
else error ("No such " ^ msg ^ ": " ^ quote key);
-fun map_data f thy =
- thy
- |> Code.purge_data
- |> (Code_Preproc_Data.map o map_thmproc) f;
+fun map_data f = Code.purge_data
+ #> (Code_Preproc_Data.map o map_thmproc) f;
val map_pre_post = map_data o apfst;
val map_pre = map_pre_post o apfst;
@@ -163,10 +160,7 @@
|> rhs_conv (Simplifier.rewrite post)
end;
-fun resubst_triv_consts thy = map_aterms (fn t as Const (c, ty) => Const (Code.resubst_alias thy c, ty)
- | t => t);
-
-fun postprocess_term thy = term_of_conv thy (postprocess_conv thy) #> resubst_triv_consts thy;
+fun postprocess_term thy = term_of_conv thy (postprocess_conv thy);
fun print_codeproc thy =
let
@@ -489,17 +483,6 @@
fun obtain thy cs ts = apsnd snd
(Wellsorted.change_yield thy (extend_arities_eqngr thy cs ts));
-fun prepare_sorts_typ prep_sort
- = map_type_tfree (fn (v, sort) => TFree (v, prep_sort sort));
-
-fun prepare_sorts prep_sort (Const (c, ty)) =
- Const (c, prepare_sorts_typ prep_sort ty)
- | prepare_sorts prep_sort (t1 $ t2) =
- prepare_sorts prep_sort t1 $ prepare_sorts prep_sort t2
- | prepare_sorts prep_sort (Abs (v, ty, t)) =
- Abs (v, prepare_sorts_typ prep_sort ty, prepare_sorts prep_sort t)
- | prepare_sorts _ (t as Bound _) = t;
-
fun gen_eval thy cterm_of conclude_evaluation evaluator proto_ct =
let
val pp = Syntax.pp_global thy;
@@ -512,12 +495,8 @@
val vs = Term.add_tfrees t' [];
val consts = fold_aterms
(fn Const (c, _) => insert (op =) c | _ => I) t' [];
-
- val add_triv_classes = curry (Sorts.inter_sort (Sign.classes_of thy))
- (Code.triv_classes thy);
- val t'' = prepare_sorts add_triv_classes t';
- val (algebra', eqngr') = obtain thy consts [t''];
- in conclude_evaluation (evaluator algebra' eqngr' vs t'' ct') thm end;
+ val (algebra', eqngr') = obtain thy consts [t'];
+ in conclude_evaluation (evaluator algebra' eqngr' vs t' ct') thm end;
fun simple_evaluator evaluator algebra eqngr vs t ct =
evaluator algebra eqngr vs t;
--- a/src/Tools/nbe.ML Tue Sep 08 18:31:26 2009 +0200
+++ b/src/Tools/nbe.ML Wed Sep 09 11:31:20 2009 +0200
@@ -23,6 +23,7 @@
val trace: bool ref
val setup: theory -> theory
+ val add_const_alias: thm -> theory -> theory
end;
structure Nbe: NBE =
@@ -31,10 +32,107 @@
(* generic non-sense *)
val trace = ref false;
-fun tracing f x = if !trace then (Output.tracing (f x); x) else x;
+fun traced f x = if !trace then (tracing (f x); x) else x;
-(** the semantical universe **)
+(** certificates and oracle for "trivial type classes" **)
+
+structure Triv_Class_Data = TheoryDataFun
+(
+ type T = (class * thm) list;
+ val empty = [];
+ val copy = I;
+ val extend = I;
+ fun merge pp = AList.merge (op =) (K true);
+);
+
+fun add_const_alias thm thy =
+ let
+ val (ofclass, eqn) = case try Logic.dest_equals (Thm.prop_of thm)
+ of SOME ofclass_eq => ofclass_eq
+ | _ => error ("Bad certificate: " ^ Display.string_of_thm_global thy thm);
+ val (T, class) = case try Logic.dest_of_class ofclass
+ of SOME T_class => T_class
+ | _ => error ("Bad certificate: " ^ Display.string_of_thm_global thy thm);
+ val tvar = case try Term.dest_TVar T
+ of SOME (tvar as (_, sort)) => if null (filter (can (AxClass.get_info thy)) sort)
+ then tvar
+ else error ("Bad sort: " ^ Display.string_of_thm_global thy thm)
+ | _ => error ("Bad type: " ^ Display.string_of_thm_global thy thm);
+ val _ = if Term.add_tvars eqn [] = [tvar] then ()
+ else error ("Inconsistent type: " ^ Display.string_of_thm_global thy thm);
+ val lhs_rhs = case try Logic.dest_equals eqn
+ of SOME lhs_rhs => lhs_rhs
+ | _ => error ("Not an equation: " ^ Syntax.string_of_term_global thy eqn);
+ val c_c' = case try (pairself (Code.check_const thy)) lhs_rhs
+ of SOME c_c' => c_c'
+ | _ => error ("Not an equation with two constants: "
+ ^ Syntax.string_of_term_global thy eqn);
+ val _ = if the_list (AxClass.class_of_param thy (snd c_c')) = [class] then ()
+ else error ("Inconsistent class: " ^ Display.string_of_thm_global thy thm);
+ in Triv_Class_Data.map (AList.update (op =) (class, thm)) thy end;
+
+local
+
+val get_triv_classes = map fst o Triv_Class_Data.get;
+
+val (_, triv_of_class) = Context.>>> (Context.map_theory_result
+ (Thm.add_oracle (Binding.name "triv_of_class", fn (thy, (v, T), class) =>
+ Thm.cterm_of thy (Logic.mk_of_class (T, class)))));
+
+in
+
+fun lift_triv_classes_conv thy conv ct =
+ let
+ val algebra = Sign.classes_of thy;
+ val triv_classes = get_triv_classes thy;
+ val certT = Thm.ctyp_of thy;
+ fun critical_classes sort = filter_out (fn class => Sign.subsort thy (sort, [class])) triv_classes;
+ val vs = Term.add_tfrees (Thm.term_of ct) []
+ |> map_filter (fn (v, sort) => case critical_classes sort
+ of [] => NONE
+ | classes => SOME (v, ((sort, classes), Sorts.inter_sort algebra (triv_classes, sort))));
+ val of_classes = maps (fn (v, ((sort, classes), _)) => map (fn class =>
+ ((v, class), triv_of_class (thy, (v, TVar ((v, 0), sort)), class))) classes
+ @ map (fn class => ((v, class), Thm.of_class (certT (TVar ((v, 0), sort)), class)))
+ sort) vs;
+ fun strip_of_class thm =
+ let
+ val prem_props = (Logic.strip_imp_prems o Thm.prop_of) thm;
+ val prem_thms = map (the o AList.lookup (op =) of_classes
+ o apfst (fst o fst o dest_TVar) o Logic.dest_of_class) prem_props;
+ in Drule.implies_elim_list thm prem_thms end;
+ in ct
+ |> Drule.cterm_rule Thm.varifyT
+ |> Thm.instantiate_cterm (Thm.certify_inst thy (map (fn (v, ((sort, _), sort')) =>
+ (((v, 0), sort), TFree (v, sort'))) vs, []))
+ |> Drule.cterm_rule Thm.freezeT
+ |> conv
+ |> Thm.varifyT
+ |> fold (fn (v, (_, sort')) => Thm.unconstrainT (certT (TVar ((v, 0), sort')))) vs
+ |> Thm.certify_instantiate (map (fn (v, ((sort, _), _)) =>
+ (((v, 0), []), TVar ((v, 0), sort))) vs, [])
+ |> strip_of_class
+ |> Thm.freezeT
+ end;
+
+fun lift_triv_classes_rew thy rew t =
+ let
+ val algebra = Sign.classes_of thy;
+ val triv_classes = get_triv_classes thy;
+ val vs = Term.add_tfrees t [];
+ in t
+ |> (map_types o map_type_tfree)
+ (fn (v, sort) => TFree (v, Sorts.inter_sort algebra (sort, triv_classes)))
+ |> rew
+ |> (map_types o map_type_tfree)
+ (fn (v, _) => TFree (v, the (AList.lookup (op =) vs v)))
+ end;
+
+end;
+
+
+(** the semantic universe **)
(*
Functions are given by their semantical function value. To avoid
@@ -275,7 +373,7 @@
val cs = map fst eqnss;
in
s
- |> tracing (fn s => "\n--- code to be evaluated:\n" ^ s)
+ |> traced (fn s => "\n--- code to be evaluated:\n" ^ s)
|> ML_Context.evaluate ctxt (!trace) univs_cookie
|> (fn f => f deps_vals)
|> (fn univs => cs ~~ univs)
@@ -450,14 +548,14 @@
val string_of_term = setmp show_types true (Syntax.string_of_term_global thy);
in
compile_eval thy naming program (vs, t) deps
- |> Code_Preproc.resubst_triv_consts thy
- |> tracing (fn t => "Normalized:\n" ^ string_of_term t)
+ |> traced (fn t => "Normalized:\n" ^ string_of_term t)
|> type_infer
- |> tracing (fn t => "Types inferred:\n" ^ string_of_term t)
+ |> traced (fn t => "Types inferred:\n" ^ string_of_term t)
|> check_tvars
- |> tracing (fn t => "---\n")
+ |> traced (fn t => "---\n")
end;
+
(* evaluation oracle *)
fun mk_equals thy lhs raw_rhs =
@@ -500,9 +598,9 @@
val norm_conv = no_frees_conv (fn ct =>
let
val thy = Thm.theory_of_cterm ct;
- in Code_Thingol.eval_conv thy (norm_oracle thy) ct end);
+ in lift_triv_classes_conv thy (Code_Thingol.eval_conv thy (norm_oracle thy)) ct end);
-fun norm thy = no_frees_rew (Code_Thingol.eval thy I (normalize thy));
+fun norm thy = lift_triv_classes_rew thy (no_frees_rew (Code_Thingol.eval thy I (normalize thy)));
(* evaluation command *)