clarified function transformator interface
authorhaftmann
Fri, 26 Sep 2008 09:09:52 +0200
changeset 28368 8437fb395294
parent 28367 10ea34297962
child 28369 196bd0305c0d
clarified function transformator interface
src/Pure/Isar/code.ML
src/Pure/Isar/code_unit.ML
--- a/src/Pure/Isar/code.ML	Fri Sep 26 09:09:51 2008 +0200
+++ b/src/Pure/Isar/code.ML	Fri Sep 26 09:09:52 2008 +0200
@@ -8,14 +8,14 @@
 
 signature CODE =
 sig
-  val add_func: thm -> theory -> theory
-  val add_nonlinear_func: thm -> theory -> theory
-  val add_liberal_func: thm -> theory -> theory
-  val add_default_func: thm -> theory -> theory
-  val add_default_func_attr: Attrib.src
-  val del_func: thm -> theory -> theory
-  val del_funcs: string -> theory -> theory
-  val add_funcl: string * (thm * bool) list Susp.T -> theory -> theory
+  val add_eqn: thm -> theory -> theory
+  val add_nonlinear_eqn: thm -> theory -> theory
+  val add_liberal_eqn: thm -> theory -> theory
+  val add_default_eqn: thm -> theory -> theory
+  val add_default_eqn_attr: Attrib.src
+  val del_eqn: thm -> theory -> theory
+  val del_eqns: string -> theory -> theory
+  val add_eqnl: string * (thm * bool) list Susp.T -> theory -> theory
   val map_pre: (MetaSimplifier.simpset -> MetaSimplifier.simpset) -> theory -> theory
   val map_post: (MetaSimplifier.simpset -> MetaSimplifier.simpset) -> theory -> theory
   val add_inline: thm -> theory -> theory
@@ -35,7 +35,7 @@
 
   val coregular_algebra: theory -> Sorts.algebra
   val operational_algebra: theory -> (sort -> sort) * Sorts.algebra
-  val these_funcs: theory -> string -> (thm * bool) list
+  val these_eqns: theory -> string -> (thm * bool) list
   val get_datatype: theory -> string -> ((string * sort) list * (string * typ list) list)
   val get_datatype_of_constr: theory -> string -> string option
   val get_case_data: theory -> string -> (int * string list) option
@@ -172,23 +172,23 @@
 (* specification data *)
 
 datatype spec = Spec of {
-  funcs: (bool * (thm * bool) list Susp.T) Symtab.table,
+  eqns: (bool * (thm * bool) list Susp.T) Symtab.table,
   dtyps: ((string * sort) list * (string * typ list) list) Symtab.table,
   cases: (int * string list) Symtab.table * unit Symtab.table
 };
 
-fun mk_spec (funcs, (dtyps, cases)) =
-  Spec { funcs = funcs, dtyps = dtyps, cases = cases };
-fun map_spec f (Spec { funcs = funcs, dtyps = dtyps, cases = cases }) =
-  mk_spec (f (funcs, (dtyps, cases)));
-fun merge_spec (Spec { funcs = funcs1, dtyps = dtyps1, cases = (cases1, undefs1) },
-  Spec { funcs = funcs2, dtyps = dtyps2, cases = (cases2, undefs2) }) =
+fun mk_spec (eqns, (dtyps, cases)) =
+  Spec { eqns = eqns, dtyps = dtyps, cases = cases };
+fun map_spec f (Spec { eqns = eqns, dtyps = dtyps, cases = cases }) =
+  mk_spec (f (eqns, (dtyps, cases)));
+fun merge_spec (Spec { eqns = eqns1, dtyps = dtyps1, cases = (cases1, undefs1) },
+  Spec { eqns = eqns2, dtyps = dtyps2, cases = (cases2, undefs2) }) =
   let
-    val funcs = Symtab.join (K merge_defthms) (funcs1, funcs2);
+    val eqns = Symtab.join (K merge_defthms) (eqns1, eqns2);
     val dtyps = merge_dtyps (dtyps1, dtyps2);
     val cases = (Symtab.merge (K true) (cases1, cases2),
       Symtab.merge (K true) (undefs1, undefs2));
-  in mk_spec (funcs, (dtyps, cases)) end;
+  in mk_spec (eqns, (dtyps, cases)) end;
 
 
 (* pre- and postprocessor *)
@@ -234,11 +234,11 @@
 
 fun the_thmproc (Exec { thmproc = Thmproc x, ...}) = x;
 fun the_spec (Exec { spec = Spec x, ...}) = x;
-val the_funcs = #funcs o the_spec;
+val the_eqns = #eqns o the_spec;
 val the_dtyps = #dtyps o the_spec;
 val the_cases = #cases o the_spec;
 val map_thmproc = map_exec o apfst o map_thmproc;
-val map_funcs = map_exec o apsnd o map_spec o apfst;
+val map_eqns = map_exec o apsnd o map_spec o apfst;
 val map_dtyps = map_exec o apsnd o map_spec o apsnd o apfst;
 val map_cases = map_exec o apsnd o map_spec o apsnd o apsnd;
 
@@ -358,7 +358,7 @@
   let
     val ctxt = ProofContext.init thy;
     val exec = the_exec thy;
-    fun pretty_func (s, (_, lthms)) =
+    fun pretty_eqn (s, (_, lthms)) =
       (Pretty.block o Pretty.fbreaks) (
         Pretty.str s :: pretty_lthms ctxt lthms
       );
@@ -378,7 +378,7 @@
     val pre = (#pre o the_thmproc) exec;
     val post = (#post o the_thmproc) exec;
     val functrans = (map fst o #functrans o the_thmproc) exec;
-    val funcs = the_funcs exec
+    val eqns = the_eqns exec
       |> Symtab.dest
       |> (map o apfst) (Code_Unit.string_of_const thy)
       |> sort (string_ord o pairself fst);
@@ -392,7 +392,7 @@
       Pretty.block (
         Pretty.str "defining equations:"
         :: Pretty.fbrk
-        :: (Pretty.fbreaks o map pretty_func) funcs
+        :: (Pretty.fbreaks o map pretty_eqn) eqns
       ),
       Pretty.block [
         Pretty.str "preprocessing simpset:",
@@ -421,14 +421,13 @@
 
 (** theorem transformation and certification **)
 
-fun const_of thy = dest_Const o fst o strip_comb o fst o Logic.dest_equals
-  o ObjectLogic.drop_judgment thy o Thm.plain_prop_of;
+fun const_of thy = dest_Const o fst o strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of;
+
+fun const_of_eqn thy = AxClass.unoverload_const thy o const_of thy;
 
-fun const_of_func thy = AxClass.unoverload_const thy o const_of thy;
-
-fun common_typ_funcs [] = []
-  | common_typ_funcs [thm] = [thm]
-  | common_typ_funcs (thms as thm :: _) = (*FIXME is too general*)
+fun common_typ_eqns [] = []
+  | common_typ_eqns [thm] = [thm]
+  | common_typ_eqns (thms as thm :: _) = (*FIXME is too general*)
       let
         val thy = Thm.theory_of_thm thm;
         fun incr_thm thm max =
@@ -451,7 +450,7 @@
 
 fun certify_const thy const thms =
   let
-    fun cert thm = if const = const_of_func thy thm
+    fun cert thm = if const = const_of_eqn thy thm
       then thm else error ("Wrong head of defining equation,\nexpected constant "
         ^ Code_Unit.string_of_const thy const ^ "\n" ^ Display.string_of_thm thm)
   in map cert thms end;
@@ -475,15 +474,15 @@
   let
     val vs = Name.invents Name.context "" (Sign.arity_number thy tyco);
     val classparams = (map fst o these o try (#params o AxClass.get_info thy)) class;
-    val funcs = classparams
+    val eqns = classparams
       |> map_filter (fn c => try (AxClass.param_of_inst thy) (c, tyco))
-      |> map (Symtab.lookup ((the_funcs o the_exec) thy))
+      |> map (Symtab.lookup ((the_eqns o the_exec) thy))
       |> (map o Option.map) (map fst o Susp.force o snd)
       |> maps these
       |> map (Thm.transfer thy);
     fun sorts_of [Type (_, tys)] = map (snd o dest_TVar) tys
       | sorts_of tys = map (snd o dest_TVar) tys;
-    val sorts = map (sorts_of o Sign.const_typargs thy o const_of thy) funcs;
+    val sorts = map (sorts_of o Sign.const_typargs thy o const_of thy) eqns;
   in sorts end;
 
 fun weakest_constraints thy algebra (class, tyco) =
@@ -548,7 +547,12 @@
 val classparam_weakest_typ = gen_classparam_typ weakest_constraints;
 val classparam_strongest_typ = gen_classparam_typ strongest_constraints;
 
-fun assert_func_typ thm =
+fun assert_eqn_linear (eqn as (thm, linear)) =
+  if linear then eqn else Code_Unit.bad_thm
+    ("Duplicate variables on left hand side of defining equation:\n"
+      ^ Display.string_of_thm thm);
+
+fun assert_eqn_typ (thm, linear) =
   let
     val thy = Thm.theory_of_thm thm;
     fun check_typ_classparam tyco (c, thm) =
@@ -597,12 +601,18 @@
       case AxClass.inst_of_param thy c
        of SOME (c, tyco) => check_typ_classparam tyco (c, thm)
         | NONE => check_typ_fun (c, thm);
-  in check_typ (const_of_func thy thm, thm) end;
+    val c = const_of_eqn thy thm;
+    val thm' = check_typ (c, thm);
+  in (thm', linear) end;
 
-fun mk_func linear = Code_Unit.error_thm (assert_func_typ o Code_Unit.mk_func linear);
-val mk_liberal_func = Code_Unit.warning_thm (assert_func_typ o Code_Unit.mk_func true);
-val mk_syntactic_func = Code_Unit.warning_thm (assert_func_typ o Code_Unit.mk_func false);
-val mk_default_func = Code_Unit.try_thm (assert_func_typ o Code_Unit.mk_func true);
+fun mk_eqn linear = Code_Unit.error_thm
+  (assert_eqn_typ o (if linear then assert_eqn_linear else I) o Code_Unit.mk_eqn);
+val mk_liberal_eqn = Code_Unit.warning_thm
+  (assert_eqn_typ o assert_eqn_linear o Code_Unit.mk_eqn);
+val mk_syntactic_eqn = Code_Unit.warning_thm
+  (assert_eqn_typ o Code_Unit.mk_eqn);
+val mk_default_eqn = Code_Unit.try_thm
+  (assert_eqn_typ o assert_eqn_linear o Code_Unit.mk_eqn);
 
 end; (*local*)
 
@@ -641,54 +651,54 @@
 
 val is_undefined = Symtab.defined o snd o the_cases o the_exec;
 
-fun gen_add_func linear strict default thm thy =
-  case (if strict then SOME o mk_func linear else mk_liberal_func) thm
-   of SOME func =>
+fun gen_add_eqn linear strict default thm thy =
+  case (if strict then SOME o mk_eqn linear else mk_liberal_eqn) thm
+   of SOME (thm, _) =>
         let
-          val c = const_of_func thy func;
+          val c = const_of_eqn thy thm;
           val _ = if strict andalso (is_some o AxClass.class_of_param thy) c
             then error ("Rejected polymorphic equation for overloaded constant:\n"
               ^ Display.string_of_thm thm)
             else ();
           val _ = if strict andalso (is_some o get_datatype_of_constr thy) c
             then error ("Rejected equation for datatype constructor:\n"
-              ^ Display.string_of_thm func)
+              ^ Display.string_of_thm thm)
             else ();
         in
-          (map_exec_purge (SOME [c]) o map_funcs) (Symtab.map_default
-            (c, (true, Susp.value [])) (add_thm default (func, linear))) thy
+          (map_exec_purge (SOME [c]) o map_eqns) (Symtab.map_default
+            (c, (true, Susp.value [])) (add_thm default (thm, linear))) thy
         end
     | NONE => thy;
 
-val add_func = gen_add_func true true false;
-val add_liberal_func = gen_add_func true false false;
-val add_default_func = gen_add_func true false true;
-val add_nonlinear_func = gen_add_func false true false;
+val add_eqn = gen_add_eqn true true false;
+val add_liberal_eqn = gen_add_eqn true false false;
+val add_default_eqn = gen_add_eqn true false true;
+val add_nonlinear_eqn = gen_add_eqn false true false;
 
-fun del_func thm thy = case mk_syntactic_func thm
- of SOME func => let
-        val c = const_of_func thy func;
-      in map_exec_purge (SOME [c]) (map_funcs
-        (Symtab.map_entry c (del_thm func))) thy
+fun del_eqn thm thy = case mk_syntactic_eqn thm
+ of SOME (thm, _) => let
+        val c = const_of_eqn thy thm;
+      in map_exec_purge (SOME [c]) (map_eqns
+        (Symtab.map_entry c (del_thm thm))) thy
       end
   | NONE => thy;
 
-fun del_funcs c = map_exec_purge (SOME [c])
-  (map_funcs (Symtab.map_entry c (K (false, Susp.value []))));
+fun del_eqns c = map_exec_purge (SOME [c])
+  (map_eqns (Symtab.map_entry c (K (false, Susp.value []))));
 
-fun add_funcl (c, lthms) thy =
+fun add_eqnl (c, lthms) thy =
   let
     val lthms' = certificate thy (fn thy => certify_const thy c) lthms;
       (*FIXME must check compatibility with sort algebra;
         alas, naive checking results in non-termination!*)
   in
     map_exec_purge (SOME [c])
-      (map_funcs (Symtab.map_default (c, (true, Susp.value []))
+      (map_eqns (Symtab.map_default (c, (true, Susp.value []))
         (add_lthms lthms'))) thy
   end;
 
-val add_default_func_attr = Attrib.internal (fn _ => Thm.declaration_attribute
-  (fn thm => Context.mapping (add_default_func thm) I));
+val add_default_eqn_attr = Attrib.internal (fn _ => Thm.declaration_attribute
+  (fn thm => Context.mapping (add_default_eqn thm) I));
 
 structure TypeInterpretation = InterpretationFun(type T = string * serial val eq = eq_snd (op =) : T * T -> bool);
 
@@ -703,7 +713,7 @@
   in
     thy
     |> map_exec_purge purge_cs (map_dtyps (Symtab.update (tyco, vs_cos))
-        #> map_funcs (fold (Symtab.delete_safe o fst) cs))
+        #> map_eqns (fold (Symtab.delete_safe o fst) cs))
     |> TypeInterpretation.data (tyco, serial ())
   end;
 
@@ -762,8 +772,8 @@
         || Scan.succeed (mk_attribute add))
   in
     TypeInterpretation.init
-    #> add_del_attribute ("func", (add_func, del_func))
-    #> add_simple_attribute ("nbe", add_nonlinear_func)
+    #> add_del_attribute ("func", (add_eqn, del_eqn))
+    #> add_simple_attribute ("nbe", add_nonlinear_eqn)
     #> add_del_attribute ("inline", (add_inline, del_inline))
     #> add_del_attribute ("post", (add_post, del_post))
   end));
@@ -776,16 +786,12 @@
 fun apply_functrans thy [] = []
   | apply_functrans thy (thms as (thm, _) :: _) =
       let
-        val const = const_of_func thy thm;
+        val const = const_of_eqn thy thm;
         val functrans = (map (fn (_, (_, f)) => f thy) o #functrans
           o the_thmproc o the_exec) thy;
         val thms' = perhaps (perhaps_loop (perhaps_apply functrans)) (map fst thms);
         val thms'' = certify_const thy const thms';
-        val linears = map snd thms;
-      in (*FIXME temporary workaround*) if length thms'' = length linears
-        then thms'' ~~ linears
-        else map (rpair true) thms''
-      end;
+      in map Code_Unit.add_linear thms'' end;
 
 fun rhs_conv conv thm =
   let
@@ -807,10 +813,10 @@
   in
     thms
     |> apply_functrans thy
-    |> (map o apfst) (Code_Unit.rewrite_func pre)
+    |> (map o apfst) (Code_Unit.rewrite_eqn pre)
     (*FIXME - must check here: rewrite rule, defining equation, proper constant *)
     |> (map o apfst) (AxClass.unoverload thy)
-    |> burrow_fst common_typ_funcs
+    |> burrow_fst common_typ_eqns
   end;
 
 
@@ -850,28 +856,28 @@
 
 local
 
-fun get_funcs thy const =
-  Symtab.lookup ((the_funcs o the_exec) thy) const
+fun get_eqns thy const =
+  Symtab.lookup ((the_eqns o the_exec) thy) const
   |> Option.map (Susp.force o snd)
   |> these
   |> (map o apfst) (Thm.transfer thy);
 
 in
 
-fun these_funcs thy const =
+fun these_eqns thy const =
   let
-    fun drop_refl thy = filter_out (is_equal o Term.fast_term_ord o Logic.dest_equals
-      o ObjectLogic.drop_judgment thy o Thm.plain_prop_of o fst);
+    val drop_refl = filter_out
+      (is_equal o Term.fast_term_ord o Logic.dest_equals o Thm.plain_prop_of o fst);
   in
-    get_funcs thy const
+    get_eqns thy const
     |> preprocess thy
-    |> drop_refl thy
+    |> drop_refl
   end;
 
 fun default_typ thy c = case default_typ_proto thy c
  of SOME ty => Code_Unit.typscheme thy (c, ty)
-  | NONE => (case get_funcs thy c
-     of (thm, _) :: _ => snd (Code_Unit.head_func (AxClass.unoverload thy thm))
+  | NONE => (case get_eqns thy c
+     of (thm, _) :: _ => snd (Code_Unit.head_eqn (AxClass.unoverload thy thm))
       | [] => Code_Unit.typscheme thy (c, Sign.the_const_type thy c));
 
 end; (*local*)
--- a/src/Pure/Isar/code_unit.ML	Fri Sep 26 09:09:51 2008 +0200
+++ b/src/Pure/Isar/code_unit.ML	Fri Sep 26 09:09:52 2008 +0200
@@ -9,9 +9,9 @@
 sig
   (*generic non-sense*)
   val bad_thm: string -> 'a
-  val error_thm: (thm -> thm) -> thm -> thm
-  val warning_thm: (thm -> thm) -> thm -> thm option
-  val try_thm: (thm -> thm) -> thm -> thm option
+  val error_thm: ('a -> 'b) -> 'a -> 'b
+  val warning_thm: ('a -> 'b) -> 'a -> 'b option
+  val try_thm: ('a -> 'b) -> 'a -> 'b option
 
   (*typ instantiations*)
   val typscheme: theory -> string * typ -> (string * sort) list * typ
@@ -38,10 +38,11 @@
   (*defining equations*)
   val assert_rew: thm -> thm
   val mk_rew: thm -> thm
-  val mk_func: bool -> thm -> thm
-  val head_func: thm -> string * ((string * sort) list * typ)
+  val add_linear: thm -> thm * bool
+  val mk_eqn: thm -> thm * bool
+  val head_eqn: thm -> string * ((string * sort) list * typ)
   val expand_eta: int -> thm -> thm
-  val rewrite_func: simpset -> thm -> thm
+  val rewrite_eqn: simpset -> thm -> thm
   val rewrite_head: thm list -> thm -> thm
   val norm_args: thm list -> thm list 
   val norm_varnames: (string -> string) -> (string -> string) -> thm list -> thm list
@@ -135,7 +136,7 @@
     |> Conv.fconv_rule Drule.beta_eta_conversion
   end;
 
-fun func_conv conv =
+fun eqn_conv conv =
   let
     fun lhs_conv ct = if can Thm.dest_comb ct
       then (Conv.combination_conv lhs_conv conv) ct
@@ -149,7 +150,7 @@
       else conv ct;
   in Conv.fun_conv (Conv.arg_conv lhs_conv) end;
 
-val rewrite_func = Conv.fconv_rule o func_conv o Simplifier.rewrite;
+val rewrite_eqn = Conv.fconv_rule o eqn_conv o Simplifier.rewrite;
 val rewrite_head = Conv.fconv_rule o head_conv o MetaSimplifier.rewrite false;
 
 fun norm_args thms =
@@ -361,21 +362,19 @@
 
 (* defining equations *)
 
-fun assert_func linear thm =
+fun add_linear thm =
+  let
+    val (_, args) = (strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of) thm;
+    val linear = not (has_duplicates (op =)
+      ((fold o fold_aterms) (fn Var (v, _) => cons v | _ => I) args []))
+  in (thm, linear) end;
+
+fun assert_eqn thm =
   let
     val thy = Thm.theory_of_thm thm;
-    val (head, args) = (strip_comb o fst o Logic.dest_equals
-      o ObjectLogic.drop_judgment thy o Thm.plain_prop_of) thm;
+    val (head, args) = (strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of) thm;
     val _ = case head of Const _ => () | _ =>
       bad_thm ("Equation not headed by constant\n" ^ Display.string_of_thm thm);
-    val _ =
-      if linear andalso has_duplicates (op =)
-        ((fold o fold_aterms) (fn Var (v, _) => cons v
-          | _ => I
-        ) args [])
-      then bad_thm ("Duplicated variables on left hand side of equation\n"
-        ^ Display.string_of_thm thm)
-      else ()
     fun check _ (Abs _) = bad_thm
           ("Abstraction on left hand side of equation\n"
             ^ Display.string_of_thm thm)
@@ -390,11 +389,13 @@
                ^ Display.string_of_thm thm)
           else ();
     val _ = map (check 0) args;
-  in thm end;
+    val linear = not (has_duplicates (op =)
+      ((fold o fold_aterms) (fn Var (v, _) => cons v | _ => I ) args []))
+  in add_linear thm end;
 
-fun mk_func linear = assert_func linear o mk_rew;
+val mk_eqn = assert_eqn o mk_rew;
 
-fun head_func thm =
+fun head_eqn thm =
   let
     val thy = Thm.theory_of_thm thm;
     val Const (c, ty) = (fst o strip_comb o fst o Logic.dest_equals