src/HOL/SMT/Tools/smt_monomorph.ML
changeset 34010 ac78f5cdc430
parent 33047 69780aef0531
child 36896 c030819254d3
equal deleted inserted replaced
33937:b5ca587d0885 34010:ac78f5cdc430
    72   in (envs @ ces, (fold (insert (op aconv)) us ts, ns')) end
    72   in (envs @ ces, (fold (insert (op aconv)) us ts, ns')) end
    73 
    73 
    74 
    74 
    75 fun incr_tvar_indices i t =
    75 fun incr_tvar_indices i t =
    76   let
    76   let
    77     val incrT = Logic.incr_tvar i
    77     val incrT = Logic.incr_tvar_same i
    78 
    78 
    79     fun incr t =
    79     fun incr t =
    80       (case t of
    80       (case t of
    81         Const (n, T) => Const (n, incrT T)
    81         Const (n, T) => Const (n, incrT T)
    82       | Free (n, T) => Free (n, incrT T)
    82       | Free (n, T) => Free (n, incrT T)
    83       | Abs (n, T, t1) => Abs (n, incrT T, incr t1)
    83       | Abs (n, T, t1) => (Abs (n, incrT T, incr t1 handle Same.SAME => t1)
    84       | t1 $ t2 => incr t1 $ incr t2
    84           handle Same.SAME => Abs (n, T, incr t1))
    85       | _ => t)
    85       | t1 $ t2 => (incr t1 $ (incr t2 handle Same.SAME => t2)
    86   in incr t end
    86           handle Same.SAME => t1 $ incr t2)
       
    87       | _ => Same.same t)
       
    88   in incr t handle Same.SAME => t end
    87 
    89 
    88 
    90 
    89 val monomorph_limit = 10
    91 val monomorph_limit = 10
    90 
    92 
    91 (* Instantiate all polymorphic constants (i.e., constants occurring both with
    93 (* Instantiate all polymorphic constants (i.e., constants occurring both with
    92    ground types and type variables) with all (necessary) ground types; thereby
    94    ground types and type variables) with all (necessary) ground types; thereby
    93    create copies of terms containing those constants.
    95    create copies of terms containing those constants.
    94    To prevent non-termination, there is an upper limit for the number of
    96    To prevent non-termination, there is an upper limit for the number of
    95    recursions involved in the fixpoint construction. *)
    97    recursions involved in the fixpoint construction. *)
    96 fun monomorph thy ts =
    98 fun monomorph thy =
    97   let
    99   let
    98     val (ps, ms) = List.partition term_has_tvars ts
   100     fun incr t idx = (incr_tvar_indices idx t, idx + Term.maxidx_of_term t + 1)
       
   101     fun incr_indices ts = fst (fold_map incr ts 0)
    99 
   102 
   100     fun with_tvar (n, Ts) =
   103     fun with_tvar (n, Ts) =
   101       let val Ts' = filter typ_has_tvars Ts
   104       let val Ts' = filter typ_has_tvars Ts
   102       in if null Ts' then NONE else SOME (n, Ts') end
   105       in if null Ts' then NONE else SOME (n, Ts') end
   103     fun incr t idx = (incr_tvar_indices idx t, idx + Term.maxidx_of_term t + 1)
   106     fun extract_consts_with_tvar t = (t, map_filter with_tvar (consts_of [t]))
   104     val rps = fst (fold_map incr ps 0)
       
   105       |> map (fn r => (r, map_filter with_tvar (consts_of [r])))
       
   106 
   107 
   107     fun mono count is ces cs ts =
   108     fun mono rps count is ces cs ts =
   108       let
   109       let
   109         val spec = specialize thy cs is
   110         val spec = specialize thy cs is
   110         val (ces', (ts', is')) = fold_map spec (rps ~~ ces) (ts, [])
   111         val (ces', (ts', is')) = fold_map spec (rps ~~ ces) (ts, [])
   111         val cs' = join_consts is cs
   112         val cs' = join_consts is cs
   112       in
   113       in
   113         if null is' then ts'
   114         if null is' then ts'
   114         else if count > monomorph_limit then
   115         else if count > monomorph_limit then
   115           (warning "monomorphization limit reached"; ts')
   116           (warning "monomorphization limit reached"; ts')
   116         else mono (count + 1) is' ces' cs' ts'
   117         else mono rps (count + 1) is' ces' cs' ts'
   117       end
   118       end
   118   in mono 0 (consts_of ms) (map (K []) rps) [] ms end
   119     fun mono_all rps ms = if null rps then ms
       
   120       else mono rps 0 (consts_of ms) (map (K []) rps) [] ms
       
   121   in
       
   122     List.partition term_has_tvars
       
   123     #>> incr_indices
       
   124     #>> map extract_consts_with_tvar
       
   125     #-> mono_all
       
   126   end
   119 
   127 
   120 end
   128 end