layers.tt.TensorTrain

layers.tt.TensorTrain(self)

Tensor Train (TT) class

\[ A(i_1, i_2, \cdots, i_f) = \sum_{\beta_1,\beta_2,\cdots,\beta_{f-1}} \ W^{[1]}_{i_1\beta_1} W^{[2]}_{\beta_1 i_2 \beta_2} \cdots W^{[f]}_{\beta_{f-1}i_f} \]

This class object is initialized by either following methods:

  1. TensorTrain.decompose(tensor): exact tensor train decomposition

    import jax
    from pompon import TensorTrain
    tensor = jax.random.normal(jax.random.PRNGKey(0), (3, 3, 3))
    tt = TensorTrain.decompose(tensor)
  2. TensorTrain.set_custom(cores): set custom cores

    import jax
    from pompon import TensorTrain
    cores = [jax.random.normal(jax.random.PRNGKey(0), (1, 3, 2)),
             jax.random.normal(jax.random.PRNGKey(1), (2, 3, 2)),
             jax.random.normal(jax.random.PRNGKey(2), (2, 3, 1))]
    tt = TensorTrain.set_custom(cores)
  3. TensorTrain.set_random(shape, rank): set random tensor train

    from pompon import TensorTrain
    tt = TensorTrain.set_random(shape=(3, 3, 3), rank=2)

Attributes

Name Description
ranks List of ranks [\(M_1, M_2, \cdots, M_{f-1}\)]

Methods

Name Description
decompose Initialize with a given tensor by exact tensor train decomposition
forward Evaluate the contraction of the tensor train \(A(i_1, i_2, \cdots, i_f)\)
set_blocks_batch Set left and right blocks for batch
set_center_onedot Set the center one-dot tensor
set_center_twodot Set the center two-dot tensor
set_custom Initialize with a given list of cores
set_ones Initialize with all ones tensor train
set_random Initialize with a random tensor train
shift_center Shift the center site to the left or right.
switch_dot When bond-dimension reaches the maximum, center cites should be switched to one-dot tensor.
to_canonical Convert tensor-train into canonical form

decompose

layers.tt.TensorTrain.decompose(tensor)

Initialize with a given tensor by exact tensor train decomposition

Parameters

Name Type Description Default
tensor Array tensor with shape (N, N, …, N) required

Returns

Name Type Description
TensorTrain TensorTrain TensorTrain object

forward

layers.tt.TensorTrain.forward(basis)

Evaluate the contraction of the tensor train \(A(i_1, i_2, \cdots, i_f)\) with the input tensor \(\Phi(i_1, i_2, \cdots, i_f)\)

Parameters

Name Type Description Default
basis list[Array] | list[Tensor] Input tensor \(D\) @ \(\phi^{[p]}_{i_p}\) with shape \(f\times(D, N)\) where \(D\) is the batch size. required

Returns

Name Type Description
Array Array Output tensor \(D\) @ \(\sum_{i_1,\cdots,i_f} A(i_1,\cdots,i_f) \phi^{[1]}_{i_1} \cdots \phi^{[f]}_{i_f}\) with shape \((D,1)\)

set_blocks_batch

layers.tt.TensorTrain.set_blocks_batch(basis)

Set left and right blocks for batch

Parameters

Name Type Description Default
basis list[Array] List of Input tensor \(D\) @ \(\phi^{[p]}_{i_p}\) with shape (D, N) where D is the batch size required

set_center_onedot

layers.tt.TensorTrain.set_center_onedot()

Set the center one-dot tensor

set_center_twodot

layers.tt.TensorTrain.set_center_twodot(to_right=True)

Set the center two-dot tensor

set_custom

layers.tt.TensorTrain.set_custom(cores)

Initialize with a given list of cores

Parameters

Name Type Description Default
cores list[Core | Array] list of cores with shape (M, N, M) like [\(W^{[1]}, W^{[2]}, \cdots, W^{[f]}\)] required

Returns

Name Type Description
TensorTrain TensorTrain TensorTrain object

set_ones

layers.tt.TensorTrain.set_ones(shape, rank=None)

Initialize with all ones tensor train

set_random

layers.tt.TensorTrain.set_random(shape, rank=None, key=None)

Initialize with a random tensor train

Parameters

Name Type Description Default
shape tuple[int, …] shape of the tensor like \((N, N, ..., N)\) required
rank int maximum tt-rank of the tensor train. Defaults to None. None
key Array random key. Defaults to None. None

Returns

Name Type Description
TensorTrain TensorTrain TensorTrain object

shift_center

layers.tt.TensorTrain.shift_center(to_right, basis, is_onedot_center=False)

Shift the center site to the left or right.

When to_right is True, the self.center is shifted to self.center + 1, left blocks are updated as follows:

\[ \mathcal{L}^{[p]}_{\beta_{p}} = \sum_{\beta_{p-1}} \sum_{i_{p}} W^{[p]}_{\beta_{p-1} i_{p} \beta_{p}} \phi_{i_{p}}^{[p]} \mathcal{L}^{[p-1]}_{\beta_{p-1}} \]

the last term of the right blocks is popped.

Parameters

Name Type Description Default
to_right bool If True, the center site is shifted to the right. Otherwise, the center site is shifted to the left. required
basis list[Array] f-length list of tensor \(D\) @ \(\phi^{[p]}_{i_p}\) with shape (D, N) where D is the batch size required
is_onedot_center bool If True, the center site is the one-dot tensor. False

switch_dot

layers.tt.TensorTrain.switch_dot(to_onedot, to_right, basis)

When bond-dimension reaches the maximum, center cites should be switched to one-dot tensor.

Parameters

Name Type Description Default
to_onedot bool If True, the center site is switched to the one-dot tensor. Otherwise, the center site is switched to the two-dot tensor. required
basis list[Array] f-length list of tensor \(D\) @ \(\phi^{[p]}_{i_p}\) with shape (D, N) where D is the batch size required

to_canonical

layers.tt.TensorTrain.to_canonical(gauge='CR', ord='fro')

Convert tensor-train into canonical form

Parameters

Name Type Description Default
gauge str gauge. “LC” for left-canonical form, “CR” for right-canonical form. 'CR'
ord str order of the norm. Defaults to “fro” which is Frobenius norm. 'fro'