src/Pure/sorts.ML
changeset 19578 f93b7637a5e6
parent 19531 89970e06351f
child 19584 606d6a73e6d9
equal deleted inserted replaced
19577:fdb3642feb49 19578:f93b7637a5e6
    38   val rebuild_arities: Pretty.pp -> classes -> arities -> arities
    38   val rebuild_arities: Pretty.pp -> classes -> arities -> arities
    39   val merge_arities: Pretty.pp -> classes -> arities * arities -> arities
    39   val merge_arities: Pretty.pp -> classes -> arities * arities -> arities
    40   val add_class: Pretty.pp -> class * class list -> classes -> classes
    40   val add_class: Pretty.pp -> class * class list -> classes -> classes
    41   val add_classrel: Pretty.pp -> class * class -> classes -> classes
    41   val add_classrel: Pretty.pp -> class * class -> classes -> classes
    42   val merge_classes: Pretty.pp -> classes * classes -> classes
    42   val merge_classes: Pretty.pp -> classes * classes -> classes
    43   exception DOMAIN of string * class
    43   type class_error
    44   val domain_error: Pretty.pp -> string * class -> 'a
    44   val class_error: Pretty.pp -> class_error -> 'a
    45   val mg_domain: classes * arities -> string -> sort -> sort list  (*exception DOMAIN*)
    45   exception CLASS_ERROR of class_error
       
    46   val mg_domain: classes * arities -> string -> sort -> sort list   (*exception CLASS_ERROR*)
    46   val of_sort: classes * arities -> typ * sort -> bool
    47   val of_sort: classes * arities -> typ * sort -> bool
    47   val of_sort_derivation: Pretty.pp -> classes * arities ->
    48   val of_sort_derivation: Pretty.pp -> classes * arities ->
    48     {classrel: 'a * class -> class -> 'a,
    49     {classrel: 'a * class -> class -> 'a,
    49      constructor: string -> ('a * class) list list -> class -> 'a,
    50      constructor: string -> ('a * class) list list -> class -> 'a,
    50      variable: typ -> ('a * class) list} -> typ * sort -> 'a list
    51      variable: typ -> ('a * class) list} -> typ * sort -> 'a list   (*exception CLASS_ERROR*)
    51   val witness_sorts: classes * arities -> string list ->
    52   val witness_sorts: classes * arities -> string list ->
    52     sort list -> sort list -> (typ * sort) list
    53     sort list -> sort list -> (typ * sort) list
    53 end;
    54 end;
    54 
    55 
    55 structure Sorts: SORTS =
    56 structure Sorts: SORTS =
   243 
   244 
   244 
   245 
   245 
   246 
   246 (** sorts of types **)
   247 (** sorts of types **)
   247 
   248 
       
   249 (* errors *)
       
   250 
       
   251 datatype class_error = NoClassrel of class * class | NoArity of string * class;
       
   252 
       
   253 fun class_error pp (NoClassrel (c1, c2)) =
       
   254       error ("No class relation " ^ Pretty.string_of_classrel pp [c1, c2])
       
   255   | class_error pp (NoArity (a, c)) =
       
   256       error ("No type arity " ^ Pretty.string_of_arity pp (a, [], [c]));
       
   257 
       
   258 exception CLASS_ERROR of class_error;
       
   259 
       
   260 
   248 (* mg_domain *)
   261 (* mg_domain *)
   249 
       
   250 exception DOMAIN of string * class;
       
   251 
       
   252 fun domain_error pp (a, c) =
       
   253   error ("No way to get " ^ Pretty.string_of_arity pp (a, [], [c]));
       
   254 
   262 
   255 fun mg_domain (classes, arities) a S =
   263 fun mg_domain (classes, arities) a S =
   256   let
   264   let
   257     fun dom c =
   265     fun dom c =
   258       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   266       (case AList.lookup (op =) (Symtab.lookup_list arities a) c of
   259         NONE => raise DOMAIN (a, c)
   267         NONE => raise CLASS_ERROR (NoArity (a, c))
   260       | SOME (_, Ss) => Ss);
   268       | SOME (_, Ss) => Ss);
   261     fun dom_inter c Ss = ListPair.map (inter_sort classes) (dom c, Ss);
   269     fun dom_inter c Ss = ListPair.map (inter_sort classes) (dom c, Ss);
   262   in
   270   in
   263     (case S of
   271     (case S of
   264       [] => raise Fail "Unknown domain of empty intersection"
   272       [] => raise Fail "Unknown domain of empty intersection"
   274       | ofS (TFree (_, S), S') = sort_le classes (S, S')
   282       | ofS (TFree (_, S), S') = sort_le classes (S, S')
   275       | ofS (TVar (_, S), S') = sort_le classes (S, S')
   283       | ofS (TVar (_, S), S') = sort_le classes (S, S')
   276       | ofS (Type (a, Ts), S) =
   284       | ofS (Type (a, Ts), S) =
   277           let val Ss = mg_domain (classes, arities) a S in
   285           let val Ss = mg_domain (classes, arities) a S in
   278             ListPair.all ofS (Ts, Ss)
   286             ListPair.all ofS (Ts, Ss)
   279           end handle DOMAIN _ => false;
   287           end handle CLASS_ERROR _ => false;
   280   in ofS end;
   288   in ofS end;
   281 
   289 
   282 
   290 
   283 (* of_sort_derivation *)
   291 (* of_sort_derivation *)
   284 
   292 
   285 fun of_sort_derivation pp (classes, arities) {classrel, constructor, variable} =
   293 fun of_sort_derivation pp (classes, arities) {classrel, constructor, variable} =
   286   let
   294   let
   287     fun weaken (x, c1) c2 = if c1 = c2 then x else classrel (x, c1) c2;
   295     fun weaken_path (x, c1 :: c2 :: cs) = weaken_path (classrel (x, c1) c2, c2 :: cs)
       
   296       | weaken_path (x, _) = x;
       
   297     fun weaken (x, c1) c2 =
       
   298       (case Graph.irreducible_paths classes (c1, c2) of
       
   299         [] => raise CLASS_ERROR (NoClassrel (c1, c2))
       
   300       | cs :: _ => weaken_path (x, cs));
       
   301 
   288     fun weakens S1 S2 = S2 |> map (fn c2 =>
   302     fun weakens S1 S2 = S2 |> map (fn c2 =>
   289       (case S1 |> find_first (fn (_, c1) => class_le classes (c1, c2)) of
   303       (case S1 |> find_first (fn (_, c1) => class_le classes (c1, c2)) of
   290         SOME d1 => weaken d1 c2
   304         SOME d1 => weaken d1 c2
   291       | NONE => error ("Cannot derive subsort relation " ^
   305       | NONE => error ("Cannot derive subsort relation " ^
   292           Pretty.string_of_sort pp (map #2 S1) ^ " < " ^ Pretty.string_of_sort pp S2)));
   306           Pretty.string_of_sort pp (map #2 S1) ^ " < " ^ Pretty.string_of_sort pp S2)));
   293 
   307 
   294     fun derive _ [] = []
   308     fun derive _ [] = []
   295       | derive (Type (a, Ts)) S =
   309       | derive (Type (a, Ts)) S =
   296           let
   310           let
   297             val Ss = mg_domain (classes, arities) a S
   311             val Ss = mg_domain (classes, arities) a S;
   298               handle DOMAIN d => domain_error pp d;
       
   299             val dom = map2 (fn T => fn S => derive T S ~~ S) Ts Ss;
   312             val dom = map2 (fn T => fn S => derive T S ~~ S) Ts Ss;
   300           in
   313           in
   301             S |> map (fn c =>
   314             S |> map (fn c =>
   302               let
   315               let
   303                 val (c0, Ss') = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   316                 val (c0, Ss') = the (AList.lookup (op =) (Symtab.lookup_list arities a) c);
   308   in uncurry derive end;
   321   in uncurry derive end;
   309 
   322 
   310 
   323 
   311 (* witness_sorts *)
   324 (* witness_sorts *)
   312 
   325 
   313 local
   326 fun witness_sorts (classes, arities) log_types hyps sorts =
   314 
   327   let
   315 fun witness_aux (classes, arities) log_types hyps sorts =
       
   316   let
       
   317     val top_witn = (propT, []);
       
   318     fun le S1 S2 = sort_le classes (S1, S2);
   328     fun le S1 S2 = sort_le classes (S1, S2);
   319     fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   329     fun get_solved S2 (T, S1) = if le S1 S2 then SOME (T, S2) else NONE;
   320     fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
   330     fun get_hyp S2 S1 = if le S1 S2 then SOME (TFree ("'hyp", S1), S2) else NONE;
   321     fun mg_dom t S = SOME (mg_domain (classes, arities) t S) handle DOMAIN _ => NONE;
   331     fun mg_dom t S = SOME (mg_domain (classes, arities) t S) handle CLASS_ERROR _ => NONE;
   322 
   332 
   323     fun witn_sort _ (solved_failed, []) = (solved_failed, SOME top_witn)
   333     fun witn_sort _ [] solved_failed = (SOME (propT, []), solved_failed)
   324       | witn_sort path ((solved, failed), S) =
   334       | witn_sort path S (solved, failed) =
   325           if exists (le S) failed then ((solved, failed), NONE)
   335           if exists (le S) failed then (NONE, (solved, failed))
   326           else
   336           else
   327             (case get_first (get_solved S) solved of
   337             (case get_first (get_solved S) solved of
   328               SOME w => ((solved, failed), SOME w)
   338               SOME w => (SOME w, (solved, failed))
   329             | NONE =>
   339             | NONE =>
   330                 (case get_first (get_hyp S) hyps of
   340                 (case get_first (get_hyp S) hyps of
   331                   SOME w => ((w :: solved, failed), SOME w)
   341                   SOME w => (SOME w, (w :: solved, failed))
   332                 | NONE => witn_types path log_types ((solved, failed), S)))
   342                 | NONE => witn_types path log_types S (solved, failed)))
   333 
   343 
   334     and witn_sorts path x = foldl_map (witn_sort path) x
   344     and witn_sorts path x = fold_map (witn_sort path) x
   335 
   345 
   336     and witn_types _ [] ((solved, failed), S) = ((solved, S :: failed), NONE)
   346     and witn_types _ [] S (solved, failed) = (NONE, (solved, S :: failed))
   337       | witn_types path (t :: ts) (solved_failed, S) =
   347       | witn_types path (t :: ts) S solved_failed =
   338           (case mg_dom t S of
   348           (case mg_dom t S of
   339             SOME SS =>
   349             SOME SS =>
   340               (*do not descend into stronger args (achieving termination)*)
   350               (*do not descend into stronger args (achieving termination)*)
   341               if exists (fn D => le D S orelse exists (le D) path) SS then
   351               if exists (fn D => le D S orelse exists (le D) path) SS then
   342                 witn_types path ts (solved_failed, S)
   352                 witn_types path ts S solved_failed
   343               else
   353               else
   344                 let val ((solved', failed'), ws) = witn_sorts (S :: path) (solved_failed, SS) in
   354                 let val (ws, (solved', failed')) = witn_sorts (S :: path) SS solved_failed in
   345                   if forall is_some ws then
   355                   if forall is_some ws then
   346                     let val w = (Type (t, map (#1 o the) ws), S)
   356                     let val w = (Type (t, map (#1 o the) ws), S)
   347                     in ((w :: solved', failed'), SOME w) end
   357                     in (SOME w, (w :: solved', failed')) end
   348                   else witn_types path ts ((solved', failed'), S)
   358                   else witn_types path ts S (solved', failed')
   349                 end
   359                 end
   350           | NONE => witn_types path ts (solved_failed, S));
   360           | NONE => witn_types path ts S solved_failed);
   351 
   361 
   352   in witn_sorts [] (([], []), sorts) end;
   362     fun double_check TS =
   353 
   363       if of_sort (classes, arities) TS then TS
   354 fun str_of_sort [c] = c
   364       else sys_error "FIXME Bad sort witness";
   355   | str_of_sort cs = enclose "{" "}" (commas cs);
   365 
   356 
   366   in map_filter (Option.map double_check) (#1 (witn_sorts [] sorts ([], []))) end;
   357 in
       
   358 
       
   359 fun witness_sorts (classes, arities) log_types hyps sorts =
       
   360   let
       
   361     fun double_check_result NONE = NONE
       
   362       | double_check_result (SOME (T, S)) =
       
   363           if of_sort (classes, arities) (T, S) then SOME (T, S)
       
   364           else sys_error ("Sorts.witness_sorts: bad witness for sort " ^ str_of_sort S);
       
   365   in map_filter double_check_result (#2 (witness_aux (classes, arities) log_types hyps sorts)) end;
       
   366 
   367 
   367 end;
   368 end;
   368 
       
   369 end;