跳转到内容

DiffSynth Studio

DiffSynth-StudioModelScope 推出的一个开源的扩散模型引擎,专注于图像与视频的风格迁移与生成任务。它通过优化架构设计(如文本编码器、UNet、VAE 等组件),在保持与开源社区模型兼容性的同时,显著提升计算性能,为用户提供高效、灵活的创作工具。

DiffSynth Studio 支持多种扩散模型,包括 Wan-Video、StepVideo、HunyuanVideo、CogVideoX、FLUX、ExVideo、Kolors、Stable Diffusion 3 等。

你可以使用DiffSynth Studio快速进行Diffusion模型训练,同时使用SwanLab进行实验跟踪与可视化。

准备工作

1. 克隆仓库并安装环境

bash
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
pip install swanlab
pip install lightning lightning_fabric

2. 准备数据集

DiffSynth Studio 的数据集需要按下面的格式进行构建,比如将图像数据存放在data/dog目录下:

bash
data/dog/
└── train
    ├── 00.jpg
    ├── 01.jpg
    ├── 02.jpg
    ├── 03.jpg
    ├── 04.jpg
    └── metadata.csv

metadata.csv 文件需要按下面的格式进行构建:

csv
file_name,text
00.jpg,一只小狗
01.jpg,一只小狗
02.jpg,一只小狗
03.jpg,一只小狗
04.jpg,一只小狗

这里有一份整理好格式的火影忍者数据集,百度云,供参考与测试

3. 准备模型

这里以Kolors模型为例,下载模型权重和VAE权重:

bash
modelscope download --model=Kwai-Kolors/Kolors --local_dir models/kolors/Kolors
modelscope download --model=AI-ModelScope/sdxl-vae-fp16-fix --local_dir models/kolors/sdxl-vae-fp16-fix

设置SwanLab参数

在运行训练脚本时,添加--use_swanlab,即可将训练过程记录到SwanLab平台。

如果你需要离线记录,可以添加--swanlab_mode "local"

bash
CUDA_VISIBLE_DEVICES="0" python examples/train/kolors/train_kolors_lora.py \
...
--use_swanlab \ 
--swanlab_mode "cloud"

开启训练

使用下面的命令即可开启训练,并使用SwanLab记录超参数、训练日志、loss曲线等信息:

bash
CUDA_VISIBLE_DEVICES="0" python examples/train/kolors/train_kolors_lora.py \
--pretrained_unet_path models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors \
--pretrained_text_encoder_path models/kolors/Kolors/text_encoder \
--pretrained_fp16_vae_path models/kolors/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors \
--dataset_path data/dog \
--output_path ./models \
--max_epochs 10 \
--center_crop \
--use_gradient_checkpointing \
--precision "16-mixed" \
--use_swanlab \
--swanlab_mode "cloud"

补充

如果你想要自定义SwanLab的项目名、实验名等参数,可以:

1. 文生图任务

DiffSynth-Studio/diffsynth/trainers/text_to_image.py文件中,找到swanlab_logger变量的位置,修改projectname参数:

python
if args.use_swanlab:
    from swanlab.integration.pytorch_lightning import SwanLabLogger
    swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
    swanlab_config.update(vars(args))
    swanlab_logger = SwanLabLogger(
        project="diffsynth_studio", 
        name="diffsynth_studio",
        config=swanlab_config,
        mode=args.swanlab_mode,
        logdir=args.output_path,
    )
    logger = [swanlab_logger]

2. Wan-Video文生视频任务

DiffSynth-Studio/examples/wanvideo/train_wan_t2v.py文件中,找到swanlab_logger变量的位置,修改projectname参数:

python
if args.use_swanlab:
    from swanlab.integration.pytorch_lightning import SwanLabLogger
    swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
    swanlab_config.update(vars(args))
    swanlab_logger = SwanLabLogger(
        project="wan", 
        name="wan",
        config=swanlab_config,
        mode=args.swanlab_mode,
        logdir=args.output_path,
    )
    logger = [swanlab_logger]