202005200123

The Form of the Loss in Matrix Completion

tags: [ neural_networks , matrix_completion ]

Key: everything goes through the loss function!

What’s wacky is that, usually when we think of neural networks, we have some data pairs \((x,y)\), and the usual operation is that you feed your input \(x\) into your network, and out pops your prediction of \(y\). Another way to put that is, you assume that \(y = f(x)\) for some unknown function \(f\), and the goal of training a neural network is to emulate that unknown \(f\).

In particular, if you ignore the non-linear components, and we deal with a fully connected neural network, then what it boils down to is repeatedly hitting the input vector \(x\) with weight matrices \(W_1, W_2, \ldots, W_N\).

Everything gets turned on its head, it feels, when you consider the problem of Matrix Completion. The architecture is exactly the same, except now the \(x\)’s have disappeared, and what you’re left with are just the weight matrices, and the way you use them is all through this new loss function:

\[ \begin{align*} \norm{ P_{\Omega}(\hat{W}) - P_{\Omega}(W) }_{2}^2, \end{align*} \]

where the \(P_{\Omega}\) is the vectorized form of the relevant entries shown in \(W\), and the \(\hat{W}\) is precisely the product of all the individual weight matrices (since it’s linear it’s equivalent to hitting \(x\) with one matrix). Thus, there’s not really a nice one-to-one in terms of what \((x,y)\) are here. But those are ultimately superfluous.

Really, everything can be derived through the loss function. That is, the normal loss function looks like:

\[ \begin{align*} \sum_{i} L(y_i, f(x_i)). \end{align*} \]

The diagrams we build to describe our networks can sometimes be misleading.

Remark. I think what’s actually confusing is that the pytorch code for matrix completion is the same as if you were creating a normal fully connected neural network. So it feels like you’re creating all this scaffolding with the input and everything, but you don’t actually use it. Everything goes through how you construct the loss function.