src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
changeset 53722 e176d6d3345f
parent 53720 03fac7082137
child 53725 9e64151359e8
equal deleted inserted replaced
53721:ccaceea6c768 53722:e176d6d3345f
   659     |> Syntax.check_terms lthy
   659     |> Syntax.check_terms lthy
   660     |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.map_name Thm.def_name b, []), t))) bs mxs
   660     |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.map_name Thm.def_name b, []), t))) bs mxs
   661     |> rpair exclss'
   661     |> rpair exclss'
   662   end;
   662   end;
   663 
   663 
   664 fun mk_real_disc_eqns fun_binding arg_Ts {ctr_specs, ...} disc_eqns =
   664 fun mk_real_disc_eqns fun_binding arg_Ts {ctr_specs, ...} sel_eqns disc_eqns =
   665   if length disc_eqns <> length ctr_specs - 1 then disc_eqns else
   665   if length disc_eqns <> length ctr_specs - 1 then disc_eqns else
   666     let
   666     let
   667       val n = 0 upto length ctr_specs
   667       val n = 0 upto length ctr_specs
   668         |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns));
   668         |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns));
       
   669       val fun_args = (try (#fun_args o hd) disc_eqns, try (#fun_args o hd) sel_eqns)
       
   670         |> the_default (map (curry Free Name.uu) arg_Ts) o merge_options;
   669       val extra_disc_eqn = {
   671       val extra_disc_eqn = {
   670         fun_name = Binding.name_of fun_binding,
   672         fun_name = Binding.name_of fun_binding,
   671         fun_T = arg_Ts ---> body_type (fastype_of (#ctr (hd ctr_specs))),
   673         fun_T = arg_Ts ---> body_type (fastype_of (#ctr (hd ctr_specs))),
   672         fun_args = the_default (map (curry Free Name.uu) arg_Ts) (try (#fun_args o hd) disc_eqns),
   674         fun_args = fun_args,
   673         ctr = #ctr (nth ctr_specs n),
   675         ctr = #ctr (nth ctr_specs n),
   674         ctr_no = n,
   676         ctr_no = n,
   675         disc = #disc (nth ctr_specs n),
   677         disc = #disc (nth ctr_specs n),
   676         prems = maps (invert_prems o #prems) disc_eqns,
   678         prems = maps (invert_prems o #prems) disc_eqns,
   677         user_eqn = undef_const};
   679         user_eqn = undef_const};
   716       |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
   718       |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
   717       |> map (flat o snd);
   719       |> map (flat o snd);
   718 
   720 
   719     val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
   721     val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
   720     val arg_Tss = map (binder_types o snd o fst) fixes;
   722     val arg_Tss = map (binder_types o snd o fst) fixes;
   721     val disc_eqnss = map4 mk_real_disc_eqns bs arg_Tss corec_specs disc_eqnss';
   723     val disc_eqnss = map5 mk_real_disc_eqns bs arg_Tss corec_specs sel_eqnss disc_eqnss';
   722     val (defs, exclss') =
   724     val (defs, exclss') =
   723       co_build_defs lthy' bs mxs has_call arg_Tss corec_specs disc_eqnss sel_eqnss;
   725       co_build_defs lthy' bs mxs has_call arg_Tss corec_specs disc_eqnss sel_eqnss;
   724 
   726 
   725     (* try to prove (automatically generated) tautologies by ourselves *)
   727     (* try to prove (automatically generated) tautologies by ourselves *)
   726     val exclss'' = exclss'
   728     val exclss'' = exclss'
   742         val exclssss = (exclss' ~~ taut_thmss |> map (op @), fun_names ~~ corec_specs)
   744         val exclssss = (exclss' ~~ taut_thmss |> map (op @), fun_names ~~ corec_specs)
   743           |-> map2 (fn excls => fn (_, {ctr_specs, ...}) => mk_exclsss excls (length ctr_specs));
   745           |-> map2 (fn excls => fn (_, {ctr_specs, ...}) => mk_exclsss excls (length ctr_specs));
   744 
   746 
   745         fun prove_disc {ctr_specs, ...} exclsss
   747         fun prove_disc {ctr_specs, ...} exclsss
   746             {fun_name, fun_T, fun_args, ctr_no, prems, user_eqn, ...} =
   748             {fun_name, fun_T, fun_args, ctr_no, prems, user_eqn, ...} =
   747           if user_eqn = undef_const then [] else
   749           if Term.aconv_untyped (#disc (nth ctr_specs ctr_no), @{term "\<lambda>x. x = x"}) then [] else
   748             let
   750             let
   749               val disc_corec = nth ctr_specs ctr_no |> #disc_corec;
   751               val {disc_corec, ...} = nth ctr_specs ctr_no;
   750               val k = 1 + ctr_no;
   752               val k = 1 + ctr_no;
   751               val m = length prems;
   753               val m = length prems;
   752               val t =
   754               val t =
   753                 (* FIXME use applied_fun from dissect_\<dots> instead? *)
       
   754                 list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0))
   755                 list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0))
   755                 |> curry betapply (#disc (nth ctr_specs ctr_no)) (*###*)
   756                 |> curry betapply (#disc (nth ctr_specs ctr_no)) (*###*)
   756                 |> HOLogic.mk_Trueprop
   757                 |> HOLogic.mk_Trueprop
   757                 |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
   758                 |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
   758                 |> curry Logic.list_all (map dest_Free fun_args);
   759                 |> curry Logic.list_all (map dest_Free fun_args);
   788             |> pair sel
   789             |> pair sel
   789           end;
   790           end;
   790 
   791 
   791         fun prove_ctr (_, disc_thms) (_, sel_thms') disc_eqns sel_eqns
   792         fun prove_ctr (_, disc_thms) (_, sel_thms') disc_eqns sel_eqns
   792             {ctr, disc, sels, collapse, ...} =
   793             {ctr, disc, sels, collapse, ...} =
       
   794 let val _ = tracing ("disc = " ^ @{make_string} disc); in
   793           if not (exists (equal ctr o #ctr) disc_eqns)
   795           if not (exists (equal ctr o #ctr) disc_eqns)
   794 andalso (warning ("no disc_eqn for ctr " ^ Syntax.string_of_term lthy ctr); true)
   796               andalso not (exists (equal ctr o #ctr) sel_eqns)
   795             orelse (* don't try to prove theorems where some sel_eqns are missing *)
   797 andalso (warning ("no eqns for ctr " ^ Syntax.string_of_term lthy ctr); true)
       
   798             orelse (* don't try to prove theorems when some sel_eqns are missing *)
   796               filter (equal ctr o #ctr) sel_eqns
   799               filter (equal ctr o #ctr) sel_eqns
   797               |> fst o finds ((op =) o apsnd #sel) sels
   800               |> fst o finds ((op =) o apsnd #sel) sels
   798               |> exists (null o snd)
   801               |> exists (null o snd)
   799 andalso (warning ("sel_eqn(s) missing for ctr " ^ Syntax.string_of_term lthy ctr); true)
   802 andalso (warning ("sel_eqn(s) missing for ctr " ^ Syntax.string_of_term lthy ctr); true)
   800             orelse
       
   801               #user_eqn (the (find_first (equal ctr o #ctr) disc_eqns)) = undef_const
       
   802 andalso (warning ("auto-generated disc_eqn for ctr " ^ Syntax.string_of_term lthy ctr); true)
       
   803           then [] else
   803           then [] else
   804             let
   804             let
   805 val _ = tracing ("ctr = " ^ Syntax.string_of_term lthy ctr);
   805 val _ = tracing ("ctr = " ^ Syntax.string_of_term lthy ctr);
   806 val _ = tracing ("disc = " ^ Syntax.string_of_term lthy (#disc (the (find_first (equal ctr o #ctr) disc_eqns))));
   806 val _ = tracing (the_default "NO disc_eqn" (Option.map (curry (op ^) "disc = " o Syntax.string_of_term lthy o #disc) (find_first (equal ctr o #ctr) disc_eqns)));
   807               val {fun_name, fun_T, fun_args, prems, ...} =
   807               val (fun_name, fun_T, fun_args, prems) =
   808                 the (find_first (equal ctr o #ctr) disc_eqns);
   808                 (find_first (equal ctr o #ctr) disc_eqns, find_first (equal ctr o #ctr) sel_eqns)
       
   809                 |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x))
       
   810                 ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, []))
       
   811                 |> the o merge_options;
   809               val m = length prems;
   812               val m = length prems;
   810               val t = sel_eqns
   813               val t = sel_eqns
   811                 |> fst o finds ((op =) o apsnd #sel) sels
   814                 |> fst o finds ((op =) o apsnd #sel) sels
   812                 |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x)) #-> abstract)
   815                 |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x)) #-> abstract)
   813                 |> curry list_comb ctr
   816                 |> curry list_comb ctr
   814                 |> curry HOLogic.mk_eq (list_comb (Free (fun_name, fun_T),
   817                 |> curry HOLogic.mk_eq (list_comb (Free (fun_name, fun_T),
   815                   map Bound (length fun_args - 1 downto 0)))
   818                   map Bound (length fun_args - 1 downto 0)))
   816                 |> HOLogic.mk_Trueprop
   819                 |> HOLogic.mk_Trueprop
   817                 |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
   820                 |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
   818                 |> curry Logic.list_all (map dest_Free fun_args);
   821                 |> curry Logic.list_all (map dest_Free fun_args);
   819               val disc_thm = the_default TrueI (AList.lookup (op =) disc_thms disc);
   822               val maybe_disc_thm = AList.lookup (op =) disc_thms disc;
   820               val sel_thms = map snd (filter (member (op =) sels o fst) sel_thms');
   823               val sel_thms = map snd (filter (member (op =) sels o fst) sel_thms');
   821 val _ = tracing ("t = " ^ Syntax.string_of_term lthy t);
   824 val _ = tracing ("t = " ^ Syntax.string_of_term lthy t);
   822 val _ = tracing ("m = " ^ @{make_string} m);
   825 val _ = tracing ("m = " ^ @{make_string} m);
   823 val _ = tracing ("collapse = " ^ @{make_string} collapse);
   826 val _ = tracing ("collapse = " ^ @{make_string} collapse);
   824 val _ = tracing ("disc_thm = " ^ @{make_string} disc_thm);
   827 val _ = tracing ("maybe_disc_thm = " ^ @{make_string} maybe_disc_thm);
   825 val _ = tracing ("sel_thms = " ^ @{make_string} sel_thms);
   828 val _ = tracing ("sel_thms = " ^ @{make_string} sel_thms);
   826             in
   829             in
   827               mk_primcorec_ctr_of_dtr_tac lthy m collapse disc_thm sel_thms
   830               mk_primcorec_ctr_of_dtr_tac lthy m collapse maybe_disc_thm sel_thms
   828               |> K |> Goal.prove lthy [] [] t
   831               |> K |> Goal.prove lthy [] [] t
   829               |> single
   832               |> single
       
   833 (*handle ERROR x => (warning x; []))*)
       
   834 end
   830           end;
   835           end;
   831 
   836 
   832         val (disc_notes, disc_thmss) =
   837         val (disc_notes, disc_thmss) =
   833           fun_names ~~ map3 (maps oo prove_disc) corec_specs exclssss disc_eqnss
   838           fun_names ~~ map3 (maps oo prove_disc) corec_specs exclssss disc_eqnss
   834           |> `(map (fn (fun_name, thms) =>
   839           |> `(map (fn (fun_name, thms) =>