src/HOLCF/Tools/Domain/domain_take_proofs.ML
changeset 35514 a2cfa413eaab
child 35515 d631dc53ede0
equal deleted inserted replaced
35513:89eddccbb93d 35514:a2cfa413eaab
       
     1 (*  Title:      HOLCF/Tools/domain/domain_take_proofs.ML
       
     2     Author:     Brian Huffman
       
     3 
       
     4 Defines take functions for the given domain equation
       
     5 and proves related theorems.
       
     6 *)
       
     7 
       
     8 signature DOMAIN_TAKE_PROOFS =
       
     9 sig
       
    10   type iso_info =
       
    11     {
       
    12       absT : typ,
       
    13       repT : typ,
       
    14       abs_const : term,
       
    15       rep_const : term,
       
    16       abs_inverse : thm,
       
    17       rep_inverse : thm
       
    18     }
       
    19 
       
    20   val define_take_functions :
       
    21     (binding * iso_info) list -> theory ->
       
    22     { take_consts : term list,
       
    23       take_defs : thm list,
       
    24       chain_take_thms : thm list,
       
    25       take_0_thms : thm list,
       
    26       take_Suc_thms : thm list,
       
    27       deflation_take_thms : thm list
       
    28     } * theory
       
    29 
       
    30   val map_of_typ :
       
    31     theory -> (typ * term) list -> typ -> term
       
    32 
       
    33   val add_map_function :
       
    34     (string * string * thm) -> theory -> theory
       
    35 
       
    36   val get_map_tab : theory -> string Symtab.table
       
    37   val get_deflation_thms : theory -> thm list
       
    38 end;
       
    39 
       
    40 structure Domain_Take_Proofs : DOMAIN_TAKE_PROOFS =
       
    41 struct
       
    42 
       
    43 type iso_info =
       
    44   {
       
    45     absT : typ,
       
    46     repT : typ,
       
    47     abs_const : term,
       
    48     rep_const : term,
       
    49     abs_inverse : thm,
       
    50     rep_inverse : thm
       
    51   };
       
    52 
       
    53 val beta_ss =
       
    54   HOL_basic_ss
       
    55     addsimps simp_thms
       
    56     addsimps [@{thm beta_cfun}]
       
    57     addsimprocs [@{simproc cont_proc}];
       
    58 
       
    59 val beta_tac = simp_tac beta_ss;
       
    60 
       
    61 (******************************************************************************)
       
    62 (******************************** theory data *********************************)
       
    63 (******************************************************************************)
       
    64 
       
    65 structure MapData = Theory_Data
       
    66 (
       
    67   (* constant names like "foo_map" *)
       
    68   type T = string Symtab.table;
       
    69   val empty = Symtab.empty;
       
    70   val extend = I;
       
    71   fun merge data = Symtab.merge (K true) data;
       
    72 );
       
    73 
       
    74 structure DeflMapData = Theory_Data
       
    75 (
       
    76   (* theorems like "deflation a ==> deflation (foo_map$a)" *)
       
    77   type T = thm list;
       
    78   val empty = [];
       
    79   val extend = I;
       
    80   val merge = Thm.merge_thms;
       
    81 );
       
    82 
       
    83 fun add_map_function (tname, map_name, deflation_map_thm) =
       
    84     MapData.map (Symtab.insert (K true) (tname, map_name))
       
    85     #> DeflMapData.map (Thm.add_thm deflation_map_thm);
       
    86 
       
    87 val get_map_tab = MapData.get;
       
    88 val get_deflation_thms = DeflMapData.get;
       
    89 
       
    90 (******************************************************************************)
       
    91 (************************** building types and terms **************************)
       
    92 (******************************************************************************)
       
    93 
       
    94 open HOLCF_Library;
       
    95 
       
    96 infixr 6 ->>;
       
    97 infix -->>;
       
    98 
       
    99 val deflT = @{typ "udom alg_defl"};
       
   100 
       
   101 fun mapT (T as Type (_, Ts)) =
       
   102     (map (fn T => T ->> T) Ts) -->> (T ->> T)
       
   103   | mapT T = T ->> T;
       
   104 
       
   105 fun mk_Rep_of T =
       
   106   Const (@{const_name Rep_of}, Term.itselfT T --> deflT) $ Logic.mk_type T;
       
   107 
       
   108 fun coerce_const T = Const (@{const_name coerce}, T);
       
   109 
       
   110 fun isodefl_const T =
       
   111   Const (@{const_name isodefl}, (T ->> T) --> deflT --> HOLogic.boolT);
       
   112 
       
   113 fun mk_deflation t =
       
   114   Const (@{const_name deflation}, Term.fastype_of t --> boolT) $ t;
       
   115 
       
   116 fun mk_lub t =
       
   117   let
       
   118     val T = Term.range_type (Term.fastype_of t);
       
   119     val lub_const = Const (@{const_name lub}, (T --> boolT) --> T);
       
   120     val UNIV_const = @{term "UNIV :: nat set"};
       
   121     val image_type = (natT --> T) --> (natT --> boolT) --> T --> boolT;
       
   122     val image_const = Const (@{const_name image}, image_type);
       
   123   in
       
   124     lub_const $ (image_const $ t $ UNIV_const)
       
   125   end;
       
   126 
       
   127 (* splits a cterm into the right and lefthand sides of equality *)
       
   128 fun dest_eqs t = HOLogic.dest_eq (HOLogic.dest_Trueprop t);
       
   129 
       
   130 fun mk_eqs (t, u) = HOLogic.mk_Trueprop (HOLogic.mk_eq (t, u));
       
   131 
       
   132 (******************************************************************************)
       
   133 (****************************** isomorphism info ******************************)
       
   134 (******************************************************************************)
       
   135 
       
   136 fun deflation_abs_rep (info : iso_info) : thm =
       
   137   let
       
   138     val abs_iso = #abs_inverse info;
       
   139     val rep_iso = #rep_inverse info;
       
   140     val thm = @{thm deflation_abs_rep} OF [abs_iso, rep_iso];
       
   141   in
       
   142     Drule.export_without_context thm
       
   143   end
       
   144 
       
   145 (******************************************************************************)
       
   146 (********************* building map functions over types **********************)
       
   147 (******************************************************************************)
       
   148 
       
   149 fun map_of_typ (thy : theory) (sub : (typ * term) list) (T : typ) : term =
       
   150   let
       
   151     val map_tab = get_map_tab thy;
       
   152     fun auto T = T ->> T;
       
   153     fun map_of T =
       
   154         case AList.lookup (op =) sub T of
       
   155           SOME m => (m, true) | NONE => map_of' T
       
   156     and map_of' (T as (Type (c, Ts))) =
       
   157         (case Symtab.lookup map_tab c of
       
   158           SOME map_name =>
       
   159           let
       
   160             val map_type = map auto Ts -->> auto T;
       
   161             val (ms, bs) = map_split map_of Ts;
       
   162           in
       
   163             if exists I bs
       
   164             then (list_ccomb (Const (map_name, map_type), ms), true)
       
   165             else (mk_ID T, false)
       
   166           end
       
   167         | NONE => (mk_ID T, false))
       
   168       | map_of' T = (mk_ID T, false);
       
   169   in
       
   170     fst (map_of T)
       
   171   end;
       
   172 
       
   173 
       
   174 (******************************************************************************)
       
   175 (********************* declaring definitions and theorems *********************)
       
   176 (******************************************************************************)
       
   177 
       
   178 fun define_const
       
   179     (bind : binding, rhs : term)
       
   180     (thy : theory)
       
   181     : (term * thm) * theory =
       
   182   let
       
   183     val typ = Term.fastype_of rhs;
       
   184     val (const, thy) = Sign.declare_const ((bind, typ), NoSyn) thy;
       
   185     val eqn = Logic.mk_equals (const, rhs);
       
   186     val def = Thm.no_attributes (Binding.suffix_name "_def" bind, eqn);
       
   187     val (def_thm, thy) = yield_singleton (PureThy.add_defs false) def thy;
       
   188   in
       
   189     ((const, def_thm), thy)
       
   190   end;
       
   191 
       
   192 fun add_qualified_thm name (path, thm) thy =
       
   193     thy
       
   194     |> Sign.add_path path
       
   195     |> yield_singleton PureThy.add_thms
       
   196         (Thm.no_attributes (Binding.name name, thm))
       
   197     ||> Sign.parent_path;
       
   198 
       
   199 (******************************************************************************)
       
   200 (************************** defining take functions ***************************)
       
   201 (******************************************************************************)
       
   202 
       
   203 fun define_take_functions
       
   204     (spec : (binding * iso_info) list)
       
   205     (thy : theory) =
       
   206   let
       
   207 
       
   208     (* retrieve components of spec *)
       
   209     val dom_binds = map fst spec;
       
   210     val iso_infos = map snd spec;
       
   211     val dom_eqns = map (fn x => (#absT x, #repT x)) iso_infos;
       
   212     val rep_abs_consts = map (fn x => (#rep_const x, #abs_const x)) iso_infos;
       
   213     val dnames = map Binding.name_of dom_binds;
       
   214 
       
   215     (* get table of map functions *)
       
   216     val map_tab = MapData.get thy;
       
   217 
       
   218     fun mk_projs []      t = []
       
   219       | mk_projs (x::[]) t = [(x, t)]
       
   220       | mk_projs (x::xs) t = (x, mk_fst t) :: mk_projs xs (mk_snd t);
       
   221 
       
   222     fun mk_cfcomp2 ((rep_const, abs_const), f) =
       
   223         mk_cfcomp (abs_const, mk_cfcomp (f, rep_const));
       
   224 
       
   225     (* define take functional *)
       
   226     val newTs : typ list = map fst dom_eqns;
       
   227     val copy_arg_type = mk_tupleT (map (fn T => T ->> T) newTs);
       
   228     val copy_arg = Free ("f", copy_arg_type);
       
   229     val copy_args = map snd (mk_projs dom_binds copy_arg);
       
   230     fun one_copy_rhs (rep_abs, (lhsT, rhsT)) =
       
   231       let
       
   232         val body = map_of_typ thy (newTs ~~ copy_args) rhsT;
       
   233       in
       
   234         mk_cfcomp2 (rep_abs, body)
       
   235       end;
       
   236     val take_functional =
       
   237         big_lambda copy_arg
       
   238           (mk_tuple (map one_copy_rhs (rep_abs_consts ~~ dom_eqns)));
       
   239     val take_rhss =
       
   240       let
       
   241         val i = Free ("i", HOLogic.natT);
       
   242         val rhs = mk_iterate (i, take_functional)
       
   243       in
       
   244         map (Term.lambda i o snd) (mk_projs dom_binds rhs)
       
   245       end;
       
   246 
       
   247     (* define take constants *)
       
   248     fun define_take_const ((tbind, take_rhs), (lhsT, rhsT)) thy =
       
   249       let
       
   250         val take_type = HOLogic.natT --> lhsT ->> lhsT;
       
   251         val take_bind = Binding.suffix_name "_take" tbind;
       
   252         val (take_const, thy) =
       
   253           Sign.declare_const ((take_bind, take_type), NoSyn) thy;
       
   254         val take_eqn = Logic.mk_equals (take_const, take_rhs);
       
   255         val (take_def_thm, thy) =
       
   256           thy
       
   257           |> Sign.add_path (Binding.name_of tbind)
       
   258           |> yield_singleton
       
   259               (PureThy.add_defs false o map Thm.no_attributes)
       
   260               (Binding.name "take_def", take_eqn)
       
   261           ||> Sign.parent_path;
       
   262       in ((take_const, take_def_thm), thy) end;
       
   263     val ((take_consts, take_defs), thy) = thy
       
   264       |> fold_map define_take_const (dom_binds ~~ take_rhss ~~ dom_eqns)
       
   265       |>> ListPair.unzip;
       
   266 
       
   267     (* prove chain_take lemmas *)
       
   268     fun prove_chain_take (take_const, dname) thy =
       
   269       let
       
   270         val goal = mk_trp (mk_chain take_const);
       
   271         val rules = take_defs @ @{thms chain_iterate ch2ch_fst ch2ch_snd};
       
   272         val tac = simp_tac (HOL_basic_ss addsimps rules) 1;
       
   273         val chain_take_thm = Goal.prove_global thy [] [] goal (K tac);
       
   274       in
       
   275         add_qualified_thm "chain_take" (dname, chain_take_thm) thy
       
   276       end;
       
   277     val (chain_take_thms, thy) =
       
   278       fold_map prove_chain_take (take_consts ~~ dnames) thy;
       
   279 
       
   280     (* prove take_0 lemmas *)
       
   281     fun prove_take_0 ((take_const, dname), (lhsT, rhsT)) thy =
       
   282       let
       
   283         val lhs = take_const $ @{term "0::nat"};
       
   284         val goal = mk_eqs (lhs, mk_bottom (lhsT ->> lhsT));
       
   285         val rules = take_defs @ @{thms iterate_0 fst_strict snd_strict};
       
   286         val tac = simp_tac (HOL_basic_ss addsimps rules) 1;
       
   287         val take_0_thm = Goal.prove_global thy [] [] goal (K tac);
       
   288       in
       
   289         add_qualified_thm "take_0" (dname, take_0_thm) thy
       
   290       end;
       
   291     val (take_0_thms, thy) =
       
   292       fold_map prove_take_0 (take_consts ~~ dnames ~~ dom_eqns) thy;
       
   293 
       
   294     (* prove take_Suc lemmas *)
       
   295     val i = Free ("i", natT);
       
   296     val take_is = map (fn t => t $ i) take_consts;
       
   297     fun prove_take_Suc
       
   298           (((take_const, rep_abs), dname), (lhsT, rhsT)) thy =
       
   299       let
       
   300         val lhs = take_const $ (@{term Suc} $ i);
       
   301         val body = map_of_typ thy (newTs ~~ take_is) rhsT;
       
   302         val rhs = mk_cfcomp2 (rep_abs, body);
       
   303         val goal = mk_eqs (lhs, rhs);
       
   304         val simps = @{thms iterate_Suc fst_conv snd_conv}
       
   305         val rules = take_defs @ simps;
       
   306         val tac = simp_tac (beta_ss addsimps rules) 1;
       
   307         val take_Suc_thm = Goal.prove_global thy [] [] goal (K tac);
       
   308       in
       
   309         add_qualified_thm "take_Suc" (dname, take_Suc_thm) thy
       
   310       end;
       
   311     val (take_Suc_thms, thy) =
       
   312       fold_map prove_take_Suc
       
   313         (take_consts ~~ rep_abs_consts ~~ dnames ~~ dom_eqns) thy;
       
   314 
       
   315     (* prove deflation theorems for take functions *)
       
   316     val deflation_abs_rep_thms = map deflation_abs_rep iso_infos;
       
   317     val deflation_take_thm =
       
   318       let
       
   319         val i = Free ("i", natT);
       
   320         fun mk_goal take_const = mk_deflation (take_const $ i);
       
   321         val goal = mk_trp (foldr1 mk_conj (map mk_goal take_consts));
       
   322         val adm_rules =
       
   323           @{thms adm_conj adm_subst [OF _ adm_deflation]
       
   324                  cont2cont_fst cont2cont_snd cont_id};
       
   325         val bottom_rules =
       
   326           take_0_thms @ @{thms deflation_UU simp_thms};
       
   327         val deflation_rules =
       
   328           @{thms conjI deflation_ID}
       
   329           @ deflation_abs_rep_thms
       
   330           @ DeflMapData.get thy;
       
   331       in
       
   332         Goal.prove_global thy [] [] goal (fn _ =>
       
   333          EVERY
       
   334           [rtac @{thm nat.induct} 1,
       
   335            simp_tac (HOL_basic_ss addsimps bottom_rules) 1,
       
   336            asm_simp_tac (HOL_basic_ss addsimps take_Suc_thms) 1,
       
   337            REPEAT (etac @{thm conjE} 1
       
   338                    ORELSE resolve_tac deflation_rules 1
       
   339                    ORELSE atac 1)])
       
   340       end;
       
   341     fun conjuncts [] thm = []
       
   342       | conjuncts (n::[]) thm = [(n, thm)]
       
   343       | conjuncts (n::ns) thm = let
       
   344           val thmL = thm RS @{thm conjunct1};
       
   345           val thmR = thm RS @{thm conjunct2};
       
   346         in (n, thmL):: conjuncts ns thmR end;
       
   347     val (deflation_take_thms, thy) =
       
   348       fold_map (add_qualified_thm "deflation_take")
       
   349         (map (apsnd Drule.export_without_context)
       
   350           (conjuncts dnames deflation_take_thm)) thy;
       
   351 
       
   352     (* prove strictness of take functions *)
       
   353     fun prove_take_strict (take_const, dname) thy =
       
   354       let
       
   355         val goal = mk_trp (mk_strict (take_const $ Free ("i", natT)));
       
   356         val tac = rtac @{thm deflation_strict} 1
       
   357                   THEN resolve_tac deflation_take_thms 1;
       
   358         val take_strict_thm = Goal.prove_global thy [] [] goal (K tac);
       
   359       in
       
   360         add_qualified_thm "take_strict" (dname, take_strict_thm) thy
       
   361       end;
       
   362     val (take_strict_thms, thy) =
       
   363       fold_map prove_take_strict (take_consts ~~ dnames) thy;
       
   364 
       
   365     (* prove take/take rules *)
       
   366     fun prove_take_take ((chain_take, deflation_take), dname) thy =
       
   367       let
       
   368         val take_take_thm =
       
   369             @{thm deflation_chain_min} OF [chain_take, deflation_take];
       
   370       in
       
   371         add_qualified_thm "take_take" (dname, take_take_thm) thy
       
   372       end;
       
   373     val (take_take_thms, thy) =
       
   374       fold_map prove_take_take
       
   375         (chain_take_thms ~~ deflation_take_thms ~~ dnames) thy;
       
   376 
       
   377     val result =
       
   378       {
       
   379         take_consts = take_consts,
       
   380         take_defs = take_defs,
       
   381         chain_take_thms = chain_take_thms,
       
   382         take_0_thms = take_0_thms,
       
   383         take_Suc_thms = take_Suc_thms,
       
   384         deflation_take_thms = deflation_take_thms
       
   385       };
       
   386 
       
   387   in
       
   388     (result, thy)
       
   389   end;
       
   390 
       
   391 end;