src/Pure/Syntax/type_ext.ML
author wenzelm
Tue Mar 22 15:32:47 2011 +0100 (2011-03-22)
changeset 42052 34f1d2d81284
parent 42050 5a505dfec04e
child 42053 006095137a81
permissions -rw-r--r--
statespace syntax: strip positions -- type constraints are unexpected here;
     1 (*  Title:      Pure/Syntax/type_ext.ML
     2     Author:     Tobias Nipkow and Markus Wenzel, TU Muenchen
     3 
     4 Utilities for input and output of types.  The concrete syntax of types.
     5 *)
     6 
     7 signature TYPE_EXT0 =
     8 sig
     9   val sort_of_term: term -> sort
    10   val term_sorts: term -> (indexname * sort) list
    11   val typ_of_term: (indexname -> sort) -> term -> typ
    12   val strip_positions: term -> term
    13   val strip_positions_ast: Ast.ast -> Ast.ast
    14   val decode_term: (((string * int) * sort) list -> string * int -> sort) ->
    15     (string -> bool * string) -> (string -> string option) -> term -> term
    16   val term_of_typ: bool -> typ -> term
    17   val no_brackets: unit -> bool
    18   val no_type_brackets: unit -> bool
    19   val type_ast_trs:
    20    {read_class: Proof.context -> string -> string,
    21     read_type: Proof.context -> string -> string} ->
    22     (string * (Proof.context -> Ast.ast list -> Ast.ast)) list
    23 end;
    24 
    25 signature TYPE_EXT =
    26 sig
    27   include TYPE_EXT0
    28   val term_of_sort: sort -> term
    29   val tappl_ast_tr': Ast.ast * Ast.ast list -> Ast.ast
    30   val sortT: typ
    31   val type_ext: Syn_Ext.syn_ext
    32 end;
    33 
    34 structure Type_Ext: TYPE_EXT =
    35 struct
    36 
    37 (** input utils **)
    38 
    39 (* sort_of_term *)
    40 
    41 fun sort_of_term tm =
    42   let
    43     fun err () = raise TERM ("sort_of_term: bad encoding of classes", [tm]);
    44 
    45     fun class s = Lexicon.unmark_class s handle Fail _ => err ();
    46 
    47     fun classes (Const (s, _)) = [class s]
    48       | classes (Const ("_classes", _) $ Const (s, _) $ cs) = class s :: classes cs
    49       | classes _ = err ();
    50 
    51     fun sort (Const ("_topsort", _)) = []
    52       | sort (Const (s, _)) = [class s]
    53       | sort (Const ("_sort", _) $ cs) = classes cs
    54       | sort _ = err ();
    55   in sort tm end;
    56 
    57 
    58 (* term_sorts *)
    59 
    60 fun term_sorts tm =
    61   let
    62     val sort_of = sort_of_term;
    63 
    64     fun add_env (Const ("_ofsort", _) $ Free (x, _) $ cs) =
    65           insert (op =) ((x, ~1), sort_of cs)
    66       | add_env (Const ("_ofsort", _) $ (Const ("_tfree", _) $ Free (x, _)) $ cs) =
    67           insert (op =) ((x, ~1), sort_of cs)
    68       | add_env (Const ("_ofsort", _) $ Var (xi, _) $ cs) =
    69           insert (op =) (xi, sort_of cs)
    70       | add_env (Const ("_ofsort", _) $ (Const ("_tvar", _) $ Var (xi, _)) $ cs) =
    71           insert (op =) (xi, sort_of cs)
    72       | add_env (Abs (_, _, t)) = add_env t
    73       | add_env (t1 $ t2) = add_env t1 #> add_env t2
    74       | add_env _ = I;
    75   in add_env tm [] end;
    76 
    77 
    78 (* typ_of_term *)
    79 
    80 fun typ_of_term get_sort tm =
    81   let
    82     fun err () = raise TERM ("typ_of_term: bad encoding of type", [tm]);
    83 
    84     fun typ_of (Free (x, _)) = TFree (x, get_sort (x, ~1))
    85       | typ_of (Var (xi, _)) = TVar (xi, get_sort xi)
    86       | typ_of (Const ("_tfree",_) $ (t as Free _)) = typ_of t
    87       | typ_of (Const ("_tvar",_) $ (t as Var _)) = typ_of t
    88       | typ_of (Const ("_ofsort", _) $ Free (x, _) $ _) = TFree (x, get_sort (x, ~1))
    89       | typ_of (Const ("_ofsort", _) $ (Const ("_tfree",_) $ Free (x, _)) $ _) =
    90           TFree (x, get_sort (x, ~1))
    91       | typ_of (Const ("_ofsort", _) $ Var (xi, _) $ _) = TVar (xi, get_sort xi)
    92       | typ_of (Const ("_ofsort", _) $ (Const ("_tvar",_) $ Var (xi, _)) $ _) =
    93           TVar (xi, get_sort xi)
    94       | typ_of (Const ("_dummy_ofsort", _) $ t) = TFree ("'_dummy_", sort_of_term t)
    95       | typ_of t =
    96           let
    97             val (head, args) = Term.strip_comb t;
    98             val a =
    99               (case head of
   100                 Const (c, _) => (Lexicon.unmark_type c handle Fail _ => err ())
   101               | _ => err ());
   102           in Type (a, map typ_of args) end;
   103   in typ_of tm end;
   104 
   105 
   106 (* positions *)
   107 
   108 fun is_position (Free (x, _)) = is_some (Lexicon.decode_position x)
   109   | is_position _ = false;
   110 
   111 fun strip_positions ((t as Const (c, _)) $ u $ v) =
   112       if (c = "_constrain" orelse c = "_constrainAbs") andalso is_position v
   113       then strip_positions u
   114       else t $ strip_positions u $ strip_positions v
   115   | strip_positions (t $ u) = strip_positions t $ strip_positions u
   116   | strip_positions (Abs (x, T, t)) = Abs (x, T, strip_positions t)
   117   | strip_positions t = t;
   118 
   119 fun strip_positions_ast (Ast.Appl ((t as Ast.Constant c) :: u :: (v as Ast.Variable x) :: asts)) =
   120       if (c = "_constrain" orelse c = "_constrainAbs") andalso is_some (Lexicon.decode_position x)
   121       then Ast.mk_appl (strip_positions_ast u) (map strip_positions_ast asts)
   122       else Ast.Appl (map strip_positions_ast (t :: u :: v :: asts))
   123   | strip_positions_ast (Ast.Appl asts) = Ast.Appl (map strip_positions_ast asts)
   124   | strip_positions_ast ast = ast;
   125 
   126 
   127 (* decode_term -- transform parse tree into raw term *)
   128 
   129 fun decode_term get_sort map_const map_free tm =
   130   let
   131     val decodeT = typ_of_term (get_sort (term_sorts tm));
   132 
   133     fun decode (Const ("_constrain", _) $ t $ typ) =
   134           if is_position typ then decode t
   135           else Type.constraint (decodeT typ) (decode t)
   136       | decode (Const ("_constrainAbs", _) $ t $ typ) =
   137           if is_position typ then decode t
   138           else Type.constraint (decodeT typ --> dummyT) (decode t)
   139       | decode (Abs (x, T, t)) = Abs (x, T, decode t)
   140       | decode (t $ u) = decode t $ decode u
   141       | decode (Const (a, T)) =
   142           (case try Lexicon.unmark_fixed a of
   143             SOME x => Free (x, T)
   144           | NONE =>
   145               let val c =
   146                 (case try Lexicon.unmark_const a of
   147                   SOME c => c
   148                 | NONE => snd (map_const a))
   149               in Const (c, T) end)
   150       | decode (Free (a, T)) =
   151           (case (map_free a, map_const a) of
   152             (SOME x, _) => Free (x, T)
   153           | (_, (true, c)) => Const (c, T)
   154           | (_, (false, c)) => (if Long_Name.is_qualified c then Const else Free) (c, T))
   155       | decode (Var (xi, T)) = Var (xi, T)
   156       | decode (t as Bound _) = t;
   157   in decode tm end;
   158 
   159 
   160 
   161 (** output utils **)
   162 
   163 (* term_of_sort *)
   164 
   165 fun term_of_sort S =
   166   let
   167     val class = Lexicon.const o Lexicon.mark_class;
   168 
   169     fun classes [c] = class c
   170       | classes (c :: cs) = Lexicon.const "_classes" $ class c $ classes cs;
   171   in
   172     (case S of
   173       [] => Lexicon.const "_topsort"
   174     | [c] => class c
   175     | cs => Lexicon.const "_sort" $ classes cs)
   176   end;
   177 
   178 
   179 (* term_of_typ *)
   180 
   181 fun term_of_typ show_sorts ty =
   182   let
   183     fun of_sort t S =
   184       if show_sorts then Lexicon.const "_ofsort" $ t $ term_of_sort S
   185       else t;
   186 
   187     fun term_of (Type (a, Ts)) =
   188           Term.list_comb (Lexicon.const (Lexicon.mark_type a), map term_of Ts)
   189       | term_of (TFree (x, S)) =
   190           if is_some (Lexicon.decode_position x) then Lexicon.free x
   191           else of_sort (Lexicon.const "_tfree" $ Lexicon.free x) S
   192       | term_of (TVar (xi, S)) = of_sort (Lexicon.const "_tvar" $ Lexicon.var xi) S;
   193   in term_of ty end;
   194 
   195 
   196 
   197 (** the type syntax **)
   198 
   199 (* print mode *)
   200 
   201 val bracketsN = "brackets";
   202 val no_bracketsN = "no_brackets";
   203 
   204 fun no_brackets () =
   205   find_first (fn mode => mode = bracketsN orelse mode = no_bracketsN)
   206     (print_mode_value ()) = SOME no_bracketsN;
   207 
   208 val type_bracketsN = "type_brackets";
   209 val no_type_bracketsN = "no_type_brackets";
   210 
   211 fun no_type_brackets () =
   212   find_first (fn mode => mode = type_bracketsN orelse mode = no_type_bracketsN)
   213     (print_mode_value ()) <> SOME type_bracketsN;
   214 
   215 
   216 (* parse ast translations *)
   217 
   218 val class_ast = Ast.Constant o Lexicon.mark_class;
   219 val type_ast = Ast.Constant o Lexicon.mark_type;
   220 
   221 fun class_name_tr read_class (*"_class_name"*) [Ast.Variable c] = class_ast (read_class c)
   222   | class_name_tr _ (*"_class_name"*) asts = raise Ast.AST ("class_name_tr", asts);
   223 
   224 fun classes_tr read_class (*"_classes"*) [Ast.Variable c, ast] =
   225       Ast.mk_appl (Ast.Constant "_classes") [class_ast (read_class c), ast]
   226   | classes_tr _ (*"_classes"*) asts = raise Ast.AST ("classes_tr", asts);
   227 
   228 fun type_name_tr read_type (*"_type_name"*) [Ast.Variable c] = type_ast (read_type c)
   229   | type_name_tr _ (*"_type_name"*) asts = raise Ast.AST ("type_name_tr", asts);
   230 
   231 fun tapp_ast_tr read_type (*"_tapp"*) [ty, Ast.Variable c] =
   232       Ast.Appl [type_ast (read_type c), ty]
   233   | tapp_ast_tr _ (*"_tapp"*) asts = raise Ast.AST ("tapp_ast_tr", asts);
   234 
   235 fun tappl_ast_tr read_type (*"_tappl"*) [ty, tys, Ast.Variable c] =
   236       Ast.Appl (type_ast (read_type c) :: ty :: Ast.unfold_ast "_types" tys)
   237   | tappl_ast_tr _ (*"_tappl"*) asts = raise Ast.AST ("tappl_ast_tr", asts);
   238 
   239 fun bracket_ast_tr (*"_bracket"*) [dom, cod] =
   240       Ast.fold_ast_p "\\<^type>fun" (Ast.unfold_ast "_types" dom, cod)
   241   | bracket_ast_tr (*"_bracket"*) asts = raise Ast.AST ("bracket_ast_tr", asts);
   242 
   243 
   244 (* print ast translations *)
   245 
   246 fun tappl_ast_tr' (f, []) = raise Ast.AST ("tappl_ast_tr'", [f])
   247   | tappl_ast_tr' (f, [ty]) = Ast.Appl [Ast.Constant "_tapp", ty, f]
   248   | tappl_ast_tr' (f, ty :: tys) =
   249       Ast.Appl [Ast.Constant "_tappl", ty, Ast.fold_ast "_types" tys, f];
   250 
   251 fun fun_ast_tr' (*"\\<^type>fun"*) asts =
   252   if no_brackets () orelse no_type_brackets () then raise Match
   253   else
   254     (case Ast.unfold_ast_p "\\<^type>fun" (Ast.Appl (Ast.Constant "\\<^type>fun" :: asts)) of
   255       (dom as _ :: _ :: _, cod)
   256         => Ast.Appl [Ast.Constant "_bracket", Ast.fold_ast "_types" dom, cod]
   257     | _ => raise Match);
   258 
   259 
   260 (* type_ext *)
   261 
   262 val sortT = Type ("sort", []);
   263 val classesT = Type ("classes", []);
   264 val typesT = Type ("types", []);
   265 
   266 local open Lexicon Syn_Ext in
   267 
   268 val type_ext = syn_ext' false (K false)
   269   [Mfix ("_",           tidT --> typeT,                "", [], max_pri),
   270    Mfix ("_",           tvarT --> typeT,               "", [], max_pri),
   271    Mfix ("_",           idT --> typeT,                 "_type_name", [], max_pri),
   272    Mfix ("_",           longidT --> typeT,             "_type_name", [], max_pri),
   273    Mfix ("_::_",        [tidT, sortT] ---> typeT,      "_ofsort", [max_pri, 0], max_pri),
   274    Mfix ("_::_",        [tvarT, sortT] ---> typeT,     "_ofsort", [max_pri, 0], max_pri),
   275    Mfix ("'_()::_",     sortT --> typeT,               "_dummy_ofsort", [0], max_pri),
   276    Mfix ("_",           idT --> sortT,                 "_class_name", [], max_pri),
   277    Mfix ("_",           longidT --> sortT,             "_class_name", [], max_pri),
   278    Mfix ("{}",          sortT,                         "_topsort", [], max_pri),
   279    Mfix ("{_}",         classesT --> sortT,            "_sort", [], max_pri),
   280    Mfix ("_",           idT --> classesT,              "_class_name", [], max_pri),
   281    Mfix ("_",           longidT --> classesT,          "_class_name", [], max_pri),
   282    Mfix ("_,_",         [idT, classesT] ---> classesT, "_classes", [], max_pri),
   283    Mfix ("_,_",         [longidT, classesT] ---> classesT, "_classes", [], max_pri),
   284    Mfix ("_ _",         [typeT, idT] ---> typeT,       "_tapp", [max_pri, 0], max_pri),
   285    Mfix ("_ _",         [typeT, longidT] ---> typeT,   "_tapp", [max_pri, 0], max_pri),
   286    Mfix ("((1'(_,/ _')) _)", [typeT, typesT, idT] ---> typeT, "_tappl", [], max_pri),
   287    Mfix ("((1'(_,/ _')) _)", [typeT, typesT, longidT] ---> typeT, "_tappl", [], max_pri),
   288    Mfix ("_",           typeT --> typesT,              "", [], max_pri),
   289    Mfix ("_,/ _",       [typeT, typesT] ---> typesT,   "_types", [], max_pri),
   290    Mfix ("(_/ => _)",   [typeT, typeT] ---> typeT,     "\\<^type>fun", [1, 0], 0),
   291    Mfix ("([_]/ => _)", [typesT, typeT] ---> typeT,    "_bracket", [0, 0], 0),
   292    Mfix ("'(_')",       typeT --> typeT,               "", [0], max_pri),
   293    Mfix ("'_",          typeT,                         "\\<^type>dummy", [], max_pri)]
   294   ["_type_prop"]
   295   ([], [], [], map Syn_Ext.mk_trfun [("\\<^type>fun", K fun_ast_tr')])
   296   []
   297   ([], []);
   298 
   299 fun type_ast_trs {read_class, read_type} =
   300  [("_class_name", class_name_tr o read_class),
   301   ("_classes", classes_tr o read_class),
   302   ("_type_name", type_name_tr o read_type),
   303   ("_tapp", tapp_ast_tr o read_type),
   304   ("_tappl", tappl_ast_tr o read_type),
   305   ("_bracket", K bracket_ast_tr)];
   306 
   307 end;
   308 
   309 end;