src/Pure/Isar/code.ML
author haftmann
Wed Jan 13 12:20:37 2010 +0100 (2010-01-13)
changeset 34895 19fd499cddff
parent 34894 fadbdd350dd1
child 34901 0d6a2ae86525
permissions -rw-r--r--
explicit abstract type of code certificates
     1 (*  Title:      Pure/Isar/code.ML
     2     Author:     Florian Haftmann, TU Muenchen
     3 
     4 Abstract executable ingredients of theory.  Management of data
     5 dependent on executable ingredients as synchronized cache; purged
     6 on any change of underlying executable ingredients.
     7 *)
     8 
     9 signature CODE =
    10 sig
    11   (*constants*)
    12   val check_const: theory -> term -> string
    13   val read_bare_const: theory -> string -> string * typ
    14   val read_const: theory -> string -> string
    15   val string_of_const: theory -> string -> string
    16   val cert_signature: theory -> typ -> typ
    17   val read_signature: theory -> string -> typ
    18   val const_typ: theory -> string -> typ
    19   val subst_signatures: theory -> term -> term
    20   val args_number: theory -> string -> int
    21 
    22   (*constructor sets*)
    23   val constrset_of_consts: theory -> (string * typ) list
    24     -> string * ((string * sort) list * (string * typ list) list)
    25 
    26   (*code equations and certificates*)
    27   val mk_eqn: theory -> thm * bool -> thm * bool
    28   val mk_eqn_warning: theory -> thm -> (thm * bool) option
    29   val mk_eqn_liberal: theory -> thm -> (thm * bool) option
    30   val assert_eqn: theory -> thm * bool -> thm * bool
    31   val const_typ_eqn: theory -> thm -> string * typ
    32   val expand_eta: theory -> int -> thm -> thm
    33   type cert
    34   val empty_cert: theory -> string -> cert
    35   val cert_of_eqns: theory -> string -> (thm * bool) list -> cert
    36   val constrain_cert: theory -> sort list -> cert -> cert
    37   val typscheme_cert: theory -> cert -> (string * sort) list * typ
    38   val equations_cert: theory -> cert -> ((string * sort) list * typ) * (term list * term) list
    39   val equations_thms_cert: theory -> cert -> ((string * sort) list * typ) * ((term list * term) * (thm * bool)) list
    40   val pretty_cert: theory -> cert -> Pretty.T list
    41 
    42   (*executable code*)
    43   val add_type: string -> theory -> theory
    44   val add_type_cmd: string -> theory -> theory
    45   val add_signature: string * typ -> theory -> theory
    46   val add_signature_cmd: string * string -> theory -> theory
    47   val add_datatype: (string * typ) list -> theory -> theory
    48   val add_datatype_cmd: string list -> theory -> theory
    49   val type_interpretation:
    50     (string * ((string * sort) list * (string * typ list) list)
    51       -> theory -> theory) -> theory -> theory
    52   val add_eqn: thm -> theory -> theory
    53   val add_nbe_eqn: thm -> theory -> theory
    54   val add_default_eqn: thm -> theory -> theory
    55   val add_default_eqn_attribute: attribute
    56   val add_default_eqn_attrib: Attrib.src
    57   val del_eqn: thm -> theory -> theory
    58   val del_eqns: string -> theory -> theory
    59   val add_case: thm -> theory -> theory
    60   val add_undefined: string -> theory -> theory
    61   val get_datatype: theory -> string -> ((string * sort) list * (string * typ list) list)
    62   val get_datatype_of_constr: theory -> string -> string option
    63   val get_cert: theory -> ((thm * bool) list -> (thm * bool) list) -> string -> cert
    64   val get_case_scheme: theory -> string -> (int * (int * string list)) option
    65   val undefineds: theory -> string list
    66   val print_codesetup: theory -> unit
    67 
    68   (*infrastructure*)
    69   val set_code_target_attr: (string -> thm -> theory -> theory) -> theory -> theory
    70   val purge_data: theory -> theory
    71 end;
    72 
    73 signature CODE_DATA_ARGS =
    74 sig
    75   type T
    76   val empty: T
    77 end;
    78 
    79 signature CODE_DATA =
    80 sig
    81   type T
    82   val change: theory -> (T -> T) -> T
    83   val change_yield: theory -> (T -> 'a * T) -> 'a * T
    84 end;
    85 
    86 signature PRIVATE_CODE =
    87 sig
    88   include CODE
    89   val declare_data: Object.T -> serial
    90   val change_data: serial * ('a -> Object.T) * (Object.T -> 'a)
    91     -> theory -> ('a -> 'a) -> 'a
    92   val change_yield_data: serial * ('a -> Object.T) * (Object.T -> 'a)
    93     -> theory -> ('a -> 'b * 'a) -> 'b * 'a
    94 end;
    95 
    96 structure Code : PRIVATE_CODE =
    97 struct
    98 
    99 (** auxiliary **)
   100 
   101 (* printing *)
   102 
   103 fun string_of_typ thy = setmp_CRITICAL show_sorts true (Syntax.string_of_typ_global thy);
   104 
   105 fun string_of_const thy c = case AxClass.inst_of_param thy c
   106  of SOME (c, tyco) => Sign.extern_const thy c ^ " " ^ enclose "[" "]" (Sign.extern_type thy tyco)
   107   | NONE => Sign.extern_const thy c;
   108 
   109 
   110 (* constants *)
   111 
   112 fun typ_equiv tys = Type.raw_instance tys andalso Type.raw_instance (swap tys);
   113 
   114 fun check_bare_const thy t = case try dest_Const t
   115  of SOME c_ty => c_ty
   116   | NONE => error ("Not a constant: " ^ Syntax.string_of_term_global thy t);
   117 
   118 fun check_const thy = AxClass.unoverload_const thy o check_bare_const thy;
   119 
   120 fun read_bare_const thy = check_bare_const thy o Syntax.read_term_global thy;
   121 
   122 fun read_const thy = AxClass.unoverload_const thy o read_bare_const thy;
   123 
   124 
   125 
   126 (** data store **)
   127 
   128 (* code equations *)
   129 
   130 type eqns = bool * (thm * bool) list;
   131   (*default flag, theorems with proper flag *)
   132 
   133 fun add_drop_redundant thy (thm, proper) thms =
   134   let
   135     val args_of = snd o strip_comb o map_types Type.strip_sorts
   136       o fst o Logic.dest_equals o Thm.plain_prop_of;
   137     val args = args_of thm;
   138     val incr_idx = Logic.incr_indexes ([], Thm.maxidx_of thm + 1);
   139     fun matches_args args' = length args <= length args' andalso
   140       Pattern.matchess thy (args, (map incr_idx o take (length args)) args');
   141     fun drop (thm', proper') = if (proper orelse not proper')
   142       andalso matches_args (args_of thm') then 
   143         (warning ("Code generator: dropping redundant code equation\n" ^
   144             Display.string_of_thm_global thy thm'); true)
   145       else false;
   146   in (thm, proper) :: filter_out drop thms end;
   147 
   148 fun add_thm thy _ thm (false, thms) = (false, add_drop_redundant thy thm thms)
   149   | add_thm thy true thm (true, thms) = (true, thms @ [thm])
   150   | add_thm thy false thm (true, thms) = (false, [thm]);
   151 
   152 fun del_thm thm = apsnd (remove (eq_fst Thm.eq_thm_prop) (thm, true));
   153 
   154 
   155 (* executable code data *)
   156 
   157 datatype spec = Spec of {
   158   history_concluded: bool,
   159   signatures: int Symtab.table * typ Symtab.table,
   160   eqns: ((bool * eqns) * (serial * eqns) list) Symtab.table
   161     (*with explicit history*),
   162   dtyps: ((serial * ((string * sort) list * (string * typ list) list)) list) Symtab.table
   163     (*with explicit history*),
   164   cases: (int * (int * string list)) Symtab.table * unit Symtab.table
   165 };
   166 
   167 fun make_spec (history_concluded, ((signatures, eqns), (dtyps, cases))) =
   168   Spec { history_concluded = history_concluded,
   169     signatures = signatures, eqns = eqns, dtyps = dtyps, cases = cases };
   170 fun map_spec f (Spec { history_concluded = history_concluded, signatures = signatures,
   171   eqns = eqns, dtyps = dtyps, cases = cases }) =
   172   make_spec (f (history_concluded, ((signatures, eqns), (dtyps, cases))));
   173 fun merge_spec (Spec { history_concluded = _, signatures = (tycos1, sigs1), eqns = eqns1,
   174     dtyps = dtyps1, cases = (cases1, undefs1) },
   175   Spec { history_concluded = _, signatures = (tycos2, sigs2), eqns = eqns2,
   176     dtyps = dtyps2, cases = (cases2, undefs2) }) =
   177   let
   178     val signatures = (Symtab.merge (op =) (tycos1, tycos2),
   179       Symtab.merge typ_equiv (sigs1, sigs2));
   180     fun merge_eqns ((_, history1), (_, history2)) =
   181       let
   182         val raw_history = AList.merge (op = : serial * serial -> bool)
   183           (K true) (history1, history2)
   184         val filtered_history = filter_out (fst o snd) raw_history
   185         val history = if null filtered_history
   186           then raw_history else filtered_history;
   187       in ((false, (snd o hd) history), history) end;
   188     val eqns = Symtab.join (K merge_eqns) (eqns1, eqns2);
   189     val dtyps = Symtab.join (K (AList.merge (op =) (K true))) (dtyps1, dtyps2);
   190     val cases = (Symtab.merge (K true) (cases1, cases2),
   191       Symtab.merge (K true) (undefs1, undefs2));
   192   in make_spec (false, ((signatures, eqns), (dtyps, cases))) end;
   193 
   194 fun history_concluded (Spec { history_concluded, ... }) = history_concluded;
   195 fun the_signatures (Spec { signatures, ... }) = signatures;
   196 fun the_eqns (Spec { eqns, ... }) = eqns;
   197 fun the_dtyps (Spec { dtyps, ... }) = dtyps;
   198 fun the_cases (Spec { cases, ... }) = cases;
   199 val map_history_concluded = map_spec o apfst;
   200 val map_signatures = map_spec o apsnd o apfst o apfst;
   201 val map_eqns = map_spec o apsnd o apfst o apsnd;
   202 val map_dtyps = map_spec o apsnd o apsnd o apfst;
   203 val map_cases = map_spec o apsnd o apsnd o apsnd;
   204 
   205 
   206 (* data slots dependent on executable code *)
   207 
   208 (*private copy avoids potential conflict of table exceptions*)
   209 structure Datatab = Table(type key = int val ord = int_ord);
   210 
   211 local
   212 
   213 type kind = { empty: Object.T };
   214 
   215 val kinds = Unsynchronized.ref (Datatab.empty: kind Datatab.table);
   216 
   217 fun invoke f k = case Datatab.lookup (! kinds) k
   218  of SOME kind => f kind
   219   | NONE => sys_error "Invalid code data identifier";
   220 
   221 in
   222 
   223 fun declare_data empty =
   224   let
   225     val k = serial ();
   226     val kind = { empty = empty };
   227     val _ = CRITICAL (fn () => Unsynchronized.change kinds (Datatab.update (k, kind)));
   228   in k end;
   229 
   230 fun invoke_init k = invoke (fn kind => #empty kind) k;
   231 
   232 end; (*local*)
   233 
   234 
   235 (* theory store *)
   236 
   237 local
   238 
   239 type data = Object.T Datatab.table;
   240 fun empty_dataref () = Synchronized.var "code data" (NONE : (data * theory_ref) option);
   241 
   242 structure Code_Data = Theory_Data
   243 (
   244   type T = spec * (data * theory_ref) option Synchronized.var;
   245   val empty = (make_spec (false, (((Symtab.empty, Symtab.empty), Symtab.empty),
   246     (Symtab.empty, (Symtab.empty, Symtab.empty)))), empty_dataref ());
   247   val extend = I
   248   fun merge ((spec1, _), (spec2, _)) =
   249     (merge_spec (spec1, spec2), empty_dataref ());
   250 );
   251 
   252 in
   253 
   254 (* access to executable code *)
   255 
   256 val the_exec = fst o Code_Data.get;
   257 
   258 fun map_exec_purge f = Code_Data.map (fn (exec, _) => (f exec, empty_dataref ()));
   259 
   260 val purge_data = (Code_Data.map o apsnd) (fn _ => empty_dataref ());
   261 
   262 fun change_eqns delete c f = (map_exec_purge o map_eqns
   263   o (if delete then Symtab.map_entry c else Symtab.map_default (c, ((false, (true, [])), [])))
   264     o apfst) (fn (_, eqns) => (true, f eqns));
   265 
   266 
   267 (* tackling equation history *)
   268 
   269 fun continue_history thy = if (history_concluded o the_exec) thy
   270   then thy
   271     |> (Code_Data.map o apfst o map_history_concluded) (K false)
   272     |> SOME
   273   else NONE;
   274 
   275 fun conclude_history thy = if (history_concluded o the_exec) thy
   276   then NONE
   277   else thy
   278     |> (Code_Data.map o apfst)
   279         ((map_eqns o Symtab.map) (fn ((changed, current), history) =>
   280           ((false, current),
   281             if changed then (serial (), current) :: history else history))
   282         #> map_history_concluded (K true))
   283     |> SOME;
   284 
   285 val _ = Context.>> (Context.map_theory (Theory.at_begin continue_history #> Theory.at_end conclude_history));
   286 
   287 
   288 (* access to data dependent on abstract executable code *)
   289 
   290 fun change_yield_data (kind, mk, dest) theory f =
   291   let
   292     val dataref = (snd o Code_Data.get) theory;
   293     val (datatab, thy_ref) = case Synchronized.value dataref
   294      of SOME (datatab, thy_ref) => if Theory.eq_thy (theory, Theory.deref thy_ref)
   295           then (datatab, thy_ref)
   296           else (Datatab.empty, Theory.check_thy theory)
   297       | NONE => (Datatab.empty, Theory.check_thy theory)
   298     val data = case Datatab.lookup datatab kind
   299      of SOME data => data
   300       | NONE => invoke_init kind;
   301     val result as (x, data') = f (dest data);
   302     val _ = Synchronized.change dataref
   303       ((K o SOME) (Datatab.update (kind, mk data') datatab, thy_ref));
   304   in result end;
   305 
   306 fun change_data ops theory f = change_yield_data ops theory (f #> pair ()) |> snd;
   307 
   308 end; (*local*)
   309 
   310 
   311 (** foundation **)
   312 
   313 (* constants *)
   314 
   315 fun arity_number thy tyco = case Symtab.lookup ((fst o the_signatures o the_exec) thy) tyco
   316  of SOME n => n
   317   | NONE => Sign.arity_number thy tyco;
   318 
   319 fun build_tsig thy =
   320   let
   321     val (tycos, _) = (the_signatures o the_exec) thy;
   322     val decls = (#types o Type.rep_tsig o Sign.tsig_of) thy
   323       |> snd 
   324       |> Symtab.fold (fn (tyco, n) =>
   325           Symtab.update (tyco, Type.LogicalType n)) tycos;
   326   in
   327     Type.empty_tsig
   328     |> Symtab.fold (fn (tyco, Type.LogicalType n) => Type.add_type Name_Space.default_naming
   329         (Binding.qualified_name tyco, n) | _ => I) decls
   330   end;
   331 
   332 fun cert_signature thy = Logic.varifyT o Type.cert_typ (build_tsig thy) o Type.no_tvars;
   333 
   334 fun read_signature thy = cert_signature thy o Type.strip_sorts
   335   o Syntax.parse_typ (ProofContext.init thy);
   336 
   337 fun expand_signature thy = Type.cert_typ_mode Type.mode_syntax (Sign.tsig_of thy);
   338 
   339 fun lookup_typ thy = Symtab.lookup ((snd o the_signatures o the_exec) thy);
   340 
   341 fun const_typ thy c = case lookup_typ thy c
   342  of SOME ty => ty
   343   | NONE => (Type.strip_sorts o Sign.the_const_type thy) c;
   344 
   345 fun subst_signature thy c ty =
   346   let
   347     fun mk_subst (Type (tyco, tys1)) (ty2 as Type (tyco2, tys2)) =
   348           fold2 mk_subst tys1 tys2
   349       | mk_subst ty (TVar (v, sort)) = Vartab.update (v, ([], ty))
   350   in case lookup_typ thy c
   351    of SOME ty' => Envir.subst_type (mk_subst ty (expand_signature thy ty') Vartab.empty) ty'
   352     | NONE => ty
   353   end;
   354 
   355 fun subst_signatures thy = map_aterms (fn Const (c, ty) => Const (c, subst_signature thy c ty) | t => t);
   356 
   357 fun args_number thy = length o fst o strip_type o const_typ thy;
   358 
   359 
   360 (* datatypes *)
   361 
   362 fun constrset_of_consts thy cs =
   363   let
   364     val _ = map (fn (c, _) => if (is_some o AxClass.class_of_param thy) c
   365       then error ("Is a class parameter: " ^ string_of_const thy c) else ()) cs;
   366     fun no_constr s (c, ty) = error ("Not a datatype constructor:\n" ^ string_of_const thy c
   367       ^ " :: " ^ string_of_typ thy ty ^ "\n" ^ enclose "(" ")" s);
   368     fun last_typ c_ty ty =
   369       let
   370         val tfrees = Term.add_tfreesT ty [];
   371         val (tyco, vs) = ((apsnd o map) (dest_TFree) o dest_Type o snd o strip_type) ty
   372           handle TYPE _ => no_constr "bad type" c_ty
   373         val _ = if has_duplicates (eq_fst (op =)) vs
   374           then no_constr "duplicate type variables in datatype" c_ty else ();
   375         val _ = if length tfrees <> length vs
   376           then no_constr "type variables missing in datatype" c_ty else ();
   377       in (tyco, vs) end;
   378     fun ty_sorts (c, raw_ty) =
   379       let
   380         val ty = subst_signature thy c raw_ty;
   381         val ty_decl = (Logic.unvarifyT o const_typ thy) c;
   382         val (tyco, _) = last_typ (c, ty) ty_decl;
   383         val (_, vs) = last_typ (c, ty) ty;
   384       in ((tyco, map snd vs), (c, (map fst vs, ty))) end;
   385     fun add ((tyco', sorts'), c) ((tyco, sorts), cs) =
   386       let
   387         val _ = if (tyco' : string) <> tyco
   388           then error "Different type constructors in constructor set"
   389           else ();
   390         val sorts'' = map2 (curry (Sorts.inter_sort (Sign.classes_of thy))) sorts' sorts
   391       in ((tyco, sorts), c :: cs) end;
   392     fun inst vs' (c, (vs, ty)) =
   393       let
   394         val the_v = the o AList.lookup (op =) (vs ~~ vs');
   395         val ty' = map_atyps (fn TFree (v, _) => TFree (the_v v)) ty;
   396       in (c, (fst o strip_type) ty') end;
   397     val c' :: cs' = map ty_sorts cs;
   398     val ((tyco, sorts), cs'') = fold add cs' (apsnd single c');
   399     val vs = Name.names Name.context Name.aT sorts;
   400     val cs''' = map (inst vs) cs'';
   401   in (tyco, (vs, rev cs''')) end;
   402 
   403 fun get_datatype thy tyco =
   404   case these (Symtab.lookup ((the_dtyps o the_exec) thy) tyco)
   405    of (_, spec) :: _ => spec
   406     | [] => arity_number thy tyco
   407         |> Name.invents Name.context Name.aT
   408         |> map (rpair [])
   409         |> rpair [];
   410 
   411 fun get_datatype_of_constr thy c =
   412   case (snd o strip_type o const_typ thy) c
   413    of Type (tyco, _) => if member (op =) ((map fst o snd o get_datatype thy) tyco) c
   414        then SOME tyco else NONE
   415     | _ => NONE;
   416 
   417 fun is_constr thy = is_some o get_datatype_of_constr thy;
   418 
   419 
   420 (* bare code equations *)
   421 
   422 exception BAD_THM of string;
   423 fun bad_thm msg = raise BAD_THM msg;
   424 fun error_thm f thm = f thm handle BAD_THM msg => error msg;
   425 fun warning_thm f thm = SOME (f thm) handle BAD_THM msg => (warning msg; NONE)
   426 fun try_thm f thm = SOME (f thm) handle BAD_THM _ => NONE;
   427 
   428 fun is_linear thm =
   429   let val (_, args) = (strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of) thm
   430   in not (has_duplicates (op =) ((fold o fold_aterms)
   431     (fn Var (v, _) => cons v | _ => I) args [])) end;
   432 
   433 fun gen_assert_eqn thy check_patterns (thm, proper) =
   434   let
   435     fun bad s = bad_thm (s ^ ":\n" ^ Display.string_of_thm_global thy thm);
   436     val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm
   437       handle TERM _ => bad "Not an equation"
   438            | THM _ => bad "Not an equation";
   439     fun vars_of t = fold_aterms (fn Var (v, _) => insert (op =) v
   440       | Free _ => bad "Illegal free variable in equation"
   441       | _ => I) t [];
   442     fun tvars_of t = fold_term_types (fn _ =>
   443       fold_atyps (fn TVar (v, _) => insert (op =) v
   444         | TFree _ => bad "Illegal free type variable in equation")) t [];
   445     val lhs_vs = vars_of lhs;
   446     val rhs_vs = vars_of rhs;
   447     val lhs_tvs = tvars_of lhs;
   448     val rhs_tvs = tvars_of rhs;
   449     val _ = if null (subtract (op =) lhs_vs rhs_vs)
   450       then ()
   451       else bad "Free variables on right hand side of equation";
   452     val _ = if null (subtract (op =) lhs_tvs rhs_tvs)
   453       then ()
   454       else bad "Free type variables on right hand side of equation";
   455     val (head, args) = strip_comb lhs;
   456     val (c, ty) = case head
   457      of Const (c_ty as (_, ty)) => (AxClass.unoverload_const thy c_ty, ty)
   458       | _ => bad "Equation not headed by constant";
   459     fun check _ (Abs _) = bad "Abstraction on left hand side of equation"
   460       | check 0 (Var _) = ()
   461       | check _ (Var _) = bad "Variable with application on left hand side of equation"
   462       | check n (t1 $ t2) = (check (n+1) t1; check 0 t2)
   463       | check n (Const (c_ty as (c, ty))) =
   464           let
   465             val c' = AxClass.unoverload_const thy c_ty
   466           in if n = (length o fst o strip_type o subst_signature thy c') ty
   467             then if not proper orelse not check_patterns orelse is_constr thy c'
   468               then ()
   469               else bad (quote c ^ " is not a constructor, on left hand side of equation")
   470             else bad ("Partially applied constant " ^ quote c ^ " on left hand side of equation")
   471           end;
   472     val _ = map (check 0) args;
   473     val _ = if not proper orelse is_linear thm then ()
   474       else bad "Duplicate variables on left hand side of equation";
   475     val _ = if (is_none o AxClass.class_of_param thy) c then ()
   476       else bad "Overloaded constant as head in equation";
   477     val _ = if not (is_constr thy c) then ()
   478       else bad "Constructor as head in equation";
   479     val ty_decl = Sign.the_const_type thy c;
   480     val _ = if Sign.typ_equiv thy (Type.strip_sorts ty_decl, Type.strip_sorts ty)
   481       then () else bad_thm ("Type\n" ^ string_of_typ thy ty
   482         ^ "\nof equation\n"
   483         ^ Display.string_of_thm_global thy thm
   484         ^ "\nis incompatible with declared function type\n"
   485         ^ string_of_typ thy ty_decl)
   486   in (thm, proper) end;
   487 
   488 fun assert_eqn thy = error_thm (gen_assert_eqn thy true);
   489 
   490 fun meta_rewrite thy = LocalDefs.meta_rewrite_rule (ProofContext.init thy);
   491 
   492 fun mk_eqn thy = error_thm (gen_assert_eqn thy false) o
   493   apfst (meta_rewrite thy);
   494 
   495 fun mk_eqn_warning thy = Option.map (fn (thm, _) => (thm, is_linear thm))
   496   o warning_thm (gen_assert_eqn thy false) o rpair false o meta_rewrite thy;
   497 
   498 fun mk_eqn_liberal thy = Option.map (fn (thm, _) => (thm, is_linear thm))
   499   o try_thm (gen_assert_eqn thy false) o rpair false o meta_rewrite thy;
   500 
   501 val head_eqn = dest_Const o fst o strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of;
   502 
   503 fun const_typ_eqn thy thm =
   504   let
   505     val (c, ty) = head_eqn thm;
   506     val c' = AxClass.unoverload_const thy (c, ty);
   507       (*permissive wrt. to overloaded constants!*)
   508   in (c', ty) end;
   509 
   510 fun const_eqn thy = fst o const_typ_eqn thy;
   511 
   512 fun logical_typscheme thy (c, ty) =
   513   (map dest_TFree (Sign.const_typargs thy (c, ty)), Type.strip_sorts ty);
   514 
   515 fun typscheme thy (c, ty) = logical_typscheme thy (c, subst_signature thy c ty);
   516 
   517 
   518 (* technical transformations of code equations *)
   519 
   520 fun expand_eta thy k thm =
   521   let
   522     val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm;
   523     val (_, args) = strip_comb lhs;
   524     val l = if k = ~1
   525       then (length o fst o strip_abs) rhs
   526       else Int.max (0, k - length args);
   527     val (raw_vars, _) = Term.strip_abs_eta l rhs;
   528     val vars = burrow_fst (Name.variant_list (map (fst o fst) (Term.add_vars lhs [])))
   529       raw_vars;
   530     fun expand (v, ty) thm = Drule.fun_cong_rule thm
   531       (Thm.cterm_of thy (Var ((v, 0), ty)));
   532   in
   533     thm
   534     |> fold expand vars
   535     |> Conv.fconv_rule Drule.beta_eta_conversion
   536   end;
   537 
   538 fun same_arity thy thms =
   539   let
   540     val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
   541     val k = fold (Integer.max o num_args_of o Thm.prop_of) thms 0;
   542   in map (expand_eta thy k) thms end;
   543 
   544 fun mk_desymbolization pre post mk vs =
   545   let
   546     val names = map (pre o fst o fst) vs
   547       |> map (Name.desymbolize false)
   548       |> Name.variant_list []
   549       |> map post;
   550   in map_filter (fn (((v, i), x), v') =>
   551     if v = v' andalso i = 0 then NONE
   552     else SOME (((v, i), x), mk ((v', 0), x))) (vs ~~ names)
   553   end;
   554 
   555 fun desymbolize_tvars thy thms =
   556   let
   557     val tvs = fold (Term.add_tvars o Thm.prop_of) thms [];
   558     val tvar_subst = mk_desymbolization (unprefix "'") (prefix "'") TVar tvs;
   559   in map (Thm.certify_instantiate (tvar_subst, [])) thms end;
   560 
   561 fun desymbolize_vars thy thm =
   562   let
   563     val vs = Term.add_vars (Thm.prop_of thm) [];
   564     val var_subst = mk_desymbolization I I Var vs;
   565   in Thm.certify_instantiate ([], var_subst) thm end;
   566 
   567 fun canonize_thms thy = desymbolize_tvars thy #> same_arity thy #> map (desymbolize_vars thy);
   568 
   569 
   570 (* code equation certificates *)
   571 
   572 fun build_head thy (c, ty) =
   573   Thm.cterm_of thy (Logic.mk_equals (Free ("HEAD", ty), Const (c, ty)));
   574 
   575 fun get_head thy cert_thm =
   576   let
   577     val [head] = (#hyps o Thm.crep_thm) cert_thm;
   578     val (_, Const (c, ty)) = (Logic.dest_equals o Thm.term_of) head;
   579   in (typscheme thy (c, ty), head) end;
   580 
   581 abstype cert = Cert of thm * bool list with
   582 
   583 fun empty_cert thy c = 
   584   let
   585     val raw_ty = const_typ thy c;
   586     val tvars = Term.add_tvar_namesT raw_ty [];
   587     val tvars' = case AxClass.class_of_param thy c
   588      of SOME class => [TFree (Name.aT, [class])]
   589       | NONE => Name.invent_list [] Name.aT (length tvars)
   590           |> map (fn v => TFree (v, []));
   591     val ty = typ_subst_TVars (tvars ~~ tvars') raw_ty;
   592     val chead = build_head thy (c, ty);
   593   in Cert (Thm.weaken chead Drule.dummy_thm, []) end;
   594 
   595 fun cert_of_eqns thy c [] = empty_cert thy c
   596   | cert_of_eqns thy c raw_eqns = 
   597       let
   598         val eqns = burrow_fst (canonize_thms thy) raw_eqns;
   599         val _ = map (assert_eqn thy) eqns;
   600         val (thms, propers) = split_list eqns;
   601         val _ = map (fn thm => if c = const_eqn thy thm then ()
   602           else error ("Wrong head of code equation,\nexpected constant "
   603             ^ string_of_const thy c ^ "\n" ^ Display.string_of_thm_global thy thm)) thms;
   604         fun tvars_of T = rev (Term.add_tvarsT T []);
   605         val vss = map (tvars_of o snd o head_eqn) thms;
   606         fun inter_sorts vs =
   607           fold (curry (Sorts.inter_sort (Sign.classes_of thy)) o snd) vs [];
   608         val sorts = map_transpose inter_sorts vss;
   609         val vts = Name.names Name.context Name.aT sorts;
   610         val thms as thm :: _ =
   611           map2 (fn vs => Thm.certify_instantiate (vs ~~ map TFree vts, [])) vss thms;
   612         val head_thm = Thm.symmetric (Thm.assume (build_head thy (head_eqn (hd thms))));
   613         fun head_conv ct = if can Thm.dest_comb ct
   614           then Conv.fun_conv head_conv ct
   615           else Conv.rewr_conv head_thm ct;
   616         val rewrite_head = Conv.fconv_rule (Conv.arg1_conv head_conv);
   617         val cert_thm = Conjunction.intr_balanced (map rewrite_head thms);
   618       in Cert (cert_thm, propers) end;
   619 
   620 fun constrain_cert thy sorts (Cert (cert_thm, propers)) =
   621   let
   622     val ((vs, _), head) = get_head thy cert_thm;
   623     val subst = map2 (fn (v, sort) => fn sort' =>
   624       (v, Sorts.inter_sort (Sign.classes_of thy) (sort, sort'))) vs sorts;
   625     val head' = Thm.term_of head
   626       |> (map_types o map_atyps)
   627           (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) subst v)))
   628       |> Thm.cterm_of thy;
   629     val inst = map2 (fn (v, sort) => fn (_, sort') =>
   630       (((v, 0), sort), TFree (v, sort'))) vs subst;
   631     val cert_thm' = cert_thm
   632       |> Thm.implies_intr head
   633       |> Thm.varifyT
   634       |> Thm.certify_instantiate (inst, [])
   635       |> Thm.elim_implies (Thm.assume head');
   636   in (Cert (cert_thm', propers)) end;
   637 
   638 fun typscheme_cert thy (Cert (cert_thm, _)) =
   639   fst (get_head thy cert_thm);
   640 
   641 fun equations_cert thy (cert as Cert (cert_thm, propers)) =
   642   let
   643     val tyscm = typscheme_cert thy cert;
   644     val equations = if null propers then [] else
   645       Thm.prop_of cert_thm
   646       |> Logic.dest_conjunction_balanced (length propers)
   647       |> map Logic.dest_equals
   648       |> (map o apfst) (snd o strip_comb)
   649   in (tyscm, equations) end;
   650 
   651 fun equations_thms_cert thy (cert as Cert (cert_thm, propers)) =
   652   let
   653     val (tyscm, equations) = equations_cert thy cert;
   654     val thms = if null propers then [] else
   655       cert_thm
   656       |> LocalDefs.expand [snd (get_head thy cert_thm)]
   657       |> Thm.varifyT
   658       |> Conjunction.elim_balanced (length propers)
   659   in (tyscm, equations ~~ (thms ~~ propers)) end;
   660 
   661 fun pretty_cert thy = map (Display.pretty_thm_global thy o AxClass.overload thy o fst o snd)
   662   o snd o equations_thms_cert thy;
   663 
   664 end;
   665 
   666 
   667 (* code equation access *)
   668 
   669 fun get_cert thy f c =
   670   Symtab.lookup ((the_eqns o the_exec) thy) c
   671   |> Option.map (snd o snd o fst)
   672   |> these
   673   |> (map o apfst) (Thm.transfer thy)
   674   |> f
   675   |> (map o apfst) (AxClass.unoverload thy)
   676   |> cert_of_eqns thy c;
   677 
   678 
   679 (* cases *)
   680 
   681 fun case_certificate thm =
   682   let
   683     val ((head, raw_case_expr), cases) = (apfst Logic.dest_equals
   684       o apsnd Logic.dest_conjunctions o Logic.dest_implies o Thm.plain_prop_of) thm;
   685     val _ = case head of Free _ => true
   686       | Var _ => true
   687       | _ => raise TERM ("case_cert", []);
   688     val ([(case_var, _)], case_expr) = Term.strip_abs_eta 1 raw_case_expr;
   689     val (Const (case_const, _), raw_params) = strip_comb case_expr;
   690     val n = find_index (fn Free (v, _) => v = case_var | _ => false) raw_params;
   691     val _ = if n = ~1 then raise TERM ("case_cert", []) else ();
   692     val params = map (fst o dest_Var) (nth_drop n raw_params);
   693     fun dest_case t =
   694       let
   695         val (head' $ t_co, rhs) = Logic.dest_equals t;
   696         val _ = if head' = head then () else raise TERM ("case_cert", []);
   697         val (Const (co, _), args) = strip_comb t_co;
   698         val (Var (param, _), args') = strip_comb rhs;
   699         val _ = if args' = args then () else raise TERM ("case_cert", []);
   700       in (param, co) end;
   701     fun analyze_cases cases =
   702       let
   703         val co_list = fold (AList.update (op =) o dest_case) cases [];
   704       in map (the o AList.lookup (op =) co_list) params end;
   705     fun analyze_let t =
   706       let
   707         val (head' $ arg, Var (param', _) $ arg') = Logic.dest_equals t;
   708         val _ = if head' = head then () else raise TERM ("case_cert", []);
   709         val _ = if arg' = arg then () else raise TERM ("case_cert", []);
   710         val _ = if [param'] = params then () else raise TERM ("case_cert", []);
   711       in [] end;
   712     fun analyze (cases as [let_case]) =
   713           (analyze_cases cases handle Bind => analyze_let let_case)
   714       | analyze cases = analyze_cases cases;
   715   in (case_const, (n, analyze cases)) end;
   716 
   717 fun case_cert thm = case_certificate thm
   718   handle Bind => error "bad case certificate"
   719        | TERM _ => error "bad case certificate";
   720 
   721 fun get_case_scheme thy = Symtab.lookup ((fst o the_cases o the_exec) thy);
   722 
   723 val undefineds = Symtab.keys o snd o the_cases o the_exec;
   724 
   725 
   726 (* diagnostic *)
   727 
   728 fun print_codesetup thy =
   729   let
   730     val ctxt = ProofContext.init thy;
   731     val exec = the_exec thy;
   732     fun pretty_eqns (s, (_, eqns)) =
   733       (Pretty.block o Pretty.fbreaks) (
   734         Pretty.str s :: map (Display.pretty_thm ctxt o fst) eqns
   735       );
   736     fun pretty_dtyp (s, []) =
   737           Pretty.str s
   738       | pretty_dtyp (s, cos) =
   739           (Pretty.block o Pretty.breaks) (
   740             Pretty.str s
   741             :: Pretty.str "="
   742             :: separate (Pretty.str "|") (map (fn (c, []) => Pretty.str (string_of_const thy c)
   743                  | (c, tys) =>
   744                      (Pretty.block o Pretty.breaks)
   745                         (Pretty.str (string_of_const thy c)
   746                           :: Pretty.str "of"
   747                           :: map (Pretty.quote o Syntax.pretty_typ_global thy) tys)) cos)
   748           );
   749     val eqns = the_eqns exec
   750       |> Symtab.dest
   751       |> (map o apfst) (string_of_const thy)
   752       |> (map o apsnd) (snd o fst)
   753       |> sort (string_ord o pairself fst);
   754     val dtyps = the_dtyps exec
   755       |> Symtab.dest
   756       |> map (fn (dtco, (_, (vs, cos)) :: _) =>
   757           (string_of_typ thy (Type (dtco, map TFree vs)), cos))
   758       |> sort (string_ord o pairself fst)
   759   in
   760     (Pretty.writeln o Pretty.chunks) [
   761       Pretty.block (
   762         Pretty.str "code equations:"
   763         :: Pretty.fbrk
   764         :: (Pretty.fbreaks o map pretty_eqns) eqns
   765       ),
   766       Pretty.block (
   767         Pretty.str "datatypes:"
   768         :: Pretty.fbrk
   769         :: (Pretty.fbreaks o map pretty_dtyp) dtyps
   770       )
   771     ]
   772   end;
   773 
   774 
   775 (** declaring executable ingredients **)
   776 
   777 (* constant signatures *)
   778 
   779 fun add_type tyco thy =
   780   case Symtab.lookup ((snd o #types o Type.rep_tsig o Sign.tsig_of) thy) tyco
   781    of SOME (Type.Abbreviation (vs, _, _)) =>
   782           (map_exec_purge o map_signatures o apfst)
   783             (Symtab.update (tyco, length vs)) thy
   784     | _ => error ("No such type abbreviation: " ^ quote tyco);
   785 
   786 fun add_type_cmd s thy = add_type (Sign.intern_type thy s) thy;
   787 
   788 fun gen_add_signature prep_const prep_signature (raw_c, raw_ty) thy =
   789   let
   790     val c = prep_const thy raw_c;
   791     val ty = prep_signature thy raw_ty;
   792     val ty' = expand_signature thy ty;
   793     val ty'' = Sign.the_const_type thy c;
   794     val _ = if typ_equiv (ty', ty'') then () else
   795       error ("Illegal constant signature: " ^ Syntax.string_of_typ_global thy ty);
   796   in
   797     thy
   798     |> (map_exec_purge o map_signatures o apsnd) (Symtab.update (c, ty))
   799   end;
   800 
   801 val add_signature = gen_add_signature (K I) cert_signature;
   802 val add_signature_cmd = gen_add_signature read_const read_signature;
   803 
   804 
   805 (* code equations *)
   806 
   807 fun gen_add_eqn default (thm, proper) thy =
   808   let
   809     val thm' = Thm.close_derivation thm;
   810     val c = const_eqn thy thm';
   811   in change_eqns false c (add_thm thy default (thm', proper)) thy end;
   812 
   813 fun add_eqn thm thy =
   814   gen_add_eqn false (mk_eqn thy (thm, true)) thy;
   815 
   816 fun add_warning_eqn thm thy =
   817   case mk_eqn_warning thy thm
   818    of SOME eqn => gen_add_eqn false eqn thy
   819     | NONE => thy;
   820 
   821 fun add_default_eqn thm thy =
   822   case mk_eqn_liberal thy thm
   823    of SOME eqn => gen_add_eqn true eqn thy
   824     | NONE => thy;
   825 
   826 fun add_nbe_eqn thm thy =
   827   gen_add_eqn false (mk_eqn thy (thm, false)) thy;
   828 
   829 val add_default_eqn_attribute = Thm.declaration_attribute
   830   (fn thm => Context.mapping (add_default_eqn thm) I);
   831 val add_default_eqn_attrib = Attrib.internal (K add_default_eqn_attribute);
   832 
   833 fun del_eqn thm thy = case mk_eqn_liberal thy thm
   834  of SOME (thm, _) => change_eqns true (const_eqn thy thm) (del_thm thm) thy
   835   | NONE => thy;
   836 
   837 fun del_eqns c = change_eqns true c (K (false, []));
   838 
   839 
   840 (* cases *)
   841 
   842 fun add_case thm thy =
   843   let
   844     val (c, (k, case_pats)) = case_cert thm;
   845     val _ = case filter_out (is_constr thy) case_pats
   846      of [] => ()
   847       | cs => error ("Non-constructor(s) in case certificate: " ^ commas (map quote cs));
   848     val entry = (1 + Int.max (1, length case_pats), (k, case_pats))
   849   in (map_exec_purge o map_cases o apfst) (Symtab.update (c, entry)) thy end;
   850 
   851 fun add_undefined c thy =
   852   (map_exec_purge o map_cases o apsnd) (Symtab.update (c, ())) thy;
   853 
   854 
   855 (* datatypes *)
   856 
   857 structure Type_Interpretation =
   858   Interpretation(type T = string * serial val eq = eq_snd (op =) : T * T -> bool);
   859 
   860 fun add_datatype raw_cs thy =
   861   let
   862     val cs = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) raw_cs;
   863     val (tyco, vs_cos) = constrset_of_consts thy cs;
   864     val old_cs = (map fst o snd o get_datatype thy) tyco;
   865     fun drop_outdated_cases cases = fold Symtab.delete_safe
   866       (Symtab.fold (fn (c, (_, (_, cos))) =>
   867         if exists (member (op =) old_cs) cos
   868           then insert (op =) c else I) cases []) cases;
   869   in
   870     thy
   871     |> fold (del_eqns o fst) cs
   872     |> map_exec_purge
   873         ((map_dtyps o Symtab.map_default (tyco, [])) (cons (serial (), vs_cos))
   874         #> (map_cases o apfst) drop_outdated_cases)
   875     |> Type_Interpretation.data (tyco, serial ())
   876   end;
   877 
   878 fun type_interpretation f =  Type_Interpretation.interpretation
   879   (fn (tyco, _) => fn thy => f (tyco, get_datatype thy tyco) thy);
   880 
   881 fun add_datatype_cmd raw_cs thy =
   882   let
   883     val cs = map (read_bare_const thy) raw_cs;
   884   in add_datatype cs thy end;
   885 
   886 
   887 (* c.f. src/HOL/Tools/recfun_codegen.ML *)
   888 
   889 structure Code_Target_Attr = Theory_Data
   890 (
   891   type T = (string -> thm -> theory -> theory) option;
   892   val empty = NONE;
   893   val extend = I;
   894   fun merge (f1, f2) = if is_some f1 then f1 else f2;
   895 );
   896 
   897 fun set_code_target_attr f = Code_Target_Attr.map (K (SOME f));
   898 
   899 fun code_target_attr prefix thm thy =
   900   let
   901     val attr = the_default ((K o K) I) (Code_Target_Attr.get thy);
   902   in thy |> add_warning_eqn thm |> attr prefix thm end;
   903 (* setup *)
   904 
   905 val _ = Context.>> (Context.map_theory
   906   (let
   907     fun mk_attribute f = Thm.declaration_attribute (fn thm => Context.mapping (f thm) I);
   908     val code_attribute_parser =
   909       Args.del |-- Scan.succeed (mk_attribute del_eqn)
   910       || Args.$$$ "nbe" |-- Scan.succeed (mk_attribute add_nbe_eqn)
   911       || (Args.$$$ "target" |-- Args.colon |-- Args.name >>
   912            (mk_attribute o code_target_attr))
   913       || Scan.succeed (mk_attribute add_warning_eqn);
   914   in
   915     Type_Interpretation.init
   916     #> Attrib.setup (Binding.name "code") (Scan.lift code_attribute_parser)
   917         "declare theorems for code generation"
   918   end));
   919 
   920 end; (*struct*)
   921 
   922 
   923 (** type-safe interfaces for data dependent on executable code **)
   924 
   925 functor Code_Data(Data: CODE_DATA_ARGS): CODE_DATA =
   926 struct
   927 
   928 type T = Data.T;
   929 exception Data of T;
   930 fun dest (Data x) = x
   931 
   932 val kind = Code.declare_data (Data Data.empty);
   933 
   934 val data_op = (kind, Data, dest);
   935 
   936 val change = Code.change_data data_op;
   937 fun change_yield thy = Code.change_yield_data data_op thy;
   938 
   939 end;
   940 
   941 structure Code : CODE = struct open Code; end;