Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
### Changed
- Move magic methods (`__radd__`, `__sub__`, `__rsub__`, `__rmul__`, `__richcmp__`, `__neg__`, and `__rtruediv__`) to `ExprLike` base class (#1204)
- Speed up `Expr.__add__` and `Expr.__iadd__` via the C-level API
- Replace Python math with C-level math functions and refactor unary expressions.
### Removed

## 6.2.1 - 2026.05.16
Expand Down
79 changes: 57 additions & 22 deletions src/pyscipopt/expr.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
# which should, in princple, modify the expr. However, since we do not implement __isub__, __sub__
# gets called (I guess) and so a copy is returned.
# Modifying the expression directly would be a bug, given that the expression might be re-used by the user. </pre>
import math
from typing import TYPE_CHECKING, Literal, Union

import numpy as np
Expand All @@ -54,6 +53,12 @@ from cpython.number cimport PyNumber_Check
from cpython.object cimport Py_LE, Py_EQ, Py_GE, Py_TYPE
from cpython.ref cimport PyObject
from cpython.tuple cimport PyTuple_GET_ITEM
from libc.math cimport cos as c_cos
from libc.math cimport exp as c_exp
from libc.math cimport fabs as c_fabs
from libc.math cimport log as c_log
from libc.math cimport sqrt as c_sqrt
from libc.math cimport sin as c_sin

cimport numpy as cnp
from pyscipopt.scip cimport Variable, Solution
Expand Down Expand Up @@ -278,23 +283,23 @@ cdef class ExprLike:
def __neg__(self, /) -> Union[Expr, GenExpr]:
return self * -1.0

def __abs__(self) -> GenExpr:
return UnaryExpr(Operator.fabs, buildGenExprObj(self))
def __abs__(self, /) -> AbsExpr:
return AbsExpr(Operator.fabs, buildGenExprObj(self))

def exp(self) -> GenExpr:
return UnaryExpr(Operator.exp, buildGenExprObj(self))
def exp(self, /) -> ExpExpr:
return ExpExpr(Operator.exp, buildGenExprObj(self))

def log(self) -> GenExpr:
return UnaryExpr(Operator.log, buildGenExprObj(self))
def log(self, /) -> LogExpr:
return LogExpr(Operator.log, buildGenExprObj(self))

def sqrt(self) -> GenExpr:
return UnaryExpr(Operator.sqrt, buildGenExprObj(self))
def sqrt(self, /) -> SqrtExpr:
return SqrtExpr(Operator.sqrt, buildGenExprObj(self))

def sin(self) -> GenExpr:
return UnaryExpr(Operator.sin, buildGenExprObj(self))
def sin(self, /) -> SinExpr:
return SinExpr(Operator.sin, buildGenExprObj(self))

def cos(self) -> GenExpr:
return UnaryExpr(Operator.cos, buildGenExprObj(self))
def cos(self, /) -> CosExpr:
return CosExpr(Operator.cos, buildGenExprObj(self))


##@details Polynomial expressions of variables with operator overloading. \n
Expand Down Expand Up @@ -799,24 +804,54 @@ cdef class PowExpr(GenExpr):
return (<GenExpr>self.children[0])._evaluate(sol) ** self.expo


# Exp, Log, Sqrt, Sin, Cos Expressions
cdef class UnaryExpr(GenExpr):

def __init__(self, op, expr):
self.children = []
self.children.append(expr)
self._op = op

def __abs__(self) -> UnaryExpr:
if self._op == "abs":
return <UnaryExpr>self.copy()
return UnaryExpr(Operator.fabs, self)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExprLike.__abs__ can replace UnaryExpr(Operator.fabs, self)


def __repr__(self):
def __repr__(self) -> str:
return self._op + "(" + self.children[0].__repr__() + ")"


cdef class AbsExpr(UnaryExpr):

def __abs__(self) -> AbsExpr:
return <AbsExpr>self.copy()

cpdef double _evaluate(self, Solution sol) except *:
return c_fabs((<GenExpr>self.children[0])._evaluate(sol))

Comment thread
Zeroto521 marked this conversation as resolved.

cdef class ExpExpr(UnaryExpr):

cpdef double _evaluate(self, Solution sol) except *:
return c_exp((<GenExpr>self.children[0])._evaluate(sol))


cdef class LogExpr(UnaryExpr):

cpdef double _evaluate(self, Solution sol) except *:
return c_log((<GenExpr>self.children[0])._evaluate(sol))


cdef class SqrtExpr(UnaryExpr):

cpdef double _evaluate(self, Solution sol) except *:
return c_sqrt((<GenExpr>self.children[0])._evaluate(sol))


cdef class SinExpr(UnaryExpr):

cpdef double _evaluate(self, Solution sol) except *:
return c_sin((<GenExpr>self.children[0])._evaluate(sol))


cdef class CosExpr(UnaryExpr):

cpdef double _evaluate(self, Solution sol) except *:
cdef double res = (<GenExpr>self.children[0])._evaluate(sol)
return math.fabs(res) if self._op == "abs" else getattr(math, self._op)(res)
return c_cos((<GenExpr>self.children[0])._evaluate(sol))


# class for constant expressions
Expand Down
20 changes: 13 additions & 7 deletions src/pyscipopt/scip.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,12 @@ class ExprLike:
def __rmul__(self, other: object, /) -> Incomplete: ...
def __rtruediv__(self, other: object, /) -> GenExpr: ...
def __neg__(self, /) -> Union[Expr, GenExpr]: ...
def __abs__(self) -> GenExpr: ...
def exp(self) -> GenExpr: ...
def log(self) -> GenExpr: ...
def sqrt(self) -> GenExpr: ...
def sin(self) -> GenExpr: ...
def cos(self) -> GenExpr: ...
def __abs__(self, /) -> AbsExpr: ...
def exp(self, /) -> ExpExpr: ...
def log(self, /) -> LogExpr: ...
def sqrt(self, /) -> SqrtExpr: ...
def sin(self, /) -> SinExpr: ...
def cos(self, /) -> CosExpr: ...

@disjoint_base
class Expr(ExprLike):
Expand Down Expand Up @@ -2262,7 +2262,13 @@ class Term:

class UnaryExpr(GenExpr):
def __init__(self, *args: Incomplete, **kwargs: Incomplete) -> None: ...
def __abs__(self) -> GenExpr: ...

class AbsExpr(UnaryExpr): ...
class ExpExpr(UnaryExpr): ...
class LogExpr(UnaryExpr): ...
class SqrtExpr(UnaryExpr): ...
class SinExpr(UnaryExpr): ...
class CosExpr(UnaryExpr): ...

@disjoint_base
class VarExpr(GenExpr):
Expand Down
10 changes: 9 additions & 1 deletion tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,16 @@ def test_getVal_with_GenExpr():
assert m.getVal(y / x) == 2
# test "**(prod(1.0,**(sum(0.0,prod(1.0,x)),-1)),2)"
assert m.getVal((1 / x) ** 2) == 1
# test "sin(sum(0.0,prod(1.0,x)))"

# test C-level math functions
assert m.getVal(abs(x)) == 1
assert m.getVal(abs(-x)) == 1
assert m.getVal(abs(abs(-x))) == 1
assert round(m.getVal(exp(x)), 6) == round(math.exp(1), 6)
assert round(m.getVal(log(x)), 6) == round(math.log(1), 6)
assert round(m.getVal(sqrt(x)), 6) == round(math.sqrt(1), 6)
assert round(m.getVal(sin(x)), 6) == round(math.sin(1), 6)
assert round(m.getVal(cos(x)), 6) == round(math.cos(1), 6)

with pytest.raises(TypeError):
m.getVal(1)
Expand Down
Loading