transformer_only models of LTX-2.3 (experimental)
- You can get rest of model from Kijai/LTX2.3_comfy.
- You may need to fetch pr-12978 on ComfyUI for Lora.
- nvfp4 and mxfp8 are fast only on the RTX 5000 series.
Updates
2016/03/16 - The nvfp4 model updated, so I remade it.
Samples
Note that I converted video to AnimatedWebP - means quality dropped.
Generation Speed
640x960 121frames cfg1 steps8 on RTX5090.
### nvfp4
Model LTXAV prepared for dynamic VRAM loading. 16747MB Staged. 1660 patches attached.
100%|βββββββββββββββββββ| 8/8 [00:11<00:00, 1.43s/it]
### nvfp4mixed_input_scaled
Model LTXAV prepared for dynamic VRAM loading. 19295MB Staged. 1660 patches attached.
100%|βββββββββββββββββββ| 8/8 [00:11<00:00, 1.44s/it]
### fp8_input_scaled
Model LTXAV prepared for dynamic VRAM loading. 23838MB Staged. 1660 patches attached.
100%|βββββββββββββββββββ| 8/8 [00:13<00:00, 1.73s/it]
### mxfp8
Model LTXAV prepared for dynamic VRAM loading. 24345MB Staged. 1660 patches attached.
100%|βββββββββββββββββββ| 8/8 [00:13<00:00, 1.67s/it]
How to reproduce
nvfp4
cut from Lightricks/LTX-2.3-nvfp4
import sys
import json
import torch
from safetensors.torch import safe_open, save_file
def cut_safetensors(input_path, output_path):
with safe_open(input_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
config = json.loads(metadata.get('config', '{}'))
for key in ['vae', 'audio_vae', 'vocoder']:
if key in config:
del config[key]
metadata['config'] = json.dumps(config)
quant_meta = json.loads(metadata.get('_quantization_metadata', '{"layers": {}}'))
quant_layers = quant_meta.get("layers", {})
del metadata['_quantization_metadata']
new_state_dict = {}
prefix = "model.diffusion_model."
for key in f.keys():
if key.startswith(prefix):
new_state_dict[key] = f.get_tensor(key)
base_key = key.replace(".weight", "")
if base_key in quant_layers:
quant_info = quant_layers[base_key]
json_data = json.dumps(quant_info).encode("utf-8")
new_tensor = torch.tensor(list(json_data), dtype=torch.uint8)
new_state_dict[f"{base_key}.comfy_quant"] = new_tensor
save_file(new_state_dict, output_path, metadata=metadata)
input_path, output_path = sys.argv[1:3]
if __name__ == "__main__":
cut_safetensors(input_path, output_path)
nvfp4mixed_input_scaled
mixed Lightricks/LTX-2.3-nvfp4 on Lightricks/LTX-2.3-fp8
import torch
import json
import os
from safetensors.torch import load_file, save_file, safe_open
def parse_args():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("src1") # fp8
parser.add_argument("src2") # nvfp4
parser.add_argument("dst")
return parser.parse_args()
def main():
args = parse_args()
original_metadata = {}
with safe_open(args.src1, framework="pt") as f:
original_metadata = f.metadata()
state_dict1 = load_file(args.src1)
new_state_dict = {}
quantization_layers = {}
block_names = [".to_", ".ff.net"]
exception_names = [
"blocks.0.", "blocks.1.", "blocks.46.", "blocks.47.",
".to_gate_logits", ".to_out.0", ".audio_ff.net", ".attn2."
]
for key, tensor in state_dict1.items():
if any(b in key for b in block_names) and not any(e in key for e in exception_names):
continue
new_state_dict[key] = tensor
if key.endswith(".weight_scale"):
layer_name = key[:-13]
quantization_layers[layer_name] = {"format": "float8_e4m3fn"}
state_dict2 = load_file(args.src2)
for key, tensor in state_dict2.items():
if any(b in key for b in block_names) and not any(e in key for e in exception_names):
new_state_dict[key] = tensor
if key.endswith(".weight_scale"):
layer_name = key[:-13]
quantization_layers[layer_name] = {"format": "nvfp4"}
save_file(new_state_dict, args.dst, metadata=original_metadata)
total_bytes = os.path.getsize(args.dst)
print(f"Output: {args.dst} ({round(total_bytes / (1024**3), 2)}GB)")
if __name__ == "__main__":
main()
also added input_scale on nvfp4 layers.
import sys
from safetensors import safe_open
from safetensors.torch import save_file
import torch
def main():
src1_path, src2_path, output_path = sys.argv[1:4]
out_tensors = {}
with safe_open(src1_path, framework="pt") as f1:
for k in f1.keys():
out_tensors[k] = f1.get_tensor(k)
metadata = f1.metadata() or {}
with safe_open(src2_path, framework="pt") as f2:
for k in f2.keys():
if k.endswith(".input_scale"):
base_key = k.replace(".input_scale", "")
if f2.get_tensor(f"{base_key}.comfy_quant") is not None:
out_tensors[k] = f2.get_tensor(k)
save_file(out_tensors, output_path, metadata=metadata)
if __name__ == "__main__":
main()
mxfp8
using comfy-dit-quantizer with the below config file.
{
"format": "comfy_quant",
"block_names": ["transformer_blocks"],
"rules": [
{ "policy": "keep", "match": ["blocks.0.", "blocks.1.", "blocks.46.", "blocks.47."] },
{ "policy": "mxfp8", "match": [".to_", "ff.net."] }
]
}
- Downloads last month
- 1,141
Inference Providers NEW
This model isn't deployed by any Inference Provider. π Ask for provider support
Model tree for Bedovyy/LTX2.3_transformer_only_comfy
Base model
Lightricks/LTX-2.3
