src/HOL/HOLCF/Tools/Domain/domain_isomorphism.ML
changeset 40774 0437dbc127b3
parent 40771 1c6f7d4b110e
child 40832 4352ca878c41
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/HOL/HOLCF/Tools/Domain/domain_isomorphism.ML	Sat Nov 27 16:08:10 2010 -0800
     1.3 @@ -0,0 +1,787 @@
     1.4 +(*  Title:      HOLCF/Tools/Domain/domain_isomorphism.ML
     1.5 +    Author:     Brian Huffman
     1.6 +
     1.7 +Defines new types satisfying the given domain equations.
     1.8 +*)
     1.9 +
    1.10 +signature DOMAIN_ISOMORPHISM =
    1.11 +sig
    1.12 +  val domain_isomorphism :
    1.13 +      (string list * binding * mixfix * typ
    1.14 +       * (binding * binding) option) list ->
    1.15 +      theory ->
    1.16 +      (Domain_Take_Proofs.iso_info list
    1.17 +       * Domain_Take_Proofs.take_induct_info) * theory
    1.18 +
    1.19 +  val define_map_functions :
    1.20 +      (binding * Domain_Take_Proofs.iso_info) list ->
    1.21 +      theory ->
    1.22 +      {
    1.23 +        map_consts : term list,
    1.24 +        map_apply_thms : thm list,
    1.25 +        map_unfold_thms : thm list,
    1.26 +        deflation_map_thms : thm list
    1.27 +      }
    1.28 +      * theory
    1.29 +
    1.30 +  val domain_isomorphism_cmd :
    1.31 +    (string list * binding * mixfix * string * (binding * binding) option) list
    1.32 +      -> theory -> theory
    1.33 +
    1.34 +  val setup : theory -> theory
    1.35 +end;
    1.36 +
    1.37 +structure Domain_Isomorphism : DOMAIN_ISOMORPHISM =
    1.38 +struct
    1.39 +
    1.40 +val beta_rules =
    1.41 +  @{thms beta_cfun cont_id cont_const cont2cont_APP cont2cont_LAM'} @
    1.42 +  @{thms cont2cont_fst cont2cont_snd cont2cont_Pair cont2cont_prod_case'};
    1.43 +
    1.44 +val beta_ss = HOL_basic_ss addsimps (simp_thms @ beta_rules);
    1.45 +
    1.46 +val beta_tac = simp_tac beta_ss;
    1.47 +
    1.48 +fun is_cpo thy T = Sign.of_sort thy (T, @{sort cpo});
    1.49 +
    1.50 +(******************************************************************************)
    1.51 +(******************************** theory data *********************************)
    1.52 +(******************************************************************************)
    1.53 +
    1.54 +structure RepData = Named_Thms
    1.55 +(
    1.56 +  val name = "domain_defl_simps"
    1.57 +  val description = "theorems like DEFL('a t) = t_defl$DEFL('a)"
    1.58 +)
    1.59 +
    1.60 +structure IsodeflData = Named_Thms
    1.61 +(
    1.62 +  val name = "domain_isodefl"
    1.63 +  val description = "theorems like isodefl d t ==> isodefl (foo_map$d) (foo_defl$t)"
    1.64 +);
    1.65 +
    1.66 +val setup = RepData.setup #> IsodeflData.setup
    1.67 +
    1.68 +
    1.69 +(******************************************************************************)
    1.70 +(************************** building types and terms **************************)
    1.71 +(******************************************************************************)
    1.72 +
    1.73 +open HOLCF_Library;
    1.74 +
    1.75 +infixr 6 ->>;
    1.76 +infixr -->>;
    1.77 +
    1.78 +val udomT = @{typ udom};
    1.79 +val deflT = @{typ "defl"};
    1.80 +
    1.81 +fun mk_DEFL T =
    1.82 +  Const (@{const_name defl}, Term.itselfT T --> deflT) $ Logic.mk_type T;
    1.83 +
    1.84 +fun dest_DEFL (Const (@{const_name defl}, _) $ t) = Logic.dest_type t
    1.85 +  | dest_DEFL t = raise TERM ("dest_DEFL", [t]);
    1.86 +
    1.87 +fun mk_LIFTDEFL T =
    1.88 +  Const (@{const_name liftdefl}, Term.itselfT T --> deflT) $ Logic.mk_type T;
    1.89 +
    1.90 +fun dest_LIFTDEFL (Const (@{const_name liftdefl}, _) $ t) = Logic.dest_type t
    1.91 +  | dest_LIFTDEFL t = raise TERM ("dest_LIFTDEFL", [t]);
    1.92 +
    1.93 +fun mk_u_defl t = mk_capply (@{const "u_defl"}, t);
    1.94 +
    1.95 +fun mk_u_map t =
    1.96 +  let
    1.97 +    val (T, U) = dest_cfunT (fastype_of t);
    1.98 +    val u_map_type = (T ->> U) ->> (mk_upT T ->> mk_upT U);
    1.99 +    val u_map_const = Const (@{const_name u_map}, u_map_type);
   1.100 +  in
   1.101 +    mk_capply (u_map_const, t)
   1.102 +  end;
   1.103 +
   1.104 +fun emb_const T = Const (@{const_name emb}, T ->> udomT);
   1.105 +fun prj_const T = Const (@{const_name prj}, udomT ->> T);
   1.106 +fun coerce_const (T, U) = mk_cfcomp (prj_const U, emb_const T);
   1.107 +
   1.108 +fun isodefl_const T =
   1.109 +  Const (@{const_name isodefl}, (T ->> T) --> deflT --> HOLogic.boolT);
   1.110 +
   1.111 +fun mk_deflation t =
   1.112 +  Const (@{const_name deflation}, Term.fastype_of t --> boolT) $ t;
   1.113 +
   1.114 +(* splits a cterm into the right and lefthand sides of equality *)
   1.115 +fun dest_eqs t = HOLogic.dest_eq (HOLogic.dest_Trueprop t);
   1.116 +
   1.117 +fun mk_eqs (t, u) = HOLogic.mk_Trueprop (HOLogic.mk_eq (t, u));
   1.118 +
   1.119 +(******************************************************************************)
   1.120 +(****************************** isomorphism info ******************************)
   1.121 +(******************************************************************************)
   1.122 +
   1.123 +fun deflation_abs_rep (info : Domain_Take_Proofs.iso_info) : thm =
   1.124 +  let
   1.125 +    val abs_iso = #abs_inverse info;
   1.126 +    val rep_iso = #rep_inverse info;
   1.127 +    val thm = @{thm deflation_abs_rep} OF [abs_iso, rep_iso];
   1.128 +  in
   1.129 +    Drule.zero_var_indexes thm
   1.130 +  end
   1.131 +
   1.132 +(******************************************************************************)
   1.133 +(*************** fixed-point definitions and unfolding theorems ***************)
   1.134 +(******************************************************************************)
   1.135 +
   1.136 +fun mk_projs []      t = []
   1.137 +  | mk_projs (x::[]) t = [(x, t)]
   1.138 +  | mk_projs (x::xs) t = (x, mk_fst t) :: mk_projs xs (mk_snd t);
   1.139 +
   1.140 +fun add_fixdefs
   1.141 +    (spec : (binding * term) list)
   1.142 +    (thy : theory) : (thm list * thm list) * theory =
   1.143 +  let
   1.144 +    val binds = map fst spec;
   1.145 +    val (lhss, rhss) = ListPair.unzip (map (dest_eqs o snd) spec);
   1.146 +    val functional = lambda_tuple lhss (mk_tuple rhss);
   1.147 +    val fixpoint = mk_fix (mk_cabs functional);
   1.148 +
   1.149 +    (* project components of fixpoint *)
   1.150 +    val projs = mk_projs lhss fixpoint;
   1.151 +
   1.152 +    (* convert parameters to lambda abstractions *)
   1.153 +    fun mk_eqn (lhs, rhs) =
   1.154 +        case lhs of
   1.155 +          Const (@{const_name Rep_cfun}, _) $ f $ (x as Free _) =>
   1.156 +            mk_eqn (f, big_lambda x rhs)
   1.157 +        | f $ Const (@{const_name TYPE}, T) =>
   1.158 +            mk_eqn (f, Abs ("t", T, rhs))
   1.159 +        | Const _ => Logic.mk_equals (lhs, rhs)
   1.160 +        | _ => raise TERM ("lhs not of correct form", [lhs, rhs]);
   1.161 +    val eqns = map mk_eqn projs;
   1.162 +
   1.163 +    (* register constant definitions *)
   1.164 +    val (fixdef_thms, thy) =
   1.165 +      (Global_Theory.add_defs false o map Thm.no_attributes)
   1.166 +        (map (Binding.suffix_name "_def") binds ~~ eqns) thy;
   1.167 +
   1.168 +    (* prove applied version of definitions *)
   1.169 +    fun prove_proj (lhs, rhs) =
   1.170 +      let
   1.171 +        val tac = rewrite_goals_tac fixdef_thms THEN beta_tac 1;
   1.172 +        val goal = Logic.mk_equals (lhs, rhs);
   1.173 +      in Goal.prove_global thy [] [] goal (K tac) end;
   1.174 +    val proj_thms = map prove_proj projs;
   1.175 +
   1.176 +    (* mk_tuple lhss == fixpoint *)
   1.177 +    fun pair_equalI (thm1, thm2) = @{thm Pair_equalI} OF [thm1, thm2];
   1.178 +    val tuple_fixdef_thm = foldr1 pair_equalI proj_thms;
   1.179 +
   1.180 +    val cont_thm =
   1.181 +      Goal.prove_global thy [] [] (mk_trp (mk_cont functional))
   1.182 +        (K (beta_tac 1));
   1.183 +    val tuple_unfold_thm =
   1.184 +      (@{thm def_cont_fix_eq} OF [tuple_fixdef_thm, cont_thm])
   1.185 +      |> Local_Defs.unfold (ProofContext.init_global thy) @{thms split_conv};
   1.186 +
   1.187 +    fun mk_unfold_thms [] thm = []
   1.188 +      | mk_unfold_thms (n::[]) thm = [(n, thm)]
   1.189 +      | mk_unfold_thms (n::ns) thm = let
   1.190 +          val thmL = thm RS @{thm Pair_eqD1};
   1.191 +          val thmR = thm RS @{thm Pair_eqD2};
   1.192 +        in (n, thmL) :: mk_unfold_thms ns thmR end;
   1.193 +    val unfold_binds = map (Binding.suffix_name "_unfold") binds;
   1.194 +
   1.195 +    (* register unfold theorems *)
   1.196 +    val (unfold_thms, thy) =
   1.197 +      (Global_Theory.add_thms o map (Thm.no_attributes o apsnd Drule.zero_var_indexes))
   1.198 +        (mk_unfold_thms unfold_binds tuple_unfold_thm) thy;
   1.199 +  in
   1.200 +    ((proj_thms, unfold_thms), thy)
   1.201 +  end;
   1.202 +
   1.203 +
   1.204 +(******************************************************************************)
   1.205 +(****************** deflation combinators and map functions *******************)
   1.206 +(******************************************************************************)
   1.207 +
   1.208 +fun defl_of_typ
   1.209 +    (thy : theory)
   1.210 +    (tab1 : (typ * term) list)
   1.211 +    (tab2 : (typ * term) list)
   1.212 +    (T : typ) : term =
   1.213 +  let
   1.214 +    val defl_simps = RepData.get (ProofContext.init_global thy);
   1.215 +    val rules = map (Thm.concl_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq) defl_simps;
   1.216 +    val rules' = map (apfst mk_DEFL) tab1 @ map (apfst mk_LIFTDEFL) tab2;
   1.217 +    fun proc1 t =
   1.218 +      (case dest_DEFL t of
   1.219 +        TFree (a, _) => SOME (Free ("d" ^ Library.unprefix "'" a, deflT))
   1.220 +      | _ => NONE) handle TERM _ => NONE;
   1.221 +    fun proc2 t =
   1.222 +      (case dest_LIFTDEFL t of
   1.223 +        TFree (a, _) => SOME (Free ("p" ^ Library.unprefix "'" a, deflT))
   1.224 +      | _ => NONE) handle TERM _ => NONE;
   1.225 +  in
   1.226 +    Pattern.rewrite_term thy (rules @ rules') [proc1, proc2] (mk_DEFL T)
   1.227 +  end;
   1.228 +
   1.229 +(******************************************************************************)
   1.230 +(********************* declaring definitions and theorems *********************)
   1.231 +(******************************************************************************)
   1.232 +
   1.233 +fun define_const
   1.234 +    (bind : binding, rhs : term)
   1.235 +    (thy : theory)
   1.236 +    : (term * thm) * theory =
   1.237 +  let
   1.238 +    val typ = Term.fastype_of rhs;
   1.239 +    val (const, thy) = Sign.declare_const ((bind, typ), NoSyn) thy;
   1.240 +    val eqn = Logic.mk_equals (const, rhs);
   1.241 +    val def = Thm.no_attributes (Binding.suffix_name "_def" bind, eqn);
   1.242 +    val (def_thm, thy) = yield_singleton (Global_Theory.add_defs false) def thy;
   1.243 +  in
   1.244 +    ((const, def_thm), thy)
   1.245 +  end;
   1.246 +
   1.247 +fun add_qualified_thm name (dbind, thm) =
   1.248 +    yield_singleton Global_Theory.add_thms
   1.249 +      ((Binding.qualified true name dbind, thm), []);
   1.250 +
   1.251 +(******************************************************************************)
   1.252 +(*************************** defining map functions ***************************)
   1.253 +(******************************************************************************)
   1.254 +
   1.255 +fun define_map_functions
   1.256 +    (spec : (binding * Domain_Take_Proofs.iso_info) list)
   1.257 +    (thy : theory) =
   1.258 +  let
   1.259 +
   1.260 +    (* retrieve components of spec *)
   1.261 +    val dbinds = map fst spec;
   1.262 +    val iso_infos = map snd spec;
   1.263 +    val dom_eqns = map (fn x => (#absT x, #repT x)) iso_infos;
   1.264 +    val rep_abs_consts = map (fn x => (#rep_const x, #abs_const x)) iso_infos;
   1.265 +
   1.266 +    fun mapT (T as Type (_, Ts)) =
   1.267 +        (map (fn T => T ->> T) (filter (is_cpo thy) Ts)) -->> (T ->> T)
   1.268 +      | mapT T = T ->> T;
   1.269 +
   1.270 +    (* declare map functions *)
   1.271 +    fun declare_map_const (tbind, (lhsT, rhsT)) thy =
   1.272 +      let
   1.273 +        val map_type = mapT lhsT;
   1.274 +        val map_bind = Binding.suffix_name "_map" tbind;
   1.275 +      in
   1.276 +        Sign.declare_const ((map_bind, map_type), NoSyn) thy
   1.277 +      end;
   1.278 +    val (map_consts, thy) = thy |>
   1.279 +      fold_map declare_map_const (dbinds ~~ dom_eqns);
   1.280 +
   1.281 +    (* defining equations for map functions *)
   1.282 +    local
   1.283 +      fun unprime a = Library.unprefix "'" a;
   1.284 +      fun mapvar T = Free (unprime (fst (dest_TFree T)), T ->> T);
   1.285 +      fun map_lhs (map_const, lhsT) =
   1.286 +          (lhsT, list_ccomb (map_const, map mapvar (filter (is_cpo thy) (snd (dest_Type lhsT)))));
   1.287 +      val tab1 = map map_lhs (map_consts ~~ map fst dom_eqns);
   1.288 +      val Ts = (snd o dest_Type o fst o hd) dom_eqns;
   1.289 +      val tab = (Ts ~~ map mapvar Ts) @ tab1;
   1.290 +      fun mk_map_spec (((rep_const, abs_const), map_const), (lhsT, rhsT)) =
   1.291 +        let
   1.292 +          val lhs = Domain_Take_Proofs.map_of_typ thy tab lhsT;
   1.293 +          val body = Domain_Take_Proofs.map_of_typ thy tab rhsT;
   1.294 +          val rhs = mk_cfcomp (abs_const, mk_cfcomp (body, rep_const));
   1.295 +        in mk_eqs (lhs, rhs) end;
   1.296 +    in
   1.297 +      val map_specs =
   1.298 +          map mk_map_spec (rep_abs_consts ~~ map_consts ~~ dom_eqns);
   1.299 +    end;
   1.300 +
   1.301 +    (* register recursive definition of map functions *)
   1.302 +    val map_binds = map (Binding.suffix_name "_map") dbinds;
   1.303 +    val ((map_apply_thms, map_unfold_thms), thy) =
   1.304 +      add_fixdefs (map_binds ~~ map_specs) thy;
   1.305 +
   1.306 +    (* prove deflation theorems for map functions *)
   1.307 +    val deflation_abs_rep_thms = map deflation_abs_rep iso_infos;
   1.308 +    val deflation_map_thm =
   1.309 +      let
   1.310 +        fun unprime a = Library.unprefix "'" a;
   1.311 +        fun mk_f T = Free (unprime (fst (dest_TFree T)), T ->> T);
   1.312 +        fun mk_assm T = mk_trp (mk_deflation (mk_f T));
   1.313 +        fun mk_goal (map_const, (lhsT, rhsT)) =
   1.314 +          let
   1.315 +            val (_, Ts) = dest_Type lhsT;
   1.316 +            val map_term = list_ccomb (map_const, map mk_f (filter (is_cpo thy) Ts));
   1.317 +          in mk_deflation map_term end;
   1.318 +        val assms = (map mk_assm o filter (is_cpo thy) o snd o dest_Type o fst o hd) dom_eqns;
   1.319 +        val goals = map mk_goal (map_consts ~~ dom_eqns);
   1.320 +        val goal = mk_trp (foldr1 HOLogic.mk_conj goals);
   1.321 +        val start_thms =
   1.322 +          @{thm split_def} :: map_apply_thms;
   1.323 +        val adm_rules =
   1.324 +          @{thms adm_conj adm_subst [OF _ adm_deflation]
   1.325 +                 cont2cont_fst cont2cont_snd cont_id};
   1.326 +        val bottom_rules =
   1.327 +          @{thms fst_strict snd_strict deflation_UU simp_thms};
   1.328 +        val deflation_rules =
   1.329 +          @{thms conjI deflation_ID}
   1.330 +          @ deflation_abs_rep_thms
   1.331 +          @ Domain_Take_Proofs.get_deflation_thms thy;
   1.332 +      in
   1.333 +        Goal.prove_global thy [] assms goal (fn {prems, ...} =>
   1.334 +         EVERY
   1.335 +          [simp_tac (HOL_basic_ss addsimps start_thms) 1,
   1.336 +           rtac @{thm fix_ind} 1,
   1.337 +           REPEAT (resolve_tac adm_rules 1),
   1.338 +           simp_tac (HOL_basic_ss addsimps bottom_rules) 1,
   1.339 +           simp_tac beta_ss 1,
   1.340 +           simp_tac (HOL_basic_ss addsimps @{thms fst_conv snd_conv}) 1,
   1.341 +           REPEAT (etac @{thm conjE} 1),
   1.342 +           REPEAT (resolve_tac (deflation_rules @ prems) 1 ORELSE atac 1)])
   1.343 +      end;
   1.344 +    fun conjuncts [] thm = []
   1.345 +      | conjuncts (n::[]) thm = [(n, thm)]
   1.346 +      | conjuncts (n::ns) thm = let
   1.347 +          val thmL = thm RS @{thm conjunct1};
   1.348 +          val thmR = thm RS @{thm conjunct2};
   1.349 +        in (n, thmL):: conjuncts ns thmR end;
   1.350 +    val deflation_map_binds = dbinds |>
   1.351 +        map (Binding.prefix_name "deflation_" o Binding.suffix_name "_map");
   1.352 +    val (deflation_map_thms, thy) = thy |>
   1.353 +      (Global_Theory.add_thms o map (Thm.no_attributes o apsnd Drule.zero_var_indexes))
   1.354 +        (conjuncts deflation_map_binds deflation_map_thm);
   1.355 +
   1.356 +    (* register indirect recursion in theory data *)
   1.357 +    local
   1.358 +      fun register_map (dname, args) =
   1.359 +        Domain_Take_Proofs.add_rec_type (dname, args);
   1.360 +      val dnames = map (fst o dest_Type o fst) dom_eqns;
   1.361 +      val map_names = map (fst o dest_Const) map_consts;
   1.362 +      fun args (T, _) = case T of Type (_, Ts) => map (is_cpo thy) Ts | _ => [];
   1.363 +      val argss = map args dom_eqns;
   1.364 +    in
   1.365 +      val thy =
   1.366 +          fold register_map (dnames ~~ argss) thy;
   1.367 +    end;
   1.368 +
   1.369 +    (* register deflation theorems *)
   1.370 +    val thy = fold Domain_Take_Proofs.add_deflation_thm deflation_map_thms thy;
   1.371 +
   1.372 +    val result =
   1.373 +      {
   1.374 +        map_consts = map_consts,
   1.375 +        map_apply_thms = map_apply_thms,
   1.376 +        map_unfold_thms = map_unfold_thms,
   1.377 +        deflation_map_thms = deflation_map_thms
   1.378 +      }
   1.379 +  in
   1.380 +    (result, thy)
   1.381 +  end;
   1.382 +
   1.383 +(******************************************************************************)
   1.384 +(******************************* main function ********************************)
   1.385 +(******************************************************************************)
   1.386 +
   1.387 +fun read_typ thy str sorts =
   1.388 +  let
   1.389 +    val ctxt = ProofContext.init_global thy
   1.390 +      |> fold (Variable.declare_typ o TFree) sorts;
   1.391 +    val T = Syntax.read_typ ctxt str;
   1.392 +  in (T, Term.add_tfreesT T sorts) end;
   1.393 +
   1.394 +fun cert_typ sign raw_T sorts =
   1.395 +  let
   1.396 +    val T = Type.no_tvars (Sign.certify_typ sign raw_T)
   1.397 +      handle TYPE (msg, _, _) => error msg;
   1.398 +    val sorts' = Term.add_tfreesT T sorts;
   1.399 +    val _ =
   1.400 +      case duplicates (op =) (map fst sorts') of
   1.401 +        [] => ()
   1.402 +      | dups => error ("Inconsistent sort constraints for " ^ commas dups)
   1.403 +  in (T, sorts') end;
   1.404 +
   1.405 +fun gen_domain_isomorphism
   1.406 +    (prep_typ: theory -> 'a -> (string * sort) list -> typ * (string * sort) list)
   1.407 +    (doms_raw: (string list * binding * mixfix * 'a * (binding * binding) option) list)
   1.408 +    (thy: theory)
   1.409 +    : (Domain_Take_Proofs.iso_info list
   1.410 +       * Domain_Take_Proofs.take_induct_info) * theory =
   1.411 +  let
   1.412 +    val _ = Theory.requires thy "Domain" "domain isomorphisms";
   1.413 +
   1.414 +    (* this theory is used just for parsing *)
   1.415 +    val tmp_thy = thy |>
   1.416 +      Theory.copy |>
   1.417 +      Sign.add_types (map (fn (tvs, tbind, mx, _, morphs) =>
   1.418 +        (tbind, length tvs, mx)) doms_raw);
   1.419 +
   1.420 +    fun prep_dom thy (vs, t, mx, typ_raw, morphs) sorts =
   1.421 +      let val (typ, sorts') = prep_typ thy typ_raw sorts
   1.422 +      in ((vs, t, mx, typ, morphs), sorts') end;
   1.423 +
   1.424 +    val (doms : (string list * binding * mixfix * typ * (binding * binding) option) list,
   1.425 +         sorts : (string * sort) list) =
   1.426 +      fold_map (prep_dom tmp_thy) doms_raw [];
   1.427 +
   1.428 +    (* lookup function for sorts of type variables *)
   1.429 +    fun the_sort v = the (AList.lookup (op =) sorts v);
   1.430 +
   1.431 +    (* declare arities in temporary theory *)
   1.432 +    val tmp_thy =
   1.433 +      let
   1.434 +        fun arity (vs, tbind, mx, _, _) =
   1.435 +          (Sign.full_name thy tbind, map the_sort vs, @{sort "domain"});
   1.436 +      in
   1.437 +        fold AxClass.axiomatize_arity (map arity doms) tmp_thy
   1.438 +      end;
   1.439 +
   1.440 +    (* check bifiniteness of right-hand sides *)
   1.441 +    fun check_rhs (vs, tbind, mx, rhs, morphs) =
   1.442 +      if Sign.of_sort tmp_thy (rhs, @{sort "domain"}) then ()
   1.443 +      else error ("Type not of sort domain: " ^
   1.444 +        quote (Syntax.string_of_typ_global tmp_thy rhs));
   1.445 +    val _ = map check_rhs doms;
   1.446 +
   1.447 +    (* domain equations *)
   1.448 +    fun mk_dom_eqn (vs, tbind, mx, rhs, morphs) =
   1.449 +      let fun arg v = TFree (v, the_sort v);
   1.450 +      in (Type (Sign.full_name tmp_thy tbind, map arg vs), rhs) end;
   1.451 +    val dom_eqns = map mk_dom_eqn doms;
   1.452 +
   1.453 +    (* check for valid type parameters *)
   1.454 +    val (tyvars, _, _, _, _) = hd doms;
   1.455 +    val new_doms = map (fn (tvs, tname, mx, _, _) =>
   1.456 +      let val full_tname = Sign.full_name tmp_thy tname
   1.457 +      in
   1.458 +        (case duplicates (op =) tvs of
   1.459 +          [] =>
   1.460 +            if eq_set (op =) (tyvars, tvs) then (full_tname, tvs)
   1.461 +            else error ("Mutually recursive domains must have same type parameters")
   1.462 +        | dups => error ("Duplicate parameter(s) for domain " ^ quote (Binding.str_of tname) ^
   1.463 +            " : " ^ commas dups))
   1.464 +      end) doms;
   1.465 +    val dbinds = map (fn (_, dbind, _, _, _) => dbind) doms;
   1.466 +    val morphs = map (fn (_, _, _, _, morphs) => morphs) doms;
   1.467 +
   1.468 +    (* determine deflation combinator arguments *)
   1.469 +    val lhsTs : typ list = map fst dom_eqns;
   1.470 +    val defl_rec = Free ("t", mk_tupleT (map (K deflT) lhsTs));
   1.471 +    val defl_recs = mk_projs lhsTs defl_rec;
   1.472 +    val defl_recs' = map (apsnd mk_u_defl) defl_recs;
   1.473 +    fun defl_body (_, _, _, rhsT, _) =
   1.474 +      defl_of_typ tmp_thy defl_recs defl_recs' rhsT;
   1.475 +    val functional = Term.lambda defl_rec (mk_tuple (map defl_body doms));
   1.476 +
   1.477 +    val tfrees = map fst (Term.add_tfrees functional []);
   1.478 +    val frees = map fst (Term.add_frees functional []);
   1.479 +    fun get_defl_flags (vs, _, _, _, _) =
   1.480 +      let
   1.481 +        fun argT v = TFree (v, the_sort v);
   1.482 +        fun mk_d v = "d" ^ Library.unprefix "'" v;
   1.483 +        fun mk_p v = "p" ^ Library.unprefix "'" v;
   1.484 +        val args = maps (fn v => [(mk_d v, mk_DEFL (argT v)), (mk_p v, mk_LIFTDEFL (argT v))]) vs;
   1.485 +        val typeTs = map argT (filter (member (op =) tfrees) vs);
   1.486 +        val defl_args = map snd (filter (member (op =) frees o fst) args);
   1.487 +      in
   1.488 +        (typeTs, defl_args)
   1.489 +      end;
   1.490 +    val defl_flagss = map get_defl_flags doms;
   1.491 +
   1.492 +    (* declare deflation combinator constants *)
   1.493 +    fun declare_defl_const ((typeTs, defl_args), (_, tbind, _, _, _)) thy =
   1.494 +      let
   1.495 +        val defl_bind = Binding.suffix_name "_defl" tbind;
   1.496 +        val defl_type =
   1.497 +          map Term.itselfT typeTs ---> map (K deflT) defl_args -->> deflT;
   1.498 +      in
   1.499 +        Sign.declare_const ((defl_bind, defl_type), NoSyn) thy
   1.500 +      end;
   1.501 +    val (defl_consts, thy) =
   1.502 +      fold_map declare_defl_const (defl_flagss ~~ doms) thy;
   1.503 +
   1.504 +    (* defining equations for type combinators *)
   1.505 +    fun mk_defl_term (defl_const, (typeTs, defl_args)) =
   1.506 +      let
   1.507 +        val type_args = map Logic.mk_type typeTs;
   1.508 +      in
   1.509 +        list_ccomb (list_comb (defl_const, type_args), defl_args)
   1.510 +      end;
   1.511 +    val defl_terms = map mk_defl_term (defl_consts ~~ defl_flagss);
   1.512 +    val defl_tab = map fst dom_eqns ~~ defl_terms;
   1.513 +    val defl_tab' = map fst dom_eqns ~~ map mk_u_defl defl_terms;
   1.514 +    fun mk_defl_spec (lhsT, rhsT) =
   1.515 +      mk_eqs (defl_of_typ tmp_thy defl_tab defl_tab' lhsT,
   1.516 +              defl_of_typ tmp_thy defl_tab defl_tab' rhsT);
   1.517 +    val defl_specs = map mk_defl_spec dom_eqns;
   1.518 +
   1.519 +    (* register recursive definition of deflation combinators *)
   1.520 +    val defl_binds = map (Binding.suffix_name "_defl") dbinds;
   1.521 +    val ((defl_apply_thms, defl_unfold_thms), thy) =
   1.522 +      add_fixdefs (defl_binds ~~ defl_specs) thy;
   1.523 +
   1.524 +    (* define types using deflation combinators *)
   1.525 +    fun make_repdef ((vs, tbind, mx, _, _), defl) thy =
   1.526 +      let
   1.527 +        val spec = (tbind, map (rpair dummyS) vs, mx);
   1.528 +        val ((_, _, _, {DEFL, liftemb_def, liftprj_def, ...}), thy) =
   1.529 +          Domaindef.add_domaindef false NONE spec defl NONE thy;
   1.530 +        (* declare domain_defl_simps rules *)
   1.531 +        val thy = Context.theory_map (RepData.add_thm DEFL) thy;
   1.532 +      in
   1.533 +        (DEFL, thy)
   1.534 +      end;
   1.535 +    val (DEFL_thms, thy) = fold_map make_repdef (doms ~~ defl_terms) thy;
   1.536 +
   1.537 +    (* prove DEFL equations *)
   1.538 +    fun mk_DEFL_eq_thm (lhsT, rhsT) =
   1.539 +      let
   1.540 +        val goal = mk_eqs (mk_DEFL lhsT, mk_DEFL rhsT);
   1.541 +        val DEFL_simps = RepData.get (ProofContext.init_global thy);
   1.542 +        val tac =
   1.543 +          rewrite_goals_tac (map mk_meta_eq DEFL_simps)
   1.544 +          THEN TRY (resolve_tac defl_unfold_thms 1);
   1.545 +      in
   1.546 +        Goal.prove_global thy [] [] goal (K tac)
   1.547 +      end;
   1.548 +    val DEFL_eq_thms = map mk_DEFL_eq_thm dom_eqns;
   1.549 +
   1.550 +    (* register DEFL equations *)
   1.551 +    val DEFL_eq_binds = map (Binding.prefix_name "DEFL_eq_") dbinds;
   1.552 +    val (_, thy) = thy |>
   1.553 +      (Global_Theory.add_thms o map Thm.no_attributes)
   1.554 +        (DEFL_eq_binds ~~ DEFL_eq_thms);
   1.555 +
   1.556 +    (* define rep/abs functions *)
   1.557 +    fun mk_rep_abs ((tbind, morphs), (lhsT, rhsT)) thy =
   1.558 +      let
   1.559 +        val rep_bind = Binding.suffix_name "_rep" tbind;
   1.560 +        val abs_bind = Binding.suffix_name "_abs" tbind;
   1.561 +        val ((rep_const, rep_def), thy) =
   1.562 +            define_const (rep_bind, coerce_const (lhsT, rhsT)) thy;
   1.563 +        val ((abs_const, abs_def), thy) =
   1.564 +            define_const (abs_bind, coerce_const (rhsT, lhsT)) thy;
   1.565 +      in
   1.566 +        (((rep_const, abs_const), (rep_def, abs_def)), thy)
   1.567 +      end;
   1.568 +    val ((rep_abs_consts, rep_abs_defs), thy) = thy
   1.569 +      |> fold_map mk_rep_abs (dbinds ~~ morphs ~~ dom_eqns)
   1.570 +      |>> ListPair.unzip;
   1.571 +
   1.572 +    (* prove isomorphism and isodefl rules *)
   1.573 +    fun mk_iso_thms ((tbind, DEFL_eq), (rep_def, abs_def)) thy =
   1.574 +      let
   1.575 +        fun make thm =
   1.576 +            Drule.zero_var_indexes (thm OF [DEFL_eq, abs_def, rep_def]);
   1.577 +        val rep_iso_thm = make @{thm domain_rep_iso};
   1.578 +        val abs_iso_thm = make @{thm domain_abs_iso};
   1.579 +        val isodefl_thm = make @{thm isodefl_abs_rep};
   1.580 +        val thy = thy
   1.581 +          |> snd o add_qualified_thm "rep_iso" (tbind, rep_iso_thm)
   1.582 +          |> snd o add_qualified_thm "abs_iso" (tbind, abs_iso_thm)
   1.583 +          |> snd o add_qualified_thm "isodefl_abs_rep" (tbind, isodefl_thm);
   1.584 +      in
   1.585 +        (((rep_iso_thm, abs_iso_thm), isodefl_thm), thy)
   1.586 +      end;
   1.587 +    val ((iso_thms, isodefl_abs_rep_thms), thy) =
   1.588 +      thy
   1.589 +      |> fold_map mk_iso_thms (dbinds ~~ DEFL_eq_thms ~~ rep_abs_defs)
   1.590 +      |>> ListPair.unzip;
   1.591 +
   1.592 +    (* collect info about rep/abs *)
   1.593 +    val iso_infos : Domain_Take_Proofs.iso_info list =
   1.594 +      let
   1.595 +        fun mk_info (((lhsT, rhsT), (repC, absC)), (rep_iso, abs_iso)) =
   1.596 +          {
   1.597 +            repT = rhsT,
   1.598 +            absT = lhsT,
   1.599 +            rep_const = repC,
   1.600 +            abs_const = absC,
   1.601 +            rep_inverse = rep_iso,
   1.602 +            abs_inverse = abs_iso
   1.603 +          };
   1.604 +      in
   1.605 +        map mk_info (dom_eqns ~~ rep_abs_consts ~~ iso_thms)
   1.606 +      end
   1.607 +
   1.608 +    (* definitions and proofs related to map functions *)
   1.609 +    val (map_info, thy) =
   1.610 +        define_map_functions (dbinds ~~ iso_infos) thy;
   1.611 +    val { map_consts, map_apply_thms, map_unfold_thms,
   1.612 +          deflation_map_thms } = map_info;
   1.613 +
   1.614 +    (* prove isodefl rules for map functions *)
   1.615 +    val isodefl_thm =
   1.616 +      let
   1.617 +        fun unprime a = Library.unprefix "'" a;
   1.618 +        fun mk_d T = Free ("d" ^ unprime (fst (dest_TFree T)), deflT);
   1.619 +        fun mk_p T = Free ("p" ^ unprime (fst (dest_TFree T)), deflT);
   1.620 +        fun mk_f T = Free ("f" ^ unprime (fst (dest_TFree T)), T ->> T);
   1.621 +        fun mk_assm t =
   1.622 +          case try dest_LIFTDEFL t of
   1.623 +            SOME T => mk_trp (isodefl_const (mk_upT T) $ mk_u_map (mk_f T) $ mk_p T)
   1.624 +          | NONE =>
   1.625 +            let val T = dest_DEFL t
   1.626 +            in mk_trp (isodefl_const T $ mk_f T $ mk_d T) end;
   1.627 +        fun mk_goal (map_const, (T, rhsT)) =
   1.628 +          let
   1.629 +            val (_, Ts) = dest_Type T;
   1.630 +            val map_term = list_ccomb (map_const, map mk_f (filter (is_cpo thy) Ts));
   1.631 +            val defl_term = defl_of_typ thy (Ts ~~ map mk_d Ts) (Ts ~~ map mk_p Ts) T;
   1.632 +          in isodefl_const T $ map_term $ defl_term end;
   1.633 +        val assms = (map mk_assm o snd o hd) defl_flagss;
   1.634 +        val goals = map mk_goal (map_consts ~~ dom_eqns);
   1.635 +        val goal = mk_trp (foldr1 HOLogic.mk_conj goals);
   1.636 +        val start_thms =
   1.637 +          @{thm split_def} :: defl_apply_thms @ map_apply_thms;
   1.638 +        val adm_rules =
   1.639 +          @{thms adm_conj adm_isodefl cont2cont_fst cont2cont_snd cont_id};
   1.640 +        val bottom_rules =
   1.641 +          @{thms fst_strict snd_strict isodefl_bottom simp_thms};
   1.642 +        val map_ID_thms = Domain_Take_Proofs.get_map_ID_thms thy;
   1.643 +        val map_ID_simps = map (fn th => th RS sym) map_ID_thms;
   1.644 +        val isodefl_rules =
   1.645 +          @{thms conjI isodefl_ID_DEFL isodefl_LIFTDEFL}
   1.646 +          @ isodefl_abs_rep_thms
   1.647 +          @ IsodeflData.get (ProofContext.init_global thy);
   1.648 +      in
   1.649 +        Goal.prove_global thy [] assms goal (fn {prems, ...} =>
   1.650 +         EVERY
   1.651 +          [simp_tac (HOL_basic_ss addsimps start_thms) 1,
   1.652 +           (* FIXME: how reliable is unification here? *)
   1.653 +           (* Maybe I should instantiate the rule. *)
   1.654 +           rtac @{thm parallel_fix_ind} 1,
   1.655 +           REPEAT (resolve_tac adm_rules 1),
   1.656 +           simp_tac (HOL_basic_ss addsimps bottom_rules) 1,
   1.657 +           simp_tac beta_ss 1,
   1.658 +           simp_tac (HOL_basic_ss addsimps @{thms fst_conv snd_conv}) 1,
   1.659 +           simp_tac (HOL_basic_ss addsimps map_ID_simps) 1,
   1.660 +           REPEAT (etac @{thm conjE} 1),
   1.661 +           REPEAT (resolve_tac (isodefl_rules @ prems) 1 ORELSE atac 1)])
   1.662 +      end;
   1.663 +    val isodefl_binds = map (Binding.prefix_name "isodefl_") dbinds;
   1.664 +    fun conjuncts [] thm = []
   1.665 +      | conjuncts (n::[]) thm = [(n, thm)]
   1.666 +      | conjuncts (n::ns) thm = let
   1.667 +          val thmL = thm RS @{thm conjunct1};
   1.668 +          val thmR = thm RS @{thm conjunct2};
   1.669 +        in (n, thmL):: conjuncts ns thmR end;
   1.670 +    val (isodefl_thms, thy) = thy |>
   1.671 +      (Global_Theory.add_thms o map (Thm.no_attributes o apsnd Drule.zero_var_indexes))
   1.672 +        (conjuncts isodefl_binds isodefl_thm);
   1.673 +    val thy = fold (Context.theory_map o IsodeflData.add_thm) isodefl_thms thy;
   1.674 +
   1.675 +    (* prove map_ID theorems *)
   1.676 +    fun prove_map_ID_thm
   1.677 +        (((map_const, (lhsT, _)), DEFL_thm), isodefl_thm) =
   1.678 +      let
   1.679 +        val Ts = snd (dest_Type lhsT);
   1.680 +        fun is_cpo T = Sign.of_sort thy (T, @{sort cpo});
   1.681 +        val lhs = list_ccomb (map_const, map mk_ID (filter is_cpo Ts));
   1.682 +        val goal = mk_eqs (lhs, mk_ID lhsT);
   1.683 +        val tac = EVERY
   1.684 +          [rtac @{thm isodefl_DEFL_imp_ID} 1,
   1.685 +           stac DEFL_thm 1,
   1.686 +           rtac isodefl_thm 1,
   1.687 +           REPEAT (resolve_tac @{thms isodefl_ID_DEFL isodefl_LIFTDEFL} 1)];
   1.688 +      in
   1.689 +        Goal.prove_global thy [] [] goal (K tac)
   1.690 +      end;
   1.691 +    val map_ID_binds = map (Binding.suffix_name "_map_ID") dbinds;
   1.692 +    val map_ID_thms =
   1.693 +      map prove_map_ID_thm
   1.694 +        (map_consts ~~ dom_eqns ~~ DEFL_thms ~~ isodefl_thms);
   1.695 +    val (_, thy) = thy |>
   1.696 +      (Global_Theory.add_thms o map (rpair [Domain_Take_Proofs.map_ID_add]))
   1.697 +        (map_ID_binds ~~ map_ID_thms);
   1.698 +
   1.699 +    (* definitions and proofs related to take functions *)
   1.700 +    val (take_info, thy) =
   1.701 +        Domain_Take_Proofs.define_take_functions
   1.702 +          (dbinds ~~ iso_infos) thy;
   1.703 +    val { take_consts, chain_take_thms, take_0_thms, take_Suc_thms, ...} =
   1.704 +        take_info;
   1.705 +
   1.706 +    (* least-upper-bound lemma for take functions *)
   1.707 +    val lub_take_lemma =
   1.708 +      let
   1.709 +        val lhs = mk_tuple (map mk_lub take_consts);
   1.710 +        fun is_cpo T = Sign.of_sort thy (T, @{sort cpo});
   1.711 +        fun mk_map_ID (map_const, (lhsT, rhsT)) =
   1.712 +          list_ccomb (map_const, map mk_ID (filter is_cpo (snd (dest_Type lhsT))));
   1.713 +        val rhs = mk_tuple (map mk_map_ID (map_consts ~~ dom_eqns));
   1.714 +        val goal = mk_trp (mk_eq (lhs, rhs));
   1.715 +        val map_ID_thms = Domain_Take_Proofs.get_map_ID_thms thy;
   1.716 +        val start_rules =
   1.717 +            @{thms lub_Pair [symmetric] ch2ch_Pair} @ chain_take_thms
   1.718 +            @ @{thms pair_collapse split_def}
   1.719 +            @ map_apply_thms @ map_ID_thms;
   1.720 +        val rules0 =
   1.721 +            @{thms iterate_0 Pair_strict} @ take_0_thms;
   1.722 +        val rules1 =
   1.723 +            @{thms iterate_Suc Pair_fst_snd_eq fst_conv snd_conv}
   1.724 +            @ take_Suc_thms;
   1.725 +        val tac =
   1.726 +            EVERY
   1.727 +            [simp_tac (HOL_basic_ss addsimps start_rules) 1,
   1.728 +             simp_tac (HOL_basic_ss addsimps @{thms fix_def2}) 1,
   1.729 +             rtac @{thm lub_eq} 1,
   1.730 +             rtac @{thm nat.induct} 1,
   1.731 +             simp_tac (HOL_basic_ss addsimps rules0) 1,
   1.732 +             asm_full_simp_tac (beta_ss addsimps rules1) 1];
   1.733 +      in
   1.734 +        Goal.prove_global thy [] [] goal (K tac)
   1.735 +      end;
   1.736 +
   1.737 +    (* prove lub of take equals ID *)
   1.738 +    fun prove_lub_take (((dbind, take_const), map_ID_thm), (lhsT, rhsT)) thy =
   1.739 +      let
   1.740 +        val n = Free ("n", natT);
   1.741 +        val goal = mk_eqs (mk_lub (lambda n (take_const $ n)), mk_ID lhsT);
   1.742 +        val tac =
   1.743 +            EVERY
   1.744 +            [rtac @{thm trans} 1, rtac map_ID_thm 2,
   1.745 +             cut_facts_tac [lub_take_lemma] 1,
   1.746 +             REPEAT (etac @{thm Pair_inject} 1), atac 1];
   1.747 +        val lub_take_thm = Goal.prove_global thy [] [] goal (K tac);
   1.748 +      in
   1.749 +        add_qualified_thm "lub_take" (dbind, lub_take_thm) thy
   1.750 +      end;
   1.751 +    val (lub_take_thms, thy) =
   1.752 +        fold_map prove_lub_take
   1.753 +          (dbinds ~~ take_consts ~~ map_ID_thms ~~ dom_eqns) thy;
   1.754 +
   1.755 +    (* prove additional take theorems *)
   1.756 +    val (take_info2, thy) =
   1.757 +        Domain_Take_Proofs.add_lub_take_theorems
   1.758 +          (dbinds ~~ iso_infos) take_info lub_take_thms thy;
   1.759 +  in
   1.760 +    ((iso_infos, take_info2), thy)
   1.761 +  end;
   1.762 +
   1.763 +val domain_isomorphism = gen_domain_isomorphism cert_typ;
   1.764 +val domain_isomorphism_cmd = snd oo gen_domain_isomorphism read_typ;
   1.765 +
   1.766 +(******************************************************************************)
   1.767 +(******************************** outer syntax ********************************)
   1.768 +(******************************************************************************)
   1.769 +
   1.770 +local
   1.771 +
   1.772 +val parse_domain_iso :
   1.773 +    (string list * binding * mixfix * string * (binding * binding) option)
   1.774 +      parser =
   1.775 +  (Parse.type_args -- Parse.binding -- Parse.opt_mixfix -- (Parse.$$$ "=" |-- Parse.typ) --
   1.776 +    Scan.option (Parse.$$$ "morphisms" |-- Parse.!!! (Parse.binding -- Parse.binding)))
   1.777 +    >> (fn ((((vs, t), mx), rhs), morphs) => (vs, t, mx, rhs, morphs));
   1.778 +
   1.779 +val parse_domain_isos = Parse.and_list1 parse_domain_iso;
   1.780 +
   1.781 +in
   1.782 +
   1.783 +val _ =
   1.784 +  Outer_Syntax.command "domain_isomorphism" "define domain isomorphisms (HOLCF)"
   1.785 +    Keyword.thy_decl
   1.786 +    (parse_domain_isos >> (Toplevel.theory o domain_isomorphism_cmd));
   1.787 +
   1.788 +end;
   1.789 +
   1.790 +end;