@@ -25,6 +25,18 @@ struct Tangent{P,T} <: AbstractTangent
2525 # Note: If T is a Tuple/Dict, then P is also a Tuple/Dict
2626 # (but potentially a different one, as it doesn't contain differentials)
2727 backing:: T
28+
29+ function Tangent {P,T} (backing) where {P,T}
30+ if P <: Tuple
31+ T <: Tuple || _backing_error (P, T, Tuple)
32+ elseif P <: AbstractDict
33+ T <: AbstractDict || _backing_error (P, T, AbstractDict)
34+ elseif P === Any # can be anything
35+ else # Any other struct (including NamedTuple)
36+ T <: NamedTuple || _backing_error (P, T, NamedTuple)
37+ end
38+ return new (backing)
39+ end
2840end
2941
3042function Tangent {P} (; kwargs... ) where {P}
@@ -45,6 +57,11 @@ function Tangent{P}(d::Dict) where {P<:Dict}
4557 return Tangent {P,typeof(d)} (d)
4658end
4759
60+ function _backing_error (P, G, E)
61+ msg = " Tangent for the primal $P should be backed by a $E type, not by $G ."
62+ return throw (ArgumentError (msg))
63+ end
64+
4865function Base.:(== )(a:: Tangent{P,T} , b:: Tangent{P,T} ) where {P,T}
4966 return backing (a) == backing (b)
5067end
0 commit comments