src/HOLCF/fixrec_package.ML
changeset 16387 67f6044c1891
parent 16226 c17ac524d866
child 16401 57c35ede00b9
equal deleted inserted replaced
16386:c6f5ade29608 16387:67f6044c1891
    25 (* extern_name is taken from domain/library.ML *)
    25 (* extern_name is taken from domain/library.ML *)
    26 fun extern_name con = case Symbol.explode con of 
    26 fun extern_name con = case Symbol.explode con of 
    27 		   ("o"::"p"::" "::rest) => implode rest
    27 		   ("o"::"p"::" "::rest) => implode rest
    28 		   | _ => con;
    28 		   | _ => con;
    29 
    29 
    30 (***************  This is the building functional **************)
       
    31 
       
    32 (* converts string proposition to a cterm *)
       
    33 fun all eqs thy = (let val sign = sign_of (thy) in 
       
    34                   ((term_of o (Thm.read_cterm sign) o rpair propT) eqs) end);
       
    35 
       
    36 (* splits a cterm into the right and lefthand sides of equality *)
    30 (* splits a cterm into the right and lefthand sides of equality *)
    37 fun dest_eqs (Const ("==", _)$lhs$rhs) = (lhs, rhs)
    31 fun dest_eqs (Const ("==", _)$lhs$rhs) = (lhs, rhs)
    38   | dest_eqs (Const ("Trueprop", _)$(Const ("op =", _)$lhs$rhs))    = (lhs,rhs)
    32   | dest_eqs (Const ("Trueprop", _)$(Const ("op =", _)$lhs$rhs))    = (lhs,rhs)
    39   | dest_eqs t = sys_error (Sign.string_of_term (sign_of (the_context())) t);
    33   | dest_eqs t = sys_error (Sign.string_of_term (sign_of (the_context())) t);
    40 
    34 
    41 (* building fixpoint functional def for an equation with only 
    35 (* these are helpful functions copied from HOLCF/domain/library.ML *)
    42    variables as parameters *)
    36 fun %: s = Free(s,dummyT);
       
    37 fun %%: s = Const(s,dummyT);
       
    38 infix 0 ==;  fun S ==  T = %%:"==" $ S $ T;
       
    39 infix 1 ===; fun S === T = %%:"op =" $ S $ T;
       
    40 infix 9 `  ; fun f ` x = %%:"Rep_CFun" $ f $ x;
       
    41 
       
    42 (* infers the type of a term *)
       
    43 fun infer t thy = #1 (Sign.infer_types (Sign.pp (sign_of thy)) (sign_of thy) (K NONE) (K NONE) [] true ([t],dummyT));
       
    44 
       
    45 (*************************************************************************)
       
    46 (************ fixed-point definitions and unfolding theorems *************)
       
    47 (*************************************************************************)
       
    48 
    43 fun func1 (lhs as Const(name,T), rhs) =
    49 fun func1 (lhs as Const(name,T), rhs) =
    44   let
    50   let
    45     val basename = Sign.base_name name;
    51     val basename = Sign.base_name name;
    46     val funcT = T ->> T;
    52     val funcT = T ->> T;
    47     val functional = Const ("Cfun.Abs_CFun", (T --> T) --> funcT) $
    53     val functional = Const ("Cfun.Abs_CFun", (T --> T) --> funcT) $
    53   in
    59   in
    54     (name, (basename^"_fixdef", equals T $ lhs $ rhs'))
    60     (name, (basename^"_fixdef", equals T $ lhs $ rhs'))
    55   end
    61   end
    56   | func1 t = sys_error "func1: not of correct form";
    62   | func1 t = sys_error "func1: not of correct form";
    57 
    63 
    58 (**************************************************************)
    64 (*************************************************************************)
    59 (* these are helpful functions copied from HOLCF/domain/library.ML *)
    65 (*********** monadic notation and pattern matching compilation ***********)
    60 fun %: s = Free(s,dummyT);
    66 (*************************************************************************)
    61 fun %%: s = Const(s,dummyT);
       
    62 infix 0 ==;  fun S ==  T = %%:"==" $ S $ T;
       
    63 infix 1 ===; fun S === T = %%:"op =" $ S $ T;
       
    64 infix 9 `  ; fun f`  x = %%:"Rep_CFun" $ f $ x;
       
    65 
       
    66 (* infers the type of a term *)
       
    67 fun infer t thy = #1 (Sign.infer_types (Sign.pp (sign_of thy)) (sign_of thy) (K NONE) (K NONE) [] true ([t],dummyT));
       
    68 
       
    69 (*****************************************************************)
       
    70 (* monadic notation and pattern matching *)
       
    71 (*****************************************************************)
       
    72 
    67 
    73 (* these 3 functions strip off parameters and destruct constructors *)
    68 (* these 3 functions strip off parameters and destruct constructors *)
       
    69 (*
    74 fun strip_cpair (Const("Cfun.Rep_CFun",_) $
    70 fun strip_cpair (Const("Cfun.Rep_CFun",_) $
    75       (Const("Cfun.Rep_CFun",_) $ Const("Cprod.cpair",_) $ b) $ r) =
    71       (Const("Cfun.Rep_CFun",_) $ Const("Cprod.cpair",_) $ b) $ r) =
    76         b :: strip_cpair r
    72         b :: strip_cpair r
    77   | strip_cpair c = [c];
    73   | strip_cpair c = [c];
    78 
    74 *)
    79 fun big_lambda v rhs = %%:"Cfun.Abs_CFun"$(lambda v rhs);
    75 fun big_lambda v rhs = %%:"Cfun.Abs_CFun"$(lambda v rhs);
    80 
    76 
    81 fun big_lambdas [] rhs = rhs
    77 fun big_lambdas [] rhs = rhs
    82   | big_lambdas (v::vs) rhs = big_lambda v (big_lambdas vs rhs);
    78   | big_lambdas (v::vs) rhs = big_lambda v (big_lambdas vs rhs);
    83 
    79 
    93   | add_names (Abs(a,_,t), bs) = add_names(t,a ins_string bs)
    89   | add_names (Abs(a,_,t), bs) = add_names(t,a ins_string bs)
    94   | add_names (_, bs) = bs;
    90   | add_names (_, bs) = bs;
    95 fun add_terms ts xs = foldr add_names xs ts;
    91 fun add_terms ts xs = foldr add_names xs ts;
    96 
    92 
    97 (* builds a monadic term for matching a constructor pattern *)
    93 (* builds a monadic term for matching a constructor pattern *)
    98 fun pre_build (Const("Cfun.Rep_CFun",_)$f$(v as Free(n,T))) rhs vs taken =
    94 fun pre_build pat rhs vs taken =
       
    95   case pat of
       
    96     Const("Cfun.Rep_CFun",_)$f$(v as Free(n,T)) =>
    99       pre_build f rhs (v::vs) taken
    97       pre_build f rhs (v::vs) taken
   100   | pre_build (Const("Cfun.Rep_CFun",_)$f$x) rhs vs taken =
    98   | Const("Cfun.Rep_CFun",_)$f$x =>
   101       let val (rhs', v, taken') = pre_build x rhs [] taken;
    99       let val (rhs', v, taken') = pre_build x rhs [] taken;
   102       in
   100       in pre_build f rhs' (v::vs) taken' end
   103         pre_build f rhs' (v::vs) taken'
   101   | Const(c,T) =>
   104       end
       
   105   | pre_build (Const(c,T)) rhs vs taken =
       
   106       let
   102       let
   107         val n = variant taken "v";
   103         val n = variant taken "v";
   108         fun result_type (Type(_,[_,T])) (x::xs) = result_type T xs
   104         fun result_type (Type("Cfun.->",[_,T])) (x::xs) = result_type T xs
   109           | result_type T _ = T;
   105           | result_type T _ = T;
   110         val v = Free(n, result_type T vs);
   106         val v = Free(n, result_type T vs);
   111         val m = "match_"^(extern_name(NameSpace.base c));
   107         val m = "match_"^(extern_name(NameSpace.base c));
   112         val k = lambda_tuple vs rhs;
   108         val k = lambda_tuple vs rhs;
   113       in
   109       in
   114         (%%:"bind"`(%%:m`v)`k, v, n::taken)
   110         (%%:"Fixrec.bind"`(%%:m`v)`k, v, n::taken)
   115       end;
   111       end;
   116 
   112 
   117 (* builds a monadic term for matching a function definition pattern *)
   113 (* builds a monadic term for matching a function definition pattern *)
   118 (* returns (name, arity, matcher) *)
   114 (* returns (name, arity, matcher) *)
   119 fun building (Const("Cfun.Rep_CFun", _)$f$(v as Free(n,T))) rhs vs taken =
   115 fun building pat rhs vs taken =
       
   116   case pat of
       
   117     Const("Cfun.Rep_CFun", _)$f$(v as Free(n,T)) =>
   120       building f rhs (v::vs) taken
   118       building f rhs (v::vs) taken
   121   | building (Const("Cfun.Rep_CFun", _)$f$x) rhs vs taken =
   119   | Const("Cfun.Rep_CFun", _)$f$x =>
   122       let
   120       let val (rhs', v, taken') = pre_build x rhs [] taken;
   123         val (rhs', v, taken') = pre_build x rhs [] taken;
   121       in building f rhs' (v::vs) taken' end
   124       in
   122   | Const(_,_) => (pat, length vs, big_lambdas vs rhs)
   125         building f rhs' (v::vs) taken'
   123   | _ => sys_error "function is not declared as constant in theory";
   126       end
   124 
   127   | building (c as Const(_,_)) rhs vs taken = (c, length vs, big_lambdas vs rhs)
   125 fun match_eq eq = 
   128   | building _ _ _ _ = sys_error "function is not declared as constant in theory";
   126   let
   129 
   127     val (lhs,rhs) = dest_eqs eq;
   130 fun match_eq thy f = 
   128     val (Const(name,_), arity, term) =
   131   let
   129       building lhs (%%:"Fixrec.return"`rhs) [] (add_terms [eq] []);
   132     val e = (all f thy);
   130   in (name, arity, term) end;
   133     val (lhs,rhs) = dest_eqs e;
       
   134     val (Const(n,_), a, t) = building lhs (%%:"return"`rhs) [] (add_terms [e] []);
       
   135   in
       
   136     (n, a, t)
       
   137   end;
       
   138 
   131 
   139 (* returns the sum (using +++) of the terms in ms *)
   132 (* returns the sum (using +++) of the terms in ms *)
   140 (* also applies "run" to the result! *)
   133 (* also applies "run" to the result! *)
   141 fun fatbar arity ms =
   134 fun fatbar arity ms =
   142   let
   135   let
   149     val msum = foldr1 mplus (map (unLAM arity) ms);
   142     val msum = foldr1 mplus (map (unLAM arity) ms);
   150   in
   143   in
   151     reLAM arity (%%:"Fixrec.run"`msum)
   144     reLAM arity (%%:"Fixrec.run"`msum)
   152   end;
   145   end;
   153 
   146 
   154 (***************************************************************)
   147 fun unzip3 [] = ([],[],[])
   155 (*** Proving associated theorems ***)
   148   | unzip3 ((x,y,z)::ts) =
   156 
   149       let val (xs,ys,zs) = unzip3 ts
   157 fun prove_thm thy unfold_thm x =
   150       in (x::xs, y::ys, z::zs) end;
       
   151 
       
   152 (* this is the pattern-matching compiler function *)
       
   153 fun compile_pats eqs = 
       
   154   let
       
   155     val ((n::names),(a::arities),mats) = unzip3 (map match_eq eqs);
       
   156     val cname = if forall (fn x => n=x) names then n
       
   157           else sys_error "FIXREC: all equations must define the same function";
       
   158     val arity = if forall (fn x => a=x) arities then a
       
   159           else sys_error "FIXREC: all equations must have the same arity";
       
   160     val rhs = fatbar arity mats;
       
   161   in
       
   162     HOLogic.mk_Trueprop (%%:cname === rhs)
       
   163   end;
       
   164 
       
   165 (*************************************************************************)
       
   166 (********************** Proving associated theorems **********************)
       
   167 (*************************************************************************)
       
   168 
       
   169 fun prove_thm thy unfold_thm ct =
   158   let
   170   let
   159     val ss = simpset_of thy;
   171     val ss = simpset_of thy;
   160     val thm = prove_goalw thy [] x (fn _ => [SOLVE(stac unfold_thm 1 THEN simp_tac ss 1)])
   172     val thm = prove_goalw_cterm [] ct
   161       handle _ => sys_error (x^" :: proof failed on this equation.");
   173       (fn _ => [SOLVE(stac unfold_thm 1 THEN simp_tac ss 1)])
       
   174         handle _ => sys_error (string_of_cterm ct^" :: proof failed on this equation.");
   162   in thm end;
   175   in thm end;
   163 
   176 
   164 (* this proves that each equation is a theorem *)
   177 (* this proves that each equation is a theorem *)
   165 fun prove_list thy unfold_thm [] = []
   178 fun prove_list thy unfold_thm [] = []
   166   | prove_list thy unfold_thm (x::xs) =
   179   | prove_list thy unfold_thm (x::xs) =
   182       , (basename^"_rews", rew_thms) ]
   195       , (basename^"_rews", rew_thms) ]
   183   in
   196   in
   184     (#1 o PureThy.add_thmss (map Thm.no_attributes thmss)) thy'
   197     (#1 o PureThy.add_thmss (map Thm.no_attributes thmss)) thy'
   185   end;
   198   end;
   186 
   199 
   187 fun unzip3 [] = ([],[],[])
   200 (*************************************************************************)
   188   | unzip3 ((x,y,z)::ts) =
   201 (************************* Main fixrec function **************************)
   189       let val (xs,ys,zs) = unzip3 ts
   202 (*************************************************************************)
   190       in (x::xs, y::ys, z::zs) end;
   203 
   191 
   204 (* this calls the main processing function and then returns the new state *)
   192 (* this is the main processing function *)
   205 fun add_fixrec strs thy =
   193 fun pat_fun eqs thy = 
   206   let
   194   let
   207     val sg = sign_of thy;
   195     val ((n::names),(a::arities),mats) = unzip3 (map (match_eq thy) eqs);
   208     val cts = map (Thm.read_cterm sg o rpair propT) strs;
   196     val cname = if forall (fn x => n=x) names then n
   209     val eqs = map term_of cts;
   197           else sys_error "PAT_FUN: all equations must define the same function";
   210     val funcc = infer (compile_pats eqs) thy;
   198     val arity = if forall (fn x => a=x) arities then a
   211     val _ = print_cterm (cterm_of sg funcc);
   199           else sys_error "FIXREC: all equations must have the same arity";
       
   200     val msum = fatbar arity mats;
       
   201     val v = variant (add_term_names (msum,[])) "v";
       
   202     val funcc = infer (HOLogic.mk_Trueprop (%%:cname === msum)) thy;
       
   203     val (name', fixdef_name_term) = func1 (dest_eqs funcc);
   212     val (name', fixdef_name_term) = func1 (dest_eqs funcc);
   204     val (thy', [fixdef_thm]) =
   213     val (thy', [fixdef_thm]) =
   205       PureThy.add_defs_i false [Thm.no_attributes fixdef_name_term] thy;
   214       PureThy.add_defs_i false [Thm.no_attributes fixdef_name_term] thy;
   206   in
   215     val ct = cterm_of (sign_of thy') funcc;
   207     make_simp name' eqs (cterm_of (sign_of thy') funcc) fixdef_thm thy'
   216   in
   208   end;
   217     make_simp name' cts ct fixdef_thm thy'
   209 
   218   end;
   210 (***************************************************************)
   219 
   211 
   220 (*************************************************************************)
   212 (* this calls the main processing function and then returns the new state *)
   221 (******************************** Fixpat *********************************)
   213 fun add_fixrec eqs = pat_fun eqs;
   222 (*************************************************************************)
   214 
       
   215 (*****************************************************************)
       
   216 (*** Fixpat ***)
       
   217 
       
   218 (* like Term.strip_comb, but with continuous application *)
       
   219 fun strip_cdot u : term * term list =
       
   220   let fun stripc (Const("Cfun.Rep_CFun",_)$f$t, ts) = stripc (f, t::ts)
       
   221         | stripc x = x
       
   222   in stripc(u,[]) end;
       
   223 
   223 
   224 fun fix_pat name pat thy = 
   224 fun fix_pat name pat thy = 
   225   let
   225   let
   226     val sign = sign_of thy;
   226     val sign = sign_of thy;
   227     val ct = Thm.read_cterm sign (pat, dummyT);
   227     val t = term_of (Thm.read_cterm sign (pat, dummyT));
   228     val (Const (f,_), args) = strip_cdot (term_of ct);
   228     val T = fastype_of t;
   229     val unfold_thm = Goals.get_thm thy (f^"_unfold");
   229     val eq = HOLogic.mk_Trueprop (HOLogic.eq_const T $ t $ Var (("x",0),T));
   230     fun add_arg (arg,thm) = instantiate' [] [SOME arg] (thm RS cfun_fun_cong);
   230     fun head_const (Const ("Cfun.Rep_CFun",_) $ f $ t) = head_const f
   231     val unfold_thm' = foldl add_arg (freezeT unfold_thm) (map (cterm_of sign) args);
   231       | head_const (Const (c,_)) = c
   232     val thm = simplify (simpset_of thy) unfold_thm';
   232       | head_const _ = sys_error "FIXPAT: function is not declared as constant in theory";
       
   233     val c = head_const t;
       
   234     val unfold_thm = Goals.get_thm thy (c^"_unfold");
       
   235     val thm = prove_goalw_cterm [] (cterm_of sign eq)
       
   236           (fn _ => [stac unfold_thm 1, simp_tac (simpset_of thy) 1]);
       
   237     val _ = print_thm thm;
   233   in
   238   in
   234     (#1 o PureThy.add_thmss [Thm.no_attributes (name, [thm])]) thy
   239     (#1 o PureThy.add_thmss [Thm.no_attributes (name, [thm])]) thy
   235   end;
   240   end;
   236 
   241 
   237 fun add_fixpat (name,pat) = fix_pat name pat;
   242 fun add_fixpat (name,pat) = fix_pat name pat;
   238 
   243 
   239 (*****************************************************************)
   244 (*************************************************************************)
   240 (*** Parsers ***)
   245 (******************************** Parsers ********************************)
       
   246 (*************************************************************************)
   241 
   247 
   242 local structure P = OuterParse and K = OuterSyntax.Keyword in
   248 local structure P = OuterParse and K = OuterSyntax.Keyword in
   243 
   249 
   244 val fixrec_decl = Scan.repeat1 P.prop;
   250 val fixrec_decl = (*P.and_list1*) (Scan.repeat1 P.prop);
   245 
   251 
   246 (* this builds a parser for a new keyword, fixrec, whose functionality 
   252 (* this builds a parser for a new keyword, fixrec, whose functionality 
   247 is defined by add_fixrec *)
   253 is defined by add_fixrec *)
   248 val fixrecP =
   254 val fixrecP =
   249   OuterSyntax.command "fixrec" "parser for fixrec functions" K.thy_decl
   255   OuterSyntax.command "fixrec" "parser for fixrec functions" K.thy_decl