src/HOLCF/Tools/Domain/domain_constructors.ML
changeset 35460 8cb42aa19358
parent 35459 3d8acfae6fb8
child 35461 34360a1e3537
--- a/src/HOLCF/Tools/Domain/domain_constructors.ML	Sat Feb 27 14:04:46 2010 -0800
+++ b/src/HOLCF/Tools/Domain/domain_constructors.ML	Sat Feb 27 15:32:42 2010 -0800
@@ -25,7 +25,8 @@
            dist_les : thm list,
            dist_eqs : thm list,
            cases : thm list,
-           sel_rews : thm list
+           sel_rews : thm list,
+           dis_rews : thm list
          } * theory;
 end;
 
@@ -590,7 +591,8 @@
     (con_betas : thm list)
     (casedist : thm)
     (iso_locale : thm)
-    (thy : theory) =
+    (thy : theory)
+    : ((typ -> term) * thm list) * theory =
   let
 
     (* prove rep/abs rules *)
@@ -598,16 +600,17 @@
     val abs_inverse = iso_locale RS @{thm iso.abs_iso};
 
     (* calculate function arguments of case combinator *)
-    val resultT = TVar (("'a",0), @{sort pcpo});
-    val fTs = map (fn (_, args) => map snd args -->> resultT) spec;
+    val resultT = TVar (("'t",0), @{sort pcpo});
+    fun fTs T = map (fn (_, args) => map snd args -->> T) spec;
     val fns = Datatype_Prop.indexify_names (map (K "f") spec);
-    val fs = map Free (fns ~~ fTs);
-    val caseT = fTs -->> (lhsT ->> resultT);
+    val fs = map Free (fns ~~ fTs resultT);
+    fun caseT T = fTs T -->> (lhsT ->> T);
 
     (* TODO: move definition of case combinator here *)
     val case_bind = Binding.name (dname ^ "_when");
-    val case_const = Const (Sign.full_name thy case_bind, caseT);
-    val case_app = list_ccomb (case_const, fs);
+    val case_name = Sign.full_name thy case_bind;
+    fun case_const T = Const (case_name, caseT T);
+    val case_app = list_ccomb (case_const resultT, fs);
 
     (* prove beta reduction rule for case combinator *)
     val case_beta = beta_of_def thy case_def;
@@ -645,7 +648,7 @@
     end
 
   in
-    (case_strict :: case_apps, thy)
+    ((case_const, case_strict :: case_apps), thy)
   end
 
 (******************************************************************************)
@@ -791,6 +794,52 @@
   end
 
 (******************************************************************************)
+(************ definitions and theorems for discriminator functions ************)
+(******************************************************************************)
+
+fun add_discriminators
+    (bindings : binding list)
+    (spec : (term * (bool * typ) list) list)
+    (case_const : typ -> term)
+    (thy : theory) =
+  let
+
+    fun vars_of args =
+      let
+        val Ts = map snd args;
+        val ns = Datatype_Prop.make_tnames Ts;
+      in
+        map Free (ns ~~ Ts)
+      end;
+
+    (* define discriminator functions *)
+    local
+      fun dis_fun i (j, (con, args)) =
+        let
+          val Ts = map snd args;
+          val ns = Datatype_Prop.make_tnames Ts;
+          val vs = map Free (ns ~~ Ts);
+          val tr = if i = j then @{term TT} else @{term FF};
+        in
+          big_lambdas vs tr
+        end;
+      fun dis_eqn (i, bind) : binding * term * mixfix =
+        let
+          val dis_bind = Binding.prefix_name "is_" bind;
+          val rhs = list_ccomb (case_const trT, map_index (dis_fun i) spec);
+        in
+          (dis_bind, rhs, NoSyn)
+        end;
+    in
+      val ((dis_consts, dis_defs), thy) =
+          define_consts (map_index dis_eqn bindings) thy
+    end;
+
+  in
+    (dis_defs, thy)
+  end;
+
+(******************************************************************************)
 (******************************* main function ********************************)
 (******************************************************************************)
 
@@ -823,7 +872,7 @@
     val {con_consts, con_betas, casedist, ...} = con_result;
 
     (* define case combinator *)
-    val (cases : thm list, thy) =
+    val ((case_const : typ -> term, cases : thm list), thy) =
       let
         fun prep_arg (lazy, sel, T) = (lazy, T);
         fun prep_con c (b, args, mx) = (c, map prep_arg args);
@@ -836,14 +885,26 @@
     (* TODO: enable this earlier *)
     val thy = Sign.add_path dname thy;
 
-    (* replace bindings with terms in constructor spec *)
-    val sel_spec : (term * (bool * binding option * typ) list) list =
-      map2 (fn con => fn (b, args, mx) => (con, args)) con_consts spec;
-
     (* define and prove theorems for selector functions *)
     val (sel_thms : thm list, thy : theory) =
-      add_selectors sel_spec rep_const
-        abs_iso_thm rep_strict rep_defined_iff con_betas thy;
+      let
+        val sel_spec : (term * (bool * binding option * typ) list) list =
+          map2 (fn con => fn (b, args, mx) => (con, args)) con_consts spec;
+      in
+        add_selectors sel_spec rep_const
+          abs_iso_thm rep_strict rep_defined_iff con_betas thy
+      end;
+
+    (* define and prove theorems for discriminator functions *)
+    val (dis_thms : thm list, thy : theory) =
+      let
+        val bindings = map #1 spec;
+        fun prep_arg (lazy, sel, T) = (lazy, T);
+        fun prep_con c (b, args, mx) = (c, map prep_arg args);
+        val dis_spec = map2 prep_con con_consts spec;
+      in
+        add_discriminators bindings dis_spec case_const thy
+      end
 
     (* restore original signature path *)
     val thy = Sign.parent_path thy;
@@ -860,7 +921,8 @@
         dist_les = #dist_les con_result,
         dist_eqs = #dist_eqs con_result,
         cases = cases,
-        sel_rews = sel_thms };
+        sel_rews = sel_thms,
+        dis_rews = dis_thms };
   in
     (result, thy)
   end;