diff -r 0c7625196d95 -r 5e45dd3b64e9 TFL/tfl.sml --- a/TFL/tfl.sml Tue Jun 03 10:56:04 1997 +0200 +++ b/TFL/tfl.sml Tue Jun 03 11:08:08 1997 +0200 @@ -3,31 +3,17 @@ Author: Konrad Slind, Cambridge University Computer Laboratory Copyright 1997 University of Cambridge -Main TFL functor +Main module *) -functor TFL(structure Rules : Rules_sig - structure Thry : Thry_sig - structure Thms : Thms_sig) : TFL_sig = +structure Prim : TFL_sig = struct -(* Declarations *) -structure Thms = Thms; -structure Rules = Rules; -structure Thry = Thry; -structure USyntax = Thry.USyntax; - - (* Abbreviations *) structure R = Rules; structure S = USyntax; structure U = S.Utils; -nonfix mem -->; -val --> = S.-->; - -infixr 3 -->; - val concl = #2 o R.dest_thm; val hyp = #1 o R.dest_thm; @@ -83,13 +69,13 @@ | part {constrs = [], rows = _::_, A} = pfail"extra cases in defn" | part {constrs = _::_, rows = [], A} = pfail"cases missing in defn" | part {constrs = c::crst, rows, A} = - let val {Name,Ty} = S.dest_const c - val (L,_) = S.strip_type Ty + let val (Name,Ty) = dest_Const c + val L = binder_types Ty val (in_group, not_in_group) = U.itlist (fn (row as (p::rst, rhs)) => fn (in_group,not_in_group) => let val (pc,args) = S.strip_comb p - in if (#Name(S.dest_const pc) = Name) + in if (#1(dest_Const pc) = Name) then ((args@rst, rhs)::in_group, not_in_group) else (in_group, row::not_in_group) end) rows ([],[]) @@ -112,8 +98,10 @@ datatype pattern = GIVEN of term * int | OMITTED of term * int -fun psubst theta (GIVEN (tm,i)) = GIVEN(subst_free theta tm, i) - | psubst theta (OMITTED (tm,i)) = OMITTED(subst_free theta tm, i); +fun pattern_map f (GIVEN (tm,i)) = GIVEN(f tm, i) + | pattern_map f (OMITTED (tm,i)) = OMITTED(f tm, i); + +fun pattern_subst theta = pattern_map (subst_free theta); fun dest_pattern (GIVEN (tm,i)) = ((GIVEN,i),tm) | dest_pattern (OMITTED (tm,i)) = ((OMITTED,i),tm); @@ -125,8 +113,9 @@ * Produce an instance of a constructor, plus genvars for its arguments. *---------------------------------------------------------------------------*) fun fresh_constr ty_match colty gv c = - let val {Ty,...} = S.dest_const c - val (L,ty) = S.strip_type Ty + let val (_,Ty) = dest_Const c + val L = binder_types Ty + and ty = body_type Ty val ty_theta = ty_match ty colty val c' = S.inst ty_theta c val gvars = map (S.inst ty_theta o gv) L @@ -142,7 +131,7 @@ U.itlist (fn (row as ((prefix, p::rst), rhs)) => fn (in_group,not_in_group) => let val (pc,args) = S.strip_comb p - in if ((#Name(S.dest_const pc) = Name) handle _ => false) + in if ((#1(dest_Const pc) = Name) handle _ => false) then (((prefix,args@rst), rhs)::in_group, not_in_group) else (in_group, row::not_in_group) end) rows ([],[]); @@ -157,7 +146,7 @@ fun part {constrs = [], rows, A} = rev A | part {constrs = c::crst, rows, A} = let val (c',gvars) = fresh c - val {Name,Ty} = S.dest_const c' + val (Name,Ty) = dest_Const c' val (in_group, not_in_group) = mk_group Name rows val in_group' = if (null in_group) (* Constructor not given *) @@ -178,7 +167,7 @@ *---------------------------------------------------------------------------*) fun mk_pat (c,l) = - let val L = length(#1(S.strip_type(type_of c))) + let val L = length (binder_types (type_of c)) fun build (prefix,tag,plist) = let val args = take (L,plist) and plist' = drop(L,plist) @@ -211,7 +200,7 @@ then let val fresh = fresh_constr ty_match ty fresh_var fun expnd (c,gvs) = let val capp = list_comb(c,gvs) - in ((prefix, capp::rst), psubst[(p,capp)] rhs) + in ((prefix, capp::rst), pattern_subst[(p,capp)] rhs) end in map expnd (map fresh constructors) end else [row] @@ -229,7 +218,7 @@ val col0 = map(hd o #2) pat_rectangle in if (forall is_Free col0) - then let val rights' = map (fn(v,e) => psubst[(v,u)] e) + then let val rights' = map (fn(v,e) => pattern_subst[(v,u)] e) (ListPair.zip (col0, rights)) val pat_rectangle' = map v_to_prefix pat_rectangle val (pref_patl,tm) = mk{path = rstp, @@ -243,7 +232,7 @@ case (ty_info ty_name) of None => mk_case_fail("Not a known datatype: "^ty_name) | Some{case_const,constructors} => - let val case_const_name = #Name(S.dest_const case_const) + let val case_const_name = #1(dest_Const case_const) val nrows = List_.concat (map (expand constructors pty) rows) val subproblems = divide(constructors, pty, range_ty, nrows) val groups = map #group subproblems @@ -256,8 +245,7 @@ val case_functions = map S.list_mk_abs (ListPair.zip (new_formals, dtrees)) val types = map type_of (case_functions@[u]) @ [range_ty] - val case_const' = S.mk_const{Name = case_const_name, - Ty = list_mk_type types} + val case_const' = Const(case_const_name, list_mk_type types) val tree = list_comb(case_const', case_functions@[u]) val pat_rect1 = List_.concat (ListPair.map mk_pat (constructors', pat_rect)) @@ -271,7 +259,7 @@ (* Repeated variable occurrences in a pattern are not allowed. *) fun FV_multiset tm = case (S.dest_term tm) - of S.VAR v => [S.mk_var v] + of S.VAR{Name,Ty} => [Free(Name,Ty)] | S.CONST _ => [] | S.COMB{Rator, Rand} => FV_multiset Rator @ FV_multiset Rand | S.LAMB _ => raise TFL_ERR{func = "FV_multiset", mesg = "lambda"}; @@ -288,39 +276,35 @@ in check (FV_multiset pat) end; -local fun paired1{lhs,rhs} = (lhs,rhs) - and paired2{Rator,Rand} = (Rator,Rand) - fun mk_functional_err s = raise TFL_ERR{func = "mk_functional", mesg=s} +local fun mk_functional_err s = raise TFL_ERR{func = "mk_functional", mesg=s} fun single [f] = f | single fs = mk_functional_err (Int.toString (length fs) ^ " distinct function names!") in -fun mk_functional thy eqs = - let val clauses = S.strip_conj eqs - val (L,R) = ListPair.unzip - (map (paired1 o S.dest_eq o #2 o S.strip_forall) - clauses) - val (funcs,pats) = ListPair.unzip(map (paired2 o S.dest_comb) L) - val f = single (U.mk_set (S.aconv) funcs) - val fvar = if (is_Free f) then f else S.mk_var(S.dest_const f) +fun mk_functional thy clauses = + let val (L,R) = ListPair.unzip + (map (fn (Const("op =",_) $ t $ u) => (t,u)) clauses) + val (funcs,pats) = ListPair.unzip (map (fn (t$u) =>(t,u)) L) + val f = single (gen_distinct (op aconv) funcs) + (**??why change the Const to a Free??????????????**) + val fvar = if (is_Free f) then f else Free(dest_Const f) val dummy = map (no_repeat_vars thy) pats val rows = ListPair.zip (map (fn x => ([],[x])) pats, map GIVEN (enumerate R)) val fvs = S.free_varsl R - val a = S.variant fvs (S.mk_var{Name="a", Ty = type_of(hd pats)}) + val a = S.variant fvs (Free("a",type_of(hd pats))) val FV = a::fvs val ty_info = Thry.match_info thy val ty_match = Thry.match_type thy val range_ty = type_of (hd R) val (patts, case_tm) = mk_case ty_info ty_match FV range_ty {path=[a], rows=rows} - val patts1 = map (fn (_,(tag,i),[pat]) => tag (pat,i)) patts handle _ - => mk_functional_err "error in pattern-match translation" + val patts1 = map (fn (_,(tag,i),[pat]) => tag (pat,i)) patts + handle _ => mk_functional_err "error in pattern-match translation" val patts2 = U.sort(fn p1=>fn p2=> row_of_pat p1 < row_of_pat p2) patts1 val finals = map row_of_pat patts2 val originals = map (row_of_pat o #2) rows - fun int_eq i1 (i2:int) = (i1=i2) - val dummy = case (U.set_diff int_eq originals finals) + val dummy = case (originals\\finals) of [] => () | L => mk_functional_err("The following rows (counting from zero)\ \ are inaccessible: "^stringize L) @@ -352,20 +336,19 @@ val (fname,_) = dest_Free f val (wfrec,_) = S.strip_comb rhs in -fun wfrec_definition0 thy fid R functional = - let val {Bvar,...} = S.dest_abs functional - val (Name, Ty) = dest_Free Bvar - val def_name = if Name<>fid then +fun wfrec_definition0 thy fid R (functional as Abs(Name, Ty, _)) = + let val def_name = if Name<>fid then raise TFL_ERR{func = "wfrec_definition0", mesg = "Expected a definition of " ^ quote fid ^ " but found one of " ^ quote Name} else Name ^ "_def" - val wfrec_R_M = map_term_types poly_tvars - (wfrec $ R $ (map_term_types poly_tvars functional)) + val wfrec_R_M = map_term_types poly_tvars + (wfrec $ map_term_types poly_tvars R) + $ functional val (_, def_term, _) = Sign.infer_types (sign_of thy) (K None) (K None) [] false - ([HOLogic.mk_eq(Bvar, wfrec_R_M)], + ([HOLogic.mk_eq(Const(Name,Ty), wfrec_R_M)], HOLogic.boolT) in @@ -401,9 +384,10 @@ fun merge full_pats TCs = let fun insert (p,TCs) = let fun insrt ((x as (h,[]))::rst) = - if (S.aconv p h) then (p,TCs)::rst else x::insrt rst + if (p aconv h) then (p,TCs)::rst else x::insrt rst | insrt (x::rst) = x::insrt rst - | insrt[] = raise TFL_ERR{func="merge.insert",mesg="pat not found"} + | insrt[] = raise TFL_ERR{func="merge.insert", + mesg="pattern not found"} in insrt end fun pass ([],ptcl_final) = ptcl_final | pass (ptcs::tcl, ptcl) = pass(tcl, insert ptcs ptcl) @@ -411,10 +395,14 @@ pass (TCs, map (fn p => (p,[])) full_pats) end; -fun not_omitted (GIVEN(tm,_)) = tm - | not_omitted (OMITTED _) = raise TFL_ERR{func="not_omitted",mesg=""} -val givens = U.mapfilter not_omitted; +(*Replace all TFrees by TVars [CURRENTLY UNUSED]*) +val tvars_of_tfrees = + map_term_types (map_type_tfree (fn (a,sort) => TVar ((a, 0), sort))); + +fun givens [] = [] + | givens (GIVEN(tm,_)::pats) = tm :: givens pats + | givens (OMITTED _::pats) = givens pats; fun post_definition (theory, (def, pats)) = let val tych = Thry.typecheck theory @@ -424,7 +412,8 @@ val WFR = #ant(S.dest_imp(concl corollary)) val R = #Rand(S.dest_comb WFR) val corollary' = R.UNDISCH corollary (* put WF R on assums *) - val corollaries = map (U.C R.SPEC corollary' o tych) given_pats + val corollaries = map (fn pat => R.SPEC (tych pat) corollary') + given_pats val (case_rewrites,context_congs) = extraction_thms theory val corollaries' = map(R.simplify case_rewrites) corollaries fun xtract th = R.CONTEXT_REWRITE_RULE(f,R) @@ -438,7 +427,8 @@ in {theory = theory, (* holds def, if it's needed *) rules = rules1, - full_pats_TCs = merge (map pat_of pats) (ListPair.zip (given_pats, TCs)), + full_pats_TCs = merge (map pat_of pats) + (ListPair.zip (given_pats, TCs)), TCs = TCs, patterns = pats} end; @@ -457,7 +447,7 @@ val (case_rewrites,context_congs) = extraction_thms thy val tych = Thry.typecheck thy val WFREC_THM0 = R.ISPEC (tych functional) Thms.WFREC_COROLLARY - val R = S.variant(S.free_vars eqns) + val R = S.variant(foldr add_term_frees (eqns,[])) (#Bvar(S.dest_forall(concl WFREC_THM0))) val WFREC_THM = R.ISPECL [tych R, tych f] WFREC_THM0 val ([proto_def, WFR],_) = S.strip_imp(concl WFREC_THM) @@ -616,12 +606,12 @@ fun complete_cases thy = let val tych = Thry.typecheck thy - fun pmk_var n ty = S.mk_var{Name = n,Ty = ty} val ty_info = Thry.induct_info thy in fn pats => let val FV0 = S.free_varsl pats - val a = S.variant FV0 (pmk_var "a" (type_of(hd pats))) - val v = S.variant (a::FV0) (pmk_var "v" (type_of a)) + val T = type_of (hd pats) + val a = S.variant FV0 (Free ("a", T)) + val v = S.variant (a::FV0) (Free ("v", T)) val FV = a::v::FV0 val a_eq_v = HOLogic.mk_eq(a,v) val ex_th0 = R.EXISTS (tych (S.mk_exists{Bvar=v,Body=a_eq_v}), tych a) @@ -661,7 +651,7 @@ of [] => (P_y, (tm,[])) | _ => let val imp = S.list_mk_conj cntxt ==> P_y - val lvs = U.set_diff S.aconv (S.free_vars_lr imp) globals + val lvs = gen_rems (op aconv) (S.free_vars_lr imp, globals) val locals = #2(U.pluck (S.aconv P) lvs) handle _ => lvs in (S.list_mk_forall(locals,imp), (tm,locals)) end end @@ -750,7 +740,7 @@ val case_thm = complete_cases thy pats val domain = (type_of o hd) pats val P = S.variant (S.all_varsl (pats @ List_.concat TCsl)) - (S.mk_var{Name="P", Ty=domain --> HOLogic.boolT}) + (Free("P",domain --> HOLogic.boolT)) val Sinduct = R.SPEC (tych P) Sinduction val Sinduct_assumf = S.rand ((#ant o S.dest_imp o concl) Sinduct) val Rassums_TCl' = map (build_ih f P) pat_TCs_list @@ -760,7 +750,7 @@ val tasks = U.zip3 cases TCl' (R.CONJUNCTS Rinduct_assum) val proved_cases = map (prove_case f thy) tasks val v = S.variant (S.free_varsl (map concl proved_cases)) - (S.mk_var{Name="v", Ty=domain}) + (Free("v",domain)) val vtyped = tych v val substs = map (R.SYM o R.ASSUME o tych o (curry HOLogic.mk_eq v)) pats val proved_cases1 = ListPair.map (fn (th,th') => R.SUBS[th]th') @@ -774,11 +764,12 @@ (R.SPEC (tych(S.mk_vstruct Parg_ty vars)) dc) in R.GEN (tych P) (R.DISCH (tych(concl Rinduct_assum)) dc') -end +end handle _ => raise TFL_ERR{func = "mk_induction", mesg = "failed derivation"}; + (*--------------------------------------------------------------------------- * * POST PROCESSING @@ -875,7 +866,7 @@ fun loop ([],extras,R,ind) = (rev R, ind, extras) | loop ((r,ftcs)::rst, nthms, R, ind) = let val tcs = #1(strip_imp (concl r)) - val extra_tcs = U.set_diff S.aconv ftcs tcs + val extra_tcs = gen_rems (op aconv) (ftcs, tcs) val extra_tc_thms = map simplify_nested_tc extra_tcs val (r1,ind1) = U.rev_itlist simplify_tc tcs (r,ind) val r2 = R.FILTER_DISCH_ALL(not o S.is_WFR) r1