exported equation transformator
authorhaftmann
Fri, 21 Jul 2006 14:46:27 +0200
changeset 20176 36737fb58614
parent 20175 0a8ca32f6e64
child 20177 0af885e3dabf
exported equation transformator
src/HOL/Tools/primrec_package.ML
--- a/src/HOL/Tools/primrec_package.ML	Fri Jul 21 14:45:43 2006 +0200
+++ b/src/HOL/Tools/primrec_package.ML	Fri Jul 21 14:46:27 2006 +0200
@@ -8,14 +8,15 @@
 signature PRIMREC_PACKAGE =
 sig
   val quiet_mode: bool ref
+  val mk_combdefs: theory -> term list -> (string * term) list
   val add_primrec: string -> ((bstring * string) * Attrib.src list) list
-    -> theory -> theory * thm list
+    -> theory -> thm list * theory
   val add_primrec_unchecked: string -> ((bstring * string) * Attrib.src list) list
-    -> theory -> theory * thm list
+    -> theory -> thm list * theory
   val add_primrec_i: string -> ((bstring * term) * attribute list) list
-    -> theory -> theory * thm list
+    -> theory -> thm list * theory
   val add_primrec_unchecked_i: string -> ((bstring * term) * attribute list) list
-    -> theory -> theory * thm list
+    -> theory -> thm list * theory
 end;
 
 structure PrimrecPackage : PRIMREC_PACKAGE =
@@ -26,8 +27,8 @@
 exception RecError of string;
 
 fun primrec_err s = error ("Primrec definition error:\n" ^ s);
-fun primrec_eq_err sign s eq =
-  primrec_err (s ^ "\nin\n" ^ quote (Sign.string_of_term sign eq));
+fun primrec_eq_err thy s eq =
+  primrec_err (s ^ "\nin\n" ^ quote (Sign.string_of_term thy eq));
 
 
 (* messages *)
@@ -38,13 +39,13 @@
 
 (* preprocessing of equations *)
 
-fun process_eqn sign (eq, rec_fns) = 
+fun process_eqn thy (eq, rec_fns) = 
   let
     val (lhs, rhs) = 
-	if null (term_vars eq) then
-	    HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
-	      handle TERM _ => raise RecError "not a proper equation"
-	else raise RecError "illegal schematic variable(s)";
+      if null (term_vars eq) then
+        HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
+        handle TERM _ => raise RecError "not a proper equation"
+      else raise RecError "illegal schematic variable(s)";
 
     val (recfun, args) = strip_comb lhs;
     val fnameT = dest_Const recfun handle TERM _ => 
@@ -61,9 +62,8 @@
     val (tname, _) = dest_Type (body_type T) handle TYPE _ =>
       raise RecError "cannot determine datatype associated with function"
 
-    val (ls, cargs, rs) = (map dest_Free ls', 
-			   map dest_Free cargs', 
-			   map dest_Free rs')
+    val (ls, cargs, rs) =
+      (map dest_Free ls', map dest_Free cargs', map dest_Free rs')
       handle TERM _ => raise RecError "illegal argument in pattern";
     val lfrees = ls @ rs @ cargs;
 
@@ -88,9 +88,9 @@
             AList.update (op =) (fnameT, (tname, rpos, (cname, (ls, cargs, rs, rhs, eq))::eqns))
               rec_fns)
   end
-  handle RecError s => primrec_eq_err sign s eq;
+  handle RecError s => primrec_eq_err thy s eq;
 
-fun process_fun sign descr rec_eqns ((i, fnameT as (fname, _)), (fnameTs, fnss)) =
+fun process_fun thy descr rec_eqns ((i, fnameT as (fname, _)), (fnameTs, fnss)) =
   let
     val (_, (tname, _, constrs)) = List.nth (descr, i);
 
@@ -122,7 +122,7 @@
                   | SOME (i', y) =>
                       let
                         val (fs', ts') = foldl_map (subst subs) (fs, xs @ ls @ rs);
-                        val fs'' = process_fun sign descr rec_eqns ((i', fnameT'), fs')
+                        val fs'' = process_fun thy descr rec_eqns ((i', fnameT'), fs')
                       in (fs'', list_comb (y, ts'))
                       end)
               end
@@ -145,15 +145,15 @@
               val recs = List.filter (is_rec_type o snd) (cargs' ~~ cargs);
               val rargs = map fst recs;
               val subs = map (rpair dummyT o fst) 
-		             (rev (rename_wrt_term rhs rargs));
+                (rev (rename_wrt_term rhs rargs));
               val ((fnameTs'', fnss''), rhs') = 
-		  (subst (map (fn ((x, y), z) =>
-			       (Free x, (body_index y, Free z)))
-			  (recs ~~ subs))
-		   ((fnameTs', fnss'), rhs))
-                  handle RecError s => primrec_eq_err sign s eq
+                  (subst (map (fn ((x, y), z) =>
+                               (Free x, (body_index y, Free z)))
+                          (recs ~~ subs))
+                   ((fnameTs', fnss'), rhs))
+                  handle RecError s => primrec_eq_err thy s eq
             in (fnameTs'', fnss'', 
-		(list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
+                (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
             end)
 
   in (case AList.lookup (op =) fnameTs i of
@@ -176,7 +176,7 @@
 
 (* prepare functions needed for definitions *)
 
-fun get_fns fns (((i, (tname, _, constrs)), rec_name), (fs, defs)) =
+fun get_fns fns (((i : int, (tname, _, constrs)), rec_name), (fs, defs)) =
   case AList.lookup (op =) fns i of
      NONE =>
        let
@@ -192,15 +192,15 @@
 
 (* make definition *)
 
-fun make_def sign fs (fname, ls, rec_name, tname) =
+fun make_def thy fs (fname, ls, rec_name, tname) =
   let
     val rhs = foldr (fn (T, t) => Abs ("", T, t)) 
-	            (list_comb (Const (rec_name, dummyT),
-				fs @ map Bound (0 ::(length ls downto 1))))
-		    ((map snd ls) @ [dummyT]);
+                    (list_comb (Const (rec_name, dummyT),
+                                fs @ map Bound (0 ::(length ls downto 1))))
+                    ((map snd ls) @ [dummyT]);
     val defpair = (Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def",
-		   Logic.mk_equals (Const (fname, dummyT), rhs))
-  in Theory.inferT_axm sign defpair end;
+                   Logic.mk_equals (Const (fname, dummyT), rhs))
+  in Theory.inferT_axm thy defpair end;
 
 
 (* find datatypes which contain all datatypes in tnames' *)
@@ -225,12 +225,10 @@
     |> RuleCases.save induction
   end;
 
-fun gen_primrec_i unchecked alt_name eqns_atts thy =
+fun mk_defs thy eqns =
   let
-    val (eqns, atts) = split_list eqns_atts;
-    val sg = Theory.sign_of thy;
     val dt_info = DatatypePackage.get_datatypes thy;
-    val rec_eqns = foldr (process_eqn sg) [] (map snd eqns);
+    val rec_eqns = foldr (process_eqn thy) [] eqns;
     val tnames = distinct (op =) (map (#1 o snd) rec_eqns);
     val dts = find_dts dt_info tnames tnames;
     val main_fns = 
@@ -242,10 +240,20 @@
       if null dts then
         primrec_err ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
       else snd (hd dts);
-    val (fnameTs, fnss) = foldr (process_fun sg descr rec_eqns)
-	                       ([], []) main_fns;
+    val (fnameTs, fnss) =
+      foldr (process_fun thy descr rec_eqns) ([], []) main_fns;
     val (fs, defs) = foldr (get_fns fnss) ([], []) (descr ~~ rec_names);
-    val defs' = map (make_def sg fs) defs;
+    val defs' = map (make_def thy fs) defs;
+  in (fnameTs, rec_eqns, rec_rewrites, dts, defs, defs') end;
+
+fun mk_combdefs thy =
+  #6 o mk_defs thy o map (ObjectLogic.ensure_propT thy);
+
+fun gen_primrec_i unchecked alt_name eqns_atts thy =
+  let
+    val (eqns, atts) = split_list eqns_atts;
+    val (fnameTs, rec_eqns, rec_rewrites, dts, defs, defs') =
+      mk_defs thy (map snd eqns);
     val nameTs1 = map snd fnameTs;
     val nameTs2 = map fst rec_eqns;
     val primrec_name =
@@ -261,26 +269,25 @@
       commas_quote (map fst nameTs1) ^ " ...");
     val simps = map (fn (_, t) => Goal.prove_global thy' [] [] t
         (fn _ => EVERY [rewrite_goals_tac rewrites, rtac refl 1])) eqns;
-    val (simps', thy'') = PureThy.add_thms ((map fst eqns ~~ simps) ~~ atts) thy';
-    val thy''' = thy''
-      |> (snd o PureThy.add_thmss [(("simps", simps'),
-          [Simplifier.simp_add, RecfunCodegen.add NONE])])
-      |> (snd o PureThy.add_thms [(("induct", prepare_induct (#2 (hd dts)) rec_eqns), [])])
-      |> Theory.parent_path
+    val (simps', thy'') = thy' |> PureThy.add_thms ((map fst eqns ~~ simps) ~~ atts);
   in
-    (thy''', simps')
+    thy''
+    |> (snd o PureThy.add_thmss [(("simps", simps'),
+        [Simplifier.simp_add, RecfunCodegen.add NONE])])
+    |> (snd o PureThy.add_thms [(("induct", prepare_induct (#2 (hd dts)) rec_eqns), [])])
+    |> Theory.parent_path
+    |> pair simps'
   end;
 
 fun gen_primrec unchecked alt_name eqns thy =
   let
-    val sign = Theory.sign_of thy;
     val ((names, strings), srcss) = apfst split_list (split_list eqns);
     val atts = map (map (Attrib.attribute thy)) srcss;
-    val eqn_ts = map (fn s => term_of (Thm.read_cterm sign (s, propT))
+    val eqn_ts = map (fn s => term_of (Thm.read_cterm thy (s, propT))
       handle ERROR msg => cat_error msg ("The error(s) above occurred for " ^ s)) strings;
     val rec_ts = map (fn eq => head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop eq)))
-      handle TERM _ => primrec_eq_err sign "not a proper equation" eq) eqn_ts;
-    val (_, eqn_ts') = InductivePackage.unify_consts (sign_of thy) rec_ts eqn_ts
+      handle TERM _ => primrec_eq_err thy "not a proper equation" eq) eqn_ts;
+    val (_, eqn_ts') = InductivePackage.unify_consts thy rec_ts eqn_ts
   in
     gen_primrec_i unchecked alt_name (names ~~ eqn_ts' ~~ atts) thy
   end;
@@ -306,7 +313,7 @@
 val primrecP =
   OuterSyntax.command "primrec" "define primitive recursive functions on datatypes" K.thy_decl
     (primrec_decl >> (fn ((unchecked, alt_name), eqns) =>
-      Toplevel.theory (#1 o
+      Toplevel.theory (snd o
         (if unchecked then add_primrec_unchecked else add_primrec) alt_name
           (map P.triple_swap eqns))));