src/HOL/Tools/basic_codegen.ML
changeset 11537 e007d35359c3
child 11539 0f17da240450
equal deleted inserted replaced
11536:6adf4d532679 11537:e007d35359c3
       
     1 (*  Title:      Pure/HOL/basic_codegen.ML
       
     2     ID:         $Id$
       
     3     Author:     Stefan Berghofer
       
     4     Copyright   2000  TU Muenchen
       
     5 
       
     6 Code generator for inductive datatypes and recursive functions
       
     7 *)
       
     8 
       
     9 signature BASIC_CODEGEN =
       
    10 sig
       
    11   val setup: (theory -> theory) list
       
    12 end;
       
    13 
       
    14 structure BasicCodegen : BASIC_CODEGEN =
       
    15 struct
       
    16 
       
    17 open Codegen;
       
    18 
       
    19 fun mk_poly_id thy (s, T) = mk_const_id (sign_of thy) s ^
       
    20   (case get_defn thy s T of
       
    21      Some (_, Some i) => "_def" ^ string_of_int i
       
    22    | _ => "");
       
    23 
       
    24 fun mk_tuple [p] = p
       
    25   | mk_tuple ps = Pretty.block (Pretty.str "(" ::
       
    26       flat (separate [Pretty.str ",", Pretty.brk 1] (map single ps)) @
       
    27         [Pretty.str ")"]);
       
    28 
       
    29 fun add_rec_funs thy dep (gr, eqs) =
       
    30   let
       
    31     fun dest_eq t =
       
    32       let val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop
       
    33             (Logic.strip_imp_concl (rename_term t)))
       
    34       in
       
    35         (mk_poly_id thy (dest_Const (head_of lhs)), (lhs, rhs))
       
    36       end;
       
    37     val eqs' = sort (string_ord o pairself fst) (map dest_eq eqs);
       
    38     val (dname, _) :: _ = eqs';
       
    39 
       
    40     fun mk_fundef fname prfx gr [] = (gr, [])
       
    41       | mk_fundef fname prfx gr ((fname', (lhs, rhs))::xs) =
       
    42       let
       
    43         val (gr1, pl) = invoke_codegen thy gr dname false lhs;
       
    44         val (gr2, pr) = invoke_codegen thy gr1 dname false rhs;
       
    45         val (gr3, rest) = mk_fundef fname' "and " gr2 xs
       
    46       in
       
    47         (gr3, Pretty.blk (4, [Pretty.str (if fname=fname' then "  | " else prfx),
       
    48            pl, Pretty.str " =", Pretty.brk 1, pr]) :: rest)
       
    49       end
       
    50 
       
    51   in
       
    52     (Graph.add_edge (dname, dep) gr handle Graph.UNDEF _ =>
       
    53        let
       
    54          val gr1 = Graph.add_edge (dname, dep)
       
    55            (Graph.new_node (dname, (None, "")) gr);
       
    56          val (gr2, fundef) = mk_fundef "" "fun " gr1 eqs'
       
    57        in
       
    58          Graph.map_node dname (K (None, Pretty.string_of (Pretty.blk (0,
       
    59            separate Pretty.fbrk fundef @ [Pretty.str ";"])) ^ "\n\n")) gr2
       
    60        end)
       
    61   end;
       
    62 
       
    63 
       
    64 (**** generate functions for datatypes specified by descr ****)
       
    65 (**** (i.e. constructors and case combinators)            ****)
       
    66 
       
    67 fun mk_typ _ _ (TVar ((s, i), _)) =
       
    68      Pretty.str (s ^ (if i=0 then "" else string_of_int i))
       
    69   | mk_typ _ _ (TFree (s, _)) = Pretty.str s
       
    70   | mk_typ sg types (Type ("fun", [T, U])) = Pretty.block [Pretty.str "(",
       
    71      mk_typ sg types T, Pretty.str " ->", Pretty.brk 1,
       
    72      mk_typ sg types U, Pretty.str ")"]
       
    73   | mk_typ sg types (Type (s, Ts)) = Pretty.block ((if null Ts then [] else
       
    74       [mk_tuple (map (mk_typ sg types) Ts), Pretty.str " "]) @
       
    75       [Pretty.str (if_none (assoc (types, s)) (mk_type_id sg s))]);
       
    76 
       
    77 fun add_dt_defs thy dep (gr, descr) =
       
    78   let
       
    79     val sg = sign_of thy;
       
    80     val tab = DatatypePackage.get_datatypes thy;
       
    81 
       
    82     val descr' = filter (can (map DatatypeAux.dest_DtTFree o #2 o snd)) descr;
       
    83 
       
    84     val (_, (_, _, (cname, _) :: _)) :: _ = descr';
       
    85     val dname = mk_const_id sg cname;
       
    86 
       
    87     fun mk_dtdef gr prfx [] = (gr, [])
       
    88       | mk_dtdef gr prfx ((_, (tname, dts, cs))::xs) =
       
    89           let
       
    90             val types = get_assoc_types thy;
       
    91             val tvs = map DatatypeAux.dest_DtTFree dts;
       
    92             val sorts = map (rpair []) tvs;
       
    93             val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
       
    94             val tycons = foldr add_typ_tycons (flat (map snd cs'), []) \\
       
    95               ("fun" :: map fst types);
       
    96             val descrs = map (fn s => case Symtab.lookup (tab, s) of
       
    97                 None => error ("Not a datatype: " ^ s ^ "\nrequired by:\n" ^
       
    98                   commas (Graph.all_succs gr [dep]))
       
    99               | Some info => #descr info) tycons;
       
   100             val gr' = foldl (add_dt_defs thy dname) (gr, descrs);
       
   101             val (gr'', rest) = mk_dtdef gr' "and " xs
       
   102           in
       
   103             (gr'',
       
   104              Pretty.block (Pretty.str prfx ::
       
   105                (if null tvs then [] else
       
   106                   [mk_tuple (map Pretty.str tvs), Pretty.str " "]) @
       
   107                [Pretty.str (mk_type_id sg tname ^ " ="), Pretty.brk 1] @
       
   108                flat (separate [Pretty.brk 1, Pretty.str "| "]
       
   109                  (map (fn (cname, cargs) => [Pretty.block
       
   110                    (Pretty.str (mk_const_id sg cname) ::
       
   111                     (if null cargs then [] else
       
   112                      flat ([Pretty.str " of", Pretty.brk 1] ::
       
   113                        separate [Pretty.str " *", Pretty.brk 1]
       
   114                          (map (single o mk_typ sg types) cargs))))]) cs'))) :: rest)
       
   115           end
       
   116   in
       
   117     ((Graph.add_edge_acyclic (dname, dep) gr
       
   118         handle Graph.CYCLES _ => gr) handle Graph.UNDEF _ =>
       
   119          let
       
   120            val gr1 = Graph.add_edge (dname, dep)
       
   121              (Graph.new_node (dname, (None, "")) gr);
       
   122            val (gr2, dtdef) = mk_dtdef gr1 "datatype " descr';
       
   123          in
       
   124            Graph.map_node dname (K (None,
       
   125              Pretty.string_of (Pretty.blk (0, separate Pretty.fbrk dtdef @
       
   126                [Pretty.str ";"])) ^ "\n\n")) gr2
       
   127          end)
       
   128   end;
       
   129 
       
   130 
       
   131 (**** generate code for applications of constructors and case ****)
       
   132 (**** combinators for datatypes                               ****)
       
   133 
       
   134 fun pretty_case thy gr dep brack constrs (c as Const (_, T)) ts =
       
   135   let val i = length constrs
       
   136   in if length ts <= i then
       
   137        invoke_codegen thy gr dep brack (eta_expand c ts (i+1))
       
   138     else
       
   139       let
       
   140         val ts1 = take (i, ts);
       
   141         val t :: ts2 = drop (i, ts);
       
   142         val names = foldr add_term_names (ts1,
       
   143           map (fst o fst o dest_Var) (foldr add_term_vars (ts1, [])));
       
   144         val (Ts, dT) = split_last (take (i+1, fst (strip_type T)));
       
   145 
       
   146         fun pcase gr [] [] [] = ([], gr)
       
   147           | pcase gr ((cname, cargs)::cs) (t::ts) (U::Us) =
       
   148               let
       
   149                 val j = length cargs;
       
   150                 val (Ts, _) = strip_type (fastype_of t);
       
   151                 val xs = variantlist (replicate j "x", names);
       
   152                 val Us' = take (j, fst (strip_type U));
       
   153                 val frees = map Free (xs ~~ Us');
       
   154                 val (gr0, cp) = invoke_codegen thy gr dep false
       
   155                   (list_comb (Const (cname, Us' ---> dT), frees));
       
   156                 val t' = Envir.beta_norm (list_comb (t, frees));
       
   157                 val (gr1, p) = invoke_codegen thy gr0 dep false t';
       
   158                 val (ps, gr2) = pcase gr1 cs ts Us;
       
   159               in
       
   160                 ([Pretty.block [cp, Pretty.str " =>", Pretty.brk 1, p]] :: ps, gr2)
       
   161               end;
       
   162 
       
   163         val (ps1, gr1) = pcase gr constrs ts1 Ts;
       
   164         val ps = flat (separate [Pretty.brk 1, Pretty.str "| "] ps1);
       
   165         val (gr2, p) = invoke_codegen thy gr1 dep false t;
       
   166         val (gr3, ps2) = foldl_map
       
   167          (fn (gr, t) => invoke_codegen thy gr dep true t) (gr2, ts2)
       
   168       in (gr3, (if not (null ts2) andalso brack then parens else I)
       
   169         (Pretty.block (separate (Pretty.brk 1)
       
   170           (Pretty.block ([Pretty.str "(case ", p, Pretty.str " of",
       
   171              Pretty.brk 1] @ ps @ [Pretty.str ")"]) :: ps2))))
       
   172       end
       
   173   end;
       
   174 
       
   175 
       
   176 fun pretty_constr thy gr dep brack args (c as Const (s, _)) ts =
       
   177   let val i = length args
       
   178   in if length ts < i then
       
   179       invoke_codegen thy gr dep brack (eta_expand c ts i)
       
   180      else
       
   181        let
       
   182          val id = mk_const_id (sign_of thy) s;
       
   183          val (gr', ps) = foldl_map
       
   184            (fn (gr, t) => invoke_codegen thy gr dep (i = 1) t) (gr, ts);
       
   185        in (case args of
       
   186           [] => (gr', Pretty.str id)
       
   187         | [_] => (gr', mk_app brack (Pretty.str id) ps)
       
   188         | _ => (gr', (if brack then parens else I) (Pretty.block
       
   189             ([Pretty.str id, Pretty.brk 1, Pretty.str "("] @
       
   190              flat (separate [Pretty.str ",", Pretty.brk 1] (map single ps)) @
       
   191              [Pretty.str ")"]))))
       
   192        end
       
   193   end;
       
   194 
       
   195 
       
   196 fun mk_recfun thy gr dep brack s T ts eqns =
       
   197   let val (gr', ps) = foldl_map
       
   198     (fn (gr, t) => invoke_codegen thy gr dep true t) (gr, ts)
       
   199   in
       
   200     Some (add_rec_funs thy dep (gr', map (#prop o rep_thm) eqns),
       
   201       mk_app brack (Pretty.str (mk_poly_id thy (s, T))) ps)
       
   202   end;
       
   203 
       
   204 
       
   205 fun datatype_codegen thy gr dep brack t = (case strip_comb t of
       
   206    (c as Const (s, T), ts) =>
       
   207        (case find_first (fn (_, {index, descr, case_name, rec_names, ...}) =>
       
   208          s = case_name orelse s mem rec_names orelse
       
   209            is_some (assoc (#3 (the (assoc (descr, index))), s)))
       
   210              (Symtab.dest (DatatypePackage.get_datatypes thy)) of
       
   211           None => None
       
   212         | Some (tname, {index, descr, case_name, rec_names, rec_rewrites, ...}) =>
       
   213            if is_some (get_assoc_code thy s T) then None else
       
   214            let
       
   215              val Some (_, _, constrs) = assoc (descr, index);
       
   216              val gr1 =
       
   217               if exists (equal tname o fst) (get_assoc_types thy) then gr
       
   218               else add_dt_defs thy dep (gr, descr);
       
   219            in
       
   220              (case assoc (constrs, s) of
       
   221                 None => if s mem rec_names then
       
   222                     mk_recfun thy gr1 dep brack s T ts rec_rewrites
       
   223                   else Some (pretty_case thy gr1 dep brack constrs c ts)
       
   224               | Some args => Some (pretty_constr thy gr1 dep brack args c ts))
       
   225            end)
       
   226  |  _ => None);
       
   227 
       
   228 
       
   229 (**** generate code for primrec and recdef ****)
       
   230 
       
   231 fun recfun_codegen thy gr dep brack t = (case strip_comb t of
       
   232     (Const (s, T), ts) =>
       
   233       (case PrimrecPackage.get_primrec thy s of
       
   234          Some ps => (case find_first (fn (_, thm::_) =>
       
   235                is_instance thy T (snd (dest_Const (head_of
       
   236                  (fst (HOLogic.dest_eq
       
   237                    (HOLogic.dest_Trueprop (#prop (rep_thm thm))))))))) ps of
       
   238              Some (_, thms) => mk_recfun thy gr dep brack s T ts thms
       
   239            | None => None)
       
   240        | None => case RecdefPackage.get_recdef thy s of
       
   241             Some {simps, ...} => mk_recfun thy gr dep brack s T ts simps
       
   242           | None => None)
       
   243   | _ => None);
       
   244 
       
   245 
       
   246 val setup = [add_codegen "datatype" datatype_codegen,
       
   247              add_codegen "primrec+recdef" recfun_codegen];
       
   248 
       
   249 end;