src/Pure/meta_simplifier.ML
author ballarin
Thu Feb 27 15:12:29 2003 +0100 (2003-02-27)
changeset 13835 12b2ffbe543a
parent 13828 fb6ec40dd291
child 13932 0eb3d91b519a
permissions -rw-r--r--
Change to meta simplifier: congruence rules may now have frees as head of term.
     1 (*  Title:      Pure/meta_simplifier.ML
     2     ID:         $Id$
     3     Author:     Tobias Nipkow and Stefan Berghofer
     4     License:    GPL (GNU GENERAL PUBLIC LICENSE)
     5 
     6 Meta-level Simplification.
     7 *)
     8 
     9 signature BASIC_META_SIMPLIFIER =
    10 sig
    11   val trace_simp: bool ref
    12   val debug_simp: bool ref
    13   val simp_depth_limit: int ref
    14 end;
    15 
    16 signature META_SIMPLIFIER =
    17 sig
    18   include BASIC_META_SIMPLIFIER
    19   exception SIMPLIFIER of string * thm
    20   exception SIMPROC_FAIL of string * exn
    21   type meta_simpset
    22   val dest_mss          : meta_simpset ->
    23     {simps: thm list, congs: thm list, procs: (string * cterm list) list}
    24   val empty_mss         : meta_simpset
    25   val clear_mss         : meta_simpset -> meta_simpset
    26   val merge_mss         : meta_simpset * meta_simpset -> meta_simpset
    27   val add_simps         : meta_simpset * thm list -> meta_simpset
    28   val del_simps         : meta_simpset * thm list -> meta_simpset
    29   val mss_of            : thm list -> meta_simpset
    30   val add_congs         : meta_simpset * thm list -> meta_simpset
    31   val del_congs         : meta_simpset * thm list -> meta_simpset
    32   val add_simprocs      : meta_simpset *
    33     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    34       -> meta_simpset
    35   val del_simprocs      : meta_simpset *
    36     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    37       -> meta_simpset
    38   val add_prems         : meta_simpset * thm list -> meta_simpset
    39   val prems_of_mss      : meta_simpset -> thm list
    40   val set_mk_rews       : meta_simpset * (thm -> thm list) -> meta_simpset
    41   val set_mk_sym        : meta_simpset * (thm -> thm option) -> meta_simpset
    42   val set_mk_eq_True    : meta_simpset * (thm -> thm option) -> meta_simpset
    43   val set_termless      : meta_simpset * (term * term -> bool) -> meta_simpset
    44   val beta_eta_conversion: cterm -> thm
    45   val rewrite_cterm: bool * bool * bool ->
    46     (meta_simpset -> thm -> thm option) -> meta_simpset -> cterm -> thm
    47   val goals_conv        : (int -> bool) -> (cterm -> thm) -> cterm -> thm
    48   val forall_conv       : (cterm -> thm) -> cterm -> thm
    49   val fconv_rule        : (cterm -> thm) -> thm -> thm
    50   val rewrite_aux       : (meta_simpset -> thm -> thm option) -> bool -> thm list -> cterm -> thm
    51   val simplify_aux      : (meta_simpset -> thm -> thm option) -> bool -> thm list -> thm -> thm
    52   val rewrite_thm       : bool * bool * bool
    53                           -> (meta_simpset -> thm -> thm option)
    54                           -> meta_simpset -> thm -> thm
    55   val rewrite_goals_rule_aux: (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
    56   val rewrite_goal_rule : bool* bool * bool
    57                           -> (meta_simpset -> thm -> thm option)
    58                           -> meta_simpset -> int -> thm -> thm
    59   val rewrite_term: Sign.sg -> thm list -> (term -> term option) list -> term -> term
    60 end;
    61 
    62 structure MetaSimplifier : META_SIMPLIFIER =
    63 struct
    64 
    65 (** diagnostics **)
    66 
    67 exception SIMPLIFIER of string * thm;
    68 exception SIMPROC_FAIL of string * exn;
    69 
    70 val simp_depth = ref 0;
    71 val simp_depth_limit = ref 1000;
    72 
    73 local
    74 
    75 fun println a =
    76   tracing ((case ! simp_depth of 0 => "" | n => "[" ^ string_of_int n ^ "]") ^ a);
    77 
    78 fun prnt warn a = if warn then warning a else println a;
    79 fun prtm warn a sign t = prnt warn (a ^ "\n" ^ Sign.string_of_term sign t);
    80 fun prctm warn a t = prnt warn (a ^ "\n" ^ Display.string_of_cterm t);
    81 
    82 in
    83 
    84 fun prthm warn a = prctm warn a o Thm.cprop_of;
    85 
    86 val trace_simp = ref false;
    87 val debug_simp = ref false;
    88 
    89 fun trace warn a = if !trace_simp then prnt warn a else ();
    90 fun debug warn a = if !debug_simp then prnt warn a else ();
    91 
    92 fun trace_term warn a sign t = if !trace_simp then prtm warn a sign t else ();
    93 fun trace_cterm warn a t = if !trace_simp then prctm warn a t else ();
    94 fun debug_term warn a sign t = if !debug_simp then prtm warn a sign t else ();
    95 
    96 fun trace_thm a thm =
    97   let val {sign, prop, ...} = rep_thm thm
    98   in trace_term false a sign prop end;
    99 
   100 fun trace_named_thm a (thm, name) =
   101   trace_thm (a ^ (if name = "" then "" else " " ^ quote name) ^ ":") thm;
   102 
   103 end;
   104 
   105 
   106 (** meta simp sets **)
   107 
   108 (* basic components *)
   109 
   110 type rrule = {thm: thm, name: string, lhs: term, elhs: cterm, fo: bool, perm: bool};
   111 (* thm: the rewrite rule
   112    name: name of theorem from which rewrite rule was extracted
   113    lhs: the left-hand side
   114    elhs: the etac-contracted lhs.
   115    fo:  use first-order matching
   116    perm: the rewrite rule is permutative
   117 Remarks:
   118   - elhs is used for matching,
   119     lhs only for preservation of bound variable names.
   120   - fo is set iff
   121     either elhs is first-order (no Var is applied),
   122            in which case fo-matching is complete,
   123     or elhs is not a pattern,
   124        in which case there is nothing better to do.
   125 *)
   126 type cong = {thm: thm, lhs: cterm};
   127 type simproc =
   128  {name: string, proc: Sign.sg -> thm list -> term -> thm option, lhs: cterm, id: stamp};
   129 
   130 fun eq_rrule ({thm = thm1, ...}: rrule, {thm = thm2, ...}: rrule) =
   131   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   132 
   133 fun eq_cong ({thm = thm1, ...}: cong, {thm = thm2, ...}: cong) =
   134   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   135 
   136 fun eq_prem (thm1, thm2) =
   137   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   138 
   139 fun eq_simproc ({id = s1, ...}:simproc, {id = s2, ...}:simproc) = (s1 = s2);
   140 
   141 fun mk_simproc (name, proc, lhs, id) =
   142   {name = name, proc = proc, lhs = lhs, id = id};
   143 
   144 
   145 (* datatype mss *)
   146 
   147 (*
   148   A "mss" contains data needed during conversion:
   149     rules: discrimination net of rewrite rules;
   150     congs: association list of congruence rules and
   151            a list of `weak' congruence constants.
   152            A congruence is `weak' if it avoids normalization of some argument.
   153     procs: discrimination net of simplification procedures
   154       (functions that prove rewrite rules on the fly);
   155     bounds: names of bound variables already used
   156       (for generating new names when rewriting under lambda abstractions);
   157     prems: current premises;
   158     mk_rews: mk: turns simplification thms into rewrite rules;
   159              mk_sym: turns == around; (needs Drule!)
   160              mk_eq_True: turns P into P == True - logic specific;
   161     termless: relation for ordered rewriting;
   162     depth: depth of conditional rewriting;
   163 *)
   164 
   165 datatype meta_simpset =
   166   Mss of {
   167     rules: rrule Net.net,
   168     congs: (string * cong) list * string list,
   169     procs: simproc Net.net,
   170     bounds: string list,
   171     prems: thm list,
   172     mk_rews: {mk: thm -> thm list,
   173               mk_sym: thm -> thm option,
   174               mk_eq_True: thm -> thm option},
   175     termless: term * term -> bool,
   176     depth: int};
   177 
   178 fun mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth) =
   179   Mss {rules = rules, congs = congs, procs = procs, bounds = bounds,
   180        prems=prems, mk_rews=mk_rews, termless=termless, depth=depth};
   181 
   182 fun upd_rules(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}, rules') =
   183   mk_mss(rules',congs,procs,bounds,prems,mk_rews,termless,depth);
   184 
   185 val empty_mss =
   186   let val mk_rews = {mk = K [], mk_sym = K None, mk_eq_True = K None}
   187   in mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, Term.termless, 0) end;
   188 
   189 fun clear_mss (Mss {mk_rews, termless, ...}) =
   190   mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, termless,0);
   191 
   192 fun incr_depth(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) =
   193   let val depth1 = depth+1
   194   in if depth1 > !simp_depth_limit
   195      then (warning "simp_depth_limit exceeded - giving up"; None)
   196      else (if depth1 mod 5 = 0
   197            then warning("Simplification depth " ^ string_of_int depth1)
   198            else ();
   199            Some(mk_mss(rules,congs,procs,bounds,prems,mk_rews,termless,depth1))
   200           )
   201   end;
   202 
   203 
   204 (** simpset operations **)
   205 
   206 (* term variables *)
   207 
   208 val add_term_varnames = foldl_aterms (fn (xs, Var (x, _)) => ins_ix (x, xs) | (xs, _) => xs);
   209 fun term_varnames t = add_term_varnames ([], t);
   210 
   211 
   212 (* dest_mss *)
   213 
   214 fun dest_mss (Mss {rules, congs, procs, ...}) =
   215   {simps = map (fn (_, {thm, ...}) => thm) (Net.dest rules),
   216    congs = map (fn (_, {thm, ...}) => thm) (fst congs),
   217    procs =
   218      map (fn (_, {name, lhs, id, ...}) => ((name, lhs), id)) (Net.dest procs)
   219      |> partition_eq eq_snd
   220      |> map (fn ps => (#1 (#1 (hd ps)), map (#2 o #1) ps))
   221      |> Library.sort_wrt #1};
   222 
   223 
   224 (* merge_mss *)       (*NOTE: ignores mk_rews, termless and depth of 2nd mss*)
   225 
   226 fun merge_mss
   227  (Mss {rules = rules1, congs = (congs1,weak1), procs = procs1,
   228        bounds = bounds1, prems = prems1, mk_rews, termless, depth},
   229   Mss {rules = rules2, congs = (congs2,weak2), procs = procs2,
   230        bounds = bounds2, prems = prems2, ...}) =
   231       mk_mss
   232        (Net.merge (rules1, rules2, eq_rrule),
   233         (gen_merge_lists (eq_cong o pairself snd) congs1 congs2,
   234         merge_lists weak1 weak2),
   235         Net.merge (procs1, procs2, eq_simproc),
   236         merge_lists bounds1 bounds2,
   237         gen_merge_lists eq_prem prems1 prems2,
   238         mk_rews, termless, depth);
   239 
   240 
   241 (* add_simps *)
   242 
   243 fun mk_rrule2{thm, name, lhs, elhs, perm} =
   244   let val fo = Pattern.first_order (term_of elhs) orelse not(Pattern.pattern (term_of elhs))
   245   in {thm=thm, name=name, lhs=lhs, elhs=elhs, fo=fo, perm=perm} end
   246 
   247 fun insert_rrule quiet (mss as Mss {rules,...},
   248                  rrule as {thm,name,lhs,elhs,perm}) =
   249   (trace_named_thm "Adding rewrite rule" (thm, name);
   250    let val rrule2 as {elhs,...} = mk_rrule2 rrule
   251        val rules' = Net.insert_term ((term_of elhs, rrule2), rules, eq_rrule)
   252    in upd_rules(mss,rules') end
   253    handle Net.INSERT => if quiet then mss else
   254      (prthm true "Ignoring duplicate rewrite rule:" thm; mss));
   255 
   256 fun vperm (Var _, Var _) = true
   257   | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t)
   258   | vperm (t1 $ t2, u1 $ u2) = vperm (t1, u1) andalso vperm (t2, u2)
   259   | vperm (t, u) = (t = u);
   260 
   261 fun var_perm (t, u) =
   262   vperm (t, u) andalso eq_set (term_varnames t, term_varnames u);
   263 
   264 (* FIXME: it seems that the conditions on extra variables are too liberal if
   265 prems are nonempty: does solving the prems really guarantee instantiation of
   266 all its Vars? Better: a dynamic check each time a rule is applied.
   267 *)
   268 fun rewrite_rule_extra_vars prems elhs erhs =
   269   not (term_varnames erhs subset foldl add_term_varnames (term_varnames elhs, prems))
   270   orelse
   271   not ((term_tvars erhs) subset
   272        (term_tvars elhs  union  List.concat(map term_tvars prems)));
   273 
   274 (*Simple test for looping rewrite rules and stupid orientations*)
   275 fun reorient sign prems lhs rhs =
   276    rewrite_rule_extra_vars prems lhs rhs
   277   orelse
   278    is_Var (head_of lhs)
   279   orelse
   280    (exists (apl (lhs, Logic.occs)) (rhs :: prems))
   281   orelse
   282    (null prems andalso
   283     Pattern.matches (#tsig (Sign.rep_sg sign)) (lhs, rhs))
   284     (*the condition "null prems" is necessary because conditional rewrites
   285       with extra variables in the conditions may terminate although
   286       the rhs is an instance of the lhs. Example: ?m < ?n ==> f(?n) == f(?m)*)
   287   orelse
   288    (is_Const lhs andalso not(is_Const rhs))
   289 
   290 fun decomp_simp thm =
   291   let val {sign, prop, ...} = rep_thm thm;
   292       val prems = Logic.strip_imp_prems prop;
   293       val concl = Drule.strip_imp_concl (cprop_of thm);
   294       val (lhs, rhs) = Drule.dest_equals concl handle TERM _ =>
   295         raise SIMPLIFIER ("Rewrite rule not a meta-equality", thm)
   296       val elhs = snd (Drule.dest_equals (cprop_of (Thm.eta_conversion lhs)));
   297       val elhs = if elhs=lhs then lhs else elhs (* try to share *)
   298       val erhs = Pattern.eta_contract (term_of rhs);
   299       val perm = var_perm (term_of elhs, erhs) andalso not (term_of elhs aconv erhs)
   300                  andalso not (is_Var (term_of elhs))
   301   in (sign, prems, term_of lhs, elhs, term_of rhs, perm) end;
   302 
   303 fun decomp_simp' thm =
   304   let val (_, _, lhs, _, rhs, _) = decomp_simp thm in
   305     if Thm.nprems_of thm > 0 then raise SIMPLIFIER ("Bad conditional rewrite rule", thm)
   306     else (lhs, rhs)
   307   end;
   308 
   309 fun mk_eq_True (Mss{mk_rews={mk_eq_True,...},...}) (thm, name) =
   310   case mk_eq_True thm of
   311     None => []
   312   | Some eq_True =>
   313       let val (_,_,lhs,elhs,_,_) = decomp_simp eq_True
   314       in [{thm=eq_True, name=name, lhs=lhs, elhs=elhs, perm=false}] end;
   315 
   316 (* create the rewrite rule and possibly also the ==True variant,
   317    in case there are extra vars on the rhs *)
   318 fun rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm2) =
   319   let val rrule = {thm=thm, name=name, lhs=lhs, elhs=elhs, perm=false}
   320   in if (term_varnames rhs)  subset (term_varnames lhs) andalso
   321         (term_tvars rhs) subset (term_tvars lhs)
   322      then [rrule]
   323      else mk_eq_True mss (thm2, name) @ [rrule]
   324   end;
   325 
   326 fun mk_rrule mss (thm, name) =
   327   let val (_,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   328   in if perm then [{thm=thm, name=name, lhs=lhs, elhs=elhs, perm=true}] else
   329      (* weak test for loops: *)
   330      if rewrite_rule_extra_vars prems lhs rhs orelse
   331         is_Var (term_of elhs)
   332      then mk_eq_True mss (thm, name)
   333      else rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm)
   334   end;
   335 
   336 fun orient_rrule mss (thm, name) =
   337   let val (sign,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   338   in if perm then [{thm=thm, name=name, lhs=lhs, elhs=elhs, perm=true}]
   339      else if reorient sign prems lhs rhs
   340           then if reorient sign prems rhs lhs
   341                then mk_eq_True mss (thm, name)
   342                else let val Mss{mk_rews={mk_sym,...},...} = mss
   343                     in case mk_sym thm of
   344                          None => []
   345                        | Some thm' =>
   346                            let val (_,_,lhs',elhs',rhs',_) = decomp_simp thm'
   347                            in rrule_eq_True(thm',name,lhs',elhs',rhs',mss,thm) end
   348                     end
   349           else rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm)
   350   end;
   351 
   352 fun extract_rews(Mss{mk_rews = {mk,...},...},thms) =
   353   flat (map (fn thm => map (rpair (Thm.name_of_thm thm)) (mk thm)) thms);
   354 
   355 fun orient_comb_simps comb mk_rrule (mss,thms) =
   356   let val rews = extract_rews(mss,thms)
   357       val rrules = flat (map mk_rrule rews)
   358   in foldl comb (mss,rrules) end
   359 
   360 (* Add rewrite rules explicitly; do not reorient! *)
   361 fun add_simps(mss,thms) =
   362   orient_comb_simps (insert_rrule false) (mk_rrule mss) (mss,thms);
   363 
   364 fun mss_of thms = foldl (insert_rrule false) (empty_mss, flat
   365   (map (fn thm => mk_rrule empty_mss (thm, Thm.name_of_thm thm)) thms));
   366 
   367 fun extract_safe_rrules(mss,thm) =
   368   flat (map (orient_rrule mss) (extract_rews(mss,[thm])));
   369 
   370 (* del_simps *)
   371 
   372 fun del_rrule(mss as Mss {rules,...},
   373               rrule as {thm, elhs, ...}) =
   374   (upd_rules(mss, Net.delete_term ((term_of elhs, rrule), rules, eq_rrule))
   375    handle Net.DELETE =>
   376      (prthm true "Rewrite rule not in simpset:" thm; mss));
   377 
   378 fun del_simps(mss,thms) =
   379   orient_comb_simps del_rrule (map mk_rrule2 o mk_rrule mss) (mss,thms);
   380 
   381 
   382 (* add_congs *)
   383 
   384 fun is_full_cong_prems [] varpairs = null varpairs
   385   | is_full_cong_prems (p::prems) varpairs =
   386     (case Logic.strip_assums_concl p of
   387        Const("==",_) $ lhs $ rhs =>
   388          let val (x,xs) = strip_comb lhs and (y,ys) = strip_comb rhs
   389          in is_Var x  andalso  forall is_Bound xs  andalso
   390             null(findrep(xs))  andalso xs=ys andalso
   391             (x,y) mem varpairs andalso
   392             is_full_cong_prems prems (varpairs\(x,y))
   393          end
   394      | _ => false);
   395 
   396 fun is_full_cong thm =
   397 let val prems = prems_of thm
   398     and concl = concl_of thm
   399     val (lhs,rhs) = Logic.dest_equals concl
   400     val (f,xs) = strip_comb lhs
   401     and (g,ys) = strip_comb rhs
   402 in
   403   f=g andalso null(findrep(xs@ys)) andalso length xs = length ys andalso
   404   is_full_cong_prems prems (xs ~~ ys)
   405 end
   406 
   407 fun cong_name (Const (a, _)) = Some a
   408   | cong_name (Free (a, _)) = Some ("Free: " ^ a)
   409   | cong_name _ = None;
   410 
   411 fun add_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   412   let
   413     val (lhs, _) = Drule.dest_equals (Drule.strip_imp_concl (cprop_of thm)) handle TERM _ =>
   414       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   415 (*   val lhs = Pattern.eta_contract lhs; *)
   416     val a = (case cong_name (head_of (term_of lhs)) of
   417         Some a => a
   418       | None =>
   419         raise SIMPLIFIER ("Congruence must start with a constant or free variable", thm));
   420     val (alist,weak) = congs
   421     val alist2 = overwrite_warn (alist, (a,{lhs=lhs, thm=thm}))
   422            ("Overwriting congruence rule for " ^ quote a);
   423     val weak2 = if is_full_cong thm then weak else a::weak
   424   in
   425     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   426   end;
   427 
   428 val (op add_congs) = foldl add_cong;
   429 
   430 
   431 (* del_congs *)
   432 
   433 fun del_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   434   let
   435     val (lhs, _) = Logic.dest_equals (concl_of thm) handle TERM _ =>
   436       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   437 (*   val lhs = Pattern.eta_contract lhs; *)
   438     val a = (case cong_name (head_of lhs) of
   439         Some a => a
   440       | None =>
   441         raise SIMPLIFIER ("Congruence must start with a constant", thm));
   442     val (alist,_) = congs
   443     val alist2 = filter (fn (x,_)=> x<>a) alist
   444     val weak2 = mapfilter (fn(a,{thm,...}) => if is_full_cong thm then None
   445                                               else Some a)
   446                    alist2
   447   in
   448     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   449   end;
   450 
   451 val (op del_congs) = foldl del_cong;
   452 
   453 
   454 (* add_simprocs *)
   455 
   456 fun add_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   457     (name, lhs, proc, id)) =
   458   let val {sign, t, ...} = rep_cterm lhs
   459   in (trace_term false ("Adding simplification procedure " ^ quote name ^ " for")
   460       sign t;
   461     mk_mss (rules, congs,
   462       Net.insert_term ((t, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   463         handle Net.INSERT =>
   464             (warning ("Ignoring duplicate simplification procedure \""
   465                       ^ name ^ "\"");
   466              procs),
   467         bounds, prems, mk_rews, termless,depth))
   468   end;
   469 
   470 fun add_simproc (mss, (name, lhss, proc, id)) =
   471   foldl add_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   472 
   473 val add_simprocs = foldl add_simproc;
   474 
   475 
   476 (* del_simprocs *)
   477 
   478 fun del_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   479     (name, lhs, proc, id)) =
   480   mk_mss (rules, congs,
   481     Net.delete_term ((term_of lhs, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   482       handle Net.DELETE =>
   483           (warning ("Simplification procedure \"" ^ name ^
   484                        "\" not in simpset"); procs),
   485       bounds, prems, mk_rews, termless, depth);
   486 
   487 fun del_simproc (mss, (name, lhss, proc, id)) =
   488   foldl del_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   489 
   490 val del_simprocs = foldl del_simproc;
   491 
   492 
   493 (* prems *)
   494 
   495 fun add_prems (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thms) =
   496   mk_mss (rules, congs, procs, bounds, thms @ prems, mk_rews, termless, depth);
   497 
   498 fun prems_of_mss (Mss {prems, ...}) = prems;
   499 
   500 
   501 (* mk_rews *)
   502 
   503 fun set_mk_rews
   504   (Mss {rules, congs, procs, bounds, prems, mk_rews, termless, depth}, mk) =
   505     mk_mss (rules, congs, procs, bounds, prems,
   506             {mk=mk, mk_sym= #mk_sym mk_rews, mk_eq_True= #mk_eq_True mk_rews},
   507             termless, depth);
   508 
   509 fun set_mk_sym
   510   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_sym) =
   511     mk_mss (rules, congs, procs, bounds, prems,
   512             {mk= #mk mk_rews, mk_sym= mk_sym, mk_eq_True= #mk_eq_True mk_rews},
   513             termless,depth);
   514 
   515 fun set_mk_eq_True
   516   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_eq_True) =
   517     mk_mss (rules, congs, procs, bounds, prems,
   518             {mk= #mk mk_rews, mk_sym= #mk_sym mk_rews, mk_eq_True= mk_eq_True},
   519             termless,depth);
   520 
   521 (* termless *)
   522 
   523 fun set_termless
   524   (Mss {rules, congs, procs, bounds, prems, mk_rews, depth, ...}, termless) =
   525     mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth);
   526 
   527 
   528 
   529 (** rewriting **)
   530 
   531 (*
   532   Uses conversions, see:
   533     L C Paulson, A higher-order implementation of rewriting,
   534     Science of Computer Programming 3 (1983), pages 119-149.
   535 *)
   536 
   537 val dest_eq = Drule.dest_equals o cprop_of;
   538 val lhs_of = fst o dest_eq;
   539 val rhs_of = snd o dest_eq;
   540 
   541 fun beta_eta_conversion t =
   542   let val thm = beta_conversion true t;
   543   in transitive thm (eta_conversion (rhs_of thm)) end;
   544 
   545 fun check_conv msg thm thm' =
   546   let
   547     val thm'' = transitive thm (transitive
   548       (symmetric (beta_eta_conversion (lhs_of thm'))) thm')
   549   in (if msg then trace_thm "SUCCEEDED" thm' else (); Some thm'') end
   550   handle THM _ =>
   551     let val {sign, prop = _ $ _ $ prop0, ...} = rep_thm thm;
   552     in
   553       (trace_thm "Proved wrong thm (Check subgoaler?)" thm';
   554        trace_term false "Should have proved:" sign prop0;
   555        None)
   556     end;
   557 
   558 
   559 (* mk_procrule *)
   560 
   561 fun mk_procrule thm =
   562   let val (_,prems,lhs,elhs,rhs,_) = decomp_simp thm
   563   in if rewrite_rule_extra_vars prems lhs rhs
   564      then (prthm true "Extra vars on rhs:" thm; [])
   565      else [mk_rrule2{thm=thm, name="", lhs=lhs, elhs=elhs, perm=false}]
   566   end;
   567 
   568 
   569 (* conversion to apply the meta simpset to a term *)
   570 
   571 (* Since the rewriting strategy is bottom-up, we avoid re-normalizing already
   572    normalized terms by carrying around the rhs of the rewrite rule just
   573    applied. This is called the `skeleton'. It is decomposed in parallel
   574    with the term. Once a Var is encountered, the corresponding term is
   575    already in normal form.
   576    skel0 is a dummy skeleton that is to enforce complete normalization.
   577 *)
   578 val skel0 = Bound 0;
   579 
   580 (* Use rhs as skeleton only if the lhs does not contain unnormalized bits.
   581    The latter may happen iff there are weak congruence rules for constants
   582    in the lhs.
   583 *)
   584 fun uncond_skel((_,weak),(lhs,rhs)) =
   585   if null weak then rhs (* optimization *)
   586   else if exists_Const (fn (c,_) => c mem weak) lhs then skel0
   587        else rhs;
   588 
   589 (* Behaves like unconditional rule if rhs does not contain vars not in the lhs.
   590    Otherwise those vars may become instantiated with unnormalized terms
   591    while the premises are solved.
   592 *)
   593 fun cond_skel(args as (congs,(lhs,rhs))) =
   594   if term_varnames rhs subset term_varnames lhs then uncond_skel(args)
   595   else skel0;
   596 
   597 (*
   598   we try in order:
   599     (1) beta reduction
   600     (2) unconditional rewrite rules
   601     (3) conditional rewrite rules
   602     (4) simplification procedures
   603 
   604   IMPORTANT: rewrite rules must not introduce new Vars or TVars!
   605 
   606 *)
   607 
   608 fun rewritec (prover, signt, maxt)
   609              (mss as Mss{rules, procs, termless, prems, congs, depth,...}) t =
   610   let
   611     val eta_thm = Thm.eta_conversion t;
   612     val eta_t' = rhs_of eta_thm;
   613     val eta_t = term_of eta_t';
   614     val tsigt = Sign.tsig_of signt;
   615     fun rew {thm, name, lhs, elhs, fo, perm} =
   616       let
   617         val {sign, prop, maxidx, ...} = rep_thm thm;
   618         val _ = if Sign.subsig (sign, signt) then ()
   619                 else (prthm true "Ignoring rewrite rule from different theory:" thm;
   620                       raise Pattern.MATCH);
   621         val (rthm, elhs') = if maxt = ~1 then (thm, elhs)
   622           else (Thm.incr_indexes (maxt+1) thm, Thm.cterm_incr_indexes (maxt+1) elhs);
   623         val insts = if fo then Thm.cterm_first_order_match (elhs', eta_t')
   624                           else Thm.cterm_match (elhs', eta_t');
   625         val thm' = Thm.instantiate insts (Thm.rename_boundvars lhs eta_t rthm);
   626         val prop' = #prop (rep_thm thm');
   627         val unconditional = (Logic.count_prems (prop',0) = 0);
   628         val (lhs', rhs') = Logic.dest_equals (Logic.strip_imp_concl prop')
   629       in
   630         if perm andalso not (termless (rhs', lhs'))
   631         then (trace_named_thm "Cannot apply permutative rewrite rule" (thm, name);
   632               trace_thm "Term does not become smaller:" thm'; None)
   633         else (trace_named_thm "Applying instance of rewrite rule" (thm, name);
   634            if unconditional
   635            then
   636              (trace_thm "Rewriting:" thm';
   637               let val lr = Logic.dest_equals prop;
   638                   val Some thm'' = check_conv false eta_thm thm'
   639               in Some (thm'', uncond_skel (congs, lr)) end)
   640            else
   641              (trace_thm "Trying to rewrite:" thm';
   642               case incr_depth mss of
   643                 None => (trace_thm "FAILED - reached depth limit" thm'; None)
   644               | Some mss =>
   645               (case prover mss thm' of
   646                 None       => (trace_thm "FAILED" thm'; None)
   647               | Some thm2 =>
   648                   (case check_conv true eta_thm thm2 of
   649                      None => None |
   650                      Some thm2' =>
   651                        let val concl = Logic.strip_imp_concl prop
   652                            val lr = Logic.dest_equals concl
   653                        in Some (thm2', cond_skel (congs, lr)) end))))
   654       end
   655 
   656     fun rews [] = None
   657       | rews (rrule :: rrules) =
   658           let val opt = rew rrule handle Pattern.MATCH => None
   659           in case opt of None => rews rrules | some => some end;
   660 
   661     fun sort_rrules rrs = let
   662       fun is_simple({thm, ...}:rrule) = case #prop (rep_thm thm) of
   663                                       Const("==",_) $ _ $ _ => true
   664                                       | _                   => false
   665       fun sort []        (re1,re2) = re1 @ re2
   666         | sort (rr::rrs) (re1,re2) = if is_simple rr
   667                                      then sort rrs (rr::re1,re2)
   668                                      else sort rrs (re1,rr::re2)
   669     in sort rrs ([],[]) end
   670 
   671     fun proc_rews ([]:simproc list) = None
   672       | proc_rews ({name, proc, lhs, ...} :: ps) =
   673           if Pattern.matches tsigt (term_of lhs, term_of t) then
   674             (debug_term false ("Trying procedure " ^ quote name ^ " on:") signt eta_t;
   675              case transform_failure (curry SIMPROC_FAIL name)
   676                  (fn () => proc signt prems eta_t) () of
   677                None => (debug false "FAILED"; proc_rews ps)
   678              | Some raw_thm =>
   679                  (trace_thm ("Procedure " ^ quote name ^ " produced rewrite rule:") raw_thm;
   680                   (case rews (mk_procrule raw_thm) of
   681                     None => (trace_cterm true ("IGNORED result of simproc " ^ quote name ^
   682                       " -- does not match") t; proc_rews ps)
   683                   | some => some)))
   684           else proc_rews ps;
   685   in case eta_t of
   686        Abs _ $ _ => Some (transitive eta_thm
   687          (beta_conversion false eta_t'), skel0)
   688      | _ => (case rews (sort_rrules (Net.match_term rules eta_t)) of
   689                None => proc_rews (Net.match_term procs eta_t)
   690              | some => some)
   691   end;
   692 
   693 
   694 (* conversion to apply a congruence rule to a term *)
   695 
   696 fun congc (prover,signt,maxt) {thm=cong,lhs=lhs} t =
   697   let val {sign, ...} = rep_thm cong
   698       val _ = if Sign.subsig (sign, signt) then ()
   699                  else error("Congruence rule from different theory")
   700       val rthm = if maxt = ~1 then cong else Thm.incr_indexes (maxt+1) cong;
   701       val rlhs = fst (Drule.dest_equals (Drule.strip_imp_concl (cprop_of rthm)));
   702       val insts = Thm.cterm_match (rlhs, t)
   703       (* Pattern.match can raise Pattern.MATCH;
   704          is handled when congc is called *)
   705       val thm' = Thm.instantiate insts (Thm.rename_boundvars (term_of rlhs) (term_of t) rthm);
   706       val unit = trace_thm "Applying congruence rule:" thm';
   707       fun err (msg, thm) = (prthm false msg thm; error "Failed congruence proof!")
   708   in case prover thm' of
   709        None => err ("Could not prove", thm')
   710      | Some thm2 => (case check_conv true (beta_eta_conversion t) thm2 of
   711           None => err ("Should not have proved", thm2)
   712         | Some thm2' =>
   713             if op aconv (pairself term_of (dest_equals (cprop_of thm2')))
   714             then None else Some thm2')
   715   end;
   716 
   717 val (cA, (cB, cC)) =
   718   apsnd dest_equals (dest_implies (hd (cprems_of Drule.imp_cong)));
   719 
   720 fun transitive1 None None = None
   721   | transitive1 (Some thm1) None = Some thm1
   722   | transitive1 None (Some thm2) = Some thm2
   723   | transitive1 (Some thm1) (Some thm2) = Some (transitive thm1 thm2)
   724 
   725 fun transitive2 thm = transitive1 (Some thm);
   726 fun transitive3 thm = transitive1 thm o Some;
   727 
   728 fun imp_cong' e = combination (combination refl_implies e);
   729 
   730 fun bottomc ((simprem,useprem,mutsimp), prover, sign, maxidx) =
   731   let
   732     fun botc skel mss t =
   733           if is_Var skel then None
   734           else
   735           (case subc skel mss t of
   736              some as Some thm1 =>
   737                (case rewritec (prover, sign, maxidx) mss (rhs_of thm1) of
   738                   Some (thm2, skel2) =>
   739                     transitive2 (transitive thm1 thm2)
   740                       (botc skel2 mss (rhs_of thm2))
   741                 | None => some)
   742            | None =>
   743                (case rewritec (prover, sign, maxidx) mss t of
   744                   Some (thm2, skel2) => transitive2 thm2
   745                     (botc skel2 mss (rhs_of thm2))
   746                 | None => None))
   747 
   748     and try_botc mss t =
   749           (case botc skel0 mss t of
   750              Some trec1 => trec1 | None => (reflexive t))
   751 
   752     and subc skel
   753           (mss as Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) t0 =
   754        (case term_of t0 of
   755            Abs (a, T, t) =>
   756              let val b = variant bounds a
   757                  val (v, t') = Thm.dest_abs (Some ("." ^ b)) t0
   758                  val mss' = mk_mss (rules, congs, procs, b :: bounds, prems, mk_rews, termless,depth)
   759                  val skel' = case skel of Abs (_, _, sk) => sk | _ => skel0
   760              in case botc skel' mss' t' of
   761                   Some thm => Some (abstract_rule a v thm)
   762                 | None => None
   763              end
   764          | t $ _ => (case t of
   765              Const ("==>", _) $ _  => impc t0 mss
   766            | Abs _ =>
   767                let val thm = beta_conversion false t0
   768                in case subc skel0 mss (rhs_of thm) of
   769                     None => Some thm
   770                   | Some thm' => Some (transitive thm thm')
   771                end
   772            | _  =>
   773                let fun appc () =
   774                      let
   775                        val (tskel, uskel) = case skel of
   776                            tskel $ uskel => (tskel, uskel)
   777                          | _ => (skel0, skel0);
   778                        val (ct, cu) = Thm.dest_comb t0
   779                      in
   780                      (case botc tskel mss ct of
   781                         Some thm1 =>
   782                           (case botc uskel mss cu of
   783                              Some thm2 => Some (combination thm1 thm2)
   784                            | None => Some (combination thm1 (reflexive cu)))
   785                       | None =>
   786                           (case botc uskel mss cu of
   787                              Some thm1 => Some (combination (reflexive ct) thm1)
   788                            | None => None))
   789                      end
   790                    val (h, ts) = strip_comb t
   791                in case cong_name h of
   792                     Some a =>
   793                       (case assoc_string (fst congs, a) of
   794                          None => appc ()
   795                        | Some cong =>
   796 (* post processing: some partial applications h t1 ... tj, j <= length ts,
   797    may be a redex. Example: map (%x.x) = (%xs.xs) wrt map_cong *)
   798                           (let
   799                              val thm = congc (prover mss, sign, maxidx) cong t0;
   800                              val t = if_none (apsome rhs_of thm) t0;
   801                              val (cl, cr) = Thm.dest_comb t
   802                              val dVar = Var(("", 0), dummyT)
   803                              val skel =
   804                                list_comb (h, replicate (length ts) dVar)
   805                            in case botc skel mss cl of
   806                                 None => thm
   807                               | Some thm' => transitive3 thm
   808                                   (combination thm' (reflexive cr))
   809                            end handle TERM _ => error "congc result"
   810                                     | Pattern.MATCH => appc ()))
   811                   | _ => appc ()
   812                end)
   813          | _ => None)
   814 
   815     and impc ct mss =
   816       if mutsimp then mut_impc0 [] ct [] [] mss else nonmut_impc ct mss
   817 
   818     and rules_of_prem mss prem =
   819       if maxidx_of_term (term_of prem) <> ~1
   820       then (trace_cterm true
   821         "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem; ([], None))
   822       else
   823         let val asm = assume prem
   824         in (extract_safe_rrules (mss, asm), Some asm) end
   825 
   826     and add_rrules (rrss, asms) mss =
   827       add_prems (foldl (insert_rrule true) (mss, flat rrss), mapfilter I asms)
   828 
   829     and disch r (prem, eq) =
   830       let
   831         val (lhs, rhs) = dest_eq eq;
   832         val eq' = implies_elim (Thm.instantiate
   833           ([], [(cA, prem), (cB, lhs), (cC, rhs)]) Drule.imp_cong)
   834           (implies_intr prem eq)
   835       in if not r then eq' else
   836         let
   837           val (prem', concl) = dest_implies lhs;
   838           val (prem'', _) = dest_implies rhs
   839         in transitive (transitive
   840           (Thm.instantiate ([], [(cA, prem'), (cB, prem), (cC, concl)])
   841              Drule.swap_prems_eq) eq')
   842           (Thm.instantiate ([], [(cA, prem), (cB, prem''), (cC, concl)])
   843              Drule.swap_prems_eq)
   844         end
   845       end
   846 
   847     and rebuild [] _ _ _ _ eq = eq
   848       | rebuild (prem :: prems) concl (rrs :: rrss) (asm :: asms) mss eq =
   849           let
   850             val mss' = add_rrules (rev rrss, rev asms) mss;
   851             val concl' =
   852               Drule.mk_implies (prem, if_none (apsome rhs_of eq) concl);
   853             val dprem = apsome (curry (disch false) prem)
   854           in case rewritec (prover, sign, maxidx) mss' concl' of
   855               None => rebuild prems concl' rrss asms mss (dprem eq)
   856             | Some (eq', _) => transitive2 (foldl (disch false o swap)
   857                   (the (transitive3 (dprem eq) eq'), prems))
   858                 (mut_impc0 (rev prems) (rhs_of eq') (rev rrss) (rev asms) mss)
   859           end
   860           
   861     and mut_impc0 prems concl rrss asms mss =
   862       let
   863         val prems' = strip_imp_prems concl;
   864         val (rrss', asms') = split_list (map (rules_of_prem mss) prems')
   865       in mut_impc (prems @ prems') (strip_imp_concl concl) (rrss @ rrss')
   866         (asms @ asms') [] [] [] [] mss ~1 ~1
   867       end
   868  
   869     and mut_impc [] concl [] [] prems' rrss' asms' eqns mss changed k =
   870         transitive1 (foldl (fn (eq2, (eq1, prem)) => transitive1 eq1
   871             (apsome (curry (disch false) prem) eq2)) (None, eqns ~~ prems'))
   872           (if changed > 0 then
   873              mut_impc (rev prems') concl (rev rrss') (rev asms')
   874                [] [] [] [] mss ~1 changed
   875            else rebuild prems' concl rrss' asms' mss
   876              (botc skel0 (add_rrules (rev rrss', rev asms') mss) concl))
   877 
   878       | mut_impc (prem :: prems) concl (rrs :: rrss) (asm :: asms)
   879           prems' rrss' asms' eqns mss changed k =
   880         case (if k = 0 then None else botc skel0 (add_rrules
   881           (rev rrss' @ rrss, rev asms' @ asms) mss) prem) of
   882             None => mut_impc prems concl rrss asms (prem :: prems')
   883               (rrs :: rrss') (asm :: asms') (None :: eqns) mss changed
   884               (if k = 0 then 0 else k - 1)
   885           | Some eqn =>
   886             let
   887               val prem' = rhs_of eqn;
   888               val tprems = map term_of prems;
   889               val i = 1 + foldl Int.max (~1, map (fn p =>
   890                 find_index_eq p tprems) (#hyps (rep_thm eqn)));
   891               val (rrs', asm') = rules_of_prem mss prem'
   892             in mut_impc prems concl rrss asms (prem' :: prems')
   893               (rrs' :: rrss') (asm' :: asms') (Some (foldr (disch true)
   894                 (take (i, prems), imp_cong' eqn (reflexive (Drule.list_implies
   895                   (drop (i, prems), concl))))) :: eqns) mss (length prems') ~1
   896             end
   897 
   898      (* legacy code - only for backwards compatibility *)
   899      and nonmut_impc ct mss =
   900        let val (prem, conc) = dest_implies ct;
   901            val thm1 = if simprem then botc skel0 mss prem else None;
   902            val prem1 = if_none (apsome rhs_of thm1) prem;
   903            val mss1 = if not useprem then mss else add_rrules
   904              (apsnd single (apfst single (rules_of_prem mss prem1))) mss
   905        in (case botc skel0 mss1 conc of
   906            None => (case thm1 of
   907                None => None
   908              | Some thm1' => Some (imp_cong' thm1' (reflexive conc)))
   909          | Some thm2 =>
   910            let val thm2' = disch false (prem1, thm2)
   911            in (case thm1 of
   912                None => Some thm2'
   913              | Some thm1' =>
   914                  Some (transitive (imp_cong' thm1' (reflexive conc)) thm2'))
   915            end)
   916        end
   917 
   918  in try_botc end;
   919 
   920 
   921 (*** Meta-rewriting: rewrites t to u and returns the theorem t==u ***)
   922 
   923 (*
   924   Parameters:
   925     mode = (simplify A,
   926             use A in simplifying B,
   927             use prems of B (if B is again a meta-impl.) to simplify A)
   928            when simplifying A ==> B
   929     mss: contains equality theorems of the form [|p1,...|] ==> t==u
   930     prover: how to solve premises in conditional rewrites and congruences
   931 *)
   932 
   933 fun rewrite_cterm mode prover mss ct =
   934   let val {sign, t, maxidx, ...} = rep_cterm ct
   935       val Mss{depth, ...} = mss
   936   in simp_depth := depth;
   937      bottomc (mode, prover, sign, maxidx) mss ct
   938   end
   939   handle THM (s, _, thms) =>
   940     error ("Exception THM was raised in simplifier:\n" ^ s ^ "\n" ^
   941       Pretty.string_of (Display.pretty_thms thms));
   942 
   943 (*In [A1,...,An]==>B, rewrite the selected A's only -- for rewrite_goals_tac*)
   944 fun goals_conv pred cv =
   945   let fun gconv i ct =
   946         let val (A,B) = Drule.dest_implies ct
   947         in imp_cong' (if pred i then cv A else reflexive A) (gconv (i+1) B) end
   948         handle TERM _ => reflexive ct
   949   in gconv 1 end;
   950 
   951 (* Rewrite A in !!x1,...,xn. A *)
   952 fun forall_conv cv ct =
   953   let val p as (ct1, ct2) = Thm.dest_comb ct
   954   in (case pairself term_of p of
   955       (Const ("all", _), Abs (s, _, _)) =>
   956          let val (v, ct') = Thm.dest_abs (Some "@") ct2;
   957          in Thm.combination (Thm.reflexive ct1)
   958            (Thm.abstract_rule s v (forall_conv cv ct'))
   959          end
   960     | _ => cv ct)
   961   end handle TERM _ => cv ct;
   962 
   963 (*Use a conversion to transform a theorem*)
   964 fun fconv_rule cv th = equal_elim (cv (cprop_of th)) th;
   965 
   966 (*Rewrite a cterm*)
   967 fun rewrite_aux _ _ [] = (fn ct => Thm.reflexive ct)
   968   | rewrite_aux prover full thms = rewrite_cterm (full, false, false) prover (mss_of thms);
   969 
   970 (*Rewrite a theorem*)
   971 fun simplify_aux _ _ [] = (fn th => th)
   972   | simplify_aux prover full thms =
   973       fconv_rule (rewrite_cterm (full, false, false) prover (mss_of thms));
   974 
   975 fun rewrite_thm mode prover mss = fconv_rule (rewrite_cterm mode prover mss);
   976 
   977 (*Rewrite the subgoals of a proof state (represented by a theorem) *)
   978 fun rewrite_goals_rule_aux _ []   th = th
   979   | rewrite_goals_rule_aux prover thms th =
   980       fconv_rule (goals_conv (K true) (rewrite_cterm (true, true, false) prover
   981         (mss_of thms))) th;
   982 
   983 (*Rewrite the subgoal of a proof state (represented by a theorem) *)
   984 fun rewrite_goal_rule mode prover mss i thm =
   985   if 0 < i  andalso  i <= nprems_of thm
   986   then fconv_rule (goals_conv (fn j => j=i) (rewrite_cterm mode prover mss)) thm
   987   else raise THM("rewrite_goal_rule",i,[thm]);
   988 
   989 
   990 (*simple term rewriting -- without proofs*)
   991 fun rewrite_term sg rules procs =
   992   Pattern.rewrite_term (Sign.tsig_of sg) (map decomp_simp' rules) procs;
   993 
   994 end;
   995 
   996 structure BasicMetaSimplifier: BASIC_META_SIMPLIFIER = MetaSimplifier;
   997 open BasicMetaSimplifier;