#码力全开·技术π对#KerasCV的StableDiffusion实现如何优化多GPU推理吞吐量?

启用`jit_compile=True`后出现显存碎片,如何平衡XLA优化与批次大小?

KerasCV
Jimaks
2025-05-29 08:22:51
浏览
收藏 0
回答 1
待解决
回答 1
按赞同
/
按时间
尔等氏人
尔等氏人
  1. 使用​​tf.distribute.MirroredStrategy​​进行多GPU分布式推理,合理设置批次大小以提升吞吐量。
  2. 启用XLA优化时,可尝试调整​​jit_compile=True​​的同时,使用​​tf.config.optimizer.set_jit(True)​​控制全局JIT行为。
  3. 针对显存碎片问题,建议限制每个GPU的显存增长:​​tf.config.experimental.set_memory_growth(physical_devices, True)​​。
  4. 平衡批次大小与XLA优化,可通过梯度累积或微批次方式缓解显存压力。
分享
微博
QQ
微信https://www.51cto.com/aigc/
回复
2025-05-30 08:33:49
发布
相关问题
提问