src/HOL/Tools/primrec.ML
changeset 39812 cdee5ca9ba9e
parent 39776 cde508d2eac8
child 39813 d466bd29c887
equal deleted inserted replaced
39811:0659e84bdc5f 39812:cdee5ca9ba9e
     1 (*  Title:      HOL/Tools/primrec.ML
     1 (*  Title:      HOL/Tools/primrec.ML
     2     Author:     Stefan Berghofer, TU Muenchen; Norbert Voelker, FernUni Hagen;
     2     Author:     Norbert Voelker, FernUni Hagen
     3                 Florian Haftmann, TU Muenchen
     3     Author:     Stefan Berghofer, TU Muenchen
     4 
     4     Author:     Florian Haftmann, TU Muenchen
     5 Package for defining functions on datatypes by primitive recursion.
     5 
       
     6 Primitive recursive functions on datatypes.
     6 *)
     7 *)
     7 
     8 
     8 signature PRIMREC =
     9 signature PRIMREC =
     9 sig
    10 sig
    10   val add_primrec: (binding * typ option * mixfix) list ->
    11   val add_primrec: (binding * typ option * mixfix) list ->
    43       (fn Free (v, _) => insert (op =) v | _ => I) body []));
    44       (fn Free (v, _) => insert (op =) v | _ => I) body []));
    44     val eqn = curry subst_bounds (map2 (curry Free) vs' Ts |> rev) body;
    45     val eqn = curry subst_bounds (map2 (curry Free) vs' Ts |> rev) body;
    45     val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn)
    46     val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn)
    46       handle TERM _ => primrec_error "not a proper equation";
    47       handle TERM _ => primrec_error "not a proper equation";
    47     val (recfun, args) = strip_comb lhs;
    48     val (recfun, args) = strip_comb lhs;
    48     val fname = case recfun of Free (v, _) => if is_fixed v then v
    49     val fname =
       
    50       (case recfun of
       
    51         Free (v, _) =>
       
    52           if is_fixed v then v
    49           else primrec_error "illegal head of function equation"
    53           else primrec_error "illegal head of function equation"
    50       | _ => primrec_error "illegal head of function equation";
    54       | _ => primrec_error "illegal head of function equation");
    51 
    55 
    52     val (ls', rest)  = take_prefix is_Free args;
    56     val (ls', rest)  = take_prefix is_Free args;
    53     val (middle, rs') = take_suffix is_Free rest;
    57     val (middle, rs') = take_suffix is_Free rest;
    54     val rpos = length ls';
    58     val rpos = length ls';
    55 
    59 
    56     val (constr, cargs') = if null middle then primrec_error "constructor missing"
    60     val (constr, cargs') =
       
    61       if null middle then primrec_error "constructor missing"
    57       else strip_comb (hd middle);
    62       else strip_comb (hd middle);
    58     val (cname, T) = dest_Const constr
    63     val (cname, T) = dest_Const constr
    59       handle TERM _ => primrec_error "ill-formed constructor";
    64       handle TERM _ => primrec_error "ill-formed constructor";
    60     val (tname, _) = dest_Type (body_type T) handle TYPE _ =>
    65     val (tname, _) = dest_Type (body_type T) handle TYPE _ =>
    61       primrec_error "cannot determine datatype associated with function"
    66       primrec_error "cannot determine datatype associated with function"
    73     else
    78     else
    74      (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees);
    79      (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees);
    75       check_vars "extra variables on rhs: "
    80       check_vars "extra variables on rhs: "
    76         (map dest_Free (OldTerm.term_frees rhs) |> subtract (op =) lfrees
    81         (map dest_Free (OldTerm.term_frees rhs) |> subtract (op =) lfrees
    77           |> filter_out (is_fixed o fst));
    82           |> filter_out (is_fixed o fst));
    78       case AList.lookup (op =) rec_fns fname of
    83       (case AList.lookup (op =) rec_fns fname of
    79         NONE =>
    84         NONE =>
    80           (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eqn))]))::rec_fns
    85           (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eqn))])) :: rec_fns
    81       | SOME (_, rpos', eqns) =>
    86       | SOME (_, rpos', eqns) =>
    82           if AList.defined (op =) eqns cname then
    87           if AList.defined (op =) eqns cname then
    83             primrec_error "constructor already occurred as pattern"
    88             primrec_error "constructor already occurred as pattern"
    84           else if rpos <> rpos' then
    89           else if rpos <> rpos' then
    85             primrec_error "position of recursive argument inconsistent"
    90             primrec_error "position of recursive argument inconsistent"
    86           else
    91           else
    87             AList.update (op =)
    92             AList.update (op =)
    88               (fname, (tname, rpos, (cname, (ls, cargs, rs, rhs, eqn))::eqns))
    93               (fname, (tname, rpos, (cname, (ls, cargs, rs, rhs, eqn)) :: eqns))
    89               rec_fns)
    94               rec_fns))
    90   end handle PrimrecError (msg, NONE) => primrec_error_eqn msg spec;
    95   end handle PrimrecError (msg, NONE) => primrec_error_eqn msg spec;
    91 
    96 
    92 fun process_fun descr eqns (i, fname) (fnames, fnss) =
    97 fun process_fun descr eqns (i, fname) (fnames, fnss) =
    93   let
    98   let
    94     val (_, (tname, _, constrs)) = nth descr i;
    99     val (_, (tname, _, constrs)) = nth descr i;
   108               andalso member (fn ((v, _), (w, _)) => v = w) eqns (dest_Free f) then
   113               andalso member (fn ((v, _), (w, _)) => v = w) eqns (dest_Free f) then
   109               let
   114               let
   110                 val (fname', _) = dest_Free f;
   115                 val (fname', _) = dest_Free f;
   111                 val (_, rpos, _) = the (AList.lookup (op =) eqns fname');
   116                 val (_, rpos, _) = the (AList.lookup (op =) eqns fname');
   112                 val (ls, rs) = chop rpos ts
   117                 val (ls, rs) = chop rpos ts
   113                 val (x', rs') = case rs
   118                 val (x', rs') =
   114                  of x' :: rs => (x', rs)
   119                   (case rs of
   115                   | [] => primrec_error ("not enough arguments in recursive application\n"
   120                     x' :: rs => (x', rs)
   116                       ^ "of function " ^ quote fname' ^ " on rhs");
   121                   | [] => primrec_error ("not enough arguments in recursive application\n" ^
       
   122                       "of function " ^ quote fname' ^ " on rhs"));
   117                 val (x, xs) = strip_comb x';
   123                 val (x, xs) = strip_comb x';
   118               in case AList.lookup (op =) subs x
   124               in
   119                of NONE =>
   125                 (case AList.lookup (op =) subs x of
       
   126                   NONE =>
   120                     fs
   127                     fs
   121                     |> fold_map (subst subs) ts
   128                     |> fold_map (subst subs) ts
   122                     |-> (fn ts' => pair (list_comb (f, ts')))
   129                     |-> (fn ts' => pair (list_comb (f, ts')))
   123                 | SOME (i', y) =>
   130                 | SOME (i', y) =>
   124                     fs
   131                     fs
   125                     |> fold_map (subst subs) (xs @ ls @ rs')
   132                     |> fold_map (subst subs) (xs @ ls @ rs')
   126                     ||> process_fun descr eqns (i', fname')
   133                     ||> process_fun descr eqns (i', fname')
   127                     |-> (fn ts' => pair (list_comb (y, ts')))
   134                     |-> (fn ts' => pair (list_comb (y, ts'))))
   128               end
   135               end
   129             else
   136             else
   130               fs
   137               fs
   131               |> fold_map (subst subs) (f :: ts)
   138               |> fold_map (subst subs) (f :: ts)
   132               |-> (fn (f'::ts') => pair (list_comb (f', ts')))
   139               |-> (fn f' :: ts' => pair (list_comb (f', ts')))
   133           end
   140           end
   134       | subst _ t fs = (t, fs);
   141       | subst _ t fs = (t, fs);
   135 
   142 
   136     (* translate rec equations into function arguments suitable for rec comb *)
   143     (* translate rec equations into function arguments suitable for rec comb *)
   137 
   144 
   138     fun trans eqns (cname, cargs) (fnames', fnss', fns) =
   145     fun trans eqns (cname, cargs) (fnames', fnss', fns) =
   139       (case AList.lookup (op =) eqns cname of
   146       (case AList.lookup (op =) eqns cname of
   140           NONE => (warning ("No equation for constructor " ^ quote cname ^
   147         NONE => (warning ("No equation for constructor " ^ quote cname ^
   141             "\nin definition of function " ^ quote fname);
   148           "\nin definition of function " ^ quote fname);
   142               (fnames', fnss', (Const ("HOL.undefined", dummyT))::fns))
   149             (fnames', fnss', (Const ("HOL.undefined", dummyT)) :: fns))
   143         | SOME (ls, cargs', rs, rhs, eq) =>
   150       | SOME (ls, cargs', rs, rhs, eq) =>
   144             let
   151           let
   145               val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
   152             val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
   146               val rargs = map fst recs;
   153             val rargs = map fst recs;
   147               val subs = map (rpair dummyT o fst)
   154             val subs = map (rpair dummyT o fst)
   148                 (rev (Term.rename_wrt_term rhs rargs));
   155               (rev (Term.rename_wrt_term rhs rargs));
   149               val (rhs', (fnames'', fnss'')) = subst (map2 (fn (x, y) => fn z =>
   156             val (rhs', (fnames'', fnss'')) = subst (map2 (fn (x, y) => fn z =>
   150                 (Free x, (body_index y, Free z))) recs subs) rhs (fnames', fnss')
   157               (Free x, (body_index y, Free z))) recs subs) rhs (fnames', fnss')
   151                   handle PrimrecError (s, NONE) => primrec_error_eqn s eq
   158                 handle PrimrecError (s, NONE) => primrec_error_eqn s eq
   152             in (fnames'', fnss'',
   159           in (fnames'', fnss'',
   153                 (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
   160               (list_abs_free (cargs' @ subs @ ls @ rs, rhs')) :: fns)
   154             end)
   161           end)
   155 
   162 
   156   in (case AList.lookup (op =) fnames i of
   163   in
       
   164     (case AList.lookup (op =) fnames i of
   157       NONE =>
   165       NONE =>
   158         if exists (fn (_, v) => fname = v) fnames then
   166         if exists (fn (_, v) => fname = v) fnames then
   159           primrec_error ("inconsistent functions for datatype " ^ quote tname)
   167           primrec_error ("inconsistent functions for datatype " ^ quote tname)
   160         else
   168         else
   161           let
   169           let
   162             val (_, _, eqns) = the (AList.lookup (op =) eqns fname);
   170             val (_, _, eqns) = the (AList.lookup (op =) eqns fname);
   163             val (fnames', fnss', fns) = fold_rev (trans eqns) constrs
   171             val (fnames', fnss', fns) = fold_rev (trans eqns) constrs
   164               ((i, fname)::fnames, fnss, [])
   172               ((i, fname) :: fnames, fnss, [])
   165           in
   173           in
   166             (fnames', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
   174             (fnames', (i, (fname, #1 (snd (hd eqns)), fns)) :: fnss')
   167           end
   175           end
   168     | SOME fname' =>
   176     | SOME fname' =>
   169         if fname = fname' then (fnames, fnss)
   177         if fname = fname' then (fnames, fnss)
   170         else primrec_error ("inconsistent functions for datatype " ^ quote tname))
   178         else primrec_error ("inconsistent functions for datatype " ^ quote tname))
   171   end;
   179   end;
   172 
   180 
   173 
   181 
   174 (* prepare functions needed for definitions *)
   182 (* prepare functions needed for definitions *)
   175 
   183 
   176 fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) =
   184 fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) =
   177   case AList.lookup (op =) fns i of
   185   (case AList.lookup (op =) fns i of
   178      NONE =>
   186     NONE =>
   179        let
   187       let
   180          val dummy_fns = map (fn (_, cargs) => Const ("HOL.undefined",
   188         val dummy_fns = map (fn (_, cargs) => Const ("HOL.undefined",
   181            replicate (length cargs + length (filter is_rec_type cargs))
   189           replicate (length cargs + length (filter is_rec_type cargs))
   182              dummyT ---> HOLogic.unitT)) constrs;
   190             dummyT ---> HOLogic.unitT)) constrs;
   183          val _ = warning ("No function definition for datatype " ^ quote tname)
   191         val _ = warning ("No function definition for datatype " ^ quote tname)
   184        in
   192       in
   185          (dummy_fns @ fs, defs)
   193         (dummy_fns @ fs, defs)
   186        end
   194       end
   187    | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs);
   195   | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs));
   188 
   196 
   189 
   197 
   190 (* make definition *)
   198 (* make definition *)
   191 
   199 
   192 fun make_def ctxt fixes fs (fname, ls, rec_name, tname) =
   200 fun make_def ctxt fixes fs (fname, ls, rec_name, tname) =
   201 
   209 
   202 
   210 
   203 (* find datatypes which contain all datatypes in tnames' *)
   211 (* find datatypes which contain all datatypes in tnames' *)
   204 
   212 
   205 fun find_dts (dt_info : info Symtab.table) _ [] = []
   213 fun find_dts (dt_info : info Symtab.table) _ [] = []
   206   | find_dts dt_info tnames' (tname::tnames) =
   214   | find_dts dt_info tnames' (tname :: tnames) =
   207       (case Symtab.lookup dt_info tname of
   215       (case Symtab.lookup dt_info tname of
   208           NONE => primrec_error (quote tname ^ " is not a datatype")
   216         NONE => primrec_error (quote tname ^ " is not a datatype")
   209         | SOME dt =>
   217       | SOME dt =>
   210             if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then
   218           if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then
   211               (tname, dt)::(find_dts dt_info tnames' tnames)
   219             (tname, dt) :: (find_dts dt_info tnames' tnames)
   212             else find_dts dt_info tnames' tnames);
   220           else find_dts dt_info tnames' tnames);
   213 
   221 
   214 
   222 
   215 (* distill primitive definition(s) from primrec specification *)
   223 (* distill primitive definition(s) from primrec specification *)
   216 
   224 
   217 fun distill lthy fixes eqs = 
   225 fun distill lthy fixes eqs = 
   229     val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
   237     val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []);
   230     val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
   238     val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
   231     val defs = map (make_def lthy fixes fs) raw_defs;
   239     val defs = map (make_def lthy fixes fs) raw_defs;
   232     val names = map snd fnames;
   240     val names = map snd fnames;
   233     val names_eqns = map fst eqns;
   241     val names_eqns = map fst eqns;
   234     val _ = if eq_set (op =) (names, names_eqns) then ()
   242     val _ =
       
   243       if eq_set (op =) (names, names_eqns) then ()
   235       else primrec_error ("functions " ^ commas_quote names_eqns ^
   244       else primrec_error ("functions " ^ commas_quote names_eqns ^
   236         "\nare not mutually recursive");
   245         "\nare not mutually recursive");
   237     val rec_rewrites' = map mk_meta_eq rec_rewrites;
   246     val rec_rewrites' = map mk_meta_eq rec_rewrites;
   238     val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs);
   247     val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs);
   239     fun prove lthy defs =
   248     fun prove lthy defs =
   244         fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
   253         fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1];
   245         val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names);
   254         val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names);
   246       in map (fn eq => Goal.prove lthy frees [] eq tac) eqs end;
   255       in map (fn eq => Goal.prove lthy frees [] eq tac) eqs end;
   247   in ((prefix, (fs, defs)), prove) end
   256   in ((prefix, (fs, defs)), prove) end
   248   handle PrimrecError (msg, some_eqn) =>
   257   handle PrimrecError (msg, some_eqn) =>
   249     error ("Primrec definition error:\n" ^ msg ^ (case some_eqn
   258     error ("Primrec definition error:\n" ^ msg ^
   250      of SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
   259       (case some_eqn of
       
   260         SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn)
   251       | NONE => ""));
   261       | NONE => ""));
   252 
   262 
   253 
   263 
   254 (* primrec definition *)
   264 (* primrec definition *)
   255 
   265 
   304   in ((ts, simps'), Local_Theory.exit_global lthy') end;
   314   in ((ts, simps'), Local_Theory.exit_global lthy') end;
   305 
   315 
   306 
   316 
   307 (* outer syntax *)
   317 (* outer syntax *)
   308 
   318 
   309 val primrec_decl = Parse.opt_target -- Parse.fixes -- Parse_Spec.where_alt_specs;
       
   310 
       
   311 val _ =
   319 val _ =
   312   Outer_Syntax.command "primrec" "define primitive recursive functions on datatypes"
   320   Outer_Syntax.local_theory "primrec" "define primitive recursive functions on datatypes"
   313   Keyword.thy_decl
   321     Keyword.thy_decl
   314     (primrec_decl >> (fn ((opt_target, fixes), specs) =>
   322     (Parse.fixes -- Parse_Spec.where_alt_specs
   315       Toplevel.local_theory opt_target (add_primrec_cmd fixes specs #> snd)));
   323       >> (fn (fixes, specs) => add_primrec_cmd fixes specs #> snd));
   316 
   324 
   317 end;
   325 end;