src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
author blanchet
Fri Sep 20 16:32:27 2013 +0200 (2013-09-20)
changeset 53753 ae7f50e70c09
parent 53744 9db52ae90009
child 53791 0ddf045113c9
permissions -rw-r--r--
renamed "primcorec" to "primcorecursive", to open the door to a 'theory -> theory' command called "primcorec" (cf. "fun" vs. "function")
     1 (*  Title:      HOL/BNF/Tools/bnf_fp_rec_sugar.ML
     2     Author:     Lorenz Panny, TU Muenchen
     3     Copyright   2013
     4 
     5 Recursor and corecursor sugar.
     6 *)
     7 
     8 signature BNF_FP_REC_SUGAR =
     9 sig
    10   val add_primrec_cmd: (binding * string option * mixfix) list ->
    11     (Attrib.binding * string) list -> local_theory -> local_theory;
    12   val add_primcorecursive_cmd: bool ->
    13     (binding * string option * mixfix) list * (Attrib.binding * string) list -> Proof.context ->
    14     Proof.state
    15 end;
    16 
    17 structure BNF_FP_Rec_Sugar : BNF_FP_REC_SUGAR =
    18 struct
    19 
    20 open BNF_Util
    21 open BNF_FP_Util
    22 open BNF_FP_Rec_Sugar_Util
    23 open BNF_FP_Rec_Sugar_Tactics
    24 
    25 exception Primrec_Error of string * term list;
    26 
    27 fun primrec_error str = raise Primrec_Error (str, []);
    28 fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]);
    29 fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns);
    30 
    31 fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x);
    32 
    33 val free_name = try (fn Free (v, _) => v);
    34 val const_name = try (fn Const (v, _) => v);
    35 val undef_const = Const (@{const_name undefined}, dummyT);
    36 
    37 fun permute_args n t = list_comb (t, map Bound (0 :: (n downto 1)))
    38   |> fold (K (Term.abs (Name.uu, dummyT))) (0 upto n);
    39 val abs_tuple = HOLogic.tupled_lambda o HOLogic.mk_tuple;
    40 fun drop_All t = subst_bounds (strip_qnt_vars @{const_name all} t |> map Free |> rev,
    41   strip_qnt_body @{const_name all} t)
    42 fun mk_not @{const True} = @{const False}
    43   | mk_not @{const False} = @{const True}
    44   | mk_not (@{const Not} $ t) = t
    45   | mk_not (@{const Trueprop} $ t) = @{const Trueprop} $ mk_not t
    46   | mk_not t = HOLogic.mk_not t
    47 val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True};
    48 val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False};
    49 fun invert_prems [t] = map mk_not (HOLogic.disjuncts t)
    50   | invert_prems ts = [mk_disjs (map mk_not ts)];
    51 fun invert_prems_disj [t] = map mk_not (HOLogic.disjuncts t)
    52   | invert_prems_disj ts = [mk_disjs (map (mk_conjs o map mk_not o HOLogic.disjuncts) ts)];
    53 fun abstract vs =
    54   let fun a n (t $ u) = a n t $ a n u
    55         | a n (Abs (v, T, b)) = Abs (v, T, a (n + 1) b)
    56         | a n t = let val idx = find_index (equal t) vs in
    57             if idx < 0 then t else Bound (n + idx) end
    58   in a 0 end;
    59 fun mk_prod1 Ts (t, u) = HOLogic.pair_const (fastype_of1 (Ts, t)) (fastype_of1 (Ts, u)) $ t $ u;
    60 fun mk_tuple1 Ts = the_default HOLogic.unit o try (foldr1 (mk_prod1 Ts));
    61 
    62 val simp_attrs = @{attributes [simp]};
    63 
    64 
    65 (* Primrec *)
    66 
    67 type eqn_data = {
    68   fun_name: string,
    69   rec_type: typ,
    70   ctr: term,
    71   ctr_args: term list,
    72   left_args: term list,
    73   right_args: term list,
    74   res_type: typ,
    75   rhs_term: term,
    76   user_eqn: term
    77 };
    78 
    79 fun dissect_eqn lthy fun_names eqn' =
    80   let
    81     val eqn = drop_All eqn' |> HOLogic.dest_Trueprop
    82       handle TERM _ =>
    83         primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
    84     val (lhs, rhs) = HOLogic.dest_eq eqn
    85         handle TERM _ =>
    86           primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
    87     val (fun_name, args) = strip_comb lhs
    88       |>> (fn x => if is_Free x then fst (dest_Free x)
    89           else primrec_error_eqn "malformed function equation (does not start with free)" eqn);
    90     val (left_args, rest) = take_prefix is_Free args;
    91     val (nonfrees, right_args) = take_suffix is_Free rest;
    92     val _ = length nonfrees = 1 orelse if length nonfrees = 0 then
    93       primrec_error_eqn "constructor pattern missing in left-hand side" eqn else
    94       primrec_error_eqn "more than one non-variable argument in left-hand side" eqn;
    95     val _ = member (op =) fun_names fun_name orelse
    96       primrec_error_eqn "malformed function equation (does not start with function name)" eqn
    97 
    98     val (ctr, ctr_args) = strip_comb (the_single nonfrees);
    99     val _ = try (num_binder_types o fastype_of) ctr = SOME (length ctr_args) orelse
   100       primrec_error_eqn "partially applied constructor in pattern" eqn;
   101     val _ = let val d = duplicates (op =) (left_args @ ctr_args @ right_args) in null d orelse
   102       primrec_error_eqn ("duplicate variable \"" ^ Syntax.string_of_term lthy (hd d) ^
   103         "\" in left-hand side") eqn end;
   104     val _ = forall is_Free ctr_args orelse
   105       primrec_error_eqn "non-primitive pattern in left-hand side" eqn;
   106     val _ =
   107       let val b = fold_aterms (fn x as Free (v, _) =>
   108         if (not (member (op =) (left_args @ ctr_args @ right_args) x) andalso
   109         not (member (op =) fun_names v) andalso
   110         not (Variable.is_fixed lthy v)) then cons x else I | _ => I) rhs []
   111       in
   112         null b orelse
   113         primrec_error_eqn ("extra variable(s) in right-hand side: " ^
   114           commas (map (Syntax.string_of_term lthy) b)) eqn
   115       end;
   116   in
   117     {fun_name = fun_name,
   118      rec_type = body_type (type_of ctr),
   119      ctr = ctr,
   120      ctr_args = ctr_args,
   121      left_args = left_args,
   122      right_args = right_args,
   123      res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs,
   124      rhs_term = rhs,
   125      user_eqn = eqn'}
   126   end;
   127 
   128 fun rewrite_map_arg get_ctr_pos rec_type res_type =
   129   let
   130     val pT = HOLogic.mk_prodT (rec_type, res_type);
   131 
   132     val maybe_suc = Option.map (fn x => x + 1);
   133     fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT)
   134       | subst d (Abs (v, T, b)) = Abs (v, if d = SOME ~1 then pT else T, subst (maybe_suc d) b)
   135       | subst d t =
   136         let
   137           val (u, vs) = strip_comb t;
   138           val ctr_pos = try (get_ctr_pos o the) (free_name u) |> the_default ~1;
   139         in
   140           if ctr_pos >= 0 then
   141             if d = SOME ~1 andalso length vs = ctr_pos then
   142               list_comb (permute_args ctr_pos (snd_const pT), vs)
   143             else if length vs > ctr_pos andalso is_some d
   144                 andalso d = try (fn Bound n => n) (nth vs ctr_pos) then
   145               list_comb (snd_const pT $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs))
   146             else
   147               primrec_error_eqn ("recursive call not directly applied to constructor argument") t
   148           else if d = SOME ~1 andalso const_name u = SOME @{const_name comp} then
   149             list_comb (map_types (K dummyT) u, map2 subst [NONE, d] vs)
   150           else
   151             list_comb (u, map (subst (d |> d = SOME ~1 ? K NONE)) vs)
   152         end
   153   in
   154     subst (SOME ~1)
   155   end;
   156 
   157 fun subst_rec_calls lthy get_ctr_pos has_call ctr_args direct_calls indirect_calls t =
   158   let
   159     fun subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
   160       | subst bound_Ts (t as g' $ y) =
   161         let
   162           val maybe_direct_y' = AList.lookup (op =) direct_calls y;
   163           val maybe_indirect_y' = AList.lookup (op =) indirect_calls y;
   164           val (g, g_args) = strip_comb g';
   165           val ctr_pos = try (get_ctr_pos o the) (free_name g) |> the_default ~1;
   166           val _ = ctr_pos < 0 orelse length g_args >= ctr_pos orelse
   167             primrec_error_eqn "too few arguments in recursive call" t;
   168         in
   169           if not (member (op =) ctr_args y) then
   170             pairself (subst bound_Ts) (g', y) |> (op $)
   171           else if ctr_pos >= 0 then
   172             list_comb (the maybe_direct_y', g_args)
   173           else if is_some maybe_indirect_y' then
   174             (if has_call g' then t else y)
   175             |> massage_indirect_rec_call lthy has_call
   176               (rewrite_map_arg get_ctr_pos) bound_Ts y (the maybe_indirect_y')
   177             |> (if has_call g' then I else curry (op $) g')
   178           else
   179             t
   180         end
   181       | subst _ t = t
   182   in
   183     subst [] t
   184     |> tap (fn u => has_call u andalso (* FIXME detect this case earlier *)
   185       primrec_error_eqn "recursive call not directly applied to constructor argument" t)
   186   end;
   187 
   188 fun build_rec_arg lthy funs_data has_call ctr_spec maybe_eqn_data =
   189   if is_none maybe_eqn_data then undef_const else
   190     let
   191       val eqn_data = the maybe_eqn_data;
   192       val t = #rhs_term eqn_data;
   193       val ctr_args = #ctr_args eqn_data;
   194 
   195       val calls = #calls ctr_spec;
   196       val n_args = fold (curry (op +) o (fn Direct_Rec _ => 2 | _ => 1)) calls 0;
   197 
   198       val no_calls' = tag_list 0 calls
   199         |> map_filter (try (apsnd (fn No_Rec n => n | Direct_Rec (n, _) => n)));
   200       val direct_calls' = tag_list 0 calls
   201         |> map_filter (try (apsnd (fn Direct_Rec (_, n) => n)));
   202       val indirect_calls' = tag_list 0 calls
   203         |> map_filter (try (apsnd (fn Indirect_Rec n => n)));
   204 
   205       fun make_direct_type T = dummyT; (* FIXME? *)
   206 
   207       val rec_res_type_list = map (fn (x :: _) => (#rec_type x, #res_type x)) funs_data;
   208 
   209       fun make_indirect_type (Type (Tname, Ts)) = Type (Tname, Ts |> map (fn T =>
   210         let val maybe_res_type = AList.lookup (op =) rec_res_type_list T in
   211           if is_some maybe_res_type
   212           then HOLogic.mk_prodT (T, the maybe_res_type)
   213           else make_indirect_type T end))
   214         | make_indirect_type T = T;
   215 
   216       val args = replicate n_args ("", dummyT)
   217         |> Term.rename_wrt_term t
   218         |> map Free
   219         |> fold (fn (ctr_arg_idx, arg_idx) =>
   220             nth_map arg_idx (K (nth ctr_args ctr_arg_idx)))
   221           no_calls'
   222         |> fold (fn (ctr_arg_idx, arg_idx) =>
   223             nth_map arg_idx (K (nth ctr_args ctr_arg_idx |> map_types make_direct_type)))
   224           direct_calls'
   225         |> fold (fn (ctr_arg_idx, arg_idx) =>
   226             nth_map arg_idx (K (nth ctr_args ctr_arg_idx |> map_types make_indirect_type)))
   227           indirect_calls';
   228 
   229       val fun_name_ctr_pos_list =
   230         map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
   231       val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1;
   232       val direct_calls = map (apfst (nth ctr_args) o apsnd (nth args)) direct_calls';
   233       val indirect_calls = map (apfst (nth ctr_args) o apsnd (nth args)) indirect_calls';
   234 
   235       val abstractions = args @ #left_args eqn_data @ #right_args eqn_data;
   236     in
   237       t
   238       |> subst_rec_calls lthy get_ctr_pos has_call ctr_args direct_calls indirect_calls
   239       |> fold_rev lambda abstractions
   240     end;
   241 
   242 fun build_defs lthy bs mxs funs_data rec_specs has_call =
   243   let
   244     val n_funs = length funs_data;
   245 
   246     val ctr_spec_eqn_data_list' =
   247       (take n_funs rec_specs |> map #ctr_specs) ~~ funs_data
   248       |> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y))
   249           ##> (fn x => null x orelse
   250             primrec_error_eqns "excess equations in definition" (map #rhs_term x)) #> fst);
   251     val _ = ctr_spec_eqn_data_list' |> map (fn (_, x) => length x <= 1 orelse
   252       primrec_error_eqns ("multiple equations for constructor") (map #user_eqn x));
   253 
   254     val ctr_spec_eqn_data_list =
   255       ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair []));
   256 
   257     val recs = take n_funs rec_specs |> map #recx;
   258     val rec_args = ctr_spec_eqn_data_list
   259       |> sort ((op <) o pairself (#offset o fst) |> make_ord)
   260       |> map (uncurry (build_rec_arg lthy funs_data has_call) o apsnd (try the_single));
   261     val ctr_poss = map (fn x =>
   262       if length (distinct ((op =) o pairself (length o #left_args)) x) <> 1 then
   263         primrec_error ("inconstant constructor pattern position for function " ^
   264           quote (#fun_name (hd x)))
   265       else
   266         hd x |> #left_args |> length) funs_data;
   267   in
   268     (recs, ctr_poss)
   269     |-> map2 (fn recx => fn ctr_pos => list_comb (recx, rec_args) |> permute_args ctr_pos)
   270     |> Syntax.check_terms lthy
   271     |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.map_name Thm.def_name b, []), t))) bs mxs
   272   end;
   273 
   274 fun find_rec_calls has_call eqn_data =
   275   let
   276     fun find (Abs (_, _, b)) ctr_arg = find b ctr_arg
   277       | find (t as _ $ _) ctr_arg =
   278         let
   279           val (f', args') = strip_comb t;
   280           val n = find_index (equal ctr_arg) args';
   281         in
   282           if n < 0 then
   283             find f' ctr_arg @ maps (fn x => find x ctr_arg) args'
   284           else
   285             let val (f, args) = chop n args' |>> curry list_comb f' in
   286               if has_call f then
   287                 f :: maps (fn x => find x ctr_arg) args
   288               else
   289                 find f ctr_arg @ maps (fn x => find x ctr_arg) args
   290             end
   291         end
   292       | find _ _ = [];
   293   in
   294     map (find (#rhs_term eqn_data)) (#ctr_args eqn_data)
   295     |> (fn [] => NONE | callss => SOME (#ctr eqn_data, callss))
   296   end;
   297 
   298 fun add_primrec fixes specs lthy =
   299   let
   300     val (bs, mxs) = map_split (apfst fst) fixes;
   301     val fun_names = map Binding.name_of bs;
   302     val eqns_data = map (snd #> dissect_eqn lthy fun_names) specs;
   303     val funs_data = eqns_data
   304       |> partition_eq ((op =) o pairself #fun_name)
   305       |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
   306       |> map (fn (x, y) => the_single y handle List.Empty =>
   307           primrec_error ("missing equations for function " ^ quote x));
   308 
   309     val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
   310     val arg_Ts = map (#rec_type o hd) funs_data;
   311     val res_Ts = map (#res_type o hd) funs_data;
   312     val callssss = funs_data
   313       |> map (partition_eq ((op =) o pairself #ctr))
   314       |> map (maps (map_filter (find_rec_calls has_call)));
   315 
   316     fun get_indices t = map (fst #>> Binding.name_of #> Free) fixes
   317       |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE)
   318       |> map_filter I;
   319     val ((nontriv, rec_specs, _, induct_thm, induct_thms), lthy') =
   320       rec_specs_of bs arg_Ts res_Ts get_indices callssss lthy;
   321 
   322     val actual_nn = length funs_data;
   323 
   324     val _ = let val ctrs = (maps (map #ctr o #ctr_specs) rec_specs) in
   325       map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
   326         primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy' ctr) ^
   327           " is not a constructor in left-hand side") user_eqn) eqns_data end;
   328 
   329     val defs = build_defs lthy' bs mxs funs_data rec_specs has_call;
   330 
   331     fun prove def_thms' {ctr_specs, nested_map_idents, nested_map_comps, ...} induct_thm fun_data
   332         lthy =
   333       let
   334         val fun_name = #fun_name (hd fun_data);
   335         val def_thms = map (snd o snd) def_thms';
   336         val simp_thms = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
   337           |> fst
   338           |> map_filter (try (fn (x, [y]) =>
   339             (#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
   340           |> map (fn (user_eqn, num_extra_args, rec_thm) =>
   341             mk_primrec_tac lthy num_extra_args nested_map_idents nested_map_comps def_thms rec_thm
   342             |> K |> Goal.prove lthy [] [] user_eqn)
   343 
   344         val notes =
   345           [(inductN, if actual_nn > 1 then [induct_thm] else [], []),
   346            (simpsN, simp_thms, simp_attrs)]
   347           |> filter_out (null o #2)
   348           |> map (fn (thmN, thms, attrs) =>
   349             ((Binding.qualify true fun_name (Binding.name thmN), attrs), [(thms, [])]));
   350       in
   351         lthy |> Local_Theory.notes notes
   352       end;
   353 
   354     val common_name = mk_common_name fun_names;
   355 
   356     val common_notes =
   357       [(inductN, if nontriv andalso actual_nn > 1 then [induct_thm] else [], [])]
   358       |> filter_out (null o #2)
   359       |> map (fn (thmN, thms, attrs) =>
   360         ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
   361   in
   362     lthy'
   363     |> fold_map Local_Theory.define defs
   364     |-> snd oo (fn def_thms' => fold_map3 (prove def_thms') (take actual_nn rec_specs)
   365       (take actual_nn induct_thms) funs_data)
   366     |> snd o Local_Theory.notes common_notes
   367   end;
   368 
   369 fun add_primrec_cmd raw_fixes raw_specs lthy =
   370   let
   371     val _ = let val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes) in null d orelse
   372       primrec_error ("duplicate function name(s): " ^ commas d) end;
   373     val (fixes, specs) = fst (Specification.read_spec raw_fixes raw_specs lthy);
   374   in
   375     add_primrec fixes specs lthy
   376       handle ERROR str => primrec_error str
   377   end
   378   handle Primrec_Error (str, eqns) =>
   379     if null eqns
   380     then error ("primrec_new error:\n  " ^ str)
   381     else error ("primrec_new error:\n  " ^ str ^ "\nin\n  " ^
   382       space_implode "\n  " (map (quote o Syntax.string_of_term lthy) eqns))
   383 
   384 
   385 
   386 (* Primcorec *)
   387 
   388 type co_eqn_data_disc = {
   389   fun_name: string,
   390   fun_T: typ,
   391   fun_args: term list,
   392   ctr: term,
   393   ctr_no: int, (*###*)
   394   disc: term,
   395   prems: term list,
   396   user_eqn: term
   397 };
   398 type co_eqn_data_sel = {
   399   fun_name: string,
   400   fun_T: typ,
   401   fun_args: term list,
   402   ctr: term,
   403   sel: term,
   404   rhs_term: term,
   405   user_eqn: term
   406 };
   407 datatype co_eqn_data =
   408   Disc of co_eqn_data_disc |
   409   Sel of co_eqn_data_sel;
   410 
   411 fun co_dissect_eqn_disc sequential fun_names corec_specs prems' concl matchedsss =
   412   let
   413     fun find_subterm p = let (* FIXME \<exists>? *)
   414       fun f (t as u $ v) = if p t then SOME t else merge_options (f u, f v)
   415         | f t = if p t then SOME t else NONE
   416       in f end;
   417 
   418     val applied_fun = concl
   419       |> find_subterm (member ((op =) o apsnd SOME) fun_names o try (fst o dest_Free o head_of))
   420       |> the
   421       handle Option.Option => primrec_error_eqn "malformed discriminator equation" concl;
   422     val ((fun_name, fun_T), fun_args) = strip_comb applied_fun |>> dest_Free;
   423     val {ctr_specs, ...} = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name);
   424 
   425     val discs = map #disc ctr_specs;
   426     val ctrs = map #ctr ctr_specs;
   427     val not_disc = head_of concl = @{term Not};
   428     val _ = not_disc andalso length ctrs <> 2 andalso
   429       primrec_error_eqn "\<not>ed discriminator for a type with \<noteq> 2 constructors" concl;
   430     val disc = find_subterm (member (op =) discs o head_of) concl;
   431     val eq_ctr0 = concl |> perhaps (try (HOLogic.dest_not)) |> try (HOLogic.dest_eq #> snd)
   432         |> (fn SOME t => let val n = find_index (equal t) ctrs in
   433           if n >= 0 then SOME n else NONE end | _ => NONE);
   434     val _ = is_some disc orelse is_some eq_ctr0 orelse
   435       primrec_error_eqn "no discriminator in equation" concl;
   436     val ctr_no' =
   437       if is_none disc then the eq_ctr0 else find_index (equal (head_of (the disc))) discs;
   438     val ctr_no = if not_disc then 1 - ctr_no' else ctr_no';
   439     val ctr = #ctr (nth ctr_specs ctr_no);
   440 
   441     val catch_all = try (fst o dest_Free o the_single) prems' = SOME Name.uu_;
   442     val matchedss = AList.lookup (op =) matchedsss fun_name |> the_default [];
   443     val prems = map (abstract (List.rev fun_args)) prems';
   444     val real_prems =
   445       (if catch_all orelse sequential then maps invert_prems_disj matchedss else []) @
   446       (if catch_all then [] else prems);
   447 
   448     val matchedsss' = AList.delete (op =) fun_name matchedsss
   449       |> cons (fun_name, if sequential then matchedss @ [prems] else matchedss @ [real_prems]);
   450 
   451     val user_eqn =
   452       (real_prems, betapply (#disc (nth ctr_specs ctr_no), applied_fun))
   453       |>> map HOLogic.mk_Trueprop ||> HOLogic.mk_Trueprop
   454       |> Logic.list_implies;
   455   in
   456     (Disc {
   457       fun_name = fun_name,
   458       fun_T = fun_T,
   459       fun_args = fun_args,
   460       ctr = ctr,
   461       ctr_no = ctr_no,
   462       disc = #disc (nth ctr_specs ctr_no),
   463       prems = real_prems,
   464       user_eqn = user_eqn
   465     }, matchedsss')
   466   end;
   467 
   468 fun co_dissect_eqn_sel fun_names corec_specs eqn' eqn =
   469   let
   470     val (lhs, rhs) = HOLogic.dest_eq eqn
   471       handle TERM _ =>
   472         primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn;
   473     val sel = head_of lhs;
   474     val ((fun_name, fun_T), fun_args) = dest_comb lhs |> snd |> strip_comb |> apfst dest_Free
   475       handle TERM _ =>
   476         primrec_error_eqn "malformed selector argument in left-hand side" eqn;
   477     val corec_spec = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name)
   478       handle Option.Option => primrec_error_eqn "malformed selector argument in left-hand side" eqn;
   479     val (ctr_spec, sel) = #ctr_specs corec_spec
   480       |> the o get_index (try (the o find_first (equal sel) o #sels))
   481       |>> nth (#ctr_specs corec_spec);
   482     val user_eqn = drop_All eqn';
   483   in
   484     Sel {
   485       fun_name = fun_name,
   486       fun_T = fun_T,
   487       fun_args = fun_args,
   488       ctr = #ctr ctr_spec,
   489       sel = sel,
   490       rhs_term = rhs,
   491       user_eqn = user_eqn
   492     }
   493   end;
   494 
   495 fun co_dissect_eqn_ctr sequential fun_names corec_specs eqn' imp_prems imp_rhs matchedsss =
   496   let 
   497     val (lhs, rhs) = HOLogic.dest_eq imp_rhs;
   498     val fun_name = head_of lhs |> fst o dest_Free;
   499     val {ctr_specs, ...} = the (AList.lookup (op =) (fun_names ~~ corec_specs) fun_name);
   500     val (ctr, ctr_args) = strip_comb rhs;
   501     val {disc, sels, ...} = the (find_first (equal ctr o #ctr) ctr_specs)
   502       handle Option.Option => primrec_error_eqn "not a constructor" ctr;
   503 
   504     val disc_imp_rhs = betapply (disc, lhs);
   505     val (maybe_eqn_data_disc, matchedsss') = if length ctr_specs = 1
   506       then (NONE, matchedsss)
   507       else apfst SOME (co_dissect_eqn_disc
   508           sequential fun_names corec_specs imp_prems disc_imp_rhs matchedsss);
   509 
   510     val sel_imp_rhss = (sels ~~ ctr_args)
   511       |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
   512 
   513 val _ = tracing ("reduced\n    " ^ Syntax.string_of_term @{context} imp_rhs ^ "\nto\n    \<cdot> " ^
   514  (is_some maybe_eqn_data_disc ? K (Syntax.string_of_term @{context} disc_imp_rhs ^ "\n    \<cdot> ")) "" ^
   515  space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) sel_imp_rhss));
   516 
   517     val eqns_data_sel =
   518       map (co_dissect_eqn_sel fun_names corec_specs eqn') sel_imp_rhss;
   519   in
   520     (the_list maybe_eqn_data_disc @ eqns_data_sel, matchedsss')
   521   end;
   522 
   523 fun co_dissect_eqn sequential fun_names corec_specs eqn' matchedsss =
   524   let
   525     val eqn = drop_All eqn'
   526       handle TERM _ => primrec_error_eqn "malformed function equation" eqn';
   527     val (imp_prems, imp_rhs) = Logic.strip_horn eqn
   528       |> apfst (map HOLogic.dest_Trueprop) o apsnd HOLogic.dest_Trueprop;
   529 
   530     val head = imp_rhs
   531       |> perhaps (try HOLogic.dest_not) |> perhaps (try (fst o HOLogic.dest_eq))
   532       |> head_of;
   533 
   534     val maybe_rhs = imp_rhs |> perhaps (try (HOLogic.dest_not)) |> try (snd o HOLogic.dest_eq);
   535 
   536     val discs = maps #ctr_specs corec_specs |> map #disc;
   537     val sels = maps #ctr_specs corec_specs |> maps #sels;
   538     val ctrs = maps #ctr_specs corec_specs |> map #ctr;
   539   in
   540     if member (op =) discs head orelse
   541       is_some maybe_rhs andalso
   542         member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then
   543       co_dissect_eqn_disc sequential fun_names corec_specs imp_prems imp_rhs matchedsss
   544       |>> single
   545     else if member (op =) sels head then
   546       ([co_dissect_eqn_sel fun_names corec_specs eqn' imp_rhs], matchedsss)
   547     else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) then
   548       co_dissect_eqn_ctr sequential fun_names corec_specs eqn' imp_prems imp_rhs matchedsss
   549     else
   550       primrec_error_eqn "malformed function equation" eqn
   551   end;
   552 
   553 fun build_corec_arg_disc ctr_specs {fun_args, ctr_no, prems, ...} =
   554   if is_none (#pred (nth ctr_specs ctr_no)) then I else
   555     mk_conjs prems
   556     |> curry subst_bounds (List.rev fun_args)
   557     |> HOLogic.tupled_lambda (HOLogic.mk_tuple fun_args)
   558     |> K |> nth_map (the (#pred (nth ctr_specs ctr_no)));
   559 
   560 fun build_corec_arg_no_call sel_eqns sel =
   561   find_first (equal sel o #sel) sel_eqns
   562   |> try (fn SOME {fun_args, rhs_term, ...} => abs_tuple fun_args rhs_term)
   563   |> the_default undef_const
   564   |> K;
   565 
   566 fun build_corec_args_direct_call lthy has_call sel_eqns sel =
   567   let
   568     val maybe_sel_eqn = find_first (equal sel o #sel) sel_eqns;
   569     fun rewrite_q t = if has_call t then @{term False} else @{term True};
   570     fun rewrite_g t = if has_call t then undef_const else t;
   571     fun rewrite_h t = if has_call t then HOLogic.mk_tuple (snd (strip_comb t)) else undef_const;
   572     fun massage _ NONE t = t
   573       | massage f (SOME {fun_args, rhs_term, ...}) t =
   574         massage_direct_corec_call lthy has_call f (range_type (fastype_of t)) rhs_term
   575         |> abs_tuple fun_args;
   576   in
   577     (massage rewrite_q maybe_sel_eqn,
   578      massage rewrite_g maybe_sel_eqn,
   579      massage rewrite_h maybe_sel_eqn)
   580   end;
   581 
   582 fun build_corec_arg_indirect_call lthy has_call sel_eqns sel =
   583   let
   584     val maybe_sel_eqn = find_first (equal sel o #sel) sel_eqns;
   585     fun rewrite bound_Ts U T (Abs (v, V, b)) = Abs (v, V, rewrite (V :: bound_Ts) U T b)
   586       | rewrite bound_Ts U T (t as _ $ _) =
   587         let val (u, vs) = strip_comb t in
   588           if is_Free u andalso has_call u then
   589             Inr_const U T $ mk_tuple1 bound_Ts vs
   590           else if try (fst o dest_Const) u = SOME @{const_name prod_case} then
   591             list_comb (map_types (K dummyT) u, map (rewrite bound_Ts U T) vs)
   592           else
   593             list_comb (rewrite bound_Ts U T u, map (rewrite bound_Ts U T) vs)
   594         end
   595       | rewrite _ U T t = if is_Free t andalso has_call t then Inr_const U T $ HOLogic.unit else t;
   596     fun massage NONE t = t
   597       | massage (SOME {fun_args, rhs_term, ...}) t =
   598         massage_indirect_corec_call lthy has_call (rewrite []) []
   599           (range_type (fastype_of t)) rhs_term
   600         |> abs_tuple fun_args;
   601   in
   602     massage maybe_sel_eqn
   603   end;
   604 
   605 fun build_corec_args_sel lthy has_call all_sel_eqns ctr_spec =
   606   let val sel_eqns = filter (equal (#ctr ctr_spec) o #ctr) all_sel_eqns in
   607     if null sel_eqns then I else
   608       let
   609         val sel_call_list = #sels ctr_spec ~~ #calls ctr_spec;
   610 
   611         val no_calls' = map_filter (try (apsnd (fn No_Corec n => n))) sel_call_list;
   612         val direct_calls' = map_filter (try (apsnd (fn Direct_Corec n => n))) sel_call_list;
   613         val indirect_calls' = map_filter (try (apsnd (fn Indirect_Corec n => n))) sel_call_list;
   614       in
   615         I
   616         #> fold (fn (sel, n) => nth_map n (build_corec_arg_no_call sel_eqns sel)) no_calls'
   617         #> fold (fn (sel, (q, g, h)) =>
   618           let val (fq, fg, fh) = build_corec_args_direct_call lthy has_call sel_eqns sel in
   619             nth_map q fq o nth_map g fg o nth_map h fh end) direct_calls'
   620         #> fold (fn (sel, n) => nth_map n
   621           (build_corec_arg_indirect_call lthy has_call sel_eqns sel)) indirect_calls'
   622       end
   623   end;
   624 
   625 fun co_build_defs lthy bs mxs has_call arg_Tss corec_specs disc_eqnss sel_eqnss =
   626   let
   627     val corec_specs' = take (length bs) corec_specs;
   628     val corecs = map #corec corec_specs';
   629     val ctr_specss = map #ctr_specs corec_specs';
   630     val corec_args = hd corecs
   631       |> fst o split_last o binder_types o fastype_of
   632       |> map (Const o pair @{const_name undefined})
   633       |> fold2 (fold o build_corec_arg_disc) ctr_specss disc_eqnss
   634       |> fold2 (fold o build_corec_args_sel lthy has_call) sel_eqnss ctr_specss;
   635     fun currys [] t = t
   636       | currys Ts t = t $ mk_tuple1 (List.rev Ts) (map Bound (length Ts - 1 downto 0))
   637           |> fold_rev (Term.abs o pair Name.uu) Ts;
   638 
   639 val _ = tracing ("corecursor arguments:\n    \<cdot> " ^
   640  space_implode "\n    \<cdot> " (map (Syntax.string_of_term lthy) corec_args));
   641 
   642     val exclss' =
   643       disc_eqnss
   644       |> map (map (fn {fun_args, ctr_no, prems, ...} => (fun_args, ctr_no, prems))
   645         #> fst o (fn xs => fold_map (fn x => fn ys => ((x, ys), ys @ [x])) xs [])
   646         #> maps (uncurry (map o pair)
   647           #> map (fn ((fun_args, c, x), (_, c', y)) => ((c, c'), (x, mk_not (mk_conjs y)))
   648             ||> apfst (map HOLogic.mk_Trueprop) o apsnd HOLogic.mk_Trueprop
   649             ||> Logic.list_implies
   650             ||> curry Logic.list_all (map dest_Free fun_args))))
   651   in
   652     map (list_comb o rpair corec_args) corecs
   653     |> map2 (fn Ts => fn t => if length Ts = 0 then t $ HOLogic.unit else t) arg_Tss
   654     |> map2 currys arg_Tss
   655     |> Syntax.check_terms lthy
   656     |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.map_name Thm.def_name b, []), t))) bs mxs
   657     |> rpair exclss'
   658   end;
   659 
   660 fun mk_real_disc_eqns fun_binding arg_Ts {ctr_specs, ...} sel_eqns disc_eqns =
   661   if length disc_eqns <> length ctr_specs - 1 then disc_eqns else
   662     let
   663       val n = 0 upto length ctr_specs
   664         |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns));
   665       val fun_args = (try (#fun_args o hd) disc_eqns, try (#fun_args o hd) sel_eqns)
   666         |> the_default (map (curry Free Name.uu) arg_Ts) o merge_options;
   667       val extra_disc_eqn = {
   668         fun_name = Binding.name_of fun_binding,
   669         fun_T = arg_Ts ---> body_type (fastype_of (#ctr (hd ctr_specs))),
   670         fun_args = fun_args,
   671         ctr = #ctr (nth ctr_specs n),
   672         ctr_no = n,
   673         disc = #disc (nth ctr_specs n),
   674         prems = maps (invert_prems o #prems) disc_eqns,
   675         user_eqn = undef_const};
   676     in
   677       chop n disc_eqns ||> cons extra_disc_eqn |> (op @)
   678     end;
   679 
   680 fun add_primcorec sequential fixes specs lthy =
   681   let
   682     val (bs, mxs) = map_split (apfst fst) fixes;
   683     val (arg_Ts, res_Ts) = map (strip_type o snd o fst #>> HOLogic.mk_tupleT) fixes |> split_list;
   684 
   685     (* copied from primrec_new *)
   686     fun get_indices t = map (fst #>> Binding.name_of #> Free) fixes
   687       |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE)
   688       |> map_filter I;
   689 
   690     val callssss = []; (* FIXME *)
   691 
   692     val ((nontriv, corec_specs', _, coinduct_thm, strong_co_induct_thm, coinduct_thmss,
   693           strong_coinduct_thmss), lthy') =
   694       corec_specs_of bs arg_Ts res_Ts get_indices callssss lthy;
   695 
   696     val fun_names = map Binding.name_of bs;
   697     val corec_specs = take (length fun_names) corec_specs'; (*###*)
   698 
   699     val (eqns_data, _) =
   700       fold_map (co_dissect_eqn sequential fun_names corec_specs) (map snd specs) []
   701       |>> flat;
   702 
   703     val disc_eqnss' = map_filter (try (fn Disc x => x)) eqns_data
   704       |> partition_eq ((op =) o pairself #fun_name)
   705       |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
   706       |> map (sort ((op <) o pairself #ctr_no |> make_ord) o flat o snd);
   707     val _ = disc_eqnss' |> map (fn x =>
   708       let val d = duplicates ((op =) o pairself #ctr_no) x in null d orelse
   709         primrec_error_eqns "excess discriminator equations in definition"
   710           (maps (fn t => filter (equal (#ctr_no t) o #ctr_no) x) d |> map #user_eqn) end);
   711 
   712     val sel_eqnss = map_filter (try (fn Sel x => x)) eqns_data
   713       |> partition_eq ((op =) o pairself #fun_name)
   714       |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names
   715       |> map (flat o snd);
   716 
   717     val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
   718     val arg_Tss = map (binder_types o snd o fst) fixes;
   719     val disc_eqnss = map5 mk_real_disc_eqns bs arg_Tss corec_specs sel_eqnss disc_eqnss';
   720     val (defs, exclss') =
   721       co_build_defs lthy' bs mxs has_call arg_Tss corec_specs disc_eqnss sel_eqnss;
   722 
   723     (* try to prove (automatically generated) tautologies by ourselves *)
   724     val exclss'' = exclss'
   725       |> map (map (apsnd
   726         (`(try (fn t => Goal.prove lthy [] [] t (mk_primcorec_assumption_tac lthy |> K))))));
   727     val taut_thmss = map (map (apsnd (the o fst)) o filter (is_some o fst o snd)) exclss'';
   728     val (obligation_idxss, obligationss) = exclss''
   729       |> map (map (apsnd (rpair [] o snd)) o filter (is_none o fst o snd))
   730       |> split_list o map split_list;
   731 
   732     fun prove thmss' def_thms' lthy =
   733       let
   734         val def_thms = map (snd o snd) def_thms';
   735 
   736         val exclss' = map (op ~~) (obligation_idxss ~~ thmss');
   737         fun mk_exclsss excls n =
   738           (excls, map (fn k => replicate k [TrueI] @ replicate (n - k) []) (0 upto n - 1))
   739           |-> fold (fn ((c, c'), thm) => nth_map c (nth_map c' (K [thm])));
   740         val exclssss = (exclss' ~~ taut_thmss |> map (op @), fun_names ~~ corec_specs)
   741           |-> map2 (fn excls => fn (_, {ctr_specs, ...}) => mk_exclsss excls (length ctr_specs));
   742 
   743         fun prove_disc {ctr_specs, ...} exclsss
   744             {fun_name, fun_T, fun_args, ctr_no, prems, user_eqn, ...} =
   745           if Term.aconv_untyped (#disc (nth ctr_specs ctr_no), @{term "\<lambda>x. x = x"}) then [] else
   746             let
   747               val {disc_corec, ...} = nth ctr_specs ctr_no;
   748               val k = 1 + ctr_no;
   749               val m = length prems;
   750               val t =
   751                 list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0))
   752                 |> curry betapply (#disc (nth ctr_specs ctr_no)) (*###*)
   753                 |> HOLogic.mk_Trueprop
   754                 |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
   755                 |> curry Logic.list_all (map dest_Free fun_args);
   756             in
   757               mk_primcorec_disc_tac lthy def_thms disc_corec k m exclsss
   758               |> K |> Goal.prove lthy [] [] t
   759               |> pair (#disc (nth ctr_specs ctr_no))
   760               |> single
   761             end;
   762 
   763         fun prove_sel {ctr_specs, nested_maps, nested_map_idents, nested_map_comps, ...}
   764             disc_eqns exclsss {fun_name, fun_T, fun_args, ctr, sel, rhs_term, ...} =
   765           let
   766             val (SOME ctr_spec) = find_first (equal ctr o #ctr) ctr_specs;
   767             val ctr_no = find_index (equal ctr o #ctr) ctr_specs;
   768             val prems = the_default (maps (invert_prems o #prems) disc_eqns)
   769                 (find_first (equal ctr_no o #ctr_no) disc_eqns |> Option.map #prems);
   770             val sel_corec = find_index (equal sel) (#sels ctr_spec)
   771               |> nth (#sel_corecs ctr_spec);
   772             val k = 1 + ctr_no;
   773             val m = length prems;
   774             val t =
   775               list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0))
   776               |> curry betapply sel
   777               |> rpair (abstract (List.rev fun_args) rhs_term)
   778               |> HOLogic.mk_Trueprop o HOLogic.mk_eq
   779               |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
   780               |> curry Logic.list_all (map dest_Free fun_args);
   781           in
   782             mk_primcorec_ctr_or_sel_tac lthy def_thms sel_corec k m exclsss
   783               nested_maps nested_map_idents nested_map_comps
   784             |> K |> Goal.prove lthy [] [] t
   785             |> pair sel
   786           end;
   787 
   788         fun prove_ctr (_, disc_thms) (_, sel_thms') disc_eqns sel_eqns
   789             {ctr, disc, sels, collapse, ...} =
   790           if not (exists (equal ctr o #ctr) disc_eqns)
   791               andalso not (exists (equal ctr o #ctr) sel_eqns)
   792 andalso (warning ("no eqns for ctr " ^ Syntax.string_of_term lthy ctr); true)
   793             orelse (* don't try to prove theorems when some sel_eqns are missing *)
   794               filter (equal ctr o #ctr) sel_eqns
   795               |> fst o finds ((op =) o apsnd #sel) sels
   796               |> exists (null o snd)
   797 andalso (warning ("sel_eqn(s) missing for ctr " ^ Syntax.string_of_term lthy ctr); true)
   798           then [] else
   799             let
   800 val _ = tracing ("ctr = " ^ Syntax.string_of_term lthy ctr);
   801 val _ = tracing (the_default "no disc_eqn" (Option.map (curry (op ^) "disc = " o Syntax.string_of_term lthy o #disc) (find_first (equal ctr o #ctr) disc_eqns)));
   802               val (fun_name, fun_T, fun_args, prems) =
   803                 (find_first (equal ctr o #ctr) disc_eqns, find_first (equal ctr o #ctr) sel_eqns)
   804                 |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x))
   805                 ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, []))
   806                 |> the o merge_options;
   807               val m = length prems;
   808               val t = sel_eqns
   809                 |> fst o finds ((op =) o apsnd #sel) sels
   810                 |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x)) #-> abstract)
   811                 |> curry list_comb ctr
   812                 |> curry HOLogic.mk_eq (list_comb (Free (fun_name, fun_T),
   813                   map Bound (length fun_args - 1 downto 0)))
   814                 |> HOLogic.mk_Trueprop
   815                 |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems)
   816                 |> curry Logic.list_all (map dest_Free fun_args);
   817               val maybe_disc_thm = AList.lookup (op =) disc_thms disc;
   818               val sel_thms = map snd (filter (member (op =) sels o fst) sel_thms');
   819 val _ = tracing ("t = " ^ Syntax.string_of_term lthy t);
   820 val _ = tracing ("m = " ^ @{make_string} m);
   821 val _ = tracing ("collapse = " ^ @{make_string} collapse);
   822 val _ = tracing ("maybe_disc_thm = " ^ @{make_string} maybe_disc_thm);
   823 val _ = tracing ("sel_thms = " ^ @{make_string} sel_thms);
   824             in
   825               mk_primcorec_ctr_of_dtr_tac lthy m collapse maybe_disc_thm sel_thms
   826               |> K |> Goal.prove lthy [] [] t
   827               |> single
   828           end;
   829 
   830         (* TODO: Reorganize so that the list looks like elsewhere in the BNF code.
   831            This is important because we continually add new theorems, change attributes, etc., and
   832            need to have a clear overview (and keep the documentation in sync). Also, it's confusing
   833            that some variables called '_thmss' are actually pairs. *)
   834         val (disc_notes, disc_thmss) =
   835           fun_names ~~ map3 (maps oo prove_disc) corec_specs exclssss disc_eqnss
   836           |> `(map (fn (fun_name, thms) =>
   837             ((Binding.qualify true fun_name (@{binding disc}), simp_attrs), [(map snd thms, [])])));
   838         val (sel_notes, sel_thmss) =
   839           fun_names ~~ map4 (map ooo prove_sel) corec_specs disc_eqnss exclssss sel_eqnss
   840           |> `(map (fn (fun_name, thms) =>
   841             ((Binding.qualify true fun_name (@{binding sel}), simp_attrs), [(map snd thms, [])])));
   842 
   843         val ctr_thmss = map5 (maps oooo prove_ctr) disc_thmss sel_thmss disc_eqnss sel_eqnss
   844           (map #ctr_specs corec_specs);
   845         val safess = map (map (K false)) ctr_thmss; (* FIXME: "true" for non-corecursive theorems *)
   846         val safe_ctr_thmss =
   847           map2 (map_filter (fn (safe, thm) => if safe then SOME thm else NONE) oo curry (op ~~))
   848             safess ctr_thmss;
   849 
   850         val ctr_notes =
   851           fun_names ~~ ctr_thmss
   852           |> map (fn (fun_name, thms) =>
   853             ((Binding.qualify true fun_name (@{binding ctr}), []), [(thms, [])]));
   854 
   855         val anonymous_notes =
   856           [(flat safe_ctr_thmss, simp_attrs)]
   857           |> map (fn (thms, attrs) => ((Binding.empty, attrs), [(thms, [])]));
   858       in
   859         lthy |> snd o Local_Theory.notes (anonymous_notes @
   860           filter (not o null o fst o the_single o snd) (disc_notes @ sel_notes @ ctr_notes))
   861       end;
   862   in
   863     lthy'
   864     |> Proof.theorem NONE (curry (op #->) (fold_map Local_Theory.define defs) o prove) obligationss
   865     |> Proof.refine (Method.primitive_text I)
   866     |> Seq.hd
   867   end
   868 
   869 fun add_primcorecursive_cmd seq (raw_fixes, raw_specs) lthy =
   870   let
   871     val (fixes, specs) = fst (Specification.read_spec raw_fixes raw_specs lthy);
   872   in
   873     add_primcorec seq fixes specs lthy
   874     handle ERROR str => primrec_error str
   875   end
   876   handle Primrec_Error (str, eqns) =>
   877     if null eqns
   878     then error ("primcorec error:\n  " ^ str)
   879     else error ("primcorec error:\n  " ^ str ^ "\nin\n  " ^
   880       space_implode "\n  " (map (quote o Syntax.string_of_term lthy) eqns))
   881 
   882 end;