Python骚操作,让图片人物动起来!

开发 后端
而今天我们就将借助论文所分享的源代码,构建模型创建自己需要的人物运动。

[[352406]]

 引言:近段时间,一个让梦娜丽莎图像动起来的项目火遍了朋友圈。而今天我们就将实现让图片中的人物随着视频人物一起产生动作。

[[352407]]

其中通过在静止图像中动画对象产生视频有无数的应用跨越的领域兴趣,包括电影制作、摄影和电子商务。更准确地说,是图像动画指将提取的视频外观结合起来自动合成视频的任务一种源图像与运动模式派生的视频。

近年来,深度生成模型作为一种有效的图像动画技术出现了视频重定向。特别是,可生成的对抗网络(GANS)和变分自动编码器(VAES)已被用于在视频中人类受试者之间转换面部表情或运动模式。

根据论文FirstOrder Motion Model for Image Animation可知,在姿态迁移的大任务当中,Monkey-Net首先尝试了通过自监督范式预测关键点来表征姿态信息,测试阶段估计驱动视频的姿态关键点完成迁移工作。在此基础上,FOMM使用了相邻关键点的局部仿射变换来模拟物体运动,还额外考虑了遮挡的部分,遮挡的部分可以使用image inpainting生成。

而今天我们就将借助论文所分享的源代码,构建模型创建自己需要的人物运动。具体流程如下。

实验前的准备

首先我们使用的python版本是3.6.5所用到的模块如下:

  •  imageio模块用来控制图像的输入输出等。
  •  Matplotlib模块用来绘图。
  •  numpy模块用来处理矩阵运算。
  •  Pillow库用来加载数据处理。
  •  pytorch模块用来创建模型和模型训练等。
  •  完整模块需求参见requirements.txt文件。

模型的加载和调用

通过定义命令行参数来达到加载模型,图片等目的。

(1)首先是训练模型的读取,包括模型加载方式: 

  1. def load_checkpoints(config_path, checkpoint_path, cpu=False):  
  2.     with open(config_path) as f:  
  3.         config = yaml.load(f)  
  4.     generator = OcclusionAwareGenerator(**config[ model_params ][ generator_params ],  
  5.                                         **config[ model_params ][ common_params ])  
  6.     if not cpu: 
  7.          generator.cuda()  
  8.     kp_detector = KPDetector(**config[ model_params ][ kp_detector_params ], 
  9.                               **config[ model_params ][ common_params ])  
  10.     if not cpu:  
  11.         kp_detector.cuda()  
  12.     if cpu:  
  13.         checkpoint = torch.load(checkpoint_path, map_location=torch.device( cpu ))  
  14.     else:  
  15.         checkpoint = torch.load(checkpoint_path)  
  16.     generator.load_state_dict(checkpoint[ generator ])  
  17.     kp_detector.load_state_dict(checkpoint[ kp_detector ])  
  18.     if not cpu:  
  19.         generator = DataParallelWithCallback(generator)  
  20.         kp_detector = DataParallelWithCallback(kp_detector)  
  21.     generator.eval() 
  22.      kp_detector.eval()  
  23.     return generator, kp_detector  

(2)然后是利用模型创建产生的虚拟图像,找到最佳的脸部特征: 

  1. def make_animation(source_image, driving_video, generator, kp_detector, relative=Trueadapt_movement_scale=Truecpu=False):  
  2.     with torch.no_grad():  
  3.         predictions = []  
  4.         source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)  
  5.         if not cpu:  
  6.             sourcesource = source.cuda()  
  7.         driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)  
  8.         kp_source = kp_detector(source)  
  9.         kp_driving_initial = kp_detector(driving[:, :, 0])  
  10.         for frame_idx in tqdm(range(driving.shape[2])): 
  11.              drivingdriving_frame = driving[:, :, frame_idx]  
  12.             if not cpu:  
  13.                 driving_framedriving_frame = driving_frame.cuda()  
  14.             kp_driving = kp_detector(driving_frame)  
  15.             kp_norm = normalize_kp(kp_sourcekp_source=kp_source, kp_drivingkp_driving=kp_driving,  
  16.                                    kp_driving_initialkp_driving_initial=kp_driving_initial, use_relative_movement=relative,  
  17.                                    use_relative_jacobian=relative, adapt_movement_scaleadapt_movement_scale=adapt_movement_scale)  
  18.             out = generator(source, kp_sourcekp_source=kp_source, kp_driving=kp_norm)           predictions.append(np.transpose(out[ prediction ].data.cpu().numpy(), [0, 2, 3, 1])[0]) 
  19.     return predictions  
  20. def find_best_frame(source, driving, cpu=False):  
  21.     import face_alignment  
  22.     def normalize_kp(kp):  
  23.         kpkp = kp - kp.mean(axis=0keepdims=True 
  24.         area = ConvexHull(kp[:, :2]).volume  
  25.         area = np.sqrt(area)  
  26.         kp[:, :2] = kp[:, :2] / area  
  27.         return kp  
  28.     fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True 
  29.                                       devicecpu  if cpu else  cuda ) 
  30.      kp_source = fa.get_landmarks(255 * source)[0]  
  31.     kp_source = normalize_kp(kp_source) 
  32.      norm  = float( inf )  
  33.     frame_num = 0  
  34.     for i, image in tqdm(enumerate(driving)):  
  35.         kp_driving = fa.get_landmarks(255 * image)[0]  
  36.         kp_driving = normalize_kp(kp_driving)  
  37.         new_norm = (np.abs(kp_source - kp_driving) ** 2).sum()  
  38.         if new_norm < norm:  
  39.             norm = new_norm  
  40.             frame_num = i  
  41. return frame_num 

(3)    接着定义命令行调用参数加载图片、视频等方式: 

  1. parser = ArgumentParser()  
  2.     parser.add_argument("--config", required=Truehelp="path to config" 
  3.     parser.add_argument("--checkpoint", defaultvox-cpk.pth.tar , help="path to checkpoint to restore" 
  4.     parser.add_argument("--source_image", defaultsup-mat/source.png , help="path to source image" 
  5.     parser.add_argument("--driving_video", defaultsup-mat/source.png , help="path to driving video" 
  6.     parser.add_argument("--result_video", defaultresult.mp4 , help="path to output"
  7.      parser.add_argument("--relative", dest="relative"action="store_true"help="use relative or absolute keypoint coordinates" 
  8.     parser.add_argument("--adapt_scale", dest="adapt_scale"action="store_true"help="adapt movement scale based on convex hull of keypoints" 
  9.     parser.add_argument("--find_best_frame", dest="find_best_frame"action="store_true",  
  10.                          help="Generate from the frame that is the most alligned with source. (Only for faces, requires face_aligment lib)" 
  11.     parser.add_argument("--best_frame", dest="best_frame"type=intdefault=None,    
  12.                         help="Set frame to start from."
  13.     parser.add_argument("--cpu", dest="cpu"action="store_true"help="cpu mode." 
  14.     parser.set_defaults(relative=False 
  15.     parser.set_defaults(adapt_scale=False 
  16.     opt = parser.parse_args()  
  17.     source_image = imageio.imread(opt.source_image)  
  18.     reader = imageio.get_reader(opt.driving_video)  
  19.     fps = reader.get_meta_data()[ fps ] 
  20.     driving_video = [] 
  21.     try:  
  22.         for im in reader:  
  23.             driving_video.append(im)  
  24.     except RuntimeError:  
  25.         pass  
  26.     reader.close()  
  27.     source_image = resize(source_image, (256, 256))[..., :3]  
  28.     driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]  
  29.     generator, kp_detector = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint, cpu=opt.cpu)  
  30.     if opt.find_best_frame or opt.best_frame is not None:  
  31.         i = opt.best_frame if opt.best_frame is not None else find_best_frame(source_image, driving_video, cpu=opt.cpu)  
  32.         print ("Best frame: " + str(i))  
  33.         driving_forward = driving_video[i:]  
  34.         driving_backward = driving_video[:(i+1)][::-1]  
  35.         predictions_forward = make_animation(source_image, driving_forward, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu) 
  36.         predictions_backward = make_animation(source_image, driving_backward, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu) 
  37.         predictions = predictions_backward[::-1] + predictions_forward[1:]  
  38.     else:  
  39.         predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu) 
  40.  imageio.mimsave(opt.result_video, [img_as_ubyte(frame) for frame in predictions], fpsfps=fps) 

[[352408]]

模型的搭建

整个模型训练过程是图像重建的过程,输入是源图像和驱动图像,输出是保留源图像物体信息的含有驱动图像姿态的新图像,其中输入的两张图像来源于同一个视频,即同一个物体信息,那么整个训练过程就是驱动图像的重建过程。大体上来说分成两个模块,一个是motion estimation module,另一个是imagegeneration module。

(1)其中通过定义VGG19模型建立网络层作为perceptual损失。

其中手动输入数据进行预测需要设置更多的GUI按钮,其中代码如下: 

  1. class Vgg19(torch.nn.Module):  
  2.     """  
  3.     Vgg19 network for perceptual loss. See Sec 3.3.  
  4.     """  
  5.     def __init__(self, requires_grad=False):  
  6.         super(Vgg19, self).__init__()  
  7.         vgg_pretrained_features = models.vgg19(pretrained=True).features  
  8.         self.slice1 = torch.nn.Sequential()  
  9.         self.slice2 = torch.nn.Sequential()  
  10.         self.slice3 = torch.nn.Sequential()  
  11.         self.slice4 = torch.nn.Sequential()  
  12.         self.slice5 = torch.nn.Sequential()  
  13.         for x in range(2): 
  14.             self.slice1.add_module(str(x), vgg_pretrained_features[x])  
  15.         for x in range(2, 7):  
  16.             self.slice2.add_module(str(x), vgg_pretrained_features[x])  
  17.         for x in range(7, 12):  
  18.             self.slice3.add_module(str(x), vgg_pretrained_features[x])  
  19.         for x in range(12, 21):  
  20.             self.slice4.add_module(str(x), vgg_pretrained_features[x])  
  21.         for x in range(21, 30):  
  22.             self.slice5.add_module(str(x), vgg_pretrained_features[x])  
  23.         self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),  
  24.                                        requires_grad=False 
  25.         self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),  
  26.                                       requires_grad=False 
  27.         if not requires_grad:  
  28.             for param in self.parameters():  
  29.                 param.requires_grad = False  
  30.     def forward(self, X):  
  31.         X = (X - self.mean) / self.std  
  32.         h_relu1 = self.slice1(X)  
  33.         h_relu2 = self.slice2(h_relu1)  
  34.         h_relu3 = self.slice3(h_relu2)  
  35.         h_relu4 = self.slice4(h_relu3)  
  36.         h_relu5 = self.slice5(h_relu4)  
  37.         out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]  
  38.         return out 

(2)创建图像金字塔计算金字塔感知损失: 

  1. class ImagePyramide(torch.nn.Module):  
  2.     """  
  3.     Create image pyramide for computing pyramide perceptual loss. See Sec 3.3  
  4.     """  
  5.     def __init__(self, scales, num_channels):  
  6.         super(ImagePyramide, self).__init__()  
  7.         downs = {}  
  8.         for scale in scales:  
  9.             downs[str(scale).replace( . ,  - )] = AntiAliasInterpolation2d(num_channels, scale)  
  10.         self.downs = nn.ModuleDict(downs)  
  11.     def forward(self, x):  
  12.         out_dict = {} 
  13.          for scale, down_module in self.downs.items():  
  14.             out_dict[ prediction_  + str(scale).replace( - ,  . )] = down_module(x)  
  15.         return out_dict 

(3)等方差约束的随机tps变换 

  1. class Transform:  
  2.     """  
  3.     Random tps transformation for equivariance constraints. See Sec 3.3  
  4.     """  
  5.     def __init__(self, bs, **kwargs):  
  6.         noise = torch.normal(mean=0std=kwargs[ sigma_affine ] * torch.ones([bs, 2, 3]))  
  7.         self.theta = noise + torch.eye(2, 3).view(1, 2, 3)  
  8.         self.bs = bs  
  9.         if ( sigma_tps  in kwargs) and ( points_tps  in kwargs):  
  10.             self.tps = True  
  11.             self.control_points = make_coordinate_grid((kwargs[ points_tps ], kwargs[ points_tps ]), type=noise.type())  
  12.             selfself.control_points = self.control_points.unsqueeze(0)  
  13.             self.control_params = torch.normal(mean=0 
  14.                                                std=kwargs[ sigma_tps ] * torch.ones([bs, 1, kwargs[ points_tps ] ** 2]))  
  15.         else:  
  16.             self.tps = False  
  17.     def transform_frame(self, frame):  
  18.         grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0)  
  19.         gridgrid = grid.view(1, frame.shape[2] * frame.shape[3], 2)  
  20.         grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)  
  21.         return F.grid_sample(frame, grid, padding_mode="reflection" 
  22.     def warp_coordinates(self, coordinates):  
  23.         theta = self.theta.type(coordinates.type())  
  24.         thetatheta = theta.unsqueeze(1)  
  25.         transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]  
  26.         transformedtransformed = transformed.squeeze(-1) 
  27.         if self.tps:  
  28.             control_points = self.control_points.type(coordinates.type())  
  29.             control_params = self.control_params.type(coordinates.type())  
  30.             distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)  
  31.             distances = torch.abs(distances).sum(-1)  
  32.             result = distances ** 2  
  33.             resultresult = result * torch.log(distances + 1e-6)  
  34.             resultresult = result * control_params  
  35.             resultresult = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)  
  36.             transformedtransformed = transformed + result  
  37.         return transformed  
  38.     def jacobian(self, coordinates):  
  39.         new_coordinates = self.warp_coordinates(coordinates)  
  40.         gradgrad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True 
  41.         gradgrad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True 
  42.         jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)  
  43.         return jacobian      

(4)生成器的定义:生成器,给定的源图像和和关键点尝试转换图像根据运动轨迹引起要点。部分代码如下: 

  1. class OcclusionAwareGenerator(nn.Module):  
  2.     def __init__(self, num_channels, num_kp, block_expansion, max_features, num_down_blocks,  
  3.                  num_bottleneck_blocks, estimate_occlusion_map=Falsedense_motion_params=Noneestimate_jacobian=False):  
  4.         super(OcclusionAwareGenerator, self).__init__()  
  5.         if dense_motion_params is not None:  
  6.             self.dense_motion_network = DenseMotionNetwork(num_kpnum_kp=num_kp, num_channelsnum_channels=num_channels,  
  7.                                                            estimate_occlusion_mapestimate_occlusion_map=estimate_occlusion_map,  
  8.                                                            **dense_motion_params) 
  9.          else:  
  10.             self.dense_motion_network = None  
  11.         self.first = SameBlock2d(num_channels, block_expansion, kernel_size=(7, 7), padding=(3, 3))  
  12.         down_blocks = []  
  13.         for i in range(num_down_blocks):  
  14.             in_features = min(max_features, block_expansion * (2 ** i))  
  15.             out_features = min(max_features, block_expansion * (2 ** (i + 1)))  
  16.             down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))  
  17.         self.down_blocks = nn.ModuleList(down_blocks)  
  18.         up_blocks = [] 
  19.          for i in range(num_down_blocks):  
  20.             in_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i)))  
  21.             out_features = min(max_features, block_expansion * (2 ** (num_down_blocks - i - 1)))  
  22.             up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))  
  23.         self.up_blocks = nn.ModuleList(up_blocks)  
  24.         self.bottleneck = torch.nn.Sequential() 
  25.          in_features = min(max_features, block_expansion * (2 ** num_down_blocks))  
  26.         for i in range(num_bottleneck_blocks):  
  27.             self.bottleneck.add_module( r  + str(i), ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))  
  28.         self.final = nn.Conv2d(block_expansion, num_channels, kernel_size=(7, 7), padding=(3, 3))  
  29.         self.estimate_occlusion_map = estimate_occlusion_map  
  30.         self.num_channels = num_channels 

(5)判别器类似于Pix2PixGenerator。 

  1. def __init__(self, num_channels=3block_expansion=64num_blocks=4max_features=512 
  2.                  sn=Falseuse_kp=Falsenum_kp=10kp_variance=0.01, **kwargs):  
  3.         super(Discriminator, self).__init__()  
  4.         down_blocks = []  
  5.         for i in range(num_blocks):  
  6.             down_blocks.append(  
  7.                 DownBlock2d(num_channels + num_kp * use_kp if i == 0 else min(max_features, block_expansion * (2 ** i)),  
  8.                             min(max_features, block_expansion * (2 ** (i + 1))),  
  9.                             norm=(i != 0), kernel_size=4pool=(i != num_blocks - 1), snsn=sn))  
  10.         self.down_blocks = nn.ModuleList(down_blocks)  
  11.         self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1kernel_size=1 
  12.         if sn:  
  13.             self.conv = nn.utils.spectral_norm(self.conv)  
  14.         self.use_kp = use_kp  
  15.         self.kp_variance = kp_variance  
  16.     def forward(self, x, kp=None):  
  17.         feature_maps = []  
  18.         out = x  
  19.         if self.use_kp:  
  20.             heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance)  
  21.             out = torch.cat([out, heatmap], dim=1 
  22.         for down_block in self.down_blocks:  
  23.             feature_maps.append(down_block(out))  
  24.             out = feature_maps[-1]  
  25.         prediction_map = self.conv(out)  
  26.         return feature_maps, prediction_map 

最终通过以下代码调用模型训练“python demo.py--config config/vox-adv-256.yaml --driving_video path/to/driving/1.mp4--source_image path/to/source/7.jpg --checkpointpath/to/checkpoint/vox-adv-cpk.pth.tar --relative --adapt_scale”

效果如下:

[[352409]][[352410]]  

 

责任编辑:庞桂玉 来源: 机器学习算法与Python学习
相关推荐

2022-06-07 09:00:32

PythonAI静态图片

2012-09-03 09:21:51

2009-06-19 11:18:51

Factory BeaSpring配置

2013-05-27 15:35:18

用友UAP移动应用移动平台

2024-03-28 13:14:00

数据训练

2022-02-24 08:30:24

操作系统CPU程序

2019-05-22 15:04:34

Python磁盘IO

2011-06-01 14:51:54

jQuery

2021-09-26 09:23:01

GC算法垃圾

2010-09-08 09:48:56

Gif播放教程Android

2018-07-26 13:53:27

2010-05-21 11:03:51

统一通信系统

2011-09-15 17:36:29

Android应用Call Cartoo动画

2021-04-12 11:47:21

人工智能知识图谱

2014-03-21 09:52:29

jQuery动画插件

2020-09-21 21:40:19

AI 数据人工智能

2012-05-21 10:53:30

HTML5

2022-07-13 15:46:57

Python数据可视化代码片段

2015-12-01 13:51:52

Webrtc

2012-05-21 10:45:30

HTML5
点赞
收藏

51CTO技术栈公众号