src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
changeset 54132 af11e99e519c
parent 54120 c2f18fd05414
child 54133 a22ded8a7f7d
equal deleted inserted replaced
54131:18b23d787062 54132:af11e99e519c
   472 
   472 
   473 datatype co_eqn_data =
   473 datatype co_eqn_data =
   474   Disc of co_eqn_data_disc |
   474   Disc of co_eqn_data_disc |
   475   Sel of co_eqn_data_sel;
   475   Sel of co_eqn_data_sel;
   476 
   476 
   477 fun co_dissect_eqn_disc seq fun_names (corec_specs : corec_spec list) maybe_ctr_rhs maybe_code_rhs
   477 fun co_dissect_eqn_disc seq fun_names (ctr_specss : corec_ctr_spec list list) maybe_ctr_rhs
   478     prems' concl matchedsss =
   478     maybe_code_rhs prems' concl matchedsss =
   479   let
   479   let
   480     fun find_subterm p = let (* FIXME \<exists>? *)
   480     fun find_subterm p = let (* FIXME \<exists>? *)
   481       fun f (t as u $ v) = if p t then SOME t else merge_options (f u, f v)
   481       fun f (t as u $ v) = if p t then SOME t else merge_options (f u, f v)
   482         | f t = if p t then SOME t else NONE
   482         | f t = if p t then SOME t else NONE
   483       in f end;
   483       in f end;
   485     val applied_fun = concl
   485     val applied_fun = concl
   486       |> find_subterm (member ((op =) o apsnd SOME) fun_names o try (fst o dest_Free o head_of))
   486       |> find_subterm (member ((op =) o apsnd SOME) fun_names o try (fst o dest_Free o head_of))
   487       |> the
   487       |> the
   488       handle Option.Option => primrec_error_eqn "malformed discriminator equation" concl;
   488       handle Option.Option => primrec_error_eqn "malformed discriminator equation" concl;
   489     val ((fun_name, fun_T), fun_args) = strip_comb applied_fun |>> dest_Free;
   489     val ((fun_name, fun_T), fun_args) = strip_comb applied_fun |>> dest_Free;
   490     val {ctr_specs, ...} = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name);
   490     val ctr_specs = the (AList.lookup (op =) (fun_names ~~ ctr_specss) fun_name);
   491 
   491 
   492     val discs = map #disc ctr_specs;
   492     val discs = map #disc ctr_specs;
   493     val ctrs = map #ctr ctr_specs;
   493     val ctrs = map #ctr ctr_specs;
   494     val not_disc = head_of concl = @{term Not};
   494     val not_disc = head_of concl = @{term Not};
   495     val _ = not_disc andalso length ctrs <> 2 andalso
   495     val _ = not_disc andalso length ctrs <> 2 andalso
   533       maybe_code_rhs = maybe_code_rhs,
   533       maybe_code_rhs = maybe_code_rhs,
   534       user_eqn = user_eqn
   534       user_eqn = user_eqn
   535     }, matchedsss')
   535     }, matchedsss')
   536   end;
   536   end;
   537 
   537 
   538 fun co_dissect_eqn_sel fun_names (corec_specs : corec_spec list) eqn' of_spec eqn =
   538 fun co_dissect_eqn_sel fun_names (ctr_specss : corec_ctr_spec list list) eqn' of_spec eqn =
   539   let
   539   let
   540     val (lhs, rhs) = HOLogic.dest_eq eqn
   540     val (lhs, rhs) = HOLogic.dest_eq eqn
   541       handle TERM _ =>
   541       handle TERM _ =>
   542         primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn;
   542         primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn;
   543     val sel = head_of lhs;
   543     val sel = head_of lhs;
   544     val ((fun_name, fun_T), fun_args) = dest_comb lhs |> snd |> strip_comb |> apfst dest_Free
   544     val ((fun_name, fun_T), fun_args) = dest_comb lhs |> snd |> strip_comb |> apfst dest_Free
   545       handle TERM _ =>
   545       handle TERM _ =>
   546         primrec_error_eqn "malformed selector argument in left-hand side" eqn;
   546         primrec_error_eqn "malformed selector argument in left-hand side" eqn;
   547     val corec_spec = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name)
   547     val ctr_specs = the (AList.lookup (op =) (fun_names ~~ ctr_specss) fun_name)
   548       handle Option.Option => primrec_error_eqn "malformed selector argument in left-hand side" eqn;
   548       handle Option.Option => primrec_error_eqn "malformed selector argument in left-hand side" eqn;
   549     val ctr_spec =
   549     val ctr_spec =
   550       if is_some of_spec
   550       if is_some of_spec
   551       then the (find_first (equal (the of_spec) o #ctr) (#ctr_specs corec_spec))
   551       then the (find_first (equal (the of_spec) o #ctr) ctr_specs)
   552       else #ctr_specs corec_spec |> filter (exists (equal sel) o #sels) |> the_single
   552       else ctr_specs |> filter (exists (equal sel) o #sels) |> the_single
   553         handle List.Empty => primrec_error_eqn "ambiguous selector - use \"of\"" eqn;
   553         handle List.Empty => primrec_error_eqn "ambiguous selector - use \"of\"" eqn;
   554     val user_eqn = drop_All eqn';
   554     val user_eqn = drop_All eqn';
   555   in
   555   in
   556     Sel {
   556     Sel {
   557       fun_name = fun_name,
   557       fun_name = fun_name,
   562       rhs_term = rhs,
   562       rhs_term = rhs,
   563       user_eqn = user_eqn
   563       user_eqn = user_eqn
   564     }
   564     }
   565   end;
   565   end;
   566 
   566 
   567 fun co_dissect_eqn_ctr seq fun_names (corec_specs : corec_spec list) eqn' maybe_code_rhs
   567 fun co_dissect_eqn_ctr seq fun_names (ctr_specss : corec_ctr_spec list list) eqn' maybe_code_rhs
   568     prems concl matchedsss =
   568     prems concl matchedsss =
   569   let
   569   let
   570     val (lhs, rhs) = HOLogic.dest_eq concl;
   570     val (lhs, rhs) = HOLogic.dest_eq concl;
   571     val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free;
   571     val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free;
   572     val {ctr_specs, ...} = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name);
   572     val ctr_specs = the (AList.lookup (op =) (fun_names ~~ ctr_specss) fun_name);
   573     val (ctr, ctr_args) = strip_comb (unfold_let rhs);
   573     val (ctr, ctr_args) = strip_comb (unfold_let rhs);
   574     val {disc, sels, ...} = the (find_first (equal ctr o #ctr) ctr_specs)
   574     val {disc, sels, ...} = the (find_first (equal ctr o #ctr) ctr_specs)
   575       handle Option.Option => primrec_error_eqn "not a constructor" ctr;
   575       handle Option.Option => primrec_error_eqn "not a constructor" ctr;
   576 
   576 
   577     val disc_concl = betapply (disc, lhs);
   577     val disc_concl = betapply (disc, lhs);
   578     val (maybe_eqn_data_disc, matchedsss') = if length ctr_specs = 1
   578     val (maybe_eqn_data_disc, matchedsss') = if length ctr_specs = 1
   579       then (NONE, matchedsss)
   579       then (NONE, matchedsss)
   580       else apfst SOME (co_dissect_eqn_disc seq fun_names corec_specs
   580       else apfst SOME (co_dissect_eqn_disc seq fun_names ctr_specss
   581           (SOME (abstract (List.rev fun_args) rhs)) maybe_code_rhs prems disc_concl matchedsss);
   581           (SOME (abstract (List.rev fun_args) rhs)) maybe_code_rhs prems disc_concl matchedsss);
   582 
   582 
   583     val sel_concls = (sels ~~ ctr_args)
   583     val sel_concls = (sels ~~ ctr_args)
   584       |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
   584       |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
   585 
   585 
   589  space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) sel_concls) ^
   589  space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) sel_concls) ^
   590  "\nfor premise(s)\n    \<cdot> " ^
   590  "\nfor premise(s)\n    \<cdot> " ^
   591  space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) prems));
   591  space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) prems));
   592 *)
   592 *)
   593 
   593 
   594     val eqns_data_sel = map (co_dissect_eqn_sel fun_names corec_specs eqn' (SOME ctr)) sel_concls;
   594     val eqns_data_sel = map (co_dissect_eqn_sel fun_names ctr_specss eqn' (SOME ctr)) sel_concls;
   595   in
   595   in
   596     (the_list maybe_eqn_data_disc @ eqns_data_sel, matchedsss')
   596     (the_list maybe_eqn_data_disc @ eqns_data_sel, matchedsss')
   597   end;
   597   end;
   598 
   598 
   599 fun co_dissect_eqn_code lthy has_call fun_names corec_specs eqn' concl matchedsss =
   599 fun co_dissect_eqn_code lthy has_call fun_names ctr_specss eqn' concl matchedsss =
   600   let
   600   let
   601     val (lhs, (rhs', rhs)) = HOLogic.dest_eq concl ||> `(expand_corec_code_rhs lthy has_call []);
   601     val (lhs, (rhs', rhs)) = HOLogic.dest_eq concl ||> `(expand_corec_code_rhs lthy has_call []);
   602     val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free;
   602     val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free;
   603     val {ctr_specs, ...} = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name);
   603     val ctr_specs = the (AList.lookup (op =) (fun_names ~~ ctr_specss) fun_name);
   604 
   604 
   605     val cond_ctrs = fold_rev_corec_code_rhs lthy (fn cs => fn ctr => fn _ =>
   605     val cond_ctrs = fold_rev_corec_code_rhs lthy (fn cs => fn ctr => fn _ =>
   606         if member ((op =) o apsnd #ctr) ctr_specs ctr
   606         if member ((op =) o apsnd #ctr) ctr_specs ctr
   607         then cons (ctr, cs)
   607         then cons (ctr, cs)
   608         else primrec_error_eqn "not a constructor" ctr) [] rhs' []
   608         else primrec_error_eqn "not a constructor" ctr) [] rhs' []
   614         |> map_index (fn (n, T) => massage_corec_code_rhs lthy (fn _ => fn ctr' => fn args =>
   614         |> map_index (fn (n, T) => massage_corec_code_rhs lthy (fn _ => fn ctr' => fn args =>
   615           if ctr' = ctr then nth args n else Const (@{const_name undefined}, T)) [] rhs')
   615           if ctr' = ctr then nth args n else Const (@{const_name undefined}, T)) [] rhs')
   616         |> curry list_comb ctr
   616         |> curry list_comb ctr
   617         |> curry HOLogic.mk_eq lhs);
   617         |> curry HOLogic.mk_eq lhs);
   618   in
   618   in
   619     fold_map2 (co_dissect_eqn_ctr false fun_names corec_specs eqn'
   619     fold_map2 (co_dissect_eqn_ctr false fun_names ctr_specss eqn'
   620         (SOME (abstract (List.rev fun_args) rhs)))
   620         (SOME (abstract (List.rev fun_args) rhs)))
   621       ctr_premss ctr_concls matchedsss
   621       ctr_premss ctr_concls matchedsss
   622   end;
   622   end;
   623 
   623 
   624 fun co_dissect_eqn lthy seq has_call fun_names (corec_specs : corec_spec list) eqn' of_spec
   624 fun co_dissect_eqn lthy seq has_call fun_names (ctr_specss : corec_ctr_spec list list) eqn' of_spec
   625     matchedsss =
   625     matchedsss =
   626   let
   626   let
   627     val eqn = drop_All eqn'
   627     val eqn = drop_All eqn'
   628       handle TERM _ => primrec_error_eqn "malformed function equation" eqn';
   628       handle TERM _ => primrec_error_eqn "malformed function equation" eqn';
   629     val (prems, concl) = Logic.strip_horn eqn
   629     val (prems, concl) = Logic.strip_horn eqn
   633       |> perhaps (try HOLogic.dest_not) |> perhaps (try (fst o HOLogic.dest_eq))
   633       |> perhaps (try HOLogic.dest_not) |> perhaps (try (fst o HOLogic.dest_eq))
   634       |> head_of;
   634       |> head_of;
   635 
   635 
   636     val maybe_rhs = concl |> perhaps (try (HOLogic.dest_not)) |> try (snd o HOLogic.dest_eq);
   636     val maybe_rhs = concl |> perhaps (try (HOLogic.dest_not)) |> try (snd o HOLogic.dest_eq);
   637 
   637 
   638     val discs = maps #ctr_specs corec_specs |> map #disc;
   638     val discs = maps (map #disc) ctr_specss;
   639     val sels = maps #ctr_specs corec_specs |> maps #sels;
   639     val sels = maps (maps #sels) ctr_specss;
   640     val ctrs = maps #ctr_specs corec_specs |> map #ctr;
   640     val ctrs = maps (map #ctr) ctr_specss;
   641   in
   641   in
   642     if member (op =) discs head orelse
   642     if member (op =) discs head orelse
   643       is_some maybe_rhs andalso
   643       is_some maybe_rhs andalso
   644         member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then
   644         member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then
   645       co_dissect_eqn_disc seq fun_names corec_specs NONE NONE prems concl matchedsss
   645       co_dissect_eqn_disc seq fun_names ctr_specss NONE NONE prems concl matchedsss
   646       |>> single
   646       |>> single
   647     else if member (op =) sels head then
   647     else if member (op =) sels head then
   648       ([co_dissect_eqn_sel fun_names corec_specs eqn' of_spec concl], matchedsss)
   648       ([co_dissect_eqn_sel fun_names ctr_specss eqn' of_spec concl], matchedsss)
   649     else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
   649     else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
   650       member (op =) ctrs (head_of (unfold_let (the maybe_rhs))) then
   650       member (op =) ctrs (head_of (unfold_let (the maybe_rhs))) then
   651       co_dissect_eqn_ctr seq fun_names corec_specs eqn' NONE prems concl matchedsss
   651       co_dissect_eqn_ctr seq fun_names ctr_specss eqn' NONE prems concl matchedsss
   652     else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
   652     else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso
   653       null prems then
   653       null prems then
   654       co_dissect_eqn_code lthy has_call fun_names corec_specs eqn' concl matchedsss
   654       co_dissect_eqn_code lthy has_call fun_names ctr_specss eqn' concl matchedsss
   655       |>> flat
   655       |>> flat
   656     else
   656     else
   657       primrec_error_eqn "malformed function equation" eqn
   657       primrec_error_eqn "malformed function equation" eqn
   658   end;
   658   end;
   659 
   659 
   820     val fun_names = map Binding.name_of bs;
   820     val fun_names = map Binding.name_of bs;
   821     val corec_specs = take actual_nn corec_specs'; (*###*)
   821     val corec_specs = take actual_nn corec_specs'; (*###*)
   822 
   822 
   823     val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
   823     val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
   824     val eqns_data =
   824     val eqns_data =
   825       fold_map2 (co_dissect_eqn lthy seq has_call fun_names corec_specs)
   825       fold_map2 (co_dissect_eqn lthy seq has_call fun_names (map #ctr_specs corec_specs))
   826         (map snd specs) of_specs []
   826         (map snd specs) of_specs []
   827       |> flat o fst;
   827       |> flat o fst;
   828 
   828 
   829     val disc_eqnss' = map_filter (try (fn Disc x => x)) eqns_data
   829     val disc_eqnss' = map_filter (try (fn Disc x => x)) eqns_data
   830       |> partition_eq ((op =) o pairself #fun_name)
   830       |> partition_eq ((op =) o pairself #fun_name)