src/Pure/meta_simplifier.ML
author wenzelm
Thu Oct 04 15:21:47 2001 +0200 (2001-10-04)
changeset 11672 8e75b78f33f3
parent 11505 a410fa8acfca
child 11689 38788d98504f
permissions -rw-r--r--
full_rewrite_cterm_aux (see also tactic.ML);
no longer open MetaSimplifier;
     1 (*  Title:      Pure/meta_simplifier.ML
     2     ID:         $Id$
     3     Author:     Tobias Nipkow and Stefan Berghofer
     4     Copyright   1994  University of Cambridge
     5 
     6 Meta-level Simplification.
     7 *)
     8 
     9 signature BASIC_META_SIMPLIFIER =
    10 sig
    11   val trace_simp: bool ref
    12   val debug_simp: bool ref
    13 end;
    14 
    15 signature META_SIMPLIFIER =
    16 sig
    17   include BASIC_META_SIMPLIFIER
    18   exception SIMPLIFIER of string * thm
    19   type meta_simpset
    20   val dest_mss		: meta_simpset ->
    21     {simps: thm list, congs: thm list, procs: (string * cterm list) list}
    22   val empty_mss         : meta_simpset
    23   val clear_mss		: meta_simpset -> meta_simpset
    24   val merge_mss		: meta_simpset * meta_simpset -> meta_simpset
    25   val add_simps         : meta_simpset * thm list -> meta_simpset
    26   val del_simps         : meta_simpset * thm list -> meta_simpset
    27   val mss_of            : thm list -> meta_simpset
    28   val add_congs         : meta_simpset * thm list -> meta_simpset
    29   val del_congs         : meta_simpset * thm list -> meta_simpset
    30   val add_simprocs	: meta_simpset *
    31     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    32       -> meta_simpset
    33   val del_simprocs	: meta_simpset *
    34     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    35       -> meta_simpset
    36   val add_prems         : meta_simpset * thm list -> meta_simpset
    37   val prems_of_mss      : meta_simpset -> thm list
    38   val set_mk_rews       : meta_simpset * (thm -> thm list) -> meta_simpset
    39   val set_mk_sym        : meta_simpset * (thm -> thm option) -> meta_simpset
    40   val set_mk_eq_True    : meta_simpset * (thm -> thm option) -> meta_simpset
    41   val set_termless      : meta_simpset * (term * term -> bool) -> meta_simpset
    42   val rewrite_cterm: bool * bool * bool ->
    43     (meta_simpset -> thm -> thm option) -> meta_simpset -> cterm -> thm
    44   val full_rewrite_cterm_aux: (meta_simpset -> thm -> thm option) -> thm list -> cterm -> cterm
    45   val rewrite_rule_aux  : (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
    46   val rewrite_thm       : bool * bool * bool
    47                           -> (meta_simpset -> thm -> thm option)
    48                           -> meta_simpset -> thm -> thm
    49   val rewrite_goals_rule_aux: (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
    50   val rewrite_goal_rule : bool* bool * bool
    51                           -> (meta_simpset -> thm -> thm option)
    52                           -> meta_simpset -> int -> thm -> thm
    53 end;
    54 
    55 structure MetaSimplifier : META_SIMPLIFIER =
    56 struct
    57 
    58 (** diagnostics **)
    59 
    60 exception SIMPLIFIER of string * thm;
    61 
    62 val simp_depth = ref 0;
    63 
    64 fun println a = writeln(replicate_string (!simp_depth) " " ^ a)
    65 
    66 fun prnt warn a = if warn then warning a else println a;
    67 
    68 fun prtm warn a sign t =
    69   (prnt warn a; prnt warn (Sign.string_of_term sign t));
    70 
    71 fun prctm warn a t =
    72   (prnt warn a; prnt warn (Display.string_of_cterm t));
    73 
    74 fun prthm warn a thm =
    75   let val {sign, prop, ...} = rep_thm thm
    76   in prtm warn a sign prop end;
    77 
    78 val trace_simp = ref false;
    79 val debug_simp = ref false;
    80 
    81 fun trace warn a = if !trace_simp then prnt warn a else ();
    82 fun debug warn a = if !debug_simp then prnt warn a else ();
    83 
    84 fun trace_term warn a sign t = if !trace_simp then prtm warn a sign t else ();
    85 fun trace_cterm warn a t = if !trace_simp then prctm warn a t else ();
    86 fun debug_term warn a sign t = if !debug_simp then prtm warn a sign t else ();
    87 
    88 fun trace_thm warn a thm =
    89   let val {sign, prop, ...} = rep_thm thm
    90   in trace_term warn a sign prop end;
    91 
    92 
    93 
    94 (** meta simp sets **)
    95 
    96 (* basic components *)
    97 
    98 type rrule = {thm: thm, lhs: term, elhs: cterm, fo: bool, perm: bool};
    99 (* thm: the rewrite rule
   100    lhs: the left-hand side
   101    elhs: the etac-contracted lhs.
   102    fo:  use first-order matching
   103    perm: the rewrite rule is permutative
   104 Reamrks:
   105   - elhs is used for matching,
   106     lhs only for preservation of bound variable names.
   107   - fo is set iff
   108     either elhs is first-order (no Var is applied),
   109            in which case fo-matching is complete,
   110     or elhs is not a pattern,
   111        in which case there is nothing better to do.
   112 *)
   113 type cong = {thm: thm, lhs: cterm};
   114 type simproc =
   115  {name: string, proc: Sign.sg -> thm list -> term -> thm option, lhs: cterm, id: stamp};
   116 
   117 fun eq_rrule ({thm = thm1, ...}: rrule, {thm = thm2, ...}: rrule) =
   118   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   119 
   120 fun eq_cong ({thm = thm1, ...}: cong, {thm = thm2, ...}: cong) = 
   121   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   122 
   123 fun eq_prem (thm1, thm2) =
   124   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   125 
   126 fun eq_simproc ({id = s1, ...}:simproc, {id = s2, ...}:simproc) = (s1 = s2);
   127 
   128 fun mk_simproc (name, proc, lhs, id) =
   129   {name = name, proc = proc, lhs = lhs, id = id};
   130 
   131 
   132 (* datatype mss *)
   133 
   134 (*
   135   A "mss" contains data needed during conversion:
   136     rules: discrimination net of rewrite rules;
   137     congs: association list of congruence rules and
   138            a list of `weak' congruence constants.
   139            A congruence is `weak' if it avoids normalization of some argument.
   140     procs: discrimination net of simplification procedures
   141       (functions that prove rewrite rules on the fly);
   142     bounds: names of bound variables already used
   143       (for generating new names when rewriting under lambda abstractions);
   144     prems: current premises;
   145     mk_rews: mk: turns simplification thms into rewrite rules;
   146              mk_sym: turns == around; (needs Drule!)
   147              mk_eq_True: turns P into P == True - logic specific;
   148     termless: relation for ordered rewriting;
   149     depth: depth of conditional rewriting;
   150 *)
   151 
   152 datatype meta_simpset =
   153   Mss of {
   154     rules: rrule Net.net,
   155     congs: (string * cong) list * string list,
   156     procs: simproc Net.net,
   157     bounds: string list,
   158     prems: thm list,
   159     mk_rews: {mk: thm -> thm list,
   160               mk_sym: thm -> thm option,
   161               mk_eq_True: thm -> thm option},
   162     termless: term * term -> bool,
   163     depth: int};
   164 
   165 fun mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth) =
   166   Mss {rules = rules, congs = congs, procs = procs, bounds = bounds,
   167        prems=prems, mk_rews=mk_rews, termless=termless, depth=depth};
   168 
   169 fun upd_rules(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}, rules') =
   170   mk_mss(rules',congs,procs,bounds,prems,mk_rews,termless,depth);
   171 
   172 val empty_mss =
   173   let val mk_rews = {mk = K [], mk_sym = K None, mk_eq_True = K None}
   174   in mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, Term.termless, 0) end;
   175 
   176 fun clear_mss (Mss {mk_rews, termless, ...}) =
   177   mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, termless,0);
   178 
   179 fun incr_depth(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) =
   180   mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth+1)
   181  
   182 
   183 
   184 (** simpset operations **)
   185 
   186 (* term variables *)
   187 
   188 val add_term_varnames = foldl_aterms (fn (xs, Var (x, _)) => ins_ix (x, xs) | (xs, _) => xs);
   189 fun term_varnames t = add_term_varnames ([], t);
   190 
   191 
   192 (* dest_mss *)
   193 
   194 fun dest_mss (Mss {rules, congs, procs, ...}) =
   195   {simps = map (fn (_, {thm, ...}) => thm) (Net.dest rules),
   196    congs = map (fn (_, {thm, ...}) => thm) (fst congs),
   197    procs =
   198      map (fn (_, {name, lhs, id, ...}) => ((name, lhs), id)) (Net.dest procs)
   199      |> partition_eq eq_snd
   200      |> map (fn ps => (#1 (#1 (hd ps)), map (#2 o #1) ps))
   201      |> Library.sort_wrt #1};
   202 
   203 
   204 (* merge_mss *)	      (*NOTE: ignores mk_rews, termless and depth of 2nd mss*)
   205 
   206 fun merge_mss
   207  (Mss {rules = rules1, congs = (congs1,weak1), procs = procs1,
   208        bounds = bounds1, prems = prems1, mk_rews, termless, depth},
   209   Mss {rules = rules2, congs = (congs2,weak2), procs = procs2,
   210        bounds = bounds2, prems = prems2, ...}) =
   211       mk_mss
   212        (Net.merge (rules1, rules2, eq_rrule),
   213         (generic_merge (eq_cong o pairself snd) I I congs1 congs2,
   214         merge_lists weak1 weak2),
   215         Net.merge (procs1, procs2, eq_simproc),
   216         merge_lists bounds1 bounds2,
   217         generic_merge eq_prem I I prems1 prems2,
   218         mk_rews, termless, depth);
   219 
   220 
   221 (* add_simps *)
   222 
   223 fun mk_rrule2{thm,lhs,elhs,perm} =
   224   let val fo = Pattern.first_order (term_of elhs) orelse not(Pattern.pattern (term_of elhs))
   225   in {thm=thm,lhs=lhs,elhs=elhs,fo=fo,perm=perm} end
   226 
   227 fun insert_rrule(mss as Mss {rules,...},
   228                  rrule as {thm,lhs,elhs,perm}) =
   229   (trace_thm false "Adding rewrite rule:" thm;
   230    let val rrule2 as {elhs,...} = mk_rrule2 rrule
   231        val rules' = Net.insert_term ((term_of elhs, rrule2), rules, eq_rrule)
   232    in upd_rules(mss,rules') end
   233    handle Net.INSERT =>
   234      (prthm true "Ignoring duplicate rewrite rule:" thm; mss));
   235 
   236 fun vperm (Var _, Var _) = true
   237   | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t)
   238   | vperm (t1 $ t2, u1 $ u2) = vperm (t1, u1) andalso vperm (t2, u2)
   239   | vperm (t, u) = (t = u);
   240 
   241 fun var_perm (t, u) =
   242   vperm (t, u) andalso eq_set (term_varnames t, term_varnames u);
   243 
   244 (* FIXME: it seems that the conditions on extra variables are too liberal if
   245 prems are nonempty: does solving the prems really guarantee instantiation of
   246 all its Vars? Better: a dynamic check each time a rule is applied.
   247 *)
   248 fun rewrite_rule_extra_vars prems elhs erhs =
   249   not (term_varnames erhs subset foldl add_term_varnames (term_varnames elhs, prems))
   250   orelse
   251   not ((term_tvars erhs) subset
   252        (term_tvars elhs  union  List.concat(map term_tvars prems)));
   253 
   254 (*Simple test for looping rewrite rules and stupid orientations*)
   255 fun reorient sign prems lhs rhs =
   256    rewrite_rule_extra_vars prems lhs rhs
   257   orelse
   258    is_Var (head_of lhs)
   259   orelse
   260    (exists (apl (lhs, Logic.occs)) (rhs :: prems))
   261   orelse
   262    (null prems andalso
   263     Pattern.matches (#tsig (Sign.rep_sg sign)) (lhs, rhs))
   264     (*the condition "null prems" is necessary because conditional rewrites
   265       with extra variables in the conditions may terminate although
   266       the rhs is an instance of the lhs. Example: ?m < ?n ==> f(?n) == f(?m)*)
   267   orelse
   268    (is_Const lhs andalso not(is_Const rhs))
   269 
   270 fun decomp_simp thm =
   271   let val {sign, prop, ...} = rep_thm thm;
   272       val prems = Logic.strip_imp_prems prop;
   273       val concl = Drule.strip_imp_concl (cprop_of thm);
   274       val (lhs, rhs) = Drule.dest_equals concl handle TERM _ =>
   275         raise SIMPLIFIER ("Rewrite rule not a meta-equality", thm)
   276       val elhs = snd (Drule.dest_equals (cprop_of (Thm.eta_conversion lhs)));
   277       val elhs = if elhs=lhs then lhs else elhs (* try to share *)
   278       val erhs = Pattern.eta_contract (term_of rhs);
   279       val perm = var_perm (term_of elhs, erhs) andalso not (term_of elhs aconv erhs)
   280                  andalso not (is_Var (term_of elhs))
   281   in (sign, prems, term_of lhs, elhs, term_of rhs, perm) end;
   282 
   283 fun mk_eq_True (Mss{mk_rews={mk_eq_True,...},...}) thm =
   284   case mk_eq_True thm of
   285     None => []
   286   | Some eq_True => let val (_,_,lhs,elhs,_,_) = decomp_simp eq_True
   287                     in [{thm=eq_True, lhs=lhs, elhs=elhs, perm=false}] end;
   288 
   289 (* create the rewrite rule and possibly also the ==True variant,
   290    in case there are extra vars on the rhs *)
   291 fun rrule_eq_True(thm,lhs,elhs,rhs,mss,thm2) =
   292   let val rrule = {thm=thm, lhs=lhs, elhs=elhs, perm=false}
   293   in if (term_varnames rhs)  subset (term_varnames lhs) andalso
   294         (term_tvars rhs) subset (term_tvars lhs)
   295      then [rrule]
   296      else mk_eq_True mss thm2 @ [rrule]
   297   end;
   298 
   299 fun mk_rrule mss thm =
   300   let val (_,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   301   in if perm then [{thm=thm, lhs=lhs, elhs=elhs, perm=true}] else
   302      (* weak test for loops: *)
   303      if rewrite_rule_extra_vars prems lhs rhs orelse
   304         is_Var (term_of elhs)
   305      then mk_eq_True mss thm
   306      else rrule_eq_True(thm,lhs,elhs,rhs,mss,thm)
   307   end;
   308 
   309 fun orient_rrule mss thm =
   310   let val (sign,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   311   in if perm then [{thm=thm,lhs=lhs,elhs=elhs,perm=true}]
   312      else if reorient sign prems lhs rhs
   313           then if reorient sign prems rhs lhs
   314                then mk_eq_True mss thm
   315                else let val Mss{mk_rews={mk_sym,...},...} = mss
   316                     in case mk_sym thm of
   317                          None => []
   318                        | Some thm' =>
   319                            let val (_,_,lhs',elhs',rhs',_) = decomp_simp thm'
   320                            in rrule_eq_True(thm',lhs',elhs',rhs',mss,thm) end
   321                     end
   322           else rrule_eq_True(thm,lhs,elhs,rhs,mss,thm)
   323   end;
   324 
   325 fun extract_rews(Mss{mk_rews = {mk,...},...},thms) = flat(map mk thms);
   326 
   327 fun orient_comb_simps comb mk_rrule (mss,thms) =
   328   let val rews = extract_rews(mss,thms)
   329       val rrules = flat (map mk_rrule rews)
   330   in foldl comb (mss,rrules) end
   331 
   332 (* Add rewrite rules explicitly; do not reorient! *)
   333 fun add_simps(mss,thms) =
   334   orient_comb_simps insert_rrule (mk_rrule mss) (mss,thms);
   335 
   336 fun mss_of thms =
   337   foldl insert_rrule (empty_mss, flat(map (mk_rrule empty_mss) thms));
   338 
   339 fun extract_safe_rrules(mss,thm) =
   340   flat (map (orient_rrule mss) (extract_rews(mss,[thm])));
   341 
   342 fun add_safe_simp(mss,thm) =
   343   foldl insert_rrule (mss, extract_safe_rrules(mss,thm))
   344 
   345 (* del_simps *)
   346 
   347 fun del_rrule(mss as Mss {rules,...},
   348               rrule as {thm, elhs, ...}) =
   349   (upd_rules(mss, Net.delete_term ((term_of elhs, rrule), rules, eq_rrule))
   350    handle Net.DELETE =>
   351      (prthm true "Rewrite rule not in simpset:" thm; mss));
   352 
   353 fun del_simps(mss,thms) =
   354   orient_comb_simps del_rrule (map mk_rrule2 o mk_rrule mss) (mss,thms);
   355 
   356 
   357 (* add_congs *)
   358 
   359 fun is_full_cong_prems [] varpairs = null varpairs
   360   | is_full_cong_prems (p::prems) varpairs =
   361     (case Logic.strip_assums_concl p of
   362        Const("==",_) $ lhs $ rhs =>
   363          let val (x,xs) = strip_comb lhs and (y,ys) = strip_comb rhs
   364          in is_Var x  andalso  forall is_Bound xs  andalso
   365             null(findrep(xs))  andalso xs=ys andalso
   366             (x,y) mem varpairs andalso
   367             is_full_cong_prems prems (varpairs\(x,y))
   368          end
   369      | _ => false);
   370 
   371 fun is_full_cong thm =
   372 let val prems = prems_of thm
   373     and concl = concl_of thm
   374     val (lhs,rhs) = Logic.dest_equals concl
   375     val (f,xs) = strip_comb lhs
   376     and (g,ys) = strip_comb rhs
   377 in
   378   f=g andalso null(findrep(xs@ys)) andalso length xs = length ys andalso
   379   is_full_cong_prems prems (xs ~~ ys)
   380 end
   381 
   382 fun add_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   383   let
   384     val (lhs, _) = Drule.dest_equals (Drule.strip_imp_concl (cprop_of thm)) handle TERM _ =>
   385       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   386 (*   val lhs = Pattern.eta_contract lhs; *)
   387     val (a, _) = dest_Const (head_of (term_of lhs)) handle TERM _ =>
   388       raise SIMPLIFIER ("Congruence must start with a constant", thm);
   389     val (alist,weak) = congs
   390     val alist2 = overwrite_warn (alist, (a,{lhs=lhs, thm=thm}))
   391            ("Overwriting congruence rule for " ^ quote a);
   392     val weak2 = if is_full_cong thm then weak else a::weak
   393   in
   394     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   395   end;
   396 
   397 val (op add_congs) = foldl add_cong;
   398 
   399 
   400 (* del_congs *)
   401 
   402 fun del_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   403   let
   404     val (lhs, _) = Logic.dest_equals (concl_of thm) handle TERM _ =>
   405       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   406 (*   val lhs = Pattern.eta_contract lhs; *)
   407     val (a, _) = dest_Const (head_of lhs) handle TERM _ =>
   408       raise SIMPLIFIER ("Congruence must start with a constant", thm);
   409     val (alist,_) = congs
   410     val alist2 = filter (fn (x,_)=> x<>a) alist
   411     val weak2 = mapfilter (fn(a,{thm,...}) => if is_full_cong thm then None
   412                                               else Some a)
   413                    alist2
   414   in
   415     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   416   end;
   417 
   418 val (op del_congs) = foldl del_cong;
   419 
   420 
   421 (* add_simprocs *)
   422 
   423 fun add_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   424     (name, lhs, proc, id)) =
   425   let val {sign, t, ...} = rep_cterm lhs
   426   in (trace_term false ("Adding simplification procedure " ^ quote name ^ " for")
   427       sign t;
   428     mk_mss (rules, congs,
   429       Net.insert_term ((t, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   430         handle Net.INSERT => 
   431 	    (warning ("Ignoring duplicate simplification procedure \"" 
   432 	              ^ name ^ "\""); 
   433 	     procs),
   434         bounds, prems, mk_rews, termless,depth))
   435   end;
   436 
   437 fun add_simproc (mss, (name, lhss, proc, id)) =
   438   foldl add_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   439 
   440 val add_simprocs = foldl add_simproc;
   441 
   442 
   443 (* del_simprocs *)
   444 
   445 fun del_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   446     (name, lhs, proc, id)) =
   447   mk_mss (rules, congs,
   448     Net.delete_term ((term_of lhs, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   449       handle Net.DELETE => 
   450 	  (warning ("Simplification procedure \"" ^ name ^
   451 		       "\" not in simpset"); procs),
   452       bounds, prems, mk_rews, termless, depth);
   453 
   454 fun del_simproc (mss, (name, lhss, proc, id)) =
   455   foldl del_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   456 
   457 val del_simprocs = foldl del_simproc;
   458 
   459 
   460 (* prems *)
   461 
   462 fun add_prems (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thms) =
   463   mk_mss (rules, congs, procs, bounds, thms @ prems, mk_rews, termless, depth);
   464 
   465 fun prems_of_mss (Mss {prems, ...}) = prems;
   466 
   467 
   468 (* mk_rews *)
   469 
   470 fun set_mk_rews
   471   (Mss {rules, congs, procs, bounds, prems, mk_rews, termless, depth}, mk) =
   472     mk_mss (rules, congs, procs, bounds, prems,
   473             {mk=mk, mk_sym= #mk_sym mk_rews, mk_eq_True= #mk_eq_True mk_rews},
   474             termless, depth);
   475 
   476 fun set_mk_sym
   477   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_sym) =
   478     mk_mss (rules, congs, procs, bounds, prems,
   479             {mk= #mk mk_rews, mk_sym= mk_sym, mk_eq_True= #mk_eq_True mk_rews},
   480             termless,depth);
   481 
   482 fun set_mk_eq_True
   483   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_eq_True) =
   484     mk_mss (rules, congs, procs, bounds, prems,
   485             {mk= #mk mk_rews, mk_sym= #mk_sym mk_rews, mk_eq_True= mk_eq_True},
   486             termless,depth);
   487 
   488 (* termless *)
   489 
   490 fun set_termless
   491   (Mss {rules, congs, procs, bounds, prems, mk_rews, depth, ...}, termless) =
   492     mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth);
   493 
   494 
   495 
   496 (** rewriting **)
   497 
   498 (*
   499   Uses conversions, see:
   500     L C Paulson, A higher-order implementation of rewriting,
   501     Science of Computer Programming 3 (1983), pages 119-149.
   502 *)
   503 
   504 type prover = meta_simpset -> thm -> thm option;
   505 type termrec = (Sign.sg_ref * term list) * term;
   506 type conv = meta_simpset -> termrec -> termrec;
   507 
   508 val dest_eq = Drule.dest_equals o cprop_of;
   509 val lhs_of = fst o dest_eq;
   510 val rhs_of = snd o dest_eq;
   511 
   512 fun beta_eta_conversion t =
   513   let val thm = beta_conversion true t;
   514   in transitive thm (eta_conversion (rhs_of thm)) end;
   515 
   516 fun check_conv msg thm thm' =
   517   let
   518     val thm'' = transitive thm (transitive
   519       (symmetric (beta_eta_conversion (lhs_of thm'))) thm')
   520   in (if msg then trace_thm false "SUCCEEDED" thm' else (); Some thm'') end
   521   handle THM _ =>
   522     let val {sign, prop = _ $ _ $ prop0, ...} = rep_thm thm;
   523     in
   524       (trace_thm false "Proved wrong thm (Check subgoaler?)" thm';
   525        trace_term false "Should have proved:" sign prop0;
   526        None)
   527     end;
   528 
   529 
   530 (* mk_procrule *)
   531 
   532 fun mk_procrule thm =
   533   let val (_,prems,lhs,elhs,rhs,_) = decomp_simp thm
   534   in if rewrite_rule_extra_vars prems lhs rhs
   535      then (prthm true "Extra vars on rhs:" thm; [])
   536      else [mk_rrule2{thm=thm, lhs=lhs, elhs=elhs, perm=false}]
   537   end;
   538 
   539 
   540 (* conversion to apply the meta simpset to a term *)
   541 
   542 (* Since the rewriting strategy is bottom-up, we avoid re-normalizing already
   543    normalized terms by carrying around the rhs of the rewrite rule just
   544    applied. This is called the `skeleton'. It is decomposed in parallel
   545    with the term. Once a Var is encountered, the corresponding term is
   546    already in normal form.
   547    skel0 is a dummy skeleton that is to enforce complete normalization.
   548 *)
   549 val skel0 = Bound 0;
   550 
   551 (* Use rhs as skeleton only if the lhs does not contain unnormalized bits.
   552    The latter may happen iff there are weak congruence rules for constants
   553    in the lhs.
   554 *)
   555 fun uncond_skel((_,weak),(lhs,rhs)) =
   556   if null weak then rhs (* optimization *)
   557   else if exists_Const (fn (c,_) => c mem weak) lhs then skel0
   558        else rhs;
   559 
   560 (* Behaves like unconditional rule if rhs does not contain vars not in the lhs.
   561    Otherwise those vars may become instantiated with unnormalized terms
   562    while the premises are solved.
   563 *)
   564 fun cond_skel(args as (congs,(lhs,rhs))) =
   565   if term_varnames rhs subset term_varnames lhs then uncond_skel(args)
   566   else skel0;
   567 
   568 (*
   569   we try in order:
   570     (1) beta reduction
   571     (2) unconditional rewrite rules
   572     (3) conditional rewrite rules
   573     (4) simplification procedures
   574 
   575   IMPORTANT: rewrite rules must not introduce new Vars or TVars!
   576 
   577 *)
   578 
   579 fun rewritec (prover, signt, maxt)
   580              (mss as Mss{rules, procs, termless, prems, congs, depth,...}) t =
   581   let
   582     val eta_thm = Thm.eta_conversion t;
   583     val eta_t' = rhs_of eta_thm;
   584     val eta_t = term_of eta_t';
   585     val tsigt = Sign.tsig_of signt;
   586     fun rew {thm, lhs, elhs, fo, perm} =
   587       let
   588         val {sign, prop, maxidx, ...} = rep_thm thm;
   589         val _ = if Sign.subsig (sign, signt) then ()
   590                 else (prthm true "Ignoring rewrite rule from different theory:" thm;
   591                       raise Pattern.MATCH);
   592         val (rthm, elhs') = if maxt = ~1 then (thm, elhs)
   593           else (Thm.incr_indexes (maxt+1) thm, Thm.cterm_incr_indexes (maxt+1) elhs);
   594         val insts = if fo then Thm.cterm_first_order_match (elhs', eta_t')
   595                           else Thm.cterm_match (elhs', eta_t');
   596         val thm' = Thm.instantiate insts (Thm.rename_boundvars lhs eta_t rthm);
   597         val prop' = #prop (rep_thm thm');
   598         val unconditional = (Logic.count_prems (prop',0) = 0);
   599         val (lhs', rhs') = Logic.dest_equals (Logic.strip_imp_concl prop')
   600       in
   601         if perm andalso not (termless (rhs', lhs'))
   602         then (trace_thm false "Cannot apply permutative rewrite rule:" thm;
   603               trace_thm false "Term does not become smaller:" thm'; None)
   604         else (trace_thm false "Applying instance of rewrite rule:" thm;
   605            if unconditional
   606            then
   607              (trace_thm false "Rewriting:" thm';
   608               let val lr = Logic.dest_equals prop;
   609                   val Some thm'' = check_conv false eta_thm thm'
   610               in Some (thm'', uncond_skel (congs, lr)) end)
   611            else
   612              (trace_thm false "Trying to rewrite:" thm';
   613               case prover (incr_depth mss) thm' of
   614                 None       => (trace_thm false "FAILED" thm'; None)
   615               | Some thm2 =>
   616                   (case check_conv true eta_thm thm2 of
   617                      None => None |
   618                      Some thm2' =>
   619                        let val concl = Logic.strip_imp_concl prop
   620                            val lr = Logic.dest_equals concl
   621                        in Some (thm2', cond_skel (congs, lr)) end)))
   622       end
   623 
   624     fun rews [] = None
   625       | rews (rrule :: rrules) =
   626           let val opt = rew rrule handle Pattern.MATCH => None
   627           in case opt of None => rews rrules | some => some end;
   628 
   629     fun sort_rrules rrs = let
   630       fun is_simple({thm, ...}:rrule) = case #prop (rep_thm thm) of 
   631                                       Const("==",_) $ _ $ _ => true
   632                                       | _                   => false 
   633       fun sort []        (re1,re2) = re1 @ re2
   634         | sort (rr::rrs) (re1,re2) = if is_simple rr 
   635                                      then sort rrs (rr::re1,re2)
   636                                      else sort rrs (re1,rr::re2)
   637     in sort rrs ([],[]) end
   638 
   639     fun proc_rews ([]:simproc list) = None
   640       | proc_rews ({name, proc, lhs, ...} :: ps) =
   641           if Pattern.matches tsigt (term_of lhs, term_of t) then
   642             (debug_term false ("Trying procedure " ^ quote name ^ " on:") signt eta_t;
   643              case proc signt prems eta_t of
   644                None => (debug false "FAILED"; proc_rews ps)
   645              | Some raw_thm =>
   646                  (trace_thm false ("Procedure " ^ quote name ^ " produced rewrite rule:") raw_thm;
   647                   (case rews (mk_procrule raw_thm) of
   648                     None => (trace_cterm false "IGNORED - does not match" t; proc_rews ps)
   649                   | some => some)))
   650           else proc_rews ps;
   651   in case eta_t of
   652        Abs _ $ _ => Some (transitive eta_thm
   653          (beta_conversion false (rhs_of eta_thm)), skel0)
   654      | _ => (case rews (sort_rrules (Net.match_term rules eta_t)) of
   655                None => proc_rews (Net.match_term procs eta_t)
   656              | some => some)
   657   end;
   658 
   659 
   660 (* conversion to apply a congruence rule to a term *)
   661 
   662 fun congc (prover,signt,maxt) {thm=cong,lhs=lhs} t =
   663   let val {sign, ...} = rep_thm cong
   664       val _ = if Sign.subsig (sign, signt) then ()
   665                  else error("Congruence rule from different theory")
   666       val rthm = if maxt = ~1 then cong else Thm.incr_indexes (maxt+1) cong;
   667       val rlhs = fst (Drule.dest_equals (Drule.strip_imp_concl (cprop_of rthm)));
   668       val insts = Thm.cterm_match (rlhs, t)
   669       (* Pattern.match can raise Pattern.MATCH;
   670          is handled when congc is called *)
   671       val thm' = Thm.instantiate insts (Thm.rename_boundvars (term_of rlhs) (term_of t) rthm);
   672       val unit = trace_thm false "Applying congruence rule:" thm';
   673       fun err (msg, thm) = (prthm false msg thm; error "Failed congruence proof!")
   674   in case prover thm' of
   675        None => err ("Could not prove", thm')
   676      | Some thm2 => (case check_conv true (beta_eta_conversion t) thm2 of
   677           None => err ("Should not have proved", thm2)
   678         | Some thm2' => thm2')
   679   end;
   680 
   681 val (cA, (cB, cC)) =
   682   apsnd dest_equals (dest_implies (hd (cprems_of Drule.imp_cong)));
   683 
   684 fun transitive' thm1 None = Some thm1
   685   | transitive' thm1 (Some thm2) = Some (transitive thm1 thm2);
   686 
   687 fun bottomc ((simprem,useprem,mutsimp), prover, sign, maxidx) =
   688   let
   689     fun botc skel mss t =
   690           if is_Var skel then None
   691           else
   692           (case subc skel mss t of
   693              some as Some thm1 =>
   694                (case rewritec (prover, sign, maxidx) mss (rhs_of thm1) of
   695                   Some (thm2, skel2) =>
   696                     transitive' (transitive thm1 thm2)
   697                       (botc skel2 mss (rhs_of thm2))
   698                 | None => some)
   699            | None =>
   700                (case rewritec (prover, sign, maxidx) mss t of
   701                   Some (thm2, skel2) => transitive' thm2
   702                     (botc skel2 mss (rhs_of thm2))
   703                 | None => None))
   704 
   705     and try_botc mss t =
   706           (case botc skel0 mss t of
   707              Some trec1 => trec1 | None => (reflexive t))
   708 
   709     and subc skel
   710           (mss as Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) t0 =
   711        (case term_of t0 of
   712            Abs (a, T, t) =>
   713              let val b = variant bounds a
   714                  val (v, t') = Thm.dest_abs (Some ("." ^ b)) t0
   715                  val mss' = mk_mss (rules, congs, procs, b :: bounds, prems, mk_rews, termless,depth)
   716                  val skel' = case skel of Abs (_, _, sk) => sk | _ => skel0
   717              in case botc skel' mss' t' of
   718                   Some thm => Some (abstract_rule a v thm)
   719                 | None => None
   720              end
   721          | t $ _ => (case t of
   722              Const ("==>", _) $ _  =>
   723                let val (s, u) = Drule.dest_implies t0
   724                in impc (s, u, mss) end
   725            | Abs _ =>
   726                let val thm = beta_conversion false t0
   727                in case subc skel0 mss (rhs_of thm) of
   728                     None => Some thm
   729                   | Some thm' => Some (transitive thm thm')
   730                end
   731            | _  =>
   732                let fun appc () =
   733                      let
   734                        val (tskel, uskel) = case skel of
   735                            tskel $ uskel => (tskel, uskel)
   736                          | _ => (skel0, skel0);
   737                        val (ct, cu) = Thm.dest_comb t0
   738                      in
   739                      (case botc tskel mss ct of
   740                         Some thm1 =>
   741                           (case botc uskel mss cu of
   742                              Some thm2 => Some (combination thm1 thm2)
   743                            | None => Some (combination thm1 (reflexive cu)))
   744                       | None =>
   745                           (case botc uskel mss cu of
   746                              Some thm1 => Some (combination (reflexive ct) thm1)
   747                            | None => None))
   748                      end
   749                    val (h, ts) = strip_comb t
   750                in case h of
   751                     Const(a, _) =>
   752                       (case assoc_string (fst congs, a) of
   753                          None => appc ()
   754                        | Some cong =>
   755 (* post processing: some partial applications h t1 ... tj, j <= length ts,
   756    may be a redex. Example: map (%x.x) = (%xs.xs) wrt map_cong *)
   757                           (let
   758                              val thm = congc (prover mss, sign, maxidx) cong t0;
   759                              val t = rhs_of thm;
   760                              val (cl, cr) = Thm.dest_comb t
   761                              val dVar = Var(("", 0), dummyT)
   762                              val skel =
   763                                list_comb (h, replicate (length ts) dVar)
   764                            in case botc skel mss cl of
   765                                 None => Some thm
   766                               | Some thm' => Some (transitive thm
   767                                   (combination thm' (reflexive cr)))
   768                            end handle TERM _ => error "congc result"
   769                                     | Pattern.MATCH => appc ()))
   770                   | _ => appc ()
   771                end)
   772          | _ => None)
   773 
   774     and impc args =
   775       if mutsimp
   776       then let val (prem, conc, mss) = args
   777            in apsome snd (mut_impc ([], prem, conc, mss)) end
   778       else nonmut_impc args
   779 
   780     and mut_impc (prems, prem, conc, mss) = (case botc skel0 mss prem of
   781         None => mut_impc1 (prems, prem, conc, mss)
   782       | Some thm1 =>
   783           let val prem1 = rhs_of thm1
   784           in (case mut_impc1 (prems, prem1, conc, mss) of
   785               None => Some (None,
   786                 combination (combination refl_implies thm1) (reflexive conc))
   787             | Some (x, thm2) => Some (x, transitive (combination (combination
   788                 refl_implies thm1) (reflexive conc)) thm2))
   789           end)
   790 
   791     and mut_impc1 (prems, prem1, conc, mss) =
   792       let
   793         fun uncond ({thm, lhs, elhs, perm}) =
   794           if Thm.no_prems thm then Some lhs else None
   795 
   796         val (lhss1, mss1) =
   797           if maxidx_of_term (term_of prem1) <> ~1
   798           then (trace_cterm true
   799             "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem1;
   800                 ([],mss))
   801           else let val thm = assume prem1
   802                    val rrules1 = extract_safe_rrules (mss, thm)
   803                    val lhss1 = mapfilter uncond rrules1
   804                    val mss1 = foldl insert_rrule (add_prems (mss, [thm]), rrules1)
   805                in (lhss1, mss1) end
   806 
   807         fun disch1 thm =
   808           let val (cB', cC') = dest_eq thm
   809           in
   810             implies_elim (Thm.instantiate
   811               ([], [(cA, prem1), (cB, cB'), (cC, cC')]) Drule.imp_cong)
   812               (implies_intr prem1 thm)
   813           end
   814 
   815         fun rebuild None = (case rewritec (prover, sign, maxidx) mss
   816             (mk_implies (prem1, conc)) of
   817               None => None
   818             | Some (thm, _) => 
   819                 let val (prem, conc) = Drule.dest_implies (rhs_of thm)
   820                 in (case mut_impc (prems, prem, conc, mss) of
   821                     None => Some (None, thm)
   822                   | Some (x, thm') => Some (x, transitive thm thm'))
   823                 end handle TERM _ => Some (None, thm))
   824           | rebuild (Some thm2) =
   825             let val thm = disch1 thm2
   826             in (case rewritec (prover, sign, maxidx) mss (rhs_of thm) of
   827                  None => Some (None, thm)
   828                | Some (thm', _) =>
   829                    let val (prem, conc) = Drule.dest_implies (rhs_of thm')
   830                    in (case mut_impc (prems, prem, conc, mss) of
   831                        None => Some (None, transitive thm thm')
   832                      | Some (x, thm'') =>
   833                          Some (x, transitive (transitive thm thm') thm''))
   834                    end handle TERM _ => Some (None, transitive thm thm'))
   835             end
   836 
   837         fun simpconc () =
   838           let val (s, t) = Drule.dest_implies conc
   839           in case mut_impc (prems @ [prem1], s, t, mss1) of
   840                None => rebuild None
   841              | Some (Some i, thm2) =>
   842                   let
   843                     val (prem, cC') = Drule.dest_implies (rhs_of thm2);
   844                     val thm2' = transitive (disch1 thm2) (Thm.instantiate
   845                       ([], [(cA, prem1), (cB, prem), (cC, cC')])
   846                       Drule.swap_prems_eq)
   847                   in if i=0 then apsome (apsnd (transitive thm2'))
   848                        (mut_impc1 (prems, prem, mk_implies (prem1, cC'), mss))
   849                      else Some (Some (i-1), thm2')
   850                   end
   851              | Some (None, thm) => rebuild (Some thm)
   852           end handle TERM _ => rebuild (botc skel0 mss1 conc)
   853 
   854       in
   855         let
   856           val tsig = Sign.tsig_of sign
   857           fun reducible t =
   858             exists (fn lhs => Pattern.matches_subterm tsig (lhs, term_of t)) lhss1;
   859         in case dropwhile (not o reducible) prems of
   860             [] => simpconc ()
   861           | red::rest => (trace_cterm false "Can now reduce premise:" red;
   862               Some (Some (length rest), reflexive (mk_implies (prem1, conc))))
   863         end
   864       end
   865 
   866      (* legacy code - only for backwards compatibility *)
   867      and nonmut_impc (prem, conc, mss) =
   868        let val thm1 = if simprem then botc skel0 mss prem else None;
   869            val prem1 = if_none (apsome rhs_of thm1) prem;
   870            val maxidx1 = maxidx_of_term (term_of prem1)
   871            val mss1 =
   872              if not useprem then mss else
   873              if maxidx1 <> ~1
   874              then (trace_cterm true
   875                "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem1;
   876                    mss)
   877              else let val thm = assume prem1
   878                   in add_safe_simp (add_prems (mss, [thm]), thm) end
   879        in (case botc skel0 mss1 conc of
   880            None => (case thm1 of
   881                None => None
   882              | Some thm1' => Some (combination
   883                  (combination refl_implies thm1') (reflexive conc)))
   884          | Some thm2 =>
   885            let
   886              val conc2 = rhs_of thm2;
   887              val thm2' = implies_elim (Thm.instantiate
   888                ([], [(cA, prem1), (cB, conc), (cC, conc2)]) Drule.imp_cong)
   889                (implies_intr prem1 thm2)
   890            in (case thm1 of
   891                None => Some thm2'
   892              | Some thm1' => Some (transitive (combination
   893                  (combination refl_implies thm1') (reflexive conc)) thm2'))
   894            end)
   895        end
   896 
   897  in try_botc end;
   898 
   899 
   900 (*** Meta-rewriting: rewrites t to u and returns the theorem t==u ***)
   901 
   902 (*
   903   Parameters:
   904     mode = (simplify A,
   905             use A in simplifying B,
   906             use prems of B (if B is again a meta-impl.) to simplify A)
   907            when simplifying A ==> B
   908     mss: contains equality theorems of the form [|p1,...|] ==> t==u
   909     prover: how to solve premises in conditional rewrites and congruences
   910 *)
   911 
   912 (* FIXME: check that #bounds(mss) does not "occur" in ct already *)
   913 
   914 fun rewrite_cterm mode prover mss ct =
   915   let val {sign, t, maxidx, ...} = rep_cterm ct
   916       val Mss{depth, ...} = mss
   917   in simp_depth := depth;
   918      bottomc (mode, prover, sign, maxidx) mss ct
   919   end
   920   handle THM (s, _, thms) =>
   921     error ("Exception THM was raised in simplifier:\n" ^ s ^ "\n" ^
   922       Pretty.string_of (pretty_thms thms));
   923 
   924 (*In [A1,...,An]==>B, rewrite the selected A's only -- for rewrite_goals_tac*)
   925 (*Do not rewrite flex-flex pairs*)
   926 fun goals_conv pred cv =
   927   let fun gconv i ct =
   928         let val (A,B) = Drule.dest_implies ct
   929             val (thA,j) = case term_of A of
   930                   Const("=?=",_)$_$_ => (reflexive A, i)
   931                 | _ => (if pred i then cv A else reflexive A, i+1)
   932         in  combination (combination refl_implies thA) (gconv j B) end
   933         handle TERM _ => reflexive ct
   934   in gconv 1 end;
   935 
   936 (*Use a conversion to transform a theorem*)
   937 fun fconv_rule cv th = equal_elim (cv (cprop_of th)) th;
   938 
   939 (*Rewrite a cterm (yielding again a cterm instead of a theorem)*)
   940 fun full_rewrite_cterm_aux _ [] = (fn ct => ct)
   941   | full_rewrite_cterm_aux prover thms =
   942       #2 o Thm.dest_comb o #prop o Thm.crep_thm o
   943         rewrite_cterm (true, false, false) prover (mss_of thms);
   944 
   945 (*Rewrite a theorem*)
   946 fun rewrite_rule_aux _ [] = (fn th => th)
   947   | rewrite_rule_aux prover thms =
   948       fconv_rule (rewrite_cterm (true,false,false) prover (mss_of thms));
   949 
   950 fun rewrite_thm mode prover mss = fconv_rule (rewrite_cterm mode prover mss);
   951 
   952 (*Rewrite the subgoals of a proof state (represented by a theorem) *)
   953 fun rewrite_goals_rule_aux _ []   th = th
   954   | rewrite_goals_rule_aux prover thms th =
   955       fconv_rule (goals_conv (K true) (rewrite_cterm (true, true, false) prover
   956         (mss_of thms))) th;
   957 
   958 (*Rewrite the subgoal of a proof state (represented by a theorem) *)
   959 fun rewrite_goal_rule mode prover mss i thm =
   960   if 0 < i  andalso  i <= nprems_of thm
   961   then fconv_rule (goals_conv (fn j => j=i) (rewrite_cterm mode prover mss)) thm
   962   else raise THM("rewrite_goal_rule",i,[thm]);
   963 
   964 end;
   965 
   966 structure BasicMetaSimplifier: BASIC_META_SIMPLIFIER = MetaSimplifier;
   967 open BasicMetaSimplifier;