#!/usr/bin/env python import os import random import torch import torch.nn as nn import logging from torch.autograd import Variable from .modules.elmo import ElmobiLm from .modules.lstm import LstmbiLm from .modules.token_embedder import ConvTokenEmbedder, LstmTokenEmbedder logger = logging.getLogger('elmoformanylangs') def create_one_batch(x, word2id, char2id, config, oov='', pad='', sort=True): """ Create one batch of input. :param x: List[List[str]] :param word2id: Dict | None :param char2id: Dict | None :param config: Dict :param oov: str, the form of OOV token. :param pad: str, the form of padding token. :param sort: bool, specify whether sorting the sentences by their lengths. :return: """ batch_size = len(x) # lst represents the order of sentences lst = list(range(batch_size)) if sort: lst.sort(key=lambda l: -len(x[l])) # shuffle the sentences by x = [x[i] for i in lst] lens = [len(x[i]) for i in lst] max_len = max(lens) # get a batch of word id whose size is (batch x max_len) if word2id is not None: oov_id, pad_id = word2id.get(oov, None), word2id.get(pad, None) assert oov_id is not None and pad_id is not None batch_w = torch.LongTensor(batch_size, max_len).fill_(pad_id) for i, x_i in enumerate(x): for j, x_ij in enumerate(x_i): batch_w[i][j] = word2id.get(x_ij, oov_id) else: batch_w = None # get a batch of character id whose size is (batch x max_len x max_chars) if char2id is not None: bow_id, eow_id, oov_id, pad_id = [char2id.get(key, None) for key in ('', '', oov, pad)] assert bow_id is not None and eow_id is not None and oov_id is not None and pad_id is not None if config['token_embedder']['name'].lower() == 'cnn': max_chars = config['token_embedder']['max_characters_per_token'] assert max([len(w) for i in lst for w in x[i]]) + 2 <= max_chars elif config['token_embedder']['name'].lower() == 'lstm': # counting the and max_chars = max([len(w) for i in lst for w in x[i]]) + 2 else: raise ValueError('Unknown token_embedder: {0}'.format(config['token_embedder']['name'])) batch_c = torch.LongTensor(batch_size, max_len, max_chars).fill_(pad_id) for i, x_i in enumerate(x): for j, x_ij in enumerate(x_i): batch_c[i][j][0] = bow_id if x_ij == '' or x_ij == '': batch_c[i][j][1] = char2id.get(x_ij) batch_c[i][j][2] = eow_id else: for k, c in enumerate(x_ij): batch_c[i][j][k + 1] = char2id.get(c, oov_id) batch_c[i][j][len(x_ij) + 1] = eow_id else: batch_c = None # mask[0] is the matrix (batch x max_len) indicating whether # there is an id is valid (not a padding) in this batch. # mask[1] stores the flattened ids indicating whether there is a valid # previous token # mask[2] stores the flattened ids indicating whether there is a valid # next token masks = [torch.LongTensor(batch_size, max_len).fill_(0), [], []] for i, x_i in enumerate(x): for j in range(len(x_i)): masks[0][i][j] = 1 if j + 1 < len(x_i): masks[1].append(i * max_len + j) if j > 0: masks[2].append(i * max_len + j) assert len(masks[1]) <= batch_size * max_len assert len(masks[2]) <= batch_size * max_len masks[1] = torch.LongTensor(masks[1]) masks[2] = torch.LongTensor(masks[2]) return batch_w, batch_c, lens, masks # shuffle training examples and create mini-batches def create_batches(x, batch_size, word2id, char2id, config, perm=None, shuffle=True, sort=True, text=None): """ :param x: List[List[str]] :param batch_size: :param word2id: :param char2id: :param config: :param perm: :param shuffle: :param sort: :param text: :return: """ lst = perm or list(range(len(x))) if shuffle: random.shuffle(lst) if sort: lst.sort(key=lambda l: -len(x[l])) x = [x[i] for i in lst] if text is not None: text = [text[i] for i in lst] sum_len = 0.0 batches_w, batches_c, batches_lens, batches_masks, batches_text = [], [], [], [], [] size = batch_size nbatch = (len(x) - 1) // size + 1 for i in range(nbatch): start_id, end_id = i * size, (i + 1) * size bw, bc, blens, bmasks = create_one_batch(x[start_id: end_id], word2id, char2id, config, sort=sort) sum_len += sum(blens) batches_w.append(bw) batches_c.append(bc) batches_lens.append(blens) batches_masks.append(bmasks) if text is not None: batches_text.append(text[start_id: end_id]) if sort: perm = list(range(nbatch)) random.shuffle(perm) batches_w = [batches_w[i] for i in perm] batches_c = [batches_c[i] for i in perm] batches_lens = [batches_lens[i] for i in perm] batches_masks = [batches_masks[i] for i in perm] if text is not None: batches_text = [batches_text[i] for i in perm] logger.info("{} batches, avg len: {:.1f}".format(nbatch, sum_len / len(x))) if text is not None: return batches_w, batches_c, batches_lens, batches_masks, batches_text return batches_w, batches_c, batches_lens, batches_masks class Model(nn.Module): def __init__(self, config, word_emb_layer, char_emb_layer, use_cuda=False): super(Model, self).__init__() self.use_cuda = use_cuda self.config = config if config['token_embedder']['name'].lower() == 'cnn': self.token_embedder = ConvTokenEmbedder( config, word_emb_layer, char_emb_layer, use_cuda) elif config['token_embedder']['name'].lower() == 'lstm': self.token_embedder = LstmTokenEmbedder( config, word_emb_layer, char_emb_layer, use_cuda) if config['encoder']['name'].lower() == 'elmo': self.encoder = ElmobiLm(config, use_cuda) elif config['encoder']['name'].lower() == 'lstm': self.encoder = LstmbiLm(config, use_cuda) self.output_dim = config['encoder']['projection_dim'] def forward(self, word_inp, chars_package, mask_package): """ :param word_inp: :param chars_package: :param mask_package: :return: """ token_embedding = self.token_embedder(word_inp, chars_package, (mask_package[0].size(0), mask_package[0].size(1))) if self.config['encoder']['name'] == 'elmo': mask = Variable(mask_package[0]).cuda() if self.use_cuda else Variable(mask_package[0]) encoder_output = self.encoder(token_embedding, mask) sz = encoder_output.size() token_embedding = torch.cat( [token_embedding, token_embedding], dim=2).view(1, sz[1], sz[2], sz[3]) encoder_output = torch.cat( [token_embedding, encoder_output], dim=0) elif self.config['encoder']['name'] == 'lstm': encoder_output = self.encoder(token_embedding) else: raise ValueError('Unknown encoder: {0}'.format(self.config['encoder']['name'])) return encoder_output def load_model(self, path): self.token_embedder.load_state_dict(torch.load(os.path.join(path, 'token_embedder.pkl'), map_location=lambda storage, loc: storage)) self.encoder.load_state_dict(torch.load(os.path.join(path, 'encoder.pkl'), map_location=lambda storage, loc: storage))