# HG changeset patch # User berghofe # Date 999269346 -7200 # Node ID e007d35359c316b6ffe2a527f274dcc4a57c73c3 # Parent 6adf4d53267917445401b97711364ea02ba97b05 New code generators for HOL. diff -r 6adf4d532679 -r e007d35359c3 src/HOL/Tools/basic_codegen.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/basic_codegen.ML Fri Aug 31 16:49:06 2001 +0200 @@ -0,0 +1,249 @@ +(* Title: Pure/HOL/basic_codegen.ML + ID: $Id$ + Author: Stefan Berghofer + Copyright 2000 TU Muenchen + +Code generator for inductive datatypes and recursive functions +*) + +signature BASIC_CODEGEN = +sig + val setup: (theory -> theory) list +end; + +structure BasicCodegen : BASIC_CODEGEN = +struct + +open Codegen; + +fun mk_poly_id thy (s, T) = mk_const_id (sign_of thy) s ^ + (case get_defn thy s T of + Some (_, Some i) => "_def" ^ string_of_int i + | _ => ""); + +fun mk_tuple [p] = p + | mk_tuple ps = Pretty.block (Pretty.str "(" :: + flat (separate [Pretty.str ",", Pretty.brk 1] (map single ps)) @ + [Pretty.str ")"]); + +fun add_rec_funs thy dep (gr, eqs) = + let + fun dest_eq t = + let val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop + (Logic.strip_imp_concl (rename_term t))) + in + (mk_poly_id thy (dest_Const (head_of lhs)), (lhs, rhs)) + end; + val eqs' = sort (string_ord o pairself fst) (map dest_eq eqs); + val (dname, _) :: _ = eqs'; + + fun mk_fundef fname prfx gr [] = (gr, []) + | mk_fundef fname prfx gr ((fname', (lhs, rhs))::xs) = + let + val (gr1, pl) = invoke_codegen thy gr dname false lhs; + val (gr2, pr) = invoke_codegen thy gr1 dname false rhs; + val (gr3, rest) = mk_fundef fname' "and " gr2 xs + in + (gr3, Pretty.blk (4, [Pretty.str (if fname=fname' then " | " else prfx), + pl, Pretty.str " =", Pretty.brk 1, pr]) :: rest) + end + + in + (Graph.add_edge (dname, dep) gr handle Graph.UNDEF _ => + let + val gr1 = Graph.add_edge (dname, dep) + (Graph.new_node (dname, (None, "")) gr); + val (gr2, fundef) = mk_fundef "" "fun " gr1 eqs' + in + Graph.map_node dname (K (None, Pretty.string_of (Pretty.blk (0, + separate Pretty.fbrk fundef @ [Pretty.str ";"])) ^ "\n\n")) gr2 + end) + end; + + +(**** generate functions for datatypes specified by descr ****) +(**** (i.e. constructors and case combinators) ****) + +fun mk_typ _ _ (TVar ((s, i), _)) = + Pretty.str (s ^ (if i=0 then "" else string_of_int i)) + | mk_typ _ _ (TFree (s, _)) = Pretty.str s + | mk_typ sg types (Type ("fun", [T, U])) = Pretty.block [Pretty.str "(", + mk_typ sg types T, Pretty.str " ->", Pretty.brk 1, + mk_typ sg types U, Pretty.str ")"] + | mk_typ sg types (Type (s, Ts)) = Pretty.block ((if null Ts then [] else + [mk_tuple (map (mk_typ sg types) Ts), Pretty.str " "]) @ + [Pretty.str (if_none (assoc (types, s)) (mk_type_id sg s))]); + +fun add_dt_defs thy dep (gr, descr) = + let + val sg = sign_of thy; + val tab = DatatypePackage.get_datatypes thy; + + val descr' = filter (can (map DatatypeAux.dest_DtTFree o #2 o snd)) descr; + + val (_, (_, _, (cname, _) :: _)) :: _ = descr'; + val dname = mk_const_id sg cname; + + fun mk_dtdef gr prfx [] = (gr, []) + | mk_dtdef gr prfx ((_, (tname, dts, cs))::xs) = + let + val types = get_assoc_types thy; + val tvs = map DatatypeAux.dest_DtTFree dts; + val sorts = map (rpair []) tvs; + val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs; + val tycons = foldr add_typ_tycons (flat (map snd cs'), []) \\ + ("fun" :: map fst types); + val descrs = map (fn s => case Symtab.lookup (tab, s) of + None => error ("Not a datatype: " ^ s ^ "\nrequired by:\n" ^ + commas (Graph.all_succs gr [dep])) + | Some info => #descr info) tycons; + val gr' = foldl (add_dt_defs thy dname) (gr, descrs); + val (gr'', rest) = mk_dtdef gr' "and " xs + in + (gr'', + Pretty.block (Pretty.str prfx :: + (if null tvs then [] else + [mk_tuple (map Pretty.str tvs), Pretty.str " "]) @ + [Pretty.str (mk_type_id sg tname ^ " ="), Pretty.brk 1] @ + flat (separate [Pretty.brk 1, Pretty.str "| "] + (map (fn (cname, cargs) => [Pretty.block + (Pretty.str (mk_const_id sg cname) :: + (if null cargs then [] else + flat ([Pretty.str " of", Pretty.brk 1] :: + separate [Pretty.str " *", Pretty.brk 1] + (map (single o mk_typ sg types) cargs))))]) cs'))) :: rest) + end + in + ((Graph.add_edge_acyclic (dname, dep) gr + handle Graph.CYCLES _ => gr) handle Graph.UNDEF _ => + let + val gr1 = Graph.add_edge (dname, dep) + (Graph.new_node (dname, (None, "")) gr); + val (gr2, dtdef) = mk_dtdef gr1 "datatype " descr'; + in + Graph.map_node dname (K (None, + Pretty.string_of (Pretty.blk (0, separate Pretty.fbrk dtdef @ + [Pretty.str ";"])) ^ "\n\n")) gr2 + end) + end; + + +(**** generate code for applications of constructors and case ****) +(**** combinators for datatypes ****) + +fun pretty_case thy gr dep brack constrs (c as Const (_, T)) ts = + let val i = length constrs + in if length ts <= i then + invoke_codegen thy gr dep brack (eta_expand c ts (i+1)) + else + let + val ts1 = take (i, ts); + val t :: ts2 = drop (i, ts); + val names = foldr add_term_names (ts1, + map (fst o fst o dest_Var) (foldr add_term_vars (ts1, []))); + val (Ts, dT) = split_last (take (i+1, fst (strip_type T))); + + fun pcase gr [] [] [] = ([], gr) + | pcase gr ((cname, cargs)::cs) (t::ts) (U::Us) = + let + val j = length cargs; + val (Ts, _) = strip_type (fastype_of t); + val xs = variantlist (replicate j "x", names); + val Us' = take (j, fst (strip_type U)); + val frees = map Free (xs ~~ Us'); + val (gr0, cp) = invoke_codegen thy gr dep false + (list_comb (Const (cname, Us' ---> dT), frees)); + val t' = Envir.beta_norm (list_comb (t, frees)); + val (gr1, p) = invoke_codegen thy gr0 dep false t'; + val (ps, gr2) = pcase gr1 cs ts Us; + in + ([Pretty.block [cp, Pretty.str " =>", Pretty.brk 1, p]] :: ps, gr2) + end; + + val (ps1, gr1) = pcase gr constrs ts1 Ts; + val ps = flat (separate [Pretty.brk 1, Pretty.str "| "] ps1); + val (gr2, p) = invoke_codegen thy gr1 dep false t; + val (gr3, ps2) = foldl_map + (fn (gr, t) => invoke_codegen thy gr dep true t) (gr2, ts2) + in (gr3, (if not (null ts2) andalso brack then parens else I) + (Pretty.block (separate (Pretty.brk 1) + (Pretty.block ([Pretty.str "(case ", p, Pretty.str " of", + Pretty.brk 1] @ ps @ [Pretty.str ")"]) :: ps2)))) + end + end; + + +fun pretty_constr thy gr dep brack args (c as Const (s, _)) ts = + let val i = length args + in if length ts < i then + invoke_codegen thy gr dep brack (eta_expand c ts i) + else + let + val id = mk_const_id (sign_of thy) s; + val (gr', ps) = foldl_map + (fn (gr, t) => invoke_codegen thy gr dep (i = 1) t) (gr, ts); + in (case args of + [] => (gr', Pretty.str id) + | [_] => (gr', mk_app brack (Pretty.str id) ps) + | _ => (gr', (if brack then parens else I) (Pretty.block + ([Pretty.str id, Pretty.brk 1, Pretty.str "("] @ + flat (separate [Pretty.str ",", Pretty.brk 1] (map single ps)) @ + [Pretty.str ")"])))) + end + end; + + +fun mk_recfun thy gr dep brack s T ts eqns = + let val (gr', ps) = foldl_map + (fn (gr, t) => invoke_codegen thy gr dep true t) (gr, ts) + in + Some (add_rec_funs thy dep (gr', map (#prop o rep_thm) eqns), + mk_app brack (Pretty.str (mk_poly_id thy (s, T))) ps) + end; + + +fun datatype_codegen thy gr dep brack t = (case strip_comb t of + (c as Const (s, T), ts) => + (case find_first (fn (_, {index, descr, case_name, rec_names, ...}) => + s = case_name orelse s mem rec_names orelse + is_some (assoc (#3 (the (assoc (descr, index))), s))) + (Symtab.dest (DatatypePackage.get_datatypes thy)) of + None => None + | Some (tname, {index, descr, case_name, rec_names, rec_rewrites, ...}) => + if is_some (get_assoc_code thy s T) then None else + let + val Some (_, _, constrs) = assoc (descr, index); + val gr1 = + if exists (equal tname o fst) (get_assoc_types thy) then gr + else add_dt_defs thy dep (gr, descr); + in + (case assoc (constrs, s) of + None => if s mem rec_names then + mk_recfun thy gr1 dep brack s T ts rec_rewrites + else Some (pretty_case thy gr1 dep brack constrs c ts) + | Some args => Some (pretty_constr thy gr1 dep brack args c ts)) + end) + | _ => None); + + +(**** generate code for primrec and recdef ****) + +fun recfun_codegen thy gr dep brack t = (case strip_comb t of + (Const (s, T), ts) => + (case PrimrecPackage.get_primrec thy s of + Some ps => (case find_first (fn (_, thm::_) => + is_instance thy T (snd (dest_Const (head_of + (fst (HOLogic.dest_eq + (HOLogic.dest_Trueprop (#prop (rep_thm thm))))))))) ps of + Some (_, thms) => mk_recfun thy gr dep brack s T ts thms + | None => None) + | None => case RecdefPackage.get_recdef thy s of + Some {simps, ...} => mk_recfun thy gr dep brack s T ts simps + | None => None) + | _ => None); + + +val setup = [add_codegen "datatype" datatype_codegen, + add_codegen "primrec+recdef" recfun_codegen]; + +end; diff -r 6adf4d532679 -r e007d35359c3 src/HOL/Tools/inductive_codegen.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/inductive_codegen.ML Fri Aug 31 16:49:06 2001 +0200 @@ -0,0 +1,381 @@ +(* Title: Pure/HOL/inductive_codegen.ML + ID: $Id$ + Author: Stefan Berghofer + Copyright 2000 TU Muenchen + +Code generator for inductive predicates +*) + +signature INDUCTIVE_CODEGEN = +sig + val setup : (theory -> theory) list +end; + +structure InductiveCodegen : INDUCTIVE_CODEGEN = +struct + +open Codegen; + +exception Modes of (string * int list list) list * (string * int list list) list; + +datatype indprem = Prem of string * term list * term list + | Sidecond of term; + +fun prod_factors p (Const ("Pair", _) $ t $ u) = + p :: prod_factors (1::p) t @ prod_factors (2::p) u + | prod_factors p _ = []; + +fun split_prod p ps t = if p mem ps then (case t of + Const ("Pair", _) $ t $ u => + split_prod (1::p) ps t @ split_prod (2::p) ps u + | _ => error "Inconsistent use of products") else [t]; + +fun string_of_factors p ps = if p mem ps then + "(" ^ string_of_factors (1::p) ps ^ ", " ^ string_of_factors (2::p) ps ^ ")" + else "_"; + +(**** check if a term contains only constructor functions ****) + +fun is_constrt thy = + let + val cnstrs = flat (flat (map + (map (fn (_, (_, _, cs)) => map (apsnd length) cs) o #descr o snd) + (Symtab.dest (DatatypePackage.get_datatypes thy)))); + fun check t = (case strip_comb t of + (Var _, []) => true + | (Const (s, _), ts) => (case assoc (cnstrs, s) of + None => false + | Some i => length ts = i andalso forall check ts) + | _ => false) + in check end; + +(**** check if a type is an equality type (i.e. doesn't contain fun) ****) + +fun is_eqT (Type (s, Ts)) = s <> "fun" andalso forall is_eqT Ts + | is_eqT _ = true; + +(**** mode inference ****) + +val term_vs = map (fst o fst o dest_Var) o term_vars; +val terms_vs = distinct o flat o (map term_vs); + +(** collect all Vars in a term (with duplicates!) **) +fun term_vTs t = map (apfst fst o dest_Var) + (filter is_Var (foldl_aterms (op :: o Library.swap) ([], t))); + +fun known_args _ _ [] = [] + | known_args vs i (t::ts) = if term_vs t subset vs then i::known_args vs (i+1) ts + else known_args vs (i+1) ts; + +fun get_args _ _ [] = ([], []) + | get_args is i (x::xs) = (if i mem is then apfst else apsnd) (cons x) + (get_args is (i+1) xs); + +fun merge xs [] = xs + | merge [] ys = ys + | merge (x::xs) (y::ys) = if length x >= length y then x::merge xs (y::ys) + else y::merge (x::xs) ys; + +fun subsets i j = if i <= j then + let val is = subsets (i+1) j + in merge (map (fn ks => i::ks) is) is end + else [[]]; + +fun select_mode_prem thy modes vs ps = + find_first (is_some o snd) (ps ~~ map + (fn Prem (s, us, args) => find_first (fn is => + let + val (_, out_ts) = get_args is 1 us; + val vTs = flat (map term_vTs out_ts); + val dupTs = map snd (duplicates vTs) @ + mapfilter (curry assoc vTs) vs; + in + is subset known_args vs 1 us andalso + forall (is_constrt thy) (snd (get_args is 1 us)) andalso + terms_vs args subset vs andalso + forall is_eqT dupTs + end) + (the (assoc (modes, s))) + | Sidecond t => if term_vs t subset vs then Some [] else None) ps); + +fun check_mode_clause thy arg_vs modes mode (ts, ps) = + let + fun check_mode_prems vs [] = Some vs + | check_mode_prems vs ps = (case select_mode_prem thy modes vs ps of + None => None + | Some (x, _) => check_mode_prems + (case x of Prem (_, us, _) => vs union terms_vs us | _ => vs) + (filter_out (equal x) ps)); + val (in_ts', _) = get_args mode 1 ts; + val in_ts = filter (is_constrt thy) in_ts'; + val in_vs = terms_vs in_ts; + val concl_vs = terms_vs ts + in + forall is_eqT (map snd (duplicates (flat (map term_vTs in_ts')))) andalso + (case check_mode_prems (arg_vs union in_vs) ps of + None => false + | Some vs => concl_vs subset vs) + end; + +fun check_modes_pred thy arg_vs preds modes (p, ms) = + let val Some rs = assoc (preds, p) + in (p, filter (fn m => forall (check_mode_clause thy arg_vs modes m) rs) ms) end + +fun fixp f x = + let val y = f x + in if x = y then x else fixp f y end; + +fun infer_modes thy extra_modes arg_vs preds = fixp (fn modes => + map (check_modes_pred thy arg_vs preds (modes @ extra_modes)) modes) + (map (fn (s, (ts, _)::_) => (s, subsets 1 (length ts))) preds); + +(**** code generation ****) + +fun mk_eq (x::xs) = + let fun mk_eqs _ [] = [] + | mk_eqs a (b::cs) = Pretty.str (a ^ " = " ^ b) :: mk_eqs b cs + in mk_eqs x xs end; + +fun mk_tuple xs = Pretty.block (Pretty.str "(" :: + flat (separate [Pretty.str ",", Pretty.brk 1] (map single xs)) @ + [Pretty.str ")"]); + +fun mk_v ((names, vs), s) = (case assoc (vs, s) of + None => ((names, (s, [s])::vs), s) + | Some xs => let val s' = variant names s in + ((s'::names, overwrite (vs, (s, s'::xs))), s') end); + +fun distinct_v (nvs, Var ((s, 0), T)) = + apsnd (Var o rpair T o rpair 0) (mk_v (nvs, s)) + | distinct_v (nvs, t $ u) = + let + val (nvs', t') = distinct_v (nvs, t); + val (nvs'', u') = distinct_v (nvs', u); + in (nvs'', t' $ u') end + | distinct_v x = x; + +fun compile_match nvs eq_ps out_ps success_p fail_p = + let val eqs = flat (separate [Pretty.str " andalso", Pretty.brk 1] + (map single (flat (map (mk_eq o snd) nvs) @ eq_ps))); + in + Pretty.block + ([Pretty.str "(fn ", mk_tuple out_ps, Pretty.str " =>", Pretty.brk 1] @ + (Pretty.block ((if eqs=[] then [] else Pretty.str "if " :: + [Pretty.block eqs, Pretty.brk 1, Pretty.str "then "]) @ + (success_p :: + (if eqs=[] then [] else [Pretty.brk 1, Pretty.str "else ", fail_p]))) :: + [Pretty.brk 1, Pretty.str "| _ => ", fail_p, Pretty.str ")"])) + end; + +fun modename thy s mode = space_implode "_" + (mk_const_id (sign_of thy) s :: map string_of_int mode); + +fun compile_clause thy gr dep all_vs arg_vs modes mode (ts, ps) = + let + fun check_constrt ((names, eqs), t) = + if is_constrt thy t then ((names, eqs), t) else + let val s = variant names "x"; + in ((s::names, (s, t)::eqs), Var ((s, 0), fastype_of t)) end; + + val (in_ts, out_ts) = get_args mode 1 ts; + val ((all_vs', eqs), in_ts') = + foldl_map check_constrt ((all_vs, []), in_ts); + + fun compile_prems out_ts' vs names gr [] = + let + val (gr2, out_ps) = foldl_map (fn (gr, t) => + invoke_codegen thy gr dep false t) (gr, out_ts); + val (gr3, eq_ps) = foldl_map (fn (gr, (s, t)) => + apsnd (Pretty.block o cons (Pretty.str (s ^ " = ")) o single) + (invoke_codegen thy gr dep false t)) (gr2, eqs); + val (nvs, out_ts'') = foldl_map distinct_v + ((names, map (fn x => (x, [x])) vs), out_ts'); + val (gr4, out_ps') = foldl_map (fn (gr, t) => + invoke_codegen thy gr dep false t) (gr3, out_ts''); + in + (gr4, compile_match (snd nvs) eq_ps out_ps' + (Pretty.block [Pretty.str "Seq.single", Pretty.brk 1, mk_tuple out_ps]) + (Pretty.str "Seq.empty")) + end + | compile_prems out_ts vs names gr ps = + let + val vs' = distinct (flat (vs :: map term_vs out_ts)); + val Some (p, Some mode') = + select_mode_prem thy modes (arg_vs union vs') ps; + val ps' = filter_out (equal p) ps; + in + (case p of + Prem (s, us, args) => + let + val (in_ts, out_ts') = get_args mode' 1 us; + val (gr1, in_ps) = foldl_map (fn (gr, t) => + invoke_codegen thy gr dep false t) (gr, in_ts); + val (gr2, arg_ps) = foldl_map (fn (gr, t) => + invoke_codegen thy gr dep true t) (gr1, args); + val (nvs, out_ts'') = foldl_map distinct_v + ((names, map (fn x => (x, [x])) vs), out_ts); + val (gr3, out_ps) = foldl_map (fn (gr, t) => + invoke_codegen thy gr dep false t) (gr2, out_ts'') + val (gr4, rest) = compile_prems out_ts' vs' (fst nvs) gr3 ps'; + in + (gr4, compile_match (snd nvs) [] out_ps + (Pretty.block (separate (Pretty.brk 1) + (Pretty.str (modename thy s mode') :: arg_ps) @ + [Pretty.brk 1, mk_tuple in_ps, + Pretty.str " :->", Pretty.brk 1, rest])) + (Pretty.str "Seq.empty")) + end + | Sidecond t => + let + val (gr1, side_p) = invoke_codegen thy gr dep true t; + val (nvs, out_ts') = foldl_map distinct_v + ((names, map (fn x => (x, [x])) vs), out_ts); + val (gr2, out_ps) = foldl_map (fn (gr, t) => + invoke_codegen thy gr dep false t) (gr1, out_ts') + val (gr3, rest) = compile_prems [] vs' (fst nvs) gr2 ps'; + in + (gr3, compile_match (snd nvs) [] out_ps + (Pretty.block [Pretty.str "?? ", side_p, + Pretty.str " :->", Pretty.brk 1, rest]) + (Pretty.str "Seq.empty")) + end) + end; + + val (gr', prem_p) = compile_prems in_ts' [] all_vs' gr ps; + in + (gr', Pretty.block [Pretty.str "Seq.single inp :->", Pretty.brk 1, prem_p]) + end; + +fun compile_pred thy gr dep prfx all_vs arg_vs modes s cls mode = + let val (gr', cl_ps) = foldl_map (fn (gr, cl) => + compile_clause thy gr dep all_vs arg_vs modes mode cl) (gr, cls) + in + ((gr', "and "), Pretty.block + ([Pretty.block (separate (Pretty.brk 1) + (Pretty.str (prfx ^ modename thy s mode) :: map Pretty.str arg_vs) @ + [Pretty.str " inp ="]), + Pretty.brk 1] @ + flat (separate [Pretty.str " ++", Pretty.brk 1] (map single cl_ps)))) + end; + +fun compile_preds thy gr dep all_vs arg_vs modes preds = + let val ((gr', _), prs) = foldl_map (fn ((gr, prfx), (s, cls)) => + foldl_map (fn ((gr', prfx'), mode) => + compile_pred thy gr' dep prfx' all_vs arg_vs modes s cls mode) + ((gr, prfx), the (assoc (modes, s)))) ((gr, "fun "), preds) + in + (gr', space_implode "\n\n" (map Pretty.string_of (flat prs)) ^ ";\n\n") + end; + +(**** processing of introduction rules ****) + +val string_of_mode = enclose "[" "]" o commas o map string_of_int; + +fun print_modes modes = message ("Inferred modes:\n" ^ + space_implode "\n" (map (fn (s, ms) => s ^ ": " ^ commas (map + string_of_mode ms)) modes)); + +fun print_factors factors = message ("Factors:\n" ^ + space_implode "\n" (map (fn (s, fs) => s ^ ": " ^ string_of_factors [] fs) factors)); + +fun get_modes (Some (Modes x), _) = x + | get_modes _ = ([], []); + +fun mk_ind_def thy gr dep names intrs = + let val ids = map (mk_const_id (sign_of thy)) names + in Graph.add_edge (hd ids, dep) gr handle Graph.UNDEF _ => + let + fun process_prem factors (gr, t' as _ $ (Const ("op :", _) $ t $ u)) = + (case strip_comb u of + (Const (name, _), args) => + (case InductivePackage.get_inductive thy name of + None => (gr, Sidecond t') + | Some ({names=names', ...}, {intrs=intrs', ...}) => + (if names = names' then gr + else mk_ind_def thy gr (hd ids) names' intrs', + Prem (name, split_prod [] + (the (assoc (factors, name))) t, args))) + | _ => (gr, Sidecond t')) + | process_prem factors (gr, _ $ (Const ("op =", _) $ t $ u)) = + (gr, Prem ("eq", [t, u], [])) + | process_prem factors (gr, _ $ t) = (gr, Sidecond t); + + fun add_clause factors ((clauses, gr), intr) = + let + val _ $ (_ $ t $ u) = Logic.strip_imp_concl intr; + val (Const (name, _), args) = strip_comb u; + val (gr', prems) = foldl_map (process_prem factors) + (gr, Logic.strip_imp_prems intr); + in + (overwrite (clauses, (name, if_none (assoc (clauses, name)) [] @ + [(split_prod [] (the (assoc (factors, name))) t, prems)])), gr') + end; + + fun add_prod_factors (fs, x as _ $ (Const ("op :", _) $ t $ u)) = + (case strip_comb u of + (Const (name, _), _) => + let val f = prod_factors [] t + in overwrite (fs, (name, f inter if_none (assoc (fs, name)) f)) end + | _ => fs) + | add_prod_factors (fs, _) = fs; + + val intrs' = map (rename_term o #prop o rep_thm o standard) intrs; + val factors = foldl add_prod_factors ([], flat (map (fn t => + Logic.strip_imp_concl t :: Logic.strip_imp_prems t) intrs')); + val (clauses, gr') = foldl (add_clause factors) (([], Graph.add_edge (hd ids, dep) + (Graph.new_node (hd ids, (None, "")) gr)), intrs'); + val _ $ (_ $ _ $ u) = Logic.strip_imp_concl (hd intrs'); + val (_, args) = strip_comb u; + val arg_vs = flat (map term_vs args); + val extra_modes = ("eq", [[1], [2], [1,2]]) :: (flat (map + (fst o get_modes o Graph.get_node gr') (Graph.all_preds gr' [hd ids]))); + val modes = infer_modes thy extra_modes arg_vs clauses; + val _ = print_modes modes; + val _ = print_factors factors; + val (gr'', s) = compile_preds thy gr' (hd ids) (terms_vs intrs') arg_vs + (modes @ extra_modes) clauses; + in + (Graph.map_node (hd ids) (K (Some (Modes (modes, factors)), s)) gr'') + end + end; + +fun mk_ind_call thy gr dep t u is_query = (case strip_comb u of + (Const (s, _), args) => (case InductivePackage.get_inductive thy s of + None => None + | Some ({names, ...}, {intrs, ...}) => + let + fun mk_mode (((ts, mode), i), Var _) = ((ts, mode), i+1) + | mk_mode (((ts, mode), i), Free _) = ((ts, mode), i+1) + | mk_mode (((ts, mode), i), t) = ((ts @ [t], mode @ [i]), i+1); + + val gr1 = mk_ind_def thy gr dep names intrs; + val (modes, factors) = pairself flat (ListPair.unzip + (map (get_modes o Graph.get_node gr1) (Graph.all_preds gr1 [dep]))); + val ts = split_prod [] (the (assoc (factors, s))) t; + val (ts', mode) = if is_query then + fst (foldl mk_mode ((([], []), 1), ts)) + else (ts, 1 upto length ts); + val _ = if mode mem the (assoc (modes, s)) then () else + error ("No such mode for " ^ s ^ ": " ^ string_of_mode mode); + val (gr2, in_ps) = foldl_map (fn (gr, t) => + invoke_codegen thy gr dep false t) (gr1, ts'); + val (gr3, arg_ps) = foldl_map (fn (gr, t) => + invoke_codegen thy gr dep true t) (gr2, args); + in + Some (gr3, Pretty.block (separate (Pretty.brk 1) + (Pretty.str (modename thy s mode) :: arg_ps @ [mk_tuple in_ps]))) + end) + | _ => None); + +fun inductive_codegen thy gr dep brack (Const ("op :", _) $ t $ u) = + (case mk_ind_call thy gr dep t u false of + None => None + | Some (gr', call_p) => Some (gr', (if brack then parens else I) + (Pretty.block [Pretty.str "nonempty (", call_p, Pretty.str ")"]))) + | inductive_codegen thy gr dep brack (Free ("query", _) $ (Const ("op :", _) $ t $ u)) = + mk_ind_call thy gr dep t u true + | inductive_codegen thy gr dep brack _ = None; + +val setup = [add_codegen "inductive" inductive_codegen]; + +end;