src/Pure/Isar/code_unit.ML
changeset 28423 9fc3befd8191
parent 28368 8437fb395294
child 28704 8703d17c5e68
--- a/src/Pure/Isar/code_unit.ML	Tue Sep 30 12:49:17 2008 +0200
+++ b/src/Pure/Isar/code_unit.ML	Tue Sep 30 12:49:18 2008 +0200
@@ -15,8 +15,8 @@
 
   (*typ instantiations*)
   val typscheme: theory -> string * typ -> (string * sort) list * typ
-  val inst_thm: sort Vartab.table -> thm -> thm
-  val constrain_thm: sort -> thm -> thm
+  val inst_thm: theory -> sort Vartab.table -> thm -> thm
+  val constrain_thm: theory -> sort -> thm -> thm
 
   (*constant aliasses*)
   val add_const_alias: thm -> theory -> theory
@@ -36,16 +36,17 @@
     -> string * ((string * sort) list * (string * typ list) list)
 
   (*defining equations*)
-  val assert_rew: thm -> thm
-  val mk_rew: thm -> thm
-  val add_linear: thm -> thm * bool
-  val mk_eqn: thm -> thm * bool
-  val head_eqn: thm -> string * ((string * sort) list * typ)
-  val expand_eta: int -> thm -> thm
+  val assert_eqn: theory -> thm -> thm
+  val mk_eqn: theory -> thm -> thm * bool
+  val assert_linear: thm * bool -> thm * bool
+  val const_eqn: thm -> string
+  val const_typ_eqn: thm -> string * typ
+  val head_eqn: theory -> thm -> string * ((string * sort) list * typ)
+  val expand_eta: theory -> int -> thm -> thm
   val rewrite_eqn: simpset -> thm -> thm
   val rewrite_head: thm list -> thm -> thm
-  val norm_args: thm list -> thm list 
-  val norm_varnames: (string -> string) -> (string -> string) -> thm list -> thm list
+  val norm_args: theory -> thm list -> thm list 
+  val norm_varnames: theory -> (string -> string) -> (string -> string) -> thm list -> thm list
 
   (*case certificates*)
   val case_cert: thm -> string * (int * string list)
@@ -81,9 +82,8 @@
     val vs = map dest (Sign.const_typargs thy (c, ty));
   in (vs, ty) end;
 
-fun inst_thm tvars' thm =
+fun inst_thm thy tvars' thm =
   let
-    val thy = Thm.theory_of_thm thm;
     val tvars = (Term.add_tvars o Thm.prop_of) thm [];
     val inter_sort = Sorts.inter_sort (Sign.classes_of thy);
     fun mk_inst (tvar as (v, sort)) = case Vartab.lookup tvars' v
@@ -93,9 +93,8 @@
     val insts = map_filter mk_inst tvars;
   in Thm.instantiate (insts, []) thm end;
 
-fun constrain_thm sort thm =
+fun constrain_thm thy sort thm =
   let
-    val thy = Thm.theory_of_thm thm;
     val constrain = curry (Sorts.inter_sort (Sign.classes_of thy)) sort
     val tvars = (Term.add_tvars o Thm.prop_of) thm [];
     fun mk_inst (tvar as (v, sort)) = pairself (Thm.ctyp_of thy o TVar o pair v)
@@ -103,9 +102,8 @@
     val insts = map mk_inst tvars;
   in Thm.instantiate (insts, []) thm end;
 
-fun expand_eta k thm =
+fun expand_eta thy k thm =
   let
-    val thy = Thm.theory_of_thm thm;
     val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm;
     val (head, args) = strip_comb lhs;
     val l = if k = ~1
@@ -153,19 +151,19 @@
 val rewrite_eqn = Conv.fconv_rule o eqn_conv o Simplifier.rewrite;
 val rewrite_head = Conv.fconv_rule o head_conv o MetaSimplifier.rewrite false;
 
-fun norm_args thms =
+fun norm_args thy thms =
   let
     val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
     val k = fold (curry Int.max o num_args_of o Thm.prop_of) thms 0;
   in
     thms
-    |> map (expand_eta k)
+    |> map (expand_eta thy k)
     |> map (Conv.fconv_rule Drule.beta_eta_conversion)
   end;
 
-fun canonical_tvars purify_tvar thm =
+fun canonical_tvars thy purify_tvar thm =
   let
-    val ctyp = Thm.ctyp_of (Thm.theory_of_thm thm);
+    val ctyp = Thm.ctyp_of thy;
     fun tvars_subst_for thm = (fold_types o fold_atyps)
       (fn TVar (v_i as (v, _), sort) => let
             val v' = purify_tvar v
@@ -182,9 +180,9 @@
     val (_, inst) = fold mk_inst (tvars_subst_for thm) (maxidx + 1, []);
   in Thm.instantiate (inst, []) thm end;
 
-fun canonical_vars purify_var thm =
+fun canonical_vars thy purify_var thm =
   let
-    val cterm = Thm.cterm_of (Thm.theory_of_thm thm);
+    val cterm = Thm.cterm_of thy;
     fun vars_subst_for thm = fold_aterms
       (fn Var (v_i as (v, _), ty) => let
             val v' = purify_var v
@@ -207,7 +205,7 @@
     val t' = Term.map_abs_vars purify_var t;
   in Thm.rename_boundvars t t' thm end;
 
-fun norm_varnames purify_tvar purify_var thms =
+fun norm_varnames thy purify_tvar purify_var thms =
   let
     fun burrow_thms f [] = []
       | burrow_thms f thms =
@@ -217,8 +215,8 @@
           |> Conjunction.elim_balanced (length thms)
   in
     thms
-    |> burrow_thms (canonical_tvars purify_tvar)
-    |> map (canonical_vars purify_var)
+    |> burrow_thms (canonical_tvars thy purify_tvar)
+    |> map (canonical_vars thy purify_var)
     |> map (canonical_absvars purify_var)
     |> map Drule.zero_var_indexes
   end;
@@ -237,18 +235,16 @@
       Library.merge (op =) (classes1, classes2));
 );
 
-fun add_const_alias thm =
+fun add_const_alias thm thy =
   let
-    val t = Thm.prop_of thm;
-    val thy = Thm.theory_of_thm thm;
-    val lhs_rhs = case try Logic.dest_equals t
+    val lhs_rhs = case try Logic.dest_equals (Thm.prop_of thm)
      of SOME lhs_rhs => lhs_rhs
       | _ => error ("Not an equation: " ^ Display.string_of_thm thm);
     val c_c' = case try (pairself (AxClass.unoverload_const thy o dest_Const)) lhs_rhs
      of SOME c_c' => c_c'
       | _ => error ("Not an equation with two constants: " ^ Display.string_of_thm thm);
     val some_class = the_list (AxClass.class_of_param thy (snd c_c'));
-  in
+  in thy |>
     ConstAlias.map (fn (alias, classes) =>
       ((c_c', thm) :: alias, fold (insert (op =)) some_class classes))
   end;
@@ -319,9 +315,9 @@
   in (tyco, (vs, cs''')) end;
 
 
-(* rewrite theorems *)
+(* defining equations *)
 
-fun assert_rew thm =
+fun assert_eqn thy thm =
   let
     val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm
       handle TERM _ => bad_thm ("Not an equation: " ^ Display.string_of_thm thm)
@@ -346,34 +342,8 @@
     val _ = if null (subtract (op =) lhs_tvs rhs_tvs)
       then ()
       else bad_thm ("Free type variables on right hand side of rewrite theorem\n"
-        ^ Display.string_of_thm thm)
-  in thm end;
-
-fun mk_rew thm =
-  let
-    val thy = Thm.theory_of_thm thm;
-    val ctxt = ProofContext.init thy;
-  in
-    thm
-    |> LocalDefs.meta_rewrite_rule ctxt
-    |> assert_rew
-  end;
-
-
-(* defining equations *)
-
-fun add_linear thm =
-  let
-    val (_, args) = (strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of) thm;
-    val linear = not (has_duplicates (op =)
-      ((fold o fold_aterms) (fn Var (v, _) => cons v | _ => I) args []))
-  in (thm, linear) end;
-
-fun assert_eqn thm =
-  let
-    val thy = Thm.theory_of_thm thm;
-    val (head, args) = (strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of) thm;
-    val _ = case head of Const _ => () | _ =>
+        ^ Display.string_of_thm thm)    val (head, args) = (strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of) thm;
+    val (c, ty) = case head of Const c_ty => c_ty | _ =>
       bad_thm ("Equation not headed by constant\n" ^ Display.string_of_thm thm);
     fun check _ (Abs _) = bad_thm
           ("Abstraction on left hand side of equation\n"
@@ -389,25 +359,41 @@
                ^ Display.string_of_thm thm)
           else ();
     val _ = map (check 0) args;
-    val linear = not (has_duplicates (op =)
-      ((fold o fold_aterms) (fn Var (v, _) => cons v | _ => I ) args []))
-  in add_linear thm end;
-
-val mk_eqn = assert_eqn o mk_rew;
+    val ty_decl = Sign.the_const_type thy c;
+    val _ = if Sign.typ_equiv thy (Type.strip_sorts ty_decl, Type.strip_sorts ty)
+      then () else bad_thm ("Type\n" ^ string_of_typ thy ty
+           ^ "\nof defining equation\n"
+           ^ Display.string_of_thm thm
+           ^ "\nis incompatible with declared function type\n"
+           ^ string_of_typ thy ty_decl)
+  in thm end;
 
-fun head_eqn thm =
+fun add_linear thm =
   let
-    val thy = Thm.theory_of_thm thm;
-    val Const (c, ty) = (fst o strip_comb o fst o Logic.dest_equals
-      o Thm.plain_prop_of) thm;
-  in (c, typscheme thy (c, ty)) end;
+    val (_, args) = (strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of) thm;
+    val linear = not (has_duplicates (op =)
+      ((fold o fold_aterms) (fn Var (v, _) => cons v | _ => I) args []))
+  in (thm, linear) end;
+
+fun assert_linear (thm, false) = (thm, false)
+  | assert_linear (thm, true) = if snd (add_linear thm) then (thm, true)
+      else bad_thm
+        ("Duplicate variables on left hand side of defining equation:\n"
+          ^ Display.string_of_thm thm);
+
+
+fun mk_eqn thy = add_linear o assert_eqn thy o AxClass.unoverload thy
+  o LocalDefs.meta_rewrite_rule (ProofContext.init thy);
+
+val const_typ_eqn = dest_Const o fst o strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of;
+val const_eqn = fst o const_typ_eqn;
+fun head_eqn thy thm = let val (c, ty) = const_typ_eqn thm in (c, typscheme thy (c, ty)) end;
 
 
 (* case cerificates *)
 
 fun case_certificate thm =
   let
-    val thy = Thm.theory_of_thm thm;
     val ((head, raw_case_expr), cases) = (apfst Logic.dest_equals
       o apsnd Logic.dest_conjunctions o Logic.dest_implies o Thm.prop_of) thm;
     val _ = case head of Free _ => true