src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
changeset 54671 d64a4ef26edb
parent 54670 cfb21e03fe2a
parent 54635 30666a281ae3
child 54672 748778ac0ab8
equal deleted inserted replaced
54670:cfb21e03fe2a 54671:d64a4ef26edb
     1 (*  Title:      HOL/BNF/Tools/bnf_fp_rec_sugar.ML
       
     2     Author:     Lorenz Panny, TU Muenchen
       
     3     Copyright   2013
       
     4 
       
     5 Recursor and corecursor sugar.
       
     6 *)
       
     7 
       
     8 signature BNF_FP_REC_SUGAR =
       
     9 sig
       
    10   val add_primrec: (binding * typ option * mixfix) list ->
       
    11     (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory
       
    12   val add_primrec_cmd: (binding * string option * mixfix) list ->
       
    13     (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory
       
    14   val add_primrec_global: (binding * typ option * mixfix) list ->
       
    15     (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
       
    16   val add_primrec_overloaded: (string * (string * typ) * bool) list ->
       
    17     (binding * typ option * mixfix) list ->
       
    18     (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
       
    19   val add_primrec_simple: ((binding * typ) * mixfix) list -> term list ->
       
    20     local_theory -> (string list * (term list * (int list list * thm list list))) * local_theory
       
    21   val add_primcorecursive_cmd: bool ->
       
    22     (binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list ->
       
    23     Proof.context -> Proof.state
       
    24   val add_primcorec_cmd: bool ->
       
    25     (binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list ->
       
    26     local_theory -> local_theory
       
    27 end;
       
    28 
       
    29 structure BNF_FP_Rec_Sugar : BNF_FP_REC_SUGAR =
       
    30 struct
       
    31 
       
    32 open BNF_Util
       
    33 open BNF_FP_Util
       
    34 open BNF_FP_Rec_Sugar_Util
       
    35 open BNF_FP_Rec_Sugar_Tactics
       
    36 
       
    37 val codeN = "code"
       
    38 val ctrN = "ctr"
       
    39 val discN = "disc"
       
    40 val selN = "sel"
       
    41 
       
    42 val nitpick_attrs = @{attributes [nitpick_simp]};
       
    43 val simp_attrs = @{attributes [simp]};
       
    44 val code_nitpick_attrs = Code.add_default_eqn_attrib :: nitpick_attrs;
       
    45 val code_nitpick_simp_attrs = Code.add_default_eqn_attrib :: nitpick_attrs @ simp_attrs;
       
    46 
       
    47 exception Primrec_Error of string * term list;
       
    48 
       
    49 fun primrec_error str = raise Primrec_Error (str, []);
       
    50 fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]);
       
    51 fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns);
       
    52 
       
    53 fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x);
       
    54 
       
    55 val free_name = try (fn Free (v, _) => v);
       
    56 val const_name = try (fn Const (v, _) => v);
       
    57 val undef_const = Const (@{const_name undefined}, dummyT);
       
    58 
       
    59 fun permute_args n t = list_comb (t, map Bound (0 :: (n downto 1)))
       
    60   |> fold (K (Term.abs (Name.uu, dummyT))) (0 upto n);
       
    61 val abs_tuple = HOLogic.tupled_lambda o HOLogic.mk_tuple;
       
    62 fun drop_All t = subst_bounds (strip_qnt_vars @{const_name all} t |> map Free |> rev,
       
    63   strip_qnt_body @{const_name all} t)
       
    64 fun abstract vs =
       
    65   let fun a n (t $ u) = a n t $ a n u
       
    66         | a n (Abs (v, T, b)) = Abs (v, T, a (n + 1) b)
       
    67         | a n t = let val idx = find_index (equal t) vs in
       
    68             if idx < 0 then t else Bound (n + idx) end
       
    69   in a 0 end;
       
    70 fun mk_prod1 Ts (t, u) = HOLogic.pair_const (fastype_of1 (Ts, t)) (fastype_of1 (Ts, u)) $ t $ u;
       
    71 fun mk_tuple1 Ts = the_default HOLogic.unit o try (foldr1 (mk_prod1 Ts));
       
    72 
       
    73 fun get_indices fixes t = map (fst #>> Binding.name_of #> Free) fixes
       
    74   |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE)
       
    75   |> map_filter I;
       
    76 
       
    77 
       
    78 (* Primrec *)
       
    79 
       
    80 type eqn_data = {
       
    81   fun_name: string,
       
    82   rec_type: typ,
       
    83   ctr: term,
       
    84   ctr_args: term list,
       
    85   left_args: term list,
       
    86   right_args: term list,
       
    87   res_type: typ,
       
    88   rhs_term: term,
       
    89   user_eqn: term
       
    90 };
       
    91 
       
    92 fun dissect_eqn lthy fun_names eqn' =
       
    93   let
       
    94     val eqn = drop_All eqn' |> HOLogic.dest_Trueprop
       
    95       handle TERM _ =>
       
    96         primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
       
    97     val (lhs, rhs) = HOLogic.dest_eq eqn
       
    98         handle TERM _ =>
       
    99           primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
       
   100     val (fun_name, args) = strip_comb lhs
       
   101       |>> (fn x => if is_Free x then fst (dest_Free x)
       
   102           else primrec_error_eqn "malformed function equation (does not start with free)" eqn);
       
   103     val (left_args, rest) = take_prefix is_Free args;
       
   104     val (nonfrees, right_args) = take_suffix is_Free rest;
       
   105     val num_nonfrees = length nonfrees;
       
   106     val _ = num_nonfrees = 1 orelse if num_nonfrees = 0 then
       
   107       primrec_error_eqn "constructor pattern missing in left-hand side" eqn else
       
   108       primrec_error_eqn "more than one non-variable argument in left-hand side" eqn;
       
   109     val _ = member (op =) fun_names fun_name orelse
       
   110       primrec_error_eqn "malformed function equation (does not start with function name)" eqn
       
   111 
       
   112     val (ctr, ctr_args) = strip_comb (the_single nonfrees);
       
   113     val _ = try (num_binder_types o fastype_of) ctr = SOME (length ctr_args) orelse
       
   114       primrec_error_eqn "partially applied constructor in pattern" eqn;
       
   115     val _ = let val d = duplicates (op =) (left_args @ ctr_args @ right_args) in null d orelse
       
   116       primrec_error_eqn ("duplicate variable \"" ^ Syntax.string_of_term lthy (hd d) ^
       
   117         "\" in left-hand side") eqn end;
       
   118     val _ = forall is_Free ctr_args orelse
       
   119       primrec_error_eqn "non-primitive pattern in left-hand side" eqn;
       
   120     val _ =
       
   121       let val b = fold_aterms (fn x as Free (v, _) =>
       
   122         if (not (member (op =) (left_args @ ctr_args @ right_args) x) andalso
       
   123         not (member (op =) fun_names v) andalso
       
   124         not (Variable.is_fixed lthy v)) then cons x else I | _ => I) rhs []
       
   125       in
       
   126         null b orelse
       
   127         primrec_error_eqn ("extra variable(s) in right-hand side: " ^
       
   128           commas (map (Syntax.string_of_term lthy) b)) eqn
       
   129       end;
       
   130   in
       
   131     {fun_name = fun_name,
       
   132      rec_type = body_type (type_of ctr),
       
   133      ctr = ctr,
       
   134      ctr_args = ctr_args,
       
   135      left_args = left_args,
       
   136      right_args = right_args,
       
   137      res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs,
       
   138      rhs_term = rhs,
       
   139      user_eqn = eqn'}
       
   140   end;
       
   141 
       
   142 fun rewrite_map_arg get_ctr_pos rec_type res_type =
       
   143   let
       
   144     val pT = HOLogic.mk_prodT (rec_type, res_type);
       
   145 
       
   146     val maybe_suc = Option.map (fn x => x + 1);
       
   147     fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT)
       
   148       | subst d (Abs (v, T, b)) = Abs (v, if d = SOME ~1 then pT else T, subst (maybe_suc d) b)
       
   149       | subst d t =
       
   150         let
       
   151           val (u, vs) = strip_comb t;
       
   152           val ctr_pos = try (get_ctr_pos o the) (free_name u) |> the_default ~1;
       
   153         in
       
   154           if ctr_pos >= 0 then
       
   155             if d = SOME ~1 andalso length vs = ctr_pos then
       
   156               list_comb (permute_args ctr_pos (snd_const pT), vs)
       
   157             else if length vs > ctr_pos andalso is_some d
       
   158                 andalso d = try (fn Bound n => n) (nth vs ctr_pos) then
       
   159               list_comb (snd_const pT $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs))
       
   160             else
       
   161               primrec_error_eqn ("recursive call not directly applied to constructor argument") t
       
   162           else if d = SOME ~1 andalso const_name u = SOME @{const_name comp} then
       
   163             list_comb (map_types (K dummyT) u, map2 subst [NONE, d] vs)
       
   164           else
       
   165             list_comb (u, map (subst (d |> d = SOME ~1 ? K NONE)) vs)
       
   166         end
       
   167   in
       
   168     subst (SOME ~1)
       
   169   end;
       
   170 
       
   171 fun subst_rec_calls lthy get_ctr_pos has_call ctr_args direct_calls indirect_calls t =
       
   172   let
       
   173     fun subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
       
   174       | subst bound_Ts (t as g' $ y) =
       
   175         let
       
   176           val maybe_direct_y' = AList.lookup (op =) direct_calls y;
       
   177           val maybe_indirect_y' = AList.lookup (op =) indirect_calls y;
       
   178           val (g, g_args) = strip_comb g';
       
   179           val ctr_pos = try (get_ctr_pos o the) (free_name g) |> the_default ~1;
       
   180           val _ = ctr_pos < 0 orelse length g_args >= ctr_pos orelse
       
   181             primrec_error_eqn "too few arguments in recursive call" t;
       
   182         in
       
   183           if not (member (op =) ctr_args y) then
       
   184             pairself (subst bound_Ts) (g', y) |> (op $)
       
   185           else if ctr_pos >= 0 then
       
   186             list_comb (the maybe_direct_y', g_args)
       
   187           else if is_some maybe_indirect_y' then
       
   188             (if has_call g' then t else y)
       
   189             |> massage_indirect_rec_call lthy has_call
       
   190               (rewrite_map_arg get_ctr_pos) bound_Ts y (the maybe_indirect_y')
       
   191             |> (if has_call g' then I else curry (op $) g')
       
   192           else
       
   193             t
       
   194         end
       
   195       | subst _ t = t
       
   196   in
       
   197     subst [] t
       
   198     |> tap (fn u => has_call u andalso (* FIXME detect this case earlier *)
       
   199       primrec_error_eqn "recursive call not directly applied to constructor argument" t)
       
   200   end;
       
   201 
       
   202 fun build_rec_arg lthy (funs_data : eqn_data list list) has_call (ctr_spec : rec_ctr_spec)
       
   203     (maybe_eqn_data : eqn_data option) =
       
   204   if is_none maybe_eqn_data then undef_const else
       
   205     let
       
   206       val eqn_data = the maybe_eqn_data;
       
   207       val t = #rhs_term eqn_data;
       
   208       val ctr_args = #ctr_args eqn_data;
       
   209 
       
   210       val calls = #calls ctr_spec;
       
   211       val n_args = fold (curry (op +) o (fn Direct_Rec _ => 2 | _ => 1)) calls 0;
       
   212 
       
   213       val no_calls' = tag_list 0 calls
       
   214         |> map_filter (try (apsnd (fn No_Rec n => n | Direct_Rec (n, _) => n)));
       
   215       val direct_calls' = tag_list 0 calls
       
   216         |> map_filter (try (apsnd (fn Direct_Rec (_, n) => n)));
       
   217       val indirect_calls' = tag_list 0 calls
       
   218         |> map_filter (try (apsnd (fn Indirect_Rec n => n)));
       
   219 
       
   220       fun make_direct_type _ = dummyT; (* FIXME? *)
       
   221 
       
   222       val rec_res_type_list = map (fn (x :: _) => (#rec_type x, #res_type x)) funs_data;
       
   223 
       
   224       fun make_indirect_type (Type (Tname, Ts)) = Type (Tname, Ts |> map (fn T =>
       
   225         let val maybe_res_type = AList.lookup (op =) rec_res_type_list T in
       
   226           if is_some maybe_res_type
       
   227           then HOLogic.mk_prodT (T, the maybe_res_type)
       
   228           else make_indirect_type T end))
       
   229         | make_indirect_type T = T;
       
   230 
       
   231       val args = replicate n_args ("", dummyT)
       
   232         |> Term.rename_wrt_term t
       
   233         |> map Free
       
   234         |> fold (fn (ctr_arg_idx, arg_idx) =>
       
   235             nth_map arg_idx (K (nth ctr_args ctr_arg_idx)))
       
   236           no_calls'
       
   237         |> fold (fn (ctr_arg_idx, arg_idx) =>
       
   238             nth_map arg_idx (K (nth ctr_args ctr_arg_idx |> map_types make_direct_type)))
       
   239           direct_calls'
       
   240         |> fold (fn (ctr_arg_idx, arg_idx) =>
       
   241             nth_map arg_idx (K (nth ctr_args ctr_arg_idx |> map_types make_indirect_type)))
       
   242           indirect_calls';
       
   243 
       
   244       val fun_name_ctr_pos_list =
       
   245         map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
       
   246       val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1;
       
   247       val direct_calls = map (apfst (nth ctr_args) o apsnd (nth args)) direct_calls';
       
   248       val indirect_calls = map (apfst (nth ctr_args) o apsnd (nth args)) indirect_calls';
       
   249 
       
   250       val abstractions = args @ #left_args eqn_data @ #right_args eqn_data;
       
   251     in
       
   252       t
       
   253       |> subst_rec_calls lthy get_ctr_pos has_call ctr_args direct_calls indirect_calls
       
   254       |> fold_rev lambda abstractions
       
   255     end;
       
   256 
       
   257 fun build_defs lthy bs mxs (funs_data : eqn_data list list) (rec_specs : rec_spec list) has_call =
       
   258   let
       
   259     val n_funs = length funs_data;
       
   260 
       
   261     val ctr_spec_eqn_data_list' =
       
   262       (take n_funs rec_specs |> map #ctr_specs) ~~ funs_data
       
   263       |> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y))
       
   264           ##> (fn x => null x orelse
       
   265             primrec_error_eqns "excess equations in definition" (map #rhs_term x)) #> fst);
       
   266     val _ = ctr_spec_eqn_data_list' |> map (fn (_, x) => length x <= 1 orelse
       
   267       primrec_error_eqns ("multiple equations for constructor") (map #user_eqn x));
       
   268 
       
   269     val ctr_spec_eqn_data_list =
       
   270       ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair []));
       
   271 
       
   272     val recs = take n_funs rec_specs |> map #recx;
       
   273     val rec_args = ctr_spec_eqn_data_list
       
   274       |> sort ((op <) o pairself (#offset o fst) |> make_ord)
       
   275       |> map (uncurry (build_rec_arg lthy funs_data has_call) o apsnd (try the_single));
       
   276     val ctr_poss = map (fn x =>
       
   277       if length (distinct ((op =) o pairself (length o #left_args)) x) <> 1 then
       
   278         primrec_error ("inconstant constructor pattern position for function " ^
       
   279           quote (#fun_name (hd x)))
       
   280       else
       
   281         hd x |> #left_args |> length) funs_data;
       
   282   in
       
   283     (recs, ctr_poss)
       
   284     |-> map2 (fn recx => fn ctr_pos => list_comb (recx, rec_args) |> permute_args ctr_pos)
       
   285     |> Syntax.check_terms lthy
       
   286     |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.map_name Thm.def_name b, []), t))) bs mxs
       
   287   end;
       
   288 
       
   289 fun find_rec_calls has_call (eqn_data : eqn_data) =
       
   290   let
       
   291     fun find (Abs (_, _, b)) ctr_arg = find b ctr_arg
       
   292       | find (t as _ $ _) ctr_arg =
       
   293         let
       
   294           val (f', args') = strip_comb t;
       
   295           val n = find_index (equal ctr_arg) args';
       
   296         in
       
   297           if n < 0 then
       
   298             find f' ctr_arg @ maps (fn x => find x ctr_arg) args'
       
   299           else
       
   300             let val (f, args) = chop n args' |>> curry list_comb f' in
       
   301               if has_call f then
       
   302                 f :: maps (fn x => find x ctr_arg) args
       
   303               else
       
   304                 find f ctr_arg @ maps (fn x => find x ctr_arg) args
       
   305             end
       
   306         end
       
   307       | find _ _ = [];
       
   308   in
       
   309     map (find (#rhs_term eqn_data)) (#ctr_args eqn_data)
       
   310     |> (fn [] => NONE | callss => SOME (#ctr eqn_data, callss))
       
   311   end;
       
   312 
       
   313 fun prepare_primrec fixes specs lthy =
       
   314   let
       
   315     val (bs, mxs) = map_split (apfst fst) fixes;
       
   316     val fun_names = map Binding.name_of bs;
       
   317     val eqns_data = map (dissect_eqn lthy fun_names) specs;
       
   318     val funs_data = eqns_data
       
   319       |> partition_eq ((op =) o pairself #fun_name)
       
   320       |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
       
   321       |> map (fn (x, y) => the_single y handle List.Empty =>
       
   322           primrec_error ("missing equations for function " ^ quote x));
       
   323 
       
   324     val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
       
   325     val arg_Ts = map (#rec_type o hd) funs_data;
       
   326     val res_Ts = map (#res_type o hd) funs_data;
       
   327     val callssss = funs_data
       
   328       |> map (partition_eq ((op =) o pairself #ctr))
       
   329       |> map (maps (map_filter (find_rec_calls has_call)));
       
   330 
       
   331     val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy') =
       
   332       rec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy;
       
   333 
       
   334     val actual_nn = length funs_data;
       
   335 
       
   336     val _ = let val ctrs = (maps (map #ctr o #ctr_specs) rec_specs) in
       
   337       map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
       
   338         primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy' ctr) ^
       
   339           " is not a constructor in left-hand side") user_eqn) eqns_data end;
       
   340 
       
   341     val defs = build_defs lthy' bs mxs funs_data rec_specs has_call;
       
   342 
       
   343     fun prove lthy def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec)
       
   344         (fun_data : eqn_data list) =
       
   345       let
       
   346         val def_thms = map (snd o snd) def_thms';
       
   347         val simp_thmss = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
       
   348           |> fst
       
   349           |> map_filter (try (fn (x, [y]) =>
       
   350             (#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
       
   351           |> map (fn (user_eqn, num_extra_args, rec_thm) =>
       
   352             mk_primrec_tac lthy num_extra_args nested_map_idents nested_map_comps def_thms rec_thm
       
   353             |> K |> Goal.prove lthy [] [] user_eqn);
       
   354         val poss = find_indices (fn (x, y) => #ctr x = #ctr y) fun_data eqns_data;
       
   355       in
       
   356         (poss, simp_thmss)
       
   357       end;
       
   358 
       
   359     val notes =
       
   360       (if n2m then map2 (fn name => fn thm =>
       
   361         (name, inductN, [thm], [])) fun_names (take actual_nn induct_thms) else [])
       
   362       |> map (fn (prefix, thmN, thms, attrs) =>
       
   363         ((Binding.qualify true prefix (Binding.name thmN), attrs), [(thms, [])]));
       
   364 
       
   365     val common_name = mk_common_name fun_names;
       
   366 
       
   367     val common_notes =
       
   368       (if n2m then [(inductN, [induct_thm], [])] else [])
       
   369       |> map (fn (thmN, thms, attrs) =>
       
   370         ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
       
   371   in
       
   372     (((fun_names, defs),
       
   373       fn lthy => fn defs =>
       
   374         split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)),
       
   375       lthy' |> Local_Theory.notes (notes @ common_notes) |> snd)
       
   376   end;
       
   377 
       
   378 (* primrec definition *)
       
   379 
       
   380 fun add_primrec_simple fixes ts lthy =
       
   381   let
       
   382     val (((names, defs), prove), lthy) = prepare_primrec fixes ts lthy
       
   383       handle ERROR str => primrec_error str;
       
   384   in
       
   385     lthy
       
   386     |> fold_map Local_Theory.define defs
       
   387     |-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs))))
       
   388   end
       
   389   handle Primrec_Error (str, eqns) =>
       
   390     if null eqns
       
   391     then error ("primrec_new error:\n  " ^ str)
       
   392     else error ("primrec_new error:\n  " ^ str ^ "\nin\n  " ^
       
   393       space_implode "\n  " (map (quote o Syntax.string_of_term lthy) eqns));
       
   394 
       
   395 local
       
   396 
       
   397 fun gen_primrec prep_spec (raw_fixes : (binding * 'a option * mixfix) list) raw_spec lthy =
       
   398   let
       
   399     val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes)
       
   400     val _ = null d orelse primrec_error ("duplicate function name(s): " ^ commas d);
       
   401 
       
   402     val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy);
       
   403 
       
   404     val mk_notes =
       
   405       flat ooo map3 (fn poss => fn prefix => fn thms =>
       
   406         let
       
   407           val (bs, attrss) = map_split (fst o nth specs) poss;
       
   408           val notes =
       
   409             map3 (fn b => fn attrs => fn thm =>
       
   410               ((Binding.qualify false prefix b, code_nitpick_simp_attrs @ attrs), [([thm], [])]))
       
   411             bs attrss thms;
       
   412         in
       
   413           ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes
       
   414         end);
       
   415   in
       
   416     lthy
       
   417     |> add_primrec_simple fixes (map snd specs)
       
   418     |-> (fn (names, (ts, (posss, simpss))) =>
       
   419       Spec_Rules.add Spec_Rules.Equational (ts, flat simpss)
       
   420       #> Local_Theory.notes (mk_notes posss names simpss)
       
   421       #>> pair ts o map snd)
       
   422   end;
       
   423 
       
   424 in
       
   425 
       
   426 val add_primrec = gen_primrec Specification.check_spec;
       
   427 val add_primrec_cmd = gen_primrec Specification.read_spec;
       
   428 
       
   429 end;
       
   430 
       
   431 fun add_primrec_global fixes specs thy =
       
   432   let
       
   433     val lthy = Named_Target.theory_init thy;
       
   434     val ((ts, simps), lthy') = add_primrec fixes specs lthy;
       
   435     val simps' = burrow (Proof_Context.export lthy' lthy) simps;
       
   436   in ((ts, simps'), Local_Theory.exit_global lthy') end;
       
   437 
       
   438 fun add_primrec_overloaded ops fixes specs thy =
       
   439   let
       
   440     val lthy = Overloading.overloading ops thy;
       
   441     val ((ts, simps), lthy') = add_primrec fixes specs lthy;
       
   442     val simps' = burrow (Proof_Context.export lthy' lthy) simps;
       
   443   in ((ts, simps'), Local_Theory.exit_global lthy') end;
       
   444 
       
   445 
       
   446 
       
   447 (* Primcorec *)
       
   448 
       
   449 type co_eqn_data_disc = {
       
   450   fun_name: string,
       
   451   fun_T: typ,
       
   452   fun_args: term list,
       
   453   ctr: term,
       
   454   ctr_no: int, (*###*)
       
   455   disc: term,
       
   456   prems: term list,
       
   457   auto_gen: bool,
       
   458   user_eqn: term
       
   459 };
       
   460 
       
   461 type co_eqn_data_sel = {
       
   462   fun_name: string,
       
   463   fun_T: typ,
       
   464   fun_args: term list,
       
   465   ctr: term,
       
   466   sel: term,
       
   467   rhs_term: term,
       
   468   user_eqn: term
       
   469 };
       
   470 
       
   471 datatype co_eqn_data =
       
   472   Disc of co_eqn_data_disc |
       
   473   Sel of co_eqn_data_sel;
       
   474 
       
   475 fun co_dissect_eqn_disc sequential fun_names (corec_specs : corec_spec list) prems' concl
       
   476     matchedsss =
       
   477   let
       
   478     fun find_subterm p = let (* FIXME \<exists>? *)
       
   479       fun f (t as u $ v) = if p t then SOME t else merge_options (f u, f v)
       
   480         | f t = if p t then SOME t else NONE
       
   481       in f end;
       
   482 
       
   483     val applied_fun = concl
       
   484       |> find_subterm (member ((op =) o apsnd SOME) fun_names o try (fst o dest_Free o head_of))
       
   485       |> the
       
   486       handle Option.Option => primrec_error_eqn "malformed discriminator equation" concl;
       
   487     val ((fun_name, fun_T), fun_args) = strip_comb applied_fun |>> dest_Free;
       
   488     val {ctr_specs, ...} = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name);
       
   489 
       
   490     val discs = map #disc ctr_specs;
       
   491     val ctrs = map #ctr ctr_specs;
       
   492     val not_disc = head_of concl = @{term Not};
       
   493     val _ = not_disc andalso length ctrs <> 2 andalso
       
   494       primrec_error_eqn "\<not>ed discriminator for a type with \<noteq> 2 constructors" concl;
       
   495     val disc = find_subterm (member (op =) discs o head_of) concl;
       
   496     val eq_ctr0 = concl |> perhaps (try (HOLogic.dest_not)) |> try (HOLogic.dest_eq #> snd)
       
   497         |> (fn SOME t => let val n = find_index (equal t) ctrs in
       
   498           if n >= 0 then SOME n else NONE end | _ => NONE);
       
   499     val _ = is_some disc orelse is_some eq_ctr0 orelse
       
   500       primrec_error_eqn "no discriminator in equation" concl;
       
   501     val ctr_no' =
       
   502       if is_none disc then the eq_ctr0 else find_index (equal (head_of (the disc))) discs;
       
   503     val ctr_no = if not_disc then 1 - ctr_no' else ctr_no';
       
   504     val ctr = #ctr (nth ctr_specs ctr_no);
       
   505 
       
   506     val catch_all = try (fst o dest_Free o the_single) prems' = SOME Name.uu_;
       
   507     val matchedss = AList.lookup (op =) matchedsss fun_name |> the_default [];
       
   508     val prems = map (abstract (List.rev fun_args)) prems';
       
   509     val real_prems =
       
   510       (if catch_all orelse sequential then maps negate_disj matchedss else []) @
       
   511       (if catch_all then [] else prems);
       
   512 
       
   513     val matchedsss' = AList.delete (op =) fun_name matchedsss
       
   514       |> cons (fun_name, if sequential then matchedss @ [prems] else matchedss @ [real_prems]);
       
   515 
       
   516     val user_eqn =
       
   517       (real_prems, betapply (#disc (nth ctr_specs ctr_no), applied_fun))
       
   518       |>> map HOLogic.mk_Trueprop ||> HOLogic.mk_Trueprop
       
   519       |> Logic.list_implies;
       
   520   in
       
   521     (Disc {
       
   522       fun_name = fun_name,
       
   523       fun_T = fun_T,
       
   524       fun_args = fun_args,
       
   525       ctr = ctr,
       
   526       ctr_no = ctr_no,
       
   527       disc = #disc (nth ctr_specs ctr_no),
       
   528       prems = real_prems,
       
   529       auto_gen = catch_all,
       
   530       user_eqn = user_eqn
       
   531     }, matchedsss')
       
   532   end;
       
   533 
       
   534 fun co_dissect_eqn_sel fun_names (corec_specs : corec_spec list) eqn' of_spec eqn =
       
   535   let
       
   536     val (lhs, rhs) = HOLogic.dest_eq eqn
       
   537       handle TERM _ =>
       
   538         primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn;
       
   539     val sel = head_of lhs;
       
   540     val ((fun_name, fun_T), fun_args) = dest_comb lhs |> snd |> strip_comb |> apfst dest_Free
       
   541       handle TERM _ =>
       
   542         primrec_error_eqn "malformed selector argument in left-hand side" eqn;
       
   543     val corec_spec = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name)
       
   544       handle Option.Option => primrec_error_eqn "malformed selector argument in left-hand side" eqn;
       
   545     val ctr_spec =
       
   546       if is_some of_spec
       
   547       then the (find_first (equal (the of_spec) o #ctr) (#ctr_specs corec_spec))
       
   548       else #ctr_specs corec_spec |> filter (exists (equal sel) o #sels) |> the_single
       
   549         handle List.Empty => primrec_error_eqn "ambiguous selector - use \"of\"" eqn;
       
   550     val user_eqn = drop_All eqn';
       
   551   in
       
   552     Sel {
       
   553       fun_name = fun_name,
       
   554       fun_T = fun_T,
       
   555       fun_args = fun_args,
       
   556       ctr = #ctr ctr_spec,
       
   557       sel = sel,
       
   558       rhs_term = rhs,
       
   559       user_eqn = user_eqn
       
   560     }
       
   561   end;
       
   562 
       
   563 fun co_dissect_eqn_ctr sequential fun_names (corec_specs : corec_spec list) eqn' imp_prems imp_rhs
       
   564     matchedsss =
       
   565   let
       
   566     val (lhs, rhs) = HOLogic.dest_eq imp_rhs;
       
   567     val fun_name = head_of lhs |> fst o dest_Free;
       
   568     val {ctr_specs, ...} = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name);
       
   569     val (ctr, ctr_args) = strip_comb rhs;
       
   570     val {disc, sels, ...} = the (find_first (equal ctr o #ctr) ctr_specs)
       
   571       handle Option.Option => primrec_error_eqn "not a constructor" ctr;
       
   572 
       
   573     val disc_imp_rhs = betapply (disc, lhs);
       
   574     val (maybe_eqn_data_disc, matchedsss') = if length ctr_specs = 1
       
   575       then (NONE, matchedsss)
       
   576       else apfst SOME (co_dissect_eqn_disc
       
   577           sequential fun_names corec_specs imp_prems disc_imp_rhs matchedsss);
       
   578 
       
   579     val sel_imp_rhss = (sels ~~ ctr_args)
       
   580       |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
       
   581 
       
   582 (*
       
   583 val _ = tracing ("reduced\n    " ^ Syntax.string_of_term @{context} imp_rhs ^ "\nto\n    \<cdot> " ^
       
   584  (is_some maybe_eqn_data_disc ? K (Syntax.string_of_term @{context} disc_imp_rhs ^ "\n    \<cdot> ")) "" ^
       
   585  space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) sel_imp_rhss));
       
   586 *)
       
   587 
       
   588     val eqns_data_sel =
       
   589       map (co_dissect_eqn_sel fun_names corec_specs eqn' (SOME ctr)) sel_imp_rhss;
       
   590   in
       
   591     (the_list maybe_eqn_data_disc @ eqns_data_sel, matchedsss')
       
   592   end;
       
   593 
       
   594 fun co_dissect_eqn sequential fun_names (corec_specs : corec_spec list) eqn' of_spec matchedsss =
       
   595   let
       
   596     val eqn = drop_All eqn'
       
   597       handle TERM _ => primrec_error_eqn "malformed function equation" eqn';
       
   598     val (imp_prems, imp_rhs) = Logic.strip_horn eqn
       
   599       |> apfst (map HOLogic.dest_Trueprop) o apsnd HOLogic.dest_Trueprop;
       
   600 
       
   601     val head = imp_rhs
       
   602       |> perhaps (try HOLogic.dest_not) |> perhaps (try (fst o HOLogic.dest_eq))
       
   603       |> head_of;
       
   604 
       
   605     val maybe_rhs = imp_rhs |> perhaps (try (HOLogic.dest_not)) |> try (snd o HOLogic.dest_eq);
       
   606 
       
   607     val discs = maps #ctr_specs corec_specs |> map #disc;
       
   608     val sels = maps #ctr_specs corec_specs |> maps #sels;
       
   609     val ctrs = maps #ctr_specs corec_specs |> map #ctr;
       
   610   in
       
   611     if member (op =) discs head orelse
       
   612       is_some maybe_rhs andalso
       
   613         member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then
       
   614       co_dissect_eqn_disc sequential fun_names corec_specs imp_prems imp_rhs matchedsss
       
   615       |>> single
       
   616     else if member (op =) sels head then
       
   617       ([co_dissect_eqn_sel fun_names corec_specs eqn' of_spec imp_rhs], matchedsss)
       
   618     else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) then
       
   619       co_dissect_eqn_ctr sequential fun_names corec_specs eqn' imp_prems imp_rhs matchedsss
       
   620     else
       
   621       primrec_error_eqn "malformed function equation" eqn
       
   622   end;
       
   623 
       
   624 fun build_corec_arg_disc (ctr_specs : corec_ctr_spec list)
       
   625     ({fun_args, ctr_no, prems, ...} : co_eqn_data_disc) =
       
   626   if is_none (#pred (nth ctr_specs ctr_no)) then I else
       
   627     mk_conjs prems
       
   628     |> curry subst_bounds (List.rev fun_args)
       
   629     |> HOLogic.tupled_lambda (HOLogic.mk_tuple fun_args)
       
   630     |> K |> nth_map (the (#pred (nth ctr_specs ctr_no)));
       
   631 
       
   632 fun build_corec_arg_no_call (sel_eqns : co_eqn_data_sel list) sel =
       
   633   find_first (equal sel o #sel) sel_eqns
       
   634   |> try (fn SOME {fun_args, rhs_term, ...} => abs_tuple fun_args rhs_term)
       
   635   |> the_default undef_const
       
   636   |> K;
       
   637 
       
   638 fun build_corec_args_direct_call lthy has_call (sel_eqns : co_eqn_data_sel list) sel =
       
   639   let
       
   640     val maybe_sel_eqn = find_first (equal sel o #sel) sel_eqns;
       
   641   in
       
   642     if is_none maybe_sel_eqn then (I, I, I) else
       
   643     let
       
   644       val {fun_args, rhs_term, ... } = the maybe_sel_eqn;
       
   645       fun rewrite_q _ t = if has_call t then @{term False} else @{term True};
       
   646       fun rewrite_g _ t = if has_call t then undef_const else t;
       
   647       fun rewrite_h bound_Ts t =
       
   648         if has_call t then mk_tuple1 bound_Ts (snd (strip_comb t)) else undef_const;
       
   649       fun massage f t = massage_direct_corec_call lthy has_call f [] rhs_term |> abs_tuple fun_args;
       
   650     in
       
   651       (massage rewrite_q,
       
   652        massage rewrite_g,
       
   653        massage rewrite_h)
       
   654     end
       
   655   end;
       
   656 
       
   657 fun build_corec_arg_indirect_call lthy has_call (sel_eqns : co_eqn_data_sel list) sel =
       
   658   let
       
   659     val maybe_sel_eqn = find_first (equal sel o #sel) sel_eqns;
       
   660   in
       
   661     if is_none maybe_sel_eqn then I else
       
   662     let
       
   663       val {fun_args, rhs_term, ...} = the maybe_sel_eqn;
       
   664       fun rewrite bound_Ts U T (Abs (v, V, b)) = Abs (v, V, rewrite (V :: bound_Ts) U T b)
       
   665         | rewrite bound_Ts U T (t as _ $ _) =
       
   666           let val (u, vs) = strip_comb t in
       
   667             if is_Free u andalso has_call u then
       
   668               Inr_const U T $ mk_tuple1 bound_Ts vs
       
   669             else if try (fst o dest_Const) u = SOME @{const_name prod_case} then
       
   670               map (rewrite bound_Ts U T) vs |> chop 1 |>> HOLogic.mk_split o the_single |> list_comb
       
   671             else
       
   672               list_comb (rewrite bound_Ts U T u, map (rewrite bound_Ts U T) vs)
       
   673           end
       
   674         | rewrite _ U T t =
       
   675           if is_Free t andalso has_call t then Inr_const U T $ HOLogic.unit else t;
       
   676       fun massage t =
       
   677         massage_indirect_corec_call lthy has_call rewrite [] (range_type (fastype_of t)) rhs_term
       
   678         |> abs_tuple fun_args;
       
   679     in
       
   680       massage
       
   681     end
       
   682   end;
       
   683 
       
   684 fun build_corec_args_sel lthy has_call (all_sel_eqns : co_eqn_data_sel list)
       
   685     (ctr_spec : corec_ctr_spec) =
       
   686   let val sel_eqns = filter (equal (#ctr ctr_spec) o #ctr) all_sel_eqns in
       
   687     if null sel_eqns then I else
       
   688       let
       
   689         val sel_call_list = #sels ctr_spec ~~ #calls ctr_spec;
       
   690 
       
   691         val no_calls' = map_filter (try (apsnd (fn No_Corec n => n))) sel_call_list;
       
   692         val direct_calls' = map_filter (try (apsnd (fn Direct_Corec n => n))) sel_call_list;
       
   693         val indirect_calls' = map_filter (try (apsnd (fn Indirect_Corec n => n))) sel_call_list;
       
   694       in
       
   695         I
       
   696         #> fold (fn (sel, n) => nth_map n (build_corec_arg_no_call sel_eqns sel)) no_calls'
       
   697         #> fold (fn (sel, (q, g, h)) =>
       
   698           let val (fq, fg, fh) = build_corec_args_direct_call lthy has_call sel_eqns sel in
       
   699             nth_map q fq o nth_map g fg o nth_map h fh end) direct_calls'
       
   700         #> fold (fn (sel, n) => nth_map n
       
   701           (build_corec_arg_indirect_call lthy has_call sel_eqns sel)) indirect_calls'
       
   702       end
       
   703   end;
       
   704 
       
   705 fun co_build_defs lthy bs mxs has_call arg_Tss (corec_specs : corec_spec list)
       
   706     (disc_eqnss : co_eqn_data_disc list list) (sel_eqnss : co_eqn_data_sel list list) =
       
   707   let
       
   708     val corec_specs' = take (length bs) corec_specs;
       
   709     val corecs = map #corec corec_specs';
       
   710     val ctr_specss = map #ctr_specs corec_specs';
       
   711     val corec_args = hd corecs
       
   712       |> fst o split_last o binder_types o fastype_of
       
   713       |> map (Const o pair @{const_name undefined})
       
   714       |> fold2 (fold o build_corec_arg_disc) ctr_specss disc_eqnss
       
   715       |> fold2 (fold o build_corec_args_sel lthy has_call) sel_eqnss ctr_specss;
       
   716     fun currys [] t = t
       
   717       | currys Ts t = t $ mk_tuple1 (List.rev Ts) (map Bound (length Ts - 1 downto 0))
       
   718           |> fold_rev (Term.abs o pair Name.uu) Ts;
       
   719 
       
   720 (*
       
   721 val _ = tracing ("corecursor arguments:\n    \<cdot> " ^
       
   722  space_implode "\n    \<cdot> " (map (Syntax.string_of_term lthy) corec_args));
       
   723 *)
       
   724 
       
   725     val exclss' =
       
   726       disc_eqnss
       
   727       |> map (map (fn x => (#fun_args x, #ctr_no x, #prems x, #auto_gen x))
       
   728         #> fst o (fn xs => fold_map (fn x => fn ys => ((x, ys), ys @ [x])) xs [])
       
   729         #> maps (uncurry (map o pair)
       
   730           #> map (fn ((fun_args, c, x, a), (_, c', y, a')) =>
       
   731               ((c, c', a orelse a'), (x, s_not (mk_conjs y)))
       
   732             ||> apfst (map HOLogic.mk_Trueprop) o apsnd HOLogic.mk_Trueprop
       
   733             ||> Logic.list_implies
       
   734             ||> curry Logic.list_all (map dest_Free fun_args))))
       
   735   in
       
   736     map (list_comb o rpair corec_args) corecs
       
   737     |> map2 (fn Ts => fn t => if length Ts = 0 then t $ HOLogic.unit else t) arg_Tss
       
   738     |> map2 currys arg_Tss
       
   739     |> Syntax.check_terms lthy
       
   740     |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.map_name Thm.def_name b, []), t))) bs mxs
       
   741     |> rpair exclss'
       
   742   end;
       
   743 
       
   744 fun mk_real_disc_eqns fun_binding arg_Ts ({ctr_specs, ...} : corec_spec)
       
   745     (sel_eqns : co_eqn_data_sel list) (disc_eqns : co_eqn_data_disc list) =
       
   746   if length disc_eqns <> length ctr_specs - 1 then disc_eqns else
       
   747     let
       
   748       val n = 0 upto length ctr_specs
       
   749         |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns));
       
   750       val fun_args = (try (#fun_args o hd) disc_eqns, try (#fun_args o hd) sel_eqns)
       
   751         |> the_default (map (curry Free Name.uu) arg_Ts) o merge_options;
       
   752       val extra_disc_eqn = {
       
   753         fun_name = Binding.name_of fun_binding,
       
   754         fun_T = arg_Ts ---> body_type (fastype_of (#ctr (hd ctr_specs))),
       
   755         fun_args = fun_args,
       
   756         ctr = #ctr (nth ctr_specs n),
       
   757         ctr_no = n,
       
   758         disc = #disc (nth ctr_specs n),
       
   759         prems = maps (negate_conj o #prems) disc_eqns,
       
   760         auto_gen = true,
       
   761         user_eqn = undef_const};
       
   762     in
       
   763       chop n disc_eqns ||> cons extra_disc_eqn |> (op @)
       
   764     end;
       
   765 
       
   766 fun add_primcorec simple sequential fixes specs of_specs lthy =
       
   767   let
       
   768     val (bs, mxs) = map_split (apfst fst) fixes;
       
   769     val (arg_Ts, res_Ts) = map (strip_type o snd o fst #>> HOLogic.mk_tupleT) fixes |> split_list;
       
   770 
       
   771     val callssss = []; (* FIXME *)
       
   772 
       
   773     val ((n2m, corec_specs', _, coinduct_thm, strong_coinduct_thm, coinduct_thms,
       
   774           strong_coinduct_thms), lthy') =
       
   775       corec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy;
       
   776 
       
   777     val actual_nn = length bs;
       
   778     val fun_names = map Binding.name_of bs;
       
   779     val corec_specs = take actual_nn corec_specs'; (*###*)
       
   780 
       
   781     val eqns_data =
       
   782       fold_map2 (co_dissect_eqn sequential fun_names corec_specs) (map snd specs) of_specs []
       
   783       |> flat o fst;
       
   784 
       
   785     val disc_eqnss' = map_filter (try (fn Disc x => x)) eqns_data
       
   786       |> partition_eq ((op =) o pairself #fun_name)
       
   787       |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
       
   788       |> map (sort ((op <) o pairself #ctr_no |> make_ord) o flat o snd);
       
   789     val _ = disc_eqnss' |> map (fn x =>
       
   790       let val d = duplicates ((op =) o pairself #ctr_no) x in null d orelse
       
   791         primrec_error_eqns "excess discriminator equations in definition"
       
   792           (maps (fn t => filter (equal (#ctr_no t) o #ctr_no) x) d |> map #user_eqn) end);
       
   793 
       
   794     val sel_eqnss = map_filter (try (fn Sel x => x)) eqns_data
       
   795       |> partition_eq ((op =) o pairself #fun_name)
       
   796       |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
       
   797       |> map (flat o snd);
       
   798 
       
   799     val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
       
   800     val arg_Tss = map (binder_types o snd o fst) fixes;
       
   801     val disc_eqnss = map5 mk_real_disc_eqns bs arg_Tss corec_specs sel_eqnss disc_eqnss';
       
   802     val (defs, exclss') =
       
   803       co_build_defs lthy' bs mxs has_call arg_Tss corec_specs disc_eqnss sel_eqnss;
       
   804 
       
   805     fun excl_tac (c, c', a) =
       
   806       if a orelse c = c' orelse sequential then
       
   807         SOME (K (HEADGOAL (mk_primcorec_assumption_tac lthy [])))
       
   808       else if simple then
       
   809         SOME (K (auto_tac lthy))
       
   810       else
       
   811         NONE;
       
   812 
       
   813 (*
       
   814 val _ = tracing ("exclusiveness properties:\n    \<cdot> " ^
       
   815  space_implode "\n    \<cdot> " (maps (map (Syntax.string_of_term lthy o snd)) exclss'));
       
   816 *)
       
   817 
       
   818     val exclss'' = exclss' |> map (map (fn (idx, t) =>
       
   819       (idx, (Option.map (Goal.prove lthy [] [] t) (excl_tac idx), t))));
       
   820     val taut_thmss = map (map (apsnd (the o fst)) o filter (is_some o fst o snd)) exclss'';
       
   821     val (obligation_idxss, obligationss) = exclss''
       
   822       |> map (map (apsnd (rpair [] o snd)) o filter (is_none o fst o snd))
       
   823       |> split_list o map split_list;
       
   824 
       
   825     fun prove thmss' def_thms' lthy =
       
   826       let
       
   827         val def_thms = map (snd o snd) def_thms';
       
   828 
       
   829         val exclss' = map (op ~~) (obligation_idxss ~~ thmss');
       
   830         fun mk_exclsss excls n =
       
   831           (excls, map (fn k => replicate k [TrueI] @ replicate (n - k) []) (0 upto n - 1))
       
   832           |-> fold (fn ((c, c', _), thm) => nth_map c (nth_map c' (K [thm])));
       
   833         val exclssss = (exclss' ~~ taut_thmss |> map (op @), fun_names ~~ corec_specs)
       
   834           |-> map2 (fn excls => fn (_, {ctr_specs, ...}) => mk_exclsss excls (length ctr_specs));
       
   835 
       
   836         fun prove_disc ({ctr_specs, ...} : corec_spec) exclsss
       
   837             ({fun_name, fun_T, fun_args, ctr_no, prems, ...} : co_eqn_data_disc) =
       
   838           if Term.aconv_untyped (#disc (nth ctr_specs ctr_no), @{term "\<lambda>x. x = x"}) then [] else
       
   839             let
       
   840               val {disc_corec, ...} = nth ctr_specs ctr_no;
       
   841               val k = 1 + ctr_no;
       
   842               val m = length prems;
       
   843               val t =
       
   844                 list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0))
       
   845                 |> curry betapply (#disc (nth ctr_specs ctr_no)) (*###*)
       
   846                 |> HOLogic.mk_Trueprop
       
   847                 |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
       
   848                 |> curry Logic.list_all (map dest_Free fun_args);
       
   849             in
       
   850               mk_primcorec_disc_tac lthy def_thms disc_corec k m exclsss
       
   851               |> K |> Goal.prove lthy [] [] t
       
   852               |> pair (#disc (nth ctr_specs ctr_no))
       
   853               |> single
       
   854             end;
       
   855 
       
   856         fun prove_sel ({nested_maps, nested_map_idents, nested_map_comps, ctr_specs, ...}
       
   857             : corec_spec) (disc_eqns : co_eqn_data_disc list) exclsss
       
   858             ({fun_name, fun_T, fun_args, ctr, sel, rhs_term, ...} : co_eqn_data_sel) =
       
   859           let
       
   860             val SOME ctr_spec = find_first (equal ctr o #ctr) ctr_specs;
       
   861             val ctr_no = find_index (equal ctr o #ctr) ctr_specs;
       
   862             val prems = the_default (maps (negate_conj o #prems) disc_eqns)
       
   863                 (find_first (equal ctr_no o #ctr_no) disc_eqns |> Option.map #prems);
       
   864             val sel_corec = find_index (equal sel) (#sels ctr_spec)
       
   865               |> nth (#sel_corecs ctr_spec);
       
   866             val k = 1 + ctr_no;
       
   867             val m = length prems;
       
   868             val t =
       
   869               list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0))
       
   870               |> curry betapply sel
       
   871               |> rpair (abstract (List.rev fun_args) rhs_term)
       
   872               |> HOLogic.mk_Trueprop o HOLogic.mk_eq
       
   873               |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
       
   874               |> curry Logic.list_all (map dest_Free fun_args);
       
   875             val (distincts, _, sel_splits, sel_split_asms) = case_thms_of_term lthy [] rhs_term;
       
   876           in
       
   877             mk_primcorec_sel_tac lthy def_thms distincts sel_splits sel_split_asms nested_maps
       
   878               nested_map_idents nested_map_comps sel_corec k m exclsss
       
   879             |> K |> Goal.prove lthy [] [] t
       
   880             |> pair sel
       
   881           end;
       
   882 
       
   883         fun prove_ctr disc_alist sel_alist (disc_eqns : co_eqn_data_disc list)
       
   884             (sel_eqns : co_eqn_data_sel list) ({ctr, disc, sels, collapse, ...} : corec_ctr_spec) =
       
   885           if not (exists (equal ctr o #ctr) disc_eqns)
       
   886               andalso not (exists (equal ctr o #ctr) sel_eqns)
       
   887             orelse (* don't try to prove theorems when some sel_eqns are missing *)
       
   888               filter (equal ctr o #ctr) sel_eqns
       
   889               |> fst o finds ((op =) o apsnd #sel) sels
       
   890               |> exists (null o snd)
       
   891           then [] else
       
   892             let
       
   893               val (fun_name, fun_T, fun_args, prems) =
       
   894                 (find_first (equal ctr o #ctr) disc_eqns, find_first (equal ctr o #ctr) sel_eqns)
       
   895                 |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x))
       
   896                 ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, []))
       
   897                 |> the o merge_options;
       
   898               val m = length prems;
       
   899               val t = filter (equal ctr o #ctr) sel_eqns
       
   900                 |> fst o finds ((op =) o apsnd #sel) sels
       
   901                 |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x)) #-> abstract)
       
   902                 |> curry list_comb ctr
       
   903                 |> curry HOLogic.mk_eq (list_comb (Free (fun_name, fun_T),
       
   904                   map Bound (length fun_args - 1 downto 0)))
       
   905                 |> HOLogic.mk_Trueprop
       
   906                 |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
       
   907                 |> curry Logic.list_all (map dest_Free fun_args);
       
   908               val maybe_disc_thm = AList.lookup (op =) disc_alist disc;
       
   909               val sel_thms = map snd (filter (member (op =) sels o fst) sel_alist);
       
   910             in
       
   911               mk_primcorec_ctr_of_dtr_tac lthy m collapse maybe_disc_thm sel_thms
       
   912               |> K |> Goal.prove lthy [] [] t
       
   913               |> single
       
   914             end;
       
   915 
       
   916         val disc_alists = map3 (maps oo prove_disc) corec_specs exclssss disc_eqnss;
       
   917         val sel_alists = map4 (map ooo prove_sel) corec_specs disc_eqnss exclssss sel_eqnss;
       
   918 
       
   919         val disc_thmss = map (map snd) disc_alists;
       
   920         val sel_thmss = map (map snd) sel_alists;
       
   921         val ctr_thmss = map5 (maps oooo prove_ctr) disc_alists sel_alists disc_eqnss sel_eqnss
       
   922           (map #ctr_specs corec_specs);
       
   923 
       
   924         val simp_thmss = map2 append disc_thmss sel_thmss
       
   925 
       
   926         val common_name = mk_common_name fun_names;
       
   927 
       
   928         val notes =
       
   929           [(coinductN, map (if n2m then single else K []) coinduct_thms, []),
       
   930            (codeN, ctr_thmss(*FIXME*), code_nitpick_attrs),
       
   931            (ctrN, ctr_thmss, []),
       
   932            (discN, disc_thmss, simp_attrs),
       
   933            (selN, sel_thmss, simp_attrs),
       
   934            (simpsN, simp_thmss, []),
       
   935            (strong_coinductN, map (if n2m then single else K []) strong_coinduct_thms, [])]
       
   936           |> maps (fn (thmN, thmss, attrs) =>
       
   937             map2 (fn fun_name => fn thms =>
       
   938                 ((Binding.qualify true fun_name (Binding.name thmN), attrs), [(thms, [])]))
       
   939               fun_names (take actual_nn thmss))
       
   940           |> filter_out (null o fst o hd o snd);
       
   941 
       
   942         val common_notes =
       
   943           [(coinductN, if n2m then [coinduct_thm] else [], []),
       
   944            (strong_coinductN, if n2m then [strong_coinduct_thm] else [], [])]
       
   945           |> filter_out (null o #2)
       
   946           |> map (fn (thmN, thms, attrs) =>
       
   947             ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
       
   948       in
       
   949         lthy |> Local_Theory.notes (notes @ common_notes) |> snd
       
   950       end;
       
   951 
       
   952     fun after_qed thmss' = fold_map Local_Theory.define defs #-> prove thmss';
       
   953 
       
   954     val _ = if not simple orelse forall null obligationss then () else
       
   955       primrec_error "need exclusiveness proofs - use primcorecursive instead of primcorec";
       
   956   in
       
   957     if simple then
       
   958       lthy'
       
   959       |> after_qed (map (fn [] => []) obligationss)
       
   960       |> pair NONE o SOME
       
   961     else
       
   962       lthy'
       
   963       |> Proof.theorem NONE after_qed obligationss
       
   964       |> Proof.refine (Method.primitive_text I)
       
   965       |> Seq.hd
       
   966       |> rpair NONE o SOME
       
   967   end;
       
   968 
       
   969 fun add_primcorec_ursive_cmd simple seq (raw_fixes, raw_specs') lthy =
       
   970   let
       
   971     val (raw_specs, of_specs) = split_list raw_specs' ||> map (Option.map (Syntax.read_term lthy));
       
   972     val ((fixes, specs), _) = Specification.read_spec raw_fixes raw_specs lthy;
       
   973   in
       
   974     add_primcorec simple seq fixes specs of_specs lthy
       
   975     handle ERROR str => primrec_error str
       
   976   end
       
   977   handle Primrec_Error (str, eqns) =>
       
   978     if null eqns
       
   979     then error ("primcorec error:\n  " ^ str)
       
   980     else error ("primcorec error:\n  " ^ str ^ "\nin\n  " ^
       
   981       space_implode "\n  " (map (quote o Syntax.string_of_term lthy) eqns));
       
   982 
       
   983 val add_primcorecursive_cmd = (the o fst) ooo add_primcorec_ursive_cmd false;
       
   984 val add_primcorec_cmd = (the o snd) ooo add_primcorec_ursive_cmd true;
       
   985 
       
   986 end;