Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SR-14235] [AutoDiff] Cross-file differentiation error: derivative configuration not registered for @derivative attribute #55170

Closed
dan-zheng opened this issue May 3, 2020 · 4 comments · Fixed by #58644 or #58965
Assignees
Labels
AutoDiff bug A deviation from expected or documented behavior. Also: expected but undesirable behavior. compiler The Swift compiler in itself multiple files Flag: An issue whose reproduction requires multiple files serialization Area → compiler: Serialization & deserialization type checker Area → compiler: Semantic analysis

Comments

@dan-zheng
Copy link
Collaborator

Previous ID SR-14235
Radar None
Original Reporter @dan-zheng
Type Bug
Additional Detail from JIRA
Votes 0
Component/s Compiler
Labels Bug, AutoDiff
Assignee None
Priority Medium

md5: 1c36955f37f690185a211308ac4034a9

Issue Description:

// a.swift
@differentiable
func crossFileDerivativeAttr<T: DifferentiableTensorProtocol>(
  _ input: T
) -> T {
  return input.identityDerivativeAttr()
}
import _Differentiation

protocol DifferentiableTensorProtocol: Differentiable {}

extension DifferentiableTensorProtocol {
  func identityDerivativeAttr() -> Self { self }

  // Test cross-file `@derivative` attribute.
  @derivative(of: identityDerivativeAttr)
  func vjpIdentityDerivativeAttr() -> (
    value: Self, pullback: (TangentVector) -> TangentVector
  ) {
    fatalError()
  }
}

The error occurs because AbstractFunctionDecl::getDerivativeConfigurations is not sufficiently requestified to type-check @derivative attributes in other files.

$ swiftc a.swift b.swift
a.swift:20:16: error: expression is not differentiable
  return input.identityDerivativeAttr()
               ^
a:20:16: note: cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files
  return input.identityDerivativeAttr()
               ^
@swift-ci swift-ci transferred this issue from apple/swift-issues Apr 25, 2022
@asl asl self-assigned this May 3, 2022
asl added a commit that referenced this issue May 3, 2022
Look-up for functions with @Derivative attributes defined in non-primary source files

Fixes #55170
@asl
Copy link
Collaborator

asl commented May 8, 2022

@dan-zheng Apparently here is more serious problem which is not addressed by #58644. Consider the following case:

  • a.swift uses function A
  • b.swift defines function A
  • c.swift defines custom derivative vjpA for A

#58644 handles the case when A and vjpA are both defined in the same file. However, it does not handle the case as above. There is one important special case: when A is some runtime function or defined in some external module. Then we'd effectively won't be able to register a custom derivative for it. Consider e.g. the following:

a.swift:

import _Differentiation

@differentiable(reverse)
func clamp(_ value: Double, _ lowerBound: Double, _ upperBound: Double) -> Double {
    return max(min(value, upperBound), lowerBound)
}

b.swift:

import _Differentiation

@inlinable
@derivative(of: min)
func minVJP<T: Comparable & Differentiable>(
    _ x: T,
    _ y: T
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
    func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
        if x <= y {
            return (v, .zero)
        }
        else {
            return (.zero, v)
        }
    }
    return (value: min(x, y), pullback: pullback)
}

@inlinable
@derivative(of: max)
func maxVJP<T: Comparable & Differentiable>(
    _ x: T,
    _ y: T
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
    func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
        if x < y {
            return (.zero, v)
        }
        else {
            return (v, .zero)
        }
    }
    return (value: max(x, y), pullback: pullback)
}

Looks like we'd need a dedicated pass over all non-primary sources to pull these custom derivatives :( Or something similar...

@dan-zheng
Copy link
Collaborator Author

dan-zheng commented May 8, 2022

#58644 handles the case when A and vjpA are both defined in the same file. However, it does not handle the case as above. There is one important special case: when A is some runtime function or defined in some external module. Then we'd effectively won't be able to register a custom derivative for it. Consider e.g. the following:

Are you sure this is not currently working as intended?

My understanding of the test case:

  • a.swift defines clamp which uses max and min.
  • b.swift defines custom derivatives for max and min.
  • I believe derivative functions are already defined for max and min in the _Differentiation library.
    • If not, compilation of a.swift should fail as the body of clamp cannot be differentiated as there are no registered derivatives for external functions max and min.

Expected behavior:

  • If max and min are not defined in a.swift, then compilation of a.swift should produce extern references to their derivative functions.
  • Files that import a.swift but not b.swift will use the max and min derivatives defined in the _Differentiation library.
  • Files that import both a.swift and b.swift will use the max and min derivatives defined in b.swift.
    • This is like a "multiple conformances" scenario. I'm not sure we have logic for deterministically resolving the right "registered derivative" when multiple options exist, or if multiple registrations are even allowed.

If you'd like to investigate this case, writing a multi-file test would be helpful!

@asl
Copy link
Collaborator

asl commented May 8, 2022

Are you sure this is not currently working as intended?

Well, I'd not put comment here w/o checking the testcase first ;)

# ./swiftc a.swift b.swift
a.swift:5:12: error: expression is not differentiable
    return max(min(value, upperBound), lowerBound)
           ^
a.swift:5:12: note: cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files
    return max(min(value, upperBound), lowerBound)
           ^
a.swift:5:16: error: expression is not differentiable
    return max(min(value, upperBound), lowerBound)
               ^
a.swift:5:16: note: cannot differentiate functions that have not been marked '@differentiable' and that are defined in other files
    return max(min(value, upperBound), lowerBound)

Compiling both files together makes this error disappear obviously.

asl added a commit that referenced this issue May 11, 2022
Look-up for functions with @Derivative attributes defined in non-primary source files

Fixes #55170
@asl
Copy link
Collaborator

asl commented May 11, 2022

I'm reopening this issue as a1e138b implemented only partial fix

@asl asl reopened this May 11, 2022
asl added a commit to asl/swift that referenced this issue May 18, 2022
…heck is finished for the primary source.

This registers all custom derivatives before autodiff transformations and makes them available to them.

Fully resolves apple#55170
drexin pushed a commit to drexin/swift that referenced this issue Jun 3, 2022
Look-up for functions with @Derivative attributes defined in non-primary source files

Fixes apple#55170
asl added a commit that referenced this issue Jul 18, 2022
…58965)

* Lookup for custom derivatives in non-primary source files after typecheck is finished for the primary source.

This registers all custom derivatives before autodiff transformations and makes them available to them.

Fully resolves #55170
Catfish-Man pushed a commit to Catfish-Man/swift that referenced this issue Jul 28, 2022
…pple#58965)

* Lookup for custom derivatives in non-primary source files after typecheck is finished for the primary source.

This registers all custom derivatives before autodiff transformations and makes them available to them.

Fully resolves apple#55170
@AnthonyLatsis AnthonyLatsis added multiple files Flag: An issue whose reproduction requires multiple files type checker Area → compiler: Semantic analysis serialization Area → compiler: Serialization & deserialization and removed type checker Area → compiler: Semantic analysis labels Jan 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AutoDiff bug A deviation from expected or documented behavior. Also: expected but undesirable behavior. compiler The Swift compiler in itself multiple files Flag: An issue whose reproduction requires multiple files serialization Area → compiler: Serialization & deserialization type checker Area → compiler: Semantic analysis
Projects
None yet
3 participants