src/HOL/Tools/BNF/Tools/bnf_lfp_rec_sugar.ML
author blanchet
Mon Jan 20 18:24:56 2014 +0100 (2014-01-20)
changeset 55058 4e700eb471d4
parent 55005 src/HOL/BNF/Tools/bnf_lfp_rec_sugar.ML@38ea5ee18a06
permissions -rw-r--r--
moved BNF files to 'HOL'
     1 (*  Title:      HOL/BNF/Tools/bnf_lfp_rec_sugar.ML
     2     Author:     Lorenz Panny, TU Muenchen
     3     Author:     Jasmin Blanchette, TU Muenchen
     4     Copyright   2013
     5 
     6 Recursor sugar.
     7 *)
     8 
     9 signature BNF_LFP_REC_SUGAR =
    10 sig
    11   val add_primrec: (binding * typ option * mixfix) list ->
    12     (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory
    13   val add_primrec_cmd: (binding * string option * mixfix) list ->
    14     (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory
    15   val add_primrec_global: (binding * typ option * mixfix) list ->
    16     (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
    17   val add_primrec_overloaded: (string * (string * typ) * bool) list ->
    18     (binding * typ option * mixfix) list ->
    19     (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
    20   val add_primrec_simple: ((binding * typ) * mixfix) list -> term list ->
    21     local_theory -> (string list * (term list * (int list list * thm list list))) * local_theory
    22 end;
    23 
    24 structure BNF_LFP_Rec_Sugar : BNF_LFP_REC_SUGAR =
    25 struct
    26 
    27 open Ctr_Sugar
    28 open BNF_Util
    29 open BNF_Tactics
    30 open BNF_Def
    31 open BNF_FP_Util
    32 open BNF_FP_Def_Sugar
    33 open BNF_FP_N2M_Sugar
    34 open BNF_FP_Rec_Sugar_Util
    35 
    36 val nitpicksimp_attrs = @{attributes [nitpick_simp]};
    37 val simp_attrs = @{attributes [simp]};
    38 val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs @ simp_attrs;
    39 
    40 exception Primrec_Error of string * term list;
    41 
    42 fun primrec_error str = raise Primrec_Error (str, []);
    43 fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]);
    44 fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns);
    45 
    46 datatype rec_call =
    47   No_Rec of int * typ |
    48   Mutual_Rec of (int * typ) * (int * typ) |
    49   Nested_Rec of int * typ;
    50 
    51 type rec_ctr_spec =
    52   {ctr: term,
    53    offset: int,
    54    calls: rec_call list,
    55    rec_thm: thm};
    56 
    57 type rec_spec =
    58   {recx: term,
    59    nested_map_idents: thm list,
    60    nested_map_comps: thm list,
    61    ctr_specs: rec_ctr_spec list};
    62 
    63 exception AINT_NO_MAP of term;
    64 
    65 fun ill_formed_rec_call ctxt t =
    66   error ("Ill-formed recursive call: " ^ quote (Syntax.string_of_term ctxt t));
    67 fun invalid_map ctxt t =
    68   error ("Invalid map function in " ^ quote (Syntax.string_of_term ctxt t));
    69 fun unexpected_rec_call ctxt t =
    70   error ("Unexpected recursive call: " ^ quote (Syntax.string_of_term ctxt t));
    71 
    72 fun massage_nested_rec_call ctxt has_call raw_massage_fun bound_Ts y y' =
    73   let
    74     fun check_no_call t = if has_call t then unexpected_rec_call ctxt t else ();
    75 
    76     val typof = curry fastype_of1 bound_Ts;
    77     val build_map_fst = build_map ctxt (fst_const o fst);
    78 
    79     val yT = typof y;
    80     val yU = typof y';
    81 
    82     fun y_of_y' () = build_map_fst (yU, yT) $ y';
    83     val elim_y = Term.map_aterms (fn t => if t = y then y_of_y' () else t);
    84 
    85     fun massage_mutual_fun U T t =
    86       (case t of
    87         Const (@{const_name comp}, _) $ t1 $ t2 =>
    88         mk_comp bound_Ts (tap check_no_call t1, massage_mutual_fun U T t2)
    89       | _ =>
    90         if has_call t then
    91           (case try HOLogic.dest_prodT U of
    92             SOME (U1, U2) => if U1 = T then raw_massage_fun T U2 t else invalid_map ctxt t
    93           | NONE => invalid_map ctxt t)
    94         else
    95           mk_comp bound_Ts (t, build_map_fst (U, T)));
    96 
    97     fun massage_map (Type (_, Us)) (Type (s, Ts)) t =
    98         (case try (dest_map ctxt s) t of
    99           SOME (map0, fs) =>
   100           let
   101             val Type (_, ran_Ts) = range_type (typof t);
   102             val map' = mk_map (length fs) Us ran_Ts map0;
   103             val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs;
   104           in
   105             Term.list_comb (map', fs')
   106           end
   107         | NONE => raise AINT_NO_MAP t)
   108       | massage_map _ _ t = raise AINT_NO_MAP t
   109     and massage_map_or_map_arg U T t =
   110       if T = U then
   111         tap check_no_call t
   112       else
   113         massage_map U T t
   114         handle AINT_NO_MAP _ => massage_mutual_fun U T t;
   115 
   116     fun massage_call (t as t1 $ t2) =
   117         if has_call t then
   118           if t2 = y then
   119             massage_map yU yT (elim_y t1) $ y'
   120             handle AINT_NO_MAP t' => invalid_map ctxt t'
   121           else
   122             let val (g, xs) = Term.strip_comb t2 in
   123               if g = y then
   124                 if exists has_call xs then unexpected_rec_call ctxt t2
   125                 else Term.list_comb (massage_call (mk_compN (length xs) bound_Ts (t1, y)), xs)
   126               else
   127                 ill_formed_rec_call ctxt t
   128             end
   129         else
   130           elim_y t
   131       | massage_call t = if t = y then y_of_y' () else ill_formed_rec_call ctxt t;
   132   in
   133     massage_call
   134   end;
   135 
   136 fun rec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy0 =
   137   let
   138     val thy = Proof_Context.theory_of lthy0;
   139 
   140     val ((missing_arg_Ts, perm0_kks,
   141           fp_sugars as {nested_bnfs, fp_res = {xtor_co_iterss = ctor_iters1 :: _, ...},
   142             co_inducts = [induct_thm], ...} :: _, (lfp_sugar_thms, _)), lthy) =
   143       nested_to_mutual_fps Least_FP bs arg_Ts get_indices callssss0 lthy0;
   144 
   145     val perm_fp_sugars = sort (int_ord o pairself #index) fp_sugars;
   146 
   147     val indices = map #index fp_sugars;
   148     val perm_indices = map #index perm_fp_sugars;
   149 
   150     val perm_ctrss = map (#ctrs o of_fp_sugar #ctr_sugars) perm_fp_sugars;
   151     val perm_ctr_Tsss = map (map (binder_types o fastype_of)) perm_ctrss;
   152     val perm_lfpTs = map (body_type o fastype_of o hd) perm_ctrss;
   153 
   154     val nn0 = length arg_Ts;
   155     val nn = length perm_lfpTs;
   156     val kks = 0 upto nn - 1;
   157     val perm_ns = map length perm_ctr_Tsss;
   158     val perm_mss = map (map length) perm_ctr_Tsss;
   159 
   160     val perm_Cs = map (body_type o fastype_of o co_rec_of o of_fp_sugar (#xtor_co_iterss o #fp_res))
   161       perm_fp_sugars;
   162     val perm_fun_arg_Tssss =
   163       mk_iter_fun_arg_types perm_ctr_Tsss perm_ns perm_mss (co_rec_of ctor_iters1);
   164 
   165     fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs;
   166     fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs;
   167 
   168     val induct_thms = unpermute0 (conj_dests nn induct_thm);
   169 
   170     val lfpTs = unpermute perm_lfpTs;
   171     val Cs = unpermute perm_Cs;
   172 
   173     val As_rho = tvar_subst thy (take nn0 lfpTs) arg_Ts;
   174     val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn res_Ts;
   175 
   176     val substA = Term.subst_TVars As_rho;
   177     val substAT = Term.typ_subst_TVars As_rho;
   178     val substCT = Term.typ_subst_TVars Cs_rho;
   179     val substACT = substAT o substCT;
   180 
   181     val perm_Cs' = map substCT perm_Cs;
   182 
   183     fun offset_of_ctr 0 _ = 0
   184       | offset_of_ctr n (({ctrs, ...} : ctr_sugar) :: ctr_sugars) =
   185         length ctrs + offset_of_ctr (n - 1) ctr_sugars;
   186 
   187     fun call_of [i] [T] = (if exists_subtype_in Cs T then Nested_Rec else No_Rec) (i, substACT T)
   188       | call_of [i, i'] [T, T'] = Mutual_Rec ((i, substACT T), (i', substACT T'));
   189 
   190     fun mk_ctr_spec ctr offset fun_arg_Tss rec_thm =
   191       let
   192         val (fun_arg_hss, _) = indexedd fun_arg_Tss 0;
   193         val fun_arg_hs = flat_rec_arg_args fun_arg_hss;
   194         val fun_arg_iss = map (map (find_index_eq fun_arg_hs)) fun_arg_hss;
   195       in
   196         {ctr = substA ctr, offset = offset, calls = map2 call_of fun_arg_iss fun_arg_Tss,
   197          rec_thm = rec_thm}
   198       end;
   199 
   200     fun mk_ctr_specs index (ctr_sugars : ctr_sugar list) iter_thmsss =
   201       let
   202         val ctrs = #ctrs (nth ctr_sugars index);
   203         val rec_thms = co_rec_of (nth iter_thmsss index);
   204         val k = offset_of_ctr index ctr_sugars;
   205         val n = length ctrs;
   206       in
   207         map4 mk_ctr_spec ctrs (k upto k + n - 1) (nth perm_fun_arg_Tssss index) rec_thms
   208       end;
   209 
   210     fun mk_spec ({T, index, ctr_sugars, co_iterss = iterss, co_iter_thmsss = iter_thmsss, ...}
   211       : fp_sugar) =
   212       {recx = mk_co_iter thy Least_FP (substAT T) perm_Cs' (co_rec_of (nth iterss index)),
   213        nested_map_idents = map (unfold_thms lthy @{thms id_def} o map_id0_of_bnf) nested_bnfs,
   214        nested_map_comps = map map_comp_of_bnf nested_bnfs,
   215        ctr_specs = mk_ctr_specs index ctr_sugars iter_thmsss};
   216   in
   217     ((is_some lfp_sugar_thms, map mk_spec fp_sugars, missing_arg_Ts, induct_thm, induct_thms), lthy)
   218   end;
   219 
   220 val undef_const = Const (@{const_name undefined}, dummyT);
   221 
   222 fun permute_args n t =
   223   list_comb (t, map Bound (0 :: (n downto 1))) |> fold (K (Term.abs (Name.uu, dummyT))) (0 upto n);
   224 
   225 type eqn_data = {
   226   fun_name: string,
   227   rec_type: typ,
   228   ctr: term,
   229   ctr_args: term list,
   230   left_args: term list,
   231   right_args: term list,
   232   res_type: typ,
   233   rhs_term: term,
   234   user_eqn: term
   235 };
   236 
   237 fun dissect_eqn lthy fun_names eqn' =
   238   let
   239     val eqn = drop_all eqn' |> HOLogic.dest_Trueprop
   240       handle TERM _ =>
   241         primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
   242     val (lhs, rhs) = HOLogic.dest_eq eqn
   243         handle TERM _ =>
   244           primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
   245     val (fun_name, args) = strip_comb lhs
   246       |>> (fn x => if is_Free x then fst (dest_Free x)
   247           else primrec_error_eqn "malformed function equation (does not start with free)" eqn);
   248     val (left_args, rest) = take_prefix is_Free args;
   249     val (nonfrees, right_args) = take_suffix is_Free rest;
   250     val num_nonfrees = length nonfrees;
   251     val _ = num_nonfrees = 1 orelse if num_nonfrees = 0 then
   252       primrec_error_eqn "constructor pattern missing in left-hand side" eqn else
   253       primrec_error_eqn "more than one non-variable argument in left-hand side" eqn;
   254     val _ = member (op =) fun_names fun_name orelse
   255       primrec_error_eqn "malformed function equation (does not start with function name)" eqn
   256 
   257     val (ctr, ctr_args) = strip_comb (the_single nonfrees);
   258     val _ = try (num_binder_types o fastype_of) ctr = SOME (length ctr_args) orelse
   259       primrec_error_eqn "partially applied constructor in pattern" eqn;
   260     val _ = let val d = duplicates (op =) (left_args @ ctr_args @ right_args) in null d orelse
   261       primrec_error_eqn ("duplicate variable \"" ^ Syntax.string_of_term lthy (hd d) ^
   262         "\" in left-hand side") eqn end;
   263     val _ = forall is_Free ctr_args orelse
   264       primrec_error_eqn "non-primitive pattern in left-hand side" eqn;
   265     val _ =
   266       let val b = fold_aterms (fn x as Free (v, _) =>
   267         if (not (member (op =) (left_args @ ctr_args @ right_args) x) andalso
   268         not (member (op =) fun_names v) andalso
   269         not (Variable.is_fixed lthy v)) then cons x else I | _ => I) rhs []
   270       in
   271         null b orelse
   272         primrec_error_eqn ("extra variable(s) in right-hand side: " ^
   273           commas (map (Syntax.string_of_term lthy) b)) eqn
   274       end;
   275   in
   276     {fun_name = fun_name,
   277      rec_type = body_type (type_of ctr),
   278      ctr = ctr,
   279      ctr_args = ctr_args,
   280      left_args = left_args,
   281      right_args = right_args,
   282      res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs,
   283      rhs_term = rhs,
   284      user_eqn = eqn'}
   285   end;
   286 
   287 fun rewrite_map_arg get_ctr_pos rec_type res_type =
   288   let
   289     val pT = HOLogic.mk_prodT (rec_type, res_type);
   290 
   291     fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT)
   292       | subst d (Abs (v, T, b)) =
   293         Abs (v, if d = SOME ~1 then pT else T, subst (Option.map (Integer.add 1) d) b)
   294       | subst d t =
   295         let
   296           val (u, vs) = strip_comb t;
   297           val ctr_pos = try (get_ctr_pos o fst o dest_Free) u |> the_default ~1;
   298         in
   299           if ctr_pos >= 0 then
   300             if d = SOME ~1 andalso length vs = ctr_pos then
   301               list_comb (permute_args ctr_pos (snd_const pT), vs)
   302             else if length vs > ctr_pos andalso is_some d
   303                 andalso d = try (fn Bound n => n) (nth vs ctr_pos) then
   304               list_comb (snd_const pT $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs))
   305             else
   306               primrec_error_eqn ("recursive call not directly applied to constructor argument") t
   307           else
   308             list_comb (u, map (subst (d |> d = SOME ~1 ? K NONE)) vs)
   309         end
   310   in
   311     subst (SOME ~1)
   312   end;
   313 
   314 fun subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls =
   315   let
   316     fun try_nested_rec bound_Ts y t =
   317       AList.lookup (op =) nested_calls y
   318       |> Option.map (fn y' =>
   319         massage_nested_rec_call lthy has_call (rewrite_map_arg get_ctr_pos) bound_Ts y y' t);
   320 
   321     fun subst bound_Ts (t as g' $ y) =
   322         let
   323           fun subst_rec () = subst bound_Ts g' $ subst bound_Ts y;
   324           val y_head = head_of y;
   325         in
   326           if not (member (op =) ctr_args y_head) then
   327             subst_rec ()
   328           else
   329             (case try_nested_rec bound_Ts y_head t of
   330               SOME t' => t'
   331             | NONE =>
   332               let val (g, g_args) = strip_comb g' in
   333                 (case try (get_ctr_pos o fst o dest_Free) g of
   334                   SOME ctr_pos =>
   335                   (length g_args >= ctr_pos orelse
   336                    primrec_error_eqn "too few arguments in recursive call" t;
   337                    (case AList.lookup (op =) mutual_calls y of
   338                      SOME y' => list_comb (y', g_args)
   339                    | NONE => subst_rec ()))
   340                 | NONE => subst_rec ())
   341               end)
   342         end
   343       | subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
   344       | subst _ t = t
   345 
   346     fun subst' t =
   347       if has_call t then
   348         (* FIXME detect this case earlier? *)
   349         primrec_error_eqn "recursive call not directly applied to constructor argument" t
   350       else
   351         try_nested_rec [] (head_of t) t |> the_default t
   352   in
   353     subst' o subst []
   354   end;
   355 
   356 fun build_rec_arg lthy (funs_data : eqn_data list list) has_call (ctr_spec : rec_ctr_spec)
   357     (eqn_data_opt : eqn_data option) =
   358   (case eqn_data_opt of
   359     NONE => undef_const
   360   | SOME {ctr_args, left_args, right_args, rhs_term = t, ...} =>
   361     let
   362       val calls = #calls ctr_spec;
   363       val n_args = fold (Integer.add o (fn Mutual_Rec _ => 2 | _ => 1)) calls 0;
   364 
   365       val no_calls' = tag_list 0 calls
   366         |> map_filter (try (apsnd (fn No_Rec p => p | Mutual_Rec (p, _) => p)));
   367       val mutual_calls' = tag_list 0 calls
   368         |> map_filter (try (apsnd (fn Mutual_Rec (_, p) => p)));
   369       val nested_calls' = tag_list 0 calls
   370         |> map_filter (try (apsnd (fn Nested_Rec p => p)));
   371 
   372       fun ensure_unique frees t =
   373         if member (op =) frees t then Free (the_single (Term.variant_frees t [dest_Free t])) else t;
   374 
   375       val args = replicate n_args ("", dummyT)
   376         |> Term.rename_wrt_term t
   377         |> map Free
   378         |> fold (fn (ctr_arg_idx, (arg_idx, _)) =>
   379             nth_map arg_idx (K (nth ctr_args ctr_arg_idx)))
   380           no_calls'
   381         |> fold (fn (ctr_arg_idx, (arg_idx, T)) => fn xs =>
   382             nth_map arg_idx (K (ensure_unique xs (retype_free T (nth ctr_args ctr_arg_idx)))) xs)
   383           mutual_calls'
   384         |> fold (fn (ctr_arg_idx, (arg_idx, T)) =>
   385             nth_map arg_idx (K (retype_free T (nth ctr_args ctr_arg_idx))))
   386           nested_calls';
   387 
   388       val fun_name_ctr_pos_list =
   389         map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
   390       val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1;
   391       val mutual_calls = map (apfst (nth ctr_args) o apsnd (nth args o fst)) mutual_calls';
   392       val nested_calls = map (apfst (nth ctr_args) o apsnd (nth args o fst)) nested_calls';
   393     in
   394       t
   395       |> subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls
   396       |> fold_rev lambda (args @ left_args @ right_args)
   397     end);
   398 
   399 fun build_defs lthy bs mxs (funs_data : eqn_data list list) (rec_specs : rec_spec list) has_call =
   400   let
   401     val n_funs = length funs_data;
   402 
   403     val ctr_spec_eqn_data_list' =
   404       (take n_funs rec_specs |> map #ctr_specs) ~~ funs_data
   405       |> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y))
   406           ##> (fn x => null x orelse
   407             primrec_error_eqns "excess equations in definition" (map #rhs_term x)) #> fst);
   408     val _ = ctr_spec_eqn_data_list' |> map (fn (_, x) => length x <= 1 orelse
   409       primrec_error_eqns ("multiple equations for constructor") (map #user_eqn x));
   410 
   411     val ctr_spec_eqn_data_list =
   412       ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair []));
   413 
   414     val recs = take n_funs rec_specs |> map #recx;
   415     val rec_args = ctr_spec_eqn_data_list
   416       |> sort ((op <) o pairself (#offset o fst) |> make_ord)
   417       |> map (uncurry (build_rec_arg lthy funs_data has_call) o apsnd (try the_single));
   418     val ctr_poss = map (fn x =>
   419       if length (distinct ((op =) o pairself (length o #left_args)) x) <> 1 then
   420         primrec_error ("inconstant constructor pattern position for function " ^
   421           quote (#fun_name (hd x)))
   422       else
   423         hd x |> #left_args |> length) funs_data;
   424   in
   425     (recs, ctr_poss)
   426     |-> map2 (fn recx => fn ctr_pos => list_comb (recx, rec_args) |> permute_args ctr_pos)
   427     |> Syntax.check_terms lthy
   428     |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.conceal (Thm.def_binding b), []), t)))
   429       bs mxs
   430   end;
   431 
   432 fun find_rec_calls has_call ({ctr, ctr_args, rhs_term, ...} : eqn_data) =
   433   let
   434     fun find bound_Ts (Abs (_, T, b)) ctr_arg = find (T :: bound_Ts) b ctr_arg
   435       | find bound_Ts (t as _ $ _) ctr_arg =
   436         let
   437           val typof = curry fastype_of1 bound_Ts;
   438           val (f', args') = strip_comb t;
   439           val n = find_index (equal ctr_arg o head_of) args';
   440         in
   441           if n < 0 then
   442             find bound_Ts f' ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args'
   443           else
   444             let
   445               val (f, args as arg :: _) = chop n args' |>> curry list_comb f'
   446               val (arg_head, arg_args) = Term.strip_comb arg;
   447             in
   448               if has_call f then
   449                 mk_partial_compN (length arg_args) (typof arg_head) f ::
   450                 maps (fn x => find bound_Ts x ctr_arg) args
   451               else
   452                 find bound_Ts f ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args
   453             end
   454         end
   455       | find _ _ _ = [];
   456   in
   457     map (find [] rhs_term) ctr_args
   458     |> (fn [] => NONE | callss => SOME (ctr, callss))
   459   end;
   460 
   461 fun mk_primrec_tac ctxt num_extra_args map_idents map_comps fun_defs recx =
   462   unfold_thms_tac ctxt fun_defs THEN
   463   HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN
   464   unfold_thms_tac ctxt (@{thms id_def split o_def fst_conv snd_conv} @ map_comps @ map_idents) THEN
   465   HEADGOAL (rtac refl);
   466 
   467 fun prepare_primrec fixes specs lthy =
   468   let
   469     val thy = Proof_Context.theory_of lthy;
   470 
   471     val (bs, mxs) = map_split (apfst fst) fixes;
   472     val fun_names = map Binding.name_of bs;
   473     val eqns_data = map (dissect_eqn lthy fun_names) specs;
   474     val funs_data = eqns_data
   475       |> partition_eq ((op =) o pairself #fun_name)
   476       |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
   477       |> map (fn (x, y) => the_single y handle List.Empty =>
   478           primrec_error ("missing equations for function " ^ quote x));
   479 
   480     val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
   481     val arg_Ts = map (#rec_type o hd) funs_data;
   482     val res_Ts = map (#res_type o hd) funs_data;
   483     val callssss = funs_data
   484       |> map (partition_eq ((op =) o pairself #ctr))
   485       |> map (maps (map_filter (find_rec_calls has_call)));
   486 
   487     val _ = (case filter_out (fn (_, T) => Sign.of_sort thy (T, HOLogic.typeS)) (bs ~~ res_Ts) of
   488         [] => ()
   489       | (b, _) :: _ => primrec_error ("type of " ^ Binding.print b ^ " contains top sort"));
   490 
   491     val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy') =
   492       rec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy;
   493 
   494     val actual_nn = length funs_data;
   495 
   496     val _ = let val ctrs = (maps (map #ctr o #ctr_specs) rec_specs) in
   497       map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
   498         primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy' ctr) ^
   499           " is not a constructor in left-hand side") user_eqn) eqns_data end;
   500 
   501     val defs = build_defs lthy' bs mxs funs_data rec_specs has_call;
   502 
   503     fun prove lthy def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec)
   504         (fun_data : eqn_data list) =
   505       let
   506         val def_thms = map (snd o snd) def_thms';
   507         val simp_thmss = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
   508           |> fst
   509           |> map_filter (try (fn (x, [y]) =>
   510             (#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
   511           |> map (fn (user_eqn, num_extra_args, rec_thm) =>
   512             mk_primrec_tac lthy num_extra_args nested_map_idents nested_map_comps def_thms rec_thm
   513             |> K |> Goal.prove_sorry lthy [] [] user_eqn
   514             |> Thm.close_derivation);
   515         val poss = find_indices (op = o pairself #ctr) fun_data eqns_data;
   516       in
   517         (poss, simp_thmss)
   518       end;
   519 
   520     val notes =
   521       (if n2m then map2 (fn name => fn thm =>
   522         (name, inductN, [thm], [])) fun_names (take actual_nn induct_thms) else [])
   523       |> map (fn (prefix, thmN, thms, attrs) =>
   524         ((Binding.qualify true prefix (Binding.name thmN), attrs), [(thms, [])]));
   525 
   526     val common_name = mk_common_name fun_names;
   527 
   528     val common_notes =
   529       (if n2m then [(inductN, [induct_thm], [])] else [])
   530       |> map (fn (thmN, thms, attrs) =>
   531         ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
   532   in
   533     (((fun_names, defs),
   534       fn lthy => fn defs =>
   535         split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)),
   536       lthy' |> Local_Theory.notes (notes @ common_notes) |> snd)
   537   end;
   538 
   539 (* primrec definition *)
   540 
   541 fun add_primrec_simple fixes ts lthy =
   542   let
   543     val (((names, defs), prove), lthy) = prepare_primrec fixes ts lthy
   544       handle ERROR str => primrec_error str;
   545   in
   546     lthy
   547     |> fold_map Local_Theory.define defs
   548     |-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs))))
   549   end
   550   handle Primrec_Error (str, eqns) =>
   551     if null eqns
   552     then error ("primrec_new error:\n  " ^ str)
   553     else error ("primrec_new error:\n  " ^ str ^ "\nin\n  " ^
   554       space_implode "\n  " (map (quote o Syntax.string_of_term lthy) eqns));
   555 
   556 local
   557 
   558 fun gen_primrec prep_spec (raw_fixes : (binding * 'a option * mixfix) list) raw_spec lthy =
   559   let
   560     val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes)
   561     val _ = null d orelse primrec_error ("duplicate function name(s): " ^ commas d);
   562 
   563     val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy);
   564 
   565     val mk_notes =
   566       flat ooo map3 (fn poss => fn prefix => fn thms =>
   567         let
   568           val (bs, attrss) = map_split (fst o nth specs) poss;
   569           val notes =
   570             map3 (fn b => fn attrs => fn thm =>
   571               ((Binding.qualify false prefix b, code_nitpicksimp_simp_attrs @ attrs), [([thm], [])]))
   572             bs attrss thms;
   573         in
   574           ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes
   575         end);
   576   in
   577     lthy
   578     |> add_primrec_simple fixes (map snd specs)
   579     |-> (fn (names, (ts, (posss, simpss))) =>
   580       Spec_Rules.add Spec_Rules.Equational (ts, flat simpss)
   581       #> Local_Theory.notes (mk_notes posss names simpss)
   582       #>> pair ts o map snd)
   583   end;
   584 
   585 in
   586 
   587 val add_primrec = gen_primrec Specification.check_spec;
   588 val add_primrec_cmd = gen_primrec Specification.read_spec;
   589 
   590 end;
   591 
   592 fun add_primrec_global fixes specs thy =
   593   let
   594     val lthy = Named_Target.theory_init thy;
   595     val ((ts, simps), lthy') = add_primrec fixes specs lthy;
   596     val simps' = burrow (Proof_Context.export lthy' lthy) simps;
   597   in ((ts, simps'), Local_Theory.exit_global lthy') end;
   598 
   599 fun add_primrec_overloaded ops fixes specs thy =
   600   let
   601     val lthy = Overloading.overloading ops thy;
   602     val ((ts, simps), lthy') = add_primrec fixes specs lthy;
   603     val simps' = burrow (Proof_Context.export lthy' lthy) simps;
   604   in ((ts, simps'), Local_Theory.exit_global lthy') end;
   605 
   606 end;