merged
authorpaulson
Thu, 24 Mar 2022 22:21:24 +0000
changeset 75330 bcb7d5f1f535
parent 75326 89d975dd39f1 (diff)
parent 75329 1fb80d2a778d (current diff)
child 75331 c3f1bf2824bc
merged
--- a/src/Tools/Code/code_thingol.ML	Thu Mar 24 18:50:11 2022 +0000
+++ b/src/Tools/Code/code_thingol.ML	Thu Mar 24 22:21:24 2022 +0000
@@ -53,6 +53,7 @@
   val split_pat_abs: iterm -> ((iterm * itype) * iterm) option
   val unfold_pat_abs: iterm -> (iterm * itype) list * iterm
   val unfold_const_app: iterm -> (const * iterm list) option
+  val map_terms_bottom_up: (iterm -> iterm) -> iterm -> iterm
   val is_IVar: iterm -> bool
   val is_IAbs: iterm -> bool
   val eta_expand: int -> const * iterm list -> iterm
@@ -78,7 +79,6 @@
   type program = stmt Code_Symbol.Graph.T
   val unimplemented: program -> string list
   val implemented_deps: program -> string list
-  val map_terms_bottom_up: (iterm -> iterm) -> iterm -> iterm
   val map_terms_stmt: (iterm -> iterm) -> stmt -> stmt
   val is_constr: program -> Code_Symbol.T -> bool
   val is_case: stmt -> bool
@@ -268,6 +268,46 @@
       (Name.invent_names vars "a" ((take l o drop j) tys));
   in vs_tys `|==> IConst const `$$ ts @ map (IVar o fst) vs_tys end;
 
+fun map_terms_bottom_up f (t as IConst _) = f t
+  | map_terms_bottom_up f (t as IVar _) = f t
+  | map_terms_bottom_up f (t1 `$ t2) = f
+      (map_terms_bottom_up f t1 `$ map_terms_bottom_up f t2)
+  | map_terms_bottom_up f ((v, ty) `|=> t) = f
+      ((v, ty) `|=> map_terms_bottom_up f t)
+  | map_terms_bottom_up f (ICase { term = t, typ = ty, clauses = clauses, primitive = t0 }) = f
+      (ICase { term = map_terms_bottom_up f t, typ = ty,
+        clauses = (map o apply2) (map_terms_bottom_up f) clauses,
+        primitive = map_terms_bottom_up f t0 });
+
+fun adjungate_clause ctxt vs_map ts (body as IConst { sym = Constant c, ... }) =
+      if Code.is_undefined (Proof_Context.theory_of ctxt) c
+      then []
+      else [(ts, body)]
+  | adjungate_clause ctxt vs_map ts (body as ICase { term = IVar (SOME v), clauses = clauses, ... }) =
+      let
+        val vs = (fold o fold_varnames) (insert (op =)) ts [];
+        fun varnames_disjunctive pat =
+          null (inter (op =) vs (fold_varnames (insert (op =)) pat []));
+        fun purge_unused_vars_in t =
+          let
+            val vs = fold_varnames (insert (op =)) t [];
+          in
+            map_terms_bottom_up (fn IVar (SOME v) =>
+              IVar (if member (op =) vs v then SOME v else NONE) | t => t)
+          end;
+      in
+        if forall (fn (pat', body') => exists_var pat' v
+          orelse not (exists_var body' v)) clauses
+          andalso forall (varnames_disjunctive o fst) clauses
+        then case AList.lookup (op =) vs_map v
+         of SOME i => maps (fn (pat', body') =>
+              adjungate_clause ctxt (AList.delete (op =) v vs_map)
+                (nth_map i (K pat') ts |> map (purge_unused_vars_in body')) body') clauses
+          | NONE => [(ts, body)]
+        else [(ts, body)]
+      end
+  | adjungate_clause ctxt vs_map ts body = [(ts, body)];
+
 fun exists_dict_var f (Dict (_, d)) = exists_plain_dict_var_pred f d
 and exists_plain_dict_var_pred f (Dict_Const (_, dss)) = exists_dictss_var f dss
   | exists_plain_dict_var_pred f (Dict_Var x) = f x
@@ -308,17 +348,6 @@
   |> subtract (op =) (Code_Symbol.Graph.all_preds program (map Constant (unimplemented program)))
   |> map_filter (fn Constant c => SOME c | _ => NONE);
 
-fun map_terms_bottom_up f (t as IConst _) = f t
-  | map_terms_bottom_up f (t as IVar _) = f t
-  | map_terms_bottom_up f (t1 `$ t2) = f
-      (map_terms_bottom_up f t1 `$ map_terms_bottom_up f t2)
-  | map_terms_bottom_up f ((v, ty) `|=> t) = f
-      ((v, ty) `|=> map_terms_bottom_up f t)
-  | map_terms_bottom_up f (ICase { term = t, typ = ty, clauses = clauses, primitive = t0 }) = f
-      (ICase { term = map_terms_bottom_up f t, typ = ty,
-        clauses = (map o apply2) (map_terms_bottom_up f) clauses,
-        primitive = map_terms_bottom_up f t0 });
-
 fun map_classparam_instances_as_term f =
   (map o apfst o apsnd o apfst) (fn const => case f (IConst const) of IConst const' => const')
 
@@ -473,6 +502,31 @@
       (annotate ctxt' algbr eqngr (c, ty) args rhs, some_abs)))) eqns
   end;
 
+
+(* preprocessing pattern schemas *)
+
+fun preprocess_pattern_schema ctxt (t_pos, case_pats) (c_ty, ts) =
+  let
+    val thy = Proof_Context.theory_of ctxt;
+    val ty = nth (binder_types (snd c_ty)) t_pos;
+    fun select_clauses xs =
+      xs
+      |> nth_drop t_pos
+      |> curry (op ~~) case_pats
+      |> map_filter (fn (NONE, _) => NONE | (SOME _, x) => SOME x);
+    fun mk_constr c t =
+      let
+        val n = Code.args_number thy c;
+      in ((c, (take n o binder_types o fastype_of o untag_term) t ---> ty), n) end;
+    val constrs =
+      if null case_pats then []
+      else map2 mk_constr (case_pats |> map_filter I) (select_clauses ts);
+    val split_clauses =
+      if null case_pats then (fn ts => (nth ts t_pos, nth_drop t_pos ts))
+        else (fn ts => (nth ts t_pos, select_clauses ts));
+  in (ty, constrs, split_clauses) end;
+
+
 (* abstract dictionary construction *)
 
 datatype typarg_witness =
@@ -673,74 +727,33 @@
   translate_const ctxt algbr eqngr permissive some_thm (c_ty, some_abs)
   ##>> fold_map (translate_term ctxt algbr eqngr permissive some_thm o rpair NONE) ts
   #>> (fn (t, ts) => t `$$ ts)
-and translate_case ctxt algbr eqngr permissive some_thm (num_args, (t_pos, case_pats)) (c_ty, ts) =
+and translate_case ctxt algbr eqngr permissive some_thm (t_pos, case_pats) (c_ty, ts) =
   let
-    val thy = Proof_Context.theory_of ctxt;
-    fun arg_types num_args ty = fst (chop num_args (binder_types ty));
-    val tys = arg_types num_args (snd c_ty);
-    val ty = nth tys t_pos;
-    fun mk_constr NONE t = NONE
-      | mk_constr (SOME c) t =
-          let
-            val n = Code.args_number thy c;
-          in SOME ((c, arg_types n (fastype_of (untag_term t)) ---> ty), n) end;
-    val constrs =
-      if null case_pats then []
-      else map_filter I (map2 mk_constr case_pats (nth_drop t_pos ts));
-    fun disjunctive_varnames ts =
-      let
-        val vs = (fold o fold_varnames) (insert (op =)) ts [];
-      in fn pat => null (inter (op =) vs (fold_varnames (insert (op =)) pat [])) end;
-    fun purge_unused_vars_in t =
-      let
-        val vs = fold_varnames (insert (op =)) t [];
-      in
-        map_terms_bottom_up (fn IVar (SOME v) =>
-          IVar (if member (op =) vs v then SOME v else NONE) | t => t)
-      end;
-    fun collapse_clause vs_map ts body =
-      case body
-       of IConst { sym = Constant c, ... } => if Code.is_undefined thy c
-            then []
-            else [(ts, body)]
-        | ICase { term = IVar (SOME v), clauses = clauses, ... } =>
-            if forall (fn (pat', body') => exists_var pat' v
-              orelse not (exists_var body' v)) clauses
-              andalso forall (disjunctive_varnames ts o fst) clauses
-            then case AList.lookup (op =) vs_map v
-             of SOME i => maps (fn (pat', body') =>
-                  collapse_clause (AList.delete (op =) v vs_map)
-                    (nth_map i (K pat') ts |> map (purge_unused_vars_in body')) body') clauses
-              | NONE => [(ts, body)]
-            else [(ts, body)]
-        | _ => [(ts, body)];
-    fun mk_clause mk tys t =
+    val (ty, constrs, split_clauses) =
+      preprocess_pattern_schema ctxt (t_pos, case_pats) (c_ty, ts);
+    fun distill_clause tys t =
       let
         val (vs, body) = unfold_abs_eta tys t;
         val vs_map = fold_index (fn (i, (SOME v, _)) => cons (v, i) | _ => I) vs [];
         val ts = map (IVar o fst) vs;
-      in map mk (collapse_clause vs_map ts body) end;
-    fun casify constrs ty t_app ts =
-      let
-        val t = nth ts t_pos;
-        val ts_clause = nth_drop t_pos ts;
-        val clauses = if null case_pats
-          then mk_clause (fn ([t], body) => (t, body)) [ty] (the_single ts_clause)
-          else maps (fn ((constr as IConst { dom = tys, ... }, n), t) =>
-            mk_clause (fn (ts, body) => (constr `$$ ts, body)) (take n tys) t)
-              (constrs ~~ (map_filter (fn (NONE, _) => NONE | (SOME _, t) => SOME t)
-                (case_pats ~~ ts_clause)));
-      in ICase { term = t, typ = ty, clauses = clauses, primitive = t_app `$$ ts } end;
+      in adjungate_clause ctxt vs_map ts body end;
+    fun mk_clauses [] ty (t, ts_clause) =
+          (t, map (fn ([t], body) => (t, body)) (distill_clause [ty] (the_single ts_clause)))
+      | mk_clauses constrs ty (t, ts_clause) =
+          (t, maps (fn ((constr as IConst { dom = tys, ... }, n), t) =>
+            map (fn (ts, body) => (constr `$$ ts, body)) (distill_clause (take n tys) t))
+              (constrs ~~ ts_clause));
   in
     translate_const ctxt algbr eqngr permissive some_thm (c_ty, NONE)
     ##>> fold_map (fn (constr, n) => translate_const ctxt algbr eqngr permissive some_thm (constr, NONE)
       #>> rpair n) constrs
     ##>> translate_typ ctxt algbr eqngr permissive ty
     ##>> fold_map (translate_term ctxt algbr eqngr permissive some_thm o rpair NONE) ts
-    #>> (fn (((t, constrs), ty), ts) =>
-      casify constrs ty t ts)
+    #>> (fn (((t_app, constrs), ty), ts) =>
+      case mk_clauses constrs ty (split_clauses ts) of (t, clauses) =>
+        ICase { term = t, typ = ty, clauses = clauses, primitive = t_app `$$ ts })
   end
-and translate_app_case ctxt algbr eqngr permissive some_thm (case_schema as (num_args, _)) ((c, ty), ts) =
+and translate_app_case ctxt algbr eqngr permissive some_thm (num_args, pattern_schema) ((c, ty), ts) =
   if length ts < num_args then
     let
       val k = length ts;
@@ -749,15 +762,15 @@
       val vs = Name.invent_names names "a" tys;
     in
       fold_map (translate_typ ctxt algbr eqngr permissive) tys
-      ##>> translate_case ctxt algbr eqngr permissive some_thm case_schema ((c, ty), ts @ map Free vs)
+      ##>> translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), ts @ map Free vs)
       #>> (fn (tys, t) => map2 (fn (v, _) => pair (SOME v)) vs tys `|==> t)
     end
   else if length ts > num_args then
-    translate_case ctxt algbr eqngr permissive some_thm case_schema ((c, ty), take num_args ts)
+    translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), take num_args ts)
     ##>> fold_map (translate_term ctxt algbr eqngr permissive some_thm o rpair NONE) (drop num_args ts)
     #>> (fn (t, ts) => t `$$ ts)
   else
-    translate_case ctxt algbr eqngr permissive some_thm case_schema ((c, ty), ts)
+    translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), ts)
 and translate_app ctxt algbr eqngr permissive some_thm (c_ty_ts as ((c, _), _), some_abs) =
   case Code.get_case_schema (Proof_Context.theory_of ctxt) c
    of SOME case_schema => translate_app_case ctxt algbr eqngr permissive some_thm case_schema c_ty_ts