src/HOL/Tools/datatype_codegen.ML
changeset 25505 4d531475129a
parent 25502 9200b36280c0
child 25534 d0b74fdd6067
equal deleted inserted replaced
25504:dc960d760052 25505:4d531475129a
     7 
     7 
     8 signature DATATYPE_CODEGEN =
     8 signature DATATYPE_CODEGEN =
     9 sig
     9 sig
    10   val get_eq: theory -> string -> thm list
    10   val get_eq: theory -> string -> thm list
    11   val get_eq_datatype: theory -> string -> thm list
    11   val get_eq_datatype: theory -> string -> thm list
    12   val dest_case_expr: theory -> term
       
    13     -> ((string * typ) list * ((term * typ) * (term * term) list)) option
       
    14   val get_case_cert: theory -> string -> thm
    12   val get_case_cert: theory -> string -> thm
    15 
    13 
    16   type hook = (string * (bool * ((string * sort) list * (string * typ list) list))) list
    14   type hook = (string * (bool * ((string * sort) list * (string * typ list) list))) list
    17     -> theory -> theory
    15     -> theory -> theory
    18   val add_codetypes_hook: hook -> theory -> theory
    16   val add_codetypes_hook: hook -> theory -> theory
   312   | datatype_tycodegen _ _ _ _ _ _ _ = NONE;
   310   | datatype_tycodegen _ _ _ _ _ _ _ = NONE;
   313 
   311 
   314 
   312 
   315 (** datatypes for code 2nd generation **)
   313 (** datatypes for code 2nd generation **)
   316 
   314 
   317 fun dtyp_of_case_const thy c =
   315 local
   318   Option.map (fn {descr, index, ...} => #1 (the (AList.lookup op = descr index)))
   316 
   319     (DatatypePackage.datatype_of_case thy c);
   317 val not_sym = thm "HOL.not_sym";
   320 
   318 val not_false_true = iffD2 OF [nth (thms "HOL.simp_thms") 7, TrueI];
   321 fun dest_case_app cs ts tys =
   319 val refl = thm "refl";
   322   let
   320 val eqTrueI = thm "eqTrueI";
   323     val names = (Name.make_context o map fst) (fold Term.add_tfrees ts []);
       
   324     val abs = Name.names names "a" (Library.drop (length ts, tys));
       
   325     val (ts', t) = split_last (ts @ map Free abs);
       
   326     val (tys', sty) = split_last tys;
       
   327     fun dest_case ((c, tys_decl), ty) t =
       
   328       let
       
   329         val (vs, t') = Term.strip_abs_eta (length tys_decl) t;
       
   330         val c' = list_comb (Const (c, map snd vs ---> sty), map Free vs);
       
   331       in case t'
       
   332        of Const ("HOL.undefined", _) => NONE
       
   333         | _ => SOME (c', t')
       
   334       end;
       
   335   in (abs, ((t, sty), map2 dest_case (cs ~~ tys') ts' |> map_filter I)) end;
       
   336 
       
   337 fun dest_case_expr thy t =
       
   338   case strip_comb t
       
   339    of (Const (c, ty), ts) =>
       
   340         (case dtyp_of_case_const thy c
       
   341          of SOME dtco =>
       
   342               let val (vs, cs) = (the o DatatypePackage.get_datatype_spec thy) dtco;
       
   343               in SOME (dest_case_app cs ts (Library.take (length cs + 1, (fst o strip_type) ty))) end
       
   344           | _ => NONE)
       
   345     | _ => NONE;
       
   346 
   321 
   347 fun mk_distinct cos =
   322 fun mk_distinct cos =
   348   let
   323   let
   349     fun sym_product [] = []
   324     fun sym_product [] = []
   350       | sym_product (x::xs) = map (pair x) xs @ sym_product xs;
   325       | sym_product (x::xs) = map (pair x) xs @ sym_product xs;
   363           (list_comb (co1, xs1), list_comb (co2, xs2));
   338           (list_comb (co1, xs1), list_comb (co2, xs2));
   364         val t = HOLogic.mk_not prem;
   339         val t = HOLogic.mk_not prem;
   365       in HOLogic.mk_Trueprop t end;
   340       in HOLogic.mk_Trueprop t end;
   366   in map mk_dist (sym_product cos) end;
   341   in map mk_dist (sym_product cos) end;
   367 
   342 
   368 local
       
   369   val not_sym = thm "HOL.not_sym";
       
   370   val not_false_true = iffD2 OF [nth (thms "HOL.simp_thms") 7, TrueI];
       
   371   val refl = thm "refl";
       
   372   val eqTrueI = thm "eqTrueI";
       
   373 in
   343 in
   374 
   344 
   375 fun get_eq_datatype thy dtco =
   345 fun get_eq_datatype thy dtco =
   376   let
   346   let
   377     val SOME (vs, cs) = DatatypePackage.get_datatype_spec thy dtco;
   347     val SOME (vs, cs) = DatatypePackage.get_datatype_spec thy dtco;
   436 
   406 
   437 (** codetypes for code 2nd generation **)
   407 (** codetypes for code 2nd generation **)
   438 
   408 
   439 (* abstraction over datatypes vs. type copies *)
   409 (* abstraction over datatypes vs. type copies *)
   440 
   410 
       
   411 fun get_typecopy_spec thy tyco =
       
   412   let
       
   413     val SOME { vs, constr, typ, ... } = TypecopyPackage.get_info thy tyco
       
   414   in (vs, [(constr, [typ])]) end;
       
   415 
       
   416 
   441 fun get_spec thy (dtco, true) =
   417 fun get_spec thy (dtco, true) =
   442       (the o DatatypePackage.get_datatype_spec thy) dtco
   418       (the o DatatypePackage.get_datatype_spec thy) dtco
   443   | get_spec thy (tyco, false) =
   419   | get_spec thy (tyco, false) =
   444       TypecopyPackage.get_spec thy tyco;
   420       get_typecopy_spec thy tyco;
   445 
   421 
   446 local
   422 local
   447   fun get_eq_thms thy tyco = case DatatypePackage.get_datatype thy tyco
   423   fun get_eq_thms thy tyco = case DatatypePackage.get_datatype thy tyco
   448    of SOME _ => get_eq_datatype thy tyco
   424    of SOME _ => get_eq_datatype thy tyco
   449     | NONE => [TypecopyPackage.get_eq thy tyco];
   425     | NONE => [TypecopyPackage.get_eq thy tyco];
   477     fun add_spec thy (tyco, is_dt) =
   453     fun add_spec thy (tyco, is_dt) =
   478       (tyco, (is_dt, get_spec thy (tyco, is_dt)));
   454       (tyco, (is_dt, get_spec thy (tyco, is_dt)));
   479     fun datatype_hook dtcos thy =
   455     fun datatype_hook dtcos thy =
   480       hook (map (add_spec thy) (map (rpair true) dtcos)) thy;
   456       hook (map (add_spec thy) (map (rpair true) dtcos)) thy;
   481     fun typecopy_hook tyco thy =
   457     fun typecopy_hook tyco thy =
   482       hook ([(tyco, (false, TypecopyPackage.get_spec thy tyco))]) thy;
   458       hook ([(tyco, (false, get_typecopy_spec thy tyco))]) thy;
   483   in
   459   in
   484     thy
   460     thy
   485     |> DatatypePackage.interpretation datatype_hook
   461     |> DatatypePackage.interpretation datatype_hook
   486     |> TypecopyPackage.interpretation typecopy_hook
   462     |> TypecopyPackage.interpretation typecopy_hook
   487   end;
   463   end;
   574 
   550 
   575 fun add_datatype_spec dtco thy =
   551 fun add_datatype_spec dtco thy =
   576   let
   552   let
   577     val SOME (vs, cos) = DatatypePackage.get_datatype_spec thy dtco;
   553     val SOME (vs, cos) = DatatypePackage.get_datatype_spec thy dtco;
   578     val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
   554     val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
   579   in try (Code.add_datatype cs) thy |> the_default thy end;
       
   580 
       
   581 fun add_datatype_case_certs dtco thy =
       
   582   Code.add_case (get_case_cert thy dtco) thy;
       
   583 
       
   584 fun add_datatype_case_defs dtco thy =
       
   585   let
       
   586     val {case_rewrites, ...} = DatatypePackage.the_datatype thy dtco;
   555     val {case_rewrites, ...} = DatatypePackage.the_datatype thy dtco;
       
   556     val certs = get_case_cert thy dtco;
   587   in
   557   in
   588     fold_rev Code.add_default_func case_rewrites thy
   558     thy
       
   559     |> try (Code.add_datatype cs)
       
   560     |> the_default thy
       
   561     |> Code.add_case certs
       
   562     |> fold_rev Code.add_default_func case_rewrites
   589   end;
   563   end;
   590 
   564 
   591 val setup = 
   565 val setup = 
   592   add_codegen "datatype" datatype_codegen
   566   add_codegen "datatype" datatype_codegen
   593   #> add_tycodegen "datatype" datatype_tycodegen
   567   #> add_tycodegen "datatype" datatype_tycodegen
   594   #> DatatypePackage.interpretation (fold add_datatype_spec)
   568   #> DatatypePackage.interpretation (fold add_datatype_spec)
   595   #> DatatypePackage.interpretation (fold add_datatype_case_certs)
       
   596   #> DatatypePackage.interpretation (fold add_datatype_case_defs)
       
   597   #> add_codetypes_hook eq_hook
   569   #> add_codetypes_hook eq_hook
   598 
   570 
   599 end;
   571 end;