TFL/tfl.sml
changeset 3391 5e45dd3b64e9
parent 3388 dbf61e36f8e9
child 3405 2cccd0e3e9ea
--- a/TFL/tfl.sml	Tue Jun 03 10:56:04 1997 +0200
+++ b/TFL/tfl.sml	Tue Jun 03 11:08:08 1997 +0200
@@ -3,31 +3,17 @@
     Author:     Konrad Slind, Cambridge University Computer Laboratory
     Copyright   1997  University of Cambridge
 
-Main TFL functor
+Main module
 *)
 
-functor TFL(structure Rules : Rules_sig
-            structure Thry  : Thry_sig
-            structure Thms  : Thms_sig) : TFL_sig  =
+structure Prim : TFL_sig =
 struct
 
-(* Declarations *)
-structure Thms = Thms;
-structure Rules = Rules;
-structure Thry = Thry;
-structure USyntax = Thry.USyntax;
-
-
 (* Abbreviations *)
 structure R = Rules;
 structure S = USyntax;
 structure U = S.Utils;
 
-nonfix mem -->;
-val --> = S.-->;
-
-infixr 3 -->;
-
 val concl = #2 o R.dest_thm;
 val hyp = #1 o R.dest_thm;
 
@@ -83,13 +69,13 @@
         | part {constrs = [],   rows = _::_, A} = pfail"extra cases in defn"
         | part {constrs = _::_, rows = [],   A} = pfail"cases missing in defn"
         | part {constrs = c::crst, rows,     A} =
-          let val {Name,Ty} = S.dest_const c
-              val (L,_) = S.strip_type Ty
+          let val (Name,Ty) = dest_Const c
+              val L = binder_types Ty
               val (in_group, not_in_group) =
                U.itlist (fn (row as (p::rst, rhs)) =>
                          fn (in_group,not_in_group) =>
                   let val (pc,args) = S.strip_comb p
-                  in if (#Name(S.dest_const pc) = Name)
+                  in if (#1(dest_Const pc) = Name)
                      then ((args@rst, rhs)::in_group, not_in_group)
                      else (in_group, row::not_in_group)
                   end)      rows ([],[])
@@ -112,8 +98,10 @@
 datatype pattern = GIVEN   of term * int
                  | OMITTED of term * int
 
-fun psubst theta (GIVEN (tm,i)) = GIVEN(subst_free theta tm, i)
-  | psubst theta (OMITTED (tm,i)) = OMITTED(subst_free theta tm, i);
+fun pattern_map f (GIVEN (tm,i)) = GIVEN(f tm, i)
+  | pattern_map f (OMITTED (tm,i)) = OMITTED(f tm, i);
+
+fun pattern_subst theta = pattern_map (subst_free theta);
 
 fun dest_pattern (GIVEN (tm,i)) = ((GIVEN,i),tm)
   | dest_pattern (OMITTED (tm,i)) = ((OMITTED,i),tm);
@@ -125,8 +113,9 @@
  * Produce an instance of a constructor, plus genvars for its arguments.
  *---------------------------------------------------------------------------*)
 fun fresh_constr ty_match colty gv c =
-  let val {Ty,...} = S.dest_const c
-      val (L,ty) = S.strip_type Ty
+  let val (_,Ty) = dest_Const c
+      val L = binder_types Ty
+      and ty = body_type Ty
       val ty_theta = ty_match ty colty
       val c' = S.inst ty_theta c
       val gvars = map (S.inst ty_theta o gv) L
@@ -142,7 +131,7 @@
   U.itlist (fn (row as ((prefix, p::rst), rhs)) =>
             fn (in_group,not_in_group) =>
                let val (pc,args) = S.strip_comb p
-               in if ((#Name(S.dest_const pc) = Name) handle _ => false)
+               in if ((#1(dest_Const pc) = Name) handle _ => false)
                   then (((prefix,args@rst), rhs)::in_group, not_in_group)
                   else (in_group, row::not_in_group) end)
       rows ([],[]);
@@ -157,7 +146,7 @@
      fun part {constrs = [],      rows, A} = rev A
        | part {constrs = c::crst, rows, A} =
          let val (c',gvars) = fresh c
-             val {Name,Ty} = S.dest_const c'
+             val (Name,Ty) = dest_Const c'
              val (in_group, not_in_group) = mk_group Name rows
              val in_group' =
                  if (null in_group)  (* Constructor not given *)
@@ -178,7 +167,7 @@
  *---------------------------------------------------------------------------*)
 
 fun mk_pat (c,l) =
-  let val L = length(#1(S.strip_type(type_of c)))
+  let val L = length (binder_types (type_of c))
       fun build (prefix,tag,plist) =
           let val args   = take (L,plist)
               and plist' = drop(L,plist)
@@ -211,7 +200,7 @@
        then let val fresh = fresh_constr ty_match ty fresh_var
                 fun expnd (c,gvs) = 
                   let val capp = list_comb(c,gvs)
-                  in ((prefix, capp::rst), psubst[(p,capp)] rhs)
+                  in ((prefix, capp::rst), pattern_subst[(p,capp)] rhs)
                   end
             in map expnd (map fresh constructors)  end
        else [row]
@@ -229,7 +218,7 @@
          val col0 = map(hd o #2) pat_rectangle
      in 
      if (forall is_Free col0) 
-     then let val rights' = map (fn(v,e) => psubst[(v,u)] e)
+     then let val rights' = map (fn(v,e) => pattern_subst[(v,u)] e)
                                 (ListPair.zip (col0, rights))
               val pat_rectangle' = map v_to_prefix pat_rectangle
               val (pref_patl,tm) = mk{path = rstp,
@@ -243,7 +232,7 @@
      case (ty_info ty_name)
      of None => mk_case_fail("Not a known datatype: "^ty_name)
       | Some{case_const,constructors} =>
-        let val case_const_name = #Name(S.dest_const case_const)
+        let val case_const_name = #1(dest_Const case_const)
             val nrows = List_.concat (map (expand constructors pty) rows)
             val subproblems = divide(constructors, pty, range_ty, nrows)
             val groups      = map #group subproblems
@@ -256,8 +245,7 @@
             val case_functions = map S.list_mk_abs
                                   (ListPair.zip (new_formals, dtrees))
             val types = map type_of (case_functions@[u]) @ [range_ty]
-            val case_const' = S.mk_const{Name = case_const_name,
-                                         Ty   = list_mk_type types}
+            val case_const' = Const(case_const_name, list_mk_type types)
             val tree = list_comb(case_const', case_functions@[u])
             val pat_rect1 = List_.concat
                               (ListPair.map mk_pat (constructors', pat_rect))
@@ -271,7 +259,7 @@
 (* Repeated variable occurrences in a pattern are not allowed. *)
 fun FV_multiset tm = 
    case (S.dest_term tm)
-     of S.VAR v => [S.mk_var v]
+     of S.VAR{Name,Ty} => [Free(Name,Ty)]
       | S.CONST _ => []
       | S.COMB{Rator, Rand} => FV_multiset Rator @ FV_multiset Rand
       | S.LAMB _ => raise TFL_ERR{func = "FV_multiset", mesg = "lambda"};
@@ -288,39 +276,35 @@
  in check (FV_multiset pat)
  end;
 
-local fun paired1{lhs,rhs} = (lhs,rhs) 
-      and paired2{Rator,Rand} = (Rator,Rand)
-      fun mk_functional_err s = raise TFL_ERR{func = "mk_functional", mesg=s}
+local fun mk_functional_err s = raise TFL_ERR{func = "mk_functional", mesg=s}
       fun single [f] = f
         | single fs  = mk_functional_err (Int.toString (length fs) ^ 
                                           " distinct function names!")
 in
-fun mk_functional thy eqs =
- let val clauses = S.strip_conj eqs
-     val (L,R) = ListPair.unzip 
-                    (map (paired1 o S.dest_eq o #2 o S.strip_forall)
-                         clauses)
-     val (funcs,pats) = ListPair.unzip(map (paired2 o S.dest_comb) L)
-     val f = single (U.mk_set (S.aconv) funcs)
-     val fvar = if (is_Free f) then f else S.mk_var(S.dest_const f)
+fun mk_functional thy clauses =
+ let val (L,R) = ListPair.unzip 
+                    (map (fn (Const("op =",_) $ t $ u) => (t,u)) clauses)
+     val (funcs,pats) = ListPair.unzip (map (fn (t$u) =>(t,u)) L)
+     val f = single (gen_distinct (op aconv) funcs)
+     (**??why change the Const to a Free??????????????**)
+     val fvar = if (is_Free f) then f else Free(dest_Const f)
      val dummy = map (no_repeat_vars thy) pats
      val rows = ListPair.zip (map (fn x => ([],[x])) pats,
                               map GIVEN (enumerate R))
      val fvs = S.free_varsl R
-     val a = S.variant fvs (S.mk_var{Name="a", Ty = type_of(hd pats)})
+     val a = S.variant fvs (Free("a",type_of(hd pats)))
      val FV = a::fvs
      val ty_info = Thry.match_info thy
      val ty_match = Thry.match_type thy
      val range_ty = type_of (hd R)
      val (patts, case_tm) = mk_case ty_info ty_match FV range_ty 
                                     {path=[a], rows=rows}
-     val patts1 = map (fn (_,(tag,i),[pat]) => tag (pat,i)) patts handle _ 
-                  => mk_functional_err "error in pattern-match translation"
+     val patts1 = map (fn (_,(tag,i),[pat]) => tag (pat,i)) patts 
+	  handle _ => mk_functional_err "error in pattern-match translation"
      val patts2 = U.sort(fn p1=>fn p2=> row_of_pat p1 < row_of_pat p2) patts1
      val finals = map row_of_pat patts2
      val originals = map (row_of_pat o #2) rows
-     fun int_eq i1 (i2:int) =  (i1=i2)
-     val dummy = case (U.set_diff int_eq originals finals)
+     val dummy = case (originals\\finals)
              of [] => ()
           | L => mk_functional_err("The following rows (counting from zero)\
                                    \ are inaccessible: "^stringize L)
@@ -352,20 +336,19 @@
       val (fname,_) = dest_Free f
       val (wfrec,_) = S.strip_comb rhs
 in
-fun wfrec_definition0 thy fid R functional =
- let val {Bvar,...} = S.dest_abs functional
-     val (Name, Ty) = dest_Free Bvar  
-     val def_name = if Name<>fid then 
+fun wfrec_definition0 thy fid R (functional as Abs(Name, Ty, _)) =
+ let val def_name = if Name<>fid then 
 			raise TFL_ERR{func = "wfrec_definition0",
 				      mesg = "Expected a definition of " ^ 
 					     quote fid ^ " but found one of " ^
 				      quote Name}
 		    else Name ^ "_def"
-     val wfrec_R_M = map_term_types poly_tvars
-	               (wfrec $ R $ (map_term_types poly_tvars functional))
+     val wfrec_R_M =  map_term_types poly_tvars 
+	                  (wfrec $ map_term_types poly_tvars R) 
+	              $ functional
      val (_, def_term, _) = 
 	 Sign.infer_types (sign_of thy) (K None) (K None) [] false
-	       ([HOLogic.mk_eq(Bvar, wfrec_R_M)], 
+	       ([HOLogic.mk_eq(Const(Name,Ty), wfrec_R_M)], 
 		HOLogic.boolT)
 		    
   in 
@@ -401,9 +384,10 @@
 fun merge full_pats TCs =
 let fun insert (p,TCs) =
       let fun insrt ((x as (h,[]))::rst) = 
-                 if (S.aconv p h) then (p,TCs)::rst else x::insrt rst
+                 if (p aconv h) then (p,TCs)::rst else x::insrt rst
             | insrt (x::rst) = x::insrt rst
-            | insrt[] = raise TFL_ERR{func="merge.insert",mesg="pat not found"}
+            | insrt[] = raise TFL_ERR{func="merge.insert",
+				      mesg="pattern not found"}
       in insrt end
     fun pass ([],ptcl_final) = ptcl_final
       | pass (ptcs::tcl, ptcl) = pass(tcl, insert ptcs ptcl)
@@ -411,10 +395,14 @@
   pass (TCs, map (fn p => (p,[])) full_pats)
 end;
 
-fun not_omitted (GIVEN(tm,_)) = tm
-  | not_omitted (OMITTED _) = raise TFL_ERR{func="not_omitted",mesg=""}
-val givens = U.mapfilter not_omitted;
 
+(*Replace all TFrees by TVars [CURRENTLY UNUSED]*)
+val tvars_of_tfrees = 
+    map_term_types (map_type_tfree (fn (a,sort) => TVar ((a, 0), sort)));
+
+fun givens [] = []
+  | givens (GIVEN(tm,_)::pats) = tm :: givens pats
+  | givens (OMITTED _::pats)   = givens pats;
 
 fun post_definition (theory, (def, pats)) =
  let val tych = Thry.typecheck theory 
@@ -424,7 +412,8 @@
      val WFR = #ant(S.dest_imp(concl corollary))
      val R = #Rand(S.dest_comb WFR)
      val corollary' = R.UNDISCH corollary  (* put WF R on assums *)
-     val corollaries = map (U.C R.SPEC corollary' o tych) given_pats
+     val corollaries = map (fn pat => R.SPEC (tych pat) corollary') 
+	                   given_pats
      val (case_rewrites,context_congs) = extraction_thms theory
      val corollaries' = map(R.simplify case_rewrites) corollaries
      fun xtract th = R.CONTEXT_REWRITE_RULE(f,R)
@@ -438,7 +427,8 @@
  in
  {theory = theory,   (* holds def, if it's needed *)
   rules = rules1,
-  full_pats_TCs = merge (map pat_of pats) (ListPair.zip (given_pats, TCs)), 
+  full_pats_TCs = merge (map pat_of pats) 
+                        (ListPair.zip (given_pats, TCs)), 
   TCs = TCs, 
   patterns = pats}
  end;
@@ -457,7 +447,7 @@
      val (case_rewrites,context_congs) = extraction_thms thy
      val tych = Thry.typecheck thy
      val WFREC_THM0 = R.ISPEC (tych functional) Thms.WFREC_COROLLARY
-     val R = S.variant(S.free_vars eqns) 
+     val R = S.variant(foldr add_term_frees (eqns,[])) 
                       (#Bvar(S.dest_forall(concl WFREC_THM0)))
      val WFREC_THM = R.ISPECL [tych R, tych f] WFREC_THM0
      val ([proto_def, WFR],_) = S.strip_imp(concl WFREC_THM)
@@ -616,12 +606,12 @@
 
 fun complete_cases thy =
  let val tych = Thry.typecheck thy
-     fun pmk_var n ty = S.mk_var{Name = n,Ty = ty}
      val ty_info = Thry.induct_info thy
  in fn pats =>
  let val FV0 = S.free_varsl pats
-     val a = S.variant FV0 (pmk_var "a" (type_of(hd pats)))
-     val v = S.variant (a::FV0) (pmk_var "v" (type_of a))
+     val T = type_of (hd pats)
+     val a = S.variant FV0 (Free ("a", T))
+     val v = S.variant (a::FV0) (Free ("v", T))
      val FV = a::v::FV0
      val a_eq_v = HOLogic.mk_eq(a,v)
      val ex_th0 = R.EXISTS (tych (S.mk_exists{Bvar=v,Body=a_eq_v}), tych a)
@@ -661,7 +651,7 @@
               of [] => (P_y, (tm,[]))
                | _  => let 
                     val imp = S.list_mk_conj cntxt ==> P_y
-                    val lvs = U.set_diff S.aconv (S.free_vars_lr imp) globals
+                    val lvs = gen_rems (op aconv) (S.free_vars_lr imp, globals)
                     val locals = #2(U.pluck (S.aconv P) lvs) handle _ => lvs
                     in (S.list_mk_forall(locals,imp), (tm,locals)) end
          end
@@ -750,7 +740,7 @@
     val case_thm = complete_cases thy pats
     val domain = (type_of o hd) pats
     val P = S.variant (S.all_varsl (pats @ List_.concat TCsl))
-                      (S.mk_var{Name="P", Ty=domain --> HOLogic.boolT})
+                      (Free("P",domain --> HOLogic.boolT))
     val Sinduct = R.SPEC (tych P) Sinduction
     val Sinduct_assumf = S.rand ((#ant o S.dest_imp o concl) Sinduct)
     val Rassums_TCl' = map (build_ih f P) pat_TCs_list
@@ -760,7 +750,7 @@
     val tasks = U.zip3 cases TCl' (R.CONJUNCTS Rinduct_assum)
     val proved_cases = map (prove_case f thy) tasks
     val v = S.variant (S.free_varsl (map concl proved_cases))
-                      (S.mk_var{Name="v", Ty=domain})
+                      (Free("v",domain))
     val vtyped = tych v
     val substs = map (R.SYM o R.ASSUME o tych o (curry HOLogic.mk_eq v)) pats
     val proved_cases1 = ListPair.map (fn (th,th') => R.SUBS[th]th') 
@@ -774,11 +764,12 @@
                        (R.SPEC (tych(S.mk_vstruct Parg_ty vars)) dc)
 in 
    R.GEN (tych P) (R.DISCH (tych(concl Rinduct_assum)) dc')
-end 
+end
 handle _ => raise TFL_ERR{func = "mk_induction", mesg = "failed derivation"};
 
 
 
+
 (*---------------------------------------------------------------------------
  * 
  *                        POST PROCESSING
@@ -875,7 +866,7 @@
    fun loop ([],extras,R,ind) = (rev R, ind, extras)
      | loop ((r,ftcs)::rst, nthms, R, ind) =
         let val tcs = #1(strip_imp (concl r))
-            val extra_tcs = U.set_diff S.aconv ftcs tcs
+            val extra_tcs = gen_rems (op aconv) (ftcs, tcs)
             val extra_tc_thms = map simplify_nested_tc extra_tcs
             val (r1,ind1) = U.rev_itlist simplify_tc tcs (r,ind)
             val r2 = R.FILTER_DISCH_ALL(not o S.is_WFR) r1