src/Pure/meta_simplifier.ML
author wenzelm
Thu Aug 08 23:51:24 2002 +0200 (2002-08-08)
changeset 13486 54464ea94d6f
parent 13458 a73823f70159
child 13569 69a6b3aa0f38
permissions -rw-r--r--
exception SIMPROC_FAIL: solid error reporting of simprocs;
     1 (*  Title:      Pure/meta_simplifier.ML
     2     ID:         $Id$
     3     Author:     Tobias Nipkow and Stefan Berghofer
     4     License:    GPL (GNU GENERAL PUBLIC LICENSE)
     5 
     6 Meta-level Simplification.
     7 *)
     8 
     9 signature BASIC_META_SIMPLIFIER =
    10 sig
    11   val trace_simp: bool ref
    12   val debug_simp: bool ref
    13 end;
    14 
    15 signature META_SIMPLIFIER =
    16 sig
    17   include BASIC_META_SIMPLIFIER
    18   exception SIMPLIFIER of string * thm
    19   exception SIMPROC_FAIL of string * exn
    20   type meta_simpset
    21   val dest_mss          : meta_simpset ->
    22     {simps: thm list, congs: thm list, procs: (string * cterm list) list}
    23   val empty_mss         : meta_simpset
    24   val clear_mss         : meta_simpset -> meta_simpset
    25   val merge_mss         : meta_simpset * meta_simpset -> meta_simpset
    26   val add_simps         : meta_simpset * thm list -> meta_simpset
    27   val del_simps         : meta_simpset * thm list -> meta_simpset
    28   val mss_of            : thm list -> meta_simpset
    29   val add_congs         : meta_simpset * thm list -> meta_simpset
    30   val del_congs         : meta_simpset * thm list -> meta_simpset
    31   val add_simprocs      : meta_simpset *
    32     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    33       -> meta_simpset
    34   val del_simprocs      : meta_simpset *
    35     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    36       -> meta_simpset
    37   val add_prems         : meta_simpset * thm list -> meta_simpset
    38   val prems_of_mss      : meta_simpset -> thm list
    39   val set_mk_rews       : meta_simpset * (thm -> thm list) -> meta_simpset
    40   val set_mk_sym        : meta_simpset * (thm -> thm option) -> meta_simpset
    41   val set_mk_eq_True    : meta_simpset * (thm -> thm option) -> meta_simpset
    42   val set_termless      : meta_simpset * (term * term -> bool) -> meta_simpset
    43   val beta_eta_conversion: cterm -> thm
    44   val rewrite_cterm: bool * bool * bool ->
    45     (meta_simpset -> thm -> thm option) -> meta_simpset -> cterm -> thm
    46   val goals_conv        : (int -> bool) -> (cterm -> thm) -> cterm -> thm
    47   val forall_conv       : (cterm -> thm) -> cterm -> thm
    48   val fconv_rule        : (cterm -> thm) -> thm -> thm
    49   val rewrite_aux       : (meta_simpset -> thm -> thm option) -> bool -> thm list -> cterm -> thm
    50   val simplify_aux      : (meta_simpset -> thm -> thm option) -> bool -> thm list -> thm -> thm
    51   val rewrite_thm       : bool * bool * bool
    52                           -> (meta_simpset -> thm -> thm option)
    53                           -> meta_simpset -> thm -> thm
    54   val rewrite_goals_rule_aux: (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
    55   val rewrite_goal_rule : bool* bool * bool
    56                           -> (meta_simpset -> thm -> thm option)
    57                           -> meta_simpset -> int -> thm -> thm
    58   val rewrite_term: Sign.sg -> thm list -> (term -> term option) list -> term -> term
    59 end;
    60 
    61 structure MetaSimplifier : META_SIMPLIFIER =
    62 struct
    63 
    64 (** diagnostics **)
    65 
    66 exception SIMPLIFIER of string * thm;
    67 exception SIMPROC_FAIL of string * exn;
    68 
    69 val simp_depth = ref 0;
    70 
    71 local
    72 
    73 fun println a =
    74   tracing ((case ! simp_depth of 0 => "" | n => "[" ^ string_of_int n ^ "]") ^ a);
    75 
    76 fun prnt warn a = if warn then warning a else println a;
    77 fun prtm warn a sign t = prnt warn (a ^ "\n" ^ Sign.string_of_term sign t);
    78 fun prctm warn a t = prnt warn (a ^ "\n" ^ Display.string_of_cterm t);
    79 
    80 in
    81 
    82 fun prthm warn a = prctm warn a o Thm.cprop_of;
    83 
    84 val trace_simp = ref false;
    85 val debug_simp = ref false;
    86 
    87 fun trace warn a = if !trace_simp then prnt warn a else ();
    88 fun debug warn a = if !debug_simp then prnt warn a else ();
    89 
    90 fun trace_term warn a sign t = if !trace_simp then prtm warn a sign t else ();
    91 fun trace_cterm warn a t = if !trace_simp then prctm warn a t else ();
    92 fun debug_term warn a sign t = if !debug_simp then prtm warn a sign t else ();
    93 
    94 fun trace_thm warn a thm =
    95   let val {sign, prop, ...} = rep_thm thm
    96   in trace_term warn a sign prop end;
    97 
    98 end;
    99 
   100 
   101 (** meta simp sets **)
   102 
   103 (* basic components *)
   104 
   105 type rrule = {thm: thm, lhs: term, elhs: cterm, fo: bool, perm: bool};
   106 (* thm: the rewrite rule
   107    lhs: the left-hand side
   108    elhs: the etac-contracted lhs.
   109    fo:  use first-order matching
   110    perm: the rewrite rule is permutative
   111 Remarks:
   112   - elhs is used for matching,
   113     lhs only for preservation of bound variable names.
   114   - fo is set iff
   115     either elhs is first-order (no Var is applied),
   116            in which case fo-matching is complete,
   117     or elhs is not a pattern,
   118        in which case there is nothing better to do.
   119 *)
   120 type cong = {thm: thm, lhs: cterm};
   121 type simproc =
   122  {name: string, proc: Sign.sg -> thm list -> term -> thm option, lhs: cterm, id: stamp};
   123 
   124 fun eq_rrule ({thm = thm1, ...}: rrule, {thm = thm2, ...}: rrule) =
   125   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   126 
   127 fun eq_cong ({thm = thm1, ...}: cong, {thm = thm2, ...}: cong) =
   128   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   129 
   130 fun eq_prem (thm1, thm2) =
   131   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   132 
   133 fun eq_simproc ({id = s1, ...}:simproc, {id = s2, ...}:simproc) = (s1 = s2);
   134 
   135 fun mk_simproc (name, proc, lhs, id) =
   136   {name = name, proc = proc, lhs = lhs, id = id};
   137 
   138 
   139 (* datatype mss *)
   140 
   141 (*
   142   A "mss" contains data needed during conversion:
   143     rules: discrimination net of rewrite rules;
   144     congs: association list of congruence rules and
   145            a list of `weak' congruence constants.
   146            A congruence is `weak' if it avoids normalization of some argument.
   147     procs: discrimination net of simplification procedures
   148       (functions that prove rewrite rules on the fly);
   149     bounds: names of bound variables already used
   150       (for generating new names when rewriting under lambda abstractions);
   151     prems: current premises;
   152     mk_rews: mk: turns simplification thms into rewrite rules;
   153              mk_sym: turns == around; (needs Drule!)
   154              mk_eq_True: turns P into P == True - logic specific;
   155     termless: relation for ordered rewriting;
   156     depth: depth of conditional rewriting;
   157 *)
   158 
   159 datatype meta_simpset =
   160   Mss of {
   161     rules: rrule Net.net,
   162     congs: (string * cong) list * string list,
   163     procs: simproc Net.net,
   164     bounds: string list,
   165     prems: thm list,
   166     mk_rews: {mk: thm -> thm list,
   167               mk_sym: thm -> thm option,
   168               mk_eq_True: thm -> thm option},
   169     termless: term * term -> bool,
   170     depth: int};
   171 
   172 fun mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth) =
   173   Mss {rules = rules, congs = congs, procs = procs, bounds = bounds,
   174        prems=prems, mk_rews=mk_rews, termless=termless, depth=depth};
   175 
   176 fun upd_rules(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}, rules') =
   177   mk_mss(rules',congs,procs,bounds,prems,mk_rews,termless,depth);
   178 
   179 val empty_mss =
   180   let val mk_rews = {mk = K [], mk_sym = K None, mk_eq_True = K None}
   181   in mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, Term.termless, 0) end;
   182 
   183 fun clear_mss (Mss {mk_rews, termless, ...}) =
   184   mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, termless,0);
   185 
   186 fun incr_depth(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) =
   187   mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth+1)
   188 
   189 
   190 
   191 (** simpset operations **)
   192 
   193 (* term variables *)
   194 
   195 val add_term_varnames = foldl_aterms (fn (xs, Var (x, _)) => ins_ix (x, xs) | (xs, _) => xs);
   196 fun term_varnames t = add_term_varnames ([], t);
   197 
   198 
   199 (* dest_mss *)
   200 
   201 fun dest_mss (Mss {rules, congs, procs, ...}) =
   202   {simps = map (fn (_, {thm, ...}) => thm) (Net.dest rules),
   203    congs = map (fn (_, {thm, ...}) => thm) (fst congs),
   204    procs =
   205      map (fn (_, {name, lhs, id, ...}) => ((name, lhs), id)) (Net.dest procs)
   206      |> partition_eq eq_snd
   207      |> map (fn ps => (#1 (#1 (hd ps)), map (#2 o #1) ps))
   208      |> Library.sort_wrt #1};
   209 
   210 
   211 (* merge_mss *)       (*NOTE: ignores mk_rews, termless and depth of 2nd mss*)
   212 
   213 fun merge_mss
   214  (Mss {rules = rules1, congs = (congs1,weak1), procs = procs1,
   215        bounds = bounds1, prems = prems1, mk_rews, termless, depth},
   216   Mss {rules = rules2, congs = (congs2,weak2), procs = procs2,
   217        bounds = bounds2, prems = prems2, ...}) =
   218       mk_mss
   219        (Net.merge (rules1, rules2, eq_rrule),
   220         (gen_merge_lists (eq_cong o pairself snd) congs1 congs2,
   221         merge_lists weak1 weak2),
   222         Net.merge (procs1, procs2, eq_simproc),
   223         merge_lists bounds1 bounds2,
   224         gen_merge_lists eq_prem prems1 prems2,
   225         mk_rews, termless, depth);
   226 
   227 
   228 (* add_simps *)
   229 
   230 fun mk_rrule2{thm,lhs,elhs,perm} =
   231   let val fo = Pattern.first_order (term_of elhs) orelse not(Pattern.pattern (term_of elhs))
   232   in {thm=thm,lhs=lhs,elhs=elhs,fo=fo,perm=perm} end
   233 
   234 fun insert_rrule(mss as Mss {rules,...},
   235                  rrule as {thm,lhs,elhs,perm}) =
   236   (trace_thm false "Adding rewrite rule:" thm;
   237    let val rrule2 as {elhs,...} = mk_rrule2 rrule
   238        val rules' = Net.insert_term ((term_of elhs, rrule2), rules, eq_rrule)
   239    in upd_rules(mss,rules') end
   240    handle Net.INSERT =>
   241      (prthm true "Ignoring duplicate rewrite rule:" thm; mss));
   242 
   243 fun vperm (Var _, Var _) = true
   244   | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t)
   245   | vperm (t1 $ t2, u1 $ u2) = vperm (t1, u1) andalso vperm (t2, u2)
   246   | vperm (t, u) = (t = u);
   247 
   248 fun var_perm (t, u) =
   249   vperm (t, u) andalso eq_set (term_varnames t, term_varnames u);
   250 
   251 (* FIXME: it seems that the conditions on extra variables are too liberal if
   252 prems are nonempty: does solving the prems really guarantee instantiation of
   253 all its Vars? Better: a dynamic check each time a rule is applied.
   254 *)
   255 fun rewrite_rule_extra_vars prems elhs erhs =
   256   not (term_varnames erhs subset foldl add_term_varnames (term_varnames elhs, prems))
   257   orelse
   258   not ((term_tvars erhs) subset
   259        (term_tvars elhs  union  List.concat(map term_tvars prems)));
   260 
   261 (*Simple test for looping rewrite rules and stupid orientations*)
   262 fun reorient sign prems lhs rhs =
   263    rewrite_rule_extra_vars prems lhs rhs
   264   orelse
   265    is_Var (head_of lhs)
   266   orelse
   267    (exists (apl (lhs, Logic.occs)) (rhs :: prems))
   268   orelse
   269    (null prems andalso
   270     Pattern.matches (#tsig (Sign.rep_sg sign)) (lhs, rhs))
   271     (*the condition "null prems" is necessary because conditional rewrites
   272       with extra variables in the conditions may terminate although
   273       the rhs is an instance of the lhs. Example: ?m < ?n ==> f(?n) == f(?m)*)
   274   orelse
   275    (is_Const lhs andalso not(is_Const rhs))
   276 
   277 fun decomp_simp thm =
   278   let val {sign, prop, ...} = rep_thm thm;
   279       val prems = Logic.strip_imp_prems prop;
   280       val concl = Drule.strip_imp_concl (cprop_of thm);
   281       val (lhs, rhs) = Drule.dest_equals concl handle TERM _ =>
   282         raise SIMPLIFIER ("Rewrite rule not a meta-equality", thm)
   283       val elhs = snd (Drule.dest_equals (cprop_of (Thm.eta_conversion lhs)));
   284       val elhs = if elhs=lhs then lhs else elhs (* try to share *)
   285       val erhs = Pattern.eta_contract (term_of rhs);
   286       val perm = var_perm (term_of elhs, erhs) andalso not (term_of elhs aconv erhs)
   287                  andalso not (is_Var (term_of elhs))
   288   in (sign, prems, term_of lhs, elhs, term_of rhs, perm) end;
   289 
   290 fun decomp_simp' thm =
   291   let val (_, _, lhs, _, rhs, _) = decomp_simp thm in
   292     if Thm.nprems_of thm > 0 then raise SIMPLIFIER ("Bad conditional rewrite rule", thm)
   293     else (lhs, rhs)
   294   end;
   295 
   296 fun mk_eq_True (Mss{mk_rews={mk_eq_True,...},...}) thm =
   297   case mk_eq_True thm of
   298     None => []
   299   | Some eq_True => let val (_,_,lhs,elhs,_,_) = decomp_simp eq_True
   300                     in [{thm=eq_True, lhs=lhs, elhs=elhs, perm=false}] end;
   301 
   302 (* create the rewrite rule and possibly also the ==True variant,
   303    in case there are extra vars on the rhs *)
   304 fun rrule_eq_True(thm,lhs,elhs,rhs,mss,thm2) =
   305   let val rrule = {thm=thm, lhs=lhs, elhs=elhs, perm=false}
   306   in if (term_varnames rhs)  subset (term_varnames lhs) andalso
   307         (term_tvars rhs) subset (term_tvars lhs)
   308      then [rrule]
   309      else mk_eq_True mss thm2 @ [rrule]
   310   end;
   311 
   312 fun mk_rrule mss thm =
   313   let val (_,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   314   in if perm then [{thm=thm, lhs=lhs, elhs=elhs, perm=true}] else
   315      (* weak test for loops: *)
   316      if rewrite_rule_extra_vars prems lhs rhs orelse
   317         is_Var (term_of elhs)
   318      then mk_eq_True mss thm
   319      else rrule_eq_True(thm,lhs,elhs,rhs,mss,thm)
   320   end;
   321 
   322 fun orient_rrule mss thm =
   323   let val (sign,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   324   in if perm then [{thm=thm,lhs=lhs,elhs=elhs,perm=true}]
   325      else if reorient sign prems lhs rhs
   326           then if reorient sign prems rhs lhs
   327                then mk_eq_True mss thm
   328                else let val Mss{mk_rews={mk_sym,...},...} = mss
   329                     in case mk_sym thm of
   330                          None => []
   331                        | Some thm' =>
   332                            let val (_,_,lhs',elhs',rhs',_) = decomp_simp thm'
   333                            in rrule_eq_True(thm',lhs',elhs',rhs',mss,thm) end
   334                     end
   335           else rrule_eq_True(thm,lhs,elhs,rhs,mss,thm)
   336   end;
   337 
   338 fun extract_rews(Mss{mk_rews = {mk,...},...},thms) = flat(map mk thms);
   339 
   340 fun orient_comb_simps comb mk_rrule (mss,thms) =
   341   let val rews = extract_rews(mss,thms)
   342       val rrules = flat (map mk_rrule rews)
   343   in foldl comb (mss,rrules) end
   344 
   345 (* Add rewrite rules explicitly; do not reorient! *)
   346 fun add_simps(mss,thms) =
   347   orient_comb_simps insert_rrule (mk_rrule mss) (mss,thms);
   348 
   349 fun mss_of thms =
   350   foldl insert_rrule (empty_mss, flat(map (mk_rrule empty_mss) thms));
   351 
   352 fun extract_safe_rrules(mss,thm) =
   353   flat (map (orient_rrule mss) (extract_rews(mss,[thm])));
   354 
   355 fun add_safe_simp(mss,thm) =
   356   foldl insert_rrule (mss, extract_safe_rrules(mss,thm))
   357 
   358 (* del_simps *)
   359 
   360 fun del_rrule(mss as Mss {rules,...},
   361               rrule as {thm, elhs, ...}) =
   362   (upd_rules(mss, Net.delete_term ((term_of elhs, rrule), rules, eq_rrule))
   363    handle Net.DELETE =>
   364      (prthm true "Rewrite rule not in simpset:" thm; mss));
   365 
   366 fun del_simps(mss,thms) =
   367   orient_comb_simps del_rrule (map mk_rrule2 o mk_rrule mss) (mss,thms);
   368 
   369 
   370 (* add_congs *)
   371 
   372 fun is_full_cong_prems [] varpairs = null varpairs
   373   | is_full_cong_prems (p::prems) varpairs =
   374     (case Logic.strip_assums_concl p of
   375        Const("==",_) $ lhs $ rhs =>
   376          let val (x,xs) = strip_comb lhs and (y,ys) = strip_comb rhs
   377          in is_Var x  andalso  forall is_Bound xs  andalso
   378             null(findrep(xs))  andalso xs=ys andalso
   379             (x,y) mem varpairs andalso
   380             is_full_cong_prems prems (varpairs\(x,y))
   381          end
   382      | _ => false);
   383 
   384 fun is_full_cong thm =
   385 let val prems = prems_of thm
   386     and concl = concl_of thm
   387     val (lhs,rhs) = Logic.dest_equals concl
   388     val (f,xs) = strip_comb lhs
   389     and (g,ys) = strip_comb rhs
   390 in
   391   f=g andalso null(findrep(xs@ys)) andalso length xs = length ys andalso
   392   is_full_cong_prems prems (xs ~~ ys)
   393 end
   394 
   395 fun add_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   396   let
   397     val (lhs, _) = Drule.dest_equals (Drule.strip_imp_concl (cprop_of thm)) handle TERM _ =>
   398       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   399 (*   val lhs = Pattern.eta_contract lhs; *)
   400     val (a, _) = dest_Const (head_of (term_of lhs)) handle TERM _ =>
   401       raise SIMPLIFIER ("Congruence must start with a constant", thm);
   402     val (alist,weak) = congs
   403     val alist2 = overwrite_warn (alist, (a,{lhs=lhs, thm=thm}))
   404            ("Overwriting congruence rule for " ^ quote a);
   405     val weak2 = if is_full_cong thm then weak else a::weak
   406   in
   407     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   408   end;
   409 
   410 val (op add_congs) = foldl add_cong;
   411 
   412 
   413 (* del_congs *)
   414 
   415 fun del_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   416   let
   417     val (lhs, _) = Logic.dest_equals (concl_of thm) handle TERM _ =>
   418       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   419 (*   val lhs = Pattern.eta_contract lhs; *)
   420     val (a, _) = dest_Const (head_of lhs) handle TERM _ =>
   421       raise SIMPLIFIER ("Congruence must start with a constant", thm);
   422     val (alist,_) = congs
   423     val alist2 = filter (fn (x,_)=> x<>a) alist
   424     val weak2 = mapfilter (fn(a,{thm,...}) => if is_full_cong thm then None
   425                                               else Some a)
   426                    alist2
   427   in
   428     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   429   end;
   430 
   431 val (op del_congs) = foldl del_cong;
   432 
   433 
   434 (* add_simprocs *)
   435 
   436 fun add_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   437     (name, lhs, proc, id)) =
   438   let val {sign, t, ...} = rep_cterm lhs
   439   in (trace_term false ("Adding simplification procedure " ^ quote name ^ " for")
   440       sign t;
   441     mk_mss (rules, congs,
   442       Net.insert_term ((t, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   443         handle Net.INSERT =>
   444             (warning ("Ignoring duplicate simplification procedure \""
   445                       ^ name ^ "\"");
   446              procs),
   447         bounds, prems, mk_rews, termless,depth))
   448   end;
   449 
   450 fun add_simproc (mss, (name, lhss, proc, id)) =
   451   foldl add_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   452 
   453 val add_simprocs = foldl add_simproc;
   454 
   455 
   456 (* del_simprocs *)
   457 
   458 fun del_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   459     (name, lhs, proc, id)) =
   460   mk_mss (rules, congs,
   461     Net.delete_term ((term_of lhs, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   462       handle Net.DELETE =>
   463           (warning ("Simplification procedure \"" ^ name ^
   464                        "\" not in simpset"); procs),
   465       bounds, prems, mk_rews, termless, depth);
   466 
   467 fun del_simproc (mss, (name, lhss, proc, id)) =
   468   foldl del_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   469 
   470 val del_simprocs = foldl del_simproc;
   471 
   472 
   473 (* prems *)
   474 
   475 fun add_prems (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thms) =
   476   mk_mss (rules, congs, procs, bounds, thms @ prems, mk_rews, termless, depth);
   477 
   478 fun prems_of_mss (Mss {prems, ...}) = prems;
   479 
   480 
   481 (* mk_rews *)
   482 
   483 fun set_mk_rews
   484   (Mss {rules, congs, procs, bounds, prems, mk_rews, termless, depth}, mk) =
   485     mk_mss (rules, congs, procs, bounds, prems,
   486             {mk=mk, mk_sym= #mk_sym mk_rews, mk_eq_True= #mk_eq_True mk_rews},
   487             termless, depth);
   488 
   489 fun set_mk_sym
   490   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_sym) =
   491     mk_mss (rules, congs, procs, bounds, prems,
   492             {mk= #mk mk_rews, mk_sym= mk_sym, mk_eq_True= #mk_eq_True mk_rews},
   493             termless,depth);
   494 
   495 fun set_mk_eq_True
   496   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_eq_True) =
   497     mk_mss (rules, congs, procs, bounds, prems,
   498             {mk= #mk mk_rews, mk_sym= #mk_sym mk_rews, mk_eq_True= mk_eq_True},
   499             termless,depth);
   500 
   501 (* termless *)
   502 
   503 fun set_termless
   504   (Mss {rules, congs, procs, bounds, prems, mk_rews, depth, ...}, termless) =
   505     mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth);
   506 
   507 
   508 
   509 (** rewriting **)
   510 
   511 (*
   512   Uses conversions, see:
   513     L C Paulson, A higher-order implementation of rewriting,
   514     Science of Computer Programming 3 (1983), pages 119-149.
   515 *)
   516 
   517 type prover = meta_simpset -> thm -> thm option;
   518 type termrec = (Sign.sg_ref * term list) * term;
   519 type conv = meta_simpset -> termrec -> termrec;
   520 
   521 val dest_eq = Drule.dest_equals o cprop_of;
   522 val lhs_of = fst o dest_eq;
   523 val rhs_of = snd o dest_eq;
   524 
   525 fun beta_eta_conversion t =
   526   let val thm = beta_conversion true t;
   527   in transitive thm (eta_conversion (rhs_of thm)) end;
   528 
   529 fun check_conv msg thm thm' =
   530   let
   531     val thm'' = transitive thm (transitive
   532       (symmetric (beta_eta_conversion (lhs_of thm'))) thm')
   533   in (if msg then trace_thm false "SUCCEEDED" thm' else (); Some thm'') end
   534   handle THM _ =>
   535     let val {sign, prop = _ $ _ $ prop0, ...} = rep_thm thm;
   536     in
   537       (trace_thm false "Proved wrong thm (Check subgoaler?)" thm';
   538        trace_term false "Should have proved:" sign prop0;
   539        None)
   540     end;
   541 
   542 
   543 (* mk_procrule *)
   544 
   545 fun mk_procrule thm =
   546   let val (_,prems,lhs,elhs,rhs,_) = decomp_simp thm
   547   in if rewrite_rule_extra_vars prems lhs rhs
   548      then (prthm true "Extra vars on rhs:" thm; [])
   549      else [mk_rrule2{thm=thm, lhs=lhs, elhs=elhs, perm=false}]
   550   end;
   551 
   552 
   553 (* conversion to apply the meta simpset to a term *)
   554 
   555 (* Since the rewriting strategy is bottom-up, we avoid re-normalizing already
   556    normalized terms by carrying around the rhs of the rewrite rule just
   557    applied. This is called the `skeleton'. It is decomposed in parallel
   558    with the term. Once a Var is encountered, the corresponding term is
   559    already in normal form.
   560    skel0 is a dummy skeleton that is to enforce complete normalization.
   561 *)
   562 val skel0 = Bound 0;
   563 
   564 (* Use rhs as skeleton only if the lhs does not contain unnormalized bits.
   565    The latter may happen iff there are weak congruence rules for constants
   566    in the lhs.
   567 *)
   568 fun uncond_skel((_,weak),(lhs,rhs)) =
   569   if null weak then rhs (* optimization *)
   570   else if exists_Const (fn (c,_) => c mem weak) lhs then skel0
   571        else rhs;
   572 
   573 (* Behaves like unconditional rule if rhs does not contain vars not in the lhs.
   574    Otherwise those vars may become instantiated with unnormalized terms
   575    while the premises are solved.
   576 *)
   577 fun cond_skel(args as (congs,(lhs,rhs))) =
   578   if term_varnames rhs subset term_varnames lhs then uncond_skel(args)
   579   else skel0;
   580 
   581 (*
   582   we try in order:
   583     (1) beta reduction
   584     (2) unconditional rewrite rules
   585     (3) conditional rewrite rules
   586     (4) simplification procedures
   587 
   588   IMPORTANT: rewrite rules must not introduce new Vars or TVars!
   589 
   590 *)
   591 
   592 fun rewritec (prover, signt, maxt)
   593              (mss as Mss{rules, procs, termless, prems, congs, depth,...}) t =
   594   let
   595     val eta_thm = Thm.eta_conversion t;
   596     val eta_t' = rhs_of eta_thm;
   597     val eta_t = term_of eta_t';
   598     val tsigt = Sign.tsig_of signt;
   599     fun rew {thm, lhs, elhs, fo, perm} =
   600       let
   601         val {sign, prop, maxidx, ...} = rep_thm thm;
   602         val _ = if Sign.subsig (sign, signt) then ()
   603                 else (prthm true "Ignoring rewrite rule from different theory:" thm;
   604                       raise Pattern.MATCH);
   605         val (rthm, elhs') = if maxt = ~1 then (thm, elhs)
   606           else (Thm.incr_indexes (maxt+1) thm, Thm.cterm_incr_indexes (maxt+1) elhs);
   607         val insts = if fo then Thm.cterm_first_order_match (elhs', eta_t')
   608                           else Thm.cterm_match (elhs', eta_t');
   609         val thm' = Thm.instantiate insts (Thm.rename_boundvars lhs eta_t rthm);
   610         val prop' = #prop (rep_thm thm');
   611         val unconditional = (Logic.count_prems (prop',0) = 0);
   612         val (lhs', rhs') = Logic.dest_equals (Logic.strip_imp_concl prop')
   613       in
   614         if perm andalso not (termless (rhs', lhs'))
   615         then (trace_thm false "Cannot apply permutative rewrite rule:" thm;
   616               trace_thm false "Term does not become smaller:" thm'; None)
   617         else (trace_thm false "Applying instance of rewrite rule:" thm;
   618            if unconditional
   619            then
   620              (trace_thm false "Rewriting:" thm';
   621               let val lr = Logic.dest_equals prop;
   622                   val Some thm'' = check_conv false eta_thm thm'
   623               in Some (thm'', uncond_skel (congs, lr)) end)
   624            else
   625              (trace_thm false "Trying to rewrite:" thm';
   626               case prover (incr_depth mss) thm' of
   627                 None       => (trace_thm false "FAILED" thm'; None)
   628               | Some thm2 =>
   629                   (case check_conv true eta_thm thm2 of
   630                      None => None |
   631                      Some thm2' =>
   632                        let val concl = Logic.strip_imp_concl prop
   633                            val lr = Logic.dest_equals concl
   634                        in Some (thm2', cond_skel (congs, lr)) end)))
   635       end
   636 
   637     fun rews [] = None
   638       | rews (rrule :: rrules) =
   639           let val opt = rew rrule handle Pattern.MATCH => None
   640           in case opt of None => rews rrules | some => some end;
   641 
   642     fun sort_rrules rrs = let
   643       fun is_simple({thm, ...}:rrule) = case #prop (rep_thm thm) of
   644                                       Const("==",_) $ _ $ _ => true
   645                                       | _                   => false
   646       fun sort []        (re1,re2) = re1 @ re2
   647         | sort (rr::rrs) (re1,re2) = if is_simple rr
   648                                      then sort rrs (rr::re1,re2)
   649                                      else sort rrs (re1,rr::re2)
   650     in sort rrs ([],[]) end
   651 
   652     fun proc_rews ([]:simproc list) = None
   653       | proc_rews ({name, proc, lhs, ...} :: ps) =
   654           if Pattern.matches tsigt (term_of lhs, term_of t) then
   655             (debug_term false ("Trying procedure " ^ quote name ^ " on:") signt eta_t;
   656              case transform_failure (curry SIMPROC_FAIL name)
   657                  (fn () => proc signt prems eta_t) () of
   658                None => (debug false "FAILED"; proc_rews ps)
   659              | Some raw_thm =>
   660                  (trace_thm false ("Procedure " ^ quote name ^ " produced rewrite rule:") raw_thm;
   661                   (case rews (mk_procrule raw_thm) of
   662                     None => (trace_cterm true ("IGNORED result of simproc " ^ quote name ^
   663                       " -- does not match") t; proc_rews ps)
   664                   | some => some)))
   665           else proc_rews ps;
   666   in case eta_t of
   667        Abs _ $ _ => Some (transitive eta_thm
   668          (beta_conversion false eta_t'), skel0)
   669      | _ => (case rews (sort_rrules (Net.match_term rules eta_t)) of
   670                None => proc_rews (Net.match_term procs eta_t)
   671              | some => some)
   672   end;
   673 
   674 
   675 (* conversion to apply a congruence rule to a term *)
   676 
   677 fun congc (prover,signt,maxt) {thm=cong,lhs=lhs} t =
   678   let val {sign, ...} = rep_thm cong
   679       val _ = if Sign.subsig (sign, signt) then ()
   680                  else error("Congruence rule from different theory")
   681       val rthm = if maxt = ~1 then cong else Thm.incr_indexes (maxt+1) cong;
   682       val rlhs = fst (Drule.dest_equals (Drule.strip_imp_concl (cprop_of rthm)));
   683       val insts = Thm.cterm_match (rlhs, t)
   684       (* Pattern.match can raise Pattern.MATCH;
   685          is handled when congc is called *)
   686       val thm' = Thm.instantiate insts (Thm.rename_boundvars (term_of rlhs) (term_of t) rthm);
   687       val unit = trace_thm false "Applying congruence rule:" thm';
   688       fun err (msg, thm) = (prthm false msg thm; error "Failed congruence proof!")
   689   in case prover thm' of
   690        None => err ("Could not prove", thm')
   691      | Some thm2 => (case check_conv true (beta_eta_conversion t) thm2 of
   692           None => err ("Should not have proved", thm2)
   693         | Some thm2' =>
   694             if op aconv (pairself term_of (dest_equals (cprop_of thm2')))
   695             then None else Some thm2')
   696   end;
   697 
   698 val (cA, (cB, cC)) =
   699   apsnd dest_equals (dest_implies (hd (cprems_of Drule.imp_cong)));
   700 
   701 fun transitive' thm1 None = Some thm1
   702   | transitive' thm1 (Some thm2) = Some (transitive thm1 thm2);
   703 
   704 fun transitive'' None thm2 = Some thm2
   705   | transitive'' (Some thm1) thm2 = Some (transitive thm1 thm2);
   706 
   707 fun bottomc ((simprem,useprem,mutsimp), prover, sign, maxidx) =
   708   let
   709     fun botc skel mss t =
   710           if is_Var skel then None
   711           else
   712           (case subc skel mss t of
   713              some as Some thm1 =>
   714                (case rewritec (prover, sign, maxidx) mss (rhs_of thm1) of
   715                   Some (thm2, skel2) =>
   716                     transitive' (transitive thm1 thm2)
   717                       (botc skel2 mss (rhs_of thm2))
   718                 | None => some)
   719            | None =>
   720                (case rewritec (prover, sign, maxidx) mss t of
   721                   Some (thm2, skel2) => transitive' thm2
   722                     (botc skel2 mss (rhs_of thm2))
   723                 | None => None))
   724 
   725     and try_botc mss t =
   726           (case botc skel0 mss t of
   727              Some trec1 => trec1 | None => (reflexive t))
   728 
   729     and subc skel
   730           (mss as Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) t0 =
   731        (case term_of t0 of
   732            Abs (a, T, t) =>
   733              let val b = variant bounds a
   734                  val (v, t') = Thm.dest_abs (Some ("." ^ b)) t0
   735                  val mss' = mk_mss (rules, congs, procs, b :: bounds, prems, mk_rews, termless,depth)
   736                  val skel' = case skel of Abs (_, _, sk) => sk | _ => skel0
   737              in case botc skel' mss' t' of
   738                   Some thm => Some (abstract_rule a v thm)
   739                 | None => None
   740              end
   741          | t $ _ => (case t of
   742              Const ("==>", _) $ _  =>
   743                let val (s, u) = Drule.dest_implies t0
   744                in impc (s, u, mss) end
   745            | Abs _ =>
   746                let val thm = beta_conversion false t0
   747                in case subc skel0 mss (rhs_of thm) of
   748                     None => Some thm
   749                   | Some thm' => Some (transitive thm thm')
   750                end
   751            | _  =>
   752                let fun appc () =
   753                      let
   754                        val (tskel, uskel) = case skel of
   755                            tskel $ uskel => (tskel, uskel)
   756                          | _ => (skel0, skel0);
   757                        val (ct, cu) = Thm.dest_comb t0
   758                      in
   759                      (case botc tskel mss ct of
   760                         Some thm1 =>
   761                           (case botc uskel mss cu of
   762                              Some thm2 => Some (combination thm1 thm2)
   763                            | None => Some (combination thm1 (reflexive cu)))
   764                       | None =>
   765                           (case botc uskel mss cu of
   766                              Some thm1 => Some (combination (reflexive ct) thm1)
   767                            | None => None))
   768                      end
   769                    val (h, ts) = strip_comb t
   770                in case h of
   771                     Const(a, _) =>
   772                       (case assoc_string (fst congs, a) of
   773                          None => appc ()
   774                        | Some cong =>
   775 (* post processing: some partial applications h t1 ... tj, j <= length ts,
   776    may be a redex. Example: map (%x.x) = (%xs.xs) wrt map_cong *)
   777                           (let
   778                              val thm = congc (prover mss, sign, maxidx) cong t0;
   779                              val t = if_none (apsome rhs_of thm) t0;
   780                              val (cl, cr) = Thm.dest_comb t
   781                              val dVar = Var(("", 0), dummyT)
   782                              val skel =
   783                                list_comb (h, replicate (length ts) dVar)
   784                            in case botc skel mss cl of
   785                                 None => thm
   786                               | Some thm' => transitive'' thm
   787                                   (combination thm' (reflexive cr))
   788                            end handle TERM _ => error "congc result"
   789                                     | Pattern.MATCH => appc ()))
   790                   | _ => appc ()
   791                end)
   792          | _ => None)
   793 
   794     and impc args =
   795       if mutsimp
   796       then let val (prem, conc, mss) = args
   797            in apsome snd (mut_impc ([], prem, conc, mss)) end
   798       else nonmut_impc args
   799 
   800     and mut_impc (prems, prem, conc, mss) = (case botc skel0 mss prem of
   801         None => mut_impc1 (prems, prem, conc, mss)
   802       | Some thm1 =>
   803           let val prem1 = rhs_of thm1
   804           in (case mut_impc1 (prems, prem1, conc, mss) of
   805               None => Some (None,
   806                 combination (combination refl_implies thm1) (reflexive conc))
   807             | Some (x, thm2) => Some (x, transitive (combination (combination
   808                 refl_implies thm1) (reflexive conc)) thm2))
   809           end)
   810 
   811     and mut_impc1 (prems, prem1, conc, mss) =
   812       let
   813         fun uncond ({thm, lhs, elhs, perm}) =
   814           if Thm.no_prems thm then Some lhs else None
   815 
   816         val (lhss1, mss1) =
   817           if maxidx_of_term (term_of prem1) <> ~1
   818           then (trace_cterm true
   819             "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem1;
   820                 ([],mss))
   821           else let val thm = assume prem1
   822                    val rrules1 = extract_safe_rrules (mss, thm)
   823                    val lhss1 = mapfilter uncond rrules1
   824                    val mss1 = foldl insert_rrule (add_prems (mss, [thm]), rrules1)
   825                in (lhss1, mss1) end
   826 
   827         fun disch1 thm =
   828           let val (cB', cC') = dest_eq thm
   829           in
   830             implies_elim (Thm.instantiate
   831               ([], [(cA, prem1), (cB, cB'), (cC, cC')]) Drule.imp_cong)
   832               (implies_intr prem1 thm)
   833           end
   834 
   835         fun rebuild None = (case rewritec (prover, sign, maxidx) mss
   836             (mk_implies (prem1, conc)) of
   837               None => None
   838             | Some (thm, _) =>
   839                 let val (prem, conc) = Drule.dest_implies (rhs_of thm)
   840                 in (case mut_impc (prems, prem, conc, mss) of
   841                     None => Some (None, thm)
   842                   | Some (x, thm') => Some (x, transitive thm thm'))
   843                 end handle TERM _ => Some (None, thm))
   844           | rebuild (Some thm2) =
   845             let val thm = disch1 thm2
   846             in (case rewritec (prover, sign, maxidx) mss (rhs_of thm) of
   847                  None => Some (None, thm)
   848                | Some (thm', _) =>
   849                    let val (prem, conc) = Drule.dest_implies (rhs_of thm')
   850                    in (case mut_impc (prems, prem, conc, mss) of
   851                        None => Some (None, transitive thm thm')
   852                      | Some (x, thm'') =>
   853                          Some (x, transitive (transitive thm thm') thm''))
   854                    end handle TERM _ => Some (None, transitive thm thm'))
   855             end
   856 
   857         fun simpconc () =
   858           let val (s, t) = Drule.dest_implies conc
   859           in case mut_impc (prems @ [prem1], s, t, mss1) of
   860                None => rebuild None
   861              | Some (Some i, thm2) =>
   862                   let
   863                     val (prem, cC') = Drule.dest_implies (rhs_of thm2);
   864                     val thm2' = transitive (disch1 thm2) (Thm.instantiate
   865                       ([], [(cA, prem1), (cB, prem), (cC, cC')])
   866                       Drule.swap_prems_eq)
   867                   in if i=0 then apsome (apsnd (transitive thm2'))
   868                        (mut_impc1 (prems, prem, mk_implies (prem1, cC'), mss))
   869                      else Some (Some (i-1), thm2')
   870                   end
   871              | Some (None, thm) => rebuild (Some thm)
   872           end handle TERM _ => rebuild (botc skel0 mss1 conc)
   873 
   874       in
   875         let
   876           val tsig = Sign.tsig_of sign
   877           fun reducible t =
   878             exists (fn lhs => Pattern.matches_subterm tsig (lhs, term_of t)) lhss1;
   879         in case dropwhile (not o reducible) prems of
   880             [] => simpconc ()
   881           | red::rest => (trace_cterm false "Can now reduce premise:" red;
   882               Some (Some (length rest), reflexive (mk_implies (prem1, conc))))
   883         end
   884       end
   885 
   886      (* legacy code - only for backwards compatibility *)
   887      and nonmut_impc (prem, conc, mss) =
   888        let val thm1 = if simprem then botc skel0 mss prem else None;
   889            val prem1 = if_none (apsome rhs_of thm1) prem;
   890            val maxidx1 = maxidx_of_term (term_of prem1)
   891            val mss1 =
   892              if not useprem then mss else
   893              if maxidx1 <> ~1
   894              then (trace_cterm true
   895                "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem1;
   896                    mss)
   897              else let val thm = assume prem1
   898                   in add_safe_simp (add_prems (mss, [thm]), thm) end
   899        in (case botc skel0 mss1 conc of
   900            None => (case thm1 of
   901                None => None
   902              | Some thm1' => Some (combination
   903                  (combination refl_implies thm1') (reflexive conc)))
   904          | Some thm2 =>
   905            let
   906              val conc2 = rhs_of thm2;
   907              val thm2' = implies_elim (Thm.instantiate
   908                ([], [(cA, prem1), (cB, conc), (cC, conc2)]) Drule.imp_cong)
   909                (implies_intr prem1 thm2)
   910            in (case thm1 of
   911                None => Some thm2'
   912              | Some thm1' => Some (transitive (combination
   913                  (combination refl_implies thm1') (reflexive conc)) thm2'))
   914            end)
   915        end
   916 
   917  in try_botc end;
   918 
   919 
   920 (*** Meta-rewriting: rewrites t to u and returns the theorem t==u ***)
   921 
   922 (*
   923   Parameters:
   924     mode = (simplify A,
   925             use A in simplifying B,
   926             use prems of B (if B is again a meta-impl.) to simplify A)
   927            when simplifying A ==> B
   928     mss: contains equality theorems of the form [|p1,...|] ==> t==u
   929     prover: how to solve premises in conditional rewrites and congruences
   930 *)
   931 
   932 fun rewrite_cterm mode prover mss ct =
   933   let val {sign, t, maxidx, ...} = rep_cterm ct
   934       val Mss{depth, ...} = mss
   935   in simp_depth := depth;
   936      bottomc (mode, prover, sign, maxidx) mss ct
   937   end
   938   handle THM (s, _, thms) =>
   939     error ("Exception THM was raised in simplifier:\n" ^ s ^ "\n" ^
   940       Pretty.string_of (Display.pretty_thms thms));
   941 
   942 (*In [A1,...,An]==>B, rewrite the selected A's only -- for rewrite_goals_tac*)
   943 (*Do not rewrite flex-flex pairs*)
   944 fun goals_conv pred cv =
   945   let fun gconv i ct =
   946         let val (A,B) = Drule.dest_implies ct
   947             val (thA,j) = case term_of A of
   948                   Const("=?=",_)$_$_ => (reflexive A, i)
   949                 | _ => (if pred i then cv A else reflexive A, i+1)
   950         in  combination (combination refl_implies thA) (gconv j B) end
   951         handle TERM _ => reflexive ct
   952   in gconv 1 end;
   953 
   954 (* Rewrite A in !!x1,...,xn. A *)
   955 fun forall_conv cv ct =
   956   let val p as (ct1, ct2) = Thm.dest_comb ct
   957   in (case pairself term_of p of
   958       (Const ("all", _), Abs (s, _, _)) =>
   959          let val (v, ct') = Thm.dest_abs (Some "@") ct2;
   960          in Thm.combination (Thm.reflexive ct1)
   961            (Thm.abstract_rule s v (forall_conv cv ct'))
   962          end
   963     | _ => cv ct)
   964   end handle TERM _ => cv ct;
   965 
   966 (*Use a conversion to transform a theorem*)
   967 fun fconv_rule cv th = equal_elim (cv (cprop_of th)) th;
   968 
   969 (*Rewrite a cterm*)
   970 fun rewrite_aux _ _ [] = (fn ct => Thm.reflexive ct)
   971   | rewrite_aux prover full thms = rewrite_cterm (full, false, false) prover (mss_of thms);
   972 
   973 (*Rewrite a theorem*)
   974 fun simplify_aux _ _ [] = (fn th => th)
   975   | simplify_aux prover full thms =
   976       fconv_rule (rewrite_cterm (full, false, false) prover (mss_of thms));
   977 
   978 fun rewrite_thm mode prover mss = fconv_rule (rewrite_cterm mode prover mss);
   979 
   980 (*Rewrite the subgoals of a proof state (represented by a theorem) *)
   981 fun rewrite_goals_rule_aux _ []   th = th
   982   | rewrite_goals_rule_aux prover thms th =
   983       fconv_rule (goals_conv (K true) (rewrite_cterm (true, true, false) prover
   984         (mss_of thms))) th;
   985 
   986 (*Rewrite the subgoal of a proof state (represented by a theorem) *)
   987 fun rewrite_goal_rule mode prover mss i thm =
   988   if 0 < i  andalso  i <= nprems_of thm
   989   then fconv_rule (goals_conv (fn j => j=i) (rewrite_cterm mode prover mss)) thm
   990   else raise THM("rewrite_goal_rule",i,[thm]);
   991 
   992 
   993 (*simple term rewriting -- without proofs*)
   994 fun rewrite_term sg rules procs =
   995   Pattern.rewrite_term (Sign.tsig_of sg) (map decomp_simp' rules) procs;
   996 
   997 end;
   998 
   999 structure BasicMetaSimplifier: BASIC_META_SIMPLIFIER = MetaSimplifier;
  1000 open BasicMetaSimplifier;