# HG changeset patch # User wenzelm # Date 1300798539 -3600 # Node ID e945717b2b155a6ec77e1ab1e5f2cbb06dcbce5b # Parent afd11ca8e01879495e04aa9d7d399baacd2ffdd0 tuned indendation and parentheses; diff -r afd11ca8e018 -r e945717b2b15 src/HOL/Tools/Datatype/datatype_case.ML --- a/src/HOL/Tools/Datatype/datatype_case.ML Tue Mar 22 13:32:20 2011 +0100 +++ b/src/HOL/Tools/Datatype/datatype_case.ML Tue Mar 22 13:55:39 2011 +0100 @@ -37,7 +37,7 @@ *---------------------------------------------------------------------------*) fun ty_info tab sT = - case tab sT of + (case tab sT of SOME ({descr, case_name, index, sorts, ...} : info) => let val (_, (tname, dts, constrs)) = nth descr index; @@ -48,7 +48,7 @@ constructors = map (fn (cname, dts') => Const (cname, Logic.varifyT_global (map mk_ty dts' ---> T))) constrs} end - | NONE => NONE; + | NONE => NONE); (*--------------------------------------------------------------------------- @@ -93,8 +93,7 @@ raise CASE_ERROR ("type mismatch", ~1) val c' = ty_inst ty_theta c val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts) - in (c', gvars) - end; + in (c', gvars) end; (*--------------------------------------------------------------------------- @@ -102,21 +101,22 @@ * pattern with constructor = name. *---------------------------------------------------------------------------*) fun mk_group (name, T) rows = - let val k = length (binder_types T) - in fold (fn (row as ((prfx, p :: rst), rhs as (_, (i, _)))) => - fn ((in_group, not_in_group), (names, cnstrts)) => (case strip_comb p of - (Const (name', _), args) => - if name = name' then - if length args = k then - let val (args', cnstrts') = split_list (map strip_constraints args) - in - ((((prfx, args' @ rst), rhs) :: in_group, not_in_group), - (default_names names args', map2 append cnstrts cnstrts')) - end - else raise CASE_ERROR - ("Wrong number of arguments for constructor " ^ name, i) - else ((in_group, row :: not_in_group), (names, cnstrts)) - | _ => raise CASE_ERROR ("Not a constructor pattern", i))) + let val k = length (binder_types T) in + fold (fn (row as ((prfx, p :: rst), rhs as (_, (i, _)))) => + fn ((in_group, not_in_group), (names, cnstrts)) => + (case strip_comb p of + (Const (name', _), args) => + if name = name' then + if length args = k then + let val (args', cnstrts') = split_list (map strip_constraints args) + in + ((((prfx, args' @ rst), rhs) :: in_group, not_in_group), + (default_names names args', map2 append cnstrts cnstrts')) + end + else raise CASE_ERROR + ("Wrong number of arguments for constructor " ^ name, i) + else ((in_group, row :: not_in_group), (names, cnstrts)) + | _ => raise CASE_ERROR ("Not a constructor pattern", i))) rows (([], []), (replicate k "", replicate k [])) |>> pairself rev end; @@ -158,8 +158,7 @@ constraints = cnstrts, group = in_group'} :: A} end - in part {constrs = constructors, rows = rows, A = []} - end; + in part {constrs = constructors, rows = rows, A = []} end; (*--------------------------------------------------------------------------- * Misc. routines used in mk_case @@ -210,48 +209,46 @@ | mk {path, rows as ((row as ((_, [Free _]), _)) :: _ :: _)} = mk {path = path, rows = [row]} | mk {path = u :: rstp, rows as ((_, _ :: _), _) :: _} = - let val col0 = map (fn ((_, p :: _), (_, (i, _))) => (p, i)) rows - in case Option.map (apfst head_of) - (find_first (not o is_Free o fst) col0) of + let val col0 = map (fn ((_, p :: _), (_, (i, _))) => (p, i)) rows in + (case Option.map (apfst head_of) (find_first (not o is_Free o fst) col0) of NONE => let val rows' = map (fn ((v, _), row) => row ||> pattern_subst [(v, u)] |>> v_to_prfx) (col0 ~~ rows); val (pref_patl, tm) = mk {path = rstp, rows = rows'} in (map v_to_pats pref_patl, tm) end - | SOME (Const (cname, cT), i) => (case ty_info tab (cname, cT) of - NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i) - | SOME {case_name, constructors} => - let - val pty = body_type cT; - val used' = fold Term.add_free_names rstp used; - val nrows = maps (expand constructors used' pty) rows; - val subproblems = partition ty_match ty_inst type_of used' - constructors pty range_ty nrows; - val constructors' = map #constructor subproblems - val news = map (fn {new_formals, group, ...} => - {path = new_formals @ rstp, rows = group}) subproblems; - val (pat_rect, dtrees) = split_list (map mk news); - val case_functions = map2 - (fn {new_formals, names, constraints, ...} => - fold_rev (fn ((x as Free (_, T), s), cnstrts) => fn t => - Abs (if s = "" then name else s, T, - abstract_over (x, t)) |> - fold mk_fun_constrain cnstrts) - (new_formals ~~ names ~~ constraints)) - subproblems dtrees; - val types = map type_of (case_functions @ [u]); - val case_const = Const (case_name, types ---> range_ty) - val tree = list_comb (case_const, case_functions @ [u]) - val pat_rect1 = maps mk_pat (constructors ~~ constructors' ~~ pat_rect) - in (pat_rect1, tree) - end) + | SOME (Const (cname, cT), i) => + (case ty_info tab (cname, cT) of + NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i) + | SOME {case_name, constructors} => + let + val pty = body_type cT; + val used' = fold Term.add_free_names rstp used; + val nrows = maps (expand constructors used' pty) rows; + val subproblems = partition ty_match ty_inst type_of used' + constructors pty range_ty nrows; + val constructors' = map #constructor subproblems + val news = map (fn {new_formals, group, ...} => + {path = new_formals @ rstp, rows = group}) subproblems; + val (pat_rect, dtrees) = split_list (map mk news); + val case_functions = map2 + (fn {new_formals, names, constraints, ...} => + fold_rev (fn ((x as Free (_, T), s), cnstrts) => fn t => + Abs (if s = "" then name else s, T, + abstract_over (x, t)) |> + fold mk_fun_constrain cnstrts) + (new_formals ~~ names ~~ constraints)) + subproblems dtrees; + val types = map type_of (case_functions @ [u]); + val case_const = Const (case_name, types ---> range_ty) + val tree = list_comb (case_const, case_functions @ [u]) + val pat_rect1 = maps mk_pat (constructors ~~ constructors' ~~ pat_rect) + in (pat_rect1, tree) end) | SOME (t, i) => raise CASE_ERROR ("Not a datatype constructor: " ^ - Syntax.string_of_term ctxt t, i) + Syntax.string_of_term ctxt t, i)) end | mk _ = raise CASE_ERROR ("Malformed row matrix", ~1) - in mk - end; + in mk end; fun case_error s = error ("Error in case expression:\n" ^ s); @@ -289,12 +286,12 @@ val finals = map row_of_pat patts2 val originals = map (row_of_pat o #2) rows val _ = - case subtract (op =) finals originals of - [] => () - | is => - (case config of Error => case_error | Warning => warning | Quiet => fn _ => {}) - ("The following clauses are redundant (covered by preceding clauses):\n" ^ - cat_lines (map (string_of_clause o nth clauses) is)); + (case subtract (op =) finals originals of + [] => () + | is => + (case config of Error => case_error | Warning => warning | Quiet => fn _ => {}) + ("The following clauses are redundant (covered by preceding clauses):\n" ^ + cat_lines (map (string_of_clause o nth clauses) is))); in (case_tm, patts2) end; @@ -308,48 +305,46 @@ (* parse translation *) fun case_tr err tab_of ctxt [t, u] = - let - val thy = ProofContext.theory_of ctxt; - val intern_const_syntax = Consts.intern_syntax (Sign.consts_of thy); + let + val thy = ProofContext.theory_of ctxt; + val intern_const_syntax = Consts.intern_syntax (Sign.consts_of thy); - (* replace occurrences of dummy_pattern by distinct variables *) - (* internalize constant names *) - fun prep_pat ((c as Const (@{syntax_const "_constrain"}, _)) $ t $ tT) used = - let val (t', used') = prep_pat t used - in (c $ t' $ tT, used') end - | prep_pat (Const (@{const_syntax dummy_pattern}, T)) used = - let val x = Name.variant used "x" - in (Free (x, T), x :: used) end - | prep_pat (Const (s, T)) used = - (Const (intern_const_syntax s, T), used) - | prep_pat (v as Free (s, T)) used = - let val s' = Sign.intern_const thy s - in - if Sign.declared_const thy s' then - (Const (s', T), used) - else (v, used) - end - | prep_pat (t $ u) used = - let - val (t', used') = prep_pat t used; - val (u', used'') = prep_pat u used' - in - (t' $ u', used'') - end - | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t); - fun dest_case1 (t as Const (@{syntax_const "_case1"}, _) $ l $ r) = - let val (l', cnstrts) = strip_constraints l - in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts) - end - | dest_case1 t = case_error "dest_case1"; - fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u - | dest_case2 t = [t]; - val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u)); - val (case_tm, _) = make_case_untyped (tab_of thy) ctxt - (if err then Error else Warning) [] - (fold (fn tT => fn t => Syntax.const @{syntax_const "_constrain"} $ t $ tT) - (flat cnstrts) t) cases; - in case_tm end + (* replace occurrences of dummy_pattern by distinct variables *) + (* internalize constant names *) + fun prep_pat ((c as Const (@{syntax_const "_constrain"}, _)) $ t $ tT) used = + let val (t', used') = prep_pat t used + in (c $ t' $ tT, used') end + | prep_pat (Const (@{const_syntax dummy_pattern}, T)) used = + let val x = Name.variant used "x" + in (Free (x, T), x :: used) end + | prep_pat (Const (s, T)) used = + (Const (intern_const_syntax s, T), used) + | prep_pat (v as Free (s, T)) used = + let val s' = Sign.intern_const thy s in + if Sign.declared_const thy s' then + (Const (s', T), used) + else (v, used) + end + | prep_pat (t $ u) used = + let + val (t', used') = prep_pat t used; + val (u', used'') = prep_pat u used' + in + (t' $ u', used'') + end + | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t); + fun dest_case1 (t as Const (@{syntax_const "_case1"}, _) $ l $ r) = + let val (l', cnstrts) = strip_constraints l + in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts) end + | dest_case1 t = case_error "dest_case1"; + fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u + | dest_case2 t = [t]; + val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u)); + val (case_tm, _) = make_case_untyped (tab_of thy) ctxt + (if err then Error else Warning) [] + (fold (fn tT => fn t => Syntax.const @{syntax_const "_constrain"} $ t $ tT) + (flat cnstrts) t) cases; + in case_tm end | case_tr _ _ _ ts = case_error "case_tr"; @@ -360,7 +355,7 @@ (* destruct one level of pattern matching *) fun gen_dest_case name_of type_of tab d used t = - case apfst name_of (strip_comb t) of + (case apfst name_of (strip_comb t) of (SOME cname, ts as _ :: _) => let val (fs, x) = split_last ts; @@ -375,15 +370,14 @@ in (xs', subst_bounds (rev xs', u)) end; fun is_dependent i t = let val k = length (strip_abs_vars t) - i - in k < 0 orelse exists (fn j => j >= k) - (loose_bnos (strip_abs_body t)) - end; + in k < 0 orelse exists (fn j => j >= k) (loose_bnos (strip_abs_body t)) end; fun count_cases (_, _, true) = I | count_cases (c, (_, body), false) = AList.map_default op aconv (body, []) (cons c); val is_undefined = name_of #> equal (SOME @{const_name undefined}); fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body) - in case ty_info tab cname of + in + (case ty_info tab cname of SOME {constructors, case_name} => if length fs = length constructors then let @@ -400,26 +394,28 @@ val R = type_of t; val dummy = if d then Const (@{const_name dummy_pattern}, R) - else Free (Name.variant used "x", R) + else Free (Name.variant used "x", R); in - SOME (x, map mk_case (case find_first (is_undefined o fst) cases' of - SOME (_, cs) => - if length cs = length constructors then [hd cases] - else filter_out (fn (_, (_, body), _) => is_undefined body) cases - | NONE => case cases' of - [] => cases - | (default, cs) :: _ => - if length cs = 1 then cases - else if length cs = length constructors then - [hd cases, (dummy, ([], default), false)] - else - filter_out (fn (c, _, _) => member op aconv cs c) cases @ - [(dummy, ([], default), false)])) + SOME (x, + map mk_case + (case find_first (is_undefined o fst) cases' of + SOME (_, cs) => + if length cs = length constructors then [hd cases] + else filter_out (fn (_, (_, body), _) => is_undefined body) cases + | NONE => case cases' of + [] => cases + | (default, cs) :: _ => + if length cs = 1 then cases + else if length cs = length constructors then + [hd cases, (dummy, ([], default), false)] + else + filter_out (fn (c, _, _) => member op aconv cs c) cases @ + [(dummy, ([], default), false)])) end handle CASE_ERROR _ => NONE else NONE - | _ => NONE + | _ => NONE) end - | _ => NONE; + | _ => NONE); val dest_case = gen_dest_case (try (dest_Const #> fst)) fastype_of; val dest_case' = gen_dest_case (try (dest_Const #> fst #> Syntax.unmark_const)) (K dummyT); @@ -428,7 +424,7 @@ (* destruct nested patterns *) fun strip_case'' dest (pat, rhs) = - case dest (Term.add_free_names pat []) rhs of + (case dest (Term.add_free_names pat []) rhs of SOME (exp as Free _, clauses) => if member op aconv (OldTerm.term_frees pat) exp andalso not (exists (fn (_, rhs') => @@ -437,13 +433,13 @@ maps (strip_case'' dest) (map (fn (pat', rhs') => (subst_free [(exp, pat')] pat, rhs')) clauses) else [(pat, rhs)] - | _ => [(pat, rhs)]; + | _ => [(pat, rhs)]); fun gen_strip_case dest t = - case dest [] t of + (case dest [] t of SOME (x, clauses) => SOME (x, maps (strip_case'' dest) clauses) - | NONE => NONE; + | NONE => NONE); val strip_case = gen_strip_case oo dest_case; val strip_case' = gen_strip_case oo dest_case'; @@ -455,8 +451,7 @@ let val thy = ProofContext.theory_of ctxt; fun mk_clause (pat, rhs) = - let val xs = Term.add_frees pat [] - in + let val xs = Term.add_frees pat [] in Syntax.const @{syntax_const "_case1"} $ map_aterms (fn Free p => Syntax.mark_boundT p @@ -466,14 +461,14 @@ (fn x as Free (s, T) => if member (op =) xs (s, T) then Syntax.mark_bound s else x | t => t) rhs - end + end; in - case strip_case' (tab_of thy) true (list_comb (Syntax.const cname, ts)) of + (case strip_case' (tab_of thy) true (list_comb (Syntax.const cname, ts)) of SOME (x, clauses) => Syntax.const @{syntax_const "_case_syntax"} $ x $ foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u) (map mk_clause clauses) - | NONE => raise Match + | NONE => raise Match) end; end;