src/HOL/ex/predicate_compile.ML
changeset 32667 09546e654222
parent 32666 fd96d5f49d59
child 32668 b2de45007537
equal deleted inserted replaced
32666:fd96d5f49d59 32667:09546e654222
     1 (* Author: Lukas Bulwahn, TU Muenchen
       
     2 
       
     3 (Prototype of) A compiler from predicates specified by intro/elim rules
       
     4 to equations.
       
     5 *)
       
     6 
       
     7 signature PREDICATE_COMPILE =
       
     8 sig
       
     9   type smode = (int * int list option) list
       
    10   type mode = smode option list * smode
       
    11   datatype tmode = Mode of mode * smode * tmode option list;
       
    12   (*val add_equations_of: bool -> string list -> theory -> theory *)
       
    13   val register_predicate : (thm list * thm * int) -> theory -> theory
       
    14   val is_registered : theory -> string -> bool
       
    15  (* val fetch_pred_data : theory -> string -> (thm list * thm * int)  *)
       
    16   val predfun_intro_of: theory -> string -> mode -> thm
       
    17   val predfun_elim_of: theory -> string -> mode -> thm
       
    18   val strip_intro_concl: int -> term -> term * (term list * term list)
       
    19   val predfun_name_of: theory -> string -> mode -> string
       
    20   val all_preds_of : theory -> string list
       
    21   val modes_of: theory -> string -> mode list
       
    22   val string_of_mode : mode -> string
       
    23   val intros_of: theory -> string -> thm list
       
    24   val nparams_of: theory -> string -> int
       
    25   val add_intro: thm -> theory -> theory
       
    26   val set_elim: thm -> theory -> theory
       
    27   val setup: theory -> theory
       
    28   val code_pred: string -> Proof.context -> Proof.state
       
    29   val code_pred_cmd: string -> Proof.context -> Proof.state
       
    30   val print_stored_rules: theory -> unit
       
    31   val print_all_modes: theory -> unit
       
    32   val do_proofs: bool ref
       
    33   val mk_casesrule : Proof.context -> int -> thm list -> term
       
    34   val analyze_compr: theory -> term -> term
       
    35   val eval_ref: (unit -> term Predicate.pred) option ref
       
    36   val add_equations : string list -> theory -> theory
       
    37   val code_pred_intros_attrib : attribute
       
    38   (* used by Quickcheck_Generator *) 
       
    39   (*val funT_of : mode -> typ -> typ
       
    40   val mk_if_pred : term -> term
       
    41   val mk_Eval : term * term -> term*)
       
    42   val mk_tupleT : typ list -> typ
       
    43 (*  val mk_predT :  typ -> typ *)
       
    44   (* temporary for testing of the compilation *)
       
    45   datatype indprem = Prem of term list * term | Negprem of term list * term | Sidecond of term |
       
    46     GeneratorPrem of term list * term | Generator of (string * typ);
       
    47  (* val prepare_intrs: theory -> string list ->
       
    48     (string * typ) list * int * string list * string list * (string * mode list) list *
       
    49     (string * (term list * indprem list) list) list * (string * (int option list * int)) list*)
       
    50   datatype compilation_funs = CompilationFuns of {
       
    51     mk_predT : typ -> typ,
       
    52     dest_predT : typ -> typ,
       
    53     mk_bot : typ -> term,
       
    54     mk_single : term -> term,
       
    55     mk_bind : term * term -> term,
       
    56     mk_sup : term * term -> term,
       
    57     mk_if : term -> term,
       
    58     mk_not : term -> term,
       
    59     mk_map : typ -> typ -> term -> term -> term,
       
    60     lift_pred : term -> term
       
    61   };  
       
    62   type moded_clause = term list * (indprem * tmode) list
       
    63   type 'a pred_mode_table = (string * (mode * 'a) list) list
       
    64   val infer_modes : theory -> (string * mode list) list
       
    65     -> (string * mode list) list
       
    66     -> string list
       
    67     -> (string * (term list * indprem list) list) list
       
    68     -> (moded_clause list) pred_mode_table
       
    69   val infer_modes_with_generator : theory -> (string * mode list) list
       
    70     -> (string * mode list) list
       
    71     -> string list
       
    72     -> (string * (term list * indprem list) list) list
       
    73     -> (moded_clause list) pred_mode_table  
       
    74   (*val compile_preds : theory -> compilation_funs -> string list -> string list
       
    75     -> (string * typ) list -> (moded_clause list) pred_mode_table -> term pred_mode_table
       
    76   val rpred_create_definitions :(string * typ) list -> string * mode list
       
    77     -> theory -> theory 
       
    78   val split_smode : int list -> term list -> (term list * term list) *)
       
    79   val print_moded_clauses :
       
    80     theory -> (moded_clause list) pred_mode_table -> unit
       
    81   val print_compiled_terms : theory -> term pred_mode_table -> unit
       
    82   (*val rpred_prove_preds : theory -> term pred_mode_table -> thm pred_mode_table*)
       
    83   val rpred_compfuns : compilation_funs
       
    84   val dest_funT : typ -> typ * typ
       
    85  (* val depending_preds_of : theory -> thm list -> string list *)
       
    86   val add_quickcheck_equations : string list -> theory -> theory
       
    87   val add_sizelim_equations : string list -> theory -> theory
       
    88   val is_inductive_predicate : theory -> string -> bool
       
    89   val terms_vs : term list -> string list
       
    90   val subsets : int -> int -> int list list
       
    91   val check_mode_clause : bool -> theory -> string list ->
       
    92     (string * mode list) list -> (string * mode list) list -> mode -> (term list * indprem list)
       
    93       -> (term list * (indprem * tmode) list) option
       
    94   val string_of_moded_prem : theory -> (indprem * tmode) -> string
       
    95   val all_modes_of : theory -> (string * mode list) list
       
    96   val all_generator_modes_of : theory -> (string * mode list) list
       
    97   val compile_clause : compilation_funs -> term option -> (term list -> term) ->
       
    98     theory -> string list -> string list -> mode -> term -> moded_clause -> term
       
    99   val preprocess_intro : theory -> thm -> thm
       
   100   val is_constrt : theory -> term -> bool
       
   101   val is_predT : typ -> bool
       
   102   val guess_nparams : typ -> int
       
   103   val cprods_subset : 'a list list -> 'a list list
       
   104 end;
       
   105 
       
   106 structure Predicate_Compile : PREDICATE_COMPILE =
       
   107 struct
       
   108 
       
   109 (** auxiliary **)
       
   110 
       
   111 (* debug stuff *)
       
   112 
       
   113 fun tracing s = (if ! Toplevel.debug then Output.tracing s else ());
       
   114 
       
   115 fun print_tac s = Seq.single; (*Tactical.print_tac s;*) (* (if ! Toplevel.debug then Tactical.print_tac s else Seq.single); *)
       
   116 fun debug_tac msg = Seq.single; (* (fn st => (Output.tracing msg; Seq.single st)); *)
       
   117 
       
   118 val do_proofs = ref true;
       
   119 
       
   120 fun mycheat_tac thy i st =
       
   121   (Tactic.rtac (SkipProof.make_thm thy (Var (("A", 0), propT))) i) st
       
   122 
       
   123 fun remove_last_goal thy st =
       
   124   (Tactic.rtac (SkipProof.make_thm thy (Var (("A", 0), propT))) (nprems_of st)) st
       
   125 
       
   126 (* reference to preprocessing of InductiveSet package *)
       
   127 
       
   128 val ind_set_codegen_preproc = Inductive_Set.codegen_preproc;
       
   129 
       
   130 (** fundamentals **)
       
   131 
       
   132 (* syntactic operations *)
       
   133 
       
   134 fun mk_eq (x, xs) =
       
   135   let fun mk_eqs _ [] = []
       
   136         | mk_eqs a (b::cs) =
       
   137             HOLogic.mk_eq (Free (a, fastype_of b), b) :: mk_eqs a cs
       
   138   in mk_eqs x xs end;
       
   139 
       
   140 fun mk_tupleT [] = HOLogic.unitT
       
   141   | mk_tupleT Ts = foldr1 HOLogic.mk_prodT Ts;
       
   142 
       
   143 fun dest_tupleT (Type (@{type_name Product_Type.unit}, [])) = []
       
   144   | dest_tupleT (Type (@{type_name "*"}, [T1, T2])) = T1 :: (dest_tupleT T2)
       
   145   | dest_tupleT t = [t]
       
   146 
       
   147 fun mk_tuple [] = HOLogic.unit
       
   148   | mk_tuple ts = foldr1 HOLogic.mk_prod ts;
       
   149 
       
   150 fun dest_tuple (Const (@{const_name Product_Type.Unity}, _)) = []
       
   151   | dest_tuple (Const (@{const_name Pair}, _) $ t1 $ t2) = t1 :: (dest_tuple t2)
       
   152   | dest_tuple t = [t]
       
   153 
       
   154 fun mk_scomp (t, u) =
       
   155   let
       
   156     val T = fastype_of t
       
   157     val U = fastype_of u
       
   158     val [A] = binder_types T
       
   159     val D = body_type U 
       
   160   in 
       
   161     Const (@{const_name "scomp"}, T --> U --> A --> D) $ t $ u
       
   162   end;
       
   163 
       
   164 fun dest_funT (Type ("fun",[S, T])) = (S, T)
       
   165   | dest_funT T = raise TYPE ("dest_funT", [T], [])
       
   166  
       
   167 fun mk_fun_comp (t, u) =
       
   168   let
       
   169     val (_, B) = dest_funT (fastype_of t)
       
   170     val (C, A) = dest_funT (fastype_of u)
       
   171   in
       
   172     Const(@{const_name "Fun.comp"}, (A --> B) --> (C --> A) --> C --> B) $ t $ u
       
   173   end;
       
   174 
       
   175 fun dest_randomT (Type ("fun", [@{typ Random.seed},
       
   176   Type ("*", [Type ("*", [T, @{typ "unit => Code_Eval.term"}]) ,@{typ Random.seed}])])) = T
       
   177   | dest_randomT T = raise TYPE ("dest_randomT", [T], [])
       
   178 
       
   179 (* destruction of intro rules *)
       
   180 
       
   181 (* FIXME: look for other place where this functionality was used before *)
       
   182 fun strip_intro_concl nparams intro = let
       
   183   val _ $ u = Logic.strip_imp_concl intro
       
   184   val (pred, all_args) = strip_comb u
       
   185   val (params, args) = chop nparams all_args
       
   186 in (pred, (params, args)) end
       
   187 
       
   188 (** data structures **)
       
   189 
       
   190 type smode = (int * int list option) list;
       
   191 type mode = smode option list * smode;
       
   192 datatype tmode = Mode of mode * smode * tmode option list;
       
   193 
       
   194 fun gen_split_smode (mk_tuple, strip_tuple) smode ts =
       
   195   let
       
   196     fun split_tuple' _ _ [] = ([], [])
       
   197     | split_tuple' is i (t::ts) =
       
   198       (if i mem is then apfst else apsnd) (cons t)
       
   199         (split_tuple' is (i+1) ts)
       
   200     fun split_tuple is t = split_tuple' is 1 (strip_tuple t)
       
   201     fun split_smode' _ _ [] = ([], [])
       
   202       | split_smode' smode i (t::ts) =
       
   203         (if i mem (map fst smode) then
       
   204           case (the (AList.lookup (op =) smode i)) of
       
   205             NONE => apfst (cons t)
       
   206             | SOME is =>
       
   207               let
       
   208                 val (ts1, ts2) = split_tuple is t
       
   209                 fun cons_tuple ts = if null ts then I else cons (mk_tuple ts)
       
   210                 in (apfst (cons_tuple ts1)) o (apsnd (cons_tuple ts2)) end
       
   211           else apsnd (cons t))
       
   212         (split_smode' smode (i+1) ts)
       
   213   in split_smode' smode 1 ts end
       
   214 
       
   215 val split_smode = gen_split_smode (HOLogic.mk_tuple, HOLogic.strip_tuple)   
       
   216 val split_smodeT = gen_split_smode (HOLogic.mk_tupleT, HOLogic.strip_tupleT)
       
   217 
       
   218 fun gen_split_mode split_smode (iss, is) ts =
       
   219   let
       
   220     val (t1, t2) = chop (length iss) ts 
       
   221   in (t1, split_smode is t2) end
       
   222 
       
   223 val split_mode = gen_split_mode split_smode
       
   224 val split_modeT = gen_split_mode split_smodeT
       
   225 
       
   226 fun string_of_smode js =
       
   227     commas (map
       
   228       (fn (i, is) =>
       
   229         string_of_int i ^ (case is of NONE => ""
       
   230     | SOME is => "p" ^ enclose "[" "]" (commas (map string_of_int is)))) js)
       
   231 
       
   232 fun string_of_mode (iss, is) = space_implode " -> " (map
       
   233   (fn NONE => "X"
       
   234     | SOME js => enclose "[" "]" (string_of_smode js))
       
   235        (iss @ [SOME is]));
       
   236 
       
   237 fun string_of_tmode (Mode (predmode, termmode, param_modes)) =
       
   238   "predmode: " ^ (string_of_mode predmode) ^ 
       
   239   (if null param_modes then "" else
       
   240     "; " ^ "params: " ^ commas (map (the_default "NONE" o Option.map string_of_tmode) param_modes))
       
   241     
       
   242 datatype indprem = Prem of term list * term | Negprem of term list * term | Sidecond of term |
       
   243   GeneratorPrem of term list * term | Generator of (string * typ);
       
   244 
       
   245 type moded_clause = term list * (indprem * tmode) list
       
   246 type 'a pred_mode_table = (string * (mode * 'a) list) list
       
   247 
       
   248 datatype predfun_data = PredfunData of {
       
   249   name : string,
       
   250   definition : thm,
       
   251   intro : thm,
       
   252   elim : thm
       
   253 };
       
   254 
       
   255 fun rep_predfun_data (PredfunData data) = data;
       
   256 fun mk_predfun_data (name, definition, intro, elim) =
       
   257   PredfunData {name = name, definition = definition, intro = intro, elim = elim}
       
   258 
       
   259 datatype function_data = FunctionData of {
       
   260   name : string,
       
   261   equation : thm option (* is not used at all? *)
       
   262 };
       
   263 
       
   264 fun rep_function_data (FunctionData data) = data;
       
   265 fun mk_function_data (name, equation) =
       
   266   FunctionData {name = name, equation = equation}
       
   267 
       
   268 datatype pred_data = PredData of {
       
   269   intros : thm list,
       
   270   elim : thm option,
       
   271   nparams : int,
       
   272   functions : (mode * predfun_data) list,
       
   273   generators : (mode * function_data) list,
       
   274   sizelim_functions : (mode * function_data) list 
       
   275 };
       
   276 
       
   277 fun rep_pred_data (PredData data) = data;
       
   278 fun mk_pred_data ((intros, elim, nparams), (functions, generators, sizelim_functions)) =
       
   279   PredData {intros = intros, elim = elim, nparams = nparams,
       
   280     functions = functions, generators = generators, sizelim_functions = sizelim_functions}
       
   281 fun map_pred_data f (PredData {intros, elim, nparams, functions, generators, sizelim_functions}) =
       
   282   mk_pred_data (f ((intros, elim, nparams), (functions, generators, sizelim_functions)))
       
   283   
       
   284 fun eq_option eq (NONE, NONE) = true
       
   285   | eq_option eq (SOME x, SOME y) = eq (x, y)
       
   286   | eq_option eq _ = false
       
   287   
       
   288 fun eq_pred_data (PredData d1, PredData d2) = 
       
   289   eq_list (Thm.eq_thm) (#intros d1, #intros d2) andalso
       
   290   eq_option (Thm.eq_thm) (#elim d1, #elim d2) andalso
       
   291   #nparams d1 = #nparams d2
       
   292   
       
   293 structure PredData = TheoryDataFun
       
   294 (
       
   295   type T = pred_data Graph.T;
       
   296   val empty = Graph.empty;
       
   297   val copy = I;
       
   298   val extend = I;
       
   299   fun merge _ = Graph.merge eq_pred_data;
       
   300 );
       
   301 
       
   302 (* queries *)
       
   303 
       
   304 fun lookup_pred_data thy name =
       
   305   Option.map rep_pred_data (try (Graph.get_node (PredData.get thy)) name)
       
   306 
       
   307 fun the_pred_data thy name = case lookup_pred_data thy name
       
   308  of NONE => error ("No such predicate " ^ quote name)  
       
   309   | SOME data => data;
       
   310 
       
   311 val is_registered = is_some oo lookup_pred_data 
       
   312 
       
   313 val all_preds_of = Graph.keys o PredData.get
       
   314 
       
   315 fun intros_of thy = map (Thm.transfer thy) o #intros o the_pred_data thy
       
   316 
       
   317 fun the_elim_of thy name = case #elim (the_pred_data thy name)
       
   318  of NONE => error ("No elimination rule for predicate " ^ quote name)
       
   319   | SOME thm => Thm.transfer thy thm 
       
   320   
       
   321 val has_elim = is_some o #elim oo the_pred_data;
       
   322 
       
   323 val nparams_of = #nparams oo the_pred_data
       
   324 
       
   325 val modes_of = (map fst) o #functions oo the_pred_data
       
   326 
       
   327 fun all_modes_of thy = map (fn name => (name, modes_of thy name)) (all_preds_of thy) 
       
   328 
       
   329 val is_compiled = not o null o #functions oo the_pred_data
       
   330 
       
   331 fun lookup_predfun_data thy name mode =
       
   332   Option.map rep_predfun_data (AList.lookup (op =)
       
   333   (#functions (the_pred_data thy name)) mode)
       
   334 
       
   335 fun the_predfun_data thy name mode = case lookup_predfun_data thy name mode
       
   336   of NONE => error ("No function defined for mode " ^ string_of_mode mode ^ " of predicate " ^ name)
       
   337    | SOME data => data;
       
   338 
       
   339 val predfun_name_of = #name ooo the_predfun_data
       
   340 
       
   341 val predfun_definition_of = #definition ooo the_predfun_data
       
   342 
       
   343 val predfun_intro_of = #intro ooo the_predfun_data
       
   344 
       
   345 val predfun_elim_of = #elim ooo the_predfun_data
       
   346 
       
   347 fun lookup_generator_data thy name mode = 
       
   348   Option.map rep_function_data (AList.lookup (op =)
       
   349   (#generators (the_pred_data thy name)) mode)
       
   350   
       
   351 fun the_generator_data thy name mode = case lookup_generator_data thy name mode
       
   352   of NONE => error ("No generator defined for mode " ^ string_of_mode mode ^ " of predicate " ^ name)
       
   353    | SOME data => data
       
   354 
       
   355 val generator_name_of = #name ooo the_generator_data
       
   356 
       
   357 val generator_modes_of = (map fst) o #generators oo the_pred_data
       
   358 
       
   359 fun all_generator_modes_of thy =
       
   360   map (fn name => (name, generator_modes_of thy name)) (all_preds_of thy) 
       
   361 
       
   362 fun lookup_sizelim_function_data thy name mode =
       
   363   Option.map rep_function_data (AList.lookup (op =)
       
   364   (#sizelim_functions (the_pred_data thy name)) mode)
       
   365 
       
   366 fun the_sizelim_function_data thy name mode = case lookup_sizelim_function_data thy name mode
       
   367   of NONE => error ("No size-limited function defined for mode " ^ string_of_mode mode
       
   368     ^ " of predicate " ^ name)
       
   369    | SOME data => data
       
   370 
       
   371 val sizelim_function_name_of = #name ooo the_sizelim_function_data
       
   372 
       
   373 (*val generator_modes_of = (map fst) o #generators oo the_pred_data*)
       
   374      
       
   375 (* diagnostic display functions *)
       
   376 
       
   377 fun print_modes modes = Output.tracing ("Inferred modes:\n" ^
       
   378   cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map
       
   379     string_of_mode ms)) modes));
       
   380 
       
   381 fun print_pred_mode_table string_of_entry thy pred_mode_table =
       
   382   let
       
   383     fun print_mode pred (mode, entry) =  "mode : " ^ (string_of_mode mode)
       
   384       ^ (string_of_entry pred mode entry)  
       
   385     fun print_pred (pred, modes) =
       
   386       "predicate " ^ pred ^ ": " ^ cat_lines (map (print_mode pred) modes)
       
   387     val _ = Output.tracing (cat_lines (map print_pred pred_mode_table))
       
   388   in () end;
       
   389 
       
   390 fun string_of_moded_prem thy (Prem (ts, p), tmode) =
       
   391     (Syntax.string_of_term_global thy (list_comb (p, ts))) ^
       
   392     "(" ^ (string_of_tmode tmode) ^ ")"
       
   393   | string_of_moded_prem thy (GeneratorPrem (ts, p), Mode (predmode, is, _)) =
       
   394     (Syntax.string_of_term_global thy (list_comb (p, ts))) ^
       
   395     "(generator_mode: " ^ (string_of_mode predmode) ^ ")"
       
   396   | string_of_moded_prem thy (Generator (v, T), _) =
       
   397     "Generator for " ^ v ^ " of Type " ^ (Syntax.string_of_typ_global thy T)
       
   398   | string_of_moded_prem thy (Negprem (ts, p), Mode (_, is, _)) =
       
   399     (Syntax.string_of_term_global thy (list_comb (p, ts))) ^
       
   400     "(negative mode: " ^ string_of_smode is ^ ")"
       
   401   | string_of_moded_prem thy (Sidecond t, Mode (_, is, _)) =
       
   402     (Syntax.string_of_term_global thy t) ^
       
   403     "(sidecond mode: " ^ string_of_smode is ^ ")"    
       
   404   | string_of_moded_prem _ _ = error "string_of_moded_prem: unimplemented"
       
   405      
       
   406 fun print_moded_clauses thy =
       
   407   let        
       
   408     fun string_of_clause pred mode clauses =
       
   409       cat_lines (map (fn (ts, prems) => (space_implode " --> "
       
   410         (map (string_of_moded_prem thy) prems)) ^ " --> " ^ pred ^ " "
       
   411         ^ (space_implode " " (map (Syntax.string_of_term_global thy) ts))) clauses)
       
   412   in print_pred_mode_table string_of_clause thy end;
       
   413 
       
   414 fun print_compiled_terms thy =
       
   415   print_pred_mode_table (fn _ => fn _ => Syntax.string_of_term_global thy) thy
       
   416     
       
   417 fun print_stored_rules thy =
       
   418   let
       
   419     val preds = (Graph.keys o PredData.get) thy
       
   420     fun print pred () = let
       
   421       val _ = writeln ("predicate: " ^ pred)
       
   422       val _ = writeln ("number of parameters: " ^ string_of_int (nparams_of thy pred))
       
   423       val _ = writeln ("introrules: ")
       
   424       val _ = fold (fn thm => fn u => writeln (Display.string_of_thm_global thy thm))
       
   425         (rev (intros_of thy pred)) ()
       
   426     in
       
   427       if (has_elim thy pred) then
       
   428         writeln ("elimrule: " ^ Display.string_of_thm_global thy (the_elim_of thy pred))
       
   429       else
       
   430         writeln ("no elimrule defined")
       
   431     end
       
   432   in
       
   433     fold print preds ()
       
   434   end;
       
   435 
       
   436 fun print_all_modes thy =
       
   437   let
       
   438     val _ = writeln ("Inferred modes:")
       
   439     fun print (pred, modes) u =
       
   440       let
       
   441         val _ = writeln ("predicate: " ^ pred)
       
   442         val _ = writeln ("modes: " ^ (commas (map string_of_mode modes)))
       
   443       in u end  
       
   444   in
       
   445     fold print (all_modes_of thy) ()
       
   446   end
       
   447   
       
   448 (** preprocessing rules **)  
       
   449 
       
   450 fun imp_prems_conv cv ct =
       
   451   case Thm.term_of ct of
       
   452     Const ("==>", _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv) (imp_prems_conv cv) ct
       
   453   | _ => Conv.all_conv ct
       
   454 
       
   455 fun Trueprop_conv cv ct =
       
   456   case Thm.term_of ct of
       
   457     Const ("Trueprop", _) $ _ => Conv.arg_conv cv ct  
       
   458   | _ => error "Trueprop_conv"
       
   459 
       
   460 fun preprocess_intro thy rule =
       
   461   Conv.fconv_rule
       
   462     (imp_prems_conv
       
   463       (Trueprop_conv (Conv.try_conv (Conv.rewr_conv (Thm.symmetric @{thm Predicate.eq_is_eq})))))
       
   464     (Thm.transfer thy rule)
       
   465 
       
   466 fun preprocess_elim thy nparams elimrule =
       
   467   let
       
   468     val _ = Output.tracing ("Preprocessing elimination rule "
       
   469       ^ (Display.string_of_thm_global thy elimrule))
       
   470     fun replace_eqs (Const ("Trueprop", _) $ (Const ("op =", T) $ lhs $ rhs)) =
       
   471        HOLogic.mk_Trueprop (Const (@{const_name Predicate.eq}, T) $ lhs $ rhs)
       
   472      | replace_eqs t = t
       
   473     val prems = Thm.prems_of elimrule
       
   474     val nargs = length (snd (strip_comb (HOLogic.dest_Trueprop (hd prems)))) - nparams
       
   475     fun preprocess_case t =
       
   476      let
       
   477        val params = Logic.strip_params t
       
   478        val (assums1, assums2) = chop nargs (Logic.strip_assums_hyp t)
       
   479        val assums_hyp' = assums1 @ (map replace_eqs assums2)
       
   480      in
       
   481        list_all (params, Logic.list_implies (assums_hyp', Logic.strip_assums_concl t))
       
   482      end 
       
   483     val cases' = map preprocess_case (tl prems)
       
   484     val elimrule' = Logic.list_implies ((hd prems) :: cases', Thm.concl_of elimrule)
       
   485     (*
       
   486     (*val _ =  Output.tracing ("elimrule': "^ (Syntax.string_of_term_global thy elimrule'))*)
       
   487     val bigeq = (Thm.symmetric (Conv.implies_concl_conv (MetaSimplifier.rewrite true [@{thm Predicate.eq_is_eq}])
       
   488          (cterm_of thy elimrule')))
       
   489     val _ = Output.tracing ("bigeq:" ^ (Display.string_of_thm_global thy bigeq))   
       
   490     val res = 
       
   491     Thm.equal_elim bigeq
       
   492       
       
   493       elimrule
       
   494     *)
       
   495     val t = (fn {...} => mycheat_tac thy 1)
       
   496     val eq = Goal.prove (ProofContext.init thy) [] [] (Logic.mk_equals ((Thm.prop_of elimrule), elimrule')) t
       
   497     val _ = Output.tracing "Preprocessed elimination rule"
       
   498   in
       
   499     Thm.equal_elim eq elimrule
       
   500   end;
       
   501 
       
   502 (* special case: predicate with no introduction rule *)
       
   503 fun noclause thy predname elim = let
       
   504   val T = (Logic.unvarifyT o Sign.the_const_type thy) predname
       
   505   val Ts = binder_types T
       
   506   val names = Name.variant_list []
       
   507         (map (fn i => "x" ^ (string_of_int i)) (1 upto (length Ts)))
       
   508   val vs = map2 (curry Free) names Ts
       
   509   val clausehd = HOLogic.mk_Trueprop (list_comb (Const (predname, T), vs))
       
   510   val intro_t = Logic.mk_implies (@{prop False}, clausehd)
       
   511   val P = HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT))
       
   512   val elim_t = Logic.list_implies ([clausehd, Logic.mk_implies (@{prop False}, P)], P)
       
   513   val intro = Goal.prove (ProofContext.init thy) names [] intro_t
       
   514         (fn {...} => etac @{thm FalseE} 1)
       
   515   val elim = Goal.prove (ProofContext.init thy) ("P" :: names) [] elim_t
       
   516         (fn {...} => etac elim 1) 
       
   517 in
       
   518   ([intro], elim)
       
   519 end
       
   520 
       
   521 fun fetch_pred_data thy name =
       
   522   case try (Inductive.the_inductive (ProofContext.init thy)) name of
       
   523     SOME (info as (_, result)) => 
       
   524       let
       
   525         fun is_intro_of intro =
       
   526           let
       
   527             val (const, _) = strip_comb (HOLogic.dest_Trueprop (concl_of intro))
       
   528           in (fst (dest_Const const) = name) end;      
       
   529         val intros = ind_set_codegen_preproc thy ((map (preprocess_intro thy))
       
   530           (filter is_intro_of (#intrs result)))
       
   531         val pre_elim = nth (#elims result) (find_index (fn s => s = name) (#names (fst info)))
       
   532         val nparams = length (Inductive.params_of (#raw_induct result))
       
   533         val elim = singleton (ind_set_codegen_preproc thy) (preprocess_elim thy nparams pre_elim)
       
   534         val (intros, elim) = if null intros then noclause thy name elim else (intros, elim)
       
   535       in
       
   536         mk_pred_data ((intros, SOME elim, nparams), ([], [], []))
       
   537       end                                                                    
       
   538   | NONE => error ("No such predicate: " ^ quote name)
       
   539   
       
   540 (* updaters *)
       
   541 
       
   542 fun apfst3 f (x, y, z) =  (f x, y, z)
       
   543 fun apsnd3 f (x, y, z) =  (x, f y, z)
       
   544 fun aptrd3 f (x, y, z) =  (x, y, f z)
       
   545 
       
   546 fun add_predfun name mode data =
       
   547   let
       
   548     val add = (apsnd o apfst3 o cons) (mode, mk_predfun_data data)
       
   549   in PredData.map (Graph.map_node name (map_pred_data add)) end
       
   550 
       
   551 fun is_inductive_predicate thy name =
       
   552   is_some (try (Inductive.the_inductive (ProofContext.init thy)) name)
       
   553 
       
   554 fun depending_preds_of thy (key, value) =
       
   555   let
       
   556     val intros = (#intros o rep_pred_data) value
       
   557   in
       
   558     fold Term.add_const_names (map Thm.prop_of intros) []
       
   559       |> filter (fn c => (not (c = key)) andalso (is_inductive_predicate thy c orelse is_registered thy c))
       
   560   end;
       
   561     
       
   562     
       
   563 (* code dependency graph *)    
       
   564 (*
       
   565 fun dependencies_of thy name =
       
   566   let
       
   567     val (intros, elim, nparams) = fetch_pred_data thy name 
       
   568     val data = mk_pred_data ((intros, SOME elim, nparams), ([], [], []))
       
   569     val keys = depending_preds_of thy intros
       
   570   in
       
   571     (data, keys)
       
   572   end;
       
   573 *)
       
   574 (* guessing number of parameters *)
       
   575 fun find_indexes pred xs =
       
   576   let
       
   577     fun find is n [] = is
       
   578       | find is n (x :: xs) = find (if pred x then (n :: is) else is) (n + 1) xs;
       
   579   in rev (find [] 0 xs) end;
       
   580 
       
   581 fun is_predT (T as Type("fun", [_, _])) = (snd (strip_type T) = HOLogic.boolT)
       
   582   | is_predT _ = false
       
   583   
       
   584 fun guess_nparams T =
       
   585   let
       
   586     val argTs = binder_types T
       
   587     val nparams = fold (curry Int.max)
       
   588       (map (fn x => x + 1) (find_indexes is_predT argTs)) 0
       
   589   in nparams end;
       
   590 
       
   591 fun add_intro thm thy = let
       
   592    val (name, T) = dest_Const (fst (strip_intro_concl 0 (prop_of thm)))
       
   593    fun cons_intro gr =
       
   594      case try (Graph.get_node gr) name of
       
   595        SOME pred_data => Graph.map_node name (map_pred_data
       
   596          (apfst (fn (intro, elim, nparams) => (thm::intro, elim, nparams)))) gr
       
   597      | NONE =>
       
   598        let
       
   599          val nparams = the_default (guess_nparams T)  (try (#nparams o rep_pred_data o (fetch_pred_data thy)) name)
       
   600        in Graph.new_node (name, mk_pred_data (([thm], NONE, nparams), ([], [], []))) gr end;
       
   601   in PredData.map cons_intro thy end
       
   602 
       
   603 fun set_elim thm = let
       
   604     val (name, _) = dest_Const (fst 
       
   605       (strip_comb (HOLogic.dest_Trueprop (hd (prems_of thm)))))
       
   606     fun set (intros, _, nparams) = (intros, SOME thm, nparams)  
       
   607   in PredData.map (Graph.map_node name (map_pred_data (apfst set))) end
       
   608 
       
   609 fun set_nparams name nparams = let
       
   610     fun set (intros, elim, _ ) = (intros, elim, nparams) 
       
   611   in PredData.map (Graph.map_node name (map_pred_data (apfst set))) end
       
   612     
       
   613 fun register_predicate (pre_intros, pre_elim, nparams) thy = let
       
   614     val (name, _) = dest_Const (fst (strip_intro_concl nparams (prop_of (hd pre_intros))))
       
   615     (* preprocessing *)
       
   616     val intros = ind_set_codegen_preproc thy (map (preprocess_intro thy) pre_intros)
       
   617     val elim = singleton (ind_set_codegen_preproc thy) (preprocess_elim thy nparams pre_elim)
       
   618   in
       
   619     PredData.map
       
   620       (Graph.new_node (name, mk_pred_data ((intros, SOME elim, nparams), ([], [], [])))) thy
       
   621   end
       
   622 
       
   623 fun set_generator_name pred mode name = 
       
   624   let
       
   625     val set = (apsnd o apsnd3 o cons) (mode, mk_function_data (name, NONE))
       
   626   in
       
   627     PredData.map (Graph.map_node pred (map_pred_data set))
       
   628   end
       
   629 
       
   630 fun set_sizelim_function_name pred mode name = 
       
   631   let
       
   632     val set = (apsnd o aptrd3 o cons) (mode, mk_function_data (name, NONE))
       
   633   in
       
   634     PredData.map (Graph.map_node pred (map_pred_data set))
       
   635   end
       
   636 
       
   637 (** data structures for generic compilation for different monads **)
       
   638 
       
   639 (* maybe rename functions more generic:
       
   640   mk_predT -> mk_monadT; dest_predT -> dest_monadT
       
   641   mk_single -> mk_return (?)
       
   642 *)
       
   643 datatype compilation_funs = CompilationFuns of {
       
   644   mk_predT : typ -> typ,
       
   645   dest_predT : typ -> typ,
       
   646   mk_bot : typ -> term,
       
   647   mk_single : term -> term,
       
   648   mk_bind : term * term -> term,
       
   649   mk_sup : term * term -> term,
       
   650   mk_if : term -> term,
       
   651   mk_not : term -> term,
       
   652 (*  funT_of : mode -> typ -> typ, *)
       
   653 (*  mk_fun_of : theory -> (string * typ) -> mode -> term, *) 
       
   654   mk_map : typ -> typ -> term -> term -> term,
       
   655   lift_pred : term -> term
       
   656 };
       
   657 
       
   658 fun mk_predT (CompilationFuns funs) = #mk_predT funs
       
   659 fun dest_predT (CompilationFuns funs) = #dest_predT funs
       
   660 fun mk_bot (CompilationFuns funs) = #mk_bot funs
       
   661 fun mk_single (CompilationFuns funs) = #mk_single funs
       
   662 fun mk_bind (CompilationFuns funs) = #mk_bind funs
       
   663 fun mk_sup (CompilationFuns funs) = #mk_sup funs
       
   664 fun mk_if (CompilationFuns funs) = #mk_if funs
       
   665 fun mk_not (CompilationFuns funs) = #mk_not funs
       
   666 (*fun funT_of (CompilationFuns funs) = #funT_of funs*)
       
   667 (*fun mk_fun_of (CompilationFuns funs) = #mk_fun_of funs*)
       
   668 fun mk_map (CompilationFuns funs) = #mk_map funs
       
   669 fun lift_pred (CompilationFuns funs) = #lift_pred funs
       
   670 
       
   671 fun funT_of compfuns (iss, is) T =
       
   672   let
       
   673     val Ts = binder_types T
       
   674     val (paramTs, (inargTs, outargTs)) = split_modeT (iss, is) Ts
       
   675     val paramTs' = map2 (fn NONE => I | SOME is => funT_of compfuns ([], is)) iss paramTs
       
   676   in
       
   677     (paramTs' @ inargTs) ---> (mk_predT compfuns (mk_tupleT outargTs))
       
   678   end;
       
   679 
       
   680 fun sizelim_funT_of compfuns (iss, is) T =
       
   681   let
       
   682     val Ts = binder_types T
       
   683     val (paramTs, (inargTs, outargTs)) = split_modeT (iss, is) Ts
       
   684     val paramTs' = map2 (fn SOME is => sizelim_funT_of compfuns ([], is) | NONE => I) iss paramTs 
       
   685   in
       
   686     (paramTs' @ inargTs @ [@{typ "code_numeral"}]) ---> (mk_predT compfuns (mk_tupleT outargTs))
       
   687   end;  
       
   688 
       
   689 fun mk_fun_of compfuns thy (name, T) mode = 
       
   690   Const (predfun_name_of thy name mode, funT_of compfuns mode T)
       
   691 
       
   692 fun mk_sizelim_fun_of compfuns thy (name, T) mode =
       
   693   Const (sizelim_function_name_of thy name mode, sizelim_funT_of compfuns mode T)
       
   694   
       
   695 fun mk_generator_of compfuns thy (name, T) mode = 
       
   696   Const (generator_name_of thy name mode, sizelim_funT_of compfuns mode T)
       
   697 
       
   698 
       
   699 structure PredicateCompFuns =
       
   700 struct
       
   701 
       
   702 fun mk_predT T = Type (@{type_name "Predicate.pred"}, [T])
       
   703 
       
   704 fun dest_predT (Type (@{type_name "Predicate.pred"}, [T])) = T
       
   705   | dest_predT T = raise TYPE ("dest_predT", [T], []);
       
   706 
       
   707 fun mk_bot T = Const (@{const_name Orderings.bot}, mk_predT T);
       
   708 
       
   709 fun mk_single t =
       
   710   let val T = fastype_of t
       
   711   in Const(@{const_name Predicate.single}, T --> mk_predT T) $ t end;
       
   712 
       
   713 fun mk_bind (x, f) =
       
   714   let val T as Type ("fun", [_, U]) = fastype_of f
       
   715   in
       
   716     Const (@{const_name Predicate.bind}, fastype_of x --> T --> U) $ x $ f
       
   717   end;
       
   718 
       
   719 val mk_sup = HOLogic.mk_binop @{const_name sup};
       
   720 
       
   721 fun mk_if cond = Const (@{const_name Predicate.if_pred},
       
   722   HOLogic.boolT --> mk_predT HOLogic.unitT) $ cond;
       
   723 
       
   724 fun mk_not t = let val T = mk_predT HOLogic.unitT
       
   725   in Const (@{const_name Predicate.not_pred}, T --> T) $ t end
       
   726 
       
   727 fun mk_Enum f =
       
   728   let val T as Type ("fun", [T', _]) = fastype_of f
       
   729   in
       
   730     Const (@{const_name Predicate.Pred}, T --> mk_predT T') $ f    
       
   731   end;
       
   732 
       
   733 fun mk_Eval (f, x) =
       
   734   let
       
   735     val T = fastype_of x
       
   736   in
       
   737     Const (@{const_name Predicate.eval}, mk_predT T --> T --> HOLogic.boolT) $ f $ x
       
   738   end;
       
   739 
       
   740 fun mk_map T1 T2 tf tp = Const (@{const_name Predicate.map},
       
   741   (T1 --> T2) --> mk_predT T1 --> mk_predT T2) $ tf $ tp;
       
   742 
       
   743 val lift_pred = I
       
   744 
       
   745 val compfuns = CompilationFuns {mk_predT = mk_predT, dest_predT = dest_predT, mk_bot = mk_bot,
       
   746   mk_single = mk_single, mk_bind = mk_bind, mk_sup = mk_sup, mk_if = mk_if, mk_not = mk_not,
       
   747   mk_map = mk_map, lift_pred = lift_pred};
       
   748 
       
   749 end;
       
   750 
       
   751 (* termify_code:
       
   752 val termT = Type ("Code_Eval.term", []);
       
   753 fun termifyT T = HOLogic.mk_prodT (T, HOLogic.unitT --> termT)
       
   754 *)
       
   755 (*
       
   756 fun lift_random random =
       
   757   let
       
   758     val T = dest_randomT (fastype_of random)
       
   759   in
       
   760     mk_scomp (random,
       
   761       mk_fun_comp (HOLogic.pair_const (PredicateCompFuns.mk_predT T) @{typ Random.seed},
       
   762         mk_fun_comp (Const (@{const_name Predicate.single}, T --> (PredicateCompFuns.mk_predT T)),
       
   763           Const (@{const_name "fst"}, HOLogic.mk_prodT (T, @{typ "unit => term"}) --> T)))) 
       
   764   end;
       
   765 *)
       
   766  
       
   767 structure RPredCompFuns =
       
   768 struct
       
   769 
       
   770 fun mk_rpredT T =
       
   771   @{typ "Random.seed"} --> HOLogic.mk_prodT (PredicateCompFuns.mk_predT T, @{typ "Random.seed"})
       
   772 
       
   773 fun dest_rpredT (Type ("fun", [_,
       
   774   Type (@{type_name "*"}, [Type (@{type_name "Predicate.pred"}, [T]), _])])) = T
       
   775   | dest_rpredT T = raise TYPE ("dest_rpredT", [T], []); 
       
   776 
       
   777 fun mk_bot T = Const(@{const_name RPred.bot}, mk_rpredT T)
       
   778 
       
   779 fun mk_single t =
       
   780   let
       
   781     val T = fastype_of t
       
   782   in
       
   783     Const (@{const_name RPred.return}, T --> mk_rpredT T) $ t
       
   784   end;
       
   785 
       
   786 fun mk_bind (x, f) =
       
   787   let
       
   788     val T as (Type ("fun", [_, U])) = fastype_of f
       
   789   in
       
   790     Const (@{const_name RPred.bind}, fastype_of x --> T --> U) $ x $ f
       
   791   end
       
   792 
       
   793 val mk_sup = HOLogic.mk_binop @{const_name RPred.supp}
       
   794 
       
   795 fun mk_if cond = Const (@{const_name RPred.if_rpred},
       
   796   HOLogic.boolT --> mk_rpredT HOLogic.unitT) $ cond;
       
   797 
       
   798 fun mk_not t = error "Negation is not defined for RPred"
       
   799 
       
   800 fun mk_map t = error "FIXME" (*FIXME*)
       
   801 
       
   802 fun lift_pred t =
       
   803   let
       
   804     val T = PredicateCompFuns.dest_predT (fastype_of t)
       
   805     val lift_predT = PredicateCompFuns.mk_predT T --> mk_rpredT T 
       
   806   in
       
   807     Const (@{const_name "RPred.lift_pred"}, lift_predT) $ t  
       
   808   end;
       
   809 
       
   810 val compfuns = CompilationFuns {mk_predT = mk_rpredT, dest_predT = dest_rpredT, mk_bot = mk_bot,
       
   811     mk_single = mk_single, mk_bind = mk_bind, mk_sup = mk_sup, mk_if = mk_if, mk_not = mk_not,
       
   812     mk_map = mk_map, lift_pred = lift_pred};
       
   813 
       
   814 end;
       
   815 (* for external use with interactive mode *)
       
   816 val rpred_compfuns = RPredCompFuns.compfuns;
       
   817 
       
   818 fun lift_random random =
       
   819   let
       
   820     val T = dest_randomT (fastype_of random)
       
   821   in
       
   822     Const (@{const_name lift_random}, (@{typ Random.seed} -->
       
   823       HOLogic.mk_prodT (HOLogic.mk_prodT (T, @{typ "unit => term"}), @{typ Random.seed})) --> 
       
   824       RPredCompFuns.mk_rpredT T) $ random
       
   825   end;
       
   826  
       
   827 (* Mode analysis *)
       
   828 
       
   829 (*** check if a term contains only constructor functions ***)
       
   830 fun is_constrt thy =
       
   831   let
       
   832     val cnstrs = flat (maps
       
   833       (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd)
       
   834       (Symtab.dest (Datatype.get_all thy)));
       
   835     fun check t = (case strip_comb t of
       
   836         (Free _, []) => true
       
   837       | (Const (s, T), ts) => (case (AList.lookup (op =) cnstrs s, body_type T) of
       
   838             (SOME (i, Tname), Type (Tname', _)) => length ts = i andalso Tname = Tname' andalso forall check ts
       
   839           | _ => false)
       
   840       | _ => false)
       
   841   in check end;
       
   842 
       
   843 (*** check if a type is an equality type (i.e. doesn't contain fun)
       
   844   FIXME this is only an approximation ***)
       
   845 fun is_eqT (Type (s, Ts)) = s <> "fun" andalso forall is_eqT Ts
       
   846   | is_eqT _ = true;
       
   847 
       
   848 fun term_vs tm = fold_aterms (fn Free (x, T) => cons x | _ => I) tm [];
       
   849 val terms_vs = distinct (op =) o maps term_vs;
       
   850 
       
   851 (** collect all Frees in a term (with duplicates!) **)
       
   852 fun term_vTs tm =
       
   853   fold_aterms (fn Free xT => cons xT | _ => I) tm [];
       
   854 
       
   855 (*FIXME this function should not be named merge... make it local instead*)
       
   856 fun merge xs [] = xs
       
   857   | merge [] ys = ys
       
   858   | merge (x::xs) (y::ys) = if length x >= length y then x::merge xs (y::ys)
       
   859       else y::merge (x::xs) ys;
       
   860 
       
   861 fun subsets i j = if i <= j then
       
   862        let val is = subsets (i+1) j
       
   863        in merge (map (fn ks => i::ks) is) is end
       
   864      else [[]];
       
   865      
       
   866 (* FIXME: should be in library - map_prod *)
       
   867 fun cprod ([], ys) = []
       
   868   | cprod (x :: xs, ys) = map (pair x) ys @ cprod (xs, ys);
       
   869 
       
   870 fun cprods xss = foldr (map op :: o cprod) [[]] xss;
       
   871 
       
   872 fun cprods_subset [] = [[]]
       
   873   | cprods_subset (xs :: xss) =
       
   874   let
       
   875     val yss = (cprods_subset xss)
       
   876   in maps (fn ys => map (fn x => cons x ys) xs) yss @ yss end
       
   877   
       
   878 (*TODO: cleanup function and put together with modes_of_term *)
       
   879 (*
       
   880 fun modes_of_param default modes t = let
       
   881     val (vs, t') = strip_abs t
       
   882     val b = length vs
       
   883     fun mk_modes name args = Option.map (maps (fn (m as (iss, is)) =>
       
   884         let
       
   885           val (args1, args2) =
       
   886             if length args < length iss then
       
   887               error ("Too few arguments for inductive predicate " ^ name)
       
   888             else chop (length iss) args;
       
   889           val k = length args2;
       
   890           val perm = map (fn i => (find_index_eq (Bound (b - i)) args2) + 1)
       
   891             (1 upto b)  
       
   892           val partial_mode = (1 upto k) \\ perm
       
   893         in
       
   894           if not (partial_mode subset is) then [] else
       
   895           let
       
   896             val is' = 
       
   897             (fold_index (fn (i, j) => if j mem is then cons (i + 1) else I) perm [])
       
   898             |> fold (fn i => if i > k then cons (i - k + b) else I) is
       
   899               
       
   900            val res = map (fn x => Mode (m, is', x)) (cprods (map
       
   901             (fn (NONE, _) => [NONE]
       
   902               | (SOME js, arg) => map SOME (filter
       
   903                   (fn Mode (_, js', _) => js=js') (modes_of_term modes arg)))
       
   904                     (iss ~~ args1)))
       
   905           in res end
       
   906         end)) (AList.lookup op = modes name)
       
   907   in case strip_comb t' of
       
   908     (Const (name, _), args) => the_default default (mk_modes name args)
       
   909     | (Var ((name, _), _), args) => the (mk_modes name args)
       
   910     | (Free (name, _), args) => the (mk_modes name args)
       
   911     | _ => default end
       
   912   
       
   913 and
       
   914 *)
       
   915 fun modes_of_term modes t =
       
   916   let
       
   917     val ks = map_index (fn (i, T) => (i, NONE)) (binder_types (fastype_of t));
       
   918     val default = [Mode (([], ks), ks, [])];
       
   919     fun mk_modes name args = Option.map (maps (fn (m as (iss, is)) =>
       
   920         let
       
   921           val (args1, args2) =
       
   922             if length args < length iss then
       
   923               error ("Too few arguments for inductive predicate " ^ name)
       
   924             else chop (length iss) args;
       
   925           val k = length args2;
       
   926           val prfx = map (rpair NONE) (1 upto k)
       
   927         in
       
   928           if not (is_prefix op = prfx is) then [] else
       
   929           let val is' = List.drop (is, k)
       
   930           in map (fn x => Mode (m, is', x)) (cprods (map
       
   931             (fn (NONE, _) => [NONE]
       
   932               | (SOME js, arg) => map SOME (filter
       
   933                   (fn Mode (_, js', _) => js=js') (modes_of_term modes arg)))
       
   934                     (iss ~~ args1)))
       
   935           end
       
   936         end)) (AList.lookup op = modes name)
       
   937 
       
   938   in
       
   939     case strip_comb (Envir.eta_contract t) of
       
   940       (Const (name, _), args) => the_default default (mk_modes name args)
       
   941     | (Var ((name, _), _), args) => the (mk_modes name args)
       
   942     | (Free (name, _), args) => the (mk_modes name args)
       
   943     | (Abs _, []) => error "Abs at param position" (* modes_of_param default modes t *)
       
   944     | _ => default
       
   945   end
       
   946   
       
   947 fun select_mode_prem thy modes vs ps =
       
   948   find_first (is_some o snd) (ps ~~ map
       
   949     (fn Prem (us, t) => find_first (fn Mode (_, is, _) =>
       
   950           let
       
   951             val (in_ts, out_ts) = split_smode is us;
       
   952             val (out_ts', in_ts') = List.partition (is_constrt thy) out_ts;
       
   953             val vTs = maps term_vTs out_ts';
       
   954             val dupTs = map snd (duplicates (op =) vTs) @
       
   955               List.mapPartial (AList.lookup (op =) vTs) vs;
       
   956           in
       
   957             terms_vs (in_ts @ in_ts') subset vs andalso
       
   958             forall (is_eqT o fastype_of) in_ts' andalso
       
   959             term_vs t subset vs andalso
       
   960             forall is_eqT dupTs
       
   961           end)
       
   962             (modes_of_term modes t handle Option =>
       
   963                error ("Bad predicate: " ^ Syntax.string_of_term_global thy t))
       
   964       | Negprem (us, t) => find_first (fn Mode (_, is, _) =>
       
   965             length us = length is andalso
       
   966             terms_vs us subset vs andalso
       
   967             term_vs t subset vs)
       
   968             (modes_of_term modes t handle Option =>
       
   969                error ("Bad predicate: " ^ Syntax.string_of_term_global thy t))
       
   970       | Sidecond t => if term_vs t subset vs then SOME (Mode (([], []), [], []))
       
   971           else NONE
       
   972       ) ps);
       
   973 
       
   974 fun fold_prem f (Prem (args, _)) = fold f args
       
   975   | fold_prem f (Negprem (args, _)) = fold f args
       
   976   | fold_prem f (Sidecond t) = f t
       
   977 
       
   978 fun all_subsets [] = [[]]
       
   979   | all_subsets (x::xs) = let val xss' = all_subsets xs in xss' @ (map (cons x) xss') end
       
   980 
       
   981 fun generator vTs v = 
       
   982   let
       
   983     val T = the (AList.lookup (op =) vTs v)
       
   984   in
       
   985     (Generator (v, T), Mode (([], []), [], []))
       
   986   end;
       
   987 
       
   988 fun gen_prem (Prem (us, t)) = GeneratorPrem (us, t) 
       
   989   | gen_prem _ = error "gen_prem : invalid input for gen_prem"
       
   990 
       
   991 fun param_gen_prem param_vs (p as Prem (us, t as Free (v, _))) =
       
   992   if member (op =) param_vs v then
       
   993     GeneratorPrem (us, t)
       
   994   else p  
       
   995   | param_gen_prem param_vs p = p
       
   996   
       
   997 fun check_mode_clause with_generator thy param_vs modes gen_modes (iss, is) (ts, ps) =
       
   998   let
       
   999     val modes' = modes @ List.mapPartial
       
  1000       (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
       
  1001         (param_vs ~~ iss);
       
  1002     val gen_modes' = gen_modes @ List.mapPartial
       
  1003       (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
       
  1004         (param_vs ~~ iss);  
       
  1005     val vTs = distinct (op =) ((fold o fold_prem) Term.add_frees ps (fold Term.add_frees ts []))
       
  1006     val prem_vs = distinct (op =) ((fold o fold_prem) Term.add_free_names ps [])
       
  1007     fun check_mode_prems acc_ps vs [] = SOME (acc_ps, vs)
       
  1008       | check_mode_prems acc_ps vs ps = (case select_mode_prem thy modes' vs ps of
       
  1009           NONE =>
       
  1010             (if with_generator then
       
  1011               (case select_mode_prem thy gen_modes' vs ps of
       
  1012                   SOME (p, SOME mode) => check_mode_prems ((gen_prem p, mode) :: acc_ps) 
       
  1013                   (case p of Prem (us, _) => vs union terms_vs us | _ => vs)
       
  1014                   (filter_out (equal p) ps)
       
  1015                 | NONE =>
       
  1016                   let 
       
  1017                     val all_generator_vs = all_subsets (prem_vs \\ vs) |> sort (int_ord o (pairself length))
       
  1018                   in
       
  1019                     case (find_first (fn generator_vs => is_some
       
  1020                       (select_mode_prem thy modes' (vs union generator_vs) ps)) all_generator_vs) of
       
  1021                       SOME generator_vs => check_mode_prems ((map (generator vTs) generator_vs) @ acc_ps)
       
  1022                         (vs union generator_vs) ps
       
  1023                     | NONE => NONE
       
  1024                   end)
       
  1025             else
       
  1026               NONE)
       
  1027         | SOME (p, SOME mode) => check_mode_prems ((if with_generator then param_gen_prem param_vs p else p, mode) :: acc_ps) 
       
  1028             (case p of Prem (us, _) => vs union terms_vs us | _ => vs)
       
  1029             (filter_out (equal p) ps))
       
  1030     val (in_ts, in_ts') = List.partition (is_constrt thy) (fst (split_smode is ts));
       
  1031     val in_vs = terms_vs in_ts;
       
  1032     val concl_vs = terms_vs ts
       
  1033   in
       
  1034     if forall is_eqT (map snd (duplicates (op =) (maps term_vTs in_ts))) andalso
       
  1035     forall (is_eqT o fastype_of) in_ts' then
       
  1036       case check_mode_prems [] (param_vs union in_vs) ps of
       
  1037          NONE => NONE
       
  1038        | SOME (acc_ps, vs) =>
       
  1039          if with_generator then
       
  1040            SOME (ts, (rev acc_ps) @ (map (generator vTs) (concl_vs \\ vs))) 
       
  1041          else
       
  1042            if concl_vs subset vs then SOME (ts, rev acc_ps) else NONE
       
  1043     else NONE
       
  1044   end;
       
  1045 
       
  1046 fun check_modes_pred with_generator thy param_vs clauses modes gen_modes (p, ms) =
       
  1047   let val SOME rs = AList.lookup (op =) clauses p
       
  1048   in (p, List.filter (fn m => case find_index
       
  1049     (is_none o check_mode_clause with_generator thy param_vs modes gen_modes m) rs of
       
  1050       ~1 => true
       
  1051     | i => (Output.tracing ("Clause " ^ string_of_int (i + 1) ^ " of " ^
       
  1052       p ^ " violates mode " ^ string_of_mode m);
       
  1053         Output.tracing (commas (map (Syntax.string_of_term_global thy) (fst (nth rs i)))); false)) ms)
       
  1054   end;
       
  1055 
       
  1056 fun get_modes_pred with_generator thy param_vs clauses modes gen_modes (p, ms) =
       
  1057   let
       
  1058     val SOME rs = AList.lookup (op =) clauses p 
       
  1059   in
       
  1060     (p, map (fn m =>
       
  1061       (m, map (the o check_mode_clause with_generator thy param_vs modes gen_modes m) rs)) ms)
       
  1062   end;
       
  1063   
       
  1064 fun fixp f (x : (string * mode list) list) =
       
  1065   let val y = f x
       
  1066   in if x = y then x else fixp f y end;
       
  1067 
       
  1068 fun infer_modes thy extra_modes all_modes param_vs clauses =
       
  1069   let
       
  1070     val modes =
       
  1071       fixp (fn modes =>
       
  1072         map (check_modes_pred false thy param_vs clauses (modes @ extra_modes) []) modes)
       
  1073           all_modes
       
  1074   in
       
  1075     map (get_modes_pred false thy param_vs clauses (modes @ extra_modes) []) modes
       
  1076   end;
       
  1077 
       
  1078 fun remove_from rem [] = []
       
  1079   | remove_from rem ((k, vs) :: xs) =
       
  1080     (case AList.lookup (op =) rem k of
       
  1081       NONE => (k, vs)
       
  1082     | SOME vs' => (k, vs \\ vs'))
       
  1083     :: remove_from rem xs
       
  1084     
       
  1085 fun infer_modes_with_generator thy extra_modes all_modes param_vs clauses =
       
  1086   let
       
  1087     val prednames = map fst clauses
       
  1088     val extra_modes = all_modes_of thy
       
  1089     val gen_modes = all_generator_modes_of thy
       
  1090       |> filter_out (fn (name, _) => member (op =) prednames name)
       
  1091     val starting_modes = remove_from extra_modes all_modes 
       
  1092     val modes =
       
  1093       fixp (fn modes =>
       
  1094         map (check_modes_pred true thy param_vs clauses extra_modes (gen_modes @ modes)) modes)
       
  1095          starting_modes 
       
  1096   in
       
  1097     map (get_modes_pred true thy param_vs clauses extra_modes (gen_modes @ modes)) modes
       
  1098   end;
       
  1099 
       
  1100 (* term construction *)
       
  1101 
       
  1102 fun mk_v (names, vs) s T = (case AList.lookup (op =) vs s of
       
  1103       NONE => (Free (s, T), (names, (s, [])::vs))
       
  1104     | SOME xs =>
       
  1105         let
       
  1106           val s' = Name.variant names s;
       
  1107           val v = Free (s', T)
       
  1108         in
       
  1109           (v, (s'::names, AList.update (op =) (s, v::xs) vs))
       
  1110         end);
       
  1111 
       
  1112 fun distinct_v (Free (s, T)) nvs = mk_v nvs s T
       
  1113   | distinct_v (t $ u) nvs =
       
  1114       let
       
  1115         val (t', nvs') = distinct_v t nvs;
       
  1116         val (u', nvs'') = distinct_v u nvs';
       
  1117       in (t' $ u', nvs'') end
       
  1118   | distinct_v x nvs = (x, nvs);
       
  1119 
       
  1120 fun compile_match thy compfuns eqs eqs' out_ts success_t =
       
  1121   let
       
  1122     val eqs'' = maps mk_eq eqs @ eqs'
       
  1123     val names = fold Term.add_free_names (success_t :: eqs'' @ out_ts) [];
       
  1124     val name = Name.variant names "x";
       
  1125     val name' = Name.variant (name :: names) "y";
       
  1126     val T = mk_tupleT (map fastype_of out_ts);
       
  1127     val U = fastype_of success_t;
       
  1128     val U' = dest_predT compfuns U;
       
  1129     val v = Free (name, T);
       
  1130     val v' = Free (name', T);
       
  1131   in
       
  1132     lambda v (fst (Datatype.make_case
       
  1133       (ProofContext.init thy) false [] v
       
  1134       [(mk_tuple out_ts,
       
  1135         if null eqs'' then success_t
       
  1136         else Const (@{const_name HOL.If}, HOLogic.boolT --> U --> U --> U) $
       
  1137           foldr1 HOLogic.mk_conj eqs'' $ success_t $
       
  1138             mk_bot compfuns U'),
       
  1139        (v', mk_bot compfuns U')]))
       
  1140   end;
       
  1141 
       
  1142 (*FIXME function can be removed*)
       
  1143 fun mk_funcomp f t =
       
  1144   let
       
  1145     val names = Term.add_free_names t [];
       
  1146     val Ts = binder_types (fastype_of t);
       
  1147     val vs = map Free
       
  1148       (Name.variant_list names (replicate (length Ts) "x") ~~ Ts)
       
  1149   in
       
  1150     fold_rev lambda vs (f (list_comb (t, vs)))
       
  1151   end;
       
  1152 (*
       
  1153 fun compile_param_ext thy compfuns modes (NONE, t) = t
       
  1154   | compile_param_ext thy compfuns modes (m as SOME (Mode ((iss, is'), is, ms)), t) =
       
  1155       let
       
  1156         val (vs, u) = strip_abs t
       
  1157         val (ivs, ovs) = split_mode is vs    
       
  1158         val (f, args) = strip_comb u
       
  1159         val (params, args') = chop (length ms) args
       
  1160         val (inargs, outargs) = split_mode is' args'
       
  1161         val b = length vs
       
  1162         val perm = map (fn i => (find_index_eq (Bound (b - i)) args') + 1) (1 upto b)
       
  1163         val outp_perm =
       
  1164           snd (split_mode is perm)
       
  1165           |> map (fn i => i - length (filter (fn x => x < i) is'))
       
  1166         val names = [] -- TODO
       
  1167         val out_names = Name.variant_list names (replicate (length outargs) "x")
       
  1168         val f' = case f of
       
  1169             Const (name, T) =>
       
  1170               if AList.defined op = modes name then
       
  1171                 mk_predfun_of thy compfuns (name, T) (iss, is')
       
  1172               else error "compile param: Not an inductive predicate with correct mode"
       
  1173           | Free (name, T) => Free (name, param_funT_of compfuns T (SOME is'))
       
  1174         val outTs = dest_tupleT (dest_predT compfuns (body_type (fastype_of f')))
       
  1175         val out_vs = map Free (out_names ~~ outTs)
       
  1176         val params' = map (compile_param thy modes) (ms ~~ params)
       
  1177         val f_app = list_comb (f', params' @ inargs)
       
  1178         val single_t = (mk_single compfuns (mk_tuple (map (fn i => nth out_vs (i - 1)) outp_perm)))
       
  1179         val match_t = compile_match thy compfuns [] [] out_vs single_t
       
  1180       in list_abs (ivs,
       
  1181         mk_bind compfuns (f_app, match_t))
       
  1182       end
       
  1183   | compile_param_ext _ _ _ _ = error "compile params"
       
  1184 *)
       
  1185 
       
  1186 fun compile_param size thy compfuns (NONE, t) = t
       
  1187   | compile_param size thy compfuns (m as SOME (Mode ((iss, is'), is, ms)), t) =
       
  1188    let
       
  1189      val (f, args) = strip_comb (Envir.eta_contract t)
       
  1190      val (params, args') = chop (length ms) args
       
  1191      val params' = map (compile_param size thy compfuns) (ms ~~ params)
       
  1192      val mk_fun_of = case size of NONE => mk_fun_of | SOME _ => mk_sizelim_fun_of
       
  1193      val funT_of = case size of NONE => funT_of | SOME _ => sizelim_funT_of
       
  1194      val f' =
       
  1195        case f of
       
  1196          Const (name, T) =>
       
  1197            mk_fun_of compfuns thy (name, T) (iss, is')
       
  1198        | Free (name, T) => Free (name, funT_of compfuns (iss, is') T)
       
  1199        | _ => error ("PredicateCompiler: illegal parameter term")
       
  1200    in list_comb (f', params' @ args') end
       
  1201    
       
  1202 fun compile_expr size thy ((Mode (mode, is, ms)), t) =
       
  1203   case strip_comb t of
       
  1204     (Const (name, T), params) =>
       
  1205        let
       
  1206          val params' = map (compile_param size thy PredicateCompFuns.compfuns) (ms ~~ params)
       
  1207          val mk_fun_of = case size of NONE => mk_fun_of | SOME _ => mk_sizelim_fun_of
       
  1208        in
       
  1209          list_comb (mk_fun_of PredicateCompFuns.compfuns thy (name, T) mode, params')
       
  1210        end
       
  1211   | (Free (name, T), args) =>
       
  1212        let 
       
  1213          val funT_of = case size of NONE => funT_of | SOME _ => sizelim_funT_of 
       
  1214        in
       
  1215          list_comb (Free (name, funT_of PredicateCompFuns.compfuns ([], is) T), args)
       
  1216        end;
       
  1217        
       
  1218 fun compile_gen_expr size thy compfuns ((Mode (mode, is, ms)), t) =
       
  1219   case strip_comb t of
       
  1220     (Const (name, T), params) =>
       
  1221       let
       
  1222         val params' = map (compile_param size thy compfuns) (ms ~~ params)
       
  1223       in
       
  1224         list_comb (mk_generator_of compfuns thy (name, T) mode, params')
       
  1225       end
       
  1226     | (Free (name, T), args) =>
       
  1227       list_comb (Free (name, sizelim_funT_of RPredCompFuns.compfuns ([], is) T), args)
       
  1228           
       
  1229 (** specific rpred functions -- move them to the correct place in this file *)
       
  1230 
       
  1231 (* uncommented termify code; causes more trouble than expected at first *) 
       
  1232 (*
       
  1233 fun mk_valtermify_term (t as Const (c, T)) = HOLogic.mk_prod (t, Abs ("u", HOLogic.unitT, HOLogic.reflect_term t))
       
  1234   | mk_valtermify_term (Free (x, T)) = Free (x, termifyT T) 
       
  1235   | mk_valtermify_term (t1 $ t2) =
       
  1236     let
       
  1237       val T = fastype_of t1
       
  1238       val (T1, T2) = dest_funT T
       
  1239       val t1' = mk_valtermify_term t1
       
  1240       val t2' = mk_valtermify_term t2
       
  1241     in
       
  1242       Const ("Code_Eval.valapp", termifyT T --> termifyT T1 --> termifyT T2) $ t1' $ t2'
       
  1243     end
       
  1244   | mk_valtermify_term _ = error "Not a valid term for mk_valtermify_term"
       
  1245 *)
       
  1246 
       
  1247 fun compile_clause compfuns size final_term thy all_vs param_vs (iss, is) inp (ts, moded_ps) =
       
  1248   let
       
  1249     fun check_constrt t (names, eqs) =
       
  1250       if is_constrt thy t then (t, (names, eqs)) else
       
  1251         let
       
  1252           val s = Name.variant names "x";
       
  1253           val v = Free (s, fastype_of t)
       
  1254         in (v, (s::names, HOLogic.mk_eq (v, t)::eqs)) end;
       
  1255 
       
  1256     val (in_ts, out_ts) = split_smode is ts;
       
  1257     val (in_ts', (all_vs', eqs)) =
       
  1258       fold_map check_constrt in_ts (all_vs, []);
       
  1259 
       
  1260     fun compile_prems out_ts' vs names [] =
       
  1261           let
       
  1262             val (out_ts'', (names', eqs')) =
       
  1263               fold_map check_constrt out_ts' (names, []);
       
  1264             val (out_ts''', (names'', constr_vs)) = fold_map distinct_v
       
  1265               out_ts'' (names', map (rpair []) vs);
       
  1266           in
       
  1267           (* termify code:
       
  1268             compile_match thy compfuns constr_vs (eqs @ eqs') out_ts'''
       
  1269               (mk_single compfuns (mk_tuple (map mk_valtermify_term out_ts)))
       
  1270            *)
       
  1271             compile_match thy compfuns constr_vs (eqs @ eqs') out_ts'''
       
  1272               (final_term out_ts)
       
  1273           end
       
  1274       | compile_prems out_ts vs names ((p, mode as Mode ((_, is), _, _)) :: ps) =
       
  1275           let
       
  1276             val vs' = distinct (op =) (flat (vs :: map term_vs out_ts));
       
  1277             val (out_ts', (names', eqs)) =
       
  1278               fold_map check_constrt out_ts (names, [])
       
  1279             val (out_ts'', (names'', constr_vs')) = fold_map distinct_v
       
  1280               out_ts' ((names', map (rpair []) vs))
       
  1281             val (compiled_clause, rest) = case p of
       
  1282                Prem (us, t) =>
       
  1283                  let
       
  1284                    val (in_ts, out_ts''') = split_smode is us;
       
  1285                    val args = case size of
       
  1286                      NONE => in_ts
       
  1287                    | SOME size_t => in_ts @ [size_t]
       
  1288                    val u = lift_pred compfuns
       
  1289                      (list_comb (compile_expr size thy (mode, t), args))                     
       
  1290                    val rest = compile_prems out_ts''' vs' names'' ps
       
  1291                  in
       
  1292                    (u, rest)
       
  1293                  end
       
  1294              | Negprem (us, t) =>
       
  1295                  let
       
  1296                    val (in_ts, out_ts''') = split_smode is us
       
  1297                    val u = lift_pred compfuns
       
  1298                      (mk_not PredicateCompFuns.compfuns (list_comb (compile_expr NONE thy (mode, t), in_ts)))
       
  1299                    val rest = compile_prems out_ts''' vs' names'' ps
       
  1300                  in
       
  1301                    (u, rest)
       
  1302                  end
       
  1303              | Sidecond t =>
       
  1304                  let
       
  1305                    val rest = compile_prems [] vs' names'' ps;
       
  1306                  in
       
  1307                    (mk_if compfuns t, rest)
       
  1308                  end
       
  1309              | GeneratorPrem (us, t) =>
       
  1310                  let
       
  1311                    val (in_ts, out_ts''') = split_smode is us;
       
  1312                    val args = case size of
       
  1313                      NONE => in_ts
       
  1314                    | SOME size_t => in_ts @ [size_t]
       
  1315                    val u = list_comb (compile_gen_expr size thy compfuns (mode, t), args)
       
  1316                    val rest = compile_prems out_ts''' vs' names'' ps
       
  1317                  in
       
  1318                    (u, rest)
       
  1319                  end
       
  1320              | Generator (v, T) =>
       
  1321                  let
       
  1322                    val u = lift_random (HOLogic.mk_random T @{term "1::code_numeral"})
       
  1323                    val rest = compile_prems [Free (v, T)]  vs' names'' ps;
       
  1324                  in
       
  1325                    (u, rest)
       
  1326                  end
       
  1327           in
       
  1328             compile_match thy compfuns constr_vs' eqs out_ts'' 
       
  1329               (mk_bind compfuns (compiled_clause, rest))
       
  1330           end
       
  1331     val prem_t = compile_prems in_ts' param_vs all_vs' moded_ps;
       
  1332   in
       
  1333     mk_bind compfuns (mk_single compfuns inp, prem_t)
       
  1334   end
       
  1335 
       
  1336 fun compile_pred compfuns mk_fun_of use_size thy all_vs param_vs s T mode moded_cls =
       
  1337   let
       
  1338 	  val (Ts1, Ts2) = chop (length (fst mode)) (binder_types T)
       
  1339     val (Us1, Us2) = split_smodeT (snd mode) Ts2
       
  1340     val funT_of = if use_size then sizelim_funT_of else funT_of
       
  1341     val Ts1' = map2 (fn NONE => I | SOME is => funT_of compfuns ([], is)) (fst mode) Ts1
       
  1342     val size_name = Name.variant (all_vs @ param_vs) "size"
       
  1343   	fun mk_input_term (i, NONE) =
       
  1344 		    [Free (Name.variant (all_vs @ param_vs) ("x" ^ string_of_int i), nth Ts2 (i - 1))]
       
  1345 		  | mk_input_term (i, SOME pis) = case HOLogic.strip_tupleT (nth Ts2 (i - 1)) of
       
  1346 						   [] => error "strange unit input"
       
  1347 					   | [T] => [Free (Name.variant (all_vs @ param_vs) ("x" ^ string_of_int i), nth Ts2 (i - 1))]
       
  1348 						 | Ts => let
       
  1349 							 val vnames = Name.variant_list (all_vs @ param_vs)
       
  1350 								(map (fn j => "x" ^ string_of_int i ^ "p" ^ string_of_int j)
       
  1351 									pis)
       
  1352 						 in if null pis then []
       
  1353 						   else [HOLogic.mk_tuple (map Free (vnames ~~ map (fn j => nth Ts (j - 1)) pis))] end
       
  1354 		val in_ts = maps mk_input_term (snd mode)
       
  1355     val params = map2 (fn s => fn T => Free (s, T)) param_vs Ts1'
       
  1356     val size = Free (size_name, @{typ "code_numeral"})
       
  1357     val decr_size =
       
  1358       if use_size then
       
  1359         SOME (Const ("HOL.minus_class.minus", @{typ "code_numeral => code_numeral => code_numeral"})
       
  1360           $ size $ Const ("HOL.one_class.one", @{typ "Code_Numeral.code_numeral"}))
       
  1361       else
       
  1362         NONE
       
  1363     val cl_ts =
       
  1364       map (compile_clause compfuns decr_size (fn out_ts => mk_single compfuns (mk_tuple out_ts))
       
  1365         thy all_vs param_vs mode (mk_tuple in_ts)) moded_cls;
       
  1366     val t = foldr1 (mk_sup compfuns) cl_ts
       
  1367     val T' = mk_predT compfuns (mk_tupleT Us2)
       
  1368     val size_t = Const (@{const_name "If"}, @{typ bool} --> T' --> T' --> T')
       
  1369       $ HOLogic.mk_eq (size, @{term "0 :: code_numeral"})
       
  1370       $ mk_bot compfuns (dest_predT compfuns T') $ t
       
  1371     val fun_const = mk_fun_of compfuns thy (s, T) mode
       
  1372     val eq = if use_size then
       
  1373       (list_comb (fun_const, params @ in_ts @ [size]), size_t)
       
  1374     else
       
  1375       (list_comb (fun_const, params @ in_ts), t)
       
  1376   in
       
  1377     HOLogic.mk_Trueprop (HOLogic.mk_eq eq)
       
  1378   end;
       
  1379   
       
  1380 (* special setup for simpset *)                  
       
  1381 val HOL_basic_ss' = HOL_basic_ss addsimps (@{thms "HOL.simp_thms"} @ [@{thm Pair_eq}])
       
  1382   setSolver (mk_solver "all_tac_solver" (fn _ => fn _ => all_tac))
       
  1383 	setSolver (mk_solver "True_solver" (fn _ => rtac @{thm TrueI}))
       
  1384 
       
  1385 (* Definition of executable functions and their intro and elim rules *)
       
  1386 
       
  1387 fun print_arities arities = tracing ("Arities:\n" ^
       
  1388   cat_lines (map (fn (s, (ks, k)) => s ^ ": " ^
       
  1389     space_implode " -> " (map
       
  1390       (fn NONE => "X" | SOME k' => string_of_int k')
       
  1391         (ks @ [SOME k]))) arities));
       
  1392 
       
  1393 fun mk_Eval_of ((x, T), NONE) names = (x, names)
       
  1394   | mk_Eval_of ((x, T), SOME mode) names =
       
  1395 	let
       
  1396     val Ts = binder_types T
       
  1397     (*val argnames = Name.variant_list names
       
  1398         (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
       
  1399     val args = map Free (argnames ~~ Ts)
       
  1400     val (inargs, outargs) = split_smode mode args*)
       
  1401 		fun mk_split_lambda [] t = lambda (Free (Name.variant names "x", HOLogic.unitT)) t
       
  1402 			| mk_split_lambda [x] t = lambda x t
       
  1403 			| mk_split_lambda xs t =
       
  1404 			let
       
  1405 				fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t))
       
  1406 					| mk_split_lambda' (x::xs) t = HOLogic.mk_split (lambda x (mk_split_lambda' xs t))
       
  1407 			in
       
  1408 				mk_split_lambda' xs t
       
  1409 			end;
       
  1410   	fun mk_arg (i, T) =
       
  1411 		  let
       
  1412 	  	  val vname = Name.variant names ("x" ^ string_of_int i)
       
  1413 		    val default = Free (vname, T)
       
  1414 		  in 
       
  1415 		    case AList.lookup (op =) mode i of
       
  1416 		      NONE => (([], [default]), [default])
       
  1417 			  | SOME NONE => (([default], []), [default])
       
  1418 			  | SOME (SOME pis) =>
       
  1419 				  case HOLogic.strip_tupleT T of
       
  1420 						[] => error "pair mode but unit tuple" (*(([default], []), [default])*)
       
  1421 					| [_] => error "pair mode but not a tuple" (*(([default], []), [default])*)
       
  1422 					| Ts =>
       
  1423 					  let
       
  1424 							val vnames = Name.variant_list names
       
  1425 								(map (fn j => "x" ^ string_of_int i ^ "p" ^ string_of_int j)
       
  1426 									(1 upto length Ts))
       
  1427 							val args = map Free (vnames ~~ Ts)
       
  1428 							fun split_args (i, arg) (ins, outs) =
       
  1429 							  if member (op =) pis i then
       
  1430 							    (arg::ins, outs)
       
  1431 								else
       
  1432 								  (ins, arg::outs)
       
  1433 							val (inargs, outargs) = fold_rev split_args ((1 upto length Ts) ~~ args) ([], [])
       
  1434 							fun tuple args = if null args then [] else [HOLogic.mk_tuple args]
       
  1435 						in ((tuple inargs, tuple outargs), args) end
       
  1436 			end
       
  1437 		val (inoutargs, args) = split_list (map mk_arg (1 upto (length Ts) ~~ Ts))
       
  1438     val (inargs, outargs) = pairself flat (split_list inoutargs)
       
  1439 		val r = PredicateCompFuns.mk_Eval (list_comb (x, inargs), mk_tuple outargs)
       
  1440     val t = fold_rev mk_split_lambda args r
       
  1441   in
       
  1442     (t, names)
       
  1443   end;
       
  1444 
       
  1445 fun create_intro_elim_rule (mode as (iss, is)) defthm mode_id funT pred thy =
       
  1446 let
       
  1447   val Ts = binder_types (fastype_of pred)
       
  1448   val funtrm = Const (mode_id, funT)
       
  1449   val (Ts1, Ts2) = chop (length iss) Ts;
       
  1450   val Ts1' = map2 (fn NONE => I | SOME is => funT_of (PredicateCompFuns.compfuns) ([], is)) iss Ts1
       
  1451 	val param_names = Name.variant_list []
       
  1452     (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts1)));
       
  1453   val params = map Free (param_names ~~ Ts1')
       
  1454 	fun mk_args (i, T) argnames =
       
  1455     let
       
  1456 		  val vname = Name.variant (param_names @ argnames) ("x" ^ string_of_int (length Ts1' + i))
       
  1457 		  val default = (Free (vname, T), vname :: argnames)
       
  1458 	  in
       
  1459   	  case AList.lookup (op =) is i of
       
  1460 						 NONE => default
       
  1461 					 | SOME NONE => default
       
  1462         	 | SOME (SOME pis) =>
       
  1463 					   case HOLogic.strip_tupleT T of
       
  1464 						   [] => default
       
  1465 					   | [_] => default
       
  1466 						 | Ts => 
       
  1467 						let
       
  1468 							val vnames = Name.variant_list (param_names @ argnames)
       
  1469 								(map (fn j => "x" ^ string_of_int (length Ts1' + i) ^ "p" ^ string_of_int j)
       
  1470 									(1 upto (length Ts)))
       
  1471 						 in (HOLogic.mk_tuple (map Free (vnames ~~ Ts)), vnames  @ argnames) end
       
  1472 		end
       
  1473 	val (args, argnames) = fold_map mk_args (1 upto (length Ts2) ~~ Ts2) []
       
  1474   val (inargs, outargs) = split_smode is args
       
  1475   val param_names' = Name.variant_list (param_names @ argnames)
       
  1476     (map (fn i => "p" ^ string_of_int i) (1 upto (length iss)))
       
  1477   val param_vs = map Free (param_names' ~~ Ts1)
       
  1478   val (params', names) = fold_map mk_Eval_of ((params ~~ Ts1) ~~ iss) []
       
  1479   val predpropI = HOLogic.mk_Trueprop (list_comb (pred, param_vs @ args))
       
  1480   val predpropE = HOLogic.mk_Trueprop (list_comb (pred, params' @ args))
       
  1481   val param_eqs = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (param_vs ~~ params')
       
  1482   val funargs = params @ inargs
       
  1483   val funpropE = HOLogic.mk_Trueprop (PredicateCompFuns.mk_Eval (list_comb (funtrm, funargs),
       
  1484                   if null outargs then Free("y", HOLogic.unitT) else mk_tuple outargs))
       
  1485   val funpropI = HOLogic.mk_Trueprop (PredicateCompFuns.mk_Eval (list_comb (funtrm, funargs),
       
  1486                    mk_tuple outargs))
       
  1487   val introtrm = Logic.list_implies (predpropI :: param_eqs, funpropI)
       
  1488   val simprules = [defthm, @{thm eval_pred},
       
  1489 	  @{thm "split_beta"}, @{thm "fst_conv"}, @{thm "snd_conv"}, @{thm pair_collapse}]
       
  1490   val unfolddef_tac = Simplifier.asm_full_simp_tac (HOL_basic_ss addsimps simprules) 1
       
  1491   val introthm = Goal.prove (ProofContext.init thy) (argnames @ param_names @ param_names' @ ["y"]) [] introtrm (fn {...} => unfolddef_tac)
       
  1492   val P = HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT));
       
  1493   val elimtrm = Logic.list_implies ([funpropE, Logic.mk_implies (predpropE, P)], P)
       
  1494   val elimthm = Goal.prove (ProofContext.init thy) (argnames @ param_names @ param_names' @ ["y", "P"]) [] elimtrm (fn {...} => unfolddef_tac)
       
  1495 	val _ = Output.tracing (Display.string_of_thm_global thy elimthm)
       
  1496 	val _ = Output.tracing (Display.string_of_thm_global thy introthm)
       
  1497 
       
  1498 in
       
  1499   (introthm, elimthm)
       
  1500 end;
       
  1501 
       
  1502 fun create_constname_of_mode thy prefix name mode = 
       
  1503   let
       
  1504     fun string_of_mode mode = if null mode then "0"
       
  1505       else space_implode "_" (map (fn (i, NONE) => string_of_int i | (i, SOME pis) => string_of_int i ^ "p"
       
  1506         ^ space_implode "p" (map string_of_int pis)) mode)
       
  1507     val HOmode = space_implode "_and_"
       
  1508       (fold (fn NONE => I | SOME mode => cons (string_of_mode mode)) (fst mode) [])
       
  1509   in
       
  1510     (Sign.full_bname thy (prefix ^ (Long_Name.base_name name))) ^
       
  1511       (if HOmode = "" then "_" else "_for_" ^ HOmode ^ "_yields_") ^ (string_of_mode (snd mode))
       
  1512   end;
       
  1513 
       
  1514 fun split_tupleT is T =
       
  1515 	let
       
  1516 		fun split_tuple' _ _ [] = ([], [])
       
  1517 			| split_tuple' is i (T::Ts) =
       
  1518 			(if i mem is then apfst else apsnd) (cons T)
       
  1519 				(split_tuple' is (i+1) Ts)
       
  1520 	in
       
  1521 	  split_tuple' is 1 (HOLogic.strip_tupleT T)
       
  1522   end
       
  1523 	
       
  1524 fun mk_arg xin xout pis T =
       
  1525   let
       
  1526 	  val n = length (HOLogic.strip_tupleT T)
       
  1527 		val ni = length pis
       
  1528 	  fun mk_proj i j t =
       
  1529 		  (if i = j then I else HOLogic.mk_fst)
       
  1530 			  (funpow (i - 1) HOLogic.mk_snd t)
       
  1531 	  fun mk_arg' i (si, so) = if i mem pis then
       
  1532 		    (mk_proj si ni xin, (si+1, so))
       
  1533 		  else
       
  1534 			  (mk_proj so (n - ni) xout, (si, so+1))
       
  1535 	  val (args, _) = fold_map mk_arg' (1 upto n) (1, 1)
       
  1536 	in
       
  1537 	  HOLogic.mk_tuple args
       
  1538 	end
       
  1539 
       
  1540 fun create_definitions preds (name, modes) thy =
       
  1541   let
       
  1542     val compfuns = PredicateCompFuns.compfuns
       
  1543     val T = AList.lookup (op =) preds name |> the
       
  1544     fun create_definition (mode as (iss, is)) thy = let
       
  1545       val mode_cname = create_constname_of_mode thy "" name mode
       
  1546       val mode_cbasename = Long_Name.base_name mode_cname
       
  1547       val Ts = binder_types T
       
  1548       val (Ts1, Ts2) = chop (length iss) Ts
       
  1549       val (Us1, Us2) =  split_smodeT is Ts2
       
  1550       val Ts1' = map2 (fn NONE => I | SOME is => funT_of compfuns ([], is)) iss Ts1
       
  1551       val funT = (Ts1' @ Us1) ---> (mk_predT compfuns (mk_tupleT Us2))
       
  1552       val names = Name.variant_list []
       
  1553         (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
       
  1554 			(* old *)
       
  1555 			(*
       
  1556 		  val xs = map Free (names ~~ (Ts1' @ Ts2))
       
  1557       val (xparams, xargs) = chop (length iss) xs
       
  1558       val (xins, xouts) = split_smode is xargs
       
  1559 			*)
       
  1560 			(* new *)
       
  1561 			val param_names = Name.variant_list []
       
  1562 			  (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts1')))
       
  1563 		  val xparams = map Free (param_names ~~ Ts1')
       
  1564       fun mk_vars (i, T) names =
       
  1565 			  let
       
  1566 				  val vname = Name.variant names ("x" ^ string_of_int (length Ts1' + i))
       
  1567 				in
       
  1568 					case AList.lookup (op =) is i of
       
  1569 						 NONE => ((([], [Free (vname, T)]), Free (vname, T)), vname :: names)
       
  1570 					 | SOME NONE => ((([Free (vname, T)], []), Free (vname, T)), vname :: names)
       
  1571         	 | SOME (SOME pis) =>
       
  1572 					   let
       
  1573 						   val (Tins, Touts) = split_tupleT pis T
       
  1574 							 val name_in = Name.variant names ("x" ^ string_of_int (length Ts1' + i) ^ "in")
       
  1575 							 val name_out = Name.variant names ("x" ^ string_of_int (length Ts1' + i) ^ "out")
       
  1576 						   val xin = Free (name_in, HOLogic.mk_tupleT Tins)
       
  1577 							 val xout = Free (name_out, HOLogic.mk_tupleT Touts)
       
  1578 							 val xarg = mk_arg xin xout pis T
       
  1579 						 in (((if null Tins then [] else [xin], if null Touts then [] else [xout]), xarg), name_in :: name_out :: names) end
       
  1580 						(* HOLogic.strip_tupleT T of
       
  1581 						[] => 
       
  1582 							in (Free (vname, T), vname :: names) end
       
  1583 					| [_] => let val vname = Name.variant names ("x" ^ string_of_int (length Ts1' + i))
       
  1584 							in (Free (vname, T), vname :: names) end
       
  1585 					| Ts =>
       
  1586 						let
       
  1587 							val vnames = Name.variant_list names
       
  1588 								(map (fn j => "x" ^ string_of_int (length Ts1' + i) ^ "p" ^ string_of_int j)
       
  1589 									(1 upto (length Ts)))
       
  1590 						 in (HOLogic.mk_tuple (map Free (vnames ~~ Ts)), vnames @ names) end *)
       
  1591 				end
       
  1592    	  val (xinoutargs, names) = fold_map mk_vars ((1 upto (length Ts2)) ~~ Ts2) param_names
       
  1593       val (xinout, xargs) = split_list xinoutargs
       
  1594 			val (xins, xouts) = pairself flat (split_list xinout)
       
  1595 			(*val (xins, xouts) = split_smode is xargs*)
       
  1596 			val (xparams', names') = fold_map mk_Eval_of ((xparams ~~ Ts1) ~~ iss) names
       
  1597 			val _ = Output.tracing ("xargs:" ^ commas (map (Syntax.string_of_term_global thy) xargs))
       
  1598       fun mk_split_lambda [] t = lambda (Free (Name.variant names' "x", HOLogic.unitT)) t
       
  1599         | mk_split_lambda [x] t = lambda x t
       
  1600         | mk_split_lambda xs t =
       
  1601         let
       
  1602           fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t))
       
  1603             | mk_split_lambda' (x::xs) t = HOLogic.mk_split (lambda x (mk_split_lambda' xs t))
       
  1604         in
       
  1605           mk_split_lambda' xs t
       
  1606         end;
       
  1607       val predterm = PredicateCompFuns.mk_Enum (mk_split_lambda xouts
       
  1608         (list_comb (Const (name, T), xparams' @ xargs)))
       
  1609       val lhs = list_comb (Const (mode_cname, funT), xparams @ xins)
       
  1610       val def = Logic.mk_equals (lhs, predterm)
       
  1611 			val _ = Output.tracing ("def:" ^ (Syntax.string_of_term_global thy def))
       
  1612       val ([definition], thy') = thy |>
       
  1613         Sign.add_consts_i [(Binding.name mode_cbasename, funT, NoSyn)] |>
       
  1614         PureThy.add_defs false [((Binding.name (mode_cbasename ^ "_def"), def), [])]
       
  1615       val (intro, elim) =
       
  1616         create_intro_elim_rule mode definition mode_cname funT (Const (name, T)) thy'
       
  1617 			val _ = Output.tracing (Display.string_of_thm_global thy' definition)
       
  1618       in thy'
       
  1619 			  |> add_predfun name mode (mode_cname, definition, intro, elim)
       
  1620         |> PureThy.store_thm (Binding.name (mode_cbasename ^ "I"), intro) |> snd
       
  1621         |> PureThy.store_thm (Binding.name (mode_cbasename ^ "E"), elim)  |> snd
       
  1622         |> Theory.checkpoint
       
  1623       end;
       
  1624   in
       
  1625     fold create_definition modes thy
       
  1626   end;
       
  1627 
       
  1628 fun sizelim_create_definitions preds (name, modes) thy =
       
  1629   let
       
  1630     val T = AList.lookup (op =) preds name |> the
       
  1631     fun create_definition mode thy =
       
  1632       let
       
  1633         val mode_cname = create_constname_of_mode thy "sizelim_" name mode
       
  1634         val funT = sizelim_funT_of PredicateCompFuns.compfuns mode T
       
  1635       in
       
  1636         thy |> Sign.add_consts_i [(Binding.name (Long_Name.base_name mode_cname), funT, NoSyn)]
       
  1637         |> set_sizelim_function_name name mode mode_cname 
       
  1638       end;
       
  1639   in
       
  1640     fold create_definition modes thy
       
  1641   end;
       
  1642     
       
  1643 fun rpred_create_definitions preds (name, modes) thy =
       
  1644   let
       
  1645     val T = AList.lookup (op =) preds name |> the
       
  1646     fun create_definition mode thy =
       
  1647       let
       
  1648         val mode_cname = create_constname_of_mode thy "gen_" name mode
       
  1649         val funT = sizelim_funT_of RPredCompFuns.compfuns mode T
       
  1650       in
       
  1651         thy |> Sign.add_consts_i [(Binding.name (Long_Name.base_name mode_cname), funT, NoSyn)]
       
  1652         |> set_generator_name name mode mode_cname 
       
  1653       end;
       
  1654   in
       
  1655     fold create_definition modes thy
       
  1656   end;
       
  1657   
       
  1658 (* Proving equivalence of term *)
       
  1659 
       
  1660 fun is_Type (Type _) = true
       
  1661   | is_Type _ = false
       
  1662 
       
  1663 (* returns true if t is an application of an datatype constructor *)
       
  1664 (* which then consequently would be splitted *)
       
  1665 (* else false *)
       
  1666 fun is_constructor thy t =
       
  1667   if (is_Type (fastype_of t)) then
       
  1668     (case Datatype.get_info thy ((fst o dest_Type o fastype_of) t) of
       
  1669       NONE => false
       
  1670     | SOME info => (let
       
  1671       val constr_consts = maps (fn (_, (_, _, constrs)) => map fst constrs) (#descr info)
       
  1672       val (c, _) = strip_comb t
       
  1673       in (case c of
       
  1674         Const (name, _) => name mem_string constr_consts
       
  1675         | _ => false) end))
       
  1676   else false
       
  1677 
       
  1678 (* MAJOR FIXME:  prove_params should be simple
       
  1679  - different form of introrule for parameters ? *)
       
  1680 fun prove_param thy (NONE, t) = TRY (rtac @{thm refl} 1)
       
  1681   | prove_param thy (m as SOME (Mode (mode, is, ms)), t) =
       
  1682   let
       
  1683     val  (f, args) = strip_comb (Envir.eta_contract t)
       
  1684     val (params, _) = chop (length ms) args
       
  1685     val f_tac = case f of
       
  1686       Const (name, T) => simp_tac (HOL_basic_ss addsimps 
       
  1687          ([@{thm eval_pred}, (predfun_definition_of thy name mode),
       
  1688          @{thm "split_eta"}, @{thm "split_beta"}, @{thm "fst_conv"},
       
  1689 				 @{thm "snd_conv"}, @{thm pair_collapse}, @{thm "Product_Type.split_conv"}])) 1
       
  1690     | Free _ => TRY (rtac @{thm refl} 1)
       
  1691     | Abs _ => error "prove_param: No valid parameter term"
       
  1692   in
       
  1693     REPEAT_DETERM (etac @{thm thin_rl} 1)
       
  1694     THEN REPEAT_DETERM (rtac @{thm ext} 1)
       
  1695     THEN print_tac "prove_param"
       
  1696     THEN f_tac
       
  1697     THEN print_tac "after simplification in prove_args"
       
  1698     THEN (EVERY (map (prove_param thy) (ms ~~ params)))
       
  1699     THEN (REPEAT_DETERM (atac 1))
       
  1700   end
       
  1701 
       
  1702 fun prove_expr thy (Mode (mode, is, ms), t, us) (premposition : int) =
       
  1703   case strip_comb t of
       
  1704     (Const (name, T), args) =>  
       
  1705       let
       
  1706         val introrule = predfun_intro_of thy name mode
       
  1707         val (args1, args2) = chop (length ms) args
       
  1708       in
       
  1709         rtac @{thm bindI} 1
       
  1710         THEN print_tac "before intro rule:"
       
  1711         (* for the right assumption in first position *)
       
  1712         THEN rotate_tac premposition 1
       
  1713         THEN debug_tac (Display.string_of_thm (ProofContext.init thy) introrule)
       
  1714         THEN rtac introrule 1
       
  1715         THEN print_tac "after intro rule"
       
  1716         (* work with parameter arguments *)
       
  1717         THEN (atac 1)
       
  1718         THEN (print_tac "parameter goal")
       
  1719         THEN (EVERY (map (prove_param thy) (ms ~~ args1)))
       
  1720         THEN (REPEAT_DETERM (atac 1))
       
  1721       end
       
  1722   | _ => rtac @{thm bindI} 1
       
  1723 	  THEN asm_full_simp_tac
       
  1724 		  (HOL_basic_ss' addsimps [@{thm "split_eta"}, @{thm "split_beta"}, @{thm "fst_conv"},
       
  1725 				 @{thm "snd_conv"}, @{thm pair_collapse}]) 1
       
  1726 	  THEN (atac 1)
       
  1727 	  THEN print_tac "after prove parameter call"
       
  1728 		
       
  1729 
       
  1730 fun SOLVED tac st = FILTER (fn st' => nprems_of st' = nprems_of st - 1) tac st; 
       
  1731 
       
  1732 fun SOLVEDALL tac st = FILTER (fn st' => nprems_of st' = 0) tac st
       
  1733 
       
  1734 fun prove_match thy (out_ts : term list) = let
       
  1735   fun get_case_rewrite t =
       
  1736     if (is_constructor thy t) then let
       
  1737       val case_rewrites = (#case_rewrites (Datatype.the_info thy
       
  1738         ((fst o dest_Type o fastype_of) t)))
       
  1739       in case_rewrites @ (flat (map get_case_rewrite (snd (strip_comb t)))) end
       
  1740     else []
       
  1741   val simprules = @{thm "unit.cases"} :: @{thm "prod.cases"} :: (flat (map get_case_rewrite out_ts))
       
  1742 (* replace TRY by determining if it necessary - are there equations when calling compile match? *)
       
  1743 in
       
  1744    (* make this simpset better! *)
       
  1745   asm_full_simp_tac (HOL_basic_ss' addsimps simprules) 1
       
  1746   THEN print_tac "after prove_match:"
       
  1747   THEN (DETERM (TRY (EqSubst.eqsubst_tac (ProofContext.init thy) [0] [@{thm "HOL.if_P"}] 1
       
  1748          THEN (REPEAT_DETERM (rtac @{thm conjI} 1 THEN (SOLVED (asm_simp_tac HOL_basic_ss 1))))
       
  1749          THEN (SOLVED (asm_simp_tac HOL_basic_ss 1)))))
       
  1750   THEN print_tac "after if simplification"
       
  1751 end;
       
  1752 
       
  1753 (* corresponds to compile_fun -- maybe call that also compile_sidecond? *)
       
  1754 
       
  1755 fun prove_sidecond thy modes t =
       
  1756   let
       
  1757     fun preds_of t nameTs = case strip_comb t of 
       
  1758       (f as Const (name, T), args) =>
       
  1759         if AList.defined (op =) modes name then (name, T) :: nameTs
       
  1760           else fold preds_of args nameTs
       
  1761       | _ => nameTs
       
  1762     val preds = preds_of t []
       
  1763     val defs = map
       
  1764       (fn (pred, T) => predfun_definition_of thy pred
       
  1765         ([], map (rpair NONE) (1 upto (length (binder_types T)))))
       
  1766         preds
       
  1767   in 
       
  1768     (* remove not_False_eq_True when simpset in prove_match is better *)
       
  1769     simp_tac (HOL_basic_ss addsimps
       
  1770       (@{thms "HOL.simp_thms"} @ (@{thm not_False_eq_True} :: @{thm eval_pred} :: defs))) 1 
       
  1771     (* need better control here! *)
       
  1772   end
       
  1773 
       
  1774 fun prove_clause thy nargs modes (iss, is) (_, clauses) (ts, moded_ps) =
       
  1775   let
       
  1776     val (in_ts, clause_out_ts) = split_smode is ts;
       
  1777     fun prove_prems out_ts [] =
       
  1778       (prove_match thy out_ts)
       
  1779 			THEN print_tac "before simplifying assumptions"
       
  1780       THEN asm_full_simp_tac HOL_basic_ss' 1
       
  1781 			THEN print_tac "before single intro rule"
       
  1782       THEN (rtac (if null clause_out_ts then @{thm singleI_unit} else @{thm singleI}) 1)
       
  1783     | prove_prems out_ts ((p, mode as Mode ((iss, is), _, param_modes)) :: ps) =
       
  1784       let
       
  1785         val premposition = (find_index (equal p) clauses) + nargs
       
  1786         val rest_tac = (case p of Prem (us, t) =>
       
  1787             let
       
  1788               val (_, out_ts''') = split_smode is us
       
  1789               val rec_tac = prove_prems out_ts''' ps
       
  1790             in
       
  1791               print_tac "before clause:"
       
  1792               THEN asm_simp_tac HOL_basic_ss 1
       
  1793               THEN print_tac "before prove_expr:"
       
  1794               THEN prove_expr thy (mode, t, us) premposition
       
  1795               THEN print_tac "after prove_expr:"
       
  1796               THEN rec_tac
       
  1797             end
       
  1798           | Negprem (us, t) =>
       
  1799             let
       
  1800               val (_, out_ts''') = split_smode is us
       
  1801               val rec_tac = prove_prems out_ts''' ps
       
  1802               val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE)
       
  1803               val (_, params) = strip_comb t
       
  1804             in
       
  1805               rtac @{thm bindI} 1
       
  1806               THEN (if (is_some name) then
       
  1807                   simp_tac (HOL_basic_ss addsimps [predfun_definition_of thy (the name) (iss, is)]) 1
       
  1808                   THEN rtac @{thm not_predI} 1
       
  1809                   THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1
       
  1810                   THEN (REPEAT_DETERM (atac 1))
       
  1811                   (* FIXME: work with parameter arguments *)
       
  1812                   THEN (EVERY (map (prove_param thy) (param_modes ~~ params)))
       
  1813                 else
       
  1814                   rtac @{thm not_predI'} 1)
       
  1815                   THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1
       
  1816               THEN rec_tac
       
  1817             end
       
  1818           | Sidecond t =>
       
  1819            rtac @{thm bindI} 1
       
  1820            THEN rtac @{thm if_predI} 1
       
  1821            THEN print_tac "before sidecond:"
       
  1822            THEN prove_sidecond thy modes t
       
  1823            THEN print_tac "after sidecond:"
       
  1824            THEN prove_prems [] ps)
       
  1825       in (prove_match thy out_ts)
       
  1826           THEN rest_tac
       
  1827       end;
       
  1828     val prems_tac = prove_prems in_ts moded_ps
       
  1829   in
       
  1830     rtac @{thm bindI} 1
       
  1831     THEN rtac @{thm singleI} 1
       
  1832     THEN prems_tac
       
  1833   end;
       
  1834 
       
  1835 fun select_sup 1 1 = []
       
  1836   | select_sup _ 1 = [rtac @{thm supI1}]
       
  1837   | select_sup n i = (rtac @{thm supI2})::(select_sup (n - 1) (i - 1));
       
  1838 
       
  1839 fun prove_one_direction thy clauses preds modes pred mode moded_clauses =
       
  1840   let
       
  1841     val T = the (AList.lookup (op =) preds pred)
       
  1842     val nargs = length (binder_types T) - nparams_of thy pred
       
  1843     val pred_case_rule = the_elim_of thy pred
       
  1844   in
       
  1845     REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"}))
       
  1846 		THEN print_tac "before applying elim rule"
       
  1847     THEN etac (predfun_elim_of thy pred mode) 1
       
  1848     THEN etac pred_case_rule 1
       
  1849     THEN (EVERY (map
       
  1850            (fn i => EVERY' (select_sup (length moded_clauses) i) i) 
       
  1851              (1 upto (length moded_clauses))))
       
  1852     THEN (EVERY (map2 (prove_clause thy nargs modes mode) clauses moded_clauses))
       
  1853     THEN print_tac "proved one direction"
       
  1854   end;
       
  1855 
       
  1856 (** Proof in the other direction **)
       
  1857 
       
  1858 fun prove_match2 thy out_ts = let
       
  1859   fun split_term_tac (Free _) = all_tac
       
  1860     | split_term_tac t =
       
  1861       if (is_constructor thy t) then let
       
  1862         val info = Datatype.the_info thy ((fst o dest_Type o fastype_of) t)
       
  1863         val num_of_constrs = length (#case_rewrites info)
       
  1864         (* special treatment of pairs -- because of fishing *)
       
  1865         val split_rules = case (fst o dest_Type o fastype_of) t of
       
  1866           "*" => [@{thm prod.split_asm}] 
       
  1867           | _ => PureThy.get_thms thy (((fst o dest_Type o fastype_of) t) ^ ".split_asm")
       
  1868         val (_, ts) = strip_comb t
       
  1869       in
       
  1870         (Splitter.split_asm_tac split_rules 1)
       
  1871 (*        THEN (Simplifier.asm_full_simp_tac HOL_basic_ss 1)
       
  1872           THEN (DETERM (TRY (etac @{thm Pair_inject} 1))) *)
       
  1873         THEN (REPEAT_DETERM_N (num_of_constrs - 1) (etac @{thm botE} 1 ORELSE etac @{thm botE} 2))
       
  1874         THEN (EVERY (map split_term_tac ts))
       
  1875       end
       
  1876     else all_tac
       
  1877   in
       
  1878     split_term_tac (mk_tuple out_ts)
       
  1879     THEN (DETERM (TRY ((Splitter.split_asm_tac [@{thm "split_if_asm"}] 1) THEN (etac @{thm botE} 2))))
       
  1880   end
       
  1881 
       
  1882 (* VERY LARGE SIMILIRATIY to function prove_param 
       
  1883 -- join both functions
       
  1884 *)
       
  1885 (* TODO: remove function *)
       
  1886 
       
  1887 fun prove_param2 thy (NONE, t) = all_tac 
       
  1888   | prove_param2 thy (m as SOME (Mode (mode, is, ms)), t) = let
       
  1889     val  (f, args) = strip_comb (Envir.eta_contract t)
       
  1890     val (params, _) = chop (length ms) args
       
  1891     val f_tac = case f of
       
  1892         Const (name, T) => full_simp_tac (HOL_basic_ss addsimps 
       
  1893            (@{thm eval_pred}::(predfun_definition_of thy name mode)
       
  1894            :: @{thm "Product_Type.split_conv"}::[])) 1
       
  1895       | Free _ => all_tac
       
  1896       | _ => error "prove_param2: illegal parameter term"
       
  1897   in  
       
  1898     print_tac "before simplification in prove_args:"
       
  1899     THEN f_tac
       
  1900     THEN print_tac "after simplification in prove_args"
       
  1901     THEN (EVERY (map (prove_param2 thy) (ms ~~ params)))
       
  1902   end
       
  1903 
       
  1904 
       
  1905 fun prove_expr2 thy (Mode (mode, is, ms), t) = 
       
  1906   (case strip_comb t of
       
  1907     (Const (name, T), args) =>
       
  1908       etac @{thm bindE} 1
       
  1909       THEN (REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"})))
       
  1910       THEN print_tac "prove_expr2-before"
       
  1911       THEN (debug_tac (Syntax.string_of_term_global thy
       
  1912         (prop_of (predfun_elim_of thy name mode))))
       
  1913       THEN (etac (predfun_elim_of thy name mode) 1)
       
  1914       THEN print_tac "prove_expr2"
       
  1915       THEN (EVERY (map (prove_param2 thy) (ms ~~ args)))
       
  1916       THEN print_tac "finished prove_expr2"      
       
  1917     | _ => etac @{thm bindE} 1)
       
  1918     
       
  1919 (* FIXME: what is this for? *)
       
  1920 (* replace defined by has_mode thy pred *)
       
  1921 (* TODO: rewrite function *)
       
  1922 fun prove_sidecond2 thy modes t = let
       
  1923   fun preds_of t nameTs = case strip_comb t of 
       
  1924     (f as Const (name, T), args) =>
       
  1925       if AList.defined (op =) modes name then (name, T) :: nameTs
       
  1926         else fold preds_of args nameTs
       
  1927     | _ => nameTs
       
  1928   val preds = preds_of t []
       
  1929   val defs = map
       
  1930     (fn (pred, T) => predfun_definition_of thy pred 
       
  1931       ([], map (rpair NONE) (1 upto (length (binder_types T)))))
       
  1932       preds
       
  1933   in
       
  1934    (* only simplify the one assumption *)
       
  1935    full_simp_tac (HOL_basic_ss' addsimps @{thm eval_pred} :: defs) 1 
       
  1936    (* need better control here! *)
       
  1937    THEN print_tac "after sidecond2 simplification"
       
  1938    end
       
  1939   
       
  1940 fun prove_clause2 thy modes pred (iss, is) (ts, ps) i =
       
  1941   let
       
  1942     val pred_intro_rule = nth (intros_of thy pred) (i - 1)
       
  1943     val (in_ts, clause_out_ts) = split_smode is ts;
       
  1944     fun prove_prems2 out_ts [] =
       
  1945       print_tac "before prove_match2 - last call:"
       
  1946       THEN prove_match2 thy out_ts
       
  1947       THEN print_tac "after prove_match2 - last call:"
       
  1948       THEN (etac @{thm singleE} 1)
       
  1949       THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1))
       
  1950       THEN (asm_full_simp_tac HOL_basic_ss' 1)
       
  1951       THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1))
       
  1952       THEN (asm_full_simp_tac HOL_basic_ss' 1)
       
  1953       THEN SOLVED (print_tac "state before applying intro rule:"
       
  1954       THEN (rtac pred_intro_rule 1)
       
  1955       (* How to handle equality correctly? *)
       
  1956       THEN (print_tac "state before assumption matching")
       
  1957       THEN (REPEAT (atac 1 ORELSE 
       
  1958          (CHANGED (asm_full_simp_tac (HOL_basic_ss' addsimps
       
  1959 					 [@{thm split_eta}, @{thm "split_beta"}, @{thm "fst_conv"}, @{thm "snd_conv"}, @{thm pair_collapse}]) 1)
       
  1960           THEN print_tac "state after simp_tac:"))))
       
  1961     | prove_prems2 out_ts ((p, mode as Mode ((iss, is), _, param_modes)) :: ps) =
       
  1962       let
       
  1963         val rest_tac = (case p of
       
  1964           Prem (us, t) =>
       
  1965           let
       
  1966             val (_, out_ts''') = split_smode is us
       
  1967             val rec_tac = prove_prems2 out_ts''' ps
       
  1968           in
       
  1969             (prove_expr2 thy (mode, t)) THEN rec_tac
       
  1970           end
       
  1971         | Negprem (us, t) =>
       
  1972           let
       
  1973             val (_, out_ts''') = split_smode is us
       
  1974             val rec_tac = prove_prems2 out_ts''' ps
       
  1975             val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE)
       
  1976             val (_, params) = strip_comb t
       
  1977           in
       
  1978             print_tac "before neg prem 2"
       
  1979             THEN etac @{thm bindE} 1
       
  1980             THEN (if is_some name then
       
  1981                 full_simp_tac (HOL_basic_ss addsimps [predfun_definition_of thy (the name) (iss, is)]) 1 
       
  1982                 THEN etac @{thm not_predE} 1
       
  1983                 THEN simp_tac (HOL_basic_ss addsimps [@{thm not_False_eq_True}]) 1
       
  1984                 THEN (EVERY (map (prove_param2 thy) (param_modes ~~ params)))
       
  1985               else
       
  1986                 etac @{thm not_predE'} 1)
       
  1987             THEN rec_tac
       
  1988           end 
       
  1989         | Sidecond t =>
       
  1990           etac @{thm bindE} 1
       
  1991           THEN etac @{thm if_predE} 1
       
  1992           THEN prove_sidecond2 thy modes t 
       
  1993           THEN prove_prems2 [] ps)
       
  1994       in print_tac "before prove_match2:"
       
  1995          THEN prove_match2 thy out_ts
       
  1996          THEN print_tac "after prove_match2:"
       
  1997          THEN rest_tac
       
  1998       end;
       
  1999     val prems_tac = prove_prems2 in_ts ps 
       
  2000   in
       
  2001     print_tac "starting prove_clause2"
       
  2002     THEN etac @{thm bindE} 1
       
  2003     THEN (etac @{thm singleE'} 1)
       
  2004     THEN (TRY (etac @{thm Pair_inject} 1))
       
  2005     THEN print_tac "after singleE':"
       
  2006     THEN prems_tac
       
  2007   end;
       
  2008  
       
  2009 fun prove_other_direction thy modes pred mode moded_clauses =
       
  2010   let
       
  2011     fun prove_clause clause i =
       
  2012       (if i < length moded_clauses then etac @{thm supE} 1 else all_tac)
       
  2013       THEN (prove_clause2 thy modes pred mode clause i)
       
  2014   in
       
  2015     (DETERM (TRY (rtac @{thm unit.induct} 1)))
       
  2016      THEN (REPEAT_DETERM (CHANGED (rewtac @{thm split_paired_all})))
       
  2017      THEN (rtac (predfun_intro_of thy pred mode) 1)
       
  2018      THEN (REPEAT_DETERM (rtac @{thm refl} 2))
       
  2019      THEN (EVERY (map2 prove_clause moded_clauses (1 upto (length moded_clauses))))
       
  2020   end;
       
  2021 
       
  2022 (** proof procedure **)
       
  2023 
       
  2024 fun prove_pred thy clauses preds modes pred mode (moded_clauses, compiled_term) =
       
  2025   let
       
  2026     val ctxt = ProofContext.init thy
       
  2027     val clauses = the (AList.lookup (op =) clauses pred)
       
  2028   in
       
  2029     Goal.prove ctxt (Term.add_free_names compiled_term []) [] compiled_term
       
  2030       (if !do_proofs then
       
  2031         (fn _ =>
       
  2032         rtac @{thm pred_iffI} 1
       
  2033 				THEN print_tac "after pred_iffI"
       
  2034         THEN prove_one_direction thy clauses preds modes pred mode moded_clauses
       
  2035         THEN print_tac "proved one direction"
       
  2036         THEN prove_other_direction thy modes pred mode moded_clauses
       
  2037         THEN print_tac "proved other direction")
       
  2038        else (fn _ => mycheat_tac thy 1))
       
  2039   end;
       
  2040 
       
  2041 (* composition of mode inference, definition, compilation and proof *)
       
  2042 
       
  2043 (** auxillary combinators for table of preds and modes **)
       
  2044 
       
  2045 fun map_preds_modes f preds_modes_table =
       
  2046   map (fn (pred, modes) =>
       
  2047     (pred, map (fn (mode, value) => (mode, f pred mode value)) modes)) preds_modes_table
       
  2048 
       
  2049 fun join_preds_modes table1 table2 =
       
  2050   map_preds_modes (fn pred => fn mode => fn value =>
       
  2051     (value, the (AList.lookup (op =) (the (AList.lookup (op =) table2 pred)) mode))) table1
       
  2052     
       
  2053 fun maps_modes preds_modes_table =
       
  2054   map (fn (pred, modes) =>
       
  2055     (pred, map (fn (mode, value) => value) modes)) preds_modes_table  
       
  2056     
       
  2057 fun compile_preds compfuns mk_fun_of use_size thy all_vs param_vs preds moded_clauses =
       
  2058   map_preds_modes (fn pred => compile_pred compfuns mk_fun_of use_size thy all_vs param_vs pred
       
  2059       (the (AList.lookup (op =) preds pred))) moded_clauses  
       
  2060   
       
  2061 fun prove thy clauses preds modes moded_clauses compiled_terms =
       
  2062   map_preds_modes (prove_pred thy clauses preds modes)
       
  2063     (join_preds_modes moded_clauses compiled_terms)
       
  2064 
       
  2065 fun prove_by_skip thy _ _ _ _ compiled_terms =
       
  2066   map_preds_modes (fn pred => fn mode => fn t => Drule.standard (SkipProof.make_thm thy t))
       
  2067     compiled_terms
       
  2068     
       
  2069 fun prepare_intrs thy prednames =
       
  2070   let
       
  2071     val intrs = maps (intros_of thy) prednames
       
  2072       |> map (Logic.unvarify o prop_of)
       
  2073     val nparams = nparams_of thy (hd prednames)
       
  2074     val extra_modes = all_modes_of thy |> filter_out (fn (name, _) => member (op =) prednames name)
       
  2075     val preds = distinct (op =) (map (dest_Const o fst o (strip_intro_concl nparams)) intrs)
       
  2076     val _ $ u = Logic.strip_imp_concl (hd intrs);
       
  2077     val params = List.take (snd (strip_comb u), nparams);
       
  2078     val param_vs = maps term_vs params
       
  2079     val all_vs = terms_vs intrs
       
  2080     fun dest_prem t =
       
  2081       (case strip_comb t of
       
  2082         (v as Free _, ts) => if v mem params then Prem (ts, v) else Sidecond t
       
  2083       | (c as Const (@{const_name Not}, _), [t]) => (case dest_prem t of          
       
  2084           Prem (ts, t) => Negprem (ts, t)
       
  2085         | Negprem _ => error ("Double negation not allowed in premise: " ^ (Syntax.string_of_term_global thy (c $ t))) 
       
  2086         | Sidecond t => Sidecond (c $ t))
       
  2087       | (c as Const (s, _), ts) =>
       
  2088         if is_registered thy s then
       
  2089           let val (ts1, ts2) = chop (nparams_of thy s) ts
       
  2090           in Prem (ts2, list_comb (c, ts1)) end
       
  2091         else Sidecond t
       
  2092       | _ => Sidecond t)
       
  2093     fun add_clause intr (clauses, arities) =
       
  2094     let
       
  2095       val _ $ t = Logic.strip_imp_concl intr;
       
  2096       val (Const (name, T), ts) = strip_comb t;
       
  2097       val (ts1, ts2) = chop nparams ts;
       
  2098       val prems = map (dest_prem o HOLogic.dest_Trueprop) (Logic.strip_imp_prems intr);
       
  2099       val (Ts, Us) = chop nparams (binder_types T)
       
  2100     in
       
  2101       (AList.update op = (name, these (AList.lookup op = clauses name) @
       
  2102         [(ts2, prems)]) clauses,
       
  2103        AList.update op = (name, (map (fn U => (case strip_type U of
       
  2104                  (Rs as _ :: _, Type ("bool", [])) => SOME (length Rs)
       
  2105                | _ => NONE)) Ts,
       
  2106              length Us)) arities)
       
  2107     end;
       
  2108     val (clauses, arities) = fold add_clause intrs ([], []);
       
  2109     fun modes_of_arities arities =
       
  2110       (map (fn (s, (ks, k)) => (s, cprod (cprods (map
       
  2111             (fn NONE => [NONE]
       
  2112               | SOME k' => map SOME (map (map (rpair NONE)) (subsets 1 k'))) ks),
       
  2113        map (map (rpair NONE)) (subsets 1 k)))) arities)
       
  2114     fun modes_of_typ T =
       
  2115       let
       
  2116         val (Ts, Us) = chop nparams (binder_types T)
       
  2117         fun all_smodes_of_typs Ts = cprods_subset (
       
  2118           map_index (fn (i, U) =>
       
  2119             case HOLogic.strip_tupleT U of
       
  2120               [] => [(i + 1, NONE)]
       
  2121             | [U] => [(i + 1, NONE)]
       
  2122 	    | Us =>  map (pair (i + 1) o SOME) ((subsets 1 (length Us)) \\ [[], 1 upto (length Us)]))
       
  2123           Ts)
       
  2124       in
       
  2125         cprod (cprods (map (fn T => case strip_type T of
       
  2126           (Rs as _ :: _, Type ("bool", [])) => map SOME (all_smodes_of_typs Rs) | _ => [NONE]) Ts),
       
  2127            all_smodes_of_typs Us)
       
  2128       end
       
  2129     val all_modes = map (fn (s, T) => (s, modes_of_typ T)) preds
       
  2130   in (preds, nparams, all_vs, param_vs, extra_modes, clauses, all_modes) end;
       
  2131 
       
  2132 (** main function of predicate compiler **)
       
  2133 
       
  2134 fun add_equations_of steps prednames thy =
       
  2135   let
       
  2136     val _ = Output.tracing ("Starting predicate compiler for predicates " ^ commas prednames ^ "...")
       
  2137     val (preds, nparams, all_vs, param_vs, extra_modes, clauses, all_modes) =
       
  2138       prepare_intrs thy prednames
       
  2139     val _ = Output.tracing "Infering modes..."
       
  2140     val moded_clauses = #infer_modes steps thy extra_modes all_modes param_vs clauses 
       
  2141     val modes = map (fn (p, mps) => (p, map fst mps)) moded_clauses
       
  2142     val _ = print_modes modes
       
  2143     val _ = print_moded_clauses thy moded_clauses
       
  2144     val _ = Output.tracing "Defining executable functions..."
       
  2145     val thy' = fold (#create_definitions steps preds) modes thy
       
  2146       |> Theory.checkpoint
       
  2147     val _ = Output.tracing "Compiling equations..."
       
  2148     val compiled_terms =
       
  2149       (#compile_preds steps) thy' all_vs param_vs preds moded_clauses
       
  2150     val _ = print_compiled_terms thy' compiled_terms
       
  2151     val _ = Output.tracing "Proving equations..."
       
  2152     val result_thms = #prove steps thy' clauses preds (extra_modes @ modes)
       
  2153       moded_clauses compiled_terms
       
  2154     val qname = #qname steps
       
  2155     (* val attrib = gn thy => Attrib.attribute_i thy Code.add_eqn_attrib *)
       
  2156     val attrib = fn thy => Attrib.attribute_i thy (Attrib.internal (K (Thm.declaration_attribute
       
  2157       (fn thm => Context.mapping (Code.add_eqn thm) I))))
       
  2158     val thy'' = fold (fn (name, result_thms) => fn thy => snd (PureThy.add_thmss
       
  2159       [((Binding.qualify true (Long_Name.base_name name) (Binding.name qname), result_thms),
       
  2160         [attrib thy ])] thy))
       
  2161       (maps_modes result_thms) thy'
       
  2162       |> Theory.checkpoint
       
  2163   in
       
  2164     thy''
       
  2165   end
       
  2166 
       
  2167 fun extend' value_of edges_of key (G, visited) =
       
  2168   let
       
  2169     val (G', v) = case try (Graph.get_node G) key of
       
  2170         SOME v => (G, v)
       
  2171       | NONE => (Graph.new_node (key, value_of key) G, value_of key)
       
  2172     val (G'', visited') = fold (extend' value_of edges_of) (edges_of (key, v) \\ visited)
       
  2173       (G', key :: visited) 
       
  2174   in
       
  2175     (fold (Graph.add_edge o (pair key)) (edges_of (key, v)) G'', visited')
       
  2176   end;
       
  2177 
       
  2178 fun extend value_of edges_of key G = fst (extend' value_of edges_of key (G, [])) 
       
  2179   
       
  2180 fun gen_add_equations steps names thy =
       
  2181   let
       
  2182     val thy' = PredData.map (fold (extend (fetch_pred_data thy) (depending_preds_of thy)) names) thy
       
  2183       |> Theory.checkpoint;
       
  2184     fun strong_conn_of gr keys =
       
  2185       Graph.strong_conn (Graph.subgraph (member (op =) (Graph.all_succs gr keys)) gr)
       
  2186     val scc = strong_conn_of (PredData.get thy') names
       
  2187     val thy'' = fold_rev
       
  2188       (fn preds => fn thy =>
       
  2189         if #are_not_defined steps thy preds then add_equations_of steps preds thy else thy)
       
  2190       scc thy' |> Theory.checkpoint
       
  2191   in thy'' end
       
  2192 
       
  2193 (* different instantiantions of the predicate compiler *)
       
  2194 
       
  2195 val add_equations = gen_add_equations
       
  2196   {infer_modes = infer_modes,
       
  2197   create_definitions = create_definitions,
       
  2198   compile_preds = compile_preds PredicateCompFuns.compfuns mk_fun_of false,
       
  2199   prove = prove,
       
  2200   are_not_defined = (fn thy => forall (null o modes_of thy)),
       
  2201   qname = "equation"}
       
  2202 
       
  2203 val add_sizelim_equations = gen_add_equations
       
  2204   {infer_modes = infer_modes,
       
  2205   create_definitions = sizelim_create_definitions,
       
  2206   compile_preds = compile_preds PredicateCompFuns.compfuns mk_sizelim_fun_of true,
       
  2207   prove = prove_by_skip,
       
  2208   are_not_defined = (fn thy => fn preds => true), (* TODO *)
       
  2209   qname = "sizelim_equation"
       
  2210   }
       
  2211 
       
  2212 val add_quickcheck_equations = gen_add_equations
       
  2213   {infer_modes = infer_modes_with_generator,
       
  2214   create_definitions = rpred_create_definitions,
       
  2215   compile_preds = compile_preds RPredCompFuns.compfuns mk_generator_of true,
       
  2216   prove = prove_by_skip,
       
  2217   are_not_defined = (fn thy => fn preds => true), (* TODO *)
       
  2218   qname = "rpred_equation"}
       
  2219 
       
  2220 (** user interface **)
       
  2221 
       
  2222 (* generation of case rules from user-given introduction rules *)
       
  2223 
       
  2224 fun mk_casesrule ctxt nparams introrules =
       
  2225   let
       
  2226     val intros = map (Logic.unvarify o prop_of) introrules
       
  2227     val (pred, (params, args)) = strip_intro_concl nparams (hd intros)
       
  2228     val ([propname], ctxt1) = Variable.variant_fixes ["thesis"] ctxt
       
  2229     val prop = HOLogic.mk_Trueprop (Free (propname, HOLogic.boolT))
       
  2230     val (argnames, ctxt2) = Variable.variant_fixes
       
  2231       (map (fn i => "a" ^ string_of_int i) (1 upto (length args))) ctxt1
       
  2232     val argvs = map2 (curry Free) argnames (map fastype_of args)
       
  2233     fun mk_case intro =
       
  2234       let
       
  2235         val (_, (_, args)) = strip_intro_concl nparams intro
       
  2236         val prems = Logic.strip_imp_prems intro
       
  2237         val eqprems = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (argvs ~~ args)
       
  2238         val frees = (fold o fold_aterms)
       
  2239           (fn t as Free _ =>
       
  2240               if member (op aconv) params t then I else insert (op aconv) t
       
  2241            | _ => I) (args @ prems) []
       
  2242       in fold Logic.all frees (Logic.list_implies (eqprems @ prems, prop)) end
       
  2243     val assm = HOLogic.mk_Trueprop (list_comb (pred, params @ argvs))
       
  2244     val cases = map mk_case intros
       
  2245   in Logic.list_implies (assm :: cases, prop) end;
       
  2246 
       
  2247 (* code_pred_intro attribute *)
       
  2248 
       
  2249 fun attrib f = Thm.declaration_attribute (fn thm => Context.mapping (f thm) I);
       
  2250 
       
  2251 val code_pred_intros_attrib = attrib add_intro;
       
  2252 
       
  2253 local
       
  2254 
       
  2255 (* TODO: make TheoryDataFun to GenericDataFun & remove duplication of local theory and theory *)
       
  2256 fun generic_code_pred prep_const raw_const lthy =
       
  2257   let
       
  2258     val thy = ProofContext.theory_of lthy
       
  2259     val const = prep_const thy raw_const
       
  2260     val lthy' = LocalTheory.theory (PredData.map
       
  2261         (extend (fetch_pred_data thy) (depending_preds_of thy) const)) lthy
       
  2262       |> LocalTheory.checkpoint
       
  2263     val thy' = ProofContext.theory_of lthy'
       
  2264     val preds = Graph.all_preds (PredData.get thy') [const] |> filter_out (has_elim thy')
       
  2265     fun mk_cases const =
       
  2266       let
       
  2267         val nparams = nparams_of thy' const
       
  2268         val intros = intros_of thy' const
       
  2269       in mk_casesrule lthy' nparams intros end  
       
  2270     val cases_rules = map mk_cases preds
       
  2271     val cases =
       
  2272       map (fn case_rule => RuleCases.Case {fixes = [],
       
  2273         assumes = [("", Logic.strip_imp_prems case_rule)],
       
  2274         binds = [], cases = []}) cases_rules
       
  2275     val case_env = map2 (fn p => fn c => (Long_Name.base_name p, SOME c)) preds cases
       
  2276     val lthy'' = lthy'
       
  2277       |> fold Variable.auto_fixes cases_rules 
       
  2278       |> ProofContext.add_cases true case_env
       
  2279     fun after_qed thms goal_ctxt =
       
  2280       let
       
  2281         val global_thms = ProofContext.export goal_ctxt
       
  2282           (ProofContext.init (ProofContext.theory_of goal_ctxt)) (map the_single thms)
       
  2283       in
       
  2284         goal_ctxt |> LocalTheory.theory (fold set_elim global_thms #> add_equations [const])
       
  2285       end  
       
  2286   in
       
  2287     Proof.theorem_i NONE after_qed (map (single o (rpair [])) cases_rules) lthy''
       
  2288   end;
       
  2289 
       
  2290 structure P = OuterParse
       
  2291 
       
  2292 in
       
  2293 
       
  2294 val code_pred = generic_code_pred (K I);
       
  2295 val code_pred_cmd = generic_code_pred Code.read_const
       
  2296 
       
  2297 val setup = PredData.put (Graph.empty) #>
       
  2298   Attrib.setup @{binding code_pred_intros} (Scan.succeed (attrib add_intro))
       
  2299     "adding alternative introduction rules for code generation of inductive predicates"
       
  2300 (*  Attrib.setup @{binding code_ind_cases} (Scan.succeed add_elim_attrib)
       
  2301     "adding alternative elimination rules for code generation of inductive predicates";
       
  2302     *)
       
  2303   (*FIXME name discrepancy in attribs and ML code*)
       
  2304   (*FIXME intros should be better named intro*)
       
  2305   (*FIXME why distinguished attribute for cases?*)
       
  2306 
       
  2307 val _ = OuterSyntax.local_theory_to_proof "code_pred"
       
  2308   "prove equations for predicate specified by intro/elim rules"
       
  2309   OuterKeyword.thy_goal (P.term_group >> code_pred_cmd)
       
  2310 
       
  2311 end
       
  2312 
       
  2313 (*FIXME
       
  2314 - Naming of auxiliary rules necessary?
       
  2315 - add default code equations P x y z = P_i_i_i x y z
       
  2316 *)
       
  2317 
       
  2318 (* transformation for code generation *)
       
  2319 
       
  2320 val eval_ref = ref (NONE : (unit -> term Predicate.pred) option);
       
  2321 
       
  2322 (*FIXME turn this into an LCF-guarded preprocessor for comprehensions*)
       
  2323 fun analyze_compr thy t_compr =
       
  2324   let
       
  2325     val split = case t_compr of (Const (@{const_name Collect}, _) $ t) => t
       
  2326       | _ => error ("Not a set comprehension: " ^ Syntax.string_of_term_global thy t_compr);
       
  2327     val (body, Ts, fp) = HOLogic.strip_psplits split;
       
  2328     val (pred as Const (name, T), all_args) = strip_comb body;
       
  2329     val (params, args) = chop (nparams_of thy name) all_args;
       
  2330     val user_mode = map_filter I (map_index
       
  2331       (fn (i, t) => case t of Bound j => if j < length Ts then NONE
       
  2332         else SOME (i+1) | _ => SOME (i+1)) args); (*FIXME dangling bounds should not occur*)
       
  2333     val user_mode' = map (rpair NONE) user_mode
       
  2334     val modes = filter (fn Mode (_, is, _) => is = user_mode')
       
  2335       (modes_of_term (all_modes_of thy) (list_comb (pred, params)));
       
  2336     val m = case modes
       
  2337      of [] => error ("No mode possible for comprehension "
       
  2338                 ^ Syntax.string_of_term_global thy t_compr)
       
  2339       | [m] => m
       
  2340       | m :: _ :: _ => (warning ("Multiple modes possible for comprehension "
       
  2341                 ^ Syntax.string_of_term_global thy t_compr); m);
       
  2342     val (inargs, outargs) = split_smode user_mode' args;
       
  2343     val t_pred = list_comb (compile_expr NONE thy (m, list_comb (pred, params)), inargs);
       
  2344     val t_eval = if null outargs then t_pred else let
       
  2345         val outargs_bounds = map (fn Bound i => i) outargs;
       
  2346         val outargsTs = map (nth Ts) outargs_bounds;
       
  2347         val T_pred = HOLogic.mk_tupleT outargsTs;
       
  2348         val T_compr = HOLogic.mk_ptupleT fp Ts;
       
  2349         val arrange_bounds = map_index I outargs_bounds
       
  2350           |> sort (prod_ord (K EQUAL) int_ord)
       
  2351           |> map fst;
       
  2352         val arrange = funpow (length outargs_bounds - 1) HOLogic.mk_split
       
  2353           (Term.list_abs (map (pair "") outargsTs,
       
  2354             HOLogic.mk_ptuple fp T_compr (map Bound arrange_bounds)))
       
  2355       in mk_map PredicateCompFuns.compfuns T_pred T_compr arrange t_pred end
       
  2356   in t_eval end;
       
  2357 
       
  2358 fun eval thy t_compr =
       
  2359   let
       
  2360     val t = analyze_compr thy t_compr;
       
  2361     val T = dest_predT PredicateCompFuns.compfuns (fastype_of t);
       
  2362     val t' = mk_map PredicateCompFuns.compfuns T HOLogic.termT (HOLogic.term_of_const T) t;
       
  2363   in (T, Code_ML.eval NONE ("Predicate_Compile.eval_ref", eval_ref) Predicate.map thy t' []) end;
       
  2364 
       
  2365 fun values ctxt k t_compr =
       
  2366   let
       
  2367     val thy = ProofContext.theory_of ctxt;
       
  2368     val (T, t) = eval thy t_compr;
       
  2369     val setT = HOLogic.mk_setT T;
       
  2370     val (ts, _) = Predicate.yieldn k t;
       
  2371     val elemsT = HOLogic.mk_set T ts;
       
  2372   in if k = ~1 orelse length ts < k then elemsT
       
  2373     else Const (@{const_name Set.union}, setT --> setT --> setT) $ elemsT $ t_compr
       
  2374   end;
       
  2375 
       
  2376 fun values_cmd modes k raw_t state =
       
  2377   let
       
  2378     val ctxt = Toplevel.context_of state;
       
  2379     val t = Syntax.read_term ctxt raw_t;
       
  2380     val t' = values ctxt k t;
       
  2381     val ty' = Term.type_of t';
       
  2382     val ctxt' = Variable.auto_fixes t' ctxt;
       
  2383     val p = PrintMode.with_modes modes (fn () =>
       
  2384       Pretty.block [Pretty.quote (Syntax.pretty_term ctxt' t'), Pretty.fbrk,
       
  2385         Pretty.str "::", Pretty.brk 1, Pretty.quote (Syntax.pretty_typ ctxt' ty')]) ();
       
  2386   in Pretty.writeln p end;
       
  2387 
       
  2388 local structure P = OuterParse in
       
  2389 
       
  2390 val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) [];
       
  2391 
       
  2392 val _ = OuterSyntax.improper_command "values" "enumerate and print comprehensions" OuterKeyword.diag
       
  2393   (opt_modes -- Scan.optional P.nat ~1 -- P.term
       
  2394     >> (fn ((modes, k), t) => Toplevel.no_timing o Toplevel.keep
       
  2395         (values_cmd modes k t)));
       
  2396 
       
  2397 end;
       
  2398 
       
  2399 end;