src/ZF/Tools/primrec_package.ML
changeset 6050 b3eb3de3a288
child 6065 3b4a29166f26
equal deleted inserted replaced
6049:7fef0169ab5e 6050:b3eb3de3a288
       
     1 (*  Title:      ZF/Tools/primrec_package.ML
       
     2     ID:         $Id$
       
     3     Author:     Stefan Berghofer and Norbert Voelker
       
     4     Copyright   1998  TU Muenchen
       
     5     ZF version by Lawrence C Paulson (Cambridge)
       
     6 
       
     7 Package for defining functions on datatypes by primitive recursion
       
     8 *)
       
     9 
       
    10 signature PRIMREC_PACKAGE =
       
    11 sig
       
    12   val add_primrec_i : (string * term) list -> theory -> theory * thm list
       
    13   val add_primrec   : (string * string) list -> theory -> theory * thm list
       
    14 end;
       
    15 
       
    16 structure PrimrecPackage : PRIMREC_PACKAGE =
       
    17 struct
       
    18 
       
    19 exception RecError of string;
       
    20 
       
    21 (* FIXME: move? *)
       
    22 
       
    23 fun dest_eq (Const ("Trueprop", _) $ (Const ("op =", _) $ lhs $ rhs)) = (lhs, rhs)
       
    24   | dest_eq t = raise TERM ("dest_eq", [t])
       
    25 
       
    26 fun primrec_err s = error ("Primrec definition error:\n" ^ s);
       
    27 
       
    28 fun primrec_eq_err sign s eq =
       
    29   primrec_err (s ^ "\nin equation\n" ^ Sign.string_of_term sign eq);
       
    30 
       
    31 (* preprocessing of equations *)
       
    32 
       
    33 (*rec_fn_opt records equations already noted for this function*)
       
    34 fun process_eqn thy (eq, rec_fn_opt) = 
       
    35   let
       
    36     val (lhs, rhs) = if null (term_vars eq) then
       
    37         dest_eq eq handle _ => raise RecError "not a proper equation"
       
    38       else raise RecError "illegal schematic variable(s)";
       
    39 
       
    40     val (recfun, args) = strip_comb lhs;
       
    41     val (fname, ftype) = dest_Const recfun handle _ => 
       
    42       raise RecError "function is not declared as constant in theory";
       
    43 
       
    44     val (ls_frees, rest)  = take_prefix is_Free args;
       
    45     val (middle, rs_frees) = take_suffix is_Free rest;
       
    46 
       
    47     val (constr, cargs_frees) = 
       
    48       if null middle then raise RecError "constructor missing"
       
    49       else strip_comb (hd middle);
       
    50     val (cname, _) = dest_Const constr
       
    51       handle _ => raise RecError "ill-formed constructor";
       
    52     val con_info = the (Symtab.lookup (ConstructorsData.get thy, cname))
       
    53       handle _ =>
       
    54       raise RecError "cannot determine datatype associated with function"
       
    55 
       
    56     val (ls, cargs, rs) = (map dest_Free ls_frees, 
       
    57 			   map dest_Free cargs_frees, 
       
    58 			   map dest_Free rs_frees)
       
    59       handle _ => raise RecError "illegal argument in pattern";
       
    60     val lfrees = ls @ rs @ cargs;
       
    61 
       
    62     (*Constructor, frees to left of pattern, pattern variables,
       
    63       frees to right of pattern, rhs of equation, full original equation. *)
       
    64     val new_eqn = (cname, (rhs, cargs, eq))
       
    65 
       
    66   in
       
    67     if not (null (duplicates lfrees)) then 
       
    68       raise RecError "repeated variable name in pattern" 
       
    69     else if not ((map dest_Free (term_frees rhs)) subset lfrees) then
       
    70       raise RecError "extra variables on rhs"
       
    71     else if length middle > 1 then 
       
    72       raise RecError "more than one non-variable in pattern"
       
    73     else case rec_fn_opt of
       
    74         None => Some (fname, ftype, ls, rs, con_info, [new_eqn])
       
    75       | Some (fname', _, ls', rs', con_info': constructor_info, eqns) => 
       
    76 	  if is_some (assoc (eqns, cname)) then
       
    77 	    raise RecError "constructor already occurred as pattern"
       
    78 	  else if (ls <> ls') orelse (rs <> rs') then
       
    79 	    raise RecError "non-recursive arguments are inconsistent"
       
    80 	  else if #big_rec_name con_info <> #big_rec_name con_info' then
       
    81 	     raise RecError ("Mixed datatypes for function " ^ fname)
       
    82 	  else if fname <> fname' then
       
    83 	     raise RecError ("inconsistent functions for datatype " ^ 
       
    84 			     #big_rec_name con_info)
       
    85 	  else Some (fname, ftype, ls, rs, con_info, new_eqn::eqns)
       
    86   end
       
    87   handle RecError s => primrec_eq_err (sign_of thy) s eq;
       
    88 
       
    89 
       
    90 (*Instantiates a recursor equation with constructor arguments*)
       
    91 fun inst_recursor ((_ $ constr, rhs), cargs') = 
       
    92     subst_atomic (#2 (strip_comb constr) ~~ map Free cargs') rhs;
       
    93 
       
    94 
       
    95 (*Convert a list of recursion equations into a recursor call*)
       
    96 fun process_fun thy (fname, ftype, ls, rs, con_info: constructor_info, eqns) =
       
    97   let
       
    98     val fconst = Const(fname, ftype)
       
    99     val fabs = list_comb (fconst, map Free ls @ [Bound 0] @ map Free rs)
       
   100     and {big_rec_name, constructors, rec_rewrites, ...} = con_info
       
   101 
       
   102     (*Replace X_rec(args,t) by fname(ls,t,rs) *)
       
   103     fun use_fabs (_ $ t) = subst_bound (t, fabs)
       
   104       | use_fabs t       = t
       
   105 
       
   106     val cnames         = map (#1 o dest_Const) constructors
       
   107     and recursor_pairs = map (dest_eq o concl_of) rec_rewrites
       
   108 
       
   109     fun absterm (Free(a,T), body) = absfree (a,T,body)
       
   110       | absterm (t,body)          = Abs("rec", iT, abstract_over (t, body))
       
   111 
       
   112     (*Translate rec equations into function arguments suitable for recursor.
       
   113       Missing cases are replaced by 0 and all cases are put into order.*)
       
   114     fun add_case ((cname, recursor_pair), cases) =
       
   115       let val (rhs, recursor_rhs, eq) = 
       
   116 	    case assoc (eqns, cname) of
       
   117 		None => (warning ("no equation for constructor " ^ cname ^
       
   118 				  "\nin definition of function " ^ fname);
       
   119 			 (Const ("0", iT), #2 recursor_pair, Const ("0", iT)))
       
   120 	      | Some (rhs, cargs', eq) =>
       
   121 		    (rhs, inst_recursor (recursor_pair, cargs'), eq)
       
   122 	  val allowed_terms = map use_fabs (#2 (strip_comb recursor_rhs))
       
   123 	  val abs = foldr absterm (allowed_terms, rhs)
       
   124       in 
       
   125           if !Ind_Syntax.trace then
       
   126 	      writeln ("recursor_rhs = " ^ 
       
   127 		       Sign.string_of_term (sign_of thy) recursor_rhs ^
       
   128 		       "\nabs = " ^ Sign.string_of_term (sign_of thy) abs)
       
   129           else();
       
   130 	  if Logic.occs (fconst, abs) then 
       
   131 	      primrec_eq_err (sign_of thy) 
       
   132 	           ("illegal recursive occurrences of " ^ fname)
       
   133 		   eq
       
   134 	  else abs :: cases
       
   135       end
       
   136 
       
   137     val recursor = head_of (#1 (hd recursor_pairs))
       
   138 
       
   139     (** make definition **)
       
   140 
       
   141     (*the recursive argument*)
       
   142     val rec_arg = Free (variant (map #1 (ls@rs)) (Sign.base_name big_rec_name),
       
   143 			iT)
       
   144 
       
   145     val def_tm = Logic.mk_equals
       
   146 	            (subst_bound (rec_arg, fabs),
       
   147 		     list_comb (recursor,
       
   148 				foldr add_case (cnames ~~ recursor_pairs, []))
       
   149 		     $ rec_arg)
       
   150 
       
   151   in
       
   152       writeln ("def = " ^ Sign.string_of_term (sign_of thy) def_tm);
       
   153       (Sign.base_name fname ^ "_" ^ Sign.base_name big_rec_name ^ "_def",
       
   154        def_tm)
       
   155   end;
       
   156 
       
   157 
       
   158 
       
   159 (* prepare functions needed for definitions *)
       
   160 
       
   161 (*Each equation is paired with an optional name, which is "_" (ML wildcard)
       
   162   if omitted.*)
       
   163 fun add_primrec_i recursion_eqns thy =
       
   164   let
       
   165     val Some (fname, ftype, ls, rs, con_info, eqns) = 
       
   166 	foldr (process_eqn thy) (map snd recursion_eqns, None);
       
   167     val def = process_fun thy (fname, ftype, ls, rs, con_info, eqns) 
       
   168     val thy' = thy |> Theory.add_path (Sign.base_name (#1 def))
       
   169                    |> Theory.add_defs_i [def]
       
   170     val rewrites = get_axiom thy' (#1 def) ::
       
   171 	           map mk_meta_eq (#rec_rewrites con_info)
       
   172     val _ = writeln ("Proving equations for primrec function " ^ fname);
       
   173     val char_thms = 
       
   174 	map (fn (_,t) => 
       
   175 	     prove_goalw_cterm rewrites
       
   176 	       (Ind_Syntax.traceIt "next primrec equation = "
       
   177 		(cterm_of (sign_of thy') t))
       
   178 	     (fn _ => [rtac refl 1]))
       
   179 	recursion_eqns;
       
   180     val tsimps = Attribute.tthms_of char_thms;
       
   181     val thy'' = thy' 
       
   182       |> PureThy.add_tthmss [(("simps", tsimps), [Simplifier.simp_add_global])]
       
   183       |> PureThy.add_tthms (map (rpair [])
       
   184          (filter_out (equal "_" o fst) (map fst recursion_eqns ~~ tsimps)))
       
   185       |> Theory.parent_path;
       
   186   in
       
   187     (thy'', char_thms)
       
   188   end;
       
   189 
       
   190 fun add_primrec eqns thy =
       
   191   add_primrec_i (map (apsnd (readtm (sign_of thy) propT)) eqns) thy;
       
   192 
       
   193 end;