src/Pure/meta_simplifier.ML
author kleing
Mon Jun 21 10:25:57 2004 +0200 (2004-06-21)
changeset 14981 e73f8140af78
parent 14643 130076a81b84
child 15001 fb2141a9f8c0
permissions -rw-r--r--
Merged in license change from Isabelle2004
     1 (*  Title:      Pure/meta_simplifier.ML
     2     ID:         $Id$
     3     Author:     Tobias Nipkow and Stefan Berghofer
     4 
     5 Meta-level Simplification.
     6 *)
     7 
     8 signature BASIC_META_SIMPLIFIER =
     9 sig
    10   val trace_simp: bool ref
    11   val debug_simp: bool ref
    12   val simp_depth_limit: int ref
    13 end;
    14 
    15 signature META_SIMPLIFIER =
    16 sig
    17   include BASIC_META_SIMPLIFIER
    18   exception SIMPLIFIER of string * thm
    19   exception SIMPROC_FAIL of string * exn
    20   type meta_simpset
    21   val dest_mss          : meta_simpset ->
    22     {simps: thm list, congs: thm list, procs: (string * cterm list) list}
    23   val empty_mss         : meta_simpset
    24   val clear_mss         : meta_simpset -> meta_simpset
    25   val merge_mss         : meta_simpset * meta_simpset -> meta_simpset
    26   val add_simps         : meta_simpset * thm list -> meta_simpset
    27   val del_simps         : meta_simpset * thm list -> meta_simpset
    28   val mss_of            : thm list -> meta_simpset
    29   val add_congs         : meta_simpset * thm list -> meta_simpset
    30   val del_congs         : meta_simpset * thm list -> meta_simpset
    31   val add_simprocs      : meta_simpset *
    32     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    33       -> meta_simpset
    34   val del_simprocs      : meta_simpset *
    35     (string * cterm list * (Sign.sg -> thm list -> term -> thm option) * stamp) list
    36       -> meta_simpset
    37   val add_prems         : meta_simpset * thm list -> meta_simpset
    38   val prems_of_mss      : meta_simpset -> thm list
    39   val set_mk_rews       : meta_simpset * (thm -> thm list) -> meta_simpset
    40   val set_mk_sym        : meta_simpset * (thm -> thm option) -> meta_simpset
    41   val set_mk_eq_True    : meta_simpset * (thm -> thm option) -> meta_simpset
    42   val get_mk_rews       : meta_simpset -> thm -> thm list
    43   val get_mk_sym        : meta_simpset -> thm -> thm option
    44   val get_mk_eq_True    : meta_simpset -> thm -> thm option
    45   val set_termless      : meta_simpset * (term * term -> bool) -> meta_simpset
    46   val beta_eta_conversion: cterm -> thm
    47   val rewrite_cterm: bool * bool * bool ->
    48     (meta_simpset -> thm -> thm option) -> meta_simpset -> cterm -> thm
    49   val goals_conv        : (int -> bool) -> (cterm -> thm) -> cterm -> thm
    50   val forall_conv       : (cterm -> thm) -> cterm -> thm
    51   val fconv_rule        : (cterm -> thm) -> thm -> thm
    52   val rewrite_aux       : (meta_simpset -> thm -> thm option) -> bool -> thm list -> cterm -> thm
    53   val simplify_aux      : (meta_simpset -> thm -> thm option) -> bool -> thm list -> thm -> thm
    54   val rewrite_thm       : bool * bool * bool
    55                           -> (meta_simpset -> thm -> thm option)
    56                           -> meta_simpset -> thm -> thm
    57   val rewrite_goals_rule_aux: (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
    58   val rewrite_goal_rule : bool* bool * bool
    59                           -> (meta_simpset -> thm -> thm option)
    60                           -> meta_simpset -> int -> thm -> thm
    61   val rewrite_term: Sign.sg -> thm list -> (term -> term option) list -> term -> term
    62 end;
    63 
    64 structure MetaSimplifier : META_SIMPLIFIER =
    65 struct
    66 
    67 (** diagnostics **)
    68 
    69 exception SIMPLIFIER of string * thm;
    70 exception SIMPROC_FAIL of string * exn;
    71 
    72 val simp_depth = ref 0;
    73 val simp_depth_limit = ref 1000;
    74 
    75 local
    76 
    77 fun println a =
    78   tracing ((case ! simp_depth of 0 => "" | n => "[" ^ string_of_int n ^ "]") ^ a);
    79 
    80 fun prnt warn a = if warn then warning a else println a;
    81 fun prtm warn a sign t = prnt warn (a ^ "\n" ^ Sign.string_of_term sign t);
    82 fun prctm warn a t = prnt warn (a ^ "\n" ^ Display.string_of_cterm t);
    83 
    84 in
    85 
    86 fun prthm warn a = prctm warn a o Thm.cprop_of;
    87 
    88 val trace_simp = ref false;
    89 val debug_simp = ref false;
    90 
    91 fun trace warn a = if !trace_simp then prnt warn a else ();
    92 fun debug warn a = if !debug_simp then prnt warn a else ();
    93 
    94 fun trace_term warn a sign t = if !trace_simp then prtm warn a sign t else ();
    95 fun trace_cterm warn a t = if !trace_simp then prctm warn a t else ();
    96 fun debug_term warn a sign t = if !debug_simp then prtm warn a sign t else ();
    97 
    98 fun trace_thm a thm =
    99   let val {sign, prop, ...} = rep_thm thm
   100   in trace_term false a sign prop end;
   101 
   102 fun trace_named_thm a (thm, name) =
   103   trace_thm (a ^ (if name = "" then "" else " " ^ quote name) ^ ":") thm;
   104 
   105 end;
   106 
   107 
   108 (** meta simp sets **)
   109 
   110 (* basic components *)
   111 
   112 type rrule = {thm: thm, name: string, lhs: term, elhs: cterm, fo: bool, perm: bool};
   113 (* thm: the rewrite rule
   114    name: name of theorem from which rewrite rule was extracted
   115    lhs: the left-hand side
   116    elhs: the etac-contracted lhs.
   117    fo:  use first-order matching
   118    perm: the rewrite rule is permutative
   119 Remarks:
   120   - elhs is used for matching,
   121     lhs only for preservation of bound variable names.
   122   - fo is set iff
   123     either elhs is first-order (no Var is applied),
   124            in which case fo-matching is complete,
   125     or elhs is not a pattern,
   126        in which case there is nothing better to do.
   127 *)
   128 type cong = {thm: thm, lhs: cterm};
   129 type simproc =
   130  {name: string, proc: Sign.sg -> thm list -> term -> thm option, lhs: cterm, id: stamp};
   131 
   132 fun eq_rrule ({thm = thm1, ...}: rrule, {thm = thm2, ...}: rrule) =
   133   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   134 
   135 fun eq_cong ({thm = thm1, ...}: cong, {thm = thm2, ...}: cong) =
   136   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   137 
   138 fun eq_prem (thm1, thm2) =
   139   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   140 
   141 fun eq_simproc ({id = s1, ...}:simproc, {id = s2, ...}:simproc) = (s1 = s2);
   142 
   143 fun mk_simproc (name, proc, lhs, id) =
   144   {name = name, proc = proc, lhs = lhs, id = id};
   145 
   146 
   147 (* datatype mss *)
   148 
   149 (*
   150   A "mss" contains data needed during conversion:
   151     rules: discrimination net of rewrite rules;
   152     congs: association list of congruence rules and
   153            a list of `weak' congruence constants.
   154            A congruence is `weak' if it avoids normalization of some argument.
   155     procs: discrimination net of simplification procedures
   156       (functions that prove rewrite rules on the fly);
   157     bounds: names of bound variables already used
   158       (for generating new names when rewriting under lambda abstractions);
   159     prems: current premises;
   160     mk_rews: mk: turns simplification thms into rewrite rules;
   161              mk_sym: turns == around; (needs Drule!)
   162              mk_eq_True: turns P into P == True - logic specific;
   163     termless: relation for ordered rewriting;
   164     depth: depth of conditional rewriting;
   165 *)
   166 
   167 datatype meta_simpset =
   168   Mss of {
   169     rules: rrule Net.net,
   170     congs: (string * cong) list * string list,
   171     procs: simproc Net.net,
   172     bounds: string list,
   173     prems: thm list,
   174     mk_rews: {mk: thm -> thm list,
   175               mk_sym: thm -> thm option,
   176               mk_eq_True: thm -> thm option},
   177     termless: term * term -> bool,
   178     depth: int};
   179 
   180 fun mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth) =
   181   Mss {rules = rules, congs = congs, procs = procs, bounds = bounds,
   182        prems=prems, mk_rews=mk_rews, termless=termless, depth=depth};
   183 
   184 fun upd_rules(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}, rules') =
   185   mk_mss(rules',congs,procs,bounds,prems,mk_rews,termless,depth);
   186 
   187 val empty_mss =
   188   let val mk_rews = {mk = K [], mk_sym = K None, mk_eq_True = K None}
   189   in mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, Term.termless, 0) end;
   190 
   191 fun clear_mss (Mss {mk_rews, termless, ...}) =
   192   mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, termless,0);
   193 
   194 fun incr_depth(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) =
   195   let val depth1 = depth+1
   196   in if depth1 > !simp_depth_limit
   197      then (warning "simp_depth_limit exceeded - giving up"; None)
   198      else (if depth1 mod 10 = 0
   199            then warning("Simplification depth " ^ string_of_int depth1)
   200            else ();
   201            Some(mk_mss(rules,congs,procs,bounds,prems,mk_rews,termless,depth1))
   202           )
   203   end;
   204 
   205 
   206 (** simpset operations **)
   207 
   208 (* term variables *)
   209 
   210 val add_term_varnames = foldl_aterms (fn (xs, Var (x, _)) => ins_ix (x, xs) | (xs, _) => xs);
   211 fun term_varnames t = add_term_varnames ([], t);
   212 
   213 
   214 (* dest_mss *)
   215 
   216 fun dest_mss (Mss {rules, congs, procs, ...}) =
   217   {simps = map (fn (_, {thm, ...}) => thm) (Net.dest rules),
   218    congs = map (fn (_, {thm, ...}) => thm) (fst congs),
   219    procs =
   220      map (fn (_, {name, lhs, id, ...}) => ((name, lhs), id)) (Net.dest procs)
   221      |> partition_eq eq_snd
   222      |> map (fn ps => (#1 (#1 (hd ps)), map (#2 o #1) ps))
   223      |> Library.sort_wrt #1};
   224 
   225 
   226 (* merge_mss *)       (*NOTE: ignores mk_rews, termless and depth of 2nd mss*)
   227 
   228 fun merge_mss
   229  (Mss {rules = rules1, congs = (congs1,weak1), procs = procs1,
   230        bounds = bounds1, prems = prems1, mk_rews, termless, depth},
   231   Mss {rules = rules2, congs = (congs2,weak2), procs = procs2,
   232        bounds = bounds2, prems = prems2, ...}) =
   233       mk_mss
   234        (Net.merge (rules1, rules2, eq_rrule),
   235         (gen_merge_lists (eq_cong o pairself snd) congs1 congs2,
   236         merge_lists weak1 weak2),
   237         Net.merge (procs1, procs2, eq_simproc),
   238         merge_lists bounds1 bounds2,
   239         gen_merge_lists eq_prem prems1 prems2,
   240         mk_rews, termless, depth);
   241 
   242 
   243 (* add_simps *)
   244 
   245 fun mk_rrule2{thm, name, lhs, elhs, perm} =
   246   let val fo = Pattern.first_order (term_of elhs) orelse not(Pattern.pattern (term_of elhs))
   247   in {thm=thm, name=name, lhs=lhs, elhs=elhs, fo=fo, perm=perm} end
   248 
   249 fun insert_rrule quiet (mss as Mss {rules,...},
   250                  rrule as {thm,name,lhs,elhs,perm}) =
   251   (trace_named_thm "Adding rewrite rule" (thm, name);
   252    let val rrule2 as {elhs,...} = mk_rrule2 rrule
   253        val rules' = Net.insert_term ((term_of elhs, rrule2), rules, eq_rrule)
   254    in upd_rules(mss,rules') end
   255    handle Net.INSERT => if quiet then mss else
   256      (prthm true "Ignoring duplicate rewrite rule:" thm; mss));
   257 
   258 fun vperm (Var _, Var _) = true
   259   | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t)
   260   | vperm (t1 $ t2, u1 $ u2) = vperm (t1, u1) andalso vperm (t2, u2)
   261   | vperm (t, u) = (t = u);
   262 
   263 fun var_perm (t, u) =
   264   vperm (t, u) andalso eq_set (term_varnames t, term_varnames u);
   265 
   266 (* FIXME: it seems that the conditions on extra variables are too liberal if
   267 prems are nonempty: does solving the prems really guarantee instantiation of
   268 all its Vars? Better: a dynamic check each time a rule is applied.
   269 *)
   270 fun rewrite_rule_extra_vars prems elhs erhs =
   271   not (term_varnames erhs subset foldl add_term_varnames (term_varnames elhs, prems))
   272   orelse
   273   not ((term_tvars erhs) subset
   274        (term_tvars elhs  union  List.concat(map term_tvars prems)));
   275 
   276 (*Simple test for looping rewrite rules and stupid orientations*)
   277 fun reorient sign prems lhs rhs =
   278    rewrite_rule_extra_vars prems lhs rhs
   279   orelse
   280    is_Var (head_of lhs)
   281   orelse
   282    (exists (apl (lhs, Logic.occs)) (rhs :: prems))
   283   orelse
   284    (null prems andalso
   285     Pattern.matches (Sign.tsig_of sign) (lhs, rhs))
   286     (*the condition "null prems" is necessary because conditional rewrites
   287       with extra variables in the conditions may terminate although
   288       the rhs is an instance of the lhs. Example: ?m < ?n ==> f(?n) == f(?m)*)
   289   orelse
   290    (is_Const lhs andalso not(is_Const rhs))
   291 
   292 fun decomp_simp thm =
   293   let val {sign, prop, ...} = rep_thm thm;
   294       val prems = Logic.strip_imp_prems prop;
   295       val concl = Drule.strip_imp_concl (cprop_of thm);
   296       val (lhs, rhs) = Drule.dest_equals concl handle TERM _ =>
   297         raise SIMPLIFIER ("Rewrite rule not a meta-equality", thm)
   298       val elhs = snd (Drule.dest_equals (cprop_of (Thm.eta_conversion lhs)));
   299       val elhs = if elhs=lhs then lhs else elhs (* try to share *)
   300       val erhs = Pattern.eta_contract (term_of rhs);
   301       val perm = var_perm (term_of elhs, erhs) andalso not (term_of elhs aconv erhs)
   302                  andalso not (is_Var (term_of elhs))
   303   in (sign, prems, term_of lhs, elhs, term_of rhs, perm) end;
   304 
   305 fun decomp_simp' thm =
   306   let val (_, _, lhs, _, rhs, _) = decomp_simp thm in
   307     if Thm.nprems_of thm > 0 then raise SIMPLIFIER ("Bad conditional rewrite rule", thm)
   308     else (lhs, rhs)
   309   end;
   310 
   311 fun mk_eq_True (Mss{mk_rews={mk_eq_True,...},...}) (thm, name) =
   312   case mk_eq_True thm of
   313     None => []
   314   | Some eq_True =>
   315       let val (_,_,lhs,elhs,_,_) = decomp_simp eq_True
   316       in [{thm=eq_True, name=name, lhs=lhs, elhs=elhs, perm=false}] end;
   317 
   318 (* create the rewrite rule and possibly also the ==True variant,
   319    in case there are extra vars on the rhs *)
   320 fun rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm2) =
   321   let val rrule = {thm=thm, name=name, lhs=lhs, elhs=elhs, perm=false}
   322   in if (term_varnames rhs)  subset (term_varnames lhs) andalso
   323         (term_tvars rhs) subset (term_tvars lhs)
   324      then [rrule]
   325      else mk_eq_True mss (thm2, name) @ [rrule]
   326   end;
   327 
   328 fun mk_rrule mss (thm, name) =
   329   let val (_,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   330   in if perm then [{thm=thm, name=name, lhs=lhs, elhs=elhs, perm=true}] else
   331      (* weak test for loops: *)
   332      if rewrite_rule_extra_vars prems lhs rhs orelse
   333         is_Var (term_of elhs)
   334      then mk_eq_True mss (thm, name)
   335      else rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm)
   336   end;
   337 
   338 fun orient_rrule mss (thm, name) =
   339   let val (sign,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   340   in if perm then [{thm=thm, name=name, lhs=lhs, elhs=elhs, perm=true}]
   341      else if reorient sign prems lhs rhs
   342           then if reorient sign prems rhs lhs
   343                then mk_eq_True mss (thm, name)
   344                else let val Mss{mk_rews={mk_sym,...},...} = mss
   345                     in case mk_sym thm of
   346                          None => []
   347                        | Some thm' =>
   348                            let val (_,_,lhs',elhs',rhs',_) = decomp_simp thm'
   349                            in rrule_eq_True(thm',name,lhs',elhs',rhs',mss,thm) end
   350                     end
   351           else rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm)
   352   end;
   353 
   354 fun extract_rews(Mss{mk_rews = {mk,...},...},thms) =
   355   flat (map (fn thm => map (rpair (Thm.name_of_thm thm)) (mk thm)) thms);
   356 
   357 fun orient_comb_simps comb mk_rrule (mss,thms) =
   358   let val rews = extract_rews(mss,thms)
   359       val rrules = flat (map mk_rrule rews)
   360   in foldl comb (mss,rrules) end
   361 
   362 (* Add rewrite rules explicitly; do not reorient! *)
   363 fun add_simps(mss,thms) =
   364   orient_comb_simps (insert_rrule false) (mk_rrule mss) (mss,thms);
   365 
   366 fun mss_of thms = foldl (insert_rrule false) (empty_mss, flat
   367   (map (fn thm => mk_rrule empty_mss (thm, Thm.name_of_thm thm)) thms));
   368 
   369 fun extract_safe_rrules(mss,thm) =
   370   flat (map (orient_rrule mss) (extract_rews(mss,[thm])));
   371 
   372 (* del_simps *)
   373 
   374 fun del_rrule(mss as Mss {rules,...},
   375               rrule as {thm, elhs, ...}) =
   376   (upd_rules(mss, Net.delete_term ((term_of elhs, rrule), rules, eq_rrule))
   377    handle Net.DELETE =>
   378      (prthm true "Rewrite rule not in simpset:" thm; mss));
   379 
   380 fun del_simps(mss,thms) =
   381   orient_comb_simps del_rrule (map mk_rrule2 o mk_rrule mss) (mss,thms);
   382 
   383 
   384 (* add_congs *)
   385 
   386 fun is_full_cong_prems [] varpairs = null varpairs
   387   | is_full_cong_prems (p::prems) varpairs =
   388     (case Logic.strip_assums_concl p of
   389        Const("==",_) $ lhs $ rhs =>
   390          let val (x,xs) = strip_comb lhs and (y,ys) = strip_comb rhs
   391          in is_Var x  andalso  forall is_Bound xs  andalso
   392             null(findrep(xs))  andalso xs=ys andalso
   393             (x,y) mem varpairs andalso
   394             is_full_cong_prems prems (varpairs\(x,y))
   395          end
   396      | _ => false);
   397 
   398 fun is_full_cong thm =
   399 let val prems = prems_of thm
   400     and concl = concl_of thm
   401     val (lhs,rhs) = Logic.dest_equals concl
   402     val (f,xs) = strip_comb lhs
   403     and (g,ys) = strip_comb rhs
   404 in
   405   f=g andalso null(findrep(xs@ys)) andalso length xs = length ys andalso
   406   is_full_cong_prems prems (xs ~~ ys)
   407 end
   408 
   409 fun cong_name (Const (a, _)) = Some a
   410   | cong_name (Free (a, _)) = Some ("Free: " ^ a)
   411   | cong_name _ = None;
   412 
   413 fun add_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   414   let
   415     val (lhs, _) = Drule.dest_equals (Drule.strip_imp_concl (cprop_of thm)) handle TERM _ =>
   416       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   417 (*   val lhs = Pattern.eta_contract lhs; *)
   418     val a = (case cong_name (head_of (term_of lhs)) of
   419         Some a => a
   420       | None =>
   421         raise SIMPLIFIER ("Congruence must start with a constant or free variable", thm));
   422     val (alist,weak) = congs
   423     val alist2 = overwrite_warn (alist, (a,{lhs=lhs, thm=thm}))
   424            ("Overwriting congruence rule for " ^ quote a);
   425     val weak2 = if is_full_cong thm then weak else a::weak
   426   in
   427     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   428   end;
   429 
   430 val (op add_congs) = foldl add_cong;
   431 
   432 
   433 (* del_congs *)
   434 
   435 fun del_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   436   let
   437     val (lhs, _) = Logic.dest_equals (concl_of thm) handle TERM _ =>
   438       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   439 (*   val lhs = Pattern.eta_contract lhs; *)
   440     val a = (case cong_name (head_of lhs) of
   441         Some a => a
   442       | None =>
   443         raise SIMPLIFIER ("Congruence must start with a constant", thm));
   444     val (alist,_) = congs
   445     val alist2 = filter (fn (x,_)=> x<>a) alist
   446     val weak2 = mapfilter (fn(a,{thm,...}) => if is_full_cong thm then None
   447                                               else Some a)
   448                    alist2
   449   in
   450     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   451   end;
   452 
   453 val (op del_congs) = foldl del_cong;
   454 
   455 
   456 (* add_simprocs *)
   457 
   458 fun add_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   459     (name, lhs, proc, id)) =
   460   let val {sign, t, ...} = rep_cterm lhs
   461   in (trace_term false ("Adding simplification procedure " ^ quote name ^ " for")
   462       sign t;
   463     mk_mss (rules, congs,
   464       Net.insert_term ((t, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   465         handle Net.INSERT =>
   466             (warning ("Ignoring duplicate simplification procedure \""
   467                       ^ name ^ "\"");
   468              procs),
   469         bounds, prems, mk_rews, termless,depth))
   470   end;
   471 
   472 fun add_simproc (mss, (name, lhss, proc, id)) =
   473   foldl add_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   474 
   475 val add_simprocs = foldl add_simproc;
   476 
   477 
   478 (* del_simprocs *)
   479 
   480 fun del_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   481     (name, lhs, proc, id)) =
   482   mk_mss (rules, congs,
   483     Net.delete_term ((term_of lhs, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   484       handle Net.DELETE =>
   485           (warning ("Simplification procedure \"" ^ name ^
   486                        "\" not in simpset"); procs),
   487       bounds, prems, mk_rews, termless, depth);
   488 
   489 fun del_simproc (mss, (name, lhss, proc, id)) =
   490   foldl del_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   491 
   492 val del_simprocs = foldl del_simproc;
   493 
   494 
   495 (* prems *)
   496 
   497 fun add_prems (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thms) =
   498   mk_mss (rules, congs, procs, bounds, thms @ prems, mk_rews, termless, depth);
   499 
   500 fun prems_of_mss (Mss {prems, ...}) = prems;
   501 
   502 
   503 (* mk_rews *)
   504 
   505 fun set_mk_rews
   506   (Mss {rules, congs, procs, bounds, prems, mk_rews, termless, depth}, mk) =
   507     mk_mss (rules, congs, procs, bounds, prems,
   508             {mk=mk, mk_sym= #mk_sym mk_rews, mk_eq_True= #mk_eq_True mk_rews},
   509             termless, depth);
   510 
   511 fun set_mk_sym
   512   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_sym) =
   513     mk_mss (rules, congs, procs, bounds, prems,
   514             {mk= #mk mk_rews, mk_sym= mk_sym, mk_eq_True= #mk_eq_True mk_rews},
   515             termless,depth);
   516 
   517 fun set_mk_eq_True
   518   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_eq_True) =
   519     mk_mss (rules, congs, procs, bounds, prems,
   520             {mk= #mk mk_rews, mk_sym= #mk_sym mk_rews, mk_eq_True= mk_eq_True},
   521             termless,depth);
   522 
   523 fun get_mk_rews    (Mss {mk_rews,...}) = #mk         mk_rews
   524 fun get_mk_sym     (Mss {mk_rews,...}) = #mk_sym     mk_rews
   525 fun get_mk_eq_True (Mss {mk_rews,...}) = #mk_eq_True mk_rews
   526 
   527 (* termless *)
   528 
   529 fun set_termless
   530   (Mss {rules, congs, procs, bounds, prems, mk_rews, depth, ...}, termless) =
   531     mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth);
   532 
   533 
   534 
   535 (** rewriting **)
   536 
   537 (*
   538   Uses conversions, see:
   539     L C Paulson, A higher-order implementation of rewriting,
   540     Science of Computer Programming 3 (1983), pages 119-149.
   541 *)
   542 
   543 val dest_eq = Drule.dest_equals o cprop_of;
   544 val lhs_of = fst o dest_eq;
   545 val rhs_of = snd o dest_eq;
   546 
   547 fun beta_eta_conversion t =
   548   let val thm = beta_conversion true t;
   549   in transitive thm (eta_conversion (rhs_of thm)) end;
   550 
   551 fun check_conv msg thm thm' =
   552   let
   553     val thm'' = transitive thm (transitive
   554       (symmetric (beta_eta_conversion (lhs_of thm'))) thm')
   555   in (if msg then trace_thm "SUCCEEDED" thm' else (); Some thm'') end
   556   handle THM _ =>
   557     let val {sign, prop = _ $ _ $ prop0, ...} = rep_thm thm;
   558     in
   559       (trace_thm "Proved wrong thm (Check subgoaler?)" thm';
   560        trace_term false "Should have proved:" sign prop0;
   561        None)
   562     end;
   563 
   564 
   565 (* mk_procrule *)
   566 
   567 fun mk_procrule thm =
   568   let val (_,prems,lhs,elhs,rhs,_) = decomp_simp thm
   569   in if rewrite_rule_extra_vars prems lhs rhs
   570      then (prthm true "Extra vars on rhs:" thm; [])
   571      else [mk_rrule2{thm=thm, name="", lhs=lhs, elhs=elhs, perm=false}]
   572   end;
   573 
   574 
   575 (* conversion to apply the meta simpset to a term *)
   576 
   577 (* Since the rewriting strategy is bottom-up, we avoid re-normalizing already
   578    normalized terms by carrying around the rhs of the rewrite rule just
   579    applied. This is called the `skeleton'. It is decomposed in parallel
   580    with the term. Once a Var is encountered, the corresponding term is
   581    already in normal form.
   582    skel0 is a dummy skeleton that is to enforce complete normalization.
   583 *)
   584 val skel0 = Bound 0;
   585 
   586 (* Use rhs as skeleton only if the lhs does not contain unnormalized bits.
   587    The latter may happen iff there are weak congruence rules for constants
   588    in the lhs.
   589 *)
   590 fun uncond_skel((_,weak),(lhs,rhs)) =
   591   if null weak then rhs (* optimization *)
   592   else if exists_Const (fn (c,_) => c mem weak) lhs then skel0
   593        else rhs;
   594 
   595 (* Behaves like unconditional rule if rhs does not contain vars not in the lhs.
   596    Otherwise those vars may become instantiated with unnormalized terms
   597    while the premises are solved.
   598 *)
   599 fun cond_skel(args as (congs,(lhs,rhs))) =
   600   if term_varnames rhs subset term_varnames lhs then uncond_skel(args)
   601   else skel0;
   602 
   603 (*
   604   we try in order:
   605     (1) beta reduction
   606     (2) unconditional rewrite rules
   607     (3) conditional rewrite rules
   608     (4) simplification procedures
   609 
   610   IMPORTANT: rewrite rules must not introduce new Vars or TVars!
   611 
   612 *)
   613 
   614 fun rewritec (prover, signt, maxt)
   615              (mss as Mss{rules, procs, termless, prems, congs, depth,...}) t =
   616   let
   617     val eta_thm = Thm.eta_conversion t;
   618     val eta_t' = rhs_of eta_thm;
   619     val eta_t = term_of eta_t';
   620     val tsigt = Sign.tsig_of signt;
   621     fun rew {thm, name, lhs, elhs, fo, perm} =
   622       let
   623         val {sign, prop, maxidx, ...} = rep_thm thm;
   624         val _ = if Sign.subsig (sign, signt) then ()
   625                 else (prthm true "Ignoring rewrite rule from different theory:" thm;
   626                       raise Pattern.MATCH);
   627         val (rthm, elhs') = if maxt = ~1 then (thm, elhs)
   628           else (Thm.incr_indexes (maxt+1) thm, Thm.cterm_incr_indexes (maxt+1) elhs);
   629         val insts = if fo then Thm.cterm_first_order_match (elhs', eta_t')
   630                           else Thm.cterm_match (elhs', eta_t');
   631         val thm' = Thm.instantiate insts (Thm.rename_boundvars lhs eta_t rthm);
   632         val prop' = Thm.prop_of thm';
   633         val unconditional = (Logic.count_prems (prop',0) = 0);
   634         val (lhs', rhs') = Logic.dest_equals (Logic.strip_imp_concl prop')
   635       in
   636         if perm andalso not (termless (rhs', lhs'))
   637         then (trace_named_thm "Cannot apply permutative rewrite rule" (thm, name);
   638               trace_thm "Term does not become smaller:" thm'; None)
   639         else (trace_named_thm "Applying instance of rewrite rule" (thm, name);
   640            if unconditional
   641            then
   642              (trace_thm "Rewriting:" thm';
   643               let val lr = Logic.dest_equals prop;
   644                   val Some thm'' = check_conv false eta_thm thm'
   645               in Some (thm'', uncond_skel (congs, lr)) end)
   646            else
   647              (trace_thm "Trying to rewrite:" thm';
   648               case incr_depth mss of
   649                 None => (trace_thm "FAILED - reached depth limit" thm'; None)
   650               | Some mss =>
   651               (case prover mss thm' of
   652                 None       => (trace_thm "FAILED" thm'; None)
   653               | Some thm2 =>
   654                   (case check_conv true eta_thm thm2 of
   655                      None => None |
   656                      Some thm2' =>
   657                        let val concl = Logic.strip_imp_concl prop
   658                            val lr = Logic.dest_equals concl
   659                        in Some (thm2', cond_skel (congs, lr)) end))))
   660       end
   661 
   662     fun rews [] = None
   663       | rews (rrule :: rrules) =
   664           let val opt = rew rrule handle Pattern.MATCH => None
   665           in case opt of None => rews rrules | some => some end;
   666 
   667     fun sort_rrules rrs = let
   668       fun is_simple({thm, ...}:rrule) = case Thm.prop_of thm of
   669                                       Const("==",_) $ _ $ _ => true
   670                                       | _                   => false
   671       fun sort []        (re1,re2) = re1 @ re2
   672         | sort (rr::rrs) (re1,re2) = if is_simple rr
   673                                      then sort rrs (rr::re1,re2)
   674                                      else sort rrs (re1,rr::re2)
   675     in sort rrs ([],[]) end
   676 
   677     fun proc_rews ([]:simproc list) = None
   678       | proc_rews ({name, proc, lhs, ...} :: ps) =
   679           if Pattern.matches tsigt (term_of lhs, term_of t) then
   680             (debug_term false ("Trying procedure " ^ quote name ^ " on:") signt eta_t;
   681              case transform_failure (curry SIMPROC_FAIL name)
   682                  (fn () => proc signt prems eta_t) () of
   683                None => (debug false "FAILED"; proc_rews ps)
   684              | Some raw_thm =>
   685                  (trace_thm ("Procedure " ^ quote name ^ " produced rewrite rule:") raw_thm;
   686                   (case rews (mk_procrule raw_thm) of
   687                     None => (trace_cterm true ("IGNORED result of simproc " ^ quote name ^
   688                       " -- does not match") t; proc_rews ps)
   689                   | some => some)))
   690           else proc_rews ps;
   691   in case eta_t of
   692        Abs _ $ _ => Some (transitive eta_thm
   693          (beta_conversion false eta_t'), skel0)
   694      | _ => (case rews (sort_rrules (Net.match_term rules eta_t)) of
   695                None => proc_rews (Net.match_term procs eta_t)
   696              | some => some)
   697   end;
   698 
   699 
   700 (* conversion to apply a congruence rule to a term *)
   701 
   702 fun congc (prover,signt,maxt) {thm=cong,lhs=lhs} t =
   703   let val sign = Thm.sign_of_thm cong
   704       val _ = if Sign.subsig (sign, signt) then ()
   705                  else error("Congruence rule from different theory")
   706       val rthm = if maxt = ~1 then cong else Thm.incr_indexes (maxt+1) cong;
   707       val rlhs = fst (Drule.dest_equals (Drule.strip_imp_concl (cprop_of rthm)));
   708       val insts = Thm.cterm_match (rlhs, t)
   709       (* Pattern.match can raise Pattern.MATCH;
   710          is handled when congc is called *)
   711       val thm' = Thm.instantiate insts (Thm.rename_boundvars (term_of rlhs) (term_of t) rthm);
   712       val unit = trace_thm "Applying congruence rule:" thm';
   713       fun err (msg, thm) = (trace_thm msg thm; None)
   714   in case prover thm' of
   715        None => err ("Congruence proof failed.  Could not prove", thm')
   716      | Some thm2 => (case check_conv true (beta_eta_conversion t) thm2 of
   717           None => err ("Congruence proof failed.  Should not have proved", thm2)
   718         | Some thm2' =>
   719             if op aconv (pairself term_of (dest_equals (cprop_of thm2')))
   720             then None else Some thm2')
   721   end;
   722 
   723 val (cA, (cB, cC)) =
   724   apsnd dest_equals (dest_implies (hd (cprems_of Drule.imp_cong)));
   725 
   726 fun transitive1 None None = None
   727   | transitive1 (Some thm1) None = Some thm1
   728   | transitive1 None (Some thm2) = Some thm2
   729   | transitive1 (Some thm1) (Some thm2) = Some (transitive thm1 thm2)
   730 
   731 fun transitive2 thm = transitive1 (Some thm);
   732 fun transitive3 thm = transitive1 thm o Some;
   733 
   734 fun imp_cong' e = combination (combination refl_implies e);
   735 
   736 fun bottomc ((simprem,useprem,mutsimp), prover, sign, maxidx) =
   737   let
   738     fun botc skel mss t =
   739           if is_Var skel then None
   740           else
   741           (case subc skel mss t of
   742              some as Some thm1 =>
   743                (case rewritec (prover, sign, maxidx) mss (rhs_of thm1) of
   744                   Some (thm2, skel2) =>
   745                     transitive2 (transitive thm1 thm2)
   746                       (botc skel2 mss (rhs_of thm2))
   747                 | None => some)
   748            | None =>
   749                (case rewritec (prover, sign, maxidx) mss t of
   750                   Some (thm2, skel2) => transitive2 thm2
   751                     (botc skel2 mss (rhs_of thm2))
   752                 | None => None))
   753 
   754     and try_botc mss t =
   755           (case botc skel0 mss t of
   756              Some trec1 => trec1 | None => (reflexive t))
   757 
   758     and subc skel
   759           (mss as Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) t0 =
   760        (case term_of t0 of
   761            Abs (a, T, t) =>
   762              let val b = variant bounds a
   763                  val (v, t') = Thm.dest_abs (Some ("." ^ b)) t0
   764                  val mss' = mk_mss (rules, congs, procs, b :: bounds, prems, mk_rews, termless,depth)
   765                  val skel' = case skel of Abs (_, _, sk) => sk | _ => skel0
   766              in case botc skel' mss' t' of
   767                   Some thm => Some (abstract_rule a v thm)
   768                 | None => None
   769              end
   770          | t $ _ => (case t of
   771              Const ("==>", _) $ _  => impc t0 mss
   772            | Abs _ =>
   773                let val thm = beta_conversion false t0
   774                in case subc skel0 mss (rhs_of thm) of
   775                     None => Some thm
   776                   | Some thm' => Some (transitive thm thm')
   777                end
   778            | _  =>
   779                let fun appc () =
   780                      let
   781                        val (tskel, uskel) = case skel of
   782                            tskel $ uskel => (tskel, uskel)
   783                          | _ => (skel0, skel0);
   784                        val (ct, cu) = Thm.dest_comb t0
   785                      in
   786                      (case botc tskel mss ct of
   787                         Some thm1 =>
   788                           (case botc uskel mss cu of
   789                              Some thm2 => Some (combination thm1 thm2)
   790                            | None => Some (combination thm1 (reflexive cu)))
   791                       | None =>
   792                           (case botc uskel mss cu of
   793                              Some thm1 => Some (combination (reflexive ct) thm1)
   794                            | None => None))
   795                      end
   796                    val (h, ts) = strip_comb t
   797                in case cong_name h of
   798                     Some a =>
   799                       (case assoc_string (fst congs, a) of
   800                          None => appc ()
   801                        | Some cong =>
   802 (* post processing: some partial applications h t1 ... tj, j <= length ts,
   803    may be a redex. Example: map (%x.x) = (%xs.xs) wrt map_cong *)
   804                           (let
   805                              val thm = congc (prover mss, sign, maxidx) cong t0;
   806                              val t = if_none (apsome rhs_of thm) t0;
   807                              val (cl, cr) = Thm.dest_comb t
   808                              val dVar = Var(("", 0), dummyT)
   809                              val skel =
   810                                list_comb (h, replicate (length ts) dVar)
   811                            in case botc skel mss cl of
   812                                 None => thm
   813                               | Some thm' => transitive3 thm
   814                                   (combination thm' (reflexive cr))
   815                            end handle TERM _ => error "congc result"
   816                                     | Pattern.MATCH => appc ()))
   817                   | _ => appc ()
   818                end)
   819          | _ => None)
   820 
   821     and impc ct mss =
   822       if mutsimp then mut_impc0 [] ct [] [] mss else nonmut_impc ct mss
   823 
   824     and rules_of_prem mss prem =
   825       if maxidx_of_term (term_of prem) <> ~1
   826       then (trace_cterm true
   827         "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem; ([], None))
   828       else
   829         let val asm = assume prem
   830         in (extract_safe_rrules (mss, asm), Some asm) end
   831 
   832     and add_rrules (rrss, asms) mss =
   833       add_prems (foldl (insert_rrule true) (mss, flat rrss), mapfilter I asms)
   834 
   835     and disch r (prem, eq) =
   836       let
   837         val (lhs, rhs) = dest_eq eq;
   838         val eq' = implies_elim (Thm.instantiate
   839           ([], [(cA, prem), (cB, lhs), (cC, rhs)]) Drule.imp_cong)
   840           (implies_intr prem eq)
   841       in if not r then eq' else
   842         let
   843           val (prem', concl) = dest_implies lhs;
   844           val (prem'', _) = dest_implies rhs
   845         in transitive (transitive
   846           (Thm.instantiate ([], [(cA, prem'), (cB, prem), (cC, concl)])
   847              Drule.swap_prems_eq) eq')
   848           (Thm.instantiate ([], [(cA, prem), (cB, prem''), (cC, concl)])
   849              Drule.swap_prems_eq)
   850         end
   851       end
   852 
   853     and rebuild [] _ _ _ _ eq = eq
   854       | rebuild (prem :: prems) concl (rrs :: rrss) (asm :: asms) mss eq =
   855           let
   856             val mss' = add_rrules (rev rrss, rev asms) mss;
   857             val concl' =
   858               Drule.mk_implies (prem, if_none (apsome rhs_of eq) concl);
   859             val dprem = apsome (curry (disch false) prem)
   860           in case rewritec (prover, sign, maxidx) mss' concl' of
   861               None => rebuild prems concl' rrss asms mss (dprem eq)
   862             | Some (eq', _) => transitive2 (foldl (disch false o swap)
   863                   (the (transitive3 (dprem eq) eq'), prems))
   864                 (mut_impc0 (rev prems) (rhs_of eq') (rev rrss) (rev asms) mss)
   865           end
   866           
   867     and mut_impc0 prems concl rrss asms mss =
   868       let
   869         val prems' = strip_imp_prems concl;
   870         val (rrss', asms') = split_list (map (rules_of_prem mss) prems')
   871       in mut_impc (prems @ prems') (strip_imp_concl concl) (rrss @ rrss')
   872         (asms @ asms') [] [] [] [] mss ~1 ~1
   873       end
   874  
   875     and mut_impc [] concl [] [] prems' rrss' asms' eqns mss changed k =
   876         transitive1 (foldl (fn (eq2, (eq1, prem)) => transitive1 eq1
   877             (apsome (curry (disch false) prem) eq2)) (None, eqns ~~ prems'))
   878           (if changed > 0 then
   879              mut_impc (rev prems') concl (rev rrss') (rev asms')
   880                [] [] [] [] mss ~1 changed
   881            else rebuild prems' concl rrss' asms' mss
   882              (botc skel0 (add_rrules (rev rrss', rev asms') mss) concl))
   883 
   884       | mut_impc (prem :: prems) concl (rrs :: rrss) (asm :: asms)
   885           prems' rrss' asms' eqns mss changed k =
   886         case (if k = 0 then None else botc skel0 (add_rrules
   887           (rev rrss' @ rrss, rev asms' @ asms) mss) prem) of
   888             None => mut_impc prems concl rrss asms (prem :: prems')
   889               (rrs :: rrss') (asm :: asms') (None :: eqns) mss changed
   890               (if k = 0 then 0 else k - 1)
   891           | Some eqn =>
   892             let
   893               val prem' = rhs_of eqn;
   894               val tprems = map term_of prems;
   895               val i = 1 + foldl Int.max (~1, map (fn p =>
   896                 find_index_eq p tprems) (#hyps (rep_thm eqn)));
   897               val (rrs', asm') = rules_of_prem mss prem'
   898             in mut_impc prems concl rrss asms (prem' :: prems')
   899               (rrs' :: rrss') (asm' :: asms') (Some (foldr (disch true)
   900                 (take (i, prems), imp_cong' eqn (reflexive (Drule.list_implies
   901                   (drop (i, prems), concl))))) :: eqns) mss (length prems') ~1
   902             end
   903 
   904      (* legacy code - only for backwards compatibility *)
   905      and nonmut_impc ct mss =
   906        let val (prem, conc) = dest_implies ct;
   907            val thm1 = if simprem then botc skel0 mss prem else None;
   908            val prem1 = if_none (apsome rhs_of thm1) prem;
   909            val mss1 = if not useprem then mss else add_rrules
   910              (apsnd single (apfst single (rules_of_prem mss prem1))) mss
   911        in (case botc skel0 mss1 conc of
   912            None => (case thm1 of
   913                None => None
   914              | Some thm1' => Some (imp_cong' thm1' (reflexive conc)))
   915          | Some thm2 =>
   916            let val thm2' = disch false (prem1, thm2)
   917            in (case thm1 of
   918                None => Some thm2'
   919              | Some thm1' =>
   920                  Some (transitive (imp_cong' thm1' (reflexive conc)) thm2'))
   921            end)
   922        end
   923 
   924  in try_botc end;
   925 
   926 
   927 (*** Meta-rewriting: rewrites t to u and returns the theorem t==u ***)
   928 
   929 (*
   930   Parameters:
   931     mode = (simplify A,
   932             use A in simplifying B,
   933             use prems of B (if B is again a meta-impl.) to simplify A)
   934            when simplifying A ==> B
   935     mss: contains equality theorems of the form [|p1,...|] ==> t==u
   936     prover: how to solve premises in conditional rewrites and congruences
   937 *)
   938 
   939 fun rewrite_cterm mode prover mss ct =
   940   let val {sign, t, maxidx, ...} = rep_cterm ct
   941       val Mss{depth, ...} = mss
   942   in trace_cterm false "SIMPLIFIER INVOKED ON THE FOLLOWING TERM:" ct;
   943      simp_depth := depth;
   944      bottomc (mode, prover, sign, maxidx) mss ct
   945   end
   946   handle THM (s, _, thms) =>
   947     error ("Exception THM was raised in simplifier:\n" ^ s ^ "\n" ^
   948       Pretty.string_of (Display.pretty_thms thms));
   949 
   950 (*In [A1,...,An]==>B, rewrite the selected A's only -- for rewrite_goals_tac*)
   951 fun goals_conv pred cv =
   952   let fun gconv i ct =
   953         let val (A,B) = Drule.dest_implies ct
   954         in imp_cong' (if pred i then cv A else reflexive A) (gconv (i+1) B) end
   955         handle TERM _ => reflexive ct
   956   in gconv 1 end;
   957 
   958 (* Rewrite A in !!x1,...,xn. A *)
   959 fun forall_conv cv ct =
   960   let val p as (ct1, ct2) = Thm.dest_comb ct
   961   in (case pairself term_of p of
   962       (Const ("all", _), Abs (s, _, _)) =>
   963          let val (v, ct') = Thm.dest_abs (Some "@") ct2;
   964          in Thm.combination (Thm.reflexive ct1)
   965            (Thm.abstract_rule s v (forall_conv cv ct'))
   966          end
   967     | _ => cv ct)
   968   end handle TERM _ => cv ct;
   969 
   970 (*Use a conversion to transform a theorem*)
   971 fun fconv_rule cv th = equal_elim (cv (cprop_of th)) th;
   972 
   973 (*Rewrite a cterm*)
   974 fun rewrite_aux _ _ [] = (fn ct => Thm.reflexive ct)
   975   | rewrite_aux prover full thms = rewrite_cterm (full, false, false) prover (mss_of thms);
   976 
   977 (*Rewrite a theorem*)
   978 fun simplify_aux _ _ [] = (fn th => th)
   979   | simplify_aux prover full thms =
   980       fconv_rule (rewrite_cterm (full, false, false) prover (mss_of thms));
   981 
   982 fun rewrite_thm mode prover mss = fconv_rule (rewrite_cterm mode prover mss);
   983 
   984 (*Rewrite the subgoals of a proof state (represented by a theorem) *)
   985 fun rewrite_goals_rule_aux _ []   th = th
   986   | rewrite_goals_rule_aux prover thms th =
   987       fconv_rule (goals_conv (K true) (rewrite_cterm (true, true, false) prover
   988         (mss_of thms))) th;
   989 
   990 (*Rewrite the subgoal of a proof state (represented by a theorem) *)
   991 fun rewrite_goal_rule mode prover mss i thm =
   992   if 0 < i  andalso  i <= nprems_of thm
   993   then fconv_rule (goals_conv (fn j => j=i) (rewrite_cterm mode prover mss)) thm
   994   else raise THM("rewrite_goal_rule",i,[thm]);
   995 
   996 
   997 (*simple term rewriting -- without proofs*)
   998 fun rewrite_term sg rules procs =
   999   Pattern.rewrite_term (Sign.tsig_of sg) (map decomp_simp' rules) procs;
  1000 
  1001 end;
  1002 
  1003 structure BasicMetaSimplifier: BASIC_META_SIMPLIFIER = MetaSimplifier;
  1004 open BasicMetaSimplifier;