src/HOL/BNF/Tools/bnf_fp_rec_sugar.ML
changeset 53357 46b0c7a08af7
parent 53354 b7469b85ca28
child 53358 b46e6cd75dc6
equal deleted inserted replaced
53356:c5a1629d8e45 53357:46b0c7a08af7
    26 
    26 
    27 fun primrec_error str = raise Primrec_Error (str, []);
    27 fun primrec_error str = raise Primrec_Error (str, []);
    28 fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]);
    28 fun primrec_error_eqn str eqn = raise Primrec_Error (str, [eqn]);
    29 fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns);
    29 fun primrec_error_eqns str eqns = raise Primrec_Error (str, eqns);
    30 
    30 
       
    31 val free_name = try (fn Free (v, _) => v);
       
    32 val const_name = try (fn Const (v, _) => v);
       
    33 
    31 fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x);
    34 fun finds eq = fold_map (fn x => List.partition (curry eq x) #>> pair x);
    32 fun abs_tuple t = if try (fst o dest_Const) t = SOME @{const_name undefined} then t else
    35 fun abs_tuple t = if const_name t = SOME @{const_name undefined} then t else
    33   strip_abs t |>> HOLogic.mk_tuple o map Free |-> HOLogic.tupled_lambda;
    36   strip_abs t |>> HOLogic.mk_tuple o map Free |-> HOLogic.tupled_lambda;
    34 
    37 
    35 val simp_attrs = @{attributes [simp]};
    38 val simp_attrs = @{attributes [simp]};
    36 
    39 
    37 
    40 
   101      res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs,
   104      res_type = map fastype_of (left_args @ right_args) ---> fastype_of rhs,
   102      rhs_term = rhs,
   105      rhs_term = rhs,
   103      user_eqn = eqn'}
   106      user_eqn = eqn'}
   104   end;
   107   end;
   105 
   108 
   106 fun rewrite_map_arg funs_data get_indices y rec_type res_type =
   109 fun rewrite_map_arg funs_data get_indices rec_type res_type =
   107   let
   110   let
       
   111     val fun_data = hd (the (find_first (equal rec_type o #rec_type o hd) funs_data));
       
   112     val fun_name = #fun_name fun_data;
       
   113     val ctr_pos = length (#left_args fun_data);
       
   114 
   108     val pT = HOLogic.mk_prodT (rec_type, res_type);
   115     val pT = HOLogic.mk_prodT (rec_type, res_type);
   109     val fstx = fst_const pT;
   116 
   110     val sndx = snd_const pT;
   117     val maybe_suc = Option.map (fn x => x + 1);
   111 
   118     fun subst d (t as Bound d') = t |> d = SOME d' ? curry (op $) (fst_const pT)
   112     val SOME ({fun_name, left_args, ...} :: _) =
   119       | subst d (Abs (v, T, b)) = Abs (v, if d = SOME ~1 then pT else T, subst (maybe_suc d) b)
   113       find_first (equal rec_type o #rec_type o hd) funs_data;
   120       | subst d t =
   114     val ctr_pos = length left_args;
       
   115 
       
   116     fun subst _ d (t as Bound d') = t |> d = d' ? curry (op $) fstx
       
   117       | subst l d (Abs (v, T, b)) = Abs (v, if d < 0 then pT else T, subst l (d + 1) b)
       
   118       | subst l d t =
       
   119         let val (u, vs) = strip_comb t in
   121         let val (u, vs) = strip_comb t in
   120           if try (fst o dest_Free) u = SOME fun_name then
   122           if free_name u = SOME fun_name then
   121             if l andalso length vs = ctr_pos then
   123             if d = SOME ~1 andalso length vs = ctr_pos then
   122               list_comb (sndx |> permute_args ctr_pos, vs)
   124               list_comb (permute_args ctr_pos (snd_const pT), vs)
   123             else if length vs <= ctr_pos then
   125             else if length vs > ctr_pos andalso is_some d
   124               primrec_error_eqn "too few arguments in recursive call" t
   126                 andalso d = try (fn Bound n => n) (nth vs ctr_pos) then
   125             else if nth vs ctr_pos |> member (op =) [y, Bound d] then
   127               list_comb (snd_const pT $ nth vs ctr_pos, map (subst d) (nth_drop ctr_pos vs))
   126               list_comb (sndx $ nth vs ctr_pos, nth_drop ctr_pos vs |> map (subst false d))
       
   127             else
   128             else
   128               primrec_error_eqn "recursive call not directly applied to constructor argument" t
   129               primrec_error_eqn ("recursive call not directly applied to constructor argument") t
   129           else if try (fst o dest_Const) u = SOME @{const_name comp} then
   130           else if d = SOME ~1 andalso const_name u = SOME @{const_name comp} then
   130             (hd vs |> get_indices |> null orelse
   131             list_comb (map_types (K dummyT) u, map2 subst [NONE, d] vs)
   131               primrec_error_eqn "recursive call not directly applied to constructor argument" t;
       
   132             list_comb
       
   133               (u |> map_types (strip_type #>> (fn Ts => Ts
       
   134                    |> nth_map (length Ts - 1) (K pT)
       
   135                    |> nth_map (length Ts - 2) (strip_type #>> nth_map 0 (K pT) #> (op --->)))
       
   136                  #> (op --->)),
       
   137               nth_map 1 (subst l d) vs))
       
   138           else
   132           else
   139             list_comb (u, map (subst false d) vs)
   133             list_comb (u, map (subst (d |> d = SOME ~1 ? K NONE)) vs)
   140         end
   134         end
   141   in
   135   in
   142     subst true ~1
   136     subst (SOME ~1)
   143   end;
   137   end;
   144 
   138 
   145 (* FIXME get rid of funs_data or get_indices *)
   139 (* FIXME get rid of funs_data or get_indices *)
   146 fun subst_rec_calls lthy funs_data get_indices direct_calls indirect_calls t =
   140 fun subst_rec_calls lthy funs_data get_indices direct_calls indirect_calls t =
   147   let
   141   let
   162                 orelse primrec_error_eqn "too few arguments in recursive call" t;
   156                 orelse primrec_error_eqn "too few arguments in recursive call" t;
   163             list_comb (the maybe_direct_y', g_args))
   157             list_comb (the maybe_direct_y', g_args))
   164           else if is_some maybe_indirect_y' then
   158           else if is_some maybe_indirect_y' then
   165             (if contains_fun g then t else y)
   159             (if contains_fun g then t else y)
   166             |> massage_indirect_rec_call lthy contains_fun
   160             |> massage_indirect_rec_call lthy contains_fun
   167               (rewrite_map_arg funs_data get_indices y) bound_Ts y (the maybe_indirect_y')
   161               (rewrite_map_arg funs_data get_indices) bound_Ts y (the maybe_indirect_y')
   168             |> (if contains_fun g then I else curry (op $) g)
   162             |> (if contains_fun g then I else curry (op $) g)
   169           else
   163           else
   170             t
   164             t
   171         end
   165         end
   172       | subst _ t = t
   166       | subst _ t = t
   424       if is_none disc then the eq_ctr0 else find_index (equal (head_of (the disc))) discs;
   418       if is_none disc then the eq_ctr0 else find_index (equal (head_of (the disc))) discs;
   425     val ctr_no = if not_disc then 1 - ctr_no' else ctr_no';
   419     val ctr_no = if not_disc then 1 - ctr_no' else ctr_no';
   426     val fun_args = if is_none disc
   420     val fun_args = if is_none disc
   427       then imp_rhs |> perhaps (try HOLogic.dest_not) |> HOLogic.dest_eq |> fst |> strip_comb |> snd
   421       then imp_rhs |> perhaps (try HOLogic.dest_not) |> HOLogic.dest_eq |> fst |> strip_comb |> snd
   428       else the disc |> the_single o snd o strip_comb
   422       else the disc |> the_single o snd o strip_comb
   429         |> (fn t => if try (fst o dest_Free o head_of) t = SOME fun_name
   423         |> (fn t => if free_name (head_of t) = SOME fun_name
   430           then snd (strip_comb t) else []);
   424           then snd (strip_comb t) else []);
   431 
   425 
   432     val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True};
   426     val mk_conjs = try (foldr1 HOLogic.mk_conj) #> the_default @{const True};
   433     val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False};
   427     val mk_disjs = try (foldr1 HOLogic.mk_disj) #> the_default @{const False};
   434     val catch_all = try (fst o dest_Free o the_single) imp_lhs' = SOME Name.uu_;
   428     val catch_all = try (fst o dest_Free o the_single) imp_lhs' = SOME Name.uu_;