src/Pure/meta_simplifier.ML
changeset 15570 8d8c70b41bab
parent 15531 08c8dad8e399
child 15574 b1d1b5bfc464
equal deleted inserted replaced
15569:1b3115d1a8df 15570:8d8c70b41bab
   400 (* FIXME: it seems that the conditions on extra variables are too liberal if
   400 (* FIXME: it seems that the conditions on extra variables are too liberal if
   401 prems are nonempty: does solving the prems really guarantee instantiation of
   401 prems are nonempty: does solving the prems really guarantee instantiation of
   402 all its Vars? Better: a dynamic check each time a rule is applied.
   402 all its Vars? Better: a dynamic check each time a rule is applied.
   403 *)
   403 *)
   404 fun rewrite_rule_extra_vars prems elhs erhs =
   404 fun rewrite_rule_extra_vars prems elhs erhs =
   405   not (term_varnames erhs subset foldl add_term_varnames (term_varnames elhs, prems))
   405   not (term_varnames erhs subset Library.foldl add_term_varnames (term_varnames elhs, prems))
   406   orelse
   406   orelse
   407   not (term_tvars erhs subset (term_tvars elhs union List.concat (map term_tvars prems)));
   407   not (term_tvars erhs subset (term_tvars elhs union List.concat (map term_tvars prems)));
   408 
   408 
   409 (*simple test for looping rewrite rules and stupid orientations*)
   409 (*simple test for looping rewrite rules and stupid orientations*)
   410 fun reorient sign prems lhs rhs =
   410 fun reorient sign prems lhs rhs =
   485         end
   485         end
   486     else rrule_eq_True (thm, name, lhs, elhs, rhs, ss, thm)
   486     else rrule_eq_True (thm, name, lhs, elhs, rhs, ss, thm)
   487   end;
   487   end;
   488 
   488 
   489 fun extract_rews (Simpset (_, {mk_rews = {mk, ...}, ...}), thms) =
   489 fun extract_rews (Simpset (_, {mk_rews = {mk, ...}, ...}), thms) =
   490   flat (map (fn thm => map (rpair (Thm.name_of_thm thm)) (mk thm)) thms);
   490   List.concat (map (fn thm => map (rpair (Thm.name_of_thm thm)) (mk thm)) thms);
   491 
   491 
   492 fun orient_comb_simps comb mk_rrule (ss, thms) =
   492 fun orient_comb_simps comb mk_rrule (ss, thms) =
   493   let
   493   let
   494     val rews = extract_rews (ss, thms);
   494     val rews = extract_rews (ss, thms);
   495     val rrules = flat (map mk_rrule rews);
   495     val rrules = List.concat (map mk_rrule rews);
   496   in foldl comb (ss, rrules) end;
   496   in Library.foldl comb (ss, rrules) end;
   497 
   497 
   498 fun extract_safe_rrules (ss, thm) =
   498 fun extract_safe_rrules (ss, thm) =
   499   flat (map (orient_rrule ss) (extract_rews (ss, [thm])));
   499   List.concat (map (orient_rrule ss) (extract_rews (ss, [thm])));
   500 
   500 
   501 (*add rewrite rules explicitly; do not reorient!*)
   501 (*add rewrite rules explicitly; do not reorient!*)
   502 fun ss addsimps thms =
   502 fun ss addsimps thms =
   503   orient_comb_simps (insert_rrule false) (mk_rrule ss) (ss, thms);
   503   orient_comb_simps (insert_rrule false) (mk_rrule ss) (ss, thms);
   504 
   504 
   549   map_simpset2 (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) =>
   549   map_simpset2 (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) =>
   550     let
   550     let
   551       val (lhs, _) = Drule.dest_equals (Drule.strip_imp_concl (Thm.cprop_of thm))
   551       val (lhs, _) = Drule.dest_equals (Drule.strip_imp_concl (Thm.cprop_of thm))
   552         handle TERM _ => raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   552         handle TERM _ => raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   553     (*val lhs = Pattern.eta_contract lhs;*)
   553     (*val lhs = Pattern.eta_contract lhs;*)
   554       val a = the (cong_name (head_of (term_of lhs))) handle Option =>
   554       val a = valOf (cong_name (head_of (term_of lhs))) handle Option =>
   555         raise SIMPLIFIER ("Congruence must start with a constant or free variable", thm);
   555         raise SIMPLIFIER ("Congruence must start with a constant or free variable", thm);
   556       val (alist, weak) = congs;
   556       val (alist, weak) = congs;
   557       val alist2 = overwrite_warn (alist, (a, {lhs = lhs, thm = thm}))
   557       val alist2 = overwrite_warn (alist, (a, {lhs = lhs, thm = thm}))
   558         ("Overwriting congruence rule for " ^ quote a);
   558         ("Overwriting congruence rule for " ^ quote a);
   559       val weak2 = if is_full_cong thm then weak else a :: weak;
   559       val weak2 = if is_full_cong thm then weak else a :: weak;
   563   map_simpset2 (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) =>
   563   map_simpset2 (fn (congs, procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) =>
   564     let
   564     let
   565       val (lhs, _) = Logic.dest_equals (Thm.concl_of thm) handle TERM _ =>
   565       val (lhs, _) = Logic.dest_equals (Thm.concl_of thm) handle TERM _ =>
   566         raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   566         raise SIMPLIFIER ("Congruence not a meta-equality", thm);
   567     (*val lhs = Pattern.eta_contract lhs;*)
   567     (*val lhs = Pattern.eta_contract lhs;*)
   568       val a = the (cong_name (head_of lhs)) handle Option =>
   568       val a = valOf (cong_name (head_of lhs)) handle Option =>
   569         raise SIMPLIFIER ("Congruence must start with a constant", thm);
   569         raise SIMPLIFIER ("Congruence must start with a constant", thm);
   570       val (alist, _) = congs;
   570       val (alist, _) = congs;
   571       val alist2 = filter (fn (x, _) => x <> a) alist;
   571       val alist2 = List.filter (fn (x, _) => x <> a) alist;
   572       val weak2 = alist2 |> mapfilter (fn (a, {thm, ...}) =>
   572       val weak2 = alist2 |> List.mapPartial (fn (a, {thm, ...}) =>
   573         if is_full_cong thm then NONE else SOME a);
   573         if is_full_cong thm then NONE else SOME a);
   574     in ((alist2, weak2), procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) end);
   574     in ((alist2, weak2), procs, mk_rews, termless, subgoal_tac, loop_tacs, solvers) end);
   575 
   575 
   576 fun mk_cong (Simpset (_, {mk_rews = {mk_cong = f, ...}, ...})) = f;
   576 fun mk_cong (Simpset (_, {mk_rews = {mk_cong = f, ...}, ...})) = f;
   577 
   577 
   578 in
   578 in
   579 
   579 
   580 val (op addeqcongs) = foldl add_cong;
   580 val (op addeqcongs) = Library.foldl add_cong;
   581 val (op deleqcongs) = foldl del_cong;
   581 val (op deleqcongs) = Library.foldl del_cong;
   582 
   582 
   583 fun ss addcongs congs = ss addeqcongs map (mk_cong ss) congs;
   583 fun ss addcongs congs = ss addeqcongs map (mk_cong ss) congs;
   584 fun ss delcongs congs = ss deleqcongs map (mk_cong ss) congs;
   584 fun ss delcongs congs = ss deleqcongs map (mk_cong ss) congs;
   585 
   585 
   586 end;
   586 end;
   605   handle Net.DELETE =>
   605   handle Net.DELETE =>
   606     (warning ("Simplification procedure " ^ quote name ^ " not in simpset"); ss);
   606     (warning ("Simplification procedure " ^ quote name ^ " not in simpset"); ss);
   607 
   607 
   608 in
   608 in
   609 
   609 
   610 val (op addsimprocs) = foldl (fn (ss, Simproc procs) => foldl add_proc (ss, procs));
   610 val (op addsimprocs) = Library.foldl (fn (ss, Simproc procs) => Library.foldl add_proc (ss, procs));
   611 val (op delsimprocs) = foldl (fn (ss, Simproc procs) => foldl del_proc (ss, procs));
   611 val (op delsimprocs) = Library.foldl (fn (ss, Simproc procs) => Library.foldl del_proc (ss, procs));
   612 
   612 
   613 end;
   613 end;
   614 
   614 
   615 
   615 
   616 (* mk_rews *)
   616 (* mk_rews *)
   964                        | SOME cong =>
   964                        | SOME cong =>
   965   (*post processing: some partial applications h t1 ... tj, j <= length ts,
   965   (*post processing: some partial applications h t1 ... tj, j <= length ts,
   966     may be a redex. Example: map (%x. x) = (%xs. xs) wrt map_cong*)
   966     may be a redex. Example: map (%x. x) = (%xs. xs) wrt map_cong*)
   967                           (let
   967                           (let
   968                              val thm = congc (prover ss, sign, maxidx) cong t0;
   968                              val thm = congc (prover ss, sign, maxidx) cong t0;
   969                              val t = if_none (apsome rhs_of thm) t0;
   969                              val t = getOpt (Option.map rhs_of thm, t0);
   970                              val (cl, cr) = Thm.dest_comb t
   970                              val (cl, cr) = Thm.dest_comb t
   971                              val dVar = Var(("", 0), dummyT)
   971                              val dVar = Var(("", 0), dummyT)
   972                              val skel =
   972                              val skel =
   973                                list_comb (h, replicate (length ts) dVar)
   973                                list_comb (h, replicate (length ts) dVar)
   974                            in case botc skel ss cl of
   974                            in case botc skel ss cl of
   991       else
   991       else
   992         let val asm = assume prem
   992         let val asm = assume prem
   993         in (extract_safe_rrules (ss, asm), SOME asm) end
   993         in (extract_safe_rrules (ss, asm), SOME asm) end
   994 
   994 
   995     and add_rrules (rrss, asms) ss =
   995     and add_rrules (rrss, asms) ss =
   996       foldl (insert_rrule true) (ss, flat rrss) |> add_prems (mapfilter I asms)
   996       Library.foldl (insert_rrule true) (ss, List.concat rrss) |> add_prems (List.mapPartial I asms)
   997 
   997 
   998     and disch r (prem, eq) =
   998     and disch r (prem, eq) =
   999       let
   999       let
  1000         val (lhs, rhs) = dest_eq eq;
  1000         val (lhs, rhs) = dest_eq eq;
  1001         val eq' = implies_elim (Thm.instantiate
  1001         val eq' = implies_elim (Thm.instantiate
  1016     and rebuild [] _ _ _ _ eq = eq
  1016     and rebuild [] _ _ _ _ eq = eq
  1017       | rebuild (prem :: prems) concl (rrs :: rrss) (asm :: asms) ss eq =
  1017       | rebuild (prem :: prems) concl (rrs :: rrss) (asm :: asms) ss eq =
  1018           let
  1018           let
  1019             val ss' = add_rrules (rev rrss, rev asms) ss;
  1019             val ss' = add_rrules (rev rrss, rev asms) ss;
  1020             val concl' =
  1020             val concl' =
  1021               Drule.mk_implies (prem, if_none (apsome rhs_of eq) concl);
  1021               Drule.mk_implies (prem, getOpt (Option.map rhs_of eq, concl));
  1022             val dprem = apsome (curry (disch false) prem)
  1022             val dprem = Option.map (curry (disch false) prem)
  1023           in case rewritec (prover, sign, maxidx) ss' concl' of
  1023           in case rewritec (prover, sign, maxidx) ss' concl' of
  1024               NONE => rebuild prems concl' rrss asms ss (dprem eq)
  1024               NONE => rebuild prems concl' rrss asms ss (dprem eq)
  1025             | SOME (eq', _) => transitive2 (foldl (disch false o swap)
  1025             | SOME (eq', _) => transitive2 (Library.foldl (disch false o swap)
  1026                   (the (transitive3 (dprem eq) eq'), prems))
  1026                   (valOf (transitive3 (dprem eq) eq'), prems))
  1027                 (mut_impc0 (rev prems) (rhs_of eq') (rev rrss) (rev asms) ss)
  1027                 (mut_impc0 (rev prems) (rhs_of eq') (rev rrss) (rev asms) ss)
  1028           end
  1028           end
  1029 
  1029 
  1030     and mut_impc0 prems concl rrss asms ss =
  1030     and mut_impc0 prems concl rrss asms ss =
  1031       let
  1031       let
  1034       in mut_impc (prems @ prems') (strip_imp_concl concl) (rrss @ rrss')
  1034       in mut_impc (prems @ prems') (strip_imp_concl concl) (rrss @ rrss')
  1035         (asms @ asms') [] [] [] [] ss ~1 ~1
  1035         (asms @ asms') [] [] [] [] ss ~1 ~1
  1036       end
  1036       end
  1037 
  1037 
  1038     and mut_impc [] concl [] [] prems' rrss' asms' eqns ss changed k =
  1038     and mut_impc [] concl [] [] prems' rrss' asms' eqns ss changed k =
  1039         transitive1 (foldl (fn (eq2, (eq1, prem)) => transitive1 eq1
  1039         transitive1 (Library.foldl (fn (eq2, (eq1, prem)) => transitive1 eq1
  1040             (apsome (curry (disch false) prem) eq2)) (NONE, eqns ~~ prems'))
  1040             (Option.map (curry (disch false) prem) eq2)) (NONE, eqns ~~ prems'))
  1041           (if changed > 0 then
  1041           (if changed > 0 then
  1042              mut_impc (rev prems') concl (rev rrss') (rev asms')
  1042              mut_impc (rev prems') concl (rev rrss') (rev asms')
  1043                [] [] [] [] ss ~1 changed
  1043                [] [] [] [] ss ~1 changed
  1044            else rebuild prems' concl rrss' asms' ss
  1044            else rebuild prems' concl rrss' asms' ss
  1045              (botc skel0 (add_rrules (rev rrss', rev asms') ss) concl))
  1045              (botc skel0 (add_rrules (rev rrss', rev asms') ss) concl))
  1053               (if k = 0 then 0 else k - 1)
  1053               (if k = 0 then 0 else k - 1)
  1054           | SOME eqn =>
  1054           | SOME eqn =>
  1055             let
  1055             let
  1056               val prem' = rhs_of eqn;
  1056               val prem' = rhs_of eqn;
  1057               val tprems = map term_of prems;
  1057               val tprems = map term_of prems;
  1058               val i = 1 + foldl Int.max (~1, map (fn p =>
  1058               val i = 1 + Library.foldl Int.max (~1, map (fn p =>
  1059                 find_index_eq p tprems) (#hyps (rep_thm eqn)));
  1059                 find_index_eq p tprems) (#hyps (rep_thm eqn)));
  1060               val (rrs', asm') = rules_of_prem ss prem'
  1060               val (rrs', asm') = rules_of_prem ss prem'
  1061             in mut_impc prems concl rrss asms (prem' :: prems')
  1061             in mut_impc prems concl rrss asms (prem' :: prems')
  1062               (rrs' :: rrss') (asm' :: asms') (SOME (foldr (disch true)
  1062               (rrs' :: rrss') (asm' :: asms') (SOME (Library.foldr (disch true)
  1063                 (take (i, prems), Drule.imp_cong' eqn (reflexive (Drule.list_implies
  1063                 (Library.take (i, prems), Drule.imp_cong' eqn (reflexive (Drule.list_implies
  1064                   (drop (i, prems), concl))))) :: eqns) ss (length prems') ~1
  1064                   (Library.drop (i, prems), concl))))) :: eqns) ss (length prems') ~1
  1065             end
  1065             end
  1066 
  1066 
  1067      (*legacy code - only for backwards compatibility*)
  1067      (*legacy code - only for backwards compatibility*)
  1068      and nonmut_impc ct ss =
  1068      and nonmut_impc ct ss =
  1069        let val (prem, conc) = dest_implies ct;
  1069        let val (prem, conc) = dest_implies ct;
  1070            val thm1 = if simprem then botc skel0 ss prem else NONE;
  1070            val thm1 = if simprem then botc skel0 ss prem else NONE;
  1071            val prem1 = if_none (apsome rhs_of thm1) prem;
  1071            val prem1 = getOpt (Option.map rhs_of thm1, prem);
  1072            val ss1 = if not useprem then ss else add_rrules
  1072            val ss1 = if not useprem then ss else add_rrules
  1073              (apsnd single (apfst single (rules_of_prem ss prem1))) ss
  1073              (apsnd single (apfst single (rules_of_prem ss prem1))) ss
  1074        in (case botc skel0 ss1 conc of
  1074        in (case botc skel0 ss1 conc of
  1075            NONE => (case thm1 of
  1075            NONE => (case thm1 of
  1076                NONE => NONE
  1076                NONE => NONE
  1167 
  1167 
  1168 fun simp rew mode ss thm =
  1168 fun simp rew mode ss thm =
  1169   let
  1169   let
  1170     val Simpset (_, {solvers = (unsafe_solvers, _), ...}) = ss;
  1170     val Simpset (_, {solvers = (unsafe_solvers, _), ...}) = ss;
  1171     val tacf = solve_all_tac unsafe_solvers;
  1171     val tacf = solve_all_tac unsafe_solvers;
  1172     fun prover s th = apsome #1 (Seq.pull (tacf s th));
  1172     fun prover s th = Option.map #1 (Seq.pull (tacf s th));
  1173   in rew mode prover ss thm end;
  1173   in rew mode prover ss thm end;
  1174 
  1174 
  1175 val simp_thm = simp rewrite_thm;
  1175 val simp_thm = simp rewrite_thm;
  1176 val simp_cterm = simp rewrite_cterm;
  1176 val simp_cterm = simp rewrite_cterm;
  1177 
  1177