--- a/src/HOL/Tools/case_translation.ML Sun May 26 14:02:03 2013 +0200
+++ b/src/HOL/Tools/case_translation.ML Sun May 26 18:37:43 2013 +0200
@@ -90,12 +90,14 @@
val lookup_by_case = Symtab.lookup o cases_of;
+
(** installation **)
fun case_error s = error ("Error in case expression:\n" ^ s);
val name_of = try (dest_Const #> fst);
+
(* parse translation *)
fun constrain_Abs tT t = Syntax.const @{syntax_const "_constrainAbs"} $ t $ tT;
@@ -129,9 +131,9 @@
| replace_dummies t used = (t, used);
fun dest_case1 (t as Const (@{syntax_const "_case1"}, _) $ l $ r) =
- let val (l', _) = replace_dummies l (Term.declare_term_frees t Name.context)
- in abs_pat l' []
- (Syntax.const @{const_syntax case_elem} $ Term_Position.strip_positions l' $ r)
+ let val (l', _) = replace_dummies l (Term.declare_term_frees t Name.context) in
+ abs_pat l' []
+ (Syntax.const @{const_syntax case_elem} $ Term_Position.strip_positions l' $ r)
end
| dest_case1 _ = case_error "dest_case1";
@@ -140,11 +142,11 @@
val errt = if err then @{term True} else @{term False};
in
- Syntax.const @{const_syntax case_guard} $ errt $ t $ (fold_rev
- (fn t => fn u =>
- Syntax.const @{const_syntax case_cons} $ dest_case1 t $ u)
- (dest_case2 u)
- (Syntax.const @{const_syntax case_nil}))
+ Syntax.const @{const_syntax case_guard} $ errt $ t $
+ (fold_rev
+ (fn t => fn u => Syntax.const @{const_syntax case_cons} $ dest_case1 t $ u)
+ (dest_case2 u)
+ (Syntax.const @{const_syntax case_nil}))
end
| case_tr _ _ _ = case_error "case_tr";
@@ -154,24 +156,24 @@
(* print translation *)
fun case_tr' (_ :: x :: t :: ts) =
- let
- fun mk_clause (Const (@{const_syntax case_abs}, _) $ Abs (s, T, t)) xs used =
- let val (s', used') = Name.variant s used
- in mk_clause t ((s', T) :: xs) used' end
- | mk_clause (Const (@{const_syntax case_elem}, _) $ pat $ rhs) xs _ =
- Syntax.const @{syntax_const "_case1"} $
- subst_bounds (map Syntax_Trans.mark_bound_abs xs, pat) $
- subst_bounds (map Syntax_Trans.mark_bound_body xs, rhs);
+ let
+ fun mk_clause (Const (@{const_syntax case_abs}, _) $ Abs (s, T, t)) xs used =
+ let val (s', used') = Name.variant s used
+ in mk_clause t ((s', T) :: xs) used' end
+ | mk_clause (Const (@{const_syntax case_elem}, _) $ pat $ rhs) xs _ =
+ Syntax.const @{syntax_const "_case1"} $
+ subst_bounds (map Syntax_Trans.mark_bound_abs xs, pat) $
+ subst_bounds (map Syntax_Trans.mark_bound_body xs, rhs);
- fun mk_clauses (Const (@{const_syntax case_nil}, _)) = []
- | mk_clauses (Const (@{const_syntax case_cons}, _) $ t $ u) =
- mk_clause t [] (Term.declare_term_frees t Name.context) ::
- mk_clauses u
- in
- list_comb (Syntax.const @{syntax_const "_case_syntax"} $ x $
- foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u)
- (mk_clauses t), ts)
- end;
+ fun mk_clauses (Const (@{const_syntax case_nil}, _)) = []
+ | mk_clauses (Const (@{const_syntax case_cons}, _) $ t $ u) =
+ mk_clause t [] (Term.declare_term_frees t Name.context) ::
+ mk_clauses u
+ in
+ list_comb (Syntax.const @{syntax_const "_case_syntax"} $ x $
+ foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u)
+ (mk_clauses t), ts)
+ end;
val trfun_setup' = Sign.print_translation [(@{const_syntax "case_guard"}, K case_tr')];
@@ -204,13 +206,13 @@
(*Each pattern carries with it a tag i, which denotes the clause it
-came from. i = ~1 indicates that the clause was added by pattern
-completion.*)
+ came from. i = ~1 indicates that the clause was added by pattern
+ completion.*)
fun add_row_used ((prfx, pats), (tm, tag)) =
fold Term.declare_term_frees (tm :: pats @ map Free prfx);
-(* try to preserve names given by user *)
+(*try to preserve names given by user*)
fun default_name "" (Free (name', _)) = name'
| default_name name _ = name;
@@ -220,8 +222,9 @@
let
val (_, T) = dest_Const c;
val Ts = binder_types T;
- val (names, _) = fold_map Name.variant
- (Datatype_Prop.make_tnames (map Logic.unvarifyT_global Ts)) used;
+ val (names, _) =
+ fold_map Name.variant
+ (Datatype_Prop.make_tnames (map Logic.unvarifyT_global Ts)) used;
val ty = body_type T;
val ty_theta = Type.raw_match (ty, colty) Vartab.empty
handle Type.TYPE_MATCH => raise CASE_ERROR ("type mismatch", ~1);
@@ -251,8 +254,7 @@
(* Partitioning *)
fun partition _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1)
- | partition used constructors colty res_ty
- (rows as (((prfx, _ :: ps), _) :: _)) =
+ | partition used constructors colty res_ty (rows as (((prfx, _ :: ps), _) :: _)) =
let
fun part [] [] = []
| part [] ((_, (_, i)) :: _) = raise CASE_ERROR ("Not a constructor pattern", i)
@@ -311,8 +313,7 @@
| mk path (rows as ((row as ((_, [Free _]), _)) :: _ :: _)) = mk path [row]
| mk (u :: us) (rows as ((_, _ :: _), _) :: _) =
let val col0 = map (fn ((_, p :: _), (_, i)) => (p, i)) rows in
- (case Option.map (apfst head_of)
- (find_first (not o is_Free o fst) col0) of
+ (case Option.map (apfst head_of) (find_first (not o is_Free o fst) col0) of
NONE =>
let
val rows' = map (fn ((v, _), row) => row ||>
@@ -326,8 +327,7 @@
val pty = body_type cT;
val used' = fold Term.declare_term_frees us used;
val nrows = maps (expand constructors used' pty) rows;
- val subproblems =
- partition used' constructors pty range_ty nrows;
+ val subproblems = partition used' constructors pty range_ty nrows;
val (pat_rect, dtrees) =
split_list (map (fn {new_formals, group, ...} =>
mk (new_formals @ us) group) subproblems);
@@ -509,13 +509,13 @@
encode_clause recur S T p $ encode_cases recur S T ps;
fun encode_case recur (t, ps as (pat, rhs) :: _) =
- let
- val tT = fastype_of t;
- val T = fastype_of rhs;
- in
- Const (@{const_name case_guard}, @{typ bool} --> tT --> (tT --> T) --> T) $
- @{term True} $ t $ (encode_cases recur (fastype_of pat) (fastype_of rhs) ps)
- end
+ let
+ val tT = fastype_of t;
+ val T = fastype_of rhs;
+ in
+ Const (@{const_name case_guard}, @{typ bool} --> tT --> (tT --> T) --> T) $
+ @{term True} $ t $ (encode_cases recur (fastype_of pat) (fastype_of rhs) ps)
+ end
| encode_case _ _ = case_error "encode_case";
fun strip_case' ctxt d (pat, rhs) =
@@ -542,8 +542,8 @@
(strip_case_full ctxt d x, maps (strip_case' ctxt d) clauses)
| NONE =>
(case t of
- (t $ u) => strip_case_full ctxt d t $ strip_case_full ctxt d u
- | (Abs (x, T, u)) =>
+ t $ u => strip_case_full ctxt d t $ strip_case_full ctxt d u
+ | Abs (x, T, u) =>
let val (x', u') = Term.dest_abs (x, T, u);
in Term.absfree (x', T) (strip_case_full ctxt d u') end
| _ => t));