Tensor & Tensor Train module

See also API docs for pompon.layers.tensor.Tensor & pompon.layers.tt.TensorTrain

Code
import jax.numpy as jnp

from pompon import DTYPE
from pompon.layers.tensor import Tensor
from pompon.layers.tt import TensorTrain

Data Type

The type of JAX Array is usually float32, but in pompon, we set float64.

Code
print(f"{DTYPE=}")
DTYPE=<class 'jax.numpy.float64'>

If you want default float32, change pompon.__dtype__.DTYPE to jnp.float32 before importing pompon.

Tensor

1d Tensor (vector) \(v_i\)

Code
v = Tensor(data=jnp.ones((2), dtype=DTYPE), leg_names=["i"])
print(f"{v=}")
print(f"{v.data=}")
v=Tensor(shape=(2,), leg_names=['i'], dtype=float64)
v.data=Array([1., 1.], dtype=float64)

2d tensor (matrix) \(M_{ji}\)

Code
M = Tensor(data=jnp.ones((3, 2), dtype=DTYPE), leg_names=["j", "i"])
print(f"{M=}")
print(f"{M.data=}")
M=Tensor(shape=(3, 2), leg_names=['j', 'i'], dtype=float64)
M.data=Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float64)

3d tensor \(T_{ijk}\)

Code
T = Tensor(data=jnp.ones((2, 3, 4), dtype=DTYPE), leg_names=["i", "j", "k"])
print(f"{T=}")
print(f"{T.data=}")
T=Tensor(shape=(2, 3, 4), leg_names=['i', 'j', 'k'], dtype=float64)
T.data=Array([[[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]],

       [[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]]], dtype=float64)

tensor with batch-index \(x_{Dij}\) and \(y_{Di}\)

batch-index is set by 'D', which must be the first index.

Code
x = Tensor(data=jnp.ones((10, 2, 3), dtype=DTYPE), leg_names=["D", "i", "j"])
y = Tensor(data=jnp.ones((10, 2), dtype=DTYPE), leg_names=["D", "i"])
print(f"{x=}")
print(f"{y=}")
x=Tensor(shape=(10, 2, 3), leg_names=['D', 'i', 'j'], dtype=float64)
y=Tensor(shape=(10, 2), leg_names=['D', 'i'], dtype=float64)

Tensor contraction

  • \(u_j = \sum_{i}M_{ji}v_i\) (matrix * vector)
Code
u = M @ v
print(f"{u=}")
assert jnp.allclose(u.data, jnp.dot(M.data, v.data))
u=Tensor(shape=(3,), leg_names=('j',), dtype=float64)
  • \(u_k = \sum_{ij} T_{ijk} M_{ji}\)
Code
u = T @ M
print(f"{u=}")
u=Tensor(shape=(4,), leg_names=('k',), dtype=float64)
  • \(u_k = \sum_{ij} T_{ijk} M_{ji} v_{i}\)

This is not the same as T @ M @ v = \(L_{ki} = \left(\sum_{ij}T_{ijk} M_{ji}\right) v_{i}\)

Code
from pompon.layers.tensor import dot

u = dot(T, M, v)
print(f"{u=}")
L = T @ M @ v
print(f"{L=}")
u=Tensor(shape=(4,), leg_names=('k',), dtype=float64)
L=Tensor(shape=(4, 2), leg_names=('k', 'i'), dtype=float64)
  • \(z_{Dj} = \sum_{i}x_{Dij}y_{Di}\) (index \(D\) will remain)
Code
z = x @ y
print(f"{z=}")
z=Tensor(shape=(10, 3), leg_names=('D', 'j'), dtype=float64)

Tensor Train

Tensor train (also called matrix product states) is written by \[ A(i_1, i_2, \cdots, i_f) = \sum_{\beta_1,\beta_2,\cdots,\beta_{f-1}} \ W^{[1]}_{i_1\beta_1} W^{[2]}_{\beta_1 i_2 \beta_2} \cdots W^{[f]}_{\beta_{f-1}i_f} \]

Code
tt = TensorTrain.set_random(shape=(3, 3, 3, 3), rank=2)
print(tt)
print(f"{tt=}")
print(f"{tt.ndim=}")
print(f"{tt.ranks=}")
[Core(shape=(1, 3, 2), leg_names=('β0', 'i1', 'β1'), dtype=float64), Core(shape=(2, 3, 2), leg_names=('β1', 'i2', 'β2'), dtype=float64), Core(shape=(2, 3, 2), leg_names=('β2', 'i3', 'β3'), dtype=float64), Core(shape=(2, 3, 1), leg_names=('β3', 'i4', 'β4'), dtype=float64), ]
tt=TensorTrain(shape=(3, 3, 3, 3), ranks=[2, 2, 2])
tt.ndim=4
tt.ranks=[2, 2, 2]
Code
tt.cores
[Core(shape=(1, 3, 2), leg_names=('β0', 'i1', 'β1'), dtype=float64),
 Core(shape=(2, 3, 2), leg_names=('β1', 'i2', 'β2'), dtype=float64),
 Core(shape=(2, 3, 2), leg_names=('β2', 'i3', 'β3'), dtype=float64),
 Core(shape=(2, 3, 1), leg_names=('β3', 'i4', 'β4'), dtype=float64)]

Each core has instance Tensor

Code
isinstance(tt[0], Tensor)
True

Each core can be contracted as TwodotCore

Code
W0, W1 = tt[0:2]
B = W0 @ W1
print(f"{B=}")
B=TwodotCore(shape=(1, 3, 3, 2), leg_names=('β0', 'i1', 'i2', 'β2'))

TwodotCore has a method svd() which split into two Cores again

Code
W0_next, W1_next = B.svd(truncation=0.99)
print(f"{W0_next=}")
print(f"{W1_next=}")
# Set again
tt.cores[0] = W0_next
tt.cores[1] = W1_next
W0_next=Core(shape=(1, 3, 2), leg_names=('β0', 'i1', 'β1'), dtype=float64)
W1_next=Core(shape=(2, 3, 2), leg_names=('β1', 'i2', 'β2'), dtype=float64)

Forward TensorTrain with basis batch

Code
import jax

basis = [
    Tensor(
        data=jax.random.normal(jax.random.PRNGKey(_), (10, 3), dtype=DTYPE),
        leg_names=["D", f"i{_}"],
    ).as_basis_batch()
    for _ in range(1, 5)
]
basis
[BasisBatch(shape=(10, 3), leg_names=['D', 'i1'], dtype=float64),
 BasisBatch(shape=(10, 3), leg_names=['D', 'i2'], dtype=float64),
 BasisBatch(shape=(10, 3), leg_names=['D', 'i3'], dtype=float64),
 BasisBatch(shape=(10, 3), leg_names=['D', 'i4'], dtype=float64)]
Code
y = tt.forward(basis)
y
Array([[-0.15648396],
       [-3.02572681],
       [-0.65708564],
       [ 2.10014884],
       [ 0.00536333],
       [-0.99552358],
       [-0.14399027],
       [-0.24901398],
       [-0.76588784],
       [ 0.09033759]], dtype=float64)