tsqn commited on
Commit
11396da
·
1 Parent(s): 4e19d7a

upload fp8_scaled versions

Browse files
convert_fp8.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ from tqdm.auto import tqdm
5
+ from safetensors.torch import load_file, save_file
6
+ from torch import dtype
7
+
8
+
9
+ def get_fp_maxval(bits=8, mantissa_bit=3, sign_bits=1):
10
+ _bits = torch.tensor(bits)
11
+ _mantissa_bit = torch.tensor(mantissa_bit)
12
+ _sign_bits = torch.tensor(sign_bits)
13
+ M = torch.clamp(torch.round(_mantissa_bit), 1, _bits - _sign_bits)
14
+ E = _bits - _sign_bits - M
15
+ bias = 2 ** (E - 1) - 1
16
+ mantissa = 1
17
+ for i in range(mantissa_bit - 1):
18
+ mantissa += 1 / (2 ** (i+1))
19
+ maxval = mantissa * 2 ** (2**E - 1 - bias)
20
+ return maxval
21
+
22
+
23
+ def quantize_to_fp8(x, bits=8, mantissa_bit=3, sign_bits=1):
24
+ """
25
+ Default is E4M3.
26
+ """
27
+ bits = torch.tensor(bits)
28
+ mantissa_bit = torch.tensor(mantissa_bit)
29
+ sign_bits = torch.tensor(sign_bits)
30
+ M = torch.clamp(torch.round(mantissa_bit), 1, bits - sign_bits)
31
+ E = bits - sign_bits - M
32
+ bias = 2 ** (E - 1) - 1
33
+ mantissa = 1
34
+ for i in range(mantissa_bit - 1):
35
+ mantissa += 1 / (2 ** (i+1))
36
+ maxval = mantissa * 2 ** (2**E - 1 - bias)
37
+ minval = - maxval
38
+ minval = - maxval if sign_bits == 1 else torch.zeros_like(maxval)
39
+ input_clamp = torch.min(torch.max(x, minval), maxval)
40
+ log_scales = torch.clamp(
41
+ (torch.floor(torch.log2(torch.abs(input_clamp)) + bias)).detach(), 1.0)
42
+ log_scales = 2.0 ** (log_scales - M - bias.type(x.dtype))
43
+ # dequant
44
+ qdq_out = torch.round(input_clamp / log_scales) * log_scales
45
+ return qdq_out, log_scales
46
+
47
+
48
+ def fp8_tensor_quant(x, scale, bits=8, mantissa_bit=3, sign_bits=1):
49
+ for i in range(len(x.shape) - 1):
50
+ scale = scale.unsqueeze(-1)
51
+ new_x = x / scale
52
+ quant_dequant_x, log_scales = quantize_to_fp8(
53
+ new_x, bits=bits, mantissa_bit=mantissa_bit, sign_bits=sign_bits)
54
+ return quant_dequant_x, scale, log_scales
55
+
56
+
57
+ def parse_args():
58
+ parser = argparse.ArgumentParser(
59
+ description="Convert safetensors to fp8 scaled",
60
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
61
+ )
62
+ parser.add_argument(
63
+ "--file",
64
+ type=str,
65
+ required=True,
66
+ help="Input .safetensors file to convert",
67
+ )
68
+ parser.add_argument(
69
+ "--base_dtype",
70
+ type=str,
71
+ default="bf16",
72
+ choices=["fp16", "bf16", "fp32"],
73
+ help="dtype to use for anything that can't be converted to fp8",
74
+ )
75
+ # parser.add_argument(
76
+ # "--ban_list",
77
+ # nargs="*",
78
+ # default=[],
79
+ # help="List of banned keys to keep in base dtype instead of converting to fp8 (zero or more strings)"
80
+ # )
81
+
82
+ args = parser.parse_args()
83
+ return args
84
+
85
+
86
+ def main(args):
87
+ input_path = os.path.normpath(args.file)
88
+ output_path = os.path.splitext(input_path)[0] + "_fp8_scaled.safetensors"
89
+
90
+ orig_state_dict = load_file(input_path)
91
+ new_state_dict = {}
92
+
93
+ model_dtype: dtype = None
94
+
95
+ if args.base_dtype == "fp16":
96
+ model_dtype = torch.float16
97
+ elif args.base_dtype == "bf16":
98
+ model_dtype = torch.bfloat16
99
+ elif args.base_dtype == "fp32":
100
+ model_dtype = torch.float32
101
+ else:
102
+ raise Exception(f"unknown dtype: {args.base_dtype}")
103
+
104
+ # ban_list = ["text", "time", "head"]
105
+ # ban_list = args.ban_list
106
+ # ban_list = ["norm", "embedder", "pad_token", "modulation", "final_layer"]
107
+ # ban_list = ["norm", "embedder", "pad_token", "modulation", "final_layer", "to_q", "to_k", "to_v"] # for transformer, output will be ~6gb(bf16) <- ex. z_image_turbo_bf16_fp8_scaled_1.safetensors
108
+ # ban_list = ["norm", "embedder", "pad_token", "modulation", "final_layer", "attention"] # for transformer, output will be ~8gb(bf16) <- ex. z_image_turbo_bf16_fp8_scaled_2.safetensors
109
+ ban_list = ["norm", "embed_tokens"] # for text encoder
110
+
111
+ maxval = get_fp_maxval()
112
+
113
+ for key in tqdm(orig_state_dict.keys()):
114
+ # decide whether to convert based on shape and banned keys
115
+ convert = False
116
+ if orig_state_dict[key].dim() == 2:
117
+ convert = True
118
+ for ban in ban_list:
119
+ if ban in key:
120
+ convert = False
121
+
122
+ scale_key = key.rsplit(".", 1)[0] + ".scale_weight"
123
+
124
+ if convert:
125
+ weight = orig_state_dict[key]
126
+ scale = torch.max(torch.abs(weight.flatten())) / maxval
127
+
128
+ linear_weight, scale, log_scales = fp8_tensor_quant(weight, scale)
129
+ linear_weight = linear_weight.to(dtype=torch.float8_e4m3fn)
130
+
131
+ new_state_dict[scale_key] = scale
132
+ new_state_dict[key] = linear_weight
133
+
134
+ else:
135
+ if orig_state_dict[key].dim() == 2:
136
+ new_state_dict[scale_key] = torch.ones(1)
137
+
138
+ new_state_dict[key] = orig_state_dict[key].to(dtype=model_dtype)
139
+
140
+ new_state_dict["scaled_fp8"] = torch.zeros(2).to(dtype=torch.float8_e4m3fn)
141
+
142
+ save_file(new_state_dict, output_path)
143
+
144
+
145
+ if __name__ == "__main__":
146
+ args = parse_args()
147
+ print(args.base_dtype)
148
+ main(args)
qwen_3_4b_bf16_fp8_scaled.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bc27405fbac59ac9998e52dfc34e66d09d0e0fd1025ead198f05b98760e3f68
3
+ size 4411692798
z_image_turbo_bf16_fp8_scaled_1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ca63396607ef00904f3bd33e8ec3d4e362f225424f8c2e36b1aeb2d7c6390ee
3
+ size 6293681826
z_image_turbo_bf16_fp8_scaled_2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:900e96b082a2010fc885fcd674fd10304ea0420de496fe0bedea169512679743
3
+ size 8299083842
z_image_turbo_fp16_fp8_scaled_1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3822b59769b179d4e96e93ffd1cb968b38fadb1c882e1faec9ff4096818a1728
3
+ size 6293681538
z_image_turbo_fp16_fp8_scaled_2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f251a69d1d6d0260c403dfaedeadffe360df785746960d27c58c045644cc3992
3
+ size 8299083490
z_image_turbo_fp32_fp8_scaled_1.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90024a47bc8f070a8916d367a25012f4a6a67d6974bee3ffe69e3db39450d823
3
+ size 6571092026
z_image_turbo_fp32_fp8_scaled_2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d56f1cfe5409c4902d3d668717de7472d8cd6c592a06ffb1cc1dfad0210cd792
3
+ size 12587297706