move definition of case combinator to domain_constructors.ML
authorhuffman
Mon, 01 Mar 2010 09:55:32 -0800
changeset 35486 c91854705b1d
parent 35485 7d7495f5e35e
child 35487 d1630f317ed0
move definition of case combinator to domain_constructors.ML
src/HOLCF/Tools/Domain/domain_axioms.ML
src/HOLCF/Tools/Domain/domain_constructors.ML
src/HOLCF/Tools/Domain/domain_syntax.ML
src/HOLCF/Tools/Domain/domain_theorems.ML
--- a/src/HOLCF/Tools/Domain/domain_axioms.ML	Mon Mar 01 08:33:49 2010 -0800
+++ b/src/HOLCF/Tools/Domain/domain_axioms.ML	Mon Mar 01 09:55:32 2010 -0800
@@ -67,10 +67,6 @@
     val abs_iso_ax = ("abs_iso", mk_trp(dc_rep`(dc_abs`%x_name') === %:x_name'));
     val rep_iso_ax = ("rep_iso", mk_trp(dc_abs`(dc_rep`%x_name') === %:x_name'));
 
-    val when_def = ("when_def",%%:(dname^"_when") == 
-        List.foldr (uncurry /\ ) (/\x_name'((when_body cons (fn (x,y) =>
-          Bound(1+length cons+x-y)))`(dc_rep`Bound 0))) (when_funs cons));
-
     val copy_def =
       let fun r i = proj (Bound 0) eqs i;
       in
@@ -95,7 +91,7 @@
 
   in (dnam,
       (if definitional then [] else [abs_iso_ax, rep_iso_ax, reach_ax]),
-      (if definitional then [when_def] else [when_def, copy_def]) @
+      (if definitional then [] else [copy_def]) @
       [take_def, finite_def])
   end; (* let (calc_axioms) *)
 
--- a/src/HOLCF/Tools/Domain/domain_constructors.ML	Mon Mar 01 08:33:49 2010 -0800
+++ b/src/HOLCF/Tools/Domain/domain_constructors.ML	Mon Mar 01 09:55:32 2010 -0800
@@ -11,7 +11,6 @@
       string
       -> (binding * (bool * binding option * typ) list * mixfix) list
       -> Domain_Isomorphism.iso_info
-      -> thm
       -> theory
       -> { con_consts : term list,
            con_betas : thm list,
@@ -367,30 +366,81 @@
     (spec : (term * (bool * typ) list) list)
     (lhsT : typ)
     (dname : string)
-    (case_def : thm)
     (con_betas : thm list)
     (casedist : thm)
     (iso_locale : thm)
+    (rep_const : term)
     (thy : theory)
     : ((typ -> term) * thm list) * theory =
   let
 
+    (* TODO: move these to holcf_library.ML *)
+    fun one_when_const T = Const (@{const_name one_when}, T ->> oneT ->> T);
+    fun mk_one_when t = one_when_const (fastype_of t) ` t;
+    fun mk_sscase (t, u) =
+      let
+        val (T, V) = dest_cfunT (fastype_of t);
+        val (U, V) = dest_cfunT (fastype_of u);
+      in sscase_const (T, U, V) ` t ` u end;
+    fun strictify_const T = Const (@{const_name strictify}, T ->> T);
+    fun mk_strictify t = strictify_const (fastype_of t) ` t;
+    fun ssplit_const (T, U, V) =
+      Const (@{const_name ssplit}, (T ->> U ->> V) ->> mk_sprodT (T, U) ->> V);
+    fun mk_ssplit t =
+      let val (T, (U, V)) = apsnd dest_cfunT (dest_cfunT (fastype_of t));
+      in ssplit_const (T, U, V) ` t end;
+    fun lambda_stuple []      t = mk_one_when t
+      | lambda_stuple [x]     t = mk_strictify (big_lambda x t)
+      | lambda_stuple [x,y]   t = mk_ssplit (big_lambdas [x, y] t)
+      | lambda_stuple (x::xs) t = mk_ssplit (big_lambda x (lambda_stuple xs t));
+
+    (* eta contraction for simplifying definitions *)
+    fun cont_eta_contract (Const(@{const_name Abs_CFun},TT) $ Abs(a,T,body)) = 
+        (case cont_eta_contract body  of
+           body' as (Const(@{const_name Abs_CFun},Ta) $ f $ Bound 0) => 
+           if not (0 mem loose_bnos f) then incr_boundvars ~1 f 
+           else   Const(@{const_name Abs_CFun},TT) $ Abs(a,T,body')
+         | body' => Const(@{const_name Abs_CFun},TT) $ Abs(a,T,body'))
+      | cont_eta_contract(f$t) = cont_eta_contract f $ cont_eta_contract t
+      | cont_eta_contract t    = t;
+
     (* prove rep/abs rules *)
     val rep_strict = iso_locale RS @{thm iso.rep_strict};
     val abs_inverse = iso_locale RS @{thm iso.abs_iso};
 
     (* calculate function arguments of case combinator *)
-    val resultT = TVar (("'t",0), @{sort pcpo});
+    val tns = map (fst o dest_TFree) (snd (dest_Type lhsT));
+    val resultT = TFree (Name.variant tns "'t", @{sort pcpo});
     fun fTs T = map (fn (_, args) => map snd args -->> T) spec;
     val fns = Datatype_Prop.indexify_names (map (K "f") spec);
     val fs = map Free (fns ~~ fTs resultT);
     fun caseT T = fTs T -->> (lhsT ->> T);
 
-    (* TODO: move definition of case combinator here *)
-    val case_bind = Binding.name (dname ^ "_when");
-    val case_name = Sign.full_name thy case_bind;
-    fun case_const T = Const (case_name, caseT T);
-    val case_app = list_ccomb (case_const resultT, fs);
+    (* definition of case combinator *)
+    local
+      val case_bind = Binding.name (dname ^ "_when");
+      fun one_con f (_, args) =
+        let
+          fun argT (lazy, T) = if lazy then mk_upT T else T;
+          fun down (lazy, T) v = if lazy then from_up T ` v else v;
+          val Ts = map argT args;
+          val ns = Name.variant_list fns (Datatype_Prop.make_tnames Ts);
+          val vs = map Free (ns ~~ Ts);
+          val xs = map2 down args vs;
+        in
+          cont_eta_contract (lambda_stuple vs (list_ccomb (f, xs)))
+        end;
+      val body = foldr1 mk_sscase (map2 one_con fs spec);
+      val rhs = big_lambdas fs (mk_cfcomp (body, rep_const));
+      val ((case_consts, case_defs), thy) =
+          define_consts [(case_bind, rhs, NoSyn)] thy;
+      val case_name = Sign.full_name thy case_bind;
+    in
+      val case_def = hd case_defs;
+      fun case_const T = Const (case_name, caseT T);
+      val case_app = list_ccomb (case_const resultT, fs);
+      val thy = thy;
+    end;
 
     (* define syntax for case combinator *)
     (* TODO: re-implement case syntax using a parse translation *)
@@ -441,9 +491,8 @@
     (* prove strictness of case combinator *)
     val case_strict =
       let
-        val defs = [case_beta, mk_meta_eq rep_strict];
-        val lhs = case_app ` mk_bottom lhsT;
-        val goal = mk_trp (mk_eq (lhs, mk_bottom resultT));
+        val defs = case_beta :: map mk_meta_eq [rep_strict, @{thm cfcomp2}];
+        val goal = mk_trp (mk_strict case_app);
         val tacs = [resolve_tac @{thms sscase1 ssplit1 strictify1} 1];
       in prove thy defs goal (K tacs) end;
         
@@ -460,7 +509,8 @@
           val defs = case_beta :: con_betas;
           val rules1 = @{thms sscase2 sscase3 ssplit2 fup2 ID1};
           val rules2 = @{thms con_defined_iff_rules};
-          val rules = abs_inverse :: rules1 @ rules2;
+          val rules3 = @{thms cfcomp2 one_when2};
+          val rules = abs_inverse :: rules1 @ rules2 @ rules3;
           val tacs = [asm_simp_tac (beta_ss addsimps rules) 1];
         in prove thy defs goal (K tacs) end;
     in
@@ -961,7 +1011,6 @@
     (dname : string)
     (spec : (binding * (bool * binding option * typ) list * mixfix) list)
     (iso_info : Domain_Isomorphism.iso_info)
-    (case_def : thm)
     (thy : theory) =
   let
 
@@ -995,7 +1044,7 @@
         val case_spec = map2 prep_con con_consts spec;
       in
         add_case_combinator case_spec lhsT dname
-          case_def con_betas casedist iso_locale thy
+          con_betas casedist iso_locale rep_const thy
       end;
 
     (* qualify constants and theorems with domain name *)
--- a/src/HOLCF/Tools/Domain/domain_syntax.ML	Mon Mar 01 08:33:49 2010 -0800
+++ b/src/HOLCF/Tools/Domain/domain_syntax.ML	Mon Mar 01 09:55:32 2010 -0800
@@ -43,7 +43,6 @@
                                            | _ => foldr1 mk_sprodT (map opt_lazy args);
       fun freetvar s = let val tvar = mk_TFree s in
                          if tvar mem typevars then freetvar ("t"^s) else tvar end;
-      fun when_type (_,args,_) = List.foldr (op ->>) (freetvar "t") (map third args);
     in
     val dtype  = Type(dname,typevars);
     val dtype2 = foldr1 mk_ssumT (map prod cons');
@@ -51,7 +50,6 @@
     fun dbind s = Binding.name (dnam ^ s);
     val const_rep  = (dbind "_rep" ,              dtype  ->> dtype2, NoSyn);
     val const_abs  = (dbind "_abs" ,              dtype2 ->> dtype , NoSyn);
-    val const_when = (dbind "_when", List.foldr (op ->>) (dtype ->> freetvar "t") (map when_type cons'), NoSyn);
     val const_copy = (dbind "_copy", dtypeprod ->> dtype  ->> dtype , NoSyn);
     end;
 
@@ -63,8 +61,7 @@
     val optional_consts =
         if definitional then [] else [const_rep, const_abs, const_copy];
 
-  in (optional_consts @ [const_when] @ 
-      [const_take, const_finite])
+  in (optional_consts @ [const_take, const_finite])
   end; (* let *)
 
 (* ----- putting all the syntax stuff together ------------------------------ *)
--- a/src/HOLCF/Tools/Domain/domain_theorems.ML	Mon Mar 01 08:33:49 2010 -0800
+++ b/src/HOLCF/Tools/Domain/domain_theorems.ML	Mon Mar 01 09:55:32 2010 -0800
@@ -120,7 +120,6 @@
 in
   val ax_abs_iso  = ga "abs_iso"  dname;
   val ax_rep_iso  = ga "rep_iso"  dname;
-  val ax_when_def = ga "when_def" dname;
   val ax_copy_def = ga "copy_def" dname;
 end; (* local *)
 
@@ -154,7 +153,7 @@
 
 val (result, thy) =
   Domain_Constructors.add_domain_constructors
-    (Long_Name.base_name dname) (snd dom_eqn) iso_info ax_when_def thy;
+    (Long_Name.base_name dname) (snd dom_eqn) iso_info thy;
 
 val con_appls = #con_betas result;
 val {exhaust, casedist, ...} = result;