src/Pure/axclass.ML
changeset 36325 8715343af626
parent 36106 19deea200358
child 36326 85d026788fce
     1.1 --- a/src/Pure/axclass.ML	Sun Apr 25 16:10:05 2010 +0200
     1.2 +++ b/src/Pure/axclass.ML	Sun Apr 25 19:09:37 2010 +0200
     1.3 @@ -89,11 +89,12 @@
     1.4  val arity_prefix = "arity_";
     1.5  
     1.6  type instances =
     1.7 -  ((class * class) * thm) list *  (*classrel theorems*)
     1.8 -  ((class * sort list) * (thm * string)) list Symtab.table;  (*arity theorems with theory name*)
     1.9 +  (thm * proof) Symreltab.table *  (*classrel theorems*)
    1.10 +  ((class * sort list) * ((thm * string) * proof)) list Symtab.table;  (*arity theorems with theory name*)
    1.11  
    1.12 +(*transitive closure of classrels and arity completion is done in Theory.at_begin hook*)
    1.13  fun merge_instances ((classrel1, arities1): instances, (classrel2, arities2)) =
    1.14 - (merge (eq_fst op =) (classrel1, classrel2),
    1.15 + (Symreltab.join (K fst) (classrel1, classrel2),
    1.16    Symtab.join (K (merge (eq_fst op =))) (arities1, arities2));
    1.17  
    1.18  
    1.19 @@ -113,12 +114,21 @@
    1.20  
    1.21  structure AxClassData = Theory_Data_PP
    1.22  (
    1.23 -  type T = axclasses * (instances * inst_params);
    1.24 -  val empty = ((Symtab.empty, []), (([], Symtab.empty), (Symtab.empty, Symtab.empty)));
    1.25 +  type T = axclasses * ((instances * inst_params) * (class * class) list);
    1.26 +  val empty = ((Symtab.empty, []), (((Symreltab.empty, Symtab.empty), (Symtab.empty, Symtab.empty)), []));
    1.27    val extend = I;
    1.28 -  fun merge pp ((axclasses1, (instances1, inst_params1)), (axclasses2, (instances2, inst_params2))) =
    1.29 -    (merge_axclasses pp (axclasses1, axclasses2),
    1.30 -      (merge_instances (instances1, instances2), merge_inst_params (inst_params1, inst_params2)));
    1.31 +  fun merge pp ((axclasses1, ((instances1, inst_params1), diff_merge_classrels1)),
    1.32 +    (axclasses2, ((instances2, inst_params2), diff_merge_classrels2))) =
    1.33 +    let
    1.34 +      val (classrels1, classrels2) = pairself (Symreltab.keys o fst) (instances1, instances2)
    1.35 +      val diff_merge_classrels = subtract (op =) classrels1 classrels2
    1.36 +        @ subtract (op =) classrels2 classrels1
    1.37 +        @ diff_merge_classrels1 @ diff_merge_classrels2
    1.38 +    in
    1.39 +      (merge_axclasses pp (axclasses1, axclasses2),
    1.40 +        ((merge_instances (instances1, instances2), merge_inst_params (inst_params1, inst_params2)),
    1.41 +          diff_merge_classrels))
    1.42 +    end;
    1.43  );
    1.44  
    1.45  
    1.46 @@ -155,48 +165,103 @@
    1.47  
    1.48  fun instance_name (a, c) = Long_Name.base_name c ^ "_" ^ Long_Name.base_name a;
    1.49  
    1.50 -val get_instances = #1 o #2 o AxClassData.get;
    1.51 -val map_instances = AxClassData.map o apsnd o apfst;
    1.52 +val get_instances = #1 o #1 o #2 o AxClassData.get;
    1.53 +val map_instances = AxClassData.map o apsnd o apfst o apfst;
    1.54 +
    1.55 +val get_diff_merge_classrels = #2 o #2 o AxClassData.get;
    1.56 +val clear_diff_merge_classrels = AxClassData.map (apsnd (apsnd (K [])));
    1.57  
    1.58  
    1.59  fun the_classrel thy (c1, c2) =
    1.60 -  (case AList.lookup (op =) (#1 (get_instances thy)) (c1, c2) of
    1.61 -    SOME th => Thm.transfer thy th
    1.62 +  (case Symreltab.lookup (#1 (get_instances thy)) (c1, c2) of
    1.63 +    SOME classrel => classrel
    1.64    | NONE => error ("Unproven class relation " ^
    1.65        Syntax.string_of_classrel (ProofContext.init thy) [c1, c2]));
    1.66  
    1.67 -fun put_classrel arg = map_instances (fn (classrel, arities) =>
    1.68 -  (insert (eq_fst op =) arg classrel, arities));
    1.69 +fun the_classrel_thm thy = Thm.transfer thy o fst o the_classrel thy;
    1.70 +fun the_classrel_prf thy = snd o the_classrel thy;
    1.71 +
    1.72 +fun put_trancl_classrel ((c1, c2), th) thy =
    1.73 +  let
    1.74 +    val classrels = fst (get_instances thy)
    1.75 +    val alg = Sign.classes_of thy
    1.76 +    val {classes, ...} = alg |> Sorts.rep_algebra
    1.77 +
    1.78 +    fun reflcl_classrel (c1', c2') =
    1.79 +      if c1' = c2' then Thm.trivial (Logic.mk_of_class (TVar(("'a",0),[]), c1') |> cterm_of thy)
    1.80 +      else the_classrel_thm thy (c1', c2')
    1.81 +    fun gen_classrel (c1_pred, c2_succ) =
    1.82 +      let
    1.83 +        val th' = ((reflcl_classrel (c1_pred, c1) RS th) RS reflcl_classrel (c2, c2_succ))
    1.84 +          |> Drule.instantiate' [SOME (ctyp_of thy (TVar ((Name.aT, 0), [])))] []
    1.85 +          |> Thm.close_derivation
    1.86 +        val prf' = th' |> Thm.proof_of
    1.87 +      in ((c1_pred, c2_succ), (th',prf')) end
    1.88 +
    1.89 +    val new_classrels = Library.map_product pair
    1.90 +        (c1 :: Graph.imm_preds classes c1) (c2 :: Graph.imm_succs classes c2)
    1.91 +      |> filter_out (Symreltab.defined classrels)
    1.92 +      |> map gen_classrel
    1.93 +    val needed = length new_classrels > 0
    1.94 +  in
    1.95 +    (needed,
    1.96 +     if needed then
    1.97 +       thy |> map_instances (fn (classrels, arities) =>
    1.98 +         (classrels |> fold Symreltab.update new_classrels, arities))
    1.99 +     else thy)
   1.100 +  end;
   1.101 +
   1.102 +fun complete_classrels thy =
   1.103 +  let
   1.104 +    val diff_merge_classrels = get_diff_merge_classrels thy
   1.105 +    val classrels = fst (get_instances thy)
   1.106 +    val (needed, thy') = (false, thy) |>
   1.107 +      fold (fn c12 => fn (needed, thy) =>
   1.108 +          put_trancl_classrel (c12, Symreltab.lookup classrels c12 |> the |> fst) thy
   1.109 +          |>> (fn b => needed orelse b))
   1.110 +        diff_merge_classrels
   1.111 +  in
   1.112 +    if null diff_merge_classrels then NONE
   1.113 +    else thy' |> clear_diff_merge_classrels |> SOME
   1.114 +  end;
   1.115  
   1.116  
   1.117  fun the_arity thy a (c, Ss) =
   1.118    (case AList.lookup (op =) (Symtab.lookup_list (#2 (get_instances thy)) a) (c, Ss) of
   1.119 -    SOME (th, _) => Thm.transfer thy th
   1.120 +    SOME arity => arity
   1.121    | NONE => error ("Unproven type arity " ^
   1.122        Syntax.string_of_arity (ProofContext.init thy) (a, Ss, [c])));
   1.123  
   1.124 +fun the_arity_thm thy a c_Ss = the_arity thy a c_Ss |> fst |> fst |> Thm.transfer thy;
   1.125 +fun the_arity_prf thy a c_Ss = the_arity thy a c_Ss |> snd;
   1.126 +
   1.127  fun thynames_of_arity thy (c, a) =
   1.128    Symtab.lookup_list (#2 (get_instances thy)) a
   1.129 -  |> map_filter (fn ((c', _), (_, name)) => if c = c' then SOME name else NONE)
   1.130 +  |> map_filter (fn ((c', _), ((_, name),_)) => if c = c' then SOME name else NONE)
   1.131    |> rev;
   1.132  
   1.133 -fun insert_arity_completions thy (t, ((c, Ss), (th, thy_name))) arities =
   1.134 +fun insert_arity_completions thy (t, ((c, Ss), ((th, thy_name), _))) arities =
   1.135    let
   1.136      val algebra = Sign.classes_of thy;
   1.137      val super_class_completions =
   1.138        Sign.super_classes thy c
   1.139        |> filter_out (fn c1 => exists (fn ((c2, Ss2), _) => c1 = c2
   1.140            andalso Sorts.sorts_le algebra (Ss2, Ss)) (Symtab.lookup_list arities t));
   1.141 -    val completions = map (fn c1 => (Sorts.classrel_derivation algebra
   1.142 -      (fn (th, c2) => fn c3 => th RS the_classrel thy (c2, c3)) (th, c) c1
   1.143 -        |> Thm.close_derivation, c1)) super_class_completions;
   1.144 -    val arities' = fold (fn (th1, c1) => Symtab.cons_list (t, ((c1, Ss), (th1, thy_name))))
   1.145 +    val names_and_Ss = Name.names Name.context Name.aT (map (K []) Ss);
   1.146 +    val completions = super_class_completions |> map (fn c1 =>
   1.147 +      let
   1.148 +        val th1 = (th RS the_classrel_thm thy (c, c1))
   1.149 +          |> Drule.instantiate' (map (SOME o ctyp_of thy o TVar o apfst (rpair 0)) names_and_Ss) []
   1.150 +          |> Thm.close_derivation
   1.151 +        val prf1 = Thm.proof_of th1
   1.152 +      in (((th1,thy_name), prf1), c1) end)
   1.153 +    val arities' = fold (fn (th_thy_prf1, c1) => Symtab.cons_list (t, ((c1, Ss), th_thy_prf1)))
   1.154        completions arities;
   1.155    in (null completions, arities') end;
   1.156  
   1.157  fun put_arity ((t, Ss, c), th) thy =
   1.158    let
   1.159 -    val arity' = (t, ((c, Ss), (th, Context.theory_name thy)));
   1.160 +    val arity' = (t, ((c, Ss), ((th, Context.theory_name thy), Thm.proof_of th)));
   1.161    in
   1.162      thy
   1.163      |> map_instances (fn (classrel, arities) => (classrel,
   1.164 @@ -216,13 +281,14 @@
   1.165      else SOME (thy |> map_instances (fn (classrel, _) => (classrel, arities')))
   1.166    end;
   1.167  
   1.168 -val _ = Context.>> (Context.map_theory (Theory.at_begin complete_arities));
   1.169 +val _ = Context.>> (Context.map_theory
   1.170 +  (Theory.at_begin complete_classrels #> Theory.at_begin complete_arities))
   1.171  
   1.172  
   1.173  (* maintain instance parameters *)
   1.174  
   1.175 -val get_inst_params = #2 o #2 o AxClassData.get;
   1.176 -val map_inst_params = AxClassData.map o apsnd o apsnd;
   1.177 +val get_inst_params = #2 o #1 o #2 o AxClassData.get;
   1.178 +val map_inst_params = AxClassData.map o apsnd o apfst o apsnd;
   1.179  
   1.180  fun get_inst_param thy (c, tyco) =
   1.181    case Symtab.lookup ((the_default Symtab.empty o Symtab.lookup (fst (get_inst_params thy))) c) tyco
   1.182 @@ -280,6 +346,11 @@
   1.183    cert_classrel thy (pairself (ProofContext.read_class (ProofContext.init thy)) raw_rel)
   1.184      handle TYPE (msg, _, _) => error msg;
   1.185  
   1.186 +fun check_shyps_topped th errmsg =
   1.187 +  let val {shyps, ...} = Thm.rep_thm th
   1.188 +  in
   1.189 +    forall null shyps orelse raise Fail errmsg
   1.190 +  end;
   1.191  
   1.192  (* declaration and definition of instances of overloaded constants *)
   1.193  
   1.194 @@ -338,10 +409,14 @@
   1.195      fun err () = raise THM ("add_classrel: malformed class relation", 0, [th]);
   1.196      val rel = Logic.dest_classrel prop handle TERM _ => err ();
   1.197      val (c1, c2) = cert_classrel thy rel handle TYPE _ => err ();
   1.198 +    val th' = th
   1.199 +      |> Drule.instantiate' [SOME (ctyp_of thy (TVar ((Name.aT, 0), [c1])))] []
   1.200 +      |> Drule.unconstrainTs;
   1.201 +    val _ = check_shyps_topped th' "add_classrel: nontop shyps after unconstrain"
   1.202    in
   1.203      thy
   1.204      |> Sign.primitive_classrel (c1, c2)
   1.205 -    |> put_classrel ((c1, c2), Thm.close_derivation (Drule.unconstrainTs th))
   1.206 +    |> (snd oo put_trancl_classrel) ((c1, c2), th')
   1.207      |> perhaps complete_arities
   1.208    end;
   1.209  
   1.210 @@ -351,17 +426,22 @@
   1.211      val prop = Thm.plain_prop_of th;
   1.212      fun err () = raise THM ("add_arity: malformed type arity", 0, [th]);
   1.213      val (t, Ss, c) = Logic.dest_arity prop handle TERM _ => err ();
   1.214 -    val T = Type (t, map TFree (Name.names Name.context Name.aT Ss));
   1.215 +    val names = Name.names Name.context Name.aT Ss;
   1.216 +    val T = Type (t, map TFree names);
   1.217      val missing_params = Sign.complete_sort thy [c]
   1.218        |> maps (these o Option.map #params o try (get_info thy))
   1.219        |> filter_out (fn (const, _) => can (get_inst_param thy) (const, t))
   1.220        |> (map o apsnd o map_atyps) (K T);
   1.221      val _ = map (Sign.certify_sort thy) Ss = Ss orelse err ();
   1.222 +    val th' = th
   1.223 +      |> Drule.instantiate' (map (SOME o ctyp_of thy o TVar o apfst (rpair 0)) names) []
   1.224 +      |> Drule.unconstrainTs;
   1.225 +    val _ = check_shyps_topped th' "add_arity: nontop shyps after unconstrain"
   1.226    in
   1.227      thy
   1.228      |> fold (snd oo declare_overloaded) missing_params
   1.229      |> Sign.primitive_arity (t, Ss, [c])
   1.230 -    |> put_arity ((t, Ss, c), Thm.close_derivation (Drule.unconstrainTs th))
   1.231 +    |> put_arity ((t, Ss, c), th')
   1.232    end;
   1.233  
   1.234  
   1.235 @@ -495,7 +575,7 @@
   1.236      val axclass = make_axclass ((def, intro, axioms), params);
   1.237      val result_thy =
   1.238        facts_thy
   1.239 -      |> fold put_classrel (map (pair class) super ~~ classrel)
   1.240 +      |> fold (snd oo put_trancl_classrel) (map (pair class) super ~~ classrel)
   1.241        |> Sign.qualified_path false bconst
   1.242        |> PureThy.note_thmss "" (name_atts ~~ map Thm.simple_fact (unflat axiomss axioms)) |> snd
   1.243        |> Sign.restore_naming facts_thy