June 9, 2024

Efficient WaveRNN: Optimizing Nonlinearities

Implementing efficient nonlinearities for WaveRNN CPU inference is tricky but critical.

In this series of posts, we're going to go through the WaveRNN neural vocoder for audio waveform synthesis, along with a variety of implementation details and commonly used extensions. For a real implementation, check out the gibiansky/wavernn repository.

Posts in the Series:

Accelerating Nonlinearities

In previous posts, we talked about a wide variety of optimizations, including rewriting the compute kernel in C++, batching GRU input matrix multiplication, using block-sparse matrix-vector multiplication, SIMD intrinsics, and quantization. Implementing all of these can create a faster-than-realtime synthesis kernel for WaveRNN, but there's still room to squeeze more out of our processors. Benchmarking our kernel, we observed that 10-20% of the time is spent in nonlinearities, including the tanh and sigmoid in the GRU and the softmax in the output layer.

Approximating Tanh and Sigmoid

The core autoregressive component of WaveRNN is a Gated Recurrent Unit (GRU) RNN. A GRU uses a state update function which requires two sigmoids and one tanh evaluation per state dimension. To compute these, you can use C++ standard library functions (very slow) or the Intel MKL library on x86 (faster), in either of these cases, these nonlinearities will end up a significant percentage of your inference time.

To speed these up, we can implement our own tanh and sigmoid which are slightly less accurate than standard library or MKL variants, but are significantly faster. First of all, we can write sigmoid as a rescaled tanh:

$$\sigma(x) = \frac{1}{2} \tanh\left(\frac{x}{2}\right) + \frac{1}{2}.$$

We will then use a Padé approximation of tanh. A Padé approximation of order [p/q] for a function $f$ is a rational function $F(x)$

$$F(x) = \frac{a_0 + a_1 x + a_2 x^2 + \cdots + a_p x^p}{1 + b_1 x + b_2 x^2 + \cdots + b_q x^q}.$$

Higher values of $p$ and $q$ allow for more precise approximations at the expense of additional computation.

The coefficients of a Padé approximation are defined to be the coefficients which have the first $p + q$ derivatives of tanh to match at zero.

\begin{align*}F(0) &= \tanh(0) \\ F'(0) &= \tanh'(0) \\ F''(0) &= \tanh''(0) \\ \vdots \\ F^{(p+q)}(0) &= \tanh^{(p+q)}(0) \end{align*}

This system of equations can be solved to yield unique coefficients for this approximation. That said, I'm lazy and will instead take advantage of some resources to make this easy:

Given these resources, we can implement an efficient tanh approximation with AVX, AVX-512, or NEON. This approximation can be fused with the sparse GRU GEMV if necessary and results in a 3-4X speedup for these nonlinearities over Intel MKL (which itself is a huge speedup over C++ standard library functions).

Faster Sampling with Gumbel-Softmax

In addition to the tanh and sigmoid in the GRU layer, our WaveRNN performs a softmax after its output layer and prior to sampling. Traditionally, this process has the following steps:

1. Given the final layer outputs $x_i$, compute the maximum value $x_{\text{max}}$.
2. Subtract the maximum value from each $x_i$.
3. Compute $e^{x_i - x_{\text{max}}}$ for every $x_i$ in the logits.
4. Compute the normalization term, $\left(\Sigma e^{x_i - x_{\text{max}}}\right)^{-1}$.
5. Multiply the exponentiated values by the normalization to get a probability distribution $p_i$.
6. Sample a random value $v \sim U(0, 1)$ from the uniform distribution [0, 1]. (In order to accelerate sampling in a tight inner loop, pre-compute several thousand random numbers and cycle through them.)
7. Find the index $i$ where the cumulative sum of $p_i$ is greater than $v$.

The index $i$ is a sample from your discrete probability distribution.

This sampling procedure has a few downsides. It requires requires scanning through your logits at least three times: once to find the maximum, once to exponentiate and compute normalization terms, and once to sample the final value. It also requires computing $e^x$, which is an expensive nonlinearity to compute accurately.

We can address both of these downsides using Gumbel-Softmax. Gumbel-Softmax was originally introduced in order to approximate sampling during training of a neural network, so a discrete sampling step could be introduced in an intermediate layer of a network. The key point for our purposes, however, is the following:

• Sample $v_i \sim U(0, 1)$ from the uniform distribution [0, 1].
• Compute $g_i = -\ln(-\ln v_i)$. $g_i$ is a sample from the Gumbel distribution.
• Compute modified logits ${\hat x}_i = x_i + g_i$. (You must sample a distinct $g_i$ for each element of the logits.)
• The index $i = \text{argmax}\; \hat x$ is a sample from the discrete distribution $\text{softmax}(x)$.

You can find the derivation for this neat property in the Gumbel-Softmax paper or the concrete distribution paper (which refers to the same distribution).

Using this property, we can do our sampling in a single pass. Given a vector of logits $x_i$ and a vector of pre-computed samples $g_i$ from the Gumbel distribution, we can compute $x_i + g_i$. As we compute this sum, keep track of the maximum value so far and the index of the maximum value. When you reach the end of $x$, the resulting index is a sample from your distribution.

Although sampling was a small fraction (5%) of our inference costs, applying this small optimization sped it up by a factor of five, making the cost of sampling negligible.

Summary

Once the matrix-vector multiplies in WaveRNN are sufficiently optimized, the nonlinearities (tanh, sigmoid, and softmax) start being a significant fraction of the WaveRNN inference time. Padé approximants, we can approximate tanh and sigmoid by an easy-to-compute rational function which can be computed in only a few arithmetic instructions. We can speed up sampling from a softmax by using the Gumbel-Softmax trick, drawing samples from a Gumbel distribution and taking an argmax of the sampled value plus the logit in order to sample from the original softmax distribution. These two optimizations, while small, can speed up inference by 10-15%.

Check out the implementation at gibiansky/wavernn or any of the other blog posts in this series: