PyTorch样式转移的优化过程(图解)

本文概述

  • 优化的迭代过程
  • 绘制内容, 样式和最终目标图像
  • 完整的代码
【PyTorch样式转移的优化过程(图解)】我们拥有所有三个图像, 现在, 我们可以执行优化过程。要执行优化过程, 我们必须执行以下步骤:
步骤1:
第一步, 我们定义一些基本参数, 这些参数可以帮助我们直观地了解培训过程, 并有助于我们简化培训过程。第一个参数每次都向我们显示我们的目标图像, 以便我们可以检查优化过程。我们用目标图像定义Adam优化器, 并设置目标学习率。最后但并非最不重要的一点是, 我们定义了培训过程应采取的优化步骤的数量。
我们需要在结果和时间效率之间取得平衡, 因为培训过程可能需要很长时间才能完成。因此, 我们将定义步骤, 在本例中, 我们将步骤限制为2100。
show_every=300optimizer=optim.Adam([target], lr=0.003)steps=2100

第2步:
现在, 我们实现了一些代码行用于数据可视化。我们定义了一个图像阵列, 它将在整个训练过程中存储目标图像。训练过程结束后, 我们可以从这些图像中创建一个视频, 以直观了解样式和内容图像如何组合以优化目标图像。我们将解开目标图像的形状。
height, width, channels=im_convert(target).shapeimage_array=np.empty(shape=(300, height, width, channels))

我们将定义一个捕获帧, 这有助于我们每次捕获一个帧。最后, 我们将定义一个计数器变量, 该变量将跟踪数组索引。
capture_frame=steps/300counter=0

优化的迭代过程
#Defining a loop statement from 1 to steps+1for ii in range(1, steps+1): #To ensure that our loop runs for the defined number of steps # Extracting feature for our current target image target_features=get_features(target, vgg)#Calculating the content loss for the iterationcontent_loss=torch.mean((target_features['conv4_2']content_features['conv4_2'])**2)#Initializing style loss style_loss=0 #The style loss is the result of a combine loss from five different layer within our model. #For this reason we iterate through the five style features to get the error at each layer. for layer in style_weights:#Collecting the target feature for the specific layer from the target feature variable target_feature=target_features[layer]#Applying gram matrix function to our target featuretarget_gram=gram_matrix(target_feature)#Getting style_gram value for our style image from the style grams variablestyle_gram=style_grams[layer]#Calculating the layer style loss as content losslayer_style_loss=style_weights[layer]*torch.mean((target_gram-style_gram)**2)#Obtaining feature dimensions_, d, h, w=target_feature.shape #Calculating total style lossstyle_loss += layer_style_loss/(d*h*w)#Calculating total losstotal_loss=content_weight*content_loss+style_weight*style_loss#Using the optimizer to update parameters within our target image optimizer.zero_grad()total_loss.backward()optimizer.step()#Processfor visualization throughout the training process#Comparing the iteration variable with our show everyif ii % show_every==0: #Printing total lossprint('Total loss:', total_loss.item())#Printing the iteration print('Iteration', ii)#Printting the target images plt.imshow(im_convert(target))#Removing the axis on the image plt.axis('off')# Showing image plt.show()#Comparing the iteration variable with our capture frame variable if ii%capture_frame==0: # Capturing a frame at every 700 iteration#Storing the target image into the image_arrayimage_array[counter]=im_convert(target)# Increment in the counter variable counter=counter+1

当我们运行代码时, 它将为我们提供预期的输出:
PyTorch样式转移的优化过程(图解)

文章图片
PyTorch样式转移的优化过程(图解)

文章图片
PyTorch样式转移的优化过程(图解)

文章图片
绘制内容, 样式和最终目标图像
#Making a grid arrangement with a single row and three columns for our three imagesfig, (ax1, ax2, ax3)=plt.subplots(1, 3, figsize=(20, 10))#Plotting content image ax1.imshow(im_convert(content))ax1.axis('off')#Plotting style imageax2.imshow(im_convert(style))ax2.axis('off')#Plotting target imageax3.imshow(im_convert(target))ax3.axis('off')

PyTorch样式转移的优化过程(图解)

文章图片
完整的代码
#Required Librariesimport torchimport torch.optim as optimfrom torchvision import transforms, modelsfrom PIL import Imageimport matplotlib.pyplot as pltimport numpy as np#Creating Modelvgg=models.vgg19(pretrained=True).featuresfor param in vgg.parameters():param.requires_grad_(False)#Add model to devicedevice=torch.device("cuda" if torch.cuda.is_available() else "cpu")vgg.to(device)#Load Iamgedef load_image(img_path, max_size=400, shape=None):image=Image.open(img_path).convert('RGB')if max(image.size)> max_size:size=max_sizeelse:size=max(image.size)if shape is not None:size=shapein_transform=transforms.Compose([transforms.Resize(size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])image=in_transform(image).unsqueeze(0)return imagecontent=load_image('ab.jpg').to(device)style=load_image('abc.jpg', shape=content.shape[-2:]).to(device)#Image Conversiondef im_convert(tensor):image=tensor.cpu().clone().detach().numpy()image=image.squeeze()image=image.transpose(1, 2, 0)image=image*np.array((0.5, 0.5, 0.5))+np.array((0.5, 0.5, 0.5))image=image.clip(0, 1)return image #Plotting Imagesfig, (ax1, ax2)=plt.subplots(1, 2, figsize=(20, 10))ax1.imshow(im_convert(content))ax1.axis('off')ax2.imshow(im_convert(style))ax2.axis('off')#Getting Featuresdef get_features(image, model):layers={'0':'conv1_1', '5':'conv2_1', '10':'conv3_1', '19':'conv4_1', '21':'conv4_2', '28':'conv5_1', }features={}for name, layer in model._modules.items():image=layer(image)if name in layers:features[layers[name]]=imagereturn features#Making content and style featurescontent_features=get_features(content, vgg)style_features=get_features(style, vgg)#Creating gram matrixdef gram_matrix(tensor):_, d, h, w=tensor.size()tensor=tensor.view(d, h*w)gram=torch.mm(tensor, tensor.t())return gram#Creating style gramsstyle_grams={layer:gram_matrix(style_features[layer]) for layer in style_features}#Initializing style weightsstyle_weights={'conv1_1':1., 'conv2_1':0.75, 'conv3_1':0.2, 'conv4_1':0.2, 'conv5_1':0.2}content_weight=1style_weight=1e6target=content.clone().requires_grad_(True).to(device)#Performing optimizationshow_every=300optimizer=optim.Adam([target], lr=0.003)steps=2100height, width, channels=im_convert(target).shapeimage_array=np.empty(shape=(300, height, width, channels))capture_frame=steps/300counter=0for ii in range(1, steps+1):target_features=get_features(target, vgg)content_loss=torch.mean((target_features['conv4_2']-content_features['conv4_2'])**2)style_loss=0for layer in style_weights:target_feature=target_features[layer]target_gram=gram_matrix(target_feature)style_gram=style_grams[layer]layer_style_loss=style_weights[layer]*torch.mean((target_gram-style_gram)**2)_, d, h, w=target_feature.shapestyle_loss += layer_style_loss/(d*h*w)total_loss=content_weight*content_loss+style_weight*style_lossoptimizer.zero_grad()total_loss.backward()optimizer.step()#Plotting output imagesif ii % show_every==0:print('Total loss:', total_loss.item())print('Iteration', ii)plt.imshow(im_convert(target))plt.axis('off')plt.show()if ii%capture_frame==0:image_array[counter]=im_convert(target)counter=counter+1#Plotting content, style and target imagesfig, (ax1, ax2, ax3)=plt.subplots(1, 3, figsize=(20, 10))ax1.imshow(im_convert(content))ax1.axis('off')ax2.imshow(im_convert(style))ax2.axis('off')ax3.imshow(im_convert(target))ax3.axis('off')

输出
PyTorch样式转移的优化过程(图解)

文章图片

    推荐阅读