improvements to abstraction, ensuring more re-use of abstraction functions
authorpaulson
Thu, 05 Oct 2006 10:42:39 +0200
changeset 20863 4ee61dbf192d
parent 20862 1059f2316f88
child 20864 bb75b876b260
improvements to abstraction, ensuring more re-use of abstraction functions moving some functions to Pure/drule.ML
src/HOL/Tools/res_axioms.ML
--- a/src/HOL/Tools/res_axioms.ML	Thu Oct 05 10:41:27 2006 +0200
+++ b/src/HOL/Tools/res_axioms.ML	Thu Oct 05 10:42:39 2006 +0200
@@ -38,11 +38,29 @@
 
 val trace_abs = ref false;
 
+(*FIXME: move some of these functions to Pure/drule.ML *)
+
+fun freeze_thm th = #1 (Drule.freeze_thaw th);
+
+fun lhs_of th =
+  case prop_of th of (Const("==",_) $ lhs $ _) => lhs
+    | _ => raise THM ("lhs_of", 1, [th]);
+
+fun rhs_of th =
+  case prop_of th of (Const("==",_) $ _ $ rhs) => rhs
+    | _ => raise THM ("rhs_of", 1, [th]);
+
 (*Store definitions of abstraction functions, ensuring that identical right-hand
   sides are denoted by the same functions and thereby reducing the need for
   extensionality in proofs.
   FIXME!  Store in theory data!!*)
-val abstraction_cache = ref Net.empty : thm Net.net ref;
+
+fun seed th net =
+  let val (_,ct) = Thm.dest_abs NONE (Drule.crhs_of th)
+  in  Net.insert_term eq_thm (term_of ct, th) net end;
+  
+val abstraction_cache = ref 
+  (seed (thm"COMBI1") (seed (thm"COMBB1") (seed (thm"COMBK1") Net.empty)));
 
 (**** Transformation of Elimination Rules into First-Order Formulas****)
 
@@ -86,7 +104,7 @@
 
 (* convert an elim rule into an equivalent formula, of type term. *)
 fun elimR2Fol elimR =
-  let val elimR' = #1 (Drule.freeze_thaw elimR)
+  let val elimR' = freeze_thm elimR
       val (prems,concl) = (prems_of elimR', concl_of elimR')
       val cv = case concl of    (*conclusion variable*)
                   Const("Trueprop",_) $ (v as Free(_,Type("bool",[]))) => v
@@ -207,41 +225,24 @@
 fun xfun_cong x = Thm.instantiate ([], [(cx, mkvar x)]) fun_cong'
 end;
 
-(*Removes the lambdas from an equation of the form t = (%x. u)*)
-fun strip_lambdas th =
-  case prop_of th of
-      _ $ (Const ("op =", _) $ _ $ Abs (x,_,_)) =>
-          strip_lambdas (#1 (Drule.freeze_thaw (th RS xfun_cong x)))
-    | _ => th;
+(*Removes the lambdas from an equation of the form t = (%x. u).  A non-negative n,
+  serves as an upper bound on how many to remove.*)
+fun strip_lambdas 0 th = th
+  | strip_lambdas n th = 
+      case prop_of th of
+	  _ $ (Const ("op =", _) $ _ $ Abs (x,_,_)) =>
+	      strip_lambdas (n-1) (freeze_thm (th RS xfun_cong x))
+	| _ => th;
 
 (*Convert meta- to object-equality. Fails for theorems like split_comp_eq,
   where some types have the empty sort.*)
-fun object_eq th = th RS def_imp_eq
+fun mk_object_eq th = th RS def_imp_eq
     handle THM _ => error ("Theorem contains empty sort: " ^ string_of_thm th);
 
-(*Contract all eta-redexes in the theorem, lest they give rise to needless abstractions*)
-fun eta_conversion_rule th =
-  equal_elim (eta_conversion (cprop_of th)) th;
-
-fun crhs_of th =
-  case Drule.strip_comb (cprop_of th) of
-      (f, [_, rhs]) =>
-          (case term_of f of Const ("==", _) => rhs
-             | _ => raise THM ("crhs_of", 0, [th]))
-    | _ => raise THM ("crhs_of", 1, [th]);
-
-fun lhs_of th =
-  case prop_of th of (Const("==",_) $ lhs $ _) => lhs
-    | _ => raise THM ("lhs_of", 1, [th]);
-
-fun rhs_of th =
-  case prop_of th of (Const("==",_) $ _ $ rhs) => rhs
-    | _ => raise THM ("rhs_of", 1, [th]);
-
 (*Apply a function definition to an argument, beta-reducing the result.*)
 fun beta_comb cf x =
   let val th1 = combination cf (reflexive x)
-      val th2 = beta_conversion false (crhs_of th1)
+      val th2 = beta_conversion false (Drule.crhs_of th1)
   in  transitive th1 th2  end;
 
 (*Apply a function definition to arguments, beta-reducing along the way.*)
@@ -281,6 +282,26 @@
   | abstract_rule_list (v::vs) (ct::cts) th = abstract_rule v ct (abstract_rule_list vs cts th)
   | abstract_rule_list _ _ th = raise THM ("abstract_rule_list", 0, [th]);
 
+
+val Envir.Envir {asol = tenv0, iTs = tyenv0, ...} = Envir.empty 0
+
+(*Does an existing abstraction definition have an RHS that matches the one we need now?*)
+fun match_rhs thy0 t th =
+  let val thy = theory_of_thm th
+      val _ = if !trace_abs then warning ("match_rhs: " ^ string_of_cterm (cterm_of thy t) ^ 
+                                          " against\n" ^ string_of_thm th) else ();
+      val (tyenv,tenv) = if Context.joinable (thy0,thy) then
+                            Pattern.first_order_match thy (rhs_of th, t) (tyenv0,tenv0)
+                         else raise Pattern.MATCH
+      val term_insts = map Meson.term_pair_of (Vartab.dest tenv)
+      val ct_pairs = if forall lambda_free (map #2 term_insts) then
+                         map (pairself (cterm_of thy)) term_insts
+                     else raise Pattern.MATCH (*Cannot allow lambdas in the instantiation*)
+      fun ctyp2 (ixn, (S, T)) = (ctyp_of thy (TVar (ixn, S)), ctyp_of thy T)
+      val th' = cterm_instantiate ct_pairs th
+  in  SOME (th, instantiate (map ctyp2 (Vartab.dest tyenv), []) th')  end
+  handle _ => NONE;
+
 (*Traverse a theorem, declaring abstraction function definitions. String s is the suggested
   prefix for the constants. Resulting theory is returned in the first theorem. *)
 fun declare_absfuns th =
@@ -293,20 +314,28 @@
                 val _ = assert_eta_free ct;
                 val (cvs,cta) = dest_abs_list ct
                 val (vs,Tvs) = ListPair.unzip (map (dest_Free o term_of) cvs)
+                val _ = if !trace_abs then warning ("Nested lambda: " ^ string_of_cterm cta) else ();
                 val (u'_th,defs) = abstract thy cta
-                val cu' = crhs_of u'_th
-                val abs_v_u = lambda_list (map term_of cvs) (term_of cu')
+                val _ = if !trace_abs then warning ("Returned " ^ string_of_thm u'_th) else ();
+                val cu' = Drule.crhs_of u'_th
+                val u' = term_of cu'
+                val abs_v_u = lambda_list (map term_of cvs) u'
                 (*get the formal parameters: ALL variables free in the term*)
                 val args = term_frees abs_v_u
+                val _ = if !trace_abs then warning (Int.toString (length args) ^ " arguments") else ();
                 val rhs = list_abs_free (map dest_Free args, abs_v_u)
                       (*Forms a lambda-abstraction over the formal parameters*)
                 val v_rhs = Logic.varify rhs
-                val (ax,thy) =
-                 case List.find (fn ax => v_rhs aconv rhs_of ax)
-                        (Net.match_term (!abstraction_cache) v_rhs) of
-                     SOME ax => (ax,thy)   (*cached axiom, current theory*)
-                   | NONE =>
-                      let val Ts = map type_of args
+                val _ = if !trace_abs then warning ("Looking up " ^ string_of_cterm cu') else ();
+                val (ax,ax',thy) =
+                 case List.mapPartial (match_rhs thy abs_v_u) (Net.match_term (!abstraction_cache) u')
+                        of
+                     (ax,ax')::_ => 
+                       (if !trace_abs then warning ("Re-using axiom " ^ string_of_thm ax) else ();
+                        (ax,ax',thy))
+                   | [] =>
+                      let val _ = if !trace_abs then warning "Lookup was empty" else ();
+                          val Ts = map type_of args
                           val cT = Ts ---> (Tvs ---> typ_of (ctyp_of_term cu'))
                           val thy = theory_of_thm u'_th
                           val c = Const (Sign.full_name thy cname, cT)
@@ -316,27 +345,34 @@
                           val cdef = cname ^ "_def"
                           val thy = Theory.add_defs_i false false
                                        [(cdef, equals cT $ c $ rhs)] thy
-                          val ax = get_axiom thy cdef
-                          val _ = abstraction_cache := Net.insert_term eq_absdef (v_rhs,ax)
-                                    (!abstraction_cache)
+                          val _ = if !trace_abs then (warning ("Definition is " ^ 
+                                                      string_of_thm (get_axiom thy cdef))) 
+                                  else ();
+                          val ax = get_axiom thy cdef |> freeze_thm
+                                     |> mk_object_eq |> strip_lambdas (length args)
+                                     |> mk_meta_eq |> Meson.generalize
+                          val (_,ax') = Option.valOf (match_rhs thy abs_v_u ax)
+                          val _ = if !trace_abs then 
+                                    (warning ("Declaring: " ^ string_of_thm ax);
+                                     warning ("Instance: " ^ string_of_thm ax')) 
+                                  else ();
+                          val _ = abstraction_cache := Net.insert_term eq_absdef 
+                                            ((Logic.varify u'), ax) (!abstraction_cache)
                             handle Net.INSERT =>
                               raise THM ("declare_absfuns: INSERT", 0, [th,u'_th,ax])
-                       in  (ax,thy)  end
-                val _ = assert (v_rhs aconv rhs_of ax) "declare_absfuns: rhs mismatch"
-                val def = #1 (Drule.freeze_thaw ax)
-                val def_args = list_combination def (map (cterm_of thy) args)
-            in (transitive (abstract_rule_list vs cvs u'_th) (symmetric def_args),
-                def :: defs) end
+                       in  (ax,ax',thy)  end
+            in if !trace_abs then warning ("Lookup result: " ^ string_of_thm ax') else ();
+               (transitive (abstract_rule_list vs cvs u'_th) (symmetric ax'), ax::defs) end
         | (t1$t2) =>
             let val (ct1,ct2) = Thm.dest_comb ct
                 val (th1,defs1) = abstract thy ct1
                 val (th2,defs2) = abstract (theory_of_thm th1) ct2
             in  (combination th1 th2, defs1@defs2)  end
-      val _ = if !trace_abs then warning (string_of_thm th) else ();
+      val _ = if !trace_abs then warning ("declare_absfuns, Abstracting: " ^ string_of_thm th) else ();
       val (eqth,defs) = abstract (theory_of_thm th) (cprop_of th)
-      val ths = equal_elim eqth th ::
-                map (forall_intr_vars o strip_lambdas o object_eq) defs
-  in  (theory_of_thm eqth, ths)  end;
+      val ths = equal_elim eqth th :: map (strip_lambdas ~1 o mk_object_eq o freeze_thm) defs
+      val _ = if !trace_abs then warning ("declare_absfuns, Result: " ^ string_of_thm (hd ths)) else ();
+  in  (theory_of_thm eqth, map Drule.eta_contraction_rule ths)  end;
 
 fun name_of def = SOME (#1 (dest_Free (lhs_of def))) handle _ => NONE;
 
@@ -356,7 +392,8 @@
                 val (cvs,cta) = dest_abs_list ct
                 val (vs,Tvs) = ListPair.unzip (map (dest_Free o term_of) cvs)
                 val (u'_th,defs) = abstract cta
-                val cu' = crhs_of u'_th
+                val cu' = Drule.crhs_of u'_th
+                val u' = term_of cu'
                 (*Could use Thm.cabs instead of lambda to work at level of cterms*)
                 val abs_v_u = lambda_list (map term_of cvs) (term_of cu')
                 (*get the formal parameters: free variables not present in the defs
@@ -365,34 +402,37 @@
                 val crhs = list_cabs (map cterm args, cterm abs_v_u)
                       (*Forms a lambda-abstraction over the formal parameters*)
                 val rhs = term_of crhs
-                val def =  (*FIXME: can we also reuse the const-abstractions?*)
-                 case List.find (fn ax => rhs aconv rhs_of ax andalso
-                                          Context.joinable (thy, theory_of_thm ax))
-                        (Net.match_term (!abstraction_cache) rhs) of
-                     SOME ax => ax
-                   | NONE =>
+                val (ax,ax') =
+                 case List.mapPartial (match_rhs thy abs_v_u) 
+                        (Net.match_term (!abstraction_cache) u') of
+                     (ax,ax')::_ => 
+                       (if !trace_abs then warning ("Re-using axiom " ^ string_of_thm ax) else ();
+                        (ax,ax'))
+                   | [] =>
                       let val Ts = map type_of args
                           val const_ty = Ts ---> (Tvs ---> typ_of (ctyp_of_term cu'))
                           val c = Free (gensym "abs_", const_ty)
                           val ax = assume (Thm.capply (cterm (equals const_ty $ c)) crhs)
+                                     |> mk_object_eq |> strip_lambdas (length args)
+                                     |> mk_meta_eq |> Meson.generalize
+                          val (_,ax') = Option.valOf (match_rhs thy abs_v_u ax)
                           val _ = abstraction_cache := Net.insert_term eq_absdef (rhs,ax)
                                     (!abstraction_cache)
                             handle Net.INSERT =>
                               raise THM ("assume_absfuns: INSERT", 0, [th,u'_th,ax])
-                      in ax end
-                val _ = assert (rhs aconv rhs_of def) "assume_absfuns: rhs mismatch"
-                val def_args = list_combination def (map cterm args)
-            in (transitive (abstract_rule_list vs cvs u'_th) (symmetric def_args),
-                def :: defs) end
+                      in (ax,ax') end
+            in if !trace_abs then warning ("Lookup result: " ^ string_of_thm ax') else ();
+               (transitive (abstract_rule_list vs cvs u'_th) (symmetric ax'), ax::defs) end
         | (t1$t2) =>
             let val (ct1,ct2) = Thm.dest_comb ct
                 val (t1',defs1) = abstract ct1
                 val (t2',defs2) = abstract ct2
             in  (combination t1' t2', defs1@defs2)  end
+      val _ = if !trace_abs then warning ("assume_absfuns, Abstracting: " ^ string_of_thm th) else ();
       val (eqth,defs) = abstract (cprop_of th)
-  in  equal_elim eqth th ::
-      map (forall_intr_vars o strip_lambdas o object_eq) defs
-  end;
+      val ths = equal_elim eqth th :: map (strip_lambdas ~1 o mk_object_eq o freeze_thm) defs
+      val _ = if !trace_abs then warning ("assume_absfuns, Result: " ^ string_of_thm (hd ths)) else ();
+  in  map Drule.eta_contraction_rule ths  end;
 
 
 (*cterms are used throughout for efficiency*)
@@ -412,7 +452,7 @@
   an existential formula by a use of that function.
    Example: "EX x. x : A & x ~: B ==> sko A B : A & sko A B ~: B"  [.] *)
 fun skolem_of_def def =
-  let val (c,rhs) = Drule.dest_equals (cprop_of (#1 (Drule.freeze_thaw def)))
+  let val (c,rhs) = Drule.dest_equals (cprop_of (freeze_thm def))
       val (ch, frees) = c_variant_abs_multi (rhs, [])
       val (chilbert,cabs) = Thm.dest_comb ch
       val {sign,t, ...} = rep_cterm chilbert
@@ -428,12 +468,11 @@
        |> Thm.varifyT
   end;
 
-(*Converts an Isabelle theorem (intro, elim or simp format) into nnf.*)
-(*It now works for HOL too. *)
+(*Converts an Isabelle theorem (intro, elim or simp format, even higher-order) into NNF.*)
 fun to_nnf th =
     th |> transfer_to_Reconstruction
-       |> transform_elim |> zero_var_indexes |> Drule.freeze_thaw |> #1
-       |> ObjectLogic.atomize_thm |> make_nnf |> strip_lambdas;
+       |> transform_elim |> zero_var_indexes |> freeze_thm
+       |> ObjectLogic.atomize_thm |> make_nnf |> strip_lambdas ~1;
 
 (*The cache prevents repeated clausification of a theorem,
   and also repeated declaration of Skolem functions*)
@@ -445,17 +484,20 @@
 fun skolem_of_nnf th =
   map (skolem_of_def o assume o (cterm_of (theory_of_thm th))) (assume_skofuns th);
 
-fun assert_lambda_free ths = assert (forall (lambda_free o prop_of) ths);
+fun assert_lambda_free ths msg = 
+  case filter (not o lambda_free o prop_of) ths of
+      [] => ()
+     | ths' => error (msg ^ "\n" ^ space_implode "\n" (map string_of_thm ths'));
 
 fun assume_abstract th =
   if lambda_free (prop_of th) then [th]
-  else th |> eta_conversion_rule |> assume_absfuns
+  else th |> Drule.eta_contraction_rule |> assume_absfuns
           |> tap (fn ths => assert_lambda_free ths "assume_abstract: lambdas")
 
 (*Replace lambdas by assumed function definitions in the theorems*)
 fun assume_abstract_list ths =
   if abstract_lambdas then List.concat (map assume_abstract ths)
-  else map eta_conversion_rule ths;
+  else map Drule.eta_contraction_rule ths;
 
 (*Replace lambdas by declared function definitions in the theorems*)
 fun declare_abstract' (thy, []) = (thy, [])
@@ -463,8 +505,8 @@
       let val (thy', th_defs) =
             if lambda_free (prop_of th) then (thy, [th])
             else
-                th |> zero_var_indexes |> Drule.freeze_thaw |> #1
-                   |> eta_conversion_rule |> transfer thy |> declare_absfuns
+                th |> zero_var_indexes |> freeze_thm
+                   |> Drule.eta_contraction_rule |> transfer thy |> declare_absfuns
           val _ = assert_lambda_free th_defs "declare_abstract: lambdas"
           val (thy'', ths') = declare_abstract' (thy', ths)
       in  (thy'', th_defs @ ths')  end;
@@ -472,7 +514,7 @@
 (*FIXME DELETE if we decide to switch to abstractions*)
 fun declare_abstract (thy, ths) =
   if abstract_lambdas then declare_abstract' (thy, ths)
-  else (thy, map eta_conversion_rule ths);
+  else (thy, map Drule.eta_contraction_rule ths);
 
 (*Skolemize a named theorem, with Skolem functions as additional premises.*)
 (*also works for HOL*)
@@ -603,7 +645,7 @@
 fun skolem_cache (name,th) thy =
   let val prop = Thm.prop_of th
   in
-      if lambda_free prop orelse monomorphic prop
+      if lambda_free prop (*orelse monomorphic prop*)
       then thy    (*monomorphic theorems can be Skolemized on demand*)
       else #2 (skolem_cache_thm (name,th) thy)
   end;