added interface for plugging in preprocessors
authorhaftmann
Tue, 30 Jan 2007 08:21:22 +0100
changeset 22212 079de24eee65
parent 22211 e2b5f3d24a17
child 22213 2dd23002c465
added interface for plugging in preprocessors
src/Pure/Tools/codegen_funcgr.ML
--- a/src/Pure/Tools/codegen_funcgr.ML	Tue Jan 30 08:21:19 2007 +0100
+++ b/src/Pure/Tools/codegen_funcgr.ML	Tue Jan 30 08:21:22 2007 +0100
@@ -8,26 +8,29 @@
 
 signature CODEGEN_FUNCGR =
 sig
-  type T;
-  val make: theory -> CodegenConsts.const list -> T
-  val make_consts: theory -> CodegenConsts.const list -> CodegenConsts.const list * T
-  val make_term: theory -> ((thm -> thm) -> cterm -> thm -> 'a) -> cterm -> 'a * T
+  type T
+  val timing: bool ref
   val funcs: T -> CodegenConsts.const -> thm list
   val typ: T -> CodegenConsts.const -> typ
   val deps: T -> CodegenConsts.const list -> CodegenConsts.const list list
   val all: T -> CodegenConsts.const list
-  val norm_varnames: thm list -> thm list
-  val print_codethms: theory -> CodegenConsts.const list -> unit
-  structure Constgraph : GRAPH
-  val timing: bool ref
+  val pretty: theory -> T -> Pretty.T
+end
+
+signature CODEGEN_FUNCGR_RETRIEVAL =
+sig
+  type T (* = CODEGEN_FUNCGR.T *)
+  val make: theory -> CodegenConsts.const list -> T
+  val make_consts: theory -> CodegenConsts.const list -> CodegenConsts.const list * T
+  val make_term: theory -> (T -> (thm -> thm) -> cterm -> thm -> 'a) -> cterm -> 'a * T
+  val init: theory -> theory
 end;
 
-structure CodegenFuncgr: CODEGEN_FUNCGR =
+structure CodegenFuncgr = (*signature is added later*)
 struct
 
-(** code data **)
+(** the graph type **)
 
-structure Consttab = CodegenConsts.Consttab;
 structure Constgraph = GraphFun (
   type key = CodegenConsts.const;
   val ord = CodegenConsts.const_ord;
@@ -35,23 +38,6 @@
 
 type T = (typ * thm list) Constgraph.T;
 
-structure Funcgr = CodeDataFun
-(struct
-  val name = "Pure/codegen_funcgr";
-  type T = T;
-  val empty = Constgraph.empty;
-  fun merge _ _ = Constgraph.empty;
-  fun purge _ NONE _ = Constgraph.empty
-    | purge _ (SOME cs) funcgr =
-        Constgraph.del_nodes ((Constgraph.all_preds funcgr 
-          o filter (can (Constgraph.get_node funcgr))) cs) funcgr;
-end);
-
-val _ = Context.add_setup Funcgr.init;
-
-
-(** retrieval **)
-
 fun funcs funcgr =
   these o Option.map snd o try (Constgraph.get_node funcgr);
 
@@ -69,79 +55,16 @@
 
 fun all funcgr = Constgraph.keys funcgr;
 
-
-(** theorem purification **)
-
-fun norm_args thms =
-  let
-    val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
-    val k = fold (curry Int.max o num_args_of o Drule.plain_prop_of) thms 0;
-  in
-    thms
-    |> map (CodegenFunc.expand_eta k)
-    |> map (Drule.fconv_rule Drule.beta_eta_conversion)
-  end;
-
-fun canonical_tvars thm =
-  let
-    val ctyp = Thm.ctyp_of (Thm.theory_of_thm thm);
-    fun tvars_subst_for thm = (fold_types o fold_atyps)
-      (fn TVar (v_i as (v, _), sort) => let
-            val v' = CodegenNames.purify_tvar v
-          in if v = v' then I
-          else insert (op =) (v_i, (v', sort)) end
-        | _ => I) (prop_of thm) [];
-    fun mk_inst (v_i, (v', sort)) (maxidx, acc) =
-      let
-        val ty = TVar (v_i, sort)
-      in
-        (maxidx + 1, (ctyp ty, ctyp (TVar ((v', maxidx), sort))) :: acc)
-      end;
-    val maxidx = Thm.maxidx_of thm + 1;
-    val (_, inst) = fold mk_inst (tvars_subst_for thm) (maxidx + 1, []);
-  in Thm.instantiate (inst, []) thm end;
-
-fun canonical_vars thm =
-  let
-    val cterm = Thm.cterm_of (Thm.theory_of_thm thm);
-    fun vars_subst_for thm = fold_aterms
-      (fn Var (v_i as (v, _), ty) => let
-            val v' = CodegenNames.purify_var v
-          in if v = v' then I
-          else insert (op =) (v_i, (v', ty)) end
-        | _ => I) (prop_of thm) [];
-    fun mk_inst (v_i as (v, i), (v', ty)) (maxidx, acc) =
-      let
-        val t = Var (v_i, ty)
-      in
-        (maxidx + 1, (cterm t, cterm (Var ((v', maxidx), ty))) :: acc)
-      end;
-    val maxidx = Thm.maxidx_of thm + 1;
-    val (_, inst) = fold mk_inst (vars_subst_for thm) (maxidx + 1, []);
-  in Thm.instantiate ([], inst) thm end;
-
-fun canonical_absvars thm =
-  let
-    val t = Thm.prop_of thm;
-    val t' = Term.map_abs_vars CodegenNames.purify_var t;
-  in Thm.rename_boundvars t t' thm end;
-
-fun norm_varnames thms =
-  let
-    fun burrow_thms f [] = []
-      | burrow_thms f thms =
-          thms
-          |> Conjunction.intr_list
-          |> f
-          |> Conjunction.elim_list;
-  in
-    thms
-    |> norm_args
-    |> burrow_thms canonical_tvars
-    |> map canonical_vars
-    |> map canonical_absvars
-    |> map Drule.zero_var_indexes
-  end;
+fun pretty thy funcgr =
+  AList.make (snd o Constgraph.get_node funcgr) (Constgraph.keys funcgr)
+  |> (map o apfst) (CodegenConsts.string_of_const thy)
+  |> sort (string_ord o pairself fst)
+  |> map (fn (s, thms) =>
+       (Pretty.block o Pretty.fbreaks) (
+         Pretty.str s
+         :: map Display.pretty_thm thms
+       ))
+  |> Pretty.chunks;
 
 
 (** generic combinators **)
@@ -181,6 +104,20 @@
     flat (maps of_sort_deriv (fold2 mk_inst tys tys_decl []))
   end;
 
+fun drop_classes thy tfrees thm =
+  let
+    val (_, thm') = Thm.varifyT' [] thm;
+    val tvars = Term.add_tvars (Thm.prop_of thm') [];
+    val unconstr = map (Thm.ctyp_of thy o TVar) tvars;
+    val instmap = map2 (fn (v_i, _) => fn (v, sort) => pairself (Thm.ctyp_of thy)
+      (TVar (v_i, []), TFree (v, sort))) tvars tfrees;
+  in
+    thm'
+    |> fold Thm.unconstrainT unconstr
+    |> Thm.instantiate (instmap, [])
+    |> Tactic.rule_by_tactic ((REPEAT o CHANGED o ALLGOALS o Tactic.resolve_tac) (AxClass.class_intros thy))
+  end;
+
 
 (** graph algorithm **)
 
@@ -275,7 +212,7 @@
     |> instances_of thy algebra
   end;
 
-fun ensure_const' thy algebra funcgr const auxgr =
+fun ensure_const' rewrites thy algebra funcgr const auxgr =
   if can (Constgraph.get_node funcgr) const
     then (NONE, auxgr)
   else if can (Constgraph.get_node auxgr) const
@@ -285,24 +222,27 @@
     |> Constgraph.new_node (const, [])
     |> pair (SOME const)
   else let
-    val thms = norm_varnames (CodegenData.these_funcs thy const);
+    val thms = CodegenData.these_funcs thy const
+      |> map (CodegenFunc.rewrite_func (rewrites thy))
+      |> CodegenFunc.norm_args
+      |> CodegenFunc.norm_varnames CodegenNames.purify_tvar CodegenNames.purify_var;
     val rhs = consts_of (const, thms);
   in
     auxgr
     |> Constgraph.new_node (const, thms)
-    |> fold_map (ensure_const thy algebra funcgr) rhs
+    |> fold_map (ensure_const rewrites thy algebra funcgr) rhs
     |-> (fn rhs' => fold (fn SOME const' => Constgraph.add_edge (const, const')
                            | NONE => I) rhs')
     |> pair (SOME const)
   end
-and ensure_const thy algebra funcgr const =
+and ensure_const rewrites thy algebra funcgr const =
   let
     val timeap = if !timing
       then Output.timeap_msg ("time for " ^ CodegenConsts.string_of_const thy const)
       else I;
-  in timeap (ensure_const' thy algebra funcgr const) end;
+  in timeap (ensure_const' rewrites thy algebra funcgr const) end;
 
-fun merge_funcss thy algebra raw_funcss funcgr =
+fun merge_funcss rewrites thy algebra raw_funcss funcgr =
   let
     val funcss = resort_funcss thy algebra funcgr raw_funcss;
     fun classop_typ (c, [typarg]) class =
@@ -346,7 +286,7 @@
           (fold_consts (insert (op =)) thms []);
       in
         funcgr
-        |> ensure_consts' thy algebra insts
+        |> ensure_consts' rewrites thy algebra insts
         |> fold (curry Constgraph.add_edge const) deps
         |> fold (curry Constgraph.add_edge const) insts
        end;
@@ -355,64 +295,53 @@
     |> fold add_funcs funcss
     |> fold add_deps funcss
   end
-and ensure_consts' thy algebra cs funcgr =
-  fold (snd oo ensure_const thy algebra funcgr) cs Constgraph.empty
-  |> (fn auxgr => fold (merge_funcss thy algebra)
+and ensure_consts' rewrites thy algebra cs funcgr =
+  fold (snd oo ensure_const rewrites thy algebra funcgr) cs Constgraph.empty
+  |> (fn auxgr => fold (merge_funcss rewrites thy algebra)
        (map (AList.make (Constgraph.get_node auxgr))
        (rev (Constgraph.strong_conn auxgr))) funcgr)
   handle INVALID (cs', msg) => raise INVALID (fold (insert CodegenConsts.eq_const) cs' cs, msg);
 
-fun drop_classes thy tfrees thm =
-  let
-    val (_, thm') = Thm.varifyT' [] thm;
-    val tvars = Term.add_tvars (Thm.prop_of thm') [];
-    val unconstr = map (Thm.ctyp_of thy o TVar) tvars;
-    val instmap = map2 (fn (v_i, _) => fn (v, sort) => pairself (Thm.ctyp_of thy)
-      (TVar (v_i, []), TFree (v, sort))) tvars tfrees;
-  in
-    thm'
-    |> fold Thm.unconstrainT unconstr
-    |> Thm.instantiate (instmap, [])
-    |> Tactic.rule_by_tactic ((REPEAT o CHANGED o ALLGOALS o Tactic.resolve_tac) (AxClass.class_intros thy))
-  end;
-
-fun ensure_consts thy consts funcgr =
+fun ensure_consts rewrites thy consts funcgr =
   let
     val algebra = CodegenData.coregular_algebra thy
-  in ensure_consts' thy algebra consts funcgr
+  in ensure_consts' rewrites thy algebra consts funcgr
     handle INVALID (cs', msg) => error (msg ^ ",\nwhile preprocessing equations for constant(s) "
     ^ commas (map (CodegenConsts.string_of_const thy) cs'))
   end;
 
 in
 
-
-(** graph retrieval **)
+(** retrieval interfaces **)
 
-fun make thy consts =
-  Funcgr.change thy (ensure_consts thy consts);
+val ensure_consts = ensure_consts;
 
-fun make_consts thy consts =
+fun check_consts rewrites thy consts funcgr =
   let
     val algebra = CodegenData.coregular_algebra thy;
     fun try_const const funcgr =
-      (SOME const, ensure_consts' thy algebra [const] funcgr)
+      (SOME const, ensure_consts' rewrites thy algebra [const] funcgr)
       handle INVALID (cs', msg) => (NONE, funcgr);
-    val (consts', funcgr) = Funcgr.change_yield thy (fold_map try_const consts);
-  in (map_filter I consts', funcgr) end;
+    val (consts', funcgr') = fold_map try_const consts funcgr;
+  in (map_filter I consts', funcgr') end;
 
-fun make_term thy f ct =
+fun ensure_consts_term rewrites thy f ct funcgr =
   let
+    fun rhs_conv conv thm =
+      let
+        val thm' = (conv o snd o Drule.dest_equals o Thm.cprop_of) thm;
+      in Thm.transitive thm thm' end
     val _ = Sign.no_vars (Sign.pp thy) (Thm.term_of ct);
     val _ = Term.fold_types (Type.no_tvars #> K I) (Thm.term_of ct) ();
-    val thm1 = CodegenData.preprocess_cterm ct;
+    val thm1 = CodegenData.preprocess_cterm ct
+      |> fold (rhs_conv o MetaSimplifier.rewrite false o single) (rewrites thy);
     val ct' = Drule.dest_equals_rhs (Thm.cprop_of thm1);
     val consts = CodegenConsts.consts_of thy (Thm.term_of ct');
-    val funcgr = make thy consts;
+    val funcgr' = ensure_consts rewrites thy consts funcgr;
     val algebra = CodegenData.coregular_algebra thy;
     val (_, thm2) = Thm.varifyT' [] thm1;
     val thm3 = Thm.reflexive (Drule.dest_equals_rhs (Thm.cprop_of thm2));
-    val typ_funcgr = try (fst o Constgraph.get_node funcgr o CodegenConsts.norm_of_typ thy);
+    val typ_funcgr = try (fst o Constgraph.get_node funcgr' o CodegenConsts.norm_of_typ thy);
     val [thm4] = resort_thms algebra typ_funcgr [thm3];
     val tfrees = Term.add_tfrees (Thm.prop_of thm1) [];
     fun inst thm =
@@ -426,48 +355,49 @@
     val ct'' = Drule.dest_equals_rhs (Thm.cprop_of thm6);
     val cs = fold_aterms (fn Const c => cons c | _ => I) (Thm.term_of ct'') [];
     val drop = drop_classes thy tfrees;
-    val instdefs = instances_of_consts thy algebra funcgr cs;
-    val funcgr' = ensure_consts thy instdefs funcgr;
-  in (f drop ct'' thm5, Funcgr.change thy (K funcgr')) end;
+    val instdefs = instances_of_consts thy algebra funcgr' cs;
+    val funcgr'' = ensure_consts rewrites thy instdefs funcgr';
+  in (f funcgr'' drop ct'' thm5, funcgr'') end;
 
 end; (*local*)
 
-
-(** diagnostics **)
+end; (*struct*)
 
-fun print_funcgr thy funcgr =
-  AList.make (snd o Constgraph.get_node funcgr) (Constgraph.keys funcgr)
-  |> (map o apfst) (CodegenConsts.string_of_const thy)
-  |> sort (string_ord o pairself fst)
-  |> map (fn (s, thms) =>
-       (Pretty.block o Pretty.fbreaks) (
-         Pretty.str s
-         :: map Display.pretty_thm thms
-       ))
-  |> Pretty.chunks
-  |> Pretty.writeln;
+functor CodegenFuncgrRetrieval (val name: string; val rewrites: theory -> thm list) : CODEGEN_FUNCGR_RETRIEVAL =
+struct
 
-fun print_codethms thy consts =
-  make thy consts |> print_funcgr thy;
+(** code data **)
+
+type T = CodegenFuncgr.T;
 
-fun print_codethms_e thy cs =
-  print_codethms thy (map (CodegenConsts.read_const thy) cs);
-
+structure Funcgr = CodeDataFun
+(struct
+  val name = name;
+  type T = T;
+  val empty = CodegenFuncgr.Constgraph.empty;
+  fun merge _ _ = CodegenFuncgr.Constgraph.empty;
+  fun purge _ NONE _ = CodegenFuncgr.Constgraph.empty
+    | purge _ (SOME cs) funcgr =
+        CodegenFuncgr.Constgraph.del_nodes ((CodegenFuncgr.Constgraph.all_preds funcgr 
+          o filter (can (CodegenFuncgr.Constgraph.get_node funcgr))) cs) funcgr;
+end);
 
-(** Isar setup **)
-
-structure P = OuterParse;
-
-val print_codethmsK = "print_codethms";
+fun make thy =
+  Funcgr.change thy o CodegenFuncgr.ensure_consts rewrites thy;
 
-val print_codethmsP =
-  OuterSyntax.improper_command print_codethmsK "print code theorems of this theory" OuterKeyword.diag
-    (Scan.option (P.$$$ "(" |-- Scan.repeat P.term --| P.$$$ ")")
-      >> (fn NONE => CodegenData.print_thms
-           | SOME cs => fn thy => print_codethms_e thy cs)
-      >> (fn f => Toplevel.no_timing o Toplevel.unknown_theory
-      o Toplevel.keep (f o Toplevel.theory_of)));
+fun make_consts thy =
+  Funcgr.change_yield thy o CodegenFuncgr.check_consts rewrites thy;
+
+fun make_term thy f =
+  Funcgr.change_yield thy o CodegenFuncgr.ensure_consts_term rewrites thy f;
 
-val _ = OuterSyntax.add_parsers [print_codethmsP];
+val init = Funcgr.init;
+
+end; (*functor*)
+
+structure CodegenFuncgr : CODEGEN_FUNCGR =
+struct
+
+open CodegenFuncgr;
 
 end; (*struct*)