src/HOL/Tools/Datatype/datatype_aux.ML
changeset 31775 2b04504fcb69
parent 31737 b3f63611784e
child 32124 954321008785
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/HOL/Tools/Datatype/datatype_aux.ML	Tue Jun 23 12:09:30 2009 +0200
     1.3 @@ -0,0 +1,381 @@
     1.4 +(*  Title:      HOL/Tools/datatype_aux.ML
     1.5 +    Author:     Stefan Berghofer, TU Muenchen
     1.6 +
     1.7 +Auxiliary functions for defining datatypes.
     1.8 +*)
     1.9 +
    1.10 +signature DATATYPE_COMMON =
    1.11 +sig
    1.12 +  type config
    1.13 +  val default_config : config
    1.14 +  datatype dtyp =
    1.15 +      DtTFree of string
    1.16 +    | DtType of string * (dtyp list)
    1.17 +    | DtRec of int;
    1.18 +  type descr
    1.19 +  type info
    1.20 +end
    1.21 +
    1.22 +signature DATATYPE_AUX =
    1.23 +sig
    1.24 +  include DATATYPE_COMMON
    1.25 +
    1.26 +  val message : config -> string -> unit
    1.27 +  
    1.28 +  val add_path : bool -> string -> theory -> theory
    1.29 +  val parent_path : bool -> theory -> theory
    1.30 +
    1.31 +  val store_thmss_atts : string -> string list -> attribute list list -> thm list list
    1.32 +    -> theory -> thm list list * theory
    1.33 +  val store_thmss : string -> string list -> thm list list -> theory -> thm list list * theory
    1.34 +  val store_thms_atts : string -> string list -> attribute list list -> thm list
    1.35 +    -> theory -> thm list * theory
    1.36 +  val store_thms : string -> string list -> thm list -> theory -> thm list * theory
    1.37 +
    1.38 +  val split_conj_thm : thm -> thm list
    1.39 +  val mk_conj : term list -> term
    1.40 +  val mk_disj : term list -> term
    1.41 +
    1.42 +  val app_bnds : term -> int -> term
    1.43 +
    1.44 +  val cong_tac : int -> tactic
    1.45 +  val indtac : thm -> string list -> int -> tactic
    1.46 +  val exh_tac : (string -> thm) -> int -> tactic
    1.47 +
    1.48 +  datatype simproc_dist = FewConstrs of thm list
    1.49 +                        | ManyConstrs of thm * simpset;
    1.50 +
    1.51 +
    1.52 +  exception Datatype
    1.53 +  exception Datatype_Empty of string
    1.54 +  val name_of_typ : typ -> string
    1.55 +  val dtyp_of_typ : (string * string list) list -> typ -> dtyp
    1.56 +  val mk_Free : string -> typ -> int -> term
    1.57 +  val is_rec_type : dtyp -> bool
    1.58 +  val typ_of_dtyp : descr -> (string * sort) list -> dtyp -> typ
    1.59 +  val dest_DtTFree : dtyp -> string
    1.60 +  val dest_DtRec : dtyp -> int
    1.61 +  val strip_dtyp : dtyp -> dtyp list * dtyp
    1.62 +  val body_index : dtyp -> int
    1.63 +  val mk_fun_dtyp : dtyp list -> dtyp -> dtyp
    1.64 +  val get_nonrec_types : descr -> (string * sort) list -> typ list
    1.65 +  val get_branching_types : descr -> (string * sort) list -> typ list
    1.66 +  val get_arities : descr -> int list
    1.67 +  val get_rec_types : descr -> (string * sort) list -> typ list
    1.68 +  val interpret_construction : descr -> (string * sort) list
    1.69 +    -> { atyp: typ -> 'a, dtyp: typ list -> int * bool -> string * typ list -> 'a }
    1.70 +    -> ((string * Term.typ list) * (string * 'a list) list) list
    1.71 +  val check_nonempty : descr list -> unit
    1.72 +  val unfold_datatypes : 
    1.73 +    theory -> descr -> (string * sort) list -> info Symtab.table ->
    1.74 +      descr -> int -> descr list * int
    1.75 +end;
    1.76 +
    1.77 +structure DatatypeAux : DATATYPE_AUX =
    1.78 +struct
    1.79 +
    1.80 +(* datatype option flags *)
    1.81 +
    1.82 +type config = { strict: bool, flat_names: bool, quiet: bool };
    1.83 +val default_config : config =
    1.84 +  { strict = true, flat_names = false, quiet = false };
    1.85 +fun message ({ quiet, ...} : config) s =
    1.86 +  if quiet then () else writeln s;
    1.87 +
    1.88 +fun add_path flat_names s = if flat_names then I else Sign.add_path s;
    1.89 +fun parent_path flat_names = if flat_names then I else Sign.parent_path;
    1.90 +
    1.91 +
    1.92 +(* store theorems in theory *)
    1.93 +
    1.94 +fun store_thmss_atts label tnames attss thmss =
    1.95 +  fold_map (fn ((tname, atts), thms) =>
    1.96 +    Sign.add_path tname
    1.97 +    #> PureThy.add_thmss [((Binding.name label, thms), atts)]
    1.98 +    #-> (fn thm::_ => Sign.parent_path #> pair thm)) (tnames ~~ attss ~~ thmss)
    1.99 +  ##> Theory.checkpoint;
   1.100 +
   1.101 +fun store_thmss label tnames = store_thmss_atts label tnames (replicate (length tnames) []);
   1.102 +
   1.103 +fun store_thms_atts label tnames attss thmss =
   1.104 +  fold_map (fn ((tname, atts), thms) =>
   1.105 +    Sign.add_path tname
   1.106 +    #> PureThy.add_thms [((Binding.name label, thms), atts)]
   1.107 +    #-> (fn thm::_ => Sign.parent_path #> pair thm)) (tnames ~~ attss ~~ thmss)
   1.108 +  ##> Theory.checkpoint;
   1.109 +
   1.110 +fun store_thms label tnames = store_thms_atts label tnames (replicate (length tnames) []);
   1.111 +
   1.112 +
   1.113 +(* split theorem thm_1 & ... & thm_n into n theorems *)
   1.114 +
   1.115 +fun split_conj_thm th =
   1.116 +  ((th RS conjunct1)::(split_conj_thm (th RS conjunct2))) handle THM _ => [th];
   1.117 +
   1.118 +val mk_conj = foldr1 (HOLogic.mk_binop "op &");
   1.119 +val mk_disj = foldr1 (HOLogic.mk_binop "op |");
   1.120 +
   1.121 +fun app_bnds t i = list_comb (t, map Bound (i - 1 downto 0));
   1.122 +
   1.123 +
   1.124 +fun cong_tac i st = (case Logic.strip_assums_concl
   1.125 +  (List.nth (prems_of st, i - 1)) of
   1.126 +    _ $ (_ $ (f $ x) $ (g $ y)) =>
   1.127 +      let
   1.128 +        val cong' = Thm.lift_rule (Thm.cprem_of st i) cong;
   1.129 +        val _ $ (_ $ (f' $ x') $ (g' $ y')) =
   1.130 +          Logic.strip_assums_concl (prop_of cong');
   1.131 +        val insts = map (pairself (cterm_of (Thm.theory_of_thm st)) o
   1.132 +          apsnd (curry list_abs (Logic.strip_params (concl_of cong'))) o
   1.133 +            apfst head_of) [(f', f), (g', g), (x', x), (y', y)]
   1.134 +      in compose_tac (false, cterm_instantiate insts cong', 2) i st
   1.135 +        handle THM _ => no_tac st
   1.136 +      end
   1.137 +  | _ => no_tac st);
   1.138 +
   1.139 +(* instantiate induction rule *)
   1.140 +
   1.141 +fun indtac indrule indnames i st =
   1.142 +  let
   1.143 +    val ts = HOLogic.dest_conj (HOLogic.dest_Trueprop (concl_of indrule));
   1.144 +    val ts' = HOLogic.dest_conj (HOLogic.dest_Trueprop
   1.145 +      (Logic.strip_imp_concl (List.nth (prems_of st, i - 1))));
   1.146 +    val getP = if can HOLogic.dest_imp (hd ts) then
   1.147 +      (apfst SOME) o HOLogic.dest_imp else pair NONE;
   1.148 +    val flt = if null indnames then I else
   1.149 +      filter (fn Free (s, _) => s mem indnames | _ => false);
   1.150 +    fun abstr (t1, t2) = (case t1 of
   1.151 +        NONE => (case flt (OldTerm.term_frees t2) of
   1.152 +            [Free (s, T)] => SOME (absfree (s, T, t2))
   1.153 +          | _ => NONE)
   1.154 +      | SOME (_ $ t') => SOME (Abs ("x", fastype_of t', abstract_over (t', t2))))
   1.155 +    val cert = cterm_of (Thm.theory_of_thm st);
   1.156 +    val insts = List.mapPartial (fn (t, u) => case abstr (getP u) of
   1.157 +        NONE => NONE
   1.158 +      | SOME u' => SOME (t |> getP |> snd |> head_of |> cert, cert u')) (ts ~~ ts');
   1.159 +    val indrule' = cterm_instantiate insts indrule
   1.160 +  in
   1.161 +    rtac indrule' i st
   1.162 +  end;
   1.163 +
   1.164 +(* perform exhaustive case analysis on last parameter of subgoal i *)
   1.165 +
   1.166 +fun exh_tac exh_thm_of i state =
   1.167 +  let
   1.168 +    val thy = Thm.theory_of_thm state;
   1.169 +    val prem = nth (prems_of state) (i - 1);
   1.170 +    val params = Logic.strip_params prem;
   1.171 +    val (_, Type (tname, _)) = hd (rev params);
   1.172 +    val exhaustion = Thm.lift_rule (Thm.cprem_of state i) (exh_thm_of tname);
   1.173 +    val prem' = hd (prems_of exhaustion);
   1.174 +    val _ $ (_ $ lhs $ _) = hd (rev (Logic.strip_assums_hyp prem'));
   1.175 +    val exhaustion' = cterm_instantiate [(cterm_of thy (head_of lhs),
   1.176 +      cterm_of thy (List.foldr (fn ((_, T), t) => Abs ("z", T, t))
   1.177 +        (Bound 0) params))] exhaustion
   1.178 +  in compose_tac (false, exhaustion', nprems_of exhaustion) i state
   1.179 +  end;
   1.180 +
   1.181 +(* handling of distinctness theorems *)
   1.182 +
   1.183 +datatype simproc_dist = FewConstrs of thm list
   1.184 +                      | ManyConstrs of thm * simpset;
   1.185 +
   1.186 +(********************** Internal description of datatypes *********************)
   1.187 +
   1.188 +datatype dtyp =
   1.189 +    DtTFree of string
   1.190 +  | DtType of string * (dtyp list)
   1.191 +  | DtRec of int;
   1.192 +
   1.193 +(* information about datatypes *)
   1.194 +
   1.195 +(* index, datatype name, type arguments, constructor name, types of constructor's arguments *)
   1.196 +type descr = (int * (string * dtyp list * (string * dtyp list) list)) list;
   1.197 +
   1.198 +type info =
   1.199 +  {index : int,
   1.200 +   alt_names : string list option,
   1.201 +   descr : descr,
   1.202 +   sorts : (string * sort) list,
   1.203 +   rec_names : string list,
   1.204 +   rec_rewrites : thm list,
   1.205 +   case_name : string,
   1.206 +   case_rewrites : thm list,
   1.207 +   induction : thm,
   1.208 +   exhaustion : thm,
   1.209 +   distinct : simproc_dist,
   1.210 +   inject : thm list,
   1.211 +   nchotomy : thm,
   1.212 +   case_cong : thm,
   1.213 +   weak_case_cong : thm};
   1.214 +
   1.215 +fun mk_Free s T i = Free (s ^ (string_of_int i), T);
   1.216 +
   1.217 +fun subst_DtTFree _ substs (T as (DtTFree name)) =
   1.218 +      AList.lookup (op =) substs name |> the_default T
   1.219 +  | subst_DtTFree i substs (DtType (name, ts)) =
   1.220 +      DtType (name, map (subst_DtTFree i substs) ts)
   1.221 +  | subst_DtTFree i _ (DtRec j) = DtRec (i + j);
   1.222 +
   1.223 +exception Datatype;
   1.224 +exception Datatype_Empty of string;
   1.225 +
   1.226 +fun dest_DtTFree (DtTFree a) = a
   1.227 +  | dest_DtTFree _ = raise Datatype;
   1.228 +
   1.229 +fun dest_DtRec (DtRec i) = i
   1.230 +  | dest_DtRec _ = raise Datatype;
   1.231 +
   1.232 +fun is_rec_type (DtType (_, dts)) = exists is_rec_type dts
   1.233 +  | is_rec_type (DtRec _) = true
   1.234 +  | is_rec_type _ = false;
   1.235 +
   1.236 +fun strip_dtyp (DtType ("fun", [T, U])) = apfst (cons T) (strip_dtyp U)
   1.237 +  | strip_dtyp T = ([], T);
   1.238 +
   1.239 +val body_index = dest_DtRec o snd o strip_dtyp;
   1.240 +
   1.241 +fun mk_fun_dtyp [] U = U
   1.242 +  | mk_fun_dtyp (T :: Ts) U = DtType ("fun", [T, mk_fun_dtyp Ts U]);
   1.243 +
   1.244 +fun name_of_typ (Type (s, Ts)) =
   1.245 +      let val s' = Long_Name.base_name s
   1.246 +      in space_implode "_" (List.filter (not o equal "") (map name_of_typ Ts) @
   1.247 +        [if Syntax.is_identifier s' then s' else "x"])
   1.248 +      end
   1.249 +  | name_of_typ _ = "";
   1.250 +
   1.251 +fun dtyp_of_typ _ (TFree (n, _)) = DtTFree n
   1.252 +  | dtyp_of_typ _ (TVar _) = error "Illegal schematic type variable(s)"
   1.253 +  | dtyp_of_typ new_dts (Type (tname, Ts)) =
   1.254 +      (case AList.lookup (op =) new_dts tname of
   1.255 +         NONE => DtType (tname, map (dtyp_of_typ new_dts) Ts)
   1.256 +       | SOME vs => if map (try (fst o dest_TFree)) Ts = map SOME vs then
   1.257 +             DtRec (find_index (curry op = tname o fst) new_dts)
   1.258 +           else error ("Illegal occurrence of recursive type " ^ tname));
   1.259 +
   1.260 +fun typ_of_dtyp descr sorts (DtTFree a) = TFree (a, (the o AList.lookup (op =) sorts) a)
   1.261 +  | typ_of_dtyp descr sorts (DtRec i) =
   1.262 +      let val (s, ds, _) = (the o AList.lookup (op =) descr) i
   1.263 +      in Type (s, map (typ_of_dtyp descr sorts) ds) end
   1.264 +  | typ_of_dtyp descr sorts (DtType (s, ds)) =
   1.265 +      Type (s, map (typ_of_dtyp descr sorts) ds);
   1.266 +
   1.267 +(* find all non-recursive types in datatype description *)
   1.268 +
   1.269 +fun get_nonrec_types descr sorts =
   1.270 +  map (typ_of_dtyp descr sorts) (Library.foldl (fn (Ts, (_, (_, _, constrs))) =>
   1.271 +    Library.foldl (fn (Ts', (_, cargs)) =>
   1.272 +      filter_out is_rec_type cargs union Ts') (Ts, constrs)) ([], descr));
   1.273 +
   1.274 +(* get all recursive types in datatype description *)
   1.275 +
   1.276 +fun get_rec_types descr sorts = map (fn (_ , (s, ds, _)) =>
   1.277 +  Type (s, map (typ_of_dtyp descr sorts) ds)) descr;
   1.278 +
   1.279 +(* get all branching types *)
   1.280 +
   1.281 +fun get_branching_types descr sorts =
   1.282 +  map (typ_of_dtyp descr sorts) (fold (fn (_, (_, _, constrs)) =>
   1.283 +    fold (fn (_, cargs) => fold (strip_dtyp #> fst #> fold (insert op =)) cargs)
   1.284 +      constrs) descr []);
   1.285 +
   1.286 +fun get_arities descr = fold (fn (_, (_, _, constrs)) =>
   1.287 +  fold (fn (_, cargs) => fold (insert op =) (map (length o fst o strip_dtyp)
   1.288 +    (List.filter is_rec_type cargs))) constrs) descr [];
   1.289 +
   1.290 +(* interpret construction of datatype *)
   1.291 +
   1.292 +fun interpret_construction descr vs { atyp, dtyp } =
   1.293 +  let
   1.294 +    val typ_of_dtyp = typ_of_dtyp descr vs;
   1.295 +    fun interpT dT = case strip_dtyp dT
   1.296 +     of (dTs, DtRec l) =>
   1.297 +          let
   1.298 +            val (tyco, dTs', _) = (the o AList.lookup (op =) descr) l;
   1.299 +            val Ts = map typ_of_dtyp dTs;
   1.300 +            val Ts' = map typ_of_dtyp dTs';
   1.301 +            val is_proper = forall (can dest_TFree) Ts';
   1.302 +          in dtyp Ts (l, is_proper) (tyco, Ts') end
   1.303 +      | _ => atyp (typ_of_dtyp dT);
   1.304 +    fun interpC (c, dTs) = (c, map interpT dTs);
   1.305 +    fun interpD (_, (tyco, dTs, cs)) = ((tyco, map typ_of_dtyp dTs), map interpC cs);
   1.306 +  in map interpD descr end;
   1.307 +
   1.308 +(* nonemptiness check for datatypes *)
   1.309 +
   1.310 +fun check_nonempty descr =
   1.311 +  let
   1.312 +    val descr' = List.concat descr;
   1.313 +    fun is_nonempty_dt is i =
   1.314 +      let
   1.315 +        val (_, _, constrs) = (the o AList.lookup (op =) descr') i;
   1.316 +        fun arg_nonempty (_, DtRec i) = if i mem is then false
   1.317 +              else is_nonempty_dt (i::is) i
   1.318 +          | arg_nonempty _ = true;
   1.319 +      in exists ((forall (arg_nonempty o strip_dtyp)) o snd) constrs
   1.320 +      end
   1.321 +  in assert_all (fn (i, _) => is_nonempty_dt [i] i) (hd descr)
   1.322 +    (fn (_, (s, _, _)) => raise Datatype_Empty s)
   1.323 +  end;
   1.324 +
   1.325 +(* unfold a list of mutually recursive datatype specifications *)
   1.326 +(* all types of the form DtType (dt_name, [..., DtRec _, ...]) *)
   1.327 +(* need to be unfolded                                         *)
   1.328 +
   1.329 +fun unfold_datatypes sign orig_descr sorts (dt_info : info Symtab.table) descr i =
   1.330 +  let
   1.331 +    fun typ_error T msg = error ("Non-admissible type expression\n" ^
   1.332 +      Syntax.string_of_typ_global sign (typ_of_dtyp (orig_descr @ descr) sorts T) ^ "\n" ^ msg);
   1.333 +
   1.334 +    fun get_dt_descr T i tname dts =
   1.335 +      (case Symtab.lookup dt_info tname of
   1.336 +         NONE => typ_error T (tname ^ " is not a datatype - can't use it in\
   1.337 +           \ nested recursion")
   1.338 +       | (SOME {index, descr, ...}) =>
   1.339 +           let val (_, vars, _) = (the o AList.lookup (op =) descr) index;
   1.340 +               val subst = ((map dest_DtTFree vars) ~~ dts) handle Library.UnequalLengths =>
   1.341 +                 typ_error T ("Type constructor " ^ tname ^ " used with wrong\
   1.342 +                  \ number of arguments")
   1.343 +           in (i + index, map (fn (j, (tn, args, cs)) => (i + j,
   1.344 +             (tn, map (subst_DtTFree i subst) args,
   1.345 +              map (apsnd (map (subst_DtTFree i subst))) cs))) descr)
   1.346 +           end);
   1.347 +
   1.348 +    (* unfold a single constructor argument *)
   1.349 +
   1.350 +    fun unfold_arg ((i, Ts, descrs), T) =
   1.351 +      if is_rec_type T then
   1.352 +        let val (Us, U) = strip_dtyp T
   1.353 +        in if exists is_rec_type Us then
   1.354 +            typ_error T "Non-strictly positive recursive occurrence of type"
   1.355 +          else (case U of
   1.356 +              DtType (tname, dts) =>  
   1.357 +                let
   1.358 +                  val (index, descr) = get_dt_descr T i tname dts;
   1.359 +                  val (descr', i') = unfold_datatypes sign orig_descr sorts
   1.360 +                    dt_info descr (i + length descr)
   1.361 +                in (i', Ts @ [mk_fun_dtyp Us (DtRec index)], descrs @ descr') end
   1.362 +            | _ => (i, Ts @ [T], descrs))
   1.363 +        end
   1.364 +      else (i, Ts @ [T], descrs);
   1.365 +
   1.366 +    (* unfold a constructor *)
   1.367 +
   1.368 +    fun unfold_constr ((i, constrs, descrs), (cname, cargs)) =
   1.369 +      let val (i', cargs', descrs') = Library.foldl unfold_arg ((i, [], descrs), cargs)
   1.370 +      in (i', constrs @ [(cname, cargs')], descrs') end;
   1.371 +
   1.372 +    (* unfold a single datatype *)
   1.373 +
   1.374 +    fun unfold_datatype ((i, dtypes, descrs), (j, (tname, tvars, constrs))) =
   1.375 +      let val (i', constrs', descrs') =
   1.376 +        Library.foldl unfold_constr ((i, [], descrs), constrs)
   1.377 +      in (i', dtypes @ [(j, (tname, tvars, constrs'))], descrs')
   1.378 +      end;
   1.379 +
   1.380 +    val (i', descr', descrs) = Library.foldl unfold_datatype ((i, [],[]), descr);
   1.381 +
   1.382 +  in (descr' :: descrs, i') end;
   1.383 +
   1.384 +end;