src/Pure/codegen.ML
author obua
Sun May 29 12:39:12 2005 +0200 (2005-05-29)
changeset 16108 cf468b93a02e
parent 15801 d2f5ca3c048d
child 16122 864fda4a4056
permissions -rw-r--r--
Implement cycle-free overloading, so that definitions cannot harm consistency any more (except of course via interaction with axioms).
     1 (*  Title:      Pure/codegen.ML
     2     ID:         $Id$
     3     Author:     Stefan Berghofer, TU Muenchen
     4 
     5 Generic code generator.
     6 *)
     7 
     8 signature CODEGEN =
     9 sig
    10   val quiet_mode : bool ref
    11   val message : string -> unit
    12   val mode : string list ref
    13   val margin : int ref
    14 
    15   datatype 'a mixfix =
    16       Arg
    17     | Ignore
    18     | Pretty of Pretty.T
    19     | Quote of 'a;
    20 
    21   type 'a codegen
    22 
    23   val add_codegen: string -> term codegen -> theory -> theory
    24   val add_tycodegen: string -> typ codegen -> theory -> theory
    25   val add_attribute: string -> (Args.T list -> theory attribute * Args.T list) -> theory -> theory
    26   val add_preprocessor: (theory -> thm list -> thm list) -> theory -> theory
    27   val preprocess: theory -> thm list -> thm list
    28   val print_codegens: theory -> unit
    29   val generate_code: theory -> (string * string) list -> string
    30   val generate_code_i: theory -> (string * term) list -> string
    31   val assoc_consts: (xstring * string option * term mixfix list) list -> theory -> theory
    32   val assoc_consts_i: (xstring * typ option * term mixfix list) list -> theory -> theory
    33   val assoc_types: (xstring * typ mixfix list) list -> theory -> theory
    34   val get_assoc_code: theory -> string -> typ -> term mixfix list option
    35   val get_assoc_type: theory -> string -> typ mixfix list option
    36   val invoke_codegen: theory -> string -> bool ->
    37     (exn option * string) Graph.T * term -> (exn option * string) Graph.T * Pretty.T
    38   val invoke_tycodegen: theory -> string -> bool ->
    39     (exn option * string) Graph.T * typ -> (exn option * string) Graph.T * Pretty.T
    40   val mk_id: string -> string
    41   val mk_const_id: Sign.sg -> string -> string
    42   val mk_type_id: Sign.sg -> string -> string
    43   val rename_term: term -> term
    44   val new_names: term -> string list -> string list
    45   val new_name: term -> string -> string
    46   val get_defn: theory -> string -> typ -> ((term list * term) * int option) option
    47   val is_instance: theory -> typ -> typ -> bool
    48   val parens: Pretty.T -> Pretty.T
    49   val mk_app: bool -> Pretty.T -> Pretty.T list -> Pretty.T
    50   val eta_expand: term -> term list -> int -> term
    51   val strip_tname: string -> string
    52   val mk_type: bool -> typ -> Pretty.T
    53   val mk_term_of: Sign.sg -> bool -> typ -> Pretty.T
    54   val mk_gen: Sign.sg -> bool -> string list -> string -> typ -> Pretty.T
    55   val test_fn: (int -> (string * term) list option) ref
    56   val test_term: theory -> int -> int -> term -> (string * term) list option
    57   val parse_mixfix: (string -> 'a) -> string -> 'a mixfix list
    58 end;
    59 
    60 structure Codegen : CODEGEN =
    61 struct
    62 
    63 val quiet_mode = ref true;
    64 fun message s = if !quiet_mode then () else writeln s;
    65 
    66 val mode = ref ([] : string list);
    67 
    68 val margin = ref 80;
    69 
    70 (**** Mixfix syntax ****)
    71 
    72 datatype 'a mixfix =
    73     Arg
    74   | Ignore
    75   | Pretty of Pretty.T
    76   | Quote of 'a;
    77 
    78 fun is_arg Arg = true
    79   | is_arg Ignore = true
    80   | is_arg _ = false;
    81 
    82 fun quotes_of [] = []
    83   | quotes_of (Quote q :: ms) = q :: quotes_of ms
    84   | quotes_of (_ :: ms) = quotes_of ms;
    85 
    86 fun args_of [] xs = ([], xs)
    87   | args_of (Arg :: ms) (x :: xs) = apfst (cons x) (args_of ms xs)
    88   | args_of (Ignore :: ms) (_ :: xs) = args_of ms xs
    89   | args_of (_ :: ms) xs = args_of ms xs;
    90 
    91 fun num_args x = length (List.filter is_arg x);
    92 
    93 
    94 (**** theory data ****)
    95 
    96 (* type of code generators *)
    97 
    98 type 'a codegen = theory -> (exn option * string) Graph.T ->
    99   string -> bool -> 'a -> ((exn option * string) Graph.T * Pretty.T) option;
   100 
   101 (* parameters for random testing *)
   102 
   103 type test_params =
   104   {size: int, iterations: int, default_type: typ option};
   105 
   106 fun merge_test_params
   107   {size = size1, iterations = iterations1, default_type = default_type1}
   108   {size = size2, iterations = iterations2, default_type = default_type2} =
   109   {size = Int.max (size1, size2),
   110    iterations = Int.max (iterations1, iterations2),
   111    default_type = case default_type1 of
   112        NONE => default_type2
   113      | _ => default_type1};
   114 
   115 val default_test_params : test_params =
   116   {size = 10, iterations = 100, default_type = NONE};
   117 
   118 fun set_size size ({iterations, default_type, ...} : test_params) =
   119   {size = size, iterations = iterations, default_type = default_type};
   120 
   121 fun set_iterations iterations ({size, default_type, ...} : test_params) =
   122   {size = size, iterations = iterations, default_type = default_type};
   123 
   124 fun set_default_type s sg ({size, iterations, ...} : test_params) =
   125   {size = size, iterations = iterations,
   126    default_type = SOME (typ_of (read_ctyp sg s))};
   127 
   128 (* data kind 'Pure/codegen' *)
   129  
   130 structure CodegenArgs =
   131 struct
   132   val name = "Pure/codegen";
   133   type T =
   134     {codegens : (string * term codegen) list,
   135      tycodegens : (string * typ codegen) list,
   136      consts : ((string * typ) * term mixfix list) list,
   137      types : (string * typ mixfix list) list,
   138      attrs: (string * (Args.T list -> theory attribute * Args.T list)) list,
   139      preprocs: (stamp * (theory -> thm list -> thm list)) list,
   140      test_params: test_params};
   141 
   142   val empty =
   143     {codegens = [], tycodegens = [], consts = [], types = [], attrs = [],
   144      preprocs = [], test_params = default_test_params};
   145   val copy = I;
   146   val prep_ext = I;
   147 
   148   fun merge
   149     ({codegens = codegens1, tycodegens = tycodegens1,
   150       consts = consts1, types = types1, attrs = attrs1,
   151       preprocs = preprocs1, test_params = test_params1},
   152      {codegens = codegens2, tycodegens = tycodegens2,
   153       consts = consts2, types = types2, attrs = attrs2,
   154       preprocs = preprocs2, test_params = test_params2}) =
   155     {codegens = merge_alists' codegens1 codegens2,
   156      tycodegens = merge_alists' tycodegens1 tycodegens2,
   157      consts = merge_alists consts1 consts2,
   158      types = merge_alists types1 types2,
   159      attrs = merge_alists attrs1 attrs2,
   160      preprocs = merge_alists' preprocs1 preprocs2,
   161      test_params = merge_test_params test_params1 test_params2};
   162 
   163   fun print sg ({codegens, tycodegens, ...} : T) =
   164     Pretty.writeln (Pretty.chunks
   165       [Pretty.strs ("term code generators:" :: map fst codegens),
   166        Pretty.strs ("type code generators:" :: map fst tycodegens)]);
   167 end;
   168 
   169 structure CodegenData = TheoryDataFun(CodegenArgs);
   170 val _ = Context.add_setup [CodegenData.init];
   171 val print_codegens = CodegenData.print;
   172 
   173 
   174 (**** access parameters for random testing ****)
   175 
   176 fun get_test_params thy = #test_params (CodegenData.get thy);
   177 
   178 fun map_test_params f thy =
   179   let val {codegens, tycodegens, consts, types, attrs, preprocs, test_params} =
   180     CodegenData.get thy;
   181   in CodegenData.put {codegens = codegens, tycodegens = tycodegens,
   182     consts = consts, types = types, attrs = attrs, preprocs = preprocs,
   183     test_params = f test_params} thy
   184   end;
   185 
   186 
   187 (**** add new code generators to theory ****)
   188 
   189 fun add_codegen name f thy =
   190   let val {codegens, tycodegens, consts, types, attrs, preprocs, test_params} =
   191     CodegenData.get thy
   192   in (case assoc (codegens, name) of
   193       NONE => CodegenData.put {codegens = (name, f) :: codegens,
   194         tycodegens = tycodegens, consts = consts, types = types,
   195         attrs = attrs, preprocs = preprocs, test_params = test_params} thy
   196     | SOME _ => error ("Code generator " ^ name ^ " already declared"))
   197   end;
   198 
   199 fun add_tycodegen name f thy =
   200   let val {codegens, tycodegens, consts, types, attrs, preprocs, test_params} =
   201     CodegenData.get thy
   202   in (case assoc (tycodegens, name) of
   203       NONE => CodegenData.put {tycodegens = (name, f) :: tycodegens,
   204         codegens = codegens, consts = consts, types = types,
   205         attrs = attrs, preprocs = preprocs, test_params = test_params} thy
   206     | SOME _ => error ("Code generator " ^ name ^ " already declared"))
   207   end;
   208 
   209 
   210 (**** code attribute ****)
   211 
   212 fun add_attribute name att thy =
   213   let val {codegens, tycodegens, consts, types, attrs, preprocs, test_params} =
   214     CodegenData.get thy
   215   in (case assoc (attrs, name) of
   216       NONE => CodegenData.put {tycodegens = tycodegens,
   217         codegens = codegens, consts = consts, types = types,
   218         attrs = if name = "" then attrs @ [(name, att)] else (name, att) :: attrs,
   219         preprocs = preprocs,
   220         test_params = test_params} thy
   221     | SOME _ => error ("Code attribute " ^ name ^ " already declared"))
   222   end;
   223 
   224 fun mk_parser (a, p) = (if a = "" then Scan.succeed "" else Args.$$$ a) |-- p;
   225 
   226 val code_attr =
   227   Attrib.syntax (Scan.peek (fn thy => foldr op || Scan.fail (map mk_parser
   228     (#attrs (CodegenData.get thy)))));
   229 
   230 val _ = Context.add_setup
   231  [Attrib.add_attributes
   232   [("code", (code_attr, K Attrib.undef_local_attribute),
   233      "declare theorems for code generation")]];
   234 
   235 
   236 (**** preprocessors ****)
   237 
   238 fun add_preprocessor p thy =
   239   let val {codegens, tycodegens, consts, types, attrs, preprocs, test_params} =
   240     CodegenData.get thy
   241   in CodegenData.put {tycodegens = tycodegens,
   242     codegens = codegens, consts = consts, types = types,
   243     attrs = attrs, preprocs = (stamp (), p) :: preprocs,
   244     test_params = test_params} thy
   245   end;
   246 
   247 fun preprocess thy ths =
   248   let val {preprocs, ...} = CodegenData.get thy
   249   in Library.foldl (fn (ths, (_, f)) => f thy ths) (ths, preprocs) end;
   250 
   251 fun unfold_attr (thy, eqn) =
   252   let
   253     val (name, _) = dest_Const (head_of
   254       (fst (Logic.dest_equals (prop_of eqn))));
   255     fun prep thy = map (fn th =>
   256       if name mem term_consts (prop_of th) then
   257         rewrite_rule [eqn] (Thm.transfer thy th)
   258       else th)
   259   in (add_preprocessor prep thy, eqn) end;
   260 
   261 val _ = Context.add_setup [add_attribute "unfold" (Scan.succeed unfold_attr)];
   262 
   263 
   264 (**** associate constants with target language code ****)
   265 
   266 fun gen_assoc_consts prep_type xs thy = Library.foldl (fn (thy, (s, tyopt, syn)) =>
   267   let
   268     val sg = sign_of thy;
   269     val {codegens, tycodegens, consts, types, attrs, preprocs, test_params} =
   270       CodegenData.get thy;
   271     val cname = Sign.intern_const sg s;
   272   in
   273     (case Sign.const_type sg cname of
   274        SOME T =>
   275          let val T' = (case tyopt of
   276                 NONE => T
   277               | SOME ty =>
   278                   let val U = prep_type sg ty
   279                   in if Sign.typ_instance sg (U, T) then U
   280                     else error ("Illegal type constraint for constant " ^ cname)
   281                   end)
   282          in (case assoc (consts, (cname, T')) of
   283              NONE => CodegenData.put {codegens = codegens,
   284                tycodegens = tycodegens,
   285                consts = ((cname, T'), syn) :: consts,
   286                types = types, attrs = attrs, preprocs = preprocs,
   287                test_params = test_params} thy
   288            | SOME _ => error ("Constant " ^ cname ^ " already associated with code"))
   289          end
   290      | _ => error ("Not a constant: " ^ s))
   291   end) (thy, xs);
   292 
   293 val assoc_consts_i = gen_assoc_consts (K I);
   294 val assoc_consts = gen_assoc_consts (fn sg => typ_of o read_ctyp sg);
   295 
   296 
   297 (**** associate types with target language types ****)
   298 
   299 fun assoc_types xs thy = Library.foldl (fn (thy, (s, syn)) =>
   300   let
   301     val {codegens, tycodegens, consts, types, attrs, preprocs, test_params} =
   302       CodegenData.get thy;
   303     val tc = Sign.intern_tycon (sign_of thy) s
   304   in
   305     (case assoc (types, tc) of
   306        NONE => CodegenData.put {codegens = codegens,
   307          tycodegens = tycodegens, consts = consts,
   308          types = (tc, syn) :: types, attrs = attrs,
   309          preprocs = preprocs, test_params = test_params} thy
   310      | SOME _ => error ("Type " ^ tc ^ " already associated with code"))
   311   end) (thy, xs);
   312 
   313 fun get_assoc_type thy s = assoc (#types (CodegenData.get thy), s);
   314 
   315 
   316 (**** make valid ML identifiers ****)
   317 
   318 fun is_ascii_letdig x = Symbol.is_ascii_letter x orelse
   319   Symbol.is_ascii_digit x orelse Symbol.is_ascii_quasi x;
   320 
   321 fun dest_sym s = (case split_last (snd (take_prefix (equal "\\") (explode s))) of
   322     ("<" :: "^" :: xs, ">") => (true, implode xs)
   323   | ("<" :: xs, ">") => (false, implode xs)
   324   | _ => sys_error "dest_sym");
   325   
   326 fun mk_id s = if s = "" then "" else
   327   let
   328     fun check_str [] = []
   329       | check_str xs = (case take_prefix is_ascii_letdig xs of
   330           ([], " " :: zs) => check_str zs
   331         | ([], z :: zs) =>
   332           if size z = 1 then string_of_int (ord z) :: check_str zs
   333           else (case dest_sym z of
   334               (true, "isub") => check_str zs
   335             | (true, "isup") => "" :: check_str zs
   336             | (ctrl, s') => (if ctrl then "ctrl_" ^ s' else s') :: check_str zs)
   337         | (ys, zs) => implode ys :: check_str zs);
   338     val s' = space_implode "_"
   339       (List.concat (map (check_str o Symbol.explode) (NameSpace.unpack s)))
   340   in
   341     if Symbol.is_ascii_letter (hd (explode s')) then s' else "id_" ^ s'
   342   end;
   343 
   344 fun mk_const_id sg s =
   345   let val s' = mk_id (Sign.cond_extern sg Sign.constK s)
   346   in if s' mem ThmDatabase.ml_reserved then s' ^ "_const" else s' end;
   347 
   348 fun mk_type_id sg s =
   349   let val s' = mk_id (Sign.cond_extern sg Sign.typeK s)
   350   in if s' mem ThmDatabase.ml_reserved then s' ^ "_type" else s' end;
   351 
   352 fun rename_terms ts =
   353   let
   354     val names = foldr add_term_names
   355       (map (fst o fst) (Drule.vars_of_terms ts)) ts;
   356     val reserved = names inter ThmDatabase.ml_reserved;
   357     val (illegal, alt_names) = split_list (List.mapPartial (fn s =>
   358       let val s' = mk_id s in if s = s' then NONE else SOME (s, s') end) names)
   359     val ps = (reserved @ illegal) ~~
   360       variantlist (map (suffix "'") reserved @ alt_names, names);
   361 
   362     fun rename_id s = getOpt (assoc (ps, s), s);
   363 
   364     fun rename (Var ((a, i), T)) = Var ((rename_id a, i), T)
   365       | rename (Free (a, T)) = Free (rename_id a, T)
   366       | rename (Abs (s, T, t)) = Abs (s, T, rename t)
   367       | rename (t $ u) = rename t $ rename u
   368       | rename t = t;
   369   in
   370     map rename ts
   371   end;
   372 
   373 val rename_term = hd o rename_terms o single;
   374 
   375 
   376 (**** retrieve definition of constant ****)
   377 
   378 fun is_instance thy T1 T2 =
   379   Sign.typ_instance (sign_of thy) (T1, Type.varifyT T2);
   380 
   381 fun get_assoc_code thy s T = Option.map snd (find_first (fn ((s', T'), _) =>
   382   s = s' andalso is_instance thy T T') (#consts (CodegenData.get thy)));
   383 
   384 fun get_defn thy s T =
   385   let
   386     val axms = List.concat (map (Symtab.dest o #axioms o Theory.rep_theory)
   387       (thy :: Theory.ancestors_of thy));
   388     fun prep_def def = (case preprocess thy [def] of
   389       [def'] => prop_of def' | _ => error "get_defn: bad preprocessor");
   390     fun dest t =
   391       let
   392         val (lhs, rhs) = Logic.dest_equals t;
   393         val (c, args) = strip_comb lhs;
   394         val (s', T') = dest_Const c
   395       in if s = s' then SOME (T', (args, rhs)) else NONE
   396       end handle TERM _ => NONE;
   397     val defs = List.mapPartial (fn (name, t) => Option.map (pair name) (dest t)) axms;
   398     val i = find_index (is_instance thy T o fst o snd) defs
   399   in
   400     if i >= 0 then
   401       let val (name, (T', (args, _))) = List.nth (defs, i)
   402       in case dest (prep_def (Thm.get_axiom thy name)) of
   403           NONE => NONE
   404         | SOME (T'', p as (args', rhs)) =>
   405             if T' = T'' andalso args = args' then
   406               SOME (split_last (rename_terms (args @ [rhs])),
   407                 if length defs = 1 then NONE else SOME i)
   408             else NONE
   409       end
   410     else NONE
   411   end;
   412 
   413 
   414 (**** invoke suitable code generator for term / type ****)
   415 
   416 fun invoke_codegen thy dep brack (gr, t) = (case get_first
   417    (fn (_, f) => f thy gr dep brack t) (#codegens (CodegenData.get thy)) of
   418       NONE => error ("Unable to generate code for term:\n" ^
   419         Sign.string_of_term (sign_of thy) t ^ "\nrequired by:\n" ^
   420         commas (Graph.all_succs gr [dep]))
   421     | SOME x => x);
   422 
   423 fun invoke_tycodegen thy dep brack (gr, T) = (case get_first
   424    (fn (_, f) => f thy gr dep brack T) (#tycodegens (CodegenData.get thy)) of
   425       NONE => error ("Unable to generate code for type:\n" ^
   426         Sign.string_of_typ (sign_of thy) T ^ "\nrequired by:\n" ^
   427         commas (Graph.all_succs gr [dep]))
   428     | SOME x => x);
   429 
   430 
   431 (**** code generator for mixfix expressions ****)
   432 
   433 fun parens p = Pretty.block [Pretty.str "(", p, Pretty.str ")"];
   434 
   435 fun pretty_fn [] p = [p]
   436   | pretty_fn (x::xs) p = Pretty.str ("fn " ^ x ^ " =>") ::
   437       Pretty.brk 1 :: pretty_fn xs p;
   438 
   439 fun pretty_mixfix [] [] _ = []
   440   | pretty_mixfix (Arg :: ms) (p :: ps) qs = p :: pretty_mixfix ms ps qs
   441   | pretty_mixfix (Ignore :: ms) ps qs = pretty_mixfix ms ps qs
   442   | pretty_mixfix (Pretty p :: ms) ps qs = p :: pretty_mixfix ms ps qs
   443   | pretty_mixfix (Quote _ :: ms) ps (q :: qs) = q :: pretty_mixfix ms ps qs;
   444 
   445 
   446 (**** default code generators ****)
   447 
   448 fun eta_expand t ts i =
   449   let
   450     val (Ts, _) = strip_type (fastype_of t);
   451     val j = i - length ts
   452   in
   453     foldr (fn (T, t) => Abs ("x", T, t))
   454       (list_comb (t, ts @ map Bound (j-1 downto 0))) (Library.take (j, Ts))
   455   end;
   456 
   457 fun mk_app _ p [] = p
   458   | mk_app brack p ps = if brack then
   459        Pretty.block (Pretty.str "(" ::
   460          separate (Pretty.brk 1) (p :: ps) @ [Pretty.str ")"])
   461      else Pretty.block (separate (Pretty.brk 1) (p :: ps));
   462 
   463 fun new_names t xs = variantlist (map mk_id xs,
   464   map (fst o fst o dest_Var) (term_vars t) union
   465   add_term_names (t, ThmDatabase.ml_reserved));
   466 
   467 fun new_name t x = hd (new_names t [x]);
   468 
   469 fun default_codegen thy gr dep brack t =
   470   let
   471     val (u, ts) = strip_comb t;
   472     fun codegens brack = foldl_map (invoke_codegen thy dep brack)
   473   in (case u of
   474       Var ((s, i), T) =>
   475         let
   476           val (gr', ps) = codegens true (gr, ts);
   477           val (gr'', _) = invoke_tycodegen thy dep false (gr', T)
   478         in SOME (gr'', mk_app brack (Pretty.str (s ^
   479            (if i=0 then "" else string_of_int i))) ps)
   480         end
   481 
   482     | Free (s, T) =>
   483         let
   484           val (gr', ps) = codegens true (gr, ts);
   485           val (gr'', _) = invoke_tycodegen thy dep false (gr', T)
   486         in SOME (gr'', mk_app brack (Pretty.str s) ps) end
   487 
   488     | Const (s, T) =>
   489       (case get_assoc_code thy s T of
   490          SOME ms =>
   491            let val i = num_args ms
   492            in if length ts < i then
   493                default_codegen thy gr dep brack (eta_expand u ts i)
   494              else
   495                let
   496                  val (ts1, ts2) = args_of ms ts;
   497                  val (gr1, ps1) = codegens false (gr, ts1);
   498                  val (gr2, ps2) = codegens true (gr1, ts2);
   499                  val (gr3, ps3) = codegens false (gr2, quotes_of ms);
   500                in
   501                  SOME (gr3, mk_app brack (Pretty.block (pretty_mixfix ms ps1 ps3)) ps2)
   502                end
   503            end
   504        | NONE => (case get_defn thy s T of
   505            NONE => NONE
   506          | SOME ((args, rhs), k) =>
   507              let
   508                val id = mk_const_id (sign_of thy) s ^ (case k of
   509                  NONE => "" | SOME i => "_def" ^ string_of_int i);
   510                val (gr', ps) = codegens true (gr, ts);
   511              in
   512                SOME (Graph.add_edge (id, dep) gr' handle Graph.UNDEF _ =>
   513                  let
   514                    val _ = message ("expanding definition of " ^ s);
   515                    val (Ts, _) = strip_type T;
   516                    val (args', rhs') =
   517                      if not (null args) orelse null Ts then (args, rhs) else
   518                        let val v = Free (new_name rhs "x", hd Ts)
   519                        in ([v], betapply (rhs, v)) end;
   520                    val (gr1, p) = invoke_codegen thy id false
   521                      (Graph.add_edge (id, dep)
   522                         (Graph.new_node (id, (NONE, "")) gr'), rhs');
   523                    val (gr2, xs) = codegens false (gr1, args');
   524                    val (gr3, ty) = invoke_tycodegen thy id false (gr2, T);
   525                  in Graph.map_node id (K (NONE, Pretty.string_of (Pretty.block
   526                    (separate (Pretty.brk 1) (if null args' then
   527                        [Pretty.str ("val " ^ id ^ " :"), ty]
   528                      else Pretty.str ("fun " ^ id) :: xs) @
   529                     [Pretty.str " =", Pretty.brk 1, p, Pretty.str ";"])) ^ "\n\n")) gr3
   530                  end, mk_app brack (Pretty.str id) ps)
   531              end))
   532 
   533     | Abs _ =>
   534       let
   535         val (bs, Ts) = ListPair.unzip (strip_abs_vars u);
   536         val t = strip_abs_body u
   537         val bs' = new_names t bs;
   538         val (gr1, ps) = codegens true (gr, ts);
   539         val (gr2, p) = invoke_codegen thy dep false
   540           (gr1, subst_bounds (map Free (rev (bs' ~~ Ts)), t));
   541       in
   542         SOME (gr2, mk_app brack (Pretty.block (Pretty.str "(" :: pretty_fn bs' p @
   543           [Pretty.str ")"])) ps)
   544       end
   545 
   546     | _ => NONE)
   547   end;
   548 
   549 fun default_tycodegen thy gr dep brack (TVar ((s, i), _)) =
   550       SOME (gr, Pretty.str (s ^ (if i = 0 then "" else string_of_int i)))
   551   | default_tycodegen thy gr dep brack (TFree (s, _)) = SOME (gr, Pretty.str s)
   552   | default_tycodegen thy gr dep brack (Type (s, Ts)) =
   553       (case assoc (#types (CodegenData.get thy), s) of
   554          NONE => NONE
   555        | SOME ms =>
   556            let
   557              val (gr', ps) = foldl_map
   558                (invoke_tycodegen thy dep false) (gr, fst (args_of ms Ts));
   559              val (gr'', qs) = foldl_map
   560                (invoke_tycodegen thy dep false) (gr', quotes_of ms)
   561            in SOME (gr'', Pretty.block (pretty_mixfix ms ps qs)) end);
   562 
   563 val _ = Context.add_setup
   564  [add_codegen "default" default_codegen,
   565   add_tycodegen "default" default_tycodegen];
   566 
   567 
   568 fun output_code gr xs = implode (map (snd o Graph.get_node gr)
   569   (rev (Graph.all_preds gr xs)));
   570 
   571 fun gen_generate_code prep_term thy =
   572   setmp print_mode [] (Pretty.setmp_margin (!margin) (fn xs =>
   573   let
   574     val sg = sign_of thy;
   575     val gr = Graph.new_node ("<Top>", (NONE, "")) Graph.empty;
   576     val (gr', ps) = foldl_map (fn (gr, (s, t)) => apsnd (pair s)
   577       (invoke_codegen thy "<Top>" false (gr, t)))
   578         (gr, map (apsnd (prep_term sg)) xs)
   579     val code =
   580       "structure Generated =\nstruct\n\n" ^
   581       output_code gr' ["<Top>"] ^
   582       space_implode "\n\n" (map (fn (s', p) => Pretty.string_of (Pretty.block
   583         [Pretty.str ("val " ^ s' ^ " ="), Pretty.brk 1, p, Pretty.str ";"])) ps) ^
   584       "\n\nend;\n\nopen Generated;\n";
   585   in code end));
   586 
   587 val generate_code_i = gen_generate_code (K I);
   588 val generate_code = gen_generate_code
   589   (fn sg => term_of o read_cterm sg o rpair TypeInfer.logicT);
   590 
   591 
   592 (**** Reflection ****)
   593 
   594 val strip_tname = implode o tl o explode;
   595 
   596 fun pretty_list xs = Pretty.block (Pretty.str "[" ::
   597   List.concat (separate [Pretty.str ",", Pretty.brk 1] (map single xs)) @
   598   [Pretty.str "]"]);
   599 
   600 fun mk_type p (TVar ((s, i), _)) = Pretty.str
   601       (strip_tname s ^ (if i = 0 then "" else string_of_int i) ^ "T")
   602   | mk_type p (TFree (s, _)) = Pretty.str (strip_tname s ^ "T")
   603   | mk_type p (Type (s, Ts)) = (if p then parens else I) (Pretty.block
   604       [Pretty.str "Type", Pretty.brk 1, Pretty.str ("(\"" ^ s ^ "\","),
   605        Pretty.brk 1, pretty_list (map (mk_type false) Ts), Pretty.str ")"]);
   606 
   607 fun mk_term_of sg p (TVar ((s, i), _)) = Pretty.str
   608       (strip_tname s ^ (if i = 0 then "" else string_of_int i) ^ "F")
   609   | mk_term_of sg p (TFree (s, _)) = Pretty.str (strip_tname s ^ "F")
   610   | mk_term_of sg p (Type (s, Ts)) = (if p then parens else I) (Pretty.block
   611       (separate (Pretty.brk 1) (Pretty.str ("term_of_" ^ mk_type_id sg s) ::
   612         List.concat (map (fn T => [mk_term_of sg true T, mk_type true T]) Ts))));
   613 
   614 
   615 (**** Test data generators ****)
   616 
   617 fun mk_gen sg p xs a (TVar ((s, i), _)) = Pretty.str
   618       (strip_tname s ^ (if i = 0 then "" else string_of_int i) ^ "G")
   619   | mk_gen sg p xs a (TFree (s, _)) = Pretty.str (strip_tname s ^ "G")
   620   | mk_gen sg p xs a (Type (s, Ts)) = (if p then parens else I) (Pretty.block
   621       (separate (Pretty.brk 1) (Pretty.str ("gen_" ^ mk_type_id sg s ^
   622         (if s mem xs then "'" else "")) :: map (mk_gen sg true xs a) Ts @
   623         (if s mem xs then [Pretty.str a] else []))));
   624 
   625 val test_fn : (int -> (string * term) list option) ref = ref (fn _ => NONE);
   626 
   627 fun test_term thy sz i = setmp print_mode [] (fn t =>
   628   let
   629     val _ = assert (null (term_tvars t) andalso null (term_tfrees t))
   630       "Term to be tested contains type variables";
   631     val _ = assert (null (term_vars t))
   632       "Term to be tested contains schematic variables";
   633     val sg = sign_of thy;
   634     val frees = map dest_Free (term_frees t);
   635     val szname = variant (map fst frees) "i";
   636     val s = "structure TestTerm =\nstruct\n\n" ^
   637       setmp mode ["term_of", "test"] (generate_code_i thy)
   638         [("testf", list_abs_free (frees, t))] ^
   639       "\n" ^ Pretty.string_of
   640         (Pretty.block [Pretty.str "val () = Codegen.test_fn :=",
   641           Pretty.brk 1, Pretty.str ("(fn " ^ szname ^ " =>"), Pretty.brk 1,
   642           Pretty.blk (0, [Pretty.str "let", Pretty.brk 1,
   643             Pretty.blk (0, separate Pretty.fbrk (map (fn (s, T) =>
   644               Pretty.block [Pretty.str ("val " ^ mk_id s ^ " ="), Pretty.brk 1,
   645               mk_gen sg false [] "" T, Pretty.brk 1,
   646               Pretty.str (szname ^ ";")]) frees)),
   647             Pretty.brk 1, Pretty.str "in", Pretty.brk 1,
   648             Pretty.block [Pretty.str "if ",
   649               mk_app false (Pretty.str "testf") (map (Pretty.str o mk_id o fst) frees),
   650               Pretty.brk 1, Pretty.str "then NONE",
   651               Pretty.brk 1, Pretty.str "else ",
   652               Pretty.block [Pretty.str "SOME ", Pretty.block (Pretty.str "[" ::
   653                 List.concat (separate [Pretty.str ",", Pretty.brk 1]
   654                   (map (fn (s, T) => [Pretty.block
   655                     [Pretty.str ("(" ^ Library.quote (Symbol.escape s) ^ ","), Pretty.brk 1,
   656                      mk_app false (mk_term_of sg false T)
   657                        [Pretty.str (mk_id s)], Pretty.str ")"]]) frees)) @
   658                   [Pretty.str "]"])]],
   659             Pretty.brk 1, Pretty.str "end"]), Pretty.str ");"]) ^
   660       "\n\nend;\n";
   661     val _ = use_text Context.ml_output false s;
   662     fun iter f k = if k > i then NONE
   663       else (case (f () handle Match =>
   664           (warning "Exception Match raised in generated code"; NONE)) of
   665         NONE => iter f (k+1) | SOME x => SOME x);
   666     fun test k = if k > sz then NONE
   667       else (priority ("Test data size: " ^ string_of_int k);
   668         case iter (fn () => !test_fn k) 1 of
   669           NONE => test (k+1) | SOME x => SOME x);
   670   in test 0 end);
   671 
   672 fun test_goal ({size, iterations, default_type}, tvinsts) i st =
   673   let
   674     val sg = Toplevel.sign_of st;
   675     fun strip (Const ("all", _) $ Abs (_, _, t)) = strip t
   676       | strip t = t;
   677     val (gi, frees) = Logic.goal_params
   678       (prop_of (snd (snd (Proof.get_goal (Toplevel.proof_of st))))) i;
   679     val gi' = ObjectLogic.atomize_term sg (map_term_types
   680       (map_type_tfree (fn p as (s, _) => getOpt (assoc (tvinsts, s),
   681         getOpt (default_type,TFree p)))) (subst_bounds (frees, strip gi)));
   682   in case test_term (Toplevel.theory_of st) size iterations gi' of
   683       NONE => writeln "No counterexamples found."
   684     | SOME cex => writeln ("Counterexample found:\n" ^
   685         Pretty.string_of (Pretty.chunks (map (fn (s, t) =>
   686           Pretty.block [Pretty.str (s ^ " ="), Pretty.brk 1,
   687             Sign.pretty_term sg t]) cex)))
   688   end;
   689 
   690 
   691 (**** Interface ****)
   692 
   693 val str = setmp print_mode [] Pretty.str;
   694 
   695 fun parse_mixfix rd s =
   696   (case Scan.finite Symbol.stopper (Scan.repeat
   697      (   $$ "_" >> K Arg
   698       || $$ "?" >> K Ignore
   699       || $$ "/" |-- Scan.repeat ($$ " ") >> (Pretty o Pretty.brk o length)
   700       || $$ "{" |-- $$ "*" |-- Scan.repeat1
   701            (   $$ "'" |-- Scan.one Symbol.not_eof
   702             || Scan.unless ($$ "*" -- $$ "}") (Scan.one Symbol.not_eof)) --|
   703          $$ "*" --| $$ "}" >> (Quote o rd o implode)
   704       || Scan.repeat1
   705            (   $$ "'" |-- Scan.one Symbol.not_eof
   706             || Scan.unless ($$ "_" || $$ "?" || $$ "/" || $$ "{" |-- $$ "*")
   707                  (Scan.one Symbol.not_eof)) >> (Pretty o str o implode)))
   708        (Symbol.explode s) of
   709      (p, []) => p
   710    | _ => error ("Malformed annotation: " ^ quote s));
   711 
   712 val _ = Context.add_setup
   713   [assoc_types [("fun", parse_mixfix (K dummyT) "(_ ->/ _)")]];
   714 
   715 
   716 structure P = OuterParse and K = OuterSyntax.Keyword;
   717 
   718 val assoc_typeP =
   719   OuterSyntax.command "types_code"
   720   "associate types with target language types" K.thy_decl
   721     (Scan.repeat1 (P.xname --| P.$$$ "(" -- P.string --| P.$$$ ")") >>
   722      (fn xs => Toplevel.theory (fn thy => assoc_types
   723        (map (fn (name, mfx) => (name, parse_mixfix
   724          (typ_of o read_ctyp (sign_of thy)) mfx)) xs) thy)));
   725 
   726 val assoc_constP =
   727   OuterSyntax.command "consts_code"
   728   "associate constants with target language code" K.thy_decl
   729     (Scan.repeat1
   730        (P.xname -- (Scan.option (P.$$$ "::" |-- P.typ)) --|
   731         P.$$$ "(" -- P.string --| P.$$$ ")") >>
   732      (fn xs => Toplevel.theory (fn thy => assoc_consts
   733        (map (fn ((name, optype), mfx) => (name, optype, parse_mixfix
   734          (term_of o read_cterm (sign_of thy) o rpair TypeInfer.logicT) mfx))
   735            xs) thy)));
   736 
   737 val generate_codeP =
   738   OuterSyntax.command "generate_code" "generates code for terms" K.thy_decl
   739     (Scan.option (P.$$$ "(" |-- P.name --| P.$$$ ")") --
   740      Scan.optional (P.$$$ "[" |-- P.enum "," P.xname --| P.$$$ "]") (!mode) --
   741      Scan.repeat1 (P.name --| P.$$$ "=" -- P.term) >>
   742      (fn ((opt_fname, mode'), xs) => Toplevel.theory (fn thy =>
   743         ((case opt_fname of
   744             NONE => use_text Context.ml_output false
   745           | SOME fname => File.write (Path.unpack fname))
   746               (setmp mode mode' (generate_code thy) xs); thy))));
   747 
   748 val params =
   749   [("size", P.nat >> (K o set_size)),
   750    ("iterations", P.nat >> (K o set_iterations)),
   751    ("default_type", P.typ >> set_default_type)];
   752 
   753 val parse_test_params = P.short_ident :-- (fn s =>
   754   P.$$$ "=" |-- getOpt (assoc (params, s), Scan.fail)) >> snd;
   755 
   756 fun parse_tyinst xs =
   757   (P.type_ident --| P.$$$ "=" -- P.typ >> (fn (v, s) => fn sg =>
   758     fn (x, ys) => (x, (v, typ_of (read_ctyp sg s)) :: ys))) xs;
   759 
   760 fun app [] x = x
   761   | app (f :: fs) x = app fs (f x);
   762 
   763 val test_paramsP =
   764   OuterSyntax.command "quickcheck_params" "set parameters for random testing" K.thy_decl
   765     (P.$$$ "[" |-- P.list1 parse_test_params --| P.$$$ "]" >>
   766       (fn fs => Toplevel.theory (fn thy =>
   767          map_test_params (app (map (fn f => f (sign_of thy)) fs)) thy)));
   768 
   769 val testP =
   770   OuterSyntax.command "quickcheck" "try to find counterexample for subgoal" K.diag
   771   (Scan.option (P.$$$ "[" |-- P.list1
   772     (   parse_test_params >> (fn f => fn sg => apfst (f sg))
   773      || parse_tyinst) --| P.$$$ "]") -- Scan.optional P.nat 1 >>
   774     (fn (ps, g) => Toplevel.keep (fn st =>
   775       test_goal (app (getOpt (Option.map
   776           (map (fn f => f (Toplevel.sign_of st))) ps, []))
   777         (get_test_params (Toplevel.theory_of st), [])) g st)));
   778 
   779 val _ = OuterSyntax.add_parsers
   780   [assoc_typeP, assoc_constP, generate_codeP, test_paramsP, testP];
   781 
   782 end;