src/HOL/Tools/primrec_package.ML
changeset 22692 1e057a3f087d
parent 22480 b20bc8029edb
child 22728 ecbbdf50df2f
equal deleted inserted replaced
22691:290454649b8c 22692:1e057a3f087d
    14     -> theory -> thm list * theory
    14     -> theory -> thm list * theory
    15   val add_primrec_i: string -> ((bstring * term) * attribute list) list
    15   val add_primrec_i: string -> ((bstring * term) * attribute list) list
    16     -> theory -> thm list * theory
    16     -> theory -> thm list * theory
    17   val add_primrec_unchecked_i: string -> ((bstring * term) * attribute list) list
    17   val add_primrec_unchecked_i: string -> ((bstring * term) * attribute list) list
    18     -> theory -> thm list * theory
    18     -> theory -> thm list * theory
       
    19   (* FIXME !? *)
    19   val gen_primrec: ((bstring * attribute list) * thm list -> theory -> (bstring * thm list) * theory)
    20   val gen_primrec: ((bstring * attribute list) * thm list -> theory -> (bstring * thm list) * theory)
    20     -> ((bstring * attribute list) * term -> theory -> (bstring * thm) * theory)
    21     -> ((bstring * attribute list) * term -> theory -> (bstring * thm) * theory)
    21     -> string -> ((bstring * attribute list) * term) list
    22     -> string -> ((bstring * attribute list) * term) list
    22     -> theory -> thm list * theory;
    23     -> theory -> thm list * theory;
    23 end;
    24 end;
    40 fun message s = if ! quiet_mode then () else writeln s;
    41 fun message s = if ! quiet_mode then () else writeln s;
    41 
    42 
    42 
    43 
    43 (* preprocessing of equations *)
    44 (* preprocessing of equations *)
    44 
    45 
    45 fun process_eqn thy eq rec_fns = 
    46 fun process_eqn thy eq rec_fns =
    46   let
    47   let
    47     val (lhs, rhs) = 
    48     val (lhs, rhs) =
    48       if null (term_vars eq) then
    49       if null (term_vars eq) then
    49         HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
    50         HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
    50         handle TERM _ => raise RecError "not a proper equation"
    51         handle TERM _ => raise RecError "not a proper equation"
    51       else raise RecError "illegal schematic variable(s)";
    52       else raise RecError "illegal schematic variable(s)";
    52 
    53 
    53     val (recfun, args) = strip_comb lhs;
    54     val (recfun, args) = strip_comb lhs;
    54     val fnameT = dest_Const recfun handle TERM _ => 
    55     val fnameT = dest_Const recfun handle TERM _ =>
    55       raise RecError "function is not declared as constant in theory";
    56       raise RecError "function is not declared as constant in theory";
    56 
    57 
    57     val (ls', rest)  = take_prefix is_Free args;
    58     val (ls', rest)  = take_prefix is_Free args;
    58     val (middle, rs') = take_suffix is_Free rest;
    59     val (middle, rs') = take_suffix is_Free rest;
    59     val rpos = length ls';
    60     val rpos = length ls';
    71     val lfrees = ls @ rs @ cargs;
    72     val lfrees = ls @ rs @ cargs;
    72 
    73 
    73     fun check_vars _ [] = ()
    74     fun check_vars _ [] = ()
    74       | check_vars s vars = raise RecError (s ^ commas_quote (map fst vars))
    75       | check_vars s vars = raise RecError (s ^ commas_quote (map fst vars))
    75   in
    76   in
    76     if length middle > 1 then 
    77     if length middle > 1 then
    77       raise RecError "more than one non-variable in pattern"
    78       raise RecError "more than one non-variable in pattern"
    78     else
    79     else
    79      (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees);
    80      (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees);
    80       check_vars "extra variables on rhs: "
    81       check_vars "extra variables on rhs: "
    81         (map dest_Free (term_frees rhs) \\ lfrees);
    82         (map dest_Free (term_frees rhs) \\ lfrees);
   145               (fnameTs', fnss', (Const ("HOL.undefined", dummyT))::fns))
   146               (fnameTs', fnss', (Const ("HOL.undefined", dummyT))::fns))
   146         | SOME (ls, cargs', rs, rhs, eq) =>
   147         | SOME (ls, cargs', rs, rhs, eq) =>
   147             let
   148             let
   148               val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
   149               val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
   149               val rargs = map fst recs;
   150               val rargs = map fst recs;
   150               val subs = map (rpair dummyT o fst) 
   151               val subs = map (rpair dummyT o fst)
   151                 (rev (rename_wrt_term rhs rargs));
   152                 (rev (rename_wrt_term rhs rargs));
   152               val (rhs', (fnameTs'', fnss'')) = 
   153               val (rhs', (fnameTs'', fnss'')) =
   153                   (subst (map (fn ((x, y), z) =>
   154                   (subst (map (fn ((x, y), z) =>
   154                                (Free x, (body_index y, Free z)))
   155                                (Free x, (body_index y, Free z)))
   155                           (recs ~~ subs)) rhs (fnameTs', fnss'))
   156                           (recs ~~ subs)) rhs (fnameTs', fnss'))
   156                   handle RecError s => primrec_eq_err thy s eq
   157                   handle RecError s => primrec_eq_err thy s eq
   157             in (fnameTs'', fnss'', 
   158             in (fnameTs'', fnss'',
   158                 (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
   159                 (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
   159             end)
   160             end)
   160 
   161 
   161   in (case AList.lookup (op =) fnameTs i of
   162   in (case AList.lookup (op =) fnameTs i of
   162       NONE =>
   163       NONE =>
   164           raise RecError ("inconsistent functions for datatype " ^ quote tname)
   165           raise RecError ("inconsistent functions for datatype " ^ quote tname)
   165         else
   166         else
   166           let
   167           let
   167             val (_, _, eqns) = the (AList.lookup (op =) rec_eqns fnameT);
   168             val (_, _, eqns) = the (AList.lookup (op =) rec_eqns fnameT);
   168             val (fnameTs', fnss', fns) = fold_rev (trans eqns) constrs
   169             val (fnameTs', fnss', fns) = fold_rev (trans eqns) constrs
   169               ((i, fnameT)::fnameTs, fnss, []) 
   170               ((i, fnameT)::fnameTs, fnss, [])
   170           in
   171           in
   171             (fnameTs', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
   172             (fnameTs', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
   172           end
   173           end
   173     | SOME fnameT' =>
   174     | SOME fnameT' =>
   174         if fnameT = fnameT' then (fnameTs, fnss)
   175         if fnameT = fnameT' then (fnameTs, fnss)
   198   let
   199   let
   199     val rhs = fold_rev (fn T => fn t => Abs ("", T, t))
   200     val rhs = fold_rev (fn T => fn t => Abs ("", T, t))
   200                     ((map snd ls) @ [dummyT])
   201                     ((map snd ls) @ [dummyT])
   201                     (list_comb (Const (rec_name, dummyT),
   202                     (list_comb (Const (rec_name, dummyT),
   202                                 fs @ map Bound (0 ::(length ls downto 1))))
   203                                 fs @ map Bound (0 ::(length ls downto 1))))
   203     val defpair = (Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def",
   204     val def_name = Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def";
   204                    Logic.mk_equals (Const (fname, dummyT), rhs))
   205     val def_prop =
   205   in Theory.inferT_axm thy defpair end;
   206       singleton (ProofContext.infer_types (ProofContext.init thy))
       
   207         (Logic.mk_equals (Const (fname, dummyT), rhs), propT) |> #1;
       
   208   in (def_name, def_prop) end;
   206 
   209 
   207 
   210 
   208 (* find datatypes which contain all datatypes in tnames' *)
   211 (* find datatypes which contain all datatypes in tnames' *)
   209 
   212 
   210 fun find_dts (dt_info : datatype_info Symtab.table) _ [] = []
   213 fun find_dts (dt_info : datatype_info Symtab.table) _ [] = []
   234     val (eqns, atts) = split_list eqns_atts;
   237     val (eqns, atts) = split_list eqns_atts;
   235     val dt_info = DatatypePackage.get_datatypes thy;
   238     val dt_info = DatatypePackage.get_datatypes thy;
   236     val rec_eqns = fold_rev (process_eqn thy o snd) eqns [] ;
   239     val rec_eqns = fold_rev (process_eqn thy o snd) eqns [] ;
   237     val tnames = distinct (op =) (map (#1 o snd) rec_eqns);
   240     val tnames = distinct (op =) (map (#1 o snd) rec_eqns);
   238     val dts = find_dts dt_info tnames tnames;
   241     val dts = find_dts dt_info tnames tnames;
   239     val main_fns = 
   242     val main_fns =
   240       map (fn (tname, {index, ...}) =>
   243       map (fn (tname, {index, ...}) =>
   241         (index, 
   244         (index,
   242           (fst o the o find_first (fn f => (#1 o snd) f = tname)) rec_eqns))
   245           (fst o the o find_first (fn f => (#1 o snd) f = tname)) rec_eqns))
   243       dts;
   246       dts;
   244     val {descr, rec_names, rec_rewrites, ...} = 
   247     val {descr, rec_names, rec_rewrites, ...} =
   245       if null dts then
   248       if null dts then
   246         primrec_err ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
   249         primrec_err ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
   247       else snd (hd dts);
   250       else snd (hd dts);
   248     val (fnameTs, fnss) =
   251     val (fnameTs, fnss) =
   249       fold_rev (process_fun thy descr rec_eqns) main_fns ([], []);
   252       fold_rev (process_fun thy descr rec_eqns) main_fns ([], []);
   306 val add_primrec_i = gen_primrec_i thy_note (thy_def false);
   309 val add_primrec_i = gen_primrec_i thy_note (thy_def false);
   307 val add_primrec_unchecked_i = gen_primrec_i thy_note (thy_def true);
   310 val add_primrec_unchecked_i = gen_primrec_i thy_note (thy_def true);
   308 fun gen_primrec note def alt_name specs =
   311 fun gen_primrec note def alt_name specs =
   309   gen_primrec_i note def alt_name (map (fn ((name, t), atts) => ((name, atts), t)) specs);
   312   gen_primrec_i note def alt_name (map (fn ((name, t), atts) => ((name, atts), t)) specs);
   310 
   313 
   311 end; (*local*)
   314 end;
   312 
   315 
   313 
   316 
   314 (* outer syntax *)
   317 (* outer syntax *)
   315 
   318 
   316 local structure P = OuterParse and K = OuterKeyword in
   319 local structure P = OuterParse and K = OuterKeyword in
   332 
   335 
   333 val _ = OuterSyntax.add_parsers [primrecP];
   336 val _ = OuterSyntax.add_parsers [primrecP];
   334 
   337 
   335 end;
   338 end;
   336 
   339 
   337 
       
   338 end;
   340 end;
   339