src/Pure/meta_simplifier.ML
author nipkow
Tue Aug 28 14:25:26 2001 +0200 (2001-08-28)
changeset 11505 a410fa8acfca
parent 11504 935f9e8de0d0
child 11672 8e75b78f33f3
permissions -rw-r--r--
Implemented indentation schema for conditional rewrite trace.
     1 (*  Title:      Pure/meta_simplifier.ML
     2     ID:         $Id$
     3     Author:     Tobias Nipkow
     4     Copyright   1994  University of Cambridge
     5 
     6 Meta Simplification
     7 *)
     8 
     9 signature META_SIMPLIFIER =
    10 sig
    11   exception SIMPLIFIER of string * thm
    12   type meta_simpset
    13   val dest_mss		: meta_simpset ->
    14     {simps: thm list, congs: thm list, procs: (string * cterm list) list}
    15   val empty_mss         : meta_simpset
    16   val clear_mss		: meta_simpset -> meta_simpset
    17   val merge_mss		: meta_simpset * meta_simpset -> meta_simpset
    18   val add_simps         : meta_simpset * thm list -> meta_simpset
    19   val del_simps         : meta_simpset * thm list -> meta_simpset
    20   val mss_of            : thm list -> meta_simpset
    21   val add_congs         : meta_simpset * thm list -> meta_simpset
    22   val del_congs         : meta_simpset * thm list -> meta_simpset
    23   val add_simprocs	: meta_simpset *
    24     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    25       -> meta_simpset
    26   val del_simprocs	: meta_simpset *
    27     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    28       -> meta_simpset
    29   val add_prems         : meta_simpset * thm list -> meta_simpset
    30   val prems_of_mss      : meta_simpset -> thm list
    31   val set_mk_rews       : meta_simpset * (thm -> thm list) -> meta_simpset
    32   val set_mk_sym        : meta_simpset * (thm -> thm option) -> meta_simpset
    33   val set_mk_eq_True    : meta_simpset * (thm -> thm option) -> meta_simpset
    34   val set_termless      : meta_simpset * (term * term -> bool) -> meta_simpset
    35   val trace_simp        : bool ref
    36   val debug_simp        : bool ref
    37   val rewrite_cterm     : bool * bool * bool
    38                           -> (meta_simpset -> thm -> thm option)
    39                           -> meta_simpset -> cterm -> thm
    40   val rewrite_rule_aux  : (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
    41   val rewrite_thm       : bool * bool * bool
    42                           -> (meta_simpset -> thm -> thm option)
    43                           -> meta_simpset -> thm -> thm
    44   val rewrite_goals_rule_aux: (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
    45   val rewrite_goal_rule : bool* bool * bool
    46                           -> (meta_simpset -> thm -> thm option)
    47                           -> meta_simpset -> int -> thm -> thm
    48 end;
    49 
    50 structure MetaSimplifier : META_SIMPLIFIER =
    51 struct
    52 
    53 (** diagnostics **)
    54 
    55 exception SIMPLIFIER of string * thm;
    56 
    57 val simp_depth = ref 0;
    58 
    59 fun println a = writeln(replicate_string (!simp_depth) " " ^ a)
    60 
    61 fun prnt warn a = if warn then warning a else println a;
    62 
    63 fun prtm warn a sign t =
    64   (prnt warn a; prnt warn (Sign.string_of_term sign t));
    65 
    66 fun prctm warn a t =
    67   (prnt warn a; prnt warn (Display.string_of_cterm t));
    68 
    69 fun prthm warn a thm =
    70   let val {sign, prop, ...} = rep_thm thm
    71   in prtm warn a sign prop end;
    72 
    73 val trace_simp = ref false;
    74 val debug_simp = ref false;
    75 
    76 fun trace warn a = if !trace_simp then prnt warn a else ();
    77 fun debug warn a = if !debug_simp then prnt warn a else ();
    78 
    79 fun trace_term warn a sign t = if !trace_simp then prtm warn a sign t else ();
    80 fun trace_cterm warn a t = if !trace_simp then prctm warn a t else ();
    81 fun debug_term warn a sign t = if !debug_simp then prtm warn a sign t else ();
    82 
    83 fun trace_thm warn a thm =
    84   let val {sign, prop, ...} = rep_thm thm
    85   in trace_term warn a sign prop end;
    86 
    87 
    88 
    89 (** meta simp sets **)
    90 
    91 (* basic components *)
    92 
    93 type rrule = {thm: thm, lhs: term, elhs: cterm, fo: bool, perm: bool};
    94 (* thm: the rewrite rule
    95    lhs: the left-hand side
    96    elhs: the etac-contracted lhs.
    97    fo:  use first-order matching
    98    perm: the rewrite rule is permutative
    99 Reamrks:
   100   - elhs is used for matching,
   101     lhs only for preservation of bound variable names.
   102   - fo is set iff
   103     either elhs is first-order (no Var is applied),
   104            in which case fo-matching is complete,
   105     or elhs is not a pattern,
   106        in which case there is nothing better to do.
   107 *)
   108 type cong = {thm: thm, lhs: cterm};
   109 type simproc =
   110  {name: string, proc: Sign.sg -> thm list -> term -> thm option, lhs: cterm, id: stamp};
   111 
   112 fun eq_rrule ({thm = thm1, ...}: rrule, {thm = thm2, ...}: rrule) =
   113   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   114 
   115 fun eq_cong ({thm = thm1, ...}: cong, {thm = thm2, ...}: cong) = 
   116   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   117 
   118 fun eq_prem (thm1, thm2) =
   119   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   120 
   121 fun eq_simproc ({id = s1, ...}:simproc, {id = s2, ...}:simproc) = (s1 = s2);
   122 
   123 fun mk_simproc (name, proc, lhs, id) =
   124   {name = name, proc = proc, lhs = lhs, id = id};
   125 
   126 
   127 (* datatype mss *)
   128 
   129 (*
   130   A "mss" contains data needed during conversion:
   131     rules: discrimination net of rewrite rules;
   132     congs: association list of congruence rules and
   133            a list of `weak' congruence constants.
   134            A congruence is `weak' if it avoids normalization of some argument.
   135     procs: discrimination net of simplification procedures
   136       (functions that prove rewrite rules on the fly);
   137     bounds: names of bound variables already used
   138       (for generating new names when rewriting under lambda abstractions);
   139     prems: current premises;
   140     mk_rews: mk: turns simplification thms into rewrite rules;
   141              mk_sym: turns == around; (needs Drule!)
   142              mk_eq_True: turns P into P == True - logic specific;
   143     termless: relation for ordered rewriting;
   144     depth: depth of conditional rewriting;
   145 *)
   146 
   147 datatype meta_simpset =
   148   Mss of {
   149     rules: rrule Net.net,
   150     congs: (string * cong) list * string list,
   151     procs: simproc Net.net,
   152     bounds: string list,
   153     prems: thm list,
   154     mk_rews: {mk: thm -> thm list,
   155               mk_sym: thm -> thm option,
   156               mk_eq_True: thm -> thm option},
   157     termless: term * term -> bool,
   158     depth: int};
   159 
   160 fun mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth) =
   161   Mss {rules = rules, congs = congs, procs = procs, bounds = bounds,
   162        prems=prems, mk_rews=mk_rews, termless=termless, depth=depth};
   163 
   164 fun upd_rules(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}, rules') =
   165   mk_mss(rules',congs,procs,bounds,prems,mk_rews,termless,depth);
   166 
   167 val empty_mss =
   168   let val mk_rews = {mk = K [], mk_sym = K None, mk_eq_True = K None}
   169   in mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, Term.termless, 0) end;
   170 
   171 fun clear_mss (Mss {mk_rews, termless, ...}) =
   172   mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, termless,0);
   173 
   174 fun incr_depth(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) =
   175   mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth+1)
   176  
   177 
   178 
   179 (** simpset operations **)
   180 
   181 (* term variables *)
   182 
   183 val add_term_varnames = foldl_aterms (fn (xs, Var (x, _)) => ins_ix (x, xs) | (xs, _) => xs);
   184 fun term_varnames t = add_term_varnames ([], t);
   185 
   186 
   187 (* dest_mss *)
   188 
   189 fun dest_mss (Mss {rules, congs, procs, ...}) =
   190   {simps = map (fn (_, {thm, ...}) => thm) (Net.dest rules),
   191    congs = map (fn (_, {thm, ...}) => thm) (fst congs),
   192    procs =
   193      map (fn (_, {name, lhs, id, ...}) => ((name, lhs), id)) (Net.dest procs)
   194      |> partition_eq eq_snd
   195      |> map (fn ps => (#1 (#1 (hd ps)), map (#2 o #1) ps))
   196      |> Library.sort_wrt #1};
   197 
   198 
   199 (* merge_mss *)	      (*NOTE: ignores mk_rews, termless and depth of 2nd mss*)
   200 
   201 fun merge_mss
   202  (Mss {rules = rules1, congs = (congs1,weak1), procs = procs1,
   203        bounds = bounds1, prems = prems1, mk_rews, termless, depth},
   204   Mss {rules = rules2, congs = (congs2,weak2), procs = procs2,
   205        bounds = bounds2, prems = prems2, ...}) =
   206       mk_mss
   207        (Net.merge (rules1, rules2, eq_rrule),
   208         (generic_merge (eq_cong o pairself snd) I I congs1 congs2,
   209         merge_lists weak1 weak2),
   210         Net.merge (procs1, procs2, eq_simproc),
   211         merge_lists bounds1 bounds2,
   212         generic_merge eq_prem I I prems1 prems2,
   213         mk_rews, termless, depth);
   214 
   215 
   216 (* add_simps *)
   217 
   218 fun mk_rrule2{thm,lhs,elhs,perm} =
   219   let val fo = Pattern.first_order (term_of elhs) orelse not(Pattern.pattern (term_of elhs))
   220   in {thm=thm,lhs=lhs,elhs=elhs,fo=fo,perm=perm} end
   221 
   222 fun insert_rrule(mss as Mss {rules,...},
   223                  rrule as {thm,lhs,elhs,perm}) =
   224   (trace_thm false "Adding rewrite rule:" thm;
   225    let val rrule2 as {elhs,...} = mk_rrule2 rrule
   226        val rules' = Net.insert_term ((term_of elhs, rrule2), rules, eq_rrule)
   227    in upd_rules(mss,rules') end
   228    handle Net.INSERT =>
   229      (prthm true "Ignoring duplicate rewrite rule:" thm; mss));
   230 
   231 fun vperm (Var _, Var _) = true
   232   | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t)
   233   | vperm (t1 $ t2, u1 $ u2) = vperm (t1, u1) andalso vperm (t2, u2)
   234   | vperm (t, u) = (t = u);
   235 
   236 fun var_perm (t, u) =
   237   vperm (t, u) andalso eq_set (term_varnames t, term_varnames u);
   238 
   239 (* FIXME: it seems that the conditions on extra variables are too liberal if
   240 prems are nonempty: does solving the prems really guarantee instantiation of
   241 all its Vars? Better: a dynamic check each time a rule is applied.
   242 *)
   243 fun rewrite_rule_extra_vars prems elhs erhs =
   244   not (term_varnames erhs subset foldl add_term_varnames (term_varnames elhs, prems))
   245   orelse
   246   not ((term_tvars erhs) subset
   247        (term_tvars elhs  union  List.concat(map term_tvars prems)));
   248 
   249 (*Simple test for looping rewrite rules and stupid orientations*)
   250 fun reorient sign prems lhs rhs =
   251    rewrite_rule_extra_vars prems lhs rhs
   252   orelse
   253    is_Var (head_of lhs)
   254   orelse
   255    (exists (apl (lhs, Logic.occs)) (rhs :: prems))
   256   orelse
   257    (null prems andalso
   258     Pattern.matches (#tsig (Sign.rep_sg sign)) (lhs, rhs))
   259     (*the condition "null prems" is necessary because conditional rewrites
   260       with extra variables in the conditions may terminate although
   261       the rhs is an instance of the lhs. Example: ?m < ?n ==> f(?n) == f(?m)*)
   262   orelse
   263    (is_Const lhs andalso not(is_Const rhs))
   264 
   265 fun decomp_simp thm =
   266   let val {sign, prop, ...} = rep_thm thm;
   267       val prems = Logic.strip_imp_prems prop;
   268       val concl = Drule.strip_imp_concl (cprop_of thm);
   269       val (lhs, rhs) = Drule.dest_equals concl handle TERM _ =>
   270         raise SIMPLIFIER ("Rewrite rule not a meta-equality", thm)
   271       val elhs = snd (Drule.dest_equals (cprop_of (Thm.eta_conversion lhs)));
   272       val elhs = if elhs=lhs then lhs else elhs (* try to share *)
   273       val erhs = Pattern.eta_contract (term_of rhs);
   274       val perm = var_perm (term_of elhs, erhs) andalso not (term_of elhs aconv erhs)
   275                  andalso not (is_Var (term_of elhs))
   276   in (sign, prems, term_of lhs, elhs, term_of rhs, perm) end;
   277 
   278 fun mk_eq_True (Mss{mk_rews={mk_eq_True,...},...}) thm =
   279   case mk_eq_True thm of
   280     None => []
   281   | Some eq_True => let val (_,_,lhs,elhs,_,_) = decomp_simp eq_True
   282                     in [{thm=eq_True, lhs=lhs, elhs=elhs, perm=false}] end;
   283 
   284 (* create the rewrite rule and possibly also the ==True variant,
   285    in case there are extra vars on the rhs *)
   286 fun rrule_eq_True(thm,lhs,elhs,rhs,mss,thm2) =
   287   let val rrule = {thm=thm, lhs=lhs, elhs=elhs, perm=false}
   288   in if (term_varnames rhs)  subset (term_varnames lhs) andalso
   289         (term_tvars rhs) subset (term_tvars lhs)
   290      then [rrule]
   291      else mk_eq_True mss thm2 @ [rrule]
   292   end;
   293 
   294 fun mk_rrule mss thm =
   295   let val (_,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   296   in if perm then [{thm=thm, lhs=lhs, elhs=elhs, perm=true}] else
   297      (* weak test for loops: *)
   298      if rewrite_rule_extra_vars prems lhs rhs orelse
   299         is_Var (term_of elhs)
   300      then mk_eq_True mss thm
   301      else rrule_eq_True(thm,lhs,elhs,rhs,mss,thm)
   302   end;
   303 
   304 fun orient_rrule mss thm =
   305   let val (sign,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   306   in if perm then [{thm=thm,lhs=lhs,elhs=elhs,perm=true}]
   307      else if reorient sign prems lhs rhs
   308           then if reorient sign prems rhs lhs
   309                then mk_eq_True mss thm
   310                else let val Mss{mk_rews={mk_sym,...},...} = mss
   311                     in case mk_sym thm of
   312                          None => []
   313                        | Some thm' =>
   314                            let val (_,_,lhs',elhs',rhs',_) = decomp_simp thm'
   315                            in rrule_eq_True(thm',lhs',elhs',rhs',mss,thm) end
   316                     end
   317           else rrule_eq_True(thm,lhs,elhs,rhs,mss,thm)
   318   end;
   319 
   320 fun extract_rews(Mss{mk_rews = {mk,...},...},thms) = flat(map mk thms);
   321 
   322 fun orient_comb_simps comb mk_rrule (mss,thms) =
   323   let val rews = extract_rews(mss,thms)
   324       val rrules = flat (map mk_rrule rews)
   325   in foldl comb (mss,rrules) end
   326 
   327 (* Add rewrite rules explicitly; do not reorient! *)
   328 fun add_simps(mss,thms) =
   329   orient_comb_simps insert_rrule (mk_rrule mss) (mss,thms);
   330 
   331 fun mss_of thms =
   332   foldl insert_rrule (empty_mss, flat(map (mk_rrule empty_mss) thms));
   333 
   334 fun extract_safe_rrules(mss,thm) =
   335   flat (map (orient_rrule mss) (extract_rews(mss,[thm])));
   336 
   337 fun add_safe_simp(mss,thm) =
   338   foldl insert_rrule (mss, extract_safe_rrules(mss,thm))
   339 
   340 (* del_simps *)
   341 
   342 fun del_rrule(mss as Mss {rules,...},
   343               rrule as {thm, elhs, ...}) =
   344   (upd_rules(mss, Net.delete_term ((term_of elhs, rrule), rules, eq_rrule))
   345    handle Net.DELETE =>
   346      (prthm true "Rewrite rule not in simpset:" thm; mss));
   347 
   348 fun del_simps(mss,thms) =
   349   orient_comb_simps del_rrule (map mk_rrule2 o mk_rrule mss) (mss,thms);
   350 
   351 
   352 (* add_congs *)
   353 
   354 fun is_full_cong_prems [] varpairs = null varpairs
   355   | is_full_cong_prems (p::prems) varpairs =
   356     (case Logic.strip_assums_concl p of
   357        Const("==",_) $ lhs $ rhs =>
   358          let val (x,xs) = strip_comb lhs and (y,ys) = strip_comb rhs
   359          in is_Var x  andalso  forall is_Bound xs  andalso
   360             null(findrep(xs))  andalso xs=ys andalso
   361             (x,y) mem varpairs andalso
   362             is_full_cong_prems prems (varpairs\(x,y))
   363          end
   364      | _ => false);
   365 
   366 fun is_full_cong thm =
   367 let val prems = prems_of thm
   368     and concl = concl_of thm
   369     val (lhs,rhs) = Logic.dest_equals concl
   370     val (f,xs) = strip_comb lhs
   371     and (g,ys) = strip_comb rhs
   372 in
   373   f=g andalso null(findrep(xs@ys)) andalso length xs = length ys andalso
   374   is_full_cong_prems prems (xs ~~ ys)
   375 end
   376 
   377 fun add_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   378   let
   379     val (lhs, _) = Drule.dest_equals (Drule.strip_imp_concl (cprop_of thm)) handle TERM _ =>
   380       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   381 (*   val lhs = Pattern.eta_contract lhs; *)
   382     val (a, _) = dest_Const (head_of (term_of lhs)) handle TERM _ =>
   383       raise SIMPLIFIER ("Congruence must start with a constant", thm);
   384     val (alist,weak) = congs
   385     val alist2 = overwrite_warn (alist, (a,{lhs=lhs, thm=thm}))
   386            ("Overwriting congruence rule for " ^ quote a);
   387     val weak2 = if is_full_cong thm then weak else a::weak
   388   in
   389     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   390   end;
   391 
   392 val (op add_congs) = foldl add_cong;
   393 
   394 
   395 (* del_congs *)
   396 
   397 fun del_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   398   let
   399     val (lhs, _) = Logic.dest_equals (concl_of thm) handle TERM _ =>
   400       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   401 (*   val lhs = Pattern.eta_contract lhs; *)
   402     val (a, _) = dest_Const (head_of lhs) handle TERM _ =>
   403       raise SIMPLIFIER ("Congruence must start with a constant", thm);
   404     val (alist,_) = congs
   405     val alist2 = filter (fn (x,_)=> x<>a) alist
   406     val weak2 = mapfilter (fn(a,{thm,...}) => if is_full_cong thm then None
   407                                               else Some a)
   408                    alist2
   409   in
   410     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   411   end;
   412 
   413 val (op del_congs) = foldl del_cong;
   414 
   415 
   416 (* add_simprocs *)
   417 
   418 fun add_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   419     (name, lhs, proc, id)) =
   420   let val {sign, t, ...} = rep_cterm lhs
   421   in (trace_term false ("Adding simplification procedure " ^ quote name ^ " for")
   422       sign t;
   423     mk_mss (rules, congs,
   424       Net.insert_term ((t, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   425         handle Net.INSERT => 
   426 	    (warning ("Ignoring duplicate simplification procedure \"" 
   427 	              ^ name ^ "\""); 
   428 	     procs),
   429         bounds, prems, mk_rews, termless,depth))
   430   end;
   431 
   432 fun add_simproc (mss, (name, lhss, proc, id)) =
   433   foldl add_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   434 
   435 val add_simprocs = foldl add_simproc;
   436 
   437 
   438 (* del_simprocs *)
   439 
   440 fun del_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   441     (name, lhs, proc, id)) =
   442   mk_mss (rules, congs,
   443     Net.delete_term ((term_of lhs, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   444       handle Net.DELETE => 
   445 	  (warning ("Simplification procedure \"" ^ name ^
   446 		       "\" not in simpset"); procs),
   447       bounds, prems, mk_rews, termless, depth);
   448 
   449 fun del_simproc (mss, (name, lhss, proc, id)) =
   450   foldl del_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   451 
   452 val del_simprocs = foldl del_simproc;
   453 
   454 
   455 (* prems *)
   456 
   457 fun add_prems (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thms) =
   458   mk_mss (rules, congs, procs, bounds, thms @ prems, mk_rews, termless, depth);
   459 
   460 fun prems_of_mss (Mss {prems, ...}) = prems;
   461 
   462 
   463 (* mk_rews *)
   464 
   465 fun set_mk_rews
   466   (Mss {rules, congs, procs, bounds, prems, mk_rews, termless, depth}, mk) =
   467     mk_mss (rules, congs, procs, bounds, prems,
   468             {mk=mk, mk_sym= #mk_sym mk_rews, mk_eq_True= #mk_eq_True mk_rews},
   469             termless, depth);
   470 
   471 fun set_mk_sym
   472   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_sym) =
   473     mk_mss (rules, congs, procs, bounds, prems,
   474             {mk= #mk mk_rews, mk_sym= mk_sym, mk_eq_True= #mk_eq_True mk_rews},
   475             termless,depth);
   476 
   477 fun set_mk_eq_True
   478   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_eq_True) =
   479     mk_mss (rules, congs, procs, bounds, prems,
   480             {mk= #mk mk_rews, mk_sym= #mk_sym mk_rews, mk_eq_True= mk_eq_True},
   481             termless,depth);
   482 
   483 (* termless *)
   484 
   485 fun set_termless
   486   (Mss {rules, congs, procs, bounds, prems, mk_rews, depth, ...}, termless) =
   487     mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth);
   488 
   489 
   490 
   491 (** rewriting **)
   492 
   493 (*
   494   Uses conversions, see:
   495     L C Paulson, A higher-order implementation of rewriting,
   496     Science of Computer Programming 3 (1983), pages 119-149.
   497 *)
   498 
   499 type prover = meta_simpset -> thm -> thm option;
   500 type termrec = (Sign.sg_ref * term list) * term;
   501 type conv = meta_simpset -> termrec -> termrec;
   502 
   503 val dest_eq = Drule.dest_equals o cprop_of;
   504 val lhs_of = fst o dest_eq;
   505 val rhs_of = snd o dest_eq;
   506 
   507 fun beta_eta_conversion t =
   508   let val thm = beta_conversion true t;
   509   in transitive thm (eta_conversion (rhs_of thm)) end;
   510 
   511 fun check_conv msg thm thm' =
   512   let
   513     val thm'' = transitive thm (transitive
   514       (symmetric (beta_eta_conversion (lhs_of thm'))) thm')
   515   in (if msg then trace_thm false "SUCCEEDED" thm' else (); Some thm'') end
   516   handle THM _ =>
   517     let val {sign, prop = _ $ _ $ prop0, ...} = rep_thm thm;
   518     in
   519       (trace_thm false "Proved wrong thm (Check subgoaler?)" thm';
   520        trace_term false "Should have proved:" sign prop0;
   521        None)
   522     end;
   523 
   524 
   525 (* mk_procrule *)
   526 
   527 fun mk_procrule thm =
   528   let val (_,prems,lhs,elhs,rhs,_) = decomp_simp thm
   529   in if rewrite_rule_extra_vars prems lhs rhs
   530      then (prthm true "Extra vars on rhs:" thm; [])
   531      else [mk_rrule2{thm=thm, lhs=lhs, elhs=elhs, perm=false}]
   532   end;
   533 
   534 
   535 (* conversion to apply the meta simpset to a term *)
   536 
   537 (* Since the rewriting strategy is bottom-up, we avoid re-normalizing already
   538    normalized terms by carrying around the rhs of the rewrite rule just
   539    applied. This is called the `skeleton'. It is decomposed in parallel
   540    with the term. Once a Var is encountered, the corresponding term is
   541    already in normal form.
   542    skel0 is a dummy skeleton that is to enforce complete normalization.
   543 *)
   544 val skel0 = Bound 0;
   545 
   546 (* Use rhs as skeleton only if the lhs does not contain unnormalized bits.
   547    The latter may happen iff there are weak congruence rules for constants
   548    in the lhs.
   549 *)
   550 fun uncond_skel((_,weak),(lhs,rhs)) =
   551   if null weak then rhs (* optimization *)
   552   else if exists_Const (fn (c,_) => c mem weak) lhs then skel0
   553        else rhs;
   554 
   555 (* Behaves like unconditional rule if rhs does not contain vars not in the lhs.
   556    Otherwise those vars may become instantiated with unnormalized terms
   557    while the premises are solved.
   558 *)
   559 fun cond_skel(args as (congs,(lhs,rhs))) =
   560   if term_varnames rhs subset term_varnames lhs then uncond_skel(args)
   561   else skel0;
   562 
   563 (*
   564   we try in order:
   565     (1) beta reduction
   566     (2) unconditional rewrite rules
   567     (3) conditional rewrite rules
   568     (4) simplification procedures
   569 
   570   IMPORTANT: rewrite rules must not introduce new Vars or TVars!
   571 
   572 *)
   573 
   574 fun rewritec (prover, signt, maxt)
   575              (mss as Mss{rules, procs, termless, prems, congs, depth,...}) t =
   576   let
   577     val eta_thm = Thm.eta_conversion t;
   578     val eta_t' = rhs_of eta_thm;
   579     val eta_t = term_of eta_t';
   580     val tsigt = Sign.tsig_of signt;
   581     fun rew {thm, lhs, elhs, fo, perm} =
   582       let
   583         val {sign, prop, maxidx, ...} = rep_thm thm;
   584         val _ = if Sign.subsig (sign, signt) then ()
   585                 else (prthm true "Ignoring rewrite rule from different theory:" thm;
   586                       raise Pattern.MATCH);
   587         val (rthm, elhs') = if maxt = ~1 then (thm, elhs)
   588           else (Thm.incr_indexes (maxt+1) thm, Thm.cterm_incr_indexes (maxt+1) elhs);
   589         val insts = if fo then Thm.cterm_first_order_match (elhs', eta_t')
   590                           else Thm.cterm_match (elhs', eta_t');
   591         val thm' = Thm.instantiate insts (Thm.rename_boundvars lhs eta_t rthm);
   592         val prop' = #prop (rep_thm thm');
   593         val unconditional = (Logic.count_prems (prop',0) = 0);
   594         val (lhs', rhs') = Logic.dest_equals (Logic.strip_imp_concl prop')
   595       in
   596         if perm andalso not (termless (rhs', lhs'))
   597         then (trace_thm false "Cannot apply permutative rewrite rule:" thm;
   598               trace_thm false "Term does not become smaller:" thm'; None)
   599         else (trace_thm false "Applying instance of rewrite rule:" thm;
   600            if unconditional
   601            then
   602              (trace_thm false "Rewriting:" thm';
   603               let val lr = Logic.dest_equals prop;
   604                   val Some thm'' = check_conv false eta_thm thm'
   605               in Some (thm'', uncond_skel (congs, lr)) end)
   606            else
   607              (trace_thm false "Trying to rewrite:" thm';
   608               case prover (incr_depth mss) thm' of
   609                 None       => (trace_thm false "FAILED" thm'; None)
   610               | Some thm2 =>
   611                   (case check_conv true eta_thm thm2 of
   612                      None => None |
   613                      Some thm2' =>
   614                        let val concl = Logic.strip_imp_concl prop
   615                            val lr = Logic.dest_equals concl
   616                        in Some (thm2', cond_skel (congs, lr)) end)))
   617       end
   618 
   619     fun rews [] = None
   620       | rews (rrule :: rrules) =
   621           let val opt = rew rrule handle Pattern.MATCH => None
   622           in case opt of None => rews rrules | some => some end;
   623 
   624     fun sort_rrules rrs = let
   625       fun is_simple({thm, ...}:rrule) = case #prop (rep_thm thm) of 
   626                                       Const("==",_) $ _ $ _ => true
   627                                       | _                   => false 
   628       fun sort []        (re1,re2) = re1 @ re2
   629         | sort (rr::rrs) (re1,re2) = if is_simple rr 
   630                                      then sort rrs (rr::re1,re2)
   631                                      else sort rrs (re1,rr::re2)
   632     in sort rrs ([],[]) end
   633 
   634     fun proc_rews ([]:simproc list) = None
   635       | proc_rews ({name, proc, lhs, ...} :: ps) =
   636           if Pattern.matches tsigt (term_of lhs, term_of t) then
   637             (debug_term false ("Trying procedure " ^ quote name ^ " on:") signt eta_t;
   638              case proc signt prems eta_t of
   639                None => (debug false "FAILED"; proc_rews ps)
   640              | Some raw_thm =>
   641                  (trace_thm false ("Procedure " ^ quote name ^ " produced rewrite rule:") raw_thm;
   642                   (case rews (mk_procrule raw_thm) of
   643                     None => (trace_cterm false "IGNORED - does not match" t; proc_rews ps)
   644                   | some => some)))
   645           else proc_rews ps;
   646   in case eta_t of
   647        Abs _ $ _ => Some (transitive eta_thm
   648          (beta_conversion false (rhs_of eta_thm)), skel0)
   649      | _ => (case rews (sort_rrules (Net.match_term rules eta_t)) of
   650                None => proc_rews (Net.match_term procs eta_t)
   651              | some => some)
   652   end;
   653 
   654 
   655 (* conversion to apply a congruence rule to a term *)
   656 
   657 fun congc (prover,signt,maxt) {thm=cong,lhs=lhs} t =
   658   let val {sign, ...} = rep_thm cong
   659       val _ = if Sign.subsig (sign, signt) then ()
   660                  else error("Congruence rule from different theory")
   661       val rthm = if maxt = ~1 then cong else Thm.incr_indexes (maxt+1) cong;
   662       val rlhs = fst (Drule.dest_equals (Drule.strip_imp_concl (cprop_of rthm)));
   663       val insts = Thm.cterm_match (rlhs, t)
   664       (* Pattern.match can raise Pattern.MATCH;
   665          is handled when congc is called *)
   666       val thm' = Thm.instantiate insts (Thm.rename_boundvars (term_of rlhs) (term_of t) rthm);
   667       val unit = trace_thm false "Applying congruence rule:" thm';
   668       fun err (msg, thm) = (prthm false msg thm; error "Failed congruence proof!")
   669   in case prover thm' of
   670        None => err ("Could not prove", thm')
   671      | Some thm2 => (case check_conv true (beta_eta_conversion t) thm2 of
   672           None => err ("Should not have proved", thm2)
   673         | Some thm2' => thm2')
   674   end;
   675 
   676 val (cA, (cB, cC)) =
   677   apsnd dest_equals (dest_implies (hd (cprems_of Drule.imp_cong)));
   678 
   679 fun transitive' thm1 None = Some thm1
   680   | transitive' thm1 (Some thm2) = Some (transitive thm1 thm2);
   681 
   682 fun bottomc ((simprem,useprem,mutsimp), prover, sign, maxidx) =
   683   let
   684     fun botc skel mss t =
   685           if is_Var skel then None
   686           else
   687           (case subc skel mss t of
   688              some as Some thm1 =>
   689                (case rewritec (prover, sign, maxidx) mss (rhs_of thm1) of
   690                   Some (thm2, skel2) =>
   691                     transitive' (transitive thm1 thm2)
   692                       (botc skel2 mss (rhs_of thm2))
   693                 | None => some)
   694            | None =>
   695                (case rewritec (prover, sign, maxidx) mss t of
   696                   Some (thm2, skel2) => transitive' thm2
   697                     (botc skel2 mss (rhs_of thm2))
   698                 | None => None))
   699 
   700     and try_botc mss t =
   701           (case botc skel0 mss t of
   702              Some trec1 => trec1 | None => (reflexive t))
   703 
   704     and subc skel
   705           (mss as Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) t0 =
   706        (case term_of t0 of
   707            Abs (a, T, t) =>
   708              let val b = variant bounds a
   709                  val (v, t') = Thm.dest_abs (Some ("." ^ b)) t0
   710                  val mss' = mk_mss (rules, congs, procs, b :: bounds, prems, mk_rews, termless,depth)
   711                  val skel' = case skel of Abs (_, _, sk) => sk | _ => skel0
   712              in case botc skel' mss' t' of
   713                   Some thm => Some (abstract_rule a v thm)
   714                 | None => None
   715              end
   716          | t $ _ => (case t of
   717              Const ("==>", _) $ _  =>
   718                let val (s, u) = Drule.dest_implies t0
   719                in impc (s, u, mss) end
   720            | Abs _ =>
   721                let val thm = beta_conversion false t0
   722                in case subc skel0 mss (rhs_of thm) of
   723                     None => Some thm
   724                   | Some thm' => Some (transitive thm thm')
   725                end
   726            | _  =>
   727                let fun appc () =
   728                      let
   729                        val (tskel, uskel) = case skel of
   730                            tskel $ uskel => (tskel, uskel)
   731                          | _ => (skel0, skel0);
   732                        val (ct, cu) = Thm.dest_comb t0
   733                      in
   734                      (case botc tskel mss ct of
   735                         Some thm1 =>
   736                           (case botc uskel mss cu of
   737                              Some thm2 => Some (combination thm1 thm2)
   738                            | None => Some (combination thm1 (reflexive cu)))
   739                       | None =>
   740                           (case botc uskel mss cu of
   741                              Some thm1 => Some (combination (reflexive ct) thm1)
   742                            | None => None))
   743                      end
   744                    val (h, ts) = strip_comb t
   745                in case h of
   746                     Const(a, _) =>
   747                       (case assoc_string (fst congs, a) of
   748                          None => appc ()
   749                        | Some cong =>
   750 (* post processing: some partial applications h t1 ... tj, j <= length ts,
   751    may be a redex. Example: map (%x.x) = (%xs.xs) wrt map_cong *)
   752                           (let
   753                              val thm = congc (prover mss, sign, maxidx) cong t0;
   754                              val t = rhs_of thm;
   755                              val (cl, cr) = Thm.dest_comb t
   756                              val dVar = Var(("", 0), dummyT)
   757                              val skel =
   758                                list_comb (h, replicate (length ts) dVar)
   759                            in case botc skel mss cl of
   760                                 None => Some thm
   761                               | Some thm' => Some (transitive thm
   762                                   (combination thm' (reflexive cr)))
   763                            end handle TERM _ => error "congc result"
   764                                     | Pattern.MATCH => appc ()))
   765                   | _ => appc ()
   766                end)
   767          | _ => None)
   768 
   769     and impc args =
   770       if mutsimp
   771       then let val (prem, conc, mss) = args
   772            in apsome snd (mut_impc ([], prem, conc, mss)) end
   773       else nonmut_impc args
   774 
   775     and mut_impc (prems, prem, conc, mss) = (case botc skel0 mss prem of
   776         None => mut_impc1 (prems, prem, conc, mss)
   777       | Some thm1 =>
   778           let val prem1 = rhs_of thm1
   779           in (case mut_impc1 (prems, prem1, conc, mss) of
   780               None => Some (None,
   781                 combination (combination refl_implies thm1) (reflexive conc))
   782             | Some (x, thm2) => Some (x, transitive (combination (combination
   783                 refl_implies thm1) (reflexive conc)) thm2))
   784           end)
   785 
   786     and mut_impc1 (prems, prem1, conc, mss) =
   787       let
   788         fun uncond ({thm, lhs, elhs, perm}) =
   789           if Thm.no_prems thm then Some lhs else None
   790 
   791         val (lhss1, mss1) =
   792           if maxidx_of_term (term_of prem1) <> ~1
   793           then (trace_cterm true
   794             "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem1;
   795                 ([],mss))
   796           else let val thm = assume prem1
   797                    val rrules1 = extract_safe_rrules (mss, thm)
   798                    val lhss1 = mapfilter uncond rrules1
   799                    val mss1 = foldl insert_rrule (add_prems (mss, [thm]), rrules1)
   800                in (lhss1, mss1) end
   801 
   802         fun disch1 thm =
   803           let val (cB', cC') = dest_eq thm
   804           in
   805             implies_elim (Thm.instantiate
   806               ([], [(cA, prem1), (cB, cB'), (cC, cC')]) Drule.imp_cong)
   807               (implies_intr prem1 thm)
   808           end
   809 
   810         fun rebuild None = (case rewritec (prover, sign, maxidx) mss
   811             (mk_implies (prem1, conc)) of
   812               None => None
   813             | Some (thm, _) => 
   814                 let val (prem, conc) = Drule.dest_implies (rhs_of thm)
   815                 in (case mut_impc (prems, prem, conc, mss) of
   816                     None => Some (None, thm)
   817                   | Some (x, thm') => Some (x, transitive thm thm'))
   818                 end handle TERM _ => Some (None, thm))
   819           | rebuild (Some thm2) =
   820             let val thm = disch1 thm2
   821             in (case rewritec (prover, sign, maxidx) mss (rhs_of thm) of
   822                  None => Some (None, thm)
   823                | Some (thm', _) =>
   824                    let val (prem, conc) = Drule.dest_implies (rhs_of thm')
   825                    in (case mut_impc (prems, prem, conc, mss) of
   826                        None => Some (None, transitive thm thm')
   827                      | Some (x, thm'') =>
   828                          Some (x, transitive (transitive thm thm') thm''))
   829                    end handle TERM _ => Some (None, transitive thm thm'))
   830             end
   831 
   832         fun simpconc () =
   833           let val (s, t) = Drule.dest_implies conc
   834           in case mut_impc (prems @ [prem1], s, t, mss1) of
   835                None => rebuild None
   836              | Some (Some i, thm2) =>
   837                   let
   838                     val (prem, cC') = Drule.dest_implies (rhs_of thm2);
   839                     val thm2' = transitive (disch1 thm2) (Thm.instantiate
   840                       ([], [(cA, prem1), (cB, prem), (cC, cC')])
   841                       Drule.swap_prems_eq)
   842                   in if i=0 then apsome (apsnd (transitive thm2'))
   843                        (mut_impc1 (prems, prem, mk_implies (prem1, cC'), mss))
   844                      else Some (Some (i-1), thm2')
   845                   end
   846              | Some (None, thm) => rebuild (Some thm)
   847           end handle TERM _ => rebuild (botc skel0 mss1 conc)
   848 
   849       in
   850         let
   851           val tsig = Sign.tsig_of sign
   852           fun reducible t =
   853             exists (fn lhs => Pattern.matches_subterm tsig (lhs, term_of t)) lhss1;
   854         in case dropwhile (not o reducible) prems of
   855             [] => simpconc ()
   856           | red::rest => (trace_cterm false "Can now reduce premise:" red;
   857               Some (Some (length rest), reflexive (mk_implies (prem1, conc))))
   858         end
   859       end
   860 
   861      (* legacy code - only for backwards compatibility *)
   862      and nonmut_impc (prem, conc, mss) =
   863        let val thm1 = if simprem then botc skel0 mss prem else None;
   864            val prem1 = if_none (apsome rhs_of thm1) prem;
   865            val maxidx1 = maxidx_of_term (term_of prem1)
   866            val mss1 =
   867              if not useprem then mss else
   868              if maxidx1 <> ~1
   869              then (trace_cterm true
   870                "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem1;
   871                    mss)
   872              else let val thm = assume prem1
   873                   in add_safe_simp (add_prems (mss, [thm]), thm) end
   874        in (case botc skel0 mss1 conc of
   875            None => (case thm1 of
   876                None => None
   877              | Some thm1' => Some (combination
   878                  (combination refl_implies thm1') (reflexive conc)))
   879          | Some thm2 =>
   880            let
   881              val conc2 = rhs_of thm2;
   882              val thm2' = implies_elim (Thm.instantiate
   883                ([], [(cA, prem1), (cB, conc), (cC, conc2)]) Drule.imp_cong)
   884                (implies_intr prem1 thm2)
   885            in (case thm1 of
   886                None => Some thm2'
   887              | Some thm1' => Some (transitive (combination
   888                  (combination refl_implies thm1') (reflexive conc)) thm2'))
   889            end)
   890        end
   891 
   892  in try_botc end;
   893 
   894 
   895 (*** Meta-rewriting: rewrites t to u and returns the theorem t==u ***)
   896 
   897 (*
   898   Parameters:
   899     mode = (simplify A,
   900             use A in simplifying B,
   901             use prems of B (if B is again a meta-impl.) to simplify A)
   902            when simplifying A ==> B
   903     mss: contains equality theorems of the form [|p1,...|] ==> t==u
   904     prover: how to solve premises in conditional rewrites and congruences
   905 *)
   906 
   907 (* FIXME: check that #bounds(mss) does not "occur" in ct already *)
   908 
   909 fun rewrite_cterm mode prover mss ct =
   910   let val {sign, t, maxidx, ...} = rep_cterm ct
   911       val Mss{depth, ...} = mss
   912   in simp_depth := depth;
   913      bottomc (mode, prover, sign, maxidx) mss ct
   914   end
   915   handle THM (s, _, thms) =>
   916     error ("Exception THM was raised in simplifier:\n" ^ s ^ "\n" ^
   917       Pretty.string_of (pretty_thms thms));
   918 
   919 (*In [A1,...,An]==>B, rewrite the selected A's only -- for rewrite_goals_tac*)
   920 (*Do not rewrite flex-flex pairs*)
   921 fun goals_conv pred cv =
   922   let fun gconv i ct =
   923         let val (A,B) = Drule.dest_implies ct
   924             val (thA,j) = case term_of A of
   925                   Const("=?=",_)$_$_ => (reflexive A, i)
   926                 | _ => (if pred i then cv A else reflexive A, i+1)
   927         in  combination (combination refl_implies thA) (gconv j B) end
   928         handle TERM _ => reflexive ct
   929   in gconv 1 end;
   930 
   931 (*Use a conversion to transform a theorem*)
   932 fun fconv_rule cv th = equal_elim (cv (cprop_of th)) th;
   933 
   934 (*Rewrite a theorem*)
   935 fun rewrite_rule_aux _ [] = (fn th => th)
   936   | rewrite_rule_aux prover thms =
   937       fconv_rule (rewrite_cterm (true,false,false) prover (mss_of thms));
   938 
   939 fun rewrite_thm mode prover mss = fconv_rule (rewrite_cterm mode prover mss);
   940 
   941 (*Rewrite the subgoals of a proof state (represented by a theorem) *)
   942 fun rewrite_goals_rule_aux _ []   th = th
   943   | rewrite_goals_rule_aux prover thms th =
   944       fconv_rule (goals_conv (K true) (rewrite_cterm (true, true, false) prover
   945         (mss_of thms))) th;
   946 
   947 (*Rewrite the subgoal of a proof state (represented by a theorem) *)
   948 fun rewrite_goal_rule mode prover mss i thm =
   949   if 0 < i  andalso  i <= nprems_of thm
   950   then fconv_rule (goals_conv (fn j => j=i) (rewrite_cterm mode prover mss)) thm
   951   else raise THM("rewrite_goal_rule",i,[thm]);
   952 
   953 end;
   954 
   955 open MetaSimplifier;