src/HOL/Tools/primrec_package.ML
changeset 16765 b8b1f310877f
parent 16646 666774b0d1b0
child 17057 0934ac31985f
equal deleted inserted replaced
16764:ca81a99c5bc1 16765:b8b1f310877f
    41 	    HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
    41 	    HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
    42 	      handle TERM _ => raise RecError "not a proper equation"
    42 	      handle TERM _ => raise RecError "not a proper equation"
    43 	else raise RecError "illegal schematic variable(s)";
    43 	else raise RecError "illegal schematic variable(s)";
    44 
    44 
    45     val (recfun, args) = strip_comb lhs;
    45     val (recfun, args) = strip_comb lhs;
    46     val (fname, _) = dest_Const recfun handle TERM _ => 
    46     val fnameT = dest_Const recfun handle TERM _ => 
    47       raise RecError "function is not declared as constant in theory";
    47       raise RecError "function is not declared as constant in theory";
    48 
    48 
    49     val (ls', rest)  = take_prefix is_Free args;
    49     val (ls', rest)  = take_prefix is_Free args;
    50     val (middle, rs') = take_suffix is_Free rest;
    50     val (middle, rs') = take_suffix is_Free rest;
    51     val rpos = length ls';
    51     val rpos = length ls';
    70       raise RecError "more than one non-variable in pattern"
    70       raise RecError "more than one non-variable in pattern"
    71     else
    71     else
    72      (check_vars "repeated variable names in pattern: " (duplicates lfrees);
    72      (check_vars "repeated variable names in pattern: " (duplicates lfrees);
    73       check_vars "extra variables on rhs: "
    73       check_vars "extra variables on rhs: "
    74         (map dest_Free (term_frees rhs) \\ lfrees);
    74         (map dest_Free (term_frees rhs) \\ lfrees);
    75       case assoc (rec_fns, fname) of
    75       case assoc (rec_fns, fnameT) of
    76         NONE =>
    76         NONE =>
    77           (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eq))]))::rec_fns
    77           (fnameT, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eq))]))::rec_fns
    78       | SOME (_, rpos', eqns) =>
    78       | SOME (_, rpos', eqns) =>
    79           if isSome (assoc (eqns, cname)) then
    79           if isSome (assoc (eqns, cname)) then
    80             raise RecError "constructor already occurred as pattern"
    80             raise RecError "constructor already occurred as pattern"
    81           else if rpos <> rpos' then
    81           else if rpos <> rpos' then
    82             raise RecError "position of recursive argument inconsistent"
    82             raise RecError "position of recursive argument inconsistent"
    83           else
    83           else
    84             overwrite (rec_fns, 
    84             overwrite (rec_fns, 
    85 		       (fname, 
    85 		       (fnameT, 
    86 			(tname, rpos,
    86 			(tname, rpos,
    87 			 (cname, (ls, cargs, rs, rhs, eq))::eqns))))
    87 			 (cname, (ls, cargs, rs, rhs, eq))::eqns))))
    88   end
    88   end
    89   handle RecError s => primrec_eq_err sign s eq;
    89   handle RecError s => primrec_eq_err sign s eq;
    90 
    90 
    91 fun process_fun sign descr rec_eqns ((i, fname), (fnames, fnss)) =
    91 fun process_fun sign descr rec_eqns ((i, fnameT as (fname, _)), (fnameTs, fnss)) =
    92   let
    92   let
    93     val (_, (tname, _, constrs)) = List.nth (descr, i);
    93     val (_, (tname, _, constrs)) = List.nth (descr, i);
    94 
    94 
    95     (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *)
    95     (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *)
    96 
    96 
    99           let val (fs', t') = subst subs (fs, t)
    99           let val (fs', t') = subst subs (fs, t)
   100           in (fs', Abs (a, T, t')) end
   100           in (fs', Abs (a, T, t')) end
   101       | subst subs (fs, t as (_ $ _)) =
   101       | subst subs (fs, t as (_ $ _)) =
   102           let val (f, ts) = strip_comb t;
   102           let val (f, ts) = strip_comb t;
   103           in
   103           in
   104             if is_Const f andalso (fst (dest_Const f)) mem (map fst rec_eqns) then
   104             if is_Const f andalso dest_Const f mem map fst rec_eqns then
   105               let
   105               let
   106                 val (fname', _) = dest_Const f;
   106                 val fnameT' as (fname', _) = dest_Const f;
   107                 val (_, rpos, _) = valOf (assoc (rec_eqns, fname'));
   107                 val (_, rpos, _) = valOf (assoc (rec_eqns, fnameT'));
   108                 val ls = Library.take (rpos, ts);
   108                 val ls = Library.take (rpos, ts);
   109                 val rest = Library.drop (rpos, ts);
   109                 val rest = Library.drop (rpos, ts);
   110                 val (x', rs) = (hd rest, tl rest)
   110                 val (x', rs) = (hd rest, tl rest)
   111                   handle Empty => raise RecError ("not enough arguments\
   111                   handle Empty => raise RecError ("not enough arguments\
   112                    \ in recursive application\nof function " ^ quote fname' ^ " on rhs");
   112                    \ in recursive application\nof function " ^ quote fname' ^ " on rhs");
   118                         val (fs', ts') = foldl_map (subst subs) (fs, ts)
   118                         val (fs', ts') = foldl_map (subst subs) (fs, ts)
   119                       in (fs', list_comb (f, ts')) end
   119                       in (fs', list_comb (f, ts')) end
   120                   | SOME (i', y) =>
   120                   | SOME (i', y) =>
   121                       let
   121                       let
   122                         val (fs', ts') = foldl_map (subst subs) (fs, xs @ ls @ rs);
   122                         val (fs', ts') = foldl_map (subst subs) (fs, xs @ ls @ rs);
   123                         val fs'' = process_fun sign descr rec_eqns ((i', fname'), fs')
   123                         val fs'' = process_fun sign descr rec_eqns ((i', fnameT'), fs')
   124                       in (fs'', list_comb (y, ts'))
   124                       in (fs'', list_comb (y, ts'))
   125                       end)
   125                       end)
   126               end
   126               end
   127             else
   127             else
   128               let
   128               let
   131           end
   131           end
   132       | subst _ x = x;
   132       | subst _ x = x;
   133 
   133 
   134     (* translate rec equations into function arguments suitable for rec comb *)
   134     (* translate rec equations into function arguments suitable for rec comb *)
   135 
   135 
   136     fun trans eqns ((cname, cargs), (fnames', fnss', fns)) =
   136     fun trans eqns ((cname, cargs), (fnameTs', fnss', fns)) =
   137       (case assoc (eqns, cname) of
   137       (case assoc (eqns, cname) of
   138           NONE => (warning ("No equation for constructor " ^ quote cname ^
   138           NONE => (warning ("No equation for constructor " ^ quote cname ^
   139             "\nin definition of function " ^ quote fname);
   139             "\nin definition of function " ^ quote fname);
   140               (fnames', fnss', (Const ("arbitrary", dummyT))::fns))
   140               (fnameTs', fnss', (Const ("arbitrary", dummyT))::fns))
   141         | SOME (ls, cargs', rs, rhs, eq) =>
   141         | SOME (ls, cargs', rs, rhs, eq) =>
   142             let
   142             let
   143               val recs = List.filter (is_rec_type o snd) (cargs' ~~ cargs);
   143               val recs = List.filter (is_rec_type o snd) (cargs' ~~ cargs);
   144               val rargs = map fst recs;
   144               val rargs = map fst recs;
   145               val subs = map (rpair dummyT o fst) 
   145               val subs = map (rpair dummyT o fst) 
   146 		             (rev (rename_wrt_term rhs rargs));
   146 		             (rev (rename_wrt_term rhs rargs));
   147               val ((fnames'', fnss''), rhs') = 
   147               val ((fnameTs'', fnss''), rhs') = 
   148 		  (subst (map (fn ((x, y), z) =>
   148 		  (subst (map (fn ((x, y), z) =>
   149 			       (Free x, (body_index y, Free z)))
   149 			       (Free x, (body_index y, Free z)))
   150 			  (recs ~~ subs))
   150 			  (recs ~~ subs))
   151 		   ((fnames', fnss'), rhs))
   151 		   ((fnameTs', fnss'), rhs))
   152                   handle RecError s => primrec_eq_err sign s eq
   152                   handle RecError s => primrec_eq_err sign s eq
   153             in (fnames'', fnss'', 
   153             in (fnameTs'', fnss'', 
   154 		(list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
   154 		(list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
   155             end)
   155             end)
   156 
   156 
   157   in (case assoc (fnames, i) of
   157   in (case assoc (fnameTs, i) of
   158       NONE =>
   158       NONE =>
   159         if exists (equal fname o snd) fnames then
   159         if exists (equal fnameT o snd) fnameTs then
   160           raise RecError ("inconsistent functions for datatype " ^ quote tname)
   160           raise RecError ("inconsistent functions for datatype " ^ quote tname)
   161         else
   161         else
   162           let
   162           let
   163             val (_, _, eqns) = valOf (assoc (rec_eqns, fname));
   163             val (_, _, eqns) = valOf (assoc (rec_eqns, fnameT));
   164             val (fnames', fnss', fns) = foldr (trans eqns)
   164             val (fnameTs', fnss', fns) = foldr (trans eqns)
   165               ((i, fname)::fnames, fnss, []) constrs
   165               ((i, fnameT)::fnameTs, fnss, []) constrs
   166           in
   166           in
   167             (fnames', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
   167             (fnameTs', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
   168           end
   168           end
   169     | SOME fname' =>
   169     | SOME fnameT' =>
   170         if fname = fname' then (fnames, fnss)
   170         if fnameT = fnameT' then (fnameTs, fnss)
   171         else raise RecError ("inconsistent functions for datatype " ^ quote tname))
   171         else raise RecError ("inconsistent functions for datatype " ^ quote tname))
   172   end;
   172   end;
   173 
   173 
   174 
   174 
   175 (* prepare functions needed for definitions *)
   175 (* prepare functions needed for definitions *)
   239     val {descr, rec_names, rec_rewrites, ...} = 
   239     val {descr, rec_names, rec_rewrites, ...} = 
   240 	if null dts then
   240 	if null dts then
   241 	    primrec_err ("datatypes " ^ commas_quote tnames ^ 
   241 	    primrec_err ("datatypes " ^ commas_quote tnames ^ 
   242 			 "\nare not mutually recursive")
   242 			 "\nare not mutually recursive")
   243 	else snd (hd dts);
   243 	else snd (hd dts);
   244     val (fnames, fnss) = foldr (process_fun sg descr rec_eqns)
   244     val (fnameTs, fnss) = foldr (process_fun sg descr rec_eqns)
   245 	                       ([], []) main_fns;
   245 	                       ([], []) main_fns;
   246     val (fs, defs) = foldr (get_fns fnss) ([], []) (descr ~~ rec_names);
   246     val (fs, defs) = foldr (get_fns fnss) ([], []) (descr ~~ rec_names);
   247     val defs' = map (make_def sg fs) defs;
   247     val defs' = map (make_def sg fs) defs;
   248     val names1 = map snd fnames;
   248     val nameTs1 = map snd fnameTs;
   249     val names2 = map fst rec_eqns;
   249     val nameTs2 = map fst rec_eqns;
   250     val primrec_name =
   250     val primrec_name =
   251       if alt_name = "" then (space_implode "_" (map (Sign.base_name o #1) defs)) else alt_name;
   251       if alt_name = "" then (space_implode "_" (map (Sign.base_name o #1) defs)) else alt_name;
   252     val (thy', defs_thms') = thy |> Theory.add_path primrec_name |>
   252     val (thy', defs_thms') = thy |> Theory.add_path primrec_name |>
   253       (if eq_set (names1, names2) then (PureThy.add_defs_i false o map Thm.no_attributes) defs'
   253       (if eq_set (nameTs1, nameTs2) then (PureThy.add_defs_i false o map Thm.no_attributes) defs'
   254        else primrec_err ("functions " ^ commas_quote names2 ^
   254        else primrec_err ("functions " ^ commas_quote (map fst nameTs2) ^
   255          "\nare not mutually recursive"));
   255          "\nare not mutually recursive"));
   256     val rewrites = (map mk_meta_eq rec_rewrites) @ defs_thms';
   256     val rewrites = (map mk_meta_eq rec_rewrites) @ defs_thms';
   257     val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names1 ^ " ...");
   257     val _ = message ("Proving equations for primrec function(s) " ^
       
   258       commas_quote (map fst nameTs1) ^ " ...");
   258     val simps = map (fn (_, t) => prove_goalw_cterm rewrites (cterm_of (Theory.sign_of thy') t)
   259     val simps = map (fn (_, t) => prove_goalw_cterm rewrites (cterm_of (Theory.sign_of thy') t)
   259         (fn _ => [rtac refl 1])) eqns;
   260         (fn _ => [rtac refl 1])) eqns;
   260     val (thy'', simps') = PureThy.add_thms ((map fst eqns ~~ simps) ~~ atts) thy';
   261     val (thy'', simps') = PureThy.add_thms ((map fst eqns ~~ simps) ~~ atts) thy';
   261     val thy''' = thy''
   262     val thy''' = thy''
   262       |> (#1 o PureThy.add_thmss [(("simps", simps'),
   263       |> (#1 o PureThy.add_thmss [(("simps", simps'),