src/HOL/Tools/Function/size.ML
author wenzelm
Sun Nov 08 18:43:42 2009 +0100 (2009-11-08)
changeset 33522 737589bb9bb8
parent 33339 d41f77196338
child 33553 35f2b30593a8
permissions -rw-r--r--
adapted Theory_Data;
tuned;
     1 (*  Title:      HOL/Tools/Function/size.ML
     2     Author:     Stefan Berghofer, Florian Haftmann & Alexander Krauss, TU Muenchen
     3 
     4 Size functions for datatypes.
     5 *)
     6 
     7 signature SIZE =
     8 sig
     9   val size_thms: theory -> string -> thm list
    10   val setup: theory -> theory
    11 end;
    12 
    13 structure Size: SIZE =
    14 struct
    15 
    16 open DatatypeAux;
    17 
    18 structure SizeData = Theory_Data
    19 (
    20   type T = (string * thm list) Symtab.table;
    21   val empty = Symtab.empty;
    22   val extend = I
    23   fun merge data = Symtab.merge (K true) data;
    24 );
    25 
    26 val lookup_size = SizeData.get #> Symtab.lookup;
    27 
    28 fun plus (t1, t2) = Const ("HOL.plus_class.plus",
    29   HOLogic.natT --> HOLogic.natT --> HOLogic.natT) $ t1 $ t2;
    30 
    31 fun size_of_type f g h (T as Type (s, Ts)) =
    32       (case f s of
    33          SOME t => SOME t
    34        | NONE => (case g s of
    35            SOME size_name =>
    36              SOME (list_comb (Const (size_name,
    37                map (fn U => U --> HOLogic.natT) Ts @ [T] ---> HOLogic.natT),
    38                  map (size_of_type' f g h) Ts))
    39          | NONE => NONE))
    40   | size_of_type f g h (TFree (s, _)) = h s
    41 and size_of_type' f g h T = (case size_of_type f g h T of
    42       NONE => Abs ("x", T, HOLogic.zero)
    43     | SOME t => t);
    44 
    45 fun is_poly thy (DtType (name, dts)) =
    46       (case Datatype.get_info thy name of
    47          NONE => false
    48        | SOME _ => exists (is_poly thy) dts)
    49   | is_poly _ _ = true;
    50 
    51 fun constrs_of thy name =
    52   let
    53     val {descr, index, ...} = Datatype.the_info thy name
    54     val SOME (_, _, constrs) = AList.lookup op = descr index
    55   in constrs end;
    56 
    57 val app = curry (list_comb o swap);
    58 
    59 fun prove_size_thms (info : info) new_type_names thy =
    60   let
    61     val {descr, alt_names, sorts, rec_names, rec_rewrites, induct, ...} = info;
    62     val l = length new_type_names;
    63     val alt_names' = (case alt_names of
    64       NONE => replicate l NONE | SOME names => map SOME names);
    65     val descr' = List.take (descr, l);
    66     val (rec_names1, rec_names2) = chop l rec_names;
    67     val recTs = get_rec_types descr sorts;
    68     val (recTs1, recTs2) = chop l recTs;
    69     val (_, (_, paramdts, _)) :: _ = descr;
    70     val paramTs = map (typ_of_dtyp descr sorts) paramdts;
    71     val ((param_size_fs, param_size_fTs), f_names) = paramTs |>
    72       map (fn T as TFree (s, _) =>
    73         let
    74           val name = "f" ^ implode (tl (explode s));
    75           val U = T --> HOLogic.natT
    76         in
    77           (((s, Free (name, U)), U), name)
    78         end) |> split_list |>> split_list;
    79     val param_size = AList.lookup op = param_size_fs;
    80 
    81     val extra_rewrites = descr |> map (#1 o snd) |> distinct op = |>
    82       map_filter (Option.map snd o lookup_size thy) |> flat;
    83     val extra_size = Option.map fst o lookup_size thy;
    84 
    85     val (((size_names, size_fns), def_names), def_names') =
    86       recTs1 ~~ alt_names' |>
    87       map (fn (T as Type (s, _), optname) =>
    88         let
    89           val s' = the_default (Long_Name.base_name s) optname ^ "_size";
    90           val s'' = Sign.full_bname thy s'
    91         in
    92           (s'',
    93            (list_comb (Const (s'', param_size_fTs @ [T] ---> HOLogic.natT),
    94               map snd param_size_fs),
    95             (s' ^ "_def", s' ^ "_overloaded_def")))
    96         end) |> split_list ||>> split_list ||>> split_list;
    97     val overloaded_size_fns = map HOLogic.size_const recTs1;
    98 
    99     (* instantiation for primrec combinator *)
   100     fun size_of_constr b size_ofp ((_, cargs), (_, cargs')) =
   101       let
   102         val Ts = map (typ_of_dtyp descr sorts) cargs;
   103         val k = length (filter is_rec_type cargs);
   104         val (ts, _, _) = fold_rev (fn ((dt, dt'), T) => fn (us, i, j) =>
   105           if is_rec_type dt then (Bound i :: us, i + 1, j + 1)
   106           else
   107             (if b andalso is_poly thy dt' then
   108                case size_of_type (K NONE) extra_size size_ofp T of
   109                  NONE => us | SOME sz => sz $ Bound j :: us
   110              else us, i, j + 1))
   111               (cargs ~~ cargs' ~~ Ts) ([], 0, k);
   112         val t =
   113           if null ts andalso (not b orelse not (exists (is_poly thy) cargs'))
   114           then HOLogic.zero
   115           else foldl1 plus (ts @ [HOLogic.Suc_zero])
   116       in
   117         fold_rev (fn T => fn t' => Abs ("x", T, t')) (Ts @ replicate k HOLogic.natT) t
   118       end;
   119 
   120     val fs = maps (fn (_, (name, _, constrs)) =>
   121       map (size_of_constr true param_size) (constrs ~~ constrs_of thy name)) descr;
   122     val fs' = maps (fn (n, (name, _, constrs)) =>
   123       map (size_of_constr (l <= n) (K NONE)) (constrs ~~ constrs_of thy name)) descr;
   124     val fTs = map fastype_of fs;
   125 
   126     val (rec_combs1, rec_combs2) = chop l (map (fn (T, rec_name) =>
   127       Const (rec_name, fTs @ [T] ---> HOLogic.natT))
   128         (recTs ~~ rec_names));
   129 
   130     fun define_overloaded (def_name, eq) lthy =
   131       let
   132         val (Free (c, _), rhs) = (Logic.dest_equals o Syntax.check_term lthy) eq;
   133         val ((_, (_, thm)), lthy') = lthy |> LocalTheory.define Thm.definitionK
   134           ((Binding.name c, NoSyn), ((Binding.name def_name, []), rhs));
   135         val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy');
   136         val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm;
   137       in (thm', lthy') end;
   138 
   139     val ((size_def_thms, size_def_thms'), thy') =
   140       thy
   141       |> Sign.add_consts_i (map (fn (s, T) =>
   142            (Binding.name (Long_Name.base_name s), param_size_fTs @ [T] ---> HOLogic.natT, NoSyn))
   143            (size_names ~~ recTs1))
   144       |> PureThy.add_defs false
   145         (map (Thm.no_attributes o apsnd (Logic.mk_equals o apsnd (app fs)))
   146            (map Binding.name def_names ~~ (size_fns ~~ rec_combs1)))
   147       ||> TheoryTarget.instantiation
   148            (map (#1 o snd) descr', map dest_TFree paramTs, [HOLogic.class_size])
   149       ||>> fold_map define_overloaded
   150         (def_names' ~~ map Logic.mk_equals (overloaded_size_fns ~~ map (app fs') rec_combs1))
   151       ||> Class.prove_instantiation_instance (K (Class.intro_classes_tac []))
   152       ||> LocalTheory.exit_global;
   153 
   154     val ctxt = ProofContext.init thy';
   155 
   156     val simpset1 = HOL_basic_ss addsimps @{thm add_0} :: @{thm add_0_right} ::
   157       size_def_thms @ size_def_thms' @ rec_rewrites @ extra_rewrites;
   158     val xs = map (fn i => "x" ^ string_of_int i) (1 upto length recTs2);
   159 
   160     fun mk_unfolded_size_eq tab size_ofp fs (p as (x, T), r) =
   161       HOLogic.mk_eq (app fs r $ Free p,
   162         the (size_of_type tab extra_size size_ofp T) $ Free p);
   163 
   164     fun prove_unfolded_size_eqs size_ofp fs =
   165       if null recTs2 then []
   166       else split_conj_thm (Skip_Proof.prove ctxt xs []
   167         (HOLogic.mk_Trueprop (mk_conj (replicate l HOLogic.true_const @
   168            map (mk_unfolded_size_eq (AList.lookup op =
   169                (new_type_names ~~ map (app fs) rec_combs1)) size_ofp fs)
   170              (xs ~~ recTs2 ~~ rec_combs2))))
   171         (fn _ => (indtac induct xs THEN_ALL_NEW asm_simp_tac simpset1) 1));
   172 
   173     val unfolded_size_eqs1 = prove_unfolded_size_eqs param_size fs;
   174     val unfolded_size_eqs2 = prove_unfolded_size_eqs (K NONE) fs';
   175 
   176     (* characteristic equations for size functions *)
   177     fun gen_mk_size_eq p size_of size_ofp size_const T (cname, cargs) =
   178       let
   179         val Ts = map (typ_of_dtyp descr sorts) cargs;
   180         val tnames = Name.variant_list f_names (DatatypeProp.make_tnames Ts);
   181         val ts = map_filter (fn (sT as (s, T), dt) =>
   182           Option.map (fn sz => sz $ Free sT)
   183             (if p dt then size_of_type size_of extra_size size_ofp T
   184              else NONE)) (tnames ~~ Ts ~~ cargs)
   185       in
   186         HOLogic.mk_Trueprop (HOLogic.mk_eq
   187           (size_const $ list_comb (Const (cname, Ts ---> T),
   188              map2 (curry Free) tnames Ts),
   189            if null ts then HOLogic.zero
   190            else foldl1 plus (ts @ [HOLogic.Suc_zero])))
   191       end;
   192 
   193     val simpset2 = HOL_basic_ss addsimps
   194       rec_rewrites @ size_def_thms @ unfolded_size_eqs1;
   195     val simpset3 = HOL_basic_ss addsimps
   196       rec_rewrites @ size_def_thms' @ unfolded_size_eqs2;
   197 
   198     fun prove_size_eqs p size_fns size_ofp simpset =
   199       maps (fn (((_, (_, _, constrs)), size_const), T) =>
   200         map (fn constr => Drule.standard (Skip_Proof.prove ctxt [] []
   201           (gen_mk_size_eq p (AList.lookup op = (new_type_names ~~ size_fns))
   202              size_ofp size_const T constr)
   203           (fn _ => simp_tac simpset 1))) constrs)
   204         (descr' ~~ size_fns ~~ recTs1);
   205 
   206     val size_eqns = prove_size_eqs (is_poly thy') size_fns param_size simpset2 @
   207       prove_size_eqs is_rec_type overloaded_size_fns (K NONE) simpset3;
   208 
   209     val ([size_thms], thy'') =  PureThy.add_thmss
   210       [((Binding.name "size", size_eqns),
   211         [Simplifier.simp_add, Nitpick_Simps.add,
   212          Thm.declaration_attribute
   213              (fn thm => Context.mapping (Code.add_default_eqn thm) I)])] thy'
   214 
   215   in
   216     SizeData.map (fold (Symtab.update_new o apsnd (rpair size_thms))
   217       (new_type_names ~~ size_names)) thy''
   218   end;
   219 
   220 fun add_size_thms config (new_type_names as name :: _) thy =
   221   let
   222     val info as {descr, alt_names, ...} = Datatype.the_info thy name;
   223     val prefix = Long_Name.map_base_name (K (space_implode "_"
   224       (the_default (map Long_Name.base_name new_type_names) alt_names))) name;
   225     val no_size = exists (fn (_, (_, _, constrs)) => exists (fn (_, cargs) => exists (fn dt =>
   226       is_rec_type dt andalso not (null (fst (strip_dtyp dt)))) cargs) constrs) descr
   227   in if no_size then thy
   228     else
   229       thy
   230       |> Sign.root_path
   231       |> Sign.add_path prefix
   232       |> Theory.checkpoint
   233       |> prove_size_thms info new_type_names
   234       |> Sign.restore_naming thy
   235   end;
   236 
   237 val size_thms = snd oo (the oo lookup_size);
   238 
   239 val setup = Datatype.interpretation add_size_thms;
   240 
   241 end;