Home

정보 : Tensorflow. Fit Customize

1. 작성 계기

나는 주력 프레임워크로 TF2를 사용하고 회사에서 주로 3D Pose를 다룬다. 내 Task는 Multi-Input(Ex: Camera Parameter, BBox)이 들어가는 경우도 많고, Loss도 여러 개 가중 합하고 Flag로 껐다 켜가며 사용한다. 때문에 Model의 기본 Method 대신 Trainer Class를 만들어 사용했다.
이 Trainer로 Model의 Step부터 Loss, Optimizer, Logger 등 학습에 관련된 부분들을 관리했는데 코드가 길어지는 건 물론이고, 모델이 바뀔 때마다 Subclass 해서 다시 만들어 줘야 했고, Distributed Startegy 같은 프레임워크에 종속된 기능을 사용할 땐 너무 손볼곳이 많았다.
코드 리팩토링을 하면서 Model 기본 Method만 이용해서 Train Loop를 구현했다. 학습에 문제는 없었으나 Logging이나 Multi Input을 다루는 데에 있어서 제약이 너무 많았다. 그러던 중 keras.Model 의 Method를 Overriding 해서 단계별로 커스터마이징 하는 공식 문서를 찾았고, 도움이 많이 되어 내용을 정리한다.

2. 기본 구조

tf.keras.Model

일반적으로 말하는 '모델'의 연산이 정의 된 Class이다.
1.
Sequential Class를 이용해, 순차적으로 연산을 정의해 만들 수 있다.
2.
Input Output을 연결함으로써, 그 사이의 연산을 정의해 만들 수 있다.(Functional API)
3.
Model Class를 Subclass해서, call method로 연산을 정의해 만들 수 있다.
Methods : Compile(), Fit()
1.
Optimizer, Loss Function이 Compile()을 통해 정의된다.
2.
데이터 입력 -> 모델 동작 -> Loss 계산 -> Optimize -> Callback이 Fit()을 통해 이루어진다.
다음과 같은 구조가 가장 Keras의 추상적인(High-Level) 학습 방법이다
1.
Model의 연산을 정의하고(Model)
2.
Loss와 Optimizer를 정의하고(Compile)
3.
데이터를 흘려 학습시킨다.(Fit)

3. Model.fit() 커스터마이징

위와 같이 간단한 예시들은 Keras의 장점을 잘 보여준다.
복잡한 동작은 알 필요 없이 Method 2개만으로 학습 과정을 수행할 수 있다. 이런 방식으로 원하는 모든 모델을 구현할 수 있다면 좋겠지만, 조금이라도 Model이나 Train Loop가 복잡해진다면 결국 커스터마이징이 필요하다. 
커스터마이징 하는 정도와 방법을 잘 결정하면 High Level의 장점도 가져가면서 유연한 동작을 만들 수 있다. 예를 들어 Train Loop를 적절히 커스터마이징하면, fit() 이 지원하는 기본 Keras Logger나 Callback 같은 유용한 기능들도 사용하면서 Train Loop의 동작을 원하는 방식으로 구현할 수 있다.
Step 1 : 연산 Customize
1.
fit() 은 데이터 입력 -> 모델 동작 -> Loss 계산 -> Optimize -> Callback을 수행한다.
2.
Model.train_step() 이 이 중 (모델 동작 -> Loss 계산 -> Optimize) 부분을 담당한다.
3.
오버 라이딩할 때, 유의해야 할 점은 다음과 같다.
a.
Model.complie() 에서 정의된 Loss와 Metric은  self.compile_loss와 self.compile_metrics 저장되어 있다.
b.
self.compile_loss를 이용해 Optimize 하고, self.compile_metrics의 상태를 업데이트 하고
c.
self.metrics의 결괏값을 return 해 주어야 한다 (Callbacks로 전달하기 위함)
Code Example
Step 2 : Loss, Metric Customize
Loss, Metric, Optimizer도 compile() 대신 train_step() 내부에서 지정해 줄 수 있다. (다만 이 친구들은 Subclass 후 init할때 Attribute로 넣어두는게 편하다)
이때 주의해야 할 점은
1.
compile()을 통해 정의된 Loss는 자동으로 self.metrics에 들어가지만 이 외의 방법으로 정의했을 경우 tf.keras.metrics 객체를 이용해 추적을 해주어야 Callback에 전달할 수 있다.
2.
사용할 metrics들을 self.metrics로 set 해 주어야, 에폭마다 기록이 reset 된다.
3.
train_step()에서 Loss와 Metric을 지정했다면, test_step()도 동일하게 지정해주어야 한다.
Code Example
Step 3 : 실전 예제 GAN
Model Class의 오버 라이딩과 기본 Attribute를 이용해 fit()으로 학습시키는 예시이다.
눈여겨볼 점은 다음과 같다.
모델 이외의 다양한 Trick을 train_step()에서 구현한 점
Optimizer를 2개 사용하기 위해 compile()을 오버 라이딩한 점
복잡한 커스터마이즈를 + Keras Model Class의 기본 구조인 compile(), fit()을 활용한 점
from tensorflow.keras import layers # Create the discriminator discriminator = keras.Sequential( [ keras.Input(shape=(28, 28, 1)), layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"), layers.LeakyReLU(alpha=0.2), layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"), layers.LeakyReLU(alpha=0.2), layers.GlobalMaxPooling2D(), layers.Dense(1), ], name="discriminator", ) # Create the generator latent_dim = 128 generator = keras.Sequential( [ keras.Input(shape=(latent_dim,)), # We want to generate 128 coefficients to reshape into a 7x7x128 map layers.Dense(7 * 7 * 128), layers.LeakyReLU(alpha=0.2), layers.Reshape((7, 7, 128)), layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"), layers.LeakyReLU(alpha=0.2), layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"), layers.LeakyReLU(alpha=0.2), layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"), ], name="generator", ) class GAN(keras.Model): def __init__(self, discriminator, generator, latent_dim): super(GAN, self).__init__() self.discriminator = discriminator self.generator = generator self.latent_dim = latent_dim def compile(self, d_optimizer, g_optimizer, loss_fn): super(GAN, self).compile() self.d_optimizer = d_optimizer self.g_optimizer = g_optimizer self.loss_fn = loss_fn def train_step(self, real_images): if isinstance(real_images, tuple): real_images = real_images[0] # Sample random points in the latent space batch_size = tf.shape(real_images)[0] random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim)) # Decode them to fake images generated_images = self.generator(random_latent_vectors) # Combine them with real images combined_images = tf.concat([generated_images, real_images], axis=0) # Assemble labels discriminating real from fake images labels = tf.concat( [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0 ) # Add random noise to the labels - important trick! labels += 0.05 * tf.random.uniform(tf.shape(labels)) # Train the discriminator with tf.GradientTape() as tape: predictions = self.discriminator(combined_images) d_loss = self.loss_fn(labels, predictions) grads = tape.gradient(d_loss, self.discriminator.trainable_weights) self.d_optimizer.apply_gradients( zip(grads, self.discriminator.trainable_weights) ) # Sample random points in the latent space random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim)) # Assemble labels that say "all real images" misleading_labels = tf.zeros((batch_size, 1)) # Train the generator (note that we should *not* update the weights # of the discriminator)! with tf.GradientTape() as tape: predictions = self.discriminator(self.generator(random_latent_vectors)) g_loss = self.loss_fn(misleading_labels, predictions) grads = tape.gradient(g_loss, self.generator.trainable_weights) self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights)) return {"d_loss": d_loss, "g_loss": g_loss} # Prepare the dataset. We use both the training & test MNIST digits. batch_size = 64 (x_train, _), (x_test, _) = keras.datasets.mnist.load_data() all_digits = np.concatenate([x_train, x_test]) all_digits = all_digits.astype("float32") / 255.0 all_digits = np.reshape(all_digits, (-1, 28, 28, 1)) dataset = tf.data.Dataset.from_tensor_slices(all_digits) dataset = dataset.shuffle(buffer_size=1024).batch(batch_size) gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim) gan.compile( d_optimizer=keras.optimizers.Adam(learning_rate=0.0003), g_optimizer=keras.optimizers.Adam(learning_rate=0.0003), loss_fn=keras.losses.BinaryCrossentropy(from_logits=True), ) # To limit the execution time, we only train on 100 batches. You can train on # the entire dataset. You will need about 20 epochs to get nice results. gan.fit(dataset.take(100), epochs=1)
Python
복사