src/HOL/Tools/recfun_codegen.ML
author haftmann
Fri Dec 07 15:07:56 2007 +0100 (2007-12-07)
changeset 25570 fdfbbb92dadf
parent 25389 3e58c7cb5a73
child 25894 0ee6e01c5572
permissions -rw-r--r--
proper treatment of code theorems for primrec
     1 (*  Title:      HOL/Tools/recfun_codegen.ML
     2     ID:         $Id$
     3     Author:     Stefan Berghofer, TU Muenchen
     4 
     5 Code generator for recursive functions.
     6 *)
     7 
     8 signature RECFUN_CODEGEN =
     9 sig
    10   val add: string option -> attribute
    11   val add_default: attribute
    12   val del: attribute
    13   val setup: theory -> theory
    14 end;
    15 
    16 structure RecfunCodegen : RECFUN_CODEGEN =
    17 struct
    18 
    19 open Codegen;
    20 
    21 structure RecCodegenData = TheoryDataFun
    22 (
    23   type T = (thm * string option) list Symtab.table;
    24   val empty = Symtab.empty;
    25   val copy = I;
    26   val extend = I;
    27   fun merge _ = Symtab.merge_list (Thm.eq_thm_prop o pairself fst);
    28 );
    29 
    30 val dest_eqn = HOLogic.dest_eq o HOLogic.dest_Trueprop;
    31 val lhs_of = fst o dest_eqn o prop_of;
    32 val const_of = dest_Const o head_of o fst o dest_eqn;
    33 
    34 fun warn thm = warning ("RecfunCodegen: Not a proper equation:\n" ^
    35   string_of_thm thm);
    36 
    37 fun add_thm opt_module thm =
    38   (if Pattern.pattern (lhs_of thm) then
    39     RecCodegenData.map
    40       (Symtab.cons_list ((fst o const_of o prop_of) thm, (thm, opt_module)))
    41   else tap (fn _ => warn thm))
    42   handle TERM _ => tap (fn _ => warn thm);
    43 
    44 fun add opt_module = Thm.declaration_attribute (fn thm => Context.mapping
    45   (add_thm opt_module thm #> Code.add_liberal_func thm) I);
    46 
    47 val add_default = Thm.declaration_attribute (fn thm => Context.mapping
    48   (add_thm NONE thm #> Code.add_default_func thm) I);
    49 
    50 fun del_thm thm = case try const_of (prop_of thm)
    51  of SOME (s, _) => RecCodegenData.map
    52       (Symtab.map_entry s (remove (Thm.eq_thm o apsnd fst) thm))
    53   | NONE => tap (fn _ => warn thm);
    54 
    55 val del = Thm.declaration_attribute
    56   (fn thm => Context.mapping (del_thm thm #> Code.del_func thm) I)
    57 
    58 fun del_redundant thy eqs [] = eqs
    59   | del_redundant thy eqs (eq :: eqs') =
    60     let
    61       val matches = curry
    62         (Pattern.matches thy o pairself (lhs_of o fst))
    63     in del_redundant thy (eq :: eqs) (filter_out (matches eq) eqs') end;
    64 
    65 fun get_equations thy defs (s, T) =
    66   (case Symtab.lookup (RecCodegenData.get thy) s of
    67      NONE => ([], "")
    68    | SOME thms => 
    69        let val thms' = del_redundant thy []
    70          (filter (fn (thm, _) => is_instance T
    71            (snd (const_of (prop_of thm)))) thms)
    72        in if null thms' then ([], "")
    73          else (preprocess thy (map fst thms'),
    74            case snd (snd (split_last thms')) of
    75                NONE => (case get_defn thy defs s T of
    76                    NONE => thyname_of_const s thy
    77                  | SOME ((_, (thyname, _)), _) => thyname)
    78              | SOME thyname => thyname)
    79        end);
    80 
    81 fun mk_suffix thy defs (s, T) = (case get_defn thy defs s T of
    82   SOME (_, SOME i) => " def" ^ string_of_int i | _ => "");
    83 
    84 exception EQN of string * typ * string;
    85 
    86 fun cycle g (xs, x : string) =
    87   if member (op =) xs x then xs
    88   else Library.foldl (cycle g) (x :: xs, flat (Graph.all_paths (fst g) (x, x)));
    89 
    90 fun add_rec_funs thy defs gr dep eqs module =
    91   let
    92     fun dest_eq t = (fst (const_of t) ^ mk_suffix thy defs (const_of t),
    93       dest_eqn (rename_term t));
    94     val eqs' = map dest_eq eqs;
    95     val (dname, _) :: _ = eqs';
    96     val (s, T) = const_of (hd eqs);
    97 
    98     fun mk_fundef module fname prfx gr [] = (gr, [])
    99       | mk_fundef module fname prfx gr ((fname' : string, (lhs, rhs)) :: xs) =
   100       let
   101         val (gr1, pl) = invoke_codegen thy defs dname module false (gr, lhs);
   102         val (gr2, pr) = invoke_codegen thy defs dname module false (gr1, rhs);
   103         val (gr3, rest) = mk_fundef module fname' "and " gr2 xs
   104       in
   105         (gr3, Pretty.blk (4, [Pretty.str (if fname = fname' then "  | " else prfx),
   106            pl, Pretty.str " =", Pretty.brk 1, pr]) :: rest)
   107       end;
   108 
   109     fun put_code module fundef = map_node dname
   110       (K (SOME (EQN ("", dummyT, dname)), module, Pretty.string_of (Pretty.blk (0,
   111       separate Pretty.fbrk fundef @ [Pretty.str ";"])) ^ "\n\n"));
   112 
   113   in
   114     (case try (get_node gr) dname of
   115        NONE =>
   116          let
   117            val gr1 = add_edge (dname, dep)
   118              (new_node (dname, (SOME (EQN (s, T, "")), module, "")) gr);
   119            val (gr2, fundef) = mk_fundef module "" "fun " gr1 eqs';
   120            val xs = cycle gr2 ([], dname);
   121            val cs = map (fn x => case get_node gr2 x of
   122                (SOME (EQN (s, T, _)), _, _) => (s, T)
   123              | _ => error ("RecfunCodegen: illegal cyclic dependencies:\n" ^
   124                 implode (separate ", " xs))) xs
   125          in (case xs of
   126              [_] => (put_code module fundef gr2, module)
   127            | _ =>
   128              if not (dep mem xs) then
   129                let
   130                  val thmss as (_, thyname) :: _ = map (get_equations thy defs) cs;
   131                  val module' = if_library thyname module;
   132                  val eqs'' = map (dest_eq o prop_of) (List.concat (map fst thmss));
   133                  val (gr3, fundef') = mk_fundef module' "" "fun "
   134                    (add_edge (dname, dep)
   135                      (foldr (uncurry new_node) (del_nodes xs gr2)
   136                        (map (fn k =>
   137                          (k, (SOME (EQN ("", dummyT, dname)), module', ""))) xs))) eqs''
   138                in (put_code module' fundef' gr3, module') end
   139              else (gr2, module))
   140          end
   141      | SOME (SOME (EQN (_, _, s)), module', _) =>
   142          (if s = "" then
   143             if dname = dep then gr else add_edge (dname, dep) gr
   144           else if s = dep then gr else add_edge (s, dep) gr,
   145           module'))
   146   end;
   147 
   148 fun recfun_codegen thy defs gr dep module brack t = (case strip_comb t of
   149     (Const (p as (s, T)), ts) => (case (get_equations thy defs p, get_assoc_code thy (s, T)) of
   150        (([], _), _) => NONE
   151      | (_, SOME _) => NONE
   152      | ((eqns, thyname), NONE) =>
   153         let
   154           val module' = if_library thyname module;
   155           val (gr', ps) = foldl_map
   156             (invoke_codegen thy defs dep module true) (gr, ts);
   157           val suffix = mk_suffix thy defs p;
   158           val (gr'', module'') =
   159             add_rec_funs thy defs gr' dep (map prop_of eqns) module';
   160           val (gr''', fname) = mk_const_id module'' (s ^ suffix) gr''
   161         in
   162           SOME (gr''', mk_app brack (Pretty.str (mk_qual_id module fname)) ps)
   163         end)
   164   | _ => NONE);
   165 
   166 
   167 val setup =
   168   add_codegen "recfun" recfun_codegen
   169   #> Code.add_attribute ("", Args.del |-- Scan.succeed del
   170      || Scan.option (Args.$$$ "target" |-- Args.colon |-- Args.name) >> add);
   171 
   172 end;