Skip to content

Commit f70a651

Browse files
authored
Merge pull request #165 from JuliaDiff/ox/covectors
Make rand_tangent on adjoint an transpose return natural
2 parents 266d6fa + 274b258 commit f70a651

5 files changed

Lines changed: 34 additions & 23 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FiniteDifferences"
22
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
3-
version = "0.12.7"
3+
version = "0.12.8"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/difference.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ difference(::Real, ::T, ::T) where {T<:Integer} = NoTangent()
1818

1919
difference::Real, y::T, x::T) where {T<:Number} = (y - x) / ε
2020

21-
difference::Real, y::T, x::T) where {T<:StridedArray} = difference.(ε, y, x)
21+
# we are a bit more relaced for AbstractArrays as they naturally represent a vector space
22+
difference::Real, y::AbstractArray, x) = difference.(ε, y, x)
23+
# resolve ambiguity
24+
difference::Real, y::T, x::T) where {T<:AbstractArray} = difference.(ε, y, x)
2225

2326
function difference::Real, y::T, x::T) where {T<:Tuple}
2427
return Tangent{T}(difference.(ε, y, x)...)

src/rand_tangent.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ rand_tangent(rng::AbstractRNG, x::Integer) = NoTangent()
1313

1414
rand_tangent(rng::AbstractRNG, x::T) where {T<:Number} = randn(rng, T)
1515

16-
# TODO: right now Julia don't allow `randn(rng, BigFloat)`
16+
# TODO: right now Julia don't allow `randn(rng, BigFloat)`
1717
# see: https://github.com/JuliaLang/julia/issues/17629
1818
rand_tangent(rng::AbstractRNG, ::BigFloat) = big(randn(rng))
1919

2020
rand_tangent(rng::AbstractRNG, x::StridedArray) = rand_tangent.(Ref(rng), x)
21+
rand_tangent(rng::AbstractRNG, x::Adjoint) = adjoint(rand_tangent(rng, parent(x)))
22+
rand_tangent(rng::AbstractRNG, x::Transpose) = transpose(rand_tangent(rng, parent(x)))
2123

2224
function rand_tangent(rng::AbstractRNG, x::T) where {T<:Tuple}
2325
return Tangent{T}(rand_tangent.(Ref(rng), x)...)

test/difference.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,5 @@
11
using FiniteDifferences: rand_tangent, difference
22

3-
function test_difference::Real, x, dx)
4-
y = x + ε * dx
5-
dx_diff = difference(ε, y, x)
6-
# TODO: `@test isapprox(dx, dx_diff)` once `isapprox` is defined appropriately
7-
# see https://github.com/JuliaDiff/ChainRulesCore.jl/issues/184
8-
@test typeof(dx) == typeof(dx_diff)
9-
end
10-
113
@testset "difference" begin
124

135
@testset "Primal: $(typeof(x))" for (ε, x) in [
@@ -56,7 +48,17 @@ end
5648
(randn(), Adjoint(randn(ComplexF64, 3, 3))),
5749
(randn(), Transpose(randn(3))),
5850
]
59-
test_difference(ε, x, rand_tangent(x))
51+
# Construct a value that should be equal to the difference and check that it is
52+
dx = rand_tangent(x)
53+
y = x + ε * dx
54+
dx_diff = difference(ε, y, x)
55+
56+
if x isa AbstractArray{<:Number} || x isa Number
57+
@test x + dx x + dx_diff
58+
else
59+
# hard to check value if don't overload `≈` so for now we just check type
60+
@test typeof(dx) == typeof(dx_diff)
61+
end
6062
end
6163

6264
# Ensure struct fallback errors for non-struct types.

test/rand_tangent.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ using FiniteDifferences: rand_tangent
2424
(randn(Complex{Float32}, 5, 4), Matrix{Complex{Float32}}),
2525
([randn(5, 4), 4.0], Vector{Any}),
2626

27+
# Wrapper Arrays
28+
(randn(5, 4)', Adjoint{Float64, Matrix{Float64}}),
29+
(transpose(randn(5, 4)), Transpose{Float64, Matrix{Float64}}),
30+
31+
2732
# Tuples.
2833
((4.0, ), Tangent{Tuple{Float64}}),
2934
((5.0, randn(3)), Tangent{Tuple{Float64, Vector{Float64}}}),
@@ -66,20 +71,19 @@ using FiniteDifferences: rand_tangent
6671
Hermitian(randn(ComplexF64, 1, 1)),
6772
Tangent{Hermitian{ComplexF64, Matrix{ComplexF64}}},
6873
),
69-
(
70-
Adjoint(randn(ComplexF64, 3, 3)),
71-
Tangent{Adjoint{ComplexF64, Matrix{ComplexF64}}},
72-
),
73-
(
74-
Transpose(randn(3)),
75-
Tangent{Transpose{Float64, Vector{Float64}}},
76-
),
7774
]
7875
@test rand_tangent(rng, x) isa T_tangent
7976
@test rand_tangent(x) isa T_tangent
80-
@test x + rand_tangent(rng, x) isa typeof(x)
8177
end
8278

83-
# Ensure struct fallback errors for non-struct types.
84-
@test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0)
79+
@testset "erroring cases" begin
80+
# Ensure struct fallback errors for non-struct types.
81+
@test_throws ArgumentError invoke(rand_tangent, Tuple{AbstractRNG, Any}, rng, 5.0)
82+
end
83+
84+
@testset "compsition of addition" begin
85+
x = Foo(1.5, 2, Foo(1.1, 3, [1.7, 1.4, 0.9]))
86+
@test x + rand_tangent(x) isa typeof(x)
87+
@test x + (rand_tangent(x) + rand_tangent(x)) isa typeof(x)
88+
end
8589
end

0 commit comments

Comments
 (0)