src/HOL/Tools/recfun_codegen.ML
author wenzelm
Tue Apr 19 23:57:28 2011 +0200 (2011-04-19)
changeset 42411 ff997038e8eb
parent 41448 72ba43b47c7f
permissions -rw-r--r--
eliminated Codegen.mode in favour of explicit argument;
haftmann@24584
     1
(*  Title:      HOL/Tools/recfun_codegen.ML
berghofe@12447
     2
    Author:     Stefan Berghofer, TU Muenchen
berghofe@12447
     3
berghofe@12447
     4
Code generator for recursive functions.
berghofe@12447
     5
*)
berghofe@12447
     6
berghofe@12447
     7
signature RECFUN_CODEGEN =
berghofe@12447
     8
sig
wenzelm@18708
     9
  val setup: theory -> theory
berghofe@12447
    10
end;
berghofe@12447
    11
berghofe@12447
    12
structure RecfunCodegen : RECFUN_CODEGEN =
berghofe@12447
    13
struct
berghofe@12447
    14
haftmann@32358
    15
val const_of = dest_Const o head_of o fst o Logic.dest_equals;
haftmann@32358
    16
wenzelm@33522
    17
structure ModuleData = Theory_Data
wenzelm@22846
    18
(
haftmann@28522
    19
  type T = string Symtab.table;
berghofe@12447
    20
  val empty = Symtab.empty;
wenzelm@16424
    21
  val extend = I;
wenzelm@33522
    22
  fun merge data = Symtab.merge (K true) data;
wenzelm@22846
    23
);
berghofe@12447
    24
haftmann@31998
    25
fun add_thm_target module_name thm thy =
haftmann@31998
    26
  let
haftmann@31998
    27
    val (thm', _) = Code.mk_eqn thy (thm, true)
haftmann@31998
    28
  in
haftmann@31998
    29
    thy
haftmann@31998
    30
    |> ModuleData.map (Symtab.update (fst (Code.const_typ_eqn thy thm'), module_name))
haftmann@31998
    31
  end;
haftmann@24624
    32
haftmann@32358
    33
fun avoid_value thy [thm] =
haftmann@32358
    34
      let val (_, T) = Code.const_typ_eqn thy thm
wenzelm@40844
    35
      in
wenzelm@40844
    36
        if null (Term.add_tvarsT T []) orelse null (binder_types T)
haftmann@32358
    37
        then [thm]
haftmann@34895
    38
        else [Code.expand_eta thy 1 thm]
haftmann@32358
    39
      end
haftmann@32358
    40
  | avoid_value thy thms = thms;
haftmann@28522
    41
haftmann@38864
    42
fun get_equations thy defs (raw_c, T) = if raw_c = @{const_name HOL.eq} then ([], "") else
haftmann@32358
    43
  let
haftmann@32358
    44
    val c = AxClass.unoverload_const thy (raw_c, T);
haftmann@34893
    45
    val raw_thms = Code.get_cert thy (Code_Preproc.preprocess_functrans thy) c
haftmann@35225
    46
      |> Code.bare_thms_of_cert thy
haftmann@34891
    47
      |> map (AxClass.overload thy)
wenzelm@41448
    48
      |> filter (Codegen.is_instance T o snd o const_of o prop_of);
haftmann@32358
    49
    val module_name = case Symtab.lookup (ModuleData.get thy) c
haftmann@32358
    50
     of SOME module_name => module_name
wenzelm@41448
    51
      | NONE =>
wenzelm@41448
    52
        case Codegen.get_defn thy defs c T
haftmann@32358
    53
         of SOME ((_, (thyname, _)), _) => thyname
haftmann@32358
    54
          | NONE => Codegen.thyname_of_const thy c;
haftmann@32358
    55
  in if null raw_thms then ([], "") else
haftmann@32358
    56
    raw_thms
wenzelm@41448
    57
    |> Codegen.preprocess thy
haftmann@32358
    58
    |> avoid_value thy
haftmann@32358
    59
    |> rpair module_name
haftmann@32358
    60
  end;
berghofe@12447
    61
wenzelm@41448
    62
fun mk_suffix thy defs (s, T) =
wenzelm@41448
    63
  (case Codegen.get_defn thy defs s T of
wenzelm@41448
    64
    SOME (_, SOME i) => " def" ^ string_of_int i
wenzelm@41448
    65
  | _ => "");
berghofe@12447
    66
berghofe@12447
    67
exception EQN of string * typ * string;
berghofe@12447
    68
wenzelm@33244
    69
fun cycle g x xs =
haftmann@22887
    70
  if member (op =) xs x then xs
wenzelm@33244
    71
  else fold (cycle g) (flat (Graph.all_paths (fst g) (x, x))) (x :: xs);
berghofe@12447
    72
wenzelm@42411
    73
fun add_rec_funs thy mode defs dep module eqs gr =
berghofe@12447
    74
  let
berghofe@16645
    75
    fun dest_eq t = (fst (const_of t) ^ mk_suffix thy defs (const_of t),
wenzelm@41448
    76
      Logic.dest_equals (Codegen.rename_term t));
berghofe@12447
    77
    val eqs' = map dest_eq eqs;
berghofe@12447
    78
    val (dname, _) :: _ = eqs';
berghofe@12447
    79
    val (s, T) = const_of (hd eqs);
berghofe@12447
    80
haftmann@28535
    81
    fun mk_fundef module fname first [] gr = ([], gr)
haftmann@28535
    82
      | mk_fundef module fname first ((fname' : string, (lhs, rhs)) :: xs) gr =
berghofe@12447
    83
      let
wenzelm@42411
    84
        val (pl, gr1) = Codegen.invoke_codegen thy mode defs dname module false lhs gr;
wenzelm@42411
    85
        val (pr, gr2) = Codegen.invoke_codegen thy mode defs dname module false rhs gr1;
haftmann@28535
    86
        val (rest, gr3) = mk_fundef module fname' false xs gr2 ;
wenzelm@42411
    87
        val (ty, gr4) = Codegen.invoke_tycodegen thy mode defs dname module false T gr3;
haftmann@28535
    88
        val num_args = (length o snd o strip_comb) lhs;
haftmann@28535
    89
        val prfx = if fname = fname' then "  |"
haftmann@28535
    90
          else if not first then "and"
haftmann@28535
    91
          else if num_args = 0 then "val"
haftmann@28535
    92
          else "fun";
wenzelm@41448
    93
        val pl' = Pretty.breaks (Codegen.str prfx
wenzelm@41448
    94
          :: (if num_args = 0 then [pl, Codegen.str ":", ty] else [pl]));
berghofe@12447
    95
      in
haftmann@28535
    96
        (Pretty.blk (4, pl'
wenzelm@41448
    97
           @ [Codegen.str " =", Pretty.brk 1, pr]) :: rest, gr4)
berghofe@12447
    98
      end;
berghofe@12447
    99
wenzelm@41448
   100
    fun put_code module fundef = Codegen.map_node dname
wenzelm@41448
   101
      (K (SOME (EQN ("", dummyT, dname)), module, Codegen.string_of (Pretty.blk (0,
wenzelm@41448
   102
      separate Pretty.fbrk fundef @ [Codegen.str ";"])) ^ "\n\n"));
berghofe@12447
   103
berghofe@12447
   104
  in
wenzelm@41448
   105
    (case try (Codegen.get_node gr) dname of
skalberg@15531
   106
       NONE =>
berghofe@12447
   107
         let
wenzelm@41448
   108
           val gr1 = Codegen.add_edge (dname, dep)
wenzelm@41448
   109
             (Codegen.new_node (dname, (SOME (EQN (s, T, "")), module, "")) gr);
haftmann@28535
   110
           val (fundef, gr2) = mk_fundef module "" true eqs' gr1 ;
wenzelm@33244
   111
           val xs = cycle gr2 dname [];
wenzelm@41448
   112
           val cs = map (fn x =>
wenzelm@41448
   113
             case Codegen.get_node gr2 x of
berghofe@16645
   114
               (SOME (EQN (s, T, _)), _, _) => (s, T)
berghofe@12447
   115
             | _ => error ("RecfunCodegen: illegal cyclic dependencies:\n" ^
berghofe@12447
   116
                implode (separate ", " xs))) xs
wenzelm@41448
   117
         in
wenzelm@41448
   118
           (case xs of
haftmann@28535
   119
             [_] => (module, put_code module fundef gr2)
berghofe@12447
   120
           | _ =>
haftmann@36692
   121
             if not (member (op =) xs dep) then
berghofe@12447
   122
               let
berghofe@16645
   123
                 val thmss as (_, thyname) :: _ = map (get_equations thy defs) cs;
wenzelm@42411
   124
                 val module' = Codegen.if_library mode thyname module;
wenzelm@32952
   125
                 val eqs'' = map (dest_eq o prop_of) (maps fst thmss);
haftmann@28535
   126
                 val (fundef', gr3) = mk_fundef module' "" true eqs''
wenzelm@41448
   127
                   (Codegen.add_edge (dname, dep)
wenzelm@41448
   128
                     (List.foldr (uncurry Codegen.new_node) (Codegen.del_nodes xs gr2)
skalberg@15574
   129
                       (map (fn k =>
haftmann@28535
   130
                         (k, (SOME (EQN ("", dummyT, dname)), module', ""))) xs)))
haftmann@28535
   131
               in (module', put_code module' fundef' gr3) end
haftmann@28535
   132
             else (module, gr2))
berghofe@12447
   133
         end
berghofe@17144
   134
     | SOME (SOME (EQN (_, _, s)), module', _) =>
haftmann@28535
   135
         (module', if s = "" then
wenzelm@41448
   136
            if dname = dep then gr else Codegen.add_edge (dname, dep) gr
wenzelm@41448
   137
          else if s = dep then gr else Codegen.add_edge (s, dep) gr))
berghofe@12447
   138
  end;
berghofe@12447
   139
wenzelm@42411
   140
fun recfun_codegen thy mode defs dep module brack t gr =
wenzelm@41448
   141
  (case strip_comb t of
wenzelm@41448
   142
    (Const (p as (s, T)), ts) =>
wenzelm@41448
   143
     (case (get_equations thy defs p, Codegen.get_assoc_code thy (s, T)) of
berghofe@16645
   144
       (([], _), _) => NONE
skalberg@15531
   145
     | (_, SOME _) => NONE
berghofe@17144
   146
     | ((eqns, thyname), NONE) =>
berghofe@16645
   147
        let
wenzelm@42411
   148
          val module' = Codegen.if_library mode thyname module;
haftmann@28535
   149
          val (ps, gr') = fold_map
wenzelm@42411
   150
            (Codegen.invoke_codegen thy mode defs dep module true) ts gr;
berghofe@17144
   151
          val suffix = mk_suffix thy defs p;
haftmann@28535
   152
          val (module'', gr'') =
wenzelm@42411
   153
            add_rec_funs thy mode defs dep module' (map prop_of eqns) gr';
wenzelm@41448
   154
          val (fname, gr''') = Codegen.mk_const_id module'' (s ^ suffix) gr''
berghofe@12447
   155
        in
wenzelm@41448
   156
          SOME (Codegen.mk_app brack (Codegen.str (Codegen.mk_qual_id module fname)) ps, gr''')
berghofe@12447
   157
        end)
skalberg@15531
   158
  | _ => NONE);
berghofe@12447
   159
haftmann@31998
   160
val setup = 
wenzelm@41448
   161
  Codegen.add_codegen "recfun" recfun_codegen
haftmann@31998
   162
  #> Code.set_code_target_attr add_thm_target;
berghofe@12447
   163
berghofe@12447
   164
end;