From abb9a811796ff232684b2dadad41fe644dbe9d17 Mon Sep 17 00:00:00 2001 From: sgrvinod Date: Wed, 17 Jul 2019 08:12:31 +0530 Subject: [PATCH] changes --- README.md | 2 +- model.py | 114 ++++++++++++++++++++++++++---------------------------- train.py | 2 +- 3 files changed, 57 insertions(+), 61 deletions(-) diff --git a/README.md b/README.md index 0772449..a2418a8 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ If you're new to PyTorch, first read [Deep Learning with PyTorch: A 60 Minute Bl Questions, suggestions, or corrections can be posted as issues. -I'm using `PyTorch 0.4` in `Python 3.6`. +I'm using `PyTorch 1.1` in `Python 3.6`. # Contents diff --git a/model.py b/model.py index 6b09dd7..16498f0 100644 --- a/model.py +++ b/model.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -import torch.nn.functional as F from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence @@ -46,7 +45,7 @@ def forward(self, documents, sentences_per_document, words_per_sentence): """ # Apply sentence-level attention module (and in turn, word-level attention module) to get document embeddings document_embeddings, word_alphas, sentence_alphas = self.sentence_attention(documents, sentences_per_document, - words_per_sentence) # (n_documents, 2 * sentence_rnn_size), (n_documents, max_doc_len_in_batch, max_sent_len_in_batch), # (n_documents, max_doc_len_in_batch) + words_per_sentence) # (n_documents, 2 * sentence_rnn_size), (n_documents, max(sentences_per_document), max(words_per_sentence)), (n_documents, max(sentences_per_document)) # Classify scores = self.fc(self.dropout(document_embeddings)) # (n_documents, n_classes) @@ -105,33 +104,33 @@ def forward(self, documents, sentences_per_document, words_per_sentence): :param words_per_sentence: sentence lengths, a tensor of dimensions (n_documents, sent_pad_len) :return: document embeddings, attention weights of words, attention weights of sentences """ - # Sort documents by decreasing document lengths (SORTING #1) - sentences_per_document, doc_sort_ind = sentences_per_document.sort(dim=0, descending=True) - documents = documents[doc_sort_ind] # (n_documents, sent_pad_len, word_pad_len) - words_per_sentence = words_per_sentence[doc_sort_ind] # (n_documents, sent_pad_len) - # Re-arrange as sentences by removing pad-sentences (DOCUMENTS -> SENTENCES) - sentences, bs = pack_padded_sequence(documents, - lengths=sentences_per_document.tolist(), - batch_first=True) # (n_sentences, word_pad_len), bs is the effective batch size at each sentence-timestep + # Re-arrange as sentences by removing sentence-pads (DOCUMENTS -> SENTENCES) + packed_sentences = pack_padded_sequence(documents, + lengths=sentences_per_document.tolist(), + batch_first=True, + enforce_sorted=False) # a PackedSequence object, where 'data' is the flattened sentences (n_sentences, word_pad_len) # Re-arrange sentence lengths in the same way (DOCUMENTS -> SENTENCES) - words_per_sentence, _ = pack_padded_sequence(words_per_sentence, - lengths=sentences_per_document.tolist(), - batch_first=True) # (n_sentences), '_' is the same as 'bs' in the earlier step + packed_words_per_sentence = pack_padded_sequence(words_per_sentence, + lengths=sentences_per_document.tolist(), + batch_first=True, + enforce_sorted=False) # a PackedSequence object, where 'data' is the flattened sentence lengths (n_sentences) # Find sentence embeddings by applying the word-level attention module - sentences, word_alphas = self.word_attention(sentences, - words_per_sentence) # (n_sentences, 2 * word_rnn_size), (n_sentences, max_sent_len_in_batch) + sentences, word_alphas = self.word_attention(packed_sentences.data, + packed_words_per_sentence.data) # (n_sentences, 2 * word_rnn_size), (n_sentences, max(words_per_sentence)) sentences = self.dropout(sentences) - # Apply the sentence-level RNN over the sentence embeddings (PyTorch automatically applies it on the packed_sequence using the effective batch_size) - (sentences, _), _ = self.sentence_rnn( - PackedSequence(sentences, bs)) # (n_sentences, 2 * sentence_rnn_size), (max(sent_lens)) + # Apply the sentence-level RNN over the sentence embeddings (PyTorch automatically applies it on the packed_sequence) + packed_sentences, _ = self.sentence_rnn(PackedSequence(data=sentences, + batch_sizes=packed_sentences.batch_sizes, + sorted_indices=packed_sentences.sorted_indices, + unsorted_indices=packed_sentences.unsorted_indices)) # a PackedSequence object, where 'data' is the output of the RNN (n_sentences, 2 * sentence_rnn_size) - # Find attention vectors by applying the attention linear layer - att_s = self.sentence_attention(sentences) # (n_sentences, att_size) - att_s = F.tanh(att_s) # (n_sentences, att_size) + # Find attention vectors by applying the attention linear layer on the output of the RNN + att_s = self.sentence_attention(packed_sentences.data) # (n_sentences, att_size) + att_s = torch.tanh(att_s) # (n_sentences, att_size) # Take the dot-product of the attention vectors with the context vector (i.e. parameter of linear layer) att_s = self.sentence_context_vector(att_s).squeeze(1) # (n_sentences) @@ -143,30 +142,30 @@ def forward(self, documents, sentences_per_document, words_per_sentence): att_s = torch.exp(att_s - max_value) # (n_sentences) # Re-arrange as documents by re-padding with 0s (SENTENCES -> DOCUMENTS) - att_s, _ = pad_packed_sequence(PackedSequence(att_s, bs), - batch_first=True) # (n_documents, max_doc_len_in_batch) + att_s, _ = pad_packed_sequence(PackedSequence(data=att_s, + batch_sizes=packed_sentences.batch_sizes, + sorted_indices=packed_sentences.sorted_indices, + unsorted_indices=packed_sentences.unsorted_indices), + batch_first=True) # (n_documents, max(sentences_per_document)) # Calculate softmax values - sentence_alphas = att_s / torch.sum(att_s, dim=1, keepdim=True) # (n_documents, max_doc_len_in_batch) + sentence_alphas = att_s / torch.sum(att_s, dim=1, keepdim=True) # (n_documents, max(sentences_per_document)) # Similarly re-arrange sentence-level RNN outputs as documents by re-padding with 0s (SENTENCES -> DOCUMENTS) - documents, _ = pad_packed_sequence(PackedSequence(sentences, bs), - batch_first=True) # (n_documents, max_doc_len_in_batch, 2 * sentence_rnn_size) + documents, _ = pad_packed_sequence(packed_sentences, + batch_first=True) # (n_documents, max(sentences_per_document), 2 * sentence_rnn_size) # Find document embeddings documents = documents * sentence_alphas.unsqueeze( - 2) # (n_documents, max_doc_len_in_batch, 2 * sentence_rnn_size) + 2) # (n_documents, max(sentences_per_document), 2 * sentence_rnn_size) documents = documents.sum(dim=1) # (n_documents, 2 * sentence_rnn_size) # Also re-arrange word_alphas (SENTENCES -> DOCUMENTS) - word_alphas, _ = pad_packed_sequence(PackedSequence(word_alphas, bs), - batch_first=True) # (n_documents, max_doc_len_in_batch, max_sent_len_in_batch) - - # Unsort documents into the original order (INVERSE OF SORTING #1) - _, doc_unsort_ind = doc_sort_ind.sort(dim=0, descending=False) # (n_documents) - documents = documents[doc_unsort_ind] # (n_documents, 2 * sentence_rnn_size) - sentence_alphas = sentence_alphas[doc_unsort_ind] # (n_documents, max_doc_len_in_batch) - word_alphas = word_alphas[doc_unsort_ind] # (n_documents, max_doc_len_in_batch, max_sent_len_in_batch) + word_alphas, _ = pad_packed_sequence(PackedSequence(data=word_alphas, + batch_sizes=packed_sentences.batch_sizes, + sorted_indices=packed_sentences.sorted_indices, + unsorted_indices=packed_sentences.unsorted_indices), + batch_first=True) # (n_documents, max(sentences_per_document), max(words_per_sentence)) return documents, word_alphas, sentence_alphas @@ -231,24 +230,23 @@ def forward(self, sentences, words_per_sentence): :param words_per_sentence: sentence lengths, a tensor of dimension (n_sentences) :return: sentence embeddings, attention weights of words """ - # Sort sentences by decreasing sentence lengths (SORTING #2) - words_per_sentence, sent_sort_ind = words_per_sentence.sort(dim=0, descending=True) - sentences = sentences[sent_sort_ind] # (n_sentences, word_pad_len, emb_size) # Get word embeddings, apply dropout sentences = self.dropout(self.embeddings(sentences)) # (n_sentences, word_pad_len, emb_size) - # Re-arrange as words by removing pad-words (SENTENCES -> WORDS) - words, bw = pack_padded_sequence(sentences, - lengths=words_per_sentence.tolist(), - batch_first=True) # (n_words, emb_size), bw is the effective batch size at each word-timestep + # Re-arrange as words by removing word-pads (SENTENCES -> WORDS) + packed_words = pack_padded_sequence(sentences, + lengths=words_per_sentence.tolist(), + batch_first=True, + enforce_sorted=False) # a PackedSequence object, where 'data' is the flattened words (n_words, word_emb) - # Apply the word-level RNN over the word embeddings (PyTorch automatically applies it on the packed_sequence using the effective batch_size) - (words, _), _ = self.word_rnn(PackedSequence(words, bw)) # (n_words, 2 * word_rnn_size), (max(sent_lens)) + # Apply the word-level RNN over the word embeddings (PyTorch automatically applies it on the PackedSequence) + packed_words, _ = self.word_rnn( + packed_words) # a PackedSequence object, where 'data' is the output of the RNN (n_words, 2 * word_rnn_size) - # Find attention vectors by applying the attention linear layer - att_w = self.word_attention(words) # (n_words, att_size) - att_w = F.tanh(att_w) # (n_words, att_size) + # Find attention vectors by applying the attention linear layer on the output of the RNN + att_w = self.word_attention(packed_words.data) # (n_words, att_size) + att_w = torch.tanh(att_w) # (n_words, att_size) # Take the dot-product of the attention vectors with the context vector (i.e. parameter of linear layer) att_w = self.word_context_vector(att_w).squeeze(1) # (n_words) @@ -260,23 +258,21 @@ def forward(self, sentences, words_per_sentence): att_w = torch.exp(att_w - max_value) # (n_words) # Re-arrange as sentences by re-padding with 0s (WORDS -> SENTENCES) - att_w, _ = pad_packed_sequence(PackedSequence(att_w, bw), - batch_first=True) # (n_sentences, max_sent_len_in_batch) + att_w, _ = pad_packed_sequence(PackedSequence(data=att_w, + batch_sizes=packed_words.batch_sizes, + sorted_indices=packed_words.sorted_indices, + unsorted_indices=packed_words.unsorted_indices), + batch_first=True) # (n_sentences, max(words_per_sentence)) - # Calculate softmax values - word_alphas = att_w / torch.sum(att_w, dim=1, keepdim=True) # (n_sentences, max_sent_len_in_batch) + # Calculate softmax values as now words are arranged in their respective sentences + word_alphas = att_w / torch.sum(att_w, dim=1, keepdim=True) # (n_sentences, max(words_per_sentence)) # Similarly re-arrange word-level RNN outputs as sentences by re-padding with 0s (WORDS -> SENTENCES) - sentences, _ = pad_packed_sequence(PackedSequence(words, bw), - batch_first=True) # (n_sentences, max_sent_len_in_batch, 2 * word_rnn_size) + sentences, _ = pad_packed_sequence(packed_words, + batch_first=True) # (n_sentences, max(words_per_sentence), 2 * word_rnn_size) # Find sentence embeddings - sentences = sentences * word_alphas.unsqueeze(2) # (n_sentences, max_sent_len_in_batch, 2 * word_rnn_size) + sentences = sentences * word_alphas.unsqueeze(2) # (n_sentences, max(words_per_sentence), 2 * word_rnn_size) sentences = sentences.sum(dim=1) # (n_sentences, 2 * word_rnn_size) - # Unsort sentences into the original order (INVERSE OF SORTING #2) - _, sent_unsort_ind = sent_sort_ind.sort(dim=0, descending=False) # (n_sentences) - sentences = sentences[sent_unsort_ind] # (n_sentences, 2 * word_rnn_size) - word_alphas = word_alphas[sent_unsort_ind] # (n_sentences, max_sent_len_in_batch) - return sentences, word_alphas diff --git a/train.py b/train.py index bfbb1c5..ec0ebf7 100644 --- a/train.py +++ b/train.py @@ -29,7 +29,7 @@ lr = 1e-3 # learning rate momentum = 0.9 # momentum workers = 4 # number of workers for loading data in the DataLoader -epochs = 2 # number of epochs to run without early-stopping +epochs = 2 # number of epochs to run grad_clip = None # clip gradients at this value print_freq = 2000 # print training or validation status every __ batches checkpoint = None # path to model checkpoint, None if none