src/HOL/HOLCF/Tools/Domain/domain_isomorphism.ML
author huffman
Tue Nov 30 14:21:57 2010 -0800 (2010-11-30)
changeset 40832 4352ca878c41
parent 40774 0437dbc127b3
child 40833 4f130bd9e17e
permissions -rw-r--r--
remove gratuitous semicolons from ML code
     1 (*  Title:      HOLCF/Tools/Domain/domain_isomorphism.ML
     2     Author:     Brian Huffman
     3 
     4 Defines new types satisfying the given domain equations.
     5 *)
     6 
     7 signature DOMAIN_ISOMORPHISM =
     8 sig
     9   val domain_isomorphism :
    10       (string list * binding * mixfix * typ
    11        * (binding * binding) option) list ->
    12       theory ->
    13       (Domain_Take_Proofs.iso_info list
    14        * Domain_Take_Proofs.take_induct_info) * theory
    15 
    16   val define_map_functions :
    17       (binding * Domain_Take_Proofs.iso_info) list ->
    18       theory ->
    19       {
    20         map_consts : term list,
    21         map_apply_thms : thm list,
    22         map_unfold_thms : thm list,
    23         deflation_map_thms : thm list
    24       }
    25       * theory
    26 
    27   val domain_isomorphism_cmd :
    28     (string list * binding * mixfix * string * (binding * binding) option) list
    29       -> theory -> theory
    30 
    31   val setup : theory -> theory
    32 end
    33 
    34 structure Domain_Isomorphism : DOMAIN_ISOMORPHISM =
    35 struct
    36 
    37 val beta_rules =
    38   @{thms beta_cfun cont_id cont_const cont2cont_APP cont2cont_LAM'} @
    39   @{thms cont2cont_fst cont2cont_snd cont2cont_Pair cont2cont_prod_case'}
    40 
    41 val beta_ss = HOL_basic_ss addsimps (simp_thms @ beta_rules)
    42 
    43 val beta_tac = simp_tac beta_ss
    44 
    45 fun is_cpo thy T = Sign.of_sort thy (T, @{sort cpo})
    46 
    47 (******************************************************************************)
    48 (******************************** theory data *********************************)
    49 (******************************************************************************)
    50 
    51 structure RepData = Named_Thms
    52 (
    53   val name = "domain_defl_simps"
    54   val description = "theorems like DEFL('a t) = t_defl$DEFL('a)"
    55 )
    56 
    57 structure IsodeflData = Named_Thms
    58 (
    59   val name = "domain_isodefl"
    60   val description = "theorems like isodefl d t ==> isodefl (foo_map$d) (foo_defl$t)"
    61 )
    62 
    63 val setup = RepData.setup #> IsodeflData.setup
    64 
    65 
    66 (******************************************************************************)
    67 (************************** building types and terms **************************)
    68 (******************************************************************************)
    69 
    70 open HOLCF_Library
    71 
    72 infixr 6 ->>
    73 infixr -->>
    74 
    75 val udomT = @{typ udom}
    76 val deflT = @{typ "defl"}
    77 
    78 fun mk_DEFL T =
    79   Const (@{const_name defl}, Term.itselfT T --> deflT) $ Logic.mk_type T
    80 
    81 fun dest_DEFL (Const (@{const_name defl}, _) $ t) = Logic.dest_type t
    82   | dest_DEFL t = raise TERM ("dest_DEFL", [t])
    83 
    84 fun mk_LIFTDEFL T =
    85   Const (@{const_name liftdefl}, Term.itselfT T --> deflT) $ Logic.mk_type T
    86 
    87 fun dest_LIFTDEFL (Const (@{const_name liftdefl}, _) $ t) = Logic.dest_type t
    88   | dest_LIFTDEFL t = raise TERM ("dest_LIFTDEFL", [t])
    89 
    90 fun mk_u_defl t = mk_capply (@{const "u_defl"}, t)
    91 
    92 fun mk_u_map t =
    93   let
    94     val (T, U) = dest_cfunT (fastype_of t)
    95     val u_map_type = (T ->> U) ->> (mk_upT T ->> mk_upT U)
    96     val u_map_const = Const (@{const_name u_map}, u_map_type)
    97   in
    98     mk_capply (u_map_const, t)
    99   end
   100 
   101 fun emb_const T = Const (@{const_name emb}, T ->> udomT)
   102 fun prj_const T = Const (@{const_name prj}, udomT ->> T)
   103 fun coerce_const (T, U) = mk_cfcomp (prj_const U, emb_const T)
   104 
   105 fun isodefl_const T =
   106   Const (@{const_name isodefl}, (T ->> T) --> deflT --> HOLogic.boolT)
   107 
   108 fun mk_deflation t =
   109   Const (@{const_name deflation}, Term.fastype_of t --> boolT) $ t
   110 
   111 (* splits a cterm into the right and lefthand sides of equality *)
   112 fun dest_eqs t = HOLogic.dest_eq (HOLogic.dest_Trueprop t)
   113 
   114 fun mk_eqs (t, u) = HOLogic.mk_Trueprop (HOLogic.mk_eq (t, u))
   115 
   116 (******************************************************************************)
   117 (****************************** isomorphism info ******************************)
   118 (******************************************************************************)
   119 
   120 fun deflation_abs_rep (info : Domain_Take_Proofs.iso_info) : thm =
   121   let
   122     val abs_iso = #abs_inverse info
   123     val rep_iso = #rep_inverse info
   124     val thm = @{thm deflation_abs_rep} OF [abs_iso, rep_iso]
   125   in
   126     Drule.zero_var_indexes thm
   127   end
   128 
   129 (******************************************************************************)
   130 (*************** fixed-point definitions and unfolding theorems ***************)
   131 (******************************************************************************)
   132 
   133 fun mk_projs []      t = []
   134   | mk_projs (x::[]) t = [(x, t)]
   135   | mk_projs (x::xs) t = (x, mk_fst t) :: mk_projs xs (mk_snd t)
   136 
   137 fun add_fixdefs
   138     (spec : (binding * term) list)
   139     (thy : theory) : (thm list * thm list) * theory =
   140   let
   141     val binds = map fst spec
   142     val (lhss, rhss) = ListPair.unzip (map (dest_eqs o snd) spec)
   143     val functional = lambda_tuple lhss (mk_tuple rhss)
   144     val fixpoint = mk_fix (mk_cabs functional)
   145 
   146     (* project components of fixpoint *)
   147     val projs = mk_projs lhss fixpoint
   148 
   149     (* convert parameters to lambda abstractions *)
   150     fun mk_eqn (lhs, rhs) =
   151         case lhs of
   152           Const (@{const_name Rep_cfun}, _) $ f $ (x as Free _) =>
   153             mk_eqn (f, big_lambda x rhs)
   154         | f $ Const (@{const_name TYPE}, T) =>
   155             mk_eqn (f, Abs ("t", T, rhs))
   156         | Const _ => Logic.mk_equals (lhs, rhs)
   157         | _ => raise TERM ("lhs not of correct form", [lhs, rhs])
   158     val eqns = map mk_eqn projs
   159 
   160     (* register constant definitions *)
   161     val (fixdef_thms, thy) =
   162       (Global_Theory.add_defs false o map Thm.no_attributes)
   163         (map (Binding.suffix_name "_def") binds ~~ eqns) thy
   164 
   165     (* prove applied version of definitions *)
   166     fun prove_proj (lhs, rhs) =
   167       let
   168         val tac = rewrite_goals_tac fixdef_thms THEN beta_tac 1
   169         val goal = Logic.mk_equals (lhs, rhs)
   170       in Goal.prove_global thy [] [] goal (K tac) end
   171     val proj_thms = map prove_proj projs
   172 
   173     (* mk_tuple lhss == fixpoint *)
   174     fun pair_equalI (thm1, thm2) = @{thm Pair_equalI} OF [thm1, thm2]
   175     val tuple_fixdef_thm = foldr1 pair_equalI proj_thms
   176 
   177     val cont_thm =
   178       Goal.prove_global thy [] [] (mk_trp (mk_cont functional))
   179         (K (beta_tac 1))
   180     val tuple_unfold_thm =
   181       (@{thm def_cont_fix_eq} OF [tuple_fixdef_thm, cont_thm])
   182       |> Local_Defs.unfold (ProofContext.init_global thy) @{thms split_conv}
   183 
   184     fun mk_unfold_thms [] thm = []
   185       | mk_unfold_thms (n::[]) thm = [(n, thm)]
   186       | mk_unfold_thms (n::ns) thm = let
   187           val thmL = thm RS @{thm Pair_eqD1}
   188           val thmR = thm RS @{thm Pair_eqD2}
   189         in (n, thmL) :: mk_unfold_thms ns thmR end
   190     val unfold_binds = map (Binding.suffix_name "_unfold") binds
   191 
   192     (* register unfold theorems *)
   193     val (unfold_thms, thy) =
   194       (Global_Theory.add_thms o map (Thm.no_attributes o apsnd Drule.zero_var_indexes))
   195         (mk_unfold_thms unfold_binds tuple_unfold_thm) thy
   196   in
   197     ((proj_thms, unfold_thms), thy)
   198   end
   199 
   200 
   201 (******************************************************************************)
   202 (****************** deflation combinators and map functions *******************)
   203 (******************************************************************************)
   204 
   205 fun defl_of_typ
   206     (thy : theory)
   207     (tab1 : (typ * term) list)
   208     (tab2 : (typ * term) list)
   209     (T : typ) : term =
   210   let
   211     val defl_simps = RepData.get (ProofContext.init_global thy)
   212     val rules = map (Thm.concl_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq) defl_simps
   213     val rules' = map (apfst mk_DEFL) tab1 @ map (apfst mk_LIFTDEFL) tab2
   214     fun proc1 t =
   215       (case dest_DEFL t of
   216         TFree (a, _) => SOME (Free ("d" ^ Library.unprefix "'" a, deflT))
   217       | _ => NONE) handle TERM _ => NONE
   218     fun proc2 t =
   219       (case dest_LIFTDEFL t of
   220         TFree (a, _) => SOME (Free ("p" ^ Library.unprefix "'" a, deflT))
   221       | _ => NONE) handle TERM _ => NONE
   222   in
   223     Pattern.rewrite_term thy (rules @ rules') [proc1, proc2] (mk_DEFL T)
   224   end
   225 
   226 (******************************************************************************)
   227 (********************* declaring definitions and theorems *********************)
   228 (******************************************************************************)
   229 
   230 fun define_const
   231     (bind : binding, rhs : term)
   232     (thy : theory)
   233     : (term * thm) * theory =
   234   let
   235     val typ = Term.fastype_of rhs
   236     val (const, thy) = Sign.declare_const ((bind, typ), NoSyn) thy
   237     val eqn = Logic.mk_equals (const, rhs)
   238     val def = Thm.no_attributes (Binding.suffix_name "_def" bind, eqn)
   239     val (def_thm, thy) = yield_singleton (Global_Theory.add_defs false) def thy
   240   in
   241     ((const, def_thm), thy)
   242   end
   243 
   244 fun add_qualified_thm name (dbind, thm) =
   245     yield_singleton Global_Theory.add_thms
   246       ((Binding.qualified true name dbind, thm), [])
   247 
   248 (******************************************************************************)
   249 (*************************** defining map functions ***************************)
   250 (******************************************************************************)
   251 
   252 fun define_map_functions
   253     (spec : (binding * Domain_Take_Proofs.iso_info) list)
   254     (thy : theory) =
   255   let
   256 
   257     (* retrieve components of spec *)
   258     val dbinds = map fst spec
   259     val iso_infos = map snd spec
   260     val dom_eqns = map (fn x => (#absT x, #repT x)) iso_infos
   261     val rep_abs_consts = map (fn x => (#rep_const x, #abs_const x)) iso_infos
   262 
   263     fun mapT (T as Type (_, Ts)) =
   264         (map (fn T => T ->> T) (filter (is_cpo thy) Ts)) -->> (T ->> T)
   265       | mapT T = T ->> T
   266 
   267     (* declare map functions *)
   268     fun declare_map_const (tbind, (lhsT, rhsT)) thy =
   269       let
   270         val map_type = mapT lhsT
   271         val map_bind = Binding.suffix_name "_map" tbind
   272       in
   273         Sign.declare_const ((map_bind, map_type), NoSyn) thy
   274       end
   275     val (map_consts, thy) = thy |>
   276       fold_map declare_map_const (dbinds ~~ dom_eqns)
   277 
   278     (* defining equations for map functions *)
   279     local
   280       fun unprime a = Library.unprefix "'" a
   281       fun mapvar T = Free (unprime (fst (dest_TFree T)), T ->> T)
   282       fun map_lhs (map_const, lhsT) =
   283           (lhsT, list_ccomb (map_const, map mapvar (filter (is_cpo thy) (snd (dest_Type lhsT)))))
   284       val tab1 = map map_lhs (map_consts ~~ map fst dom_eqns)
   285       val Ts = (snd o dest_Type o fst o hd) dom_eqns
   286       val tab = (Ts ~~ map mapvar Ts) @ tab1
   287       fun mk_map_spec (((rep_const, abs_const), map_const), (lhsT, rhsT)) =
   288         let
   289           val lhs = Domain_Take_Proofs.map_of_typ thy tab lhsT
   290           val body = Domain_Take_Proofs.map_of_typ thy tab rhsT
   291           val rhs = mk_cfcomp (abs_const, mk_cfcomp (body, rep_const))
   292         in mk_eqs (lhs, rhs) end
   293     in
   294       val map_specs =
   295           map mk_map_spec (rep_abs_consts ~~ map_consts ~~ dom_eqns)
   296     end
   297 
   298     (* register recursive definition of map functions *)
   299     val map_binds = map (Binding.suffix_name "_map") dbinds
   300     val ((map_apply_thms, map_unfold_thms), thy) =
   301       add_fixdefs (map_binds ~~ map_specs) thy
   302 
   303     (* prove deflation theorems for map functions *)
   304     val deflation_abs_rep_thms = map deflation_abs_rep iso_infos
   305     val deflation_map_thm =
   306       let
   307         fun unprime a = Library.unprefix "'" a
   308         fun mk_f T = Free (unprime (fst (dest_TFree T)), T ->> T)
   309         fun mk_assm T = mk_trp (mk_deflation (mk_f T))
   310         fun mk_goal (map_const, (lhsT, rhsT)) =
   311           let
   312             val (_, Ts) = dest_Type lhsT
   313             val map_term = list_ccomb (map_const, map mk_f (filter (is_cpo thy) Ts))
   314           in mk_deflation map_term end
   315         val assms = (map mk_assm o filter (is_cpo thy) o snd o dest_Type o fst o hd) dom_eqns
   316         val goals = map mk_goal (map_consts ~~ dom_eqns)
   317         val goal = mk_trp (foldr1 HOLogic.mk_conj goals)
   318         val start_thms =
   319           @{thm split_def} :: map_apply_thms
   320         val adm_rules =
   321           @{thms adm_conj adm_subst [OF _ adm_deflation]
   322                  cont2cont_fst cont2cont_snd cont_id}
   323         val bottom_rules =
   324           @{thms fst_strict snd_strict deflation_UU simp_thms}
   325         val deflation_rules =
   326           @{thms conjI deflation_ID}
   327           @ deflation_abs_rep_thms
   328           @ Domain_Take_Proofs.get_deflation_thms thy
   329       in
   330         Goal.prove_global thy [] assms goal (fn {prems, ...} =>
   331          EVERY
   332           [simp_tac (HOL_basic_ss addsimps start_thms) 1,
   333            rtac @{thm fix_ind} 1,
   334            REPEAT (resolve_tac adm_rules 1),
   335            simp_tac (HOL_basic_ss addsimps bottom_rules) 1,
   336            simp_tac beta_ss 1,
   337            simp_tac (HOL_basic_ss addsimps @{thms fst_conv snd_conv}) 1,
   338            REPEAT (etac @{thm conjE} 1),
   339            REPEAT (resolve_tac (deflation_rules @ prems) 1 ORELSE atac 1)])
   340       end
   341     fun conjuncts [] thm = []
   342       | conjuncts (n::[]) thm = [(n, thm)]
   343       | conjuncts (n::ns) thm = let
   344           val thmL = thm RS @{thm conjunct1}
   345           val thmR = thm RS @{thm conjunct2}
   346         in (n, thmL):: conjuncts ns thmR end
   347     val deflation_map_binds = dbinds |>
   348         map (Binding.prefix_name "deflation_" o Binding.suffix_name "_map")
   349     val (deflation_map_thms, thy) = thy |>
   350       (Global_Theory.add_thms o map (Thm.no_attributes o apsnd Drule.zero_var_indexes))
   351         (conjuncts deflation_map_binds deflation_map_thm)
   352 
   353     (* register indirect recursion in theory data *)
   354     local
   355       fun register_map (dname, args) =
   356         Domain_Take_Proofs.add_rec_type (dname, args)
   357       val dnames = map (fst o dest_Type o fst) dom_eqns
   358       val map_names = map (fst o dest_Const) map_consts
   359       fun args (T, _) = case T of Type (_, Ts) => map (is_cpo thy) Ts | _ => []
   360       val argss = map args dom_eqns
   361     in
   362       val thy =
   363           fold register_map (dnames ~~ argss) thy
   364     end
   365 
   366     (* register deflation theorems *)
   367     val thy = fold Domain_Take_Proofs.add_deflation_thm deflation_map_thms thy
   368 
   369     val result =
   370       {
   371         map_consts = map_consts,
   372         map_apply_thms = map_apply_thms,
   373         map_unfold_thms = map_unfold_thms,
   374         deflation_map_thms = deflation_map_thms
   375       }
   376   in
   377     (result, thy)
   378   end
   379 
   380 (******************************************************************************)
   381 (******************************* main function ********************************)
   382 (******************************************************************************)
   383 
   384 fun read_typ thy str sorts =
   385   let
   386     val ctxt = ProofContext.init_global thy
   387       |> fold (Variable.declare_typ o TFree) sorts
   388     val T = Syntax.read_typ ctxt str
   389   in (T, Term.add_tfreesT T sorts) end
   390 
   391 fun cert_typ sign raw_T sorts =
   392   let
   393     val T = Type.no_tvars (Sign.certify_typ sign raw_T)
   394       handle TYPE (msg, _, _) => error msg
   395     val sorts' = Term.add_tfreesT T sorts
   396     val _ =
   397       case duplicates (op =) (map fst sorts') of
   398         [] => ()
   399       | dups => error ("Inconsistent sort constraints for " ^ commas dups)
   400   in (T, sorts') end
   401 
   402 fun gen_domain_isomorphism
   403     (prep_typ: theory -> 'a -> (string * sort) list -> typ * (string * sort) list)
   404     (doms_raw: (string list * binding * mixfix * 'a * (binding * binding) option) list)
   405     (thy: theory)
   406     : (Domain_Take_Proofs.iso_info list
   407        * Domain_Take_Proofs.take_induct_info) * theory =
   408   let
   409     val _ = Theory.requires thy "Domain" "domain isomorphisms"
   410 
   411     (* this theory is used just for parsing *)
   412     val tmp_thy = thy |>
   413       Theory.copy |>
   414       Sign.add_types (map (fn (tvs, tbind, mx, _, morphs) =>
   415         (tbind, length tvs, mx)) doms_raw)
   416 
   417     fun prep_dom thy (vs, t, mx, typ_raw, morphs) sorts =
   418       let val (typ, sorts') = prep_typ thy typ_raw sorts
   419       in ((vs, t, mx, typ, morphs), sorts') end
   420 
   421     val (doms : (string list * binding * mixfix * typ * (binding * binding) option) list,
   422          sorts : (string * sort) list) =
   423       fold_map (prep_dom tmp_thy) doms_raw []
   424 
   425     (* lookup function for sorts of type variables *)
   426     fun the_sort v = the (AList.lookup (op =) sorts v)
   427 
   428     (* declare arities in temporary theory *)
   429     val tmp_thy =
   430       let
   431         fun arity (vs, tbind, mx, _, _) =
   432           (Sign.full_name thy tbind, map the_sort vs, @{sort "domain"})
   433       in
   434         fold AxClass.axiomatize_arity (map arity doms) tmp_thy
   435       end
   436 
   437     (* check bifiniteness of right-hand sides *)
   438     fun check_rhs (vs, tbind, mx, rhs, morphs) =
   439       if Sign.of_sort tmp_thy (rhs, @{sort "domain"}) then ()
   440       else error ("Type not of sort domain: " ^
   441         quote (Syntax.string_of_typ_global tmp_thy rhs))
   442     val _ = map check_rhs doms
   443 
   444     (* domain equations *)
   445     fun mk_dom_eqn (vs, tbind, mx, rhs, morphs) =
   446       let fun arg v = TFree (v, the_sort v)
   447       in (Type (Sign.full_name tmp_thy tbind, map arg vs), rhs) end
   448     val dom_eqns = map mk_dom_eqn doms
   449 
   450     (* check for valid type parameters *)
   451     val (tyvars, _, _, _, _) = hd doms
   452     val new_doms = map (fn (tvs, tname, mx, _, _) =>
   453       let val full_tname = Sign.full_name tmp_thy tname
   454       in
   455         (case duplicates (op =) tvs of
   456           [] =>
   457             if eq_set (op =) (tyvars, tvs) then (full_tname, tvs)
   458             else error ("Mutually recursive domains must have same type parameters")
   459         | dups => error ("Duplicate parameter(s) for domain " ^ quote (Binding.str_of tname) ^
   460             " : " ^ commas dups))
   461       end) doms
   462     val dbinds = map (fn (_, dbind, _, _, _) => dbind) doms
   463     val morphs = map (fn (_, _, _, _, morphs) => morphs) doms
   464 
   465     (* determine deflation combinator arguments *)
   466     val lhsTs : typ list = map fst dom_eqns
   467     val defl_rec = Free ("t", mk_tupleT (map (K deflT) lhsTs))
   468     val defl_recs = mk_projs lhsTs defl_rec
   469     val defl_recs' = map (apsnd mk_u_defl) defl_recs
   470     fun defl_body (_, _, _, rhsT, _) =
   471       defl_of_typ tmp_thy defl_recs defl_recs' rhsT
   472     val functional = Term.lambda defl_rec (mk_tuple (map defl_body doms))
   473 
   474     val tfrees = map fst (Term.add_tfrees functional [])
   475     val frees = map fst (Term.add_frees functional [])
   476     fun get_defl_flags (vs, _, _, _, _) =
   477       let
   478         fun argT v = TFree (v, the_sort v)
   479         fun mk_d v = "d" ^ Library.unprefix "'" v
   480         fun mk_p v = "p" ^ Library.unprefix "'" v
   481         val args = maps (fn v => [(mk_d v, mk_DEFL (argT v)), (mk_p v, mk_LIFTDEFL (argT v))]) vs
   482         val typeTs = map argT (filter (member (op =) tfrees) vs)
   483         val defl_args = map snd (filter (member (op =) frees o fst) args)
   484       in
   485         (typeTs, defl_args)
   486       end
   487     val defl_flagss = map get_defl_flags doms
   488 
   489     (* declare deflation combinator constants *)
   490     fun declare_defl_const ((typeTs, defl_args), (_, tbind, _, _, _)) thy =
   491       let
   492         val defl_bind = Binding.suffix_name "_defl" tbind
   493         val defl_type =
   494           map Term.itselfT typeTs ---> map (K deflT) defl_args -->> deflT
   495       in
   496         Sign.declare_const ((defl_bind, defl_type), NoSyn) thy
   497       end
   498     val (defl_consts, thy) =
   499       fold_map declare_defl_const (defl_flagss ~~ doms) thy
   500 
   501     (* defining equations for type combinators *)
   502     fun mk_defl_term (defl_const, (typeTs, defl_args)) =
   503       let
   504         val type_args = map Logic.mk_type typeTs
   505       in
   506         list_ccomb (list_comb (defl_const, type_args), defl_args)
   507       end
   508     val defl_terms = map mk_defl_term (defl_consts ~~ defl_flagss)
   509     val defl_tab = map fst dom_eqns ~~ defl_terms
   510     val defl_tab' = map fst dom_eqns ~~ map mk_u_defl defl_terms
   511     fun mk_defl_spec (lhsT, rhsT) =
   512       mk_eqs (defl_of_typ tmp_thy defl_tab defl_tab' lhsT,
   513               defl_of_typ tmp_thy defl_tab defl_tab' rhsT)
   514     val defl_specs = map mk_defl_spec dom_eqns
   515 
   516     (* register recursive definition of deflation combinators *)
   517     val defl_binds = map (Binding.suffix_name "_defl") dbinds
   518     val ((defl_apply_thms, defl_unfold_thms), thy) =
   519       add_fixdefs (defl_binds ~~ defl_specs) thy
   520 
   521     (* define types using deflation combinators *)
   522     fun make_repdef ((vs, tbind, mx, _, _), defl) thy =
   523       let
   524         val spec = (tbind, map (rpair dummyS) vs, mx)
   525         val ((_, _, _, {DEFL, liftemb_def, liftprj_def, ...}), thy) =
   526           Domaindef.add_domaindef false NONE spec defl NONE thy
   527         (* declare domain_defl_simps rules *)
   528         val thy = Context.theory_map (RepData.add_thm DEFL) thy
   529       in
   530         (DEFL, thy)
   531       end
   532     val (DEFL_thms, thy) = fold_map make_repdef (doms ~~ defl_terms) thy
   533 
   534     (* prove DEFL equations *)
   535     fun mk_DEFL_eq_thm (lhsT, rhsT) =
   536       let
   537         val goal = mk_eqs (mk_DEFL lhsT, mk_DEFL rhsT)
   538         val DEFL_simps = RepData.get (ProofContext.init_global thy)
   539         val tac =
   540           rewrite_goals_tac (map mk_meta_eq DEFL_simps)
   541           THEN TRY (resolve_tac defl_unfold_thms 1)
   542       in
   543         Goal.prove_global thy [] [] goal (K tac)
   544       end
   545     val DEFL_eq_thms = map mk_DEFL_eq_thm dom_eqns
   546 
   547     (* register DEFL equations *)
   548     val DEFL_eq_binds = map (Binding.prefix_name "DEFL_eq_") dbinds
   549     val (_, thy) = thy |>
   550       (Global_Theory.add_thms o map Thm.no_attributes)
   551         (DEFL_eq_binds ~~ DEFL_eq_thms)
   552 
   553     (* define rep/abs functions *)
   554     fun mk_rep_abs ((tbind, morphs), (lhsT, rhsT)) thy =
   555       let
   556         val rep_bind = Binding.suffix_name "_rep" tbind
   557         val abs_bind = Binding.suffix_name "_abs" tbind
   558         val ((rep_const, rep_def), thy) =
   559             define_const (rep_bind, coerce_const (lhsT, rhsT)) thy
   560         val ((abs_const, abs_def), thy) =
   561             define_const (abs_bind, coerce_const (rhsT, lhsT)) thy
   562       in
   563         (((rep_const, abs_const), (rep_def, abs_def)), thy)
   564       end
   565     val ((rep_abs_consts, rep_abs_defs), thy) = thy
   566       |> fold_map mk_rep_abs (dbinds ~~ morphs ~~ dom_eqns)
   567       |>> ListPair.unzip
   568 
   569     (* prove isomorphism and isodefl rules *)
   570     fun mk_iso_thms ((tbind, DEFL_eq), (rep_def, abs_def)) thy =
   571       let
   572         fun make thm =
   573             Drule.zero_var_indexes (thm OF [DEFL_eq, abs_def, rep_def])
   574         val rep_iso_thm = make @{thm domain_rep_iso}
   575         val abs_iso_thm = make @{thm domain_abs_iso}
   576         val isodefl_thm = make @{thm isodefl_abs_rep}
   577         val thy = thy
   578           |> snd o add_qualified_thm "rep_iso" (tbind, rep_iso_thm)
   579           |> snd o add_qualified_thm "abs_iso" (tbind, abs_iso_thm)
   580           |> snd o add_qualified_thm "isodefl_abs_rep" (tbind, isodefl_thm)
   581       in
   582         (((rep_iso_thm, abs_iso_thm), isodefl_thm), thy)
   583       end
   584     val ((iso_thms, isodefl_abs_rep_thms), thy) =
   585       thy
   586       |> fold_map mk_iso_thms (dbinds ~~ DEFL_eq_thms ~~ rep_abs_defs)
   587       |>> ListPair.unzip
   588 
   589     (* collect info about rep/abs *)
   590     val iso_infos : Domain_Take_Proofs.iso_info list =
   591       let
   592         fun mk_info (((lhsT, rhsT), (repC, absC)), (rep_iso, abs_iso)) =
   593           {
   594             repT = rhsT,
   595             absT = lhsT,
   596             rep_const = repC,
   597             abs_const = absC,
   598             rep_inverse = rep_iso,
   599             abs_inverse = abs_iso
   600           }
   601       in
   602         map mk_info (dom_eqns ~~ rep_abs_consts ~~ iso_thms)
   603       end
   604 
   605     (* definitions and proofs related to map functions *)
   606     val (map_info, thy) =
   607         define_map_functions (dbinds ~~ iso_infos) thy
   608     val { map_consts, map_apply_thms, map_unfold_thms,
   609           deflation_map_thms } = map_info
   610 
   611     (* prove isodefl rules for map functions *)
   612     val isodefl_thm =
   613       let
   614         fun unprime a = Library.unprefix "'" a
   615         fun mk_d T = Free ("d" ^ unprime (fst (dest_TFree T)), deflT)
   616         fun mk_p T = Free ("p" ^ unprime (fst (dest_TFree T)), deflT)
   617         fun mk_f T = Free ("f" ^ unprime (fst (dest_TFree T)), T ->> T)
   618         fun mk_assm t =
   619           case try dest_LIFTDEFL t of
   620             SOME T => mk_trp (isodefl_const (mk_upT T) $ mk_u_map (mk_f T) $ mk_p T)
   621           | NONE =>
   622             let val T = dest_DEFL t
   623             in mk_trp (isodefl_const T $ mk_f T $ mk_d T) end
   624         fun mk_goal (map_const, (T, rhsT)) =
   625           let
   626             val (_, Ts) = dest_Type T
   627             val map_term = list_ccomb (map_const, map mk_f (filter (is_cpo thy) Ts))
   628             val defl_term = defl_of_typ thy (Ts ~~ map mk_d Ts) (Ts ~~ map mk_p Ts) T
   629           in isodefl_const T $ map_term $ defl_term end
   630         val assms = (map mk_assm o snd o hd) defl_flagss
   631         val goals = map mk_goal (map_consts ~~ dom_eqns)
   632         val goal = mk_trp (foldr1 HOLogic.mk_conj goals)
   633         val start_thms =
   634           @{thm split_def} :: defl_apply_thms @ map_apply_thms
   635         val adm_rules =
   636           @{thms adm_conj adm_isodefl cont2cont_fst cont2cont_snd cont_id}
   637         val bottom_rules =
   638           @{thms fst_strict snd_strict isodefl_bottom simp_thms}
   639         val map_ID_thms = Domain_Take_Proofs.get_map_ID_thms thy
   640         val map_ID_simps = map (fn th => th RS sym) map_ID_thms
   641         val isodefl_rules =
   642           @{thms conjI isodefl_ID_DEFL isodefl_LIFTDEFL}
   643           @ isodefl_abs_rep_thms
   644           @ IsodeflData.get (ProofContext.init_global thy)
   645       in
   646         Goal.prove_global thy [] assms goal (fn {prems, ...} =>
   647          EVERY
   648           [simp_tac (HOL_basic_ss addsimps start_thms) 1,
   649            (* FIXME: how reliable is unification here? *)
   650            (* Maybe I should instantiate the rule. *)
   651            rtac @{thm parallel_fix_ind} 1,
   652            REPEAT (resolve_tac adm_rules 1),
   653            simp_tac (HOL_basic_ss addsimps bottom_rules) 1,
   654            simp_tac beta_ss 1,
   655            simp_tac (HOL_basic_ss addsimps @{thms fst_conv snd_conv}) 1,
   656            simp_tac (HOL_basic_ss addsimps map_ID_simps) 1,
   657            REPEAT (etac @{thm conjE} 1),
   658            REPEAT (resolve_tac (isodefl_rules @ prems) 1 ORELSE atac 1)])
   659       end
   660     val isodefl_binds = map (Binding.prefix_name "isodefl_") dbinds
   661     fun conjuncts [] thm = []
   662       | conjuncts (n::[]) thm = [(n, thm)]
   663       | conjuncts (n::ns) thm = let
   664           val thmL = thm RS @{thm conjunct1}
   665           val thmR = thm RS @{thm conjunct2}
   666         in (n, thmL):: conjuncts ns thmR end
   667     val (isodefl_thms, thy) = thy |>
   668       (Global_Theory.add_thms o map (Thm.no_attributes o apsnd Drule.zero_var_indexes))
   669         (conjuncts isodefl_binds isodefl_thm)
   670     val thy = fold (Context.theory_map o IsodeflData.add_thm) isodefl_thms thy
   671 
   672     (* prove map_ID theorems *)
   673     fun prove_map_ID_thm
   674         (((map_const, (lhsT, _)), DEFL_thm), isodefl_thm) =
   675       let
   676         val Ts = snd (dest_Type lhsT)
   677         fun is_cpo T = Sign.of_sort thy (T, @{sort cpo})
   678         val lhs = list_ccomb (map_const, map mk_ID (filter is_cpo Ts))
   679         val goal = mk_eqs (lhs, mk_ID lhsT)
   680         val tac = EVERY
   681           [rtac @{thm isodefl_DEFL_imp_ID} 1,
   682            stac DEFL_thm 1,
   683            rtac isodefl_thm 1,
   684            REPEAT (resolve_tac @{thms isodefl_ID_DEFL isodefl_LIFTDEFL} 1)]
   685       in
   686         Goal.prove_global thy [] [] goal (K tac)
   687       end
   688     val map_ID_binds = map (Binding.suffix_name "_map_ID") dbinds
   689     val map_ID_thms =
   690       map prove_map_ID_thm
   691         (map_consts ~~ dom_eqns ~~ DEFL_thms ~~ isodefl_thms)
   692     val (_, thy) = thy |>
   693       (Global_Theory.add_thms o map (rpair [Domain_Take_Proofs.map_ID_add]))
   694         (map_ID_binds ~~ map_ID_thms)
   695 
   696     (* definitions and proofs related to take functions *)
   697     val (take_info, thy) =
   698         Domain_Take_Proofs.define_take_functions
   699           (dbinds ~~ iso_infos) thy
   700     val { take_consts, chain_take_thms, take_0_thms, take_Suc_thms, ...} =
   701         take_info
   702 
   703     (* least-upper-bound lemma for take functions *)
   704     val lub_take_lemma =
   705       let
   706         val lhs = mk_tuple (map mk_lub take_consts)
   707         fun is_cpo T = Sign.of_sort thy (T, @{sort cpo})
   708         fun mk_map_ID (map_const, (lhsT, rhsT)) =
   709           list_ccomb (map_const, map mk_ID (filter is_cpo (snd (dest_Type lhsT))))
   710         val rhs = mk_tuple (map mk_map_ID (map_consts ~~ dom_eqns))
   711         val goal = mk_trp (mk_eq (lhs, rhs))
   712         val map_ID_thms = Domain_Take_Proofs.get_map_ID_thms thy
   713         val start_rules =
   714             @{thms lub_Pair [symmetric] ch2ch_Pair} @ chain_take_thms
   715             @ @{thms pair_collapse split_def}
   716             @ map_apply_thms @ map_ID_thms
   717         val rules0 =
   718             @{thms iterate_0 Pair_strict} @ take_0_thms
   719         val rules1 =
   720             @{thms iterate_Suc Pair_fst_snd_eq fst_conv snd_conv}
   721             @ take_Suc_thms
   722         val tac =
   723             EVERY
   724             [simp_tac (HOL_basic_ss addsimps start_rules) 1,
   725              simp_tac (HOL_basic_ss addsimps @{thms fix_def2}) 1,
   726              rtac @{thm lub_eq} 1,
   727              rtac @{thm nat.induct} 1,
   728              simp_tac (HOL_basic_ss addsimps rules0) 1,
   729              asm_full_simp_tac (beta_ss addsimps rules1) 1]
   730       in
   731         Goal.prove_global thy [] [] goal (K tac)
   732       end
   733 
   734     (* prove lub of take equals ID *)
   735     fun prove_lub_take (((dbind, take_const), map_ID_thm), (lhsT, rhsT)) thy =
   736       let
   737         val n = Free ("n", natT)
   738         val goal = mk_eqs (mk_lub (lambda n (take_const $ n)), mk_ID lhsT)
   739         val tac =
   740             EVERY
   741             [rtac @{thm trans} 1, rtac map_ID_thm 2,
   742              cut_facts_tac [lub_take_lemma] 1,
   743              REPEAT (etac @{thm Pair_inject} 1), atac 1]
   744         val lub_take_thm = Goal.prove_global thy [] [] goal (K tac)
   745       in
   746         add_qualified_thm "lub_take" (dbind, lub_take_thm) thy
   747       end
   748     val (lub_take_thms, thy) =
   749         fold_map prove_lub_take
   750           (dbinds ~~ take_consts ~~ map_ID_thms ~~ dom_eqns) thy
   751 
   752     (* prove additional take theorems *)
   753     val (take_info2, thy) =
   754         Domain_Take_Proofs.add_lub_take_theorems
   755           (dbinds ~~ iso_infos) take_info lub_take_thms thy
   756   in
   757     ((iso_infos, take_info2), thy)
   758   end
   759 
   760 val domain_isomorphism = gen_domain_isomorphism cert_typ
   761 val domain_isomorphism_cmd = snd oo gen_domain_isomorphism read_typ
   762 
   763 (******************************************************************************)
   764 (******************************** outer syntax ********************************)
   765 (******************************************************************************)
   766 
   767 local
   768 
   769 val parse_domain_iso :
   770     (string list * binding * mixfix * string * (binding * binding) option)
   771       parser =
   772   (Parse.type_args -- Parse.binding -- Parse.opt_mixfix -- (Parse.$$$ "=" |-- Parse.typ) --
   773     Scan.option (Parse.$$$ "morphisms" |-- Parse.!!! (Parse.binding -- Parse.binding)))
   774     >> (fn ((((vs, t), mx), rhs), morphs) => (vs, t, mx, rhs, morphs))
   775 
   776 val parse_domain_isos = Parse.and_list1 parse_domain_iso
   777 
   778 in
   779 
   780 val _ =
   781   Outer_Syntax.command "domain_isomorphism" "define domain isomorphisms (HOLCF)"
   782     Keyword.thy_decl
   783     (parse_domain_isos >> (Toplevel.theory o domain_isomorphism_cmd))
   784 
   785 end
   786 
   787 end