src/Pure/Isar/code.ML
author wenzelm
Thu Jul 09 22:01:41 2009 +0200 (2009-07-09)
changeset 31971 8c1b845ed105
parent 31962 baa8dce5bc45
child 31998 2c7a24f74db9
permissions -rw-r--r--
renamed functor TableFun to Table, and GraphFun to Graph;
     1 (*  Title:      Pure/Isar/code.ML
     2     Author:     Florian Haftmann, TU Muenchen
     3 
     4 Abstract executable code of theory.  Management of data dependent on
     5 executable code.  Cache assumes non-concurrent processing of a single theory.
     6 *)
     7 
     8 signature CODE =
     9 sig
    10   (*constants*)
    11   val check_const: theory -> term -> string
    12   val read_bare_const: theory -> string -> string * typ
    13   val read_const: theory -> string -> string
    14   val string_of_const: theory -> string -> string
    15   val args_number: theory -> string -> int
    16   val typscheme: theory -> string * typ -> (string * sort) list * typ
    17 
    18   (*constructor sets*)
    19   val constrset_of_consts: theory -> (string * typ) list
    20     -> string * ((string * sort) list * (string * typ list) list)
    21 
    22   (*constant aliasses*)
    23   val add_const_alias: thm -> theory -> theory
    24   val triv_classes: theory -> class list
    25   val resubst_alias: theory -> string -> string
    26 
    27   (*code equations*)
    28   val mk_eqn: theory -> thm * bool -> thm * bool
    29   val mk_eqn_warning: theory -> thm -> (thm * bool) option
    30   val mk_eqn_liberal: theory -> thm -> (thm * bool) option
    31   val assert_eqn: theory -> thm * bool -> thm * bool
    32   val assert_eqns_const: theory -> string
    33     -> (thm * bool) list -> (thm * bool) list
    34   val const_typ_eqn: theory -> thm -> string * typ
    35   val typscheme_eqn: theory -> thm -> (string * sort) list * typ
    36   val expand_eta: theory -> int -> thm -> thm
    37   val norm_args: theory -> thm list -> thm list 
    38   val norm_varnames: theory -> thm list -> thm list
    39 
    40   (*executable code*)
    41   val add_datatype: (string * typ) list -> theory -> theory
    42   val add_datatype_cmd: string list -> theory -> theory
    43   val type_interpretation:
    44     (string * ((string * sort) list * (string * typ list) list)
    45       -> theory -> theory) -> theory -> theory
    46   val add_eqn: thm -> theory -> theory
    47   val add_eqnl: string * (thm * bool) list lazy -> theory -> theory
    48   val add_nbe_eqn: thm -> theory -> theory
    49   val add_default_eqn: thm -> theory -> theory
    50   val add_default_eqn_attribute: attribute
    51   val add_default_eqn_attrib: Attrib.src
    52   val del_eqn: thm -> theory -> theory
    53   val del_eqns: string -> theory -> theory
    54   val add_case: thm -> theory -> theory
    55   val add_undefined: string -> theory -> theory
    56   val get_datatype: theory -> string -> ((string * sort) list * (string * typ list) list)
    57   val get_datatype_of_constr: theory -> string -> string option
    58   val these_eqns: theory -> string -> (thm * bool) list
    59   val all_eqns: theory -> (thm * bool) list
    60   val get_case_scheme: theory -> string -> (int * (int * string list)) option
    61   val undefineds: theory -> string list
    62   val print_codesetup: theory -> unit
    63 
    64   (*infrastructure*)
    65   val add_attribute: string * attribute parser -> theory -> theory
    66   val purge_data: theory -> theory
    67 end;
    68 
    69 signature CODE_DATA_ARGS =
    70 sig
    71   type T
    72   val empty: T
    73   val purge: theory -> string list -> T -> T
    74 end;
    75 
    76 signature CODE_DATA =
    77 sig
    78   type T
    79   val get: theory -> T
    80   val change: theory -> (T -> T) -> T
    81   val change_yield: theory -> (T -> 'a * T) -> 'a * T
    82 end;
    83 
    84 signature PRIVATE_CODE =
    85 sig
    86   include CODE
    87   val declare_data: Object.T -> (theory -> string list -> Object.T -> Object.T)
    88     -> serial
    89   val get_data: serial * ('a -> Object.T) * (Object.T -> 'a)
    90     -> theory -> 'a
    91   val change_data: serial * ('a -> Object.T) * (Object.T -> 'a)
    92     -> theory -> ('a -> 'a) -> 'a
    93   val change_yield_data: serial * ('a -> Object.T) * (Object.T -> 'a)
    94     -> theory -> ('a -> 'b * 'a) -> 'b * 'a
    95 end;
    96 
    97 structure Code : PRIVATE_CODE =
    98 struct
    99 
   100 (** auxiliary **)
   101 
   102 (* printing *)
   103 
   104 fun string_of_typ thy = setmp show_sorts true (Syntax.string_of_typ_global thy);
   105 
   106 fun string_of_const thy c = case AxClass.inst_of_param thy c
   107  of SOME (c, tyco) => Sign.extern_const thy c ^ " " ^ enclose "[" "]" (Sign.extern_type thy tyco)
   108   | NONE => Sign.extern_const thy c;
   109 
   110 
   111 (* constants *)
   112 
   113 fun check_bare_const thy t = case try dest_Const t
   114  of SOME c_ty => c_ty
   115   | NONE => error ("Not a constant: " ^ Syntax.string_of_term_global thy t);
   116 
   117 fun check_const thy = AxClass.unoverload_const thy o check_bare_const thy;
   118 
   119 fun read_bare_const thy = check_bare_const thy o Syntax.read_term_global thy;
   120 
   121 fun read_const thy = AxClass.unoverload_const thy o read_bare_const thy;
   122 
   123 fun typscheme thy (c, ty) =
   124   let
   125     val ty' = Logic.unvarifyT ty;
   126   in (map dest_TFree (Sign.const_typargs thy (c, ty')), Type.strip_sorts ty') end;
   127 
   128 
   129 (* code equation transformations *)
   130 
   131 fun expand_eta thy k thm =
   132   let
   133     val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm;
   134     val (head, args) = strip_comb lhs;
   135     val l = if k = ~1
   136       then (length o fst o strip_abs) rhs
   137       else Int.max (0, k - length args);
   138     val used = Name.make_context (map (fst o fst) (Term.add_vars lhs []));
   139     fun get_name _ 0 = pair []
   140       | get_name (Abs (v, ty, t)) k =
   141           Name.variants [v]
   142           ##>> get_name t (k - 1)
   143           #>> (fn ([v'], vs') => (v', ty) :: vs')
   144       | get_name t k = 
   145           let
   146             val (tys, _) = (strip_type o fastype_of) t
   147           in case tys
   148            of [] => raise TERM ("expand_eta", [t])
   149             | ty :: _ =>
   150                 Name.variants [""]
   151                 #-> (fn [v] => get_name (t $ Var ((v, 0), ty)) (k - 1)
   152                 #>> (fn vs' => (v, ty) :: vs'))
   153           end;
   154     val (vs, _) = get_name rhs l used;
   155     fun expand (v, ty) thm = Drule.fun_cong_rule thm
   156       (Thm.cterm_of thy (Var ((v, 0), ty)));
   157   in
   158     thm
   159     |> fold expand vs
   160     |> Conv.fconv_rule Drule.beta_eta_conversion
   161   end;
   162 
   163 fun norm_args thy thms =
   164   let
   165     val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
   166     val k = fold (curry Int.max o num_args_of o Thm.prop_of) thms 0;
   167   in
   168     thms
   169     |> map (expand_eta thy k)
   170     |> map (Conv.fconv_rule Drule.beta_eta_conversion)
   171   end;
   172 
   173 fun canonical_tvars thy thm =
   174   let
   175     val ctyp = Thm.ctyp_of thy;
   176     val purify_tvar = unprefix "'" #> Name.desymbolize false #> prefix "'";
   177     fun tvars_subst_for thm = (fold_types o fold_atyps)
   178       (fn TVar (v_i as (v, _), sort) => let
   179             val v' = purify_tvar v
   180           in if v = v' then I
   181           else insert (op =) (v_i, (v', sort)) end
   182         | _ => I) (prop_of thm) [];
   183     fun mk_inst (v_i, (v', sort)) (maxidx, acc) =
   184       let
   185         val ty = TVar (v_i, sort)
   186       in
   187         (maxidx + 1, (ctyp ty, ctyp (TVar ((v', maxidx), sort))) :: acc)
   188       end;
   189     val maxidx = Thm.maxidx_of thm + 1;
   190     val (_, inst) = fold mk_inst (tvars_subst_for thm) (maxidx + 1, []);
   191   in Thm.instantiate (inst, []) thm end;
   192 
   193 fun canonical_vars thy thm =
   194   let
   195     val cterm = Thm.cterm_of thy;
   196     val purify_var = Name.desymbolize false;
   197     fun vars_subst_for thm = fold_aterms
   198       (fn Var (v_i as (v, _), ty) => let
   199             val v' = purify_var v
   200           in if v = v' then I
   201           else insert (op =) (v_i, (v', ty)) end
   202         | _ => I) (prop_of thm) [];
   203     fun mk_inst (v_i as (v, i), (v', ty)) (maxidx, acc) =
   204       let
   205         val t = Var (v_i, ty)
   206       in
   207         (maxidx + 1, (cterm t, cterm (Var ((v', maxidx), ty))) :: acc)
   208       end;
   209     val maxidx = Thm.maxidx_of thm + 1;
   210     val (_, inst) = fold mk_inst (vars_subst_for thm) (maxidx + 1, []);
   211   in Thm.instantiate ([], inst) thm end;
   212 
   213 fun canonical_absvars thm =
   214   let
   215     val t = Thm.plain_prop_of thm;
   216     val purify_var = Name.desymbolize false;
   217     val t' = Term.map_abs_vars purify_var t;
   218   in Thm.rename_boundvars t t' thm end;
   219 
   220 fun norm_varnames thy thms =
   221   let
   222     fun burrow_thms f [] = []
   223       | burrow_thms f thms =
   224           thms
   225           |> Conjunction.intr_balanced
   226           |> f
   227           |> Conjunction.elim_balanced (length thms)
   228   in
   229     thms
   230     |> map (canonical_vars thy)
   231     |> map canonical_absvars
   232     |> map Drule.zero_var_indexes
   233     |> burrow_thms (canonical_tvars thy)
   234     |> Drule.zero_var_indexes_list
   235   end;
   236 
   237 
   238 (** code attributes **)
   239 
   240 structure Code_Attr = TheoryDataFun (
   241   type T = (string * attribute parser) list;
   242   val empty = [];
   243   val copy = I;
   244   val extend = I;
   245   fun merge _ = AList.merge (op = : string * string -> bool) (K true);
   246 );
   247 
   248 fun add_attribute (attr as (name, _)) =
   249   let
   250     fun add_parser ("", parser) attrs = attrs |> rev |> AList.update (op =) ("", parser) |> rev
   251       | add_parser (name, parser) attrs = (name, Args.$$$ name |-- parser) :: attrs;
   252   in Code_Attr.map (fn attrs => if not (name = "") andalso AList.defined (op =) attrs name
   253     then error ("Code attribute " ^ name ^ " already declared") else add_parser attr attrs)
   254   end;
   255 
   256 val _ = Context.>> (Context.map_theory
   257   (Attrib.setup (Binding.name "code")
   258     (Scan.peek (fn context =>
   259       List.foldr op || Scan.fail (map snd (Code_Attr.get (Context.theory_of context)))))
   260     "declare theorems for code generation"));
   261 
   262 
   263 (** data store **)
   264 
   265 (* code equations *)
   266 
   267 type eqns = bool * (thm * bool) list lazy;
   268   (*default flag, theorems with proper flag (perhaps lazy)*)
   269 
   270 fun pretty_lthms ctxt r = case Lazy.peek r
   271  of SOME thms => map (ProofContext.pretty_thm ctxt o fst) (Exn.release thms)
   272   | NONE => [Pretty.str "[...]"];
   273 
   274 fun certificate thy f r =
   275   case Lazy.peek r
   276    of SOME thms => (Lazy.value o f thy) (Exn.release thms)
   277     | NONE => let
   278         val thy_ref = Theory.check_thy thy;
   279       in Lazy.lazy (fn () => (f (Theory.deref thy_ref) o Lazy.force) r) end;
   280 
   281 fun add_drop_redundant thy (thm, proper) thms =
   282   let
   283     val args_of = snd o strip_comb o map_types Type.strip_sorts
   284       o fst o Logic.dest_equals o Thm.plain_prop_of;
   285     val args = args_of thm;
   286     val incr_idx = Logic.incr_indexes ([], Thm.maxidx_of thm + 1);
   287     fun matches_args args' = length args <= length args' andalso
   288       Pattern.matchess thy (args, (map incr_idx o curry Library.take (length args)) args');
   289     fun drop (thm', proper') = if (proper orelse not proper')
   290       andalso matches_args (args_of thm') then 
   291         (warning ("Code generator: dropping redundant code equation\n" ^ Display.string_of_thm thm'); true)
   292       else false;
   293   in (thm, proper) :: filter_out drop thms end;
   294 
   295 fun add_thm thy _ thm (false, thms) = (false, Lazy.map_force (add_drop_redundant thy thm) thms)
   296   | add_thm thy true thm (true, thms) = (true, Lazy.map_force (fn thms => thms @ [thm]) thms)
   297   | add_thm thy false thm (true, thms) = (false, Lazy.value [thm]);
   298 
   299 fun add_lthms lthms _ = (false, lthms);
   300 
   301 fun del_thm thm = (apsnd o Lazy.map_force) (remove (eq_fst Thm.eq_thm_prop) (thm, true));
   302 
   303 
   304 (* executable code data *)
   305 
   306 datatype spec = Spec of {
   307   history_concluded: bool,
   308   aliasses: ((string * string) * thm) list * class list,
   309   eqns: ((bool * eqns) * (serial * eqns) list) Symtab.table
   310     (*with explicit history*),
   311   dtyps: ((serial * ((string * sort) list * (string * typ list) list)) list) Symtab.table
   312     (*with explicit history*),
   313   cases: (int * (int * string list)) Symtab.table * unit Symtab.table
   314 };
   315 
   316 fun make_spec ((history_concluded, aliasses), (eqns, (dtyps, cases))) =
   317   Spec { history_concluded = history_concluded, aliasses = aliasses,
   318     eqns = eqns, dtyps = dtyps, cases = cases };
   319 fun map_spec f (Spec { history_concluded = history_concluded, aliasses = aliasses, eqns = eqns,
   320   dtyps = dtyps, cases = cases }) =
   321   make_spec (f ((history_concluded, aliasses), (eqns, (dtyps, cases))));
   322 fun merge_spec (Spec { history_concluded = _, aliasses = aliasses1, eqns = eqns1,
   323     dtyps = dtyps1, cases = (cases1, undefs1) },
   324   Spec { history_concluded = _, aliasses = aliasses2, eqns = eqns2,
   325     dtyps = dtyps2, cases = (cases2, undefs2) }) =
   326   let
   327     val aliasses = (Library.merge (eq_snd Thm.eq_thm_prop) (pairself fst (aliasses1, aliasses2)),
   328       Library.merge (op =) (pairself snd (aliasses1, aliasses2)));
   329     fun merge_eqns ((_, history1), (_, history2)) =
   330       let
   331         val raw_history = AList.merge (op = : serial * serial -> bool)
   332           (K true) (history1, history2)
   333         val filtered_history = filter_out (fst o snd) raw_history
   334         val history = if null filtered_history
   335           then raw_history else filtered_history;
   336       in ((false, (snd o hd) history), history) end;
   337     val eqns = Symtab.join (K merge_eqns) (eqns1, eqns2);
   338     val dtyps = Symtab.join (K (AList.merge (op =) (K true))) (dtyps1, dtyps2);
   339     val cases = (Symtab.merge (K true) (cases1, cases2),
   340       Symtab.merge (K true) (undefs1, undefs2));
   341   in make_spec ((false, aliasses), (eqns, (dtyps, cases))) end;
   342 
   343 fun history_concluded (Spec { history_concluded, ... }) = history_concluded;
   344 fun the_aliasses (Spec { aliasses, ... }) = aliasses;
   345 fun the_eqns (Spec { eqns, ... }) = eqns;
   346 fun the_dtyps (Spec { dtyps, ... }) = dtyps;
   347 fun the_cases (Spec { cases, ... }) = cases;
   348 val map_history_concluded = map_spec o apfst o apfst;
   349 val map_aliasses = map_spec o apfst o apsnd;
   350 val map_eqns = map_spec o apsnd o apfst;
   351 val map_dtyps = map_spec o apsnd o apsnd o apfst;
   352 val map_cases = map_spec o apsnd o apsnd o apsnd;
   353 
   354 
   355 (* data slots dependent on executable code *)
   356 
   357 (*private copy avoids potential conflict of table exceptions*)
   358 structure Datatab = Table(type key = int val ord = int_ord);
   359 
   360 local
   361 
   362 type kind = {
   363   empty: Object.T,
   364   purge: theory -> string list -> Object.T -> Object.T
   365 };
   366 
   367 val kinds = ref (Datatab.empty: kind Datatab.table);
   368 val kind_keys = ref ([]: serial list);
   369 
   370 fun invoke f k = case Datatab.lookup (! kinds) k
   371  of SOME kind => f kind
   372   | NONE => sys_error "Invalid code data identifier";
   373 
   374 in
   375 
   376 fun declare_data empty purge =
   377   let
   378     val k = serial ();
   379     val kind = {empty = empty, purge = purge};
   380     val _ = change kinds (Datatab.update (k, kind));
   381     val _ = change kind_keys (cons k);
   382   in k end;
   383 
   384 fun invoke_init k = invoke (fn kind => #empty kind) k;
   385 
   386 fun invoke_purge_all thy cs =
   387   fold (fn k => Datatab.map_entry k
   388     (invoke (fn kind => #purge kind thy cs) k)) (! kind_keys);
   389 
   390 end; (*local*)
   391 
   392 
   393 (* theory store *)
   394 
   395 local
   396 
   397 type data = Object.T Datatab.table;
   398 val empty_data = Datatab.empty : data;
   399 
   400 structure Code_Data = TheoryDataFun
   401 (
   402   type T = spec * data ref;
   403   val empty = (make_spec ((false, ([], [])),
   404     (Symtab.empty, (Symtab.empty, (Symtab.empty, Symtab.empty)))), ref empty_data);
   405   fun copy (spec, data) = (spec, ref (! data));
   406   val extend = copy;
   407   fun merge pp ((spec1, data1), (spec2, data2)) =
   408     (merge_spec (spec1, spec2), ref empty_data);
   409 );
   410 
   411 fun thy_data f thy = f ((snd o Code_Data.get) thy);
   412 
   413 fun get_ensure_init kind data_ref =
   414   case Datatab.lookup (! data_ref) kind
   415    of SOME x => x
   416     | NONE => let val y = invoke_init kind
   417         in (change data_ref (Datatab.update (kind, y)); y) end;
   418 
   419 in
   420 
   421 (* access to executable code *)
   422 
   423 val the_exec = fst o Code_Data.get;
   424 
   425 fun complete_class_params thy cs =
   426   fold (fn c => case AxClass.inst_of_param thy c
   427    of NONE => insert (op =) c
   428     | SOME (c', _) => insert (op =) c' #> insert (op =) c) cs [];
   429 
   430 fun map_exec_purge touched f thy =
   431   Code_Data.map (fn (exec, data) => (f exec, ref (case touched
   432    of SOME cs => invoke_purge_all thy (complete_class_params thy cs) (! data)
   433     | NONE => empty_data))) thy;
   434 
   435 val purge_data = (Code_Data.map o apsnd) (K (ref empty_data));
   436 
   437 fun change_eqns delete c f = (map_exec_purge (SOME [c]) o map_eqns
   438   o (if delete then Symtab.map_entry c else Symtab.map_default (c, ((false, (true, Lazy.value [])), [])))
   439     o apfst) (fn (_, eqns) => (true, f eqns));
   440 
   441 fun del_eqns c = change_eqns true c (K (false, Lazy.value []));
   442 
   443 
   444 (* tackling equation history *)
   445 
   446 fun get_eqns thy c =
   447   Symtab.lookup ((the_eqns o the_exec) thy) c
   448   |> Option.map (Lazy.force o snd o snd o fst)
   449   |> these;
   450 
   451 fun continue_history thy = if (history_concluded o the_exec) thy
   452   then thy
   453     |> (Code_Data.map o apfst o map_history_concluded) (K false)
   454     |> SOME
   455   else NONE;
   456 
   457 fun conclude_history thy = if (history_concluded o the_exec) thy
   458   then NONE
   459   else thy
   460     |> (Code_Data.map o apfst)
   461         ((map_eqns o Symtab.map) (fn ((changed, current), history) =>
   462           ((false, current),
   463             if changed then (serial (), current) :: history else history))
   464         #> map_history_concluded (K true))
   465     |> SOME;
   466 
   467 val _ = Context.>> (Context.map_theory (Code_Data.init
   468   #> Theory.at_begin continue_history
   469   #> Theory.at_end conclude_history));
   470 
   471 
   472 (* access to data dependent on abstract executable code *)
   473 
   474 fun get_data (kind, _, dest) = thy_data (get_ensure_init kind #> dest);
   475 
   476 fun change_data (kind, mk, dest) =
   477   let
   478     fun chnge data_ref f =
   479       let
   480         val data = get_ensure_init kind data_ref;
   481         val data' = f (dest data);
   482       in (change data_ref (Datatab.update (kind, mk data')); data') end;
   483   in thy_data chnge end;
   484 
   485 fun change_yield_data (kind, mk, dest) =
   486   let
   487     fun chnge data_ref f =
   488       let
   489         val data = get_ensure_init kind data_ref;
   490         val (x, data') = f (dest data);
   491       in (x, (change data_ref (Datatab.update (kind, mk data')); data')) end;
   492   in thy_data chnge end;
   493 
   494 end; (*local*)
   495 
   496 
   497 (** retrieval interfaces **)
   498 
   499 (* constant aliasses *)
   500 
   501 fun resubst_alias thy =
   502   let
   503     val alias = (fst o the_aliasses o the_exec) thy;
   504     val subst_inst_param = Option.map fst o AxClass.inst_of_param thy;
   505     fun subst_alias c =
   506       get_first (fn ((c', c''), _) => if c = c'' then SOME c' else NONE) alias;
   507   in
   508     perhaps subst_inst_param
   509     #> perhaps subst_alias
   510   end;
   511 
   512 val triv_classes = snd o the_aliasses o the_exec;
   513 
   514 
   515 (** foundation **)
   516 
   517 (* constants *)
   518 
   519 fun args_number thy = length o fst o strip_type o Sign.the_const_type thy;
   520 
   521 
   522 (* datatypes *)
   523 
   524 fun constrset_of_consts thy cs =
   525   let
   526     val _ = map (fn (c, _) => if (is_some o AxClass.class_of_param thy) c
   527       then error ("Is a class parameter: " ^ string_of_const thy c) else ()) cs;
   528     fun no_constr (c, ty) = error ("Not a datatype constructor: " ^ string_of_const thy c
   529       ^ " :: " ^ string_of_typ thy ty);
   530     fun last_typ c_ty ty =
   531       let
   532         val frees = OldTerm.typ_tfrees ty;
   533         val (tyco, vs) = ((apsnd o map) (dest_TFree) o dest_Type o snd o strip_type) ty
   534           handle TYPE _ => no_constr c_ty
   535         val _ = if has_duplicates (eq_fst (op =)) vs then no_constr c_ty else ();
   536         val _ = if length frees <> length vs then no_constr c_ty else ();
   537       in (tyco, vs) end;
   538     fun ty_sorts (c, ty) =
   539       let
   540         val ty_decl = (Logic.unvarifyT o Sign.the_const_type thy) c;
   541         val (tyco, _) = last_typ (c, ty) ty_decl;
   542         val (_, vs) = last_typ (c, ty) ty;
   543       in ((tyco, map snd vs), (c, (map fst vs, ty))) end;
   544     fun add ((tyco', sorts'), c) ((tyco, sorts), cs) =
   545       let
   546         val _ = if tyco' <> tyco
   547           then error "Different type constructors in constructor set"
   548           else ();
   549         val sorts'' = map2 (curry (Sorts.inter_sort (Sign.classes_of thy))) sorts' sorts
   550       in ((tyco, sorts), c :: cs) end;
   551     fun inst vs' (c, (vs, ty)) =
   552       let
   553         val the_v = the o AList.lookup (op =) (vs ~~ vs');
   554         val ty' = map_atyps (fn TFree (v, _) => TFree (the_v v)) ty;
   555       in (c, (fst o strip_type) ty') end;
   556     val c' :: cs' = map ty_sorts cs;
   557     val ((tyco, sorts), cs'') = fold add cs' (apsnd single c');
   558     val vs = Name.names Name.context Name.aT sorts;
   559     val cs''' = map (inst vs) cs'';
   560   in (tyco, (vs, rev cs''')) end;
   561 
   562 fun get_datatype thy tyco =
   563   case these (Symtab.lookup ((the_dtyps o the_exec) thy) tyco)
   564    of (_, spec) :: _ => spec
   565     | [] => Sign.arity_number thy tyco
   566         |> Name.invents Name.context Name.aT
   567         |> map (rpair [])
   568         |> rpair [];
   569 
   570 fun get_datatype_of_constr thy c =
   571   case (snd o strip_type o Sign.the_const_type thy) c
   572    of Type (tyco, _) => if member (op =) ((map fst o snd o get_datatype thy) tyco) c
   573        then SOME tyco else NONE
   574     | _ => NONE;
   575 
   576 fun is_constr thy = is_some o get_datatype_of_constr thy;
   577 
   578 
   579 (* code equations *)
   580 
   581 exception BAD_THM of string;
   582 fun bad_thm msg = raise BAD_THM msg;
   583 fun error_thm f thm = f thm handle BAD_THM msg => error msg;
   584 fun warning_thm f thm = SOME (f thm) handle BAD_THM msg => (warning msg; NONE)
   585 fun try_thm f thm = SOME (f thm) handle BAD_THM _ => NONE;
   586 
   587 fun is_linear thm =
   588   let val (_, args) = (strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of) thm
   589   in not (has_duplicates (op =) ((fold o fold_aterms)
   590     (fn Var (v, _) => cons v | _ => I) args [])) end;
   591 
   592 fun gen_assert_eqn thy is_constr_pat (thm, proper) =
   593   let
   594     val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm
   595       handle TERM _ => bad_thm ("Not an equation: " ^ Display.string_of_thm thm)
   596            | THM _ => bad_thm ("Not an equation: " ^ Display.string_of_thm thm);
   597     fun vars_of t = fold_aterms (fn Var (v, _) => insert (op =) v
   598       | Free _ => bad_thm ("Illegal free variable in equation\n"
   599           ^ Display.string_of_thm thm)
   600       | _ => I) t [];
   601     fun tvars_of t = fold_term_types (fn _ =>
   602       fold_atyps (fn TVar (v, _) => insert (op =) v
   603         | TFree _ => bad_thm 
   604       ("Illegal free type variable in equation\n" ^ Display.string_of_thm thm))) t [];
   605     val lhs_vs = vars_of lhs;
   606     val rhs_vs = vars_of rhs;
   607     val lhs_tvs = tvars_of lhs;
   608     val rhs_tvs = tvars_of rhs;
   609     val _ = if null (subtract (op =) lhs_vs rhs_vs)
   610       then ()
   611       else bad_thm ("Free variables on right hand side of equation\n"
   612         ^ Display.string_of_thm thm);
   613     val _ = if null (subtract (op =) lhs_tvs rhs_tvs)
   614       then ()
   615       else bad_thm ("Free type variables on right hand side of equation\n"
   616         ^ Display.string_of_thm thm)    val (head, args) = (strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of) thm;
   617     val (c, ty) = case head
   618      of Const (c_ty as (_, ty)) => (AxClass.unoverload_const thy c_ty, ty)
   619       | _ => bad_thm ("Equation not headed by constant\n" ^ Display.string_of_thm thm);
   620     fun check _ (Abs _) = bad_thm
   621           ("Abstraction on left hand side of equation\n"
   622             ^ Display.string_of_thm thm)
   623       | check 0 (Var _) = ()
   624       | check _ (Var _) = bad_thm
   625           ("Variable with application on left hand side of equation\n"
   626             ^ Display.string_of_thm thm)
   627       | check n (t1 $ t2) = (check (n+1) t1; check 0 t2)
   628       | check n (Const (c_ty as (c, ty))) = if n = (length o fst o strip_type) ty
   629           then if not proper orelse is_constr_pat (AxClass.unoverload_const thy c_ty)
   630             then ()
   631             else bad_thm (quote c ^ " is not a constructor, on left hand side of equation\n"
   632               ^ Display.string_of_thm thm)
   633           else bad_thm
   634             ("Partially applied constant " ^ quote c ^ " on left hand side of equation\n"
   635                ^ Display.string_of_thm thm);
   636     val _ = map (check 0) args;
   637     val _ = if not proper orelse is_linear thm then ()
   638       else bad_thm ("Duplicate variables on left hand side of equation\n"
   639         ^ Display.string_of_thm thm);
   640     val _ = if (is_none o AxClass.class_of_param thy) c
   641       then ()
   642       else bad_thm ("Polymorphic constant as head in equation\n"
   643         ^ Display.string_of_thm thm)
   644     val _ = if not (is_constr thy c)
   645       then ()
   646       else bad_thm ("Constructor as head in equation\n"
   647         ^ Display.string_of_thm thm)
   648     val ty_decl = Sign.the_const_type thy c;
   649     val _ = if Sign.typ_equiv thy (Type.strip_sorts ty_decl, Type.strip_sorts ty)
   650       then () else bad_thm ("Type\n" ^ string_of_typ thy ty
   651            ^ "\nof equation\n"
   652            ^ Display.string_of_thm thm
   653            ^ "\nis incompatible with declared function type\n"
   654            ^ string_of_typ thy ty_decl)
   655   in (thm, proper) end;
   656 
   657 fun assert_eqn thy = error_thm (gen_assert_eqn thy (is_constr thy));
   658 
   659 fun meta_rewrite thy = LocalDefs.meta_rewrite_rule (ProofContext.init thy);
   660 
   661 fun mk_eqn thy = error_thm (gen_assert_eqn thy (K true)) o
   662   apfst (meta_rewrite thy);
   663 
   664 fun mk_eqn_warning thy = Option.map (fn (thm, _) => (thm, is_linear thm))
   665   o warning_thm (gen_assert_eqn thy (K true)) o rpair false o meta_rewrite thy;
   666 
   667 fun mk_eqn_liberal thy = Option.map (fn (thm, _) => (thm, is_linear thm))
   668   o try_thm (gen_assert_eqn thy (K true)) o rpair false o meta_rewrite thy;
   669 
   670 (*those following are permissive wrt. to overloaded constants!*)
   671 
   672 fun const_typ_eqn thy thm =
   673   let
   674     val (c, ty) = (dest_Const o fst o strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of) thm;
   675     val c' = AxClass.unoverload_const thy (c, ty);
   676   in (c', ty) end;
   677 
   678 fun typscheme_eqn thy = typscheme thy o const_typ_eqn thy;
   679 fun const_eqn thy = fst o const_typ_eqn thy;
   680 
   681 fun assert_eqns_const thy c eqns =
   682   let
   683     fun cert (eqn as (thm, _)) = if c = const_eqn thy thm
   684       then eqn else error ("Wrong head of code equation,\nexpected constant "
   685         ^ string_of_const thy c ^ "\n" ^ Display.string_of_thm thm)
   686   in map (cert o assert_eqn thy) eqns end;
   687 
   688 fun common_typ_eqns thy [] = []
   689   | common_typ_eqns thy [thm] = [thm]
   690   | common_typ_eqns thy (thms as thm :: _) = (*FIXME is too general*)
   691       let
   692         fun incr_thm thm max =
   693           let
   694             val thm' = incr_indexes max thm;
   695             val max' = Thm.maxidx_of thm' + 1;
   696           in (thm', max') end;
   697         val (thms', maxidx) = fold_map incr_thm thms 0;
   698         val ty1 :: tys = map (snd o const_typ_eqn thy) thms';
   699         fun unify ty env = Sign.typ_unify thy (ty1, ty) env
   700           handle Type.TUNIFY =>
   701             error ("Type unificaton failed, while unifying code equations\n"
   702             ^ (cat_lines o map Display.string_of_thm) thms
   703             ^ "\nwith types\n"
   704             ^ (cat_lines o map (string_of_typ thy)) (ty1 :: tys));
   705         val (env, _) = fold unify tys (Vartab.empty, maxidx)
   706         val instT = Vartab.fold (fn (x_i, (sort, ty)) =>
   707           cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [];
   708       in map (Thm.instantiate (instT, [])) thms' end;
   709 
   710 fun these_eqns thy c =
   711   get_eqns thy c
   712   |> (map o apfst) (Thm.transfer thy)
   713   |> burrow_fst (common_typ_eqns thy);
   714 
   715 fun all_eqns thy =
   716   Symtab.dest ((the_eqns o the_exec) thy)
   717   |> maps (Lazy.force o snd o snd o fst o snd);
   718 
   719 
   720 (* cases *)
   721 
   722 fun case_certificate thm =
   723   let
   724     val ((head, raw_case_expr), cases) = (apfst Logic.dest_equals
   725       o apsnd Logic.dest_conjunctions o Logic.dest_implies o Thm.prop_of) thm;
   726     val _ = case head of Free _ => true
   727       | Var _ => true
   728       | _ => raise TERM ("case_cert", []);
   729     val ([(case_var, _)], case_expr) = Term.strip_abs_eta 1 raw_case_expr;
   730     val (Const (case_const, _), raw_params) = strip_comb case_expr;
   731     val n = find_index (fn Free (v, _) => v = case_var | _ => false) raw_params;
   732     val _ = if n = ~1 then raise TERM ("case_cert", []) else ();
   733     val params = map (fst o dest_Var) (nth_drop n raw_params);
   734     fun dest_case t =
   735       let
   736         val (head' $ t_co, rhs) = Logic.dest_equals t;
   737         val _ = if head' = head then () else raise TERM ("case_cert", []);
   738         val (Const (co, _), args) = strip_comb t_co;
   739         val (Var (param, _), args') = strip_comb rhs;
   740         val _ = if args' = args then () else raise TERM ("case_cert", []);
   741       in (param, co) end;
   742     fun analyze_cases cases =
   743       let
   744         val co_list = fold (AList.update (op =) o dest_case) cases [];
   745       in map (the o AList.lookup (op =) co_list) params end;
   746     fun analyze_let t =
   747       let
   748         val (head' $ arg, Var (param', _) $ arg') = Logic.dest_equals t;
   749         val _ = if head' = head then () else raise TERM ("case_cert", []);
   750         val _ = if arg' = arg then () else raise TERM ("case_cert", []);
   751         val _ = if [param'] = params then () else raise TERM ("case_cert", []);
   752       in [] end;
   753     fun analyze (cases as [let_case]) =
   754           (analyze_cases cases handle Bind => analyze_let let_case)
   755       | analyze cases = analyze_cases cases;
   756   in (case_const, (n, analyze cases)) end;
   757 
   758 fun case_cert thm = case_certificate thm
   759   handle Bind => error "bad case certificate"
   760        | TERM _ => error "bad case certificate";
   761 
   762 fun get_case_scheme thy = Symtab.lookup ((fst o the_cases o the_exec) thy);
   763 
   764 val undefineds = Symtab.keys o snd o the_cases o the_exec;
   765 
   766 
   767 (* diagnostic *)
   768 
   769 fun print_codesetup thy =
   770   let
   771     val ctxt = ProofContext.init thy;
   772     val exec = the_exec thy;
   773     fun pretty_eqn (s, (_, lthms)) =
   774       (Pretty.block o Pretty.fbreaks) (
   775         Pretty.str s :: pretty_lthms ctxt lthms
   776       );
   777     fun pretty_dtyp (s, []) =
   778           Pretty.str s
   779       | pretty_dtyp (s, cos) =
   780           (Pretty.block o Pretty.breaks) (
   781             Pretty.str s
   782             :: Pretty.str "="
   783             :: separate (Pretty.str "|") (map (fn (c, []) => Pretty.str (string_of_const thy c)
   784                  | (c, tys) =>
   785                      (Pretty.block o Pretty.breaks)
   786                         (Pretty.str (string_of_const thy c)
   787                           :: Pretty.str "of"
   788                           :: map (Pretty.quote o Syntax.pretty_typ_global thy) tys)) cos)
   789           );
   790     val eqns = the_eqns exec
   791       |> Symtab.dest
   792       |> (map o apfst) (string_of_const thy)
   793       |> (map o apsnd) (snd o fst)
   794       |> sort (string_ord o pairself fst);
   795     val dtyps = the_dtyps exec
   796       |> Symtab.dest
   797       |> map (fn (dtco, (_, (vs, cos)) :: _) =>
   798           (string_of_typ thy (Type (dtco, map TFree vs)), cos))
   799       |> sort (string_ord o pairself fst)
   800   in
   801     (Pretty.writeln o Pretty.chunks) [
   802       Pretty.block (
   803         Pretty.str "code equations:"
   804         :: Pretty.fbrk
   805         :: (Pretty.fbreaks o map pretty_eqn) eqns
   806       ),
   807       Pretty.block (
   808         Pretty.str "datatypes:"
   809         :: Pretty.fbrk
   810         :: (Pretty.fbreaks o map pretty_dtyp) dtyps
   811       )
   812     ]
   813   end;
   814 
   815 
   816 (** declaring executable ingredients **)
   817 
   818 (* constant aliasses *)
   819 
   820 fun add_const_alias thm thy =
   821   let
   822     val (ofclass, eqn) = case try Logic.dest_equals (Thm.prop_of thm)
   823      of SOME ofclass_eq => ofclass_eq
   824       | _ => error ("Bad certificate: " ^ Display.string_of_thm thm);
   825     val (T, class) = case try Logic.dest_of_class ofclass
   826      of SOME T_class => T_class
   827       | _ => error ("Bad certificate: " ^ Display.string_of_thm thm);
   828     val tvar = case try Term.dest_TVar T
   829      of SOME tvar => tvar
   830       | _ => error ("Bad type: " ^ Display.string_of_thm thm);
   831     val _ = if Term.add_tvars eqn [] = [tvar] then ()
   832       else error ("Inconsistent type: " ^ Display.string_of_thm thm);
   833     val lhs_rhs = case try Logic.dest_equals eqn
   834      of SOME lhs_rhs => lhs_rhs
   835       | _ => error ("Not an equation: " ^ Syntax.string_of_term_global thy eqn);
   836     val c_c' = case try (pairself (check_const thy)) lhs_rhs
   837      of SOME c_c' => c_c'
   838       | _ => error ("Not an equation with two constants: "
   839           ^ Syntax.string_of_term_global thy eqn);
   840     val _ = if the_list (AxClass.class_of_param thy (snd c_c')) = [class] then ()
   841       else error ("Inconsistent class: " ^ Display.string_of_thm thm);
   842   in thy |>
   843     (map_exec_purge NONE o map_aliasses) (fn (alias, classes) =>
   844       ((c_c', thm) :: alias, insert (op =) class classes))
   845   end;
   846 
   847 
   848 (* datatypes *)
   849 
   850 structure Type_Interpretation = InterpretationFun(type T = string * serial val eq = eq_snd (op =) : T * T -> bool);
   851 
   852 fun add_datatype raw_cs thy =
   853   let
   854     val cs = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) raw_cs;
   855     val (tyco, vs_cos) = constrset_of_consts thy cs;
   856     val old_cs = (map fst o snd o get_datatype thy) tyco;
   857     fun drop_outdated_cases cases = fold Symtab.delete_safe
   858       (Symtab.fold (fn (c, (_, (_, cos))) =>
   859         if exists (member (op =) old_cs) cos
   860           then insert (op =) c else I) cases []) cases;
   861   in
   862     thy
   863     |> fold (del_eqns o fst) cs
   864     |> map_exec_purge NONE
   865         ((map_dtyps o Symtab.map_default (tyco, [])) (cons (serial (), vs_cos))
   866         #> (map_cases o apfst) drop_outdated_cases)
   867     |> Type_Interpretation.data (tyco, serial ())
   868   end;
   869 
   870 fun type_interpretation f =  Type_Interpretation.interpretation
   871   (fn (tyco, _) => fn thy => f (tyco, get_datatype thy tyco) thy);
   872 
   873 fun add_datatype_cmd raw_cs thy =
   874   let
   875     val cs = map (read_bare_const thy) raw_cs;
   876   in add_datatype cs thy end;
   877 
   878 
   879 (* code equations *)
   880 
   881 fun gen_add_eqn default (eqn as (thm, _)) thy =
   882   let val c = const_eqn thy thm
   883   in change_eqns false c (add_thm thy default eqn) thy end;
   884 
   885 fun add_eqn thm thy =
   886   gen_add_eqn false (mk_eqn thy (thm, true)) thy;
   887 
   888 fun add_warning_eqn thm thy =
   889   case mk_eqn_warning thy thm
   890    of SOME eqn => gen_add_eqn false eqn thy
   891     | NONE => thy;
   892 
   893 fun add_default_eqn thm thy =
   894   case mk_eqn_liberal thy thm
   895    of SOME eqn => gen_add_eqn true eqn thy
   896     | NONE => thy;
   897 
   898 fun add_nbe_eqn thm thy =
   899   gen_add_eqn false (mk_eqn thy (thm, false)) thy;
   900 
   901 fun add_eqnl (c, lthms) thy =
   902   let
   903     val lthms' = certificate thy (fn thy => assert_eqns_const thy c) lthms;
   904   in change_eqns false c (add_lthms lthms') thy end;
   905 
   906 val add_default_eqn_attribute = Thm.declaration_attribute
   907   (fn thm => Context.mapping (add_default_eqn thm) I);
   908 val add_default_eqn_attrib = Attrib.internal (K add_default_eqn_attribute);
   909 
   910 fun del_eqn thm thy = case mk_eqn_liberal thy thm
   911  of SOME (thm, _) => change_eqns true (const_eqn thy thm) (del_thm thm) thy
   912   | NONE => thy;
   913 
   914 val _ = Context.>> (Context.map_theory
   915   (let
   916     fun mk_attribute f = Thm.declaration_attribute (fn thm => Context.mapping (f thm) I);
   917     fun add_simple_attribute (name, f) =
   918       add_attribute (name, Scan.succeed (mk_attribute f));
   919     fun add_del_attribute (name, (add, del)) =
   920       add_attribute (name, Args.del |-- Scan.succeed (mk_attribute del)
   921         || Scan.succeed (mk_attribute add))
   922   in
   923     Type_Interpretation.init
   924     #> add_del_attribute ("", (add_warning_eqn, del_eqn))
   925     #> add_simple_attribute ("nbe", add_nbe_eqn)
   926   end));
   927 
   928 
   929 (* cases *)
   930 
   931 fun add_case thm thy =
   932   let
   933     val (c, (k, case_pats)) = case_cert thm;
   934     val _ = case filter_out (is_constr thy) case_pats
   935      of [] => ()
   936       | cs => error ("Non-constructor(s) in case certificate: " ^ commas (map quote cs));
   937     val entry = (1 + Int.max (1, length case_pats), (k, case_pats))
   938   in (map_exec_purge (SOME [c]) o map_cases o apfst) (Symtab.update (c, entry)) thy end;
   939 
   940 fun add_undefined c thy =
   941   (map_exec_purge (SOME [c]) o map_cases o apsnd) (Symtab.update (c, ())) thy;
   942 
   943 end; (*struct*)
   944 
   945 
   946 (** type-safe interfaces for data depedent on executable code **)
   947 
   948 functor Code_Data_Fun(Data: CODE_DATA_ARGS): CODE_DATA =
   949 struct
   950 
   951 type T = Data.T;
   952 exception Data of T;
   953 fun dest (Data x) = x
   954 
   955 val kind = Code.declare_data (Data Data.empty)
   956   (fn thy => fn cs => fn Data x => Data (Data.purge thy cs x));
   957 
   958 val data_op = (kind, Data, dest);
   959 
   960 val get = Code.get_data data_op;
   961 val change = Code.change_data data_op;
   962 fun change_yield thy = Code.change_yield_data data_op thy;
   963 
   964 end;
   965 
   966 structure Code : CODE = struct open Code; end;