src/HOL/Tools/BNF/bnf_lfp_rec_sugar.ML
author wenzelm
Sat Mar 22 18:19:57 2014 +0100 (2014-03-22)
changeset 56254 a2dd9200854d
parent 56121 52e8f110fec3
child 56638 092a306bcc3d
permissions -rw-r--r--
more antiquotations;
blanchet@55061
     1
(*  Title:      HOL/Tools/BNF/bnf_lfp_rec_sugar.ML
blanchet@54246
     2
    Author:     Lorenz Panny, TU Muenchen
blanchet@54246
     3
    Author:     Jasmin Blanchette, TU Muenchen
blanchet@54246
     4
    Copyright   2013
blanchet@54246
     5
blanchet@55538
     6
Recursor sugar ("primrec").
blanchet@54246
     7
*)
blanchet@54246
     8
blanchet@54246
     9
signature BNF_LFP_REC_SUGAR =
blanchet@54246
    10
sig
panny@56121
    11
  datatype primrec_option = Nonexhaustive_Option
panny@56121
    12
blanchet@55571
    13
  type basic_lfp_sugar =
blanchet@55571
    14
    {T: typ,
blanchet@55571
    15
     fp_res_index: int,
blanchet@55574
    16
     C: typ,
blanchet@55574
    17
     fun_arg_Tsss : typ list list list,
blanchet@55571
    18
     ctr_defs: thm list,
blanchet@55571
    19
     ctr_sugar: Ctr_Sugar.ctr_sugar,
blanchet@55571
    20
     recx: term,
blanchet@55571
    21
     rec_thms: thm list};
blanchet@55571
    22
blanchet@55571
    23
  type lfp_rec_extension =
blanchet@55575
    24
    {nested_simps: thm list,
blanchet@55575
    25
     is_new_datatype: Proof.context -> string -> bool,
blanchet@55772
    26
     get_basic_lfp_sugars: binding list -> typ list -> term list ->
blanchet@55772
    27
       (term * term list list) list list -> local_theory ->
blanchet@55772
    28
       typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * bool * local_theory,
blanchet@55576
    29
     rewrite_nested_rec_call: Proof.context -> (term -> bool) -> (string -> int) -> typ list ->
blanchet@55576
    30
       term -> term -> term -> term};
blanchet@55575
    31
blanchet@55575
    32
  exception PRIMREC of string * term list;
blanchet@55571
    33
blanchet@55571
    34
  val register_lfp_rec_extension: lfp_rec_extension -> theory -> theory
blanchet@55571
    35
blanchet@54246
    36
  val add_primrec: (binding * typ option * mixfix) list ->
blanchet@54246
    37
    (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory
panny@56121
    38
  val add_primrec_cmd: primrec_option list -> (binding * string option * mixfix) list ->
blanchet@54246
    39
    (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory
blanchet@54246
    40
  val add_primrec_global: (binding * typ option * mixfix) list ->
blanchet@54246
    41
    (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
blanchet@54246
    42
  val add_primrec_overloaded: (string * (string * typ) * bool) list ->
blanchet@54246
    43
    (binding * typ option * mixfix) list ->
blanchet@54246
    44
    (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
blanchet@54246
    45
  val add_primrec_simple: ((binding * typ) * mixfix) list -> term list ->
blanchet@54246
    46
    local_theory -> (string list * (term list * (int list list * thm list list))) * local_theory
blanchet@54246
    47
end;
blanchet@54246
    48
blanchet@54246
    49
structure BNF_LFP_Rec_Sugar : BNF_LFP_REC_SUGAR =
blanchet@54246
    50
struct
blanchet@54246
    51
blanchet@54246
    52
open Ctr_Sugar
blanchet@55575
    53
open Ctr_Sugar_Util
blanchet@55574
    54
open Ctr_Sugar_General_Tactics
blanchet@54246
    55
open BNF_FP_Rec_Sugar_Util
blanchet@54246
    56
blanchet@55575
    57
val inductN = "induct"
blanchet@55575
    58
val simpsN = "simps"
blanchet@55575
    59
blanchet@54246
    60
val nitpicksimp_attrs = @{attributes [nitpick_simp]};
blanchet@54246
    61
val simp_attrs = @{attributes [simp]};
blanchet@54246
    62
val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs @ simp_attrs;
blanchet@54246
    63
blanchet@55528
    64
exception OLD_PRIMREC of unit;
blanchet@55527
    65
exception PRIMREC of string * term list;
blanchet@54246
    66
panny@56121
    67
datatype primrec_option = Nonexhaustive_Option;
panny@56121
    68
blanchet@54246
    69
datatype rec_call =
blanchet@54246
    70
  No_Rec of int * typ |
blanchet@54246
    71
  Mutual_Rec of (int * typ) * (int * typ) |
blanchet@54246
    72
  Nested_Rec of int * typ;
blanchet@54246
    73
blanchet@54246
    74
type rec_ctr_spec =
blanchet@54246
    75
  {ctr: term,
blanchet@54246
    76
   offset: int,
blanchet@54246
    77
   calls: rec_call list,
blanchet@54246
    78
   rec_thm: thm};
blanchet@54246
    79
blanchet@54246
    80
type rec_spec =
blanchet@54246
    81
  {recx: term,
blanchet@54246
    82
   nested_map_idents: thm list,
blanchet@54246
    83
   nested_map_comps: thm list,
blanchet@54246
    84
   ctr_specs: rec_ctr_spec list};
blanchet@54246
    85
blanchet@55538
    86
type basic_lfp_sugar =
blanchet@55538
    87
  {T: typ,
blanchet@55539
    88
   fp_res_index: int,
blanchet@55574
    89
   C: typ,
blanchet@55574
    90
   fun_arg_Tsss : typ list list list,
blanchet@55539
    91
   ctr_defs: thm list,
blanchet@55539
    92
   ctr_sugar: ctr_sugar,
blanchet@55570
    93
   recx: term,
blanchet@55570
    94
   rec_thms: thm list};
blanchet@55538
    95
blanchet@55571
    96
type lfp_rec_extension =
blanchet@55575
    97
  {nested_simps: thm list,
blanchet@55575
    98
   is_new_datatype: Proof.context -> string -> bool,
blanchet@55772
    99
   get_basic_lfp_sugars: binding list -> typ list -> term list ->
blanchet@55772
   100
     (term * term list list) list list -> local_theory ->
blanchet@55772
   101
     typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * bool * local_theory,
blanchet@55576
   102
   rewrite_nested_rec_call: Proof.context -> (term -> bool) -> (string -> int) -> typ list ->
blanchet@55576
   103
     term -> term -> term -> term};
blanchet@55571
   104
blanchet@55571
   105
structure Data = Theory_Data
blanchet@55571
   106
(
blanchet@55571
   107
  type T = lfp_rec_extension option;
blanchet@55571
   108
  val empty = NONE;
blanchet@55571
   109
  val extend = I;
blanchet@55571
   110
  val merge = merge_options;
blanchet@55571
   111
);
blanchet@55538
   112
blanchet@55571
   113
val register_lfp_rec_extension = Data.put o SOME;
blanchet@55571
   114
blanchet@55575
   115
fun nested_simps ctxt =
blanchet@55575
   116
  (case Data.get (Proof_Context.theory_of ctxt) of
blanchet@55575
   117
    SOME {nested_simps, ...} => nested_simps
blanchet@55575
   118
  | NONE => []);
blanchet@55575
   119
blanchet@55571
   120
fun is_new_datatype ctxt =
blanchet@55571
   121
  (case Data.get (Proof_Context.theory_of ctxt) of
blanchet@55571
   122
    SOME {is_new_datatype, ...} => is_new_datatype ctxt
blanchet@55571
   123
  | NONE => K false);
blanchet@55571
   124
blanchet@55772
   125
fun get_basic_lfp_sugars bs arg_Ts callers callssss lthy =
blanchet@55571
   126
  (case Data.get (Proof_Context.theory_of lthy) of
blanchet@55772
   127
    SOME {get_basic_lfp_sugars, ...} => get_basic_lfp_sugars bs arg_Ts callers callssss lthy
blanchet@55772
   128
  | NONE => error "Functionality not loaded yet");
blanchet@55571
   129
blanchet@55575
   130
fun rewrite_nested_rec_call ctxt =
blanchet@55571
   131
  (case Data.get (Proof_Context.theory_of ctxt) of
blanchet@55575
   132
    SOME {rewrite_nested_rec_call, ...} => rewrite_nested_rec_call ctxt);
blanchet@54246
   133
blanchet@55772
   134
fun rec_specs_of bs arg_Ts res_Ts callers callssss0 lthy0 =
blanchet@55538
   135
  let
blanchet@55538
   136
    val thy = Proof_Context.theory_of lthy0;
blanchet@54246
   137
blanchet@55538
   138
    val (missing_arg_Ts, perm0_kks, basic_lfp_sugars, nested_map_idents, nested_map_comps,
blanchet@55571
   139
         induct_thm, n2m, lthy) =
blanchet@55772
   140
      get_basic_lfp_sugars bs arg_Ts callers callssss0 lthy0;
blanchet@54246
   141
blanchet@55539
   142
    val perm_basic_lfp_sugars = sort (int_ord o pairself #fp_res_index) basic_lfp_sugars;
blanchet@55538
   143
blanchet@55539
   144
    val indices = map #fp_res_index basic_lfp_sugars;
blanchet@55539
   145
    val perm_indices = map #fp_res_index perm_basic_lfp_sugars;
blanchet@55538
   146
blanchet@55539
   147
    val perm_ctrss = map (#ctrs o #ctr_sugar) perm_basic_lfp_sugars;
blanchet@54246
   148
blanchet@54246
   149
    val nn0 = length arg_Ts;
blanchet@55539
   150
    val nn = length perm_ctrss;
blanchet@54246
   151
    val kks = 0 upto nn - 1;
blanchet@55539
   152
blanchet@55539
   153
    val perm_ctr_offsets = map (fn kk => Integer.sum (map length (take kk perm_ctrss))) kks;
blanchet@54246
   154
blanchet@55539
   155
    val perm_lfpTs = map (body_type o fastype_of o hd) perm_ctrss;
blanchet@55574
   156
    val perm_Cs = map #C perm_basic_lfp_sugars;
blanchet@55574
   157
    val perm_fun_arg_Tssss = map #fun_arg_Tsss perm_basic_lfp_sugars;
blanchet@54246
   158
blanchet@55480
   159
    fun unpermute0 perm0_xs = permute_like_unique (op =) perm0_kks kks perm0_xs;
blanchet@55480
   160
    fun unpermute perm_xs = permute_like_unique (op =) perm_indices indices perm_xs;
blanchet@54246
   161
blanchet@54246
   162
    val induct_thms = unpermute0 (conj_dests nn induct_thm);
blanchet@54246
   163
blanchet@54246
   164
    val lfpTs = unpermute perm_lfpTs;
blanchet@54246
   165
    val Cs = unpermute perm_Cs;
blanchet@55539
   166
    val ctr_offsets = unpermute perm_ctr_offsets;
blanchet@54246
   167
blanchet@54246
   168
    val As_rho = tvar_subst thy (take nn0 lfpTs) arg_Ts;
blanchet@54246
   169
    val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn res_Ts;
blanchet@54246
   170
blanchet@54246
   171
    val substA = Term.subst_TVars As_rho;
blanchet@54246
   172
    val substAT = Term.typ_subst_TVars As_rho;
blanchet@54246
   173
    val substCT = Term.typ_subst_TVars Cs_rho;
blanchet@54246
   174
    val substACT = substAT o substCT;
blanchet@54246
   175
blanchet@54246
   176
    val perm_Cs' = map substCT perm_Cs;
blanchet@54246
   177
blanchet@54246
   178
    fun call_of [i] [T] = (if exists_subtype_in Cs T then Nested_Rec else No_Rec) (i, substACT T)
blanchet@54246
   179
      | call_of [i, i'] [T, T'] = Mutual_Rec ((i, substACT T), (i', substACT T'));
blanchet@54246
   180
blanchet@54246
   181
    fun mk_ctr_spec ctr offset fun_arg_Tss rec_thm =
blanchet@54246
   182
      let
blanchet@54246
   183
        val (fun_arg_hss, _) = indexedd fun_arg_Tss 0;
blanchet@54246
   184
        val fun_arg_hs = flat_rec_arg_args fun_arg_hss;
blanchet@54246
   185
        val fun_arg_iss = map (map (find_index_eq fun_arg_hs)) fun_arg_hss;
blanchet@54246
   186
      in
blanchet@54246
   187
        {ctr = substA ctr, offset = offset, calls = map2 call_of fun_arg_iss fun_arg_Tss,
blanchet@54246
   188
         rec_thm = rec_thm}
blanchet@54246
   189
      end;
blanchet@54246
   190
blanchet@55539
   191
    fun mk_ctr_specs fp_res_index k ctrs rec_thms =
blanchet@55539
   192
      map4 mk_ctr_spec ctrs (k upto k + length ctrs - 1) (nth perm_fun_arg_Tssss fp_res_index)
blanchet@55539
   193
        rec_thms;
blanchet@54246
   194
blanchet@55539
   195
    fun mk_spec ctr_offset
blanchet@55570
   196
        ({T, fp_res_index, ctr_sugar = {ctrs, ...}, recx, rec_thms, ...} : basic_lfp_sugar) =
blanchet@55869
   197
      {recx = mk_co_rec thy Least_FP (substAT T) perm_Cs' recx,
blanchet@55538
   198
       nested_map_idents = nested_map_idents, nested_map_comps = nested_map_comps,
blanchet@55570
   199
       ctr_specs = mk_ctr_specs fp_res_index ctr_offset ctrs rec_thms};
blanchet@54246
   200
  in
blanchet@55571
   201
    ((n2m, map2 mk_spec ctr_offsets basic_lfp_sugars, missing_arg_Ts, induct_thm, induct_thms),
blanchet@55571
   202
     lthy)
blanchet@54246
   203
  end;
blanchet@54246
   204
blanchet@54246
   205
val undef_const = Const (@{const_name undefined}, dummyT);
blanchet@54246
   206
blanchet@54246
   207
type eqn_data = {
blanchet@54246
   208
  fun_name: string,
blanchet@54246
   209
  rec_type: typ,
blanchet@54246
   210
  ctr: term,
blanchet@54246
   211
  ctr_args: term list,
blanchet@54246
   212
  left_args: term list,
blanchet@54246
   213
  right_args: term list,
blanchet@54246
   214
  res_type: typ,
blanchet@54246
   215
  rhs_term: term,
blanchet@54246
   216
  user_eqn: term
blanchet@54246
   217
};
blanchet@54246
   218
blanchet@54246
   219
fun dissect_eqn lthy fun_names eqn' =
blanchet@54246
   220
  let
blanchet@54979
   221
    val eqn = drop_all eqn' |> HOLogic.dest_Trueprop
blanchet@54246
   222
      handle TERM _ =>
blanchet@55575
   223
             raise PRIMREC ("malformed function equation (expected \"lhs = rhs\")", [eqn']);
blanchet@54246
   224
    val (lhs, rhs) = HOLogic.dest_eq eqn
blanchet@54246
   225
        handle TERM _ =>
blanchet@55575
   226
               raise PRIMREC ("malformed function equation (expected \"lhs = rhs\")", [eqn']);
blanchet@54246
   227
    val (fun_name, args) = strip_comb lhs
blanchet@54246
   228
      |>> (fn x => if is_Free x then fst (dest_Free x)
blanchet@55575
   229
          else raise PRIMREC ("malformed function equation (does not start with free)", [eqn]));
blanchet@54246
   230
    val (left_args, rest) = take_prefix is_Free args;
blanchet@54246
   231
    val (nonfrees, right_args) = take_suffix is_Free rest;
blanchet@54246
   232
    val num_nonfrees = length nonfrees;
blanchet@54246
   233
    val _ = num_nonfrees = 1 orelse if num_nonfrees = 0 then
blanchet@55575
   234
      raise PRIMREC ("constructor pattern missing in left-hand side", [eqn]) else
blanchet@55575
   235
      raise PRIMREC ("more than one non-variable argument in left-hand side", [eqn]);
blanchet@54246
   236
    val _ = member (op =) fun_names fun_name orelse
blanchet@55575
   237
      raise PRIMREC ("malformed function equation (does not start with function name)", [eqn]);
blanchet@54246
   238
blanchet@54246
   239
    val (ctr, ctr_args) = strip_comb (the_single nonfrees);
blanchet@54246
   240
    val _ = try (num_binder_types o fastype_of) ctr = SOME (length ctr_args) orelse
blanchet@55575
   241
      raise PRIMREC ("partially applied constructor in pattern", [eqn]);
blanchet@54246
   242
    val _ = let val d = duplicates (op =) (left_args @ ctr_args @ right_args) in null d orelse
blanchet@55575
   243
      raise PRIMREC ("duplicate variable \"" ^ Syntax.string_of_term lthy (hd d) ^
blanchet@55575
   244
        "\" in left-hand side", [eqn]) end;
blanchet@54246
   245
    val _ = forall is_Free ctr_args orelse
blanchet@55575
   246
      raise PRIMREC ("non-primitive pattern in left-hand side", [eqn]);
blanchet@54246
   247
    val _ =
blanchet@54246
   248
      let val b = fold_aterms (fn x as Free (v, _) =>
blanchet@54246
   249
        if (not (member (op =) (left_args @ ctr_args @ right_args) x) andalso
blanchet@54246
   250
        not (member (op =) fun_names v) andalso
blanchet@54246
   251
        not (Variable.is_fixed lthy v)) then cons x else I | _ => I) rhs []
blanchet@54246
   252
      in
blanchet@54246
   253
        null b orelse
blanchet@55575
   254
        raise PRIMREC ("extra variable(s) in right-hand side: " ^
blanchet@55575
   255
          commas (map (Syntax.string_of_term lthy) b), [eqn])
blanchet@54246
   256
      end;
blanchet@54246
   257
  in
blanchet@54246
   258
    {fun_name = fun_name,
blanchet@54246
   259
     rec_type = body_type (type_of ctr),
blanchet@54246
   260
     ctr = ctr,
blanchet@54246
   261
     ctr_args = ctr_args,
blanchet@54246
   262
     left_args = left_args,
blanchet@54246
   263
     right_args = right_args,
blanchet@54246
   264
     res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs,
blanchet@54246
   265
     rhs_term = rhs,
blanchet@54246
   266
     user_eqn = eqn'}
blanchet@54246
   267
  end;
blanchet@54246
   268
blanchet@54246
   269
fun subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls =
blanchet@54246
   270
  let
blanchet@54246
   271
    fun try_nested_rec bound_Ts y t =
blanchet@54246
   272
      AList.lookup (op =) nested_calls y
blanchet@55575
   273
      |> Option.map (fn y' => rewrite_nested_rec_call lthy has_call get_ctr_pos bound_Ts y y' t);
blanchet@54246
   274
blanchet@54246
   275
    fun subst bound_Ts (t as g' $ y) =
blanchet@54246
   276
        let
blanchet@54246
   277
          fun subst_rec () = subst bound_Ts g' $ subst bound_Ts y;
blanchet@54246
   278
          val y_head = head_of y;
blanchet@54246
   279
        in
blanchet@54246
   280
          if not (member (op =) ctr_args y_head) then
blanchet@54246
   281
            subst_rec ()
blanchet@54246
   282
          else
blanchet@54246
   283
            (case try_nested_rec bound_Ts y_head t of
blanchet@54246
   284
              SOME t' => t'
blanchet@54246
   285
            | NONE =>
blanchet@54246
   286
              let val (g, g_args) = strip_comb g' in
blanchet@54246
   287
                (case try (get_ctr_pos o fst o dest_Free) g of
blanchet@54246
   288
                  SOME ctr_pos =>
blanchet@54246
   289
                  (length g_args >= ctr_pos orelse
blanchet@55575
   290
                   raise PRIMREC ("too few arguments in recursive call", [t]);
blanchet@54246
   291
                   (case AList.lookup (op =) mutual_calls y of
blanchet@54246
   292
                     SOME y' => list_comb (y', g_args)
blanchet@54246
   293
                   | NONE => subst_rec ()))
blanchet@54246
   294
                | NONE => subst_rec ())
blanchet@54246
   295
              end)
blanchet@54246
   296
        end
blanchet@54246
   297
      | subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
blanchet@54246
   298
      | subst _ t = t
blanchet@54246
   299
blanchet@54246
   300
    fun subst' t =
blanchet@54246
   301
      if has_call t then
blanchet@54246
   302
        (* FIXME detect this case earlier? *)
blanchet@55575
   303
        raise PRIMREC ("recursive call not directly applied to constructor argument", [t])
blanchet@54246
   304
      else
blanchet@54246
   305
        try_nested_rec [] (head_of t) t |> the_default t
blanchet@54246
   306
  in
blanchet@54246
   307
    subst' o subst []
blanchet@54246
   308
  end;
blanchet@54246
   309
blanchet@54246
   310
fun build_rec_arg lthy (funs_data : eqn_data list list) has_call (ctr_spec : rec_ctr_spec)
blanchet@54926
   311
    (eqn_data_opt : eqn_data option) =
blanchet@54926
   312
  (case eqn_data_opt of
blanchet@54246
   313
    NONE => undef_const
blanchet@54246
   314
  | SOME {ctr_args, left_args, right_args, rhs_term = t, ...} =>
blanchet@54246
   315
    let
blanchet@54246
   316
      val calls = #calls ctr_spec;
blanchet@54246
   317
      val n_args = fold (Integer.add o (fn Mutual_Rec _ => 2 | _ => 1)) calls 0;
blanchet@54246
   318
blanchet@54246
   319
      val no_calls' = tag_list 0 calls
blanchet@54246
   320
        |> map_filter (try (apsnd (fn No_Rec p => p | Mutual_Rec (p, _) => p)));
blanchet@54246
   321
      val mutual_calls' = tag_list 0 calls
blanchet@54246
   322
        |> map_filter (try (apsnd (fn Mutual_Rec (_, p) => p)));
blanchet@54246
   323
      val nested_calls' = tag_list 0 calls
blanchet@54246
   324
        |> map_filter (try (apsnd (fn Nested_Rec p => p)));
blanchet@54246
   325
panny@54851
   326
      fun ensure_unique frees t =
panny@54851
   327
        if member (op =) frees t then Free (the_single (Term.variant_frees t [dest_Free t])) else t;
panny@54851
   328
blanchet@54246
   329
      val args = replicate n_args ("", dummyT)
blanchet@54246
   330
        |> Term.rename_wrt_term t
blanchet@54246
   331
        |> map Free
blanchet@54246
   332
        |> fold (fn (ctr_arg_idx, (arg_idx, _)) =>
blanchet@54246
   333
            nth_map arg_idx (K (nth ctr_args ctr_arg_idx)))
blanchet@54246
   334
          no_calls'
panny@54851
   335
        |> fold (fn (ctr_arg_idx, (arg_idx, T)) => fn xs =>
panny@54851
   336
            nth_map arg_idx (K (ensure_unique xs (retype_free T (nth ctr_args ctr_arg_idx)))) xs)
blanchet@54246
   337
          mutual_calls'
blanchet@54246
   338
        |> fold (fn (ctr_arg_idx, (arg_idx, T)) =>
blanchet@54246
   339
            nth_map arg_idx (K (retype_free T (nth ctr_args ctr_arg_idx))))
blanchet@54246
   340
          nested_calls';
blanchet@54246
   341
blanchet@54246
   342
      val fun_name_ctr_pos_list =
blanchet@54246
   343
        map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
blanchet@54246
   344
      val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1;
blanchet@54246
   345
      val mutual_calls = map (apfst (nth ctr_args) o apsnd (nth args o fst)) mutual_calls';
blanchet@54246
   346
      val nested_calls = map (apfst (nth ctr_args) o apsnd (nth args o fst)) nested_calls';
blanchet@54246
   347
    in
blanchet@54246
   348
      t
blanchet@54246
   349
      |> subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls
blanchet@54246
   350
      |> fold_rev lambda (args @ left_args @ right_args)
blanchet@54246
   351
    end);
blanchet@54246
   352
panny@56121
   353
fun build_defs lthy nonexhaustive bs mxs (funs_data : eqn_data list list)
panny@56121
   354
    (rec_specs : rec_spec list) has_call =
blanchet@54246
   355
  let
blanchet@54246
   356
    val n_funs = length funs_data;
blanchet@54246
   357
blanchet@54246
   358
    val ctr_spec_eqn_data_list' =
blanchet@54246
   359
      (take n_funs rec_specs |> map #ctr_specs) ~~ funs_data
blanchet@54246
   360
      |> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y))
blanchet@54246
   361
          ##> (fn x => null x orelse
blanchet@55575
   362
            raise PRIMREC ("excess equations in definition", map #rhs_term x)) #> fst);
panny@56121
   363
    val _ = ctr_spec_eqn_data_list' |> map (fn ({ctr, ...}, x) =>
panny@56121
   364
        if length x > 1 then raise PRIMREC ("multiple equations for constructor", map #user_eqn x)
panny@56121
   365
        else if length x = 1 orelse nonexhaustive then ()
panny@56121
   366
        else warning ("no equation for constructor " ^ Syntax.string_of_term lthy ctr));
blanchet@54246
   367
blanchet@54246
   368
    val ctr_spec_eqn_data_list =
blanchet@54246
   369
      ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair []));
blanchet@54246
   370
blanchet@54246
   371
    val recs = take n_funs rec_specs |> map #recx;
blanchet@54246
   372
    val rec_args = ctr_spec_eqn_data_list
blanchet@54246
   373
      |> sort ((op <) o pairself (#offset o fst) |> make_ord)
blanchet@54246
   374
      |> map (uncurry (build_rec_arg lthy funs_data has_call) o apsnd (try the_single));
blanchet@54246
   375
    val ctr_poss = map (fn x =>
blanchet@54246
   376
      if length (distinct ((op =) o pairself (length o #left_args)) x) <> 1 then
blanchet@55575
   377
        raise PRIMREC ("inconstant constructor pattern position for function " ^
blanchet@55575
   378
          quote (#fun_name (hd x)), [])
blanchet@54246
   379
      else
blanchet@54246
   380
        hd x |> #left_args |> length) funs_data;
blanchet@54246
   381
  in
blanchet@54246
   382
    (recs, ctr_poss)
blanchet@54246
   383
    |-> map2 (fn recx => fn ctr_pos => list_comb (recx, rec_args) |> permute_args ctr_pos)
blanchet@54246
   384
    |> Syntax.check_terms lthy
blanchet@54246
   385
    |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.conceal (Thm.def_binding b), []), t)))
blanchet@54246
   386
      bs mxs
blanchet@54246
   387
  end;
blanchet@54246
   388
blanchet@54246
   389
fun find_rec_calls has_call ({ctr, ctr_args, rhs_term, ...} : eqn_data) =
blanchet@54246
   390
  let
blanchet@54246
   391
    fun find bound_Ts (Abs (_, T, b)) ctr_arg = find (T :: bound_Ts) b ctr_arg
blanchet@54246
   392
      | find bound_Ts (t as _ $ _) ctr_arg =
blanchet@54246
   393
        let
blanchet@54246
   394
          val typof = curry fastype_of1 bound_Ts;
blanchet@54246
   395
          val (f', args') = strip_comb t;
blanchet@54246
   396
          val n = find_index (equal ctr_arg o head_of) args';
blanchet@54246
   397
        in
blanchet@54246
   398
          if n < 0 then
blanchet@54246
   399
            find bound_Ts f' ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args'
blanchet@54246
   400
          else
blanchet@54246
   401
            let
blanchet@54246
   402
              val (f, args as arg :: _) = chop n args' |>> curry list_comb f'
blanchet@54246
   403
              val (arg_head, arg_args) = Term.strip_comb arg;
blanchet@54246
   404
            in
blanchet@54246
   405
              if has_call f then
blanchet@54246
   406
                mk_partial_compN (length arg_args) (typof arg_head) f ::
blanchet@54246
   407
                maps (fn x => find bound_Ts x ctr_arg) args
blanchet@54246
   408
              else
blanchet@54246
   409
                find bound_Ts f ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args
blanchet@54246
   410
            end
blanchet@54246
   411
        end
blanchet@54246
   412
      | find _ _ _ = [];
blanchet@54246
   413
  in
blanchet@54246
   414
    map (find [] rhs_term) ctr_args
blanchet@54246
   415
    |> (fn [] => NONE | callss => SOME (ctr, callss))
blanchet@54246
   416
  end;
blanchet@54246
   417
blanchet@54246
   418
fun mk_primrec_tac ctxt num_extra_args map_idents map_comps fun_defs recx =
blanchet@54246
   419
  unfold_thms_tac ctxt fun_defs THEN
blanchet@54246
   420
  HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN
blanchet@55575
   421
  unfold_thms_tac ctxt (nested_simps ctxt @ map_comps @ map_idents) THEN
blanchet@54246
   422
  HEADGOAL (rtac refl);
blanchet@54246
   423
panny@56121
   424
fun prepare_primrec nonexhaustive fixes specs lthy0 =
blanchet@54246
   425
  let
blanchet@55535
   426
    val thy = Proof_Context.theory_of lthy0;
blanchet@54272
   427
blanchet@54246
   428
    val (bs, mxs) = map_split (apfst fst) fixes;
blanchet@54246
   429
    val fun_names = map Binding.name_of bs;
blanchet@55535
   430
    val eqns_data = map (dissect_eqn lthy0 fun_names) specs;
blanchet@54246
   431
    val funs_data = eqns_data
blanchet@54246
   432
      |> partition_eq ((op =) o pairself #fun_name)
blanchet@54246
   433
      |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
blanchet@55527
   434
      |> map (fn (x, y) => the_single y
blanchet@55575
   435
          handle List.Empty => raise PRIMREC ("missing equations for function " ^ quote x, []));
blanchet@54246
   436
blanchet@55772
   437
    val frees = map (fst #>> Binding.name_of #> Free) fixes;
blanchet@55772
   438
    val has_call = exists_subterm (member (op =) frees);
blanchet@54246
   439
    val arg_Ts = map (#rec_type o hd) funs_data;
blanchet@54246
   440
    val res_Ts = map (#res_type o hd) funs_data;
blanchet@54246
   441
    val callssss = funs_data
blanchet@54246
   442
      |> map (partition_eq ((op =) o pairself #ctr))
blanchet@54246
   443
      |> map (maps (map_filter (find_rec_calls has_call)));
blanchet@54246
   444
blanchet@55528
   445
    fun is_only_old_datatype (Type (s, _)) =
blanchet@55571
   446
        is_some (Datatype_Data.get_info thy s) andalso not (is_new_datatype lthy0 s)
blanchet@55528
   447
      | is_only_old_datatype _ = false;
blanchet@55528
   448
blanchet@55528
   449
    val _ = if exists is_only_old_datatype arg_Ts then raise OLD_PRIMREC () else ();
wenzelm@56254
   450
    val _ = (case filter_out (fn (_, T) => Sign.of_sort thy (T, @{sort type})) (bs ~~ res_Ts) of
blanchet@54272
   451
        [] => ()
blanchet@55575
   452
      | (b, _) :: _ => raise PRIMREC ("type of " ^ Binding.print b ^ " contains top sort", []));
blanchet@54272
   453
blanchet@55535
   454
    val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy) =
blanchet@55772
   455
      rec_specs_of bs arg_Ts res_Ts frees callssss lthy0;
blanchet@54246
   456
blanchet@54246
   457
    val actual_nn = length funs_data;
blanchet@54246
   458
blanchet@55539
   459
    val ctrs = maps (map #ctr o #ctr_specs) rec_specs;
blanchet@55539
   460
    val _ =
blanchet@54246
   461
      map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
blanchet@55575
   462
        raise PRIMREC ("argument " ^ quote (Syntax.string_of_term lthy ctr) ^
blanchet@55575
   463
          " is not a constructor in left-hand side", [user_eqn])) eqns_data;
blanchet@54246
   464
panny@56121
   465
    val defs = build_defs lthy nonexhaustive bs mxs funs_data rec_specs has_call;
blanchet@54246
   466
blanchet@55535
   467
    fun prove lthy' def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec)
blanchet@54246
   468
        (fun_data : eqn_data list) =
blanchet@54246
   469
      let
blanchet@55535
   470
        val js =
blanchet@55535
   471
          find_indices (op = o pairself (fn {fun_name, ctr, ...} => (fun_name, ctr)))
blanchet@55535
   472
            fun_data eqns_data;
blanchet@55535
   473
blanchet@54246
   474
        val def_thms = map (snd o snd) def_thms';
blanchet@55535
   475
        val simp_thms = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
blanchet@54246
   476
          |> fst
blanchet@54246
   477
          |> map_filter (try (fn (x, [y]) =>
blanchet@55535
   478
            (#fun_name x, #user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
blanchet@55535
   479
          |> map2 (fn j => fn (fun_name, user_eqn, num_extra_args, rec_thm) =>
blanchet@55772
   480
              mk_primrec_tac lthy' num_extra_args nested_map_idents nested_map_comps def_thms
blanchet@55772
   481
                rec_thm
blanchet@55535
   482
              |> K |> Goal.prove_sorry lthy' [] [] user_eqn
blanchet@55535
   483
              (* for code extraction from proof terms: *)
blanchet@55535
   484
              |> singleton (Proof_Context.export lthy' lthy)
blanchet@55535
   485
              |> Thm.name_derivation (Sign.full_name thy (Binding.name fun_name) ^
blanchet@55535
   486
                Long_Name.separator ^ simpsN ^
blanchet@55535
   487
                (if js = [0] then "" else "_" ^ string_of_int (j + 1))))
blanchet@55535
   488
            js;
blanchet@54246
   489
      in
blanchet@55535
   490
        (js, simp_thms)
blanchet@54246
   491
      end;
blanchet@54246
   492
blanchet@54246
   493
    val notes =
blanchet@55482
   494
      (if n2m then
blanchet@55575
   495
         map2 (fn name => fn thm => (name, inductN, [thm], [])) fun_names
blanchet@55575
   496
           (take actual_nn induct_thms)
blanchet@55482
   497
       else
blanchet@55482
   498
         [])
blanchet@54246
   499
      |> map (fn (prefix, thmN, thms, attrs) =>
blanchet@54246
   500
        ((Binding.qualify true prefix (Binding.name thmN), attrs), [(thms, [])]));
blanchet@54246
   501
blanchet@54246
   502
    val common_name = mk_common_name fun_names;
blanchet@54246
   503
blanchet@54246
   504
    val common_notes =
blanchet@54246
   505
      (if n2m then [(inductN, [induct_thm], [])] else [])
blanchet@54246
   506
      |> map (fn (thmN, thms, attrs) =>
blanchet@54246
   507
        ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
blanchet@54246
   508
  in
blanchet@54246
   509
    (((fun_names, defs),
blanchet@54246
   510
      fn lthy => fn defs =>
blanchet@54246
   511
        split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)),
blanchet@55535
   512
      lthy |> Local_Theory.notes (notes @ common_notes) |> snd)
blanchet@54246
   513
  end;
blanchet@54246
   514
panny@56121
   515
fun add_primrec_simple' opts fixes ts lthy =
blanchet@54246
   516
  let
panny@56121
   517
    val nonexhaustive = member (op =) opts Nonexhaustive_Option;
panny@56121
   518
    val (((names, defs), prove), lthy') = prepare_primrec nonexhaustive fixes ts lthy
blanchet@55575
   519
      handle ERROR str => raise PRIMREC (str, []);
blanchet@54246
   520
  in
blanchet@55527
   521
    lthy'
blanchet@54246
   522
    |> fold_map Local_Theory.define defs
blanchet@54246
   523
    |-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs))))
blanchet@54246
   524
  end
blanchet@55527
   525
  handle PRIMREC (str, eqns) =>
blanchet@55527
   526
         if null eqns then
blanchet@55530
   527
           error ("primrec error:\n  " ^ str)
blanchet@55527
   528
         else
blanchet@55530
   529
           error ("primrec error:\n  " ^ str ^ "\nin\n  " ^
blanchet@55527
   530
             space_implode "\n  " (map (quote o Syntax.string_of_term lthy) eqns));
blanchet@54246
   531
panny@56121
   532
val add_primrec_simple = add_primrec_simple' [];
panny@56121
   533
panny@56121
   534
fun gen_primrec old_primrec prep_spec opts
panny@56121
   535
    (raw_fixes : (binding * 'a option * mixfix) list) raw_spec lthy =
blanchet@54246
   536
  let
blanchet@54246
   537
    val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes)
blanchet@55575
   538
    val _ = null d orelse raise PRIMREC ("duplicate function name(s): " ^ commas d, []);
blanchet@54246
   539
blanchet@54246
   540
    val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy);
blanchet@54246
   541
blanchet@54246
   542
    val mk_notes =
blanchet@55535
   543
      flat ooo map3 (fn js => fn prefix => fn thms =>
blanchet@54246
   544
        let
blanchet@55535
   545
          val (bs, attrss) = map_split (fst o nth specs) js;
blanchet@54246
   546
          val notes =
blanchet@54246
   547
            map3 (fn b => fn attrs => fn thm =>
blanchet@55464
   548
                ((Binding.qualify false prefix b, code_nitpicksimp_simp_attrs @ attrs),
blanchet@55464
   549
                 [([thm], [])]))
blanchet@55464
   550
              bs attrss thms;
blanchet@54246
   551
        in
blanchet@54246
   552
          ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes
blanchet@54246
   553
        end);
blanchet@54246
   554
  in
blanchet@54246
   555
    lthy
panny@56121
   556
    |> add_primrec_simple' opts fixes (map snd specs)
blanchet@55535
   557
    |-> (fn (names, (ts, (jss, simpss))) =>
blanchet@54246
   558
      Spec_Rules.add Spec_Rules.Equational (ts, flat simpss)
blanchet@55535
   559
      #> Local_Theory.notes (mk_notes jss names simpss)
blanchet@54246
   560
      #>> pair ts o map snd)
blanchet@55528
   561
  end
blanchet@55528
   562
  handle OLD_PRIMREC () => old_primrec raw_fixes raw_spec lthy |>> apsnd single;
blanchet@54246
   563
panny@56121
   564
val add_primrec = gen_primrec Primrec.add_primrec Specification.check_spec [];
blanchet@55528
   565
val add_primrec_cmd = gen_primrec Primrec.add_primrec_cmd Specification.read_spec;
blanchet@54246
   566
blanchet@55535
   567
fun add_primrec_global fixes specs =
blanchet@55535
   568
  Named_Target.theory_init
blanchet@55535
   569
  #> add_primrec fixes specs
blanchet@55535
   570
  ##> Local_Theory.exit_global;
blanchet@54246
   571
blanchet@55535
   572
fun add_primrec_overloaded ops fixes specs =
blanchet@55535
   573
  Overloading.overloading ops
blanchet@55535
   574
  #> add_primrec fixes specs
blanchet@55535
   575
  ##> Local_Theory.exit_global;
blanchet@54246
   576
panny@56121
   577
val primrec_option_parser = Parse.group (fn () => "option")
panny@56121
   578
  (Parse.reserved "nonexhaustive" >> K Nonexhaustive_Option)
panny@56121
   579
blanchet@55530
   580
val _ = Outer_Syntax.local_theory @{command_spec "primrec"}
blanchet@55529
   581
  "define primitive recursive functions"
panny@56121
   582
  ((Scan.optional (@{keyword "("} |--
panny@56121
   583
      Parse.!!! (Parse.list1 primrec_option_parser) --| @{keyword ")"}) []) --
panny@56121
   584
    (Parse.fixes -- Parse_Spec.where_alt_specs)
panny@56121
   585
    >> (fn (opts, (fixes, spec)) => snd o add_primrec_cmd opts fixes spec));
blanchet@55529
   586
blanchet@54246
   587
end;