src/HOL/SMT/Tools/smt_normalize.ML
changeset 33010 39f73a59e855
parent 32740 9dd0a2f83429
child 33243 17014b1b9353
equal deleted inserted replaced
33008:b0ff69f0a248 33010:39f73a59e855
   271   val meta_eq = @{cpat "op =="}
   271   val meta_eq = @{cpat "op =="}
   272   val meta_eqT = hd (Thm.dest_ctyp (Thm.ctyp_of_term meta_eq))
   272   val meta_eqT = hd (Thm.dest_ctyp (Thm.ctyp_of_term meta_eq))
   273   fun inst_meta cT = Thm.instantiate_cterm ([(meta_eqT, cT)], []) meta_eq
   273   fun inst_meta cT = Thm.instantiate_cterm ([(meta_eqT, cT)], []) meta_eq
   274   fun mk_meta_eq ct cu = Thm.mk_binop (inst_meta (Thm.ctyp_of_term ct)) ct cu
   274   fun mk_meta_eq ct cu = Thm.mk_binop (inst_meta (Thm.ctyp_of_term ct)) ct cu
   275 
   275 
   276   fun lambda_conv conv =
   276   val fresh_name = yield_singleton Name.variants
   277     let
       
   278       fun sub_conv cvs ctxt ct =
       
   279         (case Thm.term_of ct of
       
   280           Const (@{const_name All}, _) $ Abs _ => quant_conv cvs ctxt
       
   281         | Const (@{const_name Ex}, _) $ Abs _ => quant_conv cvs ctxt
       
   282         | Const _ $ Abs _ => Conv.arg_conv (at_lambda_conv cvs ctxt)
       
   283         | Const (@{const_name Let}, _) $ _ $ Abs _ => Conv.combination_conv
       
   284             (Conv.arg_conv (sub_conv cvs ctxt)) (abs_conv cvs ctxt)
       
   285         | Abs _ => at_lambda_conv cvs ctxt
       
   286         | _ $ _ => Conv.comb_conv (sub_conv cvs ctxt)
       
   287         | _ => Conv.all_conv) ct
       
   288       and abs_conv cvs = Conv.abs_conv (fn (cv, cx) => sub_conv (cv::cvs) cx)
       
   289       and quant_conv cvs ctxt = Conv.arg_conv (abs_conv cvs ctxt)
       
   290       and at_lambda_conv cvs ctxt = abs_conv cvs ctxt then_conv conv cvs ctxt
       
   291     in sub_conv [] end
       
   292 
   277 
   293   fun used_vars cvs ct =
   278   fun used_vars cvs ct =
   294     let
   279     let
   295       val lookup = AList.lookup (op aconv) (map (` Thm.term_of) cvs)
   280       val lookup = AList.lookup (op aconv) (map (` Thm.term_of) cvs)
   296       val add = (fn (SOME ct) => insert (op aconvc) ct | _ => I)
   281       val add = (fn (SOME ct) => insert (op aconvc) ct | _ => I)
   297     in Term.fold_aterms (add o lookup) (Thm.term_of ct) [] end
   282     in Term.fold_aterms (add o lookup) (Thm.term_of ct) [] end
   298 
   283   fun make_def cvs eq = Thm.symmetric (fold norm_meta_def cvs eq)
   299   val rev_int_fst_ord = rev_order o int_ord o pairself fst
   284   fun add_def ct thm = Termtab.update (Thm.term_of ct, (serial (), thm))
   300   fun ordered_values tab =
   285 
   301     Termtab.fold (fn (_, x) => OrdList.insert rev_int_fst_ord x) tab []
   286   fun replace ctxt cvs ct (cx as (nctxt, defs)) =
   302     |> map snd
   287     let
       
   288       val cvs' = used_vars cvs ct
       
   289       val ct' = fold Thm.cabs cvs' ct
       
   290       val mk_repl = fold (fn ct => fn cu => Thm.capply cu ct) cvs'
       
   291     in
       
   292       (case Termtab.lookup defs (Thm.term_of ct') of
       
   293         SOME (_, eq) => (make_def cvs' eq, cx)
       
   294       | NONE =>
       
   295           let
       
   296             val {t, T, ...} = Thm.rep_cterm ct'
       
   297             val (n, nctxt') = fresh_name "" nctxt
       
   298             val eq = Thm.assume (mk_meta_eq (cert ctxt (Free (n, T))) ct')
       
   299           in (make_def cvs' eq, (nctxt', add_def ct' eq defs)) end)
       
   300     end
       
   301 
       
   302   fun none ct cx = (Thm.reflexive ct, cx)
       
   303   fun in_comb f g ct cx =
       
   304     let val (cu1, cu2) = Thm.dest_comb ct
       
   305     in cx |> f cu1 ||>> g cu2 |>> uncurry Thm.combination end
       
   306   fun in_arg f = in_comb none f
       
   307   fun in_abs f cvs ct (nctxt, defs) =
       
   308     let
       
   309       val (n, nctxt') = fresh_name Name.uu nctxt
       
   310       val (cv, cu) = Thm.dest_abs (SOME n) ct
       
   311     in f (cv :: cvs) cu (nctxt', defs) |>> Thm.abstract_rule n cv end
       
   312 
       
   313   fun replace_lambdas ctxt =
       
   314     let
       
   315       fun repl cvs ct =
       
   316         (case Thm.term_of ct of
       
   317           Const (@{const_name All}, _) $ Abs _ => in_arg (in_abs repl cvs)
       
   318         | Const (@{const_name Ex}, _) $ Abs _ => in_arg (in_abs repl cvs)
       
   319         | Const _ $ Abs _ => in_arg (at_lambda cvs)
       
   320         | Const (@{const_name Let}, _) $ _ $ Abs _ =>
       
   321             in_comb (in_arg (repl cvs)) (in_abs repl cvs)
       
   322         | Abs _ => at_lambda cvs
       
   323         | _ $ _ => in_comb (repl cvs) (repl cvs)
       
   324         | _ => none) ct
       
   325       and at_lambda cvs ct cx =
       
   326         let
       
   327           val (thm1, cx') = in_abs repl cvs ct cx
       
   328           val (thm2, cx'') = replace ctxt cvs (Thm.rhs_of thm1) cx'
       
   329         in (Thm.transitive thm1 thm2, cx'') end
       
   330     in repl [] end
   303 in
   331 in
   304 fun lift_lambdas ctxt thms =
   332 fun lift_lambdas ctxt thms =
   305   let
   333   let
   306     val declare_frees = fold (Thm.fold_terms Term.declare_term_frees)
   334     val declare_frees = fold (Thm.fold_terms Term.declare_term_frees)
   307     val names = Unsynchronized.ref (declare_frees thms (Name.make_context []))
   335     fun rewrite f thm cx =
   308     val fresh_name = Unsynchronized.change_result names o yield_singleton Name.variants
   336       let val (thm', cx') = f (Thm.cprop_of thm) cx
   309 
   337       in (Thm.equal_elim thm' thm, cx') end
   310     val defs = Unsynchronized.ref (Termtab.empty : (int * thm) Termtab.table)
   338 
   311     fun add_def t thm = Unsynchronized.change defs (Termtab.update (t, (serial (), thm)))
   339     val rev_int_fst_ord = rev_order o int_ord o pairself fst
   312     fun make_def cvs eq = Thm.symmetric (fold norm_meta_def cvs eq)
   340     fun ordered_values tab =
   313     fun def_conv cvs ctxt ct =
   341       Termtab.fold (fn (_, x) => OrdList.insert rev_int_fst_ord x) tab []
   314       let
   342       |> map snd
   315         val cvs' = used_vars cvs ct
   343 
   316         val ct' = fold Thm.cabs cvs' ct
   344     val (thms', (_, defs)) =
   317       in
   345       (declare_frees thms (Name.make_context []), Termtab.empty)
   318         (case Termtab.lookup (!defs) (Thm.term_of ct') of
   346       |> fold_map (rewrite (replace_lambdas ctxt)) thms
   319           SOME (_, eq) => make_def cvs' eq
   347     val eqs = ordered_values defs
   320         | NONE =>
       
   321             let
       
   322               val {t, T, ...} = Thm.rep_cterm ct'
       
   323               val eq = mk_meta_eq (cert ctxt (Free (fresh_name "", T))) ct'
       
   324               val thm = Thm.assume eq
       
   325             in (add_def t thm; make_def cvs' thm) end)
       
   326       end
       
   327     val thms' = map (Conv.fconv_rule (lambda_conv def_conv ctxt)) thms
       
   328     val eqs = ordered_values (!defs)
       
   329   in
   348   in
   330     (maps (#hyps o Thm.crep_thm) eqs, map (normalize_rule ctxt) eqs @ thms')
   349     (maps (#hyps o Thm.crep_thm) eqs, map (normalize_rule ctxt) eqs @ thms')
   331   end
   350   end
   332 end
   351 end
   333 
   352