src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML
changeset 55535 10194808430d
parent 55530 3dfb724db099
child 55538 6a5986170c1d
equal deleted inserted replaced
55534:b18bdcbda41b 55535:10194808430d
   464   HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN
   464   HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN
   465   unfold_thms_tac ctxt (@{thms id_def split comp_def fst_conv snd_conv} @ map_comps @
   465   unfold_thms_tac ctxt (@{thms id_def split comp_def fst_conv snd_conv} @ map_comps @
   466     map_idents) THEN
   466     map_idents) THEN
   467   HEADGOAL (rtac refl);
   467   HEADGOAL (rtac refl);
   468 
   468 
   469 fun prepare_primrec fixes specs lthy =
   469 fun prepare_primrec fixes specs lthy0 =
   470   let
   470   let
   471     val thy = Proof_Context.theory_of lthy;
   471     val thy = Proof_Context.theory_of lthy0;
   472 
   472 
   473     val (bs, mxs) = map_split (apfst fst) fixes;
   473     val (bs, mxs) = map_split (apfst fst) fixes;
   474     val fun_names = map Binding.name_of bs;
   474     val fun_names = map Binding.name_of bs;
   475     val eqns_data = map (dissect_eqn lthy fun_names) specs;
   475     val eqns_data = map (dissect_eqn lthy0 fun_names) specs;
   476     val funs_data = eqns_data
   476     val funs_data = eqns_data
   477       |> partition_eq ((op =) o pairself #fun_name)
   477       |> partition_eq ((op =) o pairself #fun_name)
   478       |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
   478       |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
   479       |> map (fn (x, y) => the_single y
   479       |> map (fn (x, y) => the_single y
   480           handle List.Empty =>
   480           handle List.Empty =>
   486     val callssss = funs_data
   486     val callssss = funs_data
   487       |> map (partition_eq ((op =) o pairself #ctr))
   487       |> map (partition_eq ((op =) o pairself #ctr))
   488       |> map (maps (map_filter (find_rec_calls has_call)));
   488       |> map (maps (map_filter (find_rec_calls has_call)));
   489 
   489 
   490     fun is_only_old_datatype (Type (s, _)) =
   490     fun is_only_old_datatype (Type (s, _)) =
   491         is_none (fp_sugar_of lthy s) andalso is_some (Datatype_Data.get_info thy s)
   491         is_none (fp_sugar_of lthy0 s) andalso is_some (Datatype_Data.get_info thy s)
   492       | is_only_old_datatype _ = false;
   492       | is_only_old_datatype _ = false;
   493 
   493 
   494     val _ = if exists is_only_old_datatype arg_Ts then raise OLD_PRIMREC () else ();
   494     val _ = if exists is_only_old_datatype arg_Ts then raise OLD_PRIMREC () else ();
   495     val _ = (case filter_out (fn (_, T) => Sign.of_sort thy (T, HOLogic.typeS)) (bs ~~ res_Ts) of
   495     val _ = (case filter_out (fn (_, T) => Sign.of_sort thy (T, HOLogic.typeS)) (bs ~~ res_Ts) of
   496         [] => ()
   496         [] => ()
   497       | (b, _) :: _ => primrec_error ("type of " ^ Binding.print b ^ " contains top sort"));
   497       | (b, _) :: _ => primrec_error ("type of " ^ Binding.print b ^ " contains top sort"));
   498 
   498 
   499     val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy') =
   499     val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy) =
   500       rec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy;
   500       rec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy0;
   501 
   501 
   502     val actual_nn = length funs_data;
   502     val actual_nn = length funs_data;
   503 
   503 
   504     val _ = let val ctrs = (maps (map #ctr o #ctr_specs) rec_specs) in
   504     val _ = let val ctrs = (maps (map #ctr o #ctr_specs) rec_specs) in
   505       map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
   505       map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
   506         primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy' ctr) ^
   506         primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy ctr) ^
   507           " is not a constructor in left-hand side") user_eqn) eqns_data end;
   507           " is not a constructor in left-hand side") user_eqn) eqns_data end;
   508 
   508 
   509     val defs = build_defs lthy' bs mxs funs_data rec_specs has_call;
   509     val defs = build_defs lthy bs mxs funs_data rec_specs has_call;
   510 
   510 
   511     fun prove lthy def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec)
   511     fun prove lthy' def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec)
   512         (fun_data : eqn_data list) =
   512         (fun_data : eqn_data list) =
   513       let
   513       let
       
   514         val js =
       
   515           find_indices (op = o pairself (fn {fun_name, ctr, ...} => (fun_name, ctr)))
       
   516             fun_data eqns_data;
       
   517 
   514         val def_thms = map (snd o snd) def_thms';
   518         val def_thms = map (snd o snd) def_thms';
   515         val simp_thmss = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
   519         val simp_thms = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
   516           |> fst
   520           |> fst
   517           |> map_filter (try (fn (x, [y]) =>
   521           |> map_filter (try (fn (x, [y]) =>
   518             (#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
   522             (#fun_name x, #user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
   519           |> map (fn (user_eqn, num_extra_args, rec_thm) =>
   523           |> map2 (fn j => fn (fun_name, user_eqn, num_extra_args, rec_thm) =>
   520             mk_primrec_tac lthy num_extra_args nested_map_idents nested_map_comps def_thms rec_thm
   524               mk_primrec_tac lthy' num_extra_args nested_map_idents nested_map_comps def_thms rec_thm
   521             |> K |> Goal.prove_sorry lthy [] [] user_eqn
   525               |> K |> Goal.prove_sorry lthy' [] [] user_eqn
   522             |> Thm.close_derivation);
   526               (* for code extraction from proof terms: *)
   523         val poss =
   527               |> singleton (Proof_Context.export lthy' lthy)
   524           find_indices (op = o pairself (fn {fun_name, ctr, ...} => (fun_name, ctr)))
   528               |> Thm.name_derivation (Sign.full_name thy (Binding.name fun_name) ^
   525             fun_data eqns_data;
   529                 Long_Name.separator ^ simpsN ^
       
   530                 (if js = [0] then "" else "_" ^ string_of_int (j + 1))))
       
   531             js;
   526       in
   532       in
   527         (poss, simp_thmss)
   533         (js, simp_thms)
   528       end;
   534       end;
   529 
   535 
   530     val notes =
   536     val notes =
   531       (if n2m then
   537       (if n2m then
   532          map2 (fn name => fn thm =>
   538          map2 (fn name => fn thm =>
   544         ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
   550         ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
   545   in
   551   in
   546     (((fun_names, defs),
   552     (((fun_names, defs),
   547       fn lthy => fn defs =>
   553       fn lthy => fn defs =>
   548         split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)),
   554         split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)),
   549       lthy' |> Local_Theory.notes (notes @ common_notes) |> snd)
   555       lthy |> Local_Theory.notes (notes @ common_notes) |> snd)
   550   end;
   556   end;
   551 
   557 
   552 fun add_primrec_simple fixes ts lthy =
   558 fun add_primrec_simple fixes ts lthy =
   553   let
   559   let
   554     val (((names, defs), prove), lthy') = prepare_primrec fixes ts lthy
   560     val (((names, defs), prove), lthy') = prepare_primrec fixes ts lthy
   572     val _ = null d orelse primrec_error ("duplicate function name(s): " ^ commas d);
   578     val _ = null d orelse primrec_error ("duplicate function name(s): " ^ commas d);
   573 
   579 
   574     val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy);
   580     val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy);
   575 
   581 
   576     val mk_notes =
   582     val mk_notes =
   577       flat ooo map3 (fn poss => fn prefix => fn thms =>
   583       flat ooo map3 (fn js => fn prefix => fn thms =>
   578         let
   584         let
   579           val (bs, attrss) = map_split (fst o nth specs) poss;
   585           val (bs, attrss) = map_split (fst o nth specs) js;
   580           val notes =
   586           val notes =
   581             map3 (fn b => fn attrs => fn thm =>
   587             map3 (fn b => fn attrs => fn thm =>
   582                 ((Binding.qualify false prefix b, code_nitpicksimp_simp_attrs @ attrs),
   588                 ((Binding.qualify false prefix b, code_nitpicksimp_simp_attrs @ attrs),
   583                  [([thm], [])]))
   589                  [([thm], [])]))
   584               bs attrss thms;
   590               bs attrss thms;
   586           ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes
   592           ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes
   587         end);
   593         end);
   588   in
   594   in
   589     lthy
   595     lthy
   590     |> add_primrec_simple fixes (map snd specs)
   596     |> add_primrec_simple fixes (map snd specs)
   591     |-> (fn (names, (ts, (posss, simpss))) =>
   597     |-> (fn (names, (ts, (jss, simpss))) =>
   592       Spec_Rules.add Spec_Rules.Equational (ts, flat simpss)
   598       Spec_Rules.add Spec_Rules.Equational (ts, flat simpss)
   593       #> Local_Theory.notes (mk_notes posss names simpss)
   599       #> Local_Theory.notes (mk_notes jss names simpss)
   594       #>> pair ts o map snd)
   600       #>> pair ts o map snd)
   595   end
   601   end
   596   handle OLD_PRIMREC () => old_primrec raw_fixes raw_spec lthy |>> apsnd single;
   602   handle OLD_PRIMREC () => old_primrec raw_fixes raw_spec lthy |>> apsnd single;
   597 
   603 
   598 val add_primrec = gen_primrec Primrec.add_primrec Specification.check_spec;
   604 val add_primrec = gen_primrec Primrec.add_primrec Specification.check_spec;
   599 val add_primrec_cmd = gen_primrec Primrec.add_primrec_cmd Specification.read_spec;
   605 val add_primrec_cmd = gen_primrec Primrec.add_primrec_cmd Specification.read_spec;
   600 
   606 
   601 fun add_primrec_global fixes specs thy =
   607 fun add_primrec_global fixes specs =
   602   let
   608   Named_Target.theory_init
   603     val lthy = Named_Target.theory_init thy;
   609   #> add_primrec fixes specs
   604     val ((ts, simpss), lthy') = add_primrec fixes specs lthy;
   610   ##> Local_Theory.exit_global;
   605     val simpss' = burrow (Proof_Context.export lthy' lthy) simpss;
   611 
   606   in ((ts, simpss'), Local_Theory.exit_global lthy') end;
   612 fun add_primrec_overloaded ops fixes specs =
   607 
   613   Overloading.overloading ops
   608 fun add_primrec_overloaded ops fixes specs thy =
   614   #> add_primrec fixes specs
   609   let
   615   ##> Local_Theory.exit_global;
   610     val lthy = Overloading.overloading ops thy;
       
   611     val ((ts, simpss), lthy') = add_primrec fixes specs lthy;
       
   612     val simpss' = burrow (Proof_Context.export lthy' lthy) simpss;
       
   613   in ((ts, simpss'), Local_Theory.exit_global lthy') end;
       
   614 
   616 
   615 val _ = Outer_Syntax.local_theory @{command_spec "primrec"}
   617 val _ = Outer_Syntax.local_theory @{command_spec "primrec"}
   616   "define primitive recursive functions"
   618   "define primitive recursive functions"
   617   (Parse.fixes -- Parse_Spec.where_alt_specs >> (snd oo uncurry add_primrec_cmd));
   619   (Parse.fixes -- Parse_Spec.where_alt_specs >> (snd oo uncurry add_primrec_cmd));
   618 
   620