Custom neural network combined with tensor train
Let’s hacking Pompon
!
One will use tensor train implementaion in Pompon
and customized neural network build by Flax
Before start, install Flax
. In this example, we will follow flax.nnx
API rather than flax.linen
API.
Import modules
Prepare training data
We will use \[ y = \boldsymbol{x}A\boldsymbol{x}^\top \] as a true function
Code
def generate_data(num_samples, key, include_f=True):
mean = jnp.zeros(3)
cov = jnp.eye(3)
k = 1.0
A = k * jnp.linalg.inv(cov)
x_samples = jax.random.multivariate_normal(
key=key, mean=mean, cov=cov, shape=(num_samples,), dtype=DTYPE
)
y_samples = jnp.einsum("ni,ij,nj->n", x_samples, A, x_samples)
if include_f:
f_samples = -2.0 * x_samples @ A
return x_samples, y_samples[:, jnp.newaxis], f_samples
else:
return x_samples, y_samples[:, jnp.newaxis]
num_samples = 256
key = jax.random.PRNGKey(0)
x_samples, y_samples, f_samples = generate_data(num_samples, key)
print("x samples:", x_samples.shape, x_samples.dtype)
print("y samples:", y_samples.shape, y_samples.dtype)
print("f samples:", f_samples.shape, f_samples.dtype)
x samples: (256, 3) float64
y samples: (256, 1) float64
f samples: (256, 3) float64
Define Custom Network
Define Basis function
Define network like this;
where \(\phi_{\rho_i}(x_i): \mathbb{R}\to\mathbb{R}^d\) and \(d\) is a number of basis (chunk_size
).
Code
class CustomBasis(nnx.Module):
"""
Args:
num_chunks (int): degree of freedoms n
hidden_size (int): hidden layer size
chunk_size (int): number of basis for each mode
n_layers (int): number of hidden layers
rngs (nnx.Rngs): random generator
"""
def __init__(
self,
num_chunks: int,
hidden_size: int,
chunk_size: int,
n_layers: int,
x_scale: jax.Array,
rngs: nnx.Rngs,
):
self.num_chunks = num_chunks
self.hidden_size = hidden_size
self.chunk_size = chunk_size
self.n_layers = n_layers
self.x_scale = x_scale
initializer = jax.nn.initializers.glorot_uniform()
self.w_ini = nnx.Param(
initializer(
rngs.params(),
(self.num_chunks, self.hidden_size),
dtype=DTYPE,
)
)
self.b_ini = nnx.Param(
jnp.zeros((self.num_chunks, self.hidden_size), dtype=DTYPE)
)
self.w_mid = []
self.b_mid = []
for _ in range(n_layers):
self.w_mid.append(
nnx.Param(
initializer(
rngs.params(),
(self.hidden_size, self.hidden_size),
dtype=DTYPE,
)
)
)
self.b_mid.append(
nnx.Param(
jnp.zeros((self.num_chunks, self.hidden_size), dtype=DTYPE)
)
)
self.w_fin = nnx.Param(
initializer(
rngs.params(),
(self.hidden_size, self.chunk_size),
dtype=DTYPE,
)
)
self.b_fin = nnx.Param(
jnp.zeros((self.num_chunks, self.chunk_size), dtype=DTYPE)
)
def __call__(self, x: jax.Array, i: int | None = None) -> list[jax.Array]:
r"""
Args:
x (jax.Array): input position with shape (batch_size, DOFs)
i (int, optional): If you need only i-th basis $\phi_{\rho_i}$,
set this integer.
Returns:
list[jax.Array]: list of $\phi_{\rho_i}$ with length `num_chunks`.
Each basis has shape `(chunk_size,)`.
"""
if i is None:
index: slice = slice(0, x.shape[1])
else:
index: slice = slice(i, i + 1)
if x.ndim == 1:
x = x[:, jnp.newaxis] # batch, 1
elif x.ndim == 2:
assert x.shape[1] == 1
else:
raise ValueError(f"{x.ndim=} is invalid")
ndim = x.shape[1]
x = x / self.x_scale[jnp.newaxis, index]
x = (
self.w_ini[jnp.newaxis, index, :] * x[:, :, jnp.newaxis]
) # batch, num_chunks, hidden_size
x = x + self.b_ini[jnp.newaxis, index, :]
x = nnx.swish(x)
for i in range(self.n_layers):
res = x
x = x @ self.w_mid[i] # batch, num_chunks, hidden_size
x = x + self.b_mid[i][jnp.newaxis, index, :]
x = nnx.swish(x) + res
x = x @ self.w_fin # batch, num_chunks, chunk_size
x = x + self.b_fin[jnp.newaxis, index, :]
x = nnx.swish(x)
phi_chunked_array = jnp.split(x, ndim, axis=1)
phi_chunked_list = [item.squeeze(1) for item in phi_chunked_array]
return phi_chunked_list
x = jnp.arange(3 * 2).reshape(2, 3)
basis = CustomBasis(
num_chunks=3, hidden_size=4, chunk_size=2, n_layers=3, x_scale=x.std(axis=0),rngs=nnx.Rngs(0)
)
print(basis(x=x))
print(basis(x=x[:, 0:1], i=0))
[Array([[ 0. , 0. ],
[-0.27816969, -0.27609677]], dtype=float64), Array([[0.16544799, 0.16702912],
[0.47790957, 0.3542889 ]], dtype=float64), Array([[ 0.67407335, 0.36907104],
[ 0.45613932, -0.21072378]], dtype=float64)]
[Array([[ 0. , 0. ],
[-0.27816969, -0.27609677]], dtype=float64)]
Define Custom Model including the Custom Basis
Code
class CustomModel(nnx.Module):
"""
Args:
num_chunks (int): degree of freedoms n
hidden_size (int): hidden layer size
chunk_size (int): number of basis for each mode
n_layers (int): number of hidden layers
rngs (nnx.Rngs): random generator
"""
def __init__(
self,
num_chunks: int,
hidden_size: int,
chunk_size: int,
n_layers: int,
x_scale: jax.Array,
rngs: nnx.Rngs,
):
self.num_chunks = num_chunks
self.hidden_size = hidden_size
self.chunk_size = chunk_size
self.n_layers = n_layers
self.basis = CustomBasis(
num_chunks=self.num_chunks,
hidden_size=self.hidden_size,
chunk_size=self.chunk_size,
n_layers=self.n_layers,
x_scale=x_scale,
rngs=rngs,
)
def __call__(self, x: jax.Array, W: list[jax.Array], norm: jax.Array):
"""
Args:
x (jax.Array): input position with shape (batch_size, DOFs)
W (list[jax.Array]): list of tensor train core.
Each core has shape (bond_dim1, chunk_size, bond_dim2).
norm (jax.Array): Scalar which scales outputs.
"""
basis = self.basis(x)
return _forward_basis2y(basis=basis, W=W, norm=norm)
model = CustomModel(
num_chunks=3, hidden_size=16, chunk_size=2, n_layers=16, x_scale=x_samples.std(axis=0), rngs=nnx.Rngs(0)
)
nnx.display(model)
CustomModel(
num_chunks=3,
hidden_size=16,
chunk_size=2,
n_layers=16,
basis=CustomBasis(
num_chunks=3,
hidden_size=16,
chunk_size=2,
n_layers=16,
x_scale=Array(shape=(3,), dtype=float64),
w_ini=Param(
value=Array(shape=(3, 16), dtype=float64)
),
b_ini=Param(
value=Array(shape=(3, 16), dtype=float64)
),
w_mid=[Param(
value=Array(shape=(16, 16), dtype=float64)
), Param(
value=Array(shape=(16, 16), dtype=float64)
), Param(
value=Array(shape=(16, 16), dtype=float64)
), Param(
value=Array(shape=(16, 16), dtype=float64)
), Param(
value=Array(shape=(16, 16), dtype=float64)
), Param(
value=Array(shape=(16, 16), dtype=float64)
), Param(
value=Array(shape=(16, 16), dtype=float64)
), Param(
value=Array(shape=(16, 16), dtype=float64)
), Param(
value=Array(shape=(16, 16), dtype=float64)
), Param(
value=Array(shape=(16, 16), dtype=float64)
), Param(
value=Array(shape=(16, 16), dtype=float64)
), Param(
value=Array(shape=(16, 16), dtype=float64)
), Param(
value=Array(shape=(16, 16), dtype=float64)
), Param(
value=Array(shape=(16, 16), dtype=float64)
), Param(
value=Array(shape=(16, 16), dtype=float64)
), Param(
value=Array(shape=(16, 16), dtype=float64)
)],
b_mid=[Param(
value=Array(shape=(3, 16), dtype=float64)
), Param(
value=Array(shape=(3, 16), dtype=float64)
), Param(
value=Array(shape=(3, 16), dtype=float64)
), Param(
value=Array(shape=(3, 16), dtype=float64)
), Param(
value=Array(shape=(3, 16), dtype=float64)
), Param(
value=Array(shape=(3, 16), dtype=float64)
), Param(
value=Array(shape=(3, 16), dtype=float64)
), Param(
value=Array(shape=(3, 16), dtype=float64)
), Param(
value=Array(shape=(3, 16), dtype=float64)
), Param(
value=Array(shape=(3, 16), dtype=float64)
), Param(
value=Array(shape=(3, 16), dtype=float64)
), Param(
value=Array(shape=(3, 16), dtype=float64)
), Param(
value=Array(shape=(3, 16), dtype=float64)
), Param(
value=Array(shape=(3, 16), dtype=float64)
), Param(
value=Array(shape=(3, 16), dtype=float64)
), Param(
value=Array(shape=(3, 16), dtype=float64)
)],
w_fin=Param(
value=Array(shape=(16, 2), dtype=float64)
),
b_fin=Param(
value=Array(shape=(3, 2), dtype=float64)
)
)
)
Extract required arguments from TensorTrain modules
Initialize tensor train.
See also pompon.layers.tt.TensorTrain.
Code
2.738025133756779 [Array([[[ 0.269559 , 0.48521447],
[-0.66477019, -0.49998546]]], dtype=float64), Array([[[-2.69638259, 0.02057865],
[ 6.39613801, -0.46289717]],
[[-1.33265668, -0.4078536 ],
[ 5.44471885, 0.45414894]]], dtype=float64), Array([[[-0.47058099],
[ 0.88235681]],
[[-0.88235681],
[-0.47058099]]], dtype=float64)]
[Array([[-1.12089883e-02, 4.54244220e-01],
[-3.01759052e-47, 4.94457722e+00]], dtype=float64), Array([[-2.47008425e-01, -2.40598015e-01],
[-2.47244541e-19, -4.38226713e-02]], dtype=float64), Array([[-8.55275875e-18, 5.79409551e-01],
[-8.68260031e-86, -1.70533086e-02]], dtype=float64)]
[[ 0.71465496]
[-0.06264489]]
Train with energy
Train basis with fixed tensor-train
Visualize basis
Code
def show_basis(model, tt):
x0 = np.linspace(-1, 1, 100)
x = np.zeros((100, 3))
x[:, 0] = x0
basis = model.basis(jnp.array(x))
for i in range(model.chunk_size):
plt.plot(x0, np.array(basis[0][:, i]), label=f"basis-{i}")
plt.plot(x0, x0**2, label="true")
norm, W = get_norm_W(tt)
y = model(jnp.array(x), W, norm)
plt.plot(x0, np.array(y.squeeze(1)), label="prediciton")
plt.legend()
plt.xlabel("$x_1$")
plt.ylabel("Energy")
plt.show()
show_basis(model, tt)
Train basis iteratively
Show training trace
Train tensor train with fixed basis
See also pompon.optimizer.sweeper.sweep.
Signature: sweep( *, tt: pompon.layers.tt.TensorTrain, basis: list[jax.Array], y: jax.Array, nsweeps: int = 2, maxdim: int | list[int] | numpy.ndarray = 30, cutoff: float | list[float] | numpy.ndarray = 0.01, optax_solver: optax._src.base.GradientTransformation | None = None, opt_maxiter: int = 1000, opt_tol: float | list[float] | numpy.ndarray | None = None, opt_lambda: float = 0.0, onedot: bool = False, use_CG: bool = False, use_scipy: bool = False, use_jax_scipy: bool = False, method: str = 'L-BFGS-B', ord: str = 'fro', auto_onedot: bool = True, ) Docstring: Tensor-train sweep optimization Args: tt (TensorTrain): the tensor-train model. basis (list[Array]): the basis functions. y (Array): the target values. nsweeps (int): The number of sweeps. maxdim (int, list[int]): the maximum rank of TT-sweep. cutoff (float, list[float]): the ratio of truncated singular values for TT-sweep. When one-dot core is optimized, this parameter is not used. optax_solver (optax.GradientTransformation): the optimizer for TT-sweep. Defaults to None. If None, the optimizer is not used. opt_maxiter (int): the maximum number of iterations for TT-sweep. opt_tol (float, list[float]): the convergence criterion of gradient for TT-sweep. Defaults to None, i.e., opt_tol = cutoff. opt_lambda (float): the L2 regularization parameter for TT-sweep. Only use_CG=True is supported. onedot (bool, optional): whether to optimize one-dot or two-dot core. Defaults to False, i.e. two-dot core optimization. use_CG (bool, optional): whether to use conjugate gradient method for TT-sweep. Defaults to False. CG is suitable for one-dot core optimization. use_scipy (bool, optional): whether to use scipy.optimize.minimize for TT-sweep. Defaults to False and use L-BFGS-B method. GPU is not supported. use_jax_scipy (bool, optional): whether to use jax.scipy.optimize.minimize for TT-sweep. Defaults to False. This optimizer is only supports BFGS method, which exhausts GPU memory. method (str, optional): the optimization method for scipy.optimize.minimize. Defaults to 'L-BFGS-B'. Note that jax.scipy.optimize.minimize only supports 'BFGS'. ord (str, optional): the norm for scaling the initial core. Defaults to 'fro', Frobenuis norm. 'max`, maximum absolute value, 'fro', Frobenius norm, are supported. auto_onedot (bool, optional): whether to switch to one-dot core optimization automatically once the maximum rank is reached. Defaults to True. This will cause overfitting in the beginning of the optimization. File: ~/GitHub/Pompon/pompon/optimizer/sweeper.py Type: function
Repeat tensor train optimization & basis optimization
Code
tx = optax.adam(1e-5)
optimizer = nnx.Optimizer(model, tx)
for i in tqdm(range(10000)):
if i % 500 == 0:
norm, W = get_norm_W(tt)
basis = model.basis(x_samples)
sweep(
tt=tt,
basis=basis,
y=y_samples,
nsweeps=1,
maxdim=2,
opt_maxiter=200,
# optax_solver=optax.adam(1.e-04),
use_CG=True,
onedot=False,
auto_onedot=False,
)
norm, W = get_norm_W(tt)
loss = train_step(model, optimizer, x_samples, y_samples, W, norm)
losses.append(loss)
Train with forces
Train basis with fixed tensor-train
Define function that returns \(-\frac{\partial V}{\partial x}\) by auto differentiation
Code
from functools import partial
def get_force(x, model, norm, W):
def func(x, model, norm, W):
if x.ndim == 1:
# When no-batch
x = x[jnp.newaxis, :]
return -1.0 * model(x, norm=norm, W=W).squeeze(0)
else:
return -1.0 * model(x, norm=norm, W=W)
ener_fn = partial(func, model=model, norm=norm, W=W)
jacobian = jax.vmap(jax.jacrev(ener_fn))(x) # (batch, out, in)
return jacobian.squeeze(-2) # (batch, in)
Define train_step which has a loss function \[ \mathcal{L} = \frac{1}{2|\mathcal{D}|} \sum_{\boldsymbol{x}, E, \boldsymbol{F} \in \mathcal{D}} \left[\left(\hat{V}(\boldsymbol{x})-V\right)^2 + \left(\hat{\mathbf{F}}(\boldsymbol{x})-\mathbf{F}\right)^2\right] \]
Code
@nnx.jit
def train_step(model, optimizer, x, y, f, W, norm):
def _loss_fn(model):
pred = model(x, W, norm)
loss = jnp.mean((pred - y) ** 2)
pred_force = get_force(x, model, norm, W)
loss += (
jnp.sum((pred_force.flatten() - f.flatten()) ** 2)
/ pred_force.shape[0]
)
return loss
loss, grads = nnx.value_and_grad(_loss_fn)(model)
optimizer.update(grads)
return loss
Check train_step
works
Train basis iteratively
Code
Train tensor train with fixed basis
Prepare basis gradient \(-\frac{\partial \Phi}{\partial x_i}\)
Code
def get_partial_basis(x, i, model):
assert x.ndim == 2 # batch, dim
x = x[:, i : i + 1]
def func(x, model):
if x.ndim == 1:
x = x[jnp.newaxis, :]
return -1.0 * model.basis(x, i)[0].squeeze(0)
else:
return -1.0 * model.basis(x, i)[0]
basis_fn = partial(func, model=model)
# since in=1 < out=number of basis, we should use forward differentiation
jacobian = jax.vmap(jax.jacfwd(basis_fn))(x) # batch, out, in
return jacobian.squeeze(2) # batch, out
partial_basis_i = get_partial_basis(x_samples, 0, model)
partial_basis=
$ $
basis=
\(\left[\phi_{\rho_1}, \phi_{\rho_2}, \cdots, \phi_{\rho_n} \right]\)
concat_y=
\(y_p\) where \(p=k \otimes i\) \[
y_p = y_i^{(k)} =
\begin{cases}
\boldsymbol{F}_i^{(k)} & \mathrm{for} \quad i \leq n \\
E^{(k)} & \mathrm{for} \quad i = n+1
\end{cases}
\]
Code
(1024, 1)
basis_list=
\[
\begin{bmatrix}
[\partial_{x_1}\phi_{\rho_1} & \phi_{\rho_2} & \cdots & \phi_{\rho_n} ], \\
[\phi_{\rho_1} & \partial_{x_2}\phi_{\rho_2} & \cdots & \phi_{\rho_n} ], \\
& & \vdots & & \\
[\phi_{\rho_1} & \phi_{\rho_2} & \cdots & \partial_{x_n}\phi_{\rho_n} ], \\
[\phi_{\rho_1} & \phi_{\rho_2} & \cdots & \phi_{\rho_n} ]
\end{bmatrix}
\]
concat_basis=
$ $ where \[
\varphi_{\rho_j}^p = \varphi_{i, \rho_i}^{(k)} =
\begin{cases}
-\partial_{x_j}\phi_{\rho_j} & \mathrm{for} \quad i=j \\
\phi_{\rho_j} & \mathrm{otherwise}
\end{cases}
\]
Code
def get_concat_basis(x: jax.Array, model: CustomModel) -> list[jax.Array]:
"""
Args:
x (jax.Array): postions
model (CustomModel): model
Returns:
list[jax.Array]: derivative concatenated basis
"""
partial_basis = [get_partial_basis(x, i, model) for i in range(x.shape[1])]
basis = model.basis(x)
basis_list = []
for i in range(len(basis)):
basis_row = [
ϕ.copy() if k != i else dϕ.copy()
for k, (ϕ, dϕ) in enumerate(zip(basis, partial_basis, strict=True))
]
basis_list.append(basis_row)
basis_list.append(basis)
concat_basis = [
jnp.vstack([φ[i] for φ in basis_list]) for i in range(len(basis))
]
return concat_basis
concat_basis = get_concat_basis(x_samples, model)
print(concat_basis[0].shape)
(1024, 2)
Train tensor train with fixed basis
Only basis
& y
are changed from energy optimization.
Code
Repeat tensor train optimization & basis optimization
Code
tx = optax.adam(1e-5)
optimizer = nnx.Optimizer(model, tx)
for i in tqdm(range(10000)):
if i % 500 == 0:
norm, W = get_norm_W(tt)
concat_basis = get_concat_basis(x_samples, model)
sweep(
tt=tt,
basis=concat_basis,
y=concat_y,
nsweeps=1,
maxdim=2,
opt_maxiter=200,
# optax_solver=optax.adam(1.e-04),
use_CG=True,
onedot=True,
auto_onedot=False,
)
norm, W = get_norm_W(tt)
loss = train_step(
model, optimizer, x_samples, y_samples, f_samples, W, norm
)
losses.append(loss)
Access optimized tensor train cores
0 Core(shape=(1, 2, 2), leg_names=('β0', 'i1', 'β1'), dtype=float64)
[[[-0.23738906 -0.24130501]
[-0.93898443 0.06104555]]]
1 Core(shape=(2, 2, 2), leg_names=('β1', 'i2', 'β2'), dtype=float64)
[[[-0.58550331 0.05419654]
[ 0.80821735 0.03214547]]
[[-0.8099358 -0.07141367]
[-0.57994308 -0.05069637]]]
2 Core(shape=(2, 2, 1), leg_names=('β2', 'i3', 'β3'), dtype=float64)
[[[-0.99307731]
[ 0.11746261]]
[[ 0.11746261]
[ 0.99307731]]]