src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
changeset 53352 43a1cc050943
parent 53341 63015d035301
child 53354 b7469b85ca28
equal deleted inserted replaced
53349:ae8c9380bbc4 53352:43a1cc050943
   251     in
   251     in
   252       t' |> fold_rev absfree abstractions
   252       t' |> fold_rev absfree abstractions
   253     end
   253     end
   254   else Const (@{const_name undefined}, dummyT)
   254   else Const (@{const_name undefined}, dummyT)
   255 
   255 
   256 fun build_defs lthy bs funs_data rec_specs get_indices =
   256 fun build_defs lthy bs mxs funs_data rec_specs get_indices =
   257   let
   257   let
   258     val n_funs = length funs_data;
   258     val n_funs = length funs_data;
   259 
   259 
   260     val ctr_spec_eqn_data_list' =
   260     val ctr_spec_eqn_data_list' =
   261       (take n_funs rec_specs |> map #ctr_specs) ~~ funs_data
   261       (take n_funs rec_specs |> map #ctr_specs) ~~ funs_data
   280         hd x |> #left_args |> length) funs_data;
   280         hd x |> #left_args |> length) funs_data;
   281   in
   281   in
   282     (recs, ctr_poss)
   282     (recs, ctr_poss)
   283     |-> map2 (fn recx => fn ctr_pos => list_comb (recx, rec_args) |> permute_args ctr_pos)
   283     |-> map2 (fn recx => fn ctr_pos => list_comb (recx, rec_args) |> permute_args ctr_pos)
   284     |> Syntax.check_terms lthy
   284     |> Syntax.check_terms lthy
   285     |> map2 (fn b => fn t => ((b, NoSyn), ((Binding.map_name Thm.def_name b, []), t))) bs
   285     |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.map_name Thm.def_name b, []), t))) bs mxs
   286   end;
   286   end;
   287 
   287 
   288 fun find_rec_calls get_indices eqn_data =
   288 fun find_rec_calls get_indices eqn_data =
   289   let
   289   let
   290     fun find (Abs (_, _, b)) ctr_arg = find b ctr_arg
   290     fun find (Abs (_, _, b)) ctr_arg = find b ctr_arg
   309     |> (fn [] => NONE | callss => SOME (#ctr eqn_data, callss))
   309     |> (fn [] => NONE | callss => SOME (#ctr eqn_data, callss))
   310   end;
   310   end;
   311 
   311 
   312 fun add_primrec fixes specs lthy =
   312 fun add_primrec fixes specs lthy =
   313   let
   313   let
   314     val bs = map (fst o fst) fixes;
   314     val (bs, mxs) = map_split (apfst fst) fixes;
   315     val fun_names = map Binding.name_of bs;
   315     val fun_names = map Binding.name_of bs;
   316     val eqns_data = map (snd #> dissect_eqn lthy fun_names) specs;
   316     val eqns_data = map (snd #> dissect_eqn lthy fun_names) specs;
   317     val funs_data = eqns_data
   317     val funs_data = eqns_data
   318       |> partition_eq ((op =) o pairself #fun_name)
   318       |> partition_eq ((op =) o pairself #fun_name)
   319       |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
   319       |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
   338     val _ = let val ctrs = (maps (map #ctr o #ctr_specs) rec_specs) in
   338     val _ = let val ctrs = (maps (map #ctr o #ctr_specs) rec_specs) in
   339       map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
   339       map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
   340         primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy' ctr) ^
   340         primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy' ctr) ^
   341           " is not a constructor in left-hand side") user_eqn) eqns_data end;
   341           " is not a constructor in left-hand side") user_eqn) eqns_data end;
   342 
   342 
   343     val defs = build_defs lthy' bs funs_data rec_specs get_indices;
   343     val defs = build_defs lthy' bs mxs funs_data rec_specs get_indices;
   344 
   344 
   345     fun prove def_thms' {ctr_specs, nested_map_idents, nested_map_comps, ...} induct_thm fun_data
   345     fun prove def_thms' {ctr_specs, nested_map_idents, nested_map_comps, ...} induct_thm fun_data
   346         lthy =
   346         lthy =
   347       let
   347       let
   348         val fun_name = #fun_name (hd fun_data);
   348         val fun_name = #fun_name (hd fun_data);
   636       in
   636       in
   637         abs_args o update_args
   637         abs_args o update_args
   638       end
   638       end
   639   end;
   639   end;
   640 
   640 
   641 fun co_build_defs lthy sequential bs arg_Tss fun_name_corec_spec_list eqns_data =
   641 fun co_build_defs lthy sequential bs mxs arg_Tss fun_name_corec_spec_list eqns_data =
   642   let
   642   let
   643     val fun_names = map Binding.name_of bs;
   643     val fun_names = map Binding.name_of bs;
   644 
   644 
   645     val disc_eqnss = map_filter (try (fn Disc x => x)) eqns_data
   645     val disc_eqnss = map_filter (try (fn Disc x => x)) eqns_data
   646       |> partition_eq ((op =) o pairself #fun_name)
   646       |> partition_eq ((op =) o pairself #fun_name)
   690   in
   690   in
   691     map (list_comb o rpair corec_args) corecs
   691     map (list_comb o rpair corec_args) corecs
   692     |> map2 (fn Ts => fn t => if length Ts = 0 then t $ HOLogic.unit else t) arg_Tss
   692     |> map2 (fn Ts => fn t => if length Ts = 0 then t $ HOLogic.unit else t) arg_Tss
   693     |> map2 currys arg_Tss
   693     |> map2 currys arg_Tss
   694     |> Syntax.check_terms lthy
   694     |> Syntax.check_terms lthy
   695     |> map2 (fn b => fn t => ((b, NoSyn), ((Binding.map_name Thm.def_name b, []), t))) bs
   695     |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.map_name Thm.def_name b, []), t))) bs mxs
   696     |> rpair proof_obligations
   696     |> rpair proof_obligations
   697   end;
   697   end;
   698 
   698 
   699 fun add_primcorec sequential fixes specs lthy =
   699 fun add_primcorec sequential fixes specs lthy =
   700   let
   700   let
   701     val bs = map (fst o fst) fixes;
   701     val (bs, mxs) = map_split (apfst fst) fixes;
   702     val (arg_Ts, res_Ts) = map (strip_type o snd o fst #>> HOLogic.mk_tupleT) fixes |> split_list;
   702     val (arg_Ts, res_Ts) = map (strip_type o snd o fst #>> HOLogic.mk_tupleT) fixes |> split_list;
   703 
   703 
   704     (* copied from primrec_new *)
   704     (* copied from primrec_new *)
   705     fun get_indices t = map (fst #>> Binding.name_of #> Free) fixes
   705     fun get_indices t = map (fst #>> Binding.name_of #> Free) fixes
   706       |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE)
   706       |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE)
   721     val (eqns_data, _) =
   721     val (eqns_data, _) =
   722       fold_map (co_dissect_eqn sequential fun_name_corec_spec_list) (map snd specs) []
   722       fold_map (co_dissect_eqn sequential fun_name_corec_spec_list) (map snd specs) []
   723       |>> flat;
   723       |>> flat;
   724 
   724 
   725     val (defs, proof_obligations) =
   725     val (defs, proof_obligations) =
   726       co_build_defs lthy' sequential bs (map (binder_types o snd o fst) fixes)
   726       co_build_defs lthy' sequential bs mxs (map (binder_types o snd o fst) fixes)
   727         fun_name_corec_spec_list eqns_data;
   727         fun_name_corec_spec_list eqns_data;
   728   in
   728   in
   729     lthy'
   729     lthy'
   730     |> fold_map Local_Theory.define defs |> snd
   730     |> fold_map Local_Theory.define defs |> snd
   731     |> Proof.theorem NONE (K I) [map (rpair []) proof_obligations]
   731     |> Proof.theorem NONE (K I) [map (rpair []) proof_obligations]