Skip to content

ForwardDiff of complex erfcx fails #493

@mmikhasenko

Description

@mmikhasenko

Here is an observation

using SpecialFunctions
using Zygote
using ForwardDiff

f(x) = real(SpecialFunctions.erfcx((x + 1.0im)^2))

# reference:
Zygote.gradient(f, 1.0)[1] # 0.3169262912276313

# call
ForwardDiff.derivative(f, 1.0) # fails with 

throws the error,

ERROR: MethodError: no method matching _erfcx(::Complex{ForwardDiff.Dual{ForwardDiff.Tag{var"#13#14", Float64}, Float64, 1}})
The function `_erfcx` exists, but no method is defined for this combination of argument types.

Solution

It can be fixed by explaining manually how deal with complex numbers,

const ComplexDual{T, V, N} = Complex{ForwardDiff.Dual{T, V, N}}

function SpecialFunctions.erfcx(z::ComplexDual{T, V, N}) where {T, V, N}
    real_part = real(z)
    imag_part = imag(z)
    # Get primal values
    z_val = Complex(ForwardDiff.value(real_part), ForwardDiff.value(imag_part))

    # Compute function value
    w_val = erfcx(z_val)

    # Compute derivative
    ∂w = 2 * (z_val * w_val - 1 / sqrt(π))

    # Get partial derivatives
    dr = ForwardDiff.partials(real_part)
    di = ForwardDiff.partials(imag_part)

    # Construct dual result
    real_dual = ForwardDiff.Dual{T}(real(w_val), real(∂w) * dr - imag(∂w) * di)
    imag_dual = ForwardDiff.Dual{T}(imag(w_val), imag(∂w) * dr + real(∂w) * di)

    return Complex(real_dual, imag_dual)
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions