# HG changeset patch # User blanchet # Date 1383580423 -3600 # Node ID 8fdb4dc08ed10aaf4d401a2cf9e29ea1daef2dc4 # Parent f91022745c8582f9ed435e6b2359e4838c6a0d5a split 'primrec_new' and 'primcorec' code (to ease bootstrapping, e.g. dependency on datatype 'String' in 'primcorec') diff -r f91022745c85 -r 8fdb4dc08ed1 src/HOL/BNF/BNF_FP_Base.thy --- a/src/HOL/BNF/BNF_FP_Base.thy Mon Nov 04 15:44:43 2013 +0100 +++ b/src/HOL/BNF/BNF_FP_Base.thy Mon Nov 04 16:53:43 2013 +0100 @@ -172,7 +172,5 @@ ML_file "Tools/bnf_fp_n2m.ML" ML_file "Tools/bnf_fp_n2m_sugar.ML" ML_file "Tools/bnf_fp_rec_sugar_util.ML" -ML_file "Tools/bnf_fp_rec_sugar_tactics.ML" -ML_file "Tools/bnf_fp_rec_sugar.ML" end diff -r f91022745c85 -r 8fdb4dc08ed1 src/HOL/BNF/BNF_GFP.thy --- a/src/HOL/BNF/BNF_GFP.thy Mon Nov 04 15:44:43 2013 +0100 +++ b/src/HOL/BNF/BNF_GFP.thy Mon Nov 04 16:53:43 2013 +0100 @@ -308,6 +308,8 @@ lemma fun_rel_image2p: "(fun_rel R (image2p f g R)) f g" unfolding fun_rel_def image2p_def by auto +ML_file "Tools/bnf_gfp_rec_sugar_tactics.ML" +ML_file "Tools/bnf_gfp_rec_sugar.ML" ML_file "Tools/bnf_gfp_util.ML" ML_file "Tools/bnf_gfp_tactics.ML" ML_file "Tools/bnf_gfp.ML" diff -r f91022745c85 -r 8fdb4dc08ed1 src/HOL/BNF/BNF_LFP.thy --- a/src/HOL/BNF/BNF_LFP.thy Mon Nov 04 15:44:43 2013 +0100 +++ b/src/HOL/BNF/BNF_LFP.thy Mon Nov 04 16:53:43 2013 +0100 @@ -230,6 +230,7 @@ lemma predicate2D_vimage2p: "\R \ vimage2p f g S; R x y\ \ S (f x) (g y)" unfolding vimage2p_def by auto +ML_file "Tools/bnf_lfp_rec_sugar.ML" ML_file "Tools/bnf_lfp_util.ML" ML_file "Tools/bnf_lfp_tactics.ML" ML_file "Tools/bnf_lfp.ML" diff -r f91022745c85 -r 8fdb4dc08ed1 src/HOL/BNF/Tools/bnf_def.ML --- a/src/HOL/BNF/Tools/bnf_def.ML Mon Nov 04 15:44:43 2013 +0100 +++ b/src/HOL/BNF/Tools/bnf_def.ML Mon Nov 04 16:53:43 2013 +0100 @@ -81,6 +81,9 @@ val mk_rel: int -> typ list -> typ list -> term -> term val build_map: Proof.context -> (typ * typ -> term) -> typ * typ -> term val build_rel: Proof.context -> (typ * typ -> term) -> typ * typ -> term + val flatten_type_args_of_bnf: bnf -> 'a -> 'a list -> 'a list + val map_flattened_map_args: Proof.context -> string -> (term list -> 'a list) -> term list -> + 'a list val mk_witness: int list * term -> thm list -> nonemptiness_witness val minimize_wits: (''a list * 'b) list -> (''a list * 'b) list @@ -88,8 +91,6 @@ val zip_axioms: 'a -> 'a -> 'a -> 'a list -> 'a -> 'a -> 'a list -> 'a -> 'a -> 'a list - val flatten_type_args_of_bnf: bnf -> 'a -> 'a list -> 'a list - datatype const_policy = Dont_Inline | Hardly_Inline | Smart_Inline | Do_Inline datatype fact_policy = Dont_Note | Note_Some | Note_All @@ -524,6 +525,14 @@ val build_map = build_map_or_rel mk_map HOLogic.id_const map_of_bnf dest_funT; val build_rel = build_map_or_rel mk_rel HOLogic.eq_const rel_of_bnf dest_pred2T; +fun map_flattened_map_args ctxt s map_args fs = + let + val flat_fs = flatten_type_args_of_bnf (the (bnf_of ctxt s)) Term.dummy fs; + val flat_fs' = map_args flat_fs; + in + permute_like (op aconv) flat_fs fs flat_fs' + end; + (* Names *) diff -r f91022745c85 -r 8fdb4dc08ed1 src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML Mon Nov 04 15:44:43 2013 +0100 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,1128 +0,0 @@ -(* Title: HOL/BNF/Tools/bnf_fp_rec_sugar.ML - Author: Lorenz Panny, TU Muenchen - Copyright 2013 - -Recursor and corecursor sugar. -*) - -signature BNF_FP_REC_SUGAR = -sig - val add_primrec: (binding * typ option * mixfix) list -> - (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory - val add_primrec_cmd: (binding * string option * mixfix) list -> - (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory - val add_primrec_global: (binding * typ option * mixfix) list -> - (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory - val add_primrec_overloaded: (string * (string * typ) * bool) list -> - (binding * typ option * mixfix) list -> - (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory - val add_primrec_simple: ((binding * typ) * mixfix) list -> term list -> - local_theory -> (string list * (term list * (int list list * thm list list))) * local_theory - val add_primcorecursive_cmd: bool -> - (binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list -> - Proof.context -> Proof.state - val add_primcorec_cmd: bool -> - (binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list -> - local_theory -> local_theory -end; - -structure BNF_FP_Rec_Sugar : BNF_FP_REC_SUGAR = -struct - -open BNF_Util -open BNF_FP_Util -open BNF_FP_N2M_Sugar -open BNF_FP_Rec_Sugar_Util -open BNF_FP_Rec_Sugar_Tactics - -val codeN = "code"; -val ctrN = "ctr"; -val discN = "disc"; -val selN = "sel"; - -val nitpicksimp_attrs = @{attributes [nitpick_simp]}; -val simp_attrs = @{attributes [simp]}; -val code_nitpicksimp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs; -val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs @ simp_attrs; - -exception Primrec_Error of string * term list; - -fun primrec_error str = raise Primrec_Error (str, []); -fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]); -fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns); - -fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x); - -val free_name = try (fn Free (v, _) => v); -val const_name = try (fn Const (v, _) => v); -val undef_const = Const (@{const_name undefined}, dummyT); - -fun permute_args n t = list_comb (t, map Bound (0 :: (n downto 1))) - |> fold (K (Term.abs (Name.uu, dummyT))) (0 upto n); -val abs_tuple = HOLogic.tupled_lambda o HOLogic.mk_tuple; -fun drop_All t = subst_bounds (strip_qnt_vars @{const_name all} t |> map Free |> rev, - strip_qnt_body @{const_name all} t) -fun abstract vs = - let fun a n (t $ u) = a n t $ a n u - | a n (Abs (v, T, b)) = Abs (v, T, a (n + 1) b) - | a n t = let val idx = find_index (equal t) vs in - if idx < 0 then t else Bound (n + idx) end - in a 0 end; -fun mk_prod1 Ts (t, u) = HOLogic.pair_const (fastype_of1 (Ts, t)) (fastype_of1 (Ts, u)) $ t $ u; -fun mk_tuple1 Ts = the_default HOLogic.unit o try (foldr1 (mk_prod1 Ts)); - -fun get_indices fixes t = map (fst #>> Binding.name_of #> Free) fixes - |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE) - |> map_filter I; - - -(* Primrec *) - -type eqn_data = { - fun_name: string, - rec_type: typ, - ctr: term, - ctr_args: term list, - left_args: term list, - right_args: term list, - res_type: typ, - rhs_term: term, - user_eqn: term -}; - -fun dissect_eqn lthy fun_names eqn' = - let - val eqn = drop_All eqn' |> HOLogic.dest_Trueprop - handle TERM _ => - primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn'; - val (lhs, rhs) = HOLogic.dest_eq eqn - handle TERM _ => - primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn'; - val (fun_name, args) = strip_comb lhs - |>> (fn x => if is_Free x then fst (dest_Free x) - else primrec_error_eqn "malformed function equation (does not start with free)" eqn); - val (left_args, rest) = take_prefix is_Free args; - val (nonfrees, right_args) = take_suffix is_Free rest; - val num_nonfrees = length nonfrees; - val _ = num_nonfrees = 1 orelse if num_nonfrees = 0 then - primrec_error_eqn "constructor pattern missing in left-hand side" eqn else - primrec_error_eqn "more than one non-variable argument in left-hand side" eqn; - val _ = member (op =) fun_names fun_name orelse - primrec_error_eqn "malformed function equation (does not start with function name)" eqn - - val (ctr, ctr_args) = strip_comb (the_single nonfrees); - val _ = try (num_binder_types o fastype_of) ctr = SOME (length ctr_args) orelse - primrec_error_eqn "partially applied constructor in pattern" eqn; - val _ = let val d = duplicates (op =) (left_args @ ctr_args @ right_args) in null d orelse - primrec_error_eqn ("duplicate variable \"" ^ Syntax.string_of_term lthy (hd d) ^ - "\" in left-hand side") eqn end; - val _ = forall is_Free ctr_args orelse - primrec_error_eqn "non-primitive pattern in left-hand side" eqn; - val _ = - let val b = fold_aterms (fn x as Free (v, _) => - if (not (member (op =) (left_args @ ctr_args @ right_args) x) andalso - not (member (op =) fun_names v) andalso - not (Variable.is_fixed lthy v)) then cons x else I | _ => I) rhs [] - in - null b orelse - primrec_error_eqn ("extra variable(s) in right-hand side: " ^ - commas (map (Syntax.string_of_term lthy) b)) eqn - end; - in - {fun_name = fun_name, - rec_type = body_type (type_of ctr), - ctr = ctr, - ctr_args = ctr_args, - left_args = left_args, - right_args = right_args, - res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs, - rhs_term = rhs, - user_eqn = eqn'} - end; - -fun rewrite_map_arg get_ctr_pos rec_type res_type = - let - val pT = HOLogic.mk_prodT (rec_type, res_type); - - val maybe_suc = Option.map (fn x => x + 1); - fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT) - | subst d (Abs (v, T, b)) = Abs (v, if d = SOME ~1 then pT else T, subst (maybe_suc d) b) - | subst d t = - let - val (u, vs) = strip_comb t; - val ctr_pos = try (get_ctr_pos o the) (free_name u) |> the_default ~1; - in - if ctr_pos >= 0 then - if d = SOME ~1 andalso length vs = ctr_pos then - list_comb (permute_args ctr_pos (snd_const pT), vs) - else if length vs > ctr_pos andalso is_some d - andalso d = try (fn Bound n => n) (nth vs ctr_pos) then - list_comb (snd_const pT $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs)) - else - primrec_error_eqn ("recursive call not directly applied to constructor argument") t - else - list_comb (u, map (subst (d |> d = SOME ~1 ? K NONE)) vs) - end - in - subst (SOME ~1) - end; - -fun subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls = - let - fun try_nested_rec bound_Ts y t = - AList.lookup (op =) nested_calls y - |> Option.map (fn y' => - massage_nested_rec_call lthy has_call (rewrite_map_arg get_ctr_pos) bound_Ts y y' t); - - fun subst bound_Ts (t as g' $ y) = - let - fun subst_rec () = subst bound_Ts g' $ subst bound_Ts y; - val y_head = head_of y; - in - if not (member (op =) ctr_args y_head) then - subst_rec () - else - (case try_nested_rec bound_Ts y_head t of - SOME t' => t' - | NONE => - let val (g, g_args) = strip_comb g' in - (case try (get_ctr_pos o the) (free_name g) of - SOME ctr_pos => - (length g_args >= ctr_pos orelse - primrec_error_eqn "too few arguments in recursive call" t; - (case AList.lookup (op =) mutual_calls y of - SOME y' => list_comb (y', g_args) - | NONE => subst_rec ())) - | NONE => subst_rec ()) - end) - end - | subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b) - | subst _ t = t - - fun subst' t = - if has_call t then - (* FIXME detect this case earlier? *) - primrec_error_eqn "recursive call not directly applied to constructor argument" t - else - try_nested_rec [] (head_of t) t |> the_default t - in - subst' o subst [] - end; - -fun build_rec_arg lthy (funs_data : eqn_data list list) has_call (ctr_spec : rec_ctr_spec) - (maybe_eqn_data : eqn_data option) = - (case maybe_eqn_data of - NONE => undef_const - | SOME {ctr_args, left_args, right_args, rhs_term = t, ...} => - let - val calls = #calls ctr_spec; - val n_args = fold (Integer.add o (fn Mutual_Rec _ => 2 | _ => 1)) calls 0; - - val no_calls' = tag_list 0 calls - |> map_filter (try (apsnd (fn No_Rec p => p | Mutual_Rec (p, _) => p))); - val mutual_calls' = tag_list 0 calls - |> map_filter (try (apsnd (fn Mutual_Rec (_, p) => p))); - val nested_calls' = tag_list 0 calls - |> map_filter (try (apsnd (fn Nested_Rec p => p))); - - val args = replicate n_args ("", dummyT) - |> Term.rename_wrt_term t - |> map Free - |> fold (fn (ctr_arg_idx, (arg_idx, _)) => - nth_map arg_idx (K (nth ctr_args ctr_arg_idx))) - no_calls' - |> fold (fn (ctr_arg_idx, (arg_idx, T)) => - nth_map arg_idx (K (retype_free T (nth ctr_args ctr_arg_idx)))) - mutual_calls' - |> fold (fn (ctr_arg_idx, (arg_idx, T)) => - nth_map arg_idx (K (retype_free T (nth ctr_args ctr_arg_idx)))) - nested_calls'; - - val fun_name_ctr_pos_list = - map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data; - val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1; - val mutual_calls = map (apfst (nth ctr_args) o apsnd (nth args o fst)) mutual_calls'; - val nested_calls = map (apfst (nth ctr_args) o apsnd (nth args o fst)) nested_calls'; - in - t - |> subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls - |> fold_rev lambda (args @ left_args @ right_args) - end); - -fun build_defs lthy bs mxs (funs_data : eqn_data list list) (rec_specs : rec_spec list) has_call = - let - val n_funs = length funs_data; - - val ctr_spec_eqn_data_list' = - (take n_funs rec_specs |> map #ctr_specs) ~~ funs_data - |> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y)) - ##> (fn x => null x orelse - primrec_error_eqns "excess equations in definition" (map #rhs_term x)) #> fst); - val _ = ctr_spec_eqn_data_list' |> map (fn (_, x) => length x <= 1 orelse - primrec_error_eqns ("multiple equations for constructor") (map #user_eqn x)); - - val ctr_spec_eqn_data_list = - ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair [])); - - val recs = take n_funs rec_specs |> map #recx; - val rec_args = ctr_spec_eqn_data_list - |> sort ((op <) o pairself (#offset o fst) |> make_ord) - |> map (uncurry (build_rec_arg lthy funs_data has_call) o apsnd (try the_single)); - val ctr_poss = map (fn x => - if length (distinct ((op =) o pairself (length o #left_args)) x) <> 1 then - primrec_error ("inconstant constructor pattern position for function " ^ - quote (#fun_name (hd x))) - else - hd x |> #left_args |> length) funs_data; - in - (recs, ctr_poss) - |-> map2 (fn recx => fn ctr_pos => list_comb (recx, rec_args) |> permute_args ctr_pos) - |> Syntax.check_terms lthy - |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.conceal (Thm.def_binding b), []), t))) - bs mxs - end; - -fun find_rec_calls ctxt has_call ({ctr, ctr_args, rhs_term, ...} : eqn_data) = - let - fun find bound_Ts (Abs (_, T, b)) ctr_arg = find (T :: bound_Ts) b ctr_arg - | find bound_Ts (t as _ $ _) ctr_arg = - let - val typof = curry fastype_of1 bound_Ts; - val (f', args') = strip_comb t; - val n = find_index (equal ctr_arg o head_of) args'; - in - if n < 0 then - find bound_Ts f' ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args' - else - let - val (f, args as arg :: _) = chop n args' |>> curry list_comb f' - val (arg_head, arg_args) = Term.strip_comb arg; - in - if has_call f then - mk_partial_compN (length arg_args) (typof f) (typof arg_head) f :: - maps (fn x => find bound_Ts x ctr_arg) args - else - find bound_Ts f ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args - end - end - | find _ _ _ = []; - in - map (find [] rhs_term) ctr_args - |> (fn [] => NONE | callss => SOME (ctr, callss)) - end; - -fun prepare_primrec fixes specs lthy = - let - val (bs, mxs) = map_split (apfst fst) fixes; - val fun_names = map Binding.name_of bs; - val eqns_data = map (dissect_eqn lthy fun_names) specs; - val funs_data = eqns_data - |> partition_eq ((op =) o pairself #fun_name) - |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst - |> map (fn (x, y) => the_single y handle List.Empty => - primrec_error ("missing equations for function " ^ quote x)); - - val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =)); - val arg_Ts = map (#rec_type o hd) funs_data; - val res_Ts = map (#res_type o hd) funs_data; - val callssss = funs_data - |> map (partition_eq ((op =) o pairself #ctr)) - |> map (maps (map_filter (find_rec_calls lthy has_call))); - - val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy') = - rec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy; - - val actual_nn = length funs_data; - - val _ = let val ctrs = (maps (map #ctr o #ctr_specs) rec_specs) in - map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse - primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy' ctr) ^ - " is not a constructor in left-hand side") user_eqn) eqns_data end; - - val defs = build_defs lthy' bs mxs funs_data rec_specs has_call; - - fun prove lthy def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec) - (fun_data : eqn_data list) = - let - val def_thms = map (snd o snd) def_thms'; - val simp_thmss = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs - |> fst - |> map_filter (try (fn (x, [y]) => - (#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y))) - |> map (fn (user_eqn, num_extra_args, rec_thm) => - mk_primrec_tac lthy num_extra_args nested_map_idents nested_map_comps def_thms rec_thm - |> K |> Goal.prove lthy [] [] user_eqn - |> Thm.close_derivation); - val poss = find_indices (fn (x, y) => #ctr x = #ctr y) fun_data eqns_data; - in - (poss, simp_thmss) - end; - - val notes = - (if n2m then map2 (fn name => fn thm => - (name, inductN, [thm], [])) fun_names (take actual_nn induct_thms) else []) - |> map (fn (prefix, thmN, thms, attrs) => - ((Binding.qualify true prefix (Binding.name thmN), attrs), [(thms, [])])); - - val common_name = mk_common_name fun_names; - - val common_notes = - (if n2m then [(inductN, [induct_thm], [])] else []) - |> map (fn (thmN, thms, attrs) => - ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])])); - in - (((fun_names, defs), - fn lthy => fn defs => - split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)), - lthy' |> Local_Theory.notes (notes @ common_notes) |> snd) - end; - -(* primrec definition *) - -fun add_primrec_simple fixes ts lthy = - let - val (((names, defs), prove), lthy) = prepare_primrec fixes ts lthy - handle ERROR str => primrec_error str; - in - lthy - |> fold_map Local_Theory.define defs - |-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs)))) - end - handle Primrec_Error (str, eqns) => - if null eqns - then error ("primrec_new error:\n " ^ str) - else error ("primrec_new error:\n " ^ str ^ "\nin\n " ^ - space_implode "\n " (map (quote o Syntax.string_of_term lthy) eqns)); - -local - -fun gen_primrec prep_spec (raw_fixes : (binding * 'a option * mixfix) list) raw_spec lthy = - let - val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes) - val _ = null d orelse primrec_error ("duplicate function name(s): " ^ commas d); - - val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy); - - val mk_notes = - flat ooo map3 (fn poss => fn prefix => fn thms => - let - val (bs, attrss) = map_split (fst o nth specs) poss; - val notes = - map3 (fn b => fn attrs => fn thm => - ((Binding.qualify false prefix b, code_nitpicksimp_simp_attrs @ attrs), [([thm], [])])) - bs attrss thms; - in - ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes - end); - in - lthy - |> add_primrec_simple fixes (map snd specs) - |-> (fn (names, (ts, (posss, simpss))) => - Spec_Rules.add Spec_Rules.Equational (ts, flat simpss) - #> Local_Theory.notes (mk_notes posss names simpss) - #>> pair ts o map snd) - end; - -in - -val add_primrec = gen_primrec Specification.check_spec; -val add_primrec_cmd = gen_primrec Specification.read_spec; - -end; - -fun add_primrec_global fixes specs thy = - let - val lthy = Named_Target.theory_init thy; - val ((ts, simps), lthy') = add_primrec fixes specs lthy; - val simps' = burrow (Proof_Context.export lthy' lthy) simps; - in ((ts, simps'), Local_Theory.exit_global lthy') end; - -fun add_primrec_overloaded ops fixes specs thy = - let - val lthy = Overloading.overloading ops thy; - val ((ts, simps), lthy') = add_primrec fixes specs lthy; - val simps' = burrow (Proof_Context.export lthy' lthy) simps; - in ((ts, simps'), Local_Theory.exit_global lthy') end; - - - -(* Primcorec *) - -type coeqn_data_disc = { - fun_name: string, - fun_T: typ, - fun_args: term list, - ctr: term, - ctr_no: int, (*###*) - disc: term, - prems: term list, - auto_gen: bool, - maybe_ctr_rhs: term option, - maybe_code_rhs: term option, - user_eqn: term -}; - -type coeqn_data_sel = { - fun_name: string, - fun_T: typ, - fun_args: term list, - ctr: term, - sel: term, - rhs_term: term, - user_eqn: term -}; - -datatype coeqn_data = - Disc of coeqn_data_disc | - Sel of coeqn_data_sel; - -fun dissect_coeqn_disc seq fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) - maybe_ctr_rhs maybe_code_rhs prems' concl matchedsss = - let - fun find_subterm p = let (* FIXME \? *) - fun f (t as u $ v) = if p t then SOME t else merge_options (f u, f v) - | f t = if p t then SOME t else NONE - in f end; - - val applied_fun = concl - |> find_subterm (member ((op =) o apsnd SOME) fun_names o try (fst o dest_Free o head_of)) - |> the - handle Option.Option => primrec_error_eqn "malformed discriminator formula" concl; - val ((fun_name, fun_T), fun_args) = strip_comb applied_fun |>> dest_Free; - val SOME basic_ctr_specs = AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name; - - val discs = map #disc basic_ctr_specs; - val ctrs = map #ctr basic_ctr_specs; - val not_disc = head_of concl = @{term Not}; - val _ = not_disc andalso length ctrs <> 2 andalso - primrec_error_eqn "negated discriminator for a type with \ 2 constructors" concl; - val disc' = find_subterm (member (op =) discs o head_of) concl; - val eq_ctr0 = concl |> perhaps (try HOLogic.dest_not) |> try (HOLogic.dest_eq #> snd) - |> (fn SOME t => let val n = find_index (equal t) ctrs in - if n >= 0 then SOME n else NONE end | _ => NONE); - val _ = is_some disc' orelse is_some eq_ctr0 orelse - primrec_error_eqn "no discriminator in equation" concl; - val ctr_no' = - if is_none disc' then the eq_ctr0 else find_index (equal (head_of (the disc'))) discs; - val ctr_no = if not_disc then 1 - ctr_no' else ctr_no'; - val {ctr, disc, ...} = nth basic_ctr_specs ctr_no; - - val catch_all = try (fst o dest_Free o the_single) prems' = SOME Name.uu_; - val matchedss = AList.lookup (op =) matchedsss fun_name |> the_default []; - val prems = map (abstract (List.rev fun_args)) prems'; - val real_prems = - (if catch_all orelse seq then maps s_not_conj matchedss else []) @ - (if catch_all then [] else prems); - - val matchedsss' = AList.delete (op =) fun_name matchedsss - |> cons (fun_name, if seq then matchedss @ [prems] else matchedss @ [real_prems]); - - val user_eqn = - (real_prems, concl) - |>> map HOLogic.mk_Trueprop ||> HOLogic.mk_Trueprop o abstract (List.rev fun_args) - |> curry Logic.list_all (map dest_Free fun_args) o Logic.list_implies; - in - (Disc { - fun_name = fun_name, - fun_T = fun_T, - fun_args = fun_args, - ctr = ctr, - ctr_no = ctr_no, - disc = disc, - prems = real_prems, - auto_gen = catch_all, - maybe_ctr_rhs = maybe_ctr_rhs, - maybe_code_rhs = maybe_code_rhs, - user_eqn = user_eqn - }, matchedsss') - end; - -fun dissect_coeqn_sel fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) eqn' - maybe_of_spec eqn = - let - val (lhs, rhs) = HOLogic.dest_eq eqn - handle TERM _ => - primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn; - val sel = head_of lhs; - val ((fun_name, fun_T), fun_args) = dest_comb lhs |> snd |> strip_comb |> apfst dest_Free - handle TERM _ => - primrec_error_eqn "malformed selector argument in left-hand side" eqn; - val basic_ctr_specs = the (AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name) - handle Option.Option => primrec_error_eqn "malformed selector argument in left-hand side" eqn; - val {ctr, ...} = - (case maybe_of_spec of - SOME of_spec => the (find_first (equal of_spec o #ctr) basic_ctr_specs) - | NONE => filter (exists (equal sel) o #sels) basic_ctr_specs |> the_single - handle List.Empty => primrec_error_eqn "ambiguous selector - use \"of\"" eqn); - val user_eqn = drop_All eqn'; - in - Sel { - fun_name = fun_name, - fun_T = fun_T, - fun_args = fun_args, - ctr = ctr, - sel = sel, - rhs_term = rhs, - user_eqn = user_eqn - } - end; - -fun dissect_coeqn_ctr seq fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) eqn' - maybe_code_rhs prems concl matchedsss = - let - val (lhs, rhs) = HOLogic.dest_eq concl; - val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free; - val SOME basic_ctr_specs = AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name; - val (ctr, ctr_args) = strip_comb (unfold_let rhs); - val {disc, sels, ...} = the (find_first (equal ctr o #ctr) basic_ctr_specs) - handle Option.Option => primrec_error_eqn "not a constructor" ctr; - - val disc_concl = betapply (disc, lhs); - val (maybe_eqn_data_disc, matchedsss') = if length basic_ctr_specs = 1 - then (NONE, matchedsss) - else apfst SOME (dissect_coeqn_disc seq fun_names basic_ctr_specss - (SOME (abstract (List.rev fun_args) rhs)) maybe_code_rhs prems disc_concl matchedsss); - - val sel_concls = sels ~~ ctr_args - |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg)); - -(* -val _ = tracing ("reduced\n " ^ Syntax.string_of_term @{context} concl ^ "\nto\n \ " ^ - (is_some maybe_eqn_data_disc ? K (Syntax.string_of_term @{context} disc_concl ^ "\n \ ")) "" ^ - space_implode "\n \ " (map (Syntax.string_of_term @{context}) sel_concls) ^ - "\nfor premise(s)\n \ " ^ - space_implode "\n \ " (map (Syntax.string_of_term @{context}) prems)); -*) - - val eqns_data_sel = - map (dissect_coeqn_sel fun_names basic_ctr_specss eqn' (SOME ctr)) sel_concls; - in - (the_list maybe_eqn_data_disc @ eqns_data_sel, matchedsss') - end; - -fun dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn' concl matchedsss = - let - val (lhs, (rhs', rhs)) = HOLogic.dest_eq concl ||> `(expand_corec_code_rhs lthy has_call []); - val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free; - val SOME basic_ctr_specs = AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name; - - val cond_ctrs = fold_rev_corec_code_rhs lthy (fn cs => fn ctr => fn _ => - if member ((op =) o apsnd #ctr) basic_ctr_specs ctr - then cons (ctr, cs) - else primrec_error_eqn "not a constructor" ctr) [] rhs' [] - |> AList.group (op =); - - val ctr_premss = (case cond_ctrs of [_] => [[]] | _ => map (s_dnf o snd) cond_ctrs); - val ctr_concls = cond_ctrs |> map (fn (ctr, _) => - binder_types (fastype_of ctr) - |> map_index (fn (n, T) => massage_corec_code_rhs lthy (fn _ => fn ctr' => fn args => - if ctr' = ctr then nth args n else Const (@{const_name undefined}, T)) [] rhs') - |> curry list_comb ctr - |> curry HOLogic.mk_eq lhs); - in - fold_map2 (dissect_coeqn_ctr false fun_names basic_ctr_specss eqn' - (SOME (abstract (List.rev fun_args) rhs))) - ctr_premss ctr_concls matchedsss - end; - -fun dissect_coeqn lthy seq has_call fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) - eqn' maybe_of_spec matchedsss = - let - val eqn = drop_All eqn' - handle TERM _ => primrec_error_eqn "malformed function equation" eqn'; - val (prems, concl) = Logic.strip_horn eqn - |> apfst (map HOLogic.dest_Trueprop) o apsnd HOLogic.dest_Trueprop; - - val head = concl - |> perhaps (try HOLogic.dest_not) |> perhaps (try (fst o HOLogic.dest_eq)) - |> head_of; - - val maybe_rhs = concl |> perhaps (try HOLogic.dest_not) |> try (snd o HOLogic.dest_eq); - - val discs = maps (map #disc) basic_ctr_specss; - val sels = maps (maps #sels) basic_ctr_specss; - val ctrs = maps (map #ctr) basic_ctr_specss; - in - if member (op =) discs head orelse - is_some maybe_rhs andalso - member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then - dissect_coeqn_disc seq fun_names basic_ctr_specss NONE NONE prems concl matchedsss - |>> single - else if member (op =) sels head then - ([dissect_coeqn_sel fun_names basic_ctr_specss eqn' maybe_of_spec concl], matchedsss) - else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso - member (op =) ctrs (head_of (unfold_let (the maybe_rhs))) then - dissect_coeqn_ctr seq fun_names basic_ctr_specss eqn' NONE prems concl matchedsss - else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso - null prems then - dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn' concl matchedsss - |>> flat - else - primrec_error_eqn "malformed function equation" eqn - end; - -fun build_corec_arg_disc (ctr_specs : corec_ctr_spec list) - ({fun_args, ctr_no, prems, ...} : coeqn_data_disc) = - if is_none (#pred (nth ctr_specs ctr_no)) then I else - s_conjs prems - |> curry subst_bounds (List.rev fun_args) - |> HOLogic.tupled_lambda (HOLogic.mk_tuple fun_args) - |> K |> nth_map (the (#pred (nth ctr_specs ctr_no))); - -fun build_corec_arg_no_call (sel_eqns : coeqn_data_sel list) sel = - find_first (equal sel o #sel) sel_eqns - |> try (fn SOME {fun_args, rhs_term, ...} => abs_tuple fun_args rhs_term) - |> the_default undef_const - |> K; - -fun build_corec_args_mutual_call lthy has_call (sel_eqns : coeqn_data_sel list) sel = - (case find_first (equal sel o #sel) sel_eqns of - NONE => (I, I, I) - | SOME {fun_args, rhs_term, ... } => - let - val bound_Ts = List.rev (map fastype_of fun_args); - fun rewrite_stop _ t = if has_call t then @{term False} else @{term True}; - fun rewrite_end _ t = if has_call t then undef_const else t; - fun rewrite_cont bound_Ts t = - if has_call t then mk_tuple1 bound_Ts (snd (strip_comb t)) else undef_const; - fun massage f _ = massage_mutual_corec_call lthy has_call f bound_Ts rhs_term - |> abs_tuple fun_args; - in - (massage rewrite_stop, massage rewrite_end, massage rewrite_cont) - end); - -fun build_corec_arg_nested_call lthy has_call (sel_eqns : coeqn_data_sel list) sel = - (case find_first (equal sel o #sel) sel_eqns of - NONE => I - | SOME {fun_args, rhs_term, ...} => - let - val bound_Ts = List.rev (map fastype_of fun_args); - fun rewrite bound_Ts U T (Abs (v, V, b)) = Abs (v, V, rewrite (V :: bound_Ts) U T b) - | rewrite bound_Ts U T (t as _ $ _) = - let val (u, vs) = strip_comb t in - if is_Free u andalso has_call u then - Inr_const U T $ mk_tuple1 bound_Ts vs - else if const_name u = SOME @{const_name prod_case} then - map (rewrite bound_Ts U T) vs |> chop 1 |>> HOLogic.mk_split o the_single |> list_comb - else - list_comb (rewrite bound_Ts U T u, map (rewrite bound_Ts U T) vs) - end - | rewrite _ U T t = - if is_Free t andalso has_call t then Inr_const U T $ HOLogic.unit else t; - fun massage t = - rhs_term - |> massage_nested_corec_call lthy has_call rewrite bound_Ts (range_type (fastype_of t)) - |> abs_tuple fun_args; - in - massage - end); - -fun build_corec_args_sel lthy has_call (all_sel_eqns : coeqn_data_sel list) - (ctr_spec : corec_ctr_spec) = - (case filter (equal (#ctr ctr_spec) o #ctr) all_sel_eqns of - [] => I - | sel_eqns => - let - val sel_call_list = #sels ctr_spec ~~ #calls ctr_spec; - val no_calls' = map_filter (try (apsnd (fn No_Corec n => n))) sel_call_list; - val mutual_calls' = map_filter (try (apsnd (fn Mutual_Corec n => n))) sel_call_list; - val nested_calls' = map_filter (try (apsnd (fn Nested_Corec n => n))) sel_call_list; - in - I - #> fold (fn (sel, n) => nth_map n (build_corec_arg_no_call sel_eqns sel)) no_calls' - #> fold (fn (sel, (q, g, h)) => - let val (fq, fg, fh) = build_corec_args_mutual_call lthy has_call sel_eqns sel in - nth_map q fq o nth_map g fg o nth_map h fh end) mutual_calls' - #> fold (fn (sel, n) => nth_map n - (build_corec_arg_nested_call lthy has_call sel_eqns sel)) nested_calls' - end); - -fun build_codefs lthy bs mxs has_call arg_Tss (corec_specs : corec_spec list) - (disc_eqnss : coeqn_data_disc list list) (sel_eqnss : coeqn_data_sel list list) = - let - val corecs = map #corec corec_specs; - val ctr_specss = map #ctr_specs corec_specs; - val corec_args = hd corecs - |> fst o split_last o binder_types o fastype_of - |> map (Const o pair @{const_name undefined}) - |> fold2 (fold o build_corec_arg_disc) ctr_specss disc_eqnss - |> fold2 (fold o build_corec_args_sel lthy has_call) sel_eqnss ctr_specss; - fun currys [] t = t - | currys Ts t = t $ mk_tuple1 (List.rev Ts) (map Bound (length Ts - 1 downto 0)) - |> fold_rev (Term.abs o pair Name.uu) Ts; - -(* -val _ = tracing ("corecursor arguments:\n \ " ^ - space_implode "\n \ " (map (Syntax.string_of_term lthy) corec_args)); -*) - - val exclss' = - disc_eqnss - |> map (map (fn x => (#fun_args x, #ctr_no x, #prems x, #auto_gen x)) - #> fst o (fn xs => fold_map (fn x => fn ys => ((x, ys), ys @ [x])) xs []) - #> maps (uncurry (map o pair) - #> map (fn ((fun_args, c, x, a), (_, c', y, a')) => - ((c, c', a orelse a'), (x, s_not (s_conjs y))) - ||> apfst (map HOLogic.mk_Trueprop) o apsnd HOLogic.mk_Trueprop - ||> Logic.list_implies - ||> curry Logic.list_all (map dest_Free fun_args)))) - in - map (list_comb o rpair corec_args) corecs - |> map2 (fn Ts => fn t => if length Ts = 0 then t $ HOLogic.unit else t) arg_Tss - |> map2 currys arg_Tss - |> Syntax.check_terms lthy - |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.conceal (Thm.def_binding b), []), t))) - bs mxs - |> rpair exclss' - end; - -fun mk_real_disc_eqns fun_binding arg_Ts ({ctr_specs, ...} : corec_spec) - (sel_eqns : coeqn_data_sel list) (disc_eqns : coeqn_data_disc list) = - if length disc_eqns <> length ctr_specs - 1 then disc_eqns else - let - val n = 0 upto length ctr_specs - |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns)); - val fun_args = (try (#fun_args o hd) disc_eqns, try (#fun_args o hd) sel_eqns) - |> the_default (map (curry Free Name.uu) arg_Ts) o merge_options; - val extra_disc_eqn = { - fun_name = Binding.name_of fun_binding, - fun_T = arg_Ts ---> body_type (fastype_of (#ctr (hd ctr_specs))), - fun_args = fun_args, - ctr = #ctr (nth ctr_specs n), - ctr_no = n, - disc = #disc (nth ctr_specs n), - prems = maps (s_not_conj o #prems) disc_eqns, - auto_gen = true, - maybe_ctr_rhs = NONE, - maybe_code_rhs = NONE, - user_eqn = undef_const}; - in - chop n disc_eqns ||> cons extra_disc_eqn |> (op @) - end; - -fun find_corec_calls ctxt has_call basic_ctr_specs ({ctr, sel, rhs_term, ...} : coeqn_data_sel) = - let - val sel_no = find_first (equal ctr o #ctr) basic_ctr_specs - |> find_index (equal sel) o #sels o the; - fun find t = if has_call t then snd (fold_rev_corec_call ctxt (K cons) [] t []) else []; - in - find rhs_term - |> K |> nth_map sel_no |> AList.map_entry (op =) ctr - end; - -fun add_primcorec_ursive maybe_tac seq fixes specs maybe_of_specs lthy = - let - val (bs, mxs) = map_split (apfst fst) fixes; - val (arg_Ts, res_Ts) = map (strip_type o snd o fst #>> HOLogic.mk_tupleT) fixes |> split_list; - - val fun_names = map Binding.name_of bs; - val basic_ctr_specss = map (basic_corec_specs_of lthy) res_Ts; - val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =)); - val eqns_data = - fold_map2 (dissect_coeqn lthy seq has_call fun_names basic_ctr_specss) (map snd specs) - maybe_of_specs [] - |> flat o fst; - - val callssss = - map_filter (try (fn Sel x => x)) eqns_data - |> partition_eq ((op =) o pairself #fun_name) - |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names - |> map (flat o snd) - |> map2 (fold o find_corec_calls lthy has_call) basic_ctr_specss - |> map2 (curry (op |>)) (map (map (fn {ctr, sels, ...} => - (ctr, map (K []) sels))) basic_ctr_specss); - -(* -val _ = tracing ("callssss = " ^ @{make_string} callssss); -*) - - val ((n2m, corec_specs', _, coinduct_thm, strong_coinduct_thm, coinduct_thms, - strong_coinduct_thms), lthy') = - corec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy; - val actual_nn = length bs; - val corec_specs = take actual_nn corec_specs'; (*###*) - val ctr_specss = map #ctr_specs corec_specs; - - val disc_eqnss' = map_filter (try (fn Disc x => x)) eqns_data - |> partition_eq ((op =) o pairself #fun_name) - |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names - |> map (sort ((op <) o pairself #ctr_no |> make_ord) o flat o snd); - val _ = disc_eqnss' |> map (fn x => - let val d = duplicates ((op =) o pairself #ctr_no) x in null d orelse - primrec_error_eqns "excess discriminator formula in definition" - (maps (fn t => filter (equal (#ctr_no t) o #ctr_no) x) d |> map #user_eqn) end); - - val sel_eqnss = map_filter (try (fn Sel x => x)) eqns_data - |> partition_eq ((op =) o pairself #fun_name) - |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names - |> map (flat o snd); - - val arg_Tss = map (binder_types o snd o fst) fixes; - val disc_eqnss = map5 mk_real_disc_eqns bs arg_Tss corec_specs sel_eqnss disc_eqnss'; - val (defs, exclss') = - build_codefs lthy' bs mxs has_call arg_Tss corec_specs disc_eqnss sel_eqnss; - - fun excl_tac (c, c', a) = - if a orelse c = c' orelse seq then SOME (K (HEADGOAL (mk_primcorec_assumption_tac lthy []))) - else maybe_tac; - -(* -val _ = tracing ("exclusiveness properties:\n \ " ^ - space_implode "\n \ " (maps (map (Syntax.string_of_term lthy o snd)) exclss')); -*) - - val exclss'' = exclss' |> map (map (fn (idx, t) => - (idx, (Option.map (Goal.prove lthy [] [] t #> Thm.close_derivation) (excl_tac idx), t)))); - val taut_thmss = map (map (apsnd (the o fst)) o filter (is_some o fst o snd)) exclss''; - val (goal_idxss, goalss) = exclss'' - |> map (map (apsnd (rpair [] o snd)) o filter (is_none o fst o snd)) - |> split_list o map split_list; - - fun prove thmss' def_thms' lthy = - let - val def_thms = map (snd o snd) def_thms'; - - val exclss' = map (op ~~) (goal_idxss ~~ thmss'); - fun mk_exclsss excls n = - (excls, map (fn k => replicate k [TrueI] @ replicate (n - k) []) (0 upto n - 1)) - |-> fold (fn ((c, c', _), thm) => nth_map c (nth_map c' (K [thm]))); - val exclssss = (exclss' ~~ taut_thmss |> map (op @), fun_names ~~ corec_specs) - |-> map2 (fn excls => fn (_, {ctr_specs, ...}) => mk_exclsss excls (length ctr_specs)); - - fun prove_disc ({ctr_specs, ...} : corec_spec) exclsss - ({fun_name, fun_T, fun_args, ctr_no, prems, ...} : coeqn_data_disc) = - if Term.aconv_untyped (#disc (nth ctr_specs ctr_no), @{term "\x. x = x"}) then [] else - let - val {disc_corec, ...} = nth ctr_specs ctr_no; - val k = 1 + ctr_no; - val m = length prems; - val t = - list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0)) - |> curry betapply (#disc (nth ctr_specs ctr_no)) (*###*) - |> HOLogic.mk_Trueprop - |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems) - |> curry Logic.list_all (map dest_Free fun_args); - in - if prems = [@{term False}] then [] else - mk_primcorec_disc_tac lthy def_thms disc_corec k m exclsss - |> K |> Goal.prove lthy [] [] t - |> Thm.close_derivation - |> pair (#disc (nth ctr_specs ctr_no)) - |> single - end; - - fun prove_sel ({nested_maps, nested_map_idents, nested_map_comps, ctr_specs, ...} - : corec_spec) (disc_eqns : coeqn_data_disc list) exclsss - ({fun_name, fun_T, fun_args, ctr, sel, rhs_term, ...} : coeqn_data_sel) = - let - val SOME ctr_spec = find_first (equal ctr o #ctr) ctr_specs; - val ctr_no = find_index (equal ctr o #ctr) ctr_specs; - val prems = the_default (maps (s_not_conj o #prems) disc_eqns) - (find_first (equal ctr_no o #ctr_no) disc_eqns |> Option.map #prems); - val sel_corec = find_index (equal sel) (#sels ctr_spec) - |> nth (#sel_corecs ctr_spec); - val k = 1 + ctr_no; - val m = length prems; - val t = - list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0)) - |> curry betapply sel - |> rpair (abstract (List.rev fun_args) rhs_term) - |> HOLogic.mk_Trueprop o HOLogic.mk_eq - |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems) - |> curry Logic.list_all (map dest_Free fun_args); - val (distincts, _, sel_splits, sel_split_asms) = case_thms_of_term lthy [] rhs_term; - in - mk_primcorec_sel_tac lthy def_thms distincts sel_splits sel_split_asms nested_maps - nested_map_idents nested_map_comps sel_corec k m exclsss - |> K |> Goal.prove lthy [] [] t - |> Thm.close_derivation - |> pair sel - end; - - fun prove_ctr disc_alist sel_alist (disc_eqns : coeqn_data_disc list) - (sel_eqns : coeqn_data_sel list) ({ctr, disc, sels, collapse, ...} : corec_ctr_spec) = - (* don't try to prove theorems when some sel_eqns are missing *) - if not (exists (equal ctr o #ctr) disc_eqns) - andalso not (exists (equal ctr o #ctr) sel_eqns) - orelse - filter (equal ctr o #ctr) sel_eqns - |> fst o finds ((op =) o apsnd #sel) sels - |> exists (null o snd) - then [] else - let - val (fun_name, fun_T, fun_args, prems, maybe_rhs) = - (find_first (equal ctr o #ctr) disc_eqns, find_first (equal ctr o #ctr) sel_eqns) - |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x, - #maybe_ctr_rhs x)) - ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, [], NONE)) - |> the o merge_options; - val m = length prems; - val t = (if is_some maybe_rhs then the maybe_rhs else - filter (equal ctr o #ctr) sel_eqns - |> fst o finds ((op =) o apsnd #sel) sels - |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x)) #-> abstract) - |> curry list_comb ctr) - |> curry HOLogic.mk_eq (list_comb (Free (fun_name, fun_T), - map Bound (length fun_args - 1 downto 0))) - |> HOLogic.mk_Trueprop - |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems) - |> curry Logic.list_all (map dest_Free fun_args); - val maybe_disc_thm = AList.lookup (op =) disc_alist disc; - val sel_thms = map snd (filter (member (op =) sels o fst) sel_alist); - in - if prems = [@{term False}] then [] else - mk_primcorec_ctr_of_dtr_tac lthy m collapse maybe_disc_thm sel_thms - |> K |> Goal.prove lthy [] [] t - |> Thm.close_derivation - |> pair ctr - |> single - end; - - fun prove_code disc_eqns sel_eqns ctr_alist ctr_specs = - let - val (fun_name, fun_T, fun_args, maybe_rhs) = - (find_first (member (op =) (map #ctr ctr_specs) o #ctr) disc_eqns, - find_first (member (op =) (map #ctr ctr_specs) o #ctr) sel_eqns) - |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #maybe_code_rhs x)) - ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, NONE)) - |> the o merge_options; - - val bound_Ts = List.rev (map fastype_of fun_args); - - val lhs = list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0)); - val maybe_rhs_info = - (case maybe_rhs of - SOME rhs => - let - val raw_rhs = expand_corec_code_rhs lthy has_call bound_Ts rhs; - val cond_ctrs = - fold_rev_corec_code_rhs lthy (K oo (cons oo pair)) bound_Ts raw_rhs []; - val ctr_thms = map (the o AList.lookup (op =) ctr_alist o snd) cond_ctrs; - in SOME (rhs, raw_rhs, ctr_thms) end - | NONE => - let - fun prove_code_ctr {ctr, sels, ...} = - if not (exists (equal ctr o fst) ctr_alist) then NONE else - let - val prems = find_first (equal ctr o #ctr) disc_eqns - |> Option.map #prems |> the_default []; - val t = - filter (equal ctr o #ctr) sel_eqns - |> fst o finds ((op =) o apsnd #sel) sels - |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x)) - #-> abstract) - |> curry list_comb ctr; - in - SOME (prems, t) - end; - val maybe_ctr_conds_argss = map prove_code_ctr ctr_specs; - in - if exists is_none maybe_ctr_conds_argss then NONE else - let - val rhs = fold_rev (fn SOME (prems, u) => fn t => mk_If (s_conjs prems) u t) - maybe_ctr_conds_argss - (Const (@{const_name Code.abort}, @{typ String.literal} --> - (@{typ unit} --> body_type fun_T) --> body_type fun_T) $ - HOLogic.mk_literal fun_name $ - absdummy @{typ unit} (incr_boundvars 1 lhs)); - in SOME (rhs, rhs, map snd ctr_alist) end - end); - in - (case maybe_rhs_info of - NONE => [] - | SOME (rhs, raw_rhs, ctr_thms) => - let - val ms = map (Logic.count_prems o prop_of) ctr_thms; - val (raw_t, t) = (raw_rhs, rhs) - |> pairself - (curry HOLogic.mk_eq (list_comb (Free (fun_name, fun_T), - map Bound (length fun_args - 1 downto 0))) - #> HOLogic.mk_Trueprop - #> curry Logic.list_all (map dest_Free fun_args)); - val (distincts, discIs, sel_splits, sel_split_asms) = - case_thms_of_term lthy bound_Ts raw_rhs; - - val raw_code_thm = mk_primcorec_raw_code_of_ctr_tac lthy distincts discIs sel_splits - sel_split_asms ms ctr_thms - |> K |> Goal.prove lthy [] [] raw_t - |> Thm.close_derivation; - in - mk_primcorec_code_of_raw_code_tac lthy distincts sel_splits raw_code_thm - |> K |> Goal.prove lthy [] [] t - |> Thm.close_derivation - |> single - end) - end; - - val disc_alists = map3 (maps oo prove_disc) corec_specs exclssss disc_eqnss; - val sel_alists = map4 (map ooo prove_sel) corec_specs disc_eqnss exclssss sel_eqnss; - val disc_thmss = map (map snd) disc_alists; - val sel_thmss = map (map snd) sel_alists; - - val ctr_alists = map5 (maps oooo prove_ctr) disc_alists sel_alists disc_eqnss sel_eqnss - ctr_specss; - val ctr_thmss = map (map snd) ctr_alists; - - val code_thmss = map4 prove_code disc_eqnss sel_eqnss ctr_alists ctr_specss; - - val simp_thmss = map2 append disc_thmss sel_thmss - - val common_name = mk_common_name fun_names; - - val notes = - [(coinductN, map (if n2m then single else K []) coinduct_thms, []), - (codeN, code_thmss, code_nitpicksimp_attrs), - (ctrN, ctr_thmss, []), - (discN, disc_thmss, simp_attrs), - (selN, sel_thmss, simp_attrs), - (simpsN, simp_thmss, []), - (strong_coinductN, map (if n2m then single else K []) strong_coinduct_thms, [])] - |> maps (fn (thmN, thmss, attrs) => - map2 (fn fun_name => fn thms => - ((Binding.qualify true fun_name (Binding.name thmN), attrs), [(thms, [])])) - fun_names (take actual_nn thmss)) - |> filter_out (null o fst o hd o snd); - - val common_notes = - [(coinductN, if n2m then [coinduct_thm] else [], []), - (strong_coinductN, if n2m then [strong_coinduct_thm] else [], [])] - |> filter_out (null o #2) - |> map (fn (thmN, thms, attrs) => - ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])])); - in - lthy |> Local_Theory.notes (notes @ common_notes) |> snd - end; - - fun after_qed thmss' = fold_map Local_Theory.define defs #-> prove thmss'; - in - (goalss, after_qed, lthy') - end; - -fun add_primcorec_ursive_cmd maybe_tac seq (raw_fixes, raw_specs') lthy = - let - val (raw_specs, maybe_of_specs) = - split_list raw_specs' ||> map (Option.map (Syntax.read_term lthy)); - val ((fixes, specs), _) = Specification.read_spec raw_fixes raw_specs lthy; - in - add_primcorec_ursive maybe_tac seq fixes specs maybe_of_specs lthy - handle ERROR str => primrec_error str - end - handle Primrec_Error (str, eqns) => - if null eqns - then error ("primcorec error:\n " ^ str) - else error ("primcorec error:\n " ^ str ^ "\nin\n " ^ - space_implode "\n " (map (quote o Syntax.string_of_term lthy) eqns)); - -val add_primcorecursive_cmd = (fn (goalss, after_qed, lthy) => - lthy - |> Proof.theorem NONE after_qed goalss - |> Proof.refine (Method.primitive_text I) - |> Seq.hd) ooo add_primcorec_ursive_cmd NONE; - -val add_primcorec_cmd = (fn (goalss, after_qed, lthy) => - lthy - |> after_qed (map (fn [] => [] - | _ => primrec_error "need exclusiveness proofs - use primcorecursive instead of primcorec") - goalss)) ooo add_primcorec_ursive_cmd (SOME (fn {context = ctxt, ...} => auto_tac ctxt)); - -end; diff -r f91022745c85 -r 8fdb4dc08ed1 src/HOL/BNF/Tools/bnf_fp_rec_sugar_tactics.ML --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar_tactics.ML Mon Nov 04 15:44:43 2013 +0100 +++ /dev/null Thu Jan 01 00:00:00 1970 +0000 @@ -1,142 +0,0 @@ -(* Title: HOL/BNF/Tools/bnf_fp_rec_sugar_tactics.ML - Author: Jasmin Blanchette, TU Muenchen - Copyright 2013 - -Tactics for recursor and corecursor sugar. -*) - -signature BNF_FP_REC_SUGAR_TACTICS = -sig - val mk_primcorec_assumption_tac: Proof.context -> thm list -> int -> tactic - val mk_primcorec_code_of_raw_code_tac: Proof.context -> thm list -> thm list -> thm -> tactic - val mk_primcorec_ctr_of_dtr_tac: Proof.context -> int -> thm -> thm option -> thm list -> tactic - val mk_primcorec_disc_tac: Proof.context -> thm list -> thm -> int -> int -> thm list list list -> - tactic - val mk_primcorec_raw_code_of_ctr_tac: Proof.context -> thm list -> thm list -> thm list -> - thm list -> int list -> thm list -> tactic - val mk_primcorec_sel_tac: Proof.context -> thm list -> thm list -> thm list -> thm list -> - thm list -> thm list -> thm list -> thm -> int -> int -> thm list list list -> tactic - val mk_primrec_tac: Proof.context -> int -> thm list -> thm list -> thm list -> thm -> tactic -end; - -structure BNF_FP_Rec_Sugar_Tactics : BNF_FP_REC_SUGAR_TACTICS = -struct - -open BNF_Util -open BNF_Tactics - -val falseEs = @{thms not_TrueE FalseE}; -val Let_def = @{thm Let_def}; -val neq_eq_eq_contradict = @{thm neq_eq_eq_contradict}; -val split_if = @{thm split_if}; -val split_if_asm = @{thm split_if_asm}; -val split_connectI = @{thms allI impI conjI}; - -fun mk_primrec_tac ctxt num_extra_args map_idents map_comps fun_defs recx = - unfold_thms_tac ctxt fun_defs THEN - HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN - unfold_thms_tac ctxt (@{thms id_def split o_def fst_conv snd_conv} @ map_comps @ map_idents) THEN - HEADGOAL (rtac refl); - -fun mk_primcorec_assumption_tac ctxt discIs = - SELECT_GOAL (unfold_thms_tac ctxt - @{thms not_not not_False_eq_True not_True_eq_False de_Morgan_conj de_Morgan_disj} THEN - SOLVE (HEADGOAL (REPEAT o (rtac refl ORELSE' atac ORELSE' etac conjE ORELSE' - eresolve_tac falseEs ORELSE' - resolve_tac @{thms TrueI conjI disjI1 disjI2} ORELSE' - dresolve_tac discIs THEN' atac ORELSE' - etac notE THEN' atac ORELSE' - etac disjE)))); - -fun mk_primcorec_same_case_tac m = - HEADGOAL (if m = 0 then rtac TrueI - else REPEAT_DETERM_N (m - 1) o (rtac conjI THEN' atac) THEN' atac); - -fun mk_primcorec_different_case_tac ctxt m excl = - HEADGOAL (if m = 0 then mk_primcorec_assumption_tac ctxt [] - else dtac excl THEN' (REPEAT_DETERM_N (m - 1) o atac) THEN' mk_primcorec_assumption_tac ctxt []); - -fun mk_primcorec_cases_tac ctxt k m exclsss = - let val n = length exclsss in - EVERY (map (fn [] => if k = n then all_tac else mk_primcorec_same_case_tac m - | [excl] => mk_primcorec_different_case_tac ctxt m excl) - (take k (nth exclsss (k - 1)))) - end; - -fun mk_primcorec_prelude ctxt defs thm = - unfold_thms_tac ctxt defs THEN HEADGOAL (rtac thm) THEN - unfold_thms_tac ctxt @{thms Let_def split}; - -fun mk_primcorec_disc_tac ctxt defs disc_corec k m exclsss = - mk_primcorec_prelude ctxt defs disc_corec THEN mk_primcorec_cases_tac ctxt k m exclsss; - -fun mk_primcorec_sel_tac ctxt defs distincts splits split_asms maps map_idents map_comps f_sel k m - exclsss = - mk_primcorec_prelude ctxt defs (f_sel RS trans) THEN - mk_primcorec_cases_tac ctxt k m exclsss THEN - HEADGOAL (REPEAT_DETERM o (rtac refl ORELSE' rtac ext ORELSE' - eresolve_tac falseEs ORELSE' - resolve_tac split_connectI ORELSE' - Splitter.split_asm_tac (split_if_asm :: split_asms) ORELSE' - Splitter.split_tac (split_if :: splits) ORELSE' - eresolve_tac (map (fn thm => thm RS neq_eq_eq_contradict) distincts) THEN' atac ORELSE' - etac notE THEN' atac ORELSE' - (CHANGED o SELECT_GOAL (unfold_thms_tac ctxt - (@{thms id_def o_def split_def sum.cases} @ maps @ map_comps @ map_idents))))); - -fun mk_primcorec_ctr_of_dtr_tac ctxt m collapse maybe_disc_f sel_fs = - HEADGOAL (rtac ((if null sel_fs then collapse else collapse RS sym) RS trans) THEN' - (the_default (K all_tac) (Option.map rtac maybe_disc_f)) THEN' REPEAT_DETERM_N m o atac) THEN - unfold_thms_tac ctxt (Let_def :: sel_fs) THEN HEADGOAL (rtac refl); - -fun inst_split_eq ctxt split = - (case prop_of split of - @{const Trueprop} $ (Const (@{const_name HOL.eq}, _) $ (Var (_, Type (_, [T, _])) $ _) $ _) => - let - val s = Name.uu; - val eq = Abs (Name.uu, T, HOLogic.mk_eq (Free (s, T), Bound 0)); - val split' = Drule.instantiate' [] [SOME (certify ctxt eq)] split; - in - Thm.generalize ([], [s]) (Thm.maxidx_of split' + 1) split' - end - | _ => split); - -fun distinct_in_prems_tac distincts = - eresolve_tac (map (fn thm => thm RS neq_eq_eq_contradict) distincts) THEN' atac; - -(* TODO: reduce code duplication with selector tactic above *) -fun mk_primcorec_raw_code_of_ctr_single_tac ctxt distincts discIs splits split_asms m f_ctr = - let - val splits' = - map (fn th => th RS iffD2) (@{thm split_if_eq2} :: map (inst_split_eq ctxt) splits) - in - HEADGOAL (REPEAT o (resolve_tac (splits' @ split_connectI))) THEN - mk_primcorec_prelude ctxt [] (f_ctr RS trans) THEN - HEADGOAL ((REPEAT_DETERM_N m o mk_primcorec_assumption_tac ctxt discIs) THEN' - SELECT_GOAL (SOLVE (HEADGOAL (REPEAT_DETERM o - (rtac refl ORELSE' atac ORELSE' - resolve_tac (@{thm Code.abort_def} :: split_connectI) ORELSE' - Splitter.split_tac (split_if :: splits) ORELSE' - Splitter.split_asm_tac (split_if_asm :: split_asms) ORELSE' - mk_primcorec_assumption_tac ctxt discIs ORELSE' - distinct_in_prems_tac distincts ORELSE' - (TRY o dresolve_tac discIs) THEN' etac notE THEN' atac))))) - end; - -fun mk_primcorec_raw_code_of_ctr_tac ctxt distincts discIs splits split_asms ms f_ctrs = - EVERY (map2 (mk_primcorec_raw_code_of_ctr_single_tac ctxt distincts discIs splits split_asms) ms - f_ctrs) THEN - IF_UNSOLVED (unfold_thms_tac ctxt @{thms Code.abort_def} THEN - HEADGOAL (REPEAT_DETERM o resolve_tac (refl :: split_connectI))); - -fun mk_primcorec_code_of_raw_code_tac ctxt distincts splits raw = - HEADGOAL (rtac raw ORELSE' rtac (raw RS trans) THEN' - SELECT_GOAL (unfold_thms_tac ctxt [Let_def]) THEN' REPEAT_DETERM o - (rtac refl ORELSE' atac ORELSE' - resolve_tac split_connectI ORELSE' - Splitter.split_tac (split_if :: splits) ORELSE' - distinct_in_prems_tac distincts ORELSE' - rtac sym THEN' atac ORELSE' - etac notE THEN' atac)); - -end; diff -r f91022745c85 -r 8fdb4dc08ed1 src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML --- a/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML Mon Nov 04 15:44:43 2013 +0100 +++ b/src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML Mon Nov 04 16:53:43 2013 +0100 @@ -8,504 +8,26 @@ signature BNF_FP_REC_SUGAR_UTIL = sig - datatype rec_call = - No_Rec of int * typ | - Mutual_Rec of (int * typ) (*before*) * (int * typ) (*after*) | - Nested_Rec of int * typ - - datatype corec_call = - Dummy_No_Corec of int | - No_Corec of int | - Mutual_Corec of int (*stop?*) * int (*end*) * int (*continue*) | - Nested_Corec of int - - type rec_ctr_spec = - {ctr: term, - offset: int, - calls: rec_call list, - rec_thm: thm} - - type basic_corec_ctr_spec = - {ctr: term, - disc: term, - sels: term list} - - type corec_ctr_spec = - {ctr: term, - disc: term, - sels: term list, - pred: int option, - calls: corec_call list, - discI: thm, - sel_thms: thm list, - collapse: thm, - corec_thm: thm, - disc_corec: thm, - sel_corecs: thm list} + val indexed: 'a list -> int -> int list * int + val indexedd: 'a list list -> int -> int list list * int + val indexeddd: ''a list list list -> int -> int list list list * int + val indexedddd: 'a list list list list -> int -> int list list list list * int + val find_index_eq: ''a list -> ''a -> int + val finds: ('a * 'b -> bool) -> 'a list -> 'b list -> ('a * 'b list) list * 'b list - type rec_spec = - {recx: term, - nested_map_idents: thm list, - nested_map_comps: thm list, - ctr_specs: rec_ctr_spec list} - - type corec_spec = - {corec: term, - nested_maps: thm list, - nested_map_idents: thm list, - nested_map_comps: thm list, - ctr_specs: corec_ctr_spec list} - - val s_not: term -> term - val s_not_conj: term list -> term list - val s_conjs: term list -> term - val s_disjs: term list -> term - val s_dnf: term list list -> term list - - val mk_partial_compN: int -> typ -> typ -> term -> term + val drop_All: term -> term - val massage_nested_rec_call: Proof.context -> (term -> bool) -> (typ -> typ -> term -> term) -> - typ list -> term -> term -> term -> term - val massage_mutual_corec_call: Proof.context -> (term -> bool) -> (typ list -> term -> term) -> - typ list -> term -> term - val massage_nested_corec_call: Proof.context -> (term -> bool) -> - (typ list -> typ -> typ -> term -> term) -> typ list -> typ -> term -> term - val fold_rev_corec_call: Proof.context -> (term list -> term -> 'a -> 'a) -> typ list -> term -> - 'a -> string list * 'a - val expand_corec_code_rhs: Proof.context -> (term -> bool) -> typ list -> term -> term - val massage_corec_code_rhs: Proof.context -> (typ list -> term -> term list -> term) -> - typ list -> term -> term - val fold_rev_corec_code_rhs: Proof.context -> (term list -> term -> term list -> 'a -> 'a) -> - typ list -> term -> 'a -> 'a - val case_thms_of_term: Proof.context -> typ list -> term -> - thm list * thm list * thm list * thm list + val mk_partial_compN: int -> typ -> term -> term + val mk_partial_comp: typ -> typ -> term -> term + val mk_compN: int -> typ list -> term * term -> term + val mk_comp: typ list -> term * term -> term - val rec_specs_of: binding list -> typ list -> typ list -> (term -> int list) -> - ((term * term list list) list) list -> local_theory -> - (bool * rec_spec list * typ list * thm * thm list) * local_theory - val basic_corec_specs_of: Proof.context -> typ -> basic_corec_ctr_spec list - val corec_specs_of: binding list -> typ list -> typ list -> (term -> int list) -> - ((term * term list list) list) list -> local_theory -> - (bool * corec_spec list * typ list * thm * thm * thm list * thm list) * local_theory + val get_indices: ((binding * typ) * 'a) list -> term -> int list end; structure BNF_FP_Rec_Sugar_Util : BNF_FP_REC_SUGAR_UTIL = struct -open Ctr_Sugar -open BNF_Util -open BNF_Def -open BNF_FP_Util -open BNF_FP_Def_Sugar -open BNF_FP_N2M_Sugar - -datatype rec_call = - No_Rec of int * typ | - Mutual_Rec of (int * typ) * (int * typ) | - Nested_Rec of int * typ; - -datatype corec_call = - Dummy_No_Corec of int | - No_Corec of int | - Mutual_Corec of int * int * int | - Nested_Corec of int; - -type rec_ctr_spec = - {ctr: term, - offset: int, - calls: rec_call list, - rec_thm: thm}; - -type basic_corec_ctr_spec = - {ctr: term, - disc: term, - sels: term list}; - -type corec_ctr_spec = - {ctr: term, - disc: term, - sels: term list, - pred: int option, - calls: corec_call list, - discI: thm, - sel_thms: thm list, - collapse: thm, - corec_thm: thm, - disc_corec: thm, - sel_corecs: thm list}; - -type rec_spec = - {recx: term, - nested_map_idents: thm list, - nested_map_comps: thm list, - ctr_specs: rec_ctr_spec list}; - -type corec_spec = - {corec: term, - nested_maps: thm list, - nested_map_idents: thm list, - nested_map_comps: thm list, - ctr_specs: corec_ctr_spec list}; - -val id_def = @{thm id_def}; - -exception AINT_NO_MAP of term; - -fun not_codatatype ctxt T = - error ("Not a codatatype: " ^ Syntax.string_of_typ ctxt T); -fun ill_formed_rec_call ctxt t = - error ("Ill-formed recursive call: " ^ quote (Syntax.string_of_term ctxt t)); -fun ill_formed_corec_call ctxt t = - error ("Ill-formed corecursive call: " ^ quote (Syntax.string_of_term ctxt t)); -fun invalid_map ctxt t = - error ("Invalid map function in " ^ quote (Syntax.string_of_term ctxt t)); -fun unexpected_rec_call ctxt t = - error ("Unexpected recursive call: " ^ quote (Syntax.string_of_term ctxt t)); -fun unexpected_corec_call ctxt t = - error ("Unexpected corecursive call: " ^ quote (Syntax.string_of_term ctxt t)); - -val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True}; -val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False}; - -val conjuncts_s = filter_out (curry (op =) @{const True}) o HOLogic.conjuncts; - -fun s_not @{const True} = @{const False} - | s_not @{const False} = @{const True} - | s_not (@{const Not} $ t) = t - | s_not (@{const conj} $ t $ u) = @{const disj} $ s_not t $ s_not u - | s_not (@{const disj} $ t $ u) = @{const conj} $ s_not t $ s_not u - | s_not t = @{const Not} $ t; - -val s_not_conj = conjuncts_s o s_not o mk_conjs; - -fun s_conj c @{const True} = c - | s_conj c d = HOLogic.mk_conj (c, d); - -fun propagate_unit_pos u cs = if member (op aconv) cs u then [@{const False}] else cs; - -fun propagate_unit_neg not_u cs = remove (op aconv) not_u cs; - -fun propagate_units css = - (case List.partition (can the_single) css of - ([], _) => css - | ([u] :: uss, css') => - [u] :: propagate_units (map (propagate_unit_neg (s_not u)) - (map (propagate_unit_pos u) (uss @ css')))); - -fun s_conjs cs = - if member (op aconv) cs @{const False} then @{const False} - else mk_conjs (remove (op aconv) @{const True} cs); - -fun s_disjs ds = - if member (op aconv) ds @{const True} then @{const True} - else mk_disjs (remove (op aconv) @{const False} ds); - -fun s_dnf css0 = - let val css = propagate_units css0 in - if null css then - [@{const False}] - else if exists null css then - [] - else - map (fn c :: cs => (c, cs)) css - |> AList.coalesce (op =) - |> map (fn (c, css) => c :: s_dnf css) - |> (fn [cs] => cs | css => [s_disjs (map s_conjs css)]) - end; - -fun mk_partial_comp gT fT g = - let val T = domain_type fT --> range_type gT in - Const (@{const_name Fun.comp}, gT --> fT --> T) $ g - end; - -fun mk_partial_compN 0 _ _ g = g - | mk_partial_compN n gT fT g = - let val g' = mk_partial_compN (n - 1) gT (range_type fT) g in - mk_partial_comp (fastype_of g') fT g' - end; - -fun mk_compN n bound_Ts (g, f) = - let val typof = curry fastype_of1 bound_Ts in - mk_partial_compN n (typof g) (typof f) g $ f - end; - -val mk_comp = mk_compN 1; - -fun factor_out_types ctxt massage destU U T = - (case try destU U of - SOME (U1, U2) => if U1 = T then massage T U2 else invalid_map ctxt - | NONE => invalid_map ctxt); - -fun map_flattened_map_args ctxt s map_args fs = - let - val flat_fs = flatten_type_args_of_bnf (the (bnf_of ctxt s)) Term.dummy fs; - val flat_fs' = map_args flat_fs; - in - permute_like (op aconv) flat_fs fs flat_fs' - end; - -fun massage_nested_rec_call ctxt has_call raw_massage_fun bound_Ts y y' = - let - fun check_no_call t = if has_call t then unexpected_rec_call ctxt t else (); - - val typof = curry fastype_of1 bound_Ts; - val build_map_fst = build_map ctxt (fst_const o fst); - - val yT = typof y; - val yU = typof y'; - - fun y_of_y' () = build_map_fst (yU, yT) $ y'; - val elim_y = Term.map_aterms (fn t => if t = y then y_of_y' () else t); - - fun massage_mutual_fun U T t = - (case t of - Const (@{const_name comp}, comp_T) $ t1 $ t2 => - mk_comp bound_Ts (tap check_no_call t1, massage_mutual_fun U T t2) - | _ => - if has_call t then factor_out_types ctxt raw_massage_fun HOLogic.dest_prodT U T t - else mk_comp bound_Ts (t, build_map_fst (U, T))); - - fun massage_map (Type (_, Us)) (Type (s, Ts)) t = - (case try (dest_map ctxt s) t of - SOME (map0, fs) => - let - val Type (_, ran_Ts) = range_type (typof t); - val map' = mk_map (length fs) Us ran_Ts map0; - val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs; - in - Term.list_comb (map', fs') - end - | NONE => raise AINT_NO_MAP t) - | massage_map _ _ t = raise AINT_NO_MAP t - and massage_map_or_map_arg U T t = - if T = U then - tap check_no_call t - else - massage_map U T t - handle AINT_NO_MAP _ => massage_mutual_fun U T t; - - fun massage_call (t as t1 $ t2) = - if has_call t then - if t2 = y then - massage_map yU yT (elim_y t1) $ y' - handle AINT_NO_MAP t' => invalid_map ctxt t' - else - let val (g, xs) = Term.strip_comb t2 in - if g = y then - if exists has_call xs then unexpected_rec_call ctxt t2 - else Term.list_comb (massage_call (mk_compN (length xs) bound_Ts (t1, y)), xs) - else - ill_formed_rec_call ctxt t - end - else - elim_y t - | massage_call t = if t = y then y_of_y' () else ill_formed_rec_call ctxt t; - in - massage_call - end; - -fun fold_rev_let_if_case ctxt f bound_Ts t = - let - val thy = Proof_Context.theory_of ctxt; - - fun fld conds t = - (case Term.strip_comb t of - (Const (@{const_name Let}, _), [_, _]) => fld conds (unfold_let t) - | (Const (@{const_name If}, _), [cond, then_branch, else_branch]) => - fld (conds @ conjuncts_s cond) then_branch o fld (conds @ s_not_conj [cond]) else_branch - | (Const (c, _), args as _ :: _ :: _) => - let val n = num_binder_types (Sign.the_const_type thy c) - 1 in - if n >= 0 andalso n < length args then - (case fastype_of1 (bound_Ts, nth args n) of - Type (s, Ts) => - (case dest_case ctxt s Ts t of - NONE => apsnd (f conds t) - | SOME (conds', branches) => - apfst (cons s) o fold_rev (uncurry fld) - (map (append conds o conjuncts_s) conds' ~~ branches)) - | _ => apsnd (f conds t)) - else - apsnd (f conds t) - end - | _ => apsnd (f conds t)) - in - fld [] t o pair [] - end; - -fun case_of ctxt = ctr_sugar_of ctxt #> Option.map (fst o dest_Const o #casex); - -fun massage_let_if_case ctxt has_call massage_leaf = - let - val thy = Proof_Context.theory_of ctxt; - - fun check_no_call t = if has_call t then unexpected_corec_call ctxt t else (); - - fun massage_abs bound_Ts 0 t = massage_rec bound_Ts t - | massage_abs bound_Ts m (Abs (s, T, t)) = Abs (s, T, massage_abs (T :: bound_Ts) (m - 1) t) - | massage_abs bound_Ts m t = - let val T = domain_type (fastype_of1 (bound_Ts, t)) in - Abs (Name.uu, T, massage_abs (T :: bound_Ts) (m - 1) (incr_boundvars 1 t $ Bound 0)) - end - and massage_rec bound_Ts t = - let val typof = curry fastype_of1 bound_Ts in - (case Term.strip_comb t of - (Const (@{const_name Let}, _), [_, _]) => massage_rec bound_Ts (unfold_let t) - | (Const (@{const_name If}, _), obj :: (branches as [_, _])) => - let val branches' = map (massage_rec bound_Ts) branches in - Term.list_comb (If_const (typof (hd branches')) $ tap check_no_call obj, branches') - end - | (Const (c, _), args as _ :: _ :: _) => - (case try strip_fun_type (Sign.the_const_type thy c) of - SOME (gen_branch_Ts, gen_body_fun_T) => - let - val gen_branch_ms = map num_binder_types gen_branch_Ts; - val n = length gen_branch_ms; - in - if n < length args then - (case gen_body_fun_T of - Type (_, [Type (T_name, _), _]) => - if case_of ctxt T_name = SOME c then - let - val (branches, obj_leftovers) = chop n args; - val branches' = map2 (massage_abs bound_Ts) gen_branch_ms branches; - val branch_Ts' = map typof branches'; - val body_T' = snd (strip_typeN (hd gen_branch_ms) (hd branch_Ts')); - val casex' = Const (c, branch_Ts' ---> map typof obj_leftovers ---> body_T'); - in - Term.list_comb (casex', branches' @ tap (List.app check_no_call) obj_leftovers) - end - else - massage_leaf bound_Ts t - | _ => massage_leaf bound_Ts t) - else - massage_leaf bound_Ts t - end - | NONE => massage_leaf bound_Ts t) - | _ => massage_leaf bound_Ts t) - end - in - massage_rec - end; - -val massage_mutual_corec_call = massage_let_if_case; - -fun curried_type (Type (@{type_name fun}, [Type (@{type_name prod}, Ts), T])) = Ts ---> T; - -fun massage_nested_corec_call ctxt has_call raw_massage_call bound_Ts U t = - let - fun check_no_call t = if has_call t then unexpected_corec_call ctxt t else (); - - val build_map_Inl = build_map ctxt (uncurry Inl_const o dest_sumT o snd); - - fun massage_mutual_call bound_Ts U T t = - if has_call t then factor_out_types ctxt (raw_massage_call bound_Ts) dest_sumT U T t - else build_map_Inl (T, U) $ t; - - fun massage_mutual_fun bound_Ts U T t = - (case t of - Const (@{const_name comp}, comp_T) $ t1 $ t2 => - mk_comp bound_Ts (massage_mutual_fun bound_Ts U T t1, tap check_no_call t2) - | _ => - let - val var = Var ((Name.uu, Term.maxidx_of_term t + 1), - domain_type (fastype_of1 (bound_Ts, t))); - in - Term.lambda var (massage_mutual_call bound_Ts U T (t $ var)) - end); - - fun massage_map bound_Ts (Type (_, Us)) (Type (s, Ts)) t = - (case try (dest_map ctxt s) t of - SOME (map0, fs) => - let - val Type (_, dom_Ts) = domain_type (fastype_of1 (bound_Ts, t)); - val map' = mk_map (length fs) dom_Ts Us map0; - val fs' = - map_flattened_map_args ctxt s (map3 (massage_map_or_map_arg bound_Ts) Us Ts) fs; - in - Term.list_comb (map', fs') - end - | NONE => raise AINT_NO_MAP t) - | massage_map _ _ _ t = raise AINT_NO_MAP t - and massage_map_or_map_arg bound_Ts U T t = - if T = U then - tap check_no_call t - else - massage_map bound_Ts U T t - handle AINT_NO_MAP _ => massage_mutual_fun bound_Ts U T t; - - fun massage_call bound_Ts U T = - massage_let_if_case ctxt has_call (fn bound_Ts => fn t => - if has_call t then - (case U of - Type (s, Us) => - (case try (dest_ctr ctxt s) t of - SOME (f, args) => - let - val typof = curry fastype_of1 bound_Ts; - val f' = mk_ctr Us f - val f'_T = typof f'; - val arg_Ts = map typof args; - in - Term.list_comb (f', map3 (massage_call bound_Ts) (binder_types f'_T) arg_Ts args) - end - | NONE => - (case t of - Const (@{const_name prod_case}, _) $ t' => - let - val U' = curried_type U; - val T' = curried_type T; - in - Const (@{const_name prod_case}, U' --> U) $ massage_call bound_Ts U' T' t' - end - | t1 $ t2 => - (if has_call t2 then - massage_mutual_call bound_Ts U T t - else - massage_map bound_Ts U T t1 $ t2 - handle AINT_NO_MAP _ => massage_mutual_call bound_Ts U T t) - | Abs (s, T', t') => - Abs (s, T', massage_call (T' :: bound_Ts) (range_type U) (range_type T) t') - | _ => massage_mutual_call bound_Ts U T t)) - | _ => ill_formed_corec_call ctxt t) - else - build_map_Inl (T, U) $ t) bound_Ts; - - val T = fastype_of1 (bound_Ts, t); - in - if has_call t then massage_call bound_Ts U T t else build_map_Inl (T, U) $ t - end; - -val fold_rev_corec_call = fold_rev_let_if_case; - -fun expand_to_ctr_term ctxt s Ts t = - (case ctr_sugar_of ctxt s of - SOME {ctrs, casex, ...} => - Term.list_comb (mk_case Ts (Type (s, Ts)) casex, map (mk_ctr Ts) ctrs) $ t - | NONE => raise Fail "expand_to_ctr_term"); - -fun expand_corec_code_rhs ctxt has_call bound_Ts t = - (case fastype_of1 (bound_Ts, t) of - Type (s, Ts) => - massage_let_if_case ctxt has_call (fn _ => fn t => - if can (dest_ctr ctxt s) t then t else expand_to_ctr_term ctxt s Ts t) bound_Ts t - | _ => raise Fail "expand_corec_code_rhs"); - -fun massage_corec_code_rhs ctxt massage_ctr = - massage_let_if_case ctxt (K false) - (fn bound_Ts => uncurry (massage_ctr bound_Ts) o Term.strip_comb); - -fun fold_rev_corec_code_rhs ctxt f = - snd ooo fold_rev_let_if_case ctxt (fn conds => uncurry (f conds) o Term.strip_comb); - -fun case_thms_of_term ctxt bound_Ts t = - let - val (caseT_names, _) = fold_rev_let_if_case ctxt (K (K I)) bound_Ts t (); - val ctr_sugars = map (the o ctr_sugar_of ctxt) caseT_names; - in - (maps #distincts ctr_sugars, maps #discIs ctr_sugars, maps #sel_splits ctr_sugars, - maps #sel_split_asms ctr_sugars) - end; - fun indexed xs h = let val h' = h + length xs in (h upto h' - 1, h') end; fun indexedd xss = fold_map indexed xss; fun indexeddd xsss = fold_map indexedd xsss; @@ -513,224 +35,32 @@ fun find_index_eq hs h = find_index (curry (op =) h) hs; -(*FIXME: remove special cases for product and sum once they are registered as datatypes*) -fun map_thms_of_typ ctxt (Type (s, _)) = - if s = @{type_name prod} then - @{thms map_pair_simp} - else if s = @{type_name sum} then - @{thms sum_map.simps} - else - (case fp_sugar_of ctxt s of - SOME {index, mapss, ...} => nth mapss index - | NONE => []) - | map_thms_of_typ _ _ = []; - -fun rec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy = - let - val thy = Proof_Context.theory_of lthy; - - val ((missing_arg_Ts, perm0_kks, - fp_sugars as {nested_bnfs, fp_res = {xtor_co_iterss = ctor_iters1 :: _, ...}, - co_inducts = [induct_thm], ...} :: _, (lfp_sugar_thms, _)), lthy') = - nested_to_mutual_fps Least_FP bs arg_Ts get_indices callssss0 lthy; - - val perm_fp_sugars = sort (int_ord o pairself #index) fp_sugars; - - val indices = map #index fp_sugars; - val perm_indices = map #index perm_fp_sugars; - - val perm_ctrss = map (#ctrs o of_fp_sugar #ctr_sugars) perm_fp_sugars; - val perm_ctr_Tsss = map (map (binder_types o fastype_of)) perm_ctrss; - val perm_lfpTs = map (body_type o fastype_of o hd) perm_ctrss; - - val nn0 = length arg_Ts; - val nn = length perm_lfpTs; - val kks = 0 upto nn - 1; - val perm_ns = map length perm_ctr_Tsss; - val perm_mss = map (map length) perm_ctr_Tsss; - - val perm_Cs = map (body_type o fastype_of o co_rec_of o of_fp_sugar (#xtor_co_iterss o #fp_res)) - perm_fp_sugars; - val perm_fun_arg_Tssss = - mk_iter_fun_arg_types perm_ctr_Tsss perm_ns perm_mss (co_rec_of ctor_iters1); - - fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs; - fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs; - - val induct_thms = unpermute0 (conj_dests nn induct_thm); +fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x); - val lfpTs = unpermute perm_lfpTs; - val Cs = unpermute perm_Cs; - - val As_rho = tvar_subst thy (take nn0 lfpTs) arg_Ts; - val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn res_Ts; - - val substA = Term.subst_TVars As_rho; - val substAT = Term.typ_subst_TVars As_rho; - val substCT = Term.typ_subst_TVars Cs_rho; - val substACT = substAT o substCT; - - val perm_Cs' = map substCT perm_Cs; - - fun offset_of_ctr 0 _ = 0 - | offset_of_ctr n (({ctrs, ...} : ctr_sugar) :: ctr_sugars) = - length ctrs + offset_of_ctr (n - 1) ctr_sugars; - - fun call_of [i] [T] = (if exists_subtype_in Cs T then Nested_Rec else No_Rec) (i, substACT T) - | call_of [i, i'] [T, T'] = Mutual_Rec ((i, substACT T), (i', substACT T')); +fun drop_All t = + subst_bounds (strip_qnt_vars @{const_name all} t |> map Free |> rev, + strip_qnt_body @{const_name all} t); - fun mk_ctr_spec ctr offset fun_arg_Tss rec_thm = - let - val (fun_arg_hss, _) = indexedd fun_arg_Tss 0; - val fun_arg_hs = flat_rec_arg_args fun_arg_hss; - val fun_arg_iss = map (map (find_index_eq fun_arg_hs)) fun_arg_hss; - in - {ctr = substA ctr, offset = offset, calls = map2 call_of fun_arg_iss fun_arg_Tss, - rec_thm = rec_thm} - end; - - fun mk_ctr_specs index (ctr_sugars : ctr_sugar list) iter_thmsss = - let - val ctrs = #ctrs (nth ctr_sugars index); - val rec_thmss = co_rec_of (nth iter_thmsss index); - val k = offset_of_ctr index ctr_sugars; - val n = length ctrs; - in - map4 mk_ctr_spec ctrs (k upto k + n - 1) (nth perm_fun_arg_Tssss index) rec_thmss - end; - - fun mk_spec ({T, index, ctr_sugars, co_iterss = iterss, co_iter_thmsss = iter_thmsss, ...} - : fp_sugar) = - {recx = mk_co_iter thy Least_FP (substAT T) perm_Cs' (co_rec_of (nth iterss index)), - nested_map_idents = map (unfold_thms lthy [id_def] o map_id0_of_bnf) nested_bnfs, - nested_map_comps = map map_comp_of_bnf nested_bnfs, - ctr_specs = mk_ctr_specs index ctr_sugars iter_thmsss}; - in - ((is_some lfp_sugar_thms, map mk_spec fp_sugars, missing_arg_Ts, induct_thm, induct_thms), - lthy') +fun mk_partial_comp gT fT g = + let val T = domain_type fT --> range_type gT in + Const (@{const_name Fun.comp}, gT --> fT --> T) $ g end; -fun basic_corec_specs_of ctxt res_T = - (case res_T of - Type (T_name, _) => - (case Ctr_Sugar.ctr_sugar_of ctxt T_name of - NONE => not_codatatype ctxt res_T - | SOME {ctrs, discs, selss, ...} => - let - val thy = Proof_Context.theory_of ctxt; - val gfpT = body_type (fastype_of (hd ctrs)); - val As_rho = tvar_subst thy [gfpT] [res_T]; - val substA = Term.subst_TVars As_rho; - - fun mk_spec ctr disc sels = {ctr = substA ctr, disc = substA disc, sels = map substA sels}; - in - map3 mk_spec ctrs discs selss - end) - | _ => not_codatatype ctxt res_T); - -fun corec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy = - let - val thy = Proof_Context.theory_of lthy; - - val ((missing_res_Ts, perm0_kks, - fp_sugars as {nested_bnfs, fp_res = {xtor_co_iterss = dtor_coiters1 :: _, ...}, - co_inducts = coinduct_thms, ...} :: _, (_, gfp_sugar_thms)), lthy') = - nested_to_mutual_fps Greatest_FP bs res_Ts get_indices callssss0 lthy; - - val perm_fp_sugars = sort (int_ord o pairself #index) fp_sugars; - - val indices = map #index fp_sugars; - val perm_indices = map #index perm_fp_sugars; - - val perm_ctrss = map (#ctrs o of_fp_sugar #ctr_sugars) perm_fp_sugars; - val perm_ctr_Tsss = map (map (binder_types o fastype_of)) perm_ctrss; - val perm_gfpTs = map (body_type o fastype_of o hd) perm_ctrss; - - val nn0 = length res_Ts; - val nn = length perm_gfpTs; - val kks = 0 upto nn - 1; - val perm_ns = map length perm_ctr_Tsss; - - val perm_Cs = map (domain_type o body_fun_type o fastype_of o co_rec_of o - of_fp_sugar (#xtor_co_iterss o #fp_res)) perm_fp_sugars; - val (perm_p_Tss, (perm_q_Tssss, _, perm_f_Tssss, _)) = - mk_coiter_fun_arg_types perm_ctr_Tsss perm_Cs perm_ns (co_rec_of dtor_coiters1); - - val (perm_p_hss, h) = indexedd perm_p_Tss 0; - val (perm_q_hssss, h') = indexedddd perm_q_Tssss h; - val (perm_f_hssss, _) = indexedddd perm_f_Tssss h'; - - val fun_arg_hs = - flat (map3 flat_corec_preds_predsss_gettersss perm_p_hss perm_q_hssss perm_f_hssss); - - fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs; - fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs; - - val coinduct_thmss = map (unpermute0 o conj_dests nn) coinduct_thms; +fun mk_partial_compN 0 _ g = g + | mk_partial_compN n fT g = + let val g' = mk_partial_compN (n - 1) (range_type fT) g in + mk_partial_comp (fastype_of g') fT g' + end; - val p_iss = map (map (find_index_eq fun_arg_hs)) (unpermute perm_p_hss); - val q_issss = map (map (map (map (find_index_eq fun_arg_hs)))) (unpermute perm_q_hssss); - val f_issss = map (map (map (map (find_index_eq fun_arg_hs)))) (unpermute perm_f_hssss); - - val f_Tssss = unpermute perm_f_Tssss; - val gfpTs = unpermute perm_gfpTs; - val Cs = unpermute perm_Cs; - - val As_rho = tvar_subst thy (take nn0 gfpTs) res_Ts; - val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn arg_Ts; - - val substA = Term.subst_TVars As_rho; - val substAT = Term.typ_subst_TVars As_rho; - val substCT = Term.typ_subst_TVars Cs_rho; - - val perm_Cs' = map substCT perm_Cs; - - fun call_of nullary [] [g_i] [Type (@{type_name fun}, [_, T])] = - (if exists_subtype_in Cs T then Nested_Corec - else if nullary then Dummy_No_Corec - else No_Corec) g_i - | call_of _ [q_i] [g_i, g_i'] _ = Mutual_Corec (q_i, g_i, g_i'); - - fun mk_ctr_spec ctr disc sels p_ho q_iss f_iss f_Tss discI sel_thms collapse corec_thm - disc_corec sel_corecs = - let val nullary = not (can dest_funT (fastype_of ctr)) in - {ctr = substA ctr, disc = substA disc, sels = map substA sels, pred = p_ho, - calls = map3 (call_of nullary) q_iss f_iss f_Tss, discI = discI, sel_thms = sel_thms, - collapse = collapse, corec_thm = corec_thm, disc_corec = disc_corec, - sel_corecs = sel_corecs} - end; - - fun mk_ctr_specs index (ctr_sugars : ctr_sugar list) p_is q_isss f_isss f_Tsss coiter_thmsss - disc_coitersss sel_coiterssss = - let - val ctrs = #ctrs (nth ctr_sugars index); - val discs = #discs (nth ctr_sugars index); - val selss = #selss (nth ctr_sugars index); - val p_ios = map SOME p_is @ [NONE]; - val discIs = #discIs (nth ctr_sugars index); - val sel_thmss = #sel_thmss (nth ctr_sugars index); - val collapses = #collapses (nth ctr_sugars index); - val corec_thms = co_rec_of (nth coiter_thmsss index); - val disc_corecs = co_rec_of (nth disc_coitersss index); - val sel_corecss = co_rec_of (nth sel_coiterssss index); - in - map13 mk_ctr_spec ctrs discs selss p_ios q_isss f_isss f_Tsss discIs sel_thmss collapses - corec_thms disc_corecs sel_corecss - end; - - fun mk_spec ({T, index, ctr_sugars, co_iterss = coiterss, co_iter_thmsss = coiter_thmsss, - disc_co_itersss = disc_coitersss, sel_co_iterssss = sel_coiterssss, ...} : fp_sugar) - p_is q_isss f_isss f_Tsss = - {corec = mk_co_iter thy Greatest_FP (substAT T) perm_Cs' (co_rec_of (nth coiterss index)), - nested_maps = maps (map_thms_of_typ lthy o T_of_bnf) nested_bnfs, - nested_map_idents = map (unfold_thms lthy [id_def] o map_id0_of_bnf) nested_bnfs, - nested_map_comps = map map_comp_of_bnf nested_bnfs, - ctr_specs = mk_ctr_specs index ctr_sugars p_is q_isss f_isss f_Tsss coiter_thmsss - disc_coitersss sel_coiterssss}; - in - ((is_some gfp_sugar_thms, map5 mk_spec fp_sugars p_iss q_issss f_issss f_Tssss, missing_res_Ts, - co_induct_of coinduct_thms, strong_co_induct_of coinduct_thms, co_induct_of coinduct_thmss, - strong_co_induct_of coinduct_thmss), lthy') +fun mk_compN n bound_Ts (g, f) = + let val typof = curry fastype_of1 bound_Ts in + mk_partial_compN n (typof f) g $ f end; +val mk_comp = mk_compN 1; + +fun get_indices fixes t = map (fst #>> Binding.name_of #> Free) fixes + |> map_index (fn (i, v) => if exists_subterm (equal v) t then SOME i else NONE) + |> map_filter I; + end; diff -r f91022745c85 -r 8fdb4dc08ed1 src/HOL/BNF/Tools/bnf_gfp.ML --- a/src/HOL/BNF/Tools/bnf_gfp.ML Mon Nov 04 15:44:43 2013 +0100 +++ b/src/HOL/BNF/Tools/bnf_gfp.ML Mon Nov 04 16:53:43 2013 +0100 @@ -23,7 +23,7 @@ open BNF_Comp open BNF_FP_Util open BNF_FP_Def_Sugar -open BNF_FP_Rec_Sugar +open BNF_GFP_Rec_Sugar open BNF_GFP_Util open BNF_GFP_Tactics diff -r f91022745c85 -r 8fdb4dc08ed1 src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/BNF/Tools/bnf_gfp_rec_sugar.ML Mon Nov 04 16:53:43 2013 +0100 @@ -0,0 +1,1165 @@ +(* Title: HOL/BNF/Tools/bnf_gfp_rec_sugar.ML + Author: Lorenz Panny, TU Muenchen + Author: Jasmin Blanchette, TU Muenchen + Copyright 2013 + +Corecursor sugar. +*) + +signature BNF_GFP_REC_SUGAR = +sig + val add_primcorecursive_cmd: bool -> + (binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list -> + Proof.context -> Proof.state + val add_primcorec_cmd: bool -> + (binding * string option * mixfix) list * ((Attrib.binding * string) * string option) list -> + local_theory -> local_theory +end; + +structure BNF_GFP_Rec_Sugar : BNF_GFP_REC_SUGAR = +struct + +open Ctr_Sugar +open BNF_Util +open BNF_Def +open BNF_FP_Util +open BNF_FP_Def_Sugar +open BNF_FP_N2M_Sugar +open BNF_FP_Rec_Sugar_Util +open BNF_GFP_Rec_Sugar_Tactics + +val codeN = "code" +val ctrN = "ctr" +val discN = "disc" +val selN = "sel" + +val nitpicksimp_attrs = @{attributes [nitpick_simp]}; +val simp_attrs = @{attributes [simp]}; +val code_nitpicksimp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs; + +exception Primcorec_Error of string * term list; + +fun primcorec_error str = raise Primcorec_Error (str, []); +fun primcorec_error_eqn str eqn = raise Primcorec_Error (str, [eqn]); +fun primcorec_error_eqns str eqns = raise Primcorec_Error (str, eqns); + +datatype corec_call = + Dummy_No_Corec of int | + No_Corec of int | + Mutual_Corec of int * int * int | + Nested_Corec of int; + +type basic_corec_ctr_spec = + {ctr: term, + disc: term, + sels: term list}; + +type corec_ctr_spec = + {ctr: term, + disc: term, + sels: term list, + pred: int option, + calls: corec_call list, + discI: thm, + sel_thms: thm list, + collapse: thm, + corec_thm: thm, + disc_corec: thm, + sel_corecs: thm list}; + +type corec_spec = + {corec: term, + nested_maps: thm list, + nested_map_idents: thm list, + nested_map_comps: thm list, + ctr_specs: corec_ctr_spec list}; + +exception AINT_NO_MAP of term; + +fun not_codatatype ctxt T = + error ("Not a codatatype: " ^ Syntax.string_of_typ ctxt T); +fun ill_formed_corec_call ctxt t = + error ("Ill-formed corecursive call: " ^ quote (Syntax.string_of_term ctxt t)); +fun invalid_map ctxt t = + error ("Invalid map function in " ^ quote (Syntax.string_of_term ctxt t)); +fun unexpected_corec_call ctxt t = + error ("Unexpected corecursive call: " ^ quote (Syntax.string_of_term ctxt t)); + +val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True}; +val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False}; + +val conjuncts_s = filter_out (curry (op =) @{const True}) o HOLogic.conjuncts; + +fun s_not @{const True} = @{const False} + | s_not @{const False} = @{const True} + | s_not (@{const Not} $ t) = t + | s_not (@{const conj} $ t $ u) = @{const disj} $ s_not t $ s_not u + | s_not (@{const disj} $ t $ u) = @{const conj} $ s_not t $ s_not u + | s_not t = @{const Not} $ t; + +val s_not_conj = conjuncts_s o s_not o mk_conjs; + +fun propagate_unit_pos u cs = if member (op aconv) cs u then [@{const False}] else cs; + +fun propagate_unit_neg not_u cs = remove (op aconv) not_u cs; + +fun propagate_units css = + (case List.partition (can the_single) css of + ([], _) => css + | ([u] :: uss, css') => + [u] :: propagate_units (map (propagate_unit_neg (s_not u)) + (map (propagate_unit_pos u) (uss @ css')))); + +fun s_conjs cs = + if member (op aconv) cs @{const False} then @{const False} + else mk_conjs (remove (op aconv) @{const True} cs); + +fun s_disjs ds = + if member (op aconv) ds @{const True} then @{const True} + else mk_disjs (remove (op aconv) @{const False} ds); + +fun s_dnf css0 = + let val css = propagate_units css0 in + if null css then + [@{const False}] + else if exists null css then + [] + else + map (fn c :: cs => (c, cs)) css + |> AList.coalesce (op =) + |> map (fn (c, css) => c :: s_dnf css) + |> (fn [cs] => cs | css => [s_disjs (map s_conjs css)]) + end; + +fun fold_rev_let_if_case ctxt f bound_Ts t = + let + val thy = Proof_Context.theory_of ctxt; + + fun fld conds t = + (case Term.strip_comb t of + (Const (@{const_name Let}, _), [_, _]) => fld conds (unfold_let t) + | (Const (@{const_name If}, _), [cond, then_branch, else_branch]) => + fld (conds @ conjuncts_s cond) then_branch o fld (conds @ s_not_conj [cond]) else_branch + | (Const (c, _), args as _ :: _ :: _) => + let val n = num_binder_types (Sign.the_const_type thy c) - 1 in + if n >= 0 andalso n < length args then + (case fastype_of1 (bound_Ts, nth args n) of + Type (s, Ts) => + (case dest_case ctxt s Ts t of + NONE => apsnd (f conds t) + | SOME (conds', branches) => + apfst (cons s) o fold_rev (uncurry fld) + (map (append conds o conjuncts_s) conds' ~~ branches)) + | _ => apsnd (f conds t)) + else + apsnd (f conds t) + end + | _ => apsnd (f conds t)) + in + fld [] t o pair [] + end; + +fun case_of ctxt = ctr_sugar_of ctxt #> Option.map (fst o dest_Const o #casex); + +fun massage_let_if_case ctxt has_call massage_leaf = + let + val thy = Proof_Context.theory_of ctxt; + + fun check_no_call t = if has_call t then unexpected_corec_call ctxt t else (); + + fun massage_abs bound_Ts 0 t = massage_rec bound_Ts t + | massage_abs bound_Ts m (Abs (s, T, t)) = Abs (s, T, massage_abs (T :: bound_Ts) (m - 1) t) + | massage_abs bound_Ts m t = + let val T = domain_type (fastype_of1 (bound_Ts, t)) in + Abs (Name.uu, T, massage_abs (T :: bound_Ts) (m - 1) (incr_boundvars 1 t $ Bound 0)) + end + and massage_rec bound_Ts t = + let val typof = curry fastype_of1 bound_Ts in + (case Term.strip_comb t of + (Const (@{const_name Let}, _), [_, _]) => massage_rec bound_Ts (unfold_let t) + | (Const (@{const_name If}, _), obj :: (branches as [_, _])) => + let val branches' = map (massage_rec bound_Ts) branches in + Term.list_comb (If_const (typof (hd branches')) $ tap check_no_call obj, branches') + end + | (Const (c, _), args as _ :: _ :: _) => + (case try strip_fun_type (Sign.the_const_type thy c) of + SOME (gen_branch_Ts, gen_body_fun_T) => + let + val gen_branch_ms = map num_binder_types gen_branch_Ts; + val n = length gen_branch_ms; + in + if n < length args then + (case gen_body_fun_T of + Type (_, [Type (T_name, _), _]) => + if case_of ctxt T_name = SOME c then + let + val (branches, obj_leftovers) = chop n args; + val branches' = map2 (massage_abs bound_Ts) gen_branch_ms branches; + val branch_Ts' = map typof branches'; + val body_T' = snd (strip_typeN (hd gen_branch_ms) (hd branch_Ts')); + val casex' = Const (c, branch_Ts' ---> map typof obj_leftovers ---> body_T'); + in + Term.list_comb (casex', branches' @ tap (List.app check_no_call) obj_leftovers) + end + else + massage_leaf bound_Ts t + | _ => massage_leaf bound_Ts t) + else + massage_leaf bound_Ts t + end + | NONE => massage_leaf bound_Ts t) + | _ => massage_leaf bound_Ts t) + end + in + massage_rec + end; + +val massage_mutual_corec_call = massage_let_if_case; + +fun curried_type (Type (@{type_name fun}, [Type (@{type_name prod}, Ts), T])) = Ts ---> T; + +fun massage_nested_corec_call ctxt has_call raw_massage_call bound_Ts U t = + let + fun check_no_call t = if has_call t then unexpected_corec_call ctxt t else (); + + val build_map_Inl = build_map ctxt (uncurry Inl_const o dest_sumT o snd); + + fun massage_mutual_call bound_Ts U T t = + if has_call t then + (case try dest_sumT U of + SOME (U1, U2) => if U1 = T then raw_massage_call bound_Ts T U2 t else invalid_map ctxt t + | NONE => invalid_map ctxt t) + else + build_map_Inl (T, U) $ t; + + fun massage_mutual_fun bound_Ts U T t = + (case t of + Const (@{const_name comp}, _) $ t1 $ t2 => + mk_comp bound_Ts (massage_mutual_fun bound_Ts U T t1, tap check_no_call t2) + | _ => + let + val var = Var ((Name.uu, Term.maxidx_of_term t + 1), + domain_type (fastype_of1 (bound_Ts, t))); + in + Term.lambda var (massage_mutual_call bound_Ts U T (t $ var)) + end); + + fun massage_map bound_Ts (Type (_, Us)) (Type (s, Ts)) t = + (case try (dest_map ctxt s) t of + SOME (map0, fs) => + let + val Type (_, dom_Ts) = domain_type (fastype_of1 (bound_Ts, t)); + val map' = mk_map (length fs) dom_Ts Us map0; + val fs' = + map_flattened_map_args ctxt s (map3 (massage_map_or_map_arg bound_Ts) Us Ts) fs; + in + Term.list_comb (map', fs') + end + | NONE => raise AINT_NO_MAP t) + | massage_map _ _ _ t = raise AINT_NO_MAP t + and massage_map_or_map_arg bound_Ts U T t = + if T = U then + tap check_no_call t + else + massage_map bound_Ts U T t + handle AINT_NO_MAP _ => massage_mutual_fun bound_Ts U T t; + + fun massage_call bound_Ts U T = + massage_let_if_case ctxt has_call (fn bound_Ts => fn t => + if has_call t then + (case U of + Type (s, Us) => + (case try (dest_ctr ctxt s) t of + SOME (f, args) => + let + val typof = curry fastype_of1 bound_Ts; + val f' = mk_ctr Us f + val f'_T = typof f'; + val arg_Ts = map typof args; + in + Term.list_comb (f', map3 (massage_call bound_Ts) (binder_types f'_T) arg_Ts args) + end + | NONE => + (case t of + Const (@{const_name prod_case}, _) $ t' => + let + val U' = curried_type U; + val T' = curried_type T; + in + Const (@{const_name prod_case}, U' --> U) $ massage_call bound_Ts U' T' t' + end + | t1 $ t2 => + (if has_call t2 then + massage_mutual_call bound_Ts U T t + else + massage_map bound_Ts U T t1 $ t2 + handle AINT_NO_MAP _ => massage_mutual_call bound_Ts U T t) + | Abs (s, T', t') => + Abs (s, T', massage_call (T' :: bound_Ts) (range_type U) (range_type T) t') + | _ => massage_mutual_call bound_Ts U T t)) + | _ => ill_formed_corec_call ctxt t) + else + build_map_Inl (T, U) $ t) bound_Ts; + + val T = fastype_of1 (bound_Ts, t); + in + if has_call t then massage_call bound_Ts U T t else build_map_Inl (T, U) $ t + end; + +val fold_rev_corec_call = fold_rev_let_if_case; + +fun expand_to_ctr_term ctxt s Ts t = + (case ctr_sugar_of ctxt s of + SOME {ctrs, casex, ...} => + Term.list_comb (mk_case Ts (Type (s, Ts)) casex, map (mk_ctr Ts) ctrs) $ t + | NONE => raise Fail "expand_to_ctr_term"); + +fun expand_corec_code_rhs ctxt has_call bound_Ts t = + (case fastype_of1 (bound_Ts, t) of + Type (s, Ts) => + massage_let_if_case ctxt has_call (fn _ => fn t => + if can (dest_ctr ctxt s) t then t else expand_to_ctr_term ctxt s Ts t) bound_Ts t + | _ => raise Fail "expand_corec_code_rhs"); + +fun massage_corec_code_rhs ctxt massage_ctr = + massage_let_if_case ctxt (K false) + (fn bound_Ts => uncurry (massage_ctr bound_Ts) o Term.strip_comb); + +fun fold_rev_corec_code_rhs ctxt f = + snd ooo fold_rev_let_if_case ctxt (fn conds => uncurry (f conds) o Term.strip_comb); + +fun case_thms_of_term ctxt bound_Ts t = + let + val (caseT_names, _) = fold_rev_let_if_case ctxt (K (K I)) bound_Ts t (); + val ctr_sugars = map (the o ctr_sugar_of ctxt) caseT_names; + in + (maps #distincts ctr_sugars, maps #discIs ctr_sugars, maps #sel_splits ctr_sugars, + maps #sel_split_asms ctr_sugars) + end; + +fun basic_corec_specs_of ctxt res_T = + (case res_T of + Type (T_name, _) => + (case Ctr_Sugar.ctr_sugar_of ctxt T_name of + NONE => not_codatatype ctxt res_T + | SOME {ctrs, discs, selss, ...} => + let + val thy = Proof_Context.theory_of ctxt; + val gfpT = body_type (fastype_of (hd ctrs)); + val As_rho = tvar_subst thy [gfpT] [res_T]; + val substA = Term.subst_TVars As_rho; + + fun mk_spec ctr disc sels = {ctr = substA ctr, disc = substA disc, sels = map substA sels}; + in + map3 mk_spec ctrs discs selss + end) + | _ => not_codatatype ctxt res_T); + +(*FIXME: remove special cases for product and sum once they are registered as datatypes*) +fun map_thms_of_typ ctxt (Type (s, _)) = + if s = @{type_name prod} then + @{thms map_pair_simp} + else if s = @{type_name sum} then + @{thms sum_map.simps} + else + (case fp_sugar_of ctxt s of + SOME {index, mapss, ...} => nth mapss index + | NONE => []) + | map_thms_of_typ _ _ = []; + +fun corec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy = + let + val thy = Proof_Context.theory_of lthy; + + val ((missing_res_Ts, perm0_kks, + fp_sugars as {nested_bnfs, fp_res = {xtor_co_iterss = dtor_coiters1 :: _, ...}, + co_inducts = coinduct_thms, ...} :: _, (_, gfp_sugar_thms)), lthy') = + nested_to_mutual_fps Greatest_FP bs res_Ts get_indices callssss0 lthy; + + val perm_fp_sugars = sort (int_ord o pairself #index) fp_sugars; + + val indices = map #index fp_sugars; + val perm_indices = map #index perm_fp_sugars; + + val perm_ctrss = map (#ctrs o of_fp_sugar #ctr_sugars) perm_fp_sugars; + val perm_ctr_Tsss = map (map (binder_types o fastype_of)) perm_ctrss; + val perm_gfpTs = map (body_type o fastype_of o hd) perm_ctrss; + + val nn0 = length res_Ts; + val nn = length perm_gfpTs; + val kks = 0 upto nn - 1; + val perm_ns = map length perm_ctr_Tsss; + + val perm_Cs = map (domain_type o body_fun_type o fastype_of o co_rec_of o + of_fp_sugar (#xtor_co_iterss o #fp_res)) perm_fp_sugars; + val (perm_p_Tss, (perm_q_Tssss, _, perm_f_Tssss, _)) = + mk_coiter_fun_arg_types perm_ctr_Tsss perm_Cs perm_ns (co_rec_of dtor_coiters1); + + val (perm_p_hss, h) = indexedd perm_p_Tss 0; + val (perm_q_hssss, h') = indexedddd perm_q_Tssss h; + val (perm_f_hssss, _) = indexedddd perm_f_Tssss h'; + + val fun_arg_hs = + flat (map3 flat_corec_preds_predsss_gettersss perm_p_hss perm_q_hssss perm_f_hssss); + + fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs; + fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs; + + val coinduct_thmss = map (unpermute0 o conj_dests nn) coinduct_thms; + + val p_iss = map (map (find_index_eq fun_arg_hs)) (unpermute perm_p_hss); + val q_issss = map (map (map (map (find_index_eq fun_arg_hs)))) (unpermute perm_q_hssss); + val f_issss = map (map (map (map (find_index_eq fun_arg_hs)))) (unpermute perm_f_hssss); + + val f_Tssss = unpermute perm_f_Tssss; + val gfpTs = unpermute perm_gfpTs; + val Cs = unpermute perm_Cs; + + val As_rho = tvar_subst thy (take nn0 gfpTs) res_Ts; + val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn arg_Ts; + + val substA = Term.subst_TVars As_rho; + val substAT = Term.typ_subst_TVars As_rho; + val substCT = Term.typ_subst_TVars Cs_rho; + + val perm_Cs' = map substCT perm_Cs; + + fun call_of nullary [] [g_i] [Type (@{type_name fun}, [_, T])] = + (if exists_subtype_in Cs T then Nested_Corec + else if nullary then Dummy_No_Corec + else No_Corec) g_i + | call_of _ [q_i] [g_i, g_i'] _ = Mutual_Corec (q_i, g_i, g_i'); + + fun mk_ctr_spec ctr disc sels p_ho q_iss f_iss f_Tss discI sel_thms collapse corec_thm + disc_corec sel_corecs = + let val nullary = not (can dest_funT (fastype_of ctr)) in + {ctr = substA ctr, disc = substA disc, sels = map substA sels, pred = p_ho, + calls = map3 (call_of nullary) q_iss f_iss f_Tss, discI = discI, sel_thms = sel_thms, + collapse = collapse, corec_thm = corec_thm, disc_corec = disc_corec, + sel_corecs = sel_corecs} + end; + + fun mk_ctr_specs index (ctr_sugars : ctr_sugar list) p_is q_isss f_isss f_Tsss coiter_thmsss + disc_coitersss sel_coiterssss = + let + val ctrs = #ctrs (nth ctr_sugars index); + val discs = #discs (nth ctr_sugars index); + val selss = #selss (nth ctr_sugars index); + val p_ios = map SOME p_is @ [NONE]; + val discIs = #discIs (nth ctr_sugars index); + val sel_thmss = #sel_thmss (nth ctr_sugars index); + val collapses = #collapses (nth ctr_sugars index); + val corec_thms = co_rec_of (nth coiter_thmsss index); + val disc_corecs = co_rec_of (nth disc_coitersss index); + val sel_corecss = co_rec_of (nth sel_coiterssss index); + in + map13 mk_ctr_spec ctrs discs selss p_ios q_isss f_isss f_Tsss discIs sel_thmss collapses + corec_thms disc_corecs sel_corecss + end; + + fun mk_spec ({T, index, ctr_sugars, co_iterss = coiterss, co_iter_thmsss = coiter_thmsss, + disc_co_itersss = disc_coitersss, sel_co_iterssss = sel_coiterssss, ...} : fp_sugar) + p_is q_isss f_isss f_Tsss = + {corec = mk_co_iter thy Greatest_FP (substAT T) perm_Cs' (co_rec_of (nth coiterss index)), + nested_maps = maps (map_thms_of_typ lthy o T_of_bnf) nested_bnfs, + nested_map_idents = map (unfold_thms lthy @{thms id_def} o map_id0_of_bnf) nested_bnfs, + nested_map_comps = map map_comp_of_bnf nested_bnfs, + ctr_specs = mk_ctr_specs index ctr_sugars p_is q_isss f_isss f_Tsss coiter_thmsss + disc_coitersss sel_coiterssss}; + in + ((is_some gfp_sugar_thms, map5 mk_spec fp_sugars p_iss q_issss f_issss f_Tssss, missing_res_Ts, + co_induct_of coinduct_thms, strong_co_induct_of coinduct_thms, co_induct_of coinduct_thmss, + strong_co_induct_of coinduct_thmss), lthy') + end; + +val const_name = try (fn Const (v, _) => v); +val undef_const = Const (@{const_name undefined}, dummyT); + +val abs_tuple = HOLogic.tupled_lambda o HOLogic.mk_tuple; +fun abstract vs = + let fun a n (t $ u) = a n t $ a n u + | a n (Abs (v, T, b)) = Abs (v, T, a (n + 1) b) + | a n t = let val idx = find_index (equal t) vs in + if idx < 0 then t else Bound (n + idx) end + in a 0 end; +fun mk_prod1 Ts (t, u) = HOLogic.pair_const (fastype_of1 (Ts, t)) (fastype_of1 (Ts, u)) $ t $ u; +fun mk_tuple1 Ts = the_default HOLogic.unit o try (foldr1 (mk_prod1 Ts)); + +type coeqn_data_disc = { + fun_name: string, + fun_T: typ, + fun_args: term list, + ctr: term, + ctr_no: int, (*###*) + disc: term, + prems: term list, + auto_gen: bool, + maybe_ctr_rhs: term option, + maybe_code_rhs: term option, + user_eqn: term +}; + +type coeqn_data_sel = { + fun_name: string, + fun_T: typ, + fun_args: term list, + ctr: term, + sel: term, + rhs_term: term, + user_eqn: term +}; + +datatype coeqn_data = + Disc of coeqn_data_disc | + Sel of coeqn_data_sel; + +fun dissect_coeqn_disc seq fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) + maybe_ctr_rhs maybe_code_rhs prems' concl matchedsss = + let + fun find_subterm p = let (* FIXME \? *) + fun f (t as u $ v) = if p t then SOME t else merge_options (f u, f v) + | f t = if p t then SOME t else NONE + in f end; + + val applied_fun = concl + |> find_subterm (member ((op =) o apsnd SOME) fun_names o try (fst o dest_Free o head_of)) + |> the + handle Option.Option => primcorec_error_eqn "malformed discriminator formula" concl; + val ((fun_name, fun_T), fun_args) = strip_comb applied_fun |>> dest_Free; + val SOME basic_ctr_specs = AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name; + + val discs = map #disc basic_ctr_specs; + val ctrs = map #ctr basic_ctr_specs; + val not_disc = head_of concl = @{term Not}; + val _ = not_disc andalso length ctrs <> 2 andalso + primcorec_error_eqn "negated discriminator for a type with \ 2 constructors" concl; + val disc' = find_subterm (member (op =) discs o head_of) concl; + val eq_ctr0 = concl |> perhaps (try HOLogic.dest_not) |> try (HOLogic.dest_eq #> snd) + |> (fn SOME t => let val n = find_index (equal t) ctrs in + if n >= 0 then SOME n else NONE end | _ => NONE); + val _ = is_some disc' orelse is_some eq_ctr0 orelse + primcorec_error_eqn "no discriminator in equation" concl; + val ctr_no' = + if is_none disc' then the eq_ctr0 else find_index (equal (head_of (the disc'))) discs; + val ctr_no = if not_disc then 1 - ctr_no' else ctr_no'; + val {ctr, disc, ...} = nth basic_ctr_specs ctr_no; + + val catch_all = try (fst o dest_Free o the_single) prems' = SOME Name.uu_; + val matchedss = AList.lookup (op =) matchedsss fun_name |> the_default []; + val prems = map (abstract (List.rev fun_args)) prems'; + val real_prems = + (if catch_all orelse seq then maps s_not_conj matchedss else []) @ + (if catch_all then [] else prems); + + val matchedsss' = AList.delete (op =) fun_name matchedsss + |> cons (fun_name, if seq then matchedss @ [prems] else matchedss @ [real_prems]); + + val user_eqn = + (real_prems, concl) + |>> map HOLogic.mk_Trueprop ||> HOLogic.mk_Trueprop o abstract (List.rev fun_args) + |> curry Logic.list_all (map dest_Free fun_args) o Logic.list_implies; + in + (Disc { + fun_name = fun_name, + fun_T = fun_T, + fun_args = fun_args, + ctr = ctr, + ctr_no = ctr_no, + disc = disc, + prems = real_prems, + auto_gen = catch_all, + maybe_ctr_rhs = maybe_ctr_rhs, + maybe_code_rhs = maybe_code_rhs, + user_eqn = user_eqn + }, matchedsss') + end; + +fun dissect_coeqn_sel fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) eqn' + maybe_of_spec eqn = + let + val (lhs, rhs) = HOLogic.dest_eq eqn + handle TERM _ => + primcorec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn; + val sel = head_of lhs; + val ((fun_name, fun_T), fun_args) = dest_comb lhs |> snd |> strip_comb |> apfst dest_Free + handle TERM _ => + primcorec_error_eqn "malformed selector argument in left-hand side" eqn; + val basic_ctr_specs = the (AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name) + handle Option.Option => primcorec_error_eqn "malformed selector argument in left-hand side" eqn; + val {ctr, ...} = + (case maybe_of_spec of + SOME of_spec => the (find_first (equal of_spec o #ctr) basic_ctr_specs) + | NONE => filter (exists (equal sel) o #sels) basic_ctr_specs |> the_single + handle List.Empty => primcorec_error_eqn "ambiguous selector - use \"of\"" eqn); + val user_eqn = drop_All eqn'; + in + Sel { + fun_name = fun_name, + fun_T = fun_T, + fun_args = fun_args, + ctr = ctr, + sel = sel, + rhs_term = rhs, + user_eqn = user_eqn + } + end; + +fun dissect_coeqn_ctr seq fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) eqn' + maybe_code_rhs prems concl matchedsss = + let + val (lhs, rhs) = HOLogic.dest_eq concl; + val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free; + val SOME basic_ctr_specs = AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name; + val (ctr, ctr_args) = strip_comb (unfold_let rhs); + val {disc, sels, ...} = the (find_first (equal ctr o #ctr) basic_ctr_specs) + handle Option.Option => primcorec_error_eqn "not a constructor" ctr; + + val disc_concl = betapply (disc, lhs); + val (maybe_eqn_data_disc, matchedsss') = if length basic_ctr_specs = 1 + then (NONE, matchedsss) + else apfst SOME (dissect_coeqn_disc seq fun_names basic_ctr_specss + (SOME (abstract (List.rev fun_args) rhs)) maybe_code_rhs prems disc_concl matchedsss); + + val sel_concls = sels ~~ ctr_args + |> map (fn (sel, ctr_arg) => HOLogic.mk_eq (betapply (sel, lhs), ctr_arg)); + +(* +val _ = tracing ("reduced\n " ^ Syntax.string_of_term @{context} concl ^ "\nto\n \ " ^ + (is_some maybe_eqn_data_disc ? K (Syntax.string_of_term @{context} disc_concl ^ "\n \ ")) "" ^ + space_implode "\n \ " (map (Syntax.string_of_term @{context}) sel_concls) ^ + "\nfor premise(s)\n \ " ^ + space_implode "\n \ " (map (Syntax.string_of_term @{context}) prems)); +*) + + val eqns_data_sel = + map (dissect_coeqn_sel fun_names basic_ctr_specss eqn' (SOME ctr)) sel_concls; + in + (the_list maybe_eqn_data_disc @ eqns_data_sel, matchedsss') + end; + +fun dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn' concl matchedsss = + let + val (lhs, (rhs', rhs)) = HOLogic.dest_eq concl ||> `(expand_corec_code_rhs lthy has_call []); + val (fun_name, fun_args) = strip_comb lhs |>> fst o dest_Free; + val SOME basic_ctr_specs = AList.lookup (op =) (fun_names ~~ basic_ctr_specss) fun_name; + + val cond_ctrs = fold_rev_corec_code_rhs lthy (fn cs => fn ctr => fn _ => + if member ((op =) o apsnd #ctr) basic_ctr_specs ctr + then cons (ctr, cs) + else primcorec_error_eqn "not a constructor" ctr) [] rhs' [] + |> AList.group (op =); + + val ctr_premss = (case cond_ctrs of [_] => [[]] | _ => map (s_dnf o snd) cond_ctrs); + val ctr_concls = cond_ctrs |> map (fn (ctr, _) => + binder_types (fastype_of ctr) + |> map_index (fn (n, T) => massage_corec_code_rhs lthy (fn _ => fn ctr' => fn args => + if ctr' = ctr then nth args n else Const (@{const_name undefined}, T)) [] rhs') + |> curry list_comb ctr + |> curry HOLogic.mk_eq lhs); + in + fold_map2 (dissect_coeqn_ctr false fun_names basic_ctr_specss eqn' + (SOME (abstract (List.rev fun_args) rhs))) + ctr_premss ctr_concls matchedsss + end; + +fun dissect_coeqn lthy seq has_call fun_names (basic_ctr_specss : basic_corec_ctr_spec list list) + eqn' maybe_of_spec matchedsss = + let + val eqn = drop_All eqn' + handle TERM _ => primcorec_error_eqn "malformed function equation" eqn'; + val (prems, concl) = Logic.strip_horn eqn + |> apfst (map HOLogic.dest_Trueprop) o apsnd HOLogic.dest_Trueprop; + + val head = concl + |> perhaps (try HOLogic.dest_not) |> perhaps (try (fst o HOLogic.dest_eq)) + |> head_of; + + val maybe_rhs = concl |> perhaps (try HOLogic.dest_not) |> try (snd o HOLogic.dest_eq); + + val discs = maps (map #disc) basic_ctr_specss; + val sels = maps (maps #sels) basic_ctr_specss; + val ctrs = maps (map #ctr) basic_ctr_specss; + in + if member (op =) discs head orelse + is_some maybe_rhs andalso + member (op =) (filter (null o binder_types o fastype_of) ctrs) (the maybe_rhs) then + dissect_coeqn_disc seq fun_names basic_ctr_specss NONE NONE prems concl matchedsss + |>> single + else if member (op =) sels head then + ([dissect_coeqn_sel fun_names basic_ctr_specss eqn' maybe_of_spec concl], matchedsss) + else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso + member (op =) ctrs (head_of (unfold_let (the maybe_rhs))) then + dissect_coeqn_ctr seq fun_names basic_ctr_specss eqn' NONE prems concl matchedsss + else if is_Free head andalso member (op =) fun_names (fst (dest_Free head)) andalso + null prems then + dissect_coeqn_code lthy has_call fun_names basic_ctr_specss eqn' concl matchedsss + |>> flat + else + primcorec_error_eqn "malformed function equation" eqn + end; + +fun build_corec_arg_disc (ctr_specs : corec_ctr_spec list) + ({fun_args, ctr_no, prems, ...} : coeqn_data_disc) = + if is_none (#pred (nth ctr_specs ctr_no)) then I else + s_conjs prems + |> curry subst_bounds (List.rev fun_args) + |> HOLogic.tupled_lambda (HOLogic.mk_tuple fun_args) + |> K |> nth_map (the (#pred (nth ctr_specs ctr_no))); + +fun build_corec_arg_no_call (sel_eqns : coeqn_data_sel list) sel = + find_first (equal sel o #sel) sel_eqns + |> try (fn SOME {fun_args, rhs_term, ...} => abs_tuple fun_args rhs_term) + |> the_default undef_const + |> K; + +fun build_corec_args_mutual_call lthy has_call (sel_eqns : coeqn_data_sel list) sel = + (case find_first (equal sel o #sel) sel_eqns of + NONE => (I, I, I) + | SOME {fun_args, rhs_term, ... } => + let + val bound_Ts = List.rev (map fastype_of fun_args); + fun rewrite_stop _ t = if has_call t then @{term False} else @{term True}; + fun rewrite_end _ t = if has_call t then undef_const else t; + fun rewrite_cont bound_Ts t = + if has_call t then mk_tuple1 bound_Ts (snd (strip_comb t)) else undef_const; + fun massage f _ = massage_mutual_corec_call lthy has_call f bound_Ts rhs_term + |> abs_tuple fun_args; + in + (massage rewrite_stop, massage rewrite_end, massage rewrite_cont) + end); + +fun build_corec_arg_nested_call lthy has_call (sel_eqns : coeqn_data_sel list) sel = + (case find_first (equal sel o #sel) sel_eqns of + NONE => I + | SOME {fun_args, rhs_term, ...} => + let + val bound_Ts = List.rev (map fastype_of fun_args); + fun rewrite bound_Ts U T (Abs (v, V, b)) = Abs (v, V, rewrite (V :: bound_Ts) U T b) + | rewrite bound_Ts U T (t as _ $ _) = + let val (u, vs) = strip_comb t in + if is_Free u andalso has_call u then + Inr_const U T $ mk_tuple1 bound_Ts vs + else if const_name u = SOME @{const_name prod_case} then + map (rewrite bound_Ts U T) vs |> chop 1 |>> HOLogic.mk_split o the_single |> list_comb + else + list_comb (rewrite bound_Ts U T u, map (rewrite bound_Ts U T) vs) + end + | rewrite _ U T t = + if is_Free t andalso has_call t then Inr_const U T $ HOLogic.unit else t; + fun massage t = + rhs_term + |> massage_nested_corec_call lthy has_call rewrite bound_Ts (range_type (fastype_of t)) + |> abs_tuple fun_args; + in + massage + end); + +fun build_corec_args_sel lthy has_call (all_sel_eqns : coeqn_data_sel list) + (ctr_spec : corec_ctr_spec) = + (case filter (equal (#ctr ctr_spec) o #ctr) all_sel_eqns of + [] => I + | sel_eqns => + let + val sel_call_list = #sels ctr_spec ~~ #calls ctr_spec; + val no_calls' = map_filter (try (apsnd (fn No_Corec n => n))) sel_call_list; + val mutual_calls' = map_filter (try (apsnd (fn Mutual_Corec n => n))) sel_call_list; + val nested_calls' = map_filter (try (apsnd (fn Nested_Corec n => n))) sel_call_list; + in + I + #> fold (fn (sel, n) => nth_map n (build_corec_arg_no_call sel_eqns sel)) no_calls' + #> fold (fn (sel, (q, g, h)) => + let val (fq, fg, fh) = build_corec_args_mutual_call lthy has_call sel_eqns sel in + nth_map q fq o nth_map g fg o nth_map h fh end) mutual_calls' + #> fold (fn (sel, n) => nth_map n + (build_corec_arg_nested_call lthy has_call sel_eqns sel)) nested_calls' + end); + +fun build_codefs lthy bs mxs has_call arg_Tss (corec_specs : corec_spec list) + (disc_eqnss : coeqn_data_disc list list) (sel_eqnss : coeqn_data_sel list list) = + let + val corecs = map #corec corec_specs; + val ctr_specss = map #ctr_specs corec_specs; + val corec_args = hd corecs + |> fst o split_last o binder_types o fastype_of + |> map (Const o pair @{const_name undefined}) + |> fold2 (fold o build_corec_arg_disc) ctr_specss disc_eqnss + |> fold2 (fold o build_corec_args_sel lthy has_call) sel_eqnss ctr_specss; + fun currys [] t = t + | currys Ts t = t $ mk_tuple1 (List.rev Ts) (map Bound (length Ts - 1 downto 0)) + |> fold_rev (Term.abs o pair Name.uu) Ts; + +(* +val _ = tracing ("corecursor arguments:\n \ " ^ + space_implode "\n \ " (map (Syntax.string_of_term lthy) corec_args)); +*) + + val exclss' = + disc_eqnss + |> map (map (fn x => (#fun_args x, #ctr_no x, #prems x, #auto_gen x)) + #> fst o (fn xs => fold_map (fn x => fn ys => ((x, ys), ys @ [x])) xs []) + #> maps (uncurry (map o pair) + #> map (fn ((fun_args, c, x, a), (_, c', y, a')) => + ((c, c', a orelse a'), (x, s_not (s_conjs y))) + ||> apfst (map HOLogic.mk_Trueprop) o apsnd HOLogic.mk_Trueprop + ||> Logic.list_implies + ||> curry Logic.list_all (map dest_Free fun_args)))) + in + map (list_comb o rpair corec_args) corecs + |> map2 (fn Ts => fn t => if length Ts = 0 then t $ HOLogic.unit else t) arg_Tss + |> map2 currys arg_Tss + |> Syntax.check_terms lthy + |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.conceal (Thm.def_binding b), []), t))) + bs mxs + |> rpair exclss' + end; + +fun mk_real_disc_eqns fun_binding arg_Ts ({ctr_specs, ...} : corec_spec) + (sel_eqns : coeqn_data_sel list) (disc_eqns : coeqn_data_disc list) = + if length disc_eqns <> length ctr_specs - 1 then disc_eqns else + let + val n = 0 upto length ctr_specs + |> the o find_first (fn idx => not (exists (equal idx o #ctr_no) disc_eqns)); + val fun_args = (try (#fun_args o hd) disc_eqns, try (#fun_args o hd) sel_eqns) + |> the_default (map (curry Free Name.uu) arg_Ts) o merge_options; + val extra_disc_eqn = { + fun_name = Binding.name_of fun_binding, + fun_T = arg_Ts ---> body_type (fastype_of (#ctr (hd ctr_specs))), + fun_args = fun_args, + ctr = #ctr (nth ctr_specs n), + ctr_no = n, + disc = #disc (nth ctr_specs n), + prems = maps (s_not_conj o #prems) disc_eqns, + auto_gen = true, + maybe_ctr_rhs = NONE, + maybe_code_rhs = NONE, + user_eqn = undef_const}; + in + chop n disc_eqns ||> cons extra_disc_eqn |> (op @) + end; + +fun find_corec_calls ctxt has_call basic_ctr_specs ({ctr, sel, rhs_term, ...} : coeqn_data_sel) = + let + val sel_no = find_first (equal ctr o #ctr) basic_ctr_specs + |> find_index (equal sel) o #sels o the; + fun find t = if has_call t then snd (fold_rev_corec_call ctxt (K cons) [] t []) else []; + in + find rhs_term + |> K |> nth_map sel_no |> AList.map_entry (op =) ctr + end; + +fun add_primcorec_ursive maybe_tac seq fixes specs maybe_of_specs lthy = + let + val (bs, mxs) = map_split (apfst fst) fixes; + val (arg_Ts, res_Ts) = map (strip_type o snd o fst #>> HOLogic.mk_tupleT) fixes |> split_list; + + val fun_names = map Binding.name_of bs; + val basic_ctr_specss = map (basic_corec_specs_of lthy) res_Ts; + val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =)); + val eqns_data = + fold_map2 (dissect_coeqn lthy seq has_call fun_names basic_ctr_specss) (map snd specs) + maybe_of_specs [] + |> flat o fst; + + val callssss = + map_filter (try (fn Sel x => x)) eqns_data + |> partition_eq ((op =) o pairself #fun_name) + |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names + |> map (flat o snd) + |> map2 (fold o find_corec_calls lthy has_call) basic_ctr_specss + |> map2 (curry (op |>)) (map (map (fn {ctr, sels, ...} => + (ctr, map (K []) sels))) basic_ctr_specss); + +(* +val _ = tracing ("callssss = " ^ @{make_string} callssss); +*) + + val ((n2m, corec_specs', _, coinduct_thm, strong_coinduct_thm, coinduct_thms, + strong_coinduct_thms), lthy') = + corec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy; + val actual_nn = length bs; + val corec_specs = take actual_nn corec_specs'; (*###*) + val ctr_specss = map #ctr_specs corec_specs; + + val disc_eqnss' = map_filter (try (fn Disc x => x)) eqns_data + |> partition_eq ((op =) o pairself #fun_name) + |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names + |> map (sort ((op <) o pairself #ctr_no |> make_ord) o flat o snd); + val _ = disc_eqnss' |> map (fn x => + let val d = duplicates ((op =) o pairself #ctr_no) x in null d orelse + primcorec_error_eqns "excess discriminator formula in definition" + (maps (fn t => filter (equal (#ctr_no t) o #ctr_no) x) d |> map #user_eqn) end); + + val sel_eqnss = map_filter (try (fn Sel x => x)) eqns_data + |> partition_eq ((op =) o pairself #fun_name) + |> fst o finds (fn (x, ({fun_name, ...} :: _)) => x = fun_name) fun_names + |> map (flat o snd); + + val arg_Tss = map (binder_types o snd o fst) fixes; + val disc_eqnss = map5 mk_real_disc_eqns bs arg_Tss corec_specs sel_eqnss disc_eqnss'; + val (defs, exclss') = + build_codefs lthy' bs mxs has_call arg_Tss corec_specs disc_eqnss sel_eqnss; + + fun excl_tac (c, c', a) = + if a orelse c = c' orelse seq then SOME (K (HEADGOAL (mk_primcorec_assumption_tac lthy []))) + else maybe_tac; + +(* +val _ = tracing ("exclusiveness properties:\n \ " ^ + space_implode "\n \ " (maps (map (Syntax.string_of_term lthy o snd)) exclss')); +*) + + val exclss'' = exclss' |> map (map (fn (idx, t) => + (idx, (Option.map (Goal.prove lthy [] [] t #> Thm.close_derivation) (excl_tac idx), t)))); + val taut_thmss = map (map (apsnd (the o fst)) o filter (is_some o fst o snd)) exclss''; + val (goal_idxss, goalss) = exclss'' + |> map (map (apsnd (rpair [] o snd)) o filter (is_none o fst o snd)) + |> split_list o map split_list; + + fun prove thmss' def_thms' lthy = + let + val def_thms = map (snd o snd) def_thms'; + + val exclss' = map (op ~~) (goal_idxss ~~ thmss'); + fun mk_exclsss excls n = + (excls, map (fn k => replicate k [TrueI] @ replicate (n - k) []) (0 upto n - 1)) + |-> fold (fn ((c, c', _), thm) => nth_map c (nth_map c' (K [thm]))); + val exclssss = (exclss' ~~ taut_thmss |> map (op @), fun_names ~~ corec_specs) + |-> map2 (fn excls => fn (_, {ctr_specs, ...}) => mk_exclsss excls (length ctr_specs)); + + fun prove_disc ({ctr_specs, ...} : corec_spec) exclsss + ({fun_name, fun_T, fun_args, ctr_no, prems, ...} : coeqn_data_disc) = + if Term.aconv_untyped (#disc (nth ctr_specs ctr_no), @{term "\x. x = x"}) then [] else + let + val {disc_corec, ...} = nth ctr_specs ctr_no; + val k = 1 + ctr_no; + val m = length prems; + val t = + list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0)) + |> curry betapply (#disc (nth ctr_specs ctr_no)) (*###*) + |> HOLogic.mk_Trueprop + |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems) + |> curry Logic.list_all (map dest_Free fun_args); + in + if prems = [@{term False}] then [] else + mk_primcorec_disc_tac lthy def_thms disc_corec k m exclsss + |> K |> Goal.prove lthy [] [] t + |> Thm.close_derivation + |> pair (#disc (nth ctr_specs ctr_no)) + |> single + end; + + fun prove_sel ({nested_maps, nested_map_idents, nested_map_comps, ctr_specs, ...} + : corec_spec) (disc_eqns : coeqn_data_disc list) exclsss + ({fun_name, fun_T, fun_args, ctr, sel, rhs_term, ...} : coeqn_data_sel) = + let + val SOME ctr_spec = find_first (equal ctr o #ctr) ctr_specs; + val ctr_no = find_index (equal ctr o #ctr) ctr_specs; + val prems = the_default (maps (s_not_conj o #prems) disc_eqns) + (find_first (equal ctr_no o #ctr_no) disc_eqns |> Option.map #prems); + val sel_corec = find_index (equal sel) (#sels ctr_spec) + |> nth (#sel_corecs ctr_spec); + val k = 1 + ctr_no; + val m = length prems; + val t = + list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0)) + |> curry betapply sel + |> rpair (abstract (List.rev fun_args) rhs_term) + |> HOLogic.mk_Trueprop o HOLogic.mk_eq + |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems) + |> curry Logic.list_all (map dest_Free fun_args); + val (distincts, _, sel_splits, sel_split_asms) = case_thms_of_term lthy [] rhs_term; + in + mk_primcorec_sel_tac lthy def_thms distincts sel_splits sel_split_asms nested_maps + nested_map_idents nested_map_comps sel_corec k m exclsss + |> K |> Goal.prove lthy [] [] t + |> Thm.close_derivation + |> pair sel + end; + + fun prove_ctr disc_alist sel_alist (disc_eqns : coeqn_data_disc list) + (sel_eqns : coeqn_data_sel list) ({ctr, disc, sels, collapse, ...} : corec_ctr_spec) = + (* don't try to prove theorems when some sel_eqns are missing *) + if not (exists (equal ctr o #ctr) disc_eqns) + andalso not (exists (equal ctr o #ctr) sel_eqns) + orelse + filter (equal ctr o #ctr) sel_eqns + |> fst o finds ((op =) o apsnd #sel) sels + |> exists (null o snd) + then [] else + let + val (fun_name, fun_T, fun_args, prems, maybe_rhs) = + (find_first (equal ctr o #ctr) disc_eqns, find_first (equal ctr o #ctr) sel_eqns) + |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #prems x, + #maybe_ctr_rhs x)) + ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, [], NONE)) + |> the o merge_options; + val m = length prems; + val t = (if is_some maybe_rhs then the maybe_rhs else + filter (equal ctr o #ctr) sel_eqns + |> fst o finds ((op =) o apsnd #sel) sels + |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x)) #-> abstract) + |> curry list_comb ctr) + |> curry HOLogic.mk_eq (list_comb (Free (fun_name, fun_T), + map Bound (length fun_args - 1 downto 0))) + |> HOLogic.mk_Trueprop + |> curry Logic.list_implies (map HOLogic.mk_Trueprop prems) + |> curry Logic.list_all (map dest_Free fun_args); + val maybe_disc_thm = AList.lookup (op =) disc_alist disc; + val sel_thms = map snd (filter (member (op =) sels o fst) sel_alist); + in + if prems = [@{term False}] then [] else + mk_primcorec_ctr_of_dtr_tac lthy m collapse maybe_disc_thm sel_thms + |> K |> Goal.prove lthy [] [] t + |> Thm.close_derivation + |> pair ctr + |> single + end; + + fun prove_code disc_eqns sel_eqns ctr_alist ctr_specs = + let + val (fun_name, fun_T, fun_args, maybe_rhs) = + (find_first (member (op =) (map #ctr ctr_specs) o #ctr) disc_eqns, + find_first (member (op =) (map #ctr ctr_specs) o #ctr) sel_eqns) + |>> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, #maybe_code_rhs x)) + ||> Option.map (fn x => (#fun_name x, #fun_T x, #fun_args x, NONE)) + |> the o merge_options; + + val bound_Ts = List.rev (map fastype_of fun_args); + + val lhs = list_comb (Free (fun_name, fun_T), map Bound (length fun_args - 1 downto 0)); + val maybe_rhs_info = + (case maybe_rhs of + SOME rhs => + let + val raw_rhs = expand_corec_code_rhs lthy has_call bound_Ts rhs; + val cond_ctrs = + fold_rev_corec_code_rhs lthy (K oo (cons oo pair)) bound_Ts raw_rhs []; + val ctr_thms = map (the o AList.lookup (op =) ctr_alist o snd) cond_ctrs; + in SOME (rhs, raw_rhs, ctr_thms) end + | NONE => + let + fun prove_code_ctr {ctr, sels, ...} = + if not (exists (equal ctr o fst) ctr_alist) then NONE else + let + val prems = find_first (equal ctr o #ctr) disc_eqns + |> Option.map #prems |> the_default []; + val t = + filter (equal ctr o #ctr) sel_eqns + |> fst o finds ((op =) o apsnd #sel) sels + |> map (snd #> (fn [x] => (List.rev (#fun_args x), #rhs_term x)) + #-> abstract) + |> curry list_comb ctr; + in + SOME (prems, t) + end; + val maybe_ctr_conds_argss = map prove_code_ctr ctr_specs; + in + if exists is_none maybe_ctr_conds_argss then NONE else + let + val rhs = fold_rev (fn SOME (prems, u) => fn t => mk_If (s_conjs prems) u t) + maybe_ctr_conds_argss + (Const (@{const_name Code.abort}, @{typ String.literal} --> + (@{typ unit} --> body_type fun_T) --> body_type fun_T) $ + HOLogic.mk_literal fun_name $ + absdummy @{typ unit} (incr_boundvars 1 lhs)); + in SOME (rhs, rhs, map snd ctr_alist) end + end); + in + (case maybe_rhs_info of + NONE => [] + | SOME (rhs, raw_rhs, ctr_thms) => + let + val ms = map (Logic.count_prems o prop_of) ctr_thms; + val (raw_t, t) = (raw_rhs, rhs) + |> pairself + (curry HOLogic.mk_eq (list_comb (Free (fun_name, fun_T), + map Bound (length fun_args - 1 downto 0))) + #> HOLogic.mk_Trueprop + #> curry Logic.list_all (map dest_Free fun_args)); + val (distincts, discIs, sel_splits, sel_split_asms) = + case_thms_of_term lthy bound_Ts raw_rhs; + + val raw_code_thm = mk_primcorec_raw_code_of_ctr_tac lthy distincts discIs sel_splits + sel_split_asms ms ctr_thms + |> K |> Goal.prove lthy [] [] raw_t + |> Thm.close_derivation; + in + mk_primcorec_code_of_raw_code_tac lthy distincts sel_splits raw_code_thm + |> K |> Goal.prove lthy [] [] t + |> Thm.close_derivation + |> single + end) + end; + + val disc_alists = map3 (maps oo prove_disc) corec_specs exclssss disc_eqnss; + val sel_alists = map4 (map ooo prove_sel) corec_specs disc_eqnss exclssss sel_eqnss; + val disc_thmss = map (map snd) disc_alists; + val sel_thmss = map (map snd) sel_alists; + + val ctr_alists = map5 (maps oooo prove_ctr) disc_alists sel_alists disc_eqnss sel_eqnss + ctr_specss; + val ctr_thmss = map (map snd) ctr_alists; + + val code_thmss = map4 prove_code disc_eqnss sel_eqnss ctr_alists ctr_specss; + + val simp_thmss = map2 append disc_thmss sel_thmss + + val common_name = mk_common_name fun_names; + + val notes = + [(coinductN, map (if n2m then single else K []) coinduct_thms, []), + (codeN, code_thmss, code_nitpicksimp_attrs), + (ctrN, ctr_thmss, []), + (discN, disc_thmss, simp_attrs), + (selN, sel_thmss, simp_attrs), + (simpsN, simp_thmss, []), + (strong_coinductN, map (if n2m then single else K []) strong_coinduct_thms, [])] + |> maps (fn (thmN, thmss, attrs) => + map2 (fn fun_name => fn thms => + ((Binding.qualify true fun_name (Binding.name thmN), attrs), [(thms, [])])) + fun_names (take actual_nn thmss)) + |> filter_out (null o fst o hd o snd); + + val common_notes = + [(coinductN, if n2m then [coinduct_thm] else [], []), + (strong_coinductN, if n2m then [strong_coinduct_thm] else [], [])] + |> filter_out (null o #2) + |> map (fn (thmN, thms, attrs) => + ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])])); + in + lthy |> Local_Theory.notes (notes @ common_notes) |> snd + end; + + fun after_qed thmss' = fold_map Local_Theory.define defs #-> prove thmss'; + in + (goalss, after_qed, lthy') + end; + +fun add_primcorec_ursive_cmd maybe_tac seq (raw_fixes, raw_specs') lthy = + let + val (raw_specs, maybe_of_specs) = + split_list raw_specs' ||> map (Option.map (Syntax.read_term lthy)); + val ((fixes, specs), _) = Specification.read_spec raw_fixes raw_specs lthy; + in + add_primcorec_ursive maybe_tac seq fixes specs maybe_of_specs lthy + handle ERROR str => primcorec_error str + end + handle Primcorec_Error (str, eqns) => + if null eqns + then error ("primcorec error:\n " ^ str) + else error ("primcorec error:\n " ^ str ^ "\nin\n " ^ + space_implode "\n " (map (quote o Syntax.string_of_term lthy) eqns)); + +val add_primcorecursive_cmd = (fn (goalss, after_qed, lthy) => + lthy + |> Proof.theorem NONE after_qed goalss + |> Proof.refine (Method.primitive_text I) + |> Seq.hd) ooo add_primcorec_ursive_cmd NONE; + +val add_primcorec_cmd = (fn (goalss, after_qed, lthy) => + lthy + |> after_qed (map (fn [] => [] + | _ => primcorec_error "need exclusiveness proofs - use primcorecursive instead of primcorec") + goalss)) ooo add_primcorec_ursive_cmd (SOME (fn {context = ctxt, ...} => auto_tac ctxt)); + +end; diff -r f91022745c85 -r 8fdb4dc08ed1 src/HOL/BNF/Tools/bnf_gfp_rec_sugar_tactics.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/BNF/Tools/bnf_gfp_rec_sugar_tactics.ML Mon Nov 04 16:53:43 2013 +0100 @@ -0,0 +1,135 @@ +(* Title: HOL/BNF/Tools/bnf_gfp_rec_sugar_tactics.ML + Author: Jasmin Blanchette, TU Muenchen + Copyright 2013 + +Tactics for corecursor sugar. +*) + +signature BNF_GFP_REC_SUGAR_TACTICS = +sig + val mk_primcorec_assumption_tac: Proof.context -> thm list -> int -> tactic + val mk_primcorec_code_of_raw_code_tac: Proof.context -> thm list -> thm list -> thm -> tactic + val mk_primcorec_ctr_of_dtr_tac: Proof.context -> int -> thm -> thm option -> thm list -> tactic + val mk_primcorec_disc_tac: Proof.context -> thm list -> thm -> int -> int -> thm list list list -> + tactic + val mk_primcorec_raw_code_of_ctr_tac: Proof.context -> thm list -> thm list -> thm list -> + thm list -> int list -> thm list -> tactic + val mk_primcorec_sel_tac: Proof.context -> thm list -> thm list -> thm list -> thm list -> + thm list -> thm list -> thm list -> thm -> int -> int -> thm list list list -> tactic +end; + +structure BNF_GFP_Rec_Sugar_Tactics : BNF_GFP_REC_SUGAR_TACTICS = +struct + +open BNF_Util +open BNF_Tactics + +val falseEs = @{thms not_TrueE FalseE}; +val Let_def = @{thm Let_def}; +val neq_eq_eq_contradict = @{thm neq_eq_eq_contradict}; +val split_if = @{thm split_if}; +val split_if_asm = @{thm split_if_asm}; +val split_connectI = @{thms allI impI conjI}; + +fun mk_primcorec_assumption_tac ctxt discIs = + SELECT_GOAL (unfold_thms_tac ctxt + @{thms not_not not_False_eq_True not_True_eq_False de_Morgan_conj de_Morgan_disj} THEN + SOLVE (HEADGOAL (REPEAT o (rtac refl ORELSE' atac ORELSE' etac conjE ORELSE' + eresolve_tac falseEs ORELSE' + resolve_tac @{thms TrueI conjI disjI1 disjI2} ORELSE' + dresolve_tac discIs THEN' atac ORELSE' + etac notE THEN' atac ORELSE' + etac disjE)))); + +fun mk_primcorec_same_case_tac m = + HEADGOAL (if m = 0 then rtac TrueI + else REPEAT_DETERM_N (m - 1) o (rtac conjI THEN' atac) THEN' atac); + +fun mk_primcorec_different_case_tac ctxt m excl = + HEADGOAL (if m = 0 then mk_primcorec_assumption_tac ctxt [] + else dtac excl THEN' (REPEAT_DETERM_N (m - 1) o atac) THEN' mk_primcorec_assumption_tac ctxt []); + +fun mk_primcorec_cases_tac ctxt k m exclsss = + let val n = length exclsss in + EVERY (map (fn [] => if k = n then all_tac else mk_primcorec_same_case_tac m + | [excl] => mk_primcorec_different_case_tac ctxt m excl) + (take k (nth exclsss (k - 1)))) + end; + +fun mk_primcorec_prelude ctxt defs thm = + unfold_thms_tac ctxt defs THEN HEADGOAL (rtac thm) THEN + unfold_thms_tac ctxt @{thms Let_def split}; + +fun mk_primcorec_disc_tac ctxt defs disc_corec k m exclsss = + mk_primcorec_prelude ctxt defs disc_corec THEN mk_primcorec_cases_tac ctxt k m exclsss; + +fun mk_primcorec_sel_tac ctxt defs distincts splits split_asms maps map_idents map_comps f_sel k m + exclsss = + mk_primcorec_prelude ctxt defs (f_sel RS trans) THEN + mk_primcorec_cases_tac ctxt k m exclsss THEN + HEADGOAL (REPEAT_DETERM o (rtac refl ORELSE' rtac ext ORELSE' + eresolve_tac falseEs ORELSE' + resolve_tac split_connectI ORELSE' + Splitter.split_asm_tac (split_if_asm :: split_asms) ORELSE' + Splitter.split_tac (split_if :: splits) ORELSE' + eresolve_tac (map (fn thm => thm RS neq_eq_eq_contradict) distincts) THEN' atac ORELSE' + etac notE THEN' atac ORELSE' + (CHANGED o SELECT_GOAL (unfold_thms_tac ctxt + (@{thms id_def o_def split_def sum.cases} @ maps @ map_comps @ map_idents))))); + +fun mk_primcorec_ctr_of_dtr_tac ctxt m collapse maybe_disc_f sel_fs = + HEADGOAL (rtac ((if null sel_fs then collapse else collapse RS sym) RS trans) THEN' + (the_default (K all_tac) (Option.map rtac maybe_disc_f)) THEN' REPEAT_DETERM_N m o atac) THEN + unfold_thms_tac ctxt (Let_def :: sel_fs) THEN HEADGOAL (rtac refl); + +fun inst_split_eq ctxt split = + (case prop_of split of + @{const Trueprop} $ (Const (@{const_name HOL.eq}, _) $ (Var (_, Type (_, [T, _])) $ _) $ _) => + let + val s = Name.uu; + val eq = Abs (Name.uu, T, HOLogic.mk_eq (Free (s, T), Bound 0)); + val split' = Drule.instantiate' [] [SOME (certify ctxt eq)] split; + in + Thm.generalize ([], [s]) (Thm.maxidx_of split' + 1) split' + end + | _ => split); + +fun distinct_in_prems_tac distincts = + eresolve_tac (map (fn thm => thm RS neq_eq_eq_contradict) distincts) THEN' atac; + +(* TODO: reduce code duplication with selector tactic above *) +fun mk_primcorec_raw_code_of_ctr_single_tac ctxt distincts discIs splits split_asms m f_ctr = + let + val splits' = + map (fn th => th RS iffD2) (@{thm split_if_eq2} :: map (inst_split_eq ctxt) splits) + in + HEADGOAL (REPEAT o (resolve_tac (splits' @ split_connectI))) THEN + mk_primcorec_prelude ctxt [] (f_ctr RS trans) THEN + HEADGOAL ((REPEAT_DETERM_N m o mk_primcorec_assumption_tac ctxt discIs) THEN' + SELECT_GOAL (SOLVE (HEADGOAL (REPEAT_DETERM o + (rtac refl ORELSE' atac ORELSE' + resolve_tac (@{thm Code.abort_def} :: split_connectI) ORELSE' + Splitter.split_tac (split_if :: splits) ORELSE' + Splitter.split_asm_tac (split_if_asm :: split_asms) ORELSE' + mk_primcorec_assumption_tac ctxt discIs ORELSE' + distinct_in_prems_tac distincts ORELSE' + (TRY o dresolve_tac discIs) THEN' etac notE THEN' atac))))) + end; + +fun mk_primcorec_raw_code_of_ctr_tac ctxt distincts discIs splits split_asms ms f_ctrs = + EVERY (map2 (mk_primcorec_raw_code_of_ctr_single_tac ctxt distincts discIs splits split_asms) ms + f_ctrs) THEN + IF_UNSOLVED (unfold_thms_tac ctxt @{thms Code.abort_def} THEN + HEADGOAL (REPEAT_DETERM o resolve_tac (refl :: split_connectI))); + +fun mk_primcorec_code_of_raw_code_tac ctxt distincts splits raw = + HEADGOAL (rtac raw ORELSE' rtac (raw RS trans) THEN' + SELECT_GOAL (unfold_thms_tac ctxt [Let_def]) THEN' REPEAT_DETERM o + (rtac refl ORELSE' atac ORELSE' + resolve_tac split_connectI ORELSE' + Splitter.split_tac (split_if :: splits) ORELSE' + distinct_in_prems_tac distincts ORELSE' + rtac sym THEN' atac ORELSE' + etac notE THEN' atac)); + +end; diff -r f91022745c85 -r 8fdb4dc08ed1 src/HOL/BNF/Tools/bnf_lfp.ML --- a/src/HOL/BNF/Tools/bnf_lfp.ML Mon Nov 04 15:44:43 2013 +0100 +++ b/src/HOL/BNF/Tools/bnf_lfp.ML Mon Nov 04 16:53:43 2013 +0100 @@ -22,7 +22,7 @@ open BNF_Comp open BNF_FP_Util open BNF_FP_Def_Sugar -open BNF_FP_Rec_Sugar +open BNF_LFP_Rec_Sugar open BNF_LFP_Util open BNF_LFP_Tactics diff -r f91022745c85 -r 8fdb4dc08ed1 src/HOL/BNF/Tools/bnf_lfp_rec_sugar.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/BNF/Tools/bnf_lfp_rec_sugar.ML Mon Nov 04 16:53:43 2013 +0100 @@ -0,0 +1,598 @@ +(* Title: HOL/BNF/Tools/bnf_lfp_rec_sugar.ML + Author: Lorenz Panny, TU Muenchen + Author: Jasmin Blanchette, TU Muenchen + Copyright 2013 + +Recursor sugar. +*) + +signature BNF_LFP_REC_SUGAR = +sig + val add_primrec: (binding * typ option * mixfix) list -> + (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory + val add_primrec_cmd: (binding * string option * mixfix) list -> + (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory + val add_primrec_global: (binding * typ option * mixfix) list -> + (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory + val add_primrec_overloaded: (string * (string * typ) * bool) list -> + (binding * typ option * mixfix) list -> + (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory + val add_primrec_simple: ((binding * typ) * mixfix) list -> term list -> + local_theory -> (string list * (term list * (int list list * thm list list))) * local_theory +end; + +structure BNF_LFP_Rec_Sugar : BNF_LFP_REC_SUGAR = +struct + +open Ctr_Sugar +open BNF_Util +open BNF_Tactics +open BNF_Def +open BNF_FP_Util +open BNF_FP_Def_Sugar +open BNF_FP_N2M_Sugar +open BNF_FP_Rec_Sugar_Util + +val nitpicksimp_attrs = @{attributes [nitpick_simp]}; +val simp_attrs = @{attributes [simp]}; +val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs @ simp_attrs; + +exception Primrec_Error of string * term list; + +fun primrec_error str = raise Primrec_Error (str, []); +fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]); +fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns); + +datatype rec_call = + No_Rec of int * typ | + Mutual_Rec of (int * typ) * (int * typ) | + Nested_Rec of int * typ; + +type rec_ctr_spec = + {ctr: term, + offset: int, + calls: rec_call list, + rec_thm: thm}; + +type rec_spec = + {recx: term, + nested_map_idents: thm list, + nested_map_comps: thm list, + ctr_specs: rec_ctr_spec list}; + +exception AINT_NO_MAP of term; + +fun ill_formed_rec_call ctxt t = + error ("Ill-formed recursive call: " ^ quote (Syntax.string_of_term ctxt t)); +fun invalid_map ctxt t = + error ("Invalid map function in " ^ quote (Syntax.string_of_term ctxt t)); +fun unexpected_rec_call ctxt t = + error ("Unexpected recursive call: " ^ quote (Syntax.string_of_term ctxt t)); + +fun massage_nested_rec_call ctxt has_call raw_massage_fun bound_Ts y y' = + let + fun check_no_call t = if has_call t then unexpected_rec_call ctxt t else (); + + val typof = curry fastype_of1 bound_Ts; + val build_map_fst = build_map ctxt (fst_const o fst); + + val yT = typof y; + val yU = typof y'; + + fun y_of_y' () = build_map_fst (yU, yT) $ y'; + val elim_y = Term.map_aterms (fn t => if t = y then y_of_y' () else t); + + fun massage_mutual_fun U T t = + (case t of + Const (@{const_name comp}, _) $ t1 $ t2 => + mk_comp bound_Ts (tap check_no_call t1, massage_mutual_fun U T t2) + | _ => + if has_call t then + (case try HOLogic.dest_prodT U of + SOME (U1, U2) => if U1 = T then raw_massage_fun T U2 t else invalid_map ctxt t + | NONE => invalid_map ctxt t) + else + mk_comp bound_Ts (t, build_map_fst (U, T))); + + fun massage_map (Type (_, Us)) (Type (s, Ts)) t = + (case try (dest_map ctxt s) t of + SOME (map0, fs) => + let + val Type (_, ran_Ts) = range_type (typof t); + val map' = mk_map (length fs) Us ran_Ts map0; + val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs; + in + Term.list_comb (map', fs') + end + | NONE => raise AINT_NO_MAP t) + | massage_map _ _ t = raise AINT_NO_MAP t + and massage_map_or_map_arg U T t = + if T = U then + tap check_no_call t + else + massage_map U T t + handle AINT_NO_MAP _ => massage_mutual_fun U T t; + + fun massage_call (t as t1 $ t2) = + if has_call t then + if t2 = y then + massage_map yU yT (elim_y t1) $ y' + handle AINT_NO_MAP t' => invalid_map ctxt t' + else + let val (g, xs) = Term.strip_comb t2 in + if g = y then + if exists has_call xs then unexpected_rec_call ctxt t2 + else Term.list_comb (massage_call (mk_compN (length xs) bound_Ts (t1, y)), xs) + else + ill_formed_rec_call ctxt t + end + else + elim_y t + | massage_call t = if t = y then y_of_y' () else ill_formed_rec_call ctxt t; + in + massage_call + end; + +fun rec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy = + let + val thy = Proof_Context.theory_of lthy; + + val ((missing_arg_Ts, perm0_kks, + fp_sugars as {nested_bnfs, fp_res = {xtor_co_iterss = ctor_iters1 :: _, ...}, + co_inducts = [induct_thm], ...} :: _, (lfp_sugar_thms, _)), lthy') = + nested_to_mutual_fps Least_FP bs arg_Ts get_indices callssss0 lthy; + + val perm_fp_sugars = sort (int_ord o pairself #index) fp_sugars; + + val indices = map #index fp_sugars; + val perm_indices = map #index perm_fp_sugars; + + val perm_ctrss = map (#ctrs o of_fp_sugar #ctr_sugars) perm_fp_sugars; + val perm_ctr_Tsss = map (map (binder_types o fastype_of)) perm_ctrss; + val perm_lfpTs = map (body_type o fastype_of o hd) perm_ctrss; + + val nn0 = length arg_Ts; + val nn = length perm_lfpTs; + val kks = 0 upto nn - 1; + val perm_ns = map length perm_ctr_Tsss; + val perm_mss = map (map length) perm_ctr_Tsss; + + val perm_Cs = map (body_type o fastype_of o co_rec_of o of_fp_sugar (#xtor_co_iterss o #fp_res)) + perm_fp_sugars; + val perm_fun_arg_Tssss = + mk_iter_fun_arg_types perm_ctr_Tsss perm_ns perm_mss (co_rec_of ctor_iters1); + + fun unpermute0 perm0_xs = permute_like (op =) perm0_kks kks perm0_xs; + fun unpermute perm_xs = permute_like (op =) perm_indices indices perm_xs; + + val induct_thms = unpermute0 (conj_dests nn induct_thm); + + val lfpTs = unpermute perm_lfpTs; + val Cs = unpermute perm_Cs; + + val As_rho = tvar_subst thy (take nn0 lfpTs) arg_Ts; + val Cs_rho = map (fst o dest_TVar) Cs ~~ pad_list HOLogic.unitT nn res_Ts; + + val substA = Term.subst_TVars As_rho; + val substAT = Term.typ_subst_TVars As_rho; + val substCT = Term.typ_subst_TVars Cs_rho; + val substACT = substAT o substCT; + + val perm_Cs' = map substCT perm_Cs; + + fun offset_of_ctr 0 _ = 0 + | offset_of_ctr n (({ctrs, ...} : ctr_sugar) :: ctr_sugars) = + length ctrs + offset_of_ctr (n - 1) ctr_sugars; + + fun call_of [i] [T] = (if exists_subtype_in Cs T then Nested_Rec else No_Rec) (i, substACT T) + | call_of [i, i'] [T, T'] = Mutual_Rec ((i, substACT T), (i', substACT T')); + + fun mk_ctr_spec ctr offset fun_arg_Tss rec_thm = + let + val (fun_arg_hss, _) = indexedd fun_arg_Tss 0; + val fun_arg_hs = flat_rec_arg_args fun_arg_hss; + val fun_arg_iss = map (map (find_index_eq fun_arg_hs)) fun_arg_hss; + in + {ctr = substA ctr, offset = offset, calls = map2 call_of fun_arg_iss fun_arg_Tss, + rec_thm = rec_thm} + end; + + fun mk_ctr_specs index (ctr_sugars : ctr_sugar list) iter_thmsss = + let + val ctrs = #ctrs (nth ctr_sugars index); + val rec_thmss = co_rec_of (nth iter_thmsss index); + val k = offset_of_ctr index ctr_sugars; + val n = length ctrs; + in + map4 mk_ctr_spec ctrs (k upto k + n - 1) (nth perm_fun_arg_Tssss index) rec_thmss + end; + + fun mk_spec ({T, index, ctr_sugars, co_iterss = iterss, co_iter_thmsss = iter_thmsss, ...} + : fp_sugar) = + {recx = mk_co_iter thy Least_FP (substAT T) perm_Cs' (co_rec_of (nth iterss index)), + nested_map_idents = map (unfold_thms lthy @{thms id_def} o map_id0_of_bnf) nested_bnfs, + nested_map_comps = map map_comp_of_bnf nested_bnfs, + ctr_specs = mk_ctr_specs index ctr_sugars iter_thmsss}; + in + ((is_some lfp_sugar_thms, map mk_spec fp_sugars, missing_arg_Ts, induct_thm, induct_thms), + lthy') + end; + +val undef_const = Const (@{const_name undefined}, dummyT); + +fun permute_args n t = + list_comb (t, map Bound (0 :: (n downto 1))) |> fold (K (Term.abs (Name.uu, dummyT))) (0 upto n); + +type eqn_data = { + fun_name: string, + rec_type: typ, + ctr: term, + ctr_args: term list, + left_args: term list, + right_args: term list, + res_type: typ, + rhs_term: term, + user_eqn: term +}; + +fun dissect_eqn lthy fun_names eqn' = + let + val eqn = drop_All eqn' |> HOLogic.dest_Trueprop + handle TERM _ => + primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn'; + val (lhs, rhs) = HOLogic.dest_eq eqn + handle TERM _ => + primrec_error_eqn "malformed function equation (expected \"lhs = rhs\")" eqn'; + val (fun_name, args) = strip_comb lhs + |>> (fn x => if is_Free x then fst (dest_Free x) + else primrec_error_eqn "malformed function equation (does not start with free)" eqn); + val (left_args, rest) = take_prefix is_Free args; + val (nonfrees, right_args) = take_suffix is_Free rest; + val num_nonfrees = length nonfrees; + val _ = num_nonfrees = 1 orelse if num_nonfrees = 0 then + primrec_error_eqn "constructor pattern missing in left-hand side" eqn else + primrec_error_eqn "more than one non-variable argument in left-hand side" eqn; + val _ = member (op =) fun_names fun_name orelse + primrec_error_eqn "malformed function equation (does not start with function name)" eqn + + val (ctr, ctr_args) = strip_comb (the_single nonfrees); + val _ = try (num_binder_types o fastype_of) ctr = SOME (length ctr_args) orelse + primrec_error_eqn "partially applied constructor in pattern" eqn; + val _ = let val d = duplicates (op =) (left_args @ ctr_args @ right_args) in null d orelse + primrec_error_eqn ("duplicate variable \"" ^ Syntax.string_of_term lthy (hd d) ^ + "\" in left-hand side") eqn end; + val _ = forall is_Free ctr_args orelse + primrec_error_eqn "non-primitive pattern in left-hand side" eqn; + val _ = + let val b = fold_aterms (fn x as Free (v, _) => + if (not (member (op =) (left_args @ ctr_args @ right_args) x) andalso + not (member (op =) fun_names v) andalso + not (Variable.is_fixed lthy v)) then cons x else I | _ => I) rhs [] + in + null b orelse + primrec_error_eqn ("extra variable(s) in right-hand side: " ^ + commas (map (Syntax.string_of_term lthy) b)) eqn + end; + in + {fun_name = fun_name, + rec_type = body_type (type_of ctr), + ctr = ctr, + ctr_args = ctr_args, + left_args = left_args, + right_args = right_args, + res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs, + rhs_term = rhs, + user_eqn = eqn'} + end; + +fun rewrite_map_arg get_ctr_pos rec_type res_type = + let + val pT = HOLogic.mk_prodT (rec_type, res_type); + + val maybe_suc = Option.map (fn x => x + 1); + fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT) + | subst d (Abs (v, T, b)) = Abs (v, if d = SOME ~1 then pT else T, subst (maybe_suc d) b) + | subst d t = + let + val (u, vs) = strip_comb t; + val ctr_pos = try (get_ctr_pos o fst o dest_Free) u |> the_default ~1; + in + if ctr_pos >= 0 then + if d = SOME ~1 andalso length vs = ctr_pos then + list_comb (permute_args ctr_pos (snd_const pT), vs) + else if length vs > ctr_pos andalso is_some d + andalso d = try (fn Bound n => n) (nth vs ctr_pos) then + list_comb (snd_const pT $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs)) + else + primrec_error_eqn ("recursive call not directly applied to constructor argument") t + else + list_comb (u, map (subst (d |> d = SOME ~1 ? K NONE)) vs) + end + in + subst (SOME ~1) + end; + +fun subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls = + let + fun try_nested_rec bound_Ts y t = + AList.lookup (op =) nested_calls y + |> Option.map (fn y' => + massage_nested_rec_call lthy has_call (rewrite_map_arg get_ctr_pos) bound_Ts y y' t); + + fun subst bound_Ts (t as g' $ y) = + let + fun subst_rec () = subst bound_Ts g' $ subst bound_Ts y; + val y_head = head_of y; + in + if not (member (op =) ctr_args y_head) then + subst_rec () + else + (case try_nested_rec bound_Ts y_head t of + SOME t' => t' + | NONE => + let val (g, g_args) = strip_comb g' in + (case try (get_ctr_pos o fst o dest_Free) g of + SOME ctr_pos => + (length g_args >= ctr_pos orelse + primrec_error_eqn "too few arguments in recursive call" t; + (case AList.lookup (op =) mutual_calls y of + SOME y' => list_comb (y', g_args) + | NONE => subst_rec ())) + | NONE => subst_rec ()) + end) + end + | subst bound_Ts (Abs (v, T, b)) = Abs (v, T, subst (T :: bound_Ts) b) + | subst _ t = t + + fun subst' t = + if has_call t then + (* FIXME detect this case earlier? *) + primrec_error_eqn "recursive call not directly applied to constructor argument" t + else + try_nested_rec [] (head_of t) t |> the_default t + in + subst' o subst [] + end; + +fun build_rec_arg lthy (funs_data : eqn_data list list) has_call (ctr_spec : rec_ctr_spec) + (maybe_eqn_data : eqn_data option) = + (case maybe_eqn_data of + NONE => undef_const + | SOME {ctr_args, left_args, right_args, rhs_term = t, ...} => + let + val calls = #calls ctr_spec; + val n_args = fold (Integer.add o (fn Mutual_Rec _ => 2 | _ => 1)) calls 0; + + val no_calls' = tag_list 0 calls + |> map_filter (try (apsnd (fn No_Rec p => p | Mutual_Rec (p, _) => p))); + val mutual_calls' = tag_list 0 calls + |> map_filter (try (apsnd (fn Mutual_Rec (_, p) => p))); + val nested_calls' = tag_list 0 calls + |> map_filter (try (apsnd (fn Nested_Rec p => p))); + + val args = replicate n_args ("", dummyT) + |> Term.rename_wrt_term t + |> map Free + |> fold (fn (ctr_arg_idx, (arg_idx, _)) => + nth_map arg_idx (K (nth ctr_args ctr_arg_idx))) + no_calls' + |> fold (fn (ctr_arg_idx, (arg_idx, T)) => + nth_map arg_idx (K (retype_free T (nth ctr_args ctr_arg_idx)))) + mutual_calls' + |> fold (fn (ctr_arg_idx, (arg_idx, T)) => + nth_map arg_idx (K (retype_free T (nth ctr_args ctr_arg_idx)))) + nested_calls'; + + val fun_name_ctr_pos_list = + map (fn (x :: _) => (#fun_name x, length (#left_args x))) funs_data; + val get_ctr_pos = try (the o AList.lookup (op =) fun_name_ctr_pos_list) #> the_default ~1; + val mutual_calls = map (apfst (nth ctr_args) o apsnd (nth args o fst)) mutual_calls'; + val nested_calls = map (apfst (nth ctr_args) o apsnd (nth args o fst)) nested_calls'; + in + t + |> subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls + |> fold_rev lambda (args @ left_args @ right_args) + end); + +fun build_defs lthy bs mxs (funs_data : eqn_data list list) (rec_specs : rec_spec list) has_call = + let + val n_funs = length funs_data; + + val ctr_spec_eqn_data_list' = + (take n_funs rec_specs |> map #ctr_specs) ~~ funs_data + |> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y)) + ##> (fn x => null x orelse + primrec_error_eqns "excess equations in definition" (map #rhs_term x)) #> fst); + val _ = ctr_spec_eqn_data_list' |> map (fn (_, x) => length x <= 1 orelse + primrec_error_eqns ("multiple equations for constructor") (map #user_eqn x)); + + val ctr_spec_eqn_data_list = + ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair [])); + + val recs = take n_funs rec_specs |> map #recx; + val rec_args = ctr_spec_eqn_data_list + |> sort ((op <) o pairself (#offset o fst) |> make_ord) + |> map (uncurry (build_rec_arg lthy funs_data has_call) o apsnd (try the_single)); + val ctr_poss = map (fn x => + if length (distinct ((op =) o pairself (length o #left_args)) x) <> 1 then + primrec_error ("inconstant constructor pattern position for function " ^ + quote (#fun_name (hd x))) + else + hd x |> #left_args |> length) funs_data; + in + (recs, ctr_poss) + |-> map2 (fn recx => fn ctr_pos => list_comb (recx, rec_args) |> permute_args ctr_pos) + |> Syntax.check_terms lthy + |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.conceal (Thm.def_binding b), []), t))) + bs mxs + end; + +fun find_rec_calls has_call ({ctr, ctr_args, rhs_term, ...} : eqn_data) = + let + fun find bound_Ts (Abs (_, T, b)) ctr_arg = find (T :: bound_Ts) b ctr_arg + | find bound_Ts (t as _ $ _) ctr_arg = + let + val typof = curry fastype_of1 bound_Ts; + val (f', args') = strip_comb t; + val n = find_index (equal ctr_arg o head_of) args'; + in + if n < 0 then + find bound_Ts f' ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args' + else + let + val (f, args as arg :: _) = chop n args' |>> curry list_comb f' + val (arg_head, arg_args) = Term.strip_comb arg; + in + if has_call f then + mk_partial_compN (length arg_args) (typof arg_head) f :: + maps (fn x => find bound_Ts x ctr_arg) args + else + find bound_Ts f ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args + end + end + | find _ _ _ = []; + in + map (find [] rhs_term) ctr_args + |> (fn [] => NONE | callss => SOME (ctr, callss)) + end; + +fun mk_primrec_tac ctxt num_extra_args map_idents map_comps fun_defs recx = + unfold_thms_tac ctxt fun_defs THEN + HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN + unfold_thms_tac ctxt (@{thms id_def split o_def fst_conv snd_conv} @ map_comps @ map_idents) THEN + HEADGOAL (rtac refl); + +fun prepare_primrec fixes specs lthy = + let + val (bs, mxs) = map_split (apfst fst) fixes; + val fun_names = map Binding.name_of bs; + val eqns_data = map (dissect_eqn lthy fun_names) specs; + val funs_data = eqns_data + |> partition_eq ((op =) o pairself #fun_name) + |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst + |> map (fn (x, y) => the_single y handle List.Empty => + primrec_error ("missing equations for function " ^ quote x)); + + val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =)); + val arg_Ts = map (#rec_type o hd) funs_data; + val res_Ts = map (#res_type o hd) funs_data; + val callssss = funs_data + |> map (partition_eq ((op =) o pairself #ctr)) + |> map (maps (map_filter (find_rec_calls has_call))); + + val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy') = + rec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy; + + val actual_nn = length funs_data; + + val _ = let val ctrs = (maps (map #ctr o #ctr_specs) rec_specs) in + map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse + primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy' ctr) ^ + " is not a constructor in left-hand side") user_eqn) eqns_data end; + + val defs = build_defs lthy' bs mxs funs_data rec_specs has_call; + + fun prove lthy def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec) + (fun_data : eqn_data list) = + let + val def_thms = map (snd o snd) def_thms'; + val simp_thmss = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs + |> fst + |> map_filter (try (fn (x, [y]) => + (#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y))) + |> map (fn (user_eqn, num_extra_args, rec_thm) => + mk_primrec_tac lthy num_extra_args nested_map_idents nested_map_comps def_thms rec_thm + |> K |> Goal.prove lthy [] [] user_eqn + |> Thm.close_derivation); + val poss = find_indices (fn (x, y) => #ctr x = #ctr y) fun_data eqns_data; + in + (poss, simp_thmss) + end; + + val notes = + (if n2m then map2 (fn name => fn thm => + (name, inductN, [thm], [])) fun_names (take actual_nn induct_thms) else []) + |> map (fn (prefix, thmN, thms, attrs) => + ((Binding.qualify true prefix (Binding.name thmN), attrs), [(thms, [])])); + + val common_name = mk_common_name fun_names; + + val common_notes = + (if n2m then [(inductN, [induct_thm], [])] else []) + |> map (fn (thmN, thms, attrs) => + ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])])); + in + (((fun_names, defs), + fn lthy => fn defs => + split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)), + lthy' |> Local_Theory.notes (notes @ common_notes) |> snd) + end; + +(* primrec definition *) + +fun add_primrec_simple fixes ts lthy = + let + val (((names, defs), prove), lthy) = prepare_primrec fixes ts lthy + handle ERROR str => primrec_error str; + in + lthy + |> fold_map Local_Theory.define defs + |-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs)))) + end + handle Primrec_Error (str, eqns) => + if null eqns + then error ("primrec_new error:\n " ^ str) + else error ("primrec_new error:\n " ^ str ^ "\nin\n " ^ + space_implode "\n " (map (quote o Syntax.string_of_term lthy) eqns)); + +local + +fun gen_primrec prep_spec (raw_fixes : (binding * 'a option * mixfix) list) raw_spec lthy = + let + val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes) + val _ = null d orelse primrec_error ("duplicate function name(s): " ^ commas d); + + val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy); + + val mk_notes = + flat ooo map3 (fn poss => fn prefix => fn thms => + let + val (bs, attrss) = map_split (fst o nth specs) poss; + val notes = + map3 (fn b => fn attrs => fn thm => + ((Binding.qualify false prefix b, code_nitpicksimp_simp_attrs @ attrs), [([thm], [])])) + bs attrss thms; + in + ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes + end); + in + lthy + |> add_primrec_simple fixes (map snd specs) + |-> (fn (names, (ts, (posss, simpss))) => + Spec_Rules.add Spec_Rules.Equational (ts, flat simpss) + #> Local_Theory.notes (mk_notes posss names simpss) + #>> pair ts o map snd) + end; + +in + +val add_primrec = gen_primrec Specification.check_spec; +val add_primrec_cmd = gen_primrec Specification.read_spec; + +end; + +fun add_primrec_global fixes specs thy = + let + val lthy = Named_Target.theory_init thy; + val ((ts, simps), lthy') = add_primrec fixes specs lthy; + val simps' = burrow (Proof_Context.export lthy' lthy) simps; + in ((ts, simps'), Local_Theory.exit_global lthy') end; + +fun add_primrec_overloaded ops fixes specs thy = + let + val lthy = Overloading.overloading ops thy; + val ((ts, simps), lthy') = add_primrec fixes specs lthy; + val simps' = burrow (Proof_Context.export lthy' lthy) simps; + in ((ts, simps'), Local_Theory.exit_global lthy') end; + +end;