src/Pure/meta_simplifier.ML
changeset 16042 8e15ff79851a
parent 15574 b1d1b5bfc464
child 16305 5e7b6731b004
equal deleted inserted replaced
16041:5a8736668ced 16042:8e15ff79851a
    13 signature BASIC_META_SIMPLIFIER =
    13 signature BASIC_META_SIMPLIFIER =
    14 sig
    14 sig
    15   val debug_simp: bool ref
    15   val debug_simp: bool ref
    16   val trace_simp: bool ref
    16   val trace_simp: bool ref
    17   val simp_depth_limit: int ref
    17   val simp_depth_limit: int ref
       
    18   val trace_simp_depth_limit: int ref
    18   type rrule
    19   type rrule
    19   type cong
    20   type cong
    20   type solver
    21   type solver
    21   val mk_solver: string -> (thm list -> int -> tactic) -> solver
    22   val mk_solver: string -> (thm list -> int -> tactic) -> solver
    22   type simpset
    23   type simpset
    23   type proc
    24   type proc
    24   val rep_ss: simpset ->
    25   val rep_ss: simpset ->
    25    {rules: rrule Net.net,
    26    {rules: rrule Net.net,
    26     prems: thm list,
    27     prems: thm list,
    27     bounds: int,
    28     bounds: int} *
    28     depth: int} *
       
    29    {congs: (string * cong) list * string list,
    29    {congs: (string * cong) list * string list,
    30     procs: proc Net.net,
    30     procs: proc Net.net,
    31     mk_rews:
    31     mk_rews:
    32      {mk: thm -> thm list,
    32      {mk: thm -> thm list,
    33       mk_cong: thm -> thm,
    33       mk_cong: thm -> thm,
   105 
   105 
   106 val debug_simp = ref false;
   106 val debug_simp = ref false;
   107 val trace_simp = ref false;
   107 val trace_simp = ref false;
   108 val simp_depth = ref 0;
   108 val simp_depth = ref 0;
   109 val simp_depth_limit = ref 1000;
   109 val simp_depth_limit = ref 1000;
       
   110 val trace_simp_depth_limit = ref 1000;
   110 
   111 
   111 local
   112 local
   112 
   113 
   113 fun println a =
   114 fun println a =
   114   tracing (case ! simp_depth of 0 => a | n => enclose "[" "]" (string_of_int n) ^ a);
   115   if !simp_depth > !trace_simp_depth_limit then ()
       
   116   else tracing (enclose "[" "]" (string_of_int(!simp_depth)) ^ a);
   115 
   117 
   116 fun prnt warn a = if warn then warning a else println a;
   118 fun prnt warn a = if warn then warning a else println a;
   117 fun prtm warn a sg t = prnt warn (a ^ "\n" ^ Sign.string_of_term sg t);
   119 fun prtm warn a sg t = prnt warn (a ^ "\n" ^ Sign.string_of_term sg t);
   118 fun prctm warn a t = prnt warn (a ^ "\n" ^ Display.string_of_cterm t);
   120 fun prctm warn a t = prnt warn (a ^ "\n" ^ Display.string_of_cterm t);
   119 
   121 
   193 (*A simpset contains data required during conversion:
   195 (*A simpset contains data required during conversion:
   194     rules: discrimination net of rewrite rules;
   196     rules: discrimination net of rewrite rules;
   195     prems: current premises;
   197     prems: current premises;
   196     bounds: maximal index of bound variables already used
   198     bounds: maximal index of bound variables already used
   197       (for generating new names when rewriting under lambda abstractions);
   199       (for generating new names when rewriting under lambda abstractions);
   198     depth: depth of conditional rewriting;
       
   199     congs: association list of congruence rules and
   200     congs: association list of congruence rules and
   200            a list of `weak' congruence constants.
   201            a list of `weak' congruence constants.
   201            A congruence is `weak' if it avoids normalization of some argument.
   202            A congruence is `weak' if it avoids normalization of some argument.
   202     procs: discrimination net of simplification procedures
   203     procs: discrimination net of simplification procedures
   203       (functions that prove rewrite rules on the fly);
   204       (functions that prove rewrite rules on the fly);
   216 
   217 
   217 datatype simpset =
   218 datatype simpset =
   218   Simpset of
   219   Simpset of
   219    {rules: rrule Net.net,
   220    {rules: rrule Net.net,
   220     prems: thm list,
   221     prems: thm list,
   221     bounds: int,
   222     bounds: int} *
   222     depth: int} *
       
   223    {congs: (string * cong) list * string list,
   223    {congs: (string * cong) list * string list,
   224     procs: proc Net.net,
   224     procs: proc Net.net,
   225     mk_rews: mk_rews,
   225     mk_rews: mk_rews,
   226     termless: term * term -> bool,
   226     termless: term * term -> bool,
   227     subgoal_tac: simpset -> int -> tactic,
   227     subgoal_tac: simpset -> int -> tactic,
   236 
   236 
   237 fun eq_proc (Proc {id = id1, ...}, Proc {id = id2, ...}) = (id1 = id2);
   237 fun eq_proc (Proc {id = id1, ...}, Proc {id = id2, ...}) = (id1 = id2);
   238 
   238 
   239 fun rep_ss (Simpset args) = args;
   239 fun rep_ss (Simpset args) = args;
   240 
   240 
   241 fun make_ss1 (rules, prems, bounds, depth) =
   241 fun make_ss1 (rules, prems, bounds) =
   242   {rules = rules, prems = prems, bounds = bounds, depth = depth};
   242   {rules = rules, prems = prems, bounds = bounds};
   243 
   243 
   244 fun map_ss1 f {rules, prems, bounds, depth} =
   244 fun map_ss1 f {rules, prems, bounds} =
   245   make_ss1 (f (rules, prems, bounds, depth));
   245   make_ss1 (f (rules, prems, bounds));
   246 
   246 
   247 fun make_ss2 (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) =
   247 fun make_ss2 (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) =
   248   {congs = congs, procs = procs, mk_rews = mk_rews, termless = termless,
   248   {congs = congs, procs = procs, mk_rews = mk_rews, termless = termless,
   249     subgoal_tac = subgoal_tac, loop_tacs = loop_tacs, solvers = solvers};
   249     subgoal_tac = subgoal_tac, loop_tacs = loop_tacs, solvers = solvers};
   250 
   250 
   251 fun map_ss2 f {congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers} =
   251 fun map_ss2 f {congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers} =
   252   make_ss2 (f (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers));
   252   make_ss2 (f (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers));
   253 
   253 
   254 fun make_simpset (args1, args2) = Simpset (make_ss1 args1, make_ss2 args2);
   254 fun make_simpset (args1, args2) = Simpset (make_ss1 args1, make_ss2 args2);
   255 
   255 
   256 fun map_simpset f (Simpset ({rules, prems, bounds, depth},
   256 fun map_simpset f (Simpset ({rules, prems, bounds},
   257     {congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers})) =
   257     {congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers})) =
   258   make_simpset (f ((rules, prems, bounds, depth),
   258   make_simpset (f ((rules, prems, bounds),
   259     (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers)));
   259     (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers)));
   260 
   260 
   261 fun map_simpset1 f (Simpset (r1, r2)) = Simpset (map_ss1 f r1, r2);
   261 fun map_simpset1 f (Simpset (r1, r2)) = Simpset (map_ss1 f r1, r2);
   262 fun map_simpset2 f (Simpset (r1, r2)) = Simpset (r1, map_ss2 f r2);
   262 fun map_simpset2 f (Simpset (r1, r2)) = Simpset (r1, map_ss2 f r2);
   263 
   263 
   295 (* empty simpsets *)
   295 (* empty simpsets *)
   296 
   296 
   297 local
   297 local
   298 
   298 
   299 fun init_ss mk_rews termless subgoal_tac solvers =
   299 fun init_ss mk_rews termless subgoal_tac solvers =
   300   make_simpset ((Net.empty, [], 0, 0),
   300   make_simpset ((Net.empty, [], 0),
   301     (([], []), Net.empty, mk_rews, termless, subgoal_tac, [], solvers));
   301     (([], []), Net.empty, mk_rews, termless, subgoal_tac, [], solvers));
   302 
   302 
   303 val basic_mk_rews: mk_rews =
   303 val basic_mk_rews: mk_rews =
   304  {mk = fn th => if can Logic.dest_equals (Thm.concl_of th) then [th] else [],
   304  {mk = fn th => if can Logic.dest_equals (Thm.concl_of th) then [th] else [],
   305   mk_cong = I,
   305   mk_cong = I,
   318 
   318 
   319 (* merge simpsets *)            (*NOTE: ignores some fields of 2nd simpset*)
   319 (* merge simpsets *)            (*NOTE: ignores some fields of 2nd simpset*)
   320 
   320 
   321 fun merge_ss (ss1, ss2) =
   321 fun merge_ss (ss1, ss2) =
   322   let
   322   let
   323     val Simpset ({rules = rules1, prems = prems1, bounds = bounds1, depth},
   323     val Simpset ({rules = rules1, prems = prems1, bounds = bounds1},
   324      {congs = (congs1, weak1), procs = procs1, mk_rews, termless, subgoal_tac,
   324      {congs = (congs1, weak1), procs = procs1, mk_rews, termless, subgoal_tac,
   325       loop_tacs = loop_tacs1, solvers = (unsafe_solvers1, solvers1)}) = ss1;
   325       loop_tacs = loop_tacs1, solvers = (unsafe_solvers1, solvers1)}) = ss1;
   326     val Simpset ({rules = rules2, prems = prems2, bounds = bounds2, depth = _},
   326     val Simpset ({rules = rules2, prems = prems2, bounds = bounds2},
   327      {congs = (congs2, weak2), procs = procs2, mk_rews = _, termless = _, subgoal_tac = _,
   327      {congs = (congs2, weak2), procs = procs2, mk_rews = _, termless = _, subgoal_tac = _,
   328       loop_tacs = loop_tacs2, solvers = (unsafe_solvers2, solvers2)}) = ss2;
   328       loop_tacs = loop_tacs2, solvers = (unsafe_solvers2, solvers2)}) = ss2;
   329 
   329 
   330     val rules' = Net.merge (rules1, rules2, eq_rrule);
   330     val rules' = Net.merge (rules1, rules2, eq_rrule);
   331     val prems' = gen_merge_lists Drule.eq_thm_prop prems1 prems2;
   331     val prems' = gen_merge_lists Drule.eq_thm_prop prems1 prems2;
   335     val procs' = Net.merge (procs1, procs2, eq_proc);
   335     val procs' = Net.merge (procs1, procs2, eq_proc);
   336     val loop_tacs' = merge_alists loop_tacs1 loop_tacs2;
   336     val loop_tacs' = merge_alists loop_tacs1 loop_tacs2;
   337     val unsafe_solvers' = merge_solvers unsafe_solvers1 unsafe_solvers2;
   337     val unsafe_solvers' = merge_solvers unsafe_solvers1 unsafe_solvers2;
   338     val solvers' = merge_solvers solvers1 solvers2;
   338     val solvers' = merge_solvers solvers1 solvers2;
   339   in
   339   in
   340     make_simpset ((rules', prems', bounds', depth), ((congs', weak'), procs',
   340     make_simpset ((rules', prems', bounds'), ((congs', weak'), procs',
   341       mk_rews, termless, subgoal_tac, loop_tacs', (unsafe_solvers', solvers')))
   341       mk_rews, termless, subgoal_tac, loop_tacs', (unsafe_solvers', solvers')))
   342   end;
   342   end;
   343 
   343 
   344 
   344 
   345 (* simprocs *)
   345 (* simprocs *)
   361 
   361 
   362 (** simpset operations **)
   362 (** simpset operations **)
   363 
   363 
   364 (* bounds and prems *)
   364 (* bounds and prems *)
   365 
   365 
   366 val incr_bounds = map_simpset1 (fn (rules, prems, bounds, depth) =>
   366 val incr_bounds = map_simpset1 (fn (rules, prems, bounds) =>
   367   (rules, prems, bounds + 1, depth));
   367   (rules, prems, bounds + 1));
   368 
   368 
   369 fun add_prems ths = map_simpset1 (fn (rules, prems, bounds, depth) =>
   369 fun add_prems ths = map_simpset1 (fn (rules, prems, bounds) =>
   370   (rules, ths @ prems, bounds, depth));
   370   (rules, ths @ prems, bounds));
   371 
   371 
   372 fun prems_of_ss (Simpset ({prems, ...}, _)) = prems;
   372 fun prems_of_ss (Simpset ({prems, ...}, _)) = prems;
   373 
   373 
   374 
   374 
   375 (* addsimps *)
   375 (* addsimps *)
   379     val fo = Pattern.first_order (term_of elhs) orelse not (Pattern.pattern (term_of elhs))
   379     val fo = Pattern.first_order (term_of elhs) orelse not (Pattern.pattern (term_of elhs))
   380   in {thm = thm, name = name, lhs = lhs, elhs = elhs, fo = fo, perm = perm} end;
   380   in {thm = thm, name = name, lhs = lhs, elhs = elhs, fo = fo, perm = perm} end;
   381 
   381 
   382 fun insert_rrule quiet (ss, rrule as {thm, name, lhs, elhs, perm}) =
   382 fun insert_rrule quiet (ss, rrule as {thm, name, lhs, elhs, perm}) =
   383  (trace_named_thm "Adding rewrite rule" (thm, name);
   383  (trace_named_thm "Adding rewrite rule" (thm, name);
   384   ss |> map_simpset1 (fn (rules, prems, bounds, depth) =>
   384   ss |> map_simpset1 (fn (rules, prems, bounds) =>
   385     let
   385     let
   386       val rrule2 as {elhs, ...} = mk_rrule2 rrule;
   386       val rrule2 as {elhs, ...} = mk_rrule2 rrule;
   387       val rules' = Net.insert_term ((term_of elhs, rrule2), rules, eq_rrule);
   387       val rules' = Net.insert_term ((term_of elhs, rrule2), rules, eq_rrule);
   388     in (rules', prems, bounds, depth) end)
   388     in (rules', prems, bounds) end)
   389   handle Net.INSERT =>
   389   handle Net.INSERT =>
   390     (if quiet then () else warn_thm "Ignoring duplicate rewrite rule:" thm; ss));
   390     (if quiet then () else warn_thm "Ignoring duplicate rewrite rule:" thm; ss));
   391 
   391 
   392 fun vperm (Var _, Var _) = true
   392 fun vperm (Var _, Var _) = true
   393   | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t)
   393   | vperm (Abs (_, _, s), Abs (_, _, t)) = vperm (s, t)
   504 
   504 
   505 
   505 
   506 (* delsimps *)
   506 (* delsimps *)
   507 
   507 
   508 fun del_rrule (ss, rrule as {thm, elhs, ...}) =
   508 fun del_rrule (ss, rrule as {thm, elhs, ...}) =
   509   ss |> map_simpset1 (fn (rules, prems, bounds, depth) =>
   509   ss |> map_simpset1 (fn (rules, prems, bounds) =>
   510     (Net.delete_term ((term_of elhs, rrule), rules, eq_rrule), prems, bounds, depth))
   510     (Net.delete_term ((term_of elhs, rrule), rules, eq_rrule), prems, bounds))
   511   handle Net.DELETE => (warn_thm "Rewrite rule not in simpset:" thm; ss);
   511   handle Net.DELETE => (warn_thm "Rewrite rule not in simpset:" thm; ss);
   512 
   512 
   513 fun ss delsimps thms =
   513 fun ss delsimps thms =
   514   orient_comb_simps del_rrule (map mk_rrule2 o mk_rrule ss) (ss, thms);
   514   orient_comb_simps del_rrule (map mk_rrule2 o mk_rrule ss) (ss, thms);
   515 
   515 
   754   while the premises are solved.*)
   754   while the premises are solved.*)
   755 
   755 
   756 fun cond_skel (args as (congs, (lhs, rhs))) =
   756 fun cond_skel (args as (congs, (lhs, rhs))) =
   757   if term_varnames rhs subset term_varnames lhs then uncond_skel args
   757   if term_varnames rhs subset term_varnames lhs then uncond_skel args
   758   else skel0;
   758   else skel0;
   759 
       
   760 fun incr_depth ss =
       
   761   let
       
   762     val ss' = ss |> map_simpset1 (fn (rules, prems, bounds, depth) =>
       
   763       (rules, prems, bounds, depth + 1));
       
   764     val Simpset ({depth = depth', ...}, _) = ss';
       
   765   in
       
   766     if depth' > ! simp_depth_limit
       
   767     then (warning "simp_depth_limit exceeded - giving up"; NONE)
       
   768     else
       
   769      (if depth' mod 10 = 0
       
   770       then warning ("Simplification depth " ^ string_of_int depth')
       
   771       else ();
       
   772       SOME ss')
       
   773   end;
       
   774 
   759 
   775 (*
   760 (*
   776   Rewriting -- we try in order:
   761   Rewriting -- we try in order:
   777     (1) beta reduction
   762     (1) beta reduction
   778     (2) unconditional rewrite rules
   763     (2) unconditional rewrite rules
   811               let val lr = Logic.dest_equals prop;
   796               let val lr = Logic.dest_equals prop;
   812                   val SOME thm'' = check_conv false eta_thm thm'
   797                   val SOME thm'' = check_conv false eta_thm thm'
   813               in SOME (thm'', uncond_skel (congs, lr)) end)
   798               in SOME (thm'', uncond_skel (congs, lr)) end)
   814            else
   799            else
   815              (trace_thm "Trying to rewrite:" thm';
   800              (trace_thm "Trying to rewrite:" thm';
   816               case incr_depth ss of
   801               if !simp_depth > !simp_depth_limit
   817                 NONE => (trace_thm "FAILED - reached depth limit" thm'; NONE)
   802               then let val s = "simp_depth_limit exceeded - giving up"
   818               | SOME ss' =>
   803                    in trace false s; warning s; NONE end
   819               (case prover ss' thm' of
   804               else
       
   805               case prover ss thm' of
   820                 NONE => (trace_thm "FAILED" thm'; NONE)
   806                 NONE => (trace_thm "FAILED" thm'; NONE)
   821               | SOME thm2 =>
   807               | SOME thm2 =>
   822                   (case check_conv true eta_thm thm2 of
   808                   (case check_conv true eta_thm thm2 of
   823                      NONE => NONE |
   809                      NONE => NONE |
   824                      SOME thm2' =>
   810                      SOME thm2' =>
   825                        let val concl = Logic.strip_imp_concl prop
   811                        let val concl = Logic.strip_imp_concl prop
   826                            val lr = Logic.dest_equals concl
   812                            val lr = Logic.dest_equals concl
   827                        in SOME (thm2', cond_skel (congs, lr)) end))))
   813                        in SOME (thm2', cond_skel (congs, lr)) end)))
   828       end
   814       end
   829 
   815 
   830     fun rews [] = NONE
   816     fun rews [] = NONE
   831       | rews (rrule :: rrules) =
   817       | rews (rrule :: rrules) =
   832           let val opt = rew rrule handle Pattern.MATCH => NONE
   818           let val opt = rew rrule handle Pattern.MATCH => NONE
  1097            when simplifying A ==> B
  1083            when simplifying A ==> B
  1098     prover: how to solve premises in conditional rewrites and congruences
  1084     prover: how to solve premises in conditional rewrites and congruences
  1099 *)
  1085 *)
  1100 
  1086 
  1101 fun rewrite_cterm mode prover ss ct =
  1087 fun rewrite_cterm mode prover ss ct =
  1102   let
  1088   (simp_depth := !simp_depth + 1;
  1103     val Simpset ({depth, ...}, _) = ss;
  1089    if !simp_depth mod 10 = 0
  1104     val {sign, t, maxidx, ...} = Thm.rep_cterm ct;
  1090    then warning ("Simplification depth " ^ string_of_int (!simp_depth))
  1105   in
  1091    else ();
  1106     trace_cterm false "SIMPLIFIER INVOKED ON THE FOLLOWING TERM:" ct;
  1092    trace_cterm false "SIMPLIFIER INVOKED ON THE FOLLOWING TERM:" ct;
  1107     simp_depth := depth;
  1093    let val {sign, t, maxidx, ...} = Thm.rep_cterm ct
  1108     bottomc (mode, prover, sign, maxidx) ss ct
  1094        val res = bottomc (mode, prover, sign, maxidx) ss ct
  1109   end handle THM (s, _, thms) =>
  1095          handle THM (s, _, thms) =>
  1110     error ("Exception THM was raised in simplifier:\n" ^ s ^ "\n" ^
  1096          error ("Exception THM was raised in simplifier:\n" ^ s ^ "\n" ^
  1111       Pretty.string_of (Display.pretty_thms thms));
  1097            Pretty.string_of (Display.pretty_thms thms))
       
  1098    in simp_depth := !simp_depth - 1; res end
       
  1099   ) handle exn => (simp_depth := 0; raise exn);
  1112 
  1100 
  1113 (*Rewrite a cterm*)
  1101 (*Rewrite a cterm*)
  1114 fun rewrite_aux _ _ [] = (fn ct => Thm.reflexive ct)
  1102 fun rewrite_aux _ _ [] = (fn ct => Thm.reflexive ct)
  1115   | rewrite_aux prover full thms =
  1103   | rewrite_aux prover full thms =
  1116       rewrite_cterm (full, false, false) prover (empty_ss addsimps thms);
  1104       rewrite_cterm (full, false, false) prover (empty_ss addsimps thms);