You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The derivative below with respect to Bar.a is, correctly, '0' on the backwards pass while still inside the apply() function. Something happens between there and propagating out to the enclosing function, where the derivative incorrectly becomes '1'.
As a possible hint, making 'HoldsKeyPaths' only hold a single keyPath instead of an array makes the derivative propagate correctly.
This bug seems to be something of an 'inverse' to https://bugs.swift.org/browse/SR-14218, in which the derivative incorrectly goes from 1 to 0. Here it goes from 0 to 1.
import Foundation
import _Differentiation
//------------------------------------------------------------------------------------------// Enable differentiable keyPath writes//called 'writeSlowTo' because inout object wasn't working@differentiable(where Object: Differentiable, Object ==Object.TangentVector, Member: Differentiable, Member ==Member.TangentVector)publicfunc writeSlowTo<Object,Member>(_ object:Object, at member:WritableKeyPath<Object,Member>, with value:Member)->Object{varobject= object
object[keyPath: member]= value
return object
}@derivative(of: writeSlowTo)publicfunc vjpWriteSlowTo<Object,Member>(_ object:Object, at member:WritableKeyPath<Object,Member>, with value:Member)->(value:Object, pullback:(Object.TangentVector)->(Object.TangentVector,Member.TangentVector))where Object:Differentiable, Object ==Object.TangentVector, Member:Differentiable, Member ==Member.TangentVector{func pullback(_ dself:Object.TangentVector)->(Object.TangentVector,Member.TangentVector){vardself= dself
letdWriteValue=dself[keyPath: member]dself[keyPath: member]=Member.zero
return(dself, dWriteValue)}varobject= object
object =writeSlowTo(object, at: member, with: value)return(object, pullback)}//------------------------------------------------------------------------------------------// Usage codepublicstructBar:Differentiable&AdditiveArithmetic{publicvara:Double=7}structHoldsKeyPaths<Root,Value>{varkeyPaths:[WritableKeyPath<Root,Value>]=[]@differentiable(where Root: Differentiable, Root ==Root.TangentVector, Value: Differentiable, Value ==Value.TangentVector)publicfunc apply(to root:inoutRoot, using value:Value){
for kp in self.keyPaths {//---// force cast to print gradientvarrootbar= root as!Bar
rootbar = rootbar.withDerivative{print("correct derivative", $0)}
root = rootbar as!Root//---
root =writeSlowTo(root, at: kp, with: value)}}}letbar=Bar()letkeyPaths=[\Bar.a]letkeyPathHolder=HoldsKeyPaths<Bar,Double>(keyPaths: keyPaths)func writeToBar(bar:Bar, newValue:Double)->Double{varbar= bar
bar = bar.withDerivative{print("actual derivative", $0)}
keyPathHolder.apply(to:&bar, using: newValue)return bar.a
}letnewValue:Double=7letvalAndGrad=valueWithGradient(at: bar, newValue, in: writeToBar)
if valAndGrad.gradient.0.a !=0{print("gradient is incorrect")}//prints: //correct derivative Bar(a: 0.0)//actual derivative Bar(a: 1.0)//gradient is incorrect
The text was updated successfully, but these errors were encountered:
Alternatively, (and maybe more insightfully), here is a version where HoldsKeyPaths only has one keypath, and adding a useless loop can toggle the bug:
Additional Detail from JIRA
md5: fe6a21d4d320144e6bcc478f511d7199
Issue Description:
The derivative below with respect to Bar.a is, correctly, '0' on the backwards pass while still inside the apply() function. Something happens between there and propagating out to the enclosing function, where the derivative incorrectly becomes '1'.
As a possible hint, making 'HoldsKeyPaths' only hold a single keyPath instead of an array makes the derivative propagate correctly.
This bug seems to be something of an 'inverse' to https://bugs.swift.org/browse/SR-14218, in which the derivative incorrectly goes from 1 to 0. Here it goes from 0 to 1.
The text was updated successfully, but these errors were encountered: