[SR-9596] [AD] Support the trailing 'where' clause in @differentiable attributes #52043
Labels
bug
A deviation from expected or documented behavior. Also: expected but undesirable behavior.
swift for tensorflow
Additional Detail from JIRA
md5: c7be10d7e95a72060b883a88c257d900
Issue Description:
Currently, we require the vjp to be defined in the same generic context. However, this is mathematically wrong and practically limiting.
In this example, the original function is defined under constraint `Scalar : Numeric`. But it should not be differentiable unless 'Scalar : Differentiable'. For this, we have syntax support for a trailing 'where' clause in '@differentiable':
But this is currently being ignored by the type checker. We should support it now in order to unblock defining primitives for tensor methods that take or return `Scalar`. With 'where' clauses type-checked, VJP/adjoint should be allowed to be defined in a more constrained generic context.
The text was updated successfully, but these errors were encountered: