tuned;
authorwenzelm
Fri, 16 Dec 2011 11:02:55 +0100
changeset 45898 b619242b0439
parent 45897 65cef0298158
child 45899 df887263a379
child 45904 c9ae2bc95fad
tuned;
src/HOL/Tools/Datatype/datatype_case.ML
src/HOL/Tools/Datatype/primrec.ML
--- a/src/HOL/Tools/Datatype/datatype_case.ML	Fri Dec 16 10:52:35 2011 +0100
+++ b/src/HOL/Tools/Datatype/datatype_case.ML	Fri Dec 16 11:02:55 2011 +0100
@@ -130,7 +130,7 @@
                  names = names,
                  constraints = cnstrts,
                  group = in_group'} :: part cs not_in_group
-              end
+              end;
       in part constructors rows end;
 
 fun v_to_prfx (prfx, Free v :: pats) = (v :: prfx, pats)
@@ -143,7 +143,6 @@
   let
     val get_info = Datatype_Data.info_of_constr_permissive (Proof_Context.theory_of ctxt);
 
-    val name = singleton (Name.variant_list used) "a";
     fun expand constructors used ty ((_, []), _) = raise CASE_ERROR ("mk_case: expand_var_row", ~1)
       | expand constructors used ty (row as ((prfx, p :: ps), (rhs, tag))) =
           if is_Free p then
@@ -153,7 +152,10 @@
                 let val capp = list_comb (fresh_constr ty_match ty_inst ty used' c)
                 in ((prfx, capp :: ps), (subst_free [(p, capp)] rhs, tag)) end
             in map expnd constructors end
-          else [row]
+          else [row];
+
+    val name = singleton (Name.variant_list used) "a";
+
     fun mk _ [] = raise CASE_ERROR ("no rows", ~1)
       | mk [] (((_, []), (tm, tag)) :: _) = ([tag], tm) (* Done *)
       | mk path (rows as ((row as ((_, [Free _]), _)) :: _ :: _)) = mk path [row]
@@ -277,19 +279,22 @@
                 val (u', used'') = prep_pat u used';
               in (t' $ u', used'') end
           | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t);
+
         fun dest_case1 (t as Const (@{syntax_const "_case1"}, _) $ l $ r) =
               let val (l', cnstrts) = strip_constraints l
               in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts) end
           | dest_case1 t = case_error "dest_case1";
+
         fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u
           | dest_case2 t = [t];
+
         val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
-        val case_tm =
-          make_case_untyped ctxt
-            (if err then Error else Warning) []
-            (fold (fn tT => fn t => Syntax.const @{syntax_const "_constrain"} $ t $ tT)
-               (flat cnstrts) t) cases;
-      in case_tm end
+      in
+        make_case_untyped ctxt
+          (if err then Error else Warning) []
+          (fold (fn tT => fn t => Syntax.const @{syntax_const "_constrain"} $ t $ tT)
+             (flat cnstrts) t) cases
+      end
   | case_tr _ _ _ = case_error "case_tr";
 
 val trfun_setup =
--- a/src/HOL/Tools/Datatype/primrec.ML	Fri Dec 16 10:52:35 2011 +0100
+++ b/src/HOL/Tools/Datatype/primrec.ML	Fri Dec 16 11:02:55 2011 +0100
@@ -206,11 +206,11 @@
 
 (* find datatypes which contain all datatypes in tnames' *)
 
-fun find_dts (dt_info : Datatype_Aux.info Symtab.table) _ [] = []
+fun find_dts _ _ [] = []
   | find_dts dt_info tnames' (tname :: tnames) =
       (case Symtab.lookup dt_info tname of
         NONE => primrec_error (quote tname ^ " is not a datatype")
-      | SOME dt =>
+      | SOME (dt : Datatype_Aux.info) =>
           if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then
             (tname, dt) :: (find_dts dt_info tnames' tnames)
           else find_dts dt_info tnames' tnames);
@@ -218,12 +218,12 @@
 
 (* distill primitive definition(s) from primrec specification *)
 
-fun distill lthy fixes eqs =
+fun distill ctxt fixes eqs =
   let
-    val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v
+    val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed ctxt v
       orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs [];
     val tnames = distinct (op =) (map (#1 o snd) eqns);
-    val dts = find_dts (Datatype_Data.get_all (Proof_Context.theory_of lthy)) tnames tnames;
+    val dts = find_dts (Datatype_Data.get_all (Proof_Context.theory_of ctxt)) tnames tnames;
     val main_fns = map (fn (tname, {index, ...}) =>
       (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts;
     val {descr, rec_names, rec_rewrites, ...} =
@@ -232,7 +232,7 @@
       else snd (hd dts);
     val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
     val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
-    val defs = map (make_def lthy fixes fs) raw_defs;
+    val defs = map (make_def ctxt fixes fs) raw_defs;
     val names = map snd fnames;
     val names_eqns = map fst eqns;
     val _ =
@@ -241,17 +241,17 @@
         "\nare not mutually recursive");
     val rec_rewrites' = map mk_meta_eq rec_rewrites;
     val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs);
-    fun prove lthy defs =
+    fun prove ctxt defs =
       let
-        val frees = fold (Variable.add_free_names lthy) eqs [];
+        val frees = fold (Variable.add_free_names ctxt) eqs [];
         val rewrites = rec_rewrites' @ map (snd o snd) defs;
         fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
-      in map (fn eq => Goal.prove lthy frees [] eq tac) eqs end;
+      in map (fn eq => Goal.prove ctxt frees [] eq tac) eqs end;
   in ((prefix, (fs, defs)), prove) end
   handle PrimrecError (msg, some_eqn) =>
     error ("Primrec definition error:\n" ^ msg ^
       (case some_eqn of
-        SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
+        SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term ctxt eqn)
       | NONE => ""));
 
 
@@ -259,7 +259,7 @@
 
 fun add_primrec_simple fixes ts lthy =
   let
-    val ((prefix, (fs, defs)), prove) = distill lthy fixes ts;
+    val ((prefix, (_, defs)), prove) = distill lthy fixes ts;
   in
     lthy
     |> fold_map Local_Theory.define defs