Advanced Usage¶
Writing Custom Functions¶
Sometimes one may have functions that go beyond the Solverz built-in library. This guide will describe how to create such custom functions in Solverz, so that the functions can be incorporated into numerical simulations. The philosophy of function customization comes from Sympy, it helps to learn the Sympy basics and read the Sympy tutorial of custom functions for an overview.
As a motivating example for this document, let’s create a custom function class representing the \(\min\) function. We want use \(\min\) to determine the smaller one of two operands, which can be defined by
We also want to extend the function to vector input, that is, capable of finding the element-wise minimum of two vectors. To summarize, we shall implement \(\min\) that
evaluates \(\min(x,y)\) correctly
can be derived proper derivatives with respect to \(x\) and \(y\).
However, it is difficult to devise the analytical derivatives of \(\min\). We should perform the trick that rewrites \(\min(x,y)\) as
where the \(\operatorname{lessthan}(x,y)\) function mathematically denotes the \(\leq\) operator and returns 1 if \(x\leq y\) else 0. Since \(\operatorname{lessthan}(x,y)\) can only be either 1 or 0, the above transformation holds.
If the derivatives of \(\operatorname{lessthan}(x,y)\) with respect to any argument are zero, then we have
and
Hence, it suffices to have a custom \(\operatorname{lessthan}(x,y)\) function that
evaluates \(\operatorname{lessthan}(x,y)\) correctly
has zero-derivative with respect to \(x\) or \(y\).
Let us start by subclassing MulVarFunc
from Solverz.sym_algebra.functions import MulVarFunc
class Min(MulVarFunc):
pass
class LessThan(MulVarFunc):
pass
The MulVarFunc is the base class of multi-variate functions in Solverz.
At this point, Min has no behaviors defined on it. To automatically evaluate the Min function, we ought to define
the class method eval(). eval() should take the arguments of the function and return the value
\(x*\operatorname{lessthan}(x,y)+y*(1-\operatorname{lessthan}(x,y))\):
class Min(MulVarFunc):
@classmethod
def eval(cls, x, y):
return x * LessThan(x, y) + y * (1 - LessThan(x, y))
>>> from Solverz import Var
>>> Min(Var('x',0),Var('y',0))
... x*((x)<=(y)) + y*(1 - ((x)<=(y)))
To define the differentiation of LessThan(), we have
from sympy import Integer
class LessThan(MulVarFunc):
"""
Represents < operator
"""
def _eval_derivative(self, s):
return Integer(0)
def _sympystr(self, printer, **kwargs):
return '(({op1})<=({op2}))'.format(op1=printer._print(self.args[0]),
op2=printer._print(self.args[1]))
def _numpycode(self, printer, **kwargs):
return r'SolLessThan(' + ', '.join([printer._print(arg, **kwargs) for arg in self.args]) + r')'
def _lambdacode(self, printer, **kwargs):
return self._numpycode(printer, **kwargs)
def _pythoncode(self, printer, **kwargs):
return self._numpycode(printer, **kwargs)
Here, the _sympystr() method defines its string representation:
>>> LessThan(Var('x'), Var('y'))
... ((x)<=(y))
The _eval_derivative() method forces the derivatives of LessThan() to be zero:
from Solverz import iVar
>>> Min(Var('x',0),Var('y',0)).diff(iVar('x'))
... ((x)<=(y))
where iVar is the internal variable type of Solverz, diff() is the method to derive derivatives.
The _numpycode() function defines what should LessThan() be printed to in numerical codes. Here, we define the
SolLessThan() as the numerical implementation of LessThan(). Given array [0,2,-1] and [1,2,3]:
>>> import numpy as np
>>> SolLessThan(np.array([0, 2, -1]), np.array([1,2,3]))
... array([1, 0, 1])
Note
In Solverz, the numerical computations are mainly dependent on the prevailing numerical libraries such as numpy and scipy. It is recommended that one first gets familiar with the numpy and scipy.
The implementation of SolLessThan() should be put in the Solverz.num_api.custom_function module:
@implements_nfunc('SolLessThan')
@njit(cache=True)
def SolLessThan(x, y):
x = np.asarray(x).reshape((-1,))
return (x < y).astype(np.int32)
The implements_nfunc() cannot be omitted and the njit() decorator enables the numba-based dynamic compilation for efficiency.