Source code for pympo.operators

from __future__ import annotations

import re
from bisect import bisect_left
from typing import Iterator, Sequence, overload

import numpy as np
import sympy
from numpy.typing import NDArray


[docs] class OpSite: """ Represents an operator acting on a specific site in a quantum system. Attributes: symbol (sympy.Basic): The symbolic representation of the operator. isite (int): The site index on which the operator acts. value (NDArray | None): The numerical value of the operator, if available. isdiag (bool): Indicates if the operator is diagonal. Operator z_i acting on site i. """ symbol: sympy.Basic isite: int value: NDArray | None isdiag: bool def __init__( self, symbol: sympy.Basic | str, isite: int, *, value: NDArray | None = None, isdiag: bool = False, ) -> None: if isinstance(symbol, sympy.Basic): self.symbol = symbol elif isinstance(symbol, str): self.symbol = sympy.Symbol(symbol, commutative=False) else: raise ValueError("Invalid type", type(symbol)) self.isite = isite self.value = value if value is not None: self.isdiag = value.ndim == 1 else: self.isdiag = isdiag def __repr__(self) -> str: retval = self.symbol assert isinstance(retval, sympy.Basic) return retval.__repr__() def __str__(self) -> str: return self.__repr__() def __mul__( self, other: int | float | complex | sympy.Basic | OpSite | OpProductSite, ) -> OpSite | OpProductSite: if isinstance(other, int | float | complex | sympy.Basic): retval = OpProductSite([self]) * other assert isinstance(retval, OpProductSite) return retval elif isinstance(other, OpSite): if self.isite < other.isite: return OpProductSite([self, other]) elif self.isite > other.isite: return OpProductSite([other, self]) else: symbol = self.symbol * other.symbol isite = self.isite isdiag = self.isdiag and other.isdiag if self.value is not None and other.value is not None: if isdiag: value = self.value * other.value else: if self.value.ndim == 1: value1 = np.diag(self.value) else: value1 = self.value if other.value.ndim == 1: value2 = np.diag(other.value) else: value2 = other.value value = value1 @ value2 else: value = None return OpSite(symbol, isite, value=value, isdiag=isdiag) elif isinstance(other, OpProductSite): return OpProductSite([self] + other.ops) else: raise ValueError(f"Invalid type: {type(other)=}") def __rmul__( self, other: int | float | complex | sympy.Basic | OpSite | OpProductSite, ) -> OpSite | OpProductSite: if isinstance(other, int | float | complex | sympy.Basic): """ Commutative """ return self.__mul__(other) else: raise ValueError(f"Invalid type: {type(other)=}") def __truediv__( self, other: int | float | complex | sympy.Basic ) -> OpSite | OpProductSite: if isinstance(other, int | float | complex | sympy.Basic): retval = self.__mul__(1 / other) return retval else: raise ValueError(f"Invalid type: {type(other)=}") def __rtruediv__( self, other: int | float | complex | sympy.Basic ) -> OpSite | OpProductSite: if isinstance(other, int | float | complex | sympy.Basic): return self.__truediv__(other) else: raise ValueError(f"Invalid type: {type(other)=}") def __add__( self, other: OpSite | OpProductSite | SumOfProducts | int | float | complex | sympy.Basic, ) -> OpSite | SumOfProducts: if isinstance(other, OpSite): return SumOfProducts([self, other]) elif isinstance(other, OpProductSite): return SumOfProducts([self, other]) elif isinstance(other, SumOfProducts): op_product = OpProductSite([self]) other.ops.append(op_product) other.coefs.append(op_product.coef) other.symbols.append(op_product.symbol) return other elif isinstance(other, int | float | complex | sympy.Basic): if isinstance(self.value, np.ndarray): n_basis = self.value.shape[0] else: n_basis = None eye = get_eye_site(self.isite, n_basis=n_basis) return self + eye else: raise ValueError("Invalid type") def __sub__( self, other: OpSite | OpProductSite | SumOfProducts ) -> OpSite | SumOfProducts: if isinstance(other, OpSite | OpProductSite | SumOfProducts): return self + (-1) * other else: raise ValueError("Invalid type")
[docs] def get_eye_site(i: int, n_basis: int | None = None) -> OpSite: """ Create an identity operator site. Parameters: ----------- i (int): The index of the site. n_basis (int | None, optional): The number of basis states. If provided, an array of ones with length `n_basis` is created. Defaults to None. Returns: -------- OpSite: An operator site with the identity operator. """ value: NDArray | None = None if isinstance(n_basis, int): value = np.ones(n_basis) return OpSite( sympy.Symbol(r"\hat{1}_" + f"{i}"), i, value=value, isdiag=True )
[docs] def omit_eye_site(latex_symbol: str) -> str: r""" Args: latex_symbol (str): The latex symbol of the operator like $\hat{1}_0\hat{z}_1$. Returns: str: The latex symbol of the operator without the identity operator like $\hat{z}_1$. """ latex = re.sub(r"\\hat\{1\}_[0-9]+", "", latex_symbol) if re.match(r"\$[ ]*\$", latex): if re.search(r"\\hat\{1\}_0", latex_symbol): return r"$\hat{1}_{\text{left}}$" else: return r"$\hat{1}_{\text{right}}$" return latex
[docs] class OpProductSite: """ Represents a product of operators acting on multiple sites, such as z_i * z_j * z_k. Attributes: coef (int | float | complex | sympy.Basic): Coefficient of the operator product. symbol (sympy.Basic): Symbolic representation of the operator product. ops (list[OpSite]): List of operators in the product. sites (list[int]): List of site indices where the operators act. Product of operators acting on multiple sites like z_i * z_j * z_k """ coef: int | float | complex | sympy.Basic symbol: sympy.Basic ops: list[OpSite] sites: list[int] def __init__(self, ops: list[OpSite]) -> None: argsrt = np.argsort([op.isite for op in ops]) self.ops = [ops[i] for i in argsrt] self.sites = [] self.symbol = 1 self.coef = 1 for op in self.ops: self.symbol *= op.symbol self.sites.append(op.isite) if self._is_duplicated(): raise ValueError("Duplicate site index") if not self._is_sorted(): raise ValueError("Site index is not sorted") self.left_product: list[sympy.Basic] = None # type: ignore self.right_product: list[sympy.Basic] = None # type: ignore def _set_left_product(self) -> None: self.left_product = [self.ops[0].symbol] k = 0 for i in range(self.sites[0] + 1, self.sites[-1] + 1): if i in self.sites[1:]: k += 1 self.left_product.append( self.left_product[-1] * self.ops[k].symbol ) else: self.left_product.append(self.left_product[-1]) assert k == len(self.ops) - 1, f"{k=}, {len(self.ops)=}, {self.sites=}" def _set_right_product(self) -> None: self.right_product = [self.ops[-1].symbol] k = len(self.ops) - 1 for i in range(self.sites[-1] - 1, self.sites[0] - 1, -1): if i in self.sites[:-1]: k -= 1 self.right_product.append( self.ops[k].symbol * self.right_product[-1] ) else: self.right_product.append(self.right_product[-1]) self.right_product = self.right_product[::-1] assert k == 0, f"{k=}, {self.sites=}"
[docs] def replace(self, new_op: OpSite) -> None: """ Replace an existing operator in the list with a new operator. Args: new_op (OpSite): The new operator to replace the existing one. Raises: AssertionError: If the site of the new operator is not found in the existing operators. Modifies: self.symbol: Updates the symbol by multiplying the symbols of all operators. self.ops: Replaces the operator at the matching site with the new operator. """ self.symbol = 1 is_replaced = False for i, op in enumerate(self.ops): if op.isite == new_op.isite: self.ops[i] = new_op is_replaced = True self.symbol *= self.ops[i].symbol assert is_replaced, f"{new_op.isite=} is not found in {self.sites=}"
def __repr__(self) -> str: return " * ".join([op.__repr__() for op in self.ops]) def __str__(self) -> str: return " * ".join([op.__str__() for op in self.ops]) @overload def __mul__( self, other: int | float | complex | sympy.Basic | OpSite | OpProductSite, ) -> OpProductSite: ... @overload def __mul__(self, other: SumOfProducts) -> SumOfProducts: ... def __mul__( self, other: int | float | complex | sympy.Basic | OpSite | OpProductSite | SumOfProducts, ) -> OpProductSite | SumOfProducts: if isinstance(other, int | float | complex | sympy.Basic): self.coef *= other return self elif isinstance(other, OpSite): if other.isite in self.sites: idx = bisect_left(self.sites, other.isite) self.symbol *= other.symbol same_site_op = self.ops[idx] isdiag = same_site_op.isdiag and other.isdiag if same_site_op.value is not None and other.value is not None: if isdiag: value = same_site_op.value * other.value else: if same_site_op.value.ndim == 1: value1 = np.diag(same_site_op.value) else: value1 = same_site_op.value if other.value.ndim == 1: value2 = np.diag(other.value) else: value2 = other.value value = value1 @ value2 else: value = None site_symbol = same_site_op.symbol * other.symbol new_op = OpSite( site_symbol, same_site_op.isite, value=value, isdiag=isdiag, ) self.ops[idx] = new_op return self else: idx = bisect_left(self.sites, other.isite) self.ops.insert(idx, other) return OpProductSite(self.ops) * self.coef elif isinstance(other, OpProductSite): coef = self.coef * other.coef new_product = OpProductSite(self.ops) for op in other.ops: assert isinstance(op, OpSite) _new_product = new_product * op assert isinstance(_new_product, OpProductSite) new_product = _new_product new_product.coef = coef return new_product elif isinstance(other, SumOfProducts): ops = [] opproduct1 = OpProductSite(self.ops) opproduct1.coef = self.coef for opproduct2 in other.ops: assert isinstance(opproduct2, OpProductSite) _new_product = opproduct1 * opproduct2 assert isinstance(_new_product, OpProductSite) ops.append(_new_product) return SumOfProducts(ops) else: raise ValueError(f"Invalid type: {type(other)=}") def __rmul__( self, other: int | float | complex | sympy.Basic | OpSite | OpProductSite, ) -> OpProductSite: if isinstance(other, int | float | complex | sympy.Basic): retval = self.__mul__(other) elif isinstance(other, OpSite | OpProductSite): retval = other.__mul__(self) # type: ignore else: raise ValueError(f"Invalid type: {type(other)=}") assert isinstance(retval, OpProductSite) return retval def __truediv__( self, other: int | float | complex | sympy.Basic ) -> OpProductSite: if isinstance(other, int | float | complex | sympy.Basic): retval = self.__mul__(1 / other) assert isinstance(retval, OpProductSite) return retval else: raise ValueError(f"Invalid type: {type(other)=}") def __add__( self, other: OpSite | OpProductSite | SumOfProducts | int | float | complex | sympy.Basic, ) -> SumOfProducts: if isinstance(other, OpSite): return SumOfProducts([self, other]) elif isinstance(other, OpProductSite): return SumOfProducts([self, other]) elif isinstance(other, SumOfProducts): other.ops.append(self) other.coefs.append(self.coef) other.symbols.append(self.symbol) return other elif isinstance(other, int | float | complex | sympy.Basic): const = get_eye_site(self.sites[0]) * other return SumOfProducts([self, const]) else: raise ValueError(f"Invalid type: {type(other)=}") def __radd__( self, other: OpSite | OpProductSite | SumOfProducts | int | float | complex | sympy.Basic, ) -> SumOfProducts: return self.__add__(other) def __sub__( self, other: OpSite | OpProductSite | SumOfProducts | int | float | complex | sympy.Basic, ) -> SumOfProducts: if isinstance( other, OpSite | OpProductSite | SumOfProducts | int | float | complex | sympy.Basic, ): return self + (-1) * other else: raise ValueError(f"Invalid type: {type(other)=}") def __rsub__( self, other: OpSite | OpProductSite | SumOfProducts | int | float | complex | sympy.Basic, ) -> SumOfProducts: return self.__sub__((-1) * other) def _is_duplicated(self) -> bool: return len(self.sites) != len(set(self.sites)) def _is_sorted(self) -> bool: return self.sites == sorted(self.sites) and self.sites == [ op.isite for op in self.ops ]
[docs] def get_symbol_interval( self, start_site: int, end_site: int ) -> sympy.Basic: """ Get the symbol of the operator acting on the sites between start_site and end_site. When the operator is symbol = z_1 * z_6, - get_symbol_interval(0, 3) returns 1_0 * z_1 * 1_2, - get_symbol_interval(3, 8) returns 1_3 * 1_4 * 1_5 * z_6 * 1_7. Args: start_site (int): The start site. end_site (int): The end site. Returns: sympy.Basic: The symbol of the operator acting on the sites between start_site and end_site. To Do: - Improve the performance of the function by memoization. """ if start_site + 1 == end_site: if start_site in self.sites: return self.ops[bisect_left(self.sites, start_site)].symbol else: return get_eye_site(start_site).symbol if start_site <= self.sites[0]: case_left = 0 elif start_site <= self.sites[-1]: case_left = 1 else: case_left = 2 if end_site <= self.sites[0]: case_right = 0 elif end_site <= self.sites[-1]: case_right = 1 else: case_right = 2 match (case_left, case_right): case (0, 0): return get_eye_site(end_site).symbol case (0, 1): if self.left_product is None: self._set_left_product() return self.left_product[end_site - self.sites[0]] case (0, 2): if self.left_product is None: self._set_left_product() return self.left_product[-1] case (1, 1): return self._get_symbol_interval_intermidiate( start_site, end_site ) case (1, 2): if self.right_product is None: self._set_right_product() return self.right_product[start_site - self.sites[0]] case (2, 2): return get_eye_site(start_site).symbol case _: raise ValueError( f"{case_left=}, {case_right=}, {start_site=}, {end_site=}" )
def _get_symbol_interval_intermidiate( self, start_site: int, end_site: int ) -> sympy.Basic: symbol = 1 idx = bisect_left(self.sites, start_site) is_eye = True for i in range(start_site, end_site): if len(self.sites) > idx and self.sites[idx] == i: # If operator is 1_1 * 1_2 * 1_3 * z_4 * 1_5 ... # skip 1_1 * 1_2 * 1_3 if is_eye: symbol = self.ops[idx].symbol is_eye = False else: symbol *= self.ops[idx].symbol idx += 1 else: if symbol == 1: # To reduce cost for symbolic computation, # identity operator is only used when symbol == 1. symbol = get_eye_site(i).symbol else: pass if symbol == 1: symbol = get_eye_site(start_site).symbol assert isinstance(symbol, sympy.Basic) # self.symbol_intervals[(start_site, end_site)] = symbol return symbol def __getitem__(self, key: int | slice) -> sympy.Basic: if isinstance(key, slice): start = key.start stop = key.stop if start is None: start = 0 if stop is None: raise ValueError("Invalid slice. End index is not trivial.") return self.get_symbol_interval(start, stop) elif isinstance(key, int): return self.get_symbol_interval(key, key + 1) else: raise ValueError("Invalid type")
[docs] def get_site_value(self, isite: int, n_basis: int, isdiag: bool) -> NDArray: """ Get the value of the operator acting on the site isite. Args: isite (int): The site index. n_basis (int): The number of basis. Returns: NDArray: The value of the operator acting on the site isite. """ idx = bisect_left(self.sites, isite) if len(self.sites) > idx and self.sites[idx] == isite: value = self.ops[idx].value assert isinstance(value, np.ndarray) if isdiag: assert value.shape == ( n_basis, ), f"{value.shape=} while {n_basis=}" elif value.ndim == 1: value = np.diag(value) assert value.shape == ( n_basis, n_basis, ), f"{value.shape=} while {n_basis=}" return value else: if isdiag: return np.ones(n_basis) else: return np.eye(n_basis)
[docs] class SumOfProducts: """ Sum of products of operators acting on multiple sites like z_i * z_j + z_k * z_l Args: ops (Sequence[OpProductSite | OpSite], optional): List of operator products. Defaults to []. Attributes: coefs (list[int | float | complex | sympy.Basic]): Coefficients of the operator products. ops (list[OpProductSite]): List of operator products. symbols (list[sympy.Basic]): List of symbolic representations of the operator products. """ coefs: list[int | float | complex | sympy.Basic] ops: list[OpProductSite] symbols: list[sympy.Basic] def __init__(self, ops: Sequence[OpProductSite | OpSite] = []) -> None: self.coefs = [] self.ops = [] self.symbols = [] for op in ops: if isinstance(op, OpSite): op = OpProductSite([op]) assert isinstance(op, OpProductSite) self.ops.append(op) self.coefs.append(op.coef) self.symbols.append(op.symbol)
[docs] def simplify(self) -> SumOfProducts: """ Concatenate common operator such as q_i * a_j + q_i * a^dagger_j -> q_i * (a_j + a^dagger_j) Note that the computational complexity is O(n^2) where n is the number of operators. """ skip_flag = [False] * len(self.ops) new_ops = [] for i in range(len(self.ops)): if skip_flag[i]: continue for j in range(i + 1, len(self.ops)): if skip_flag[j]: continue if self.ops[i].sites != self.ops[j].sites: continue if self.ops[i].coef != self.ops[j].coef: continue continue_flag = False op_i_common = None op_j_common = None for op_i, op_j in zip( self.ops[i].ops, self.ops[j].ops, strict=True ): if op_i.symbol != op_j.symbol: if op_i_common is not None: # When two operators are not common, skip the loop. continue_flag = True break op_i_common = op_i op_j_common = op_j if continue_flag or op_i_common is None: # Either two or more operators are not common or no common operator is found. continue assert isinstance(op_i_common, OpSite) and isinstance( op_j_common, OpSite ) assert op_j_common.isite == op_i_common.isite skip_flag[j] = True new_symbol = op_i_common.symbol + op_j_common.symbol new_isdiag = op_i_common.isdiag and op_j_common.isdiag if ( op_i_common.value is not None and op_j_common.value is not None ): if new_isdiag: new_value = op_i_common.value + op_j_common.value else: if op_i_common.value.ndim == 1: value1 = np.diag(op_i_common.value) else: value1 = op_i_common.value if op_j_common.value.ndim == 1: value2 = np.diag(op_j_common.value) else: value2 = op_j_common.value new_value = value1 + value2 else: new_value = None new_op = OpSite( new_symbol, op_i_common.isite, value=new_value, isdiag=new_isdiag, ) self.ops[i].replace(new_op) new_ops.append(self.ops[i]) return SumOfProducts(new_ops)
@property def symbol(self) -> sympy.Basic | int | float | complex: symbol = 0 for i in range(len(self.ops)): symbol += self.ops[i].symbol * self.coefs[i] assert isinstance( symbol, sympy.Basic | int | float | complex ), f"{symbol=}" return symbol @property def ndim(self) -> int: max_ndim = 0 for op in self.ops: max_ndim = max(max_ndim, max(op.sites) + 1) return max_ndim
[docs] def get_unique_ops_site(self, i: int) -> set[OpSite]: unique_ops_set = set() for op in self.ops: op_i = op[i] unique_ops_set.add(op_i) return unique_ops_set
@property def nops(self) -> int: return len(self.ops) @property def nbasis_list(self) -> list[int]: nbasis_list = [0] * self.ndim for opproduct in self.ops: for isite, opsite in zip( opproduct.sites, opproduct.ops, strict=True ): if opsite.value is None: raise ValueError(f"Value at {opsite=} is not defined.") if nbasis_list[isite] == 0: nbasis_list[isite] = opsite.value.shape[0] else: if nbasis_list[isite] != opsite.value.shape[0]: raise ValueError( f"Number of basis at {isite=} is not consistent with {opsite=} and {nbasis_list[isite]=}" ) for i, nbasis in enumerate(nbasis_list): if nbasis == 0: raise ValueError(f"Number of basis at {i=} is ambiguous.") return nbasis_list @property def isdiag_list(self) -> list[bool]: isdiag_list = [True] * self.ndim for opproduct in self.ops: for isite, opsite in zip( opproduct.sites, opproduct.ops, strict=True ): isdiag_list[isite] &= opsite.isdiag return isdiag_list def __add__( self, other: OpSite | OpProductSite | SumOfProducts | int | float | complex | sympy.Basic, ) -> SumOfProducts: if isinstance(other, int | float | complex | sympy.Basic): op_product = OpProductSite([get_eye_site(self.ops[0].sites[0])]) return self + op_product * other elif isinstance(other, OpSite): op_product = OpProductSite([other]) return self + op_product elif isinstance(other, OpProductSite): self.ops.append(other) self.coefs.append(other.coef) self.symbols.append(other.symbol) return self elif isinstance(other, SumOfProducts): for i in range(len(other.ops)): self.ops.append(other.ops[i]) self.coefs.append(other.coefs[i]) self.symbols.append(other.symbols[i]) return self else: raise ValueError(f"Invalid type: {type(other)=}") def __sub__( self, other: OpSite | OpProductSite | SumOfProducts ) -> SumOfProducts: if isinstance(other, OpSite | OpProductSite | SumOfProducts): return self + (-1) * other else: raise ValueError(f"Invalid type: {type(other)=}") def __mul__( self, other: int | float | complex | sympy.Basic | OpSite | OpProductSite | SumOfProducts, ) -> SumOfProducts: if isinstance(other, int | float | complex | sympy.Basic): for i in range(len(self.coefs)): self.ops[i] *= other self.coefs[i] = self.ops[i].coef return self elif isinstance(other, OpSite): for i in range(len(self.ops)): self.ops[i] = self.ops[i] * other return self elif isinstance(other, OpProductSite): assert len(self.ops) == len(self.coefs) for i in range(len(self.ops)): self.ops[i] = self.ops[i] * other self.coefs[i] *= other.coef return self elif isinstance(other, SumOfProducts): raise NotImplementedError() else: raise ValueError(f"Invalid type: {type(other)=}") def __rmul__( self, other: int | float | complex | sympy.Basic ) -> SumOfProducts: if isinstance(other, int | float | complex | sympy.Basic): return self.__mul__(other) else: raise ValueError(f"Invalid type: {type(other)=}")
[docs] def to_mpo(self) -> list[NDArray]: raise NotImplementedError()
def __iter__(self) -> Iterator[OpProductSite]: return iter(self.ops)