src/HOL/Tools/datatype_codegen.ML
changeset 20177 0af885e3dabf
parent 20105 454f4be984b7
child 20182 79c9ff40d760
equal deleted inserted replaced
20176:36737fb58614 20177:0af885e3dabf
     9 sig
     9 sig
    10   val get_datatype_spec_thms: theory -> string
    10   val get_datatype_spec_thms: theory -> string
    11     -> (((string * sort) list * (string * typ list) list) * tactic) option
    11     -> (((string * sort) list * (string * typ list) list) * tactic) option
    12   val get_all_datatype_cons: theory -> (string * string) list
    12   val get_all_datatype_cons: theory -> (string * string) list
    13   val dest_case_expr: theory -> term
    13   val dest_case_expr: theory -> term
    14     -> ((string * typ) list * ((term * typ) * (term * term) list)) option;
    14     -> ((string * typ) list * ((term * typ) * (term * term) list)) option
    15   val add_datatype_case_const: string -> theory -> theory
    15   val add_datatype_case_const: string -> theory -> theory
    16   val add_datatype_case_defs: string -> theory -> theory
    16   val add_datatype_case_defs: string -> theory -> theory
       
    17   val datatypes_dependency: theory -> string list list
       
    18   val get_datatype_mut_specs: theory -> string list
       
    19     -> ((string * sort) list * (string * (string * typ list) list) list)
       
    20   val get_datatype_arities: theory -> string list -> sort
       
    21     -> (string * (((string * sort list) * sort)  * term list)) list option
       
    22   val datatype_prove_arities : tactic -> string list -> sort
       
    23     -> ((string * term list) list
       
    24     -> ((bstring * attribute list) * term) list) -> theory -> theory
    17   val setup: theory -> theory
    25   val setup: theory -> theory
    18 end;
    26 end;
    19 
    27 
    20 structure DatatypeCodegen : DATATYPE_CODEGEN =
    28 structure DatatypeCodegen : DATATYPE_CODEGEN =
    21 struct
    29 struct
   304   | datatype_tycodegen _ _ _ _ _ _ _ = NONE;
   312   | datatype_tycodegen _ _ _ _ _ _ _ = NONE;
   305 
   313 
   306 
   314 
   307 (** code 2nd generation **)
   315 (** code 2nd generation **)
   308 
   316 
       
   317 fun datatypes_dependency thy =
       
   318   let
       
   319     val dtnames = DatatypePackage.get_datatypes thy;
       
   320     fun add_node (dtname, _) =
       
   321       let
       
   322         fun add_tycos (Type (tyco, tys)) = insert (op =) tyco #> fold add_tycos tys
       
   323           | add_tycos _ = I;
       
   324         val deps = (filter (Symtab.defined dtnames) o maps (fn ty =>
       
   325           add_tycos ty [])
       
   326             o maps snd o snd o the o DatatypePackage.get_datatype_spec thy) dtname
       
   327       in
       
   328         Graph.default_node (dtname, ())
       
   329         #> fold (fn dtname' =>
       
   330              Graph.default_node (dtname', ())
       
   331              #> Graph.add_edge (dtname', dtname)
       
   332            ) deps
       
   333       end
       
   334   in
       
   335     Graph.empty
       
   336     |> Symtab.fold add_node dtnames
       
   337     |> Graph.strong_conn
       
   338   end;
       
   339 
       
   340 fun get_datatype_mut_specs thy (tycos as tyco :: _) =
       
   341   let
       
   342     val tycos' = (map (#1 o snd) o #descr o DatatypePackage.the_datatype thy) tyco;
       
   343     val _ = if gen_subset (op =) (tycos, tycos') then () else
       
   344       error ("datatype constructors are not mutually recursive: " ^ (commas o map quote) tycos);
       
   345     val (vs::_, css) = split_list (map (the o DatatypePackage.get_datatype_spec thy) tycos);
       
   346   in (vs, tycos ~~ css) end;
       
   347 
       
   348 fun get_datatype_arities thy tycos sort =
       
   349   let
       
   350     val algebra = Sign.classes_of thy;
       
   351     val (vs_proto, css_proto) = get_datatype_mut_specs thy tycos;
       
   352     val vs = map (fn (v, vsort) => (v, Sorts.inter_sort algebra (vsort, sort))) vs_proto;
       
   353     fun inst_type tyco (c, tys) =
       
   354       let
       
   355         val tys' = (map o map_atyps)
       
   356           (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) vs v))) tys
       
   357       in (c, tys') end;
       
   358     val css = map (fn (tyco, cs) => (tyco, (map (inst_type tyco) cs))) css_proto;
       
   359     fun mk_arity tyco =
       
   360       ((tyco, map snd vs), sort);
       
   361     fun typ_of_sort ty =
       
   362       let
       
   363         val arities = map (fn (tyco, _) => ((tyco, map snd vs), sort)) css;
       
   364       in ClassPackage.assume_arities_of_sort thy arities (ty, sort) end;
       
   365     fun mk_cons tyco (c, tys) =
       
   366       let
       
   367         val ts = Name.give_names Name.context "a" tys;
       
   368         val ty = tys ---> Type (tyco, map TFree vs);
       
   369       in list_comb (Const (c, ty), map Free ts) end;
       
   370   in if forall (fn (_, cs) => forall (fn (_, tys) => forall typ_of_sort tys) cs) css
       
   371     then SOME (
       
   372       map (fn (tyco, cs) => (tyco, (mk_arity tyco, map (mk_cons tyco) cs))) css
       
   373     ) else NONE
       
   374   end;
       
   375 
       
   376 fun datatype_prove_arities tac tycos sort f thy =
       
   377   case get_datatype_arities thy tycos sort
       
   378    of NONE => thy
       
   379     | SOME insts => let
       
   380         fun proven ((tyco, asorts), sort) =
       
   381           Sorts.of_sort (Sign.classes_of thy)
       
   382             (Type (tyco, map TFree (Name.give_names Name.context "'a" asorts)), sort);
       
   383         val (arities, css) = (split_list o map_filter
       
   384           (fn (tyco, (arity, cs)) => if proven arity
       
   385             then SOME (arity, (tyco, cs)) else NONE)) insts;
       
   386       in
       
   387         thy
       
   388         |> ClassPackage.prove_instance_arity tac
       
   389              arities ("", []) (f css)
       
   390       end;
       
   391 
   309 fun dtyp_of_case_const thy c =
   392 fun dtyp_of_case_const thy c =
   310   get_first (fn (dtco, { case_name, ... }) => if case_name = c then SOME dtco else NONE)
   393   get_first (fn (dtco, { case_name, ... }) => if case_name = c then SOME dtco else NONE)
   311     ((Symtab.dest o DatatypePackage.get_datatypes) thy);
   394     ((Symtab.dest o DatatypePackage.get_datatypes) thy);
   312 
   395 
   313 fun dest_case_app cs ts tys =
   396 fun dest_case_app cs ts tys =
   314   let
   397   let
   315     val abs = CodegenThingol.give_names [] (Library.drop (length ts, tys));
   398     val abs = Name.give_names Name.context "a" (Library.drop (length ts, tys));
   316     val (ts', t) = split_last (ts @ map Free abs);
   399     val (ts', t) = split_last (ts @ map Free abs);
   317     val (tys', sty) = split_last tys;
   400     val (tys', sty) = split_last tys;
   318     fun freenames_of t = fold_aterms
   401     fun freenames_of t = fold_aterms
   319       (fn Free (v, _) => insert (op =) v | _ => I) t [];
   402       (fn Free (v, _) => insert (op =) v | _ => I) t [];
   320     fun dest_case ((c, tys_decl), ty) t =
   403     fun dest_case ((c, tys_decl), ty) t =
   380   add_tycodegen "datatype" datatype_tycodegen #>
   463   add_tycodegen "datatype" datatype_tycodegen #>
   381   CodegenTheorems.add_datatype_extr
   464   CodegenTheorems.add_datatype_extr
   382     get_datatype_spec_thms #>
   465     get_datatype_spec_thms #>
   383   CodegenPackage.set_get_all_datatype_cons
   466   CodegenPackage.set_get_all_datatype_cons
   384     get_all_datatype_cons #>
   467     get_all_datatype_cons #>
   385   DatatypeHooks.add add_datatype_case_const #>
   468   DatatypeHooks.add (fold add_datatype_case_const) #>
   386   DatatypeHooks.add add_datatype_case_defs
   469   DatatypeHooks.add (fold add_datatype_case_defs)
   387 
   470 
   388 end;
   471 end;