DiffSynth Studio
DiffSynth-Studio 是 ModelScope 推出的一个开源的扩散模型引擎,专注于图像与视频的风格迁移与生成任务。它通过优化架构设计(如文本编码器、UNet、VAE 等组件),在保持与开源社区模型兼容性的同时,显著提升计算性能,为用户提供高效、灵活的创作工具。
DiffSynth Studio 支持多种扩散模型,包括 Wan-Video、StepVideo、HunyuanVideo、CogVideoX、FLUX、ExVideo、Kolors、Stable Diffusion 3 等。
你可以使用DiffSynth Studio快速进行Diffusion模型训练,同时使用SwanLab进行实验跟踪与可视化。
准备工作
1. 克隆仓库并安装环境
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
pip install swanlab
pip install lightning lightning_fabric
2. 准备数据集
DiffSynth Studio 的数据集需要按下面的格式进行构建,比如将图像数据存放在data/dog
目录下:
data/dog/
└── train
├── 00.jpg
├── 01.jpg
├── 02.jpg
├── 03.jpg
├── 04.jpg
└── metadata.csv
metadata.csv
文件需要按下面的格式进行构建:
file_name,text
00.jpg,一只小狗
01.jpg,一只小狗
02.jpg,一只小狗
03.jpg,一只小狗
04.jpg,一只小狗
这里有一份整理好格式的火影忍者数据集,百度云,供参考与测试
3. 准备模型
这里以Kolors模型为例,下载模型权重和VAE权重:
modelscope download --model=Kwai-Kolors/Kolors --local_dir models/kolors/Kolors
modelscope download --model=AI-ModelScope/sdxl-vae-fp16-fix --local_dir models/kolors/sdxl-vae-fp16-fix
设置SwanLab参数
在运行训练脚本时,添加--use_swanlab
,即可将训练过程记录到SwanLab平台。
如果你需要离线记录,可以添加--swanlab_mode "local"
。
CUDA_VISIBLE_DEVICES="0" python examples/train/kolors/train_kolors_lora.py \
...
--use_swanlab \
--swanlab_mode "cloud"
开启训练
使用下面的命令即可开启训练,并使用SwanLab记录超参数、训练日志、loss曲线等信息:
CUDA_VISIBLE_DEVICES="0" python examples/train/kolors/train_kolors_lora.py \
--pretrained_unet_path models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors \
--pretrained_text_encoder_path models/kolors/Kolors/text_encoder \
--pretrained_fp16_vae_path models/kolors/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors \
--dataset_path data/dog \
--output_path ./models \
--max_epochs 10 \
--center_crop \
--use_gradient_checkpointing \
--precision "16-mixed" \
--use_swanlab \
--swanlab_mode "cloud"
补充
如果你想要自定义SwanLab的项目名、实验名等参数,可以:
1. 文生图任务
在DiffSynth-Studio/diffsynth/trainers/text_to_image.py
文件中,找到swanlab_logger
变量的位置,修改project
和name
参数:
if args.use_swanlab:
from swanlab.integration.pytorch_lightning import SwanLabLogger
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
swanlab_config.update(vars(args))
swanlab_logger = SwanLabLogger(
project="diffsynth_studio",
name="diffsynth_studio",
config=swanlab_config,
mode=args.swanlab_mode,
logdir=args.output_path,
)
logger = [swanlab_logger]
2. Wan-Video文生视频任务
在DiffSynth-Studio/examples/wanvideo/train_wan_t2v.py
文件中,找到swanlab_logger
变量的位置,修改project
和name
参数:
if args.use_swanlab:
from swanlab.integration.pytorch_lightning import SwanLabLogger
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
swanlab_config.update(vars(args))
swanlab_logger = SwanLabLogger(
project="wan",
name="wan",
config=swanlab_config,
mode=args.swanlab_mode,
logdir=args.output_path,
)
logger = [swanlab_logger]