Generating natural-looking digits with neural networks


< change language

In this post I will show you how I got a neural network to write digits that look like they've been handwritten. In fact, the digits you can see in the first image of this post where generated by said neural network. The code that did all this is here.

It all started when I read, some time ago, about this funny way in which you could use neural networks: you are going to train a neural network to take a vector of size $N$ as input, and you are going to teach your network to do nothing. That is, you are going to teach your neural network to behave as the identity function... but with a twist! If your input size is $N$, somewhere in the middle of the network you want to have a layer with less than $N$ neurons! If this happens, you can think that your vector $x$ that is going in, first has to be squeezed through that tight layer, and then it expands again to be the output of the network.

If now you slice your network in half, the part of the network from the input up until the "squeeze it tight" part, compresses your inputs. The second half, from the "squeeze it tight" part until the output, decompresses your data (for those of you interested in looking this up, I think the correct name for this network is "encoder", but I am not sure).

Let that sink in for a second.
Now that I "know" how to compress things, imagine this scenario: I teach a neural network to compress digit images into a vector of size $10$ (one component per existing digit), and let us call encoder to this network. To train the encoder, all I have to do is effectively teach the network to classify the images of the digits. After training the encoder, I will create a new network (let's call it decoder) that takes $10$ values as input, and returns an image as output. Now I will teach the decoder to undo whatever the encoder does.

After doing that for a bit, say I take this image of a $2$:

If I feed that image into my encoder, this is what comes out:
tensor([[-3.2723,  3.7296,  9.6557,  4.9145, -8.6327, -3.1320,
                    -3.5934,  6.0783, -0.8307, -5.3434]])

If now I take that vector and feed it into the decoder, this is the image that comes out:

Looks good, right? Of course it is not exactly the same, but take a second to appreciate that the original image was $28 \times 28$, hence had $784$ values; those $784$ values were compressed into the $10$ numbers you saw above and then those $10$ numbers were used to reconstruct this beautiful $2$.

Of course that to build this $2$ we had to first have a $2$ go through the encoder, but I don't really like that step. What I had envisioned was a system that only made use of a single network.

One possible first idea would be to feed the decoder a vector of zeroes with a single $1$ in the right place. The only reason that does not work so well is because the encoder doesn't really compress the images to such nice-looking vectors, so the decoder doesn't really know how to handle that. Instead, I have a little helper function that creates a random vector of size $10$, where each entry is randomly chosen in $[-2, 1]$, except for the coordinate that corresponds to the digit I want to create: that one is a random number in $[13, 14]$.

Using that $5$ times for each digit, I was able to create these "handwritten" digits:


It is true that they look practically the same, but they are not!

Then I looked at those pictures, chose one digit out of each row, and used some basic thresholding to try and set the background to black and the number to white, to see if I could make the digits more easily read. This is what I ended up with:


which is not too shabby, in my opinion.

As future work, there is still everything to be done! Trying different neural network architectures, studying if the accuracy of the encoder is very important or not, finding more clever ways of giving the input to the decoder, finding more clever ways of enhancing the image produced by the decoder, trying to understand if it is possible for the network to generate more distinct digits, etc.

All this was written in a Jupyter notebook that you can find here. I used pytorch for this, so if you want to run it you'll need to have that. Otherwise, you can still look at the pictures and check the specifications of what I did, and try to reproduce this with your own packages and whatnot. TL;DR the images throughout this post are of digits that were generated by a neural network.

  - RGS

Popular posts from this blog

Pocket maths: how to compute averages in your head

Tutorial on programming a memory card game

Markov Decision Processes 02: how the discount factor works