src/Pure/Tools/codegen_func.ML
changeset 22197 461130ccfef4
parent 22185 24bf0e403526
child 22211 e2b5f3d24a17
--- a/src/Pure/Tools/codegen_func.ML	Fri Jan 26 13:59:03 2007 +0100
+++ b/src/Pure/Tools/codegen_func.ML	Fri Jan 26 13:59:04 2007 +0100
@@ -15,6 +15,7 @@
   val dest_func: thm -> (string * typ) * term list
   val typ_func: thm -> typ
 
+  val inst_thm: sort Vartab.table -> thm -> thm
   val expand_eta: int -> thm -> thm
   val rewrite_func: thm list -> thm -> thm
 end;
@@ -62,7 +63,7 @@
   end;
 
 
-(* making function theorems *)
+(* making defining equations *)
 
 val typ_func = lift_thm_thy (fn thy => snd o dest_Const o fst o strip_comb
   o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of);
@@ -83,20 +84,30 @@
             ((fold o fold_aterms) (fn Var (v, _) => cons v
               | _ => I
             ) args [])
-          then bad_thm "Repeated variables on left hand side of function equation" thm
+          then bad_thm "Repeated variables on left hand side of defining equation" thm
           else ()
-        fun no_abs (Abs _) = bad_thm "Abstraction on left hand side of function equation" thm 
+        fun no_abs (Abs _) = bad_thm "Abstraction on left hand side of defining equation" thm 
           | no_abs (t1 $ t2) = (no_abs t1; no_abs t2)
           | no_abs _ = ();
         val _ = map no_abs args;
       in thm end
-  | NONE => bad_thm "Not a function equation" thm;
+  | NONE => bad_thm "Not a defining equation" thm;
 
 val mk_func = map (mk_head o assert_func) o mk_rew;
 
 
 (* utilities *)
 
+fun inst_thm tvars' thm =
+  let
+    val thy = Thm.theory_of_thm thm;
+    val tvars = (Term.add_tvars o Thm.prop_of) thm [];
+    fun mk_inst (tvar as (v, _)) = case Vartab.lookup tvars' v
+     of SOME sort => SOME (pairself (Thm.ctyp_of thy o TVar) (tvar, (v, sort)))
+      | NONE => NONE;
+    val insts = map_filter mk_inst tvars;
+  in Thm.instantiate (insts, []) thm end;
+
 fun expand_eta k thm =
   let
     val thy = Thm.theory_of_thm thm;
@@ -139,8 +150,6 @@
     val args' = map rewrite ct_args;
     val lhs' = Thm.symmetric (fold (fn th1 => fn th2 => Thm.combination th2 th1)
       args' (Thm.reflexive ct_f));
-  in
-    Thm.transitive (Thm.transitive lhs' thm) rhs'
-  end handle Bind => raise ERROR "rewrite_func"
+  in Thm.transitive (Thm.transitive lhs' thm) rhs' end;
 
 end;