src/Pure/meta_simplifier.ML
author skalberg
Tue Oct 21 11:09:23 2003 +0200 (2003-10-21)
changeset 14242 ec70653a02bf
parent 14040 8a2c8f762837
child 14330 eb8b8241ef5b
permissions -rw-r--r--
Added access to the mk_rews field (and friends).
     1 (*  Title:      Pure/meta_simplifier.ML
     2     ID:         $Id$
     3     Author:     Tobias Nipkow and Stefan Berghofer
     4     License:    GPL (GNU GENERAL PUBLIC LICENSE)
     5 
     6 Meta-level Simplification.
     7 *)
     8 
     9 signature BASIC_META_SIMPLIFIER =
    10 sig
    11   val trace_simp: bool ref
    12   val debug_simp: bool ref
    13   val simp_depth_limit: int ref
    14 end;
    15 
    16 signature META_SIMPLIFIER =
    17 sig
    18   include BASIC_META_SIMPLIFIER
    19   exception SIMPLIFIER of string * thm
    20   exception SIMPROC_FAIL of string * exn
    21   type meta_simpset
    22   val dest_mss          : meta_simpset ->
    23     {simps: thm list, congs: thm list, procs: (string * cterm list) list}
    24   val empty_mss         : meta_simpset
    25   val clear_mss         : meta_simpset -> meta_simpset
    26   val merge_mss         : meta_simpset * meta_simpset -> meta_simpset
    27   val add_simps         : meta_simpset * thm list -> meta_simpset
    28   val del_simps         : meta_simpset * thm list -> meta_simpset
    29   val mss_of            : thm list -> meta_simpset
    30   val add_congs         : meta_simpset * thm list -> meta_simpset
    31   val del_congs         : meta_simpset * thm list -> meta_simpset
    32   val add_simprocs      : meta_simpset *
    33     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    34       -> meta_simpset
    35   val del_simprocs      : meta_simpset *
    36     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    37       -> meta_simpset
    38   val add_prems         : meta_simpset * thm list -> meta_simpset
    39   val prems_of_mss      : meta_simpset -> thm list
    40   val set_mk_rews       : meta_simpset * (thm -> thm list) -> meta_simpset
    41   val set_mk_sym        : meta_simpset * (thm -> thm option) -> meta_simpset
    42   val set_mk_eq_True    : meta_simpset * (thm -> thm option) -> meta_simpset
    43   val get_mk_rews       : meta_simpset -> thm -> thm list
    44   val get_mk_sym        : meta_simpset -> thm -> thm option
    45   val get_mk_eq_True    : meta_simpset -> thm -> thm option
    46   val set_termless      : meta_simpset * (term * term -> bool) -> meta_simpset
    47   val beta_eta_conversion: cterm -> thm
    48   val rewrite_cterm: bool * bool * bool ->
    49     (meta_simpset -> thm -> thm option) -> meta_simpset -> cterm -> thm
    50   val goals_conv        : (int -> bool) -> (cterm -> thm) -> cterm -> thm
    51   val forall_conv       : (cterm -> thm) -> cterm -> thm
    52   val fconv_rule        : (cterm -> thm) -> thm -> thm
    53   val rewrite_aux       : (meta_simpset -> thm -> thm option) -> bool -> thm list -> cterm -> thm
    54   val simplify_aux      : (meta_simpset -> thm -> thm option) -> bool -> thm list -> thm -> thm
    55   val rewrite_thm       : bool * bool * bool
    56                           -> (meta_simpset -> thm -> thm option)
    57                           -> meta_simpset -> thm -> thm
    58   val rewrite_goals_rule_aux: (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
    59   val rewrite_goal_rule : bool* bool * bool
    60                           -> (meta_simpset -> thm -> thm option)
    61                           -> meta_simpset -> int -> thm -> thm
    62   val rewrite_term: Sign.sg -> thm list -> (term -> term option) list -> term -> term
    63 end;
    64 
    65 structure MetaSimplifier : META_SIMPLIFIER =
    66 struct
    67 
    68 (** diagnostics **)
    69 
    70 exception SIMPLIFIER of string * thm;
    71 exception SIMPROC_FAIL of string * exn;
    72 
    73 val simp_depth = ref 0;
    74 val simp_depth_limit = ref 1000;
    75 
    76 local
    77 
    78 fun println a =
    79   tracing ((case ! simp_depth of 0 => "" | n => "[" ^ string_of_int n ^ "]") ^ a);
    80 
    81 fun prnt warn a = if warn then warning a else println a;
    82 fun prtm warn a sign t = prnt warn (a ^ "\n" ^ Sign.string_of_term sign t);
    83 fun prctm warn a t = prnt warn (a ^ "\n" ^ Display.string_of_cterm t);
    84 
    85 in
    86 
    87 fun prthm warn a = prctm warn a o Thm.cprop_of;
    88 
    89 val trace_simp = ref false;
    90 val debug_simp = ref false;
    91 
    92 fun trace warn a = if !trace_simp then prnt warn a else ();
    93 fun debug warn a = if !debug_simp then prnt warn a else ();
    94 
    95 fun trace_term warn a sign t = if !trace_simp then prtm warn a sign t else ();
    96 fun trace_cterm warn a t = if !trace_simp then prctm warn a t else ();
    97 fun debug_term warn a sign t = if !debug_simp then prtm warn a sign t else ();
    98 
    99 fun trace_thm a thm =
   100   let val {sign, prop, ...} = rep_thm thm
   101   in trace_term false a sign prop end;
   102 
   103 fun trace_named_thm a (thm, name) =
   104   trace_thm (a ^ (if name = "" then "" else " " ^ quote name) ^ ":") thm;
   105 
   106 end;
   107 
   108 
   109 (** meta simp sets **)
   110 
   111 (* basic components *)
   112 
   113 type rrule = {thm: thm, name: string, lhs: term, elhs: cterm, fo: bool, perm: bool};
   114 (* thm: the rewrite rule
   115    name: name of theorem from which rewrite rule was extracted
   116    lhs: the left-hand side
   117    elhs: the etac-contracted lhs.
   118    fo:  use first-order matching
   119    perm: the rewrite rule is permutative
   120 Remarks:
   121   - elhs is used for matching,
   122     lhs only for preservation of bound variable names.
   123   - fo is set iff
   124     either elhs is first-order (no Var is applied),
   125            in which case fo-matching is complete,
   126     or elhs is not a pattern,
   127        in which case there is nothing better to do.
   128 *)
   129 type cong = {thm: thm, lhs: cterm};
   130 type simproc =
   131  {name: string, proc: Sign.sg -> thm list -> term -> thm option, lhs: cterm, id: stamp};
   132 
   133 fun eq_rrule ({thm = thm1, ...}: rrule, {thm = thm2, ...}: rrule) =
   134   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   135 
   136 fun eq_cong ({thm = thm1, ...}: cong, {thm = thm2, ...}: cong) =
   137   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   138 
   139 fun eq_prem (thm1, thm2) =
   140   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   141 
   142 fun eq_simproc ({id = s1, ...}:simproc, {id = s2, ...}:simproc) = (s1 = s2);
   143 
   144 fun mk_simproc (name, proc, lhs, id) =
   145   {name = name, proc = proc, lhs = lhs, id = id};
   146 
   147 
   148 (* datatype mss *)
   149 
   150 (*
   151   A "mss" contains data needed during conversion:
   152     rules: discrimination net of rewrite rules;
   153     congs: association list of congruence rules and
   154            a list of `weak' congruence constants.
   155            A congruence is `weak' if it avoids normalization of some argument.
   156     procs: discrimination net of simplification procedures
   157       (functions that prove rewrite rules on the fly);
   158     bounds: names of bound variables already used
   159       (for generating new names when rewriting under lambda abstractions);
   160     prems: current premises;
   161     mk_rews: mk: turns simplification thms into rewrite rules;
   162              mk_sym: turns == around; (needs Drule!)
   163              mk_eq_True: turns P into P == True - logic specific;
   164     termless: relation for ordered rewriting;
   165     depth: depth of conditional rewriting;
   166 *)
   167 
   168 datatype meta_simpset =
   169   Mss of {
   170     rules: rrule Net.net,
   171     congs: (string * cong) list * string list,
   172     procs: simproc Net.net,
   173     bounds: string list,
   174     prems: thm list,
   175     mk_rews: {mk: thm -> thm list,
   176               mk_sym: thm -> thm option,
   177               mk_eq_True: thm -> thm option},
   178     termless: term * term -> bool,
   179     depth: int};
   180 
   181 fun mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth) =
   182   Mss {rules = rules, congs = congs, procs = procs, bounds = bounds,
   183        prems=prems, mk_rews=mk_rews, termless=termless, depth=depth};
   184 
   185 fun upd_rules(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}, rules') =
   186   mk_mss(rules',congs,procs,bounds,prems,mk_rews,termless,depth);
   187 
   188 val empty_mss =
   189   let val mk_rews = {mk = K [], mk_sym = K None, mk_eq_True = K None}
   190   in mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, Term.termless, 0) end;
   191 
   192 fun clear_mss (Mss {mk_rews, termless, ...}) =
   193   mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, termless,0);
   194 
   195 fun incr_depth(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) =
   196   let val depth1 = depth+1
   197   in if depth1 > !simp_depth_limit
   198      then (warning "simp_depth_limit exceeded - giving up"; None)
   199      else (if depth1 mod 10 = 0
   200            then warning("Simplification depth " ^ string_of_int depth1)
   201            else ();
   202            Some(mk_mss(rules,congs,procs,bounds,prems,mk_rews,termless,depth1))
   203           )
   204   end;
   205 
   206 
   207 (** simpset operations **)
   208 
   209 (* term variables *)
   210 
   211 val add_term_varnames = foldl_aterms (fn (xs, Var (x, _)) => ins_ix (x, xs) | (xs, _) => xs);
   212 fun term_varnames t = add_term_varnames ([], t);
   213 
   214 
   215 (* dest_mss *)
   216 
   217 fun dest_mss (Mss {rules, congs, procs, ...}) =
   218   {simps = map (fn (_, {thm, ...}) => thm) (Net.dest rules),
   219    congs = map (fn (_, {thm, ...}) => thm) (fst congs),
   220    procs =
   221      map (fn (_, {name, lhs, id, ...}) => ((name, lhs), id)) (Net.dest procs)
   222      |> partition_eq eq_snd
   223      |> map (fn ps => (#1 (#1 (hd ps)), map (#2 o #1) ps))
   224      |> Library.sort_wrt #1};
   225 
   226 
   227 (* merge_mss *)       (*NOTE: ignores mk_rews, termless and depth of 2nd mss*)
   228 
   229 fun merge_mss
   230  (Mss {rules = rules1, congs = (congs1,weak1), procs = procs1,
   231        bounds = bounds1, prems = prems1, mk_rews, termless, depth},
   232   Mss {rules = rules2, congs = (congs2,weak2), procs = procs2,
   233        bounds = bounds2, prems = prems2, ...}) =
   234       mk_mss
   235        (Net.merge (rules1, rules2, eq_rrule),
   236         (gen_merge_lists (eq_cong o pairself snd) congs1 congs2,
   237         merge_lists weak1 weak2),
   238         Net.merge (procs1, procs2, eq_simproc),
   239         merge_lists bounds1 bounds2,
   240         gen_merge_lists eq_prem prems1 prems2,
   241         mk_rews, termless, depth);
   242 
   243 
   244 (* add_simps *)
   245 
   246 fun mk_rrule2{thm, name, lhs, elhs, perm} =
   247   let val fo = Pattern.first_order (term_of elhs) orelse not(Pattern.pattern (term_of elhs))
   248   in {thm=thm, name=name, lhs=lhs, elhs=elhs, fo=fo, perm=perm} end
   249 
   250 fun insert_rrule quiet (mss as Mss {rules,...},
   251                  rrule as {thm,name,lhs,elhs,perm}) =
   252   (trace_named_thm "Adding rewrite rule" (thm, name);
   253    let val rrule2 as {elhs,...} = mk_rrule2 rrule
   254        val rules' = Net.insert_term ((term_of elhs, rrule2), rules, eq_rrule)
   255    in upd_rules(mss,rules') end
   256    handle Net.INSERT => if quiet then mss else
   257      (prthm true "Ignoring duplicate rewrite rule:" thm; mss));
   258 
   259 fun vperm (Var _, Var _) = true
   260   | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t)
   261   | vperm (t1 $ t2, u1 $ u2) = vperm (t1, u1) andalso vperm (t2, u2)
   262   | vperm (t, u) = (t = u);
   263 
   264 fun var_perm (t, u) =
   265   vperm (t, u) andalso eq_set (term_varnames t, term_varnames u);
   266 
   267 (* FIXME: it seems that the conditions on extra variables are too liberal if
   268 prems are nonempty: does solving the prems really guarantee instantiation of
   269 all its Vars? Better: a dynamic check each time a rule is applied.
   270 *)
   271 fun rewrite_rule_extra_vars prems elhs erhs =
   272   not (term_varnames erhs subset foldl add_term_varnames (term_varnames elhs, prems))
   273   orelse
   274   not ((term_tvars erhs) subset
   275        (term_tvars elhs  union  List.concat(map term_tvars prems)));
   276 
   277 (*Simple test for looping rewrite rules and stupid orientations*)
   278 fun reorient sign prems lhs rhs =
   279    rewrite_rule_extra_vars prems lhs rhs
   280   orelse
   281    is_Var (head_of lhs)
   282   orelse
   283    (exists (apl (lhs, Logic.occs)) (rhs :: prems))
   284   orelse
   285    (null prems andalso
   286     Pattern.matches (#tsig (Sign.rep_sg sign)) (lhs, rhs))
   287     (*the condition "null prems" is necessary because conditional rewrites
   288       with extra variables in the conditions may terminate although
   289       the rhs is an instance of the lhs. Example: ?m < ?n ==> f(?n) == f(?m)*)
   290   orelse
   291    (is_Const lhs andalso not(is_Const rhs))
   292 
   293 fun decomp_simp thm =
   294   let val {sign, prop, ...} = rep_thm thm;
   295       val prems = Logic.strip_imp_prems prop;
   296       val concl = Drule.strip_imp_concl (cprop_of thm);
   297       val (lhs, rhs) = Drule.dest_equals concl handle TERM _ =>
   298         raise SIMPLIFIER ("Rewrite rule not a meta-equality", thm)
   299       val elhs = snd (Drule.dest_equals (cprop_of (Thm.eta_conversion lhs)));
   300       val elhs = if elhs=lhs then lhs else elhs (* try to share *)
   301       val erhs = Pattern.eta_contract (term_of rhs);
   302       val perm = var_perm (term_of elhs, erhs) andalso not (term_of elhs aconv erhs)
   303                  andalso not (is_Var (term_of elhs))
   304   in (sign, prems, term_of lhs, elhs, term_of rhs, perm) end;
   305 
   306 fun decomp_simp' thm =
   307   let val (_, _, lhs, _, rhs, _) = decomp_simp thm in
   308     if Thm.nprems_of thm > 0 then raise SIMPLIFIER ("Bad conditional rewrite rule", thm)
   309     else (lhs, rhs)
   310   end;
   311 
   312 fun mk_eq_True (Mss{mk_rews={mk_eq_True,...},...}) (thm, name) =
   313   case mk_eq_True thm of
   314     None => []
   315   | Some eq_True =>
   316       let val (_,_,lhs,elhs,_,_) = decomp_simp eq_True
   317       in [{thm=eq_True, name=name, lhs=lhs, elhs=elhs, perm=false}] end;
   318 
   319 (* create the rewrite rule and possibly also the ==True variant,
   320    in case there are extra vars on the rhs *)
   321 fun rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm2) =
   322   let val rrule = {thm=thm, name=name, lhs=lhs, elhs=elhs, perm=false}
   323   in if (term_varnames rhs)  subset (term_varnames lhs) andalso
   324         (term_tvars rhs) subset (term_tvars lhs)
   325      then [rrule]
   326      else mk_eq_True mss (thm2, name) @ [rrule]
   327   end;
   328 
   329 fun mk_rrule mss (thm, name) =
   330   let val (_,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   331   in if perm then [{thm=thm, name=name, lhs=lhs, elhs=elhs, perm=true}] else
   332      (* weak test for loops: *)
   333      if rewrite_rule_extra_vars prems lhs rhs orelse
   334         is_Var (term_of elhs)
   335      then mk_eq_True mss (thm, name)
   336      else rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm)
   337   end;
   338 
   339 fun orient_rrule mss (thm, name) =
   340   let val (sign,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   341   in if perm then [{thm=thm, name=name, lhs=lhs, elhs=elhs, perm=true}]
   342      else if reorient sign prems lhs rhs
   343           then if reorient sign prems rhs lhs
   344                then mk_eq_True mss (thm, name)
   345                else let val Mss{mk_rews={mk_sym,...},...} = mss
   346                     in case mk_sym thm of
   347                          None => []
   348                        | Some thm' =>
   349                            let val (_,_,lhs',elhs',rhs',_) = decomp_simp thm'
   350                            in rrule_eq_True(thm',name,lhs',elhs',rhs',mss,thm) end
   351                     end
   352           else rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm)
   353   end;
   354 
   355 fun extract_rews(Mss{mk_rews = {mk,...},...},thms) =
   356   flat (map (fn thm => map (rpair (Thm.name_of_thm thm)) (mk thm)) thms);
   357 
   358 fun orient_comb_simps comb mk_rrule (mss,thms) =
   359   let val rews = extract_rews(mss,thms)
   360       val rrules = flat (map mk_rrule rews)
   361   in foldl comb (mss,rrules) end
   362 
   363 (* Add rewrite rules explicitly; do not reorient! *)
   364 fun add_simps(mss,thms) =
   365   orient_comb_simps (insert_rrule false) (mk_rrule mss) (mss,thms);
   366 
   367 fun mss_of thms = foldl (insert_rrule false) (empty_mss, flat
   368   (map (fn thm => mk_rrule empty_mss (thm, Thm.name_of_thm thm)) thms));
   369 
   370 fun extract_safe_rrules(mss,thm) =
   371   flat (map (orient_rrule mss) (extract_rews(mss,[thm])));
   372 
   373 (* del_simps *)
   374 
   375 fun del_rrule(mss as Mss {rules,...},
   376               rrule as {thm, elhs, ...}) =
   377   (upd_rules(mss, Net.delete_term ((term_of elhs, rrule), rules, eq_rrule))
   378    handle Net.DELETE =>
   379      (prthm true "Rewrite rule not in simpset:" thm; mss));
   380 
   381 fun del_simps(mss,thms) =
   382   orient_comb_simps del_rrule (map mk_rrule2 o mk_rrule mss) (mss,thms);
   383 
   384 
   385 (* add_congs *)
   386 
   387 fun is_full_cong_prems [] varpairs = null varpairs
   388   | is_full_cong_prems (p::prems) varpairs =
   389     (case Logic.strip_assums_concl p of
   390        Const("==",_) $ lhs $ rhs =>
   391          let val (x,xs) = strip_comb lhs and (y,ys) = strip_comb rhs
   392          in is_Var x  andalso  forall is_Bound xs  andalso
   393             null(findrep(xs))  andalso xs=ys andalso
   394             (x,y) mem varpairs andalso
   395             is_full_cong_prems prems (varpairs\(x,y))
   396          end
   397      | _ => false);
   398 
   399 fun is_full_cong thm =
   400 let val prems = prems_of thm
   401     and concl = concl_of thm
   402     val (lhs,rhs) = Logic.dest_equals concl
   403     val (f,xs) = strip_comb lhs
   404     and (g,ys) = strip_comb rhs
   405 in
   406   f=g andalso null(findrep(xs@ys)) andalso length xs = length ys andalso
   407   is_full_cong_prems prems (xs ~~ ys)
   408 end
   409 
   410 fun cong_name (Const (a, _)) = Some a
   411   | cong_name (Free (a, _)) = Some ("Free: " ^ a)
   412   | cong_name _ = None;
   413 
   414 fun add_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   415   let
   416     val (lhs, _) = Drule.dest_equals (Drule.strip_imp_concl (cprop_of thm)) handle TERM _ =>
   417       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   418 (*   val lhs = Pattern.eta_contract lhs; *)
   419     val a = (case cong_name (head_of (term_of lhs)) of
   420         Some a => a
   421       | None =>
   422         raise SIMPLIFIER ("Congruence must start with a constant or free variable", thm));
   423     val (alist,weak) = congs
   424     val alist2 = overwrite_warn (alist, (a,{lhs=lhs, thm=thm}))
   425            ("Overwriting congruence rule for " ^ quote a);
   426     val weak2 = if is_full_cong thm then weak else a::weak
   427   in
   428     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   429   end;
   430 
   431 val (op add_congs) = foldl add_cong;
   432 
   433 
   434 (* del_congs *)
   435 
   436 fun del_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   437   let
   438     val (lhs, _) = Logic.dest_equals (concl_of thm) handle TERM _ =>
   439       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   440 (*   val lhs = Pattern.eta_contract lhs; *)
   441     val a = (case cong_name (head_of lhs) of
   442         Some a => a
   443       | None =>
   444         raise SIMPLIFIER ("Congruence must start with a constant", thm));
   445     val (alist,_) = congs
   446     val alist2 = filter (fn (x,_)=> x<>a) alist
   447     val weak2 = mapfilter (fn(a,{thm,...}) => if is_full_cong thm then None
   448                                               else Some a)
   449                    alist2
   450   in
   451     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   452   end;
   453 
   454 val (op del_congs) = foldl del_cong;
   455 
   456 
   457 (* add_simprocs *)
   458 
   459 fun add_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   460     (name, lhs, proc, id)) =
   461   let val {sign, t, ...} = rep_cterm lhs
   462   in (trace_term false ("Adding simplification procedure " ^ quote name ^ " for")
   463       sign t;
   464     mk_mss (rules, congs,
   465       Net.insert_term ((t, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   466         handle Net.INSERT =>
   467             (warning ("Ignoring duplicate simplification procedure \""
   468                       ^ name ^ "\"");
   469              procs),
   470         bounds, prems, mk_rews, termless,depth))
   471   end;
   472 
   473 fun add_simproc (mss, (name, lhss, proc, id)) =
   474   foldl add_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   475 
   476 val add_simprocs = foldl add_simproc;
   477 
   478 
   479 (* del_simprocs *)
   480 
   481 fun del_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   482     (name, lhs, proc, id)) =
   483   mk_mss (rules, congs,
   484     Net.delete_term ((term_of lhs, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   485       handle Net.DELETE =>
   486           (warning ("Simplification procedure \"" ^ name ^
   487                        "\" not in simpset"); procs),
   488       bounds, prems, mk_rews, termless, depth);
   489 
   490 fun del_simproc (mss, (name, lhss, proc, id)) =
   491   foldl del_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   492 
   493 val del_simprocs = foldl del_simproc;
   494 
   495 
   496 (* prems *)
   497 
   498 fun add_prems (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thms) =
   499   mk_mss (rules, congs, procs, bounds, thms @ prems, mk_rews, termless, depth);
   500 
   501 fun prems_of_mss (Mss {prems, ...}) = prems;
   502 
   503 
   504 (* mk_rews *)
   505 
   506 fun set_mk_rews
   507   (Mss {rules, congs, procs, bounds, prems, mk_rews, termless, depth}, mk) =
   508     mk_mss (rules, congs, procs, bounds, prems,
   509             {mk=mk, mk_sym= #mk_sym mk_rews, mk_eq_True= #mk_eq_True mk_rews},
   510             termless, depth);
   511 
   512 fun set_mk_sym
   513   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_sym) =
   514     mk_mss (rules, congs, procs, bounds, prems,
   515             {mk= #mk mk_rews, mk_sym= mk_sym, mk_eq_True= #mk_eq_True mk_rews},
   516             termless,depth);
   517 
   518 fun set_mk_eq_True
   519   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_eq_True) =
   520     mk_mss (rules, congs, procs, bounds, prems,
   521             {mk= #mk mk_rews, mk_sym= #mk_sym mk_rews, mk_eq_True= mk_eq_True},
   522             termless,depth);
   523 
   524 fun get_mk_rews    (Mss {mk_rews,...}) = #mk         mk_rews
   525 fun get_mk_sym     (Mss {mk_rews,...}) = #mk_sym     mk_rews
   526 fun get_mk_eq_True (Mss {mk_rews,...}) = #mk_eq_True mk_rews
   527 
   528 (* termless *)
   529 
   530 fun set_termless
   531   (Mss {rules, congs, procs, bounds, prems, mk_rews, depth, ...}, termless) =
   532     mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth);
   533 
   534 
   535 
   536 (** rewriting **)
   537 
   538 (*
   539   Uses conversions, see:
   540     L C Paulson, A higher-order implementation of rewriting,
   541     Science of Computer Programming 3 (1983), pages 119-149.
   542 *)
   543 
   544 val dest_eq = Drule.dest_equals o cprop_of;
   545 val lhs_of = fst o dest_eq;
   546 val rhs_of = snd o dest_eq;
   547 
   548 fun beta_eta_conversion t =
   549   let val thm = beta_conversion true t;
   550   in transitive thm (eta_conversion (rhs_of thm)) end;
   551 
   552 fun check_conv msg thm thm' =
   553   let
   554     val thm'' = transitive thm (transitive
   555       (symmetric (beta_eta_conversion (lhs_of thm'))) thm')
   556   in (if msg then trace_thm "SUCCEEDED" thm' else (); Some thm'') end
   557   handle THM _ =>
   558     let val {sign, prop = _ $ _ $ prop0, ...} = rep_thm thm;
   559     in
   560       (trace_thm "Proved wrong thm (Check subgoaler?)" thm';
   561        trace_term false "Should have proved:" sign prop0;
   562        None)
   563     end;
   564 
   565 
   566 (* mk_procrule *)
   567 
   568 fun mk_procrule thm =
   569   let val (_,prems,lhs,elhs,rhs,_) = decomp_simp thm
   570   in if rewrite_rule_extra_vars prems lhs rhs
   571      then (prthm true "Extra vars on rhs:" thm; [])
   572      else [mk_rrule2{thm=thm, name="", lhs=lhs, elhs=elhs, perm=false}]
   573   end;
   574 
   575 
   576 (* conversion to apply the meta simpset to a term *)
   577 
   578 (* Since the rewriting strategy is bottom-up, we avoid re-normalizing already
   579    normalized terms by carrying around the rhs of the rewrite rule just
   580    applied. This is called the `skeleton'. It is decomposed in parallel
   581    with the term. Once a Var is encountered, the corresponding term is
   582    already in normal form.
   583    skel0 is a dummy skeleton that is to enforce complete normalization.
   584 *)
   585 val skel0 = Bound 0;
   586 
   587 (* Use rhs as skeleton only if the lhs does not contain unnormalized bits.
   588    The latter may happen iff there are weak congruence rules for constants
   589    in the lhs.
   590 *)
   591 fun uncond_skel((_,weak),(lhs,rhs)) =
   592   if null weak then rhs (* optimization *)
   593   else if exists_Const (fn (c,_) => c mem weak) lhs then skel0
   594        else rhs;
   595 
   596 (* Behaves like unconditional rule if rhs does not contain vars not in the lhs.
   597    Otherwise those vars may become instantiated with unnormalized terms
   598    while the premises are solved.
   599 *)
   600 fun cond_skel(args as (congs,(lhs,rhs))) =
   601   if term_varnames rhs subset term_varnames lhs then uncond_skel(args)
   602   else skel0;
   603 
   604 (*
   605   we try in order:
   606     (1) beta reduction
   607     (2) unconditional rewrite rules
   608     (3) conditional rewrite rules
   609     (4) simplification procedures
   610 
   611   IMPORTANT: rewrite rules must not introduce new Vars or TVars!
   612 
   613 *)
   614 
   615 fun rewritec (prover, signt, maxt)
   616              (mss as Mss{rules, procs, termless, prems, congs, depth,...}) t =
   617   let
   618     val eta_thm = Thm.eta_conversion t;
   619     val eta_t' = rhs_of eta_thm;
   620     val eta_t = term_of eta_t';
   621     val tsigt = Sign.tsig_of signt;
   622     fun rew {thm, name, lhs, elhs, fo, perm} =
   623       let
   624         val {sign, prop, maxidx, ...} = rep_thm thm;
   625         val _ = if Sign.subsig (sign, signt) then ()
   626                 else (prthm true "Ignoring rewrite rule from different theory:" thm;
   627                       raise Pattern.MATCH);
   628         val (rthm, elhs') = if maxt = ~1 then (thm, elhs)
   629           else (Thm.incr_indexes (maxt+1) thm, Thm.cterm_incr_indexes (maxt+1) elhs);
   630         val insts = if fo then Thm.cterm_first_order_match (elhs', eta_t')
   631                           else Thm.cterm_match (elhs', eta_t');
   632         val thm' = Thm.instantiate insts (Thm.rename_boundvars lhs eta_t rthm);
   633         val prop' = #prop (rep_thm thm');
   634         val unconditional = (Logic.count_prems (prop',0) = 0);
   635         val (lhs', rhs') = Logic.dest_equals (Logic.strip_imp_concl prop')
   636       in
   637         if perm andalso not (termless (rhs', lhs'))
   638         then (trace_named_thm "Cannot apply permutative rewrite rule" (thm, name);
   639               trace_thm "Term does not become smaller:" thm'; None)
   640         else (trace_named_thm "Applying instance of rewrite rule" (thm, name);
   641            if unconditional
   642            then
   643              (trace_thm "Rewriting:" thm';
   644               let val lr = Logic.dest_equals prop;
   645                   val Some thm'' = check_conv false eta_thm thm'
   646               in Some (thm'', uncond_skel (congs, lr)) end)
   647            else
   648              (trace_thm "Trying to rewrite:" thm';
   649               case incr_depth mss of
   650                 None => (trace_thm "FAILED - reached depth limit" thm'; None)
   651               | Some mss =>
   652               (case prover mss thm' of
   653                 None       => (trace_thm "FAILED" thm'; None)
   654               | Some thm2 =>
   655                   (case check_conv true eta_thm thm2 of
   656                      None => None |
   657                      Some thm2' =>
   658                        let val concl = Logic.strip_imp_concl prop
   659                            val lr = Logic.dest_equals concl
   660                        in Some (thm2', cond_skel (congs, lr)) end))))
   661       end
   662 
   663     fun rews [] = None
   664       | rews (rrule :: rrules) =
   665           let val opt = rew rrule handle Pattern.MATCH => None
   666           in case opt of None => rews rrules | some => some end;
   667 
   668     fun sort_rrules rrs = let
   669       fun is_simple({thm, ...}:rrule) = case #prop (rep_thm thm) of
   670                                       Const("==",_) $ _ $ _ => true
   671                                       | _                   => false
   672       fun sort []        (re1,re2) = re1 @ re2
   673         | sort (rr::rrs) (re1,re2) = if is_simple rr
   674                                      then sort rrs (rr::re1,re2)
   675                                      else sort rrs (re1,rr::re2)
   676     in sort rrs ([],[]) end
   677 
   678     fun proc_rews ([]:simproc list) = None
   679       | proc_rews ({name, proc, lhs, ...} :: ps) =
   680           if Pattern.matches tsigt (term_of lhs, term_of t) then
   681             (debug_term false ("Trying procedure " ^ quote name ^ " on:") signt eta_t;
   682              case transform_failure (curry SIMPROC_FAIL name)
   683                  (fn () => proc signt prems eta_t) () of
   684                None => (debug false "FAILED"; proc_rews ps)
   685              | Some raw_thm =>
   686                  (trace_thm ("Procedure " ^ quote name ^ " produced rewrite rule:") raw_thm;
   687                   (case rews (mk_procrule raw_thm) of
   688                     None => (trace_cterm true ("IGNORED result of simproc " ^ quote name ^
   689                       " -- does not match") t; proc_rews ps)
   690                   | some => some)))
   691           else proc_rews ps;
   692   in case eta_t of
   693        Abs _ $ _ => Some (transitive eta_thm
   694          (beta_conversion false eta_t'), skel0)
   695      | _ => (case rews (sort_rrules (Net.match_term rules eta_t)) of
   696                None => proc_rews (Net.match_term procs eta_t)
   697              | some => some)
   698   end;
   699 
   700 
   701 (* conversion to apply a congruence rule to a term *)
   702 
   703 fun congc (prover,signt,maxt) {thm=cong,lhs=lhs} t =
   704   let val {sign, ...} = rep_thm cong
   705       val _ = if Sign.subsig (sign, signt) then ()
   706                  else error("Congruence rule from different theory")
   707       val rthm = if maxt = ~1 then cong else Thm.incr_indexes (maxt+1) cong;
   708       val rlhs = fst (Drule.dest_equals (Drule.strip_imp_concl (cprop_of rthm)));
   709       val insts = Thm.cterm_match (rlhs, t)
   710       (* Pattern.match can raise Pattern.MATCH;
   711          is handled when congc is called *)
   712       val thm' = Thm.instantiate insts (Thm.rename_boundvars (term_of rlhs) (term_of t) rthm);
   713       val unit = trace_thm "Applying congruence rule:" thm';
   714       fun err (msg, thm) = (trace_thm msg thm; None)
   715   in case prover thm' of
   716        None => err ("Congruence proof failed.  Could not prove", thm')
   717      | Some thm2 => (case check_conv true (beta_eta_conversion t) thm2 of
   718           None => err ("Congruence proof failed.  Should not have proved", thm2)
   719         | Some thm2' =>
   720             if op aconv (pairself term_of (dest_equals (cprop_of thm2')))
   721             then None else Some thm2')
   722   end;
   723 
   724 val (cA, (cB, cC)) =
   725   apsnd dest_equals (dest_implies (hd (cprems_of Drule.imp_cong)));
   726 
   727 fun transitive1 None None = None
   728   | transitive1 (Some thm1) None = Some thm1
   729   | transitive1 None (Some thm2) = Some thm2
   730   | transitive1 (Some thm1) (Some thm2) = Some (transitive thm1 thm2)
   731 
   732 fun transitive2 thm = transitive1 (Some thm);
   733 fun transitive3 thm = transitive1 thm o Some;
   734 
   735 fun imp_cong' e = combination (combination refl_implies e);
   736 
   737 fun bottomc ((simprem,useprem,mutsimp), prover, sign, maxidx) =
   738   let
   739     fun botc skel mss t =
   740           if is_Var skel then None
   741           else
   742           (case subc skel mss t of
   743              some as Some thm1 =>
   744                (case rewritec (prover, sign, maxidx) mss (rhs_of thm1) of
   745                   Some (thm2, skel2) =>
   746                     transitive2 (transitive thm1 thm2)
   747                       (botc skel2 mss (rhs_of thm2))
   748                 | None => some)
   749            | None =>
   750                (case rewritec (prover, sign, maxidx) mss t of
   751                   Some (thm2, skel2) => transitive2 thm2
   752                     (botc skel2 mss (rhs_of thm2))
   753                 | None => None))
   754 
   755     and try_botc mss t =
   756           (case botc skel0 mss t of
   757              Some trec1 => trec1 | None => (reflexive t))
   758 
   759     and subc skel
   760           (mss as Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) t0 =
   761        (case term_of t0 of
   762            Abs (a, T, t) =>
   763              let val b = variant bounds a
   764                  val (v, t') = Thm.dest_abs (Some ("." ^ b)) t0
   765                  val mss' = mk_mss (rules, congs, procs, b :: bounds, prems, mk_rews, termless,depth)
   766                  val skel' = case skel of Abs (_, _, sk) => sk | _ => skel0
   767              in case botc skel' mss' t' of
   768                   Some thm => Some (abstract_rule a v thm)
   769                 | None => None
   770              end
   771          | t $ _ => (case t of
   772              Const ("==>", _) $ _  => impc t0 mss
   773            | Abs _ =>
   774                let val thm = beta_conversion false t0
   775                in case subc skel0 mss (rhs_of thm) of
   776                     None => Some thm
   777                   | Some thm' => Some (transitive thm thm')
   778                end
   779            | _  =>
   780                let fun appc () =
   781                      let
   782                        val (tskel, uskel) = case skel of
   783                            tskel $ uskel => (tskel, uskel)
   784                          | _ => (skel0, skel0);
   785                        val (ct, cu) = Thm.dest_comb t0
   786                      in
   787                      (case botc tskel mss ct of
   788                         Some thm1 =>
   789                           (case botc uskel mss cu of
   790                              Some thm2 => Some (combination thm1 thm2)
   791                            | None => Some (combination thm1 (reflexive cu)))
   792                       | None =>
   793                           (case botc uskel mss cu of
   794                              Some thm1 => Some (combination (reflexive ct) thm1)
   795                            | None => None))
   796                      end
   797                    val (h, ts) = strip_comb t
   798                in case cong_name h of
   799                     Some a =>
   800                       (case assoc_string (fst congs, a) of
   801                          None => appc ()
   802                        | Some cong =>
   803 (* post processing: some partial applications h t1 ... tj, j <= length ts,
   804    may be a redex. Example: map (%x.x) = (%xs.xs) wrt map_cong *)
   805                           (let
   806                              val thm = congc (prover mss, sign, maxidx) cong t0;
   807                              val t = if_none (apsome rhs_of thm) t0;
   808                              val (cl, cr) = Thm.dest_comb t
   809                              val dVar = Var(("", 0), dummyT)
   810                              val skel =
   811                                list_comb (h, replicate (length ts) dVar)
   812                            in case botc skel mss cl of
   813                                 None => thm
   814                               | Some thm' => transitive3 thm
   815                                   (combination thm' (reflexive cr))
   816                            end handle TERM _ => error "congc result"
   817                                     | Pattern.MATCH => appc ()))
   818                   | _ => appc ()
   819                end)
   820          | _ => None)
   821 
   822     and impc ct mss =
   823       if mutsimp then mut_impc0 [] ct [] [] mss else nonmut_impc ct mss
   824 
   825     and rules_of_prem mss prem =
   826       if maxidx_of_term (term_of prem) <> ~1
   827       then (trace_cterm true
   828         "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem; ([], None))
   829       else
   830         let val asm = assume prem
   831         in (extract_safe_rrules (mss, asm), Some asm) end
   832 
   833     and add_rrules (rrss, asms) mss =
   834       add_prems (foldl (insert_rrule true) (mss, flat rrss), mapfilter I asms)
   835 
   836     and disch r (prem, eq) =
   837       let
   838         val (lhs, rhs) = dest_eq eq;
   839         val eq' = implies_elim (Thm.instantiate
   840           ([], [(cA, prem), (cB, lhs), (cC, rhs)]) Drule.imp_cong)
   841           (implies_intr prem eq)
   842       in if not r then eq' else
   843         let
   844           val (prem', concl) = dest_implies lhs;
   845           val (prem'', _) = dest_implies rhs
   846         in transitive (transitive
   847           (Thm.instantiate ([], [(cA, prem'), (cB, prem), (cC, concl)])
   848              Drule.swap_prems_eq) eq')
   849           (Thm.instantiate ([], [(cA, prem), (cB, prem''), (cC, concl)])
   850              Drule.swap_prems_eq)
   851         end
   852       end
   853 
   854     and rebuild [] _ _ _ _ eq = eq
   855       | rebuild (prem :: prems) concl (rrs :: rrss) (asm :: asms) mss eq =
   856           let
   857             val mss' = add_rrules (rev rrss, rev asms) mss;
   858             val concl' =
   859               Drule.mk_implies (prem, if_none (apsome rhs_of eq) concl);
   860             val dprem = apsome (curry (disch false) prem)
   861           in case rewritec (prover, sign, maxidx) mss' concl' of
   862               None => rebuild prems concl' rrss asms mss (dprem eq)
   863             | Some (eq', _) => transitive2 (foldl (disch false o swap)
   864                   (the (transitive3 (dprem eq) eq'), prems))
   865                 (mut_impc0 (rev prems) (rhs_of eq') (rev rrss) (rev asms) mss)
   866           end
   867           
   868     and mut_impc0 prems concl rrss asms mss =
   869       let
   870         val prems' = strip_imp_prems concl;
   871         val (rrss', asms') = split_list (map (rules_of_prem mss) prems')
   872       in mut_impc (prems @ prems') (strip_imp_concl concl) (rrss @ rrss')
   873         (asms @ asms') [] [] [] [] mss ~1 ~1
   874       end
   875  
   876     and mut_impc [] concl [] [] prems' rrss' asms' eqns mss changed k =
   877         transitive1 (foldl (fn (eq2, (eq1, prem)) => transitive1 eq1
   878             (apsome (curry (disch false) prem) eq2)) (None, eqns ~~ prems'))
   879           (if changed > 0 then
   880              mut_impc (rev prems') concl (rev rrss') (rev asms')
   881                [] [] [] [] mss ~1 changed
   882            else rebuild prems' concl rrss' asms' mss
   883              (botc skel0 (add_rrules (rev rrss', rev asms') mss) concl))
   884 
   885       | mut_impc (prem :: prems) concl (rrs :: rrss) (asm :: asms)
   886           prems' rrss' asms' eqns mss changed k =
   887         case (if k = 0 then None else botc skel0 (add_rrules
   888           (rev rrss' @ rrss, rev asms' @ asms) mss) prem) of
   889             None => mut_impc prems concl rrss asms (prem :: prems')
   890               (rrs :: rrss') (asm :: asms') (None :: eqns) mss changed
   891               (if k = 0 then 0 else k - 1)
   892           | Some eqn =>
   893             let
   894               val prem' = rhs_of eqn;
   895               val tprems = map term_of prems;
   896               val i = 1 + foldl Int.max (~1, map (fn p =>
   897                 find_index_eq p tprems) (#hyps (rep_thm eqn)));
   898               val (rrs', asm') = rules_of_prem mss prem'
   899             in mut_impc prems concl rrss asms (prem' :: prems')
   900               (rrs' :: rrss') (asm' :: asms') (Some (foldr (disch true)
   901                 (take (i, prems), imp_cong' eqn (reflexive (Drule.list_implies
   902                   (drop (i, prems), concl))))) :: eqns) mss (length prems') ~1
   903             end
   904 
   905      (* legacy code - only for backwards compatibility *)
   906      and nonmut_impc ct mss =
   907        let val (prem, conc) = dest_implies ct;
   908            val thm1 = if simprem then botc skel0 mss prem else None;
   909            val prem1 = if_none (apsome rhs_of thm1) prem;
   910            val mss1 = if not useprem then mss else add_rrules
   911              (apsnd single (apfst single (rules_of_prem mss prem1))) mss
   912        in (case botc skel0 mss1 conc of
   913            None => (case thm1 of
   914                None => None
   915              | Some thm1' => Some (imp_cong' thm1' (reflexive conc)))
   916          | Some thm2 =>
   917            let val thm2' = disch false (prem1, thm2)
   918            in (case thm1 of
   919                None => Some thm2'
   920              | Some thm1' =>
   921                  Some (transitive (imp_cong' thm1' (reflexive conc)) thm2'))
   922            end)
   923        end
   924 
   925  in try_botc end;
   926 
   927 
   928 (*** Meta-rewriting: rewrites t to u and returns the theorem t==u ***)
   929 
   930 (*
   931   Parameters:
   932     mode = (simplify A,
   933             use A in simplifying B,
   934             use prems of B (if B is again a meta-impl.) to simplify A)
   935            when simplifying A ==> B
   936     mss: contains equality theorems of the form [|p1,...|] ==> t==u
   937     prover: how to solve premises in conditional rewrites and congruences
   938 *)
   939 
   940 fun rewrite_cterm mode prover mss ct =
   941   let val {sign, t, maxidx, ...} = rep_cterm ct
   942       val Mss{depth, ...} = mss
   943   in simp_depth := depth;
   944      bottomc (mode, prover, sign, maxidx) mss ct
   945   end
   946   handle THM (s, _, thms) =>
   947     error ("Exception THM was raised in simplifier:\n" ^ s ^ "\n" ^
   948       Pretty.string_of (Display.pretty_thms thms));
   949 
   950 (*In [A1,...,An]==>B, rewrite the selected A's only -- for rewrite_goals_tac*)
   951 fun goals_conv pred cv =
   952   let fun gconv i ct =
   953         let val (A,B) = Drule.dest_implies ct
   954         in imp_cong' (if pred i then cv A else reflexive A) (gconv (i+1) B) end
   955         handle TERM _ => reflexive ct
   956   in gconv 1 end;
   957 
   958 (* Rewrite A in !!x1,...,xn. A *)
   959 fun forall_conv cv ct =
   960   let val p as (ct1, ct2) = Thm.dest_comb ct
   961   in (case pairself term_of p of
   962       (Const ("all", _), Abs (s, _, _)) =>
   963          let val (v, ct') = Thm.dest_abs (Some "@") ct2;
   964          in Thm.combination (Thm.reflexive ct1)
   965            (Thm.abstract_rule s v (forall_conv cv ct'))
   966          end
   967     | _ => cv ct)
   968   end handle TERM _ => cv ct;
   969 
   970 (*Use a conversion to transform a theorem*)
   971 fun fconv_rule cv th = equal_elim (cv (cprop_of th)) th;
   972 
   973 (*Rewrite a cterm*)
   974 fun rewrite_aux _ _ [] = (fn ct => Thm.reflexive ct)
   975   | rewrite_aux prover full thms = rewrite_cterm (full, false, false) prover (mss_of thms);
   976 
   977 (*Rewrite a theorem*)
   978 fun simplify_aux _ _ [] = (fn th => th)
   979   | simplify_aux prover full thms =
   980       fconv_rule (rewrite_cterm (full, false, false) prover (mss_of thms));
   981 
   982 fun rewrite_thm mode prover mss = fconv_rule (rewrite_cterm mode prover mss);
   983 
   984 (*Rewrite the subgoals of a proof state (represented by a theorem) *)
   985 fun rewrite_goals_rule_aux _ []   th = th
   986   | rewrite_goals_rule_aux prover thms th =
   987       fconv_rule (goals_conv (K true) (rewrite_cterm (true, true, false) prover
   988         (mss_of thms))) th;
   989 
   990 (*Rewrite the subgoal of a proof state (represented by a theorem) *)
   991 fun rewrite_goal_rule mode prover mss i thm =
   992   if 0 < i  andalso  i <= nprems_of thm
   993   then fconv_rule (goals_conv (fn j => j=i) (rewrite_cterm mode prover mss)) thm
   994   else raise THM("rewrite_goal_rule",i,[thm]);
   995 
   996 
   997 (*simple term rewriting -- without proofs*)
   998 fun rewrite_term sg rules procs =
   999   Pattern.rewrite_term (Sign.tsig_of sg) (map decomp_simp' rules) procs;
  1000 
  1001 end;
  1002 
  1003 structure BasicMetaSimplifier: BASIC_META_SIMPLIFIER = MetaSimplifier;
  1004 open BasicMetaSimplifier;