src/Pure/meta_simplifier.ML
author wenzelm
Wed Jan 16 20:58:27 2002 +0100 (2002-01-16)
changeset 12779 c5739c1431ab
parent 12603 7d2bca103101
child 12783 36ca5ea8092c
permissions -rw-r--r--
export beta_eta_conversion;
     1 (*  Title:      Pure/meta_simplifier.ML
     2     ID:         $Id$
     3     Author:     Tobias Nipkow and Stefan Berghofer
     4     Copyright   1994  University of Cambridge
     5 
     6 Meta-level Simplification.
     7 *)
     8 
     9 signature BASIC_META_SIMPLIFIER =
    10 sig
    11   val trace_simp: bool ref
    12   val debug_simp: bool ref
    13 end;
    14 
    15 signature META_SIMPLIFIER =
    16 sig
    17   include BASIC_META_SIMPLIFIER
    18   exception SIMPLIFIER of string * thm
    19   type meta_simpset
    20   val dest_mss          : meta_simpset ->
    21     {simps: thm list, congs: thm list, procs: (string * cterm list) list}
    22   val empty_mss         : meta_simpset
    23   val clear_mss         : meta_simpset -> meta_simpset
    24   val merge_mss         : meta_simpset * meta_simpset -> meta_simpset
    25   val add_simps         : meta_simpset * thm list -> meta_simpset
    26   val del_simps         : meta_simpset * thm list -> meta_simpset
    27   val mss_of            : thm list -> meta_simpset
    28   val add_congs         : meta_simpset * thm list -> meta_simpset
    29   val del_congs         : meta_simpset * thm list -> meta_simpset
    30   val add_simprocs      : meta_simpset *
    31     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    32       -> meta_simpset
    33   val del_simprocs      : meta_simpset *
    34     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    35       -> meta_simpset
    36   val add_prems         : meta_simpset * thm list -> meta_simpset
    37   val prems_of_mss      : meta_simpset -> thm list
    38   val set_mk_rews       : meta_simpset * (thm -> thm list) -> meta_simpset
    39   val set_mk_sym        : meta_simpset * (thm -> thm option) -> meta_simpset
    40   val set_mk_eq_True    : meta_simpset * (thm -> thm option) -> meta_simpset
    41   val set_termless      : meta_simpset * (term * term -> bool) -> meta_simpset
    42   val beta_eta_conversion: cterm -> thm
    43   val rewrite_cterm: bool * bool * bool ->
    44     (meta_simpset -> thm -> thm option) -> meta_simpset -> cterm -> thm
    45   val goals_conv        : (int -> bool) -> (cterm -> thm) -> cterm -> thm
    46   val forall_conv       : (cterm -> thm) -> cterm -> thm
    47   val fconv_rule        : (cterm -> thm) -> thm -> thm
    48   val rewrite_aux       : (meta_simpset -> thm -> thm option) -> bool -> thm list -> cterm -> thm
    49   val simplify_aux      : (meta_simpset -> thm -> thm option) -> bool -> thm list -> thm -> thm
    50   val rewrite_thm       : bool * bool * bool
    51                           -> (meta_simpset -> thm -> thm option)
    52                           -> meta_simpset -> thm -> thm
    53   val rewrite_goals_rule_aux: (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
    54   val rewrite_goal_rule : bool* bool * bool
    55                           -> (meta_simpset -> thm -> thm option)
    56                           -> meta_simpset -> int -> thm -> thm
    57 end;
    58 
    59 structure MetaSimplifier : META_SIMPLIFIER =
    60 struct
    61 
    62 (** diagnostics **)
    63 
    64 exception SIMPLIFIER of string * thm;
    65 
    66 val simp_depth = ref 0;
    67 
    68 local
    69 
    70 fun println a =
    71   tracing ((case ! simp_depth of 0 => "" | n => "[" ^ string_of_int n ^ "]") ^ a);
    72 
    73 fun prnt warn a = if warn then warning a else println a;
    74 fun prtm warn a sign t = prnt warn (a ^ "\n" ^ Sign.string_of_term sign t);
    75 fun prctm warn a t = prnt warn (a ^ "\n" ^ Display.string_of_cterm t);
    76 
    77 in
    78 
    79 fun prthm warn a = prctm warn a o Thm.cprop_of;
    80 
    81 val trace_simp = ref false;
    82 val debug_simp = ref false;
    83 
    84 fun trace warn a = if !trace_simp then prnt warn a else ();
    85 fun debug warn a = if !debug_simp then prnt warn a else ();
    86 
    87 fun trace_term warn a sign t = if !trace_simp then prtm warn a sign t else ();
    88 fun trace_cterm warn a t = if !trace_simp then prctm warn a t else ();
    89 fun debug_term warn a sign t = if !debug_simp then prtm warn a sign t else ();
    90 
    91 fun trace_thm warn a thm =
    92   let val {sign, prop, ...} = rep_thm thm
    93   in trace_term warn a sign prop end;
    94 
    95 end;
    96 
    97 
    98 (** meta simp sets **)
    99 
   100 (* basic components *)
   101 
   102 type rrule = {thm: thm, lhs: term, elhs: cterm, fo: bool, perm: bool};
   103 (* thm: the rewrite rule
   104    lhs: the left-hand side
   105    elhs: the etac-contracted lhs.
   106    fo:  use first-order matching
   107    perm: the rewrite rule is permutative
   108 Remarks:
   109   - elhs is used for matching,
   110     lhs only for preservation of bound variable names.
   111   - fo is set iff
   112     either elhs is first-order (no Var is applied),
   113            in which case fo-matching is complete,
   114     or elhs is not a pattern,
   115        in which case there is nothing better to do.
   116 *)
   117 type cong = {thm: thm, lhs: cterm};
   118 type simproc =
   119  {name: string, proc: Sign.sg -> thm list -> term -> thm option, lhs: cterm, id: stamp};
   120 
   121 fun eq_rrule ({thm = thm1, ...}: rrule, {thm = thm2, ...}: rrule) =
   122   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   123 
   124 fun eq_cong ({thm = thm1, ...}: cong, {thm = thm2, ...}: cong) =
   125   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   126 
   127 fun eq_prem (thm1, thm2) =
   128   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   129 
   130 fun eq_simproc ({id = s1, ...}:simproc, {id = s2, ...}:simproc) = (s1 = s2);
   131 
   132 fun mk_simproc (name, proc, lhs, id) =
   133   {name = name, proc = proc, lhs = lhs, id = id};
   134 
   135 
   136 (* datatype mss *)
   137 
   138 (*
   139   A "mss" contains data needed during conversion:
   140     rules: discrimination net of rewrite rules;
   141     congs: association list of congruence rules and
   142            a list of `weak' congruence constants.
   143            A congruence is `weak' if it avoids normalization of some argument.
   144     procs: discrimination net of simplification procedures
   145       (functions that prove rewrite rules on the fly);
   146     bounds: names of bound variables already used
   147       (for generating new names when rewriting under lambda abstractions);
   148     prems: current premises;
   149     mk_rews: mk: turns simplification thms into rewrite rules;
   150              mk_sym: turns == around; (needs Drule!)
   151              mk_eq_True: turns P into P == True - logic specific;
   152     termless: relation for ordered rewriting;
   153     depth: depth of conditional rewriting;
   154 *)
   155 
   156 datatype meta_simpset =
   157   Mss of {
   158     rules: rrule Net.net,
   159     congs: (string * cong) list * string list,
   160     procs: simproc Net.net,
   161     bounds: string list,
   162     prems: thm list,
   163     mk_rews: {mk: thm -> thm list,
   164               mk_sym: thm -> thm option,
   165               mk_eq_True: thm -> thm option},
   166     termless: term * term -> bool,
   167     depth: int};
   168 
   169 fun mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth) =
   170   Mss {rules = rules, congs = congs, procs = procs, bounds = bounds,
   171        prems=prems, mk_rews=mk_rews, termless=termless, depth=depth};
   172 
   173 fun upd_rules(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}, rules') =
   174   mk_mss(rules',congs,procs,bounds,prems,mk_rews,termless,depth);
   175 
   176 val empty_mss =
   177   let val mk_rews = {mk = K [], mk_sym = K None, mk_eq_True = K None}
   178   in mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, Term.termless, 0) end;
   179 
   180 fun clear_mss (Mss {mk_rews, termless, ...}) =
   181   mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, termless,0);
   182 
   183 fun incr_depth(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) =
   184   mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth+1)
   185 
   186 
   187 
   188 (** simpset operations **)
   189 
   190 (* term variables *)
   191 
   192 val add_term_varnames = foldl_aterms (fn (xs, Var (x, _)) => ins_ix (x, xs) | (xs, _) => xs);
   193 fun term_varnames t = add_term_varnames ([], t);
   194 
   195 
   196 (* dest_mss *)
   197 
   198 fun dest_mss (Mss {rules, congs, procs, ...}) =
   199   {simps = map (fn (_, {thm, ...}) => thm) (Net.dest rules),
   200    congs = map (fn (_, {thm, ...}) => thm) (fst congs),
   201    procs =
   202      map (fn (_, {name, lhs, id, ...}) => ((name, lhs), id)) (Net.dest procs)
   203      |> partition_eq eq_snd
   204      |> map (fn ps => (#1 (#1 (hd ps)), map (#2 o #1) ps))
   205      |> Library.sort_wrt #1};
   206 
   207 
   208 (* merge_mss *)       (*NOTE: ignores mk_rews, termless and depth of 2nd mss*)
   209 
   210 fun merge_mss
   211  (Mss {rules = rules1, congs = (congs1,weak1), procs = procs1,
   212        bounds = bounds1, prems = prems1, mk_rews, termless, depth},
   213   Mss {rules = rules2, congs = (congs2,weak2), procs = procs2,
   214        bounds = bounds2, prems = prems2, ...}) =
   215       mk_mss
   216        (Net.merge (rules1, rules2, eq_rrule),
   217         (gen_merge_lists (eq_cong o pairself snd) congs1 congs2,
   218         merge_lists weak1 weak2),
   219         Net.merge (procs1, procs2, eq_simproc),
   220         merge_lists bounds1 bounds2,
   221         gen_merge_lists eq_prem prems1 prems2,
   222         mk_rews, termless, depth);
   223 
   224 
   225 (* add_simps *)
   226 
   227 fun mk_rrule2{thm,lhs,elhs,perm} =
   228   let val fo = Pattern.first_order (term_of elhs) orelse not(Pattern.pattern (term_of elhs))
   229   in {thm=thm,lhs=lhs,elhs=elhs,fo=fo,perm=perm} end
   230 
   231 fun insert_rrule(mss as Mss {rules,...},
   232                  rrule as {thm,lhs,elhs,perm}) =
   233   (trace_thm false "Adding rewrite rule:" thm;
   234    let val rrule2 as {elhs,...} = mk_rrule2 rrule
   235        val rules' = Net.insert_term ((term_of elhs, rrule2), rules, eq_rrule)
   236    in upd_rules(mss,rules') end
   237    handle Net.INSERT =>
   238      (prthm true "Ignoring duplicate rewrite rule:" thm; mss));
   239 
   240 fun vperm (Var _, Var _) = true
   241   | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t)
   242   | vperm (t1 $ t2, u1 $ u2) = vperm (t1, u1) andalso vperm (t2, u2)
   243   | vperm (t, u) = (t = u);
   244 
   245 fun var_perm (t, u) =
   246   vperm (t, u) andalso eq_set (term_varnames t, term_varnames u);
   247 
   248 (* FIXME: it seems that the conditions on extra variables are too liberal if
   249 prems are nonempty: does solving the prems really guarantee instantiation of
   250 all its Vars? Better: a dynamic check each time a rule is applied.
   251 *)
   252 fun rewrite_rule_extra_vars prems elhs erhs =
   253   not (term_varnames erhs subset foldl add_term_varnames (term_varnames elhs, prems))
   254   orelse
   255   not ((term_tvars erhs) subset
   256        (term_tvars elhs  union  List.concat(map term_tvars prems)));
   257 
   258 (*Simple test for looping rewrite rules and stupid orientations*)
   259 fun reorient sign prems lhs rhs =
   260    rewrite_rule_extra_vars prems lhs rhs
   261   orelse
   262    is_Var (head_of lhs)
   263   orelse
   264    (exists (apl (lhs, Logic.occs)) (rhs :: prems))
   265   orelse
   266    (null prems andalso
   267     Pattern.matches (#tsig (Sign.rep_sg sign)) (lhs, rhs))
   268     (*the condition "null prems" is necessary because conditional rewrites
   269       with extra variables in the conditions may terminate although
   270       the rhs is an instance of the lhs. Example: ?m < ?n ==> f(?n) == f(?m)*)
   271   orelse
   272    (is_Const lhs andalso not(is_Const rhs))
   273 
   274 fun decomp_simp thm =
   275   let val {sign, prop, ...} = rep_thm thm;
   276       val prems = Logic.strip_imp_prems prop;
   277       val concl = Drule.strip_imp_concl (cprop_of thm);
   278       val (lhs, rhs) = Drule.dest_equals concl handle TERM _ =>
   279         raise SIMPLIFIER ("Rewrite rule not a meta-equality", thm)
   280       val elhs = snd (Drule.dest_equals (cprop_of (Thm.eta_conversion lhs)));
   281       val elhs = if elhs=lhs then lhs else elhs (* try to share *)
   282       val erhs = Pattern.eta_contract (term_of rhs);
   283       val perm = var_perm (term_of elhs, erhs) andalso not (term_of elhs aconv erhs)
   284                  andalso not (is_Var (term_of elhs))
   285   in (sign, prems, term_of lhs, elhs, term_of rhs, perm) end;
   286 
   287 fun mk_eq_True (Mss{mk_rews={mk_eq_True,...},...}) thm =
   288   case mk_eq_True thm of
   289     None => []
   290   | Some eq_True => let val (_,_,lhs,elhs,_,_) = decomp_simp eq_True
   291                     in [{thm=eq_True, lhs=lhs, elhs=elhs, perm=false}] end;
   292 
   293 (* create the rewrite rule and possibly also the ==True variant,
   294    in case there are extra vars on the rhs *)
   295 fun rrule_eq_True(thm,lhs,elhs,rhs,mss,thm2) =
   296   let val rrule = {thm=thm, lhs=lhs, elhs=elhs, perm=false}
   297   in if (term_varnames rhs)  subset (term_varnames lhs) andalso
   298         (term_tvars rhs) subset (term_tvars lhs)
   299      then [rrule]
   300      else mk_eq_True mss thm2 @ [rrule]
   301   end;
   302 
   303 fun mk_rrule mss thm =
   304   let val (_,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   305   in if perm then [{thm=thm, lhs=lhs, elhs=elhs, perm=true}] else
   306      (* weak test for loops: *)
   307      if rewrite_rule_extra_vars prems lhs rhs orelse
   308         is_Var (term_of elhs)
   309      then mk_eq_True mss thm
   310      else rrule_eq_True(thm,lhs,elhs,rhs,mss,thm)
   311   end;
   312 
   313 fun orient_rrule mss thm =
   314   let val (sign,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   315   in if perm then [{thm=thm,lhs=lhs,elhs=elhs,perm=true}]
   316      else if reorient sign prems lhs rhs
   317           then if reorient sign prems rhs lhs
   318                then mk_eq_True mss thm
   319                else let val Mss{mk_rews={mk_sym,...},...} = mss
   320                     in case mk_sym thm of
   321                          None => []
   322                        | Some thm' =>
   323                            let val (_,_,lhs',elhs',rhs',_) = decomp_simp thm'
   324                            in rrule_eq_True(thm',lhs',elhs',rhs',mss,thm) end
   325                     end
   326           else rrule_eq_True(thm,lhs,elhs,rhs,mss,thm)
   327   end;
   328 
   329 fun extract_rews(Mss{mk_rews = {mk,...},...},thms) = flat(map mk thms);
   330 
   331 fun orient_comb_simps comb mk_rrule (mss,thms) =
   332   let val rews = extract_rews(mss,thms)
   333       val rrules = flat (map mk_rrule rews)
   334   in foldl comb (mss,rrules) end
   335 
   336 (* Add rewrite rules explicitly; do not reorient! *)
   337 fun add_simps(mss,thms) =
   338   orient_comb_simps insert_rrule (mk_rrule mss) (mss,thms);
   339 
   340 fun mss_of thms =
   341   foldl insert_rrule (empty_mss, flat(map (mk_rrule empty_mss) thms));
   342 
   343 fun extract_safe_rrules(mss,thm) =
   344   flat (map (orient_rrule mss) (extract_rews(mss,[thm])));
   345 
   346 fun add_safe_simp(mss,thm) =
   347   foldl insert_rrule (mss, extract_safe_rrules(mss,thm))
   348 
   349 (* del_simps *)
   350 
   351 fun del_rrule(mss as Mss {rules,...},
   352               rrule as {thm, elhs, ...}) =
   353   (upd_rules(mss, Net.delete_term ((term_of elhs, rrule), rules, eq_rrule))
   354    handle Net.DELETE =>
   355      (prthm true "Rewrite rule not in simpset:" thm; mss));
   356 
   357 fun del_simps(mss,thms) =
   358   orient_comb_simps del_rrule (map mk_rrule2 o mk_rrule mss) (mss,thms);
   359 
   360 
   361 (* add_congs *)
   362 
   363 fun is_full_cong_prems [] varpairs = null varpairs
   364   | is_full_cong_prems (p::prems) varpairs =
   365     (case Logic.strip_assums_concl p of
   366        Const("==",_) $ lhs $ rhs =>
   367          let val (x,xs) = strip_comb lhs and (y,ys) = strip_comb rhs
   368          in is_Var x  andalso  forall is_Bound xs  andalso
   369             null(findrep(xs))  andalso xs=ys andalso
   370             (x,y) mem varpairs andalso
   371             is_full_cong_prems prems (varpairs\(x,y))
   372          end
   373      | _ => false);
   374 
   375 fun is_full_cong thm =
   376 let val prems = prems_of thm
   377     and concl = concl_of thm
   378     val (lhs,rhs) = Logic.dest_equals concl
   379     val (f,xs) = strip_comb lhs
   380     and (g,ys) = strip_comb rhs
   381 in
   382   f=g andalso null(findrep(xs@ys)) andalso length xs = length ys andalso
   383   is_full_cong_prems prems (xs ~~ ys)
   384 end
   385 
   386 fun add_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   387   let
   388     val (lhs, _) = Drule.dest_equals (Drule.strip_imp_concl (cprop_of thm)) handle TERM _ =>
   389       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   390 (*   val lhs = Pattern.eta_contract lhs; *)
   391     val (a, _) = dest_Const (head_of (term_of lhs)) handle TERM _ =>
   392       raise SIMPLIFIER ("Congruence must start with a constant", thm);
   393     val (alist,weak) = congs
   394     val alist2 = overwrite_warn (alist, (a,{lhs=lhs, thm=thm}))
   395            ("Overwriting congruence rule for " ^ quote a);
   396     val weak2 = if is_full_cong thm then weak else a::weak
   397   in
   398     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   399   end;
   400 
   401 val (op add_congs) = foldl add_cong;
   402 
   403 
   404 (* del_congs *)
   405 
   406 fun del_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   407   let
   408     val (lhs, _) = Logic.dest_equals (concl_of thm) handle TERM _ =>
   409       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   410 (*   val lhs = Pattern.eta_contract lhs; *)
   411     val (a, _) = dest_Const (head_of lhs) handle TERM _ =>
   412       raise SIMPLIFIER ("Congruence must start with a constant", thm);
   413     val (alist,_) = congs
   414     val alist2 = filter (fn (x,_)=> x<>a) alist
   415     val weak2 = mapfilter (fn(a,{thm,...}) => if is_full_cong thm then None
   416                                               else Some a)
   417                    alist2
   418   in
   419     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   420   end;
   421 
   422 val (op del_congs) = foldl del_cong;
   423 
   424 
   425 (* add_simprocs *)
   426 
   427 fun add_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   428     (name, lhs, proc, id)) =
   429   let val {sign, t, ...} = rep_cterm lhs
   430   in (trace_term false ("Adding simplification procedure " ^ quote name ^ " for")
   431       sign t;
   432     mk_mss (rules, congs,
   433       Net.insert_term ((t, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   434         handle Net.INSERT =>
   435             (warning ("Ignoring duplicate simplification procedure \""
   436                       ^ name ^ "\"");
   437              procs),
   438         bounds, prems, mk_rews, termless,depth))
   439   end;
   440 
   441 fun add_simproc (mss, (name, lhss, proc, id)) =
   442   foldl add_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   443 
   444 val add_simprocs = foldl add_simproc;
   445 
   446 
   447 (* del_simprocs *)
   448 
   449 fun del_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   450     (name, lhs, proc, id)) =
   451   mk_mss (rules, congs,
   452     Net.delete_term ((term_of lhs, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   453       handle Net.DELETE =>
   454           (warning ("Simplification procedure \"" ^ name ^
   455                        "\" not in simpset"); procs),
   456       bounds, prems, mk_rews, termless, depth);
   457 
   458 fun del_simproc (mss, (name, lhss, proc, id)) =
   459   foldl del_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   460 
   461 val del_simprocs = foldl del_simproc;
   462 
   463 
   464 (* prems *)
   465 
   466 fun add_prems (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thms) =
   467   mk_mss (rules, congs, procs, bounds, thms @ prems, mk_rews, termless, depth);
   468 
   469 fun prems_of_mss (Mss {prems, ...}) = prems;
   470 
   471 
   472 (* mk_rews *)
   473 
   474 fun set_mk_rews
   475   (Mss {rules, congs, procs, bounds, prems, mk_rews, termless, depth}, mk) =
   476     mk_mss (rules, congs, procs, bounds, prems,
   477             {mk=mk, mk_sym= #mk_sym mk_rews, mk_eq_True= #mk_eq_True mk_rews},
   478             termless, depth);
   479 
   480 fun set_mk_sym
   481   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_sym) =
   482     mk_mss (rules, congs, procs, bounds, prems,
   483             {mk= #mk mk_rews, mk_sym= mk_sym, mk_eq_True= #mk_eq_True mk_rews},
   484             termless,depth);
   485 
   486 fun set_mk_eq_True
   487   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_eq_True) =
   488     mk_mss (rules, congs, procs, bounds, prems,
   489             {mk= #mk mk_rews, mk_sym= #mk_sym mk_rews, mk_eq_True= mk_eq_True},
   490             termless,depth);
   491 
   492 (* termless *)
   493 
   494 fun set_termless
   495   (Mss {rules, congs, procs, bounds, prems, mk_rews, depth, ...}, termless) =
   496     mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth);
   497 
   498 
   499 
   500 (** rewriting **)
   501 
   502 (*
   503   Uses conversions, see:
   504     L C Paulson, A higher-order implementation of rewriting,
   505     Science of Computer Programming 3 (1983), pages 119-149.
   506 *)
   507 
   508 type prover = meta_simpset -> thm -> thm option;
   509 type termrec = (Sign.sg_ref * term list) * term;
   510 type conv = meta_simpset -> termrec -> termrec;
   511 
   512 val dest_eq = Drule.dest_equals o cprop_of;
   513 val lhs_of = fst o dest_eq;
   514 val rhs_of = snd o dest_eq;
   515 
   516 fun beta_eta_conversion t =
   517   let val thm = beta_conversion true t;
   518   in transitive thm (eta_conversion (rhs_of thm)) end;
   519 
   520 fun check_conv msg thm thm' =
   521   let
   522     val thm'' = transitive thm (transitive
   523       (symmetric (beta_eta_conversion (lhs_of thm'))) thm')
   524   in (if msg then trace_thm false "SUCCEEDED" thm' else (); Some thm'') end
   525   handle THM _ =>
   526     let val {sign, prop = _ $ _ $ prop0, ...} = rep_thm thm;
   527     in
   528       (trace_thm false "Proved wrong thm (Check subgoaler?)" thm';
   529        trace_term false "Should have proved:" sign prop0;
   530        None)
   531     end;
   532 
   533 
   534 (* mk_procrule *)
   535 
   536 fun mk_procrule thm =
   537   let val (_,prems,lhs,elhs,rhs,_) = decomp_simp thm
   538   in if rewrite_rule_extra_vars prems lhs rhs
   539      then (prthm true "Extra vars on rhs:" thm; [])
   540      else [mk_rrule2{thm=thm, lhs=lhs, elhs=elhs, perm=false}]
   541   end;
   542 
   543 
   544 (* conversion to apply the meta simpset to a term *)
   545 
   546 (* Since the rewriting strategy is bottom-up, we avoid re-normalizing already
   547    normalized terms by carrying around the rhs of the rewrite rule just
   548    applied. This is called the `skeleton'. It is decomposed in parallel
   549    with the term. Once a Var is encountered, the corresponding term is
   550    already in normal form.
   551    skel0 is a dummy skeleton that is to enforce complete normalization.
   552 *)
   553 val skel0 = Bound 0;
   554 
   555 (* Use rhs as skeleton only if the lhs does not contain unnormalized bits.
   556    The latter may happen iff there are weak congruence rules for constants
   557    in the lhs.
   558 *)
   559 fun uncond_skel((_,weak),(lhs,rhs)) =
   560   if null weak then rhs (* optimization *)
   561   else if exists_Const (fn (c,_) => c mem weak) lhs then skel0
   562        else rhs;
   563 
   564 (* Behaves like unconditional rule if rhs does not contain vars not in the lhs.
   565    Otherwise those vars may become instantiated with unnormalized terms
   566    while the premises are solved.
   567 *)
   568 fun cond_skel(args as (congs,(lhs,rhs))) =
   569   if term_varnames rhs subset term_varnames lhs then uncond_skel(args)
   570   else skel0;
   571 
   572 (*
   573   we try in order:
   574     (1) beta reduction
   575     (2) unconditional rewrite rules
   576     (3) conditional rewrite rules
   577     (4) simplification procedures
   578 
   579   IMPORTANT: rewrite rules must not introduce new Vars or TVars!
   580 
   581 *)
   582 
   583 fun rewritec (prover, signt, maxt)
   584              (mss as Mss{rules, procs, termless, prems, congs, depth,...}) t =
   585   let
   586     val eta_thm = Thm.eta_conversion t;
   587     val eta_t' = rhs_of eta_thm;
   588     val eta_t = term_of eta_t';
   589     val tsigt = Sign.tsig_of signt;
   590     fun rew {thm, lhs, elhs, fo, perm} =
   591       let
   592         val {sign, prop, maxidx, ...} = rep_thm thm;
   593         val _ = if Sign.subsig (sign, signt) then ()
   594                 else (prthm true "Ignoring rewrite rule from different theory:" thm;
   595                       raise Pattern.MATCH);
   596         val (rthm, elhs') = if maxt = ~1 then (thm, elhs)
   597           else (Thm.incr_indexes (maxt+1) thm, Thm.cterm_incr_indexes (maxt+1) elhs);
   598         val insts = if fo then Thm.cterm_first_order_match (elhs', eta_t')
   599                           else Thm.cterm_match (elhs', eta_t');
   600         val thm' = Thm.instantiate insts (Thm.rename_boundvars lhs eta_t rthm);
   601         val prop' = #prop (rep_thm thm');
   602         val unconditional = (Logic.count_prems (prop',0) = 0);
   603         val (lhs', rhs') = Logic.dest_equals (Logic.strip_imp_concl prop')
   604       in
   605         if perm andalso not (termless (rhs', lhs'))
   606         then (trace_thm false "Cannot apply permutative rewrite rule:" thm;
   607               trace_thm false "Term does not become smaller:" thm'; None)
   608         else (trace_thm false "Applying instance of rewrite rule:" thm;
   609            if unconditional
   610            then
   611              (trace_thm false "Rewriting:" thm';
   612               let val lr = Logic.dest_equals prop;
   613                   val Some thm'' = check_conv false eta_thm thm'
   614               in Some (thm'', uncond_skel (congs, lr)) end)
   615            else
   616              (trace_thm false "Trying to rewrite:" thm';
   617               case prover (incr_depth mss) thm' of
   618                 None       => (trace_thm false "FAILED" thm'; None)
   619               | Some thm2 =>
   620                   (case check_conv true eta_thm thm2 of
   621                      None => None |
   622                      Some thm2' =>
   623                        let val concl = Logic.strip_imp_concl prop
   624                            val lr = Logic.dest_equals concl
   625                        in Some (thm2', cond_skel (congs, lr)) end)))
   626       end
   627 
   628     fun rews [] = None
   629       | rews (rrule :: rrules) =
   630           let val opt = rew rrule handle Pattern.MATCH => None
   631           in case opt of None => rews rrules | some => some end;
   632 
   633     fun sort_rrules rrs = let
   634       fun is_simple({thm, ...}:rrule) = case #prop (rep_thm thm) of
   635                                       Const("==",_) $ _ $ _ => true
   636                                       | _                   => false
   637       fun sort []        (re1,re2) = re1 @ re2
   638         | sort (rr::rrs) (re1,re2) = if is_simple rr
   639                                      then sort rrs (rr::re1,re2)
   640                                      else sort rrs (re1,rr::re2)
   641     in sort rrs ([],[]) end
   642 
   643     fun proc_rews ([]:simproc list) = None
   644       | proc_rews ({name, proc, lhs, ...} :: ps) =
   645           if Pattern.matches tsigt (term_of lhs, term_of t) then
   646             (debug_term false ("Trying procedure " ^ quote name ^ " on:") signt eta_t;
   647              case proc signt prems eta_t of
   648                None => (debug false "FAILED"; proc_rews ps)
   649              | Some raw_thm =>
   650                  (trace_thm false ("Procedure " ^ quote name ^ " produced rewrite rule:") raw_thm;
   651                   (case rews (mk_procrule raw_thm) of
   652                     None => (trace_cterm false "IGNORED - does not match" t; proc_rews ps)
   653                   | some => some)))
   654           else proc_rews ps;
   655   in case eta_t of
   656        Abs _ $ _ => Some (transitive eta_thm
   657          (beta_conversion false eta_t'), skel0)
   658      | _ => (case rews (sort_rrules (Net.match_term rules eta_t)) of
   659                None => proc_rews (Net.match_term procs eta_t)
   660              | some => some)
   661   end;
   662 
   663 
   664 (* conversion to apply a congruence rule to a term *)
   665 
   666 fun congc (prover,signt,maxt) {thm=cong,lhs=lhs} t =
   667   let val {sign, ...} = rep_thm cong
   668       val _ = if Sign.subsig (sign, signt) then ()
   669                  else error("Congruence rule from different theory")
   670       val rthm = if maxt = ~1 then cong else Thm.incr_indexes (maxt+1) cong;
   671       val rlhs = fst (Drule.dest_equals (Drule.strip_imp_concl (cprop_of rthm)));
   672       val insts = Thm.cterm_match (rlhs, t)
   673       (* Pattern.match can raise Pattern.MATCH;
   674          is handled when congc is called *)
   675       val thm' = Thm.instantiate insts (Thm.rename_boundvars (term_of rlhs) (term_of t) rthm);
   676       val unit = trace_thm false "Applying congruence rule:" thm';
   677       fun err (msg, thm) = (prthm false msg thm; error "Failed congruence proof!")
   678   in case prover thm' of
   679        None => err ("Could not prove", thm')
   680      | Some thm2 => (case check_conv true (beta_eta_conversion t) thm2 of
   681           None => err ("Should not have proved", thm2)
   682         | Some thm2' =>
   683             if op aconv (pairself term_of (dest_equals (cprop_of thm2')))
   684             then None else Some thm2')
   685   end;
   686 
   687 val (cA, (cB, cC)) =
   688   apsnd dest_equals (dest_implies (hd (cprems_of Drule.imp_cong)));
   689 
   690 fun transitive' thm1 None = Some thm1
   691   | transitive' thm1 (Some thm2) = Some (transitive thm1 thm2);
   692 
   693 fun transitive'' None thm2 = Some thm2
   694   | transitive'' (Some thm1) thm2 = Some (transitive thm1 thm2);
   695 
   696 fun bottomc ((simprem,useprem,mutsimp), prover, sign, maxidx) =
   697   let
   698     fun botc skel mss t =
   699           if is_Var skel then None
   700           else
   701           (case subc skel mss t of
   702              some as Some thm1 =>
   703                (case rewritec (prover, sign, maxidx) mss (rhs_of thm1) of
   704                   Some (thm2, skel2) =>
   705                     transitive' (transitive thm1 thm2)
   706                       (botc skel2 mss (rhs_of thm2))
   707                 | None => some)
   708            | None =>
   709                (case rewritec (prover, sign, maxidx) mss t of
   710                   Some (thm2, skel2) => transitive' thm2
   711                     (botc skel2 mss (rhs_of thm2))
   712                 | None => None))
   713 
   714     and try_botc mss t =
   715           (case botc skel0 mss t of
   716              Some trec1 => trec1 | None => (reflexive t))
   717 
   718     and subc skel
   719           (mss as Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) t0 =
   720        (case term_of t0 of
   721            Abs (a, T, t) =>
   722              let val b = variant bounds a
   723                  val (v, t') = Thm.dest_abs (Some ("." ^ b)) t0
   724                  val mss' = mk_mss (rules, congs, procs, b :: bounds, prems, mk_rews, termless,depth)
   725                  val skel' = case skel of Abs (_, _, sk) => sk | _ => skel0
   726              in case botc skel' mss' t' of
   727                   Some thm => Some (abstract_rule a v thm)
   728                 | None => None
   729              end
   730          | t $ _ => (case t of
   731              Const ("==>", _) $ _  =>
   732                let val (s, u) = Drule.dest_implies t0
   733                in impc (s, u, mss) end
   734            | Abs _ =>
   735                let val thm = beta_conversion false t0
   736                in case subc skel0 mss (rhs_of thm) of
   737                     None => Some thm
   738                   | Some thm' => Some (transitive thm thm')
   739                end
   740            | _  =>
   741                let fun appc () =
   742                      let
   743                        val (tskel, uskel) = case skel of
   744                            tskel $ uskel => (tskel, uskel)
   745                          | _ => (skel0, skel0);
   746                        val (ct, cu) = Thm.dest_comb t0
   747                      in
   748                      (case botc tskel mss ct of
   749                         Some thm1 =>
   750                           (case botc uskel mss cu of
   751                              Some thm2 => Some (combination thm1 thm2)
   752                            | None => Some (combination thm1 (reflexive cu)))
   753                       | None =>
   754                           (case botc uskel mss cu of
   755                              Some thm1 => Some (combination (reflexive ct) thm1)
   756                            | None => None))
   757                      end
   758                    val (h, ts) = strip_comb t
   759                in case h of
   760                     Const(a, _) =>
   761                       (case assoc_string (fst congs, a) of
   762                          None => appc ()
   763                        | Some cong =>
   764 (* post processing: some partial applications h t1 ... tj, j <= length ts,
   765    may be a redex. Example: map (%x.x) = (%xs.xs) wrt map_cong *)
   766                           (let
   767                              val thm = congc (prover mss, sign, maxidx) cong t0;
   768                              val t = if_none (apsome rhs_of thm) t0;
   769                              val (cl, cr) = Thm.dest_comb t
   770                              val dVar = Var(("", 0), dummyT)
   771                              val skel =
   772                                list_comb (h, replicate (length ts) dVar)
   773                            in case botc skel mss cl of
   774                                 None => thm
   775                               | Some thm' => transitive'' thm
   776                                   (combination thm' (reflexive cr))
   777                            end handle TERM _ => error "congc result"
   778                                     | Pattern.MATCH => appc ()))
   779                   | _ => appc ()
   780                end)
   781          | _ => None)
   782 
   783     and impc args =
   784       if mutsimp
   785       then let val (prem, conc, mss) = args
   786            in apsome snd (mut_impc ([], prem, conc, mss)) end
   787       else nonmut_impc args
   788 
   789     and mut_impc (prems, prem, conc, mss) = (case botc skel0 mss prem of
   790         None => mut_impc1 (prems, prem, conc, mss)
   791       | Some thm1 =>
   792           let val prem1 = rhs_of thm1
   793           in (case mut_impc1 (prems, prem1, conc, mss) of
   794               None => Some (None,
   795                 combination (combination refl_implies thm1) (reflexive conc))
   796             | Some (x, thm2) => Some (x, transitive (combination (combination
   797                 refl_implies thm1) (reflexive conc)) thm2))
   798           end)
   799 
   800     and mut_impc1 (prems, prem1, conc, mss) =
   801       let
   802         fun uncond ({thm, lhs, elhs, perm}) =
   803           if Thm.no_prems thm then Some lhs else None
   804 
   805         val (lhss1, mss1) =
   806           if maxidx_of_term (term_of prem1) <> ~1
   807           then (trace_cterm true
   808             "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem1;
   809                 ([],mss))
   810           else let val thm = assume prem1
   811                    val rrules1 = extract_safe_rrules (mss, thm)
   812                    val lhss1 = mapfilter uncond rrules1
   813                    val mss1 = foldl insert_rrule (add_prems (mss, [thm]), rrules1)
   814                in (lhss1, mss1) end
   815 
   816         fun disch1 thm =
   817           let val (cB', cC') = dest_eq thm
   818           in
   819             implies_elim (Thm.instantiate
   820               ([], [(cA, prem1), (cB, cB'), (cC, cC')]) Drule.imp_cong)
   821               (implies_intr prem1 thm)
   822           end
   823 
   824         fun rebuild None = (case rewritec (prover, sign, maxidx) mss
   825             (mk_implies (prem1, conc)) of
   826               None => None
   827             | Some (thm, _) =>
   828                 let val (prem, conc) = Drule.dest_implies (rhs_of thm)
   829                 in (case mut_impc (prems, prem, conc, mss) of
   830                     None => Some (None, thm)
   831                   | Some (x, thm') => Some (x, transitive thm thm'))
   832                 end handle TERM _ => Some (None, thm))
   833           | rebuild (Some thm2) =
   834             let val thm = disch1 thm2
   835             in (case rewritec (prover, sign, maxidx) mss (rhs_of thm) of
   836                  None => Some (None, thm)
   837                | Some (thm', _) =>
   838                    let val (prem, conc) = Drule.dest_implies (rhs_of thm')
   839                    in (case mut_impc (prems, prem, conc, mss) of
   840                        None => Some (None, transitive thm thm')
   841                      | Some (x, thm'') =>
   842                          Some (x, transitive (transitive thm thm') thm''))
   843                    end handle TERM _ => Some (None, transitive thm thm'))
   844             end
   845 
   846         fun simpconc () =
   847           let val (s, t) = Drule.dest_implies conc
   848           in case mut_impc (prems @ [prem1], s, t, mss1) of
   849                None => rebuild None
   850              | Some (Some i, thm2) =>
   851                   let
   852                     val (prem, cC') = Drule.dest_implies (rhs_of thm2);
   853                     val thm2' = transitive (disch1 thm2) (Thm.instantiate
   854                       ([], [(cA, prem1), (cB, prem), (cC, cC')])
   855                       Drule.swap_prems_eq)
   856                   in if i=0 then apsome (apsnd (transitive thm2'))
   857                        (mut_impc1 (prems, prem, mk_implies (prem1, cC'), mss))
   858                      else Some (Some (i-1), thm2')
   859                   end
   860              | Some (None, thm) => rebuild (Some thm)
   861           end handle TERM _ => rebuild (botc skel0 mss1 conc)
   862 
   863       in
   864         let
   865           val tsig = Sign.tsig_of sign
   866           fun reducible t =
   867             exists (fn lhs => Pattern.matches_subterm tsig (lhs, term_of t)) lhss1;
   868         in case dropwhile (not o reducible) prems of
   869             [] => simpconc ()
   870           | red::rest => (trace_cterm false "Can now reduce premise:" red;
   871               Some (Some (length rest), reflexive (mk_implies (prem1, conc))))
   872         end
   873       end
   874 
   875      (* legacy code - only for backwards compatibility *)
   876      and nonmut_impc (prem, conc, mss) =
   877        let val thm1 = if simprem then botc skel0 mss prem else None;
   878            val prem1 = if_none (apsome rhs_of thm1) prem;
   879            val maxidx1 = maxidx_of_term (term_of prem1)
   880            val mss1 =
   881              if not useprem then mss else
   882              if maxidx1 <> ~1
   883              then (trace_cterm true
   884                "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem1;
   885                    mss)
   886              else let val thm = assume prem1
   887                   in add_safe_simp (add_prems (mss, [thm]), thm) end
   888        in (case botc skel0 mss1 conc of
   889            None => (case thm1 of
   890                None => None
   891              | Some thm1' => Some (combination
   892                  (combination refl_implies thm1') (reflexive conc)))
   893          | Some thm2 =>
   894            let
   895              val conc2 = rhs_of thm2;
   896              val thm2' = implies_elim (Thm.instantiate
   897                ([], [(cA, prem1), (cB, conc), (cC, conc2)]) Drule.imp_cong)
   898                (implies_intr prem1 thm2)
   899            in (case thm1 of
   900                None => Some thm2'
   901              | Some thm1' => Some (transitive (combination
   902                  (combination refl_implies thm1') (reflexive conc)) thm2'))
   903            end)
   904        end
   905 
   906  in try_botc end;
   907 
   908 
   909 (*** Meta-rewriting: rewrites t to u and returns the theorem t==u ***)
   910 
   911 (*
   912   Parameters:
   913     mode = (simplify A,
   914             use A in simplifying B,
   915             use prems of B (if B is again a meta-impl.) to simplify A)
   916            when simplifying A ==> B
   917     mss: contains equality theorems of the form [|p1,...|] ==> t==u
   918     prover: how to solve premises in conditional rewrites and congruences
   919 *)
   920 
   921 fun rewrite_cterm mode prover mss ct =
   922   let val {sign, t, maxidx, ...} = rep_cterm ct
   923       val Mss{depth, ...} = mss
   924   in simp_depth := depth;
   925      bottomc (mode, prover, sign, maxidx) mss ct
   926   end
   927   handle THM (s, _, thms) =>
   928     error ("Exception THM was raised in simplifier:\n" ^ s ^ "\n" ^
   929       Pretty.string_of (Display.pretty_thms thms));
   930 
   931 (*In [A1,...,An]==>B, rewrite the selected A's only -- for rewrite_goals_tac*)
   932 (*Do not rewrite flex-flex pairs*)
   933 fun goals_conv pred cv =
   934   let fun gconv i ct =
   935         let val (A,B) = Drule.dest_implies ct
   936             val (thA,j) = case term_of A of
   937                   Const("=?=",_)$_$_ => (reflexive A, i)
   938                 | _ => (if pred i then cv A else reflexive A, i+1)
   939         in  combination (combination refl_implies thA) (gconv j B) end
   940         handle TERM _ => reflexive ct
   941   in gconv 1 end;
   942 
   943 (* Rewrite A in !!x1,...,xn. A *)
   944 fun forall_conv cv ct =
   945   let val p as (ct1, ct2) = Thm.dest_comb ct
   946   in (case pairself term_of p of
   947       (Const ("all", _), Abs (s, _, _)) =>
   948          let val (v, ct') = Thm.dest_abs (Some "@") ct2;
   949          in Thm.combination (Thm.reflexive ct1)
   950            (Thm.abstract_rule s v (forall_conv cv ct'))
   951          end
   952     | _ => cv ct)
   953   end handle TERM _ => cv ct;
   954 
   955 (*Use a conversion to transform a theorem*)
   956 fun fconv_rule cv th = equal_elim (cv (cprop_of th)) th;
   957 
   958 (*Rewrite a cterm*)
   959 fun rewrite_aux _ _ [] = (fn ct => Thm.reflexive ct)
   960   | rewrite_aux prover full thms = rewrite_cterm (full, false, false) prover (mss_of thms);
   961 
   962 (*Rewrite a theorem*)
   963 fun simplify_aux _ _ [] = (fn th => th)
   964   | simplify_aux prover full thms =
   965       fconv_rule (rewrite_cterm (full, false, false) prover (mss_of thms));
   966 
   967 fun rewrite_thm mode prover mss = fconv_rule (rewrite_cterm mode prover mss);
   968 
   969 (*Rewrite the subgoals of a proof state (represented by a theorem) *)
   970 fun rewrite_goals_rule_aux _ []   th = th
   971   | rewrite_goals_rule_aux prover thms th =
   972       fconv_rule (goals_conv (K true) (rewrite_cterm (true, true, false) prover
   973         (mss_of thms))) th;
   974 
   975 (*Rewrite the subgoal of a proof state (represented by a theorem) *)
   976 fun rewrite_goal_rule mode prover mss i thm =
   977   if 0 < i  andalso  i <= nprems_of thm
   978   then fconv_rule (goals_conv (fn j => j=i) (rewrite_cterm mode prover mss)) thm
   979   else raise THM("rewrite_goal_rule",i,[thm]);
   980 
   981 end;
   982 
   983 structure BasicMetaSimplifier: BASIC_META_SIMPLIFIER = MetaSimplifier;
   984 open BasicMetaSimplifier;