src/Pure/codegen.ML
changeset 15261 ba3c9fdbace3
parent 15029 a4d0ed993050
child 15326 ff21cddee442
--- a/src/Pure/codegen.ML	Tue Oct 26 16:31:09 2004 +0200
+++ b/src/Pure/codegen.ML	Tue Oct 26 16:32:09 2004 +0200
@@ -23,6 +23,8 @@
   val add_codegen: string -> term codegen -> theory -> theory
   val add_tycodegen: string -> typ codegen -> theory -> theory
   val add_attribute: string -> (Args.T list -> theory attribute * Args.T list) -> theory -> theory
+  val add_preprocessor: (theory -> thm list -> thm list) -> theory -> theory
+  val preprocess: theory -> thm list -> thm list
   val print_codegens: theory -> unit
   val generate_code: theory -> (string * string) list -> string
   val generate_code_i: theory -> (string * term) list -> string
@@ -134,26 +136,28 @@
      consts : ((string * typ) * term mixfix list) list,
      types : (string * typ mixfix list) list,
      attrs: (string * (Args.T list -> theory attribute * Args.T list)) list,
+     preprocs: (stamp * (theory -> thm list -> thm list)) list,
      test_params: test_params};
 
   val empty =
     {codegens = [], tycodegens = [], consts = [], types = [], attrs = [],
-     test_params = default_test_params};
+     preprocs = [], test_params = default_test_params};
   val copy = I;
   val prep_ext = I;
 
   fun merge
     ({codegens = codegens1, tycodegens = tycodegens1,
       consts = consts1, types = types1, attrs = attrs1,
-      test_params = test_params1},
+      preprocs = preprocs1, test_params = test_params1},
      {codegens = codegens2, tycodegens = tycodegens2,
       consts = consts2, types = types2, attrs = attrs2,
-      test_params = test_params2}) =
-    {codegens = rev (merge_alists (rev codegens1) (rev codegens2)),
-     tycodegens = rev (merge_alists (rev tycodegens1) (rev tycodegens2)),
+      preprocs = preprocs2, test_params = test_params2}) =
+    {codegens = merge_alists' codegens1 codegens2,
+     tycodegens = merge_alists' tycodegens1 tycodegens2,
      consts = merge_alists consts1 consts2,
      types = merge_alists types1 types2,
      attrs = merge_alists attrs1 attrs2,
+     preprocs = merge_alists' preprocs1 preprocs2,
      test_params = merge_test_params test_params1 test_params2};
 
   fun print sg ({codegens, tycodegens, ...} : T) =
@@ -171,10 +175,10 @@
 fun get_test_params thy = #test_params (CodegenData.get thy);
 
 fun map_test_params f thy =
-  let val {codegens, tycodegens, consts, types, attrs, test_params} =
+  let val {codegens, tycodegens, consts, types, attrs, preprocs, test_params} =
     CodegenData.get thy;
   in CodegenData.put {codegens = codegens, tycodegens = tycodegens,
-    consts = consts, types = types, attrs = attrs,
+    consts = consts, types = types, attrs = attrs, preprocs = preprocs,
     test_params = f test_params} thy
   end;
 
@@ -182,22 +186,22 @@
 (**** add new code generators to theory ****)
 
 fun add_codegen name f thy =
-  let val {codegens, tycodegens, consts, types, attrs, test_params} =
+  let val {codegens, tycodegens, consts, types, attrs, preprocs, test_params} =
     CodegenData.get thy
   in (case assoc (codegens, name) of
       None => CodegenData.put {codegens = (name, f) :: codegens,
         tycodegens = tycodegens, consts = consts, types = types,
-        attrs = attrs, test_params = test_params} thy
+        attrs = attrs, preprocs = preprocs, test_params = test_params} thy
     | Some _ => error ("Code generator " ^ name ^ " already declared"))
   end;
 
 fun add_tycodegen name f thy =
-  let val {codegens, tycodegens, consts, types, attrs, test_params} =
+  let val {codegens, tycodegens, consts, types, attrs, preprocs, test_params} =
     CodegenData.get thy
   in (case assoc (tycodegens, name) of
       None => CodegenData.put {tycodegens = (name, f) :: tycodegens,
         codegens = codegens, consts = consts, types = types,
-        attrs = attrs, test_params = test_params} thy
+        attrs = attrs, preprocs = preprocs, test_params = test_params} thy
     | Some _ => error ("Code generator " ^ name ^ " already declared"))
   end;
 
@@ -205,12 +209,14 @@
 (**** code attribute ****)
 
 fun add_attribute name att thy =
-  let val {codegens, tycodegens, consts, types, attrs, test_params} =
+  let val {codegens, tycodegens, consts, types, attrs, preprocs, test_params} =
     CodegenData.get thy
   in (case assoc (attrs, name) of
       None => CodegenData.put {tycodegens = tycodegens,
         codegens = codegens, consts = consts, types = types,
-        attrs = (name, att) :: attrs, test_params = test_params} thy
+        attrs = if name = "" then attrs @ [(name, att)] else (name, att) :: attrs,
+        preprocs = preprocs,
+        test_params = test_params} thy
     | Some _ => error ("Code attribute " ^ name ^ " already declared"))
   end;
 
@@ -221,12 +227,41 @@
     (#attrs (CodegenData.get thy)), Scan.fail) >> pair thy));
 
 
+(**** preprocessors ****)
+
+fun add_preprocessor p thy =
+  let val {codegens, tycodegens, consts, types, attrs, preprocs, test_params} =
+    CodegenData.get thy
+  in CodegenData.put {tycodegens = tycodegens,
+    codegens = codegens, consts = consts, types = types,
+    attrs = attrs, preprocs = (stamp (), p) :: preprocs,
+    test_params = test_params} thy
+  end;
+
+fun preprocess thy ths =
+  let val {preprocs, ...} = CodegenData.get thy
+  in foldl (fn (ths, (_, f)) => f thy ths) (ths, preprocs) end;
+
+fun unfold_attr (thy, eqn) =
+  let
+    val (name, _) = dest_Const (head_of
+      (fst (Logic.dest_equals (prop_of eqn))));
+    fun prep thy = map (fn th =>
+      if name mem term_consts (prop_of th) then
+        let val sg = sign_of_thm eqn
+        in rewrite_rule [eqn] (if Sign.subsig (sign_of_thm th, sg) then
+          Thm.transfer_sg sg th else th)
+        end
+      else th)
+  in (add_preprocessor prep thy, eqn) end;
+
+
 (**** associate constants with target language code ****)
 
 fun gen_assoc_consts prep_type xs thy = foldl (fn (thy, (s, tyopt, syn)) =>
   let
     val sg = sign_of thy;
-    val {codegens, tycodegens, consts, types, attrs, test_params} =
+    val {codegens, tycodegens, consts, types, attrs, preprocs, test_params} =
       CodegenData.get thy;
     val cname = Sign.intern_const sg s;
   in
@@ -243,7 +278,8 @@
              None => CodegenData.put {codegens = codegens,
                tycodegens = tycodegens,
                consts = ((cname, T'), syn) :: consts,
-               types = types, attrs = attrs, test_params = test_params} thy
+               types = types, attrs = attrs, preprocs = preprocs,
+               test_params = test_params} thy
            | Some _ => error ("Constant " ^ cname ^ " already associated with code"))
          end
      | _ => error ("Not a constant: " ^ s))
@@ -256,7 +292,7 @@
 
 fun assoc_types xs thy = foldl (fn (thy, (s, syn)) =>
   let
-    val {codegens, tycodegens, consts, types, attrs, test_params} =
+    val {codegens, tycodegens, consts, types, attrs, preprocs, test_params} =
       CodegenData.get thy;
     val tc = Sign.intern_tycon (sign_of thy) s
   in
@@ -264,7 +300,7 @@
        None => CodegenData.put {codegens = codegens,
          tycodegens = tycodegens, consts = consts,
          types = (tc, syn) :: types, attrs = attrs,
-         test_params = test_params} thy
+         preprocs = preprocs, test_params = test_params} thy
      | Some _ => error ("Type " ^ tc ^ " already associated with code"))
   end) (thy, xs);
 
@@ -343,17 +379,28 @@
   let
     val axms = flat (map (Symtab.dest o #axioms o Theory.rep_theory)
       (thy :: Theory.ancestors_of thy));
-    val defs = mapfilter (fn (_, t) =>
-      (let
-         val (lhs, rhs) = Logic.dest_equals t;
-         val (c, args) = strip_comb lhs;
-         val (s', T') = dest_Const c
-       in if s=s' then Some (T', split_last (rename_terms (args @ [rhs])))
-         else None end) handle TERM _ => None) axms;
-    val i = find_index (is_instance thy T o fst) defs
+    fun prep_def def = (case preprocess thy [def] of
+      [def'] => prop_of def' | _ => error "get_defn: bad preprocessor");
+    fun dest t =
+      let
+        val (lhs, rhs) = Logic.dest_equals t;
+        val (c, args) = strip_comb lhs;
+        val (s', T') = dest_Const c
+      in if s = s' then Some (T', (args, rhs)) else None
+      end handle TERM _ => None;
+    val defs = mapfilter (fn (name, t) => apsome (pair name) (dest t)) axms;
+    val i = find_index (is_instance thy T o fst o snd) defs
   in
-    if i>=0 then Some (snd (nth_elem (i, defs)),
-      if length defs = 1 then None else Some i)
+    if i >= 0 then
+      let val (name, (T', (args, _))) = nth_elem (i, defs)
+      in case dest (prep_def (Thm.get_axiom thy name)) of
+          None => None
+        | Some (T'', p as (args', rhs)) =>
+            if T' = T'' andalso args = args' then
+              Some (split_last (rename_terms (args @ [rhs])),
+                if length defs = 1 then None else Some i)
+            else None
+      end
     else None
   end;
 
@@ -724,7 +771,8 @@
    assoc_types [("fun", parse_mixfix (K dummyT) "(_ ->/ _)")],
    Attrib.add_attributes [("code",
      (code_attr, K Attrib.undef_local_attribute),
-     "declare theorems for code generation")]];
+     "declare theorems for code generation")],
+   add_attribute "unfold" (Scan.succeed unfold_attr)];
 
 end;