Recurrent Neural Networks for Text Generation

Posted on
March 20, 2019
by
Boris Dayma, ML Engineer

Understanding Recurrent Neural Networks for Text Generation

Having worked mainly with images and ConvNets, I recently became interested in recurrent neural networks (RNNs). I wanted to revisit text generation as detailed in The Unreasonable Effectiveness of Recurrent Neural Networks and try to better understand RNNs and how to optimize them.

I like to think of this project as building the modern lorem ipsum, which is used as placeholder text in all kinds of templates. We can do a much better job by reading a large piece of text and generating “random” text with the same style.

Here are examples of text generated through the model we will be building.


 

I’ve probably been spending way too much time reading made-up text instead of actual books!

This project was developed with PyTorch and every single run was logged on Weights & Biases and is fully reproducible.

You can follow the project implementation through the notebook.

Nowadays, it’s pretty easy to implement an RNN with custom cells and multiple layers with dropout between them.

self.rnn = torch.nn.LSTM(input_size = input_size, hidden_size = hidden_size, num_layers = num_layers, dropout = dropout)

That’s it for the definition! Unfortunately, there are quite a few more details to take care of…

The main idea is to read each character, encode it as a one-hot tensor, feed it to our RNN, and use the output as a probability distribution for the next character. Using a cross-entropy loss and backpropagation we can then improve our predictions.



Some components were essential to get the model running:

  • encoding to one-hot tensors needs to be done on the fly during training to avoid using too much memory
  • the loss needs to be accumulated on a long sequence of characters prior to backpropagation, otherwise the model does not learn anything useful (this stabilizes the weights update)
  • multiple sequences of data must be read in parallel for efficient learning


Developing a complete prototype through a notebook is very helpful for debugging and observing critical items that must be implemented prior to optimization. I would have actually kept some of those essential “details” for later (or maybe even for never)…

To read the text in parallel batches, we split the text at equidistant intervals so the weights do not focus on a specific context (i.e. a section of the text). The next batch follows each sequence to keep the hidden state relevant.



After only 5 minutes of training with not too much fine-tuning, we obtain promising results:


"The dear was heard on a man."

“The would did not now he wish rement the grave of the young man with the last of the Paris, and her
carry so a son of the door."

“I do you will be four a moment to the prove you will a corntion, and where the doung the still of the sconvers
of the miscondens. I will to de Valentine, and the discontort of the count of the tell me to him."

“The young man he was a mistors, and with the moment.”


Now we go into the inner depths of recurrent neural networks, hoping to discover how to best optimize them.



Here are the conclusions I reached after analyzing and comparing every logged run:

  • GRU is more efficient than LSTM & vanilla RNN
  • 2 to 3 layers is deep enough for this RNN
  • width has a much larger effect than depth
  • use dropout only after signs of overfitting
  • loss should be calculated after sequences of about 100 characters (not much improvement from using longer sequences)
  • multiple sequences per batch improve the results but 8 is enough to avoid overfitting
  • network architecture needs to be adapted to a specific text length.


Please refer to the W&B report for the justifications and supporting graphs. I built the report as I was launching new experiments to easily keep track of my insights.

It is interesting to note that some texts are more predictable than others. For example the Internal Revenue Code learns much faster than The Count of Monte-Cristo, even after dividing its length by 10 (to end up with a similar number of characters).

I observed that the text generated by a model can give you a good idea of what it is based on. For example, text generated from the Internal Revenue Code shows what the network has retained from it: lots of amendments, sections, subsections, properties, and taxes!

The fact that the syntax is mostly well preserved in every type of text is really impressive and could fool anyone not paying close attention.

Feel free to clone the github project repository and share your results and conclusions!

Project links:



Try our tools →