11using ChainRulesCore, Test
22using LinearAlgebra, SparseArrays
3- using OffsetArrays, BenchmarkTools
3+ using OffsetArrays, StaticArrays, BenchmarkTools
44
55# Like ForwardDiff.jl's Dual
66struct Dual{T<: Real } <: Real
@@ -295,7 +295,7 @@ struct NoSuperType end
295295 # ####
296296
297297 @testset " OffsetArrays" begin
298- # While there is no code for this, the rule that it checks axes(x) == axes(dx) else
298+ # While there is no code for this, the rule that it checks axes(x) === axes(dx) else
299299 # reshape means that it restores offsets. (It throws an error on nontrivial size mismatch.)
300300
301301 poffv = ProjectTo (OffsetArray (rand (3 ), 0 : 2 ))
@@ -304,8 +304,34 @@ struct NoSuperType end
304304
305305 @test axes (poffv (OffsetArray (rand (3 ), 0 : 2 ))) == (0 : 2 ,)
306306 @test axes (poffv (OffsetArray (rand (3 , 1 ), 0 : 2 , 0 : 0 ))) == (0 : 2 ,)
307+
308+ pvec3 = ProjectTo ([1 , 2 , 3 ])
309+ @test axes (pvec3 (OffsetArray (rand (3 ), 0 : 2 ))) == (1 : 3 ,)
310+ @test pvec3 (OffsetArray (rand (3 ), 0 : 2 )) isa Vector # relies on axes === axes test
311+ @test pvec3 (OffsetArray (rand (3 ,1 ), 0 : 2 , 0 : 0 )) isa Vector
307312 end
308313
314+ # ####
315+ # #### `StaticArrays`
316+ # ####
317+
318+ @testset " StaticArrays" begin
319+ # There is no code for this, but when argument isa StaticArray, axes(x) === axes(dx)
320+ # implies a check, and reshape will wrap a Vector into a static SizedVector:
321+ pstat = ProjectTo (SA[1 , 2 , 3 ])
322+ @test axes (pstat (rand (3 ))) === (SOneTo (3 ),)
323+
324+ # This recurses into structured arrays:
325+ pst = ProjectTo (transpose (SA[1 , 2 , 3 ]))
326+ @test axes (pst (rand (1 ,3 ))) === (SOneTo (1 ), SOneTo (3 ))
327+ @test pst (rand (1 ,3 )) isa Transpose
328+
329+ # When the argument is an ordinary Array, static gradients are allowed to pass,
330+ # like FillArrays. Collecting to an Array would cost a copy.
331+ pvec3 = ProjectTo ([1 , 2 , 3 ])
332+ @test pvec3 (SA[1 , 2 , 3 ]) isa StaticArray
333+ end
334+
309335 # ####
310336 # #### `ChainRulesCore`
311337 # ####
0 commit comments