automate definition of map functions; remove unused code
authorhuffman
Thu, 19 Nov 2009 09:04:58 -0800
changeset 33785 2f2d9eb37084
parent 33784 7e434813752f
child 33786 d280c5ebd7d7
automate definition of map functions; remove unused code
src/HOLCF/Tools/Domain/domain_isomorphism.ML
--- a/src/HOLCF/Tools/Domain/domain_isomorphism.ML	Thu Nov 19 08:22:00 2009 -0800
+++ b/src/HOLCF/Tools/Domain/domain_isomorphism.ML	Thu Nov 19 09:04:58 2009 -0800
@@ -41,6 +41,9 @@
 
 val deflT = @{typ "udom alg_defl"};
 
+fun mapT (T as Type (_, Ts)) =
+  Library.foldr cfunT (map (fn T => T ->> T) Ts, T ->> T);     
+
 (******************************************************************************)
 (******************************* building terms *******************************)
 (******************************************************************************)
@@ -98,6 +101,21 @@
   let val (T, _) = dest_cfunT (Term.fastype_of t)
   in mk_capply (Const(@{const_name fix}, (T ->> T) ->> T), t) end;
 
+fun ID_const T = Const (@{const_name ID}, cfunT (T, T));
+
+fun cfcomp_const (T, U, V) =
+  Const (@{const_name cfcomp}, (U ->> V) ->> (T ->> U) ->> (T ->> V));
+
+fun mk_cfcomp (f, g) =
+  let
+    val (U, V) = dest_cfunT (Term.fastype_of f);
+    val (T, U') = dest_cfunT (Term.fastype_of g);
+  in
+    if U = U'
+    then mk_capply (mk_capply (cfcomp_const (T, U, V), f), g)
+    else raise TYPE ("mk_cfcomp", [U, U'], [f, g])
+  end;
+
 fun mk_Rep_of T =
   Const (@{const_name Rep_of}, Term.itselfT T --> deflT) $ Logic.mk_type T;
 
@@ -178,23 +196,6 @@
 
 (******************************************************************************)
 
-fun typ_of_dtyp
-    (descr : (string * string list) list)
-    (sorts : (string * sort) list)
-    : DatatypeAux.dtyp -> typ =
-  let
-    fun tfree a = TFree (a, the (AList.lookup (op =) sorts a))
-    fun typ_of (DatatypeAux.DtTFree a) = tfree a
-      | typ_of (DatatypeAux.DtType (s, ds)) = Type (s, map typ_of ds)
-      | typ_of (DatatypeAux.DtRec i) =
-          let val (s, vs) = nth descr i
-          in Type (s, map tfree vs) end
-  in typ_of end;
-
-fun is_closed_dtyp (DatatypeAux.DtTFree a) = false
-  | is_closed_dtyp (DatatypeAux.DtRec i) = false
-  | is_closed_dtyp (DatatypeAux.DtType (s, ds)) = forall is_closed_dtyp ds;
-
 (* FIXME: use theory data for this *)
 val defl_tab : term Symtab.table =
     Symtab.make [(@{type_name "->"}, @{term "cfun_defl"}),
@@ -208,12 +209,11 @@
 
 fun defl_of_typ
     (tab : term Symtab.table)
-    (free : string -> term)
     (T : typ) : term =
   let
     fun is_closed_typ (Type (_, Ts)) = forall is_closed_typ Ts
       | is_closed_typ _ = false;
-    fun defl_of (TFree (a, _)) = free a
+    fun defl_of (TFree (a, _)) = Free (Library.unprefix "'" a, deflT)
       | defl_of (TVar _) = error ("defl_of_typ: TVar")
       | defl_of (T as Type (c, Ts)) =
         case Symtab.lookup tab c of
@@ -223,25 +223,33 @@
                   else error ("defl_of_typ: type variable under unsupported type constructor " ^ c);
   in defl_of T end;
 
-fun defl_of_dtyp
-    (descr : (string * string list) list)
-    (sorts : (string * sort) list)
-    (f : string -> term)
-    (r : int -> term)
-    (dt : DatatypeAux.dtyp) : term =
+(* FIXME: use theory data for this *)
+val map_tab : string Symtab.table =
+    Symtab.make [(@{type_name "->"}, @{const_name "cfun_map"}),
+                 (@{type_name "++"}, @{const_name "ssum_map"}),
+                 (@{type_name "**"}, @{const_name "sprod_map"}),
+                 (@{type_name "*"}, @{const_name "cprod_map"}),
+                 (@{type_name "u"}, @{const_name "u_map"}),
+                 (@{type_name "upper_pd"}, @{const_name "upper_map"}),
+                 (@{type_name "lower_pd"}, @{const_name "lower_map"}),
+                 (@{type_name "convex_pd"}, @{const_name "convex_map"})];
+
+fun map_of_typ
+    (tab : string Symtab.table)
+    (T : typ) : term =
   let
-    fun tfree a = TFree (a, the (AList.lookup (op =) sorts a))
-    fun defl_of (DatatypeAux.DtTFree a) = f a
-      | defl_of (DatatypeAux.DtRec i) = r i
-      | defl_of (dt as DatatypeAux.DtType (s, ds)) =
-        case Symtab.lookup defl_tab s of
-          SOME t => Library.foldl mk_capply (t, map defl_of ds)
-        | NONE => if DatatypeAux.is_rec_type dt
-                  then error ("defl_of_dtyp: recursion under unsupported type constructor " ^ s)
-                  else if is_closed_dtyp dt
-                  then mk_Rep_of (typ_of_dtyp descr sorts dt)
-                  else error ("defl_of_dtyp: type variable under unsupported type constructor " ^ s);
-  in defl_of dt end;
+    fun is_closed_typ (Type (_, Ts)) = forall is_closed_typ Ts
+      | is_closed_typ _ = false;
+    fun map_of (T as TFree (a, _)) = Free (Library.unprefix "'" a, T ->> T)
+      | map_of (T as TVar _) = error ("map_of_typ: TVar")
+      | map_of (T as Type (c, Ts)) =
+        case Symtab.lookup tab c of
+          SOME t => Library.foldl mk_capply (Const (t, mapT T), map map_of Ts)
+        | NONE => if is_closed_typ T
+                  then ID_const T
+                  else error ("map_of_typ: type variable under unsupported type constructor " ^ c);
+  in map_of T end;
+
 
 (******************************************************************************)
 (* prepare datatype specifications *)
@@ -321,10 +329,9 @@
     val defl_tab2 =
       Symtab.make (map (fst o dest_Type o fst) dom_eqns ~~ defl_consts);
     val defl_tab' = Symtab.merge (K true) (defl_tab1, defl_tab2);
-    fun free a = Free (Library.unprefix "'" a, deflT);
     fun mk_defl_spec (lhsT, rhsT) =
-      mk_eqs (defl_of_typ defl_tab' free lhsT,
-              defl_of_typ defl_tab' free rhsT);
+      mk_eqs (defl_of_typ defl_tab' lhsT,
+              defl_of_typ defl_tab' rhsT);
     val defl_specs = map mk_defl_spec dom_eqns;
 
     (* register recursive definition of deflation combinators *)
@@ -385,10 +392,11 @@
             [(Binding.suffix_name "_rep_def" tbind, rep_eqn),
              (Binding.suffix_name "_abs_def" tbind, abs_eqn)];
       in
-        ((rep_def, abs_def), thy)
+        (((rep_const, abs_const), (rep_def, abs_def)), thy)
       end;
-    val (rep_abs_defs, thy) = thy |>
-      fold_map mk_rep_abs (dom_binds ~~ dom_eqns);
+    val ((rep_abs_consts, rep_abs_defs), thy) = thy
+      |> fold_map mk_rep_abs (dom_binds ~~ dom_eqns)
+      |>> ListPair.unzip;
 
     (* prove isomorphism and isodefl rules *)
     fun mk_iso_thms ((tbind, REP_eq), (rep_def, abs_def)) thy =
@@ -412,6 +420,34 @@
       |> fold_map mk_iso_thms (dom_binds ~~ REP_eq_thms ~~ rep_abs_defs)
       |>> ListPair.unzip;
 
+    (* declare map functions *)
+    fun declare_map_const (tbind, (lhsT, rhsT)) thy =
+      let
+        val map_type = mapT lhsT;
+        val map_bind = Binding.suffix_name "_map" tbind;
+      in
+        Sign.declare_const ((map_bind, map_type), NoSyn) thy
+      end;
+    val (map_consts, thy) = thy |>
+      fold_map declare_map_const (dom_binds ~~ dom_eqns);
+
+    (* defining equations for map functions *)
+    val map_tab1 = map_tab; (* FIXME: use theory data *)
+    val map_tab2 =
+      Symtab.make (map (fst o dest_Type o fst) dom_eqns
+                   ~~ map (fst o dest_Const) map_consts);
+    val map_tab' = Symtab.merge (K true) (map_tab1, map_tab2);
+    fun mk_map_spec ((rep_const, abs_const), (lhsT, rhsT)) =
+      let
+        val lhs = map_of_typ map_tab' lhsT;
+        val body = map_of_typ map_tab' rhsT;
+        val rhs = mk_cfcomp (abs_const, mk_cfcomp (body, rep_const));
+      in mk_eqs (lhs, rhs) end;
+    val map_specs = map mk_map_spec (rep_abs_consts ~~ dom_eqns);
+
+    (* register recursive definition of map functions *)
+    val map_binds = map (Binding.suffix_name "_map") dom_binds;
+    val (map_unfold_thms, thy) = add_fixdefs (map_binds ~~ map_specs) thy;
   in
     thy
   end;