src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML
changeset 56121 52e8f110fec3
parent 55869 54ddb003e128
child 56254 a2dd9200854d
equal deleted inserted replaced
56120:04c37dfef684 56121:52e8f110fec3
     6 Recursor sugar ("primrec").
     6 Recursor sugar ("primrec").
     7 *)
     7 *)
     8 
     8 
     9 signature BNF_LFP_REC_SUGAR =
     9 signature BNF_LFP_REC_SUGAR =
    10 sig
    10 sig
       
    11   datatype primrec_option = Nonexhaustive_Option
       
    12 
    11   type basic_lfp_sugar =
    13   type basic_lfp_sugar =
    12     {T: typ,
    14     {T: typ,
    13      fp_res_index: int,
    15      fp_res_index: int,
    14      C: typ,
    16      C: typ,
    15      fun_arg_Tsss : typ list list list,
    17      fun_arg_Tsss : typ list list list,
    31 
    33 
    32   val register_lfp_rec_extension: lfp_rec_extension -> theory -> theory
    34   val register_lfp_rec_extension: lfp_rec_extension -> theory -> theory
    33 
    35 
    34   val add_primrec: (binding * typ option * mixfix) list ->
    36   val add_primrec: (binding * typ option * mixfix) list ->
    35     (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory
    37     (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory
    36   val add_primrec_cmd: (binding * string option * mixfix) list ->
    38   val add_primrec_cmd: primrec_option list -> (binding * string option * mixfix) list ->
    37     (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory
    39     (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory
    38   val add_primrec_global: (binding * typ option * mixfix) list ->
    40   val add_primrec_global: (binding * typ option * mixfix) list ->
    39     (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
    41     (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
    40   val add_primrec_overloaded: (string * (string * typ) * bool) list ->
    42   val add_primrec_overloaded: (string * (string * typ) * bool) list ->
    41     (binding * typ option * mixfix) list ->
    43     (binding * typ option * mixfix) list ->
    59 val simp_attrs = @{attributes [simp]};
    61 val simp_attrs = @{attributes [simp]};
    60 val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs @ simp_attrs;
    62 val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs @ simp_attrs;
    61 
    63 
    62 exception OLD_PRIMREC of unit;
    64 exception OLD_PRIMREC of unit;
    63 exception PRIMREC of string * term list;
    65 exception PRIMREC of string * term list;
       
    66 
       
    67 datatype primrec_option = Nonexhaustive_Option;
    64 
    68 
    65 datatype rec_call =
    69 datatype rec_call =
    66   No_Rec of int * typ |
    70   No_Rec of int * typ |
    67   Mutual_Rec of (int * typ) * (int * typ) |
    71   Mutual_Rec of (int * typ) * (int * typ) |
    68   Nested_Rec of int * typ;
    72   Nested_Rec of int * typ;
   344       t
   348       t
   345       |> subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls
   349       |> subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls
   346       |> fold_rev lambda (args @ left_args @ right_args)
   350       |> fold_rev lambda (args @ left_args @ right_args)
   347     end);
   351     end);
   348 
   352 
   349 fun build_defs lthy bs mxs (funs_data : eqn_data list list) (rec_specs : rec_spec list) has_call =
   353 fun build_defs lthy nonexhaustive bs mxs (funs_data : eqn_data list list)
       
   354     (rec_specs : rec_spec list) has_call =
   350   let
   355   let
   351     val n_funs = length funs_data;
   356     val n_funs = length funs_data;
   352 
   357 
   353     val ctr_spec_eqn_data_list' =
   358     val ctr_spec_eqn_data_list' =
   354       (take n_funs rec_specs |> map #ctr_specs) ~~ funs_data
   359       (take n_funs rec_specs |> map #ctr_specs) ~~ funs_data
   355       |> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y))
   360       |> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y))
   356           ##> (fn x => null x orelse
   361           ##> (fn x => null x orelse
   357             raise PRIMREC ("excess equations in definition", map #rhs_term x)) #> fst);
   362             raise PRIMREC ("excess equations in definition", map #rhs_term x)) #> fst);
   358     val _ = ctr_spec_eqn_data_list' |> map (fn (_, x) => length x <= 1 orelse
   363     val _ = ctr_spec_eqn_data_list' |> map (fn ({ctr, ...}, x) =>
   359       raise PRIMREC ("multiple equations for constructor", map #user_eqn x));
   364         if length x > 1 then raise PRIMREC ("multiple equations for constructor", map #user_eqn x)
       
   365         else if length x = 1 orelse nonexhaustive then ()
       
   366         else warning ("no equation for constructor " ^ Syntax.string_of_term lthy ctr));
   360 
   367 
   361     val ctr_spec_eqn_data_list =
   368     val ctr_spec_eqn_data_list =
   362       ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair []));
   369       ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair []));
   363 
   370 
   364     val recs = take n_funs rec_specs |> map #recx;
   371     val recs = take n_funs rec_specs |> map #recx;
   412   unfold_thms_tac ctxt fun_defs THEN
   419   unfold_thms_tac ctxt fun_defs THEN
   413   HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN
   420   HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN
   414   unfold_thms_tac ctxt (nested_simps ctxt @ map_comps @ map_idents) THEN
   421   unfold_thms_tac ctxt (nested_simps ctxt @ map_comps @ map_idents) THEN
   415   HEADGOAL (rtac refl);
   422   HEADGOAL (rtac refl);
   416 
   423 
   417 fun prepare_primrec fixes specs lthy0 =
   424 fun prepare_primrec nonexhaustive fixes specs lthy0 =
   418   let
   425   let
   419     val thy = Proof_Context.theory_of lthy0;
   426     val thy = Proof_Context.theory_of lthy0;
   420 
   427 
   421     val (bs, mxs) = map_split (apfst fst) fixes;
   428     val (bs, mxs) = map_split (apfst fst) fixes;
   422     val fun_names = map Binding.name_of bs;
   429     val fun_names = map Binding.name_of bs;
   453     val _ =
   460     val _ =
   454       map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
   461       map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
   455         raise PRIMREC ("argument " ^ quote (Syntax.string_of_term lthy ctr) ^
   462         raise PRIMREC ("argument " ^ quote (Syntax.string_of_term lthy ctr) ^
   456           " is not a constructor in left-hand side", [user_eqn])) eqns_data;
   463           " is not a constructor in left-hand side", [user_eqn])) eqns_data;
   457 
   464 
   458     val defs = build_defs lthy bs mxs funs_data rec_specs has_call;
   465     val defs = build_defs lthy nonexhaustive bs mxs funs_data rec_specs has_call;
   459 
   466 
   460     fun prove lthy' def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec)
   467     fun prove lthy' def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec)
   461         (fun_data : eqn_data list) =
   468         (fun_data : eqn_data list) =
   462       let
   469       let
   463         val js =
   470         val js =
   503       fn lthy => fn defs =>
   510       fn lthy => fn defs =>
   504         split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)),
   511         split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)),
   505       lthy |> Local_Theory.notes (notes @ common_notes) |> snd)
   512       lthy |> Local_Theory.notes (notes @ common_notes) |> snd)
   506   end;
   513   end;
   507 
   514 
   508 fun add_primrec_simple fixes ts lthy =
   515 fun add_primrec_simple' opts fixes ts lthy =
   509   let
   516   let
   510     val (((names, defs), prove), lthy') = prepare_primrec fixes ts lthy
   517     val nonexhaustive = member (op =) opts Nonexhaustive_Option;
       
   518     val (((names, defs), prove), lthy') = prepare_primrec nonexhaustive fixes ts lthy
   511       handle ERROR str => raise PRIMREC (str, []);
   519       handle ERROR str => raise PRIMREC (str, []);
   512   in
   520   in
   513     lthy'
   521     lthy'
   514     |> fold_map Local_Theory.define defs
   522     |> fold_map Local_Theory.define defs
   515     |-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs))))
   523     |-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs))))
   519            error ("primrec error:\n  " ^ str)
   527            error ("primrec error:\n  " ^ str)
   520          else
   528          else
   521            error ("primrec error:\n  " ^ str ^ "\nin\n  " ^
   529            error ("primrec error:\n  " ^ str ^ "\nin\n  " ^
   522              space_implode "\n  " (map (quote o Syntax.string_of_term lthy) eqns));
   530              space_implode "\n  " (map (quote o Syntax.string_of_term lthy) eqns));
   523 
   531 
   524 fun gen_primrec old_primrec prep_spec (raw_fixes : (binding * 'a option * mixfix) list) raw_spec
   532 val add_primrec_simple = add_primrec_simple' [];
   525     lthy =
   533 
       
   534 fun gen_primrec old_primrec prep_spec opts
       
   535     (raw_fixes : (binding * 'a option * mixfix) list) raw_spec lthy =
   526   let
   536   let
   527     val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes)
   537     val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes)
   528     val _ = null d orelse raise PRIMREC ("duplicate function name(s): " ^ commas d, []);
   538     val _ = null d orelse raise PRIMREC ("duplicate function name(s): " ^ commas d, []);
   529 
   539 
   530     val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy);
   540     val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy);
   541         in
   551         in
   542           ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes
   552           ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes
   543         end);
   553         end);
   544   in
   554   in
   545     lthy
   555     lthy
   546     |> add_primrec_simple fixes (map snd specs)
   556     |> add_primrec_simple' opts fixes (map snd specs)
   547     |-> (fn (names, (ts, (jss, simpss))) =>
   557     |-> (fn (names, (ts, (jss, simpss))) =>
   548       Spec_Rules.add Spec_Rules.Equational (ts, flat simpss)
   558       Spec_Rules.add Spec_Rules.Equational (ts, flat simpss)
   549       #> Local_Theory.notes (mk_notes jss names simpss)
   559       #> Local_Theory.notes (mk_notes jss names simpss)
   550       #>> pair ts o map snd)
   560       #>> pair ts o map snd)
   551   end
   561   end
   552   handle OLD_PRIMREC () => old_primrec raw_fixes raw_spec lthy |>> apsnd single;
   562   handle OLD_PRIMREC () => old_primrec raw_fixes raw_spec lthy |>> apsnd single;
   553 
   563 
   554 val add_primrec = gen_primrec Primrec.add_primrec Specification.check_spec;
   564 val add_primrec = gen_primrec Primrec.add_primrec Specification.check_spec [];
   555 val add_primrec_cmd = gen_primrec Primrec.add_primrec_cmd Specification.read_spec;
   565 val add_primrec_cmd = gen_primrec Primrec.add_primrec_cmd Specification.read_spec;
   556 
   566 
   557 fun add_primrec_global fixes specs =
   567 fun add_primrec_global fixes specs =
   558   Named_Target.theory_init
   568   Named_Target.theory_init
   559   #> add_primrec fixes specs
   569   #> add_primrec fixes specs
   562 fun add_primrec_overloaded ops fixes specs =
   572 fun add_primrec_overloaded ops fixes specs =
   563   Overloading.overloading ops
   573   Overloading.overloading ops
   564   #> add_primrec fixes specs
   574   #> add_primrec fixes specs
   565   ##> Local_Theory.exit_global;
   575   ##> Local_Theory.exit_global;
   566 
   576 
       
   577 val primrec_option_parser = Parse.group (fn () => "option")
       
   578   (Parse.reserved "nonexhaustive" >> K Nonexhaustive_Option)
       
   579 
   567 val _ = Outer_Syntax.local_theory @{command_spec "primrec"}
   580 val _ = Outer_Syntax.local_theory @{command_spec "primrec"}
   568   "define primitive recursive functions"
   581   "define primitive recursive functions"
   569   (Parse.fixes -- Parse_Spec.where_alt_specs >> (snd oo uncurry add_primrec_cmd));
   582   ((Scan.optional (@{keyword "("} |--
       
   583       Parse.!!! (Parse.list1 primrec_option_parser) --| @{keyword ")"}) []) --
       
   584     (Parse.fixes -- Parse_Spec.where_alt_specs)
       
   585     >> (fn (opts, (fixes, spec)) => snd o add_primrec_cmd opts fixes spec));
   570 
   586 
   571 end;
   587 end;