diff -r 5611f178a747 -r a2a9018843ae src/HOL/Tools/inductive_codegen.ML --- a/src/HOL/Tools/inductive_codegen.ML Wed Apr 20 15:55:34 2011 +0200 +++ b/src/HOL/Tools/inductive_codegen.ML Wed Apr 20 16:18:47 2011 +0200 @@ -40,8 +40,9 @@ ); -fun warn thm = warning ("Inductive_Codegen: Not a proper clause:\n" ^ - Display.string_of_thm_without_context thm); +fun warn thy thm = + warning ("Inductive_Codegen: Not a proper clause:\n" ^ + Display.string_of_thm_global thy thm); fun add_node x g = Graph.new_node (x, ()) g handle Graph.DUP _ => g; @@ -50,14 +51,15 @@ val {intros, graph, eqns} = CodegenData.get thy; fun thyname_of s = (case optmod of NONE => Codegen.thyname_of_const thy s | SOME s => s); - in (case Option.map strip_comb (try HOLogic.dest_Trueprop (concl_of thm)) of + in + (case Option.map strip_comb (try HOLogic.dest_Trueprop (concl_of thm)) of SOME (Const (@{const_name HOL.eq}, _), [t, _]) => (case head_of t of Const (s, _) => CodegenData.put {intros = intros, graph = graph, eqns = eqns |> Symtab.map_default (s, []) (AList.update Thm.eq_thm_prop (thm, thyname_of s))} thy - | _ => (warn thm; thy)) + | _ => (warn thy thm; thy)) | SOME (Const (s, _), _) => let val cs = fold Term.add_const_names (Thm.prems_of thm) []; @@ -80,25 +82,26 @@ graph = fold_rev (Graph.add_edge o pair s) cs (fold add_node (s :: cs) graph), eqns = eqns} thy end - | _ => (warn thm; thy)) + | _ => (warn thy thm; thy)) end) I); fun get_clauses thy s = - let val {intros, graph, ...} = CodegenData.get thy - in case Symtab.lookup intros s of - NONE => (case try (Inductive.the_inductive (Proof_Context.init_global thy)) s of - NONE => NONE - | SOME ({names, ...}, {intrs, raw_induct, ...}) => - SOME (names, Codegen.thyname_of_const thy s, - length (Inductive.params_of raw_induct), - Codegen.preprocess thy intrs)) + let val {intros, graph, ...} = CodegenData.get thy in + (case Symtab.lookup intros s of + NONE => + (case try (Inductive.the_inductive (Proof_Context.init_global thy)) s of + NONE => NONE + | SOME ({names, ...}, {intrs, raw_induct, ...}) => + SOME (names, Codegen.thyname_of_const thy s, + length (Inductive.params_of raw_induct), + Codegen.preprocess thy intrs)) | SOME _ => let val SOME names = find_first (fn xs => member (op =) xs s) (Graph.strong_conn graph); val intrs as (_, (thyname, nparms)) :: _ = maps (the o Symtab.lookup intros) names; - in SOME (names, thyname, nparms, Codegen.preprocess thy (map fst (rev intrs))) end + in SOME (names, thyname, nparms, Codegen.preprocess thy (map fst (rev intrs))) end) end; @@ -109,19 +112,23 @@ val cnstrs = flat (maps (map (fn (_, (_, _, cs)) => map (apsnd length) cs) o #descr o snd) (Symtab.dest (Datatype_Data.get_all thy))); - fun check t = (case strip_comb t of + fun check t = + (case strip_comb t of (Var _, []) => true - | (Const (s, _), ts) => (case AList.lookup (op =) cnstrs s of + | (Const (s, _), ts) => + (case AList.lookup (op =) cnstrs s of NONE => false | SOME i => length ts = i andalso forall check ts) - | _ => false) + | _ => false); in check end; + (**** check if a type is an equality type (i.e. doesn't contain fun) ****) fun is_eqT (Type (s, Ts)) = s <> "fun" andalso forall is_eqT Ts | is_eqT _ = true; + (**** mode inference ****) fun string_of_mode (iss, is) = space_implode " -> " (map @@ -189,7 +196,8 @@ end end)) (AList.lookup op = modes name) - in (case strip_comb t of + in + (case strip_comb t of (Const (@{const_name HOL.eq}, Type (_, [T, _])), _) => [Mode ((([], [1]), false), [1], []), Mode ((([], [2]), false), [2], [])] @ (if is_eqT T then [Mode ((([], [1, 2]), false), [1, 2], [])] else []) @@ -247,43 +255,46 @@ (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [(([], js), false)])) (arg_vs ~~ iss); fun check_mode_prems vs rnd [] = SOME (vs, rnd) - | check_mode_prems vs rnd ps = (case select_mode_prem thy modes' vs ps of - (x, (m, []) :: _) :: _ => check_mode_prems - (case x of Prem (us, _, _) => union (op =) vs (terms_vs us) | _ => vs) - (rnd orelse needs_random m) - (filter_out (equal x) ps) - | (_, (_, vs') :: _) :: _ => - if use_random codegen_mode then - check_mode_prems (union (op =) vs (map (fst o fst) vs')) true ps - else NONE - | _ => NONE); + | check_mode_prems vs rnd ps = + (case select_mode_prem thy modes' vs ps of + (x, (m, []) :: _) :: _ => + check_mode_prems + (case x of Prem (us, _, _) => union (op =) vs (terms_vs us) | _ => vs) + (rnd orelse needs_random m) + (filter_out (equal x) ps) + | (_, (_, vs') :: _) :: _ => + if use_random codegen_mode then + check_mode_prems (union (op =) vs (map (fst o fst) vs')) true ps + else NONE + | _ => NONE); val (in_ts, in_ts') = List.partition (is_constrt thy) (fst (get_args is 1 ts)); val in_vs = terms_vs in_ts; in if forall is_eqT (map snd (duplicates (op =) (maps term_vTs in_ts))) andalso forall (is_eqT o fastype_of) in_ts' - then (case check_mode_prems (union (op =) arg_vs in_vs) rnd ps of - NONE => NONE - | SOME (vs, rnd') => - let val missing_vs = missing_vars vs ts - in - if null missing_vs orelse - use_random codegen_mode andalso monomorphic_vars missing_vs - then SOME (rnd' orelse not (null missing_vs)) - else NONE - end) + then + (case check_mode_prems (union (op =) arg_vs in_vs) rnd ps of + NONE => NONE + | SOME (vs, rnd') => + let val missing_vs = missing_vars vs ts + in + if null missing_vs orelse + use_random codegen_mode andalso monomorphic_vars missing_vs + then SOME (rnd' orelse not (null missing_vs)) + else NONE + end) else NONE end; fun check_modes_pred thy codegen_mode arg_vs preds modes (p, ms) = - let val SOME rs = AList.lookup (op =) preds p - in (p, List.mapPartial (fn m as (m', _) => - let val xs = map (check_mode_clause thy codegen_mode arg_vs modes m) rs - in case find_index is_none xs of - ~1 => SOME (m', exists (fn SOME b => b) xs) - | i => (Codegen.message ("Clause " ^ string_of_int (i+1) ^ " of " ^ - p ^ " violates mode " ^ string_of_mode m'); NONE) - end) ms) + let val SOME rs = AList.lookup (op =) preds p in + (p, List.mapPartial (fn m as (m', _) => + let val xs = map (check_mode_clause thy codegen_mode arg_vs modes m) rs in + (case find_index is_none xs of + ~1 => SOME (m', exists (fn SOME b => b) xs) + | i => (Codegen.message ("Clause " ^ string_of_int (i+1) ^ " of " ^ + p ^ " violates mode " ^ string_of_mode m'); NONE)) + end) ms) end; fun fixp f (x : (string * ((int list option list * int list) * bool) list) list) = @@ -297,16 +308,19 @@ | SOME k' => map SOME (subsets 1 k')) ks), subsets 1 k)))) arities); + (**** code generation ****) fun mk_eq (x::xs) = - let fun mk_eqs _ [] = [] - | mk_eqs a (b::cs) = Codegen.str (a ^ " = " ^ b) :: mk_eqs b cs + let + fun mk_eqs _ [] = [] + | mk_eqs a (b :: cs) = Codegen.str (a ^ " = " ^ b) :: mk_eqs b cs; in mk_eqs x xs end; -fun mk_tuple xs = Pretty.block (Codegen.str "(" :: - flat (separate [Codegen.str ",", Pretty.brk 1] (map single xs)) @ - [Codegen.str ")"]); +fun mk_tuple xs = + Pretty.block (Codegen.str "(" :: + flat (separate [Codegen.str ",", Pretty.brk 1] (map single xs)) @ + [Codegen.str ")"]); fun mk_v s (names, vs) = (case AList.lookup (op =) vs s of @@ -362,28 +376,34 @@ apfst single (Codegen.invoke_codegen thy codegen_mode defs dep module brack t gr) | compile_expr _ _ _ _ _ _ _ (SOME _, Var ((name, _), _)) gr = ([Codegen.str name], gr) - | compile_expr thy codegen_mode defs dep module brack modes (SOME (Mode ((mode, _), _, ms)), t) gr = + | compile_expr thy codegen_mode + defs dep module brack modes (SOME (Mode ((mode, _), _, ms)), t) gr = (case strip_comb t of - (Const (name, _), args) => - if name = @{const_name HOL.eq} orelse AList.defined op = modes name then - let - val (args1, args2) = chop (length ms) args; - val ((ps, mode_id), gr') = gr |> fold_map - (compile_expr thy codegen_mode defs dep module true modes) (ms ~~ args1) - ||>> modename module name mode; - val (ps', gr'') = (case mode of + (Const (name, _), args) => + if name = @{const_name HOL.eq} orelse AList.defined op = modes name then + let + val (args1, args2) = chop (length ms) args; + val ((ps, mode_id), gr') = + gr |> fold_map + (compile_expr thy codegen_mode defs dep module true modes) (ms ~~ args1) + ||>> modename module name mode; + val (ps', gr'') = + (case mode of ([], []) => ([Codegen.str "()"], gr') | _ => fold_map - (Codegen.invoke_codegen thy codegen_mode defs dep module true) args2 gr') - in ((if brack andalso not (null ps andalso null ps') then - single o Codegen.parens o Pretty.block else I) - (flat (separate [Pretty.brk 1] - ([Codegen.str mode_id] :: ps @ map single ps'))), gr') + (Codegen.invoke_codegen thy codegen_mode defs dep module true) args2 gr'); + in + ((if brack andalso not (null ps andalso null ps') then + single o Codegen.parens o Pretty.block else I) + (flat (separate [Pretty.brk 1] + ([Codegen.str mode_id] :: ps @ map single ps'))), gr') end - else apfst (single o mk_funcomp brack "??" (length (binder_types (fastype_of t)))) - (Codegen.invoke_codegen thy codegen_mode defs dep module true t gr) - | _ => apfst (single o mk_funcomp brack "??" (length (binder_types (fastype_of t)))) - (Codegen.invoke_codegen thy codegen_mode defs dep module true t gr)); + else + apfst (single o mk_funcomp brack "??" (length (binder_types (fastype_of t)))) + (Codegen.invoke_codegen thy codegen_mode defs dep module true t gr) + | _ => + apfst (single o mk_funcomp brack "??" (length (binder_types (fastype_of t)))) + (Codegen.invoke_codegen thy codegen_mode defs dep module true t gr)); fun compile_clause thy codegen_mode defs dep module all_vs arg_vs modes (iss, is) (ts, ps) inp gr = let @@ -407,13 +427,15 @@ fun compile_prems out_ts' vs names [] gr = let val (out_ps, gr2) = - fold_map (Codegen.invoke_codegen thy codegen_mode defs dep module false) out_ts gr; + fold_map (Codegen.invoke_codegen thy codegen_mode defs dep module false) + out_ts gr; val (eq_ps, gr3) = fold_map compile_eq eqs gr2; val (out_ts'', (names', eqs')) = fold_map check_constrt out_ts' (names, []); val (out_ts''', nvs) = fold_map distinct_v out_ts'' (names', map (fn x => (x, [x])) vs); val (out_ps', gr4) = - fold_map (Codegen.invoke_codegen thy codegen_mode defs dep module false) out_ts''' gr3; + fold_map (Codegen.invoke_codegen thy codegen_mode defs dep module false) + out_ts''' gr3; val (eq_ps', gr5) = fold_map compile_eq eqs' gr4; val vs' = distinct (op =) (flat (vs :: map term_vs out_ts')); val missing_vs = missing_vars vs' out_ts; @@ -425,17 +447,18 @@ final_p (exists (not o is_exhaustive) out_ts'''), gr5) else let - val (pat_p, gr6) = Codegen.invoke_codegen thy codegen_mode defs dep module true - (HOLogic.mk_tuple (map Var missing_vs)) gr5; + val (pat_p, gr6) = + Codegen.invoke_codegen thy codegen_mode defs dep module true + (HOLogic.mk_tuple (map Var missing_vs)) gr5; val gen_p = Codegen.mk_gen gr6 module true [] "" - (HOLogic.mk_tupleT (map snd missing_vs)) + (HOLogic.mk_tupleT (map snd missing_vs)); in (compile_match (snd nvs) eq_ps' out_ps' - (Pretty.block [Codegen.str "DSeq.generator ", gen_p, - Codegen.str " :->", Pretty.brk 1, - compile_match [] eq_ps [pat_p] final_p false]) - (exists (not o is_exhaustive) out_ts'''), + (Pretty.block [Codegen.str "DSeq.generator ", gen_p, + Codegen.str " :->", Pretty.brk 1, + compile_match [] eq_ps [pat_p] final_p false]) + (exists (not o is_exhaustive) out_ts'''), gr6) end end @@ -443,65 +466,68 @@ let val vs' = distinct (op =) (flat (vs :: map term_vs out_ts)); val (out_ts', (names', eqs)) = fold_map check_constrt out_ts (names, []); - val (out_ts'', nvs) = fold_map distinct_v out_ts' (names', map (fn x => (x, [x])) vs); + val (out_ts'', nvs) = + fold_map distinct_v out_ts' (names', map (fn x => (x, [x])) vs); val (out_ps, gr0) = - fold_map (Codegen.invoke_codegen thy codegen_mode defs dep module false) out_ts'' gr; + fold_map (Codegen.invoke_codegen thy codegen_mode defs dep module false) + out_ts'' gr; val (eq_ps, gr1) = fold_map compile_eq eqs gr0; in (case hd (select_mode_prem thy modes' vs' ps) of - (p as Prem (us, t, is_set), (mode as Mode (_, js, _), []) :: _) => - let - val ps' = filter_out (equal p) ps; - val (in_ts, out_ts''') = get_args js 1 us; - val (in_ps, gr2) = - fold_map (Codegen.invoke_codegen thy codegen_mode defs dep module true) in_ts gr1; - val (ps, gr3) = - if not is_set then - apfst (fn ps => ps @ - (if null in_ps then [] else [Pretty.brk 1]) @ - separate (Pretty.brk 1) in_ps) - (compile_expr thy codegen_mode defs dep module false modes - (SOME mode, t) gr2) - else - apfst (fn p => - Pretty.breaks [Codegen.str "DSeq.of_list", Codegen.str "(case", p, - Codegen.str "of", Codegen.str "Set", Codegen.str "xs", Codegen.str "=>", - Codegen.str "xs)"]) - (*this is a very strong assumption about the generated code!*) - (Codegen.invoke_codegen thy codegen_mode defs dep module true t gr2); + (p as Prem (us, t, is_set), (mode as Mode (_, js, _), []) :: _) => + let + val ps' = filter_out (equal p) ps; + val (in_ts, out_ts''') = get_args js 1 us; + val (in_ps, gr2) = + fold_map (Codegen.invoke_codegen thy codegen_mode defs dep module true) + in_ts gr1; + val (ps, gr3) = + if not is_set then + apfst (fn ps => ps @ + (if null in_ps then [] else [Pretty.brk 1]) @ + separate (Pretty.brk 1) in_ps) + (compile_expr thy codegen_mode defs dep module false modes + (SOME mode, t) gr2) + else + apfst (fn p => + Pretty.breaks [Codegen.str "DSeq.of_list", Codegen.str "(case", p, + Codegen.str "of", Codegen.str "Set", Codegen.str "xs", Codegen.str "=>", + Codegen.str "xs)"]) + (*this is a very strong assumption about the generated code!*) + (Codegen.invoke_codegen thy codegen_mode defs dep module true t gr2); val (rest, gr4) = compile_prems out_ts''' vs' (fst nvs) ps' gr3; in (compile_match (snd nvs) eq_ps out_ps - (Pretty.block (ps @ - [Codegen.str " :->", Pretty.brk 1, rest])) - (exists (not o is_exhaustive) out_ts''), gr4) + (Pretty.block (ps @ + [Codegen.str " :->", Pretty.brk 1, rest])) + (exists (not o is_exhaustive) out_ts''), gr4) end - | (p as Sidecond t, [(_, [])]) => - let - val ps' = filter_out (equal p) ps; - val (side_p, gr2) = + | (p as Sidecond t, [(_, [])]) => + let + val ps' = filter_out (equal p) ps; + val (side_p, gr2) = Codegen.invoke_codegen thy codegen_mode defs dep module true t gr1; - val (rest, gr3) = compile_prems [] vs' (fst nvs) ps' gr2; - in - (compile_match (snd nvs) eq_ps out_ps - (Pretty.block [Codegen.str "?? ", side_p, - Codegen.str " :->", Pretty.brk 1, rest]) - (exists (not o is_exhaustive) out_ts''), gr3) - end - | (_, (_, missing_vs) :: _) => - let - val T = HOLogic.mk_tupleT (map snd missing_vs); - val (_, gr2) = + val (rest, gr3) = compile_prems [] vs' (fst nvs) ps' gr2; + in + (compile_match (snd nvs) eq_ps out_ps + (Pretty.block [Codegen.str "?? ", side_p, + Codegen.str " :->", Pretty.brk 1, rest]) + (exists (not o is_exhaustive) out_ts''), gr3) + end + | (_, (_, missing_vs) :: _) => + let + val T = HOLogic.mk_tupleT (map snd missing_vs); + val (_, gr2) = Codegen.invoke_tycodegen thy codegen_mode defs dep module false T gr1; - val gen_p = Codegen.mk_gen gr2 module true [] "" T; - val (rest, gr3) = compile_prems - [HOLogic.mk_tuple (map Var missing_vs)] vs' (fst nvs) ps gr2 - in - (compile_match (snd nvs) eq_ps out_ps - (Pretty.block [Codegen.str "DSeq.generator", Pretty.brk 1, - gen_p, Codegen.str " :->", Pretty.brk 1, rest]) - (exists (not o is_exhaustive) out_ts''), gr3) - end) + val gen_p = Codegen.mk_gen gr2 module true [] "" T; + val (rest, gr3) = compile_prems + [HOLogic.mk_tuple (map Var missing_vs)] vs' (fst nvs) ps gr2; + in + (compile_match (snd nvs) eq_ps out_ps + (Pretty.block [Codegen.str "DSeq.generator", Pretty.brk 1, + gen_p, Codegen.str " :->", Pretty.brk 1, rest]) + (exists (not o is_exhaustive) out_ts''), gr3) + end) end; val (prem_p, gr') = compile_prems in_ts' arg_vs all_vs' ps gr ; @@ -560,9 +586,9 @@ fun constrain cs [] = [] | constrain cs ((s, xs) :: ys) = (s, - case AList.lookup (op =) cs (s : string) of + (case AList.lookup (op =) cs (s : string) of NONE => xs - | SOME xs' => inter (op = o apfst fst) xs' xs) :: constrain cs ys; + | SOME xs' => inter (op = o apfst fst) xs' xs)) :: constrain cs ys; fun mk_extra_defs thy codegen_mode defs gr dep names module ts = fold (fn name => fn gr => @@ -595,7 +621,8 @@ Prem ([t, u], eq, false) | dest_prem (_ $ t) = (case strip_comb t of - (v as Var _, ts) => if member (op =) args v then Prem (ts, v, false) else Sidecond t + (v as Var _, ts) => + if member (op =) args v then Prem (ts, v, false) else Sidecond t | (c as Const (s, _), ts) => (case get_nparms s of NONE => Sidecond t @@ -614,9 +641,10 @@ in (AList.update op = (name, these (AList.lookup op = clauses name) @ [(ts2, prems)]) clauses, - AList.update op = (name, (map (fn U => (case strip_type U of - (Rs as _ :: _, @{typ bool}) => SOME (length Rs) - | _ => NONE)) Ts, + AList.update op = (name, (map (fn U => + (case strip_type U of + (Rs as _ :: _, @{typ bool}) => SOME (length Rs) + | _ => NONE)) Ts, length Us)) arities) end; @@ -629,31 +657,35 @@ (infer_modes thy codegen_mode extra_modes arities arg_vs clauses); val _ = print_arities arities; val _ = print_modes modes; - val (s, gr'') = compile_preds thy codegen_mode defs (hd names) module (terms_vs intrs) - arg_vs (modes @ extra_modes) clauses gr'; + val (s, gr'') = + compile_preds thy codegen_mode defs (hd names) module (terms_vs intrs) + arg_vs (modes @ extra_modes) clauses gr'; in (Codegen.map_node (hd names) (K (SOME (Modes (modes, arities)), module, s)) gr'') end; -fun find_mode gr dep s u modes is = (case find_first (fn Mode (_, js, _) => is=js) - (modes_of modes u handle Option => []) of - NONE => Codegen.codegen_error gr dep - ("No such mode for " ^ s ^ ": " ^ string_of_mode ([], is)) - | mode => mode); +fun find_mode gr dep s u modes is = + (case find_first (fn Mode (_, js, _) => is = js) (modes_of modes u handle Option => []) of + NONE => + Codegen.codegen_error gr dep + ("No such mode for " ^ s ^ ": " ^ string_of_mode ([], is)) + | mode => mode); fun mk_ind_call thy codegen_mode defs dep module is_query s T ts names thyname k intrs gr = let val (ts1, ts2) = chop k ts; val u = list_comb (Const (s, T), ts1); - fun mk_mode (Const (@{const_name dummy_pattern}, _)) ((ts, mode), i) = ((ts, mode), i + 1) + fun mk_mode (Const (@{const_name dummy_pattern}, _)) ((ts, mode), i) = + ((ts, mode), i + 1) | mk_mode t ((ts, mode), i) = ((ts @ [t], mode @ [i]), i + 1); val module' = Codegen.if_library codegen_mode thyname module; - val gr1 = mk_extra_defs thy codegen_mode defs - (mk_ind_def thy codegen_mode defs gr dep names module' - [] (prep_intrs intrs) k) dep names module' [u]; + val gr1 = + mk_extra_defs thy codegen_mode defs + (mk_ind_def thy codegen_mode defs gr dep names module' + [] (prep_intrs intrs) k) dep names module' [u]; val (modes, _) = lookup_modes gr1 dep; val (ts', is) = if is_query then fst (fold mk_mode ts2 (([], []), 1)) @@ -662,8 +694,10 @@ val _ = if is_query orelse not (needs_random (the mode)) then () else warning ("Illegal use of random data generators in " ^ s); val (in_ps, gr2) = - fold_map (Codegen.invoke_codegen thy codegen_mode defs dep module true) ts' gr1; - val (ps, gr3) = compile_expr thy codegen_mode defs dep module false modes (mode, u) gr2; + fold_map (Codegen.invoke_codegen thy codegen_mode defs dep module true) + ts' gr1; + val (ps, gr3) = + compile_expr thy codegen_mode defs dep module false modes (mode, u) gr2; in (Pretty.block (ps @ (if null in_ps then [] else [Pretty.brk 1]) @ separate (Pretty.brk 1) in_ps), gr3) @@ -680,32 +714,34 @@ end; fun mk_fun thy codegen_mode defs name eqns dep module module' gr = - case try (Codegen.get_node gr) name of + (case try (Codegen.get_node gr) name of NONE => - let - val clauses = map clause_of_eqn eqns; - val pname = name ^ "_aux"; - val arity = length (snd (strip_comb (fst (HOLogic.dest_eq - (HOLogic.dest_Trueprop (concl_of (hd eqns))))))); - val mode = 1 upto arity; - val ((fun_id, mode_id), gr') = gr |> - Codegen.mk_const_id module' name ||>> - modename module' pname ([], mode); - val vars = map (fn i => Codegen.str ("x" ^ string_of_int i)) mode; - val s = Codegen.string_of (Pretty.block - [Codegen.mk_app false (Codegen.str ("fun " ^ snd fun_id)) vars, Codegen.str " =", - Pretty.brk 1, Codegen.str "DSeq.hd", Pretty.brk 1, - Codegen.parens (Pretty.block (separate (Pretty.brk 1) (Codegen.str mode_id :: - vars)))]) ^ ";\n\n"; - val gr'' = mk_ind_def thy codegen_mode defs (Codegen.add_edge (name, dep) - (Codegen.new_node (name, (NONE, module', s)) gr')) name [pname] module' - [(pname, [([], mode)])] clauses 0; - val (modes, _) = lookup_modes gr'' dep; - val _ = find_mode gr'' dep pname (head_of (HOLogic.dest_Trueprop - (Logic.strip_imp_concl (hd clauses)))) modes mode - in (Codegen.mk_qual_id module fun_id, gr'') end + let + val clauses = map clause_of_eqn eqns; + val pname = name ^ "_aux"; + val arity = + length (snd (strip_comb (fst (HOLogic.dest_eq + (HOLogic.dest_Trueprop (concl_of (hd eqns))))))); + val mode = 1 upto arity; + val ((fun_id, mode_id), gr') = gr |> + Codegen.mk_const_id module' name ||>> + modename module' pname ([], mode); + val vars = map (fn i => Codegen.str ("x" ^ string_of_int i)) mode; + val s = Codegen.string_of (Pretty.block + [Codegen.mk_app false (Codegen.str ("fun " ^ snd fun_id)) vars, Codegen.str " =", + Pretty.brk 1, Codegen.str "DSeq.hd", Pretty.brk 1, + Codegen.parens (Pretty.block (separate (Pretty.brk 1) (Codegen.str mode_id :: + vars)))]) ^ ";\n\n"; + val gr'' = mk_ind_def thy codegen_mode defs (Codegen.add_edge (name, dep) + (Codegen.new_node (name, (NONE, module', s)) gr')) name [pname] module' + [(pname, [([], mode)])] clauses 0; + val (modes, _) = lookup_modes gr'' dep; + val _ = find_mode gr'' dep pname (head_of (HOLogic.dest_Trueprop + (Logic.strip_imp_concl (hd clauses)))) modes mode + in (Codegen.mk_qual_id module fun_id, gr'') end | SOME _ => - (Codegen.mk_qual_id module (Codegen.get_const_id gr name), Codegen.add_edge (name, dep) gr); + (Codegen.mk_qual_id module (Codegen.get_const_id gr name), + Codegen.add_edge (name, dep) gr)); (* convert n-tuple to nested pairs *) @@ -730,10 +766,11 @@ else p end; -fun inductive_codegen thy codegen_mode defs dep module brack t gr = (case strip_comb t of +fun inductive_codegen thy codegen_mode defs dep module brack t gr = + (case strip_comb t of (Const (@{const_name Collect}, _), [u]) => - let val (r, Ts, fs) = HOLogic.strip_psplits u - in case strip_comb r of + let val (r, Ts, fs) = HOLogic.strip_psplits u in + (case strip_comb r of (Const (s, T), ts) => (case (get_clauses thy s, Codegen.get_assoc_code thy (s, T)) of (SOME (names, thyname, k, intrs), NONE) => @@ -742,47 +779,55 @@ val ts2' = map (fn Bound i => Term.dummy_pattern (nth Ts (length Ts - i - 1)) | t => t) ts2; val (ots, its) = List.partition is_Bound ts2; - val closed = forall (not o Term.is_open) + val closed = forall (not o Term.is_open); in if null (duplicates op = ots) andalso closed ts1 andalso closed its then - let val (call_p, gr') = mk_ind_call thy codegen_mode defs dep module true - s T (ts1 @ ts2') names thyname k intrs gr - in SOME ((if brack then Codegen.parens else I) (Pretty.block - [Codegen.str "Set", Pretty.brk 1, Codegen.str "(DSeq.list_of", Pretty.brk 1, - Codegen.str "(", conv_ntuple fs ots call_p, Codegen.str "))"]), - (*this is a very strong assumption about the generated code!*) - gr') + let + val (call_p, gr') = + mk_ind_call thy codegen_mode defs dep module true + s T (ts1 @ ts2') names thyname k intrs gr; + in + SOME ((if brack then Codegen.parens else I) (Pretty.block + [Codegen.str "Set", Pretty.brk 1, Codegen.str "(DSeq.list_of", Pretty.brk 1, + Codegen.str "(", conv_ntuple fs ots call_p, Codegen.str "))"]), + (*this is a very strong assumption about the generated code!*) + gr') end else NONE end | _ => NONE) - | _ => NONE + | _ => NONE) end | (Const (s, T), ts) => - (case Symtab.lookup (#eqns (CodegenData.get thy)) s of - NONE => - (case (get_clauses thy s, Codegen.get_assoc_code thy (s, T)) of - (SOME (names, thyname, k, intrs), NONE) => - if length ts < k then NONE else SOME - (let val (call_p, gr') = mk_ind_call thy codegen_mode defs dep module false - s T (map Term.no_dummy_patterns ts) names thyname k intrs gr - in (mk_funcomp brack "?!" - (length (binder_types T) - length ts) (Codegen.parens call_p), gr') - end handle TERM _ => mk_ind_call thy codegen_mode defs dep module true - s T ts names thyname k intrs gr ) - | _ => NONE) - | SOME eqns => - let - val (_, thyname) :: _ = eqns; - val (id, gr') = - mk_fun thy codegen_mode defs s (Codegen.preprocess thy (map fst (rev eqns))) - dep module (Codegen.if_library codegen_mode thyname module) gr; - val (ps, gr'') = - fold_map (Codegen.invoke_codegen thy codegen_mode defs dep module true) ts gr'; - in SOME (Codegen.mk_app brack (Codegen.str id) ps, gr'') - end) + (case Symtab.lookup (#eqns (CodegenData.get thy)) s of + NONE => + (case (get_clauses thy s, Codegen.get_assoc_code thy (s, T)) of + (SOME (names, thyname, k, intrs), NONE) => + if length ts < k then NONE else + SOME + (let + val (call_p, gr') = mk_ind_call thy codegen_mode defs dep module false + s T (map Term.no_dummy_patterns ts) names thyname k intrs gr + in + (mk_funcomp brack "?!" + (length (binder_types T) - length ts) (Codegen.parens call_p), gr') + end + handle TERM _ => + mk_ind_call thy codegen_mode defs dep module true + s T ts names thyname k intrs gr) + | _ => NONE) + | SOME eqns => + let + val (_, thyname) :: _ = eqns; + val (id, gr') = + mk_fun thy codegen_mode defs s (Codegen.preprocess thy (map fst (rev eqns))) + dep module (Codegen.if_library codegen_mode thyname module) gr; + val (ps, gr'') = + fold_map (Codegen.invoke_codegen thy codegen_mode defs dep module true) + ts gr'; + in SOME (Codegen.mk_app brack (Codegen.str id) ps, gr'') end) | _ => NONE); val setup = @@ -812,7 +857,8 @@ fun deepen bound f i = if i > bound then NONE - else (case f i of + else + (case f i of NONE => deepen bound f (i + 1) | SOME x => SOME x);