src/Pure/meta_simplifier.ML
author skalberg
Wed Jun 23 14:44:22 2004 +0200 (2004-06-23)
changeset 15001 fb2141a9f8c0
parent 14981 e73f8140af78
child 15006 107e4dfd3b96
permissions -rw-r--r--
Moved conversion rules from MetaSimplifier to Drule. refl_implies removed
from Drule, instead imp_cong' exported from there.
     1 (*  Title:      Pure/meta_simplifier.ML
     2     ID:         $Id$
     3     Author:     Tobias Nipkow and Stefan Berghofer
     4 
     5 Meta-level Simplification.
     6 *)
     7 
     8 signature BASIC_META_SIMPLIFIER =
     9 sig
    10   val trace_simp: bool ref
    11   val debug_simp: bool ref
    12   val simp_depth_limit: int ref
    13 end;
    14 
    15 signature META_SIMPLIFIER =
    16 sig
    17   include BASIC_META_SIMPLIFIER
    18   exception SIMPLIFIER of string * thm
    19   exception SIMPROC_FAIL of string * exn
    20   type meta_simpset
    21   val dest_mss          : meta_simpset ->
    22     {simps: thm list, congs: thm list, procs: (string * cterm list) list}
    23   val empty_mss         : meta_simpset
    24   val clear_mss         : meta_simpset -> meta_simpset
    25   val merge_mss         : meta_simpset * meta_simpset -> meta_simpset
    26   val add_simps         : meta_simpset * thm list -> meta_simpset
    27   val del_simps         : meta_simpset * thm list -> meta_simpset
    28   val mss_of            : thm list -> meta_simpset
    29   val add_congs         : meta_simpset * thm list -> meta_simpset
    30   val del_congs         : meta_simpset * thm list -> meta_simpset
    31   val add_simprocs      : meta_simpset *
    32     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    33       -> meta_simpset
    34   val del_simprocs      : meta_simpset *
    35     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    36       -> meta_simpset
    37   val add_prems         : meta_simpset * thm list -> meta_simpset
    38   val prems_of_mss      : meta_simpset -> thm list
    39   val set_mk_rews       : meta_simpset * (thm -> thm list) -> meta_simpset
    40   val set_mk_sym        : meta_simpset * (thm -> thm option) -> meta_simpset
    41   val set_mk_eq_True    : meta_simpset * (thm -> thm option) -> meta_simpset
    42   val get_mk_rews       : meta_simpset -> thm -> thm list
    43   val get_mk_sym        : meta_simpset -> thm -> thm option
    44   val get_mk_eq_True    : meta_simpset -> thm -> thm option
    45   val set_termless      : meta_simpset * (term * term -> bool) -> meta_simpset
    46   val rewrite_cterm: bool * bool * bool ->
    47     (meta_simpset -> thm -> thm option) -> meta_simpset -> cterm -> thm
    48   val rewrite_aux       : (meta_simpset -> thm -> thm option) -> bool -> thm list -> cterm -> thm
    49   val simplify_aux      : (meta_simpset -> thm -> thm option) -> bool -> thm list -> thm -> thm
    50   val rewrite_thm       : bool * bool * bool
    51                           -> (meta_simpset -> thm -> thm option)
    52                           -> meta_simpset -> thm -> thm
    53   val rewrite_goals_rule_aux: (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
    54   val rewrite_goal_rule : bool* bool * bool
    55                           -> (meta_simpset -> thm -> thm option)
    56                           -> meta_simpset -> int -> thm -> thm
    57   val rewrite_term: Sign.sg -> thm list -> (term -> term option) list -> term -> term
    58 end;
    59 
    60 structure MetaSimplifier : META_SIMPLIFIER =
    61 struct
    62 
    63 (** diagnostics **)
    64 
    65 exception SIMPLIFIER of string * thm;
    66 exception SIMPROC_FAIL of string * exn;
    67 
    68 val simp_depth = ref 0;
    69 val simp_depth_limit = ref 1000;
    70 
    71 local
    72 
    73 fun println a =
    74   tracing ((case ! simp_depth of 0 => "" | n => "[" ^ string_of_int n ^ "]") ^ a);
    75 
    76 fun prnt warn a = if warn then warning a else println a;
    77 fun prtm warn a sign t = prnt warn (a ^ "\n" ^ Sign.string_of_term sign t);
    78 fun prctm warn a t = prnt warn (a ^ "\n" ^ Display.string_of_cterm t);
    79 
    80 in
    81 
    82 fun prthm warn a = prctm warn a o Thm.cprop_of;
    83 
    84 val trace_simp = ref false;
    85 val debug_simp = ref false;
    86 
    87 fun trace warn a = if !trace_simp then prnt warn a else ();
    88 fun debug warn a = if !debug_simp then prnt warn a else ();
    89 
    90 fun trace_term warn a sign t = if !trace_simp then prtm warn a sign t else ();
    91 fun trace_cterm warn a t = if !trace_simp then prctm warn a t else ();
    92 fun debug_term warn a sign t = if !debug_simp then prtm warn a sign t else ();
    93 
    94 fun trace_thm a thm =
    95   let val {sign, prop, ...} = rep_thm thm
    96   in trace_term false a sign prop end;
    97 
    98 fun trace_named_thm a (thm, name) =
    99   trace_thm (a ^ (if name = "" then "" else " " ^ quote name) ^ ":") thm;
   100 
   101 end;
   102 
   103 
   104 (** meta simp sets **)
   105 
   106 (* basic components *)
   107 
   108 type rrule = {thm: thm, name: string, lhs: term, elhs: cterm, fo: bool, perm: bool};
   109 (* thm: the rewrite rule
   110    name: name of theorem from which rewrite rule was extracted
   111    lhs: the left-hand side
   112    elhs: the etac-contracted lhs.
   113    fo:  use first-order matching
   114    perm: the rewrite rule is permutative
   115 Remarks:
   116   - elhs is used for matching,
   117     lhs only for preservation of bound variable names.
   118   - fo is set iff
   119     either elhs is first-order (no Var is applied),
   120            in which case fo-matching is complete,
   121     or elhs is not a pattern,
   122        in which case there is nothing better to do.
   123 *)
   124 type cong = {thm: thm, lhs: cterm};
   125 type simproc =
   126  {name: string, proc: Sign.sg -> thm list -> term -> thm option, lhs: cterm, id: stamp};
   127 
   128 fun eq_rrule ({thm = thm1, ...}: rrule, {thm = thm2, ...}: rrule) =
   129   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   130 
   131 fun eq_cong ({thm = thm1, ...}: cong, {thm = thm2, ...}: cong) =
   132   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   133 
   134 fun eq_prem (thm1, thm2) =
   135   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   136 
   137 fun eq_simproc ({id = s1, ...}:simproc, {id = s2, ...}:simproc) = (s1 = s2);
   138 
   139 fun mk_simproc (name, proc, lhs, id) =
   140   {name = name, proc = proc, lhs = lhs, id = id};
   141 
   142 
   143 (* datatype mss *)
   144 
   145 (*
   146   A "mss" contains data needed during conversion:
   147     rules: discrimination net of rewrite rules;
   148     congs: association list of congruence rules and
   149            a list of `weak' congruence constants.
   150            A congruence is `weak' if it avoids normalization of some argument.
   151     procs: discrimination net of simplification procedures
   152       (functions that prove rewrite rules on the fly);
   153     bounds: names of bound variables already used
   154       (for generating new names when rewriting under lambda abstractions);
   155     prems: current premises;
   156     mk_rews: mk: turns simplification thms into rewrite rules;
   157              mk_sym: turns == around; (needs Drule!)
   158              mk_eq_True: turns P into P == True - logic specific;
   159     termless: relation for ordered rewriting;
   160     depth: depth of conditional rewriting;
   161 *)
   162 
   163 datatype meta_simpset =
   164   Mss of {
   165     rules: rrule Net.net,
   166     congs: (string * cong) list * string list,
   167     procs: simproc Net.net,
   168     bounds: string list,
   169     prems: thm list,
   170     mk_rews: {mk: thm -> thm list,
   171               mk_sym: thm -> thm option,
   172               mk_eq_True: thm -> thm option},
   173     termless: term * term -> bool,
   174     depth: int};
   175 
   176 fun mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth) =
   177   Mss {rules = rules, congs = congs, procs = procs, bounds = bounds,
   178        prems=prems, mk_rews=mk_rews, termless=termless, depth=depth};
   179 
   180 fun upd_rules(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}, rules') =
   181   mk_mss(rules',congs,procs,bounds,prems,mk_rews,termless,depth);
   182 
   183 val empty_mss =
   184   let val mk_rews = {mk = K [], mk_sym = K None, mk_eq_True = K None}
   185   in mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, Term.termless, 0) end;
   186 
   187 fun clear_mss (Mss {mk_rews, termless, ...}) =
   188   mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, termless,0);
   189 
   190 fun incr_depth(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) =
   191   let val depth1 = depth+1
   192   in if depth1 > !simp_depth_limit
   193      then (warning "simp_depth_limit exceeded - giving up"; None)
   194      else (if depth1 mod 10 = 0
   195            then warning("Simplification depth " ^ string_of_int depth1)
   196            else ();
   197            Some(mk_mss(rules,congs,procs,bounds,prems,mk_rews,termless,depth1))
   198           )
   199   end;
   200 
   201 
   202 (** simpset operations **)
   203 
   204 (* term variables *)
   205 
   206 val add_term_varnames = foldl_aterms (fn (xs, Var (x, _)) => ins_ix (x, xs) | (xs, _) => xs);
   207 fun term_varnames t = add_term_varnames ([], t);
   208 
   209 
   210 (* dest_mss *)
   211 
   212 fun dest_mss (Mss {rules, congs, procs, ...}) =
   213   {simps = map (fn (_, {thm, ...}) => thm) (Net.dest rules),
   214    congs = map (fn (_, {thm, ...}) => thm) (fst congs),
   215    procs =
   216      map (fn (_, {name, lhs, id, ...}) => ((name, lhs), id)) (Net.dest procs)
   217      |> partition_eq eq_snd
   218      |> map (fn ps => (#1 (#1 (hd ps)), map (#2 o #1) ps))
   219      |> Library.sort_wrt #1};
   220 
   221 
   222 (* merge_mss *)       (*NOTE: ignores mk_rews, termless and depth of 2nd mss*)
   223 
   224 fun merge_mss
   225  (Mss {rules = rules1, congs = (congs1,weak1), procs = procs1,
   226        bounds = bounds1, prems = prems1, mk_rews, termless, depth},
   227   Mss {rules = rules2, congs = (congs2,weak2), procs = procs2,
   228        bounds = bounds2, prems = prems2, ...}) =
   229       mk_mss
   230        (Net.merge (rules1, rules2, eq_rrule),
   231         (gen_merge_lists (eq_cong o pairself snd) congs1 congs2,
   232         merge_lists weak1 weak2),
   233         Net.merge (procs1, procs2, eq_simproc),
   234         merge_lists bounds1 bounds2,
   235         gen_merge_lists eq_prem prems1 prems2,
   236         mk_rews, termless, depth);
   237 
   238 
   239 (* add_simps *)
   240 
   241 fun mk_rrule2{thm, name, lhs, elhs, perm} =
   242   let val fo = Pattern.first_order (term_of elhs) orelse not(Pattern.pattern (term_of elhs))
   243   in {thm=thm, name=name, lhs=lhs, elhs=elhs, fo=fo, perm=perm} end
   244 
   245 fun insert_rrule quiet (mss as Mss {rules,...},
   246                  rrule as {thm,name,lhs,elhs,perm}) =
   247   (trace_named_thm "Adding rewrite rule" (thm, name);
   248    let val rrule2 as {elhs,...} = mk_rrule2 rrule
   249        val rules' = Net.insert_term ((term_of elhs, rrule2), rules, eq_rrule)
   250    in upd_rules(mss,rules') end
   251    handle Net.INSERT => if quiet then mss else
   252      (prthm true "Ignoring duplicate rewrite rule:" thm; mss));
   253 
   254 fun vperm (Var _, Var _) = true
   255   | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t)
   256   | vperm (t1 $ t2, u1 $ u2) = vperm (t1, u1) andalso vperm (t2, u2)
   257   | vperm (t, u) = (t = u);
   258 
   259 fun var_perm (t, u) =
   260   vperm (t, u) andalso eq_set (term_varnames t, term_varnames u);
   261 
   262 (* FIXME: it seems that the conditions on extra variables are too liberal if
   263 prems are nonempty: does solving the prems really guarantee instantiation of
   264 all its Vars? Better: a dynamic check each time a rule is applied.
   265 *)
   266 fun rewrite_rule_extra_vars prems elhs erhs =
   267   not (term_varnames erhs subset foldl add_term_varnames (term_varnames elhs, prems))
   268   orelse
   269   not ((term_tvars erhs) subset
   270        (term_tvars elhs  union  List.concat(map term_tvars prems)));
   271 
   272 (*Simple test for looping rewrite rules and stupid orientations*)
   273 fun reorient sign prems lhs rhs =
   274    rewrite_rule_extra_vars prems lhs rhs
   275   orelse
   276    is_Var (head_of lhs)
   277   orelse
   278    (exists (apl (lhs, Logic.occs)) (rhs :: prems))
   279   orelse
   280    (null prems andalso
   281     Pattern.matches (Sign.tsig_of sign) (lhs, rhs))
   282     (*the condition "null prems" is necessary because conditional rewrites
   283       with extra variables in the conditions may terminate although
   284       the rhs is an instance of the lhs. Example: ?m < ?n ==> f(?n) == f(?m)*)
   285   orelse
   286    (is_Const lhs andalso not(is_Const rhs))
   287 
   288 fun decomp_simp thm =
   289   let val {sign, prop, ...} = rep_thm thm;
   290       val prems = Logic.strip_imp_prems prop;
   291       val concl = Drule.strip_imp_concl (cprop_of thm);
   292       val (lhs, rhs) = Drule.dest_equals concl handle TERM _ =>
   293         raise SIMPLIFIER ("Rewrite rule not a meta-equality", thm)
   294       val elhs = snd (Drule.dest_equals (cprop_of (Thm.eta_conversion lhs)));
   295       val elhs = if elhs=lhs then lhs else elhs (* try to share *)
   296       val erhs = Pattern.eta_contract (term_of rhs);
   297       val perm = var_perm (term_of elhs, erhs) andalso not (term_of elhs aconv erhs)
   298                  andalso not (is_Var (term_of elhs))
   299   in (sign, prems, term_of lhs, elhs, term_of rhs, perm) end;
   300 
   301 fun decomp_simp' thm =
   302   let val (_, _, lhs, _, rhs, _) = decomp_simp thm in
   303     if Thm.nprems_of thm > 0 then raise SIMPLIFIER ("Bad conditional rewrite rule", thm)
   304     else (lhs, rhs)
   305   end;
   306 
   307 fun mk_eq_True (Mss{mk_rews={mk_eq_True,...},...}) (thm, name) =
   308   case mk_eq_True thm of
   309     None => []
   310   | Some eq_True =>
   311       let val (_,_,lhs,elhs,_,_) = decomp_simp eq_True
   312       in [{thm=eq_True, name=name, lhs=lhs, elhs=elhs, perm=false}] end;
   313 
   314 (* create the rewrite rule and possibly also the ==True variant,
   315    in case there are extra vars on the rhs *)
   316 fun rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm2) =
   317   let val rrule = {thm=thm, name=name, lhs=lhs, elhs=elhs, perm=false}
   318   in if (term_varnames rhs)  subset (term_varnames lhs) andalso
   319         (term_tvars rhs) subset (term_tvars lhs)
   320      then [rrule]
   321      else mk_eq_True mss (thm2, name) @ [rrule]
   322   end;
   323 
   324 fun mk_rrule mss (thm, name) =
   325   let val (_,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   326   in if perm then [{thm=thm, name=name, lhs=lhs, elhs=elhs, perm=true}] else
   327      (* weak test for loops: *)
   328      if rewrite_rule_extra_vars prems lhs rhs orelse
   329         is_Var (term_of elhs)
   330      then mk_eq_True mss (thm, name)
   331      else rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm)
   332   end;
   333 
   334 fun orient_rrule mss (thm, name) =
   335   let val (sign,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   336   in if perm then [{thm=thm, name=name, lhs=lhs, elhs=elhs, perm=true}]
   337      else if reorient sign prems lhs rhs
   338           then if reorient sign prems rhs lhs
   339                then mk_eq_True mss (thm, name)
   340                else let val Mss{mk_rews={mk_sym,...},...} = mss
   341                     in case mk_sym thm of
   342                          None => []
   343                        | Some thm' =>
   344                            let val (_,_,lhs',elhs',rhs',_) = decomp_simp thm'
   345                            in rrule_eq_True(thm',name,lhs',elhs',rhs',mss,thm) end
   346                     end
   347           else rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm)
   348   end;
   349 
   350 fun extract_rews(Mss{mk_rews = {mk,...},...},thms) =
   351   flat (map (fn thm => map (rpair (Thm.name_of_thm thm)) (mk thm)) thms);
   352 
   353 fun orient_comb_simps comb mk_rrule (mss,thms) =
   354   let val rews = extract_rews(mss,thms)
   355       val rrules = flat (map mk_rrule rews)
   356   in foldl comb (mss,rrules) end
   357 
   358 (* Add rewrite rules explicitly; do not reorient! *)
   359 fun add_simps(mss,thms) =
   360   orient_comb_simps (insert_rrule false) (mk_rrule mss) (mss,thms);
   361 
   362 fun mss_of thms = foldl (insert_rrule false) (empty_mss, flat
   363   (map (fn thm => mk_rrule empty_mss (thm, Thm.name_of_thm thm)) thms));
   364 
   365 fun extract_safe_rrules(mss,thm) =
   366   flat (map (orient_rrule mss) (extract_rews(mss,[thm])));
   367 
   368 (* del_simps *)
   369 
   370 fun del_rrule(mss as Mss {rules,...},
   371               rrule as {thm, elhs, ...}) =
   372   (upd_rules(mss, Net.delete_term ((term_of elhs, rrule), rules, eq_rrule))
   373    handle Net.DELETE =>
   374      (prthm true "Rewrite rule not in simpset:" thm; mss));
   375 
   376 fun del_simps(mss,thms) =
   377   orient_comb_simps del_rrule (map mk_rrule2 o mk_rrule mss) (mss,thms);
   378 
   379 
   380 (* add_congs *)
   381 
   382 fun is_full_cong_prems [] varpairs = null varpairs
   383   | is_full_cong_prems (p::prems) varpairs =
   384     (case Logic.strip_assums_concl p of
   385        Const("==",_) $ lhs $ rhs =>
   386          let val (x,xs) = strip_comb lhs and (y,ys) = strip_comb rhs
   387          in is_Var x  andalso  forall is_Bound xs  andalso
   388             null(findrep(xs))  andalso xs=ys andalso
   389             (x,y) mem varpairs andalso
   390             is_full_cong_prems prems (varpairs\(x,y))
   391          end
   392      | _ => false);
   393 
   394 fun is_full_cong thm =
   395 let val prems = prems_of thm
   396     and concl = concl_of thm
   397     val (lhs,rhs) = Logic.dest_equals concl
   398     val (f,xs) = strip_comb lhs
   399     and (g,ys) = strip_comb rhs
   400 in
   401   f=g andalso null(findrep(xs@ys)) andalso length xs = length ys andalso
   402   is_full_cong_prems prems (xs ~~ ys)
   403 end
   404 
   405 fun cong_name (Const (a, _)) = Some a
   406   | cong_name (Free (a, _)) = Some ("Free: " ^ a)
   407   | cong_name _ = None;
   408 
   409 fun add_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   410   let
   411     val (lhs, _) = Drule.dest_equals (Drule.strip_imp_concl (cprop_of thm)) handle TERM _ =>
   412       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   413 (*   val lhs = Pattern.eta_contract lhs; *)
   414     val a = (case cong_name (head_of (term_of lhs)) of
   415         Some a => a
   416       | None =>
   417         raise SIMPLIFIER ("Congruence must start with a constant or free variable", thm));
   418     val (alist,weak) = congs
   419     val alist2 = overwrite_warn (alist, (a,{lhs=lhs, thm=thm}))
   420            ("Overwriting congruence rule for " ^ quote a);
   421     val weak2 = if is_full_cong thm then weak else a::weak
   422   in
   423     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   424   end;
   425 
   426 val (op add_congs) = foldl add_cong;
   427 
   428 
   429 (* del_congs *)
   430 
   431 fun del_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   432   let
   433     val (lhs, _) = Logic.dest_equals (concl_of thm) handle TERM _ =>
   434       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   435 (*   val lhs = Pattern.eta_contract lhs; *)
   436     val a = (case cong_name (head_of lhs) of
   437         Some a => a
   438       | None =>
   439         raise SIMPLIFIER ("Congruence must start with a constant", thm));
   440     val (alist,_) = congs
   441     val alist2 = filter (fn (x,_)=> x<>a) alist
   442     val weak2 = mapfilter (fn(a,{thm,...}) => if is_full_cong thm then None
   443                                               else Some a)
   444                    alist2
   445   in
   446     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   447   end;
   448 
   449 val (op del_congs) = foldl del_cong;
   450 
   451 
   452 (* add_simprocs *)
   453 
   454 fun add_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   455     (name, lhs, proc, id)) =
   456   let val {sign, t, ...} = rep_cterm lhs
   457   in (trace_term false ("Adding simplification procedure " ^ quote name ^ " for")
   458       sign t;
   459     mk_mss (rules, congs,
   460       Net.insert_term ((t, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   461         handle Net.INSERT =>
   462             (warning ("Ignoring duplicate simplification procedure \""
   463                       ^ name ^ "\"");
   464              procs),
   465         bounds, prems, mk_rews, termless,depth))
   466   end;
   467 
   468 fun add_simproc (mss, (name, lhss, proc, id)) =
   469   foldl add_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   470 
   471 val add_simprocs = foldl add_simproc;
   472 
   473 
   474 (* del_simprocs *)
   475 
   476 fun del_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   477     (name, lhs, proc, id)) =
   478   mk_mss (rules, congs,
   479     Net.delete_term ((term_of lhs, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   480       handle Net.DELETE =>
   481           (warning ("Simplification procedure \"" ^ name ^
   482                        "\" not in simpset"); procs),
   483       bounds, prems, mk_rews, termless, depth);
   484 
   485 fun del_simproc (mss, (name, lhss, proc, id)) =
   486   foldl del_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   487 
   488 val del_simprocs = foldl del_simproc;
   489 
   490 
   491 (* prems *)
   492 
   493 fun add_prems (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thms) =
   494   mk_mss (rules, congs, procs, bounds, thms @ prems, mk_rews, termless, depth);
   495 
   496 fun prems_of_mss (Mss {prems, ...}) = prems;
   497 
   498 
   499 (* mk_rews *)
   500 
   501 fun set_mk_rews
   502   (Mss {rules, congs, procs, bounds, prems, mk_rews, termless, depth}, mk) =
   503     mk_mss (rules, congs, procs, bounds, prems,
   504             {mk=mk, mk_sym= #mk_sym mk_rews, mk_eq_True= #mk_eq_True mk_rews},
   505             termless, depth);
   506 
   507 fun set_mk_sym
   508   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_sym) =
   509     mk_mss (rules, congs, procs, bounds, prems,
   510             {mk= #mk mk_rews, mk_sym= mk_sym, mk_eq_True= #mk_eq_True mk_rews},
   511             termless,depth);
   512 
   513 fun set_mk_eq_True
   514   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_eq_True) =
   515     mk_mss (rules, congs, procs, bounds, prems,
   516             {mk= #mk mk_rews, mk_sym= #mk_sym mk_rews, mk_eq_True= mk_eq_True},
   517             termless,depth);
   518 
   519 fun get_mk_rews    (Mss {mk_rews,...}) = #mk         mk_rews
   520 fun get_mk_sym     (Mss {mk_rews,...}) = #mk_sym     mk_rews
   521 fun get_mk_eq_True (Mss {mk_rews,...}) = #mk_eq_True mk_rews
   522 
   523 (* termless *)
   524 
   525 fun set_termless
   526   (Mss {rules, congs, procs, bounds, prems, mk_rews, depth, ...}, termless) =
   527     mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth);
   528 
   529 
   530 
   531 (** rewriting **)
   532 
   533 (*
   534   Uses conversions, see:
   535     L C Paulson, A higher-order implementation of rewriting,
   536     Science of Computer Programming 3 (1983), pages 119-149.
   537 *)
   538 
   539 val dest_eq = Drule.dest_equals o cprop_of;
   540 val lhs_of = fst o dest_eq;
   541 val rhs_of = snd o dest_eq;
   542 
   543 fun check_conv msg thm thm' =
   544   let
   545     val thm'' = transitive thm (transitive
   546       (symmetric (Drule.beta_eta_conversion (lhs_of thm'))) thm')
   547   in (if msg then trace_thm "SUCCEEDED" thm' else (); Some thm'') end
   548   handle THM _ =>
   549     let val {sign, prop = _ $ _ $ prop0, ...} = rep_thm thm;
   550     in
   551       (trace_thm "Proved wrong thm (Check subgoaler?)" thm';
   552        trace_term false "Should have proved:" sign prop0;
   553        None)
   554     end;
   555 
   556 
   557 (* mk_procrule *)
   558 
   559 fun mk_procrule thm =
   560   let val (_,prems,lhs,elhs,rhs,_) = decomp_simp thm
   561   in if rewrite_rule_extra_vars prems lhs rhs
   562      then (prthm true "Extra vars on rhs:" thm; [])
   563      else [mk_rrule2{thm=thm, name="", lhs=lhs, elhs=elhs, perm=false}]
   564   end;
   565 
   566 
   567 (* conversion to apply the meta simpset to a term *)
   568 
   569 (* Since the rewriting strategy is bottom-up, we avoid re-normalizing already
   570    normalized terms by carrying around the rhs of the rewrite rule just
   571    applied. This is called the `skeleton'. It is decomposed in parallel
   572    with the term. Once a Var is encountered, the corresponding term is
   573    already in normal form.
   574    skel0 is a dummy skeleton that is to enforce complete normalization.
   575 *)
   576 val skel0 = Bound 0;
   577 
   578 (* Use rhs as skeleton only if the lhs does not contain unnormalized bits.
   579    The latter may happen iff there are weak congruence rules for constants
   580    in the lhs.
   581 *)
   582 fun uncond_skel((_,weak),(lhs,rhs)) =
   583   if null weak then rhs (* optimization *)
   584   else if exists_Const (fn (c,_) => c mem weak) lhs then skel0
   585        else rhs;
   586 
   587 (* Behaves like unconditional rule if rhs does not contain vars not in the lhs.
   588    Otherwise those vars may become instantiated with unnormalized terms
   589    while the premises are solved.
   590 *)
   591 fun cond_skel(args as (congs,(lhs,rhs))) =
   592   if term_varnames rhs subset term_varnames lhs then uncond_skel(args)
   593   else skel0;
   594 
   595 (*
   596   we try in order:
   597     (1) beta reduction
   598     (2) unconditional rewrite rules
   599     (3) conditional rewrite rules
   600     (4) simplification procedures
   601 
   602   IMPORTANT: rewrite rules must not introduce new Vars or TVars!
   603 
   604 *)
   605 
   606 fun rewritec (prover, signt, maxt)
   607              (mss as Mss{rules, procs, termless, prems, congs, depth,...}) t =
   608   let
   609     val eta_thm = Thm.eta_conversion t;
   610     val eta_t' = rhs_of eta_thm;
   611     val eta_t = term_of eta_t';
   612     val tsigt = Sign.tsig_of signt;
   613     fun rew {thm, name, lhs, elhs, fo, perm} =
   614       let
   615         val {sign, prop, maxidx, ...} = rep_thm thm;
   616         val _ = if Sign.subsig (sign, signt) then ()
   617                 else (prthm true "Ignoring rewrite rule from different theory:" thm;
   618                       raise Pattern.MATCH);
   619         val (rthm, elhs') = if maxt = ~1 then (thm, elhs)
   620           else (Thm.incr_indexes (maxt+1) thm, Thm.cterm_incr_indexes (maxt+1) elhs);
   621         val insts = if fo then Thm.cterm_first_order_match (elhs', eta_t')
   622                           else Thm.cterm_match (elhs', eta_t');
   623         val thm' = Thm.instantiate insts (Thm.rename_boundvars lhs eta_t rthm);
   624         val prop' = Thm.prop_of thm';
   625         val unconditional = (Logic.count_prems (prop',0) = 0);
   626         val (lhs', rhs') = Logic.dest_equals (Logic.strip_imp_concl prop')
   627       in
   628         if perm andalso not (termless (rhs', lhs'))
   629         then (trace_named_thm "Cannot apply permutative rewrite rule" (thm, name);
   630               trace_thm "Term does not become smaller:" thm'; None)
   631         else (trace_named_thm "Applying instance of rewrite rule" (thm, name);
   632            if unconditional
   633            then
   634              (trace_thm "Rewriting:" thm';
   635               let val lr = Logic.dest_equals prop;
   636                   val Some thm'' = check_conv false eta_thm thm'
   637               in Some (thm'', uncond_skel (congs, lr)) end)
   638            else
   639              (trace_thm "Trying to rewrite:" thm';
   640               case incr_depth mss of
   641                 None => (trace_thm "FAILED - reached depth limit" thm'; None)
   642               | Some mss =>
   643               (case prover mss thm' of
   644                 None       => (trace_thm "FAILED" thm'; None)
   645               | Some thm2 =>
   646                   (case check_conv true eta_thm thm2 of
   647                      None => None |
   648                      Some thm2' =>
   649                        let val concl = Logic.strip_imp_concl prop
   650                            val lr = Logic.dest_equals concl
   651                        in Some (thm2', cond_skel (congs, lr)) end))))
   652       end
   653 
   654     fun rews [] = None
   655       | rews (rrule :: rrules) =
   656           let val opt = rew rrule handle Pattern.MATCH => None
   657           in case opt of None => rews rrules | some => some end;
   658 
   659     fun sort_rrules rrs = let
   660       fun is_simple({thm, ...}:rrule) = case Thm.prop_of thm of
   661                                       Const("==",_) $ _ $ _ => true
   662                                       | _                   => false
   663       fun sort []        (re1,re2) = re1 @ re2
   664         | sort (rr::rrs) (re1,re2) = if is_simple rr
   665                                      then sort rrs (rr::re1,re2)
   666                                      else sort rrs (re1,rr::re2)
   667     in sort rrs ([],[]) end
   668 
   669     fun proc_rews ([]:simproc list) = None
   670       | proc_rews ({name, proc, lhs, ...} :: ps) =
   671           if Pattern.matches tsigt (term_of lhs, term_of t) then
   672             (debug_term false ("Trying procedure " ^ quote name ^ " on:") signt eta_t;
   673              case transform_failure (curry SIMPROC_FAIL name)
   674                  (fn () => proc signt prems eta_t) () of
   675                None => (debug false "FAILED"; proc_rews ps)
   676              | Some raw_thm =>
   677                  (trace_thm ("Procedure " ^ quote name ^ " produced rewrite rule:") raw_thm;
   678                   (case rews (mk_procrule raw_thm) of
   679                     None => (trace_cterm true ("IGNORED result of simproc " ^ quote name ^
   680                       " -- does not match") t; proc_rews ps)
   681                   | some => some)))
   682           else proc_rews ps;
   683   in case eta_t of
   684        Abs _ $ _ => Some (transitive eta_thm
   685          (beta_conversion false eta_t'), skel0)
   686      | _ => (case rews (sort_rrules (Net.match_term rules eta_t)) of
   687                None => proc_rews (Net.match_term procs eta_t)
   688              | some => some)
   689   end;
   690 
   691 
   692 (* conversion to apply a congruence rule to a term *)
   693 
   694 fun congc (prover,signt,maxt) {thm=cong,lhs=lhs} t =
   695   let val sign = Thm.sign_of_thm cong
   696       val _ = if Sign.subsig (sign, signt) then ()
   697                  else error("Congruence rule from different theory")
   698       val rthm = if maxt = ~1 then cong else Thm.incr_indexes (maxt+1) cong;
   699       val rlhs = fst (Drule.dest_equals (Drule.strip_imp_concl (cprop_of rthm)));
   700       val insts = Thm.cterm_match (rlhs, t)
   701       (* Pattern.match can raise Pattern.MATCH;
   702          is handled when congc is called *)
   703       val thm' = Thm.instantiate insts (Thm.rename_boundvars (term_of rlhs) (term_of t) rthm);
   704       val unit = trace_thm "Applying congruence rule:" thm';
   705       fun err (msg, thm) = (trace_thm msg thm; None)
   706   in case prover thm' of
   707        None => err ("Congruence proof failed.  Could not prove", thm')
   708      | Some thm2 => (case check_conv true (Drule.beta_eta_conversion t) thm2 of
   709           None => err ("Congruence proof failed.  Should not have proved", thm2)
   710         | Some thm2' =>
   711             if op aconv (pairself term_of (dest_equals (cprop_of thm2')))
   712             then None else Some thm2')
   713   end;
   714 
   715 val (cA, (cB, cC)) =
   716   apsnd dest_equals (dest_implies (hd (cprems_of Drule.imp_cong)));
   717 
   718 fun transitive1 None None = None
   719   | transitive1 (Some thm1) None = Some thm1
   720   | transitive1 None (Some thm2) = Some thm2
   721   | transitive1 (Some thm1) (Some thm2) = Some (transitive thm1 thm2)
   722 
   723 fun transitive2 thm = transitive1 (Some thm);
   724 fun transitive3 thm = transitive1 thm o Some;
   725 
   726 fun bottomc ((simprem,useprem,mutsimp), prover, sign, maxidx) =
   727   let
   728     fun botc skel mss t =
   729           if is_Var skel then None
   730           else
   731           (case subc skel mss t of
   732              some as Some thm1 =>
   733                (case rewritec (prover, sign, maxidx) mss (rhs_of thm1) of
   734                   Some (thm2, skel2) =>
   735                     transitive2 (transitive thm1 thm2)
   736                       (botc skel2 mss (rhs_of thm2))
   737                 | None => some)
   738            | None =>
   739                (case rewritec (prover, sign, maxidx) mss t of
   740                   Some (thm2, skel2) => transitive2 thm2
   741                     (botc skel2 mss (rhs_of thm2))
   742                 | None => None))
   743 
   744     and try_botc mss t =
   745           (case botc skel0 mss t of
   746              Some trec1 => trec1 | None => (reflexive t))
   747 
   748     and subc skel
   749           (mss as Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) t0 =
   750        (case term_of t0 of
   751            Abs (a, T, t) =>
   752              let val b = variant bounds a
   753                  val (v, t') = Thm.dest_abs (Some ("." ^ b)) t0
   754                  val mss' = mk_mss (rules, congs, procs, b :: bounds, prems, mk_rews, termless,depth)
   755                  val skel' = case skel of Abs (_, _, sk) => sk | _ => skel0
   756              in case botc skel' mss' t' of
   757                   Some thm => Some (abstract_rule a v thm)
   758                 | None => None
   759              end
   760          | t $ _ => (case t of
   761              Const ("==>", _) $ _  => impc t0 mss
   762            | Abs _ =>
   763                let val thm = beta_conversion false t0
   764                in case subc skel0 mss (rhs_of thm) of
   765                     None => Some thm
   766                   | Some thm' => Some (transitive thm thm')
   767                end
   768            | _  =>
   769                let fun appc () =
   770                      let
   771                        val (tskel, uskel) = case skel of
   772                            tskel $ uskel => (tskel, uskel)
   773                          | _ => (skel0, skel0);
   774                        val (ct, cu) = Thm.dest_comb t0
   775                      in
   776                      (case botc tskel mss ct of
   777                         Some thm1 =>
   778                           (case botc uskel mss cu of
   779                              Some thm2 => Some (combination thm1 thm2)
   780                            | None => Some (combination thm1 (reflexive cu)))
   781                       | None =>
   782                           (case botc uskel mss cu of
   783                              Some thm1 => Some (combination (reflexive ct) thm1)
   784                            | None => None))
   785                      end
   786                    val (h, ts) = strip_comb t
   787                in case cong_name h of
   788                     Some a =>
   789                       (case assoc_string (fst congs, a) of
   790                          None => appc ()
   791                        | Some cong =>
   792 (* post processing: some partial applications h t1 ... tj, j <= length ts,
   793    may be a redex. Example: map (%x.x) = (%xs.xs) wrt map_cong *)
   794                           (let
   795                              val thm = congc (prover mss, sign, maxidx) cong t0;
   796                              val t = if_none (apsome rhs_of thm) t0;
   797                              val (cl, cr) = Thm.dest_comb t
   798                              val dVar = Var(("", 0), dummyT)
   799                              val skel =
   800                                list_comb (h, replicate (length ts) dVar)
   801                            in case botc skel mss cl of
   802                                 None => thm
   803                               | Some thm' => transitive3 thm
   804                                   (combination thm' (reflexive cr))
   805                            end handle TERM _ => error "congc result"
   806                                     | Pattern.MATCH => appc ()))
   807                   | _ => appc ()
   808                end)
   809          | _ => None)
   810 
   811     and impc ct mss =
   812       if mutsimp then mut_impc0 [] ct [] [] mss else nonmut_impc ct mss
   813 
   814     and rules_of_prem mss prem =
   815       if maxidx_of_term (term_of prem) <> ~1
   816       then (trace_cterm true
   817         "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem; ([], None))
   818       else
   819         let val asm = assume prem
   820         in (extract_safe_rrules (mss, asm), Some asm) end
   821 
   822     and add_rrules (rrss, asms) mss =
   823       add_prems (foldl (insert_rrule true) (mss, flat rrss), mapfilter I asms)
   824 
   825     and disch r (prem, eq) =
   826       let
   827         val (lhs, rhs) = dest_eq eq;
   828         val eq' = implies_elim (Thm.instantiate
   829           ([], [(cA, prem), (cB, lhs), (cC, rhs)]) Drule.imp_cong)
   830           (implies_intr prem eq)
   831       in if not r then eq' else
   832         let
   833           val (prem', concl) = dest_implies lhs;
   834           val (prem'', _) = dest_implies rhs
   835         in transitive (transitive
   836           (Thm.instantiate ([], [(cA, prem'), (cB, prem), (cC, concl)])
   837              Drule.swap_prems_eq) eq')
   838           (Thm.instantiate ([], [(cA, prem), (cB, prem''), (cC, concl)])
   839              Drule.swap_prems_eq)
   840         end
   841       end
   842 
   843     and rebuild [] _ _ _ _ eq = eq
   844       | rebuild (prem :: prems) concl (rrs :: rrss) (asm :: asms) mss eq =
   845           let
   846             val mss' = add_rrules (rev rrss, rev asms) mss;
   847             val concl' =
   848               Drule.mk_implies (prem, if_none (apsome rhs_of eq) concl);
   849             val dprem = apsome (curry (disch false) prem)
   850           in case rewritec (prover, sign, maxidx) mss' concl' of
   851               None => rebuild prems concl' rrss asms mss (dprem eq)
   852             | Some (eq', _) => transitive2 (foldl (disch false o swap)
   853                   (the (transitive3 (dprem eq) eq'), prems))
   854                 (mut_impc0 (rev prems) (rhs_of eq') (rev rrss) (rev asms) mss)
   855           end
   856           
   857     and mut_impc0 prems concl rrss asms mss =
   858       let
   859         val prems' = strip_imp_prems concl;
   860         val (rrss', asms') = split_list (map (rules_of_prem mss) prems')
   861       in mut_impc (prems @ prems') (strip_imp_concl concl) (rrss @ rrss')
   862         (asms @ asms') [] [] [] [] mss ~1 ~1
   863       end
   864  
   865     and mut_impc [] concl [] [] prems' rrss' asms' eqns mss changed k =
   866         transitive1 (foldl (fn (eq2, (eq1, prem)) => transitive1 eq1
   867             (apsome (curry (disch false) prem) eq2)) (None, eqns ~~ prems'))
   868           (if changed > 0 then
   869              mut_impc (rev prems') concl (rev rrss') (rev asms')
   870                [] [] [] [] mss ~1 changed
   871            else rebuild prems' concl rrss' asms' mss
   872              (botc skel0 (add_rrules (rev rrss', rev asms') mss) concl))
   873 
   874       | mut_impc (prem :: prems) concl (rrs :: rrss) (asm :: asms)
   875           prems' rrss' asms' eqns mss changed k =
   876         case (if k = 0 then None else botc skel0 (add_rrules
   877           (rev rrss' @ rrss, rev asms' @ asms) mss) prem) of
   878             None => mut_impc prems concl rrss asms (prem :: prems')
   879               (rrs :: rrss') (asm :: asms') (None :: eqns) mss changed
   880               (if k = 0 then 0 else k - 1)
   881           | Some eqn =>
   882             let
   883               val prem' = rhs_of eqn;
   884               val tprems = map term_of prems;
   885               val i = 1 + foldl Int.max (~1, map (fn p =>
   886                 find_index_eq p tprems) (#hyps (rep_thm eqn)));
   887               val (rrs', asm') = rules_of_prem mss prem'
   888             in mut_impc prems concl rrss asms (prem' :: prems')
   889               (rrs' :: rrss') (asm' :: asms') (Some (foldr (disch true)
   890                 (take (i, prems), Drule.imp_cong' eqn (reflexive (Drule.list_implies
   891                   (drop (i, prems), concl))))) :: eqns) mss (length prems') ~1
   892             end
   893 
   894      (* legacy code - only for backwards compatibility *)
   895      and nonmut_impc ct mss =
   896        let val (prem, conc) = dest_implies ct;
   897            val thm1 = if simprem then botc skel0 mss prem else None;
   898            val prem1 = if_none (apsome rhs_of thm1) prem;
   899            val mss1 = if not useprem then mss else add_rrules
   900              (apsnd single (apfst single (rules_of_prem mss prem1))) mss
   901        in (case botc skel0 mss1 conc of
   902            None => (case thm1 of
   903                None => None
   904              | Some thm1' => Some (Drule.imp_cong' thm1' (reflexive conc)))
   905          | Some thm2 =>
   906            let val thm2' = disch false (prem1, thm2)
   907            in (case thm1 of
   908                None => Some thm2'
   909              | Some thm1' =>
   910                  Some (transitive (Drule.imp_cong' thm1' (reflexive conc)) thm2'))
   911            end)
   912        end
   913 
   914  in try_botc end;
   915 
   916 
   917 (*** Meta-rewriting: rewrites t to u and returns the theorem t==u ***)
   918 
   919 (*
   920   Parameters:
   921     mode = (simplify A,
   922             use A in simplifying B,
   923             use prems of B (if B is again a meta-impl.) to simplify A)
   924            when simplifying A ==> B
   925     mss: contains equality theorems of the form [|p1,...|] ==> t==u
   926     prover: how to solve premises in conditional rewrites and congruences
   927 *)
   928 
   929 fun rewrite_cterm mode prover mss ct =
   930   let val {sign, t, maxidx, ...} = rep_cterm ct
   931       val Mss{depth, ...} = mss
   932   in trace_cterm false "SIMPLIFIER INVOKED ON THE FOLLOWING TERM:" ct;
   933      simp_depth := depth;
   934      bottomc (mode, prover, sign, maxidx) mss ct
   935   end
   936   handle THM (s, _, thms) =>
   937     error ("Exception THM was raised in simplifier:\n" ^ s ^ "\n" ^
   938       Pretty.string_of (Display.pretty_thms thms));
   939 
   940 (*Rewrite a cterm*)
   941 fun rewrite_aux _ _ [] = (fn ct => Thm.reflexive ct)
   942   | rewrite_aux prover full thms = rewrite_cterm (full, false, false) prover (mss_of thms);
   943 
   944 (*Rewrite a theorem*)
   945 fun simplify_aux _ _ [] = (fn th => th)
   946   | simplify_aux prover full thms =
   947       Drule.fconv_rule (rewrite_cterm (full, false, false) prover (mss_of thms));
   948 
   949 fun rewrite_thm mode prover mss = Drule.fconv_rule (rewrite_cterm mode prover mss);
   950 
   951 (*Rewrite the subgoals of a proof state (represented by a theorem) *)
   952 fun rewrite_goals_rule_aux _ []   th = th
   953   | rewrite_goals_rule_aux prover thms th =
   954       Drule.fconv_rule (Drule.goals_conv (K true) (rewrite_cterm (true, true, false) prover
   955         (mss_of thms))) th;
   956 
   957 (*Rewrite the subgoal of a proof state (represented by a theorem) *)
   958 fun rewrite_goal_rule mode prover mss i thm =
   959   if 0 < i  andalso  i <= nprems_of thm
   960   then Drule.fconv_rule (Drule.goals_conv (fn j => j=i) (rewrite_cterm mode prover mss)) thm
   961   else raise THM("rewrite_goal_rule",i,[thm]);
   962 
   963 
   964 (*simple term rewriting -- without proofs*)
   965 fun rewrite_term sg rules procs =
   966   Pattern.rewrite_term (Sign.tsig_of sg) (map decomp_simp' rules) procs;
   967 
   968 end;
   969 
   970 structure BasicMetaSimplifier: BASIC_META_SIMPLIFIER = MetaSimplifier;
   971 open BasicMetaSimplifier;