Datatype.ML
author wenzelm
Wed, 21 Sep 1994 15:40:41 +0200
changeset 145 a9f7ff3a464c
parent 123 8bef44f9b237
permissions -rw-r--r--
minor cleanup, added 'axclass', 'instance', 'syntax', 'defs' sections;

(*  Title:       HOL/Datatype
    ID:          $Id$
    Author:      Max Breitling, Carsten Clasohm,
                 Tobias Nipkow, Norbert Voelker
    Copyright    1994 TU Muenchen
*)


(*choice between Ci_neg1 and Ci_neg2 axioms depends on number of constructors*)
local

val dtK = 5

in

local open ThyParse in
val datatype_decls =
  let val tvar = type_var >> (fn s => "dtVar" ^ s);

      val type_var_list = 
        tvar >> (fn s => [s]) || "(" $$-- list1 tvar --$$ ")";
    
      val typ =
         ident                  >> (fn s => "dtTyp([]," ^ quote s ^")")
        ||
         type_var_list -- ident >> (fn (ts, id) => "dtTyp(" ^ mk_list ts ^
  				  "," ^ quote id ^ ")")
        ||
         tvar;
    
      val typ_list = "(" $$-- list1 typ --$$ ")" || empty;
  
      val cons = name -- typ_list -- opt_mixfix;
  
      fun constructs ts =
        ( cons --$$ "|" -- constructs >> op::
         ||
          cons                        >> (fn c => [c])) ts;  
  
      fun mk_cons cs =
        case findrep (map (fst o fst) cs) of
           [] => map (fn ((s,ts),syn) => parens (commas [s,mk_list ts,syn])) cs
         | c::_ => error("Constructor \"" ^ c ^ "\" occurs twice");
      
      (*remove all quotes from a string*)
      val rem_quotes = implode o filter (fn c => c <> "\"") o explode;

      (*generate names of distinct axioms*)
      fun rules_distinct cs tname = 
        let val uqcs = map (fn ((s,_),_) => rem_quotes s) cs;
            (*combine all constructor names with all others w/o duplicates*)
            fun negOne c = map (fn c2 => quote (c ^ "_not_" ^ c2));
            fun neg1 [] = []
              | neg1 (c1 :: cs) = (negOne c1 cs) @ (neg1 cs)
        in if length uqcs < dtK then neg1 uqcs
           else quote (tname ^ "_ord_distinct") ::
                map (fn c => quote (tname ^ "_ord_" ^ c)) uqcs
        end;
         
       fun rules tname cons pre =
         " map (get_axiom thy) " ^
         mk_list (map (fn ((s,_),_) => quote(tname ^ pre ^ rem_quotes s)) cons)

      (*generate string for calling 'add_datatype'*)
      fun mk_params ((ts, tname), cons) =
       ("val (thy," ^ tname ^ "_add_primrec) =  add_datatype\n" ^
       parens (commas [mk_list ts, quote tname, mk_list (mk_cons cons)]) ^
       " thy\n\
       \val thy=thy",
       "structure " ^ tname ^ " =\n\
       \struct\n\
       \  val inject = map (get_axiom thy) " ^
         mk_list (map (fn ((s,_), _) => quote ("inject_" ^ rem_quotes s)) 
                      (filter_out (null o snd o fst) cons)) ^ ";\n\
       \  val distinct = " ^ (if length cons < dtK then "let val distinct' = " else "")
         ^ "map (get_axiom thy) " ^ mk_list (rules_distinct cons tname) ^ 
         (if length cons < dtK then 
           "  in distinct' @ (map (fn t => sym COMP (t RS contrapos)) distinct') end"
          else "") ^ ";\n\
       \  val induct = get_axiom thy \"" ^ tname ^ "_induct\";\n\
       \  val cases =" ^ rules tname cons "_case_" ^ ";\n\
       \  val recs =" ^ rules tname cons "_rec_" ^ ";\n\
       \  val simps = inject @ distinct @ cases @ recs;\n\
       \  fun induct_tac a = res_inst_tac[(" ^ quote tname ^ ", a)]induct;\n\
       \end;\n")
  in (type_var_list || empty) -- ident --$$ "=" -- constructs >> mk_params end

val primrec_decl =
  let fun mkstrings((fname,tname),axms) =
        let fun prove (name,eqn) =
             "val "^name^"= prove_goalw thy [get_axiom thy \""^fname^"_def\"] "
                 ^ eqn ^"\n\
             \(fn _ => [resolve_tac " ^ tname^".recs 1])"
        in ("|> " ^ tname^"_add_primrec " ^ mk_list (map snd axms),
            cat_lines(map prove axms))
        end
  in ident -- long_id -- repeat1 (ident -- string)  >> mkstrings end

end;

(*used for constructor parameters*)
datatype dt_type = dtVar of string |
                   dtTyp of dt_type list * string |
                   dtRek of dt_type list * string;

local open Syntax
           ThyParse
      exception Impossible

val is_Rek = (fn dtRek _ => true  |  _  => false);

(* ------------------------------------------------------------------------- *)
(* Die Funktionen fuer das Umsetzen von Gleichungen in eine Definition mit   *)
(* dem prim-Rek. Kombinator                                                  *)

(*** Part 1: handling a single equation   ***)
 
(* filter REK type args by correspondence with targs. Reverses order *) 

fun rek_args (args, targs) = 
let fun h (x :: xs, tx :: txs, res) 
           = h(xs,txs,if is_Rek tx then x :: res else res )
     |  h ([],[],res) = res
in h (args,targs,[])
end;

(* abstract over all recursive calls of f in t with param v in vs.
   Name in abstraction is variant of v w.r.t. free names in t. 
   Also returns reversed list of new variables names with types. 
   Checks that there are no free occurences of f left. 
*) 

fun abstract_recs f vs t  = 
let val tfrees = add_term_names(t,[]); 
    fun h [] vns t = if fst(dest_Const f) mem add_term_names(t,[]) 
		     then raise Impossible 
		     else (t,vns)
     |  h (v::vs) vns t
        = let val vn = variant tfrees (fst(dest_Free v))
          in  h vs (vn::vns) (Abs(vn, dummyT, abstract_over(f $ v,t)))
          end;
in h vs [] t
end;

(* For every defining equation, I need to abstract over arguments and
   over the recursive calls. Cant do it simply minded in this order, because 
   abstracting over v turns (Free v) into a bound variable, so that
   abstract_recs does not apply anymore.  
   abstract_arecs_funct performs the following steps 
    * abstract over (f xi) (reverse order) 
    * remove outermost length(rargs) abstractions
    * increase loose bound variables index by #cargs
    * apply the carg abstraction (reverse order) 
    * add length(rargs) lambdas. 
    Using lower level operations on term and arithmetic, this could probably
    be made more efficient. 
*) 

(* remove n outermost abstractions from a term *)
fun rem_Abs 0 t = t
 |  rem_Abs n (Abs(s,T,t)) = rem_Abs (n-1) t
;
(* add one abstraction for for every variable in vs *)  
fun add_Abs []      t = t
 |  add_Abs (vname::vs) t = Abs(vname, dummyT, add_Abs vs t)
; 
fun abstract_arecs funct rargs args t = 
let val (arecs,vns) = abstract_recs funct rargs t;
in  add_Abs vns 
    ( list_abs_free
        ( map dest_Free args
        , incr_boundvars (length args) (rem_Abs (length rargs) arecs)))
end;

(*** part 2. Processing of list of equations ***) 

(* Take list of constructors cs and equations eqns. 
   Find for ever element c of cs a corresponding eq in eqns. 
   Check that the function name is unique and there are no double params.  
   Derive term from equation using abstract_arecs and instantiate types. 
   Assume: equation list eqns nonempty
           length(eqns) = length(cs) 
           every constant name identifies a constant and its type. 
   In h: first parameter reqs reflects the remaining equations. 
*)

fun funs_from_eqns cs eqns =
let fun dest_eq ( Const("Trueprop",_) $ (Const ("op =",_)
                 $ (f $ capp) $ right))
	         = (f, strip_comb(capp), right);
    val fname = (fn (Const(f,_),_,_) => f) (dest_eq(hd eqns));
    fun h []   []        []       res = res
     |  h _    (_ :: _)  []       _   = raise Impossible
     |  h _    []        (_ :: _) _   = raise Impossible
     |  h reqs (eq::eqs) (c::cs)  res =
	let
          val (f,(Const(cname_eq,_),args),rhs) = dest_eq eq;
          val (_,cname,targs,_) = c;
        in
	  if cname_eq <> cname then h reqs eqs (c::cs) res
          else
          if fst(dest_Const(f)) = fname
             andalso (duplicates (map (fst o dest_Free) args) = [])
          then let val reqs' = reqs \ eq
               in h reqs' reqs' cs
                    (abstract_arecs f (rek_args(args,targs)) args rhs :: res)
	       end
          else raise Impossible
        end
in (fname, h eqns eqns cs []) end;

(* take datatype and eqns and return a properly type-instantiated 
   application of the prim-rec-combinator which solves eqns.
*)

fun instant_types thy t =
  fst (Sign.infer_types (sign_of thy) (K None) (K None) (t, TVar(("",0),[])));

in

fun add_datatype (typevars, tname, cons_list') thy = 
  let (*search for free type variables and convert recursive *)
      fun analyse_types (cons, typlist, syn) =
            let fun analyse(t as dtVar v) =
                     if t mem typevars then t
                     else error ("Free type variable " ^ v ^ " on rhs.")
                  | analyse(dtTyp(typl,s)) =
                     if tname <> s then dtTyp(analyses typl, s)
                     else if typevars = typl then dtRek(typl, s)
                     else error (s ^ " used in different ways")
                  | analyse(dtRek _) = raise Impossible
                 and analyses ts = map analyse ts;
            in (cons, const_name cons syn, analyses typlist, syn) end;

      (*test if all elements are recursive, i.e. if the type is empty*)
      fun non_empty (cs : ('a * 'b * dt_type list * 'c) list) = 
        not(forall (exists is_Rek o #3) cs) orelse
        error("Empty datatype not allowed!");

      val cons_list = map analyse_types cons_list';
      val dummy = non_empty cons_list;
      val num_of_cons = length cons_list;

      (*Pretty printers for type lists;
        pp_typlist1: parentheses, pp_typlist2: brackets*)
      fun pp_typ (dtVar s) = s
        | pp_typ (dtTyp (typvars, id)) =
            if null typvars then id else (pp_typlist1 typvars) ^ id
        | pp_typ (dtRek (typvars, id)) = (pp_typlist1 typvars) ^ id
      and
          pp_typlist' ts = commas (map pp_typ ts)
      and
          pp_typlist1 ts = if null ts then "" else parens (pp_typlist' ts);

      fun pp_typlist2 ts = if null ts then "" else brackets (pp_typlist' ts);

      fun Args(var, delim, n, m) = if n = m then var ^ string_of_int(n) 
                                   else var ^ string_of_int(n) ^ delim ^ 
			    	        Args(var, delim, n+1, m);

      (* Generate syntax translation for case rules *)
      fun calc_xrules c_nr y_nr ((_, name, typlist, _) :: cs) = 
            let val arity = length typlist;
                val body  = "z" ^ string_of_int(c_nr);
                val args1 = if arity=0 then ""
                            else parens (Args ("y", ",", y_nr, y_nr+arity-1));
                val args2 = if arity=0 then ""
                            else "% " ^ Args ("y", " ", y_nr, y_nr+arity-1) 
                            ^ ". ";
                val (rest1,rest2) = 
		  if null cs then ("","")
                  else let val (h1, h2) = calc_xrules (c_nr+1) (y_nr+arity) cs
                       in (" | " ^ h1, ", " ^ h2) end;
            in (name ^ args1 ^ " => " ^ body ^ rest1, args2 ^ body ^ rest2) end
        | calc_xrules _ _ [] = raise Impossible;
      
      val xrules =
         let val (first_part, scnd_part) = calc_xrules 1 1 cons_list
         in  [("logic", "case x of " ^ first_part) <->
              ("logic", tname ^ "_case(" ^ scnd_part ^ ", x)" )]
         end;

      (*type declarations for constructors*)
      fun const_type (id, _, typlist, syn) =
           (id,  
            (if null typlist then "" else pp_typlist2 typlist ^ " => ") ^
             pp_typlist1 typevars ^ tname, syn);


      fun assumpt (dtRek _ :: ts, v :: vs ,found) =
            let val h = if found then ";P(" ^ v ^ ")" else "[| P(" ^ v ^ ")"
            in h ^ (assumpt (ts, vs, true)) end
        | assumpt (t :: ts, v :: vs, found) = assumpt (ts, vs, found)
        | assumpt ([], [], found) = if found then "|] ==>" else ""
        | assumpt _ = raise Impossible;

      (*insert type with suggested name 'varname' into table*)
      fun insert typ varname ((tri as (t, s, n)) :: xs) = 
            if typ = t then (t, s, n+1) :: xs
            else tri :: (if varname = s then insert typ (varname ^ "'") xs
                         else insert typ varname xs)
        | insert typ varname [] = [(typ, varname, 1)];

      fun typid(dtRek(_,id)) = id
        | typid(dtVar s) = implode (tl (explode s))
        | typid(dtTyp(_,id)) = id;

      val insert_types = foldl (fn (tab,typ) => insert typ (typid typ) tab);

      fun update(dtRek _, s, v :: vs, (dtRek _) :: ts) = s :: vs
        | update(t, s, v :: vs, t1 :: ts) = 
            if t=t1 then s :: vs
                    else v :: (update (t, s, vs, ts))
        | update _ = raise Impossible;
      
      fun update_n (dtRek r1, s, v :: vs, (dtRek r2) :: ts, n) =
            if r1 = r2 then (s ^ string_of_int n) :: 
                            (update_n (dtRek r1, s, vs, ts, n+1))
                       else v :: (update_n (dtRek r1, s, vs, ts, n))
        | update_n (t, s, v :: vs, t1 :: ts, n) = 
            if t = t1 then (s ^ string_of_int n) :: 
                           (update_n (t, s, vs, ts, n+1))
                      else v :: (update_n (t, s, vs, ts, n))
        | update_n (_,_,[],[],_) = []
        | update_n _ = raise Impossible;

      (*insert type variables into table*)
      fun convert typs =
        let fun conv(vars, (t, s, n)) =
              if n=1 then update (t, s, vars, typs)
                     else update_n (t, s, vars, typs, 1)
        in foldl conv end;

      fun empty_list n = replicate n "";

      fun t_inducting ((_, name, typl, _) :: cs) =
            let val tab = insert_types([],typl);
                val arity = length typl;
                val var_list = convert typl (empty_list arity,tab); 
                val h = if arity = 0 then " P(" ^ name ^ ")"
                        else " !!" ^ (space_implode " " var_list) ^ "." ^
                             (assumpt (typl, var_list, false)) ^ "P(" ^ 
                             name ^ "(" ^ (commas var_list) ^ "))";
                val rest = t_inducting cs;
            in if rest = "" then h else h ^ "; " ^ rest end
        | t_inducting [] = "";

      fun t_induct cl typ_name =
        "[|" ^ t_inducting cl ^ "|] ==> P(" ^ typ_name ^ ")";

      fun gen_typlist typevar f ((_, _, ts, _) :: cs) =
           let val h = if (length ts) > 0
                       then pp_typlist2(f ts) ^ "=>"
                       else ""
           in "," ^ h ^ typevar ^ (gen_typlist typevar f cs) end
        | gen_typlist _ _ [] = "";

      val t_case = tname ^ "_case";

      fun case_rules n ((id, name, typlist, _) :: cs) =
            let val args = if null typlist then ""
  			   else parens(Args("x", ",", 1, length typlist))
            in (t_case ^ "_" ^ id,
                t_case ^ "(" ^ name ^ args ^ "," ^
                  Args("f", ",", 1, num_of_cons)
                  ^ ") = f" ^ string_of_int(n) ^ args)
               :: (case_rules (n+1) cs)
            end
        | case_rules _ [] = [];

      val datatype_arity = length typevars;

      val types = [(tname, datatype_arity, NoSyn)];

      val arities = 
        let val term_list = replicate datatype_arity ["term"];
        in [(tname, term_list, ["term"])] end;

      val datatype_name = pp_typlist1 typevars ^ tname;

      val new_tvar_name = variant (map (fn dtVar s => s) typevars) "'z";

      val case_const =
         (t_case,
          "[" ^ pp_typlist1 typevars ^ tname ^
                gen_typlist new_tvar_name I cons_list ^ "] =>" ^ new_tvar_name,
          NoSyn);

      val rules_case = case_rules 1 cons_list;


(* -------------------------------------------------------------------- *)
(* Die Funktionen fuer die t_rec - Funktion                             *)
(* Analog zu t_case bis auf Hinzufuegen rek. Aufrufe pro Konstruktor 	*)

      val t_rec = tname ^ "_rec"

      fun add_reks ts = 
        let val tv = dtVar new_tvar_name; 
            fun h (t::ts) res = h ts (if is_Rek(t) then tv::res else res)
	      | h [] res  = res        
        in  h ts ts  end;

fun arg_reks ts = 
  let fun arg_rek (t::ts) n res  = 
        let val h = t_rec ^"(" ^ "x" ^string_of_int(n) 
			       ^"," ^Args("f",",",1,num_of_cons) ^")," 
        in arg_rek ts (n+1) (if is_Rek(t) then res ^ h else res)
        end 
      | arg_rek [] _ res = res        
  in  arg_rek ts 1 ""  end;

fun rec_rules n ((id,name,ts,_)::cs) =
  let val lts = length ts 
      val args = if lts = 0 then ""
	         else parens(Args("x",",",1,lts)) 
      val rargs = if (lts = 0) then ""
	          else "("^ arg_reks ts ^ Args("x",",",1,lts) ^")"
  in     
    ( t_rec ^ "_" ^ id
    , t_rec ^ "(" ^ name ^ args ^ "," ^ Args("f",",",1,num_of_cons) ^ ") = f"
      ^ string_of_int(n) ^ rargs) 
     :: (rec_rules (n+1) cs)
  end
  | rec_rules _ [] = [];

      val rec_const =
        (t_rec,
         "[" ^ (pp_typlist1 typevars) ^ tname ^
               (gen_typlist new_tvar_name add_reks cons_list) ^
               "] =>" ^ new_tvar_name,
         NoSyn);

      val rules_rec = rec_rules 1 cons_list

      val consts = 
        map const_type cons_list
	@ (if num_of_cons < dtK then []
	   else [(tname ^ "_ord", datatype_name ^ "=>nat", NoSyn)])
	@ [case_const,rec_const];

      (*generate 'var_n, ..., var_m'*)
      fun Args(var, delim, n, m) = 
        space_implode delim (map (fn n => var^string_of_int(n)) (n upto m));

      (*generate 'name_1', ..., 'name_n'*)
      fun C_exp(name, n, var) =
        if n > 0 then name ^ parens(Args(var, ",", 1, n)) else name;

      (*generate 'x_n = y_n, ..., x_m = y_m'*)
      fun Arg_eql(n,m) = 
        if n=m then "x" ^ string_of_int(n) ^ "=y" ^ string_of_int(n) 
        else "x" ^ string_of_int(n) ^ "=y" ^ string_of_int(n) ^ " & " ^ 
             Arg_eql(n+1, m);

      fun Ci_ing ((id, name, typlist, _) :: cs) =
            let val arity = length typlist;
            in if arity = 0 then Ci_ing cs
               else ("inject_" ^ id,
                     "(" ^ C_exp(name,arity,"x") ^ "=" ^ C_exp(name,arity,"y") 
                     ^ ") = (" ^ Arg_eql (1, arity) ^ ")") :: (Ci_ing cs)
            end
        | Ci_ing [] = [];

      fun Ci_negOne (id1, name1, tl1, _) (id2, name2, tl2, _) =
           let val ax = C_exp(name1, length tl1, "x") ^ "~=" ^
                        C_exp(name2, length tl2, "y")
           in (id1 ^ "_not_" ^ id2, ax) end;

      fun Ci_neg1 [] = []
        | Ci_neg1 (c1::cs) = (map (Ci_negOne c1) cs) @ Ci_neg1 cs;

      fun suc_expr n = 
        if n=0 then "0" else "Suc(" ^ suc_expr(n-1) ^ ")";

      fun Ci_neg2() =
        let val ord_t = tname ^ "_ord";
            val cis = cons_list ~~ (0 upto (num_of_cons - 1))
            fun Ci_neg2equals ((id, name, typlist, _), n) =
              let val ax = ord_t ^ "(" ^ (C_exp(name, length typlist, "x")) 
                                 ^ ") = " ^ (suc_expr n)
              in (ord_t ^ "_" ^ id, ax) end
        in (ord_t ^ "_distinct", ord_t^"(x) ~= "^ord_t^"(y) ==> x ~= y") ::
           (map Ci_neg2equals cis)
        end;

      val rules_distinct = if num_of_cons < dtK then Ci_neg1 cons_list
                           else Ci_neg2();

      val rules_inject = Ci_ing cons_list;

      val rule_induct = (tname ^ "_induct", t_induct cons_list tname);

      val rules = rule_induct ::
                  (rules_inject @ rules_distinct @ rules_case @ rules_rec);

      fun add_primrec eqns thy =
      let val rec_comb = Const(t_rec,dummyT)
          val teqns = map (fn eq => snd(read_axm (sign_of thy) ("",eq))) eqns
          val (fname,rfuns) = funs_from_eqns cons_list teqns
          val rhs = Abs(tname, dummyT,
                        list_comb(rec_comb, Bound 0 :: rev rfuns))
          val def = Const("==",dummyT) $ Const(fname,dummyT) $ rhs
          val tdef = instant_types thy def
      in add_defs_i [(fname ^ "_def", tdef)] thy end;

  in (thy
     |> add_types types
     |> add_arities arities
     |> add_consts consts
     |> add_trrules xrules
     |> add_axioms rules,
     add_primrec)
  end
end
end;