src/HOL/Tools/Function/size.ML
changeset 58112 8081087096ad
parent 58111 82db9ad610b9
child 58113 ab6220d6cf70
equal deleted inserted replaced
58111:82db9ad610b9 58112:8081087096ad
     1 (*  Title:      HOL/Tools/Function/size.ML
       
     2     Author:     Stefan Berghofer, Florian Haftmann, TU Muenchen
       
     3 
       
     4 Size functions for datatypes.
       
     5 *)
       
     6 
       
     7 signature SIZE =
       
     8 sig
       
     9   val setup: theory -> theory
       
    10 end;
       
    11 
       
    12 structure Size: SIZE =
       
    13 struct
       
    14 
       
    15 fun plus (t1, t2) = Const (@{const_name Groups.plus},
       
    16   HOLogic.natT --> HOLogic.natT --> HOLogic.natT) $ t1 $ t2;
       
    17 
       
    18 fun size_of_type f g h (T as Type (s, Ts)) =
       
    19       (case f s of
       
    20          SOME t => SOME t
       
    21        | NONE => (case g s of
       
    22            SOME size_name =>
       
    23              SOME (list_comb (Const (size_name,
       
    24                map (fn U => U --> HOLogic.natT) Ts @ [T] ---> HOLogic.natT),
       
    25                  map (size_of_type' f g h) Ts))
       
    26          | NONE => NONE))
       
    27   | size_of_type _ _ h (TFree (s, _)) = h s
       
    28 and size_of_type' f g h T = (case size_of_type f g h T of
       
    29       NONE => Abs ("x", T, HOLogic.zero)
       
    30     | SOME t => t);
       
    31 
       
    32 fun is_poly thy (Datatype_Aux.DtType (name, dts)) =
       
    33       is_some (BNF_LFP_Size.lookup_size_global thy name) andalso exists (is_poly thy) dts
       
    34   | is_poly _ _ = true;
       
    35 
       
    36 fun constrs_of thy name =
       
    37   let
       
    38     val {descr, index, ...} = Datatype_Data.the_info thy name
       
    39     val SOME (_, _, constrs) = AList.lookup op = descr index
       
    40   in constrs end;
       
    41 
       
    42 val app = curry (list_comb o swap);
       
    43 
       
    44 fun prove_size_thms (info : Datatype_Aux.info) new_type_names thy =
       
    45   let
       
    46     val {descr, rec_names, rec_rewrites, induct, ...} = info;
       
    47     val l = length new_type_names;
       
    48     val descr' = List.take (descr, l);
       
    49     val tycos = map (#1 o snd) descr';
       
    50   in
       
    51     if forall (fn tyco => can (Sign.arity_sorts thy tyco) [HOLogic.class_size]) tycos then
       
    52       (* nothing to do -- the "size" function is already defined *)
       
    53       thy
       
    54     else
       
    55       let
       
    56         val recTs = Datatype_Aux.get_rec_types descr;
       
    57         val (recTs1, recTs2) = chop l recTs;
       
    58         val (_, (_, paramdts, _)) :: _ = descr;
       
    59         val paramTs = map (Datatype_Aux.typ_of_dtyp descr) paramdts;
       
    60         val ((param_size_fs, param_size_fTs), f_names) = paramTs |>
       
    61           map (fn T as TFree (s, _) =>
       
    62             let
       
    63               val name = "f" ^ unprefix "'" s;
       
    64               val U = T --> HOLogic.natT
       
    65             in
       
    66               (((s, Free (name, U)), U), name)
       
    67             end) |> split_list |>> split_list;
       
    68         val param_size = AList.lookup op = param_size_fs;
       
    69 
       
    70         val extra_rewrites = descr |> map (#1 o snd) |> distinct op = |>
       
    71           map_filter (Option.map (fst o snd) o BNF_LFP_Size.lookup_size_global thy) |> flat;
       
    72         val extra_size = Option.map fst o BNF_LFP_Size.lookup_size_global thy;
       
    73 
       
    74         val (((size_names, size_fns), def_names), def_names') =
       
    75           recTs1 |> map (fn T as Type (s, _) =>
       
    76             let
       
    77               val s' = "size_" ^ Long_Name.base_name s;
       
    78               val s'' = Sign.full_bname thy s';
       
    79             in
       
    80               (s'',
       
    81                (list_comb (Const (s'', param_size_fTs @ [T] ---> HOLogic.natT),
       
    82                   map snd param_size_fs),
       
    83                 (s' ^ "_def", s' ^ "_overloaded_def")))
       
    84             end) |> split_list ||>> split_list ||>> split_list;
       
    85         val overloaded_size_fns = map HOLogic.size_const recTs1;
       
    86 
       
    87         (* instantiation for primrec combinator *)
       
    88         fun size_of_constr b size_ofp ((_, cargs), (_, cargs')) =
       
    89           let
       
    90             val Ts = map (Datatype_Aux.typ_of_dtyp descr) cargs;
       
    91             val k = length (filter Datatype_Aux.is_rec_type cargs);
       
    92             val (ts, _, _) = fold_rev (fn ((dt, dt'), T) => fn (us, i, j) =>
       
    93               if Datatype_Aux.is_rec_type dt then (Bound i :: us, i + 1, j + 1)
       
    94               else
       
    95                 (if b andalso is_poly thy dt' then
       
    96                    case size_of_type (K NONE) extra_size size_ofp T of
       
    97                      NONE => us | SOME sz => sz $ Bound j :: us
       
    98                  else us, i, j + 1))
       
    99                   (cargs ~~ cargs' ~~ Ts) ([], 0, k);
       
   100             val t =
       
   101               if null ts andalso (not b orelse not (exists (is_poly thy) cargs'))
       
   102               then HOLogic.zero
       
   103               else foldl1 plus (ts @ [HOLogic.Suc_zero])
       
   104           in
       
   105             fold_rev (fn T => fn t' => Abs ("x", T, t')) (Ts @ replicate k HOLogic.natT) t
       
   106           end;
       
   107 
       
   108         val fs = maps (fn (_, (name, _, constrs)) =>
       
   109           map (size_of_constr true param_size) (constrs ~~ constrs_of thy name)) descr;
       
   110         val fs' = maps (fn (n, (name, _, constrs)) =>
       
   111           map (size_of_constr (l <= n) (K NONE)) (constrs ~~ constrs_of thy name)) descr;
       
   112         val fTs = map fastype_of fs;
       
   113 
       
   114         val (rec_combs1, rec_combs2) = chop l (map (fn (T, rec_name) =>
       
   115           Const (rec_name, fTs @ [T] ---> HOLogic.natT))
       
   116             (recTs ~~ rec_names));
       
   117 
       
   118         fun define_overloaded (def_name, eq) lthy =
       
   119           let
       
   120             val (Free (c, _), rhs) = (Logic.dest_equals o Syntax.check_term lthy) eq;
       
   121             val (thm, lthy') = lthy
       
   122               |> Local_Theory.define ((Binding.name c, NoSyn), ((Binding.name def_name, []), rhs))
       
   123               |-> (fn (t, (_, thm)) => Spec_Rules.add Spec_Rules.Equational ([t], [thm]) #> pair thm);
       
   124             val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy');
       
   125             val thm' = singleton (Proof_Context.export lthy' ctxt_thy) thm;
       
   126           in (thm', lthy') end;
       
   127 
       
   128         val ((size_def_thms, size_def_thms'), thy') =
       
   129           thy
       
   130           |> Sign.add_consts (map (fn (s, T) => (Binding.name (Long_Name.base_name s),
       
   131               param_size_fTs @ [T] ---> HOLogic.natT, NoSyn))
       
   132             (size_names ~~ recTs1))
       
   133           |> Global_Theory.add_defs false
       
   134             (map (Thm.no_attributes o apsnd (Logic.mk_equals o apsnd (app fs)))
       
   135                (map Binding.name def_names ~~ (size_fns ~~ rec_combs1)))
       
   136           ||> Class.instantiation (tycos, map dest_TFree paramTs, [HOLogic.class_size])
       
   137           ||>> fold_map define_overloaded
       
   138             (def_names' ~~ map Logic.mk_equals (overloaded_size_fns ~~ map (app fs') rec_combs1))
       
   139           ||> Class.prove_instantiation_instance (K (Class.intro_classes_tac []))
       
   140           ||> Local_Theory.exit_global;
       
   141 
       
   142         val ctxt = Proof_Context.init_global thy';
       
   143 
       
   144         val simpset1 =
       
   145           put_simpset HOL_basic_ss ctxt addsimps @{thm Nat.add_0} :: @{thm Nat.add_0_right} ::
       
   146             size_def_thms @ size_def_thms' @ rec_rewrites @ extra_rewrites;
       
   147         val xs = map (fn i => "x" ^ string_of_int i) (1 upto length recTs2);
       
   148 
       
   149         fun mk_unfolded_size_eq tab size_ofp fs (p as (_, T), r) =
       
   150           HOLogic.mk_eq (app fs r $ Free p,
       
   151             the (size_of_type tab extra_size size_ofp T) $ Free p);
       
   152 
       
   153         fun prove_unfolded_size_eqs size_ofp fs =
       
   154           if null recTs2 then []
       
   155           else Datatype_Aux.split_conj_thm (Goal.prove_sorry ctxt xs []
       
   156             (HOLogic.mk_Trueprop (Datatype_Aux.mk_conj (replicate l @{term True} @
       
   157                map (mk_unfolded_size_eq (AList.lookup op =
       
   158                    (new_type_names ~~ map (app fs) rec_combs1)) size_ofp fs)
       
   159                  (xs ~~ recTs2 ~~ rec_combs2))))
       
   160             (fn _ => (Datatype_Aux.ind_tac induct xs THEN_ALL_NEW asm_simp_tac simpset1) 1));
       
   161 
       
   162         val unfolded_size_eqs1 = prove_unfolded_size_eqs param_size fs;
       
   163         val unfolded_size_eqs2 = prove_unfolded_size_eqs (K NONE) fs';
       
   164 
       
   165         (* characteristic equations for size functions *)
       
   166         fun gen_mk_size_eq p size_of size_ofp size_const T (cname, cargs) =
       
   167           let
       
   168             val Ts = map (Datatype_Aux.typ_of_dtyp descr) cargs;
       
   169             val tnames = Name.variant_list f_names (Datatype_Prop.make_tnames Ts);
       
   170             val ts = map_filter (fn (sT as (_, T), dt) =>
       
   171               Option.map (fn sz => sz $ Free sT)
       
   172                 (if p dt then size_of_type size_of extra_size size_ofp T
       
   173                  else NONE)) (tnames ~~ Ts ~~ cargs)
       
   174           in
       
   175             HOLogic.mk_Trueprop (HOLogic.mk_eq
       
   176               (size_const $ list_comb (Const (cname, Ts ---> T),
       
   177                  map2 (curry Free) tnames Ts),
       
   178                if null ts then HOLogic.zero
       
   179                else foldl1 plus (ts @ [HOLogic.Suc_zero])))
       
   180           end;
       
   181 
       
   182         val simpset2 =
       
   183           put_simpset HOL_basic_ss ctxt
       
   184             addsimps (rec_rewrites @ size_def_thms @ unfolded_size_eqs1);
       
   185         val simpset3 =
       
   186           put_simpset HOL_basic_ss ctxt
       
   187             addsimps (rec_rewrites @ size_def_thms' @ unfolded_size_eqs2);
       
   188 
       
   189         fun prove_size_eqs p size_fns size_ofp simpset =
       
   190           maps (fn (((_, (_, _, constrs)), size_const), T) =>
       
   191             map (fn constr => Drule.export_without_context (Goal.prove_sorry ctxt [] []
       
   192               (gen_mk_size_eq p (AList.lookup op = (new_type_names ~~ size_fns))
       
   193                  size_ofp size_const T constr)
       
   194               (fn _ => simp_tac simpset 1))) constrs)
       
   195             (descr' ~~ size_fns ~~ recTs1);
       
   196 
       
   197         val size_eqns = prove_size_eqs (is_poly thy') size_fns param_size simpset2 @
       
   198           prove_size_eqs Datatype_Aux.is_rec_type overloaded_size_fns (K NONE) simpset3;
       
   199 
       
   200         val ([(_, size_thms)], thy'') = thy'
       
   201           |> Global_Theory.note_thmss ""
       
   202             [((Binding.name "size",
       
   203                 [Simplifier.simp_add, Named_Theorems.add @{named_theorems nitpick_simp},
       
   204                  Thm.declaration_attribute (fn thm =>
       
   205                    Context.mapping (Code.add_default_eqn thm) I)]),
       
   206               [(size_eqns, [])])];
       
   207 
       
   208       in
       
   209         fold2 (fn new_type_name => fn size_name =>
       
   210             BNF_LFP_Size.register_size_global new_type_name size_name size_thms [])
       
   211           new_type_names size_names thy''
       
   212       end
       
   213   end;
       
   214 
       
   215 fun add_size_thms _ (new_type_names as name :: _) thy =
       
   216   let
       
   217     val info as {descr, ...} = Datatype_Data.the_info thy name;
       
   218     val prefix = space_implode "_" (map Long_Name.base_name new_type_names);
       
   219     val no_size = exists (fn (_, (_, _, constrs)) => exists (fn (_, cargs) => exists (fn dt =>
       
   220       Datatype_Aux.is_rec_type dt andalso
       
   221         not (null (fst (Datatype_Aux.strip_dtyp dt)))) cargs) constrs) descr
       
   222   in
       
   223     if no_size then thy
       
   224     else
       
   225       thy
       
   226       |> Sign.add_path prefix
       
   227       |> prove_size_thms info new_type_names
       
   228       |> Sign.restore_naming thy
       
   229   end;
       
   230 
       
   231 val setup = Datatype_Data.interpretation add_size_thms;
       
   232 
       
   233 end;