1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
| class Transformer(nn.Module): def __init__(self, src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, embed_size=256, num_layers=6, forward_expansion=4, heads=8, dropout=0, device="cuda", max_length=100): super(Transformer, self).__init__() self.encoder = Encoder( src_vocab_size, embed_size, num_layers, heads, device, forward_expansion, dropout, max_length ) self.decoder = Decoder( trg_vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, device, max_length ) self.src_pad_idx = src_pad_idx self.trg_pad_idx = trg_pad_idx self.device = device def make_src_mask(self, src): src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2) return src_mask.to(self.device) def make_trg_mask(self, trg): N, trg_len = trg.shape trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand( N, 1, trg_len, trg_len ) return trg_mask.to(self.device) def forward(self, src, trg): src_mask = self.make_src_mask(src) trg_mask = self.make_trg_mask(trg) enc_src = self.encoder(src, src_mask) out = self.decoder(trg, enc_src, src_mask, trg_mask) return out
|