src/Pure/meta_simplifier.ML
author nipkow
Tue Feb 25 18:49:23 2003 +0100 (2003-02-25)
changeset 13828 fb6ec40dd291
parent 13661 ec97dfc2bfe0
child 13835 12b2ffbe543a
permissions -rw-r--r--
added simp_depth_limit
     1 (*  Title:      Pure/meta_simplifier.ML
     2     ID:         $Id$
     3     Author:     Tobias Nipkow and Stefan Berghofer
     4     License:    GPL (GNU GENERAL PUBLIC LICENSE)
     5 
     6 Meta-level Simplification.
     7 *)
     8 
     9 signature BASIC_META_SIMPLIFIER =
    10 sig
    11   val trace_simp: bool ref
    12   val debug_simp: bool ref
    13   val simp_depth_limit: int ref
    14 end;
    15 
    16 signature META_SIMPLIFIER =
    17 sig
    18   include BASIC_META_SIMPLIFIER
    19   exception SIMPLIFIER of string * thm
    20   exception SIMPROC_FAIL of string * exn
    21   type meta_simpset
    22   val dest_mss          : meta_simpset ->
    23     {simps: thm list, congs: thm list, procs: (string * cterm list) list}
    24   val empty_mss         : meta_simpset
    25   val clear_mss         : meta_simpset -> meta_simpset
    26   val merge_mss         : meta_simpset * meta_simpset -> meta_simpset
    27   val add_simps         : meta_simpset * thm list -> meta_simpset
    28   val del_simps         : meta_simpset * thm list -> meta_simpset
    29   val mss_of            : thm list -> meta_simpset
    30   val add_congs         : meta_simpset * thm list -> meta_simpset
    31   val del_congs         : meta_simpset * thm list -> meta_simpset
    32   val add_simprocs      : meta_simpset *
    33     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    34       -> meta_simpset
    35   val del_simprocs      : meta_simpset *
    36     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    37       -> meta_simpset
    38   val add_prems         : meta_simpset * thm list -> meta_simpset
    39   val prems_of_mss      : meta_simpset -> thm list
    40   val set_mk_rews       : meta_simpset * (thm -> thm list) -> meta_simpset
    41   val set_mk_sym        : meta_simpset * (thm -> thm option) -> meta_simpset
    42   val set_mk_eq_True    : meta_simpset * (thm -> thm option) -> meta_simpset
    43   val set_termless      : meta_simpset * (term * term -> bool) -> meta_simpset
    44   val beta_eta_conversion: cterm -> thm
    45   val rewrite_cterm: bool * bool * bool ->
    46     (meta_simpset -> thm -> thm option) -> meta_simpset -> cterm -> thm
    47   val goals_conv        : (int -> bool) -> (cterm -> thm) -> cterm -> thm
    48   val forall_conv       : (cterm -> thm) -> cterm -> thm
    49   val fconv_rule        : (cterm -> thm) -> thm -> thm
    50   val rewrite_aux       : (meta_simpset -> thm -> thm option) -> bool -> thm list -> cterm -> thm
    51   val simplify_aux      : (meta_simpset -> thm -> thm option) -> bool -> thm list -> thm -> thm
    52   val rewrite_thm       : bool * bool * bool
    53                           -> (meta_simpset -> thm -> thm option)
    54                           -> meta_simpset -> thm -> thm
    55   val rewrite_goals_rule_aux: (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
    56   val rewrite_goal_rule : bool* bool * bool
    57                           -> (meta_simpset -> thm -> thm option)
    58                           -> meta_simpset -> int -> thm -> thm
    59   val rewrite_term: Sign.sg -> thm list -> (term -> term option) list -> term -> term
    60 end;
    61 
    62 structure MetaSimplifier : META_SIMPLIFIER =
    63 struct
    64 
    65 (** diagnostics **)
    66 
    67 exception SIMPLIFIER of string * thm;
    68 exception SIMPROC_FAIL of string * exn;
    69 
    70 val simp_depth = ref 0;
    71 val simp_depth_limit = ref 1000;
    72 
    73 local
    74 
    75 fun println a =
    76   tracing ((case ! simp_depth of 0 => "" | n => "[" ^ string_of_int n ^ "]") ^ a);
    77 
    78 fun prnt warn a = if warn then warning a else println a;
    79 fun prtm warn a sign t = prnt warn (a ^ "\n" ^ Sign.string_of_term sign t);
    80 fun prctm warn a t = prnt warn (a ^ "\n" ^ Display.string_of_cterm t);
    81 
    82 in
    83 
    84 fun prthm warn a = prctm warn a o Thm.cprop_of;
    85 
    86 val trace_simp = ref false;
    87 val debug_simp = ref false;
    88 
    89 fun trace warn a = if !trace_simp then prnt warn a else ();
    90 fun debug warn a = if !debug_simp then prnt warn a else ();
    91 
    92 fun trace_term warn a sign t = if !trace_simp then prtm warn a sign t else ();
    93 fun trace_cterm warn a t = if !trace_simp then prctm warn a t else ();
    94 fun debug_term warn a sign t = if !debug_simp then prtm warn a sign t else ();
    95 
    96 fun trace_thm a thm =
    97   let val {sign, prop, ...} = rep_thm thm
    98   in trace_term false a sign prop end;
    99 
   100 fun trace_named_thm a (thm, name) =
   101   trace_thm (a ^ (if name = "" then "" else " " ^ quote name) ^ ":") thm;
   102 
   103 end;
   104 
   105 
   106 (** meta simp sets **)
   107 
   108 (* basic components *)
   109 
   110 type rrule = {thm: thm, name: string, lhs: term, elhs: cterm, fo: bool, perm: bool};
   111 (* thm: the rewrite rule
   112    name: name of theorem from which rewrite rule was extracted
   113    lhs: the left-hand side
   114    elhs: the etac-contracted lhs.
   115    fo:  use first-order matching
   116    perm: the rewrite rule is permutative
   117 Remarks:
   118   - elhs is used for matching,
   119     lhs only for preservation of bound variable names.
   120   - fo is set iff
   121     either elhs is first-order (no Var is applied),
   122            in which case fo-matching is complete,
   123     or elhs is not a pattern,
   124        in which case there is nothing better to do.
   125 *)
   126 type cong = {thm: thm, lhs: cterm};
   127 type simproc =
   128  {name: string, proc: Sign.sg -> thm list -> term -> thm option, lhs: cterm, id: stamp};
   129 
   130 fun eq_rrule ({thm = thm1, ...}: rrule, {thm = thm2, ...}: rrule) =
   131   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   132 
   133 fun eq_cong ({thm = thm1, ...}: cong, {thm = thm2, ...}: cong) =
   134   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   135 
   136 fun eq_prem (thm1, thm2) =
   137   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   138 
   139 fun eq_simproc ({id = s1, ...}:simproc, {id = s2, ...}:simproc) = (s1 = s2);
   140 
   141 fun mk_simproc (name, proc, lhs, id) =
   142   {name = name, proc = proc, lhs = lhs, id = id};
   143 
   144 
   145 (* datatype mss *)
   146 
   147 (*
   148   A "mss" contains data needed during conversion:
   149     rules: discrimination net of rewrite rules;
   150     congs: association list of congruence rules and
   151            a list of `weak' congruence constants.
   152            A congruence is `weak' if it avoids normalization of some argument.
   153     procs: discrimination net of simplification procedures
   154       (functions that prove rewrite rules on the fly);
   155     bounds: names of bound variables already used
   156       (for generating new names when rewriting under lambda abstractions);
   157     prems: current premises;
   158     mk_rews: mk: turns simplification thms into rewrite rules;
   159              mk_sym: turns == around; (needs Drule!)
   160              mk_eq_True: turns P into P == True - logic specific;
   161     termless: relation for ordered rewriting;
   162     depth: depth of conditional rewriting;
   163 *)
   164 
   165 datatype meta_simpset =
   166   Mss of {
   167     rules: rrule Net.net,
   168     congs: (string * cong) list * string list,
   169     procs: simproc Net.net,
   170     bounds: string list,
   171     prems: thm list,
   172     mk_rews: {mk: thm -> thm list,
   173               mk_sym: thm -> thm option,
   174               mk_eq_True: thm -> thm option},
   175     termless: term * term -> bool,
   176     depth: int};
   177 
   178 fun mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth) =
   179   Mss {rules = rules, congs = congs, procs = procs, bounds = bounds,
   180        prems=prems, mk_rews=mk_rews, termless=termless, depth=depth};
   181 
   182 fun upd_rules(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}, rules') =
   183   mk_mss(rules',congs,procs,bounds,prems,mk_rews,termless,depth);
   184 
   185 val empty_mss =
   186   let val mk_rews = {mk = K [], mk_sym = K None, mk_eq_True = K None}
   187   in mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, Term.termless, 0) end;
   188 
   189 fun clear_mss (Mss {mk_rews, termless, ...}) =
   190   mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, termless,0);
   191 
   192 fun incr_depth(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) =
   193   let val depth1 = depth+1
   194   in if depth1 > !simp_depth_limit
   195      then (warning "simp_depth_limit exceeded - giving up"; None)
   196      else (if depth1 mod 5 = 0
   197            then warning("Simplification depth " ^ string_of_int depth1)
   198            else ();
   199            Some(mk_mss(rules,congs,procs,bounds,prems,mk_rews,termless,depth1))
   200           )
   201   end;
   202 
   203 
   204 (** simpset operations **)
   205 
   206 (* term variables *)
   207 
   208 val add_term_varnames = foldl_aterms (fn (xs, Var (x, _)) => ins_ix (x, xs) | (xs, _) => xs);
   209 fun term_varnames t = add_term_varnames ([], t);
   210 
   211 
   212 (* dest_mss *)
   213 
   214 fun dest_mss (Mss {rules, congs, procs, ...}) =
   215   {simps = map (fn (_, {thm, ...}) => thm) (Net.dest rules),
   216    congs = map (fn (_, {thm, ...}) => thm) (fst congs),
   217    procs =
   218      map (fn (_, {name, lhs, id, ...}) => ((name, lhs), id)) (Net.dest procs)
   219      |> partition_eq eq_snd
   220      |> map (fn ps => (#1 (#1 (hd ps)), map (#2 o #1) ps))
   221      |> Library.sort_wrt #1};
   222 
   223 
   224 (* merge_mss *)       (*NOTE: ignores mk_rews, termless and depth of 2nd mss*)
   225 
   226 fun merge_mss
   227  (Mss {rules = rules1, congs = (congs1,weak1), procs = procs1,
   228        bounds = bounds1, prems = prems1, mk_rews, termless, depth},
   229   Mss {rules = rules2, congs = (congs2,weak2), procs = procs2,
   230        bounds = bounds2, prems = prems2, ...}) =
   231       mk_mss
   232        (Net.merge (rules1, rules2, eq_rrule),
   233         (gen_merge_lists (eq_cong o pairself snd) congs1 congs2,
   234         merge_lists weak1 weak2),
   235         Net.merge (procs1, procs2, eq_simproc),
   236         merge_lists bounds1 bounds2,
   237         gen_merge_lists eq_prem prems1 prems2,
   238         mk_rews, termless, depth);
   239 
   240 
   241 (* add_simps *)
   242 
   243 fun mk_rrule2{thm, name, lhs, elhs, perm} =
   244   let val fo = Pattern.first_order (term_of elhs) orelse not(Pattern.pattern (term_of elhs))
   245   in {thm=thm, name=name, lhs=lhs, elhs=elhs, fo=fo, perm=perm} end
   246 
   247 fun insert_rrule quiet (mss as Mss {rules,...},
   248                  rrule as {thm,name,lhs,elhs,perm}) =
   249   (trace_named_thm "Adding rewrite rule" (thm, name);
   250    let val rrule2 as {elhs,...} = mk_rrule2 rrule
   251        val rules' = Net.insert_term ((term_of elhs, rrule2), rules, eq_rrule)
   252    in upd_rules(mss,rules') end
   253    handle Net.INSERT => if quiet then mss else
   254      (prthm true "Ignoring duplicate rewrite rule:" thm; mss));
   255 
   256 fun vperm (Var _, Var _) = true
   257   | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t)
   258   | vperm (t1 $ t2, u1 $ u2) = vperm (t1, u1) andalso vperm (t2, u2)
   259   | vperm (t, u) = (t = u);
   260 
   261 fun var_perm (t, u) =
   262   vperm (t, u) andalso eq_set (term_varnames t, term_varnames u);
   263 
   264 (* FIXME: it seems that the conditions on extra variables are too liberal if
   265 prems are nonempty: does solving the prems really guarantee instantiation of
   266 all its Vars? Better: a dynamic check each time a rule is applied.
   267 *)
   268 fun rewrite_rule_extra_vars prems elhs erhs =
   269   not (term_varnames erhs subset foldl add_term_varnames (term_varnames elhs, prems))
   270   orelse
   271   not ((term_tvars erhs) subset
   272        (term_tvars elhs  union  List.concat(map term_tvars prems)));
   273 
   274 (*Simple test for looping rewrite rules and stupid orientations*)
   275 fun reorient sign prems lhs rhs =
   276    rewrite_rule_extra_vars prems lhs rhs
   277   orelse
   278    is_Var (head_of lhs)
   279   orelse
   280    (exists (apl (lhs, Logic.occs)) (rhs :: prems))
   281   orelse
   282    (null prems andalso
   283     Pattern.matches (#tsig (Sign.rep_sg sign)) (lhs, rhs))
   284     (*the condition "null prems" is necessary because conditional rewrites
   285       with extra variables in the conditions may terminate although
   286       the rhs is an instance of the lhs. Example: ?m < ?n ==> f(?n) == f(?m)*)
   287   orelse
   288    (is_Const lhs andalso not(is_Const rhs))
   289 
   290 fun decomp_simp thm =
   291   let val {sign, prop, ...} = rep_thm thm;
   292       val prems = Logic.strip_imp_prems prop;
   293       val concl = Drule.strip_imp_concl (cprop_of thm);
   294       val (lhs, rhs) = Drule.dest_equals concl handle TERM _ =>
   295         raise SIMPLIFIER ("Rewrite rule not a meta-equality", thm)
   296       val elhs = snd (Drule.dest_equals (cprop_of (Thm.eta_conversion lhs)));
   297       val elhs = if elhs=lhs then lhs else elhs (* try to share *)
   298       val erhs = Pattern.eta_contract (term_of rhs);
   299       val perm = var_perm (term_of elhs, erhs) andalso not (term_of elhs aconv erhs)
   300                  andalso not (is_Var (term_of elhs))
   301   in (sign, prems, term_of lhs, elhs, term_of rhs, perm) end;
   302 
   303 fun decomp_simp' thm =
   304   let val (_, _, lhs, _, rhs, _) = decomp_simp thm in
   305     if Thm.nprems_of thm > 0 then raise SIMPLIFIER ("Bad conditional rewrite rule", thm)
   306     else (lhs, rhs)
   307   end;
   308 
   309 fun mk_eq_True (Mss{mk_rews={mk_eq_True,...},...}) (thm, name) =
   310   case mk_eq_True thm of
   311     None => []
   312   | Some eq_True =>
   313       let val (_,_,lhs,elhs,_,_) = decomp_simp eq_True
   314       in [{thm=eq_True, name=name, lhs=lhs, elhs=elhs, perm=false}] end;
   315 
   316 (* create the rewrite rule and possibly also the ==True variant,
   317    in case there are extra vars on the rhs *)
   318 fun rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm2) =
   319   let val rrule = {thm=thm, name=name, lhs=lhs, elhs=elhs, perm=false}
   320   in if (term_varnames rhs)  subset (term_varnames lhs) andalso
   321         (term_tvars rhs) subset (term_tvars lhs)
   322      then [rrule]
   323      else mk_eq_True mss (thm2, name) @ [rrule]
   324   end;
   325 
   326 fun mk_rrule mss (thm, name) =
   327   let val (_,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   328   in if perm then [{thm=thm, name=name, lhs=lhs, elhs=elhs, perm=true}] else
   329      (* weak test for loops: *)
   330      if rewrite_rule_extra_vars prems lhs rhs orelse
   331         is_Var (term_of elhs)
   332      then mk_eq_True mss (thm, name)
   333      else rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm)
   334   end;
   335 
   336 fun orient_rrule mss (thm, name) =
   337   let val (sign,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   338   in if perm then [{thm=thm, name=name, lhs=lhs, elhs=elhs, perm=true}]
   339      else if reorient sign prems lhs rhs
   340           then if reorient sign prems rhs lhs
   341                then mk_eq_True mss (thm, name)
   342                else let val Mss{mk_rews={mk_sym,...},...} = mss
   343                     in case mk_sym thm of
   344                          None => []
   345                        | Some thm' =>
   346                            let val (_,_,lhs',elhs',rhs',_) = decomp_simp thm'
   347                            in rrule_eq_True(thm',name,lhs',elhs',rhs',mss,thm) end
   348                     end
   349           else rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm)
   350   end;
   351 
   352 fun extract_rews(Mss{mk_rews = {mk,...},...},thms) =
   353   flat (map (fn thm => map (rpair (Thm.name_of_thm thm)) (mk thm)) thms);
   354 
   355 fun orient_comb_simps comb mk_rrule (mss,thms) =
   356   let val rews = extract_rews(mss,thms)
   357       val rrules = flat (map mk_rrule rews)
   358   in foldl comb (mss,rrules) end
   359 
   360 (* Add rewrite rules explicitly; do not reorient! *)
   361 fun add_simps(mss,thms) =
   362   orient_comb_simps (insert_rrule false) (mk_rrule mss) (mss,thms);
   363 
   364 fun mss_of thms = foldl (insert_rrule false) (empty_mss, flat
   365   (map (fn thm => mk_rrule empty_mss (thm, Thm.name_of_thm thm)) thms));
   366 
   367 fun extract_safe_rrules(mss,thm) =
   368   flat (map (orient_rrule mss) (extract_rews(mss,[thm])));
   369 
   370 (* del_simps *)
   371 
   372 fun del_rrule(mss as Mss {rules,...},
   373               rrule as {thm, elhs, ...}) =
   374   (upd_rules(mss, Net.delete_term ((term_of elhs, rrule), rules, eq_rrule))
   375    handle Net.DELETE =>
   376      (prthm true "Rewrite rule not in simpset:" thm; mss));
   377 
   378 fun del_simps(mss,thms) =
   379   orient_comb_simps del_rrule (map mk_rrule2 o mk_rrule mss) (mss,thms);
   380 
   381 
   382 (* add_congs *)
   383 
   384 fun is_full_cong_prems [] varpairs = null varpairs
   385   | is_full_cong_prems (p::prems) varpairs =
   386     (case Logic.strip_assums_concl p of
   387        Const("==",_) $ lhs $ rhs =>
   388          let val (x,xs) = strip_comb lhs and (y,ys) = strip_comb rhs
   389          in is_Var x  andalso  forall is_Bound xs  andalso
   390             null(findrep(xs))  andalso xs=ys andalso
   391             (x,y) mem varpairs andalso
   392             is_full_cong_prems prems (varpairs\(x,y))
   393          end
   394      | _ => false);
   395 
   396 fun is_full_cong thm =
   397 let val prems = prems_of thm
   398     and concl = concl_of thm
   399     val (lhs,rhs) = Logic.dest_equals concl
   400     val (f,xs) = strip_comb lhs
   401     and (g,ys) = strip_comb rhs
   402 in
   403   f=g andalso null(findrep(xs@ys)) andalso length xs = length ys andalso
   404   is_full_cong_prems prems (xs ~~ ys)
   405 end
   406 
   407 fun add_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   408   let
   409     val (lhs, _) = Drule.dest_equals (Drule.strip_imp_concl (cprop_of thm)) handle TERM _ =>
   410       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   411 (*   val lhs = Pattern.eta_contract lhs; *)
   412     val (a, _) = dest_Const (head_of (term_of lhs)) handle TERM _ =>
   413       raise SIMPLIFIER ("Congruence must start with a constant", thm);
   414     val (alist,weak) = congs
   415     val alist2 = overwrite_warn (alist, (a,{lhs=lhs, thm=thm}))
   416            ("Overwriting congruence rule for " ^ quote a);
   417     val weak2 = if is_full_cong thm then weak else a::weak
   418   in
   419     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   420   end;
   421 
   422 val (op add_congs) = foldl add_cong;
   423 
   424 
   425 (* del_congs *)
   426 
   427 fun del_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   428   let
   429     val (lhs, _) = Logic.dest_equals (concl_of thm) handle TERM _ =>
   430       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   431 (*   val lhs = Pattern.eta_contract lhs; *)
   432     val (a, _) = dest_Const (head_of lhs) handle TERM _ =>
   433       raise SIMPLIFIER ("Congruence must start with a constant", thm);
   434     val (alist,_) = congs
   435     val alist2 = filter (fn (x,_)=> x<>a) alist
   436     val weak2 = mapfilter (fn(a,{thm,...}) => if is_full_cong thm then None
   437                                               else Some a)
   438                    alist2
   439   in
   440     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   441   end;
   442 
   443 val (op del_congs) = foldl del_cong;
   444 
   445 
   446 (* add_simprocs *)
   447 
   448 fun add_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   449     (name, lhs, proc, id)) =
   450   let val {sign, t, ...} = rep_cterm lhs
   451   in (trace_term false ("Adding simplification procedure " ^ quote name ^ " for")
   452       sign t;
   453     mk_mss (rules, congs,
   454       Net.insert_term ((t, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   455         handle Net.INSERT =>
   456             (warning ("Ignoring duplicate simplification procedure \""
   457                       ^ name ^ "\"");
   458              procs),
   459         bounds, prems, mk_rews, termless,depth))
   460   end;
   461 
   462 fun add_simproc (mss, (name, lhss, proc, id)) =
   463   foldl add_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   464 
   465 val add_simprocs = foldl add_simproc;
   466 
   467 
   468 (* del_simprocs *)
   469 
   470 fun del_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   471     (name, lhs, proc, id)) =
   472   mk_mss (rules, congs,
   473     Net.delete_term ((term_of lhs, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   474       handle Net.DELETE =>
   475           (warning ("Simplification procedure \"" ^ name ^
   476                        "\" not in simpset"); procs),
   477       bounds, prems, mk_rews, termless, depth);
   478 
   479 fun del_simproc (mss, (name, lhss, proc, id)) =
   480   foldl del_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   481 
   482 val del_simprocs = foldl del_simproc;
   483 
   484 
   485 (* prems *)
   486 
   487 fun add_prems (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thms) =
   488   mk_mss (rules, congs, procs, bounds, thms @ prems, mk_rews, termless, depth);
   489 
   490 fun prems_of_mss (Mss {prems, ...}) = prems;
   491 
   492 
   493 (* mk_rews *)
   494 
   495 fun set_mk_rews
   496   (Mss {rules, congs, procs, bounds, prems, mk_rews, termless, depth}, mk) =
   497     mk_mss (rules, congs, procs, bounds, prems,
   498             {mk=mk, mk_sym= #mk_sym mk_rews, mk_eq_True= #mk_eq_True mk_rews},
   499             termless, depth);
   500 
   501 fun set_mk_sym
   502   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_sym) =
   503     mk_mss (rules, congs, procs, bounds, prems,
   504             {mk= #mk mk_rews, mk_sym= mk_sym, mk_eq_True= #mk_eq_True mk_rews},
   505             termless,depth);
   506 
   507 fun set_mk_eq_True
   508   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_eq_True) =
   509     mk_mss (rules, congs, procs, bounds, prems,
   510             {mk= #mk mk_rews, mk_sym= #mk_sym mk_rews, mk_eq_True= mk_eq_True},
   511             termless,depth);
   512 
   513 (* termless *)
   514 
   515 fun set_termless
   516   (Mss {rules, congs, procs, bounds, prems, mk_rews, depth, ...}, termless) =
   517     mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth);
   518 
   519 
   520 
   521 (** rewriting **)
   522 
   523 (*
   524   Uses conversions, see:
   525     L C Paulson, A higher-order implementation of rewriting,
   526     Science of Computer Programming 3 (1983), pages 119-149.
   527 *)
   528 
   529 val dest_eq = Drule.dest_equals o cprop_of;
   530 val lhs_of = fst o dest_eq;
   531 val rhs_of = snd o dest_eq;
   532 
   533 fun beta_eta_conversion t =
   534   let val thm = beta_conversion true t;
   535   in transitive thm (eta_conversion (rhs_of thm)) end;
   536 
   537 fun check_conv msg thm thm' =
   538   let
   539     val thm'' = transitive thm (transitive
   540       (symmetric (beta_eta_conversion (lhs_of thm'))) thm')
   541   in (if msg then trace_thm "SUCCEEDED" thm' else (); Some thm'') end
   542   handle THM _ =>
   543     let val {sign, prop = _ $ _ $ prop0, ...} = rep_thm thm;
   544     in
   545       (trace_thm "Proved wrong thm (Check subgoaler?)" thm';
   546        trace_term false "Should have proved:" sign prop0;
   547        None)
   548     end;
   549 
   550 
   551 (* mk_procrule *)
   552 
   553 fun mk_procrule thm =
   554   let val (_,prems,lhs,elhs,rhs,_) = decomp_simp thm
   555   in if rewrite_rule_extra_vars prems lhs rhs
   556      then (prthm true "Extra vars on rhs:" thm; [])
   557      else [mk_rrule2{thm=thm, name="", lhs=lhs, elhs=elhs, perm=false}]
   558   end;
   559 
   560 
   561 (* conversion to apply the meta simpset to a term *)
   562 
   563 (* Since the rewriting strategy is bottom-up, we avoid re-normalizing already
   564    normalized terms by carrying around the rhs of the rewrite rule just
   565    applied. This is called the `skeleton'. It is decomposed in parallel
   566    with the term. Once a Var is encountered, the corresponding term is
   567    already in normal form.
   568    skel0 is a dummy skeleton that is to enforce complete normalization.
   569 *)
   570 val skel0 = Bound 0;
   571 
   572 (* Use rhs as skeleton only if the lhs does not contain unnormalized bits.
   573    The latter may happen iff there are weak congruence rules for constants
   574    in the lhs.
   575 *)
   576 fun uncond_skel((_,weak),(lhs,rhs)) =
   577   if null weak then rhs (* optimization *)
   578   else if exists_Const (fn (c,_) => c mem weak) lhs then skel0
   579        else rhs;
   580 
   581 (* Behaves like unconditional rule if rhs does not contain vars not in the lhs.
   582    Otherwise those vars may become instantiated with unnormalized terms
   583    while the premises are solved.
   584 *)
   585 fun cond_skel(args as (congs,(lhs,rhs))) =
   586   if term_varnames rhs subset term_varnames lhs then uncond_skel(args)
   587   else skel0;
   588 
   589 (*
   590   we try in order:
   591     (1) beta reduction
   592     (2) unconditional rewrite rules
   593     (3) conditional rewrite rules
   594     (4) simplification procedures
   595 
   596   IMPORTANT: rewrite rules must not introduce new Vars or TVars!
   597 
   598 *)
   599 
   600 fun rewritec (prover, signt, maxt)
   601              (mss as Mss{rules, procs, termless, prems, congs, depth,...}) t =
   602   let
   603     val eta_thm = Thm.eta_conversion t;
   604     val eta_t' = rhs_of eta_thm;
   605     val eta_t = term_of eta_t';
   606     val tsigt = Sign.tsig_of signt;
   607     fun rew {thm, name, lhs, elhs, fo, perm} =
   608       let
   609         val {sign, prop, maxidx, ...} = rep_thm thm;
   610         val _ = if Sign.subsig (sign, signt) then ()
   611                 else (prthm true "Ignoring rewrite rule from different theory:" thm;
   612                       raise Pattern.MATCH);
   613         val (rthm, elhs') = if maxt = ~1 then (thm, elhs)
   614           else (Thm.incr_indexes (maxt+1) thm, Thm.cterm_incr_indexes (maxt+1) elhs);
   615         val insts = if fo then Thm.cterm_first_order_match (elhs', eta_t')
   616                           else Thm.cterm_match (elhs', eta_t');
   617         val thm' = Thm.instantiate insts (Thm.rename_boundvars lhs eta_t rthm);
   618         val prop' = #prop (rep_thm thm');
   619         val unconditional = (Logic.count_prems (prop',0) = 0);
   620         val (lhs', rhs') = Logic.dest_equals (Logic.strip_imp_concl prop')
   621       in
   622         if perm andalso not (termless (rhs', lhs'))
   623         then (trace_named_thm "Cannot apply permutative rewrite rule" (thm, name);
   624               trace_thm "Term does not become smaller:" thm'; None)
   625         else (trace_named_thm "Applying instance of rewrite rule" (thm, name);
   626            if unconditional
   627            then
   628              (trace_thm "Rewriting:" thm';
   629               let val lr = Logic.dest_equals prop;
   630                   val Some thm'' = check_conv false eta_thm thm'
   631               in Some (thm'', uncond_skel (congs, lr)) end)
   632            else
   633              (trace_thm "Trying to rewrite:" thm';
   634               case incr_depth mss of
   635                 None => (trace_thm "FAILED - reached depth limit" thm'; None)
   636               | Some mss =>
   637               (case prover mss thm' of
   638                 None       => (trace_thm "FAILED" thm'; None)
   639               | Some thm2 =>
   640                   (case check_conv true eta_thm thm2 of
   641                      None => None |
   642                      Some thm2' =>
   643                        let val concl = Logic.strip_imp_concl prop
   644                            val lr = Logic.dest_equals concl
   645                        in Some (thm2', cond_skel (congs, lr)) end))))
   646       end
   647 
   648     fun rews [] = None
   649       | rews (rrule :: rrules) =
   650           let val opt = rew rrule handle Pattern.MATCH => None
   651           in case opt of None => rews rrules | some => some end;
   652 
   653     fun sort_rrules rrs = let
   654       fun is_simple({thm, ...}:rrule) = case #prop (rep_thm thm) of
   655                                       Const("==",_) $ _ $ _ => true
   656                                       | _                   => false
   657       fun sort []        (re1,re2) = re1 @ re2
   658         | sort (rr::rrs) (re1,re2) = if is_simple rr
   659                                      then sort rrs (rr::re1,re2)
   660                                      else sort rrs (re1,rr::re2)
   661     in sort rrs ([],[]) end
   662 
   663     fun proc_rews ([]:simproc list) = None
   664       | proc_rews ({name, proc, lhs, ...} :: ps) =
   665           if Pattern.matches tsigt (term_of lhs, term_of t) then
   666             (debug_term false ("Trying procedure " ^ quote name ^ " on:") signt eta_t;
   667              case transform_failure (curry SIMPROC_FAIL name)
   668                  (fn () => proc signt prems eta_t) () of
   669                None => (debug false "FAILED"; proc_rews ps)
   670              | Some raw_thm =>
   671                  (trace_thm ("Procedure " ^ quote name ^ " produced rewrite rule:") raw_thm;
   672                   (case rews (mk_procrule raw_thm) of
   673                     None => (trace_cterm true ("IGNORED result of simproc " ^ quote name ^
   674                       " -- does not match") t; proc_rews ps)
   675                   | some => some)))
   676           else proc_rews ps;
   677   in case eta_t of
   678        Abs _ $ _ => Some (transitive eta_thm
   679          (beta_conversion false eta_t'), skel0)
   680      | _ => (case rews (sort_rrules (Net.match_term rules eta_t)) of
   681                None => proc_rews (Net.match_term procs eta_t)
   682              | some => some)
   683   end;
   684 
   685 
   686 (* conversion to apply a congruence rule to a term *)
   687 
   688 fun congc (prover,signt,maxt) {thm=cong,lhs=lhs} t =
   689   let val {sign, ...} = rep_thm cong
   690       val _ = if Sign.subsig (sign, signt) then ()
   691                  else error("Congruence rule from different theory")
   692       val rthm = if maxt = ~1 then cong else Thm.incr_indexes (maxt+1) cong;
   693       val rlhs = fst (Drule.dest_equals (Drule.strip_imp_concl (cprop_of rthm)));
   694       val insts = Thm.cterm_match (rlhs, t)
   695       (* Pattern.match can raise Pattern.MATCH;
   696          is handled when congc is called *)
   697       val thm' = Thm.instantiate insts (Thm.rename_boundvars (term_of rlhs) (term_of t) rthm);
   698       val unit = trace_thm "Applying congruence rule:" thm';
   699       fun err (msg, thm) = (prthm false msg thm; error "Failed congruence proof!")
   700   in case prover thm' of
   701        None => err ("Could not prove", thm')
   702      | Some thm2 => (case check_conv true (beta_eta_conversion t) thm2 of
   703           None => err ("Should not have proved", thm2)
   704         | Some thm2' =>
   705             if op aconv (pairself term_of (dest_equals (cprop_of thm2')))
   706             then None else Some thm2')
   707   end;
   708 
   709 val (cA, (cB, cC)) =
   710   apsnd dest_equals (dest_implies (hd (cprems_of Drule.imp_cong)));
   711 
   712 fun transitive1 None None = None
   713   | transitive1 (Some thm1) None = Some thm1
   714   | transitive1 None (Some thm2) = Some thm2
   715   | transitive1 (Some thm1) (Some thm2) = Some (transitive thm1 thm2)
   716 
   717 fun transitive2 thm = transitive1 (Some thm);
   718 fun transitive3 thm = transitive1 thm o Some;
   719 
   720 fun imp_cong' e = combination (combination refl_implies e);
   721 
   722 fun bottomc ((simprem,useprem,mutsimp), prover, sign, maxidx) =
   723   let
   724     fun botc skel mss t =
   725           if is_Var skel then None
   726           else
   727           (case subc skel mss t of
   728              some as Some thm1 =>
   729                (case rewritec (prover, sign, maxidx) mss (rhs_of thm1) of
   730                   Some (thm2, skel2) =>
   731                     transitive2 (transitive thm1 thm2)
   732                       (botc skel2 mss (rhs_of thm2))
   733                 | None => some)
   734            | None =>
   735                (case rewritec (prover, sign, maxidx) mss t of
   736                   Some (thm2, skel2) => transitive2 thm2
   737                     (botc skel2 mss (rhs_of thm2))
   738                 | None => None))
   739 
   740     and try_botc mss t =
   741           (case botc skel0 mss t of
   742              Some trec1 => trec1 | None => (reflexive t))
   743 
   744     and subc skel
   745           (mss as Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) t0 =
   746        (case term_of t0 of
   747            Abs (a, T, t) =>
   748              let val b = variant bounds a
   749                  val (v, t') = Thm.dest_abs (Some ("." ^ b)) t0
   750                  val mss' = mk_mss (rules, congs, procs, b :: bounds, prems, mk_rews, termless,depth)
   751                  val skel' = case skel of Abs (_, _, sk) => sk | _ => skel0
   752              in case botc skel' mss' t' of
   753                   Some thm => Some (abstract_rule a v thm)
   754                 | None => None
   755              end
   756          | t $ _ => (case t of
   757              Const ("==>", _) $ _  => impc t0 mss
   758            | Abs _ =>
   759                let val thm = beta_conversion false t0
   760                in case subc skel0 mss (rhs_of thm) of
   761                     None => Some thm
   762                   | Some thm' => Some (transitive thm thm')
   763                end
   764            | _  =>
   765                let fun appc () =
   766                      let
   767                        val (tskel, uskel) = case skel of
   768                            tskel $ uskel => (tskel, uskel)
   769                          | _ => (skel0, skel0);
   770                        val (ct, cu) = Thm.dest_comb t0
   771                      in
   772                      (case botc tskel mss ct of
   773                         Some thm1 =>
   774                           (case botc uskel mss cu of
   775                              Some thm2 => Some (combination thm1 thm2)
   776                            | None => Some (combination thm1 (reflexive cu)))
   777                       | None =>
   778                           (case botc uskel mss cu of
   779                              Some thm1 => Some (combination (reflexive ct) thm1)
   780                            | None => None))
   781                      end
   782                    val (h, ts) = strip_comb t
   783                in case h of
   784                     Const(a, _) =>
   785                       (case assoc_string (fst congs, a) of
   786                          None => appc ()
   787                        | Some cong =>
   788 (* post processing: some partial applications h t1 ... tj, j <= length ts,
   789    may be a redex. Example: map (%x.x) = (%xs.xs) wrt map_cong *)
   790                           (let
   791                              val thm = congc (prover mss, sign, maxidx) cong t0;
   792                              val t = if_none (apsome rhs_of thm) t0;
   793                              val (cl, cr) = Thm.dest_comb t
   794                              val dVar = Var(("", 0), dummyT)
   795                              val skel =
   796                                list_comb (h, replicate (length ts) dVar)
   797                            in case botc skel mss cl of
   798                                 None => thm
   799                               | Some thm' => transitive3 thm
   800                                   (combination thm' (reflexive cr))
   801                            end handle TERM _ => error "congc result"
   802                                     | Pattern.MATCH => appc ()))
   803                   | _ => appc ()
   804                end)
   805          | _ => None)
   806 
   807     and impc ct mss =
   808       if mutsimp then mut_impc0 [] ct [] [] mss else nonmut_impc ct mss
   809 
   810     and rules_of_prem mss prem =
   811       if maxidx_of_term (term_of prem) <> ~1
   812       then (trace_cterm true
   813         "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem; ([], None))
   814       else
   815         let val asm = assume prem
   816         in (extract_safe_rrules (mss, asm), Some asm) end
   817 
   818     and add_rrules (rrss, asms) mss =
   819       add_prems (foldl (insert_rrule true) (mss, flat rrss), mapfilter I asms)
   820 
   821     and disch r (prem, eq) =
   822       let
   823         val (lhs, rhs) = dest_eq eq;
   824         val eq' = implies_elim (Thm.instantiate
   825           ([], [(cA, prem), (cB, lhs), (cC, rhs)]) Drule.imp_cong)
   826           (implies_intr prem eq)
   827       in if not r then eq' else
   828         let
   829           val (prem', concl) = dest_implies lhs;
   830           val (prem'', _) = dest_implies rhs
   831         in transitive (transitive
   832           (Thm.instantiate ([], [(cA, prem'), (cB, prem), (cC, concl)])
   833              Drule.swap_prems_eq) eq')
   834           (Thm.instantiate ([], [(cA, prem), (cB, prem''), (cC, concl)])
   835              Drule.swap_prems_eq)
   836         end
   837       end
   838 
   839     and rebuild [] _ _ _ _ eq = eq
   840       | rebuild (prem :: prems) concl (rrs :: rrss) (asm :: asms) mss eq =
   841           let
   842             val mss' = add_rrules (rev rrss, rev asms) mss;
   843             val concl' =
   844               Drule.mk_implies (prem, if_none (apsome rhs_of eq) concl);
   845             val dprem = apsome (curry (disch false) prem)
   846           in case rewritec (prover, sign, maxidx) mss' concl' of
   847               None => rebuild prems concl' rrss asms mss (dprem eq)
   848             | Some (eq', _) => transitive2 (foldl (disch false o swap)
   849                   (the (transitive3 (dprem eq) eq'), prems))
   850                 (mut_impc0 (rev prems) (rhs_of eq') (rev rrss) (rev asms) mss)
   851           end
   852           
   853     and mut_impc0 prems concl rrss asms mss =
   854       let
   855         val prems' = strip_imp_prems concl;
   856         val (rrss', asms') = split_list (map (rules_of_prem mss) prems')
   857       in mut_impc (prems @ prems') (strip_imp_concl concl) (rrss @ rrss')
   858         (asms @ asms') [] [] [] [] mss ~1 ~1
   859       end
   860  
   861     and mut_impc [] concl [] [] prems' rrss' asms' eqns mss changed k =
   862         transitive1 (foldl (fn (eq2, (eq1, prem)) => transitive1 eq1
   863             (apsome (curry (disch false) prem) eq2)) (None, eqns ~~ prems'))
   864           (if changed > 0 then
   865              mut_impc (rev prems') concl (rev rrss') (rev asms')
   866                [] [] [] [] mss ~1 changed
   867            else rebuild prems' concl rrss' asms' mss
   868              (botc skel0 (add_rrules (rev rrss', rev asms') mss) concl))
   869 
   870       | mut_impc (prem :: prems) concl (rrs :: rrss) (asm :: asms)
   871           prems' rrss' asms' eqns mss changed k =
   872         case (if k = 0 then None else botc skel0 (add_rrules
   873           (rev rrss' @ rrss, rev asms' @ asms) mss) prem) of
   874             None => mut_impc prems concl rrss asms (prem :: prems')
   875               (rrs :: rrss') (asm :: asms') (None :: eqns) mss changed
   876               (if k = 0 then 0 else k - 1)
   877           | Some eqn =>
   878             let
   879               val prem' = rhs_of eqn;
   880               val tprems = map term_of prems;
   881               val i = 1 + foldl Int.max (~1, map (fn p =>
   882                 find_index_eq p tprems) (#hyps (rep_thm eqn)));
   883               val (rrs', asm') = rules_of_prem mss prem'
   884             in mut_impc prems concl rrss asms (prem' :: prems')
   885               (rrs' :: rrss') (asm' :: asms') (Some (foldr (disch true)
   886                 (take (i, prems), imp_cong' eqn (reflexive (Drule.list_implies
   887                   (drop (i, prems), concl))))) :: eqns) mss (length prems') ~1
   888             end
   889 
   890      (* legacy code - only for backwards compatibility *)
   891      and nonmut_impc ct mss =
   892        let val (prem, conc) = dest_implies ct;
   893            val thm1 = if simprem then botc skel0 mss prem else None;
   894            val prem1 = if_none (apsome rhs_of thm1) prem;
   895            val mss1 = if not useprem then mss else add_rrules
   896              (apsnd single (apfst single (rules_of_prem mss prem1))) mss
   897        in (case botc skel0 mss1 conc of
   898            None => (case thm1 of
   899                None => None
   900              | Some thm1' => Some (imp_cong' thm1' (reflexive conc)))
   901          | Some thm2 =>
   902            let val thm2' = disch false (prem1, thm2)
   903            in (case thm1 of
   904                None => Some thm2'
   905              | Some thm1' =>
   906                  Some (transitive (imp_cong' thm1' (reflexive conc)) thm2'))
   907            end)
   908        end
   909 
   910  in try_botc end;
   911 
   912 
   913 (*** Meta-rewriting: rewrites t to u and returns the theorem t==u ***)
   914 
   915 (*
   916   Parameters:
   917     mode = (simplify A,
   918             use A in simplifying B,
   919             use prems of B (if B is again a meta-impl.) to simplify A)
   920            when simplifying A ==> B
   921     mss: contains equality theorems of the form [|p1,...|] ==> t==u
   922     prover: how to solve premises in conditional rewrites and congruences
   923 *)
   924 
   925 fun rewrite_cterm mode prover mss ct =
   926   let val {sign, t, maxidx, ...} = rep_cterm ct
   927       val Mss{depth, ...} = mss
   928   in simp_depth := depth;
   929      bottomc (mode, prover, sign, maxidx) mss ct
   930   end
   931   handle THM (s, _, thms) =>
   932     error ("Exception THM was raised in simplifier:\n" ^ s ^ "\n" ^
   933       Pretty.string_of (Display.pretty_thms thms));
   934 
   935 (*In [A1,...,An]==>B, rewrite the selected A's only -- for rewrite_goals_tac*)
   936 fun goals_conv pred cv =
   937   let fun gconv i ct =
   938         let val (A,B) = Drule.dest_implies ct
   939         in imp_cong' (if pred i then cv A else reflexive A) (gconv (i+1) B) end
   940         handle TERM _ => reflexive ct
   941   in gconv 1 end;
   942 
   943 (* Rewrite A in !!x1,...,xn. A *)
   944 fun forall_conv cv ct =
   945   let val p as (ct1, ct2) = Thm.dest_comb ct
   946   in (case pairself term_of p of
   947       (Const ("all", _), Abs (s, _, _)) =>
   948          let val (v, ct') = Thm.dest_abs (Some "@") ct2;
   949          in Thm.combination (Thm.reflexive ct1)
   950            (Thm.abstract_rule s v (forall_conv cv ct'))
   951          end
   952     | _ => cv ct)
   953   end handle TERM _ => cv ct;
   954 
   955 (*Use a conversion to transform a theorem*)
   956 fun fconv_rule cv th = equal_elim (cv (cprop_of th)) th;
   957 
   958 (*Rewrite a cterm*)
   959 fun rewrite_aux _ _ [] = (fn ct => Thm.reflexive ct)
   960   | rewrite_aux prover full thms = rewrite_cterm (full, false, false) prover (mss_of thms);
   961 
   962 (*Rewrite a theorem*)
   963 fun simplify_aux _ _ [] = (fn th => th)
   964   | simplify_aux prover full thms =
   965       fconv_rule (rewrite_cterm (full, false, false) prover (mss_of thms));
   966 
   967 fun rewrite_thm mode prover mss = fconv_rule (rewrite_cterm mode prover mss);
   968 
   969 (*Rewrite the subgoals of a proof state (represented by a theorem) *)
   970 fun rewrite_goals_rule_aux _ []   th = th
   971   | rewrite_goals_rule_aux prover thms th =
   972       fconv_rule (goals_conv (K true) (rewrite_cterm (true, true, false) prover
   973         (mss_of thms))) th;
   974 
   975 (*Rewrite the subgoal of a proof state (represented by a theorem) *)
   976 fun rewrite_goal_rule mode prover mss i thm =
   977   if 0 < i  andalso  i <= nprems_of thm
   978   then fconv_rule (goals_conv (fn j => j=i) (rewrite_cterm mode prover mss)) thm
   979   else raise THM("rewrite_goal_rule",i,[thm]);
   980 
   981 
   982 (*simple term rewriting -- without proofs*)
   983 fun rewrite_term sg rules procs =
   984   Pattern.rewrite_term (Sign.tsig_of sg) (map decomp_simp' rules) procs;
   985 
   986 end;
   987 
   988 structure BasicMetaSimplifier: BASIC_META_SIMPLIFIER = MetaSimplifier;
   989 open BasicMetaSimplifier;