src/HOLCF/Tools/fixrec.ML
changeset 40036 a81758e0394d
parent 39806 d59b9531d6b0
child 40041 1f09b4c7b85e
equal deleted inserted replaced
40035:a12d35795cb9 40036:a81758e0394d
   245   in
   245   in
   246     comp_pat pat rhs taken
   246     comp_pat pat rhs taken
   247   end;
   247   end;
   248 
   248 
   249 (* builds a monadic term for matching a function definition pattern *)
   249 (* builds a monadic term for matching a function definition pattern *)
   250 (* returns (name, arity, matcher) *)
   250 (* returns (constant, (vars, matcher)) *)
   251 fun compile_lhs match_name pat rhs vs taken =
   251 fun compile_lhs match_name pat rhs vs taken =
   252   case pat of
   252   case pat of
   253     Const(@{const_name Rep_CFun}, _) $ f $ x =>
   253     Const(@{const_name Rep_CFun}, _) $ f $ x =>
   254       let val (rhs', v, taken') = compile_pat match_name x rhs taken;
   254       let val (rhs', v, taken') = compile_pat match_name x rhs taken;
   255       in compile_lhs match_name f rhs' (v::vs) taken' end
   255       in compile_lhs match_name f rhs' (v::vs) taken' end
   256   | Free(_,_) => ((pat, length vs), big_lambdas vs rhs)
   256   | Free(_,_) => (pat, (vs, rhs))
   257   | Const(_,_) => ((pat, length vs), big_lambdas vs rhs)
   257   | Const(_,_) => (pat, (vs, rhs))
   258   | _ => fixrec_err ("function is not declared as constant in theory: "
   258   | _ => fixrec_err ("invalid function pattern: "
   259                     ^ ML_Syntax.print_term pat);
   259                     ^ ML_Syntax.print_term pat);
   260 
   260 
   261 fun strip_alls t =
   261 fun strip_alls t =
   262   if Logic.is_all t then strip_alls (snd (Logic.dest_all t)) else t;
   262   if Logic.is_all t then strip_alls (snd (Logic.dest_all t)) else t;
   263 
   263 
   266     val (lhs,rhs) = dest_eqs (Logic.strip_imp_concl (strip_alls eq));
   266     val (lhs,rhs) = dest_eqs (Logic.strip_imp_concl (strip_alls eq));
   267   in
   267   in
   268     compile_lhs match_name lhs (mk_succeed rhs) [] (taken_names eq)
   268     compile_lhs match_name lhs (mk_succeed rhs) [] (taken_names eq)
   269   end;
   269   end;
   270 
   270 
   271 (* returns the sum (using +++) of the terms in ms *)
       
   272 (* also applies "run" to the result! *)
       
   273 fun fatbar arity ms =
       
   274   let
       
   275     fun LAM_Ts 0 t = ([], Term.fastype_of t)
       
   276       | LAM_Ts n (_ $ Abs(_,T,t)) =
       
   277           let val (Ts, U) = LAM_Ts (n-1) t in (T::Ts, U) end
       
   278       | LAM_Ts _ _ = fixrec_err "fatbar: internal error, not enough LAMs";
       
   279     fun unLAM 0 t = t
       
   280       | unLAM n (_$Abs(_,_,t)) = unLAM (n-1) t
       
   281       | unLAM _ _ = fixrec_err "fatbar: internal error, not enough LAMs";
       
   282     fun reLAM ([], U) t = t
       
   283       | reLAM (T::Ts, U) t = reLAM (Ts, T ->> U) (cabs_const(T,U)$Abs("",T,t));
       
   284     val msum = foldr1 mk_mplus (map (unLAM arity) ms);
       
   285     val (Ts, U) = LAM_Ts arity (hd ms)
       
   286   in
       
   287     reLAM (rev Ts, dest_matchT U) (mk_run msum)
       
   288   end;
       
   289 
       
   290 (* this is the pattern-matching compiler function *)
   271 (* this is the pattern-matching compiler function *)
   291 fun compile_eqs match_name eqs =
   272 fun compile_eqs match_name eqs =
   292   let
   273   let
   293     val ((names, arities), mats) =
   274     val (consts, matchers) =
   294       apfst ListPair.unzip (ListPair.unzip (map (compile_eq match_name) eqs));
   275       ListPair.unzip (map (compile_eq match_name) eqs);
   295     val cname =
   276     val const =
   296         case distinct (op =) names of
   277         case distinct (op =) consts of
   297           [n] => n
   278           [n] => n
   298         | _ => fixrec_err "all equations in block must define the same function";
   279         | _ => fixrec_err "all equations in block must define the same function";
   299     val arity =
   280     val vars =
   300         case distinct (op =) arities of
   281         case distinct (op = o pairself length) (map fst matchers) of
   301           [a] => a
   282           [vars] => vars
   302         | _ => fixrec_err "all equations in block must have the same arity";
   283         | _ => fixrec_err "all equations in block must have the same arity";
   303     val rhs = fatbar arity mats;
   284     (* rename so all matchers use same free variables *)
   304   in
   285     fun rename (vs, t) = Term.subst_free (filter_out (op =) (vs ~~ vars)) t;
   305     mk_trp (cname === rhs)
   286     val rhs = big_lambdas vars (mk_run (foldr1 mk_mplus (map rename matchers)));
       
   287   in
       
   288     mk_trp (const === rhs)
   306   end;
   289   end;
   307 
   290 
   308 (*************************************************************************)
   291 (*************************************************************************)
   309 (********************** Proving associated theorems **********************)
   292 (********************** Proving associated theorems **********************)
   310 (*************************************************************************)
   293 (*************************************************************************)