import tensorflow as tf
import numpy as np
import keras
tf.random.set_seed(42)
vocab_size = 100
embed_size = 10
import tensorflow_addons as tfa
encoder_inputs = keras.layers.Input(shape=[None], dtype=np.int32)
decoder_inputs = keras.layers.Input(shape=[None], dtype=np.int32)
sequence_lengths = keras.layers.Input(shape=[], dtype=np.int32)
embeddings = keras.layers.Embedding(vocab_size, embed_size)
encoder_embeddings = embeddings(encoder_inputs)
decoder_embeddings = embeddings(decoder_inputs)
encoder = keras.layers.LSTM(512, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoder_embeddings)
encoder_state = [state_h, state_c]
sampler = tfa.seq2seq.sampler.TrainingSampler()
decoder_cell = keras.layers.LSTMCell(512)
output_layer = keras.layers.Dense(vocab_size)
decoder = tfa.seq2seq.basic_decoder.BasicDecoder(decoder_cell, sampler,
output_layer=output_layer)
final_outputs, final_state, final_sequence_lengths = decoder(
decoder_embeddings, initial_state=encoder_state,
sequence_length=sequence_lengths)
Y_proba = tf.nn.softmax(final_outputs.rnn_output)
model = keras.models.Model(
inputs=[encoder_inputs, decoder_inputs, sequence_lengths],
outputs=[Y_proba])
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam")
X = np.random.randint(100, size=10*1000).reshape(1000, 10)
Y = np.random.randint(100, size=15*1000).reshape(1000, 15)
X_decoder = np.c_[np.zeros((1000, 1)), Y[:, :-1]]
seq_lengths = np.full([1000], 15)
history = model.fit([X, X_decoder, seq_lengths], Y, epochs=2)
sampler = tfa.seq2seq.sampler.TrainingSampler()
TrainingSampler : 각 Time Step에서 Decoder에게 이전 Time step 의 출력이 무엇인지 알려 줌. (hidden of encoder or Target Token)
1. inference 시에는 실제로 출력되는 Token의 embeding.
2. 훈련 시에는 이전 Target token의 embeding.
아래는 참고 (Link : https://stackoverflow.com/questions/48783798/whats-the-difference-between-data-time-major-and-batch-major)
Time major vs Batch major
When in comes to RNNs, the tensors usually go to rank 3+, but the idea stays the same. If the input is (batch_size, sequence_num, features), it's called batch major, because the 0 axis is the batch_size. If the input is (sequence_num, batch_size, features), it's called time major likewise. The features is always the last dimension (at least I don't know real cases when it's not), so there's no further variety in naming.
tfa.seq2seq.basic_decoder.BasicDecoder