src/HOL/Tools/Function/size.ML
author wenzelm
Thu Nov 19 14:46:33 2009 +0100 (2009-11-19)
changeset 33766 c679f05600cd
parent 33671 4b0f2599ed48
child 33968 f94fb13ecbb3
permissions -rw-r--r--
adapted Local_Theory.define -- eliminated odd thm kind;
haftmann@31775
     1
(*  Title:      HOL/Tools/Function/size.ML
haftmann@29495
     2
    Author:     Stefan Berghofer, Florian Haftmann & Alexander Krauss, TU Muenchen
haftmann@24710
     3
haftmann@24710
     4
Size functions for datatypes.
haftmann@24710
     5
*)
haftmann@24710
     6
haftmann@24710
     7
signature SIZE =
haftmann@24710
     8
sig
haftmann@24710
     9
  val size_thms: theory -> string -> thm list
haftmann@24710
    10
  val setup: theory -> theory
haftmann@24710
    11
end;
haftmann@24710
    12
haftmann@24710
    13
structure Size: SIZE =
haftmann@24710
    14
struct
haftmann@24710
    15
haftmann@24710
    16
open DatatypeAux;
haftmann@24710
    17
wenzelm@33522
    18
structure SizeData = Theory_Data
wenzelm@24714
    19
(
berghofe@25679
    20
  type T = (string * thm list) Symtab.table;
haftmann@24710
    21
  val empty = Symtab.empty;
haftmann@24710
    22
  val extend = I
wenzelm@33522
    23
  fun merge data = Symtab.merge (K true) data;
wenzelm@24714
    24
);
haftmann@24710
    25
berghofe@25679
    26
val lookup_size = SizeData.get #> Symtab.lookup;
haftmann@24710
    27
haftmann@24710
    28
fun plus (t1, t2) = Const ("HOL.plus_class.plus",
haftmann@24710
    29
  HOLogic.natT --> HOLogic.natT --> HOLogic.natT) $ t1 $ t2;
haftmann@24710
    30
berghofe@25679
    31
fun size_of_type f g h (T as Type (s, Ts)) =
berghofe@25679
    32
      (case f s of
berghofe@25679
    33
         SOME t => SOME t
berghofe@25679
    34
       | NONE => (case g s of
berghofe@25679
    35
           SOME size_name =>
berghofe@25679
    36
             SOME (list_comb (Const (size_name,
berghofe@25679
    37
               map (fn U => U --> HOLogic.natT) Ts @ [T] ---> HOLogic.natT),
berghofe@25679
    38
                 map (size_of_type' f g h) Ts))
berghofe@25679
    39
         | NONE => NONE))
berghofe@25679
    40
  | size_of_type f g h (TFree (s, _)) = h s
berghofe@25679
    41
and size_of_type' f g h T = (case size_of_type f g h T of
berghofe@25679
    42
      NONE => Abs ("x", T, HOLogic.zero)
berghofe@25679
    43
    | SOME t => t);
berghofe@25679
    44
berghofe@25679
    45
fun is_poly thy (DtType (name, dts)) =
haftmann@31784
    46
      (case Datatype.get_info thy name of
berghofe@25679
    47
         NONE => false
berghofe@25679
    48
       | SOME _ => exists (is_poly thy) dts)
berghofe@25679
    49
  | is_poly _ _ = true;
berghofe@25679
    50
berghofe@25679
    51
fun constrs_of thy name =
haftmann@24710
    52
  let
haftmann@31784
    53
    val {descr, index, ...} = Datatype.the_info thy name
berghofe@25679
    54
    val SOME (_, _, constrs) = AList.lookup op = descr index
berghofe@25679
    55
  in constrs end;
berghofe@25679
    56
berghofe@25679
    57
val app = curry (list_comb o swap);
haftmann@24710
    58
haftmann@31737
    59
fun prove_size_thms (info : info) new_type_names thy =
berghofe@25679
    60
  let
haftmann@32727
    61
    val {descr, alt_names, sorts, rec_names, rec_rewrites, induct, ...} = info;
berghofe@25679
    62
    val l = length new_type_names;
berghofe@25679
    63
    val alt_names' = (case alt_names of
berghofe@25679
    64
      NONE => replicate l NONE | SOME names => map SOME names);
berghofe@25679
    65
    val descr' = List.take (descr, l);
berghofe@25679
    66
    val (rec_names1, rec_names2) = chop l rec_names;
haftmann@25864
    67
    val recTs = get_rec_types descr sorts;
berghofe@25679
    68
    val (recTs1, recTs2) = chop l recTs;
berghofe@25679
    69
    val (_, (_, paramdts, _)) :: _ = descr;
haftmann@25864
    70
    val paramTs = map (typ_of_dtyp descr sorts) paramdts;
berghofe@25679
    71
    val ((param_size_fs, param_size_fTs), f_names) = paramTs |>
berghofe@25679
    72
      map (fn T as TFree (s, _) =>
berghofe@25679
    73
        let
berghofe@25679
    74
          val name = "f" ^ implode (tl (explode s));
berghofe@25679
    75
          val U = T --> HOLogic.natT
berghofe@25679
    76
        in
berghofe@25679
    77
          (((s, Free (name, U)), U), name)
berghofe@25679
    78
        end) |> split_list |>> split_list;
berghofe@25679
    79
    val param_size = AList.lookup op = param_size_fs;
haftmann@24710
    80
berghofe@25679
    81
    val extra_rewrites = descr |> map (#1 o snd) |> distinct op = |>
haftmann@29495
    82
      map_filter (Option.map snd o lookup_size thy) |> flat;
berghofe@25679
    83
    val extra_size = Option.map fst o lookup_size thy;
berghofe@25679
    84
berghofe@25679
    85
    val (((size_names, size_fns), def_names), def_names') =
berghofe@25679
    86
      recTs1 ~~ alt_names' |>
berghofe@25679
    87
      map (fn (T as Type (s, _), optname) =>
berghofe@25679
    88
        let
wenzelm@30364
    89
          val s' = the_default (Long_Name.base_name s) optname ^ "_size";
haftmann@28965
    90
          val s'' = Sign.full_bname thy s'
berghofe@25679
    91
        in
berghofe@25679
    92
          (s'',
berghofe@25679
    93
           (list_comb (Const (s'', param_size_fTs @ [T] ---> HOLogic.natT),
berghofe@25679
    94
              map snd param_size_fs),
berghofe@25679
    95
            (s' ^ "_def", s' ^ "_overloaded_def")))
berghofe@25679
    96
        end) |> split_list ||>> split_list ||>> split_list;
berghofe@25679
    97
    val overloaded_size_fns = map HOLogic.size_const recTs1;
berghofe@25679
    98
berghofe@25679
    99
    (* instantiation for primrec combinator *)
berghofe@25679
   100
    fun size_of_constr b size_ofp ((_, cargs), (_, cargs')) =
haftmann@24710
   101
      let
haftmann@25864
   102
        val Ts = map (typ_of_dtyp descr sorts) cargs;
haftmann@24710
   103
        val k = length (filter is_rec_type cargs);
berghofe@25679
   104
        val (ts, _, _) = fold_rev (fn ((dt, dt'), T) => fn (us, i, j) =>
berghofe@25679
   105
          if is_rec_type dt then (Bound i :: us, i + 1, j + 1)
berghofe@25679
   106
          else
berghofe@25679
   107
            (if b andalso is_poly thy dt' then
berghofe@25679
   108
               case size_of_type (K NONE) extra_size size_ofp T of
berghofe@25679
   109
                 NONE => us | SOME sz => sz $ Bound j :: us
berghofe@25679
   110
             else us, i, j + 1))
berghofe@25679
   111
              (cargs ~~ cargs' ~~ Ts) ([], 0, k);
berghofe@25679
   112
        val t =
berghofe@25679
   113
          if null ts andalso (not b orelse not (exists (is_poly thy) cargs'))
berghofe@25679
   114
          then HOLogic.zero
berghofe@25679
   115
          else foldl1 plus (ts @ [HOLogic.Suc_zero])
haftmann@24710
   116
      in
wenzelm@33339
   117
        fold_rev (fn T => fn t' => Abs ("x", T, t')) (Ts @ replicate k HOLogic.natT) t
haftmann@24710
   118
      end;
haftmann@24710
   119
berghofe@25679
   120
    val fs = maps (fn (_, (name, _, constrs)) =>
berghofe@25679
   121
      map (size_of_constr true param_size) (constrs ~~ constrs_of thy name)) descr;
berghofe@25679
   122
    val fs' = maps (fn (n, (name, _, constrs)) =>
berghofe@25679
   123
      map (size_of_constr (l <= n) (K NONE)) (constrs ~~ constrs_of thy name)) descr;
haftmann@24710
   124
    val fTs = map fastype_of fs;
haftmann@24710
   125
berghofe@25679
   126
    val (rec_combs1, rec_combs2) = chop l (map (fn (T, rec_name) =>
berghofe@25679
   127
      Const (rec_name, fTs @ [T] ---> HOLogic.natT))
berghofe@25679
   128
        (recTs ~~ rec_names));
berghofe@25679
   129
haftmann@25835
   130
    fun define_overloaded (def_name, eq) lthy =
haftmann@25835
   131
      let
haftmann@25835
   132
        val (Free (c, _), rhs) = (Logic.dest_equals o Syntax.check_term lthy) eq;
wenzelm@33766
   133
        val ((_, (_, thm)), lthy') = lthy
wenzelm@33766
   134
          |> Local_Theory.define ((Binding.name c, NoSyn), ((Binding.name def_name, []), rhs));
haftmann@25835
   135
        val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy');
haftmann@25864
   136
        val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm;
haftmann@25835
   137
      in (thm', lthy') end;
haftmann@25835
   138
berghofe@25679
   139
    val ((size_def_thms, size_def_thms'), thy') =
haftmann@24710
   140
      thy
wenzelm@24714
   141
      |> Sign.add_consts_i (map (fn (s, T) =>
wenzelm@30364
   142
           (Binding.name (Long_Name.base_name s), param_size_fTs @ [T] ---> HOLogic.natT, NoSyn))
berghofe@25679
   143
           (size_names ~~ recTs1))
haftmann@27691
   144
      |> PureThy.add_defs false
berghofe@25679
   145
        (map (Thm.no_attributes o apsnd (Logic.mk_equals o apsnd (app fs)))
haftmann@29579
   146
           (map Binding.name def_names ~~ (size_fns ~~ rec_combs1)))
wenzelm@33553
   147
      ||> Theory_Target.instantiation
berghofe@25890
   148
           (map (#1 o snd) descr', map dest_TFree paramTs, [HOLogic.class_size])
haftmann@25835
   149
      ||>> fold_map define_overloaded
haftmann@25835
   150
        (def_names' ~~ map Logic.mk_equals (overloaded_size_fns ~~ map (app fs') rec_combs1))
haftmann@25835
   151
      ||> Class.prove_instantiation_instance (K (Class.intro_classes_tac []))
wenzelm@33671
   152
      ||> Local_Theory.exit_global;
berghofe@25679
   153
berghofe@25679
   154
    val ctxt = ProofContext.init thy';
berghofe@25679
   155
berghofe@25679
   156
    val simpset1 = HOL_basic_ss addsimps @{thm add_0} :: @{thm add_0_right} ::
berghofe@25679
   157
      size_def_thms @ size_def_thms' @ rec_rewrites @ extra_rewrites;
berghofe@25679
   158
    val xs = map (fn i => "x" ^ string_of_int i) (1 upto length recTs2);
berghofe@25679
   159
berghofe@25679
   160
    fun mk_unfolded_size_eq tab size_ofp fs (p as (x, T), r) =
berghofe@25679
   161
      HOLogic.mk_eq (app fs r $ Free p,
berghofe@25679
   162
        the (size_of_type tab extra_size size_ofp T) $ Free p);
berghofe@25679
   163
berghofe@25679
   164
    fun prove_unfolded_size_eqs size_ofp fs =
berghofe@25679
   165
      if null recTs2 then []
wenzelm@32970
   166
      else split_conj_thm (Skip_Proof.prove ctxt xs []
berghofe@25679
   167
        (HOLogic.mk_Trueprop (mk_conj (replicate l HOLogic.true_const @
berghofe@25679
   168
           map (mk_unfolded_size_eq (AList.lookup op =
berghofe@25679
   169
               (new_type_names ~~ map (app fs) rec_combs1)) size_ofp fs)
berghofe@25679
   170
             (xs ~~ recTs2 ~~ rec_combs2))))
haftmann@32712
   171
        (fn _ => (indtac induct xs THEN_ALL_NEW asm_simp_tac simpset1) 1));
berghofe@25679
   172
berghofe@25890
   173
    val unfolded_size_eqs1 = prove_unfolded_size_eqs param_size fs;
berghofe@25890
   174
    val unfolded_size_eqs2 = prove_unfolded_size_eqs (K NONE) fs';
haftmann@24710
   175
berghofe@25679
   176
    (* characteristic equations for size functions *)
berghofe@25679
   177
    fun gen_mk_size_eq p size_of size_ofp size_const T (cname, cargs) =
berghofe@25679
   178
      let
haftmann@25864
   179
        val Ts = map (typ_of_dtyp descr sorts) cargs;
berghofe@25679
   180
        val tnames = Name.variant_list f_names (DatatypeProp.make_tnames Ts);
haftmann@29495
   181
        val ts = map_filter (fn (sT as (s, T), dt) =>
berghofe@25679
   182
          Option.map (fn sz => sz $ Free sT)
berghofe@25679
   183
            (if p dt then size_of_type size_of extra_size size_ofp T
berghofe@25679
   184
             else NONE)) (tnames ~~ Ts ~~ cargs)
berghofe@25679
   185
      in
berghofe@25679
   186
        HOLogic.mk_Trueprop (HOLogic.mk_eq
berghofe@25679
   187
          (size_const $ list_comb (Const (cname, Ts ---> T),
berghofe@25679
   188
             map2 (curry Free) tnames Ts),
berghofe@25679
   189
           if null ts then HOLogic.zero
berghofe@25679
   190
           else foldl1 plus (ts @ [HOLogic.Suc_zero])))
berghofe@25679
   191
      end;
haftmann@24710
   192
berghofe@25679
   193
    val simpset2 = HOL_basic_ss addsimps
berghofe@25890
   194
      rec_rewrites @ size_def_thms @ unfolded_size_eqs1;
berghofe@25890
   195
    val simpset3 = HOL_basic_ss addsimps
berghofe@25890
   196
      rec_rewrites @ size_def_thms' @ unfolded_size_eqs2;
berghofe@25679
   197
berghofe@25890
   198
    fun prove_size_eqs p size_fns size_ofp simpset =
berghofe@25679
   199
      maps (fn (((_, (_, _, constrs)), size_const), T) =>
wenzelm@32970
   200
        map (fn constr => Drule.standard (Skip_Proof.prove ctxt [] []
berghofe@25679
   201
          (gen_mk_size_eq p (AList.lookup op = (new_type_names ~~ size_fns))
berghofe@25679
   202
             size_ofp size_const T constr)
berghofe@25890
   203
          (fn _ => simp_tac simpset 1))) constrs)
berghofe@25679
   204
        (descr' ~~ size_fns ~~ recTs1);
berghofe@25679
   205
berghofe@25890
   206
    val size_eqns = prove_size_eqs (is_poly thy') size_fns param_size simpset2 @
berghofe@25890
   207
      prove_size_eqs is_rec_type overloaded_size_fns (K NONE) simpset3;
berghofe@25679
   208
berghofe@25679
   209
    val ([size_thms], thy'') =  PureThy.add_thmss
haftmann@29579
   210
      [((Binding.name "size", size_eqns),
blanchet@33056
   211
        [Simplifier.simp_add, Nitpick_Simps.add,
blanchet@29863
   212
         Thm.declaration_attribute
blanchet@29863
   213
             (fn thm => Context.mapping (Code.add_default_eqn thm) I)])] thy'
haftmann@24710
   214
haftmann@24710
   215
  in
berghofe@25679
   216
    SizeData.map (fold (Symtab.update_new o apsnd (rpair size_thms))
berghofe@25679
   217
      (new_type_names ~~ size_names)) thy''
haftmann@24710
   218
  end;
haftmann@24710
   219
haftmann@31668
   220
fun add_size_thms config (new_type_names as name :: _) thy =
haftmann@24710
   221
  let
haftmann@31784
   222
    val info as {descr, alt_names, ...} = Datatype.the_info thy name;
wenzelm@30364
   223
    val prefix = Long_Name.map_base_name (K (space_implode "_"
wenzelm@30364
   224
      (the_default (map Long_Name.base_name new_type_names) alt_names))) name;
haftmann@24710
   225
    val no_size = exists (fn (_, (_, _, constrs)) => exists (fn (_, cargs) => exists (fn dt =>
berghofe@25679
   226
      is_rec_type dt andalso not (null (fst (strip_dtyp dt)))) cargs) constrs) descr
haftmann@24710
   227
  in if no_size then thy
haftmann@24710
   228
    else
haftmann@24710
   229
      thy
berghofe@25679
   230
      |> Sign.root_path
wenzelm@24714
   231
      |> Sign.add_path prefix
wenzelm@28361
   232
      |> Theory.checkpoint
berghofe@25679
   233
      |> prove_size_thms info new_type_names
berghofe@25679
   234
      |> Sign.restore_naming thy
haftmann@24710
   235
  end;
haftmann@24710
   236
berghofe@25679
   237
val size_thms = snd oo (the oo lookup_size);
haftmann@24710
   238
haftmann@31723
   239
val setup = Datatype.interpretation add_size_thms;
haftmann@24710
   240
blanchet@29866
   241
end;