src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML
author blanchet
Wed Apr 23 10:23:26 2014 +0200 (2014-04-23)
changeset 56638 092a306bcc3d
parent 56254 a2dd9200854d
child 56651 fc105315822a
permissions -rw-r--r--
generate size instances for new-style datatypes
     1 (*  Title:      HOL/Tools/BNF/bnf_lfp_rec_sugar.ML
     2     Author:     Lorenz Panny, TU Muenchen
     3     Author:     Jasmin Blanchette, TU Muenchen
     4     Copyright   2013
     5 
     6 New-style recursor sugar ("primrec").
     7 *)
     8 
     9 signature BNF_LFP_REC_SUGAR =
    10 sig
    11   datatype primrec_option = Nonexhaustive_Option
    12 
    13   type basic_lfp_sugar =
    14     {T: typ,
    15      fp_res_index: int,
    16      C: typ,
    17      fun_arg_Tsss : typ list list list,
    18      ctr_defs: thm list,
    19      ctr_sugar: Ctr_Sugar.ctr_sugar,
    20      recx: term,
    21      rec_thms: thm list};
    22 
    23   type lfp_rec_extension =
    24     {nested_simps: thm list,
    25      is_new_datatype: Proof.context -> string -> bool,
    26      get_basic_lfp_sugars: binding list -> typ list -> term list ->
    27        (term * term list list) list list -> local_theory ->
    28        typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * bool * local_theory,
    29      rewrite_nested_rec_call: Proof.context -> (term -> bool) -> (string -> int) -> typ list ->
    30        term -> term -> term -> term};
    31 
    32   exception PRIMREC of string * term list;
    33 
    34   val register_lfp_rec_extension: lfp_rec_extension -> theory -> theory
    35 
    36   val add_primrec: (binding * typ option * mixfix) list ->
    37     (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory
    38   val add_primrec_cmd: primrec_option list -> (binding * string option * mixfix) list ->
    39     (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory
    40   val add_primrec_global: (binding * typ option * mixfix) list ->
    41     (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
    42   val add_primrec_overloaded: (string * (string * typ) * bool) list ->
    43     (binding * typ option * mixfix) list ->
    44     (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
    45   val add_primrec_simple: ((binding * typ) * mixfix) list -> term list ->
    46     local_theory -> (string list * (term list * (int list list * thm list list))) * local_theory
    47 end;
    48 
    49 structure BNF_LFP_Rec_Sugar : BNF_LFP_REC_SUGAR =
    50 struct
    51 
    52 open Ctr_Sugar
    53 open Ctr_Sugar_Util
    54 open Ctr_Sugar_General_Tactics
    55 open BNF_FP_Rec_Sugar_Util
    56 
    57 val inductN = "induct"
    58 val simpsN = "simps"
    59 
    60 val nitpicksimp_attrs = @{attributes [nitpick_simp]};
    61 val simp_attrs = @{attributes [simp]};
    62 val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs @ simp_attrs;
    63 
    64 exception OLD_PRIMREC of unit;
    65 exception PRIMREC of string * term list;
    66 
    67 datatype primrec_option = Nonexhaustive_Option;
    68 
    69 datatype rec_call =
    70   No_Rec of int * typ |
    71   Mutual_Rec of (int * typ) * (int * typ) |
    72   Nested_Rec of int * typ;
    73 
    74 type rec_ctr_spec =
    75   {ctr: term,
    76    offset: int,
    77    calls: rec_call list,
    78    rec_thm: thm};
    79 
    80 type rec_spec =
    81   {recx: term,
    82    nested_map_idents: thm list,
    83    nested_map_comps: thm list,
    84    ctr_specs: rec_ctr_spec list};
    85 
    86 type basic_lfp_sugar =
    87   {T: typ,
    88    fp_res_index: int,
    89    C: typ,
    90    fun_arg_Tsss : typ list list list,
    91    ctr_defs: thm list,
    92    ctr_sugar: ctr_sugar,
    93    recx: term,
    94    rec_thms: thm list};
    95 
    96 type lfp_rec_extension =
    97   {nested_simps: thm list,
    98    is_new_datatype: Proof.context -> string -> bool,
    99    get_basic_lfp_sugars: binding list -> typ list -> term list ->
   100      (term * term list list) list list -> local_theory ->
   101      typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * bool * local_theory,
   102    rewrite_nested_rec_call: Proof.context -> (term -> bool) -> (string -> int) -> typ list ->
   103      term -> term -> term -> term};
   104 
   105 structure Data = Theory_Data
   106 (
   107   type T = lfp_rec_extension option;
   108   val empty = NONE;
   109   val extend = I;
   110   val merge = merge_options;
   111 );
   112 
   113 val register_lfp_rec_extension = Data.put o SOME;
   114 
   115 fun nested_simps ctxt =
   116   (case Data.get (Proof_Context.theory_of ctxt) of
   117     SOME {nested_simps, ...} => nested_simps
   118   | NONE => []);
   119 
   120 fun is_new_datatype ctxt =
   121   (case Data.get (Proof_Context.theory_of ctxt) of
   122     SOME {is_new_datatype, ...} => is_new_datatype ctxt
   123   | NONE => K false);
   124 
   125 fun get_basic_lfp_sugars bs arg_Ts callers callssss lthy =
   126   (case Data.get (Proof_Context.theory_of lthy) of
   127     SOME {get_basic_lfp_sugars, ...} => get_basic_lfp_sugars bs arg_Ts callers callssss lthy
   128   | NONE => error "Functionality not loaded yet");
   129 
   130 fun rewrite_nested_rec_call ctxt =
   131   (case Data.get (Proof_Context.theory_of ctxt) of
   132     SOME {rewrite_nested_rec_call, ...} => rewrite_nested_rec_call ctxt);
   133 
   134 fun rec_specs_of bs arg_Ts res_Ts callers callssss0 lthy0 =
   135   let
   136     val thy = Proof_Context.theory_of lthy0;
   137 
   138     val (missing_arg_Ts, perm0_kks, basic_lfp_sugars, nested_map_idents, nested_map_comps,
   139          induct_thm, n2m, lthy) =
   140       get_basic_lfp_sugars bs arg_Ts callers callssss0 lthy0;
   141 
   142     val perm_basic_lfp_sugars = sort (int_ord o pairself #fp_res_index) basic_lfp_sugars;
   143 
   144     val indices = map #fp_res_index basic_lfp_sugars;
   145     val perm_indices = map #fp_res_index perm_basic_lfp_sugars;
   146 
   147     val perm_ctrss = map (#ctrs o #ctr_sugar) perm_basic_lfp_sugars;
   148 
   149     val nn0 = length arg_Ts;
   150     val nn = length perm_ctrss;
   151     val kks = 0 upto nn - 1;
   152 
   153     val perm_ctr_offsets = map (fn kk => Integer.sum (map length (take kk perm_ctrss))) kks;
   154 
   155     val perm_lfpTs = map (body_type o fastype_of o hd) perm_ctrss;
   156     val perm_Cs = map #C perm_basic_lfp_sugars;
   157     val perm_fun_arg_Tssss = map #fun_arg_Tsss perm_basic_lfp_sugars;
   158 
   159     fun unpermute0 perm0_xs = permute_like_unique (op =) perm0_kks kks perm0_xs;
   160     fun unpermute perm_xs = permute_like_unique (op =) perm_indices indices perm_xs;
   161 
   162     val induct_thms = unpermute0 (conj_dests nn induct_thm);
   163 
   164     val lfpTs = unpermute perm_lfpTs;
   165     val Cs = unpermute perm_Cs;
   166     val ctr_offsets = unpermute perm_ctr_offsets;
   167 
   168     val As_rho = tvar_subst thy (take nn0 lfpTs) arg_Ts;
   169     val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn res_Ts;
   170 
   171     val substA = Term.subst_TVars As_rho;
   172     val substAT = Term.typ_subst_TVars As_rho;
   173     val substCT = Term.typ_subst_TVars Cs_rho;
   174     val substACT = substAT o substCT;
   175 
   176     val perm_Cs' = map substCT perm_Cs;
   177 
   178     fun call_of [i] [T] = (if exists_subtype_in Cs T then Nested_Rec else No_Rec) (i, substACT T)
   179       | call_of [i, i'] [T, T'] = Mutual_Rec ((i, substACT T), (i', substACT T'));
   180 
   181     fun mk_ctr_spec ctr offset fun_arg_Tss rec_thm =
   182       let
   183         val (fun_arg_hss, _) = indexedd fun_arg_Tss 0;
   184         val fun_arg_hs = flat_rec_arg_args fun_arg_hss;
   185         val fun_arg_iss = map (map (find_index_eq fun_arg_hs)) fun_arg_hss;
   186       in
   187         {ctr = substA ctr, offset = offset, calls = map2 call_of fun_arg_iss fun_arg_Tss,
   188          rec_thm = rec_thm}
   189       end;
   190 
   191     fun mk_ctr_specs fp_res_index k ctrs rec_thms =
   192       map4 mk_ctr_spec ctrs (k upto k + length ctrs - 1) (nth perm_fun_arg_Tssss fp_res_index)
   193         rec_thms;
   194 
   195     fun mk_spec ctr_offset
   196         ({T, fp_res_index, ctr_sugar = {ctrs, ...}, recx, rec_thms, ...} : basic_lfp_sugar) =
   197       {recx = mk_co_rec thy Least_FP (substAT T) perm_Cs' recx,
   198        nested_map_idents = nested_map_idents, nested_map_comps = nested_map_comps,
   199        ctr_specs = mk_ctr_specs fp_res_index ctr_offset ctrs rec_thms};
   200   in
   201     ((n2m, map2 mk_spec ctr_offsets basic_lfp_sugars, missing_arg_Ts, induct_thm, induct_thms),
   202      lthy)
   203   end;
   204 
   205 val undef_const = Const (@{const_name undefined}, dummyT);
   206 
   207 type eqn_data = {
   208   fun_name: string,
   209   rec_type: typ,
   210   ctr: term,
   211   ctr_args: term list,
   212   left_args: term list,
   213   right_args: term list,
   214   res_type: typ,
   215   rhs_term: term,
   216   user_eqn: term
   217 };
   218 
   219 fun dissect_eqn lthy fun_names eqn' =
   220   let
   221     val eqn = drop_all eqn' |> HOLogic.dest_Trueprop
   222       handle TERM _ =>
   223              raise PRIMREC ("malformed function equation (expected \"lhs = rhs\")", [eqn']);
   224     val (lhs, rhs) = HOLogic.dest_eq eqn
   225         handle TERM _ =>
   226                raise PRIMREC ("malformed function equation (expected \"lhs = rhs\")", [eqn']);
   227     val (fun_name, args) = strip_comb lhs
   228       |>> (fn x => if is_Free x then fst (dest_Free x)
   229           else raise PRIMREC ("malformed function equation (does not start with free)", [eqn]));
   230     val (left_args, rest) = take_prefix is_Free args;
   231     val (nonfrees, right_args) = take_suffix is_Free rest;
   232     val num_nonfrees = length nonfrees;
   233     val _ = num_nonfrees = 1 orelse if num_nonfrees = 0 then
   234       raise PRIMREC ("constructor pattern missing in left-hand side", [eqn]) else
   235       raise PRIMREC ("more than one non-variable argument in left-hand side", [eqn]);
   236     val _ = member (op =) fun_names fun_name orelse
   237       raise PRIMREC ("malformed function equation (does not start with function name)", [eqn]);
   238 
   239     val (ctr, ctr_args) = strip_comb (the_single nonfrees);
   240     val _ = try (num_binder_types o fastype_of) ctr = SOME (length ctr_args) orelse
   241       raise PRIMREC ("partially applied constructor in pattern", [eqn]);
   242     val _ = let val d = duplicates (op =) (left_args @ ctr_args @ right_args) in null d orelse
   243       raise PRIMREC ("duplicate variable \"" ^ Syntax.string_of_term lthy (hd d) ^
   244         "\" in left-hand side", [eqn]) end;
   245     val _ = forall is_Free ctr_args orelse
   246       raise PRIMREC ("non-primitive pattern in left-hand side", [eqn]);
   247     val _ =
   248       let val b = fold_aterms (fn x as Free (v, _) =>
   249         if (not (member (op =) (left_args @ ctr_args @ right_args) x) andalso
   250         not (member (op =) fun_names v) andalso
   251         not (Variable.is_fixed lthy v)) then cons x else I | _ => I) rhs []
   252       in
   253         null b orelse
   254         raise PRIMREC ("extra variable(s) in right-hand side: " ^
   255           commas (map (Syntax.string_of_term lthy) b), [eqn])
   256       end;
   257   in
   258     {fun_name = fun_name,
   259      rec_type = body_type (type_of ctr),
   260      ctr = ctr,
   261      ctr_args = ctr_args,
   262      left_args = left_args,
   263      right_args = right_args,
   264      res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs,
   265      rhs_term = rhs,
   266      user_eqn = eqn'}
   267   end;
   268 
   269 fun subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls =
   270   let
   271     fun try_nested_rec bound_Ts y t =
   272       AList.lookup (op =) nested_calls y
   273       |> Option.map (fn y' => rewrite_nested_rec_call lthy has_call get_ctr_pos bound_Ts y y' t);
   274 
   275     fun subst bound_Ts (t as g' $ y) =
   276         let
   277           fun subst_rec () = subst bound_Ts g' $ subst bound_Ts y;
   278           val y_head = head_of y;
   279         in
   280           if not (member (op =) ctr_args y_head) then
   281             subst_rec ()
   282           else
   283             (case try_nested_rec bound_Ts y_head t of
   284               SOME t' => t'
   285             | NONE =>
   286               let val (g, g_args) = strip_comb g' in
   287                 (case try (get_ctr_pos o fst o dest_Free) g of
   288                   SOME ctr_pos =>
   289                   (length g_args >= ctr_pos orelse
   290                    raise PRIMREC ("too few arguments in recursive call", [t]);
   291                    (case AList.lookup (op =) mutual_calls y of
   292                      SOME y' => list_comb (y', g_args)
   293                    | NONE => subst_rec ()))
   294                 | NONE => subst_rec ())
   295               end)
   296         end
   297       | subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
   298       | subst _ t = t
   299 
   300     fun subst' t =
   301       if has_call t then
   302         (* FIXME detect this case earlier? *)
   303         raise PRIMREC ("recursive call not directly applied to constructor argument", [t])
   304       else
   305         try_nested_rec [] (head_of t) t |> the_default t
   306   in
   307     subst' o subst []
   308   end;
   309 
   310 fun build_rec_arg lthy (funs_data : eqn_data list list) has_call (ctr_spec : rec_ctr_spec)
   311     (eqn_data_opt : eqn_data option) =
   312   (case eqn_data_opt of
   313     NONE => undef_const
   314   | SOME {ctr_args, left_args, right_args, rhs_term = t, ...} =>
   315     let
   316       val calls = #calls ctr_spec;
   317       val n_args = fold (Integer.add o (fn Mutual_Rec _ => 2 | _ => 1)) calls 0;
   318 
   319       val no_calls' = tag_list 0 calls
   320         |> map_filter (try (apsnd (fn No_Rec p => p | Mutual_Rec (p, _) => p)));
   321       val mutual_calls' = tag_list 0 calls
   322         |> map_filter (try (apsnd (fn Mutual_Rec (_, p) => p)));
   323       val nested_calls' = tag_list 0 calls
   324         |> map_filter (try (apsnd (fn Nested_Rec p => p)));
   325 
   326       fun ensure_unique frees t =
   327         if member (op =) frees t then Free (the_single (Term.variant_frees t [dest_Free t])) else t;
   328 
   329       val args = replicate n_args ("", dummyT)
   330         |> Term.rename_wrt_term t
   331         |> map Free
   332         |> fold (fn (ctr_arg_idx, (arg_idx, _)) =>
   333             nth_map arg_idx (K (nth ctr_args ctr_arg_idx)))
   334           no_calls'
   335         |> fold (fn (ctr_arg_idx, (arg_idx, T)) => fn xs =>
   336             nth_map arg_idx (K (ensure_unique xs (retype_free T (nth ctr_args ctr_arg_idx)))) xs)
   337           mutual_calls'
   338         |> fold (fn (ctr_arg_idx, (arg_idx, T)) =>
   339             nth_map arg_idx (K (retype_free T (nth ctr_args ctr_arg_idx))))
   340           nested_calls';
   341 
   342       val fun_name_ctr_pos_list =
   343         map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
   344       val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1;
   345       val mutual_calls = map (apfst (nth ctr_args) o apsnd (nth args o fst)) mutual_calls';
   346       val nested_calls = map (apfst (nth ctr_args) o apsnd (nth args o fst)) nested_calls';
   347     in
   348       t
   349       |> subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls
   350       |> fold_rev lambda (args @ left_args @ right_args)
   351     end);
   352 
   353 fun build_defs lthy nonexhaustive bs mxs (funs_data : eqn_data list list)
   354     (rec_specs : rec_spec list) has_call =
   355   let
   356     val n_funs = length funs_data;
   357 
   358     val ctr_spec_eqn_data_list' =
   359       (take n_funs rec_specs |> map #ctr_specs) ~~ funs_data
   360       |> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y))
   361           ##> (fn x => null x orelse
   362             raise PRIMREC ("excess equations in definition", map #rhs_term x)) #> fst);
   363     val _ = ctr_spec_eqn_data_list' |> map (fn ({ctr, ...}, x) =>
   364         if length x > 1 then raise PRIMREC ("multiple equations for constructor", map #user_eqn x)
   365         else if length x = 1 orelse nonexhaustive then ()
   366         else warning ("no equation for constructor " ^ Syntax.string_of_term lthy ctr));
   367 
   368     val ctr_spec_eqn_data_list =
   369       ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair []));
   370 
   371     val recs = take n_funs rec_specs |> map #recx;
   372     val rec_args = ctr_spec_eqn_data_list
   373       |> sort ((op <) o pairself (#offset o fst) |> make_ord)
   374       |> map (uncurry (build_rec_arg lthy funs_data has_call) o apsnd (try the_single));
   375     val ctr_poss = map (fn x =>
   376       if length (distinct ((op =) o pairself (length o #left_args)) x) <> 1 then
   377         raise PRIMREC ("inconstant constructor pattern position for function " ^
   378           quote (#fun_name (hd x)), [])
   379       else
   380         hd x |> #left_args |> length) funs_data;
   381   in
   382     (recs, ctr_poss)
   383     |-> map2 (fn recx => fn ctr_pos => list_comb (recx, rec_args) |> permute_args ctr_pos)
   384     |> Syntax.check_terms lthy
   385     |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.conceal (Thm.def_binding b), []), t)))
   386       bs mxs
   387   end;
   388 
   389 fun find_rec_calls has_call ({ctr, ctr_args, rhs_term, ...} : eqn_data) =
   390   let
   391     fun find bound_Ts (Abs (_, T, b)) ctr_arg = find (T :: bound_Ts) b ctr_arg
   392       | find bound_Ts (t as _ $ _) ctr_arg =
   393         let
   394           val typof = curry fastype_of1 bound_Ts;
   395           val (f', args') = strip_comb t;
   396           val n = find_index (equal ctr_arg o head_of) args';
   397         in
   398           if n < 0 then
   399             find bound_Ts f' ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args'
   400           else
   401             let
   402               val (f, args as arg :: _) = chop n args' |>> curry list_comb f'
   403               val (arg_head, arg_args) = Term.strip_comb arg;
   404             in
   405               if has_call f then
   406                 mk_partial_compN (length arg_args) (typof arg_head) f ::
   407                 maps (fn x => find bound_Ts x ctr_arg) args
   408               else
   409                 find bound_Ts f ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args
   410             end
   411         end
   412       | find _ _ _ = [];
   413   in
   414     map (find [] rhs_term) ctr_args
   415     |> (fn [] => NONE | callss => SOME (ctr, callss))
   416   end;
   417 
   418 fun mk_primrec_tac ctxt num_extra_args map_idents map_comps fun_defs recx =
   419   unfold_thms_tac ctxt fun_defs THEN
   420   HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN
   421   unfold_thms_tac ctxt (nested_simps ctxt @ map_comps @ map_idents) THEN
   422   HEADGOAL (rtac refl);
   423 
   424 fun prepare_primrec nonexhaustive fixes specs lthy0 =
   425   let
   426     val thy = Proof_Context.theory_of lthy0;
   427 
   428     val (bs, mxs) = map_split (apfst fst) fixes;
   429     val fun_names = map Binding.name_of bs;
   430     val eqns_data = map (dissect_eqn lthy0 fun_names) specs;
   431     val funs_data = eqns_data
   432       |> partition_eq ((op =) o pairself #fun_name)
   433       |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
   434       |> map (fn (x, y) => the_single y
   435           handle List.Empty => raise PRIMREC ("missing equations for function " ^ quote x, []));
   436 
   437     val frees = map (fst #>> Binding.name_of #> Free) fixes;
   438     val has_call = exists_subterm (member (op =) frees);
   439     val arg_Ts = map (#rec_type o hd) funs_data;
   440     val res_Ts = map (#res_type o hd) funs_data;
   441     val callssss = funs_data
   442       |> map (partition_eq ((op =) o pairself #ctr))
   443       |> map (maps (map_filter (find_rec_calls has_call)));
   444 
   445     fun is_only_old_datatype (Type (s, _)) =
   446         is_some (Datatype_Data.get_info thy s) andalso not (is_new_datatype lthy0 s)
   447       | is_only_old_datatype _ = false;
   448 
   449     val _ = if exists is_only_old_datatype arg_Ts then raise OLD_PRIMREC () else ();
   450     val _ = (case filter_out (fn (_, T) => Sign.of_sort thy (T, @{sort type})) (bs ~~ res_Ts) of
   451         [] => ()
   452       | (b, _) :: _ => raise PRIMREC ("type of " ^ Binding.print b ^ " contains top sort", []));
   453 
   454     val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy) =
   455       rec_specs_of bs arg_Ts res_Ts frees callssss lthy0;
   456 
   457     val actual_nn = length funs_data;
   458 
   459     val ctrs = maps (map #ctr o #ctr_specs) rec_specs;
   460     val _ =
   461       map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
   462         raise PRIMREC ("argument " ^ quote (Syntax.string_of_term lthy ctr) ^
   463           " is not a constructor in left-hand side", [user_eqn])) eqns_data;
   464 
   465     val defs = build_defs lthy nonexhaustive bs mxs funs_data rec_specs has_call;
   466 
   467     fun prove lthy' def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec)
   468         (fun_data : eqn_data list) =
   469       let
   470         val js =
   471           find_indices (op = o pairself (fn {fun_name, ctr, ...} => (fun_name, ctr)))
   472             fun_data eqns_data;
   473 
   474         val def_thms = map (snd o snd) def_thms';
   475         val simp_thms = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
   476           |> fst
   477           |> map_filter (try (fn (x, [y]) =>
   478             (#fun_name x, #user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
   479           |> map2 (fn j => fn (fun_name, user_eqn, num_extra_args, rec_thm) =>
   480               mk_primrec_tac lthy' num_extra_args nested_map_idents nested_map_comps def_thms
   481                 rec_thm
   482               |> K |> Goal.prove_sorry lthy' [] [] user_eqn
   483               (* for code extraction from proof terms: *)
   484               |> singleton (Proof_Context.export lthy' lthy)
   485               |> Thm.name_derivation (Sign.full_name thy (Binding.name fun_name) ^
   486                 Long_Name.separator ^ simpsN ^
   487                 (if js = [0] then "" else "_" ^ string_of_int (j + 1))))
   488             js;
   489       in
   490         (js, simp_thms)
   491       end;
   492 
   493     val notes =
   494       (if n2m then
   495          map2 (fn name => fn thm => (name, inductN, [thm], [])) fun_names
   496            (take actual_nn induct_thms)
   497        else
   498          [])
   499       |> map (fn (prefix, thmN, thms, attrs) =>
   500         ((Binding.qualify true prefix (Binding.name thmN), attrs), [(thms, [])]));
   501 
   502     val common_name = mk_common_name fun_names;
   503 
   504     val common_notes =
   505       (if n2m then [(inductN, [induct_thm], [])] else [])
   506       |> map (fn (thmN, thms, attrs) =>
   507         ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
   508   in
   509     (((fun_names, defs),
   510       fn lthy => fn defs =>
   511         split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)),
   512       lthy |> Local_Theory.notes (notes @ common_notes) |> snd)
   513   end;
   514 
   515 fun add_primrec_simple' opts fixes ts lthy =
   516   let
   517     val nonexhaustive = member (op =) opts Nonexhaustive_Option;
   518     val (((names, defs), prove), lthy') = prepare_primrec nonexhaustive fixes ts lthy
   519       handle ERROR str => raise PRIMREC (str, []);
   520   in
   521     lthy'
   522     |> fold_map Local_Theory.define defs
   523     |-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs))))
   524   end
   525   handle PRIMREC (str, eqns) =>
   526          if null eqns then
   527            error ("primrec error:\n  " ^ str)
   528          else
   529            error ("primrec error:\n  " ^ str ^ "\nin\n  " ^
   530              space_implode "\n  " (map (quote o Syntax.string_of_term lthy) eqns));
   531 
   532 val add_primrec_simple = add_primrec_simple' [];
   533 
   534 fun gen_primrec old_primrec prep_spec opts
   535     (raw_fixes : (binding * 'a option * mixfix) list) raw_spec lthy =
   536   let
   537     val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes)
   538     val _ = null d orelse raise PRIMREC ("duplicate function name(s): " ^ commas d, []);
   539 
   540     val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy);
   541 
   542     val mk_notes =
   543       flat ooo map3 (fn js => fn prefix => fn thms =>
   544         let
   545           val (bs, attrss) = map_split (fst o nth specs) js;
   546           val notes =
   547             map3 (fn b => fn attrs => fn thm =>
   548                 ((Binding.qualify false prefix b, code_nitpicksimp_simp_attrs @ attrs),
   549                  [([thm], [])]))
   550               bs attrss thms;
   551         in
   552           ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes
   553         end);
   554   in
   555     lthy
   556     |> add_primrec_simple' opts fixes (map snd specs)
   557     |-> (fn (names, (ts, (jss, simpss))) =>
   558       Spec_Rules.add Spec_Rules.Equational (ts, flat simpss)
   559       #> Local_Theory.notes (mk_notes jss names simpss)
   560       #>> pair ts o map snd)
   561   end
   562   handle OLD_PRIMREC () => old_primrec raw_fixes raw_spec lthy |>> apsnd single;
   563 
   564 val add_primrec = gen_primrec Primrec.add_primrec Specification.check_spec [];
   565 val add_primrec_cmd = gen_primrec Primrec.add_primrec_cmd Specification.read_spec;
   566 
   567 fun add_primrec_global fixes specs =
   568   Named_Target.theory_init
   569   #> add_primrec fixes specs
   570   ##> Local_Theory.exit_global;
   571 
   572 fun add_primrec_overloaded ops fixes specs =
   573   Overloading.overloading ops
   574   #> add_primrec fixes specs
   575   ##> Local_Theory.exit_global;
   576 
   577 val primrec_option_parser = Parse.group (fn () => "option")
   578   (Parse.reserved "nonexhaustive" >> K Nonexhaustive_Option)
   579 
   580 val _ = Outer_Syntax.local_theory @{command_spec "primrec"}
   581   "define primitive recursive functions"
   582   ((Scan.optional (@{keyword "("} |--
   583       Parse.!!! (Parse.list1 primrec_option_parser) --| @{keyword ")"}) []) --
   584     (Parse.fixes -- Parse_Spec.where_alt_specs)
   585     >> (fn (opts, (fixes, spec)) => snd o add_primrec_cmd opts fixes spec));
   586 
   587 end;