You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
SR-14056 Generalize Array: Differentiable conformance to more Collection types
Issue Description:
Generalize Array.differentiableReduce to an appropriately constrained Collection.differentiableReduce function.
It's not currently possible to write code (e.g. using loops) to implement reduction over a Differentiable-conforming Collection type otherwise.
Some code adapted from Array.differentiableReduce:
extensionCollectionwhere
Self:Differentiable, Element:Differentiable,
TangentVector:MutableCollection,
TangentVector.Element ==Element.TangentVector{@differentiable(wrt:(self, initialResult))publicfunc differentiableReduce<Result:Differentiable>(
_ initialResult:Result,
_ nextPartialResult:@differentiable(Result,Element)->Result)->Result{reduce(initialResult, nextPartialResult)}@inlinable@derivative(of: differentiableReduce)internalfunc _vjpDifferentiableReduce<Result:Differentiable>(
_ initialResult:Result,
_ nextPartialResult:@differentiable(Result,Element)->Result)->(
value:Result,
pullback:(Result.TangentVector)->(TangentVector,Result.TangentVector)){varpullbacks:[(Result.TangentVector)->(Result.TangentVector,Element.TangentVector)]=[]letcount=self.count
pullbacks.reserveCapacity(count)varresult= initialResult
for element in self{let(y, pb)=valueWithPullback(at: result, element, in: nextPartialResult)
result = y
pullbacks.append(pb)}return(
value: result,
pullback:{[selfTan = zeroTangentVector] tangent invarselfTangent= selfTan
varresultTangent= tangent
varindex= selfTangent.startIndex
for pullback in pullbacks.reversed(){let(newResultTangent, elementTangent)=pullback(resultTangent)
resultTangent = newResultTangent
selfTangent[index]+= elementTangent
index = selfTangent.index(after: index)}return(selfTangent, resultTangent)})}}
// Example:@differentiablefunc foo(_ array:[Float])->Float{return array.differentiableReduce(0,+)}print(gradient(at:[1,2,3,4], in: foo))// Doesn't currently work because `Array.TangentVector` is not a `MutableCollection`.
We need to change array-specific code to use general differentiation helpers, like AdditiveArithmetic.+.
One roadblock may be converting [Element.TangentVector] to Self.TangentVector in the pullback body. I'm not sure how to best do this.
The text was updated successfully, but these errors were encountered:
The second code snippet works as is because differentiableReduce is currently defined as a method on Array.
This issue tracks removing Array.differentiableReduce from the Differentiation library and generalizing it to Collection.differentiableReduce, as prototyped in the first snippet. When this is done, I believe the second snippet fails to compile with the error in the issue description.
This can be verified by renaming the definition and usage differentiableReduce to something else in both snippets, e.g. differentiableReduce_.
Additional Detail from JIRA
md5: 99f51c59606b6059ec9f1e4b34a6280a
Parent-Task:
Array: Differentiable
conformance to moreCollection
typesIssue Description:
Generalize
Array.differentiableReduce
to an appropriately constrainedCollection.differentiableReduce
function.It's not currently possible to write code (e.g. using loops) to implement reduction over a
Differentiable
-conformingCollection
type otherwise.Some code adapted from
Array.differentiableReduce
:We need to change array-specific code to use general differentiation helpers, like
AdditiveArithmetic.+
.One roadblock may be converting
[Element.TangentVector]
toSelf.TangentVector
in the pullback body. I'm not sure how to best do this.The text was updated successfully, but these errors were encountered: