moved eq handling in nbe into separate oracle
authorhaftmann
Wed, 09 Sep 2009 11:31:20 +0200
changeset 32544 e129333b9df0
parent 32543 62e6c9b67c6f
child 32546 d68b7c181211
child 32547 f3eab1682b0d
moved eq handling in nbe into separate oracle
src/HOL/HOL.thy
src/HOL/ex/NormalForm.thy
src/Pure/Isar/code.ML
src/Tools/Code/code_preproc.ML
src/Tools/nbe.ML
--- a/src/HOL/HOL.thy	Tue Sep 08 18:31:26 2009 +0200
+++ b/src/HOL/HOL.thy	Wed Sep 09 11:31:20 2009 +0200
@@ -1887,7 +1887,7 @@
 *}
 
 setup {*
-  Code.add_const_alias @{thm equals_alias_cert}
+  Nbe.add_const_alias @{thm equals_alias_cert}
 *}
 
 hide (open) const eq
--- a/src/HOL/ex/NormalForm.thy	Tue Sep 08 18:31:26 2009 +0200
+++ b/src/HOL/ex/NormalForm.thy	Wed Sep 09 11:31:20 2009 +0200
@@ -4,8 +4,17 @@
 
 theory NormalForm
 imports Main Rational
+uses "~~/src/Tools/nbe.ML"
 begin
 
+setup {*
+  Nbe.add_const_alias @{thm equals_alias_cert}
+*}
+
+method_setup normalization = {*
+  Scan.succeed (K (SIMPLE_METHOD' (CONVERSION Nbe.norm_conv THEN' (fn k => TRY (rtac TrueI k)))))
+*} "solve goal by normalization"
+
 lemma "True" by normalization
 lemma "p \<longrightarrow> True" by normalization
 declare disj_assoc [code nbe]
@@ -120,4 +129,13 @@
 normal_form "(%m n f x. m (n f) x) (%f x. f(f(f(x)))) (%f x. f(f(f(x))))"
 normal_form "(%m n. n m) (%f x. f(f(f(x)))) (%f x. f(f(f(x))))"
 
+(* handling of type classes in connection with equality *)
+
+lemma "map f [x, y] = [f x, f y]" by normalization
+lemma "(map f [x, y], w) = ([f x, f y], w)" by normalization
+lemma "map f [x, y] = [f x \<Colon> 'a\<Colon>semigroup_add, f y]" by normalization
+lemma "map f [x \<Colon> 'a\<Colon>semigroup_add, y] = [f x, f y]" by normalization
+lemma "(map f [x \<Colon> 'a\<Colon>semigroup_add, y], w \<Colon> 'b\<Colon>finite) = ([f x, f y], w)" by normalization
+
+
 end
--- a/src/Pure/Isar/code.ML	Tue Sep 08 18:31:26 2009 +0200
+++ b/src/Pure/Isar/code.ML	Wed Sep 09 11:31:20 2009 +0200
@@ -19,11 +19,6 @@
   val constrset_of_consts: theory -> (string * typ) list
     -> string * ((string * sort) list * (string * typ list) list)
 
-  (*constant aliasses*)
-  val add_const_alias: thm -> theory -> theory
-  val triv_classes: theory -> class list
-  val resubst_alias: theory -> string -> string
-
   (*code equations*)
   val mk_eqn: theory -> thm * bool -> thm * bool
   val mk_eqn_warning: theory -> thm -> (thm * bool) option
@@ -169,7 +164,6 @@
 
 datatype spec = Spec of {
   history_concluded: bool,
-  aliasses: ((string * string) * thm) list * class list,
   eqns: ((bool * eqns) * (serial * eqns) list) Symtab.table
     (*with explicit history*),
   dtyps: ((serial * ((string * sort) list * (string * typ list) list)) list) Symtab.table
@@ -177,19 +171,16 @@
   cases: (int * (int * string list)) Symtab.table * unit Symtab.table
 };
 
-fun make_spec ((history_concluded, aliasses), (eqns, (dtyps, cases))) =
-  Spec { history_concluded = history_concluded, aliasses = aliasses,
-    eqns = eqns, dtyps = dtyps, cases = cases };
-fun map_spec f (Spec { history_concluded = history_concluded, aliasses = aliasses, eqns = eqns,
+fun make_spec (history_concluded, (eqns, (dtyps, cases))) =
+  Spec { history_concluded = history_concluded, eqns = eqns, dtyps = dtyps, cases = cases };
+fun map_spec f (Spec { history_concluded = history_concluded, eqns = eqns,
   dtyps = dtyps, cases = cases }) =
-  make_spec (f ((history_concluded, aliasses), (eqns, (dtyps, cases))));
-fun merge_spec (Spec { history_concluded = _, aliasses = aliasses1, eqns = eqns1,
+  make_spec (f (history_concluded, (eqns, (dtyps, cases))));
+fun merge_spec (Spec { history_concluded = _, eqns = eqns1,
     dtyps = dtyps1, cases = (cases1, undefs1) },
-  Spec { history_concluded = _, aliasses = aliasses2, eqns = eqns2,
+  Spec { history_concluded = _, eqns = eqns2,
     dtyps = dtyps2, cases = (cases2, undefs2) }) =
   let
-    val aliasses = (Library.merge (eq_snd Thm.eq_thm_prop) (pairself fst (aliasses1, aliasses2)),
-      Library.merge (op =) (pairself snd (aliasses1, aliasses2)));
     fun merge_eqns ((_, history1), (_, history2)) =
       let
         val raw_history = AList.merge (op = : serial * serial -> bool)
@@ -202,15 +193,13 @@
     val dtyps = Symtab.join (K (AList.merge (op =) (K true))) (dtyps1, dtyps2);
     val cases = (Symtab.merge (K true) (cases1, cases2),
       Symtab.merge (K true) (undefs1, undefs2));
-  in make_spec ((false, aliasses), (eqns, (dtyps, cases))) end;
+  in make_spec (false, (eqns, (dtyps, cases))) end;
 
 fun history_concluded (Spec { history_concluded, ... }) = history_concluded;
-fun the_aliasses (Spec { aliasses, ... }) = aliasses;
 fun the_eqns (Spec { eqns, ... }) = eqns;
 fun the_dtyps (Spec { dtyps, ... }) = dtyps;
 fun the_cases (Spec { cases, ... }) = cases;
-val map_history_concluded = map_spec o apfst o apfst;
-val map_aliasses = map_spec o apfst o apsnd;
+val map_history_concluded = map_spec o apfst;
 val map_eqns = map_spec o apsnd o apfst;
 val map_dtyps = map_spec o apsnd o apsnd o apfst;
 val map_cases = map_spec o apsnd o apsnd o apsnd;
@@ -264,7 +253,7 @@
 structure Code_Data = TheoryDataFun
 (
   type T = spec * data ref;
-  val empty = (make_spec ((false, ([], [])),
+  val empty = (make_spec (false,
     (Symtab.empty, (Symtab.empty, (Symtab.empty, Symtab.empty)))), ref empty_data);
   fun copy (spec, data) = (spec, ref (! data));
   val extend = copy;
@@ -358,24 +347,6 @@
 end; (*local*)
 
 
-(** retrieval interfaces **)
-
-(* constant aliasses *)
-
-fun resubst_alias thy =
-  let
-    val alias = (fst o the_aliasses o the_exec) thy;
-    val subst_inst_param = Option.map fst o AxClass.inst_of_param thy;
-    fun subst_alias c =
-      get_first (fn ((c', c''), _) => if c = c'' then SOME c' else NONE) alias;
-  in
-    perhaps subst_inst_param
-    #> perhaps subst_alias
-  end;
-
-val triv_classes = snd o the_aliasses o the_exec;
-
-
 (** foundation **)
 
 (* constants *)
@@ -669,38 +640,6 @@
 
 (** declaring executable ingredients **)
 
-(* constant aliasses *)
-
-fun add_const_alias thm thy =
-  let
-    val (ofclass, eqn) = case try Logic.dest_equals (Thm.prop_of thm)
-     of SOME ofclass_eq => ofclass_eq
-      | _ => error ("Bad certificate: " ^ Display.string_of_thm_global thy thm);
-    val (T, class) = case try Logic.dest_of_class ofclass
-     of SOME T_class => T_class
-      | _ => error ("Bad certificate: " ^ Display.string_of_thm_global thy thm);
-    val tvar = case try Term.dest_TVar T
-     of SOME (tvar as (_, sort)) => if null (filter (can (AxClass.get_info thy)) sort)
-          then tvar
-          else error ("Bad sort: " ^ Display.string_of_thm_global thy thm)
-      | _ => error ("Bad type: " ^ Display.string_of_thm_global thy thm);
-    val _ = if Term.add_tvars eqn [] = [tvar] then ()
-      else error ("Inconsistent type: " ^ Display.string_of_thm_global thy thm);
-    val lhs_rhs = case try Logic.dest_equals eqn
-     of SOME lhs_rhs => lhs_rhs
-      | _ => error ("Not an equation: " ^ Syntax.string_of_term_global thy eqn);
-    val c_c' = case try (pairself (check_const thy)) lhs_rhs
-     of SOME c_c' => c_c'
-      | _ => error ("Not an equation with two constants: "
-          ^ Syntax.string_of_term_global thy eqn);
-    val _ = if the_list (AxClass.class_of_param thy (snd c_c')) = [class] then ()
-      else error ("Inconsistent class: " ^ Display.string_of_thm_global thy thm);
-  in thy |>
-    (map_exec_purge NONE o map_aliasses) (fn (alias, classes) =>
-      ((c_c', thm) :: alias, insert (op =) class classes))
-  end;
-
-
 (* datatypes *)
 
 structure Type_Interpretation = InterpretationFun(type T = string * serial val eq = eq_snd (op =) : T * T -> bool);
--- a/src/Tools/Code/code_preproc.ML	Tue Sep 08 18:31:26 2009 +0200
+++ b/src/Tools/Code/code_preproc.ML	Wed Sep 09 11:31:20 2009 +0200
@@ -23,7 +23,6 @@
   val all: code_graph -> string list
   val pretty: theory -> code_graph -> Pretty.T
   val obtain: theory -> string list -> term list -> code_algebra * code_graph
-  val resubst_triv_consts: theory -> term -> term
   val eval_conv: theory
     -> (code_algebra -> code_graph -> (string * sort) list -> term -> cterm -> thm) -> cterm -> thm
   val eval: theory -> ((term -> term) -> 'a -> 'a)
@@ -73,10 +72,8 @@
   if AList.defined (op =) xs key then AList.delete (op =) key xs
   else error ("No such " ^ msg ^ ": " ^ quote key);
 
-fun map_data f thy =
-  thy
-  |> Code.purge_data
-  |> (Code_Preproc_Data.map o map_thmproc) f;
+fun map_data f = Code.purge_data
+  #> (Code_Preproc_Data.map o map_thmproc) f;
 
 val map_pre_post = map_data o apfst;
 val map_pre = map_pre_post o apfst;
@@ -163,10 +160,7 @@
     |> rhs_conv (Simplifier.rewrite post)
   end;
 
-fun resubst_triv_consts thy = map_aterms (fn t as Const (c, ty) => Const (Code.resubst_alias thy c, ty)
-  | t => t);
-
-fun postprocess_term thy = term_of_conv thy (postprocess_conv thy) #> resubst_triv_consts thy;
+fun postprocess_term thy = term_of_conv thy (postprocess_conv thy);
 
 fun print_codeproc thy =
   let
@@ -489,17 +483,6 @@
 fun obtain thy cs ts = apsnd snd
   (Wellsorted.change_yield thy (extend_arities_eqngr thy cs ts));
 
-fun prepare_sorts_typ prep_sort
-  = map_type_tfree (fn (v, sort) => TFree (v, prep_sort sort));
-
-fun prepare_sorts prep_sort (Const (c, ty)) =
-      Const (c, prepare_sorts_typ prep_sort ty)
-  | prepare_sorts prep_sort (t1 $ t2) =
-      prepare_sorts prep_sort t1 $ prepare_sorts prep_sort t2
-  | prepare_sorts prep_sort (Abs (v, ty, t)) =
-      Abs (v, prepare_sorts_typ prep_sort ty, prepare_sorts prep_sort t)
-  | prepare_sorts _ (t as Bound _) = t;
-
 fun gen_eval thy cterm_of conclude_evaluation evaluator proto_ct =
   let
     val pp = Syntax.pp_global thy;
@@ -512,12 +495,8 @@
     val vs = Term.add_tfrees t' [];
     val consts = fold_aterms
       (fn Const (c, _) => insert (op =) c | _ => I) t' [];
-
-    val add_triv_classes = curry (Sorts.inter_sort (Sign.classes_of thy))
-      (Code.triv_classes thy);
-    val t'' = prepare_sorts add_triv_classes t';
-    val (algebra', eqngr') = obtain thy consts [t''];
-  in conclude_evaluation (evaluator algebra' eqngr' vs t'' ct') thm end;
+    val (algebra', eqngr') = obtain thy consts [t'];
+  in conclude_evaluation (evaluator algebra' eqngr' vs t' ct') thm end;
 
 fun simple_evaluator evaluator algebra eqngr vs t ct =
   evaluator algebra eqngr vs t;
--- a/src/Tools/nbe.ML	Tue Sep 08 18:31:26 2009 +0200
+++ b/src/Tools/nbe.ML	Wed Sep 09 11:31:20 2009 +0200
@@ -23,6 +23,7 @@
   val trace: bool ref
 
   val setup: theory -> theory
+  val add_const_alias: thm -> theory -> theory
 end;
 
 structure Nbe: NBE =
@@ -31,10 +32,107 @@
 (* generic non-sense *)
 
 val trace = ref false;
-fun tracing f x = if !trace then (Output.tracing (f x); x) else x;
+fun traced f x = if !trace then (tracing (f x); x) else x;
 
 
-(** the semantical universe **)
+(** certificates and oracle for "trivial type classes" **)
+
+structure Triv_Class_Data = TheoryDataFun
+(
+  type T = (class * thm) list;
+  val empty = [];
+  val copy = I;
+  val extend = I;
+  fun merge pp = AList.merge (op =) (K true);
+);
+
+fun add_const_alias thm thy =
+  let
+    val (ofclass, eqn) = case try Logic.dest_equals (Thm.prop_of thm)
+     of SOME ofclass_eq => ofclass_eq
+      | _ => error ("Bad certificate: " ^ Display.string_of_thm_global thy thm);
+    val (T, class) = case try Logic.dest_of_class ofclass
+     of SOME T_class => T_class
+      | _ => error ("Bad certificate: " ^ Display.string_of_thm_global thy thm);
+    val tvar = case try Term.dest_TVar T
+     of SOME (tvar as (_, sort)) => if null (filter (can (AxClass.get_info thy)) sort)
+          then tvar
+          else error ("Bad sort: " ^ Display.string_of_thm_global thy thm)
+      | _ => error ("Bad type: " ^ Display.string_of_thm_global thy thm);
+    val _ = if Term.add_tvars eqn [] = [tvar] then ()
+      else error ("Inconsistent type: " ^ Display.string_of_thm_global thy thm);
+    val lhs_rhs = case try Logic.dest_equals eqn
+     of SOME lhs_rhs => lhs_rhs
+      | _ => error ("Not an equation: " ^ Syntax.string_of_term_global thy eqn);
+    val c_c' = case try (pairself (Code.check_const thy)) lhs_rhs
+     of SOME c_c' => c_c'
+      | _ => error ("Not an equation with two constants: "
+          ^ Syntax.string_of_term_global thy eqn);
+    val _ = if the_list (AxClass.class_of_param thy (snd c_c')) = [class] then ()
+      else error ("Inconsistent class: " ^ Display.string_of_thm_global thy thm);
+  in Triv_Class_Data.map (AList.update (op =) (class, thm)) thy end;
+
+local
+
+val get_triv_classes = map fst o Triv_Class_Data.get;
+
+val (_, triv_of_class) = Context.>>> (Context.map_theory_result
+  (Thm.add_oracle (Binding.name "triv_of_class", fn (thy, (v, T), class) =>
+    Thm.cterm_of thy (Logic.mk_of_class (T, class)))));
+
+in
+
+fun lift_triv_classes_conv thy conv ct =
+  let
+    val algebra = Sign.classes_of thy;
+    val triv_classes = get_triv_classes thy;
+    val certT = Thm.ctyp_of thy;
+    fun critical_classes sort = filter_out (fn class => Sign.subsort thy (sort, [class])) triv_classes;
+    val vs = Term.add_tfrees (Thm.term_of ct) []
+      |> map_filter (fn (v, sort) => case critical_classes sort
+          of [] => NONE
+           | classes => SOME (v, ((sort, classes), Sorts.inter_sort algebra (triv_classes, sort))));
+    val of_classes = maps (fn (v, ((sort, classes), _)) => map (fn class =>
+      ((v, class), triv_of_class (thy, (v, TVar ((v, 0), sort)), class))) classes
+      @ map (fn class => ((v, class), Thm.of_class (certT (TVar ((v, 0), sort)), class)))
+        sort) vs;
+    fun strip_of_class thm =
+      let
+        val prem_props = (Logic.strip_imp_prems o Thm.prop_of) thm;
+        val prem_thms = map (the o AList.lookup (op =) of_classes
+          o apfst (fst o fst o dest_TVar) o Logic.dest_of_class) prem_props;
+      in Drule.implies_elim_list thm prem_thms end;
+  in ct
+    |> Drule.cterm_rule Thm.varifyT
+    |> Thm.instantiate_cterm (Thm.certify_inst thy (map (fn (v, ((sort, _), sort')) =>
+        (((v, 0), sort), TFree (v, sort'))) vs, []))
+    |> Drule.cterm_rule Thm.freezeT
+    |> conv
+    |> Thm.varifyT
+    |> fold (fn (v, (_, sort')) => Thm.unconstrainT (certT (TVar ((v, 0), sort')))) vs
+    |> Thm.certify_instantiate (map (fn (v, ((sort, _), _)) =>
+        (((v, 0), []), TVar ((v, 0), sort))) vs, [])
+    |> strip_of_class
+    |> Thm.freezeT
+  end;
+
+fun lift_triv_classes_rew thy rew t =
+  let
+    val algebra = Sign.classes_of thy;
+    val triv_classes = get_triv_classes thy;
+    val vs = Term.add_tfrees t [];
+  in t
+    |> (map_types o map_type_tfree)
+        (fn (v, sort) => TFree (v, Sorts.inter_sort algebra (sort, triv_classes)))
+    |> rew
+    |> (map_types o map_type_tfree)
+        (fn (v, _) => TFree (v, the (AList.lookup (op =) vs v)))
+  end;
+
+end;
+
+
+(** the semantic universe **)
 
 (*
    Functions are given by their semantical function value. To avoid
@@ -275,7 +373,7 @@
         val cs = map fst eqnss;
       in
         s
-        |> tracing (fn s => "\n--- code to be evaluated:\n" ^ s)
+        |> traced (fn s => "\n--- code to be evaluated:\n" ^ s)
         |> ML_Context.evaluate ctxt (!trace) univs_cookie
         |> (fn f => f deps_vals)
         |> (fn univs => cs ~~ univs)
@@ -450,14 +548,14 @@
     val string_of_term = setmp show_types true (Syntax.string_of_term_global thy);
   in
     compile_eval thy naming program (vs, t) deps
-    |> Code_Preproc.resubst_triv_consts thy
-    |> tracing (fn t => "Normalized:\n" ^ string_of_term t)
+    |> traced (fn t => "Normalized:\n" ^ string_of_term t)
     |> type_infer
-    |> tracing (fn t => "Types inferred:\n" ^ string_of_term t)
+    |> traced (fn t => "Types inferred:\n" ^ string_of_term t)
     |> check_tvars
-    |> tracing (fn t => "---\n")
+    |> traced (fn t => "---\n")
   end;
 
+
 (* evaluation oracle *)
 
 fun mk_equals thy lhs raw_rhs =
@@ -500,9 +598,9 @@
 val norm_conv = no_frees_conv (fn ct =>
   let
     val thy = Thm.theory_of_cterm ct;
-  in Code_Thingol.eval_conv thy (norm_oracle thy) ct end);
+  in lift_triv_classes_conv thy (Code_Thingol.eval_conv thy (norm_oracle thy)) ct end);
 
-fun norm thy = no_frees_rew (Code_Thingol.eval thy I (normalize thy));
+fun norm thy = lift_triv_classes_rew thy (no_frees_rew (Code_Thingol.eval thy I (normalize thy)));
 
 (* evaluation command *)