src/HOL/BNF/Tools/bnf_lfp_rec_sugar.ML
author blanchet
Tue Nov 05 16:53:40 2013 +0100 (2013-11-05)
changeset 54272 9d623cada37f
parent 54256 4843082be7ef
child 54851 48a24d371ebb
permissions -rw-r--r--
avoid subtle failure in the presence of top sort
blanchet@54246
     1
(*  Title:      HOL/BNF/Tools/bnf_lfp_rec_sugar.ML
blanchet@54246
     2
    Author:     Lorenz Panny, TU Muenchen
blanchet@54246
     3
    Author:     Jasmin Blanchette, TU Muenchen
blanchet@54246
     4
    Copyright   2013
blanchet@54246
     5
blanchet@54246
     6
Recursor sugar.
blanchet@54246
     7
*)
blanchet@54246
     8
blanchet@54246
     9
signature BNF_LFP_REC_SUGAR =
blanchet@54246
    10
sig
blanchet@54246
    11
  val add_primrec: (binding * typ option * mixfix) list ->
blanchet@54246
    12
    (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory
blanchet@54246
    13
  val add_primrec_cmd: (binding * string option * mixfix) list ->
blanchet@54246
    14
    (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory
blanchet@54246
    15
  val add_primrec_global: (binding * typ option * mixfix) list ->
blanchet@54246
    16
    (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
blanchet@54246
    17
  val add_primrec_overloaded: (string * (string * typ) * bool) list ->
blanchet@54246
    18
    (binding * typ option * mixfix) list ->
blanchet@54246
    19
    (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory
blanchet@54246
    20
  val add_primrec_simple: ((binding * typ) * mixfix) list -> term list ->
blanchet@54246
    21
    local_theory -> (string list * (term list * (int list list * thm list list))) * local_theory
blanchet@54246
    22
end;
blanchet@54246
    23
blanchet@54246
    24
structure BNF_LFP_Rec_Sugar : BNF_LFP_REC_SUGAR =
blanchet@54246
    25
struct
blanchet@54246
    26
blanchet@54246
    27
open Ctr_Sugar
blanchet@54246
    28
open BNF_Util
blanchet@54246
    29
open BNF_Tactics
blanchet@54246
    30
open BNF_Def
blanchet@54246
    31
open BNF_FP_Util
blanchet@54246
    32
open BNF_FP_Def_Sugar
blanchet@54246
    33
open BNF_FP_N2M_Sugar
blanchet@54246
    34
open BNF_FP_Rec_Sugar_Util
blanchet@54246
    35
blanchet@54246
    36
val nitpicksimp_attrs = @{attributes [nitpick_simp]};
blanchet@54246
    37
val simp_attrs = @{attributes [simp]};
blanchet@54246
    38
val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs @ simp_attrs;
blanchet@54246
    39
blanchet@54246
    40
exception Primrec_Error of string * term list;
blanchet@54246
    41
blanchet@54246
    42
fun primrec_error str = raise Primrec_Error (str, []);
blanchet@54246
    43
fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]);
blanchet@54246
    44
fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns);
blanchet@54246
    45
blanchet@54246
    46
datatype rec_call =
blanchet@54246
    47
  No_Rec of int * typ |
blanchet@54246
    48
  Mutual_Rec of (int * typ) * (int * typ) |
blanchet@54246
    49
  Nested_Rec of int * typ;
blanchet@54246
    50
blanchet@54246
    51
type rec_ctr_spec =
blanchet@54246
    52
  {ctr: term,
blanchet@54246
    53
   offset: int,
blanchet@54246
    54
   calls: rec_call list,
blanchet@54246
    55
   rec_thm: thm};
blanchet@54246
    56
blanchet@54246
    57
type rec_spec =
blanchet@54246
    58
  {recx: term,
blanchet@54246
    59
   nested_map_idents: thm list,
blanchet@54246
    60
   nested_map_comps: thm list,
blanchet@54246
    61
   ctr_specs: rec_ctr_spec list};
blanchet@54246
    62
blanchet@54246
    63
exception AINT_NO_MAP of term;
blanchet@54246
    64
blanchet@54246
    65
fun ill_formed_rec_call ctxt t =
blanchet@54246
    66
  error ("Ill-formed recursive call: " ^ quote (Syntax.string_of_term ctxt t));
blanchet@54246
    67
fun invalid_map ctxt t =
blanchet@54246
    68
  error ("Invalid map function in " ^ quote (Syntax.string_of_term ctxt t));
blanchet@54246
    69
fun unexpected_rec_call ctxt t =
blanchet@54246
    70
  error ("Unexpected recursive call: " ^ quote (Syntax.string_of_term ctxt t));
blanchet@54246
    71
blanchet@54246
    72
fun massage_nested_rec_call ctxt has_call raw_massage_fun bound_Ts y y' =
blanchet@54246
    73
  let
blanchet@54246
    74
    fun check_no_call t = if has_call t then unexpected_rec_call ctxt t else ();
blanchet@54246
    75
blanchet@54246
    76
    val typof = curry fastype_of1 bound_Ts;
blanchet@54246
    77
    val build_map_fst = build_map ctxt (fst_const o fst);
blanchet@54246
    78
blanchet@54246
    79
    val yT = typof y;
blanchet@54246
    80
    val yU = typof y';
blanchet@54246
    81
blanchet@54246
    82
    fun y_of_y' () = build_map_fst (yU, yT) $ y';
blanchet@54246
    83
    val elim_y = Term.map_aterms (fn t => if t = y then y_of_y' () else t);
blanchet@54246
    84
blanchet@54246
    85
    fun massage_mutual_fun U T t =
blanchet@54246
    86
      (case t of
blanchet@54246
    87
        Const (@{const_name comp}, _) $ t1 $ t2 =>
blanchet@54246
    88
        mk_comp bound_Ts (tap check_no_call t1, massage_mutual_fun U T t2)
blanchet@54246
    89
      | _ =>
blanchet@54246
    90
        if has_call t then
blanchet@54246
    91
          (case try HOLogic.dest_prodT U of
blanchet@54246
    92
            SOME (U1, U2) => if U1 = T then raw_massage_fun T U2 t else invalid_map ctxt t
blanchet@54246
    93
          | NONE => invalid_map ctxt t)
blanchet@54246
    94
        else
blanchet@54246
    95
          mk_comp bound_Ts (t, build_map_fst (U, T)));
blanchet@54246
    96
blanchet@54246
    97
    fun massage_map (Type (_, Us)) (Type (s, Ts)) t =
blanchet@54246
    98
        (case try (dest_map ctxt s) t of
blanchet@54246
    99
          SOME (map0, fs) =>
blanchet@54246
   100
          let
blanchet@54246
   101
            val Type (_, ran_Ts) = range_type (typof t);
blanchet@54246
   102
            val map' = mk_map (length fs) Us ran_Ts map0;
blanchet@54246
   103
            val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs;
blanchet@54246
   104
          in
blanchet@54246
   105
            Term.list_comb (map', fs')
blanchet@54246
   106
          end
blanchet@54246
   107
        | NONE => raise AINT_NO_MAP t)
blanchet@54246
   108
      | massage_map _ _ t = raise AINT_NO_MAP t
blanchet@54246
   109
    and massage_map_or_map_arg U T t =
blanchet@54246
   110
      if T = U then
blanchet@54246
   111
        tap check_no_call t
blanchet@54246
   112
      else
blanchet@54246
   113
        massage_map U T t
blanchet@54246
   114
        handle AINT_NO_MAP _ => massage_mutual_fun U T t;
blanchet@54246
   115
blanchet@54246
   116
    fun massage_call (t as t1 $ t2) =
blanchet@54246
   117
        if has_call t then
blanchet@54246
   118
          if t2 = y then
blanchet@54246
   119
            massage_map yU yT (elim_y t1) $ y'
blanchet@54246
   120
            handle AINT_NO_MAP t' => invalid_map ctxt t'
blanchet@54246
   121
          else
blanchet@54246
   122
            let val (g, xs) = Term.strip_comb t2 in
blanchet@54246
   123
              if g = y then
blanchet@54246
   124
                if exists has_call xs then unexpected_rec_call ctxt t2
blanchet@54246
   125
                else Term.list_comb (massage_call (mk_compN (length xs) bound_Ts (t1, y)), xs)
blanchet@54246
   126
              else
blanchet@54246
   127
                ill_formed_rec_call ctxt t
blanchet@54246
   128
            end
blanchet@54246
   129
        else
blanchet@54246
   130
          elim_y t
blanchet@54246
   131
      | massage_call t = if t = y then y_of_y' () else ill_formed_rec_call ctxt t;
blanchet@54246
   132
  in
blanchet@54246
   133
    massage_call
blanchet@54246
   134
  end;
blanchet@54246
   135
blanchet@54246
   136
fun rec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy =
blanchet@54246
   137
  let
blanchet@54246
   138
    val thy = Proof_Context.theory_of lthy;
blanchet@54246
   139
blanchet@54246
   140
    val ((missing_arg_Ts, perm0_kks,
blanchet@54246
   141
          fp_sugars as {nested_bnfs, fp_res = {xtor_co_iterss = ctor_iters1 :: _, ...},
blanchet@54246
   142
            co_inducts = [induct_thm], ...} :: _, (lfp_sugar_thms, _)), lthy') =
blanchet@54246
   143
      nested_to_mutual_fps Least_FP bs arg_Ts get_indices callssss0 lthy;
blanchet@54246
   144
blanchet@54246
   145
    val perm_fp_sugars = sort (int_ord o pairself #index) fp_sugars;
blanchet@54246
   146
blanchet@54246
   147
    val indices = map #index fp_sugars;
blanchet@54246
   148
    val perm_indices = map #index perm_fp_sugars;
blanchet@54246
   149
blanchet@54246
   150
    val perm_ctrss = map (#ctrs o of_fp_sugar #ctr_sugars) perm_fp_sugars;
blanchet@54246
   151
    val perm_ctr_Tsss = map (map (binder_types o fastype_of)) perm_ctrss;
blanchet@54246
   152
    val perm_lfpTs = map (body_type o fastype_of o hd) perm_ctrss;
blanchet@54246
   153
blanchet@54246
   154
    val nn0 = length arg_Ts;
blanchet@54246
   155
    val nn = length perm_lfpTs;
blanchet@54246
   156
    val kks = 0 upto nn - 1;
blanchet@54246
   157
    val perm_ns = map length perm_ctr_Tsss;
blanchet@54246
   158
    val perm_mss = map (map length) perm_ctr_Tsss;
blanchet@54246
   159
blanchet@54246
   160
    val perm_Cs = map (body_type o fastype_of o co_rec_of o of_fp_sugar (#xtor_co_iterss o #fp_res))
blanchet@54246
   161
      perm_fp_sugars;
blanchet@54246
   162
    val perm_fun_arg_Tssss =
blanchet@54246
   163
      mk_iter_fun_arg_types perm_ctr_Tsss perm_ns perm_mss (co_rec_of ctor_iters1);
blanchet@54246
   164
blanchet@54246
   165
    fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs;
blanchet@54246
   166
    fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs;
blanchet@54246
   167
blanchet@54246
   168
    val induct_thms = unpermute0 (conj_dests nn induct_thm);
blanchet@54246
   169
blanchet@54246
   170
    val lfpTs = unpermute perm_lfpTs;
blanchet@54246
   171
    val Cs = unpermute perm_Cs;
blanchet@54246
   172
blanchet@54246
   173
    val As_rho = tvar_subst thy (take nn0 lfpTs) arg_Ts;
blanchet@54246
   174
    val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn res_Ts;
blanchet@54246
   175
blanchet@54246
   176
    val substA = Term.subst_TVars As_rho;
blanchet@54246
   177
    val substAT = Term.typ_subst_TVars As_rho;
blanchet@54246
   178
    val substCT = Term.typ_subst_TVars Cs_rho;
blanchet@54246
   179
    val substACT = substAT o substCT;
blanchet@54246
   180
blanchet@54246
   181
    val perm_Cs' = map substCT perm_Cs;
blanchet@54246
   182
blanchet@54246
   183
    fun offset_of_ctr 0 _ = 0
blanchet@54246
   184
      | offset_of_ctr n (({ctrs, ...} : ctr_sugar) :: ctr_sugars) =
blanchet@54246
   185
        length ctrs + offset_of_ctr (n - 1) ctr_sugars;
blanchet@54246
   186
blanchet@54246
   187
    fun call_of [i] [T] = (if exists_subtype_in Cs T then Nested_Rec else No_Rec) (i, substACT T)
blanchet@54246
   188
      | call_of [i, i'] [T, T'] = Mutual_Rec ((i, substACT T), (i', substACT T'));
blanchet@54246
   189
blanchet@54246
   190
    fun mk_ctr_spec ctr offset fun_arg_Tss rec_thm =
blanchet@54246
   191
      let
blanchet@54246
   192
        val (fun_arg_hss, _) = indexedd fun_arg_Tss 0;
blanchet@54246
   193
        val fun_arg_hs = flat_rec_arg_args fun_arg_hss;
blanchet@54246
   194
        val fun_arg_iss = map (map (find_index_eq fun_arg_hs)) fun_arg_hss;
blanchet@54246
   195
      in
blanchet@54246
   196
        {ctr = substA ctr, offset = offset, calls = map2 call_of fun_arg_iss fun_arg_Tss,
blanchet@54246
   197
         rec_thm = rec_thm}
blanchet@54246
   198
      end;
blanchet@54246
   199
blanchet@54246
   200
    fun mk_ctr_specs index (ctr_sugars : ctr_sugar list) iter_thmsss =
blanchet@54246
   201
      let
blanchet@54246
   202
        val ctrs = #ctrs (nth ctr_sugars index);
blanchet@54256
   203
        val rec_thms = co_rec_of (nth iter_thmsss index);
blanchet@54246
   204
        val k = offset_of_ctr index ctr_sugars;
blanchet@54246
   205
        val n = length ctrs;
blanchet@54246
   206
      in
blanchet@54256
   207
        map4 mk_ctr_spec ctrs (k upto k + n - 1) (nth perm_fun_arg_Tssss index) rec_thms
blanchet@54246
   208
      end;
blanchet@54246
   209
blanchet@54246
   210
    fun mk_spec ({T, index, ctr_sugars, co_iterss = iterss, co_iter_thmsss = iter_thmsss, ...}
blanchet@54246
   211
      : fp_sugar) =
blanchet@54246
   212
      {recx = mk_co_iter thy Least_FP (substAT T) perm_Cs' (co_rec_of (nth iterss index)),
blanchet@54246
   213
       nested_map_idents = map (unfold_thms lthy @{thms id_def} o map_id0_of_bnf) nested_bnfs,
blanchet@54246
   214
       nested_map_comps = map map_comp_of_bnf nested_bnfs,
blanchet@54246
   215
       ctr_specs = mk_ctr_specs index ctr_sugars iter_thmsss};
blanchet@54246
   216
  in
blanchet@54246
   217
    ((is_some lfp_sugar_thms, map mk_spec fp_sugars, missing_arg_Ts, induct_thm, induct_thms),
blanchet@54246
   218
     lthy')
blanchet@54246
   219
  end;
blanchet@54246
   220
blanchet@54246
   221
val undef_const = Const (@{const_name undefined}, dummyT);
blanchet@54246
   222
blanchet@54246
   223
fun permute_args n t =
blanchet@54246
   224
  list_comb (t, map Bound (0 :: (n downto 1))) |> fold (K (Term.abs (Name.uu, dummyT))) (0 upto n);
blanchet@54246
   225
blanchet@54246
   226
type eqn_data = {
blanchet@54246
   227
  fun_name: string,
blanchet@54246
   228
  rec_type: typ,
blanchet@54246
   229
  ctr: term,
blanchet@54246
   230
  ctr_args: term list,
blanchet@54246
   231
  left_args: term list,
blanchet@54246
   232
  right_args: term list,
blanchet@54246
   233
  res_type: typ,
blanchet@54246
   234
  rhs_term: term,
blanchet@54246
   235
  user_eqn: term
blanchet@54246
   236
};
blanchet@54246
   237
blanchet@54246
   238
fun dissect_eqn lthy fun_names eqn' =
blanchet@54246
   239
  let
blanchet@54246
   240
    val eqn = drop_All eqn' |> HOLogic.dest_Trueprop
blanchet@54246
   241
      handle TERM _ =>
blanchet@54246
   242
        primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
blanchet@54246
   243
    val (lhs, rhs) = HOLogic.dest_eq eqn
blanchet@54246
   244
        handle TERM _ =>
blanchet@54246
   245
          primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn';
blanchet@54246
   246
    val (fun_name, args) = strip_comb lhs
blanchet@54246
   247
      |>> (fn x => if is_Free x then fst (dest_Free x)
blanchet@54246
   248
          else primrec_error_eqn "malformed function equation (does not start with free)" eqn);
blanchet@54246
   249
    val (left_args, rest) = take_prefix is_Free args;
blanchet@54246
   250
    val (nonfrees, right_args) = take_suffix is_Free rest;
blanchet@54246
   251
    val num_nonfrees = length nonfrees;
blanchet@54246
   252
    val _ = num_nonfrees = 1 orelse if num_nonfrees = 0 then
blanchet@54246
   253
      primrec_error_eqn "constructor pattern missing in left-hand side" eqn else
blanchet@54246
   254
      primrec_error_eqn "more than one non-variable argument in left-hand side" eqn;
blanchet@54246
   255
    val _ = member (op =) fun_names fun_name orelse
blanchet@54246
   256
      primrec_error_eqn "malformed function equation (does not start with function name)" eqn
blanchet@54246
   257
blanchet@54246
   258
    val (ctr, ctr_args) = strip_comb (the_single nonfrees);
blanchet@54246
   259
    val _ = try (num_binder_types o fastype_of) ctr = SOME (length ctr_args) orelse
blanchet@54246
   260
      primrec_error_eqn "partially applied constructor in pattern" eqn;
blanchet@54246
   261
    val _ = let val d = duplicates (op =) (left_args @ ctr_args @ right_args) in null d orelse
blanchet@54246
   262
      primrec_error_eqn ("duplicate variable \"" ^ Syntax.string_of_term lthy (hd d) ^
blanchet@54246
   263
        "\" in left-hand side") eqn end;
blanchet@54246
   264
    val _ = forall is_Free ctr_args orelse
blanchet@54246
   265
      primrec_error_eqn "non-primitive pattern in left-hand side" eqn;
blanchet@54246
   266
    val _ =
blanchet@54246
   267
      let val b = fold_aterms (fn x as Free (v, _) =>
blanchet@54246
   268
        if (not (member (op =) (left_args @ ctr_args @ right_args) x) andalso
blanchet@54246
   269
        not (member (op =) fun_names v) andalso
blanchet@54246
   270
        not (Variable.is_fixed lthy v)) then cons x else I | _ => I) rhs []
blanchet@54246
   271
      in
blanchet@54246
   272
        null b orelse
blanchet@54246
   273
        primrec_error_eqn ("extra variable(s) in right-hand side: " ^
blanchet@54246
   274
          commas (map (Syntax.string_of_term lthy) b)) eqn
blanchet@54246
   275
      end;
blanchet@54246
   276
  in
blanchet@54246
   277
    {fun_name = fun_name,
blanchet@54246
   278
     rec_type = body_type (type_of ctr),
blanchet@54246
   279
     ctr = ctr,
blanchet@54246
   280
     ctr_args = ctr_args,
blanchet@54246
   281
     left_args = left_args,
blanchet@54246
   282
     right_args = right_args,
blanchet@54246
   283
     res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs,
blanchet@54246
   284
     rhs_term = rhs,
blanchet@54246
   285
     user_eqn = eqn'}
blanchet@54246
   286
  end;
blanchet@54246
   287
blanchet@54246
   288
fun rewrite_map_arg get_ctr_pos rec_type res_type =
blanchet@54246
   289
  let
blanchet@54246
   290
    val pT = HOLogic.mk_prodT (rec_type, res_type);
blanchet@54246
   291
blanchet@54246
   292
    val maybe_suc = Option.map (fn x => x + 1);
blanchet@54246
   293
    fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT)
blanchet@54246
   294
      | subst d (Abs (v, T, b)) = Abs (v, if d = SOME ~1 then pT else T, subst (maybe_suc d) b)
blanchet@54246
   295
      | subst d t =
blanchet@54246
   296
        let
blanchet@54246
   297
          val (u, vs) = strip_comb t;
blanchet@54246
   298
          val ctr_pos = try (get_ctr_pos o fst o dest_Free) u |> the_default ~1;
blanchet@54246
   299
        in
blanchet@54246
   300
          if ctr_pos >= 0 then
blanchet@54246
   301
            if d = SOME ~1 andalso length vs = ctr_pos then
blanchet@54246
   302
              list_comb (permute_args ctr_pos (snd_const pT), vs)
blanchet@54246
   303
            else if length vs > ctr_pos andalso is_some d
blanchet@54246
   304
                andalso d = try (fn Bound n => n) (nth vs ctr_pos) then
blanchet@54246
   305
              list_comb (snd_const pT $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs))
blanchet@54246
   306
            else
blanchet@54246
   307
              primrec_error_eqn ("recursive call not directly applied to constructor argument") t
blanchet@54246
   308
          else
blanchet@54246
   309
            list_comb (u, map (subst (d |> d = SOME ~1 ? K NONE)) vs)
blanchet@54246
   310
        end
blanchet@54246
   311
  in
blanchet@54246
   312
    subst (SOME ~1)
blanchet@54246
   313
  end;
blanchet@54246
   314
blanchet@54246
   315
fun subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls =
blanchet@54246
   316
  let
blanchet@54246
   317
    fun try_nested_rec bound_Ts y t =
blanchet@54246
   318
      AList.lookup (op =) nested_calls y
blanchet@54246
   319
      |> Option.map (fn y' =>
blanchet@54246
   320
        massage_nested_rec_call lthy has_call (rewrite_map_arg get_ctr_pos) bound_Ts y y' t);
blanchet@54246
   321
blanchet@54246
   322
    fun subst bound_Ts (t as g' $ y) =
blanchet@54246
   323
        let
blanchet@54246
   324
          fun subst_rec () = subst bound_Ts g' $ subst bound_Ts y;
blanchet@54246
   325
          val y_head = head_of y;
blanchet@54246
   326
        in
blanchet@54246
   327
          if not (member (op =) ctr_args y_head) then
blanchet@54246
   328
            subst_rec ()
blanchet@54246
   329
          else
blanchet@54246
   330
            (case try_nested_rec bound_Ts y_head t of
blanchet@54246
   331
              SOME t' => t'
blanchet@54246
   332
            | NONE =>
blanchet@54246
   333
              let val (g, g_args) = strip_comb g' in
blanchet@54246
   334
                (case try (get_ctr_pos o fst o dest_Free) g of
blanchet@54246
   335
                  SOME ctr_pos =>
blanchet@54246
   336
                  (length g_args >= ctr_pos orelse
blanchet@54246
   337
                   primrec_error_eqn "too few arguments in recursive call" t;
blanchet@54246
   338
                   (case AList.lookup (op =) mutual_calls y of
blanchet@54246
   339
                     SOME y' => list_comb (y', g_args)
blanchet@54246
   340
                   | NONE => subst_rec ()))
blanchet@54246
   341
                | NONE => subst_rec ())
blanchet@54246
   342
              end)
blanchet@54246
   343
        end
blanchet@54246
   344
      | subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b)
blanchet@54246
   345
      | subst _ t = t
blanchet@54246
   346
blanchet@54246
   347
    fun subst' t =
blanchet@54246
   348
      if has_call t then
blanchet@54246
   349
        (* FIXME detect this case earlier? *)
blanchet@54246
   350
        primrec_error_eqn "recursive call not directly applied to constructor argument" t
blanchet@54246
   351
      else
blanchet@54246
   352
        try_nested_rec [] (head_of t) t |> the_default t
blanchet@54246
   353
  in
blanchet@54246
   354
    subst' o subst []
blanchet@54246
   355
  end;
blanchet@54246
   356
blanchet@54246
   357
fun build_rec_arg lthy (funs_data : eqn_data list list) has_call (ctr_spec : rec_ctr_spec)
blanchet@54246
   358
    (maybe_eqn_data : eqn_data option) =
blanchet@54246
   359
  (case maybe_eqn_data of
blanchet@54246
   360
    NONE => undef_const
blanchet@54246
   361
  | SOME {ctr_args, left_args, right_args, rhs_term = t, ...} =>
blanchet@54246
   362
    let
blanchet@54246
   363
      val calls = #calls ctr_spec;
blanchet@54246
   364
      val n_args = fold (Integer.add o (fn Mutual_Rec _ => 2 | _ => 1)) calls 0;
blanchet@54246
   365
blanchet@54246
   366
      val no_calls' = tag_list 0 calls
blanchet@54246
   367
        |> map_filter (try (apsnd (fn No_Rec p => p | Mutual_Rec (p, _) => p)));
blanchet@54246
   368
      val mutual_calls' = tag_list 0 calls
blanchet@54246
   369
        |> map_filter (try (apsnd (fn Mutual_Rec (_, p) => p)));
blanchet@54246
   370
      val nested_calls' = tag_list 0 calls
blanchet@54246
   371
        |> map_filter (try (apsnd (fn Nested_Rec p => p)));
blanchet@54246
   372
blanchet@54246
   373
      val args = replicate n_args ("", dummyT)
blanchet@54246
   374
        |> Term.rename_wrt_term t
blanchet@54246
   375
        |> map Free
blanchet@54246
   376
        |> fold (fn (ctr_arg_idx, (arg_idx, _)) =>
blanchet@54246
   377
            nth_map arg_idx (K (nth ctr_args ctr_arg_idx)))
blanchet@54246
   378
          no_calls'
blanchet@54246
   379
        |> fold (fn (ctr_arg_idx, (arg_idx, T)) =>
blanchet@54246
   380
            nth_map arg_idx (K (retype_free T (nth ctr_args ctr_arg_idx))))
blanchet@54246
   381
          mutual_calls'
blanchet@54246
   382
        |> fold (fn (ctr_arg_idx, (arg_idx, T)) =>
blanchet@54246
   383
            nth_map arg_idx (K (retype_free T (nth ctr_args ctr_arg_idx))))
blanchet@54246
   384
          nested_calls';
blanchet@54246
   385
blanchet@54246
   386
      val fun_name_ctr_pos_list =
blanchet@54246
   387
        map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data;
blanchet@54246
   388
      val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1;
blanchet@54246
   389
      val mutual_calls = map (apfst (nth ctr_args) o apsnd (nth args o fst)) mutual_calls';
blanchet@54246
   390
      val nested_calls = map (apfst (nth ctr_args) o apsnd (nth args o fst)) nested_calls';
blanchet@54246
   391
    in
blanchet@54246
   392
      t
blanchet@54246
   393
      |> subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls
blanchet@54246
   394
      |> fold_rev lambda (args @ left_args @ right_args)
blanchet@54246
   395
    end);
blanchet@54246
   396
blanchet@54246
   397
fun build_defs lthy bs mxs (funs_data : eqn_data list list) (rec_specs : rec_spec list) has_call =
blanchet@54246
   398
  let
blanchet@54246
   399
    val n_funs = length funs_data;
blanchet@54246
   400
blanchet@54246
   401
    val ctr_spec_eqn_data_list' =
blanchet@54246
   402
      (take n_funs rec_specs |> map #ctr_specs) ~~ funs_data
blanchet@54246
   403
      |> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y))
blanchet@54246
   404
          ##> (fn x => null x orelse
blanchet@54246
   405
            primrec_error_eqns "excess equations in definition" (map #rhs_term x)) #> fst);
blanchet@54246
   406
    val _ = ctr_spec_eqn_data_list' |> map (fn (_, x) => length x <= 1 orelse
blanchet@54246
   407
      primrec_error_eqns ("multiple equations for constructor") (map #user_eqn x));
blanchet@54246
   408
blanchet@54246
   409
    val ctr_spec_eqn_data_list =
blanchet@54246
   410
      ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair []));
blanchet@54246
   411
blanchet@54246
   412
    val recs = take n_funs rec_specs |> map #recx;
blanchet@54246
   413
    val rec_args = ctr_spec_eqn_data_list
blanchet@54246
   414
      |> sort ((op <) o pairself (#offset o fst) |> make_ord)
blanchet@54246
   415
      |> map (uncurry (build_rec_arg lthy funs_data has_call) o apsnd (try the_single));
blanchet@54246
   416
    val ctr_poss = map (fn x =>
blanchet@54246
   417
      if length (distinct ((op =) o pairself (length o #left_args)) x) <> 1 then
blanchet@54246
   418
        primrec_error ("inconstant constructor pattern position for function " ^
blanchet@54246
   419
          quote (#fun_name (hd x)))
blanchet@54246
   420
      else
blanchet@54246
   421
        hd x |> #left_args |> length) funs_data;
blanchet@54246
   422
  in
blanchet@54246
   423
    (recs, ctr_poss)
blanchet@54246
   424
    |-> map2 (fn recx => fn ctr_pos => list_comb (recx, rec_args) |> permute_args ctr_pos)
blanchet@54246
   425
    |> Syntax.check_terms lthy
blanchet@54246
   426
    |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.conceal (Thm.def_binding b), []), t)))
blanchet@54246
   427
      bs mxs
blanchet@54246
   428
  end;
blanchet@54246
   429
blanchet@54246
   430
fun find_rec_calls has_call ({ctr, ctr_args, rhs_term, ...} : eqn_data) =
blanchet@54246
   431
  let
blanchet@54246
   432
    fun find bound_Ts (Abs (_, T, b)) ctr_arg = find (T :: bound_Ts) b ctr_arg
blanchet@54246
   433
      | find bound_Ts (t as _ $ _) ctr_arg =
blanchet@54246
   434
        let
blanchet@54246
   435
          val typof = curry fastype_of1 bound_Ts;
blanchet@54246
   436
          val (f', args') = strip_comb t;
blanchet@54246
   437
          val n = find_index (equal ctr_arg o head_of) args';
blanchet@54246
   438
        in
blanchet@54246
   439
          if n < 0 then
blanchet@54246
   440
            find bound_Ts f' ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args'
blanchet@54246
   441
          else
blanchet@54246
   442
            let
blanchet@54246
   443
              val (f, args as arg :: _) = chop n args' |>> curry list_comb f'
blanchet@54246
   444
              val (arg_head, arg_args) = Term.strip_comb arg;
blanchet@54246
   445
            in
blanchet@54246
   446
              if has_call f then
blanchet@54246
   447
                mk_partial_compN (length arg_args) (typof arg_head) f ::
blanchet@54246
   448
                maps (fn x => find bound_Ts x ctr_arg) args
blanchet@54246
   449
              else
blanchet@54246
   450
                find bound_Ts f ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args
blanchet@54246
   451
            end
blanchet@54246
   452
        end
blanchet@54246
   453
      | find _ _ _ = [];
blanchet@54246
   454
  in
blanchet@54246
   455
    map (find [] rhs_term) ctr_args
blanchet@54246
   456
    |> (fn [] => NONE | callss => SOME (ctr, callss))
blanchet@54246
   457
  end;
blanchet@54246
   458
blanchet@54246
   459
fun mk_primrec_tac ctxt num_extra_args map_idents map_comps fun_defs recx =
blanchet@54246
   460
  unfold_thms_tac ctxt fun_defs THEN
blanchet@54246
   461
  HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN
blanchet@54246
   462
  unfold_thms_tac ctxt (@{thms id_def split o_def fst_conv snd_conv} @ map_comps @ map_idents) THEN
blanchet@54246
   463
  HEADGOAL (rtac refl);
blanchet@54246
   464
blanchet@54246
   465
fun prepare_primrec fixes specs lthy =
blanchet@54246
   466
  let
blanchet@54272
   467
    val thy = Proof_Context.theory_of lthy;
blanchet@54272
   468
blanchet@54246
   469
    val (bs, mxs) = map_split (apfst fst) fixes;
blanchet@54246
   470
    val fun_names = map Binding.name_of bs;
blanchet@54246
   471
    val eqns_data = map (dissect_eqn lthy fun_names) specs;
blanchet@54246
   472
    val funs_data = eqns_data
blanchet@54246
   473
      |> partition_eq ((op =) o pairself #fun_name)
blanchet@54246
   474
      |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst
blanchet@54246
   475
      |> map (fn (x, y) => the_single y handle List.Empty =>
blanchet@54246
   476
          primrec_error ("missing equations for function " ^ quote x));
blanchet@54246
   477
blanchet@54246
   478
    val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =));
blanchet@54246
   479
    val arg_Ts = map (#rec_type o hd) funs_data;
blanchet@54246
   480
    val res_Ts = map (#res_type o hd) funs_data;
blanchet@54246
   481
    val callssss = funs_data
blanchet@54246
   482
      |> map (partition_eq ((op =) o pairself #ctr))
blanchet@54246
   483
      |> map (maps (map_filter (find_rec_calls has_call)));
blanchet@54246
   484
blanchet@54272
   485
    val _ = (case filter_out (fn (_, T) => Sign.of_sort thy (T, HOLogic.typeS)) (bs ~~ res_Ts) of
blanchet@54272
   486
        [] => ()
blanchet@54272
   487
      | (b, _) :: _ => primrec_error ("type of " ^ Binding.print b ^ " contains top sort"));
blanchet@54272
   488
blanchet@54246
   489
    val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy') =
blanchet@54246
   490
      rec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy;
blanchet@54246
   491
blanchet@54246
   492
    val actual_nn = length funs_data;
blanchet@54246
   493
blanchet@54246
   494
    val _ = let val ctrs = (maps (map #ctr o #ctr_specs) rec_specs) in
blanchet@54246
   495
      map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse
blanchet@54246
   496
        primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy' ctr) ^
blanchet@54246
   497
          " is not a constructor in left-hand side") user_eqn) eqns_data end;
blanchet@54246
   498
blanchet@54246
   499
    val defs = build_defs lthy' bs mxs funs_data rec_specs has_call;
blanchet@54246
   500
blanchet@54246
   501
    fun prove lthy def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec)
blanchet@54246
   502
        (fun_data : eqn_data list) =
blanchet@54246
   503
      let
blanchet@54246
   504
        val def_thms = map (snd o snd) def_thms';
blanchet@54246
   505
        val simp_thmss = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs
blanchet@54246
   506
          |> fst
blanchet@54246
   507
          |> map_filter (try (fn (x, [y]) =>
blanchet@54246
   508
            (#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y)))
blanchet@54246
   509
          |> map (fn (user_eqn, num_extra_args, rec_thm) =>
blanchet@54246
   510
            mk_primrec_tac lthy num_extra_args nested_map_idents nested_map_comps def_thms rec_thm
blanchet@54246
   511
            |> K |> Goal.prove lthy [] [] user_eqn
blanchet@54246
   512
            |> Thm.close_derivation);
blanchet@54246
   513
        val poss = find_indices (fn (x, y) => #ctr x = #ctr y) fun_data eqns_data;
blanchet@54246
   514
      in
blanchet@54246
   515
        (poss, simp_thmss)
blanchet@54246
   516
      end;
blanchet@54246
   517
blanchet@54246
   518
    val notes =
blanchet@54246
   519
      (if n2m then map2 (fn name => fn thm =>
blanchet@54246
   520
        (name, inductN, [thm], [])) fun_names (take actual_nn induct_thms) else [])
blanchet@54246
   521
      |> map (fn (prefix, thmN, thms, attrs) =>
blanchet@54246
   522
        ((Binding.qualify true prefix (Binding.name thmN), attrs), [(thms, [])]));
blanchet@54246
   523
blanchet@54246
   524
    val common_name = mk_common_name fun_names;
blanchet@54246
   525
blanchet@54246
   526
    val common_notes =
blanchet@54246
   527
      (if n2m then [(inductN, [induct_thm], [])] else [])
blanchet@54246
   528
      |> map (fn (thmN, thms, attrs) =>
blanchet@54246
   529
        ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])]));
blanchet@54246
   530
  in
blanchet@54246
   531
    (((fun_names, defs),
blanchet@54246
   532
      fn lthy => fn defs =>
blanchet@54246
   533
        split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)),
blanchet@54246
   534
      lthy' |> Local_Theory.notes (notes @ common_notes) |> snd)
blanchet@54246
   535
  end;
blanchet@54246
   536
blanchet@54246
   537
(* primrec definition *)
blanchet@54246
   538
blanchet@54246
   539
fun add_primrec_simple fixes ts lthy =
blanchet@54246
   540
  let
blanchet@54246
   541
    val (((names, defs), prove), lthy) = prepare_primrec fixes ts lthy
blanchet@54246
   542
      handle ERROR str => primrec_error str;
blanchet@54246
   543
  in
blanchet@54246
   544
    lthy
blanchet@54246
   545
    |> fold_map Local_Theory.define defs
blanchet@54246
   546
    |-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs))))
blanchet@54246
   547
  end
blanchet@54246
   548
  handle Primrec_Error (str, eqns) =>
blanchet@54246
   549
    if null eqns
blanchet@54246
   550
    then error ("primrec_new error:\n  " ^ str)
blanchet@54246
   551
    else error ("primrec_new error:\n  " ^ str ^ "\nin\n  " ^
blanchet@54246
   552
      space_implode "\n  " (map (quote o Syntax.string_of_term lthy) eqns));
blanchet@54246
   553
blanchet@54246
   554
local
blanchet@54246
   555
blanchet@54246
   556
fun gen_primrec prep_spec (raw_fixes : (binding * 'a option * mixfix) list) raw_spec lthy =
blanchet@54246
   557
  let
blanchet@54246
   558
    val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes)
blanchet@54246
   559
    val _ = null d orelse primrec_error ("duplicate function name(s): " ^ commas d);
blanchet@54246
   560
blanchet@54246
   561
    val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy);
blanchet@54246
   562
blanchet@54246
   563
    val mk_notes =
blanchet@54246
   564
      flat ooo map3 (fn poss => fn prefix => fn thms =>
blanchet@54246
   565
        let
blanchet@54246
   566
          val (bs, attrss) = map_split (fst o nth specs) poss;
blanchet@54246
   567
          val notes =
blanchet@54246
   568
            map3 (fn b => fn attrs => fn thm =>
blanchet@54246
   569
              ((Binding.qualify false prefix b, code_nitpicksimp_simp_attrs @ attrs), [([thm], [])]))
blanchet@54246
   570
            bs attrss thms;
blanchet@54246
   571
        in
blanchet@54246
   572
          ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes
blanchet@54246
   573
        end);
blanchet@54246
   574
  in
blanchet@54246
   575
    lthy
blanchet@54246
   576
    |> add_primrec_simple fixes (map snd specs)
blanchet@54246
   577
    |-> (fn (names, (ts, (posss, simpss))) =>
blanchet@54246
   578
      Spec_Rules.add Spec_Rules.Equational (ts, flat simpss)
blanchet@54246
   579
      #> Local_Theory.notes (mk_notes posss names simpss)
blanchet@54246
   580
      #>> pair ts o map snd)
blanchet@54246
   581
  end;
blanchet@54246
   582
blanchet@54246
   583
in
blanchet@54246
   584
blanchet@54246
   585
val add_primrec = gen_primrec Specification.check_spec;
blanchet@54246
   586
val add_primrec_cmd = gen_primrec Specification.read_spec;
blanchet@54246
   587
blanchet@54246
   588
end;
blanchet@54246
   589
blanchet@54246
   590
fun add_primrec_global fixes specs thy =
blanchet@54246
   591
  let
blanchet@54246
   592
    val lthy = Named_Target.theory_init thy;
blanchet@54246
   593
    val ((ts, simps), lthy') = add_primrec fixes specs lthy;
blanchet@54246
   594
    val simps' = burrow (Proof_Context.export lthy' lthy) simps;
blanchet@54246
   595
  in ((ts, simps'), Local_Theory.exit_global lthy') end;
blanchet@54246
   596
blanchet@54246
   597
fun add_primrec_overloaded ops fixes specs thy =
blanchet@54246
   598
  let
blanchet@54246
   599
    val lthy = Overloading.overloading ops thy;
blanchet@54246
   600
    val ((ts, simps), lthy') = add_primrec fixes specs lthy;
blanchet@54246
   601
    val simps' = burrow (Proof_Context.export lthy' lthy) simps;
blanchet@54246
   602
  in ((ts, simps'), Local_Theory.exit_global lthy') end;
blanchet@54246
   603
blanchet@54246
   604
end;