src/HOL/ex/predicate_compile.ML
author bulwahn
Sat May 16 20:17:59 2009 +0200 (2009-05-16)
changeset 31174 f1f1e9b53c81
parent 31170 c6efe82fc652
child 31177 c39994cb152a
permissions -rw-r--r--
added new kind generated_theorem for theorems which are generated by packages to distinguish between theorems from users and packages
haftmann@30374
     1
(* Author: Lukas Bulwahn
haftmann@30374
     2
haftmann@30374
     3
(Prototype of) A compiler from predicates specified by intro/elim rules
haftmann@30374
     4
to equations.
haftmann@30374
     5
*)
haftmann@30374
     6
haftmann@30374
     7
signature PREDICATE_COMPILE =
haftmann@30374
     8
sig
haftmann@30972
     9
  type mode = int list option list * int list
haftmann@31124
    10
  val prove_equation: string -> mode option -> theory -> theory
haftmann@30972
    11
  val intro_rule: theory -> string -> mode -> thm
haftmann@30972
    12
  val elim_rule: theory -> string -> mode -> thm
haftmann@31124
    13
  val strip_intro_concl: term -> int -> term * (term list * term list)
haftmann@30972
    14
  val modename_of: theory -> string -> mode -> string
haftmann@30972
    15
  val modes_of: theory -> string -> mode list
haftmann@31124
    16
  val setup: theory -> theory
haftmann@31124
    17
  val code_pred: string -> Proof.context -> Proof.state
haftmann@31124
    18
  val code_pred_cmd: string -> Proof.context -> Proof.state
haftmann@31124
    19
  val print_alternative_rules: theory -> theory (*FIXME diagnostic command?*)
haftmann@30374
    20
  val do_proofs: bool ref
bulwahn@31106
    21
  val pred_intros : theory -> string -> thm list
bulwahn@31106
    22
  val get_nparams : theory -> string -> int
bulwahn@31169
    23
  val pred_term_of : theory -> term -> term option
haftmann@30374
    24
end;
haftmann@30374
    25
haftmann@31124
    26
structure Predicate_Compile : PREDICATE_COMPILE =
haftmann@30374
    27
struct
haftmann@30374
    28
haftmann@30972
    29
(** auxiliary **)
haftmann@30972
    30
haftmann@30972
    31
(* debug stuff *)
haftmann@30972
    32
haftmann@30972
    33
fun tracing s = (if ! Toplevel.debug then Output.tracing s else ());
haftmann@30972
    34
haftmann@30972
    35
fun print_tac s = (if ! Toplevel.debug then Tactical.print_tac s else Seq.single);
haftmann@30972
    36
fun debug_tac msg = (fn st => (tracing msg; Seq.single st));
haftmann@30972
    37
haftmann@30972
    38
val do_proofs = ref true;
haftmann@30972
    39
haftmann@30972
    40
haftmann@30972
    41
(** fundamentals **)
haftmann@30972
    42
haftmann@30972
    43
(* syntactic operations *)
haftmann@30972
    44
haftmann@30972
    45
fun mk_eq (x, xs) =
haftmann@30972
    46
  let fun mk_eqs _ [] = []
haftmann@30972
    47
        | mk_eqs a (b::cs) =
haftmann@30972
    48
            HOLogic.mk_eq (Free (a, fastype_of b), b) :: mk_eqs a cs
haftmann@30972
    49
  in mk_eqs x xs end;
haftmann@30972
    50
haftmann@30972
    51
fun mk_tupleT [] = HOLogic.unitT
haftmann@30972
    52
  | mk_tupleT Ts = foldr1 HOLogic.mk_prodT Ts;
haftmann@30972
    53
haftmann@30972
    54
fun mk_tuple [] = HOLogic.unit
haftmann@30972
    55
  | mk_tuple ts = foldr1 HOLogic.mk_prod ts;
haftmann@30972
    56
haftmann@30972
    57
fun dest_tuple (Const (@{const_name Product_Type.Unity}, _)) = []
haftmann@30972
    58
  | dest_tuple (Const (@{const_name Pair}, _) $ t1 $ t2) = t1 :: (dest_tuple t2)
haftmann@30972
    59
  | dest_tuple t = [t]
haftmann@30972
    60
haftmann@30972
    61
fun mk_pred_enumT T = Type ("Predicate.pred", [T])
haftmann@30972
    62
haftmann@30972
    63
fun dest_pred_enumT (Type ("Predicate.pred", [T])) = T
haftmann@30972
    64
  | dest_pred_enumT T = raise TYPE ("dest_pred_enumT", [T], []);
haftmann@30972
    65
haftmann@30972
    66
fun mk_Enum f =
haftmann@30972
    67
  let val T as Type ("fun", [T', _]) = fastype_of f
haftmann@30972
    68
  in
haftmann@30972
    69
    Const (@{const_name Predicate.Pred}, T --> mk_pred_enumT T') $ f    
haftmann@30972
    70
  end;
haftmann@30972
    71
haftmann@30972
    72
fun mk_Eval (f, x) =
haftmann@30972
    73
  let val T = fastype_of x
haftmann@30972
    74
  in
haftmann@30972
    75
    Const (@{const_name Predicate.eval}, mk_pred_enumT T --> T --> HOLogic.boolT) $ f $ x
haftmann@30972
    76
  end;
haftmann@30972
    77
haftmann@30972
    78
fun mk_empty T = Const (@{const_name Orderings.bot}, mk_pred_enumT T);
haftmann@30972
    79
haftmann@30972
    80
fun mk_single t =
haftmann@30972
    81
  let val T = fastype_of t
haftmann@30972
    82
  in Const(@{const_name Predicate.single}, T --> mk_pred_enumT T) $ t end;
haftmann@30972
    83
haftmann@30972
    84
fun mk_bind (x, f) =
haftmann@30972
    85
  let val T as Type ("fun", [_, U]) = fastype_of f
haftmann@30972
    86
  in
haftmann@30972
    87
    Const (@{const_name Predicate.bind}, fastype_of x --> T --> U) $ x $ f
haftmann@30972
    88
  end;
haftmann@30972
    89
haftmann@30972
    90
val mk_sup = HOLogic.mk_binop @{const_name sup};
haftmann@30972
    91
haftmann@30972
    92
fun mk_if_predenum cond = Const (@{const_name Predicate.if_pred},
haftmann@30972
    93
  HOLogic.boolT --> mk_pred_enumT HOLogic.unitT) $ cond;
haftmann@30972
    94
haftmann@30972
    95
fun mk_not_pred t = let val T = mk_pred_enumT HOLogic.unitT
haftmann@30972
    96
  in Const (@{const_name Predicate.not_pred}, T --> T) $ t end
haftmann@30972
    97
haftmann@30972
    98
haftmann@30972
    99
(* data structures *)
haftmann@30972
   100
haftmann@30972
   101
type mode = int list option list * int list;
haftmann@30972
   102
haftmann@30972
   103
val mode_ord = prod_ord (list_ord (option_ord (list_ord int_ord))) (list_ord int_ord);
haftmann@30972
   104
haftmann@30374
   105
structure PredModetab = TableFun(
haftmann@30972
   106
  type key = string * mode
haftmann@30972
   107
  val ord = prod_ord fast_string_ord mode_ord
haftmann@30972
   108
);
haftmann@30374
   109
haftmann@30374
   110
haftmann@30972
   111
(*FIXME scrap boilerplate*)
haftmann@30972
   112
haftmann@30374
   113
structure IndCodegenData = TheoryDataFun
haftmann@30374
   114
(
haftmann@30374
   115
  type T = {names : string PredModetab.table,
haftmann@30972
   116
            modes : mode list Symtab.table,
haftmann@30374
   117
            function_defs : Thm.thm Symtab.table,
haftmann@30374
   118
            function_intros : Thm.thm Symtab.table,
haftmann@30374
   119
            function_elims : Thm.thm Symtab.table,
haftmann@30972
   120
            intro_rules : Thm.thm list Symtab.table,
haftmann@30374
   121
            elim_rules : Thm.thm Symtab.table,
haftmann@30374
   122
            nparams : int Symtab.table
haftmann@30972
   123
           }; (*FIXME: better group tables according to key*)
haftmann@30374
   124
      (* names: map from inductive predicate and mode to function name (string).
haftmann@30374
   125
         modes: map from inductive predicates to modes
haftmann@30374
   126
         function_defs: map from function name to definition
haftmann@30374
   127
         function_intros: map from function name to intro rule
haftmann@30374
   128
         function_elims: map from function name to elim rule
haftmann@30374
   129
         intro_rules: map from inductive predicate to alternative intro rules
haftmann@30374
   130
         elim_rules: map from inductive predicate to alternative elimination rule
haftmann@30374
   131
         nparams: map from const name to number of parameters (* assuming there exist intro and elimination rules *) 
haftmann@30374
   132
       *)
haftmann@30374
   133
  val empty = {names = PredModetab.empty,
haftmann@30374
   134
               modes = Symtab.empty,
haftmann@30374
   135
               function_defs = Symtab.empty,
haftmann@30374
   136
               function_intros = Symtab.empty,
haftmann@30374
   137
               function_elims = Symtab.empty,
haftmann@30374
   138
               intro_rules = Symtab.empty,
haftmann@30374
   139
               elim_rules = Symtab.empty,
haftmann@30374
   140
               nparams = Symtab.empty};
haftmann@30374
   141
  val copy = I;
haftmann@30374
   142
  val extend = I;
haftmann@30374
   143
  fun merge _ r = {names = PredModetab.merge (op =) (pairself #names r),
haftmann@30374
   144
                   modes = Symtab.merge (op =) (pairself #modes r),
haftmann@30374
   145
                   function_defs = Symtab.merge Thm.eq_thm (pairself #function_defs r),
haftmann@30374
   146
                   function_intros = Symtab.merge Thm.eq_thm (pairself #function_intros r),
haftmann@30374
   147
                   function_elims = Symtab.merge Thm.eq_thm (pairself #function_elims r),
haftmann@30374
   148
                   intro_rules = Symtab.merge ((forall Thm.eq_thm) o (op ~~)) (pairself #intro_rules r),
haftmann@30374
   149
                   elim_rules = Symtab.merge Thm.eq_thm (pairself #elim_rules r),
haftmann@30374
   150
                   nparams = Symtab.merge (op =) (pairself #nparams r)};
haftmann@30374
   151
);
haftmann@30374
   152
haftmann@30374
   153
  fun map_names f thy = IndCodegenData.map
haftmann@30374
   154
    (fn x => {names = f (#names x), modes = #modes x, function_defs = #function_defs x,
haftmann@30374
   155
            function_intros = #function_intros x, function_elims = #function_elims x,
haftmann@30374
   156
            intro_rules = #intro_rules x, elim_rules = #elim_rules x,
haftmann@30374
   157
            nparams = #nparams x}) thy
haftmann@30374
   158
haftmann@30374
   159
  fun map_modes f thy = IndCodegenData.map
haftmann@30374
   160
    (fn x => {names = #names x, modes = f (#modes x), function_defs = #function_defs x,
haftmann@30374
   161
            function_intros = #function_intros x, function_elims = #function_elims x,
haftmann@30374
   162
            intro_rules = #intro_rules x, elim_rules = #elim_rules x,
haftmann@30374
   163
            nparams = #nparams x}) thy
haftmann@30374
   164
haftmann@30374
   165
  fun map_function_defs f thy = IndCodegenData.map
haftmann@30374
   166
    (fn x => {names = #names x, modes = #modes x, function_defs = f (#function_defs x),
haftmann@30374
   167
            function_intros = #function_intros x, function_elims = #function_elims x,
haftmann@30374
   168
            intro_rules = #intro_rules x, elim_rules = #elim_rules x,
haftmann@30374
   169
            nparams = #nparams x}) thy 
haftmann@30374
   170
  
haftmann@30374
   171
  fun map_function_elims f thy = IndCodegenData.map
haftmann@30374
   172
    (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
haftmann@30374
   173
            function_intros = #function_intros x, function_elims = f (#function_elims x),
haftmann@30374
   174
            intro_rules = #intro_rules x, elim_rules = #elim_rules x,
haftmann@30374
   175
            nparams = #nparams x}) thy
haftmann@30374
   176
haftmann@30374
   177
  fun map_function_intros f thy = IndCodegenData.map
haftmann@30374
   178
    (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
haftmann@30374
   179
            function_intros = f (#function_intros x), function_elims = #function_elims x,
haftmann@30374
   180
            intro_rules = #intro_rules x, elim_rules = #elim_rules x,
haftmann@30374
   181
            nparams = #nparams x}) thy
haftmann@30374
   182
haftmann@30374
   183
  fun map_intro_rules f thy = IndCodegenData.map
haftmann@30374
   184
    (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
haftmann@30374
   185
            function_intros = #function_intros x, function_elims = #function_elims x,
haftmann@30374
   186
            intro_rules = f (#intro_rules x), elim_rules = #elim_rules x,
haftmann@30374
   187
            nparams = #nparams x}) thy 
haftmann@30374
   188
  
haftmann@30374
   189
  fun map_elim_rules f thy = IndCodegenData.map
haftmann@30374
   190
    (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
haftmann@30374
   191
            function_intros = #function_intros x, function_elims = #function_elims x,
haftmann@30374
   192
            intro_rules = #intro_rules x, elim_rules = f (#elim_rules x),
haftmann@30374
   193
            nparams = #nparams x}) thy
haftmann@30374
   194
haftmann@30374
   195
  fun map_nparams f thy = IndCodegenData.map
haftmann@30374
   196
    (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x,
haftmann@30374
   197
            function_intros = #function_intros x, function_elims = #function_elims x,
haftmann@30374
   198
            intro_rules = #intro_rules x, elim_rules = #elim_rules x,
haftmann@30374
   199
            nparams = f (#nparams x)}) thy
haftmann@30374
   200
haftmann@30374
   201
(* removes first subgoal *)
haftmann@30374
   202
fun mycheat_tac thy i st =
haftmann@30374
   203
  (Tactic.rtac (SkipProof.make_thm thy (Var (("A", 0), propT))) i) st
haftmann@30374
   204
haftmann@30374
   205
(* Lightweight mode analysis **********************************************)
haftmann@30374
   206
haftmann@30374
   207
(**************************************************************************)
haftmann@30374
   208
(* source code from old code generator ************************************)
haftmann@30374
   209
haftmann@30374
   210
(**** check if a term contains only constructor functions ****)
haftmann@30374
   211
haftmann@30374
   212
fun is_constrt thy =
haftmann@30374
   213
  let
haftmann@30374
   214
    val cnstrs = flat (maps
haftmann@30374
   215
      (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd)
haftmann@30374
   216
      (Symtab.dest (DatatypePackage.get_datatypes thy)));
haftmann@30374
   217
    fun check t = (case strip_comb t of
haftmann@30374
   218
        (Free _, []) => true
haftmann@30374
   219
      | (Const (s, T), ts) => (case (AList.lookup (op =) cnstrs s, body_type T) of
haftmann@30374
   220
            (SOME (i, Tname), Type (Tname', _)) => length ts = i andalso Tname = Tname' andalso forall check ts
haftmann@30374
   221
          | _ => false)
haftmann@30374
   222
      | _ => false)
haftmann@30374
   223
  in check end;
haftmann@30374
   224
haftmann@30972
   225
(**** check if a type is an equality type (i.e. doesn't contain fun)
haftmann@30972
   226
  FIXME this is only an approximation ****)
haftmann@30374
   227
haftmann@30374
   228
fun is_eqT (Type (s, Ts)) = s <> "fun" andalso forall is_eqT Ts
haftmann@30374
   229
  | is_eqT _ = true;
haftmann@30374
   230
haftmann@30374
   231
(**** mode inference ****)
haftmann@30374
   232
haftmann@30374
   233
fun string_of_mode (iss, is) = space_implode " -> " (map
haftmann@30374
   234
  (fn NONE => "X"
haftmann@30374
   235
    | SOME js => enclose "[" "]" (commas (map string_of_int js)))
haftmann@30374
   236
       (iss @ [SOME is]));
haftmann@30374
   237
haftmann@30972
   238
fun print_modes modes = tracing ("Inferred modes:\n" ^
haftmann@30374
   239
  cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map
haftmann@30374
   240
    string_of_mode ms)) modes));
haftmann@30374
   241
haftmann@30374
   242
fun term_vs tm = fold_aterms (fn Free (x, T) => cons x | _ => I) tm [];
haftmann@30374
   243
val terms_vs = distinct (op =) o maps term_vs;
haftmann@30374
   244
haftmann@30374
   245
(** collect all Frees in a term (with duplicates!) **)
haftmann@30374
   246
fun term_vTs tm =
haftmann@30374
   247
  fold_aterms (fn Free xT => cons xT | _ => I) tm [];
haftmann@30374
   248
haftmann@30374
   249
fun get_args is ts = let
haftmann@30374
   250
  fun get_args' _ _ [] = ([], [])
haftmann@30374
   251
    | get_args' is i (t::ts) = (if i mem is then apfst else apsnd) (cons t)
haftmann@30374
   252
        (get_args' is (i+1) ts)
haftmann@30374
   253
in get_args' is 1 ts end
haftmann@30374
   254
haftmann@30972
   255
(*FIXME this function should not be named merge... make it local instead*)
haftmann@30374
   256
fun merge xs [] = xs
haftmann@30374
   257
  | merge [] ys = ys
haftmann@30374
   258
  | merge (x::xs) (y::ys) = if length x >= length y then x::merge xs (y::ys)
haftmann@30374
   259
      else y::merge (x::xs) ys;
haftmann@30374
   260
haftmann@30374
   261
fun subsets i j = if i <= j then
haftmann@30374
   262
       let val is = subsets (i+1) j
haftmann@30374
   263
       in merge (map (fn ks => i::ks) is) is end
haftmann@30374
   264
     else [[]];
haftmann@30374
   265
haftmann@30374
   266
fun cprod ([], ys) = []
haftmann@30374
   267
  | cprod (x :: xs, ys) = map (pair x) ys @ cprod (xs, ys);
haftmann@30374
   268
haftmann@30374
   269
fun cprods xss = foldr (map op :: o cprod) [[]] xss;
haftmann@30374
   270
haftmann@30972
   271
datatype hmode = Mode of mode * int list * hmode option list; (*FIXME don't understand
haftmann@30972
   272
  why there is another mode type!?*)
haftmann@30374
   273
bulwahn@31170
   274
fun modes_of_term modes t =
haftmann@30374
   275
  let
haftmann@30374
   276
    val ks = 1 upto length (binder_types (fastype_of t));
haftmann@30374
   277
    val default = [Mode (([], ks), ks, [])];
haftmann@30374
   278
    fun mk_modes name args = Option.map (maps (fn (m as (iss, is)) =>
haftmann@30374
   279
        let
haftmann@30374
   280
          val (args1, args2) =
haftmann@30374
   281
            if length args < length iss then
haftmann@30374
   282
              error ("Too few arguments for inductive predicate " ^ name)
haftmann@30374
   283
            else chop (length iss) args;
haftmann@30374
   284
          val k = length args2;
haftmann@30374
   285
          val prfx = 1 upto k
haftmann@30374
   286
        in
haftmann@30374
   287
          if not (is_prefix op = prfx is) then [] else
haftmann@30374
   288
          let val is' = map (fn i => i - k) (List.drop (is, k))
haftmann@30374
   289
          in map (fn x => Mode (m, is', x)) (cprods (map
haftmann@30374
   290
            (fn (NONE, _) => [NONE]
haftmann@30374
   291
              | (SOME js, arg) => map SOME (filter
bulwahn@31170
   292
                  (fn Mode (_, js', _) => js=js') (modes_of_term modes arg)))
haftmann@30374
   293
                    (iss ~~ args1)))
haftmann@30374
   294
          end
haftmann@30374
   295
        end)) (AList.lookup op = modes name)
haftmann@30374
   296
haftmann@30374
   297
  in (case strip_comb t of
haftmann@30374
   298
      (Const (name, _), args) => the_default default (mk_modes name args)
haftmann@30374
   299
    | (Var ((name, _), _), args) => the (mk_modes name args)
haftmann@30374
   300
    | (Free (name, _), args) => the (mk_modes name args)
haftmann@30374
   301
    | _ => default)
haftmann@30374
   302
  end
haftmann@30374
   303
haftmann@30374
   304
datatype indprem = Prem of term list * term | Negprem of term list * term | Sidecond of term;
haftmann@30374
   305
haftmann@30374
   306
fun select_mode_prem thy modes vs ps =
haftmann@30374
   307
  find_first (is_some o snd) (ps ~~ map
haftmann@30374
   308
    (fn Prem (us, t) => find_first (fn Mode (_, is, _) =>
haftmann@30374
   309
          let
haftmann@30374
   310
            val (in_ts, out_ts) = get_args is us;
haftmann@30374
   311
            val (out_ts', in_ts') = List.partition (is_constrt thy) out_ts;
haftmann@30374
   312
            val vTs = maps term_vTs out_ts';
haftmann@30374
   313
            val dupTs = map snd (duplicates (op =) vTs) @
haftmann@30374
   314
              List.mapPartial (AList.lookup (op =) vTs) vs;
haftmann@30374
   315
          in
haftmann@30374
   316
            terms_vs (in_ts @ in_ts') subset vs andalso
haftmann@30374
   317
            forall (is_eqT o fastype_of) in_ts' andalso
haftmann@30374
   318
            term_vs t subset vs andalso
haftmann@30374
   319
            forall is_eqT dupTs
haftmann@30374
   320
          end)
bulwahn@31170
   321
            (modes_of_term modes t handle Option =>
haftmann@30374
   322
               error ("Bad predicate: " ^ Syntax.string_of_term_global thy t))
haftmann@30374
   323
      | Negprem (us, t) => find_first (fn Mode (_, is, _) =>
haftmann@30374
   324
            length us = length is andalso
haftmann@30374
   325
            terms_vs us subset vs andalso
haftmann@30374
   326
            term_vs t subset vs)
bulwahn@31170
   327
            (modes_of_term modes t handle Option =>
haftmann@30374
   328
               error ("Bad predicate: " ^ Syntax.string_of_term_global thy t))
haftmann@30374
   329
      | Sidecond t => if term_vs t subset vs then SOME (Mode (([], []), [], []))
haftmann@30374
   330
          else NONE
haftmann@30374
   331
      ) ps);
haftmann@30374
   332
haftmann@30374
   333
fun check_mode_clause thy param_vs modes (iss, is) (ts, ps) =
haftmann@30374
   334
  let
haftmann@30374
   335
    val modes' = modes @ List.mapPartial
haftmann@30374
   336
      (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
haftmann@30374
   337
        (param_vs ~~ iss); 
haftmann@30374
   338
    fun check_mode_prems vs [] = SOME vs
haftmann@30374
   339
      | check_mode_prems vs ps = (case select_mode_prem thy modes' vs ps of
haftmann@30374
   340
          NONE => NONE
haftmann@30374
   341
        | SOME (x, _) => check_mode_prems
haftmann@30374
   342
            (case x of Prem (us, _) => vs union terms_vs us | _ => vs)
haftmann@30374
   343
            (filter_out (equal x) ps))
haftmann@30374
   344
    val (in_ts, in_ts') = List.partition (is_constrt thy) (fst (get_args is ts));
haftmann@30374
   345
    val in_vs = terms_vs in_ts;
haftmann@30374
   346
    val concl_vs = terms_vs ts
haftmann@30374
   347
  in
haftmann@30374
   348
    forall is_eqT (map snd (duplicates (op =) (maps term_vTs in_ts))) andalso
haftmann@30374
   349
    forall (is_eqT o fastype_of) in_ts' andalso
haftmann@30374
   350
    (case check_mode_prems (param_vs union in_vs) ps of
haftmann@30374
   351
       NONE => false
haftmann@30374
   352
     | SOME vs => concl_vs subset vs)
haftmann@30374
   353
  end;
haftmann@30374
   354
haftmann@30374
   355
fun check_modes_pred thy param_vs preds modes (p, ms) =
haftmann@30374
   356
  let val SOME rs = AList.lookup (op =) preds p
haftmann@30374
   357
  in (p, List.filter (fn m => case find_index
haftmann@30374
   358
    (not o check_mode_clause thy param_vs modes m) rs of
haftmann@30374
   359
      ~1 => true
haftmann@30972
   360
    | i => (tracing ("Clause " ^ string_of_int (i+1) ^ " of " ^
haftmann@30374
   361
      p ^ " violates mode " ^ string_of_mode m); false)) ms)
haftmann@30374
   362
  end;
haftmann@30374
   363
haftmann@30972
   364
fun fixp f (x : (string * mode list) list) =
haftmann@30374
   365
  let val y = f x
haftmann@30374
   366
  in if x = y then x else fixp f y end;
haftmann@30374
   367
haftmann@30374
   368
fun infer_modes thy extra_modes arities param_vs preds = fixp (fn modes =>
haftmann@30374
   369
  map (check_modes_pred thy param_vs preds (modes @ extra_modes)) modes)
haftmann@30374
   370
    (map (fn (s, (ks, k)) => (s, cprod (cprods (map
haftmann@30374
   371
      (fn NONE => [NONE]
haftmann@30374
   372
        | SOME k' => map SOME (subsets 1 k')) ks),
haftmann@30374
   373
      subsets 1 k))) arities);
haftmann@30374
   374
haftmann@30374
   375
haftmann@30374
   376
(*****************************************************************************************)
haftmann@30374
   377
(**** end of old source code *************************************************************)
haftmann@30374
   378
(*****************************************************************************************)
haftmann@30374
   379
(**** term construction ****)
haftmann@30374
   380
haftmann@30374
   381
(* for simple modes (e.g. parameters) only: better call it param_funT *)
haftmann@30374
   382
(* or even better: remove it and only use funT'_of - some modifications to funT'_of necessary *) 
haftmann@30374
   383
fun funT_of T NONE = T
haftmann@30374
   384
  | funT_of T (SOME mode) = let
haftmann@30374
   385
     val Ts = binder_types T;
haftmann@30374
   386
     val (Us1, Us2) = get_args mode Ts
haftmann@30374
   387
   in Us1 ---> (mk_pred_enumT (mk_tupleT Us2)) end;
haftmann@30374
   388
haftmann@30374
   389
fun funT'_of (iss, is) T = let
haftmann@30374
   390
    val Ts = binder_types T
haftmann@30374
   391
    val (paramTs, argTs) = chop (length iss) Ts
haftmann@30374
   392
    val paramTs' = map2 (fn SOME is => funT'_of ([], is) | NONE => I) iss paramTs 
haftmann@30374
   393
    val (inargTs, outargTs) = get_args is argTs
haftmann@30374
   394
  in
haftmann@30374
   395
    (paramTs' @ inargTs) ---> (mk_pred_enumT (mk_tupleT outargTs))
haftmann@30374
   396
  end; 
haftmann@30374
   397
haftmann@30374
   398
haftmann@30374
   399
fun mk_v (names, vs) s T = (case AList.lookup (op =) vs s of
haftmann@30374
   400
      NONE => ((names, (s, [])::vs), Free (s, T))
haftmann@30374
   401
    | SOME xs =>
haftmann@30374
   402
        let
haftmann@30374
   403
          val s' = Name.variant names s;
haftmann@30374
   404
          val v = Free (s', T)
haftmann@30374
   405
        in
haftmann@30374
   406
          ((s'::names, AList.update (op =) (s, v::xs) vs), v)
haftmann@30374
   407
        end);
haftmann@30374
   408
haftmann@30374
   409
fun distinct_v (nvs, Free (s, T)) = mk_v nvs s T
haftmann@30374
   410
  | distinct_v (nvs, t $ u) =
haftmann@30374
   411
      let
haftmann@30374
   412
        val (nvs', t') = distinct_v (nvs, t);
haftmann@30374
   413
        val (nvs'', u') = distinct_v (nvs', u);
haftmann@30374
   414
      in (nvs'', t' $ u') end
haftmann@30374
   415
  | distinct_v x = x;
haftmann@30374
   416
haftmann@30374
   417
fun compile_match thy eqs eqs' out_ts success_t =
haftmann@30374
   418
  let 
haftmann@30374
   419
    val eqs'' = maps mk_eq eqs @ eqs'
haftmann@30374
   420
    val names = fold Term.add_free_names (success_t :: eqs'' @ out_ts) [];
haftmann@30374
   421
    val name = Name.variant names "x";
haftmann@30374
   422
    val name' = Name.variant (name :: names) "y";
haftmann@30374
   423
    val T = mk_tupleT (map fastype_of out_ts);
haftmann@30374
   424
    val U = fastype_of success_t;
haftmann@30374
   425
    val U' = dest_pred_enumT U;
haftmann@30374
   426
    val v = Free (name, T);
haftmann@30374
   427
    val v' = Free (name', T);
haftmann@30374
   428
  in
haftmann@30374
   429
    lambda v (fst (DatatypePackage.make_case
haftmann@30374
   430
      (ProofContext.init thy) false [] v
haftmann@30374
   431
      [(mk_tuple out_ts,
haftmann@30374
   432
        if null eqs'' then success_t
haftmann@30374
   433
        else Const (@{const_name HOL.If}, HOLogic.boolT --> U --> U --> U) $
haftmann@30374
   434
          foldr1 HOLogic.mk_conj eqs'' $ success_t $
haftmann@30374
   435
            mk_empty U'),
haftmann@30374
   436
       (v', mk_empty U')]))
haftmann@30374
   437
  end;
haftmann@30374
   438
haftmann@30972
   439
fun modename_of thy name mode = let
haftmann@30374
   440
    val v = (PredModetab.lookup (#names (IndCodegenData.get thy)) (name, mode))
haftmann@30972
   441
  in if (is_some v) then the v (*FIXME use case here*)
haftmann@30972
   442
     else error ("fun modename_of - definition not found: name: " ^ name ^ " mode: " ^  (makestring mode))
haftmann@30374
   443
  end
haftmann@30374
   444
haftmann@30972
   445
fun modes_of thy =
haftmann@30972
   446
  these o Symtab.lookup ((#modes o IndCodegenData.get) thy);
haftmann@30972
   447
haftmann@30972
   448
(*FIXME function can be removed*)
haftmann@30374
   449
fun mk_funcomp f t =
haftmann@30374
   450
  let
haftmann@30374
   451
    val names = Term.add_free_names t [];
haftmann@30374
   452
    val Ts = binder_types (fastype_of t);
haftmann@30374
   453
    val vs = map Free
haftmann@30374
   454
      (Name.variant_list names (replicate (length Ts) "x") ~~ Ts)
haftmann@30374
   455
  in
haftmann@30374
   456
    fold_rev lambda vs (f (list_comb (t, vs)))
haftmann@30374
   457
  end;
haftmann@30374
   458
haftmann@30374
   459
fun compile_param thy modes (NONE, t) = t
haftmann@30374
   460
  | compile_param thy modes (m as SOME (Mode ((iss, is'), is, ms)), t) = let
haftmann@30374
   461
    val (f, args) = strip_comb t
haftmann@30374
   462
    val (params, args') = chop (length ms) args
haftmann@30374
   463
    val params' = map (compile_param thy modes) (ms ~~ params)
haftmann@30374
   464
    val f' = case f of
haftmann@30374
   465
        Const (name, T) =>
haftmann@30374
   466
          if AList.defined op = modes name then
haftmann@30972
   467
            Const (modename_of thy name (iss, is'), funT'_of (iss, is') T)
haftmann@30374
   468
          else error "compile param: Not an inductive predicate with correct mode"
haftmann@30374
   469
      | Free (name, T) => Free (name, funT_of T (SOME is'))
haftmann@30374
   470
    in list_comb (f', params' @ args') end
haftmann@30374
   471
  | compile_param _ _ _ = error "compile params"
haftmann@30374
   472
haftmann@30374
   473
fun compile_expr thy modes (SOME (Mode (mode, is, ms)), t) =
haftmann@30374
   474
      (case strip_comb t of
haftmann@30374
   475
         (Const (name, T), params) =>
haftmann@30374
   476
           if AList.defined op = modes name then
haftmann@30374
   477
             let
haftmann@30374
   478
               val (Ts, Us) = get_args is
haftmann@30374
   479
                 (curry Library.drop (length ms) (fst (strip_type T)))
haftmann@30374
   480
               val params' = map (compile_param thy modes) (ms ~~ params)
haftmann@30972
   481
               val mode_id = modename_of thy name mode
haftmann@30374
   482
             in list_comb (Const (mode_id, ((map fastype_of params') @ Ts) --->
haftmann@30374
   483
               mk_pred_enumT (mk_tupleT Us)), params')
haftmann@30374
   484
             end
haftmann@30374
   485
           else error "not a valid inductive expression"
haftmann@30374
   486
       | (Free (name, T), args) =>
haftmann@30374
   487
         (*if name mem param_vs then *)
haftmann@30374
   488
         (* Higher order mode call *)
haftmann@30374
   489
         let val r = Free (name, funT_of T (SOME is))
haftmann@30374
   490
         in list_comb (r, args) end)
haftmann@30374
   491
  | compile_expr _ _ _ = error "not a valid inductive expression"
haftmann@30374
   492
haftmann@30374
   493
haftmann@30374
   494
fun compile_clause thy all_vs param_vs modes (iss, is) (ts, ps) inp =
haftmann@30374
   495
  let
haftmann@30374
   496
    val modes' = modes @ List.mapPartial
haftmann@30374
   497
      (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
haftmann@30374
   498
        (param_vs ~~ iss);
haftmann@30374
   499
    fun check_constrt ((names, eqs), t) =
haftmann@30374
   500
      if is_constrt thy t then ((names, eqs), t) else
haftmann@30374
   501
        let
haftmann@30374
   502
          val s = Name.variant names "x";
haftmann@30374
   503
          val v = Free (s, fastype_of t)
haftmann@30374
   504
        in ((s::names, HOLogic.mk_eq (v, t)::eqs), v) end;
haftmann@30374
   505
haftmann@30374
   506
    val (in_ts, out_ts) = get_args is ts;
haftmann@30374
   507
    val ((all_vs', eqs), in_ts') =
haftmann@30374
   508
      (*FIXME*) Library.foldl_map check_constrt ((all_vs, []), in_ts);
haftmann@30374
   509
haftmann@30374
   510
    fun compile_prems out_ts' vs names [] =
haftmann@30374
   511
          let
haftmann@30374
   512
            val ((names', eqs'), out_ts'') =
haftmann@30374
   513
              (*FIXME*) Library.foldl_map check_constrt ((names, []), out_ts');
haftmann@30374
   514
            val (nvs, out_ts''') = (*FIXME*) Library.foldl_map distinct_v
haftmann@30374
   515
              ((names', map (rpair []) vs), out_ts'');
haftmann@30374
   516
          in
haftmann@30374
   517
            compile_match thy (snd nvs) (eqs @ eqs') out_ts'''
haftmann@30374
   518
              (mk_single (mk_tuple out_ts))
haftmann@30374
   519
          end
haftmann@30374
   520
      | compile_prems out_ts vs names ps =
haftmann@30374
   521
          let
haftmann@30374
   522
            val vs' = distinct (op =) (flat (vs :: map term_vs out_ts));
haftmann@30374
   523
            val SOME (p, mode as SOME (Mode (_, js, _))) =
haftmann@30374
   524
              select_mode_prem thy modes' vs' ps
haftmann@30374
   525
            val ps' = filter_out (equal p) ps
haftmann@30374
   526
            val ((names', eqs), out_ts') =
haftmann@30374
   527
              (*FIXME*) Library.foldl_map check_constrt ((names, []), out_ts)
haftmann@30374
   528
            val (nvs, out_ts'') = (*FIXME*) Library.foldl_map distinct_v
haftmann@30374
   529
              ((names', map (rpair []) vs), out_ts')
haftmann@30374
   530
            val (compiled_clause, rest) = case p of
haftmann@30374
   531
               Prem (us, t) =>
haftmann@30374
   532
                 let
haftmann@30374
   533
                   val (in_ts, out_ts''') = get_args js us;
haftmann@30374
   534
                   val u = list_comb (compile_expr thy modes (mode, t), in_ts)
haftmann@30374
   535
                   val rest = compile_prems out_ts''' vs' (fst nvs) ps'
haftmann@30374
   536
                 in
haftmann@30374
   537
                   (u, rest)
haftmann@30374
   538
                 end
haftmann@30374
   539
             | Negprem (us, t) =>
haftmann@30374
   540
                 let
haftmann@30374
   541
                   val (in_ts, out_ts''') = get_args js us
haftmann@30374
   542
                   val u = list_comb (compile_expr thy modes (mode, t), in_ts)
haftmann@30374
   543
                   val rest = compile_prems out_ts''' vs' (fst nvs) ps'
haftmann@30374
   544
                 in
haftmann@30374
   545
                   (mk_not_pred u, rest)
haftmann@30374
   546
                 end
haftmann@30374
   547
             | Sidecond t =>
haftmann@30374
   548
                 let
haftmann@30374
   549
                   val rest = compile_prems [] vs' (fst nvs) ps';
haftmann@30374
   550
                 in
haftmann@30374
   551
                   (mk_if_predenum t, rest)
haftmann@30374
   552
                 end
haftmann@30374
   553
          in
haftmann@30374
   554
            compile_match thy (snd nvs) eqs out_ts'' 
haftmann@30374
   555
              (mk_bind (compiled_clause, rest))
haftmann@30374
   556
          end
haftmann@30374
   557
    val prem_t = compile_prems in_ts' param_vs all_vs' ps;
haftmann@30374
   558
  in
haftmann@30374
   559
    mk_bind (mk_single inp, prem_t)
haftmann@30374
   560
  end
haftmann@30374
   561
haftmann@30374
   562
fun compile_pred thy all_vs param_vs modes s T cls mode =
haftmann@30374
   563
  let
haftmann@30374
   564
    val Ts = binder_types T;
haftmann@30374
   565
    val (Ts1, Ts2) = chop (length param_vs) Ts;
haftmann@30374
   566
    val Ts1' = map2 funT_of Ts1 (fst mode)
haftmann@30374
   567
    val (Us1, Us2) = get_args (snd mode) Ts2;
haftmann@30374
   568
    val xnames = Name.variant_list param_vs
haftmann@30374
   569
      (map (fn i => "x" ^ string_of_int i) (snd mode));
haftmann@30374
   570
    val xs = map2 (fn s => fn T => Free (s, T)) xnames Us1;
haftmann@30374
   571
    val cl_ts =
haftmann@30374
   572
      map (fn cl => compile_clause thy
haftmann@30374
   573
        all_vs param_vs modes mode cl (mk_tuple xs)) cls;
haftmann@30972
   574
    val mode_id = modename_of thy s mode
haftmann@30374
   575
  in
haftmann@30374
   576
    HOLogic.mk_Trueprop (HOLogic.mk_eq
haftmann@30374
   577
      (list_comb (Const (mode_id, (Ts1' @ Us1) --->
haftmann@30374
   578
           mk_pred_enumT (mk_tupleT Us2)),
haftmann@30374
   579
         map2 (fn s => fn T => Free (s, T)) param_vs Ts1' @ xs),
haftmann@30374
   580
       foldr1 mk_sup cl_ts))
haftmann@30374
   581
  end;
haftmann@30374
   582
haftmann@30374
   583
fun compile_preds thy all_vs param_vs modes preds =
haftmann@30374
   584
  map (fn (s, (T, cls)) =>
haftmann@30374
   585
    map (compile_pred thy all_vs param_vs modes s T cls)
haftmann@30374
   586
      ((the o AList.lookup (op =) modes) s)) preds;
haftmann@30374
   587
haftmann@30374
   588
(* end of term construction ******************************************************)
haftmann@30374
   589
haftmann@30374
   590
(* special setup for simpset *)                  
haftmann@30374
   591
val HOL_basic_ss' = HOL_basic_ss setSolver 
haftmann@30374
   592
  (mk_solver "all_tac_solver" (fn _ => fn _ => all_tac))
haftmann@30374
   593
haftmann@30374
   594
haftmann@30374
   595
(* misc: constructing and proving tupleE rules ***********************************)
haftmann@30374
   596
haftmann@30374
   597
haftmann@30374
   598
(* Creating definitions of functional programs 
haftmann@30374
   599
   and proving intro and elim rules **********************************************) 
haftmann@30374
   600
haftmann@30374
   601
fun is_ind_pred thy c = 
haftmann@30374
   602
  (can (InductivePackage.the_inductive (ProofContext.init thy)) c) orelse
haftmann@30374
   603
  (c mem_string (Symtab.keys (#intro_rules (IndCodegenData.get thy))))
haftmann@30374
   604
haftmann@30374
   605
fun get_name_of_ind_calls_of_clauses thy preds intrs =
haftmann@30374
   606
    fold Term.add_consts intrs [] |> map fst
haftmann@30374
   607
    |> filter_out (member (op =) preds) |> filter (is_ind_pred thy)
haftmann@30374
   608
haftmann@30972
   609
fun print_arities arities = tracing ("Arities:\n" ^
haftmann@30374
   610
  cat_lines (map (fn (s, (ks, k)) => s ^ ": " ^
haftmann@30374
   611
    space_implode " -> " (map
haftmann@30374
   612
      (fn NONE => "X" | SOME k' => string_of_int k')
haftmann@30374
   613
        (ks @ [SOME k]))) arities));
haftmann@30374
   614
haftmann@30374
   615
fun mk_Eval_of ((x, T), NONE) names = (x, names)
haftmann@30374
   616
  | mk_Eval_of ((x, T), SOME mode) names = let
haftmann@30374
   617
  val Ts = binder_types T
haftmann@30374
   618
  val argnames = Name.variant_list names
haftmann@30374
   619
        (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
haftmann@30374
   620
  val args = map Free (argnames ~~ Ts)
haftmann@30374
   621
  val (inargs, outargs) = get_args mode args
haftmann@30374
   622
  val r = mk_Eval (list_comb (x, inargs), mk_tuple outargs)
haftmann@30374
   623
  val t = fold_rev lambda args r 
haftmann@30374
   624
in
haftmann@30374
   625
  (t, argnames @ names)
haftmann@30374
   626
end;
haftmann@30374
   627
haftmann@30374
   628
fun create_intro_rule nparams mode defthm mode_id funT pred thy =
haftmann@30374
   629
let
haftmann@30374
   630
  val Ts = binder_types (fastype_of pred)
haftmann@30374
   631
  val funtrm = Const (mode_id, funT)
haftmann@30374
   632
  val argnames = Name.variant_list []
haftmann@30374
   633
        (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
haftmann@30374
   634
  val (Ts1, Ts2) = chop nparams Ts;
haftmann@30374
   635
  val Ts1' = map2 funT_of Ts1 (fst mode)
haftmann@30374
   636
  val args = map Free (argnames ~~ (Ts1' @ Ts2))
haftmann@30374
   637
  val (params, io_args) = chop nparams args
haftmann@30374
   638
  val (inargs, outargs) = get_args (snd mode) io_args
haftmann@30374
   639
  val (params', names) = fold_map mk_Eval_of ((params ~~ Ts1) ~~ (fst mode)) []
haftmann@30374
   640
  val predprop = HOLogic.mk_Trueprop (list_comb (pred, params' @ io_args))
haftmann@30374
   641
  val funargs = params @ inargs
haftmann@30374
   642
  val funpropE = HOLogic.mk_Trueprop (mk_Eval (list_comb (funtrm, funargs),
haftmann@30374
   643
                  if null outargs then Free("y", HOLogic.unitT) else mk_tuple outargs))
haftmann@30374
   644
  val funpropI = HOLogic.mk_Trueprop (mk_Eval (list_comb (funtrm, funargs),
haftmann@30374
   645
                   mk_tuple outargs))
haftmann@30374
   646
  val introtrm = Logic.mk_implies (predprop, funpropI)
haftmann@30374
   647
  val simprules = [defthm, @{thm eval_pred},
haftmann@30374
   648
                   @{thm "split_beta"}, @{thm "fst_conv"}, @{thm "snd_conv"}]
haftmann@30374
   649
  val unfolddef_tac = (Simplifier.asm_full_simp_tac (HOL_basic_ss addsimps simprules) 1)
haftmann@30374
   650
  val introthm = Goal.prove (ProofContext.init thy) (argnames @ ["y"]) [] introtrm (fn {...} => unfolddef_tac)
haftmann@30374
   651
  val P = HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT));
haftmann@30374
   652
  val elimtrm = Logic.list_implies ([funpropE, Logic.mk_implies (predprop, P)], P)
haftmann@30374
   653
  val elimthm = Goal.prove (ProofContext.init thy) (argnames @ ["y", "P"]) [] elimtrm (fn {...} => unfolddef_tac)
haftmann@30374
   654
in
haftmann@30374
   655
  map_function_intros (Symtab.update_new (mode_id, introthm)) thy
haftmann@30374
   656
  |> map_function_elims (Symtab.update_new (mode_id, elimthm))
haftmann@30387
   657
  |> PureThy.store_thm (Binding.name (Long_Name.base_name mode_id ^ "I"), introthm) |> snd
haftmann@30387
   658
  |> PureThy.store_thm (Binding.name (Long_Name.base_name mode_id ^ "E"), elimthm)  |> snd
haftmann@30374
   659
end;
haftmann@30374
   660
haftmann@30374
   661
fun create_definitions preds nparams (name, modes) thy =
haftmann@30374
   662
  let
haftmann@30374
   663
    val _ = tracing "create definitions"
haftmann@30374
   664
    val T = AList.lookup (op =) preds name |> the
haftmann@30374
   665
    fun create_definition mode thy = let
haftmann@30374
   666
      fun string_of_mode mode = if null mode then "0"
haftmann@30374
   667
        else space_implode "_" (map string_of_int mode)
haftmann@30374
   668
      val HOmode = let
haftmann@30374
   669
        fun string_of_HOmode m s = case m of NONE => s | SOME mode => s ^ "__" ^ (string_of_mode mode)    
haftmann@30374
   670
        in (fold string_of_HOmode (fst mode) "") end;
haftmann@30374
   671
      val mode_id = name ^ (if HOmode = "" then "_" else HOmode ^ "___")
haftmann@30374
   672
        ^ (string_of_mode (snd mode))
haftmann@30374
   673
      val Ts = binder_types T;
haftmann@30374
   674
      val (Ts1, Ts2) = chop nparams Ts;
haftmann@30374
   675
      val Ts1' = map2 funT_of Ts1 (fst mode)
haftmann@30374
   676
      val (Us1, Us2) = get_args (snd mode) Ts2;
haftmann@30374
   677
      val names = Name.variant_list []
haftmann@30374
   678
        (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
haftmann@30374
   679
      val xs = map Free (names ~~ (Ts1' @ Ts2));
haftmann@30374
   680
      val (xparams, xargs) = chop nparams xs;
haftmann@30374
   681
      val (xparams', names') = fold_map mk_Eval_of ((xparams ~~ Ts1) ~~ (fst mode)) names
haftmann@30374
   682
      val (xins, xouts) = get_args (snd mode) xargs;
haftmann@30374
   683
      fun mk_split_lambda [] t = lambda (Free (Name.variant names' "x", HOLogic.unitT)) t
haftmann@30374
   684
       | mk_split_lambda [x] t = lambda x t
haftmann@30374
   685
       | mk_split_lambda xs t = let
haftmann@30374
   686
         fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t))
haftmann@30374
   687
           | mk_split_lambda' (x::xs) t = HOLogic.mk_split (lambda x (mk_split_lambda' xs t))
haftmann@30374
   688
         in mk_split_lambda' xs t end;
haftmann@30374
   689
      val predterm = mk_Enum (mk_split_lambda xouts (list_comb (Const (name, T), xparams' @ xargs)))
haftmann@30374
   690
      val funT = (Ts1' @ Us1) ---> (mk_pred_enumT (mk_tupleT Us2))
haftmann@30387
   691
      val mode_id = Sign.full_bname thy (Long_Name.base_name mode_id)
haftmann@30374
   692
      val lhs = list_comb (Const (mode_id, funT), xparams @ xins)
haftmann@30374
   693
      val def = Logic.mk_equals (lhs, predterm)
haftmann@30374
   694
      val ([defthm], thy') = thy |>
haftmann@30387
   695
        Sign.add_consts_i [(Binding.name (Long_Name.base_name mode_id), funT, NoSyn)] |>
haftmann@30387
   696
        PureThy.add_defs false [((Binding.name (Long_Name.base_name mode_id ^ "_def"), def), [])]
haftmann@30374
   697
      in thy' |> map_names (PredModetab.update_new ((name, mode), mode_id))
haftmann@30374
   698
           |> map_function_defs (Symtab.update_new (mode_id, defthm))
haftmann@30374
   699
           |> create_intro_rule nparams mode defthm mode_id funT (Const (name, T))
haftmann@30374
   700
      end;
haftmann@30374
   701
  in
haftmann@30374
   702
    fold create_definition modes thy
haftmann@30374
   703
  end;
haftmann@30374
   704
haftmann@30374
   705
(**************************************************************************************)
haftmann@30374
   706
(* Proving equivalence of term *)
haftmann@30374
   707
haftmann@30374
   708
haftmann@30972
   709
fun intro_rule thy pred mode = modename_of thy pred mode
haftmann@30374
   710
    |> Symtab.lookup (#function_intros (IndCodegenData.get thy)) |> the
haftmann@30374
   711
haftmann@30972
   712
fun elim_rule thy pred mode = modename_of thy pred mode
haftmann@30374
   713
    |> Symtab.lookup (#function_elims (IndCodegenData.get thy)) |> the
haftmann@30374
   714
haftmann@30374
   715
fun pred_intros thy predname = let
haftmann@30374
   716
    fun is_intro_of pred intro = let
haftmann@30374
   717
      val const = fst (strip_comb (HOLogic.dest_Trueprop (concl_of intro)))
haftmann@30374
   718
    in (fst (dest_Const const) = pred) end;
haftmann@30374
   719
    val d = IndCodegenData.get thy
haftmann@30374
   720
  in
haftmann@30374
   721
    if (Symtab.defined (#intro_rules d) predname) then
haftmann@30374
   722
      rev (Symtab.lookup_list (#intro_rules d) predname)
haftmann@30374
   723
    else
haftmann@30374
   724
      InductivePackage.the_inductive (ProofContext.init thy) predname
haftmann@30374
   725
      |> snd |> #intrs |> filter (is_intro_of predname)
haftmann@30374
   726
  end
haftmann@30374
   727
haftmann@30374
   728
fun function_definition thy pred mode =
haftmann@30972
   729
  modename_of thy pred mode |> Symtab.lookup (#function_defs (IndCodegenData.get thy)) |> the
haftmann@30374
   730
haftmann@30374
   731
fun is_Type (Type _) = true
haftmann@30374
   732
  | is_Type _ = false
haftmann@30374
   733
haftmann@30374
   734
fun imp_prems_conv cv ct =
haftmann@30374
   735
  case Thm.term_of ct of
haftmann@30374
   736
    Const ("==>", _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv) (imp_prems_conv cv) ct
haftmann@30374
   737
  | _ => Conv.all_conv ct
haftmann@30374
   738
haftmann@30374
   739
fun Trueprop_conv cv ct =
haftmann@30374
   740
  case Thm.term_of ct of
haftmann@30374
   741
    Const ("Trueprop", _) $ _ => Conv.arg_conv cv ct  
haftmann@30374
   742
  | _ => error "Trueprop_conv"
haftmann@30374
   743
bulwahn@31105
   744
fun preprocess_intro thy rule =
haftmann@30374
   745
  Conv.fconv_rule
haftmann@30374
   746
    (imp_prems_conv
bulwahn@31105
   747
      (Trueprop_conv (Conv.try_conv (Conv.rewr_conv (Thm.symmetric @{thm Predicate.eq_is_eq})))))
bulwahn@31105
   748
    (Thm.transfer thy rule)
haftmann@30374
   749
bulwahn@31105
   750
fun preprocess_elim thy nargs elimrule = let
haftmann@30374
   751
   fun replace_eqs (Const ("Trueprop", _) $ (Const ("op =", T) $ lhs $ rhs)) =
bulwahn@31105
   752
      HOLogic.mk_Trueprop (Const (@{const_name Predicate.eq}, T) $ lhs $ rhs)
haftmann@30374
   753
    | replace_eqs t = t
haftmann@30374
   754
   fun preprocess_case t = let
haftmann@30374
   755
     val params = Logic.strip_params t
haftmann@30374
   756
     val (assums1, assums2) = chop nargs (Logic.strip_assums_hyp t)
haftmann@30374
   757
     val assums_hyp' = assums1 @ (map replace_eqs assums2)
haftmann@30374
   758
     in list_all (params, Logic.list_implies (assums_hyp', Logic.strip_assums_concl t)) end
haftmann@30374
   759
   val prems = Thm.prems_of elimrule
haftmann@30374
   760
   val cases' = map preprocess_case (tl prems)
haftmann@30374
   761
   val elimrule' = Logic.list_implies ((hd prems) :: cases', Thm.concl_of elimrule)
haftmann@30374
   762
 in
haftmann@30374
   763
   Thm.equal_elim
bulwahn@31105
   764
     (Thm.symmetric (Conv.implies_concl_conv (MetaSimplifier.rewrite true [@{thm eq_is_eq}])
haftmann@30374
   765
        (cterm_of thy elimrule')))
haftmann@30374
   766
     elimrule
bulwahn@31105
   767
 end;
haftmann@30374
   768
haftmann@30374
   769
haftmann@30374
   770
(* returns true if t is an application of an datatype constructor *)
haftmann@30374
   771
(* which then consequently would be splitted *)
haftmann@30374
   772
(* else false *)
haftmann@30374
   773
fun is_constructor thy t =
haftmann@30374
   774
  if (is_Type (fastype_of t)) then
haftmann@30374
   775
    (case DatatypePackage.get_datatype thy ((fst o dest_Type o fastype_of) t) of
haftmann@30374
   776
      NONE => false
haftmann@30374
   777
    | SOME info => (let
haftmann@30374
   778
      val constr_consts = maps (fn (_, (_, _, constrs)) => map fst constrs) (#descr info)
haftmann@30374
   779
      val (c, _) = strip_comb t
haftmann@30374
   780
      in (case c of
haftmann@30374
   781
        Const (name, _) => name mem_string constr_consts
haftmann@30374
   782
        | _ => false) end))
haftmann@30374
   783
  else false
haftmann@30374
   784
haftmann@30374
   785
(* MAJOR FIXME:  prove_params should be simple
haftmann@30374
   786
 - different form of introrule for parameters ? *)
haftmann@30374
   787
fun prove_param thy modes (NONE, t) = all_tac 
haftmann@30374
   788
  | prove_param thy modes (m as SOME (Mode (mode, is, ms)), t) = let
haftmann@30374
   789
    val  (f, args) = strip_comb t
haftmann@30374
   790
    val (params, _) = chop (length ms) args
haftmann@30374
   791
    val f_tac = case f of
haftmann@30374
   792
        Const (name, T) => simp_tac (HOL_basic_ss addsimps 
haftmann@30374
   793
           @{thm eval_pred}::function_definition thy name mode::[]) 1
haftmann@30374
   794
      | Free _ => all_tac
haftmann@30374
   795
  in  
haftmann@30374
   796
    print_tac "before simplification in prove_args:"
haftmann@30374
   797
    THEN debug_tac ("mode" ^ (makestring mode))
haftmann@30374
   798
    THEN f_tac
haftmann@30374
   799
    THEN print_tac "after simplification in prove_args"
haftmann@30374
   800
    (* work with parameter arguments *)
haftmann@30374
   801
    THEN (EVERY (map (prove_param thy modes) (ms ~~ params)))
haftmann@30374
   802
    THEN (REPEAT_DETERM (atac 1))
haftmann@30374
   803
  end
haftmann@30374
   804
haftmann@30374
   805
fun prove_expr thy modes (SOME (Mode (mode, is, ms)), t, us) (premposition : int) =
haftmann@30374
   806
  (case strip_comb t of
haftmann@30374
   807
    (Const (name, T), args) =>
haftmann@30374
   808
      if AList.defined op = modes name then (let
haftmann@30374
   809
          val introrule = intro_rule thy name mode
haftmann@30374
   810
          (*val (in_args, out_args) = get_args is us
haftmann@30374
   811
          val (pred, rargs) = strip_comb (HOLogic.dest_Trueprop
haftmann@30374
   812
            (hd (Logic.strip_imp_prems (prop_of introrule))))
haftmann@30374
   813
          val nparams = length ms (* get_nparams thy (fst (dest_Const pred)) *)
haftmann@30374
   814
          val (_, args) = chop nparams rargs
haftmann@30374
   815
          val _ = tracing ("args: " ^ (makestring args))
haftmann@30374
   816
          val subst = map (pairself (cterm_of thy)) (args ~~ us)
haftmann@30374
   817
          val _ = tracing ("subst: " ^ (makestring subst))
haftmann@30374
   818
          val inst_introrule = Drule.cterm_instantiate subst introrule*)
haftmann@30374
   819
         (* the next line is old and probably wrong *)
haftmann@30374
   820
          val (args1, args2) = chop (length ms) args
haftmann@30374
   821
          val _ = tracing ("premposition: " ^ (makestring premposition))
haftmann@30374
   822
        in
haftmann@30374
   823
        rtac @{thm bindI} 1
haftmann@30374
   824
        THEN print_tac "before intro rule:"
haftmann@30374
   825
        THEN debug_tac ("mode" ^ (makestring mode))
haftmann@30374
   826
        THEN debug_tac (makestring introrule)
haftmann@30374
   827
        THEN debug_tac ("premposition: " ^ (makestring premposition))
haftmann@30374
   828
        (* for the right assumption in first position *)
haftmann@30374
   829
        THEN rotate_tac premposition 1
haftmann@30374
   830
        THEN rtac introrule 1
haftmann@30374
   831
        THEN print_tac "after intro rule"
haftmann@30374
   832
        (* work with parameter arguments *)
haftmann@30374
   833
        THEN (EVERY (map (prove_param thy modes) (ms ~~ args1)))
haftmann@30374
   834
        THEN (REPEAT_DETERM (atac 1)) end)
haftmann@30374
   835
      else error "Prove expr if case not implemented"
haftmann@30374
   836
    | _ => rtac @{thm bindI} 1
haftmann@30374
   837
           THEN atac 1)
haftmann@30374
   838
  | prove_expr _ _ _ _ =  error "Prove expr not implemented"
haftmann@30374
   839
haftmann@30374
   840
fun SOLVED tac st = FILTER (fn st' => nprems_of st' = nprems_of st - 1) tac st; 
haftmann@30374
   841
haftmann@30374
   842
fun SOLVEDALL tac st = FILTER (fn st' => nprems_of st' = 0) tac st
haftmann@30374
   843
haftmann@30374
   844
fun prove_match thy (out_ts : term list) = let
haftmann@30374
   845
  fun get_case_rewrite t =
haftmann@30374
   846
    if (is_constructor thy t) then let
haftmann@30374
   847
      val case_rewrites = (#case_rewrites (DatatypePackage.the_datatype thy
haftmann@30374
   848
        ((fst o dest_Type o fastype_of) t)))
haftmann@30374
   849
      in case_rewrites @ (flat (map get_case_rewrite (snd (strip_comb t)))) end
haftmann@30374
   850
    else []
haftmann@30374
   851
  val simprules = @{thm "unit.cases"} :: @{thm "prod.cases"} :: (flat (map get_case_rewrite out_ts))
haftmann@30374
   852
(* replace TRY by determining if it necessary - are there equations when calling compile match? *)
haftmann@30374
   853
in
haftmann@30374
   854
  print_tac ("before prove_match rewriting: simprules = " ^ (makestring simprules))
haftmann@30374
   855
   (* make this simpset better! *)
haftmann@30374
   856
  THEN asm_simp_tac (HOL_basic_ss' addsimps simprules) 1
haftmann@30374
   857
  THEN print_tac "after prove_match:"
haftmann@30374
   858
  THEN (DETERM (TRY (EqSubst.eqsubst_tac (ProofContext.init thy) [0] [@{thm "HOL.if_P"}] 1
haftmann@30374
   859
         THEN (REPEAT_DETERM (rtac @{thm conjI} 1 THEN (SOLVED (asm_simp_tac HOL_basic_ss 1))))
haftmann@30374
   860
         THEN (SOLVED (asm_simp_tac HOL_basic_ss 1)))))
haftmann@30374
   861
  THEN print_tac "after if simplification"
haftmann@30374
   862
end;
haftmann@30374
   863
haftmann@30374
   864
(* corresponds to compile_fun -- maybe call that also compile_sidecond? *)
haftmann@30374
   865
haftmann@30374
   866
fun prove_sidecond thy modes t = let
haftmann@30374
   867
  val _ = tracing ("prove_sidecond:" ^ (makestring t))
haftmann@30374
   868
  fun preds_of t nameTs = case strip_comb t of 
haftmann@30374
   869
    (f as Const (name, T), args) =>
haftmann@30374
   870
      if AList.defined (op =) modes name then (name, T) :: nameTs
haftmann@30374
   871
        else fold preds_of args nameTs
haftmann@30374
   872
    | _ => nameTs
haftmann@30374
   873
  val preds = preds_of t []
haftmann@30374
   874
  
haftmann@30374
   875
  val _ = tracing ("preds: " ^ (makestring preds))
haftmann@30374
   876
  val defs = map
haftmann@30374
   877
    (fn (pred, T) => function_definition thy pred ([], (1 upto (length (binder_types T)))))
haftmann@30374
   878
      preds
haftmann@30374
   879
  val _ = tracing ("defs: " ^ (makestring defs))
haftmann@30374
   880
in 
haftmann@30374
   881
   (* remove not_False_eq_True when simpset in prove_match is better *)
haftmann@30374
   882
   simp_tac (HOL_basic_ss addsimps @{thm not_False_eq_True} :: @{thm eval_pred} :: defs) 1 
haftmann@30374
   883
   (* need better control here! *)
haftmann@30374
   884
   THEN print_tac "after sidecond simplification"
haftmann@30374
   885
   end
haftmann@30374
   886
haftmann@30374
   887
fun prove_clause thy nargs all_vs param_vs modes (iss, is) (ts, ps) = let
haftmann@30374
   888
  val modes' = modes @ List.mapPartial
haftmann@30374
   889
   (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
haftmann@30374
   890
     (param_vs ~~ iss);
haftmann@30374
   891
  fun check_constrt ((names, eqs), t) =
haftmann@30374
   892
      if is_constrt thy t then ((names, eqs), t) else
haftmann@30374
   893
        let
haftmann@30374
   894
          val s = Name.variant names "x";
haftmann@30374
   895
          val v = Free (s, fastype_of t)
haftmann@30374
   896
        in ((s::names, HOLogic.mk_eq (v, t)::eqs), v) end;
haftmann@30374
   897
  
haftmann@30374
   898
  val (in_ts, clause_out_ts) = get_args is ts;
haftmann@30374
   899
  val ((all_vs', eqs), in_ts') =
haftmann@30374
   900
      (*FIXME*) Library.foldl_map check_constrt ((all_vs, []), in_ts);
haftmann@30374
   901
  fun prove_prems out_ts vs [] =
haftmann@30374
   902
    (prove_match thy out_ts)
haftmann@30374
   903
    THEN asm_simp_tac HOL_basic_ss' 1
haftmann@30374
   904
    THEN print_tac "before the last rule of singleI:"
haftmann@30374
   905
    THEN (rtac (if null clause_out_ts then @{thm singleI_unit} else @{thm singleI}) 1)
haftmann@30374
   906
  | prove_prems out_ts vs rps =
haftmann@30374
   907
    let
haftmann@30374
   908
      val vs' = distinct (op =) (flat (vs :: map term_vs out_ts));
haftmann@30374
   909
      val SOME (p, mode as SOME (Mode ((iss, js), _, param_modes))) =
haftmann@30374
   910
        select_mode_prem thy modes' vs' rps;
haftmann@30374
   911
      val premposition = (find_index (equal p) ps) + nargs
haftmann@30374
   912
      val rps' = filter_out (equal p) rps;
haftmann@30374
   913
      val rest_tac = (case p of Prem (us, t) =>
haftmann@30374
   914
          let
haftmann@30374
   915
            val (in_ts, out_ts''') = get_args js us
haftmann@30374
   916
            val rec_tac = prove_prems out_ts''' vs' rps'
haftmann@30374
   917
          in
haftmann@30374
   918
            print_tac "before clause:"
haftmann@30374
   919
            THEN asm_simp_tac HOL_basic_ss 1
haftmann@30374
   920
            THEN print_tac "before prove_expr:"
haftmann@30374
   921
            THEN prove_expr thy modes (mode, t, us) premposition
haftmann@30374
   922
            THEN print_tac "after prove_expr:"
haftmann@30374
   923
            THEN rec_tac
haftmann@30374
   924
          end
haftmann@30374
   925
        | Negprem (us, t) =>
haftmann@30374
   926
          let
haftmann@30374
   927
            val (in_ts, out_ts''') = get_args js us
haftmann@30374
   928
            val rec_tac = prove_prems out_ts''' vs' rps'
haftmann@30374
   929
            val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE)
haftmann@30374
   930
            val (_, params) = strip_comb t
haftmann@30374
   931
          in
haftmann@30374
   932
            print_tac "before negated clause:"
haftmann@30374
   933
            THEN rtac @{thm bindI} 1
haftmann@30374
   934
            THEN (if (is_some name) then
haftmann@30374
   935
                simp_tac (HOL_basic_ss addsimps [function_definition thy (the name) (iss, js)]) 1
haftmann@30374
   936
                THEN rtac @{thm not_predI} 1
haftmann@30374
   937
                THEN print_tac "after neg. intro rule"
haftmann@30374
   938
                THEN print_tac ("t = " ^ (makestring t))
haftmann@30374
   939
                (* FIXME: work with parameter arguments *)
haftmann@30374
   940
                THEN (EVERY (map (prove_param thy modes) (param_modes ~~ params)))
haftmann@30374
   941
              else
haftmann@30374
   942
                rtac @{thm not_predI'} 1)
haftmann@30374
   943
            THEN (REPEAT_DETERM (atac 1))
haftmann@30374
   944
            THEN rec_tac
haftmann@30374
   945
          end
haftmann@30374
   946
        | Sidecond t =>
haftmann@30374
   947
         rtac @{thm bindI} 1
haftmann@30374
   948
         THEN rtac @{thm if_predI} 1
haftmann@30374
   949
         THEN print_tac "before sidecond:"
haftmann@30374
   950
         THEN prove_sidecond thy modes t
haftmann@30374
   951
         THEN print_tac "after sidecond:"
haftmann@30374
   952
         THEN prove_prems [] vs' rps')
haftmann@30374
   953
    in (prove_match thy out_ts)
haftmann@30374
   954
        THEN rest_tac
haftmann@30374
   955
    end;
haftmann@30374
   956
  val prems_tac = prove_prems in_ts' param_vs ps
haftmann@30374
   957
in
haftmann@30374
   958
  rtac @{thm bindI} 1
haftmann@30374
   959
  THEN rtac @{thm singleI} 1
haftmann@30374
   960
  THEN prems_tac
haftmann@30374
   961
end;
haftmann@30374
   962
haftmann@30374
   963
fun select_sup 1 1 = []
haftmann@30374
   964
  | select_sup _ 1 = [rtac @{thm supI1}]
haftmann@30374
   965
  | select_sup n i = (rtac @{thm supI2})::(select_sup (n - 1) (i - 1));
haftmann@30374
   966
haftmann@30374
   967
fun get_nparams thy s = let
haftmann@30374
   968
    val _ = tracing ("get_nparams: " ^ s)
haftmann@30374
   969
  in
haftmann@30374
   970
  if Symtab.defined (#nparams (IndCodegenData.get thy)) s then
haftmann@30374
   971
    the (Symtab.lookup (#nparams (IndCodegenData.get thy)) s) 
haftmann@30374
   972
  else
haftmann@30374
   973
    case try (InductivePackage.the_inductive (ProofContext.init thy)) s of
haftmann@30374
   974
      SOME info => info |> snd |> #raw_induct |> Thm.unvarify
haftmann@30374
   975
        |> InductivePackage.params_of |> length
haftmann@30374
   976
    | NONE => 0 (* default value *)
haftmann@30374
   977
  end
haftmann@30374
   978
haftmann@30374
   979
val ind_set_codegen_preproc = InductiveSetPackage.codegen_preproc;
haftmann@30374
   980
haftmann@30374
   981
fun pred_elim thy predname =
haftmann@30374
   982
  if (Symtab.defined (#elim_rules (IndCodegenData.get thy)) predname) then
haftmann@30374
   983
    the (Symtab.lookup (#elim_rules (IndCodegenData.get thy)) predname)
haftmann@30374
   984
  else
haftmann@30374
   985
    (let
haftmann@30374
   986
      val ind_result = InductivePackage.the_inductive (ProofContext.init thy) predname
haftmann@30374
   987
      val index = find_index (fn s => s = predname) (#names (fst ind_result))
haftmann@30374
   988
    in nth (#elims (snd ind_result)) index end)
haftmann@30374
   989
haftmann@30374
   990
fun prove_one_direction thy all_vs param_vs modes clauses ((pred, T), mode) = let
haftmann@30972
   991
  val elim_rule = the (Symtab.lookup (#function_elims (IndCodegenData.get thy)) (modename_of thy pred mode))
haftmann@30374
   992
(*  val ind_result = InductivePackage.the_inductive (ProofContext.init thy) pred
haftmann@30374
   993
  val index = find_index (fn s => s = pred) (#names (fst ind_result))
haftmann@30374
   994
  val (_, T) = dest_Const (nth (#preds (snd ind_result)) index) *)
haftmann@30374
   995
  val nargs = length (binder_types T) - get_nparams thy pred
haftmann@30374
   996
  val pred_case_rule = singleton (ind_set_codegen_preproc thy)
haftmann@30374
   997
    (preprocess_elim thy nargs (pred_elim thy pred))
haftmann@30374
   998
  (* FIXME preprocessor |> Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}])*)
haftmann@30374
   999
  val _ = tracing ("pred_case_rule " ^ (makestring pred_case_rule))
haftmann@30374
  1000
in
haftmann@30374
  1001
  REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"}))
haftmann@30374
  1002
  THEN etac elim_rule 1
haftmann@30374
  1003
  THEN etac pred_case_rule 1
haftmann@30374
  1004
  THEN (EVERY (map
haftmann@30374
  1005
         (fn i => EVERY' (select_sup (length clauses) i) i) 
haftmann@30374
  1006
           (1 upto (length clauses))))
haftmann@30374
  1007
  THEN (EVERY (map (prove_clause thy nargs all_vs param_vs modes mode) clauses))
haftmann@30374
  1008
end;
haftmann@30374
  1009
haftmann@30374
  1010
(*******************************************************************************************************)
haftmann@30374
  1011
(* Proof in the other direction ************************************************************************)
haftmann@30374
  1012
(*******************************************************************************************************)
haftmann@30374
  1013
haftmann@30374
  1014
fun prove_match2 thy out_ts = let
haftmann@30374
  1015
  fun split_term_tac (Free _) = all_tac
haftmann@30374
  1016
    | split_term_tac t =
haftmann@30374
  1017
      if (is_constructor thy t) then let
haftmann@30374
  1018
        val info = DatatypePackage.the_datatype thy ((fst o dest_Type o fastype_of) t)
haftmann@30374
  1019
        val num_of_constrs = length (#case_rewrites info)
haftmann@30374
  1020
        (* special treatment of pairs -- because of fishing *)
haftmann@30374
  1021
        val split_rules = case (fst o dest_Type o fastype_of) t of
haftmann@30374
  1022
          "*" => [@{thm prod.split_asm}] 
haftmann@30374
  1023
          | _ => PureThy.get_thms thy (((fst o dest_Type o fastype_of) t) ^ ".split_asm")
haftmann@30374
  1024
        val (_, ts) = strip_comb t
haftmann@30374
  1025
      in
haftmann@30374
  1026
        print_tac ("splitting with t = " ^ (makestring t))
haftmann@30374
  1027
        THEN (Splitter.split_asm_tac split_rules 1)
haftmann@30374
  1028
(*        THEN (Simplifier.asm_full_simp_tac HOL_basic_ss 1)
haftmann@30374
  1029
          THEN (DETERM (TRY (etac @{thm Pair_inject} 1))) *)
haftmann@30374
  1030
        THEN (REPEAT_DETERM_N (num_of_constrs - 1) (etac @{thm botE} 1 ORELSE etac @{thm botE} 2))
haftmann@30374
  1031
        THEN (EVERY (map split_term_tac ts))
haftmann@30374
  1032
      end
haftmann@30374
  1033
    else all_tac
haftmann@30374
  1034
  in
haftmann@30374
  1035
    split_term_tac (mk_tuple out_ts)
haftmann@30374
  1036
    THEN (DETERM (TRY ((Splitter.split_asm_tac [@{thm "split_if_asm"}] 1) THEN (etac @{thm botE} 2))))
haftmann@30374
  1037
  end
haftmann@30374
  1038
haftmann@30374
  1039
(* VERY LARGE SIMILIRATIY to function prove_param 
haftmann@30374
  1040
-- join both functions
haftmann@30374
  1041
*) 
haftmann@30374
  1042
fun prove_param2 thy modes (NONE, t) = all_tac 
haftmann@30374
  1043
  | prove_param2 thy modes (m as SOME (Mode (mode, is, ms)), t) = let
haftmann@30374
  1044
    val  (f, args) = strip_comb t
haftmann@30374
  1045
    val (params, _) = chop (length ms) args
haftmann@30374
  1046
    val f_tac = case f of
haftmann@30374
  1047
        Const (name, T) => full_simp_tac (HOL_basic_ss addsimps 
haftmann@30374
  1048
           @{thm eval_pred}::function_definition thy name mode::[]) 1
haftmann@30374
  1049
      | Free _ => all_tac
haftmann@30374
  1050
  in  
haftmann@30374
  1051
    print_tac "before simplification in prove_args:"
haftmann@30374
  1052
    THEN debug_tac ("function : " ^ (makestring f) ^ " - mode" ^ (makestring mode))
haftmann@30374
  1053
    THEN f_tac
haftmann@30374
  1054
    THEN print_tac "after simplification in prove_args"
haftmann@30374
  1055
    (* work with parameter arguments *)
haftmann@30374
  1056
    THEN (EVERY (map (prove_param2 thy modes) (ms ~~ params)))
haftmann@30374
  1057
  end
haftmann@30374
  1058
haftmann@30374
  1059
fun prove_expr2 thy modes (SOME (Mode (mode, is, ms)), t) = 
haftmann@30374
  1060
  (case strip_comb t of
haftmann@30374
  1061
    (Const (name, T), args) =>
haftmann@30374
  1062
      if AList.defined op = modes name then
haftmann@30374
  1063
        etac @{thm bindE} 1
haftmann@30374
  1064
        THEN (REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"})))
haftmann@30374
  1065
        THEN (etac (elim_rule thy name mode) 1)
haftmann@30374
  1066
        THEN (EVERY (map (prove_param2 thy modes) (ms ~~ args)))
haftmann@30374
  1067
      else error "Prove expr2 if case not implemented"
haftmann@30374
  1068
    | _ => etac @{thm bindE} 1)
haftmann@30374
  1069
  | prove_expr2 _ _ _ = error "Prove expr2 not implemented"
haftmann@30374
  1070
haftmann@30374
  1071
fun prove_sidecond2 thy modes t = let
haftmann@30374
  1072
  val _ = tracing ("prove_sidecond:" ^ (makestring t))
haftmann@30374
  1073
  fun preds_of t nameTs = case strip_comb t of 
haftmann@30374
  1074
    (f as Const (name, T), args) =>
haftmann@30374
  1075
      if AList.defined (op =) modes name then (name, T) :: nameTs
haftmann@30374
  1076
        else fold preds_of args nameTs
haftmann@30374
  1077
    | _ => nameTs
haftmann@30374
  1078
  val preds = preds_of t []
haftmann@30374
  1079
  val _ = tracing ("preds: " ^ (makestring preds))
haftmann@30374
  1080
  val defs = map
haftmann@30374
  1081
    (fn (pred, T) => function_definition thy pred ([], (1 upto (length (binder_types T)))))
haftmann@30374
  1082
      preds
haftmann@30374
  1083
  in
haftmann@30374
  1084
   (* only simplify the one assumption *)
haftmann@30374
  1085
   full_simp_tac (HOL_basic_ss' addsimps @{thm eval_pred} :: defs) 1 
haftmann@30374
  1086
   (* need better control here! *)
haftmann@30374
  1087
   THEN print_tac "after sidecond2 simplification"
haftmann@30374
  1088
   end
haftmann@30374
  1089
  
haftmann@30374
  1090
fun prove_clause2 thy all_vs param_vs modes (iss, is) (ts, ps) pred i = let
haftmann@30374
  1091
  val modes' = modes @ List.mapPartial
haftmann@30374
  1092
   (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)]))
haftmann@30374
  1093
     (param_vs ~~ iss);
haftmann@30374
  1094
  fun check_constrt ((names, eqs), t) =
haftmann@30374
  1095
      if is_constrt thy t then ((names, eqs), t) else
haftmann@30374
  1096
        let
haftmann@30374
  1097
          val s = Name.variant names "x";
haftmann@30374
  1098
          val v = Free (s, fastype_of t)
haftmann@30374
  1099
        in ((s::names, HOLogic.mk_eq (v, t)::eqs), v) end;
haftmann@30374
  1100
  val pred_intro_rule = nth (pred_intros thy pred) (i - 1)
haftmann@30374
  1101
    |> preprocess_intro thy
haftmann@30374
  1102
    |> (fn thm => hd (ind_set_codegen_preproc thy [thm]))
haftmann@30374
  1103
    (* FIXME preprocess |> Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}]) *)
haftmann@30374
  1104
  val (in_ts, clause_out_ts) = get_args is ts;
haftmann@30374
  1105
  val ((all_vs', eqs), in_ts') =
haftmann@30374
  1106
      (*FIXME*) Library.foldl_map check_constrt ((all_vs, []), in_ts);
haftmann@30374
  1107
  fun prove_prems2 out_ts vs [] =
haftmann@30374
  1108
    print_tac "before prove_match2 - last call:"
haftmann@30374
  1109
    THEN prove_match2 thy out_ts
haftmann@30374
  1110
    THEN print_tac "after prove_match2 - last call:"
haftmann@30374
  1111
    THEN (etac @{thm singleE} 1)
haftmann@30374
  1112
    THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1))
haftmann@30374
  1113
    THEN (asm_full_simp_tac HOL_basic_ss' 1)
haftmann@30374
  1114
    THEN (REPEAT_DETERM (etac @{thm Pair_inject} 1))
haftmann@30374
  1115
    THEN (asm_full_simp_tac HOL_basic_ss' 1)
haftmann@30374
  1116
    THEN SOLVED (print_tac "state before applying intro rule:"
haftmann@30374
  1117
      THEN (rtac pred_intro_rule 1)
haftmann@30374
  1118
      (* How to handle equality correctly? *)
haftmann@30374
  1119
      THEN (print_tac "state before assumption matching")
haftmann@30374
  1120
      THEN (REPEAT (atac 1 ORELSE 
haftmann@30374
  1121
         (CHANGED (asm_full_simp_tac HOL_basic_ss' 1)
haftmann@30374
  1122
          THEN print_tac "state after simp_tac:"))))
haftmann@30374
  1123
  | prove_prems2 out_ts vs ps = let
haftmann@30374
  1124
      val vs' = distinct (op =) (flat (vs :: map term_vs out_ts));
haftmann@30374
  1125
      val SOME (p, mode as SOME (Mode ((iss, js), _, param_modes))) =
haftmann@30374
  1126
        select_mode_prem thy modes' vs' ps;
haftmann@30374
  1127
      val ps' = filter_out (equal p) ps;
haftmann@30374
  1128
      val rest_tac = (case p of Prem (us, t) =>
haftmann@30374
  1129
          let
haftmann@30374
  1130
            val (in_ts, out_ts''') = get_args js us
haftmann@30374
  1131
            val rec_tac = prove_prems2 out_ts''' vs' ps'
haftmann@30374
  1132
          in
haftmann@30374
  1133
            (prove_expr2 thy modes (mode, t)) THEN rec_tac
haftmann@30374
  1134
          end
haftmann@30374
  1135
        | Negprem (us, t) =>
haftmann@30374
  1136
          let
haftmann@30374
  1137
            val (in_ts, out_ts''') = get_args js us
haftmann@30374
  1138
            val rec_tac = prove_prems2 out_ts''' vs' ps'
haftmann@30374
  1139
            val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE)
haftmann@30374
  1140
            val (_, params) = strip_comb t
haftmann@30374
  1141
          in
haftmann@30374
  1142
            print_tac "before neg prem 2"
haftmann@30374
  1143
            THEN etac @{thm bindE} 1
haftmann@30374
  1144
            THEN (if is_some name then
haftmann@30374
  1145
                full_simp_tac (HOL_basic_ss addsimps [function_definition thy (the name) (iss, js)]) 1 
haftmann@30374
  1146
                THEN etac @{thm not_predE} 1
haftmann@30374
  1147
                THEN (EVERY (map (prove_param2 thy modes) (param_modes ~~ params)))
haftmann@30374
  1148
              else
haftmann@30374
  1149
                etac @{thm not_predE'} 1)
haftmann@30374
  1150
            THEN rec_tac
haftmann@30374
  1151
          end 
haftmann@30374
  1152
        | Sidecond t =>
haftmann@30374
  1153
            etac @{thm bindE} 1
haftmann@30374
  1154
            THEN etac @{thm if_predE} 1
haftmann@30374
  1155
            THEN prove_sidecond2 thy modes t 
haftmann@30374
  1156
            THEN prove_prems2 [] vs' ps')
haftmann@30374
  1157
    in print_tac "before prove_match2:"
haftmann@30374
  1158
       THEN prove_match2 thy out_ts
haftmann@30374
  1159
       THEN print_tac "after prove_match2:"
haftmann@30374
  1160
       THEN rest_tac
haftmann@30374
  1161
    end;
haftmann@30374
  1162
  val prems_tac = prove_prems2 in_ts' param_vs ps 
haftmann@30374
  1163
in
haftmann@30374
  1164
  print_tac "starting prove_clause2"
haftmann@30374
  1165
  THEN etac @{thm bindE} 1
haftmann@30374
  1166
  THEN (etac @{thm singleE'} 1)
haftmann@30374
  1167
  THEN (TRY (etac @{thm Pair_inject} 1))
haftmann@30374
  1168
  THEN print_tac "after singleE':"
haftmann@30374
  1169
  THEN prems_tac
haftmann@30374
  1170
end;
haftmann@30374
  1171
 
haftmann@30374
  1172
fun prove_other_direction thy all_vs param_vs modes clauses (pred, mode) = let
haftmann@30374
  1173
  fun prove_clause (clause, i) =
haftmann@30374
  1174
    (if i < length clauses then etac @{thm supE} 1 else all_tac)
haftmann@30374
  1175
    THEN (prove_clause2 thy all_vs param_vs modes mode clause pred i)
haftmann@30374
  1176
in
haftmann@30374
  1177
  (DETERM (TRY (rtac @{thm unit.induct} 1)))
haftmann@30374
  1178
   THEN (REPEAT_DETERM (CHANGED (rewtac @{thm split_paired_all})))
haftmann@30374
  1179
   THEN (rtac (intro_rule thy pred mode) 1)
haftmann@30374
  1180
   THEN (EVERY (map prove_clause (clauses ~~ (1 upto (length clauses)))))
haftmann@30374
  1181
end;
haftmann@30374
  1182
haftmann@30374
  1183
fun prove_pred thy all_vs param_vs modes clauses (((pred, T), mode), t) = let
haftmann@30374
  1184
  val ctxt = ProofContext.init thy
haftmann@30374
  1185
  val clauses' = the (AList.lookup (op =) clauses pred)
haftmann@30374
  1186
in
haftmann@30374
  1187
  Goal.prove ctxt (Term.fold_aterms (fn Free (x, _) => insert (op =) x | _ => I) t []) [] t
haftmann@30374
  1188
    (if !do_proofs then
haftmann@30374
  1189
      (fn _ =>
haftmann@30374
  1190
      rtac @{thm pred_iffI} 1
haftmann@30374
  1191
      THEN prove_one_direction thy all_vs param_vs modes clauses' ((pred, T), mode)
haftmann@30374
  1192
      THEN print_tac "proved one direction"
haftmann@30374
  1193
      THEN prove_other_direction thy all_vs param_vs modes clauses' (pred, mode)
haftmann@30374
  1194
      THEN print_tac "proved other direction")
haftmann@30374
  1195
     else (fn _ => mycheat_tac thy 1))
haftmann@30374
  1196
end;
haftmann@30374
  1197
haftmann@30374
  1198
fun prove_preds thy all_vs param_vs modes clauses pmts =
haftmann@30374
  1199
  map (prove_pred thy all_vs param_vs modes clauses) pmts
haftmann@30374
  1200
haftmann@30374
  1201
(* look for other place where this functionality was used before *)
haftmann@30374
  1202
fun strip_intro_concl intro nparams = let
haftmann@30374
  1203
  val _ $ u = Logic.strip_imp_concl intro
haftmann@30374
  1204
  val (pred, all_args) = strip_comb u
haftmann@30374
  1205
  val (params, args) = chop nparams all_args
haftmann@30374
  1206
in (pred, (params, args)) end
haftmann@30374
  1207
haftmann@30374
  1208
(* setup for alternative introduction and elimination rules *)
haftmann@30374
  1209
haftmann@30374
  1210
fun add_intro_thm thm thy = let
haftmann@30374
  1211
   val (pred, _) = dest_Const (fst (strip_intro_concl (prop_of thm) 0))
haftmann@30374
  1212
 in map_intro_rules (Symtab.insert_list Thm.eq_thm (pred, thm)) thy end
haftmann@30374
  1213
haftmann@30374
  1214
fun add_elim_thm thm thy = let
haftmann@30374
  1215
    val (pred, _) = dest_Const (fst 
haftmann@30374
  1216
      (strip_comb (HOLogic.dest_Trueprop (hd (prems_of thm)))))
haftmann@30374
  1217
  in map_elim_rules (Symtab.update (pred, thm)) thy end
haftmann@30374
  1218
haftmann@30374
  1219
haftmann@30374
  1220
(* special case: inductive predicate with no clauses *)
haftmann@30374
  1221
fun noclause (predname, T) thy = let
haftmann@30374
  1222
  val Ts = binder_types T
haftmann@30374
  1223
  val names = Name.variant_list []
haftmann@30374
  1224
        (map (fn i => "x" ^ (string_of_int i)) (1 upto (length Ts)))
haftmann@31124
  1225
  val vs = map2 (curry Free) names Ts
haftmann@30374
  1226
  val clausehd =  HOLogic.mk_Trueprop (list_comb(Const (predname, T), vs))
haftmann@30374
  1227
  val intro_t = Logic.mk_implies (@{prop False}, clausehd)
haftmann@30374
  1228
  val P = HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT))
haftmann@30374
  1229
  val elim_t = Logic.list_implies ([clausehd, Logic.mk_implies (@{prop False}, P)], P)
haftmann@30374
  1230
  val intro_thm = Goal.prove (ProofContext.init thy) names [] intro_t
haftmann@30374
  1231
        (fn {...} => etac @{thm FalseE} 1)
haftmann@30374
  1232
  val elim_thm = Goal.prove (ProofContext.init thy) ("P" :: names) [] elim_t
haftmann@30374
  1233
        (fn {...} => etac (pred_elim thy predname) 1) 
haftmann@30374
  1234
in
haftmann@30374
  1235
  add_intro_thm intro_thm thy
haftmann@30374
  1236
  |> add_elim_thm elim_thm
haftmann@30374
  1237
end
haftmann@30374
  1238
haftmann@30374
  1239
(*************************************************************************************)
haftmann@30374
  1240
(* main function *********************************************************************)
haftmann@30374
  1241
(*************************************************************************************)
haftmann@30374
  1242
haftmann@31124
  1243
fun prove_equation ind_name mode thy =
haftmann@30374
  1244
let
haftmann@31124
  1245
  val _ = tracing ("starting prove_equation' with " ^ ind_name)
haftmann@30374
  1246
  val (prednames, preds) = 
haftmann@30374
  1247
    case (try (InductivePackage.the_inductive (ProofContext.init thy)) ind_name) of
haftmann@30374
  1248
      SOME info => let val preds = info |> snd |> #preds
haftmann@30374
  1249
        in (map (fst o dest_Const) preds, map ((apsnd Logic.unvarifyT) o dest_Const) preds) end
haftmann@30374
  1250
    | NONE => let
haftmann@30374
  1251
        val pred = Symtab.lookup (#intro_rules (IndCodegenData.get thy)) ind_name
haftmann@30374
  1252
          |> the |> hd |> prop_of
haftmann@30374
  1253
          |> Logic.strip_imp_concl |> HOLogic.dest_Trueprop |> strip_comb
haftmann@30374
  1254
          |> fst |>  dest_Const |> apsnd Logic.unvarifyT
haftmann@30374
  1255
       in ([ind_name], [pred]) end
haftmann@30374
  1256
  val thy' = fold (fn pred as (predname, T) => fn thy =>
haftmann@30374
  1257
    if null (pred_intros thy predname) then noclause pred thy else thy) preds thy
haftmann@30374
  1258
  val intrs = map (preprocess_intro thy') (maps (pred_intros thy') prednames)
haftmann@30374
  1259
    |> ind_set_codegen_preproc thy' (*FIXME preprocessor
haftmann@30374
  1260
    |> map (Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}]))*)
haftmann@30374
  1261
    |> map (Logic.unvarify o prop_of)
haftmann@30374
  1262
  val _ = tracing ("preprocessed intro rules:" ^ (makestring (map (cterm_of thy') intrs)))
haftmann@30374
  1263
  val name_of_calls = get_name_of_ind_calls_of_clauses thy' prednames intrs 
haftmann@30374
  1264
  val _ = tracing ("calling preds: " ^ makestring name_of_calls)
haftmann@30374
  1265
  val _ = tracing "starting recursive compilations"
haftmann@30374
  1266
  fun rec_call name thy = 
haftmann@30972
  1267
    (*FIXME use member instead of infix mem*)
haftmann@30374
  1268
    if not (name mem (Symtab.keys (#modes (IndCodegenData.get thy)))) then
haftmann@31124
  1269
      prove_equation name NONE thy else thy
haftmann@30374
  1270
  val thy'' = fold rec_call name_of_calls thy'
haftmann@30374
  1271
  val _ = tracing "returning from recursive calls"
haftmann@30374
  1272
  val _ = tracing "starting mode inference"
haftmann@30374
  1273
  val extra_modes = Symtab.dest (#modes (IndCodegenData.get thy''))
haftmann@30374
  1274
  val nparams = get_nparams thy'' ind_name
haftmann@30374
  1275
  val _ $ u = Logic.strip_imp_concl (hd intrs);
haftmann@30374
  1276
  val params = List.take (snd (strip_comb u), nparams);
haftmann@30374
  1277
  val param_vs = maps term_vs params
haftmann@30374
  1278
  val all_vs = terms_vs intrs
haftmann@30374
  1279
  fun dest_prem t =
haftmann@30374
  1280
      (case strip_comb t of
haftmann@30374
  1281
        (v as Free _, ts) => if v mem params then Prem (ts, v) else Sidecond t
haftmann@30374
  1282
      | (c as Const (@{const_name Not}, _), [t]) => (case dest_prem t of
haftmann@30374
  1283
          Prem (ts, t) => Negprem (ts, t)
haftmann@30374
  1284
        | Negprem _ => error ("Double negation not allowed in premise: " ^ (makestring (c $ t))) 
haftmann@30374
  1285
        | Sidecond t => Sidecond (c $ t))
haftmann@30374
  1286
      | (c as Const (s, _), ts) =>
haftmann@30374
  1287
        if is_ind_pred thy'' s then
haftmann@30374
  1288
          let val (ts1, ts2) = chop (get_nparams thy'' s) ts
haftmann@30374
  1289
          in Prem (ts2, list_comb (c, ts1)) end
haftmann@30374
  1290
        else Sidecond t
haftmann@30374
  1291
      | _ => Sidecond t)
haftmann@30374
  1292
  fun add_clause intr (clauses, arities) =
haftmann@30374
  1293
  let
haftmann@30374
  1294
    val _ $ t = Logic.strip_imp_concl intr;
haftmann@30374
  1295
    val (Const (name, T), ts) = strip_comb t;
haftmann@30374
  1296
    val (ts1, ts2) = chop nparams ts;
haftmann@30374
  1297
    val prems = map (dest_prem o HOLogic.dest_Trueprop) (Logic.strip_imp_prems intr);
haftmann@30374
  1298
    val (Ts, Us) = chop nparams (binder_types T)
haftmann@30374
  1299
  in
haftmann@30374
  1300
    (AList.update op = (name, these (AList.lookup op = clauses name) @
haftmann@30374
  1301
      [(ts2, prems)]) clauses,
haftmann@30374
  1302
     AList.update op = (name, (map (fn U => (case strip_type U of
haftmann@30374
  1303
                 (Rs as _ :: _, Type ("bool", [])) => SOME (length Rs)
haftmann@30374
  1304
               | _ => NONE)) Ts,
haftmann@30374
  1305
             length Us)) arities)
haftmann@30374
  1306
  end;
haftmann@30374
  1307
  val (clauses, arities) = fold add_clause intrs ([], []);
haftmann@30374
  1308
  val modes = infer_modes thy'' extra_modes arities param_vs clauses
haftmann@30374
  1309
  val _ = print_arities arities;
haftmann@30374
  1310
  val _ = print_modes modes;
haftmann@30374
  1311
  val modes = if (is_some mode) then AList.update (op =) (ind_name, [the mode]) modes else modes
haftmann@30374
  1312
  val _ = print_modes modes
haftmann@30374
  1313
  val thy''' = fold (create_definitions preds nparams) modes thy''
haftmann@30374
  1314
    |> map_modes (fold Symtab.update_new modes)
haftmann@30374
  1315
  val clauses' = map (fn (s, cls) => (s, (the (AList.lookup (op =) preds s), cls))) clauses
haftmann@30374
  1316
  val _ = tracing "compiling predicates..."
haftmann@30374
  1317
  val ts = compile_preds thy''' all_vs param_vs (extra_modes @ modes) clauses'
haftmann@30374
  1318
  val _ = tracing "returned term from compile_preds"
haftmann@30374
  1319
  val pred_mode = maps (fn (s, (T, _)) => map (pair (s, T)) ((the o AList.lookup (op =) modes) s)) clauses'
haftmann@30374
  1320
  val _ = tracing "starting proof"
haftmann@30374
  1321
  val result_thms = prove_preds thy''' all_vs param_vs (extra_modes @ modes) clauses (pred_mode ~~ (flat ts))
haftmann@30374
  1322
  val (_, thy'''') = yield_singleton PureThy.add_thmss
haftmann@31124
  1323
    ((Binding.qualify true (Long_Name.base_name ind_name) (Binding.name "equation"), result_thms),
haftmann@30374
  1324
      [Attrib.attribute_i thy''' Code.add_default_eqn_attrib]) thy'''
haftmann@30374
  1325
in
haftmann@30374
  1326
  thy''''
haftmann@30374
  1327
end
haftmann@30374
  1328
haftmann@30374
  1329
fun set_nparams (pred, nparams) thy = map_nparams (Symtab.update (pred, nparams)) thy
haftmann@30374
  1330
haftmann@30374
  1331
fun print_alternative_rules thy = let
haftmann@30374
  1332
    val d = IndCodegenData.get thy
haftmann@30374
  1333
    val preds = (Symtab.keys (#intro_rules d)) union (Symtab.keys (#elim_rules d))
haftmann@30374
  1334
    val _ = tracing ("preds: " ^ (makestring preds))
haftmann@30374
  1335
    fun print pred = let
haftmann@30374
  1336
      val _ = tracing ("predicate: " ^ pred)
haftmann@30374
  1337
      val _ = tracing ("introrules: ")
haftmann@30374
  1338
      val _ = fold (fn thm => fn u => tracing (makestring thm))
haftmann@30374
  1339
        (rev (Symtab.lookup_list (#intro_rules d) pred)) ()
haftmann@30374
  1340
      val _ = tracing ("casesrule: ")
haftmann@30374
  1341
      val _ = tracing (makestring (Symtab.lookup (#elim_rules d) pred))
haftmann@30374
  1342
    in () end
haftmann@30374
  1343
    val _ = map print preds
haftmann@30374
  1344
 in thy end; 
haftmann@30374
  1345
haftmann@30374
  1346
bulwahn@31106
  1347
(* generation of case rules from user-given introduction rules *)
bulwahn@31106
  1348
haftmann@31124
  1349
fun mk_casesrule introrules nparams ctxt =
haftmann@31124
  1350
  let
bulwahn@31106
  1351
    val intros = map prop_of introrules
bulwahn@31106
  1352
    val (pred, (params, args)) = strip_intro_concl (hd intros) nparams
bulwahn@31106
  1353
    val ([propname], ctxt1) = Variable.variant_fixes ["thesis"] ctxt
bulwahn@31106
  1354
    val prop = HOLogic.mk_Trueprop (Free (propname, HOLogic.boolT))
bulwahn@31106
  1355
    val (argnames, ctxt2) = Variable.variant_fixes
bulwahn@31106
  1356
      (map (fn i => "a" ^ string_of_int i) (1 upto (length args))) ctxt1
bulwahn@31106
  1357
    val argvs = map Free (argnames ~~ (map fastype_of args))
haftmann@31124
  1358
      (*FIXME map2*)
bulwahn@31106
  1359
    fun mk_case intro = let
bulwahn@31106
  1360
        val (_, (_, args)) = strip_intro_concl intro nparams
bulwahn@31106
  1361
        val prems = Logic.strip_imp_prems intro
bulwahn@31106
  1362
        val eqprems = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (argvs ~~ args)
bulwahn@31106
  1363
        val frees = (fold o fold_aterms)
bulwahn@31106
  1364
          (fn t as Free _ =>
bulwahn@31106
  1365
              if member (op aconv) params t then I else insert (op aconv) t
bulwahn@31106
  1366
           | _ => I) (args @ prems) []
bulwahn@31106
  1367
        in fold Logic.all frees (Logic.list_implies (eqprems @ prems, prop)) end
bulwahn@31106
  1368
    val assm = HOLogic.mk_Trueprop (list_comb (pred, params @ argvs))
bulwahn@31106
  1369
    val cases = map mk_case intros
bulwahn@31106
  1370
    val (_, ctxt3) = ProofContext.add_assms_i Assumption.assume_export
bulwahn@31106
  1371
              [((Binding.name AutoBind.assmsN, []), map (fn t => (t, [])) (assm :: cases))]
bulwahn@31106
  1372
              ctxt2
bulwahn@31106
  1373
  in (pred, prop, ctxt3) end;
bulwahn@31106
  1374
haftmann@31124
  1375
haftmann@31124
  1376
(** user interface **)
haftmann@31124
  1377
haftmann@31124
  1378
local
haftmann@31124
  1379
haftmann@31124
  1380
fun attrib f = Thm.declaration_attribute (fn thm => Context.mapping (f thm) I);
haftmann@31124
  1381
haftmann@31124
  1382
val add_elim_attrib = attrib add_elim_thm;
bulwahn@31106
  1383
haftmann@31124
  1384
fun generic_code_pred prep_const raw_const lthy =
haftmann@31124
  1385
  let
haftmann@31124
  1386
    val thy = ProofContext.theory_of lthy
haftmann@31124
  1387
    val const = prep_const thy raw_const
haftmann@31124
  1388
    val nparams = get_nparams thy const
haftmann@31124
  1389
    val intro_rules = pred_intros thy const
haftmann@31124
  1390
    val (((tfrees, frees), fact), lthy') =
haftmann@31124
  1391
      Variable.import_thms true intro_rules lthy;
haftmann@31124
  1392
    val (pred, prop, lthy'') = mk_casesrule fact nparams lthy'
haftmann@31124
  1393
    val (predname, _) = dest_Const pred
haftmann@31124
  1394
    fun after_qed [[th]] lthy'' =
haftmann@31124
  1395
      lthy''
bulwahn@31174
  1396
      |> LocalTheory.note Thm.generated_theoremK
haftmann@31124
  1397
           ((Binding.empty, [Attrib.internal (K add_elim_attrib)]), [th])
haftmann@31124
  1398
      |> snd
haftmann@31124
  1399
      |> LocalTheory.theory (prove_equation predname NONE)
haftmann@31124
  1400
  in
haftmann@31124
  1401
    Proof.theorem_i NONE after_qed [[(prop, [])]] lthy''
haftmann@31124
  1402
  end;
haftmann@31124
  1403
haftmann@31124
  1404
structure P = OuterParse
bulwahn@31106
  1405
haftmann@31124
  1406
in
haftmann@31124
  1407
haftmann@31124
  1408
val code_pred = generic_code_pred (K I);
haftmann@31156
  1409
val code_pred_cmd = generic_code_pred Code.read_const
haftmann@31124
  1410
haftmann@31124
  1411
val setup =
haftmann@31124
  1412
  Attrib.setup @{binding code_ind_intros} (Scan.succeed (attrib add_intro_thm))
haftmann@31124
  1413
    "adding alternative introduction rules for code generation of inductive predicates" #>
haftmann@31124
  1414
  Attrib.setup @{binding code_ind_cases} (Scan.succeed add_elim_attrib)
haftmann@31124
  1415
    "adding alternative elimination rules for code generation of inductive predicates";
haftmann@31124
  1416
  (*FIXME name discrepancy in attribs and ML code*)
haftmann@31124
  1417
  (*FIXME intros should be better named intro*)
haftmann@31124
  1418
  (*FIXME why distinguished atribute for cases?*)
haftmann@31124
  1419
haftmann@31124
  1420
val _ = OuterSyntax.local_theory_to_proof "code_pred"
haftmann@31124
  1421
  "prove equations for predicate specified by intro/elim rules"
haftmann@31124
  1422
  OuterKeyword.thy_goal (P.term_group >> code_pred_cmd)
haftmann@31124
  1423
haftmann@31124
  1424
end
haftmann@31124
  1425
haftmann@31124
  1426
(*FIXME
haftmann@31124
  1427
- Naming of auxiliary rules necessary?
haftmann@31124
  1428
*)
bulwahn@31106
  1429
bulwahn@31169
  1430
(* transformation for code generation *)
bulwahn@31169
  1431
bulwahn@31169
  1432
fun pred_term_of thy t = let
bulwahn@31169
  1433
   val (vars, body) = strip_abs t
bulwahn@31169
  1434
   val (pred, all_args) = strip_comb body
bulwahn@31169
  1435
   val (name, T) = dest_Const pred 
bulwahn@31169
  1436
   val (params, args) = chop (get_nparams thy name) all_args
bulwahn@31169
  1437
   val user_mode = flat (map_index
bulwahn@31169
  1438
      (fn (i, t) => case t of Bound j => if j < length vars then [] else [i+1] | _ => [i+1])
bulwahn@31169
  1439
        args)
bulwahn@31169
  1440
  val (inargs, _) = get_args user_mode args
bulwahn@31169
  1441
  val all_modes = Symtab.dest (#modes (IndCodegenData.get thy))
bulwahn@31170
  1442
  val modes = filter (fn Mode (_, is, _) => is = user_mode) (modes_of_term all_modes (list_comb (pred, params)))
bulwahn@31169
  1443
  fun compile m = list_comb (compile_expr thy all_modes (SOME m, list_comb (pred, params)), inargs)
bulwahn@31169
  1444
  in
bulwahn@31169
  1445
    case modes of
bulwahn@31169
  1446
      []  => (let val _ = error "No mode possible for this term" in NONE end)
bulwahn@31169
  1447
    | [m] => SOME (compile m)
bulwahn@31169
  1448
    | ms  => (let val _ = warning "Multiple modes possible for this term"
bulwahn@31169
  1449
        in SOME (compile (hd ms)) end)
bulwahn@31169
  1450
  end;
bulwahn@31169
  1451
haftmann@30374
  1452
end;
haftmann@30374
  1453