src/Pure/meta_simplifier.ML
author skalberg
Wed Jun 30 00:42:59 2004 +0200 (2004-06-30)
changeset 15011 35be762f58f9
parent 15006 107e4dfd3b96
child 15023 0e4689f411d5
permissions -rw-r--r--
Made simplification procedures simpset-aware.
     1 (*  Title:      Pure/meta_simplifier.ML
     2     ID:         $Id$
     3     Author:     Tobias Nipkow and Stefan Berghofer
     4 
     5 Meta-level Simplification.
     6 *)
     7 
     8 infix 4
     9   setsubgoaler setloop addloop delloop setSSolver addSSolver setSolver
    10   addSolver addsimps delsimps addeqcongs deleqcongs addcongs delcongs
    11   setmksimps setmkeqTrue setmkcong setmksym settermless addsimprocs delsimprocs;
    12 
    13 signature BASIC_META_SIMPLIFIER =
    14 sig
    15   val trace_simp: bool ref
    16   val debug_simp: bool ref
    17   val simp_depth_limit: int ref
    18 end;
    19 
    20 signature AUX_SIMPLIFIER =
    21 sig
    22   type meta_simpset
    23   type simpset
    24   type simproc
    25   val full_mk_simproc: string -> cterm list
    26     -> (simpset -> Sign.sg -> thm list -> term -> thm option) -> simproc
    27   val mk_simproc: string -> cterm list
    28     -> (Sign.sg -> thm list -> term -> thm option) -> simproc
    29   type solver
    30   val mk_solver: string -> (thm list -> int -> tactic) -> solver
    31   val empty_ss: simpset
    32   val rep_ss: simpset ->
    33    {mss: meta_simpset,
    34     mk_cong: thm -> thm,
    35     subgoal_tac: simpset -> int -> tactic,
    36     loop_tacs: (string * (int -> tactic)) list,
    37     unsafe_solvers: solver list,
    38     solvers: solver list}
    39   val from_mss: meta_simpset -> simpset
    40   val ss_of            : thm list -> simpset
    41   val print_ss: simpset -> unit
    42   val setsubgoaler: simpset *  (simpset -> int -> tactic) -> simpset
    43   val setloop:      simpset *             (int -> tactic) -> simpset
    44   val addloop:      simpset *  (string * (int -> tactic)) -> simpset
    45   val delloop:      simpset *   string                    -> simpset
    46   val setSSolver:   simpset * solver -> simpset
    47   val addSSolver:   simpset * solver -> simpset
    48   val setSolver:    simpset * solver -> simpset
    49   val addSolver:    simpset * solver -> simpset
    50   val setmksimps:   simpset * (thm -> thm list) -> simpset
    51   val setmkeqTrue:  simpset * (thm -> thm option) -> simpset
    52   val setmkcong:    simpset * (thm -> thm) -> simpset
    53   val setmksym:     simpset * (thm -> thm option) -> simpset
    54   val settermless:  simpset * (term * term -> bool) -> simpset
    55   val addsimps:     simpset * thm list -> simpset
    56   val delsimps:     simpset * thm list -> simpset
    57   val addeqcongs:   simpset * thm list -> simpset
    58   val deleqcongs:   simpset * thm list -> simpset
    59   val addcongs:     simpset * thm list -> simpset
    60   val delcongs:     simpset * thm list -> simpset
    61   val addsimprocs:  simpset * simproc list -> simpset
    62   val delsimprocs:  simpset * simproc list -> simpset
    63   val merge_ss:     simpset * simpset -> simpset
    64   val prems_of_ss:  simpset -> thm list
    65   val generic_simp_tac: bool -> bool * bool * bool -> simpset -> int -> tactic
    66   val simproc: Sign.sg -> string -> string list
    67     -> (Sign.sg -> thm list -> term -> thm option) -> simproc
    68   val simproc_i: Sign.sg -> string -> term list
    69     -> (Sign.sg -> thm list -> term -> thm option) -> simproc
    70   val full_simproc: Sign.sg -> string -> string list
    71     -> (simpset -> Sign.sg -> thm list -> term -> thm option) -> simproc
    72   val full_simproc_i: Sign.sg -> string -> term list
    73     -> (simpset -> Sign.sg -> thm list -> term -> thm option) -> simproc
    74   val clear_ss  : simpset -> simpset
    75   val simp_thm  : bool * bool * bool -> simpset -> thm -> thm
    76   val simp_cterm: bool * bool * bool -> simpset -> cterm -> thm
    77 end;
    78 
    79 signature META_SIMPLIFIER =
    80 sig
    81   include BASIC_META_SIMPLIFIER
    82   include AUX_SIMPLIFIER
    83   exception SIMPLIFIER of string * thm
    84   exception SIMPROC_FAIL of string * exn
    85   val dest_mss          : meta_simpset ->
    86     {simps: thm list, congs: thm list, procs: (string * cterm list) list}
    87   val empty_mss         : meta_simpset
    88   val clear_mss         : meta_simpset -> meta_simpset
    89   val merge_mss         : meta_simpset * meta_simpset -> meta_simpset
    90   val add_simps         : meta_simpset * thm list -> meta_simpset
    91   val del_simps         : meta_simpset * thm list -> meta_simpset
    92   val mss_of            : thm list -> meta_simpset
    93   val add_congs         : meta_simpset * thm list -> meta_simpset
    94   val del_congs         : meta_simpset * thm list -> meta_simpset
    95   val add_simprocs      : meta_simpset *
    96     (string * cterm list * (simpset -> Sign.sg -> thm list -> term -> thm option) * stamp) list
    97       -> meta_simpset
    98   val del_simprocs      : meta_simpset *
    99     (string * cterm list * (simpset -> Sign.sg -> thm list -> term -> thm option) * stamp) list
   100       -> meta_simpset
   101   val add_prems         : meta_simpset * thm list -> meta_simpset
   102   val prems_of_mss      : meta_simpset -> thm list
   103   val set_mk_rews       : meta_simpset * (thm -> thm list) -> meta_simpset
   104   val set_mk_sym        : meta_simpset * (thm -> thm option) -> meta_simpset
   105   val set_mk_eq_True    : meta_simpset * (thm -> thm option) -> meta_simpset
   106   val get_mk_rews       : meta_simpset -> thm -> thm list
   107   val get_mk_sym        : meta_simpset -> thm -> thm option
   108   val get_mk_eq_True    : meta_simpset -> thm -> thm option
   109   val set_termless      : meta_simpset * (term * term -> bool) -> meta_simpset
   110   val rewrite_cterm: bool * bool * bool ->
   111     (meta_simpset -> thm -> thm option) -> simpset -> cterm -> thm
   112   val rewrite_aux       : (meta_simpset -> thm -> thm option) -> bool -> thm list -> cterm -> thm
   113   val simplify_aux      : (meta_simpset -> thm -> thm option) -> bool -> thm list -> thm -> thm
   114   val rewrite_thm       : bool * bool * bool
   115                           -> (meta_simpset -> thm -> thm option)
   116                           -> simpset -> thm -> thm
   117   val rewrite_goals_rule_aux: (meta_simpset -> thm -> thm option) -> thm list -> thm -> thm
   118   val rewrite_goal_rule : bool* bool * bool
   119                           -> (meta_simpset -> thm -> thm option)
   120                           -> simpset -> int -> thm -> thm
   121   val rewrite_term: Sign.sg -> thm list -> (term -> term option) list -> term -> term
   122   val asm_rewrite_goal_tac: bool*bool*bool ->
   123     (meta_simpset -> tactic) -> simpset -> int -> tactic
   124 
   125 end;
   126 
   127 structure MetaSimplifier : META_SIMPLIFIER =
   128 struct
   129 
   130 (** diagnostics **)
   131 
   132 exception SIMPLIFIER of string * thm;
   133 exception SIMPROC_FAIL of string * exn;
   134 
   135 val simp_depth = ref 0;
   136 val simp_depth_limit = ref 1000;
   137 
   138 local
   139 
   140 fun println a =
   141   tracing ((case ! simp_depth of 0 => "" | n => "[" ^ string_of_int n ^ "]") ^ a);
   142 
   143 fun prnt warn a = if warn then warning a else println a;
   144 fun prtm warn a sign t = prnt warn (a ^ "\n" ^ Sign.string_of_term sign t);
   145 fun prctm warn a t = prnt warn (a ^ "\n" ^ Display.string_of_cterm t);
   146 
   147 in
   148 
   149 fun prthm warn a = prctm warn a o Thm.cprop_of;
   150 
   151 val trace_simp = ref false;
   152 val debug_simp = ref false;
   153 
   154 fun trace warn a = if !trace_simp then prnt warn a else ();
   155 fun debug warn a = if !debug_simp then prnt warn a else ();
   156 
   157 fun trace_term warn a sign t = if !trace_simp then prtm warn a sign t else ();
   158 fun trace_cterm warn a t = if !trace_simp then prctm warn a t else ();
   159 fun debug_term warn a sign t = if !debug_simp then prtm warn a sign t else ();
   160 
   161 fun trace_thm a thm =
   162   let val {sign, prop, ...} = rep_thm thm
   163   in trace_term false a sign prop end;
   164 
   165 fun trace_named_thm a (thm, name) =
   166   trace_thm (a ^ (if name = "" then "" else " " ^ quote name) ^ ":") thm;
   167 
   168 end;
   169 
   170 
   171 (** meta simp sets **)
   172 
   173 (* basic components *)
   174 
   175 type rrule = {thm: thm, name: string, lhs: term, elhs: cterm, fo: bool, perm: bool};
   176 (* thm: the rewrite rule
   177    name: name of theorem from which rewrite rule was extracted
   178    lhs: the left-hand side
   179    elhs: the etac-contracted lhs.
   180    fo:  use first-order matching
   181    perm: the rewrite rule is permutative
   182 Remarks:
   183   - elhs is used for matching,
   184     lhs only for preservation of bound variable names.
   185   - fo is set iff
   186     either elhs is first-order (no Var is applied),
   187            in which case fo-matching is complete,
   188     or elhs is not a pattern,
   189        in which case there is nothing better to do.
   190 *)
   191 type cong = {thm: thm, lhs: cterm};
   192 
   193 fun eq_rrule ({thm = thm1, ...}: rrule, {thm = thm2, ...}: rrule) =
   194   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   195 
   196 fun eq_cong ({thm = thm1, ...}: cong, {thm = thm2, ...}: cong) =
   197   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   198 
   199 fun eq_prem (thm1, thm2) =
   200   #prop (rep_thm thm1) aconv #prop (rep_thm thm2);
   201 
   202 
   203 (* datatype mss *)
   204 
   205 (*
   206   A "mss" contains data needed during conversion:
   207     rules: discrimination net of rewrite rules;
   208     congs: association list of congruence rules and
   209            a list of `weak' congruence constants.
   210            A congruence is `weak' if it avoids normalization of some argument.
   211     procs: discrimination net of simplification procedures
   212       (functions that prove rewrite rules on the fly);
   213     bounds: names of bound variables already used
   214       (for generating new names when rewriting under lambda abstractions);
   215     prems: current premises;
   216     mk_rews: mk: turns simplification thms into rewrite rules;
   217              mk_sym: turns == around; (needs Drule!)
   218              mk_eq_True: turns P into P == True - logic specific;
   219     termless: relation for ordered rewriting;
   220     depth: depth of conditional rewriting;
   221 *)
   222 
   223 datatype solver = Solver of string * (thm list -> int -> tactic) * stamp;
   224 
   225 datatype meta_simpset =
   226   Mss of {
   227     rules: rrule Net.net,
   228     congs: (string * cong) list * string list,
   229     procs: meta_simproc Net.net,
   230     bounds: string list,
   231     prems: thm list,
   232     mk_rews: {mk: thm -> thm list,
   233               mk_sym: thm -> thm option,
   234               mk_eq_True: thm -> thm option},
   235     termless: term * term -> bool,
   236     depth: int}
   237 and simpset =
   238   Simpset of {
   239     mss: meta_simpset,
   240     mk_cong: thm -> thm,
   241     subgoal_tac: simpset -> int -> tactic,
   242     loop_tacs: (string * (int -> tactic)) list,
   243     unsafe_solvers: solver list,
   244     solvers: solver list}
   245 withtype meta_simproc =
   246  {name: string, proc: simpset -> Sign.sg -> thm list -> term -> thm option, lhs: cterm, id: stamp};
   247 
   248 fun mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth) =
   249   Mss {rules = rules, congs = congs, procs = procs, bounds = bounds,
   250        prems=prems, mk_rews=mk_rews, termless=termless, depth=depth};
   251 
   252 fun upd_rules(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}, rules') =
   253   mk_mss(rules',congs,procs,bounds,prems,mk_rews,termless,depth);
   254 
   255 val empty_mss =
   256   let val mk_rews = {mk = K [], mk_sym = K None, mk_eq_True = K None}
   257   in mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, Term.termless, 0) end;
   258 
   259 fun clear_mss (Mss {mk_rews, termless, ...}) =
   260   mk_mss (Net.empty, ([], []), Net.empty, [], [], mk_rews, termless,0);
   261 
   262 fun incr_depth(Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) =
   263   let val depth1 = depth+1
   264   in if depth1 > !simp_depth_limit
   265      then (warning "simp_depth_limit exceeded - giving up"; None)
   266      else (if depth1 mod 10 = 0
   267            then warning("Simplification depth " ^ string_of_int depth1)
   268            else ();
   269            Some(mk_mss(rules,congs,procs,bounds,prems,mk_rews,termless,depth1))
   270           )
   271   end;
   272 
   273 datatype simproc =
   274   Simproc of string * cterm list * (simpset -> Sign.sg -> thm list -> term -> thm option) * stamp
   275 
   276 fun eq_simproc ({id = s1, ...}:meta_simproc, {id = s2, ...}:meta_simproc) = (s1 = s2);
   277 
   278 fun mk_simproc (name, proc, lhs, id) =
   279   {name = name, proc = proc, lhs = lhs, id = id};
   280 
   281 
   282 (** simpset operations **)
   283 
   284 (* term variables *)
   285 
   286 val add_term_varnames = foldl_aterms (fn (xs, Var (x, _)) => ins_ix (x, xs) | (xs, _) => xs);
   287 fun term_varnames t = add_term_varnames ([], t);
   288 
   289 
   290 (* dest_mss *)
   291 
   292 fun dest_mss (Mss {rules, congs, procs, ...}) =
   293   {simps = map (fn (_, {thm, ...}) => thm) (Net.dest rules),
   294    congs = map (fn (_, {thm, ...}) => thm) (fst congs),
   295    procs =
   296      map (fn (_, {name, lhs, id, ...}) => ((name, lhs), id)) (Net.dest procs)
   297      |> partition_eq eq_snd
   298      |> map (fn ps => (#1 (#1 (hd ps)), map (#2 o #1) ps))
   299      |> Library.sort_wrt #1};
   300 
   301 
   302 (* merge_mss *)       (*NOTE: ignores mk_rews, termless and depth of 2nd mss*)
   303 
   304 fun merge_mss
   305  (Mss {rules = rules1, congs = (congs1,weak1), procs = procs1,
   306        bounds = bounds1, prems = prems1, mk_rews, termless, depth},
   307   Mss {rules = rules2, congs = (congs2,weak2), procs = procs2,
   308        bounds = bounds2, prems = prems2, ...}) =
   309       mk_mss
   310        (Net.merge (rules1, rules2, eq_rrule),
   311         (gen_merge_lists (eq_cong o pairself snd) congs1 congs2,
   312         merge_lists weak1 weak2),
   313         Net.merge (procs1, procs2, eq_simproc),
   314         merge_lists bounds1 bounds2,
   315         gen_merge_lists eq_prem prems1 prems2,
   316         mk_rews, termless, depth);
   317 
   318 
   319 (* add_simps *)
   320 
   321 fun mk_rrule2{thm, name, lhs, elhs, perm} =
   322   let val fo = Pattern.first_order (term_of elhs) orelse not(Pattern.pattern (term_of elhs))
   323   in {thm=thm, name=name, lhs=lhs, elhs=elhs, fo=fo, perm=perm} end
   324 
   325 fun insert_rrule quiet (mss as Mss {rules,...},
   326                  rrule as {thm,name,lhs,elhs,perm}) =
   327   (trace_named_thm "Adding rewrite rule" (thm, name);
   328    let val rrule2 as {elhs,...} = mk_rrule2 rrule
   329        val rules' = Net.insert_term ((term_of elhs, rrule2), rules, eq_rrule)
   330    in upd_rules(mss,rules') end
   331    handle Net.INSERT => if quiet then mss else
   332      (prthm true "Ignoring duplicate rewrite rule:" thm; mss));
   333 
   334 fun vperm (Var _, Var _) = true
   335   | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t)
   336   | vperm (t1 $ t2, u1 $ u2) = vperm (t1, u1) andalso vperm (t2, u2)
   337   | vperm (t, u) = (t = u);
   338 
   339 fun var_perm (t, u) =
   340   vperm (t, u) andalso eq_set (term_varnames t, term_varnames u);
   341 
   342 (* FIXME: it seems that the conditions on extra variables are too liberal if
   343 prems are nonempty: does solving the prems really guarantee instantiation of
   344 all its Vars? Better: a dynamic check each time a rule is applied.
   345 *)
   346 fun rewrite_rule_extra_vars prems elhs erhs =
   347   not (term_varnames erhs subset foldl add_term_varnames (term_varnames elhs, prems))
   348   orelse
   349   not ((term_tvars erhs) subset
   350        (term_tvars elhs  union  List.concat(map term_tvars prems)));
   351 
   352 (*Simple test for looping rewrite rules and stupid orientations*)
   353 fun reorient sign prems lhs rhs =
   354    rewrite_rule_extra_vars prems lhs rhs
   355   orelse
   356    is_Var (head_of lhs)
   357   orelse
   358    (exists (apl (lhs, Logic.occs)) (rhs :: prems))
   359   orelse
   360    (null prems andalso
   361     Pattern.matches (Sign.tsig_of sign) (lhs, rhs))
   362     (*the condition "null prems" is necessary because conditional rewrites
   363       with extra variables in the conditions may terminate although
   364       the rhs is an instance of the lhs. Example: ?m < ?n ==> f(?n) == f(?m)*)
   365   orelse
   366    (is_Const lhs andalso not(is_Const rhs))
   367 
   368 fun decomp_simp thm =
   369   let val {sign, prop, ...} = rep_thm thm;
   370       val prems = Logic.strip_imp_prems prop;
   371       val concl = Drule.strip_imp_concl (cprop_of thm);
   372       val (lhs, rhs) = Drule.dest_equals concl handle TERM _ =>
   373         raise SIMPLIFIER ("Rewrite rule not a meta-equality", thm)
   374       val elhs = snd (Drule.dest_equals (cprop_of (Thm.eta_conversion lhs)));
   375       val elhs = if elhs=lhs then lhs else elhs (* try to share *)
   376       val erhs = Pattern.eta_contract (term_of rhs);
   377       val perm = var_perm (term_of elhs, erhs) andalso not (term_of elhs aconv erhs)
   378                  andalso not (is_Var (term_of elhs))
   379   in (sign, prems, term_of lhs, elhs, term_of rhs, perm) end;
   380 
   381 fun decomp_simp' thm =
   382   let val (_, _, lhs, _, rhs, _) = decomp_simp thm in
   383     if Thm.nprems_of thm > 0 then raise SIMPLIFIER ("Bad conditional rewrite rule", thm)
   384     else (lhs, rhs)
   385   end;
   386 
   387 fun mk_eq_True (Mss{mk_rews={mk_eq_True,...},...}) (thm, name) =
   388   case mk_eq_True thm of
   389     None => []
   390   | Some eq_True =>
   391       let val (_,_,lhs,elhs,_,_) = decomp_simp eq_True
   392       in [{thm=eq_True, name=name, lhs=lhs, elhs=elhs, perm=false}] end;
   393 
   394 (* create the rewrite rule and possibly also the ==True variant,
   395    in case there are extra vars on the rhs *)
   396 fun rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm2) =
   397   let val rrule = {thm=thm, name=name, lhs=lhs, elhs=elhs, perm=false}
   398   in if (term_varnames rhs)  subset (term_varnames lhs) andalso
   399         (term_tvars rhs) subset (term_tvars lhs)
   400      then [rrule]
   401      else mk_eq_True mss (thm2, name) @ [rrule]
   402   end;
   403 
   404 fun mk_rrule mss (thm, name) =
   405   let val (_,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   406   in if perm then [{thm=thm, name=name, lhs=lhs, elhs=elhs, perm=true}] else
   407      (* weak test for loops: *)
   408      if rewrite_rule_extra_vars prems lhs rhs orelse
   409         is_Var (term_of elhs)
   410      then mk_eq_True mss (thm, name)
   411      else rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm)
   412   end;
   413 
   414 fun orient_rrule mss (thm, name) =
   415   let val (sign,prems,lhs,elhs,rhs,perm) = decomp_simp thm
   416   in if perm then [{thm=thm, name=name, lhs=lhs, elhs=elhs, perm=true}]
   417      else if reorient sign prems lhs rhs
   418           then if reorient sign prems rhs lhs
   419                then mk_eq_True mss (thm, name)
   420                else let val Mss{mk_rews={mk_sym,...},...} = mss
   421                     in case mk_sym thm of
   422                          None => []
   423                        | Some thm' =>
   424                            let val (_,_,lhs',elhs',rhs',_) = decomp_simp thm'
   425                            in rrule_eq_True(thm',name,lhs',elhs',rhs',mss,thm) end
   426                     end
   427           else rrule_eq_True(thm,name,lhs,elhs,rhs,mss,thm)
   428   end;
   429 
   430 fun extract_rews(Mss{mk_rews = {mk,...},...},thms) =
   431   flat (map (fn thm => map (rpair (Thm.name_of_thm thm)) (mk thm)) thms);
   432 
   433 fun orient_comb_simps comb mk_rrule (mss,thms) =
   434   let val rews = extract_rews(mss,thms)
   435       val rrules = flat (map mk_rrule rews)
   436   in foldl comb (mss,rrules) end
   437 
   438 (* Add rewrite rules explicitly; do not reorient! *)
   439 fun add_simps(mss,thms) =
   440   orient_comb_simps (insert_rrule false) (mk_rrule mss) (mss,thms);
   441 
   442 fun mss_of thms = foldl (insert_rrule false) (empty_mss, flat
   443   (map (fn thm => mk_rrule empty_mss (thm, Thm.name_of_thm thm)) thms));
   444 
   445 fun extract_safe_rrules(mss,thm) =
   446   flat (map (orient_rrule mss) (extract_rews(mss,[thm])));
   447 
   448 (* del_simps *)
   449 
   450 fun del_rrule(mss as Mss {rules,...},
   451               rrule as {thm, elhs, ...}) =
   452   (upd_rules(mss, Net.delete_term ((term_of elhs, rrule), rules, eq_rrule))
   453    handle Net.DELETE =>
   454      (prthm true "Rewrite rule not in simpset:" thm; mss));
   455 
   456 fun del_simps(mss,thms) =
   457   orient_comb_simps del_rrule (map mk_rrule2 o mk_rrule mss) (mss,thms);
   458 
   459 
   460 (* add_congs *)
   461 
   462 fun is_full_cong_prems [] varpairs = null varpairs
   463   | is_full_cong_prems (p::prems) varpairs =
   464     (case Logic.strip_assums_concl p of
   465        Const("==",_) $ lhs $ rhs =>
   466          let val (x,xs) = strip_comb lhs and (y,ys) = strip_comb rhs
   467          in is_Var x  andalso  forall is_Bound xs  andalso
   468             null(findrep(xs))  andalso xs=ys andalso
   469             (x,y) mem varpairs andalso
   470             is_full_cong_prems prems (varpairs\(x,y))
   471          end
   472      | _ => false);
   473 
   474 fun is_full_cong thm =
   475 let val prems = prems_of thm
   476     and concl = concl_of thm
   477     val (lhs,rhs) = Logic.dest_equals concl
   478     val (f,xs) = strip_comb lhs
   479     and (g,ys) = strip_comb rhs
   480 in
   481   f=g andalso null(findrep(xs@ys)) andalso length xs = length ys andalso
   482   is_full_cong_prems prems (xs ~~ ys)
   483 end
   484 
   485 fun cong_name (Const (a, _)) = Some a
   486   | cong_name (Free (a, _)) = Some ("Free: " ^ a)
   487   | cong_name _ = None;
   488 
   489 fun add_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   490   let
   491     val (lhs, _) = Drule.dest_equals (Drule.strip_imp_concl (cprop_of thm)) handle TERM _ =>
   492       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   493 (*   val lhs = Pattern.eta_contract lhs; *)
   494     val a = (case cong_name (head_of (term_of lhs)) of
   495         Some a => a
   496       | None =>
   497         raise SIMPLIFIER ("Congruence must start with a constant or free variable", thm));
   498     val (alist,weak) = congs
   499     val alist2 = overwrite_warn (alist, (a,{lhs=lhs, thm=thm}))
   500            ("Overwriting congruence rule for " ^ quote a);
   501     val weak2 = if is_full_cong thm then weak else a::weak
   502   in
   503     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   504   end;
   505 
   506 val (op add_congs) = foldl add_cong;
   507 
   508 
   509 (* del_congs *)
   510 
   511 fun del_cong (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thm) =
   512   let
   513     val (lhs, _) = Logic.dest_equals (concl_of thm) handle TERM _ =>
   514       raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   515 (*   val lhs = Pattern.eta_contract lhs; *)
   516     val a = (case cong_name (head_of lhs) of
   517         Some a => a
   518       | None =>
   519         raise SIMPLIFIER ("Congruence must start with a constant", thm));
   520     val (alist,_) = congs
   521     val alist2 = filter (fn (x,_)=> x<>a) alist
   522     val weak2 = mapfilter (fn(a,{thm,...}) => if is_full_cong thm then None
   523                                               else Some a)
   524                    alist2
   525   in
   526     mk_mss (rules,(alist2,weak2),procs,bounds,prems,mk_rews,termless,depth)
   527   end;
   528 
   529 val (op del_congs) = foldl del_cong;
   530 
   531 
   532 (* add_simprocs *)
   533 
   534 fun add_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   535     (name, lhs, proc, id)) =
   536   let val {sign, t, ...} = rep_cterm lhs
   537   in (trace_term false ("Adding simplification procedure " ^ quote name ^ " for")
   538       sign t;
   539     mk_mss (rules, congs,
   540       Net.insert_term ((t, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   541         handle Net.INSERT =>
   542             (warning ("Ignoring duplicate simplification procedure \""
   543                       ^ name ^ "\"");
   544              procs),
   545         bounds, prems, mk_rews, termless,depth))
   546   end;
   547 
   548 fun add_simproc (mss, (name, lhss, proc, id)) =
   549   foldl add_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   550 
   551 val add_simprocs = foldl add_simproc;
   552 
   553 
   554 (* del_simprocs *)
   555 
   556 fun del_proc (mss as Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth},
   557     (name, lhs, proc, id)) =
   558   mk_mss (rules, congs,
   559     Net.delete_term ((term_of lhs, mk_simproc (name, proc, lhs, id)), procs, eq_simproc)
   560       handle Net.DELETE =>
   561           (warning ("Simplification procedure \"" ^ name ^
   562                        "\" not in simpset"); procs),
   563       bounds, prems, mk_rews, termless, depth);
   564 
   565 fun del_simproc (mss, (name, lhss, proc, id)) =
   566   foldl del_proc (mss, map (fn lhs => (name, lhs, proc, id)) lhss);
   567 
   568 val del_simprocs = foldl del_simproc;
   569 
   570 
   571 (* prems *)
   572 
   573 fun add_prems (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, thms) =
   574   mk_mss (rules, congs, procs, bounds, thms @ prems, mk_rews, termless, depth);
   575 
   576 fun prems_of_mss (Mss {prems, ...}) = prems;
   577 
   578 
   579 (* mk_rews *)
   580 
   581 fun set_mk_rews
   582   (Mss {rules, congs, procs, bounds, prems, mk_rews, termless, depth}, mk) =
   583     mk_mss (rules, congs, procs, bounds, prems,
   584             {mk=mk, mk_sym= #mk_sym mk_rews, mk_eq_True= #mk_eq_True mk_rews},
   585             termless, depth);
   586 
   587 fun set_mk_sym
   588   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_sym) =
   589     mk_mss (rules, congs, procs, bounds, prems,
   590             {mk= #mk mk_rews, mk_sym= mk_sym, mk_eq_True= #mk_eq_True mk_rews},
   591             termless,depth);
   592 
   593 fun set_mk_eq_True
   594   (Mss {rules,congs,procs,bounds,prems,mk_rews,termless,depth}, mk_eq_True) =
   595     mk_mss (rules, congs, procs, bounds, prems,
   596             {mk= #mk mk_rews, mk_sym= #mk_sym mk_rews, mk_eq_True= mk_eq_True},
   597             termless,depth);
   598 
   599 fun get_mk_rews    (Mss {mk_rews,...}) = #mk         mk_rews
   600 fun get_mk_sym     (Mss {mk_rews,...}) = #mk_sym     mk_rews
   601 fun get_mk_eq_True (Mss {mk_rews,...}) = #mk_eq_True mk_rews
   602 
   603 (* termless *)
   604 
   605 fun set_termless
   606   (Mss {rules, congs, procs, bounds, prems, mk_rews, depth, ...}, termless) =
   607     mk_mss (rules, congs, procs, bounds, prems, mk_rews, termless, depth);
   608 
   609 
   610 
   611 (** simplification procedures **)
   612 
   613 (* datatype simproc *)
   614 
   615 fun full_mk_simproc name lhss proc =
   616   Simproc (name, map (Thm.cterm_fun Logic.varify) lhss, proc, stamp ());
   617 
   618 fun full_simproc sg name ss =
   619   full_mk_simproc name (map (fn s => Thm.read_cterm sg (s, TypeInfer.logicT)) ss);
   620 fun full_simproc_i sg name = full_mk_simproc name o map (Thm.cterm_of sg);
   621 
   622 fun mk_simproc name lhss proc =
   623   Simproc (name, map (Thm.cterm_fun Logic.varify) lhss, K proc, stamp ());
   624 
   625 fun simproc sg name ss =
   626   mk_simproc name (map (fn s => Thm.read_cterm sg (s, TypeInfer.logicT)) ss);
   627 fun simproc_i sg name = mk_simproc name o map (Thm.cterm_of sg);
   628 
   629 fun rep_simproc (Simproc args) = args;
   630 
   631 
   632 
   633 (** solvers **)
   634 
   635 fun mk_solver name solver = Solver (name, solver, stamp());
   636 fun eq_solver (Solver (_, _, s1), Solver(_, _, s2)) = s1 = s2;
   637 
   638 val merge_solvers = gen_merge_lists eq_solver;
   639 
   640 fun app_sols [] _ _ = no_tac
   641   | app_sols (Solver(_,solver,_)::sols) thms i =
   642        solver thms i ORELSE app_sols sols thms i;
   643 
   644 
   645 
   646 (** simplification sets **)
   647 
   648 (* type simpset *)
   649 
   650 fun make_ss mss mk_cong subgoal_tac loop_tacs unsafe_solvers solvers =
   651   Simpset {mss = mss, mk_cong = mk_cong, subgoal_tac = subgoal_tac,
   652     loop_tacs = loop_tacs, unsafe_solvers = unsafe_solvers, solvers = solvers};
   653 
   654 fun from_mss mss = make_ss mss I (K (K no_tac)) [] [] [];
   655 
   656 val empty_ss = from_mss (set_mk_sym (empty_mss, Some o symmetric_fun));
   657 
   658 fun rep_ss (Simpset args) = args;
   659 fun prems_of_ss (Simpset {mss, ...}) = prems_of_mss mss;
   660 
   661 
   662 (* print simpsets *)
   663 
   664 fun print_ss ss =
   665   let
   666     val Simpset {mss, ...} = ss;
   667     val {simps, procs, congs} = dest_mss mss;
   668 
   669     val pretty_thms = map Display.pretty_thm;
   670     fun pretty_proc (name, lhss) =
   671       Pretty.big_list (name ^ ":") (map Display.pretty_cterm lhss);
   672   in
   673     [Pretty.big_list "simplification rules:" (pretty_thms simps),
   674       Pretty.big_list "simplification procedures:" (map pretty_proc procs),
   675       Pretty.big_list "congruences:" (pretty_thms congs)]
   676     |> Pretty.chunks |> Pretty.writeln
   677   end;
   678 
   679 
   680 (* extend simpsets *)
   681 
   682 fun (Simpset {mss, mk_cong, subgoal_tac = _, loop_tacs, unsafe_solvers, solvers})
   683     setsubgoaler subgoal_tac =
   684   make_ss mss mk_cong subgoal_tac loop_tacs unsafe_solvers solvers;
   685 
   686 fun (Simpset {mss, mk_cong, subgoal_tac, loop_tacs = _, unsafe_solvers, solvers})
   687     setloop tac =
   688   make_ss mss mk_cong subgoal_tac [("", tac)] unsafe_solvers solvers;
   689 
   690 fun (Simpset {mss, mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, solvers})
   691     addloop tac = make_ss mss mk_cong subgoal_tac
   692       (case assoc_string (loop_tacs, (#1 tac)) of None => () | Some x =>
   693         warning ("overwriting looper " ^ quote (#1 tac)); overwrite (loop_tacs, tac))
   694       unsafe_solvers solvers;
   695 
   696 fun (ss as Simpset {mss, mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, solvers})
   697  delloop name =
   698   let val (del, rest) = partition (fn (n, _) => n = name) loop_tacs in
   699     if null del then (warning ("No such looper in simpset: " ^ name); ss)
   700     else make_ss mss mk_cong subgoal_tac rest unsafe_solvers solvers
   701   end;
   702 
   703 fun (Simpset {mss, mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, ...})
   704     setSSolver solver =
   705   make_ss mss mk_cong subgoal_tac loop_tacs unsafe_solvers [solver];
   706 
   707 fun (Simpset {mss, mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, solvers})
   708     addSSolver sol =
   709   make_ss mss mk_cong subgoal_tac loop_tacs unsafe_solvers (merge_solvers solvers [sol]);
   710 
   711 fun (Simpset {mss, mk_cong, subgoal_tac, loop_tacs, unsafe_solvers = _, solvers})
   712     setSolver unsafe_solver =
   713   make_ss mss mk_cong subgoal_tac loop_tacs [unsafe_solver] solvers;
   714 
   715 fun (Simpset {mss, mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, solvers})
   716     addSolver sol =
   717   make_ss mss mk_cong subgoal_tac loop_tacs (merge_solvers unsafe_solvers [sol]) solvers;
   718 
   719 fun (Simpset {mss, mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, solvers})
   720     setmksimps mk_simps =
   721   make_ss (set_mk_rews (mss, mk_simps)) mk_cong subgoal_tac loop_tacs unsafe_solvers solvers;
   722 
   723 fun (Simpset {mss, mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, solvers})
   724     setmkeqTrue mk_eq_True =
   725   make_ss (set_mk_eq_True (mss, mk_eq_True)) mk_cong subgoal_tac loop_tacs
   726     unsafe_solvers solvers;
   727 
   728 fun (Simpset {mss, mk_cong = _, subgoal_tac, loop_tacs, unsafe_solvers, solvers})
   729     setmkcong mk_cong =
   730   make_ss mss mk_cong subgoal_tac loop_tacs unsafe_solvers solvers;
   731 
   732 fun (Simpset {mss, mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, solvers})
   733     setmksym mksym =
   734   make_ss (set_mk_sym (mss, mksym)) mk_cong subgoal_tac loop_tacs unsafe_solvers solvers;
   735 
   736 fun (Simpset {mss, mk_cong, subgoal_tac, loop_tacs,  unsafe_solvers, solvers})
   737     settermless termless =
   738   make_ss (set_termless (mss, termless)) mk_cong subgoal_tac loop_tacs
   739     unsafe_solvers solvers;
   740 
   741 fun (Simpset {mss, mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, solvers})
   742     addsimps rews =
   743   make_ss (add_simps (mss, rews)) mk_cong subgoal_tac loop_tacs unsafe_solvers solvers;
   744 
   745 fun (Simpset {mss, mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, solvers})
   746     delsimps rews =
   747   make_ss (del_simps (mss, rews)) mk_cong subgoal_tac loop_tacs unsafe_solvers solvers;
   748 
   749 fun (Simpset {mss, mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, solvers})
   750     addeqcongs newcongs =
   751   make_ss (add_congs (mss, newcongs)) mk_cong subgoal_tac loop_tacs unsafe_solvers solvers;
   752 
   753 fun (Simpset {mss, subgoal_tac, mk_cong, loop_tacs, unsafe_solvers, solvers})
   754     deleqcongs oldcongs =
   755   make_ss (del_congs (mss, oldcongs)) mk_cong subgoal_tac loop_tacs unsafe_solvers solvers;
   756 
   757 fun (ss as Simpset {mk_cong, ...}) addcongs newcongs =
   758   ss addeqcongs map mk_cong newcongs;
   759 
   760 fun (ss as Simpset {mk_cong, ...}) delcongs oldcongs =
   761   ss deleqcongs map mk_cong oldcongs;
   762 
   763 fun (Simpset {mss, mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, solvers})
   764     addsimprocs simprocs =
   765   make_ss (add_simprocs (mss, map rep_simproc simprocs)) mk_cong subgoal_tac loop_tacs
   766     unsafe_solvers solvers;
   767 
   768 fun (Simpset {mss, mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, solvers})
   769     delsimprocs simprocs =
   770   make_ss (del_simprocs (mss, map rep_simproc simprocs)) mk_cong subgoal_tac
   771     loop_tacs unsafe_solvers solvers;
   772 
   773 fun clear_ss (Simpset {mss, mk_cong, subgoal_tac, loop_tacs = _, unsafe_solvers, solvers}) =
   774   make_ss (clear_mss mss) mk_cong subgoal_tac [] unsafe_solvers solvers;
   775 
   776 
   777 (* merge simpsets *)
   778 
   779 (*ignores subgoal_tac of 2nd simpset!*)
   780 
   781 fun merge_ss
   782    (Simpset {mss = mss1, mk_cong, loop_tacs = loop_tacs1, subgoal_tac,
   783              unsafe_solvers = unsafe_solvers1, solvers = solvers1},
   784     Simpset {mss = mss2, mk_cong = _, loop_tacs = loop_tacs2, subgoal_tac = _,
   785              unsafe_solvers = unsafe_solvers2, solvers = solvers2}) =
   786   make_ss (merge_mss (mss1, mss2)) mk_cong subgoal_tac
   787     (merge_alists loop_tacs1 loop_tacs2)
   788     (merge_solvers unsafe_solvers1 unsafe_solvers2)
   789     (merge_solvers solvers1 solvers2);
   790 
   791 (** rewriting **)
   792 
   793 (*
   794   Uses conversions, see:
   795     L C Paulson, A higher-order implementation of rewriting,
   796     Science of Computer Programming 3 (1983), pages 119-149.
   797 *)
   798 
   799 val dest_eq = Drule.dest_equals o cprop_of;
   800 val lhs_of = fst o dest_eq;
   801 val rhs_of = snd o dest_eq;
   802 
   803 fun check_conv msg thm thm' =
   804   let
   805     val thm'' = transitive thm (transitive
   806       (symmetric (Drule.beta_eta_conversion (lhs_of thm'))) thm')
   807   in (if msg then trace_thm "SUCCEEDED" thm' else (); Some thm'') end
   808   handle THM _ =>
   809     let val {sign, prop = _ $ _ $ prop0, ...} = rep_thm thm;
   810     in
   811       (trace_thm "Proved wrong thm (Check subgoaler?)" thm';
   812        trace_term false "Should have proved:" sign prop0;
   813        None)
   814     end;
   815 
   816 
   817 (* mk_procrule *)
   818 
   819 fun mk_procrule thm =
   820   let val (_,prems,lhs,elhs,rhs,_) = decomp_simp thm
   821   in if rewrite_rule_extra_vars prems lhs rhs
   822      then (prthm true "Extra vars on rhs:" thm; [])
   823      else [mk_rrule2{thm=thm, name="", lhs=lhs, elhs=elhs, perm=false}]
   824   end;
   825 
   826 
   827 (* conversion to apply the meta simpset to a term *)
   828 
   829 (* Since the rewriting strategy is bottom-up, we avoid re-normalizing already
   830    normalized terms by carrying around the rhs of the rewrite rule just
   831    applied. This is called the `skeleton'. It is decomposed in parallel
   832    with the term. Once a Var is encountered, the corresponding term is
   833    already in normal form.
   834    skel0 is a dummy skeleton that is to enforce complete normalization.
   835 *)
   836 val skel0 = Bound 0;
   837 
   838 (* Use rhs as skeleton only if the lhs does not contain unnormalized bits.
   839    The latter may happen iff there are weak congruence rules for constants
   840    in the lhs.
   841 *)
   842 fun uncond_skel((_,weak),(lhs,rhs)) =
   843   if null weak then rhs (* optimization *)
   844   else if exists_Const (fn (c,_) => c mem weak) lhs then skel0
   845        else rhs;
   846 
   847 (* Behaves like unconditional rule if rhs does not contain vars not in the lhs.
   848    Otherwise those vars may become instantiated with unnormalized terms
   849    while the premises are solved.
   850 *)
   851 fun cond_skel(args as (congs,(lhs,rhs))) =
   852   if term_varnames rhs subset term_varnames lhs then uncond_skel(args)
   853   else skel0;
   854 
   855 (*
   856   we try in order:
   857     (1) beta reduction
   858     (2) unconditional rewrite rules
   859     (3) conditional rewrite rules
   860     (4) simplification procedures
   861 
   862   IMPORTANT: rewrite rules must not introduce new Vars or TVars!
   863 
   864 *)
   865 
   866 fun rewritec (prover, signt, maxt)
   867              (ss as Simpset{mss=mss as Mss{rules, procs, termless, prems, congs, depth,...},...}) t =
   868   let
   869     val eta_thm = Thm.eta_conversion t;
   870     val eta_t' = rhs_of eta_thm;
   871     val eta_t = term_of eta_t';
   872     val tsigt = Sign.tsig_of signt;
   873     fun rew {thm, name, lhs, elhs, fo, perm} =
   874       let
   875         val {sign, prop, maxidx, ...} = rep_thm thm;
   876         val _ = if Sign.subsig (sign, signt) then ()
   877                 else (prthm true "Ignoring rewrite rule from different theory:" thm;
   878                       raise Pattern.MATCH);
   879         val (rthm, elhs') = if maxt = ~1 then (thm, elhs)
   880           else (Thm.incr_indexes (maxt+1) thm, Thm.cterm_incr_indexes (maxt+1) elhs);
   881         val insts = if fo then Thm.cterm_first_order_match (elhs', eta_t')
   882                           else Thm.cterm_match (elhs', eta_t');
   883         val thm' = Thm.instantiate insts (Thm.rename_boundvars lhs eta_t rthm);
   884         val prop' = Thm.prop_of thm';
   885         val unconditional = (Logic.count_prems (prop',0) = 0);
   886         val (lhs', rhs') = Logic.dest_equals (Logic.strip_imp_concl prop')
   887       in
   888         if perm andalso not (termless (rhs', lhs'))
   889         then (trace_named_thm "Cannot apply permutative rewrite rule" (thm, name);
   890               trace_thm "Term does not become smaller:" thm'; None)
   891         else (trace_named_thm "Applying instance of rewrite rule" (thm, name);
   892            if unconditional
   893            then
   894              (trace_thm "Rewriting:" thm';
   895               let val lr = Logic.dest_equals prop;
   896                   val Some thm'' = check_conv false eta_thm thm'
   897               in Some (thm'', uncond_skel (congs, lr)) end)
   898            else
   899              (trace_thm "Trying to rewrite:" thm';
   900               case incr_depth mss of
   901                 None => (trace_thm "FAILED - reached depth limit" thm'; None)
   902               | Some mss =>
   903               (case prover mss thm' of
   904                 None       => (trace_thm "FAILED" thm'; None)
   905               | Some thm2 =>
   906                   (case check_conv true eta_thm thm2 of
   907                      None => None |
   908                      Some thm2' =>
   909                        let val concl = Logic.strip_imp_concl prop
   910                            val lr = Logic.dest_equals concl
   911                        in Some (thm2', cond_skel (congs, lr)) end))))
   912       end
   913 
   914     fun rews [] = None
   915       | rews (rrule :: rrules) =
   916           let val opt = rew rrule handle Pattern.MATCH => None
   917           in case opt of None => rews rrules | some => some end;
   918 
   919     fun sort_rrules rrs = let
   920       fun is_simple({thm, ...}:rrule) = case Thm.prop_of thm of
   921                                       Const("==",_) $ _ $ _ => true
   922                                       | _                   => false
   923       fun sort []        (re1,re2) = re1 @ re2
   924         | sort (rr::rrs) (re1,re2) = if is_simple rr
   925                                      then sort rrs (rr::re1,re2)
   926                                      else sort rrs (re1,rr::re2)
   927     in sort rrs ([],[]) end
   928 
   929     fun proc_rews ([]:meta_simproc list) = None
   930       | proc_rews ({name, proc, lhs, ...} :: ps) =
   931           if Pattern.matches tsigt (term_of lhs, term_of t) then
   932             (debug_term false ("Trying procedure " ^ quote name ^ " on:") signt eta_t;
   933              case transform_failure (curry SIMPROC_FAIL name)
   934                  (fn () => proc ss signt prems eta_t) () of
   935                None => (debug false "FAILED"; proc_rews ps)
   936              | Some raw_thm =>
   937                  (trace_thm ("Procedure " ^ quote name ^ " produced rewrite rule:") raw_thm;
   938                   (case rews (mk_procrule raw_thm) of
   939                     None => (trace_cterm true ("IGNORED result of simproc " ^ quote name ^
   940                       " -- does not match") t; proc_rews ps)
   941                   | some => some)))
   942           else proc_rews ps;
   943   in case eta_t of
   944        Abs _ $ _ => Some (transitive eta_thm
   945          (beta_conversion false eta_t'), skel0)
   946      | _ => (case rews (sort_rrules (Net.match_term rules eta_t)) of
   947                None => proc_rews (Net.match_term procs eta_t)
   948              | some => some)
   949   end;
   950 
   951 
   952 (* conversion to apply a congruence rule to a term *)
   953 
   954 fun congc (prover,signt,maxt) {thm=cong,lhs=lhs} t =
   955   let val sign = Thm.sign_of_thm cong
   956       val _ = if Sign.subsig (sign, signt) then ()
   957                  else error("Congruence rule from different theory")
   958       val rthm = if maxt = ~1 then cong else Thm.incr_indexes (maxt+1) cong;
   959       val rlhs = fst (Drule.dest_equals (Drule.strip_imp_concl (cprop_of rthm)));
   960       val insts = Thm.cterm_match (rlhs, t)
   961       (* Pattern.match can raise Pattern.MATCH;
   962          is handled when congc is called *)
   963       val thm' = Thm.instantiate insts (Thm.rename_boundvars (term_of rlhs) (term_of t) rthm);
   964       val unit = trace_thm "Applying congruence rule:" thm';
   965       fun err (msg, thm) = (trace_thm msg thm; None)
   966   in case prover thm' of
   967        None => err ("Congruence proof failed.  Could not prove", thm')
   968      | Some thm2 => (case check_conv true (Drule.beta_eta_conversion t) thm2 of
   969           None => err ("Congruence proof failed.  Should not have proved", thm2)
   970         | Some thm2' =>
   971             if op aconv (pairself term_of (dest_equals (cprop_of thm2')))
   972             then None else Some thm2')
   973   end;
   974 
   975 val (cA, (cB, cC)) =
   976   apsnd dest_equals (dest_implies (hd (cprems_of Drule.imp_cong)));
   977 
   978 fun transitive1 None None = None
   979   | transitive1 (Some thm1) None = Some thm1
   980   | transitive1 None (Some thm2) = Some thm2
   981   | transitive1 (Some thm1) (Some thm2) = Some (transitive thm1 thm2)
   982 
   983 fun transitive2 thm = transitive1 (Some thm);
   984 fun transitive3 thm = transitive1 thm o Some;
   985 
   986 fun replace_mss (Simpset{mss=_,mk_cong,subgoal_tac,loop_tacs,unsafe_solvers,solvers}) mss_new =
   987     Simpset{mss=mss_new,mk_cong=mk_cong,subgoal_tac=subgoal_tac,loop_tacs=loop_tacs,
   988 	    unsafe_solvers=unsafe_solvers,solvers=solvers};
   989 
   990 fun bottomc ((simprem,useprem,mutsimp), prover, sign, maxidx) (ss as Simpset{mss,...}) =
   991   let
   992     fun botc skel mss t =
   993           if is_Var skel then None
   994           else
   995           (case subc skel mss t of
   996              some as Some thm1 =>
   997                (case rewritec (prover, sign, maxidx) (replace_mss ss mss) (rhs_of thm1) of
   998                   Some (thm2, skel2) =>
   999                     transitive2 (transitive thm1 thm2)
  1000                       (botc skel2 mss (rhs_of thm2))
  1001                 | None => some)
  1002            | None =>
  1003                (case rewritec (prover, sign, maxidx) (replace_mss ss mss) t of
  1004                   Some (thm2, skel2) => transitive2 thm2
  1005                     (botc skel2 mss (rhs_of thm2))
  1006                 | None => None))
  1007 
  1008     and try_botc mss t =
  1009           (case botc skel0 mss t of
  1010              Some trec1 => trec1 | None => (reflexive t))
  1011 
  1012     and subc skel
  1013           (mss as Mss{rules,congs,procs,bounds,prems,mk_rews,termless,depth}) t0 =
  1014        (case term_of t0 of
  1015            Abs (a, T, t) =>
  1016              let val b = variant bounds a
  1017                  val (v, t') = Thm.dest_abs (Some ("." ^ b)) t0
  1018                  val mss' = mk_mss (rules, congs, procs, b :: bounds, prems, mk_rews, termless,depth)
  1019                  val skel' = case skel of Abs (_, _, sk) => sk | _ => skel0
  1020              in case botc skel' mss' t' of
  1021                   Some thm => Some (abstract_rule a v thm)
  1022                 | None => None
  1023              end
  1024          | t $ _ => (case t of
  1025              Const ("==>", _) $ _  => impc t0 mss
  1026            | Abs _ =>
  1027                let val thm = beta_conversion false t0
  1028                in case subc skel0 mss (rhs_of thm) of
  1029                     None => Some thm
  1030                   | Some thm' => Some (transitive thm thm')
  1031                end
  1032            | _  =>
  1033                let fun appc () =
  1034                      let
  1035                        val (tskel, uskel) = case skel of
  1036                            tskel $ uskel => (tskel, uskel)
  1037                          | _ => (skel0, skel0);
  1038                        val (ct, cu) = Thm.dest_comb t0
  1039                      in
  1040                      (case botc tskel mss ct of
  1041                         Some thm1 =>
  1042                           (case botc uskel mss cu of
  1043                              Some thm2 => Some (combination thm1 thm2)
  1044                            | None => Some (combination thm1 (reflexive cu)))
  1045                       | None =>
  1046                           (case botc uskel mss cu of
  1047                              Some thm1 => Some (combination (reflexive ct) thm1)
  1048                            | None => None))
  1049                      end
  1050                    val (h, ts) = strip_comb t
  1051                in case cong_name h of
  1052                     Some a =>
  1053                       (case assoc_string (fst congs, a) of
  1054                          None => appc ()
  1055                        | Some cong =>
  1056 (* post processing: some partial applications h t1 ... tj, j <= length ts,
  1057    may be a redex. Example: map (%x.x) = (%xs.xs) wrt map_cong *)
  1058                           (let
  1059                              val thm = congc (prover mss, sign, maxidx) cong t0;
  1060                              val t = if_none (apsome rhs_of thm) t0;
  1061                              val (cl, cr) = Thm.dest_comb t
  1062                              val dVar = Var(("", 0), dummyT)
  1063                              val skel =
  1064                                list_comb (h, replicate (length ts) dVar)
  1065                            in case botc skel mss cl of
  1066                                 None => thm
  1067                               | Some thm' => transitive3 thm
  1068                                   (combination thm' (reflexive cr))
  1069                            end handle TERM _ => error "congc result"
  1070                                     | Pattern.MATCH => appc ()))
  1071                   | _ => appc ()
  1072                end)
  1073          | _ => None)
  1074 
  1075     and impc ct mss =
  1076       if mutsimp then mut_impc0 [] ct [] [] mss else nonmut_impc ct mss
  1077 
  1078     and rules_of_prem mss prem =
  1079       if maxidx_of_term (term_of prem) <> ~1
  1080       then (trace_cterm true
  1081         "Cannot add premise as rewrite rule because it contains (type) unknowns:" prem; ([], None))
  1082       else
  1083         let val asm = assume prem
  1084         in (extract_safe_rrules (mss, asm), Some asm) end
  1085 
  1086     and add_rrules (rrss, asms) mss =
  1087       add_prems (foldl (insert_rrule true) (mss, flat rrss), mapfilter I asms)
  1088 
  1089     and disch r (prem, eq) =
  1090       let
  1091         val (lhs, rhs) = dest_eq eq;
  1092         val eq' = implies_elim (Thm.instantiate
  1093           ([], [(cA, prem), (cB, lhs), (cC, rhs)]) Drule.imp_cong)
  1094           (implies_intr prem eq)
  1095       in if not r then eq' else
  1096         let
  1097           val (prem', concl) = dest_implies lhs;
  1098           val (prem'', _) = dest_implies rhs
  1099         in transitive (transitive
  1100           (Thm.instantiate ([], [(cA, prem'), (cB, prem), (cC, concl)])
  1101              Drule.swap_prems_eq) eq')
  1102           (Thm.instantiate ([], [(cA, prem), (cB, prem''), (cC, concl)])
  1103              Drule.swap_prems_eq)
  1104         end
  1105       end
  1106 
  1107     and rebuild [] _ _ _ _ eq = eq
  1108       | rebuild (prem :: prems) concl (rrs :: rrss) (asm :: asms) mss eq =
  1109           let
  1110             val mss' = add_rrules (rev rrss, rev asms) mss;
  1111             val concl' =
  1112               Drule.mk_implies (prem, if_none (apsome rhs_of eq) concl);
  1113             val dprem = apsome (curry (disch false) prem)
  1114           in case rewritec (prover, sign, maxidx) (replace_mss ss mss') concl' of
  1115               None => rebuild prems concl' rrss asms mss (dprem eq)
  1116             | Some (eq', _) => transitive2 (foldl (disch false o swap)
  1117                   (the (transitive3 (dprem eq) eq'), prems))
  1118                 (mut_impc0 (rev prems) (rhs_of eq') (rev rrss) (rev asms) mss)
  1119           end
  1120           
  1121     and mut_impc0 prems concl rrss asms mss =
  1122       let
  1123         val prems' = strip_imp_prems concl;
  1124         val (rrss', asms') = split_list (map (rules_of_prem mss) prems')
  1125       in mut_impc (prems @ prems') (strip_imp_concl concl) (rrss @ rrss')
  1126         (asms @ asms') [] [] [] [] mss ~1 ~1
  1127       end
  1128  
  1129     and mut_impc [] concl [] [] prems' rrss' asms' eqns mss changed k =
  1130         transitive1 (foldl (fn (eq2, (eq1, prem)) => transitive1 eq1
  1131             (apsome (curry (disch false) prem) eq2)) (None, eqns ~~ prems'))
  1132           (if changed > 0 then
  1133              mut_impc (rev prems') concl (rev rrss') (rev asms')
  1134                [] [] [] [] mss ~1 changed
  1135            else rebuild prems' concl rrss' asms' mss
  1136              (botc skel0 (add_rrules (rev rrss', rev asms') mss) concl))
  1137 
  1138       | mut_impc (prem :: prems) concl (rrs :: rrss) (asm :: asms)
  1139           prems' rrss' asms' eqns mss changed k =
  1140         case (if k = 0 then None else botc skel0 (add_rrules
  1141           (rev rrss' @ rrss, rev asms' @ asms) mss) prem) of
  1142             None => mut_impc prems concl rrss asms (prem :: prems')
  1143               (rrs :: rrss') (asm :: asms') (None :: eqns) mss changed
  1144               (if k = 0 then 0 else k - 1)
  1145           | Some eqn =>
  1146             let
  1147               val prem' = rhs_of eqn;
  1148               val tprems = map term_of prems;
  1149               val i = 1 + foldl Int.max (~1, map (fn p =>
  1150                 find_index_eq p tprems) (#hyps (rep_thm eqn)));
  1151               val (rrs', asm') = rules_of_prem mss prem'
  1152             in mut_impc prems concl rrss asms (prem' :: prems')
  1153               (rrs' :: rrss') (asm' :: asms') (Some (foldr (disch true)
  1154                 (take (i, prems), Drule.imp_cong' eqn (reflexive (Drule.list_implies
  1155                   (drop (i, prems), concl))))) :: eqns) mss (length prems') ~1
  1156             end
  1157 
  1158      (* legacy code - only for backwards compatibility *)
  1159      and nonmut_impc ct mss =
  1160        let val (prem, conc) = dest_implies ct;
  1161            val thm1 = if simprem then botc skel0 mss prem else None;
  1162            val prem1 = if_none (apsome rhs_of thm1) prem;
  1163            val mss1 = if not useprem then mss else add_rrules
  1164              (apsnd single (apfst single (rules_of_prem mss prem1))) mss
  1165        in (case botc skel0 mss1 conc of
  1166            None => (case thm1 of
  1167                None => None
  1168              | Some thm1' => Some (Drule.imp_cong' thm1' (reflexive conc)))
  1169          | Some thm2 =>
  1170            let val thm2' = disch false (prem1, thm2)
  1171            in (case thm1 of
  1172                None => Some thm2'
  1173              | Some thm1' =>
  1174                  Some (transitive (Drule.imp_cong' thm1' (reflexive conc)) thm2'))
  1175            end)
  1176        end
  1177 
  1178  in try_botc mss end;
  1179 
  1180 
  1181 (*** Meta-rewriting: rewrites t to u and returns the theorem t==u ***)
  1182 
  1183 (*
  1184   Parameters:
  1185     mode = (simplify A,
  1186             use A in simplifying B,
  1187             use prems of B (if B is again a meta-impl.) to simplify A)
  1188            when simplifying A ==> B
  1189     mss: contains equality theorems of the form [|p1,...|] ==> t==u
  1190     prover: how to solve premises in conditional rewrites and congruences
  1191 *)
  1192 
  1193 fun rewrite_cterm mode prover (ss as Simpset{mss,...}) ct =
  1194   let val {sign, t, maxidx, ...} = rep_cterm ct
  1195       val Mss{depth, ...} = mss
  1196   in trace_cterm false "SIMPLIFIER INVOKED ON THE FOLLOWING TERM:" ct;
  1197      simp_depth := depth;
  1198      bottomc (mode, prover, sign, maxidx) ss ct
  1199   end
  1200   handle THM (s, _, thms) =>
  1201     error ("Exception THM was raised in simplifier:\n" ^ s ^ "\n" ^
  1202       Pretty.string_of (Display.pretty_thms thms));
  1203 
  1204 val ss_of = from_mss o mss_of
  1205 
  1206 (*Rewrite a cterm*)
  1207 fun rewrite_aux _ _ [] = (fn ct => Thm.reflexive ct)
  1208   | rewrite_aux prover full thms = rewrite_cterm (full, false, false) prover (ss_of thms);
  1209 
  1210 (*Rewrite a theorem*)
  1211 fun simplify_aux _ _ [] = (fn th => th)
  1212   | simplify_aux prover full thms =
  1213       Drule.fconv_rule (rewrite_cterm (full, false, false) prover (ss_of thms));
  1214 
  1215 fun rewrite_thm mode prover mss = Drule.fconv_rule (rewrite_cterm mode prover mss);
  1216 
  1217 (*Rewrite the subgoals of a proof state (represented by a theorem) *)
  1218 fun rewrite_goals_rule_aux _ []   th = th
  1219   | rewrite_goals_rule_aux prover thms th =
  1220       Drule.fconv_rule (Drule.goals_conv (K true) (rewrite_cterm (true, true, false) prover
  1221         (ss_of thms))) th;
  1222 
  1223 (*Rewrite the subgoal of a proof state (represented by a theorem) *)
  1224 fun rewrite_goal_rule mode prover ss i thm =
  1225   if 0 < i  andalso  i <= nprems_of thm
  1226   then Drule.fconv_rule (Drule.goals_conv (fn j => j=i) (rewrite_cterm mode prover ss)) thm
  1227   else raise THM("rewrite_goal_rule",i,[thm]);
  1228 
  1229 
  1230 (*simple term rewriting -- without proofs*)
  1231 fun rewrite_term sg rules procs =
  1232   Pattern.rewrite_term (Sign.tsig_of sg) (map decomp_simp' rules) procs;
  1233 
  1234 (*Rewrite subgoal i only.  SELECT_GOAL avoids inefficiencies in goals_conv.*)
  1235 fun asm_rewrite_goal_tac mode prover_tac mss =
  1236   SELECT_GOAL
  1237     (PRIMITIVE (rewrite_goal_rule mode (SINGLE o prover_tac) mss 1));
  1238 
  1239 (** simplification tactics **)
  1240 
  1241 fun solve_all_tac (mk_cong, subgoal_tac, loop_tacs, unsafe_solvers) mss =
  1242   let
  1243     val ss =
  1244       make_ss mss mk_cong subgoal_tac loop_tacs unsafe_solvers unsafe_solvers;
  1245     val solve1_tac = (subgoal_tac ss THEN_ALL_NEW (K no_tac)) 1
  1246   in DEPTH_SOLVE solve1_tac end;
  1247 
  1248 fun loop_tac loop_tacs = FIRST'(map snd loop_tacs);
  1249 
  1250 (*note: may instantiate unknowns that appear also in other subgoals*)
  1251 fun generic_simp_tac safe mode =
  1252   fn (ss as Simpset {mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, solvers, ...}) =>
  1253     let
  1254       val solvs = app_sols (if safe then solvers else unsafe_solvers);
  1255       fun simp_loop_tac i =
  1256         asm_rewrite_goal_tac mode
  1257           (solve_all_tac (mk_cong, subgoal_tac, loop_tacs, unsafe_solvers))
  1258           ss i
  1259         THEN (solvs (prems_of_ss ss) i ORELSE
  1260               TRY ((loop_tac loop_tacs THEN_ALL_NEW simp_loop_tac) i))
  1261     in simp_loop_tac end;
  1262 
  1263 (** simplification rules and conversions **)
  1264 
  1265 fun simp rew mode
  1266      (ss as Simpset {mk_cong, subgoal_tac, loop_tacs, unsafe_solvers, ...}) thm =
  1267   let
  1268     val tacf = solve_all_tac (mk_cong, subgoal_tac, loop_tacs, unsafe_solvers);
  1269     fun prover m th = apsome fst (Seq.pull (tacf m th));
  1270   in rew mode prover ss thm end;
  1271 
  1272 val simp_thm = simp rewrite_thm;
  1273 val simp_cterm = simp rewrite_cterm;
  1274 
  1275 end;
  1276 
  1277 structure BasicMetaSimplifier: BASIC_META_SIMPLIFIER = MetaSimplifier;
  1278 open BasicMetaSimplifier;