src/HOL/Tools/record.ML
changeset 32744 50406c4951d9
parent 32743 c4e9a48bc50e
child 32745 192d58483fdf
equal deleted inserted replaced
32743:c4e9a48bc50e 32744:50406c4951d9
   284 (* theory data *)
   284 (* theory data *)
   285 
   285 
   286 type record_data =
   286 type record_data =
   287  {records: record_info Symtab.table,
   287  {records: record_info Symtab.table,
   288   sel_upd:
   288   sel_upd:
   289    {selectors: unit Symtab.table,
   289    {selectors: (int * bool) Symtab.table,
   290     updates: string Symtab.table,
   290     updates: string Symtab.table,
   291     simpset: Simplifier.simpset},
   291     simpset: Simplifier.simpset,
       
   292     defset: Simplifier.simpset,
       
   293     foldcong: Simplifier.simpset,
       
   294     unfoldcong: Simplifier.simpset},
   292   equalities: thm Symtab.table,
   295   equalities: thm Symtab.table,
   293   extinjects: thm list,
   296   extinjects: thm list,
   294   extsplit: thm Symtab.table, (* maps extension name to split rule *)
   297   extsplit: thm Symtab.table, (* maps extension name to split rule *)
   295   splits: (thm*thm*thm*thm) Symtab.table,    (* !!,!,EX - split-equalities,induct rule *)
   298   splits: (thm*thm*thm*thm) Symtab.table,    (* !!,!,EX - split-equalities,induct rule *)
   296   extfields: (string*typ) list Symtab.table, (* maps extension to its fields *)
   299   extfields: (string*typ) list Symtab.table, (* maps extension to its fields *)
   408     SOME s => let
   411     SOME s => let
   409         val SOME (dep, ismore) = Symtab.lookup (#selectors sel_upd) s;
   412         val SOME (dep, ismore) = Symtab.lookup (#selectors sel_upd) s;
   410       in SOME (s, dep, ismore) end
   413       in SOME (s, dep, ismore) end
   411   | NONE => NONE end;
   414   | NONE => NONE end;
   412 
   415 
   413 fun put_sel_upd names simps = RecordsData.map (fn {records,
   416 fun put_sel_upd names more depth simps defs (folds, unfolds) thy =
   414   sel_upd = {selectors, updates, simpset},
   417   let
   415     equalities, extinjects, extsplit, splits, extfields, fieldext} =>
   418     val all  = names @ [more];
   416   make_record_data records
   419     val sels = map (rpair (depth, false)) names @ [(more, (depth, true))];
   417     {selectors = fold (fn name => Symtab.update (name, ())) names selectors,
   420     val upds = map (suffix updateN) all ~~ all;
   418       updates = fold (fn name => Symtab.update ((suffix updateN) name, name)) names updates,
   421 
   419       simpset = Simplifier.addsimps (simpset, simps)}
   422     val {records, sel_upd = {selectors, updates, simpset,
   420       equalities extinjects extsplit splits extfields fieldext);
   423                              defset, foldcong, unfoldcong},
   421 
   424       equalities, extinjects, extsplit, splits, extfields,
       
   425       fieldext} = RecordsData.get thy;
       
   426     val data = make_record_data records
       
   427       {selectors = fold Symtab.update_new sels selectors,
       
   428         updates = fold Symtab.update_new upds updates,
       
   429         simpset = Simplifier.addsimps (simpset, simps),
       
   430         defset = Simplifier.addsimps (defset, defs),
       
   431         foldcong = foldcong addcongs folds,
       
   432         unfoldcong = unfoldcong addcongs unfolds}
       
   433        equalities extinjects extsplit splits extfields fieldext;
       
   434   in RecordsData.put data thy end;
   422 
   435 
   423 (* access 'equalities' *)
   436 (* access 'equalities' *)
   424 
   437 
   425 fun add_record_equalities name thm thy =
   438 fun add_record_equalities name thm thy =
   426   let
   439   let
   956   | get_updfuns _             = [];
   969   | get_updfuns _             = [];
   957 
   970 
   958 fun get_accupd_simps thy term defset intros_tac = let
   971 fun get_accupd_simps thy term defset intros_tac = let
   959     val (acc, [body]) = strip_comb term;
   972     val (acc, [body]) = strip_comb term;
   960     val recT          = domain_type (fastype_of acc);
   973     val recT          = domain_type (fastype_of acc);
   961     val updfuns       = sort_distinct Term.fast_term_ord
   974     val updfuns       = sort_distinct TermOrd.fast_term_ord
   962                            (get_updfuns body);
   975                            (get_updfuns body);
   963     fun get_simp upd  = let
   976     fun get_simp upd  = let
   964         val T    = domain_type (fastype_of upd);
   977         val T    = domain_type (fastype_of upd);
   965         val lhs  = mk_comp acc (upd $ Free ("f", T));
   978         val lhs  = mk_comp acc (upd $ Free ("f", T));
   966         val rhs  = if is_sel_upd_pair thy acc upd
   979         val rhs  = if is_sel_upd_pair thy acc upd
   973         val dest = if is_sel_upd_pair thy acc upd
   986         val dest = if is_sel_upd_pair thy acc upd
   974                    then o_eq_dest else o_eq_id_dest;
   987                    then o_eq_dest else o_eq_id_dest;
   975       in standard (othm RS dest) end;
   988       in standard (othm RS dest) end;
   976   in map get_simp updfuns end;
   989   in map get_simp updfuns end;
   977 
   990 
   978 structure SymSymTab = TableFun(type key = string * string
   991 structure SymSymTab = Table(type key = string * string
   979                                 val ord = prod_ord fast_string_ord fast_string_ord);
   992                             val ord = prod_ord fast_string_ord fast_string_ord);
   980 
   993 
   981 fun get_updupd_simp thy defset intros_tac u u' comp = let
   994 fun get_updupd_simp thy defset intros_tac u u' comp = let
   982     val f    = Free ("f", domain_type (fastype_of u));
   995     val f    = Free ("f", domain_type (fastype_of u));
   983     val f'   = Free ("f'", domain_type (fastype_of u'));
   996     val f'   = Free ("f'", domain_type (fastype_of u'));
   984     val lhs  = mk_comp (u $ f) (u' $ f');
   997     val lhs  = mk_comp (u $ f) (u' $ f');
  1017   in swapsneeded updfuns [] Symtab.empty SymSymTab.empty end;
  1030   in swapsneeded updfuns [] Symtab.empty SymSymTab.empty end;
  1018 
  1031 
  1019 fun named_cterm_instantiate values thm = let
  1032 fun named_cterm_instantiate values thm = let
  1020     fun match name (Var ((name', _), _)) = name = name'
  1033     fun match name (Var ((name', _), _)) = name = name'
  1021       | match name _ = false;
  1034       | match name _ = false;
  1022     fun getvar name = case (find_first (match name) (term_vars (prop_of thm)))
  1035     fun getvar name = case (find_first (match name)
       
  1036                                     (OldTerm.term_vars (prop_of thm)))
  1023       of SOME var => cterm_of (theory_of_thm thm) var
  1037       of SOME var => cterm_of (theory_of_thm thm) var
  1024        | NONE => raise THM ("named_cterm_instantiate: " ^ name, 0, [thm])
  1038        | NONE => raise THM ("named_cterm_instantiate: " ^ name, 0, [thm])
  1025   in
  1039   in
  1026     cterm_instantiate (map (apfst getvar) values) thm
  1040     cterm_instantiate (map (apfst getvar) values) thm
  1027   end;
  1041   end;
  1112                 | mk_eq_terms r = NONE
  1126                 | mk_eq_terms r = NONE
  1113             in
  1127             in
  1114               (case mk_eq_terms (upd$k$r) of
  1128               (case mk_eq_terms (upd$k$r) of
  1115                  SOME (trm,trm',vars)
  1129                  SOME (trm,trm',vars)
  1116                  => SOME (prove_unfold_defs thy ss domS [] []
  1130                  => SOME (prove_unfold_defs thy ss domS [] []
  1117                              (list_all(vars,(equals rangeS$(sel$trm)$trm'))))
  1131                              (list_all(vars,(Logic.mk_equals (sel$trm, trm')))))
  1118                | NONE => NONE)
  1132                | NONE => NONE)
  1119             end
  1133             end
  1120           | NONE => NONE)
  1134           | NONE => NONE)
  1121         else NONE
  1135         else NONE
  1122       | _ => NONE));
  1136       | _ => NONE));
  1204          *  used for eliminating case (2) defined above
  1218          *  used for eliminating case (2) defined above
  1205          *)
  1219          *)
  1206         fun mk_updterm ((upd as Const (u, T), s, f) :: upds) above term = let
  1220         fun mk_updterm ((upd as Const (u, T), s, f) :: upds) above term = let
  1207             val (lhs, rhs, vars, dups, simp, noops) =
  1221             val (lhs, rhs, vars, dups, simp, noops) =
  1208                   mk_updterm upds (Symtab.update (u, ()) above) term;
  1222                   mk_updterm upds (Symtab.update (u, ()) above) term;
  1209             val (fvar, skelf) = K_skeleton (Sign.base_name s) (domain_type T)
  1223             val (fvar, skelf) = K_skeleton (Long_Name.base_name s) (domain_type T)
  1210                                       (Bound (length vars)) f;
  1224                                       (Bound (length vars)) f;
  1211             val (isnoop, skelf') = is_upd_noop s f term;
  1225             val (isnoop, skelf') = is_upd_noop s f term;
  1212             val funT  = domain_type T;
  1226             val funT  = domain_type T;
  1213             fun mk_comp_local (f, f') =
  1227             fun mk_comp_local (f, f') =
  1214               Const ("Fun.comp", funT --> funT --> funT) $ f $ f';
  1228               Const ("Fun.comp", funT --> funT --> funT) $ f $ f';
  1238                   = mk_updterm upds Symtab.empty base;
  1252                   = mk_updterm upds Symtab.empty base;
  1239         val noops' = flat (map snd (Symtab.dest noops));
  1253         val noops' = flat (map snd (Symtab.dest noops));
  1240       in
  1254       in
  1241         if simp then
  1255         if simp then
  1242            SOME (prove_unfold_defs thy ss baseT noops' [record_simproc]
  1256            SOME (prove_unfold_defs thy ss baseT noops' [record_simproc]
  1243                              (list_all(vars,(equals baseT$lhs$rhs))))
  1257                              (list_all(vars,(Logic.mk_equals (lhs, rhs)))))
  1244         else NONE
  1258         else NONE
  1245       end)
  1259       end)
  1246 
  1260 
  1247 end
  1261 end
  1248 
  1262 
  1557     end;
  1571     end;
  1558 
  1572 
  1559     fun mk_istuple ((thy, i), (left, rght)) =
  1573     fun mk_istuple ((thy, i), (left, rght)) =
  1560     let
  1574     let
  1561       val suff = if i = 0 then ext_typeN else inner_typeN ^ (string_of_int i);
  1575       val suff = if i = 0 then ext_typeN else inner_typeN ^ (string_of_int i);
  1562       val nm   = suffix suff (Sign.base_name name);
  1576       val nm   = suffix suff (Long_Name.base_name name);
  1563       val (cons, thy') = IsTupleSupport.add_istuple_type
  1577       val (cons, thy') = IsTupleSupport.add_istuple_type
  1564                (nm, alphas_zeta) (fastype_of left, fastype_of rght) thy;
  1578                (nm, alphas_zeta) (fastype_of left, fastype_of rght) thy;
  1565     in
  1579     in
  1566       ((thy', i + 1), cons $ left $ rght)
  1580       ((thy', i + 1), cons $ left $ rght)
  1567     end;
  1581     end;
  1584           fun group16 [] = []
  1598           fun group16 [] = []
  1585             | group16 xs = Library.take (16, xs)
  1599             | group16 xs = Library.take (16, xs)
  1586                              :: group16 (Library.drop (16, xs));
  1600                              :: group16 (Library.drop (16, xs));
  1587           val vars' = group16 vars;
  1601           val vars' = group16 vars;
  1588           val ((thy', i'), composites) =
  1602           val ((thy', i'), composites) =
  1589                    foldl_map mk_even_istuple ((thy, i), vars');
  1603                    Library.foldl_map mk_even_istuple ((thy, i), vars');
  1590         in
  1604         in
  1591           build_meta_tree_type i' thy' composites more
  1605           build_meta_tree_type i' thy' composites more
  1592         end
  1606         end
  1593       else let
  1607       else let
  1594           val ((thy', i'), term)
  1608           val ((thy', i'), term)
  1709            EVERY [cut_rules_tac [split_meta RS meta_iffD2] 1,
  1723            EVERY [cut_rules_tac [split_meta RS meta_iffD2] 1,
  1710                   resolve_tac prems 2,
  1724                   resolve_tac prems 2,
  1711                   asm_simp_tac HOL_ss 1]) end;
  1725                   asm_simp_tac HOL_ss 1]) end;
  1712     val induct = timeit_msg "record extension induct proof:" induct_prf;
  1726     val induct = timeit_msg "record extension induct proof:" induct_prf;
  1713 
  1727 
  1714     val (([inject',induct',surjective',split_meta',ext_def'],
  1728     val ([inject',induct',surjective',split_meta',ext_def'],
  1715           [dest_convs',upd_convs']),
       
  1716       thm_thy) =
  1729       thm_thy) =
  1717       defs_thy
  1730       defs_thy
  1718       |> (PureThy.add_thms o map (Thm.no_attributes o apfst Binding.name))
  1731       |> (PureThy.add_thms o map (Thm.no_attributes o apfst Binding.name))
  1719            [("ext_inject", inject),
  1732            [("ext_inject", inject),
  1720             ("ext_induct", induct),
  1733             ("ext_induct", induct),
  1721             ("ext_surjective", surjective),
  1734             ("ext_surjective", surject),
  1722             ("ext_split", split_meta),
  1735             ("ext_split", split_meta),
  1723             ("ext_def", ext_def)]
  1736             ("ext_def", ext_def)]
  1724 
  1737 
  1725   in (thm_thy,extT,induct',inject',split_meta',ext_def')
  1738   in (thm_thy,extT,induct',inject',split_meta',ext_def')
  1726   end;
  1739   end;
  2112           timeit_msg "record upd_convs_standard proof:" upd_convs_standard_prf;
  2125           timeit_msg "record upd_convs_standard proof:" upd_convs_standard_prf;
  2113 
  2126 
  2114     fun get_upd_acc_congs () = let
  2127     fun get_upd_acc_congs () = let
  2115         val symdefs  = map symmetric (sel_defs @ upd_defs);
  2128         val symdefs  = map symmetric (sel_defs @ upd_defs);
  2116         val fold_ss  = HOL_basic_ss addsimps symdefs;
  2129         val fold_ss  = HOL_basic_ss addsimps symdefs;
  2117         val ua_congs = map (simplify fold_ss) upd_acc_cong_assists;
  2130         val ua_congs = map (standard o simplify fold_ss) upd_acc_cong_assists;
  2118       in (ua_congs RL [updacc_foldE], ua_congs RL [updacc_unfoldE]) end;
  2131       in (ua_congs RL [updacc_foldE], ua_congs RL [updacc_unfoldE]) end;
  2119     val (fold_congs, unfold_congs) =
  2132     val (fold_congs, unfold_congs) =
  2120           timeit_msg "record upd fold/unfold congs:" get_upd_acc_congs;
  2133           timeit_msg "record upd fold/unfold congs:" get_upd_acc_congs;
  2121 
  2134 
  2122     val parent_induct = if null parents then [] else [#induct (hd (rev parents))];
  2135     val parent_induct = if null parents then [] else [#induct (hd (rev parents))];
  2135           try_param_tac rN induct_scheme 1
  2148           try_param_tac rN induct_scheme 1
  2136           THEN try_param_tac "more" @{thm unit.induct} 1
  2149           THEN try_param_tac "more" @{thm unit.induct} 1
  2137           THEN resolve_tac prems 1)
  2150           THEN resolve_tac prems 1)
  2138       end;
  2151       end;
  2139     val induct = timeit_msg "record induct proof:" induct_prf;
  2152     val induct = timeit_msg "record induct proof:" induct_prf;
  2140 
       
  2141     fun surjective_prf () =
       
  2142       prove_standard [] surjective_prop (fn prems =>
       
  2143           (EVERY [try_param_tac rN induct_scheme 1,
       
  2144                   simp_tac (ss addsimps sel_convs_standard) 1]))
       
  2145     val surjective = timeit_msg "record surjective proof:" surjective_prf;
       
  2146 
  2153 
  2147     fun cases_scheme_prf_opt () =
  2154     fun cases_scheme_prf_opt () =
  2148       let
  2155       let
  2149         val (_$(Pvar$_)) = concl_of induct_scheme;
  2156         val (_$(Pvar$_)) = concl_of induct_scheme;
  2150         val ind = cterm_instantiate
  2157         val ind = cterm_instantiate
  2169         try_param_tac rN cases_scheme 1
  2176         try_param_tac rN cases_scheme 1
  2170         THEN simp_all_tac HOL_basic_ss [unit_all_eq1]);
  2177         THEN simp_all_tac HOL_basic_ss [unit_all_eq1]);
  2171     val cases = timeit_msg "record cases proof:" cases_prf;
  2178     val cases = timeit_msg "record cases proof:" cases_prf;
  2172 
  2179 
  2173     fun surjective_prf () = let
  2180     fun surjective_prf () = let
  2174         val o_ass_thm = symmetric (mk_meta_eq o_assoc);
  2181         val leaf_ss   = get_sel_upd_defs defs_thy
  2175         val o_reassoc = simplify (HOL_basic_ss addsimps [o_ass_thm]);
  2182                                 addsimps (sel_defs @ (o_assoc :: id_o_apps));
  2176         val sel_defs' = map o_reassoc sel_defs;
  2183         val init_ss   = HOL_basic_ss addsimps ext_defs;
  2177         val ss        = HOL_basic_ss addsimps (ext_defs @ sel_defs');
       
  2178       in
  2184       in
  2179         prove_standard [] surjective_prop (fn prems =>
  2185         prove_standard [] surjective_prop (fn prems =>
  2180             (EVERY [rtac surject_assist_idE 1,
  2186             (EVERY [rtac surject_assist_idE 1,
  2181                     simp_tac ss 1,
  2187                     simp_tac init_ss 1,
  2182                     REPEAT (intros_tac 1 ORELSE
  2188                     REPEAT (intros_tac 1 ORELSE
  2183                             (rtac surject_assistI 1 THEN
  2189                             (rtac surject_assistI 1 THEN
  2184                              simp_tac (HOL_basic_ss addsimps id_o_apps) 1))]))
  2190                              simp_tac leaf_ss 1))]))
  2185       end;
  2191       end;
  2186     val surjective = timeit_msg "record surjective proof:" surjective_prf;
  2192     val surjective = timeit_msg "record surjective proof:" surjective_prf;
  2187 
  2193 
  2188     fun split_meta_prf () =
  2194     fun split_meta_prf () =
  2189         prove false [] split_meta_prop (fn prems =>
  2195         prove false [] split_meta_prop (fn prems =>
  2228 
  2234 
  2229 
  2235 
  2230     fun split_ex_prf () =
  2236     fun split_ex_prf () =
  2231       let
  2237       let
  2232         val ss   = HOL_basic_ss addsimps [not_ex RS sym, nth simp_thms 1];
  2238         val ss   = HOL_basic_ss addsimps [not_ex RS sym, nth simp_thms 1];
  2233         val [Pv] = term_vars (prop_of split_object);
  2239         val [Pv] = OldTerm.term_vars (prop_of split_object);
  2234         val cPv  = cterm_of defs_thy Pv;
  2240         val cPv  = cterm_of defs_thy Pv;
  2235         val cP   = cterm_of defs_thy (lambda r0 (HOLogic.mk_not (P $ r0)));
  2241         val cP   = cterm_of defs_thy (lambda r0 (HOLogic.mk_not (P $ r0)));
  2236         val so3  = cterm_instantiate ([(cPv, cP)]) split_object;
  2242         val so3  = cterm_instantiate ([(cPv, cP)]) split_object;
  2237         val so4  = simplify ss so3;
  2243         val so4  = simplify ss so3;
  2238       in
  2244       in
  2254              (* simp_all_tac ss (sel_convs) would also work but is less efficient *)
  2260              (* simp_all_tac ss (sel_convs) would also work but is less efficient *)
  2255       end);
  2261       end);
  2256      val equality = timeit_msg "record equality proof:" equality_prf;
  2262      val equality = timeit_msg "record equality proof:" equality_prf;
  2257 
  2263 
  2258     val ((([sel_convs', upd_convs', sel_defs', upd_defs',
  2264     val ((([sel_convs', upd_convs', sel_defs', upd_defs',
  2259             fold_congs', unfold_congs', surjective',
  2265             fold_congs', unfold_congs',
  2260           [split_meta', split_object', split_ex'], derived_defs'],
  2266           [split_meta', split_object', split_ex'], derived_defs'],
  2261           [surjective', equality']),
  2267           [surjective', equality']),
  2262           [induct_scheme', induct', cases_scheme', cases']), thms_thy) =
  2268           [induct_scheme', induct', cases_scheme', cases']), thms_thy) =
  2263       defs_thy
  2269       defs_thy
  2264       |> (PureThy.add_thmss o map (Thm.no_attributes o apfst Binding.name))
  2270       |> (PureThy.add_thmss o map (Thm.no_attributes o apfst Binding.name))
  2265          [("select_convs", sel_convs_standard),
  2271          [("select_convs", sel_convs_standard),
  2266           ("update_convs", upd_convs),
  2272           ("update_convs", upd_convs_standard),
  2267           ("select_defs", sel_defs),
  2273           ("select_defs", sel_defs),
  2268           ("update_defs", upd_defs),
  2274           ("update_defs", upd_defs),
       
  2275           ("fold_congs", fold_congs),
       
  2276           ("unfold_congs", unfold_congs),
  2269           ("splits", [split_meta_standard,split_object,split_ex]),
  2277           ("splits", [split_meta_standard,split_object,split_ex]),
  2270           ("defs", derived_defs)]
  2278           ("defs", derived_defs)]
  2271       ||>> (PureThy.add_thms o map (Thm.no_attributes o apfst Binding.name))
  2279       ||>> (PureThy.add_thms o map (Thm.no_attributes o apfst Binding.name))
  2272           [("surjective", surjective),
  2280           [("surjective", surjective),
  2273            ("equality", equality)]
  2281            ("equality", equality)]
  2287       |> (snd oo PureThy.add_thmss)
  2295       |> (snd oo PureThy.add_thmss)
  2288           [((Binding.name "simps", sel_upd_simps),
  2296           [((Binding.name "simps", sel_upd_simps),
  2289             [Simplifier.simp_add, Nitpick_Const_Simps.add]),
  2297             [Simplifier.simp_add, Nitpick_Const_Simps.add]),
  2290            ((Binding.name "iffs", iffs), [iff_add])]
  2298            ((Binding.name "iffs", iffs), [iff_add])]
  2291       |> put_record name (make_record_info args parent fields extension induct_scheme' ext_def)
  2299       |> put_record name (make_record_info args parent fields extension induct_scheme' ext_def)
  2292       |> put_sel_upd_names full_moreN depth sel_upd_simps
  2300       |> put_sel_upd names full_moreN depth sel_upd_simps
  2293                            sel_upd_defs (fold_congs', unfold_congs')
  2301                            sel_upd_defs (fold_congs', unfold_congs')
  2294       |> add_record_equalities extension_id equality'
  2302       |> add_record_equalities extension_id equality'
  2295       |> add_extinjects ext_inject
  2303       |> add_extinjects ext_inject
  2296       |> add_extsplit extension_name ext_split
  2304       |> add_extsplit extension_name ext_split
  2297       |> add_record_splits extension_id (split_meta',split_object',split_ex',induct_scheme')
  2305       |> add_record_splits extension_id (split_meta',split_object',split_ex',induct_scheme')