cleaned up; factored out fixed-point definition code
authorhuffman
Wed, 18 Nov 2009 15:01:00 -0800
changeset 33775 7a1518c42c56
parent 33774 e11e05b32548
child 33776 5048b02c2bbb
cleaned up; factored out fixed-point definition code
src/HOLCF/Tools/Domain/domain_isomorphism.ML
--- a/src/HOLCF/Tools/Domain/domain_isomorphism.ML	Wed Nov 18 12:41:43 2009 -0800
+++ b/src/HOLCF/Tools/Domain/domain_isomorphism.ML	Wed Nov 18 15:01:00 2009 -0800
@@ -15,6 +15,14 @@
 structure Domain_Isomorphism :> DOMAIN_ISOMORPHISM =
 struct
 
+val beta_ss =
+  HOL_basic_ss
+    addsimps simp_thms
+    addsimps [@{thm beta_cfun}]
+    addsimprocs [@{simproc cont_proc}];
+
+val beta_tac = simp_tac beta_ss;
+
 (******************************************************************************)
 (******************************* building types *******************************)
 (******************************************************************************)
@@ -79,6 +87,9 @@
 
 val mk_trp = HOLogic.mk_Trueprop;
 
+val mk_fst = HOLogic.mk_fst;
+val mk_snd = HOLogic.mk_snd;
+
 fun mk_cont t =
   let val T = Term.fastype_of t
   in Const(@{const_name cont}, T --> HOLogic.boolT) $ t end;
@@ -90,6 +101,79 @@
 fun mk_Rep_of T =
   Const (@{const_name Rep_of}, Term.itselfT T --> deflT) $ Logic.mk_type T;
 
+(* splits a cterm into the right and lefthand sides of equality *)
+fun dest_eqs t = HOLogic.dest_eq (HOLogic.dest_Trueprop t);
+
+fun mk_eqs (t, u) = HOLogic.mk_Trueprop (HOLogic.mk_eq (t, u));
+
+(******************************************************************************)
+(*************** fixed-point definitions and unfolding theorems ***************)
+(******************************************************************************)
+
+fun add_fixdefs
+    (spec : (binding * term) list)
+    (thy : theory) : thm list * theory =
+  let
+    val binds = map fst spec;
+    val (lhss, rhss) = ListPair.unzip (map (dest_eqs o snd) spec);
+    val functional = lambda_tuple lhss (mk_tuple rhss);
+    val fixpoint = mk_fix (mk_cabs functional);
+
+    (* project components of fixpoint *)
+    fun mk_projs (x::[]) t = [(x, t)]
+      | mk_projs (x::xs) t = (x, mk_fst t) :: mk_projs xs (mk_snd t);
+    val projs = mk_projs lhss fixpoint;
+
+    (* convert parameters to lambda abstractions *)
+    fun mk_eqn (lhs, rhs) =
+        case lhs of
+          Const (@{const_name Rep_CFun}, _) $ f $ (x as Free _) =>
+            mk_eqn (f, big_lambda x rhs)
+        | Const _ => Logic.mk_equals (lhs, rhs)
+        | _ => raise TERM ("lhs not of correct form", [lhs, rhs]);
+    val eqns = map mk_eqn projs;
+
+    (* register constant definitions *)
+    val (fixdef_thms, thy2) =
+      (PureThy.add_defs false o map Thm.no_attributes)
+        (map (Binding.suffix_name "_def") binds ~~ eqns) thy;
+
+    (* prove applied version of definitions *)
+    fun prove_proj (lhs, rhs) =
+      let
+        val tac = rewrite_goals_tac fixdef_thms THEN beta_tac 1;
+        val goal = Logic.mk_equals (lhs, rhs);
+      in Goal.prove_global thy2 [] [] goal (K tac) end;
+    val proj_thms = map prove_proj projs;
+
+    (* mk_tuple lhss == fixpoint *)
+    fun pair_equalI (thm1, thm2) = @{thm Pair_equalI} OF [thm1, thm2];
+    val tuple_fixdef_thm = foldr1 pair_equalI proj_thms;
+
+    val cont_thm =
+      Goal.prove_global thy2 [] [] (mk_trp (mk_cont functional))
+        (K (beta_tac 1));
+    val tuple_unfold_thm =
+      (@{thm def_cont_fix_eq} OF [tuple_fixdef_thm, cont_thm])
+      |> LocalDefs.unfold (ProofContext.init thy2) @{thms split_conv};
+
+    fun mk_unfold_thms [] thm = []
+      | mk_unfold_thms (n::[]) thm = [(n, thm)]
+      | mk_unfold_thms (n::ns) thm = let
+          val thmL = thm RS @{thm Pair_eqD1};
+          val thmR = thm RS @{thm Pair_eqD2};
+        in (n, thmL) :: mk_unfold_thms ns thmR end;
+    val unfold_binds = map (Binding.suffix_name "_unfold") binds;
+
+    (* register unfold theorems *)
+    val (unfold_thms, thy3) =
+      (PureThy.add_thms o map (Thm.no_attributes o apsnd Drule.standard))
+        (mk_unfold_thms unfold_binds tuple_unfold_thm) thy2;
+  in
+    (unfold_thms, thy3)
+  end;
+
+
 (******************************************************************************)
 
 fun typ_of_dtyp
@@ -130,7 +214,7 @@
     fun defl_of (TFree (a, _)) = free a
       | defl_of (TVar _) = error ("defl_of_typ: TVar")
       | defl_of (T as Type (c, Ts)) =
-        case Symtab.lookup defl_tab c of
+        case Symtab.lookup tab c of
           SOME t => Library.foldl mk_capply (t, map defl_of Ts)
         | NONE => if is_closed_typ T
                   then mk_Rep_of T
@@ -200,114 +284,49 @@
          sorts : (string * sort) list) =
       fold_map (prep_dom tmp_thy) doms_raw [];
 
+    (* domain equations *)
+    fun mk_dom_eqn (vs, tbind, mx, rhs) =
+      let fun arg v = TFree (v, the (AList.lookup (op =) sorts v));
+      in (Type (Sign.full_name tmp_thy tbind, map arg vs), rhs) end;
+    val dom_eqns = map mk_dom_eqn doms;
+
+    (* check for valid type parameters *)
     val (tyvars, _, _, _)::_ = doms;
-    val (new_doms, types_syntax) = ListPair.unzip (map (fn (tvs, tname, mx, _) =>
+    val new_doms = map (fn (tvs, tname, mx, _) =>
       let val full_tname = Sign.full_name tmp_thy tname
       in
         (case duplicates (op =) tvs of
           [] =>
-            if eq_set (op =) (tyvars, tvs) then ((full_tname, tvs), (tname, mx))
+            if eq_set (op =) (tyvars, tvs) then (full_tname, tvs)
             else error ("Mutually recursive domains must have same type parameters")
         | dups => error ("Duplicate parameter(s) for domain " ^ quote (Binding.str_of tname) ^
             " : " ^ commas dups))
-      end) doms);
+      end) doms;
     val dom_names = map fst new_doms;
 
-    val dtyps =
-      map (fn (vs, t, mx, rhs) => DatatypeAux.dtyp_of_typ new_doms rhs) doms;
-
-    fun unprime a = Library.unprefix "'" a;
-    fun free_defl a = Free (a, deflT);
-
-    val (ts, rs) =
-      let
-        val used = map unprime tyvars;
-        val i = length doms;
-        val ns = map (fn i => "r" ^ ML_Syntax.print_int i) (1 upto i);
-        val ns' = Name.variant_list used ns;
-      in (map free_defl used, map free_defl ns') end;
-
-    val defls =
-      map (defl_of_dtyp new_doms sorts (free_defl o unprime) (nth rs)) dtyps;
-    val functional = lambda_tuple rs (mk_tuple defls);
-    val fixpoint = mk_fix (mk_cabs functional);
-
-    fun projs t (_::[]) = [t]
-      | projs t (_::xs) = HOLogic.mk_fst t :: projs (HOLogic.mk_snd t) xs;
-    fun typ_eqn ((tvs, tbind, mx, _), t) =
+    (* declare type combinator constants *)
+    fun declare_typ_const (vs, tbind, mx, rhs) thy =
       let
-        val typ_type = Library.foldr cfunT (map (K deflT) tvs, deflT);
+        val typ_type = Library.foldr cfunT (map (K deflT) vs, deflT);
         val typ_bind = Binding.suffix_name "_typ" tbind;
-        val typ_name = Sign.full_name tmp_thy typ_bind;
-        val typ_const = Const (typ_name, typ_type);
-        val args = map (free_defl o unprime) tvs;
-        val typ_rhs = big_lambdas args t;
-        val typ_eqn = Logic.mk_equals (typ_const, typ_rhs);
-        val typ_beta = Logic.mk_equals
-          (Library.foldl mk_capply (typ_const, args), t);
-        val typ_syn = (typ_bind, typ_type, NoSyn);
-        val typ_def = (Binding.suffix_name "_def" typ_bind, typ_eqn);
       in
-        ((typ_syn, typ_def), (typ_beta, typ_const))
+        Sign.declare_const ((typ_bind, typ_type), NoSyn) thy
       end;
-    val ((typ_syns, typ_defs), (typ_betas, typ_consts)) =
-      map typ_eqn (doms ~~ projs fixpoint doms)
-      |> ListPair.unzip
-      |> apfst ListPair.unzip
-      |> apsnd ListPair.unzip;
-    val (typ_def_thms, thy2) =
-      thy
-      |> Sign.add_consts_i typ_syns
-      |> (PureThy.add_defs false o map Thm.no_attributes) typ_defs;
+    val (typ_consts, thy2) = fold_map declare_typ_const doms thy;
 
-    val beta_ss = HOL_basic_ss
-      addsimps simp_thms
-      addsimps [@{thm beta_cfun}]
-      addsimprocs [@{simproc cont_proc}];
-    val beta_tac = rewrite_goals_tac typ_def_thms THEN simp_tac beta_ss 1;
-    val typ_beta_thms =
-      map (fn t => Goal.prove_global thy2 [] [] t (K beta_tac)) typ_betas;
-
-    fun pair_equalI (thm1, thm2) = @{thm Pair_equalI} OF [thm1, thm2];
-    val tuple_typ_thm = Drule.standard (foldr1 pair_equalI typ_beta_thms);
-
-    val tuple_cont_thm =
-      Goal.prove_global thy2 [] [] (mk_trp (mk_cont functional))
-        (K (simp_tac beta_ss 1));
-    val tuple_unfold_thm =
-      (@{thm def_cont_fix_eq} OF [tuple_typ_thm, tuple_cont_thm])
-      |> LocalDefs.unfold (ProofContext.init thy2) @{thms split_conv};
+    (* defining equations for type combinators *)
+    val defl_tab1 = defl_tab; (* FIXME: use theory data *)
+    val defl_tab2 =
+      Symtab.make (map (fst o dest_Type o fst) dom_eqns ~~ typ_consts);
+    val defl_tab' = Symtab.merge (K true) (defl_tab1, defl_tab2);
+    fun free a = Free (Library.unprefix "'" a, deflT);
+    fun mk_defl_spec (lhs, rhs) =
+      mk_eqs (defl_of_typ defl_tab' free lhs, defl_of_typ defl_tab' free rhs);
+    val defl_specs = map mk_defl_spec dom_eqns;
 
-    fun typ_unfold_eqn ((tvs, tbind, mx, _), t) =
-      let
-        val typ_type = Library.foldr cfunT (map (K deflT) tvs, deflT);
-        val typ_bind = Binding.suffix_name "_typ" tbind;
-        val typ_name = Sign.full_name tmp_thy typ_bind;
-        val typ_const = Const (typ_name, typ_type);
-        val args = map (free_defl o unprime) tvs;
-        val typ_rhs = big_lambdas args t;
-        val typ_eqn = Logic.mk_equals (typ_const, typ_rhs);
-        val typ_beta = Logic.mk_equals
-          (Library.foldl mk_capply (typ_const, args), t);
-        val typ_syn = (typ_bind, typ_type, NoSyn);
-        val typ_def = (Binding.suffix_name "_def" typ_bind, typ_eqn);
-      in
-        ((typ_syn, typ_def), (typ_beta, typ_const))
-      end;
-
-    val typ_unfold_names =
-      map (Binding.suffix_name "_typ_unfold" o #2) doms;
-    fun unfolds [] thm = []
-      | unfolds (n::[]) thm = [(n, thm)]
-      | unfolds (n::ns) thm = let
-          val thmL = thm RS @{thm Pair_eqD1};
-          val thmR = thm RS @{thm Pair_eqD2};
-        in (n, thmL) :: unfolds ns thmR end;
-    val typ_unfold_thms =
-      map (apsnd Drule.standard) (unfolds typ_unfold_names tuple_unfold_thm);
-
-    val (_, thy3) = thy2
-      |> (PureThy.add_thms o map Thm.no_attributes) typ_unfold_thms;
+    (* register recursive definition of type combinators *)
+    val typ_binds = map (Binding.suffix_name "_typ" o #2) doms;
+    val (typ_unfold_thms, thy3) = add_fixdefs (typ_binds ~~ defl_specs) thy2;
 
     fun make_repdef ((vs, tbind, mx, _), typ_const) thy =
       let