src/Pure/meta_simplifier.ML
author berghofe
Fri Oct 12 16:57:07 2001 +0200 (2001-10-12)
changeset 11736 da6fc37ed6fa
parent 11689 38788d98504f
child 11737 0ec18d3131b5
permissions -rw-r--r--
- Exported goals_conv and fconv_rule
- Added forall_conv
     1 (*  Title:      Pure/meta_simplifier.ML
     2     ID:         $Id$
     3     Author:     Tobias Nipkow and Stefan Berghofer
     4     Copyright   1994  University of Cambridge
     5 
     6 Meta-level Simplification.
     7 *)
     8 
     9 signature BASIC_META_SIMPLIFIER =
    10 sig
    11   val trace_simp: bool ref
    12   val debug_simp: bool ref
    13 end;
    14 
    15 signature META_SIMPLIFIER =
    16 sig
    17   include BASIC_META_SIMPLIFIER
    18   exception SIMPLIFIER of string * thm
    19   type meta_simpset
    20   val dest_mss		: meta_simpset ->
    21     {simps: thm list, congs: thm list, procs: (string * cterm list) list}
    22   val empty_mss         : meta_simpset
    23   val clear_mss		: meta_simpset -> meta_simpset
    24   val merge_mss		: meta_simpset * meta_simpset -> meta_simpset
    25   val add_simps         : meta_simpset * thm list -> meta_simpset
    26   val del_simps         : meta_simpset * thm list -> meta_simpset
    27   val mss_of            : thm list -> meta_simpset
    28   val add_congs         : meta_simpset * thm list -> meta_simpset
    29   val del_congs         : meta_simpset * thm list -> meta_simpset
    30   val add_simprocs	: meta_simpset *
    31     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    32       -> meta_simpset
    33   val del_simprocs	: meta_simpset *
    34     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    35       -> meta_simpset
    36   val add_prems         : meta_simpset * thm list -> meta_simpset
    37   val prems_of_mss      : meta_simpset -> thm list
    38   val set_mk_rews       : meta_simpset * (thm -> thm list) -> meta_simpset
    39   val set_mk_sym        : meta_simpset * (thm -> thm option) -> meta_simpset
    40   val set_mk_eq_True    : meta_simpset * (thm -> thm option) -> meta_simpset
    41   val set_termless      : meta_simpset * (term * term -> bool) -> meta_simpset
    42   val rewrite_cterm: bool * bool * bool ->
    43     (meta_simpset -> thm -> thm option) -> meta_simpset -> cterm -> thm
    44   val goals_conv        : (int -> bool) -> (cterm -> thm) -> cterm -> thm
    45   val forall_conv       : (cterm -> thm) -> cterm -> thm
    46   val fconv_rule        : (cterm -> thm) -> thm -> thm
    47   val full_rewrite_cterm_aux: (meta_simpset -> thm -> thm option) -> thm list -> cterm -> cterm
    48   val rewrite_rule_aux  : (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
    49   val rewrite_thm       : bool * bool * bool
    50                           -> (meta_simpset -> thm -> thm option)
    51                           -> meta_simpset -> thm -> thm
    52   val rewrite_goals_rule_aux: (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
    53   val rewrite_goal_rule : bool* bool * bool
    54                           -> (meta_simpset -> thm -> thm option)
    55                           -> meta_simpset -> int -> thm -> thm
    56 end;
    57 
    58 structure MetaSimplifier : META_SIMPLIFIER =
    59 struct
    60 
    61 (** diagnostics **)
    62 
    63 exception SIMPLIFIER of string * thm;
    64 
    65 val simp_depth = ref 0;
    66 
    67 fun println a = writeln(replicate_string (!simp_depth) " " ^ a)
    68 
    69 fun prnt warn a = if warn then warning a else println a;
    70 
    71 fun prtm warn a sign t =
    72   (prnt warn a; prnt warn (Sign.string_of_term sign t));
    73 
    74 fun prctm warn a t =
    75   (prnt warn a; prnt warn (Display.string_of_cterm t));
    76 
    77 fun prthm warn a thm =
    78   let val {sign, prop, ...} = rep_thm thm
    79   in prtm warn a sign prop end;
    80 
    81 val trace_simp = ref false;
    82 val debug_simp = ref false;
    83 
    84 fun trace warn a = if !trace_simp then prnt warn a else ();
    85 fun debug warn a = if !debug_simp then prnt warn a else ();
    86 
    87 fun trace_term warn a sign t = if !trace_simp then prtm warn a sign t else ();
    88 fun trace_cterm warn a t = if !trace_simp then prctm warn a t else ();
    89 fun debug_term warn a sign t = if !debug_simp then prtm warn a sign t else ();
    90 
    91 fun trace_thm warn a thm =
    92   let val {sign, prop, ...} = rep_thm thm
    93   in trace_term warn a sign prop end;
    94 
    95 
    96 
    97 (** meta simp sets **)
    98 
    99 (* basic components *)
   100 
   101 type rrule = {thm: thm, lhs: term, elhs: cterm, fo: bool, perm: bool};
   102 (* thm: the rewrite rule
   103    lhs: the left-hand side
   104    elhs: the etac-contracted lhs.
   105    fo:  use first-order matching
   106    perm: the rewrite rule is permutative
   107 Reamrks:
   108   - elhs is used for matching,
   109     lhs only for preservation of bound variable names.
   110   - fo is set iff
   111     either elhs is first-order (no Var is applied),
   112            in which case fo-matching is complete,
   113     or elhs is not a pattern,
   114        in which case there is nothing better to do.
   115 *)
   116 type cong = {thm: thm, lhs: cterm};
   117 type simproc =
   118  {name: string, proc: Sign.sg -> thm list -> term -> thm option, lhs: cterm, id: stamp};
   119 
   120 fun eq_rrule ({thm = thm1, ...}: rrule, {thm = thm2, ...}: rrule) =
   121   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   122 
   123 fun eq_cong ({thm = thm1, ...}: cong, {thm = thm2, ...}: cong) = 
   124   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   125 
   126 fun eq_prem (thm1, thm2) =
   127   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   128 
   129 fun eq_simproc ({id = s1, ...}:simproc, {id = s2, ...}:simproc) = (s1 = s2);
   130 
   131 fun mk_simproc (name, proc, lhs, id) =
   132   {name = name, proc = proc, lhs = lhs, id = id};
   133 
   134 
   135 (* datatype mss *)
   136 
   137 (*
   138   A "mss" contains data needed during conversion:
   139     rules: discrimination net of rewrite rules;
   140     congs: association list of congruence rules and
   141            a list of `weak' congruence constants.
   142            A congruence is `weak' if it avoids normalization of some argument.
   143     procs: discrimination net of simplification procedures
   144       (functions that prove rewrite rules on the fly);
   145     bounds: names of bound variables already used
   146       (for generating new names when rewriting under lambda abstractions);
   147     prems: current premises;
   148     mk_rews: mk: turns simplification thms into rewrite rules;
   149              mk_sym: turns == around; (needs Drule!)
   150              mk_eq_True: turns P into P == True - logic specific;
   151     termless: relation for ordered rewriting;
   152     depth: depth of conditional rewriting;
   153 *)
   154 
   155 datatype meta_simpset =
   156   Mss of {
   157     rules: rrule Net.net,
   158     congs: (string * cong) list * string list,
   159     procs: simproc Net.net,
   160     bounds: string list,
   161     prems: thm list,
   162     mk_rews: {mk: thm -> thm list,
   163               mk_sym: thm -> thm option,
   164               mk_eq_True: thm -> thm option},
   165     termless: term * term -> bool,
   166     depth: int};
   167 
   168 fun mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth) =
   169   Mss {rules = rules, congs = congs, procs = procs, bounds = bounds,
   170        prems=prems, mk_rews=mk_rews, termless=termless, depth=depth};
   171 
   172 fun upd_rules(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}, rules') =
   173   mk_mss(rules',congs,procs,bounds,prems,mk_rews,termless,depth);
   174 
   175 val empty_mss =
   176   let val mk_rews = {mk = K [], mk_sym = K None, mk_eq_True = K None}
   177   in mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, Term.termless, 0) end;
   178 
   179 fun clear_mss (Mss {mk_rews, termless, ...}) =
   180   mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, termless,0);
   181 
   182 fun incr_depth(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) =
   183   mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth+1)
   184  
   185 
   186 
   187 (** simpset operations **)
   188 
   189 (* term variables *)
   190 
   191 val add_term_varnames = foldl_aterms (fn (xs, Var (x, _)) => ins_ix (x, xs) | (xs, _) => xs);
   192 fun term_varnames t = add_term_varnames ([], t);
   193 
   194 
   195 (* dest_mss *)
   196 
   197 fun dest_mss (Mss {rules, congs, procs, ...}) =
   198   {simps = map (fn (_, {thm, ...}) => thm) (Net.dest rules),
   199    congs = map (fn (_, {thm, ...}) => thm) (fst congs),
   200    procs =
   201      map (fn (_, {name, lhs, id, ...}) => ((name, lhs), id)) (Net.dest procs)
   202      |> partition_eq eq_snd
   203      |> map (fn ps => (#1 (#1 (hd ps)), map (#2 o #1) ps))
   204      |> Library.sort_wrt #1};
   205 
   206 
   207 (* merge_mss *)	      (*NOTE: ignores mk_rews, termless and depth of 2nd mss*)
   208 
   209 fun merge_mss
   210  (Mss {rules = rules1, congs = (congs1,weak1), procs = procs1,
   211        bounds = bounds1, prems = prems1, mk_rews, termless, depth},
   212   Mss {rules = rules2, congs = (congs2,weak2), procs = procs2,
   213        bounds = bounds2, prems = prems2, ...}) =
   214       mk_mss
   215        (Net.merge (rules1, rules2, eq_rrule),
   216         (generic_merge (eq_cong o pairself snd) I I congs1 congs2,
   217         merge_lists weak1 weak2),
   218         Net.merge (procs1, procs2, eq_simproc),
   219         merge_lists bounds1 bounds2,
   220         generic_merge eq_prem I I prems1 prems2,
   221         mk_rews, termless, depth);
   222 
   223 
   224 (* add_simps *)
   225 
   226 fun mk_rrule2{thm,lhs,elhs,perm} =
   227   let val fo = Pattern.first_order (term_of elhs) orelse not(Pattern.pattern (term_of elhs))
   228   in {thm=thm,lhs=lhs,elhs=elhs,fo=fo,perm=perm} end
   229 
   230 fun insert_rrule(mss as Mss {rules,...},
   231                  rrule as {thm,lhs,elhs,perm}) =
   232   (trace_thm false "Adding rewrite rule:" thm;
   233    let val rrule2 as {elhs,...} = mk_rrule2 rrule
   234        val rules' = Net.insert_term ((term_of elhs, rrule2), rules, eq_rrule)
   235    in upd_rules(mss,rules') end
   236    handle Net.INSERT =>
   237      (prthm true "Ignoring duplicate rewrite rule:" thm; mss));
   238 
   239 fun vperm (Var _, Var _) = true
   240   | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t)
   241   | vperm (t1 $ t2, u1 $ u2) = vperm (t1, u1) andalso vperm (t2, u2)
   242   | vperm (t, u) = (t = u);
   243 
   244 fun var_perm (t, u) =
   245   vperm (t, u) andalso eq_set (term_varnames t, term_varnames u);
   246 
   247 (* FIXME: it seems that the conditions on extra variables are too liberal if
   248 prems are nonempty: does solving the prems really guarantee instantiation of
   249 all its Vars? Better: a dynamic check each time a rule is applied.
   250 *)
   251 fun rewrite_rule_extra_vars prems elhs erhs =
   252   not (term_varnames erhs subset foldl add_term_varnames (term_varnames elhs, prems))
   253   orelse
   254   not ((term_tvars erhs) subset
   255        (term_tvars elhs  union  List.concat(map term_tvars prems)));
   256 
   257 (*Simple test for looping rewrite rules and stupid orientations*)
   258 fun reorient sign prems lhs rhs =
   259    rewrite_rule_extra_vars prems lhs rhs
   260   orelse
   261    is_Var (head_of lhs)
   262   orelse
   263    (exists (apl (lhs, Logic.occs)) (rhs :: prems))
   264   orelse
   265    (null prems andalso
   266     Pattern.matches (#tsig (Sign.rep_sg sign)) (lhs, rhs))
   267     (*the condition "null prems" is necessary because conditional rewrites
   268       with extra variables in the conditions may terminate although
   269       the rhs is an instance of the lhs. Example: ?m < ?n ==> f(?n) == f(?m)*)
   270   orelse
   271    (is_Const lhs andalso not(is_Const rhs))
   272 
   273 fun decomp_simp thm =
   274   let val {sign, prop, ...} = rep_thm thm;
   275       val prems = Logic.strip_imp_prems prop;
   276       val concl = Drule.strip_imp_concl (cprop_of thm);
   277       val (lhs, rhs) = Drule.dest_equals concl handle TERM _ =>
   278         raise SIMPLIFIER ("Rewrite rule not a meta-equality", thm)
   279       val elhs = snd (Drule.dest_equals (cprop_of (Thm.eta_conversion lhs)));
   280       val elhs = if elhs=lhs then lhs else elhs (* try to share *)
   281       val erhs = Pattern.eta_contract (term_of rhs);
   282       val perm = var_perm (term_of elhs, erhs) andalso not (term_of elhs aconv erhs)
   283                  andalso not (is_Var (term_of elhs))
   284   in (sign, prems, term_of lhs, elhs, term_of rhs, perm) end;
   285 
   286 fun mk_eq_True (Mss{mk_rews={mk_eq_True,...},...}) thm =
   287   case mk_eq_True thm of
   288     None => []
   289   | Some eq_True => let val (_,_,lhs,elhs,_,_) = decomp_simp eq_True
   290                     in [{thm=eq_True, lhs=lhs, elhs=elhs, perm=false}] end;
   291 
   292 (* create the rewrite rule and possibly also the ==True variant,
   293    in case there are extra vars on the rhs *)
   294 fun rrule_eq_True(thm,lhs,elhs,rhs,mss,thm2) =
   295   let val rrule = {thm=thm, lhs=lhs, elhs=elhs, perm=false}
   296   in if (term_varnames rhs)  subset (term_varnames lhs) andalso
   297         (term_tvars rhs) subset (term_tvars lhs)
   298      then [rrule]
   299      else mk_eq_True mss thm2 @ [rrule]
   300   end;
   301 
   302 fun mk_rrule mss thm =
   303   let val (_,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   304   in if perm then [{thm=thm, lhs=lhs, elhs=elhs, perm=true}] else
   305      (* weak test for loops: *)
   306      if rewrite_rule_extra_vars prems lhs rhs orelse
   307         is_Var (term_of elhs)
   308      then mk_eq_True mss thm
   309      else rrule_eq_True(thm,lhs,elhs,rhs,mss,thm)
   310   end;
   311 
   312 fun orient_rrule mss thm =
   313   let val (sign,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   314   in if perm then [{thm=thm,lhs=lhs,elhs=elhs,perm=true}]
   315      else if reorient sign prems lhs rhs
   316           then if reorient sign prems rhs lhs
   317                then mk_eq_True mss thm
   318                else let val Mss{mk_rews={mk_sym,...},...} = mss
   319                     in case mk_sym thm of
   320                          None => []
   321                        | Some thm' =>
   322                            let val (_,_,lhs',elhs',rhs',_) = decomp_simp thm'
   323                            in rrule_eq_True(thm',lhs',elhs',rhs',mss,thm) end
   324                     end
   325           else rrule_eq_True(thm,lhs,elhs,rhs,mss,thm)
   326   end;
   327 
   328 fun extract_rews(Mss{mk_rews = {mk,...},...},thms) = flat(map mk thms);
   329 
   330 fun orient_comb_simps comb mk_rrule (mss,thms) =
   331   let val rews = extract_rews(mss,thms)
   332       val rrules = flat (map mk_rrule rews)
   333   in foldl comb (mss,rrules) end
   334 
   335 (* Add rewrite rules explicitly; do not reorient! *)
   336 fun add_simps(mss,thms) =
   337   orient_comb_simps insert_rrule (mk_rrule mss) (mss,thms);
   338 
   339 fun mss_of thms =
   340   foldl insert_rrule (empty_mss, flat(map (mk_rrule empty_mss) thms));
   341 
   342 fun extract_safe_rrules(mss,thm) =
   343   flat (map (orient_rrule mss) (extract_rews(mss,[thm])));
   344 
   345 fun add_safe_simp(mss,thm) =
   346   foldl insert_rrule (mss, extract_safe_rrules(mss,thm))
   347 
   348 (* del_simps *)
   349 
   350 fun del_rrule(mss as Mss {rules,...},
   351               rrule as {thm, elhs, ...}) =
   352   (upd_rules(mss, Net.delete_term ((term_of elhs, rrule), rules, eq_rrule))
   353    handle Net.DELETE =>
   354      (prthm true "Rewrite rule not in simpset:" thm; mss));
   355 
   356 fun del_simps(mss,thms) =
   357   orient_comb_simps del_rrule (map mk_rrule2 o mk_rrule mss) (mss,thms);
   358 
   359 
   360 (* add_congs *)
   361 
   362 fun is_full_cong_prems [] varpairs = null varpairs
   363   | is_full_cong_prems (p::prems) varpairs =
   364     (case Logic.strip_assums_concl p of
   365        Const("==",_) $ lhs $ rhs =>
   366          let val (x,xs) = strip_comb lhs and (y,ys) = strip_comb rhs
   367          in is_Var x  andalso  forall is_Bound xs  andalso
   368             null(findrep(xs))  andalso xs=ys andalso
   369             (x,y) mem varpairs andalso
   370             is_full_cong_prems prems (varpairs\(x,y))
   371          end
   372      | _ => false);
   373 
   374 fun is_full_cong thm =
   375 let val prems = prems_of thm
   376     and concl = concl_of thm
   377     val (lhs,rhs) = Logic.dest_equals concl
   378     val (f,xs) = strip_comb lhs
   379     and (g,ys) = strip_comb rhs
   380 in
   381   f=g andalso null(findrep(xs@ys)) andalso length xs = length ys andalso
   382   is_full_cong_prems prems (xs ~~ ys)
   383 end
   384 
   385 fun add_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   386   let
   387     val (lhs, _) = Drule.dest_equals (Drule.strip_imp_concl (cprop_of thm)) handle TERM _ =>
   388       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   389 (*   val lhs = Pattern.eta_contract lhs; *)
   390     val (a, _) = dest_Const (head_of (term_of lhs)) handle TERM _ =>
   391       raise SIMPLIFIER ("Congruence must start with a constant", thm);
   392     val (alist,weak) = congs
   393     val alist2 = overwrite_warn (alist, (a,{lhs=lhs, thm=thm}))
   394            ("Overwriting congruence rule for " ^ quote a);
   395     val weak2 = if is_full_cong thm then weak else a::weak
   396   in
   397     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   398   end;
   399 
   400 val (op add_congs) = foldl add_cong;
   401 
   402 
   403 (* del_congs *)
   404 
   405 fun del_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   406   let
   407     val (lhs, _) = Logic.dest_equals (concl_of thm) handle TERM _ =>
   408       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   409 (*   val lhs = Pattern.eta_contract lhs; *)
   410     val (a, _) = dest_Const (head_of lhs) handle TERM _ =>
   411       raise SIMPLIFIER ("Congruence must start with a constant", thm);
   412     val (alist,_) = congs
   413     val alist2 = filter (fn (x,_)=> x<>a) alist
   414     val weak2 = mapfilter (fn(a,{thm,...}) => if is_full_cong thm then None
   415                                               else Some a)
   416                    alist2
   417   in
   418     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   419   end;
   420 
   421 val (op del_congs) = foldl del_cong;
   422 
   423 
   424 (* add_simprocs *)
   425 
   426 fun add_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   427     (name, lhs, proc, id)) =
   428   let val {sign, t, ...} = rep_cterm lhs
   429   in (trace_term false ("Adding simplification procedure " ^ quote name ^ " for")
   430       sign t;
   431     mk_mss (rules, congs,
   432       Net.insert_term ((t, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   433         handle Net.INSERT => 
   434 	    (warning ("Ignoring duplicate simplification procedure \"" 
   435 	              ^ name ^ "\""); 
   436 	     procs),
   437         bounds, prems, mk_rews, termless,depth))
   438   end;
   439 
   440 fun add_simproc (mss, (name, lhss, proc, id)) =
   441   foldl add_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   442 
   443 val add_simprocs = foldl add_simproc;
   444 
   445 
   446 (* del_simprocs *)
   447 
   448 fun del_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   449     (name, lhs, proc, id)) =
   450   mk_mss (rules, congs,
   451     Net.delete_term ((term_of lhs, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   452       handle Net.DELETE => 
   453 	  (warning ("Simplification procedure \"" ^ name ^
   454 		       "\" not in simpset"); procs),
   455       bounds, prems, mk_rews, termless, depth);
   456 
   457 fun del_simproc (mss, (name, lhss, proc, id)) =
   458   foldl del_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   459 
   460 val del_simprocs = foldl del_simproc;
   461 
   462 
   463 (* prems *)
   464 
   465 fun add_prems (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thms) =
   466   mk_mss (rules, congs, procs, bounds, thms @ prems, mk_rews, termless, depth);
   467 
   468 fun prems_of_mss (Mss {prems, ...}) = prems;
   469 
   470 
   471 (* mk_rews *)
   472 
   473 fun set_mk_rews
   474   (Mss {rules, congs, procs, bounds, prems, mk_rews, termless, depth}, mk) =
   475     mk_mss (rules, congs, procs, bounds, prems,
   476             {mk=mk, mk_sym= #mk_sym mk_rews, mk_eq_True= #mk_eq_True mk_rews},
   477             termless, depth);
   478 
   479 fun set_mk_sym
   480   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_sym) =
   481     mk_mss (rules, congs, procs, bounds, prems,
   482             {mk= #mk mk_rews, mk_sym= mk_sym, mk_eq_True= #mk_eq_True mk_rews},
   483             termless,depth);
   484 
   485 fun set_mk_eq_True
   486   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_eq_True) =
   487     mk_mss (rules, congs, procs, bounds, prems,
   488             {mk= #mk mk_rews, mk_sym= #mk_sym mk_rews, mk_eq_True= mk_eq_True},
   489             termless,depth);
   490 
   491 (* termless *)
   492 
   493 fun set_termless
   494   (Mss {rules, congs, procs, bounds, prems, mk_rews, depth, ...}, termless) =
   495     mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth);
   496 
   497 
   498 
   499 (** rewriting **)
   500 
   501 (*
   502   Uses conversions, see:
   503     L C Paulson, A higher-order implementation of rewriting,
   504     Science of Computer Programming 3 (1983), pages 119-149.
   505 *)
   506 
   507 type prover = meta_simpset -> thm -> thm option;
   508 type termrec = (Sign.sg_ref * term list) * term;
   509 type conv = meta_simpset -> termrec -> termrec;
   510 
   511 val dest_eq = Drule.dest_equals o cprop_of;
   512 val lhs_of = fst o dest_eq;
   513 val rhs_of = snd o dest_eq;
   514 
   515 fun beta_eta_conversion t =
   516   let val thm = beta_conversion true t;
   517   in transitive thm (eta_conversion (rhs_of thm)) end;
   518 
   519 fun check_conv msg thm thm' =
   520   let
   521     val thm'' = transitive thm (transitive
   522       (symmetric (beta_eta_conversion (lhs_of thm'))) thm')
   523   in (if msg then trace_thm false "SUCCEEDED" thm' else (); Some thm'') end
   524   handle THM _ =>
   525     let val {sign, prop = _ $ _ $ prop0, ...} = rep_thm thm;
   526     in
   527       (trace_thm false "Proved wrong thm (Check subgoaler?)" thm';
   528        trace_term false "Should have proved:" sign prop0;
   529        None)
   530     end;
   531 
   532 
   533 (* mk_procrule *)
   534 
   535 fun mk_procrule thm =
   536   let val (_,prems,lhs,elhs,rhs,_) = decomp_simp thm
   537   in if rewrite_rule_extra_vars prems lhs rhs
   538      then (prthm true "Extra vars on rhs:" thm; [])
   539      else [mk_rrule2{thm=thm, lhs=lhs, elhs=elhs, perm=false}]
   540   end;
   541 
   542 
   543 (* conversion to apply the meta simpset to a term *)
   544 
   545 (* Since the rewriting strategy is bottom-up, we avoid re-normalizing already
   546    normalized terms by carrying around the rhs of the rewrite rule just
   547    applied. This is called the `skeleton'. It is decomposed in parallel
   548    with the term. Once a Var is encountered, the corresponding term is
   549    already in normal form.
   550    skel0 is a dummy skeleton that is to enforce complete normalization.
   551 *)
   552 val skel0 = Bound 0;
   553 
   554 (* Use rhs as skeleton only if the lhs does not contain unnormalized bits.
   555    The latter may happen iff there are weak congruence rules for constants
   556    in the lhs.
   557 *)
   558 fun uncond_skel((_,weak),(lhs,rhs)) =
   559   if null weak then rhs (* optimization *)
   560   else if exists_Const (fn (c,_) => c mem weak) lhs then skel0
   561        else rhs;
   562 
   563 (* Behaves like unconditional rule if rhs does not contain vars not in the lhs.
   564    Otherwise those vars may become instantiated with unnormalized terms
   565    while the premises are solved.
   566 *)
   567 fun cond_skel(args as (congs,(lhs,rhs))) =
   568   if term_varnames rhs subset term_varnames lhs then uncond_skel(args)
   569   else skel0;
   570 
   571 (*
   572   we try in order:
   573     (1) beta reduction
   574     (2) unconditional rewrite rules
   575     (3) conditional rewrite rules
   576     (4) simplification procedures
   577 
   578   IMPORTANT: rewrite rules must not introduce new Vars or TVars!
   579 
   580 *)
   581 
   582 fun rewritec (prover, signt, maxt)
   583              (mss as Mss{rules, procs, termless, prems, congs, depth,...}) t =
   584   let
   585     val eta_thm = Thm.eta_conversion t;
   586     val eta_t' = rhs_of eta_thm;
   587     val eta_t = term_of eta_t';
   588     val tsigt = Sign.tsig_of signt;
   589     fun rew {thm, lhs, elhs, fo, perm} =
   590       let
   591         val {sign, prop, maxidx, ...} = rep_thm thm;
   592         val _ = if Sign.subsig (sign, signt) then ()
   593                 else (prthm true "Ignoring rewrite rule from different theory:" thm;
   594                       raise Pattern.MATCH);
   595         val (rthm, elhs') = if maxt = ~1 then (thm, elhs)
   596           else (Thm.incr_indexes (maxt+1) thm, Thm.cterm_incr_indexes (maxt+1) elhs);
   597         val insts = if fo then Thm.cterm_first_order_match (elhs', eta_t')
   598                           else Thm.cterm_match (elhs', eta_t');
   599         val thm' = Thm.instantiate insts (Thm.rename_boundvars lhs eta_t rthm);
   600         val prop' = #prop (rep_thm thm');
   601         val unconditional = (Logic.count_prems (prop',0) = 0);
   602         val (lhs', rhs') = Logic.dest_equals (Logic.strip_imp_concl prop')
   603       in
   604         if perm andalso not (termless (rhs', lhs'))
   605         then (trace_thm false "Cannot apply permutative rewrite rule:" thm;
   606               trace_thm false "Term does not become smaller:" thm'; None)
   607         else (trace_thm false "Applying instance of rewrite rule:" thm;
   608            if unconditional
   609            then
   610              (trace_thm false "Rewriting:" thm';
   611               let val lr = Logic.dest_equals prop;
   612                   val Some thm'' = check_conv false eta_thm thm'
   613               in Some (thm'', uncond_skel (congs, lr)) end)
   614            else
   615              (trace_thm false "Trying to rewrite:" thm';
   616               case prover (incr_depth mss) thm' of
   617                 None       => (trace_thm false "FAILED" thm'; None)
   618               | Some thm2 =>
   619                   (case check_conv true eta_thm thm2 of
   620                      None => None |
   621                      Some thm2' =>
   622                        let val concl = Logic.strip_imp_concl prop
   623                            val lr = Logic.dest_equals concl
   624                        in Some (thm2', cond_skel (congs, lr)) end)))
   625       end
   626 
   627     fun rews [] = None
   628       | rews (rrule :: rrules) =
   629           let val opt = rew rrule handle Pattern.MATCH => None
   630           in case opt of None => rews rrules | some => some end;
   631 
   632     fun sort_rrules rrs = let
   633       fun is_simple({thm, ...}:rrule) = case #prop (rep_thm thm) of 
   634                                       Const("==",_) $ _ $ _ => true
   635                                       | _                   => false 
   636       fun sort []        (re1,re2) = re1 @ re2
   637         | sort (rr::rrs) (re1,re2) = if is_simple rr 
   638                                      then sort rrs (rr::re1,re2)
   639                                      else sort rrs (re1,rr::re2)
   640     in sort rrs ([],[]) end
   641 
   642     fun proc_rews ([]:simproc list) = None
   643       | proc_rews ({name, proc, lhs, ...} :: ps) =
   644           if Pattern.matches tsigt (term_of lhs, term_of t) then
   645             (debug_term false ("Trying procedure " ^ quote name ^ " on:") signt eta_t;
   646              case proc signt prems eta_t of
   647                None => (debug false "FAILED"; proc_rews ps)
   648              | Some raw_thm =>
   649                  (trace_thm false ("Procedure " ^ quote name ^ " produced rewrite rule:") raw_thm;
   650                   (case rews (mk_procrule raw_thm) of
   651                     None => (trace_cterm false "IGNORED - does not match" t; proc_rews ps)
   652                   | some => some)))
   653           else proc_rews ps;
   654   in case eta_t of
   655        Abs _ $ _ => Some (transitive eta_thm
   656          (beta_conversion false (rhs_of eta_thm)), skel0)
   657      | _ => (case rews (sort_rrules (Net.match_term rules eta_t)) of
   658                None => proc_rews (Net.match_term procs eta_t)
   659              | some => some)
   660   end;
   661 
   662 
   663 (* conversion to apply a congruence rule to a term *)
   664 
   665 fun congc (prover,signt,maxt) {thm=cong,lhs=lhs} t =
   666   let val {sign, ...} = rep_thm cong
   667       val _ = if Sign.subsig (sign, signt) then ()
   668                  else error("Congruence rule from different theory")
   669       val rthm = if maxt = ~1 then cong else Thm.incr_indexes (maxt+1) cong;
   670       val rlhs = fst (Drule.dest_equals (Drule.strip_imp_concl (cprop_of rthm)));
   671       val insts = Thm.cterm_match (rlhs, t)
   672       (* Pattern.match can raise Pattern.MATCH;
   673          is handled when congc is called *)
   674       val thm' = Thm.instantiate insts (Thm.rename_boundvars (term_of rlhs) (term_of t) rthm);
   675       val unit = trace_thm false "Applying congruence rule:" thm';
   676       fun err (msg, thm) = (prthm false msg thm; error "Failed congruence proof!")
   677   in case prover thm' of
   678        None => err ("Could not prove", thm')
   679      | Some thm2 => (case check_conv true (beta_eta_conversion t) thm2 of
   680           None => err ("Should not have proved", thm2)
   681         | Some thm2' => thm2')
   682   end;
   683 
   684 val (cA, (cB, cC)) =
   685   apsnd dest_equals (dest_implies (hd (cprems_of Drule.imp_cong)));
   686 
   687 fun transitive' thm1 None = Some thm1
   688   | transitive' thm1 (Some thm2) = Some (transitive thm1 thm2);
   689 
   690 fun bottomc ((simprem,useprem,mutsimp), prover, sign, maxidx) =
   691   let
   692     fun botc skel mss t =
   693           if is_Var skel then None
   694           else
   695           (case subc skel mss t of
   696              some as Some thm1 =>
   697                (case rewritec (prover, sign, maxidx) mss (rhs_of thm1) of
   698                   Some (thm2, skel2) =>
   699                     transitive' (transitive thm1 thm2)
   700                       (botc skel2 mss (rhs_of thm2))
   701                 | None => some)
   702            | None =>
   703                (case rewritec (prover, sign, maxidx) mss t of
   704                   Some (thm2, skel2) => transitive' thm2
   705                     (botc skel2 mss (rhs_of thm2))
   706                 | None => None))
   707 
   708     and try_botc mss t =
   709           (case botc skel0 mss t of
   710              Some trec1 => trec1 | None => (reflexive t))
   711 
   712     and subc skel
   713           (mss as Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) t0 =
   714        (case term_of t0 of
   715            Abs (a, T, t) =>
   716              let val b = variant bounds a
   717                  val (v, t') = Thm.dest_abs (Some ("." ^ b)) t0
   718                  val mss' = mk_mss (rules, congs, procs, b :: bounds, prems, mk_rews, termless,depth)
   719                  val skel' = case skel of Abs (_, _, sk) => sk | _ => skel0
   720              in case botc skel' mss' t' of
   721                   Some thm => Some (abstract_rule a v thm)
   722                 | None => None
   723              end
   724          | t $ _ => (case t of
   725              Const ("==>", _) $ _  =>
   726                let val (s, u) = Drule.dest_implies t0
   727                in impc (s, u, mss) end
   728            | Abs _ =>
   729                let val thm = beta_conversion false t0
   730                in case subc skel0 mss (rhs_of thm) of
   731                     None => Some thm
   732                   | Some thm' => Some (transitive thm thm')
   733                end
   734            | _  =>
   735                let fun appc () =
   736                      let
   737                        val (tskel, uskel) = case skel of
   738                            tskel $ uskel => (tskel, uskel)
   739                          | _ => (skel0, skel0);
   740                        val (ct, cu) = Thm.dest_comb t0
   741                      in
   742                      (case botc tskel mss ct of
   743                         Some thm1 =>
   744                           (case botc uskel mss cu of
   745                              Some thm2 => Some (combination thm1 thm2)
   746                            | None => Some (combination thm1 (reflexive cu)))
   747                       | None =>
   748                           (case botc uskel mss cu of
   749                              Some thm1 => Some (combination (reflexive ct) thm1)
   750                            | None => None))
   751                      end
   752                    val (h, ts) = strip_comb t
   753                in case h of
   754                     Const(a, _) =>
   755                       (case assoc_string (fst congs, a) of
   756                          None => appc ()
   757                        | Some cong =>
   758 (* post processing: some partial applications h t1 ... tj, j <= length ts,
   759    may be a redex. Example: map (%x.x) = (%xs.xs) wrt map_cong *)
   760                           (let
   761                              val thm = congc (prover mss, sign, maxidx) cong t0;
   762                              val t = rhs_of thm;
   763                              val (cl, cr) = Thm.dest_comb t
   764                              val dVar = Var(("", 0), dummyT)
   765                              val skel =
   766                                list_comb (h, replicate (length ts) dVar)
   767                            in case botc skel mss cl of
   768                                 None => Some thm
   769                               | Some thm' => Some (transitive thm
   770                                   (combination thm' (reflexive cr)))
   771                            end handle TERM _ => error "congc result"
   772                                     | Pattern.MATCH => appc ()))
   773                   | _ => appc ()
   774                end)
   775          | _ => None)
   776 
   777     and impc args =
   778       if mutsimp
   779       then let val (prem, conc, mss) = args
   780            in apsome snd (mut_impc ([], prem, conc, mss)) end
   781       else nonmut_impc args
   782 
   783     and mut_impc (prems, prem, conc, mss) = (case botc skel0 mss prem of
   784         None => mut_impc1 (prems, prem, conc, mss)
   785       | Some thm1 =>
   786           let val prem1 = rhs_of thm1
   787           in (case mut_impc1 (prems, prem1, conc, mss) of
   788               None => Some (None,
   789                 combination (combination refl_implies thm1) (reflexive conc))
   790             | Some (x, thm2) => Some (x, transitive (combination (combination
   791                 refl_implies thm1) (reflexive conc)) thm2))
   792           end)
   793 
   794     and mut_impc1 (prems, prem1, conc, mss) =
   795       let
   796         fun uncond ({thm, lhs, elhs, perm}) =
   797           if Thm.no_prems thm then Some lhs else None
   798 
   799         val (lhss1, mss1) =
   800           if maxidx_of_term (term_of prem1) <> ~1
   801           then (trace_cterm true
   802             "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem1;
   803                 ([],mss))
   804           else let val thm = assume prem1
   805                    val rrules1 = extract_safe_rrules (mss, thm)
   806                    val lhss1 = mapfilter uncond rrules1
   807                    val mss1 = foldl insert_rrule (add_prems (mss, [thm]), rrules1)
   808                in (lhss1, mss1) end
   809 
   810         fun disch1 thm =
   811           let val (cB', cC') = dest_eq thm
   812           in
   813             implies_elim (Thm.instantiate
   814               ([], [(cA, prem1), (cB, cB'), (cC, cC')]) Drule.imp_cong)
   815               (implies_intr prem1 thm)
   816           end
   817 
   818         fun rebuild None = (case rewritec (prover, sign, maxidx) mss
   819             (mk_implies (prem1, conc)) of
   820               None => None
   821             | Some (thm, _) => 
   822                 let val (prem, conc) = Drule.dest_implies (rhs_of thm)
   823                 in (case mut_impc (prems, prem, conc, mss) of
   824                     None => Some (None, thm)
   825                   | Some (x, thm') => Some (x, transitive thm thm'))
   826                 end handle TERM _ => Some (None, thm))
   827           | rebuild (Some thm2) =
   828             let val thm = disch1 thm2
   829             in (case rewritec (prover, sign, maxidx) mss (rhs_of thm) of
   830                  None => Some (None, thm)
   831                | Some (thm', _) =>
   832                    let val (prem, conc) = Drule.dest_implies (rhs_of thm')
   833                    in (case mut_impc (prems, prem, conc, mss) of
   834                        None => Some (None, transitive thm thm')
   835                      | Some (x, thm'') =>
   836                          Some (x, transitive (transitive thm thm') thm''))
   837                    end handle TERM _ => Some (None, transitive thm thm'))
   838             end
   839 
   840         fun simpconc () =
   841           let val (s, t) = Drule.dest_implies conc
   842           in case mut_impc (prems @ [prem1], s, t, mss1) of
   843                None => rebuild None
   844              | Some (Some i, thm2) =>
   845                   let
   846                     val (prem, cC') = Drule.dest_implies (rhs_of thm2);
   847                     val thm2' = transitive (disch1 thm2) (Thm.instantiate
   848                       ([], [(cA, prem1), (cB, prem), (cC, cC')])
   849                       Drule.swap_prems_eq)
   850                   in if i=0 then apsome (apsnd (transitive thm2'))
   851                        (mut_impc1 (prems, prem, mk_implies (prem1, cC'), mss))
   852                      else Some (Some (i-1), thm2')
   853                   end
   854              | Some (None, thm) => rebuild (Some thm)
   855           end handle TERM _ => rebuild (botc skel0 mss1 conc)
   856 
   857       in
   858         let
   859           val tsig = Sign.tsig_of sign
   860           fun reducible t =
   861             exists (fn lhs => Pattern.matches_subterm tsig (lhs, term_of t)) lhss1;
   862         in case dropwhile (not o reducible) prems of
   863             [] => simpconc ()
   864           | red::rest => (trace_cterm false "Can now reduce premise:" red;
   865               Some (Some (length rest), reflexive (mk_implies (prem1, conc))))
   866         end
   867       end
   868 
   869      (* legacy code - only for backwards compatibility *)
   870      and nonmut_impc (prem, conc, mss) =
   871        let val thm1 = if simprem then botc skel0 mss prem else None;
   872            val prem1 = if_none (apsome rhs_of thm1) prem;
   873            val maxidx1 = maxidx_of_term (term_of prem1)
   874            val mss1 =
   875              if not useprem then mss else
   876              if maxidx1 <> ~1
   877              then (trace_cterm true
   878                "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem1;
   879                    mss)
   880              else let val thm = assume prem1
   881                   in add_safe_simp (add_prems (mss, [thm]), thm) end
   882        in (case botc skel0 mss1 conc of
   883            None => (case thm1 of
   884                None => None
   885              | Some thm1' => Some (combination
   886                  (combination refl_implies thm1') (reflexive conc)))
   887          | Some thm2 =>
   888            let
   889              val conc2 = rhs_of thm2;
   890              val thm2' = implies_elim (Thm.instantiate
   891                ([], [(cA, prem1), (cB, conc), (cC, conc2)]) Drule.imp_cong)
   892                (implies_intr prem1 thm2)
   893            in (case thm1 of
   894                None => Some thm2'
   895              | Some thm1' => Some (transitive (combination
   896                  (combination refl_implies thm1') (reflexive conc)) thm2'))
   897            end)
   898        end
   899 
   900  in try_botc end;
   901 
   902 
   903 (*** Meta-rewriting: rewrites t to u and returns the theorem t==u ***)
   904 
   905 (*
   906   Parameters:
   907     mode = (simplify A,
   908             use A in simplifying B,
   909             use prems of B (if B is again a meta-impl.) to simplify A)
   910            when simplifying A ==> B
   911     mss: contains equality theorems of the form [|p1,...|] ==> t==u
   912     prover: how to solve premises in conditional rewrites and congruences
   913 *)
   914 
   915 fun rewrite_cterm mode prover mss ct =
   916   let val {sign, t, maxidx, ...} = rep_cterm ct
   917       val Mss{depth, ...} = mss
   918   in simp_depth := depth;
   919      bottomc (mode, prover, sign, maxidx) mss ct
   920   end
   921   handle THM (s, _, thms) =>
   922     error ("Exception THM was raised in simplifier:\n" ^ s ^ "\n" ^
   923       Pretty.string_of (pretty_thms thms));
   924 
   925 (*In [A1,...,An]==>B, rewrite the selected A's only -- for rewrite_goals_tac*)
   926 (*Do not rewrite flex-flex pairs*)
   927 fun goals_conv pred cv =
   928   let fun gconv i ct =
   929         let val (A,B) = Drule.dest_implies ct
   930             val (thA,j) = case term_of A of
   931                   Const("=?=",_)$_$_ => (reflexive A, i)
   932                 | _ => (if pred i then cv A else reflexive A, i+1)
   933         in  combination (combination refl_implies thA) (gconv j B) end
   934         handle TERM _ => reflexive ct
   935   in gconv 1 end;
   936 
   937 (* Rewrite A in !!x1,...x1. A *)
   938 fun forall_conv cv ct =
   939   let val p as (ct1, ct2) = Thm.dest_comb ct
   940   in (case pairself term_of p of
   941       (Const ("all", _), Abs (s, _, _)) =>
   942          let val (v, ct') = Thm.dest_abs (Some "@") ct2;
   943          in Thm.combination (Thm.reflexive ct1)
   944            (Thm.abstract_rule s v (forall_conv cv ct'))
   945          end
   946     | _ => cv ct)
   947   end handle TERM _ => cv ct;
   948 
   949 (*Use a conversion to transform a theorem*)
   950 fun fconv_rule cv th = equal_elim (cv (cprop_of th)) th;
   951 
   952 (*Rewrite a cterm (yielding again a cterm instead of a theorem)*)
   953 fun full_rewrite_cterm_aux _ [] = (fn ct => ct)
   954   | full_rewrite_cterm_aux prover thms =
   955       #2 o Thm.dest_comb o #prop o Thm.crep_thm o
   956         rewrite_cterm (true, false, false) prover (mss_of thms);
   957 
   958 (*Rewrite a theorem*)
   959 fun rewrite_rule_aux _ [] = (fn th => th)
   960   | rewrite_rule_aux prover thms =
   961       fconv_rule (rewrite_cterm (true,false,false) prover (mss_of thms));
   962 
   963 fun rewrite_thm mode prover mss = fconv_rule (rewrite_cterm mode prover mss);
   964 
   965 (*Rewrite the subgoals of a proof state (represented by a theorem) *)
   966 fun rewrite_goals_rule_aux _ []   th = th
   967   | rewrite_goals_rule_aux prover thms th =
   968       fconv_rule (goals_conv (K true) (rewrite_cterm (true, true, false) prover
   969         (mss_of thms))) th;
   970 
   971 (*Rewrite the subgoal of a proof state (represented by a theorem) *)
   972 fun rewrite_goal_rule mode prover mss i thm =
   973   if 0 < i  andalso  i <= nprems_of thm
   974   then fconv_rule (goals_conv (fn j => j=i) (rewrite_cterm mode prover mss)) thm
   975   else raise THM("rewrite_goal_rule",i,[thm]);
   976 
   977 end;
   978 
   979 structure BasicMetaSimplifier: BASIC_META_SIMPLIFIER = MetaSimplifier;
   980 open BasicMetaSimplifier;