support for recursion over mutually recursive IDTs
authorwebertj
Mon, 18 Apr 2005 17:20:49 +0200
changeset 15767 8ed9fcc004fe
parent 15766 b08feb003f3c
child 15768 a167d50eadbb
support for recursion over mutually recursive IDTs
src/HOL/Tools/refute.ML
src/HOL/ex/Refute_Examples.thy
--- a/src/HOL/Tools/refute.ML	Mon Apr 18 15:54:23 2005 +0200
+++ b/src/HOL/Tools/refute.ML	Mon Apr 18 17:20:49 2005 +0200
@@ -192,7 +192,7 @@
 			 parameters = Symtab.merge (op=) (pa1, pa2)};
 		fun print sg {interpreters, printers, parameters} =
 			Pretty.writeln (Pretty.chunks
-				[Pretty.strs ("default parameters:" :: List.concat (map (fn (name,value) => [name, "=", value]) (Symtab.dest parameters))),
+				[Pretty.strs ("default parameters:" :: List.concat (map (fn (name, value) => [name, "=", value]) (Symtab.dest parameters))),
 				 Pretty.strs ("interpreters:" :: map fst interpreters),
 				 Pretty.strs ("printers:" :: map fst printers)]);
 	end;
@@ -624,6 +624,7 @@
 			| Const ("op +", T as Type ("fun", [Type ("nat", []), Type ("fun", [Type ("nat", []), Type ("nat", [])])])) => collect_type_axioms (axs, T)
 			| Const ("op -", T as Type ("fun", [Type ("nat", []), Type ("fun", [Type ("nat", []), Type ("nat", [])])])) => collect_type_axioms (axs, T)
 			| Const ("op *", T as Type ("fun", [Type ("nat", []), Type ("fun", [Type ("nat", []), Type ("nat", [])])])) => collect_type_axioms (axs, T)
+			| Const ("List.op @", T)          => collect_type_axioms (axs, T)
 			(* simply-typed lambda calculus *)
 			| Const (s, T)                    =>
 				let
@@ -896,8 +897,8 @@
 
 	(* interpretation -> prop_formula *)
 
-	fun toTrue (Leaf [fm,_]) = fm
-	  | toTrue _             = raise REFUTE ("toTrue", "interpretation does not denote a Boolean value");
+	fun toTrue (Leaf [fm, _]) = fm
+	  | toTrue _              = raise REFUTE ("toTrue", "interpretation does not denote a Boolean value");
 
 (* ------------------------------------------------------------------------- *)
 (* toFalse: converts the interpretation of a Boolean value to a              *)
@@ -907,8 +908,8 @@
 
 	(* interpretation -> prop_formula *)
 
-	fun toFalse (Leaf [_,fm]) = fm
-	  | toFalse _             = raise REFUTE ("toFalse", "interpretation does not denote a Boolean value");
+	fun toFalse (Leaf [_, fm]) = fm
+	  | toFalse _              = raise REFUTE ("toFalse", "interpretation does not denote a Boolean value");
 
 (* ------------------------------------------------------------------------- *)
 (* find_model: repeatedly calls 'interpret' with appropriate parameters,     *)
@@ -1607,7 +1608,7 @@
 				in
 					SOME (intr, (typs, (t, intr)::terms), args')
 				end
-			| Var ((x,i), Type ("set", [T])) =>
+			| Var ((x, i), Type ("set", [T])) =>
 				let
 					val (intr, _, args') = interpret thy (typs, []) args (Var ((x,i), T --> HOLogic.boolT))
 				in
@@ -1896,27 +1897,33 @@
 					result  (* just keep 'result' *)
 				| NONE =>
 					if s mem (#rec_names info) then
-						(* okay, we do have a recursion operator of the datatype given by 'info' *)
+						(* we do have a recursion operator of the datatype given by 'info', *)
+						(* or of a mutually recursive datatype                              *)
 						let
-							val index               = #index info
-							val descr               = #descr info
-							val (_, dtyps, constrs) = (valOf o assoc) (descr, index)
-							(* the total number of constructors, including those of different    *)
-							(* (mutually recursive) datatypes within the same descriptor 'descr' *)
-							val constrs_count = sum (map (fn (_, (_, _, cs)) => length cs) descr)
-							val params_count  = length params
+							val index              = #index info
+							val descr              = #descr info
+							val (dtname, dtyps, _) = (valOf o assoc) (descr, index)
+							(* number of all constructors, including those of different           *)
+							(* (mutually recursive) datatypes within the same descriptor 'descr'  *)
+							val mconstrs_count     = sum (map (fn (_, (_, _, cs)) => length cs) descr)
+							val params_count       = length params
+							(* the type of a recursion operator: [T1, ..., Tn, IDT] ---> Tresult *)
+							val IDT = List.nth (binder_types T, mconstrs_count)
 						in
-							if constrs_count < params_count then
+							if (fst o dest_Type) IDT <> dtname then
+								(* recursion operator of a mutually recursive datatype *)
+								NONE
+							else if mconstrs_count < params_count then
 								(* too many actual parameters; for now we'll use the *)
 								(* 'stlc_interpreter' to strip off one application   *)
 								NONE
-							else if constrs_count > params_count then
+							else if mconstrs_count > params_count then
 								(* too few actual parameters; we use eta expansion            *)
 								(* Note that the resulting expansion of lambda abstractions   *)
 								(* by the 'stlc_interpreter' may be rather slow (depending on *)
 								(* the argument types and the size of the IDT, of course).    *)
-								SOME (interpret thy model args (eta_expand t (constrs_count - params_count)))
-							else  (* constrs_count = params_count *)
+								SOME (interpret thy model args (eta_expand t (mconstrs_count - params_count)))
+							else  (* mconstrs_count = params_count *)
 								let
 									(* interpret each parameter separately *)
 									val ((model', args'), p_intrs) = foldl_map (fn ((m, a), p) =>
@@ -1925,23 +1932,37 @@
 										in
 											((m', a'), i)
 										end) ((model, args), params)
-									val (typs, terms) = model'
-									(* the type of a recursion operator: [T1, ..., Tn, IDT] ---> Tresult *)
-									val IDT       = List.nth (binder_types T, constrs_count)
+									val (typs, _) = model'
 									val typ_assoc = dtyps ~~ (snd o dest_Type) IDT
-									(* interpret each constructor of the datatype *)
-									(* TODO: we probably need to interpret every constructor in the descriptor, *)
-									(*       possibly for typs' instead of typs                                 *)
-									val c_intrs = map (#1 o interpret thy (typs, []) {maxvars=0, def_eq=false, next_idx=1, bounds=[], wellformed=True})
-										(map (fn (cname, cargs) => Const (cname, map (typ_of_dtyp descr typ_assoc) cargs ---> IDT)) constrs)
+									(* interpret each constructor in the descriptor (including *)
+									(* those of mutually recursive datatypes)                  *)
+									(* (int * interpretation list) list *)
+									val mc_intrs = map (fn (idx, (_, _, cs)) =>
+										let
+											val c_return_typ = typ_of_dtyp descr typ_assoc (DatatypeAux.DtRec idx)
+										in
+											(idx, map (fn (cname, cargs) =>
+												(#1 o interpret thy (typs, []) {maxvars=0, def_eq=false, next_idx=1, bounds=[], wellformed=True})
+													(Const (cname, map (typ_of_dtyp descr typ_assoc) cargs ---> c_return_typ))) cs)
+										end) descr
+									val _ = writeln (makestring index)
+									val _ = writeln (makestring descr)
+									val _ = writeln (makestring mc_intrs)
 									(* the recursion operator is a function that maps every element of *)
-									(* the inductive datatype to an element of the result type         *)
-									val (i, _, _) = interpret thy (typs, []) {maxvars=0, def_eq=false, next_idx=1, bounds=[], wellformed=True} (Free ("dummy", IDT))
-									val size      = size_of_type i
-									val INTRS     = Array.array (size, Leaf [])  (* the initial value 'Leaf []' does not matter; it will be overwritten *)
-									(* takes an interpretation, and if some leaf of this interpretation *)
-									(* is the 'elem'-th element of the datatype, the indices of the     *)
-									(* arguments leading to this leaf are returned                      *)
+									(* the inductive datatype (and of mutually recursive types) to an  *)
+									(* element of some result type                                     *)
+									(* (int * interpretation option Array.array) list *)
+									val INTRS = map (fn (idx, _) =>
+										let
+											val T         = typ_of_dtyp descr typ_assoc (DatatypeAux.DtRec idx)
+											val (i, _, _) = interpret thy (typs, []) {maxvars=0, def_eq=false, next_idx=1, bounds=[], wellformed=True} (Free ("dummy", T))
+											val size      = size_of_type i
+										in
+											(idx, Array.array (size, NONE))
+										end) descr
+									(* takes an interpretation, and if some leaf of this interpretation   *)
+									(* is the 'elem'-th element of the type, the indices of the arguments *)
+									(* leading to this leaf are returned                                  *)
 									(* interpretation -> int -> int list option *)
 									fun get_args (Leaf xs) elem =
 										if find_index_eq True xs = elem then
@@ -1962,27 +1983,46 @@
 										end
 									(* returns the index of the constructor and indices for its      *)
 									(* arguments that generate the 'elem'-th element of the datatype *)
-									(* int -> int * int list *)
-									fun get_cargs elem =
+									(* given by 'idx'                                                *)
+									(* int -> int -> int * int list *)
+									fun get_cargs idx elem =
 										let
 											(* int * interpretation list -> int * int list *)
 											fun get_cargs_rec (_, []) =
-												raise REFUTE ("IDT_recursion_interpreter", "no matching constructor found for element " ^ string_of_int elem)
+												raise REFUTE ("IDT_recursion_interpreter", "no matching constructor found for element " ^ string_of_int elem ^ " in datatype " ^ Sign.string_of_typ (sign_of thy) IDT ^ " (datatype index " ^ string_of_int idx ^ ")")
 											  | get_cargs_rec (n, x::xs) =
 												(case get_args x elem of
 												  SOME args => (n, args)
 												| NONE      => get_cargs_rec (n+1, xs))
 										in
-											get_cargs_rec (0, c_intrs)
+											get_cargs_rec (0, (valOf o assoc) (mc_intrs, idx))
 										end
-									(* int -> unit *)
-									fun compute_loop elem =
-										if elem=size then
-											()
-										else  (* elem < size *)
+									(* returns the number of constructors in datatypes that occur in *)
+									(* the descriptor 'descr' before the datatype given by 'idx'     *)
+									fun get_coffset idx =
+										let
+											fun get_coffset_acc _ [] =
+												raise REFUTE ("IDT_recursion_interpreter", "index " ^ string_of_int idx ^ " not found in descriptor")
+											  | get_coffset_acc sum ((i, (_, _, cs))::descr') =
+												if i=idx then
+													sum
+												else
+													get_coffset_acc (sum + length cs) descr'
+										in
+											get_coffset_acc 0 descr
+										end
+									(* computes one entry in INTRS, and recursively all entries needed for it, *)
+									(* where 'idx' gives the datatype and 'elem' the element of it             *)
+									(* int -> int -> interpretation *)
+									fun compute_array_entry idx elem =
+										case Array.sub ((valOf o assoc) (INTRS, idx), elem) of
+										  SOME result =>
+											(* simply return the previously computed result *)
+											result
+										| NONE =>
 											let
 												(* int * int list *)
-												val (c, args) = get_cargs elem
+												val (c, args) = get_cargs idx elem
 												(* interpretation * int list -> interpretation *)
 												fun select_subtree (tr, []) =
 													tr  (* return the whole tree *)
@@ -1991,22 +2031,40 @@
 												  | select_subtree (Node tr, x::xs) =
 													select_subtree (List.nth (tr, x), xs)
 												(* select the correct subtree of the parameter corresponding to constructor 'c' *)
-												val p_intr = select_subtree (List.nth (p_intrs, c), args)
-												(* find the indices of recursive arguments *)
-												val rec_args = map snd (List.filter (DatatypeAux.is_rec_type o fst) ((snd (List.nth (constrs, c))) ~~ args))
+												val p_intr = select_subtree (List.nth (p_intrs, get_coffset idx + c), args)
+												(* find the indices of the constructor's recursive arguments *)
+												val (_, _, constrs) = (valOf o assoc) (descr, idx)
+												val constr_args     = (snd o List.nth) (constrs, c)
+												val rec_args        = List.filter (DatatypeAux.is_rec_type o fst) (constr_args ~~ args)
+												val rec_args'       = map (fn (dtyp, elem) => (DatatypeAux.dest_DtRec dtyp, elem)) rec_args
 												(* apply 'p_intr' to recursively computed results *)
-												val rec_p_intr = Library.foldl (fn (i, n) => interpretation_apply (i, Array.sub (INTRS, n))) (p_intr, rec_args)
+												val result = foldl (fn ((idx, elem), intr) =>
+													interpretation_apply (intr, compute_array_entry idx elem)) p_intr rec_args'
 												(* update 'INTRS' *)
-												val _ = Array.update (INTRS, elem, rec_p_intr)
+												val _ = Array.update ((valOf o assoc) (INTRS, idx), elem, SOME result)
 											in
-												compute_loop (elem+1)
+												result
 											end
+									(* compute all entries in INTRS for the current datatype (given by 'index') *)
+									val size = (Array.length o valOf o assoc) (INTRS, index)
+									(* int -> unit *)
+									fun compute_loop elem =
+										if elem=size then
+											(* terminate *)
+											()
+										else (  (* elem < size *)
+											compute_array_entry index elem;
+											compute_loop (elem+1)
+										)
 									val _ = compute_loop 0
+									val _ = Array.modifyi ... ((valOf o assoc) (INTRS, index))
 									(* 'a Array.array -> 'a list *)
 									fun toList arr =
 										Array.foldr op:: [] arr
+									val _ = writeln (makestring INTRS)
 								in
-									SOME (Node (toList INTRS), model', args')
+									(* return the part of 'INTRS' that corresponds to the current datatype *)
+									SOME ((Node o (map valOf) o toList o valOf o assoc) (INTRS, index), model', args')
 								end
 						end
 					else
@@ -2111,7 +2169,8 @@
 				val (i_nat, _, _) = interpret thy model {maxvars=0, def_eq=false, next_idx=1, bounds=[], wellformed=True} (Free ("dummy", Type ("nat", [])))
 				val size_nat      = size_of_type i_nat
 				(* int -> int -> interpretation *)
-				fun plus m n = let
+				fun plus m n =
+					let
 						val element = (m+n)+1
 					in
 						if element > size_nat then
@@ -2138,7 +2197,8 @@
 				val (i_nat, _, _) = interpret thy model {maxvars=0, def_eq=false, next_idx=1, bounds=[], wellformed=True} (Free ("dummy", Type ("nat", [])))
 				val size_nat      = size_of_type i_nat
 				(* int -> int -> interpretation *)
-				fun minus m n = let
+				fun minus m n =
+					let
 						val element = Int.max (m-n, 0) + 1
 					in
 						Leaf ((replicate (element-1) False) @ True :: (replicate (size_nat - element) False))
@@ -2162,7 +2222,8 @@
 				val (i_nat, _, _) = interpret thy model {maxvars=0, def_eq=false, next_idx=1, bounds=[], wellformed=True} (Free ("dummy", Type ("nat", [])))
 				val size_nat      = size_of_type i_nat
 				(* nat -> nat -> interpretation *)
-				fun mult m n = let
+				fun mult m n =
+					let
 						val element = (m*n)+1
 					in
 						if element > size_nat then
@@ -2176,6 +2237,53 @@
 		| _ =>
 			NONE;
 
+	(* theory -> model -> arguments -> Term.term -> (interpretation * model * arguments) option *)
+
+	(* only an optimization: 'op @' could in principle be interpreted with    *)
+	(* interpreters available already (using its definition), but the code    *)
+	(* below is more efficient                                                *)
+
+	fun List_append_interpreter thy model args t =
+		case t of
+		  Const ("List.op @", Type ("fun", [Type ("List.list", [T]), Type ("fun", [Type ("List.list", [_]), Type ("List.list", [_])])])) =>
+			let
+				val (i_elem, _, _) = interpret thy model {maxvars=0, def_eq=false, next_idx=1, bounds=[], wellformed=True} (Free ("dummy", T))
+				val size_elem      = size_of_type i_elem
+				val (i_list, _, _) = interpret thy model {maxvars=0, def_eq=false, next_idx=1, bounds=[], wellformed=True} (Free ("dummy", Type ("List.list", [T])))
+				val size_list      = size_of_type i_list
+				(* power (a, b) computes a^b, for a>=0, b>=0 *)
+				(* int * int -> int *)
+				fun power (a, 0) = 1
+				  | power (a, 1) = a
+				  | power (a, b) = let val ab = power(a, b div 2) in ab * ab * power(a, b mod 2) end
+				(* log (a, b) computes floor(log_a(b)), i.e. the largest integer x s.t. a^x <= b, for a>=2, b>=1 *)
+				(* int * int -> int *)
+				fun log (a, b) =
+					let
+						fun logloop (ax, x) =
+							if ax > b then x-1 else logloop (a * ax, x+1)
+					in
+						logloop (1, 0)
+					end
+				(* nat -> nat -> interpretation *)
+				fun append m n =
+					let
+						(* The following formula depends on the order in which lists are *)
+						(* enumerated by the 'IDT_constructor_interpreter'.  It took me  *)
+						(* a while to come up with this formula.                         *)
+						val element = n + m * (if size_elem = 1 then 1 else power (size_elem, log (size_elem, n+1))) + 1
+					in
+						if element > size_list then
+							Leaf (replicate size_list False)
+						else
+							Leaf ((replicate (element-1) False) @ True :: (replicate (size_list - element) False))
+					end
+			in
+				SOME (Node (map (fn m => Node (map (append m) (0 upto size_list-1))) (0 upto size_list-1)), model, args)
+			end
+		| _ =>
+			NONE;
+
 
 (* ------------------------------------------------------------------------- *)
 (* PRINTERS                                                                  *)
@@ -2413,6 +2521,7 @@
 		 add_interpreter "Nat.op +" Nat_plus_interpreter,
 		 add_interpreter "Nat.op -" Nat_minus_interpreter,
 		 add_interpreter "Nat.op *" Nat_mult_interpreter,
+		 add_interpreter "List.op @" List_append_interpreter,
 		 add_printer "stlc" stlc_printer,
 		 add_printer "set"  set_printer,
 		 add_printer "IDT"  IDT_printer];
--- a/src/HOL/ex/Refute_Examples.thy	Mon Apr 18 15:54:23 2005 +0200
+++ b/src/HOL/ex/Refute_Examples.thy	Mon Apr 18 17:20:49 2005 +0200
@@ -161,7 +161,8 @@
 oops
 
 lemma "\<forall>a b c d e f. a=b | a=c | a=d | a=e | a=f | b=c | b=d | b=e | b=f | c=d | c=e | c=f | d=e | d=f | e=f"
-  refute
+  refute  -- {* quantification causes an expansion of the formula; the
+                previous version with free variables is refuted much faster *}
 oops
 
 text {* "Every reflexive and symmetric relation is transitive." *}
@@ -451,6 +452,8 @@
   apply (auto simp add: someI)
 done
 
+(******************************************************************************)
+
 subsection {* Subtypes (typedef), typedecl *}
 
 text {* A completely unspecified non-empty subset of @{typ "'a"}: *}
@@ -471,16 +474,16 @@
   refute
 oops
 
+(******************************************************************************)
+
 subsection {* Inductive datatypes *}
 
-text {* This is necessary because with quick\_and\_dirty set, the datatype
-package does not generate certain axioms for recursion operators.  Without
-these axioms, refute may find spurious countermodels. *}
+text {* With quick\_and\_dirty set, the datatype package does not generate
+  certain axioms for recursion operators.  Without these axioms, refute may
+  find spurious countermodels. *}
 
 ML {* reset quick_and_dirty; *}
 
-(*TODO*) ML {* set show_consts; set show_types; *}
-
 subsubsection {* unit *}
 
 lemma "P (x::unit)"
@@ -777,11 +780,11 @@
 oops
 
 lemma "P (aexp_bexp_rec_2 number ite equal x)"
-  (*TODO refute*)
+  refute
 oops
 
 lemma "P (case x of Equal a1 a2 \<Rightarrow> equal a1 a2)"
-  (*TODO: refute*)
+  refute
 oops
 
 subsubsection {* Other datatype examples *}
@@ -800,7 +803,11 @@
   refute
 oops
 
-lemma "P (Trie_rec tr x)"
+lemma "P (Trie_rec_1 a b c x)"
+  refute
+oops
+
+lemma "P (Trie_rec_2 a b c x)"
   refute
 oops
 
@@ -840,14 +847,106 @@
   refute
 oops
 
-subsubsection {* Examples involving certain functions *}
+text {* Taken from "Inductive datatypes in HOL", p.8: *}
+
+datatype ('a, 'b) T = C "'a \<Rightarrow> bool" | D "'b list"
+datatype 'c U = E "('c, 'c U) T"
+
+lemma "P (x::'c U)"
+  refute
+oops
+
+lemma "\<forall>x::'c U. P x"
+  refute
+oops
+
+lemma "P (E (C (\<lambda>a. True)))"
+  refute
+oops
+
+lemma "P (U_rec_1 e f g h i x)"
+  refute
+oops
+
+lemma "P (U_rec_2 e f g h i x)"
+  refute
+oops
+
+lemma "P (U_rec_3 e f g h i x)"
+  refute
+oops
+
+(******************************************************************************)
+
+subsection {* Records *}
+
+(*TODO: make use of pair types, rather than typedef, for record types*)
+
+record ('a, 'b) point =
+  xpos :: 'a
+  ypos :: 'b
+
+lemma "(x::('a, 'b) point) = y"
+  refute
+oops
+
+record ('a, 'b, 'c) extpoint = "('a, 'b) point" +
+  ext :: 'c
+
+lemma "(x::('a, 'b, 'c) extpoint) = y"
+  refute
+oops
+
+(******************************************************************************)
+
+subsection {* Inductively defined sets *}
+
+(*TODO: can we implement lfp/gfp more efficiently? *)
+
+consts
+  arbitrarySet :: "'a set"
+inductive arbitrarySet
+intros
+  "arbitrary : arbitrarySet"
+
+lemma "x : arbitrarySet"
+  (*TODO refute*)  -- {* unfortunately, this little example already takes too long *}
+oops
+
+consts
+  evenCard :: "'a set set"
+inductive evenCard
+intros
+  "{} : evenCard"
+  "\<lbrakk> S : evenCard; x \<notin> S; y \<notin> S; x \<noteq> y \<rbrakk> \<Longrightarrow> S \<union> {x, y} : evenCard"
+
+lemma "S : evenCard"
+  (*TODO refute*)  -- {* unfortunately, this little example already takes too long *}
+oops
+
+consts
+  even :: "nat set"
+  odd  :: "nat set"
+inductive even odd
+intros
+  "0 : even"
+  "n : even \<Longrightarrow> Suc n : odd"
+  "n : odd \<Longrightarrow> Suc n : even"
+
+lemma "n : odd"
+  (*TODO refute*)  -- {* unfortunately, this little example already takes too long *}
+oops
+
+(******************************************************************************)
+
+subsection {* Examples involving special functions *}
 
 lemma "card x = 0"
   refute
 oops
 
-lemma "P nat_rec"
-  refute
+lemma "finite x"
+  refute  -- {* no finite countermodel exists *}
 oops
 
 lemma "(x::nat) + y = 0"
@@ -874,18 +973,14 @@
   refute
 oops
 
-lemma "P (xs::'a list)"
-  refute ["List.list"=4, "'a"=2]
+lemma "a @ b = b @ a"
+  refute
 oops
 
-lemma "a @ b = b @ a"
-  (*TODO refute*)  -- {* unfortunately, this little example already takes too long *}
-oops
+(******************************************************************************)
 
 subsection {* Axiomatic type classes and overloading *}
 
-ML {* set show_consts; set show_types; set show_sorts; *}
-
 text {* A type class without axioms: *}
 
 axclass classA