Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SR-13153] @derivative attribute ambiguous original declaration errors #54265

Closed
dan-zheng opened this issue Nov 26, 2019 · 1 comment
Closed
Assignees
Labels
AutoDiff compiler The Swift compiler in itself diagnostics QoI Bug: Diagnostics Quality of Implementation

Comments

@dan-zheng
Copy link
Collaborator

Previous ID SR-13153
Radar None
Original Reporter @dan-zheng
Type Sub-task
Status Closed
Resolution Done
Additional Detail from JIRA
Votes 0
Component/s Compiler
Labels Sub-task, AutoDiff, DiagnosticsQoI
Assignee @dan-zheng
Priority Medium

md5: 7ea5fa61fb651909a76ff92dae522234

Parent-Task:

  • SR-13149 @derivative and @transpose type-checking diagnostic improvements

blocks:

relates to:

  • TF-1058 [AD] Make @derivative attribute support qualified original declaration names

Issue Description:

@derivative(of🙂 is quite not at feature parity with @differentiable(jvp:vjp🙂 due to @derivative(of🙂 original function lookup ambiguity errors.

protocol P: Differentiable {}
extension Float: P {}

func max<T: P>(_ lhs: T, _ rhs: T) -> T { lhs }

@derivative(of: max(_:_:))
func vjpMax<T: P>(_ lhs: T, _ rhs: T) -> (value: T, pullback: (T) -> (T, T))
where T == T.TangentVector {
  // Type-checking is able to disambiguate the `max(_:_:)` call below.
  // But not the `max(_:_:)` reference above.
  (max(lhs, rhs), { v in (v, v) })
}
$ swiftc tf-1001.swift
tf-1001.swift:6:18: error: type 'T' does not conform to protocol 'Comparable'
@derivative(of: max(_:_:))
                 ^

The current workaround is to use @differentiable(jvp:vjp🙂 on the original function for derivative registration:

protocol P: Differentiable {}
extension Float: P {}

// No ambiguity here.
@differentiable(vjp: vjpMax where T == T.TangentVector)
func max<T: P>(_ lhs: T, _ rhs: T) -> T { lhs }

func vjpMax<T: P>(_ lhs: T, _ rhs: T) -> (value: T, pullback: (T) -> (T, T))
where T == T.TangentVector {
  // Type-checking is able to disambiguate the `max(_:_:)` call below.
  // But not the `max(_:_:)` reference above.
  (max(lhs, rhs), { v in (v, v) })
}

An eventual solution may be to allow qualified original function references in @derivative attribute, e.g. @derivative(of: ModuleName.max).


This issue has been encountered in practical use cases, e.g. ewconnell/swiftrt#4

@dan-zheng
Copy link
Collaborator Author

Done in #28892 by enabling @derivative attribute with qualified original declaration name.

func foo(_ x: Float) -> Float { x }

// 1. Module qualification.
// Immediately useful for derivative registration use cases: TF-1002.
// May be replaced by proper module qualification/`DeclNameRef` work:
// - https://forums.swift.org/t/pitch-fully-qualified-name-syntax/28482
// - https://forums.swift.org/t/what-the-eff-is-declnameref/31594
@derivative(of: MyModule.foo)
func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
  (foo(x), { $0 })
}

@swift-ci swift-ci transferred this issue from apple/swift-issues Apr 25, 2022
This issue was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AutoDiff compiler The Swift compiler in itself diagnostics QoI Bug: Diagnostics Quality of Implementation
Projects
None yet
Development

No branches or pull requests

1 participant