# HG changeset patch # User haftmann # Date 1676184358 0 # Node ID 04571037ed3310a32c03d8f6c5e82284120d8d3b # Parent 2d26af072990e2c942f5698a8d92f52702c0726a tuned diff -r 2d26af072990 -r 04571037ed33 src/Tools/Code/code_haskell.ML --- a/src/Tools/Code/code_haskell.ML Fri Feb 10 14:51:51 2023 +0000 +++ b/src/Tools/Code/code_haskell.ML Sun Feb 12 06:45:58 2023 +0000 @@ -252,7 +252,7 @@ |> intro_vars (map_filter I (s :: vs)); val lhs = IConst { sym = Constant classparam, typargs = [], dicts = [], dom = dom, annotation = NONE } `$$ map IVar vs; - (*dictionaries are not relevant at this late stage, + (*dictionaries are not relevant in Haskell, and these consts never need type annotations for disambiguation *) in semicolon [ diff -r 2d26af072990 -r 04571037ed33 src/Tools/Code/code_thingol.ML --- a/src/Tools/Code/code_thingol.ML Fri Feb 10 14:51:51 2023 +0000 +++ b/src/Tools/Code/code_thingol.ML Sun Feb 12 06:45:58 2023 +0000 @@ -9,6 +9,7 @@ infix 4 `$; infix 4 `$$; infixr 3 `->; +infixr 3 `-->; infixr 3 `|=>; infixr 3 `|==>; @@ -32,6 +33,7 @@ | `|=> of (vname option * itype) * iterm | ICase of { term: iterm, typ: itype, clauses: (iterm * iterm) list, primitive: iterm }; val `-> : itype * itype -> itype; + val `--> : itype list * itype -> itype; val `$$ : iterm * iterm list -> iterm; val `|==> : (vname option * itype) list * iterm -> iterm; type typscheme = (vname * sort) list * itype; @@ -53,7 +55,6 @@ val split_pat_abs: iterm -> ((iterm * itype) * iterm) option val unfold_pat_abs: iterm -> (iterm * itype) list * iterm val unfold_const_app: iterm -> (const * iterm list) option - val map_terms_bottom_up: (iterm -> iterm) -> iterm -> iterm val is_IVar: iterm -> bool val is_IAbs: iterm -> bool val eta_expand: int -> const * iterm list -> iterm @@ -144,6 +145,8 @@ fun ty1 `-> ty2 = "fun" `%% [ty1, ty2]; +val op `--> = Library.foldr (op `->); + val unfold_fun = unfoldr (fn "fun" `%% [ty1, ty2] => SOME (ty1, ty2) | _ => NONE); @@ -152,8 +155,7 @@ let val (tys1, ty1) = unfold_fun ty; val (tys3, tys2) = chop n tys1; - val ty3 = Library.foldr (op `->) (tys2, ty1); - in (tys3, ty3) end; + in (tys3, tys2 `--> ty1) end; type const = { sym: Code_Symbol.T, typargs: itype list, dicts: dict list list, dom: itype list, annotation: itype option }; @@ -176,20 +178,20 @@ val op `|==> = Library.foldr (op `|=>); val unfold_app = unfoldl - (fn op `$ t => SOME t + (fn op `$ t_t => SOME t_t | _ => NONE); val unfold_abs = unfoldr - (fn op `|=> t => SOME t + (fn op `|=> v_t => SOME v_t | _ => NONE); -val split_let = - (fn ICase { term = t, typ = ty, clauses = [(p, body)], ... } => SOME (((p, ty), t), body) - | _ => NONE); +fun split_let (ICase { term = t, typ = ty, clauses = [(p, body)], ... }) = + SOME (((p, ty), t), body) + | split_let _ = NONE; -val split_let_no_pat = - (fn ICase { term = t, typ = ty, clauses = [(IVar v, body)], ... } => SOME (((v, ty), t), body) - | _ => NONE); +fun split_let_no_pat (ICase { term = t, typ = ty, clauses = [(IVar v, body)], ... }) = + SOME (((v, ty), t), body) + | split_let_no_pat _ = NONE; val unfold_let = unfoldr split_let; @@ -206,8 +208,8 @@ | fold' (IVar _) = I | fold' (t1 `$ t2) = fold' t1 #> fold' t2 | fold' (_ `|=> t) = fold' t - | fold' (ICase { term = t, clauses = clauses, ... }) = fold' t - #> fold (fn (p, body) => fold' p #> fold' body) clauses + | fold' (ICase { term = t, clauses = clauses, ... }) = + fold' t #> fold (fn (p, body) => fold' p #> fold' body) clauses in fold' end; val add_constsyms = fold_constexprs (fn { sym, ... } => insert (op =) sym); @@ -240,6 +242,9 @@ fun exists_var t v = fold_varnames (fn w => fn b => v = w orelse b) t false; +fun invent_params used tys = + (map o apfst) SOME (Name.invent_names (Name.build_context used) "a" tys); + fun split_pat_abs ((NONE, ty) `|=> t) = SOME ((IVar NONE, ty), t) | split_pat_abs ((SOME v, ty) `|=> t) = SOME (case t of ICase { term = IVar (SOME w), clauses = [(p, body)], ... } => @@ -252,25 +257,21 @@ val unfold_pat_abs = unfoldr split_pat_abs; fun unfold_abs_eta [] t = ([], t) - | unfold_abs_eta (_ :: tys) (v_ty `|=> t) = + | unfold_abs_eta (_ :: tys) ((v, _) `|=> t) = let - val (vs_tys, t') = unfold_abs_eta tys t; - in (v_ty :: vs_tys, t') end + val (vs, t') = unfold_abs_eta tys t; + in (v :: vs, t') end | unfold_abs_eta tys t = let - val ctxt = Name.build_context (declare_varnames t); - val vs_tys = (map o apfst) SOME (Name.invent_names ctxt "a" tys); - in (vs_tys, t `$$ map (IVar o fst) vs_tys) end; + val vs = map fst (invent_params (declare_varnames t) tys); + in (vs, t `$$ map IVar vs) end; -fun eta_expand k (const as { dom = tys, ... }, ts) = +fun eta_expand wanted (const as { dom = tys, ... }, ts) = let - val j = length ts; - val l = k - j; - val _ = if l > length tys - then error "Impossible eta-expansion" else (); - val vars = Name.build_context (fold declare_varnames ts); - val vs_tys = (map o apfst) SOME - (Name.invent_names vars "a" ((take l o drop j) tys)); + val given = length ts; + val delta = wanted - given; + val vs_tys = invent_params (fold declare_varnames ts) + (((take delta o drop given) tys)); in vs_tys `|==> IConst const `$$ ts @ map (IVar o fst) vs_tys end; fun map_terms_bottom_up f (t as IConst _) = f t @@ -316,8 +317,7 @@ |> the_default [(pat_args, body)] | NONE => [(pat_args, body)]) | distill vs_map pat_args body = [(pat_args, body)]; - val (vTs, body) = unfold_abs_eta tys t; - val vs = map fst vTs; + val (vs, body) = unfold_abs_eta tys t; val vs_map = build (fold_index (fn (i, SOME v) => cons (v, i) | _ => I) vs); in distill vs_map (map IVar vs) body end; @@ -639,7 +639,7 @@ fun translate_classparam_instance (c, ty) = let val raw_const = Const (c, map_type_tfree (K arity_typ') ty); - val dom_length = length (fst (strip_type ty)) + val dom_length = length (binder_types ty); val thm = Axclass.unoverload_conv ctxt (Thm.cterm_of ctxt raw_const); val const = (apsnd Logic.unvarifyT_global o dest_Const o snd o Logic.dest_equals o Thm.prop_of) thm; @@ -766,21 +766,22 @@ clauses = (filter_out (is_undefined_clause ctxt) o distill_clauses constrs o project_cases) ts, primitive = t_app `$$ ts }) end -and translate_app_case ctxt algbr eqngr permissive some_thm (num_args, pattern_schema) ((c, ty), ts) = - if length ts < num_args then +and translate_app_case ctxt algbr eqngr permissive some_thm (wanted, pattern_schema) ((c, ty), ts) = + if length ts < wanted then let - val k = length ts; - val tys = (take (num_args - k) o drop k o fst o strip_type) ty; - val names = Name.build_context (ts |> (fold o fold_aterms) Term.declare_term_frees); - val vs = Name.invent_names names "a" tys; + val given = length ts; + val delta = wanted - given; + val tys = (take delta o drop given o binder_types) ty; + val used = Name.build_context ((fold o fold_aterms) Term.declare_term_frees ts); + val vs_tys = Name.invent_names used "a" tys; in fold_map (translate_typ ctxt algbr eqngr permissive) tys - ##>> translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), ts @ map Free vs) - #>> (fn (tys, t) => map2 (fn (v, _) => pair (SOME v)) vs tys `|==> t) + ##>> translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), ts @ map Free vs_tys) + #>> (fn (tys, t) => map2 (fn (v, _) => pair (SOME v)) vs_tys tys `|==> t) end - else if length ts > num_args then - translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), take num_args ts) - ##>> fold_map (translate_term ctxt algbr eqngr permissive some_thm o rpair NONE) (drop num_args ts) + else if length ts > wanted then + translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), take wanted ts) + ##>> fold_map (translate_term ctxt algbr eqngr permissive some_thm o rpair NONE) (drop wanted ts) #>> (fn (t, ts) => t `$$ ts) else translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), ts)