@@ -272,18 +272,55 @@ end
272272# ####
273273
274274# Ref
275+ # Note that Ref is mutable. This causes Zygote to represent its structral tangent not as a NamedTuple,
276+ # but as `Ref{Any}((x=val,))`. Here we use a Tangent, there is at present no mutable version, but see
277+ # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/105
275278function ProjectTo (x:: Ref )
276279 sub = ProjectTo (x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)?
277- if sub isa ProjectTo{<: AbstractZero }
280+ return ProjectTo {Tangent{typeof(x)}} (; x= sub)
281+ end
282+ (project:: ProjectTo{<:Tangent{<:Ref}} )(dx:: Tangent ) = project (Ref (first (backing (dx))))
283+ function (project:: ProjectTo{<:Tangent{<:Ref}} )(dx:: Ref )
284+ dy = project. x (dx[])
285+ return project_type (project)(; x= dy)
286+ end
287+ # Since this works like a zero-array in broadcasting, it should also accept a number:
288+ (project:: ProjectTo{<:Tangent{<:Ref}} )(dx:: Number ) = project (Ref (dx))
289+
290+ # Tuple
291+ function ProjectTo (x:: Tuple )
292+ elements = map (ProjectTo, x)
293+ if elements isa NTuple{<: Any ,ProjectTo{<: AbstractZero }}
278294 return ProjectTo {NoTangent} ()
279295 else
280- return ProjectTo {Ref} (; type = typeof (x), x = sub )
296+ return ProjectTo {Tangent{ typeof(x)}} (; elements = elements )
281297 end
282298end
283- (project:: ProjectTo{Ref} )(dx:: Tangent{<:Ref} ) = Tangent {project.type} (; x= project. x (dx. x))
284- (project:: ProjectTo{Ref} )(dx:: Ref ) = Tangent {project.type} (; x= project. x (dx[]))
285- # Since this works like a zero-array in broadcasting, it should also accept a number:
286- (project:: ProjectTo{Ref} )(dx:: Number ) = Tangent {project.type} (; x= project. x (dx))
299+ # This method means that projection is re-applied to the contents of a Tangent.
300+ # We're not entirely sure whether this is every necessary; but it should be safe,
301+ # and should often compile away:
302+ (project:: ProjectTo{<:Tangent{<:Tuple}} )(dx:: Tangent ) = project (backing (dx))
303+ function (project:: ProjectTo{<:Tangent{<:Tuple}} )(dx:: Tuple )
304+ len = length (project. elements)
305+ if length (dx) != len
306+ str = " tuple with length(x) == $len cannot have a gradient with length(dx) == $(length (dx)) "
307+ throw (DimensionMismatch (str))
308+ end
309+ # Here map will fail if the lengths don't match, but gives a much less helpful error:
310+ dy = map ((f, x) -> f (x), project. elements, dx)
311+ return project_type (project)(dy... )
312+ end
313+ function (project:: ProjectTo{<:Tangent{<:Tuple}} )(dx:: AbstractArray )
314+ for d in 1 : ndims (dx)
315+ if size (dx, d) != get (length (project. elements), d, 1 )
316+ throw (_projection_mismatch (axes (project. elements), size (dx)))
317+ end
318+ end
319+ dy = reshape (dx, axes (project. elements)) # allows for dx::OffsetArray
320+ dz = ntuple (i -> project. elements[i](dy[i]), length (project. elements))
321+ return project_type (project)(dz... )
322+ end
323+
287324
288325# ####
289326# #### `LinearAlgebra`
0 commit comments