本記事では、時系列モデリング手法HiPPOの理解を目指し、著者実装をstep-by-stepで動かす。
参考にする著者実装はこちら。
なお、HiPPOの理論は第一部の記事にまとめたのでそちらも参照されたい。
必要モジュールのインポート
from functools import partial import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.data as data import numpy as np from scipy import signal from scipy import linalg as la from scipy import special as ss import matplotlib.pyplot as plt device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
テスト用データ作成
ノイズを加えた時系列データを作成するための、whitesignal()が用意されている。ノイズの周波数の上限をfreq引数で制御できるため、過度に急峻なノイズが加わることを防ぐことができる。
def whitesignal(period, dt, freq, rms=0.5, batch_shape=()): """ Produces output signal of length period / dt, band-limited to frequency freq Output shape (*batch_shape, period/dt) Adapted from the nengo library """ if freq is not None and freq < 1. / period: raise ValueError(f"Make ``{freq} >= 1. / {period}`` to produce a non-zero signal",) nyquist_cutoff = 0.5 / dt if freq > nyquist_cutoff: raise ValueError(f"{freq} must not exceed the Nyquist frequency for the given dt ({nyquist_cutoff:0.3f})") n_coefficients = int(np.ceil(period / dt / 2.)) shape = batch_shape + (n_coefficients + 1,) sigma = rms * np.sqrt(0.5) coefficients = 1j * np.random.normal(0., sigma, size=shape) coefficients[..., -1] = 0. coefficients += np.random.normal(0., sigma, size=shape) coefficients[..., 0] = 0. set_to_zero = np.fft.rfftfreq(2 * n_coefficients, d=dt) > freq coefficients *= (1-set_to_zero) power_correction = np.sqrt(1. - np.sum(set_to_zero, dtype=float) / n_coefficients) if power_correction > 0.: coefficients /= power_correction coefficients *= np.sqrt(2 * n_coefficients) signal = np.fft.irfft(coefficients, axis=-1) signal = signal - signal[..., :1] # Start from 0 return signal
時系列データを作成してプロットしてみよう。
np.random.seed(0) T=4 dt=5e-4 N=64 freq=20.0 vals = np.arange(0.0, T+dt, dt) L = int(T / dt) + 1 u = torch.FloatTensor(whitesignal(T, dt, freq=freq)) u = F.pad(u, (1, 0)) u = u + torch.FloatTensor( np.sin(1.5 * np.pi / T * np.arange(0, T + dt, dt)) ) # add 3/4 of a sin cycle u = u.to(device) plt.figure(figsize=(12, 4)) offset = 0.0 plt.plot(vals, u.cpu() + offset, "k", linewidth=1.0)
HiPPO-LegSのクラスを用意
以下の関数を持つHiPPOScale
クラスを作成する。
__init__()
: LTI方程式の係数行列 A,B を計算しておくforward()
: 時系列信号を入力としてLTI方程式を解き、係数ベクトル c の時間発展を計算するreconstruct()
: 係数ベクトル c から現在までの時系列信号を再構成する
class HiPPOScale(nn.Module): """Vanilla HiPPO-LegS model (scale invariant instead of time invariant)""" def __init__(self, N, max_length=1024): """ max_length: maximum sequence length """ super().__init__() # HiPPO行列 A,B を計算しておく def forward(self, inputs): """ inputs : (length, ...) output : (length, ..., N) where N is the order of the HiPPO projection """ # 時系列信号を入力としてLTI方程式を解き、係数ベクトル c の時間発展を計算する ... def reconstruct(self, c): # 係数ベクトル c から、現在までの時系列信号を再構成する ...
HiPPO行列の計算
LegS測度
のもとでLTI方程式
の係数行列は次のように与えられるのであった。
これは以下のtransition()
で計算される。
def transition(N): # Legendre (scaled) # q = np.arange(N, dtype=np.float64) q = np.arange(N, dtype=np.float32) col, row = np.meshgrid(q, q) r = 2 * q + 1 M = -(np.where(row >= col, r, 0) - np.diag(q)) T = np.sqrt(np.diag(2 * q + 1)) A = T @ M @ np.linalg.inv(T) B = np.diag(T)[:, None] B = ( B.copy() ) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) return A, B
試しにN=3でHiPPO行列を計算してみる。
A, B = transition(3) B = B.squeeze(-1)
まず行列Aの成分を確認しよう。
# print(A) [[-1. 0. 0. ] [-1.7320508 -1.9999999 0. ] [-2.236068 -3.8729835 -3. ]]
に関して対角成分は
下三角の非対角成分は であることが確認できる。
次に行列Bの成分を確認しよう。
# print(B) [1. 1.7320508 2.236068 ]
各成分がであることが確認できる。
離散化した常微分方程式による時間発展
Bilinear法では以下の式で係数cが更新される。
max_length = L
A, B = transition(N)
B = B.squeeze(-1)
A_stacked = np.empty((max_length, N, N), dtype=A.dtype)
B_stacked = np.empty((max_length, N), dtype=B.dtype)
solve_triangular(a,b)
は、が三角行列であると仮定して、方程式
を
について解く
以下では、方程式
および
をについて解くことで各時刻の
および
の値を求めている。
計算された値はA_stackedとB_stackedにそれぞれ保存される。
for t in range(1, max_length + 1): At = A / t Bt = B / t # bilinear A_stacked[t - 1] = la.solve_triangular( np.eye(N) - At / 2, np.eye(N) + At / 2, lower=True ) B_stacked[t - 1] = la.solve_triangular( np.eye(N) - At / 2, Bt, lower=True )
それではこれをHiPPO-LegSの初期化操作に追加しよう。 以下のコードではPyTorchのregister_buffer()を使用する。 register_buffer()によって、 nn.Moduleの中で、最適化されるパラメータ以外のパラメータを定義することができる。例えば、以下のコードではA_stackedやB_stackedはmodel.parametersでは出てこない(最適化パラメータではない)が、model.state_dictでは出てくる(保存が可能)。
class HiPPOScale(nn.Module): """Vanilla HiPPO-LegS model (scale invariant instead of time invariant)""" def __init__(self, N, max_length=1024): """ max_length: maximum sequence length """ super().__init__() self.N = N A, B = transition(N) B = B.squeeze(-1) A_stacked = np.empty((max_length, N, N), dtype=A.dtype) B_stacked = np.empty((max_length, N), dtype=B.dtype) for t in range(1, max_length + 1): At = A / t Bt = B / t # bilinear A_stacked[t - 1] = la.solve_triangular( np.eye(N) - At / 2, np.eye(N) + At / 2, lower=True ) B_stacked[t - 1] = la.solve_triangular( np.eye(N) - At / 2, Bt, lower=True ) self.register_buffer("A_stacked", torch.Tensor(A_stacked)) # (max_length, N, N) self.register_buffer("B_stacked", torch.Tensor(B_stacked)) # (max_length, N) def forward(self, inputs): """ inputs : (length, ...) output : (length, ..., N) where N is the order of the HiPPO projection """ # 時系列信号を入力としてLTI方程式を解き、係数ベクトル c の時間発展を計算する ... def reconstruct(self, c): # 係数ベクトル c から、現在までの時系列信号を再構成する ...
forward()の実装
forward()
では式(1)の計算を行い、各時刻の係数ベクトルcを逐次的に求める。次のように実装されている。
class HiPPOScale(nn.Module): ... def forward(self, inputs): """ inputs : (length, ...) output : (length, ..., N) where N is the order of the HiPPO projection """ L = inputs.shape[0] inputs = inputs.unsqueeze(-1) x = torch.transpose(inputs, 0, -2) x = x * self.B_stacked[:L] x = torch.transpose(x, 0, -2) # (length, ..., N) c = torch.zeros(x.shape[1:]).to(inputs) cs = [] for t, f in enumerate(inputs): c = F.linear(c, self.A_stacked[t]) + self.B_stacked[t] * f cs.append(c) return torch.stack(cs, dim=0) ...
forward()
の前半では係数cの初期ベクトルを作成する。
L = inputs.shape[0] inputs = inputs.unsqueeze(-1) x = torch.transpose(inputs, 0, -2) x = x * self.B_stacked[:L] x = torch.transpose(x, 0, -2) # (length, ..., N) c = torch.zeros(x.shape[1:]).to(inputs)
後半で式(1)の計算を逐次的に行い、最後にcsをTensor形式で結合する。 torch.stack
では新しいdimを作成し、そのdimに沿ってテンソルを連結することができる。
cs = [] for t, f in enumerate(inputs): c = F.linear(c, self.A_stacked[t]) + self.B_stacked[t] * f cs.append(c) torch.stack(cs, dim=0)
それでは時系列信号u
を入力してHiPPO係数を計算してみよう。
hippo = HiPPOScale(N=N, max_length=int(T / dt)+1).to(device) cs = hippo(u)
HiPPO係数は以下のように計算される。
# print(cs.shape) # print(cs) torch.Size([8001, 64]) tensor([[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00, 0.0000e+00], [ 2.3562e-04, 2.7207e-04, 1.5053e-04, ..., 1.9469e-11, -2.3281e-11, 1.7233e-11], [ 1.1005e-02, 1.4125e-02, 9.9080e-03, ..., 6.9807e-10, -2.7768e-10, -5.0084e-10], ..., [ 6.1107e-01, -4.5706e-01, -3.7144e-01, ..., 4.1828e-03, -2.0363e-02, -1.4803e-02], [ 6.1086e-01, -4.5731e-01, -3.7155e-01, ..., 4.1581e-03, -2.0260e-02, -1.4424e-02], [ 6.1066e-01, -4.5755e-01, -3.7165e-01, ..., 4.1496e-03, -2.0141e-02, -1.4033e-02]])
信号の再構成
最後にHiPPO係数から時系列信号を再構成してみる。これはHiPPOによって作成された記憶を頼りに、過去の信号履歴を復元する操作である。
まず]の区間で時間グリッドを用意する。
vals = np.linspace(0.0, 1.0, max_length)
測度関数
def measure_fn(c=0.0): # legs fn = lambda x: np.heaviside(x, 1.0) * np.exp(-x) fn_tilted = lambda x: np.exp(c * x) * fn(x) return fn_tilted measure = measure_fn()(vals)
基底関数
LegSで用いられる直交基底であるLegendre多項式を計算しておく。
def basis(N, vals, c=0.0, truncate_measure=True): """ vals: list of times (forward in time) returns: shape (T, N) where T is length of vals """ # legs _vals = np.exp(-vals) eval_matrix = ss.eval_legendre(np.arange(N)[:, None], 1 - 2 * _vals).T # (L, N) eval_matrix *= (2 * np.arange(N) + 1) ** 0.5 * (-1) ** np.arange(N) if truncate_measure: eval_matrix[measure_fn()(vals) == 0.0] = 0.0 p = torch.tensor(eval_matrix) p *= np.exp(-c * vals)[:, None] # [::-1, None] return p eval_matrix = torch.Tensor((B[:, None] * ss.eval_legendre(np.arange(N)[:, None], 2 * vals - 1)).T)
0~6次のLegendre多項式を以下に図示する。
再構成
HiPPOにより求めた係数cから時系列信号を再構成する。これは係数と基底行列の積を取ることで求められる。
rec = eval_matrix.to(cs) @ cs.unsqueeze(-1)
plt.figure(figsize=(14, 5)) offset = 0.0 plt.plot(vals, u.cpu() + offset, "k", linewidth=1.0, label="input $u(t)$") plt.plot(vals, rec[-1], label="HiPPO reconstruction", color="red") plt.xlim(0,1) plt.legend()
まとめ
以上、HiPPOによる著者実装を見てきた。実装は極めてシンプルで、直交基底に対する各基底の係数を式(1)で時間発展させることで、時系列データをモデリングすることができた。このように簡略な数値計算で強力な時系列モデリングを実現する理論的枠組みの強力さを再認識した。