src/Pure/codegen.ML
changeset 16649 d88271eb5b26
parent 16458 4c6fd0c01d28
child 16769 7f188f2127f7
equal deleted inserted replaced
16648:fc2a425f0977 16649:d88271eb5b26
    16       Arg
    16       Arg
    17     | Ignore
    17     | Ignore
    18     | Pretty of Pretty.T
    18     | Pretty of Pretty.T
    19     | Quote of 'a;
    19     | Quote of 'a;
    20 
    20 
       
    21   type deftab
       
    22   type codegr
    21   type 'a codegen
    23   type 'a codegen
    22 
    24 
    23   val add_codegen: string -> term codegen -> theory -> theory
    25   val add_codegen: string -> term codegen -> theory -> theory
    24   val add_tycodegen: string -> typ codegen -> theory -> theory
    26   val add_tycodegen: string -> typ codegen -> theory -> theory
    25   val add_attribute: string -> (Args.T list -> theory attribute * Args.T list) -> theory -> theory
    27   val add_attribute: string -> (Args.T list -> theory attribute * Args.T list) -> theory -> theory
    26   val add_preprocessor: (theory -> thm list -> thm list) -> theory -> theory
    28   val add_preprocessor: (theory -> thm list -> thm list) -> theory -> theory
    27   val preprocess: theory -> thm list -> thm list
    29   val preprocess: theory -> thm list -> thm list
    28   val print_codegens: theory -> unit
    30   val print_codegens: theory -> unit
    29   val generate_code: theory -> (string * string) list -> string
    31   val generate_code: theory -> (string * string) list -> (string * string) list
    30   val generate_code_i: theory -> (string * term) list -> string
    32   val generate_code_i: theory -> (string * term) list -> (string * string) list
    31   val assoc_consts: (xstring * string option * term mixfix list) list -> theory -> theory
    33   val assoc_consts: (xstring * string option * term mixfix list) list -> theory -> theory
    32   val assoc_consts_i: (xstring * typ option * term mixfix list) list -> theory -> theory
    34   val assoc_consts_i: (xstring * typ option * term mixfix list) list -> theory -> theory
    33   val assoc_types: (xstring * typ mixfix list) list -> theory -> theory
    35   val assoc_types: (xstring * typ mixfix list) list -> theory -> theory
    34   val get_assoc_code: theory -> string -> typ -> term mixfix list option
    36   val get_assoc_code: theory -> string -> typ -> term mixfix list option
    35   val get_assoc_type: theory -> string -> typ mixfix list option
    37   val get_assoc_type: theory -> string -> typ mixfix list option
    36   val invoke_codegen: theory -> string -> bool ->
    38   val invoke_codegen: theory -> deftab -> string -> string -> bool ->
    37     (exn option * string) Graph.T * term -> (exn option * string) Graph.T * Pretty.T
    39     codegr * term -> codegr * Pretty.T
    38   val invoke_tycodegen: theory -> string -> bool ->
    40   val invoke_tycodegen: theory -> deftab -> string -> string -> bool ->
    39     (exn option * string) Graph.T * typ -> (exn option * string) Graph.T * Pretty.T
    41     codegr * typ -> codegr * Pretty.T
    40   val mk_id: string -> string
    42   val mk_id: string -> string
    41   val mk_const_id: theory -> string -> string
    43   val mk_const_id: theory -> string -> string -> string -> string
    42   val mk_type_id: theory -> string -> string
    44   val mk_type_id: theory -> string -> string -> string -> string
       
    45   val thyname_of_type: string -> theory -> string
       
    46   val thyname_of_const: string -> theory -> string
       
    47   val rename_terms: term list -> term list
    43   val rename_term: term -> term
    48   val rename_term: term -> term
    44   val new_names: term -> string list -> string list
    49   val new_names: term -> string list -> string list
    45   val new_name: term -> string -> string
    50   val new_name: term -> string -> string
    46   val get_defn: theory -> string -> typ -> ((term list * term) * int option) option
    51   val get_defn: theory -> deftab -> string -> typ ->
       
    52     ((typ * (string * (term list * term))) * int option) option
    47   val is_instance: theory -> typ -> typ -> bool
    53   val is_instance: theory -> typ -> typ -> bool
    48   val parens: Pretty.T -> Pretty.T
    54   val parens: Pretty.T -> Pretty.T
    49   val mk_app: bool -> Pretty.T -> Pretty.T list -> Pretty.T
    55   val mk_app: bool -> Pretty.T -> Pretty.T list -> Pretty.T
    50   val eta_expand: term -> term list -> int -> term
    56   val eta_expand: term -> term list -> int -> term
    51   val strip_tname: string -> string
    57   val strip_tname: string -> string
    52   val mk_type: bool -> typ -> Pretty.T
    58   val mk_type: bool -> typ -> Pretty.T
    53   val mk_term_of: theory -> bool -> typ -> Pretty.T
    59   val mk_term_of: theory -> string -> bool -> typ -> Pretty.T
    54   val mk_gen: theory -> bool -> string list -> string -> typ -> Pretty.T
    60   val mk_gen: theory -> string -> bool -> string list -> string -> typ -> Pretty.T
    55   val test_fn: (int -> (string * term) list option) ref
    61   val test_fn: (int -> (string * term) list option) ref
    56   val test_term: theory -> int -> int -> term -> (string * term) list option
    62   val test_term: theory -> int -> int -> term -> (string * term) list option
    57   val parse_mixfix: (string -> 'a) -> string -> 'a mixfix list
    63   val parse_mixfix: (string -> 'a) -> string -> 'a mixfix list
       
    64   val mk_deftab: theory -> deftab
    58 end;
    65 end;
    59 
    66 
    60 structure Codegen : CODEGEN =
    67 structure Codegen : CODEGEN =
    61 struct
    68 struct
    62 
    69 
    91 fun num_args x = length (List.filter is_arg x);
    98 fun num_args x = length (List.filter is_arg x);
    92 
    99 
    93 
   100 
    94 (**** theory data ****)
   101 (**** theory data ****)
    95 
   102 
       
   103 (* preprocessed definition table *)
       
   104 
       
   105 type deftab =
       
   106   (typ *              (* type of constant *)
       
   107     (string *         (* name of theory containing definition of constant *)
       
   108       (term list *    (* parameters *)
       
   109        term)))        (* right-hand side *)
       
   110   list Symtab.table;
       
   111 
       
   112 (* code dependency graph *)
       
   113 
       
   114 type codegr =
       
   115   (exn option *    (* slot for arbitrary data *)
       
   116    string *        (* name of structure containing piece of code *)
       
   117    string)         (* piece of code *)
       
   118   Graph.T;
       
   119 
    96 (* type of code generators *)
   120 (* type of code generators *)
    97 
   121 
    98 type 'a codegen = theory -> (exn option * string) Graph.T ->
   122 type 'a codegen =
    99   string -> bool -> 'a -> ((exn option * string) Graph.T * Pretty.T) option;
   123   theory ->    (* theory in which generate_code was called *)
       
   124   deftab ->    (* definition table (for efficiency) *)
       
   125   codegr ->    (* code dependency graph *)
       
   126   string ->    (* node name of caller (for recording dependencies) *)
       
   127   string ->    (* theory name of caller (for modular code generation) *)
       
   128   bool ->      (* whether to parenthesize generated expression *)
       
   129   'a ->        (* item to generate code from *)
       
   130   (codegr * Pretty.T) option;
   100 
   131 
   101 (* parameters for random testing *)
   132 (* parameters for random testing *)
   102 
   133 
   103 type test_params =
   134 type test_params =
   104   {size: int, iterations: int, default_type: typ option};
   135   {size: int, iterations: int, default_type: typ option};
   296 
   327 
   297 fun assoc_types xs thy = Library.foldl (fn (thy, (s, syn)) =>
   328 fun assoc_types xs thy = Library.foldl (fn (thy, (s, syn)) =>
   298   let
   329   let
   299     val {codegens, tycodegens, consts, types, attrs, preprocs, test_params} =
   330     val {codegens, tycodegens, consts, types, attrs, preprocs, test_params} =
   300       CodegenData.get thy;
   331       CodegenData.get thy;
   301     val tc = Sign.intern_type (sign_of thy) s
   332     val tc = Sign.intern_type thy s
   302   in
   333   in
   303     (case assoc (types, tc) of
   334     (case assoc (types, tc) of
   304        NONE => CodegenData.put {codegens = codegens,
   335        NONE => CodegenData.put {codegens = codegens,
   305          tycodegens = tycodegens, consts = consts,
   336          tycodegens = tycodegens, consts = consts,
   306          types = (tc, syn) :: types, attrs = attrs,
   337          types = (tc, syn) :: types, attrs = attrs,
   337       (List.concat (map (check_str o Symbol.explode) (NameSpace.unpack s)))
   368       (List.concat (map (check_str o Symbol.explode) (NameSpace.unpack s)))
   338   in
   369   in
   339     if Symbol.is_ascii_letter (hd (explode s')) then s' else "id_" ^ s'
   370     if Symbol.is_ascii_letter (hd (explode s')) then s' else "id_" ^ s'
   340   end;
   371   end;
   341 
   372 
   342 fun mk_const_id thy s =
   373 fun extrn thy f thyname s =
   343   let val s' = mk_id (Sign.extern_const thy s)
   374   let
   344   in if s' mem ThmDatabase.ml_reserved then s' ^ "_const" else s' end;
   375     val xs = NameSpace.unpack s;
   345 
   376     val s' = setmp NameSpace.long_names false (setmp NameSpace.short_names false
   346 fun mk_type_id thy s =
   377       (setmp NameSpace.unique_names true (f thy))) s;
   347   let val s' = mk_id (Sign.extern_type thy s)
   378     val xs' = NameSpace.unpack s'
   348   in if s' mem ThmDatabase.ml_reserved then s' ^ "_type" else s' end;
   379   in
       
   380     if "modular" mem !mode andalso length xs = length xs' andalso hd xs' = thyname
       
   381     then NameSpace.pack (tl xs') else s'
       
   382   end;
       
   383 
       
   384 (* thyname:  theory name for caller                                        *)
       
   385 (* thyname': theory name for callee                                        *)
       
   386 (* if caller and callee reside in different theories, use qualified access *)
       
   387 
       
   388 fun mk_const_id thy thyname thyname' s =
       
   389   let
       
   390     val s' = mk_id (extrn thy Sign.extern_const thyname' s);
       
   391     val s'' = if s' mem ThmDatabase.ml_reserved then s' ^ "_const" else s'
       
   392   in
       
   393     if "modular" mem !mode andalso thyname <> thyname' andalso thyname' <> ""
       
   394     then thyname' ^ "." ^ s'' else s''
       
   395   end;
       
   396 
       
   397 fun mk_type_id' f thy thyname thyname' s =
       
   398   let
       
   399     val s' = mk_id (extrn thy Sign.extern_type thyname' s);
       
   400     val s'' = f (if s' mem ThmDatabase.ml_reserved then s' ^ "_type" else s')
       
   401   in
       
   402     if "modular" mem !mode andalso thyname <> thyname' andalso thyname' <> ""
       
   403     then thyname' ^ "." ^ s'' else s''
       
   404   end;
       
   405 
       
   406 val mk_type_id = mk_type_id' I;
       
   407 
       
   408 fun theory_of_type s thy = 
       
   409   if Sign.declared_tyname thy s
       
   410   then SOME (if_none (get_first (theory_of_type s) (Theory.parents_of thy)) thy)
       
   411   else NONE;
       
   412 
       
   413 fun theory_of_const s thy = 
       
   414   if Sign.declared_const thy s
       
   415   then SOME (if_none (get_first (theory_of_const s) (Theory.parents_of thy)) thy)
       
   416   else NONE;
       
   417 
       
   418 fun thyname_of_type s thy = (case theory_of_type s thy of
       
   419     NONE => error ("thyname_of_type: no such type: " ^ quote s)
       
   420   | SOME thy' => Context.theory_name thy');
       
   421 
       
   422 fun thyname_of_const s thy = (case theory_of_const s thy of
       
   423     NONE => error ("thyname_of_const: no such constant: " ^ quote s)
       
   424   | SOME thy' => Context.theory_name thy');
   349 
   425 
   350 fun rename_terms ts =
   426 fun rename_terms ts =
   351   let
   427   let
   352     val names = foldr add_term_names
   428     val names = foldr add_term_names
   353       (map (fst o fst) (Drule.vars_of_terms ts)) ts;
   429       (map (fst o fst) (Drule.vars_of_terms ts)) ts;
   372 
   448 
   373 
   449 
   374 (**** retrieve definition of constant ****)
   450 (**** retrieve definition of constant ****)
   375 
   451 
   376 fun is_instance thy T1 T2 =
   452 fun is_instance thy T1 T2 =
   377   Sign.typ_instance (sign_of thy) (T1, Type.varifyT T2);
   453   Sign.typ_instance thy (T1, Type.varifyT T2);
   378 
   454 
   379 fun get_assoc_code thy s T = Option.map snd (find_first (fn ((s', T'), _) =>
   455 fun get_assoc_code thy s T = Option.map snd (find_first (fn ((s', T'), _) =>
   380   s = s' andalso is_instance thy T T') (#consts (CodegenData.get thy)));
   456   s = s' andalso is_instance thy T T') (#consts (CodegenData.get thy)));
   381 
   457 
   382 fun get_defn thy s T =
   458 fun mk_deftab thy =
   383   let
   459   let
   384     val axms = Theory.all_axioms_of thy;
   460     val axmss = map (fn thy' =>
       
   461       (Context.theory_name thy', snd (#axioms (Theory.rep_theory thy'))))
       
   462       (thy :: Theory.ancestors_of thy);
   385     fun prep_def def = (case preprocess thy [def] of
   463     fun prep_def def = (case preprocess thy [def] of
   386       [def'] => prop_of def' | _ => error "get_defn: bad preprocessor");
   464       [def'] => prop_of def' | _ => error "mk_deftab: bad preprocessor");
   387     fun dest t =
   465     fun dest t =
   388       let
   466       let
   389         val (lhs, rhs) = Logic.dest_equals t;
   467         val (lhs, rhs) = Logic.dest_equals t;
   390         val (c, args) = strip_comb lhs;
   468         val (c, args) = strip_comb lhs;
   391         val (s', T') = dest_Const c
   469         val (s, T) = dest_Const c
   392       in if s = s' then SOME (T', (args, rhs)) else NONE
   470       in if forall is_Var args then SOME (s, (T, (args, rhs))) else NONE
   393       end handle TERM _ => NONE;
   471       end handle TERM _ => NONE;
   394     val defs = List.mapPartial (fn (name, t) => Option.map (pair name) (dest t)) axms;
   472     fun add_def thyname (defs, (name, t)) = (case dest t of
   395     val i = find_index (is_instance thy T o fst o snd) defs
   473         NONE => defs
   396   in
   474       | SOME _ => (case dest (prep_def (Thm.get_axiom thy name)) of
   397     if i >= 0 then
   475           NONE => defs
   398       let val (name, (T', (args, _))) = List.nth (defs, i)
   476         | SOME (s, (T, (args, rhs))) => Symtab.update
   399       in case dest (prep_def (Thm.get_axiom thy name)) of
   477             ((s, (T, (thyname, split_last (rename_terms (args @ [rhs])))) ::
   400           NONE => NONE
   478             if_none (Symtab.lookup (defs, s)) []), defs)))
   401         | SOME (T'', p as (args', rhs)) =>
   479   in
   402             if T' = T'' andalso args = args' then
   480     foldl (fn ((thyname, axms), defs) =>
   403               SOME (split_last (rename_terms (args @ [rhs])),
   481       Symtab.foldl (add_def thyname) (defs, axms)) Symtab.empty axmss
   404                 if length defs = 1 then NONE else SOME i)
   482   end;
   405             else NONE
   483 
   406       end
   484 fun get_defn thy defs s T = (case Symtab.lookup (defs, s) of
   407     else NONE
   485     NONE => NONE
   408   end;
   486   | SOME ds =>
       
   487       let val i = find_index (is_instance thy T o fst) ds
       
   488       in if i >= 0 then
       
   489           SOME (List.nth (ds, i), if length ds = 1 then NONE else SOME i)
       
   490         else NONE
       
   491       end);
   409 
   492 
   410 
   493 
   411 (**** invoke suitable code generator for term / type ****)
   494 (**** invoke suitable code generator for term / type ****)
   412 
   495 
   413 fun invoke_codegen thy dep brack (gr, t) = (case get_first
   496 fun invoke_codegen thy defs dep thyname brack (gr, t) = (case get_first
   414    (fn (_, f) => f thy gr dep brack t) (#codegens (CodegenData.get thy)) of
   497    (fn (_, f) => f thy defs gr dep thyname brack t) (#codegens (CodegenData.get thy)) of
   415       NONE => error ("Unable to generate code for term:\n" ^
   498       NONE => error ("Unable to generate code for term:\n" ^
   416         Sign.string_of_term (sign_of thy) t ^ "\nrequired by:\n" ^
   499         Sign.string_of_term thy t ^ "\nrequired by:\n" ^
   417         commas (Graph.all_succs gr [dep]))
   500         commas (Graph.all_succs gr [dep]))
   418     | SOME x => x);
   501     | SOME x => x);
   419 
   502 
   420 fun invoke_tycodegen thy dep brack (gr, T) = (case get_first
   503 fun invoke_tycodegen thy defs dep thyname brack (gr, T) = (case get_first
   421    (fn (_, f) => f thy gr dep brack T) (#tycodegens (CodegenData.get thy)) of
   504    (fn (_, f) => f thy defs gr dep thyname brack T) (#tycodegens (CodegenData.get thy)) of
   422       NONE => error ("Unable to generate code for type:\n" ^
   505       NONE => error ("Unable to generate code for type:\n" ^
   423         Sign.string_of_typ (sign_of thy) T ^ "\nrequired by:\n" ^
   506         Sign.string_of_typ thy T ^ "\nrequired by:\n" ^
   424         commas (Graph.all_succs gr [dep]))
   507         commas (Graph.all_succs gr [dep]))
   425     | SOME x => x);
   508     | SOME x => x);
   426 
   509 
   427 
   510 
   428 (**** code generator for mixfix expressions ****)
   511 (**** code generator for mixfix expressions ****)
   461   map (fst o fst o dest_Var) (term_vars t) union
   544   map (fst o fst o dest_Var) (term_vars t) union
   462   add_term_names (t, ThmDatabase.ml_reserved));
   545   add_term_names (t, ThmDatabase.ml_reserved));
   463 
   546 
   464 fun new_name t x = hd (new_names t [x]);
   547 fun new_name t x = hd (new_names t [x]);
   465 
   548 
   466 fun default_codegen thy gr dep brack t =
   549 fun default_codegen thy defs gr dep thyname brack t =
   467   let
   550   let
   468     val (u, ts) = strip_comb t;
   551     val (u, ts) = strip_comb t;
   469     fun codegens brack = foldl_map (invoke_codegen thy dep brack)
   552     fun codegens brack = foldl_map (invoke_codegen thy defs dep thyname brack)
   470   in (case u of
   553   in (case u of
   471       Var ((s, i), T) =>
   554       Var ((s, i), T) =>
   472         let
   555         let
   473           val (gr', ps) = codegens true (gr, ts);
   556           val (gr', ps) = codegens true (gr, ts);
   474           val (gr'', _) = invoke_tycodegen thy dep false (gr', T)
   557           val (gr'', _) = invoke_tycodegen thy defs dep thyname false (gr', T)
   475         in SOME (gr'', mk_app brack (Pretty.str (s ^
   558         in SOME (gr'', mk_app brack (Pretty.str (s ^
   476            (if i=0 then "" else string_of_int i))) ps)
   559            (if i=0 then "" else string_of_int i))) ps)
   477         end
   560         end
   478 
   561 
   479     | Free (s, T) =>
   562     | Free (s, T) =>
   480         let
   563         let
   481           val (gr', ps) = codegens true (gr, ts);
   564           val (gr', ps) = codegens true (gr, ts);
   482           val (gr'', _) = invoke_tycodegen thy dep false (gr', T)
   565           val (gr'', _) = invoke_tycodegen thy defs dep thyname false (gr', T)
   483         in SOME (gr'', mk_app brack (Pretty.str s) ps) end
   566         in SOME (gr'', mk_app brack (Pretty.str s) ps) end
   484 
   567 
   485     | Const (s, T) =>
   568     | Const (s, T) =>
   486       (case get_assoc_code thy s T of
   569       (case get_assoc_code thy s T of
   487          SOME ms =>
   570          SOME ms =>
   488            let val i = num_args ms
   571            let val i = num_args ms
   489            in if length ts < i then
   572            in if length ts < i then
   490                default_codegen thy gr dep brack (eta_expand u ts i)
   573                default_codegen thy defs gr dep thyname brack (eta_expand u ts i)
   491              else
   574              else
   492                let
   575                let
   493                  val (ts1, ts2) = args_of ms ts;
   576                  val (ts1, ts2) = args_of ms ts;
   494                  val (gr1, ps1) = codegens false (gr, ts1);
   577                  val (gr1, ps1) = codegens false (gr, ts1);
   495                  val (gr2, ps2) = codegens true (gr1, ts2);
   578                  val (gr2, ps2) = codegens true (gr1, ts2);
   496                  val (gr3, ps3) = codegens false (gr2, quotes_of ms);
   579                  val (gr3, ps3) = codegens false (gr2, quotes_of ms);
   497                in
   580                in
   498                  SOME (gr3, mk_app brack (Pretty.block (pretty_mixfix ms ps1 ps3)) ps2)
   581                  SOME (gr3, mk_app brack (Pretty.block (pretty_mixfix ms ps1 ps3)) ps2)
   499                end
   582                end
   500            end
   583            end
   501        | NONE => (case get_defn thy s T of
   584        | NONE => (case get_defn thy defs s T of
   502            NONE => NONE
   585            NONE => NONE
   503          | SOME ((args, rhs), k) =>
   586          | SOME ((U, (thyname', (args, rhs))), k) =>
   504              let
   587              let
   505                val id = mk_const_id (sign_of thy) s ^ (case k of
   588                val suffix = (case k of NONE => "" | SOME i => "_def" ^ string_of_int i);
   506                  NONE => "" | SOME i => "_def" ^ string_of_int i);
   589                val node_id = s ^ suffix;
       
   590                val def_id = mk_const_id thy thyname' thyname' s ^ suffix;
       
   591                val call_id = mk_const_id thy thyname thyname' s ^ suffix;
   507                val (gr', ps) = codegens true (gr, ts);
   592                val (gr', ps) = codegens true (gr, ts);
   508              in
   593              in
   509                SOME (Graph.add_edge (id, dep) gr' handle Graph.UNDEF _ =>
   594                SOME (Graph.add_edge (node_id, dep) gr' handle Graph.UNDEF _ =>
   510                  let
   595                  let
   511                    val _ = message ("expanding definition of " ^ s);
   596                    val _ = message ("expanding definition of " ^ s);
   512                    val (Ts, _) = strip_type T;
   597                    val (Ts, _) = strip_type T;
   513                    val (args', rhs') =
   598                    val (args', rhs') =
   514                      if not (null args) orelse null Ts then (args, rhs) else
   599                      if not (null args) orelse null Ts then (args, rhs) else
   515                        let val v = Free (new_name rhs "x", hd Ts)
   600                        let val v = Free (new_name rhs "x", hd Ts)
   516                        in ([v], betapply (rhs, v)) end;
   601                        in ([v], betapply (rhs, v)) end;
   517                    val (gr1, p) = invoke_codegen thy id false
   602                    val (gr1, p) = invoke_codegen thy defs node_id thyname' false
   518                      (Graph.add_edge (id, dep)
   603                      (Graph.add_edge (node_id, dep)
   519                         (Graph.new_node (id, (NONE, "")) gr'), rhs');
   604                         (Graph.new_node (node_id, (NONE, "", "")) gr'), rhs');
   520                    val (gr2, xs) = codegens false (gr1, args');
   605                    val (gr2, xs) = codegens false (gr1, args');
   521                    val (gr3, ty) = invoke_tycodegen thy id false (gr2, T);
   606                    val (gr3, _) = invoke_tycodegen thy defs dep thyname false (gr2, T);
   522                  in Graph.map_node id (K (NONE, Pretty.string_of (Pretty.block
   607                    val (gr4, ty) = invoke_tycodegen thy defs node_id thyname' false (gr3, U);
   523                    (separate (Pretty.brk 1) (if null args' then
   608                  in Graph.map_node node_id (K (NONE, thyname', Pretty.string_of
   524                        [Pretty.str ("val " ^ id ^ " :"), ty]
   609                    (Pretty.block (separate (Pretty.brk 1)
   525                      else Pretty.str ("fun " ^ id) :: xs) @
   610                      (if null args' then
   526                     [Pretty.str " =", Pretty.brk 1, p, Pretty.str ";"])) ^ "\n\n")) gr3
   611                         [Pretty.str ("val " ^ def_id ^ " :"), ty]
   527                  end, mk_app brack (Pretty.str id) ps)
   612                       else Pretty.str ("fun " ^ def_id) :: xs) @
       
   613                     [Pretty.str " =", Pretty.brk 1, p, Pretty.str ";"])) ^ "\n\n")) gr4
       
   614                  end, mk_app brack (Pretty.str call_id) ps)
   528              end))
   615              end))
   529 
   616 
   530     | Abs _ =>
   617     | Abs _ =>
   531       let
   618       let
   532         val (bs, Ts) = ListPair.unzip (strip_abs_vars u);
   619         val (bs, Ts) = ListPair.unzip (strip_abs_vars u);
   533         val t = strip_abs_body u
   620         val t = strip_abs_body u
   534         val bs' = new_names t bs;
   621         val bs' = new_names t bs;
   535         val (gr1, ps) = codegens true (gr, ts);
   622         val (gr1, ps) = codegens true (gr, ts);
   536         val (gr2, p) = invoke_codegen thy dep false
   623         val (gr2, p) = invoke_codegen thy defs dep thyname false
   537           (gr1, subst_bounds (map Free (rev (bs' ~~ Ts)), t));
   624           (gr1, subst_bounds (map Free (rev (bs' ~~ Ts)), t));
   538       in
   625       in
   539         SOME (gr2, mk_app brack (Pretty.block (Pretty.str "(" :: pretty_fn bs' p @
   626         SOME (gr2, mk_app brack (Pretty.block (Pretty.str "(" :: pretty_fn bs' p @
   540           [Pretty.str ")"])) ps)
   627           [Pretty.str ")"])) ps)
   541       end
   628       end
   542 
   629 
   543     | _ => NONE)
   630     | _ => NONE)
   544   end;
   631   end;
   545 
   632 
   546 fun default_tycodegen thy gr dep brack (TVar ((s, i), _)) =
   633 fun default_tycodegen thy defs gr dep thyname brack (TVar ((s, i), _)) =
   547       SOME (gr, Pretty.str (s ^ (if i = 0 then "" else string_of_int i)))
   634       SOME (gr, Pretty.str (s ^ (if i = 0 then "" else string_of_int i)))
   548   | default_tycodegen thy gr dep brack (TFree (s, _)) = SOME (gr, Pretty.str s)
   635   | default_tycodegen thy defs gr dep thyname brack (TFree (s, _)) =
   549   | default_tycodegen thy gr dep brack (Type (s, Ts)) =
   636       SOME (gr, Pretty.str s)
       
   637   | default_tycodegen thy defs gr dep thyname brack (Type (s, Ts)) =
   550       (case assoc (#types (CodegenData.get thy), s) of
   638       (case assoc (#types (CodegenData.get thy), s) of
   551          NONE => NONE
   639          NONE => NONE
   552        | SOME ms =>
   640        | SOME ms =>
   553            let
   641            let
   554              val (gr', ps) = foldl_map
   642              val (gr', ps) = foldl_map
   555                (invoke_tycodegen thy dep false) (gr, fst (args_of ms Ts));
   643                (invoke_tycodegen thy defs dep thyname false)
       
   644                (gr, fst (args_of ms Ts));
   556              val (gr'', qs) = foldl_map
   645              val (gr'', qs) = foldl_map
   557                (invoke_tycodegen thy dep false) (gr', quotes_of ms)
   646                (invoke_tycodegen thy defs dep thyname false)
       
   647                (gr', quotes_of ms)
   558            in SOME (gr'', Pretty.block (pretty_mixfix ms ps qs)) end);
   648            in SOME (gr'', Pretty.block (pretty_mixfix ms ps qs)) end);
   559 
   649 
   560 val _ = Context.add_setup
   650 val _ = Context.add_setup
   561  [add_codegen "default" default_codegen,
   651  [add_codegen "default" default_codegen,
   562   add_tycodegen "default" default_tycodegen];
   652   add_tycodegen "default" default_tycodegen];
   563 
   653 
   564 
   654 
   565 fun output_code gr xs = implode (map (snd o Graph.get_node gr)
   655 fun mk_struct name s = "structure " ^ name ^ " =\nstruct\n\n" ^ s ^ "end;\n";
   566   (rev (Graph.all_preds gr xs)));
   656 
       
   657 fun add_to_module name s ms =
       
   658   overwrite (ms, (name, the (assoc (ms, name)) ^ s));
       
   659 
       
   660 fun output_code gr xs =
       
   661   let
       
   662     val code =
       
   663       map (fn s => (s, Graph.get_node gr s)) (rev (Graph.all_preds gr xs))
       
   664     fun string_of_cycle (a :: b :: cs) =
       
   665           let val SOME (x, y) = get_first (fn (x, (_, a', _)) =>
       
   666             if a = a' then Option.map (pair x)
       
   667               (find_first (equal b o #2 o Graph.get_node gr)
       
   668                 (Graph.imm_succs gr x))
       
   669             else NONE) code
       
   670           in x ^ " called by " ^ y ^ "\n" ^ string_of_cycle (b :: cs) end
       
   671       | string_of_cycle _ = ""
       
   672   in
       
   673     if "modular" mem !mode then
       
   674       let
       
   675         val modules = distinct (map (#2 o snd) code);
       
   676         val mod_gr = foldr (uncurry Graph.add_edge_acyclic)
       
   677           (foldr (uncurry (Graph.new_node o rpair ())) Graph.empty modules)
       
   678           (List.concat (map (fn (s, (_, thyname, _)) => map (pair thyname)
       
   679             (filter_out (equal thyname) (map (#2 o Graph.get_node gr)
       
   680               (Graph.imm_succs gr s)))) code));
       
   681         val modules' =
       
   682           rev (Graph.all_preds mod_gr (map (#2 o Graph.get_node gr) xs))
       
   683       in
       
   684         foldl (fn ((_, (_, thyname, s)), ms) => add_to_module thyname s ms)
       
   685           (map (rpair "") modules') code
       
   686       end handle Graph.CYCLES (cs :: _) =>
       
   687         error ("Cyclic dependency of modules:\n" ^ commas cs ^
       
   688           "\n" ^ string_of_cycle cs)
       
   689     else [("Generated", implode (map (#3 o snd) code))]
       
   690   end;
   567 
   691 
   568 fun gen_generate_code prep_term thy =
   692 fun gen_generate_code prep_term thy =
   569   setmp print_mode [] (Pretty.setmp_margin (!margin) (fn xs =>
   693   setmp print_mode [] (Pretty.setmp_margin (!margin) (fn xs =>
   570   let
   694   let
   571     val gr = Graph.new_node ("<Top>", (NONE, "")) Graph.empty;
   695     val defs = mk_deftab thy;
       
   696     val gr = Graph.new_node ("<Top>", (NONE, "Generated", "")) Graph.empty;
       
   697     fun expand (t as Abs _) = t
       
   698       | expand t = (case fastype_of t of
       
   699           Type ("fun", [T, U]) => Abs ("x", T, t $ Bound 0) | _ => t);
   572     val (gr', ps) = foldl_map (fn (gr, (s, t)) => apsnd (pair s)
   700     val (gr', ps) = foldl_map (fn (gr, (s, t)) => apsnd (pair s)
   573       (invoke_codegen thy "<Top>" false (gr, t)))
   701       (invoke_codegen thy defs "<Top>" "Generated" false (gr, t)))
   574         (gr, map (apsnd (prep_term thy)) xs)
   702         (gr, map (apsnd (expand o prep_term thy)) xs);
   575     val code =
   703     val code =
   576       "structure Generated =\nstruct\n\n" ^
       
   577       output_code gr' ["<Top>"] ^
       
   578       space_implode "\n\n" (map (fn (s', p) => Pretty.string_of (Pretty.block
   704       space_implode "\n\n" (map (fn (s', p) => Pretty.string_of (Pretty.block
   579         [Pretty.str ("val " ^ s' ^ " ="), Pretty.brk 1, p, Pretty.str ";"])) ps) ^
   705         [Pretty.str ("val " ^ s' ^ " ="), Pretty.brk 1, p, Pretty.str ";"])) ps) ^
   580       "\n\nend;\n\nopen Generated;\n";
   706       "\n\n"
   581   in code end));
   707   in
       
   708     map (fn (name, s) => (name, mk_struct name s))
       
   709       (add_to_module "Generated" code (output_code gr' ["<Top>"]))
       
   710   end));
   582 
   711 
   583 val generate_code_i = gen_generate_code (K I);
   712 val generate_code_i = gen_generate_code (K I);
   584 val generate_code = gen_generate_code
   713 val generate_code = gen_generate_code
   585   (fn thy => term_of o read_cterm thy o rpair TypeInfer.logicT);
   714   (fn thy => term_of o read_cterm thy o rpair TypeInfer.logicT);
   586 
   715 
   598   | mk_type p (TFree (s, _)) = Pretty.str (strip_tname s ^ "T")
   727   | mk_type p (TFree (s, _)) = Pretty.str (strip_tname s ^ "T")
   599   | mk_type p (Type (s, Ts)) = (if p then parens else I) (Pretty.block
   728   | mk_type p (Type (s, Ts)) = (if p then parens else I) (Pretty.block
   600       [Pretty.str "Type", Pretty.brk 1, Pretty.str ("(\"" ^ s ^ "\","),
   729       [Pretty.str "Type", Pretty.brk 1, Pretty.str ("(\"" ^ s ^ "\","),
   601        Pretty.brk 1, pretty_list (map (mk_type false) Ts), Pretty.str ")"]);
   730        Pretty.brk 1, pretty_list (map (mk_type false) Ts), Pretty.str ")"]);
   602 
   731 
   603 fun mk_term_of _ p (TVar ((s, i), _)) = Pretty.str
   732 fun mk_term_of thy thyname p (TVar ((s, i), _)) = Pretty.str
   604       (strip_tname s ^ (if i = 0 then "" else string_of_int i) ^ "F")
   733       (strip_tname s ^ (if i = 0 then "" else string_of_int i) ^ "F")
   605   | mk_term_of _ p (TFree (s, _)) = Pretty.str (strip_tname s ^ "F")
   734   | mk_term_of thy thyname p (TFree (s, _)) = Pretty.str (strip_tname s ^ "F")
   606   | mk_term_of thy p (Type (s, Ts)) = (if p then parens else I) (Pretty.block
   735   | mk_term_of thy thyname p (Type (s, Ts)) = (if p then parens else I)
   607       (separate (Pretty.brk 1) (Pretty.str ("term_of_" ^ mk_type_id thy s) ::
   736       (Pretty.block (separate (Pretty.brk 1)
   608         List.concat (map (fn T => [mk_term_of thy true T, mk_type true T]) Ts))));
   737         (Pretty.str (mk_type_id' (fn s' => "term_of_" ^ s')
       
   738           thy thyname (thyname_of_type s thy) s) ::
       
   739         List.concat (map (fn T =>
       
   740           [mk_term_of thy thyname true T, mk_type true T]) Ts))));
   609 
   741 
   610 
   742 
   611 (**** Test data generators ****)
   743 (**** Test data generators ****)
   612 
   744 
   613 fun mk_gen _ p xs a (TVar ((s, i), _)) = Pretty.str
   745 fun mk_gen thy thyname p xs a (TVar ((s, i), _)) = Pretty.str
   614       (strip_tname s ^ (if i = 0 then "" else string_of_int i) ^ "G")
   746       (strip_tname s ^ (if i = 0 then "" else string_of_int i) ^ "G")
   615   | mk_gen _ p xs a (TFree (s, _)) = Pretty.str (strip_tname s ^ "G")
   747   | mk_gen thy thyname p xs a (TFree (s, _)) = Pretty.str (strip_tname s ^ "G")
   616   | mk_gen thy p xs a (Type (s, Ts)) = (if p then parens else I) (Pretty.block
   748   | mk_gen thy thyname p xs a (Type (s, Ts)) = (if p then parens else I)
   617       (separate (Pretty.brk 1) (Pretty.str ("gen_" ^ mk_type_id thy s ^
   749       (Pretty.block (separate (Pretty.brk 1)
   618         (if s mem xs then "'" else "")) :: map (mk_gen thy true xs a) Ts @
   750         (Pretty.str (mk_type_id' (fn s' => "gen_" ^ s')
   619         (if s mem xs then [Pretty.str a] else []))));
   751           thy thyname (thyname_of_type s thy) s ^
       
   752           (if s mem xs then "'" else "")) ::
       
   753          map (mk_gen thy thyname true xs a) Ts @
       
   754          (if s mem xs then [Pretty.str a] else []))));
   620 
   755 
   621 val test_fn : (int -> (string * term) list option) ref = ref (fn _ => NONE);
   756 val test_fn : (int -> (string * term) list option) ref = ref (fn _ => NONE);
   622 
   757 
   623 fun test_term thy sz i = setmp print_mode [] (fn t =>
   758 fun test_term thy sz i = setmp print_mode [] (fn t =>
   624   let
   759   let
   626       "Term to be tested contains type variables";
   761       "Term to be tested contains type variables";
   627     val _ = assert (null (term_vars t))
   762     val _ = assert (null (term_vars t))
   628       "Term to be tested contains schematic variables";
   763       "Term to be tested contains schematic variables";
   629     val frees = map dest_Free (term_frees t);
   764     val frees = map dest_Free (term_frees t);
   630     val szname = variant (map fst frees) "i";
   765     val szname = variant (map fst frees) "i";
   631     val s = "structure TestTerm =\nstruct\n\n" ^
   766     val code = space_implode "\n" (map snd
   632       setmp mode ["term_of", "test"] (generate_code_i thy)
   767       (setmp mode ["term_of", "test"] (generate_code_i thy)
   633         [("testf", list_abs_free (frees, t))] ^
   768         [("testf", list_abs_free (frees, t))]));
   634       "\n" ^ Pretty.string_of
   769     val s = "structure TestTerm =\nstruct\n\n" ^ code ^
       
   770       "\nopen Generated;\n\n" ^ Pretty.string_of
   635         (Pretty.block [Pretty.str "val () = Codegen.test_fn :=",
   771         (Pretty.block [Pretty.str "val () = Codegen.test_fn :=",
   636           Pretty.brk 1, Pretty.str ("(fn " ^ szname ^ " =>"), Pretty.brk 1,
   772           Pretty.brk 1, Pretty.str ("(fn " ^ szname ^ " =>"), Pretty.brk 1,
   637           Pretty.blk (0, [Pretty.str "let", Pretty.brk 1,
   773           Pretty.blk (0, [Pretty.str "let", Pretty.brk 1,
   638             Pretty.blk (0, separate Pretty.fbrk (map (fn (s, T) =>
   774             Pretty.blk (0, separate Pretty.fbrk (map (fn (s, T) =>
   639               Pretty.block [Pretty.str ("val " ^ mk_id s ^ " ="), Pretty.brk 1,
   775               Pretty.block [Pretty.str ("val " ^ mk_id s ^ " ="), Pretty.brk 1,
   640               mk_gen thy false [] "" T, Pretty.brk 1,
   776               mk_gen thy "" false [] "" T, Pretty.brk 1,
   641               Pretty.str (szname ^ ";")]) frees)),
   777               Pretty.str (szname ^ ";")]) frees)),
   642             Pretty.brk 1, Pretty.str "in", Pretty.brk 1,
   778             Pretty.brk 1, Pretty.str "in", Pretty.brk 1,
   643             Pretty.block [Pretty.str "if ",
   779             Pretty.block [Pretty.str "if ",
   644               mk_app false (Pretty.str "testf") (map (Pretty.str o mk_id o fst) frees),
   780               mk_app false (Pretty.str "testf") (map (Pretty.str o mk_id o fst) frees),
   645               Pretty.brk 1, Pretty.str "then NONE",
   781               Pretty.brk 1, Pretty.str "then NONE",
   646               Pretty.brk 1, Pretty.str "else ",
   782               Pretty.brk 1, Pretty.str "else ",
   647               Pretty.block [Pretty.str "SOME ", Pretty.block (Pretty.str "[" ::
   783               Pretty.block [Pretty.str "SOME ", Pretty.block (Pretty.str "[" ::
   648                 List.concat (separate [Pretty.str ",", Pretty.brk 1]
   784                 List.concat (separate [Pretty.str ",", Pretty.brk 1]
   649                   (map (fn (s, T) => [Pretty.block
   785                   (map (fn (s, T) => [Pretty.block
   650                     [Pretty.str ("(" ^ Library.quote (Symbol.escape s) ^ ","), Pretty.brk 1,
   786                     [Pretty.str ("(" ^ Library.quote (Symbol.escape s) ^ ","), Pretty.brk 1,
   651                      mk_app false (mk_term_of thy false T)
   787                      mk_app false (mk_term_of thy "" false T)
   652                        [Pretty.str (mk_id s)], Pretty.str ")"]]) frees)) @
   788                        [Pretty.str (mk_id s)], Pretty.str ")"]]) frees)) @
   653                   [Pretty.str "]"])]],
   789                   [Pretty.str "]"])]],
   654             Pretty.brk 1, Pretty.str "end"]), Pretty.str ");"]) ^
   790             Pretty.brk 1, Pretty.str "end"]), Pretty.str ");"]) ^
   655       "\n\nend;\n";
   791       "\n\nend;\n";
   656     val _ = use_text Context.ml_output false s;
   792     val _ = use_text Context.ml_output false s;
   714   OuterSyntax.command "types_code"
   850   OuterSyntax.command "types_code"
   715   "associate types with target language types" K.thy_decl
   851   "associate types with target language types" K.thy_decl
   716     (Scan.repeat1 (P.xname --| P.$$$ "(" -- P.string --| P.$$$ ")") >>
   852     (Scan.repeat1 (P.xname --| P.$$$ "(" -- P.string --| P.$$$ ")") >>
   717      (fn xs => Toplevel.theory (fn thy => assoc_types
   853      (fn xs => Toplevel.theory (fn thy => assoc_types
   718        (map (fn (name, mfx) => (name, parse_mixfix
   854        (map (fn (name, mfx) => (name, parse_mixfix
   719          (typ_of o read_ctyp (sign_of thy)) mfx)) xs) thy)));
   855          (typ_of o read_ctyp thy) mfx)) xs) thy)));
   720 
   856 
   721 val assoc_constP =
   857 val assoc_constP =
   722   OuterSyntax.command "consts_code"
   858   OuterSyntax.command "consts_code"
   723   "associate constants with target language code" K.thy_decl
   859   "associate constants with target language code" K.thy_decl
   724     (Scan.repeat1
   860     (Scan.repeat1
   725        (P.xname -- (Scan.option (P.$$$ "::" |-- P.typ)) --|
   861        (P.xname -- (Scan.option (P.$$$ "::" |-- P.typ)) --|
   726         P.$$$ "(" -- P.string --| P.$$$ ")") >>
   862         P.$$$ "(" -- P.string --| P.$$$ ")") >>
   727      (fn xs => Toplevel.theory (fn thy => assoc_consts
   863      (fn xs => Toplevel.theory (fn thy => assoc_consts
   728        (map (fn ((name, optype), mfx) => (name, optype, parse_mixfix
   864        (map (fn ((name, optype), mfx) => (name, optype, parse_mixfix
   729          (term_of o read_cterm (sign_of thy) o rpair TypeInfer.logicT) mfx))
   865          (term_of o read_cterm thy o rpair TypeInfer.logicT) mfx))
   730            xs) thy)));
   866            xs) thy)));
   731 
   867 
   732 val generate_codeP =
   868 val generate_codeP =
   733   OuterSyntax.command "generate_code" "generates code for terms" K.thy_decl
   869   OuterSyntax.command "generate_code" "generates code for terms" K.thy_decl
   734     (Scan.option (P.$$$ "(" |-- P.name --| P.$$$ ")") --
   870     (Scan.option (P.$$$ "(" |-- P.name --| P.$$$ ")") --
   735      Scan.optional (P.$$$ "[" |-- P.enum "," P.xname --| P.$$$ "]") (!mode) --
   871      Scan.optional (P.$$$ "[" |-- P.enum "," P.xname --| P.$$$ "]") (!mode) --
   736      Scan.repeat1 (P.name --| P.$$$ "=" -- P.term) >>
   872      Scan.repeat1 (P.name --| P.$$$ "=" -- P.term) >>
   737      (fn ((opt_fname, mode'), xs) => Toplevel.theory (fn thy =>
   873      (fn ((opt_fname, mode'), xs) => Toplevel.theory (fn thy =>
   738         ((case opt_fname of
   874        let val code = setmp mode mode' (generate_code thy) xs
   739             NONE => use_text Context.ml_output false
   875        in ((case opt_fname of
   740           | SOME fname => File.write (Path.unpack fname))
   876            NONE => use_text Context.ml_output false
   741               (setmp mode mode' (generate_code thy) xs); thy))));
   877              (space_implode "\n" (map snd code) ^ "\nopen Generated;\n")
       
   878          | SOME fname =>
       
   879              if "modular" mem mode' then
       
   880                app (fn (name, s) => File.write
       
   881                    (Path.append (Path.unpack fname) (Path.basic (name ^ ".ML"))) s)
       
   882                  (("ROOT", implode (map (fn (name, _) =>
       
   883                      "use \"" ^ name ^ ".ML\";\n") code)) :: code)
       
   884              else File.write (Path.unpack fname) (snd (hd code))); thy)
       
   885        end)));
   742 
   886 
   743 val params =
   887 val params =
   744   [("size", P.nat >> (K o set_size)),
   888   [("size", P.nat >> (K o set_size)),
   745    ("iterations", P.nat >> (K o set_iterations)),
   889    ("iterations", P.nat >> (K o set_iterations)),
   746    ("default_type", P.typ >> set_default_type)];
   890    ("default_type", P.typ >> set_default_type)];
   757 
   901 
   758 val test_paramsP =
   902 val test_paramsP =
   759   OuterSyntax.command "quickcheck_params" "set parameters for random testing" K.thy_decl
   903   OuterSyntax.command "quickcheck_params" "set parameters for random testing" K.thy_decl
   760     (P.$$$ "[" |-- P.list1 parse_test_params --| P.$$$ "]" >>
   904     (P.$$$ "[" |-- P.list1 parse_test_params --| P.$$$ "]" >>
   761       (fn fs => Toplevel.theory (fn thy =>
   905       (fn fs => Toplevel.theory (fn thy =>
   762          map_test_params (app (map (fn f => f (sign_of thy)) fs)) thy)));
   906          map_test_params (app (map (fn f => f thy) fs)) thy)));
   763 
   907 
   764 val testP =
   908 val testP =
   765   OuterSyntax.command "quickcheck" "try to find counterexample for subgoal" K.diag
   909   OuterSyntax.command "quickcheck" "try to find counterexample for subgoal" K.diag
   766   (Scan.option (P.$$$ "[" |-- P.list1
   910   (Scan.option (P.$$$ "[" |-- P.list1
   767     (   parse_test_params >> (fn f => fn thy => apfst (f thy))
   911     (   parse_test_params >> (fn f => fn thy => apfst (f thy))