1. 작성 계기
•
나는 주력 프레임워크로 TF2를 사용하고 회사에서 주로 3D Pose를 다룬다. 내 Task는 Multi-Input(Ex: Camera Parameter, BBox)이 들어가는 경우도 많고, Loss도 여러 개 가중 합하고 Flag로 껐다 켜가며 사용한다. 때문에 Model의 기본 Method 대신 Trainer Class를 만들어 사용했다.
•
이 Trainer로 Model의 Step부터 Loss, Optimizer, Logger 등 학습에 관련된 부분들을 관리했는데 코드가 길어지는 건 물론이고, 모델이 바뀔 때마다 Subclass 해서 다시 만들어 줘야 했고, Distributed Startegy 같은 프레임워크에 종속된 기능을 사용할 땐 너무 손볼곳이 많았다.
•
코드 리팩토링을 하면서 Model 기본 Method만 이용해서 Train Loop를 구현했다. 학습에 문제는 없었으나 Logging이나 Multi Input을 다루는 데에 있어서 제약이 너무 많았다. 그러던 중 keras.Model 의 Method를 Overriding 해서 단계별로 커스터마이징 하는 공식 문서를 찾았고, 도움이 많이 되어 내용을 정리한다.
2. 기본 구조
tf.keras.Model
일반적으로 말하는 '모델'의 연산이 정의 된 Class이다.
1.
Sequential Class를 이용해, 순차적으로 연산을 정의해 만들 수 있다.
2.
Input Output을 연결함으로써, 그 사이의 연산을 정의해 만들 수 있다.(Functional API)
3.
Model Class를 Subclass해서, call method로 연산을 정의해 만들 수 있다.
Methods : Compile(), Fit()
1.
Optimizer, Loss Function이 Compile()을 통해 정의된다.
2.
데이터 입력 -> 모델 동작 -> Loss 계산 -> Optimize -> Callback이 Fit()을 통해 이루어진다.
다음과 같은 구조가 가장 Keras의 추상적인(High-Level) 학습 방법이다
1.
Model의 연산을 정의하고(Model)
2.
Loss와 Optimizer를 정의하고(Compile)
3.
데이터를 흘려 학습시킨다.(Fit)
3. Model.fit() 커스터마이징
위와 같이 간단한 예시들은 Keras의 장점을 잘 보여준다.
•
복잡한 동작은 알 필요 없이 Method 2개만으로 학습 과정을 수행할 수 있다. 이런 방식으로 원하는 모든 모델을 구현할 수 있다면 좋겠지만, 조금이라도 Model이나 Train Loop가 복잡해진다면 결국 커스터마이징이 필요하다.
•
커스터마이징 하는 정도와 방법을 잘 결정하면 High Level의 장점도 가져가면서 유연한 동작을 만들 수 있다. 예를 들어 Train Loop를 적절히 커스터마이징하면, fit() 이 지원하는 기본 Keras Logger나 Callback 같은 유용한 기능들도 사용하면서 Train Loop의 동작을 원하는 방식으로 구현할 수 있다.
Step 1 : 연산 Customize
1.
fit() 은 데이터 입력 -> 모델 동작 -> Loss 계산 -> Optimize -> Callback을 수행한다.
2.
Model.train_step() 이 이 중 (모델 동작 -> Loss 계산 -> Optimize) 부분을 담당한다.
3.
오버 라이딩할 때, 유의해야 할 점은 다음과 같다.
a.
Model.complie() 에서 정의된 Loss와 Metric은
self.compile_loss와 self.compile_metrics에 저장되어 있다.
b.
self.compile_loss를 이용해 Optimize 하고, self.compile_metrics의 상태를 업데이트 하고
c.
self.metrics의 결괏값을 return 해 주어야 한다 (Callbacks로 전달하기 위함)
Code Example
Step 2 : Loss, Metric Customize
•
Loss, Metric, Optimizer도 compile() 대신 train_step() 내부에서 지정해 줄 수 있다.
(다만 이 친구들은 Subclass 후 init할때 Attribute로 넣어두는게 편하다)
•
이때 주의해야 할 점은
1.
compile()을 통해 정의된 Loss는 자동으로 self.metrics에 들어가지만 이 외의 방법으로 정의했을 경우 tf.keras.metrics 객체를 이용해 추적을 해주어야 Callback에 전달할 수 있다.
2.
사용할 metrics들을 self.metrics로 set 해 주어야, 에폭마다 기록이 reset 된다.
3.
train_step()에서 Loss와 Metric을 지정했다면, test_step()도 동일하게 지정해주어야 한다.
Code Example
Step 3 : 실전 예제 GAN
•
Model Class의 오버 라이딩과 기본 Attribute를 이용해 fit()으로 학습시키는 예시이다.
•
눈여겨볼 점은 다음과 같다.
◦
모델 이외의 다양한 Trick을 train_step()에서 구현한 점
◦
Optimizer를 2개 사용하기 위해 compile()을 오버 라이딩한 점
◦
복잡한 커스터마이즈를 + Keras Model Class의 기본 구조인 compile(), fit()을 활용한 점
from tensorflow.keras import layers
# Create the discriminator
discriminator = keras.Sequential(
[
keras.Input(shape=(28, 28, 1)),
layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.GlobalMaxPooling2D(),
layers.Dense(1),
],
name="discriminator",
)
# Create the generator
latent_dim = 128
generator = keras.Sequential(
[
keras.Input(shape=(latent_dim,)),
# We want to generate 128 coefficients to reshape into a 7x7x128 map
layers.Dense(7 * 7 * 128),
layers.LeakyReLU(alpha=0.2),
layers.Reshape((7, 7, 128)),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
layers.LeakyReLU(alpha=0.2),
layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
],
name="generator",
)
class GAN(keras.Model):
def __init__(self, discriminator, generator, latent_dim):
super(GAN, self).__init__()
self.discriminator = discriminator
self.generator = generator
self.latent_dim = latent_dim
def compile(self, d_optimizer, g_optimizer, loss_fn):
super(GAN, self).compile()
self.d_optimizer = d_optimizer
self.g_optimizer = g_optimizer
self.loss_fn = loss_fn
def train_step(self, real_images):
if isinstance(real_images, tuple):
real_images = real_images[0]
# Sample random points in the latent space
batch_size = tf.shape(real_images)[0]
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# Decode them to fake images
generated_images = self.generator(random_latent_vectors)
# Combine them with real images
combined_images = tf.concat([generated_images, real_images], axis=0)
# Assemble labels discriminating real from fake images
labels = tf.concat(
[tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0
)
# Add random noise to the labels - important trick!
labels += 0.05 * tf.random.uniform(tf.shape(labels))
# Train the discriminator
with tf.GradientTape() as tape:
predictions = self.discriminator(combined_images)
d_loss = self.loss_fn(labels, predictions)
grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
self.d_optimizer.apply_gradients(
zip(grads, self.discriminator.trainable_weights)
)
# Sample random points in the latent space
random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
# Assemble labels that say "all real images"
misleading_labels = tf.zeros((batch_size, 1))
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with tf.GradientTape() as tape:
predictions = self.discriminator(self.generator(random_latent_vectors))
g_loss = self.loss_fn(misleading_labels, predictions)
grads = tape.gradient(g_loss, self.generator.trainable_weights)
self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
return {"d_loss": d_loss, "g_loss": g_loss}
# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(
d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)
# To limit the execution time, we only train on 100 batches. You can train on
# the entire dataset. You will need about 20 epochs to get nice results.
gan.fit(dataset.take(100), epochs=1)
Python
복사