src/HOL/Tools/Datatype/datatype_realizer.ML
changeset 41423 25df154b8ffc
parent 40844 5895c525739d
child 41698 90597e044e5f
equal deleted inserted replaced
41422:8a765db7e0f8 41423:25df154b8ffc
    12 end;
    12 end;
    13 
    13 
    14 structure Datatype_Realizer : DATATYPE_REALIZER =
    14 structure Datatype_Realizer : DATATYPE_REALIZER =
    15 struct
    15 struct
    16 
    16 
    17 open Datatype_Aux;
       
    18 
       
    19 fun subsets i j =
    17 fun subsets i j =
    20   if i <= j then
    18   if i <= j then
    21     let val is = subsets (i+1) j
    19     let val is = subsets (i+1) j
    22     in map (fn ks => i::ks) is @ is end
    20     in map (fn ks => i::ks) is @ is end
    23   else [[]];
    21   else [[]];
    28 fun is_unit t = body_type (fastype_of t) = HOLogic.unitT;
    26 fun is_unit t = body_type (fastype_of t) = HOLogic.unitT;
    29 
    27 
    30 fun tname_of (Type (s, _)) = s
    28 fun tname_of (Type (s, _)) = s
    31   | tname_of _ = "";
    29   | tname_of _ = "";
    32 
    30 
    33 fun make_ind sorts ({descr, rec_names, rec_rewrites, induct, ...} : info) is thy =
    31 fun make_ind sorts ({descr, rec_names, rec_rewrites, induct, ...} : Datatype_Aux.info) is thy =
    34   let
    32   let
    35     val recTs = get_rec_types descr sorts;
    33     val recTs = Datatype_Aux.get_rec_types descr sorts;
    36     val pnames = if length descr = 1 then ["P"]
    34     val pnames = if length descr = 1 then ["P"]
    37       else map (fn i => "P" ^ string_of_int i) (1 upto length descr);
    35       else map (fn i => "P" ^ string_of_int i) (1 upto length descr);
    38 
    36 
    39     val rec_result_Ts = map (fn ((i, _), P) =>
    37     val rec_result_Ts = map (fn ((i, _), P) =>
    40       if member (op =) is i then TFree ("'" ^ P, HOLogic.typeS) else HOLogic.unitT)
    38       if member (op =) is i then TFree ("'" ^ P, HOLogic.typeS) else HOLogic.unitT)
    49       if member (op =) is i then list_all_free ([(s, T)], t) else t;
    47       if member (op =) is i then list_all_free ([(s, T)], t) else t;
    50 
    48 
    51     val (prems, rec_fns) = split_list (flat (fst (fold_map
    49     val (prems, rec_fns) = split_list (flat (fst (fold_map
    52       (fn ((i, (_, _, constrs)), T) => fold_map (fn (cname, cargs) => fn j =>
    50       (fn ((i, (_, _, constrs)), T) => fold_map (fn (cname, cargs) => fn j =>
    53         let
    51         let
    54           val Ts = map (typ_of_dtyp descr sorts) cargs;
    52           val Ts = map (Datatype_Aux.typ_of_dtyp descr sorts) cargs;
    55           val tnames = Name.variant_list pnames (Datatype_Prop.make_tnames Ts);
    53           val tnames = Name.variant_list pnames (Datatype_Prop.make_tnames Ts);
    56           val recs = filter (is_rec_type o fst o fst) (cargs ~~ tnames ~~ Ts);
    54           val recs = filter (Datatype_Aux.is_rec_type o fst o fst) (cargs ~~ tnames ~~ Ts);
    57           val frees = tnames ~~ Ts;
    55           val frees = tnames ~~ Ts;
    58 
    56 
    59           fun mk_prems vs [] =
    57           fun mk_prems vs [] =
    60                 let
    58                 let
    61                   val rT = nth (rec_result_Ts) i;
    59                   val rT = nth (rec_result_Ts) i;
    62                   val vs' = filter_out is_unit vs;
    60                   val vs' = filter_out is_unit vs;
    63                   val f = mk_Free "f" (map fastype_of vs' ---> rT) j;
    61                   val f = Datatype_Aux.mk_Free "f" (map fastype_of vs' ---> rT) j;
    64                   val f' = Envir.eta_contract (list_abs_free
    62                   val f' = Envir.eta_contract (list_abs_free
    65                     (map dest_Free vs, if member (op =) is i then list_comb (f, vs')
    63                     (map dest_Free vs, if member (op =) is i then list_comb (f, vs')
    66                       else HOLogic.unit));
    64                       else HOLogic.unit));
    67                 in (HOLogic.mk_Trueprop (make_pred i rT T (list_comb (f, vs'))
    65                 in (HOLogic.mk_Trueprop (make_pred i rT T (list_comb (f, vs'))
    68                   (list_comb (Const (cname, Ts ---> T), map Free frees))), f')
    66                   (list_comb (Const (cname, Ts ---> T), map Free frees))), f')
    69                 end
    67                 end
    70             | mk_prems vs (((dt, s), T) :: ds) =
    68             | mk_prems vs (((dt, s), T) :: ds) =
    71                 let
    69                 let
    72                   val k = body_index dt;
    70                   val k = Datatype_Aux.body_index dt;
    73                   val (Us, U) = strip_type T;
    71                   val (Us, U) = strip_type T;
    74                   val i = length Us;
    72                   val i = length Us;
    75                   val rT = nth (rec_result_Ts) k;
    73                   val rT = nth (rec_result_Ts) k;
    76                   val r = Free ("r" ^ s, Us ---> rT);
    74                   val r = Free ("r" ^ s, Us ---> rT);
    77                   val (p, f) = mk_prems (vs @ [r]) ds
    75                   val (p, f) = mk_prems (vs @ [r]) ds
    78                 in (mk_all k ("r" ^ s) (Us ---> rT) (Logic.mk_implies
    76                 in (mk_all k ("r" ^ s) (Us ---> rT) (Logic.mk_implies
    79                   (list_all (map (pair "x") Us, HOLogic.mk_Trueprop
    77                   (list_all (map (pair "x") Us, HOLogic.mk_Trueprop
    80                     (make_pred k rT U (app_bnds r i)
    78                     (make_pred k rT U (Datatype_Aux.app_bnds r i)
    81                       (app_bnds (Free (s, T)) i))), p)), f)
    79                       (Datatype_Aux.app_bnds (Free (s, T)) i))), p)), f)
    82                 end
    80                 end
    83 
    81 
    84         in (apfst (curry list_all_free frees) (mk_prems (map Free frees) recs), j + 1) end)
    82         in (apfst (curry list_all_free frees) (mk_prems (map Free frees) recs), j + 1) end)
    85           constrs) (descr ~~ recTs) 1)));
    83           constrs) (descr ~~ recTs) 1)));
    86 
    84 
   152             (map (head_of o strip_abs_body) rec_fns)) r);
   150             (map (head_of o strip_abs_body) rec_fns)) r);
   153 
   151 
   154   in Extraction.add_realizers_i [(ind_name, (vs, r', prf))] thy' end;
   152   in Extraction.add_realizers_i [(ind_name, (vs, r', prf))] thy' end;
   155 
   153 
   156 
   154 
   157 fun make_casedists sorts ({index, descr, case_name, case_rewrites, exhaust, ...} : info) thy =
   155 fun make_casedists sorts
       
   156     ({index, descr, case_name, case_rewrites, exhaust, ...} : Datatype_Aux.info) thy =
   158   let
   157   let
   159     val cert = cterm_of thy;
   158     val cert = cterm_of thy;
   160     val rT = TFree ("'P", HOLogic.typeS);
   159     val rT = TFree ("'P", HOLogic.typeS);
   161     val rT' = TVar (("'P", 0), HOLogic.typeS);
   160     val rT' = TVar (("'P", 0), HOLogic.typeS);
   162 
   161 
   163     fun make_casedist_prem T (cname, cargs) =
   162     fun make_casedist_prem T (cname, cargs) =
   164       let
   163       let
   165         val Ts = map (typ_of_dtyp descr sorts) cargs;
   164         val Ts = map (Datatype_Aux.typ_of_dtyp descr sorts) cargs;
   166         val frees = Name.variant_list ["P", "y"] (Datatype_Prop.make_tnames Ts) ~~ Ts;
   165         val frees = Name.variant_list ["P", "y"] (Datatype_Prop.make_tnames Ts) ~~ Ts;
   167         val free_ts = map Free frees;
   166         val free_ts = map Free frees;
   168         val r = Free ("r" ^ Long_Name.base_name cname, Ts ---> rT)
   167         val r = Free ("r" ^ Long_Name.base_name cname, Ts ---> rT)
   169       in (r, list_all_free (frees, Logic.mk_implies (HOLogic.mk_Trueprop
   168       in (r, list_all_free (frees, Logic.mk_implies (HOLogic.mk_Trueprop
   170         (HOLogic.mk_eq (Free ("y", T), list_comb (Const (cname, Ts ---> T), free_ts))),
   169         (HOLogic.mk_eq (Free ("y", T), list_comb (Const (cname, Ts ---> T), free_ts))),
   171           HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $
   170           HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $
   172             list_comb (r, free_ts)))))
   171             list_comb (r, free_ts)))))
   173       end;
   172       end;
   174 
   173 
   175     val SOME (_, _, constrs) = AList.lookup (op =) descr index;
   174     val SOME (_, _, constrs) = AList.lookup (op =) descr index;
   176     val T = List.nth (get_rec_types descr sorts, index);
   175     val T = List.nth (Datatype_Aux.get_rec_types descr sorts, index);
   177     val (rs, prems) = split_list (map (make_casedist_prem T) constrs);
   176     val (rs, prems) = split_list (map (make_casedist_prem T) constrs);
   178     val r = Const (case_name, map fastype_of rs ---> T --> rT);
   177     val r = Const (case_name, map fastype_of rs ---> T --> rT);
   179 
   178 
   180     val y = Var (("y", 0), Logic.varifyT_global T);
   179     val y = Var (("y", 0), Logic.varifyT_global T);
   181     val y' = Free ("y", T);
   180     val y' = Free ("y", T);
   216 
   215 
   217 fun add_dt_realizers config names thy =
   216 fun add_dt_realizers config names thy =
   218   if ! Proofterm.proofs < 2 then thy
   217   if ! Proofterm.proofs < 2 then thy
   219   else
   218   else
   220     let
   219     let
   221       val _ = message config "Adding realizers for induction and case analysis ..."
   220       val _ = Datatype_Aux.message config "Adding realizers for induction and case analysis ...";
   222       val infos = map (Datatype.the_info thy) names;
   221       val infos = map (Datatype.the_info thy) names;
   223       val info :: _ = infos;
   222       val info :: _ = infos;
   224     in
   223     in
   225       thy
   224       thy
   226       |> fold_rev (make_ind (#sorts info) info) (subsets 0 (length (#descr info) - 1))
   225       |> fold_rev (make_ind (#sorts info) info) (subsets 0 (length (#descr info) - 1))