from mindspore.nn.wrap.grad_reducer import DistributedGradReducer from mindspore.train.parallel_utils import ParallelMode from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean, _get_parallel_mode)
import mindspore as ms import mindspore.context as context import mindspore.nn.wrap as mwp import mindspore.nn.layer as ml import mindspore.train.callback as callback import mindspore.nn.loss as mls import mindspore.nn.optim as moptim from mindspore.nn import Cell import mindspore.ops.functional as F import mindspore.ops.operations as P import mindspore.ops.composite as C from mindspore.common import initializer as minit import mindspore.dataset as ds import mindspore.dataset.transforms as transforms import numpy as np import matplotlib.pyplot as plt import os import urllib.request from urllib.parse import urlparse import gzip import time
def unzipfile(gzip_path): """unzip dataset file Args: gzip_path: dataset file path """ open_file = open(gzip_path.replace('.gz', ''), 'wb') gz_file = gzip.GzipFile(gzip_path) open_file.write(gz_file.read()) gz_file.close()
def download_dataset(): """Download the dataset from http://yann.lecun.com/exdb/mnist/.""" train_path = "./MNIST_Data/train/" test_path = "./MNIST_Data/test/" train_path_check = os.path.exists(train_path) test_path_check = os.path.exists(test_path) if train_path_check == False and test_path_check == False: os.makedirs(train_path) os.makedirs(test_path) else: return print("******Downloading the MNIST dataset******") train_url = {"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"} test_url = {"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"} for url in train_url: url_parse = urlparse(url) file_name = os.path.join(train_path, url_parse.path.split('/')[-1]) if not os.path.exists(file_name.replace('.gz', '')): file = urllib.request.urlretrieve(url, file_name) unzipfile(file_name) os.remove(file_name) for url in test_url: url_parse = urlparse(url) file_name = os.path.join(test_path, url_parse.path.split('/')[-1]) if not os.path.exists(file_name.replace('.gz', '')): file = urllib.request.urlretrieve(url, file_name) unzipfile(file_name) os.remove(file_name)
def create_dataset(data_path, noise_dim, batch_size=32, repeat_size=1, num_parallel_workers=1): """ create dataset for train or test Args: data_path: Data path batch_size: The number of data records in each group repeat_size: The number of replicated data records num_parallel_workers: The number of parallel workers """ mnist_ds = ds.MnistDataset(data_path)
hwc2chw_op = transforms.vision.c_transforms.HWC2CHW() mnist_ds = (mnist_ds.map(operations=lambda x: ((x - 127.5) / 127.5).astype('float32'), input_columns="image", num_parallel_workers=num_parallel_workers) .map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) .map(operations=lambda x: (x, np.random.randn(noise_dim).astype('float32')), input_columns="image", output_columns=["image", "noise"], columns_order=["image", "noise"], num_parallel_workers=num_parallel_workers)) buffer_size = 60000 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds: ds.MindDataset = mnist_ds.repeat(repeat_size) return mnist_ds
class Reshape(Cell): def __init__(self, shape: list) -> None: super().__init__() self.shape = shape
def construct(self, x): return F.reshape(x, self.shape)
def make_generator_model(noise_dim): model = ml.SequentialCell( ml.Dense(noise_dim, 7 * 7 * 256, has_bias=False), Reshape((-1, 256, 7, 7)), ml.BatchNorm2d(256), ml.LeakyReLU(), ml.Conv2dTranspose(256, 128, (5, 5), stride=(1, 1), pad_mode='same', has_bias=False), ml.BatchNorm2d(128), ml.LeakyReLU(), ml.Conv2dTranspose(128, 64, (5, 5), stride=(2, 2), pad_mode='same', has_bias=False), ml.BatchNorm2d(64), ml.LeakyReLU(), ml.Conv2dTranspose(64, 1, (5, 5), stride=(2, 2), pad_mode='same', has_bias=False), ml.Tanh() ) return model
def make_discriminator_model(): model = ml.SequentialCell( ml.Conv2d(1, 64, (5, 5), stride=(2, 2), pad_mode='same'), ml.LeakyReLU(), ml.Dropout(0.3), ml.Conv2d(64, 128, (5, 5), stride=(2, 2), pad_mode='same'), ml.LeakyReLU(), ml.Dropout(0.3), ml.Flatten(), ml.Dense(128 * 7 * 7, 1), ml.Sigmoid() )
return model
class GANBaseNet(Cell): def __init__(self, noise_dim) -> None: super().__init__(auto_prefix=True) self.generator = make_generator_model(noise_dim) self.discriminator = make_discriminator_model()
def construct(self, images, noise): generated_images = self.generator(noise) real_output = self.discriminator(images) fake_output = self.discriminator(generated_images) return real_output, fake_output
class GANWithLoss(Cell): def __init__(self, base_net: GANBaseNet) -> None: super().__init__(auto_prefix=True) self.base_net = base_net self.cross_entropy = P.BinaryCrossEntropy()
def discriminator_loss(self, real_output, fake_output, weight): real_loss = self.cross_entropy(real_output, F.ones_like(real_output), weight) fake_loss = self.cross_entropy(fake_output, F.zeros_like(fake_output), weight) total_loss = real_loss + fake_loss return total_loss
def generator_loss(self, fake_output, weight): return self.cross_entropy(fake_output, F.ones_like(fake_output), weight)
def construct(self, images, noise): real_output, fake_output = self.base_net(images, noise) weight = F.ones_like(real_output) gen_loss = self.generator_loss(fake_output, weight) disc_loss = self.discriminator_loss(real_output, fake_output, weight) return gen_loss, disc_loss
class IthOutputCell(Cell): """ 显式指定反向传播图 """
def __init__(self, network, output_index): super(IthOutputCell, self).__init__() self.network = network self.output_index = output_index
def construct(self, image, noise): predict = self.network(image, noise)[self.output_index] return predict
class TrainStepWrap(Cell): def __init__(self, network: GANWithLoss, g_optimizer: moptim.Optimizer, d_optimizer: moptim.Optimizer, sens=1.0): super(TrainStepWrap, self).__init__(auto_prefix=True) self.network = network self.network.set_grad() self.network.add_flags(defer_inline=True) self.g_weights = g_optimizer.parameters self.g_optimizer = g_optimizer self.d_weights = d_optimizer.parameters self.d_optimizer = d_optimizer self.g_grad = C.GradOperation('g_grad', get_by_list=True, sens_param=True) self.d_grad = C.GradOperation('d_grad', get_by_list=True, sens_param=True)
self.g_loss_net = IthOutputCell(network, output_index=0) self.d_loss_net = IthOutputCell(network, output_index=1)
self.sens = sens self.reducer_flag = False self.grad_reducer = None parallel_mode = _get_parallel_mode() if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): self.reducer_flag = True if self.reducer_flag: mean = _get_mirror_mean() degree = _get_device_num() self.g_grad_reducer = DistributedGradReducer(g_optimizer.parameters, mean, degree) self.d_grad_reducer = DistributedGradReducer(d_optimizer.parameters, mean, degree)
def update_model(self, image, noise, loss, loss_net, grad, optimizer, weights, grad_reducer): sens = F.fill(F.dtype(loss), F.shape(loss), self.sens) grads = grad(loss_net, weights)(image, noise, sens) if self.reducer_flag: grads = grad_reducer(grads) return F.depend(loss, optimizer(grads))
def construct(self, image, noise): g_loss, d_loss = self.network(image, noise) g_out = self.update_model(image, noise, g_loss, self.g_loss_net, self.g_grad, self.g_optimizer, self.g_weights, self.g_grad_reducer) d_out = self.update_model(image, noise, d_loss, self.d_loss_net, self.d_grad, self.d_optimizer, self.d_weights, self.d_grad_reducer) return g_out, d_out
class GANLossMonitor(callback.LossMonitor): def step_end(self, run_context): cb_params = run_context.original_args() g_loss, d_loss = cb_params.net_outputs g_loss: ms.Tensor
g_loss = np.mean(g_loss.asnumpy()) d_loss = np.mean(d_loss.asnumpy())
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
if isinstance(g_loss, float) and (np.isnan(g_loss) or np.isinf(g_loss)): raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( cb_params.cur_epoch_num, cur_step_in_epoch)) if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: print("epoch: %s step: %s, g_loss %s d_loss %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, g_loss, d_loss), flush=True)
class GANImageSave(callback.Callback): def __init__(self, generator: Cell, noise_dim) -> None: super().__init__() self.generator = generator self.seed = ms.Tensor(np.random.randn(16, noise_dim), ms.float32) if not os.path.exists('./log'): os.mkdir('./log')
def epoch_end(self, run_context): cb_params = run_context.original_args() predictions: ms.Tensor = self.generator(self.seed) predictions = predictions.asnumpy()
fig = plt.figure(figsize=(4, 4))
for i in range(predictions.shape[0]): plt.subplot(4, 4, i + 1) plt.imshow(predictions[i, 0, :, :] * 127.5 + 127.5, cmap='gray') plt.axis('off')
plt.savefig('./log/image_at_epoch_{:04d}.png'.format(cb_params.cur_epoch_num))
class Timer(callback.Callback):
def epoch_begin(self, run_context): self.start = time.time()
def epoch_end(self, run_context): cb_params = run_context.original_args() print('Time for epoch {} is {} sec'.format(cb_params.cur_epoch_num, time.time() - self.start))
if __name__ == "__main__": EPOCHS = 50 NOISE_DIM = 100 BATCH_SIZE = 256 num_examples_to_generate = 16 context.set_context(mode=context.GRAPH_MODE, device_target='GPU') sink_mode = True
""" set dataset ~ """ download_dataset() mnist_path = "./MNIST_Data" ds_train = create_dataset(os.path.join(mnist_path, "train"), NOISE_DIM, BATCH_SIZE, 1)
""" define model ~ """ net = GANBaseNet(NOISE_DIM) net_loss = GANWithLoss(net) generator_optimizer = moptim.Adam(net.generator.trainable_params(), 1e-4) discriminator_optimizer = moptim.Adam(net.discriminator.trainable_params(), 1e-4) net_train_step = TrainStepWrap(net_loss, generator_optimizer, discriminator_optimizer) model = ms.train.Model(net_train_step, amp_level='O2')
""" trianing ~ """ model.train(EPOCHS, ds_train, callbacks=[Timer(), GANImageSave(net.generator, NOISE_DIM)], dataset_sink_mode=sink_mode) """ make gif ~ """ anim_file = 'dcgan.gif' import imageio import glob with imageio.get_writer(anim_file, mode='I') as writer: filenames = glob.glob('./log/image*.png') filenames = sorted(filenames) last = -1 for i, filename in enumerate(filenames): frame = 2 * (i**0.5) if round(frame) > round(last): last = frame else: continue image = imageio.imread(filename) writer.append_data(image) image = imageio.imread(filename) writer.append_data(image) """ 14.3 sec/epoch , GPU mem 1544Mb """
|