extending predicate compiler and proof procedure to support tuples; testing predicate wirh HOL-MicroJava semantics
authorbulwahn
Wed, 23 Sep 2009 16:20:12 +0200
changeset 32665 8bf46a97ff79
parent 32664 5d4f32b02450
child 32666 fd96d5f49d59
extending predicate compiler and proof procedure to support tuples; testing predicate wirh HOL-MicroJava semantics
src/HOL/ex/Predicate_Compile_ex.thy
src/HOL/ex/predicate_compile.ML
--- a/src/HOL/ex/Predicate_Compile_ex.thy	Wed Sep 23 16:20:12 2009 +0200
+++ b/src/HOL/ex/Predicate_Compile_ex.thy	Wed Sep 23 16:20:12 2009 +0200
@@ -44,7 +44,7 @@
     "partition f [] [] []"
   | "f x \<Longrightarrow> partition f xs ys zs \<Longrightarrow> partition f (x # xs) (x # ys) zs"
   | "\<not> f x \<Longrightarrow> partition f xs ys zs \<Longrightarrow> partition f (x # xs) ys (x # zs)"
-
+ML {* set Toplevel.debug *} 
 code_pred partition .
 
 thm partition.equation
--- a/src/HOL/ex/predicate_compile.ML	Wed Sep 23 16:20:12 2009 +0200
+++ b/src/HOL/ex/predicate_compile.ML	Wed Sep 23 16:20:12 2009 +0200
@@ -672,7 +672,7 @@
   let
     val Ts = binder_types T
     val (paramTs, (inargTs, outargTs)) = split_modeT (iss, is) Ts
-    val paramTs' = map2 (fn NONE => I | SOME is => funT_of compfuns ([], is)) iss paramTs 
+    val paramTs' = map2 (fn NONE => I | SOME is => funT_of compfuns ([], is)) iss paramTs
   in
     (paramTs' @ inargTs) ---> (mk_predT compfuns (mk_tupleT outargTs))
   end;
@@ -1335,14 +1335,23 @@
 
 fun compile_pred compfuns mk_fun_of use_size thy all_vs param_vs s T mode moded_cls =
   let
-    val (Ts1, (Us1, Us2)) = split_modeT mode (binder_types T)
-    val funT_of = if use_size then sizelim_funT_of else funT_of 
+	  val (Ts1, Ts2) = chop (length (fst mode)) (binder_types T)
+    val (Us1, Us2) = split_smodeT (snd mode) Ts2
+    val funT_of = if use_size then sizelim_funT_of else funT_of
     val Ts1' = map2 (fn NONE => I | SOME is => funT_of compfuns ([], is)) (fst mode) Ts1
-    val xnames = Name.variant_list (all_vs @ param_vs)
-      (map (fn (i, NONE) => "x" ^ string_of_int i | (i, SOME s) => error "pair mode") (snd mode));
-    val size_name = Name.variant (all_vs @ param_vs @ xnames) "size"
-    (* termify code: val xs = map2 (fn s => fn T => Free (s, termifyT T)) xnames Us1; *)
-    val xs = map2 (fn s => fn T => Free (s, T)) xnames Us1;
+    val size_name = Name.variant (all_vs @ param_vs) "size"
+  	fun mk_input_term (i, NONE) =
+		    [Free (Name.variant (all_vs @ param_vs) ("x" ^ string_of_int i), nth Ts2 (i - 1))]
+		  | mk_input_term (i, SOME pis) = case HOLogic.strip_tupleT (nth Ts2 (i - 1)) of
+						   [] => error "strange unit input"
+					   | [T] => [Free (Name.variant (all_vs @ param_vs) ("x" ^ string_of_int i), nth Ts2 (i - 1))]
+						 | Ts => let
+							 val vnames = Name.variant_list (all_vs @ param_vs)
+								(map (fn j => "x" ^ string_of_int i ^ "p" ^ string_of_int j)
+									pis)
+						 in if null pis then []
+						   else [HOLogic.mk_tuple (map Free (vnames ~~ map (fn j => nth Ts (j - 1)) pis))] end
+		val in_ts = maps mk_input_term (snd mode)
     val params = map2 (fn s => fn T => Free (s, T)) param_vs Ts1'
     val size = Free (size_name, @{typ "code_numeral"})
     val decr_size =
@@ -1353,7 +1362,7 @@
         NONE
     val cl_ts =
       map (compile_clause compfuns decr_size (fn out_ts => mk_single compfuns (mk_tuple out_ts))
-        thy all_vs param_vs mode (mk_tuple xs)) moded_cls;
+        thy all_vs param_vs mode (mk_tuple in_ts)) moded_cls;
     val t = foldr1 (mk_sup compfuns) cl_ts
     val T' = mk_predT compfuns (mk_tupleT Us2)
     val size_t = Const (@{const_name "If"}, @{typ bool} --> T' --> T' --> T')
@@ -1361,16 +1370,17 @@
       $ mk_bot compfuns (dest_predT compfuns T') $ t
     val fun_const = mk_fun_of compfuns thy (s, T) mode
     val eq = if use_size then
-      (list_comb (fun_const, params @ xs @ [size]), size_t)
+      (list_comb (fun_const, params @ in_ts @ [size]), size_t)
     else
-      (list_comb (fun_const, params @ xs), t)
+      (list_comb (fun_const, params @ in_ts), t)
   in
     HOLogic.mk_Trueprop (HOLogic.mk_eq eq)
   end;
   
 (* special setup for simpset *)                  
-val HOL_basic_ss' = HOL_basic_ss addsimps @{thms "HOL.simp_thms"} setSolver 
-  (mk_solver "all_tac_solver" (fn _ => fn _ => all_tac))
+val HOL_basic_ss' = HOL_basic_ss addsimps (@{thms "HOL.simp_thms"} @ [@{thm Pair_eq}])
+  setSolver (mk_solver "all_tac_solver" (fn _ => fn _ => all_tac))
+	setSolver (mk_solver "True_solver" (fn _ => rtac @{thm TrueI}))
 
 (* Definition of executable functions and their intro and elim rules *)
 
@@ -1381,35 +1391,93 @@
         (ks @ [SOME k]))) arities));
 
 fun mk_Eval_of ((x, T), NONE) names = (x, names)
-  | mk_Eval_of ((x, T), SOME mode) names = let
-  val Ts = binder_types T
-  val argnames = Name.variant_list names
+  | mk_Eval_of ((x, T), SOME mode) names =
+	let
+    val Ts = binder_types T
+    (*val argnames = Name.variant_list names
         (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
-  val args = map Free (argnames ~~ Ts)
-  val (inargs, outargs) = split_smode mode args
-  val r = PredicateCompFuns.mk_Eval (list_comb (x, inargs), mk_tuple outargs)
-  val t = fold_rev lambda args r 
-in
-  (t, argnames @ names)
-end;
+    val args = map Free (argnames ~~ Ts)
+    val (inargs, outargs) = split_smode mode args*)
+		fun mk_split_lambda [] t = lambda (Free (Name.variant names "x", HOLogic.unitT)) t
+			| mk_split_lambda [x] t = lambda x t
+			| mk_split_lambda xs t =
+			let
+				fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t))
+					| mk_split_lambda' (x::xs) t = HOLogic.mk_split (lambda x (mk_split_lambda' xs t))
+			in
+				mk_split_lambda' xs t
+			end;
+  	fun mk_arg (i, T) =
+		  let
+	  	  val vname = Name.variant names ("x" ^ string_of_int i)
+		    val default = Free (vname, T)
+		  in 
+		    case AList.lookup (op =) mode i of
+		      NONE => (([], [default]), [default])
+			  | SOME NONE => (([default], []), [default])
+			  | SOME (SOME pis) =>
+				  case HOLogic.strip_tupleT T of
+						[] => error "pair mode but unit tuple" (*(([default], []), [default])*)
+					| [_] => error "pair mode but not a tuple" (*(([default], []), [default])*)
+					| Ts =>
+					  let
+							val vnames = Name.variant_list names
+								(map (fn j => "x" ^ string_of_int i ^ "p" ^ string_of_int j)
+									(1 upto length Ts))
+							val args = map Free (vnames ~~ Ts)
+							fun split_args (i, arg) (ins, outs) =
+							  if member (op =) pis i then
+							    (arg::ins, outs)
+								else
+								  (ins, arg::outs)
+							val (inargs, outargs) = fold_rev split_args ((1 upto length Ts) ~~ args) ([], [])
+							fun tuple args = if null args then [] else [HOLogic.mk_tuple args]
+						in ((tuple inargs, tuple outargs), args) end
+			end
+		val (inoutargs, args) = split_list (map mk_arg (1 upto (length Ts) ~~ Ts))
+    val (inargs, outargs) = pairself flat (split_list inoutargs)
+		val r = PredicateCompFuns.mk_Eval (list_comb (x, inargs), mk_tuple outargs)
+    val t = fold_rev mk_split_lambda args r
+  in
+    (t, names)
+  end;
 
 fun create_intro_elim_rule (mode as (iss, is)) defthm mode_id funT pred thy =
 let
   val Ts = binder_types (fastype_of pred)
   val funtrm = Const (mode_id, funT)
-  val argnames = Name.variant_list []
-        (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
   val (Ts1, Ts2) = chop (length iss) Ts;
   val Ts1' = map2 (fn NONE => I | SOME is => funT_of (PredicateCompFuns.compfuns) ([], is)) iss Ts1
-  val args = map Free (argnames ~~ (Ts1' @ Ts2))
-  val (params, ioargs) = chop (length iss) args
-  val (inargs, outargs) = split_smode is ioargs
-  val param_names = Name.variant_list argnames
+	val param_names = Name.variant_list []
+    (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts1)));
+  val params = map Free (param_names ~~ Ts1')
+	fun mk_args (i, T) argnames =
+    let
+		  val vname = Name.variant (param_names @ argnames) ("x" ^ string_of_int (length Ts1' + i))
+		  val default = (Free (vname, T), vname :: argnames)
+	  in
+  	  case AList.lookup (op =) is i of
+						 NONE => default
+					 | SOME NONE => default
+        	 | SOME (SOME pis) =>
+					   case HOLogic.strip_tupleT T of
+						   [] => default
+					   | [_] => default
+						 | Ts => 
+						let
+							val vnames = Name.variant_list (param_names @ argnames)
+								(map (fn j => "x" ^ string_of_int (length Ts1' + i) ^ "p" ^ string_of_int j)
+									(1 upto (length Ts)))
+						 in (HOLogic.mk_tuple (map Free (vnames ~~ Ts)), vnames  @ argnames) end
+		end
+	val (args, argnames) = fold_map mk_args (1 upto (length Ts2) ~~ Ts2) []
+  val (inargs, outargs) = split_smode is args
+  val param_names' = Name.variant_list (param_names @ argnames)
     (map (fn i => "p" ^ string_of_int i) (1 upto (length iss)))
-  val param_vs = map Free (param_names ~~ Ts1)
+  val param_vs = map Free (param_names' ~~ Ts1)
   val (params', names) = fold_map mk_Eval_of ((params ~~ Ts1) ~~ iss) []
-  val predpropI = HOLogic.mk_Trueprop (list_comb (pred, param_vs @ ioargs))
-  val predpropE = HOLogic.mk_Trueprop (list_comb (pred, params' @ ioargs))
+  val predpropI = HOLogic.mk_Trueprop (list_comb (pred, param_vs @ args))
+  val predpropE = HOLogic.mk_Trueprop (list_comb (pred, params' @ args))
   val param_eqs = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (param_vs ~~ params')
   val funargs = params @ inargs
   val funpropE = HOLogic.mk_Trueprop (PredicateCompFuns.mk_Eval (list_comb (funtrm, funargs),
@@ -1418,13 +1486,16 @@
                    mk_tuple outargs))
   val introtrm = Logic.list_implies (predpropI :: param_eqs, funpropI)
   val simprules = [defthm, @{thm eval_pred},
-                   @{thm "split_beta"}, @{thm "fst_conv"}, @{thm "snd_conv"}]
+	  @{thm "split_beta"}, @{thm "fst_conv"}, @{thm "snd_conv"}, @{thm pair_collapse}]
   val unfolddef_tac = Simplifier.asm_full_simp_tac (HOL_basic_ss addsimps simprules) 1
-  val introthm = Goal.prove (ProofContext.init thy) (argnames @ param_names @ ["y"]) [] introtrm (fn {...} => unfolddef_tac)
+  val introthm = Goal.prove (ProofContext.init thy) (argnames @ param_names @ param_names' @ ["y"]) [] introtrm (fn {...} => unfolddef_tac)
   val P = HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT));
   val elimtrm = Logic.list_implies ([funpropE, Logic.mk_implies (predpropE, P)], P)
-  val elimthm = Goal.prove (ProofContext.init thy) (argnames @ param_names @ ["y", "P"]) [] elimtrm (fn {...} => unfolddef_tac)
-in 
+  val elimthm = Goal.prove (ProofContext.init thy) (argnames @ param_names @ param_names' @ ["y", "P"]) [] elimtrm (fn {...} => unfolddef_tac)
+	val _ = Output.tracing (Display.string_of_thm_global thy elimthm)
+	val _ = Output.tracing (Display.string_of_thm_global thy introthm)
+
+in
   (introthm, elimthm)
 end;
 
@@ -1439,7 +1510,33 @@
     (Sign.full_bname thy (prefix ^ (Long_Name.base_name name))) ^
       (if HOmode = "" then "_" else "_for_" ^ HOmode ^ "_yields_") ^ (string_of_mode (snd mode))
   end;
-  
+
+fun split_tupleT is T =
+	let
+		fun split_tuple' _ _ [] = ([], [])
+			| split_tuple' is i (T::Ts) =
+			(if i mem is then apfst else apsnd) (cons T)
+				(split_tuple' is (i+1) Ts)
+	in
+	  split_tuple' is 1 (HOLogic.strip_tupleT T)
+  end
+	
+fun mk_arg xin xout pis T =
+  let
+	  val n = length (HOLogic.strip_tupleT T)
+		val ni = length pis
+	  fun mk_proj i j t =
+		  (if i = j then I else HOLogic.mk_fst)
+			  (funpow (i - 1) HOLogic.mk_snd t)
+	  fun mk_arg' i (si, so) = if i mem pis then
+		    (mk_proj si ni xin, (si+1, so))
+		  else
+			  (mk_proj so (n - ni) xout, (si, so+1))
+	  val (args, _) = fold_map mk_arg' (1 upto n) (1, 1)
+	in
+	  HOLogic.mk_tuple args
+	end
+
 fun create_definitions preds (name, modes) thy =
   let
     val compfuns = PredicateCompFuns.compfuns
@@ -1454,10 +1551,50 @@
       val funT = (Ts1' @ Us1) ---> (mk_predT compfuns (mk_tupleT Us2))
       val names = Name.variant_list []
         (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
-      val xs = map Free (names ~~ (Ts1' @ Ts2))
+			(* old *)
+			(*
+		  val xs = map Free (names ~~ (Ts1' @ Ts2))
       val (xparams, xargs) = chop (length iss) xs
       val (xins, xouts) = split_smode is xargs
-      val (xparams', names') = fold_map mk_Eval_of ((xparams ~~ Ts1) ~~ iss) names
+			*)
+			(* new *)
+			val param_names = Name.variant_list []
+			  (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts1')))
+		  val xparams = map Free (param_names ~~ Ts1')
+      fun mk_vars (i, T) names =
+			  let
+				  val vname = Name.variant names ("x" ^ string_of_int (length Ts1' + i))
+				in
+					case AList.lookup (op =) is i of
+						 NONE => ((([], [Free (vname, T)]), Free (vname, T)), vname :: names)
+					 | SOME NONE => ((([Free (vname, T)], []), Free (vname, T)), vname :: names)
+        	 | SOME (SOME pis) =>
+					   let
+						   val (Tins, Touts) = split_tupleT pis T
+							 val name_in = Name.variant names ("x" ^ string_of_int (length Ts1' + i) ^ "in")
+							 val name_out = Name.variant names ("x" ^ string_of_int (length Ts1' + i) ^ "out")
+						   val xin = Free (name_in, HOLogic.mk_tupleT Tins)
+							 val xout = Free (name_out, HOLogic.mk_tupleT Touts)
+							 val xarg = mk_arg xin xout pis T
+						 in (((if null Tins then [] else [xin], if null Touts then [] else [xout]), xarg), name_in :: name_out :: names) end
+						(* HOLogic.strip_tupleT T of
+						[] => 
+							in (Free (vname, T), vname :: names) end
+					| [_] => let val vname = Name.variant names ("x" ^ string_of_int (length Ts1' + i))
+							in (Free (vname, T), vname :: names) end
+					| Ts =>
+						let
+							val vnames = Name.variant_list names
+								(map (fn j => "x" ^ string_of_int (length Ts1' + i) ^ "p" ^ string_of_int j)
+									(1 upto (length Ts)))
+						 in (HOLogic.mk_tuple (map Free (vnames ~~ Ts)), vnames @ names) end *)
+				end
+   	  val (xinoutargs, names) = fold_map mk_vars ((1 upto (length Ts2)) ~~ Ts2) param_names
+      val (xinout, xargs) = split_list xinoutargs
+			val (xins, xouts) = pairself flat (split_list xinout)
+			(*val (xins, xouts) = split_smode is xargs*)
+			val (xparams', names') = fold_map mk_Eval_of ((xparams ~~ Ts1) ~~ iss) names
+			val _ = Output.tracing ("xargs:" ^ commas (map (Syntax.string_of_term_global thy) xargs))
       fun mk_split_lambda [] t = lambda (Free (Name.variant names' "x", HOLogic.unitT)) t
         | mk_split_lambda [x] t = lambda x t
         | mk_split_lambda xs t =
@@ -1471,12 +1608,15 @@
         (list_comb (Const (name, T), xparams' @ xargs)))
       val lhs = list_comb (Const (mode_cname, funT), xparams @ xins)
       val def = Logic.mk_equals (lhs, predterm)
+			val _ = Output.tracing ("def:" ^ (Syntax.string_of_term_global thy def))
       val ([definition], thy') = thy |>
         Sign.add_consts_i [(Binding.name mode_cbasename, funT, NoSyn)] |>
         PureThy.add_defs false [((Binding.name (mode_cbasename ^ "_def"), def), [])]
       val (intro, elim) =
         create_intro_elim_rule mode definition mode_cname funT (Const (name, T)) thy'
-      in thy' |> add_predfun name mode (mode_cname, definition, intro, elim)
+			val _ = Output.tracing (Display.string_of_thm_global thy' definition)
+      in thy'
+			  |> add_predfun name mode (mode_cname, definition, intro, elim)
         |> PureThy.store_thm (Binding.name (mode_cbasename ^ "I"), intro) |> snd
         |> PureThy.store_thm (Binding.name (mode_cbasename ^ "E"), elim)  |> snd
         |> Theory.checkpoint
@@ -1544,8 +1684,9 @@
     val (params, _) = chop (length ms) args
     val f_tac = case f of
       Const (name, T) => simp_tac (HOL_basic_ss addsimps 
-         (@{thm eval_pred}::(predfun_definition_of thy name mode)::
-         @{thm "Product_Type.split_conv"}::[])) 1
+         ([@{thm eval_pred}, (predfun_definition_of thy name mode),
+         @{thm "split_eta"}, @{thm "split_beta"}, @{thm "fst_conv"},
+				 @{thm "snd_conv"}, @{thm pair_collapse}, @{thm "Product_Type.split_conv"}])) 1
     | Free _ => TRY (rtac @{thm refl} 1)
     | Abs _ => error "prove_param: No valid parameter term"
   in
@@ -1578,7 +1719,13 @@
         THEN (EVERY (map (prove_param thy) (ms ~~ args1)))
         THEN (REPEAT_DETERM (atac 1))
       end
-  | _ => rtac @{thm bindI} 1 THEN atac 1
+  | _ => rtac @{thm bindI} 1
+	  THEN asm_full_simp_tac
+		  (HOL_basic_ss' addsimps [@{thm "split_eta"}, @{thm "split_beta"}, @{thm "fst_conv"},
+				 @{thm "snd_conv"}, @{thm pair_collapse}]) 1
+	  THEN (atac 1)
+	  THEN print_tac "after prove parameter call"
+		
 
 fun SOLVED tac st = FILTER (fn st' => nprems_of st' = nprems_of st - 1) tac st; 
 
@@ -1595,7 +1742,7 @@
 (* replace TRY by determining if it necessary - are there equations when calling compile match? *)
 in
    (* make this simpset better! *)
-  asm_simp_tac (HOL_basic_ss' addsimps simprules) 1
+  asm_full_simp_tac (HOL_basic_ss' addsimps simprules) 1
   THEN print_tac "after prove_match:"
   THEN (DETERM (TRY (EqSubst.eqsubst_tac (ProofContext.init thy) [0] [@{thm "HOL.if_P"}] 1
          THEN (REPEAT_DETERM (rtac @{thm conjI} 1 THEN (SOLVED (asm_simp_tac HOL_basic_ss 1))))
@@ -1629,7 +1776,9 @@
     val (in_ts, clause_out_ts) = split_smode is ts;
     fun prove_prems out_ts [] =
       (prove_match thy out_ts)
-      THEN asm_simp_tac HOL_basic_ss' 1
+			THEN print_tac "before simplifying assumptions"
+      THEN asm_full_simp_tac HOL_basic_ss' 1
+			THEN print_tac "before single intro rule"
       THEN (rtac (if null clause_out_ts then @{thm singleI_unit} else @{thm singleI}) 1)
     | prove_prems out_ts ((p, mode as Mode ((iss, is), _, param_modes)) :: ps) =
       let
@@ -1694,6 +1843,7 @@
     val pred_case_rule = the_elim_of thy pred
   in
     REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"}))
+		THEN print_tac "before applying elim rule"
     THEN etac (predfun_elim_of thy pred mode) 1
     THEN etac pred_case_rule 1
     THEN (EVERY (map
@@ -1805,7 +1955,8 @@
       (* How to handle equality correctly? *)
       THEN (print_tac "state before assumption matching")
       THEN (REPEAT (atac 1 ORELSE 
-         (CHANGED (asm_full_simp_tac HOL_basic_ss' 1)
+         (CHANGED (asm_full_simp_tac (HOL_basic_ss' addsimps
+					 [@{thm split_eta}, @{thm "split_beta"}, @{thm "fst_conv"}, @{thm "snd_conv"}, @{thm pair_collapse}]) 1)
           THEN print_tac "state after simp_tac:"))))
     | prove_prems2 out_ts ((p, mode as Mode ((iss, is), _, param_modes)) :: ps) =
       let
@@ -1879,6 +2030,7 @@
       (if !do_proofs then
         (fn _ =>
         rtac @{thm pred_iffI} 1
+				THEN print_tac "after pred_iffI"
         THEN prove_one_direction thy clauses preds modes pred mode moded_clauses
         THEN print_tac "proved one direction"
         THEN prove_other_direction thy modes pred mode moded_clauses
@@ -2245,4 +2397,3 @@
 end;
 
 end;
-