--- a/src/HOL/Tools/inductive_codegen.ML Thu Dec 20 14:57:54 2001 +0100
+++ b/src/HOL/Tools/inductive_codegen.ML Thu Dec 20 14:58:18 2001 +0100
@@ -8,6 +8,7 @@
signature INDUCTIVE_CODEGEN =
sig
+ val add : theory attribute
val setup : (theory -> theory) list
end;
@@ -16,10 +17,48 @@
open Codegen;
-exception Modes of (string * int list list) list * (string * int list list) list;
+(**** theory data ****)
+
+structure CodegenArgs =
+struct
+ val name = "HOL/inductive_codegen";
+ type T = thm list Symtab.table;
+ val empty = Symtab.empty;
+ val copy = I;
+ val prep_ext = I;
+ val merge = Symtab.merge_multi eq_thm;
+ fun print _ _ = ();
+end;
+
+structure CodegenData = TheoryDataFun(CodegenArgs);
+
+fun warn thm = warning ("InductiveCodegen: Not a proper clause:\n" ^
+ string_of_thm thm);
-datatype indprem = Prem of string * term list * term list
- | Sidecond of term;
+fun add (p as (thy, thm)) =
+ let
+ val tsig = Sign.tsig_of (sign_of thy);
+ val tab = CodegenData.get thy;
+ val matches = curry (Pattern.matches tsig o pairself concl_of);
+
+ in (case concl_of thm of
+ _ $ (Const ("op :", _) $ _ $ t) => (case head_of t of
+ Const (s, _) => (CodegenData.put (Symtab.update ((s,
+ filter_out (matches thm) (if_none (Symtab.lookup (tab, s)) []) @
+ [thm]), tab)) thy, thm)
+ | _ => (warn thm; p))
+ | _ => (warn thm; p))
+ end handle Pattern.Pattern => (warn thm; p);
+
+fun get_clauses thy s =
+ (case Symtab.lookup (CodegenData.get thy, s) of
+ None => (case InductivePackage.get_inductive thy s of
+ None => None
+ | Some ({names, ...}, {intrs, ...}) => Some (names, intrs))
+ | Some thms => Some ([s], thms));
+
+
+(**** improper tuples ****)
fun prod_factors p (Const ("Pair", _) $ t $ u) =
p :: prod_factors (1::p) t @ prod_factors (2::p) u
@@ -30,10 +69,44 @@
split_prod (1::p) ps t @ split_prod (2::p) ps u
| _ => error "Inconsistent use of products") else [t];
+datatype factors = FVar of int list list | FFix of int list list;
+
+exception Factors;
+
+fun mg_factor (FVar f) (FVar f') = FVar (f inter f')
+ | mg_factor (FVar f) (FFix f') =
+ if f' subset f then FFix f' else raise Factors
+ | mg_factor (FFix f) (FVar f') =
+ if f subset f' then FFix f else raise Factors
+ | mg_factor (FFix f) (FFix f') =
+ if f subset f' andalso f' subset f then FFix f else raise Factors;
+
+fun dest_factors (FVar f) = f
+ | dest_factors (FFix f) = f;
+
+fun infer_factors sg extra_fs (fs, (optf, t)) =
+ let fun err s = error (s ^ "\n" ^ Sign.string_of_term sg t)
+ in (case (optf, strip_comb t) of
+ (Some f, (Const (name, _), args)) =>
+ (case assoc (extra_fs, name) of
+ None => overwrite (fs, (name, if_none
+ (apsome (mg_factor f) (assoc (fs, name))) f))
+ | Some (fs', f') => (mg_factor f (FFix f');
+ foldl (infer_factors sg extra_fs)
+ (fs, map (apsome FFix) fs' ~~ args)))
+ | (Some f, (Var ((name, _), _), [])) =>
+ overwrite (fs, (name, if_none
+ (apsome (mg_factor f) (assoc (fs, name))) f))
+ | (None, _) => fs
+ | _ => err "Illegal term")
+ handle Factors => err "Product factor mismatch in"
+ end;
+
fun string_of_factors p ps = if p mem ps then
"(" ^ string_of_factors (1::p) ps ^ ", " ^ string_of_factors (2::p) ps ^ ")"
else "_";
+
(**** check if a term contains only constructor functions ****)
fun is_constrt thy =
@@ -81,9 +154,32 @@
in merge (map (fn ks => i::ks) is) is end
else [[]];
+fun cprod ([], ys) = []
+ | cprod (x :: xs, ys) = map (pair x) ys @ cprod (xs, ys);
+
+fun cprods xss = foldr (map op :: o cprod) (xss, [[]]);
+
+datatype mode = Mode of (int list option list * int list) * mode option list;
+
+fun modes_of modes t =
+ let
+ fun mk_modes name args = flat
+ (map (fn (m as (iss, is)) => map (Mode o pair m) (cprods (map
+ (fn (None, _) => [None]
+ | (Some js, arg) => map Some
+ (filter (fn Mode ((_, js'), _) => js=js') (modes_of modes arg)))
+ (iss ~~ args)))) (the (assoc (modes, name))))
+
+ in (case strip_comb t of
+ (Const (name, _), args) => mk_modes name args
+ | (Var ((name, _), _), args) => mk_modes name args)
+ end;
+
+datatype indprem = Prem of term list * term | Sidecond of term;
+
fun select_mode_prem thy modes vs ps =
find_first (is_some o snd) (ps ~~ map
- (fn Prem (s, us, args) => find_first (fn is =>
+ (fn Prem (us, t) => find_first (fn Mode ((_, is), _) =>
let
val (_, out_ts) = get_args is 1 us;
val vTs = flat (map term_vTs out_ts);
@@ -92,21 +188,25 @@
in
is subset known_args vs 1 us andalso
forall (is_constrt thy) (snd (get_args is 1 us)) andalso
- terms_vs args subset vs andalso
+ term_vs t subset vs andalso
forall is_eqT dupTs
end)
- (the (assoc (modes, s)))
- | Sidecond t => if term_vs t subset vs then Some [] else None) ps);
+ (modes_of modes t)
+ | Sidecond t => if term_vs t subset vs then Some (Mode (([], []), []))
+ else None) ps);
-fun check_mode_clause thy arg_vs modes mode (ts, ps) =
+fun check_mode_clause thy arg_vs modes (iss, is) (ts, ps) =
let
+ val modes' = modes @ mapfilter
+ (fn (_, None) => None | (v, Some js) => Some (v, [([], js)]))
+ (arg_vs ~~ iss);
fun check_mode_prems vs [] = Some vs
- | check_mode_prems vs ps = (case select_mode_prem thy modes vs ps of
+ | check_mode_prems vs ps = (case select_mode_prem thy modes' vs ps of
None => None
| Some (x, _) => check_mode_prems
- (case x of Prem (_, us, _) => vs union terms_vs us | _ => vs)
+ (case x of Prem (us, _) => vs union terms_vs us | _ => vs)
(filter_out (equal x) ps));
- val (in_ts', _) = get_args mode 1 ts;
+ val (in_ts', _) = get_args is 1 ts;
val in_ts = filter (is_constrt thy) in_ts';
val in_vs = terms_vs in_ts;
val concl_vs = terms_vs ts
@@ -125,9 +225,12 @@
let val y = f x
in if x = y then x else fixp f y end;
-fun infer_modes thy extra_modes arg_vs preds = fixp (fn modes =>
+fun infer_modes thy extra_modes factors arg_vs preds = fixp (fn modes =>
map (check_modes_pred thy arg_vs preds (modes @ extra_modes)) modes)
- (map (fn (s, (ts, _)::_) => (s, subsets 1 (length ts))) preds);
+ (map (fn (s, (fs, f)) => (s, cprod (cprods (map
+ (fn None => [None]
+ | Some f' => map Some (subsets 1 (length f' + 1))) fs),
+ subsets 1 (length f + 1)))) factors);
(**** code generation ****)
@@ -167,17 +270,37 @@
[Pretty.brk 1, Pretty.str "| _ => ", fail_p, Pretty.str ")"]))
end;
-fun modename thy s mode = space_implode "_"
- (mk_const_id (sign_of thy) s :: map string_of_int mode);
+fun modename thy s (iss, is) = space_implode "__"
+ (mk_const_id (sign_of thy) s ::
+ map (space_implode "_" o map string_of_int) (mapfilter I iss @ [is]));
-fun compile_clause thy gr dep all_vs arg_vs modes mode (ts, ps) =
+fun compile_expr thy dep brack (gr, (None, t)) =
+ apsnd single (invoke_codegen thy dep brack (gr, t))
+ | compile_expr _ _ _ (gr, (Some _, Var ((name, _), _))) =
+ (gr, [Pretty.str name])
+ | compile_expr thy dep brack (gr, (Some (Mode (mode, ms)), t)) =
+ let
+ val (Const (name, _), args) = strip_comb t;
+ val (gr', ps) = foldl_map
+ (compile_expr thy dep true) (gr, ms ~~ args);
+ in (gr', (if brack andalso not (null ps) then
+ single o parens o Pretty.block else I)
+ (flat (separate [Pretty.brk 1]
+ ([Pretty.str (modename thy name mode)] :: ps))))
+ end;
+
+fun compile_clause thy gr dep all_vs arg_vs modes (iss, is) (ts, ps) =
let
+ val modes' = modes @ mapfilter
+ (fn (_, None) => None | (v, Some js) => Some (v, [([], js)]))
+ (arg_vs ~~ iss);
+
fun check_constrt ((names, eqs), t) =
if is_constrt thy t then ((names, eqs), t) else
let val s = variant names "x";
in ((s::names, (s, t)::eqs), Var ((s, 0), fastype_of t)) end;
- val (in_ts, out_ts) = get_args mode 1 ts;
+ val (in_ts, out_ts) = get_args is 1 ts;
val ((all_vs', eqs), in_ts') =
foldl_map check_constrt ((all_vs, []), in_ts);
@@ -200,27 +323,25 @@
| compile_prems out_ts vs names gr ps =
let
val vs' = distinct (flat (vs :: map term_vs out_ts));
- val Some (p, Some mode') =
- select_mode_prem thy modes (arg_vs union vs') ps;
+ val Some (p, mode as Some (Mode ((_, js), _))) =
+ select_mode_prem thy modes' (arg_vs union vs') ps;
val ps' = filter_out (equal p) ps;
in
(case p of
- Prem (s, us, args) =>
+ Prem (us, t) =>
let
- val (in_ts, out_ts') = get_args mode' 1 us;
+ val (in_ts, out_ts') = get_args js 1 us;
val (gr1, in_ps) = foldl_map
(invoke_codegen thy dep false) (gr, in_ts);
- val (gr2, arg_ps) = foldl_map
- (invoke_codegen thy dep true) (gr1, args);
val (nvs, out_ts'') = foldl_map distinct_v
((names, map (fn x => (x, [x])) vs), out_ts);
- val (gr3, out_ps) = foldl_map
- (invoke_codegen thy dep false) (gr2, out_ts'')
+ val (gr2, out_ps) = foldl_map
+ (invoke_codegen thy dep false) (gr1, out_ts'');
+ val (gr3, ps) = compile_expr thy dep false (gr2, (mode, t));
val (gr4, rest) = compile_prems out_ts' vs' (fst nvs) gr3 ps';
in
(gr4, compile_match (snd nvs) [] out_ps
- (Pretty.block (separate (Pretty.brk 1)
- (Pretty.str (modename thy s mode') :: arg_ps) @
+ (Pretty.block (ps @
[Pretty.brk 1, mk_tuple in_ps,
Pretty.str " :->", Pretty.brk 1, rest]))
(Pretty.str "Seq.empty"))
@@ -269,69 +390,91 @@
(**** processing of introduction rules ****)
-val string_of_mode = enclose "[" "]" o commas o map string_of_int;
+exception Modes of
+ (string * (int list option list * int list) list) list *
+ (string * (int list list option list * int list list)) list;
+
+fun lookup_modes gr dep = apfst flat (apsnd flat (ListPair.unzip
+ (map ((fn (Some (Modes x), _) => x | _ => ([], [])) o Graph.get_node gr)
+ (Graph.all_preds gr [dep]))));
+
+fun string_of_mode (iss, is) = space_implode " -> " (map
+ (fn None => "X"
+ | Some js => enclose "[" "]" (commas (map string_of_int js)))
+ (iss @ [Some is]));
fun print_modes modes = message ("Inferred modes:\n" ^
space_implode "\n" (map (fn (s, ms) => s ^ ": " ^ commas (map
string_of_mode ms)) modes));
fun print_factors factors = message ("Factors:\n" ^
- space_implode "\n" (map (fn (s, fs) => s ^ ": " ^ string_of_factors [] fs) factors));
-
-fun get_modes (Some (Modes x), _) = x
- | get_modes _ = ([], []);
+ space_implode "\n" (map (fn (s, (fs, f)) => s ^ ": " ^
+ space_implode " -> " (map
+ (fn None => "X" | Some f' => string_of_factors [] f')
+ (fs @ [Some f]))) factors));
-fun mk_ind_def thy gr dep names intrs =
+fun mk_extra_defs thy gr dep names ts =
+ foldl (fn (gr, name) =>
+ if name mem names then gr
+ else (case get_clauses thy name of
+ None => gr
+ | Some (names, intrs) =>
+ mk_ind_def thy gr dep names intrs))
+ (gr, foldr add_term_consts (ts, []))
+
+and mk_ind_def thy gr dep names intrs =
let val ids = map (mk_const_id (sign_of thy)) names
in Graph.add_edge (hd ids, dep) gr handle Graph.UNDEF _ =>
let
- fun process_prem factors (gr, t' as _ $ (Const ("op :", _) $ t $ u)) =
- (case strip_comb u of
- (Const (name, _), args) =>
- (case InductivePackage.get_inductive thy name of
- None => (gr, Sidecond t')
- | Some ({names=names', ...}, {intrs=intrs', ...}) =>
- (if names = names' then gr
- else mk_ind_def thy gr (hd ids) names' intrs',
- Prem (name, split_prod []
- (the (assoc (factors, name))) t, args)))
- | _ => (gr, Sidecond t'))
- | process_prem factors (gr, _ $ (Const ("op =", _) $ t $ u)) =
- (gr, Prem ("eq", [t, u], []))
- | process_prem factors (gr, _ $ t) = (gr, Sidecond t);
+ fun dest_prem factors (_ $ (Const ("op :", _) $ t $ u)) =
+ (case head_of u of
+ Const (name, _) => Prem (split_prod []
+ (the (assoc (factors, name))) t, u)
+ | Var ((name, _), _) => Prem (split_prod []
+ (the (assoc (factors, name))) t, u))
+ | dest_prem factors (_ $ ((eq as Const ("op =", _)) $ t $ u)) =
+ Prem ([t, u], eq)
+ | dest_prem factors (_ $ t) = Sidecond t;
- fun add_clause factors ((clauses, gr), intr) =
+ fun add_clause factors (clauses, intr) =
let
val _ $ (_ $ t $ u) = Logic.strip_imp_concl intr;
- val (Const (name, _), args) = strip_comb u;
- val (gr', prems) = foldl_map (process_prem factors)
- (gr, Logic.strip_imp_prems intr);
+ val Const (name, _) = head_of u;
+ val prems = map (dest_prem factors) (Logic.strip_imp_prems intr);
in
(overwrite (clauses, (name, if_none (assoc (clauses, name)) [] @
- [(split_prod [] (the (assoc (factors, name))) t, prems)])), gr')
+ [(split_prod [] (the (assoc (factors, name))) t, prems)])))
end;
- fun add_prod_factors (fs, x as _ $ (Const ("op :", _) $ t $ u)) =
- (case strip_comb u of
- (Const (name, _), _) =>
- let val f = prod_factors [] t
- in overwrite (fs, (name, f inter if_none (assoc (fs, name)) f)) end
- | _ => fs)
- | add_prod_factors (fs, _) = fs;
+ fun add_prod_factors extra_fs (fs, _ $ (Const ("op :", _) $ t $ u)) =
+ infer_factors (sign_of thy) extra_fs
+ (fs, (Some (FVar (prod_factors [] t)), u))
+ | add_prod_factors _ (fs, _) = fs;
val intrs' = map (rename_term o #prop o rep_thm o standard) intrs;
- val factors = foldl add_prod_factors ([], flat (map (fn t =>
- Logic.strip_imp_concl t :: Logic.strip_imp_prems t) intrs'));
- val (clauses, gr') = foldl (add_clause factors) (([], Graph.add_edge (hd ids, dep)
- (Graph.new_node (hd ids, (None, "")) gr)), intrs');
val _ $ (_ $ _ $ u) = Logic.strip_imp_concl (hd intrs');
val (_, args) = strip_comb u;
val arg_vs = flat (map term_vs args);
- val extra_modes = ("eq", [[1], [2], [1,2]]) :: (flat (map
- (fst o get_modes o Graph.get_node gr') (Graph.all_preds gr' [hd ids])));
- val modes = infer_modes thy extra_modes arg_vs clauses;
+ val gr' = mk_extra_defs thy
+ (Graph.add_edge (hd ids, dep)
+ (Graph.new_node (hd ids, (None, "")) gr)) (hd ids) names intrs';
+ val (extra_modes', extra_factors) = lookup_modes gr' (hd ids);
+ val extra_modes =
+ ("op =", [([], [1]), ([], [2]), ([], [1, 2])]) :: extra_modes';
+ val fs = map (apsnd dest_factors)
+ (foldl (add_prod_factors extra_factors) ([], flat (map (fn t =>
+ Logic.strip_imp_concl t :: Logic.strip_imp_prems t) intrs')));
+ val _ = (case map fst fs \\ names \\ arg_vs of
+ [] => ()
+ | xs => error ("Non-inductive sets: " ^ commas_quote xs));
+ val factors = mapfilter (fn (name, f) =>
+ if name mem arg_vs then None
+ else Some (name, (map (curry assoc fs) arg_vs, f))) fs;
+ val clauses =
+ foldl (add_clause (fs @ map (apsnd snd) extra_factors)) ([], intrs');
+ val modes = infer_modes thy extra_modes factors arg_vs clauses;
+ val _ = print_factors factors;
val _ = print_modes modes;
- val _ = print_factors factors;
val (gr'', s) = compile_preds thy gr' (hd ids) (terms_vs intrs') arg_vs
(modes @ extra_modes) clauses;
in
@@ -339,31 +482,33 @@
end
end;
-fun mk_ind_call thy gr dep t u is_query = (case strip_comb u of
- (Const (s, _), args) => (case InductivePackage.get_inductive thy s of
+fun mk_ind_call thy gr dep t u is_query = (case head_of u of
+ Const (s, _) => (case get_clauses thy s of
None => None
- | Some ({names, ...}, {intrs, ...}) =>
+ | Some (names, intrs) =>
let
fun mk_mode (((ts, mode), i), Var _) = ((ts, mode), i+1)
| mk_mode (((ts, mode), i), Free _) = ((ts, mode), i+1)
| mk_mode (((ts, mode), i), t) = ((ts @ [t], mode @ [i]), i+1);
- val gr1 = mk_ind_def thy gr dep names intrs;
- val (modes, factors) = pairself flat (ListPair.unzip
- (map (get_modes o Graph.get_node gr1) (Graph.all_preds gr1 [dep])));
- val ts = split_prod [] (the (assoc (factors, s))) t;
- val (ts', mode) = if is_query then
+ val gr1 = mk_extra_defs thy
+ (mk_ind_def thy gr dep names intrs) dep names [u];
+ val (modes, factors) = lookup_modes gr1 dep;
+ val ts = split_prod [] (snd (the (assoc (factors, s)))) t;
+ val (ts', is) = if is_query then
fst (foldl mk_mode ((([], []), 1), ts))
else (ts, 1 upto length ts);
- val _ = if mode mem the (assoc (modes, s)) then () else
- error ("No such mode for " ^ s ^ ": " ^ string_of_mode mode);
+ val mode = (case find_first (fn Mode ((_, js), _) => is=js)
+ (modes_of modes u) of
+ None => error ("No such mode for " ^ s ^ ": " ^
+ string_of_mode ([], is))
+ | mode => mode);
val (gr2, in_ps) = foldl_map
(invoke_codegen thy dep false) (gr1, ts');
- val (gr3, arg_ps) = foldl_map
- (invoke_codegen thy dep true) (gr2, args);
+ val (gr3, ps) = compile_expr thy dep false (gr2, (mode, u))
in
- Some (gr3, Pretty.block (separate (Pretty.brk 1)
- (Pretty.str (modename thy s mode) :: arg_ps @ [mk_tuple in_ps])))
+ Some (gr3, Pretty.block
+ (ps @ [Pretty.brk 1, mk_tuple in_ps]))
end)
| _ => None);
@@ -376,7 +521,10 @@
mk_ind_call thy gr dep t u true
| inductive_codegen thy gr dep brack _ = None;
-val setup = [add_codegen "inductive" inductive_codegen];
+val setup =
+ [add_codegen "inductive" inductive_codegen,
+ CodegenData.init,
+ add_attribute "ind" add];
end;
@@ -394,6 +542,8 @@
fun ?! s = is_some (Seq.pull s);
-fun eq_1 x = Seq.single x;
+fun op__61__1 x = Seq.single x;
-val eq_2 = eq_1;
+val op__61__2 = op__61__1;
+
+fun op__61__1_2 (x, y) = ?? (x = y);