@@ -39,10 +39,11 @@ for T in (:AbstractThunk, :Tangent, :Any)
3939 @eval LinearAlgebra. dot (x:: NotImplemented , :: $T ) = x
4040 @eval LinearAlgebra. dot (:: $T , x:: NotImplemented ) = x
4141end
42+ # unary :- is the same as multiplication by -1
43+ Base.:- (x:: NotImplemented ) = x
4244
4345# subtraction throws an exception: in AD we add tangents but do not subtract them
4446# subtraction happens eg. in gradient descent which can't be performed with `NotImplemented`
45- Base.:- (x:: NotImplemented ) = throw (NotImplementedException (x))
4647Base.:- (x:: NotImplemented , :: NotImplemented ) = throw (NotImplementedException (x))
4748for T in (:ZeroTangent , :NoTangent , :AbstractThunk , :Tangent , :Any )
4849 @eval Base.:- (x:: NotImplemented , :: $T ) = throw (NotImplementedException (x))
144145Base.:+ (a:: Dict , d:: Tangent{P} ) where {P} = merge (+ , a, backing (d))
145146Base.:+ (a:: Tangent{P} , b:: P ) where {P} = b + a
146147
148+ Base.:- (tangent:: Tangent{P} ) where {P} = map (- , tangent)
149+
147150# We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful
148151# In general one doesn't have to represent multiplications of 2 differentials
149152# Only of a differential and a scaling factor (generally `Real`)
0 commit comments