src/HOL/Nominal/nominal_datatype.ML
changeset 45822 843dc212f69e
parent 45741 088256c289e7
child 45838 653c84d5c6c9
equal deleted inserted replaced
45821:c2f6c50e3d42 45822:843dc212f69e
    75 type descr = (int * (string * dtyp list * (string * (dtyp list * dtyp) list) list)) list;
    75 type descr = (int * (string * dtyp list * (string * (dtyp list * dtyp) list) list)) list;
    76 
    76 
    77 type nominal_datatype_info =
    77 type nominal_datatype_info =
    78   {index : int,
    78   {index : int,
    79    descr : descr,
    79    descr : descr,
    80    sorts : (string * sort) list,
       
    81    rec_names : string list,
    80    rec_names : string list,
    82    rec_rewrites : thm list,
    81    rec_rewrites : thm list,
    83    induction : thm,
    82    induction : thm,
    84    distinct : thm list,
    83    distinct : thm list,
    85    inject : thm list};
    84    inject : thm list};
    98 val get_nominal_datatype = Symtab.lookup o get_nominal_datatypes;
    97 val get_nominal_datatype = Symtab.lookup o get_nominal_datatypes;
    99 
    98 
   100 
    99 
   101 (**** make datatype info ****)
   100 (**** make datatype info ****)
   102 
   101 
   103 fun make_dt_info descr sorts induct reccomb_names rec_thms
   102 fun make_dt_info descr induct reccomb_names rec_thms
   104     (i, (((_, (tname, _, _)), distinct), inject)) =
   103     (i, (((_, (tname, _, _)), distinct), inject)) =
   105   (tname,
   104   (tname,
   106    {index = i,
   105    {index = i,
   107     descr = descr,
   106     descr = descr,
   108     sorts = sorts,
       
   109     rec_names = reccomb_names,
   107     rec_names = reccomb_names,
   110     rec_rewrites = rec_thms,
   108     rec_rewrites = rec_thms,
   111     induction = induct,
   109     induction = induct,
   112     distinct = distinct,
   110     distinct = distinct,
   113     inject = inject});
   111     inject = inject});
   243     val new_type_names' = map (fn n => n ^ "_Rep") new_type_names;
   241     val new_type_names' = map (fn n => n ^ "_Rep") new_type_names;
   244 
   242 
   245     val (full_new_type_names',thy1) = Datatype.add_datatype config dts'' thy;
   243     val (full_new_type_names',thy1) = Datatype.add_datatype config dts'' thy;
   246 
   244 
   247     val {descr, induct, ...} = Datatype.the_info thy1 (hd full_new_type_names');
   245     val {descr, induct, ...} = Datatype.the_info thy1 (hd full_new_type_names');
   248     fun nth_dtyp i = typ_of_dtyp descr sorts (DtRec i);
   246     fun nth_dtyp i = typ_of_dtyp descr (DtRec i);
   249 
   247 
   250     val big_name = space_implode "_" new_type_names;
   248     val big_name = space_implode "_" new_type_names;
   251 
   249 
   252 
   250 
   253     (**** define permutation functions ****)
   251     (**** define permutation functions ****)
   266 
   264 
   267     val perm_eqs = maps (fn (i, (_, _, constrs)) =>
   265     val perm_eqs = maps (fn (i, (_, _, constrs)) =>
   268       let val T = nth_dtyp i
   266       let val T = nth_dtyp i
   269       in map (fn (cname, dts) =>
   267       in map (fn (cname, dts) =>
   270         let
   268         let
   271           val Ts = map (typ_of_dtyp descr sorts) dts;
   269           val Ts = map (typ_of_dtyp descr) dts;
   272           val names = Name.variant_list ["pi"] (Datatype_Prop.make_tnames Ts);
   270           val names = Name.variant_list ["pi"] (Datatype_Prop.make_tnames Ts);
   273           val args = map Free (names ~~ Ts);
   271           val args = map Free (names ~~ Ts);
   274           val c = Const (cname, Ts ---> T);
   272           val c = Const (cname, Ts ---> T);
   275           fun perm_arg (dt, x) =
   273           fun perm_arg (dt, x) =
   276             let val T = type_of x
   274             let val T = type_of x
   516            | _ => ([], dtf))
   514            | _ => ([], dtf))
   517       | strip_option (DtType ("fun", [dt, DtType ("Nominal.noption", [dt'])])) =
   515       | strip_option (DtType ("fun", [dt, DtType ("Nominal.noption", [dt'])])) =
   518           apfst (cons dt) (strip_option dt')
   516           apfst (cons dt) (strip_option dt')
   519       | strip_option dt = ([], dt);
   517       | strip_option dt = ([], dt);
   520 
   518 
   521     val dt_atomTs = distinct op = (map (typ_of_dtyp descr sorts)
   519     val dt_atomTs = distinct op = (map (typ_of_dtyp descr)
   522       (maps (fn (_, (_, _, cs)) => maps (maps (fst o strip_option) o snd) cs) descr));
   520       (maps (fn (_, (_, _, cs)) => maps (maps (fst o strip_option) o snd) cs) descr));
   523     val dt_atoms = map (fst o dest_Type) dt_atomTs;
   521     val dt_atoms = map (fst o dest_Type) dt_atomTs;
   524 
   522 
   525     fun make_intr s T (cname, cargs) =
   523     fun make_intr s T (cname, cargs) =
   526       let
   524       let
   527         fun mk_prem dt (j, j', prems, ts) =
   525         fun mk_prem dt (j, j', prems, ts) =
   528           let
   526           let
   529             val (dts, dt') = strip_option dt;
   527             val (dts, dt') = strip_option dt;
   530             val (dts', dt'') = strip_dtyp dt';
   528             val (dts', dt'') = strip_dtyp dt';
   531             val Ts = map (typ_of_dtyp descr sorts) dts;
   529             val Ts = map (typ_of_dtyp descr) dts;
   532             val Us = map (typ_of_dtyp descr sorts) dts';
   530             val Us = map (typ_of_dtyp descr) dts';
   533             val T = typ_of_dtyp descr sorts dt'';
   531             val T = typ_of_dtyp descr dt'';
   534             val free = mk_Free "x" (Us ---> T) j;
   532             val free = mk_Free "x" (Us ---> T) j;
   535             val free' = app_bnds free (length Us);
   533             val free' = app_bnds free (length Us);
   536             fun mk_abs_fun T (i, t) =
   534             fun mk_abs_fun T (i, t) =
   537               let val U = fastype_of t
   535               let val U = fastype_of t
   538               in (i + 1, Const ("Nominal.abs_fun", [T, U, T] --->
   536               in (i + 1, Const ("Nominal.abs_fun", [T, U, T] --->
   754 
   752 
   755     val pdescr = map (fn ((i, (s, dts, constrs)), (_, idxss)) => (i, (s, dts,
   753     val pdescr = map (fn ((i, (s, dts, constrs)), (_, idxss)) => (i, (s, dts,
   756       map (fn ((cname, cargs), idxs) => (cname, partition_cargs idxs cargs))
   754       map (fn ((cname, cargs), idxs) => (cname, partition_cargs idxs cargs))
   757         (constrs ~~ idxss)))) (descr'' ~~ ndescr);
   755         (constrs ~~ idxss)))) (descr'' ~~ ndescr);
   758 
   756 
   759     fun nth_dtyp' i = typ_of_dtyp descr'' sorts (DtRec i);
   757     fun nth_dtyp' i = typ_of_dtyp descr'' (DtRec i);
   760 
   758 
   761     val rep_names = map (fn s =>
   759     val rep_names = map (fn s =>
   762       Sign.intern_const thy7 ("Rep_" ^ s)) new_type_names;
   760       Sign.intern_const thy7 ("Rep_" ^ s)) new_type_names;
   763     val abs_names = map (fn s =>
   761     val abs_names = map (fn s =>
   764       Sign.intern_const thy7 ("Abs_" ^ s)) new_type_names;
   762       Sign.intern_const thy7 ("Abs_" ^ s)) new_type_names;
   765 
   763 
   766     val recTs = get_rec_types descr'' sorts;
   764     val recTs = get_rec_types descr'';
   767     val newTs' = take (length new_type_names) recTs';
   765     val newTs' = take (length new_type_names) recTs';
   768     val newTs = take (length new_type_names) recTs;
   766     val newTs = take (length new_type_names) recTs;
   769 
   767 
   770     val full_new_type_names = map (Sign.full_bname thy) new_type_names;
   768     val full_new_type_names = map (Sign.full_bname thy) new_type_names;
   771 
   769 
   772     fun make_constr_def tname T T' (((cname_rep, _), (cname, cargs)), (cname', mx))
   770     fun make_constr_def tname T T' (((cname_rep, _), (cname, cargs)), (cname', mx))
   773         (thy, defs, eqns) =
   771         (thy, defs, eqns) =
   774       let
   772       let
   775         fun constr_arg (dts, dt) (j, l_args, r_args) =
   773         fun constr_arg (dts, dt) (j, l_args, r_args) =
   776           let
   774           let
   777             val xs = map (fn (dt, i) => mk_Free "x" (typ_of_dtyp descr'' sorts dt) i)
   775             val xs = map (fn (dt, i) => mk_Free "x" (typ_of_dtyp descr'' dt) i)
   778               (dts ~~ (j upto j + length dts - 1))
   776               (dts ~~ (j upto j + length dts - 1))
   779             val x = mk_Free "x" (typ_of_dtyp descr'' sorts dt) (j + length dts)
   777             val x = mk_Free "x" (typ_of_dtyp descr'' dt) (j + length dts)
   780           in
   778           in
   781             (j + length dts + 1,
   779             (j + length dts + 1,
   782              xs @ x :: l_args,
   780              xs @ x :: l_args,
   783              fold_rev mk_abs_fun xs
   781              fold_rev mk_abs_fun xs
   784                (case dt of
   782                (case dt of
   785                   DtRec k => if k < length new_type_names then
   783                   DtRec k => if k < length new_type_names then
   786                       Const (nth rep_names k, typ_of_dtyp descr'' sorts dt -->
   784                       Const (nth rep_names k, typ_of_dtyp descr'' dt -->
   787                         typ_of_dtyp descr sorts dt) $ x
   785                         typ_of_dtyp descr dt) $ x
   788                     else error "nested recursion not (yet) supported"
   786                     else error "nested recursion not (yet) supported"
   789                 | _ => x) :: r_args)
   787                 | _ => x) :: r_args)
   790           end
   788           end
   791 
   789 
   792         val (_, l_args, r_args) = fold_rev constr_arg cargs (1, [], []);
   790         val (_, l_args, r_args) = fold_rev constr_arg cargs (1, [], []);
   864 
   862 
   865     val perm_rep_perm_thms = maps prove_perm_rep_perm (atoms ~~ perm_closed_thmss);
   863     val perm_rep_perm_thms = maps prove_perm_rep_perm (atoms ~~ perm_closed_thmss);
   866 
   864 
   867     (* prove distinctness theorems *)
   865     (* prove distinctness theorems *)
   868 
   866 
   869     val distinct_props = Datatype_Prop.make_distincts descr' sorts;
   867     val distinct_props = Datatype_Prop.make_distincts descr';
   870     val dist_rewrites = map2 (fn rep_thms => fn dist_lemma =>
   868     val dist_rewrites = map2 (fn rep_thms => fn dist_lemma =>
   871       dist_lemma :: rep_thms @ [In0_eq, In1_eq, In0_not_In1, In1_not_In0])
   869       dist_lemma :: rep_thms @ [In0_eq, In1_eq, In0_not_In1, In1_not_In0])
   872         constr_rep_thmss dist_lemmas;
   870         constr_rep_thmss dist_lemmas;
   873 
   871 
   874     fun prove_distinct_thms _ (_, []) = []
   872     fun prove_distinct_thms _ (_, []) = []
   900             let val T = fastype_of t
   898             let val T = fastype_of t
   901             in Const ("Nominal.perm", permT --> T --> T) $ pi $ t end;
   899             in Const ("Nominal.perm", permT --> T --> T) $ pi $ t end;
   902 
   900 
   903           fun constr_arg (dts, dt) (j, l_args, r_args) =
   901           fun constr_arg (dts, dt) (j, l_args, r_args) =
   904             let
   902             let
   905               val Ts = map (typ_of_dtyp descr'' sorts) dts;
   903               val Ts = map (typ_of_dtyp descr'') dts;
   906               val xs = map (fn (T, i) => mk_Free "x" T i)
   904               val xs = map (fn (T, i) => mk_Free "x" T i)
   907                 (Ts ~~ (j upto j + length dts - 1))
   905                 (Ts ~~ (j upto j + length dts - 1))
   908               val x = mk_Free "x" (typ_of_dtyp descr'' sorts dt) (j + length dts)
   906               val x = mk_Free "x" (typ_of_dtyp descr'' dt) (j + length dts)
   909             in
   907             in
   910               (j + length dts + 1,
   908               (j + length dts + 1,
   911                xs @ x :: l_args,
   909                xs @ x :: l_args,
   912                map perm (xs @ [x]) @ r_args)
   910                map perm (xs @ [x]) @ r_args)
   913             end
   911             end
   950           val cname = Sign.intern_const thy8
   948           val cname = Sign.intern_const thy8
   951             (Long_Name.append tname (Long_Name.base_name cname));
   949             (Long_Name.append tname (Long_Name.base_name cname));
   952 
   950 
   953           fun make_inj (dts, dt) (j, args1, args2, eqs) =
   951           fun make_inj (dts, dt) (j, args1, args2, eqs) =
   954             let
   952             let
   955               val Ts_idx = map (typ_of_dtyp descr'' sorts) dts ~~ (j upto j + length dts - 1);
   953               val Ts_idx = map (typ_of_dtyp descr'') dts ~~ (j upto j + length dts - 1);
   956               val xs = map (fn (T, i) => mk_Free "x" T i) Ts_idx;
   954               val xs = map (fn (T, i) => mk_Free "x" T i) Ts_idx;
   957               val ys = map (fn (T, i) => mk_Free "y" T i) Ts_idx;
   955               val ys = map (fn (T, i) => mk_Free "y" T i) Ts_idx;
   958               val x = mk_Free "x" (typ_of_dtyp descr'' sorts dt) (j + length dts);
   956               val x = mk_Free "x" (typ_of_dtyp descr'' dt) (j + length dts);
   959               val y = mk_Free "y" (typ_of_dtyp descr'' sorts dt) (j + length dts)
   957               val y = mk_Free "y" (typ_of_dtyp descr'' dt) (j + length dts);
   960             in
   958             in
   961               (j + length dts + 1,
   959               (j + length dts + 1,
   962                xs @ (x :: args1), ys @ (y :: args2),
   960                xs @ (x :: args1), ys @ (y :: args2),
   963                HOLogic.mk_eq
   961                HOLogic.mk_eq
   964                  (fold_rev mk_abs_fun xs x, fold_rev mk_abs_fun ys y) :: eqs)
   962                  (fold_rev mk_abs_fun xs x, fold_rev mk_abs_fun ys y) :: eqs)
   993             (Long_Name.append tname (Long_Name.base_name cname));
   991             (Long_Name.append tname (Long_Name.base_name cname));
   994           val atomT = Type (atom, []);
   992           val atomT = Type (atom, []);
   995 
   993 
   996           fun process_constr (dts, dt) (j, args1, args2) =
   994           fun process_constr (dts, dt) (j, args1, args2) =
   997             let
   995             let
   998               val Ts_idx = map (typ_of_dtyp descr'' sorts) dts ~~ (j upto j + length dts - 1);
   996               val Ts_idx = map (typ_of_dtyp descr'') dts ~~ (j upto j + length dts - 1);
   999               val xs = map (fn (T, i) => mk_Free "x" T i) Ts_idx;
   997               val xs = map (fn (T, i) => mk_Free "x" T i) Ts_idx;
  1000               val x = mk_Free "x" (typ_of_dtyp descr'' sorts dt) (j + length dts)
   998               val x = mk_Free "x" (typ_of_dtyp descr'' dt) (j + length dts);
  1001             in
   999             in
  1002               (j + length dts + 1,
  1000               (j + length dts + 1,
  1003                xs @ (x :: args1), fold_rev mk_abs_fun xs x :: args2)
  1001                xs @ (x :: args1), fold_rev mk_abs_fun xs x :: args2)
  1004             end;
  1002             end;
  1005 
  1003 
  1064     val indrule_lemma' = cterm_instantiate
  1062     val indrule_lemma' = cterm_instantiate
  1065       (map (cterm_of thy8) Ps ~~ map (cterm_of thy8) frees) indrule_lemma;
  1063       (map (cterm_of thy8) Ps ~~ map (cterm_of thy8) frees) indrule_lemma;
  1066 
  1064 
  1067     val Abs_inverse_thms' = map (fn r => r RS subst) Abs_inverse_thms;
  1065     val Abs_inverse_thms' = map (fn r => r RS subst) Abs_inverse_thms;
  1068 
  1066 
  1069     val dt_induct_prop = Datatype_Prop.make_ind descr' sorts;
  1067     val dt_induct_prop = Datatype_Prop.make_ind descr';
  1070     val dt_induct = Goal.prove_global thy8 []
  1068     val dt_induct = Goal.prove_global thy8 []
  1071       (Logic.strip_imp_prems dt_induct_prop) (Logic.strip_imp_concl dt_induct_prop)
  1069       (Logic.strip_imp_prems dt_induct_prop) (Logic.strip_imp_concl dt_induct_prop)
  1072       (fn {prems, ...} => EVERY
  1070       (fn {prems, ...} => EVERY
  1073         [rtac indrule_lemma' 1,
  1071         [rtac indrule_lemma' 1,
  1074          (Datatype_Aux.ind_tac rep_induct [] THEN_ALL_NEW Object_Logic.atomize_prems_tac) 1,
  1072          (Datatype_Aux.ind_tac rep_induct [] THEN_ALL_NEW Object_Logic.atomize_prems_tac) 1,
  1161           mk_fresh2 (p :: xss) yss;
  1159           mk_fresh2 (p :: xss) yss;
  1162 
  1160 
  1163     fun make_ind_prem fsT f k T ((cname, cargs), idxs) =
  1161     fun make_ind_prem fsT f k T ((cname, cargs), idxs) =
  1164       let
  1162       let
  1165         val recs = filter is_rec_type cargs;
  1163         val recs = filter is_rec_type cargs;
  1166         val Ts = map (typ_of_dtyp descr'' sorts) cargs;
  1164         val Ts = map (typ_of_dtyp descr'') cargs;
  1167         val recTs' = map (typ_of_dtyp descr'' sorts) recs;
  1165         val recTs' = map (typ_of_dtyp descr'') recs;
  1168         val tnames = Name.variant_list pnames (Datatype_Prop.make_tnames Ts);
  1166         val tnames = Name.variant_list pnames (Datatype_Prop.make_tnames Ts);
  1169         val rec_tnames = map fst (filter (is_rec_type o snd) (tnames ~~ cargs));
  1167         val rec_tnames = map fst (filter (is_rec_type o snd) (tnames ~~ cargs));
  1170         val frees = tnames ~~ Ts;
  1168         val frees = tnames ~~ Ts;
  1171         val frees' = partition_cargs idxs frees;
  1169         val frees' = partition_cargs idxs frees;
  1172         val z = (singleton (Name.variant_list tnames) "z", fsT);
  1170         val z = (singleton (Name.variant_list tnames) "z", fsT);
  1414 
  1412 
  1415     val _ = warning "defining recursion combinator ...";
  1413     val _ = warning "defining recursion combinator ...";
  1416 
  1414 
  1417     val used = fold Term.add_tfree_namesT recTs [];
  1415     val used = fold Term.add_tfree_namesT recTs [];
  1418 
  1416 
  1419     val (rec_result_Ts', rec_fn_Ts') = Datatype_Prop.make_primrec_Ts descr' sorts used;
  1417     val (rec_result_Ts', rec_fn_Ts') = Datatype_Prop.make_primrec_Ts descr' used;
  1420 
  1418 
  1421     val rec_sort = if null dt_atomTs then HOLogic.typeS else
  1419     val rec_sort = if null dt_atomTs then HOLogic.typeS else
  1422       Sign.minimize_sort thy10 (Sign.certify_sort thy10 pt_cp_sort);
  1420       Sign.minimize_sort thy10 (Sign.certify_sort thy10 pt_cp_sort);
  1423 
  1421 
  1424     val rec_result_Ts = map (fn TFree (s, _) => TFree (s, rec_sort)) rec_result_Ts';
  1422     val rec_result_Ts = map (fn TFree (s, _) => TFree (s, rec_sort)) rec_result_Ts';
  1457     val rec_ctxt = Free ("z", fsT');
  1455     val rec_ctxt = Free ("z", fsT');
  1458 
  1456 
  1459     fun make_rec_intr T p rec_set ((cname, cargs), idxs)
  1457     fun make_rec_intr T p rec_set ((cname, cargs), idxs)
  1460         (rec_intr_ts, rec_prems, rec_prems', rec_eq_prems, l) =
  1458         (rec_intr_ts, rec_prems, rec_prems', rec_eq_prems, l) =
  1461       let
  1459       let
  1462         val Ts = map (typ_of_dtyp descr'' sorts) cargs;
  1460         val Ts = map (typ_of_dtyp descr'') cargs;
  1463         val frees = map (fn i => "x" ^ string_of_int i) (1 upto length Ts) ~~ Ts;
  1461         val frees = map (fn i => "x" ^ string_of_int i) (1 upto length Ts) ~~ Ts;
  1464         val frees' = partition_cargs idxs frees;
  1462         val frees' = partition_cargs idxs frees;
  1465         val binders = maps fst frees';
  1463         val binders = maps fst frees';
  1466         val atomTs = distinct op = (maps (map snd o fst) frees');
  1464         val atomTs = distinct op = (maps (map snd o fst) frees');
  1467         val recs = map_filter
  1465         val recs = map_filter
  2044              rtac the1_equality 1,
  2042              rtac the1_equality 1,
  2045              solve rec_unique_thms prems 1,
  2043              solve rec_unique_thms prems 1,
  2046              resolve_tac rec_intrs 1,
  2044              resolve_tac rec_intrs 1,
  2047              REPEAT (solve (prems @ rec_total_thms) prems 1)])
  2045              REPEAT (solve (prems @ rec_total_thms) prems 1)])
  2048       end) (rec_eq_prems ~~
  2046       end) (rec_eq_prems ~~
  2049         Datatype_Prop.make_primrecs new_type_names descr' sorts thy12);
  2047         Datatype_Prop.make_primrecs new_type_names descr' thy12);
  2050 
  2048 
  2051     val dt_infos = map_index (make_dt_info pdescr sorts induct reccomb_names rec_thms)
  2049     val dt_infos = map_index (make_dt_info pdescr induct reccomb_names rec_thms)
  2052       (descr1 ~~ distinct_thms ~~ inject_thms);
  2050       (descr1 ~~ distinct_thms ~~ inject_thms);
  2053 
  2051 
  2054     (* FIXME: theorems are stored in database for testing only *)
  2052     (* FIXME: theorems are stored in database for testing only *)
  2055     val (_, thy13) = thy12 |>
  2053     val (_, thy13) = thy12 |>
  2056       Global_Theory.add_thmss
  2054       Global_Theory.add_thmss