src/HOL/Tools/datatype_codegen.ML
changeset 30240 5b25fee0362c
parent 29302 eb782d1dc07c
child 30242 aea5d7fa7ef5
equal deleted inserted replaced
30239:179ff9cb160b 30240:5b25fee0362c
     4 Code generator facilities for inductive datatypes.
     4 Code generator facilities for inductive datatypes.
     5 *)
     5 *)
     6 
     6 
     7 signature DATATYPE_CODEGEN =
     7 signature DATATYPE_CODEGEN =
     8 sig
     8 sig
     9   val get_eq: theory -> string -> thm list
     9   val mk_eq: theory -> string -> thm list
    10   val get_case_cert: theory -> string -> thm
    10   val mk_case_cert: theory -> string -> thm
    11   val setup: theory -> theory
    11   val setup: theory -> theory
    12 end;
    12 end;
    13 
    13 
    14 structure DatatypeCodegen : DATATYPE_CODEGEN =
    14 structure DatatypeCodegen : DATATYPE_CODEGEN =
    15 struct
    15 struct
    83           let
    83           let
    84             val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
    84             val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
    85             val dts' = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
    85             val dts' = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
    86             val T = Type (tname, dts');
    86             val T = Type (tname, dts');
    87             val rest = mk_term_of_def gr "and " xs;
    87             val rest = mk_term_of_def gr "and " xs;
    88             val (_, eqs) = foldl_map (fn (prfx, (cname, Ts)) =>
    88             val (_, eqs) = Library.foldl_map (fn (prfx, (cname, Ts)) =>
    89               let val args = map (fn i =>
    89               let val args = map (fn i =>
    90                 str ("x" ^ string_of_int i)) (1 upto length Ts)
    90                 str ("x" ^ string_of_int i)) (1 upto length Ts)
    91               in ("  | ", Pretty.blk (4,
    91               in ("  | ", Pretty.blk (4,
    92                 [str prfx, mk_term_of gr module' false T, Pretty.brk 1,
    92                 [str prfx, mk_term_of gr module' false T, Pretty.brk 1,
    93                  if null Ts then str (snd (get_const_id gr cname))
    93                  if null Ts then str (snd (get_const_id gr cname))
   214        invoke_codegen thy defs dep module brack (eta_expand c ts (i+1)) gr
   214        invoke_codegen thy defs dep module brack (eta_expand c ts (i+1)) gr
   215     else
   215     else
   216       let
   216       let
   217         val ts1 = Library.take (i, ts);
   217         val ts1 = Library.take (i, ts);
   218         val t :: ts2 = Library.drop (i, ts);
   218         val t :: ts2 = Library.drop (i, ts);
   219         val names = foldr OldTerm.add_term_names
   219         val names = List.foldr OldTerm.add_term_names
   220           (map (fst o fst o dest_Var) (foldr OldTerm.add_term_vars [] ts1)) ts1;
   220           (map (fst o fst o dest_Var) (List.foldr OldTerm.add_term_vars [] ts1)) ts1;
   221         val (Ts, dT) = split_last (Library.take (i+1, fst (strip_type T)));
   221         val (Ts, dT) = split_last (Library.take (i+1, fst (strip_type T)));
   222 
   222 
   223         fun pcase [] [] [] gr = ([], gr)
   223         fun pcase [] [] [] gr = ([], gr)
   224           | pcase ((cname, cargs)::cs) (t::ts) (U::Us) gr =
   224           | pcase ((cname, cargs)::cs) (t::ts) (U::Us) gr =
   225               let
   225               let
   321   end;
   321   end;
   322 
   322 
   323 
   323 
   324 (* case certificates *)
   324 (* case certificates *)
   325 
   325 
   326 fun get_case_cert thy tyco =
   326 fun mk_case_cert thy tyco =
   327   let
   327   let
   328     val raw_thms =
   328     val raw_thms =
   329       (#case_rewrites o DatatypePackage.the_datatype thy) tyco;
   329       (#case_rewrites o DatatypePackage.the_datatype thy) tyco;
   330     val thms as hd_thm :: _ = raw_thms
   330     val thms as hd_thm :: _ = raw_thms
   331       |> Conjunction.intr_balanced
   331       |> Conjunction.intr_balanced
   355   end;
   355   end;
   356 
   356 
   357 fun add_datatype_cases dtco thy =
   357 fun add_datatype_cases dtco thy =
   358   let
   358   let
   359     val {case_rewrites, ...} = DatatypePackage.the_datatype thy dtco;
   359     val {case_rewrites, ...} = DatatypePackage.the_datatype thy dtco;
   360     val certs = get_case_cert thy dtco;
   360     val cert = mk_case_cert thy dtco;
       
   361     fun add_case_liberal thy = thy
       
   362       |> try (Code.add_case cert)
       
   363       |> the_default thy;
   361   in
   364   in
   362     thy
   365     thy
   363     |> Code.add_case certs
   366     |> add_case_liberal
   364     |> fold_rev Code.add_default_eqn case_rewrites
   367     |> fold_rev Code.add_default_eqn case_rewrites
   365   end;
   368   end;
   366 
   369 
   367 
   370 
   368 (* equality *)
   371 (* equality *)
   369 
   372 
   370 local
   373 local
   371 
   374 
   372 val not_sym = thm "HOL.not_sym";
   375 val not_sym = @{thm HOL.not_sym};
   373 val not_false_true = iffD2 OF [nth (thms "HOL.simp_thms") 7, TrueI];
   376 val not_false_true = iffD2 OF [nth @{thms HOL.simp_thms} 7, TrueI];
   374 val refl = thm "refl";
   377 val refl = @{thm refl};
   375 val eqTrueI = thm "eqTrueI";
   378 val eqTrueI = @{thm eqTrueI};
   376 
   379 
   377 fun mk_distinct cos =
   380 fun mk_distinct cos =
   378   let
   381   let
   379     fun sym_product [] = []
   382     fun sym_product [] = []
   380       | sym_product (x::xs) = map (pair x) xs @ sym_product xs;
   383       | sym_product (x::xs) = map (pair x) xs @ sym_product xs;
   395       in HOLogic.mk_Trueprop t end;
   398       in HOLogic.mk_Trueprop t end;
   396   in map mk_dist (sym_product cos) end;
   399   in map mk_dist (sym_product cos) end;
   397 
   400 
   398 in
   401 in
   399 
   402 
   400 fun get_eq thy dtco =
   403 fun mk_eq thy dtco =
   401   let
   404   let
   402     val (vs, cs) = DatatypePackage.the_datatype_spec thy dtco;
   405     val (vs, cs) = DatatypePackage.the_datatype_spec thy dtco;
   403     fun mk_triv_inject co =
   406     fun mk_triv_inject co =
   404       let
   407       let
   405         val ct' = Thm.cterm_of thy
   408         val ct' = Thm.cterm_of thy
   443         val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy);
   446         val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy);
   444         val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm;
   447         val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm;
   445       in (thm', lthy') end;
   448       in (thm', lthy') end;
   446     fun tac thms = Class.intro_classes_tac []
   449     fun tac thms = Class.intro_classes_tac []
   447       THEN ALLGOALS (ProofContext.fact_tac thms);
   450       THEN ALLGOALS (ProofContext.fact_tac thms);
   448     fun get_eq' thy dtco = get_eq thy dtco
   451     fun mk_eq' thy dtco = mk_eq thy dtco
   449       |> map (Code_Unit.constrain_thm thy [HOLogic.class_eq])
   452       |> map (Code_Unit.constrain_thm thy [HOLogic.class_eq])
   450       |> map Simpdata.mk_eq
   453       |> map Simpdata.mk_eq
   451       |> map (MetaSimplifier.rewrite_rule [Thm.transfer thy @{thm equals_eq}])
   454       |> map (MetaSimplifier.rewrite_rule [Thm.transfer thy @{thm equals_eq}])
   452       |> map (AxClass.unoverload thy);
   455       |> map (AxClass.unoverload thy);
   453     fun add_eq_thms dtco thy =
   456     fun add_eq_thms dtco thy =
   458         val eq_refl = @{thm HOL.eq_refl}
   461         val eq_refl = @{thm HOL.eq_refl}
   459           |> Thm.instantiate
   462           |> Thm.instantiate
   460               ([pairself (Thm.ctyp_of thy) (TVar (("'a", 0), @{sort eq}), Logic.varifyT ty)], [])
   463               ([pairself (Thm.ctyp_of thy) (TVar (("'a", 0), @{sort eq}), Logic.varifyT ty)], [])
   461           |> Simpdata.mk_eq
   464           |> Simpdata.mk_eq
   462           |> AxClass.unoverload thy;
   465           |> AxClass.unoverload thy;
   463         fun get_thms () = (eq_refl, false)
   466         fun mk_thms () = (eq_refl, false)
   464           :: rev (map (rpair true) (get_eq' (Theory.deref thy_ref) dtco));
   467           :: rev (map (rpair true) (mk_eq' (Theory.deref thy_ref) dtco));
   465       in
   468       in
   466         Code.add_eqnl (const, Lazy.lazy get_thms) thy
   469         Code.add_eqnl (const, Lazy.lazy mk_thms) thy
   467       end;
   470       end;
   468   in
   471   in
   469     thy
   472     thy
   470     |> TheoryTarget.instantiation (dtcos, vs', [HOLogic.class_eq])
   473     |> TheoryTarget.instantiation (dtcos, vs', [HOLogic.class_eq])
   471     |> fold_map add_def dtcos
   474     |> fold_map add_def dtcos