Datatype.ML
author nipkow
Fri, 08 Jul 1994 17:22:58 +0200
changeset 92 bcd0ee8d71aa
parent 91 a94029edb01f
child 96 d94d0b324b4b
permissions -rw-r--r--
Hidden dtK and Impossible with a "local" clause

(*  Title:       HOL/Datatype
    ID:          $Id$
    Author:      Max Breitling / Carsten Clasohm
    Copyright    1994 TU Muenchen
*)


(*choice between Ci_neg1 and Ci_neg2 axioms depends on number of constructors*)
local val dtK = 5
in

local open ThyParse in
val datatype_decls =
  let fun cat s1 s2 = s1 ^ " " ^ s2;

      val pars = parents "(" ")";
      val brackets = parents "[" "]";

      val mk_list = brackets o commas;

      val tvar = type_var >> cat "dtVar";

      val type_var_list = 
        tvar >> (fn s => [s]) || "(" $$-- list1 tvar --$$ ")";
    
      val typ =
         ident                  >> (cat "dtId" o quote)
        ||
         type_var_list -- ident >> (fn (ts, id) => "dtComp (" ^ mk_list ts ^
  				  ", " ^ quote id ^ ")")
        ||
         tvar;
    
      val typ_list = "(" $$-- list1 typ --$$ ")" || empty;
  
      val cons = name -- typ_list -- opt_mixfix;
  
      fun constructs ts =
        ( cons --$$ "|" -- constructs >> op::
         ||
          cons                        >> (fn c => [c])) ts;  
  
      val mk_cons = map (fn ((s, ts), syn) => 
                           pars (commas [s, mk_list ts, syn]));
  
      (*remove all quotes from a string*)
      fun rem_quotes s = implode (filter (fn c => c <> "\"") (explode s));
            
      (*generate names of ineq axioms*)
      fun rules_ineq cs tname = 
        let (*combine all constructor names with all others w/o duplicates*)
            fun negOne _ [] = [] 
              | negOne (c : (string * 'a) * 'b) ((c2 : (string * 'a) * 'b) 
                                                 :: cs) = 
                  quote ("ineq_" ^ rem_quotes (#1 (#1 c)) ^ "_" ^ 
                  rem_quotes (#1 (#1 c2))) :: negOne c cs;
  
            fun neg1 [] = []
              | neg1 (c1 :: cs) = (negOne c1 cs) @ (neg1 cs)
        in if length cs < dtK then neg1 cs
           else map (fn n => quote (tname ^ "_ord" ^ string_of_int n)) 
                    (0 upto (length cs))
        end;

      fun arg1 ((_, ts), _) = not (null ts);
          
      (*generate string for calling 'add_datatype'*)
      fun mk_params ((ts, tname), cons) =
       ("|> add_datatype\n" ^ 
       pars (commas [mk_list ts, quote tname, mk_list (mk_cons cons)]),
       "structure " ^ tname ^ " =\n\
       \struct\n\
       \  val inject = map (get_axiom thy) " ^
         mk_list (map (fn ((s,_), _) => quote ("inject_" ^ rem_quotes s)) 
                      (filter arg1 cons)) ^ ";\n\
       \  val ineq = " ^ (if length cons < dtK then "let val ineq' = " else "")
         ^ "map (get_axiom thy) " ^ mk_list (rules_ineq cons tname) ^ 
         (if length cons < dtK then 
           "  in ineq' @ (map (fn t => sym COMP (t RS contrapos)) ineq') end"
          else "") ^ ";\n\
       \  val induct = get_axiom thy \"" ^ tname ^ "_induct\";\n\
       \  val cases = map (get_axiom thy) " ^
         mk_list (map (fn ((s,_),_) => 
                         quote(tname ^ "_case_" ^ rem_quotes s)) cons) ^ ";\n\
       \  val simps = inject @ ineq @ cases;\n\
       \  fun induct_tac a = res_inst_tac [(" ^ quote tname ^ ", a)] induct;\n\
       \end;\n");
  in (type_var_list || empty) -- ident --$$ "=" -- constructs >> mk_params end
end;

(*used for constructor parameters*)
datatype dt_type = dtVar of string |
                   dtId  of string |
                   dtComp of dt_type list * string |
                   dtRek of dt_type list * string;

local open Syntax.Mixfix
      exception Impossible
in
fun add_datatype (typevars, tname, cons_list') thy = 
  let fun cat s1 s2 = s1 ^ " " ^ s2;

      val pars = parents "(" ")";
      val brackets = parents "[" "]";

      val mk_list = brackets o commas;

      (*check if constructor names are unique*)
      fun check_cons (cs : (string * 'b * 'c) list) =
        (case findrep (map #1 cs) of
           [] => true
         | c::_ => error("Constructor \"" ^ c ^ "\" occurs twice"));

      (*search for free type variables and convert recursive *)
      fun analyse_types (cons, typlist, syn) =
            let fun analyse ((dtVar v) :: typlist) =
                     if ((dtVar v) mem typevars) then
                       (dtVar v) :: analyse typlist
                     else error ("Variable " ^ v ^ " is free.")
                  | analyse ((dtId s) :: typlist) =
                     if tname<>s then (dtId s) :: analyse typlist
                     else if null typevars then 
                       dtRek ([], tname) :: analyse typlist
                     else error (s ^ " used in different ways")
                  | analyse (dtComp (typl,s) :: typlist) =
                     if tname <> s then dtComp (analyse typl, s)
                                     :: analyse typlist
                     else if typevars = typl then
                       dtRek (typl, s) :: analyse typlist
                     else 
                       error (s ^ " used in different ways")
                  | analyse [] = []
                  | analyse ((dtRek _) :: _) = raise Impossible;
            in (cons, analyse typlist, syn) end;

      (*test if there are elements that are not recursive, i.e. if the type is
        not empty*)
      fun one_not_rek (cs : ('a * dt_type list * 'c) list) = 
        let val contains_no_rek = forall (fn dtRek _ => false | _ => true);
        in exists (contains_no_rek o #2) cs orelse
           error("Empty type not allowed!") end;

      val dummy = check_cons cons_list';
      val cons_list = map analyse_types cons_list';
      val dummy = one_not_rek cons_list;

      (*Pretty printers for type lists;
        pp_typlist1: parentheses, pp_typlist2: brackets*)
      fun pp_typ (dtVar s) = s
        | pp_typ (dtId s) = s
        | pp_typ (dtComp (typvars, id)) = (pp_typlist1 typvars) ^ id
        | pp_typ (dtRek (typvars, id)) = (pp_typlist1 typvars) ^ id
      and
          pp_typlist' ts = commas (map pp_typ ts)
      and
          pp_typlist1 ts = if null ts then "" else pars (pp_typlist' ts);

      fun pp_typlist2 ts = if null ts then "" else brackets (pp_typlist' ts);

      fun Args(var, delim, n, m) = if n = m then var ^ string_of_int(n) 
                                   else var ^ string_of_int(n) ^ delim ^ 
			    	        Args(var, delim, n+1, m);

      (* Generate syntax translation for case rules *)
      fun calc_xrules c_nr y_nr ((id, typlist, syn) :: cs) = 
            let val name = const_name id syn;
                val arity = length typlist;
                val body  = "z" ^ string_of_int(c_nr);
                val args1 = if arity=0 then ""
                            else pars (Args ("y", ",", y_nr, y_nr+arity-1));
                val args2 = if arity=0 then ""
                            else "% " ^ Args ("y", " ", y_nr, y_nr+arity-1) 
                            ^ ". ";
                val (rest1,rest2) = 
		  if null cs then ("","")
                  else let val (h1, h2) = calc_xrules (c_nr+1) (y_nr+arity) cs
                       in (" | " ^ h1, ", " ^ h2) end;
            in (name ^ args1 ^ " => " ^ body ^ rest1, args2 ^ body ^ rest2) end
        | calc_xrules _ _ [] = raise Impossible;
      
      val xrules =
         let val (first_part, scnd_part) = calc_xrules 1 1 cons_list
         in  [("logic", "case x of " ^ first_part) <->
              ("logic", tname ^ "_case(x, " ^ scnd_part ^ ")" )]
         end;

      (*type declarations for constructors*)
      fun const_types ((id, typlist, syn) :: cs) =
           (id,  
            (if null typlist then "" else pp_typlist2 typlist ^ " => ") ^
             pp_typlist1 typevars ^ tname, syn)
           :: const_types cs
        | const_types [] = [];

      fun create_typevar (dtVar s) typlist =
            if (dtVar s) mem typlist then 
	      create_typevar (dtVar (s ^ "'")) typlist 
            else s
        | create_typevar _ _ = raise Impossible;

      fun assumpt (dtRek _ :: ts, v :: vs ,found) =
            let val h = if found then ";P(" ^ v ^ ")"
                                 else "[| P(" ^ v ^ ")"
            in h ^ (assumpt (ts, vs, true)) end
        | assumpt (t :: ts, v :: vs, found) = assumpt (ts, vs, found)
        | assumpt ([], [], found) = if found then "|] ==>" else ""
        | assumpt _ = raise Impossible;

      (*insert type with suggested name 'varname' into table*)
      fun insert typ varname ((t, s, n) :: xs) = 
            if typ = t then (t, s, n+1) :: xs
            else if varname = s then (t,s,n) :: (insert typ (varname ^ "'") xs)
                                else (t,s,n) :: (insert typ varname xs)
        | insert typ varname [] = [(typ, varname, 1)];

      fun insert_types (dtRek (l,id) :: ts) tab =
            insert_types ts (insert (dtRek(l,id)) id tab)
        | insert_types ((dtVar s) :: ts) tab =
            insert_types ts (insert (dtVar s) (implode (tl (explode s))) tab)
        | insert_types ((dtId s) :: ts) tab =
            insert_types ts (insert (dtId s) s tab)
        | insert_types (dtComp (l,id) :: ts) tab =
            insert_types ts (insert (dtComp(l,id)) id tab)
        | insert_types [] tab = tab;

      fun update(dtRek _, s, v :: vs, (dtRek _) :: ts) = s :: vs
        | update(t, s, v :: vs, t1 :: ts) = 
            if t=t1 then s :: vs
                    else v :: (update (t, s, vs, ts))
        | update _ = raise Impossible;
      
      fun update_n (dtRek r1, s, v :: vs, (dtRek r2) :: ts, n) =
            if r1 = r2 then (s ^ string_of_int n) :: 
                            (update_n (dtRek r1, s, vs, ts, n+1))
                       else v :: (update_n (dtRek r1, s, vs, ts, n))
        | update_n (t, s, v :: vs, t1 :: ts, n) = 
            if t = t1 then (s ^ string_of_int n) :: 
                           (update_n (t, s, vs, ts, n+1))
                      else v :: (update_n (t, s, vs, ts, n))
        | update_n (_,_,[],[],_) = []
        | update_n _ = raise Impossible;

      (*insert type variables into table*)
      fun convert ((t, s, n) :: ts) var_list typ_list =
            let val h = if n=1 then update (t, s, var_list, typ_list)
                               else update_n (t, s, var_list, typ_list, 1)
            in convert ts h typ_list end
        | convert [] var_list _ = var_list;

      fun empty_list n = replicate n "";

      fun t_inducting ((id, typl, syn) :: cs) =
            let val name = const_name id syn;
                val tab = insert_types typl [];
                val arity = length typl;
                val var_list = convert tab (empty_list arity) typl; 
                val h = if arity = 0 then " P(" ^ name ^ ")"
                        else " !!" ^ (space_implode " " var_list) ^ "." ^
                             (assumpt (typl, var_list, false)) ^ "P(" ^ 
                             name ^ "(" ^ (commas var_list) ^ "))";
                val rest = t_inducting cs;
            in if rest = "" then h else h ^ "; " ^ rest end
        | t_inducting [] = "";

      fun t_induct cl typ_name=
        "[|" ^ t_inducting cl ^ "|] ==> P(" ^ typ_name ^ ")";

      fun case_typlist typevar ((_, typlist, _) :: cs) =
           let val h = if (length typlist) > 0 then 
		         (pp_typlist2 typlist) ^ "=>"
                       else ""
           in "," ^ h ^ typevar ^ (case_typlist typevar cs) end
        | case_typlist _ [] = "";

      fun case_rules t_case arity n ((id, typlist, syn) :: cs) =
            let val name = const_name id syn;
                val args = if null typlist then ""
  			   else "(" ^ Args ("x", ",", 1, length typlist) ^ ")"
            in (t_case ^ "_" ^ id,
                t_case ^ "(" ^ name ^ args ^ "," ^ Args ("f", ",", 1, arity) 
                ^ ") = f" ^ string_of_int(n) ^ args)
               :: (case_rules t_case arity (n+1) cs)
            end
        | case_rules _ _ _ [] = [];

      val datatype_arity = length typevars;

      val types = [(tname, datatype_arity, NoSyn)];

      val arities = 
        let val term_list = replicate datatype_arity ["term"];
        in [(tname, term_list, ["term"])] end;

      val datatype_name = pp_typlist1 typevars ^ tname;

      val (case_const, rules_case) =
         let val typevar = create_typevar (dtVar "'beta") typevars;
             val t_case = tname ^ "_case";
             val arity = length cons_list;
             val dekl = (t_case, "[" ^ pp_typlist1 typevars ^ tname ^
                       case_typlist typevar cons_list ^ "]=>" ^ typevar, NoSyn)
                       :: nil;
             val rules = case_rules t_case arity 1 cons_list;
         in (dekl, rules) end;

      val consts = 
        const_types cons_list
	@ (if length cons_list < dtK then []
	   else [(tname ^ "_ord", datatype_name ^ "=>nat", NoSyn)])
	@ case_const;

      (*generate 'var_n, ..., var_m'*)
      fun Args(var, delim, n, m) = 
        space_implode delim (map (fn n => var^string_of_int(n)) (n upto m));

      (*generate 'name_1', ..., 'name_n'*)
      fun C_exp(name, n, var) =
        if n > 0 then name ^ "(" ^ Args (var, ",", 1, n) ^ ")"
                 else name;

      (*generate 'x_n = y_n, ..., x_m = y_m'*)
      fun Arg_eql(n,m) = 
        if n=m then "x" ^ string_of_int(n) ^ "=y" ^ string_of_int(n) 
        else "x" ^ string_of_int(n) ^ "=y" ^ string_of_int(n) ^ " & " ^ 
             Arg_eql(n+1, m);

      fun Ci_ing ((id, typlist, syn) :: cs) =
            let val name = const_name id syn;
                val arity = length typlist;
            in if arity > 0 
               then ("inject_" ^ id,
                     "(" ^ C_exp(name,arity,"x") ^ "=" ^ C_exp(name,arity,"y") 
                     ^ ") = (" ^ Arg_eql (1, arity) ^ ")") :: (Ci_ing cs)
               else (Ci_ing cs)      
            end
        | Ci_ing [] = [];

      fun Ci_negOne _ [] = []
        | Ci_negOne c (c1::cs) =
           let val (id1, tl1, syn1) = c
               val (id2, tl2, syn2) = c1
               val name1 = const_name id1 syn1;
               val name2 = const_name id2 syn2;
               val arit1 = length tl1
               val arit2 = length tl2
               val h = "(" ^ C_exp(name1, arit1, "x") ^ "~=" ^
                             C_exp(name2, arit2, "y") ^ ")"
           in ("ineq_" ^ id1 ^ "_" ^ id2, h):: (Ci_negOne c cs) 
	   end;

      fun Ci_neg1 [] = []
        | Ci_neg1 (c1::cs) = Ci_negOne c1 cs @ Ci_neg1 cs;

      fun suc_expr n = 
        if n=0 then "0" else "Suc(" ^ suc_expr(n-1) ^ ")";

      fun Ci_neg2equals (ord_t, ((id, typlist, syn) :: cs), n) =
          let val name = const_name id syn;
              val h = ord_t ^ "(" ^ (C_exp(name, length typlist, "x")) 
                      ^ ") = " ^ (suc_expr n)
          in (ord_t ^ (string_of_int (n+1)), h) 
             :: (Ci_neg2equals (ord_t, cs , n+1))
          end
        | Ci_neg2equals (_, [], _) = [];

      val Ci_neg2 =
        let val ord_t = tname ^ "_ord";
        in (Ci_neg2equals (ord_t, cons_list, 0)) @
           [(ord_t ^ "0",
            "(" ^ ord_t ^ "(x) ~= " ^ ord_t ^ "(y)) ==> (x ~= y)")]
        end;

      val rules_ineq = if length cons_list < dtK then Ci_neg1 cons_list
                                                 else Ci_neg2;

      val rules_inject = Ci_ing cons_list;

      val rule_induct = (tname ^ "_induct", t_induct cons_list tname);

      val rules = rule_induct :: (rules_inject @ rules_ineq @ rules_case);
  in thy
     |> add_types types
     |> add_arities arities
     |> add_consts consts
     |> add_trrules xrules
     |> add_axioms rules
  end
end
end;