src/Pure/Tools/codegen_data.ML
changeset 22033 8e19bad4125f
parent 22020 e52aef4ab54b
child 22050 859e5784c58c
--- a/src/Pure/Tools/codegen_data.ML	Tue Jan 09 08:31:47 2007 +0100
+++ b/src/Pure/Tools/codegen_data.ML	Tue Jan 09 08:31:48 2007 +0100
@@ -29,10 +29,8 @@
 
   val print_thms: theory -> unit
 
-  val typ_func: theory -> thm -> typ
   val typ_funcs: theory -> CodegenConsts.const * thm list -> typ
-  val rewrite_func: thm list -> thm -> thm
-  val preprocess_cterm: theory -> cterm -> thm
+  val preprocess_cterm: cterm -> thm
 
   val trace: bool ref
 end;
@@ -107,144 +105,6 @@
 
 (** code theorems **)
 
-(* making rewrite theorems *)
-
-fun bad_thm msg thm =
-  error (msg ^ ": " ^ string_of_thm thm);
-
-fun check_rew thy thm =
-  let
-    val (lhs, rhs) = (Logic.dest_equals o Thm.prop_of) thm;
-    fun vars_of t = fold_aterms
-     (fn Var (v, _) => insert (op =) v
-       | Free _ => bad_thm "Illegal free variable in rewrite theorem" thm
-       | _ => I) t [];
-    fun tvars_of t = fold_term_types
-     (fn _ => fold_atyps (fn TVar (v, _) => insert (op =) v
-                          | TFree _ => bad_thm "Illegal free type variable in rewrite theorem" thm)) t [];
-    val lhs_vs = vars_of lhs;
-    val rhs_vs = vars_of rhs;
-    val lhs_tvs = tvars_of lhs;
-    val rhs_tvs = tvars_of lhs;
-    val _ = if null (subtract (op =) lhs_vs rhs_vs)
-      then ()
-      else bad_thm "Free variables on right hand side of rewrite theorems" thm
-    val _ = if null (subtract (op =) lhs_tvs rhs_tvs)
-      then ()
-      else bad_thm "Free type variables on right hand side of rewrite theorems" thm
-  in thm end;
-
-fun mk_rew thy thm =
-  let
-    val thms = (#mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of) thy thm;
-  in
-    map (check_rew thy) thms
-  end;
-
-
-(* making function theorems *)
-
-fun typ_func thy = snd o dest_Const o fst o strip_comb
-  o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of;
-
-val strict_functyp = ref true;
-
-fun dest_func thy = apfst dest_Const o strip_comb
-  o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of
-  o Drule.fconv_rule Drule.beta_eta_conversion;
-
-fun mk_head thy thm =
-  ((CodegenConsts.norm_of_typ thy o fst o dest_func thy) thm, thm);
-
-fun check_func thy thm = case try (dest_func thy) thm
- of SOME (c_ty as (c, ty), args) =>
-      let
-        val _ =
-          if has_duplicates (op =)
-            ((fold o fold_aterms) (fn Var (v, _) => cons v
-              | _ => I
-            ) args [])
-          then bad_thm "Repeated variables on left hand side of function equation" thm
-          else ()
-        fun no_abs (Abs _) = bad_thm "Abstraction on left hand side of function equation" thm 
-          | no_abs (t1 $ t2) = (no_abs t1; no_abs t2)
-          | no_abs _ = ();
-        val _ = map no_abs args;
-        val is_classop = (is_some o AxClass.class_of_param thy) c;
-        val const = CodegenConsts.norm_of_typ thy c_ty;
-        val ty_decl = CodegenConsts.disc_typ_of_const thy
-          (snd o CodegenConsts.typ_of_inst thy) const;
-        val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy);
-      in if Sign.typ_equiv thy (ty_decl, ty)
-        then SOME (const, thm)
-        else (if is_classop
-            then if !strict_functyp
-              then error
-              else warning #> K NONE
-          else if Sign.typ_equiv thy (Type.strip_sorts ty_decl, Type.strip_sorts ty)
-            then warning #> (K o SOME) (const, thm)
-          else if !strict_functyp
-            then error
-          else warning #> K NONE)
-          ("Type\n" ^ string_of_typ ty
-           ^ "\nof function theorem\n"
-           ^ string_of_thm thm
-           ^ "\nis strictly less general than declared function type\n"
-           ^ string_of_typ ty_decl)
-      end
-  | NONE => bad_thm "Not a function equation" thm;
-
-fun check_typ_classop thy thm =
-  let
-    val (c_ty as (c, ty), _) = dest_func thy thm;  
-  in case AxClass.class_of_param thy c
-   of SOME class => let
-        val const = CodegenConsts.norm_of_typ thy c_ty;
-        val ty_decl = CodegenConsts.disc_typ_of_const thy
-            (snd o CodegenConsts.typ_of_inst thy) const;
-        val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy);
-      in if Sign.typ_equiv thy (ty_decl, ty)
-        then thm
-        else error
-          ("Type\n" ^ string_of_typ ty
-           ^ "\nof function theorem\n"
-           ^ string_of_thm thm
-           ^ "\nis strictly less general than declared function type\n"
-           ^ string_of_typ ty_decl)
-      end
-    | NONE => thm
-  end;
-
-fun mk_func thy raw_thm =
-  mk_rew thy raw_thm
-  |> map_filter (check_func thy);
-
-fun get_prim_def_funcs thy c =
-  let
-    fun constrain thm0 thm = case AxClass.class_of_param thy (fst c)
-     of SOME _ =>
-          let
-            val ty_decl = CodegenConsts.disc_typ_of_classop thy c;
-            val max = maxidx_of_typ ty_decl + 1;
-            val thm = Thm.incr_indexes max thm;
-            val ty = typ_func thy thm;
-            val (env, _) = Sign.typ_unify thy (ty_decl, ty) (Vartab.empty, max);
-            val instT = Vartab.fold (fn (x_i, (sort, ty)) =>
-              cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [];
-          in Thm.instantiate (instT, []) thm end
-      | NONE => thm
-  in case CodegenConsts.find_def thy c
-   of SOME ((_, thm), _) =>
-        thm
-        |> Thm.transfer thy
-        |> try (map snd o mk_func thy)
-        |> these
-        |> map (constrain thm)
-        |> map (CodegenFunc.expand_eta thy ~1)
-    | NONE => []
-  end;
-
-
 (* pairs of (selected, deleted) function theorems *)
 
 type sdthms = thm list Susp.T * thm list;
@@ -529,31 +389,18 @@
 
 (** theorem transformation and certification **)
 
-fun rewrite_func rewrites thm =
-  let
-    val rewrite = MetaSimplifier.rewrite false rewrites;
-    val (ct_eq, [ct_lhs, ct_rhs]) = (Drule.strip_comb o Thm.cprop_of) thm;
-    val Const ("==", _) = Thm.term_of ct_eq;
-    val (ct_f, ct_args) = Drule.strip_comb ct_lhs;
-    val rhs' = rewrite ct_rhs;
-    val args' = map rewrite ct_args;
-    val lhs' = Thm.symmetric (fold (fn th1 => fn th2 => Thm.combination th2 th1)
-      args' (Thm.reflexive ct_f));
-  in
-    Thm.transitive (Thm.transitive lhs' thm) rhs'
-  end handle Bind => raise ERROR "rewrite_func"
-
-fun common_typ_funcs thy [] = []
-  | common_typ_funcs thy [thm] = [thm]
-  | common_typ_funcs thy thms =
+fun common_typ_funcs [] = []
+  | common_typ_funcs [thm] = [thm]
+  | common_typ_funcs thms =
       let
+        val thy = Thm.theory_of_thm (hd thms)
         fun incr_thm thm max =
           let
             val thm' = incr_indexes max thm;
             val max' = Thm.maxidx_of thm' + 1;
           in (thm', max') end;
         val (thms', maxidx) = fold_map incr_thm thms 0;
-        val (ty1::tys) = map (typ_func thy) thms';
+        val (ty1::tys) = map CodegenFunc.typ_func thms';
         fun unify ty env = Sign.typ_unify thy (ty1, ty) env
           handle Type.TUNIFY =>
             error ("Type unificaton failed, while unifying function equations\n"
@@ -568,8 +415,8 @@
 fun certify_const thy c c_thms =
   let
     fun cert (c', thm) = if CodegenConsts.eq_const (c, c')
-      then thm else bad_thm ("Wrong head of function equation,\nexpected constant "
-        ^ CodegenConsts.string_of_const thy c) thm
+      then thm else error ("Wrong head of function equation,\nexpected constant "
+        ^ CodegenConsts.string_of_const thy c ^ "\n" ^ string_of_thm thm)
   in map cert c_thms end;
 
 fun mk_cos tyco vs cos =
@@ -647,9 +494,9 @@
 
 (** interfaces **)
 
-fun add_func thm thy =
+fun gen_add_func mk_func thm thy =
   let
-    val thms = mk_func thy thm;
+    val thms = mk_func thm;
     val cs = map fst thms;
   in
     map_exec_purge (SOME cs) (map_funcs 
@@ -657,11 +504,12 @@
        (c, (Susp.value [], [])) (add_thm thm)) thms)) thy
   end;
 
-fun add_func_legacy thm = setmp strict_functyp false (add_func thm);
+val add_func = gen_add_func CodegenFunc.mk_func;
+val add_func_legacy = gen_add_func CodegenFunc.legacy_mk_func;
 
 fun del_func thm thy =
   let
-    val thms = mk_func thy thm;
+    val thms = CodegenFunc.mk_func thm;
     val cs = map fst thms;
   in
     map_exec_purge (SOME cs) (map_funcs
@@ -672,7 +520,7 @@
 fun add_funcl (c, lthms) thy =
   let
     val c' = CodegenConsts.norm thy c;
-    val lthms' = certificate thy (fn thy => certify_const thy c' o maps (mk_func thy)) lthms;
+    val lthms' = certificate thy (fn thy => certify_const thy c' o maps (CodegenFunc.mk_func)) lthms;
   in
     map_exec_purge (SOME [c]) (map_funcs (Consttab.map_default (c', (Susp.value [], []))
       (add_lthms lthms'))) thy
@@ -699,10 +547,10 @@
   in map_exec_purge (SOME consts) del thy end;
 
 fun add_inline thm thy =
-  (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (insert eq_thm) (mk_rew thy thm)) thy;
+  (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (insert eq_thm) (CodegenFunc.mk_rew thm)) thy;
 
 fun del_inline thm thy =
-  (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (remove eq_thm) (mk_rew thy thm)) thy ;
+  (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (remove eq_thm) (CodegenFunc.mk_rew thm)) thy ;
 
 fun add_inline_proc f =
   (map_exec_purge NONE o map_preproc o apfst o apsnd) (cons (serial (), f));
@@ -715,12 +563,12 @@
 fun gen_apply_inline_proc prep post thy f x =
   let
     val cts = prep x;
-    val rews = map (check_rew thy) (f thy cts);
+    val rews = map CodegenFunc.check_rew (f thy cts);
   in post rews x end;
 
 val apply_inline_proc = gen_apply_inline_proc (maps
   ((fn [args, rhs] => rhs :: (snd o Drule.strip_comb) args) o snd o Drule.strip_comb o Thm.cprop_of))
-  (fn rews => map (rewrite_func rews));
+  (fn rews => map (CodegenFunc.rewrite_func rews));
 val apply_inline_proc_cterm = gen_apply_inline_proc single
   (MetaSimplifier.rewrite false);
 
@@ -728,11 +576,11 @@
   | apply_preproc thy f (thms as (thm :: _)) =
       let
         val thms' = f thy thms;
-        val c = (CodegenConsts.norm_of_typ thy o fst o dest_func thy) thm;
-      in (certify_const thy c o map (mk_head thy)) thms' end;
+        val c = (CodegenConsts.norm_of_typ thy o fst o CodegenFunc.dest_func) thm;
+      in (certify_const thy c o map CodegenFunc.mk_head) thms' end;
 
 fun cmp_thms thy =
-  make_ord (fn (thm1, thm2) => not (Sign.typ_instance thy (typ_func thy thm1, typ_func thy thm2)));
+  make_ord (fn (thm1, thm2) => not (Sign.typ_instance thy (CodegenFunc.typ_func thm1, CodegenFunc.typ_func thm2)));
 
 fun rhs_conv conv thm =
   let
@@ -744,19 +592,23 @@
 fun preprocess thy thms =
   thms
   |> fold (fn (_, f) => apply_preproc thy f) ((#preprocs o the_preproc o get_exec) thy)
-  |> map (rewrite_func ((#inlines o the_preproc o get_exec) thy))
+  |> map (CodegenFunc.rewrite_func ((#inlines o the_preproc o get_exec) thy))
   |> fold (fn (_, f) => apply_inline_proc thy f) ((#inline_procs o the_preproc o get_exec) thy)
 (*FIXME - must check: rewrite rule, function equation, proper constant |> map (snd o check_func false thy) *)
   |> sort (cmp_thms thy)
-  |> common_typ_funcs thy;
+  |> common_typ_funcs;
 
-fun preprocess_cterm thy ct =
-  ct
-  |> Thm.reflexive
-  |> fold (rhs_conv o MetaSimplifier.rewrite false o single)
-    ((#inlines o the_preproc o get_exec) thy)
-  |> fold (fn (_, f) => rhs_conv (apply_inline_proc_cterm thy f))
-    ((#inline_procs o the_preproc o get_exec) thy)
+fun preprocess_cterm ct =
+  let
+    val thy = Thm.theory_of_cterm ct
+  in
+    ct
+    |> Thm.reflexive
+    |> fold (rhs_conv o MetaSimplifier.rewrite false o single)
+      ((#inlines o the_preproc o get_exec) thy)
+    |> fold (fn (_, f) => rhs_conv (apply_inline_proc_cterm thy f))
+      ((#inline_procs o the_preproc o get_exec) thy)
+  end;
 
 end; (*local*)
 
@@ -768,7 +620,7 @@
       |> these
       |> map (Thm.transfer thy);
     val funcs_2 = case funcs_1
-     of [] => get_prim_def_funcs thy c
+     of [] => CodegenFunc.get_prim_def_funcs thy c
       | xs => xs;
     fun drop_refl thy = filter_out (is_equal o Term.fast_term_ord o Logic.dest_equals
       o ObjectLogic.drop_judgment thy o Drule.plain_prop_of);
@@ -789,9 +641,9 @@
 fun typ_funcs thy (c as (name, _), []) = (case AxClass.class_of_param thy name
      of SOME class => CodegenConsts.disc_typ_of_classop thy c
       | NONE => (case Option.map (Susp.force o fst) (Consttab.lookup ((the_funcs o get_exec) thy) c)
-         of SOME [eq] => typ_func thy eq
+         of SOME [eq] => CodegenFunc.typ_func eq
           | _ => Sign.the_const_type thy name))
-  | typ_funcs thy (_, eq :: _) = typ_func thy eq;
+  | typ_funcs thy (_, eq :: _) = CodegenFunc.typ_func eq;
 
 
 (** code attributes **)