4 Code generator facilities for inductive datatypes. |
4 Code generator facilities for inductive datatypes. |
5 *) |
5 *) |
6 |
6 |
7 signature DATATYPE_CODEGEN = |
7 signature DATATYPE_CODEGEN = |
8 sig |
8 sig |
9 val get_eq: theory -> string -> thm list |
9 val mk_eq: theory -> string -> thm list |
10 val get_case_cert: theory -> string -> thm |
10 val mk_case_cert: theory -> string -> thm |
11 val setup: theory -> theory |
11 val setup: theory -> theory |
12 end; |
12 end; |
13 |
13 |
14 structure DatatypeCodegen : DATATYPE_CODEGEN = |
14 structure DatatypeCodegen : DATATYPE_CODEGEN = |
15 struct |
15 struct |
83 let |
83 let |
84 val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs; |
84 val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs; |
85 val dts' = map (DatatypeAux.typ_of_dtyp descr sorts) dts; |
85 val dts' = map (DatatypeAux.typ_of_dtyp descr sorts) dts; |
86 val T = Type (tname, dts'); |
86 val T = Type (tname, dts'); |
87 val rest = mk_term_of_def gr "and " xs; |
87 val rest = mk_term_of_def gr "and " xs; |
88 val (_, eqs) = foldl_map (fn (prfx, (cname, Ts)) => |
88 val (_, eqs) = Library.foldl_map (fn (prfx, (cname, Ts)) => |
89 let val args = map (fn i => |
89 let val args = map (fn i => |
90 str ("x" ^ string_of_int i)) (1 upto length Ts) |
90 str ("x" ^ string_of_int i)) (1 upto length Ts) |
91 in (" | ", Pretty.blk (4, |
91 in (" | ", Pretty.blk (4, |
92 [str prfx, mk_term_of gr module' false T, Pretty.brk 1, |
92 [str prfx, mk_term_of gr module' false T, Pretty.brk 1, |
93 if null Ts then str (snd (get_const_id gr cname)) |
93 if null Ts then str (snd (get_const_id gr cname)) |
214 invoke_codegen thy defs dep module brack (eta_expand c ts (i+1)) gr |
214 invoke_codegen thy defs dep module brack (eta_expand c ts (i+1)) gr |
215 else |
215 else |
216 let |
216 let |
217 val ts1 = Library.take (i, ts); |
217 val ts1 = Library.take (i, ts); |
218 val t :: ts2 = Library.drop (i, ts); |
218 val t :: ts2 = Library.drop (i, ts); |
219 val names = foldr OldTerm.add_term_names |
219 val names = List.foldr OldTerm.add_term_names |
220 (map (fst o fst o dest_Var) (foldr OldTerm.add_term_vars [] ts1)) ts1; |
220 (map (fst o fst o dest_Var) (List.foldr OldTerm.add_term_vars [] ts1)) ts1; |
221 val (Ts, dT) = split_last (Library.take (i+1, fst (strip_type T))); |
221 val (Ts, dT) = split_last (Library.take (i+1, fst (strip_type T))); |
222 |
222 |
223 fun pcase [] [] [] gr = ([], gr) |
223 fun pcase [] [] [] gr = ([], gr) |
224 | pcase ((cname, cargs)::cs) (t::ts) (U::Us) gr = |
224 | pcase ((cname, cargs)::cs) (t::ts) (U::Us) gr = |
225 let |
225 let |
321 end; |
321 end; |
322 |
322 |
323 |
323 |
324 (* case certificates *) |
324 (* case certificates *) |
325 |
325 |
326 fun get_case_cert thy tyco = |
326 fun mk_case_cert thy tyco = |
327 let |
327 let |
328 val raw_thms = |
328 val raw_thms = |
329 (#case_rewrites o DatatypePackage.the_datatype thy) tyco; |
329 (#case_rewrites o DatatypePackage.the_datatype thy) tyco; |
330 val thms as hd_thm :: _ = raw_thms |
330 val thms as hd_thm :: _ = raw_thms |
331 |> Conjunction.intr_balanced |
331 |> Conjunction.intr_balanced |
355 end; |
355 end; |
356 |
356 |
357 fun add_datatype_cases dtco thy = |
357 fun add_datatype_cases dtco thy = |
358 let |
358 let |
359 val {case_rewrites, ...} = DatatypePackage.the_datatype thy dtco; |
359 val {case_rewrites, ...} = DatatypePackage.the_datatype thy dtco; |
360 val certs = get_case_cert thy dtco; |
360 val cert = mk_case_cert thy dtco; |
|
361 fun add_case_liberal thy = thy |
|
362 |> try (Code.add_case cert) |
|
363 |> the_default thy; |
361 in |
364 in |
362 thy |
365 thy |
363 |> Code.add_case certs |
366 |> add_case_liberal |
364 |> fold_rev Code.add_default_eqn case_rewrites |
367 |> fold_rev Code.add_default_eqn case_rewrites |
365 end; |
368 end; |
366 |
369 |
367 |
370 |
368 (* equality *) |
371 (* equality *) |
369 |
372 |
370 local |
373 local |
371 |
374 |
372 val not_sym = thm "HOL.not_sym"; |
375 val not_sym = @{thm HOL.not_sym}; |
373 val not_false_true = iffD2 OF [nth (thms "HOL.simp_thms") 7, TrueI]; |
376 val not_false_true = iffD2 OF [nth @{thms HOL.simp_thms} 7, TrueI]; |
374 val refl = thm "refl"; |
377 val refl = @{thm refl}; |
375 val eqTrueI = thm "eqTrueI"; |
378 val eqTrueI = @{thm eqTrueI}; |
376 |
379 |
377 fun mk_distinct cos = |
380 fun mk_distinct cos = |
378 let |
381 let |
379 fun sym_product [] = [] |
382 fun sym_product [] = [] |
380 | sym_product (x::xs) = map (pair x) xs @ sym_product xs; |
383 | sym_product (x::xs) = map (pair x) xs @ sym_product xs; |
395 in HOLogic.mk_Trueprop t end; |
398 in HOLogic.mk_Trueprop t end; |
396 in map mk_dist (sym_product cos) end; |
399 in map mk_dist (sym_product cos) end; |
397 |
400 |
398 in |
401 in |
399 |
402 |
400 fun get_eq thy dtco = |
403 fun mk_eq thy dtco = |
401 let |
404 let |
402 val (vs, cs) = DatatypePackage.the_datatype_spec thy dtco; |
405 val (vs, cs) = DatatypePackage.the_datatype_spec thy dtco; |
403 fun mk_triv_inject co = |
406 fun mk_triv_inject co = |
404 let |
407 let |
405 val ct' = Thm.cterm_of thy |
408 val ct' = Thm.cterm_of thy |
443 val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy); |
446 val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy); |
444 val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm; |
447 val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm; |
445 in (thm', lthy') end; |
448 in (thm', lthy') end; |
446 fun tac thms = Class.intro_classes_tac [] |
449 fun tac thms = Class.intro_classes_tac [] |
447 THEN ALLGOALS (ProofContext.fact_tac thms); |
450 THEN ALLGOALS (ProofContext.fact_tac thms); |
448 fun get_eq' thy dtco = get_eq thy dtco |
451 fun mk_eq' thy dtco = mk_eq thy dtco |
449 |> map (Code_Unit.constrain_thm thy [HOLogic.class_eq]) |
452 |> map (Code_Unit.constrain_thm thy [HOLogic.class_eq]) |
450 |> map Simpdata.mk_eq |
453 |> map Simpdata.mk_eq |
451 |> map (MetaSimplifier.rewrite_rule [Thm.transfer thy @{thm equals_eq}]) |
454 |> map (MetaSimplifier.rewrite_rule [Thm.transfer thy @{thm equals_eq}]) |
452 |> map (AxClass.unoverload thy); |
455 |> map (AxClass.unoverload thy); |
453 fun add_eq_thms dtco thy = |
456 fun add_eq_thms dtco thy = |
458 val eq_refl = @{thm HOL.eq_refl} |
461 val eq_refl = @{thm HOL.eq_refl} |
459 |> Thm.instantiate |
462 |> Thm.instantiate |
460 ([pairself (Thm.ctyp_of thy) (TVar (("'a", 0), @{sort eq}), Logic.varifyT ty)], []) |
463 ([pairself (Thm.ctyp_of thy) (TVar (("'a", 0), @{sort eq}), Logic.varifyT ty)], []) |
461 |> Simpdata.mk_eq |
464 |> Simpdata.mk_eq |
462 |> AxClass.unoverload thy; |
465 |> AxClass.unoverload thy; |
463 fun get_thms () = (eq_refl, false) |
466 fun mk_thms () = (eq_refl, false) |
464 :: rev (map (rpair true) (get_eq' (Theory.deref thy_ref) dtco)); |
467 :: rev (map (rpair true) (mk_eq' (Theory.deref thy_ref) dtco)); |
465 in |
468 in |
466 Code.add_eqnl (const, Lazy.lazy get_thms) thy |
469 Code.add_eqnl (const, Lazy.lazy mk_thms) thy |
467 end; |
470 end; |
468 in |
471 in |
469 thy |
472 thy |
470 |> TheoryTarget.instantiation (dtcos, vs', [HOLogic.class_eq]) |
473 |> TheoryTarget.instantiation (dtcos, vs', [HOLogic.class_eq]) |
471 |> fold_map add_def dtcos |
474 |> fold_map add_def dtcos |