new version of triv_of_class machinery without legacy_unconstrain
authorhaftmann
Wed, 19 May 2010 10:14:37 +0200
changeset 36982 1d4478a797c2
parent 36979 da7c06ab3169
child 36983 e922a5124428
new version of triv_of_class machinery without legacy_unconstrain
src/Tools/nbe.ML
--- a/src/Tools/nbe.ML	Tue May 18 19:00:55 2010 -0700
+++ b/src/Tools/nbe.ML	Wed May 19 10:14:37 2010 +0200
@@ -76,7 +76,7 @@
 val get_triv_classes = map fst o Triv_Class_Data.get;
 
 val (_, triv_of_class) = Context.>>> (Context.map_theory_result
-  (Thm.add_oracle (Binding.name "triv_of_class", fn (thy, (v, T), class) =>
+  (Thm.add_oracle (Binding.name "triv_of_class", fn (thy, T, class) =>
     Thm.cterm_of thy (Logic.mk_of_class (T, class)))));
 
 in
@@ -84,37 +84,46 @@
 fun lift_triv_classes_conv thy conv ct =
   let
     val algebra = Sign.classes_of thy;
+    val certT = Thm.ctyp_of thy;
     val triv_classes = get_triv_classes thy;
-    val certT = Thm.ctyp_of thy;
-    fun critical_classes sort = filter_out (fn class => Sign.subsort thy (sort, [class])) triv_classes;
-    val vs = Term.add_tfrees (Thm.term_of ct) []
-      |> map_filter (fn (v, sort) => case critical_classes sort
-          of [] => NONE
-           | classes => SOME (v, ((sort, classes), Sorts.inter_sort algebra (triv_classes, sort))));
-    val of_classes = maps (fn (v, ((sort, classes), _)) => map (fn class =>
-      ((v, class), triv_of_class (thy, (v, TVar ((v, 0), sort)), class))) classes
-      @ map (fn class => ((v, class), Thm.of_class (certT (TVar ((v, 0), sort)), class)))
-        sort) vs;
+    fun additional_classes sort = filter_out (fn class => Sorts.sort_le algebra (sort, [class])) triv_classes;
+    fun mk_entry (v, sort) =
+      let
+        val T = TFree (v, sort);
+        val cT = certT T;
+        val triv_sort = additional_classes sort;
+      in
+        (v, (Sorts.inter_sort algebra (sort, triv_sort),
+          (cT, AList.make (fn class => Thm.of_class (cT, class)) sort
+            @ AList.make (fn class => triv_of_class (thy, T, class)) triv_sort)))
+      end;
+    val vs_tab = map mk_entry (Term.add_tfrees (Thm.term_of ct) []);
+    fun instantiate thm =
+      let
+        val cert_tvars = map (certT o TVar) (Term.add_tvars
+          ((fst o Logic.dest_equals o Logic.strip_imp_concl o Thm.prop_of) thm) []);
+        val instantiation =
+          map2 (fn cert_tvar => fn (_, (_, (cT, _))) => (cert_tvar, cT)) cert_tvars vs_tab;
+      in Thm.instantiate (instantiation, []) thm end;
+    fun of_class (TFree (v, _), class) =
+          the (AList.lookup (op =) ((snd o snd o the o AList.lookup (op =) vs_tab) v) class)
+      | of_class (T, _) = error ("Bad type " ^ Syntax.string_of_typ_global thy T);
     fun strip_of_class thm =
       let
-        val prem_props = (Logic.strip_imp_prems o Thm.prop_of) thm;
-        val prem_thms = map (the o AList.lookup (op =) of_classes
-          o apfst (fst o fst o dest_TVar) o Logic.dest_of_class) prem_props;
-      in Drule.implies_elim_list thm prem_thms end;
+        val prems_of_class = Thm.prop_of thm
+          |> Logic.strip_imp_prems
+          |> map (Logic.dest_of_class #> of_class);
+      in fold Thm.elim_implies prems_of_class thm end;
   in
-    (* FIXME avoid legacy operations *)
     ct
-    |> Drule.cterm_rule Thm.varifyT_global
-    |> Thm.instantiate_cterm (Thm.certify_inst thy (map (fn (v, ((sort, _), sort')) =>
-        (((v, 0), sort), TFree (v, sort'))) vs, []))
-    |> Drule.cterm_rule Thm.legacy_freezeT
+    |> (Drule.cterm_fun o map_types o map_type_tfree)
+        (fn (v, sort) => TFree (v, (fst o the o AList.lookup (op =) vs_tab) v))
     |> conv
+    |> Thm.strip_shyps
     |> Thm.varifyT_global
-    |> fold (fn (v, (_, sort')) => Thm.legacy_unconstrainT (certT (TVar ((v, 0), sort')))) vs
-    |> Thm.certify_instantiate (map (fn (v, ((sort, _), _)) =>
-        (((v, 0), []), TVar ((v, 0), sort))) vs, [])
+    |> Thm.unconstrainT
+    |> instantiate
     |> strip_of_class
-    |> Thm.legacy_freezeT
   end;
 
 fun lift_triv_classes_rew thy rew t =
@@ -365,7 +374,7 @@
 
 (* code compilation *)
 
-fun compile_eqnss _ gr raw_deps [] = []
+fun compile_eqnss ctxt gr raw_deps [] = []
   | compile_eqnss ctxt gr raw_deps eqnss =
       let
         val (deps, deps_vals) = split_list (map_filter
@@ -552,7 +561,7 @@
     |> type_infer
     |> traced (fn t => "Types inferred:\n" ^ string_of_term t)
     |> check_tvars
-    |> traced (fn t => "---\n")
+    |> traced (fn _ => "---\n")
   end;