Extending Matrix Calculus

This guide explains how to add new operations to the matrix-vector calculus module. It is intended for developers and LLM-assisted development.

Architecture Overview

The matrix calculus module consists of four main components:

User Expression (SymPy)
    │
    ▼
obtain_TExpr()          ← Convert SymPy nodes to tensor nodes
    │
    ▼
ToComGraph()            ← Build computation DAG
    │
    ▼
TensorExpr.diff()       ← Forward-mode differentiation on DAG
    │
    ▼
TMul2Mul()              ← Resolve tensor indices to matrix operations
    │
    ▼
Symbolic Derivative (SymPy)

Key files:

  • Solverz/sym_algebra/matrix_calculus.py — tensor nodes, DAG, differentiation, TMul2Mul

  • Solverz/sym_algebra/functions.py — symbolic function definitions with fdiff()

  • Solverz/sym_algebra/test/test_matrix_calculus.py — tests

Adding a New Element-wise Unary Function

Element-wise unary functions (like exp, sin, tanh) are automatically handled by the TUnaryFunc node. You only need to:

Step 1: Define the function in functions.py

class tanh(UniVarFunc):
    r"""Hyperbolic tangent function."""

    def fdiff(self, argindex=1):
        if argindex == 1:
            return 1 - tanh(self.args[0])**2
        raise ArgumentIndexError(self, argindex)

    def _numpycode(self, printer, **kwargs):
        return r'np.tanh(' + printer._print(self.args[0]) + r')'

Key requirements:

  • Inherit from UniVarFunc

  • Implement fdiff() returning the derivative without the chain rule (the chain rule is handled automatically by the tensor calculus)

  • Implement _numpycode() for numerical code generation

Step 2: No changes needed in matrix_calculus.py

The TUnaryFunc node automatically recognizes all UniVarFunc subclasses via isinstance(expr, UniVarFunc) in obtain_TExpr().

Step 3: Add a test

from Solverz.sym_algebra.functions import tanh

A = Para('A', dim=2)
x = iVar('x')
expr = tanh(Mat_Mul(A, x))
te = TensorExpr(expr)
# d/dx tanh(A@x) = diag(1 - tanh(A@x)**2) @ A
assert te.diff(x).__repr__() == 'diag(1 - tanh(A@x)**2)@A'

Adding a New Node Type

For operations that are not element-wise unary functions, you need to create a new tensor node class. Here is a step-by-step guide using a hypothetical matrix trace operation as an example.

Step 1: Create the T-node class

class TTrace:
    """Trace node: tr(A) = sum of diagonal elements.
    
    If A has index (i,j), tr(A) is a scalar (no index).
    Derivative: d tr(A)/dA = I (identity matrix).
    """
    def __init__(self, expr, index: TensorIndex):
        self.index = index
        self.expr = expr
        self.args = expr.args

    def __repr__(self):
        return f"$Tr_{self.index}$"

    def __hash__(self):
        return hash(tuple([*self.args, self.__repr__()]))

    def __eq__(self, other):
        if not isinstance(other, TTrace):
            return False
        return self.args == other.args and self.__repr__() == other.__repr__()

Step 2: Add to obtain_TExpr()

def obtain_TExpr(expr, index):
    # ... existing cases ...
    elif isinstance(expr, Trace):
        return TTrace(expr, index)
    # ...

Step 3: Add DAG building logic in ToComGraph()

elif isinstance(Texpr, TTrace):
    # tr(A): child A has matrix index, output is scalar
    child_index = TensorIndex([0, 1])  # assign fresh 2D index
    queue.append((Texpr.args[0], child_index, Texpr))

Step 4: Add differentiation rule in TensorExpr.diff()

elif isinstance(node, TTrace):
    succ = list(self.ComGraph.successors(node))
    Aprime = derivatives[succ[0]][0]
    Aprime_idx = derivatives[succ[0]][1]
    # d tr(A)/dA_ij = delta_ij, so the result picks the diagonal of A'
    # ... implement the specific derivative rule ...
    derivatives[node] = (result_expr, result_idx)

Step 5: Add TMul2Mul rules if needed

If the differentiation produces new index combinations not already handled by TMul2Mul, add them:

def TMul2Mul(x, y, index):
    # ... existing cases ...
    # (new_pattern) → corresponding matrix operation
    if new_condition:
        return corresponding_matrix_expression

Step 6: Add tests

Always test:

  1. The derivative result matches the expected symbolic expression

  2. The result is mathematically correct (verify against manual derivation or matrixcalculus.org)

TMul2Mul Index Reference

Complete list of currently handled index patterns:

s1 (x)

s2 (y)

s3 (result)

Operation

Condition

scalar

any

any

x * y

s1 or s2 is scalar

i

i

i

x * y

element-wise

i

j (j!=i)

ij

x @ transpose(y)

outer product

i

ii

ii

Diag(x * y)

vector → diagonal

i

ij

ij

Diag(x) @ y

i == j[0]

i

ji

ji

y @ Diag(x)

i == j[1]

ii

i

i

Diag(x).diag * y

requires x is Diag

ii

ii

ii

Diag(x * y)

diagonal × diagonal

ii

ij

ij

Diag(x) @ y

i == j[0]

ij

j

i

x @ y

matrix-vector

ij

i

j

transpose(x) @ y

ij

ij

ij

x * y

element-wise matrix

ij

jj

ij

x @ Diag(y)

only if y is Diag

ij

jk

ik

x @ y

matrix multiply

ij

kj

ik

x @ transpose(y)

ji

jk

ik

transpose(x) @ y

When an unhandled pattern is encountered, TMul2Mul raises NotImplementedError with the specific index values, which helps identify the new rule needed.