src/HOLCF/Tools/Domain/domain_constructors.ML
changeset 35486 c91854705b1d
parent 35485 7d7495f5e35e
child 35487 d1630f317ed0
--- a/src/HOLCF/Tools/Domain/domain_constructors.ML	Mon Mar 01 08:33:49 2010 -0800
+++ b/src/HOLCF/Tools/Domain/domain_constructors.ML	Mon Mar 01 09:55:32 2010 -0800
@@ -11,7 +11,6 @@
       string
       -> (binding * (bool * binding option * typ) list * mixfix) list
       -> Domain_Isomorphism.iso_info
-      -> thm
       -> theory
       -> { con_consts : term list,
            con_betas : thm list,
@@ -367,30 +366,81 @@
     (spec : (term * (bool * typ) list) list)
     (lhsT : typ)
     (dname : string)
-    (case_def : thm)
     (con_betas : thm list)
     (casedist : thm)
     (iso_locale : thm)
+    (rep_const : term)
     (thy : theory)
     : ((typ -> term) * thm list) * theory =
   let
 
+    (* TODO: move these to holcf_library.ML *)
+    fun one_when_const T = Const (@{const_name one_when}, T ->> oneT ->> T);
+    fun mk_one_when t = one_when_const (fastype_of t) ` t;
+    fun mk_sscase (t, u) =
+      let
+        val (T, V) = dest_cfunT (fastype_of t);
+        val (U, V) = dest_cfunT (fastype_of u);
+      in sscase_const (T, U, V) ` t ` u end;
+    fun strictify_const T = Const (@{const_name strictify}, T ->> T);
+    fun mk_strictify t = strictify_const (fastype_of t) ` t;
+    fun ssplit_const (T, U, V) =
+      Const (@{const_name ssplit}, (T ->> U ->> V) ->> mk_sprodT (T, U) ->> V);
+    fun mk_ssplit t =
+      let val (T, (U, V)) = apsnd dest_cfunT (dest_cfunT (fastype_of t));
+      in ssplit_const (T, U, V) ` t end;
+    fun lambda_stuple []      t = mk_one_when t
+      | lambda_stuple [x]     t = mk_strictify (big_lambda x t)
+      | lambda_stuple [x,y]   t = mk_ssplit (big_lambdas [x, y] t)
+      | lambda_stuple (x::xs) t = mk_ssplit (big_lambda x (lambda_stuple xs t));
+
+    (* eta contraction for simplifying definitions *)
+    fun cont_eta_contract (Const(@{const_name Abs_CFun},TT) $ Abs(a,T,body)) = 
+        (case cont_eta_contract body  of
+           body' as (Const(@{const_name Abs_CFun},Ta) $ f $ Bound 0) => 
+           if not (0 mem loose_bnos f) then incr_boundvars ~1 f 
+           else   Const(@{const_name Abs_CFun},TT) $ Abs(a,T,body')
+         | body' => Const(@{const_name Abs_CFun},TT) $ Abs(a,T,body'))
+      | cont_eta_contract(f$t) = cont_eta_contract f $ cont_eta_contract t
+      | cont_eta_contract t    = t;
+
     (* prove rep/abs rules *)
     val rep_strict = iso_locale RS @{thm iso.rep_strict};
     val abs_inverse = iso_locale RS @{thm iso.abs_iso};
 
     (* calculate function arguments of case combinator *)
-    val resultT = TVar (("'t",0), @{sort pcpo});
+    val tns = map (fst o dest_TFree) (snd (dest_Type lhsT));
+    val resultT = TFree (Name.variant tns "'t", @{sort pcpo});
     fun fTs T = map (fn (_, args) => map snd args -->> T) spec;
     val fns = Datatype_Prop.indexify_names (map (K "f") spec);
     val fs = map Free (fns ~~ fTs resultT);
     fun caseT T = fTs T -->> (lhsT ->> T);
 
-    (* TODO: move definition of case combinator here *)
-    val case_bind = Binding.name (dname ^ "_when");
-    val case_name = Sign.full_name thy case_bind;
-    fun case_const T = Const (case_name, caseT T);
-    val case_app = list_ccomb (case_const resultT, fs);
+    (* definition of case combinator *)
+    local
+      val case_bind = Binding.name (dname ^ "_when");
+      fun one_con f (_, args) =
+        let
+          fun argT (lazy, T) = if lazy then mk_upT T else T;
+          fun down (lazy, T) v = if lazy then from_up T ` v else v;
+          val Ts = map argT args;
+          val ns = Name.variant_list fns (Datatype_Prop.make_tnames Ts);
+          val vs = map Free (ns ~~ Ts);
+          val xs = map2 down args vs;
+        in
+          cont_eta_contract (lambda_stuple vs (list_ccomb (f, xs)))
+        end;
+      val body = foldr1 mk_sscase (map2 one_con fs spec);
+      val rhs = big_lambdas fs (mk_cfcomp (body, rep_const));
+      val ((case_consts, case_defs), thy) =
+          define_consts [(case_bind, rhs, NoSyn)] thy;
+      val case_name = Sign.full_name thy case_bind;
+    in
+      val case_def = hd case_defs;
+      fun case_const T = Const (case_name, caseT T);
+      val case_app = list_ccomb (case_const resultT, fs);
+      val thy = thy;
+    end;
 
     (* define syntax for case combinator *)
     (* TODO: re-implement case syntax using a parse translation *)
@@ -441,9 +491,8 @@
     (* prove strictness of case combinator *)
     val case_strict =
       let
-        val defs = [case_beta, mk_meta_eq rep_strict];
-        val lhs = case_app ` mk_bottom lhsT;
-        val goal = mk_trp (mk_eq (lhs, mk_bottom resultT));
+        val defs = case_beta :: map mk_meta_eq [rep_strict, @{thm cfcomp2}];
+        val goal = mk_trp (mk_strict case_app);
         val tacs = [resolve_tac @{thms sscase1 ssplit1 strictify1} 1];
       in prove thy defs goal (K tacs) end;
         
@@ -460,7 +509,8 @@
           val defs = case_beta :: con_betas;
           val rules1 = @{thms sscase2 sscase3 ssplit2 fup2 ID1};
           val rules2 = @{thms con_defined_iff_rules};
-          val rules = abs_inverse :: rules1 @ rules2;
+          val rules3 = @{thms cfcomp2 one_when2};
+          val rules = abs_inverse :: rules1 @ rules2 @ rules3;
           val tacs = [asm_simp_tac (beta_ss addsimps rules) 1];
         in prove thy defs goal (K tacs) end;
     in
@@ -961,7 +1011,6 @@
     (dname : string)
     (spec : (binding * (bool * binding option * typ) list * mixfix) list)
     (iso_info : Domain_Isomorphism.iso_info)
-    (case_def : thm)
     (thy : theory) =
   let
 
@@ -995,7 +1044,7 @@
         val case_spec = map2 prep_con con_consts spec;
       in
         add_case_combinator case_spec lhsT dname
-          case_def con_betas casedist iso_locale thy
+          con_betas casedist iso_locale rep_const thy
       end;
 
     (* qualify constants and theorems with domain name *)