src/Pure/meta_simplifier.ML
author berghofe
Mon Oct 21 17:09:31 2002 +0200 (2002-10-21)
changeset 13661 ec97dfc2bfe0
parent 13614 0b91269c0b13
child 13828 fb6ec40dd291
permissions -rw-r--r--
No more explicit manipulation of flex-flex constraints in goals_conv.
     1 (*  Title:      Pure/meta_simplifier.ML
     2     ID:         $Id$
     3     Author:     Tobias Nipkow and Stefan Berghofer
     4     License:    GPL (GNU GENERAL PUBLIC LICENSE)
     5 
     6 Meta-level Simplification.
     7 *)
     8 
     9 signature BASIC_META_SIMPLIFIER =
    10 sig
    11   val trace_simp: bool ref
    12   val debug_simp: bool ref
    13 end;
    14 
    15 signature META_SIMPLIFIER =
    16 sig
    17   include BASIC_META_SIMPLIFIER
    18   exception SIMPLIFIER of string * thm
    19   exception SIMPROC_FAIL of string * exn
    20   type meta_simpset
    21   val dest_mss          : meta_simpset ->
    22     {simps: thm list, congs: thm list, procs: (string * cterm list) list}
    23   val empty_mss         : meta_simpset
    24   val clear_mss         : meta_simpset -> meta_simpset
    25   val merge_mss         : meta_simpset * meta_simpset -> meta_simpset
    26   val add_simps         : meta_simpset * thm list -> meta_simpset
    27   val del_simps         : meta_simpset * thm list -> meta_simpset
    28   val mss_of            : thm list -> meta_simpset
    29   val add_congs         : meta_simpset * thm list -> meta_simpset
    30   val del_congs         : meta_simpset * thm list -> meta_simpset
    31   val add_simprocs      : meta_simpset *
    32     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    33       -> meta_simpset
    34   val del_simprocs      : meta_simpset *
    35     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    36       -> meta_simpset
    37   val add_prems         : meta_simpset * thm list -> meta_simpset
    38   val prems_of_mss      : meta_simpset -> thm list
    39   val set_mk_rews       : meta_simpset * (thm -> thm list) -> meta_simpset
    40   val set_mk_sym        : meta_simpset * (thm -> thm option) -> meta_simpset
    41   val set_mk_eq_True    : meta_simpset * (thm -> thm option) -> meta_simpset
    42   val set_termless      : meta_simpset * (term * term -> bool) -> meta_simpset
    43   val beta_eta_conversion: cterm -> thm
    44   val rewrite_cterm: bool * bool * bool ->
    45     (meta_simpset -> thm -> thm option) -> meta_simpset -> cterm -> thm
    46   val goals_conv        : (int -> bool) -> (cterm -> thm) -> cterm -> thm
    47   val forall_conv       : (cterm -> thm) -> cterm -> thm
    48   val fconv_rule        : (cterm -> thm) -> thm -> thm
    49   val rewrite_aux       : (meta_simpset -> thm -> thm option) -> bool -> thm list -> cterm -> thm
    50   val simplify_aux      : (meta_simpset -> thm -> thm option) -> bool -> thm list -> thm -> thm
    51   val rewrite_thm       : bool * bool * bool
    52                           -> (meta_simpset -> thm -> thm option)
    53                           -> meta_simpset -> thm -> thm
    54   val rewrite_goals_rule_aux: (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
    55   val rewrite_goal_rule : bool* bool * bool
    56                           -> (meta_simpset -> thm -> thm option)
    57                           -> meta_simpset -> int -> thm -> thm
    58   val rewrite_term: Sign.sg -> thm list -> (term -> term option) list -> term -> term
    59 end;
    60 
    61 structure MetaSimplifier : META_SIMPLIFIER =
    62 struct
    63 
    64 (** diagnostics **)
    65 
    66 exception SIMPLIFIER of string * thm;
    67 exception SIMPROC_FAIL of string * exn;
    68 
    69 val simp_depth = ref 0;
    70 
    71 local
    72 
    73 fun println a =
    74   tracing ((case ! simp_depth of 0 => "" | n => "[" ^ string_of_int n ^ "]") ^ a);
    75 
    76 fun prnt warn a = if warn then warning a else println a;
    77 fun prtm warn a sign t = prnt warn (a ^ "\n" ^ Sign.string_of_term sign t);
    78 fun prctm warn a t = prnt warn (a ^ "\n" ^ Display.string_of_cterm t);
    79 
    80 in
    81 
    82 fun prthm warn a = prctm warn a o Thm.cprop_of;
    83 
    84 val trace_simp = ref false;
    85 val debug_simp = ref false;
    86 
    87 fun trace warn a = if !trace_simp then prnt warn a else ();
    88 fun debug warn a = if !debug_simp then prnt warn a else ();
    89 
    90 fun trace_term warn a sign t = if !trace_simp then prtm warn a sign t else ();
    91 fun trace_cterm warn a t = if !trace_simp then prctm warn a t else ();
    92 fun debug_term warn a sign t = if !debug_simp then prtm warn a sign t else ();
    93 
    94 fun trace_thm a thm =
    95   let val {sign, prop, ...} = rep_thm thm
    96   in trace_term false a sign prop end;
    97 
    98 fun trace_named_thm a (thm, name) =
    99   trace_thm (a ^ (if name = "" then "" else " " ^ quote name) ^ ":") thm;
   100 
   101 end;
   102 
   103 
   104 (** meta simp sets **)
   105 
   106 (* basic components *)
   107 
   108 type rrule = {thm: thm, name: string, lhs: term, elhs: cterm, fo: bool, perm: bool};
   109 (* thm: the rewrite rule
   110    name: name of theorem from which rewrite rule was extracted
   111    lhs: the left-hand side
   112    elhs: the etac-contracted lhs.
   113    fo:  use first-order matching
   114    perm: the rewrite rule is permutative
   115 Remarks:
   116   - elhs is used for matching,
   117     lhs only for preservation of bound variable names.
   118   - fo is set iff
   119     either elhs is first-order (no Var is applied),
   120            in which case fo-matching is complete,
   121     or elhs is not a pattern,
   122        in which case there is nothing better to do.
   123 *)
   124 type cong = {thm: thm, lhs: cterm};
   125 type simproc =
   126  {name: string, proc: Sign.sg -> thm list -> term -> thm option, lhs: cterm, id: stamp};
   127 
   128 fun eq_rrule ({thm = thm1, ...}: rrule, {thm = thm2, ...}: rrule) =
   129   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   130 
   131 fun eq_cong ({thm = thm1, ...}: cong, {thm = thm2, ...}: cong) =
   132   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   133 
   134 fun eq_prem (thm1, thm2) =
   135   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   136 
   137 fun eq_simproc ({id = s1, ...}:simproc, {id = s2, ...}:simproc) = (s1 = s2);
   138 
   139 fun mk_simproc (name, proc, lhs, id) =
   140   {name = name, proc = proc, lhs = lhs, id = id};
   141 
   142 
   143 (* datatype mss *)
   144 
   145 (*
   146   A "mss" contains data needed during conversion:
   147     rules: discrimination net of rewrite rules;
   148     congs: association list of congruence rules and
   149            a list of `weak' congruence constants.
   150            A congruence is `weak' if it avoids normalization of some argument.
   151     procs: discrimination net of simplification procedures
   152       (functions that prove rewrite rules on the fly);
   153     bounds: names of bound variables already used
   154       (for generating new names when rewriting under lambda abstractions);
   155     prems: current premises;
   156     mk_rews: mk: turns simplification thms into rewrite rules;
   157              mk_sym: turns == around; (needs Drule!)
   158              mk_eq_True: turns P into P == True - logic specific;
   159     termless: relation for ordered rewriting;
   160     depth: depth of conditional rewriting;
   161 *)
   162 
   163 datatype meta_simpset =
   164   Mss of {
   165     rules: rrule Net.net,
   166     congs: (string * cong) list * string list,
   167     procs: simproc Net.net,
   168     bounds: string list,
   169     prems: thm list,
   170     mk_rews: {mk: thm -> thm list,
   171               mk_sym: thm -> thm option,
   172               mk_eq_True: thm -> thm option},
   173     termless: term * term -> bool,
   174     depth: int};
   175 
   176 fun mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth) =
   177   Mss {rules = rules, congs = congs, procs = procs, bounds = bounds,
   178        prems=prems, mk_rews=mk_rews, termless=termless, depth=depth};
   179 
   180 fun upd_rules(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}, rules') =
   181   mk_mss(rules',congs,procs,bounds,prems,mk_rews,termless,depth);
   182 
   183 val empty_mss =
   184   let val mk_rews = {mk = K [], mk_sym = K None, mk_eq_True = K None}
   185   in mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, Term.termless, 0) end;
   186 
   187 fun clear_mss (Mss {mk_rews, termless, ...}) =
   188   mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, termless,0);
   189 
   190 fun incr_depth(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) =
   191   mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth+1)
   192 
   193 
   194 
   195 (** simpset operations **)
   196 
   197 (* term variables *)
   198 
   199 val add_term_varnames = foldl_aterms (fn (xs, Var (x, _)) => ins_ix (x, xs) | (xs, _) => xs);
   200 fun term_varnames t = add_term_varnames ([], t);
   201 
   202 
   203 (* dest_mss *)
   204 
   205 fun dest_mss (Mss {rules, congs, procs, ...}) =
   206   {simps = map (fn (_, {thm, ...}) => thm) (Net.dest rules),
   207    congs = map (fn (_, {thm, ...}) => thm) (fst congs),
   208    procs =
   209      map (fn (_, {name, lhs, id, ...}) => ((name, lhs), id)) (Net.dest procs)
   210      |> partition_eq eq_snd
   211      |> map (fn ps => (#1 (#1 (hd ps)), map (#2 o #1) ps))
   212      |> Library.sort_wrt #1};
   213 
   214 
   215 (* merge_mss *)       (*NOTE: ignores mk_rews, termless and depth of 2nd mss*)
   216 
   217 fun merge_mss
   218  (Mss {rules = rules1, congs = (congs1,weak1), procs = procs1,
   219        bounds = bounds1, prems = prems1, mk_rews, termless, depth},
   220   Mss {rules = rules2, congs = (congs2,weak2), procs = procs2,
   221        bounds = bounds2, prems = prems2, ...}) =
   222       mk_mss
   223        (Net.merge (rules1, rules2, eq_rrule),
   224         (gen_merge_lists (eq_cong o pairself snd) congs1 congs2,
   225         merge_lists weak1 weak2),
   226         Net.merge (procs1, procs2, eq_simproc),
   227         merge_lists bounds1 bounds2,
   228         gen_merge_lists eq_prem prems1 prems2,
   229         mk_rews, termless, depth);
   230 
   231 
   232 (* add_simps *)
   233 
   234 fun mk_rrule2{thm, name, lhs, elhs, perm} =
   235   let val fo = Pattern.first_order (term_of elhs) orelse not(Pattern.pattern (term_of elhs))
   236   in {thm=thm, name=name, lhs=lhs, elhs=elhs, fo=fo, perm=perm} end
   237 
   238 fun insert_rrule quiet (mss as Mss {rules,...},
   239                  rrule as {thm,name,lhs,elhs,perm}) =
   240   (trace_named_thm "Adding rewrite rule" (thm, name);
   241    let val rrule2 as {elhs,...} = mk_rrule2 rrule
   242        val rules' = Net.insert_term ((term_of elhs, rrule2), rules, eq_rrule)
   243    in upd_rules(mss,rules') end
   244    handle Net.INSERT => if quiet then mss else
   245      (prthm true "Ignoring duplicate rewrite rule:" thm; mss));
   246 
   247 fun vperm (Var _, Var _) = true
   248   | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t)
   249   | vperm (t1 $ t2, u1 $ u2) = vperm (t1, u1) andalso vperm (t2, u2)
   250   | vperm (t, u) = (t = u);
   251 
   252 fun var_perm (t, u) =
   253   vperm (t, u) andalso eq_set (term_varnames t, term_varnames u);
   254 
   255 (* FIXME: it seems that the conditions on extra variables are too liberal if
   256 prems are nonempty: does solving the prems really guarantee instantiation of
   257 all its Vars? Better: a dynamic check each time a rule is applied.
   258 *)
   259 fun rewrite_rule_extra_vars prems elhs erhs =
   260   not (term_varnames erhs subset foldl add_term_varnames (term_varnames elhs, prems))
   261   orelse
   262   not ((term_tvars erhs) subset
   263        (term_tvars elhs  union  List.concat(map term_tvars prems)));
   264 
   265 (*Simple test for looping rewrite rules and stupid orientations*)
   266 fun reorient sign prems lhs rhs =
   267    rewrite_rule_extra_vars prems lhs rhs
   268   orelse
   269    is_Var (head_of lhs)
   270   orelse
   271    (exists (apl (lhs, Logic.occs)) (rhs :: prems))
   272   orelse
   273    (null prems andalso
   274     Pattern.matches (#tsig (Sign.rep_sg sign)) (lhs, rhs))
   275     (*the condition "null prems" is necessary because conditional rewrites
   276       with extra variables in the conditions may terminate although
   277       the rhs is an instance of the lhs. Example: ?m < ?n ==> f(?n) == f(?m)*)
   278   orelse
   279    (is_Const lhs andalso not(is_Const rhs))
   280 
   281 fun decomp_simp thm =
   282   let val {sign, prop, ...} = rep_thm thm;
   283       val prems = Logic.strip_imp_prems prop;
   284       val concl = Drule.strip_imp_concl (cprop_of thm);
   285       val (lhs, rhs) = Drule.dest_equals concl handle TERM _ =>
   286         raise SIMPLIFIER ("Rewrite rule not a meta-equality", thm)
   287       val elhs = snd (Drule.dest_equals (cprop_of (Thm.eta_conversion lhs)));
   288       val elhs = if elhs=lhs then lhs else elhs (* try to share *)
   289       val erhs = Pattern.eta_contract (term_of rhs);
   290       val perm = var_perm (term_of elhs, erhs) andalso not (term_of elhs aconv erhs)
   291                  andalso not (is_Var (term_of elhs))
   292   in (sign, prems, term_of lhs, elhs, term_of rhs, perm) end;
   293 
   294 fun decomp_simp' thm =
   295   let val (_, _, lhs, _, rhs, _) = decomp_simp thm in
   296     if Thm.nprems_of thm > 0 then raise SIMPLIFIER ("Bad conditional rewrite rule", thm)
   297     else (lhs, rhs)
   298   end;
   299 
   300 fun mk_eq_True (Mss{mk_rews={mk_eq_True,...},...}) (thm, name) =
   301   case mk_eq_True thm of
   302     None => []
   303   | Some eq_True =>
   304       let val (_,_,lhs,elhs,_,_) = decomp_simp eq_True
   305       in [{thm=eq_True, name=name, lhs=lhs, elhs=elhs, perm=false}] end;
   306 
   307 (* create the rewrite rule and possibly also the ==True variant,
   308    in case there are extra vars on the rhs *)
   309 fun rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm2) =
   310   let val rrule = {thm=thm, name=name, lhs=lhs, elhs=elhs, perm=false}
   311   in if (term_varnames rhs)  subset (term_varnames lhs) andalso
   312         (term_tvars rhs) subset (term_tvars lhs)
   313      then [rrule]
   314      else mk_eq_True mss (thm2, name) @ [rrule]
   315   end;
   316 
   317 fun mk_rrule mss (thm, name) =
   318   let val (_,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   319   in if perm then [{thm=thm, name=name, lhs=lhs, elhs=elhs, perm=true}] else
   320      (* weak test for loops: *)
   321      if rewrite_rule_extra_vars prems lhs rhs orelse
   322         is_Var (term_of elhs)
   323      then mk_eq_True mss (thm, name)
   324      else rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm)
   325   end;
   326 
   327 fun orient_rrule mss (thm, name) =
   328   let val (sign,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   329   in if perm then [{thm=thm, name=name, lhs=lhs, elhs=elhs, perm=true}]
   330      else if reorient sign prems lhs rhs
   331           then if reorient sign prems rhs lhs
   332                then mk_eq_True mss (thm, name)
   333                else let val Mss{mk_rews={mk_sym,...},...} = mss
   334                     in case mk_sym thm of
   335                          None => []
   336                        | Some thm' =>
   337                            let val (_,_,lhs',elhs',rhs',_) = decomp_simp thm'
   338                            in rrule_eq_True(thm',name,lhs',elhs',rhs',mss,thm) end
   339                     end
   340           else rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm)
   341   end;
   342 
   343 fun extract_rews(Mss{mk_rews = {mk,...},...},thms) =
   344   flat (map (fn thm => map (rpair (Thm.name_of_thm thm)) (mk thm)) thms);
   345 
   346 fun orient_comb_simps comb mk_rrule (mss,thms) =
   347   let val rews = extract_rews(mss,thms)
   348       val rrules = flat (map mk_rrule rews)
   349   in foldl comb (mss,rrules) end
   350 
   351 (* Add rewrite rules explicitly; do not reorient! *)
   352 fun add_simps(mss,thms) =
   353   orient_comb_simps (insert_rrule false) (mk_rrule mss) (mss,thms);
   354 
   355 fun mss_of thms = foldl (insert_rrule false) (empty_mss, flat
   356   (map (fn thm => mk_rrule empty_mss (thm, Thm.name_of_thm thm)) thms));
   357 
   358 fun extract_safe_rrules(mss,thm) =
   359   flat (map (orient_rrule mss) (extract_rews(mss,[thm])));
   360 
   361 (* del_simps *)
   362 
   363 fun del_rrule(mss as Mss {rules,...},
   364               rrule as {thm, elhs, ...}) =
   365   (upd_rules(mss, Net.delete_term ((term_of elhs, rrule), rules, eq_rrule))
   366    handle Net.DELETE =>
   367      (prthm true "Rewrite rule not in simpset:" thm; mss));
   368 
   369 fun del_simps(mss,thms) =
   370   orient_comb_simps del_rrule (map mk_rrule2 o mk_rrule mss) (mss,thms);
   371 
   372 
   373 (* add_congs *)
   374 
   375 fun is_full_cong_prems [] varpairs = null varpairs
   376   | is_full_cong_prems (p::prems) varpairs =
   377     (case Logic.strip_assums_concl p of
   378        Const("==",_) $ lhs $ rhs =>
   379          let val (x,xs) = strip_comb lhs and (y,ys) = strip_comb rhs
   380          in is_Var x  andalso  forall is_Bound xs  andalso
   381             null(findrep(xs))  andalso xs=ys andalso
   382             (x,y) mem varpairs andalso
   383             is_full_cong_prems prems (varpairs\(x,y))
   384          end
   385      | _ => false);
   386 
   387 fun is_full_cong thm =
   388 let val prems = prems_of thm
   389     and concl = concl_of thm
   390     val (lhs,rhs) = Logic.dest_equals concl
   391     val (f,xs) = strip_comb lhs
   392     and (g,ys) = strip_comb rhs
   393 in
   394   f=g andalso null(findrep(xs@ys)) andalso length xs = length ys andalso
   395   is_full_cong_prems prems (xs ~~ ys)
   396 end
   397 
   398 fun add_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   399   let
   400     val (lhs, _) = Drule.dest_equals (Drule.strip_imp_concl (cprop_of thm)) handle TERM _ =>
   401       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   402 (*   val lhs = Pattern.eta_contract lhs; *)
   403     val (a, _) = dest_Const (head_of (term_of lhs)) handle TERM _ =>
   404       raise SIMPLIFIER ("Congruence must start with a constant", thm);
   405     val (alist,weak) = congs
   406     val alist2 = overwrite_warn (alist, (a,{lhs=lhs, thm=thm}))
   407            ("Overwriting congruence rule for " ^ quote a);
   408     val weak2 = if is_full_cong thm then weak else a::weak
   409   in
   410     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   411   end;
   412 
   413 val (op add_congs) = foldl add_cong;
   414 
   415 
   416 (* del_congs *)
   417 
   418 fun del_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   419   let
   420     val (lhs, _) = Logic.dest_equals (concl_of thm) handle TERM _ =>
   421       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   422 (*   val lhs = Pattern.eta_contract lhs; *)
   423     val (a, _) = dest_Const (head_of lhs) handle TERM _ =>
   424       raise SIMPLIFIER ("Congruence must start with a constant", thm);
   425     val (alist,_) = congs
   426     val alist2 = filter (fn (x,_)=> x<>a) alist
   427     val weak2 = mapfilter (fn(a,{thm,...}) => if is_full_cong thm then None
   428                                               else Some a)
   429                    alist2
   430   in
   431     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   432   end;
   433 
   434 val (op del_congs) = foldl del_cong;
   435 
   436 
   437 (* add_simprocs *)
   438 
   439 fun add_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   440     (name, lhs, proc, id)) =
   441   let val {sign, t, ...} = rep_cterm lhs
   442   in (trace_term false ("Adding simplification procedure " ^ quote name ^ " for")
   443       sign t;
   444     mk_mss (rules, congs,
   445       Net.insert_term ((t, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   446         handle Net.INSERT =>
   447             (warning ("Ignoring duplicate simplification procedure \""
   448                       ^ name ^ "\"");
   449              procs),
   450         bounds, prems, mk_rews, termless,depth))
   451   end;
   452 
   453 fun add_simproc (mss, (name, lhss, proc, id)) =
   454   foldl add_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   455 
   456 val add_simprocs = foldl add_simproc;
   457 
   458 
   459 (* del_simprocs *)
   460 
   461 fun del_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   462     (name, lhs, proc, id)) =
   463   mk_mss (rules, congs,
   464     Net.delete_term ((term_of lhs, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   465       handle Net.DELETE =>
   466           (warning ("Simplification procedure \"" ^ name ^
   467                        "\" not in simpset"); procs),
   468       bounds, prems, mk_rews, termless, depth);
   469 
   470 fun del_simproc (mss, (name, lhss, proc, id)) =
   471   foldl del_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   472 
   473 val del_simprocs = foldl del_simproc;
   474 
   475 
   476 (* prems *)
   477 
   478 fun add_prems (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thms) =
   479   mk_mss (rules, congs, procs, bounds, thms @ prems, mk_rews, termless, depth);
   480 
   481 fun prems_of_mss (Mss {prems, ...}) = prems;
   482 
   483 
   484 (* mk_rews *)
   485 
   486 fun set_mk_rews
   487   (Mss {rules, congs, procs, bounds, prems, mk_rews, termless, depth}, mk) =
   488     mk_mss (rules, congs, procs, bounds, prems,
   489             {mk=mk, mk_sym= #mk_sym mk_rews, mk_eq_True= #mk_eq_True mk_rews},
   490             termless, depth);
   491 
   492 fun set_mk_sym
   493   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_sym) =
   494     mk_mss (rules, congs, procs, bounds, prems,
   495             {mk= #mk mk_rews, mk_sym= mk_sym, mk_eq_True= #mk_eq_True mk_rews},
   496             termless,depth);
   497 
   498 fun set_mk_eq_True
   499   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_eq_True) =
   500     mk_mss (rules, congs, procs, bounds, prems,
   501             {mk= #mk mk_rews, mk_sym= #mk_sym mk_rews, mk_eq_True= mk_eq_True},
   502             termless,depth);
   503 
   504 (* termless *)
   505 
   506 fun set_termless
   507   (Mss {rules, congs, procs, bounds, prems, mk_rews, depth, ...}, termless) =
   508     mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth);
   509 
   510 
   511 
   512 (** rewriting **)
   513 
   514 (*
   515   Uses conversions, see:
   516     L C Paulson, A higher-order implementation of rewriting,
   517     Science of Computer Programming 3 (1983), pages 119-149.
   518 *)
   519 
   520 val dest_eq = Drule.dest_equals o cprop_of;
   521 val lhs_of = fst o dest_eq;
   522 val rhs_of = snd o dest_eq;
   523 
   524 fun beta_eta_conversion t =
   525   let val thm = beta_conversion true t;
   526   in transitive thm (eta_conversion (rhs_of thm)) end;
   527 
   528 fun check_conv msg thm thm' =
   529   let
   530     val thm'' = transitive thm (transitive
   531       (symmetric (beta_eta_conversion (lhs_of thm'))) thm')
   532   in (if msg then trace_thm "SUCCEEDED" thm' else (); Some thm'') end
   533   handle THM _ =>
   534     let val {sign, prop = _ $ _ $ prop0, ...} = rep_thm thm;
   535     in
   536       (trace_thm "Proved wrong thm (Check subgoaler?)" thm';
   537        trace_term false "Should have proved:" sign prop0;
   538        None)
   539     end;
   540 
   541 
   542 (* mk_procrule *)
   543 
   544 fun mk_procrule thm =
   545   let val (_,prems,lhs,elhs,rhs,_) = decomp_simp thm
   546   in if rewrite_rule_extra_vars prems lhs rhs
   547      then (prthm true "Extra vars on rhs:" thm; [])
   548      else [mk_rrule2{thm=thm, name="", lhs=lhs, elhs=elhs, perm=false}]
   549   end;
   550 
   551 
   552 (* conversion to apply the meta simpset to a term *)
   553 
   554 (* Since the rewriting strategy is bottom-up, we avoid re-normalizing already
   555    normalized terms by carrying around the rhs of the rewrite rule just
   556    applied. This is called the `skeleton'. It is decomposed in parallel
   557    with the term. Once a Var is encountered, the corresponding term is
   558    already in normal form.
   559    skel0 is a dummy skeleton that is to enforce complete normalization.
   560 *)
   561 val skel0 = Bound 0;
   562 
   563 (* Use rhs as skeleton only if the lhs does not contain unnormalized bits.
   564    The latter may happen iff there are weak congruence rules for constants
   565    in the lhs.
   566 *)
   567 fun uncond_skel((_,weak),(lhs,rhs)) =
   568   if null weak then rhs (* optimization *)
   569   else if exists_Const (fn (c,_) => c mem weak) lhs then skel0
   570        else rhs;
   571 
   572 (* Behaves like unconditional rule if rhs does not contain vars not in the lhs.
   573    Otherwise those vars may become instantiated with unnormalized terms
   574    while the premises are solved.
   575 *)
   576 fun cond_skel(args as (congs,(lhs,rhs))) =
   577   if term_varnames rhs subset term_varnames lhs then uncond_skel(args)
   578   else skel0;
   579 
   580 (*
   581   we try in order:
   582     (1) beta reduction
   583     (2) unconditional rewrite rules
   584     (3) conditional rewrite rules
   585     (4) simplification procedures
   586 
   587   IMPORTANT: rewrite rules must not introduce new Vars or TVars!
   588 
   589 *)
   590 
   591 fun rewritec (prover, signt, maxt)
   592              (mss as Mss{rules, procs, termless, prems, congs, depth,...}) t =
   593   let
   594     val eta_thm = Thm.eta_conversion t;
   595     val eta_t' = rhs_of eta_thm;
   596     val eta_t = term_of eta_t';
   597     val tsigt = Sign.tsig_of signt;
   598     fun rew {thm, name, lhs, elhs, fo, perm} =
   599       let
   600         val {sign, prop, maxidx, ...} = rep_thm thm;
   601         val _ = if Sign.subsig (sign, signt) then ()
   602                 else (prthm true "Ignoring rewrite rule from different theory:" thm;
   603                       raise Pattern.MATCH);
   604         val (rthm, elhs') = if maxt = ~1 then (thm, elhs)
   605           else (Thm.incr_indexes (maxt+1) thm, Thm.cterm_incr_indexes (maxt+1) elhs);
   606         val insts = if fo then Thm.cterm_first_order_match (elhs', eta_t')
   607                           else Thm.cterm_match (elhs', eta_t');
   608         val thm' = Thm.instantiate insts (Thm.rename_boundvars lhs eta_t rthm);
   609         val prop' = #prop (rep_thm thm');
   610         val unconditional = (Logic.count_prems (prop',0) = 0);
   611         val (lhs', rhs') = Logic.dest_equals (Logic.strip_imp_concl prop')
   612       in
   613         if perm andalso not (termless (rhs', lhs'))
   614         then (trace_named_thm "Cannot apply permutative rewrite rule" (thm, name);
   615               trace_thm "Term does not become smaller:" thm'; None)
   616         else (trace_named_thm "Applying instance of rewrite rule" (thm, name);
   617            if unconditional
   618            then
   619              (trace_thm "Rewriting:" thm';
   620               let val lr = Logic.dest_equals prop;
   621                   val Some thm'' = check_conv false eta_thm thm'
   622               in Some (thm'', uncond_skel (congs, lr)) end)
   623            else
   624              (trace_thm "Trying to rewrite:" thm';
   625               case prover (incr_depth mss) thm' of
   626                 None       => (trace_thm "FAILED" thm'; None)
   627               | Some thm2 =>
   628                   (case check_conv true eta_thm thm2 of
   629                      None => None |
   630                      Some thm2' =>
   631                        let val concl = Logic.strip_imp_concl prop
   632                            val lr = Logic.dest_equals concl
   633                        in Some (thm2', cond_skel (congs, lr)) end)))
   634       end
   635 
   636     fun rews [] = None
   637       | rews (rrule :: rrules) =
   638           let val opt = rew rrule handle Pattern.MATCH => None
   639           in case opt of None => rews rrules | some => some end;
   640 
   641     fun sort_rrules rrs = let
   642       fun is_simple({thm, ...}:rrule) = case #prop (rep_thm thm) of
   643                                       Const("==",_) $ _ $ _ => true
   644                                       | _                   => false
   645       fun sort []        (re1,re2) = re1 @ re2
   646         | sort (rr::rrs) (re1,re2) = if is_simple rr
   647                                      then sort rrs (rr::re1,re2)
   648                                      else sort rrs (re1,rr::re2)
   649     in sort rrs ([],[]) end
   650 
   651     fun proc_rews ([]:simproc list) = None
   652       | proc_rews ({name, proc, lhs, ...} :: ps) =
   653           if Pattern.matches tsigt (term_of lhs, term_of t) then
   654             (debug_term false ("Trying procedure " ^ quote name ^ " on:") signt eta_t;
   655              case transform_failure (curry SIMPROC_FAIL name)
   656                  (fn () => proc signt prems eta_t) () of
   657                None => (debug false "FAILED"; proc_rews ps)
   658              | Some raw_thm =>
   659                  (trace_thm ("Procedure " ^ quote name ^ " produced rewrite rule:") raw_thm;
   660                   (case rews (mk_procrule raw_thm) of
   661                     None => (trace_cterm true ("IGNORED result of simproc " ^ quote name ^
   662                       " -- does not match") t; proc_rews ps)
   663                   | some => some)))
   664           else proc_rews ps;
   665   in case eta_t of
   666        Abs _ $ _ => Some (transitive eta_thm
   667          (beta_conversion false eta_t'), skel0)
   668      | _ => (case rews (sort_rrules (Net.match_term rules eta_t)) of
   669                None => proc_rews (Net.match_term procs eta_t)
   670              | some => some)
   671   end;
   672 
   673 
   674 (* conversion to apply a congruence rule to a term *)
   675 
   676 fun congc (prover,signt,maxt) {thm=cong,lhs=lhs} t =
   677   let val {sign, ...} = rep_thm cong
   678       val _ = if Sign.subsig (sign, signt) then ()
   679                  else error("Congruence rule from different theory")
   680       val rthm = if maxt = ~1 then cong else Thm.incr_indexes (maxt+1) cong;
   681       val rlhs = fst (Drule.dest_equals (Drule.strip_imp_concl (cprop_of rthm)));
   682       val insts = Thm.cterm_match (rlhs, t)
   683       (* Pattern.match can raise Pattern.MATCH;
   684          is handled when congc is called *)
   685       val thm' = Thm.instantiate insts (Thm.rename_boundvars (term_of rlhs) (term_of t) rthm);
   686       val unit = trace_thm "Applying congruence rule:" thm';
   687       fun err (msg, thm) = (prthm false msg thm; error "Failed congruence proof!")
   688   in case prover thm' of
   689        None => err ("Could not prove", thm')
   690      | Some thm2 => (case check_conv true (beta_eta_conversion t) thm2 of
   691           None => err ("Should not have proved", thm2)
   692         | Some thm2' =>
   693             if op aconv (pairself term_of (dest_equals (cprop_of thm2')))
   694             then None else Some thm2')
   695   end;
   696 
   697 val (cA, (cB, cC)) =
   698   apsnd dest_equals (dest_implies (hd (cprems_of Drule.imp_cong)));
   699 
   700 fun transitive1 None None = None
   701   | transitive1 (Some thm1) None = Some thm1
   702   | transitive1 None (Some thm2) = Some thm2
   703   | transitive1 (Some thm1) (Some thm2) = Some (transitive thm1 thm2)
   704 
   705 fun transitive2 thm = transitive1 (Some thm);
   706 fun transitive3 thm = transitive1 thm o Some;
   707 
   708 fun imp_cong' e = combination (combination refl_implies e);
   709 
   710 fun bottomc ((simprem,useprem,mutsimp), prover, sign, maxidx) =
   711   let
   712     fun botc skel mss t =
   713           if is_Var skel then None
   714           else
   715           (case subc skel mss t of
   716              some as Some thm1 =>
   717                (case rewritec (prover, sign, maxidx) mss (rhs_of thm1) of
   718                   Some (thm2, skel2) =>
   719                     transitive2 (transitive thm1 thm2)
   720                       (botc skel2 mss (rhs_of thm2))
   721                 | None => some)
   722            | None =>
   723                (case rewritec (prover, sign, maxidx) mss t of
   724                   Some (thm2, skel2) => transitive2 thm2
   725                     (botc skel2 mss (rhs_of thm2))
   726                 | None => None))
   727 
   728     and try_botc mss t =
   729           (case botc skel0 mss t of
   730              Some trec1 => trec1 | None => (reflexive t))
   731 
   732     and subc skel
   733           (mss as Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) t0 =
   734        (case term_of t0 of
   735            Abs (a, T, t) =>
   736              let val b = variant bounds a
   737                  val (v, t') = Thm.dest_abs (Some ("." ^ b)) t0
   738                  val mss' = mk_mss (rules, congs, procs, b :: bounds, prems, mk_rews, termless,depth)
   739                  val skel' = case skel of Abs (_, _, sk) => sk | _ => skel0
   740              in case botc skel' mss' t' of
   741                   Some thm => Some (abstract_rule a v thm)
   742                 | None => None
   743              end
   744          | t $ _ => (case t of
   745              Const ("==>", _) $ _  => impc t0 mss
   746            | Abs _ =>
   747                let val thm = beta_conversion false t0
   748                in case subc skel0 mss (rhs_of thm) of
   749                     None => Some thm
   750                   | Some thm' => Some (transitive thm thm')
   751                end
   752            | _  =>
   753                let fun appc () =
   754                      let
   755                        val (tskel, uskel) = case skel of
   756                            tskel $ uskel => (tskel, uskel)
   757                          | _ => (skel0, skel0);
   758                        val (ct, cu) = Thm.dest_comb t0
   759                      in
   760                      (case botc tskel mss ct of
   761                         Some thm1 =>
   762                           (case botc uskel mss cu of
   763                              Some thm2 => Some (combination thm1 thm2)
   764                            | None => Some (combination thm1 (reflexive cu)))
   765                       | None =>
   766                           (case botc uskel mss cu of
   767                              Some thm1 => Some (combination (reflexive ct) thm1)
   768                            | None => None))
   769                      end
   770                    val (h, ts) = strip_comb t
   771                in case h of
   772                     Const(a, _) =>
   773                       (case assoc_string (fst congs, a) of
   774                          None => appc ()
   775                        | Some cong =>
   776 (* post processing: some partial applications h t1 ... tj, j <= length ts,
   777    may be a redex. Example: map (%x.x) = (%xs.xs) wrt map_cong *)
   778                           (let
   779                              val thm = congc (prover mss, sign, maxidx) cong t0;
   780                              val t = if_none (apsome rhs_of thm) t0;
   781                              val (cl, cr) = Thm.dest_comb t
   782                              val dVar = Var(("", 0), dummyT)
   783                              val skel =
   784                                list_comb (h, replicate (length ts) dVar)
   785                            in case botc skel mss cl of
   786                                 None => thm
   787                               | Some thm' => transitive3 thm
   788                                   (combination thm' (reflexive cr))
   789                            end handle TERM _ => error "congc result"
   790                                     | Pattern.MATCH => appc ()))
   791                   | _ => appc ()
   792                end)
   793          | _ => None)
   794 
   795     and impc ct mss =
   796       if mutsimp then mut_impc0 [] ct [] [] mss else nonmut_impc ct mss
   797 
   798     and rules_of_prem mss prem =
   799       if maxidx_of_term (term_of prem) <> ~1
   800       then (trace_cterm true
   801         "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem; ([], None))
   802       else
   803         let val asm = assume prem
   804         in (extract_safe_rrules (mss, asm), Some asm) end
   805 
   806     and add_rrules (rrss, asms) mss =
   807       add_prems (foldl (insert_rrule true) (mss, flat rrss), mapfilter I asms)
   808 
   809     and disch r (prem, eq) =
   810       let
   811         val (lhs, rhs) = dest_eq eq;
   812         val eq' = implies_elim (Thm.instantiate
   813           ([], [(cA, prem), (cB, lhs), (cC, rhs)]) Drule.imp_cong)
   814           (implies_intr prem eq)
   815       in if not r then eq' else
   816         let
   817           val (prem', concl) = dest_implies lhs;
   818           val (prem'', _) = dest_implies rhs
   819         in transitive (transitive
   820           (Thm.instantiate ([], [(cA, prem'), (cB, prem), (cC, concl)])
   821              Drule.swap_prems_eq) eq')
   822           (Thm.instantiate ([], [(cA, prem), (cB, prem''), (cC, concl)])
   823              Drule.swap_prems_eq)
   824         end
   825       end
   826 
   827     and rebuild [] _ _ _ _ eq = eq
   828       | rebuild (prem :: prems) concl (rrs :: rrss) (asm :: asms) mss eq =
   829           let
   830             val mss' = add_rrules (rev rrss, rev asms) mss;
   831             val concl' =
   832               Drule.mk_implies (prem, if_none (apsome rhs_of eq) concl);
   833             val dprem = apsome (curry (disch false) prem)
   834           in case rewritec (prover, sign, maxidx) mss' concl' of
   835               None => rebuild prems concl' rrss asms mss (dprem eq)
   836             | Some (eq', _) => transitive2 (foldl (disch false o swap)
   837                   (the (transitive3 (dprem eq) eq'), prems))
   838                 (mut_impc0 (rev prems) (rhs_of eq') (rev rrss) (rev asms) mss)
   839           end
   840           
   841     and mut_impc0 prems concl rrss asms mss =
   842       let
   843         val prems' = strip_imp_prems concl;
   844         val (rrss', asms') = split_list (map (rules_of_prem mss) prems')
   845       in mut_impc (prems @ prems') (strip_imp_concl concl) (rrss @ rrss')
   846         (asms @ asms') [] [] [] [] mss ~1 ~1
   847       end
   848  
   849     and mut_impc [] concl [] [] prems' rrss' asms' eqns mss changed k =
   850         transitive1 (foldl (fn (eq2, (eq1, prem)) => transitive1 eq1
   851             (apsome (curry (disch false) prem) eq2)) (None, eqns ~~ prems'))
   852           (if changed > 0 then
   853              mut_impc (rev prems') concl (rev rrss') (rev asms')
   854                [] [] [] [] mss ~1 changed
   855            else rebuild prems' concl rrss' asms' mss
   856              (botc skel0 (add_rrules (rev rrss', rev asms') mss) concl))
   857 
   858       | mut_impc (prem :: prems) concl (rrs :: rrss) (asm :: asms)
   859           prems' rrss' asms' eqns mss changed k =
   860         case (if k = 0 then None else botc skel0 (add_rrules
   861           (rev rrss' @ rrss, rev asms' @ asms) mss) prem) of
   862             None => mut_impc prems concl rrss asms (prem :: prems')
   863               (rrs :: rrss') (asm :: asms') (None :: eqns) mss changed
   864               (if k = 0 then 0 else k - 1)
   865           | Some eqn =>
   866             let
   867               val prem' = rhs_of eqn;
   868               val tprems = map term_of prems;
   869               val i = 1 + foldl Int.max (~1, map (fn p =>
   870                 find_index_eq p tprems) (#hyps (rep_thm eqn)));
   871               val (rrs', asm') = rules_of_prem mss prem'
   872             in mut_impc prems concl rrss asms (prem' :: prems')
   873               (rrs' :: rrss') (asm' :: asms') (Some (foldr (disch true)
   874                 (take (i, prems), imp_cong' eqn (reflexive (Drule.list_implies
   875                   (drop (i, prems), concl))))) :: eqns) mss (length prems') ~1
   876             end
   877 
   878      (* legacy code - only for backwards compatibility *)
   879      and nonmut_impc ct mss =
   880        let val (prem, conc) = dest_implies ct;
   881            val thm1 = if simprem then botc skel0 mss prem else None;
   882            val prem1 = if_none (apsome rhs_of thm1) prem;
   883            val mss1 = if not useprem then mss else add_rrules
   884              (apsnd single (apfst single (rules_of_prem mss prem1))) mss
   885        in (case botc skel0 mss1 conc of
   886            None => (case thm1 of
   887                None => None
   888              | Some thm1' => Some (imp_cong' thm1' (reflexive conc)))
   889          | Some thm2 =>
   890            let val thm2' = disch false (prem1, thm2)
   891            in (case thm1 of
   892                None => Some thm2'
   893              | Some thm1' =>
   894                  Some (transitive (imp_cong' thm1' (reflexive conc)) thm2'))
   895            end)
   896        end
   897 
   898  in try_botc end;
   899 
   900 
   901 (*** Meta-rewriting: rewrites t to u and returns the theorem t==u ***)
   902 
   903 (*
   904   Parameters:
   905     mode = (simplify A,
   906             use A in simplifying B,
   907             use prems of B (if B is again a meta-impl.) to simplify A)
   908            when simplifying A ==> B
   909     mss: contains equality theorems of the form [|p1,...|] ==> t==u
   910     prover: how to solve premises in conditional rewrites and congruences
   911 *)
   912 
   913 fun rewrite_cterm mode prover mss ct =
   914   let val {sign, t, maxidx, ...} = rep_cterm ct
   915       val Mss{depth, ...} = mss
   916   in simp_depth := depth;
   917      bottomc (mode, prover, sign, maxidx) mss ct
   918   end
   919   handle THM (s, _, thms) =>
   920     error ("Exception THM was raised in simplifier:\n" ^ s ^ "\n" ^
   921       Pretty.string_of (Display.pretty_thms thms));
   922 
   923 (*In [A1,...,An]==>B, rewrite the selected A's only -- for rewrite_goals_tac*)
   924 fun goals_conv pred cv =
   925   let fun gconv i ct =
   926         let val (A,B) = Drule.dest_implies ct
   927         in imp_cong' (if pred i then cv A else reflexive A) (gconv (i+1) B) end
   928         handle TERM _ => reflexive ct
   929   in gconv 1 end;
   930 
   931 (* Rewrite A in !!x1,...,xn. A *)
   932 fun forall_conv cv ct =
   933   let val p as (ct1, ct2) = Thm.dest_comb ct
   934   in (case pairself term_of p of
   935       (Const ("all", _), Abs (s, _, _)) =>
   936          let val (v, ct') = Thm.dest_abs (Some "@") ct2;
   937          in Thm.combination (Thm.reflexive ct1)
   938            (Thm.abstract_rule s v (forall_conv cv ct'))
   939          end
   940     | _ => cv ct)
   941   end handle TERM _ => cv ct;
   942 
   943 (*Use a conversion to transform a theorem*)
   944 fun fconv_rule cv th = equal_elim (cv (cprop_of th)) th;
   945 
   946 (*Rewrite a cterm*)
   947 fun rewrite_aux _ _ [] = (fn ct => Thm.reflexive ct)
   948   | rewrite_aux prover full thms = rewrite_cterm (full, false, false) prover (mss_of thms);
   949 
   950 (*Rewrite a theorem*)
   951 fun simplify_aux _ _ [] = (fn th => th)
   952   | simplify_aux prover full thms =
   953       fconv_rule (rewrite_cterm (full, false, false) prover (mss_of thms));
   954 
   955 fun rewrite_thm mode prover mss = fconv_rule (rewrite_cterm mode prover mss);
   956 
   957 (*Rewrite the subgoals of a proof state (represented by a theorem) *)
   958 fun rewrite_goals_rule_aux _ []   th = th
   959   | rewrite_goals_rule_aux prover thms th =
   960       fconv_rule (goals_conv (K true) (rewrite_cterm (true, true, false) prover
   961         (mss_of thms))) th;
   962 
   963 (*Rewrite the subgoal of a proof state (represented by a theorem) *)
   964 fun rewrite_goal_rule mode prover mss i thm =
   965   if 0 < i  andalso  i <= nprems_of thm
   966   then fconv_rule (goals_conv (fn j => j=i) (rewrite_cterm mode prover mss)) thm
   967   else raise THM("rewrite_goal_rule",i,[thm]);
   968 
   969 
   970 (*simple term rewriting -- without proofs*)
   971 fun rewrite_term sg rules procs =
   972   Pattern.rewrite_term (Sign.tsig_of sg) (map decomp_simp' rules) procs;
   973 
   974 end;
   975 
   976 structure BasicMetaSimplifier: BASIC_META_SIMPLIFIER = MetaSimplifier;
   977 open BasicMetaSimplifier;