src/Pure/axclass.ML
changeset 36327 c0415cb24a10
parent 36326 85d026788fce
child 36328 4d9deabf6474
equal deleted inserted replaced
36326:85d026788fce 36327:c0415cb24a10
    11     (Thm.binding * term list) list -> theory -> class * theory
    11     (Thm.binding * term list) list -> theory -> class * theory
    12   val add_classrel: thm -> theory -> theory
    12   val add_classrel: thm -> theory -> theory
    13   val add_arity: thm -> theory -> theory
    13   val add_arity: thm -> theory -> theory
    14   val prove_classrel: class * class -> tactic -> theory -> theory
    14   val prove_classrel: class * class -> tactic -> theory -> theory
    15   val prove_arity: string * sort list * sort -> tactic -> theory -> theory
    15   val prove_arity: string * sort list * sort -> tactic -> theory -> theory
    16   val get_info: theory -> class ->
    16   type info = {def: thm, intro: thm, axioms: thm list, params: (string * typ) list}
    17     {def: thm, intro: thm, axioms: thm list, params: (string * typ) list}
    17   val get_info: theory -> class -> info
    18   val class_intros: theory -> thm list
    18   val class_intros: theory -> thm list
    19   val class_of_param: theory -> string -> class option
    19   val class_of_param: theory -> string -> class option
    20   val cert_classrel: theory -> class * class -> class * class
    20   val cert_classrel: theory -> class * class -> class * class
    21   val read_classrel: theory -> xstring * xstring -> class * class
    21   val read_classrel: theory -> xstring * xstring -> class * class
    22   val axiomatize_class: binding * class list -> theory -> theory
    22   val axiomatize_class: binding * class list -> theory -> theory
    58 fun merge_params _ ([], qs) = qs
    58 fun merge_params _ ([], qs) = qs
    59   | merge_params pp (ps, qs) =
    59   | merge_params pp (ps, qs) =
    60       fold_rev (fn q => if member (op =) ps q then I else add_param pp q) qs ps;
    60       fold_rev (fn q => if member (op =) ps q then I else add_param pp q) qs ps;
    61 
    61 
    62 
    62 
    63 (* axclasses *)
    63 (* axclass info *)
    64 
    64 
    65 datatype axclass = AxClass of
    65 type info =
    66  {def: thm,
    66  {def: thm,
    67   intro: thm,
    67   intro: thm,
    68   axioms: thm list,
    68   axioms: thm list,
    69   params: (string * typ) list};
    69   params: (string * typ) list};
    70 
    70 
    71 type axclasses = axclass Symtab.table * param list;
    71 type axclasses = info Symtab.table * param list;
    72 
    72 
    73 fun make_axclass ((def, intro, axioms), params) = AxClass
    73 fun make_axclass ((def, intro, axioms), params): info =
    74   {def = def, intro = intro, axioms = axioms, params = params};
    74   {def = def, intro = intro, axioms = axioms, params = params};
    75 
    75 
    76 fun merge_axclasses pp ((tab1, params1), (tab2, params2)) : axclasses =
    76 fun merge_axclasses pp ((tab1, params1), (tab2, params2)) : axclasses =
    77   (Symtab.merge (K true) (tab1, tab2), merge_params pp (params1, params2));
    77   (Symtab.merge (K true) (tab1, tab2), merge_params pp (params1, params2));
    78 
    78 
   104     Symtab.merge (K true) (param_const1, param_const2));
   104     Symtab.merge (K true) (param_const1, param_const2));
   105 
   105 
   106 
   106 
   107 (* setup data *)
   107 (* setup data *)
   108 
   108 
   109 structure AxClassData = Theory_Data_PP
   109 structure Data = Theory_Data_PP
   110 (
   110 (
   111   type T = axclasses * ((instances * inst_params) * (class * class) list);
   111   type T = axclasses * ((instances * inst_params) * (class * class) list);
   112   val empty = ((Symtab.empty, []), (((Symreltab.empty, Symtab.empty), (Symtab.empty, Symtab.empty)), []));
   112   val empty = ((Symtab.empty, []), (((Symreltab.empty, Symtab.empty), (Symtab.empty, Symtab.empty)), []));
   113   val extend = I;
   113   val extend = I;
   114   fun merge pp ((axclasses1, ((instances1, inst_params1), diff_merge_classrels1)),
   114   fun merge pp ((axclasses1, ((instances1, inst_params1), diff_merge_classrels1)),
   115     (axclasses2, ((instances2, inst_params2), diff_merge_classrels2))) =
   115     (axclasses2, ((instances2, inst_params2), diff_merge_classrels2))) =
   116     let
   116     let
   117       val (classrels1, classrels2) = pairself (Symreltab.keys o fst) (instances1, instances2)
   117       val (classrels1, classrels2) = pairself (Symreltab.keys o fst) (instances1, instances2);
   118       val diff_merge_classrels = subtract (op =) classrels1 classrels2
   118       val diff_merge_classrels =
   119         @ subtract (op =) classrels2 classrels1
   119         subtract (op =) classrels1 classrels2 @
   120         @ diff_merge_classrels1 @ diff_merge_classrels2
   120         subtract (op =) classrels2 classrels1 @
       
   121         diff_merge_classrels1 @ diff_merge_classrels2;
   121     in
   122     in
   122       (merge_axclasses pp (axclasses1, axclasses2),
   123       (merge_axclasses pp (axclasses1, axclasses2),
   123         ((merge_instances (instances1, instances2), merge_inst_params (inst_params1, inst_params2)),
   124         ((merge_instances (instances1, instances2), merge_inst_params (inst_params1, inst_params2)),
   124           diff_merge_classrels))
   125           diff_merge_classrels))
   125     end;
   126     end;
   126 );
   127 );
   127 
   128 
   128 
   129 
   129 (* maintain axclasses *)
   130 (* maintain axclasses *)
   130 
   131 
   131 val get_axclasses = #1 o AxClassData.get;
   132 val get_axclasses = #1 o Data.get;
   132 val map_axclasses = AxClassData.map o apfst;
   133 val map_axclasses = Data.map o apfst;
   133 
       
   134 val lookup_def = Symtab.lookup o #1 o get_axclasses;
       
   135 
   134 
   136 fun get_info thy c =
   135 fun get_info thy c =
   137   (case lookup_def thy c of
   136   (case Symtab.lookup (#1 (get_axclasses thy)) c of
   138     SOME (AxClass info) => info
   137     SOME info => info
   139   | NONE => error ("No such axclass: " ^ quote c));
   138   | NONE => error ("No such axclass: " ^ quote c));
   140 
   139 
   141 fun class_intros thy =
   140 fun class_intros thy =
   142   let
   141   let
   143     fun add_intro c =
   142     fun add_intro c = (case try (get_info thy) c of SOME {intro, ...} => cons intro | _ => I);
   144       (case lookup_def thy c of SOME (AxClass {intro, ...}) => cons intro | _ => I);
       
   145     val classes = Sign.all_classes thy;
   143     val classes = Sign.all_classes thy;
   146   in map (Thm.class_triv thy) classes @ fold add_intro classes [] end;
   144   in map (Thm.class_triv thy) classes @ fold add_intro classes [] end;
   147 
   145 
   148 
   146 fun all_params_of thy S =
   149 fun get_params thy pred =
       
   150   let val params = #2 (get_axclasses thy);
   147   let val params = #2 (get_axclasses thy);
   151   in fold (fn (x, c) => if pred c then cons x else I) params [] end;
   148   in fold (fn (x, c) => if Sign.subsort thy (S, [c]) then cons x else I) params [] end;
   152 
       
   153 fun all_params_of thy S = get_params thy (fn c => Sign.subsort thy (S, [c]));
       
   154 
   149 
   155 fun class_of_param thy = AList.lookup (op =) (#2 (get_axclasses thy));
   150 fun class_of_param thy = AList.lookup (op =) (#2 (get_axclasses thy));
   156 
   151 
   157 
   152 
   158 (* maintain instances *)
   153 (* maintain instances *)
   159 
   154 
   160 fun instance_name (a, c) = Long_Name.base_name c ^ "_" ^ Long_Name.base_name a;
   155 fun instance_name (a, c) = Long_Name.base_name c ^ "_" ^ Long_Name.base_name a;
   161 
   156 
   162 val get_instances = #1 o #1 o #2 o AxClassData.get;
   157 val get_instances = #1 o #1 o #2 o Data.get;
   163 val map_instances = AxClassData.map o apsnd o apfst o apfst;
   158 val map_instances = Data.map o apsnd o apfst o apfst;
   164 
   159 
   165 val get_diff_merge_classrels = #2 o #2 o AxClassData.get;
   160 val get_diff_merge_classrels = #2 o #2 o Data.get;
   166 val clear_diff_merge_classrels = AxClassData.map (apsnd (apsnd (K [])));
   161 val clear_diff_merge_classrels = Data.map (apsnd (apsnd (K [])));
   167 
   162 
   168 
   163 
   169 fun the_classrel thy (c1, c2) =
   164 fun the_classrel thy (c1, c2) =
   170   (case Symreltab.lookup (#1 (get_instances thy)) (c1, c2) of
   165   (case Symreltab.lookup (#1 (get_instances thy)) (c1, c2) of
   171     SOME classrel => classrel
   166     SOME classrel => classrel
   175 fun the_classrel_thm thy = Thm.transfer thy o fst o the_classrel thy;
   170 fun the_classrel_thm thy = Thm.transfer thy o fst o the_classrel thy;
   176 fun the_classrel_prf thy = snd o the_classrel thy;
   171 fun the_classrel_prf thy = snd o the_classrel thy;
   177 
   172 
   178 fun put_trancl_classrel ((c1, c2), th) thy =
   173 fun put_trancl_classrel ((c1, c2), th) thy =
   179   let
   174   let
   180     val classrels = fst (get_instances thy)
   175     val cert = Thm.cterm_of thy;
   181     val alg = Sign.classes_of thy
   176     val certT = Thm.ctyp_of thy;
   182     val {classes, ...} = alg |> Sorts.rep_algebra
   177 
       
   178     val classrels = fst (get_instances thy);
       
   179     val classes = #classes (Sorts.rep_algebra (Sign.classes_of thy));
   183 
   180 
   184     fun reflcl_classrel (c1', c2') =
   181     fun reflcl_classrel (c1', c2') =
   185       if c1' = c2' then Thm.trivial (Logic.mk_of_class (TVar(("'a",0),[]), c1') |> cterm_of thy)
   182       if c1' = c2'
   186       else the_classrel_thm thy (c1', c2')
   183       then Thm.trivial (cert (Logic.mk_of_class (TVar ((Name.aT, 0), []), c1')))
       
   184       else the_classrel_thm thy (c1', c2');
   187     fun gen_classrel (c1_pred, c2_succ) =
   185     fun gen_classrel (c1_pred, c2_succ) =
   188       let
   186       let
   189         val th' = ((reflcl_classrel (c1_pred, c1) RS th) RS reflcl_classrel (c2, c2_succ))
   187         val th' = ((reflcl_classrel (c1_pred, c1) RS th) RS reflcl_classrel (c2, c2_succ))
   190           |> Drule.instantiate' [SOME (ctyp_of thy (TVar ((Name.aT, 0), [])))] []
   188           |> Drule.instantiate' [SOME (certT (TVar ((Name.aT, 0), [])))] []
   191           |> Thm.close_derivation
   189           |> Thm.close_derivation;
   192         val prf' = th' |> Thm.proof_of
   190         val prf' = th' |> Thm.proof_of;
   193       in ((c1_pred, c2_succ), (th',prf')) end
   191       in ((c1_pred, c2_succ), (th', prf')) end;
   194 
   192 
   195     val new_classrels = Library.map_product pair
   193     val new_classrels =
   196         (c1 :: Graph.imm_preds classes c1) (c2 :: Graph.imm_succs classes c2)
   194       Library.map_product pair (c1 :: Graph.imm_preds classes c1) (c2 :: Graph.imm_succs classes c2)
   197       |> filter_out (Symreltab.defined classrels)
   195       |> filter_out (Symreltab.defined classrels)
   198       |> map gen_classrel
   196       |> map gen_classrel;
   199     val needed = length new_classrels > 0
   197     val needed = not (null new_classrels);
   200   in
   198   in
   201     (needed,
   199     (needed,
   202      if needed then
   200      if needed then
   203        thy |> map_instances (fn (classrels, arities) =>
   201        thy |> map_instances (fn (classrels, arities) =>
   204          (classrels |> fold Symreltab.update new_classrels, arities))
   202          (classrels |> fold Symreltab.update new_classrels, arities))
   205      else thy)
   203      else thy)
   206   end;
   204   end;
   207 
   205 
   208 fun complete_classrels thy =
   206 fun complete_classrels thy =
   209   let
   207   let
   210     val diff_merge_classrels = get_diff_merge_classrels thy
   208     val diff_merge_classrels = get_diff_merge_classrels thy;
   211     val classrels = fst (get_instances thy)
   209     val classrels = fst (get_instances thy);
   212     val (needed, thy') = (false, thy) |>
   210     val (needed, thy') = (false, thy) |>
   213       fold (fn c12 => fn (needed, thy) =>
   211       fold (fn c12 => fn (needed, thy) =>
   214           put_trancl_classrel (c12, Symreltab.lookup classrels c12 |> the |> fst) thy
   212           put_trancl_classrel (c12, Symreltab.lookup classrels c12 |> the |> fst) thy
   215           |>> (fn b => needed orelse b))
   213           |>> (fn b => needed orelse b))
   216         diff_merge_classrels
   214         diff_merge_classrels;
   217   in
   215   in
   218     if null diff_merge_classrels then NONE
   216     if null diff_merge_classrels then NONE
   219     else thy' |> clear_diff_merge_classrels |> SOME
   217     else thy' |> clear_diff_merge_classrels |> SOME
   220   end;
   218   end;
   221 
   219 
   244     val names_and_Ss = Name.names Name.context Name.aT (map (K []) Ss);
   242     val names_and_Ss = Name.names Name.context Name.aT (map (K []) Ss);
   245     val completions = super_class_completions |> map (fn c1 =>
   243     val completions = super_class_completions |> map (fn c1 =>
   246       let
   244       let
   247         val th1 = (th RS the_classrel_thm thy (c, c1))
   245         val th1 = (th RS the_classrel_thm thy (c, c1))
   248           |> Drule.instantiate' (map (SOME o ctyp_of thy o TVar o apfst (rpair 0)) names_and_Ss) []
   246           |> Drule.instantiate' (map (SOME o ctyp_of thy o TVar o apfst (rpair 0)) names_and_Ss) []
   249           |> Thm.close_derivation
   247           |> Thm.close_derivation;
   250         val prf1 = Thm.proof_of th1
   248         val prf1 = Thm.proof_of th1;
   251       in (((th1,thy_name), prf1), c1) end)
   249       in (((th1, thy_name), prf1), c1) end);
   252     val arities' = fold (fn (th_thy_prf1, c1) => Symtab.cons_list (t, ((c1, Ss), th_thy_prf1)))
   250     val arities' = fold (fn (th_thy_prf1, c1) => Symtab.cons_list (t, ((c1, Ss), th_thy_prf1)))
   253       completions arities;
   251       completions arities;
   254   in (null completions, arities') end;
   252   in (null completions, arities') end;
   255 
   253 
   256 fun put_arity ((t, Ss, c), th) thy =
   254 fun put_arity ((t, Ss, c), th) thy =
   279   (Theory.at_begin complete_classrels #> Theory.at_begin complete_arities))
   277   (Theory.at_begin complete_classrels #> Theory.at_begin complete_arities))
   280 
   278 
   281 
   279 
   282 (* maintain instance parameters *)
   280 (* maintain instance parameters *)
   283 
   281 
   284 val get_inst_params = #2 o #1 o #2 o AxClassData.get;
   282 val get_inst_params = #2 o #1 o #2 o Data.get;
   285 val map_inst_params = AxClassData.map o apsnd o apfst o apsnd;
   283 val map_inst_params = Data.map o apsnd o apfst o apsnd;
   286 
   284 
   287 fun get_inst_param thy (c, tyco) =
   285 fun get_inst_param thy (c, tyco) =
   288   case Symtab.lookup ((the_default Symtab.empty o Symtab.lookup (fst (get_inst_params thy))) c) tyco
   286   (case Symtab.lookup (the_default Symtab.empty (Symtab.lookup (#1 (get_inst_params thy)) c)) tyco of
   289    of SOME c' => c'
   287     SOME c' => c'
   290     | NONE => error ("No instance parameter for constant " ^ quote c
   288   | NONE => error ("No instance parameter for constant " ^ quote c ^ " on type " ^ quote tyco));
   291         ^ " on type constructor " ^ quote tyco);
   289 
   292 
   290 fun add_inst_param (c, tyco) inst =
   293 fun add_inst_param (c, tyco) inst = (map_inst_params o apfst
   291   (map_inst_params o apfst o Symtab.map_default (c, Symtab.empty)) (Symtab.update_new (tyco, inst))
   294       o Symtab.map_default (c, Symtab.empty)) (Symtab.update_new (tyco, inst))
       
   295   #> (map_inst_params o apsnd) (Symtab.update_new (fst inst, (c, tyco)));
   292   #> (map_inst_params o apsnd) (Symtab.update_new (fst inst, (c, tyco)));
   296 
   293 
   297 val inst_of_param = Symtab.lookup o snd o get_inst_params;
   294 val inst_of_param = Symtab.lookup o snd o get_inst_params;
   298 val param_of_inst = fst oo get_inst_param;
   295 val param_of_inst = fst oo get_inst_param;
   299 
   296 
   300 fun inst_thms thy = (Symtab.fold (Symtab.fold (cons o snd o snd) o snd) o fst)
   297 fun inst_thms thy =
   301   (get_inst_params thy) [];
   298   (Symtab.fold (Symtab.fold (cons o snd o snd) o snd) o fst) (get_inst_params thy) [];
   302 
   299 
   303 fun get_inst_tyco consts = try (fst o dest_Type o the_single o Consts.typargs consts);
   300 fun get_inst_tyco consts = try (fst o dest_Type o the_single o Consts.typargs consts);
   304 
   301 
   305 fun unoverload thy = MetaSimplifier.simplify true (inst_thms thy);
   302 fun unoverload thy = MetaSimplifier.simplify true (inst_thms thy);
   306 fun overload thy = MetaSimplifier.simplify true (map Thm.symmetric (inst_thms thy));
   303 fun overload thy = MetaSimplifier.simplify true (map Thm.symmetric (inst_thms thy));
   307 
   304 
   308 fun unoverload_conv thy = MetaSimplifier.rewrite true (inst_thms thy);
   305 fun unoverload_conv thy = MetaSimplifier.rewrite true (inst_thms thy);
   309 fun overload_conv thy = MetaSimplifier.rewrite true (map Thm.symmetric (inst_thms thy));
   306 fun overload_conv thy = MetaSimplifier.rewrite true (map Thm.symmetric (inst_thms thy));
   310 
   307 
   311 fun lookup_inst_param consts params (c, T) = case get_inst_tyco consts (c, T)
   308 fun lookup_inst_param consts params (c, T) =
   312  of SOME tyco => AList.lookup (op =) params (c, tyco)
   309   (case get_inst_tyco consts (c, T) of
   313   | NONE => NONE;
   310     SOME tyco => AList.lookup (op =) params (c, tyco)
       
   311   | NONE => NONE);
   314 
   312 
   315 fun unoverload_const thy (c_ty as (c, _)) =
   313 fun unoverload_const thy (c_ty as (c, _)) =
   316   if is_some (class_of_param thy c)
   314   if is_some (class_of_param thy c) then
   317   then case get_inst_tyco (Sign.consts_of thy) c_ty
   315     (case get_inst_tyco (Sign.consts_of thy) c_ty of
   318    of SOME tyco => try (param_of_inst thy) (c, tyco) |> the_default c
   316       SOME tyco => try (param_of_inst thy) (c, tyco) |> the_default c
   319     | NONE => c
   317     | NONE => c)
   320   else c;
   318   else c;
       
   319 
   321 
   320 
   322 
   321 
   323 (** instances **)
   322 (** instances **)
   324 
   323 
   325 (* class relations *)
   324 (* class relations *)
   338 
   337 
   339 fun read_classrel thy raw_rel =
   338 fun read_classrel thy raw_rel =
   340   cert_classrel thy (pairself (ProofContext.read_class (ProofContext.init thy)) raw_rel)
   339   cert_classrel thy (pairself (ProofContext.read_class (ProofContext.init thy)) raw_rel)
   341     handle TYPE (msg, _, _) => error msg;
   340     handle TYPE (msg, _, _) => error msg;
   342 
   341 
   343 fun check_shyps_topped th errmsg =
   342 val shyps_topped = forall null o #shyps o Thm.rep_thm;
   344   let val {shyps, ...} = Thm.rep_thm th
   343 
   345   in
       
   346     forall null shyps orelse raise Fail errmsg
       
   347   end;
       
   348 
   344 
   349 (* declaration and definition of instances of overloaded constants *)
   345 (* declaration and definition of instances of overloaded constants *)
   350 
   346 
   351 fun inst_tyco_of thy (c, T) =
   347 fun inst_tyco_of thy (c, T) =
   352   (case get_inst_tyco (Sign.consts_of thy) (c, T) of
   348   (case get_inst_tyco (Sign.consts_of thy) (c, T) of
   404     val rel = Logic.dest_classrel prop handle TERM _ => err ();
   400     val rel = Logic.dest_classrel prop handle TERM _ => err ();
   405     val (c1, c2) = cert_classrel thy rel handle TYPE _ => err ();
   401     val (c1, c2) = cert_classrel thy rel handle TYPE _ => err ();
   406     val th' = th
   402     val th' = th
   407       |> Drule.instantiate' [SOME (ctyp_of thy (TVar ((Name.aT, 0), [c1])))] []
   403       |> Drule.instantiate' [SOME (ctyp_of thy (TVar ((Name.aT, 0), [c1])))] []
   408       |> Drule.unconstrainTs;
   404       |> Drule.unconstrainTs;
   409     val _ = check_shyps_topped th' "add_classrel: nontop shyps after unconstrain"
   405     val _ = shyps_topped th' orelse raise Fail "add_classrel: nontop shyps after unconstrain";
   410   in
   406   in
   411     thy
   407     thy
   412     |> Sign.primitive_classrel (c1, c2)
   408     |> Sign.primitive_classrel (c1, c2)
   413     |> (snd oo put_trancl_classrel) ((c1, c2), th')
   409     |> (snd oo put_trancl_classrel) ((c1, c2), th')
   414     |> perhaps complete_arities
   410     |> perhaps complete_arities
   428       |> (map o apsnd o map_atyps) (K T);
   424       |> (map o apsnd o map_atyps) (K T);
   429     val _ = map (Sign.certify_sort thy) Ss = Ss orelse err ();
   425     val _ = map (Sign.certify_sort thy) Ss = Ss orelse err ();
   430     val th' = th
   426     val th' = th
   431       |> Drule.instantiate' (map (SOME o ctyp_of thy o TVar o apfst (rpair 0)) names) []
   427       |> Drule.instantiate' (map (SOME o ctyp_of thy o TVar o apfst (rpair 0)) names) []
   432       |> Drule.unconstrainTs;
   428       |> Drule.unconstrainTs;
   433     val _ = check_shyps_topped th' "add_arity: nontop shyps after unconstrain"
   429     val _ = shyps_topped th' orelse raise Fail "add_arity: nontop shyps after unconstrain";
   434   in
   430   in
   435     thy
   431     thy
   436     |> fold (snd oo declare_overloaded) missing_params
   432     |> fold (snd oo declare_overloaded) missing_params
   437     |> Sign.primitive_arity (t, Ss, [c])
   433     |> Sign.primitive_arity (t, Ss, [c])
   438     |> put_arity ((t, Ss, c), th')
   434     |> put_arity ((t, Ss, c), th')