src/HOL/Tools/primrec_package.ML
changeset 15570 8d8c70b41bab
parent 15531 08c8dad8e399
child 15574 b1d1b5bfc464
equal deleted inserted replaced
15569:1b3115d1a8df 15570:8d8c70b41bab
    74         (map dest_Free (term_frees rhs) \\ lfrees);
    74         (map dest_Free (term_frees rhs) \\ lfrees);
    75       case assoc (rec_fns, fname) of
    75       case assoc (rec_fns, fname) of
    76         NONE =>
    76         NONE =>
    77           (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eq))]))::rec_fns
    77           (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eq))]))::rec_fns
    78       | SOME (_, rpos', eqns) =>
    78       | SOME (_, rpos', eqns) =>
    79           if is_some (assoc (eqns, cname)) then
    79           if isSome (assoc (eqns, cname)) then
    80             raise RecError "constructor already occurred as pattern"
    80             raise RecError "constructor already occurred as pattern"
    81           else if rpos <> rpos' then
    81           else if rpos <> rpos' then
    82             raise RecError "position of recursive argument inconsistent"
    82             raise RecError "position of recursive argument inconsistent"
    83           else
    83           else
    84             overwrite (rec_fns, 
    84             overwrite (rec_fns, 
    88   end
    88   end
    89   handle RecError s => primrec_eq_err sign s eq;
    89   handle RecError s => primrec_eq_err sign s eq;
    90 
    90 
    91 fun process_fun sign descr rec_eqns ((i, fname), (fnames, fnss)) =
    91 fun process_fun sign descr rec_eqns ((i, fname), (fnames, fnss)) =
    92   let
    92   let
    93     val (_, (tname, _, constrs)) = nth_elem (i, descr);
    93     val (_, (tname, _, constrs)) = List.nth (descr, i);
    94 
    94 
    95     (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *)
    95     (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *)
    96 
    96 
    97     fun subst [] x = x
    97     fun subst [] x = x
    98       | subst subs (fs, Abs (a, T, t)) =
    98       | subst subs (fs, Abs (a, T, t)) =
   102           let val (f, ts) = strip_comb t;
   102           let val (f, ts) = strip_comb t;
   103           in
   103           in
   104             if is_Const f andalso (fst (dest_Const f)) mem (map fst rec_eqns) then
   104             if is_Const f andalso (fst (dest_Const f)) mem (map fst rec_eqns) then
   105               let
   105               let
   106                 val (fname', _) = dest_Const f;
   106                 val (fname', _) = dest_Const f;
   107                 val (_, rpos, _) = the (assoc (rec_eqns, fname'));
   107                 val (_, rpos, _) = valOf (assoc (rec_eqns, fname'));
   108                 val ls = take (rpos, ts);
   108                 val ls = Library.take (rpos, ts);
   109                 val rest = drop (rpos, ts);
   109                 val rest = Library.drop (rpos, ts);
   110                 val (x', rs) = (hd rest, tl rest)
   110                 val (x', rs) = (hd rest, tl rest)
   111                   handle LIST _ => raise RecError ("not enough arguments\
   111                   handle Empty => raise RecError ("not enough arguments\
   112                    \ in recursive application\nof function " ^ quote fname' ^ " on rhs");
   112                    \ in recursive application\nof function " ^ quote fname' ^ " on rhs");
   113                 val (x, xs) = strip_comb x'
   113                 val (x, xs) = strip_comb x'
   114               in 
   114               in 
   115                 (case assoc (subs, x) of
   115                 (case assoc (subs, x) of
   116                     NONE =>
   116                     NONE =>
   138           NONE => (warning ("No equation for constructor " ^ quote cname ^
   138           NONE => (warning ("No equation for constructor " ^ quote cname ^
   139             "\nin definition of function " ^ quote fname);
   139             "\nin definition of function " ^ quote fname);
   140               (fnames', fnss', (Const ("arbitrary", dummyT))::fns))
   140               (fnames', fnss', (Const ("arbitrary", dummyT))::fns))
   141         | SOME (ls, cargs', rs, rhs, eq) =>
   141         | SOME (ls, cargs', rs, rhs, eq) =>
   142             let
   142             let
   143               val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
   143               val recs = List.filter (is_rec_type o snd) (cargs' ~~ cargs);
   144               val rargs = map fst recs;
   144               val rargs = map fst recs;
   145               val subs = map (rpair dummyT o fst) 
   145               val subs = map (rpair dummyT o fst) 
   146 		             (rev (rename_wrt_term rhs rargs));
   146 		             (rev (rename_wrt_term rhs rargs));
   147               val ((fnames'', fnss''), rhs') = 
   147               val ((fnames'', fnss''), rhs') = 
   148 		  (subst (map (fn ((x, y), z) =>
   148 		  (subst (map (fn ((x, y), z) =>
   158       NONE =>
   158       NONE =>
   159         if exists (equal fname o snd) fnames then
   159         if exists (equal fname o snd) fnames then
   160           raise RecError ("inconsistent functions for datatype " ^ quote tname)
   160           raise RecError ("inconsistent functions for datatype " ^ quote tname)
   161         else
   161         else
   162           let
   162           let
   163             val (_, _, eqns) = the (assoc (rec_eqns, fname));
   163             val (_, _, eqns) = valOf (assoc (rec_eqns, fname));
   164             val (fnames', fnss', fns) = foldr (trans eqns)
   164             val (fnames', fnss', fns) = Library.foldr (trans eqns)
   165               (constrs, ((i, fname)::fnames, fnss, []))
   165               (constrs, ((i, fname)::fnames, fnss, []))
   166           in
   166           in
   167             (fnames', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
   167             (fnames', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
   168           end
   168           end
   169     | SOME fname' =>
   169     | SOME fname' =>
   177 fun get_fns fns (((i, (tname, _, constrs)), rec_name), (fs, defs)) =
   177 fun get_fns fns (((i, (tname, _, constrs)), rec_name), (fs, defs)) =
   178   case assoc (fns, i) of
   178   case assoc (fns, i) of
   179      NONE =>
   179      NONE =>
   180        let
   180        let
   181          val dummy_fns = map (fn (_, cargs) => Const ("arbitrary",
   181          val dummy_fns = map (fn (_, cargs) => Const ("arbitrary",
   182            replicate ((length cargs) + (length (filter is_rec_type cargs)))
   182            replicate ((length cargs) + (length (List.filter is_rec_type cargs)))
   183              dummyT ---> HOLogic.unitT)) constrs;
   183              dummyT ---> HOLogic.unitT)) constrs;
   184          val _ = warning ("No function definition for datatype " ^ quote tname)
   184          val _ = warning ("No function definition for datatype " ^ quote tname)
   185        in
   185        in
   186          (dummy_fns @ fs, defs)
   186          (dummy_fns @ fs, defs)
   187        end
   187        end
   190 
   190 
   191 (* make definition *)
   191 (* make definition *)
   192 
   192 
   193 fun make_def sign fs (fname, ls, rec_name, tname) =
   193 fun make_def sign fs (fname, ls, rec_name, tname) =
   194   let
   194   let
   195     val rhs = foldr (fn (T, t) => Abs ("", T, t)) 
   195     val rhs = Library.foldr (fn (T, t) => Abs ("", T, t)) 
   196 	            ((map snd ls) @ [dummyT],
   196 	            ((map snd ls) @ [dummyT],
   197 		     list_comb (Const (rec_name, dummyT),
   197 		     list_comb (Const (rec_name, dummyT),
   198 				fs @ map Bound (0 ::(length ls downto 1))));
   198 				fs @ map Bound (0 ::(length ls downto 1))));
   199     val defpair = (Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def",
   199     val defpair = (Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def",
   200 		   Logic.mk_equals (Const (fname, dummyT), rhs))
   200 		   Logic.mk_equals (Const (fname, dummyT), rhs))
   214 
   214 
   215 fun prepare_induct ({descr, induction, ...}: datatype_info) rec_eqns =
   215 fun prepare_induct ({descr, induction, ...}: datatype_info) rec_eqns =
   216   let
   216   let
   217     fun constrs_of (_, (_, _, cs)) =
   217     fun constrs_of (_, (_, _, cs)) =
   218       map (fn (cname:string, (_, cargs, _, _, _)) => (cname, map fst cargs)) cs;
   218       map (fn (cname:string, (_, cargs, _, _, _)) => (cname, map fst cargs)) cs;
   219     val params_of = Library.assocs (flat (map constrs_of rec_eqns));
   219     val params_of = Library.assocs (List.concat (map constrs_of rec_eqns));
   220   in
   220   in
   221     induction
   221     induction
   222     |> RuleCases.rename_params (map params_of (flat (map (map #1 o #3 o #2) descr)))
   222     |> RuleCases.rename_params (map params_of (List.concat (map (map #1 o #3 o #2) descr)))
   223     |> RuleCases.save induction
   223     |> RuleCases.save induction
   224   end;
   224   end;
   225 
   225 
   226 fun add_primrec_i alt_name eqns_atts thy =
   226 fun add_primrec_i alt_name eqns_atts thy =
   227   let
   227   let
   228     val (eqns, atts) = split_list eqns_atts;
   228     val (eqns, atts) = split_list eqns_atts;
   229     val sg = Theory.sign_of thy;
   229     val sg = Theory.sign_of thy;
   230     val dt_info = DatatypePackage.get_datatypes thy;
   230     val dt_info = DatatypePackage.get_datatypes thy;
   231     val rec_eqns = foldr (process_eqn sg) (map snd eqns, []);
   231     val rec_eqns = Library.foldr (process_eqn sg) (map snd eqns, []);
   232     val tnames = distinct (map (#1 o snd) rec_eqns);
   232     val tnames = distinct (map (#1 o snd) rec_eqns);
   233     val dts = find_dts dt_info tnames tnames;
   233     val dts = find_dts dt_info tnames tnames;
   234     val main_fns = 
   234     val main_fns = 
   235 	map (fn (tname, {index, ...}) =>
   235 	map (fn (tname, {index, ...}) =>
   236 	     (index, 
   236 	     (index, 
   237 	      fst (the (find_first (fn f => #1 (snd f) = tname) rec_eqns))))
   237 	      fst (valOf (find_first (fn f => #1 (snd f) = tname) rec_eqns))))
   238 	dts;
   238 	dts;
   239     val {descr, rec_names, rec_rewrites, ...} = 
   239     val {descr, rec_names, rec_rewrites, ...} = 
   240 	if null dts then
   240 	if null dts then
   241 	    primrec_err ("datatypes " ^ commas_quote tnames ^ 
   241 	    primrec_err ("datatypes " ^ commas_quote tnames ^ 
   242 			 "\nare not mutually recursive")
   242 			 "\nare not mutually recursive")
   243 	else snd (hd dts);
   243 	else snd (hd dts);
   244     val (fnames, fnss) = foldr (process_fun sg descr rec_eqns)
   244     val (fnames, fnss) = Library.foldr (process_fun sg descr rec_eqns)
   245 	                       (main_fns, ([], []));
   245 	                       (main_fns, ([], []));
   246     val (fs, defs) = foldr (get_fns fnss) (descr ~~ rec_names, ([], []));
   246     val (fs, defs) = Library.foldr (get_fns fnss) (descr ~~ rec_names, ([], []));
   247     val defs' = map (make_def sg fs) defs;
   247     val defs' = map (make_def sg fs) defs;
   248     val names1 = map snd fnames;
   248     val names1 = map snd fnames;
   249     val names2 = map fst rec_eqns;
   249     val names2 = map fst rec_eqns;
   250     val primrec_name =
   250     val primrec_name =
   251       if alt_name = "" then (space_implode "_" (map (Sign.base_name o #1) defs)) else alt_name;
   251       if alt_name = "" then (space_implode "_" (map (Sign.base_name o #1) defs)) else alt_name;