| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 
 | class Transformer(nn.Module):  def __init__(self,
 src_vocab_size,
 trg_vocab_size,
 src_pad_idx,
 trg_pad_idx,
 embed_size=256,
 num_layers=6,
 forward_expansion=4,
 heads=8,
 dropout=0,
 device="cuda",
 max_length=100):
 
 super(Transformer, self).__init__()
 
 self.encoder = Encoder(
 src_vocab_size,
 embed_size,
 num_layers,
 heads,
 device,
 forward_expansion,
 dropout,
 max_length
 )
 
 self.decoder = Decoder(
 trg_vocab_size,
 embed_size,
 num_layers,
 heads,
 forward_expansion,
 dropout,
 device,
 max_length
 )
 
 self.src_pad_idx = src_pad_idx
 self.trg_pad_idx = trg_pad_idx
 self.device = device
 
 def make_src_mask(self, src):
 src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
 return src_mask.to(self.device)
 
 def make_trg_mask(self, trg):
 N, trg_len = trg.shape
 trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
 N, 1, trg_len, trg_len
 )
 return trg_mask.to(self.device)
 
 def forward(self, src, trg):
 src_mask = self.make_src_mask(src)
 trg_mask = self.make_trg_mask(trg)
 
 enc_src = self.encoder(src, src_mask)
 out = self.decoder(trg, enc_src, src_mask, trg_mask)
 return out
 
 |