src/Pure/meta_simplifier.ML
author wenzelm
Thu Dec 27 16:46:52 2001 +0100 (2001-12-27)
changeset 12603 7d2bca103101
parent 12285 6490fc7b3eed
child 12779 c5739c1431ab
permissions -rw-r--r--
tuned tracing markup;
     1 (*  Title:      Pure/meta_simplifier.ML
     2     ID:         $Id$
     3     Author:     Tobias Nipkow and Stefan Berghofer
     4     Copyright   1994  University of Cambridge
     5 
     6 Meta-level Simplification.
     7 *)
     8 
     9 signature BASIC_META_SIMPLIFIER =
    10 sig
    11   val trace_simp: bool ref
    12   val debug_simp: bool ref
    13 end;
    14 
    15 signature META_SIMPLIFIER =
    16 sig
    17   include BASIC_META_SIMPLIFIER
    18   exception SIMPLIFIER of string * thm
    19   type meta_simpset
    20   val dest_mss          : meta_simpset ->
    21     {simps: thm list, congs: thm list, procs: (string * cterm list) list}
    22   val empty_mss         : meta_simpset
    23   val clear_mss         : meta_simpset -> meta_simpset
    24   val merge_mss         : meta_simpset * meta_simpset -> meta_simpset
    25   val add_simps         : meta_simpset * thm list -> meta_simpset
    26   val del_simps         : meta_simpset * thm list -> meta_simpset
    27   val mss_of            : thm list -> meta_simpset
    28   val add_congs         : meta_simpset * thm list -> meta_simpset
    29   val del_congs         : meta_simpset * thm list -> meta_simpset
    30   val add_simprocs      : meta_simpset *
    31     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    32       -> meta_simpset
    33   val del_simprocs      : meta_simpset *
    34     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    35       -> meta_simpset
    36   val add_prems         : meta_simpset * thm list -> meta_simpset
    37   val prems_of_mss      : meta_simpset -> thm list
    38   val set_mk_rews       : meta_simpset * (thm -> thm list) -> meta_simpset
    39   val set_mk_sym        : meta_simpset * (thm -> thm option) -> meta_simpset
    40   val set_mk_eq_True    : meta_simpset * (thm -> thm option) -> meta_simpset
    41   val set_termless      : meta_simpset * (term * term -> bool) -> meta_simpset
    42   val rewrite_cterm: bool * bool * bool ->
    43     (meta_simpset -> thm -> thm option) -> meta_simpset -> cterm -> thm
    44   val goals_conv        : (int -> bool) -> (cterm -> thm) -> cterm -> thm
    45   val forall_conv       : (cterm -> thm) -> cterm -> thm
    46   val fconv_rule        : (cterm -> thm) -> thm -> thm
    47   val rewrite_aux       : (meta_simpset -> thm -> thm option) -> bool -> thm list -> cterm -> thm
    48   val simplify_aux      : (meta_simpset -> thm -> thm option) -> bool -> thm list -> thm -> thm
    49   val rewrite_thm       : bool * bool * bool
    50                           -> (meta_simpset -> thm -> thm option)
    51                           -> meta_simpset -> thm -> thm
    52   val rewrite_goals_rule_aux: (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
    53   val rewrite_goal_rule : bool* bool * bool
    54                           -> (meta_simpset -> thm -> thm option)
    55                           -> meta_simpset -> int -> thm -> thm
    56 end;
    57 
    58 structure MetaSimplifier : META_SIMPLIFIER =
    59 struct
    60 
    61 (** diagnostics **)
    62 
    63 exception SIMPLIFIER of string * thm;
    64 
    65 val simp_depth = ref 0;
    66 
    67 local
    68 
    69 fun println a =
    70   tracing ((case ! simp_depth of 0 => "" | n => "[" ^ string_of_int n ^ "]") ^ a);
    71 
    72 fun prnt warn a = if warn then warning a else println a;
    73 fun prtm warn a sign t = prnt warn (a ^ "\n" ^ Sign.string_of_term sign t);
    74 fun prctm warn a t = prnt warn (a ^ "\n" ^ Display.string_of_cterm t);
    75 
    76 in
    77 
    78 fun prthm warn a = prctm warn a o Thm.cprop_of;
    79 
    80 val trace_simp = ref false;
    81 val debug_simp = ref false;
    82 
    83 fun trace warn a = if !trace_simp then prnt warn a else ();
    84 fun debug warn a = if !debug_simp then prnt warn a else ();
    85 
    86 fun trace_term warn a sign t = if !trace_simp then prtm warn a sign t else ();
    87 fun trace_cterm warn a t = if !trace_simp then prctm warn a t else ();
    88 fun debug_term warn a sign t = if !debug_simp then prtm warn a sign t else ();
    89 
    90 fun trace_thm warn a thm =
    91   let val {sign, prop, ...} = rep_thm thm
    92   in trace_term warn a sign prop end;
    93 
    94 end;
    95 
    96 
    97 (** meta simp sets **)
    98 
    99 (* basic components *)
   100 
   101 type rrule = {thm: thm, lhs: term, elhs: cterm, fo: bool, perm: bool};
   102 (* thm: the rewrite rule
   103    lhs: the left-hand side
   104    elhs: the etac-contracted lhs.
   105    fo:  use first-order matching
   106    perm: the rewrite rule is permutative
   107 Remarks:
   108   - elhs is used for matching,
   109     lhs only for preservation of bound variable names.
   110   - fo is set iff
   111     either elhs is first-order (no Var is applied),
   112            in which case fo-matching is complete,
   113     or elhs is not a pattern,
   114        in which case there is nothing better to do.
   115 *)
   116 type cong = {thm: thm, lhs: cterm};
   117 type simproc =
   118  {name: string, proc: Sign.sg -> thm list -> term -> thm option, lhs: cterm, id: stamp};
   119 
   120 fun eq_rrule ({thm = thm1, ...}: rrule, {thm = thm2, ...}: rrule) =
   121   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   122 
   123 fun eq_cong ({thm = thm1, ...}: cong, {thm = thm2, ...}: cong) =
   124   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   125 
   126 fun eq_prem (thm1, thm2) =
   127   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   128 
   129 fun eq_simproc ({id = s1, ...}:simproc, {id = s2, ...}:simproc) = (s1 = s2);
   130 
   131 fun mk_simproc (name, proc, lhs, id) =
   132   {name = name, proc = proc, lhs = lhs, id = id};
   133 
   134 
   135 (* datatype mss *)
   136 
   137 (*
   138   A "mss" contains data needed during conversion:
   139     rules: discrimination net of rewrite rules;
   140     congs: association list of congruence rules and
   141            a list of `weak' congruence constants.
   142            A congruence is `weak' if it avoids normalization of some argument.
   143     procs: discrimination net of simplification procedures
   144       (functions that prove rewrite rules on the fly);
   145     bounds: names of bound variables already used
   146       (for generating new names when rewriting under lambda abstractions);
   147     prems: current premises;
   148     mk_rews: mk: turns simplification thms into rewrite rules;
   149              mk_sym: turns == around; (needs Drule!)
   150              mk_eq_True: turns P into P == True - logic specific;
   151     termless: relation for ordered rewriting;
   152     depth: depth of conditional rewriting;
   153 *)
   154 
   155 datatype meta_simpset =
   156   Mss of {
   157     rules: rrule Net.net,
   158     congs: (string * cong) list * string list,
   159     procs: simproc Net.net,
   160     bounds: string list,
   161     prems: thm list,
   162     mk_rews: {mk: thm -> thm list,
   163               mk_sym: thm -> thm option,
   164               mk_eq_True: thm -> thm option},
   165     termless: term * term -> bool,
   166     depth: int};
   167 
   168 fun mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth) =
   169   Mss {rules = rules, congs = congs, procs = procs, bounds = bounds,
   170        prems=prems, mk_rews=mk_rews, termless=termless, depth=depth};
   171 
   172 fun upd_rules(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}, rules') =
   173   mk_mss(rules',congs,procs,bounds,prems,mk_rews,termless,depth);
   174 
   175 val empty_mss =
   176   let val mk_rews = {mk = K [], mk_sym = K None, mk_eq_True = K None}
   177   in mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, Term.termless, 0) end;
   178 
   179 fun clear_mss (Mss {mk_rews, termless, ...}) =
   180   mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, termless,0);
   181 
   182 fun incr_depth(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) =
   183   mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth+1)
   184 
   185 
   186 
   187 (** simpset operations **)
   188 
   189 (* term variables *)
   190 
   191 val add_term_varnames = foldl_aterms (fn (xs, Var (x, _)) => ins_ix (x, xs) | (xs, _) => xs);
   192 fun term_varnames t = add_term_varnames ([], t);
   193 
   194 
   195 (* dest_mss *)
   196 
   197 fun dest_mss (Mss {rules, congs, procs, ...}) =
   198   {simps = map (fn (_, {thm, ...}) => thm) (Net.dest rules),
   199    congs = map (fn (_, {thm, ...}) => thm) (fst congs),
   200    procs =
   201      map (fn (_, {name, lhs, id, ...}) => ((name, lhs), id)) (Net.dest procs)
   202      |> partition_eq eq_snd
   203      |> map (fn ps => (#1 (#1 (hd ps)), map (#2 o #1) ps))
   204      |> Library.sort_wrt #1};
   205 
   206 
   207 (* merge_mss *)       (*NOTE: ignores mk_rews, termless and depth of 2nd mss*)
   208 
   209 fun merge_mss
   210  (Mss {rules = rules1, congs = (congs1,weak1), procs = procs1,
   211        bounds = bounds1, prems = prems1, mk_rews, termless, depth},
   212   Mss {rules = rules2, congs = (congs2,weak2), procs = procs2,
   213        bounds = bounds2, prems = prems2, ...}) =
   214       mk_mss
   215        (Net.merge (rules1, rules2, eq_rrule),
   216         (gen_merge_lists (eq_cong o pairself snd) congs1 congs2,
   217         merge_lists weak1 weak2),
   218         Net.merge (procs1, procs2, eq_simproc),
   219         merge_lists bounds1 bounds2,
   220         gen_merge_lists eq_prem prems1 prems2,
   221         mk_rews, termless, depth);
   222 
   223 
   224 (* add_simps *)
   225 
   226 fun mk_rrule2{thm,lhs,elhs,perm} =
   227   let val fo = Pattern.first_order (term_of elhs) orelse not(Pattern.pattern (term_of elhs))
   228   in {thm=thm,lhs=lhs,elhs=elhs,fo=fo,perm=perm} end
   229 
   230 fun insert_rrule(mss as Mss {rules,...},
   231                  rrule as {thm,lhs,elhs,perm}) =
   232   (trace_thm false "Adding rewrite rule:" thm;
   233    let val rrule2 as {elhs,...} = mk_rrule2 rrule
   234        val rules' = Net.insert_term ((term_of elhs, rrule2), rules, eq_rrule)
   235    in upd_rules(mss,rules') end
   236    handle Net.INSERT =>
   237      (prthm true "Ignoring duplicate rewrite rule:" thm; mss));
   238 
   239 fun vperm (Var _, Var _) = true
   240   | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t)
   241   | vperm (t1 $ t2, u1 $ u2) = vperm (t1, u1) andalso vperm (t2, u2)
   242   | vperm (t, u) = (t = u);
   243 
   244 fun var_perm (t, u) =
   245   vperm (t, u) andalso eq_set (term_varnames t, term_varnames u);
   246 
   247 (* FIXME: it seems that the conditions on extra variables are too liberal if
   248 prems are nonempty: does solving the prems really guarantee instantiation of
   249 all its Vars? Better: a dynamic check each time a rule is applied.
   250 *)
   251 fun rewrite_rule_extra_vars prems elhs erhs =
   252   not (term_varnames erhs subset foldl add_term_varnames (term_varnames elhs, prems))
   253   orelse
   254   not ((term_tvars erhs) subset
   255        (term_tvars elhs  union  List.concat(map term_tvars prems)));
   256 
   257 (*Simple test for looping rewrite rules and stupid orientations*)
   258 fun reorient sign prems lhs rhs =
   259    rewrite_rule_extra_vars prems lhs rhs
   260   orelse
   261    is_Var (head_of lhs)
   262   orelse
   263    (exists (apl (lhs, Logic.occs)) (rhs :: prems))
   264   orelse
   265    (null prems andalso
   266     Pattern.matches (#tsig (Sign.rep_sg sign)) (lhs, rhs))
   267     (*the condition "null prems" is necessary because conditional rewrites
   268       with extra variables in the conditions may terminate although
   269       the rhs is an instance of the lhs. Example: ?m < ?n ==> f(?n) == f(?m)*)
   270   orelse
   271    (is_Const lhs andalso not(is_Const rhs))
   272 
   273 fun decomp_simp thm =
   274   let val {sign, prop, ...} = rep_thm thm;
   275       val prems = Logic.strip_imp_prems prop;
   276       val concl = Drule.strip_imp_concl (cprop_of thm);
   277       val (lhs, rhs) = Drule.dest_equals concl handle TERM _ =>
   278         raise SIMPLIFIER ("Rewrite rule not a meta-equality", thm)
   279       val elhs = snd (Drule.dest_equals (cprop_of (Thm.eta_conversion lhs)));
   280       val elhs = if elhs=lhs then lhs else elhs (* try to share *)
   281       val erhs = Pattern.eta_contract (term_of rhs);
   282       val perm = var_perm (term_of elhs, erhs) andalso not (term_of elhs aconv erhs)
   283                  andalso not (is_Var (term_of elhs))
   284   in (sign, prems, term_of lhs, elhs, term_of rhs, perm) end;
   285 
   286 fun mk_eq_True (Mss{mk_rews={mk_eq_True,...},...}) thm =
   287   case mk_eq_True thm of
   288     None => []
   289   | Some eq_True => let val (_,_,lhs,elhs,_,_) = decomp_simp eq_True
   290                     in [{thm=eq_True, lhs=lhs, elhs=elhs, perm=false}] end;
   291 
   292 (* create the rewrite rule and possibly also the ==True variant,
   293    in case there are extra vars on the rhs *)
   294 fun rrule_eq_True(thm,lhs,elhs,rhs,mss,thm2) =
   295   let val rrule = {thm=thm, lhs=lhs, elhs=elhs, perm=false}
   296   in if (term_varnames rhs)  subset (term_varnames lhs) andalso
   297         (term_tvars rhs) subset (term_tvars lhs)
   298      then [rrule]
   299      else mk_eq_True mss thm2 @ [rrule]
   300   end;
   301 
   302 fun mk_rrule mss thm =
   303   let val (_,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   304   in if perm then [{thm=thm, lhs=lhs, elhs=elhs, perm=true}] else
   305      (* weak test for loops: *)
   306      if rewrite_rule_extra_vars prems lhs rhs orelse
   307         is_Var (term_of elhs)
   308      then mk_eq_True mss thm
   309      else rrule_eq_True(thm,lhs,elhs,rhs,mss,thm)
   310   end;
   311 
   312 fun orient_rrule mss thm =
   313   let val (sign,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   314   in if perm then [{thm=thm,lhs=lhs,elhs=elhs,perm=true}]
   315      else if reorient sign prems lhs rhs
   316           then if reorient sign prems rhs lhs
   317                then mk_eq_True mss thm
   318                else let val Mss{mk_rews={mk_sym,...},...} = mss
   319                     in case mk_sym thm of
   320                          None => []
   321                        | Some thm' =>
   322                            let val (_,_,lhs',elhs',rhs',_) = decomp_simp thm'
   323                            in rrule_eq_True(thm',lhs',elhs',rhs',mss,thm) end
   324                     end
   325           else rrule_eq_True(thm,lhs,elhs,rhs,mss,thm)
   326   end;
   327 
   328 fun extract_rews(Mss{mk_rews = {mk,...},...},thms) = flat(map mk thms);
   329 
   330 fun orient_comb_simps comb mk_rrule (mss,thms) =
   331   let val rews = extract_rews(mss,thms)
   332       val rrules = flat (map mk_rrule rews)
   333   in foldl comb (mss,rrules) end
   334 
   335 (* Add rewrite rules explicitly; do not reorient! *)
   336 fun add_simps(mss,thms) =
   337   orient_comb_simps insert_rrule (mk_rrule mss) (mss,thms);
   338 
   339 fun mss_of thms =
   340   foldl insert_rrule (empty_mss, flat(map (mk_rrule empty_mss) thms));
   341 
   342 fun extract_safe_rrules(mss,thm) =
   343   flat (map (orient_rrule mss) (extract_rews(mss,[thm])));
   344 
   345 fun add_safe_simp(mss,thm) =
   346   foldl insert_rrule (mss, extract_safe_rrules(mss,thm))
   347 
   348 (* del_simps *)
   349 
   350 fun del_rrule(mss as Mss {rules,...},
   351               rrule as {thm, elhs, ...}) =
   352   (upd_rules(mss, Net.delete_term ((term_of elhs, rrule), rules, eq_rrule))
   353    handle Net.DELETE =>
   354      (prthm true "Rewrite rule not in simpset:" thm; mss));
   355 
   356 fun del_simps(mss,thms) =
   357   orient_comb_simps del_rrule (map mk_rrule2 o mk_rrule mss) (mss,thms);
   358 
   359 
   360 (* add_congs *)
   361 
   362 fun is_full_cong_prems [] varpairs = null varpairs
   363   | is_full_cong_prems (p::prems) varpairs =
   364     (case Logic.strip_assums_concl p of
   365        Const("==",_) $ lhs $ rhs =>
   366          let val (x,xs) = strip_comb lhs and (y,ys) = strip_comb rhs
   367          in is_Var x  andalso  forall is_Bound xs  andalso
   368             null(findrep(xs))  andalso xs=ys andalso
   369             (x,y) mem varpairs andalso
   370             is_full_cong_prems prems (varpairs\(x,y))
   371          end
   372      | _ => false);
   373 
   374 fun is_full_cong thm =
   375 let val prems = prems_of thm
   376     and concl = concl_of thm
   377     val (lhs,rhs) = Logic.dest_equals concl
   378     val (f,xs) = strip_comb lhs
   379     and (g,ys) = strip_comb rhs
   380 in
   381   f=g andalso null(findrep(xs@ys)) andalso length xs = length ys andalso
   382   is_full_cong_prems prems (xs ~~ ys)
   383 end
   384 
   385 fun add_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   386   let
   387     val (lhs, _) = Drule.dest_equals (Drule.strip_imp_concl (cprop_of thm)) handle TERM _ =>
   388       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   389 (*   val lhs = Pattern.eta_contract lhs; *)
   390     val (a, _) = dest_Const (head_of (term_of lhs)) handle TERM _ =>
   391       raise SIMPLIFIER ("Congruence must start with a constant", thm);
   392     val (alist,weak) = congs
   393     val alist2 = overwrite_warn (alist, (a,{lhs=lhs, thm=thm}))
   394            ("Overwriting congruence rule for " ^ quote a);
   395     val weak2 = if is_full_cong thm then weak else a::weak
   396   in
   397     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   398   end;
   399 
   400 val (op add_congs) = foldl add_cong;
   401 
   402 
   403 (* del_congs *)
   404 
   405 fun del_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   406   let
   407     val (lhs, _) = Logic.dest_equals (concl_of thm) handle TERM _ =>
   408       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   409 (*   val lhs = Pattern.eta_contract lhs; *)
   410     val (a, _) = dest_Const (head_of lhs) handle TERM _ =>
   411       raise SIMPLIFIER ("Congruence must start with a constant", thm);
   412     val (alist,_) = congs
   413     val alist2 = filter (fn (x,_)=> x<>a) alist
   414     val weak2 = mapfilter (fn(a,{thm,...}) => if is_full_cong thm then None
   415                                               else Some a)
   416                    alist2
   417   in
   418     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   419   end;
   420 
   421 val (op del_congs) = foldl del_cong;
   422 
   423 
   424 (* add_simprocs *)
   425 
   426 fun add_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   427     (name, lhs, proc, id)) =
   428   let val {sign, t, ...} = rep_cterm lhs
   429   in (trace_term false ("Adding simplification procedure " ^ quote name ^ " for")
   430       sign t;
   431     mk_mss (rules, congs,
   432       Net.insert_term ((t, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   433         handle Net.INSERT =>
   434             (warning ("Ignoring duplicate simplification procedure \""
   435                       ^ name ^ "\"");
   436              procs),
   437         bounds, prems, mk_rews, termless,depth))
   438   end;
   439 
   440 fun add_simproc (mss, (name, lhss, proc, id)) =
   441   foldl add_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   442 
   443 val add_simprocs = foldl add_simproc;
   444 
   445 
   446 (* del_simprocs *)
   447 
   448 fun del_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   449     (name, lhs, proc, id)) =
   450   mk_mss (rules, congs,
   451     Net.delete_term ((term_of lhs, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   452       handle Net.DELETE =>
   453           (warning ("Simplification procedure \"" ^ name ^
   454                        "\" not in simpset"); procs),
   455       bounds, prems, mk_rews, termless, depth);
   456 
   457 fun del_simproc (mss, (name, lhss, proc, id)) =
   458   foldl del_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   459 
   460 val del_simprocs = foldl del_simproc;
   461 
   462 
   463 (* prems *)
   464 
   465 fun add_prems (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thms) =
   466   mk_mss (rules, congs, procs, bounds, thms @ prems, mk_rews, termless, depth);
   467 
   468 fun prems_of_mss (Mss {prems, ...}) = prems;
   469 
   470 
   471 (* mk_rews *)
   472 
   473 fun set_mk_rews
   474   (Mss {rules, congs, procs, bounds, prems, mk_rews, termless, depth}, mk) =
   475     mk_mss (rules, congs, procs, bounds, prems,
   476             {mk=mk, mk_sym= #mk_sym mk_rews, mk_eq_True= #mk_eq_True mk_rews},
   477             termless, depth);
   478 
   479 fun set_mk_sym
   480   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_sym) =
   481     mk_mss (rules, congs, procs, bounds, prems,
   482             {mk= #mk mk_rews, mk_sym= mk_sym, mk_eq_True= #mk_eq_True mk_rews},
   483             termless,depth);
   484 
   485 fun set_mk_eq_True
   486   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_eq_True) =
   487     mk_mss (rules, congs, procs, bounds, prems,
   488             {mk= #mk mk_rews, mk_sym= #mk_sym mk_rews, mk_eq_True= mk_eq_True},
   489             termless,depth);
   490 
   491 (* termless *)
   492 
   493 fun set_termless
   494   (Mss {rules, congs, procs, bounds, prems, mk_rews, depth, ...}, termless) =
   495     mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth);
   496 
   497 
   498 
   499 (** rewriting **)
   500 
   501 (*
   502   Uses conversions, see:
   503     L C Paulson, A higher-order implementation of rewriting,
   504     Science of Computer Programming 3 (1983), pages 119-149.
   505 *)
   506 
   507 type prover = meta_simpset -> thm -> thm option;
   508 type termrec = (Sign.sg_ref * term list) * term;
   509 type conv = meta_simpset -> termrec -> termrec;
   510 
   511 val dest_eq = Drule.dest_equals o cprop_of;
   512 val lhs_of = fst o dest_eq;
   513 val rhs_of = snd o dest_eq;
   514 
   515 fun beta_eta_conversion t =
   516   let val thm = beta_conversion true t;
   517   in transitive thm (eta_conversion (rhs_of thm)) end;
   518 
   519 fun check_conv msg thm thm' =
   520   let
   521     val thm'' = transitive thm (transitive
   522       (symmetric (beta_eta_conversion (lhs_of thm'))) thm')
   523   in (if msg then trace_thm false "SUCCEEDED" thm' else (); Some thm'') end
   524   handle THM _ =>
   525     let val {sign, prop = _ $ _ $ prop0, ...} = rep_thm thm;
   526     in
   527       (trace_thm false "Proved wrong thm (Check subgoaler?)" thm';
   528        trace_term false "Should have proved:" sign prop0;
   529        None)
   530     end;
   531 
   532 
   533 (* mk_procrule *)
   534 
   535 fun mk_procrule thm =
   536   let val (_,prems,lhs,elhs,rhs,_) = decomp_simp thm
   537   in if rewrite_rule_extra_vars prems lhs rhs
   538      then (prthm true "Extra vars on rhs:" thm; [])
   539      else [mk_rrule2{thm=thm, lhs=lhs, elhs=elhs, perm=false}]
   540   end;
   541 
   542 
   543 (* conversion to apply the meta simpset to a term *)
   544 
   545 (* Since the rewriting strategy is bottom-up, we avoid re-normalizing already
   546    normalized terms by carrying around the rhs of the rewrite rule just
   547    applied. This is called the `skeleton'. It is decomposed in parallel
   548    with the term. Once a Var is encountered, the corresponding term is
   549    already in normal form.
   550    skel0 is a dummy skeleton that is to enforce complete normalization.
   551 *)
   552 val skel0 = Bound 0;
   553 
   554 (* Use rhs as skeleton only if the lhs does not contain unnormalized bits.
   555    The latter may happen iff there are weak congruence rules for constants
   556    in the lhs.
   557 *)
   558 fun uncond_skel((_,weak),(lhs,rhs)) =
   559   if null weak then rhs (* optimization *)
   560   else if exists_Const (fn (c,_) => c mem weak) lhs then skel0
   561        else rhs;
   562 
   563 (* Behaves like unconditional rule if rhs does not contain vars not in the lhs.
   564    Otherwise those vars may become instantiated with unnormalized terms
   565    while the premises are solved.
   566 *)
   567 fun cond_skel(args as (congs,(lhs,rhs))) =
   568   if term_varnames rhs subset term_varnames lhs then uncond_skel(args)
   569   else skel0;
   570 
   571 (*
   572   we try in order:
   573     (1) beta reduction
   574     (2) unconditional rewrite rules
   575     (3) conditional rewrite rules
   576     (4) simplification procedures
   577 
   578   IMPORTANT: rewrite rules must not introduce new Vars or TVars!
   579 
   580 *)
   581 
   582 fun rewritec (prover, signt, maxt)
   583              (mss as Mss{rules, procs, termless, prems, congs, depth,...}) t =
   584   let
   585     val eta_thm = Thm.eta_conversion t;
   586     val eta_t' = rhs_of eta_thm;
   587     val eta_t = term_of eta_t';
   588     val tsigt = Sign.tsig_of signt;
   589     fun rew {thm, lhs, elhs, fo, perm} =
   590       let
   591         val {sign, prop, maxidx, ...} = rep_thm thm;
   592         val _ = if Sign.subsig (sign, signt) then ()
   593                 else (prthm true "Ignoring rewrite rule from different theory:" thm;
   594                       raise Pattern.MATCH);
   595         val (rthm, elhs') = if maxt = ~1 then (thm, elhs)
   596           else (Thm.incr_indexes (maxt+1) thm, Thm.cterm_incr_indexes (maxt+1) elhs);
   597         val insts = if fo then Thm.cterm_first_order_match (elhs', eta_t')
   598                           else Thm.cterm_match (elhs', eta_t');
   599         val thm' = Thm.instantiate insts (Thm.rename_boundvars lhs eta_t rthm);
   600         val prop' = #prop (rep_thm thm');
   601         val unconditional = (Logic.count_prems (prop',0) = 0);
   602         val (lhs', rhs') = Logic.dest_equals (Logic.strip_imp_concl prop')
   603       in
   604         if perm andalso not (termless (rhs', lhs'))
   605         then (trace_thm false "Cannot apply permutative rewrite rule:" thm;
   606               trace_thm false "Term does not become smaller:" thm'; None)
   607         else (trace_thm false "Applying instance of rewrite rule:" thm;
   608            if unconditional
   609            then
   610              (trace_thm false "Rewriting:" thm';
   611               let val lr = Logic.dest_equals prop;
   612                   val Some thm'' = check_conv false eta_thm thm'
   613               in Some (thm'', uncond_skel (congs, lr)) end)
   614            else
   615              (trace_thm false "Trying to rewrite:" thm';
   616               case prover (incr_depth mss) thm' of
   617                 None       => (trace_thm false "FAILED" thm'; None)
   618               | Some thm2 =>
   619                   (case check_conv true eta_thm thm2 of
   620                      None => None |
   621                      Some thm2' =>
   622                        let val concl = Logic.strip_imp_concl prop
   623                            val lr = Logic.dest_equals concl
   624                        in Some (thm2', cond_skel (congs, lr)) end)))
   625       end
   626 
   627     fun rews [] = None
   628       | rews (rrule :: rrules) =
   629           let val opt = rew rrule handle Pattern.MATCH => None
   630           in case opt of None => rews rrules | some => some end;
   631 
   632     fun sort_rrules rrs = let
   633       fun is_simple({thm, ...}:rrule) = case #prop (rep_thm thm) of
   634                                       Const("==",_) $ _ $ _ => true
   635                                       | _                   => false
   636       fun sort []        (re1,re2) = re1 @ re2
   637         | sort (rr::rrs) (re1,re2) = if is_simple rr
   638                                      then sort rrs (rr::re1,re2)
   639                                      else sort rrs (re1,rr::re2)
   640     in sort rrs ([],[]) end
   641 
   642     fun proc_rews ([]:simproc list) = None
   643       | proc_rews ({name, proc, lhs, ...} :: ps) =
   644           if Pattern.matches tsigt (term_of lhs, term_of t) then
   645             (debug_term false ("Trying procedure " ^ quote name ^ " on:") signt eta_t;
   646              case proc signt prems eta_t of
   647                None => (debug false "FAILED"; proc_rews ps)
   648              | Some raw_thm =>
   649                  (trace_thm false ("Procedure " ^ quote name ^ " produced rewrite rule:") raw_thm;
   650                   (case rews (mk_procrule raw_thm) of
   651                     None => (trace_cterm false "IGNORED - does not match" t; proc_rews ps)
   652                   | some => some)))
   653           else proc_rews ps;
   654   in case eta_t of
   655        Abs _ $ _ => Some (transitive eta_thm
   656          (beta_conversion false eta_t'), skel0)
   657      | _ => (case rews (sort_rrules (Net.match_term rules eta_t)) of
   658                None => proc_rews (Net.match_term procs eta_t)
   659              | some => some)
   660   end;
   661 
   662 
   663 (* conversion to apply a congruence rule to a term *)
   664 
   665 fun congc (prover,signt,maxt) {thm=cong,lhs=lhs} t =
   666   let val {sign, ...} = rep_thm cong
   667       val _ = if Sign.subsig (sign, signt) then ()
   668                  else error("Congruence rule from different theory")
   669       val rthm = if maxt = ~1 then cong else Thm.incr_indexes (maxt+1) cong;
   670       val rlhs = fst (Drule.dest_equals (Drule.strip_imp_concl (cprop_of rthm)));
   671       val insts = Thm.cterm_match (rlhs, t)
   672       (* Pattern.match can raise Pattern.MATCH;
   673          is handled when congc is called *)
   674       val thm' = Thm.instantiate insts (Thm.rename_boundvars (term_of rlhs) (term_of t) rthm);
   675       val unit = trace_thm false "Applying congruence rule:" thm';
   676       fun err (msg, thm) = (prthm false msg thm; error "Failed congruence proof!")
   677   in case prover thm' of
   678        None => err ("Could not prove", thm')
   679      | Some thm2 => (case check_conv true (beta_eta_conversion t) thm2 of
   680           None => err ("Should not have proved", thm2)
   681         | Some thm2' =>
   682             if op aconv (pairself term_of (dest_equals (cprop_of thm2')))
   683             then None else Some thm2')
   684   end;
   685 
   686 val (cA, (cB, cC)) =
   687   apsnd dest_equals (dest_implies (hd (cprems_of Drule.imp_cong)));
   688 
   689 fun transitive' thm1 None = Some thm1
   690   | transitive' thm1 (Some thm2) = Some (transitive thm1 thm2);
   691 
   692 fun transitive'' None thm2 = Some thm2
   693   | transitive'' (Some thm1) thm2 = Some (transitive thm1 thm2);
   694 
   695 fun bottomc ((simprem,useprem,mutsimp), prover, sign, maxidx) =
   696   let
   697     fun botc skel mss t =
   698           if is_Var skel then None
   699           else
   700           (case subc skel mss t of
   701              some as Some thm1 =>
   702                (case rewritec (prover, sign, maxidx) mss (rhs_of thm1) of
   703                   Some (thm2, skel2) =>
   704                     transitive' (transitive thm1 thm2)
   705                       (botc skel2 mss (rhs_of thm2))
   706                 | None => some)
   707            | None =>
   708                (case rewritec (prover, sign, maxidx) mss t of
   709                   Some (thm2, skel2) => transitive' thm2
   710                     (botc skel2 mss (rhs_of thm2))
   711                 | None => None))
   712 
   713     and try_botc mss t =
   714           (case botc skel0 mss t of
   715              Some trec1 => trec1 | None => (reflexive t))
   716 
   717     and subc skel
   718           (mss as Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) t0 =
   719        (case term_of t0 of
   720            Abs (a, T, t) =>
   721              let val b = variant bounds a
   722                  val (v, t') = Thm.dest_abs (Some ("." ^ b)) t0
   723                  val mss' = mk_mss (rules, congs, procs, b :: bounds, prems, mk_rews, termless,depth)
   724                  val skel' = case skel of Abs (_, _, sk) => sk | _ => skel0
   725              in case botc skel' mss' t' of
   726                   Some thm => Some (abstract_rule a v thm)
   727                 | None => None
   728              end
   729          | t $ _ => (case t of
   730              Const ("==>", _) $ _  =>
   731                let val (s, u) = Drule.dest_implies t0
   732                in impc (s, u, mss) end
   733            | Abs _ =>
   734                let val thm = beta_conversion false t0
   735                in case subc skel0 mss (rhs_of thm) of
   736                     None => Some thm
   737                   | Some thm' => Some (transitive thm thm')
   738                end
   739            | _  =>
   740                let fun appc () =
   741                      let
   742                        val (tskel, uskel) = case skel of
   743                            tskel $ uskel => (tskel, uskel)
   744                          | _ => (skel0, skel0);
   745                        val (ct, cu) = Thm.dest_comb t0
   746                      in
   747                      (case botc tskel mss ct of
   748                         Some thm1 =>
   749                           (case botc uskel mss cu of
   750                              Some thm2 => Some (combination thm1 thm2)
   751                            | None => Some (combination thm1 (reflexive cu)))
   752                       | None =>
   753                           (case botc uskel mss cu of
   754                              Some thm1 => Some (combination (reflexive ct) thm1)
   755                            | None => None))
   756                      end
   757                    val (h, ts) = strip_comb t
   758                in case h of
   759                     Const(a, _) =>
   760                       (case assoc_string (fst congs, a) of
   761                          None => appc ()
   762                        | Some cong =>
   763 (* post processing: some partial applications h t1 ... tj, j <= length ts,
   764    may be a redex. Example: map (%x.x) = (%xs.xs) wrt map_cong *)
   765                           (let
   766                              val thm = congc (prover mss, sign, maxidx) cong t0;
   767                              val t = if_none (apsome rhs_of thm) t0;
   768                              val (cl, cr) = Thm.dest_comb t
   769                              val dVar = Var(("", 0), dummyT)
   770                              val skel =
   771                                list_comb (h, replicate (length ts) dVar)
   772                            in case botc skel mss cl of
   773                                 None => thm
   774                               | Some thm' => transitive'' thm
   775                                   (combination thm' (reflexive cr))
   776                            end handle TERM _ => error "congc result"
   777                                     | Pattern.MATCH => appc ()))
   778                   | _ => appc ()
   779                end)
   780          | _ => None)
   781 
   782     and impc args =
   783       if mutsimp
   784       then let val (prem, conc, mss) = args
   785            in apsome snd (mut_impc ([], prem, conc, mss)) end
   786       else nonmut_impc args
   787 
   788     and mut_impc (prems, prem, conc, mss) = (case botc skel0 mss prem of
   789         None => mut_impc1 (prems, prem, conc, mss)
   790       | Some thm1 =>
   791           let val prem1 = rhs_of thm1
   792           in (case mut_impc1 (prems, prem1, conc, mss) of
   793               None => Some (None,
   794                 combination (combination refl_implies thm1) (reflexive conc))
   795             | Some (x, thm2) => Some (x, transitive (combination (combination
   796                 refl_implies thm1) (reflexive conc)) thm2))
   797           end)
   798 
   799     and mut_impc1 (prems, prem1, conc, mss) =
   800       let
   801         fun uncond ({thm, lhs, elhs, perm}) =
   802           if Thm.no_prems thm then Some lhs else None
   803 
   804         val (lhss1, mss1) =
   805           if maxidx_of_term (term_of prem1) <> ~1
   806           then (trace_cterm true
   807             "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem1;
   808                 ([],mss))
   809           else let val thm = assume prem1
   810                    val rrules1 = extract_safe_rrules (mss, thm)
   811                    val lhss1 = mapfilter uncond rrules1
   812                    val mss1 = foldl insert_rrule (add_prems (mss, [thm]), rrules1)
   813                in (lhss1, mss1) end
   814 
   815         fun disch1 thm =
   816           let val (cB', cC') = dest_eq thm
   817           in
   818             implies_elim (Thm.instantiate
   819               ([], [(cA, prem1), (cB, cB'), (cC, cC')]) Drule.imp_cong)
   820               (implies_intr prem1 thm)
   821           end
   822 
   823         fun rebuild None = (case rewritec (prover, sign, maxidx) mss
   824             (mk_implies (prem1, conc)) of
   825               None => None
   826             | Some (thm, _) =>
   827                 let val (prem, conc) = Drule.dest_implies (rhs_of thm)
   828                 in (case mut_impc (prems, prem, conc, mss) of
   829                     None => Some (None, thm)
   830                   | Some (x, thm') => Some (x, transitive thm thm'))
   831                 end handle TERM _ => Some (None, thm))
   832           | rebuild (Some thm2) =
   833             let val thm = disch1 thm2
   834             in (case rewritec (prover, sign, maxidx) mss (rhs_of thm) of
   835                  None => Some (None, thm)
   836                | Some (thm', _) =>
   837                    let val (prem, conc) = Drule.dest_implies (rhs_of thm')
   838                    in (case mut_impc (prems, prem, conc, mss) of
   839                        None => Some (None, transitive thm thm')
   840                      | Some (x, thm'') =>
   841                          Some (x, transitive (transitive thm thm') thm''))
   842                    end handle TERM _ => Some (None, transitive thm thm'))
   843             end
   844 
   845         fun simpconc () =
   846           let val (s, t) = Drule.dest_implies conc
   847           in case mut_impc (prems @ [prem1], s, t, mss1) of
   848                None => rebuild None
   849              | Some (Some i, thm2) =>
   850                   let
   851                     val (prem, cC') = Drule.dest_implies (rhs_of thm2);
   852                     val thm2' = transitive (disch1 thm2) (Thm.instantiate
   853                       ([], [(cA, prem1), (cB, prem), (cC, cC')])
   854                       Drule.swap_prems_eq)
   855                   in if i=0 then apsome (apsnd (transitive thm2'))
   856                        (mut_impc1 (prems, prem, mk_implies (prem1, cC'), mss))
   857                      else Some (Some (i-1), thm2')
   858                   end
   859              | Some (None, thm) => rebuild (Some thm)
   860           end handle TERM _ => rebuild (botc skel0 mss1 conc)
   861 
   862       in
   863         let
   864           val tsig = Sign.tsig_of sign
   865           fun reducible t =
   866             exists (fn lhs => Pattern.matches_subterm tsig (lhs, term_of t)) lhss1;
   867         in case dropwhile (not o reducible) prems of
   868             [] => simpconc ()
   869           | red::rest => (trace_cterm false "Can now reduce premise:" red;
   870               Some (Some (length rest), reflexive (mk_implies (prem1, conc))))
   871         end
   872       end
   873 
   874      (* legacy code - only for backwards compatibility *)
   875      and nonmut_impc (prem, conc, mss) =
   876        let val thm1 = if simprem then botc skel0 mss prem else None;
   877            val prem1 = if_none (apsome rhs_of thm1) prem;
   878            val maxidx1 = maxidx_of_term (term_of prem1)
   879            val mss1 =
   880              if not useprem then mss else
   881              if maxidx1 <> ~1
   882              then (trace_cterm true
   883                "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem1;
   884                    mss)
   885              else let val thm = assume prem1
   886                   in add_safe_simp (add_prems (mss, [thm]), thm) end
   887        in (case botc skel0 mss1 conc of
   888            None => (case thm1 of
   889                None => None
   890              | Some thm1' => Some (combination
   891                  (combination refl_implies thm1') (reflexive conc)))
   892          | Some thm2 =>
   893            let
   894              val conc2 = rhs_of thm2;
   895              val thm2' = implies_elim (Thm.instantiate
   896                ([], [(cA, prem1), (cB, conc), (cC, conc2)]) Drule.imp_cong)
   897                (implies_intr prem1 thm2)
   898            in (case thm1 of
   899                None => Some thm2'
   900              | Some thm1' => Some (transitive (combination
   901                  (combination refl_implies thm1') (reflexive conc)) thm2'))
   902            end)
   903        end
   904 
   905  in try_botc end;
   906 
   907 
   908 (*** Meta-rewriting: rewrites t to u and returns the theorem t==u ***)
   909 
   910 (*
   911   Parameters:
   912     mode = (simplify A,
   913             use A in simplifying B,
   914             use prems of B (if B is again a meta-impl.) to simplify A)
   915            when simplifying A ==> B
   916     mss: contains equality theorems of the form [|p1,...|] ==> t==u
   917     prover: how to solve premises in conditional rewrites and congruences
   918 *)
   919 
   920 fun rewrite_cterm mode prover mss ct =
   921   let val {sign, t, maxidx, ...} = rep_cterm ct
   922       val Mss{depth, ...} = mss
   923   in simp_depth := depth;
   924      bottomc (mode, prover, sign, maxidx) mss ct
   925   end
   926   handle THM (s, _, thms) =>
   927     error ("Exception THM was raised in simplifier:\n" ^ s ^ "\n" ^
   928       Pretty.string_of (Display.pretty_thms thms));
   929 
   930 (*In [A1,...,An]==>B, rewrite the selected A's only -- for rewrite_goals_tac*)
   931 (*Do not rewrite flex-flex pairs*)
   932 fun goals_conv pred cv =
   933   let fun gconv i ct =
   934         let val (A,B) = Drule.dest_implies ct
   935             val (thA,j) = case term_of A of
   936                   Const("=?=",_)$_$_ => (reflexive A, i)
   937                 | _ => (if pred i then cv A else reflexive A, i+1)
   938         in  combination (combination refl_implies thA) (gconv j B) end
   939         handle TERM _ => reflexive ct
   940   in gconv 1 end;
   941 
   942 (* Rewrite A in !!x1,...,xn. A *)
   943 fun forall_conv cv ct =
   944   let val p as (ct1, ct2) = Thm.dest_comb ct
   945   in (case pairself term_of p of
   946       (Const ("all", _), Abs (s, _, _)) =>
   947          let val (v, ct') = Thm.dest_abs (Some "@") ct2;
   948          in Thm.combination (Thm.reflexive ct1)
   949            (Thm.abstract_rule s v (forall_conv cv ct'))
   950          end
   951     | _ => cv ct)
   952   end handle TERM _ => cv ct;
   953 
   954 (*Use a conversion to transform a theorem*)
   955 fun fconv_rule cv th = equal_elim (cv (cprop_of th)) th;
   956 
   957 (*Rewrite a cterm*)
   958 fun rewrite_aux _ _ [] = (fn ct => Thm.reflexive ct)
   959   | rewrite_aux prover full thms = rewrite_cterm (full, false, false) prover (mss_of thms);
   960 
   961 (*Rewrite a theorem*)
   962 fun simplify_aux _ _ [] = (fn th => th)
   963   | simplify_aux prover full thms =
   964       fconv_rule (rewrite_cterm (full, false, false) prover (mss_of thms));
   965 
   966 fun rewrite_thm mode prover mss = fconv_rule (rewrite_cterm mode prover mss);
   967 
   968 (*Rewrite the subgoals of a proof state (represented by a theorem) *)
   969 fun rewrite_goals_rule_aux _ []   th = th
   970   | rewrite_goals_rule_aux prover thms th =
   971       fconv_rule (goals_conv (K true) (rewrite_cterm (true, true, false) prover
   972         (mss_of thms))) th;
   973 
   974 (*Rewrite the subgoal of a proof state (represented by a theorem) *)
   975 fun rewrite_goal_rule mode prover mss i thm =
   976   if 0 < i  andalso  i <= nprems_of thm
   977   then fconv_rule (goals_conv (fn j => j=i) (rewrite_cterm mode prover mss)) thm
   978   else raise THM("rewrite_goal_rule",i,[thm]);
   979 
   980 end;
   981 
   982 structure BasicMetaSimplifier: BASIC_META_SIMPLIFIER = MetaSimplifier;
   983 open BasicMetaSimplifier;