Datatype.ML
changeset 101 5f99df1e26c4
parent 96 d94d0b324b4b
child 103 c57ab3ce997e
--- a/Datatype.ML	Wed Aug 03 11:00:40 1994 +0200
+++ b/Datatype.ML	Sat Aug 13 16:33:53 1994 +0200
@@ -1,6 +1,7 @@
 (*  Title:       HOL/Datatype
     ID:          $Id$
-    Author:      Max Breitling / Carsten Clasohm
+    Author:      Max Breitling / Carsten Clasohm /
+                 Norbert Voelker / Tobias Nipkow
     Copyright    1994 TU Muenchen
 *)
 
@@ -9,16 +10,15 @@
 local
 
 val dtK = 5
-val pars = parents "(" ")";
-val brackets = parents "[" "]";
+val pars = enclose "(" ")";
+val brackets = enclose "[" "]";
+val mk_list = brackets o commas;
 
 in
 
 local open ThyParse in
 val datatype_decls =
-  let val mk_list = brackets o commas;
-
-      val tvar = type_var >> (fn s => "dtVar" ^ s);
+  let val tvar = type_var >> (fn s => "dtVar" ^ s);
 
       val type_var_list = 
         tvar >> (fn s => [s]) || "(" $$-- list1 tvar --$$ ")";
@@ -42,10 +42,10 @@
   
       val mk_cons = map (fn ((s, ts), syn) => 
                            pars (commas [s, mk_list ts, syn]));
-  
+
       (*remove all quotes from a string*)
       val rem_quotes = implode o filter (fn c => c <> "\"") o explode;
-            
+
       (*generate names of distinct axioms*)
       fun rules_distinct cs tname = 
         let val uqcs = map (fn ((s,_),_) => rem_quotes s) cs;
@@ -57,11 +57,19 @@
            else quote (tname ^ "_ord_distinct") ::
                 map (fn c => quote (tname ^ "_ord_" ^ c)) uqcs
         end;
-          
+
+       fun rule_names tname cons pre =
+         mk_list (map (fn ((s,_),_) => quote(tname ^ pre ^ rem_quotes s)) cons)
+
+       fun rules tname cons pre =
+         " map (get_axiom thy) " ^ rule_names tname cons pre       
+
       (*generate string for calling 'add_datatype'*)
       fun mk_params ((ts, tname), cons) =
-       ("|> add_datatype\n" ^ 
-       pars (commas [mk_list ts, quote tname, mk_list (mk_cons cons)]),
+       ("val (thy," ^ tname ^ "_add_primrec) =  add_datatype\n" ^
+       pars (commas [mk_list ts, quote tname, mk_list (mk_cons cons)]) ^
+       " thy\n\
+       \val thy=thy",
        "structure " ^ tname ^ " =\n\
        \struct\n\
        \  val inject = map (get_axiom thy) " ^
@@ -73,13 +81,24 @@
            "  in distinct' @ (map (fn t => sym COMP (t RS contrapos)) distinct') end"
           else "") ^ ";\n\
        \  val induct = get_axiom thy \"" ^ tname ^ "_induct\";\n\
-       \  val cases = map (get_axiom thy) " ^
-         mk_list (map (fn ((s,_),_) => 
-                         quote(tname ^ "_case_" ^ rem_quotes s)) cons) ^ ";\n\
-       \  val simps = inject @ distinct @ cases;\n\
+       \  val cases =" ^ rules tname cons "_case_" ^ ";\n\
+       \  val recs =" ^ rules tname cons "_rec_" ^ ";\n\
+       \  val simps = inject @ distinct @ cases @ recs;\n\
        \  fun induct_tac a = res_inst_tac[(" ^ quote tname ^ ", a)]induct;\n\
-       \end;\n");
+       \end;\n")
   in (type_var_list || empty) -- ident --$$ "=" -- constructs >> mk_params end
+
+val primrec_decl =
+  let fun mkstrings((fname,tname),axms) =
+        let fun prove (name,eqn) =
+             "val "^name^"= prove_goalw thy [get_axiom thy \""^fname^"_def\"] "
+                 ^ eqn ^"\n\
+             \(fn _ => [resolve_tac " ^ tname^".recs 1])"
+        in ("|> " ^ tname^"_add_primrec " ^ mk_list (map snd axms),
+            cat_lines(map prove axms))
+        end
+  in ident -- ident -- repeat1 (ident -- string)  >> mkstrings end
+
 end;
 
 (*used for constructor parameters*)
@@ -89,7 +108,121 @@
 
 local open Syntax.Mixfix
       exception Impossible
+
+val is_rek = (fn dtRek _ => true  |  _  => false);
+
+(* ------------------------------------------------------------------------- *)
+(* Die Funktionen fuer das Umsetzen von Gleichungen in eine Definition mit   *)
+(* dem prim-Rek. Kombinator                                                  *)
+
+(*** Part 1: handling a single equation   ***)
+ 
+(* filter REK type args by correspondence with targs. Reverses order *) 
+
+fun rek_args (args, targs) = 
+let fun h (x :: xs, tx :: txs, res) 
+           = h(xs,txs,if is_rek tx then x :: res else res )
+     |  h ([],[],res) = res
+in h (args,targs,[])
+end;
+
+(* abstract over all recursive calls of f in t with param v in vs.
+   Name in abstraction is variant of v w.r.t. free names in t. 
+   Also returns reversed list of new variables names with types. 
+   Checks that there are no free occurences of f left. 
+*) 
+
+fun abstract_recs f vs t  = 
+let val tfrees = add_term_names(t,[]); 
+    fun h [] vns t = if fst(dest_Const f) mem add_term_names(t,[]) 
+		     then raise Impossible 
+		     else (t,vns)
+     |  h (v::vs) vns t
+        = let val vn = variant tfrees (fst(dest_Free v))
+          in  h vs (vn::vns) (Abs(vn, dummyT, abstract_over(f $ v,t)))
+          end;
+in h vs [] t
+end;
+
+(* For every defining equation, I need to abstract over arguments and
+   over the recursive calls. Cant do it simply minded in this order, because 
+   abstracting over v turns (Free v) into a bound variable, so that
+   abstract_recs does not apply anymore.  
+   abstract_arecs_funct performs the following steps 
+    * abstract over (f xi) (reverse order) 
+    * remove outermost length(rargs) abstractions
+    * increase loose bound variables index by #cargs
+    * apply the carg abstraction (reverse order) 
+    * add length(rargs) lambdas. 
+    Using lower level operations on term and arithmetic, this could probably
+    be made more efficient. 
+*) 
+
+(* remove n outermost abstractions from a term *)
+fun rem_Abs 0 t = t
+ |  rem_Abs n (Abs(s,T,t)) = rem_Abs (n-1) t
+;
+(* add one abstraction for for every variable in vs *)  
+fun add_Abs []      t = t
+ |  add_Abs (vname::vs) t = Abs(vname, dummyT, add_Abs vs t)
+; 
+fun abstract_arecs funct rargs args t = 
+let val (arecs,vns) = abstract_recs funct rargs t;
+in  add_Abs vns 
+    ( list_abs_free
+        ( map dest_Free args
+        , incr_boundvars (length args) (rem_Abs (length rargs) arecs)))
+end;
+
+(*** part 2. Processing of list of equations ***) 
+
+(* Take list of constructors cs and equations eqns. 
+   Find for ever element c of cs a corresponding eq in eqns. 
+   Check that the function name is unique and there are no double params.  
+   Derive term from equation using abstract_arecs and instantiate types. 
+   Assume: equation list eqns nonempty
+           length(eqns) = length(cs) 
+           every constant name identifies a constant and its type. 
+   In h: first parameter reqs reflects the remaining equations. 
+*)
+
+fun funs_from_eqns cs eqns =
+let fun dest_eq ( Const("Trueprop",_) $ (Const ("op =",_)
+                 $ (f $ capp) $ right))
+	         = (f, strip_comb(capp), right);
+    val fname = (fn (Const(f,_),_,_) => f) (dest_eq(hd eqns));
+    fun h []   []        []       res = res
+     |  h _    (_ :: _)  []       _   = raise Impossible
+     |  h _    []        (_ :: _) _   = raise Impossible
+     |  h reqs (eq::eqs) (c::cs)  res =
+	let
+          val (f,(Const(cname_eq,_),args),rhs) = dest_eq eq;
+          val (cname,targs,syn) = c;
+        in
+	  if (cname_eq <> const_name cname syn) then h reqs eqs (c::cs) res
+          else
+          if fst(dest_Const(f)) = fname
+             andalso (duplicates (map (fst o dest_Free) args) = [])
+          then let val reqs' = reqs \ eq
+               in h reqs' reqs' cs
+                    (abstract_arecs f (rek_args(args,targs)) args rhs :: res)
+	       end
+          else raise Impossible
+        end
+in (fname, h eqns eqns cs []) end;
+
+(* take datatype and eqns and return a properly type-instantiated 
+   application of the prim-rec-combinator which solves eqns.
+*)
+
+fun instant_types thy t =
+let val rs = Sign.rep_sg(sign_of thy);  
+in  fst(Type.infer_types( #tsig rs,#const_tab rs, K None, K None
+		        , TVar(("",0),[]), t))
+end;
+
 in
+
 fun add_datatype (typevars, tname, cons_list') thy = 
   let (*check if constructor names are unique*)
       fun check_cons (cs : (string * 'b * 'c) list) =
@@ -242,16 +375,18 @@
            in "," ^ h ^ typevar ^ (case_typlist typevar cs) end
         | case_typlist _ [] = "";
 
-      fun case_rules t_case arity n ((id, typlist, syn) :: cs) =
+      val t_case = tname ^ "_case";
+
+      fun case_rules arity n ((id, typlist, syn) :: cs) =
             let val name = const_name id syn;
                 val args = if null typlist then ""
   			   else pars(Args("x", ",", 1, length typlist))
             in (t_case ^ "_" ^ id,
                 t_case ^ "(" ^ name ^ args ^ "," ^ Args ("f", ",", 1, arity) 
                 ^ ") = f" ^ string_of_int(n) ^ args)
-               :: (case_rules t_case arity (n+1) cs)
+               :: (case_rules arity (n+1) cs)
             end
-        | case_rules _ _ _ [] = [];
+        | case_rules _ _ [] = [];
 
       val datatype_arity = length typevars;
 
@@ -265,19 +400,78 @@
 
       val (case_const, rules_case) =
          let val typevar = create_typevar (dtVar "'beta") typevars;
-             val t_case = tname ^ "_case";
              val arity = length cons_list;
              val dekl = (t_case, "[" ^ pp_typlist1 typevars ^ tname ^
                        case_typlist typevar cons_list ^ "]=>" ^ typevar, NoSyn)
-                       :: nil;
-             val rules = case_rules t_case arity 1 cons_list;
+             val rules = case_rules arity 1 cons_list;
          in (dekl, rules) end;
 
+
+
+(* -------------------------------------------------------------------- *)
+(* Die Funktionen fuer die t_rec - Funktion                             *)
+(* Analog zu t_case bis auf Hinzufuegen rek. Aufrufe pro Konstruktor 	*)
+
+      val t_rec = tname ^ "_rec"
+
+fun add_reks typevar ts = 
+  let val tv = dtVar typevar; 
+      fun h (t::ts) res = h ts (if is_rek(t) then tv::res else res)
+	| h [] res  = res        
+  in  h ts ts
+  end;
+
+fun rec_typlist typevar ((c,ts,_)::cs) = 
+    let val h = if (length ts) > 0 
+	        then (pp_typlist2 (add_reks typevar ts)) ^ "=>"
+                else ""
+    in "," ^ h ^ typevar ^ (rec_typlist typevar cs)
+    end
+  | rec_typlist _ [] = "";
+
+fun arg_reks arity ts = 
+  let fun arg_rek (t::ts) n res  = 
+        let val h = t_rec ^"(" ^ "x" ^string_of_int(n) 
+			       ^"," ^Args("f",",",1,arity) ^")," 
+        in arg_rek ts (n+1) (if is_rek(t) then res ^ h else res)
+        end 
+      | arg_rek [] _ res = res        
+  in  arg_rek ts 1 ""
+  end;
+
+fun rec_rules arity n ((id,ts,syn)::cs) =
+  let val name = const_name id syn;
+      val lts = length ts 
+      val args = if (lts = 0) then ""
+	         else "(" ^ Args("x",",",1,lts) ^ ")" 
+      val rargs = if (lts = 0) then ""
+	          else "("^ arg_reks arity ts ^ Args("x",",",1,lts) ^")"
+  in     
+    ( t_rec ^ "_" ^ id
+    , t_rec ^ "(" ^ name ^ args ^ "," ^ Args("f",",",1,arity) ^ ") = f"
+      ^ string_of_int(n) ^ rargs) 
+     :: (rec_rules arity (n+1) cs)
+  end
+  | rec_rules _ _ [] = [];
+
+val (rec_const,rules_rec) =
+   let val typevar = create_typevar (dtVar "'beta") typevars
+       val arity = length cons_list
+       val dekl = (t_rec,
+                    "[" ^ (pp_typlist1 typevars) ^ tname ^
+                      (rec_typlist typevar cons_list) ^ "]=>" ^ typevar,
+                    NoSyn)
+       val rules = rec_rules arity 1 cons_list
+       in (dekl,rules)
+   end;
+
+
+
       val consts = 
         map const_type cons_list
 	@ (if length cons_list < dtK then []
 	   else [(tname ^ "_ord", datatype_name ^ "=>nat", NoSyn)])
-	@ case_const;
+	@ [case_const,rec_const];
 
       (*generate 'var_n, ..., var_m'*)
       fun Args(var, delim, n, m) = 
@@ -335,13 +529,26 @@
 
       val rule_induct = (tname ^ "_induct", t_induct cons_list tname);
 
-      val rules = rule_induct :: (rules_inject @ rules_distinct @ rules_case);
-  in thy
+      val rules = rule_induct ::
+                  (rules_inject @ rules_distinct @ rules_case @ rules_rec);
+
+      fun add_primrec eqns thy =
+      let val rec_comb = Const(t_rec,dummyT)
+          val teqns = map (fn eq => snd(read_axm (sign_of thy) ("",eq))) eqns
+          val (fname,rfuns) = funs_from_eqns cons_list teqns
+          val rhs = Abs(tname, dummyT,
+                        list_comb(rec_comb, Bound 0 :: rev rfuns))
+          val def = Const("==",dummyT) $ Const(fname,dummyT) $ rhs
+          val tdef = instant_types thy def
+      in add_defns_i [(fname ^ "_def", tdef)] thy end;
+
+  in (thy
      |> add_types types
      |> add_arities arities
      |> add_consts consts
      |> add_trrules xrules
-     |> add_axioms rules
+     |> add_axioms rules,
+     add_primrec)
   end
 end
 end;