tuned
authorhaftmann
Sun, 12 Feb 2023 06:45:58 +0000
changeset 77231 04571037ed33
parent 77230 2d26af072990
child 77232 6cad6ed2700a
tuned
src/Tools/Code/code_haskell.ML
src/Tools/Code/code_thingol.ML
--- a/src/Tools/Code/code_haskell.ML	Fri Feb 10 14:51:51 2023 +0000
+++ b/src/Tools/Code/code_haskell.ML	Sun Feb 12 06:45:58 2023 +0000
@@ -252,7 +252,7 @@
                       |> intro_vars (map_filter I (s :: vs));
                     val lhs = IConst { sym = Constant classparam, typargs = [],
                       dicts = [], dom = dom, annotation = NONE } `$$ map IVar vs;
-                      (*dictionaries are not relevant at this late stage,
+                      (*dictionaries are not relevant in Haskell,
                         and these consts never need type annotations for disambiguation *)
                   in
                     semicolon [
--- a/src/Tools/Code/code_thingol.ML	Fri Feb 10 14:51:51 2023 +0000
+++ b/src/Tools/Code/code_thingol.ML	Sun Feb 12 06:45:58 2023 +0000
@@ -9,6 +9,7 @@
 infix 4 `$;
 infix 4 `$$;
 infixr 3 `->;
+infixr 3 `-->;
 infixr 3 `|=>;
 infixr 3 `|==>;
 
@@ -32,6 +33,7 @@
     | `|=> of (vname option * itype) * iterm
     | ICase of { term: iterm, typ: itype, clauses: (iterm * iterm) list, primitive: iterm };
   val `-> : itype * itype -> itype;
+  val `--> : itype list * itype -> itype;
   val `$$ : iterm * iterm list -> iterm;
   val `|==> : (vname option * itype) list * iterm -> iterm;
   type typscheme = (vname * sort) list * itype;
@@ -53,7 +55,6 @@
   val split_pat_abs: iterm -> ((iterm * itype) * iterm) option
   val unfold_pat_abs: iterm -> (iterm * itype) list * iterm
   val unfold_const_app: iterm -> (const * iterm list) option
-  val map_terms_bottom_up: (iterm -> iterm) -> iterm -> iterm
   val is_IVar: iterm -> bool
   val is_IAbs: iterm -> bool
   val eta_expand: int -> const * iterm list -> iterm
@@ -144,6 +145,8 @@
 
 fun ty1 `-> ty2 = "fun" `%% [ty1, ty2];
 
+val op `--> = Library.foldr (op `->);
+
 val unfold_fun = unfoldr
   (fn "fun" `%% [ty1, ty2] => SOME (ty1, ty2)
     | _ => NONE);
@@ -152,8 +155,7 @@
   let
     val (tys1, ty1) = unfold_fun ty;
     val (tys3, tys2) = chop n tys1;
-    val ty3 = Library.foldr (op `->) (tys2, ty1);
-  in (tys3, ty3) end;
+  in (tys3, tys2 `--> ty1) end;
 
 type const = { sym: Code_Symbol.T, typargs: itype list, dicts: dict list list,
   dom: itype list, annotation: itype option };
@@ -176,20 +178,20 @@
 val op `|==> = Library.foldr (op `|=>);
 
 val unfold_app = unfoldl
-  (fn op `$ t => SOME t
+  (fn op `$ t_t => SOME t_t
     | _ => NONE);
 
 val unfold_abs = unfoldr
-  (fn op `|=> t => SOME t
+  (fn op `|=> v_t => SOME v_t
     | _ => NONE);
 
-val split_let = 
-  (fn ICase { term = t, typ = ty, clauses = [(p, body)], ... } => SOME (((p, ty), t), body)
-    | _ => NONE);
+fun split_let (ICase { term = t, typ = ty, clauses = [(p, body)], ... }) =
+      SOME (((p, ty), t), body)
+  | split_let _ = NONE;
 
-val split_let_no_pat = 
-  (fn ICase { term = t, typ = ty, clauses = [(IVar v, body)], ... } => SOME (((v, ty), t), body)
-    | _ => NONE);
+fun split_let_no_pat (ICase { term = t, typ = ty, clauses = [(IVar v, body)], ... }) =
+      SOME (((v, ty), t), body)
+  | split_let_no_pat _ = NONE;
 
 val unfold_let = unfoldr split_let;
 
@@ -206,8 +208,8 @@
       | fold' (IVar _) = I
       | fold' (t1 `$ t2) = fold' t1 #> fold' t2
       | fold' (_ `|=> t) = fold' t
-      | fold' (ICase { term = t, clauses = clauses, ... }) = fold' t
-          #> fold (fn (p, body) => fold' p #> fold' body) clauses
+      | fold' (ICase { term = t, clauses = clauses, ... }) =
+          fold' t #> fold (fn (p, body) => fold' p #> fold' body) clauses
   in fold' end;
 
 val add_constsyms = fold_constexprs (fn { sym, ... } => insert (op =) sym);
@@ -240,6 +242,9 @@
 
 fun exists_var t v = fold_varnames (fn w => fn b => v = w orelse b) t false;
 
+fun invent_params used tys =
+  (map o apfst) SOME (Name.invent_names (Name.build_context used) "a" tys);
+
 fun split_pat_abs ((NONE, ty) `|=> t) = SOME ((IVar NONE, ty), t)
   | split_pat_abs ((SOME v, ty) `|=> t) = SOME (case t
      of ICase { term = IVar (SOME w), clauses = [(p, body)], ... } =>
@@ -252,25 +257,21 @@
 val unfold_pat_abs = unfoldr split_pat_abs;
 
 fun unfold_abs_eta [] t = ([], t)
-  | unfold_abs_eta (_ :: tys) (v_ty `|=> t) =
+  | unfold_abs_eta (_ :: tys) ((v, _) `|=> t) =
       let
-        val (vs_tys, t') = unfold_abs_eta tys t;
-      in (v_ty :: vs_tys, t') end
+        val (vs, t') = unfold_abs_eta tys t;
+      in (v :: vs, t') end
   | unfold_abs_eta tys t =
       let
-        val ctxt = Name.build_context (declare_varnames t);
-        val vs_tys = (map o apfst) SOME (Name.invent_names ctxt "a" tys);
-      in (vs_tys, t `$$ map (IVar o fst) vs_tys) end;
+        val vs = map fst (invent_params (declare_varnames t) tys);
+      in (vs, t `$$ map IVar vs) end;
 
-fun eta_expand k (const as { dom = tys, ... }, ts) =
+fun eta_expand wanted (const as { dom = tys, ... }, ts) =
   let
-    val j = length ts;
-    val l = k - j;
-    val _ = if l > length tys
-      then error "Impossible eta-expansion" else ();
-    val vars = Name.build_context (fold declare_varnames ts);
-    val vs_tys = (map o apfst) SOME
-      (Name.invent_names vars "a" ((take l o drop j) tys));
+    val given = length ts;
+    val delta = wanted - given;
+    val vs_tys = invent_params (fold declare_varnames ts)
+      (((take delta o drop given) tys));
   in vs_tys `|==> IConst const `$$ ts @ map (IVar o fst) vs_tys end;
 
 fun map_terms_bottom_up f (t as IConst _) = f t
@@ -316,8 +317,7 @@
                 |> the_default [(pat_args, body)]
             | NONE => [(pat_args, body)])
       | distill vs_map pat_args body = [(pat_args, body)];
-    val (vTs, body) = unfold_abs_eta tys t;
-    val vs = map fst vTs;
+    val (vs, body) = unfold_abs_eta tys t;
     val vs_map =
       build (fold_index (fn (i, SOME v) => cons (v, i) | _ => I) vs);
   in distill vs_map (map IVar vs) body end;
@@ -639,7 +639,7 @@
     fun translate_classparam_instance (c, ty) =
       let
         val raw_const = Const (c, map_type_tfree (K arity_typ') ty);
-        val dom_length = length (fst (strip_type ty))
+        val dom_length = length (binder_types ty);
         val thm = Axclass.unoverload_conv ctxt (Thm.cterm_of ctxt raw_const);
         val const = (apsnd Logic.unvarifyT_global o dest_Const o snd
           o Logic.dest_equals o Thm.prop_of) thm;
@@ -766,21 +766,22 @@
               clauses = (filter_out (is_undefined_clause ctxt) o distill_clauses constrs o project_cases) ts,
               primitive = t_app `$$ ts })
       end
-and translate_app_case ctxt algbr eqngr permissive some_thm (num_args, pattern_schema) ((c, ty), ts) =
-  if length ts < num_args then
+and translate_app_case ctxt algbr eqngr permissive some_thm (wanted, pattern_schema) ((c, ty), ts) =
+  if length ts < wanted then
     let
-      val k = length ts;
-      val tys = (take (num_args - k) o drop k o fst o strip_type) ty;
-      val names = Name.build_context (ts |> (fold o fold_aterms) Term.declare_term_frees);
-      val vs = Name.invent_names names "a" tys;
+      val given = length ts;
+      val delta = wanted - given;
+      val tys = (take delta o drop given o binder_types) ty;
+      val used = Name.build_context ((fold o fold_aterms) Term.declare_term_frees ts);
+      val vs_tys = Name.invent_names used "a" tys;
     in
       fold_map (translate_typ ctxt algbr eqngr permissive) tys
-      ##>> translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), ts @ map Free vs)
-      #>> (fn (tys, t) => map2 (fn (v, _) => pair (SOME v)) vs tys `|==> t)
+      ##>> translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), ts @ map Free vs_tys)
+      #>> (fn (tys, t) => map2 (fn (v, _) => pair (SOME v)) vs_tys tys `|==> t)
     end
-  else if length ts > num_args then
-    translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), take num_args ts)
-    ##>> fold_map (translate_term ctxt algbr eqngr permissive some_thm o rpair NONE) (drop num_args ts)
+  else if length ts > wanted then
+    translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), take wanted ts)
+    ##>> fold_map (translate_term ctxt algbr eqngr permissive some_thm o rpair NONE) (drop wanted ts)
     #>> (fn (t, ts) => t `$$ ts)
   else
     translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), ts)