Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SR-14228] [AutoDiff] "Curry thunk" differentiation regression #54819

Open
dan-zheng opened this issue Mar 19, 2020 · 4 comments
Open

[SR-14228] [AutoDiff] "Curry thunk" differentiation regression #54819

dan-zheng opened this issue Mar 19, 2020 · 4 comments
Labels
AutoDiff bug A deviation from expected or documented behavior. Also: expected but undesirable behavior. compiler The Swift compiler in itself

Comments

@dan-zheng
Copy link
Collaborator

Previous ID SR-14228
Radar None
Original Reporter @dan-zheng
Type Bug
Additional Detail from JIRA
Votes 0
Component/s Compiler
Labels Bug, AutoDiff
Assignee None
Priority Medium

md5: 8635dba187654182d9292136d60a56fd

relates to:

  • TF-1030 allow serialized functions to reference implicit derivatives in some cases

Issue Description:

Curry thunks were recently rewritten as implicit AST closures instead of SILGen'd thunks: #28698

This caused regressions in curry thunk differentiation. Extracted from test/AutoDiff/downstream/generics.swift:

// TF-688: Test generic curry thunk cloning.
public struct TF_688_Struct<Scalar> {
  var x: Scalar
}
extension TF_688_Struct: Differentiable where Scalar: Differentiable {
  @differentiable
  public static func id(x: Self) -> Self {
    return x
  }
}
@differentiable(wrt: x)
public func TF_688<Scalar: Differentiable>(
  _ x: TF_688_Struct<Scalar>,
  reduction: @differentiable (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
) -> TF_688_Struct<Scalar> {
  reduction(x)
}

Before: no error.

// default argument 1 of TF_688<A>(_:reduction:)
sil non_abi [serialized] [ossa] @$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_ : $@convention(thin) <Scalar where Scalar : Differentiable> () -> @owned @differentiable @callee_guaranteed (@in_guaranteed TF_688_Struct<Scalar>) -> @out TF_688_Struct<Scalar> {
bb0:
  %0 = metatype $@thin TF_688_Struct<Scalar>.Type // user: %2
  // function_ref curry thunk of static TF_688_Struct<A>.id(x:)
  %1 = function_ref @$s4main13TF_688_StructVAAs14DifferentiableRzlE2id1xACyxGAG_tFZTc : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@thin TF_688_Struct<τ_0_0>.Type) -> @owned @callee_guaranteed (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_0> // user: %2
  %2 = apply %1<Scalar>(%0) : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@thin TF_688_Struct<τ_0_0>.Type) -> @owned @callee_guaranteed (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_0> // user: %3
  %3 = differentiable_function [parameters 0] %2 : $@callee_guaranteed (@in_guaranteed TF_688_Struct<Scalar>) -> @out TF_688_Struct<Scalar> // user: %4
  return %3 : $@differentiable @callee_guaranteed (@in_guaranteed TF_688_Struct<Scalar>) -> @out TF_688_Struct<Scalar> // id: %4
} // end sil function '$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_'

After: error regarding differentiating fragile function in serialized function.
This error was introduced in #28582

$ swiftc -Xllvm -debug-only=differentiation tf-688.swift
// AD__$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_A2FcAFmcfu___differentiable_curry_thunk_src_0_wrt_0
sil shared [serialized] @AD__$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_A2FcAFmcfu___differentiable_curry_thunk_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@thin TF_688_Struct<τ_0_0>.Type) -> @owned @differentiable @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> {
// %0                                             // users: %3, %1
bb0(%0 : $@thin TF_688_Struct<τ_0_0>.Type):
  debug_value %0 : $@thin TF_688_Struct<τ_0_0>.Type, let, name "self", argno 1 // id: %1
  // function_ref implicit closure #&#8203;2 in implicit closure #&#8203;1 in default argument 1 of TF_688<A>(_:reduction:)
  %2 = function_ref @$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_A2FcAFmcfu_A2Fcfu0_ : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed TF_688_Struct<τ_0_0>, @thin TF_688_Struct<τ_0_0>.Type) -> @out TF_688_Struct<τ_0_0> // user: %3
  %3 = partial_apply [callee_guaranteed] %2<τ_0_0>(%0) : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed TF_688_Struct<τ_0_0>, @thin TF_688_Struct<τ_0_0>.Type) -> @out TF_688_Struct<τ_0_0> // user: %4
  %4 = convert_function %3 : $@callee_guaranteed (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_0> to $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // user: %5
  %5 = differentiable_function [parameters 0] %4 : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // user: %6
  return %5 : $@differentiable @callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // id: %6
} // end sil function 'AD__$s4main6TF_688_9reductionAA0B11_688_StructVyxGAF_A2FXFts14DifferentiableRzlFfA0_A2FcAFmcfu___differentiable_curry_thunk_src_0_wrt_0'

[AD] Diagnosing non-differentiability.
[AD] For value:
  %4 = convert_function %3 : $@callee_guaranteed (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_0> to $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // user: %5
[AD] With invoker:
(differentiation_invoker differentiable_function_inst=(  %5 = differentiable_function [parameters 0] %4 : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed TF_688_Struct<τ_0_0>) -> @out TF_688_Struct<τ_0_1> for <τ_0_0, τ_0_0> // user: %6
))
tf-688.swift:14:95: error: function is not differentiable
  reduction: @differentiable (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
                                                                                ~~~~~~~~~~~~~~^~
tf-688.swift:14:95: note: differentiated functions in '@inlinable' functions must be marked '@differentiable' or have a public '@derivative'; this is not possible with a closure, make a top-level function instead
  reduction: @differentiable (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
                                                                                              ^
@swift-ci swift-ci transferred this issue from apple/swift-issues Apr 25, 2022
@philipturner
Copy link
Contributor

@slavapestov I was planning to fix this bug, whose reproducer is here. The end of its stack trace is somewhere in RQM. I was planning to read your entire research paper just to understand what was going on at the end, but that seems like an overblown amount of effort. Is it possible for you to examine the crash a little in LLDB and give me enough of an understanding that I can utilize my experience with other areas of the compiler to fix the bug?

@slavapestov
Copy link
Member

Your generic signature is <τ_0_0, τ_0_1, τ_0_2, τ_0_3 where τ_0_0 == τ_0_1, τ_0_2 == τ_0_3>.

τ_0_0.TangentVector is not a valid type parameter in this signature because τ_0_0 does not conform to Differentiable (which is where TangentVector is declared).

That's your bug. The autodiff code is probably forgetting to add a requirement to the signature, which is probably coming from a call to buildGenericSignature() somewhere in the autodiff code.

@philipturner
Copy link
Contributor

philipturner commented Apr 28, 2022

@slavapestov you're the best!

For future reference, I have narrowed down the reproducer to something smaller:

import _Differentiation

struct Box<Scalar> {
  var x: Scalar
}

extension Box: Differentiable where Scalar: Differentiable {}

struct Box2<T> {
  var x2: @differentiable(reverse) (Box<T>) -> Box<T>
}

@fibrechannelscsi
Copy link
Contributor

The reproducer posted on 4/28 is still broken with 2023-01-02a through to 2023-01-18a. We get:

Invalid type parameter in getReducedType()
Original type: τ_0_0.TangentVector
Simplified term: τ_0_0.[Differentiable:TangentVector]
Longest valid prefix: τ_0_0
Prefix type: τ_0_0

Requirement machine for <τ_0_0, τ_0_1>
Rewrite system: {
}
}
Property map: {
}
Conformance paths: {
}

and

1.	Apple Swift version 5.9-dev (LLVM 3f23b4ceaf01213, Swift 0763e4b98c74b5b)
2.	Compiling with the current language version
3.	While evaluating request ASTLoweringRequest(Lowering AST to SIL for module smallProject)
4.	While emitting property descriptor for 'x2' (at /Users/user/smallProject/main.swift:10:7)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
AutoDiff bug A deviation from expected or documented behavior. Also: expected but undesirable behavior. compiler The Swift compiler in itself
Projects
None yet
Development

No branches or pull requests

4 participants