src/HOL/BNF/Tools/bnf_fp_rec_sugar_util.ML
changeset 53866 7c23df53af01
parent 53865 cadccda5be03
child 53867 8ad44ecc0d15
equal deleted inserted replaced
53865:cadccda5be03 53866:7c23df53af01
   195       | massage_call t = if t = y then y_of_y' () else ill_formed_rec_call ctxt t;
   195       | massage_call t = if t = y then y_of_y' () else ill_formed_rec_call ctxt t;
   196   in
   196   in
   197     massage_call
   197     massage_call
   198   end;
   198   end;
   199 
   199 
   200 fun massage_let_if ctxt has_call massage_leaf bound_Ts U =
   200 (* TODO: also support old-style datatypes.
       
   201    (Ideally, we would have a proper registry for these things.) *)
       
   202 fun case_of ctxt =
       
   203   fp_sugar_of ctxt #> Option.map (fst o dest_Const o #casex o of_fp_sugar #ctr_sugars);
       
   204 
       
   205 fun fold_rev_let_if_case ctxt f bound_Ts =
       
   206   let
       
   207     fun fld t =
       
   208       (case Term.strip_comb t of
       
   209         (Const (@{const_name Let}, _), [arg1, arg2]) => fld (betapply (arg2, arg1))
       
   210       | (Const (@{const_name If}, _), _ :: branches) => fold_rev fld branches
       
   211       | (Const (c, _), args as _ :: _) =>
       
   212         let val (branches, obj) = split_last args in
       
   213           (case fastype_of1 (bound_Ts, obj) of
       
   214             Type (T_name, _) => if case_of ctxt T_name = SOME c then fold_rev fld branches else f t
       
   215           | _ => f t)
       
   216         end
       
   217       | _ => f t)
       
   218   in
       
   219     fld
       
   220   end;
       
   221 
       
   222 fun massage_let_if_case ctxt has_call massage_leaf bound_Ts U =
   201   let
   223   let
   202     val typof = curry fastype_of1 bound_Ts;
   224     val typof = curry fastype_of1 bound_Ts;
   203     val check_obj = ((not o has_call) orf unexpected_corec_call ctxt);
   225     val check_obj = ((not o has_call) orf unexpected_corec_call ctxt);
   204 
   226 
   205     fun massage_rec t =
   227     fun massage_rec t =
   206       (case Term.strip_comb t of
   228       (case Term.strip_comb t of
   207         (Const (@{const_name Let}, _), [arg1, arg2]) => massage_rec (betapply (arg2, arg1))
   229         (Const (@{const_name Let}, _), [arg1, arg2]) => massage_rec (betapply (arg2, arg1))
   208       | (Const (@{const_name If}, _), obj :: branches) =>
   230       | (Const (@{const_name If}, _), obj :: branches) =>
   209         list_comb (If_const U $ tap check_obj obj, map massage_rec branches)
   231         list_comb (If_const U $ tap check_obj obj, map massage_rec branches)
   210       | (Const (@{const_name nat_case}, _), args) =>
   232       | (Const (c, _), args as _ :: _) =>
   211         (* Proof of concept -- should be extensible to all case-like constructs *)
   233         let val (branches, obj) = split_last args in
   212         let
   234           (case fastype_of1 (bound_Ts, obj) of
   213           val (branches, obj) = split_last args;
   235             Type (T_name, _) =>
   214           val branches' = map massage_rec branches
   236             if case_of ctxt T_name = SOME c then
   215           (* FIXME: bound_Ts *)
   237               let
   216           val casex' = Const (@{const_name nat_case}, map typof branches' ---> typof obj);
   238                 val branches' = map massage_rec branches;
   217         in
   239                 val casex' = Const (c, map typof branches' ---> typof obj);
   218           list_comb (casex', branches') $ tap check_obj obj
   240               in
       
   241                 list_comb (casex', branches') $ tap check_obj obj
       
   242               end
       
   243             else
       
   244               massage_leaf t
       
   245           | _ => massage_leaf t)
   219         end
   246         end
   220       | _ => massage_leaf t)
   247       | _ => massage_leaf t)
   221   in
   248   in
   222     massage_rec
   249     massage_rec
   223   end;
   250   end;
   224 
   251 
   225 val massage_direct_corec_call = massage_let_if;
   252 val massage_direct_corec_call = massage_let_if_case;
   226 
   253 
   227 fun massage_indirect_corec_call ctxt has_call raw_massage_call bound_Ts U t =
   254 fun massage_indirect_corec_call ctxt has_call raw_massage_call bound_Ts U t =
   228   let
   255   let
   229     val typof = curry fastype_of1 bound_Ts;
   256     val typof = curry fastype_of1 bound_Ts;
   230     val build_map_Inl = build_map ctxt (uncurry Inl_const o dest_sumT o snd)
   257     val build_map_Inl = build_map ctxt (uncurry Inl_const o dest_sumT o snd)
   256       else
   283       else
   257         massage_map U T t
   284         massage_map U T t
   258         handle AINT_NO_MAP _ => massage_direct_fun U T t;
   285         handle AINT_NO_MAP _ => massage_direct_fun U T t;
   259 
   286 
   260     fun massage_call U T =
   287     fun massage_call U T =
   261       massage_let_if ctxt has_call (fn t =>
   288       massage_let_if_case ctxt has_call (fn t =>
   262         if has_call t then
   289         if has_call t then
   263           (case U of
   290           (case U of
   264             Type (s, Us) =>
   291             Type (s, Us) =>
   265             (case try (dest_ctr ctxt s) t of
   292             (case try (dest_ctr ctxt s) t of
   266               SOME (f, args) =>
   293               SOME (f, args) =>
   298   | NONE => raise Fail "expand_ctr_term");
   325   | NONE => raise Fail "expand_ctr_term");
   299 
   326 
   300 fun expand_corec_code_rhs ctxt has_call bound_Ts t =
   327 fun expand_corec_code_rhs ctxt has_call bound_Ts t =
   301   (case fastype_of1 (bound_Ts, t) of
   328   (case fastype_of1 (bound_Ts, t) of
   302     T as Type (s, Ts) =>
   329     T as Type (s, Ts) =>
   303     massage_let_if ctxt has_call (fn t =>
   330     massage_let_if_case ctxt has_call (fn t =>
   304       if can (dest_ctr ctxt s) t then t
   331       if can (dest_ctr ctxt s) t then
   305       else massage_let_if ctxt has_call I bound_Ts T (expand_ctr_term ctxt s Ts t)) bound_Ts T t
   332         t
       
   333       else
       
   334         massage_let_if_case ctxt has_call I bound_Ts T (expand_ctr_term ctxt s Ts t)) bound_Ts T t
   306   | _ => raise Fail "expand_corec_code_rhs");
   335   | _ => raise Fail "expand_corec_code_rhs");
   307 
   336 
   308 fun massage_corec_code_rhs ctxt massage_ctr =
   337 fun massage_corec_code_rhs ctxt massage_ctr =
   309   massage_let_if ctxt (K false) (uncurry massage_ctr o Term.strip_comb);
   338   massage_let_if_case ctxt (K false) (uncurry massage_ctr o Term.strip_comb);
   310 
   339 
   311 (* TODO: also support old-style datatypes.
   340 fun fold_rev_corec_code_rhs ctxt f = fold_rev_let_if_case ctxt (uncurry f o Term.strip_comb);
   312    (Ideally, we would have a proper registry for these things.) *)
       
   313 fun case_of ctxt =
       
   314   fp_sugar_of ctxt #> Option.map (fst o dest_Const o #casex o of_fp_sugar #ctr_sugars);
       
   315 
       
   316 fun fold_rev_let_if ctxt f bound_Ts =
       
   317   let
       
   318     fun fld t =
       
   319       (case Term.strip_comb t of
       
   320         (Const (@{const_name Let}, _), [arg1, arg2]) => fld (betapply (arg2, arg1))
       
   321       | (Const (@{const_name If}, _), _ :: branches) => fold_rev fld branches
       
   322       | (Const (c, _), args as _ :: _) =>
       
   323         let val (branches, obj) = split_last args in
       
   324           (case fastype_of1 (bound_Ts, obj) of
       
   325             Type (T_name, _) => if case_of ctxt T_name = SOME c then fold_rev fld branches else f t
       
   326           | _ => f t)
       
   327         end
       
   328       | _ => f t)
       
   329   in
       
   330     fld
       
   331   end;
       
   332 
       
   333 fun fold_rev_corec_code_rhs ctxt f = fold_rev_let_if ctxt (uncurry f o Term.strip_comb);
       
   334 
   341 
   335 fun add_conjuncts (Const (@{const_name conj}, _) $ t $ t') = add_conjuncts t o add_conjuncts t'
   342 fun add_conjuncts (Const (@{const_name conj}, _) $ t $ t') = add_conjuncts t o add_conjuncts t'
   336   | add_conjuncts t = cons t;
   343   | add_conjuncts t = cons t;
   337 
   344 
   338 fun conjuncts t = add_conjuncts t [];
   345 fun conjuncts t = add_conjuncts t [];