src/HOL/Tools/datatype_package.ML
changeset 14799 a405aadff16c
parent 14471 5688b05b2575
child 14887 4938ce4ef295
     1.1 --- a/src/HOL/Tools/datatype_package.ML	Fri May 21 21:47:07 2004 +0200
     1.2 +++ b/src/HOL/Tools/datatype_package.ML	Fri May 21 21:48:03 2004 +0200
     1.3 @@ -393,6 +393,116 @@
     1.4     fn thy => (simpset_ref_of thy := simpset_of thy addsimprocs [distinct_simproc]; thy)];
     1.5  
     1.6  
     1.7 +(**** translation rules for case ****)
     1.8 +
     1.9 +fun case_tr sg [t, u] =
    1.10 +    let
    1.11 +      fun case_error s name ts = raise TERM ("Error in case expression" ^
    1.12 +        if_none (apsome (curry op ^ " for datatype ") name) "" ^ ":\n" ^ s, ts);
    1.13 +      fun dest_case1 (Const ("_case1", _) $ t $ u) = (case strip_comb t of
    1.14 +            (Const (s, _), ts) => (Sign.intern_const sg s, ts)
    1.15 +          | (Free (s, _), ts) => (Sign.intern_const sg s, ts)
    1.16 +          | _ => case_error "Head is not a constructor" None [t, u], u)
    1.17 +        | dest_case1 t = raise TERM ("dest_case1", [t]);
    1.18 +      fun dest_case2 (Const ("_case2", _) $ t $ u) = t :: dest_case2 u
    1.19 +        | dest_case2 t = [t];
    1.20 +      val cases as ((cname, _), _) :: _ = map dest_case1 (dest_case2 u);
    1.21 +      val tab = Symtab.dest (get_datatypes_sg sg);
    1.22 +      val (cases', default) = (case split_last cases of
    1.23 +          (cases', (("dummy_pattern", []), t)) => (cases', Some t)
    1.24 +        | _ => (cases, None))
    1.25 +      fun abstr (Free (x, T), body) = Term.absfree (x, T, body)
    1.26 +        | abstr (Const ("_constrain", _) $ Free (x, T) $ tT, body) =
    1.27 +            Syntax.const Syntax.constrainAbsC $ Term.absfree (x, T, body) $ tT
    1.28 +        | abstr (Const ("Pair", _) $ x $ y, body) =
    1.29 +            Syntax.const "split" $ abstr (x, abstr (y, body))
    1.30 +        | abstr (t, _) = case_error "Illegal pattern" None [t];
    1.31 +    in case find_first (fn (_, {descr, index, ...}) =>
    1.32 +      exists (equal cname o fst) (#3 (snd (nth_elem (index, descr))))) tab of
    1.33 +        None => case_error ("Not a datatype constructor: " ^ cname) None [u]
    1.34 +      | Some (tname, {descr, case_name, index, ...}) =>
    1.35 +        let
    1.36 +          val _ = if exists (equal "dummy_pattern" o fst o fst) cases' then
    1.37 +            case_error "Illegal occurrence of '_' dummy pattern" (Some tname) [u] else ();
    1.38 +          val (_, (_, dts, constrs)) = nth_elem (index, descr);
    1.39 +          val sorts = map (rpair [] o dest_DtTFree) dts;
    1.40 +          fun find_case (cases, (s, dt)) =
    1.41 +            (case find_first (equal s o fst o fst) cases' of
    1.42 +               None => (case default of
    1.43 +                   None => case_error ("No clause for constructor " ^ s) (Some tname) [u]
    1.44 +                 | Some t => (cases, list_abs (map (rpair dummyT) (DatatypeProp.make_tnames
    1.45 +                     (map (typ_of_dtyp descr sorts) dt)), t)))
    1.46 +             | Some (c as ((_, vs), t)) =>
    1.47 +                 if length dt <> length vs then
    1.48 +                    case_error ("Wrong number of arguments for constructor " ^ s)
    1.49 +                      (Some tname) vs
    1.50 +                 else (cases \ c, foldr abstr (vs, t)))
    1.51 +          val (cases'', fs) = foldl_map find_case (cases', constrs)
    1.52 +        in case (cases'', length constrs = length cases', default) of
    1.53 +            ([], true, Some _) =>
    1.54 +              case_error "Extra '_' dummy pattern" (Some tname) [u]
    1.55 +          | (_ :: _, _, _) =>
    1.56 +              let val extra = distinct (map (fst o fst) cases'')
    1.57 +              in case extra \\ map fst constrs of
    1.58 +                  [] => case_error ("More than one clause for constructor(s) " ^
    1.59 +                    commas extra) (Some tname) [u]
    1.60 +                | extra' => case_error ("Illegal constructor(s): " ^ commas extra')
    1.61 +                    (Some tname) [u]
    1.62 +              end
    1.63 +          | _ => list_comb (Syntax.const case_name, fs) $ t
    1.64 +        end
    1.65 +    end
    1.66 +  | case_tr sg ts = raise TERM ("case_tr", ts);
    1.67 +
    1.68 +fun case_tr' constrs sg ts =
    1.69 +  if length ts <> length constrs + 1 then raise Match else
    1.70 +  let
    1.71 +    val (fs, x) = split_last ts;
    1.72 +    fun strip_abs 0 t = ([], t)
    1.73 +      | strip_abs i (Abs p) =
    1.74 +        let val (x, u) = Syntax.atomic_abs_tr' p
    1.75 +        in apfst (cons x) (strip_abs (i-1) u) end
    1.76 +      | strip_abs i (Const ("split", _) $ t) = (case strip_abs (i+1) t of
    1.77 +          (v :: v' :: vs, u) => (Syntax.const "Pair" $ v $ v' :: vs, u));
    1.78 +    fun is_dependent i t =
    1.79 +      let val k = length (strip_abs_vars t) - i
    1.80 +      in k < 0 orelse exists (fn j => j >= k)
    1.81 +        (loose_bnos (strip_abs_body t))
    1.82 +      end;
    1.83 +    val cases = map (fn ((cname, dts), t) =>
    1.84 +      (Sign.cond_extern sg Sign.constK cname,
    1.85 +       strip_abs (length dts) t, is_dependent (length dts) t))
    1.86 +      (constrs ~~ fs);
    1.87 +    fun count_cases (cs, (_, _, true)) = cs
    1.88 +      | count_cases (cs, (cname, (_, body), false)) = (case assoc (cs, body) of
    1.89 +          None => (body, [cname]) :: cs
    1.90 +        | Some cnames => overwrite (cs, (body, cnames @ [cname])));
    1.91 +    val cases' = sort (int_ord o Library.swap o pairself (length o snd))
    1.92 +      (foldl count_cases ([], cases));
    1.93 +    fun mk_case1 (cname, (vs, body), _) = Syntax.const "_case1" $
    1.94 +      list_comb (Syntax.const cname, vs) $ body;
    1.95 +  in
    1.96 +    Syntax.const "_case_syntax" $ x $
    1.97 +      foldr1 (fn (t, u) => Syntax.const "_case2" $ t $ u) (map mk_case1
    1.98 +        (case cases' of
    1.99 +           [] => cases
   1.100 +         | (default, cnames) :: _ =>
   1.101 +           if length cnames = 1 then cases
   1.102 +           else if length cnames = length constrs then
   1.103 +             [hd cases, ("dummy_pattern", ([], default), false)]
   1.104 +           else
   1.105 +             filter_out (fn (cname, _, _) => cname mem cnames) cases @
   1.106 +             [("dummy_pattern", ([], default), false)]))
   1.107 +  end;
   1.108 +
   1.109 +fun make_case_tr' case_names descr = flat (map
   1.110 +  (fn ((_, (_, _, constrs)), case_name) => map (rpair (case_tr' constrs))
   1.111 +    (NameSpace.accesses' case_name)) (descr ~~ case_names));
   1.112 +
   1.113 +val trfun_setup =
   1.114 +  [Theory.add_advanced_trfuns ([], [("_case_syntax", case_tr)], [], [])];
   1.115 +
   1.116 +
   1.117  (* prepare types *)
   1.118  
   1.119  fun read_typ sign ((Ts, sorts), str) =
   1.120 @@ -529,8 +639,7 @@
   1.121  
   1.122        Theory.add_consts_i (map (fn ((name, T), Ts) =>
   1.123          (name, Ts @ [T] ---> freeT, NoSyn))
   1.124 -          (case_names ~~ newTs ~~ case_fn_Ts)) |>
   1.125 -      Theory.add_trrules_i (DatatypeProp.make_case_trrules new_type_names descr);
   1.126 +          (case_names ~~ newTs ~~ case_fn_Ts));
   1.127  
   1.128      val reccomb_names' = map (Sign.intern_const (Theory.sign_of thy2')) reccomb_names;
   1.129      val case_names' = map (Sign.intern_const (Theory.sign_of thy2')) case_names;
   1.130 @@ -600,6 +709,7 @@
   1.131      val split_thms = split ~~ split_asm;
   1.132  
   1.133      val thy12 = thy11 |>
   1.134 +      Theory.add_advanced_trfuns ([], [], make_case_tr' case_names' (hd descr), []) |>
   1.135        Theory.add_path (space_implode "_" new_type_names) |>
   1.136        add_rules simps case_thms size_thms rec_thms inject distinct
   1.137                  weak_case_congs Simplifier.cong_add_global |> 
   1.138 @@ -657,6 +767,7 @@
   1.139      val simps = flat (distinct @ inject @ case_thms) @ size_thms @ rec_thms;
   1.140  
   1.141      val thy12 = thy11 |>
   1.142 +      Theory.add_advanced_trfuns ([], [], make_case_tr' case_names (hd descr), []) |>
   1.143        Theory.add_path (space_implode "_" new_type_names) |>
   1.144        add_rules simps case_thms size_thms rec_thms inject distinct
   1.145                  weak_case_congs (Simplifier.change_global_ss (op addcongs)) |> 
   1.146 @@ -765,6 +876,7 @@
   1.147      val simps = flat (distinct @ inject @ case_thms) @ size_thms @ rec_thms;
   1.148  
   1.149      val thy11 = thy10 |>
   1.150 +      Theory.add_advanced_trfuns ([], [], make_case_tr' case_names descr, []) |>
   1.151        add_rules simps case_thms size_thms rec_thms inject distinct
   1.152                  weak_case_congs (Simplifier.change_global_ss (op addcongs)) |> 
   1.153        put_datatypes (foldr Symtab.update (dt_infos, dt_info)) |>
   1.154 @@ -870,7 +982,7 @@
   1.155  
   1.156  (* setup theory *)
   1.157  
   1.158 -val setup = [DatatypesData.init, Method.add_methods tactic_emulations] @ simproc_setup;
   1.159 +val setup = [DatatypesData.init, Method.add_methods tactic_emulations] @ simproc_setup @ trfun_setup;
   1.160  
   1.161  
   1.162  (* outer syntax *)