diff -r 6cad6ed2700a -r 6bdd125d932b src/Tools/Code/code_thingol.ML --- a/src/Tools/Code/code_thingol.ML Sun Feb 12 06:45:59 2023 +0000 +++ b/src/Tools/Code/code_thingol.ML Thu Feb 09 13:50:09 2023 +0100 @@ -25,17 +25,17 @@ `%% of string * itype list | ITyVar of vname; type const = { sym: Code_Symbol.T, typargs: itype list, dicts: dict list list, - dom: itype list, annotation: itype option }; + dom: itype list, range: itype, annotation: itype option }; datatype iterm = IConst of const | IVar of vname option | `$ of iterm * iterm - | `|=> of (vname option * itype) * iterm + | `|=> of (vname option * itype) * (iterm * itype) | ICase of { term: iterm, typ: itype, clauses: (iterm * iterm) list, primitive: iterm }; val `-> : itype * itype -> itype; val `--> : itype list * itype -> itype; val `$$ : iterm * iterm list -> iterm; - val `|==> : (vname option * itype) list * iterm -> iterm; + val `|==> : (vname option * itype) list * (iterm * itype) -> iterm; type typscheme = (vname * sort) list * itype; end; @@ -48,6 +48,7 @@ val unfold_fun_n: int -> itype -> itype list * itype val unfold_app: iterm -> iterm * iterm list val unfold_abs: iterm -> (vname option * itype) list * iterm + val unfold_abs_typed: iterm -> ((vname option * itype) list * (iterm * itype)) option val split_let: iterm -> (((iterm * itype) * iterm) * iterm) option val split_let_no_pat: iterm -> (((string option * itype) * iterm) * iterm) option val unfold_let: iterm -> ((iterm * itype) * iterm) list * iterm @@ -158,13 +159,13 @@ in (tys3, tys2 `--> ty1) end; type const = { sym: Code_Symbol.T, typargs: itype list, dicts: dict list list, - dom: itype list, annotation: itype option }; + dom: itype list, range: itype, annotation: itype option }; datatype iterm = IConst of const | IVar of vname option | `$ of iterm * iterm - | `|=> of (vname option * itype) * iterm + | `|=> of (vname option * itype) * (iterm * itype) | ICase of { term: iterm, typ: itype, clauses: (iterm * iterm) list, primitive: iterm }; (*see also signature*) @@ -175,16 +176,26 @@ | is_IAbs _ = false; val op `$$ = Library.foldl (op `$); -val op `|==> = Library.foldr (op `|=>); +fun vs_tys `|==> body = Library.foldr + (fn (v_ty as (_, ty), body as (_, rty)) => (v_ty `|=> body, ty `-> rty)) (vs_tys, body) + |> fst; val unfold_app = unfoldl (fn op `$ t_t => SOME t_t | _ => NONE); val unfold_abs = unfoldr - (fn op `|=> v_t => SOME v_t + (fn (v `|=> (t, _)) => SOME (v, t) | _ => NONE); +fun unfold_abs_typed (v_ty `|=> body) = + unfoldr + (fn (v_ty `|=> body, _) => SOME (v_ty, body) + | _ => NONE) body + |> apfst (cons v_ty) + |> SOME + | unfold_abs_typed _ = NONE + fun split_let (ICase { term = t, typ = ty, clauses = [(p, body)], ... }) = SOME (((p, ty), t), body) | split_let _ = NONE; @@ -207,7 +218,7 @@ fun fold' (IConst c) = f c | fold' (IVar _) = I | fold' (t1 `$ t2) = fold' t1 #> fold' t2 - | fold' (_ `|=> t) = fold' t + | fold' (_ `|=> (t, _)) = fold' t | fold' (ICase { term = t, clauses = clauses, ... }) = fold' t #> fold (fn (p, body) => fold' p #> fold' body) clauses in fold' end; @@ -227,8 +238,8 @@ | fold_term vs (IVar (SOME v)) = if member (op =) vs v then I else f v | fold_term _ (IVar NONE) = I | fold_term vs (t1 `$ t2) = fold_term vs t1 #> fold_term vs t2 - | fold_term vs ((SOME v, _) `|=> t) = fold_term (insert (op =) v vs) t - | fold_term vs ((NONE, _) `|=> t) = fold_term vs t + | fold_term vs ((SOME v, _) `|=> (t, _)) = fold_term (insert (op =) v vs) t + | fold_term vs ((NONE, _) `|=> (t, _)) = fold_term vs t | fold_term vs (ICase { term = t, clauses = clauses, ... }) = fold_term vs t #> fold (fold_clause vs) clauses and fold_clause vs (p, t) = fold_term (add_vars p vs) t; @@ -245,8 +256,8 @@ fun invent_params used tys = (map o apfst) SOME (Name.invent_names (Name.build_context used) "a" tys); -fun split_pat_abs ((NONE, ty) `|=> t) = SOME ((IVar NONE, ty), t) - | split_pat_abs ((SOME v, ty) `|=> t) = SOME (case t +fun split_pat_abs ((NONE, ty) `|=> (t, _)) = SOME ((IVar NONE, ty), t) + | split_pat_abs ((SOME v, ty) `|=> (t, _)) = SOME (case t of ICase { term = IVar (SOME w), clauses = [(p, body)], ... } => if v = w andalso (exists_var p v orelse not (exists_var body v)) then ((p, ty), body) @@ -257,7 +268,7 @@ val unfold_pat_abs = unfoldr split_pat_abs; fun unfold_abs_eta [] t = ([], t) - | unfold_abs_eta (_ :: tys) ((v, _) `|=> t) = + | unfold_abs_eta (_ :: tys) ((v, _) `|=> (t, _)) = let val (vs, t') = unfold_abs_eta tys t; in (v :: vs, t') end @@ -266,20 +277,21 @@ val vs = map fst (invent_params (declare_varnames t) tys); in (vs, t `$$ map IVar vs) end; -fun satisfied_application wanted (const as { dom = tys, ... }, ts) = +fun satisfied_application wanted (const as { dom, range, ... }, ts) = let val given = length ts; val delta = wanted - given; val vs_tys = invent_params (fold declare_varnames ts) - (((take delta o drop given) tys)); - in vs_tys `|==> IConst const `$$ ts @ map (IVar o fst) vs_tys end; + (((take delta o drop given) dom)); + val (_, rty) = unfold_fun_n wanted range; + in vs_tys `|==> (IConst const `$$ ts @ map (IVar o fst) vs_tys, rty) end; fun map_terms_bottom_up f (t as IConst _) = f t | map_terms_bottom_up f (t as IVar _) = f t | map_terms_bottom_up f (t1 `$ t2) = f (map_terms_bottom_up f t1 `$ map_terms_bottom_up f t2) - | map_terms_bottom_up f ((v, ty) `|=> t) = f - ((v, ty) `|=> map_terms_bottom_up f t) + | map_terms_bottom_up f ((v, ty) `|=> (t, rty)) = f + ((v, ty) `|=> (map_terms_bottom_up f t, rty)) | map_terms_bottom_up f (ICase { term = t, typ = ty, clauses = clauses, primitive = t0 }) = f (ICase { term = map_terms_bottom_up f t, typ = ty, clauses = (map o apply2) (map_terms_bottom_up f) clauses, @@ -330,7 +342,7 @@ fun contains_dict_var (IConst { dicts = dss, ... }) = exists_dictss_var (K true) dss | contains_dict_var (IVar _) = false | contains_dict_var (t1 `$ t2) = contains_dict_var t1 orelse contains_dict_var t2 - | contains_dict_var (_ `|=> t) = contains_dict_var t + | contains_dict_var (_ `|=> (t, _)) = contains_dict_var t | contains_dict_var (ICase { primitive = t, ... }) = contains_dict_var t; val unambiguous_dictss = not o exists_dictss_var (fn { unique, ... } => not unique); @@ -673,10 +685,12 @@ let val ((v', _), t') = Term.dest_abs_global (Abs (Name.desymbolize (SOME false) v, ty, t)); val v'' = if Term.used_free v' t' then SOME v' else NONE + val rty = fastype_of_tagged_term t' in translate_typ ctxt algbr eqngr permissive ty + ##>> translate_typ ctxt algbr eqngr permissive rty ##>> translate_term ctxt algbr eqngr permissive some_thm (t', some_abs) - #>> (fn (ty, t) => (v'', ty) `|=> t) + #>> (fn ((ty, rty), t) => (v'', ty) `|=> (t, rty)) end | translate_term ctxt algbr eqngr permissive some_thm (t as _ $ _, some_abs) = case strip_comb t @@ -712,11 +726,11 @@ ensure_const ctxt algbr eqngr permissive c ##>> fold_map (translate_typ ctxt algbr eqngr permissive) typargs ##>> fold_map (translate_dicts ctxt algbr eqngr permissive some_thm) (typargs ~~ sorts) - ##>> fold_map (translate_typ ctxt algbr eqngr permissive) (ty' :: dom) - #>> (fn (((c, typargs), dss), annotation :: dom) => + ##>> fold_map (translate_typ ctxt algbr eqngr permissive) (range :: dom) + #>> (fn (((c, typargs), dss), range :: dom) => IConst { sym = Constant c, typargs = typargs, dicts = dss, - dom = dom, annotation = - if annotate then SOME annotation else NONE }) + dom = dom, range = range, annotation = + if annotate then SOME (dom `--> range) else NONE }) end and translate_app_const ctxt algbr eqngr permissive some_thm ((c_ty, ts), some_abs) = translate_const ctxt algbr eqngr permissive some_thm (c_ty, some_abs) @@ -774,10 +788,12 @@ val tys = (take delta o drop given o binder_types) ty; val used = Name.build_context ((fold o fold_aterms) Term.declare_term_frees ts); val vs_tys = Name.invent_names used "a" tys; + val rty = (drop delta o binder_types) ty ---> body_type ty; in fold_map (translate_typ ctxt algbr eqngr permissive) tys ##>> translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), ts @ map Free vs_tys) - #>> (fn (tys, t) => map2 (fn (v, _) => pair (SOME v)) vs_tys tys `|==> t) + ##>> translate_typ ctxt algbr eqngr permissive rty + #>> (fn ((tys, t), rty) => map2 (fn (v, _) => pair (SOME v)) vs_tys tys `|==> (t, rty)) end else if length ts > wanted then translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), take wanted ts)