src/HOL/Tools/Datatype/datatype_data.ML
changeset 45822 843dc212f69e
parent 45821 c2f6c50e3d42
child 45839 43a5b86bc102
equal deleted inserted replaced
45821:c2f6c50e3d42 45822:843dc212f69e
     5 *)
     5 *)
     6 
     6 
     7 signature DATATYPE_DATA =
     7 signature DATATYPE_DATA =
     8 sig
     8 sig
     9   include DATATYPE_COMMON
     9   include DATATYPE_COMMON
    10   val derive_datatype_props : config -> string list -> descr list -> (string * sort) list ->
    10   val derive_datatype_props : config -> string list -> descr list ->
    11     thm -> thm list list -> thm list list -> theory -> string list * theory
    11     thm -> thm list list -> thm list list -> theory -> string list * theory
    12   val rep_datatype : config -> (string list -> Proof.context -> Proof.context) ->
    12   val rep_datatype : config -> (string list -> Proof.context -> Proof.context) ->
    13     term list -> theory -> Proof.state
    13     term list -> theory -> Proof.state
    14   val rep_datatype_cmd : string list -> theory -> Proof.state
    14   val rep_datatype_cmd : string list -> theory -> Proof.state
    15   val get_info : theory -> string -> info option
    15   val get_info : theory -> string -> info option
   107 
   107 
   108 (* complex queries *)
   108 (* complex queries *)
   109 
   109 
   110 fun the_spec thy dtco =
   110 fun the_spec thy dtco =
   111   let
   111   let
   112     val {descr, index, sorts = raw_sorts, ...} = the_info thy dtco;
   112     val {descr, index, ...} = the_info thy dtco;
   113     val (_, dtys, raw_cos) = the (AList.lookup (op =) descr index);
   113     val (_, dtys, raw_cos) = the (AList.lookup (op =) descr index);
   114     val sorts =
   114     val args = map Datatype_Aux.dest_DtTFree dtys;
   115       map ((fn v => (v, (the o AList.lookup (op =) raw_sorts) v)) o Datatype_Aux.dest_DtTFree) dtys;
   115     val cos = map (fn (co, tys) => (co, map (Datatype_Aux.typ_of_dtyp descr) tys)) raw_cos;
   116     val cos = map (fn (co, tys) => (co, map (Datatype_Aux.typ_of_dtyp descr sorts) tys)) raw_cos;
   116   in (args, cos) end;
   117   in (sorts, cos) end;
       
   118 
   117 
   119 fun the_descr thy (raw_tycos as raw_tyco :: _) =
   118 fun the_descr thy (raw_tycos as raw_tyco :: _) =
   120   let
   119   let
   121     val info = the_info thy raw_tyco;
   120     val info = the_info thy raw_tyco;
   122     val descr = #descr info;
   121     val descr = #descr info;
   123 
   122 
   124     val SOME (_, dtys, _) = AList.lookup (op =) descr (#index info);
   123     val (_, dtys, _) = the (AList.lookup (op =) descr (#index info));
   125     val vs =
   124     val vs = map Datatype_Aux.dest_DtTFree dtys;
   126       map ((fn v => (v, the (AList.lookup (op =) (#sorts info) v))) o Datatype_Aux.dest_DtTFree)
       
   127         dtys;
       
   128 
   125 
   129     fun is_DtTFree (Datatype_Aux.DtTFree _) = true
   126     fun is_DtTFree (Datatype_Aux.DtTFree _) = true
   130       | is_DtTFree _ = false;
   127       | is_DtTFree _ = false;
   131     val k = find_index (fn (_, (_, dTs, _)) => not (forall is_DtTFree dTs)) descr;
   128     val k = find_index (fn (_, (_, dTs, _)) => not (forall is_DtTFree dTs)) descr;
   132     val protoTs as (dataTs, _) =
   129     val protoTs as (dataTs, _) =
   133       chop k descr
   130       chop k descr
   134       |> (pairself o map)
   131       |> (pairself o map)
   135         (fn (_, (tyco, dTs, _)) => (tyco, map (Datatype_Aux.typ_of_dtyp descr vs) dTs));
   132         (fn (_, (tyco, dTs, _)) => (tyco, map (Datatype_Aux.typ_of_dtyp descr) dTs));
   136 
   133 
   137     val tycos = map fst dataTs;
   134     val tycos = map fst dataTs;
   138     val _ =
   135     val _ =
   139       if eq_set (op =) (tycos, raw_tycos) then ()
   136       if eq_set (op =) (tycos, raw_tycos) then ()
   140       else
   137       else
   158     val tycos = fold add_tycos Ts [];
   155     val tycos = fold add_tycos Ts [];
   159   in map_filter (Option.map #distinct o get_info thy) tycos end;
   156   in map_filter (Option.map #distinct o get_info thy) tycos end;
   160 
   157 
   161 fun get_constrs thy dtco =
   158 fun get_constrs thy dtco =
   162   (case try (the_spec thy) dtco of
   159   (case try (the_spec thy) dtco of
   163     SOME (sorts, cos) =>
   160     SOME (args, cos) =>
   164       let
   161       let
   165         fun subst (v, sort) = TVar ((v, 0), sort);
   162         fun subst (v, sort) = TVar ((v, 0), sort);
   166         fun subst_ty (TFree v) = subst v
   163         fun subst_ty (TFree v) = subst v
   167           | subst_ty ty = ty;
   164           | subst_ty ty = ty;
   168         val dty = Type (dtco, map subst sorts);
   165         val dty = Type (dtco, map subst args);
   169         fun mk_co (co, tys) = (co, map (Term.map_atyps subst_ty) tys ---> dty);
   166         fun mk_co (co, tys) = (co, map (Term.map_atyps subst_ty) tys ---> dty);
   170       in SOME (map mk_co cos) end
   167       in SOME (map mk_co cos) end
   171   | NONE => NONE);
   168   | NONE => NONE);
   172 
   169 
   173 
   170 
   281   type T = Datatype_Aux.config * string list;
   278   type T = Datatype_Aux.config * string list;
   282   val eq: T * T -> bool = eq_snd (op =);
   279   val eq: T * T -> bool = eq_snd (op =);
   283 );
   280 );
   284 fun interpretation f = Datatype_Interpretation.interpretation (uncurry f);
   281 fun interpretation f = Datatype_Interpretation.interpretation (uncurry f);
   285 
   282 
   286 fun make_dt_info descr sorts induct inducts rec_names rec_rewrites
   283 fun make_dt_info descr induct inducts rec_names rec_rewrites
   287     (index, (((((((((((_, (tname, _, _))), inject), distinct),
   284     (index, (((((((((((_, (tname, _, _))), inject), distinct),
   288       exhaust), nchotomy), case_name), case_rewrites), case_cong), weak_case_cong),
   285       exhaust), nchotomy), case_name), case_rewrites), case_cong), weak_case_cong),
   289         (split, split_asm))) =
   286         (split, split_asm))) =
   290   (tname,
   287   (tname,
   291    {index = index,
   288    {index = index,
   292     descr = descr,
   289     descr = descr,
   293     sorts = sorts,
       
   294     inject = inject,
   290     inject = inject,
   295     distinct = distinct,
   291     distinct = distinct,
   296     induct = induct,
   292     induct = induct,
   297     inducts = inducts,
   293     inducts = inducts,
   298     exhaust = exhaust,
   294     exhaust = exhaust,
   304     case_cong = case_cong,
   300     case_cong = case_cong,
   305     weak_case_cong = weak_case_cong,
   301     weak_case_cong = weak_case_cong,
   306     split = split,
   302     split = split,
   307     split_asm = split_asm});
   303     split_asm = split_asm});
   308 
   304 
   309 fun derive_datatype_props config dt_names descr sorts
   305 fun derive_datatype_props config dt_names descr induct inject distinct thy1 =
   310     induct inject distinct thy1 =
       
   311   let
   306   let
   312     val thy2 = thy1 |> Theory.checkpoint;
   307     val thy2 = thy1 |> Theory.checkpoint;
   313     val flat_descr = flat descr;
   308     val flat_descr = flat descr;
   314     val new_type_names = map Long_Name.base_name dt_names;
   309     val new_type_names = map Long_Name.base_name dt_names;
   315     val _ =
   310     val _ =
   316       Datatype_Aux.message config
   311       Datatype_Aux.message config
   317         ("Deriving properties for datatype(s) " ^ commas_quote new_type_names);
   312         ("Deriving properties for datatype(s) " ^ commas_quote new_type_names);
   318 
   313 
   319     val (exhaust, thy3) =
   314     val (exhaust, thy3) = thy2
   320       Datatype_Abs_Proofs.prove_casedist_thms config new_type_names
   315       |> Datatype_Abs_Proofs.prove_casedist_thms config new_type_names
   321         descr sorts induct (mk_case_names_exhausts flat_descr dt_names) thy2;
   316         descr induct (mk_case_names_exhausts flat_descr dt_names);
   322     val (nchotomys, thy4) =
   317     val (nchotomys, thy4) = thy3
   323       Datatype_Abs_Proofs.prove_nchotomys config new_type_names
   318       |> Datatype_Abs_Proofs.prove_nchotomys config new_type_names descr exhaust;
   324         descr sorts exhaust thy3;
   319     val ((rec_names, rec_rewrites), thy5) = thy4
   325     val ((rec_names, rec_rewrites), thy5) =
   320       |> Datatype_Abs_Proofs.prove_primrec_thms
   326       Datatype_Abs_Proofs.prove_primrec_thms
   321         config new_type_names descr (#inject o the o Symtab.lookup (get_all thy4))
   327         config new_type_names descr sorts (#inject o the o Symtab.lookup (get_all thy4))
   322         inject (distinct, all_distincts thy2 (Datatype_Aux.get_rec_types flat_descr)) induct;
   328         inject (distinct, all_distincts thy2 (Datatype_Aux.get_rec_types flat_descr sorts))
   323     val ((case_rewrites, case_names), thy6) = thy5
   329         induct thy4;
   324       |> Datatype_Abs_Proofs.prove_case_thms config new_type_names descr rec_names rec_rewrites;
   330     val ((case_rewrites, case_names), thy6) =
   325     val (case_congs, thy7) = thy6
   331       Datatype_Abs_Proofs.prove_case_thms
   326       |> Datatype_Abs_Proofs.prove_case_congs new_type_names descr nchotomys case_rewrites;
   332         config new_type_names descr sorts rec_names rec_rewrites thy5;
   327     val (weak_case_congs, thy8) = thy7
   333     val (case_congs, thy7) =
   328       |> Datatype_Abs_Proofs.prove_weak_case_congs new_type_names descr;
   334       Datatype_Abs_Proofs.prove_case_congs new_type_names
   329     val (splits, thy9) = thy8
   335         descr sorts nchotomys case_rewrites thy6;
   330       |> Datatype_Abs_Proofs.prove_split_thms
   336     val (weak_case_congs, thy8) =
   331         config new_type_names descr inject distinct exhaust case_rewrites;
   337       Datatype_Abs_Proofs.prove_weak_case_congs new_type_names descr sorts thy7;
       
   338     val (splits, thy9) =
       
   339       Datatype_Abs_Proofs.prove_split_thms
       
   340         config new_type_names descr sorts inject distinct exhaust case_rewrites thy8;
       
   341 
   332 
   342     val inducts = Project_Rule.projections (Proof_Context.init_global thy2) induct;
   333     val inducts = Project_Rule.projections (Proof_Context.init_global thy2) induct;
   343     val dt_infos = map_index
   334     val dt_infos = map_index
   344       (make_dt_info flat_descr sorts induct inducts rec_names rec_rewrites)
   335       (make_dt_info flat_descr induct inducts rec_names rec_rewrites)
   345       (hd descr ~~ inject ~~ distinct ~~ exhaust ~~ nchotomys ~~
   336       (hd descr ~~ inject ~~ distinct ~~ exhaust ~~ nchotomys ~~
   346         case_names ~~ case_rewrites ~~ case_congs ~~ weak_case_congs ~~ splits);
   337         case_names ~~ case_rewrites ~~ case_congs ~~ weak_case_congs ~~ splits);
   347     val dt_names = map fst dt_infos;
   338     val dt_names = map fst dt_infos;
   348     val prfx = Binding.qualify true (space_implode "_" new_type_names);
   339     val prfx = Binding.qualify true (space_implode "_" new_type_names);
   349     val simps = flat (inject @ distinct @ case_rewrites) @ rec_rewrites;
   340     val simps = flat (inject @ distinct @ case_rewrites) @ rec_rewrites;
   376 
   367 
   377 
   368 
   378 
   369 
   379 (** declare existing type as datatype **)
   370 (** declare existing type as datatype **)
   380 
   371 
   381 fun prove_rep_datatype config dt_names descr sorts raw_inject half_distinct raw_induct thy1 =
   372 fun prove_rep_datatype config dt_names descr raw_inject half_distinct raw_induct thy1 =
   382   let
   373   let
   383     val raw_distinct = (map o maps) (fn thm => [thm, thm RS not_sym]) half_distinct;
   374     val raw_distinct = (map o maps) (fn thm => [thm, thm RS not_sym]) half_distinct;
   384     val new_type_names = map Long_Name.base_name dt_names;
   375     val new_type_names = map Long_Name.base_name dt_names;
   385     val prfx = Binding.qualify true (space_implode "_" new_type_names);
   376     val prfx = Binding.qualify true (space_implode "_" new_type_names);
   386     val (((inject, distinct), [induct]), thy2) =
   377     val (((inject, distinct), [induct]), thy2) =
   390       ||>> Global_Theory.add_thms
   381       ||>> Global_Theory.add_thms
   391         [((prfx (Binding.name "induct"), raw_induct),
   382         [((prfx (Binding.name "induct"), raw_induct),
   392           [mk_case_names_induct descr])];
   383           [mk_case_names_induct descr])];
   393   in
   384   in
   394     thy2
   385     thy2
   395     |> derive_datatype_props config dt_names [descr] sorts induct inject distinct
   386     |> derive_datatype_props config dt_names [descr] induct inject distinct
   396  end;
   387  end;
   397 
   388 
   398 fun gen_rep_datatype prep_term config after_qed raw_ts thy =
   389 fun gen_rep_datatype prep_term config after_qed raw_ts thy =
   399   let
   390   let
   400     val ctxt = Proof_Context.init_global thy;
   391     val ctxt = Proof_Context.init_global thy;
   439     val dtyps_of_typ =
   430     val dtyps_of_typ =
   440       map (Datatype_Aux.dtyp_of_typ (map (rpair (map fst vs) o fst) cs)) o binder_types;
   431       map (Datatype_Aux.dtyp_of_typ (map (rpair (map fst vs) o fst) cs)) o binder_types;
   441     val dt_names = map fst cs;
   432     val dt_names = map fst cs;
   442 
   433 
   443     fun mk_spec (i, (tyco, constr)) =
   434     fun mk_spec (i, (tyco, constr)) =
   444       (i, (tyco,
   435       (i, (tyco, map Datatype_Aux.DtTFree vs, (map o apsnd) dtyps_of_typ constr));
   445         map (Datatype_Aux.DtTFree o fst) vs,
       
   446         (map o apsnd) dtyps_of_typ constr));
       
   447     val descr = map_index mk_spec cs;
   436     val descr = map_index mk_spec cs;
   448     val injs = Datatype_Prop.make_injs [descr] vs;
   437     val injs = Datatype_Prop.make_injs [descr];
   449     val half_distincts = map snd (Datatype_Prop.make_distincts [descr] vs);
   438     val half_distincts = map snd (Datatype_Prop.make_distincts [descr]);
   450     val ind = Datatype_Prop.make_ind [descr] vs;
   439     val ind = Datatype_Prop.make_ind [descr];
   451     val rules = (map o map o map) Logic.close_form [[[ind]], injs, half_distincts];
   440     val rules = (map o map o map) Logic.close_form [[[ind]], injs, half_distincts];
   452 
   441 
   453     fun after_qed' raw_thms =
   442     fun after_qed' raw_thms =
   454       let
   443       let
   455         val [[[raw_induct]], raw_inject, half_distinct] =
   444         val [[[raw_induct]], raw_inject, half_distinct] =
   456           unflat rules (map Drule.zero_var_indexes_list raw_thms);
   445           unflat rules (map Drule.zero_var_indexes_list raw_thms);
   457             (*FIXME somehow dubious*)
   446             (*FIXME somehow dubious*)
   458       in
   447       in
   459         Proof_Context.background_theory_result  (* FIXME !? *)
   448         Proof_Context.background_theory_result  (* FIXME !? *)
   460           (prove_rep_datatype config dt_names descr vs raw_inject half_distinct raw_induct)
   449           (prove_rep_datatype config dt_names descr raw_inject half_distinct raw_induct)
   461         #-> after_qed
   450         #-> after_qed
   462       end;
   451       end;
   463   in
   452   in
   464     ctxt
   453     ctxt
   465     |> Proof.theorem NONE after_qed' ((map o map) (rpair []) (flat rules))
   454     |> Proof.theorem NONE after_qed' ((map o map) (rpair []) (flat rules))