Extending Matrix Calculus¶
This guide explains how to add new operations to the matrix-vector calculus module. It is intended for developers and LLM-assisted development.
Architecture Overview¶
The matrix calculus module consists of four main components:
User Expression (SymPy)
│
▼
obtain_TExpr() ← Convert SymPy nodes to tensor nodes
│
▼
ToComGraph() ← Build computation DAG
│
▼
TensorExpr.diff() ← Forward-mode differentiation on DAG
│
▼
TMul2Mul() ← Resolve tensor indices to matrix operations
│
▼
Symbolic Derivative (SymPy)
Key files:
Solverz/sym_algebra/matrix_calculus.py— tensor nodes, DAG, differentiation, TMul2MulSolverz/sym_algebra/functions.py— symbolic function definitions withfdiff()Solverz/sym_algebra/test/test_matrix_calculus.py— tests
Adding a New Element-wise Unary Function¶
Element-wise unary functions (like exp, sin, tanh) are automatically handled by
the TUnaryFunc node. You only need to:
Step 1: Define the function in functions.py¶
class tanh(UniVarFunc):
r"""Hyperbolic tangent function."""
def fdiff(self, argindex=1):
if argindex == 1:
return 1 - tanh(self.args[0])**2
raise ArgumentIndexError(self, argindex)
def _numpycode(self, printer, **kwargs):
return r'np.tanh(' + printer._print(self.args[0]) + r')'
Key requirements:
Inherit from
UniVarFuncImplement
fdiff()returning the derivative without the chain rule (the chain rule is handled automatically by the tensor calculus)Implement
_numpycode()for numerical code generation
Step 2: No changes needed in matrix_calculus.py¶
The TUnaryFunc node automatically recognizes all UniVarFunc subclasses
via isinstance(expr, UniVarFunc) in obtain_TExpr().
Step 3: Add a test¶
from Solverz.sym_algebra.functions import tanh
A = Para('A', dim=2)
x = iVar('x')
expr = tanh(Mat_Mul(A, x))
te = TensorExpr(expr)
# d/dx tanh(A@x) = diag(1 - tanh(A@x)**2) @ A
assert te.diff(x).__repr__() == 'diag(1 - tanh(A@x)**2)@A'
Adding a New Node Type¶
For operations that are not element-wise unary functions, you need to create a new tensor node class. Here is a step-by-step guide using a hypothetical matrix trace operation as an example.
Step 1: Create the T-node class¶
class TTrace:
"""Trace node: tr(A) = sum of diagonal elements.
If A has index (i,j), tr(A) is a scalar (no index).
Derivative: d tr(A)/dA = I (identity matrix).
"""
def __init__(self, expr, index: TensorIndex):
self.index = index
self.expr = expr
self.args = expr.args
def __repr__(self):
return f"$Tr_{self.index}$"
def __hash__(self):
return hash(tuple([*self.args, self.__repr__()]))
def __eq__(self, other):
if not isinstance(other, TTrace):
return False
return self.args == other.args and self.__repr__() == other.__repr__()
Step 2: Add to obtain_TExpr()¶
def obtain_TExpr(expr, index):
# ... existing cases ...
elif isinstance(expr, Trace):
return TTrace(expr, index)
# ...
Step 3: Add DAG building logic in ToComGraph()¶
elif isinstance(Texpr, TTrace):
# tr(A): child A has matrix index, output is scalar
child_index = TensorIndex([0, 1]) # assign fresh 2D index
queue.append((Texpr.args[0], child_index, Texpr))
Step 4: Add differentiation rule in TensorExpr.diff()¶
elif isinstance(node, TTrace):
succ = list(self.ComGraph.successors(node))
Aprime = derivatives[succ[0]][0]
Aprime_idx = derivatives[succ[0]][1]
# d tr(A)/dA_ij = delta_ij, so the result picks the diagonal of A'
# ... implement the specific derivative rule ...
derivatives[node] = (result_expr, result_idx)
Step 5: Add TMul2Mul rules if needed¶
If the differentiation produces new index combinations not already handled
by TMul2Mul, add them:
def TMul2Mul(x, y, index):
# ... existing cases ...
# (new_pattern) → corresponding matrix operation
if new_condition:
return corresponding_matrix_expression
Step 6: Add tests¶
Always test:
The derivative result matches the expected symbolic expression
The result is mathematically correct (verify against manual derivation or matrixcalculus.org)
TMul2Mul Index Reference¶
Complete list of currently handled index patterns:
s1 (x) |
s2 (y) |
s3 (result) |
Operation |
Condition |
|---|---|---|---|---|
scalar |
any |
any |
|
s1 or s2 is scalar |
|
|
|
|
element-wise |
|
|
|
|
outer product |
|
|
|
|
vector → diagonal |
|
|
|
|
i == j[0] |
|
|
|
|
i == j[1] |
|
|
|
|
requires x is Diag |
|
|
|
|
diagonal × diagonal |
|
|
|
|
i == j[0] |
|
|
|
|
matrix-vector |
|
|
|
|
|
|
|
|
|
element-wise matrix |
|
|
|
|
only if y is Diag |
|
|
|
|
matrix multiply |
|
|
|
|
|
|
|
|
|
When an unhandled pattern is encountered, TMul2Mul raises NotImplementedError
with the specific index values, which helps identify the new rule needed.