--- a/src/Tools/Code/code_thingol.ML Thu Mar 24 18:50:11 2022 +0000
+++ b/src/Tools/Code/code_thingol.ML Thu Mar 24 22:21:24 2022 +0000
@@ -53,6 +53,7 @@
val split_pat_abs: iterm -> ((iterm * itype) * iterm) option
val unfold_pat_abs: iterm -> (iterm * itype) list * iterm
val unfold_const_app: iterm -> (const * iterm list) option
+ val map_terms_bottom_up: (iterm -> iterm) -> iterm -> iterm
val is_IVar: iterm -> bool
val is_IAbs: iterm -> bool
val eta_expand: int -> const * iterm list -> iterm
@@ -78,7 +79,6 @@
type program = stmt Code_Symbol.Graph.T
val unimplemented: program -> string list
val implemented_deps: program -> string list
- val map_terms_bottom_up: (iterm -> iterm) -> iterm -> iterm
val map_terms_stmt: (iterm -> iterm) -> stmt -> stmt
val is_constr: program -> Code_Symbol.T -> bool
val is_case: stmt -> bool
@@ -268,6 +268,46 @@
(Name.invent_names vars "a" ((take l o drop j) tys));
in vs_tys `|==> IConst const `$$ ts @ map (IVar o fst) vs_tys end;
+fun map_terms_bottom_up f (t as IConst _) = f t
+ | map_terms_bottom_up f (t as IVar _) = f t
+ | map_terms_bottom_up f (t1 `$ t2) = f
+ (map_terms_bottom_up f t1 `$ map_terms_bottom_up f t2)
+ | map_terms_bottom_up f ((v, ty) `|=> t) = f
+ ((v, ty) `|=> map_terms_bottom_up f t)
+ | map_terms_bottom_up f (ICase { term = t, typ = ty, clauses = clauses, primitive = t0 }) = f
+ (ICase { term = map_terms_bottom_up f t, typ = ty,
+ clauses = (map o apply2) (map_terms_bottom_up f) clauses,
+ primitive = map_terms_bottom_up f t0 });
+
+fun adjungate_clause ctxt vs_map ts (body as IConst { sym = Constant c, ... }) =
+ if Code.is_undefined (Proof_Context.theory_of ctxt) c
+ then []
+ else [(ts, body)]
+ | adjungate_clause ctxt vs_map ts (body as ICase { term = IVar (SOME v), clauses = clauses, ... }) =
+ let
+ val vs = (fold o fold_varnames) (insert (op =)) ts [];
+ fun varnames_disjunctive pat =
+ null (inter (op =) vs (fold_varnames (insert (op =)) pat []));
+ fun purge_unused_vars_in t =
+ let
+ val vs = fold_varnames (insert (op =)) t [];
+ in
+ map_terms_bottom_up (fn IVar (SOME v) =>
+ IVar (if member (op =) vs v then SOME v else NONE) | t => t)
+ end;
+ in
+ if forall (fn (pat', body') => exists_var pat' v
+ orelse not (exists_var body' v)) clauses
+ andalso forall (varnames_disjunctive o fst) clauses
+ then case AList.lookup (op =) vs_map v
+ of SOME i => maps (fn (pat', body') =>
+ adjungate_clause ctxt (AList.delete (op =) v vs_map)
+ (nth_map i (K pat') ts |> map (purge_unused_vars_in body')) body') clauses
+ | NONE => [(ts, body)]
+ else [(ts, body)]
+ end
+ | adjungate_clause ctxt vs_map ts body = [(ts, body)];
+
fun exists_dict_var f (Dict (_, d)) = exists_plain_dict_var_pred f d
and exists_plain_dict_var_pred f (Dict_Const (_, dss)) = exists_dictss_var f dss
| exists_plain_dict_var_pred f (Dict_Var x) = f x
@@ -308,17 +348,6 @@
|> subtract (op =) (Code_Symbol.Graph.all_preds program (map Constant (unimplemented program)))
|> map_filter (fn Constant c => SOME c | _ => NONE);
-fun map_terms_bottom_up f (t as IConst _) = f t
- | map_terms_bottom_up f (t as IVar _) = f t
- | map_terms_bottom_up f (t1 `$ t2) = f
- (map_terms_bottom_up f t1 `$ map_terms_bottom_up f t2)
- | map_terms_bottom_up f ((v, ty) `|=> t) = f
- ((v, ty) `|=> map_terms_bottom_up f t)
- | map_terms_bottom_up f (ICase { term = t, typ = ty, clauses = clauses, primitive = t0 }) = f
- (ICase { term = map_terms_bottom_up f t, typ = ty,
- clauses = (map o apply2) (map_terms_bottom_up f) clauses,
- primitive = map_terms_bottom_up f t0 });
-
fun map_classparam_instances_as_term f =
(map o apfst o apsnd o apfst) (fn const => case f (IConst const) of IConst const' => const')
@@ -473,6 +502,31 @@
(annotate ctxt' algbr eqngr (c, ty) args rhs, some_abs)))) eqns
end;
+
+(* preprocessing pattern schemas *)
+
+fun preprocess_pattern_schema ctxt (t_pos, case_pats) (c_ty, ts) =
+ let
+ val thy = Proof_Context.theory_of ctxt;
+ val ty = nth (binder_types (snd c_ty)) t_pos;
+ fun select_clauses xs =
+ xs
+ |> nth_drop t_pos
+ |> curry (op ~~) case_pats
+ |> map_filter (fn (NONE, _) => NONE | (SOME _, x) => SOME x);
+ fun mk_constr c t =
+ let
+ val n = Code.args_number thy c;
+ in ((c, (take n o binder_types o fastype_of o untag_term) t ---> ty), n) end;
+ val constrs =
+ if null case_pats then []
+ else map2 mk_constr (case_pats |> map_filter I) (select_clauses ts);
+ val split_clauses =
+ if null case_pats then (fn ts => (nth ts t_pos, nth_drop t_pos ts))
+ else (fn ts => (nth ts t_pos, select_clauses ts));
+ in (ty, constrs, split_clauses) end;
+
+
(* abstract dictionary construction *)
datatype typarg_witness =
@@ -673,74 +727,33 @@
translate_const ctxt algbr eqngr permissive some_thm (c_ty, some_abs)
##>> fold_map (translate_term ctxt algbr eqngr permissive some_thm o rpair NONE) ts
#>> (fn (t, ts) => t `$$ ts)
-and translate_case ctxt algbr eqngr permissive some_thm (num_args, (t_pos, case_pats)) (c_ty, ts) =
+and translate_case ctxt algbr eqngr permissive some_thm (t_pos, case_pats) (c_ty, ts) =
let
- val thy = Proof_Context.theory_of ctxt;
- fun arg_types num_args ty = fst (chop num_args (binder_types ty));
- val tys = arg_types num_args (snd c_ty);
- val ty = nth tys t_pos;
- fun mk_constr NONE t = NONE
- | mk_constr (SOME c) t =
- let
- val n = Code.args_number thy c;
- in SOME ((c, arg_types n (fastype_of (untag_term t)) ---> ty), n) end;
- val constrs =
- if null case_pats then []
- else map_filter I (map2 mk_constr case_pats (nth_drop t_pos ts));
- fun disjunctive_varnames ts =
- let
- val vs = (fold o fold_varnames) (insert (op =)) ts [];
- in fn pat => null (inter (op =) vs (fold_varnames (insert (op =)) pat [])) end;
- fun purge_unused_vars_in t =
- let
- val vs = fold_varnames (insert (op =)) t [];
- in
- map_terms_bottom_up (fn IVar (SOME v) =>
- IVar (if member (op =) vs v then SOME v else NONE) | t => t)
- end;
- fun collapse_clause vs_map ts body =
- case body
- of IConst { sym = Constant c, ... } => if Code.is_undefined thy c
- then []
- else [(ts, body)]
- | ICase { term = IVar (SOME v), clauses = clauses, ... } =>
- if forall (fn (pat', body') => exists_var pat' v
- orelse not (exists_var body' v)) clauses
- andalso forall (disjunctive_varnames ts o fst) clauses
- then case AList.lookup (op =) vs_map v
- of SOME i => maps (fn (pat', body') =>
- collapse_clause (AList.delete (op =) v vs_map)
- (nth_map i (K pat') ts |> map (purge_unused_vars_in body')) body') clauses
- | NONE => [(ts, body)]
- else [(ts, body)]
- | _ => [(ts, body)];
- fun mk_clause mk tys t =
+ val (ty, constrs, split_clauses) =
+ preprocess_pattern_schema ctxt (t_pos, case_pats) (c_ty, ts);
+ fun distill_clause tys t =
let
val (vs, body) = unfold_abs_eta tys t;
val vs_map = fold_index (fn (i, (SOME v, _)) => cons (v, i) | _ => I) vs [];
val ts = map (IVar o fst) vs;
- in map mk (collapse_clause vs_map ts body) end;
- fun casify constrs ty t_app ts =
- let
- val t = nth ts t_pos;
- val ts_clause = nth_drop t_pos ts;
- val clauses = if null case_pats
- then mk_clause (fn ([t], body) => (t, body)) [ty] (the_single ts_clause)
- else maps (fn ((constr as IConst { dom = tys, ... }, n), t) =>
- mk_clause (fn (ts, body) => (constr `$$ ts, body)) (take n tys) t)
- (constrs ~~ (map_filter (fn (NONE, _) => NONE | (SOME _, t) => SOME t)
- (case_pats ~~ ts_clause)));
- in ICase { term = t, typ = ty, clauses = clauses, primitive = t_app `$$ ts } end;
+ in adjungate_clause ctxt vs_map ts body end;
+ fun mk_clauses [] ty (t, ts_clause) =
+ (t, map (fn ([t], body) => (t, body)) (distill_clause [ty] (the_single ts_clause)))
+ | mk_clauses constrs ty (t, ts_clause) =
+ (t, maps (fn ((constr as IConst { dom = tys, ... }, n), t) =>
+ map (fn (ts, body) => (constr `$$ ts, body)) (distill_clause (take n tys) t))
+ (constrs ~~ ts_clause));
in
translate_const ctxt algbr eqngr permissive some_thm (c_ty, NONE)
##>> fold_map (fn (constr, n) => translate_const ctxt algbr eqngr permissive some_thm (constr, NONE)
#>> rpair n) constrs
##>> translate_typ ctxt algbr eqngr permissive ty
##>> fold_map (translate_term ctxt algbr eqngr permissive some_thm o rpair NONE) ts
- #>> (fn (((t, constrs), ty), ts) =>
- casify constrs ty t ts)
+ #>> (fn (((t_app, constrs), ty), ts) =>
+ case mk_clauses constrs ty (split_clauses ts) of (t, clauses) =>
+ ICase { term = t, typ = ty, clauses = clauses, primitive = t_app `$$ ts })
end
-and translate_app_case ctxt algbr eqngr permissive some_thm (case_schema as (num_args, _)) ((c, ty), ts) =
+and translate_app_case ctxt algbr eqngr permissive some_thm (num_args, pattern_schema) ((c, ty), ts) =
if length ts < num_args then
let
val k = length ts;
@@ -749,15 +762,15 @@
val vs = Name.invent_names names "a" tys;
in
fold_map (translate_typ ctxt algbr eqngr permissive) tys
- ##>> translate_case ctxt algbr eqngr permissive some_thm case_schema ((c, ty), ts @ map Free vs)
+ ##>> translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), ts @ map Free vs)
#>> (fn (tys, t) => map2 (fn (v, _) => pair (SOME v)) vs tys `|==> t)
end
else if length ts > num_args then
- translate_case ctxt algbr eqngr permissive some_thm case_schema ((c, ty), take num_args ts)
+ translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), take num_args ts)
##>> fold_map (translate_term ctxt algbr eqngr permissive some_thm o rpair NONE) (drop num_args ts)
#>> (fn (t, ts) => t `$$ ts)
else
- translate_case ctxt algbr eqngr permissive some_thm case_schema ((c, ty), ts)
+ translate_case ctxt algbr eqngr permissive some_thm pattern_schema ((c, ty), ts)
and translate_app ctxt algbr eqngr permissive some_thm (c_ty_ts as ((c, _), _), some_abs) =
case Code.get_case_schema (Proof_Context.theory_of ctxt) c
of SOME case_schema => translate_app_case ctxt algbr eqngr permissive some_thm case_schema c_ty_ts