Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrvinod committed Jul 17, 2019
1 parent 5cacf31 commit abb9a81
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 61 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ If you're new to PyTorch, first read [Deep Learning with PyTorch: A 60 Minute Bl

Questions, suggestions, or corrections can be posted as issues.

I'm using `PyTorch 0.4` in `Python 3.6`.
I'm using `PyTorch 1.1` in `Python 3.6`.

# Contents

Expand Down
114 changes: 55 additions & 59 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence


Expand Down Expand Up @@ -46,7 +45,7 @@ def forward(self, documents, sentences_per_document, words_per_sentence):
"""
# Apply sentence-level attention module (and in turn, word-level attention module) to get document embeddings
document_embeddings, word_alphas, sentence_alphas = self.sentence_attention(documents, sentences_per_document,
words_per_sentence) # (n_documents, 2 * sentence_rnn_size), (n_documents, max_doc_len_in_batch, max_sent_len_in_batch), # (n_documents, max_doc_len_in_batch)
words_per_sentence) # (n_documents, 2 * sentence_rnn_size), (n_documents, max(sentences_per_document), max(words_per_sentence)), (n_documents, max(sentences_per_document))

# Classify
scores = self.fc(self.dropout(document_embeddings)) # (n_documents, n_classes)
Expand Down Expand Up @@ -105,33 +104,33 @@ def forward(self, documents, sentences_per_document, words_per_sentence):
:param words_per_sentence: sentence lengths, a tensor of dimensions (n_documents, sent_pad_len)
:return: document embeddings, attention weights of words, attention weights of sentences
"""
# Sort documents by decreasing document lengths (SORTING #1)
sentences_per_document, doc_sort_ind = sentences_per_document.sort(dim=0, descending=True)
documents = documents[doc_sort_ind] # (n_documents, sent_pad_len, word_pad_len)
words_per_sentence = words_per_sentence[doc_sort_ind] # (n_documents, sent_pad_len)

# Re-arrange as sentences by removing pad-sentences (DOCUMENTS -> SENTENCES)
sentences, bs = pack_padded_sequence(documents,
lengths=sentences_per_document.tolist(),
batch_first=True) # (n_sentences, word_pad_len), bs is the effective batch size at each sentence-timestep
# Re-arrange as sentences by removing sentence-pads (DOCUMENTS -> SENTENCES)
packed_sentences = pack_padded_sequence(documents,
lengths=sentences_per_document.tolist(),
batch_first=True,
enforce_sorted=False) # a PackedSequence object, where 'data' is the flattened sentences (n_sentences, word_pad_len)

# Re-arrange sentence lengths in the same way (DOCUMENTS -> SENTENCES)
words_per_sentence, _ = pack_padded_sequence(words_per_sentence,
lengths=sentences_per_document.tolist(),
batch_first=True) # (n_sentences), '_' is the same as 'bs' in the earlier step
packed_words_per_sentence = pack_padded_sequence(words_per_sentence,
lengths=sentences_per_document.tolist(),
batch_first=True,
enforce_sorted=False) # a PackedSequence object, where 'data' is the flattened sentence lengths (n_sentences)

# Find sentence embeddings by applying the word-level attention module
sentences, word_alphas = self.word_attention(sentences,
words_per_sentence) # (n_sentences, 2 * word_rnn_size), (n_sentences, max_sent_len_in_batch)
sentences, word_alphas = self.word_attention(packed_sentences.data,
packed_words_per_sentence.data) # (n_sentences, 2 * word_rnn_size), (n_sentences, max(words_per_sentence))
sentences = self.dropout(sentences)

# Apply the sentence-level RNN over the sentence embeddings (PyTorch automatically applies it on the packed_sequence using the effective batch_size)
(sentences, _), _ = self.sentence_rnn(
PackedSequence(sentences, bs)) # (n_sentences, 2 * sentence_rnn_size), (max(sent_lens))
# Apply the sentence-level RNN over the sentence embeddings (PyTorch automatically applies it on the packed_sequence)
packed_sentences, _ = self.sentence_rnn(PackedSequence(data=sentences,
batch_sizes=packed_sentences.batch_sizes,
sorted_indices=packed_sentences.sorted_indices,
unsorted_indices=packed_sentences.unsorted_indices)) # a PackedSequence object, where 'data' is the output of the RNN (n_sentences, 2 * sentence_rnn_size)

# Find attention vectors by applying the attention linear layer
att_s = self.sentence_attention(sentences) # (n_sentences, att_size)
att_s = F.tanh(att_s) # (n_sentences, att_size)
# Find attention vectors by applying the attention linear layer on the output of the RNN
att_s = self.sentence_attention(packed_sentences.data) # (n_sentences, att_size)
att_s = torch.tanh(att_s) # (n_sentences, att_size)
# Take the dot-product of the attention vectors with the context vector (i.e. parameter of linear layer)
att_s = self.sentence_context_vector(att_s).squeeze(1) # (n_sentences)

Expand All @@ -143,30 +142,30 @@ def forward(self, documents, sentences_per_document, words_per_sentence):
att_s = torch.exp(att_s - max_value) # (n_sentences)

# Re-arrange as documents by re-padding with 0s (SENTENCES -> DOCUMENTS)
att_s, _ = pad_packed_sequence(PackedSequence(att_s, bs),
batch_first=True) # (n_documents, max_doc_len_in_batch)
att_s, _ = pad_packed_sequence(PackedSequence(data=att_s,
batch_sizes=packed_sentences.batch_sizes,
sorted_indices=packed_sentences.sorted_indices,
unsorted_indices=packed_sentences.unsorted_indices),
batch_first=True) # (n_documents, max(sentences_per_document))

# Calculate softmax values
sentence_alphas = att_s / torch.sum(att_s, dim=1, keepdim=True) # (n_documents, max_doc_len_in_batch)
sentence_alphas = att_s / torch.sum(att_s, dim=1, keepdim=True) # (n_documents, max(sentences_per_document))

# Similarly re-arrange sentence-level RNN outputs as documents by re-padding with 0s (SENTENCES -> DOCUMENTS)
documents, _ = pad_packed_sequence(PackedSequence(sentences, bs),
batch_first=True) # (n_documents, max_doc_len_in_batch, 2 * sentence_rnn_size)
documents, _ = pad_packed_sequence(packed_sentences,
batch_first=True) # (n_documents, max(sentences_per_document), 2 * sentence_rnn_size)

# Find document embeddings
documents = documents * sentence_alphas.unsqueeze(
2) # (n_documents, max_doc_len_in_batch, 2 * sentence_rnn_size)
2) # (n_documents, max(sentences_per_document), 2 * sentence_rnn_size)
documents = documents.sum(dim=1) # (n_documents, 2 * sentence_rnn_size)

# Also re-arrange word_alphas (SENTENCES -> DOCUMENTS)
word_alphas, _ = pad_packed_sequence(PackedSequence(word_alphas, bs),
batch_first=True) # (n_documents, max_doc_len_in_batch, max_sent_len_in_batch)

# Unsort documents into the original order (INVERSE OF SORTING #1)
_, doc_unsort_ind = doc_sort_ind.sort(dim=0, descending=False) # (n_documents)
documents = documents[doc_unsort_ind] # (n_documents, 2 * sentence_rnn_size)
sentence_alphas = sentence_alphas[doc_unsort_ind] # (n_documents, max_doc_len_in_batch)
word_alphas = word_alphas[doc_unsort_ind] # (n_documents, max_doc_len_in_batch, max_sent_len_in_batch)
word_alphas, _ = pad_packed_sequence(PackedSequence(data=word_alphas,
batch_sizes=packed_sentences.batch_sizes,
sorted_indices=packed_sentences.sorted_indices,
unsorted_indices=packed_sentences.unsorted_indices),
batch_first=True) # (n_documents, max(sentences_per_document), max(words_per_sentence))

return documents, word_alphas, sentence_alphas

Expand Down Expand Up @@ -231,24 +230,23 @@ def forward(self, sentences, words_per_sentence):
:param words_per_sentence: sentence lengths, a tensor of dimension (n_sentences)
:return: sentence embeddings, attention weights of words
"""
# Sort sentences by decreasing sentence lengths (SORTING #2)
words_per_sentence, sent_sort_ind = words_per_sentence.sort(dim=0, descending=True)
sentences = sentences[sent_sort_ind] # (n_sentences, word_pad_len, emb_size)

# Get word embeddings, apply dropout
sentences = self.dropout(self.embeddings(sentences)) # (n_sentences, word_pad_len, emb_size)

# Re-arrange as words by removing pad-words (SENTENCES -> WORDS)
words, bw = pack_padded_sequence(sentences,
lengths=words_per_sentence.tolist(),
batch_first=True) # (n_words, emb_size), bw is the effective batch size at each word-timestep
# Re-arrange as words by removing word-pads (SENTENCES -> WORDS)
packed_words = pack_padded_sequence(sentences,
lengths=words_per_sentence.tolist(),
batch_first=True,
enforce_sorted=False) # a PackedSequence object, where 'data' is the flattened words (n_words, word_emb)

# Apply the word-level RNN over the word embeddings (PyTorch automatically applies it on the packed_sequence using the effective batch_size)
(words, _), _ = self.word_rnn(PackedSequence(words, bw)) # (n_words, 2 * word_rnn_size), (max(sent_lens))
# Apply the word-level RNN over the word embeddings (PyTorch automatically applies it on the PackedSequence)
packed_words, _ = self.word_rnn(
packed_words) # a PackedSequence object, where 'data' is the output of the RNN (n_words, 2 * word_rnn_size)

# Find attention vectors by applying the attention linear layer
att_w = self.word_attention(words) # (n_words, att_size)
att_w = F.tanh(att_w) # (n_words, att_size)
# Find attention vectors by applying the attention linear layer on the output of the RNN
att_w = self.word_attention(packed_words.data) # (n_words, att_size)
att_w = torch.tanh(att_w) # (n_words, att_size)
# Take the dot-product of the attention vectors with the context vector (i.e. parameter of linear layer)
att_w = self.word_context_vector(att_w).squeeze(1) # (n_words)

Expand All @@ -260,23 +258,21 @@ def forward(self, sentences, words_per_sentence):
att_w = torch.exp(att_w - max_value) # (n_words)

# Re-arrange as sentences by re-padding with 0s (WORDS -> SENTENCES)
att_w, _ = pad_packed_sequence(PackedSequence(att_w, bw),
batch_first=True) # (n_sentences, max_sent_len_in_batch)
att_w, _ = pad_packed_sequence(PackedSequence(data=att_w,
batch_sizes=packed_words.batch_sizes,
sorted_indices=packed_words.sorted_indices,
unsorted_indices=packed_words.unsorted_indices),
batch_first=True) # (n_sentences, max(words_per_sentence))

# Calculate softmax values
word_alphas = att_w / torch.sum(att_w, dim=1, keepdim=True) # (n_sentences, max_sent_len_in_batch)
# Calculate softmax values as now words are arranged in their respective sentences
word_alphas = att_w / torch.sum(att_w, dim=1, keepdim=True) # (n_sentences, max(words_per_sentence))

# Similarly re-arrange word-level RNN outputs as sentences by re-padding with 0s (WORDS -> SENTENCES)
sentences, _ = pad_packed_sequence(PackedSequence(words, bw),
batch_first=True) # (n_sentences, max_sent_len_in_batch, 2 * word_rnn_size)
sentences, _ = pad_packed_sequence(packed_words,
batch_first=True) # (n_sentences, max(words_per_sentence), 2 * word_rnn_size)

# Find sentence embeddings
sentences = sentences * word_alphas.unsqueeze(2) # (n_sentences, max_sent_len_in_batch, 2 * word_rnn_size)
sentences = sentences * word_alphas.unsqueeze(2) # (n_sentences, max(words_per_sentence), 2 * word_rnn_size)
sentences = sentences.sum(dim=1) # (n_sentences, 2 * word_rnn_size)

# Unsort sentences into the original order (INVERSE OF SORTING #2)
_, sent_unsort_ind = sent_sort_ind.sort(dim=0, descending=False) # (n_sentences)
sentences = sentences[sent_unsort_ind] # (n_sentences, 2 * word_rnn_size)
word_alphas = word_alphas[sent_unsort_ind] # (n_sentences, max_sent_len_in_batch)

return sentences, word_alphas
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
lr = 1e-3 # learning rate
momentum = 0.9 # momentum
workers = 4 # number of workers for loading data in the DataLoader
epochs = 2 # number of epochs to run without early-stopping
epochs = 2 # number of epochs to run
grad_clip = None # clip gradients at this value
print_freq = 2000 # print training or validation status every __ batches
checkpoint = None # path to model checkpoint, None if none
Expand Down

0 comments on commit abb9a81

Please sign in to comment.