Transformer原理和代码详解
个人其他链接githubblog资源完整代码+详细代码注释:github参考论文: Attention Is All You Need参考实现 tensorflow2.0 offical tutorials/text/transformer原理Transformer模型来自论文Attention Is All You Need。这个模型的应用场景是机器翻译,借助Se...
个人其他链接
资源
-
完整代码+详细代码注释:github
原理
Transformer模型来自论文Attention Is All You Need。这个模型的应用场景是机器翻译,借助Self-Attention机制和Position Encoding可以替代传统Seq2Seq模型中的RNN结构。由于Transformer的优异表现,后续OpenAI GPT和BERT模型都使用了Transformer的Decoder部分。
Transformer算法流程:
输入:inputs, targets
举个例子:
inputs = ‘SOS 想象力 比 知识 更 重要 EOS’
targets = ‘SOS imagination is more important than knowledge EOS’
训练
训练时采用强制学习
inputs = ‘SOS 想象力 比 知识 更 重要 EOS’
targets = ‘SOS imagination is more important than knowledge’
目标(targets)被分成了 tar_inp 和 tar_real。tar_inp 作为输入传递到Decoder。tar_real 是位移了 1 的同一个输入:在 tar_inp 中的每个位置,tar_real 包含了应该被预测到的下一个标记(token)。
tar_inp = ‘SOS imagination is more important than knowledge’
tar_real = ‘imagination is more important than knowledge EOS’
即inputs经过Encoder编码后得到inputs的信息,targets开始输入SOS 向后Decoder翻译预测下一个词的概率,由于训练时采用强制学习,所以用真实值来预测下一个词。
预测输出
tar_pred = ‘imagination is more important than knowledge EOS’
当然这是希望预测最好的情况,即真实tar_real就是这样。实际训练时开始不会预测这么准确
损失:交叉熵损失
根据tar_pred和tar_real得到交叉熵损失
模型训练好后如何预测?
其中SOS为标志句子开始的标志符号,EOS为标志结束的符号
Encoder阶段:inputs = ‘SOS 想象力 比 知识 更 重要 EOS’
Decoder阶段:循环预测
输入一个[SOS, ],预测到下一个token为:imagination
输入[SOS, imagination], 预测下一个token为:is
…
输入[SOS, imagination is more important than knowledge]预测下一个EOS。最终结束
结束有两个条件,预测到EOS,或者最长的target_seq_len
网络结构
原始论文网络结构
自己实现的网络结构:
Encoder部分:
下面伪代码中的解释:
MultiHeadAttention(v, k, q, mask)
Encoder block
包括两个子层:
- 多头注意力(有填充遮挡)
- 点式前馈网络(Point wise feed forward networks), 其实就是两层全连接
输入x为input_sentents, (batch_size, seq_len, d_model)
- out1 = BatchNormalization( x +(MultiHeadAttention(x, x, x)=>dropout))
- out2 = BatchNormalization( out1 + (ffn(out1) => dropout) )
Decoder部分:
和Encoder部分区别在于,Decoder部分先对自身做了Self-Attention后,在作为query,对Encoder的输出作为key和value,进行普通Attention后的结果,作为 feed forward的输入
Decoder block,需要的子层:
- 遮挡的多头注意力(前瞻遮挡和填充遮挡)
- 多头注意力(用填充遮挡)。V(数值)和 K(主键)接收编码器输出作为输入。Q(请求)接收遮挡的多头注意力子层的输出。
- 点式前馈网络
输入x为target_sentents, (batch_size, seq_len, d_model)
- out1 = BatchNormalization( x +(MultiHeadAttention(x, x, x)=>dropout))
- out2 = BatchNormalization( out1 +(MultiHeadAttention(enc_output, enc_output out1)=>dropout))
- out3 = BatchNormalization( out2 + (ffn(out2) => dropout) )
具体代码实现
Position
def get_angles(pos, i, d_model):
'''
:param pos:单词在句子的位置
:param i:单词在词表里的位置
:param d_model:词向量维度大小
:return:
'''
angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
return pos * angle_rates
def positional_encoding(position, d_model):
'''
:param position: 最大的position
:param d_model: 词向量维度大小
:return: [1, 最大position个数,词向量维度大小] 最后和embedding矩阵相加
'''
angle_rads = get_angles(np.arange(position)[:, np.newaxis],
np.arange(d_model)[np.newaxis, :],
d_model)
# apply sin to even indices in the array; 2i
angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
# apply cos to odd indices in the array; 2i+1
angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
pos_encoding = angle_rads[np.newaxis, ...]
return tf.cast(pos_encoding, dtype=tf.float32)
point_wise_feed_forward_network
def point_wise_feed_forward_network(d_model, dff):
return tf.keras.Sequential([
tf.keras.layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff)
tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model)
])
Attention
其中MultiHeadAttention其实是在d_model(词embedding维度)进行split,然后做Attention
def scaled_dot_product_attention(q, k, v, mask=None):
'''计算attention
q,k,v的第一维度必须相同
q,k的最后一维必须相同
k,v在倒数第二的维度需要相同, seq_len_k = seq_len_q=seq_len。
参数:
q: 请求的形状 == (..., seq_len_q, d)
k: 主键的形状 == (..., seq_len, d)
v: 数值的形状 == (..., seq_len, d_v)
mask: Float 张量,其形状能转换成
(..., seq_len_q, seq_len)。默认为None。
返回值:
输出,注意力权重
'''
# (batch_size, num_heads, seq_len_q, d ) dot (batch_size, num_heads, d, seq_ken_k) = (batch_size, num_heads,, seq_len_q, seq_len)
matmul_qk = tf.matmul(q, k, transpose_b=True)
# 缩放matmul_qk
dk = tf.cast(tf.shape(k)[-1], dtype=tf.float32)
scaled_attention_logits = matmul_qk/tf.math.sqrt(dk)
# 将 mask 加入到缩放的张量上。
if mask is not None:
# (batch_size, num_heads,, seq_len_q, seq_len) + (batch_size, 1,, 1, seq_len)
scaled_attention_logits += (mask * -1e9)
# softmax归一化权重 (batch_size, num_heads, seq_len)
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
# seq_len_q个位置分别对应v上的加权求和
# (batch_size, num_heads, seq_len) dot (batch_size, num_heads, d_v) = (batch_size, num_heads, seq_len_q, d_v)
output = tf.matmul(attention_weights, v)
return output, attention_weights
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert (d_model > num_heads) and (d_model % num_heads == 0)
self.d_model = d_model
self.num_heads = num_heads
self.depth = d_model // num_heads
self.qw = tf.keras.layers.Dense(d_model)
self.kw = tf.keras.layers.Dense(d_model)
self.vw = tf.keras.layers.Dense(d_model)
self.dense = tf.keras.layers.Dense(d_model)
def split_heads(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) # (batch_size, seq_len, num_heads, depth)
return tf.transpose(x, perm=(0, 2, 1, 3)) # (batch_size, num_heads, seq_len, depth)
def call(self, v, k, q, mask=None):
# v = inputs
batch_size = tf.shape(q)[0]
q = self.qw(q) # (batch_size, seq_len_q, d_model)
k = self.kw(k) # (batch_size, seq_len, d_model)
v = self.vw(v) # (batch_size, seq_len, d_model)
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len, depth_v)
# scaled_attention, (batch_size, num_heads, seq_len_q, depth_v)
# attention_weights, (batch_size, num_heads, seq_len_q, seq_len)
scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)
scaled_attention = tf.transpose(scaled_attention, perm=(0, 2, 1, 3)) # (batch_size, seq_len_q, num_heads, depth_v)
concat_attention = tf.reshape(scaled_attention, shape=(batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
return output, attention_weights
Encoder
输入:
- inputs(batch_size, seq_len_inp, d_model)
- mask(batch_size, 1, 1, seq_len_inp),因为输入序列要填充到相同的长度,所以对填充的位置做self-attention时要做mask,这里之所以是(batch_size, 1, 1, d_model)的维度,是因为inputs做MultiHeadAttention会split成(batch_size, num_heads, seq_len_inp, d_model//num_heads),经过MultiHeadAttention计算的权重是(batch_size, num_heads, seq_len_inp, seq_len_inp ),这样做mask时,mask会自动传播成:(batch_size, num_heads, seq_len_inp, seq_len_inp )
输出:
- encode_output(batch_size, seq_len_inp, d_model)
class EncoderLayer(tf.keras.layers.Layer):
'''Encoder block
包括两个子层:1.多头注意力(有填充遮挡)2.点式前馈网络(Point wise feed forward networks)。
out1 = BatchNormalization( x +(MultiHeadAttention(x, x, x)=>dropout))
out2 = BatchNormalization( out1 + (ffn(out1) => dropout) )
'''
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(EncoderLayer, self).__init__()
self.mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
self.ffn = point_wise_feed_forward_network(d_model, dff)
self.layer_norm1 = tf.keras.layers.BatchNormalization(epsilon=1e-6)
self.layer_norm2 = tf.keras.layers.BatchNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask):
attn_output, _ = self.mha(x, x, x, mask) # (batch_size, input_seq_len, d_model)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layer_norm1(x+attn_output) # (batch_size, input_seq_len, d_model)
ffn_output = self.ffn(out1) # (batch_size, input_seq_len, d_model)
ffn_output = self.dropout2(ffn_output, training=training)
out2 = self.layer_norm2(out1+ffn_output) # (batch_size, input_seq_len, d_model)
return out2
class Encoder(tf.keras.layers.Layer):
'''
输入嵌入(Input Embedding)
位置编码(Positional Encoding)
N 个编码器层(encoder layers)
输入经过嵌入(embedding)后,该嵌入与位置编码相加。该加法结果的输出是编码器层的输入。编码器的输出是解码器的输入。
'''
def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, maximum_position_encoding, rate=0.1):
super(Encoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)
self.enc_layer = [EncoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, training, mask):
# x.shape == (batch_size, seq_len)
seq_len = tf.shape(x)[1]
x = self.embedding(x) # (batch_size, input_seq_len, d_model)
x *= tf.math.sqrt(tf.cast(self.d_model, dtype=tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x = self.enc_layer[i](x, training, mask)
return x #(batch_size, input_seq_len, d_model)
Decoder
输入:
- targets_inp(batch_size, seq_len_tar, d_model)
- encode_output(batch_size, seq_len_inp, d_model)
- self_mask(batch_size, 1, 1, seq_len_tar), enc_output_mask(batch_size, 1, 1, seq_len_inp)
输出:
- decode_output(batch_size, seq_len_tar, tar_vobsize)
class DecoderLayer(tf.keras.layers.Layer):
''' Decoder block
需要的子层:
1.遮挡的多头注意力(前瞻遮挡和填充遮挡)
2.多头注意力(用填充遮挡)。V(数值)和 K(主键)接收编码器输出作为输入。Q(请求)接收遮挡的多头注意力子层的输出。
3. 点式前馈网络
out1 = BatchNormalization( x +(MultiHeadAttention(x, x, x)=>dropout))
out2 = BatchNormalization( out1 +(MultiHeadAttention(enc_output, enc_output out1)=>dropout))
out3 = BatchNormalization( out2 + (ffn => dropout) )
'''
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(DecoderLayer, self).__init__()
self.mha1 = MultiHeadAttention(d_model, num_heads)
self.mha2 = MultiHeadAttention(d_model, num_heads)
self.ffn = point_wise_feed_forward_network(d_model, dff)
self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.layer_norm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
self.dropout3 = tf.keras.layers.Dropout(rate)
def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
# x.shape == (batch_size, target_seq_len, d_model)
# enc_output.shape == (batch_size, input_seq_len, d_model)
attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask) # (batch_size, target_seq_len, d_model)
attn1 = self.dropout1(attn1, training=training)
out1 = self.layer_norm1(x+attn1)
attn2, attn_weights_block2 = self.mha1(enc_output, enc_output, out1, padding_mask) # (batch_size, target_seq_len, d_model)
attn2 = self.dropout2(attn2, training=training)
out2 = self.layer_norm2(out1+attn2)
ffn_output = self.ffn(out2)
ffn_output = self.dropout3(ffn_output, training=training)
out3 = self.layer_norm3(out2+ffn_output) # (batch_size, target_seq_len, d_model)
return out3, attn_weights_block1, attn_weights_block2
class Decoder(tf.keras.layers.Layer):
'''解码器包括:
输出嵌入(Output Embedding)
位置编码(Positional Encoding)
N 个解码器层(decoder layers)
目标(target)经过一个嵌入后,该嵌入和位置编码相加。该加法结果是解码器层的输入。解码器的输出是最后的线性层的输入。
'''
def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size, maximum_position_encoding, rate=0.1):
super(Decoder, self).__init__()
self.d_model = d_model
self.num_layers = num_layers
self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)
self.dec_layer = [DecoderLayer(d_model, num_heads, dff, rate) for _ in range(num_layers)]
self.dropout = tf.keras.layers.Dropout(rate)
def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
# x.shape==(batch_size, target_seq_len)
# enc_output.shape==(batch_size, input_seq_len, d_model)
seq_len = tf.shape(x)[1]
attention_weights = {}
x = self.embedding(x) # (batch_size, target_seq_len, d_model)
x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
x += self.pos_encoding[:, :seq_len, :]
x = self.dropout(x, training=training)
for i in range(self.num_layers):
x, block1, block2 = self.dec_layer[i](x, enc_output, training, look_ahead_mask, padding_mask)
attention_weights['decoder_layer{}_block1'.format(i + 1)] = block1
attention_weights['decoder_layer{}_block2'.format(i + 1)] = block2
# x.shape==(batch_size, target_seq_len, d_model)
return x, attention_weights
Transformer
class Transformer(tf.keras.Model):
def __init__(self, params):
super(Transformer, self).__init__()
self.encoder = Encoder(params['num_layers'],params['d_model'],params['num_heads'],params['dff'],params['input_vocab_size'],params['pe_input'],params['rate'])
self.decoder = Decoder(params['num_layers'],params['d_model'],params['num_heads'],params['dff'],params['target_vocab_size'],params['pe_target'],params['rate'])
self.final_layer = tf.keras.layers.Dense(params['target_vocab_size'])
def call(self, inp, tar, training, enc_padding_mask=None, look_ahead_mask=None, dec_padding_mask=None):
# (batch_size, inp_seq_len, d_model)
enc_output = self.encoder(inp, training, enc_padding_mask)
# (batch_size, tar_seq_len, d_model)
dec_output, attention_weights = self.decoder(tar, enc_output, training, look_ahead_mask, dec_padding_mask)
final_output = self.final_layer(dec_output) # (batch_size, tar_seq_len, target_vocab_size)
return final_output, attention_weights
Mask
def create_padding_mask(seq):
seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
# 添加额外的维度来将填充加到
# 注意力对数(logits)。
return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len)
def create_look_ahead_mask(size):
'''
eg.
x = tf.random.uniform((1, 3))
temp = create_look_ahead_mask(x.shape[1])
temp:<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[0., 1., 1.],
[0., 0., 1.],
[0., 0., 0.]], dtype=float32)>
'''
mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
return mask # (seq_len, seq_len)
def create_masks(inp, tar):
# 编码器填充遮挡
enc_padding_mask = create_padding_mask(inp)
# 在解码器的第二个注意力模块使用。
# 该填充遮挡用于遮挡编码器的输出。
dec_padding_mask = create_padding_mask(inp)
# 在解码器的第一个注意力模块使用。
# 用于填充(pad)和遮挡(mask)解码器获取到的输入的后续标记(future tokens)。
look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1]) #(tar_seq_len, tar_seq_len)
dec_target_padding_mask = create_padding_mask(tar) # (batch_size, 1, 1, tar_seq_len)
# 广播机制,look_ahead_mask==>(batch_size, 1, tar_seq_len, tar_seq_len)
# dec_target_padding_mask ==> (batch_size, 1, tar_seq_len, tar_seq_len)
combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
return enc_padding_mask, combined_mask, dec_padding_mask
组合最终
# ==============================================================
params = {
'num_layers':4,
'd_model':128,
'dff':512,
'num_heads':8,
'input_vocab_size' :tokenizer_pt.vocab_size + 2,
'target_vocab_size':tokenizer_en.vocab_size + 2,
'pe_input':tokenizer_pt.vocab_size + 2,
'pe_target':tokenizer_en.vocab_size + 2,
'rate':0.1,
'checkpoint_path':'./checkpoints/train',
'checkpoint_do_delete':False
}
print('input_vocab_size is {}, target_vocab_size is {}'.format(params['input_vocab_size'], params['target_vocab_size']))
class ModelHelper:
def __init__(self):
self.transformer = Transformer(params)
# optimizer
learning_rate = CustomSchedule(params['d_model'])
self.optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
# 主要为了累计一个epoch中的batch的loss,最后求平均,得到一个epoch的loss
self.train_loss = tf.keras.metrics.Mean(name='train_loss')
# 主要为了累计一个epoch中的batch的acc,最后求平均,得到一个epoch的acc
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
self.test_loss = tf.keras.metrics.Mean(name='test_loss')
self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
# 检查点 params['checkpoint_path']如果不存在,则创建对应目录;如果存在,且checkpoint_do_delete=True时,则先删除目录在创建
checkout_dir(dir_path=params['checkpoint_path'], do_delete=params.get('checkpoint_do_delete', False))
# 检查点
ckpt = tf.train.Checkpoint(transformer=self.transformer,
optimizer=self.optimizer)
self.ckpt_manager = tf.train.CheckpointManager(ckpt, params['checkpoint_path'], max_to_keep=5)
# 如果检查点存在,则恢复最新的检查点。
if self.ckpt_manager.latest_checkpoint:
ckpt.restore(self.ckpt_manager.latest_checkpoint)
print('Latest checkpoint restored!!')
def loss_function(self, real, pred):
mask = tf.math.logical_not(tf.math.equal(real, 0))
loss_ = self.loss_object(real, pred)
mask = tf.cast(mask, dtype=loss_.dtype)
loss_ *= mask
return tf.reduce_mean(loss_)
train_step_signature = [
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]
@tf.function(input_signature=train_step_signature)
def train_step(self, inp, tar):
tar_inp = tar[:, :-1]
tar_real = tar[:, 1:]
enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
with tf.GradientTape() as tape:
predictions, _ = self.transformer(inp, tar_inp,
True,
enc_padding_mask,
combined_mask,
dec_padding_mask)
loss = self.loss_function(tar_real, predictions)
gradients = tape.gradient(loss, self.transformer.trainable_variables)
self.optimizer.apply_gradients(zip(gradients, self.transformer.trainable_variables))
self.train_loss(loss)
self.train_accuracy(tar_real, predictions)
@tf.function
def test_step(self, inp, labels):
predictions = self.predict(inp)
t_loss = self.loss_object(labels, predictions)
self.test_loss(t_loss)
self.test_accuracy(labels, predictions)
def train(self, train_dataset):
for epoch in range(params['epochs']):
start = time.time()
self.train_loss.reset_states()
self.train_accuracy.reset_states()
# inp -> portuguese, tar -> english
for (batch, (inp, tar)) in enumerate(train_dataset):
self.train_step(inp, tar)
if batch % 50 == 0:
print('Epoch {} Batch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, batch, self.train_loss.result(), self.train_accuracy.result()))
if (epoch + 1) % 5 == 0:
ckpt_save_path = self.ckpt_manager.save()
print('Saving checkpoint for epoch {} at {}'.format(epoch + 1,ckpt_save_path))
print('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format(epoch + 1, self.train_loss.result(), self.train_accuracy.result()))
print('Time taken for 1 epoch: {} secs\n'.format(time.time() - start))
# 评估
def predict(self, inp_sentence):
start_token = [tokenizer_pt.vocab_size]
end_token = [tokenizer_pt.vocab_size + 1]
# 输入语句是葡萄牙语,增加开始和结束标记
inp_sentence = start_token + tokenizer_pt.encode(inp_sentence) + end_token
encoder_input = tf.expand_dims(inp_sentence, 0)
# 因为目标是英语,输入 transformer 的第一个词应该是
# 英语的开始标记。
decoder_input = [tokenizer_en.vocab_size]
output = tf.expand_dims(decoder_input, 0)
for i in range(MAX_LENGTH):
enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
encoder_input, output)
# predictions.shape == (batch_size, seq_len, vocab_size)
predictions, attention_weights = self.transformer(encoder_input,
output,
False,
enc_padding_mask,
combined_mask,
dec_padding_mask)
# 从 seq_len 维度选择最后一个词
predictions = predictions[:, -1:, :] # (batch_size, 1, vocab_size)
predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)
# 如果 predicted_id 等于结束标记,就返回结果
if predicted_id == tokenizer_en.vocab_size + 1:
return tf.squeeze(output, axis=0), attention_weights
# 连接 predicted_id 与输出,作为解码器的输入传递到解码器。
output = tf.concat([output, predicted_id], axis=-1)
return tf.squeeze(output, axis=0)
更多推荐
所有评论(0)