src/HOL/Tools/BNF/bnf_lfp_rec_sugar_more.ML
author blanchet
Wed Apr 23 10:23:26 2014 +0200 (2014-04-23)
changeset 56638 092a306bcc3d
parent 55863 fa3a1ec69a1b
child 56857 aa2de99be748
permissions -rw-r--r--
generate size instances for new-style datatypes
blanchet@55571
     1
(*  Title:      HOL/Tools/BNF/bnf_lfp_rec_sugar_more.ML
blanchet@55571
     2
    Author:     Lorenz Panny, TU Muenchen
blanchet@55571
     3
    Author:     Jasmin Blanchette, TU Muenchen
blanchet@55571
     4
    Copyright   2013
blanchet@55571
     5
blanchet@56638
     6
More new-style recursor sugar.
blanchet@55571
     7
*)
blanchet@55571
     8
blanchet@55571
     9
structure BNF_LFP_Rec_Sugar_More : sig end =
blanchet@55571
    10
struct
blanchet@55571
    11
blanchet@55571
    12
open BNF_Util
blanchet@55571
    13
open BNF_Def
blanchet@55571
    14
open BNF_FP_Util
blanchet@55571
    15
open BNF_FP_Def_Sugar
blanchet@55571
    16
open BNF_FP_N2M_Sugar
blanchet@55571
    17
open BNF_LFP_Rec_Sugar
blanchet@55571
    18
blanchet@55575
    19
val nested_simps = @{thms id_def split comp_def fst_conv snd_conv};
blanchet@55575
    20
blanchet@55571
    21
fun is_new_datatype ctxt s =
blanchet@55571
    22
  (case fp_sugar_of ctxt s of SOME {fp = Least_FP, ...} => true | _ => false);
blanchet@55571
    23
blanchet@55863
    24
fun basic_lfp_sugar_of C fun_arg_Tsss ({T, fp_res_index, ctr_defs, ctr_sugar, co_rec = recx,
blanchet@55863
    25
    co_rec_thms = rec_thms, ...} : fp_sugar) =
blanchet@55574
    26
  {T = T, fp_res_index = fp_res_index, C = C, fun_arg_Tsss = fun_arg_Tsss, ctr_defs = ctr_defs,
blanchet@55863
    27
   ctr_sugar = ctr_sugar, recx = recx, rec_thms = rec_thms};
blanchet@55571
    28
blanchet@55772
    29
fun get_basic_lfp_sugars bs arg_Ts callers callssss0 lthy0 =
blanchet@55571
    30
  let
blanchet@55571
    31
    val ((missing_arg_Ts, perm0_kks,
blanchet@55571
    32
          fp_sugars as {nested_bnfs, co_inducts = [induct_thm], ...} :: _, (lfp_sugar_thms, _)),
blanchet@55571
    33
         lthy) =
blanchet@55772
    34
      nested_to_mutual_fps Least_FP bs arg_Ts callers callssss0 lthy0;
blanchet@55574
    35
blanchet@55574
    36
    val Ts = map #T fp_sugars;
blanchet@55772
    37
    val Xs = map #X fp_sugars;
blanchet@55863
    38
    val Cs = map (body_type o fastype_of o #co_rec) fp_sugars;
blanchet@55772
    39
    val Xs_TCs = Xs ~~ (Ts ~~ Cs);
blanchet@55574
    40
blanchet@55772
    41
    fun zip_recT (Type (s, Us)) = [Type (s, map (HOLogic.mk_tupleT o zip_recT) Us)]
blanchet@55772
    42
      | zip_recT U =
blanchet@55772
    43
        (case AList.lookup (op =) Xs_TCs U of
blanchet@55772
    44
          SOME (T, C) => [T, C]
blanchet@55772
    45
        | NONE => [U]);
blanchet@55574
    46
blanchet@55772
    47
    val ctrXs_Tsss = map #ctrXs_Tss fp_sugars;
blanchet@55772
    48
    val fun_arg_Tssss = map (map (map zip_recT)) ctrXs_Tsss;
blanchet@55574
    49
blanchet@55571
    50
    val nested_map_idents = map (unfold_thms lthy @{thms id_def} o map_id0_of_bnf) nested_bnfs;
blanchet@55571
    51
    val nested_map_comps = map map_comp_of_bnf nested_bnfs;
blanchet@55571
    52
  in
blanchet@55574
    53
    (missing_arg_Ts, perm0_kks, map3 basic_lfp_sugar_of Cs fun_arg_Tssss fp_sugars,
blanchet@55574
    54
     nested_map_idents, nested_map_comps, induct_thm, is_some lfp_sugar_thms, lthy)
blanchet@55571
    55
  end;
blanchet@55571
    56
blanchet@55571
    57
exception NOT_A_MAP of term;
blanchet@55571
    58
blanchet@55571
    59
fun ill_formed_rec_call ctxt t =
blanchet@55571
    60
  error ("Ill-formed recursive call: " ^ quote (Syntax.string_of_term ctxt t));
blanchet@55571
    61
fun invalid_map ctxt t =
blanchet@55571
    62
  error ("Invalid map function in " ^ quote (Syntax.string_of_term ctxt t));
blanchet@55571
    63
fun unexpected_rec_call ctxt t =
blanchet@55571
    64
  error ("Unexpected recursive call: " ^ quote (Syntax.string_of_term ctxt t));
blanchet@55571
    65
blanchet@55571
    66
fun massage_nested_rec_call ctxt has_call raw_massage_fun bound_Ts y y' =
blanchet@55571
    67
  let
blanchet@55571
    68
    fun check_no_call t = if has_call t then unexpected_rec_call ctxt t else ();
blanchet@55571
    69
blanchet@55571
    70
    val typof = curry fastype_of1 bound_Ts;
blanchet@55571
    71
    val build_map_fst = build_map ctxt (fst_const o fst);
blanchet@55571
    72
blanchet@55571
    73
    val yT = typof y;
blanchet@55571
    74
    val yU = typof y';
blanchet@55571
    75
blanchet@55571
    76
    fun y_of_y' () = build_map_fst (yU, yT) $ y';
blanchet@55571
    77
    val elim_y = Term.map_aterms (fn t => if t = y then y_of_y' () else t);
blanchet@55571
    78
blanchet@55571
    79
    fun massage_mutual_fun U T t =
blanchet@55571
    80
      (case t of
blanchet@55571
    81
        Const (@{const_name comp}, _) $ t1 $ t2 =>
blanchet@55571
    82
        mk_comp bound_Ts (tap check_no_call t1, massage_mutual_fun U T t2)
blanchet@55571
    83
      | _ =>
blanchet@55571
    84
        if has_call t then
blanchet@55571
    85
          (case try HOLogic.dest_prodT U of
blanchet@55571
    86
            SOME (U1, U2) => if U1 = T then raw_massage_fun T U2 t else invalid_map ctxt t
blanchet@55571
    87
          | NONE => invalid_map ctxt t)
blanchet@55571
    88
        else
blanchet@55571
    89
          mk_comp bound_Ts (t, build_map_fst (U, T)));
blanchet@55571
    90
blanchet@55571
    91
    fun massage_map (Type (_, Us)) (Type (s, Ts)) t =
blanchet@55571
    92
        (case try (dest_map ctxt s) t of
blanchet@55571
    93
          SOME (map0, fs) =>
blanchet@55571
    94
          let
blanchet@55571
    95
            val Type (_, ran_Ts) = range_type (typof t);
blanchet@55571
    96
            val map' = mk_map (length fs) Us ran_Ts map0;
blanchet@55571
    97
            val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs;
blanchet@55571
    98
          in
blanchet@55571
    99
            Term.list_comb (map', fs')
blanchet@55571
   100
          end
blanchet@55571
   101
        | NONE => raise NOT_A_MAP t)
blanchet@55571
   102
      | massage_map _ _ t = raise NOT_A_MAP t
blanchet@55571
   103
    and massage_map_or_map_arg U T t =
blanchet@55571
   104
      if T = U then
blanchet@55571
   105
        tap check_no_call t
blanchet@55571
   106
      else
blanchet@55571
   107
        massage_map U T t
blanchet@55571
   108
        handle NOT_A_MAP _ => massage_mutual_fun U T t;
blanchet@55571
   109
blanchet@55571
   110
    fun massage_call (t as t1 $ t2) =
blanchet@55571
   111
        if has_call t then
blanchet@55571
   112
          if t2 = y then
blanchet@55571
   113
            massage_map yU yT (elim_y t1) $ y'
blanchet@55571
   114
            handle NOT_A_MAP t' => invalid_map ctxt t'
blanchet@55571
   115
          else
blanchet@55571
   116
            let val (g, xs) = Term.strip_comb t2 in
blanchet@55571
   117
              if g = y then
blanchet@55571
   118
                if exists has_call xs then unexpected_rec_call ctxt t2
blanchet@55571
   119
                else Term.list_comb (massage_call (mk_compN (length xs) bound_Ts (t1, y)), xs)
blanchet@55571
   120
              else
blanchet@55571
   121
                ill_formed_rec_call ctxt t
blanchet@55571
   122
            end
blanchet@55571
   123
        else
blanchet@55571
   124
          elim_y t
blanchet@55571
   125
      | massage_call t = if t = y then y_of_y' () else ill_formed_rec_call ctxt t;
blanchet@55571
   126
  in
blanchet@55571
   127
    massage_call
blanchet@55571
   128
  end;
blanchet@55571
   129
blanchet@55575
   130
fun rewrite_map_arg get_ctr_pos rec_type res_type =
blanchet@55575
   131
  let
blanchet@55575
   132
    val pT = HOLogic.mk_prodT (rec_type, res_type);
blanchet@55575
   133
blanchet@55575
   134
    fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT)
blanchet@55575
   135
      | subst d (Abs (v, T, b)) =
blanchet@55575
   136
        Abs (v, if d = SOME ~1 then pT else T, subst (Option.map (Integer.add 1) d) b)
blanchet@55575
   137
      | subst d t =
blanchet@55575
   138
        let
blanchet@55575
   139
          val (u, vs) = strip_comb t;
blanchet@55575
   140
          val ctr_pos = try (get_ctr_pos o fst o dest_Free) u |> the_default ~1;
blanchet@55575
   141
        in
blanchet@55575
   142
          if ctr_pos >= 0 then
blanchet@55575
   143
            if d = SOME ~1 andalso length vs = ctr_pos then
blanchet@56638
   144
              Term.list_comb (permute_args ctr_pos (snd_const pT), vs)
blanchet@55575
   145
            else if length vs > ctr_pos andalso is_some d andalso
blanchet@55575
   146
                d = try (fn Bound n => n) (nth vs ctr_pos) then
blanchet@56638
   147
              Term.list_comb (snd_const pT $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs))
blanchet@55575
   148
            else
blanchet@55575
   149
              raise PRIMREC ("recursive call not directly applied to constructor argument", [t])
blanchet@55575
   150
          else
blanchet@56638
   151
            Term.list_comb (u, map (subst (d |> d = SOME ~1 ? K NONE)) vs)
blanchet@55575
   152
        end
blanchet@55575
   153
  in
blanchet@55575
   154
    subst (SOME ~1)
blanchet@55575
   155
  end;
blanchet@55575
   156
blanchet@55575
   157
fun rewrite_nested_rec_call ctxt has_call get_ctr_pos =
blanchet@55575
   158
  massage_nested_rec_call ctxt has_call (rewrite_map_arg get_ctr_pos);
blanchet@55575
   159
blanchet@55571
   160
val _ = Theory.setup (register_lfp_rec_extension
blanchet@55575
   161
  {nested_simps = nested_simps, is_new_datatype = is_new_datatype,
blanchet@55575
   162
   get_basic_lfp_sugars = get_basic_lfp_sugars, rewrite_nested_rec_call = rewrite_nested_rec_call});
blanchet@55571
   163
blanchet@55571
   164
end;