どすえのブログ

ソフトウェア開発ブログ

時系列モデリング手法 HiPPO を読み解く(2)

本記事では、時系列モデリング手法HiPPOの理解を目指し、著者実装をstep-by-stepで動かす。

参考にする著者実装はこちら。

github.com

なお、HiPPOの理論は第一部の記事にまとめたのでそちらも参照されたい。

dosuex.com

必要モジュールのインポート

from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import numpy as np
from scipy import signal
from scipy import linalg as la
from scipy import special as ss
import matplotlib.pyplot as plt

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

テスト用データ作成

ノイズを加えた時系列データを作成するための、whitesignal()が用意されている。ノイズの周波数の上限をfreq引数で制御できるため、過度に急峻なノイズが加わることを防ぐことができる。

def whitesignal(period, dt, freq, rms=0.5, batch_shape=()):
    """
    Produces output signal of length period / dt, band-limited to frequency freq
    Output shape (*batch_shape, period/dt)
    Adapted from the nengo library
    """

    if freq is not None and freq < 1. / period:
        raise ValueError(f"Make ``{freq} >= 1. / {period}`` to produce a non-zero signal",)

    nyquist_cutoff = 0.5 / dt
    if freq > nyquist_cutoff:
        raise ValueError(f"{freq} must not exceed the Nyquist frequency for the given dt ({nyquist_cutoff:0.3f})")

    n_coefficients = int(np.ceil(period / dt / 2.))
    shape = batch_shape + (n_coefficients + 1,)
    sigma = rms * np.sqrt(0.5)
    coefficients = 1j * np.random.normal(0., sigma, size=shape)
    coefficients[..., -1] = 0.
    coefficients += np.random.normal(0., sigma, size=shape)
    coefficients[..., 0] = 0.

    set_to_zero = np.fft.rfftfreq(2 * n_coefficients, d=dt) > freq
    coefficients *= (1-set_to_zero)
    power_correction = np.sqrt(1. - np.sum(set_to_zero, dtype=float) / n_coefficients)
    if power_correction > 0.: coefficients /= power_correction
    coefficients *= np.sqrt(2 * n_coefficients)
    signal = np.fft.irfft(coefficients, axis=-1)
    signal = signal - signal[..., :1]  # Start from 0
    return signal

時系列データを作成してプロットしてみよう。

np.random.seed(0)

T=4
dt=5e-4
N=64
freq=20.0

vals = np.arange(0.0, T+dt, dt)
L = int(T / dt) + 1

u = torch.FloatTensor(whitesignal(T, dt, freq=freq))
u = F.pad(u, (1, 0))
u = u + torch.FloatTensor(
    np.sin(1.5 * np.pi / T * np.arange(0, T + dt, dt))
)  # add 3/4 of a sin cycle
u = u.to(device)

plt.figure(figsize=(12, 4))
offset = 0.0
plt.plot(vals, u.cpu() + offset, "k", linewidth=1.0)

テスト用時系列信号

HiPPO-LegSのクラスを用意

以下の関数を持つHiPPOScaleクラスを作成する。

  • __init__() : LTI方程式の係数行列 A,B を計算しておく
  • forward() : 時系列信号を入力としてLTI方程式を解き、係数ベクトル c の時間発展を計算する
  • reconstruct() : 係数ベクトル c から現在までの時系列信号を再構成する
class HiPPOScale(nn.Module):
    """Vanilla HiPPO-LegS model (scale invariant instead of time invariant)"""
    
    def __init__(self, N, max_length=1024):
        """
        max_length: maximum sequence length
        """
        super().__init__()
        # HiPPO行列 A,B を計算しておく
    
    def forward(self, inputs):
        """
        inputs : (length, ...)
        output : (length, ..., N) where N is the order of the HiPPO projection
        """
        # 時系列信号を入力としてLTI方程式を解き、係数ベクトル c の時間発展を計算する
        ...
        
    def reconstruct(self, c):
        # 係数ベクトル c から、現在までの時系列信号を再構成する
        ...

HiPPO行列の計算

LegS測度

 
\mu^{(t)}=\frac{1}{t} \mathbb{I}_{[0, t]}

のもとでLTI方程式

 
\frac{d}{d t} c(t)=-\frac{1}{t} A c(t)+\frac{1}{t} B f(t)

の係数行列は次のように与えられるのであった。

 
A_{n k}=\left\{\begin{array}{ll}
(2 n+1)^{1 / 2}(2 k+1)^{1 / 2} \quad \text { if } n>k \\
n+1 \quad \text { if } n=k, \\
0 \quad \text { if } n \lt k
\end{array} \quad B_n=(2 n+1)^{\frac{1}{2}}\right.

これは以下のtransition()で計算される。

def transition(N):
    # Legendre (scaled)
    # q = np.arange(N, dtype=np.float64)
    q = np.arange(N, dtype=np.float32)
    col, row = np.meshgrid(q, q)
    r = 2 * q + 1
    M = -(np.where(row >= col, r, 0) - np.diag(q))
    T = np.sqrt(np.diag(2 * q + 1))
    A = T @ M @ np.linalg.inv(T)
    B = np.diag(T)[:, None]
    B = (
        B.copy()
    )  # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B)
    return A, B

試しにN=3でHiPPO行列を計算してみる。

A, B = transition(3)
B = B.squeeze(-1)

まず行列Aの成分を確認しよう。

# print(A)
[[-1.         0.         0.       ]
 [-1.7320508 -1.9999999  0.       ]
 [-2.236068  -3.8729835 -3.       ]]

 n=0,1,2 に関して対角成分は n+1
下三角の非対角成分は (2n+1)^{\frac{1}{2}}(2k+1)^{\frac{1}{2}} であることが確認できる。

次に行列Bの成分を確認しよう。

# print(B)
[1.        1.7320508 2.236068 ]

各成分が (2n+1)^{\frac{1}{2}} であることが確認できる。

離散化した常微分方程式による時間発展

Bilinear法では以下の式で係数cが更新される。

 
c(t+\Delta t) = (I-\Delta t / 2A)^{-1}(I+\Delta t/2A)c(t)+\Delta t(I-\Delta t / 2A)^{-1}B f(t) \tag{1}
max_length = L
A, B = transition(N)
B = B.squeeze(-1)

A_stacked = np.empty((max_length, N, N), dtype=A.dtype)
B_stacked = np.empty((max_length, N), dtype=B.dtype)

solve_triangular(a,b)は、 aが三角行列であると仮定して、方程式 a x = b xについて解く
以下では、方程式

 
(I-\Delta t / 2A) x = (I+\Delta t/2A)

および

 
(I-\Delta t / 2A) x = B

 xについて解くことで各時刻の (I-\Delta t / 2A)^{-1}(I+\Delta t/2A)および (I-\Delta t / 2A)^{-1}Bの値を求めている。 計算された値はA_stackedとB_stackedにそれぞれ保存される。

for t in range(1, max_length + 1):
    At = A / t
    Bt = B / t
    # bilinear
    A_stacked[t - 1] = la.solve_triangular(
        np.eye(N) - At / 2, np.eye(N) + At / 2, lower=True
    )
    B_stacked[t - 1] = la.solve_triangular(
        np.eye(N) - At / 2, Bt, lower=True
    )

それではこれをHiPPO-LegSの初期化操作に追加しよう。 以下のコードではPyTorchのregister_buffer()を使用する。 register_buffer()によって、 nn.Moduleの中で、最適化されるパラメータ以外のパラメータを定義することができる。例えば、以下のコードではA_stackedやB_stackedはmodel.parametersでは出てこない(最適化パラメータではない)が、model.state_dictでは出てくる(保存が可能)。

class HiPPOScale(nn.Module):
    """Vanilla HiPPO-LegS model (scale invariant instead of time invariant)"""
    
    def __init__(self, N, max_length=1024):
        """
        max_length: maximum sequence length
        """
        super().__init__()
        self.N = N
        A, B = transition(N)
        B = B.squeeze(-1)
        A_stacked = np.empty((max_length, N, N), dtype=A.dtype)
        B_stacked = np.empty((max_length, N), dtype=B.dtype)
        for t in range(1, max_length + 1):
            At = A / t
            Bt = B / t
            # bilinear
            A_stacked[t - 1] = la.solve_triangular(
                np.eye(N) - At / 2, np.eye(N) + At / 2, lower=True
            )
            B_stacked[t - 1] = la.solve_triangular(
                np.eye(N) - At / 2, Bt, lower=True
            )
        self.register_buffer("A_stacked", torch.Tensor(A_stacked))  # (max_length, N, N)
        self.register_buffer("B_stacked", torch.Tensor(B_stacked))  # (max_length, N)
    
    def forward(self, inputs):
        """
        inputs : (length, ...)
        output : (length, ..., N) where N is the order of the HiPPO projection
        """
        # 時系列信号を入力としてLTI方程式を解き、係数ベクトル c の時間発展を計算する
        ...
        
    def reconstruct(self, c):
        # 係数ベクトル c から、現在までの時系列信号を再構成する
        ...

forward()の実装

forward()では式(1)の計算を行い、各時刻の係数ベクトルcを逐次的に求める。次のように実装されている。

class HiPPOScale(nn.Module):
    ...
    def forward(self, inputs):
        """
        inputs : (length, ...)
        output : (length, ..., N) where N is the order of the HiPPO projection
        """
        L = inputs.shape[0]

        inputs = inputs.unsqueeze(-1)
        x = torch.transpose(inputs, 0, -2)
        x = x * self.B_stacked[:L]
        x = torch.transpose(x, 0, -2)  # (length, ..., N)

        c = torch.zeros(x.shape[1:]).to(inputs)
        cs = []
        for t, f in enumerate(inputs):
            c = F.linear(c, self.A_stacked[t]) + self.B_stacked[t] * f
            cs.append(c)
        return torch.stack(cs, dim=0)
    ...

forward()の前半では係数cの初期ベクトルを作成する。

L = inputs.shape[0]

inputs = inputs.unsqueeze(-1)
x = torch.transpose(inputs, 0, -2)
x = x * self.B_stacked[:L]
x = torch.transpose(x, 0, -2)  # (length, ..., N)

c = torch.zeros(x.shape[1:]).to(inputs)

後半で式(1)の計算を逐次的に行い、最後にcsをTensor形式で結合する。 torch.stackでは新しいdimを作成し、そのdimに沿ってテンソルを連結することができる。

cs = []
for t, f in enumerate(inputs):
    c = F.linear(c, self.A_stacked[t]) + self.B_stacked[t] * f
    cs.append(c)
torch.stack(cs, dim=0)

それでは時系列信号uを入力してHiPPO係数を計算してみよう。

hippo = HiPPOScale(N=N, max_length=int(T / dt)+1).to(device)
cs = hippo(u)

HiPPO係数は以下のように計算される。

# print(cs.shape)
# print(cs)
torch.Size([8001, 64])
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 2.3562e-04,  2.7207e-04,  1.5053e-04,  ...,  1.9469e-11,
         -2.3281e-11,  1.7233e-11],
        [ 1.1005e-02,  1.4125e-02,  9.9080e-03,  ...,  6.9807e-10,
         -2.7768e-10, -5.0084e-10],
        ...,
        [ 6.1107e-01, -4.5706e-01, -3.7144e-01,  ...,  4.1828e-03,
         -2.0363e-02, -1.4803e-02],
        [ 6.1086e-01, -4.5731e-01, -3.7155e-01,  ...,  4.1581e-03,
         -2.0260e-02, -1.4424e-02],
        [ 6.1066e-01, -4.5755e-01, -3.7165e-01,  ...,  4.1496e-03,
         -2.0141e-02, -1.4033e-02]])

信号の再構成

最後にHiPPO係数から時系列信号を再構成してみる。これはHiPPOによって作成された記憶を頼りに、過去の信号履歴を復元する操作である。

まず [0,1]の区間で時間グリッドを用意する。

vals = np.linspace(0.0, 1.0, max_length)

測度関数

def measure_fn(c=0.0):
    # legs
    fn = lambda x: np.heaviside(x, 1.0) * np.exp(-x)
    
    fn_tilted = lambda x: np.exp(c * x) * fn(x)
    return fn_tilted

measure = measure_fn()(vals)

基底関数

LegSで用いられる直交基底であるLegendre多項式を計算しておく。

def basis(N, vals, c=0.0, truncate_measure=True):
    """
    vals: list of times (forward in time)
    returns: shape (T, N) where T is length of vals
    """
    # legs
    _vals = np.exp(-vals)
    eval_matrix = ss.eval_legendre(np.arange(N)[:, None], 1 - 2 * _vals).T  # (L, N)
    eval_matrix *= (2 * np.arange(N) + 1) ** 0.5 * (-1) ** np.arange(N)
    
    if truncate_measure:
        eval_matrix[measure_fn()(vals) == 0.0] = 0.0
    
    p = torch.tensor(eval_matrix)
    p *= np.exp(-c * vals)[:, None]  # [::-1, None]
    return p

eval_matrix = torch.Tensor((B[:, None] * ss.eval_legendre(np.arange(N)[:, None], 2 * vals - 1)).T)

0~6次のLegendre多項式を以下に図示する。

Legendre多項式

再構成

HiPPOにより求めた係数cから時系列信号を再構成する。これは係数と基底行列の積を取ることで求められる。

rec = eval_matrix.to(cs) @ cs.unsqueeze(-1)
plt.figure(figsize=(14, 5))
offset = 0.0
plt.plot(vals, u.cpu() + offset, "k", linewidth=1.0, label="input $u(t)$")
plt.plot(vals, rec[-1], label="HiPPO reconstruction", color="red")
plt.xlim(0,1)
plt.legend()

HiPPOによる時系列入力の再構成

まとめ

以上、HiPPOによる著者実装を見てきた。実装は極めてシンプルで、直交基底に対する各基底の係数を式(1)で時間発展させることで、時系列データをモデリングすることができた。このように簡略な数値計算で強力な時系列モデリングを実現する理論的枠組みの強力さを再認識した。

参考文献

  1. HiPPO: Recurrent Memory with Optimal Polynomial Projections
  2. GitHub - HazyResearch/state-spaces: Sequence Modeling with Structured State Spaces