Stiefel manifold¶
This notebook is to reproduce the results of the paper: Riemannian geometry and automatic differentiation for optimization problems of quantum physics and quantum technologies doi-link
also see paper "A Global Cayley Parametrization of Stiefel Manifold for Direct Utilization of Optimization Mechanisms Over Vector Spaces" doi-link
In [1]:
Copied!
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
try:
import numqi
except ImportError:
%pip install numqi
import numqi
np_rng = np.random.default_rng(234) #fix seed for documentation
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
try:
import numqi
except ImportError:
%pip install numqi
import numqi
np_rng = np.random.default_rng(234) #fix seed for documentation
In [2]:
Copied!
class LowEnergySpectrum(torch.nn.Module):
def __init__(self, dim:int, rank:int, method:str='cholesky'):
super().__init__()
self.manifold = numqi.manifold.Stiefel(dim, rank, dtype=torch.complex128, method=method)
self.matH = None
def set_matH(self, matH):
self.matH = torch.tensor(matH, dtype=torch.complex128)
def forward(self):
EVC = self.manifold()
loss = torch.trace(EVC.T.conj() @ self.matH @ EVC).real
return loss
class LowEnergySpectrum(torch.nn.Module):
def __init__(self, dim:int, rank:int, method:str='cholesky'):
super().__init__()
self.manifold = numqi.manifold.Stiefel(dim, rank, dtype=torch.complex128, method=method)
self.matH = None
def set_matH(self, matH):
self.matH = torch.tensor(matH, dtype=torch.complex128)
def forward(self):
EVC = self.manifold()
loss = torch.trace(EVC.T.conj() @ self.matH @ EVC).real
return loss
In [3]:
Copied!
dim = 128
rank = 32
tmp0 = np_rng.normal(size=(dim,dim)) + 1j*np_rng.normal(size=(dim,dim))
matU = np.linalg.eigh(tmp0@tmp0.T.conj())[1]
tmp0 = np_rng.uniform(-4, 0, size=dim)
EVL = np.exp(tmp0) - np.exp(tmp0.max())
matH = (matU * EVL) @ matU.T.conj()
method_list = ['so-exp', 'so-cayley', 'qr', 'polar', 'choleskyL']
kwargs = dict(maxiter=600, theta0=('uniform',-0.1,0.1), num_repeat=1, tol=1e-16, seed=np_rng)
result_dict = dict()
for method in method_list:
print(method)
model = LowEnergySpectrum(dim, rank, method=method)
model.set_matH(matH)
callback = numqi.optimize.MinimizeCallback(print_freq=1, tag_print=False)
tmp0 = time.time()
theta_optim = numqi.optimize.minimize(model, callback=callback, **kwargs)
result_dict[method] = np.array(callback.state['fval']), time.time()-tmp0
ret_ = np.linalg.eigvalsh(matH)[:rank].sum()
dim = 128
rank = 32
tmp0 = np_rng.normal(size=(dim,dim)) + 1j*np_rng.normal(size=(dim,dim))
matU = np.linalg.eigh(tmp0@tmp0.T.conj())[1]
tmp0 = np_rng.uniform(-4, 0, size=dim)
EVL = np.exp(tmp0) - np.exp(tmp0.max())
matH = (matU * EVL) @ matU.T.conj()
method_list = ['so-exp', 'so-cayley', 'qr', 'polar', 'choleskyL']
kwargs = dict(maxiter=600, theta0=('uniform',-0.1,0.1), num_repeat=1, tol=1e-16, seed=np_rng)
result_dict = dict()
for method in method_list:
print(method)
model = LowEnergySpectrum(dim, rank, method=method)
model.set_matH(matH)
callback = numqi.optimize.MinimizeCallback(print_freq=1, tag_print=False)
tmp0 = time.time()
theta_optim = numqi.optimize.minimize(model, callback=callback, **kwargs)
result_dict[method] = np.array(callback.state['fval']), time.time()-tmp0
ret_ = np.linalg.eigvalsh(matH)[:rank].sum()
so-exp
[round=0] min(f)=-30.901894423783205, current(f)=-30.901894423783205 so-cayley
[round=0] min(f)=-30.901894423785713, current(f)=-30.901894423785713 qr
[round=0] min(f)=-30.901894423776174, current(f)=-30.901894423776174 polar
[round=0] min(f)=-30.901894423783762, current(f)=-30.901894423783762 choleskyL
[round=0] min(f)=-30.899152486112257, current(f)=-30.899152486112257
In [4]:
Copied!
fig,ax = plt.subplots()
for method in method_list:
x0,x1 = result_dict[method]
ax.plot(x0-ret_, label=f'{method} ({x1:.2f}s)')
ax.set_yscale('log')
ax.set_xlabel('iteration')
ax.set_ylabel('loss')
ax.set_title(f'St({dim},{rank}) manifold')
ax.grid()
ax.legend()
fig.tight_layout()
fig,ax = plt.subplots()
for method in method_list:
x0,x1 = result_dict[method]
ax.plot(x0-ret_, label=f'{method} ({x1:.2f}s)')
ax.set_yscale('log')
ax.set_xlabel('iteration')
ax.set_ylabel('loss')
ax.set_title(f'St({dim},{rank}) manifold')
ax.grid()
ax.legend()
fig.tight_layout()