src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
author panny
Thu Sep 05 01:58:48 2013 +0200 (2013-09-05)
changeset 53411 ab4edf89992f
parent 53401 2101a97e6220
child 53654 8b9ea4420f81
permissions -rw-r--r--
support indirect corecursion
blanchet@53303
     1
(*  Title:      HOL/BNF/Tools/bnf_fp_rec_sugar.ML
blanchet@53303
     2
    Author:     Lorenz Panny, TU Muenchen
blanchet@53303
     3
    Copyright   2013
blanchet@53303
     4
blanchet@53303
     5
Recursor and corecursor sugar.
blanchet@53303
     6
*)
blanchet@53303
     7
blanchet@53303
     8
signature BNF_FP_REC_SUGAR =
blanchet@53303
     9
sig
blanchet@53303
    10
  val add_primrec_cmd: (binding * string option * mixfix) list ->
blanchet@53303
    11
    (Attrib.binding * string) list -> local_theory -> local_theory;
blanchet@53310
    12
  val add_primcorec_cmd: bool ->
blanchet@53310
    13
    (binding * string option * mixfix) list * (Attrib.binding * string) list -> Proof.context ->
blanchet@53310
    14
    Proof.state
blanchet@53303
    15
end;
blanchet@53303
    16
blanchet@53303
    17
structure BNF_FP_Rec_Sugar : BNF_FP_REC_SUGAR =
blanchet@53303
    18
struct
blanchet@53303
    19
blanchet@53303
    20
open BNF_Util
blanchet@53303
    21
open BNF_FP_Util
blanchet@53303
    22
open BNF_FP_Rec_Sugar_Util
blanchet@53303
    23
open BNF_FP_Rec_Sugar_Tactics
blanchet@53303
    24
blanchet@53303
    25
exception Primrec_Error of string * term list;
blanchet@53303
    26
blanchet@53303
    27
fun primrec_error str = raise Primrec_Error (str, []);
blanchet@53303
    28
fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]);
blanchet@53303
    29
fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns);
blanchet@53303
    30
panny@53358
    31
fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x);
panny@53358
    32
panny@53357
    33
val free_name = try (fn Free (v, _) => v);
panny@53357
    34
val const_name = try (fn Const (v, _) => v);
panny@53358
    35
val undef_const = Const (@{const_name undefined}, dummyT);
panny@53357
    36
panny@53358
    37
fun permute_args n t = list_comb (t, map Bound (0 :: (n downto 1)))
panny@53358
    38
  |> fold (K (fn u => Abs (Name.uu, dummyT, u))) (0 upto n);
panny@53401
    39
val abs_tuple = HOLogic.tupled_lambda o HOLogic.mk_tuple;
blanchet@53303
    40
blanchet@53303
    41
val simp_attrs = @{attributes [simp]};
blanchet@53303
    42
blanchet@53303
    43
blanchet@53310
    44
blanchet@53310
    45
(* Primrec *)
blanchet@53310
    46
blanchet@53303
    47
type eqn_data = {
blanchet@53303
    48
  fun_name: string,
blanchet@53303
    49
  rec_type: typ,
blanchet@53303
    50
  ctr: term,
blanchet@53303
    51
  ctr_args: term list,
blanchet@53303
    52
  left_args: term list,
blanchet@53303
    53
  right_args: term list,
blanchet@53303
    54
  res_type: typ,
blanchet@53303
    55
  rhs_term: term,
blanchet@53303
    56
  user_eqn: term
blanchet@53303
    57
};
blanchet@53303
    58
blanchet@53303
    59
fun dissect_eqn lthy fun_names eqn' =
blanchet@53303
    60
  let
blanchet@53303
    61
    val eqn = subst_bounds (strip_qnt_vars @{const_name all} eqn' |> map Free |> rev,
blanchet@53303
    62
        strip_qnt_body @{const_name all} eqn') |> HOLogic.dest_Trueprop
blanchet@53303
    63
        handle TERM _ =>
blanchet@53303
    64
          primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
blanchet@53303
    65
    val (lhs, rhs) = HOLogic.dest_eq eqn
blanchet@53303
    66
        handle TERM _ =>
blanchet@53303
    67
          primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
blanchet@53303
    68
    val (fun_name, args) = strip_comb lhs
blanchet@53303
    69
      |>> (fn x => if is_Free x then fst (dest_Free x)
blanchet@53303
    70
          else primrec_error_eqn "malformed function equation (does not start with free)" eqn);
blanchet@53303
    71
    val (left_args, rest) = take_prefix is_Free args;
blanchet@53303
    72
    val (nonfrees, right_args) = take_suffix is_Free rest;
blanchet@53303
    73
    val _ = length nonfrees = 1 orelse if length nonfrees = 0 then
blanchet@53303
    74
      primrec_error_eqn "constructor pattern missing in left-hand side" eqn else
blanchet@53303
    75
      primrec_error_eqn "more than one non-variable argument in left-hand side" eqn;
blanchet@53303
    76
    val _ = member (op =) fun_names fun_name orelse
blanchet@53303
    77
      primrec_error_eqn "malformed function equation (does not start with function name)" eqn
blanchet@53303
    78
blanchet@53303
    79
    val (ctr, ctr_args) = strip_comb (the_single nonfrees);
blanchet@53303
    80
    val _ = try (num_binder_types o fastype_of) ctr = SOME (length ctr_args) orelse
blanchet@53303
    81
      primrec_error_eqn "partially applied constructor in pattern" eqn;
blanchet@53303
    82
    val _ = let val d = duplicates (op =) (left_args @ ctr_args @ right_args) in null d orelse
blanchet@53303
    83
      primrec_error_eqn ("duplicate variable \"" ^ Syntax.string_of_term lthy (hd d) ^
blanchet@53303
    84
        "\" in left-hand side") eqn end;
blanchet@53303
    85
    val _ = forall is_Free ctr_args orelse
blanchet@53303
    86
      primrec_error_eqn "non-primitive pattern in left-hand side" eqn;
blanchet@53303
    87
    val _ =
blanchet@53303
    88
      let val b = fold_aterms (fn x as Free (v, _) =>
blanchet@53303
    89
        if (not (member (op =) (left_args @ ctr_args @ right_args) x) andalso
blanchet@53303
    90
        not (member (op =) fun_names v) andalso
blanchet@53303
    91
        not (Variable.is_fixed lthy v)) then cons x else I | _ => I) rhs []
blanchet@53303
    92
      in
blanchet@53303
    93
        null b orelse
blanchet@53303
    94
        primrec_error_eqn ("extra variable(s) in right-hand side: " ^
blanchet@53303
    95
          commas (map (Syntax.string_of_term lthy) b)) eqn
blanchet@53303
    96
      end;
blanchet@53303
    97
  in
blanchet@53303
    98
    {fun_name = fun_name,
blanchet@53303
    99
     rec_type = body_type (type_of ctr),
blanchet@53303
   100
     ctr = ctr,
blanchet@53303
   101
     ctr_args = ctr_args,
blanchet@53303
   102
     left_args = left_args,
blanchet@53303
   103
     right_args = right_args,
blanchet@53303
   104
     res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs,
blanchet@53303
   105
     rhs_term = rhs,
blanchet@53303
   106
     user_eqn = eqn'}
blanchet@53303
   107
  end;
blanchet@53303
   108
panny@53401
   109
fun rewrite_map_arg get_ctr_pos rec_type res_type =
blanchet@53303
   110
  let
panny@53357
   111
    val pT = HOLogic.mk_prodT (rec_type, res_type);
blanchet@53303
   112
panny@53357
   113
    val maybe_suc = Option.map (fn x => x + 1);
panny@53357
   114
    fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT)
panny@53357
   115
      | subst d (Abs (v, T, b)) = Abs (v, if d = SOME ~1 then pT else T, subst (maybe_suc d) b)
panny@53357
   116
      | subst d t =
panny@53358
   117
        let
panny@53358
   118
          val (u, vs) = strip_comb t;
panny@53401
   119
          val ctr_pos = try (get_ctr_pos o the) (free_name u) |> the_default ~1;
panny@53358
   120
        in
panny@53401
   121
          if ctr_pos >= 0 then
panny@53357
   122
            if d = SOME ~1 andalso length vs = ctr_pos then
panny@53357
   123
              list_comb (permute_args ctr_pos (snd_const pT), vs)
panny@53357
   124
            else if length vs > ctr_pos andalso is_some d
panny@53357
   125
                andalso d = try (fn Bound n => n) (nth vs ctr_pos) then
panny@53357
   126
              list_comb (snd_const pT $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs))
blanchet@53303
   127
            else
panny@53357
   128
              primrec_error_eqn ("recursive call not directly applied to constructor argument") t
panny@53357
   129
          else if d = SOME ~1 andalso const_name u = SOME @{const_name comp} then
panny@53357
   130
            list_comb (map_types (K dummyT) u, map2 subst [NONE, d] vs)
blanchet@53303
   131
          else
panny@53357
   132
            list_comb (u, map (subst (d |> d = SOME ~1 ? K NONE)) vs)
blanchet@53303
   133
        end
blanchet@53303
   134
  in
panny@53357
   135
    subst (SOME ~1)
blanchet@53303
   136
  end;
blanchet@53303
   137
panny@53401
   138
fun subst_rec_calls lthy get_ctr_pos has_call ctr_args direct_calls indirect_calls t =
blanchet@53303
   139
  let
panny@53350
   140
    fun subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
panny@53358
   141
      | subst bound_Ts (t as g' $ y) =
blanchet@53303
   142
        let
panny@53350
   143
          val maybe_direct_y' = AList.lookup (op =) direct_calls y;
panny@53350
   144
          val maybe_indirect_y' = AList.lookup (op =) indirect_calls y;
panny@53358
   145
          val (g, g_args) = strip_comb g';
panny@53401
   146
          val ctr_pos = try (get_ctr_pos o the) (free_name g) |> the_default ~1;
panny@53401
   147
          val _ = ctr_pos < 0 orelse length g_args >= ctr_pos orelse
panny@53358
   148
            primrec_error_eqn "too few arguments in recursive call" t;
blanchet@53303
   149
        in
panny@53358
   150
          if not (member (op =) ctr_args y) then
panny@53358
   151
            pairself (subst bound_Ts) (g', y) |> (op $)
panny@53401
   152
          else if ctr_pos >= 0 then
panny@53358
   153
            list_comb (the maybe_direct_y', g_args)
panny@53350
   154
          else if is_some maybe_indirect_y' then
panny@53358
   155
            (if has_call g' then t else y)
panny@53358
   156
            |> massage_indirect_rec_call lthy has_call
panny@53401
   157
              (rewrite_map_arg get_ctr_pos) bound_Ts y (the maybe_indirect_y')
panny@53358
   158
            |> (if has_call g' then I else curry (op $) g')
blanchet@53303
   159
          else
panny@53350
   160
            t
blanchet@53303
   161
        end
panny@53350
   162
      | subst _ t = t
panny@53350
   163
  in
panny@53350
   164
    subst [] t
panny@53358
   165
    |> tap (fn u => has_call u andalso (* FIXME detect this case earlier *)
panny@53358
   166
      primrec_error_eqn "recursive call not directly applied to constructor argument" t)
panny@53350
   167
  end;
blanchet@53303
   168
panny@53358
   169
fun build_rec_arg lthy funs_data has_call ctr_spec maybe_eqn_data =
panny@53358
   170
  if is_none maybe_eqn_data then undef_const else
blanchet@53303
   171
    let
blanchet@53303
   172
      val eqn_data = the maybe_eqn_data;
blanchet@53303
   173
      val t = #rhs_term eqn_data;
blanchet@53303
   174
      val ctr_args = #ctr_args eqn_data;
blanchet@53303
   175
blanchet@53303
   176
      val calls = #calls ctr_spec;
blanchet@53303
   177
      val n_args = fold (curry (op +) o (fn Direct_Rec _ => 2 | _ => 1)) calls 0;
blanchet@53303
   178
blanchet@53303
   179
      val no_calls' = tag_list 0 calls
blanchet@53303
   180
        |> map_filter (try (apsnd (fn No_Rec n => n | Direct_Rec (n, _) => n)));
blanchet@53303
   181
      val direct_calls' = tag_list 0 calls
blanchet@53303
   182
        |> map_filter (try (apsnd (fn Direct_Rec (_, n) => n)));
blanchet@53303
   183
      val indirect_calls' = tag_list 0 calls
blanchet@53303
   184
        |> map_filter (try (apsnd (fn Indirect_Rec n => n)));
blanchet@53303
   185
blanchet@53303
   186
      fun make_direct_type T = dummyT; (* FIXME? *)
blanchet@53303
   187
blanchet@53303
   188
      val rec_res_type_list = map (fn (x :: _) => (#rec_type x, #res_type x)) funs_data;
blanchet@53303
   189
blanchet@53303
   190
      fun make_indirect_type (Type (Tname, Ts)) = Type (Tname, Ts |> map (fn T =>
blanchet@53303
   191
        let val maybe_res_type = AList.lookup (op =) rec_res_type_list T in
blanchet@53303
   192
          if is_some maybe_res_type
blanchet@53303
   193
          then HOLogic.mk_prodT (T, the maybe_res_type)
blanchet@53303
   194
          else make_indirect_type T end))
blanchet@53303
   195
        | make_indirect_type T = T;
blanchet@53303
   196
blanchet@53303
   197
      val args = replicate n_args ("", dummyT)
blanchet@53303
   198
        |> Term.rename_wrt_term t
blanchet@53303
   199
        |> map Free
blanchet@53303
   200
        |> fold (fn (ctr_arg_idx, arg_idx) =>
blanchet@53303
   201
            nth_map arg_idx (K (nth ctr_args ctr_arg_idx)))
blanchet@53303
   202
          no_calls'
blanchet@53303
   203
        |> fold (fn (ctr_arg_idx, arg_idx) =>
blanchet@53303
   204
            nth_map arg_idx (K (nth ctr_args ctr_arg_idx |> map_types make_direct_type)))
blanchet@53303
   205
          direct_calls'
blanchet@53303
   206
        |> fold (fn (ctr_arg_idx, arg_idx) =>
blanchet@53303
   207
            nth_map arg_idx (K (nth ctr_args ctr_arg_idx |> map_types make_indirect_type)))
blanchet@53303
   208
          indirect_calls';
blanchet@53303
   209
panny@53401
   210
      val fun_name_ctr_pos_list =
panny@53401
   211
        map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
panny@53401
   212
      val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1;
blanchet@53303
   213
      val direct_calls = map (apfst (nth ctr_args) o apsnd (nth args)) direct_calls';
blanchet@53303
   214
      val indirect_calls = map (apfst (nth ctr_args) o apsnd (nth args)) indirect_calls';
blanchet@53303
   215
panny@53401
   216
      val abstractions = args @ #left_args eqn_data @ #right_args eqn_data;
blanchet@53303
   217
    in
panny@53350
   218
      t
panny@53401
   219
      |> subst_rec_calls lthy get_ctr_pos has_call ctr_args direct_calls indirect_calls
panny@53401
   220
      |> fold_rev lambda abstractions
panny@53350
   221
    end;
blanchet@53303
   222
panny@53358
   223
fun build_defs lthy bs mxs funs_data rec_specs has_call =
blanchet@53303
   224
  let
blanchet@53303
   225
    val n_funs = length funs_data;
blanchet@53303
   226
blanchet@53303
   227
    val ctr_spec_eqn_data_list' =
blanchet@53303
   228
      (take n_funs rec_specs |> map #ctr_specs) ~~ funs_data
blanchet@53303
   229
      |> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y))
blanchet@53303
   230
          ##> (fn x => null x orelse
blanchet@53303
   231
            primrec_error_eqns "excess equations in definition" (map #rhs_term x)) #> fst);
blanchet@53303
   232
    val _ = ctr_spec_eqn_data_list' |> map (fn (_, x) => length x <= 1 orelse
blanchet@53303
   233
      primrec_error_eqns ("multiple equations for constructor") (map #user_eqn x));
blanchet@53303
   234
blanchet@53303
   235
    val ctr_spec_eqn_data_list =
blanchet@53303
   236
      ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair []));
blanchet@53303
   237
blanchet@53303
   238
    val recs = take n_funs rec_specs |> map #recx;
blanchet@53303
   239
    val rec_args = ctr_spec_eqn_data_list
blanchet@53303
   240
      |> sort ((op <) o pairself (#offset o fst) |> make_ord)
panny@53358
   241
      |> map (uncurry (build_rec_arg lthy funs_data has_call) o apsnd (try the_single));
blanchet@53303
   242
    val ctr_poss = map (fn x =>
blanchet@53303
   243
      if length (distinct ((op =) o pairself (length o #left_args)) x) <> 1 then
blanchet@53303
   244
        primrec_error ("inconstant constructor pattern position for function " ^
blanchet@53303
   245
          quote (#fun_name (hd x)))
blanchet@53303
   246
      else
blanchet@53303
   247
        hd x |> #left_args |> length) funs_data;
blanchet@53303
   248
  in
blanchet@53303
   249
    (recs, ctr_poss)
blanchet@53303
   250
    |-> map2 (fn recx => fn ctr_pos => list_comb (recx, rec_args) |> permute_args ctr_pos)
blanchet@53303
   251
    |> Syntax.check_terms lthy
traytel@53352
   252
    |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.map_name Thm.def_name b, []), t))) bs mxs
blanchet@53303
   253
  end;
blanchet@53303
   254
panny@53358
   255
fun find_rec_calls has_call eqn_data =
blanchet@53303
   256
  let
blanchet@53303
   257
    fun find (Abs (_, _, b)) ctr_arg = find b ctr_arg
blanchet@53303
   258
      | find (t as _ $ _) ctr_arg =
blanchet@53303
   259
        let
blanchet@53303
   260
          val (f', args') = strip_comb t;
blanchet@53303
   261
          val n = find_index (equal ctr_arg) args';
blanchet@53303
   262
        in
blanchet@53303
   263
          if n < 0 then
blanchet@53303
   264
            find f' ctr_arg @ maps (fn x => find x ctr_arg) args'
blanchet@53303
   265
          else
blanchet@53303
   266
            let val (f, args) = chop n args' |>> curry list_comb f' in
panny@53358
   267
              if has_call f then
blanchet@53303
   268
                f :: maps (fn x => find x ctr_arg) args
blanchet@53303
   269
              else
blanchet@53303
   270
                find f ctr_arg @ maps (fn x => find x ctr_arg) args
blanchet@53303
   271
            end
blanchet@53303
   272
        end
blanchet@53303
   273
      | find _ _ = [];
blanchet@53303
   274
  in
blanchet@53303
   275
    map (find (#rhs_term eqn_data)) (#ctr_args eqn_data)
blanchet@53303
   276
    |> (fn [] => NONE | callss => SOME (#ctr eqn_data, callss))
blanchet@53303
   277
  end;
blanchet@53303
   278
blanchet@53303
   279
fun add_primrec fixes specs lthy =
blanchet@53303
   280
  let
traytel@53352
   281
    val (bs, mxs) = map_split (apfst fst) fixes;
blanchet@53303
   282
    val fun_names = map Binding.name_of bs;
blanchet@53303
   283
    val eqns_data = map (snd #> dissect_eqn lthy fun_names) specs;
blanchet@53303
   284
    val funs_data = eqns_data
blanchet@53303
   285
      |> partition_eq ((op =) o pairself #fun_name)
blanchet@53303
   286
      |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
blanchet@53303
   287
      |> map (fn (x, y) => the_single y handle List.Empty =>
blanchet@53303
   288
          primrec_error ("missing equations for function " ^ quote x));
blanchet@53303
   289
panny@53358
   290
    val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
blanchet@53303
   291
    val arg_Ts = map (#rec_type o hd) funs_data;
blanchet@53303
   292
    val res_Ts = map (#res_type o hd) funs_data;
blanchet@53303
   293
    val callssss = funs_data
blanchet@53303
   294
      |> map (partition_eq ((op =) o pairself #ctr))
panny@53358
   295
      |> map (maps (map_filter (find_rec_calls has_call)));
blanchet@53303
   296
panny@53358
   297
    fun get_indices t = map (fst #>> Binding.name_of #> Free) fixes
panny@53358
   298
      |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE)
panny@53358
   299
      |> map_filter I;
blanchet@53303
   300
    val ((nontriv, rec_specs, _, induct_thm, induct_thms), lthy') =
blanchet@53303
   301
      rec_specs_of bs arg_Ts res_Ts get_indices callssss lthy;
blanchet@53303
   302
blanchet@53303
   303
    val actual_nn = length funs_data;
blanchet@53303
   304
blanchet@53303
   305
    val _ = let val ctrs = (maps (map #ctr o #ctr_specs) rec_specs) in
blanchet@53303
   306
      map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
blanchet@53303
   307
        primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy' ctr) ^
blanchet@53303
   308
          " is not a constructor in left-hand side") user_eqn) eqns_data end;
blanchet@53303
   309
panny@53358
   310
    val defs = build_defs lthy' bs mxs funs_data rec_specs has_call;
blanchet@53303
   311
blanchet@53329
   312
    fun prove def_thms' {ctr_specs, nested_map_idents, nested_map_comps, ...} induct_thm fun_data
blanchet@53303
   313
        lthy =
blanchet@53303
   314
      let
blanchet@53303
   315
        val fun_name = #fun_name (hd fun_data);
blanchet@53303
   316
        val def_thms = map (snd o snd) def_thms';
blanchet@53303
   317
        val simp_thms = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
blanchet@53303
   318
          |> fst
blanchet@53303
   319
          |> map_filter (try (fn (x, [y]) =>
blanchet@53303
   320
            (#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
blanchet@53303
   321
          |> map (fn (user_eqn, num_extra_args, rec_thm) =>
blanchet@53329
   322
            mk_primrec_tac lthy num_extra_args nested_map_idents nested_map_comps def_thms rec_thm
blanchet@53303
   323
            |> K |> Goal.prove lthy [] [] user_eqn)
blanchet@53303
   324
blanchet@53303
   325
        val notes =
blanchet@53303
   326
          [(inductN, if actual_nn > 1 then [induct_thm] else [], []),
blanchet@53303
   327
           (simpsN, simp_thms, simp_attrs)]
blanchet@53303
   328
          |> filter_out (null o #2)
blanchet@53303
   329
          |> map (fn (thmN, thms, attrs) =>
blanchet@53303
   330
            ((Binding.qualify true fun_name (Binding.name thmN), attrs), [(thms, [])]));
blanchet@53303
   331
      in
blanchet@53303
   332
        lthy |> Local_Theory.notes notes
blanchet@53303
   333
      end;
blanchet@53303
   334
blanchet@53303
   335
    val common_name = mk_common_name fun_names;
blanchet@53303
   336
blanchet@53303
   337
    val common_notes =
blanchet@53303
   338
      [(inductN, if nontriv andalso actual_nn > 1 then [induct_thm] else [], [])]
blanchet@53303
   339
      |> filter_out (null o #2)
blanchet@53303
   340
      |> map (fn (thmN, thms, attrs) =>
blanchet@53303
   341
        ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
blanchet@53303
   342
  in
blanchet@53303
   343
    lthy'
blanchet@53303
   344
    |> fold_map Local_Theory.define defs
blanchet@53303
   345
    |-> (fn def_thms' => fold_map3 (prove def_thms') (take actual_nn rec_specs)
blanchet@53303
   346
      (take actual_nn induct_thms) funs_data)
blanchet@53303
   347
    |> snd
blanchet@53303
   348
    |> Local_Theory.notes common_notes |> snd
blanchet@53303
   349
  end;
blanchet@53303
   350
blanchet@53303
   351
fun add_primrec_cmd raw_fixes raw_specs lthy =
blanchet@53303
   352
  let
blanchet@53303
   353
    val _ = let val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes) in null d orelse
blanchet@53303
   354
      primrec_error ("duplicate function name(s): " ^ commas d) end;
blanchet@53303
   355
    val (fixes, specs) = fst (Specification.read_spec raw_fixes raw_specs lthy);
blanchet@53303
   356
  in
blanchet@53303
   357
    add_primrec fixes specs lthy
blanchet@53303
   358
      handle ERROR str => primrec_error str
blanchet@53303
   359
  end
blanchet@53303
   360
  handle Primrec_Error (str, eqns) =>
blanchet@53303
   361
    if null eqns
blanchet@53303
   362
    then error ("primrec_new error:\n  " ^ str)
blanchet@53303
   363
    else error ("primrec_new error:\n  " ^ str ^ "\nin\n  " ^
blanchet@53303
   364
      space_implode "\n  " (map (quote o Syntax.string_of_term lthy) eqns))
blanchet@53303
   365
blanchet@53303
   366
blanchet@53303
   367
blanchet@53310
   368
(* Primcorec *)
blanchet@53303
   369
panny@53341
   370
type co_eqn_data_disc = {
blanchet@53303
   371
  fun_name: string,
panny@53401
   372
  fun_args: term list,
panny@53341
   373
  ctr_no: int, (*###*)
blanchet@53303
   374
  cond: term,
blanchet@53303
   375
  user_eqn: term
blanchet@53303
   376
};
panny@53341
   377
type co_eqn_data_sel = {
blanchet@53303
   378
  fun_name: string,
panny@53401
   379
  fun_args: term list,
panny@53341
   380
  ctr: term,
panny@53341
   381
  sel: term,
blanchet@53303
   382
  rhs_term: term,
blanchet@53303
   383
  user_eqn: term
blanchet@53303
   384
};
blanchet@53303
   385
datatype co_eqn_data =
panny@53341
   386
  Disc of co_eqn_data_disc |
panny@53341
   387
  Sel of co_eqn_data_sel;
blanchet@53303
   388
panny@53401
   389
fun co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds =
blanchet@53303
   390
  let
blanchet@53303
   391
    fun find_subterm p = let (* FIXME \<exists>? *)
panny@53401
   392
      fun f (t as u $ v) = if p t then SOME t else merge_options (f u, f v)
blanchet@53303
   393
        | f t = if p t then SOME t else NONE
blanchet@53303
   394
      in f end;
blanchet@53303
   395
blanchet@53303
   396
    val fun_name = imp_rhs
blanchet@53303
   397
      |> perhaps (try HOLogic.dest_not)
blanchet@53303
   398
      |> `(try (fst o dest_Free o head_of o snd o dest_comb))
blanchet@53303
   399
      ||> (try (fst o dest_Free o head_of o fst o HOLogic.dest_eq))
blanchet@53303
   400
      |> the o merge_options;
blanchet@53303
   401
    val corec_spec = the (AList.lookup (op =) fun_name_corec_spec_list fun_name)
blanchet@53303
   402
      handle Option.Option => primrec_error_eqn "malformed discriminator equation" imp_rhs;
blanchet@53303
   403
blanchet@53303
   404
    val discs = #ctr_specs corec_spec |> map #disc;
blanchet@53303
   405
    val ctrs = #ctr_specs corec_spec |> map #ctr;
blanchet@53303
   406
    val not_disc = head_of imp_rhs = @{term Not};
panny@53401
   407
    val _ = not_disc andalso length ctrs <> 2 andalso
blanchet@53303
   408
      primrec_error_eqn "\<not>ed discriminator for a type with \<noteq> 2 constructors" imp_rhs;
blanchet@53303
   409
    val disc = find_subterm (member (op =) discs o head_of) imp_rhs;
blanchet@53303
   410
    val eq_ctr0 = imp_rhs |> perhaps (try (HOLogic.dest_not)) |> try (HOLogic.dest_eq #> snd)
blanchet@53303
   411
        |> (fn SOME t => let val n = find_index (equal t) ctrs in
blanchet@53303
   412
          if n >= 0 then SOME n else NONE end | _ => NONE);
blanchet@53303
   413
    val _ = is_some disc orelse is_some eq_ctr0 orelse
blanchet@53303
   414
      primrec_error_eqn "no discriminator in equation" imp_rhs;
blanchet@53303
   415
    val ctr_no' =
blanchet@53303
   416
      if is_none disc then the eq_ctr0 else find_index (equal (head_of (the disc))) discs;
blanchet@53303
   417
    val ctr_no = if not_disc then 1 - ctr_no' else ctr_no';
blanchet@53303
   418
    val fun_args = if is_none disc
blanchet@53303
   419
      then imp_rhs |> perhaps (try HOLogic.dest_not) |> HOLogic.dest_eq |> fst |> strip_comb |> snd
blanchet@53303
   420
      else the disc |> the_single o snd o strip_comb
panny@53357
   421
        |> (fn t => if free_name (head_of t) = SOME fun_name
blanchet@53303
   422
          then snd (strip_comb t) else []);
blanchet@53303
   423
blanchet@53303
   424
    val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True};
blanchet@53303
   425
    val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False};
blanchet@53303
   426
    val catch_all = try (fst o dest_Free o the_single) imp_lhs' = SOME Name.uu_;
panny@53401
   427
    val matched_cond = filter (equal fun_name o fst) matched_conds |> map snd |> mk_disjs;
panny@53401
   428
    val imp_lhs = mk_conjs imp_lhs'
panny@53401
   429
      |> incr_boundvars (length fun_args)
panny@53401
   430
      |> subst_atomic (fun_args ~~ map Bound (length fun_args - 1 downto 0))
blanchet@53303
   431
    val cond =
blanchet@53303
   432
      if catch_all then
panny@53401
   433
        matched_cond |> HOLogic.mk_not
blanchet@53303
   434
      else if sequential then
panny@53401
   435
        HOLogic.mk_conj (HOLogic.mk_not matched_cond, imp_lhs)
blanchet@53303
   436
      else
panny@53401
   437
        imp_lhs;
blanchet@53303
   438
panny@53401
   439
    val matched_conds' =
panny@53401
   440
      (fun_name, if catch_all orelse not sequential then cond else imp_lhs) :: matched_conds;
blanchet@53303
   441
  in
panny@53341
   442
    (Disc {
blanchet@53303
   443
      fun_name = fun_name,
panny@53401
   444
      fun_args = fun_args,
blanchet@53303
   445
      ctr_no = ctr_no,
blanchet@53303
   446
      cond = cond,
blanchet@53303
   447
      user_eqn = eqn'
panny@53401
   448
    }, matched_conds')
blanchet@53303
   449
  end;
blanchet@53303
   450
blanchet@53303
   451
fun co_dissect_eqn_sel fun_name_corec_spec_list eqn' eqn =
blanchet@53303
   452
  let
blanchet@53303
   453
    val (lhs, rhs) = HOLogic.dest_eq eqn
blanchet@53303
   454
      handle TERM _ =>
blanchet@53303
   455
        primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn;
blanchet@53303
   456
    val sel = head_of lhs;
blanchet@53303
   457
    val (fun_name, fun_args) = dest_comb lhs |> snd |> strip_comb |> apfst (fst o dest_Free)
blanchet@53303
   458
      handle TERM _ =>
blanchet@53303
   459
        primrec_error_eqn "malformed selector argument in left-hand side" eqn;
blanchet@53303
   460
    val corec_spec = the (AList.lookup (op =) fun_name_corec_spec_list fun_name)
blanchet@53303
   461
      handle Option.Option => primrec_error_eqn "malformed selector argument in left-hand side" eqn;
panny@53341
   462
    val (ctr_spec, sel) = #ctr_specs corec_spec
blanchet@53303
   463
      |> the o get_index (try (the o find_first (equal sel) o #sels))
panny@53341
   464
      |>> nth (#ctr_specs corec_spec);
blanchet@53303
   465
  in
panny@53341
   466
    Sel {
blanchet@53303
   467
      fun_name = fun_name,
panny@53401
   468
      fun_args = fun_args,
panny@53341
   469
      ctr = #ctr ctr_spec,
panny@53341
   470
      sel = sel,
blanchet@53303
   471
      rhs_term = rhs,
blanchet@53303
   472
      user_eqn = eqn'
blanchet@53303
   473
    }
blanchet@53303
   474
  end;
blanchet@53303
   475
panny@53401
   476
fun co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds =
blanchet@53303
   477
  let 
blanchet@53303
   478
    val (lhs, rhs) = HOLogic.dest_eq imp_rhs;
blanchet@53303
   479
    val fun_name = head_of lhs |> fst o dest_Free;
blanchet@53303
   480
    val corec_spec = the (AList.lookup (op =) fun_name_corec_spec_list fun_name);
blanchet@53303
   481
    val (ctr, ctr_args) = strip_comb rhs;
blanchet@53303
   482
    val ctr_spec = the (find_first (equal ctr o #ctr) (#ctr_specs corec_spec))
blanchet@53303
   483
      handle Option.Option => primrec_error_eqn "not a constructor" ctr;
panny@53341
   484
blanchet@53303
   485
    val disc_imp_rhs = betapply (#disc ctr_spec, lhs);
panny@53401
   486
    val (maybe_eqn_data_disc, matched_conds') = if length (#ctr_specs corec_spec) = 1
panny@53401
   487
      then (NONE, matched_conds)
panny@53341
   488
      else apfst SOME (co_dissect_eqn_disc
panny@53401
   489
          sequential fun_name_corec_spec_list eqn' imp_lhs' disc_imp_rhs matched_conds);
blanchet@53303
   490
blanchet@53303
   491
    val sel_imp_rhss = (#sels ctr_spec ~~ ctr_args)
blanchet@53303
   492
      |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg));
blanchet@53303
   493
panny@53360
   494
val _ = tracing ("reduced\n    " ^ Syntax.string_of_term @{context} imp_rhs ^ "\nto\n    \<cdot> " ^
panny@53341
   495
 (is_some maybe_eqn_data_disc ? K (Syntax.string_of_term @{context} disc_imp_rhs ^ "\n    \<cdot> ")) "" ^
blanchet@53303
   496
 space_implode "\n    \<cdot> " (map (Syntax.string_of_term @{context}) sel_imp_rhss));
blanchet@53303
   497
blanchet@53303
   498
    val eqns_data_sel =
panny@53341
   499
      map (co_dissect_eqn_sel fun_name_corec_spec_list eqn') sel_imp_rhss;
blanchet@53303
   500
  in
panny@53401
   501
    (map_filter I [maybe_eqn_data_disc] @ eqns_data_sel, matched_conds')
blanchet@53303
   502
  end;
blanchet@53303
   503
panny@53401
   504
fun co_dissect_eqn sequential fun_name_corec_spec_list eqn' matched_conds =
blanchet@53303
   505
  let
blanchet@53303
   506
    val eqn = subst_bounds (strip_qnt_vars @{const_name all} eqn' |> map Free |> rev,
blanchet@53303
   507
        strip_qnt_body @{const_name all} eqn')
blanchet@53303
   508
        handle TERM _ => primrec_error_eqn "malformed function equation" eqn';
panny@53341
   509
    val (imp_lhs', imp_rhs) = Logic.strip_horn eqn
panny@53341
   510
      |> apfst (map HOLogic.dest_Trueprop) o apsnd HOLogic.dest_Trueprop;
blanchet@53303
   511
blanchet@53303
   512
    val head = imp_rhs
blanchet@53303
   513
      |> perhaps (try HOLogic.dest_not) |> perhaps (try (fst o HOLogic.dest_eq))
blanchet@53303
   514
      |> head_of;
blanchet@53303
   515
blanchet@53303
   516
    val maybe_rhs = imp_rhs |> perhaps (try (HOLogic.dest_not)) |> try (snd o HOLogic.dest_eq);
blanchet@53303
   517
blanchet@53303
   518
    val fun_names = map fst fun_name_corec_spec_list;
blanchet@53303
   519
    val discs = maps (#ctr_specs o snd) fun_name_corec_spec_list |> map #disc;
blanchet@53303
   520
    val sels = maps (#ctr_specs o snd) fun_name_corec_spec_list |> maps #sels;
blanchet@53303
   521
    val ctrs = maps (#ctr_specs o snd) fun_name_corec_spec_list |> map #ctr;
blanchet@53303
   522
  in
blanchet@53303
   523
    if member (op =) discs head orelse
blanchet@53303
   524
      is_some maybe_rhs andalso
blanchet@53303
   525
        member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then
panny@53401
   526
      co_dissect_eqn_disc sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds
blanchet@53303
   527
      |>> single
blanchet@53303
   528
    else if member (op =) sels head then
panny@53401
   529
      ([co_dissect_eqn_sel fun_name_corec_spec_list eqn' imp_rhs], matched_conds)
blanchet@53303
   530
    else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) then
panny@53401
   531
      co_dissect_eqn_ctr sequential fun_name_corec_spec_list eqn' imp_lhs' imp_rhs matched_conds
blanchet@53303
   532
    else
blanchet@53303
   533
      primrec_error_eqn "malformed function equation" eqn
blanchet@53303
   534
  end;
blanchet@53303
   535
panny@53341
   536
fun build_corec_args_discs disc_eqns ctr_specs =
panny@53401
   537
  if null disc_eqns then I else
panny@53401
   538
    let
panny@53411
   539
(*val _ = tracing ("d/p:\<cdot> " ^ space_implode "\n    \<cdot> " (map ((op ^) o
panny@53411
   540
 apfst (Syntax.string_of_term @{context}) o apsnd (curry (op ^) " / " o @{make_string}))
panny@53411
   541
  (ctr_specs |> map_filter (fn {disc, pred = SOME pred, ...} => SOME (disc, pred) | _ => NONE))));*)
panny@53401
   542
      val conds = map #cond disc_eqns;
panny@53401
   543
      val fun_args = #fun_args (hd disc_eqns);
panny@53401
   544
      val args =
panny@53401
   545
        if length ctr_specs = 1 then []
panny@53401
   546
        else if length disc_eqns = length ctr_specs then
panny@53401
   547
          fst (split_last conds)
panny@53401
   548
        else if length disc_eqns = length ctr_specs - 1 then
panny@53401
   549
          let val n = 0 upto length ctr_specs - 1
panny@53411
   550
              |> the(*###*) o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns)) in
panny@53401
   551
            if n = length ctr_specs - 1 then
panny@53401
   552
              conds
panny@53401
   553
            else
panny@53401
   554
              split_last conds
panny@53401
   555
              ||> HOLogic.mk_not
panny@53411
   556
              |> `(uncurry (fold (curry HOLogic.mk_conj o HOLogic.mk_not)))
panny@53411
   557
              ||> chop n o fst
panny@53411
   558
              |> (fn (x, (l, r)) => l @ (x :: r))
panny@53401
   559
          end
panny@53401
   560
        else
panny@53401
   561
          0 upto length ctr_specs - 1
panny@53401
   562
          |> map (fn idx => find_first (equal idx o #ctr_no) disc_eqns
panny@53401
   563
            |> Option.map #cond
panny@53401
   564
            |> the_default undef_const)
panny@53401
   565
          |> fst o split_last;
panny@53401
   566
    in
panny@53401
   567
      (* FIXME deal with #preds above *)
panny@53401
   568
      (map_filter #pred ctr_specs, args)
panny@53401
   569
      |-> fold2 (fn idx => fn t => nth_map idx
panny@53401
   570
        (K (subst_bounds (List.rev fun_args, t)
panny@53401
   571
          |> HOLogic.tupled_lambda (HOLogic.mk_tuple fun_args))))
panny@53401
   572
    end;
blanchet@53303
   573
panny@53360
   574
fun build_corec_arg_no_call sel_eqns sel = find_first (equal sel o #sel) sel_eqns
panny@53401
   575
  |> try (fn SOME sel_eqn => (#fun_args sel_eqn, #rhs_term sel_eqn))
panny@53360
   576
  |> the_default ([], undef_const)
panny@53411
   577
  |-> abs_tuple
panny@53411
   578
  |> K;
panny@53360
   579
panny@53360
   580
fun build_corec_arg_direct_call lthy has_call sel_eqns sel =
panny@53360
   581
  let
panny@53411
   582
    val maybe_sel_eqn = find_first (equal sel o #sel) sel_eqns;
panny@53401
   583
    fun rewrite is_end U T t =
panny@53360
   584
      if U = @{typ bool} then @{term True} |> has_call t ? K @{term False} (* stop? *)
panny@53401
   585
      else if is_end = has_call t then undef_const
panny@53401
   586
      else if is_end then t (* end *)
panny@53360
   587
      else HOLogic.mk_tuple (snd (strip_comb t)); (* continue *)
panny@53401
   588
    fun massage rhs_term is_end t = massage_direct_corec_call
panny@53401
   589
      lthy has_call (rewrite is_end) [] (range_type (fastype_of t)) rhs_term;
panny@53360
   590
  in
panny@53401
   591
    if is_none maybe_sel_eqn then K I else
panny@53401
   592
      abs_tuple (#fun_args (the maybe_sel_eqn)) oo massage (#rhs_term (the maybe_sel_eqn))
panny@53360
   593
  end;
panny@53360
   594
panny@53411
   595
fun build_corec_arg_indirect_call lthy has_call sel_eqns sel =
panny@53411
   596
  let
panny@53411
   597
    val maybe_sel_eqn = find_first (equal sel o #sel) sel_eqns;
panny@53411
   598
    fun rewrite _ _ =
panny@53411
   599
      let
panny@53411
   600
        fun subst (Abs (v, T, b)) = Abs (v, T, subst b)
panny@53411
   601
          | subst (t as _ $ _) =
panny@53411
   602
            let val (u, vs) = strip_comb t in
panny@53411
   603
              if is_Free u andalso has_call u then
panny@53411
   604
                Const (@{const_name Inr}, dummyT) $ (*HOLogic.mk_tuple vs*)
panny@53411
   605
                  (try (foldr1 (fn (x, y) => Const (@{const_name Pair}, dummyT) $ x $ y)) vs
panny@53411
   606
                    |> the_default (hd vs))
panny@53411
   607
              else if try (fst o dest_Const) u = SOME @{const_name prod_case} then
panny@53411
   608
                list_comb (u |> map_types (K dummyT), map subst vs)
panny@53411
   609
              else
panny@53411
   610
                list_comb (subst u, map subst vs)
panny@53411
   611
            end
panny@53411
   612
          | subst t = t;
panny@53411
   613
      in
panny@53411
   614
        subst
panny@53411
   615
      end;
panny@53411
   616
    fun massage rhs_term t = massage_indirect_corec_call
panny@53411
   617
      lthy has_call rewrite [] (fastype_of t |> range_type) rhs_term;
panny@53411
   618
  in
panny@53411
   619
    if is_none maybe_sel_eqn then I else
panny@53411
   620
      abs_tuple (#fun_args (the maybe_sel_eqn)) o massage (#rhs_term (the maybe_sel_eqn))
panny@53411
   621
  end;
panny@53360
   622
panny@53360
   623
fun build_corec_args_sel lthy has_call all_sel_eqns ctr_spec =
panny@53341
   624
  let val sel_eqns = filter (equal (#ctr ctr_spec) o #ctr) all_sel_eqns in
panny@53341
   625
    if null sel_eqns then I else
panny@53341
   626
      let
panny@53341
   627
        val sel_call_list = #sels ctr_spec ~~ #calls ctr_spec;
panny@53341
   628
panny@53411
   629
(*val _ = tracing ("s/c:\<cdot> " ^ space_implode "\n    \<cdot> " (map ((op ^) o
panny@53411
   630
 apfst (Syntax.string_of_term lthy) o apsnd (curry (op ^) " / " o @{make_string}))
panny@53411
   631
  sel_call_list));*)
panny@53341
   632
panny@53341
   633
        val no_calls' = map_filter (try (apsnd (fn No_Corec n => n))) sel_call_list;
panny@53341
   634
        val direct_calls' = map_filter (try (apsnd (fn Direct_Corec n => n))) sel_call_list;
panny@53341
   635
        val indirect_calls' = map_filter (try (apsnd (fn Indirect_Corec n => n))) sel_call_list;
panny@53341
   636
      in
panny@53360
   637
        I
panny@53360
   638
        #> fold (fn (sel, n) => nth_map n
panny@53411
   639
          (build_corec_arg_no_call sel_eqns sel)) no_calls'
panny@53360
   640
        #> fold (fn (sel, (q, g, h)) =>
panny@53360
   641
          let val f = build_corec_arg_direct_call lthy has_call sel_eqns sel in
panny@53401
   642
            nth_map h (f false) o nth_map g (f true) o nth_map q (f true) end) direct_calls'
panny@53360
   643
        #> fold (fn (sel, n) => nth_map n
panny@53411
   644
          (build_corec_arg_indirect_call lthy has_call sel_eqns sel)) indirect_calls'
panny@53341
   645
      end
blanchet@53303
   646
  end;
blanchet@53303
   647
panny@53360
   648
fun co_build_defs lthy sequential bs mxs has_call arg_Tss fun_name_corec_spec_list eqns_data =
blanchet@53303
   649
  let
blanchet@53303
   650
    val fun_names = map Binding.name_of bs;
blanchet@53303
   651
panny@53341
   652
    val disc_eqnss = map_filter (try (fn Disc x => x)) eqns_data
blanchet@53303
   653
      |> partition_eq ((op =) o pairself #fun_name)
blanchet@53303
   654
      |> finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names |> fst
blanchet@53303
   655
      |> map (sort ((op <) o pairself #ctr_no |> make_ord) o flat o snd);
blanchet@53303
   656
blanchet@53303
   657
    val _ = disc_eqnss |> map (fn x =>
blanchet@53303
   658
      let val d = duplicates ((op =) o pairself #ctr_no) x in null d orelse
blanchet@53303
   659
        primrec_error_eqns "excess discriminator equations in definition"
blanchet@53303
   660
          (maps (fn t => filter (equal (#ctr_no t) o #ctr_no) x) d |> map #user_eqn) end);
blanchet@53303
   661
panny@53360
   662
(*val _ = tracing ("disc_eqnss:\n    \<cdot> " ^ space_implode "\n    \<cdot> " (map @{make_string} disc_eqnss));*)
blanchet@53303
   663
panny@53341
   664
    val sel_eqnss = map_filter (try (fn Sel x => x)) eqns_data
blanchet@53303
   665
      |> partition_eq ((op =) o pairself #fun_name)
blanchet@53303
   666
      |> finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names |> fst
panny@53341
   667
      |> map (flat o snd);
blanchet@53303
   668
panny@53360
   669
(*val _ = tracing ("sel_eqnss:\n    \<cdot> " ^ space_implode "\n    \<cdot> " (map @{make_string} sel_eqnss));*)
blanchet@53303
   670
blanchet@53303
   671
    val corecs = map (#corec o snd) fun_name_corec_spec_list;
panny@53341
   672
    val ctr_specss = map (#ctr_specs o snd) fun_name_corec_spec_list;
panny@53360
   673
    val corec_args = hd corecs
panny@53360
   674
      |> fst o split_last o binder_types o fastype_of
panny@53360
   675
      |> map (Const o pair @{const_name undefined})
panny@53341
   676
      |> fold2 build_corec_args_discs disc_eqnss ctr_specss
panny@53360
   677
      |> fold2 (fold o build_corec_args_sel lthy has_call) sel_eqnss ctr_specss;
blanchet@53303
   678
panny@53401
   679
    fun currys Ts t = if length Ts <= 1 then t else
panny@53401
   680
      t $ foldr1 (fn (u, v) => HOLogic.pair_const dummyT dummyT $ u $ v)
panny@53401
   681
        (length Ts - 1 downto 0 |> map Bound)
panny@53401
   682
      |> fold_rev (fn T => fn u => Abs (Name.uu, T, u)) Ts;
panny@53401
   683
panny@53360
   684
val _ = tracing ("corecursor arguments:\n    \<cdot> " ^
panny@53411
   685
 space_implode "\n    \<cdot> " (map (Syntax.string_of_term lthy) corec_args));
blanchet@53303
   686
blanchet@53303
   687
    fun uneq_pairs_rev xs = xs (* FIXME \<exists>? *)
blanchet@53303
   688
      |> these o try (split_last #> (fn (ys, y) => uneq_pairs_rev ys @ map (pair y) ys));
blanchet@53303
   689
    val proof_obligations = if sequential then [] else
panny@53411
   690
      disc_eqnss
panny@53411
   691
      |> maps (uneq_pairs_rev o map (fn {fun_args, cond, ...} => (fun_args, cond)))
panny@53401
   692
      |> map (fn ((fun_args, x), (_, y)) => [x, HOLogic.mk_not y]
panny@53401
   693
        |> map (HOLogic.mk_Trueprop o curry subst_bounds (List.rev fun_args))
panny@53401
   694
        |> curry list_comb @{const ==>});
blanchet@53303
   695
panny@53401
   696
val _ = tracing ("proof obligations:\n    \<cdot> " ^
panny@53411
   697
 space_implode "\n    \<cdot> " (map (Syntax.string_of_term lthy) proof_obligations));
panny@53401
   698
blanchet@53303
   699
  in
blanchet@53303
   700
    map (list_comb o rpair corec_args) corecs
blanchet@53303
   701
    |> map2 (fn Ts => fn t => if length Ts = 0 then t $ HOLogic.unit else t) arg_Tss
blanchet@53303
   702
    |> map2 currys arg_Tss
blanchet@53303
   703
    |> Syntax.check_terms lthy
traytel@53352
   704
    |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.map_name Thm.def_name b, []), t))) bs mxs
blanchet@53303
   705
    |> rpair proof_obligations
blanchet@53303
   706
  end;
blanchet@53303
   707
blanchet@53303
   708
fun add_primcorec sequential fixes specs lthy =
blanchet@53303
   709
  let
traytel@53352
   710
    val (bs, mxs) = map_split (apfst fst) fixes;
blanchet@53303
   711
    val (arg_Ts, res_Ts) = map (strip_type o snd o fst #>> HOLogic.mk_tupleT) fixes |> split_list;
blanchet@53303
   712
blanchet@53303
   713
    (* copied from primrec_new *)
blanchet@53303
   714
    fun get_indices t = map (fst #>> Binding.name_of #> Free) fixes
blanchet@53303
   715
      |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE)
blanchet@53303
   716
      |> map_filter I;
blanchet@53303
   717
blanchet@53303
   718
    val callssss = []; (* FIXME *)
blanchet@53303
   719
blanchet@53303
   720
    val ((nontriv, corec_specs, _, coinduct_thm, strong_co_induct_thm, coinduct_thmss,
blanchet@53303
   721
          strong_coinduct_thmss), lthy') =
blanchet@53303
   722
      corec_specs_of bs arg_Ts res_Ts get_indices callssss lthy;
blanchet@53303
   723
blanchet@53303
   724
    val fun_names = map Binding.name_of bs;
blanchet@53303
   725
blanchet@53303
   726
    val fun_name_corec_spec_list = (fun_names ~~ res_Ts, corec_specs)
panny@53360
   727
      |> uncurry (finds (fn ((_, T), {corec, ...}) => T = body_type (fastype_of corec))) |> fst
blanchet@53303
   728
      |> map (apfst fst #> apsnd the_single); (*###*)
blanchet@53303
   729
blanchet@53303
   730
    val (eqns_data, _) =
blanchet@53303
   731
      fold_map (co_dissect_eqn sequential fun_name_corec_spec_list) (map snd specs) []
blanchet@53303
   732
      |>> flat;
blanchet@53303
   733
panny@53360
   734
    val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
panny@53360
   735
    val arg_Tss = map (binder_types o snd o fst) fixes;
blanchet@53303
   736
    val (defs, proof_obligations) =
panny@53360
   737
      co_build_defs lthy' sequential bs mxs has_call arg_Tss fun_name_corec_spec_list eqns_data;
blanchet@53303
   738
  in
blanchet@53303
   739
    lthy'
blanchet@53303
   740
    |> fold_map Local_Theory.define defs |> snd
blanchet@53303
   741
    |> Proof.theorem NONE (K I) [map (rpair []) proof_obligations]
blanchet@53303
   742
    |> Proof.refine (Method.primitive_text I)
blanchet@53303
   743
    |> Seq.hd
blanchet@53303
   744
  end
blanchet@53303
   745
blanchet@53303
   746
fun add_primcorec_cmd seq (raw_fixes, raw_specs) lthy =
blanchet@53303
   747
  let
blanchet@53303
   748
    val (fixes, specs) = fst (Specification.read_spec raw_fixes raw_specs lthy);
blanchet@53303
   749
  in
blanchet@53303
   750
    add_primcorec seq fixes specs lthy
blanchet@53303
   751
    handle ERROR str => primrec_error str
blanchet@53303
   752
  end
blanchet@53303
   753
  handle Primrec_Error (str, eqns) =>
blanchet@53303
   754
    if null eqns
blanchet@53303
   755
    then error ("primcorec error:\n  " ^ str)
blanchet@53303
   756
    else error ("primcorec error:\n  " ^ str ^ "\nin\n  " ^
blanchet@53303
   757
      space_implode "\n  " (map (quote o Syntax.string_of_term lthy) eqns))
blanchet@53303
   758
blanchet@53303
   759
end;