token_embedder.py 4.4 KB
from __future__ import absolute_import
from __future__ import unicode_literals
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import copy
from .highway import Highway


class LstmTokenEmbedder(nn.Module):
  def __init__(self, config, word_emb_layer, char_emb_layer, use_cuda=False):
    super(LstmTokenEmbedder, self).__init__()
    self.config = config
    self.use_cuda = use_cuda
    self.word_emb_layer = word_emb_layer
    self.char_emb_layer = char_emb_layer
    self.output_dim = config['encoder']['projection_dim']
    emb_dim = 0
    if word_emb_layer is not None:
      emb_dim += word_emb_layer.n_d

    if char_emb_layer is not None:
      emb_dim += char_emb_layer.n_d * 2
      self.char_lstm = nn.LSTM(char_emb_layer.n_d, char_emb_layer.n_d, num_layers=1, bidirectional=True,
                               batch_first=True, dropout=config['dropout'])

    self.projection = nn.Linear(emb_dim, self.output_dim, bias=True)

  def forward(self, word_inp, chars_inp, shape):
    embs = []
    batch_size, seq_len = shape
    if self.word_emb_layer is not None:
      word_emb = self.word_emb_layer(Variable(word_inp).cuda() if self.use_cuda else Variable(word_inp))
      embs.append(word_emb)

    if self.char_emb_layer is not None:
      chars_inp = chars_inp.view(batch_size * seq_len, -1)
      chars_emb = self.char_emb_layer(Variable(chars_inp).cuda() if self.use_cuda else Variable(chars_inp))
      _, (chars_outputs, __) = self.char_lstm(chars_emb)
      chars_outputs = chars_outputs.contiguous().view(-1, self.config['token_embedder']['char_dim'] * 2)
      embs.append(chars_outputs)

    token_embedding = torch.cat(embs, dim=2)

    return self.projection(token_embedding)


class ConvTokenEmbedder(nn.Module):
  def __init__(self, config, word_emb_layer, char_emb_layer, use_cuda):
    super(ConvTokenEmbedder, self).__init__()
    self.config = config
    self.use_cuda = use_cuda

    self.word_emb_layer = word_emb_layer
    self.char_emb_layer = char_emb_layer

    self.output_dim = config['encoder']['projection_dim']
    self.emb_dim = 0
    if word_emb_layer is not None:
      self.emb_dim += word_emb_layer.n_d

    if char_emb_layer is not None:
      self.convolutions = []
      cnn_config = config['token_embedder']
      filters = cnn_config['filters']
      char_embed_dim = cnn_config['char_dim']

      for i, (width, num) in enumerate(filters):
        conv = torch.nn.Conv1d(
          in_channels=char_embed_dim,
          out_channels=num,
          kernel_size=width,
          bias=True
        )
        self.convolutions.append(conv)

      self.convolutions = nn.ModuleList(self.convolutions)
      
      self.n_filters = sum(f[1] for f in filters)
      self.n_highway = cnn_config['n_highway']

      self.highways = Highway(self.n_filters, self.n_highway, activation=torch.nn.functional.relu)
      self.emb_dim += self.n_filters

    self.projection = nn.Linear(self.emb_dim, self.output_dim, bias=True)
    
  def forward(self, word_inp, chars_inp, shape):
    embs = []
    batch_size, seq_len = shape
    if self.word_emb_layer is not None:
      batch_size, seq_len = word_inp.size(0), word_inp.size(1)
      word_emb = self.word_emb_layer(Variable(word_inp).cuda() if self.use_cuda else Variable(word_inp))
      embs.append(word_emb)

    if self.char_emb_layer is not None:
      chars_inp = chars_inp.view(batch_size * seq_len, -1)

      character_embedding = self.char_emb_layer(Variable(chars_inp).cuda() if self.use_cuda else Variable(chars_inp))

      character_embedding = torch.transpose(character_embedding, 1, 2)

      cnn_config = self.config['token_embedder']
      if cnn_config['activation'] == 'tanh':
        activation = torch.nn.functional.tanh
      elif cnn_config['activation'] == 'relu':
        activation = torch.nn.functional.relu
      else:
        raise Exception("Unknown activation")

      convs = []
      for i in range(len(self.convolutions)):
        convolved = self.convolutions[i](character_embedding)
        # (batch_size * sequence_length, n_filters for this width)
        convolved, _ = torch.max(convolved, dim=-1)
        convolved = activation(convolved)
        convs.append(convolved)
      char_emb = torch.cat(convs, dim=-1)
      char_emb = self.highways(char_emb)

      embs.append(char_emb.view(batch_size, -1, self.n_filters))
      
    token_embedding = torch.cat(embs, dim=2)

    return self.projection(token_embedding)