--- a/src/HOL/Tools/refute.ML Mon Apr 18 15:54:23 2005 +0200
+++ b/src/HOL/Tools/refute.ML Mon Apr 18 17:20:49 2005 +0200
@@ -192,7 +192,7 @@
parameters = Symtab.merge (op=) (pa1, pa2)};
fun print sg {interpreters, printers, parameters} =
Pretty.writeln (Pretty.chunks
- [Pretty.strs ("default parameters:" :: List.concat (map (fn (name,value) => [name, "=", value]) (Symtab.dest parameters))),
+ [Pretty.strs ("default parameters:" :: List.concat (map (fn (name, value) => [name, "=", value]) (Symtab.dest parameters))),
Pretty.strs ("interpreters:" :: map fst interpreters),
Pretty.strs ("printers:" :: map fst printers)]);
end;
@@ -624,6 +624,7 @@
| Const ("op +", T as Type ("fun", [Type ("nat", []), Type ("fun", [Type ("nat", []), Type ("nat", [])])])) => collect_type_axioms (axs, T)
| Const ("op -", T as Type ("fun", [Type ("nat", []), Type ("fun", [Type ("nat", []), Type ("nat", [])])])) => collect_type_axioms (axs, T)
| Const ("op *", T as Type ("fun", [Type ("nat", []), Type ("fun", [Type ("nat", []), Type ("nat", [])])])) => collect_type_axioms (axs, T)
+ | Const ("List.op @", T) => collect_type_axioms (axs, T)
(* simply-typed lambda calculus *)
| Const (s, T) =>
let
@@ -896,8 +897,8 @@
(* interpretation -> prop_formula *)
- fun toTrue (Leaf [fm,_]) = fm
- | toTrue _ = raise REFUTE ("toTrue", "interpretation does not denote a Boolean value");
+ fun toTrue (Leaf [fm, _]) = fm
+ | toTrue _ = raise REFUTE ("toTrue", "interpretation does not denote a Boolean value");
(* ------------------------------------------------------------------------- *)
(* toFalse: converts the interpretation of a Boolean value to a *)
@@ -907,8 +908,8 @@
(* interpretation -> prop_formula *)
- fun toFalse (Leaf [_,fm]) = fm
- | toFalse _ = raise REFUTE ("toFalse", "interpretation does not denote a Boolean value");
+ fun toFalse (Leaf [_, fm]) = fm
+ | toFalse _ = raise REFUTE ("toFalse", "interpretation does not denote a Boolean value");
(* ------------------------------------------------------------------------- *)
(* find_model: repeatedly calls 'interpret' with appropriate parameters, *)
@@ -1607,7 +1608,7 @@
in
SOME (intr, (typs, (t, intr)::terms), args')
end
- | Var ((x,i), Type ("set", [T])) =>
+ | Var ((x, i), Type ("set", [T])) =>
let
val (intr, _, args') = interpret thy (typs, []) args (Var ((x,i), T --> HOLogic.boolT))
in
@@ -1896,27 +1897,33 @@
result (* just keep 'result' *)
| NONE =>
if s mem (#rec_names info) then
- (* okay, we do have a recursion operator of the datatype given by 'info' *)
+ (* we do have a recursion operator of the datatype given by 'info', *)
+ (* or of a mutually recursive datatype *)
let
- val index = #index info
- val descr = #descr info
- val (_, dtyps, constrs) = (valOf o assoc) (descr, index)
- (* the total number of constructors, including those of different *)
- (* (mutually recursive) datatypes within the same descriptor 'descr' *)
- val constrs_count = sum (map (fn (_, (_, _, cs)) => length cs) descr)
- val params_count = length params
+ val index = #index info
+ val descr = #descr info
+ val (dtname, dtyps, _) = (valOf o assoc) (descr, index)
+ (* number of all constructors, including those of different *)
+ (* (mutually recursive) datatypes within the same descriptor 'descr' *)
+ val mconstrs_count = sum (map (fn (_, (_, _, cs)) => length cs) descr)
+ val params_count = length params
+ (* the type of a recursion operator: [T1, ..., Tn, IDT] ---> Tresult *)
+ val IDT = List.nth (binder_types T, mconstrs_count)
in
- if constrs_count < params_count then
+ if (fst o dest_Type) IDT <> dtname then
+ (* recursion operator of a mutually recursive datatype *)
+ NONE
+ else if mconstrs_count < params_count then
(* too many actual parameters; for now we'll use the *)
(* 'stlc_interpreter' to strip off one application *)
NONE
- else if constrs_count > params_count then
+ else if mconstrs_count > params_count then
(* too few actual parameters; we use eta expansion *)
(* Note that the resulting expansion of lambda abstractions *)
(* by the 'stlc_interpreter' may be rather slow (depending on *)
(* the argument types and the size of the IDT, of course). *)
- SOME (interpret thy model args (eta_expand t (constrs_count - params_count)))
- else (* constrs_count = params_count *)
+ SOME (interpret thy model args (eta_expand t (mconstrs_count - params_count)))
+ else (* mconstrs_count = params_count *)
let
(* interpret each parameter separately *)
val ((model', args'), p_intrs) = foldl_map (fn ((m, a), p) =>
@@ -1925,23 +1932,37 @@
in
((m', a'), i)
end) ((model, args), params)
- val (typs, terms) = model'
- (* the type of a recursion operator: [T1, ..., Tn, IDT] ---> Tresult *)
- val IDT = List.nth (binder_types T, constrs_count)
+ val (typs, _) = model'
val typ_assoc = dtyps ~~ (snd o dest_Type) IDT
- (* interpret each constructor of the datatype *)
- (* TODO: we probably need to interpret every constructor in the descriptor, *)
- (* possibly for typs' instead of typs *)
- val c_intrs = map (#1 o interpret thy (typs, []) {maxvars=0, def_eq=false, next_idx=1, bounds=[], wellformed=True})
- (map (fn (cname, cargs) => Const (cname, map (typ_of_dtyp descr typ_assoc) cargs ---> IDT)) constrs)
+ (* interpret each constructor in the descriptor (including *)
+ (* those of mutually recursive datatypes) *)
+ (* (int * interpretation list) list *)
+ val mc_intrs = map (fn (idx, (_, _, cs)) =>
+ let
+ val c_return_typ = typ_of_dtyp descr typ_assoc (DatatypeAux.DtRec idx)
+ in
+ (idx, map (fn (cname, cargs) =>
+ (#1 o interpret thy (typs, []) {maxvars=0, def_eq=false, next_idx=1, bounds=[], wellformed=True})
+ (Const (cname, map (typ_of_dtyp descr typ_assoc) cargs ---> c_return_typ))) cs)
+ end) descr
+ val _ = writeln (makestring index)
+ val _ = writeln (makestring descr)
+ val _ = writeln (makestring mc_intrs)
(* the recursion operator is a function that maps every element of *)
- (* the inductive datatype to an element of the result type *)
- val (i, _, _) = interpret thy (typs, []) {maxvars=0, def_eq=false, next_idx=1, bounds=[], wellformed=True} (Free ("dummy", IDT))
- val size = size_of_type i
- val INTRS = Array.array (size, Leaf []) (* the initial value 'Leaf []' does not matter; it will be overwritten *)
- (* takes an interpretation, and if some leaf of this interpretation *)
- (* is the 'elem'-th element of the datatype, the indices of the *)
- (* arguments leading to this leaf are returned *)
+ (* the inductive datatype (and of mutually recursive types) to an *)
+ (* element of some result type *)
+ (* (int * interpretation option Array.array) list *)
+ val INTRS = map (fn (idx, _) =>
+ let
+ val T = typ_of_dtyp descr typ_assoc (DatatypeAux.DtRec idx)
+ val (i, _, _) = interpret thy (typs, []) {maxvars=0, def_eq=false, next_idx=1, bounds=[], wellformed=True} (Free ("dummy", T))
+ val size = size_of_type i
+ in
+ (idx, Array.array (size, NONE))
+ end) descr
+ (* takes an interpretation, and if some leaf of this interpretation *)
+ (* is the 'elem'-th element of the type, the indices of the arguments *)
+ (* leading to this leaf are returned *)
(* interpretation -> int -> int list option *)
fun get_args (Leaf xs) elem =
if find_index_eq True xs = elem then
@@ -1962,27 +1983,46 @@
end
(* returns the index of the constructor and indices for its *)
(* arguments that generate the 'elem'-th element of the datatype *)
- (* int -> int * int list *)
- fun get_cargs elem =
+ (* given by 'idx' *)
+ (* int -> int -> int * int list *)
+ fun get_cargs idx elem =
let
(* int * interpretation list -> int * int list *)
fun get_cargs_rec (_, []) =
- raise REFUTE ("IDT_recursion_interpreter", "no matching constructor found for element " ^ string_of_int elem)
+ raise REFUTE ("IDT_recursion_interpreter", "no matching constructor found for element " ^ string_of_int elem ^ " in datatype " ^ Sign.string_of_typ (sign_of thy) IDT ^ " (datatype index " ^ string_of_int idx ^ ")")
| get_cargs_rec (n, x::xs) =
(case get_args x elem of
SOME args => (n, args)
| NONE => get_cargs_rec (n+1, xs))
in
- get_cargs_rec (0, c_intrs)
+ get_cargs_rec (0, (valOf o assoc) (mc_intrs, idx))
end
- (* int -> unit *)
- fun compute_loop elem =
- if elem=size then
- ()
- else (* elem < size *)
+ (* returns the number of constructors in datatypes that occur in *)
+ (* the descriptor 'descr' before the datatype given by 'idx' *)
+ fun get_coffset idx =
+ let
+ fun get_coffset_acc _ [] =
+ raise REFUTE ("IDT_recursion_interpreter", "index " ^ string_of_int idx ^ " not found in descriptor")
+ | get_coffset_acc sum ((i, (_, _, cs))::descr') =
+ if i=idx then
+ sum
+ else
+ get_coffset_acc (sum + length cs) descr'
+ in
+ get_coffset_acc 0 descr
+ end
+ (* computes one entry in INTRS, and recursively all entries needed for it, *)
+ (* where 'idx' gives the datatype and 'elem' the element of it *)
+ (* int -> int -> interpretation *)
+ fun compute_array_entry idx elem =
+ case Array.sub ((valOf o assoc) (INTRS, idx), elem) of
+ SOME result =>
+ (* simply return the previously computed result *)
+ result
+ | NONE =>
let
(* int * int list *)
- val (c, args) = get_cargs elem
+ val (c, args) = get_cargs idx elem
(* interpretation * int list -> interpretation *)
fun select_subtree (tr, []) =
tr (* return the whole tree *)
@@ -1991,22 +2031,40 @@
| select_subtree (Node tr, x::xs) =
select_subtree (List.nth (tr, x), xs)
(* select the correct subtree of the parameter corresponding to constructor 'c' *)
- val p_intr = select_subtree (List.nth (p_intrs, c), args)
- (* find the indices of recursive arguments *)
- val rec_args = map snd (List.filter (DatatypeAux.is_rec_type o fst) ((snd (List.nth (constrs, c))) ~~ args))
+ val p_intr = select_subtree (List.nth (p_intrs, get_coffset idx + c), args)
+ (* find the indices of the constructor's recursive arguments *)
+ val (_, _, constrs) = (valOf o assoc) (descr, idx)
+ val constr_args = (snd o List.nth) (constrs, c)
+ val rec_args = List.filter (DatatypeAux.is_rec_type o fst) (constr_args ~~ args)
+ val rec_args' = map (fn (dtyp, elem) => (DatatypeAux.dest_DtRec dtyp, elem)) rec_args
(* apply 'p_intr' to recursively computed results *)
- val rec_p_intr = Library.foldl (fn (i, n) => interpretation_apply (i, Array.sub (INTRS, n))) (p_intr, rec_args)
+ val result = foldl (fn ((idx, elem), intr) =>
+ interpretation_apply (intr, compute_array_entry idx elem)) p_intr rec_args'
(* update 'INTRS' *)
- val _ = Array.update (INTRS, elem, rec_p_intr)
+ val _ = Array.update ((valOf o assoc) (INTRS, idx), elem, SOME result)
in
- compute_loop (elem+1)
+ result
end
+ (* compute all entries in INTRS for the current datatype (given by 'index') *)
+ val size = (Array.length o valOf o assoc) (INTRS, index)
+ (* int -> unit *)
+ fun compute_loop elem =
+ if elem=size then
+ (* terminate *)
+ ()
+ else ( (* elem < size *)
+ compute_array_entry index elem;
+ compute_loop (elem+1)
+ )
val _ = compute_loop 0
+ val _ = Array.modifyi ... ((valOf o assoc) (INTRS, index))
(* 'a Array.array -> 'a list *)
fun toList arr =
Array.foldr op:: [] arr
+ val _ = writeln (makestring INTRS)
in
- SOME (Node (toList INTRS), model', args')
+ (* return the part of 'INTRS' that corresponds to the current datatype *)
+ SOME ((Node o (map valOf) o toList o valOf o assoc) (INTRS, index), model', args')
end
end
else
@@ -2111,7 +2169,8 @@
val (i_nat, _, _) = interpret thy model {maxvars=0, def_eq=false, next_idx=1, bounds=[], wellformed=True} (Free ("dummy", Type ("nat", [])))
val size_nat = size_of_type i_nat
(* int -> int -> interpretation *)
- fun plus m n = let
+ fun plus m n =
+ let
val element = (m+n)+1
in
if element > size_nat then
@@ -2138,7 +2197,8 @@
val (i_nat, _, _) = interpret thy model {maxvars=0, def_eq=false, next_idx=1, bounds=[], wellformed=True} (Free ("dummy", Type ("nat", [])))
val size_nat = size_of_type i_nat
(* int -> int -> interpretation *)
- fun minus m n = let
+ fun minus m n =
+ let
val element = Int.max (m-n, 0) + 1
in
Leaf ((replicate (element-1) False) @ True :: (replicate (size_nat - element) False))
@@ -2162,7 +2222,8 @@
val (i_nat, _, _) = interpret thy model {maxvars=0, def_eq=false, next_idx=1, bounds=[], wellformed=True} (Free ("dummy", Type ("nat", [])))
val size_nat = size_of_type i_nat
(* nat -> nat -> interpretation *)
- fun mult m n = let
+ fun mult m n =
+ let
val element = (m*n)+1
in
if element > size_nat then
@@ -2176,6 +2237,53 @@
| _ =>
NONE;
+ (* theory -> model -> arguments -> Term.term -> (interpretation * model * arguments) option *)
+
+ (* only an optimization: 'op @' could in principle be interpreted with *)
+ (* interpreters available already (using its definition), but the code *)
+ (* below is more efficient *)
+
+ fun List_append_interpreter thy model args t =
+ case t of
+ Const ("List.op @", Type ("fun", [Type ("List.list", [T]), Type ("fun", [Type ("List.list", [_]), Type ("List.list", [_])])])) =>
+ let
+ val (i_elem, _, _) = interpret thy model {maxvars=0, def_eq=false, next_idx=1, bounds=[], wellformed=True} (Free ("dummy", T))
+ val size_elem = size_of_type i_elem
+ val (i_list, _, _) = interpret thy model {maxvars=0, def_eq=false, next_idx=1, bounds=[], wellformed=True} (Free ("dummy", Type ("List.list", [T])))
+ val size_list = size_of_type i_list
+ (* power (a, b) computes a^b, for a>=0, b>=0 *)
+ (* int * int -> int *)
+ fun power (a, 0) = 1
+ | power (a, 1) = a
+ | power (a, b) = let val ab = power(a, b div 2) in ab * ab * power(a, b mod 2) end
+ (* log (a, b) computes floor(log_a(b)), i.e. the largest integer x s.t. a^x <= b, for a>=2, b>=1 *)
+ (* int * int -> int *)
+ fun log (a, b) =
+ let
+ fun logloop (ax, x) =
+ if ax > b then x-1 else logloop (a * ax, x+1)
+ in
+ logloop (1, 0)
+ end
+ (* nat -> nat -> interpretation *)
+ fun append m n =
+ let
+ (* The following formula depends on the order in which lists are *)
+ (* enumerated by the 'IDT_constructor_interpreter'. It took me *)
+ (* a while to come up with this formula. *)
+ val element = n + m * (if size_elem = 1 then 1 else power (size_elem, log (size_elem, n+1))) + 1
+ in
+ if element > size_list then
+ Leaf (replicate size_list False)
+ else
+ Leaf ((replicate (element-1) False) @ True :: (replicate (size_list - element) False))
+ end
+ in
+ SOME (Node (map (fn m => Node (map (append m) (0 upto size_list-1))) (0 upto size_list-1)), model, args)
+ end
+ | _ =>
+ NONE;
+
(* ------------------------------------------------------------------------- *)
(* PRINTERS *)
@@ -2413,6 +2521,7 @@
add_interpreter "Nat.op +" Nat_plus_interpreter,
add_interpreter "Nat.op -" Nat_minus_interpreter,
add_interpreter "Nat.op *" Nat_mult_interpreter,
+ add_interpreter "List.op @" List_append_interpreter,
add_printer "stlc" stlc_printer,
add_printer "set" set_printer,
add_printer "IDT" IDT_printer];