-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Revert "Add generic fallback to all scalar functions" #86
base: master
Are you sure you want to change the base?
Conversation
This reverts commit 05dbeac.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #86 +/- ##
==========================================
- Coverage 92.66% 92.19% -0.47%
==========================================
Files 1 1
Lines 150 141 -9
==========================================
- Hits 139 130 -9
Misses 11 11 ☔ View full report in Codecov by Sentry. |
This has merge conflicts and doesn't make sense. There are checks in the functions for the domains. |
The point of NaNMath for me (and others I guess) is that The |
Only after #85 was merged. When the PR was opened it had no conflicts. |
But before they also errored with |
That's besides the point - of course, if something is not supported by e.g. But falling back to some function in all cases completely breaks the safety promises of NaNMath. The main point is to not throw ArgumentErrors or DomainErrors. |
I can see the philosophical argument, but I can't see a practical application other than real-time controls that wouldn't just want to have the fallback. What about a middle ground: a Preference "strict=false" that can be turned to true that removes the fallbacks? |
I agree with @ChrisRackauckas, the fallbacks are good. NaNs are an instance of floats, so it makes sense to only define NaNs for those and then pass thru for everything else. NaNMath provides an interface if a package does want to add special handling for their type, but the default is to propagate the DomainError. It just feels like what I'd expect for any interface on an abstract number. Maybe another interface is to just define a function log(x)
if applicable(zero, x) & applicable(<, x, zero(x)) && x < zero(x)
return nan(x)
else
return Base.log(x)
end
end
nan(::Float64) = NaN
# etc Then a custom type only needs to ensure that |
If I understand your point, I guess that same argument could be applied for NaNMath as a whole? To throw in my two cents as the original NaNMath author, I agree with @devmotion's point. Providing fallbacks that could throw domain errors breaks what I thought was the API contract. (If you'd like to change the contract, that's potentially reasonable but should be more explicit.) |
Or maybe an interface like valid_domain(::typeof(log), _) = true
valid_domain(::typeof(log), x::AbstractFloat) = x < zero(x) |
I want to use NaNMath in my MTK-based ODE models to ensure that function evaluation never crashes but only returns NaN if there's a numerical issue. This safety aspect is the main point of using NaNMath IMO - otherwise I could directly use the Base methods and hope for the best. The main API guarantee of Thus IMO NaNMath functions must only be extended explicitly and in such a way that this API guarantee is satisfied. This means that
It's just not feasible to do 2. in NaNMath because the maintenance burden on NaNMath would be too high; it seems much easier to deal with this (there are only a handful of functions anyway!) in the respective packages and thereby distribute the maintenance burden (maintainers of these packages should also be the ones most familiar with their number types). I also don't expect that there are too many packages that would have to do 2. since the number of relevant custom number types is not too large in my experience. |
I'd personally vote to change the contract. Because of the speed issues noted in #63 I ended up implementing an in-house version of NaNMath from scratch (actually started before I noticed this library existed): As mentioned, DomainErrors will still get hit when necessary, should such a user type even generate them. I don't think it's a pragmatic goal to pursue one error over another because both tracebacks give you a clear error with information about what method needs to be added. |
I'll vote against a change for the reasons I outlined above. If a different behaviour is desired in SymbolicRegressions, then probably NaNMath is not the right tool for its use case and a different package is needed. I want the safety guarantees of NaNMath and will happily accept if the number of supported types is restricted to a safe set. I'll also accept if it's slower than unsafe alternatives. In my use case, otherwise errors might be thrown deep inside of the ODE solver, even within subsequent time steps, and there's no good way for me to handle these. I mean there's surely a reason for why MTK by default creates ODE functions with NaNMath. |
I wouldn't say it's necessarily safer in either setting. It's just hitting a different error. I do not think a MethodError deep in a call stack is any safer or more helpful than a DomainError. If some end user is passing a non-Real type then I think a |
I think the instances that have come up are ForwardDiff, BigFloat, Tracker, and ReverseDiff. For the autodiff ones, the issue is that not supporting the autodiff stuff entirely makes using NaNMath more generally pretty difficult. And on the other hand, one of the big reasons for NaNMath is for a solver to "try a step", and then pull back / change dt / d alpha, whatever the step size is in an ODE, Newton method, gradient decent, etc. In that context, the gradient calculation is unlikely to be the part calculating the NaNs, it's the "try at a new point" phase that is most likely the thing that has now stepped out of bounds. For that context, we then have that NaNMath not having full support on these AD types as causing an error where the algorithm would have otherwise worked, which is why we've had 3 or 4 PRs / issues over the years to add these fallbacks to the library. People keep running into it, seeing that there is a trivial fix, and wondering why we won't just add it. But on the other side, I do agree that if you do that, you no longer can guarantee that NaNMath will never error. Once one fallback can error, there is no 100% guarantee. And that's not great either. But if you go with a Preference system, then at least it's possible for both choices to be made. |
ForwardDiff, DiffRules, ReverseDiff, and Tracker have all supported NaNMath for many years, and also defined derivatives for NaNMath functions. So I don't think AD is necessarily a problem with NaNMath. The only problem I ran into and which was recently fixed both in ForwardDiff and Symbolics is that the derivative of
I'm not generally against it but given the experience with ForwardDiff and the original/current API contract of NaNMath the default setting should be the NaN-safe one without fallbacks. I'm a bit worried that it will be difficult for users to discover the preference setting - e.g., my impression is that NaN-safe mode in ForwardDiff is not known very widely even though it is explained in the official docs. Another consequence of a preference setting might be that other packages would be pressured into adding support for NaNMath anyway (to support the setting without fallbacks), which might make the unsafe setting less attractive and useful in general. |
Please me correct me if I'm wrong, but if I remember correctly, the use of preferences is meant to be used by end users, a package developer can offer preferences and according to those preferences, but not change preferences in other packages. So, for package developers, changing the preferences in ForwardDiff (or NaNMath, if they are enabled) is a problem. In particular, i would like opt-in support, but more in the likes of #signalling support for NaNMath, that is,
#if there is a function f(x::T1)::T2 with a domain for x,
#then NaNMath.f(x::T1)::T2 will return f(x) if x is in the domain of f, and T2(NaN) otherwise.
can_nan(::Type{Float64}) = true
can_nan(::typeof(Base.log), ::Type{Float64}) = true #if we only support a limited set of functions
#we need to convert to a valid input type
nan_promote(x::Float64) = x
nan_promote(x::Int) = float(x)
#TODO: reasonable defaults for Number, error on anything else
can_nan(x::Tuple) = mapreduce(can_nan, &, x)
can_nan(x) = can_nan(typeof(nan_promote(x)))
can_nan(x::Type) = false
#TODO: maybe define another function for more specific type selection? Also, is this granularity ok?
can_nan(f, x) = can_nan(x)
log(x::T) where T = nan_log(x, Val{can_nan(Base.log, T)}())
#maybe this function should be inside a module?
function nan_log(x::T, ::Val{true}) where T
x < 0 ? T(NaN) : Base.log(x)
end
function nan_log(x::T, ::Val{false})
throw(ArgumentError("""
NaNMath.log explicitly does not support $T. If this type
can support calculating NaN, define NaNMath.can_nan(::typeof(log), x::Type{$T}
"""
end In this way, NaNMath acts as just another composable layer, instead of another set of rules that need to be defined for packages and intersection of packages. |
Okay, how about the following idea? I like this approach a lot. Basically, you could have a submodule NaNMath.log # no fallback
NaNMath.Generic.log # has fallback These are separate functions. The module NaNMath
#= existing code =#
module Generic
import ..log as _log
log(x) = applicable(_log, x) ? _log(x) : Base.log(x)
#= other functions =#
end
end Thus, should a user define a custom NaNMath extension, it will also be accessible to the generic method.
Wdyt? Could also be inverted and have a |
Please let's just merge this PR and move this discussion to an issue. Breaking the core API guarantees and putting this change in a non-breaking release is basically the worst scenario, regardless of which API design you'd prefer. |
Well, then it seems you'd want the API without generic fallback 😄 I should also emphasize that even without generic fallback by no means every number type
You can just define such a method in your package(s) if you want this generality. No need to put it in NaNMath and make the API messy/confusing for users and downstream packages (eg which version should MTK use?). |
Let's simplify this discussion a little bit: can you give one example of using the current master/release version where you cause the |
It broke the API guarantees for any special number type that had not opt in to the NaNMath API. For instance, julia> using NaNMath, Unitful
julia> NaNMath.sqrt(u"-1s^2")
ERROR: DomainError with -1.0:
sqrt was called with a negative real argument but will only return a complex result if called with a complex argument. Try sqrt(Complex(x)).
Stacktrace:
[1] throw_complex_domainerror(f::Symbol, x::Float64)
@ Base.Math ./math.jl:33
[2] sqrt
@ ./math.jl:608 [inlined]
[3] sqrt
@ ./math.jl:1531 [inlined]
[4] sqrt
@ ~/.julia/packages/Unitful/nwwOk/src/quantities.jl:205 [inlined]
[5] sqrt(x::Quantity{Int64, 𝐓², Unitful.FreeUnits{(s²,), 𝐓², nothing}})
@ NaNMath ~/.julia/packages/NaNMath/h9tir/src/NaNMath.jl:19
[6] top-level scope
@ REPL[19]:1 julia> using NaNMath, DynamicQuantities
julia> NaNMath.sqrt(u"-1s^2")
ERROR: DomainError with -1.0:
sqrt was called with a negative real argument but will only return a complex result if called with a complex argument. Try sqrt(Complex(x)).
Stacktrace:
[1] throw_complex_domainerror(f::Symbol, x::Float64)
@ Base.Math ./math.jl:33
[2] sqrt
@ ./math.jl:608 [inlined]
[3] sqrt
@ ~/.julia/packages/DynamicQuantities/OrITh/src/math.jl:146 [inlined]
[4] sqrt(x::Quantity{Float64, Dimensions{FixedRational{Int32, 25200}}})
@ NaNMath ~/.julia/packages/NaNMath/rftLo/src/NaNMath.jl:19
[5] top-level scope
@ REPL[4]:1 -1 is a bit extreme but in a numerical solver it's easy to hit slightly negative values. Without the NaNMath API guarantees, why should I use NaNMath at all? But I'm digressing. My main point right now is just: This change completely broke existing API guarantees and hence should not have been released in a non-breaking release. It should be reverted. |
How about: using NaNMath, Unitful
NaNMath.sqrt(x::Real) = sqrt(float(x))
NaNMath.sqrt(x::Complex) = Base.sqrt(float(x))
NaNMath.sqrt(x::T) where {T<:Number} = x < zero(x) ? T(NaN) : Base.sqrt(float(x)) julia> NaNMath.sqrt(u"-1.0s^2")
NaN s² ? Maybe the thing that NaNMath actually needs is just an interface function NaNMath.sqrt(x::T) where {T<:Number} =
hasnan(T) || error("NaNMath not supported for this type as no NaN for this type is possible")
x < zero(x) ? nangen(x) : Base.sqrt(float(x))
end and so then you need to define |
That's similar to how I would imagine it.
I don't think that's needed. I would use NaNMath.sqrt(x::Complex) = Base.sqrt(float(x))
NaNMath.sqrt(x::Number) = x < zero(x) ? typeof(Base.sqrt(-float(x)))(NaN) : Base.sqrt(float(x)) should always be fine: julia> using Unitful
julia> NaNMath.sqrt(-1)
NaN
julia> NaNMath.sqrt(-1.0)
NaN
julia> NaNMath.sqrt(big"-1")
NaN
julia> typeof(NaNMath.sqrt(big"-1"))
BigFloat
julia> NaNMath.sqrt(u"-1s^2")
NaN s Edit: I edited the initial version such that unitful numbers are supported as well. Note that the units should be Or maybe slightly simpler: NaNMath.sqrt(x::Complex) = Base.sqrt(float(x))
NaNMath.sqrt(x::Number) = x < zero(x) ? Base.sqrt(typeof(float(x))(NaN)) : Base.sqrt(float(x)) |
I would be fine with this form. It's close to what's in there but handles the non AbstractFloat stuff better. Could you change to this form? If we all agree this is sufficiently safe then I think we get the best of both worlds? |
I propose to revert #71 which added generally unsafe fallback definitions. There's no guarantee in general that they do not error since argument values are not checked and no type constraints are applied.