layers.basis.Basis

layers.basis.Basis(
    self
    hidden_size
    basis_size
    activation
    key=None
    w_dist='uniform'
    w_scale=1.0
    b_dist='linspace'
    b_scale=1.0
)

Basis layer class

This class consisists of Phi layer of each mode as a list.

Parameters

Name Type Description Default
hidden_size int number of modes \(f\) required
basis_size int number of basis \(N\) required
activation str activation function required
key Array random key. Defaults to None. None
w_dist str distribution of the weight. Available distributions are “uniform”, “normal”, “ones”. 'uniform'
w_scale float scale of the weight. 1.0
b_dist str distribution of the bias. Available distributions are “uniform”, “normal”, “linspace”. 'linspace'
b_scale float scale of the bias. 1.0

Attributes

Name Description
activations JAX cannot compile list[Callable], so use tuple[Callable] instead.

Methods

Name Description
forward Forward transformation
partial Partial derivative of the basis with respect to the q-th hidden coordinate.
plot_basis Monitor the distribution of the basis to

forward

layers.basis.Basis.forward(q, q0)

Forward transformation

Parameters

Name Type Description Default
q Array hidden coordinates with shape (D, f) where D is the size of the batch and f is the hidden dimension. required
q0 Array reference hidden coordinates with shape (N-1, f) where N is the basis size. required

Returns

Name Type Description
list[Array] list[Array]: basis with shape (D, N) where D is the size of the batch and N is the basis size.

partial

layers.basis.Basis.partial(q, q0)

Partial derivative of the basis with respect to the q-th hidden coordinate.

Parameters

Name Type Description Default
q Array hidden coordinates with shape (D, f) where D is the size of the batch and f is the hidden dimension. required
q0 Array reference hidden coordinates with shape (N-1,) where N is the basis size. required

Returns

Name Type Description
list[Array] list[Array]: [∂φ(wq + b) / ∂q]_{p=0}^{f} with shape (D, N) where D is the size of the batch and N is the basis size.

plot_basis

layers.basis.Basis.plot_basis(q, q0)

Monitor the distribution of the basis to avoid the saturation of the activation function.