Skip to content

Commit 5935f4a

Browse files
authored
Constrain type in to_vec(::AbstractArray/Vector) to DenseArray/Vector (#156)
1 parent c319525 commit 5935f4a

3 files changed

Lines changed: 35 additions & 24 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FiniteDifferences"
22
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
3-
version = "0.12.2"
3+
version = "0.12.3"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/to_vec.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ function to_vec(x::T) where {T}
3434
v, vals_from_vec = to_vec(vals)
3535
function structtype_from_vec(v::Vector{<:Real})
3636
val_vecs = vals_from_vec(v)
37-
vals = map((b, v) -> b(v), backs, val_vecs)
38-
return T(vals...)
37+
values = map((b, v) -> b(v), backs, val_vecs)
38+
return T(values...)
3939
end
4040
return v, structtype_from_vec
4141
end
4242

43-
function to_vec(x::AbstractVector)
43+
function to_vec(x::DenseVector)
4444
x_vecs_and_backs = map(to_vec, x)
4545
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
4646
function Vector_from_vec(x_vec)
@@ -53,7 +53,7 @@ function to_vec(x::AbstractVector)
5353
return x_vec, Vector_from_vec
5454
end
5555

56-
function to_vec(x::AbstractArray)
56+
function to_vec(x::DenseArray)
5757
x_vec, from_vec = to_vec(vec(x))
5858

5959
function Array_from_vec(x_vec)
@@ -63,7 +63,6 @@ function to_vec(x::AbstractArray)
6363
return x_vec, Array_from_vec
6464
end
6565

66-
6766
# Some specific subtypes of AbstractArray.
6867
function to_vec(x::Base.ReshapedArray{<:Any, 1})
6968
x_vec, from_vec = to_vec(parent(x))

test/to_vec.jl

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@ end
1111
Base.:(==)(x::DummyType, y::DummyType) = x.X == y.X
1212
Base.length(x::DummyType) = size(x.X, 1)
1313

14-
# A dummy FillVector. This is a type for which the fallback implementation of
15-
# `to_vec` should fail loudly.
14+
# A dummy FillVector
1615
struct FillVector <: AbstractVector{Float64}
1716
x::Float64
1817
len::Int
1918
end
2019

20+
Base.size(x::FillVector) = (x.len,)
21+
Base.getindex(x::FillVector, n::Int) = x.x
22+
2123
# For testing Composite{ThreeFields}
2224
struct ThreeFields
2325
a
@@ -32,10 +34,17 @@ struct Nested
3234
y::Singleton
3335
end
3436

35-
Base.size(x::FillVector) = (x.len,)
36-
Base.getindex(x::FillVector, n::Int) = x.x
37+
# For testing generic subtypes of AbstractArray
38+
struct WrapperArray{T, N, A<:AbstractArray{T, N}} <: AbstractArray{T, N}
39+
data::A
40+
end
41+
function WrapperArray(a::AbstractArray{T, N}) where {T, N}
42+
return WrapperArray{T, N, AbstractArray{T, N}}(a)
43+
end
44+
Base.size(a::WrapperArray) = size(a.data)
45+
Base.getindex(a::WrapperArray, inds...) = getindex(a.data, inds...)
3746

38-
function test_to_vec(x::T; check_inferred = true) where {T}
47+
function test_to_vec(x::T; check_inferred=true) where {T}
3948
check_inferred && @inferred to_vec(x)
4049
x_vec, back = to_vec(x)
4150
@test x_vec isa Vector
@@ -61,14 +70,14 @@ end
6170
test_to_vec(randn(T, 5, 11))
6271
test_to_vec(randn(T, 13, 17, 19))
6372
test_to_vec(randn(T, 13, 0, 19))
64-
test_to_vec([1.0, randn(T, 2), randn(T, 1), 2.0]; check_inferred = false)
65-
test_to_vec([randn(T, 5, 4, 3), (5, 4, 3), 2.0]; check_inferred = false)
66-
test_to_vec(reshape([1.0, randn(T, 5, 4, 3), randn(T, 4, 3), 2.0], 2, 2); check_inferred = false)
73+
test_to_vec([1.0, randn(T, 2), randn(T, 1), 2.0]; check_inferred=false)
74+
test_to_vec([randn(T, 5, 4, 3), (5, 4, 3), 2.0]; check_inferred=false)
75+
test_to_vec(reshape([1.0, randn(T, 5, 4, 3), randn(T, 4, 3), 2.0], 2, 2); check_inferred=false)
6776
test_to_vec(UpperTriangular(randn(T, 13, 13)))
6877
test_to_vec(Diagonal(randn(T, 7)))
6978
test_to_vec(DummyType(randn(T, 2, 9)))
70-
test_to_vec(SVector{2, T}(1.0, 2.0))
71-
test_to_vec(SMatrix{2, 2, T}(1.0, 2.0, 3.0, 4.0))
79+
test_to_vec(SVector{2, T}(1.0, 2.0); check_inferred=false)
80+
test_to_vec(SMatrix{2, 2, T}(1.0, 2.0, 3.0, 4.0); check_inferred=false)
7281
test_to_vec(@view randn(T, 10)[1:4]) # SubArray -- Vector
7382
test_to_vec(@view randn(T, 10, 2)[1:4, :]) # SubArray -- Matrix
7483
test_to_vec(Base.ReshapedArray(rand(T, 3, 3), (9,), ()))
@@ -111,10 +120,10 @@ end
111120
test_to_vec((5, 4))
112121
# TODO remove "< 1.6" once https://github.com/JuliaLang/julia/issues/40277
113122
test_to_vec((5, randn(T, 5)); check_inferred = VERSION v"1.2" && VERSION < v"1.6")
114-
test_to_vec((randn(T, 4), randn(T, 4, 3, 2), 1); check_inferred = false)
123+
test_to_vec((randn(T, 4), randn(T, 4, 3, 2), 1); check_inferred=false)
115124
# TODO remove "< 1.6" once https://github.com/JuliaLang/julia/issues/40277
116125
test_to_vec((5, randn(T, 4, 3, 2), UpperTriangular(randn(T, 4, 4)), 2.5); check_inferred = VERSION v"1.2" && VERSION < v"1.6")
117-
test_to_vec(((6, 5), 3, randn(T, 3, 2, 0, 1)); check_inferred = false)
126+
test_to_vec(((6, 5), 3, randn(T, 3, 2, 0, 1)); check_inferred=false)
118127
test_to_vec((DummyType(randn(T, 2, 7)), DummyType(randn(T, 3, 9))))
119128
test_to_vec((DummyType(randn(T, 3, 2)), randn(T, 11, 8)))
120129
end
@@ -127,9 +136,9 @@ end
127136
end
128137
@testset "Dictionary" begin
129138
if T == Float64
130-
test_to_vec(Dict(:a=>5, :b=>randn(10, 11), :c=>(5, 4, 3)); check_inferred = false)
139+
test_to_vec(Dict(:a=>5, :b=>randn(10, 11), :c=>(5, 4, 3)); check_inferred=false)
131140
else
132-
test_to_vec(Dict(:a=>3 + 2im, :b=>randn(T, 10, 11), :c=>(5+im, 2-im, 1+im)); check_inferred = false)
141+
test_to_vec(Dict(:a=>3 + 2im, :b=>randn(T, 10, 11), :c=>(5+im, 2-im, 1+im)); check_inferred=false)
133142
end
134143
end
135144
end
@@ -146,7 +155,7 @@ end
146155
x_inner = (2, 3)
147156
x_outer = (1, x_inner)
148157
x_comp = Composite{typeof(x_outer)}(1, Composite{typeof(x_inner)}(2, 3))
149-
test_to_vec(x_comp; check_inferred = false)
158+
test_to_vec(x_comp; check_inferred=false)
150159
end
151160
end
152161

@@ -173,13 +182,16 @@ end
173182
end
174183

175184
@testset "FillVector" begin
176-
x = FillVector(5.0, 10)
177-
x_vec, from_vec = to_vec(x)
178-
@test_throws MethodError from_vec(randn(10))
185+
test_to_vec(FillVector(5.0, 10); check_inferred=false)
179186
end
180187

181188
@testset "fallback" begin
182189
nested = Nested(ThreeFields(1.0, 2.0, "Three"), Singleton())
183190
test_to_vec(nested; check_inferred=false) # map
184191
end
192+
193+
@testset "WrapperArray" begin
194+
wa = WrapperArray(rand(4, 5))
195+
test_to_vec(wa; check_inferred=false)
196+
end
185197
end

0 commit comments

Comments
 (0)