diff -r 701218c1301c -r a152d6b21c31 src/HOL/Tools/inductive_codegen.ML --- a/src/HOL/Tools/inductive_codegen.ML Fri Jul 01 14:11:06 2005 +0200 +++ b/src/HOL/Tools/inductive_codegen.ML Fri Jul 01 14:13:40 2005 +0200 @@ -7,7 +7,7 @@ signature INDUCTIVE_CODEGEN = sig - val add : theory attribute + val add : string option -> theory attribute val setup : (theory -> theory) list end; @@ -22,18 +22,20 @@ (struct val name = "HOL/inductive_codegen"; type T = - {intros : thm list Symtab.table, + {intros : (thm * string) list Symtab.table, graph : unit Graph.T, - eqns : thm list Symtab.table}; + eqns : (thm * string) list Symtab.table}; val empty = {intros = Symtab.empty, graph = Graph.empty, eqns = Symtab.empty}; val copy = I; val extend = I; fun merge _ ({intros=intros1, graph=graph1, eqns=eqns1}, {intros=intros2, graph=graph2, eqns=eqns2}) = - {intros = Symtab.merge_multi Drule.eq_thm_prop (intros1, intros2), + {intros = Symtab.merge_multi (Drule.eq_thm_prop o pairself fst) + (intros1, intros2), graph = Graph.merge (K true) (graph1, graph2), - eqns = Symtab.merge_multi Drule.eq_thm_prop (eqns1, eqns2)}; + eqns = Symtab.merge_multi (Drule.eq_thm_prop o pairself fst) + (eqns1, eqns2)}; fun print _ _ = (); end); @@ -43,15 +45,19 @@ fun add_node (g, x) = Graph.new_node (x, ()) g handle Graph.DUP _ => g; -fun add (p as (thy, thm)) = - let val {intros, graph, eqns} = CodegenData.get thy; +fun add optmod (p as (thy, thm)) = + let + val {intros, graph, eqns} = CodegenData.get thy; + fun thyname_of s = (case optmod of + NONE => thyname_of_const s thy | SOME s => s); in (case concl_of thm of _ $ (Const ("op :", _) $ _ $ t) => (case head_of t of Const (s, _) => let val cs = foldr add_term_consts [] (prems_of thm) in (CodegenData.put {intros = Symtab.update ((s, - getOpt (Symtab.lookup (intros, s), []) @ [thm]), intros), + getOpt (Symtab.lookup (intros, s), []) @ + [(thm, thyname_of s)]), intros), graph = foldr (uncurry (Graph.add_edge o pair s)) (Library.foldl add_node (graph, s :: cs)) cs, eqns = eqns} thy, thm) @@ -61,7 +67,8 @@ Const (s, _) => (CodegenData.put {intros = intros, graph = graph, eqns = Symtab.update ((s, - getOpt (Symtab.lookup (eqns, s), []) @ [thm]), eqns)} thy, thm) + getOpt (Symtab.lookup (eqns, s), []) @ + [(thm, thyname_of s)]), eqns)} thy, thm) | _ => (warn thm; p)) | _ => (warn thm; p)) end; @@ -71,13 +78,17 @@ in case Symtab.lookup (intros, s) of NONE => (case InductivePackage.get_inductive thy s of NONE => NONE - | SOME ({names, ...}, {intrs, ...}) => SOME (names, preprocess thy intrs)) + | SOME ({names, ...}, {intrs, ...}) => + SOME (names, thyname_of_const s thy, + preprocess thy intrs)) | SOME _ => - let val SOME names = find_first - (fn xs => s mem xs) (Graph.strong_conn graph) - in SOME (names, preprocess thy - (List.concat (map (fn s => valOf (Symtab.lookup (intros, s))) names))) - end + let + val SOME names = find_first + (fn xs => s mem xs) (Graph.strong_conn graph); + val intrs = List.concat (map + (fn s => valOf (Symtab.lookup (intros, s))) names); + val (_, (_, thyname)) = split_last intrs + in SOME (names, thyname, preprocess thy (map fst intrs)) end end; @@ -364,26 +375,30 @@ else [Pretty.str ")"]))) end; -fun modename thy s (iss, is) = space_implode "__" - (mk_const_id (sign_of thy) s :: +fun strip_spaces s = implode (fst (take_suffix (equal " ") (explode s))); + +fun modename thy thyname thyname' s (iss, is) = space_implode "__" + (mk_const_id (sign_of thy) thyname thyname' (strip_spaces s) :: map (space_implode "_" o map string_of_int) (List.mapPartial I iss @ [is])); -fun compile_expr thy dep brack (gr, (NONE, t)) = - apsnd single (invoke_codegen thy dep brack (gr, t)) - | compile_expr _ _ _ (gr, (SOME _, Var ((name, _), _))) = +fun compile_expr thy defs dep thyname brack thynames (gr, (NONE, t)) = + apsnd single (invoke_codegen thy defs dep thyname brack (gr, t)) + | compile_expr _ _ _ _ _ _ (gr, (SOME _, Var ((name, _), _))) = (gr, [Pretty.str name]) - | compile_expr thy dep brack (gr, (SOME (Mode (mode, ms)), t)) = + | compile_expr thy defs dep thyname brack thynames (gr, (SOME (Mode (mode, ms)), t)) = let val (Const (name, _), args) = strip_comb t; val (gr', ps) = foldl_map - (compile_expr thy dep true) (gr, ms ~~ args); + (compile_expr thy defs dep thyname true thynames) (gr, ms ~~ args); in (gr', (if brack andalso not (null ps) then single o parens o Pretty.block else I) (List.concat (separate [Pretty.brk 1] - ([Pretty.str (modename thy name mode)] :: ps)))) + ([Pretty.str (modename thy thyname + (if name = "op =" then "" + else the (assoc (thynames, name))) name mode)] :: ps)))) end; -fun compile_clause thy gr dep all_vs arg_vs modes (iss, is) (ts, ps) = +fun compile_clause thy defs gr dep thyname all_vs arg_vs modes thynames (iss, is) (ts, ps) = let val modes' = modes @ List.mapPartial (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)])) @@ -396,7 +411,7 @@ fun compile_eq (gr, (s, t)) = apsnd (Pretty.block o cons (Pretty.str (s ^ " = ")) o single) - (invoke_codegen thy dep false (gr, t)); + (invoke_codegen thy defs dep thyname false (gr, t)); val (in_ts, out_ts) = get_args is 1 ts; val ((all_vs', eqs), in_ts') = @@ -409,14 +424,14 @@ fun compile_prems out_ts' vs names gr [] = let val (gr2, out_ps) = foldl_map - (invoke_codegen thy dep false) (gr, out_ts); + (invoke_codegen thy defs dep thyname false) (gr, out_ts); val (gr3, eq_ps) = foldl_map compile_eq (gr2, eqs); val ((names', eqs'), out_ts'') = foldl_map check_constrt ((names, []), out_ts'); val (nvs, out_ts''') = foldl_map distinct_v ((names', map (fn x => (x, [x])) vs), out_ts''); val (gr4, out_ps') = foldl_map - (invoke_codegen thy dep false) (gr3, out_ts'''); + (invoke_codegen thy defs dep thyname false) (gr3, out_ts'''); val (gr5, eq_ps') = foldl_map compile_eq (gr4, eqs') in (gr5, compile_match (snd nvs) (eq_ps @ eq_ps') out_ps' @@ -434,7 +449,7 @@ val (nvs, out_ts'') = foldl_map distinct_v ((names', map (fn x => (x, [x])) vs), out_ts'); val (gr0, out_ps) = foldl_map - (invoke_codegen thy dep false) (gr, out_ts''); + (invoke_codegen thy defs dep thyname false) (gr, out_ts''); val (gr1, eq_ps) = foldl_map compile_eq (gr0, eqs) in (case p of @@ -442,14 +457,15 @@ let val (in_ts, out_ts''') = get_args js 1 us; val (gr2, in_ps) = foldl_map - (invoke_codegen thy dep false) (gr1, in_ts); + (invoke_codegen thy defs dep thyname false) (gr1, in_ts); val (gr3, ps) = if is_ind t then apsnd (fn ps => ps @ [Pretty.brk 1, mk_tuple in_ps]) - (compile_expr thy dep false (gr2, (mode, t))) + (compile_expr thy defs dep thyname false thynames + (gr2, (mode, t))) else apsnd (fn p => conv_ntuple us t [Pretty.str "Seq.of_list", Pretty.brk 1, p]) - (invoke_codegen thy dep true (gr2, t)); + (invoke_codegen thy defs dep thyname true (gr2, t)); val (gr4, rest) = compile_prems out_ts''' vs' (fst nvs) gr3 ps'; in (gr4, compile_match (snd nvs) eq_ps out_ps @@ -459,7 +475,7 @@ end | Sidecond t => let - val (gr2, side_p) = invoke_codegen thy dep true (gr1, t); + val (gr2, side_p) = invoke_codegen thy defs dep thyname true (gr1, t); val (gr3, rest) = compile_prems [] vs' (fst nvs) gr2 ps'; in (gr3, compile_match (snd nvs) eq_ps out_ps @@ -474,22 +490,23 @@ (gr', Pretty.block [Pretty.str "Seq.single inp :->", Pretty.brk 1, prem_p]) end; -fun compile_pred thy gr dep prfx all_vs arg_vs modes s cls mode = - let val (gr', cl_ps) = foldl_map (fn (gr, cl) => - compile_clause thy gr dep all_vs arg_vs modes mode cl) (gr, cls) +fun compile_pred thy defs gr dep thyname prfx all_vs arg_vs modes thynames s cls mode = + let val (gr', cl_ps) = foldl_map (fn (gr, cl) => compile_clause thy defs + gr dep thyname all_vs arg_vs modes thynames mode cl) (gr, cls) in ((gr', "and "), Pretty.block ([Pretty.block (separate (Pretty.brk 1) - (Pretty.str (prfx ^ modename thy s mode) :: map Pretty.str arg_vs) @ + (Pretty.str (prfx ^ modename thy thyname thyname s mode) :: + map Pretty.str arg_vs) @ [Pretty.str " inp ="]), Pretty.brk 1] @ List.concat (separate [Pretty.str " ++", Pretty.brk 1] (map single cl_ps)))) end; -fun compile_preds thy gr dep all_vs arg_vs modes preds = +fun compile_preds thy defs gr dep thyname all_vs arg_vs modes thynames preds = let val ((gr', _), prs) = foldl_map (fn ((gr, prfx), (s, cls)) => - foldl_map (fn ((gr', prfx'), mode) => - compile_pred thy gr' dep prfx' all_vs arg_vs modes s cls mode) + foldl_map (fn ((gr', prfx'), mode) => compile_pred thy defs gr' + dep thyname prfx' all_vs arg_vs modes thynames s cls mode) ((gr, prfx), valOf (assoc (modes, s)))) ((gr, "fun "), preds) in (gr', space_implode "\n\n" (map Pretty.string_of (List.concat prs)) ^ ";\n\n") @@ -499,11 +516,13 @@ exception Modes of (string * (int list option list * int list) list) list * - (string * (int list list option list * int list list)) list; + (string * (int list list option list * int list list)) list * + string; -fun lookup_modes gr dep = apfst List.concat (apsnd List.concat (ListPair.unzip - (map ((fn (SOME (Modes x), _) => x | _ => ([], [])) o Graph.get_node gr) - (Graph.all_preds gr [dep])))); +fun lookup_modes gr dep = foldl (fn ((xs, ys, z), (xss, yss, zss)) => + (xss @ xs, yss @ ys, zss @ map (rpair z o fst) ys)) ([], [], []) + (map ((fn (SOME (Modes x), _, _) => x | _ => ([], [], "")) o Graph.get_node gr) + (Graph.all_preds gr [dep])); fun print_factors factors = message ("Factors:\n" ^ space_implode "\n" (map (fn (s, (fs, f)) => s ^ ": " ^ @@ -518,18 +537,17 @@ NONE => xs | SOME xs' => xs inter xs') :: constrain cs ys; -fun mk_extra_defs thy gr dep names ts = +fun mk_extra_defs thy defs gr dep names ts = Library.foldl (fn (gr, name) => if name mem names then gr else (case get_clauses thy name of NONE => gr - | SOME (names, intrs) => - mk_ind_def thy gr dep names [] [] (prep_intrs intrs))) + | SOME (names, thyname, intrs) => + mk_ind_def thy defs gr dep names thyname [] [] (prep_intrs intrs))) (gr, foldr add_term_consts [] ts) -and mk_ind_def thy gr dep names modecs factorcs intrs = - let val ids = map (mk_const_id (sign_of thy)) names - in Graph.add_edge (hd ids, dep) gr handle Graph.UNDEF _ => +and mk_ind_def thy defs gr dep names thyname modecs factorcs intrs = + Graph.add_edge (hd names, dep) gr handle Graph.UNDEF _ => let val _ $ (_ $ _ $ u) = Logic.strip_imp_concl (hd intrs); val (_, args) = strip_comb u; @@ -565,10 +583,10 @@ else fs | add_prod_factors _ (fs, _) = fs; - val gr' = mk_extra_defs thy - (Graph.add_edge (hd ids, dep) - (Graph.new_node (hd ids, (NONE, "")) gr)) (hd ids) names intrs; - val (extra_modes, extra_factors) = lookup_modes gr' (hd ids); + val gr' = mk_extra_defs thy defs + (Graph.add_edge (hd names, dep) + (Graph.new_node (hd names, (NONE, "", "")) gr)) (hd names) names intrs; + val (extra_modes, extra_factors, extra_thynames) = lookup_modes gr' (hd names); val fs = constrain factorcs (map (apsnd dest_factors) (Library.foldl (add_prod_factors extra_factors) ([], List.concat (map (fn t => Logic.strip_imp_concl t :: Logic.strip_imp_prems t) intrs)))); @@ -581,38 +599,40 @@ (infer_modes thy extra_modes factors arg_vs clauses); val _ = print_factors factors; val _ = print_modes modes; - val (gr'', s) = compile_preds thy gr' (hd ids) (terms_vs intrs) arg_vs - (modes @ extra_modes) clauses; + val (gr'', s) = compile_preds thy defs gr' (hd names) thyname (terms_vs intrs) + arg_vs (modes @ extra_modes) + (map (rpair thyname o fst) factors @ extra_thynames) clauses; in - (Graph.map_node (hd ids) (K (SOME (Modes (modes, factors)), s)) gr'') - end - end; + (Graph.map_node (hd names) + (K (SOME (Modes (modes, factors, thyname)), thyname, s)) gr'') + end; fun find_mode s u modes is = (case find_first (fn Mode ((_, js), _) => is=js) (modes_of modes u handle Option => []) of NONE => error ("No such mode for " ^ s ^ ": " ^ string_of_mode ([], is)) | mode => mode); -fun mk_ind_call thy gr dep t u is_query = (case head_of u of +fun mk_ind_call thy defs gr dep thyname t u is_query = (case head_of u of Const (s, T) => (case (get_clauses thy s, get_assoc_code thy s T) of (NONE, _) => NONE - | (SOME (names, intrs), NONE) => + | (SOME (names, thyname', intrs), NONE) => let fun mk_mode (((ts, mode), i), Const ("dummy_pattern", _)) = ((ts, mode), i+1) | mk_mode (((ts, mode), i), t) = ((ts @ [t], mode @ [i]), i+1); - val gr1 = mk_extra_defs thy - (mk_ind_def thy gr dep names [] [] (prep_intrs intrs)) dep names [u]; - val (modes, factors) = lookup_modes gr1 dep; + val gr1 = mk_extra_defs thy defs + (mk_ind_def thy defs gr dep names thyname' [] [] (prep_intrs intrs)) dep names [u]; + val (modes, factors, thynames) = lookup_modes gr1 dep; val ts = split_prod [] (snd (valOf (assoc (factors, s)))) t; val (ts', is) = if is_query then fst (Library.foldl mk_mode ((([], []), 1), ts)) else (ts, 1 upto length ts); val mode = find_mode s u modes is; val (gr2, in_ps) = foldl_map - (invoke_codegen thy dep false) (gr1, ts'); - val (gr3, ps) = compile_expr thy dep false (gr2, (mode, u)) + (invoke_codegen thy defs dep thyname false) (gr1, ts'); + val (gr3, ps) = + compile_expr thy defs dep thyname false thynames (gr2, (mode, u)) in SOME (gr3, Pretty.block (ps @ [Pretty.brk 1, mk_tuple in_ps])) @@ -620,16 +640,17 @@ | _ => NONE) | _ => NONE); -fun list_of_indset thy gr dep brack u = (case head_of u of +fun list_of_indset thy defs gr dep thyname brack u = (case head_of u of Const (s, T) => (case (get_clauses thy s, get_assoc_code thy s T) of (NONE, _) => NONE - | (SOME (names, intrs), NONE) => + | (SOME (names, thyname', intrs), NONE) => let - val gr1 = mk_extra_defs thy - (mk_ind_def thy gr dep names [] [] (prep_intrs intrs)) dep names [u]; - val (modes, factors) = lookup_modes gr1 dep; + val gr1 = mk_extra_defs thy defs + (mk_ind_def thy defs gr dep names thyname' [] [] (prep_intrs intrs)) dep names [u]; + val (modes, factors, thynames) = lookup_modes gr1 dep; val mode = find_mode s u modes []; - val (gr2, ps) = compile_expr thy dep false (gr1, (mode, u)) + val (gr2, ps) = + compile_expr thy defs dep thyname false thynames (gr1, (mode, u)) in SOME (gr2, (if brack then parens else I) (Pretty.block ([Pretty.str "Seq.list_of", Pretty.brk 1, @@ -650,58 +671,63 @@ in rename_term (Logic.list_implies (prems_of eqn, HOLogic.mk_Trueprop (HOLogic.mk_mem - (foldr1 HOLogic.mk_prod (ts @ [u]), Const (Sign.base_name s ^ "_aux", + (foldr1 HOLogic.mk_prod (ts @ [u]), Const (s ^ " ", HOLogic.mk_setT (foldr1 HOLogic.mk_prodT (Ts @ [U]))))))) end; -fun mk_fun thy name eqns dep gr = - let val id = mk_const_id (sign_of thy) name - in Graph.add_edge (id, dep) gr handle Graph.UNDEF _ => +fun mk_fun thy defs name eqns dep thyname thyname' gr = + let + val fun_id = mk_const_id (sign_of thy) thyname' thyname' name; + val call_id = mk_const_id (sign_of thy) thyname thyname' name + in (Graph.add_edge (name, dep) gr handle Graph.UNDEF _ => let val clauses = map clause_of_eqn eqns; - val pname = mk_const_id (sign_of thy) (Sign.base_name name ^ "_aux"); + val pname = name ^ " "; val arity = length (snd (strip_comb (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop (concl_of (hd eqns))))))); val mode = 1 upto arity; val vars = map (fn i => Pretty.str ("x" ^ string_of_int i)) mode; val s = Pretty.string_of (Pretty.block - [mk_app false (Pretty.str ("fun " ^ id)) vars, Pretty.str " =", + [mk_app false (Pretty.str ("fun " ^ fun_id)) vars, Pretty.str " =", Pretty.brk 1, Pretty.str "Seq.hd", Pretty.brk 1, - parens (Pretty.block [Pretty.str (modename thy pname ([], mode)), + parens (Pretty.block [Pretty.str (modename thy thyname' thyname' pname ([], mode)), Pretty.brk 1, mk_tuple vars])]) ^ ";\n\n"; - val gr' = mk_ind_def thy (Graph.add_edge (id, dep) - (Graph.new_node (id, (NONE, s)) gr)) id [pname] + val gr' = mk_ind_def thy defs (Graph.add_edge (name, dep) + (Graph.new_node (name, (NONE, thyname', s)) gr)) name [pname] thyname' [(pname, [([], mode)])] [(pname, map (fn i => replicate i 2) (0 upto arity-1))] clauses; - val (modes, _) = lookup_modes gr' dep; + val (modes, _, _) = lookup_modes gr' dep; val _ = find_mode pname (snd (HOLogic.dest_mem (HOLogic.dest_Trueprop (Logic.strip_imp_concl (hd clauses))))) modes mode - in gr' end + in gr' end, call_id) end; -fun inductive_codegen thy gr dep brack (Const ("op :", _) $ t $ u) = - ((case mk_ind_call thy gr dep (Term.no_dummy_patterns t) u false of +fun inductive_codegen thy defs gr dep thyname brack (Const ("op :", _) $ t $ u) = + ((case mk_ind_call thy defs gr dep thyname (Term.no_dummy_patterns t) u false of NONE => NONE | SOME (gr', call_p) => SOME (gr', (if brack then parens else I) (Pretty.block [Pretty.str "?! (", call_p, Pretty.str ")"]))) - handle TERM _ => mk_ind_call thy gr dep t u true) - | inductive_codegen thy gr dep brack t = (case strip_comb t of + handle TERM _ => mk_ind_call thy defs gr dep thyname t u true) + | inductive_codegen thy defs gr dep thyname brack t = (case strip_comb t of (Const (s, _), ts) => (case Symtab.lookup (#eqns (CodegenData.get thy), s) of - NONE => list_of_indset thy gr dep brack t + NONE => list_of_indset thy defs gr dep thyname brack t | SOME eqns => let - val gr' = mk_fun thy s (preprocess thy eqns) dep gr - val (gr'', ps) = foldl_map (invoke_codegen thy dep true) (gr', ts); - in SOME (gr'', mk_app brack (Pretty.str (mk_const_id - (sign_of thy) s)) ps) + val (_, (_, thyname')) = split_last eqns; + val (gr', id) = mk_fun thy defs s (preprocess thy (map fst eqns)) + dep thyname thyname' gr; + val (gr'', ps) = foldl_map + (invoke_codegen thy defs dep thyname true) (gr', ts); + in SOME (gr'', mk_app brack (Pretty.str id) ps) end) | _ => NONE); val setup = [add_codegen "inductive" inductive_codegen, CodegenData.init, - add_attribute "ind" (Scan.succeed add)]; + add_attribute "ind" + (Scan.option (Args.$$$ "target" |-- Args.colon |-- Args.name) >> add)]; end;