src/HOLCF/Tools/Domain/domain_constructors.ML
changeset 35459 3d8acfae6fb8
parent 35458 deaf221c4a59
child 35460 8cb42aa19358
--- a/src/HOLCF/Tools/Domain/domain_constructors.ML	Sat Feb 27 10:12:47 2010 -0800
+++ b/src/HOLCF/Tools/Domain/domain_constructors.ML	Sat Feb 27 14:04:46 2010 -0800
@@ -12,6 +12,7 @@
       -> typ * (binding * (bool * binding option * typ) list * mixfix) list
       -> term * term
       -> thm * thm
+      -> thm
       -> theory
       -> { con_consts : term list,
            con_betas : thm list,
@@ -23,6 +24,7 @@
            injects : thm list,
            dist_les : thm list,
            dist_eqs : thm list,
+           cases : thm list,
            sel_rews : thm list
          } * theory;
 end;
@@ -267,6 +269,12 @@
 
 val simple_ss : simpset = HOL_basic_ss addsimps simp_thms;
 
+val beta_ss =
+  HOL_basic_ss
+    addsimps simp_thms
+    addsimps [@{thm beta_cfun}]
+    addsimprocs [@{simproc cont_proc}];
+
 fun define_consts
     (specs : (binding * term * mixfix) list)
     (thy : theory)
@@ -571,6 +579,76 @@
   end;
 
 (******************************************************************************)
+(**************** definition and theorems for case combinator *****************)
+(******************************************************************************)
+
+fun add_case_combinator
+    (spec : (term * (bool * typ) list) list)
+    (lhsT : typ)
+    (dname : string)
+    (case_def : thm)
+    (con_betas : thm list)
+    (casedist : thm)
+    (iso_locale : thm)
+    (thy : theory) =
+  let
+
+    (* prove rep/abs rules *)
+    val rep_strict = iso_locale RS @{thm iso.rep_strict};
+    val abs_inverse = iso_locale RS @{thm iso.abs_iso};
+
+    (* calculate function arguments of case combinator *)
+    val resultT = TVar (("'a",0), @{sort pcpo});
+    val fTs = map (fn (_, args) => map snd args -->> resultT) spec;
+    val fns = Datatype_Prop.indexify_names (map (K "f") spec);
+    val fs = map Free (fns ~~ fTs);
+    val caseT = fTs -->> (lhsT ->> resultT);
+
+    (* TODO: move definition of case combinator here *)
+    val case_bind = Binding.name (dname ^ "_when");
+    val case_const = Const (Sign.full_name thy case_bind, caseT);
+    val case_app = list_ccomb (case_const, fs);
+
+    (* prove beta reduction rule for case combinator *)
+    val case_beta = beta_of_def thy case_def;
+
+    (* prove strictness of case combinator *)
+    val case_strict =
+      let
+        val defs = [case_beta, mk_meta_eq rep_strict];
+        val lhs = case_app ` mk_bottom lhsT;
+        val goal = mk_trp (mk_eq (lhs, mk_bottom resultT));
+        val tacs = [resolve_tac @{thms sscase1 ssplit1 strictify1} 1];
+      in prove thy defs goal (K tacs) end;
+        
+    (* prove rewrites for case combinator *)
+    local
+      fun one_case (con, args) f =
+        let
+          val Ts = map snd args;
+          val ns = Name.variant_list fns (Datatype_Prop.make_tnames Ts);
+          val vs = map Free (ns ~~ Ts);
+          val nonlazy = map snd (filter_out (fst o fst) (args ~~ vs));
+          val assms = map (mk_trp o mk_defined) nonlazy;
+          val lhs = case_app ` list_ccomb (con, vs);
+          val rhs = list_ccomb (f, vs);
+          val concl = mk_trp (mk_eq (lhs, rhs));
+          val goal = Logic.list_implies (assms, concl);
+          val defs = case_beta :: con_betas;
+          val rules1 = @{thms sscase2 sscase3 ssplit2 fup2 ID1};
+          val rules2 = @{thms con_defined_iff_rules};
+          val rules = abs_inverse :: rules1 @ rules2;
+          val tacs = [asm_simp_tac (beta_ss addsimps rules) 1];
+        in prove thy defs goal (K tacs) end;
+    in
+      val case_apps = map2 one_case spec fs;
+    end
+
+  in
+    (case_strict :: case_apps, thy)
+  end
+
+(******************************************************************************)
 (************** definitions and theorems for selector functions ***************)
 (******************************************************************************)
 
@@ -722,6 +800,7 @@
      spec : (binding * (bool * binding option * typ) list * mixfix) list)
     (rep_const : term, abs_const : term)
     (rep_iso_thm : thm, abs_iso_thm : thm)
+    (case_def : thm)
     (thy : theory) =
   let
 
@@ -741,7 +820,18 @@
       in
         add_constructors con_spec abs_const iso_locale thy
       end;
-    val {con_consts, con_betas, ...} = con_result;
+    val {con_consts, con_betas, casedist, ...} = con_result;
+
+    (* define case combinator *)
+    val (cases : thm list, thy) =
+      let
+        fun prep_arg (lazy, sel, T) = (lazy, T);
+        fun prep_con c (b, args, mx) = (c, map prep_arg args);
+        val case_spec = map2 prep_con con_consts spec;
+      in
+        add_case_combinator case_spec lhsT dname
+          case_def con_betas casedist iso_locale thy
+      end;
 
     (* TODO: enable this earlier *)
     val thy = Sign.add_path dname thy;
@@ -762,13 +852,14 @@
       { con_consts = con_consts,
         con_betas = con_betas,
         exhaust = #exhaust con_result,
-        casedist = #casedist con_result,
+        casedist = casedist,
         con_compacts = #con_compacts con_result,
         con_rews = #con_rews con_result,
         inverts = #inverts con_result,
         injects = #injects con_result,
         dist_les = #dist_les con_result,
         dist_eqs = #dist_eqs con_result,
+        cases = cases,
         sel_rews = sel_thms };
   in
     (result, thy)