src/HOL/Tools/datatype_codegen.ML
changeset 20426 9ffea7a8b31c
parent 20389 8b6ecb22ef35
child 20435 d2a30fed7596
     1.1 --- a/src/HOL/Tools/datatype_codegen.ML	Tue Aug 29 14:31:12 2006 +0200
     1.2 +++ b/src/HOL/Tools/datatype_codegen.ML	Tue Aug 29 14:31:13 2006 +0200
     1.3 @@ -2,11 +2,12 @@
     1.4      ID:         $Id$
     1.5      Author:     Stefan Berghofer & Florian Haftmann, TU Muenchen
     1.6  
     1.7 -Code generator for inductive datatypes.
     1.8 +Code generator for inductive datatypes and type copies ("code types").
     1.9  *)
    1.10  
    1.11  signature DATATYPE_CODEGEN =
    1.12  sig
    1.13 +  val get_eq: theory -> string -> thm list
    1.14    val get_datatype_spec_thms: theory -> string
    1.15      -> (((string * sort) list * (string * typ list) list) * tactic) option
    1.16    val datatype_tac: theory -> string -> tactic
    1.17 @@ -14,15 +15,19 @@
    1.18      -> ((string * typ) list * ((term * typ) * (term * term) list)) option
    1.19    val add_datatype_case_const: string -> theory -> theory
    1.20    val add_datatype_case_defs: string -> theory -> theory
    1.21 -  val datatypes_dependency: theory -> string list list
    1.22 -  val add_hook_bootstrap: DatatypeHooks.hook -> theory -> theory
    1.23 -  val get_datatype_mut_specs: theory -> string list
    1.24 -    -> ((string * sort) list * (string * (string * typ list) list) list)
    1.25 -  val get_datatype_arities: theory -> string list -> sort
    1.26 +
    1.27 +  type hook = (string * (bool * ((string * sort) list * (string * typ list) list))) list
    1.28 +    -> theory -> theory
    1.29 +  val codetypes_dependency: theory -> (string * bool) list list
    1.30 +  val add_codetypes_hook_bootstrap: hook -> theory -> theory
    1.31 +  val the_codetypes_mut_specs: theory -> (string * bool) list
    1.32 +    -> ((string * sort) list * (string * (bool * (string * typ list) list)) list)
    1.33 +  val get_codetypes_arities: theory -> (string * bool) list -> sort
    1.34      -> (string * (((string * sort list) * sort) * term list)) list option
    1.35 -  val prove_arities: (thm list -> tactic) -> string list -> sort
    1.36 +  val prove_codetypes_arities: (thm list -> tactic) -> (string * bool) list -> sort
    1.37      -> (theory -> ((string * sort list) * sort) list -> (string * term list) list
    1.38      -> ((bstring * attribute list) * term) list) -> theory -> theory
    1.39 +
    1.40    val setup: theory -> theory
    1.41  end;
    1.42  
    1.43 @@ -313,87 +318,7 @@
    1.44    | datatype_tycodegen _ _ _ _ _ _ _ = NONE;
    1.45  
    1.46  
    1.47 -(** code 2nd generation **)
    1.48 -
    1.49 -fun datatypes_dependency thy =
    1.50 -  let
    1.51 -    val dtnames = DatatypePackage.get_datatypes thy;
    1.52 -    fun add_node (dtname, _) =
    1.53 -      let
    1.54 -        fun add_tycos (Type (tyco, tys)) = insert (op =) tyco #> fold add_tycos tys
    1.55 -          | add_tycos _ = I;
    1.56 -        val deps = (filter (Symtab.defined dtnames) o maps (fn ty =>
    1.57 -          add_tycos ty [])
    1.58 -            o maps snd o snd o the o DatatypePackage.get_datatype_spec thy) dtname
    1.59 -      in
    1.60 -        Graph.default_node (dtname, ())
    1.61 -        #> fold (fn dtname' =>
    1.62 -             Graph.default_node (dtname', ())
    1.63 -             #> Graph.add_edge (dtname', dtname)
    1.64 -           ) deps
    1.65 -      end
    1.66 -  in
    1.67 -    Graph.empty
    1.68 -    |> Symtab.fold add_node dtnames
    1.69 -    |> Graph.strong_conn
    1.70 -  end;
    1.71 -
    1.72 -fun add_hook_bootstrap hook thy =
    1.73 -  thy
    1.74 -  |> fold hook (datatypes_dependency thy)
    1.75 -  |> DatatypeHooks.add hook;
    1.76 -
    1.77 -fun get_datatype_mut_specs thy (tycos as tyco :: _) =
    1.78 -  let
    1.79 -    val tycos' = (map (#1 o snd) o #descr o DatatypePackage.the_datatype thy) tyco;
    1.80 -    val _ = if gen_subset (op =) (tycos, tycos') then () else
    1.81 -      error ("datatype constructors are not mutually recursive: " ^ (commas o map quote) tycos);
    1.82 -    val (vs::_, css) = split_list (map (the o DatatypePackage.get_datatype_spec thy) tycos);
    1.83 -  in (vs, tycos ~~ css) end;
    1.84 -
    1.85 -fun get_datatype_arities thy tycos sort =
    1.86 -  let
    1.87 -    val algebra = Sign.classes_of thy;
    1.88 -    val (vs_proto, css_proto) = get_datatype_mut_specs thy tycos;
    1.89 -    val vs = map (fn (v, vsort) => (v, Sorts.inter_sort algebra (vsort, sort))) vs_proto;
    1.90 -    fun inst_type tyco (c, tys) =
    1.91 -      let
    1.92 -        val tys' = (map o map_atyps)
    1.93 -          (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) vs v))) tys
    1.94 -      in (c, tys') end;
    1.95 -    val css = map (fn (tyco, cs) => (tyco, (map (inst_type tyco) cs))) css_proto;
    1.96 -    fun mk_arity tyco =
    1.97 -      ((tyco, map snd vs), sort);
    1.98 -    fun typ_of_sort ty =
    1.99 -      let
   1.100 -        val arities = map (fn (tyco, _) => ((tyco, map snd vs), sort)) css;
   1.101 -      in ClassPackage.assume_arities_of_sort thy arities (ty, sort) end;
   1.102 -    fun mk_cons tyco (c, tys) =
   1.103 -      let
   1.104 -        val ts = Name.names Name.context "a" tys;
   1.105 -        val ty = tys ---> Type (tyco, map TFree vs);
   1.106 -      in list_comb (Const (c, ty), map Free ts) end;
   1.107 -  in if forall (fn (_, cs) => forall (fn (_, tys) => forall typ_of_sort tys) cs) css
   1.108 -    then SOME (
   1.109 -      map (fn (tyco, cs) => (tyco, (mk_arity tyco, map (mk_cons tyco) cs))) css
   1.110 -    ) else NONE
   1.111 -  end;
   1.112 -
   1.113 -fun prove_arities tac tycos sort f thy =
   1.114 -  case get_datatype_arities thy tycos sort
   1.115 -   of NONE => thy
   1.116 -    | SOME insts => let
   1.117 -        fun proven ((tyco, asorts), sort) =
   1.118 -          Sorts.of_sort (Sign.classes_of thy)
   1.119 -            (Type (tyco, map TFree (Name.names Name.context "'a" asorts)), sort);
   1.120 -        val (arities, css) = (split_list o map_filter
   1.121 -          (fn (tyco, (arity, cs)) => if proven arity
   1.122 -            then NONE else SOME (arity, (tyco, cs)))) insts;
   1.123 -      in
   1.124 -        thy
   1.125 -        |> K ((not o null) arities) ? ClassPackage.prove_instance_arity tac
   1.126 -             arities ("", []) (f thy arities css)
   1.127 -      end;
   1.128 +(** datatypes for code 2nd generation **)
   1.129  
   1.130  fun dtyp_of_case_const thy c =
   1.131    get_first (fn (dtco, { case_name, ... }) => if case_name = c then SOME dtco else NONE)
   1.132 @@ -423,6 +348,57 @@
   1.133            | _ => NONE)
   1.134      | _ => NONE;
   1.135  
   1.136 +fun mk_distinct cos =
   1.137 +  let
   1.138 +    fun sym_product [] = []
   1.139 +      | sym_product (x::xs) = map (pair x) xs @ sym_product xs;
   1.140 +    fun mk_co_args (co, tys) ctxt =
   1.141 +      let
   1.142 +        val names = Name.invents ctxt "a" (length tys);
   1.143 +        val ctxt' = fold Name.declare names ctxt;
   1.144 +        val vs = map2 (curry Free) names tys;
   1.145 +      in (vs, ctxt) end;
   1.146 +    fun mk_dist ((co1, tys1), (co2, tys2)) =
   1.147 +      let
   1.148 +        val ((xs1, xs2), _) = Name.context
   1.149 +          |> mk_co_args (co1, tys1)
   1.150 +          ||>> mk_co_args (co2, tys2);
   1.151 +        val prem = HOLogic.mk_eq
   1.152 +          (list_comb (co1, xs1), list_comb (co2, xs2));
   1.153 +        val t = HOLogic.mk_not prem;
   1.154 +      in HOLogic.mk_Trueprop t end;
   1.155 +  in map mk_dist (sym_product cos) end;
   1.156 +
   1.157 +local
   1.158 +  val not_false_true = iffD2 OF [nth (thms "HOL.simp_thms") 7, TrueI];
   1.159 +in fun get_eq thy dtco =
   1.160 +  let
   1.161 +    val SOME (vs, cs) = DatatypePackage.get_datatype_spec thy dtco;
   1.162 +    fun mk_triv_inject co =
   1.163 +      let
   1.164 +        val ct' = Thm.cterm_of thy
   1.165 +          (Const (co, Type (dtco, map (fn (v, sort) => TVar ((v, 0), sort)) vs)))
   1.166 +        val cty' = Thm.ctyp_of_term ct';
   1.167 +        val refl = Thm.prop_of HOL.refl;
   1.168 +        val SOME (ct, cty) = fold_aterms (fn Var (v, ty) =>
   1.169 +          (K o SOME) (Thm.cterm_of thy (Var (v, Thm.typ_of cty')), Thm.ctyp_of thy ty) | _ => I)
   1.170 +          refl NONE;
   1.171 +      in eqTrueI OF [Thm.instantiate ([(cty, cty')], [(ct, ct')]) HOL.refl] end;
   1.172 +    val inject1 = map_filter (fn (co, []) => SOME (mk_triv_inject co) | _ => NONE) cs
   1.173 +    val inject2 = (#inject o DatatypePackage.the_datatype thy) dtco;
   1.174 +    val ctxt = Context.init_proof thy;
   1.175 +    val simpset = Simplifier.context ctxt
   1.176 +      (MetaSimplifier.empty_ss addsimprocs [distinct_simproc]);
   1.177 +    val cos = map (fn (co, tys) =>
   1.178 +        (Const (co, tys ---> Type (dtco, map TFree vs)), tys)) cs;
   1.179 +    val tac = ALLGOALS (simp_tac simpset)
   1.180 +      THEN ALLGOALS (ProofContext.fact_tac [not_false_true, TrueI]);
   1.181 +    val distinct =
   1.182 +      mk_distinct cos
   1.183 +      |> map (fn t => Goal.prove_global thy [] [] t (K tac))
   1.184 +  in inject1 @ inject2 @ distinct end;
   1.185 +end (*local*);
   1.186 +
   1.187  fun datatype_tac thy dtco =
   1.188    let
   1.189      val ctxt = Context.init_proof thy;
   1.190 @@ -458,6 +434,126 @@
   1.191      fold CodegenTheorems.add_fun case_rewrites thy
   1.192    end;
   1.193  
   1.194 +
   1.195 +(** codetypes for code 2nd generation **)
   1.196 +
   1.197 +type hook = (string * (bool * ((string * sort) list * (string * typ list) list))) list
   1.198 +  -> theory -> theory;
   1.199 +
   1.200 +fun codetypes_dependency thy =
   1.201 +  let
   1.202 +    val names =
   1.203 +      map (rpair true) (Symtab.keys (DatatypePackage.get_datatypes thy))
   1.204 +        @ map (rpair false) (TypecopyPackage.get_typecopies thy);
   1.205 +    fun add_node (name, is_dt) =
   1.206 +      let
   1.207 +        fun add_tycos (Type (tyco, tys)) = insert (op =) tyco #> fold add_tycos tys
   1.208 +          | add_tycos _ = I;
   1.209 +        val tys = if is_dt then
   1.210 +            (maps snd o snd o the o DatatypePackage.get_datatype_spec thy) name
   1.211 +          else
   1.212 +            [(#typ o the o TypecopyPackage.get_typecopy_info thy) name]
   1.213 +        val deps = (filter (AList.defined (op =) names) o maps (fn ty =>
   1.214 +          add_tycos ty [])) tys;
   1.215 +      in
   1.216 +        Graph.default_node (name, ())
   1.217 +        #> fold (fn name' =>
   1.218 +             Graph.default_node (name', ())
   1.219 +             #> Graph.add_edge (name', name)
   1.220 +           ) deps
   1.221 +      end
   1.222 +  in
   1.223 +    Graph.empty
   1.224 +    |> fold add_node names
   1.225 +    |> Graph.strong_conn
   1.226 +    |> map (AList.make (the o AList.lookup (op =) names))
   1.227 +  end;
   1.228 +
   1.229 +fun mk_typecopy_spec ({ vs, constr, typ, ... } : TypecopyPackage.info) =
   1.230 +  (vs, [(constr, [typ])]);
   1.231 +
   1.232 +fun get_spec thy (dtco, true) =
   1.233 +      (the o DatatypePackage.get_datatype_spec thy) dtco
   1.234 +  | get_spec thy (tyco, false) =
   1.235 +      (mk_typecopy_spec o the o TypecopyPackage.get_typecopy_info thy) tyco;
   1.236 +
   1.237 +fun add_spec thy (tyco, is_dt) =
   1.238 +  (tyco, (is_dt, get_spec thy (tyco, is_dt)));
   1.239 +
   1.240 +fun add_codetypes_hook_bootstrap hook thy =
   1.241 +  let
   1.242 +    fun datatype_hook dtcos thy =
   1.243 +      hook (map (add_spec thy) (map (rpair true) dtcos)) thy;
   1.244 +    fun typecopy_hook ((tyco, info )) thy =
   1.245 +      hook ([(tyco, (false, mk_typecopy_spec info))]) thy;
   1.246 +  in
   1.247 +    thy
   1.248 +    |> fold hook ((map o map) (add_spec thy) (codetypes_dependency thy))
   1.249 +    |> DatatypeHooks.add datatype_hook
   1.250 +    |> TypecopyPackage.add_hook typecopy_hook
   1.251 +  end;
   1.252 +
   1.253 +fun the_codetypes_mut_specs thy ([(tyco, is_dt)]) =
   1.254 +      let
   1.255 +        val (vs, cs) = get_spec thy (tyco, is_dt)
   1.256 +      in (vs, [(tyco, (is_dt, cs))]) end
   1.257 +  | the_codetypes_mut_specs thy (tycos' as (tyco, true) :: _) =
   1.258 +      let
   1.259 +        val tycos = map fst tycos';
   1.260 +        val tycos'' = (map (#1 o snd) o #descr o DatatypePackage.the_datatype thy) tyco;
   1.261 +        val _ = if gen_subset (op =) (tycos, tycos'') then () else
   1.262 +          error ("datatype constructors are not mutually recursive: " ^ (commas o map quote) tycos);
   1.263 +        val (vs::_, css) = split_list (map (the o DatatypePackage.get_datatype_spec thy) tycos);
   1.264 +      in (vs, map2 (fn (tyco, is_dt) => fn cs => (tyco, (is_dt, cs))) tycos' css) end;
   1.265 +
   1.266 +fun get_codetypes_arities thy tycos sort =
   1.267 +  let
   1.268 +    val algebra = Sign.classes_of thy;
   1.269 +    val (vs_proto, css_proto) = the_codetypes_mut_specs thy tycos;
   1.270 +    val vs = map (fn (v, vsort) => (v, Sorts.inter_sort algebra (vsort, sort))) vs_proto;
   1.271 +    fun inst_type tyco (c, tys) =
   1.272 +      let
   1.273 +        val tys' = (map o map_atyps)
   1.274 +          (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) vs v))) tys
   1.275 +      in (c, tys') end;
   1.276 +    val css = map (fn (tyco, (_, cs)) => (tyco, (map (inst_type tyco) cs))) css_proto;
   1.277 +    fun mk_arity tyco =
   1.278 +      ((tyco, map snd vs), sort);
   1.279 +    fun typ_of_sort ty =
   1.280 +      let
   1.281 +        val arities = map (fn (tyco, _) => ((tyco, map snd vs), sort)) css;
   1.282 +      in ClassPackage.assume_arities_of_sort thy arities (ty, sort) end;
   1.283 +    fun mk_cons tyco (c, tys) =
   1.284 +      let
   1.285 +        val ts = Name.names Name.context "a" tys;
   1.286 +        val ty = tys ---> Type (tyco, map TFree vs);
   1.287 +      in list_comb (Const (c, ty), map Free ts) end;
   1.288 +  in if forall (fn (_, cs) => forall (fn (_, tys) => forall typ_of_sort tys) cs) css
   1.289 +    then SOME (
   1.290 +      map (fn (tyco, cs) => (tyco, (mk_arity tyco, map (mk_cons tyco) cs))) css
   1.291 +    ) else NONE
   1.292 +  end;
   1.293 +
   1.294 +fun prove_codetypes_arities tac tycos sort f thy =
   1.295 +  case get_codetypes_arities thy tycos sort
   1.296 +   of NONE => thy
   1.297 +    | SOME insts => let
   1.298 +        fun proven ((tyco, asorts), sort) =
   1.299 +          Sorts.of_sort (Sign.classes_of thy)
   1.300 +            (Type (tyco, map TFree (Name.names Name.context "'a" asorts)), sort);
   1.301 +        val (arities, css) = (split_list o map_filter
   1.302 +          (fn (tyco, (arity, cs)) => if proven arity
   1.303 +            then NONE else SOME (arity, (tyco, cs)))) insts;
   1.304 +      in
   1.305 +        thy
   1.306 +        |> K ((not o null) arities) ? ClassPackage.prove_instance_arity tac
   1.307 +             arities ("", []) (f thy arities css)
   1.308 +      end;
   1.309 +
   1.310 +
   1.311 +
   1.312 +(** theory setup **)
   1.313 +
   1.314  val setup = 
   1.315    add_codegen "datatype" datatype_codegen #>
   1.316    add_tycodegen "datatype" datatype_tycodegen #>