--- a/src/HOL/Imperative_HOL/Heap_Monad.thy Sun Feb 12 06:45:59 2023 +0000
+++ b/src/HOL/Imperative_HOL/Heap_Monad.thy Thu Feb 09 13:50:09 2023 +0100
@@ -646,8 +646,8 @@
val unitT = \<^type_name>\<open>unit\<close> `%% [];
val unitt =
IConst { sym = Code_Symbol.Constant \<^const_name>\<open>Unity\<close>, typargs = [], dicts = [], dom = [],
- annotation = NONE };
- fun dest_abs ((v, ty) `|=> t, _) = ((v, ty), t)
+ annotation = NONE, range = unitT };
+ fun dest_abs ((v, ty) `|=> (t, _), _) = ((v, ty), t)
| dest_abs (t, ty) =
let
val vs = fold_varnames cons t [];
@@ -667,7 +667,7 @@
else force t
| _ => force t;
fun imp_monad_bind'' ts = (SOME dummy_name, unitT) `|=>
- ICase { term = IVar (SOME dummy_name), typ = unitT, clauses = [(unitt, tr_bind'' ts)], primitive = dummy_case_term }
+ (ICase { term = IVar (SOME dummy_name), typ = unitT, clauses = [(unitt, tr_bind'' ts)], primitive = dummy_case_term }, unitT)
fun imp_monad_bind' (const as { sym = Code_Symbol.Constant c, dom = dom, ... }) ts = if is_bind c then case (ts, dom)
of ([t1, t2], ty1 :: ty2 :: _) => imp_monad_bind'' [(t1, ty1), (t2, ty2)]
| ([t1, t2, t3], ty1 :: ty2 :: _) => imp_monad_bind'' [(t1, ty1), (t2, ty2)] `$ t3
@@ -678,7 +678,7 @@
| imp_monad_bind (t as _ `$ _) = (case unfold_app t
of (IConst const, ts) => imp_monad_bind' const ts
| (t, ts) => imp_monad_bind t `$$ map imp_monad_bind ts)
- | imp_monad_bind (v_ty `|=> t) = v_ty `|=> imp_monad_bind t
+ | imp_monad_bind (v_ty `|=> t) = v_ty `|=> apfst imp_monad_bind t
| imp_monad_bind (ICase { term = t, typ = ty, clauses = clauses, primitive = t0 }) =
ICase { term = imp_monad_bind t, typ = ty,
clauses = (map o apply2) imp_monad_bind clauses, primitive = imp_monad_bind t0 };
--- a/src/Tools/Code/code_thingol.ML Sun Feb 12 06:45:59 2023 +0000
+++ b/src/Tools/Code/code_thingol.ML Thu Feb 09 13:50:09 2023 +0100
@@ -25,17 +25,17 @@
`%% of string * itype list
| ITyVar of vname;
type const = { sym: Code_Symbol.T, typargs: itype list, dicts: dict list list,
- dom: itype list, annotation: itype option };
+ dom: itype list, range: itype, annotation: itype option };
datatype iterm =
IConst of const
| IVar of vname option
| `$ of iterm * iterm
- | `|=> of (vname option * itype) * iterm
+ | `|=> of (vname option * itype) * (iterm * itype)
| ICase of { term: iterm, typ: itype, clauses: (iterm * iterm) list, primitive: iterm };
val `-> : itype * itype -> itype;
val `--> : itype list * itype -> itype;
val `$$ : iterm * iterm list -> iterm;
- val `|==> : (vname option * itype) list * iterm -> iterm;
+ val `|==> : (vname option * itype) list * (iterm * itype) -> iterm;
type typscheme = (vname * sort) list * itype;
end;
@@ -48,6 +48,7 @@
val unfold_fun_n: int -> itype -> itype list * itype
val unfold_app: iterm -> iterm * iterm list
val unfold_abs: iterm -> (vname option * itype) list * iterm
+ val unfold_abs_typed: iterm -> ((vname option * itype) list * (iterm * itype)) option
val split_let: iterm -> (((iterm * itype) * iterm) * iterm) option
val split_let_no_pat: iterm -> (((string option * itype) * iterm) * iterm) option
val unfold_let: iterm -> ((iterm * itype) * iterm) list * iterm
@@ -158,13 +159,13 @@
in (tys3, tys2 `--> ty1) end;
type const = { sym: Code_Symbol.T, typargs: itype list, dicts: dict list list,
- dom: itype list, annotation: itype option };
+ dom: itype list, range: itype, annotation: itype option };
datatype iterm =
IConst of const
| IVar of vname option
| `$ of iterm * iterm
- | `|=> of (vname option * itype) * iterm
+ | `|=> of (vname option * itype) * (iterm * itype)
| ICase of { term: iterm, typ: itype, clauses: (iterm * iterm) list, primitive: iterm };
(*see also signature*)
@@ -175,16 +176,26 @@
| is_IAbs _ = false;
val op `$$ = Library.foldl (op `$);
-val op `|==> = Library.foldr (op `|=>);
+fun vs_tys `|==> body = Library.foldr
+ (fn (v_ty as (_, ty), body as (_, rty)) => (v_ty `|=> body, ty `-> rty)) (vs_tys, body)
+ |> fst;
val unfold_app = unfoldl
(fn op `$ t_t => SOME t_t
| _ => NONE);
val unfold_abs = unfoldr
- (fn op `|=> v_t => SOME v_t
+ (fn (v `|=> (t, _)) => SOME (v, t)
| _ => NONE);
+fun unfold_abs_typed (v_ty `|=> body) =
+ unfoldr
+ (fn (v_ty `|=> body, _) => SOME (v_ty, body)
+ | _ => NONE) body
+ |> apfst (cons v_ty)
+ |> SOME
+ | unfold_abs_typed _ = NONE
+
fun split_let (ICase { term = t, typ = ty, clauses = [(p, body)], ... }) =
SOME (((p, ty), t), body)
| split_let _ = NONE;
@@ -207,7 +218,7 @@
fun fold' (IConst c) = f c
| fold' (IVar _) = I
| fold' (t1 `$ t2) = fold' t1 #> fold' t2
- | fold' (_ `|=> t) = fold' t
+ | fold' (_ `|=> (t, _)) = fold' t
| fold' (ICase { term = t, clauses = clauses, ... }) =
fold' t #> fold (fn (p, body) => fold' p #> fold' body) clauses
in fold' end;
@@ -227,8 +238,8 @@
| fold_term vs (IVar (SOME v)) = if member (op =) vs v then I else f v
| fold_term _ (IVar NONE) = I
| fold_term vs (t1 `$ t2) = fold_term vs t1 #> fold_term vs t2
- | fold_term vs ((SOME v, _) `|=> t) = fold_term (insert (op =) v vs) t
- | fold_term vs ((NONE, _) `|=> t) = fold_term vs t
+ | fold_term vs ((SOME v, _) `|=> (t, _)) = fold_term (insert (op =) v vs) t
+ | fold_term vs ((NONE, _) `|=> (t, _)) = fold_term vs t
| fold_term vs (ICase { term = t, clauses = clauses, ... }) =
fold_term vs t #> fold (fold_clause vs) clauses
and fold_clause vs (p, t) = fold_term (add_vars p vs) t;
@@ -245,8 +256,8 @@
fun invent_params used tys =
(map o apfst) SOME (Name.invent_names (Name.build_context used) "a" tys);
-fun split_pat_abs ((NONE, ty) `|=> t) = SOME ((IVar NONE, ty), t)
- | split_pat_abs ((SOME v, ty) `|=> t) = SOME (case t
+fun split_pat_abs ((NONE, ty) `|=> (t, _)) = SOME ((IVar NONE, ty), t)
+ | split_pat_abs ((SOME v, ty) `|=> (t, _)) = SOME (case t
of ICase { term = IVar (SOME w), clauses = [(p, body)], ... } =>
if v = w andalso (exists_var p v orelse not (exists_var body v))
then ((p, ty), body)
@@ -257,7 +268,7 @@
val unfold_pat_abs = unfoldr split_pat_abs;
fun unfold_abs_eta [] t = ([], t)
- | unfold_abs_eta (_ :: tys) ((v, _) `|=> t) =
+ | unfold_abs_eta (_ :: tys) ((v, _) `|=> (t, _)) =
let
val (vs, t') = unfold_abs_eta tys t;
in (v :: vs, t') end
@@ -266,20 +277,21 @@
val vs = map fst (invent_params (declare_varnames t) tys);
in (vs, t `$$ map IVar vs) end;
-fun satisfied_application wanted (const as { dom = tys, ... }, ts) =
+fun satisfied_application wanted (const as { dom, range, ... }, ts) =
let
val given = length ts;
val delta = wanted - given;
val vs_tys = invent_params (fold declare_varnames ts)
- (((take delta o drop given) tys));
- in vs_tys `|==> IConst const `$$ ts @ map (IVar o fst) vs_tys end;
+ (((take delta o drop given) dom));
+ val (_, rty) = unfold_fun_n wanted range;
+ in vs_tys `|==> (IConst const `$$ ts @ map (IVar o fst) vs_tys, rty) end;
fun map_terms_bottom_up f (t as IConst _) = f t
| map_terms_bottom_up f (t as IVar _) = f t
| map_terms_bottom_up f (t1 `$ t2) = f
(map_terms_bottom_up f t1 `$ map_terms_bottom_up f t2)
- | map_terms_bottom_up f ((v, ty) `|=> t) = f
- ((v, ty) `|=> map_terms_bottom_up f t)
+ | map_terms_bottom_up f ((v, ty) `|=> (t, rty)) = f
+ ((v, ty) `|=> (map_terms_bottom_up f t, rty))
| map_terms_bottom_up f (ICase { term = t, typ = ty, clauses = clauses, primitive = t0 }) = f
(ICase { term = map_terms_bottom_up f t, typ = ty,
clauses = (map o apply2) (map_terms_bottom_up f) clauses,
@@ -330,7 +342,7 @@
fun contains_dict_var (IConst { dicts = dss, ... }) = exists_dictss_var (K true) dss
| contains_dict_var (IVar _) = false
| contains_dict_var (t1 `$ t2) = contains_dict_var t1 orelse contains_dict_var t2
- | contains_dict_var (_ `|=> t) = contains_dict_var t
+ | contains_dict_var (_ `|=> (t, _)) = contains_dict_var t
| contains_dict_var (ICase { primitive = t, ... }) = contains_dict_var t;
val unambiguous_dictss = not o exists_dictss_var (fn { unique, ... } => not unique);
@@ -673,10 +685,12 @@
let
val ((v', _), t') = Term.dest_abs_global (Abs (Name.desymbolize (SOME false) v, ty, t));
val v'' = if Term.used_free v' t' then SOME v' else NONE
+ val rty = fastype_of_tagged_term t'
in
translate_typ ctxt algbr eqngr permissive ty
+ ##>> translate_typ ctxt algbr eqngr permissive rty
##>> translate_term ctxt algbr eqngr permissive some_thm (t', some_abs)
- #>> (fn (ty, t) => (v'', ty) `|=> t)
+ #>> (fn ((ty, rty), t) => (v'', ty) `|=> (t, rty))
end
| translate_term ctxt algbr eqngr permissive some_thm (t as _ $ _, some_abs) =
case strip_comb t
@@ -712,11 +726,11 @@
ensure_const ctxt algbr eqngr permissive c
##>> fold_map (translate_typ ctxt algbr eqngr permissive) typargs
##>> fold_map (translate_dicts ctxt algbr eqngr permissive some_thm) (typargs ~~ sorts)
- ##>> fold_map (translate_typ ctxt algbr eqngr permissive) (ty' :: dom)
- #>> (fn (((c, typargs), dss), annotation :: dom) =>
+ ##>> fold_map (translate_typ ctxt algbr eqngr permissive) (range :: dom)
+ #>> (fn (((c, typargs), dss), range :: dom) =>
IConst { sym = Constant c, typargs = typargs, dicts = dss,
- dom = dom, annotation =
- if annotate then SOME annotation else NONE })
+ dom = dom, range = range, annotation =
+ if annotate then SOME (dom `--> range) else NONE })
end
and translate_app_const ctxt algbr eqngr permissive some_thm ((c_ty, ts), some_abs) =
translate_const ctxt algbr eqngr permissive some_thm (c_ty, some_abs)
@@ -774,10 +788,12 @@
val tys = (take delta o drop given o binder_types) ty;
val used = Name.build_context ((fold o fold_aterms) Term.declare_term_frees ts);
val vs_tys = Name.invent_names used "a" tys;
+ val rty = (drop delta o binder_types) ty ---> body_type ty;
in
fold_map (translate_typ ctxt algbr eqngr permissive) tys
##>> translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), ts @ map Free vs_tys)
- #>> (fn (tys, t) => map2 (fn (v, _) => pair (SOME v)) vs_tys tys `|==> t)
+ ##>> translate_typ ctxt algbr eqngr permissive rty
+ #>> (fn ((tys, t), rty) => map2 (fn (v, _) => pair (SOME v)) vs_tys tys `|==> (t, rty))
end
else if length ts > wanted then
translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), take wanted ts)