improvements to abstraction generation
authorpaulson
Thu, 31 Aug 2006 10:20:22 +0200
changeset 20445 b222d9939e00
parent 20444 6c5e38a73db0
child 20446 7e616709bca2
improvements to abstraction generation
src/HOL/Tools/res_axioms.ML
--- a/src/HOL/Tools/res_axioms.ML	Thu Aug 31 10:18:26 2006 +0200
+++ b/src/HOL/Tools/res_axioms.ML	Thu Aug 31 10:20:22 2006 +0200
@@ -5,6 +5,7 @@
 Transformation of axiom rules (elim/intro/etc) into CNF forms.    
 *)
 
+(*FIXME: does this signature serve any purpose?*)
 signature RES_AXIOMS =
   sig
   val elimRule_tac : thm -> Tactical.tactic
@@ -37,6 +38,12 @@
 
 val trace_abs = ref false;
 
+(*Store definitions of abstraction functions, ensuring that identical right-hand
+  sides are denoted by the same functions and thereby reducing the need for
+  extensionality in proofs.
+  FIXME!  Store in theory data!!*)
+val abstraction_cache = ref Net.empty : thm Net.net ref;
+
 (**** Transformation of Elimination Rules into First-Order Formulas****)
 
 (* a tactic used to prove an elim-rule. *)
@@ -187,14 +194,14 @@
 
 (*Returns the vars of a theorem*)
 fun vars_of_thm th =
-  map (Thm.cterm_of (Thm.theory_of_thm th) o Var) (Drule.fold_terms Term.add_vars th []);
+  map (Thm.cterm_of (theory_of_thm th) o Var) (Drule.fold_terms Term.add_vars th []);
 
 (*Make a version of fun_cong with a given variable name*)
 local
     val fun_cong' = fun_cong RS asm_rl; (*renumber f, g to prevent clashes with (a,0)*)
     val cx = hd (vars_of_thm fun_cong');
     val ty = typ_of (ctyp_of_term cx);
-    val thy = Thm.theory_of_thm fun_cong;
+    val thy = theory_of_thm fun_cong;
     fun mkvar a = cterm_of thy (Var((a,0),ty));
 in
 fun xfun_cong x = Thm.instantiate ([], [(cx, mkvar x)]) fun_cong'
@@ -219,18 +226,23 @@
 fun eta_conversion_rule th =
   equal_elim (eta_conversion (cprop_of th)) th;
   
-fun crhs th =
+fun crhs_of th =
   case Drule.strip_comb (cprop_of th) of
       (f, [_, rhs]) => 
           (case term_of f of
                Const ("==", _) => rhs
-             | _ => raise THM ("crhs", 0, [th]))
-    | _ => raise THM ("crhs", 1, [th]);
+             | _ => raise THM ("crhs_of", 0, [th]))
+    | _ => raise THM ("crhs_of", 1, [th]);
+
+fun rhs_of th =
+  case #prop (rep_thm th) of
+      (Const("==",_) $ _ $ rhs) => rhs
+   | _ => raise THM ("rhs_of", 1, [th]);
 
 (*Apply a function definition to an argument, beta-reducing the result.*)
 fun beta_comb cf x =
   let val th1 = combination cf (reflexive x)
-      val th2 = beta_conversion false (crhs th1)
+      val th2 = beta_conversion false (crhs_of th1)
   in  transitive th1 th2  end;
 
 (*Apply a function definition to arguments, beta-reducing along the way.*)
@@ -246,31 +258,59 @@
      else error ("Eta redex in term: " ^ string_of_cterm ct)
   end;
 
+fun eq_absdef (th1, th2) = 
+    Context.joinable (theory_of_thm th1, theory_of_thm th2)  andalso
+    rhs_of th1 aconv rhs_of th2;
+
+fun lambda_free (Abs _) = false
+  | lambda_free (t $ u) = lambda_free t andalso lambda_free u
+  | lambda_free _ = true;
+
+fun lambda_free_thm th = lambda_free (#prop (rep_thm th));
+  
 (*Traverse a theorem, declaring abstraction function definitions. String s is the suggested
   prefix for the constants. Resulting theory is returned in the first theorem. *)
 fun declare_absfuns th =
-  let fun abstract thy ct = case term_of ct of
+  let fun abstract thy ct = 
+        if lambda_free (term_of ct) then (transfer thy (reflexive ct), [])
+        else
+        case term_of ct of
           Abs (_,T,u) =>
 	    let val cname = gensym "abs_"
 	        val _ = assert_eta_free ct;
 		val (cv,cta) = Thm.dest_abs NONE ct
 		val v = (#1 o dest_Free o term_of) cv
 		val (u'_th,defs) = abstract thy cta
-                val cu' = crhs u'_th
+                val cu' = crhs_of u'_th
 		val abs_v_u = lambda (term_of cv) (term_of cu')
 		(*get the formal parameters: ALL variables free in the term*)
 		val args = term_frees abs_v_u
-		val Ts = map type_of args
-		val cT = Ts ---> (T --> typ_of (ctyp_of_term cu'))
-		val thy = theory_of_thm u'_th
-		val c = Const (Sign.full_name thy cname, cT)
-		val thy = Theory.add_consts_i [(cname, cT, NoSyn)] thy
-		           (*Theory is augmented with the constant, then its def*)
 		val rhs = list_abs_free (map dest_Free args, abs_v_u)
 		      (*Forms a lambda-abstraction over the formal parameters*)
-		val cdef = cname ^ "_def"
-		val thy = Theory.add_defs_i false false [(cdef, equals cT $ c $ rhs)] thy		      
-		val def = #1 (Drule.freeze_thaw (get_axiom thy cdef))
+		val v_rhs = Logic.varify rhs
+		val (ax,thy) = 
+		 case List.find (fn ax => v_rhs aconv rhs_of ax) 
+			(Net.match_term (!abstraction_cache) v_rhs) of
+		     SOME ax => (ax,thy)   (*cached axiom, current theory*)
+		   | NONE =>
+		      let val Ts = map type_of args
+			  val cT = Ts ---> (T --> typ_of (ctyp_of_term cu'))
+			  val thy = theory_of_thm u'_th
+			  val c = Const (Sign.full_name thy cname, cT)
+			  val thy = Theory.add_consts_i [(cname, cT, NoSyn)] thy
+				     (*Theory is augmented with the constant, 
+				       then its definition*)
+			  val cdef = cname ^ "_def"
+			  val thy = Theory.add_defs_i false false 
+				       [(cdef, equals cT $ c $ rhs)] thy		      
+			  val ax = get_axiom thy cdef
+			  val _ = abstraction_cache := Net.insert_term eq_absdef (v_rhs,ax) 
+				    (!abstraction_cache)
+			    handle Net.INSERT => 
+			      raise THM ("declare_absfuns: INSERT", 0, [th,u'_th,ax])
+		       in  (ax,thy)  end
+		val _ = assert (v_rhs aconv rhs_of ax) "declare_absfuns: rhs mismatch"
+		val def = #1 (Drule.freeze_thaw ax)
 		val def_args = list_combination def (map (cterm_of thy) args)
 	    in (transitive (abstract_rule v cv u'_th) (symmetric def_args), 
 	        def :: defs) end
@@ -279,7 +319,6 @@
 	        val (th1,defs1) = abstract thy ct1
 		val (th2,defs2) = abstract (theory_of_thm th1) ct2
 	    in  (combination th1 th2, defs1@defs2)  end
-	| _ => (transfer thy (reflexive ct), [])
       val _ = if !trace_abs then warning (string_of_thm th) else ();
       val (eqth,defs) = abstract (theory_of_thm th) (cprop_of th)
       val ths = equal_elim eqth th ::
@@ -287,23 +326,40 @@
   in  (theory_of_thm eqth, ths)  end;
 
 fun assume_absfuns th =
-  let val cterm = cterm_of (Thm.theory_of_thm th)
-      fun abstract vs ct = case term_of ct of
+  let val thy = theory_of_thm th
+      val cterm = cterm_of thy
+      fun abstract vs ct = 
+        if lambda_free (term_of ct) then (reflexive ct, [])
+        else
+        case term_of ct of
           Abs (_,T,u) =>
 	    let val (cv,cta) = Thm.dest_abs NONE ct
 	        val _ = assert_eta_free ct;
 		val v = (#1 o dest_Free o term_of) cv
 		val (u'_th,defs) = abstract (v::vs) cta
-                val cu' = crhs u'_th
+                val cu' = crhs_of u'_th
 		val abs_v_u = Thm.cabs cv cu'
 		(*get the formal parameters: bound variables also present in the term*)
 		val args = filter (valid_name vs) (term_frees (term_of abs_v_u))
-		val Ts = map type_of args
-		val const_ty = Ts ---> (T --> typ_of (ctyp_of_term cu'))
-		val c = Free (gensym "abs_", const_ty)
-		val rhs = list_cabs (map cterm args, abs_v_u)
+		val crhs = list_cabs (map cterm args, abs_v_u)
 		      (*Forms a lambda-abstraction over the formal parameters*)
-		val def = assume (Thm.capply (cterm (equals const_ty $ c)) rhs)
+		val rhs = term_of crhs
+		val def =  (*FIXME: can we also use the const-abstractions?*)
+		 case List.find (fn ax => rhs aconv rhs_of ax andalso
+					  Context.joinable (thy, theory_of_thm ax)) 
+			(Net.match_term (!abstraction_cache) rhs) of
+		     SOME ax => ax
+		   | NONE =>
+		      let val Ts = map type_of args
+			  val const_ty = Ts ---> (T --> typ_of (ctyp_of_term cu'))
+			  val c = Free (gensym "abs_", const_ty)
+			  val ax = assume (Thm.capply (cterm (equals const_ty $ c)) crhs)
+			  val _ = abstraction_cache := Net.insert_term eq_absdef (rhs,ax) 
+				    (!abstraction_cache)
+			    handle Net.INSERT => 
+			      raise THM ("assume_absfuns: INSERT", 0, [th,u'_th,ax])
+		      in ax end
+		val _ = assert (rhs aconv rhs_of def) "assume_absfuns: rhs mismatch"
 		val def_args = list_combination def (map cterm args)
 	    in (transitive (abstract_rule v cv u'_th) (symmetric def_args), 
 	        def :: defs) end
@@ -312,7 +368,6 @@
 	        val (t1',defs1) = abstract vs ct1
 		val (t2',defs2) = abstract vs ct2
 	    in  (combination t1' t2', defs1@defs2)  end
-	| _ => (reflexive ct, [])
       val (eqth,defs) = abstract [] (cprop_of th)
   in  equal_elim eqth th ::
       map (forall_intr_vars o strip_lambdas o object_eq) defs
@@ -369,17 +424,25 @@
 fun skolem_of_nnf th =
   map (skolem_of_def o assume o (cterm_of (theory_of_thm th))) (assume_skofuns th);
 
+fun assume_abstract th =
+  if lambda_free_thm th then [th]
+  else th |> eta_conversion_rule |> assume_absfuns 
+          |> tap (fn ths => assert ((forall lambda_free_thm) ths) "assume_abstract: lambdas")
+
 (*Replace lambdas by assumed function definitions in the theorems*)
-fun assume_abstract ths =
-  if abstract_lambdas then List.concat (map (assume_absfuns o eta_conversion_rule) ths)
+fun assume_abstract_list ths =
+  if abstract_lambdas then List.concat (map assume_abstract ths)
   else map eta_conversion_rule ths;
 
 (*Replace lambdas by declared function definitions in the theorems*)
 fun declare_abstract' (thy, []) = (thy, [])
   | declare_abstract' (thy, th::ths) =
       let val (thy', th_defs) = 
-            th |> zero_var_indexes |> Drule.freeze_thaw |> #1
-               |> eta_conversion_rule |> transfer thy |> declare_absfuns
+            if lambda_free_thm th then (thy, [th])
+            else
+		th |> zero_var_indexes |> Drule.freeze_thaw |> #1
+		   |> eta_conversion_rule |> transfer thy |> declare_absfuns
+	  val _ = assert ((forall lambda_free_thm) th_defs) "declare_abstract: lambdas"
 	  val (thy'', ths') = declare_abstract' (thy', ths)
       in  (thy'', th_defs @ ths')  end;
 
@@ -393,7 +456,7 @@
 fun skolem_thm th = 
   let val nnfth = to_nnf th
   in  Meson.make_cnf (skolem_of_nnf nnfth) nnfth
-      |> assume_abstract |> Meson.finish_cnf |> rm_redundant_cls
+      |> assume_abstract_list |> Meson.finish_cnf |> rm_redundant_cls
   end
   handle THM _ => [];
 
@@ -420,7 +483,10 @@
 	(case skolem thy (name, Thm.transfer thy th) of
 	     NONE => ([th],thy)
 	   | SOME (thy',cls) => 
-	       (change clause_cache (Symtab.update (name, (th, cls))); (cls,thy')))
+	       (if null cls then warning ("skolem_cache: empty clause set for " ^ name)
+	        else ();
+	        change clause_cache (Symtab.update (name, (th, cls))); 
+	        (cls,thy')))
     | SOME (th',cls) =>
         if eq_thm(th,th') then (cls,thy)
 	else (Output.debug ("skolem_cache: Ignoring variant of theorem " ^ name); 
@@ -532,8 +598,9 @@
 
 fun conj2_rule (th1,th2) = conjI OF [th1,th2];
 
-(*Conjoin a list of clauses to recreate a single theorem*)
-val conj_rule = foldr1 conj2_rule;
+(*Conjoin a list of theorems to recreate a single theorem*)
+fun conj_rule []  = raise THM ("conj_rule", 0, []) 
+  | conj_rule ths = foldr1 conj2_rule ths;
 
 fun skolem_attr (Context.Theory thy, th) =
       let val name = Thm.name_of_thm th