# HG changeset patch # User berghofe # Date 1008856698 -3600 # Node ID bb2e4689347e8d139660e60896e9437585ddc8e1 # Parent 5e2593ef0a4499a36274fa750d11227d17ecec22 Implemented higher order modes. diff -r 5e2593ef0a44 -r bb2e4689347e src/HOL/Tools/inductive_codegen.ML --- a/src/HOL/Tools/inductive_codegen.ML Thu Dec 20 14:57:54 2001 +0100 +++ b/src/HOL/Tools/inductive_codegen.ML Thu Dec 20 14:58:18 2001 +0100 @@ -8,6 +8,7 @@ signature INDUCTIVE_CODEGEN = sig + val add : theory attribute val setup : (theory -> theory) list end; @@ -16,10 +17,48 @@ open Codegen; -exception Modes of (string * int list list) list * (string * int list list) list; +(**** theory data ****) + +structure CodegenArgs = +struct + val name = "HOL/inductive_codegen"; + type T = thm list Symtab.table; + val empty = Symtab.empty; + val copy = I; + val prep_ext = I; + val merge = Symtab.merge_multi eq_thm; + fun print _ _ = (); +end; + +structure CodegenData = TheoryDataFun(CodegenArgs); + +fun warn thm = warning ("InductiveCodegen: Not a proper clause:\n" ^ + string_of_thm thm); -datatype indprem = Prem of string * term list * term list - | Sidecond of term; +fun add (p as (thy, thm)) = + let + val tsig = Sign.tsig_of (sign_of thy); + val tab = CodegenData.get thy; + val matches = curry (Pattern.matches tsig o pairself concl_of); + + in (case concl_of thm of + _ $ (Const ("op :", _) $ _ $ t) => (case head_of t of + Const (s, _) => (CodegenData.put (Symtab.update ((s, + filter_out (matches thm) (if_none (Symtab.lookup (tab, s)) []) @ + [thm]), tab)) thy, thm) + | _ => (warn thm; p)) + | _ => (warn thm; p)) + end handle Pattern.Pattern => (warn thm; p); + +fun get_clauses thy s = + (case Symtab.lookup (CodegenData.get thy, s) of + None => (case InductivePackage.get_inductive thy s of + None => None + | Some ({names, ...}, {intrs, ...}) => Some (names, intrs)) + | Some thms => Some ([s], thms)); + + +(**** improper tuples ****) fun prod_factors p (Const ("Pair", _) $ t $ u) = p :: prod_factors (1::p) t @ prod_factors (2::p) u @@ -30,10 +69,44 @@ split_prod (1::p) ps t @ split_prod (2::p) ps u | _ => error "Inconsistent use of products") else [t]; +datatype factors = FVar of int list list | FFix of int list list; + +exception Factors; + +fun mg_factor (FVar f) (FVar f') = FVar (f inter f') + | mg_factor (FVar f) (FFix f') = + if f' subset f then FFix f' else raise Factors + | mg_factor (FFix f) (FVar f') = + if f subset f' then FFix f else raise Factors + | mg_factor (FFix f) (FFix f') = + if f subset f' andalso f' subset f then FFix f else raise Factors; + +fun dest_factors (FVar f) = f + | dest_factors (FFix f) = f; + +fun infer_factors sg extra_fs (fs, (optf, t)) = + let fun err s = error (s ^ "\n" ^ Sign.string_of_term sg t) + in (case (optf, strip_comb t) of + (Some f, (Const (name, _), args)) => + (case assoc (extra_fs, name) of + None => overwrite (fs, (name, if_none + (apsome (mg_factor f) (assoc (fs, name))) f)) + | Some (fs', f') => (mg_factor f (FFix f'); + foldl (infer_factors sg extra_fs) + (fs, map (apsome FFix) fs' ~~ args))) + | (Some f, (Var ((name, _), _), [])) => + overwrite (fs, (name, if_none + (apsome (mg_factor f) (assoc (fs, name))) f)) + | (None, _) => fs + | _ => err "Illegal term") + handle Factors => err "Product factor mismatch in" + end; + fun string_of_factors p ps = if p mem ps then "(" ^ string_of_factors (1::p) ps ^ ", " ^ string_of_factors (2::p) ps ^ ")" else "_"; + (**** check if a term contains only constructor functions ****) fun is_constrt thy = @@ -81,9 +154,32 @@ in merge (map (fn ks => i::ks) is) is end else [[]]; +fun cprod ([], ys) = [] + | cprod (x :: xs, ys) = map (pair x) ys @ cprod (xs, ys); + +fun cprods xss = foldr (map op :: o cprod) (xss, [[]]); + +datatype mode = Mode of (int list option list * int list) * mode option list; + +fun modes_of modes t = + let + fun mk_modes name args = flat + (map (fn (m as (iss, is)) => map (Mode o pair m) (cprods (map + (fn (None, _) => [None] + | (Some js, arg) => map Some + (filter (fn Mode ((_, js'), _) => js=js') (modes_of modes arg))) + (iss ~~ args)))) (the (assoc (modes, name)))) + + in (case strip_comb t of + (Const (name, _), args) => mk_modes name args + | (Var ((name, _), _), args) => mk_modes name args) + end; + +datatype indprem = Prem of term list * term | Sidecond of term; + fun select_mode_prem thy modes vs ps = find_first (is_some o snd) (ps ~~ map - (fn Prem (s, us, args) => find_first (fn is => + (fn Prem (us, t) => find_first (fn Mode ((_, is), _) => let val (_, out_ts) = get_args is 1 us; val vTs = flat (map term_vTs out_ts); @@ -92,21 +188,25 @@ in is subset known_args vs 1 us andalso forall (is_constrt thy) (snd (get_args is 1 us)) andalso - terms_vs args subset vs andalso + term_vs t subset vs andalso forall is_eqT dupTs end) - (the (assoc (modes, s))) - | Sidecond t => if term_vs t subset vs then Some [] else None) ps); + (modes_of modes t) + | Sidecond t => if term_vs t subset vs then Some (Mode (([], []), [])) + else None) ps); -fun check_mode_clause thy arg_vs modes mode (ts, ps) = +fun check_mode_clause thy arg_vs modes (iss, is) (ts, ps) = let + val modes' = modes @ mapfilter + (fn (_, None) => None | (v, Some js) => Some (v, [([], js)])) + (arg_vs ~~ iss); fun check_mode_prems vs [] = Some vs - | check_mode_prems vs ps = (case select_mode_prem thy modes vs ps of + | check_mode_prems vs ps = (case select_mode_prem thy modes' vs ps of None => None | Some (x, _) => check_mode_prems - (case x of Prem (_, us, _) => vs union terms_vs us | _ => vs) + (case x of Prem (us, _) => vs union terms_vs us | _ => vs) (filter_out (equal x) ps)); - val (in_ts', _) = get_args mode 1 ts; + val (in_ts', _) = get_args is 1 ts; val in_ts = filter (is_constrt thy) in_ts'; val in_vs = terms_vs in_ts; val concl_vs = terms_vs ts @@ -125,9 +225,12 @@ let val y = f x in if x = y then x else fixp f y end; -fun infer_modes thy extra_modes arg_vs preds = fixp (fn modes => +fun infer_modes thy extra_modes factors arg_vs preds = fixp (fn modes => map (check_modes_pred thy arg_vs preds (modes @ extra_modes)) modes) - (map (fn (s, (ts, _)::_) => (s, subsets 1 (length ts))) preds); + (map (fn (s, (fs, f)) => (s, cprod (cprods (map + (fn None => [None] + | Some f' => map Some (subsets 1 (length f' + 1))) fs), + subsets 1 (length f + 1)))) factors); (**** code generation ****) @@ -167,17 +270,37 @@ [Pretty.brk 1, Pretty.str "| _ => ", fail_p, Pretty.str ")"])) end; -fun modename thy s mode = space_implode "_" - (mk_const_id (sign_of thy) s :: map string_of_int mode); +fun modename thy s (iss, is) = space_implode "__" + (mk_const_id (sign_of thy) s :: + map (space_implode "_" o map string_of_int) (mapfilter I iss @ [is])); -fun compile_clause thy gr dep all_vs arg_vs modes mode (ts, ps) = +fun compile_expr thy dep brack (gr, (None, t)) = + apsnd single (invoke_codegen thy dep brack (gr, t)) + | compile_expr _ _ _ (gr, (Some _, Var ((name, _), _))) = + (gr, [Pretty.str name]) + | compile_expr thy dep brack (gr, (Some (Mode (mode, ms)), t)) = + let + val (Const (name, _), args) = strip_comb t; + val (gr', ps) = foldl_map + (compile_expr thy dep true) (gr, ms ~~ args); + in (gr', (if brack andalso not (null ps) then + single o parens o Pretty.block else I) + (flat (separate [Pretty.brk 1] + ([Pretty.str (modename thy name mode)] :: ps)))) + end; + +fun compile_clause thy gr dep all_vs arg_vs modes (iss, is) (ts, ps) = let + val modes' = modes @ mapfilter + (fn (_, None) => None | (v, Some js) => Some (v, [([], js)])) + (arg_vs ~~ iss); + fun check_constrt ((names, eqs), t) = if is_constrt thy t then ((names, eqs), t) else let val s = variant names "x"; in ((s::names, (s, t)::eqs), Var ((s, 0), fastype_of t)) end; - val (in_ts, out_ts) = get_args mode 1 ts; + val (in_ts, out_ts) = get_args is 1 ts; val ((all_vs', eqs), in_ts') = foldl_map check_constrt ((all_vs, []), in_ts); @@ -200,27 +323,25 @@ | compile_prems out_ts vs names gr ps = let val vs' = distinct (flat (vs :: map term_vs out_ts)); - val Some (p, Some mode') = - select_mode_prem thy modes (arg_vs union vs') ps; + val Some (p, mode as Some (Mode ((_, js), _))) = + select_mode_prem thy modes' (arg_vs union vs') ps; val ps' = filter_out (equal p) ps; in (case p of - Prem (s, us, args) => + Prem (us, t) => let - val (in_ts, out_ts') = get_args mode' 1 us; + val (in_ts, out_ts') = get_args js 1 us; val (gr1, in_ps) = foldl_map (invoke_codegen thy dep false) (gr, in_ts); - val (gr2, arg_ps) = foldl_map - (invoke_codegen thy dep true) (gr1, args); val (nvs, out_ts'') = foldl_map distinct_v ((names, map (fn x => (x, [x])) vs), out_ts); - val (gr3, out_ps) = foldl_map - (invoke_codegen thy dep false) (gr2, out_ts'') + val (gr2, out_ps) = foldl_map + (invoke_codegen thy dep false) (gr1, out_ts''); + val (gr3, ps) = compile_expr thy dep false (gr2, (mode, t)); val (gr4, rest) = compile_prems out_ts' vs' (fst nvs) gr3 ps'; in (gr4, compile_match (snd nvs) [] out_ps - (Pretty.block (separate (Pretty.brk 1) - (Pretty.str (modename thy s mode') :: arg_ps) @ + (Pretty.block (ps @ [Pretty.brk 1, mk_tuple in_ps, Pretty.str " :->", Pretty.brk 1, rest])) (Pretty.str "Seq.empty")) @@ -269,69 +390,91 @@ (**** processing of introduction rules ****) -val string_of_mode = enclose "[" "]" o commas o map string_of_int; +exception Modes of + (string * (int list option list * int list) list) list * + (string * (int list list option list * int list list)) list; + +fun lookup_modes gr dep = apfst flat (apsnd flat (ListPair.unzip + (map ((fn (Some (Modes x), _) => x | _ => ([], [])) o Graph.get_node gr) + (Graph.all_preds gr [dep])))); + +fun string_of_mode (iss, is) = space_implode " -> " (map + (fn None => "X" + | Some js => enclose "[" "]" (commas (map string_of_int js))) + (iss @ [Some is])); fun print_modes modes = message ("Inferred modes:\n" ^ space_implode "\n" (map (fn (s, ms) => s ^ ": " ^ commas (map string_of_mode ms)) modes)); fun print_factors factors = message ("Factors:\n" ^ - space_implode "\n" (map (fn (s, fs) => s ^ ": " ^ string_of_factors [] fs) factors)); - -fun get_modes (Some (Modes x), _) = x - | get_modes _ = ([], []); + space_implode "\n" (map (fn (s, (fs, f)) => s ^ ": " ^ + space_implode " -> " (map + (fn None => "X" | Some f' => string_of_factors [] f') + (fs @ [Some f]))) factors)); -fun mk_ind_def thy gr dep names intrs = +fun mk_extra_defs thy gr dep names ts = + foldl (fn (gr, name) => + if name mem names then gr + else (case get_clauses thy name of + None => gr + | Some (names, intrs) => + mk_ind_def thy gr dep names intrs)) + (gr, foldr add_term_consts (ts, [])) + +and mk_ind_def thy gr dep names intrs = let val ids = map (mk_const_id (sign_of thy)) names in Graph.add_edge (hd ids, dep) gr handle Graph.UNDEF _ => let - fun process_prem factors (gr, t' as _ $ (Const ("op :", _) $ t $ u)) = - (case strip_comb u of - (Const (name, _), args) => - (case InductivePackage.get_inductive thy name of - None => (gr, Sidecond t') - | Some ({names=names', ...}, {intrs=intrs', ...}) => - (if names = names' then gr - else mk_ind_def thy gr (hd ids) names' intrs', - Prem (name, split_prod [] - (the (assoc (factors, name))) t, args))) - | _ => (gr, Sidecond t')) - | process_prem factors (gr, _ $ (Const ("op =", _) $ t $ u)) = - (gr, Prem ("eq", [t, u], [])) - | process_prem factors (gr, _ $ t) = (gr, Sidecond t); + fun dest_prem factors (_ $ (Const ("op :", _) $ t $ u)) = + (case head_of u of + Const (name, _) => Prem (split_prod [] + (the (assoc (factors, name))) t, u) + | Var ((name, _), _) => Prem (split_prod [] + (the (assoc (factors, name))) t, u)) + | dest_prem factors (_ $ ((eq as Const ("op =", _)) $ t $ u)) = + Prem ([t, u], eq) + | dest_prem factors (_ $ t) = Sidecond t; - fun add_clause factors ((clauses, gr), intr) = + fun add_clause factors (clauses, intr) = let val _ $ (_ $ t $ u) = Logic.strip_imp_concl intr; - val (Const (name, _), args) = strip_comb u; - val (gr', prems) = foldl_map (process_prem factors) - (gr, Logic.strip_imp_prems intr); + val Const (name, _) = head_of u; + val prems = map (dest_prem factors) (Logic.strip_imp_prems intr); in (overwrite (clauses, (name, if_none (assoc (clauses, name)) [] @ - [(split_prod [] (the (assoc (factors, name))) t, prems)])), gr') + [(split_prod [] (the (assoc (factors, name))) t, prems)]))) end; - fun add_prod_factors (fs, x as _ $ (Const ("op :", _) $ t $ u)) = - (case strip_comb u of - (Const (name, _), _) => - let val f = prod_factors [] t - in overwrite (fs, (name, f inter if_none (assoc (fs, name)) f)) end - | _ => fs) - | add_prod_factors (fs, _) = fs; + fun add_prod_factors extra_fs (fs, _ $ (Const ("op :", _) $ t $ u)) = + infer_factors (sign_of thy) extra_fs + (fs, (Some (FVar (prod_factors [] t)), u)) + | add_prod_factors _ (fs, _) = fs; val intrs' = map (rename_term o #prop o rep_thm o standard) intrs; - val factors = foldl add_prod_factors ([], flat (map (fn t => - Logic.strip_imp_concl t :: Logic.strip_imp_prems t) intrs')); - val (clauses, gr') = foldl (add_clause factors) (([], Graph.add_edge (hd ids, dep) - (Graph.new_node (hd ids, (None, "")) gr)), intrs'); val _ $ (_ $ _ $ u) = Logic.strip_imp_concl (hd intrs'); val (_, args) = strip_comb u; val arg_vs = flat (map term_vs args); - val extra_modes = ("eq", [[1], [2], [1,2]]) :: (flat (map - (fst o get_modes o Graph.get_node gr') (Graph.all_preds gr' [hd ids]))); - val modes = infer_modes thy extra_modes arg_vs clauses; + val gr' = mk_extra_defs thy + (Graph.add_edge (hd ids, dep) + (Graph.new_node (hd ids, (None, "")) gr)) (hd ids) names intrs'; + val (extra_modes', extra_factors) = lookup_modes gr' (hd ids); + val extra_modes = + ("op =", [([], [1]), ([], [2]), ([], [1, 2])]) :: extra_modes'; + val fs = map (apsnd dest_factors) + (foldl (add_prod_factors extra_factors) ([], flat (map (fn t => + Logic.strip_imp_concl t :: Logic.strip_imp_prems t) intrs'))); + val _ = (case map fst fs \\ names \\ arg_vs of + [] => () + | xs => error ("Non-inductive sets: " ^ commas_quote xs)); + val factors = mapfilter (fn (name, f) => + if name mem arg_vs then None + else Some (name, (map (curry assoc fs) arg_vs, f))) fs; + val clauses = + foldl (add_clause (fs @ map (apsnd snd) extra_factors)) ([], intrs'); + val modes = infer_modes thy extra_modes factors arg_vs clauses; + val _ = print_factors factors; val _ = print_modes modes; - val _ = print_factors factors; val (gr'', s) = compile_preds thy gr' (hd ids) (terms_vs intrs') arg_vs (modes @ extra_modes) clauses; in @@ -339,31 +482,33 @@ end end; -fun mk_ind_call thy gr dep t u is_query = (case strip_comb u of - (Const (s, _), args) => (case InductivePackage.get_inductive thy s of +fun mk_ind_call thy gr dep t u is_query = (case head_of u of + Const (s, _) => (case get_clauses thy s of None => None - | Some ({names, ...}, {intrs, ...}) => + | Some (names, intrs) => let fun mk_mode (((ts, mode), i), Var _) = ((ts, mode), i+1) | mk_mode (((ts, mode), i), Free _) = ((ts, mode), i+1) | mk_mode (((ts, mode), i), t) = ((ts @ [t], mode @ [i]), i+1); - val gr1 = mk_ind_def thy gr dep names intrs; - val (modes, factors) = pairself flat (ListPair.unzip - (map (get_modes o Graph.get_node gr1) (Graph.all_preds gr1 [dep]))); - val ts = split_prod [] (the (assoc (factors, s))) t; - val (ts', mode) = if is_query then + val gr1 = mk_extra_defs thy + (mk_ind_def thy gr dep names intrs) dep names [u]; + val (modes, factors) = lookup_modes gr1 dep; + val ts = split_prod [] (snd (the (assoc (factors, s)))) t; + val (ts', is) = if is_query then fst (foldl mk_mode ((([], []), 1), ts)) else (ts, 1 upto length ts); - val _ = if mode mem the (assoc (modes, s)) then () else - error ("No such mode for " ^ s ^ ": " ^ string_of_mode mode); + val mode = (case find_first (fn Mode ((_, js), _) => is=js) + (modes_of modes u) of + None => error ("No such mode for " ^ s ^ ": " ^ + string_of_mode ([], is)) + | mode => mode); val (gr2, in_ps) = foldl_map (invoke_codegen thy dep false) (gr1, ts'); - val (gr3, arg_ps) = foldl_map - (invoke_codegen thy dep true) (gr2, args); + val (gr3, ps) = compile_expr thy dep false (gr2, (mode, u)) in - Some (gr3, Pretty.block (separate (Pretty.brk 1) - (Pretty.str (modename thy s mode) :: arg_ps @ [mk_tuple in_ps]))) + Some (gr3, Pretty.block + (ps @ [Pretty.brk 1, mk_tuple in_ps])) end) | _ => None); @@ -376,7 +521,10 @@ mk_ind_call thy gr dep t u true | inductive_codegen thy gr dep brack _ = None; -val setup = [add_codegen "inductive" inductive_codegen]; +val setup = + [add_codegen "inductive" inductive_codegen, + CodegenData.init, + add_attribute "ind" add]; end; @@ -394,6 +542,8 @@ fun ?! s = is_some (Seq.pull s); -fun eq_1 x = Seq.single x; +fun op__61__1 x = Seq.single x; -val eq_2 = eq_1; +val op__61__2 = op__61__1; + +fun op__61__1_2 (x, y) = ?? (x = y);