import functools
from abc import ABC, abstractmethod
from typing import Any, Callable
import matplotlib.pyplot as plt
import numpy as np
from scipy import integrate
from scipy.linalg import eigh
[docs]
class DVRPrimitivesMixin(ABC):
"""
Abstract DVR primitive function class
"""
def __init__(self, ngrid: int):
if not isinstance(ngrid, int):
raise TypeError(
"ngrid argument must be integer but {ngrid} is given."
)
self.ngrid: int = ngrid
self.nprim: int = ngrid
self.lb: float = -100.0
self.ub: float = 100.0
self.label: str = "abc"
def __call__(self, n: int, q: float) -> float | complex:
return self.dvr_func(n, q)
def __iter__(self):
for grid in self.get_grids():
yield grid
def __len__(self) -> int:
return self.ngrid
[docs]
@abstractmethod
def fbr_func(self, n: int, q: float) -> float | complex:
r"""
fbr Primitive function :math:`\varphi_n(q)`, such as HO, Sine, etc
"""
pass
[docs]
@abstractmethod
def get_pos_rep_matrix(self) -> np.ndarray:
r"""
Numerical integral of :math:`\langle\varphi_j|\hat{q}|\varphi_k\rangle`
If analytical integral is known, implemented in inheritance class
"""
if not hasattr(self, "pos_rep_matrix"):
self.pos_rep_matrix: np.ndarray[Any, np.dtype[np.bool_]] = np.zeros(
(self.ngrid, self.ngrid), dtype=np.complex128
)
avg_error = 0.0
for j in range(self.ngrid):
bra_func = functools.partial(self.fbr_func, j)
for k in range(j, self.ngrid):
ket_func = functools.partial(self.fbr_func, k)
def integrand(x, bra_func=bra_func, ket_func=ket_func):
return np.conjugate(bra_func(x)) * x * ket_func(x)
self.pos_rep_matrix[j, k], error = integrate.quad(
integrand, self.lb, self.ub
)
avg_error += error
avg_error /= self.ngrid * (self.ngrid - 1) // 2
print(f"average numerical integral error : {avg_error}")
self.pos_rep_matrix += np.conjugate(
self.pos_rep_matrix.T
) - np.diag(np.diag(self.pos_rep_matrix))
return self.pos_rep_matrix
[docs]
@abstractmethod
def get_1st_derivative_matrix_fbr(self) -> np.ndarray:
r"""
Numerical integral of
:math:`\langle\varphi_j|\frac{d}{dq}|\varphi_k\rangle`
If analytical integral is known, implemented in inheritance class
"""
raise NotImplementedError
[docs]
def get_1st_derivative_matrix_dvr(self) -> np.ndarray:
r""":math:`\langle\chi_\alpha|\frac{d}{dq}|\chi_\beta\rangle`"""
if not hasattr(self, "first_derivative_matrix_dvr"):
self.first_derivative_matrix_dvr = (
self.get_unitary().conj().T
[docs]
@ self.get_1st_derivative_matrix_fbr()
@ self.get_unitary()
)
return self.first_derivative_matrix_dvr
@abstractmethod
def get_2nd_derivative_matrix_fbr(self) -> np.ndarray:
r"""
Numerical integral of
:math:`\langle\varphi_j|\frac{d^2}{dq^2}|\varphi_k\rangle`
If analytical integral is known, implemented in inheritance class
"""
raise NotImplementedError
[docs]
def get_2nd_derivative_matrix_dvr(self) -> np.ndarray:
r""":math:`\langle\chi_\alpha|\frac{d^2}{dq^2}|\chi_\beta\rangle`"""
if not hasattr(self, "second_derivative_matrix_dvr"):
self.second_derivative_matrix_dvr = (
self.get_unitary().conj().T
[docs]
@ self.get_2nd_derivative_matrix_fbr()
@ self.get_unitary()
)
return self.second_derivative_matrix_dvr
@abstractmethod
def diagnalize_pos_rep_matrix(self) -> None:
"""Numerical diagonalization of `pos_rep_matrix`.
If analytical diagonalization is known, implemented in inheritance class
"""
if not hasattr(self, "grids"):
eig_val, eig_vec = eigh(self.get_pos_rep_matrix())
self.grids = list(eig_val)
self.unitary = eig_vec
self.get_sqrt_weights()
[docs]
def get_sqrt_weights(self, k: int = 0) -> list[float]:
r""":math:`\sqrt{w_\alpha}=U_{k\alpha}^{\ast}/\varphi_k(x_\alpha)`"""
if not hasattr(self, "sqrt_weights"):
self.sqrt_weights = [
(
np.conjugate(self.get_unitary()[k, alpha])
/ self.fbr_func(k, self.get_grids()[alpha])
).real
for alpha in range(self.ngrid)
]
for alpha in range(self.ngrid):
if self.sqrt_weights[alpha].real < 0:
self.sqrt_weights[alpha] *= -1.0
self.unitary[:, alpha] *= -1.0
return self.sqrt_weights
[docs]
def get_grids(self) -> list[float]:
r"""
grids :math:`x_\alpha` correspond to eigenvalue of `pos_rep_matrix`
"""
if not hasattr(self, "grids"):
self.diagnalize_pos_rep_matrix()
return self.grids
[docs]
def get_unitary(self) -> np.ndarray:
r"""Get Unitary Matrix which diagonalize `pos_rep_matrix`
Returns:
np.ndarray : `u[alpha,j]` =
:math:`(U_{j\alpha})^\dagger` = :math:`(U^\dagger)_{\alpha j}`
where,
.. math::
\sum_{j,k}
U_{j\alpha}\langle\varphi_j|\hat{q}|
\varphi_k\rangle U_{k\beta}^\dagger
= x_\alpha \delta_{\alpha\beta}
"""
if not hasattr(self, "unitary"):
self.diagnalize_pos_rep_matrix()
return self.unitary
[docs]
def dvr_func(self, n: int, q: float) -> float:
r"""DVR function
.. math::
\chi_\alpha=\sum_{j=0}^{N-1}\varphi_j(q)U_{j\alpha}
\quad (\alpha=0,\ldots, N-1)
In other words,
.. math::
|\chi_\alpha\rangle =U^\dagger |\varphi_j\rangle
"""
if not (0 <= n < self.ngrid):
ValueError
dum = 0.0
for j in range(self.ngrid):
dum += self.fbr_func(j, q) * self.get_unitary()[j, n]
return dum
[docs]
def plot_fbr(
self, n: int | None = None, q: np.ndarray | None = None
) -> None:
r"""Plot FBR :math:`\{\varphi_n(q)\}`"""
plt.title(f"{self.label}-FBR funtions")
self._plot(self.fbr_func, n, q, name="fbr-func")
[docs]
def plot_dvr(
self, n: int | None = None, q: np.ndarray | None = None
) -> None:
r"""Plot DVR functions :math:`\{\chi_n(q)\}`"""
plt.title(f"{self.label}-DVR functions")
self._plot(self.dvr_func, n, q, name="dvr-func")
def _plot(
self,
func: Callable,
n: int | None = None,
q: np.ndarray | None = None,
name: str | None = None,
) -> None:
if q is None:
q = np.linspace(self.lb, self.ub, 100)
if n is None:
for n in range(self.ngrid):
array = func(n, q)
# if imaginary part is small, plot real part
if np.max(np.abs(array.imag)) < 1e-10:
plt.plot(q, array.real, label=f"{n}")
else:
plt.plot(q, array.real, label=f"{n} real")
plt.plot(q, array.imag, label=f"{n} imag", linestyle="--")
else:
array = func(n, q)
# if imaginary part is small, plot real part
if np.max(np.abs(array.imag)) < 1e-10:
plt.plot(q, array.real, label=f"{n}")
else:
plt.plot(q, array.real, label=f"{n} real")
plt.plot(q, array.imag, label=f"{n} imag")
plt.legend(loc="upper right")
if isinstance(name, str):
plt.savefig(f"{name}.pdf")
plt.show()