simplify fixrec pattern match function
authorhuffman
Tue, 19 Oct 2010 07:05:04 -0700
changeset 40036 a81758e0394d
parent 40035 a12d35795cb9
child 40037 81e6b89d8f58
simplify fixrec pattern match function
src/HOLCF/Tools/fixrec.ML
--- a/src/HOLCF/Tools/fixrec.ML	Sun Oct 17 09:53:47 2010 -0700
+++ b/src/HOLCF/Tools/fixrec.ML	Tue Oct 19 07:05:04 2010 -0700
@@ -247,15 +247,15 @@
   end;
 
 (* builds a monadic term for matching a function definition pattern *)
-(* returns (name, arity, matcher) *)
+(* returns (constant, (vars, matcher)) *)
 fun compile_lhs match_name pat rhs vs taken =
   case pat of
     Const(@{const_name Rep_CFun}, _) $ f $ x =>
       let val (rhs', v, taken') = compile_pat match_name x rhs taken;
       in compile_lhs match_name f rhs' (v::vs) taken' end
-  | Free(_,_) => ((pat, length vs), big_lambdas vs rhs)
-  | Const(_,_) => ((pat, length vs), big_lambdas vs rhs)
-  | _ => fixrec_err ("function is not declared as constant in theory: "
+  | Free(_,_) => (pat, (vs, rhs))
+  | Const(_,_) => (pat, (vs, rhs))
+  | _ => fixrec_err ("invalid function pattern: "
                     ^ ML_Syntax.print_term pat);
 
 fun strip_alls t =
@@ -268,41 +268,24 @@
     compile_lhs match_name lhs (mk_succeed rhs) [] (taken_names eq)
   end;
 
-(* returns the sum (using +++) of the terms in ms *)
-(* also applies "run" to the result! *)
-fun fatbar arity ms =
-  let
-    fun LAM_Ts 0 t = ([], Term.fastype_of t)
-      | LAM_Ts n (_ $ Abs(_,T,t)) =
-          let val (Ts, U) = LAM_Ts (n-1) t in (T::Ts, U) end
-      | LAM_Ts _ _ = fixrec_err "fatbar: internal error, not enough LAMs";
-    fun unLAM 0 t = t
-      | unLAM n (_$Abs(_,_,t)) = unLAM (n-1) t
-      | unLAM _ _ = fixrec_err "fatbar: internal error, not enough LAMs";
-    fun reLAM ([], U) t = t
-      | reLAM (T::Ts, U) t = reLAM (Ts, T ->> U) (cabs_const(T,U)$Abs("",T,t));
-    val msum = foldr1 mk_mplus (map (unLAM arity) ms);
-    val (Ts, U) = LAM_Ts arity (hd ms)
-  in
-    reLAM (rev Ts, dest_matchT U) (mk_run msum)
-  end;
-
 (* this is the pattern-matching compiler function *)
 fun compile_eqs match_name eqs =
   let
-    val ((names, arities), mats) =
-      apfst ListPair.unzip (ListPair.unzip (map (compile_eq match_name) eqs));
-    val cname =
-        case distinct (op =) names of
+    val (consts, matchers) =
+      ListPair.unzip (map (compile_eq match_name) eqs);
+    val const =
+        case distinct (op =) consts of
           [n] => n
         | _ => fixrec_err "all equations in block must define the same function";
-    val arity =
-        case distinct (op =) arities of
-          [a] => a
+    val vars =
+        case distinct (op = o pairself length) (map fst matchers) of
+          [vars] => vars
         | _ => fixrec_err "all equations in block must have the same arity";
-    val rhs = fatbar arity mats;
+    (* rename so all matchers use same free variables *)
+    fun rename (vs, t) = Term.subst_free (filter_out (op =) (vs ~~ vars)) t;
+    val rhs = big_lambdas vars (mk_run (foldr1 mk_mplus (map rename matchers)));
   in
-    mk_trp (cname === rhs)
+    mk_trp (const === rhs)
   end;
 
 (*************************************************************************)