A Problem about RNNs

I want to write here about a problem I became interested in after meeting Boris Hanin at the Simons Institute Foundations of Deep Learning program a few summers ago. I talked briefly with Boris about his work on the number of affine pieces of random neural networks.

Boris's work showed that the number of affine pieces in a vanilla feedforward ReLU network was limited. I began to wonder if this was true for RNNs as well. It seems like RNNs have recursive computation that makes their outputs more complicated. They can even compute very weird functions like the devil's staircase if the weights are set carefully. I ended up with the question: Is there some input data you can feed a ReLU RNN so that it has a chance of learning an infinite-affine-piece function?

Here is the conjecture I had, in formal language, and my thoughts:

Define a \(k\)-depth \(w\)-width 1-hidden-layer ReLU RNN, parametrized by the weight \(W \in \mathbb{R}^{w \times w}\), to be the function $$ f_k(W, x) = (\text{ReLU} \circ W \circ \cdots \circ \text{ReLU} \circ W)(x)_1 $$ With \(k\) applications of \(W\). The subscript 1 denotes that the output of the network is the first component of the output layer vector.

Conjecture: There exists some finite data set \(\{(X_i, y_i)\}_{i=1}^n\) such that, if \(W\) is initialized with standard normal values, and the network is trained by gradient flow on the squared loss, then the expected number of affine pieces in the trained network is exponential in \(k\).

Proof attempt (wrong). Consider the devil's staircase function. Note there is a choice of parameters \(W^*\) so that \(f_k\) computes the \(k\)th function in the devil's staircase iteration. Choose \(n >> w^2\) datapoints \((x, y)\) that exist in the limiting DS function, not on any of its discontinuities. Consider the loss landscape in the vicinity of \(W^*\). The landscape is smooth here - since we are not on a discontinuity the loss is a polynomial in the elements of \(W\) - and the loss has a minimum at \(W^*\). Furthermore, since there are more data points than parameters, presumably the losses associated with each of these data points mean all but a rank \(w^2-1\) subspace intersecting \(W^*\) has nonzero loss. We make sure that the data points are chosen such that the intersection of all of these subspaces is rank 0 on a sufficiently small neighborhood of \(W^*\).

There are obviously a few holes in this proof. The rank \(w^2-1\) zero loss subspaces for each datapoint are presumably curved, so these individual point losses are not convex, and I don't know if or how to show their sum is convex. Furthermore, some of the parameters may not affect the loss on any datapoint in the vicinity of \(W^*\), since they may only control the widths of the steps rather than their height or slope, so the zero loss subspaces for all datapoints may intersect on a line. Trying to move the datapoints to the corners of the staircase function opens up a host of new problems, such as pathological behavior in the loss gradient.

Does the devil's staircase work, or is a piecewise affine approximations of a quadratic function better? The function \(x \mapsto x(1-x)\) has the representation [see Yarotsky's 2016 paper]. $$ \frac14 t_1(x) + \frac{1}{16} t_2(x) + \frac{1}{64} t_4(x) + \dots $$ where \(t_i\) is the triangle wave with \(i\) cycles on \([0,1]\).

When I was studying this problem two years ago, I remember thinking a good candidate function would be. $$ \frac14 t_1(x) + \frac{1}{32} t_2(x) + \frac{1}{128} t_4(x) + \dots $$ Because (I think) this would admit a \(W\) with functional norm \(< 1\), which was important somehow.

Anyway, I have stopped doing research on the theory of deep learning, so it is unlikely I will ever solve this problem. If you ever manage to prove this conjecture, or disprove it, I would be happy to hear about it!