src/Pure/Tools/codegen_package.ML
changeset 18912 dd168daf172d
parent 18885 ee8b5c36ba2b
child 18915 7521b849ae98
equal deleted inserted replaced
18911:74edab16166f 18912:dd168daf172d
    34   val set_get_all_datatype_cons : (theory -> (string * string) list)
    34   val set_get_all_datatype_cons : (theory -> (string * string) list)
    35     -> theory -> theory;
    35     -> theory -> theory;
    36   val set_defgen_datatype: defgen -> theory -> theory;
    36   val set_defgen_datatype: defgen -> theory -> theory;
    37   val set_int_tyco: string -> theory -> theory;
    37   val set_int_tyco: string -> theory -> theory;
    38 
    38 
    39   val exprgen_type: theory -> auxtab
       
    40     -> typ -> CodegenThingol.transact -> CodegenThingol.itype * CodegenThingol.transact;
       
    41   val exprgen_term: theory -> auxtab
       
    42     -> term -> CodegenThingol.transact -> CodegenThingol.iexpr * CodegenThingol.transact;
       
    43   val appgen_default: appgen;
    39   val appgen_default: appgen;
    44 
       
    45   val appgen_let: (int -> term -> term list * term)
    40   val appgen_let: (int -> term -> term list * term)
    46     -> appgen;
    41     -> appgen;
    47   val appgen_split: (int -> term -> term list * term)
    42   val appgen_split: (int -> term -> term list * term)
    48     -> appgen;
    43     -> appgen;
    49   val appgen_number_of: (term -> term) -> (term -> IntInf.int) -> string -> string
    44   val appgen_number_of: (term -> term) -> (term -> IntInf.int) -> string -> string
    82 infixr 6 `-->;
    77 infixr 6 `-->;
    83 infix 4 `$;
    78 infix 4 `$;
    84 infix 4 `$$;
    79 infix 4 `$$;
    85 infixr 3 `|->;
    80 infixr 3 `|->;
    86 infixr 3 `|-->;
    81 infixr 3 `|-->;
    87 
       
    88 (* auxiliary *)
       
    89 
       
    90 fun devarify_type ty = (fst o Type.freeze_thaw_type o Term.zero_var_indexesT) ty;
       
    91 fun devarify_term t = (fst o Type.freeze_thaw o Term.zero_var_indexes) t;
       
    92 
       
    93 val is_number = is_some o Int.fromString;
       
    94 
       
    95 fun merge_opt _ (x1, NONE) = x1
       
    96   | merge_opt _ (NONE, x2) = x2
       
    97   | merge_opt eq (SOME x1, SOME x2) =
       
    98       if eq (x1, x2) then SOME x1 else error ("incompatible options during merge");
       
    99 
       
   100 
    82 
   101 (* shallow name spaces *)
    83 (* shallow name spaces *)
   102 
    84 
   103 val nsp_module = ""; (* a dummy by convention *)
    85 val nsp_module = ""; (* a dummy by convention *)
   104 val nsp_class = "class";
    86 val nsp_class = "class";
   202      )
   184      )
   203 );
   185 );
   204 
   186 
   205 
   187 
   206 (* theory data for code generator *)
   188 (* theory data for code generator *)
       
   189 
       
   190 fun merge_opt _ (x1, NONE) = x1
       
   191   | merge_opt _ (NONE, x2) = x2
       
   192   | merge_opt eq (SOME x1, SOME x2) =
       
   193       if eq (x1, x2) then SOME x1 else error ("incompatible options during merge");
   207 
   194 
   208 type gens = {
   195 type gens = {
   209   appconst: ((int * int) * (appgen * stamp)) Symtab.table,
   196   appconst: ((int * int) * (appgen * stamp)) Symtab.table,
   210   eqextrs: (string * (eqextr * stamp)) list
   197   eqextrs: (string * (eqextr * stamp)) list
   211 };
   198 };
   504             )
   491             )
   505        )
   492        )
   506     ); thy);
   493     ); thy);
   507 
   494 
   508 
   495 
       
   496 (* sophisticated devarification *)
       
   497 
       
   498 fun assert f msg x =
       
   499   if f x then x
       
   500     else error msg;
       
   501 
       
   502 val _ : ('a -> bool) -> string -> 'a -> 'a = assert;
       
   503 
       
   504 fun devarify_typs tys =
       
   505   let
       
   506     fun add_rename (var as ((v, _), sort)) used = 
       
   507       let
       
   508         val v' = variant used v
       
   509       in (((var, TFree (v', sort)), (v', TVar var)), v' :: used) end;
       
   510     fun typ_names (Type (tyco, tys)) (vars, names) =
       
   511           (vars, names |> insert (op =) (NameSpace.base tyco))
       
   512           |> fold typ_names tys
       
   513       | typ_names (TFree (v, _)) (vars, names) =
       
   514           (vars, names |> insert (op =) v)
       
   515       | typ_names (TVar (v, sort)) (vars, names) =
       
   516           (vars |> AList.update (op =) (v, sort), names);
       
   517     val (vars, used) = fold typ_names tys ([], []);
       
   518     val (renames, reverse) = fold_map add_rename vars used |> fst |> split_list;
       
   519   in
       
   520     (reverse, (map o map_atyps) (Term.instantiateT renames) tys)
       
   521   end;
       
   522 
       
   523 fun burrow_typs_yield f ts =
       
   524   let
       
   525     val typtab =
       
   526       fold (fold_types (fn ty => Typtab.update (ty, dummyT)))
       
   527         ts Typtab.empty;
       
   528     val typs = Typtab.keys typtab;
       
   529     val (x, typs') = f typs;
       
   530     val typtab' = fold2 (Typtab.update oo pair) typs typs' typtab;
       
   531   in
       
   532     (x, (map o map_term_types) (the o Typtab.lookup typtab') ts)
       
   533   end;
       
   534 
       
   535 fun devarify_terms ts =
       
   536   let
       
   537     fun add_rename (var as ((v, _), ty)) used = 
       
   538       let
       
   539         val v' = variant used v
       
   540       in (((var, Free (v', ty)), (v', Var var)), v' :: used) end;
       
   541     fun term_names (Const (c, _)) (vars, names) =
       
   542           (vars, names |> insert (op =) (NameSpace.base c))
       
   543       | term_names (Free (v, _)) (vars, names) =
       
   544           (vars, names |> insert (op =) v)
       
   545       | term_names (Var (v, sort)) (vars, names) =
       
   546           (vars |> AList.update (op =) (v, sort), names)
       
   547       | term_names (Bound _) vars_names =
       
   548           vars_names
       
   549       | term_names (Abs (v, _, _)) (vars, names) =
       
   550           (vars, names |> insert (op =) v)
       
   551       | term_names (t1 $ t2) vars_names =
       
   552           vars_names |> term_names t1 |> term_names t2
       
   553     val (vars, used) = fold term_names ts ([], []);
       
   554     val (renames, reverse) = fold_map add_rename vars used |> fst |> split_list;
       
   555   in
       
   556     (reverse, (map o map_aterms) (Term.instantiate ([], renames)) ts)
       
   557   end;
       
   558 
       
   559 fun devarify_term_typs ts =
       
   560   ts
       
   561   |> devarify_terms
       
   562   |-> (fn reverse => burrow_typs_yield devarify_typs
       
   563   #-> (fn reverseT => pair (reverseT, reverse)));
       
   564 
   509 (* definition and expression generators *)
   565 (* definition and expression generators *)
   510 
   566 
   511 fun ensure_def_class thy tabs cls trns =
   567 fun ensure_def_class thy tabs cls trns =
   512   let
   568   let
   513     fun defgen_class thy (tabs as (_, (insttab, _, _))) cls trns =
   569     fun defgen_class thy (tabs as (_, (insttab, _, _))) cls trns =
   519               val idfs = map (idf_of_name thy nsp_mem o fst) cs;
   575               val idfs = map (idf_of_name thy nsp_mem o fst) cs;
   520             in
   576             in
   521               trns
   577               trns
   522               |> debug 5 (fn _ => "trying defgen class declaration for " ^ quote cls)
   578               |> debug 5 (fn _ => "trying defgen class declaration for " ^ quote cls)
   523               |> fold_map (ensure_def_class thy tabs) (ClassPackage.the_superclasses thy cls)
   579               |> fold_map (ensure_def_class thy tabs) (ClassPackage.the_superclasses thy cls)
   524               ||>> fold_map (exprgen_type thy tabs o devarify_type o snd) cs
   580               ||>> (codegen_type thy tabs o map snd) cs
   525               ||>> (fold_map o fold_map) (exprgen_tyvar_sort thy tabs) sortctxts
   581               ||>> (fold_map o fold_map) (exprgen_tyvar_sort thy tabs) sortctxts
   526               |-> (fn ((supcls, memtypes), sortctxts) => succeed
   582               |-> (fn ((supcls, memtypes), sortctxts) => succeed
   527                 (Class ((supcls, ("a", idfs ~~ (sortctxts ~~ memtypes))), [])))
   583                 (Class ((supcls, ("a", idfs ~~ (sortctxts ~~ memtypes))), [])))
   528             end
   584             end
   529         | _ =>
   585         | _ =>
   562       |-> (fn (t1', t2') => pair (t1' `-> t2'))
   618       |-> (fn (t1', t2') => pair (t1' `-> t2'))
   563   | exprgen_type thy tabs (Type (tyco, tys)) trns =
   619   | exprgen_type thy tabs (Type (tyco, tys)) trns =
   564       trns
   620       trns
   565       |> ensure_def_tyco thy tabs tyco
   621       |> ensure_def_tyco thy tabs tyco
   566       ||>> fold_map (exprgen_type thy tabs) tys
   622       ||>> fold_map (exprgen_type thy tabs) tys
   567       |-> (fn (tyco, tys) => pair (tyco `%% tys));
   623       |-> (fn (tyco, tys) => pair (tyco `%% tys))
       
   624 and codegen_type thy tabs =
       
   625   fold_map (exprgen_type thy tabs) o snd o devarify_typs;
   568 
   626 
   569 fun exprgen_classlookup thy tabs (ClassPackage.Instance (inst, ls)) trns =
   627 fun exprgen_classlookup thy tabs (ClassPackage.Instance (inst, ls)) trns =
   570       trns
   628       trns
   571       |> ensure_def_inst thy tabs inst
   629       |> ensure_def_inst thy tabs inst
   572       ||>> (fold_map o fold_map) (exprgen_classlookup thy tabs) ls
   630       ||>> (fold_map o fold_map) (exprgen_classlookup thy tabs) ls
   588              of Const (c', _) => if c' = c then (args, rhs)
   646              of Const (c', _) => if c' = c then (args, rhs)
   589                  else error ("illegal function equation for " ^ quote c
   647                  else error ("illegal function equation for " ^ quote c
   590                    ^ ", actually defining " ^ quote c')
   648                    ^ ", actually defining " ^ quote c')
   591               | _ => error ("illegal function equation for " ^ quote c)
   649               | _ => error ("illegal function equation for " ^ quote c)
   592             end;
   650             end;
   593           fun mk_eq (args, rhs) trns =
       
   594             trns
       
   595             |> fold_map (exprgen_term thy tabs o devarify_term) args
       
   596             ||>> (exprgen_term thy tabs o devarify_term) rhs
       
   597             |-> (fn (args, rhs) => pair (args, rhs))
       
   598         in
   651         in
   599           trns
   652           trns
   600           |> fold_map (mk_eq o dest_eqthm) eq_thms
   653           |> (codegen_eqs thy tabs o map dest_eqthm) eq_thms
   601           ||>> (exprgen_type thy tabs o devarify_type) ty
   654           ||>> codegen_type thy tabs [ty]
   602           ||>> fold_map (exprgen_tyvar_sort thy tabs) sortctxt
   655           ||>> fold_map (exprgen_tyvar_sort thy tabs) sortctxt
   603           |-> (fn ((eqs, ty), sortctxt) => (pair o SOME) (eqs, (sortctxt, ty)))
   656           |-> (fn ((eqs, [ty]), sortctxt) => (pair o SOME) (eqs, (sortctxt, ty)))
   604         end
   657         end
   605     | NONE => (NONE, trns)
   658     | NONE => (NONE, trns)
   606 and ensure_def_inst thy (tabs as (_, (insttab, _, _))) (cls, tyco) trns =
   659 and ensure_def_inst thy (tabs as (_, (insttab, _, _))) (cls, tyco) trns =
   607   let
   660   let
   608     fun defgen_inst thy (tabs as (_, (insttab, _, _))) inst trns =
   661     fun defgen_inst thy (tabs as (_, (insttab, _, _))) inst trns =
   688   end
   741   end
   689 and exprgen_term thy tabs (Const (f, ty)) trns =
   742 and exprgen_term thy tabs (Const (f, ty)) trns =
   690       trns
   743       trns
   691       |> appgen thy tabs ((f, ty), [])
   744       |> appgen thy tabs ((f, ty), [])
   692       |-> (fn e => pair e)
   745       |-> (fn e => pair e)
   693   | exprgen_term thy tabs (Var ((v, 0), ty)) trns =
   746   (* | exprgen_term thy tabs (Var ((v, 0), ty)) trns =
   694       trns
   747       trns
   695       |> (exprgen_type thy tabs o devarify_type) ty
   748       |> (exprgen_type thy tabs) ty
   696       |-> (fn ty => pair (IVarE (v, ty)))
   749       |-> (fn ty => pair (IVarE (v, ty)))
   697   | exprgen_term thy tabs (Var ((_, _), _)) trns =
   750   | exprgen_term thy tabs (Var ((_, _), _)) trns =
   698       error "Var with index greater 0 encountered during code generation"
   751       error "Var with index greater 0 encountered during code generation" *)
       
   752   | exprgen_term thy tabs (Var _) trns =
       
   753       error "Var encountered during code generation"
   699   | exprgen_term thy tabs (Free (v, ty)) trns =
   754   | exprgen_term thy tabs (Free (v, ty)) trns =
   700       trns
   755       trns
   701       |> (exprgen_type thy tabs o devarify_type) ty
   756       |> exprgen_type thy tabs ty
   702       |-> (fn ty => pair (IVarE (v, ty)))
   757       |-> (fn ty => pair (IVarE (v, ty)))
   703   | exprgen_term thy tabs (Abs (v, ty, t)) trns =
   758   | exprgen_term thy tabs (Abs (v, ty, t)) trns =
   704       trns
   759       trns
   705       |> (exprgen_type thy tabs o devarify_type) ty
   760       |> exprgen_type thy tabs ty
   706       ||>> exprgen_term thy tabs (subst_bound (Free (v, ty), t))
   761       ||>> exprgen_term thy tabs (subst_bound (Free (v, ty), t))
   707       |-> (fn (ty, e) => pair ((v, ty) `|-> e))
   762       |-> (fn (ty, e) => pair ((v, ty) `|-> e))
   708   | exprgen_term thy tabs (t as t1 $ t2) trns =
   763   | exprgen_term thy tabs (t as t1 $ t2) trns =
   709       let
   764       let
   710         val (t', ts) = strip_comb t
   765         val (t', ts) = strip_comb t
   717             trns
   772             trns
   718             |> exprgen_term thy tabs t'
   773             |> exprgen_term thy tabs t'
   719             ||>> fold_map (exprgen_term thy tabs) ts
   774             ||>> fold_map (exprgen_term thy tabs) ts
   720             |-> (fn (e, es) => pair (e `$$ es))
   775             |-> (fn (e, es) => pair (e `$$ es))
   721       end
   776       end
       
   777 and codegen_term thy tabs =
       
   778   fold_map (exprgen_term thy tabs) o snd o devarify_term_typs
       
   779 and codegen_eqs thy tabs =
       
   780   apfst (map (fn (rhs::args) => (args, rhs)))
       
   781     oo fold_burrow (codegen_term thy tabs)
       
   782     o map (fn (args, rhs) => (rhs :: args))
   722 and appgen_default thy tabs ((c, ty), ts) trns =
   783 and appgen_default thy tabs ((c, ty), ts) trns =
   723   trns
   784   trns
   724   |> ensure_def_const thy tabs (c, ty)
   785   |> ensure_def_const thy tabs (c, ty)
   725   ||>> (fold_map o fold_map) (exprgen_classlookup thy tabs)
   786   ||>> (fold_map o fold_map) (exprgen_classlookup thy tabs)
   726          (ClassPackage.extract_classlookup thy (c, ty))
   787          (ClassPackage.extract_classlookup thy (c, ty))
   727   ||>> (exprgen_type thy tabs o devarify_type) ty
   788   ||>> codegen_type thy tabs [ty]
   728   ||>> fold_map (exprgen_term thy tabs o devarify_term) ts
   789   ||>> fold_map (exprgen_term thy tabs) ts
   729   |-> (fn (((c, ls), ty), es) =>
   790   |-> (fn (((c, ls), [ty]), es) =>
   730          pair (IConst ((c, ty), ls) `$$ es))
   791          pair (IConst ((c, ty), ls) `$$ es))
   731 and appgen thy tabs ((f, ty), ts) trns =
   792 and appgen thy tabs ((f, ty), ts) trns =
   732   case Symtab.lookup ((#appconst o #gens o CodegenData.get) thy) f
   793   case Symtab.lookup ((#appconst o #gens o CodegenData.get) thy) f
   733    of SOME ((imin, imax), (ag, _)) =>
   794    of SOME ((imin, imax), (ag, _)) =>
   734         if length ts < imin then
   795         if length ts < imin then
   737             val vs = Term.invent_names (add_term_names (Const (f, ty), [])) "x" d;
   798             val vs = Term.invent_names (add_term_names (Const (f, ty), [])) "x" d;
   738             val tys = Library.take (d, ((fst o strip_type) ty));
   799             val tys = Library.take (d, ((fst o strip_type) ty));
   739           in
   800           in
   740             trns
   801             trns
   741             |> debug 10 (fn _ => "eta-expanding")
   802             |> debug 10 (fn _ => "eta-expanding")
   742             |> fold_map (exprgen_type thy tabs o devarify_type) tys
   803             |> fold_map (exprgen_type thy tabs) tys
   743             ||>> ag thy tabs ((f, ty), ts @ map2 (curry Free) vs tys)
   804             ||>> ag thy tabs ((f, ty), ts @ map2 (curry Free) vs tys)
   744             |-> (fn (tys, e) => pair ((vs ~~ tys) `|--> e))
   805             |-> (fn (tys, e) => pair ((vs ~~ tys) `|--> e))
   745           end
   806           end
   746         else if length ts > imax then
   807         else if length ts > imax then
   747           trns
   808           trns
   840 
   901 
   841 fun appgen_number_of mk_int_to_nat bin_to_int tyco_int tyco_nat thy tabs ((_,
   902 fun appgen_number_of mk_int_to_nat bin_to_int tyco_int tyco_nat thy tabs ((_,
   842   Type (_, [_, ty as Type (tyco, [])])), [bin]) trns =
   903   Type (_, [_, ty as Type (tyco, [])])), [bin]) trns =
   843     if tyco = tyco_int then
   904     if tyco = tyco_int then
   844       trns
   905       trns
   845       |> (exprgen_type thy tabs o devarify_type) ty
   906       |> exprgen_type thy tabs ty
   846       |-> (fn ty => pair (CodegenThingol.IConst (((IntInf.toString o bin_to_int) bin, ty), [])))
   907       |-> (fn ty => pair (CodegenThingol.IConst (((IntInf.toString o bin_to_int) bin, ty), [])))
   847     else if tyco = tyco_nat then
   908     else if tyco = tyco_nat then
   848       trns
   909       trns
   849       |> exprgen_term thy tabs (mk_int_to_nat bin)
   910       |> exprgen_term thy tabs (mk_int_to_nat bin)
   850     else error ("invalid type constructor for numeral: " ^ quote tyco);
   911     else error ("invalid type constructor for numeral: " ^ quote tyco);
   900                   idf_of_name thy nsp_dtcon) cos;
   961                   idf_of_name thy nsp_dtcon) cos;
   901               in
   962               in
   902                 trns
   963                 trns
   903                 |> debug 5 (fn _ => "trying defgen datatype for " ^ quote dtco)
   964                 |> debug 5 (fn _ => "trying defgen datatype for " ^ quote dtco)
   904                 |> fold_map (exprgen_tyvar_sort thy tabs) vars
   965                 |> fold_map (exprgen_tyvar_sort thy tabs) vars
   905                 ||>> (fold_map o fold_map) (exprgen_type thy tabs o devarify_type) cotys
   966                 ||>> fold_map (codegen_type thy tabs) cotys
   906                 |-> (fn (sorts, tys) => succeed (Datatype
   967                 |-> (fn (sorts, tys) => succeed (Datatype
   907                      ((sorts, coidfs ~~ tys), [])))
   968                      ((sorts, coidfs ~~ tys), [])))
   908               end
   969               end
   909           | NONE =>
   970           | NONE =>
   910               trns
   971               trns
  1054      of NONE => Sign.the_const_constraint thy c
  1115      of NONE => Sign.the_const_constraint thy c
  1055       | SOME raw_ty => read_typ thy raw_ty;
  1116       | SOME raw_ty => read_typ thy raw_ty;
  1056   in (c, ty) end;
  1117   in (c, ty) end;
  1057 
  1118 
  1058 fun read_quote reader gen raw thy =
  1119 fun read_quote reader gen raw thy =
  1059   expand_module
  1120   thy
  1060     (fn thy => fn tabs => gen thy tabs (reader thy raw))
  1121   |> expand_module
  1061     thy;
  1122        (fn thy => fn tabs => (gen thy tabs o single o reader thy) raw)
       
  1123   |-> (fn [x] => pair x);
  1062 
  1124 
  1063 fun gen_add_prim prep_name prep_primdef raw_name deps (target, raw_primdef) thy =
  1125 fun gen_add_prim prep_name prep_primdef raw_name deps (target, raw_primdef) thy =
  1064   let
  1126   let
  1065     val _ = if Symtab.defined ((#target_data o CodegenData.get) thy) target
  1127     val _ = if Symtab.defined ((#target_data o CodegenData.get) thy) target
  1066       then () else error ("unknown target language: " ^ quote target);
  1128       then () else error ("unknown target language: " ^ quote target);
  1131                       (tyco, (pretty, stamp ())),
  1193                       (tyco, (pretty, stamp ())),
  1132                     syntax_const))),
  1194                     syntax_const))),
  1133               logic_data)))
  1195               logic_data)))
  1134       end;
  1196       end;
  1135   in
  1197   in
  1136     CodegenSerializer.parse_syntax
  1198     CodegenSerializer.parse_syntax (read_quote read_typ codegen_type)
  1137       (read_quote read_typ (fn thy => fn tabs => exprgen_type thy tabs o devarify_type))
       
  1138     #-> (fn reader => pair (mk reader))
  1199     #-> (fn reader => pair (mk reader))
  1139   end;
  1200   end;
  1140 
  1201 
  1141 fun add_pretty_syntax_const c target pretty =
  1202 fun add_pretty_syntax_const c target pretty =
  1142   map_codegen_data
  1203   map_codegen_data
  1163         |> ensure_prim c target
  1224         |> ensure_prim c target
  1164         |> reader
  1225         |> reader
  1165         |-> (fn pretty => add_pretty_syntax_const c target pretty)
  1226         |-> (fn pretty => add_pretty_syntax_const c target pretty)
  1166       end;
  1227       end;
  1167   in
  1228   in
  1168     CodegenSerializer.parse_syntax (read_quote Sign.read_term exprgen_term)
  1229     CodegenSerializer.parse_syntax (read_quote Sign.read_term codegen_term)
  1169     #-> (fn reader => pair (mk reader))
  1230     #-> (fn reader => pair (mk reader))
  1170   end;
  1231   end;
  1171 
  1232 
  1172 fun add_pretty_list raw_nil raw_cons (target, seri) thy =
  1233 fun add_pretty_list raw_nil raw_cons (target, seri) thy =
  1173   let
  1234   let