move definition and syntax of pattern combinators into domain_constructors.ML
authorhuffman
Sun, 28 Feb 2010 08:55:01 -0800
changeset 35468 09bc6a2e2296
parent 35467 561d8e98d9d3
child 35469 6e59de61d501
move definition and syntax of pattern combinators into domain_constructors.ML
src/HOLCF/Tools/Domain/domain_axioms.ML
src/HOLCF/Tools/Domain/domain_constructors.ML
src/HOLCF/Tools/Domain/domain_syntax.ML
src/HOLCF/Tools/Domain/domain_theorems.ML
--- a/src/HOLCF/Tools/Domain/domain_axioms.ML	Sat Feb 27 21:38:24 2010 -0800
+++ b/src/HOLCF/Tools/Domain/domain_axioms.ML	Sun Feb 28 08:55:01 2010 -0800
@@ -78,22 +78,6 @@
           (dc_abs oo (copy_of_dtyp map_tab r (dtyp_of_eq eqn)) oo dc_rep))
       end;
 
-    val pat_defs =
-      let
-        fun pdef (con, _, args) =
-          let
-            val ps = mapn (fn n => fn _ => %:("pat" ^ string_of_int n)) 1 args;
-            val xs = map (bound_arg args) args;
-            val r = Bound (length args);
-            val rhs = case args of [] => mk_return HOLogic.unit
-                                 | _ => mk_ctuple_pat ps ` mk_ctuple xs;
-            fun one_con (con', _, args') = List.foldr /\# (if con'=con then rhs else mk_fail) args';
-          in (pat_name con ^"_def", list_comb (%%:(pat_name con), ps) == 
-                                              list_ccomb(%%:(dname^"_when"), map one_con cons))
-          end
-      in map pdef cons end;
-
-
 (* ----- axiom and definitions concerning induction ------------------------- *)
 
     val reach_ax = ("reach", mk_trp(proj (mk_fix (%%:(comp_dname^"_copy"))) eqs n
@@ -112,7 +96,6 @@
   in (dnam,
       (if definitional then [] else [abs_iso_ax, rep_iso_ax, reach_ax]),
       (if definitional then [when_def] else [when_def, copy_def]) @
-      pat_defs @
       [take_def, finite_def])
   end; (* let (calc_axioms) *)
 
--- a/src/HOLCF/Tools/Domain/domain_constructors.ML	Sat Feb 27 21:38:24 2010 -0800
+++ b/src/HOLCF/Tools/Domain/domain_constructors.ML	Sun Feb 28 08:55:01 2010 -0800
@@ -27,7 +27,8 @@
            cases : thm list,
            sel_rews : thm list,
            dis_rews : thm list,
-           match_rews : thm list
+           match_rews : thm list,
+           pat_rews : thm list
          } * theory;
 end;
 
@@ -94,6 +95,10 @@
   let val T = fastype_of t
   in cabs_const (Term.domain_type T, Term.range_type T) $ t end
 
+(* builds the expression (% v1 v2 .. vn. rhs) *)
+fun lambdas [] rhs = rhs
+  | lambdas (v::vs) rhs = Term.lambda v (lambdas vs rhs);
+
 (* builds the expression (LAM v. rhs) *)
 fun big_lambda v rhs =
   cabs_const (fastype_of v, fastype_of rhs) $ Term.lambda v rhs;
@@ -131,9 +136,11 @@
 
 (*** Product type ***)
 
+val mk_prodT = HOLogic.mk_prodT
+
 fun mk_tupleT [] = HOLogic.unitT
   | mk_tupleT [T] = T
-  | mk_tupleT (T :: Ts) = HOLogic.mk_prodT (T, mk_tupleT Ts);
+  | mk_tupleT (T :: Ts) = mk_prodT (T, mk_tupleT Ts);
 
 (* builds the expression (v1,v2,..,vn) *)
 fun mk_tuple [] = HOLogic.unit
@@ -235,8 +242,14 @@
 
 fun mk_matchT T = Type (@{type_name "maybe"}, [T]);
 
+fun dest_matchT (Type(@{type_name "maybe"}, [T])) = T
+  | dest_matchT T = raise TYPE ("dest_matchT", [T], []);
+
 fun mk_fail T = Const (@{const_name "Fixrec.fail"}, mk_matchT T);
 
+fun return_const T = Const (@{const_name "Fixrec.return"}, T ->> mk_matchT T);
+fun mk_return t = return_const (fastype_of t) ` t;
+
 
 (*** miscellaneous constructions ***)
 
@@ -997,6 +1010,109 @@
   end;
 
 (******************************************************************************)
+(************** definitions and theorems for pattern combinators **************)
+(******************************************************************************)
+
+fun add_pattern_combinators
+    (bindings : binding list)
+    (spec : (term * (bool * typ) list) list)
+    (lhsT : typ)
+    (casedist : thm)
+    (case_const : typ -> term)
+    (case_rews : thm list)
+    (thy : theory) =
+  let
+
+    (* define pattern combinators *)
+    local
+      fun mk_pair_pat (p1, p2) =
+        let
+          val T1 = fastype_of p1;
+          val T2 = fastype_of p2;
+          val (U1, V1) = apsnd dest_matchT (dest_cfunT T1);
+          val (U2, V2) = apsnd dest_matchT (dest_cfunT T2);
+          val pat_typ = [T1, T2] --->
+              (mk_prodT (U1, U2) ->> mk_matchT (mk_prodT (V1, V2)));
+          val pat_const = Const (@{const_name cpair_pat}, pat_typ);
+        in
+          pat_const $ p1 $ p2
+        end;
+      fun mk_tuple_pat [] = return_const HOLogic.unitT
+        | mk_tuple_pat ps = foldr1 mk_pair_pat ps;
+
+      val tns = map (fst o dest_TFree) (snd (dest_Type lhsT));
+
+      fun pat_eqn (i, (bind, (con, args))) : binding * term * mixfix =
+        let
+          val pat_bind = Binding.suffix_name "_pat" bind;
+          val Ts = map snd args;
+          val Vs =
+              (map (K "t") args)
+              |> Datatype_Prop.indexify_names
+              |> Name.variant_list tns
+              |> map (fn t => TFree (t, @{sort pcpo}));
+          val patNs = Datatype_Prop.indexify_names (map (K "pat") args);
+          val patTs = map2 (fn T => fn V => T ->> mk_matchT V) Ts Vs;
+          val pats = map Free (patNs ~~ patTs);
+          val fail = mk_fail (mk_tupleT Vs);
+          val ns = Name.variant_list patNs (Datatype_Prop.make_tnames Ts);
+          val vs = map Free (ns ~~ Ts);
+          val rhs = big_lambdas vs (mk_tuple_pat pats ` mk_tuple vs);
+          fun one_fun (j, (_, args')) =
+            let
+              val Ts = map snd args';
+              val ns = Name.variant_list patNs (Datatype_Prop.make_tnames Ts);
+              val vs' = map Free (ns ~~ Ts);
+            in if i = j then rhs else big_lambdas vs' fail end;
+          val funs = map_index one_fun spec;
+          val body = list_ccomb (case_const (mk_matchT (mk_tupleT Vs)), funs);
+        in
+          (pat_bind, lambdas pats body, NoSyn)
+        end;
+    in
+      val ((pat_consts, pat_defs), thy) =
+          define_consts (map_index pat_eqn (bindings ~~ spec)) thy
+    end;
+
+    (* syntax translations for pattern combinators *)
+    local
+      open Syntax
+      fun syntax c = Syntax.mark_const (fst (dest_Const c));
+      fun app s (l, r) = Syntax.mk_appl (Constant s) [l, r];
+      val capp = app @{const_syntax Rep_CFun};
+      val capps = Library.foldl capp
+
+      fun app_var x = Syntax.mk_appl (Constant "_variable") [x, Variable "rhs"];
+      fun app_pat x = Syntax.mk_appl (Constant "_pat") [x];
+      fun args_list [] = Constant "_noargs"
+        | args_list xs = foldr1 (app "_args") xs;
+      fun one_case_trans (pat, (con, args)) =
+        let
+          val cname = Constant (syntax con);
+          val pname = Constant (syntax con ^ "_pat");
+          val ns = 1 upto length args;
+          val xs = map (fn n => Variable ("x"^(string_of_int n))) ns;
+          val ps = map (fn n => Variable ("p"^(string_of_int n))) ns;
+          val vs = map (fn n => Variable ("v"^(string_of_int n))) ns;
+        in
+          [ParseRule (app_pat (capps (cname, xs)),
+                      mk_appl pname (map app_pat xs)),
+           ParseRule (app_var (capps (cname, xs)),
+                      app_var (args_list xs)),
+           PrintRule (capps (cname, ListPair.map (app "_match") (ps,vs)),
+                      app "_match" (mk_appl pname ps, args_list vs))]
+        end;
+      val trans_rules : Syntax.ast Syntax.trrule list =
+          maps one_case_trans (pat_consts ~~ spec);
+    in
+      val thy = Sign.add_trrules_i trans_rules thy;
+    end;
+
+  in
+    (pat_defs, thy)
+  end
+
+(******************************************************************************)
 (******************************* main function ********************************)
 (******************************************************************************)
 
@@ -1077,6 +1193,18 @@
           casedist case_const cases thy
       end
 
+    (* define and prove theorems for pattern combinators *)
+    val (pat_thms : thm list, thy : theory) =
+      let
+        val bindings = map #1 spec;
+        fun prep_arg (lazy, sel, T) = (lazy, T);
+        fun prep_con c (b, args, mx) = (c, map prep_arg args);
+        val pat_spec = map2 prep_con con_consts spec;
+      in
+        add_pattern_combinators bindings pat_spec lhsT
+          casedist case_const cases thy
+      end
+
     (* restore original signature path *)
     val thy = Sign.parent_path thy;
 
@@ -1094,7 +1222,8 @@
         cases = cases,
         sel_rews = sel_thms,
         dis_rews = dis_thms,
-        match_rews = match_thms };
+        match_rews = match_thms,
+        pat_rews = pat_thms };
   in
     (result, thy)
   end;
--- a/src/HOLCF/Tools/Domain/domain_syntax.ML	Sat Feb 27 21:38:24 2010 -0800
+++ b/src/HOLCF/Tools/Domain/domain_syntax.ML	Sun Feb 28 08:55:01 2010 -0800
@@ -55,36 +55,6 @@
     val const_copy = (dbind "_copy", dtypeprod ->> dtype  ->> dtype , NoSyn);
     end;
 
-(* ----- constants concerning constructors, discriminators, and selectors --- *)
-
-    local
-      val escape = let
-        fun esc (c::cs) = if c mem ["'","_","(",")","/"] then "'"::c::esc cs
-                          else      c::esc cs
-          | esc []      = []
-      in implode o esc o Symbol.explode end;
-
-      fun pat_name_ con =
-          Binding.name (strip_esc (Binding.name_of con) ^ "_pat");
-      (* strictly speaking, these constants have one argument,
-       but the mixfix (without arguments) is introduced only
-           to generate parse rules for non-alphanumeric names*)
-      fun freetvar s n =
-          let val tvar = mk_TFree (s ^ string_of_int n)
-          in if tvar mem typevars then freetvar ("t"^s) n else tvar end;
-
-      fun mk_patT (a,b)     = a ->> mk_maybeT b;
-      fun pat_arg_typ n arg = mk_patT (third arg, freetvar "t" n);
-      fun pat (con,args,mx) =
-          (pat_name_ con,
-           (mapn pat_arg_typ 1 args)
-             --->
-             mk_patT (dtype, mk_ctupleT (map (freetvar "t") (1 upto length args))),
-           Mixfix(escape (Binding.name_of con ^ "_pat"), [], Syntax.max_pri));
-    in
-    val consts_pat = map pat cons';
-    end;
-
 (* ----- constants concerning induction ------------------------------------- *)
 
     val const_take   = (dbind "_take"  , HOLogic.natT-->dtype->>dtype, NoSyn);
@@ -150,7 +120,6 @@
         if definitional then [] else [const_rep, const_abs, const_copy];
 
   in (optional_consts @ [const_when] @ 
-      consts_pat @
       [const_take, const_finite],
       (case_trans false :: case_trans true :: (abscon_trans false @ abscon_trans true @ Case_trans)))
   end; (* let *)
--- a/src/HOLCF/Tools/Domain/domain_theorems.ML	Sat Feb 27 21:38:24 2010 -0800
+++ b/src/HOLCF/Tools/Domain/domain_theorems.ML	Sun Feb 28 08:55:01 2010 -0800
@@ -121,8 +121,6 @@
   val ax_abs_iso  = ga "abs_iso"  dname;
   val ax_rep_iso  = ga "rep_iso"  dname;
   val ax_when_def = ga "when_def" dname;
-  fun get_def mk_name (con, _, _) = ga (mk_name con^"_def") dname;
-  val axs_pat_def = map (get_def pat_name) cons;
   val ax_copy_def = ga "copy_def" dname;
 end; (* local *)
 
@@ -157,6 +155,7 @@
 val when_strict = hd when_rews;
 val dis_rews = #dis_rews result;
 val mat_rews = #match_rews result;
+val axs_pat_def = #pat_rews result;
 
 (* ----- theorems concerning the isomorphism -------------------------------- *)