Skip to content

Commit 274b258

Browse files
committed
make difference relaxed about arrays and smarten tests
1 parent 8b708ae commit 274b258

2 files changed

Lines changed: 15 additions & 10 deletions

File tree

src/difference.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ difference(::Real, ::T, ::T) where {T<:Integer} = NoTangent()
1818

1919
difference::Real, y::T, x::T) where {T<:Number} = (y - x) / ε
2020

21-
difference::Real, y::T, x::T) where {T<:StridedArray} = difference.(ε, y, x)
21+
# we are a bit more relaced for AbstractArrays as they naturally represent a vector space
22+
difference::Real, y::AbstractArray, x) = difference.(ε, y, x)
23+
# resolve ambiguity
24+
difference::Real, y::T, x::T) where {T<:AbstractArray} = difference.(ε, y, x)
2225

2326
function difference::Real, y::T, x::T) where {T<:Tuple}
2427
return Tangent{T}(difference.(ε, y, x)...)

test/difference.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,5 @@
11
using FiniteDifferences: rand_tangent, difference
22

3-
function test_difference::Real, x, dx)
4-
y = x + ε * dx
5-
dx_diff = difference(ε, y, x)
6-
# TODO: `@test isapprox(dx, dx_diff)` once `isapprox` is defined appropriately
7-
# see https://github.com/JuliaDiff/ChainRulesCore.jl/issues/184
8-
@test typeof(dx) == typeof(dx_diff)
9-
end
10-
113
@testset "difference" begin
124

135
@testset "Primal: $(typeof(x))" for (ε, x) in [
@@ -56,7 +48,17 @@ end
5648
(randn(), Adjoint(randn(ComplexF64, 3, 3))),
5749
(randn(), Transpose(randn(3))),
5850
]
59-
test_difference(ε, x, rand_tangent(x))
51+
# Construct a value that should be equal to the difference and check that it is
52+
dx = rand_tangent(x)
53+
y = x + ε * dx
54+
dx_diff = difference(ε, y, x)
55+
56+
if x isa AbstractArray{<:Number} || x isa Number
57+
@test x + dx x + dx_diff
58+
else
59+
# hard to check value if don't overload `≈` so for now we just check type
60+
@test typeof(dx) == typeof(dx_diff)
61+
end
6062
end
6163

6264
# Ensure struct fallback errors for non-struct types.

0 commit comments

Comments
 (0)