changed preprocessing framework
authorhaftmann
Mon, 02 Oct 2006 23:01:05 +0200
changeset 20844 6792583aa463
parent 20843 a5343075bdc5
child 20845 c55dcf606f65
changed preprocessing framework
src/Pure/Tools/codegen_data.ML
--- a/src/Pure/Tools/codegen_data.ML	Mon Oct 02 23:01:04 2006 +0200
+++ b/src/Pure/Tools/codegen_data.ML	Mon Oct 02 23:01:05 2006 +0200
@@ -5,8 +5,6 @@
 Basic code generator data structures; abstract executable content of theory.
 *)
 
-(* val _ = PolyML.Compiler.maxInlineSize := 0;  *)
-
 signature CODEGEN_DATA =
 sig
   type lthms = thm list Susp.T;
@@ -21,6 +19,8 @@
   val del_datatype: string -> theory -> theory
   val add_inline: thm -> theory -> theory
   val del_inline: thm -> theory -> theory
+  val add_inline_proc: (theory -> cterm list -> thm list) -> theory -> theory
+  val add_constrains: (theory -> term list -> (indexname * sort) list) -> theory -> theory
   val add_preproc: (theory -> thm list -> thm list) -> theory -> theory
   val these_funcs: theory -> CodegenConsts.const -> thm list
   val get_datatype: theory -> string
@@ -31,10 +31,9 @@
 
   val typ_func: theory -> thm -> typ
   val rewrite_func: thm list -> thm -> thm
-  val preprocess_cterm: theory -> cterm -> thm
-  val preprocess: theory -> thm list -> thm list
+  val preprocess_cterm: theory -> (string * typ -> typ) -> cterm -> thm * cterm
 
-  val debug: bool ref
+  val trace: bool ref
   val strict_functyp: bool ref
 end;
 
@@ -55,8 +54,8 @@
 
 (** diagnostics **)
 
-val debug = ref false;
-fun debug_msg f x = (if !debug then Output.tracing (f x) else (); x);
+val trace = ref false;
+fun tracing f x = (if !trace then Output.tracing (f x) else (); x);
 
 
 
@@ -64,7 +63,6 @@
 
 type lthms = thm list Susp.T;
 val eval_always = ref false;
-val _ = eval_always := true;
 
 fun lazy f = if !eval_always
   then Susp.value (f ())
@@ -78,10 +76,12 @@
  of SOME thms => (map (ProofContext.pretty_thm ctxt) o rev) thms
   | NONE => [Pretty.str "[...]"];
 
-fun certificate f r =
+fun certificate thy f r =
   case Susp.peek r
-   of SOME thms => (Susp.value o f) thms
-     | NONE => lazy (fn () => (f o Susp.force) r);
+   of SOME thms => (Susp.value o f thy) thms
+     | NONE => let
+          val thy_ref = Theory.self_ref thy;
+        in lazy (fn () => (f (Theory.deref thy_ref) o Susp.force) r) end;
 
 fun merge' _ ([], []) = (false, [])
   | merge' _ ([], ys) = (true, ys)
@@ -107,45 +107,104 @@
 
 (** code theorems **)
 
-(* making function theorems *)
+(* making rewrite theorems *)
 
 fun bad_thm msg thm =
   error (msg ^ ": " ^ string_of_thm thm);
 
+fun check_rew thy thm =
+  let
+    val (lhs, rhs) = (Logic.dest_equals o Thm.prop_of) thm;
+    fun vars_of t = fold_aterms
+     (fn Var (v, _) => insert (op =) v
+       | Free _ => bad_thm "Illegal free variable in rewrite theorem" thm
+       | _ => I) t [];
+    fun tvars_of t = fold_term_types
+     (fn _ => fold_atyps (fn TVar (v, _) => insert (op =) v
+                          | TFree _ => bad_thm "Illegal free type variable in rewrite theorem" thm)) t [];
+    val lhs_vs = vars_of lhs;
+    val rhs_vs = vars_of rhs;
+    val lhs_tvs = tvars_of lhs;
+    val rhs_tvs = tvars_of lhs;
+    val _ = if null (subtract (op =) lhs_vs rhs_vs)
+      then ()
+      else bad_thm "Free variables on right hand side of rewrite theorems" thm
+    val _ = if null (subtract (op =) lhs_tvs rhs_tvs)
+      then ()
+      else bad_thm "Free type variables on right hand side of rewrite theorems" thm
+  in thm end;
+
+fun mk_rew thy thm =
+  let
+    val thms = (#mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of) thy thm;
+  in
+    map (check_rew thy) thms
+  end;
+
+
+(* making function theorems *)
+
 fun typ_func thy = snd o dest_Const o fst o strip_comb o fst 
   o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of;
 
-val mk_rew =
-  #mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of;
+val strict_functyp = ref true;
+
+fun dest_func thy = apfst dest_Const o strip_comb o Envir.beta_eta_contract
+  o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of;
+
+fun mk_head thy thm =
+  ((CodegenConsts.norm_of_typ thy o fst o dest_func thy) thm, thm);
 
-val strict_functyp = ref true;
+fun check_func verbose thy thm = case try (dest_func thy) thm
+ of SOME (c_ty as (c, ty), args) =>
+      let
+        val _ =
+          if has_duplicates (op =)
+            ((fold o fold_aterms) (fn Var (v, _) => cons v | _ => I) args [])
+          then bad_thm "Repeated variables on left hand side of function equation" thm
+          else ()
+        val is_classop = (is_some o AxClass.class_of_param thy) c;
+        val const = CodegenConsts.norm_of_typ thy c_ty;
+        val ty_decl = CodegenConsts.disc_typ_of_const thy
+          (snd o CodegenConsts.typ_of_inst thy) const;
+        val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy);
+      in if Sign.typ_equiv thy (ty_decl, ty)
+        then (const, thm)
+        else (if is_classop orelse (!strict_functyp andalso not
+          (Sign.typ_equiv thy (Type.strip_sorts ty_decl, Type.strip_sorts ty)))
+          then error else (if verbose then warning else K ()) #> K (const, thm))
+          ("Type\n" ^ string_of_typ ty
+           ^ "\nof function theorem\n"
+           ^ string_of_thm thm
+           ^ "\nis strictly less general than declared function type\n"
+           ^ string_of_typ ty_decl) 
+      end
+  | NONE => bad_thm "Not a function equation" thm;
+
+fun check_typ_classop thy thm =
+  let
+    val (c_ty as (c, ty), _) = dest_func thy thm;  
+  in case AxClass.class_of_param thy c
+   of SOME class => let
+        val const = CodegenConsts.norm_of_typ thy c_ty;
+        val ty_decl = CodegenConsts.disc_typ_of_const thy
+            (snd o CodegenConsts.typ_of_inst thy) const;
+        val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy);
+      in if Sign.typ_equiv thy (ty_decl, ty)
+        then thm
+        else error
+          ("Type\n" ^ string_of_typ ty
+           ^ "\nof function theorem\n"
+           ^ string_of_thm thm
+           ^ "\nis strictly less general than declared function type\n"
+           ^ string_of_typ ty_decl)
+      end
+    | NONE => thm
+  end;
 
 fun mk_func thy raw_thm =
-  let
-    fun dest_func thy = dest_Const o fst o strip_comb o Envir.beta_eta_contract
-      o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of;
-    fun mk_head thm = case try (dest_func thy) thm
-     of SOME (c_ty as (c, ty)) =>
-          let
-            val is_classop = (is_some o AxClass.class_of_param thy) c;
-            val const = CodegenConsts.norm_of_typ thy c_ty;
-            val ty_decl = CodegenConsts.disc_typ_of_const thy
-                (snd o CodegenConsts.typ_of_inst thy) const;
-            val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy);
-          in if Sign.typ_equiv thy (ty_decl, ty)
-            then (const, thm)
-            else ((if is_classop orelse !strict_functyp then error else warning)
-              ("Type\n" ^ string_of_typ ty
-               ^ "\nof function theorem\n"
-               ^ string_of_thm thm
-               ^ "\nis strictly less general than declared function type\n"
-               ^ string_of_typ ty_decl); (const, thm))
-          end
-      | NONE => bad_thm "Not a function equation" thm;
-  in
-    mk_rew thy raw_thm
-    |> map mk_head
-  end;
+  mk_rew thy raw_thm
+  |> map (check_func true thy);
 
 fun get_prim_def_funcs thy c =
   let
@@ -178,9 +237,7 @@
 
 fun add_drop_redundant thm thms =
   let
-(*     val _ = writeln "add_drop 01";  *)
     val thy = Context.check_thy (Thm.theory_of_thm thm);
-(*     val _ = writeln "add_drop 02";  *)
     val pattern = (fst o Logic.dest_equals o Drule.plain_prop_of) thm;
     fun matches thm' = if (curry (Pattern.matches thy) pattern o
       fst o Logic.dest_equals o Drule.plain_prop_of) thm'
@@ -222,19 +279,24 @@
 
 datatype preproc = Preproc of {
   inlines: thm list,
+  inline_procs: (serial * (theory -> cterm list -> thm list)) list,
+  constrains: (serial * (theory -> term list -> (indexname * sort) list)) list,
   preprocs: (serial * (theory -> thm list -> thm list)) list
 };
 
-fun mk_preproc (inlines, preprocs) =
-  Preproc { inlines = inlines, preprocs = preprocs };
-fun map_preproc f (Preproc { inlines, preprocs }) =
-  mk_preproc (f (inlines, preprocs));
-fun merge_preproc (Preproc { inlines = inlines1, preprocs = preprocs1 },
-  Preproc { inlines = inlines2, preprocs = preprocs2 }) =
+fun mk_preproc ((inlines, inline_procs), (constrains, preprocs)) =
+  Preproc { inlines = inlines, inline_procs = inline_procs, constrains = constrains, preprocs = preprocs };
+fun map_preproc f (Preproc { inlines, inline_procs, constrains, preprocs }) =
+  mk_preproc (f ((inlines, inline_procs), (constrains, preprocs)));
+fun merge_preproc (Preproc { inlines = inlines1, inline_procs = inline_procs1, constrains = constrains1 , preprocs = preprocs1 },
+  Preproc { inlines = inlines2, inline_procs = inline_procs2, constrains = constrains2 , preprocs = preprocs2 }) =
     let
       val (touched1, inlines) = merge_thms (inlines1, inlines2);
-      val (touched2, preprocs) = merge_alist (op =) (K true) (preprocs1, preprocs2);
-    in (touched1 orelse touched2, mk_preproc (inlines, preprocs)) end;
+      val (touched2, inline_procs) = merge_alist (op =) (K true) (inline_procs1, inline_procs2);
+      val (touched3, constrains) = merge_alist (op =) (K true) (constrains1, constrains2);
+      val (touched4, preprocs) = merge_alist (op =) (K true) (preprocs1, preprocs2);
+    in (touched1 orelse touched2 orelse touched3 orelse touched4,
+      mk_preproc ((inlines, inline_procs), (constrains, preprocs))) end;
 
 fun join_func_thms (tabs as (tab1, tab2)) =
   let
@@ -257,13 +319,13 @@
     andalso gen_eq_set (eq_pair eq_string (eq_list (is_equal o Term.typ_ord))) (cs1, cs2);
 fun merge_dtyps (tabs as (tab1, tab2)) =
   let
-    (*EXTEND: could be more clever with respect to constructors*)
     val tycos1 = Symtab.keys tab1;
     val tycos2 = Symtab.keys tab2;
     val tycos' = filter (member eq_string tycos2) tycos1;
-    val touched = gen_eq_set (eq_pair (op =) (eq_dtyp))
+    val touched = not (gen_eq_set (op =) (tycos1, tycos2) andalso
+      gen_eq_set (eq_pair (op =) (eq_dtyp))
       (AList.make (the o Symtab.lookup tab1) tycos',
-       AList.make (the o Symtab.lookup tab2) tycos');
+       AList.make (the o Symtab.lookup tab2) tycos'));
   in (touched, Symtab.merge (K true) tabs) end;
 
 datatype spec = Spec of {
@@ -301,7 +363,7 @@
     val (touched_cs, spec) = merge_spec (spec1, spec2);
     val touched = if touched' then NONE else touched_cs;
   in (touched, mk_exec (preproc, spec)) end;
-val empty_exec = mk_exec (mk_preproc ([], []),
+val empty_exec = mk_exec (mk_preproc (([], []), ([], [])),
   mk_spec ((Consttab.empty, Consttab.empty), Symtab.empty));
 
 fun the_preproc (Exec { preproc = Preproc x, ...}) = x;
@@ -450,9 +512,9 @@
 
 fun rewrite_func rewrites thm =
   let
-    val rewrite = Tactic.rewrite true rewrites;
-    val (ct_eq, [ct_lhs, ct_rhs]) = (Drule.strip_comb o cprop_of) thm;
-    val Const ("==", _) = term_of ct_eq;
+    val rewrite = Tactic.rewrite false rewrites;
+    val (ct_eq, [ct_lhs, ct_rhs]) = (Drule.strip_comb o Thm.cprop_of) thm;
+    val Const ("==", _) = Thm.term_of ct_eq;
     val (ct_f, ct_args) = Drule.strip_comb ct_lhs;
     val rhs' = rewrite ct_rhs;
     val args' = map rewrite ct_args;
@@ -484,12 +546,12 @@
           cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [];
       in map (Thm.instantiate (instT, [])) thms end;
 
-fun certify_const thy c thms =
+fun certify_const thy c c_thms =
   let
     fun cert (c', thm) = if CodegenConsts.eq_const (c, c')
       then thm else bad_thm ("Wrong head of function equation,\nexpected constant "
         ^ CodegenConsts.string_of_const thy c) thm
-  in (map cert o maps (mk_func thy)) thms end;
+  in map cert c_thms end;
 
 fun mk_cos tyco vs cos =
   let
@@ -589,7 +651,7 @@
 fun add_funcl (c, lthms) thy =
   let
     val c' = CodegenConsts.norm thy c;
-    val lthms' = certificate (certify_const thy c') lthms;
+    val lthms' = certificate thy (fn thy => certify_const thy c' o maps (mk_func thy)) lthms;
   in
     map_exec_purge (SOME [c]) (map_funcs (Consttab.map_default (c', (Susp.value [], []))
       (add_lthms lthms'))) thy
@@ -601,7 +663,7 @@
     val consts = map (CodegenConsts.norm_of_typ thy o dest_Const o fst) cs;
     val add =
       map_dtyps (Symtab.update_new (tyco,
-        (vs_cos, certificate (certify_datatype thy tyco cs) lthms)))
+        (vs_cos, certificate thy (fn thy => certify_datatype thy tyco cs) lthms)))
       #> map_dconstrs (fold (fn c => Consttab.update (c, tyco)) consts)
   in map_exec_purge (SOME consts) add thy end;
 
@@ -616,52 +678,145 @@
   in map_exec_purge (SOME consts) del thy end;
 
 fun add_inline thm thy =
-  map_exec_purge NONE (map_preproc (apfst (fold (insert eq_thm) (mk_rew thy thm)))) thy;
+  (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (insert eq_thm) (mk_rew thy thm)) thy;
 
 fun del_inline thm thy =
-  map_exec_purge NONE (map_preproc (apfst (fold (remove eq_thm) (mk_rew thy thm)))) thy ;
+  (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (remove eq_thm) (mk_rew thy thm)) thy ;
+
+fun add_inline_proc f =
+  (map_exec_purge NONE o map_preproc o apfst o apsnd) (cons (serial (), f));
+
+fun add_constrains f =
+  (map_exec_purge NONE o map_preproc o apsnd o apfst) (cons (serial (), f));
 
 fun add_preproc f =
-  map_exec_purge NONE (map_preproc (apsnd (cons (serial (), f))));
+  (map_exec_purge NONE o map_preproc o apsnd o apsnd) (cons (serial (), f));
+
+local
+
+fun gen_apply_constrain prep post const_typ thy fs x =
+  let
+    val ts = prep x;
+    val tvars = (fold o fold_aterms) Term.add_tvars ts [];
+    val consts = (fold o fold_aterms) (fn Const c => cons c | _ => I) ts [];
+    fun insts_of const_typ (c, ty) =
+      let
+        val ty_decl = const_typ (c, ty);
+        val env = Vartab.dest (Type.raw_match (ty_decl, ty) Vartab.empty);
+        val insts = map_filter
+         (fn (v, (sort, TVar (_, sort'))) =>
+                if Sorts.sort_le (Sign.classes_of thy) (sort, sort')
+                then NONE else SOME (v, sort)
+           | _ => NONE) env
+      in 
+        insts
+      end
+    val const_insts = case const_typ
+     of NONE => []
+      | SOME const_typ => maps (insts_of const_typ) consts;
+    fun add_inst (v, sort') =
+      let
+        val sort = (the o AList.lookup (op =) tvars) v
+      in
+        AList.map_default (op =) (v, (sort, sort))
+          (apsnd (fn sort => Sorts.inter_sort (Sign.classes_of thy) (sort, sort')))
+      end;
+    val inst =
+      []
+      |> fold (fn f => fold add_inst (f thy ts)) fs
+      |> fold add_inst const_insts;
+  in
+    post thy inst x
+  end;
 
-fun getf_first [] _ = NONE
-  | getf_first (f::fs) x = case f x
-     of NONE => getf_first fs x
-      | y as SOME x => y;
+val apply_constrain = gen_apply_constrain (maps
+  ((fn (args, rhs) => rhs :: (snd o strip_comb) args) o Logic.dest_equals o Thm.prop_of))
+  (fn thy => fn inst => map (check_typ_classop thy o Thm.instantiate (map (fn (v, (sort, sort')) =>
+    (Thm.ctyp_of thy (TVar (v, sort)), Thm.ctyp_of thy (TVar (v, sort')))
+  ) inst, []))) NONE;
+fun apply_constrain_cterm thy const_typ = gen_apply_constrain (single o Thm.term_of)
+  (fn thy => fn inst => pair inst o Thm.cterm_of thy o map_types
+    (TermSubst.instantiateT (map (fn (v, (sort, sort')) => ((v, sort), TVar (v, sort'))) inst)) o Thm.term_of) (SOME const_typ) thy;
+
+fun gen_apply_inline_proc prep post thy f x =
+  let
+    val cts = prep x;
+    val rews = map (check_rew thy) (f thy cts);
+  in post rews x end;
+
+val apply_inline_proc = gen_apply_inline_proc (maps
+  ((fn [args, rhs] => rhs :: (snd o Drule.strip_comb) args) o snd o Drule.strip_comb o Thm.cprop_of))
+  (fn rews => map (rewrite_func rews));
+val apply_inline_proc_cterm = gen_apply_inline_proc single
+  (Tactic.rewrite false);
 
-fun getf_first_list [] x = []
-  | getf_first_list (f::fs) x = case f x
-     of [] => getf_first_list fs x
-      | xs => xs;
+fun apply_preproc thy f [] = []
+  | apply_preproc thy f (thms as (thm :: _)) =
+      let
+        val thms' = f thy thms;
+        val c = (CodegenConsts.norm_of_typ thy o fst o dest_func thy) thm;
+      in (certify_const thy c o map (mk_head thy)) thms' end;
+
+fun cmp_thms thy =
+  make_ord (fn (thm1, thm2) => not (Sign.typ_instance thy (typ_func thy thm1, typ_func thy thm2)));
+
+fun rhs_conv conv thm =
+  let
+    val thm' = (conv o snd o Drule.dest_equals o Thm.cprop_of) thm;
+  in Thm.transitive thm thm' end
+
+fun drop_classes thy inst thm =
+  let
+    val unconstr = map (fn (v, (_, sort')) =>
+      (Thm.ctyp_of thy o TVar) (v, sort')) inst;
+    val instmap = map (fn (v, (sort, _)) =>
+      pairself (Thm.ctyp_of thy o TVar) ((v, []), (v, sort))) inst;
+  in
+    thm
+    |> fold Thm.unconstrainT unconstr
+    |> Thm.instantiate (instmap, [])
+    |> Tactic.rule_by_tactic ((REPEAT o CHANGED o ALLGOALS o Tactic.resolve_tac) (AxClass.class_intros thy))
+  end;
+
+in
 
 fun preprocess thy thms =
-  let
-    fun cmp_thms (thm1, thm2) =
-      not (Sign.typ_instance thy (typ_func thy thm1, typ_func thy thm2));
-  in
-    thms
-    |> map (rewrite_func ((#inlines o the_preproc o get_exec) thy))
-    |> fold (fn (_, f) => f thy) ((#preprocs o the_preproc o get_exec) thy)
-    |> map (rewrite_func ((#inlines o the_preproc o get_exec) thy))
-    |> sort (make_ord cmp_thms)
-    |> common_typ_funcs thy
-  end;
+  thms
+  |> fold (fn (_, f) => apply_preproc thy f) ((#preprocs o the_preproc o get_exec) thy)
+  |> map (rewrite_func ((#inlines o the_preproc o get_exec) thy))
+  |> apply_constrain thy ((map snd o #constrains o the_preproc o get_exec) thy)
+  |> map (rewrite_func ((#inlines o the_preproc o get_exec) thy))
+  |> fold (fn (_, f) => apply_inline_proc thy f) ((#inline_procs o the_preproc o get_exec) thy)
+  |> map (snd o check_func false thy)
+  |> sort (cmp_thms thy)
+  |> common_typ_funcs thy;
 
-fun preprocess_cterm thy =
-  Tactic.rewrite false ((#inlines o the_preproc o get_exec) thy);
+fun preprocess_cterm thy const_typ ct =
+  ct
+  |> apply_constrain_cterm thy const_typ ((map snd o #constrains o the_preproc o get_exec) thy)
+  |-> (fn inst =>
+     Thm.reflexive
+  #> fold (rhs_conv o Tactic.rewrite false o single) ((#inlines o the_preproc o get_exec) thy)
+  #> fold (fn (_, f) => rhs_conv (apply_inline_proc_cterm thy f)) ((#inline_procs o the_preproc o get_exec) thy)
+  #> (fn thm => (drop_classes thy inst thm, ((fn xs => nth xs 1) o snd o Drule.strip_comb o Thm.cprop_of) thm))
+  );
+
+end; (*local*)
 
 fun these_funcs thy c =
   let
-    fun test_funcs c =
+    val funcs_1 =
       Consttab.lookup ((the_funcs o get_exec) thy) c
       |> Option.map (Susp.force o fst)
       |> these
       |> map (Thm.transfer thy);
-    val test_defs = get_prim_def_funcs thy;
+    val funcs_2 = case funcs_1
+     of [] => get_prim_def_funcs thy c
+      | xs => xs;
     fun drop_refl thy = filter_out (is_equal o Term.fast_term_ord o Logic.dest_equals
       o ObjectLogic.drop_judgment thy o Drule.plain_prop_of);
   in
-    getf_first_list [test_funcs, test_defs] c
+    funcs_2
     |> preprocess thy
     |> drop_refl thy
   end;