659 |> Syntax.check_terms lthy |
659 |> Syntax.check_terms lthy |
660 |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.map_name Thm.def_name b, []), t))) bs mxs |
660 |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.map_name Thm.def_name b, []), t))) bs mxs |
661 |> rpair exclss' |
661 |> rpair exclss' |
662 end; |
662 end; |
663 |
663 |
664 fun mk_real_disc_eqns fun_binding arg_Ts {ctr_specs, ...} disc_eqns = |
664 fun mk_real_disc_eqns fun_binding arg_Ts {ctr_specs, ...} sel_eqns disc_eqns = |
665 if length disc_eqns <> length ctr_specs - 1 then disc_eqns else |
665 if length disc_eqns <> length ctr_specs - 1 then disc_eqns else |
666 let |
666 let |
667 val n = 0 upto length ctr_specs |
667 val n = 0 upto length ctr_specs |
668 |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns)); |
668 |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns)); |
|
669 val fun_args = (try (#fun_args o hd) disc_eqns, try (#fun_args o hd) sel_eqns) |
|
670 |> the_default (map (curry Free Name.uu) arg_Ts) o merge_options; |
669 val extra_disc_eqn = { |
671 val extra_disc_eqn = { |
670 fun_name = Binding.name_of fun_binding, |
672 fun_name = Binding.name_of fun_binding, |
671 fun_T = arg_Ts ---> body_type (fastype_of (#ctr (hd ctr_specs))), |
673 fun_T = arg_Ts ---> body_type (fastype_of (#ctr (hd ctr_specs))), |
672 fun_args = the_default (map (curry Free Name.uu) arg_Ts) (try (#fun_args o hd) disc_eqns), |
674 fun_args = fun_args, |
673 ctr = #ctr (nth ctr_specs n), |
675 ctr = #ctr (nth ctr_specs n), |
674 ctr_no = n, |
676 ctr_no = n, |
675 disc = #disc (nth ctr_specs n), |
677 disc = #disc (nth ctr_specs n), |
676 prems = maps (invert_prems o #prems) disc_eqns, |
678 prems = maps (invert_prems o #prems) disc_eqns, |
677 user_eqn = undef_const}; |
679 user_eqn = undef_const}; |
716 |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names |
718 |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names |
717 |> map (flat o snd); |
719 |> map (flat o snd); |
718 |
720 |
719 val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =)); |
721 val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =)); |
720 val arg_Tss = map (binder_types o snd o fst) fixes; |
722 val arg_Tss = map (binder_types o snd o fst) fixes; |
721 val disc_eqnss = map4 mk_real_disc_eqns bs arg_Tss corec_specs disc_eqnss'; |
723 val disc_eqnss = map5 mk_real_disc_eqns bs arg_Tss corec_specs sel_eqnss disc_eqnss'; |
722 val (defs, exclss') = |
724 val (defs, exclss') = |
723 co_build_defs lthy' bs mxs has_call arg_Tss corec_specs disc_eqnss sel_eqnss; |
725 co_build_defs lthy' bs mxs has_call arg_Tss corec_specs disc_eqnss sel_eqnss; |
724 |
726 |
725 (* try to prove (automatically generated) tautologies by ourselves *) |
727 (* try to prove (automatically generated) tautologies by ourselves *) |
726 val exclss'' = exclss' |
728 val exclss'' = exclss' |
742 val exclssss = (exclss' ~~ taut_thmss |> map (op @), fun_names ~~ corec_specs) |
744 val exclssss = (exclss' ~~ taut_thmss |> map (op @), fun_names ~~ corec_specs) |
743 |-> map2 (fn excls => fn (_, {ctr_specs, ...}) => mk_exclsss excls (length ctr_specs)); |
745 |-> map2 (fn excls => fn (_, {ctr_specs, ...}) => mk_exclsss excls (length ctr_specs)); |
744 |
746 |
745 fun prove_disc {ctr_specs, ...} exclsss |
747 fun prove_disc {ctr_specs, ...} exclsss |
746 {fun_name, fun_T, fun_args, ctr_no, prems, user_eqn, ...} = |
748 {fun_name, fun_T, fun_args, ctr_no, prems, user_eqn, ...} = |
747 if user_eqn = undef_const then [] else |
749 if Term.aconv_untyped (#disc (nth ctr_specs ctr_no), @{term "\<lambda>x. x = x"}) then [] else |
748 let |
750 let |
749 val disc_corec = nth ctr_specs ctr_no |> #disc_corec; |
751 val {disc_corec, ...} = nth ctr_specs ctr_no; |
750 val k = 1 + ctr_no; |
752 val k = 1 + ctr_no; |
751 val m = length prems; |
753 val m = length prems; |
752 val t = |
754 val t = |
753 (* FIXME use applied_fun from dissect_\<dots> instead? *) |
|
754 list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0)) |
755 list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0)) |
755 |> curry betapply (#disc (nth ctr_specs ctr_no)) (*###*) |
756 |> curry betapply (#disc (nth ctr_specs ctr_no)) (*###*) |
756 |> HOLogic.mk_Trueprop |
757 |> HOLogic.mk_Trueprop |
757 |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems) |
758 |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems) |
758 |> curry Logic.list_all (map dest_Free fun_args); |
759 |> curry Logic.list_all (map dest_Free fun_args); |
788 |> pair sel |
789 |> pair sel |
789 end; |
790 end; |
790 |
791 |
791 fun prove_ctr (_, disc_thms) (_, sel_thms') disc_eqns sel_eqns |
792 fun prove_ctr (_, disc_thms) (_, sel_thms') disc_eqns sel_eqns |
792 {ctr, disc, sels, collapse, ...} = |
793 {ctr, disc, sels, collapse, ...} = |
|
794 let val _ = tracing ("disc = " ^ @{make_string} disc); in |
793 if not (exists (equal ctr o #ctr) disc_eqns) |
795 if not (exists (equal ctr o #ctr) disc_eqns) |
794 andalso (warning ("no disc_eqn for ctr " ^ Syntax.string_of_term lthy ctr); true) |
796 andalso not (exists (equal ctr o #ctr) sel_eqns) |
795 orelse (* don't try to prove theorems where some sel_eqns are missing *) |
797 andalso (warning ("no eqns for ctr " ^ Syntax.string_of_term lthy ctr); true) |
|
798 orelse (* don't try to prove theorems when some sel_eqns are missing *) |
796 filter (equal ctr o #ctr) sel_eqns |
799 filter (equal ctr o #ctr) sel_eqns |
797 |> fst o finds ((op =) o apsnd #sel) sels |
800 |> fst o finds ((op =) o apsnd #sel) sels |
798 |> exists (null o snd) |
801 |> exists (null o snd) |
799 andalso (warning ("sel_eqn(s) missing for ctr " ^ Syntax.string_of_term lthy ctr); true) |
802 andalso (warning ("sel_eqn(s) missing for ctr " ^ Syntax.string_of_term lthy ctr); true) |
800 orelse |
|
801 #user_eqn (the (find_first (equal ctr o #ctr) disc_eqns)) = undef_const |
|
802 andalso (warning ("auto-generated disc_eqn for ctr " ^ Syntax.string_of_term lthy ctr); true) |
|
803 then [] else |
803 then [] else |
804 let |
804 let |
805 val _ = tracing ("ctr = " ^ Syntax.string_of_term lthy ctr); |
805 val _ = tracing ("ctr = " ^ Syntax.string_of_term lthy ctr); |
806 val _ = tracing ("disc = " ^ Syntax.string_of_term lthy (#disc (the (find_first (equal ctr o #ctr) disc_eqns)))); |
806 val _ = tracing (the_default "NO disc_eqn" (Option.map (curry (op ^) "disc = " o Syntax.string_of_term lthy o #disc) (find_first (equal ctr o #ctr) disc_eqns))); |
807 val {fun_name, fun_T, fun_args, prems, ...} = |
807 val (fun_name, fun_T, fun_args, prems) = |
808 the (find_first (equal ctr o #ctr) disc_eqns); |
808 (find_first (equal ctr o #ctr) disc_eqns, find_first (equal ctr o #ctr) sel_eqns) |
|
809 |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x)) |
|
810 ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, [])) |
|
811 |> the o merge_options; |
809 val m = length prems; |
812 val m = length prems; |
810 val t = sel_eqns |
813 val t = sel_eqns |
811 |> fst o finds ((op =) o apsnd #sel) sels |
814 |> fst o finds ((op =) o apsnd #sel) sels |
812 |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x)) #-> abstract) |
815 |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x)) #-> abstract) |
813 |> curry list_comb ctr |
816 |> curry list_comb ctr |
814 |> curry HOLogic.mk_eq (list_comb (Free (fun_name, fun_T), |
817 |> curry HOLogic.mk_eq (list_comb (Free (fun_name, fun_T), |
815 map Bound (length fun_args - 1 downto 0))) |
818 map Bound (length fun_args - 1 downto 0))) |
816 |> HOLogic.mk_Trueprop |
819 |> HOLogic.mk_Trueprop |
817 |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems) |
820 |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems) |
818 |> curry Logic.list_all (map dest_Free fun_args); |
821 |> curry Logic.list_all (map dest_Free fun_args); |
819 val disc_thm = the_default TrueI (AList.lookup (op =) disc_thms disc); |
822 val maybe_disc_thm = AList.lookup (op =) disc_thms disc; |
820 val sel_thms = map snd (filter (member (op =) sels o fst) sel_thms'); |
823 val sel_thms = map snd (filter (member (op =) sels o fst) sel_thms'); |
821 val _ = tracing ("t = " ^ Syntax.string_of_term lthy t); |
824 val _ = tracing ("t = " ^ Syntax.string_of_term lthy t); |
822 val _ = tracing ("m = " ^ @{make_string} m); |
825 val _ = tracing ("m = " ^ @{make_string} m); |
823 val _ = tracing ("collapse = " ^ @{make_string} collapse); |
826 val _ = tracing ("collapse = " ^ @{make_string} collapse); |
824 val _ = tracing ("disc_thm = " ^ @{make_string} disc_thm); |
827 val _ = tracing ("maybe_disc_thm = " ^ @{make_string} maybe_disc_thm); |
825 val _ = tracing ("sel_thms = " ^ @{make_string} sel_thms); |
828 val _ = tracing ("sel_thms = " ^ @{make_string} sel_thms); |
826 in |
829 in |
827 mk_primcorec_ctr_of_dtr_tac lthy m collapse disc_thm sel_thms |
830 mk_primcorec_ctr_of_dtr_tac lthy m collapse maybe_disc_thm sel_thms |
828 |> K |> Goal.prove lthy [] [] t |
831 |> K |> Goal.prove lthy [] [] t |
829 |> single |
832 |> single |
|
833 (*handle ERROR x => (warning x; []))*) |
|
834 end |
830 end; |
835 end; |
831 |
836 |
832 val (disc_notes, disc_thmss) = |
837 val (disc_notes, disc_thmss) = |
833 fun_names ~~ map3 (maps oo prove_disc) corec_specs exclssss disc_eqnss |
838 fun_names ~~ map3 (maps oo prove_disc) corec_specs exclssss disc_eqnss |
834 |> `(map (fn (fun_name, thms) => |
839 |> `(map (fn (fun_name, thms) => |