src/Pure/meta_simplifier.ML
author nipkow
Thu Aug 23 14:32:48 2001 +0200 (2001-08-23)
changeset 11504 935f9e8de0d0
parent 11371 1d5d181b7e28
child 11505 a410fa8acfca
permissions -rw-r--r--
Traced depth of conditional rewriting
     1 (*  Title:      Pure/meta_simplifier.ML
     2     ID:         $Id$
     3     Author:     Tobias Nipkow
     4     Copyright   1994  University of Cambridge
     5 
     6 Meta Simplification
     7 *)
     8 
     9 signature META_SIMPLIFIER =
    10 sig
    11   exception SIMPLIFIER of string * thm
    12   type meta_simpset
    13   val dest_mss		: meta_simpset ->
    14     {simps: thm list, congs: thm list, procs: (string * cterm list) list}
    15   val empty_mss         : meta_simpset
    16   val clear_mss		: meta_simpset -> meta_simpset
    17   val merge_mss		: meta_simpset * meta_simpset -> meta_simpset
    18   val add_simps         : meta_simpset * thm list -> meta_simpset
    19   val del_simps         : meta_simpset * thm list -> meta_simpset
    20   val mss_of            : thm list -> meta_simpset
    21   val add_congs         : meta_simpset * thm list -> meta_simpset
    22   val del_congs         : meta_simpset * thm list -> meta_simpset
    23   val add_simprocs	: meta_simpset *
    24     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    25       -> meta_simpset
    26   val del_simprocs	: meta_simpset *
    27     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    28       -> meta_simpset
    29   val add_prems         : meta_simpset * thm list -> meta_simpset
    30   val prems_of_mss      : meta_simpset -> thm list
    31   val set_mk_rews       : meta_simpset * (thm -> thm list) -> meta_simpset
    32   val set_mk_sym        : meta_simpset * (thm -> thm option) -> meta_simpset
    33   val set_mk_eq_True    : meta_simpset * (thm -> thm option) -> meta_simpset
    34   val set_termless      : meta_simpset * (term * term -> bool) -> meta_simpset
    35   val trace_simp        : bool ref
    36   val debug_simp        : bool ref
    37   val rewrite_cterm     : bool * bool * bool
    38                           -> (meta_simpset -> thm -> thm option)
    39                           -> meta_simpset -> cterm -> thm
    40   val rewrite_rule_aux  : (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
    41   val rewrite_thm       : bool * bool * bool
    42                           -> (meta_simpset -> thm -> thm option)
    43                           -> meta_simpset -> thm -> thm
    44   val rewrite_goals_rule_aux: (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
    45   val rewrite_goal_rule : bool* bool * bool
    46                           -> (meta_simpset -> thm -> thm option)
    47                           -> meta_simpset -> int -> thm -> thm
    48 end;
    49 
    50 structure MetaSimplifier : META_SIMPLIFIER =
    51 struct
    52 
    53 (** diagnostics **)
    54 
    55 exception SIMPLIFIER of string * thm;
    56 
    57 fun prnt warn a = if warn then warning a else writeln a;
    58 
    59 fun prtm warn a sign t =
    60   (prnt warn a; prnt warn (Sign.string_of_term sign t));
    61 
    62 fun prctm warn a t =
    63   (prnt warn a; prnt warn (Display.string_of_cterm t));
    64 
    65 fun prthm warn a thm =
    66   let val {sign, prop, ...} = rep_thm thm
    67   in prtm warn a sign prop end;
    68 
    69 val trace_simp = ref false;
    70 val debug_simp = ref false;
    71 
    72 fun trace warn a = if !trace_simp then prnt warn a else ();
    73 fun debug warn a = if !debug_simp then prnt warn a else ();
    74 
    75 fun trace_term warn a sign t = if !trace_simp then prtm warn a sign t else ();
    76 fun trace_cterm warn a t = if !trace_simp then prctm warn a t else ();
    77 fun debug_term warn a sign t = if !debug_simp then prtm warn a sign t else ();
    78 
    79 fun trace_thm warn a thm =
    80   let val {sign, prop, ...} = rep_thm thm
    81   in trace_term warn a sign prop end;
    82 
    83 
    84 
    85 (** meta simp sets **)
    86 
    87 (* basic components *)
    88 
    89 type rrule = {thm: thm, lhs: term, elhs: cterm, fo: bool, perm: bool};
    90 (* thm: the rewrite rule
    91    lhs: the left-hand side
    92    elhs: the etac-contracted lhs.
    93    fo:  use first-order matching
    94    perm: the rewrite rule is permutative
    95 Reamrks:
    96   - elhs is used for matching,
    97     lhs only for preservation of bound variable names.
    98   - fo is set iff
    99     either elhs is first-order (no Var is applied),
   100            in which case fo-matching is complete,
   101     or elhs is not a pattern,
   102        in which case there is nothing better to do.
   103 *)
   104 type cong = {thm: thm, lhs: cterm};
   105 type simproc =
   106  {name: string, proc: Sign.sg -> thm list -> term -> thm option, lhs: cterm, id: stamp};
   107 
   108 fun eq_rrule ({thm = thm1, ...}: rrule, {thm = thm2, ...}: rrule) =
   109   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   110 
   111 fun eq_cong ({thm = thm1, ...}: cong, {thm = thm2, ...}: cong) = 
   112   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   113 
   114 fun eq_prem (thm1, thm2) =
   115   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   116 
   117 fun eq_simproc ({id = s1, ...}:simproc, {id = s2, ...}:simproc) = (s1 = s2);
   118 
   119 fun mk_simproc (name, proc, lhs, id) =
   120   {name = name, proc = proc, lhs = lhs, id = id};
   121 
   122 
   123 (* datatype mss *)
   124 
   125 (*
   126   A "mss" contains data needed during conversion:
   127     rules: discrimination net of rewrite rules;
   128     congs: association list of congruence rules and
   129            a list of `weak' congruence constants.
   130            A congruence is `weak' if it avoids normalization of some argument.
   131     procs: discrimination net of simplification procedures
   132       (functions that prove rewrite rules on the fly);
   133     bounds: names of bound variables already used
   134       (for generating new names when rewriting under lambda abstractions);
   135     prems: current premises;
   136     mk_rews: mk: turns simplification thms into rewrite rules;
   137              mk_sym: turns == around; (needs Drule!)
   138              mk_eq_True: turns P into P == True - logic specific;
   139     termless: relation for ordered rewriting;
   140     depth: depth of conditional rewriting;
   141 *)
   142 
   143 datatype meta_simpset =
   144   Mss of {
   145     rules: rrule Net.net,
   146     congs: (string * cong) list * string list,
   147     procs: simproc Net.net,
   148     bounds: string list,
   149     prems: thm list,
   150     mk_rews: {mk: thm -> thm list,
   151               mk_sym: thm -> thm option,
   152               mk_eq_True: thm -> thm option},
   153     termless: term * term -> bool,
   154     depth: int};
   155 
   156 fun mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth) =
   157   Mss {rules = rules, congs = congs, procs = procs, bounds = bounds,
   158        prems=prems, mk_rews=mk_rews, termless=termless, depth=depth};
   159 
   160 fun upd_rules(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}, rules') =
   161   mk_mss(rules',congs,procs,bounds,prems,mk_rews,termless,depth);
   162 
   163 val empty_mss =
   164   let val mk_rews = {mk = K [], mk_sym = K None, mk_eq_True = K None}
   165   in mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, Term.termless, 0) end;
   166 
   167 fun clear_mss (Mss {mk_rews, termless, ...}) =
   168   mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, termless,0);
   169 
   170 fun incr_depth(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) =
   171   mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth+1)
   172  
   173 
   174 
   175 (** simpset operations **)
   176 
   177 (* term variables *)
   178 
   179 val add_term_varnames = foldl_aterms (fn (xs, Var (x, _)) => ins_ix (x, xs) | (xs, _) => xs);
   180 fun term_varnames t = add_term_varnames ([], t);
   181 
   182 
   183 (* dest_mss *)
   184 
   185 fun dest_mss (Mss {rules, congs, procs, ...}) =
   186   {simps = map (fn (_, {thm, ...}) => thm) (Net.dest rules),
   187    congs = map (fn (_, {thm, ...}) => thm) (fst congs),
   188    procs =
   189      map (fn (_, {name, lhs, id, ...}) => ((name, lhs), id)) (Net.dest procs)
   190      |> partition_eq eq_snd
   191      |> map (fn ps => (#1 (#1 (hd ps)), map (#2 o #1) ps))
   192      |> Library.sort_wrt #1};
   193 
   194 
   195 (* merge_mss *)	      (*NOTE: ignores mk_rews, termless and depth of 2nd mss*)
   196 
   197 fun merge_mss
   198  (Mss {rules = rules1, congs = (congs1,weak1), procs = procs1,
   199        bounds = bounds1, prems = prems1, mk_rews, termless, depth},
   200   Mss {rules = rules2, congs = (congs2,weak2), procs = procs2,
   201        bounds = bounds2, prems = prems2, ...}) =
   202       mk_mss
   203        (Net.merge (rules1, rules2, eq_rrule),
   204         (generic_merge (eq_cong o pairself snd) I I congs1 congs2,
   205         merge_lists weak1 weak2),
   206         Net.merge (procs1, procs2, eq_simproc),
   207         merge_lists bounds1 bounds2,
   208         generic_merge eq_prem I I prems1 prems2,
   209         mk_rews, termless, depth);
   210 
   211 
   212 (* add_simps *)
   213 
   214 fun mk_rrule2{thm,lhs,elhs,perm} =
   215   let val fo = Pattern.first_order (term_of elhs) orelse not(Pattern.pattern (term_of elhs))
   216   in {thm=thm,lhs=lhs,elhs=elhs,fo=fo,perm=perm} end
   217 
   218 fun insert_rrule(mss as Mss {rules,...},
   219                  rrule as {thm,lhs,elhs,perm}) =
   220   (trace_thm false "Adding rewrite rule:" thm;
   221    let val rrule2 as {elhs,...} = mk_rrule2 rrule
   222        val rules' = Net.insert_term ((term_of elhs, rrule2), rules, eq_rrule)
   223    in upd_rules(mss,rules') end
   224    handle Net.INSERT =>
   225      (prthm true "Ignoring duplicate rewrite rule:" thm; mss));
   226 
   227 fun vperm (Var _, Var _) = true
   228   | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t)
   229   | vperm (t1 $ t2, u1 $ u2) = vperm (t1, u1) andalso vperm (t2, u2)
   230   | vperm (t, u) = (t = u);
   231 
   232 fun var_perm (t, u) =
   233   vperm (t, u) andalso eq_set (term_varnames t, term_varnames u);
   234 
   235 (* FIXME: it seems that the conditions on extra variables are too liberal if
   236 prems are nonempty: does solving the prems really guarantee instantiation of
   237 all its Vars? Better: a dynamic check each time a rule is applied.
   238 *)
   239 fun rewrite_rule_extra_vars prems elhs erhs =
   240   not (term_varnames erhs subset foldl add_term_varnames (term_varnames elhs, prems))
   241   orelse
   242   not ((term_tvars erhs) subset
   243        (term_tvars elhs  union  List.concat(map term_tvars prems)));
   244 
   245 (*Simple test for looping rewrite rules and stupid orientations*)
   246 fun reorient sign prems lhs rhs =
   247    rewrite_rule_extra_vars prems lhs rhs
   248   orelse
   249    is_Var (head_of lhs)
   250   orelse
   251    (exists (apl (lhs, Logic.occs)) (rhs :: prems))
   252   orelse
   253    (null prems andalso
   254     Pattern.matches (#tsig (Sign.rep_sg sign)) (lhs, rhs))
   255     (*the condition "null prems" is necessary because conditional rewrites
   256       with extra variables in the conditions may terminate although
   257       the rhs is an instance of the lhs. Example: ?m < ?n ==> f(?n) == f(?m)*)
   258   orelse
   259    (is_Const lhs andalso not(is_Const rhs))
   260 
   261 fun decomp_simp thm =
   262   let val {sign, prop, ...} = rep_thm thm;
   263       val prems = Logic.strip_imp_prems prop;
   264       val concl = Drule.strip_imp_concl (cprop_of thm);
   265       val (lhs, rhs) = Drule.dest_equals concl handle TERM _ =>
   266         raise SIMPLIFIER ("Rewrite rule not a meta-equality", thm)
   267       val elhs = snd (Drule.dest_equals (cprop_of (Thm.eta_conversion lhs)));
   268       val elhs = if elhs=lhs then lhs else elhs (* try to share *)
   269       val erhs = Pattern.eta_contract (term_of rhs);
   270       val perm = var_perm (term_of elhs, erhs) andalso not (term_of elhs aconv erhs)
   271                  andalso not (is_Var (term_of elhs))
   272   in (sign, prems, term_of lhs, elhs, term_of rhs, perm) end;
   273 
   274 fun mk_eq_True (Mss{mk_rews={mk_eq_True,...},...}) thm =
   275   case mk_eq_True thm of
   276     None => []
   277   | Some eq_True => let val (_,_,lhs,elhs,_,_) = decomp_simp eq_True
   278                     in [{thm=eq_True, lhs=lhs, elhs=elhs, perm=false}] end;
   279 
   280 (* create the rewrite rule and possibly also the ==True variant,
   281    in case there are extra vars on the rhs *)
   282 fun rrule_eq_True(thm,lhs,elhs,rhs,mss,thm2) =
   283   let val rrule = {thm=thm, lhs=lhs, elhs=elhs, perm=false}
   284   in if (term_varnames rhs)  subset (term_varnames lhs) andalso
   285         (term_tvars rhs) subset (term_tvars lhs)
   286      then [rrule]
   287      else mk_eq_True mss thm2 @ [rrule]
   288   end;
   289 
   290 fun mk_rrule mss thm =
   291   let val (_,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   292   in if perm then [{thm=thm, lhs=lhs, elhs=elhs, perm=true}] else
   293      (* weak test for loops: *)
   294      if rewrite_rule_extra_vars prems lhs rhs orelse
   295         is_Var (term_of elhs)
   296      then mk_eq_True mss thm
   297      else rrule_eq_True(thm,lhs,elhs,rhs,mss,thm)
   298   end;
   299 
   300 fun orient_rrule mss thm =
   301   let val (sign,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   302   in if perm then [{thm=thm,lhs=lhs,elhs=elhs,perm=true}]
   303      else if reorient sign prems lhs rhs
   304           then if reorient sign prems rhs lhs
   305                then mk_eq_True mss thm
   306                else let val Mss{mk_rews={mk_sym,...},...} = mss
   307                     in case mk_sym thm of
   308                          None => []
   309                        | Some thm' =>
   310                            let val (_,_,lhs',elhs',rhs',_) = decomp_simp thm'
   311                            in rrule_eq_True(thm',lhs',elhs',rhs',mss,thm) end
   312                     end
   313           else rrule_eq_True(thm,lhs,elhs,rhs,mss,thm)
   314   end;
   315 
   316 fun extract_rews(Mss{mk_rews = {mk,...},...},thms) = flat(map mk thms);
   317 
   318 fun orient_comb_simps comb mk_rrule (mss,thms) =
   319   let val rews = extract_rews(mss,thms)
   320       val rrules = flat (map mk_rrule rews)
   321   in foldl comb (mss,rrules) end
   322 
   323 (* Add rewrite rules explicitly; do not reorient! *)
   324 fun add_simps(mss,thms) =
   325   orient_comb_simps insert_rrule (mk_rrule mss) (mss,thms);
   326 
   327 fun mss_of thms =
   328   foldl insert_rrule (empty_mss, flat(map (mk_rrule empty_mss) thms));
   329 
   330 fun extract_safe_rrules(mss,thm) =
   331   flat (map (orient_rrule mss) (extract_rews(mss,[thm])));
   332 
   333 fun add_safe_simp(mss,thm) =
   334   foldl insert_rrule (mss, extract_safe_rrules(mss,thm))
   335 
   336 (* del_simps *)
   337 
   338 fun del_rrule(mss as Mss {rules,...},
   339               rrule as {thm, elhs, ...}) =
   340   (upd_rules(mss, Net.delete_term ((term_of elhs, rrule), rules, eq_rrule))
   341    handle Net.DELETE =>
   342      (prthm true "Rewrite rule not in simpset:" thm; mss));
   343 
   344 fun del_simps(mss,thms) =
   345   orient_comb_simps del_rrule (map mk_rrule2 o mk_rrule mss) (mss,thms);
   346 
   347 
   348 (* add_congs *)
   349 
   350 fun is_full_cong_prems [] varpairs = null varpairs
   351   | is_full_cong_prems (p::prems) varpairs =
   352     (case Logic.strip_assums_concl p of
   353        Const("==",_) $ lhs $ rhs =>
   354          let val (x,xs) = strip_comb lhs and (y,ys) = strip_comb rhs
   355          in is_Var x  andalso  forall is_Bound xs  andalso
   356             null(findrep(xs))  andalso xs=ys andalso
   357             (x,y) mem varpairs andalso
   358             is_full_cong_prems prems (varpairs\(x,y))
   359          end
   360      | _ => false);
   361 
   362 fun is_full_cong thm =
   363 let val prems = prems_of thm
   364     and concl = concl_of thm
   365     val (lhs,rhs) = Logic.dest_equals concl
   366     val (f,xs) = strip_comb lhs
   367     and (g,ys) = strip_comb rhs
   368 in
   369   f=g andalso null(findrep(xs@ys)) andalso length xs = length ys andalso
   370   is_full_cong_prems prems (xs ~~ ys)
   371 end
   372 
   373 fun add_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   374   let
   375     val (lhs, _) = Drule.dest_equals (Drule.strip_imp_concl (cprop_of thm)) handle TERM _ =>
   376       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   377 (*   val lhs = Pattern.eta_contract lhs; *)
   378     val (a, _) = dest_Const (head_of (term_of lhs)) handle TERM _ =>
   379       raise SIMPLIFIER ("Congruence must start with a constant", thm);
   380     val (alist,weak) = congs
   381     val alist2 = overwrite_warn (alist, (a,{lhs=lhs, thm=thm}))
   382            ("Overwriting congruence rule for " ^ quote a);
   383     val weak2 = if is_full_cong thm then weak else a::weak
   384   in
   385     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   386   end;
   387 
   388 val (op add_congs) = foldl add_cong;
   389 
   390 
   391 (* del_congs *)
   392 
   393 fun del_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   394   let
   395     val (lhs, _) = Logic.dest_equals (concl_of thm) handle TERM _ =>
   396       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   397 (*   val lhs = Pattern.eta_contract lhs; *)
   398     val (a, _) = dest_Const (head_of lhs) handle TERM _ =>
   399       raise SIMPLIFIER ("Congruence must start with a constant", thm);
   400     val (alist,_) = congs
   401     val alist2 = filter (fn (x,_)=> x<>a) alist
   402     val weak2 = mapfilter (fn(a,{thm,...}) => if is_full_cong thm then None
   403                                               else Some a)
   404                    alist2
   405   in
   406     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   407   end;
   408 
   409 val (op del_congs) = foldl del_cong;
   410 
   411 
   412 (* add_simprocs *)
   413 
   414 fun add_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   415     (name, lhs, proc, id)) =
   416   let val {sign, t, ...} = rep_cterm lhs
   417   in (trace_term false ("Adding simplification procedure " ^ quote name ^ " for")
   418       sign t;
   419     mk_mss (rules, congs,
   420       Net.insert_term ((t, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   421         handle Net.INSERT => 
   422 	    (warning ("Ignoring duplicate simplification procedure \"" 
   423 	              ^ name ^ "\""); 
   424 	     procs),
   425         bounds, prems, mk_rews, termless,depth))
   426   end;
   427 
   428 fun add_simproc (mss, (name, lhss, proc, id)) =
   429   foldl add_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   430 
   431 val add_simprocs = foldl add_simproc;
   432 
   433 
   434 (* del_simprocs *)
   435 
   436 fun del_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   437     (name, lhs, proc, id)) =
   438   mk_mss (rules, congs,
   439     Net.delete_term ((term_of lhs, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   440       handle Net.DELETE => 
   441 	  (warning ("Simplification procedure \"" ^ name ^
   442 		       "\" not in simpset"); procs),
   443       bounds, prems, mk_rews, termless, depth);
   444 
   445 fun del_simproc (mss, (name, lhss, proc, id)) =
   446   foldl del_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   447 
   448 val del_simprocs = foldl del_simproc;
   449 
   450 
   451 (* prems *)
   452 
   453 fun add_prems (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thms) =
   454   mk_mss (rules, congs, procs, bounds, thms @ prems, mk_rews, termless, depth);
   455 
   456 fun prems_of_mss (Mss {prems, ...}) = prems;
   457 
   458 
   459 (* mk_rews *)
   460 
   461 fun set_mk_rews
   462   (Mss {rules, congs, procs, bounds, prems, mk_rews, termless, depth}, mk) =
   463     mk_mss (rules, congs, procs, bounds, prems,
   464             {mk=mk, mk_sym= #mk_sym mk_rews, mk_eq_True= #mk_eq_True mk_rews},
   465             termless, depth);
   466 
   467 fun set_mk_sym
   468   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_sym) =
   469     mk_mss (rules, congs, procs, bounds, prems,
   470             {mk= #mk mk_rews, mk_sym= mk_sym, mk_eq_True= #mk_eq_True mk_rews},
   471             termless,depth);
   472 
   473 fun set_mk_eq_True
   474   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_eq_True) =
   475     mk_mss (rules, congs, procs, bounds, prems,
   476             {mk= #mk mk_rews, mk_sym= #mk_sym mk_rews, mk_eq_True= mk_eq_True},
   477             termless,depth);
   478 
   479 (* termless *)
   480 
   481 fun set_termless
   482   (Mss {rules, congs, procs, bounds, prems, mk_rews, depth, ...}, termless) =
   483     mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth);
   484 
   485 
   486 
   487 (** rewriting **)
   488 
   489 (*
   490   Uses conversions, see:
   491     L C Paulson, A higher-order implementation of rewriting,
   492     Science of Computer Programming 3 (1983), pages 119-149.
   493 *)
   494 
   495 type prover = meta_simpset -> thm -> thm option;
   496 type termrec = (Sign.sg_ref * term list) * term;
   497 type conv = meta_simpset -> termrec -> termrec;
   498 
   499 val dest_eq = Drule.dest_equals o cprop_of;
   500 val lhs_of = fst o dest_eq;
   501 val rhs_of = snd o dest_eq;
   502 
   503 fun beta_eta_conversion t =
   504   let val thm = beta_conversion true t;
   505   in transitive thm (eta_conversion (rhs_of thm)) end;
   506 
   507 fun check_conv msg thm thm' =
   508   let
   509     val thm'' = transitive thm (transitive
   510       (symmetric (beta_eta_conversion (lhs_of thm'))) thm')
   511   in (if msg then trace_thm false "SUCCEEDED" thm' else (); Some thm'') end
   512   handle THM _ =>
   513     let val {sign, prop = _ $ _ $ prop0, ...} = rep_thm thm;
   514     in
   515       (trace_thm false "Proved wrong thm (Check subgoaler?)" thm';
   516        trace_term false "Should have proved:" sign prop0;
   517        None)
   518     end;
   519 
   520 
   521 (* mk_procrule *)
   522 
   523 fun mk_procrule thm =
   524   let val (_,prems,lhs,elhs,rhs,_) = decomp_simp thm
   525   in if rewrite_rule_extra_vars prems lhs rhs
   526      then (prthm true "Extra vars on rhs:" thm; [])
   527      else [mk_rrule2{thm=thm, lhs=lhs, elhs=elhs, perm=false}]
   528   end;
   529 
   530 
   531 (* conversion to apply the meta simpset to a term *)
   532 
   533 (* Since the rewriting strategy is bottom-up, we avoid re-normalizing already
   534    normalized terms by carrying around the rhs of the rewrite rule just
   535    applied. This is called the `skeleton'. It is decomposed in parallel
   536    with the term. Once a Var is encountered, the corresponding term is
   537    already in normal form.
   538    skel0 is a dummy skeleton that is to enforce complete normalization.
   539 *)
   540 val skel0 = Bound 0;
   541 
   542 (* Use rhs as skeleton only if the lhs does not contain unnormalized bits.
   543    The latter may happen iff there are weak congruence rules for constants
   544    in the lhs.
   545 *)
   546 fun uncond_skel((_,weak),(lhs,rhs)) =
   547   if null weak then rhs (* optimization *)
   548   else if exists_Const (fn (c,_) => c mem weak) lhs then skel0
   549        else rhs;
   550 
   551 (* Behaves like unconditional rule if rhs does not contain vars not in the lhs.
   552    Otherwise those vars may become instantiated with unnormalized terms
   553    while the premises are solved.
   554 *)
   555 fun cond_skel(args as (congs,(lhs,rhs))) =
   556   if term_varnames rhs subset term_varnames lhs then uncond_skel(args)
   557   else skel0;
   558 
   559 (*
   560   we try in order:
   561     (1) beta reduction
   562     (2) unconditional rewrite rules
   563     (3) conditional rewrite rules
   564     (4) simplification procedures
   565 
   566   IMPORTANT: rewrite rules must not introduce new Vars or TVars!
   567 
   568 *)
   569 
   570 fun rewritec (prover, signt, maxt)
   571              (mss as Mss{rules, procs, termless, prems, congs, depth,...}) t =
   572   let
   573     val eta_thm = Thm.eta_conversion t;
   574     val eta_t' = rhs_of eta_thm;
   575     val eta_t = term_of eta_t';
   576     val tsigt = Sign.tsig_of signt;
   577     fun rew {thm, lhs, elhs, fo, perm} =
   578       let
   579         val {sign, prop, maxidx, ...} = rep_thm thm;
   580         val _ = if Sign.subsig (sign, signt) then ()
   581                 else (prthm true "Ignoring rewrite rule from different theory:" thm;
   582                       raise Pattern.MATCH);
   583         val (rthm, elhs') = if maxt = ~1 then (thm, elhs)
   584           else (Thm.incr_indexes (maxt+1) thm, Thm.cterm_incr_indexes (maxt+1) elhs);
   585         val insts = if fo then Thm.cterm_first_order_match (elhs', eta_t')
   586                           else Thm.cterm_match (elhs', eta_t');
   587         val thm' = Thm.instantiate insts (Thm.rename_boundvars lhs eta_t rthm);
   588         val prop' = #prop (rep_thm thm');
   589         val unconditional = (Logic.count_prems (prop',0) = 0);
   590         val (lhs', rhs') = Logic.dest_equals (Logic.strip_imp_concl prop')
   591       in
   592         if perm andalso not (termless (rhs', lhs'))
   593         then (trace_thm false "Cannot apply permutative rewrite rule:" thm;
   594               trace_thm false "Term does not become smaller:" thm'; None)
   595         else
   596           let val ds = "[" ^ string_of_int depth ^ "]"
   597           in trace_thm false "Applying instance of rewrite rule:" thm;
   598            if unconditional
   599            then
   600              (trace_thm false (ds ^ "Rewriting:") thm';
   601               let val lr = Logic.dest_equals prop;
   602                   val Some thm'' = check_conv false eta_thm thm'
   603               in Some (thm'', uncond_skel (congs, lr)) end)
   604            else
   605              (trace_thm false (ds ^ "Trying to rewrite:") thm';
   606               case prover (incr_depth mss) thm' of
   607                 None       => (trace_thm false (ds ^ "FAILED") thm'; None)
   608               | Some thm2 =>
   609                   (case check_conv true eta_thm thm2 of
   610                      None => None |
   611                      Some thm2' =>
   612                        let val concl = Logic.strip_imp_concl prop
   613                            val lr = Logic.dest_equals concl
   614                        in Some (thm2', cond_skel (congs, lr)) end))
   615           end
   616       end
   617 
   618     fun rews [] = None
   619       | rews (rrule :: rrules) =
   620           let val opt = rew rrule handle Pattern.MATCH => None
   621           in case opt of None => rews rrules | some => some end;
   622 
   623     fun sort_rrules rrs = let
   624       fun is_simple({thm, ...}:rrule) = case #prop (rep_thm thm) of 
   625                                       Const("==",_) $ _ $ _ => true
   626                                       | _                   => false 
   627       fun sort []        (re1,re2) = re1 @ re2
   628         | sort (rr::rrs) (re1,re2) = if is_simple rr 
   629                                      then sort rrs (rr::re1,re2)
   630                                      else sort rrs (re1,rr::re2)
   631     in sort rrs ([],[]) end
   632 
   633     fun proc_rews ([]:simproc list) = None
   634       | proc_rews ({name, proc, lhs, ...} :: ps) =
   635           if Pattern.matches tsigt (term_of lhs, term_of t) then
   636             (debug_term false ("Trying procedure " ^ quote name ^ " on:") signt eta_t;
   637              case proc signt prems eta_t of
   638                None => (debug false "FAILED"; proc_rews ps)
   639              | Some raw_thm =>
   640                  (trace_thm false ("Procedure " ^ quote name ^ " produced rewrite rule:") raw_thm;
   641                   (case rews (mk_procrule raw_thm) of
   642                     None => (trace_cterm false "IGNORED - does not match" t; proc_rews ps)
   643                   | some => some)))
   644           else proc_rews ps;
   645   in case eta_t of
   646        Abs _ $ _ => Some (transitive eta_thm
   647          (beta_conversion false (rhs_of eta_thm)), skel0)
   648      | _ => (case rews (sort_rrules (Net.match_term rules eta_t)) of
   649                None => proc_rews (Net.match_term procs eta_t)
   650              | some => some)
   651   end;
   652 
   653 
   654 (* conversion to apply a congruence rule to a term *)
   655 
   656 fun congc (prover,signt,maxt) {thm=cong,lhs=lhs} t =
   657   let val {sign, ...} = rep_thm cong
   658       val _ = if Sign.subsig (sign, signt) then ()
   659                  else error("Congruence rule from different theory")
   660       val rthm = if maxt = ~1 then cong else Thm.incr_indexes (maxt+1) cong;
   661       val rlhs = fst (Drule.dest_equals (Drule.strip_imp_concl (cprop_of rthm)));
   662       val insts = Thm.cterm_match (rlhs, t)
   663       (* Pattern.match can raise Pattern.MATCH;
   664          is handled when congc is called *)
   665       val thm' = Thm.instantiate insts (Thm.rename_boundvars (term_of rlhs) (term_of t) rthm);
   666       val unit = trace_thm false "Applying congruence rule:" thm';
   667       fun err (msg, thm) = (prthm false msg thm; error "Failed congruence proof!")
   668   in case prover thm' of
   669        None => err ("Could not prove", thm')
   670      | Some thm2 => (case check_conv true (beta_eta_conversion t) thm2 of
   671           None => err ("Should not have proved", thm2)
   672         | Some thm2' => thm2')
   673   end;
   674 
   675 val (cA, (cB, cC)) =
   676   apsnd dest_equals (dest_implies (hd (cprems_of Drule.imp_cong)));
   677 
   678 fun transitive' thm1 None = Some thm1
   679   | transitive' thm1 (Some thm2) = Some (transitive thm1 thm2);
   680 
   681 fun bottomc ((simprem,useprem,mutsimp), prover, sign, maxidx) =
   682   let
   683     fun botc skel mss t =
   684           if is_Var skel then None
   685           else
   686           (case subc skel mss t of
   687              some as Some thm1 =>
   688                (case rewritec (prover, sign, maxidx) mss (rhs_of thm1) of
   689                   Some (thm2, skel2) =>
   690                     transitive' (transitive thm1 thm2)
   691                       (botc skel2 mss (rhs_of thm2))
   692                 | None => some)
   693            | None =>
   694                (case rewritec (prover, sign, maxidx) mss t of
   695                   Some (thm2, skel2) => transitive' thm2
   696                     (botc skel2 mss (rhs_of thm2))
   697                 | None => None))
   698 
   699     and try_botc mss t =
   700           (case botc skel0 mss t of
   701              Some trec1 => trec1 | None => (reflexive t))
   702 
   703     and subc skel
   704           (mss as Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) t0 =
   705        (case term_of t0 of
   706            Abs (a, T, t) =>
   707              let val b = variant bounds a
   708                  val (v, t') = Thm.dest_abs (Some ("." ^ b)) t0
   709                  val mss' = mk_mss (rules, congs, procs, b :: bounds, prems, mk_rews, termless,depth)
   710                  val skel' = case skel of Abs (_, _, sk) => sk | _ => skel0
   711              in case botc skel' mss' t' of
   712                   Some thm => Some (abstract_rule a v thm)
   713                 | None => None
   714              end
   715          | t $ _ => (case t of
   716              Const ("==>", _) $ _  =>
   717                let val (s, u) = Drule.dest_implies t0
   718                in impc (s, u, mss) end
   719            | Abs _ =>
   720                let val thm = beta_conversion false t0
   721                in case subc skel0 mss (rhs_of thm) of
   722                     None => Some thm
   723                   | Some thm' => Some (transitive thm thm')
   724                end
   725            | _  =>
   726                let fun appc () =
   727                      let
   728                        val (tskel, uskel) = case skel of
   729                            tskel $ uskel => (tskel, uskel)
   730                          | _ => (skel0, skel0);
   731                        val (ct, cu) = Thm.dest_comb t0
   732                      in
   733                      (case botc tskel mss ct of
   734                         Some thm1 =>
   735                           (case botc uskel mss cu of
   736                              Some thm2 => Some (combination thm1 thm2)
   737                            | None => Some (combination thm1 (reflexive cu)))
   738                       | None =>
   739                           (case botc uskel mss cu of
   740                              Some thm1 => Some (combination (reflexive ct) thm1)
   741                            | None => None))
   742                      end
   743                    val (h, ts) = strip_comb t
   744                in case h of
   745                     Const(a, _) =>
   746                       (case assoc_string (fst congs, a) of
   747                          None => appc ()
   748                        | Some cong =>
   749 (* post processing: some partial applications h t1 ... tj, j <= length ts,
   750    may be a redex. Example: map (%x.x) = (%xs.xs) wrt map_cong *)
   751                           (let
   752                              val thm = congc (prover mss, sign, maxidx) cong t0;
   753                              val t = rhs_of thm;
   754                              val (cl, cr) = Thm.dest_comb t
   755                              val dVar = Var(("", 0), dummyT)
   756                              val skel =
   757                                list_comb (h, replicate (length ts) dVar)
   758                            in case botc skel mss cl of
   759                                 None => Some thm
   760                               | Some thm' => Some (transitive thm
   761                                   (combination thm' (reflexive cr)))
   762                            end handle TERM _ => error "congc result"
   763                                     | Pattern.MATCH => appc ()))
   764                   | _ => appc ()
   765                end)
   766          | _ => None)
   767 
   768     and impc args =
   769       if mutsimp
   770       then let val (prem, conc, mss) = args
   771            in apsome snd (mut_impc ([], prem, conc, mss)) end
   772       else nonmut_impc args
   773 
   774     and mut_impc (prems, prem, conc, mss) = (case botc skel0 mss prem of
   775         None => mut_impc1 (prems, prem, conc, mss)
   776       | Some thm1 =>
   777           let val prem1 = rhs_of thm1
   778           in (case mut_impc1 (prems, prem1, conc, mss) of
   779               None => Some (None,
   780                 combination (combination refl_implies thm1) (reflexive conc))
   781             | Some (x, thm2) => Some (x, transitive (combination (combination
   782                 refl_implies thm1) (reflexive conc)) thm2))
   783           end)
   784 
   785     and mut_impc1 (prems, prem1, conc, mss) =
   786       let
   787         fun uncond ({thm, lhs, elhs, perm}) =
   788           if Thm.no_prems thm then Some lhs else None
   789 
   790         val (lhss1, mss1) =
   791           if maxidx_of_term (term_of prem1) <> ~1
   792           then (trace_cterm true
   793             "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem1;
   794                 ([],mss))
   795           else let val thm = assume prem1
   796                    val rrules1 = extract_safe_rrules (mss, thm)
   797                    val lhss1 = mapfilter uncond rrules1
   798                    val mss1 = foldl insert_rrule (add_prems (mss, [thm]), rrules1)
   799                in (lhss1, mss1) end
   800 
   801         fun disch1 thm =
   802           let val (cB', cC') = dest_eq thm
   803           in
   804             implies_elim (Thm.instantiate
   805               ([], [(cA, prem1), (cB, cB'), (cC, cC')]) Drule.imp_cong)
   806               (implies_intr prem1 thm)
   807           end
   808 
   809         fun rebuild None = (case rewritec (prover, sign, maxidx) mss
   810             (mk_implies (prem1, conc)) of
   811               None => None
   812             | Some (thm, _) => 
   813                 let val (prem, conc) = Drule.dest_implies (rhs_of thm)
   814                 in (case mut_impc (prems, prem, conc, mss) of
   815                     None => Some (None, thm)
   816                   | Some (x, thm') => Some (x, transitive thm thm'))
   817                 end handle TERM _ => Some (None, thm))
   818           | rebuild (Some thm2) =
   819             let val thm = disch1 thm2
   820             in (case rewritec (prover, sign, maxidx) mss (rhs_of thm) of
   821                  None => Some (None, thm)
   822                | Some (thm', _) =>
   823                    let val (prem, conc) = Drule.dest_implies (rhs_of thm')
   824                    in (case mut_impc (prems, prem, conc, mss) of
   825                        None => Some (None, transitive thm thm')
   826                      | Some (x, thm'') =>
   827                          Some (x, transitive (transitive thm thm') thm''))
   828                    end handle TERM _ => Some (None, transitive thm thm'))
   829             end
   830 
   831         fun simpconc () =
   832           let val (s, t) = Drule.dest_implies conc
   833           in case mut_impc (prems @ [prem1], s, t, mss1) of
   834                None => rebuild None
   835              | Some (Some i, thm2) =>
   836                   let
   837                     val (prem, cC') = Drule.dest_implies (rhs_of thm2);
   838                     val thm2' = transitive (disch1 thm2) (Thm.instantiate
   839                       ([], [(cA, prem1), (cB, prem), (cC, cC')])
   840                       Drule.swap_prems_eq)
   841                   in if i=0 then apsome (apsnd (transitive thm2'))
   842                        (mut_impc1 (prems, prem, mk_implies (prem1, cC'), mss))
   843                      else Some (Some (i-1), thm2')
   844                   end
   845              | Some (None, thm) => rebuild (Some thm)
   846           end handle TERM _ => rebuild (botc skel0 mss1 conc)
   847 
   848       in
   849         let
   850           val tsig = Sign.tsig_of sign
   851           fun reducible t =
   852             exists (fn lhs => Pattern.matches_subterm tsig (lhs, term_of t)) lhss1;
   853         in case dropwhile (not o reducible) prems of
   854             [] => simpconc ()
   855           | red::rest => (trace_cterm false "Can now reduce premise:" red;
   856               Some (Some (length rest), reflexive (mk_implies (prem1, conc))))
   857         end
   858       end
   859 
   860      (* legacy code - only for backwards compatibility *)
   861      and nonmut_impc (prem, conc, mss) =
   862        let val thm1 = if simprem then botc skel0 mss prem else None;
   863            val prem1 = if_none (apsome rhs_of thm1) prem;
   864            val maxidx1 = maxidx_of_term (term_of prem1)
   865            val mss1 =
   866              if not useprem then mss else
   867              if maxidx1 <> ~1
   868              then (trace_cterm true
   869                "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem1;
   870                    mss)
   871              else let val thm = assume prem1
   872                   in add_safe_simp (add_prems (mss, [thm]), thm) end
   873        in (case botc skel0 mss1 conc of
   874            None => (case thm1 of
   875                None => None
   876              | Some thm1' => Some (combination
   877                  (combination refl_implies thm1') (reflexive conc)))
   878          | Some thm2 =>
   879            let
   880              val conc2 = rhs_of thm2;
   881              val thm2' = implies_elim (Thm.instantiate
   882                ([], [(cA, prem1), (cB, conc), (cC, conc2)]) Drule.imp_cong)
   883                (implies_intr prem1 thm2)
   884            in (case thm1 of
   885                None => Some thm2'
   886              | Some thm1' => Some (transitive (combination
   887                  (combination refl_implies thm1') (reflexive conc)) thm2'))
   888            end)
   889        end
   890 
   891  in try_botc end;
   892 
   893 
   894 (*** Meta-rewriting: rewrites t to u and returns the theorem t==u ***)
   895 
   896 (*
   897   Parameters:
   898     mode = (simplify A,
   899             use A in simplifying B,
   900             use prems of B (if B is again a meta-impl.) to simplify A)
   901            when simplifying A ==> B
   902     mss: contains equality theorems of the form [|p1,...|] ==> t==u
   903     prover: how to solve premises in conditional rewrites and congruences
   904 *)
   905 
   906 (* FIXME: check that #bounds(mss) does not "occur" in ct already *)
   907 
   908 fun rewrite_cterm mode prover mss ct =
   909   let val {sign, t, maxidx, ...} = rep_cterm ct
   910   in bottomc (mode, prover, sign, maxidx) mss ct end
   911   handle THM (s, _, thms) =>
   912     error ("Exception THM was raised in simplifier:\n" ^ s ^ "\n" ^
   913       Pretty.string_of (pretty_thms thms));
   914 
   915 (*In [A1,...,An]==>B, rewrite the selected A's only -- for rewrite_goals_tac*)
   916 (*Do not rewrite flex-flex pairs*)
   917 fun goals_conv pred cv =
   918   let fun gconv i ct =
   919         let val (A,B) = Drule.dest_implies ct
   920             val (thA,j) = case term_of A of
   921                   Const("=?=",_)$_$_ => (reflexive A, i)
   922                 | _ => (if pred i then cv A else reflexive A, i+1)
   923         in  combination (combination refl_implies thA) (gconv j B) end
   924         handle TERM _ => reflexive ct
   925   in gconv 1 end;
   926 
   927 (*Use a conversion to transform a theorem*)
   928 fun fconv_rule cv th = equal_elim (cv (cprop_of th)) th;
   929 
   930 (*Rewrite a theorem*)
   931 fun rewrite_rule_aux _ [] = (fn th => th)
   932   | rewrite_rule_aux prover thms =
   933       fconv_rule (rewrite_cterm (true,false,false) prover (mss_of thms));
   934 
   935 fun rewrite_thm mode prover mss = fconv_rule (rewrite_cterm mode prover mss);
   936 
   937 (*Rewrite the subgoals of a proof state (represented by a theorem) *)
   938 fun rewrite_goals_rule_aux _ []   th = th
   939   | rewrite_goals_rule_aux prover thms th =
   940       fconv_rule (goals_conv (K true) (rewrite_cterm (true, true, false) prover
   941         (mss_of thms))) th;
   942 
   943 (*Rewrite the subgoal of a proof state (represented by a theorem) *)
   944 fun rewrite_goal_rule mode prover mss i thm =
   945   if 0 < i  andalso  i <= nprems_of thm
   946   then fconv_rule (goals_conv (fn j => j=i) (rewrite_cterm mode prover mss)) thm
   947   else raise THM("rewrite_goal_rule",i,[thm]);
   948 
   949 end;
   950 
   951 open MetaSimplifier;