src/Pure/Tools/codegen_func.ML
changeset 22033 8e19bad4125f
parent 22023 487b79b95a20
child 22049 a995f9a8f669
--- a/src/Pure/Tools/codegen_func.ML	Tue Jan 09 08:31:47 2007 +0100
+++ b/src/Pure/Tools/codegen_func.ML	Tue Jan 09 08:31:48 2007 +0100
@@ -2,43 +2,152 @@
     ID:         $Id$
     Author:     Florian Haftmann, TU Muenchen
 
-Handling defining equations ("func"s) for code generator framework
+Handling defining equations ("func"s) for code generator framework.
 *)
 
-(* FIXME move various stuff here *)
-
 signature CODEGEN_FUNC =
 sig
-  val expand_eta: theory -> int -> thm -> thm
+  val check_rew: thm -> thm
+  val mk_rew: thm -> thm list
+  val check_func: thm -> (CodegenConsts.const * thm) option
+  val mk_func: thm -> (CodegenConsts.const * thm) list
+  val dest_func: thm -> (string * typ) * term list
+  val mk_head: thm -> CodegenConsts.const * thm
+  val typ_func: thm -> typ
+  val legacy_mk_func: thm -> (CodegenConsts.const * thm) list
+  val expand_eta: int -> thm -> thm
+  val rewrite_func: thm list -> thm -> thm
+  val get_prim_def_funcs: theory -> string * typ list -> thm list
 end;
 
 structure CodegenFunc : CODEGEN_FUNC =
 struct
 
-(* FIXME get rid of this code duplication *)
-val purify_name =
+fun lift_thm_thy f thm = f (Thm.theory_of_thm thm) thm;
+
+fun bad_thm msg thm =
+  error (msg ^ ": " ^ string_of_thm thm);
+
+
+(* making rewrite theorems *)
+
+fun check_rew thm =
+  let
+    val thy = Thm.theory_of_thm thm;
+    val (lhs, rhs) = (Logic.dest_equals o Thm.prop_of) thm;
+    fun vars_of t = fold_aterms
+     (fn Var (v, _) => insert (op =) v
+       | Free _ => bad_thm "Illegal free variable in rewrite theorem" thm
+       | _ => I) t [];
+    fun tvars_of t = fold_term_types
+     (fn _ => fold_atyps (fn TVar (v, _) => insert (op =) v
+                          | TFree _ => bad_thm "Illegal free type variable in rewrite theorem" thm)) t [];
+    val lhs_vs = vars_of lhs;
+    val rhs_vs = vars_of rhs;
+    val lhs_tvs = tvars_of lhs;
+    val rhs_tvs = tvars_of lhs;
+    val _ = if null (subtract (op =) lhs_vs rhs_vs)
+      then ()
+      else bad_thm "Free variables on right hand side of rewrite theorems" thm
+    val _ = if null (subtract (op =) lhs_tvs rhs_tvs)
+      then ()
+      else bad_thm "Free type variables on right hand side of rewrite theorems" thm
+  in thm end;
+
+fun mk_rew thm =
   let
-    fun is_valid s = Symbol.is_ascii_letter s orelse Symbol.is_ascii_digit s orelse s = "'";
-    val is_junk = not o is_valid andf Symbol.not_eof;
-    val junk = Scan.many is_junk;
-    val scan_valids = Symbol.scanner "Malformed input"
-      ((junk |--
-        (Scan.optional (Scan.one Symbol.is_ascii_letter) "x" ^^ (Scan.many is_valid >> implode)
-        --| junk))
-      -- Scan.repeat ((Scan.many1 is_valid >> implode) --| junk) >> op ::);
-  in explode #> scan_valids #> space_implode "_" end;
+    val thy = Thm.theory_of_thm thm;
+    val thms = (#mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of) thy thm;
+  in
+    map check_rew thms
+  end;
+
+
+(* making function theorems *)
+
+val typ_func = lift_thm_thy (fn thy => snd o dest_Const o fst o strip_comb
+  o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of);
+
+val dest_func = lift_thm_thy (fn thy => apfst dest_Const o strip_comb
+  o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of
+  o Drule.fconv_rule Drule.beta_eta_conversion);
+
+val mk_head = lift_thm_thy (fn thy => fn thm =>
+  ((CodegenConsts.norm_of_typ thy o fst o dest_func) thm, thm));
 
-val purify_lower =
-  explode
-  #> (fn cs => (if forall Symbol.is_ascii_upper cs
-        then map else nth_map 0) Symbol.to_ascii_lower cs)
-  #> implode;
+fun gen_check_func strict_functyp thm = case try dest_func thm
+ of SOME (c_ty as (c, ty), args) =>
+      let
+        val thy = Thm.theory_of_thm thm;
+        val _ =
+          if has_duplicates (op =)
+            ((fold o fold_aterms) (fn Var (v, _) => cons v
+              | _ => I
+            ) args [])
+          then bad_thm "Repeated variables on left hand side of function equation" thm
+          else ()
+        fun no_abs (Abs _) = bad_thm "Abstraction on left hand side of function equation" thm 
+          | no_abs (t1 $ t2) = (no_abs t1; no_abs t2)
+          | no_abs _ = ();
+        val _ = map no_abs args;
+        val is_classop = (is_some o AxClass.class_of_param thy) c;
+        val const = CodegenConsts.norm_of_typ thy c_ty;
+        val ty_decl = CodegenConsts.disc_typ_of_const thy
+          (snd o CodegenConsts.typ_of_inst thy) const;
+        val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy);
+        val error_warning = if strict_functyp
+          then error
+          else warning #> K NONE
+      in if Sign.typ_equiv thy (ty_decl, ty)
+        then SOME (const, thm)
+        else (if is_classop
+            then error_warning
+          else if Sign.typ_equiv thy (Type.strip_sorts ty_decl, Type.strip_sorts ty)
+            then warning #> (K o SOME) (const, thm)
+          else error_warning)
+          ("Type\n" ^ string_of_typ ty
+           ^ "\nof function theorem\n"
+           ^ string_of_thm thm
+           ^ "\nis strictly less general than declared function type\n"
+           ^ string_of_typ ty_decl)
+      end
+  | NONE => bad_thm "Not a function equation" thm;
 
-fun purify_var "" = "x"
-  | purify_var v = (purify_name #> purify_lower) v;
+val check_func = gen_check_func true;
+val legacy_check_func = gen_check_func false;
 
-fun expand_eta thy k thm =
+fun check_typ_classop thm =
   let
+    val thy = Thm.theory_of_thm thm;
+    val (c_ty as (c, ty), _) = dest_func thm;
+  in case AxClass.class_of_param thy c
+   of SOME class => let
+        val const = CodegenConsts.norm_of_typ thy c_ty;
+        val ty_decl = CodegenConsts.disc_typ_of_const thy
+            (snd o CodegenConsts.typ_of_inst thy) const;
+        val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy);
+      in if Sign.typ_equiv thy (ty_decl, ty)
+        then thm
+        else error
+          ("Type\n" ^ string_of_typ ty
+           ^ "\nof function theorem\n"
+           ^ string_of_thm thm
+           ^ "\nis strictly less general than declared function type\n"
+           ^ string_of_typ ty_decl)
+      end
+    | NONE => thm
+  end;
+
+fun gen_mk_func check_func = map_filter check_func o mk_rew;
+val mk_func = gen_mk_func check_func;
+val legacy_mk_func = gen_mk_func legacy_check_func;
+
+
+(* utilities *)
+
+fun expand_eta k thm =
+  let
+    val thy = Thm.theory_of_thm thm;
     val (lhs, rhs) = (Logic.dest_equals o Drule.plain_prop_of) thm;
     val (head, args) = strip_comb lhs;
     val l = if k = ~1
@@ -48,7 +157,7 @@
     fun get_name _ 0 used = ([], used)
       | get_name (Abs (v, ty, t)) k used =
           used
-          |> Name.variants [purify_var v]
+          |> Name.variants [v]
           ||>> get_name t (k - 1)
           |>> (fn ([v'], vs') => (v', ty) :: vs')
       | get_name t k used = 
@@ -68,4 +177,43 @@
     fold (fn refl => fn thm => Thm.combination thm refl) vs_refl thm
   end;
 
+fun get_prim_def_funcs thy c =
+  let
+    fun constrain thm0 thm = case AxClass.class_of_param thy (fst c)
+     of SOME _ =>
+          let
+            val ty_decl = CodegenConsts.disc_typ_of_classop thy c;
+            val max = maxidx_of_typ ty_decl + 1;
+            val thm = Thm.incr_indexes max thm;
+            val ty = typ_func thm;
+            val (env, _) = Sign.typ_unify thy (ty_decl, ty) (Vartab.empty, max);
+            val instT = Vartab.fold (fn (x_i, (sort, ty)) =>
+              cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [];
+          in Thm.instantiate (instT, []) thm end
+      | NONE => thm
+  in case CodegenConsts.find_def thy c
+   of SOME ((_, thm), _) =>
+        thm
+        |> Thm.transfer thy
+        |> try (map snd o mk_func)
+        |> these
+        |> map (constrain thm)
+        |> map (expand_eta ~1)
+    | NONE => []
+  end;
+
+fun rewrite_func rewrites thm =
+  let
+    val rewrite = MetaSimplifier.rewrite false rewrites;
+    val (ct_eq, [ct_lhs, ct_rhs]) = (Drule.strip_comb o Thm.cprop_of) thm;
+    val Const ("==", _) = Thm.term_of ct_eq;
+    val (ct_f, ct_args) = Drule.strip_comb ct_lhs;
+    val rhs' = rewrite ct_rhs;
+    val args' = map rewrite ct_args;
+    val lhs' = Thm.symmetric (fold (fn th1 => fn th2 => Thm.combination th2 th1)
+      args' (Thm.reflexive ct_f));
+  in
+    Thm.transitive (Thm.transitive lhs' thm) rhs'
+  end handle Bind => raise ERROR "rewrite_func"
+
 end;